Pytorch Lightning – 事前学習モデルを使ってクラス分類モデルを学習する方法

目次

概要

画像のクラス分類問題の学習を Pytorch Lightning を使用して行う方法について解説します。Pytorch で行う場合のコードは以下の記事で解説していますが、Pytorch Lightning を使用することで Pytorch の冗長なコードを大幅に減らすことができ、コードの見通しがよくなります。

[blogcard url=”https://pystyle.info/pytorch-train-classification-problem-using-a-pretrained-model”]

準備

Pytorch、torchvision、Pytorch Lightning は pip でインストールできます。

pip install pytorch torchvision pytorch-lightning

手順

1. 必要なモジュールを import する

In [1]:
from pathlib import Path
from typing import Optional

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

2. LightningDataModule を作成する

LightningDataModule はデータを処理するために必要なすべてのステップをカプセル化した、共有・再利用可能なクラスです。基本的に以下の関数をオーバーライドして、処理を記述します。

import pytorch_lightning as pl

class MyDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

    def prepare_data(self):
        pass
        # データセットのダウンロードなどデータを準備する処理

    def setup(self, stage: Optional[str] = None):
        pass
        # データセットを使用するための準備
        # * データセットを学習用、検証用、テスト用に分ける
        # * クラス数を変える
        # * Transform を作成する

    def train_dataloader(self):
        pass
        # 学習用の DataLoader を返す

    def val_dataloader(self):
        pass
        # 検証用のDataLoader を返す

    def test_dataloader(self):
        pass
        # テスト用のDataLoader を返す

題材として、蟻と鉢の2クラスの画像で構成される hymenoptera データセットを使用します。

データセット

In [2]:
class HymenopteraDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int, data_dir: str = "./"):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.train_transform = self._get_transform(train=True)
        self.test_transform = self._get_transform(train=False)

    def prepare_data(self):
        # ダウンロードして、解凍する。
        url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
        torchvision.datasets.utils.download_and_extract_archive(url, self.data_dir)

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            # train/ の画像は学習、検証用に使用する。
            dataset = torchvision.datasets.ImageFolder(
                self.data_dir / "hymenoptera_data/train"
            )

            # train/ の画像を学習用と検証用に分割する。
            n_train = int(len(dataset) * 0.8)
            n_val = len(dataset) - n_train
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(
                dataset, [n_train, n_val]
            )
            self.train_dataset.dataset.transform = self.train_transform
            self.val_dataset.dataset.transform = self.test_transform
            self.classes = self.train_dataset.dataset.classes

        if stage == "test" or stage is None:
            # val/ の画像はテスト用に使用する。
            self.test_dataset = torchvision.datasets.ImageFolder(
                self.data_dir / "hymenoptera_data/val"
            )
            self.test_dataset.transform = self.test_transform
            self.classes = self.test_dataset.classes

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=4
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=4
        )

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=4
        )

    def _get_transform(self, train: bool):
        lst = []
        if train:
            # 学習時に適用する Transform
            lst += [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
            ]
        else:
            # 推論時に適用する Transform
            lst += [
                transforms.Resize(256),
                transforms.CenterCrop(224),
            ]

        lst += [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]

        transform = transforms.Compose(lst)

        return transform

3. LightningModule を作成する

LightningModuleは、以下の5つのステップをカプセル化したクラスです。

  • モデルの作成
  • 学習のステップ
  • 検証のステップ
  • テストのステップ
  • Optimizer の設定
import pytorch_lightning as pl

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # モデルの作成など初期化処理を記述する。

    def forward(self, x):
        pass
        # 純伝搬の処理を記述する。

    def training_step(self, batch, batch_idx):
        pass
        # 学習する処理を記述する。

    def configure_optimizers(self):
        pass
        # 使用する Optimizer や Scheduler を返す。
In [3]:
class LitModel(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        # ResNet-18 を作成する。
        self.model = torchvision.models.resnet18(pretrained=True)

        # 出力層の出力数を ImageNet の 1000 からこのデータセットのクラス数である 2 に置き換える。
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss, acc = self._shared_step(batch)

        self.log("train_loss", loss)
        self.log("train_acc", acc)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_step(batch)

        self.log("val_loss", loss)
        self.log("val_acc", acc)

        return loss

    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_step(batch)

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, _ = batch
        y_hat = self.model(x)
        preds = torch.argmax(y_hat, dim=1)

        return preds

    def configure_optimizers(self):
        optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

        return [optimizer], [scheduler]

    def _shared_step(self, batch):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)

        # accuracy を計算する。
        preds = torch.argmax(y_hat, dim=1)
        acc = (y == preds).sum() / y.size(0)

        return loss, acc

4. Trainer を作成する

LightningDataModule 及び LightningModule を作成したら、その2つを引数に Trainer を作成します。

In [4]:
dm = HymenopteraDataModule(batch_size=8, data_dir="/data/hymenoptera_data/")
model = LitModel(num_classes=2)

trainer = pl.Trainer(max_epochs=10, log_every_n_steps=1, gpus=1)
GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs

5. 学習する

学習は Trainer.fit() を呼び出すと行えます。

In [5]:
# 学習する。
# !rm -r lightning_logs
trainer.fit(model, dm)
LOCAL_RANK: 0 – CUDA_VISIBLE_DEVICES: [0]
Using downloaded and verified file: /data/hymenoptera_data/hymenoptera_data.zip
Extracting /data/hymenoptera_data/hymenoptera_data.zip to /data/hymenoptera_data
| Name | Type | Params ——————————— 0 | model | ResNet | 11.2 M ——————————— 11.2 M Trainable params 0 Non-trainable params 11.2 M Total params 44.710 Total estimated model params size (MB)

5. 学習結果を確認する

デフォルトでは、学習のログは TensorBoard 形式で lightning_logs/version_${バージョン} というディレクトリに出力されます。

lightning_logs
├── version_${バージョン}
│   ├── checkpoints: チェックポイントのファイルがあるディレクトリ
│   ├── events.out.tfevents.*: TensorBoard 形式のログ
│   └── hparams.yaml: ハイパーパラメータが記載されたファイル

TensorBoard 形式のログは JupyterLab の場合、tensorboard の拡張でインラインで表示できます。

  • epoch: エポック
  • train_acc: 学習データに対するイテレーションごとの精度
  • train_acc_epoch: 学習データに対するエポックごとに集計した精度
  • train_loss: 学習データに対するイテレーションごとの損失
  • train_loss_epoch: 学習データに対するエポックごとに集計した損失
  • val_acc: 検証データに対するイテレーションごとの精度
  • val_acc_epoch: 検証データに対するエポックごとに集計した精度
  • val_loss: 検証データに対するイテレーションごとの損失
  • val_loss_epoch: 検証データに対するエポックごとに集計した損失

TensorBoard

In [6]:
%reload_ext tensorboard
%tensorboard --logdir ./lightning_logs --bind_all --port 6006

テストする

Trainer.test() でテストが行えます。test_step() 内の log() で出力している値を集計した結果が表示されます。

In [7]:
ret = trainer.test(datamodule=dm, ckpt_path="best")
Restoring states from the checkpoint path at /data/notebook/pystyle/pytorch/complate/pytorch-lightning-image-classification-train-hymenoptera/lightning_logs/version_3/checkpoints/epoch=9-step=249.ckpt LOCAL_RANK: 0 – CUDA_VISIBLE_DEVICES: [0] Loaded model weights from checkpoint at /data/notebook/pystyle/pytorch/complate/pytorch-lightning-image-classification-train-hymenoptera/lightning_logs/version_3/checkpoints/epoch=9-step=249.ckpt
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.9477124214172363, 'test_loss': 0.16836263239383698}
--------------------------------------------------------------------------------

参考

コメント

コメントする

目次