贝叶斯神经网络在协变量偏移下:当理论在实践中失效
Source: Dev.to
贝叶斯鲁棒性的惊人失效
如果你一直在关注贝叶斯深度学习的文献,你很可能已经看到过标准的叙事:贝叶斯方法提供了原则性的不确定性量化,这应当使它们在分布漂移时更具鲁棒性。理论听起来很有说服力——面对分布外(out‑of‑distribution)数据时,贝叶斯模型平均(BMA)应该考虑多种可能的解释,从而得到校准的不确定性并实现更好的泛化。
但如果这个叙事本身就是根本错误的呢?如果在实际中,具有精确推断的贝叶斯神经网络(BNN)在分布漂移下的鲁棒性实际上比经典模型更差呢?
这正是 Izmailov 等人在 NeurIPS 2021 论文《Dangers of Bayesian Model Averaging under Covariate Shift》中发现的现象。他们的发现挑战了关于贝叶斯方法的核心假设,并对真实世界的应用具有重要意义。
反直觉的结果
先来看最引人注目的发现:
![Bayesian neural networks under covariate shift. (a): Performance of a ResNet‑20 on the pixelate corruption in CIFAR‑10‑C. For the highest degree of corruption, a Bayesian model average underperforms a MAP solution by 25 % (44 % against 69 %) accuracy. See Izmailov et al. [2021] for details. (b): Visualization of the weights in the first layer of a Bayesian fully‑connected network on MNIST sampled via HMC. (c): The corresponding MAP weights. We visualize the weights connecting the input pixels to a neuron in the hidden layer as a 28 × 28 image, where each weight is shown in the location of the input pixel it interacts with.](https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fj4d7zv1q4uhzywc64yqj.jpg)
是的,你没有看错。在严重受损的 CIFAR‑10‑C 数据上,使用 Hamiltonian Monte Carlo(HMC)的贝叶斯神经网络仅 44 % 的准确率,而一个简单的最大后验(MAP)估计能够达到 69 % 的准确率——相差 25 个百分点,且优势在更简单的方法一方!
这尤其令人惊讶,因为在干净、分布内的数据上,BNN 实际上比 MAP 高出约 5 %。于是我们得到一个在标准基准上表现更好,却在分布漂移时灾难性失效的方法。
为什么会这样?“死像素”类比
作者通过他们所谓的 “死像素” 现象给出了优雅的解释。考虑 MNIST 手写数字——它们的四角始终是黑色像素(强度 = 0)。这些像素在训练期间从不被激活,因而称为“死像素”。
贝叶斯问题
- 与死像素相连的权重 不会影响训练损失(始终乘以零)。
- 因此,这些权重的 后验等于先验(没有被更新)。
- 在测试时,噪声可能会激活死像素。
- 从先验中随机抽取的权重 与非零值相乘,噪声被放大并在网络中传播,导致预测质量下降。
p(w_{ij}^1 \mid \mathcal{D}) = p(w_{ij}^1) \quad \text{if } x_k^i = 0 \ \forall i
MAP 解
- 与死像素相连的权重 在正则化项的作用下被推向零。
- 在测试时,即使死像素被激活,零权重也会忽略它们。
- 噪声不会被放大,从而得到更鲁棒的预测。
引理 1
若特征 (x_k^i = 0) 对所有训练样本成立且先验可因子化,则
[ p(w_{ij}^1 \mid \mathcal{D}) = p(w_{ij}^1) ]
即后验等于先验,这些权重保持随机。
更一般的问题:线性依赖
死像素例子只是更广泛问题的一个特例:训练数据中的任何线性依赖都可能导致相同的失效模式。
命题 2(Izmailov 等)指出,如果训练数据位于仿射子空间
[ \sum_{j=1}^m x_i^j c_j = c_0 \quad \forall i, ]
则:
- 权重投影 (w_j^c = \sum_{i=1}^m c_i w_{ij}^1 - c_0 b_j^1) 的后验等于先验。
- MAP 将 (w_j^c = 0)。
- BMA 的预测对超出该子空间的测试数据极其敏感。
这解释了为何某些腐蚀(corruption)对 BNN 的伤害比其他更大:


巧妙的解决方案:EmpCov 先验
作者提出了一个简单而优雅的修正:让先验与数据协方差结构保持一致。
经验协方差(EmpCov)先验
针对第一层权重:
[ p(w^{(1)}) = \mathcal{N}!\left(0,; \alpha \Sigma + \epsilon I\right), \qquad \Sigma = \frac{1}{n-1}\sum_{i=1}^{n} x_i x_i^\top, ]
其中 (\Sigma) 为经验数据协方差,(\alpha) 为缩放因子,(\epsilon) 为小的抖动项。

工作原理
- 先验的特征向量 = 数据的主成分(PCA)方向。
- 先验在第 (p_i) 个主成分上的方差 为 (\alpha \sigma_i^2 + \epsilon)。
- 对于 零方差方向((\sigma_i^2 = 0)),方差降至 (\epsilon)(极小)。
- 结果:BNN 在不重要的方向上无法抽取大幅随机权重,从而防止噪声放大。
实验改进
| 腐蚀 / 漂移 | BNN(Gaussian) | BNN(EmpCov) | 提升幅度 |
|---|---|---|---|
| 高斯噪声 | 21.3 % | 52.8 % | +31.5 pp |
| 斑点噪声 | 24.1 % | 54.2 % | +30.1 pp |
| MNIST → SVHN | 31.2 % | 45.8 % | +14.6 pp |
