交叉熵导数,第2部分:对偏置的导数设置
Source: Dev.to
Introduction
在上一篇文章中,我们回顾了使用交叉熵导数所需的关键概念。在本篇文章中,我们将一步步搭建导数的推导过程。
Predicted probability for Setosa
当我们把预测概率代入交叉熵公式时,公式的形式取决于观测到的物种。因为观测到的物种是 Setosa,所以对应的预测概率就是 Setosa 的预测概率。
使用 softmax 函数,Setosa 的预测概率为
[ p_{\text{Setosa}} = \frac{e^{z_1}}{e^{z_1}+e^{z_2}+e^{z_3}} ]
其中 (z_1, z_2, z_3) 分别是 Setosa、Versicolor 和 Virginica 的原始输出值(logits)。
把它代入交叉熵损失得到
[ L = -\log!\left(\frac{e^{z_1}}{e^{z_1}+e^{z_2}+e^{z_3}}\right) ]
如果观测到的物种是 Virginica,则 Virginica 的 softmax 方程会给出相应的预测概率
[ p_{\text{Virginica}} = \frac{e^{z_3}}{e^{z_1}+e^{z_2}+e^{z_3}} ]
每种情况都会产生稍有不同的损失表达式,进而导致相对于偏置项 (b_3) 的交叉熵导数不同。
Summary of derivatives
相对于偏置 (b_3) 的交叉熵损失的导数可以概括如下:
[ \frac{\partial L}{\partial b_3}= \begin{cases} \displaystyle \frac{\partial L}{\partial p_{\text{Setosa}}}, \frac{\partial p_{\text{Setosa}}}{\partial z_1}, \frac{\partial z_1}{\partial b_3}, & \text{if observed = Setosa}\[10pt] \displaystyle \frac{\partial L}{\partial p_{\text{Virginica}}}, \frac{\partial p_{\text{Virginica}}}{\partial z_3}, \frac{\partial z_3}{\partial b_3}, & \text{if observed = Virginica}\[10pt] \text{(similar terms for Versicolor)} & \dots \end{cases} ]
Derivative for Setosa with respect to (b_3)
Cross‑entropy loss definition
真实类别为 Setosa 的单个样本的交叉熵损失为
[ L = -\log(p_{\text{Setosa}}) ]
Softmax prediction
[ p_{\text{Setosa}} = \frac{e^{z_1}}{e^{z_1}+e^{z_2}+e^{z_3}} ]
softmax 的输入是原始输出值(logits)(z_1, z_2, z_3)。只有 Setosa 对应的 logits((z_1))会直接受到偏置 (b_3) 的影响(通过网络结构将 (b_3) 加到 Setosa 的原始输出上)。
Applying the chain rule
要使用梯度下降优化 (b_3),我们需要
[ \frac{\partial L}{\partial b_3} = \frac{\partial L}{\partial p_{\text{Setosa}}}, \frac{\partial p_{\text{Setosa}}}{\partial z_1}, \frac{\partial z_1}{\partial b_3} ]
- (\displaystyle \frac{\partial L}{\partial p_{\text{Setosa}}}= -\frac{1}{p_{\text{Setosa}}})
- (\displaystyle \frac{\partial p_{\text{Setosa}}}{\partial z_1}= p_{\text{Setosa}}(1-p_{\text{Setosa}})) (softmax 导数)
- (\displaystyle \frac{\partial z_1}{\partial b_3}=1) 若 (b_3) 直接加到 (z_1) 上;否则取决于网络的连线方式。
将这些项相乘即可得到针对 Setosa 观测的损失相对于 (b_3) 的梯度。
Next steps
在下一篇文章中,我们将显式计算这些项,并展示它们如何组合成用于梯度下降更新的最终梯度。