Pytorch Lightning – Auto Encoder で MNIST の特徴表現を学習する

目次

概要

Auto Encoder について解説し、Pytorch Lightning を使用した実装例を紹介します。

Auto Encoder

オートエンコーダ (Auto Encoder) はラベル付けされていないデータの特徴量表現を学習するためのニューラルネットワークの一種です。

Auto Encoder は、入力を特徴量に変換したのち、その特徴量から再び、入力と同じ画像を生成できるように学習します。特徴量の次元は入力データより小さいので、学習が上手くいった場合は、もとのデータを表現可能な特徴量を学習できたことになります。

Auto Encoder

手順

1. 必要なモジュールを import する

In [1]:
from pathlib import Path
from typing import Optional, Union, Tuple

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

2. LightningDataModule を作成する

今回は入力データに MNIST を使用します。

In [2]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir="./"):
        super().__init__()

        dataset = torchvision.datasets.MNIST(
            data_dir, train=True, download=True, transform=transforms.ToTensor()
        )
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            dataset, [55000, 5000]
        )
        self.test_dataset = torchvision.datasets.MNIST(
            data_dir, train=False, download=True, transform=transforms.ToTensor()
        )

        self.batch_size = batch_size
        self.data_dir = data_dir

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset, batch_size=self.batch_size, num_workers=8
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=8
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=8
        )

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=8
        )

3. LightningModule を作成する

LightningModule にモデルや各処理を定義します。 モデルの構造ですが、

  1. 784次元の入力層
  2. 64次元の全結合層
  3. ReLU
  4. 32次元の全結合層
  5. 64次元の全結合層
  6. ReLU
  7. 784次元の全結合層

となっており、入力画像と同じ画像を出力することが学習の目標なので、入力画像と出力画像の差異を mse_loss() で計算します。

In [3]:
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 32)
        )
        self.decoder = nn.Sequential(
            nn.Linear(32, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28)
        )

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        self._common_step(batch, batch_idx, "test")

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, _ = batch
        return self(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def _prepare_batch(self, batch):
        x, _ = batch
        return x.view(x.size(0), -1)

    def _common_step(self, batch, batch_idx, stage):
        x = self._prepare_batch(batch)
        loss = F.mse_loss(x, self(x))
        self.log(f"{stage}_loss", loss, on_step=True)
        return loss

4. 学習途中の生成結果を画像化するコールバック関数を作成する

AutoEncoder の学習過程の生成結果を画像化するコールバック関数を作成します。

  1. 生成する枚数を torch.utils.data.DataLoader() のバッチサイズに指定し、検証用のデータセットから画像を取得します。
  2. (N, C, H, W) -> (N, C H W) に形状を変更します。
  3. Auto Encoder に入力し、画像を生成します。
  4. (N, C H W) -> (N, C, H, W) に形状を変更します。
  5. torchvision.utils.make_grid() で画像を結合します。
  6. torchvision.utils.save_image() で画像を保存します。

[blogcard url=”https://pystyle.info/pytorch-make-grid”]

In [4]:
class ImageSampler(pl.callbacks.Callback):
    def __init__(
        self, output_dir: str = "./generated", num_samples: int = 64, nrow: int = 8
    ):
        super().__init__()
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True)

        self.output_dir = output_dir
        self.num_samples = num_samples
        self.nrow = nrow

    def on_epoch_end(self, trainer: pl.Trainer, model: pl.LightningModule) -> None:
        # 生成する枚数をバッチサイズに指定し、検証用のデータセットから取得する。
        dataloader = torch.utils.data.DataLoader(
            trainer.datamodule.val_dataset, batch_size=self.num_samples
        )
        img_in, _ = next(iter(dataloader))

        x = img_in.view(img_in.size(0), -1)  # (N, C, H, W) -> (N, C * H * W)

        with torch.no_grad():
            model.eval()
            y = model(x.to(model.device))
            model.train()

        img_out = y.reshape(img_in.shape)  # (N, C * H * W) -> (N, C, H, W)

        # 画像を結合する。
        img_grid = torchvision.utils.make_grid(img_out, nrow=self.nrow, normalize=True)

        # 画像を保存する。
        save_path = self.output_dir / f"generated_epoch{trainer.current_epoch}.png"
        torchvision.utils.save_image(img_grid, save_path)

5. Trainer を作成して、学習する

In [5]:
dm = MNISTDataModule(batch_size=32, data_dir="/data/MNIST/")
model = LitAutoEncoder()

trainer = pl.Trainer(
    max_epochs=10, log_every_n_steps=1, gpus=1, callbacks=[ImageSampler()]
)
trainer.fit(model, dm)
GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs LOCAL_RANK: 0 – CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ————————————— 0 | encoder | Sequential | 52.3 K 1 | decoder | Sequential | 53.1 K ————————————— 105 K Trainable params 0 Non-trainable params 105 K Total params 0.422 Total estimated model params size (MB)

genereated ディレクトリに学習途中の生成画像の例が保存されます。

生成された画像

Auto Encoder で次元削減

学習した Auto Encoder の Encoder で抽出した特徴量を使って、学習してみます。

  1. 学習した Auto Encoder の Encoder で (784,) -> (32,) に次元削減する
  2. SVM で学習する
  3. 精度を計算する
In [6]:
def extract(dataloader):
    X, y = [], []
    for x_batch, y_batch in dataloader:
        x_batch = x_batch.view(x_batch.size(0), -1)
        feat = model.encoder(x_batch)

        X.append(feat)
        y.append(y_batch)

    X = torch.concat(X).detach().numpy()
    y = torch.concat(y).detach().numpy()

    return X, y


# 学習データ、テストデータを Encoder で特徴量に変換する
X_train, y_train = extract(dm.train_dataloader())
X_test, y_test = extract(dm.test_dataloader())
In [7]:
from sklearn import metrics, svm

clf = svm.SVC(kernel="linear", C=1, random_state=0)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

acc = metrics.accuracy_score(y_test, y_pred)
print(f"accuracy: {acc:.2%}")
accuracy: 92.70%

784次元の入力データを32次元に削減しても精度が出ているので、Encoder によって次元削減が上手く行えていることがわかりました。

コメント

コメントする

目次