概要
ニューラルネットワークによる生成モデル GAN (Generative Adversarial Nets) の理論的背景について解説します。
生成モデル
観測されたデータはある確率分布 pd に従う確率変数によって得られたものと仮定します。この確率分布をモデル化した pg を生成モデル (Generative Models) といいます。真の確率分布 pd に近い生成モデル pg を作成できれば、サンプリングにより、実際のデータに近いサンプルを生成することが可能となります。
例として、手書き画像データセットである MNIST について考えます。1つのサンプルは、画素数が 28×28=784、画素値が [0,255] の範囲の整数であるグレースケール画像です。I={0,1,⋯,255} としたとき、MNIST データセットのサンプルは I784 の空間上に、ある確率分布 pd に従って分布していると考えます。
GAN
ニューラルネットワークを使用した生成モデルを深層生成モデル (Deep Generative Models, DGM) といいます。
今回紹介する GAN (Generative Adversarial Nets) も深層生成モデルの一種です。生成モデル pg を高い表現力を持つニューラルネットワークでモデル化したことにより、画像のような高次元空間上に分布する複雑なデータを生成することが可能となりました。スタイル変換やイラストの自動着色など幅広い応用例があり、活発に研究されている分野の1つです。
GAN の構造
GANでは、Generator (生成モデル) と Discriminator (識別モデル) という2つのモデルを用意します。
- Generator はノイズ z を入力として、偽物のデータ x を出力するモデルです。
- Discriminator は本物または偽物のデータ x を入力として、そのデータが本物である確率を出力するモデルです。
Discriminator は本物かどうかを正しく判定することを目標とする一方、Generator は Discriminator に偽物と判定されない本物に近いデータを生成することを目標として学習します。この2つのネットワークが互いに競うように、交互に学習を行うことで、本物に近いデータを生成できる Generator を作ります。
GAN の目標関数
論文では、GAN の学習を次式の最適化問題として設定しています。
GmaxDmaxV(D,G)=GmaxDmaxEx∼pd [logD(x)]+Ez∼pz [log(1–D(G(z)))]
- pd: 本物のデータが従う分布
- pz: ノイズが従う分布
- pg: Generator が生成した偽物のデータが従う分布
- G(z): ノイズ z を入力とし、データ x を出力する Generator
- D(x): データ x を入力とし、x が本物のデータである確率を出力する Discriminator
この最適化問題を解くことが、Generator 及び Discriminator の目標とどのように関係があるのかを以下で説明します。この最適化問題の理論的な解釈は以下のリンクを参照してください。
- Generative Adversarial Nets: GAN の論文
- An Annotated Proof of Generative Adversarial Networks with Implementation Notes: GAN 論文内の証明の詳解
Discriminator の学習
第1項 Ex∼pd [log(D(x))] について考えます。D(x) の出力は [0,1] の範囲の確率値になりますが、このとき log(D(x)) の値は以下のようになります。
D(x) は pd の確率分布から得られた本物のデータに対して1に近い値を出力できるようにしたいので、期待値 Ex∼pd [log(D(x))] が大きくなる D を探します。
第2項 Ez∼pz [log(1–D(G(z)))] について考えます。D(x) の出力は [0,1] の範囲の確率値になりますが、このとき log(1–D(x))) の値は以下のようになります。
D(x) は pg の確率分布から得られた偽物のデータに対して0に近い値を出力できるようにしたいので、期待値 Ez∼pz [log(1–D(G(z)))] が大きくなる D を探します。
したがって、入力データが本物かどうかを判別するという Discriminator の目標を達成するには、以下の最大化問題を考えればよいことがわかります。
DmaxV(D,G)=DmaxEx∼pd [logD(x)]+Ez∼pz [log(1–D(G(z)))]この解を D∗ とします。
Generator の学習
Discriminator が D∗ に固定されたとき、次に Generator の学習について考えます。
V(D∗,G)=Ex∼pd [logD∗(x)]+Ez∼pz [log(1–D∗(G(z)))]第1項は定数なので無視し、第2項 Ez∼pz [log(1–D∗(G(z)))] について考えます。
Generator は pz の確率分布から得られたノイズを入力として生成した偽物のデータ G(z) に対して、Discriminator に1に近い値を出力できるようにしたいので、期待値 Ez∼pz [log(1–D∗(G(z)))] が小さくなる G を探します。
よって、以下の最小化問題を考えればよいことがわかります。
GminV(D,G)=GminEx∼pd [logD∗(x)]+Ez∼pz [log(1–D∗(G(z)))]この解 G∗ が最終的に求めたい Generator です。
実装上の差異
実装上は、上記の式をそのまま実行するのではなく、以下の変更を加えます。
期待値の計算
任意のデータやノイズに関する期待値は計算できないので、標本平均に置き換えます。
Ez∼pz [log(1–D∗(G(z)))] は、確率分布 pz からサンプリングした n 個のノイズによる標本平均 n1∑i=1nlog(1–D(G(zi))) で計算します。
また、Ex∼pd [logD∗(x)] は、学習データから n 個を選択し、n1∑i=1nlogD(x) で計算します。
V(D,G)=n1i=1∑n(logD(xi)+log(1–D(G(zi))))
損失関数に binary cross entropy を使う
正解ラベルを ti,(i=1,2,⋯,n,ti∈{0,1})、出力を yi,(i=1,2,⋯,n,yi∈[0,1]) としたとき、binary cross entropy は次式になります。
BCELoss=–i=1∑n(tilogyi+(1–ti)log(1–yi))Discriminator について考えます。
xi が本物であるとき、出力 D(xi) に対する正解ラベル ti は1なので、
BCELossd=–i=1∑nlogD(xi)G(zi) は偽物なので、出力 D(xi) に対する正解ラベル ti は0なので、
BCELossg=–i=1∑nlog(1–D(G(zi)))よって、2つを足すと
BCELossd+BCELossg=–i=1∑n(logD(xi)+log(1–D(zi)))=−nV(D,G)よって、これの最小化は元の V(D,G) の D に関する最小化と同じことがわかります。
Dminimize−nV(D,G)⇔DmaximizeV(D,G)Generator について考えます。
G(zi) を Discriminator に本物であると認識させたいので、正解ラベル ti は1となり、
BCELoss=–i=1∑nlogD(xi)よって、これの最小化は元の V(D,G) の G に関する最小化と同じことがわかります。
Gminimize–i=1∑nlogD(xi)⇔Gmaximize–i=1∑nlog(1–D(xi))⇔Gminimizei=1∑nlog(1–D(x))⇔GminimizeV(D,G)
最適化の順序
論文の式だと V(G,D) を D について最適化してから、G について最適化しますが、実装上は、モデル D を1反復分パラメータを更新し、モデル G を1反復分パラメータを更新することを交互に繰り返して最適化を行います。
Pytorch の実装例
今回は MNIST データセットを利用して、手書き数字画像を生成できる Generator を作ります。
モデル構造は論文では特に規定されていないので、シンプルな全結合ニューラルネットワークで作成します。
モジュールを import する
デバイスを選択する
Transform、Dataset、DataLoader を作成する
入力データは値の範囲を [−1,1] に正規化します。
ToTensor
で PIL Image 画像を値の範囲が [0,1] のテンソルに変換します
Normalize
でを値の範囲を [−1,1] に変換します
標準化とは、stdx–mean であるため、mean=0.5,std=0.5 とすると、値の範囲が [−1,1] になります。
Generator を作成する
Generator はノイズ z を入力とし、データ x を出力するモデルになります。
Discriminator を作成する
Discriminator はデータ x を入力とし、データが本物である確率を出力するモデルになります。
[0,1] の範囲の値を出力できるように、出力層の活性化関数はシグモイド関数を使用します。
GAN を作成する
MNIST データセットの各サンプルは大きさが (28, 28) のグレースケール画像なので、データの次元数は 28×28=784 になります。
損失関数とオプティマイザを作成する
損失関数は先に説明した通り、Binary Cross Entropy なので、BCELoss を使用します。
Discriminator を学習する関数を作成する。
Generator を学習する関数を作成する。
Generator で画像を生成する関数を作成する。
データを生成するときは、勾配情報を不要なので、torch.no_grad()
コンテキストで実行します。
GAN の学習を実行する関数を作成する
損失関数の値の推移を描画する
生成される画像の推移を gif 動画で保存する
pillow を使用して、各エポックの生成画像を gif 動画にして保存します。
学習が進むにつれ、段々とはっきりした手書き数字画像が生成される様子が確認できます。
各エポックの生成画像
学習が終了した段階の Generator が生成する画像を表示します。
参考文献
- Generative model – Wikipedia: 生成モデル
- Generative Adversarial Nets: GAN の論文
- An Annotated Proof of Generative Adversarial Networks with Implementation Notes: GAN 論文内の証明の詳解
- GANと損失関数の計算についてまとめた – Qiita
- eriklindernoren/PyTorch-GAN: PyTorch implementations of Generative Adversarial Networks.: いろいろな GAN の Pytorch での実装例
- GANs from Scratch 1: A deep introduction. With code in PyTorch and TensorFlow
- The GAN objective, from practice to theory and back again
コメント
コメント一覧 (0件)
分かりやすい解説ありがとうございます。
細かい点かもしれませんが、Discriminatorの最大化問題の括弧の数が合わない気がします(Generatorも間違っているかもです)
追記です。
目標関数のGの箇所はminではないでしょうか
コメント及びご指摘ありがとうございます。
誤記があり、すみません。
後ほど修正いたします。
遅くなりすみません。
ご指摘いただいた箇所を修正しました。
ありがとうございました。