Pytorch – torchvision の make_grid で複数の画像を並べて表示する方法

目次

概要

複数の画像から、それらをグリッド上に並べた画像を作成できる torchvision.utils.make_grid() の使い方について解説します。GAN や AutoEncoder などの生成系モデルにおいて、学習過程の画像を確認したい場合に便利です。

torchvision.utils.make_grid

torchvision.utils.make_grid(
    tensor: Union[torch.Tensor, List[torch.Tensor]],
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
    value_range: Optional[Tuple[int, int]] = None,
    scale_each: bool = False,
    pad_value: int = 0,
    **kwargs
) -> torch.Tensor
  • tensor: (B, C, H, W) のテンソルまたは (C, H, W) のテンソルにリスト
  • nrow: 行数
  • padding: 画像に追加するパディング幅 (px)
  • normalize: [0, 1] の範囲に正規化するかどうか
  • value_range: normalize=True の場合に、スケールする際の最小値、最大値として使用されます。指定しない場合は、テンソルの値から自動でスケールする際の最小値、最大値が計算されます。
  • scale_each: normalize=True の場合に、scale_each=True を指定すると、スケールする際の最小値、最大値がバッチ単位ではなく、画像単位で計算されます。
  • pad_value: 画像の間のマージン (px)
In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from IPython.display import display


def imshow(img):
    """ndarray 配列をインラインで Notebook 上に表示する。"""
    img = transforms.functional.to_pil_image(img)
    display(img)


# FashionMNIST データセットから画像を読み込む DataLoader を作成する。
dataset = torchvision.datasets.FashionMNIST(
    root="/data", train=True, transform=transforms.ToTensor(), download=True
)

data_loader = torch.utils.data.DataLoader(dataset, batch_size=64)

# 1バッチ (64枚) の画像を読み込む。
imgs, _ = next(iter(data_loader))
print(imgs.shape)

# グリッド上に並べて1枚の画像にする。
img = torchvision.utils.make_grid(imgs)
imshow(img)
torch.Size([64, 1, 28, 28])

nrow で1行に配置する画像枚数を変更できます。

In [2]:
img = torchvision.utils.make_grid(imgs, nrow=12)
imshow(img)

コメント

コメントする

目次