YOLOv3 – 自作データセットで学習する方法について

YOLOv3 – 自作データセットで学習する方法について

概要

YOLOv3 で独自のデータセットを学習する方法について解説します。本記事では、例として金魚の物体検出を学習します。

Advertisement

YOLOv3 のスクリプトを準備する

YOLOv3 の Pytorch 実装である nekobean/pytorch_yolov3 を使用します。

まず、レポジトリをクローンします。

git clone https://github.com/nekobean/pytorch_yolov3.git
cd pytorch_yolov3

依存ライブラリをインストールします。

pip install -r requirements.txt

Darknet 53 の学習済みの重み darknet53.conv.74 をダウンロードします。物体検出の学習は転移学習の形で行うので、この重みが必要となります。

wget https://pjreddie.com/media/files/darknet53.conv.74

ダウンロードが完了したら、weights ディレクトリに配置してください。

weights/
|-- download_weights.sh
`-- darknet53.conv.74

データセット作成

画像収集

アノテーションするための画像を用意します。検出対象物が一般的なものでない場合は自分で対象物を撮影するところから始めます。逆にネット上でも十分な枚数の画像が入手可能な場合は、Google 画像検索などを活用して画像を収集するとよいでしょう。 必要な枚数ですが、1クラスあたり最低300ラベルはあったほうがよいでしょう。枚数でなくラベル数なので、例えば、ある物体が1枚に3個写っていたとすると、3ラベルとカウントします。

google-images-download を使って、Google 検索結果から画像を保存する方法 – pystyle

今回は google-images-download を使って、Web 上から金魚の画像を414枚収集しました。

収集した金魚の画像

アノテーション

物体検出用のアノテーションツールを使って、画像に対して物体がある位置の注釈をつけるアノテーションを行います。本記事では VOTT というツールの使用を前提として解説しますが、tzutalin/labelImg などいくつか種類があるので、使いやすいと思うものを使用してください。

物体検出のアノテーションツール VOTT の使い方 – pystyle

VOTT ですべての画像に対してアノテーションを行いました。414枚の画像に対して625ラベルのアノテーションを行い、作業時間は2時間でした。アノテーションは時間がかかる地道が作業ですが、精度を出すためにとても重要です。学習ではオーグメンテーションも行いますが、これにより機械的に増やせるバリエーションは照明や向きの変化などに限定されるので、多少の精度向上には寄与しますが、元々データ数が少ない場合はどうしようもありません。できるだけ沢山の画像を集めて、アノテーションをしましょう。

VOTT

Advertisement

学習する

設定ファイルの準備

  1. config/yolov3_custom.yamln_classes にクラス数を設定します。今回は1クラスなので n_classes: 1 としました。
  2. config/custom_classes.txt に1行に1つのクラスを記載します。

データセットの変換

スクリプトの都合上、VOTT でアノテーションしたデータセットを1枚の画像に対して、ラベルが記載された1つのテキストファイルが対応する以下の形式にデータセットを変換します。

python convert_vott_dataset.py <VOTT のデータセットのあるディレクトリ> <出力先のディレクトリ>

例:

python convert_vott_dataset.py F:\work\dataset\金魚 custom_dataset

変換が完了すると、custom_dataset ディレクトリに変換結果が出力されます。images に画像、labels に同じ名前で対応するラベルが配置されます。

custom_dataset
|-- images
|   |-- 000000.jpg
|   |-- 000001.jpg
    ...
`-- labels
    |-- 000000.txt
    |-- 000001.txt
    ...

学習する

学習前に上記作業を行った時点での自作データセットに関係するファイルを確認します。

pytorch_yolov3
|-- config
|   |-- custom_classes.txt ← クラスの一覧を設定
|   `-- yolov3_custom.yaml ← クラス数 n_classes を設定
|-- custom_dataset ← データセット
|   |-- images
|   |   |-- 000000.jpg
|   |   |-- 000001.jpg
|       ...
|   `-- labels
|       |-- 000000.txt
|       |-- 000001.txt
|       ...
`-- weights
    `-- darknet53.conv.74 ← ダウンロードした Darknet 53 の学習済みの重み

このようなディレクトリ構成になっていることが確認できたら、以下のコマンドで学習を開始します。

  • --dataset_dir: 上記データセットのディレクトリパス
  • --weights: 学習済みモデルのパス。(YOLOv3 の場合、weights/darknet53.conv.74 を指定する。)
  • --config: 設定ファイルのパス
python train_custom.py \
    --dataset_dir custom_dataset \
    --weights weights/darknet53.conv.74 \
    --config config/yolov3_custom.yaml

学習結果は train_output というディレクトリが作られ、その中に重みを含む学習ステータスが .pth ファイル、損失の履歴が history.csv に保存されます。

`-- train_output
    |-- yolov3_001000.pth
    |-- yolov3_002000.pth
    |-- ...
    |-- yolov3_final.pth
    `-- history.csv

学習完了後に損失関数の推移が記載された history.csv を pandas で読み込み、グラフ化してみます。

In [1]:
import pandas as pd
from matplotlib import pyplot as plt

df = pd.read_csv("train_output/history.csv")
df.plot(
    x="iter",
    y=["loss_total", "loss_xy", "loss_wh", "loss_obj", "loss_cls"],
    figsize=(8, 4),
    logy=True,
)
plt.show()
Advertisement

学習した重みで推論する

学習した重みを使用して、金魚の画像に対して推論してみます。出力結果は output ディレクトリ以下に出力されます。

  • --input: 推論する画像ファイルのパス
  • --output: 結果を出力するディレクトリのパス
  • --weights: 学習した重みファイルのパス (.pth ファイル)
  • --config: 設定ファイルのパス
python detect_image.py \
    --input goldfish.jpg \
    --output output \
    --weights weights/yolov3_final.pth \
    --config config/yolov3_custom.yaml

推論結果

画像にうつっている金魚が正しく検出できました。