概要
CNN の推論結果を解釈するには、入力画像に対する CNN の反応を可視化した顕著性マップ (saliency map) を見ることが有用です。 本記事では、Pytorch を使用して顕著性マップを作成する方法について解説します。
顕著性マップ (saliency map)
顕著性マップ (saliency map) は入力画像の各ピクセルに対して、CNN がどの程度反応したかを表すヒートマップです。モデルが強く反応した領域は、重要な情報を含んでいることを意味します。顕著性マップを作成することで、入力画像のどの領域が出力結果に影響したのかを視覚的に確認することができます。 顕著性マップ の作成方法は、Grad-CAM などいくつかの手法がありますが、本記事では通常の逆伝搬 (vanilla backpropagation) を利用した方法について、Pytorch の実装例を交えて解説します。
Pytorch による実装
必要なモジュールを import する
以下の外部ライブラリを使用します。ライブラリがない場合は、インストールしてください。
- OpenCV:
pip install opencv-contrib-python
- matplotlib:
pip install matplotlib
- NumPy:
pip install numpy
- Pytorch:
pip install pytorch
- torchvision:
pip install torchvision
- opencv_transforms:
pip install opencv_transforms
デバイスを選択する
使用するデバイスを表す device オブジェクトを返す関数を作成します。CUDA が利用可能かどうかは torch.cuda.is_available() で判定できます。use_gpu=True
かつ CUDA が利用可能な場合は GPU デバイス、そうでない場合は CPU デバイスを返します。GPU デバイスを使用する場合、デフォルトでは、逆伝搬で計算した勾配の値が実行するたびに異なる現象が発生します。再現性を担保するために torch.backends.cudnn.deterministic = True
に設定します。
モデルを作成する
torchvision で利用可能なモデルの一覧を返す関数を作成します。モデルを作成する関数は models モジュール以下にあります。
dir()
で models モジュール以下の属性名の一覧を取得する。getattr()
で属性にする。- isfunction() でその属性が関数かどうかを判定する。関数の場合は、
dict
にキーが関数名、値が関数という形で追加する。
指定した名前のモデルを作成する関数を作成します。
_get_avaable_models()
で利用可能なモデルの一覧を取得する。available_models[name]
で名前name
のモデルを作成する関数を取得する。- 学習済みモデルを使用するので、
pretrained=True
を指定してモデルを作成する。 nn.Module.to()
でモデルを使用するデバイスに転送する。nn.Module.eval()
でモデルを推論モードに設定する。
今回、モデルは AlexNet を使用します。
DataLoader を作成する
まず、ディレクトリ内から指定した拡張子のファイルを探し、そのパスの一覧を返す関数を作成します。
次に指定したディレクトリから画像を読み込む Dataset を作成します。 今回、ImageNet で学習済みのモデルを使用するため、そのモデルを学習したときと同じ以下の前処理を行う必要があります。
(224, 224)
にリサイズする。- RGB の値を平均
[0.485, 0.456, 0.406]
、分散[0.229, 0.224, 0.225]
で標準化する。
init()
_get_img_paths(img_dir)
で指定したディレクトリにある画像ファイルのパス一覧を取得する。- numpy 配列を Tensor に変換する ToTensor と標準化する Normalize を Compose で結合し、Transformer を作成する。
getitem(self, index)
self.img_paths[index]
でインデックスindex
のファイルパスを取得する。cv2.imread(path)
で画像ファイルを読み込む。- OpenCV のチャンネル順は BGR、学習済みのモデルの入力画像のチャンネル順は RGB のため、
cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB)
でチャンネル順を BGR から RGB にする。 cv2.resize(raw_img, ImageFolder.IMAGENET_SIZE)
でリサイズする。self.transform(raw_img)
で numpy 配列のテンソル化及び標準化を行う。- 標準化前の画像、標準化後の画像、ファイルパスを返す。
先程作成した Dataset オブジェクトからミニバッチを生成する DataLoader を作成します。
ImageNet の日本語のクラス名を取得する関数を作成します。
- クラス一覧を記載したファイルが存在しない場合、download_url() でダウンロードする。
- クラス一覧を記載した json ファイルを読み込む。
Vanilla Backpropagation を計算する処理
先にコード全体を記載してから、解説します。
順伝搬
- 入力画像
inputs
をto()
で使用するデバイスに転送する。入力画像は学習するパラメータを持たないので、デフォルトでは勾配の計算は行われい。そのため、requiresgrad() で勾配の計算を行うように設定する。 - zero_grad() で勾配を初期化する。
self.model(inputs)
で順伝搬を行う。- モデルには出力層の softmax は含まれていないので、
F.softmax(logits, dim=1)
で softmax を計算する。 class_ids = probs.sort(dim=1, descending=True)[1]
で確率が大きい順にソートし、対応するクラス ID の一覧を返す。
逆伝搬
- 指定したクラス ID の箇所を1、それ以外の箇所を0とした逆伝搬の入力を作成する。
to()
で使用するデバイスに転送する。logits.backward(gradient=onehot, retain_graph=True)
で逆伝搬を行う。逆伝搬後も勾配の情報は保持しておきたいので、retain_graph=True
を指定する。
勾配を取得する
self.inputs.grad
で入力画像の勾配を取得する。- clone() でディープコピーする。
- モデルの入力画像の勾配は、
self.inputs.grad.zero_()
で初期化する。
テンソルを画像化する
テンソルを画像に変換します。
normalize(x)
でテンソルを に正規化する。x.cpu().numpy()
でテンソルを numpy 配列に変換する。transpose(0, 2, 3, 1)
で軸を(B, C, H, W)
から(B, H, W, C)
に並び替える。- に正規化されているので、255 を乗算し、
astype(np.uint8)
でキャストすることで、[0, 255] の非負の整数に変換する。 x.shape[3] == 1
、つまり、グレースケール画像の場合、x.squeeze(axis=3)
で(B, H, W, 1)
を(B, H, W)
に変換する。
顕著正マップをグレースケール画像にする。
x.abs().sum(dim=1, keepdims=True)
で各チャンネルの値を足し合わせて、(B, C, H, W)
を (B, 1, H, W)
にします。
顕著正マップで値が正のピクセルだけ強調する。
torch.nn.functional.relu(x)
で負の値は0としたテンソルを作成します。
顕著正マップで値が負のピクセルだけ強調する。
torch.nn.functional.relu(-x)
で正の値は0としたテンソルを作成します。
メインの処理
for batch in dataloader
でループさせて、ミニバッチbatch
を取得する。vanilla_backprop.forward(batch["image"])
を呼び、確率が大きい順にソートされたクラス ID の一覧を取得する。- 確率が高い上位3クラスについて、見ていく。
ith_class_ids = class_ids[:, i]
でi
番目にスコアが高いクラス ID の一覧を取得します。vanilla_backprop.backward(ith_class_ids)
で逆伝搬を行う。vanilla_backprop.generate()
で入力画像の勾配を取得する。- 4通りの方法で、する。
- 結果は dict に記録する。
可視化する
作成した顕著性マップを matplotlib で可視化します。
![](/wp/wp-content/uploads/2021/01/pytorch-vanilla-backpropagation_01.jpg)
![](/wp/wp-content/uploads/2021/01/pytorch-vanilla-backpropagation_02.jpg)
![](/wp/wp-content/uploads/2021/01/pytorch-vanilla-backpropagation_03.jpg)
コメント