目次
概要
Pytorch でコードを書く際によく使う Tips を公式チュートリアルなどを参考にまとめました。
使用頻度の高いモジュール
torch.nn.functional
は F
という名前で import するなど、いくつかの慣例があるようです。
名前 | import | 内容 |
---|---|---|
torch | import torch | Pytorch |
torch.nn | import torch.nn as nn | レイヤークラスなど |
torch.nn.functional | import torch.nn.functional as F | 活性化関数など |
torch.nn.init | import torch.nn.init as init | 初期化 |
torch.optim | import torch.optim as optim | 最適化 |
torch.utils.data | import torch.utils.data as data | Dataset、DadaLoader |
torchvision | import torchvision | torchvision |
torchvision.datasets | import torchvision.datasets as datasets | Dataset |
torchvision.models | import torchvision.models as models | 事前学習済みモデル |
torchvision.transforms | import torchvision.transforms as transforms | Transform |
In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
デバイスを選択する
デバイスを選択するためのヘルパー関数です。
In [2]:
def get_device(gpu_id=-1):
"""gpu_id を指定した場合、GPU のデバイスを作成します。
"""
if gpu_id >= 0 and torch.cuda.is_available():
return torch.device("cuda", gpu_id)
else:
return torch.device("cpu")
device = get_device()
print(device) # cpu
device = get_device(gpu_id=0)
print(device) # cuda:0
cpu cuda:0
事前学習済みモデルを使う場合の Transform
torchvision で提供されている事前学習済みモデルを使う場合の Transform の例を紹介します。
Pytorch – 学習済みモデルで画像分類を行う方法 – pystyle
ImageNet の事前学習済みモデルで学習または推論を行う際に以下の前処理が必要です。
- 入力の大きさを (224, 224) にする
- 入力を RGB チャンネルごとに平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化する
学習時のオーグメンテーションは、適宜追加してください。
Pytorch – torchvision で使える Transform まとめ – pystyle
In [ ]:
data_transforms = {
# 学習時の Transform
"train": transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
# 推論時の Transform
"val": transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
}
コメント