概要
Pytorch の自動微分について解説します。
自動微分
次の合成関数を考えます。
$$ f(x_1, x_2) = \log(x_1 x_2) \times \sin(x_1 x_2) $$合成関数の出力を計算する場合、入力から出力に向けて順番に演算を行っていきます。これを順伝搬 (Forward Propagation)といいます。
$$ \begin{aligned} y_1 &= x_1 \times x_2 \\ y_2 &= \log(y_1) \\ y_3 &= \sin(y_1) \\ z &= y_2 \times y_3 \\ \end{aligned} $$例: $x_1 = 1, x_2 = 2$ のとき、
$$ \begin{aligned} y_1 &= 1 \times 2 \\ y_2 &= \log(2) \\ y_3 &= \sin(2) \\ z &= \log(2) \times \sin(2) \\ \end{aligned} $$この計算過程を演算をノード、値 (入力や計算途中の出力) をエッジで表現したグラフで表します。これを計算グラフ (Computation Graph) といいます。
この合成関数の微分は、連鎖律より合成関数を構成する関数の微分の組み合わせで計算できます。
フォーワードモードの自動微分
$\frac{\partial z}{\partial x_1}$ の偏微分係数を求めることを考えます。フォーワードモードの自動微分 (Forward Mode Automatic Differentiation)では、以下のように入力側から順番に計算していきます。式からわかるように順伝搬時の途中の出力 $y_1, y_2, y_3$ を使用するので、記録しておく必要があります。
$$ \begin{aligned} \frac{\partial y_1}{\partial x_1} &= \frac{\partial}{\partial x_1} (x_1 x_2) = x_2 \\ \frac{\partial y_2}{\partial x_1} &= \frac{\partial y_2}{\partial y_1}\frac{\partial y_1}{\partial x_1} = \frac{\partial}{\partial y_1}\log y_1 \cdot x_2 = \frac{x_2}{y_1} \\ \frac{\partial y_3}{\partial x_1} &= \frac{\partial y_3}{\partial y_1}\frac{\partial y_1}{\partial x_1} = \frac{\partial}{\partial y_1} \sin y_1 \cdot x_2 = \cos y_1 \cdot x_2 \\ \frac{\partial z}{\partial x_1} &= \frac{\partial z}{\partial y_2} \frac{\partial y_2}{\partial x_1} + \frac{\partial z}{\partial y_3} \frac{\partial y_3}{\partial x_1} \\ &= \frac{\partial}{\partial y_2}(y_2 y_3) \cdot \frac{\partial y_2}{\partial x_1} + \frac{\partial}{\partial y_3}(y_2 y_3) \cdot \frac{\partial y_3}{\partial x_1} \\ &= y_3 \cdot \frac{x_2}{y_1} + y_2 \cdot \cos y_1 \cdot x_2 \\ \end{aligned} $$
$\frac{\partial z}{\partial x_1}$ を計算しましたが、$\frac{\partial z}{\partial x_2}$ を求めたい場合、同様の計算をもう一度行う必要があります。
バックワードモードの自動微分
バックワードモードの自動微分 (Backward Mode Automatic Differentiation)では、以下のように出力側から順番に計算していきます。式からわかるように順伝搬時の途中の出力 $y_1, y_2, y_3$ を使用するので、記録しておく必要があります。
$$ \begin{aligned} \frac{\partial z}{\partial y_3} &= \frac{\partial}{\partial y_3}(y_2 y_3) = y_2 \\ \frac{\partial z}{\partial y_2} &= \frac{\partial}{\partial y_3}(y_2 y_3) = y_3 \\ \frac{\partial z}{\partial y_1} &= \frac{\partial z}{\partial y_2}\frac{\partial y_2}{\partial y_1} + \frac{\partial z}{\partial y_3}\frac{\partial y_3}{\partial y_1} \\ &= \frac{\partial z}{\partial y_2} \cdot \frac{\partial}{\partial y_1} \log y_1 + \frac{\partial z}{\partial y_3} \cdot \frac{\partial}{\partial y_1} \sin y_1 \\ &= \frac{\partial z}{\partial y_2} \cdot \frac{1}{y_1} + \frac{\partial z}{\partial y_3} \cdot \cos y_1 \\ &= y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \\ \frac{\partial z}{\partial x_1} &= \frac{\partial z}{\partial y_1}\frac{\partial y_1}{\partial x_1} \\ &= \left(y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \right) x_2 \\ \frac{\partial z}{\partial x_2} &= \frac{\partial z}{\partial y_1}\frac{\partial y_1}{\partial x_2} \\ &= \left(y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \right) x_1 \\ \end{aligned} $$
フォーワードモードと違い、1度のバックワードで $\frac{\partial z}{\partial x_1}$ 以外の出力に対するすべての係数の偏微分係数が求まります。ディープラーニングではモデルのすべての係数の偏微分係数を求める必要があるため、バックワードモードの自動微分が使われます。
Pytorch の自動微分
Tensor.backward()、Tensor.grad 属性
Pytorch で上記の計算グラフを $x_1 = 1, x_2 = 2$ として作成します。計算グラフを作成したあと、微分対象のテンソル (例: z) の Tensor.backward()
を呼び出します。すると、逆伝搬が行われ、Tensor.grad
属性に計算した微分係数が記録されます。
Tensor.requires_grad
逆伝搬を行うには、途中の出力値を記録しておく必要があり、その分のメモリが必要になります。Pytorch では、演算の入力のテンソルの Tensor.requires_grad
属性が True の場合のみ、演算の出力のテンソルの値が記録されるようになっています。そのため、テンソル x1, x2
を作成するときに requires_grad=True
引数を指定し、このテンソルの微分係数を計算する必要があることを設定しています。これを設定しない場合、微分係数が計算できないため、以下のエラーが発生します。
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
import torch
# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2
y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3
# 逆伝搬を行う。
z.backward()
print(f"dz/dx1={x1.grad}")
print(f"dz/dx2={x2.grad}")
dz/dx1=0.33239537477493286 dz/dx2=0.16619768738746643
順伝搬、逆伝搬を先程の式に当てはめて計算すると、Pytorch の計算結果と一致することがわかります。
$$ \begin{aligned} y_1 &= 2 \\ y_2 &= \log(2) \\ y_3 &= \sin(2) \\ z &= \log(2)\sin(2) \\ \end{aligned} $$$$ \begin{aligned} \frac{\partial z}{\partial x_1} &= \left(\sin2 \cdot \frac{1}{2} + \log2 \cdot \cos 2 \right) 2 = 0.3323954139224974\\ \frac{\partial z}{\partial x_2} &= \left(\sin2 \cdot \frac{1}{2} + \log2 \cdot \cos 2 \right) = 0.1661977069612487 \\ \end{aligned} $$作成したテンソルをあとから Tensor.requires_grad == True
に設定したい場合、次の2通りの方法があります。
Tensor.requires_grad = True
とする。Tensor.requires_grad_()
を呼び出す。
torch.no_grad()
推論時など、あとで逆伝搬を行う予定がない場合、順伝搬時に途中の出力を記録しておくことはメモリの無駄遣いです。途中の出力を記録しないようにするには、すべてのテンソルに Tensor.requires_grad = False
を設定するか、コンテキストマネージャー torch.no_grad()
を使用します。このコンテキスト内の計算は、Tensor.requires_grad == True
であっても、途中の出力が記録されません。
import torch
# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
with torch.no_grad():
y1 = x1 * x2
y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3
print(z)
tensor(0.6303)
Tensor.retain_grad()
デフォルトでは、逆伝搬時に微分係数が記録されるのは出力側から見て末端の入力の Tensor だけなので、注意してください。それ以外のテンソルの grad
属性は値が None になっており、アクセスしようとすると、警告が発生します。
print(f"dz/y1={y1.grad}")
print(f"dz/y2={y2.grad}")
print(f"dz/y3={y3.grad}")
# dz/y1=None
# dz/y2=None
# dz/y3=None
UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:417.)
return self._grad
逆伝搬時に末端以外のテンソルの微分係数の記録するようにしたい場合は Tensor.backward()
前に Tensor.retain_grad()
を呼び出します。末端のテンソルかどうかは Tensor.is_leaf
属性で確認できます。現在の状態は Tensor.retains_grad
属性で確認できます。
import torch
# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2
y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3
y1.retain_grad()
y2.retain_grad()
y3.retain_grad()
# 逆伝搬を行う。
z.backward()
print(f"dz/dx1={x1.grad}, is_leaf={x1.is_leaf}")
print(f"dz/dx2={x2.grad}, is_leaf={x2.is_leaf}")
print(f"dz/y1={y1.grad}, is_leaf={y1.is_leaf}")
print(f"dz/y2={y2.grad}, is_leaf={y2.is_leaf}")
print(f"dz/y3={y3.grad}, is_leaf={y3.is_leaf}")
dz/dx1=0.33239537477493286, is_leaf=True dz/dx2=0.16619768738746643, is_leaf=True dz/y1=0.16619768738746643, is_leaf=False dz/y2=0.9092974066734314, is_leaf=False dz/y3=0.6931471824645996, is_leaf=False
Torch.detach()
Torch.detach()
は呼び出したテンソルを計算グラフから切り離したテンソルを作成します。shallow copy なので、切り離したテンソルと元のテンソルは値自体は共有しています。
import torch
from torchviz import make_dot
# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2
y1_detach = y1.detach()
y2 = torch.log(y1_detach)
y3 = torch.sin(y1)
z = y2 * y3
# 逆伝搬を行う。
z.backward()
print(f"dz/dx1={x1.grad}")
print(f"dz/dx2={x2.grad}")
dz/dx1=-0.5769020318984985 dz/dx2=-0.28845101594924927
Tensor の numpy への変換
Tensor.numpy()
で Tensor を numpy 配列に変換できます。ただし、値は共有しているため、ディープコピーとして作成したい場合、Tensor.clone()
で予めコピーしておきます。- numpy 配列に変換できる Tensor は計算デバイスが CPU のものだけです。計算デバイスが GPU のテンソルは予め
Tensor.cpu()
で計算デバイスを CPU に変更します。 Tensor.requires_grad==True
の Tensor は numpy 配列に変換できないため、予めTensor.detach()
でTensor.requires_grad=False
にします。RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
以上をまとめると、あるテンソルを numpy 配列に変換したい場合、以下のようにします。
y = x.detach().cpu().clone().numpy()
import numpy as np
import torch
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2
y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3
arr = z.detach().cpu().clone().numpy()
計算グラフを描画する
計算グラフを可視化するための torchviz というライブラリを紹介します。このソフトは Pytorch で作成した計算グラフを Graphviz を使って可視化できます。
torchviz
のほか、Graphviz を別途インストールする必要があります。
pip install torchviz
[blogcard url=”https://pystyle.info/how-to-install-graphviz-on-windows-and-ubuntu”]
import torch
from torchviz import make_dot
# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2
y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3
dot = make_dot(z)
dot
import torch
from torchviz import make_dot
# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2
y1_detach = y1.detach()
y2 = torch.log(y1_detach)
y3 = torch.sin(y1)
z = y2 * y3
dot = make_dot(z)
dot
コメント