概要
Pytorch で事前学習済みモデルを使ってクラス分類モデルを学習する方法について解説します。
事前学習済みモデル
昨今の CNN モデルは数千万~数億のパラメータで構成されるため、このモデルのパラメータを1から調整するには、大規模なデータセットと膨大な計算リソースが要求されます。そのため、用意したデータセットのサンプルが少ない場合や潤沢な計算リソースを利用できない場合は、精度のよいモデルを作成することができません。この問題を解決するテクニックとして、転移学習 (Transfer Learning) があります。転移学習では、事前に大規模なデータセットで学習したモデルを使い、用意したデータセットでその重みを調整します。これにより、小規模なデータセットでも精度のよいモデルを手早く作成することができるようになります。
torchvision では、ImageNet で事前学習済みのモデルが提供されています。使用できるモデルの一覧は以下の記事を参照してください。
Pytorch – 学習済みモデルで画像分類を行う方法 – pystyle
転移学習の方法
基本的な CNN の構造は、画像から特徴量を抽出するための特徴抽出器 (extractor と抽出された特徴量を元に分類を行う分類器 (classifier) の2つからなります。転移学習の場合、分類器の全結合層をデータセットのクラス数に合わせて変更し、特徴抽出器は事前学習済みモデルの重みで初期化します。
学習方法については、特徴抽出器と分類器の両方のパラメータを調整する方法 (Finetuning) と特徴抽出器のパラメータは固定して、分類器のパラメータのみ調整する2通りの方法があります。
それぞれ次のような特徴があります。
- 特徴抽出器と分類器の両方のパラメータを調整する (Finetuning)
- データセットは中規模以上
- 計算リソースがかかる
- データセットの規模がそれなりにあるなら、精度は分類器のパラメータのみ調整する場合より上がりやすい
- 分類器のパラメータのみ調整する
- データセットは小規模でも可
- 計算リソースが少ない
実装方法
公式チュートリアル「Transfer Learning for Computer Vision Tutorial」を参考にして進めていきます。
必要なモジュールを import する
データセットを用意する
題材として、蟻と鉢の2クラスの画像で構成される hymenoptera データセットを使用します。 データセットは こちら からダウンロードできます。

自作のデータセットを使う場合は、クラスごとに画像がディレクトリで分けられている以下のディレクトリ構造を用意してください。
用意したら、上記 [3]
の部分を代わりに以下に変更してください。
Transform を作成する
ImageNet の事前学習済みモデルで学習または推論を行う際に以下の前処理が必要です。
- 入力の大きさを (224, 224) にする
- 入力を RGB チャンネルごとに平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化する
また、学習時はランダムな切り抜き、左右反転によるデータオーグメンテーションを行います。
- 学習時
- ランダムに大きさ (224, 224) で切り抜く
- ランダムに左右反転を行う
- PIL Image をテンソルにする
- RGB チャンネルごとに平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化する
- 推論時
- 大きさ (256, 256) にリサイズする
- 大きさ (224, 224) で画像の中心を切り抜く
- PIL Image をテンソルにする
- RGB チャンネルごとに平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化する
これらの処理を行う Transform を作成します。
Dataset を作成する
hymenoptera データセットを読み込む Dataset を作成します。 先ほどダウンロードして解凍したデータセットは以下のディレクトリ構造になっています。
ディレクトリがこのような構造になっている場合、ImageFolder が利用できます。
Pytorch – 自作のデータセットを扱う Dataset クラスを作る方法 – pystyle
ImageFolder はフォルダ名からクラス名及びクラス ID を作成します。
クラス名の一覧は ImageFolder.classes
で取得できます。
['ants', 'bees']
DataLoader を作成する
Pytorch – Transforms、Dataset、DataLoader について解説 – pystyle
デバイスを作成する
Pytorch – 計算を行うデバイスを指定する方法について – pystyle
学習用のヘルパー関数を作成する
指定したエポック数だけ学習を行う train()
関数とその内部で呼び出す1エポックだけ学習する train_on_epoch()
を作成します。
Pytorch – Fashion-MNIST で CNN モデルによる画像分類を行う – pystyle
損失関数と精度の履歴を描画するヘルパー関数を作成する
損失関数と精度の履歴を描画するヘルパー関数を作成します。
Finetuning
モデルを作成する
今回は小規模なデータセットなので、ResNet-18 を利用します。 まず、モデルの全パラメータを学習する場合を記載します。
学習する
epoch 1 [train] loss: 0.129269, accuracy: 72% [test] loss: 0.074556, accuracy: 90% epoch 2 [train] loss: 0.112779, accuracy: 82% [test] loss: 0.065001, accuracy: 91% epoch 3 [train] loss: 0.110108, accuracy: 84% [test] loss: 0.110336, accuracy: 85% epoch 4 [train] loss: 0.118470, accuracy: 83% [test] loss: 0.104664, accuracy: 82% epoch 5 [train] loss: 0.089549, accuracy: 82% [test] loss: 0.090263, accuracy: 85% epoch 6 [train] loss: 0.091574, accuracy: 84% [test] loss: 0.073524, accuracy: 90% epoch 7 [train] loss: 0.127873, accuracy: 77% [test] loss: 0.223082, accuracy: 76% epoch 8 [train] loss: 0.081925, accuracy: 86% [test] loss: 0.072420, accuracy: 88% epoch 9 [train] loss: 0.094426, accuracy: 84% [test] loss: 0.059911, accuracy: 91% epoch 10 [train] loss: 0.061152, accuracy: 89% [test] loss: 0.059724, accuracy: 90%

推論結果を表示する
いくつかのサンプル画像に対して推論を行い、結果を表示します。

bees

ants

bees

bees
分類器のパラメータのみ調整する
モデルを作成する
次に分類器のパラメータのみ調整する場合のコードを記載します。
特徴抽出器のパラメータは調整しないので、各パラメータの勾配を計算するかどうかを決める属性 requires_grad
をすべて False
に設定します。
それ以外の部分は Finetuning の場合と同様です。
学習する
epoch 1 [train] loss: 0.158743, accuracy: 69% [test] loss: 0.050216, accuracy: 94% epoch 2 [train] loss: 0.125644, accuracy: 77% [test] loss: 0.140428, accuracy: 76% epoch 3 [train] loss: 0.154087, accuracy: 75% [test] loss: 0.044722, accuracy: 93% epoch 4 [train] loss: 0.149899, accuracy: 76% [test] loss: 0.047211, accuracy: 93% epoch 5 [train] loss: 0.119348, accuracy: 79% [test] loss: 0.043895, accuracy: 95% epoch 6 [train] loss: 0.144165, accuracy: 75% [test] loss: 0.074479, accuracy: 91% epoch 7 [train] loss: 0.122264, accuracy: 78% [test] loss: 0.060641, accuracy: 92% epoch 8 [train] loss: 0.105848, accuracy: 81% [test] loss: 0.049391, accuracy: 94% epoch 9 [train] loss: 0.106663, accuracy: 85% [test] loss: 0.048351, accuracy: 93% epoch 10 [train] loss: 0.104302, accuracy: 85% [test] loss: 0.059005, accuracy: 94%

コメント