Bayesian Neural Networks Under Covariate Shift: When Theory Fails Practice

Published: (December 14, 2025 at 02:25 PM EST)
5 min read
Source: Dev.to

Source: Dev.to

The Surprising Failure of Bayesian Robustness

If you’ve been following Bayesian deep learning literature, you’ve likely encountered the standard narrative: Bayesian methods provide principled uncertainty quantification, which should make them more robust to distribution shifts. The theory sounds compelling—when faced with out-of-distribution data, Bayesian Model Averaging (BMA) should account for multiple plausible explanations, leading to calibrated uncertainty and better generalization.

But what if this narrative is fundamentally flawed? What if, in practice, Bayesian Neural Networks (BNNs) with exact inference are actually less robust to distribution shift than their classical counterparts?

This is exactly what Izmailov et al. discovered in their NeurIPS 2021 paper, “Dangers of Bayesian Model Averaging under Covariate Shift.” Their findings challenge core assumptions about Bayesian methods and have significant implications for real‑world applications.

The Counterintuitive Result

Let’s start with the most striking finding:

Bayesian neural networks under covariate shift. (a): Performance of a ResNet‑20 on the pixelate corruption in CIFAR‑10‑C. For the highest degree of corruption, a Bayesian model average underperforms a MAP solution by 25 % (44 % against 69 %) accuracy. See Izmailov et al. [2021] for details. (b): Visualization of the weights in the first layer of a Bayesian fully‑connected network on MNIST sampled via HMC. (c): The corresponding MAP weights. We visualize the weights connecting the input pixels to a neuron in the hidden layer as a 28 × 28 image, where each weight is shown in the location of the input pixel it interacts with.

Yes, you read that correctly. On severely corrupted CIFAR‑10‑C data, a Bayesian Neural Network using Hamiltonian Monte Carlo (HMC) achieves only 44 % accuracy, while a simple Maximum a‑Posteriori (MAP) estimate achieves 69 % accuracy – a 25 percentage‑point gap in favor of the simpler method!

This is particularly surprising because on clean, in‑distribution data, the BNN actually outperforms MAP by about 5 %. Thus we have a method that is better on standard benchmarks but catastrophically fails under distribution shift.

Why Does This Happen? The “Dead Pixels” Analogy

The authors provide an elegant explanation through what they call the “dead pixels” phenomenon. Consider MNIST digits—they always have black pixels in the corners (intensity = 0). These are “dead pixels” that never activate during training.

The Bayesian Problem

  • Weights connected to dead pixels don’t affect the training loss (always multiplied by zero).
  • Therefore, the posterior equals the prior for these weights (they are not updated).
  • At test time, noise may activate dead pixels.
  • Random weights drawn from the prior then multiply non‑zero values, propagating noise through the network and yielding poor predictions.
p(w_{ij}^1 \mid \mathcal{D}) = p(w_{ij}^1) \quad \text{if } x_k^i = 0 \ \forall i

The MAP Solution

  • Weights connected to dead pixels are pushed toward zero by the regularizer.
  • At test time, even if dead pixels activate, zero weights ignore them.
  • Noise does not propagate, leading to robust predictions.

Lemma 1

If feature (x_k^i = 0) for all training examples and the prior factorizes, then
[ p(w_{ij}^1 \mid \mathcal{D}) = p(w_{ij}^1) ]
i.e., the posterior equals the prior, and these weights remain random.

The General Problem: Linear Dependencies

The dead‑pixel example is a special case of a broader issue: any linear dependency in the training data can cause the same failure mode.

Proposition 2 (Izmailov et al.) states that if the training data lie in an affine subspace

[ \sum_{j=1}^m x_i^j c_j = c_0 \quad \forall i, ]

then:

  • The posterior of the weight projection (w_j^c = \sum_{i=1}^m c_i w_{ij}^1 - c_0 b_j^1) equals the prior.
  • MAP sets (w_j^c = 0).
  • BMA predictions become highly sensitive to test data outside the subspace.

This explains why certain corruptions hurt BNNs more than others:

Robustness on MNIST. Accuracy for deep ensembles, MAP and Bayesian neural networks trained on MNIST under covariate shift. Top: Fully‑connected network; bottom: Convolutional network. While on the original MNIST test set BNNs provide competitive performance, they underperform deep ensembles on most corruptions. With the CNN architecture, all BNN variants lose to MAP when evaluated on SVHN by almost 20 %.

Robustness on CIFAR‑10. Accuracy for deep ensembles, MAP and Bayesian neural networks using a CNN architecture trained on CIFAR‑10 under covariate shift. For the corruptions from CIFAR‑10‑C, we report results for corruption intensity 4. While the BNNs with both Laplace and Gaussian priors outperform deep ensembles on the in‑distribution accuracy, they underperform even a single MAP solution on most corruptions.

The Brilliant Solution: EmpCov Prior

The authors propose a simple, elegant fix: align the prior with the data covariance structure.

Empirical Covariance (EmpCov) Prior

For first‑layer weights:

[ p(w^{(1)}) = \mathcal{N}!\left(0,; \alpha \Sigma + \epsilon I\right), \qquad \Sigma = \frac{1}{n-1}\sum_{i=1}^{n} x_i x_i^\top, ]

where (\Sigma) is the empirical data covariance, (\alpha) a scaling factor, and (\epsilon) a small jitter term.

Bayesian inference samples weights along low‑variance principal components from the prior, while MAP sets these weights to zero. (a): Distribution (mean ± 2 std) of projections of first‑layer weights onto PCA directions for BNN samples and MAP solutions (MLP and CNN) with different prior scales. MAP zeros out low‑variance components; BNN samples retain them. (b): Accuracy of BNN and MAP on MNIST test set with Gaussian noise applied along the 50 highest and 50 lowest variance PCA components. MAP is robust to noise along low‑variance directions, while BMA is not; both are similarly robust along high‑variance components.

How It Works

  • Eigenvectors of the prior = principal components of the data.
  • Prior variance along PC (p_i): (\alpha \sigma_i^2 + \epsilon).
  • For a zero‑variance direction ((\sigma_i^2 = 0)): variance reduces to (\epsilon) (tiny).
  • Result: The BNN cannot sample large random weights along unimportant directions, preventing noise amplification.

Empirical Improvements

Corruption / ShiftBNN (Gaussian)BNN (EmpCov)Improvement
Gaussian noise21.3 %52.8 %+31.5 pp
Shot noise24.1 %54.2 %+30.1 pp
MNIST → SVHN31.2 %45.8 %+14.6 pp

EmpCov prior improves robustness. Test accuracy under covariate shift for deep ensembles, MAP optimization with SGD, and BNN with Gaussian and EmpCov priors. Left: MLP architecture trained on MNIST. Right: CNN architecture trained on CIFAR‑10. The EmpCov prior provides consistent improvement over the standard Gaussian prior, especially on noise corruptions and domain‑shift experiments (SVHN, STL‑10).

Back to Blog

Related posts

Read more »