概要
torch.utils.data,DataLoader
DataLoader は、Dataset からサンプルを取得して、ミニバッチを作成するクラスです。基本的には、サンプルを取得する Dataset とバッチサイズを指定して作成します。DataLoader は、iterate するとミニバッチを返すようになっています。
- dataset: データセット
- batch_size: バッチサイズ
- shuffle: シャッフルするかどうか
- sampler: サンプラー
- batch_sampler: サンプラー
- num_workers: 並列実行数
- collate_fn:ミニバッチ作成前に使用されるコールバック関数
- pin_memory: pinned memory を使用するかどうか
- drop_last: 最後の余りのミニバッチは切り捨てるかどうか
- timeout: ミニバッチを作成する時間制限
- worker_init_fn: ミニバッチ作成前に呼ばれるコールバック関数
Dataset – データセット
Dataset は、データセットを表すクラスで、サンプルを要求されたときに返す処理を定義します。Dataset は次の2種類があります。
- map-style Dataset: キー (通常はインデックス) を渡して、それに対応するデータを返す Dataset です。キーに対応するサンプルを返す
__getitem__()
及びデータセットのサンプル数を返す__len__()
が実装されている必要があります。 - iterable-style Dataset: 逐次データを返す iterable な Dataset です。
__iter__()
が実装されている必要があります。
map-style Dataset を自作する
map-style Dataset を自作する場合、Dataset を継承したクラスを作成し、キーが渡されたときにそれに対応するサンプルを返す __getitem__()
及びデータセットのサンプル数を返す __len__()
を実装します。
指定されたディレクトリから画像を読み込んで返す Dataset を作成してみます。
これを使って data
ディレクトリに画像ファイルがあるとき、以下のようにサンプルを取得できます。
<class 'torch.Tensor'> torch.Size([3, 224, 224]) <class 'torch.Tensor'> torch.Size([3, 224, 224]) <class 'torch.Tensor'> torch.Size([3, 224, 224]) <class 'torch.Tensor'> torch.Size([3, 224, 224]) <class 'torch.Tensor'> torch.Size([3, 224, 224])
batchsize
DataLoader が返すミニバッチのサイズを設定します。
batchsize=None
とした場合、ミニバッチの代わりにサンプル1つを返します。この場合、バッチ次元はありません。
batchsize
に1以上の整数を指定した場合、複数のサンプルから作成したミニバッチを返します。
batchsize=None torch.Size([100, 100, 3]) tensor(0) batchsize=int torch.Size([3, 100, 100, 3]) torch.Size([3])
shuffle – シャッフルするかどうか
shuffle=True
の場合、シャッフルした順番で返します。
shuffle=False tensor([0, 1, 2, 3, 4]) tensor([5, 6, 7, 8, 9]) shuffle=True tensor([7, 0, 5, 4, 1]) tensor([2, 8, 9, 6, 3])
sampler – 次に読み込むサンプルのキーを返す
Sampler は、map-style データセットの場合に、サンプルのキーを返す iterable なクラスで、データセットから読み込む順序を規定します。デフォルトの Sampler がいくつか用意されており、例えば、DataLoader()
で shuffle=False
を指定した場合、データセットから順番に読み込む SequentialSampler、shuffle=True
を指定した場合、重複なしでランダムな順番で読み込む RandomSampler が使用されます。
大抵の場合、上記2つで事足りるので自作する必要性はあまりないですが、torch.utils.data.SequentialSampler の実装を見てみます。
0 1 2 3 4 5 6 7 8 9
BatchSampler – ミニバッチ作成に使用するサンプルのキー一覧を返す
BatchSampler は、map-style データセットの場合に、ミニバッチに使用するサンプルのキー一覧を返す iterable なクラスです。torch.utils.data.BatchSampler の実装を見てみます。
[0, 1, 2] [3, 4, 5] [6, 7, 8] [9]
num_workers – ミニバッチを作成する際の並列実行数
num_workers
でミニバッチを作成する際の並列実行数を指定できます。
最大で CPU の論理スレッド数分の高速化が期待できます。
以下は、num_workers
を変化させたとき、読み込み1枚あたりにかかる時間をグラフ化したものです。
使用した Core i7-6700K は、論理スレッドが8なので、8までは増やすほど読み込みが高速化することが確認できました。
- サンプル数: 12894
- バッチサイズ: 128
- CPU: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz
![](/wp/wp-content/uploads/2020/07/pytorch-dataloader_01.jpg)
collate_fn – サンプルのリストを集計してミニバッチを作成する関数
collate_fn
は Dataset から取得した複数のサンプルを結合して、1つのミニバッチを作成する処理を行う関数です。
デフォルトでは、default_collate
が使用されます。default_collate() では渡される batch の各要素が以下の場合に対応しています。
- Tensor
- ndarray (dtype = “S”, “a”, “U”, “O” を除く)
- float
- int
- str
- bytes
- Mapping (dict など)
- namedtuple
- Sequence (list, tuple など)
default_collate の実装を見ています。
各要素が numpy 配列の場合
elem_type.__name__: ndarray elem_type.__name__: Tensor output.shape: torch.Size([5, 10, 10]), output.dtype: torch.float64 elem_type.__name__: ndarray elem_type.__name__: Tensor output.shape: torch.Size([5]), output.dtype: torch.int64
int, float の場合
elem_type.__name__: int output.shape: torch.Size([5]), output.dtype: torch.int64 elem_type.__name__: float output.shape: torch.Size([5]), output.dtype: torch.float64
str の場合
elem_type.__name__: str ['hoge', 'hoge', 'hoge', 'hoge', 'hoge']
Mapping の場合 (dict など)
elem_type.__name__: dict elem_type.__name__: ndarray elem_type.__name__: Tensor elem_type.__name__: str elem_type.__name__: int key: A, value.shape: torch.Size([5, 10, 10]), value.dtype: torch.float64 key: B, len(value): 5, type(value[0]): <class 'str'> key: C, value.shape: torch.Size([5]), value.dtype: torch.int64
namedtuple の場合
elem_type.__name__: sample elem_type.__name__: ndarray elem_type.__name__: Tensor elem_type.__name__: str elem_type.__name__: int key: A, value.shape: torch.Size([5, 10, 10]), value.dtype: torch.float64 key: B, len(value): 5, type(value[0]): <class 'str'> key: C, value.shape: torch.Size([5]), value.dtype: torch.int64
Sequence の場合
elem_type.__name__: list elem_type.__name__: str elem_type.__name__: int elem_type.__name__: ndarray elem_type.__name__: Tensor len(value): 5, type(value[0]): <class 'str'> value.shape: torch.Size([5]), value.dtype: torch.int64 value.shape: torch.Size([5, 2, 2]), value.dtype: torch.float64 elem_type.__name__: tuple elem_type.__name__: str elem_type.__name__: int elem_type.__name__: ndarray elem_type.__name__: Tensor len(value): 5, type(value[0]): <class 'str'> value.shape: torch.Size([5]), value.dtype: torch.int64 value.shape: torch.Size([5, 2, 2]), value.dtype: torch.float64
drop_last – 最後の余りのミニバッチは切り捨てる
drop_last=True
を指定した場合、端数となってしまった最後のミニバッチは切り捨てます。
例えば、500 サンプルをバッチサイズ 128 で読み込んだ場合、128, 128, 128, 116 と最後だけ端数となってしまいますが、
drop_last=True
の場合はこの最後の 116 は切り捨てるようになります。
number of samples: 501 drop_last=False torch.Size([128, 3, 224, 224]) torch.Size([128]) torch.Size([128, 3, 224, 224]) torch.Size([128]) torch.Size([128, 3, 224, 224]) torch.Size([128]) torch.Size([117, 3, 224, 224]) torch.Size([117]) drop_last=True torch.Size([128, 3, 224, 224]) torch.Size([128]) torch.Size([128, 3, 224, 224]) torch.Size([128]) torch.Size([128, 3, 224, 224]) torch.Size([128])
コメント