概要
Pytorch でモデルをファイルに保存する方法について紹介します。
torch.save、torch.load、load_state_dict
モデルの保存及び読み込みに関して、次の3つの関数があります。
- torch.save(): オブジェクトを直列化してディスクに保存します。直列化には Python の pickle 機能を利用します。
- torch.load(): ディスクから直列化されたオブジェクトを読み込みます。
- load_state_dict(): 後述する state_dict を読み込みます。
state_dict
state_dict は、モデルのパラメータやオプティマイザの状態を表す dict です。
torch.nn.Module
:Module
の state_dict は、モデルの学習可能なパラメータ(重みやバイアス)が格納されています。Module.state_dict()
で取得できます。- torch.optim.Optimizer:
Optimizer
のstate_dict
は、オプティマイザの状態や使用されているハイパーパラメータに関する情報が格納されています。Optimizer.state_dict()
で取得できます。
state_dict は、dict なので、保存、更新、変更、復元が簡単にできます。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# モデルを作成する。
model = Model()
# オプティマイザを作成する。
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# モデルの state_dict を表示する。
print("Model's state_dict:")
for key, param in model.state_dict().items():
print(key, "\t", param.size())
# オプティマイザの state_dict を表示する。
print("Optimizer's state_dict:")
for key, param in optimizer.state_dict().items():
print(key, "\t", param)
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139665281919328, 139665281919488, 139665281919568, 139665281919648, 139665281919728, 139665281919808, 139665281922384, 139665281919408, 139665281919888, 139665281919968]}]
torch.save、torch.load でモデル全体を保存する (非推奨)
torch.save() でモデル全体を直列化して、ファイルに保存できます。保存したファイルは、torch.load() で読み込めます。
# モデルを保存する。
torch.save(model, "model.pth")
# 保存したモデルを読み込む。
model = torch.load("model.pth")
この方法は保存、読み込みを簡単に行えますが、ディレクトリ構造や使用した GPU など保存時の環境に依存した情報を含み、他の環境では保存したファイルが読み込めない可能性があります。そのため、次に紹介する state_dict を使用した保存方法を推奨します。
state_dict でモデルのパラメータを保存する (推奨)
モデルの学習可能なパラメータを表す state_dict を保存する方法です。
model.state_dict()
で state_dict を取得して、これを torch.save()
を使って保存します。
読み込む際は torch.load()
で読み込んだ state_dict を Module.load_state_dict()
でモデルに復元します。保存するファイル名は自由ですが、慣例として、拡張子は .pt
または .pth
にします。
# モデルを保存する。
torch.save(model.state_dict(), "model.pth")
# 保存したモデルを読み込む。
model.load_state_dict(torch.load("model.pth"))
<All keys matched successfully>
学習途中の状態を保存する
オプティマイザも状態や使用されているハイパーパラメータに関する情報が格納された state_dict を持ちます。torch.save()
は任意のオブジェクトを保存できるので、学習途中の状態を dict などで格納して保存します。
# 学習途中の状態
epoch = 10
# 学習途中の状態を保存する。
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
"model.tar",
)
# 学習途中の状態を読み込む。
checkpoint = torch.load("model.tar")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
モデルを部分的に読み込む
転移学習などで出力層など一部の層を変更して、ファイルに格納されたパラメータの一部を読み込みたい場合があります。そのときは、load_state_dict() に strict=False
を指定し、state_dict のマッチしたキーのみを読み込むようにします。
# 保存したモデルを読み込む。
model.load_state_dict(torch.load("model.pth"), strict=False)
コメント