{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 307: Week 05" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regularization" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from sklearn.datasets import make_regression\n", "from sklearn.linear_model import LinearRegression\n", "from sklearn.linear_model import Lasso\n", "from sklearn.linear_model import Ridge\n", "from sklearn.preprocessing import StandardScaler" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# create some data\n", "X, y = make_regression(n_samples=100, n_features=20, noise=0.1, random_state=42)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-0.92216532, 1.87679581, 0.75698862, 0.27996863, 0.72576662,\n", " 0.48100923, 1.35563786, -1.2446547 , 0.4134349 , 0.86960592,\n", " 0.65436566, -1.12548905, 2.44575198, 0.12922118, 0.22388402,\n", " 1.49604431, -0.7737892 , -0.05558467, 0.10939479, -1.77872025],\n", " [-0.08310557, -1.4575515 , -1.40631746, -0.1601328 , -0.79602586,\n", " 1.07600714, 0.76005596, -0.75215641, 0.08243975, -1.50472037,\n", " -1.87517247, 0.67134008, 0.21319663, -0.75196933, 0.02131165,\n", " 1.34045045, -0.30920908, 0.11502608, -0.31905394, 0.31917451]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# view first two rows of the X data\n", "X[0:2]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 108.81742028, -250.567936 , 1.86765761, 127.84318259,\n", " 34.15127975])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# view some examples from the y data\n", "y[0:5]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
012345678910111213141516171819
0-0.9221651.8767960.7569890.2799690.7257670.4810091.355638-1.2446550.4134350.8696060.654366-1.1254892.4457520.1292210.2238841.496044-0.773789-0.0555850.109395-1.778720
1-0.083106-1.457551-1.406317-0.160133-0.7960261.0760070.760056-0.7521560.082440-1.504720-1.8751720.6713400.213197-0.7519690.0213121.340450-0.3092090.115026-0.3190540.319175
20.810808-1.662492-0.134309-0.308034-0.209222-1.683438-1.7485321.1267051.3043400.793489-1.1057050.7796611.3103091.395684-0.805870-0.4108141.032546-0.214921-0.562168-1.090966
30.536653-0.756795-1.0469110.4558880.2685921.5284680.7189531.5013340.9960481.1857041.3281942.165002-0.6435180.9278400.507836-0.250833-1.4218110.5562300.057013-0.322680
41.532739-0.4012200.5193471.4511440.1833422.1898030.4017120.0125920.690144-0.1087600.0245100.9592712.153182-0.767348-0.808298-0.7730100.2240920.4979980.8723210.097676
...............................................................
95-0.114736-0.334501-0.7925212.122156-0.7076690.4438190.865755-0.653329-1.2002960.504987-1.2608841.032465-1.519370-0.4842340.7746340.404982-0.4749450.9178621.2669111.765454
96-0.5993750.622850-1.594428-1.5341140.1156751.1792970.046981-0.142379-0.4500650.0052440.7116151.2776770.332314-0.7484870.0675180.514439-1.067620-1.1246421.5511520.120296
97-0.152470-1.3312330.133541-0.006071-0.2902750.2673920.9567020.507991-0.7859890.7081090.3885790.8384910.081829-0.0988900.321698-2.152891-1.8362052.4930000.919076-1.103367
98-1.3796180.513085-0.9716571.188913-0.881875-0.1630670.8623930.5161780.953125-0.6267170.8004100.7083040.3514481.070150-0.7449030.4319230.7250960.754291-0.026521-0.641482
99-2.848543-1.1196700.7716990.076822-0.4281151.500760-1.7397141.160827-0.3624411.148766-0.046921-1.2829920.996267-0.4937570.8502220.346504-1.2946810.477041-1.556582-0.467701
\n", "

100 rows × 20 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 \\\n", "0 -0.922165 1.876796 0.756989 0.279969 0.725767 0.481009 1.355638 \n", "1 -0.083106 -1.457551 -1.406317 -0.160133 -0.796026 1.076007 0.760056 \n", "2 0.810808 -1.662492 -0.134309 -0.308034 -0.209222 -1.683438 -1.748532 \n", "3 0.536653 -0.756795 -1.046911 0.455888 0.268592 1.528468 0.718953 \n", "4 1.532739 -0.401220 0.519347 1.451144 0.183342 2.189803 0.401712 \n", ".. ... ... ... ... ... ... ... \n", "95 -0.114736 -0.334501 -0.792521 2.122156 -0.707669 0.443819 0.865755 \n", "96 -0.599375 0.622850 -1.594428 -1.534114 0.115675 1.179297 0.046981 \n", "97 -0.152470 -1.331233 0.133541 -0.006071 -0.290275 0.267392 0.956702 \n", "98 -1.379618 0.513085 -0.971657 1.188913 -0.881875 -0.163067 0.862393 \n", "99 -2.848543 -1.119670 0.771699 0.076822 -0.428115 1.500760 -1.739714 \n", "\n", " 7 8 9 10 11 12 13 \\\n", "0 -1.244655 0.413435 0.869606 0.654366 -1.125489 2.445752 0.129221 \n", "1 -0.752156 0.082440 -1.504720 -1.875172 0.671340 0.213197 -0.751969 \n", "2 1.126705 1.304340 0.793489 -1.105705 0.779661 1.310309 1.395684 \n", "3 1.501334 0.996048 1.185704 1.328194 2.165002 -0.643518 0.927840 \n", "4 0.012592 0.690144 -0.108760 0.024510 0.959271 2.153182 -0.767348 \n", ".. ... ... ... ... ... ... ... \n", "95 -0.653329 -1.200296 0.504987 -1.260884 1.032465 -1.519370 -0.484234 \n", "96 -0.142379 -0.450065 0.005244 0.711615 1.277677 0.332314 -0.748487 \n", "97 0.507991 -0.785989 0.708109 0.388579 0.838491 0.081829 -0.098890 \n", "98 0.516178 0.953125 -0.626717 0.800410 0.708304 0.351448 1.070150 \n", "99 1.160827 -0.362441 1.148766 -0.046921 -1.282992 0.996267 -0.493757 \n", "\n", " 14 15 16 17 18 19 \n", "0 0.223884 1.496044 -0.773789 -0.055585 0.109395 -1.778720 \n", "1 0.021312 1.340450 -0.309209 0.115026 -0.319054 0.319175 \n", "2 -0.805870 -0.410814 1.032546 -0.214921 -0.562168 -1.090966 \n", "3 0.507836 -0.250833 -1.421811 0.556230 0.057013 -0.322680 \n", "4 -0.808298 -0.773010 0.224092 0.497998 0.872321 0.097676 \n", ".. ... ... ... ... ... ... \n", "95 0.774634 0.404982 -0.474945 0.917862 1.266911 1.765454 \n", "96 0.067518 0.514439 -1.067620 -1.124642 1.551152 0.120296 \n", "97 0.321698 -2.152891 -1.836205 2.493000 0.919076 -1.103367 \n", "98 -0.744903 0.431923 0.725096 0.754291 -0.026521 -0.641482 \n", "99 0.850222 0.346504 -1.294681 0.477041 -1.556582 -0.467701 \n", "\n", "[100 rows x 20 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# get a better look at the X data by temporarily displaying it as a pandas data frame\n", "pd.DataFrame(X)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# create different linear models (some with regularization!)\n", "linear = LinearRegression()\n", "lasso = Lasso(alpha=0.1)\n", "ridge = Ridge(alpha=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Relevant documentation:\n", "\n", "- [`sklearn.linear_model.LinearRegression`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html)\n", "- [`sklearn.linear_model.Lasso`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html)\n", "- [`sklearn.linear_model.Ridge`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# fit the models\n", "_ = linear.fit(X, y)\n", "_ = lasso.fit(X, y)\n", "_ = ridge.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 6.59316802e+00, 9.47572840e+01, 4.07080537e+01, -1.36954387e-03,\n", " 9.32945665e-03, -1.55281734e-02, 1.11188043e+01, 9.55121723e+01,\n", " 8.08254906e+01, 3.48859900e+01, 1.62428940e-02, 1.09659676e-02,\n", " 9.76125777e-03, 1.02690177e-02, -6.11540038e-03, 2.99366631e+01,\n", " 7.23584909e+00, 9.68705480e-03, -1.93391406e-03, 5.22797071e+01])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linear.coef_" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 6.59850269e+00, 9.46618846e+01, 4.06454457e+01, 4.37714102e-03,\n", " 6.54059913e-03, -1.06479513e-02, 1.11069896e+01, 9.53732188e+01,\n", " 8.07112840e+01, 3.48957311e+01, 4.45154929e-04, 3.57365603e-02,\n", " 3.00167738e-02, 5.40170250e-03, -1.88067055e-02, 2.98957936e+01,\n", " 7.23859203e+00, -2.66565706e-02, -3.45743780e-02, 5.21972230e+01])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ridge.coef_" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 6.45386948, 94.67737779, 40.55794675, 0. , -0. ,\n", " -0. , 10.98839003, 95.40225582, 80.67610577, 34.83198456,\n", " -0. , 0. , 0. , -0. , -0. ,\n", " 29.82848069, 7.12581798, -0. , 0. , 52.17111821])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lasso.coef_" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.97887229, 0.99915673, 0.9963126 , -0. , -0. ,\n", " 0. , 0.98827084, 0.99884919, 0.99815176, 0.99845194,\n", " -0. , 0. , 0. , -0. , 0. ,\n", " 0.99638629, 0.98479362, -0. , -0. , 0.99792292])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lasso.coef_ / linear.coef_" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/sklearn/base.py:1351: UserWarning: With alpha=0, this algorithm does not converge well. You are advised to use the LinearRegression estimator\n", " return fit_method(estimator, *args, **kwargs)\n", "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/sklearn/linear_model/_coordinate_descent.py:678: UserWarning: Coordinate descent with no regularization may lead to unexpected results and is discouraged.\n", " model = cd_fast.enet_coordinate_descent(\n" ] } ], "source": [ "# don't do this in practice, instead use LinearRegression as the warning notes\n", "lasso_0 = Lasso(alpha=0)\n", "_ = lasso_0.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 6.59220650e+00, 9.47578704e+01, 4.07086530e+01, 1.67942913e-03,\n", " 7.04365608e-03, -1.08366993e-02, 1.11190745e+01, 9.55172343e+01,\n", " 8.08290446e+01, 3.48823135e+01, 1.51450701e-02, 6.76300320e-03,\n", " 6.58339664e-03, 1.01400600e-02, -4.48418689e-03, 2.99366761e+01,\n", " 7.23587178e+00, 1.20404224e-02, 6.35862604e-04, 5.22790377e+01])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lasso_0.coef_" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "lasso_01 = Lasso(alpha=0.1)\n", "lasso_05 = Lasso(alpha=5)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "_ = lasso_01.fit(X, y)\n", "_ = lasso_05.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 6.45386948, 94.67737779, 40.55794675, 0. , -0. ,\n", " -0. , 10.98839003, 95.40225582, 80.67610577, 34.83198456,\n", " -0. , 0. , 0. , -0. , -0. ,\n", " 29.82848069, 7.12581798, -0. , 0. , 52.17111821])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lasso_01.coef_" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0. , 90.82037065, 33.18439832, 0. , -0. ,\n", " -0. , 4.90698647, 90.05810943, 73.22163014, 32.39510376,\n", " -0. , 0. , 0. , -0. , 0. ,\n", " 24.45096119, 1.92062117, -0. , 0. , 46.96663538])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lasso_05.coef_" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "scaler = StandardScaler()\n", "scaler.fit(X)\n", "X_scaled = scaler.transform(X)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.06648265, -0.09066858, 0.14879335, -0.0092515 , 0.06688254,\n", " 0.13867757, -0.020263 , 0.1355665 , 0.04824566, 0.02152732,\n", " -0.06736594, 0.11621231, 0.18151189, -0.02125737, 0.14082977,\n", " 0.16464492, -0.08393592, 0.01139174, -0.01378138, -0.0325596 ])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(X, axis=0)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-4.44089210e-17, -4.21884749e-17, -2.55351296e-17, -1.84574578e-17,\n", " 9.99200722e-18, 3.99680289e-17, -5.10702591e-17, -4.66293670e-17,\n", " -8.38218384e-17, 2.22044605e-18, -3.92047506e-18, 4.44089210e-17,\n", " 3.66373598e-17, 3.21964677e-17, -3.44169138e-17, 2.35922393e-17,\n", " -4.44089210e-18, 2.55351296e-17, 5.77315973e-17, 3.38618023e-17])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(X_scaled, axis=0)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1.02482769, 0.97262509, 0.92892846, 0.89105338, 0.92304624,\n", " 0.91902508, 0.97366298, 0.94131669, 0.89996273, 1.07352597,\n", " 0.84522696, 1.01952759, 1.04023857, 1.04776952, 0.91703382,\n", " 1.08312224, 1.05773266, 1.1802187 , 1.00197408, 0.88057435])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.std(X, axis=0)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", " 1., 1., 1.])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.std(X_scaled, axis=0)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "_ = lasso.fit(X_scaled, y)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 108.56929274, -250.27007327, 1.74711702, 127.74731283,\n", " 34.16484902, 6.73089519, -47.99671007, -323.75282966,\n", " 40.71830457, 12.39299331, -233.2758167 , -68.89474373,\n", " -95.14530018, 55.75266526, -25.81229695, 77.54586244,\n", " 88.98514129, -200.10765794, 49.19701855, 421.46438405,\n", " 157.23398526, 87.67924179, -36.15995542, 216.57697103,\n", " 340.00532386, -275.91290738, 257.12555842, -374.03964729,\n", " 357.20158352, -244.88762318, 93.67700784, -39.95057358,\n", " 61.98410136, 212.02109074, 265.26394524, 281.77244047,\n", " 318.53883286, 28.62306877, -22.37384227, 234.1955855 ,\n", " -238.8564661 , 280.59818363, -49.02400195, -89.68866067,\n", " -60.59043932, 34.04378327, 147.98817117, 243.56423325,\n", " -14.29619749, 79.54075017, -45.42211395, -36.5241835 ,\n", " 29.71181164, -176.91121525, 96.89249288, 52.16684761,\n", " 114.5794129 , -333.86642314, 226.29054183, -106.73890637,\n", " -6.6769494 , 162.63925058, -30.94403975, -21.76185466,\n", " -19.87450714, -227.84122694, 87.90333805, 11.65702773,\n", " 209.70645989, 214.75616587, -108.581903 , 199.07494731,\n", " -130.09302298, 33.6694303 , -58.72737906, -100.74761299,\n", " -254.15146038, -9.07234502, 157.97000597, -444.49241247,\n", " 6.22181382, -1.07204728, 59.76748702, -58.99329766,\n", " -71.9214518 , 207.05877426, -20.81839889, -265.38705602,\n", " 97.88748469, 114.24112507, 144.72937082, -55.47105068,\n", " 86.01322255, 300.68411823, 22.14183285, -95.77237868,\n", " -44.71399397, -236.20063655, 98.68319317, -13.91650175])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# but the *real* solution is to use a pipeline!\n", "lasso.predict(X) # WRONG\n", "lasso.predict(X_scaled) # correct" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To add regularization for classification use `LogisticRegression` with the correct `penalty` argument!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 2 }