Pytorch – DataLoader の使い方について解説

目次

概要

torch.utils.data,DataLoader

DataLoader は、Dataset からサンプルを取得して、ミニバッチを作成するクラスです。基本的には、サンプルを取得する Dataset とバッチサイズを指定して作成します。DataLoader は、iterate するとミニバッチを返すようになっています。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)
  • dataset: データセット
  • batch_size: バッチサイズ
  • shuffle: シャッフルするかどうか
  • sampler: サンプラー
  • batch_sampler: サンプラー
  • num_workers: 並列実行数
  • collate_fn:ミニバッチ作成前に使用されるコールバック関数
  • pin_memory: pinned memory を使用するかどうか
  • drop_last: 最後の余りのミニバッチは切り捨てるかどうか
  • timeout: ミニバッチを作成する時間制限
  • worker_init_fn: ミニバッチ作成前に呼ばれるコールバック関数

Dataset – データセット

Dataset は、データセットを表すクラスで、サンプルを要求されたときに返す処理を定義します。Dataset は次の2種類があります。

  • map-style Dataset: キー (通常はインデックス) を渡して、それに対応するデータを返す Dataset です。キーに対応するサンプルを返す __getitem__() 及びデータセットのサンプル数を返す __len__() が実装されている必要があります。
  • iterable-style Dataset: 逐次データを返す iterable な Dataset です。__iter__() が実装されている必要があります。

map-style Dataset を自作する

map-style Dataset を自作する場合、Dataset を継承したクラスを作成し、キーが渡されたときにそれに対応するサンプルを返す __getitem__() 及びデータセットのサンプル数を返す __len__() を実装します。

from torch.utils.data import Dataset

class ImageFolder(Dataset):
    def __getitem__(self, index):
        # 1つのサンプルを返す処理を書く

    def __len__(self):
        # データセットのサンプル数を返す処理を書く

指定されたディレクトリから画像を読み込んで返す Dataset を作成してみます。

In [1]:
def _get_img_paths(img_dir):
    img_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
    img_paths = [p for p in Path(img_dir).iterdir() if p.suffix in img_extensions]

    return img_paths


class ImageFolder(data.Dataset):
    def __init__(self, img_dir, transform=None):
        # 画像ファイルのパス一覧を取得する。
        self.img_paths = _get_img_paths(img_dir)
        self.transform = transform

    def __getitem__(self, index):
        # インデックスに対応するファイルパスを取得する。
        path = self.img_paths[index]
        # 画像を読み込む。
        img = Image.open(path)
        # Transforms で変換する。
        if transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        # ディレクトリ内の画像枚数を返す。
        return len(self.img_paths)

これを使って data ディレクトリに画像ファイルがあるとき、以下のようにサンプルを取得できます。

In [2]:
transform = transforms.Compose(
    [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
)

dataset = ImageFolder("data", transform)

# データを取得する。
for i in range(len(dataset)):
    img = dataset[i]
    print(type(img), img.shape)
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([3, 224, 224])
In [3]:
import numpy as np
from torch.utils import data as data


class MyDataset(data.Dataset):
    def __init__(self, n):
        self.data = np.random.rand(n, 100, 100, 3)
        self.labels = np.arange(n)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

    def __len__(self):
        return len(self.data)


dataset = MyDataset(10)

batchsize

DataLoader が返すミニバッチのサイズを設定します。 batchsize=None とした場合、ミニバッチの代わりにサンプル1つを返します。この場合、バッチ次元はありません。 batchsize に1以上の整数を指定した場合、複数のサンプルから作成したミニバッチを返します。

In [4]:
print("batchsize=None")
dataloader = data.DataLoader(dataset, batch_size=None)
X, y = next(iter(dataloader))
print(X.shape, y)

print("batchsize=int")
dataloader = data.DataLoader(dataset, batch_size=3)
X, y = next(iter(dataloader))
print(X.shape, y.shape)
batchsize=None
torch.Size([100, 100, 3]) tensor(0)
batchsize=int
torch.Size([3, 100, 100, 3]) torch.Size([3])

shuffle – シャッフルするかどうか

shuffle=True の場合、シャッフルした順番で返します。

In [5]:
print("shuffle=False")
dataloader = data.DataLoader(dataset, batch_size=5)
for X, y in dataloader:
    print(y, end=" ")

print("\nshuffle=True")
dataloader = data.DataLoader(dataset, batch_size=5, shuffle=True)
for X, y in dataloader:
    print(y, end=" ")
shuffle=False
tensor([0, 1, 2, 3, 4]) tensor([5, 6, 7, 8, 9]) 
shuffle=True
tensor([7, 0, 5, 4, 1]) tensor([2, 8, 9, 6, 3]) 

sampler – 次に読み込むサンプルのキーを返す

Sampler は、map-style データセットの場合に、サンプルのキーを返す iterable なクラスで、データセットから読み込む順序を規定します。デフォルトの Sampler がいくつか用意されており、例えば、DataLoader()shuffle=False を指定した場合、データセットから順番に読み込む SequentialSamplershuffle=True を指定した場合、重複なしでランダムな順番で読み込む RandomSampler が使用されます。 大抵の場合、上記2つで事足りるので自作する必要性はあまりないですが、torch.utils.data.SequentialSampler の実装を見てみます。

In [6]:
class SequentialSampler(data.Sampler):
    def __init__(self, dataset):
        self.dataset = dataset

    def __iter__(self):
        return iter(range(len(self.dataset)))


sampler = SequentialSampler(dataset)
for i in sampler:  # iterate すると、読み込むサンプルのキーを返す
    print(i, end=" ")
0 1 2 3 4 5 6 7 8 9 

BatchSampler – ミニバッチ作成に使用するサンプルのキー一覧を返す

BatchSampler は、map-style データセットの場合に、ミニバッチに使用するサンプルのキー一覧を返す iterable なクラスです。torch.utils.data.BatchSampler の実装を見てみます。

In [7]:
class BatchSampler(data.Sampler):
    def __init__(self, sampler, batch_size):
        self.sampler = sampler
        self.batch_size = batch_size

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0:
            yield batch


batch_sampler = BatchSampler(sampler, 3)
for batch_indices in batch_sampler:
    print(batch_indices)
[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9]

num_workers – ミニバッチを作成する際の並列実行数

num_workers でミニバッチを作成する際の並列実行数を指定できます。 最大で CPU の論理スレッド数分の高速化が期待できます。

以下は、num_workers を変化させたとき、読み込み1枚あたりにかかる時間をグラフ化したものです。 使用した Core i7-6700K は、論理スレッドが8なので、8までは増やすほど読み込みが高速化することが確認できました。

  • サンプル数: 12894
  • バッチサイズ: 128
  • CPU: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz

collate_fn – サンプルのリストを集計してミニバッチを作成する関数

collate_fn は Dataset から取得した複数のサンプルを結合して、1つのミニバッチを作成する処理を行う関数です。 デフォルトでは、default_collate が使用されます。default_collate() では渡される batch の各要素が以下の場合に対応しています。

  • Tensor
  • ndarray (dtype = “S”, “a”, “U”, “O” を除く)
  • float
  • int
  • str
  • bytes
  • Mapping (dict など)
  • namedtuple
  • Sequence (list, tuple など)

default_collate の実装を見ています。

In [8]:
import collections
import re

import numpy as np
import torch

np_str_obj_array_pattern = re.compile(r"[SaUO]")

default_collate_err_msg_format = (
    "default_collate: batch must contain tensors, numpy arrays, numbers, "
    "dicts or lists; found {}"
)


def default_collate(batch):
    elem = batch[0]
    elem_type = type(elem)

    print(f"elem_type.__name__: {elem_type.__name__}")

    if isinstance(elem, torch.Tensor):
        # Tensor の場合はそのまま結合
        return torch.stack(batch, 0)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):
        if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
            # numpy 配列の場合
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            # スカラー以外の場合、テンソルに変換して、default_collate() にもう一度渡す。
            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():
            # スカラーの場合、そのままテンソルにする。
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        # float の場合、そのままテンソルにする。
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        # int の場合、そのままテンソルにする。
        return torch.tensor(batch)
    elif isinstance(elem, (str, bytes)):
        # str, bytes の場合、そのまま返す。
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        # Mapping の場合 (dict など)、キーごとにテンソル化する。
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):
        # namedtuple の場合
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # sequence の場合
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError("each element in list of batch should be of equal size")
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

各要素が numpy 配列の場合

In [9]:
# 各要素が (10, 10) の numpy 配列
batch = [np.random.rand(10, 10) for i in range(5)]
output = default_collate(batch)
print(f"output.shape: {output.shape}, output.dtype: {output.dtype}")

# 各要素が () の numpy 配列 (スカラー)
batch = [np.array(5) for i in range(5)]
output = default_collate(batch)
print(f"output.shape: {output.shape}, output.dtype: {output.dtype}")
elem_type.__name__: ndarray
elem_type.__name__: Tensor
output.shape: torch.Size([5, 10, 10]), output.dtype: torch.float64
elem_type.__name__: ndarray
elem_type.__name__: Tensor
output.shape: torch.Size([5]), output.dtype: torch.int64

int, float の場合

In [10]:
# 各要素が int
batch = [1 for i in range(5)]
output = default_collate(batch)
print(f"output.shape: {output.shape}, output.dtype: {output.dtype}")

# 各要素が () の numpy 配列 (スカラー)
batch = [1.1 for i in range(5)]
output = default_collate(batch)
print(f"output.shape: {output.shape}, output.dtype: {output.dtype}")
elem_type.__name__: int
output.shape: torch.Size([5]), output.dtype: torch.int64
elem_type.__name__: float
output.shape: torch.Size([5]), output.dtype: torch.float64

str の場合

In [11]:
# 各要素が str
batch = ["hoge" for i in range(5)]
output = default_collate(batch)
print(output)
elem_type.__name__: str
['hoge', 'hoge', 'hoge', 'hoge', 'hoge']

Mapping の場合 (dict など)

In [12]:
# 各要素が str
batch = [{"A": np.random.rand(10, 10), "B": "hoge", "C": 1} for i in range(5)]
output = default_collate(batch)

for key, value in output.items():
    if isinstance(value, list):
        print(f"key: {key}, len(value): {len(value)}, type(value[0]): {type(value[0])}")
    else:
        print(f"key: {key}, value.shape: {value.shape}, value.dtype: {value.dtype}")
elem_type.__name__: dict
elem_type.__name__: ndarray
elem_type.__name__: Tensor
elem_type.__name__: str
elem_type.__name__: int
key: A, value.shape: torch.Size([5, 10, 10]), value.dtype: torch.float64
key: B, len(value): 5, type(value[0]): <class 'str'>
key: C, value.shape: torch.Size([5]), value.dtype: torch.int64

namedtuple の場合

In [13]:
from collections import namedtuple

Sample = namedtuple("sample", ["A", "B", "C"])

# 各要素が str
batch = [Sample(np.random.rand(10, 10), "hoge", 1) for i in range(5)]
output = default_collate(batch)

for key, value in output._asdict().items():
    if isinstance(value, torch.Tensor):
        print(f"key: {key}, value.shape: {value.shape}, value.dtype: {value.dtype}")
    else:
        print(f"key: {key}, len(value): {len(value)}, type(value[0]): {type(value[0])}")
elem_type.__name__: sample
elem_type.__name__: ndarray
elem_type.__name__: Tensor
elem_type.__name__: str
elem_type.__name__: int
key: A, value.shape: torch.Size([5, 10, 10]), value.dtype: torch.float64
key: B, len(value): 5, type(value[0]): <class 'str'>
key: C, value.shape: torch.Size([5]), value.dtype: torch.int64

Sequence の場合

In [14]:
# 各要素が list
batch = [["hoge", 2, np.random.rand(2, 2)] for i in range(5)]
output = default_collate(batch)
for value in output:
    if isinstance(value, torch.Tensor):
        print(f"value.shape: {value.shape}, value.dtype: {value.dtype}")
    else:
        print(f"len(value): {len(value)}, type(value[0]): {type(value[0])}")

# 各要素が tuple
batch = [("hoge", 2, np.random.rand(2, 2)) for i in range(5)]
output = default_collate(batch)
for value in output:
    if isinstance(value, torch.Tensor):
        print(f"value.shape: {value.shape}, value.dtype: {value.dtype}")
    else:
        print(f"len(value): {len(value)}, type(value[0]): {type(value[0])}")
elem_type.__name__: list
elem_type.__name__: str
elem_type.__name__: int
elem_type.__name__: ndarray
elem_type.__name__: Tensor
len(value): 5, type(value[0]): <class 'str'>
value.shape: torch.Size([5]), value.dtype: torch.int64
value.shape: torch.Size([5, 2, 2]), value.dtype: torch.float64
elem_type.__name__: tuple
elem_type.__name__: str
elem_type.__name__: int
elem_type.__name__: ndarray
elem_type.__name__: Tensor
len(value): 5, type(value[0]): <class 'str'>
value.shape: torch.Size([5]), value.dtype: torch.int64
value.shape: torch.Size([5, 2, 2]), value.dtype: torch.float64

drop_last – 最後の余りのミニバッチは切り捨てる

drop_last=True を指定した場合、端数となってしまった最後のミニバッチは切り捨てます。 例えば、500 サンプルをバッチサイズ 128 で読み込んだ場合、128, 128, 128, 116 と最後だけ端数となってしまいますが、 drop_last=True の場合はこの最後の 116 は切り捨てるようになります。

In [15]:
print(f"number of samples: {len(dataset)}")

dataloader = data.DataLoader(dataset, batch_size=128, drop_last=False)
print("drop_last=False")
for X, y in dataloader:
    print(X.shape, y.shape)

dataloader = data.DataLoader(dataset, batch_size=128, drop_last=True)
print("drop_last=True")
for X, y in dataloader:
    print(X.shape, y.shape)
number of samples: 501
drop_last=False
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([117, 3, 224, 224]) torch.Size([117])
drop_last=True
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])

コメント

コメントする

目次