概要
画像のクラス分類問題の学習を Pytorch Lightning を使用して行う方法について解説します。Pytorch で行う場合のコードは以下の記事で解説していますが、Pytorch Lightning を使用することで Pytorch の冗長なコードを大幅に減らすことができ、コードの見通しがよくなります。
[blogcard url=”https://pystyle.info/pytorch-train-classification-problem-using-a-pretrained-model”]
準備
Pytorch、torchvision、Pytorch Lightning は pip でインストールできます。
手順
1. 必要なモジュールを import する
2. LightningDataModule を作成する
LightningDataModule はデータを処理するために必要なすべてのステップをカプセル化した、共有・再利用可能なクラスです。基本的に以下の関数をオーバーライドして、処理を記述します。
題材として、蟻と鉢の2クラスの画像で構成される hymenoptera データセットを使用します。

3. LightningModule を作成する
LightningModuleは、以下の5つのステップをカプセル化したクラスです。
- モデルの作成
- 学習のステップ
- 検証のステップ
- テストのステップ
- Optimizer の設定
4. Trainer を作成する
LightningDataModule 及び LightningModule を作成したら、その2つを引数に Trainer を作成します。
5. 学習する
学習は Trainer.fit() を呼び出すと行えます。
Using downloaded and verified file: /data/hymenoptera_data/hymenoptera_data.zip Extracting /data/hymenoptera_data/hymenoptera_data.zip to /data/hymenoptera_data
5. 学習結果を確認する
デフォルトでは、学習のログは TensorBoard 形式で lightning_logs/version_${バージョン}
というディレクトリに出力されます。
TensorBoard 形式のログは JupyterLab の場合、tensorboard
の拡張でインラインで表示できます。
- epoch: エポック
- train_acc: 学習データに対するイテレーションごとの精度
- train_acc_epoch: 学習データに対するエポックごとに集計した精度
- train_loss: 学習データに対するイテレーションごとの損失
- train_loss_epoch: 学習データに対するエポックごとに集計した損失
- val_acc: 検証データに対するイテレーションごとの精度
- val_acc_epoch: 検証データに対するエポックごとに集計した精度
- val_loss: 検証データに対するイテレーションごとの損失
- val_loss_epoch: 検証データに対するエポックごとに集計した損失

テストする
Trainer.test() でテストが行えます。test_step()
内の log()
で出力している値を集計した結果が表示されます。
-------------------------------------------------------------------------------- DATALOADER:0 TEST RESULTS {'test_acc': 0.9477124214172363, 'test_loss': 0.16836263239383698} --------------------------------------------------------------------------------
コメント