Pytorch – データセットを学習用、テスト用に分割する方法

目次

概要

Pytorch である Dataset を分割し、学習用、テスト用の Dataset を作成する方法について解説します。

Dataset の分割

以下のように学習用、テスト用で最初からデータが別れている場合はそれぞれ Dataset を作成すればよいですが、別れていない場合はコード上で学習用、テスト用にそれぞれ分割する必要があります。

Dataset
├── train
└── test

ランダムに2つのデータセットに分割する

以下のようなディレクトリ構成のクラス分類のデータセットを例に説明しますが、他のデータセットでも同様です。

my_dataset
├── A: 学習用のクラス A の画像があるディレクトリ
├── B: 学習用のクラス B の画像があるディレクトリ
└── C: 学習用のクラス C の画像があるディレクトリ

torch.utils.data.random_split() を使用すると、データセットを指定した数ごとに重複がないようにランダムに分割できます。

torch.utils.data.random_split(dataset, lengths)
In [1]:
import torch
import torchvision.datasets as datasets

# 元のデータセット
dataset_dir = "/data/hymenoptera_data/train"
full_dataset = datasets.ImageFolder(dataset_dir)

# 学習データ、検証データに 8:2 の割合で分割する。
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

print(f"full: {len(full_dataset)} -> train: {len(train_dataset)}, test: {len(val_dataset)}")
full: 244 -> train: 195, test: 49

インデックスを指定して2つのデータセットに分割する

torch.utils.data.Subset(dataset, indices) を使用すると、元のデータセットから指定したインデックスだけ使用するデータセットを作成できます。学習及びテストに使用するインデックスを予め作成しておくことで、データセットを分割できます。

torch.utils.data.Subset(dataset, indices)
In [2]:
import numpy as np
import torch
import torchvision.datasets as datasets

# 元のデータセット
dataset_dir = "/data/hymenoptera_data/train"
dataset = datasets.ImageFolder(dataset_dir)

# 学習データ、検証データに 8:2 の割合で分割する。
train_size = int(0.8 * len(full_dataset))
indices = np.arange(len(full_dataset))

train_dataset = torch.utils.data.Subset(dataset, indices[:train_size])
val_dataset = torch.utils.data.Subset(dataset, indices[train_size:])

print(f"full: {len(dataset)} -> train: {len(train_dataset)}, test: {len(val_dataset)}")
full: 244 -> train: 195, test: 49

使用する Transform を学習用とテスト用で別々にする

上記のやり方の場合、元のデータセットで指定されている Transform が適用されるため、学習用とテスト用で共通の Transform になってしまいます。学習用とテスト用で別の Transform を適用したい場合はコンストラクタ引数に transform を追加した Subset を自作します。

In [3]:
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

train_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

val_transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

class MySubset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __getitem__(self, idx):
        img, label = self.dataset[self.indices[idx]]
        if self.transform:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.indices)


# 元のデータセット
dataset_dir = "/data/hymenoptera_data/train"
dataset = datasets.ImageFolder(dataset_dir)

# 学習データ、検証データに 8:2 の割合で分割する。
train_size = int(0.8 * len(full_dataset))
indices = np.arange(len(full_dataset))

train_dataset = MySubset(dataset, indices[:train_size], train_transform)
val_dataset = MySubset(dataset, indices[train_size:], val_transform)

print(f"full: {len(dataset)} -> train: {len(train_dataset)}, test: {len(val_dataset)}")
full: 244 -> train: 195, test: 49

コメント

コメントする

目次