概要
Pytorch Lightning の LightningDataModule について解説します。
LightningDataModule の役割
LightningDataModule はデータを処理するために必要なすべてのステップをカプセル化したクラスです。LightningDataModule クラスを継承して、以下の関数のうち、必要なものをオーバーロードして実装します。
以下は MNIST データセットの LightningDataModule の例です。
import torch
import torchvision
import torchvision.transforms as T
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, root_dir="data", batch_size=32, num_workers=8):
super().__init__()
self.save_hyperparameters()
def prepare_data(self) -> None:
torchvision.datasets.MNIST(self.hparams.root_dir, download=True)
def setup(self, stage):
self.train_dataset = torchvision.datasets.MNIST(
self.hparams.root_dir, train=True, transform=self.transform
)
self.test_dataset = torchvision.datasets.MNIST(
self.hparams.root_dir, train=False, transform=self.transform
)
def train_dataloader(self):
dataloader = torch.utils.data.DataLoader(
self.train_dataset, batch_size=self.hparams.batch_size
)
return dataloader
def test_dataloader(self):
dataloader = torch.utils.data.DataLoader(
self.test_dataset, batch_size=self.hparams.batch_size
)
return dataloader
@property
def transform(self):
return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
__init__()
引数にバッチサイズなどのパラメータを渡し、save_hyperparameters()
を呼ぶことで、引数が self.hpparams
に登録されて参照できるようになります。
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, root_dir="data", batch_size=32, num_workers=8):
super().__init__()
self.save_hyperparameters()
print(self.hparams)
print(self.hparams.root_dir)
print(self.hparams.batch_size)
print(self.hparams.num_workers)
dm = MNISTDataModule()
"batch_size": 32 "num_workers": 8 "root_dir": data data 32 8
prepare_data()
ダウンロードなどデータセットの準備を行います。この関数はマルチ GPU 環境では、1つの GPU でしか呼ばれないため、self.dataset = Dataset()
のようにクラスの状態を変更するような処理は行ってはいけません。
def prepare_data(self) -> None:
torchvision.datasets.MNIST(self.hparams.root_dir, download=True)
setup(stage)
データを読み込んでデータセットを作成するなどの処理を行います。どの段階のデータセットを準備する必要があるのかが引数 stage
で渡されるので、以下のように必要なデータセットだけ作成するようにすることもできます。
stage の値 | |
---|---|
学習 | “fit” |
検証 | “validate” |
テスト | “test” |
推論 | “predict” |
def setup(self, stage):
if stage == "fit" or stage is None:
self.train_dataset = torchvision.datasets.MNIST(
self.hparams.root_dir, train=True, transform=self.transform
)
if stage == "test" or stage is None:
self.test_dataset = torchvision.datasets.MNIST(
self.hparams.root_dir, train=False, transform=self.transform
)
train_dataloader()
学習中に使用するバッチを生成する DataLoader を返すようにします。DataLoader が生成したバッチは学習中に LightningModule の training_step() に渡されます。
DataLoader は複数返すこともでき、train_dataloader()
の返り値と training_step(batch, batch_idx)
の batch
に渡される値の関係は以下のようになります。
返り値が DataLoader
dl -> batch
返り値が Sequence[DataLoader]
dl1, dl2 -> [batch1, batch2]
返り値が Sequence[Sequence[DataLoader]]
[[dl1, dl2], [dl3, dl4]] -> [[batch1, batch2], [batch3, batch4]]
返り値が Sequence[Dict[str, DataLoader]]
{"dl1": dl1, "dl2": dl2} -> {'dl1': batch1, 'dl2': batch2}
返り値が Dict[str, Dict[str, DataLoader]]
{"dl_12": {"dl1": dl1, "dl2": dl2}, "dl_34": {"dl1": dl1, "dl2": dl2}} -> {"dl_12": {"dl1": batch1, "dl2": batch2}, "dl_34": {"dl1": batch3, "dl2": batch4}})
返り値が Dict[str, Sequence[DataLoader]]
{"dl_12": [dl1, dl2], "dl_34": [dl3, dl4]} -> {'dl_12': [batch1, batch2], 'dl_34': [batch3, batch4]}
val_dataloader()
検証中に使用するバッチを生成する DataLoader を返すようにします。DataLoader が生成したバッチは学習中に LightningModule の validation_step() に渡されます。
val_dataloader()
の返り値が DataLoader が1つなのか、DataLoader のリストなのかによって挙動が変わります。
- 返り値が DataLoader
- DataLoader が生成するバッチが
val_step(self, batch, batch_idx)
のbatch
に渡されます。
- DataLoader が生成するバッチが
- 返り値が Sequence[DataLoader]
- リストの先頭の DataLoader から順番に使われます。
val_step(self, batch, batch_idx, dataloader_id)
のdataloader_id
にバッチを生成した DataLoader のインデックスが入っています。
test_dataloader()
テスト中に使用するバッチを生成する DataLoader を返すようにします。DataLoader が生成したバッチは学習中に LightningModule の test_step() に渡されます。
test_dataloader()
の返り値が DataLoader が1つなのか、DataLoader のリストなのかによって挙動が変わります。
- 返り値が DataLoader
- DataLoader が生成するバッチが
test_step(self, batch, batch_idx)
のbatch
に渡されます。
- DataLoader が生成するバッチが
- 返り値が Sequence[DataLoader]
- リストの先頭の DataLoader から順番に使われます。
test_step(self, batch, batch_idx, dataloader_id)
のdataloader_id
にバッチを生成した DataLoader のインデックスが入っています。
predict_dataloader()
テスト中に使用するバッチを生成する DataLoader を返すようにします。DataLoader が生成したバッチは学習中に LightningModule の predict_step()
に渡されます。
predict_dataloader()
の返り値が DataLoader が1つなのか、DataLoader のリストなのかによって挙動が変わります。
- 返り値が DataLoader
- DataLoader が生成するバッチが
predict_step(self, batch, batch_idx)
のbatch
に渡されます。
- DataLoader が生成するバッチが
- 返り値が Sequence[DataLoader]
- リストの先頭の DataLoader から順番に使われます。
predict_step(self, batch, batch_idx, dataloader_id)
のdataloader_id
にバッチを生成した DataLoader のインデックスが入っています。
コメント