# Generate VI posterior predictive samples
x_test_tensor = torch.FloatTensor(x_test)
vi_predictions = []
with torch.no_grad():
for _ in range(100):
y_pred = vi_model(x_test_tensor, sample=True).numpy()
vi_predictions.append(y_pred)
vi_predictions = np.array(vi_predictions)
vi_mean = vi_predictions.mean(axis=0)
vi_std = vi_predictions.std(axis=0)
# Compare VI vs MCMC
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# Plot 1: VI Posterior Predictive
ax = axes[0]
# VI samples
for i in range(len(vi_predictions)):
ax.plot(x_test, vi_predictions[i], 'blue', alpha=0.15, linewidth=1)
# True function and data
ax.plot(x_test, y_test_true, 'k-', linewidth=3, label='True function', zorder=10)
ax.scatter(x_train, y_train, s=100, c='red', edgecolors='black',
linewidth=2, zorder=15, label='Training data')
# VI mean
ax.plot(x_test, vi_mean, 'blue', linewidth=3, label='VI posterior mean', linestyle='--', zorder=5)
# Dummy line for legend
ax.plot([], [], 'blue', alpha=0.6, linewidth=2, label='VI posterior samples')
ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title('Variational Inference: Posterior Predictive', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-1.5, 1.5])
# Plot 2: Direct comparison VI vs MCMC
ax = axes[1]
# MCMC samples (lighter)
for i in range(min(50, len(mcmc_predictions))):
ax.plot(x_test, mcmc_predictions[i], 'orange', alpha=0.1, linewidth=1)
# VI samples (darker)
for i in range(50):
ax.plot(x_test, vi_predictions[i], 'blue', alpha=0.1, linewidth=1)
# Means
ax.plot(x_test, mcmc_mean, 'orange', linewidth=3, label='MCMC mean', linestyle='--')
ax.plot(x_test, vi_mean, 'blue', linewidth=3, label='VI mean', linestyle='--')
# True function and data
ax.plot(x_test, y_test_true, 'k-', linewidth=3, label='True function', zorder=10)
ax.scatter(x_train, y_train, s=100, c='red', edgecolors='black',
linewidth=2, zorder=15, label='Training data')
ax.set_xlabel('Input $x$', fontsize=13)
ax.set_ylabel('Output $y$', fontsize=13)
ax.set_title('Comparison: VI vs MCMC', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([-3.2, 3.2])
ax.set_ylim([-1.5, 1.5])
plt.tight_layout()
plt.show()