Pytorch – DenseNet の仕組みと実装について解説

目次

概要

ディープラーニングの画像認識モデルである DenseNet を解説し、Pytorch の実装例を紹介します。

DenseNet

DenseNet について、論文 Densely Connected Convolutional Networks に基づいて解説します。

Dense Layer

Dense Layer は、複数の畳み込み層、Batch Normalization、ReLU から構成されます。DenseNet では、ReLU -> Batch Normalization -> Conv の順番で適用します。最初に 1×1 の畳み込みで次元数を削減したあと、3×3 の畳み込みを適用します。 各畳み込み層の出力数は以下のようになります。ただし、$k$ は後で紹介する growth rate を表します。

出力の形状
入力 (N, C, H, W)
ReLU (N, C, H, W)
Batch Normalization (N, C, H, W)
1×1 Conv (N, k * 4, H, W)
ReLU (N, k * 4, H, W)
Batch Normalization (N, k * 4, H, W)
3×3 Conv (N, k, H, W)

Dense Block

Dense Block は連続した複数の Dense Layer から構成されるブロックです。Dense Block の入力を $x_0$、$l$ 個目の Dense Layer を $H_l(x)$、$l$ 個目の Dense Layer の出力を $x_l$ としたとき、$H_l(x)$ は Dense Block の入力 $x_0$ 及びそれより前のDense Layer $H_1, H_2, \cdots, H_{l – 1}$ の出力 $x_1, x_2, \cdots, x_{l – 1}$ をチャンネル方向に結合したものが入力となります。

$$ H(\text{concat}(x_0, x_1, x_2, \cdots, x_{l -1})) = x_l $$

Dense Block

$l$ 個目の Dense Layer の入力のチャンネル数は、Dense Block の入力 $x_0$ のチャンネル数を $k_0$、各 Dense Layer の出力のチャンネル数を $k$ としたとき、

$$ x_l の入力チャンネル数 = k_0 + k (l – 1) $$

で計算できますが、この値 $k$ を growth rate といいます。

Dense Block の出力も、Dense Block の入力及びすべての Dense Layer の出力をチャンネル方向に結合した

$$ DenseBlock ブロックの出力 = \text{concat}(x_0, x_1, \cdots, x_L) $$

になります。ただし、$L$ は Dense Layer の数です。

Transition Layer

DenseNet は DenseBlock を複数重ねた構造になっています。Dense Block の間には、1×1 の畳み込みでチャンネル数を削減し、平均プーリングで大きさを削減する Transition Layer が挿入されています。

DenseNet

DenseNet

DenseNet の最初には kernel_size=7, stride=2, out_channels=2 * grouth_rate の畳み込みがあります。 最後の Dense Block の後には、Global Average Pooling 層及び分類用の全結合層があります。

Pytorch の実装

Pytorch では、以下の DenseNet が提供されています。以下に Pytroch の実装例を紹介します。実用上は torchvision.models の DenseNet を使用してください。

モデル名 関数名 パラメータ数 Top-1 エラー率 Top-5 エラー率
Densenet-121 densenet121() 7978856 25.35 7.83
Densenet-161 densenet161() 28681000 24 7
Densenet-169 densenet169() 14149480 22.8 6.43
Densenet-201 densenet201() 20013928 22.35 6.2

Dense Layer を作成する

Dense Layer を作成します。forward()x には、Dense Block の入力及びそれより前の Dense Layer の出力がリストで渡されるので、torch.cat(x, 1) でチャンネル方向に結合しています。

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DenseLayer(nn.Module):
    def __init__(
        self,
        in_channels,
        growth_rate,
        drop_rate,
    ):
        super().__init__()
        self.norm1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, growth_rate * 4, kernel_size=1, bias=False)
        self.norm2 = nn.BatchNorm2d(growth_rate * 4)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(growth_rate * 4, growth_rate, kernel_size=3, padding=1, bias=False)
        self.dropout = nn.Dropout(p=drop_rate)

    def forward(self, x):
        x = torch.cat(x, 1)
        x = self.norm1(x)
        x = self.relu1(x)
        x = self.conv1(x)

        x = self.norm2(x)
        x = self.relu2(x)
        x = self.conv2(x)
        x = self.dropout(x)

        return x

Dense Block を作成する

Dense Block を作成します。forward() 中にリストに Dense Layer の出力を追加していってます。また、Dense Block の出力は、Dense Block の入力及びすべての Dense Layer の出力をチャンネル方向に結合したものなので、torch.cat(x, 1) で結合を行っています。

In [2]:
class DenseBlock(nn.ModuleDict):
    def __init__(
        self,
        num_layers,
        in_channels,
        growth_rate,
        drop_rate,
    ):
        super().__init__()
        for i in range(num_layers):
            layer = DenseLayer(
                in_channels + i * growth_rate,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
            )
            self.add_module(f"denselayer{i + 1}", layer)

    def forward(self, x0):
        x = [x0]
        for name, layer in self.items():
            out = layer(x)
            x.append(out)

        return torch.cat(x, 1)

Transition Layer を作成する

Transition Layer を作成します。1×1 の畳み込みでチャンネル数を削減したあと、平均プーリングで大きさを削減します。

In [3]:
class TransitionLayer(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.add_module("norm", nn.BatchNorm2d(in_channels))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))

DenseNet

以下のパラメータに従い、DenseNet 全体を作成します。

DenseNet のパラメータ

In [4]:
class DenseNet(nn.Module):
    def __init__(
        self,
        growth_rate,
        block_config,
        drop_rate=0,
        num_classes=1000,
    ):
        super().__init__()

        # 最初の畳み込み層を追加する。
        self.features = nn.Sequential()
        self.features.add_module(
            "conv0", nn.Conv2d(3, 2 * growth_rate, kernel_size=7, stride=2, padding=3, bias=False)
        )
        self.features.add_module("norm0", nn.BatchNorm2d(2 * growth_rate))
        self.features.add_module("relu0", nn.ReLU(inplace=True))
        self.features.add_module("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        # Dense Block 及び Transition Layer を作成する。
        in_channels = 2 * growth_rate
        for i, num_layers in enumerate(block_config):
            block = DenseBlock(
                num_layers=num_layers,
                in_channels=in_channels,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
            )
            self.features.add_module(f"denseblock{i + 1}", block)

            in_channels = in_channels + num_layers * growth_rate
            if i != len(block_config) - 1:
                # 最後の Dense Block でない場合は、Transition Layer を追加する。
                trans = TransitionLayer(in_channels=in_channels, out_channels=in_channels // 2)
                self.features.add_module(f"transition{i + 1}", trans)
                in_channels = in_channels // 2

        self.features.add_module("norm5", nn.BatchNorm2d(in_channels))
        self.features.add_module("relu5", nn.ReLU(inplace=True))
        self.features.add_module("pool5", nn.AdaptiveAvgPool2d((1, 1)))

        self.classifier = nn.Linear(in_channels, num_classes)

        # 重みを初期化する。
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x


def densenet121():
    return DenseNet(growth_rate=32, block_config=(6, 12, 24, 16))


def densenet169():
    return DenseNet(growth_rate=32, block_config=(6, 12, 32, 32))


def densenet201():
    return DenseNet(growth_rate=32, block_config=(6, 12, 48, 32))


def densenet264():
    return DenseNet(growth_rate=48, block_config=(6, 12, 64, 48))

コメント

コメントする

目次