Pytorch – 学習済みモデルで画像分類を行う方法

Pytorch – 学習済みモデルで画像分類を行う方法

概要

torchvision で提供されている学習済みのモデルを紹介し、推論を行う方法について解説します。

Advertisement

学習済みのモデル

torchvision では、以下のモデルが提供されています。 これらのモデルでは、ImageNet の1000クラス分類問題を学習した重みが使えるようになっており、転移学習や fine-tuning に利用できます。

モデル名 関数名 パラメータ数 Top-1 エラー率 Top-5 エラー率 出典
AlexNet alexnet() 61100840 43.45 20.91 One weird trick for parallelizing convolutional neural networks
VGG-11 vgg11() 132863336 30.98 11.37 Very Deep Convolutional Networks for Large-Scale Image Recognition
VGG-13 vgg13() 133047848 30.07 10.75 Very Deep Convolutional Networks for Large-Scale Image Recognition
VGG-16 vgg16() 138357544 28.41 9.62 Very Deep Convolutional Networks for Large-Scale Image Recognition
VGG-19 vgg19() 143667240 27.62 9.12 Very Deep Convolutional Networks for Large-Scale Image Recognition
VGG-11 with batch normalization vgg11_bn() 132868840 29.62 10.19 Very Deep Convolutional Networks for Large-Scale Image Recognition
VGG-13 with batch normalization vgg13_bn() 133053736 28.45 9.63 Very Deep Convolutional Networks for Large-Scale Image Recognition
VGG-16 with batch normalization vgg16_bn() 138365992 26.63 8.5 Very Deep Convolutional Networks for Large-Scale Image Recognition
VGG-19 with batch normalization vgg19_bn() 143678248 25.76 8.15 Very Deep Convolutional Networks for Large-Scale Image Recognition
ResNet-18 resnet18() 11689512 30.24 10.92 Deep Residual Learning for Image Recognition
ResNet-34 resnet34() 21797672 26.7 8.58 Deep Residual Learning for Image Recognition
ResNet-50 resnet50() 25557032 23.85 7.13 Deep Residual Learning for Image Recognition
ResNet-101 resnet101() 44549160 22.63 6.44 Deep Residual Learning for Image Recognition
ResNet-152 resnet152() 60192808 21.69 5.94 Deep Residual Learning for Image Recognition
SqueezeNet 1.0 squeezenet1_0() 1248424 41.9 19.58 SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size
SqueezeNet 1.1 squeezenet1_1() 1235496 41.81 19.38 SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size
Densenet-121 densenet121() 7978856 25.35 7.83 Densely Connected Convolutional Networks
Densenet-161 densenet161() 28681000 24 7 Densely Connected Convolutional Networks
Densenet-169 densenet169() 14149480 22.8 6.43 Densely Connected Convolutional Networks
Densenet-201 densenet201() 20013928 22.35 6.2 Densely Connected Convolutional Networks
Inception v3 inception_v3() 27161264 22.55 6.44 Rethinking the Inception Architecture for Computer Vision
GoogleNet googlenet() 13004888 30.22 10.47 Going Deeper with Convolutions
ShuffleNet v2 shufflenet_v2_x0_5() 1366792 30.64 11.68 ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
ShuffleNet v3 shufflenet_v2_x1_0() 2278604 ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
ShuffleNet v4 shufflenet_v2_x1_5() 3503624 ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
ShuffleNet v5 shufflenet_v2_x2_0() 7393996 ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
MobileNet v2 mobilenet_v2() 3504872 28.12 9.71 MobileNetV2: Inverted Residuals and Linear Bottlenecks
ResNeXt-50-32x4d resnext50_32x4d() 25028904 22.38 6.3 Aggregated Residual Transformations for Deep Neural Networks
ResNeXt-101-32x8d resnext101_32x8d() 88791336 20.69 5.47 Aggregated Residual Transformations for Deep Neural Networks
Wide ResNet-50-2 wide_resnet50_2() 68883240 21.49 5.91 Wide Residual Networks
Wide ResNet-101-2 wide_resnet101_2() 126886696 21.16 5.72 Wide Residual Networks
MNASNet mnasnet0_5() 2220824 26.49 8.456 MnasNet: Platform-Aware Neural Architecture Search for Mobile
MNASNet mnasnet0_75() 3170656 MnasNet: Platform-Aware Neural Architecture Search for Mobile
MNASNet mnasnet1_0() 4383312 MnasNet: Platform-Aware Neural Architecture Search for Mobile
MNASNet mnasnet1_3() 6279432 MnasNet: Platform-Aware Neural Architecture Search for Mobile
  • パラメータ数: モデルを構成するパラメータ数を表す。パラメータが多いほど、モデルの表現力が上がるため、精度はよくなる傾向があるが、一方で計算量が増る
  • Top-1 エラー率: ImageNet データセットでの確率が一番高い予測ラベルが正解ラベルと一致していない割合
  • Top-5 エラー率: ImageNet データセットでの確率が高い上位5個の予測ラベルに正解ラベルが含まれていない割合

一般に、Top-k エラー率は、ImageNet データセットでの確率が高い上位 $k$ 個の予測ラベルに正解ラベルが含まれていない割合を表します。Top-k エラー率が低いほど、精度がよいモデルといえます。

パラメータ数と Top-1 エラー率、Top-5 エラー率の関係をそれぞれ描画すると、以下のようになります。 左下のモデルほど、パラメータが少なく、精度がいいモデルということになります。

学習済みモデルで推論する

モデルを作成する際に pretrained=True を指定すると、ImageNet の1000クラス分類問題を学習した重みでモデルが初期化されます。ResNet-50 の学習済みモデルを使い、画像の推論を行う例を以下で紹介します。

必要なモジュールを import する

In [1]:
import json
from pathlib import Path

import numpy as np
import torch
import torchvision
from PIL import Image
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets.utils import download_url

デバイスを作成する

In [2]:
def get_device(use_gpu):
    if use_gpu and torch.cuda.is_available():
        # これを有効にしないと、計算した勾配が毎回異なり、再現性が担保できない。
        torch.backends.cudnn.deterministic = True
        return torch.device("cuda")
    else:
        return torch.device("cpu")


# デバイスを選択する。
device = get_device(use_gpu=True)
Advertisement

モデルを作成する

resnet50(pretrained=True) で学習済みの重みを使用した ResNet-50 を作成します。作成後、to(device) で計算を行うデバイスに転送します。

In [3]:
model = torchvision.models.resnet50(pretrained=True).to(device)

Transforms を作成する

ImageNet の学習済みモデルで推論を行う際は以下の前処理が必要となります。

  1. (256, 256) にリサイズする
  2. 画像の中心に合わせて、(224, 224) で切り抜く
  3. RGB チャンネルごとに平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化する

これらの処理を行う Transforms を作成します。

In [4]:
transform = transforms.Compose(
    [
        transforms.Resize(256),  # (256, 256) で切り抜く。
        transforms.CenterCrop(224),  # 画像の中心に合わせて、(224, 224) で切り抜く
        transforms.ToTensor(),  # テンソルにする。
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),  # 標準化する。
    ]
)

画像を読み込む

以下の手順でモデルに流せる状態にします。

  1. Transforms は Pillow の画像形式が対応しているので、PIL.Image.open() で読み込む。
  2. Transforms で変換し、テンソルにする。
  3. unsqueeze(0) でバッチ次元を追加する。形状を (C, H, W) から (1, C, H, W) にする。
  4. to(device) で計算を行うデバイスに転送する。
In [5]:
img = Image.open("sample.jpg")
inputs = transform(img)
inputs = inputs.unsqueeze(0).to(device)

推論する

eval() でモデルを推論モードに設定したら、順伝搬を行います。

In [6]:
model.eval()
outputs = model(inputs)
Advertisement

推論結果を解釈する

モデルの出力結果を解釈します。torchvision のモデルは softmax をとる前の結果なので、softmax(outputs, dim=1) で softmax を計算します。その後、sort(dim=1, descending=True) で確率が高い順にソートし、確率及び対応するクラス ID の一覧を取得します。

In [7]:
batch_probs = F.softmax(outputs, dim=1)
batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)

クラス ID だと何のクラスを表しているかわからないので、クラス名の一覧が記載されたファイルを Web 上から取得します。

In [8]:
def get_classes():
    if not Path("data/imagenet_class_index.json").exists():
        # ファイルが存在しない場合はダウンロードする。
        download_url("https://git.io/JebAs", "data", "imagenet_class_index.json")

    # クラス一覧を読み込む。
    with open("data/imagenet_class_index.json") as f:
        data = json.load(f)
        class_names = [x["ja"] for x in data]

    return class_names


# クラス名一覧を取得する。
class_names = get_classes()

確率が高い上位3クラスの名前及び確率を出力します。

In [9]:
for probs, indices in zip(batch_probs, batch_indices):
    for k in range(3):
        print(f"Top-{k + 1} {class_names[indices[k]]} {probs[k]:.2%}")
Top-1 ポメラニアン 97.72%
Top-2 パピヨン 0.42%
Top-3 キースホンド 0.40%

DataLoader を使って推論する

先程は1枚の画像を PIL.Image.open() で読み込み、推論を行いました。今度は DataLoader を利用して複数枚の画像をミニバッチ単位で一度に推論する方法を紹介します。

Dataset を作成する

まずは、指定したディレクトリ (data とします) 内にある画像一覧を読み込む Dataset を作成します。この Dataset を使って、Dataloader を作成します。

data
├── apple.jpg
├── cat.jpg
├── dog.jpg
├── imagenet_class_index.json
├── sea_turtle.jpg
└── traffic_light.jpg
In [10]:
def _get_img_paths(img_dir):
    img_dir = Path(img_dir)
    img_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
    img_paths = [str(p) for p in img_dir.iterdir() if p.suffix in img_extensions]
    img_paths.sort()

    return img_paths


class ImageFolder(Dataset):
    def __init__(self, img_dir, transform):
        # 画像ファイルのパス一覧を取得する。
        self.img_paths = _get_img_paths(img_dir)
        self.transform = transform

    def __getitem__(self, index):
        path = self.img_paths[index]
        img = Image.open(path)
        inputs = self.transform(img)

        return {"image": inputs, "path": path}

    def __len__(self):
        return len(self.img_paths)


# Dataset を作成する。
dataset = ImageFolder("data", transform)
# DataLoader を作成する。
dataloader = DataLoader(dataset, batch_size=8)

各画像を推論し、結果を表示します。(Jupyter Notebook 上で実行する)

In [11]:
from IPython import display

for batch in dataloader:
    inputs = batch["image"].to(device)
    outputs = model(inputs)

    batch_probs = F.softmax(outputs, dim=1)

    batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)

    for probs, indices, path in zip(batch_probs, batch_indices, batch["path"]):
        display.display(display.Image(path, width=224))
        print(f"path: {path}")
        for k in range(3):
            print(f"Top-{k + 1} {probs[k]:.2%} {class_names[indices[k]]}")
path: data/apple.jpg
Top-1 17.94% ザクロ
Top-2 16.04% リンゴ
Top-3 11.18% 口紅
path: data/cat.jpg
Top-1 99.13% エジプトの猫
Top-2 0.66% タビー
Top-3 0.15% 虎猫
path: data/dog.jpg
Top-1 34.79% ラブラドル・レトリーバー犬
Top-2 13.69% ゴールデンレトリバー
Top-3 13.57% ローデシアン・リッジバック
path: data/sea_turtle.jpg
Top-1 75.74% とんちき
Top-2 23.19% オサガメ
Top-3 0.90% テラピン
path: data/traffic_light.jpg
Top-1 99.95% 交通信号灯
Top-2 0.01% 道路標識
Top-3 0.00% スポットライト

上手く推論できていることが確認できました。torchvision の学習済みモデルを使って正しく推論できるのは、ImageNet の1000クラスに含まれるクラスだけになります。1000クラスに含まれないクラスの分類を行いたい場合は、fine-tuning または転移学習を行う必要があります。