7  Introduction to Inference

Author

Mark Fuge

Published

October 22, 2025

Learning Objectives

By the end of this notebook, you will:

  1. Understand Bayesian inference for regression and how it quantifies model uncertainty
  2. Visualize the posterior predictive distribution for linear models
  3. Recognize the limitations of linear models on nonlinear data
  4. Learn how neural networks provide flexible nonlinear models
  5. Understand why exact Bayesian inference becomes intractable for neural networks
  6. Implement Variational Inference (VI) as an approximate solution
  7. Compare VI to MCMC and understand the speed vs. accuracy tradeoff
  8. Visualize predictive uncertainty in function space

The Big Picture

Imagine you have sparse sensor measurements from a mechanical system. You want to:

  1. Fit a model that captures the underlying relationship
  2. Quantify uncertainty in predictions (critical for safety-critical systems)
  3. Make predictions with confidence intervals

This notebook shows you how to do this using Bayesian inference, starting with simple linear models and progressing to flexible neural networks where variational inference becomes essential.

7.1 Motivation and Toy Problem Definition

Let’s start with a fundamental question: when you fit a model to data, how confident should you be in its predictions?

In traditional machine learning, we find a single “best” set of parameters. But in engineering, we often need to know: - How uncertain are we about the model? - Where in the input space are predictions reliable? - What happens if we had slightly different data?

Bayesian inference answers these questions by maintaining a distribution over models rather than picking a single one.

7.1.1 The Dataset: Nonlinear Sensor Measurements

We’ll use a simple nonlinear function to represent sensor data: \[y = \sin(x) + \epsilon, \quad \epsilon \sim \mathcal{N}(0, 0.1^2)\]

This represents, for example: - Cyclic behavior (vibrations, temperatures, etc.) - Noisy measurements - Sparse observations (we won’t measure everywhere)

We will be interested in understanding how the model can capture the uncertainty of its prediction.

Show Code
# Setup and imports
import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from typing import Tuple
import ipywidgets as widgets

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
<torch._C.Generator at 0x280584e30f0>
Show Code
# Generate nonlinear dataset
def generate_nonlinear_data(n_samples: int = 20, 
                           noise_std: float = 0.1,
                           x_range: Tuple[float, float] = (-3, 3)) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate sparse noisy samples from a sinusoidal function.
    
    This simulates sensor measurements from a cyclic process.
    
    Args:
        n_samples: Number of observations
        noise_std: Standard deviation of measurement noise
        x_range: Input domain (min, max)
        
    Returns:
        x: Input locations (n_samples,)
        y: Noisy measurements (n_samples,)
    """
    x_min, x_max = x_range
    
    # Sample input locations (sparse, irregular spacing)
    x = np.random.uniform(x_min, x_max, n_samples)
    
    # True underlying function (unknown to the model)
    y_true = np.sin(x)
    
    # Add measurement noise
    y = y_true + noise_std * np.random.randn(n_samples)
    
    # Sort for easier visualization
    sort_idx = np.argsort(x)
    x = x[sort_idx]
    y = y[sort_idx]
    
    return x, y

# Generate training data
n_train = 20
noise_std = 0.1
x_train, y_train = generate_nonlinear_data(n_train, noise_std)

# Create dense grid for visualization (the "true" function)
x_test = np.linspace(-3, 3, 200)
y_test_true = np.sin(x_test)

# Visualize the dataset
fig, ax = plt.subplots(figsize=(12, 6))

# Plot true function
ax.plot(x_test, y_test_true, 'k-', linewidth=2, label='True function: $y = \sin(x)$', alpha=0.7)

# Plot noisy training observations
ax.scatter(x_train, y_train, s=80, c='red', edgecolors='black', 
           linewidth=1.5, zorder=5, label=f'Training data (n={n_train})')

# Shade noise region
ax.fill_between(x_test, y_test_true - 2*noise_std, y_test_true + 2*noise_std,
                alpha=0.2, color='gray', label=f'±2sigma noise band (sigma={noise_std})')

ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title('Nonlinear Sensor Measurements', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-1.5, 1.5])

plt.tight_layout()
plt.show()

7.2 Linear Bayesian Regression

Let’s start with the simplest approach: linear regression with Bayesian inference.

7.2.1 The Model

We assume a linear relationship: \[y = w_0 + w_1 x + \epsilon, \quad \epsilon \sim \mathcal{N}(0, \sigma^2)\]

In Bayesian inference, we place a prior over the weights: \[p(\mathbf{w}) = \mathcal{N}(\mathbf{0}, \alpha^{-1} \mathbf{I})\]

After seeing data \((X, Y)\), the posterior over weights is: \[p(\mathbf{w} | X, Y) = \mathcal{N}(\mathbf{m}_N, \mathbf{S}_N)\]

where (for linear-Gaussian models, we can derive it exactly): \[\mathbf{S}_N = (\alpha \mathbf{I} + \beta \mathbf{\Phi}^T \mathbf{\Phi})^{-1}\] \[\mathbf{m}_N = \beta \mathbf{S}_N \mathbf{\Phi}^T \mathbf{y}\]

where \(\beta = 1/\sigma^2\) is the noise precision and \(\mathbf{\Phi}\) is the design matrix.

7.2.2 The Posterior Predictive Distribution

For a new input \(x_*\), the predictive distribution integrates over all possible weights: \[p(y_* | x_*, X, Y) = \int p(y_* | x_*, \mathbf{w}) p(\mathbf{w} | X, Y) d\mathbf{w}\]

For linear-Gaussian models, this is also Gaussian: \[p(y_* | x_*, X, Y) = \mathcal{N}(\mathbf{m}_N^T \phi(x_*), \sigma_N^2(x_*))\]

The variance \(\sigma_N^2(x_*)\) tells us how uncertain we are about predictions at \(x_*\).

Let’s implement this and see what happens when applied to our non-linear data.

class BayesianLinearRegression:
    """
    Bayesian linear regression with Gaussian prior and likelihood.
    
    Closed-form posterior over weights for polynomial basis functions.
    """
    
    def __init__(self, degree: int = 1, alpha: float = 1.0, beta: float = 100.0):
        """
        Args:
            degree: Degree of polynomial basis (1 = linear)
            alpha: Prior precision (inverse variance)
            beta: Noise precision (1/sigma^2)
        """
        self.degree = degree
        self.alpha = alpha
        self.beta = beta
        self.m_N = None  # Posterior mean
        self.S_N = None  # Posterior covariance
        
    def _design_matrix(self, x: np.ndarray) -> np.ndarray:
        """Compute polynomial design matrix."""
        x = x.reshape(-1, 1)
        Phi = np.concatenate([x**i for i in range(self.degree + 1)], axis=1)
        return Phi
    
    def fit(self, X: np.ndarray, y: np.ndarray):
        """Compute posterior over weights given training data."""
        Phi = self._design_matrix(X)
        n_features = Phi.shape[1]
        
        # Posterior covariance
        self.S_N = np.linalg.inv(self.alpha * np.eye(n_features) + 
                                  self.beta * Phi.T @ Phi)
        
        # Posterior mean
        self.m_N = self.beta * self.S_N @ Phi.T @ y
        
        return self
    
    def predict(self, X_test: np.ndarray, return_std: bool = True):
        """
        Posterior predictive distribution.
        
        Returns:
            mean: Predictive mean
            std: Predictive standard deviation (if return_std=True)
        """
        Phi_test = self._design_matrix(X_test)
        
        # Predictive mean
        y_mean = Phi_test @ self.m_N
        
        if return_std:
            # Predictive variance (includes noise and weight uncertainty)
            y_var = 1.0/self.beta + np.sum(Phi_test @ self.S_N * Phi_test, axis=1)
            y_std = np.sqrt(y_var)
            return y_mean, y_std
        
        return y_mean
    
    def sample_weights(self, n_samples: int = 100) -> np.ndarray:
        """Sample weight vectors from posterior."""
        return np.random.multivariate_normal(self.m_N, self.S_N, size=n_samples)

# Fit linear Bayesian regression
blr = BayesianLinearRegression(degree=1, alpha=1.0, beta=1.0/noise_std**2)
blr.fit(x_train, y_train)

# Posterior predictive distribution
y_pred_mean, y_pred_std = blr.predict(x_test)

print("Linear Bayesian Regression Fitted!")
print(f"Posterior mean weights: {blr.m_N}")
print(f"Posterior std of weights: {np.sqrt(np.diag(blr.S_N))}")
Linear Bayesian Regression Fitted!
Posterior mean weights: [-0.07553835  0.33272507]
Posterior std of weights: [0.0225762  0.01242588]
Show Code
# Visualize linear Bayesian regression predictions
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Posterior predictive with uncertainty bands
ax = axes[0]

# True function
ax.plot(x_test, y_test_true, 'k-', linewidth=2, label='True: $y = \sin(x)$', alpha=0.7)

# Training data
ax.scatter(x_train, y_train, s=80, c='red', edgecolors='black', 
           linewidth=1.5, zorder=5, label='Training data')

# Predictive mean
ax.plot(x_test, y_pred_mean, 'b-', linewidth=2, label='Linear prediction (posterior mean)')

# Uncertainty bands (±1sigma and ±2sigma)
ax.fill_between(x_test, y_pred_mean - y_pred_std, y_pred_mean + y_pred_std,
                alpha=0.3, color='blue', label='±1sigma predictive')
ax.fill_between(x_test, y_pred_mean - 2*y_pred_std, y_pred_mean + 2*y_pred_std,
                alpha=0.2, color='blue', label='±2sigma predictive')

ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title('Linear Model: Cannot Capture Nonlinearity', fontsize=14, fontweight='bold')
ax.legend(fontsize=10, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-2, 2])

# Plot 2: Function samples from posterior
ax = axes[1]

# Sample weight vectors and plot corresponding functions
w_samples = blr.sample_weights(n_samples=50)
Phi_test = np.concatenate([x_test.reshape(-1,1)**i for i in range(2)], axis=1)

for i in range(50):
    y_sample = Phi_test @ w_samples[i]
    ax.plot(x_test, y_sample, 'b-', alpha=0.1, linewidth=1)

# True function and data
ax.plot(x_test, y_test_true, 'k-', linewidth=2, label='True function', alpha=0.7)
ax.scatter(x_train, y_train, s=80, c='red', edgecolors='black', 
           linewidth=1.5, zorder=5, label='Training data')

ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title('Posterior Function Samples (50 draws)', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-2, 2])

plt.tight_layout()
plt.show()

7.3 A Small Neural Network (Nonlinear Model)

The linear model fails because the true function is nonlinear. What if we try to use a simple neural network instead?

7.3.1 Architecture

We’ll use a simple 1-hidden-layer network: \[f(x; \mathbf{w}) = \mathbf{w}_2^T \sigma(\mathbf{w}_1 x + \mathbf{b}_1) + b_2\]

where \(\sigma\) is ReLU activation. With just 5 hidden units, this can approximate smooth nonlinear functions.

7.3.2 The New Challenge: Intractable Posterior

Unlike linear regression, the posterior \(p(\mathbf{w} | X, Y)\) is no longer Gaussian! We cannot compute it analytically.

However, we can still: 1. Define a prior \(p(\mathbf{w})\) (e.g., Gaussian on all weights) 2. Sample functions from the prior to see what kinds of functions are admissible under the model 3. Later, we can use VI or MCMC to approximate the posterior

Let’s first see what the prior predictive distribution looks like.

class SmallNN(nn.Module):
    """
    Simple 1-hidden-layer neural network for regression.
    """
    
    def __init__(self, hidden_size: int = 5):
        super().__init__()
        self.fc1 = nn.Linear(1, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        
    def forward(self, x):
        x = x.view(-1, 1)
        h = F.relu(self.fc1(x))
        y = self.fc2(h)
        return y.squeeze()
    
    def get_weights_flat(self) -> torch.Tensor:
        """Get all weights as a flat vector."""
        return torch.cat([p.flatten() for p in self.parameters()])
    
    def set_weights_flat(self, w_flat: torch.Tensor):
        """Set all weights from a flat vector."""
        offset = 0
        for p in self.parameters():
            n_params = p.numel()
            p.data = w_flat[offset:offset+n_params].view(p.shape)
            offset += n_params

# Create network and count parameters
hidden_size = 5
model = SmallNN(hidden_size=hidden_size)
n_params = sum(p.numel() for p in model.parameters())

print("Neural Network Architecture:")
print(f"  Input → {hidden_size} hidden (ReLU) → 1 output")
print(f"  Total parameters: {n_params}")
print("Parameter breakdown:")
for name, param in model.named_parameters():
    print(f"  {name}: shape {param.shape}, count {param.numel()}")
Neural Network Architecture:
  Input → 5 hidden (ReLU) → 1 output
  Total parameters: 16
Parameter breakdown:
  fc1.weight: shape torch.Size([5, 1]), count 5
  fc1.bias: shape torch.Size([5]), count 5
  fc2.weight: shape torch.Size([1, 5]), count 5
  fc2.bias: shape torch.Size([1]), count 1
Show Code
# Sample functions from the PRIOR (before seeing any data)
def sample_prior_functions(model, x_test, n_samples=20, prior_std=1.0):
    """
    Sample function predictions from prior p(w) = N(0, prior_std^2 * I).
    """
    x_test_tensor = torch.FloatTensor(x_test)
    predictions = []
    
    with torch.no_grad():
        for _ in range(n_samples):
            # Sample weights from prior
            for param in model.parameters():
                param.data.normal_(0, prior_std)
            
            # Forward pass
            y_pred = model(x_test_tensor).numpy()
            predictions.append(y_pred)
    
    return np.array(predictions)

# Generate prior predictive samples
prior_std = 1.0
n_prior_samples = 50
prior_predictions = sample_prior_functions(model, x_test, n_prior_samples, prior_std)

# Visualize prior predictive distribution
fig, ax = plt.subplots(figsize=(12, 7))

# Plot prior function samples
for i in range(n_prior_samples):
    ax.plot(x_test, prior_predictions[i], 'purple', alpha=0.3, linewidth=1)

# Plot true function and data
ax.plot(x_test, y_test_true, 'k-', linewidth=3, label='True function: $y = \sin(x)$', zorder=10)
ax.scatter(x_train, y_train, s=100, c='red', edgecolors='black', 
           linewidth=2, zorder=15, label='Training data (not yet used!)')

# Add dummy line for legend
ax.plot([], [], 'purple', alpha=0.6, linewidth=2, 
        label=f'Prior samples (n={n_prior_samples})')

ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title(f'Prior Predictive: Neural Network Before Training\n(Prior: $w \sim \mathcal{{N}}(0, {prior_std}^2 I)$)', 
             fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-3, 3])

plt.tight_layout()
plt.show()

We can see that just sampling functions from the prior, the possible Neural Network functions are a lot more diverse that those we could sample from the linear model. But how do we find out which of these weights are (probabilistically) likely?

7.4 MCMC Approximation

The neural network posterior \(p(\mathbf{w} | X, Y)\) is intractable, but we can sample from it using Markov Chain Monte Carlo (MCMC). Before we dive into how to apply this to neural networks, let’s first gain some intuition on how and why MCMC works on a simple 2D example.

7.4.1 Why can’t we just compute the posterior using regular Monte Carlo Integration?

A classical Monte Carlo approach would involve: 1. Drawing many samples \(\mathbf{w}^{(i)}\) from the prior \(p(\mathbf{w})\), or (as is more common) from a proposal distribution \(q(\mathbf{w})\) that we hope covers the posterior well. 2. Weighting each sample by its likelihood \(p(Y | X, \mathbf{w}^{(i)})\) 3. Normalizing empirically to get an approximate posterior.

This is called importance sampling and while it is conceptually simple, it doesn’t work for many complex models or distributions because the proposal distribution needs to be well matched to the true posterior. If the posterior is concentrated in a small region of parameter space, most samples will have negligible weight, and this will make our “Effective Sample Size” (ESS) very small, i.e., we will need to do lots of sampling, but ultimately throw out most of those samples, so this is inefficient. For low-dimensional, fairly simple distributions, importance sampling can work well, but for high-dimensional models like neural networks or very complex posterior distributions, it typically fails (as we will see below).

To address this, we can use Markov Chain Monte Carlo (MCMC) methods, which adaptively explore the posterior distribution. The key idea is to construct a Markov chain whose stationary distribution is the target posterior distribution. By running the chain for a long time, we can obtain samples that are approximately distributed according to the posterior. It’s main disadvantage is that because many of the samples are correlated, we need many more samples to get a good approximation of the posterior compared to independent sampling methods, but this downside is offset by the fact that MCMC can explore more complex distributions more effectively even using a simple proposal distribution.

7.4.2 What is MCMC?

MCMC methods (like Metropolis-Hastings, Hamiltonian Monte Carlo, and others) generate samples \(\mathbf{w}^{(1)}, \mathbf{w}^{(2)}, \ldots\) from the posterior by: 1. Starting at a random initial weight vector 2. Proposing updated samples that respect the posterior density 3. Eventually converging to samples from \(p(\mathbf{w} | X, Y)\)

So, in a nutshell, it allows you to generate plausible samples from \(p(\mathbf{w} | X, Y)\) without ever having to strictly compute the posterior.

Pros: Asymptotically exact (given infinite samples)
Cons: Slow, especially for high-dimensional models

For our small network, MCMC is feasible. We will use it as a reference to compare against variational inference at the end of this notebook.

7.4.3 Note on Implementation

For this notebook, we’ll use a simple Metropolis-Hastings sampler, although many more advanced algorithms exist.1

1 There are some tricks to reduce the variance by subtracting a moving “baseline” average to the advantage term, which is implemented below, but we will not go into the details of this here.

7.4.3.1 The Metropolis-Hastings Algorithm

The algorithm works as follows:

  1. Initialize: Start with random weights \(\mathbf{w}^{(0)}\)
  2. Propose: Generate a candidate \(\mathbf{w}^* \sim q(\mathbf{w}^* | \mathbf{w}^{(t)})\) (typically a Gaussian centered at current weights)
  3. Accept/Reject: Compute the acceptance probability: \[\alpha = \min\left(1, \frac{p(\mathbf{w}^* | X, Y)}{p(\mathbf{w}^{(t)} | X, Y)}\right) = \min\left(1, \frac{p(Y | X, \mathbf{w}^*) p(\mathbf{w}^*)}{p(Y | X, \mathbf{w}^{(t)}) p(\mathbf{w}^{(t)})}\right)\] Accept \(\mathbf{w}^{(t+1)} = \mathbf{w}^*\) with probability \(\alpha\), otherwise keep \(\mathbf{w}^{(t+1)} = \mathbf{w}^{(t)}\)
  4. Repeat: Continue for many iterations until the chain converges

The useful thing about this sampler (and indeed MCMC approaches in general), is that we only ever need to evaluate the unnormalized posterior (likelihood x prior), and thus can avoid computing the intractable normalizing constant.

Key Caveats (with Metropolis-Hastings or MCMC methods in general): - You cannot just take individual samples and treat them as independent draws from the posterior, since the starting location of one sample was dependent on the location of the last sample. This means that the samples are correlated. In practice, this means you will need to discard the first certain number of samples (this is called “Burn In”) and also consider only keeping every K number of samples (this is called “Thinning”) so that the MCMC samples become less correlated with one another and more like true independent samples from the posterior. - While the proposal distribution can be simple (e.g., Gaussian), its parameters (like standard deviation) need to be carefully tuned since if it is too large we will have low sample acceptance rates, but it is is too small we will explore the space very slowly, requiring many more samples to converge.

For the 2D example we will do next, MCMC works very well and will give us a good approximation of the posterior. For our 16-parameter network that we will return to later in the notebook, MCMC sampling is still tractable. For modern deep networks with millions of parameters, however, MCMC becomes prohibitively slow and this is where Variational Inference can come to the rescue, as we will see later in the second part of the notebook.

7.4.4 Toy Example: 2D Posterior Sampling from the Banana Distribution

We will start with a simple 2D example to illustrate how MCMC works. This will come from the famous “Banana Distribution”, which is a common test case for MCMC methods due to its curved, non-Gaussian shape, as we can see below.

Show Code
# Reuse global RNG seed for reproducibility (already set above)
rng = np.random.default_rng(42)

# Unnormalized log-density of banana target:
# log p~(x,y) = -0.5 * ( x^2/9 + (y - 0.1 x^2)^2 )

def log_p_tilde(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    x = np.asarray(x)
    y = np.asarray(y)
    return -0.5 * ( (x**2) / 9.0 + (y - 0.1 * x**2)**2 )

# Simple isotropic Gaussian proposal q(x,y) = N(0, I)
# log q(x,y) = -0.5 * (x^2 + y^2) + const; constants cancel in ratios, so we can drop them

def log_q(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    x = np.asarray(x)
    y = np.asarray(y)
    return -0.5 * (x**2 + y**2)

print("Defined banana target log-density and Gaussian proposal.")

# Now plot the target distribution for visualization
x_vals = np.linspace(-10, 10, 300)
y_vals = np.linspace(-5, 15, 300)
X, Y = np.meshgrid(x_vals, y_vals)
Z = np.exp(log_p_tilde(X, Y))
fig, ax = plt.subplots(figsize=(10, 6))
contour = ax.contourf(X, Y, Z, levels=50, cmap='viridis')
plt.colorbar(contour, ax=ax, label='Density')
ax.set_title('Banana-shaped Target Distribution', fontsize=14, fontweight='bold')
ax.set_xlabel('x', fontsize=13)
ax.set_ylabel('y', fontsize=13)
plt.show()
Defined banana target log-density and Gaussian proposal.

We can attempt to do Monte Carlo integration on this distribution, but because of its curved shape, simple importance sampling is unlikely to cover this distribution well, as we can see below:

# Importance Sampling on the banana target
N_is = 10_000

# Draw from q ~ N(0, I)
x_is = rng.normal(size=N_is)*.3
y_is = rng.normal(size=N_is)*.3

# Unnormalized log weights: log w_i = log p~(x_i,y_i) - log q(x_i,y_i)
log_w = log_p_tilde(x_is, y_is) - log_q(x_is, y_is)
# Stabilize weights by subtracting max log_w
log_w_shift = log_w - np.max(log_w)
w = np.exp(log_w_shift)
w_norm = w / np.sum(w)

# Effective sample size
ESS = 1.0 / np.sum(w_norm**2)

# Weighted means
mean_x_is = np.sum(w_norm * x_is)
mean_y_is = np.sum(w_norm * y_is)

print(f"Importance Sampling: N={N_is}, ESS={ESS:.1f} ({100*ESS/N_is:.1f}% of N)")
print(f"Weighted means: E[x]={mean_x_is:.3f}, E[y]={mean_y_is:.3f}")
Importance Sampling: N=10000, ESS=9959.7 (99.6% of N)
Weighted means: E[x]=-0.003, E[y]=0.007
Show Code
# Visualize weighted samples
fig, ax = plt.subplots(figsize=(6, 5))
sc = ax.scatter(x_is, y_is, s=8, alpha=0.3)
#cb = plt.colorbar(sc, ax=ax)
#cb.set_label('Normalized weight')
ax.set_title(f'Importance Sampling\nESS={ESS:.0f} of N={N_is}')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.grid(True, alpha=0.3)
# Overlay contour of target distribution
x_vals = np.linspace(-10, 10, 300)
y_vals = np.linspace(-4, 5, 300)
X, Y = np.meshgrid(x_vals, y_vals)
Z = np.exp(log_p_tilde(X, Y))
contour = ax.contour(X, Y, Z, levels=10, colors='black', alpha=0.3)
ax.set_aspect('equal', adjustable='box')
plt.tight_layout()
plt.show()

TipExperiment: Modifying the Importance Sampling Proposal Distribution
  • What happens if you change the proposal distribution to have a larger or smaller variance? Does this improve the coverage of the posterior?
  • Compare the Effective Sample Size (ESS) for different proposal variances – does this align with your intuition when you compare the coverage of the sample points with the posterior?

Now that we have seen the limitations of importance sampling, let’s see how MCMC can help us better sample from this complex posterior. We will implement Metropolis-Hastings below:

# Metropolis–Hastings MCMC on the banana target
def mcmc_banana(N=10_000, sigma=0.5, burn_in=1_000, start=(0.0, 0.0), seed=42):
    """Run a simple MH sampler with isotropic Gaussian proposals."""
    rng_local = np.random.default_rng(seed)
    x_curr, y_curr = float(start[0]), float(start[1])
    logp_curr = log_p_tilde(x_curr, y_curr)

    samples = np.zeros((N, 2), dtype=float)
    accepts = 0

    for i in range(N):
        # Propose from q(x'|x) = N(x, sigma^2 I)
        dx, dy = rng_local.normal(scale=sigma, size=2)
        x_prop, y_prop = x_curr + dx, y_curr + dy
        logp_prop = log_p_tilde(x_prop, y_prop)

        # Symmetric proposal -> Hastings ratio reduces to p~(x')/p~(x)
        log_alpha = logp_prop - logp_curr
        if np.log(rng_local.uniform()) < min(0.0, log_alpha):
            x_curr, y_curr = x_prop, y_prop
            logp_curr = logp_prop
            accepts += 1
        samples[i] = (x_curr, y_curr)

    acc_rate = accepts / N
    chain = samples[burn_in:]
    mean_x_mcmc = float(np.mean(chain[:, 0]))
    mean_y_mcmc = float(np.mean(chain[:, 1]))
    return chain, acc_rate, mean_x_mcmc, mean_y_mcmc
Show Code
# Run MCMC
## Try changing these and see what happens to the behavior of the MCMC samples
# (e.g., no burn-in, different starting point, low vs high sigma)
N_mcmc = 100_000
burn_in = 1_000
sigma = .1
start = (0.0, 0.0)
##
chain, acc_rate, mean_x_mcmc, mean_y_mcmc = mcmc_banana(N=N_mcmc, sigma=sigma, burn_in=burn_in, start=start, seed=7)

print(f"MCMC: N={N_mcmc}, burn-in={burn_in}, sigma={sigma}, acceptance={acc_rate:.2f}")
print(f"MCMC means: E[x]={mean_x_mcmc:.3f}, E[y]={mean_y_mcmc:.3f}")
MCMC: N=100000, burn-in=1000, sigma=0.1, acceptance=0.96
MCMC means: E[x]=1.061, E[y]=0.729
Show Code
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# Left: MCMC (plot scatter and path)
ax = axes[0]
ax.scatter(chain[:,0], chain[:,1], s=5, c='tab:orange', alpha=0.3, label='MCMC samples')
# Plot a thin path for the first few thousand steps to show exploration
path_len = min(2000, len(chain))
ax.plot(chain[:path_len,0], chain[:path_len,1], color='tab:red', linewidth=1, alpha=0.5, label='Trace (first 2k)')
ax.set_title(f'MCMC (MH)\naccept.={acc_rate:.2f}, N={N_mcmc}, burn-in={burn_in}, sigma={sigma}')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend(loc='upper right', fontsize=9)
ax.grid(True, alpha=0.3)
# overlay contour of target distribution
x_vals = np.linspace(-10, 10, 300)
y_vals = np.linspace(-4, 5, 300)
X, Y = np.meshgrid(x_vals, y_vals)
Z = np.exp(log_p_tilde(X, Y))
contour = ax.contour(X, Y, Z, levels=10, colors='black', alpha=0.3)
ax.set_aspect('equal', adjustable='box')

# Right: Line plot of MCMC trace
ax = axes[1]
ax.plot(chain[:,0], label='x', color='tab:blue', alpha=0.7)
ax.plot(chain[:,1], label='y', color='tab:green', alpha=0.7)
ax.set_title('MCMC Trace Plot')
ax.set_xlabel('Iteration')
ax.set_ylabel('Value')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Summary diagnostics:")
print(f"  Importance Sampling: ESS={ESS:.1f}/{N_is} ({100*ESS/N_is:.1f}%), E[x]={mean_x_is:.3f}, E[y]={mean_y_is:.3f}")
print(f"  MCMC: acceptance={acc_rate:.2f}, E[x]={mean_x_mcmc:.3f}, E[y]={mean_y_mcmc:.3f}")

Summary diagnostics:
  Importance Sampling: ESS=9959.7/10000 (99.6%), E[x]=-0.003, E[y]=0.007
  MCMC: acceptance=0.96, E[x]=1.061, E[y]=0.729
TipExperiment: Metropolis-Hastings sampling parameters

Modify the number of samples, proposal variance, burn-in period, and the starting point. - What happens if the proposal variance is very small? Is the chain able to explore the posterior well? How is this behavior reflected in the acceptance rate or the trace plot? - What happens if the starting point is very far from the high-density region of the posterior? How long does it take for the chain to converge to the target distribution? What does this imply for the burn-in period and how is this reflected in the trace plot? - How do the various factors (proposal variance, starting point, burn-in) interact to affect the acceptance rate and the quality of the samples? Is a higher or lower acceptance rate always better? Why do you think this is?

7.4.5 Why IS struggles and MCMC succeeds

  • Importance Sampling (IS) relies on a global reweighting of proposal samples. When \(q\) poorly overlaps with the curved high-density region of the banana target, almost all weights are near zero. This leads to a tiny effective sample size (ESS) and high estimator variance.
  • MCMC with local proposals uses acceptance ratios that only depend on pairwise density ratios. Even with an unnormalized target \(\tilde p\), the normalizing constant cancels and the chain can move along the curved ridge, accumulating many effective samples.
  • Both methods can use unnormalized \(p\); the practical difference is normalization: IS requires a global normalization of weights, while MCMC only needs local ratios. As dimensionality and curvature increase, IS becomes very brittle unless \(q\) is very well matched to \(p\), whereas MCMC typically remains viable with reasonable proposals.

7.4.6 Returning to Neural Networks (now with MCMC sampling)

Now that we have a good understanding of how MCMC works on a simple 2D example, let’s return to our small neural network and use MCMC to sample from its posterior distribution over weights. This will allow us to approximate the posterior predictive distribution for our nonlinear regression problem. We will add on a few tricks that we didn’t have in our simple 2D example, namely: - We’ll thin out samples in the chain to reduce correlation between samples - We’ll discard a burn-in period to allow the chain to converge before collecting samples (as we did before) - Before starting the MCMC chain, we’ll pretrain the network using MAP estimation to find a good starting point for the chain.

# Simple MCMC sampler for neural network weights
def log_likelihood(model, x, y, noise_std=0.1):
    """Compute log p(y|x,w) for current weights."""
    with torch.no_grad():
        x_tensor = torch.FloatTensor(x)
        y_pred = model(x_tensor).numpy()
        log_lik = -0.5 * np.sum(((y - y_pred) / noise_std) ** 2)
        log_lik -= len(y) * np.log(noise_std * np.sqrt(2 * np.pi))
    return log_lik

def log_prior(model, prior_std=1.0):
    """Compute log p(w) for current weights."""
    log_p = 0
    for param in model.parameters():
        log_p -= 0.5 * torch.sum((param / prior_std) ** 2).item()
    return log_p

def metropolis_hastings_nn(x_train, y_train, n_samples=5000, proposal_std=0.01,
                           prior_std=1.0, noise_std=0.1, thin=10):
    """
    Simple Metropolis-Hastings MCMC for neural network weights.
    
    Args:
        thin: Only keep every 'thin'-th sample to reduce autocorrelation
    
    Returns:
        weight_samples: List of weight vectors
        acceptance_rate: Fraction of proposals accepted
    """
    model = SmallNN(hidden_size=5)
    
    # Initialize at MAP estimate (rough approximation: just train briefly)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    x_tensor = torch.FloatTensor(x_train).view(-1, 1)
    y_tensor = torch.FloatTensor(y_train)
    
    print("Finding initial weights (quick MAP estimate)...")
    for _ in range(200):
        optimizer.zero_grad()
        loss = F.mse_loss(model(x_tensor.squeeze()), y_tensor)
        loss.backward()
        optimizer.step()
    
    print(f"Starting MCMC sampling ({n_samples} iterations)...")
    
    # Current state
    current_log_posterior = log_likelihood(model, x_train, y_train, noise_std) + \
                            log_prior(model, prior_std)
    
    weight_samples = []
    n_accepted = 0
    
    for i in range(n_samples):
        # Propose new weights
        proposed_model = SmallNN(hidden_size=5)
        with torch.no_grad():
            for p_curr, p_prop in zip(model.parameters(), proposed_model.parameters()):
                p_prop.data = p_curr.data + proposal_std * torch.randn_like(p_curr)
        
        # Compute proposed log posterior
        proposed_log_posterior = log_likelihood(proposed_model, x_train, y_train, noise_std) + \
                                  log_prior(proposed_model, prior_std)
        
        # Accept/reject
        log_alpha = proposed_log_posterior - current_log_posterior
        if np.log(np.random.rand()) < log_alpha:
            # Accept
            model = proposed_model
            current_log_posterior = proposed_log_posterior
            n_accepted += 1
        
        # Store sample (with thinning)
        if i % thin == 0:
            weight_samples.append(model.get_weights_flat().clone().detach().numpy())
        
        if (i + 1) % 2000 == 0:
            acc_rate = n_accepted / (i + 1)
            print(f"  Iteration {i+1}/{n_samples}, acceptance rate: {acc_rate:.3f}")
    
    acceptance_rate = n_accepted / n_samples
    print(f"MCMC complete! Final acceptance rate: {acceptance_rate:.3f}")
    
    return weight_samples, acceptance_rate

# Run MCMC (this will take a minute or two)
print("Running MCMC to sample from posterior p(w|X,Y)...")
print("This may take 1-2 minutes for accurate results.")
mcmc_samples, acc_rate = metropolis_hastings_nn(
    x_train, y_train, 
    n_samples=10_000, 
    proposal_std=0.01,
    prior_std=1.0,
    noise_std=noise_std,
    thin=10
)

print(f"Collected {len(mcmc_samples)} posterior samples (after thinning)")
print(f"Acceptance rate: {acc_rate:.1%}")
Running MCMC to sample from posterior p(w|X,Y)...
This may take 1-2 minutes for accurate results.
Finding initial weights (quick MAP estimate)...
Starting MCMC sampling (10000 iterations)...
Finding initial weights (quick MAP estimate)...
Starting MCMC sampling (10000 iterations)...
  Iteration 2000/10000, acceptance rate: 0.460
  Iteration 2000/10000, acceptance rate: 0.460
  Iteration 4000/10000, acceptance rate: 0.458
  Iteration 4000/10000, acceptance rate: 0.458
  Iteration 6000/10000, acceptance rate: 0.462
  Iteration 6000/10000, acceptance rate: 0.462
  Iteration 8000/10000, acceptance rate: 0.466
  Iteration 8000/10000, acceptance rate: 0.466
  Iteration 10000/10000, acceptance rate: 0.455
MCMC complete! Final acceptance rate: 0.455
Collected 1000 posterior samples (after thinning)
Acceptance rate: 45.5%
  Iteration 10000/10000, acceptance rate: 0.455
MCMC complete! Final acceptance rate: 0.455
Collected 1000 posterior samples (after thinning)
Acceptance rate: 45.5%
Show Code
# Plot the trace plots for all of the weights as a quick diagnostic using figure subplots:
mcmc_samples_array = np.array(mcmc_samples)  # shape (n_samples, n_params)
n_params = mcmc_samples_array.shape[1]
fig, axes = plt.subplots(n_params, 1, figsize=(5, n_params), sharex=True)
for i in range(n_params):
    axes[i].plot(mcmc_samples_array[:, i], color='tab:blue', alpha=0.7)
    axes[i].set_ylabel(f'W[{i}]', fontsize=10)
    axes[i].grid(True, alpha=0.3)
axes[-1].set_xlabel('MCMC Sample Index', fontsize=12)
fig.suptitle('MCMC Trace Plots for Neural Network Weights', fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

Show Code
# Visualize MCMC posterior predictive
def predict_with_weight_samples(weight_samples, x_test, n_samples=None):
    """Generate predictions using sampled weights."""
    if n_samples is None:
        n_samples = len(weight_samples)
    
    model_temp = SmallNN(hidden_size=5)
    x_test_tensor = torch.FloatTensor(x_test)
    predictions = []
    
    # Randomly select samples to plot
    sample_indices = np.random.choice(len(weight_samples), size=min(n_samples, len(weight_samples)), replace=False)
    
    with torch.no_grad():
        for idx in sample_indices:
            model_temp.set_weights_flat(torch.FloatTensor(weight_samples[idx]))
            y_pred = model_temp(x_test_tensor).numpy()
            predictions.append(y_pred)
    
    return np.array(predictions)

# Generate MCMC posterior predictions
mcmc_predictions = predict_with_weight_samples(mcmc_samples, x_test, n_samples=100)

# Visualize
fig, ax = plt.subplots(figsize=(12, 7))

# Plot MCMC posterior samples
for i in range(len(mcmc_predictions)):
    ax.plot(x_test, mcmc_predictions[i], 'orange', alpha=0.15, linewidth=1)

# True function and data
ax.plot(x_test, y_test_true, 'k-', linewidth=3, label='True function', zorder=10)
ax.scatter(x_train, y_train, s=100, c='red', edgecolors='black', 
           linewidth=2, zorder=15, label='Training data')

# Posterior mean
mcmc_mean = mcmc_predictions.mean(axis=0)
ax.plot(x_test, mcmc_mean, 'orange', linewidth=3, label='MCMC posterior mean', linestyle='--', zorder=5)

# Dummy line for legend
ax.plot([], [], 'orange', alpha=0.6, linewidth=2, label='MCMC posterior samples')

ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title('MCMC Posterior Predictive Distribution\n(Reference "Ground Truth")', 
             fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-1.5, 1.5])

plt.tight_layout()
plt.show()

7.4.7 Advanced Sampling: Hamiltonian Monte Carlo (HMC)

There have been many substantial improvements to MCMC sampling methods over the past several decades, and one of the main advances came through the recognition that we need not only sample with knowledge of \(q\) and \(p\), but that we could leverage the gradients of \(p\). This was previously very limited or difficult to do, but with the widespread use of Automatic Differentiation libraries became much more powerful.

Upon widespread access to gradients of \(p\), we could now think about sampling processes that could leverage the slope and curvature of the likelihood to guide us toward better samples. One of the mainstream ways to do this is Hamiltonian Monte Carlo which treats sampling as analogous to estimating a dynamical system. Specifically, we treat the samples/parameters as positions, introduce an auxiliary momentum variable, and simulate Hamiltonian dynamics that preserve the joint density, typically using a symplectic integretor, such as the leapfrog integrator. Because the leapfrog integrator follows smooth trajectories, HMC avoids the random walk behaviour that slowed down our Metropolis–Hastings chains above. HMC is still fundamentally accepting or rejecting samples like MH did, accept that it’s method of producing new samples is significantly more likely to lead to acceptable samples, and thus can be far more sample efficient.

It does this through a few basic concepts: - It uses the likelihood as a Potential Energy function \(U(\theta) = -\log p(\theta \mid \text{data})\) which pulls particle trajectories toward higher posterior density. - Each particle has Kinetic Energy \(K(\mathbf{p}) = \tfrac{1}{2} \mathbf{p}^\top M^{-1} \mathbf{p}\) which comes from a tunable mass matrix \(M\) that can re-scale parameters. - A symplectic integrator (Leapfrog integration) alternates half-steps in momentum with full steps in position and nearly conserves energy.

The algorith itself is fairly straightforward in that you:

  1. Sample a fresh momentum from \(\mathcal{N}(0, M)\) to randomize the next trajectory.
  2. Integrate the Hamiltonian dynamics for \(L\) leapfrog steps with step size \(\varepsilon\).
  3. Apply a Metropolis accept/reject step using the Hamiltonian.

Some pecularities with HMC include: - Smaller step size \(\varepsilon\) improves accuracy but increases computation. If step sizes are too large, it can overshoot and causes many sample rejections. - The number of leapfrog steps \(L\) controls how far each proposal travels across the posterior ridge. Too small and the samples become auto-correlated since the sample cannot move far enough away from the current sample. - A well-chosen mass matrix and initial adaptation help HMC cope with areas of strong curvature in the posterior. - HMC relies on gradients, so we need target distributions that are differentiable and numerically stable.

Below we implement a minimal HMC sampler for the banana distribution so you can compare its behaviour with the random-walk MH sampler above.

# Hamiltonian Monte Carlo on the banana target
def grad_log_p_tilde(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """Compute the gradient of the banana log-density with respect to (x, y)."""
    x = np.asarray(x)
    y = np.asarray(y)
    grad_x = -x / 9.0 + 0.2 * x * (y - 0.1 * x**2)
    grad_y = -(y - 0.1 * x**2)
    return np.stack([grad_x, grad_y], axis=-1)

def hmc_banana(num_samples: int = 8_000, burn_in: int = 1_000, step_size: float = 0.25,
                num_leapfrog: int = 20, mass: np.ndarray | float = 1.0,
                start: tuple[float, float] = (0.0, 0.0), seed: int = 2024) -> tuple[np.ndarray, float]:
    """Run a basic HMC sampler with (scalar or diagonal) mass matrix for the banana target."""
    rng_local = np.random.default_rng(seed)
    q_current = np.array(start, dtype=float)
    if q_current.shape != (2,):
        raise ValueError("start must be a length-2 tuple or array.")

    mass_array = np.asarray(mass, dtype=float)
    if mass_array.ndim == 0:
        mass_vec = np.full(2, mass_array.item())
    elif mass_array.shape == (2,):
        mass_vec = mass_array
    else:
        raise ValueError("mass must be a scalar or length-2 array for this example.")
    if np.any(mass_vec <= 0):
        raise ValueError("mass components must be positive.")
    inv_mass = 1.0 / mass_vec

    logp_current = log_p_tilde(q_current[0], q_current[1])
    samples = np.zeros((num_samples, 2), dtype=float)
    accepts = 0

    for i in range(num_samples):
        # Sample fresh momentum from N(0, M) with diagonal mass matrix.
        p_current = rng_local.normal(scale=np.sqrt(mass_vec), size=2)
        q = q_current.copy()
        p = p_current.copy()

        # Leapfrog integration
        grad = grad_log_p_tilde(q[0], q[1])
        p += 0.5 * step_size * grad
        for lf_step in range(num_leapfrog):
            q += step_size * p * inv_mass
            grad = grad_log_p_tilde(q[0], q[1])
            if lf_step != num_leapfrog - 1:
                p += step_size * grad
        p += 0.5 * step_size * grad
        p = -p  # Reverse momentum for detailed balance

        logp_prop = log_p_tilde(q[0], q[1])
        current_H = -logp_current + 0.5 * np.sum((p_current**2) * inv_mass)
        proposed_H = -logp_prop + 0.5 * np.sum((p**2) * inv_mass)
        log_alpha = current_H - proposed_H

        if np.log(rng_local.uniform()) < min(0.0, log_alpha):
            q_current = q
            logp_current = logp_prop
            accepts += 1

        samples[i] = q_current

    if burn_in >= num_samples:
        raise ValueError("burn_in must be smaller than num_samples.")
    chain = samples[burn_in:]
    acc_rate = accepts / num_samples
    return chain, acc_rate

# Run HMC and summarize diagnostics
N_hmc = 8_000
burn_in_hmc = 1_000
step_size = 0.25
num_leapfrog = 20
mass_vec = np.array([0.7, 1.3])  # Slight re-scaling along x and y
hmc_chain, hmc_acc_rate = hmc_banana(num_samples=N_hmc, burn_in=burn_in_hmc,
                                     step_size=step_size, num_leapfrog=num_leapfrog,
                                     mass=mass_vec, start=(10.0, 10.0), seed=11)
mean_x_hmc = float(np.mean(hmc_chain[:, 0]))
mean_y_hmc = float(np.mean(hmc_chain[:, 1]))

print(f"HMC: N={N_hmc}, burn-in={burn_in_hmc}, step size={step_size}, L={num_leapfrog}, acceptance={hmc_acc_rate:.2f}")
print(f"HMC means: E[x]={mean_x_hmc:.3f}, E[y]={mean_y_hmc:.3f}")
HMC: N=8000, burn-in=1000, step size=0.25, L=20, acceptance=1.00
HMC means: E[x]=0.046, E[y]=0.860

Show Code
# Visualize samples and trace using the same style as earlier
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

ax = axes[0]
ax.scatter(hmc_chain[:, 0], hmc_chain[:, 1], s=5, c='tab:purple', alpha=0.3, label='HMC samples')
path_len = min(2_000, len(hmc_chain))
ax.plot(hmc_chain[:path_len, 0], hmc_chain[:path_len, 1], color='tab:purple', linewidth=1, alpha=0.5, label='Trace (first 2k)')
ax.set_title(f'HMC\naccept.={hmc_acc_rate:.2f}, N={N_hmc}, burn-in={burn_in_hmc}, step={step_size}, L={num_leapfrog}')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend(loc='upper right', fontsize=9)
ax.grid(True, alpha=0.3)
x_vals = np.linspace(-10, 10, 300)
y_vals = np.linspace(-4, 5, 300)
X, Y = np.meshgrid(x_vals, y_vals)
Z = np.exp(log_p_tilde(X, Y))
ax.contour(X, Y, Z, levels=10, colors='black', alpha=0.3)
ax.set_aspect('equal', adjustable='box')

ax = axes[1]
ax.plot(hmc_chain[:, 0], label='x', color='tab:blue', alpha=0.7)
ax.plot(hmc_chain[:, 1], label='y', color='tab:green', alpha=0.7)
ax.set_title('HMC Trace Plot')
ax.set_xlabel('Iteration')
ax.set_ylabel('Value')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
TipExperiment: Tuning HMC Hyperparameters
  • Decrease the step size while increasing the leapfrog count so the trajectory length stays similar. What happens to acceptance and exploration?
  • Try a very large step size or very few leapfrog steps. Do you see the chain reverting to random-walk behaviour?
  • Replace the scalar mass with a small diagonal matrix (e.g., different masses for \(x\) and \(y\)). Does this help the sampler hug the banana ridge more quickly?

7.5 Variational Inference

MCMC works but is slow, since it essentially has to generate thousands of good samples. In constrast, Variational Inference methods turn the integration problem into an optimization problem. Instead of sampling from \(p(\mathbf{w} | X, Y)\), we:

  1. Choose a tractable family \(q_\phi(\mathbf{w})\) (e.g., Gaussian)
  2. Find \(\phi\) that makes \(q_\phi\) close to the true posterior
  3. Measure “closeness” using the ELBO (Evidence Lower Bound)

7.5.1 Chosing a tractable family \(q_\phi(\mathbf{w})\)

There are many options for this, as we will see in later more advanced notebooks, but for now let’s use something simple: a mean-field Gaussian approximation: \[q_\phi(\mathbf{w}) = \mathcal{N}(\boldsymbol{\mu}_\phi, \text{diag}(\boldsymbol{\sigma}_\phi^2))\]

We will look at two examples below: the first is just a simple 2D Gaussian with a 2D mean vector and two \(\sigma\) s corresponding to a diagonal covariance. Later, we extend this concept to a Neural Network (like we did with MCMC above) each weight in the network will get its own mean and variance, and for our network with n_params parameters, we would then optimize both for all weights, leading to \(2 \times {n_{params}}\) variational parameters.

But, in either case, what should we be optimizing, exactly?

7.5.2 1. Marginal Likelihood (Evidence)

The goal in Bayesian inference is to compute the marginal likelihood (also called the evidence):

\[ p(Y|X) = \int p(Y|X, \mathbf{w})\, p(\mathbf{w})\, d\mathbf{w} \]

Which sounds fine in theory, but this integral is usually intractable for neural networks or other complex models we might be interested in. How can we solve this?

7.5.3 2. Introducing a Variational Distribution

We introduce a tractable distribution \(q_\phi(\mathbf{w})\) to approximate the true posterior \(p(\mathbf{w}|X,Y)\).

We can rewrite the log evidence using \(q_\phi\):

\[ \log p(Y|X) = \log \int p(Y|X, \mathbf{w})\, p(\mathbf{w})\, d\mathbf{w} \]

\[ = \log \int q_\phi(\mathbf{w})\, \frac{p(Y|X, \mathbf{w})\, p(\mathbf{w})}{q_\phi(\mathbf{w})}\, d\mathbf{w} \]

So far, this mathematical maneuver hasn’t really bought us much, but we can apply an approximation to help us separate out some terms.

7.5.4 3. Applying Jensen’s Inequality

Jensen’s Inequality states that, for a convex function \(f\):

\[ f(\mathbb{E}[x]) \le \mathbb{E}[f(x)] \]

and if we apply Jensen’s inequality to the above assuming \(f=\log\) and \(\mathbb{E}=\int\) (note we flip the sign of the inequality since \(\log\) is concave, not convex), we get:

\[ \log \mathbb{E}_{q_\phi}\left[\frac{p(Y|X, \mathbf{w})\, p(\mathbf{w})}{q_\phi(\mathbf{w})}\right] \geq \mathbb{E}_{q_\phi}\left[\log \frac{p(Y|X, \mathbf{w})\, p(\mathbf{w})}{q_\phi(\mathbf{w})}\right] \]

Now we can bring the log inside to help break out some of the terms and rearrange them.

7.5.5 4. The ELBO Expression

This gives us the Evidence Lower Bound (ELBO):

\[ \log p(Y|X) \geq \mathbb{E}_{q_\phi}\left[\log p(Y|X, \mathbf{w})\right] + \mathbb{E}_{q_\phi}\left[\log p(\mathbf{w})\right] - \mathbb{E}_{q_\phi}\left[\log q_\phi(\mathbf{w})\right] \]

Or, by grouping \(\mathbb{E}_{q_\phi}\left[\log p(\mathbf{w})\right]\) and \(\mathbb{E}_{q_\phi}\left[\log q_\phi(\mathbf{w})\right]\) into the (negative) KL Divergence:

\[ \mathcal{L}(\phi) = \mathbb{E}_{q_\phi}\left[\log p(Y|X, \mathbf{w})\right] - \mathrm{KL}\left(q_\phi(\mathbf{w})\,\|\,p(\mathbf{w})\right) \]

We arrive at the common form of the ELBO, where:

  • The first term encourages \(q_\phi\) to place mass on weights that explain the data well. This is the Expected log-likelihood and measures How well \(q_\phi\) explains the data.
  • The second term (the KL divergence) regularizes \(q_\phi\) to stay close to the prior.

This is the objective we will then optimize when we perform variational inference. We will see this first applied to a simple 2D example so that we can understand what is going on, and then later move to a more complex and higher dimension example of a Neural Network.

7.5.6 Example of using VI for the Banana Distribution

Let’s use our concrete example from earlier, the Banana distribution, to illustrate how Variational Inference works. Here we will use a very simple proposal distribution—a Gaussian with a mean and diagonal covariance—which is often called the mean-fild Gaussian approximation as \(q_\phi(\mathbf{w})\). We will then optimize the ELBO to approximate the posterior. That is, we will optimize:

\[ \begin{align} \mathcal{L}(\phi) &= \mathbb{E}_{q_\phi}\left[\log p(Y|X, \mathbf{w})\right] - \mathrm{KL}\left(q_\phi(\mathbf{w})\,\|\,p(\mathbf{w})\right)\\ & = \mathbb{E}_{q_\phi}\left[\log p(Y|X, \mathbf{w})\right] + \mathbb{E}_{q_\phi}\left[\log p(\mathbf{w})\right] - \mathbb{E}_{q_\phi}\left[\log q_\phi(\mathbf{w})\right]\\ & = \mathbb{E}_{q_\phi}\left[\log p(Y,w|X)\right] - \mathbb{E}_{q_\phi}\left[\log q_\phi(\mathbf{w})\right] \end{align} \]

The reason that we are only computing \(\mathbb{E}_{q_\phi}\left[\log p(Y,w|X)\right]\) in this case is because we are using a simple toy distribution where we are computing the joint distribution \(p(Y,w|X)\) directly, rather than having a likelihood and prior separately as we would in a more complex model like a neural network.

# Variational Inference on the banana target distribution with interactive visualization
rng_vi = np.random.default_rng(0)
LOG_2PI = math.log(2.0 * math.pi)

def mf_elbo(mu_vec: torch.Tensor, raw_diag: torch.Tensor, num_samples: int = None) -> torch.Tensor:
    """Analytic ELBO for a mean-field (diagonal) Gaussian q.
    
    Assumes raw_diag stores log(sigma) for each dimension.
    Computes E_q[log p] analytically for the banana target:
        log p(x,y) = -0.5 * ( x^2/9 + (y - 0.1 x^2)^2 )
    and uses the closed-form moments of a Gaussian.
    """
    # diagonal std and variances
    sigma = torch.exp(raw_diag)
    var = sigma**2

    mu_x, mu_y = mu_vec[0], mu_vec[1]
    var_x, var_y = var[0], var[1]

    # moments for x
    E_x2 = var_x + mu_x**2
    # Var(x^2) for Gaussian x ~ N(mu_x, var_x)
    Var_x2 = 2.0 * var_x**2 + 4.0 * mu_x**2 * var_x

    # First term in the ELBO
    # E_q[ (y - 0.1 x^2)^2 ] = Var(y) + (E[y] - 0.1 E[x^2])^2 + 0.01 Var(x^2)
    E_y_term = var_y + (mu_y - 0.1 * E_x2) ** 2 + 0.01 * Var_x2

    # Second term in the ELBO
    # E_q[log p]
    E_log_p = -0.5 * (E_x2 / 9.0 + E_y_term)

    # Third term in the ELBO
    # E_q[log q] for diagonal Gaussian: -0.5*(D + D*log(2π) + 2 * sum(log sigma))
    D = mu_vec.numel()
    sum_log_sigma = torch.sum(raw_diag)
    E_log_q = -0.5 * (D + D * LOG_2PI + 2.0 * sum_log_sigma)

    return E_log_p - E_log_q

def run_vi_optimization(num_steps: int = 800, lr: float = 0.08, record_points: int = 60):
    """Optimize the ELBO using the mean-field (diagonal) analytic ELBO (mf_elbo).
    
    Records intermediate means, covariances (diagonal), and ELBO values for visualization.
    """
    mu = torch.tensor([0.0, 0.0], requires_grad=True)
    raw_diag = torch.log(torch.ones(2) * 2.5)
    raw_diag.requires_grad_(True)

    optimizer = Adam([mu, raw_diag], lr=lr)
    record_every = 1
    history = {"steps": [], "mu": [], "cov": [], "elbo": []}

    for step in range(num_steps):
        optimizer.zero_grad()
        elbo = mf_elbo(mu, raw_diag)
        loss = -elbo
        loss.backward()
        optimizer.step()

        if (step % record_every == 0) or (step == num_steps - 1):
            # build diagonal cov_diagonal and covariance for recording/plotting
            cov_diagonal = np.diag(np.exp(raw_diag.detach().cpu().numpy()))
            cov_np = cov_diagonal @ cov_diagonal.T
            history["steps"].append(step)
            history["mu"].append(mu.detach().cpu().numpy())
            history["cov"].append(cov_np)
            history["elbo"].append(elbo.detach().cpu().item())

    return history

# Run VI once and cache the trajectory for the interactive plot
vi_history = run_vi_optimization(num_steps=150, record_points=150)
Show Code
history_steps = np.array(vi_history["steps"])
history_mu = np.stack(vi_history["mu"])
history_cov = np.stack(vi_history["cov"])
history_elbo = np.array(vi_history["elbo"])

x_grid = np.linspace(-10, 10, 240)
y_grid = np.linspace(-5, 10, 220)
X_grid, Y_grid = np.meshgrid(x_grid, y_grid)
target_density_grid = np.exp(log_p_tilde(X_grid, Y_grid))

def gaussian_contour_density(X: np.ndarray, Y: np.ndarray, mu: np.ndarray, cov: np.ndarray) -> np.ndarray:
    """Evaluate a full-covariance Gaussian density on a meshgrid."""
    diff_x = X - mu[0]
    diff_y = Y - mu[1]
    inv_cov = np.linalg.inv(cov)
    det_cov = float(np.maximum(np.linalg.det(cov), 1e-12))
    exponent = (inv_cov[0, 0] * diff_x**2 +
                2.0 * inv_cov[0, 1] * diff_x * diff_y +
                inv_cov[1, 1] * diff_y**2)
    norm_const = 1.0 / (2.0 * math.pi * np.sqrt(det_cov))
    return norm_const * np.exp(-0.5 * exponent)

def plot_vi_state(frame_idx: int) -> None:
    """Render the variational approximation and ELBO trace for the given optimization step."""
    mu = history_mu[frame_idx]
    cov = history_cov[frame_idx]
    elbo_val = history_elbo[frame_idx]
    vi_density = gaussian_contour_density(X_grid, Y_grid, mu, cov)

    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    ax = axes[0]
    ax.contour(X_grid, Y_grid, target_density_grid, levels=14, colors='#b3b3b3', linewidths=1.0, alpha=0.8)
    contourf = ax.contourf(X_grid, Y_grid, vi_density, levels=14, cmap='Blues', alpha=0.35)
    ax.contour(X_grid, Y_grid, vi_density, levels=14, colors='#1f77b4', linewidths=1.1, alpha=0.9)
    ax.scatter(mu[0], mu[1], color='tab:blue', s=40, label='VI mean')
    ax.set_title(f"VI approximation (step {history_steps[frame_idx]})")
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, alpha=0.25)
    ax.legend(loc='upper right', fontsize=9)

    ax_elbo = axes[1]
    ax_elbo.plot(history_steps, history_elbo, color='tab:blue', linewidth=2)
    ax_elbo.axvline(history_steps[frame_idx], color='tab:purple', linestyle='--', alpha=0.7)
    ax_elbo.scatter(history_steps[frame_idx], elbo_val, color='tab:purple', s=40)
    ax_elbo.set_title('ELBO during VI optimization')
    ax_elbo.set_xlabel('Optimization step')
    ax_elbo.set_ylabel('ELBO (MC estimate)')
    ax_elbo.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

step_slider = widgets.IntSlider(value=len(history_steps) - 1, min=0, max=len(history_steps) - 1, step=1, description='Frame')
play_ctrl = widgets.Play(interval=120, value=0, min=0, max=len(history_steps) - 1, step=1, description='Press play')
widgets.jslink((play_ctrl, 'value'), (step_slider, 'value'))
controls = widgets.HBox([play_ctrl, step_slider])
interactive_plot = widgets.interactive_output(plot_vi_state, {'frame_idx': step_slider})

display(controls, interactive_plot)

What do we notice about the VI approximation compared to MCMC? In what ways is it better or worse? Comparing the ELBO plot versus the MCMC trace plot, what differences do you notice about the convergence behavior of the two methods?

7.5.7 Monte Carlo Variational Inference

By assuming a certain form of the proposal distribution \(q_\phi(\mathbf{w})\), we were able to compute the KL divergence term in closed form and perform VI analytically. However, in more complex models (like neural networks), we often cannot compute these expectations exactly. In such cases, we can resort to using Monte Carlo Variational Inference to approximate the ELBO and its gradients.

Instead of integrating analytically, we draw samples from our Gaussian variational family and use the score-function identity below to update the parameters. This reliance on Monte Carlo introduces sampling noise, so we will see below how the ELBO estimate bounces around and how the mean and covariance of \(q_\phi\) drift across the banana distribution. The only change from the closed-form solution above is how we estimate expectations, where now our gradients still come from the “score function” of the Gaussian.2

2 There are some tricks to reduce the variance by subtracting a moving “baseline” average to the advantage term, which is implemented below, but we will not go into the details of this here.

The score-function identity states that for any distribution \(q_\phi(\mathbf{w})\) parameterized by \(\phi\): \[\nabla_\phi \mathbb{E}_{q_\phi} [f(\mathbf{w})] = \mathbb{E}_{q_\phi} [f(\mathbf{w}) \nabla_\phi \log q_\phi(\mathbf{w})]\]

where in our case, \(f(\mathbf{w}) = \log p(Y|X, \mathbf{w}) + \log p(\mathbf{w}) - \log q_\phi(\mathbf{w}) = \log p(Y,\mathbf{w}|X) - \log q_\phi(\mathbf{w})\). (This allows us to compute gradients of the ELBO with respect to \(\phi\) using samples from \(q_\phi\).)

This score-function identity is also sometimes called the “REINFORCE” gradient estimator in the reinforcement learning literature, as it was first popularized there and uses something called the “log-derivative trick”3 to move the gradient operator inside the expectation: \[ \begin{align*} \nabla_\phi \mathbb{E}_{q_\phi} [f(\mathbf{w})] &= \nabla_\phi \int q_\phi(\mathbf{w}) f(\mathbf{w}) d\mathbf{w}\\ &= \int f(\mathbf{w}) \nabla_\phi q_\phi(\mathbf{w}) d\mathbf{w} \\ &= \int f(\mathbf{w}) q_\phi(\mathbf{w}) \nabla_\phi \log q_\phi(\mathbf{w}) d\mathbf{w} \\ &= \mathbb{E}_{q_\phi} [f(\mathbf{w}) \nabla_\phi \log q_\phi(\mathbf{w})]\\ \end{align*} \]

3 Recall that \(\nabla_\phi \log(z) = \frac{\nabla_\phi z}{z}\).

So putting this all together for our toy example case here, we can estimate the gradient of the ELBO as: \[ \begin{align*} \nabla_\phi \mathcal{L}(\phi) &= \nabla_\phi \mathbb{E}_{q_\phi}\left[\log p(Y,\mathbf{w}|X)\right] - \nabla_\phi \mathbb{E}_{q_\phi}\left[\log q_\phi(\mathbf{w})\right]\\ &= \mathbb{E}_{q_\phi}\left[\left(\log p(Y,\mathbf{w}|X) - \log q_\phi(\mathbf{w})\right) \nabla_\phi \log q_\phi(\mathbf{w})\right]\\ &\approx \frac{1}{S} \sum_{s=1}^S \left(\log p(Y,\mathbf{w}^{(s)}|X) - \log q_\phi(\mathbf{w}^{(s)})\right) \nabla_\phi \log q_\phi(\mathbf{w}^{(s)}) \quad \text{where } \mathbf{w}^{(s)} \sim q_\phi(\mathbf{w}) \end{align*} \]

So we can see that this will break down into the following steps:

  1. Sample \(S\) weight vectors \(\mathbf{w}^{(s)}\) from the variational distribution \(q_\phi(\mathbf{w})\)
  2. For each sample, compute the log joint probability \(\log p(Y,\mathbf{w}^{(s)}|X)\) and the log variational probability \(\log q_\phi(\mathbf{w}^{(s)})\). We subtract these to get what is commonly called the “advantage” term, though we can see that it is just the difference between how well the sample explains the data versus how probable it is under the variational distribution, which is also akin to something called the “likelihood ratio”.
  3. Compute the score function \(\nabla_\phi \log q_\phi(\mathbf{w}^{(s)})\) for each sample, which in this case is intuitively the direction (in the model weights) that will increase the log likelihood of the variational distribution.
  4. We multiply these gradients by the advantage term of that sample from step 2. Intuitively, what this is doing is weighting the gradient direction by how much better that sample explains the data versus how probable it is under the variational distribution. If a sample explains the data well but is improbable under \(q_\phi\), we want to move \(q_\phi\) in that direction to increase its probability.
  5. Average these gradients over all \(S\) samples to get an estimate of \(\nabla_\phi \mathcal{L}(\phi)\)
# Monte Carlo VI using score-function gradients for a diagonal Gaussian

def torch_log_p(z: torch.Tensor) -> torch.Tensor:
    x = z[:, 0]
    y = z[:, 1]
    return -0.5 * ((x**2) / 9.0 + (y - 0.1 * x**2) ** 2)

def log_q_diag(z: torch.Tensor, mu: torch.Tensor, log_sigma: torch.Tensor) -> torch.Tensor:
    diff = z - mu
    inv_var = torch.exp(-2.0 * log_sigma)
    quad = torch.sum(diff**2 * inv_var, dim=1)
    log_det = 2.0 * torch.sum(log_sigma)
    return -0.5 * (quad + log_det + 2.0 * LOG_2PI)

steps = 600
batch_size = 256
lr = 0.08
record_every = 1

mu = torch.tensor([0.0, 0.0])
log_sigma = torch.log(torch.ones(2) * 2.5)

mc_history = {
    "steps": [],
    "mu": [],
    "cov": [],
    "elbo": [],
    "grad_var_mu": [],
    "grad_var_log_sigma": []
}

for step in range(steps):

    # Sample from q(z; mu, sigma)
    sigma = torch.exp(log_sigma)
    eps = torch.randn(batch_size, 2)

    z_samples = (mu + sigma * eps).detach()

    # Compute score-function gradients under samples from q
    log_p = torch_log_p(z_samples)
    log_q = log_q_diag(z_samples, mu, log_sigma)

    # Compute the ELBO gradient estimates (also called the advantage in RL))
    advantage = log_p - log_q

    # Compute score functions for q at the samples
    score_mu = (z_samples - mu) / (sigma**2)
    score_log_sigma = -1.0 + (z_samples - mu) ** 2 / (sigma**2)

    # Multiply the advantage with the score functions to get gradient estimates
    grad_mu_samples = advantage.unsqueeze(1) * score_mu
    grad_log_sigma_samples = advantage.unsqueeze(1) * score_log_sigma

    # Take the average over the MC samples
    grad_mu = grad_mu_samples.mean(dim=0)
    grad_log_sigma = grad_log_sigma_samples.mean(dim=0)

    # Now we can do a simple gradient ascent step
    mu = mu + lr * grad_mu
    log_sigma = log_sigma + lr * grad_log_sigma

    if step % record_every == 0 or step == steps - 1:
        cov = np.diag(torch.exp(2.0 * log_sigma).cpu().numpy())
        mc_history["steps"].append(step)
        mc_history["mu"].append(mu.cpu().numpy())
        mc_history["cov"].append(cov)
        mc_history["elbo"].append(advantage.mean().item())
        mc_history["grad_var_mu"].append(grad_mu_samples.var(dim=0, unbiased=False).mean().item())
        mc_history["grad_var_log_sigma"].append(grad_log_sigma_samples.var(dim=0, unbiased=False).mean().item())
Show Code
mc_history["steps"] = np.array(mc_history["steps"])
mc_history["mu"] = np.stack(mc_history["mu"])
mc_history["cov"] = np.stack(mc_history["cov"])
mc_history["elbo"] = np.array(mc_history["elbo"])
mc_history["grad_var_mu"] = np.array(mc_history["grad_var_mu"])
mc_history["grad_var_log_sigma"] = np.array(mc_history["grad_var_log_sigma"])

def plot_mc_state(frame_idx: int) -> None:
    mu_frame = mc_history["mu"][frame_idx]
    cov_frame = mc_history["cov"][frame_idx]
    step_val = mc_history["steps"][frame_idx]
    elbo_val = mc_history["elbo"][frame_idx]

    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    # adapt variable names
    mu = mu_frame
    history_steps = mc_history["steps"]

    # compute VI density on the plotting grid (assumes X_grid, Y_grid are defined as numpy meshgrid)
    mu_x, mu_y = float(mu[0]), float(mu[1])
    var_x = float(cov_frame[0, 0])
    var_y = float(cov_frame[1, 1])
    denom = 2.0 * np.pi * np.sqrt(var_x * var_y)
    exponent = -0.5 * (((X_grid - mu_x) ** 2) / var_x + ((Y_grid - mu_y) ** 2) / var_y)
    vi_density = np.exp(exponent) / denom

    ax = axes[0]
    ax.contour(X_grid, Y_grid, target_density_grid, levels=14, colors='#b3b3b3', linewidths=1.0, alpha=0.8)
    contourf = ax.contourf(X_grid, Y_grid, vi_density, levels=14, cmap='Blues', alpha=0.35)
    ax.contour(X_grid, Y_grid, vi_density, levels=14, colors='#1f77b4', linewidths=1.1, alpha=0.9)
    ax.scatter(mu[0], mu[1], color='tab:blue', s=40, label='VI mean')
    ax.set_title(f"MC VI (score-function) — step {history_steps[frame_idx]})")
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, alpha=0.25)
    ax.legend(loc='upper right', fontsize=9)

    ax = axes[1]
    ax.plot(mc_history["steps"], mc_history["elbo"], color='tab:blue', linewidth=2, label='ELBO estimate')
    ax.axvline(step_val, color='tab:purple', linestyle='--', alpha=0.7)
    ax.scatter(step_val, elbo_val, color='tab:purple', s=40)
    ax.set_title('ELBO during MC VI optimisation')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('ELBO (MC estimate)')
    ax.grid(True, alpha=0.3)

    ax2 = ax.twinx()
    ax2.plot(mc_history["steps"], mc_history["grad_var_mu"], color='tab:red', linestyle='--', linewidth=1.6, label='Var[∇μ]')
    ax2.set_ylabel('Gradient variance (μ)')

    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines + lines2, labels + labels2, loc='lower right', fontsize=9)

    plt.tight_layout()
    plt.show()

mc_slider = widgets.IntSlider(value=len(mc_history["steps"]) - 1, min=0, max=len(mc_history["steps"]) - 1, step=1, description='Frame')
mc_play = widgets.Play(interval=120, value=0, min=0, max=len(mc_history["steps"]) - 1, step=1, description='Press play')
widgets.jslink((mc_play, 'value'), (mc_slider, 'value'))
mc_controls = widgets.HBox([mc_play, mc_slider])
mc_plot = widgets.interactive_output(plot_mc_state, {'frame_idx': mc_slider})

display(mc_controls, mc_plot)

print(f"Final mean: {mc_history['mu'][-1]}")
print(f"Final marginal std: {np.sqrt(np.diag(mc_history['cov'][-1]))}")
print(f"Average gradient variance (μ): {mc_history['grad_var_mu'].mean():.3f}")
Final mean: [0.05686871 0.33599067]
Final marginal std: [1.9444891 1.0056936]
Average gradient variance (μ): 5.516

7.5.8 The Reparameterization Trick

From the above we saw that we can now estimate VI updates even if we don’t have a closed form solution, provided that we can compute gradients w.r.t. the score function (essentially the log likelihood of \(q\)), which we could in principle get through something like Automatic Differentiation. However, there is a problem with this: to compute the score function gradients, we need to differentiate through samples drawn from \(q_\phi(\mathbf{w})\). That is, our expectation is over samples \(\mathbf{w} \sim q_\phi(\mathbf{w})\), and the score function \(\nabla_\phi \log q_\phi(\mathbf{w})\) depends on \(\phi\) both directly (since \(q_\phi\) is parameterized by \(\phi\)) and indirectly through the samples \(\mathbf{w}\). This leads to stochasticity in the gradients due to the sampling process, which can result in high-variance gradient estimates.

To get around this, we can use something called the reparameterization trick, whose key idea is to move the stochasticity/randomness outside of the gradient computation. This is done by expressing samples from \(q_\phi(\mathbf{w})\) as a deterministic function of \(\phi\) and some noise variable \(\boldsymbol{\epsilon}\) that is independent of \(\phi\). For example, if \(q_\phi(\mathbf{w})\) is a Gaussian with mean \(\boldsymbol{\mu}_\phi\) and standard deviation \(\boldsymbol{\sigma}_\phi\), we can write: \[\mathbf{w} = \boldsymbol{\mu}_\phi + \boldsymbol{\sigma}_\phi \odot \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, I)\]

Where the expectation is not over \(\mathbf{w} \sim q_\phi(\mathbf{w})\) anymore, but rather over \(\boldsymbol{\epsilon} \sim \mathcal{N}(0, I)\).

Note now that \(\mathbf{w}\) is still a random variable, but the randomness comes solely from \(\boldsymbol{\epsilon}\), which is independent of \(\phi\). Instead \(\phi\) becomes a deterministic transformation of \(\boldsymbol{\epsilon}\). This allows us to rewrite the ELBO expectation as: \[ \begin{align*} \nabla_\phi \mathcal{L}(\phi) &= \mathbb{E}_{q_\phi}\left[\left(\log p(Y,\mathbf{w}|X) - \log q_\phi(\mathbf{w})\right) \nabla_\phi \log q_\phi(\mathbf{w})\right]\\ &= \mathbb{E}_{\boldsymbol{\epsilon} \sim \mathcal{N}(0, I)}\left[\left(\log p(Y,\mathbf{w}(\boldsymbol{\epsilon}, \phi)|X) - \log q_\phi(\mathbf{w}(\boldsymbol{\epsilon}, \phi))\right) \nabla_\phi \log q_\phi(\mathbf{w}(\boldsymbol{\epsilon}, \phi))\right]\\ &\approx \frac{1}{S} \sum_{s=1}^S \left(\log p(Y,\mathbf{w}(\boldsymbol{\epsilon}^{(s)}, \phi)|X) - \log q_\phi(\mathbf{w}(\boldsymbol{\epsilon}^{(s)}, \phi))\right) \nabla_\phi \log q_\phi(\mathbf{w}(\boldsymbol{\epsilon}^{(s)}, \phi)) \quad \text{where } \boldsymbol{\epsilon}^{(s)} \sim \mathcal{N}(0, I) \end{align*} \]

This subtle change has a few important implications:

  1. The randomness is now isolated in \(\boldsymbol{\epsilon}\), which does not depend on \(\phi\). This means that when we compute gradients w.r.t. \(\phi\), we can treat \(\boldsymbol{\epsilon}\) as a constant. This reduces the variance of our gradient estimates significantly.
  2. The gradients \(\nabla_\phi \log q_\phi(\mathbf{w}(\boldsymbol{\epsilon}, \phi))\) can now be computed using standard automatic differentiation techniques, as \(\mathbf{w}\) is a deterministic function of \(\phi\) and \(\boldsymbol{\epsilon}\).
  3. This trick is particularly useful when \(q_\phi(\mathbf{w})\) is a continuous distribution like a Gaussian, but it can be extended to non-continuous distributions as well, which we may cover in a later notebook.

This approach is sometimes called the “pathwise derivative estimator” because it allows us to compute gradients by following a deterministic path through the parameter space, rather than relying on stochastic sampling, and sits in contrast to “score-function” estimators which rely on sampling directly from the distribution.

# Reparameterized VI with automatic differentiation
mu_param = torch.nn.Parameter(torch.tensor([0.0, 0.0]))
log_sigma_param = torch.nn.Parameter(torch.log(torch.ones(2) * 2.5))
optimizer = Adam([mu_param, log_sigma_param], lr=0.05)

steps = 600
batch_size = 256
record_every = 1

reparam_history = {
    "steps": [],
    "mu": [],
    "cov": [],
    "elbo": [],
    "grad_var_mu": [],
    "grad_var_log_sigma": [],
    "grad_norm": []
}

for step in range(steps):
    optimizer.zero_grad()

    # Sample from q(z; mu, sigma) using reparameterization
    sigma = torch.exp(log_sigma_param)
    eps = torch.randn(batch_size, 2)
    z = mu_param + sigma * eps

    # Compute ELBO using reparameterization trick
    log_p = torch_log_p(z)
    log_q = -0.5 * ((eps**2).sum(dim=1) + 2.0 * torch.sum(log_sigma_param) + 2.0 * LOG_2PI)

    # Compute ELBO (aka the advantage)
    elbo = (log_p - log_q).mean()

    # Call AD backward to compute gradients directly on the ELBO
    (-elbo).backward()

    grad_norm = torch.sqrt(sum(param.grad.detach().pow(2).sum() for param in [mu_param, log_sigma_param])).item()
    optimizer.step()

    if step % record_every == 0 or step == steps - 1:
        with torch.no_grad():
            sigma_det = torch.exp(log_sigma_param)
            z_det = (mu_param + sigma_det * eps).detach()
            grad_log_p = torch.stack([
                -(z_det[:, 0] / 9.0) + 0.2 * z_det[:, 0] * (z_det[:, 1] - 0.1 * z_det[:, 0] ** 2),
                -(z_det[:, 1] - 0.1 * z_det[:, 0] ** 2)
            ], dim=1)
            grad_mu_samples = grad_log_p
            grad_log_sigma_samples = grad_log_p * (sigma_det * eps) + 1.0

            cov = np.diag(torch.exp(2.0 * log_sigma_param).cpu().numpy())
            reparam_history["steps"].append(step)
            reparam_history["mu"].append(mu_param.detach().cpu().numpy())
            reparam_history["cov"].append(cov)
            reparam_history["elbo"].append(elbo.item())
            reparam_history["grad_var_mu"].append(grad_mu_samples.var(dim=0, unbiased=False).mean().item())
            reparam_history["grad_var_log_sigma"].append(grad_log_sigma_samples.var(dim=0, unbiased=False).mean().item())
            reparam_history["grad_norm"].append(grad_norm)
Show Code
reparam_history["steps"] = np.array(reparam_history["steps"])
reparam_history["mu"] = np.stack(reparam_history["mu"])
reparam_history["cov"] = np.stack(reparam_history["cov"])
reparam_history["elbo"] = np.array(reparam_history["elbo"])
reparam_history["grad_var_mu"] = np.array(reparam_history["grad_var_mu"])
reparam_history["grad_var_log_sigma"] = np.array(reparam_history["grad_var_log_sigma"])
reparam_history["grad_norm"] = np.array(reparam_history["grad_norm"])

def plot_reparam_state(frame_idx: int) -> None:
    mu_frame = reparam_history["mu"][frame_idx]
    cov_frame = reparam_history["cov"][frame_idx]
    step_val = reparam_history["steps"][frame_idx]
    elbo_val = reparam_history["elbo"][frame_idx]

    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    # adapt variable names
    mu = mu_frame
    history_steps = reparam_history["steps"]

    # compute VI density on the plotting grid (assumes X_grid, Y_grid are defined as numpy meshgrid)
    mu_x, mu_y = float(mu[0]), float(mu[1])
    var_x = float(cov_frame[0, 0])
    var_y = float(cov_frame[1, 1])
    denom = 2.0 * np.pi * np.sqrt(var_x * var_y)
    exponent = -0.5 * (((X_grid - mu_x) ** 2) / var_x + ((Y_grid - mu_y) ** 2) / var_y)
    vi_density = np.exp(exponent) / denom

    ax = axes[0]
    ax.contour(X_grid, Y_grid, target_density_grid, levels=14, colors='#b3b3b3', linewidths=1.0, alpha=0.8)
    contourf = ax.contourf(X_grid, Y_grid, vi_density, levels=14, cmap='Blues', alpha=0.35)
    ax.contour(X_grid, Y_grid, vi_density, levels=14, colors='#1f77b4', linewidths=1.1, alpha=0.9)
    ax.scatter(mu[0], mu[1], color='tab:blue', s=40, label='VI mean')
    ax.set_title(f"MC VI (score-function) — step {history_steps[frame_idx]})")
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, alpha=0.25)
    ax.legend(loc='upper right', fontsize=9)

    ax = axes[1]
    ax.plot(reparam_history["steps"], reparam_history["elbo"], color='tab:blue', linewidth=2, label='ELBO estimate')
    ax.axvline(step_val, color='tab:purple', linestyle='--', alpha=0.7)
    ax.scatter(step_val, elbo_val, color='tab:purple', s=40)
    ax.set_title('ELBO during MC VI optimisation')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('ELBO (MC estimate)')
    ax.grid(True, alpha=0.3)

    ax2 = ax.twinx()
    ax2.plot(reparam_history["steps"], reparam_history["grad_var_mu"], color='tab:red', linestyle='--', linewidth=1.6, label='Var[∇μ]')
    ax2.set_ylabel('Gradient variance (μ)')

    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines + lines2, labels + labels2, loc='lower right', fontsize=9)

    plt.tight_layout()
    plt.show()

reparam_slider = widgets.IntSlider(value=len(reparam_history["steps"]) - 1, min=0, max=len(reparam_history["steps"]) - 1, step=1, description='Frame')
reparam_play = widgets.Play(interval=120, value=0, min=0, max=len(reparam_history["steps"]) - 1, step=1, description='Press play')
widgets.jslink((reparam_play, 'value'), (reparam_slider, 'value'))
reparam_controls = widgets.HBox([reparam_play, reparam_slider])
reparam_plot = widgets.interactive_output(plot_reparam_state, {'frame_idx': reparam_slider})

display(reparam_controls, reparam_plot)

print(f"Final mean: {reparam_history['mu'][-1]}")
print(f"Final marginal std: {np.sqrt(np.diag(reparam_history['cov'][-1]))}")
print(f"Average gradient variance (μ): {reparam_history['grad_var_mu'].mean():.3f}")
Final mean: [0.03177835 0.3603806 ]
Final marginal std: [1.9253347 1.0171496]
Average gradient variance (μ): 1.408

7.5.9 Black-Box Variational Inference

We saw above how to use the reparameterization trick to help reduce variance in estimating the VI update in cases where \(q_\phi(\mathbf{w})\) is a continuous distribution, such as a Gaussian or even a simpler feedforward Neural Network. In many cases (like we will see next in VAEs) this is sufficient and we are good to go.

However, there are certain cases where basic reparameterization will not work. For example, if I pick a variational family \(q_\phi(\mathbf{w})\) that is no longer a nice smooth function of the parameters \(w\), we will not be able to reparameterize \(q_\phi(\mathbf{w})\) into a smoot deterministic function. This is common in cases where the desired \(q_\phi(\mathbf{w})\) has discrete variables, such as clusters or categories.

In such cases, we actually have three options:

  1. We can go back to score estimators like we did earlier in the notebook. We will suffer from higher variance updates but it will at least allow us to train the model even under discrete or non-continuous samples.4 This is the approach we will take below.
  2. We can essentially try to “relax” discontinuities by finding a distribution that can approximate the “hard” or discrete objects we wish to study by variants that can be differentiated through. The price we will pay for this is that the optimal solution to the relaxed version of the model will be slightly biased away from the optimal solution to the original hard problem, but often this approximation error and bias is acceptable. Classical examples of these distributions include the Gumbel-Softmax or Concrete Distribution for relaxing Categorical variables, or the Birkhoff Polytope for relaxing Permutation Matrices. These extensions go beyond what I want to cover in this book, but are fascinating topics to read more about.
  3. There are also classes of “Straight-Through Estimators” which essentially detach certain gradients from the computational graph such that one can sample hard/discrete samples, but then access non-hard gradients when calling automatic differentiation. These inherit similar problems to the relaxed distributions mentioned above, which is that the gradients you compute using Straight-Through estimators are only approximating, in a heuristic sense, the true gradients.

4 There are some tricks to reduce the variance by subtracting a moving “baseline” average to the advantage term, which is implemented below, but we will not go into the details of this here.

We will demonstrate an example below where reparameterization will not be effective by fitting a three member Gaussian Mixture Model as our \(q_\phi(\mathbf{w})\) and resorting to score-function estimators. Note that this below example takes a significantly longer time to fit/train than the examples we have used so far.

# Black-box VI with a Gaussian mixture and REINFORCE gradients
try:
    from scipy.stats import gaussian_kde
except ImportError:
    gaussian_kde = None

def log_mixture_prob(z: torch.Tensor, locs: torch.Tensor, log_scales: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
    diff = z.unsqueeze(1) - locs.unsqueeze(0)
    inv_var = torch.exp(-2.0 * log_scales).unsqueeze(0)
    quad = torch.sum(diff**2 * inv_var, dim=2)
    log_det = 2.0 * log_scales.sum(dim=1)
    log_weights = F.log_softmax(logits, dim=0)
    component_log_prob = -0.5 * (quad + log_det.unsqueeze(0) + 2.0 * LOG_2PI)
    return torch.logsumexp(log_weights.unsqueeze(0) + component_log_prob, dim=1)

def sample_mixture(locs: torch.Tensor, log_scales: torch.Tensor, logits: torch.Tensor, num_samples: int) -> torch.Tensor:
    weights = F.softmax(logits, dim=0)
    cat = torch.distributions.Categorical(probs=weights)
    indices = cat.sample((num_samples,))
    sigma = torch.exp(log_scales)
    eps = torch.randn(num_samples, 2)
    samples = locs[indices] + sigma[indices] * eps
    return samples.detach()

def radial_density_estimate(samples: np.ndarray, X: np.ndarray, Y: np.ndarray, bandwidth: float = 0.7, max_centers: int = 240) -> np.ndarray:
    rng_local = np.random.default_rng(0)
    if len(samples) > max_centers:
        idx = rng_local.choice(len(samples), max_centers, replace=False)
        centers = samples[idx]
    else:
        centers = samples
    diff_x = X[None, :, :] - centers[:, 0][:, None, None]
    diff_y = Y[None, :, :] - centers[:, 1][:, None, None]
    density = np.exp(-(diff_x**2 + diff_y**2) / (2.0 * bandwidth**2)).sum(axis=0)
    norm = (2.0 * math.pi * bandwidth**2) * centers.shape[0]
    return density / norm

num_components = 3
logits = torch.nn.Parameter(torch.zeros(num_components))
locs = torch.nn.Parameter(torch.tensor([[6.0, 6.0], [0.0, 0.0], [-6.0, 6.0]]))
log_scales = torch.nn.Parameter(torch.log(torch.ones(num_components, 2) * 2.5))
optimizer = Adam([logits, locs, log_scales], lr=0.05)

steps = 250
batch_size = 512
eval_samples = 6000
record_every = 1
baseline = 0.0
baseline_beta = 0.05

bbvi_history = {
    "steps": [],
    "elbo": [],
    "grad_norm": [],
    "baseline": [],
    "mean": [],
    "cov": [],
    "density": []
}

for step in range(steps):
    optimizer.zero_grad()
    sigma = torch.exp(log_scales)

    weights = F.softmax(logits, dim=0)
    cat = torch.distributions.Categorical(probs=weights)
    indices = cat.sample((batch_size,))
    eps = torch.randn(batch_size, 2)
    samples = (locs[indices] + sigma[indices] * eps).detach()

    log_q = log_mixture_prob(samples, locs, log_scales, logits)
    log_p = torch_log_p(samples)
    elbo_estimate = (log_p - log_q).mean().item()

    if step == 0:
        baseline = elbo_estimate
    baseline = (1.0 - baseline_beta) * baseline + baseline_beta * elbo_estimate
    advantage = (log_p - log_q - baseline).detach()

    loss = -(advantage * log_q).mean()
    loss.backward()
    grad_norm = torch.sqrt(sum(param.grad.detach().pow(2).sum() for param in [logits, locs, log_scales])).item()
    optimizer.step()

    if step % record_every == 0 or step == steps - 1:
        with torch.no_grad():
            eval_samples_tensor = sample_mixture(locs, log_scales, logits, eval_samples)
            eval_np = eval_samples_tensor.cpu().numpy()
            mean_np = eval_np.mean(axis=0)
            cov_np = np.cov(eval_np, rowvar=False)
            if gaussian_kde is not None:
                kde = gaussian_kde(eval_np.T)
                density = kde(np.vstack([X_grid.ravel(), Y_grid.ravel()])).reshape(X_grid.shape)
            else:
                density = radial_density_estimate(eval_np, X_grid, Y_grid)
        bbvi_history["steps"].append(step)
        bbvi_history["elbo"].append(elbo_estimate)
        bbvi_history["grad_norm"].append(grad_norm)
        bbvi_history["baseline"].append(baseline)
        bbvi_history["mean"].append(mean_np)
        bbvi_history["cov"].append(cov_np)
        bbvi_history["density"].append(density)
Show Code
bbvi_history["steps"] = np.array(bbvi_history["steps"])
bbvi_history["elbo"] = np.array(bbvi_history["elbo"])
bbvi_history["grad_norm"] = np.array(bbvi_history["grad_norm"])
bbvi_history["baseline"] = np.array(bbvi_history["baseline"])
bbvi_history["mean"] = np.stack(bbvi_history["mean"])
bbvi_history["cov"] = np.stack(bbvi_history["cov"])
bbvi_history["density"] = np.stack(bbvi_history["density"])

final_weights = F.softmax(logits.detach(), dim=0).cpu().numpy()
final_locs = locs.detach().cpu().numpy()

fig, axes = plt.subplots(1, 2, figsize=(13, 5))
ax = axes[0]
ax.contour(X_grid, Y_grid, target_density_grid, levels=14, colors='#b3b3b3', linewidths=1.0, alpha=0.8)
ax.contourf(X_grid, Y_grid, bbvi_history["density"][-1], levels=14, cmap='Blues', alpha=0.35)
ax.contour(X_grid, Y_grid, bbvi_history["density"][-1], levels=14, colors='#1f77b4', linewidths=1.1, alpha=0.9)
scatter_sizes = 300 * final_weights
ax.scatter(final_locs[:, 0], final_locs[:, 1], s=scatter_sizes, c='tab:orange', edgecolors='black', linewidths=1.0, label='Mixture means')
ax.set_title('Gaussian mixture $q_{\phi}$ (final snapshot)')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal', adjustable='box')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=9)

ax = axes[1]
ax.plot(bbvi_history["steps"], bbvi_history["elbo"], color='tab:blue', linewidth=2, label='ELBO estimate')
ax.set_title('Score-function optimisation trace')
ax.set_xlabel('Iteration')
ax.set_ylabel('ELBO')
ax.grid(True, alpha=0.3)

ax2 = ax.twinx()
ax2.plot(bbvi_history["steps"], bbvi_history["grad_norm"], color='tab:red', linestyle='--', linewidth=1.6, label='Gradient norm')
ax2.set_ylabel('Gradient norm')

lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2, loc='lower right', fontsize=9)

plt.tight_layout()
plt.show()

print(f"Final mixture weights: {np.round(final_weights, 3)}")
print(f"Final mixture means: {np.round(final_locs, 3)}")
print(f"Average ELBO estimate: {bbvi_history['elbo'].mean():.3f}")

Final mixture weights: [0.043 0.906 0.051]
Final mixture means: [[ 5.449  3.545]
 [ 0.021  0.365]
 [-5.234  3.275]]
Average ELBO estimate: 2.372

With this you now have a high-level overview of some of this major approaches to performing Variational Inference as well as their pros and cons. We can now move on to applying this to a more complex example of the Neural Network from earlier.

7.5.10 Example of using VI for Neural Networks

Since in this case the Neural Network is a continuous and differentiable function, we will leverage the reparameterization trick to allow us to automatically differentiate backwards through the network to perform the VI updates.

# Variational Inference for Bayesian Neural Network
class BayesianNNVI(nn.Module):
    """
    Bayesian Neural Network with Variational Inference.
    
    Each weight has a variational posterior q(w) = N(mu, sigma^2).
    """
    
    def __init__(self, hidden_size=5):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Variational parameters for layer 1: input -> hidden
        self.fc1_mu_weight = nn.Parameter(torch.randn(hidden_size, 1) * 0.1)
        self.fc1_mu_bias = nn.Parameter(torch.randn(hidden_size) * 0.1)
        self.fc1_logsigma_weight = nn.Parameter(torch.ones(hidden_size, 1) * -3)  # Start with low variance
        self.fc1_logsigma_bias = nn.Parameter(torch.ones(hidden_size) * -3)
        
        # Variational parameters for layer 2: hidden -> output
        self.fc2_mu_weight = nn.Parameter(torch.randn(1, hidden_size) * 0.1)
        self.fc2_mu_bias = nn.Parameter(torch.randn(1) * 0.1)
        self.fc2_logsigma_weight = nn.Parameter(torch.ones(1, hidden_size) * -3)
        self.fc2_logsigma_bias = nn.Parameter(torch.ones(1) * -3)
        
    def sample_weights(self):
        """Sample weights using reparameterization trick."""
        # Layer 1
        fc1_weight = self.fc1_mu_weight + torch.exp(self.fc1_logsigma_weight) * torch.randn_like(self.fc1_mu_weight)
        fc1_bias = self.fc1_mu_bias + torch.exp(self.fc1_logsigma_bias) * torch.randn_like(self.fc1_mu_bias)
        
        # Layer 2
        fc2_weight = self.fc2_mu_weight + torch.exp(self.fc2_logsigma_weight) * torch.randn_like(self.fc2_mu_weight)
        fc2_bias = self.fc2_mu_bias + torch.exp(self.fc2_logsigma_bias) * torch.randn_like(self.fc2_mu_bias)
        
        return (fc1_weight, fc1_bias, fc2_weight, fc2_bias)
    
    def forward(self, x, sample=True):
        """Forward pass with sampled or mean weights."""
        x = x.view(-1, 1)
        
        if sample:
            fc1_w, fc1_b, fc2_w, fc2_b = self.sample_weights()
        else:
            fc1_w, fc1_b = self.fc1_mu_weight, self.fc1_mu_bias
            fc2_w, fc2_b = self.fc2_mu_weight, self.fc2_mu_bias
        
        # Forward propagation
        h = F.relu(F.linear(x, fc1_w, fc1_b))
        y = F.linear(h, fc2_w, fc2_b)
        
        return y.squeeze()
    
    def kl_divergence(self, prior_std=1.0):
        """Compute KL(q||p) where p is N(0, prior_std^2)."""
        kl = 0
        
        # KL for each parameter: KL(N(mu, sigma^2) || N(0, prior_std^2))
        for mu_param, logsigma_param in [
            (self.fc1_mu_weight, self.fc1_logsigma_weight),
            (self.fc1_mu_bias, self.fc1_logsigma_bias),
            (self.fc2_mu_weight, self.fc2_logsigma_weight),
            (self.fc2_mu_bias, self.fc2_logsigma_bias)
        ]:
            sigma = torch.exp(logsigma_param)
            kl += 0.5 * torch.sum(
                (mu_param**2 + sigma**2) / prior_std**2 - 1 - 2*logsigma_param + 2*np.log(prior_std)
            )
        
        return kl
    
    def elbo(self, x, y, n_samples=5, noise_std=0.1, prior_std=1.0):
        """
        Compute the Evidence Lower Bound.
        
        ELBO = E_q[log p(y|x,w)] - KL(q||p)
        """
        x = x.view(-1, 1)
        
        # Expected log-likelihood (Monte Carlo estimate)
        log_lik = 0
        for _ in range(n_samples):
            y_pred = self.forward(x, sample=True)
            log_lik += -0.5 * torch.sum(((y - y_pred) / noise_std) ** 2)
        log_lik = log_lik / n_samples
        log_lik -= len(y) * np.log(noise_std * np.sqrt(2 * np.pi))
        
        # KL divergence
        kl = self.kl_divergence(prior_std)
        
        # ELBO
        elbo = log_lik - kl
        
        return elbo, log_lik, kl

# Initialize variational model
vi_model = BayesianNNVI(hidden_size=5)

# Uncomment below if you would like a listing of all parameters and counts:
# print("Variational parameters:")
# for name, param in vi_model.named_parameters():
#     print(f"  {name}: shape {param.shape}")
#     n_var_params = sum(p.numel() for p in vi_model.parameters())
#     print(f"Total variational parameters: {n_var_params}")
#     print("(2x the number of network weights for mean + variance)")
Show Code
# Train with Variational Inference
vi_model = BayesianNNVI(hidden_size=5)
optimizer = torch.optim.Adam(vi_model.parameters(), lr=0.01)

x_train_tensor = torch.FloatTensor(x_train)
y_train_tensor = torch.FloatTensor(y_train)

n_epochs = 3000
n_mc_samples = 5  # Monte Carlo samples per ELBO evaluation

history = {
    'elbo': [],
    'log_lik': [],
    'kl': []
}

print("Training Bayesian NN with Variational Inference...")
print(f"Epochs: {n_epochs}, MC samples: {n_mc_samples}")

for epoch in range(n_epochs):
    optimizer.zero_grad()
    
    # Compute ELBO (negative for minimization)
    elbo, log_lik, kl = vi_model.elbo(x_train_tensor, y_train_tensor, 
                                       n_samples=n_mc_samples, 
                                       noise_std=noise_std, 
                                       prior_std=1.0)
    loss = -elbo  # Maximize ELBO = minimize negative ELBO
    
    # Backprop and update
    loss.backward()
    optimizer.step()
    
    # Track metrics
    history['elbo'].append(elbo.item())
    history['log_lik'].append(log_lik.item())
    history['kl'].append(kl.item())
    
    if (epoch + 1) % 500 == 0:
        print(f"Epoch {epoch+1:4d} | ELBO: {elbo.item():8.2f} | "
              f"Log-lik: {log_lik.item():8.2f} | KL: {kl.item():6.2f}")
print("Training complete!")
Training Bayesian NN with Variational Inference...
Epochs: 3000, MC samples: 5
Epoch  500 | ELBO:  -123.28 | Log-lik:   -86.77 | KL:  36.51
Epoch  500 | ELBO:  -123.28 | Log-lik:   -86.77 | KL:  36.51
Epoch 1000 | ELBO:  -113.35 | Log-lik:   -80.70 | KL:  32.65
Epoch 1000 | ELBO:  -113.35 | Log-lik:   -80.70 | KL:  32.65
Epoch 1500 | ELBO:   -35.02 | Log-lik:     7.50 | KL:  42.52
Epoch 1500 | ELBO:   -35.02 | Log-lik:     7.50 | KL:  42.52
Epoch 2000 | ELBO:   -32.42 | Log-lik:    12.57 | KL:  44.99
Epoch 2000 | ELBO:   -32.42 | Log-lik:    12.57 | KL:  44.99
Epoch 2500 | ELBO:   -25.91 | Log-lik:    16.31 | KL:  42.22
Epoch 2500 | ELBO:   -25.91 | Log-lik:    16.31 | KL:  42.22
Epoch 3000 | ELBO:   -27.61 | Log-lik:    14.09 | KL:  41.70
Training complete!
Epoch 3000 | ELBO:   -27.61 | Log-lik:    14.09 | KL:  41.70
Training complete!
Show Code
# Visualize VI training progress
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

epochs_arr = np.arange(1, len(history['elbo']) + 1)

# Plot 1: ELBO over time
ax = axes[0]
ax.plot(epochs_arr, history['elbo'], linewidth=2, color='blue')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('ELBO', fontsize=12)
ax.set_title('ELBO During VI Training', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

# Plot 2: ELBO components
ax = axes[1]
ax.plot(epochs_arr, history['log_lik'], linewidth=2, label='Expected log-likelihood', color='green')
ax.plot(epochs_arr, history['kl'], linewidth=2, label='KL divergence', color='red')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Value', fontsize=12)
ax.set_title('ELBO Components', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Final ELBO breakdown:")
print(f"  ELBO: {history['elbo'][-1]:.2f}")
print(f"  Expected log-likelihood: {history['log_lik'][-1]:.2f}")
print(f"  KL divergence: {history['kl'][-1]:.2f}")

Final ELBO breakdown:
  ELBO: -27.61
  Expected log-likelihood: 14.09
  KL divergence: 41.70
Show Code
# Generate VI posterior predictive samples
x_test_tensor = torch.FloatTensor(x_test)
vi_predictions = []

with torch.no_grad():
    for _ in range(100):
        y_pred = vi_model(x_test_tensor, sample=True).numpy()
        vi_predictions.append(y_pred)

vi_predictions = np.array(vi_predictions)
vi_mean = vi_predictions.mean(axis=0)
vi_std = vi_predictions.std(axis=0)

# Compare VI vs MCMC
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: VI Posterior Predictive
ax = axes[0]

# VI samples
for i in range(len(vi_predictions)):
    ax.plot(x_test, vi_predictions[i], 'blue', alpha=0.15, linewidth=1)

# True function and data
ax.plot(x_test, y_test_true, 'k-', linewidth=3, label='True function', zorder=10)
ax.scatter(x_train, y_train, s=100, c='red', edgecolors='black', 
           linewidth=2, zorder=15, label='Training data')

# VI mean
ax.plot(x_test, vi_mean, 'blue', linewidth=3, label='VI posterior mean', linestyle='--', zorder=5)

# Dummy line for legend
ax.plot([], [], 'blue', alpha=0.6, linewidth=2, label='VI posterior samples')

ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title('Variational Inference: Posterior Predictive', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-1.5, 1.5])

# Plot 2: Direct comparison VI vs MCMC
ax = axes[1]

# MCMC samples (lighter)
for i in range(min(50, len(mcmc_predictions))):
    ax.plot(x_test, mcmc_predictions[i], 'orange', alpha=0.1, linewidth=1)

# VI samples (darker)
for i in range(50):
    ax.plot(x_test, vi_predictions[i], 'blue', alpha=0.1, linewidth=1)

# Means
ax.plot(x_test, mcmc_mean, 'orange', linewidth=3, label='MCMC mean', linestyle='--')
ax.plot(x_test, vi_mean, 'blue', linewidth=3, label='VI mean', linestyle='--')

# True function and data
ax.plot(x_test, y_test_true, 'k-', linewidth=3, label='True function', zorder=10)
ax.scatter(x_train, y_train, s=100, c='red', edgecolors='black', 
           linewidth=2, zorder=15, label='Training data')

ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title('Comparison: VI vs MCMC', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-1.5, 1.5])

plt.tight_layout()
plt.show()

Note the difference in training times between the MCMC Method and VI Method. The main thing we give up for this increase in speed is the

7.6 Possible Additional Experiments

Now that you understand the full pipeline, you can explore below how some of the fundamental parameters of the model might affect the results:

TipExperiment 1: Vary Hidden Layer Size

Train VI models with hidden_size = 2, 5, 10, 20. Observe: - How does model flexibility change? - Does uncertainty increase or decrease? - Is there a risk of overfitting with larger networks?

TipExperiment 2: Effect of Training Data Size

Generate datasets with n_train = 5, 10, 20, 50. For each: - Train the VI model - Plot posterior predictive distributions - Compare the model uncertainty in data-rich vs data-sparse regions What do you notice?

TipExperiment 3: Prior Sensitivity

Try different prior standard deviations: prior_std = 0.1, 1.0, 10.0. Observe: - How does this affect the learned posterior? - Does a strong prior (small std) regularize more? - What happens with a weak prior (large std)?

We can do Experiment 1 together as an example, and leave the remainder for you to do independently:

Show Code
# Experiment 1: Vary hidden layer size
hidden_sizes_to_test = [2, 5, 10, 20]
results_by_size = {}

print("Experiment 1: Effect of Hidden Layer Size")
print("Training VI models with different architectures...")

for h_size in hidden_sizes_to_test:
    print(f"Training with hidden_size = {h_size}...")
    
    # Initialize and train
    model_exp = BayesianNNVI(hidden_size=h_size)
    optimizer_exp = torch.optim.Adam(model_exp.parameters(), lr=0.01)
    
    for epoch in range(2000):
        optimizer_exp.zero_grad()
        elbo, _, _ = model_exp.elbo(x_train_tensor, y_train_tensor, 
                                     n_samples=5, noise_std=noise_std, prior_std=1.0)
        loss = -elbo
        loss.backward()
        optimizer_exp.step()
    
    # Generate predictions
    predictions = []
    with torch.no_grad():
        for _ in range(100):
            y_pred = model_exp(x_test_tensor, sample=True).numpy()
            predictions.append(y_pred)
    
    predictions = np.array(predictions)
    results_by_size[h_size] = {
        'predictions': predictions,
        'mean': predictions.mean(axis=0),
        'std': predictions.std(axis=0)
    }
    
    print(f"  Final ELBO: {elbo.item():.2f}")

print("Training complete!")
Experiment 1: Effect of Hidden Layer Size
Training VI models with different architectures...
Training with hidden_size = 2...
  Final ELBO: -21.87
Training with hidden_size = 5...
  Final ELBO: -21.87
Training with hidden_size = 5...
  Final ELBO: -30.33
Training with hidden_size = 10...
  Final ELBO: -30.33
Training with hidden_size = 10...
  Final ELBO: -45.41
Training with hidden_size = 20...
  Final ELBO: -45.41
Training with hidden_size = 20...
  Final ELBO: -82.38
Training complete!
  Final ELBO: -82.38
Training complete!
Show Code
# Visualize results from Experiment 1
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

for idx, h_size in enumerate(hidden_sizes_to_test):
    ax = axes[idx]
    
    result = results_by_size[h_size]
    
    # Plot posterior samples
    for i in range(min(50, len(result['predictions']))):
        ax.plot(x_test, result['predictions'][i], 'blue', alpha=0.1, linewidth=1)
    
    # Plot mean and uncertainty bands
    ax.plot(x_test, result['mean'], 'blue', linewidth=3, label='VI posterior mean', linestyle='--')
    ax.fill_between(x_test, 
                     result['mean'] - result['std'], 
                     result['mean'] + result['std'],
                     alpha=0.3, color='blue', label='±1sigma')
    
    # True function and data
    ax.plot(x_test, y_test_true, 'k-', linewidth=2, label='True function', zorder=10)
    ax.scatter(x_train, y_train, s=80, c='red', edgecolors='black', 
               linewidth=1.5, zorder=15, label='Training data')
    
    ax.set_xlabel('Input $x$', fontsize=12)
    ax.set_ylabel('Output $y$', fontsize=12)
    ax.set_title(f'Hidden Size = {h_size} ({(h_size+1) + h_size} parameters)', 
                 fontsize=13, fontweight='bold')
    ax.legend(fontsize=10, loc='upper right')
    ax.grid(True, alpha=0.3)
    ax.set_xlim([-3.2, 3.2])
    ax.set_ylim([-1.5, 1.5])

plt.suptitle('Experiment 1: Effect of Network Capacity', fontsize=15, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

7.7 Summary and the bridge to VAEs

We saw a few key takeaways in this notebook:

Bayesian Inference Quantifies Uncertainty - Instead of a single “best” model, we maintain a distribution over models - The posterior predictive distribution \(p(y_* | x_*, X, Y)\) gives us predictions with uncertainty

Linear Models Have Limitations - Bayesian linear regression has exact, closed-form posteriors (Gaussian) - This allows us to compute predictive mean and variance analytically - However, they cannot capture nonlinear relationships, as we saw in this example

Neural Networks Provide Flexibility at the cost of an intractable posterior - Even a small 1-hidden-layer network can approximate complex functions - Prior predictive shows diverse possible functions before seeing data - Challenge: Posterior \(p(\mathbf{w} | X, Y)\) is no longer tractable.

MCMC: Accurate but Slow - Markov Chain Monte Carlo samples from the true posterior - Pros: Asymptotically exact, unbiased - Cons: Computationally expensive, slow for high-dimensional models - Useful as a reference for validation

Variational Inference: Fast Approximation - Approximate \(p(\mathbf{w} | X, Y)\) with a tractable family \(q_\phi(\mathbf{w})\) - Optimize \(\phi\) to maximize the ELBO: \[\mathcal{L}(\phi) = \mathbb{E}_{q_\phi}\left[\log p(Y | X, \mathbf{w})\right] - \text{KL}(q_\phi(\mathbf{w}) || p(\mathbf{w}))\] - Reparameterization trick enables low-variance gradient estimates - Result: Posterior approximation in seconds instead of minutes

Function Space vs Weight Space - We don’t visualize high-dimensional weight distributions directly - Instead, we visualize functions sampled from the posterior - Posterior predictive is what matters for making predictions!

7.7.1 Looking forward to VAEs (Variational Autoencoders)

The VI approach you learned here is the same mathematical foundation for VAEs. The key differences are:

  • Here (Bayesian NN): Posterior over weights \(p(\mathbf{w} | X, Y)\)
  • VAEs: Posterior over latent variables \(p(\mathbf{z} | \mathbf{x})\)

The ELBO, reparameterization trick, and optimization strategy are identical.

Moreover, in VAEs, instead of optimizing \(q_\phi(z)\) separately for each datapoint \(x\), we learn an encoder network \(q_\phi(z | x)\) that works for all \(x\). This is called amortized inference.

Questions to Reflect on

Think about these before moving to the next notebook:

  1. Why does VI underestimate uncertainty?
    Hint: Is the Mean-field approximation where we assume independence between weights reasonable? When might it not be?

  2. When would MCMC be essential despite being slow?

  3. Could we use a more expressive variational family?
    We used a Mean-field Gaussian here, but are we limited to this?