Pytorch – MNIST データセットについて解説

目次

概要

MNIST データセットについて解説します。

MNIST データセット

MNIST は Modified National Institute of Standards and Technology の訳で、アメリカ国立標準技術研究所 (NIST) が保有している手書き文字画像のデータを修正して作成したデータセットです。

  • データセットの種類: クラス分類
  • 画像枚数: 学習60000、テスト10000
  • 画像: 28×28 のグレースケール画像

Pytorch の MNIST データセット

torchvision.datasets.MNIST で提供されています。このデータセットは (28, 28, 1) の画像及び正解ラベル (0 ~ 9) を返します。

In [1]:
import torch
import torchvision
import torchvision.transforms as T
from IPython.display import display


def imshow(img):
    img = T.functional.to_pil_image(img)
    display(img)


train_data = torchvision.datasets.MNIST(root="/data", train=True, transform=T.ToTensor())
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64)

# 1バッチ取得する。
imgs, _ = next(iter(train_dataloader))

# グリッド上に並べて1枚の画像にする。
img = torchvision.utils.make_grid(imgs)
imshow(img)

クラスごとの枚数を計算する

In [2]:
import torch
import torchvision
import torchvision.transforms as T

train_data = torchvision.datasets.MNIST(root="/data", train=True, transform=T.ToTensor())
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64)

n_imgs = torch.bincount(train_data.targets)
for label, n_label_imgs in zip(train_data.classes, n_imgs):
    print(f"{label}: {n_label_imgs}")
0 - zero: 5923
1 - one: 6742
2 - two: 5958
3 - three: 6131
4 - four: 5842
5 - five: 5421
6 - six: 5918
7 - seven: 6265
8 - eight: 5851
9 - nine: 5949

コメント

コメントする

目次