目次
概要
ディープラーニングの画像認識モデルである SqueezeNet を解説し、Pytorch の実装例を紹介します。
SqueezeNet
SqueezeNet について、論文 SQUEEZENET: ALEXNET-LEVEL ACCURACY WITH 50X FEWER PARAMETERS AND <0.5MB MODEL SIZE に基づいて解説します。 従来のモデルと比較し、パラメータ数が大幅に削減した (Squeeze) モデルとなっています。SqueezeNet では、次のテクニックで構築されています。
- 多くの畳み込み層のカーネルサイズを 1×1 にしました。
- 1×1 の畳み込みで数を削減してから、3×3 の畳み込み層に入力するようにしました。
- プーリング層を畳み込み層のできるだけ後ろに持っていくことで、表現力が低下しにくいようにしました。
Fire Module
Fire Module では、まず 1×1 の畳み込みでチャンネル数を削減し、1×1 及び 3×3 の畳み込みに入力します。 1×1 及び 3×3 の畳み込みではチャンネル数を拡張し、最後にチャンネル方向に結合したものが Fire Module の出力となります。
SqueezeNet v1.0、v1.1
論文に記載されているパラメータのモデルは SqueezeNet v1.0 になります。論文作者の Caffe 実装 には、SqueezeNet v1.1 も公開されています。パラメータ数が AlexNet の 1/50 にも関わらず、AlexNet より高い精度になっています。
Pytorch の実装
以下に Pytroch の実装例を紹介します。Pytorch では、以下の SqueezeNet が提供されています。実用上は torchvision.models の SqueezeNet を使用してください。
モデル名 | 関数名 | パラメータ数 | Top-1 エラー率 | Top-5 エラー率 |
---|---|---|---|---|
SqueezeNet 1.0 | squeezenet1_0() | 1248424 | 41.9 | 19.58 |
SqueezeNet 1.1 | squeezenet1_1() | 1235496 | 41.81 | 19.38 |
Fire Module を作成する
Fire Module を作成します。
In [ ]:
import torch
import torch.nn as nn
class Fire(nn.Module):
def __init__(self, in_channels, squeeze_channels, expand1x1_channels, expand3x3_channels):
super().__init__()
self.squeeze = nn.Conv2d(in_channels, squeeze_channels, kernel_size=1)
self.squeeze_activation = nn.ReLU(inplace=True)
self.expand1x1 = nn.Conv2d(squeeze_channels, expand1x1_channels, kernel_size=1)
self.expand1x1_activation = nn.ReLU(inplace=True)
self.expand3x3 = nn.Conv2d(squeeze_channels, expand3x3_channels, kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.squeeze(x)
x = self.squeeze_activation(x)
branch1 = self.expand1x1(x)
branch1 = self.expand1x1_activation(branch1)
branch2 = self.expand3x3(x)
branch2 = self.expand3x3_activation(branch2)
return torch.cat([branch1, branch2], 1)
SqueezeNet を作成する
- 特徴抽出部分が v1.0 と v1.1 で異なっています。パラメータは forresti/SqueezeNet を参考にしています。
- Caffe 実装では、最後の畳み込み層だけ平均0、標準偏差0.01の正規分布で初期化しています。
In [1]:
class SqueezeNet(nn.Module):
def __init__(self, version="1_0", drop_rate=0.5, num_classes=1000):
super().__init__()
if version == "1_0":
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2), # (N, 96, 111, 111)
nn.ReLU(inplace=True), # (N, 96, 111, 111)
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), # (N, 96, 55, 55)
Fire(96, 16, 64, 64), # (N, 128, 55, 55)
Fire(128, 16, 64, 64), # (N, 128, 55, 55)
Fire(128, 32, 128, 128), # (N, 256, 55, 55)
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), # (N, 256, 27, 27)
Fire(256, 32, 128, 128), # (N, 256, 27, 27)
Fire(256, 48, 192, 192), # (N, 384, 27, 27)
Fire(384, 48, 192, 192), # (N, 384, 27, 27)
Fire(384, 64, 256, 256), # (N, 512, 27, 27)
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), # (N, 512, 13, 13)
Fire(512, 64, 256, 256), # (N, 512, 13, 13)
nn.Dropout(p=drop_rate), # (N, 512, 13, 13)
)
elif version == "1_1":
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(64, 16, 64, 64),
Fire(128, 16, 64, 64),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(128, 32, 128, 128),
Fire(256, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
nn.Dropout(p=drop_rate),
)
final_conv = nn.Conv2d(512, num_classes, kernel_size=1)
final_conv.stddev = 0.01
self.classifier = nn.Sequential(
nn.Dropout(p=0.5), # (N, 512, 13, 13)
final_conv, # (N, 1000, 13, 13)
nn.ReLU(inplace=True), # (N, 1000, 13, 13)
nn.AdaptiveAvgPool2d((1, 1)), # (N, 1000, 1, 1)
)
# 重みを初期化する。
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
if hasattr(m, "stddev"):
torch.nn.init.normal_(m.weight, std=m.stddev)
else:
torch.nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return torch.flatten(x, 1)
def squeezenet1_0():
return SqueezeNet("1_0")
def squeezenet1_1():
return SqueezeNet("1_1")
コメント