概要
「パターン認識と機械学習 上 (PRML)」 の「1.1 例: 多項式曲線フィッティング」に記載されている内容を Python で再現したコードになります。 書籍に記載されている説明は省略しているので、PRML と合わせて読むことが前提の記事です。
問題設定
の多項式曲線 (Polynomial Curve) による近似 (fitting) を考えます。
訓練集合を作成する
- N: サンプル数
- 入力データ集合 (input data set):
- 目標データ集合 (target data set):
今回、訓練データ集合、テストデータ集合を以下のように作成します。
- 訓練データ集合
- 入力データ集合は、 から等間隔に10個選ぶ。
- 目標データ集合は、入力データ集合の各点の の値に正規分布に従うノイズを加えた値とする。
- テストデータ集合
- 入力データ集合は、 から等間隔に100個選ぶ。
- 目標データ集合は、入力データ集合の各点の の値に正規分布に従うノイズを加えた値とする。

多項式曲線フィッティング
次の多項式曲線で近似します。
- は係数が である の関数であることを意味します。
- : 多項式曲線の次数 (degree)
- : 多項式曲線の係数
損失関数は二乗誤差とします。 は微分した際に式を簡潔にするためについています。
この損失関数を最小化する を求めて、 を係数とした多項式曲線 で近似します。
のときの のグラフを描画してみます。

- の場合は、訓練データへの当てはまりが悪く、 の近似には表現力が足りていません
- の場合は、訓練データ、テストデータともに当てはまりがよく、 を上手く近似できています
- の場合は、訓練データへの当てはまりは良いが、テストデータへの当てはまりは悪く、過学習しています
平均平方二乗誤差 (RMSE) を確認する
どのくらい近似できているかを数値的にも評価しましょう。 評価には、平均平方二乗誤差 (RMSE) を使用します。
先程の二乗誤差と比べると、以下の利点があります。
- 平均をとっているため、平均平方二乗誤差の値がサンプル数 に左右されない
- 平方根をとっているため、目標変数 と単位が同じ

- は訓練誤差が大きく、 の近似には表現力が足りていません
- は訓練誤差、テスト誤差ともに低く、 を上手く近似できています
- は訓練誤差は低いが、テスト誤差は大きくなり、過学習しています
係数を確認する
各次数における学習によって得られた係数 の値を確認します。
0.15 | ||||||||||
0.98 | -1.66 | |||||||||
0.97 | -1.64 | -0.02 | ||||||||
0.18 | 11.39 | -34.35 | 22.89 | |||||||
0.24 | 9.31 | -23.83 | 6.01 | 8.44 | ||||||
0.34 | -0.38 | 56.92 | -222.22 | 270.05 | -104.64 | |||||
0.36 | -6.09 | 127.41 | -525.66 | 856.42 | -626.23 | 173.86 | ||||
0.36 | 3.13 | -24.81 | 377.03 | -1695.99 | 3095.93 | -2527.33 | 771.77 | |||
0.35 | 33.28 | -632.79 | 4962.76 | -18995.39 | 39111.60 | -44545.62 | 26497.25 | -6431.37 | ||
0.35 | -90.27 | 2220.86 | -20741.11 | 101757.97 | -289865.39 | 494134.29 | -495988.80 | 270001.76 | -61429.58 |
次数が大きいほど、係数が大きな値をとる傾向があることがわかります。
サンプル数を変えたときの挙動を確認する
次数を に固定して、サンプル数を変化させたときの挙動を確認します。

次数が多くてもサンプル数が十分大きい場合は過学習しないことがわかります。
正規化項の導入
最小化対象の二乗誤差に正規化項 を追加します。 は係数ベクトル のノルムで、各係数が大きいほど、ノルムも大きくなります。 を最小化することは、二乗誤差を小さくしつつ、係数も小さくなるような を探すことになります。
は正規化項が課す制約の強さを表すパラメータです。大きいほど制約が強くなり、小さいほど制約が弱くなります。 とすると、正則化項がない二乗誤差になります。
を の範囲で変えたときの挙動を確認します。

- では、制約が弱すぎるため、過学習の傾向が見られます。
- ではそれなりに近似できています。
- では、制約が強すぎるため、 を表現できなくなっています。
0.35 | -90.49 | 2226.13 | -20788.42 | 101979.48 | -290466.89 | 495116.24 | -496938.63 | 270503.06 | -61540.74 | |
0.35 | -90.49 | 2226.13 | -20788.42 | 101979.48 | -290466.89 | 495116.24 | -496938.63 | 270503.06 | -61540.74 | |
0.35 | -30.28 | 839.32 | -8336.90 | 43671.44 | -132098.74 | 236528.98 | -246766.87 | 138448.10 | -32255.31 | |
0.36 | -0.75 | 41.34 | -40.01 | -385.48 | 802.71 | -197.68 | -606.72 | 460.80 | -74.48 | |
0.27 | 6.93 | -8.90 | -16.09 | 2.98 | 12.66 | 10.40 | 2.98 | -3.87 | -7.29 | |
0.52 | -0.41 | -0.47 | -0.33 | -0.18 | -0.05 | 0.05 | 0.12 | 0.18 | 0.23 |
制約を強くするほど、係数が小さくなっていることが確認できました。 を変化させたときの平均平方二乗誤差の値も確認します。

- では、訓練誤差は低いが、テスト誤差は大きくなり、過学習しています
- ではそれなりに近似できています
- では、訓練誤差が大きく、 の近似には表現力が足りていません
コメント