Bayesian Poisson factorization

A probabilistic recommender system implemented with TensorFlow Probability layers.
bayesian modeling
variational inference
tensorflow-probability
machine learning
Author

Yves Barmaz

Published

February 23, 2022

Recommender systems

A lot of recommender systems are built on matrix factorization models, where the partially observed matrix of user/item interactions is approximated by a product of matrices encoding latent characteristics of users and items. They can be corrected by user and item bias terms, and modified by activation functions that map to the data type of the observed interactions (e.g. binary kudos on Strava activities, counts of visits of a YouTube channel, ratings on Tripadvisor, or time spent watching a TikTok video before swiping up).

In machine learning, these matrix factorizations are often implemented as embeddings into latent spaces followed by scalar products of the user latent vectors by the item latent vectors, and the model is fit to historical data with a gradient descent algorithm. A Keras tutorial on collaborative filtering describes the methodology with the Movielens dataset.

While the resulting model only provides point estimates of future interactions between users and items, a Bayesian treatment of the problem would add an approximation of the uncertainty of these estimates. Formally, we could express the same model with priors on the vector embeddings in a probabilistic programming library and sample the posteriors with an MCMC algorithm, but computing the likelihood of past interactions can be impractical for very large datasets. Moreover, due to the invariance of the model under permutations of the embedding dimensions, MCMC sampling of the multimodal posterior would be a nightmare. On the other hand, minimizing the Kullback-Leibler divergence \(D_{KL}\left(Q(Z) \Vert P(Z\vert X)\right)\) between a surrogate posterior and the true posterior distribution through variational inference produces a mode-seeking behavior (see for instance these lecture notes), a bit like gradient descent in “classical” machine learning finds a local minimum.

In this blog post, we will be exploring how to implement a model inspired by Gopalan, Hofman and Blei (Scalable Recommendation with Poisson Factorization) with TensorFLow Probability.

Bayesian models and variational inference

As a reminder, minimizing the (intractable) Kullback-Leibler divergence \(D_{KL}\left(Q(Z) \Vert P(Z\vert X)\right)\) between the variational distribution \(Q(Z)\) and the true posterior \(P(Z\vert X)\) is equivalent to maximizing the (computable) evidence lower bound

\[ \mathrm{ELBO} = \mathbb{E}_Q\left[ \log P(X \vert Z) \right] - D_{KL}(Q(Z) \Vert P(Z)), \]

where the first term is the expectation under the surrogate distribution of the likelihood function of the model and the second term the negative Kullback-Leibler divergence between the surrogate distribution and the prior distribution of the model parameters.

In the mean field approximation, where we assume that the variational distribution \(Q\) factorizes over the latent variables, we can implement variational inference with probabilistic layers from the tfp.layers module, which will automatically keep track of the variational parameters during training. These layers are conveniently combined in a Keras model, where the last layer is typically a distribution layer corresponding to the observed data \(X\), and from which which we can use the log_prob method to compute the log-likelihood of the model. In their forward mode, the probabilistic layers draw samples from the surrogate distribution \(Q\) they implement that we can use to compute Monte-Carlo estimates of \(\mathbb{E}_Q\left[ \log P(X \vert Z) \right]\), the first part of the ELBO function.

The second part, \(D_{KL}(Q(Z) \Vert P(Z))\), can be implemented either through an activity regularizer, activity_regularizer = tfpl.KLDivergenceRegularizer(prior_distribution), or through a custom loss term added in the call method of the layer. The former is suitable for global parameters of the model that are shared by all training examples, and the latter is required for latent variables associated to specific training examples.

We will wrap up these initial theoretical considerations with a couple of observations that are important for a good implementation of these methods. The first is that Keras models are trained through minimization of a loss function, so instead of maximizing the ELBO, we will be minimizing the negative ELBO. Concretely, the loss function passed to the model.compile method will be the negative log-likelihood, and \(D_{KL}(Q(Z) \Vert P(Z))\) will be added to the loss rather than subtracted from it. The second is that Keras model training routines evaluate the loss function as a sum over the training examples. While this is straightforward for the likelihood part of the loss when observations are assumed to be conditionally independent,

\[ - \log P(X \vert Z) = - \sum_i \log P(X_i \vert Z), \]

extra care needs to be taken for the Kullback-Leibler term

\[ D_{KL}(Q(Z) \Vert P(Z)) = \mathbb{E}_Q\left[ \log Q(Z) - \log P(Z))\right], \]

which has to be expressed as a sum over the training examples like the log-likelihood. In practice, this can often be achieved through a weighted sum. For instance, if we have only parameters that are shared by all training examples, we can apply a \(1/N\) weight, where \(N\) is the number of data points \(X_i\),

\[ - \mathrm{ELBO} = \mathbb{E}_Q\left[ \sum_i \left( - \log P(X_i \vert Z) + \frac{1}{N}( \log Q(Z) - \log P(Z))\right) \right]. \]

For parameters linked to only subsets of the data points, or even individual data points (we often speak of latent variables), a bit of algebra might be needed to find the proper weights.

Constant weights can be specified as optional parameters of methods like tfpl.KLDivergenceRegularizer, but non-uniform weights have to be coded in the add_loss method of the Keras model, which is the reason why this approach is required as mentioned earlier.

Probabilistic embeddings

While a standard embedding maps discrete values to vectors in a latent space, a probabilistic embedding specifies a probability distribution over the latent space. It can be constructed as a sequence of a parameter layer, that contains the variational distribution parameters, and a TensorFlow Probability distribution layer that encapsulates the sample method and the prior distribution for the KL term. The parameter layer can be expressed as a standard embedding into a space of dimension \(n_{params} \times D\), where \(D\) is the dimension of the original latent space, and \(n_{params}\) the number of variational parameters of the corresponding component of the variational distribution \(Q\).

A custom Keras layer wraps the construction of a probabilistic embedding, with its __init__ method receiving the hyperparameters. In the following code example, we picked Gamma distributions for both the priors and the variational distributions. Note how the KL term is added to the loss with self.add_loss and weighted by latent_kl_weights passed as an argument to the call method of the custom Keras layer. It allows us to specify a unique weight for each training example from inputs.

class GammaEmbedding(tfkl.Layer):
    def __init__(self, num_classes, embedding_size,
                 embedding_concentration, embedding_rate,
                 **kwargs):
        super(GammaEmbedding, self).__init__(**kwargs)
        
        self.embedding_parameters = tfkl.Embedding(
            num_classes,
            2 * embedding_size,
            embeddings_initializer="he_normal"
        )
        
        self.embedding_distribution = tfpl.DistributionLambda(
            lambda x: tfd.Independent(
                tfd.Gamma(tf.math.exp(x[:, :embedding_size]),
                          rate=tf.math.exp(x[:, embedding_size:])),
                reinterpreted_batch_ndims=1)
        )
        
        self.embedding_prior = tfd.Independent(
            tfd.Gamma(
                embedding_concentration *
                tf.ones(shape=(embedding_size,), dtype=tf.float32),
                rate=embedding_rate),
            reinterpreted_batch_ndims=1
        )
            
    def __call__(self, inputs, latent_kl_weights):
        embedding_param = self.embedding_parameters(inputs)
        embedding_distribution = self.embedding_distribution(embedding_param)
        self.add_loss(
            tf.reduce_sum(latent_kl_weights *
                embedding_distribution.kl_divergence(self.embedding_prior)
                         )
        )
        return embedding_distribution

The user and item bias terms mentioned earlier could be implemented as one-dimensional probabilistic embeddings and added to the scalar products of user and item embeddings, but Gopalan et al. suggest to introduce these degrees of freedom as random rate parameters of the embedding Gamma distributions with a hierarchical model construction. Compared to the simple Gamma embedding where the embedding priors were fixed and specified in the __init__ method, we now need them to evolve with the variational parameters of their parent distribution during training, so they need to be dynamically computed in the __call__ method. To model the random rate parameter, we can make use of the Gamma embedding layer we already constructed, with an embedding dimension of 1.

class RateAdjustedGammaEmbedding(tfkl.Layer):
    def __init__(self, num_classes, embedding_size,
                 parent_concentration, parent_rate,
                 embedding_concentration, **kwargs):
        super(RateAdjustedGammaEmbedding, self).__init__(**kwargs)
        
        self.embedding_size = embedding_size
        
        self.rate_distribution = GammaEmbedding(num_classes=num_classes,
                                                embedding_size=1,
                                                embedding_concentration=
                                                parent_concentration,
                                                embedding_rate=parent_rate
                                               )
        
        
        self.embedding_concentration = embedding_concentration
        
        self.embedding_parameters = tfkl.Embedding(
            num_classes,
            2 * embedding_size,
            embeddings_initializer="he_normal"
        )
        
        self.embedding_distribution = tfpl.DistributionLambda(
            lambda x: tfd.Independent(
                tfd.Gamma(tf.math.exp(x[:, :embedding_size]),
                          rate=tf.math.exp(x[:, embedding_size:])),
                reinterpreted_batch_ndims=1
            )
        )
            
    def __call__(self, inputs, latent_kl_weights):
        embedding_param = self.embedding_parameters(inputs)
        embedding_distribution = self.embedding_distribution(embedding_param)

        embedding_rate = self.rate_distribution(inputs, latent_kl_weights)
        embedding_prior = tfd.Independent(
            tfd.Gamma(self.embedding_concentration,
                      rate=embedding_rate * tf.ones((1, self.embedding_size))),
            reinterpreted_batch_ndims=1)
        self.add_loss(
            tf.reduce_sum(latent_kl_weights *
                          embedding_distribution.kl_divergence(embedding_prior))
        )
        
        return embedding_distribution

Probabilistic recommender

With the probabilistic embeddings defined as custom layers, the full model only needs a few lines of code. The user and movie embeddings are constructed as rate-adjusted Gamma embeddings, and their scalar product will be the rate of the Poisson distribution that generates the observations, implemented as a distribution lambda layer.

The KL weights need to be passed to the corresponding probabilistic layers, so we need to include them in the input of the model, for instance in two additional columns. For the user part, we observe that the KL divergence term is decomposed as a sum over user terms,

\[ \mathbb{E}_Q\left[ \log Q(Z_{users}) - \log P(Z_{users})) \right] = \sum_u \mathbb{E}_Q\left[ \log Q(Z_u) - \log P(Z_u)) \right], \]

but we need to express it as a sum over all user/movie interactions of the training set. If we simply replace the sum, we are counting the same user once for every movie they have rated, so we can rescale these terms by this number,

\[ \mathbb{E}_Q\left[ \log Q(Z_{users}) - \log P(Z_{users})) \right] = \sum_i \mathbb{E}_Q\left[ \log Q(Z_{u[i]}) - \log P(Z_{u[i]})) \right]\frac1{N_{u[i]}}, \]

where \(u[i]\) denotes the user of interaction \(i\), and \(N_{u[i]}\) the number of movies rated by this user. The KL weights of the movie part can be derived in the same way.

class ProbabilisticRecommender(tfk.Model):
    def __init__(self, num_users, num_movies, embedding_size, **kwargs):
        super(ProbabilisticRecommender, self).__init__(**kwargs)
        self.num_users = num_users
        self.num_movies = num_movies
        self.embedding_size = embedding_size
        
        self.user_embedding = RateAdjustedGammaEmbedding(
            num_users,
            embedding_size,
            parent_concentration=1.,
            parent_rate=.8,
            embedding_concentration=1.
        )
        
        self.movie_embedding = RateAdjustedGammaEmbedding(
            num_movies,
            embedding_size,
            parent_concentration=1.,
            parent_rate=.8,
            embedding_concentration=1.
        )

        self.head = tfpl.DistributionLambda(lambda t: tfd.Poisson(t))
        
    def call(self, inputs):
        user_vector = self.user_embedding(inputs[:, 0], inputs[:, 2])
        movie_vector = self.movie_embedding(inputs[:, 1], inputs[:, 3])
        dot_user_movie = tf.reduce_sum(user_vector * movie_vector, axis=-1)
        return self.head(dot_user_movie)

As mentioned in the preliminary observations, this model requires a negative log-likelihood loss function, otherwise it is straightforward to train it like any Keras model.

EMBEDDING_SIZE = 20

BATCH_SIZE = 1024

negloglik = lambda y, rv_y: -rv_y.log_prob(y)

prob_model = ProbabilisticRecommender(num_users,
                                      num_movies,
                                      embedding_size=EMBEDDING_SIZE
                                     )
prob_model.compile(
    loss=negloglik, optimizer=tfk.optimizers.Adam(learning_rate=0.01)
)

Once the model has been trained, we can call it on new user/item pairs to produce a Poisson distribution of posterior predicted observations. We can directly sample user ratings from this distribution with .sample(), or call its .rate_parameter() method to find the posterior predicted rate. The latter offers a higher resolution to rank items for a given user (it is a continuous variable rather than an integer) and is therefore more practical for recommender systems.

When the model is called, each probabilistic layer returns a single sample from its learned variational distribution. To estimate the posterior predicted Poisson rate of a user/item interaction, one can call the model several times to obtain a sample. In real world applications, drawing a single Poisson rate or only a few of them rather than estimating the posterior mean to score an item might prove more useful as it offers a broader variety of suggestions to users whose tastes are less certain, namely with user embedding distributions of higher variance. It addresses the exploration-exploitation trade-off with a mechanism similar to Thompson sampling (see also this previous blog post).

In general, users who have provided less ratings or good ratings across a large spectrum of items will get less defined embeddings, and this approach will give them recommendations that explore the item landscape more broadly than for users with tighter embeddings. These users with better-known tastes will still get random suggestions relatively far from their usual preference, albeit less frequently than the users with less defined embeddings, but this is not something that happens with models based on classical embeddings, which always return the same results.

This mechanism is also interesting when acquiring new users who have not yet provided ratings. We can itialize their embeddings to match the priors, or learned distributions from similar users but with wider variance, and use this untrained model to draw recommendations that are compatible with our prior knowledge. When they start giving ratings, we can train the corresponding user embedding layers to incorporate the new knowledge, while freezing the item embedding layers for stability and increased speed.

With TensorFlow Probability layers, we can thus add a Bayesian flavor to more traditional recommender systems and address issues such as exploration-exploitation trade-offs or cold starts in a more principled way. From another angle, we can express probabilistic models such as matrix factorization models as Keras models and take advantage of the tf.data.Dataset API for batch training with potentially large datasets.