概要
確率的勾配降下法 (Stochastic Gradient Decent, SGD)、重み減衰 (weight decay)、Momentum、Nesterov’s Momentum について解説します。
確率的勾配降下法 (Stochastic Gradient Decent, SGD)
勾配降下法 (Gradient Decent) は、各ステップ t でその時点でのパラメータ θt–1 の目標関数 (objective function) の勾配 (gradient) ∇θf(θt–1) を計算します。目標関数とは損失関数 (loss function) など最小化したい関数のことです。勾配は、その地点で関数の値が最も増加する方向を表しているので、勾配と反対の方向は関数の値が最も減少する方向を表しています。どのくらい動くかは学習率 (learning rate, lr) というハイパーパラメータによって決めます。
勾配降下法のアルゴリズム
input:γ(lr),θ(params),f(θ) (objective)θ0←initial valuefort=1to…dogt←∇θft(θt–1)θt←θt–1–γgtreturnθt目標関数は入力データにも依存するので、∇θf(θt–1;x) と表記しますが、この x に学習データのうち、どのくらいのデータを使用するかによって次の種類があります。
- 学習データの中の1個のサンプルを使用する: 確率的勾配降下法 (Stochastic Gradient Decent, SGD)
- 学習データの一部のサンプルを使用する: ミニバッチ勾配降下法 (Minibatch Gradient Descent)
- 学習データのすべてのサンプルを使用する: バッチ勾配降下法 (Batch Gradient Descent)
バッチ (batch) とは学習データのすべてのサンプルのことをいうのに対し、ミニバッチ (minibatch) は一部のサンプルのことをいいます。ディープラーニングの文脈では、上記3つをあわせて確率的勾配降下法といい、使用するサンプルの数をバッチサイズ (batch size)といいます。
重み減衰 (weight decay)
目標関数を f(θ)、パラメータを θ としたとき、最小化対象を f(θ)+λ∥θ∥ と変更することを重み減衰 (weight decay)といいます。λ∥θ∥ は正則化項といい、λ はどのくらい正則化を強くするかを制御するハイパーパラメータです。正則化項が加わったことにより、∥θ∥ の値も小さくする制約を考慮しつつ、目標関数 f(θ) を最小化することができます。制約の追加は、ネットワークは大量のパラメータがあり自由度が高いため、過学習を防ぎ、汎化能力を高めるために行います。
通常、2ノルム ∥⋅∥2 を使用する L2 正則化を使用します。2ノルムは微分した際に ∇θ∥θ∥2=21θ となるので、ステップ t の勾配を計算する際は、実装上は目標関数の勾配に λθ を加えます。
gt←∇θf(θt–1)+λθt–1
Momentum
Momentum が有効の場合、今回の勾配に過去の勾配を加えます。これにより、勾配が振動して学習が不安定になる問題を防ぎます。μ の値が大きいほど過去の勾配の影響が大きくなります。
Momentum
input:γ (lr),θ (params),f(θ) (objective),μ (momentum),θ0←initial valuev0←0fort=1to…dogt←∇θft(θt–1)vt←μvt–1+gtθt←θt–1–γvtreturnθtvt の漸化式を展開すると、
vt=μtg1+μt–1g2+⋯+gt=i=1∑tμt–i+1giであるから、Momentum は通常の SGD の勾配をこれまでの勾配の指数移動平均に置き換えたアルゴリズムであると言えます。
※ 上記で紹介した Pytorch の実装とは異なり、元の論文 (Sutskever et. al.) では学習率を乗算する部分に Momentum の項 μvt–1 を含まないので、その点の差異があります。
input:γ (lr),θ (params),f(θ) (objective),μ (momentum),θ0←initial valuev0←0fort=1to…dogt←∇θft(θt–1)vt←μvt–1+gtθt←θt–1–μvt–1–γgtreturnθt
Nesterov’s Momentum
Nesterov’s Momentum またはネステロフの加速勾配法 (Nesterov’s Accelerated Gradient method, NAG) は、Momentum のアルゴリズムにおいて、勾配を計算する位置を θt–1 から −γμvt–1 だけ移動した位置 θt–1–γμvt–1 に変更したものです。少し進んだ先で勾配の方向が変わる場合にその事を考慮に入れて、次の移動する方向を決めれるため、より効率的に移動できます。
input:γ (lr),θ (params),f(θ) (objective),μ (momentum),θ0←initial valuev0←0fort=1to…dogt←∇θft(θt–1–γμvt–1)vt←μvt–1+gtθt←θt–1–γvtreturnθt
Nesterov’s Momentum
以下の式変形を行うと、Pytorch の SGD に記載されているアルゴリズムと一致します。
θ’t=θt–γμvt とおくと、
θ’t=θt–γμvt=θt–1–γvt–γμvt=θ’t–1+γμvt–1–γvt–γμvt∵θ’t–1=θt–1–γμvt–1=θ’t–1–γ(vt–μvt–1+μvt)=θ’t–1–γ(∇θ’f(θ’t–1)+μvt)∵vt–μvt–1=∇θ’f(θ’t–1)よって
vtθ’t=μvt–1+∇θ’f(θ’t–1)=θ’t–1–γ(∇θ’f(θ’t–1)+μvt)Nesterov’s Momentum のアルゴリズム (Pytorch 版)
input:γ (lr),θ (params),f(θ) (objective),μ (momentum),θ0←initial valuev0←0fort=1to…dogt←∇θf(θt–1)vt←μvt–1+gtθt←θt–1–γ(gt+μvt)returnθt※ 公式ドキュメントの記載において、gt←gt–1+μbt と記載されていますが、ソースコードを見ると、gt←gt+μbt が正しいと思われます。
Pytorch で SGD を使用する
確率的勾配降下法は、SGD で実装されています。
dampening
は Momentum の値を更新する部分で vt←μvt–1+(1–dampening)gt として、加算される現在の勾配の値の影響を小さくするパラメータです。dampening=0
で前の項で紹介したアルゴリズムと同じになります。
SGD を使用して関数の最小値を探す
f(x,y)=x2+y2+xy という関数の最小値を SGD を使用して探索してみます。
この関数は次のような形状をしています。
torch.tensor(init, dtype=torch.float32, requires_grad=True)
で変数を定義します。requires_grad=True
とした場合、勾配の計算を行えるテンソルになります。
optimizer = torch.optim.SGD([x], lr)
で SGD の Optimizer を定義します。第1引数には最適化を行う変数をリストで渡します。
y = f(x)
で x
に対する関数値を計算します。
y.backward()
で x
に対する f(x)
の勾配を計算します。計算前に optimizer.zero_grad()
を呼び出し、前回のステップの勾配情報をクリアする必要があります。勾配をクリアしない場合、前回の勾配に今回の勾配が加算される形になります。
optimizer.step()
で計算された勾配を元に、SGD で変数 x
を更新します。
- 変数 x の移動量 ∥xold–x∥ を計算し、この値が十分小さい場合は収束したと判断して、最大のステップ数に達していない場合でも途中で抜けます。
- 上記のステップを指定回数繰り返します。
変数 x
の推移を等高線及びグラフに描画し、確認します。
学習率の影響
学習率がパラメータの更新にどう影響するかを可視化して考察します。
学習率が大きい場合は、移動方向で最も関数が最小となる点を通り過ぎてしまうため、下記のように振動しながら収束します。
逆に学習率が小さい場合は、一回に移動する量が少ないため、収束までに多くのステップ数が必要です。
局所解、鞍点
凸関数でない場合、極小点が大域解であることが保証されませんが、ニューラルネットワークの損失関数は通常、凸関数でありません。勾配降下法は局所的に関数の値を小さくする方向に逐次移動するアルゴリズムであるため、学習率によっては局所解や鞍点に嵌ると抜け出せなくなります。
下記の例は f(x,y)=x3+y−3+3x2–3y2–8 という関数の例です。この関数は (0,0),(0,2),(−2,0),(−2,2) という4点で勾配が ∇xf(x)=0 になります。SGD を実行すると、(0,2) の局所解に嵌ってしまい更新が停止します。
Momentum の影響
勾配が急なところでは、Momentum により1回の移動量が多くなるため、解に早くたどり着けます。
Nesterov’s Momentum の場合、少し進んだ先での勾配の方向を考慮に入れて移動できるため、Momentum にあった行き過ぎて折り返すということがなくなるため、より少ないステップで収束しています。
パラメータごとに各種係数を変更したい場合
torch.optim.SGD()
の第1引数に dict の list を指定する方式で、パラメータごとに学習率などの係数を指定できます。
このようにした場合、パラメータAだけ学習率は 0.001 になり、個別に指定していないパラメータB、パラメータCの学習率は0.01になります。
参考
コメント