バックプロパゲーションと連鎖律 (乗算ノード)
バックプロパゲーション(誤差逆伝播法)は、ニューラルネットワークなどの計算グラフにおいて、最終的な損失 \(L\) に対する各パラメータや中間変数の勾配を効率的に計算する手法です。この計算の基本原理は数学の連鎖律(Chain Rule)です。
計算グラフの観点から見ると、あるノード(ここでは乗算関数 \(f\))の出力 \(f\) に関する勾配(上流からの勾配 \(\frac{\partial L}{\partial f}\)が与えられたとき、そのノードの入力(ここでは $x,y$)に関する勾配(下流への勾配 \(\frac{\partial L}{\partial x}\) , \(\frac{\partial L}{\partial y}\) を計算することが目標です。
連鎖律によれば、下流への勾配は、上流からの勾配に「ローカル勾配(そのノード自体の出力の、入力に関する偏微分)」を乗じることで求められます。
$\frac{\partial L}{\partial \text{入力}} = \frac{\partial L}{\partial \text{出力}} \times \frac{\partial \text{出力}}{\partial \text{入力}}$
今回の問題における計算 (\(\frac{\partial L}{\partial x}\) の計算)
入力 \(x\) に関する勾配 \(\frac{\partial L}{\partial x}\) を計算します。
- 上流からの勾配: \(\frac{\partial L}{\partial f} = 0.5\)
- 出力: \(f\)
- 入力: \(x\)
連鎖律の式は以下のようになります。
$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial f} \times \frac{\partial f}{\partial x}$
ステップ1: ローカル勾配 (\(\frac{\partial f}{\partial x}\)) の計算
関数 \(f(x, y) = x \times y\) を \(x\) で偏微分します(\(y\) は定数として扱います)。
$\frac{\partial f}{\partial x} = \frac{\partial}{\partial x}(x \times y) = 1 \times y = y$
(同様に、\(\frac{\partial f}{\partial y} = x\) となります。)
ステップ2: 連鎖律の適用
求めたローカル勾配 \(\frac{\partial f}{\partial x} = y\) と、与えられた上流からの勾配 \(\frac{\partial L}{\partial f} = 0.5\) を用いて、\(\frac{\partial L}{\partial x}\) を計算します。計算には現在の入力値 \(y = -2\) を使用します。
$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial f} \times \frac{\partial f}{\partial x} = 0.5 \times y = 0.5 \times (-2) = -1.0$
したがって、入力 \(x\) に関する損失 \(L\) の勾配は -1.0 です。
乗算ノードのバックプロパゲーション
乗算ノード \(f = x \times y\) のバックプロパゲーションでは、上流からの勾配 \(\frac{\partial L}{\partial f}\) が与えられると、下流の各入力への勾配は以下のように計算されます。
- \(\frac{\partial L}{\partial x} = \frac{\partial L}{\partial f} \times y\) (上流勾配に、もう一方の入力値を掛ける)
- \(\frac{\partial L}{\partial y} = \frac{\partial L}{\partial f} \times x\) (上流勾配に、もう一方の入力値を掛ける)
これは「勾配をスイッチして渡す」ような振る舞いと覚えられます。