scikit-learn
Cross-validation &
Classification

Lecture 13

Dr. Colin Rundel

Cross validation &
hyperparameter tuning

Ridge regression

One way to expand on the idea of least squares regression is to modify the loss function. Ridge regression is one such approach - it adds a scaled penalty for the sum of the squared \(\beta\)s to the least squares loss.

\[ \underset{\boldsymbol{\beta}}{\text{argmin}} \; \lVert \boldsymbol{y} - \boldsymbol{X} \boldsymbol{\beta} \rVert^2 + \lambda (\boldsymbol{\beta}^T\boldsymbol{\beta}) \]

d = pd.read_csv("data/ridge.csv"); d
            y        x1        x2        x3        x4 x5
0   -0.151710  0.353658  1.633932  0.553257  1.415731  A
1    3.579895  1.311354  1.457500  0.072879  0.330330  B
2    0.768329 -0.744034  0.710362 -0.246941  0.008825  B
3    7.788646  0.806624 -0.228695  0.408348 -2.481624  B
4    1.394327  0.837430 -1.091535 -0.860979 -0.810492  A
..        ...       ...       ...       ...       ... ..
495 -0.204932 -0.385814 -0.130371 -0.046242  0.004914  A
496  0.541988  0.845885  0.045291  0.171596  0.332869  A
497 -1.402627 -1.071672 -1.716487 -0.319496 -1.163740  C
498 -0.043645  1.744800 -0.010161  0.422594  0.772606  A
499 -1.550276  0.910775 -1.675396  1.921238 -0.232189  B

[500 rows x 6 columns]

Dummy coding

d = pd.get_dummies(d); d
            y        x1        x2        x3  ...   x5_A   x5_B   x5_C   x5_D
0   -0.151710  0.353658  1.633932  0.553257  ...   True  False  False  False
1    3.579895  1.311354  1.457500  0.072879  ...  False   True  False  False
2    0.768329 -0.744034  0.710362 -0.246941  ...  False   True  False  False
3    7.788646  0.806624 -0.228695  0.408348  ...  False   True  False  False
4    1.394327  0.837430 -1.091535 -0.860979  ...   True  False  False  False
..        ...       ...       ...       ...  ...    ...    ...    ...    ...
495 -0.204932 -0.385814 -0.130371 -0.046242  ...   True  False  False  False
496  0.541988  0.845885  0.045291  0.171596  ...   True  False  False  False
497 -1.402627 -1.071672 -1.716487 -0.319496  ...  False  False   True  False
498 -0.043645  1.744800 -0.010161  0.422594  ...   True  False  False  False
499 -1.550276  0.910775 -1.675396  1.921238  ...  False   True  False  False

[500 rows x 9 columns]

Fitting a ridge regression model

The linear_model submodule also contains the Ridge model which can be used to fit a ridge regression model. Usage is identical other than Ridge() takes the parameter alpha to specify the regularization parameter.

from sklearn.linear_model import Ridge, LinearRegression

X, y = d.drop(["y"], axis=1), d.y

lm = LinearRegression(fit_intercept=False).fit(X, y)
rg = Ridge(fit_intercept=False, alpha=10).fit(X, y)
lm.coef_
array([ 0.99505,  2.00762,  0.00232, -3.00088,  0.49329,  0.10193, -0.29413,  1.00856])
root_mean_squared_error(y, lm.predict(X))
0.0993601324682191
rg.coef_
array([ 0.97809,  1.96215,  0.00172, -2.94457,  0.45558,  0.09001, -0.28193,  0.79781])
root_mean_squared_error(y, rg.predict(X))
0.13820792795597317

Test-Train split

The most basic form of CV is to split the data into a testing and training set, this can be achieved using train_test_split from the model_selection submodule.

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
  X, y, test_size=0.2, random_state=1234
)
X.shape
(500, 8)
X_train.shape
(400, 8)
X_test.shape
(100, 8)
y.shape
(500,)
y_train.shape
(400,)
y_test.shape
(100,)

Train vs Test RMSE

alpha = np.logspace(-2,1, 100)
train_rmse = []
test_rmse = []

for a in alpha:
    rg = Ridge(alpha=a).fit(
      X_train, y_train
    )
    train_rmse.append( 
     root_mean_squared_error(
        y_train, rg.predict(X_train)
      ) 
    )
    test_rmse.append( 
      root_mean_squared_error(
        y_test, rg.predict(X_test)
      ) 
    )

res = pd.DataFrame( {
  "alpha": alpha, 
   "train": train_rmse, 
   "test": test_rmse
} )
res
        alpha     train      test
0    0.010000  0.097568  0.106985
1    0.010723  0.097568  0.106984
2    0.011498  0.097568  0.106984
3    0.012328  0.097568  0.106983
4    0.013219  0.097568  0.106983
..        ...       ...       ...
95   7.564633  0.126990  0.129414
96   8.111308  0.130591  0.132458
97   8.697490  0.134568  0.135838
98   9.326033  0.138950  0.139581
99  10.000000  0.143764  0.143715

[100 rows x 3 columns]

g = sns.relplot(
  x="alpha", y="rmse", hue="variable", data = pd.melt(res, id_vars=["alpha"],value_name="rmse")
).set(
  xscale="log"
)

Best alpha?

min_i = np.argmin(res.train)
min_i
np.int64(0)
res.iloc[[min_i],:]
   alpha     train      test
0   0.01  0.097568  0.106985
min_i = np.argmin(res.test)
min_i
np.int64(58)
res.iloc[[min_i],:]
       alpha     train    test
58  0.572237  0.097787  0.1068

k-fold cross validation

The previous approach was relatively straightforward, but it required a fair bit of bookkeeping to implement and we only examined a single test/train split. If we would like to perform k-fold cross validation we can use cross_val_score from the model_selection submodule.

from sklearn.model_selection import cross_val_score

cross_val_score(
  Ridge(alpha=0.59, fit_intercept=False), 
  X, y,
  cv=5, 
  scoring="neg_root_mean_squared_error"
)
array([-0.09364, -0.09995, -0.10474, -0.10273, -0.10597])

🚩🚩🚩 Note that the default k-fold cross validation used here does not shuffle the data which can be massively problematic if the data is ordered 🚩🚩🚩

Controlling k-fold behavior

Rather than providing cv as an integer, it is better to specify a cross-validation scheme directly (with additional options). Here we will use the KFold class from the model_selection submodule.

from sklearn.model_selection import KFold

cross_val_score(
  Ridge(alpha=0.59, fit_intercept=False), 
  X, y, 
  cv = KFold(n_splits=5, shuffle=True, random_state=1234), 
  scoring="neg_root_mean_squared_error"
)
array([-0.10658, -0.104  , -0.1037 , -0.10125, -0.09228])

KFold object

KFold() returns a class object that provides the split() method — a generator yielding tuples of training and test indices for each fold.

ex = pd.DataFrame(data = list(range(10)), columns=["x"])
cv = KFold(5)
for train, test in cv.split(ex):
  print(f'Train: {train} | Test: {test}')
Train: [2 3 4 5 6 7 8 9] | Test: [0 1]
Train: [0 1 4 5 6 7 8 9] | Test: [2 3]
Train: [0 1 2 3 6 7 8 9] | Test: [4 5]
Train: [0 1 2 3 4 5 8 9] | Test: [6 7]
Train: [0 1 2 3 4 5 6 7] | Test: [8 9]
cv = KFold(5, shuffle=True, random_state=1234)
for train, test in cv.split(ex):
  print(f'Train: {train} | Test: {test}')
Train: [0 1 3 4 5 6 8 9] | Test: [2 7]
Train: [0 2 3 4 5 6 7 8] | Test: [1 9]
Train: [1 2 3 4 5 6 7 9] | Test: [0 8]
Train: [0 1 2 3 6 7 8 9] | Test: [4 5]
Train: [0 1 2 4 5 7 8 9] | Test: [3 6]

scoring

For most of the cross validation functions we pass in either a string or a callable (with signature scorer(estimator, X, y)).

The names of the possible metrics are available via sklearn.metrics.get_scorer_names().

np.array( sklearn.metrics.get_scorer_names() )
array(['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score',
       'average_precision', 'balanced_accuracy', 'completeness_score',
       'd2_absolute_error_score', 'd2_brier_score', 'd2_log_loss_score',
       'explained_variance', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted',
       'fowlkes_mallows_score', 'homogeneity_score', 'jaccard', 'jaccard_macro',
       'jaccard_micro', 'jaccard_samples', 'jaccard_weighted', 'matthews_corrcoef',
       'mutual_info_score', 'neg_brier_score', 'neg_log_loss', 'neg_max_error',
       'neg_mean_absolute_error', 'neg_mean_absolute_percentage_error',
       'neg_mean_gamma_deviance', 'neg_mean_poisson_deviance', 'neg_mean_squared_error',
       'neg_mean_squared_log_error', 'neg_median_absolute_error',
       'neg_negative_likelihood_ratio', 'neg_root_mean_squared_error',
       'neg_root_mean_squared_log_error', 'normalized_mutual_info_score',
       'positive_likelihood_ratio', 'precision', 'precision_macro', 'precision_micro',
       'precision_samples', 'precision_weighted', 'r2', 'rand_score', 'recall',
       'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc',
       'roc_auc_ovo', 'roc_auc_ovo_weighted', 'roc_auc_ovr', 'roc_auc_ovr_weighted',
       'top_k_accuracy', 'v_measure_score'], dtype='<U34')

Train vs Test RMSE (again)

alpha = np.logspace(-2,1, 30)
test_mean_rmse = []
test_rmse = []
cv = KFold(n_splits=5, shuffle=True, random_state=1234)

for a in alpha:
    rg = Ridge(fit_intercept=False, alpha=a)
    
    scores = -1 * cross_val_score(
      rg, X_train, y_train, 
      cv = cv, 
      scoring="neg_root_mean_squared_error"
    )
    test_mean_rmse.append(np.mean(scores))
    test_rmse.append(scores)

res = pd.DataFrame(
    data = np.c_[alpha, test_mean_rmse, test_rmse],
    columns = ["alpha", "mean_rmse"] + ["fold" + str(i) for i in range(1,6) ]
)

res
        alpha  mean_rmse     fold1     fold2     fold3     fold4     fold5
0    0.010000   0.099393  0.096577  0.091750  0.091573  0.104881  0.112186
1    0.012690   0.099393  0.096581  0.091743  0.091575  0.104882  0.112185
2    0.016103   0.099393  0.096585  0.091734  0.091577  0.104884  0.112185
3    0.020434   0.099392  0.096591  0.091722  0.091580  0.104885  0.112184
4    0.025929   0.099392  0.096599  0.091708  0.091583  0.104888  0.112183
5    0.032903   0.099392  0.096608  0.091690  0.091588  0.104891  0.112182
6    0.041753   0.099391  0.096621  0.091667  0.091594  0.104895  0.112181
7    0.052983   0.099392  0.096637  0.091639  0.091602  0.104900  0.112180
8    0.067234   0.099392  0.096657  0.091604  0.091612  0.104908  0.112179
9    0.085317   0.099394  0.096684  0.091561  0.091626  0.104919  0.112179
10   0.108264   0.099398  0.096720  0.091510  0.091644  0.104935  0.112179
11   0.137382   0.099405  0.096767  0.091448  0.091669  0.104958  0.112181
12   0.174333   0.099417  0.096829  0.091376  0.091704  0.104992  0.112186
13   0.221222   0.099439  0.096913  0.091294  0.091751  0.105042  0.112196
14   0.280722   0.099477  0.097028  0.091207  0.091819  0.105117  0.112215
15   0.356225   0.099540  0.097185  0.091121  0.091914  0.105231  0.112249
16   0.452035   0.099644  0.097403  0.091052  0.092052  0.105406  0.112309
17   0.573615   0.099816  0.097709  0.091030  0.092254  0.105675  0.112410
18   0.727895   0.100093  0.098143  0.091102  0.092551  0.106090  0.112580
19   0.923671   0.100541  0.098766  0.091354  0.092993  0.106733  0.112857
20   1.172102   0.101254  0.099664  0.091918  0.093654  0.107727  0.113309
21   1.487352   0.102382  0.100964  0.093008  0.094646  0.109258  0.114033
22   1.887392   0.104142  0.102847  0.094944  0.096135  0.111602  0.115181
23   2.395027   0.106848  0.105567  0.098184  0.098358  0.115153  0.116978
24   3.039195   0.110933  0.109465  0.103349  0.101654  0.120451  0.119747
25   3.856620   0.116963  0.114982  0.111214  0.106477  0.128202  0.123940
26   4.893901   0.125634  0.122661  0.122664  0.113415  0.139275  0.130153
27   6.210169   0.137756  0.133144  0.138641  0.123187  0.154674  0.139134
28   7.880463   0.154231  0.147158  0.160089  0.136635  0.175507  0.151764
29  10.000000   0.176041  0.165524  0.187970  0.154706  0.202968  0.169035

g = sns.relplot(
  x="alpha", y="rmse", hue="variable", data=res.melt(id_vars=["alpha"], value_name="rmse"), 
  marker="o", kind="line"
).set(
  xscale="log"
)

best_* attributes

GridSearchCV()’s return object contains attributes with details on the “best” model based on the chosen scoring metric.

gs.best_index_
np.int64(5)
gs.best_params_
{'alpha': np.float64(0.03290344562312668)}
gs.best_score_
np.float64(-0.1012561176745365)

best_estimator_ attribute

If refit = True (default) with GridSearchCV() then the best_estimator_ attribute will be available which gives direct access to the “best” model or pipeline object. This model is constructed by using the parameter(s) that achieved the minimum score and refitting the model to the complete data set.

gs.best_estimator_
Ridge(alpha=np.float64(0.03290344562312668), fit_intercept=False)
gs.best_estimator_.coef_
array([ 0.99499,  2.00747,  0.00231, -3.0007 ,  0.49316,  0.10189, -0.29408,  1.00767])
gs.best_estimator_.predict(X)
array([ -0.12179,   3.34151,   0.76055,   7.89292,   1.56523,  -5.33575,  -4.37469,
         3.13003,  -0.16859,  -1.60087,  -1.89073,   1.44596,   3.99773,   4.70003,
        -6.45959,   4.11085,   3.60426,  -1.96548,   2.99039,   0.56796,  -5.26672,
         5.4966 ,   3.47247,  -2.66117,   3.35011,   0.64221,  -1.50238,   2.41562,
         3.11665,   1.11236,  -2.11839,   1.36006,  -0.53666,  -2.78112,   0.76008,
         5.49779,   2.6521 ,  -0.83127,   0.04167,  -1.92585,  -2.48865,   2.29127,
         3.62514,  -2.01226,  -0.69725,  -1.94514,  -0.47559,  -7.36557,  -3.20766,
         2.9218 ,  -0.8213 ,  -2.78598, -12.55143,   2.79189,  -1.89763,  -5.1769 ,
         1.87484,   2.18345,  -6.45358,   0.91006,   0.94792,   2.91799,   6.12323,
        -1.87654,   3.63259,  -0.53797,  -3.23506,  -2.23885,   1.04564,  -1.54843,
         0.76161,  -1.65495,   0.22378,  -0.68221,   0.12976,   2.58875,   2.54421,
        -3.69056,   3.73479,  -0.90278,   1.22394,  -3.22614,   7.16719,  -5.6168 ,
         3.3433 ,   0.36935,   0.87397,   9.22348,  -1.29078,   1.74347,  -1.55169,
        -0.69398,  -1.40445,   0.23072,   1.06277,   2.84797,   2.35596,  -1.93292,
         8.35129,  -2.98221,  -6.35071,  -5.15138,   1.70208,   7.15821,   3.96172,
         5.75363,  -4.50718,  -5.81785,  -2.47424,   1.19276,   2.57431,  -2.57053,
        -0.53682,  -1.65955,   1.99839,  -6.19607,  -1.73962,  -2.11993,  -2.29362,
         2.65413,  -0.67486,  -3.01324,   0.34118,  -3.83856,   0.33096,  -3.59485,
        -1.55578,   0.96765,   3.50934,  -0.31935,  -4.18323,   2.87843,  -1.64857,
        -3.68181,   2.24423,  -1.00244,  -2.65588,  -5.77111,  -1.20292,   2.66903,
        -1.11387,   3.05231,   6.34596,  -1.42886,  -2.29709,  -1.4573 ,  -2.46733,
         1.69685,   4.21699,   1.21569,   9.06269,  -3.62209,   1.94704,   1.14603,
        -3.35087,  -5.91052,  -1.23355,   2.8308 ,  -3.21438,   4.09019,  -5.95969,
        -0.98044,   2.06976,   0.58541,   1.83006,   8.11251,  -0.18073,  -4.80287,
         1.59881,   0.13323,   2.67859,   2.45406,  -2.28901,   1.1609 ,  -1.50239,
        -5.51199,   2.67089,   2.39878,   6.65249,   0.5551 ,   9.36975,   6.90333,
         0.48633,  -0.51877,   1.44203,  -5.95008,   5.99042,  -0.85644,   1.90162,
        -1.23686,   3.22403,   5.31725,   0.31415,   0.17128,  -1.53623,   1.73354,
        -1.93645,   4.68716,  -3.62658,   0.22032, -10.94667,   2.83953,  -8.13513,
         4.30062,  -0.67864,  -0.67348,   4.22499,   3.34704,  -1.44927,  -6.3229 ,
         4.83881,  -3.71184,   6.32207,   3.69622,  -1.02501, -12.91691,   1.85435,
        -0.43171,   4.77516,  -1.53529,  -1.65685,   5.69233,   6.28949,   5.37201,
        -0.63177,   2.88795,   4.01781,   7.03453,   1.76797,   5.86793,   1.57465,
         3.03172,   0.96769,  -3.0659 ,  -1.51918,  -2.89632,  -1.28436,   2.67186,
        -0.92299,  -4.85603,   4.18714,  -3.60775,  -2.31532,   1.27459,   0.37238,
        -1.21   ,   2.44074,  -1.52466,  -2.59175,  -1.83419,  -0.8865 ,   0.89346,
         2.70453,  -3.15098,  -4.43793,   0.8058 ,   0.23748,   1.13615,   0.63385,
        -0.2395 ,   6.07024,   0.85521,   0.18951,   3.27772,  -0.8963 ,  -5.84285,
         0.68905,  -0.30427,  -2.87087,  10.51629,  -3.96115,  -5.09138, -10.86754,
        -9.25489,   7.0615 ,   0.01263,   3.93274,   3.40325,  -1.57858,  -4.94508,
        -2.69779,   1.07372,  -3.95091,  -3.80321,  -1.91214,   0.14772,   3.70995,
         5.04094,  -0.02024,  -0.03725,  -1.15642,   8.92035,   2.63769,  -1.39664,
         1.62241,  -4.87487,  -2.49769,   1.39569,  -1.39193,   4.52569,   2.29201,
         1.57898,   0.69253,  -3.4654 ,   3.71004,   6.10037,  -4.41299,  -4.79775,
        -3.79204,  -3.61711,  -2.92489,   7.15104,  -3.24195,   3.03705,  -4.01473,
        -1.99391,  -4.64601,   4.40534,  -3.12028,  -0.1754 ,   2.52698,   0.49637,
        -1.0263 ,  10.77554,  -1.64465,  -2.13624,  -2.16392,   1.92049,  -2.47602,
        -4.34462,  -2.09427,  -0.32466,   2.56876,  -5.7397 ,  -2.94306,  -1.12118,
         4.16147,   2.5303 ,   3.38768,   7.96277,  -3.28827,  -5.73513,   4.76249,
        -1.24714,   0.08253,  -1.71446,   1.3742 ,   1.85738,  -6.37864,  -0.0773 ,
         0.73072,  -1.64713,  -3.65246,   1.57344,  -2.56019,  -1.09033,  -1.05099,
        -4.48298,  -0.28666,  -4.92509,   2.6523 ,  -4.59622,   3.09283,   3.50353,
        -6.1787 ,  -2.08203,  -2.72838,  -8.55473,   4.14717,   0.03483,  -2.07173,
        -1.22492,  -2.1331 ,  -3.24188,  -3.23348,  -1.43328,   3.09365,   2.85666,
         3.1452 ,  -0.60436,  -3.08445,   2.39221,   1.26373,   4.77618,  -1.78471,
        -6.19369,  -3.24321,  -0.76221,  -1.56433,   1.39877,   2.28802,   4.46115,
        -3.25751,  -2.51097,   1.19593,   1.12214,   2.0177 ,  -2.9301 ,  -5.70471,
         2.94404,  -9.62989,  -4.13055,  -0.30686,   5.41388,   3.36441,  -1.68838,
         3.18239,  -1.97929,   3.84279,   0.59629,   4.23805,  -8.3217 ,   4.71925,
         0.32863,   2.20721,   3.46358,   3.38237,  -2.65319,   2.32341,   0.31199,
         5.29292,   0.798  ,   2.17796,   5.74332,  -7.68979,   0.33166,  -1.84974,
         4.73811,   0.51179,  -1.18062,  -1.08818,   6.30818,  -2.88198,  -1.68064,
         1.76754,  -3.80955,  -5.03755,   3.41809,  -2.62689,   4.09036,  -4.51406,
         0.95089,  -1.0706 ,  -1.51755,  -1.83065,  -5.33533,  -2.15694,  -5.43987,
        -5.04878,  -5.62245,  -1.46875,  -0.60701,   0.20797,  -3.21649,   3.93528,
         1.14442,   1.93545,  -4.11887,  -0.39968,  -4.07461,   2.32534,  -0.26627,
        -2.45467,  -1.08026,   2.35466,   0.92026,  -1.41122,  -1.21825, -10.48345,
         3.18599,   0.08117,   4.24776,   4.47563,   6.52936,   4.06496,   0.61928,
        -4.96605,  -1.23884,  -3.06521,   2.4295 ,  -3.13812,  -0.51459,  -2.9222 ,
         0.72806,   4.4886 ,  -1.04944, -11.67098,   1.12496,   3.81906,  -6.76879,
        -3.90709,  -1.75508,   1.57104,   2.2711 ,   7.69569,  -0.16729,   0.42729,
        -1.31489,  -0.10855,  -1.65403])

cv_results_ attribute

Other useful details about the grid search process are stored as a dictionary in the cv_results_ attribute which includes things like average test scores, fold level test scores, test ranks, test runtimes, etc.

gs.cv_results_.keys()
dict_keys(['mean_fit_time', 'std_fit_time', 'mean_score_time', 'std_score_time', 'param_alpha', 'params', 'split0_test_score', 'split1_test_score', 'split2_test_score', 'split3_test_score', 'split4_test_score', 'mean_test_score', 'std_test_score', 'rank_test_score'])
gs.cv_results_["mean_test_score"]
array([-0.10126, -0.10126, -0.10126, -0.10126, -0.10126, -0.10126, -0.10126, -0.10126,
       -0.10126, -0.10126, -0.10126, -0.10127, -0.10128, -0.10129, -0.10132, -0.10136,
       -0.10143, -0.10154, -0.10173, -0.10203, -0.1025 , -0.10325, -0.10444, -0.10627,
       -0.10909, -0.11333, -0.11959, -0.12859, -0.14119, -0.15832])
gs.cv_results_["param_alpha"]
masked_array(data=[0.01, 0.01268961003167922, 0.01610262027560939, 0.020433597178569417,
                   0.02592943797404667, 0.03290344562312668, 0.041753189365604,
                   0.05298316906283707, 0.06723357536499334, 0.08531678524172806,
                   0.10826367338740546, 0.1373823795883263, 0.17433288221999882,
                   0.2212216291070449, 0.2807216203941177, 0.3562247890262442,
                   0.4520353656360243, 0.5736152510448679, 0.727895384398315,
                   0.9236708571873861, 1.1721022975334805, 1.4873521072935119,
                   1.8873918221350976, 2.395026619987486, 3.039195382313198,
                   3.856620421163472, 4.893900918477494, 6.2101694189156165,
                   7.880462815669913, 10.0],
             mask=[False, False, False, False, False, False, False, False, False, False,
                   False, False, False, False, False, False, False, False, False, False,
                   False, False, False, False, False, False, False, False, False, False],
       fill_value=1e+20)

alpha = np.array(gs.cv_results_["param_alpha"], dtype="float64")
score = -gs.cv_results_["mean_test_score"]
score_std = gs.cv_results_["std_test_score"]
n_folds = gs.cv.get_n_splits()

plt.figure(layout="constrained")
ax = sns.lineplot(x=alpha, y=score)
ax.set_xscale("log")
plt.fill_between(
  x = alpha,
  y1 = score + 1.96*score_std / np.sqrt(n_folds),
  y2 = score - 1.96*score_std / np.sqrt(n_folds),
  alpha = 0.2
)
plt.show()

Ridge traceplot

alpha = np.logspace(-1,5, 100)
betas = []

for a in alpha:
    rg = Ridge(alpha=a, fit_intercept=False).fit(X, y)
    betas.append(rg.coef_)

res = pd.DataFrame(
  data = betas, columns = rg.feature_names_in_
).assign(
  alpha = alpha  
)

g = sns.relplot(
  data = res.melt(id_vars="alpha", value_name="coef values", var_name="feature"),
  x = "alpha", y = "coef values", hue = "feature",
  kind = "line", aspect=2
).set(
  xscale="log"
)

Classification

OpenIntro - Spam

We will start by looking at a data set on spam emails from the OpenIntro project. A full data dictionary can be found here. To keep things simple this week we will restrict our exploration to including only the following columns: spam, exclaim_mess, format, num_char, line_breaks, and number.

  • spam - Indicator for whether the email was spam.
  • exclaim_mess - The number of exclamation points in the email message.
  • format - Indicates whether the email was written using HTML (e.g. may have included bolding or active links).
  • num_char - The number of characters in the email, in thousands.
  • line_breaks - The number of line breaks in the email (does not count text wrapping).
  • number - Factor variable saying whether there was no number, a small number (under 1 million), or a big number.

As number is categorical, we will take care of the necessary dummy coding via pd.get_dummies(),

email = pd.read_csv('data/email.csv')[ 
  ['spam', 'exclaim_mess', 'format', 'num_char', 'line_breaks', 'number'] 
]
email_dc = pd.get_dummies(email)
email_dc
      spam  exclaim_mess  format  ...  number_big  number_none  number_small
0        0             0       1  ...        True        False         False
1        0             1       1  ...       False        False          True
2        0             6       1  ...       False        False          True
3        0            48       1  ...       False        False          True
4        0             1       0  ...       False         True         False
...    ...           ...     ...  ...         ...          ...           ...
3916     1             0       0  ...       False        False          True
3917     1             0       0  ...       False        False          True
3918     0             5       1  ...       False        False          True
3919     0             0       0  ...       False        False          True
3920     1             1       0  ...       False        False          True

[3921 rows x 8 columns]

g = sns.pairplot(email, hue='spam', corner=True, aspect=1.25)

Model fitting

from sklearn.linear_model import LogisticRegression

y = email_dc.spam
X = email_dc.drop('spam', axis=1)

m = LogisticRegression(fit_intercept = False).fit(X, y)
m.feature_names_in_
array(['exclaim_mess', 'format', 'num_char', 'line_breaks', 'number_big', 'number_none',
       'number_small'], dtype=object)
m.coef_
array([[ 0.00982, -0.61905,  0.05449, -0.00555, -1.21236, -0.69326, -1.92064]])

A quick comparison

R output

glm(spam~.-1, data=d, family=binomial) |>
  coef()
exclaim_mess       format     num_char  line_breaks    numberbig   numbernone 
 0.009586821 -0.604781649  0.054765496 -0.005480427 -1.264826746 -0.706842516 
 numbersmall 
-1.950440237 


sklearn output

m.feature_names_in_
array(['exclaim_mess', 'format', 'num_char', 'line_breaks', 'number_big', 'number_none',
       'number_small'], dtype=object)
m.coef_
array([[ 0.00982, -0.61905,  0.05449, -0.00555, -1.21236, -0.69326, -1.92064]])

sklearn.linear_model.LogisticRegression

From the documentation,

This class implements regularized logistic regression using the ‘liblinear’ library, ‘newton-cg’, ‘sag’, ‘saga’ and ‘lbfgs’ solvers. Note that regularization is applied by default. It can handle both dense and sparse input. Use C-ordered arrays or CSR matrices containing 64-bit floats for optimal performance; any other input format will be converted (and copied).

Penalty parameter

🚩🚩🚩

LogisticRegression() has a parameter called penalty that applies a "l1" (lasso), "l2" (ridge), "elasticnet" or None with "l2" being the default. To make matters worse, the regularization is controlled by the parameter C which defaults to 1. C here is the inverse regularization strength (e.g. the inverse of alpha for ridge and lasso models).

🚩🚩🚩

\[ \min_{w} \sum_{i=1}^n \left(-y_i \log(\hat{p}(X_i)) - (1 - y_i) \log(1 - \hat{p}(X_i))\right) + \frac{1}{C}\, \left(\frac{1 - \rho}{2}w^T w + \rho \|w\|_1\right) \]

Another quick comparison

R output

glm(spam~.-1, data = d, family=binomial) |>
  coef()
exclaim_mess       format     num_char  line_breaks    numberbig   numbernone 
 0.009586821 -0.604781649  0.054765496 -0.005480427 -1.264826746 -0.706842516 
 numbersmall 
-1.950440237 

sklearn output (penalty None)

m = LogisticRegression(
  fit_intercept = False, penalty=None
).fit(X, y)
m.feature_names_in_
array(['exclaim_mess', 'format', 'num_char', 'line_breaks', 'number_big', 'number_none',
       'number_small'], dtype=object)
m.coef_
array([[ 0.00959, -0.60483,  0.05476, -0.00548, -1.26481, -0.70687, -1.95043]])

Solver parameter

It is also possible to specify the solver to use when fitting a logistic regression model, to complicate matters somewhat the choice of the algorithm depends on the penalty chosen:

  • newton-cg - ["l2", None]
  • lbfgs - ["l2", None]
  • liblinear - ["l1", "l2"]
  • sag - ["l2", None]
  • saga - ["elasticnet", "l1", "l2", None]

Also there can be issues with feature scales for some of these solvers:

Note: ‘sag’ and ‘saga’ fast convergence is only guaranteed on features with approximately the same scale. You can preprocess the data with a scaler from sklearn.preprocessing.

Prediction

Classification models have multiple prediction methods depending on what type of output you would like,

m.predict(X)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0], shape=(3921,))
m.predict_proba(X)
array([[0.91325, 0.08675],
       [0.95595, 0.04405],
       [0.95788, 0.04212],
       [0.94085, 0.05915],
       [0.68757, 0.31243],
       [0.6845 , 0.3155 ],
       [0.93419, 0.06581],
       [0.96357, 0.03643],
       [0.89585, 0.10415],
       [0.94176, 0.05824],
       [0.93248, 0.06752],
       [0.89601, 0.10399],
       [0.91243, 0.08757],
       [0.97272, 0.02728],
       [0.92833, 0.07167],
       [0.98352, 0.01648],
       [0.96326, 0.03674],
       [0.95381, 0.04619],
       [0.88889, 0.11111],
       [0.80419, 0.19581],
       [0.89905, 0.10095],
       [0.95645, 0.04355],
       [0.99084, 0.00916],
       [0.88019, 0.11981],
       [0.80525, 0.19475],
       [0.8875 , 0.1125 ],
       [0.89733, 0.10267],
       [0.88684, 0.11316],
       [0.68521, 0.31479],
       [0.93696, 0.06304],
       ...,
       [0.88214, 0.11786],
       [0.99385, 0.00615],
       [0.93508, 0.06492],
       [0.68929, 0.31071],
       [0.87709, 0.12291],
       [0.79315, 0.20685],
       [0.78987, 0.21013],
       [0.6726 , 0.3274 ],
       [0.89342, 0.10658],
       [0.93273, 0.06727],
       [0.68929, 0.31071],
       [0.88446, 0.11554],
       [0.98186, 0.01814],
       [0.88949, 0.11051],
       [0.88358, 0.11642],
       [0.67276, 0.32724],
       [0.7904 , 0.2096 ],
       [0.67994, 0.32006],
       [0.68715, 0.31285],
       [0.70615, 0.29385],
       [0.93313, 0.06687],
       [0.93056, 0.06944],
       [0.88956, 0.11044],
       [0.78882, 0.21118],
       [0.91827, 0.08173],
       [0.88059, 0.11941],
       [0.88236, 0.11764],
       [0.95981, 0.04019],
       [0.89246, 0.10754],
       [0.898  , 0.102  ]], shape=(3921, 2))
m.predict_log_proba(X)
array([[-0.09074, -2.44476],
       [-0.04505, -3.12251],
       [-0.04303, -3.16731],
       [-0.06098, -2.8276 ],
       [-0.37459, -1.16338],
       [-0.37906, -1.1536 ],
       [-0.06808, -2.72095],
       [-0.03711, -3.31224],
       [-0.10999, -2.26189],
       [-0.06   , -2.84318],
       [-0.0699 , -2.6954 ],
       [-0.10981, -2.26344],
       [-0.09165, -2.43527],
       [-0.02766, -3.60153],
       [-0.07436, -2.63572],
       [-0.01662, -4.10549],
       [-0.03743, -3.3039 ],
       [-0.04729, -3.07497],
       [-0.11778, -2.19724],
       [-0.21791, -1.63063],
       [-0.10642, -2.29309],
       [-0.04453, -3.13385],
       [-0.0092 , -4.69303],
       [-0.12761, -2.12189],
       [-0.21661, -1.63602],
       [-0.11935, -2.18479],
       [-0.10833, -2.27627],
       [-0.12009, -2.17899],
       [-0.37804, -1.15584],
       [-0.06512, -2.76395],
       ...,
       [-0.12541, -2.13823],
       [-0.00617, -5.0908 ],
       [-0.06712, -2.73464],
       [-0.37209, -1.16891],
       [-0.13114, -2.09631],
       [-0.23174, -1.57577],
       [-0.23589, -1.56003],
       [-0.39661, -1.11657],
       [-0.1127 , -2.23886],
       [-0.06964, -2.69899],
       [-0.37209, -1.16891],
       [-0.12277, -2.15817],
       [-0.01831, -4.00951],
       [-0.1171 , -2.20268],
       [-0.12377, -2.15058],
       [-0.39637, -1.11705],
       [-0.23522, -1.56254],
       [-0.38574, -1.13926],
       [-0.3752 , -1.16204],
       [-0.34793, -1.22468],
       [-0.06921, -2.70507],
       [-0.07197, -2.66732],
       [-0.11702, -2.20333],
       [-0.23722, -1.55504],
       [-0.08526, -2.50436],
       [-0.12717, -2.12518],
       [-0.12516, -2.1401 ],
       [-0.04102, -3.21422],
       [-0.11378, -2.22985],
       [-0.10758, -2.2828 ]], shape=(3921, 2))

Scoring

All estimators include a score() method which returns the default scorer for a given model, on the case of classification models this is the accuracy,

m.score(X, y)
0.90640142820709

Other scoring options are available via the metrics submodule

from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, confusion_matrix
accuracy_score(y, m.predict(X))
0.90640142820709
roc_auc_score(y, m.predict_proba(X)[:,1])
0.7606967779329887
f1_score(y, m.predict(X))
0.0
confusion_matrix(y, m.predict(X), labels=m.classes_)
array([[3554,    0],
       [ 367,    0]])

Confusion matrix

from sklearn.metrics import ConfusionMatrixDisplay
cm = confusion_matrix(y, m.predict(X), labels=m.classes_)
disp = ConfusionMatrixDisplay(cm).plot()
plt.show()

ROC curve

from sklearn.metrics import auc, roc_curve, RocCurveDisplay
fpr, tpr, thresholds = roc_curve(y, m.predict_proba(X)[:,1])
roc_auc = auc(fpr, tpr)
disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
                       estimator_name='Logistic Regression').plot()
plt.show()

Precision Recall curve

from sklearn.metrics import precision_recall_curve, PrecisionRecallDisplay
precision, recall, _ = precision_recall_curve(y, m.predict_proba(X)[:,1])
disp = PrecisionRecallDisplay(precision=precision, recall=recall).plot()
plt.show()

StratifiedKFold

For classification problems, StratifiedKFold is preferred over KFold - it ensures each fold preserves the same class proportions as the full dataset, which is particularly important for imbalanced classes.

from sklearn.model_selection import StratifiedKFold
print(f'overall % spam = {y.mean():.3f}')
overall % spam = 0.094

KFold

cv = KFold(5, shuffle=True, random_state=1234)
for i, (train, test) in enumerate(cv.split(X, y)):
  print(f'fold {i+1}: % spam = {y.iloc[test].mean():.3f}')
fold 1: % spam = 0.078
fold 2: % spam = 0.107
fold 3: % spam = 0.097
fold 4: % spam = 0.085
fold 5: % spam = 0.101

StratifiedKFold

cv = StratifiedKFold(5, shuffle=True, random_state=1234)
for i, (train, test) in enumerate(cv.split(X, y)):
  print(f'fold {i+1}: % spam = {y.iloc[test].mean():.3f}')
fold 1: % spam = 0.094
fold 2: % spam = 0.093
fold 3: % spam = 0.093
fold 4: % spam = 0.093
fold 5: % spam = 0.094

MNIST (sklearn)

MNIST (sklearn) handwritten digits

These are a simplified (and cleaned) version of the original MNIST data set, which consists of 8x8 pixel images of handwritten digits.

from sklearn.datasets import load_digits
digits = load_digits(as_frame=True)
X = digits.data
X
      pixel_0_0  pixel_0_1  pixel_0_2  ...  pixel_7_5  pixel_7_6  pixel_7_7
0           0.0        0.0        5.0  ...        0.0        0.0        0.0
1           0.0        0.0        0.0  ...       10.0        0.0        0.0
2           0.0        0.0        0.0  ...       16.0        9.0        0.0
3           0.0        0.0        7.0  ...        9.0        0.0        0.0
4           0.0        0.0        0.0  ...        4.0        0.0        0.0
...         ...        ...        ...  ...        ...        ...        ...
1792        0.0        0.0        4.0  ...        9.0        0.0        0.0
1793        0.0        0.0        6.0  ...        6.0        0.0        0.0
1794        0.0        0.0        1.0  ...        6.0        0.0        0.0
1795        0.0        0.0        2.0  ...       12.0        0.0        0.0
1796        0.0        0.0       10.0  ...       12.0        1.0        0.0

[1797 rows x 64 columns]
y = digits.target
y
0       0
1       1
2       2
3       3
4       4
       ..
1792    9
1793    0
1794    8
1795    9
1796    8
Name: target, Length: 1797, dtype: int64

Example digits

Doing things properly - train/test split

To properly assess our modeling we will create a training and testing set of these data, only the training data will be used to learn model coefficients or hyperparameters, test data will only be used for final model scoring.

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, shuffle=True, stratify=y, random_state=1234
)

Multiclass logistic regression

Fitting a multiclass logistic regression model (in recent versions of sklearn) uses multinomial logistic regression by default. We can use GridSearchCV with StratifiedKFold to explore the available solver options (all of which support penalty=None).

mc_log_cv = GridSearchCV(
  LogisticRegression(penalty=None, max_iter = 5000),
  param_grid = {"solver": ["lbfgs", "newton-cg", "sag", "saga"]},
  cv = StratifiedKFold(10, shuffle=True, random_state=12345)
).fit(
  X_train, y_train
)
mc_log_cv.best_estimator_
LogisticRegression(max_iter=5000, penalty=None, solver='saga')
mc_log_cv.best_score_
np.float64(0.9617217630853994)
for param, score in zip(mc_log_cv.cv_results_["params"], mc_log_cv.cv_results_["mean_test_score"]):
  print(f"{param=}, {score=}")
param={'solver': 'lbfgs'}, score=np.float64(0.9542699724517906)
param={'solver': 'newton-cg'}, score=np.float64(0.959249311294766)
param={'solver': 'sag'}, score=np.float64(0.9609022038567494)
param={'solver': 'saga'}, score=np.float64(0.9617217630853994)

Model coefficients

pd.DataFrame(
  mc_log_cv.best_estimator_.coef_
)
    0         1         2         3   ...        60        61        62        63
0  0.0 -0.004725 -0.044285  0.032391  ... -0.033048 -0.047579 -0.049476 -0.017627
1  0.0 -0.000007 -0.149267  0.144655  ...  0.069315  0.191033  0.229132  0.140568
2  0.0  0.043262  0.030727  0.046395  ...  0.079775  0.210461  0.474337  0.171941
3  0.0  0.038958 -0.156024  0.062062  ...  0.115703  0.079442 -0.041634 -0.113690
4  0.0 -0.004442  0.021386 -0.315967  ... -0.181147 -0.237009 -0.098966 -0.004862
5  0.0  0.068641  0.474700  0.071645  ... -0.073288 -0.082737 -0.050102 -0.022247
6  0.0 -0.007990 -0.219742 -0.078264  ... -0.023428  0.098858  0.060984 -0.050235
7  0.0  0.028949  0.048150  0.085590  ... -0.224151 -0.343424 -0.134076 -0.011080
8  0.0 -0.004807  0.066576 -0.271286  ...  0.243447 -0.004030 -0.168261 -0.119749
9  0.0 -0.157838 -0.072221  0.222780  ...  0.026823  0.134983 -0.221938  0.026981

[10 rows x 64 columns]
mc_log_cv.best_estimator_.coef_.shape
(10, 64)
mc_log_cv.best_estimator_.intercept_
array([ 0.00638, -0.06455,  0.00183,  0.01208,  0.04303,  0.00836,  0.00025,  0.01013,
        0.0292 , -0.04671])

Confusion Matrix

Within sample

accuracy_score(
  y_train, 
  mc_log_cv.best_estimator_.predict(X_train)
)
1.0
confusion_matrix(
  y_train, 
  mc_log_cv.best_estimator_.predict(X_train)
)
array([[119,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0, 122,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0, 118,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0, 123,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0, 121,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0, 122,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0, 121,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0, 120,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0, 116,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0, 121]])

Out of sample

accuracy_score(
  y_test, 
  mc_log_cv.best_estimator_.predict(X_test)
)
0.968013468013468
confusion_matrix(
  y_test, 
  mc_log_cv.best_estimator_.predict(X_test),
  labels = digits.target_names
)
array([[59,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0, 58,  0,  1,  0,  0,  0,  0,  0,  1],
       [ 0,  0, 59,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  1, 56,  0,  2,  0,  0,  1,  0],
       [ 0,  1,  0,  0, 58,  0,  0,  0,  1,  0],
       [ 0,  0,  0,  0,  0, 56,  0,  0,  1,  3],
       [ 0,  0,  0,  0,  0,  0, 60,  0,  0,  0],
       [ 0,  0,  0,  0,  2,  0,  0, 56,  0,  1],
       [ 0,  1,  1,  0,  0,  0,  0,  0, 56,  0],
       [ 0,  0,  0,  0,  0,  1,  0,  0,  1, 57]])

Report

print( classification_report(
  y_test, 
  mc_log_cv.best_estimator_.predict(X_test)
) )
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        59
           1       0.97      0.97      0.97        60
           2       0.97      1.00      0.98        59
           3       0.98      0.93      0.96        60
           4       0.97      0.97      0.97        60
           5       0.95      0.93      0.94        60
           6       1.00      1.00      1.00        60
           7       1.00      0.95      0.97        59
           8       0.93      0.97      0.95        58
           9       0.92      0.97      0.94        59

    accuracy                           0.97       594
   macro avg       0.97      0.97      0.97       594
weighted avg       0.97      0.97      0.97       594

Prediction

mc_log_cv.best_estimator_.predict(X_test)
array([1, 6, 1, 9, 6, 7, 8, 7, 2, 4, 0, 7, 1, 7, 6, 6, 0,
       5, 0, 4, 3, 2, 3, 1, 7, 0, 8, 2, 2, 5, 1, 2, 7, 3,
       1, 2, 6, 2, 5, 8, 1, 5, 6, 7, 1, 9, 6, 7, 9, 4, 9,
       9, 6, 3, 5, 4, 1, 3, 8, 6, 0, 1, 6, 4, 8, 1, 2, 8,
       0, 3, 8, 3, 0, 0, 6, 1, 9, 1, 7, 0, 3, 1, 1, 6, 1,
       1, 9, 1, 4, 3, 9, 9, 8, 3, 3, 4, 4, 7, 5, 3, 9, 9,
       8, 1, 9, 6, 2, 8, 5, 8, 6, 4, 4, 1, 4, 1, 5, 0, 6,
       9, 2, 5, 5, 2, 8, 8, 5, 9, 0, 0, 8, 6, 4, 1, 3, 3,
       8, 9, 4, 4, 4, 4, 4, 0, 2, 0, 3, 4, 5, 2, 2, 8, 8,
       3, 7, 7, 1, 7, 2, 4, 9, 9, 9, 4, 2, 2, 9, 0, 0, 7,
       6, 5, 7, 3, 4, 2, 5, 5, 2, 0, 5, 9, 6, 2, 9, 1, 8,
       1, 2, 2, 5, 7, 4, 6, 4, 9, 9, 8, 5, 6, 7, 9, 4, 2,
       6, 1, 4, 7, 3, 2, 5, 6, 5, 5, 6, 0, 3, 8, 1, 1, 9,
       2, 7, 8, 0, 4, 0, 7, 0, 5, 2, 7, 6, 3, 4, 9, 5, 7,
       3, 8, 6, 8, 1, 5, 2, 8, 2, 8, 6, 6, 8, 2, 2, 7, 2,
       3, 6, 8, 9, 4, 2, 3, 8, 7, 3, 5, 4, 8, 3, 9, 6, 3,
       2, 0, 9, 7, 5, 0, 3, 2, 6, 6, 6, 5, 8, 8, 6, 7, 0,
       1, 3, 2, 5, 9, 7, 8, 8, 4, 6, 3, 4, 6, 1, 4, 2, 3,
       9, 2, 0, 2, 9, 8, 3, 6, 3, 9, 7, 8, 0, 2, 2, 0, 1,
       7, 2, 1, 2, 7, 6, 3, 9, 2, 4, 1, 2, 8, 7, 0, 4, 7,
       2, 4, 4, 8, 4, 4, 1, 8, 0, 7, 6, 9, 9, 0, 5, 7, 0,
       9, 9, 4, 5, 1, 5, 2, 3, 2, 1, 5, 9, 9, 8, 7, 9, 1,
       3, 6, 6, 9, 8, 8, 7, 7, 5, 3, 5, 7, 4, 8, 0, 6, 5,
       2, 0, 1, 2, 5, 8, 7, 8, 5, 7, 4, 3, 6, 2, 7, 8, 0,
       9, 9, 4, 3, 4, 5, 3, 4, 6, 0, 5, 0, 6, 2, 3, 0, 7,
       2, 4, 3, 0, 3, 0, 0, 4, 8, 7, 4, 6, 1, 8, 6, 3, 4,
       9, 6, 2, 6, 4, 5, 5, 8, 5, 5, 4, 1, 5, 8, 5, 7, 1,
       7, 5, 7, 3, 0, 9, 9, 7, 1, 0, 0, 5, 6, 5, 0, 6, 0,
       0, 3, 1, 1, 9, 9, 5, 3, 3, 2, 9, 4, 7, 5, 0, 0, 1,
       1, 9, 0, 7, 8, 5, 4, 1, 4, 9, 6, 5, 7, 1, 6, 8, 8,
       1, 9, 9, 5, 3, 2, 2, 9, 8, 7, 3, 2, 8, 8, 0, 4, 0,
       1, 0, 0, 3, 1, 9, 8, 4, 2, 5, 3, 9, 0, 3, 7, 5, 0,
       9, 4, 8, 7, 9, 8, 9, 4, 0, 2, 5, 2, 0, 6, 0, 1, 0,
       1, 3, 8, 1, 4, 6, 1, 3, 1, 6, 7, 6, 4, 6, 1, 8, 3,
       3, 5, 2, 5, 5, 6, 4, 1, 1, 7, 6, 3, 7, 9, 9, 6])
mc_log_cv.best_estimator_.predict_proba(X_test)
array([[0.     , 0.99329, 0.     , 0.     , 0.00666,
        0.     , 0.     , 0.     , 0.00005, 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.99999, 0.     , 0.00001, 0.     ],
       [0.     , 0.99996, 0.     , 0.     , 0.     ,
        0.     , 0.00001, 0.     , 0.00003, 0.     ],
       [0.     , 0.     , 0.     , 0.00001, 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.99999],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 1.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.00048, 0.     ,
        0.     , 0.     , 0.99949, 0.00003, 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 1.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 1.     , 0.     , 0.     ],
       [0.     , 0.     , 0.99998, 0.00002, 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 1.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [1.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 1.     , 0.     , 0.     ],
       [0.     , 1.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.00003,
        0.     , 0.     , 0.99997, 0.     , 0.     ],
       [0.     , 0.00002, 0.     , 0.     , 0.     ,
        0.     , 0.99998, 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 1.     , 0.     , 0.     , 0.     ],
       [1.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        1.     , 0.     , 0.     , 0.     , 0.     ],
       [1.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 1.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 1.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.81245, 0.18755, 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.99947, 0.     ,
        0.00052, 0.     , 0.     , 0.     , 0.00001],
       [0.     , 0.9993 , 0.     , 0.00068, 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.00002],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 1.     , 0.     , 0.     ],
       [0.95032, 0.     , 0.00004, 0.04952, 0.     ,
        0.00002, 0.     , 0.     , 0.0001 , 0.00001],
       [0.00012, 0.0118 , 0.     , 0.     , 0.00001,
        0.     , 0.00046, 0.23179, 0.75582, 0.     ],
       [0.     , 0.     , 1.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 1.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        1.     , 0.     , 0.     , 0.     , 0.     ],
       ...,
       [0.     , 1.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 1.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.99999, 0.     , 0.00001, 0.     ],
       [0.     , 0.97211, 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.02789, 0.     ],
       [0.     , 0.     , 0.     , 0.99896, 0.     ,
        0.00054, 0.     , 0.     , 0.00023, 0.00026],
       [0.     , 0.99565, 0.     , 0.     , 0.00435,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 1.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.00004, 0.     , 0.99993, 0.     , 0.00003],
       [0.00001, 0.     , 0.     , 0.     , 0.     ,
        0.00002, 0.99997, 0.     , 0.     , 0.     ],
       [0.     , 0.02386, 0.     , 0.     , 0.97613,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.04823, 0.     , 0.     , 0.     ,
        0.     , 0.94236, 0.     , 0.0094 , 0.     ],
       [0.     , 0.99987, 0.     , 0.     , 0.00011,
        0.     , 0.     , 0.     , 0.00002, 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 1.     , 0.     ],
       [0.     , 0.     , 0.00002, 0.99938, 0.     ,
        0.     , 0.     , 0.     , 0.00059, 0.     ],
       [0.     , 0.     , 0.     , 0.99904, 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.00095],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        1.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 1.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        1.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.2237 , 0.     ,
        0.73021, 0.00001, 0.     , 0.     , 0.04608],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.99999, 0.     , 0.00001, 0.     ],
       [0.     , 0.     , 0.     , 0.     , 1.     ,
        0.     , 0.     , 0.     , 0.     , 0.     ],
       [0.     , 0.99998, 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.00001, 0.00001, 0.     ],
       [0.     , 0.99145, 0.     , 0.     , 0.00633,
        0.     , 0.     , 0.00003, 0.00219, 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 1.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 1.     , 0.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.98663, 0.     ,
        0.     , 0.     , 0.     , 0.01336, 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 1.     , 0.     , 0.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 1.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 1.     ],
       [0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 1.     , 0.     , 0.     , 0.     ]],
      shape=(594, 10))

Examining the coefs

coef_img = mc_log_cv.best_estimator_.coef_.reshape(10,8,8)

fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(10, 5), layout="constrained")
axes2 = [ax for row in axes for ax in row]

for ax, image, label in zip(axes2, coef_img, range(10)):
    ax.set_axis_off()
    img = ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
    txt = ax.set_title(f"{label}")
    
plt.show()

Example 1 - DecisionTreeClassifier

Using these data we will now fit a DecisionTreeClassifier, we will employ GridSearchCV to tune some of the parameters (max_depth at a minimum) - see the full list or parameters here.

from sklearn.datasets import load_digits
digits = load_digits(as_frame=True)


X, y = digits.data, digits.target
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, shuffle=True, stratify=y, random_state=1234
)

Example 1 - Fitting

from sklearn.tree import DecisionTreeClassifier

digits_tree = GridSearchCV(
  DecisionTreeClassifier(),
  param_grid = {
    "criterion": ["gini", "entropy"],
    "max_depth": range(2,16)
  },
  cv = KFold(5, shuffle=True, random_state=12345)
).fit(
  X_train, y_train
)

Example 1 - Results

digits_tree.best_estimator_
DecisionTreeClassifier(criterion='entropy', max_depth=13)
digits_tree.best_score_
np.float64(0.86615836791148)
accuracy_score(y_test, digits_tree.best_estimator_.predict(X_test))
0.8434343434343434
confusion_matrix(
  y_test, digits_tree.best_estimator_.predict(X_test)
)
array([[54,  0,  0,  0,  3,  0,  0,  0,  2,  0],
       [ 0, 53,  0,  4,  1,  0,  1,  0,  0,  1],
       [ 0,  3, 50,  1,  0,  1,  0,  1,  3,  0],
       [ 0,  2,  1, 52,  0,  1,  0,  0,  1,  3],
       [ 0,  2,  0,  0, 50,  1,  2,  4,  0,  1],
       [ 1,  5,  1,  1,  2, 44,  4,  1,  0,  1],
       [ 0,  1,  0,  0,  1,  4, 54,  0,  0,  0],
       [ 0,  2,  0,  0,  3,  1,  0, 52,  0,  1],
       [ 1,  9,  0,  1,  1,  0,  0,  0, 44,  2],
       [ 1,  0,  0,  2,  1,  3,  1,  1,  2, 48]])

Example 2 - GridSearchCV w/ Multiple models
(Trees vs Forests)

Example 2 - Setup

To compare fundamentally different model types with GridSearchCV we can use a Pipeline where the estimator itself becomes a parameter in the grid search. The param_grid then takes a list of dicts, one per model family.

from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline

pipe = Pipeline([("clf", DecisionTreeClassifier())])

Example 2 - Fitting

param_grid = [
  {
    "clf": [DecisionTreeClassifier()],
    "clf__criterion": ["gini", "entropy"],
    "clf__max_depth": range(2, 16)
  },
  {
    "clf": [RandomForestClassifier()],
    "clf__n_estimators": [50, 100, 200],
    "clf__max_depth": [None, 5, 10]
  }
]

trees_vs_forests = GridSearchCV(
  pipe,
  param_grid,
  cv = StratifiedKFold(5, shuffle=True, random_state=12345)
).fit(X_train, y_train)

Example 2 - Results

trees_vs_forests.best_estimator_
Pipeline(steps=[('clf',
                 RandomForestClassifier(max_depth=10, n_estimators=200))])
trees_vs_forests.best_params_
{'clf': RandomForestClassifier(), 'clf__max_depth': 10, 'clf__n_estimators': 200}
trees_vs_forests.best_score_
np.float64(0.9717254495159061)

Example 2 - Test Performance

accuracy_score(y_test, trees_vs_forests.best_estimator_.predict(X_test))
0.9764309764309764
confusion_matrix(
  y_test, trees_vs_forests.best_estimator_.predict(X_test)
)
array([[58,  0,  0,  0,  1,  0,  0,  0,  0,  0],
       [ 0, 60,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0, 59,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0, 57,  0,  2,  0,  0,  1,  0],
       [ 0,  0,  0,  0, 58,  0,  0,  1,  1,  0],
       [ 0,  0,  0,  0,  0, 59,  0,  0,  0,  1],
       [ 0,  0,  0,  0,  0,  0, 60,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0, 59,  0,  0],
       [ 0,  0,  0,  1,  1,  0,  0,  1, 54,  1],
       [ 0,  0,  0,  1,  0,  1,  0,  0,  1, 56]])