Pytorch Lightning – DataModule の使い方について解説

目次

概要

Pytorch Lightning の LightningDataModule について解説します。

LightningDataModule の役割

LightningDataModule はデータを処理するために必要なすべてのステップをカプセル化したクラスです。LightningDataModule クラスを継承して、以下の関数のうち、必要なものをオーバーロードして実装します。

以下は MNIST データセットの LightningDataModule の例です。

In [1]:
import torch
import torchvision
import torchvision.transforms as T
import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, root_dir="data", batch_size=32, num_workers=8):
        super().__init__()
        self.save_hyperparameters()

    def prepare_data(self) -> None:
        torchvision.datasets.MNIST(self.hparams.root_dir, download=True)

    def setup(self, stage):
        self.train_dataset = torchvision.datasets.MNIST(
            self.hparams.root_dir, train=True, transform=self.transform
        )
        self.test_dataset = torchvision.datasets.MNIST(
            self.hparams.root_dir, train=False, transform=self.transform
        )

    def train_dataloader(self):
        dataloader = torch.utils.data.DataLoader(
            self.train_dataset, batch_size=self.hparams.batch_size
        )
        return dataloader

    def test_dataloader(self):
        dataloader = torch.utils.data.DataLoader(
            self.test_dataset, batch_size=self.hparams.batch_size
        )
        return dataloader

    @property
    def transform(self):
        return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])

__init__()

引数にバッチサイズなどのパラメータを渡し、save_hyperparameters() を呼ぶことで、引数が self.hpparams に登録されて参照できるようになります。

In [2]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, root_dir="data", batch_size=32, num_workers=8):
        super().__init__()
        self.save_hyperparameters()
        
        print(self.hparams)
        print(self.hparams.root_dir)
        print(self.hparams.batch_size)
        print(self.hparams.num_workers)

dm = MNISTDataModule()
"batch_size":  32
"num_workers": 8
"root_dir":    data
data
32
8

prepare_data()

ダウンロードなどデータセットの準備を行います。この関数はマルチ GPU 環境では、1つの GPU でしか呼ばれないため、self.dataset = Dataset() のようにクラスの状態を変更するような処理は行ってはいけません。

def prepare_data(self) -> None:
    torchvision.datasets.MNIST(self.hparams.root_dir, download=True)

setup(stage)

データを読み込んでデータセットを作成するなどの処理を行います。どの段階のデータセットを準備する必要があるのかが引数 stage で渡されるので、以下のように必要なデータセットだけ作成するようにすることもできます。

stage の値
学習 “fit”
検証 “validate”
テスト “test”
推論 “predict”
def setup(self, stage):
    if stage == "fit" or stage is None:
        self.train_dataset = torchvision.datasets.MNIST(
            self.hparams.root_dir, train=True, transform=self.transform
        )

    if stage == "test" or stage is None:
        self.test_dataset = torchvision.datasets.MNIST(
            self.hparams.root_dir, train=False, transform=self.transform
        )

train_dataloader()

学習中に使用するバッチを生成する DataLoader を返すようにします。DataLoader が生成したバッチは学習中に LightningModule の training_step() に渡されます。 DataLoader は複数返すこともでき、train_dataloader() の返り値と training_step(batch, batch_idx)batch に渡される値の関係は以下のようになります。

  • 返り値が DataLoader

    dl -> batch
  • 返り値が Sequence[DataLoader]

    dl1, dl2 -> [batch1, batch2]
  • 返り値が Sequence[Sequence[DataLoader]]

    [[dl1, dl2], [dl3, dl4]] -> [[batch1, batch2], [batch3, batch4]]
  • 返り値が Sequence[Dict[str, DataLoader]]

    {"dl1": dl1, "dl2": dl2} -> {'dl1': batch1, 'dl2': batch2}
  • 返り値が Dict[str, Dict[str, DataLoader]]

    {"dl_12": {"dl1": dl1, "dl2": dl2}, "dl_34": {"dl1": dl1, "dl2": dl2}}
      -> {"dl_12": {"dl1": batch1, "dl2": batch2}, "dl_34": {"dl1": batch3, "dl2": batch4}})
  • 返り値が Dict[str, Sequence[DataLoader]]

    {"dl_12": [dl1, dl2], "dl_34": [dl3, dl4]} -> {'dl_12': [batch1, batch2], 'dl_34': [batch3, batch4]}

val_dataloader()

検証中に使用するバッチを生成する DataLoader を返すようにします。DataLoader が生成したバッチは学習中に LightningModule の validation_step() に渡されます。 val_dataloader() の返り値が DataLoader が1つなのか、DataLoader のリストなのかによって挙動が変わります。

  • 返り値が DataLoader
    • DataLoader が生成するバッチが val_step(self, batch, batch_idx)batch に渡されます。
  • 返り値が Sequence[DataLoader]
    • リストの先頭の DataLoader から順番に使われます。
    • val_step(self, batch, batch_idx, dataloader_id)dataloader_id にバッチを生成した DataLoader のインデックスが入っています。

test_dataloader()

テスト中に使用するバッチを生成する DataLoader を返すようにします。DataLoader が生成したバッチは学習中に LightningModule の test_step() に渡されます。 test_dataloader() の返り値が DataLoader が1つなのか、DataLoader のリストなのかによって挙動が変わります。

  • 返り値が DataLoader
    • DataLoader が生成するバッチが test_step(self, batch, batch_idx)batch に渡されます。
  • 返り値が Sequence[DataLoader]
    • リストの先頭の DataLoader から順番に使われます。
    • test_step(self, batch, batch_idx, dataloader_id)dataloader_id にバッチを生成した DataLoader のインデックスが入っています。

predict_dataloader()

テスト中に使用するバッチを生成する DataLoader を返すようにします。DataLoader が生成したバッチは学習中に LightningModule の predict_step() に渡されます。 predict_dataloader() の返り値が DataLoader が1つなのか、DataLoader のリストなのかによって挙動が変わります。

  • 返り値が DataLoader
    • DataLoader が生成するバッチが predict_step(self, batch, batch_idx)batch に渡されます。
  • 返り値が Sequence[DataLoader]
    • リストの先頭の DataLoader から順番に使われます。
    • predict_step(self, batch, batch_idx, dataloader_id)dataloader_id にバッチを生成した DataLoader のインデックスが入っています。

コメント

コメントする

目次