概要
Pytorch で自作のデータセットを扱うには、Dataset クラスを継承したクラスを作成する必要があります。本記事では、そのやり方について説明します。
Dataset
Dataset クラスでは、画像や csv ファイルといったリソースで構成されるデータセットからデータを取得する方法について定義します。基本的にはインデックス index
のサンプルが要求されたときに返す __getitem__(self, index)
とデータセットのサンプル数が要求されたときに返す __len__(self)
の2つを実装します。
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __getitem__(self, index):
# インデックス index のサンプルが要求されたときに返す処理を実装
def __len__(self):
# データセットのサンプル数が要求されたときに返す処理を実装
指定したディレクトリから画像を読み込む Dataset
指定したディレクトリ以下の画像を読み込むデータセットの例です。画像のクラス分類モデルを使って、ディレクトリ内の画像に対して推論を行いたい場合に使います。
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
class ImageFolder(Dataset):
IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".bmp"]
def __init__(self, img_dir, transform=None):
# 画像ファイルのパス一覧を取得する。
self.img_paths = self._get_img_paths(img_dir)
self.transform = transform
def __getitem__(self, index):
path = self.img_paths[index]
# 画像を読み込む。
img = Image.open(path)
if self.transform is not None:
# 前処理がある場合は行う。
img = self.transform(img)
return img
def _get_img_paths(self, img_dir):
"""指定したディレクトリ内の画像ファイルのパス一覧を取得する。
"""
img_dir = Path(img_dir)
img_paths = [
p for p in img_dir.iterdir() if p.suffix in ImageFolder.IMG_EXTENSIONS
]
return img_paths
def __len__(self):
"""ディレクトリ内の画像ファイルの数を返す。
"""
return len(self.img_paths)
# Transform を作成する。
transform = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
# Dataset を作成する。
dataset = ImageFolder("data", transform)
# DataLoader を作成する。
dataloader = DataLoader(dataset, batch_size=3)
for batch in dataloader:
print(batch.shape)
torch.Size([3, 3, 256, 256]) torch.Size([2, 3, 256, 256])
指定した動画ファイルからフレームを読み込む Dataset
指定した動画ファイルからフレームを読み込むデータセットの例です。画像のクラス分類モデルを使って、動画ファイルのフレームに対して推論を行いたい場合に使います。
from pathlib import Path
import cv2
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class Video(Dataset):
def __init__(self, video_path, transform=None):
self.cap = cv2.VideoCapture(video_path)
self.transform = transform
def __getitem__(self, index):
"""動画のフレームを返す。
"""
# フレームを読み込む。
ret, img = self.cap.read()
# チャンネルの順番を変更する。 (B, G, R) -> (R, G, B)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# numpy 配列を PIL Image に変換する。
img = Image.fromarray(img)
if self.transform is not None:
# 前処理がある場合は行う。
img = self.transform(img)
return img
def __len__(self):
"""動画のフレーム数を返す。
"""
return int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Transform を作成する。
transform = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
# Dataset を作成する。
dataset = Video("data/sample.avi", transform)
# DataLoader を作成する。
dataloader = DataLoader(dataset, batch_size=128)
for batch in dataloader:
print(batch.shape)
torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([128, 3, 256, 341]) torch.Size([27, 3, 256, 341])
指定した csv ファイルからデータを読み込む Dataset
指定した csv ファイルからデータを読み込むデータセットの例です。クラス分類モデルを使って、数値データに対して学習を行いたい場合に使います。
サンプルとして wine.csv を使います。この CSV ファイルは14列あり、1列目がラベル、2~14列目が特徴量となっています。
import pandas as pd
import torch
from sklearn.preprocessing import normalize
from torch.utils.data import DataLoader, Dataset
class Wine(Dataset):
def __init__(self, csv_path):
# csv ファイルを読み込む。
df = pd.read_csv(csv_path)
data = df.iloc[:, 1:] # データ (2 ~ 14列目)
labels = df.iloc[:, 0] # ラベル (1列目)
# データを標準化する。
data = normalize(data)
# クラス ID を 0 始まりにする。[1, 2, 3] -> [0, 1, 2]
labels -= 1
self.data = data
self.labels = labels
def __getitem__(self, index):
"""サンプルを返す。
"""
return self.data[index], self.labels[index]
def __len__(self):
"""csv の行数を返す。
"""
return len(self.data)
# Dataset を作成する。
dataset = Wine("https://git.io/JfodD")
# DataLoader を作成する。
dataloader = DataLoader(dataset, batch_size=64)
for X_batch, y_batch in dataloader:
print(X_batch.shape, y_batch.shape)
torch.Size([64, 13]) torch.Size([64]) torch.Size([64, 13]) torch.Size([64]) torch.Size([50, 13]) torch.Size([50])
ImageFolder – 画像のクラス分類の学習用のデータセット
画像のクラス分類の学習を行う際にデータセットがクラスごとにディレクトリに分けられている構造の場合は、ImageFolder を利用できます。このデータセットは、サンプルが要求されると、データである画像及びラベルであるクラス ID を返します。
データセットのディレクトリ構造の例
dataset1
├── class1: クラス1の画像があるディレクトリ
│ ├── a.jpg
│ ├── b.jpg
│ └── c.jpg
├── class2: クラス2の画像があるディレクトリ
│ ├── a.jpg
│ ├── b.jpg
│ └── c.jpg
└── class3: クラス3の画像があるディレクトリ
├── a.jpg
├── b.jpg
└── c.jpg
サブディレクトリの名前がクラス名となります。クラス ID はクラス名を辞書順ソートして、0, 1, … と整数が割り振られます。上記の例では、以下のようになります。
クラス ID | クラス名 |
---|---|
0 | class1 |
1 | class2 |
2 | class3 |
ImageFolder.class_to_idx
属性でクラス名とクラス ID の対応関係を取得できます。
ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
- 引数
- root (str) – データセットのルートディレクトリ
- transform (callable) – データ用の Transform
- target_transform (callable) – ラベル用の Transform
- loader (callable) – 画像を読み込む関数
- is_valid_file (callable) – 画像が破損していないかどうかをチェックする関数
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
# Transform を作成する。
transform = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
# Dataset を作成する。
dataset = ImageFolder("dataset1", transform)
print(dataset.class_to_idx)
# DataLoader を作成する。
dataloader = DataLoader(dataset, batch_size=3)
for X_batch, y_batch in dataloader:
print(X_batch.shape, y_batch.shape)
{'class1': 0, 'class2': 1, 'class3': 2} torch.Size([3, 3, 256, 352]) torch.Size([3]) torch.Size([3, 3, 256, 352]) torch.Size([3]) torch.Size([3, 3, 256, 352]) torch.Size([3])
DatasetFolder – 画像以外のクラス分類の学習用のデータセット
画像以外のクラス分類の学習を行う際にデータセットがクラスごとにディレクトリに分けられている構造の場合は、DatasetFolder を利用できます。使い方は ImageFolder と同じです。
データセットのディレクトリ構造の例
dataset1
├── class1: クラス1のデータがあるディレクトリ
│ ├── a.ext
│ ├── b.ext
│ └── c.ext
├── class2: クラス2のデータがあるディレクトリ
│ ├── a.ext
│ ├── b.ext
│ └── c.ext
└── class3: クラス3のデータがあるディレクトリ
├── a.ext
├── b.ext
└── c.ext
DatasetFolder(root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None)
- 引数
- root (str) – データセットのルートディレクトリ
- loader (callable) – データを読み込む関数
- extensions (tuple of strings) – 読み込むファイルの拡張子一覧 (extensions と is_valid_file のどちらか一方のみを指定)
- transform (callable) – データ用の Transform
- target_transform (callable) – ラベル用の Transform
- is_valid_file (callable) – 読み込むファイルかどうかをチェックする関数 (extensions と is_valid_file のどちらか一方のみを指定)
コメント
コメント一覧 (0件)
全然わからないのですが、パスはどうやって指定すればよいのですか?
コメントありがとうございます。
例えば、「指定したディレクトリから画像を読み込む Dataset」の項で紹介しているコードの場合、ImageFolder() の第一引数に画像があるディレクトリのパスを相対または絶対パスで指定することを想定しています。
“`
dataset = ImageFolder(<画像があるディレクトリのパス>, transform)
“`
もしよろしければ、やりたいことをコメントしていただけたら、コード例など具体的なアドバイスができるかもしれません。