概要
Pytorch で自作のデータセットを扱うには、Dataset クラスを継承したクラスを作成する必要があります。本記事では、そのやり方について説明します。
Dataset
Dataset クラスでは、画像や csv ファイルといったリソースで構成されるデータセットからデータを取得する方法について定義します。基本的にはインデックス index
のサンプルが要求されたときに返す __getitem__(self, index)
とデータセットのサンプル数が要求されたときに返す __len__(self)
の2つを実装します。
指定したディレクトリから画像を読み込む Dataset
指定したディレクトリ以下の画像を読み込むデータセットの例です。画像のクラス分類モデルを使って、ディレクトリ内の画像に対して推論を行いたい場合に使います。
torch.Size([3, 3, 256, 256]) torch.Size([2, 3, 256, 256])
指定した動画ファイルからフレームを読み込む Dataset
指定した動画ファイルからフレームを読み込むデータセットの例です。画像のクラス分類モデルを使って、動画ファイルのフレームに対して推論を行いたい場合に使います。
torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([27, 3, 256, 341])
指定した csv ファイルからデータを読み込む Dataset
指定した csv ファイルからデータを読み込むデータセットの例です。クラス分類モデルを使って、数値データに対して学習を行いたい場合に使います。
サンプルとして wine.csv を使います。この CSV ファイルは14列あり、1列目がラベル、2~14列目が特徴量となっています。
torch.Size([64, 13]) torch.Size([64]) torch.Size([64, 13]) torch.Size([64]) torch.Size([50, 13]) torch.Size([50])
ImageFolder – 画像のクラス分類の学習用のデータセット
画像のクラス分類の学習を行う際にデータセットがクラスごとにディレクトリに分けられている構造の場合は、ImageFolder を利用できます。このデータセットは、サンプルが要求されると、データである画像及びラベルであるクラス ID を返します。
データセットのディレクトリ構造の例
サブディレクトリの名前がクラス名となります。クラス ID はクラス名を辞書順ソートして、0, 1, … と整数が割り振られます。上記の例では、以下のようになります。
クラス ID | クラス名 |
---|---|
0 | class1 |
1 | class2 |
2 | class3 |
ImageFolder.class_to_idx
属性でクラス名とクラス ID の対応関係を取得できます。
- 引数
- root (str) – データセットのルートディレクトリ
- transform (callable) – データ用の Transform
- target_transform (callable) – ラベル用の Transform
- loader (callable) – 画像を読み込む関数
- is_valid_file (callable) – 画像が破損していないかどうかをチェックする関数
{'class1': 0, 'class2': 1, 'class3': 2} torch.Size([3, 3, 256, 352]) torch.Size([3]) torch.Size([3, 3, 256, 352]) torch.Size([3]) torch.Size([3, 3, 256, 352]) torch.Size([3])
DatasetFolder – 画像以外のクラス分類の学習用のデータセット
画像以外のクラス分類の学習を行う際にデータセットがクラスごとにディレクトリに分けられている構造の場合は、DatasetFolder を利用できます。使い方は ImageFolder と同じです。
データセットのディレクトリ構造の例
- 引数
- root (str) – データセットのルートディレクトリ
- loader (callable) – データを読み込む関数
- extensions (tuple of strings) – 読み込むファイルの拡張子一覧 (extensions と is_valid_file のどちらか一方のみを指定)
- transform (callable) – データ用の Transform
- target_transform (callable) – ラベル用の Transform
- is_valid_file (callable) – 読み込むファイルかどうかをチェックする関数 (extensions と is_valid_file のどちらか一方のみを指定)
コメント
コメント一覧 (0件)
全然わからないのですが、パスはどうやって指定すればよいのですか?
コメントありがとうございます。
例えば、「指定したディレクトリから画像を読み込む Dataset」の項で紹介しているコードの場合、ImageFolder() の第一引数に画像があるディレクトリのパスを相対または絶対パスで指定することを想定しています。
“`
dataset = ImageFolder(<画像があるディレクトリのパス>, transform)
“`
もしよろしければ、やりたいことをコメントしていただけたら、コード例など具体的なアドバイスができるかもしれません。