損失関数の勾配計算 (連鎖律の適用)
損失関数 \(L\) のパラメータ \(w\) に関する勾配 \(\frac{\partial L}{\partial w}\) を計算するには、連鎖律を用います。損失 \(L\) は予測値 \(\hat{y}\) の関数であり、予測値 \(\hat{y}\) は重み \(w\) の関数であるため、以下のように分解できます。
$\frac{\partial L}{\partial w} = \frac{\partial L}{\partial \hat{y}} \times \frac{\partial \hat{y}}{\partial w}$
ステップ1: \(L\) の \(\hat{y}\) に関する偏微分の計算
損失関数は \(L = \frac{1}{2}(\hat{y} - y)^2\) です。これを \(\hat{y}\) で偏微分します(\(y\) は定数として扱います)。
$\frac{\partial L}{\partial \hat{y}} = \frac{\partial}{\partial \hat{y}} \left( \frac{1}{2}(\hat{y} - y)^2 \right)\\= \frac{1}{2} \times 2(\hat{y} - y) \times \frac{\partial}{\partial \hat{y}}(\hat{y} - y) \quad \\
= (\hat{y} - y) \times 1 \
= \hat{y} - y$
ステップ2: \(\hat{y}\) の \(w\) に関する偏微分の計算
モデルの予測式は \(\hat{y} = wx\) です。これを \(w\) で偏微分します(\(x\) は定数として扱います)。
$ \frac{\partial \hat{y}}{\partial w} = \frac{\partial}{\partial w}(wx) = x$
ステップ3: 連鎖律による勾配の結合
ステップ1とステップ2の結果を連鎖律の式に代入します。
$\frac{\partial L}{\partial w} = \frac{\partial L}{\partial \hat{y}} \times \frac{\partial \hat{y}}{\partial w} = (\hat{y} - y) \times x$
ステップ4: 具体的な値の代入
与えられた値 \(x=2, w=1.5, y=4\) を使って、まず予測値 \(\hat{y}\) を計算します。
$\hat{y} = w x = 1.5 \times 2 = 3$
次に、これらの値を勾配の式に代入します。
$\frac{\partial L}{\partial w} = (\hat{y} - y) \times x = (3 - 4) \times 2 = (-1) \times 2 = -2$
したがって、重み \(w\) に関する損失の勾配は -2.0 です。
勾配の意味と利用
- 勾配 \(\frac{\partial L}{\partial w} = -2\) は、現在の点 (w=1.5) において、重み \(w\) をわずかに増加させると損失 \(L\) が減少することを示しています(勾配が負のため)。
- 最急降下法では、この勾配の逆方向(この場合は正の方向)に重みを更新することで、損失を最小化しようとします。更新量は \(-\eta \frac{\partial L}{\partial w} = -\eta (-2) = 2\eta\) となります。