過学習~多項式回帰の場合

概要

過学習(over fitting)の例として、多項式の係数を線形回帰で予測した場合の挙動をまとめてみた。

複数の点(xi, yi)に対して、以下の線形式の項数を変化させて、Pythonのパッケージ、scikit-learnにあるLinearRegressionでフィッティングさせてみる。

(1)    \begin{equation*} \hat{y} = w_0 + \sum_{j=1}\m w_j x^j \end{equation*}

データ数が少ない場合

以下の例は、[-3, 1]の間で等間隔な4つの値を発生させ、(x, ex)となる4つの点を準備、これらのデータセットに対して、多項式の項数(すなわちxの次数)を1~6まで変化させてフィッティングした結果。たとえばn_terms=3の場合はy = w_0 + w_1 x + w_2 x^2 + w_3 x^3の4つの係数を決定することになる。

  • n_terms=1の場合は単純な線形関数で、データセットの曲線関係を表しているとは言えない。
  • n_terms=2になるとかなり各点にフィットしているが、x < −1の範囲で本来の関数の値と離れていく。
  • n_terms=3はデータ数より項数(特徴量の数)が1つ少ない。各点にほぼぴったり合っていて、最も「それらしい」(ただしデータセットの外側の範囲でも合っているとは限らない/指数関数に対してxの有限の多項式ではどこかで乖離していく)
  • n_terms=4はデータ数と項数(特徴量の数)が等しい。予測曲線がすべての点を通っているが、無理矢理合わせている感があり、データセットの左側で関数形が跳ね上がっている。
  • n_terms=5はデータ数より特徴量数の方が多くなる。予測曲線は全ての点を通っているが、1番目の点と2番目の点の間で若干曲線が歪んでいる
  • n_terms=6になると歪が大きくなる

上記の実行コードは以下の通り。

  • 7~8行目は、切片・係数のセットとxの値を与えて多項式の値を計算する関数。
  • 19行目でn_data=4個のxの値を発生させ、20行目で指数関数の値を計算している。後のために乱数でばらつかせる準備をしているが、ここではばらつかせていない
  • 23~24行目でxnの特徴量を生成している
  • 35行目で線形回帰モデルのフィッティングを行っている。n_termsで指定した項数(=次数)までをフィッティングに使っている。
  • 36行目で、フィッティングの結果予測された切片と係数を使って、予測曲線の値を計算している。

異常値がある場合

上記の整然とした指数関数のデータに1つだけ飛び離れた異常値を入れてみる。

先の例に比べて不安定性=曲線の振動の度合いが大きくなっている。

データ数を多くした場合

点の数を10個とし、乱数で擾乱を与えてみる(乱数系列も変えている)。

n_terms=5あたりから、全ての点に何とかフィットさせようと曲線が揺れ始め、特徴量数がデータ数と同じ値となる前後から振動が大きくなっている。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です