scikit-learn – LinearRegression

概要

scikit-learnLinearRegressionは、最も単純な多重線形回帰モデルを提供する。

モデルの利用方法の概要は以下の手順。

  1. LinearRegressionのクラスをインポートする
  2. モデルのインスタンスを生成する
  3. fit()メソッドに訓練データを与えて学習させる

学習済みのモデルの利用方法は以下の通り。

  • score()メソッドにテストデータを与えて適合度を計算する
  • predict()メソッドに説明変数を与えてターゲットを予測
  • モデルインスタンスのプロパティーからモデルのパラメーターを利用
    • 切片はintercept_、重み係数はcoef_(末尾のアンダースコアに注意)

利用例

配列による場合

以下はscikit-learnのBoston hose pricesデータのうち、2つの特徴量RM(1戸あたり部屋数)とLSTAT(下位層の人口比率)を取り出して、線形回帰のモデルを適用している。

特徴量の一部をとりだすのに、ファンシー・インデックスでリストの要素に2つの変数のインデックスを指定している。また、特徴量データXとターゲットデータyをtrain_test_split()を使って訓練データとテストデータに分けている。

DataFrameによる場合

以下の例では、データセットの本体(data)をpandasのDataFrameとして構成し、2つの特徴量RMとLSTATを指定して取り出している。

利用方法

モデルクラスのインポート

scikit-learn.linear_modelパッケージからLinearRegressionクラスをインポートする。

モデルのインスタンスの生成

LinearRegressionの場合、ハイパーパラメーターの指定はない。

fit_intercept
切片を計算しない場合Falseを指定。デフォルトはTrueで切片も計算されるが、原点を通るべき場合にはFalseを指定する。
normalize
Trueを指定すると、特徴量Xが学習の前に正規化(normalize)される(平均を引いてL2ノルムで割る)。デフォルトはFalsefit_intercept=Falseにセットされた場合は無視される。説明変数を標準化(standardize)する場合はこの引数をFalseにしてsklearn.preprocessing.StandardScalerを使う。
copy_X
Trueを指定するとXはコピーされ、Falseの場合は上書きされる。デフォルトはTrue
n_jobs
計算のジョブの数を指定する。デフォルトはNoneで1に相当。n_targets > 1のときのみ適用される。

 モデルの学習

fit()メソッドに特徴量とターゲットの訓練データを与えてモデルに学習させる(回帰係数を決定する)。

X
特徴量の配列。2次元配列で、各列が各々の説明変数に対応し、行数はデータ数を想定している。変数が1つで1次元配列の時はreshape(-1, 1)かスライス([:, n:n+1])を使って1列の列ベクトルに変換する必要がある。
y
ターゲットの配列で、通常は1変数で1次元配列。

3つ目の引数sample_weightは省略。

適合度の計算

score()メソッドに特徴量とターゲットを与えて適合度を計算する。

戻り値は適合度を示す実数で、回帰計算の決定係数R2で計算される。

(1)    \begin{equation*} R^2 = 1 - \frac{RSS}{TSS} = 1 - \frac{\sum (y_i - \hat{y}_i)^2}{\sum (y_i - \overline{y})^2} \end{equation*}

モデルによる予測

predict()メソッドに特徴量を与えて、ターゲットの予測結果を得る。

ここで特徴量Xは複数のデータセットの2次元配列を想定しており、1組のデータの場合でも2次元配列とする必要がある。

また、結果は複数のデータセットに対する1次元配列で返されるため、ターゲットが1つの場合でも要素数1の1次元配列となる。

切片・係数の利用

fit()メソッドによる学習後、モデルの学習結果として切片と特徴量に対する重み係数を得ることができる。

各々モデル・インスタンスのプロパティーとして保持されており、切片はintercept_で1つの実数、重み係数はcoeff_で特徴量の数と同じ要素数の1次元配列となる(特徴量が1つの場合も要素数1の1次元配列)。

末尾のアンダースコアに注意。

実行例

 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です