Inference in Bayesian Modelling

Purpose

Understanding the ideas behind Bayesian modelling can be challenging. This article's purpose is to stitch together the fundemental ideas behind their design and utility. We will discuss inference, and how it will can be calculated analytically. It will be more of a theoretical exercise than a practical one. A following post will cover approximate inference techniques.

Inference

What is inference exactly?

In standard feedforward network setting, there exists a dataset which consists of inputs and outputs. For this dataset, there exists a true generating function, often defined by nature, that maps the inputs to the outputs. For example, consider the task of idenfiying different breeds of dogs. Certain dog breeds look the way they do just because of nature. We humans can easily approximate this natural function because we are accustomed to seeing dogs.

Often times however, this generating function is intractable to extract directly, and a feedforward network is used to approximate it using a more limited set of parameters. We generally start with a general (often random) approximation to this generating function, and optimize it using some form of convex optimization. Inference in this setting is just evaluation, or taking a test input, and running it through the optimized network.

In the Bayesian setting, inference is much more expansive: instead of considering the parameters of a single, optimized model, we consider all possible parameters for a model given a prior, and aim to determine which ones are more likely to represent the data (or approximate the true generating function). Instead of evaluating with a single set of weights, inference is the process of integrating over model parameters to determine the likihood that a single input maps to a single output. The model described in this paragraph is generally a Bayesian neural network, though this form of inference can be applied on any model, including a standard feedforward network. A Bayesian neural network is a graphical model which is composed of DAGs (not present in feedforward networks), and is not the main focus of this post.

Thinking using Bayes’ Theorem

Bayes’ theorem is defined as:

It defines a basic conditional probability, or dependence, of a variable X based on Y. In a statistical learning setting, we have an additional variable, namely w, or our model weights. Bayes’ theorem becomes:

Equation (2) has 4 important components discussed extensively in the literature, so it’s important you understand their purpose.

(1) P(w) is the prior. This is a difficult concept to fully understand; essentially, it is an assumption about the data you’re trying to model. As an example, suppose you look at a dataset of human heights in a population, and infer the heights are normally distributed with a mean of 70 inches, but you are unsure of the standard deviation. A possible prior is (N(μ=70, σ=[1,2,3,4,5])). Here, you make the explicit assumption that the data is described using a unimodal normal distribution with a mean of 70 inches and one of five standard deviations. Concretely, you hypothesize that the data could be described by the five normal curves. Another possible prior is (exp(λ = [1,2,3])). You make the explicit assumption that the data is described by the three exponential curves. The exact value for P(w=i) depends on how likely you think a specific i describes the data, and will be covered in the next section.

(2) P(Y|X,w) is the likelihood of a particular model specified in the prior describing a given dataset. We are not looking at all the models in the prior, but a specific one. To calculate this explicitly follows from the product rule of probability:

Concretely, we take every data point, calculate the pdf of that point according to a specific distribution we selected from the prior, and multiply all the probabilities.

(3) P(Y|X) is known as the model evidence, or more intuitively, the normalizer. It follows from the formula:

It can be seen intuitively that it is simply the summation of the numerator over all possible weights (or models) defined in the prior, and ensures the posterior is a true probability distribution (sums to 1 across all possibilities).

(4) P(w|X,Y) is the posterior, and is the quantity we want to compute. It essentially asks the question: given a set of data points X and Y, how well does a specific set of weights (or model) w in the prior represent the data points?

Finally, to calculate the liklihood of a single point $y^{*}$, given $x^{*}$ and data points X and Y, we need to integrate over all priors:

Exact Inference: Fully Worked-out Problem

To understand how to perform exact inference, let’s consider a toy example. Suppose we have 3 subjects, whose heights are 5.7, 6.2, and 6.8 feet, respectively. Let’s say our prior is N(μ=[6.1, 6.3], σ=0.5). Thus, we are considering two normal distributions with the same standard deviation, but different means. To jog your memory about the pdf of the normal distribution:

Note: in this problem setting, there is no concept of inputs (X), and the likelihoods only depend on the weights. In most real world problems, this often isn’t the case, and the likelihood is calculated as some function of the input. The images below are the distributions with their corresponding likelihoods. The blue lines correspond to the data points, and the p values (height of the normal curve at that point) are provided on the x-axis.

μ = 6.1

μ = 6.3

First, let us make the assumption (as part of the prior), that each of the two normal curves equally well describes the data. Thus, P(w) is 0.5 for both models. First we calculate the likelihoods:

Now, let's calculate the model evidence. Since there are just 2 models in the prior, the integration boils down to a sum.

Now that we know the model evidence, we can easily calculate the posteriors:

So given our prior, we can conclude the model with $\mu$=6.3 is more likely to describe our data.

Analytic Inference

In this section, we'll go into more detail on how inference is performed analytically. NOTE: These notes are adapted from Yarin Gal's talk at MLSS 2019 (bdl101.ml), and is not my original work. In this section, we are going to derive two useful properties of the posterior, the mean and the variance. We will then use this to compute analytically $y^{*}$ (seen in equation (5)). Let's start with the problem statement. Given a set of data points in X and in Y, we want to create a function between X and Y using two quantities: w, or the weights, and $ \phi $, which is known as a "frozen" feature vector. This feature vector is frozen because it does not change. On the other hand, we place a distribution, or prior, over the weights W in order to perform inference. Let's set some baselines with the following equations:

In equation 7, we establish the prior. The weights W are distributed with mean 0 and standard deviation s. In equation 8, we generate Gaussian noise with standard deviation $\sigma$. This Gaussian has two purposes. First, it allows us to model our likelihood function using a Gaussian. Second, it mimics observational noise in data.

In equation 9, we ignore the model evidence because it is a normalizer, and is the same for all possible weights, w. In equation 10, we calculate the likelihood as the product of the pdfs of each set of data points $ (x_{n}, y_{n}) $, just as we did in the toy problem.

In equation 11, we simply rewrite the normal distribution probability according to its pdf, described in equation 6. In equation 12, we eliminate the C term, which is equivalent to the constant $ \frac{1}{\sqrt{2 \pi}} $. Next, we will collect the terms in the exponent of equation (12) which contain w and ignore the -1/2 constant.

In equation (14), we take only those terms which contain the w terms (we only ignore only one term: $y_{n}$). In equations (15), we expand the summation, and write it as a sum of three terms.

Equation (16) is important. Here, we split the sum of three terms in equation (15) into 2 terms. Namely, we look for terms A and B in the expression: $w^{T}Aw - 2w^{T}B$. Why? Let us consider the example of the normal distribution $P(w) = N(w, \mu, \Sigma I_{k})$ (similair to equation (7)). If we want to group the w terms similair to what we did in equation (13), we first get:

Note: in equation (16), generally the formula includes the term $ s^{-2} $, but since $ \Sigma = s^{2} $, $ s^{-2} = \Sigma^{-1} $. In equation (19) if we similairly look for the terms A and B, we find $A = \Sigma^{-1}, B = \Sigma^{-1}\mu$. In equation (16), we find $A = \sigma^{-2} \sum_{i=0}^n \phi(x_{n}) \phi(x_{n})^T + s^{-2}I_{k}$, B = $ \sigma^{-2} \sum_{i=0}^n y_{n}\phi(x_{n})$. And finally, we have our posterior's mean and variance:

Now, let's use these formulations to do analytic inference on a sample point! We want to find $\mu^{*}$, such that $\mu^{*} = E_{p(y^{*}| x^{*}, X, Y)}[y^{*}]$. In essence, we are finding the mean value of $y^{*}$ across all possible values.

In equation (23), we replace the term $P(y^{*}|x^{*},X,Y)$ with the RHS of equation (5). In equation (24), we notice the term $ \int P(y^*|x^*,w) y^{*} \,d(y^{*}) $ is simply the mean of the distribution described in equation (8), which is $ w^{T}\phi(X) $. In equation (26), we notice the term $ \int w^{T} P(w|X,Y)\,d(w) $ is the mean of the posterior (described in equation (21)), but inverted because it contains $w^{T}$, not $w$.

The Intractability Problem

Why do we need approximate inference when it's shown here exact and analytic inference works? It's important to understand this only works in relatively simple cases, as in the toy experiment above. The bottleneck is calculating the model evidence. In our toy experiment, we only considered 2 model in the prior. In some situations, this number becomes exceedinly large. An example: if we don't know anything about the data distribution, we may have hundreds of normal distributions, each with different means and standard deviations. In some cases, we may even have an infinite number of models in the prior, and calculating the integral analytically (or numerically) is infeasible.

Next steps

In the next post, I will be going over some techniques to approximate the posterior, and thus perform inference when the calculating the posterior analytically is computationally infeasible.