关于 torch.nn.CrossEntropyLoss 的计算
torch中计算损失函数时,会使用到名为 CrossEntropyLoss 的交叉熵损失函数,这个函数的公式为:
$$
\begin{aligned}
loss(x,class)&=−log\frac{∑_jexp(x[j])}{exp(x[class])}\
&= −x[class] + log(∑_jexp(x[j]))
\end{aligned}
$$
class 表示该样本的分类,x[j] 表示预测函数的第 j 个输出,关于此公式的解释如下:
假设我们的预测函数的输出如下
$$
[[0.0541, 0.1762, 0.9489 ], [−0.0288, −0.8072, 0.4909]]
$$
假设我们的应该的分类如下
$$
[0,2]
$$
即第一个样本为类别class=0,第二个样本为类别class=2
那么 loss 函数则为:
$$
\begin{aligned}
\frac{e^{0.0541}}{e^{0.0541}+e^{0.1762}+e^{0.9489}}&=0.2185\
\frac{e^{0.4909}}{e^{-0.0288}+e^{-0.8072}+e^{0.4909}}&=0.5354
\end{aligned}
$$
然后计算log之后的相反数:
$$
\begin{aligned}
−\log(0.2185)&=1.5210\
−\log(0.5354)&=0.6247
\end{aligned}
$$
取均值:
$$
\frac{1.5210+0.6247}{2}=1.073
$$
示例代码如下
1 | loss_function = nn.CrossEntropyLoss() |
All articles on this blog are licensed under CC BY-NC-SA 4.0 unless otherwise stated.
Comments