贝叶斯神经网络在协变量偏移下:当理论在实践中失效

发布: (2025年12月15日 GMT+8 03:25)
8 min read
原文: Dev.to

Source: Dev.to

贝叶斯鲁棒性的惊人失效

如果你一直在关注贝叶斯深度学习的文献,你很可能已经看到过标准的叙事:贝叶斯方法提供了原则性的不确定性量化,这应当使它们在分布漂移时更具鲁棒性。理论听起来很有说服力——面对分布外(out‑of‑distribution)数据时,贝叶斯模型平均(BMA)应该考虑多种可能的解释,从而得到校准的不确定性并实现更好的泛化。

但如果这个叙事本身就是根本错误的呢?如果在实际中,具有精确推断的贝叶斯神经网络(BNN)在分布漂移下的鲁棒性实际上比经典模型更差呢?

这正是 Izmailov 等人在 NeurIPS 2021 论文《Dangers of Bayesian Model Averaging under Covariate Shift》中发现的现象。他们的发现挑战了关于贝叶斯方法的核心假设,并对真实世界的应用具有重要意义。

反直觉的结果

先来看最引人注目的发现:

Bayesian neural networks under covariate shift. (a): Performance of a ResNet‑20 on the pixelate corruption in CIFAR‑10‑C. For the highest degree of corruption, a Bayesian model average underperforms a MAP solution by 25 % (44 % against 69 %) accuracy. See Izmailov et al. [2021] for details. (b): Visualization of the weights in the first layer of a Bayesian fully‑connected network on MNIST sampled via HMC. (c): The corresponding MAP weights. We visualize the weights connecting the input pixels to a neuron in the hidden layer as a 28 × 28 image, where each weight is shown in the location of the input pixel it interacts with.

是的,你没有看错。在严重受损的 CIFAR‑10‑C 数据上,使用 Hamiltonian Monte Carlo(HMC)的贝叶斯神经网络仅 44 % 的准确率,而一个简单的最大后验(MAP)估计能够达到 69 % 的准确率——相差 25 个百分点,且优势在更简单的方法一方!

这尤其令人惊讶,因为在干净、分布内的数据上,BNN 实际上比 MAP 高出约 5 %。于是我们得到一个在标准基准上表现更好,却在分布漂移时灾难性失效的方法。

为什么会这样?“死像素”类比

作者通过他们所谓的 “死像素” 现象给出了优雅的解释。考虑 MNIST 手写数字——它们的四角始终是黑色像素(强度 = 0)。这些像素在训练期间从不被激活,因而称为“死像素”。

贝叶斯问题

  • 与死像素相连的权重 不会影响训练损失(始终乘以零)。
  • 因此,这些权重的 后验等于先验(没有被更新)。
  • 在测试时,噪声可能会激活死像素。
  • 从先验中随机抽取的权重 与非零值相乘,噪声被放大并在网络中传播,导致预测质量下降。
p(w_{ij}^1 \mid \mathcal{D}) = p(w_{ij}^1) \quad \text{if } x_k^i = 0 \ \forall i

MAP 解

  • 与死像素相连的权重 在正则化项的作用下被推向零
  • 在测试时,即使死像素被激活,零权重也会忽略它们。
  • 噪声不会被放大,从而得到更鲁棒的预测。

引理 1

若特征 (x_k^i = 0) 对所有训练样本成立且先验可因子化,则
[ p(w_{ij}^1 \mid \mathcal{D}) = p(w_{ij}^1) ]
即后验等于先验,这些权重保持随机。

更一般的问题:线性依赖

死像素例子只是更广泛问题的一个特例:训练数据中的任何线性依赖都可能导致相同的失效模式

命题 2(Izmailov 等)指出,如果训练数据位于仿射子空间

[ \sum_{j=1}^m x_i^j c_j = c_0 \quad \forall i, ]

则:

  • 权重投影 (w_j^c = \sum_{i=1}^m c_i w_{ij}^1 - c_0 b_j^1) 的后验等于先验。
  • MAP 将 (w_j^c = 0)。
  • BMA 的预测对超出该子空间的测试数据极其敏感。

这解释了为何某些腐蚀(corruption)对 BNN 的伤害比其他更大:

Robustness on MNIST. Accuracy for deep ensembles, MAP and Bayesian neural networks trained on MNIST under covariate shift. Top: Fully‑connected network; bottom: Convolutional network. While on the original MNIST test set BNNs provide competitive performance, they underperform deep ensembles on most corruptions. With the CNN architecture, all BNN variants lose to MAP when evaluated on SVHN by almost 20 %.

Robustness on CIFAR‑10. Accuracy for deep ensembles, MAP and Bayesian neural networks using a CNN architecture trained on CIFAR‑10 under covariate shift. For the corruptions from CIFAR‑10‑C, we report results for corruption intensity 4. While the BNNs with both Laplace and Gaussian priors outperform deep ensembles on the in‑distribution accuracy, they underperform even a single MAP solution on most corruptions.

巧妙的解决方案:EmpCov 先验

作者提出了一个简单而优雅的修正:让先验与数据协方差结构保持一致

经验协方差(EmpCov)先验

针对第一层权重:

[ p(w^{(1)}) = \mathcal{N}!\left(0,; \alpha \Sigma + \epsilon I\right), \qquad \Sigma = \frac{1}{n-1}\sum_{i=1}^{n} x_i x_i^\top, ]

其中 (\Sigma) 为经验数据协方差,(\alpha) 为缩放因子,(\epsilon) 为小的抖动项。

Bayesian inference samples weights along low‑variance principal components from the prior, while MAP sets these weights to zero. (a): Distribution (mean ± 2 std) of projections of first‑layer weights onto PCA directions for BNN samples and MAP solutions (MLP and CNN) with different prior scales. MAP zeros out low‑variance components; BNN samples retain them. (b): Accuracy of BNN and MAP on MNIST test set with Gaussian noise applied along the 50 highest and 50 lowest variance PCA components. MAP is robust to noise along low‑variance directions, while BMA is not; both are similarly robust along high‑variance components.

工作原理

  • 先验的特征向量 = 数据的主成分(PCA)方向。
  • 先验在第 (p_i) 个主成分上的方差 为 (\alpha \sigma_i^2 + \epsilon)。
  • 对于 零方差方向((\sigma_i^2 = 0)),方差降至 (\epsilon)(极小)。
  • 结果:BNN 在不重要的方向上无法抽取大幅随机权重,从而防止噪声放大。

实验改进

腐蚀 / 漂移BNN(Gaussian)BNN(EmpCov)提升幅度
高斯噪声21.3 %52.8 %+31.5 pp
斑点噪声24.1 %54.2 %+30.1 pp
MNIST → SVHN31.2 %45.8 %+14.6 pp

EmpCov prior improves robustness. Test accuracy under covariate shift for deep ensembles, MAP optimization with SGD, and BNN with Gaussian and EmpCov priors. Left: MLP architecture trained on MNIST. Right: CNN architecture trained on CIFAR‑10. The EmpCov prior provides consistent improvement over the standard Gaussian prior, especially on noise corruptions and domain‑shift experiments (SVHN, STL‑10).

Back to Blog

相关文章

阅读更多 »