Pytorch – 自動微分の仕組みと backward() の使い方を解説

目次

概要

Pytorch の自動微分について解説します。

自動微分

次の合成関数を考えます。

$$ f(x_1, x_2) = \log(x_1 x_2) \times \sin(x_1 x_2) $$

合成関数の出力を計算する場合、入力から出力に向けて順番に演算を行っていきます。これを順伝搬 (Forward Propagation)といいます。

$$ \begin{aligned} y_1 &= x_1 \times x_2 \\ y_2 &= \log(y_1) \\ y_3 &= \sin(y_1) \\ z &= y_2 \times y_3 \\ \end{aligned} $$

例: $x_1 = 1, x_2 = 2$ のとき、

$$ \begin{aligned} y_1 &= 1 \times 2 \\ y_2 &= \log(2) \\ y_3 &= \sin(2) \\ z &= \log(2) \times \sin(2) \\ \end{aligned} $$

この計算過程を演算をノード、値 (入力や計算途中の出力) をエッジで表現したグラフで表します。これを計算グラフ (Computation Graph) といいます。

計算グラフ

この合成関数の微分は、連鎖律より合成関数を構成する関数の微分の組み合わせで計算できます。

フォーワードモードの自動微分

$\frac{\partial z}{\partial x_1}$ の偏微分係数を求めることを考えます。フォーワードモードの自動微分 (Forward Mode Automatic Differentiation)では、以下のように入力側から順番に計算していきます。式からわかるように順伝搬時の途中の出力 $y_1, y_2, y_3$ を使用するので、記録しておく必要があります。

$$ \begin{aligned} \frac{\partial y_1}{\partial x_1} &= \frac{\partial}{\partial x_1} (x_1 x_2) = x_2 \\ \frac{\partial y_2}{\partial x_1} &= \frac{\partial y_2}{\partial y_1}\frac{\partial y_1}{\partial x_1} = \frac{\partial}{\partial y_1}\log y_1 \cdot x_2 = \frac{x_2}{y_1} \\ \frac{\partial y_3}{\partial x_1} &= \frac{\partial y_3}{\partial y_1}\frac{\partial y_1}{\partial x_1} = \frac{\partial}{\partial y_1} \sin y_1 \cdot x_2 = \cos y_1 \cdot x_2 \\ \frac{\partial z}{\partial x_1} &= \frac{\partial z}{\partial y_2} \frac{\partial y_2}{\partial x_1} + \frac{\partial z}{\partial y_3} \frac{\partial y_3}{\partial x_1} \\ &= \frac{\partial}{\partial y_2}(y_2 y_3) \cdot \frac{\partial y_2}{\partial x_1} + \frac{\partial}{\partial y_3}(y_2 y_3) \cdot \frac{\partial y_3}{\partial x_1} \\ &= y_3 \cdot \frac{x_2}{y_1} + y_2 \cdot \cos y_1 \cdot x_2 \\ \end{aligned} $$

$\frac{\partial z}{\partial x_1}$ を計算しましたが、$\frac{\partial z}{\partial x_2}$ を求めたい場合、同様の計算をもう一度行う必要があります。

バックワードモードの自動微分

バックワードモードの自動微分 (Backward Mode Automatic Differentiation)では、以下のように出力側から順番に計算していきます。式からわかるように順伝搬時の途中の出力 $y_1, y_2, y_3$ を使用するので、記録しておく必要があります。

$$ \begin{aligned} \frac{\partial z}{\partial y_3} &= \frac{\partial}{\partial y_3}(y_2 y_3) = y_2 \\ \frac{\partial z}{\partial y_2} &= \frac{\partial}{\partial y_3}(y_2 y_3) = y_3 \\ \frac{\partial z}{\partial y_1} &= \frac{\partial z}{\partial y_2}\frac{\partial y_2}{\partial y_1} + \frac{\partial z}{\partial y_3}\frac{\partial y_3}{\partial y_1} \\ &= \frac{\partial z}{\partial y_2} \cdot \frac{\partial}{\partial y_1} \log y_1 + \frac{\partial z}{\partial y_3} \cdot \frac{\partial}{\partial y_1} \sin y_1 \\ &= \frac{\partial z}{\partial y_2} \cdot \frac{1}{y_1} + \frac{\partial z}{\partial y_3} \cdot \cos y_1 \\ &= y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \\ \frac{\partial z}{\partial x_1} &= \frac{\partial z}{\partial y_1}\frac{\partial y_1}{\partial x_1} \\ &= \left(y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \right) x_2 \\ \frac{\partial z}{\partial x_2} &= \frac{\partial z}{\partial y_1}\frac{\partial y_1}{\partial x_2} \\ &= \left(y_3 \cdot \frac{1}{y_1} + y_2 \cdot \cos y_1 \right) x_1 \\ \end{aligned} $$

フォーワードモードと違い、1度のバックワードで $\frac{\partial z}{\partial x_1}$ 以外の出力に対するすべての係数の偏微分係数が求まります。ディープラーニングではモデルのすべての係数の偏微分係数を求める必要があるため、バックワードモードの自動微分が使われます。

Pytorch の自動微分

Tensor.backward()、Tensor.grad 属性

Pytorch で上記の計算グラフを $x_1 = 1, x_2 = 2$ として作成します。計算グラフを作成したあと、微分対象のテンソル (例: z) の Tensor.backward() を呼び出します。すると、逆伝搬が行われ、Tensor.grad 属性に計算した微分係数が記録されます。

Tensor.requires_grad

逆伝搬を行うには、途中の出力値を記録しておく必要があり、その分のメモリが必要になります。Pytorch では、演算の入力のテンソルの Tensor.requires_grad 属性が True の場合のみ、演算の出力のテンソルの値が記録されるようになっています。そのため、テンソル x1, x2 を作成するときに requires_grad=True 引数を指定し、このテンソルの微分係数を計算する必要があることを設定しています。これを設定しない場合、微分係数が計算できないため、以下のエラーが発生します。

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
In [1]:
import torch

# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2
y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3

# 逆伝搬を行う。
z.backward()

print(f"dz/dx1={x1.grad}")
print(f"dz/dx2={x2.grad}")
dz/dx1=0.33239537477493286
dz/dx2=0.16619768738746643

順伝搬、逆伝搬を先程の式に当てはめて計算すると、Pytorch の計算結果と一致することがわかります。

$$ \begin{aligned} y_1 &= 2 \\ y_2 &= \log(2) \\ y_3 &= \sin(2) \\ z &= \log(2)\sin(2) \\ \end{aligned} $$$$ \begin{aligned} \frac{\partial z}{\partial x_1} &= \left(\sin2 \cdot \frac{1}{2} + \log2 \cdot \cos 2 \right) 2 = 0.3323954139224974\\ \frac{\partial z}{\partial x_2} &= \left(\sin2 \cdot \frac{1}{2} + \log2 \cdot \cos 2 \right) = 0.1661977069612487 \\ \end{aligned} $$

作成したテンソルをあとから Tensor.requires_grad == True に設定したい場合、次の2通りの方法があります。

  1. Tensor.requires_grad = True とする。
  2. Tensor.requires_grad_() を呼び出す。

torch.no_grad()

推論時など、あとで逆伝搬を行う予定がない場合、順伝搬時に途中の出力を記録しておくことはメモリの無駄遣いです。途中の出力を記録しないようにするには、すべてのテンソルに Tensor.requires_grad = False を設定するか、コンテキストマネージャー torch.no_grad() を使用します。このコンテキスト内の計算は、Tensor.requires_grad == True であっても、途中の出力が記録されません。

In [2]:
import torch

# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)

with torch.no_grad():
    y1 = x1 * x2
    y2 = torch.log(y1)
    y3 = torch.sin(y1)
    z = y2 * y3

print(z)
tensor(0.6303)

Tensor.retain_grad()

デフォルトでは、逆伝搬時に微分係数が記録されるのは出力側から見て末端の入力の Tensor だけなので、注意してください。それ以外のテンソルの grad 属性は値が None になっており、アクセスしようとすると、警告が発生します。

print(f"dz/y1={y1.grad}")
print(f"dz/y2={y2.grad}")
print(f"dz/y3={y3.grad}")
# dz/y1=None
# dz/y2=None
# dz/y3=None
UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:417.)
  return self._grad

逆伝搬時に末端以外のテンソルの微分係数の記録するようにしたい場合は Tensor.backward() 前に Tensor.retain_grad() を呼び出します。末端のテンソルかどうかは Tensor.is_leaf 属性で確認できます。現在の状態は Tensor.retains_grad 属性で確認できます。

In [3]:
import torch

# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2
y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3

y1.retain_grad()
y2.retain_grad()
y3.retain_grad()

# 逆伝搬を行う。
z.backward()

print(f"dz/dx1={x1.grad}, is_leaf={x1.is_leaf}")
print(f"dz/dx2={x2.grad}, is_leaf={x2.is_leaf}")
print(f"dz/y1={y1.grad}, is_leaf={y1.is_leaf}")
print(f"dz/y2={y2.grad}, is_leaf={y2.is_leaf}")
print(f"dz/y3={y3.grad}, is_leaf={y3.is_leaf}")
dz/dx1=0.33239537477493286, is_leaf=True
dz/dx2=0.16619768738746643, is_leaf=True
dz/y1=0.16619768738746643, is_leaf=False
dz/y2=0.9092974066734314, is_leaf=False
dz/y3=0.6931471824645996, is_leaf=False

Torch.detach()

Torch.detach() は呼び出したテンソルを計算グラフから切り離したテンソルを作成します。shallow copy なので、切り離したテンソルと元のテンソルは値自体は共有しています。

Torch.detach()

In [4]:
import torch
from torchviz import make_dot

# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2

y1_detach = y1.detach()
y2 = torch.log(y1_detach)
y3 = torch.sin(y1)
z = y2 * y3

# 逆伝搬を行う。
z.backward()

print(f"dz/dx1={x1.grad}")
print(f"dz/dx2={x2.grad}")
dz/dx1=-0.5769020318984985
dz/dx2=-0.28845101594924927

Tensor の numpy への変換

  • Tensor.numpy() で Tensor を numpy 配列に変換できます。ただし、値は共有しているため、ディープコピーとして作成したい場合、Tensor.clone() で予めコピーしておきます。
  • numpy 配列に変換できる Tensor は計算デバイスが CPU のものだけです。計算デバイスが GPU のテンソルは予め Tensor.cpu() で計算デバイスを CPU に変更します。
  • Tensor.requires_grad==True の Tensor は numpy 配列に変換できないため、予め Tensor.detach()Tensor.requires_grad=False にします。

    RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

以上をまとめると、あるテンソルを numpy 配列に変換したい場合、以下のようにします。

y = x.detach().cpu().clone().numpy()
In [5]:
import numpy as np
import torch

x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2

y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3

arr = z.detach().cpu().clone().numpy()

計算グラフを描画する

計算グラフを可視化するための torchviz というライブラリを紹介します。このソフトは Pytorch で作成した計算グラフを Graphviz を使って可視化できます。

torchviz のほか、Graphviz を別途インストールする必要があります。

pip install torchviz

[blogcard url=”https://pystyle.info/how-to-install-graphviz-on-windows-and-ubuntu”]

In [6]:
import torch
from torchviz import make_dot

# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2

y2 = torch.log(y1)
y3 = torch.sin(y1)
z = y2 * y3

dot = make_dot(z)
dot
%3 140080637344960 () 140080638265040 MulBackward0 140080638265040->140080637344960 140080638264128 LogBackward0 140080638264128->140080638265040 140080638264224 MulBackward0 140080638264224->140080638264128 140080638264272 SinBackward0 140080638264224->140080638264272 140080638264416 AccumulateGrad 140080638264416->140080638264224 140080637343920 () 140080637343920->140080638264416 140080638264368 AccumulateGrad 140080638264368->140080638264224 140080637902736 () 140080637902736->140080638264368 140080638264272->140080638265040
In [7]:
import torch
from torchviz import make_dot

# 計算グラフを作成する。
x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y1 = x1 * x2

y1_detach = y1.detach()
y2 = torch.log(y1_detach)
y3 = torch.sin(y1)
z = y2 * y3

dot = make_dot(z)
dot
%3 140080637350112 () 140080638065680 MulBackward0 140080638065680->140080637350112 140080638064864 SinBackward0 140080638064864->140080638065680 140080638264608 MulBackward0 140080638264608->140080638064864 140080638265328 AccumulateGrad 140080638265328->140080638264608 140080637345680 () 140080637345680->140080638265328 140080638264464 AccumulateGrad 140080638264464->140080638264608 140080637298544 () 140080637298544->140080638264464

参考

コメント

コメントする

目次