Pytorch – Pytorch でコードを書く際によく使う Tips

目次

概要

Pytorch でコードを書く際によく使う Tips を公式チュートリアルなどを参考にまとめました。

使用頻度の高いモジュール

torch.nn.functionalF という名前で import するなど、いくつかの慣例があるようです。

名前 import 内容
torch import torch Pytorch
torch.nn import torch.nn as nn レイヤークラスなど
torch.nn.functional import torch.nn.functional as F 活性化関数など
torch.nn.init import torch.nn.init as init 初期化
torch.optim import torch.optim as optim 最適化
torch.utils.data import torch.utils.data as data Dataset、DadaLoader
torchvision import torchvision torchvision
torchvision.datasets import torchvision.datasets as datasets Dataset
torchvision.models import torchvision.models as models 事前学習済みモデル
torchvision.transforms import torchvision.transforms as transforms Transform
In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

デバイスを選択する

デバイスを選択するためのヘルパー関数です。

In [2]:
def get_device(gpu_id=-1):
    """gpu_id を指定した場合、GPU のデバイスを作成します。
    """
    if gpu_id >= 0 and torch.cuda.is_available():
        return torch.device("cuda", gpu_id)
    else:
        return torch.device("cpu")


device = get_device()
print(device)  # cpu

device = get_device(gpu_id=0)
print(device)  # cuda:0
cpu
cuda:0

事前学習済みモデルを使う場合の Transform

torchvision で提供されている事前学習済みモデルを使う場合の Transform の例を紹介します。

Pytorch – 学習済みモデルで画像分類を行う方法 – pystyle

ImageNet の事前学習済みモデルで学習または推論を行う際に以下の前処理が必要です。

  • 入力の大きさを (224, 224) にする
  • 入力を RGB チャンネルごとに平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化する

学習時のオーグメンテーションは、適宜追加してください。

Pytorch – torchvision で使える Transform まとめ – pystyle

In [ ]:
data_transforms = {
    # 学習時の Transform
    "train": transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    ),
    # 推論時の Transform
    "val": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    ),
}

コメント

コメントする

目次