目次
概要
MNIST データセットについて解説します。
MNIST データセット
MNIST は Modified National Institute of Standards and Technology の訳で、アメリカ国立標準技術研究所 (NIST) が保有している手書き文字画像のデータを修正して作成したデータセットです。
- データセットの種類: クラス分類
- 画像枚数: 学習60000、テスト10000
- 画像: 28×28 のグレースケール画像
Pytorch の MNIST データセット
torchvision.datasets.MNIST
で提供されています。このデータセットは (28, 28, 1) の画像及び正解ラベル (0 ~ 9) を返します。
In [1]:
import torch
import torchvision
import torchvision.transforms as T
from IPython.display import display
def imshow(img):
img = T.functional.to_pil_image(img)
display(img)
train_data = torchvision.datasets.MNIST(root="/data", train=True, transform=T.ToTensor())
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64)
# 1バッチ取得する。
imgs, _ = next(iter(train_dataloader))
# グリッド上に並べて1枚の画像にする。
img = torchvision.utils.make_grid(imgs)
imshow(img)
クラスごとの枚数を計算する
In [2]:
import torch
import torchvision
import torchvision.transforms as T
train_data = torchvision.datasets.MNIST(root="/data", train=True, transform=T.ToTensor())
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64)
n_imgs = torch.bincount(train_data.targets)
for label, n_label_imgs in zip(train_data.classes, n_imgs):
print(f"{label}: {n_label_imgs}")
0 - zero: 5923 1 - one: 6742 2 - two: 5958 3 - three: 6131 4 - four: 5842 5 - five: 5421 6 - six: 5918 7 - seven: 6265 8 - eight: 5851 9 - nine: 5949
コメント