Pytorch – 中間層の出力を取得する方法

Pytorch – 中間層の出力を取得する方法

概要

Pytorch で中間層の出力を取得する方法について解説します。

Advertisement

手順

必要なモジュールを import する

In [1]:
import torch
import torchvision
from PIL import Image
from torchvision import transforms

画像を読み込む

以下の画像を使用します。

sample.jpg

In [2]:
def get_device(gpu_id=-1):
    if gpu_id >= 0 and torch.cuda.is_available():
        return torch.device("cuda", gpu_id)
    else:
        return torch.device("cpu")


# デバイスを選択する。
device = get_device(gpu_id=0)

# Transforms を作成する
transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# 画像を読み込む
img = Image.open("sample.jpg")
inputs = transform(img)
inputs = inputs.unsqueeze(0).to(device)

モデルを作成する

例として、torchvision の VGG16 モデルを使用します。 Pytorch では、各層 (例: Conv2d、ReLU) や複数の層をまとめたもの (例: Sequential)、またモデル自体もモジュール (torch.nn.Module) として表されます。 モデルを構成するモジュール構成がどうなっているかは、定義方法によって異なるので、ソースコード (VGG16 の場合、vgg.py) で確認します。

今回は、2つ目の畳み込み層後に活性化関数 ReLU を適用したあとの出力を抽出してみます。 2つ目の畳み込み層後の活性化関数 ReLU を表すモジュールは model.features[3] です。

In [3]:
# モデルを作成する
model = torchvision.models.vgg16(pretrained=True).to(device)
print(model)

# 抽出対象の層
target_module = model.features[3]  # (3): ReLU(inplace=True)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
Advertisement

中間層の出力を取得する

まず、出力を取得したいモジュールに中間層の出力を記録するコールバック関数を登録します。 torch.nn.Module.register_forward_hook() でコールバック関数を登録しておくと、順伝搬にその関数が呼び出されます。コールバック関数は、モジュール module、そのモジュールの入力 inputs、そのモジュールの出力 outputs の3つの引数をとる関数になります。

def forward_hook(module, inputs, outputs):
    # 処理

コールバック関数を登録したら、順伝搬して、中間層の出力を取得します。

In [4]:
def extract(target, inputs):
    feature = None

    def forward_hook(module, inputs, outputs):
        # 順伝搬の出力を features というグローバル変数に記録する
        global features
        features = outputs.detach()

    # コールバック関数を登録する。
    handle = target.register_forward_hook(forward_hook)

    # 推論する
    model.eval()
    model(inputs)

    # コールバック関数を解除する。
    handle.remove()

    return features


features = extract(target_module, inputs)
print(features.shape)
# torch.Size([1, 64, 224, 224])
torch.Size([1, 64, 224, 224])

中間層の出力を可視化する

抽出した出力は、出力数が64の畳み込み層の出力に ReLU を適用したものなので、64枚の大きさが (244, 224) のグレースケール画像として可視化します。

In [5]:
def feature_to_img(feature):
    # (N, H, W) -> (N, C, H, W)
    feature = feature.unsqueeze(1)
    # 画像化して、格子状に並べる
    img = torchvision.utils.make_grid(feature.cpu(), nrow=4, normalize=True, pad_value=1)
    # テンソル -> PIL Image
    img = transforms.functional.to_pil_image(img)
    # リサイズする。
    new_w = 500
    new_h = int(new_w * img.height / img.width)
    img = img.resize((new_w, new_h))

    return img


img = feature_to_img(features[0])
img