CozyRats Notebook

Pythonや機械学習、データの可視化について書きます。

分類の損失関数によく使われる交差エントロピーを深ぼってみる

損失関数を計算する際、回帰分析であれば最小二乗法が使われますが、分類問題であれば最尤推定法や交差エントロピー(cross entropy)を活用するシーンが多いです。

何となく使っていることが多かったので、実例を交えて整理してみました。

TL;DL

結論としては、損失関数として交差エントロピーを活用するメリットとして、以下が大きいようでした。

勾配降下法で学習を進める際に、

  1. 損失関数が大きかった時の修正値が大きくなり、学習効率が良い
  2. 勾配降下法を利用する際に微分計算をするが、 \exp などが数式にあるため計算しやすい

もう少し詳しく見ていきたいと思います。

交差エントロピーの式を眺める

改めてですが、交差エントロピー  H(p,q) は以下のように定義されます。

 H(p, q) = - \sum_{x} p(x) \log(q(x))

ここで、

  •  p(x) :真の確率分布
  •  q(x) :予測した確率分布

となります。式の形が物理学の統計力学に出てくるエントロピーの定義とほぼ同じため、エントロピーという名前が使われたようです。

損失関数のため、真の確率分布と予測した確率分布が一致する方が  H(p,q) は小さくなります。

実例

簡単に一つ例をあげてみたいと思います。 今回、「うどん」の画像を見せた時に、それが「そば」「うどん」「そうめん」のどれに該当するかを予測するとします。

正解の確率分布  p(x) は当然 (そば, うどん, そうめん) = (0, 1, 0) となります。 それに対してとある学習モデルによって得られた予測値  q(x) では (そば, うどん, そうめん) = (0.2, 0.5, 0.3) となったとします。

この場合の交差エントロピーを計算すると、

 \begin{align}
H(p, q) = - 0 \times  \log(0.2) - 1 \times \log(0.5) - 0 \times \log(0.3) = 0.69
\end{align}

となります。(そばの計算、うどんの計算、そうめんの計算をそれぞれしている)

今度はモデルを改善して、先ほどよりも精度が良くなり (そば, うどん, そうめん) = (0.1, 0.8, 0.1) となったとします。 同じように交差エントロピーは、

 \begin{align}
H(p, q) = - 0 \times  \log(0.1) - 1 \times \log(0.8) - 0 \times \log(0.1) = 0.22
\end{align}

となりました。損失関数なので精度が低くなればより妥当なモデルと言えるので、精度が良くなったことを表現できています。

上記2つの計算を見ると分かりますが、定義式で  \sum をとっていても、結局は正解データ(うどんの項)のところしか値が残りません。 結局うどんという正解 1 に対して、2つ目の例であれば 0.8 まで予測したということになります。

うどん以外の誤ったクラスに関する項は交差エントロピーの計算には一切使われず、分類問題で使われる交差エントロピーは、正解データに対してどれくらい予測したのかのみに影響します。

ここで、横軸をうどんの項(正解データの項)にかけた予測期待、縦軸を交差エントロピーにした時のグラフを出すと、以下のようになります。

当たり前ですが、うどんへの予測期待を大きくすればするほど損失関数は減少します。

損失関数に交差エントロピーを使う利点

上記で出てきたグラフの形を見てもらうと分かりやすいですが、機械学習ではモデルを調整する(=学ぶ)過程で、勾配降下法という手法がよく使われます。

これは、正解と予測の違いを損失関数という形で定義し、損失関数が減少するようにパラメーターをチューニングします。

交差エントロピーのグラフを見てもらうを分かるように、正解データへの誤りが大きければ大きいほど、損失関数の計算結果が大きくなります。そのため、大きく外しているな ということを学ぶことができ、より大胆な修正をしやすくなります。

そのため、少ない計算コストで大きな効果が得やすくなります。

ソース

import numpy as np
import seaborn as sns
import japanize_matplotlib


# 交差エントロピー
def cross_ent(p, q):
    h = -p * np.log(q)
    return h

# グラフ描画
x = np.linspace(0.0001, 1, 100)
y = cross_ent(1, x)
sns.set(rc = {'figure.figsize':(15,8)})
sns.lineplot(x, y).set(xlabel='q (うどん)', ylabel='H (1,q (x))')