交叉熵导数,第2部分:对偏置的导数设置

发布: (2026年2月3日 GMT+8 03:57)
3 min read
原文: Dev.to

Source: Dev.to

Introduction

在上一篇文章中,我们回顾了使用交叉熵导数所需的关键概念。在本篇文章中,我们将一步步搭建导数的推导过程。

Predicted probability for Setosa

当我们把预测概率代入交叉熵公式时,公式的形式取决于观测到的物种。因为观测到的物种是 Setosa,所以对应的预测概率就是 Setosa 的预测概率。

使用 softmax 函数,Setosa 的预测概率为

[ p_{\text{Setosa}} = \frac{e^{z_1}}{e^{z_1}+e^{z_2}+e^{z_3}} ]

其中 (z_1, z_2, z_3) 分别是 Setosa、Versicolor 和 Virginica 的原始输出值(logits)。

把它代入交叉熵损失得到

[ L = -\log!\left(\frac{e^{z_1}}{e^{z_1}+e^{z_2}+e^{z_3}}\right) ]

如果观测到的物种是 Virginica,则 Virginica 的 softmax 方程会给出相应的预测概率

[ p_{\text{Virginica}} = \frac{e^{z_3}}{e^{z_1}+e^{z_2}+e^{z_3}} ]

每种情况都会产生稍有不同的损失表达式,进而导致相对于偏置项 (b_3) 的交叉熵导数不同。

Summary of derivatives

相对于偏置 (b_3) 的交叉熵损失的导数可以概括如下:

[ \frac{\partial L}{\partial b_3}= \begin{cases} \displaystyle \frac{\partial L}{\partial p_{\text{Setosa}}}, \frac{\partial p_{\text{Setosa}}}{\partial z_1}, \frac{\partial z_1}{\partial b_3}, & \text{if observed = Setosa}\[10pt] \displaystyle \frac{\partial L}{\partial p_{\text{Virginica}}}, \frac{\partial p_{\text{Virginica}}}{\partial z_3}, \frac{\partial z_3}{\partial b_3}, & \text{if observed = Virginica}\[10pt] \text{(similar terms for Versicolor)} & \dots \end{cases} ]

Derivative for Setosa with respect to (b_3)

Cross‑entropy loss definition

真实类别为 Setosa 的单个样本的交叉熵损失为

[ L = -\log(p_{\text{Setosa}}) ]

Softmax prediction

[ p_{\text{Setosa}} = \frac{e^{z_1}}{e^{z_1}+e^{z_2}+e^{z_3}} ]

softmax 的输入是原始输出值(logits)(z_1, z_2, z_3)。只有 Setosa 对应的 logits((z_1))会直接受到偏置 (b_3) 的影响(通过网络结构将 (b_3) 加到 Setosa 的原始输出上)。

Applying the chain rule

要使用梯度下降优化 (b_3),我们需要

[ \frac{\partial L}{\partial b_3} = \frac{\partial L}{\partial p_{\text{Setosa}}}, \frac{\partial p_{\text{Setosa}}}{\partial z_1}, \frac{\partial z_1}{\partial b_3} ]

  • (\displaystyle \frac{\partial L}{\partial p_{\text{Setosa}}}= -\frac{1}{p_{\text{Setosa}}})
  • (\displaystyle \frac{\partial p_{\text{Setosa}}}{\partial z_1}= p_{\text{Setosa}}(1-p_{\text{Setosa}})) (softmax 导数)
  • (\displaystyle \frac{\partial z_1}{\partial b_3}=1) 若 (b_3) 直接加到 (z_1) 上;否则取决于网络的连线方式。

将这些项相乘即可得到针对 Setosa 观测的损失相对于 (b_3) 的梯度。

Next steps

在下一篇文章中,我们将显式计算这些项,并展示它们如何组合成用于梯度下降更新的最终梯度。

Back to Blog

相关文章

阅读更多 »

理解梯度爆炸问题

为什么神经网络会爆炸——一个帮助训练的简单修复 一些神经网络,尤其是RNN,在训练时可能感觉像在风暴中驾驶船只,因为微小的…