概要
Pytorch の自動微分について解説します。
自動微分
次の合成関数を考えます。
f ( x 1 , x 2 ) = log ( x 1 x 2 ) × sin ( x 1 x 2 )
f(x_1, x_2) = \log(x_1 x_2) \times \sin(x_1 x_2)
f ( x 1 , x 2 ) = log ( x 1 x 2 ) × sin ( x 1 x 2 ) 合成関数の出力を計算する場合、入力から出力に向けて順番に演算を行っていきます。これを順伝搬 (Forward Propagation) といいます。
y 1 = x 1 × x 2 y 2 = log ( y 1 ) y 3 = sin ( y 1 ) z = y 2 × y 3
\begin{aligned}
y_1 &= x_1 \times x_2 \\
y_2 &= \log(y_1) \\
y_3 &= \sin(y_1) \\
z &= y_2 \times y_3 \\
\end{aligned}
y 1 y 2 y 3 z = x 1 × x 2 = log ( y 1 ) = sin ( y 1 ) = y 2 × y 3 例: x 1 = 1 , x 2 = 2 x_1 = 1, x_2 = 2 x 1 = 1 , x 2 = 2 のとき、
y 1 = 1 × 2 y 2 = log ( 2 ) y 3 = sin ( 2 ) z = log ( 2 ) × sin ( 2 )
\begin{aligned}
y_1 &= 1 \times 2 \\
y_2 &= \log(2) \\
y_3 &= \sin(2) \\
z &= \log(2) \times \sin(2) \\
\end{aligned}
y 1 y 2 y 3 z = 1 × 2 = log ( 2 ) = sin ( 2 ) = log ( 2 ) × sin ( 2 ) この計算過程を演算をノード、値 (入力や計算途中の出力) をエッジで表現したグラフで表します。これを計算グラフ (Computation Graph) といいます。
計算グラフ
この合成関数の微分は、連鎖律より合成関数を構成する関数の微分の組み合わせで計算できます。
フォーワードモードの自動微分
∂ z ∂ x 1 \frac{\partial z}{\partial x_1} ∂ x 1 ∂ z の偏微分係数を求めることを考えます。フォーワードモードの自動微分 (Forward Mode Automatic Differentiation) では、以下のように入力側から順番に計算していきます。式からわかるように順伝搬時の途中の出力 y 1 , y 2 , y 3 y_1, y_2, y_3 y 1 , y 2 , y 3 を使用するので、記録しておく必要があります。
∂ y 1 ∂ x 1 = ∂ ∂ x 1 ( x 1 x 2 ) = x 2 ∂ y 2 ∂ x 1 = ∂ y 2 ∂ y 1 ∂ y 1 ∂ x 1 = ∂ ∂ y 1 log y 1 ⋅ x 2 = x 2 y 1 ∂ y 3 ∂ x 1 = ∂ y 3 ∂ y 1 ∂ y 1 ∂ x 1 = ∂ ∂ y 1 sin y 1 ⋅ x 2 = cos y 1 ⋅ x 2 ∂ z ∂ x 1 = ∂ z ∂ y 2 ∂ y 2 ∂ x 1 + ∂ z ∂ y 3 ∂ y 3 ∂ x 1 = ∂ ∂ y 2 ( y 2 y 3 ) ⋅ ∂ y 2 ∂ x 1 + ∂ ∂ y 3 ( y 2 y 3 ) ⋅ ∂ y 3 ∂ x 1 = y 3 ⋅ x 2 y 1 + y 2 ⋅ cos y 1 ⋅ x 2
\begin{aligned}
\frac{\partial y_1}{\partial x_1}
&= \frac{\partial}{\partial x_1} (x_1 x_2) = x_2 \\
\frac{\partial y_2}{\partial x_1}
&= \frac{\partial y_2}{\partial y_1}\frac{\partial y_1}{\partial x_1}
= \frac{\partial}{\partial y_1}\log y_1 \cdot x_2
= \frac{x_2}{y_1} \\
\frac{\partial y_3}{\partial x_1}
&= \frac{\partial y_3}{\partial y_1}\frac{\partial y_1}{\partial x_1}
= \frac{\partial}{\partial y_1} \sin y_1 \cdot x_2 = \cos y_1 \cdot x_2 \\
\frac{\partial z}{\partial x_1} &=
\frac{\partial z}{\partial y_2} \frac{\partial y_2}{\partial x_1}
+ \frac{\partial z}{\partial y_3} \frac{\partial y_3}{\partial x_1} \\
&= \frac{\partial}{\partial y_2}(y_2 y_3) \cdot \frac{\partial y_2}{\partial x_1}
+ \frac{\partial}{\partial y_3}(y_2 y_3) \cdot \frac{\partial y_3}{\partial x_1} \\
&= y_3 \cdot \frac{x_2}{y_1}
+ y_2 \cdot \cos y_1 \cdot x_2 \\
\end{aligned}
∂ x 1 ∂ y 1 ∂ x 1 ∂ y 2 ∂ x 1 ∂ y 3 ∂ x 1 ∂ z = ∂ x 1 ∂ ( x 1 x 2 ) = x 2 = ∂ y 1 ∂ y 2 ∂ x 1 ∂ y 1 = ∂ y 1 ∂ log y 1 ⋅ x 2 = y 1 x 2 = ∂ y 1 ∂ y 3 ∂ x 1 ∂ y 1 = ∂ y 1 ∂ sin y 1 ⋅ x 2 = cos y 1 ⋅ x 2 = ∂ y 2 ∂ z ∂ x 1 ∂ y 2 + ∂ y 3 ∂ z ∂ x 1 ∂ y 3 = ∂ y 2 ∂ ( y 2 y 3 ) ⋅ ∂ x 1 ∂ y 2 + ∂ y 3 ∂ ( y 2 y 3 ) ⋅ ∂ x 1 ∂ y 3 = y 3 ⋅ y 1 x 2 + y 2 ⋅ cos y 1 ⋅ x 2
∂ z ∂ x 1 \frac{\partial z}{\partial x_1} ∂ x 1 ∂ z を計算しましたが、∂ z ∂ x 2 \frac{\partial z}{\partial x_2} ∂ x 2 ∂ z を求めたい場合、同様の計算をもう一度行う必要があります。
バックワードモードの自動微分
バックワードモードの自動微分 (Backward Mode Automatic Differentiation) では、以下のように出力側から順番に計算していきます。式からわかるように順伝搬時の途中の出力 y 1 , y 2 , y 3 y_1, y_2, y_3 y 1 , y 2 , y 3 を使用するので、記録しておく必要があります。
∂ z ∂ y 3 = ∂ ∂ y 3 ( y 2 y 3 ) = y 2 ∂ z ∂ y 2 = ∂ ∂ y 3 ( y 2 y 3 ) = y 3 ∂ z ∂ y 1 = ∂ z ∂ y 2 ∂ y 2 ∂ y 1 + ∂ z ∂ y 3 ∂ y 3 ∂ y 1 = ∂ z ∂ y 2 ⋅ ∂ ∂ y 1 log y 1 + ∂ z ∂ y 3 ⋅ ∂ ∂ y 1 sin y 1 = ∂ z ∂ y 2 ⋅ 1 y 1 + ∂ z ∂ y 3 ⋅ cos y 1 = y 3 ⋅ 1 y 1 + y 2 ⋅ cos y 1 ∂ z ∂ x 1 = ∂ z ∂ y 1 ∂ y 1 ∂ x 1 = ( y 3 ⋅ 1 y 1 + y 2 ⋅ cos y 1 ) x 2 ∂ z ∂ x 2 = ∂ z ∂ y 1 ∂ y 1 ∂ x 2 = ( y 3 ⋅ 1 y 1 + y 2 ⋅ cos y 1 ) x 1
\begin{aligned}
\frac{\partial z}{\partial y_3} &= \frac{\partial}{\partial y_3}(y_2 y_3) = y_2 \\
\frac{\partial z}{\partial y_2} &= \frac{\partial}{\partial y_3}(y_2 y_3) = y_3 \\
\frac{\partial z}{\partial y_1} &=
\frac{\partial z}{\partial y_2}\frac{\partial y_2}{\partial y_1}
+ \frac{\partial z}{\partial y_3}\frac{\partial y_3}{\partial y_1} \\
&= \frac{\partial z}{\partial y_2} \cdot \frac{\partial}{\partial y_1} \log y_1
+ \frac{\partial z}{\partial y_3} \cdot \frac{\partial}{\partial y_1} \sin y_1 \\
&= \frac{\partial z}{\partial y_2} \cdot \frac{1}{y_1}
+ \frac{\partial z}{\partial y_3} \cdot \cos y_1 \\
&= y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \\
\frac{\partial z}{\partial x_1} &=
\frac{\partial z}{\partial y_1}\frac{\partial y_1}{\partial x_1} \\
&= \left(y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \right) x_2 \\
\frac{\partial z}{\partial x_2} &=
\frac{\partial z}{\partial y_1}\frac{\partial y_1}{\partial x_2} \\
&= \left(y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \right) x_1 \\
\end{aligned}
∂ y 3 ∂ z ∂ y 2 ∂ z ∂ y 1 ∂ z ∂ x 1 ∂ z ∂ x 2 ∂ z = ∂ y 3 ∂ ( y 2 y 3 ) = y 2 = ∂ y 3 ∂ ( y 2 y 3 ) = y 3 = ∂ y 2 ∂ z ∂ y 1 ∂ y 2 + ∂ y 3 ∂ z ∂ y 1 ∂ y 3 = ∂ y 2 ∂ z ⋅ ∂ y 1 ∂ log y 1 + ∂ y 3 ∂ z ⋅ ∂ y 1 ∂ sin y 1 = ∂ y 2 ∂ z ⋅ y 1 1 + ∂ y 3 ∂ z ⋅ cos y 1 = y 3 ⋅ y 1 1 + y 2 ⋅ cos y 1 = ∂ y 1 ∂ z ∂ x 1 ∂ y 1 = ( y 3 ⋅ y 1 1 + y 2 ⋅ cos y 1 ) x 2 = ∂ y 1 ∂ z ∂ x 2 ∂ y 1 = ( y 3 ⋅ y 1 1 + y 2 ⋅ cos y 1 ) x 1
フォーワードモードと違い、1度のバックワードで ∂ z ∂ x 1 \frac{\partial z}{\partial x_1} ∂ x 1 ∂ z 以外の出力に対するすべての係数の偏微分係数が求まります。ディープラーニングではモデルのすべての係数の偏微分係数を求める必要があるため、バックワードモードの自動微分が使われます。
Pytorch の自動微分
Tensor.backward()、Tensor.grad 属性
Pytorch で上記の計算グラフを x 1 = 1 , x 2 = 2 x_1 = 1, x_2 = 2 x 1 = 1 , x 2 = 2 として作成します。計算グラフを作成したあと、微分対象のテンソル (例: z) の Tensor.backward()
を呼び出します。すると、逆伝搬が行われ、Tensor.grad
属性に計算した微分係数が記録されます。
Tensor.requires_grad
逆伝搬を行うには、途中の出力値を記録しておく必要があり、その分のメモリが必要になります。Pytorch では、演算の入力のテンソルの Tensor.requires_grad
属性が True の場合のみ、演算の出力のテンソルの値が記録されるようになっています。そのため、テンソル x1, x2
を作成するときに requires_grad=True
引数を指定し、このテンソルの微分係数を計算する必要があることを設定しています。これを設定しない場合、微分係数が計算できないため、以下のエラーが発生します。
dz/dx1=0.33239537477493286
dz/dx2=0.16619768738746643
順伝搬、逆伝搬を先程の式に当てはめて計算すると、Pytorch の計算結果と一致することがわかります。
y 1 = 2 y 2 = log ( 2 ) y 3 = sin ( 2 ) z = log ( 2 ) sin ( 2 )
\begin{aligned}
y_1 &= 2 \\
y_2 &= \log(2) \\
y_3 &= \sin(2) \\
z &= \log(2)\sin(2) \\
\end{aligned}
y 1 y 2 y 3 z = 2 = log ( 2 ) = sin ( 2 ) = log ( 2 ) sin ( 2 ) ∂ z ∂ x 1 = ( sin 2 ⋅ 1 2 + log 2 ⋅ cos 2 ) 2 = 0.3323954139224974 ∂ z ∂ x 2 = ( sin 2 ⋅ 1 2 + log 2 ⋅ cos 2 ) = 0.1661977069612487
\begin{aligned}
\frac{\partial z}{\partial x_1} &=
\left(\sin2 \cdot \frac{1}{2} + \log2 \cdot \cos 2 \right) 2 = 0.3323954139224974\\
\frac{\partial z}{\partial x_2} &=
\left(\sin2 \cdot \frac{1}{2} + \log2 \cdot \cos 2 \right) = 0.1661977069612487 \\
\end{aligned}
∂ x 1 ∂ z ∂ x 2 ∂ z = ( sin 2 ⋅ 2 1 + log 2 ⋅ cos 2 ) 2 = 0.3323954139224974 = ( sin 2 ⋅ 2 1 + log 2 ⋅ cos 2 ) = 0.1661977069612487
作成したテンソルをあとから Tensor.requires_grad == True
に設定したい場合、次の2通りの方法があります。
Tensor.requires_grad = True
とする。
Tensor.requires_grad_()
を呼び出す。
torch.no_grad()
推論時など、あとで逆伝搬を行う予定がない場合、順伝搬時に途中の出力を記録しておくことはメモリの無駄遣いです。途中の出力を記録しないようにするには、すべてのテンソルに Tensor.requires_grad = False
を設定するか、コンテキストマネージャー torch.no_grad()
を使用します。このコンテキスト内の計算は、Tensor.requires_grad == True
であっても、途中の出力が記録されません。
Tensor.retain_grad()
デフォルトでは、逆伝搬時に微分係数が記録されるのは出力側から見て末端の入力の Tensor だけなので、注意してください。それ以外のテンソルの grad
属性は値が None になっており、アクセスしようとすると、警告が発生します。
逆伝搬時に末端以外のテンソルの微分係数の記録するようにしたい場合は Tensor.backward()
前に Tensor.retain_grad()
を呼び出します。末端のテンソルかどうかは Tensor.is_leaf
属性で確認できます。現在の状態は Tensor.retains_grad
属性で確認できます。
dz/dx1=0.33239537477493286, is_leaf=True
dz/dx2=0.16619768738746643, is_leaf=True
dz/y1=0.16619768738746643, is_leaf=False
dz/y2=0.9092974066734314, is_leaf=False
dz/y3=0.6931471824645996, is_leaf=False
Torch.detach()
Torch.detach()
は呼び出したテンソルを計算グラフから切り離したテンソルを作成します。shallow copy なので、切り離したテンソルと元のテンソルは値自体は共有しています。
Torch.detach()
dz/dx1=-0.5769020318984985
dz/dx2=-0.28845101594924927
Tensor の numpy への変換
以上をまとめると、あるテンソルを numpy 配列に変換したい場合、以下のようにします。
計算グラフを描画する
計算グラフを可視化するための torchviz というライブラリを紹介します。このソフトは Pytorch で作成した計算グラフを Graphviz を使って可視化できます。
torchviz
のほか、Graphviz を別途インストールする必要があります。
[blogcard url=”https://pystyle.info/how-to-install-graphviz-on-windows-and-ubuntu”]
%3
140080637344960
()
140080638265040
MulBackward0
140080638265040->140080637344960
140080638264128
LogBackward0
140080638264128->140080638265040
140080638264224
MulBackward0
140080638264224->140080638264128
140080638264272
SinBackward0
140080638264224->140080638264272
140080638264416
AccumulateGrad
140080638264416->140080638264224
140080637343920
()
140080637343920->140080638264416
140080638264368
AccumulateGrad
140080638264368->140080638264224
140080637902736
()
140080637902736->140080638264368
140080638264272->140080638265040
%3
140080637350112
()
140080638065680
MulBackward0
140080638065680->140080637350112
140080638064864
SinBackward0
140080638064864->140080638065680
140080638264608
MulBackward0
140080638264608->140080638064864
140080638265328
AccumulateGrad
140080638265328->140080638264608
140080637345680
()
140080637345680->140080638265328
140080638264464
AccumulateGrad
140080638264464->140080638264608
140080637298544
()
140080637298544->140080638264464
参考
コメント