scikit-learn – LogisticRegression

概要

scikit-learnLogisticRegressionモデルはLogistic回帰のモデルを提供する。利用方法の概要は以下の手順で、LinearRegressionなど他の線形モデルとほぼ同じだが、モデルインスタンス生成時に与える正則化パラメーターCRidge/Lassoalphaと逆で、正則化の効果を強くするにはCを小さくする(Cを大きくすると正則化が弱まり、訓練データに対する精度は高まるが過学習の可能性が高くなる)。

また、正則化の方法をL1正則化、L2正則化、Elastic netから選択できる。

  1. LogisticRegressのクラスをインポートする
  2. ハイパーパラメーターC、正則化方法、solver(収束計算方法)などを指定し、モデルのインスタンスを生成する
  3. fit()メソッドに訓練データを与えて学習させる

学習済みのモデルの利用方法は以下の通り。

  • score()メソッドにテストデータを与えて適合度を計算する
  • predict()メソッドに説明変数を与えてターゲットを予測
  • モデルインスタンスのプロパティーからモデルのパラメーターを利用
    • 切片はintercept_、重み係数はcoef_(末尾のアンダースコアに注意)

利用例

以下は、breast_cancerデータセットに対してLogisticRegressionを適用した例。デフォルトのsolver'lbfgs'でデフォルトの最大収束回数(100)では収束しなかったため、max_iter=3000を指定している。

利用方法

LogisticRgressionの主な利用方法はLineaRegressionとほとんど同じで、以下は特有の設定を中心にまとめる。

モデルクラスのインポート

scikit-learn.linear_modelパッケージからLogisticRegressonクラスをインポートする。

モデルのインスタンスの生成

LogisticRegressionでは、ハイパーパラメーターCによって正則化の強さを指定する。このCRidge/Lassoalphaと異なり、正則化の効果を強めるためには値を小さくする。デフォルトはC=1.0

以下、RidgeLassoに特有のパラメーターのみ説明。LinearRegressionと共通のパラメーターはLinearRegressionを参照。

penalty
'l1''l2''elasticnet', 'none'で正則化項のノルムのタイプを指定する。ソルバーの'newton-cg','sag','lbfgs'はL2正則化のみサポートし、'elasticnet''saga'のみがサポートする。デフォルトは'none'で正則化は適用されない('liblinear''none'に対応しない)。
tol
収束計算の解の精度で、デフォルトは1e-4。
C
正則化の強さの逆数。正の整数で指定し、デフォルトは1.0。
solver
'newton-cg''lbfgs''liblinear''sag''saga'のうちから選択される。デフォルトは'lbfgs'。小さなデータセットには'liblnear'が適し、大きなデータセットに対しては'sag''saga'の計算が速い。複数クラスの問題には、'newton-cg''sag''saga''lbfgs'が対応し、'liblinear'は一対他しか対応しない。その他ノルムの種類とソルバーの対応。
max_iter
収束計算の制限回数を指定する。デフォルト値は100。
random_state
データをシャッフルする際のランダム・シードで、solver='sag'の際に用いる。
l1_ratio
Elastic-Netのパラメーター。[0, 1]の値で、penalty='elasticnet'の時のみ使われる。

 モデルの学習

fit()メソッドに特徴量とターゲットの訓練データを与えてモデルに学習させる(回帰係数を決定する)。

X
特徴量の配列。2次元配列で、各列が各々の説明変数に対応し、行数はデータ数を想定している。変数が1つで1次元配列の時はreshape(-1, 1)かスライス([:, n:n+1])を使って1列の列ベクトルに変換する必要がある。
y
ターゲットの配列で、通常は1変数で1次元配列。

3つ目の引数sample_weightは省略。

適合度の計算

score()メソッドに特徴量とターゲットを与えて適合度を計算する。

その他のメソッド

  • decision_function(X)
  • densiffy()
  • predict_proba(X)
  • predict_log_proba()
  • sparsify()

 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です