Files
100-Days-Of-ML-Code/Code/Day 3_Multiple_Linear_Regression.ipynb
2021-01-18 19:31:58 +08:00

309 lines
19 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 机器学习100天——第3天多元线性回归Multiple Linear Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 第1步数据预处理"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**导入库**"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**导入数据集**"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"X:\n[[165349.2 136897.8 471784.1 'New York']\n [162597.7 151377.59 443898.53 'California']\n [153441.51 101145.55 407934.54 'Florida']\n [144372.41 118671.85 383199.62 'New York']\n [142107.34 91391.77 366168.42 'Florida']\n [131876.9 99814.71 362861.36 'New York']\n [134615.46 147198.87 127716.82 'California']\n [130298.13 145530.06 323876.68 'Florida']\n [120542.52 148718.95 311613.29 'New York']\n [123334.88 108679.17 304981.62 'California']]\nY:\n[192261.83 191792.06 191050.39 182901.99 166187.94 156991.12 156122.51\n 155752.6 152211.77 149759.96 146121.95 144259.4 141585.52 134307.35\n 132602.65 129917.04 126992.93 125370.37 124266.9 122776.86 118474.03\n 111313.02 110352.25 108733.99 108552.04 107404.34 105733.54 105008.31\n 103282.38 101004.64 99937.59 97483.56 97427.84 96778.92 96712.8\n 96479.51 90708.19 89949.14 81229.06 81005.76 78239.91 77798.83\n 71498.49 69758.98 65200.33 64926.08 49490.75 42559.73 35673.41\n 14681.4 ]\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" R&D Spend Administration Marketing Spend State Profit\n",
"0 165349.20 136897.80 471784.10 New York 192261.83\n",
"1 162597.70 151377.59 443898.53 California 191792.06\n",
"2 153441.51 101145.55 407934.54 Florida 191050.39\n",
"3 144372.41 118671.85 383199.62 New York 182901.99\n",
"4 142107.34 91391.77 366168.42 Florida 166187.94"
],
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>R&amp;D Spend</th>\n <th>Administration</th>\n <th>Marketing Spend</th>\n <th>State</th>\n <th>Profit</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>165349.20</td>\n <td>136897.80</td>\n <td>471784.10</td>\n <td>New York</td>\n <td>192261.83</td>\n </tr>\n <tr>\n <th>1</th>\n <td>162597.70</td>\n <td>151377.59</td>\n <td>443898.53</td>\n <td>California</td>\n <td>191792.06</td>\n </tr>\n <tr>\n <th>2</th>\n <td>153441.51</td>\n <td>101145.55</td>\n <td>407934.54</td>\n <td>Florida</td>\n <td>191050.39</td>\n </tr>\n <tr>\n <th>3</th>\n <td>144372.41</td>\n <td>118671.85</td>\n <td>383199.62</td>\n <td>New York</td>\n <td>182901.99</td>\n </tr>\n <tr>\n <th>4</th>\n <td>142107.34</td>\n <td>91391.77</td>\n <td>366168.42</td>\n <td>Florida</td>\n <td>166187.94</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"metadata": {},
"execution_count": 57
}
],
"source": [
"dataset = pd.read_csv('../datasets/50_Startups.csv')\n",
"X = dataset.iloc[ : , :-1].values\n",
"Y = dataset.iloc[ : , 4 ].values\n",
"Z = dataset.iloc[ : , 0 ].values\n",
"print(\"X:\")\n",
"print(X[:10])\n",
"print(\"Y:\")\n",
"print(Y)\n",
"dataset.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[[165349.2 136897.8 471784.1 'New York']\n [162597.7 151377.59 443898.53 'California']\n [153441.51 101145.55 407934.54 'Florida']\n [144372.41 118671.85 383199.62 'New York']\n [142107.34 91391.77 366168.42 'Florida']\n [131876.9 99814.71 362861.36 'New York']\n [134615.46 147198.87 127716.82 'California']\n [130298.13 145530.06 323876.68 'Florida']\n [120542.52 148718.95 311613.29 'New York']\n [123334.88 108679.17 304981.62 'California']\n [101913.08 110594.11 229160.95 'Florida']\n [100671.96 91790.61 249744.55 'California']\n [93863.75 127320.38 249839.44 'Florida']\n [91992.39 135495.07 252664.93 'California']\n [119943.24 156547.42 256512.92 'Florida']\n [114523.61 122616.84 261776.23 'New York']\n [78013.11 121597.55 264346.06 'California']\n [94657.16 145077.58 282574.31 'New York']\n [91749.16 114175.79 294919.57 'Florida']\n [86419.7 153514.11 224494.78489361703 'New York']\n [76253.86 113867.3 298664.47 'California']\n [78389.47 153773.43 299737.29 'New York']\n [73994.56 122782.75 303319.26 'Florida']\n [67532.53 105751.03 304768.73 'Florida']\n [77044.01 99281.34 140574.81 'New York']\n [64664.71 139553.16 137962.62 'California']\n [75328.87 144135.98 134050.07 'Florida']\n [72107.6 127864.55 353183.81 'New York']\n [66051.52 182645.56 118148.2 'Florida']\n [65605.48 153032.06 107138.38 'New York']\n [61994.48 115641.28 91131.24 'Florida']\n [61136.38 152701.92 88218.23 'New York']\n [63408.86 129219.61 46085.25 'California']\n [55493.95 103057.49 214634.81 'Florida']\n [46426.07 157693.92 210797.67 'California']\n [46014.02 85047.44 205517.64 'New York']\n [28663.76 127056.21 201126.82 'Florida']\n [44069.95 51283.14 197029.42 'California']\n [20229.59 65947.93 185265.1 'New York']\n [38558.51 82982.09 174999.3 'California']\n [28754.33 118546.05 172795.67 'California']\n [27892.92 84710.77 164470.71 'Florida']\n [23640.93 96189.63 148001.11 'California']\n [15505.73 127382.3 35534.17 'New York']\n [22177.74 154806.14 28334.72 'California']\n [1000.23 124153.04 1903.93 'New York']\n [1315.46 115816.21 297114.46 'Florida']\n [76793.34958333334 135426.92 224494.78489361703 'California']\n [542.05 51743.15 224494.78489361703 'New York']\n [76793.34958333334 116983.8 45173.06 'California']]\n"
]
}
],
"source": [
"from sklearn.impute import SimpleImputer\n",
"imputer = SimpleImputer(missing_values=0.0, strategy=\"mean\")\n",
"imputer = imputer.fit(X[ : , 0:3])\n",
"X[ : , 0:3] = imputer.transform(X[ : , 0:3])\n",
"print(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**将类别数据数字化**"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"original:\n[[165349.2 136897.8 471784.1 'New York']\n [162597.7 151377.59 443898.53 'California']\n [153441.51 101145.55 407934.54 'Florida']\n [144372.41 118671.85 383199.62 'New York']\n [142107.34 91391.77 366168.42 'Florida']\n [131876.9 99814.71 362861.36 'New York']\n [134615.46 147198.87 127716.82 'California']\n [130298.13 145530.06 323876.68 'Florida']\n [120542.52 148718.95 311613.29 'New York']\n [123334.88 108679.17 304981.62 'California']]\nlabelencoder:\n[[165349.2 136897.8 471784.1 2]\n [162597.7 151377.59 443898.53 0]\n [153441.51 101145.55 407934.54 1]\n [144372.41 118671.85 383199.62 2]\n [142107.34 91391.77 366168.42 1]\n [131876.9 99814.71 362861.36 2]\n [134615.46 147198.87 127716.82 0]\n [130298.13 145530.06 323876.68 1]\n [120542.52 148718.95 311613.29 2]\n [123334.88 108679.17 304981.62 0]]\nonehot:\n[[0.0 0.0 1.0 165349.2 136897.8 471784.1]\n [1.0 0.0 0.0 162597.7 151377.59 443898.53]\n [0.0 1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 0.0 1.0 144372.41 118671.85 383199.62]\n [0.0 1.0 0.0 142107.34 91391.77 366168.42]\n [0.0 0.0 1.0 131876.9 99814.71 362861.36]\n [1.0 0.0 0.0 134615.46 147198.87 127716.82]\n [0.0 1.0 0.0 130298.13 145530.06 323876.68]\n [0.0 0.0 1.0 120542.52 148718.95 311613.29]\n [1.0 0.0 0.0 123334.88 108679.17 304981.62]]\n"
]
}
],
"source": [
"from sklearn.preprocessing import LabelEncoder, OneHotEncoder\n",
"from sklearn.compose import ColumnTransformer \n",
"labelencoder = LabelEncoder()\n",
"print(\"original:\")\n",
"print(X[:10])\n",
"#print(X[: , 3])\n",
"X[: , 3] = labelencoder.fit_transform(X[ : , 3])\n",
"#print(X[: , 3])\n",
"print(\"labelencoder:\")\n",
"print(X[:10])\n",
"ct = ColumnTransformer([( \"encoder\", OneHotEncoder(), [3])], remainder = 'passthrough')\n",
"X = ct.fit_transform(X)\n",
"#onehotencoder = OneHotEncoder(categorical_features = [3])\n",
"#X = onehotencoder.fit_transform(X).toarray()\n",
"print(\"onehot:\")\n",
"print(X[:10])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**躲避虚拟变量陷阱**\n",
"\n",
"在回归预测中我们需要所有的数据都是numeric的但是会有一些非numeric的数据比如国家部门性别。这时候我们需要设置虚拟变量Dummy variable。做法是将此变量中的每一个值衍生成为新的变量是设为1否设为0.举个例子,“性别”这个变量,我们可以虚拟出“男”和”女”两虚拟变量男性的话“男”值为1”女”值为0女性的话“男”值为0”女”值为1。\n",
"\n",
"但是要注意,这时候虚拟变量陷阱就出现了。就拿性别来说,其实一个虚拟变量就够了,比如 1 的时候是“男”, 0 的时候是”非男”,即为女。如果设置两个虚拟变量“男”和“女”,语义上来说没有问题,可以理解,但是在回归预测中会多出一个变量,多出的这个变量将会对回归预测结果产生影响。一般来说,如果虚拟变量要比实际变量的种类少一个。 \n",
"\n",
"在多重线性回归中变量不是越多越好而是选择适合的变量。这样才会对结果准确预测。如果category类的特征都放进去拟合的时候所有权重的计算都可以有两种方法实现一种是提高某个category的w一种是降低其他category的w这两种效果是等效的也就是发生了共线性,虚拟变量系数相加和为1出现完全共线陷阱。\n",
"\n",
"**但是下面测试尽然和想法不一致。。。**"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"X1 = X[: , 1:]"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[[0.0 1.0 165349.2 136897.8 471784.1]\n [0.0 0.0 162597.7 151377.59 443898.53]\n [1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 1.0 144372.41 118671.85 383199.62]\n [1.0 0.0 142107.34 91391.77 366168.42]\n [0.0 1.0 131876.9 99814.71 362861.36]\n [0.0 0.0 134615.46 147198.87 127716.82]\n [1.0 0.0 130298.13 145530.06 323876.68]\n [0.0 1.0 120542.52 148718.95 311613.29]\n [0.0 0.0 123334.88 108679.17 304981.62]\n [1.0 0.0 101913.08 110594.11 229160.95]\n [0.0 0.0 100671.96 91790.61 249744.55]\n [1.0 0.0 93863.75 127320.38 249839.44]\n [0.0 0.0 91992.39 135495.07 252664.93]\n [1.0 0.0 119943.24 156547.42 256512.92]\n [0.0 1.0 114523.61 122616.84 261776.23]\n [0.0 0.0 78013.11 121597.55 264346.06]\n [0.0 1.0 94657.16 145077.58 282574.31]\n [1.0 0.0 91749.16 114175.79 294919.57]\n [0.0 1.0 86419.7 153514.11 224494.78489361703]\n [0.0 0.0 76253.86 113867.3 298664.47]\n [0.0 1.0 78389.47 153773.43 299737.29]\n [1.0 0.0 73994.56 122782.75 303319.26]\n [1.0 0.0 67532.53 105751.03 304768.73]\n [0.0 1.0 77044.01 99281.34 140574.81]\n [0.0 0.0 64664.71 139553.16 137962.62]\n [1.0 0.0 75328.87 144135.98 134050.07]\n [0.0 1.0 72107.6 127864.55 353183.81]\n [1.0 0.0 66051.52 182645.56 118148.2]\n [0.0 1.0 65605.48 153032.06 107138.38]\n [1.0 0.0 61994.48 115641.28 91131.24]\n [0.0 1.0 61136.38 152701.92 88218.23]\n [0.0 0.0 63408.86 129219.61 46085.25]\n [1.0 0.0 55493.95 103057.49 214634.81]\n [0.0 0.0 46426.07 157693.92 210797.67]\n [0.0 1.0 46014.02 85047.44 205517.64]\n [1.0 0.0 28663.76 127056.21 201126.82]\n [0.0 0.0 44069.95 51283.14 197029.42]\n [0.0 1.0 20229.59 65947.93 185265.1]\n [0.0 0.0 38558.51 82982.09 174999.3]\n [0.0 0.0 28754.33 118546.05 172795.67]\n [1.0 0.0 27892.92 84710.77 164470.71]\n [0.0 0.0 23640.93 96189.63 148001.11]\n [0.0 1.0 15505.73 127382.3 35534.17]\n [0.0 0.0 22177.74 154806.14 28334.72]\n [0.0 1.0 1000.23 124153.04 1903.93]\n [1.0 0.0 1315.46 115816.21 297114.46]\n [0.0 0.0 76793.34958333334 135426.92 224494.78489361703]\n [0.0 1.0 542.05 51743.15 224494.78489361703]\n [0.0 0.0 76793.34958333334 116983.8 45173.06]]\n[[0.0 0.0 1.0 165349.2 136897.8 471784.1]\n [1.0 0.0 0.0 162597.7 151377.59 443898.53]\n [0.0 1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 0.0 1.0 144372.41 118671.85 383199.62]\n [0.0 1.0 0.0 142107.34 91391.77 366168.42]\n [0.0 0.0 1.0 131876.9 99814.71 362861.36]\n [1.0 0.0 0.0 134615.46 147198.87 127716.82]\n [0.0 1.0 0.0 130298.13 145530.06 323876.68]\n [0.0 0.0 1.0 120542.52 148718.95 311613.29]\n [1.0 0.0 0.0 123334.88 108679.17 304981.62]\n [0.0 1.0 0.0 101913.08 110594.11 229160.95]\n [1.0 0.0 0.0 100671.96 91790.61 249744.55]\n [0.0 1.0 0.0 93863.75 127320.38 249839.44]\n [1.0 0.0 0.0 91992.39 135495.07 252664.93]\n [0.0 1.0 0.0 119943.24 156547.42 256512.92]\n [0.0 0.0 1.0 114523.61 122616.84 261776.23]\n [1.0 0.0 0.0 78013.11 121597.55 264346.06]\n [0.0 0.0 1.0 94657.16 145077.58 282574.31]\n [0.0 1.0 0.0 91749.16 114175.79 294919.57]\n [0.0 0.0 1.0 86419.7 153514.11 224494.78489361703]\n [1.0 0.0 0.0 76253.86 113867.3 298664.47]\n [0.0 0.0 1.0 78389.47 153773.43 299737.29]\n [0.0 1.0 0.0 73994.56 122782.75 303319.26]\n [0.0 1.0 0.0 67532.53 105751.03 304768.73]\n [0.0 0.0 1.0 77044.01 99281.34 140574.81]\n [1.0 0.0 0.0 64664.71 139553.16 137962.62]\n [0.0 1.0 0.0 75328.87 144135.98 134050.07]\n [0.0 0.0 1.0 72107.6 127864.55 353183.81]\n [0.0 1.0 0.0 66051.52 182645.56 118148.2]\n [0.0 0.0 1.0 65605.48 153032.06 107138.38]\n [0.0 1.0 0.0 61994.48 115641.28 91131.24]\n [0.0 0.0 1.0 61136.38 152701.92 88218.23]\n [1.0 0.0 0.0 63408.86 129219.61 46085.25]\n [0.0 1.0 0.0 55493.95 103057.49 214634.81]\n [1.0 0.0 0.0 46426.07 157693.92 210797.67]\n [0.0 0.0 1.0 46014.02 85047.44 205517.64]\n [0.0 1.0 0.0 28663.76 127056.21 201126.82]\n [1.0 0.0 0.0 44069.95 51283.14 197029.42]\n [0.0 0.0 1.0 20229.59 65947.93 185265.1]\n [1.0 0.0 0.0 38558.51 82982.09 174999.3]\n [1.0 0.0 0.0 28754.33 118546.05 172795.67]\n [0.0 1.0 0.0 27892.92 84710.77 164470.71]\n [1.0 0.0 0.0 23640.93 96189.63 148001.11]\n [0.0 0.0 1.0 15505.73 127382.3 35534.17]\n [1.0 0.0 0.0 22177.74 154806.14 28334.72]\n [0.0 0.0 1.0 1000.23 124153.04 1903.93]\n [0.0 1.0 0.0 1315.46 115816.21 297114.46]\n [1.0 0.0 0.0 76793.34958333334 135426.92 224494.78489361703]\n [0.0 0.0 1.0 542.05 51743.15 224494.78489361703]\n [1.0 0.0 0.0 76793.34958333334 116983.8 45173.06]]\n"
]
}
],
"source": [
"print(X1)\n",
"print(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**拆分数据集为训练集和测试集**"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[[0.0 1.0 0.0 66051.52 182645.56 118148.2]\n [1.0 0.0 0.0 100671.96 91790.61 249744.55]\n [0.0 1.0 0.0 101913.08 110594.11 229160.95]\n [0.0 1.0 0.0 27892.92 84710.77 164470.71]\n [0.0 1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 0.0 1.0 72107.6 127864.55 353183.81]\n [0.0 0.0 1.0 20229.59 65947.93 185265.1]\n [0.0 0.0 1.0 61136.38 152701.92 88218.23]\n [0.0 1.0 0.0 73994.56 122782.75 303319.26]\n [0.0 1.0 0.0 142107.34 91391.77 366168.42]]\n[103282.38 144259.4 146121.95 77798.83 191050.39 105008.31 81229.06\n 97483.56 110352.25 166187.94]\n[[1.0 0.0 66051.52 182645.56 118148.2]\n [0.0 0.0 100671.96 91790.61 249744.55]\n [1.0 0.0 101913.08 110594.11 229160.95]\n [1.0 0.0 27892.92 84710.77 164470.71]\n [1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 1.0 72107.6 127864.55 353183.81]\n [0.0 1.0 20229.59 65947.93 185265.1]\n [0.0 1.0 61136.38 152701.92 88218.23]\n [1.0 0.0 73994.56 122782.75 303319.26]\n [1.0 0.0 142107.34 91391.77 366168.42]]\n[103282.38 144259.4 146121.95 77798.83 191050.39 105008.31 81229.06\n 97483.56 110352.25 166187.94]\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.2, random_state = 0)\n",
"X1_train, X1_test, Y1_train, Y1_test = train_test_split(X1, Y, test_size = 0.2, random_state = 0)\n",
"print(X_test)\n",
"print(Y_test)\n",
"print(X1_test)\n",
"print(Y1_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 第2步在训练集上训练多元线性回归模型"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LinearRegression()"
]
},
"metadata": {},
"execution_count": 64
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"regressor = LinearRegression()\n",
"regressor.fit(X_train, Y_train)\n",
"regressor1 = LinearRegression()\n",
"regressor1.fit(X1_train, Y1_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 第3步在测试集上预测结果¶"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"y_pred = regressor.predict(X_test)\n",
"y1_pred = regressor1.predict(X1_test)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[102388.94113041 121465.72713517 127340.57708619 71709.47538912\n 174211.0848 121771.65061494 68393.54360668 95588.5313349\n 116596.3467699 162514.07218551]\n[102388.94113046 121465.72713518 127340.57708619 71709.47538916\n 174211.08479987 121771.65061482 68393.5436067 95588.53133498\n 116596.34676982 162514.07218541]\n"
]
}
],
"source": [
"print(y_pred)\n",
"print(y1_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**完整的项目请前往Github项目100-Days-Of-ML-Code查看。有任何的建议或者意见欢迎在issue中提出~**"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3-final"
}
},
"nbformat": 4,
"nbformat_minor": 2
}