Warning: Undefined variable $position in /home/pystyles/pystyle.info/public_html/wp/wp-content/themes/lionblog/functions.php on line 4897

Pytorch – モデルをファイルに保存する方法

Pytorch – モデルをファイルに保存する方法

概要

Pytorch でモデルをファイルに保存する方法について紹介します。

Advertisement

torch.save、torch.load、load_state_dict

モデルの保存及び読み込みに関して、次の3つの関数があります。

  • torch.save(): オブジェクトを直列化してディスクに保存します。直列化には Python の pickle 機能を利用します。
  • torch.load(): ディスクから直列化されたオブジェクトを読み込みます。
  • load_state_dict(): 後述する state_dict を読み込みます。

state_dict

state_dict は、モデルのパラメータやオプティマイザの状態を表す dict です。

  • torch.nn.Module: Module の state_dict は、モデルの学習可能なパラメータ(重みやバイアス)が格納されています。Module.state_dict() で取得できます。
  • torch.optim.Optimizer: Optimizerstate_dict は、オプティマイザの状態や使用されているハイパーパラメータに関する情報が格納されています。Optimizer.state_dict() で取得できます。

state_dict は、dict なので、保存、更新、変更、復元が簡単にできます。

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


# モデルを作成する。
model = Model()

# オプティマイザを作成する。
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# モデルの state_dict を表示する。
print("Model's state_dict:")
for key, param in model.state_dict().items():
    print(key, "\t", param.size())

# オプティマイザの state_dict を表示する。
print("Optimizer's state_dict:")
for key, param in optimizer.state_dict().items():
    print(key, "\t", param)
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139665281919328, 139665281919488, 139665281919568, 139665281919648, 139665281919728, 139665281919808, 139665281922384, 139665281919408, 139665281919888, 139665281919968]}]

torch.save、torch.load でモデル全体を保存する (非推奨)

torch.save() でモデル全体を直列化して、ファイルに保存できます。保存したファイルは、torch.load() で読み込めます。

In [2]:
# モデルを保存する。
torch.save(model, "model.pth")

# 保存したモデルを読み込む。
model = torch.load("model.pth")

この方法は保存、読み込みを簡単に行えますが、ディレクトリ構造や使用した GPU など保存時の環境に依存した情報を含み、他の環境では保存したファイルが読み込めない可能性があります。そのため、次に紹介する state_dict を使用した保存方法を推奨します。

state_dict でモデルのパラメータを保存する (推奨)

モデルの学習可能なパラメータを表す state_dict を保存する方法です。 model.state_dict() で state_dict を取得して、これを torch.save() を使って保存します。 読み込む際は torch.load() で読み込んだ state_dict を Module.load_state_dict() でモデルに復元します。保存するファイル名は自由ですが、慣例として、拡張子は .pt または .pth にします。

In [3]:
# モデルを保存する。
torch.save(model.state_dict(), "model.pth")

# 保存したモデルを読み込む。
model.load_state_dict(torch.load("model.pth"))
<All keys matched successfully>
Advertisement

学習途中の状態を保存する

オプティマイザも状態や使用されているハイパーパラメータに関する情報が格納された state_dict を持ちます。torch.save() は任意のオブジェクトを保存できるので、学習途中の状態を dict などで格納して保存します。

In [4]:
# 学習途中の状態
epoch = 10

# 学習途中の状態を保存する。
torch.save(
    {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    "model.tar",
)

# 学習途中の状態を読み込む。
checkpoint = torch.load("model.tar")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]

モデルを部分的に読み込む

転移学習などで出力層など一部の層を変更して、ファイルに格納されたパラメータの一部を読み込みたい場合があります。そのときは、load_state_dict() に strict=False を指定し、state_dict のマッチしたキーのみを読み込むようにします。

# 保存したモデルを読み込む。
model.load_state_dict(torch.load("model.pth"), strict=False)