Bayesian Neural Networks Under Covariate Shift: When Theory Fails Practice
Source: Dev.to
The Surprising Failure of Bayesian Robustness
If you’ve been following Bayesian deep learning literature, you’ve likely encountered the standard narrative: Bayesian methods provide principled uncertainty quantification, which should make them more robust to distribution shifts. The theory sounds compelling—when faced with out-of-distribution data, Bayesian Model Averaging (BMA) should account for multiple plausible explanations, leading to calibrated uncertainty and better generalization.
But what if this narrative is fundamentally flawed? What if, in practice, Bayesian Neural Networks (BNNs) with exact inference are actually less robust to distribution shift than their classical counterparts?
This is exactly what Izmailov et al. discovered in their NeurIPS 2021 paper, “Dangers of Bayesian Model Averaging under Covariate Shift.” Their findings challenge core assumptions about Bayesian methods and have significant implications for real‑world applications.
The Counterintuitive Result
Let’s start with the most striking finding:
![Bayesian neural networks under covariate shift. (a): Performance of a ResNet‑20 on the pixelate corruption in CIFAR‑10‑C. For the highest degree of corruption, a Bayesian model average underperforms a MAP solution by 25 % (44 % against 69 %) accuracy. See Izmailov et al. [2021] for details. (b): Visualization of the weights in the first layer of a Bayesian fully‑connected network on MNIST sampled via HMC. (c): The corresponding MAP weights. We visualize the weights connecting the input pixels to a neuron in the hidden layer as a 28 × 28 image, where each weight is shown in the location of the input pixel it interacts with.](https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fj4d7zv1q4uhzywc64yqj.jpg)
Yes, you read that correctly. On severely corrupted CIFAR‑10‑C data, a Bayesian Neural Network using Hamiltonian Monte Carlo (HMC) achieves only 44 % accuracy, while a simple Maximum a‑Posteriori (MAP) estimate achieves 69 % accuracy – a 25 percentage‑point gap in favor of the simpler method!
This is particularly surprising because on clean, in‑distribution data, the BNN actually outperforms MAP by about 5 %. Thus we have a method that is better on standard benchmarks but catastrophically fails under distribution shift.
Why Does This Happen? The “Dead Pixels” Analogy
The authors provide an elegant explanation through what they call the “dead pixels” phenomenon. Consider MNIST digits—they always have black pixels in the corners (intensity = 0). These are “dead pixels” that never activate during training.
The Bayesian Problem
- Weights connected to dead pixels don’t affect the training loss (always multiplied by zero).
- Therefore, the posterior equals the prior for these weights (they are not updated).
- At test time, noise may activate dead pixels.
- Random weights drawn from the prior then multiply non‑zero values, propagating noise through the network and yielding poor predictions.
p(w_{ij}^1 \mid \mathcal{D}) = p(w_{ij}^1) \quad \text{if } x_k^i = 0 \ \forall i
The MAP Solution
- Weights connected to dead pixels are pushed toward zero by the regularizer.
- At test time, even if dead pixels activate, zero weights ignore them.
- Noise does not propagate, leading to robust predictions.
Lemma 1
If feature (x_k^i = 0) for all training examples and the prior factorizes, then
[ p(w_{ij}^1 \mid \mathcal{D}) = p(w_{ij}^1) ]
i.e., the posterior equals the prior, and these weights remain random.
The General Problem: Linear Dependencies
The dead‑pixel example is a special case of a broader issue: any linear dependency in the training data can cause the same failure mode.
Proposition 2 (Izmailov et al.) states that if the training data lie in an affine subspace
[ \sum_{j=1}^m x_i^j c_j = c_0 \quad \forall i, ]
then:
- The posterior of the weight projection (w_j^c = \sum_{i=1}^m c_i w_{ij}^1 - c_0 b_j^1) equals the prior.
- MAP sets (w_j^c = 0).
- BMA predictions become highly sensitive to test data outside the subspace.
This explains why certain corruptions hurt BNNs more than others:


The Brilliant Solution: EmpCov Prior
The authors propose a simple, elegant fix: align the prior with the data covariance structure.
Empirical Covariance (EmpCov) Prior
For first‑layer weights:
[ p(w^{(1)}) = \mathcal{N}!\left(0,; \alpha \Sigma + \epsilon I\right), \qquad \Sigma = \frac{1}{n-1}\sum_{i=1}^{n} x_i x_i^\top, ]
where (\Sigma) is the empirical data covariance, (\alpha) a scaling factor, and (\epsilon) a small jitter term.

How It Works
- Eigenvectors of the prior = principal components of the data.
- Prior variance along PC (p_i): (\alpha \sigma_i^2 + \epsilon).
- For a zero‑variance direction ((\sigma_i^2 = 0)): variance reduces to (\epsilon) (tiny).
- Result: The BNN cannot sample large random weights along unimportant directions, preventing noise amplification.
Empirical Improvements
| Corruption / Shift | BNN (Gaussian) | BNN (EmpCov) | Improvement |
|---|---|---|---|
| Gaussian noise | 21.3 % | 52.8 % | +31.5 pp |
| Shot noise | 24.1 % | 54.2 % | +30.1 pp |
| MNIST → SVHN | 31.2 % | 45.8 % | +14.6 pp |
