概要
画像のクラス分類問題の学習を Pytorch Lightning を使用して行う方法について解説します。Pytorch で行う場合のコードは以下の記事で解説していますが、Pytorch Lightning を使用することで Pytorch の冗長なコードを大幅に減らすことができ、コードの見通しがよくなります。
[blogcard url=”https://pystyle.info/pytorch-train-classification-problem-using-a-pretrained-model”]
準備
Pytorch、torchvision、Pytorch Lightning は pip でインストールできます。
pip install pytorch torchvision pytorch-lightning
手順
1. 必要なモジュールを import する
from pathlib import Path
from typing import Optional
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
2. LightningDataModule を作成する
LightningDataModule はデータを処理するために必要なすべてのステップをカプセル化した、共有・再利用可能なクラスです。基本的に以下の関数をオーバーライドして、処理を記述します。
import pytorch_lightning as pl
class MyDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
pass
# データセットのダウンロードなどデータを準備する処理
def setup(self, stage: Optional[str] = None):
pass
# データセットを使用するための準備
# * データセットを学習用、検証用、テスト用に分ける
# * クラス数を変える
# * Transform を作成する
def train_dataloader(self):
pass
# 学習用の DataLoader を返す
def val_dataloader(self):
pass
# 検証用のDataLoader を返す
def test_dataloader(self):
pass
# テスト用のDataLoader を返す
題材として、蟻と鉢の2クラスの画像で構成される hymenoptera データセットを使用します。
class HymenopteraDataModule(pl.LightningDataModule):
def __init__(self, batch_size: int, data_dir: str = "./"):
super().__init__()
self.data_dir = Path(data_dir)
self.batch_size = batch_size
self.train_transform = self._get_transform(train=True)
self.test_transform = self._get_transform(train=False)
def prepare_data(self):
# ダウンロードして、解凍する。
url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
torchvision.datasets.utils.download_and_extract_archive(url, self.data_dir)
def setup(self, stage: Optional[str] = None):
if stage == "fit" or stage is None:
# train/ の画像は学習、検証用に使用する。
dataset = torchvision.datasets.ImageFolder(
self.data_dir / "hymenoptera_data/train"
)
# train/ の画像を学習用と検証用に分割する。
n_train = int(len(dataset) * 0.8)
n_val = len(dataset) - n_train
self.train_dataset, self.val_dataset = torch.utils.data.random_split(
dataset, [n_train, n_val]
)
self.train_dataset.dataset.transform = self.train_transform
self.val_dataset.dataset.transform = self.test_transform
self.classes = self.train_dataset.dataset.classes
if stage == "test" or stage is None:
# val/ の画像はテスト用に使用する。
self.test_dataset = torchvision.datasets.ImageFolder(
self.data_dir / "hymenoptera_data/val"
)
self.test_dataset.transform = self.test_transform
self.classes = self.test_dataset.classes
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_dataset, batch_size=self.batch_size, num_workers=4
)
def test_dataloader(self):
return torch.utils.data.DataLoader(
self.test_dataset, batch_size=self.batch_size, num_workers=4
)
def predict_dataloader(self):
return torch.utils.data.DataLoader(
self.test_dataset, batch_size=self.batch_size, num_workers=4
)
def _get_transform(self, train: bool):
lst = []
if train:
# 学習時に適用する Transform
lst += [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]
else:
# 推論時に適用する Transform
lst += [
transforms.Resize(256),
transforms.CenterCrop(224),
]
lst += [
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
transform = transforms.Compose(lst)
return transform
3. LightningModule を作成する
LightningModuleは、以下の5つのステップをカプセル化したクラスです。
- モデルの作成
- 学習のステップ
- 検証のステップ
- テストのステップ
- Optimizer の設定
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
# モデルの作成など初期化処理を記述する。
def forward(self, x):
pass
# 純伝搬の処理を記述する。
def training_step(self, batch, batch_idx):
pass
# 学習する処理を記述する。
def configure_optimizers(self):
pass
# 使用する Optimizer や Scheduler を返す。
class LitModel(pl.LightningModule):
def __init__(self, num_classes):
super().__init__()
# ResNet-18 を作成する。
self.model = torchvision.models.resnet18(pretrained=True)
# 出力層の出力数を ImageNet の 1000 からこのデータセットのクラス数である 2 に置き換える。
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
loss, acc = self._shared_step(batch)
self.log("train_loss", loss)
self.log("train_acc", acc)
return loss
def validation_step(self, batch, batch_idx):
loss, acc = self._shared_step(batch)
self.log("val_loss", loss)
self.log("val_acc", acc)
return loss
def test_step(self, batch, batch_idx):
loss, acc = self._shared_step(batch)
self.log("test_loss", loss, prog_bar=True)
self.log("test_acc", acc, prog_bar=True)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, _ = batch
y_hat = self.model(x)
preds = torch.argmax(y_hat, dim=1)
return preds
def configure_optimizers(self):
optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
return [optimizer], [scheduler]
def _shared_step(self, batch):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# accuracy を計算する。
preds = torch.argmax(y_hat, dim=1)
acc = (y == preds).sum() / y.size(0)
return loss, acc
4. Trainer を作成する
LightningDataModule 及び LightningModule を作成したら、その2つを引数に Trainer を作成します。
dm = HymenopteraDataModule(batch_size=8, data_dir="/data/hymenoptera_data/")
model = LitModel(num_classes=2)
trainer = pl.Trainer(max_epochs=10, log_every_n_steps=1, gpus=1)
5. 学習する
学習は Trainer.fit() を呼び出すと行えます。
# 学習する。
# !rm -r lightning_logs
trainer.fit(model, dm)
Using downloaded and verified file: /data/hymenoptera_data/hymenoptera_data.zip Extracting /data/hymenoptera_data/hymenoptera_data.zip to /data/hymenoptera_data
5. 学習結果を確認する
デフォルトでは、学習のログは TensorBoard 形式で lightning_logs/version_${バージョン}
というディレクトリに出力されます。
lightning_logs
├── version_${バージョン}
│ ├── checkpoints: チェックポイントのファイルがあるディレクトリ
│ ├── events.out.tfevents.*: TensorBoard 形式のログ
│ └── hparams.yaml: ハイパーパラメータが記載されたファイル
TensorBoard 形式のログは JupyterLab の場合、tensorboard
の拡張でインラインで表示できます。
- epoch: エポック
- train_acc: 学習データに対するイテレーションごとの精度
- train_acc_epoch: 学習データに対するエポックごとに集計した精度
- train_loss: 学習データに対するイテレーションごとの損失
- train_loss_epoch: 学習データに対するエポックごとに集計した損失
- val_acc: 検証データに対するイテレーションごとの精度
- val_acc_epoch: 検証データに対するエポックごとに集計した精度
- val_loss: 検証データに対するイテレーションごとの損失
- val_loss_epoch: 検証データに対するエポックごとに集計した損失
%reload_ext tensorboard
%tensorboard --logdir ./lightning_logs --bind_all --port 6006
テストする
Trainer.test() でテストが行えます。test_step()
内の log()
で出力している値を集計した結果が表示されます。
ret = trainer.test(datamodule=dm, ckpt_path="best")
-------------------------------------------------------------------------------- DATALOADER:0 TEST RESULTS {'test_acc': 0.9477124214172363, 'test_loss': 0.16836263239383698} --------------------------------------------------------------------------------
コメント