概要
Pytorch である Dataset を分割し、学習用、テスト用の Dataset を作成する方法について解説します。
Dataset の分割
以下のように学習用、テスト用で最初からデータが別れている場合はそれぞれ Dataset を作成すればよいですが、別れていない場合はコード上で学習用、テスト用にそれぞれ分割する必要があります。
ランダムに2つのデータセットに分割する
以下のようなディレクトリ構成のクラス分類のデータセットを例に説明しますが、他のデータセットでも同様です。
torch.utils.data.random_split()
を使用すると、データセットを指定した数ごとに重複がないようにランダムに分割できます。
In [1]:
full: 244 -> train: 195, test: 49
インデックスを指定して2つのデータセットに分割する
torch.utils.data.Subset(dataset, indices)
を使用すると、元のデータセットから指定したインデックスだけ使用するデータセットを作成できます。学習及びテストに使用するインデックスを予め作成しておくことで、データセットを分割できます。
In [2]:
full: 244 -> train: 195, test: 49
使用する Transform を学習用とテスト用で別々にする
上記のやり方の場合、元のデータセットで指定されている Transform が適用されるため、学習用とテスト用で共通の Transform になってしまいます。学習用とテスト用で別の Transform を適用したい場合はコンストラクタ引数に transform を追加した Subset を自作します。
In [3]:
full: 244 -> train: 195, test: 49
コメント