309 lines
19 KiB
Plaintext
309 lines
19 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 机器学习100天——第3天:多元线性回归(Multiple Linear Regression)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 第1步:数据预处理"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**导入库**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 45,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**导入数据集**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 57,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"X:\n[[165349.2 136897.8 471784.1 'New York']\n [162597.7 151377.59 443898.53 'California']\n [153441.51 101145.55 407934.54 'Florida']\n [144372.41 118671.85 383199.62 'New York']\n [142107.34 91391.77 366168.42 'Florida']\n [131876.9 99814.71 362861.36 'New York']\n [134615.46 147198.87 127716.82 'California']\n [130298.13 145530.06 323876.68 'Florida']\n [120542.52 148718.95 311613.29 'New York']\n [123334.88 108679.17 304981.62 'California']]\nY:\n[192261.83 191792.06 191050.39 182901.99 166187.94 156991.12 156122.51\n 155752.6 152211.77 149759.96 146121.95 144259.4 141585.52 134307.35\n 132602.65 129917.04 126992.93 125370.37 124266.9 122776.86 118474.03\n 111313.02 110352.25 108733.99 108552.04 107404.34 105733.54 105008.31\n 103282.38 101004.64 99937.59 97483.56 97427.84 96778.92 96712.8\n 96479.51 90708.19 89949.14 81229.06 81005.76 78239.91 77798.83\n 71498.49 69758.98 65200.33 64926.08 49490.75 42559.73 35673.41\n 14681.4 ]\n"
|
||
]
|
||
},
|
||
{
|
||
"output_type": "execute_result",
|
||
"data": {
|
||
"text/plain": [
|
||
" R&D Spend Administration Marketing Spend State Profit\n",
|
||
"0 165349.20 136897.80 471784.10 New York 192261.83\n",
|
||
"1 162597.70 151377.59 443898.53 California 191792.06\n",
|
||
"2 153441.51 101145.55 407934.54 Florida 191050.39\n",
|
||
"3 144372.41 118671.85 383199.62 New York 182901.99\n",
|
||
"4 142107.34 91391.77 366168.42 Florida 166187.94"
|
||
],
|
||
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>R&D Spend</th>\n <th>Administration</th>\n <th>Marketing Spend</th>\n <th>State</th>\n <th>Profit</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>165349.20</td>\n <td>136897.80</td>\n <td>471784.10</td>\n <td>New York</td>\n <td>192261.83</td>\n </tr>\n <tr>\n <th>1</th>\n <td>162597.70</td>\n <td>151377.59</td>\n <td>443898.53</td>\n <td>California</td>\n <td>191792.06</td>\n </tr>\n <tr>\n <th>2</th>\n <td>153441.51</td>\n <td>101145.55</td>\n <td>407934.54</td>\n <td>Florida</td>\n <td>191050.39</td>\n </tr>\n <tr>\n <th>3</th>\n <td>144372.41</td>\n <td>118671.85</td>\n <td>383199.62</td>\n <td>New York</td>\n <td>182901.99</td>\n </tr>\n <tr>\n <th>4</th>\n <td>142107.34</td>\n <td>91391.77</td>\n <td>366168.42</td>\n <td>Florida</td>\n <td>166187.94</td>\n </tr>\n </tbody>\n</table>\n</div>"
|
||
},
|
||
"metadata": {},
|
||
"execution_count": 57
|
||
}
|
||
],
|
||
"source": [
|
||
"dataset = pd.read_csv('../datasets/50_Startups.csv')\n",
|
||
"X = dataset.iloc[ : , :-1].values\n",
|
||
"Y = dataset.iloc[ : , 4 ].values\n",
|
||
"Z = dataset.iloc[ : , 0 ].values\n",
|
||
"print(\"X:\")\n",
|
||
"print(X[:10])\n",
|
||
"print(\"Y:\")\n",
|
||
"print(Y)\n",
|
||
"dataset.head(5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 59,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"[[165349.2 136897.8 471784.1 'New York']\n [162597.7 151377.59 443898.53 'California']\n [153441.51 101145.55 407934.54 'Florida']\n [144372.41 118671.85 383199.62 'New York']\n [142107.34 91391.77 366168.42 'Florida']\n [131876.9 99814.71 362861.36 'New York']\n [134615.46 147198.87 127716.82 'California']\n [130298.13 145530.06 323876.68 'Florida']\n [120542.52 148718.95 311613.29 'New York']\n [123334.88 108679.17 304981.62 'California']\n [101913.08 110594.11 229160.95 'Florida']\n [100671.96 91790.61 249744.55 'California']\n [93863.75 127320.38 249839.44 'Florida']\n [91992.39 135495.07 252664.93 'California']\n [119943.24 156547.42 256512.92 'Florida']\n [114523.61 122616.84 261776.23 'New York']\n [78013.11 121597.55 264346.06 'California']\n [94657.16 145077.58 282574.31 'New York']\n [91749.16 114175.79 294919.57 'Florida']\n [86419.7 153514.11 224494.78489361703 'New York']\n [76253.86 113867.3 298664.47 'California']\n [78389.47 153773.43 299737.29 'New York']\n [73994.56 122782.75 303319.26 'Florida']\n [67532.53 105751.03 304768.73 'Florida']\n [77044.01 99281.34 140574.81 'New York']\n [64664.71 139553.16 137962.62 'California']\n [75328.87 144135.98 134050.07 'Florida']\n [72107.6 127864.55 353183.81 'New York']\n [66051.52 182645.56 118148.2 'Florida']\n [65605.48 153032.06 107138.38 'New York']\n [61994.48 115641.28 91131.24 'Florida']\n [61136.38 152701.92 88218.23 'New York']\n [63408.86 129219.61 46085.25 'California']\n [55493.95 103057.49 214634.81 'Florida']\n [46426.07 157693.92 210797.67 'California']\n [46014.02 85047.44 205517.64 'New York']\n [28663.76 127056.21 201126.82 'Florida']\n [44069.95 51283.14 197029.42 'California']\n [20229.59 65947.93 185265.1 'New York']\n [38558.51 82982.09 174999.3 'California']\n [28754.33 118546.05 172795.67 'California']\n [27892.92 84710.77 164470.71 'Florida']\n [23640.93 96189.63 148001.11 'California']\n [15505.73 127382.3 35534.17 'New York']\n [22177.74 154806.14 28334.72 'California']\n [1000.23 124153.04 1903.93 'New York']\n [1315.46 115816.21 297114.46 'Florida']\n [76793.34958333334 135426.92 224494.78489361703 'California']\n [542.05 51743.15 224494.78489361703 'New York']\n [76793.34958333334 116983.8 45173.06 'California']]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"imputer = SimpleImputer(missing_values=0.0, strategy=\"mean\")\n",
|
||
"imputer = imputer.fit(X[ : , 0:3])\n",
|
||
"X[ : , 0:3] = imputer.transform(X[ : , 0:3])\n",
|
||
"print(X)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**将类别数据数字化**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 60,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"original:\n[[165349.2 136897.8 471784.1 'New York']\n [162597.7 151377.59 443898.53 'California']\n [153441.51 101145.55 407934.54 'Florida']\n [144372.41 118671.85 383199.62 'New York']\n [142107.34 91391.77 366168.42 'Florida']\n [131876.9 99814.71 362861.36 'New York']\n [134615.46 147198.87 127716.82 'California']\n [130298.13 145530.06 323876.68 'Florida']\n [120542.52 148718.95 311613.29 'New York']\n [123334.88 108679.17 304981.62 'California']]\nlabelencoder:\n[[165349.2 136897.8 471784.1 2]\n [162597.7 151377.59 443898.53 0]\n [153441.51 101145.55 407934.54 1]\n [144372.41 118671.85 383199.62 2]\n [142107.34 91391.77 366168.42 1]\n [131876.9 99814.71 362861.36 2]\n [134615.46 147198.87 127716.82 0]\n [130298.13 145530.06 323876.68 1]\n [120542.52 148718.95 311613.29 2]\n [123334.88 108679.17 304981.62 0]]\nonehot:\n[[0.0 0.0 1.0 165349.2 136897.8 471784.1]\n [1.0 0.0 0.0 162597.7 151377.59 443898.53]\n [0.0 1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 0.0 1.0 144372.41 118671.85 383199.62]\n [0.0 1.0 0.0 142107.34 91391.77 366168.42]\n [0.0 0.0 1.0 131876.9 99814.71 362861.36]\n [1.0 0.0 0.0 134615.46 147198.87 127716.82]\n [0.0 1.0 0.0 130298.13 145530.06 323876.68]\n [0.0 0.0 1.0 120542.52 148718.95 311613.29]\n [1.0 0.0 0.0 123334.88 108679.17 304981.62]]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.preprocessing import LabelEncoder, OneHotEncoder\n",
|
||
"from sklearn.compose import ColumnTransformer \n",
|
||
"labelencoder = LabelEncoder()\n",
|
||
"print(\"original:\")\n",
|
||
"print(X[:10])\n",
|
||
"#print(X[: , 3])\n",
|
||
"X[: , 3] = labelencoder.fit_transform(X[ : , 3])\n",
|
||
"#print(X[: , 3])\n",
|
||
"print(\"labelencoder:\")\n",
|
||
"print(X[:10])\n",
|
||
"ct = ColumnTransformer([( \"encoder\", OneHotEncoder(), [3])], remainder = 'passthrough')\n",
|
||
"X = ct.fit_transform(X)\n",
|
||
"#onehotencoder = OneHotEncoder(categorical_features = [3])\n",
|
||
"#X = onehotencoder.fit_transform(X).toarray()\n",
|
||
"print(\"onehot:\")\n",
|
||
"print(X[:10])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**躲避虚拟变量陷阱**\n",
|
||
"\n",
|
||
"在回归预测中我们需要所有的数据都是numeric的,但是会有一些非numeric的数据,比如国家,省,部门,性别。这时候我们需要设置虚拟变量(Dummy variable)。做法是将此变量中的每一个值,衍生成为新的变量,是设为1,否设为0.举个例子,“性别”这个变量,我们可以虚拟出“男”和”女”两虚拟变量,男性的话“男”值为1,”女”值为0;女性的话“男”值为0,”女”值为1。\n",
|
||
"\n",
|
||
"但是要注意,这时候虚拟变量陷阱就出现了。就拿性别来说,其实一个虚拟变量就够了,比如 1 的时候是“男”, 0 的时候是”非男”,即为女。如果设置两个虚拟变量“男”和“女”,语义上来说没有问题,可以理解,但是在回归预测中会多出一个变量,多出的这个变量将会对回归预测结果产生影响。一般来说,如果虚拟变量要比实际变量的种类少一个。 \n",
|
||
"\n",
|
||
"在多重线性回归中,变量不是越多越好,而是选择适合的变量。这样才会对结果准确预测。如果category类的特征都放进去,拟合的时候,所有权重的计算,都可以有两种方法实现,一种是提高某个category的w,一种是降低其他category的w,这两种效果是等效的,也就是发生了共线性,虚拟变量系数相加和为1,出现完全共线陷阱。\n",
|
||
"\n",
|
||
"**但是下面测试尽然和想法不一致。。。**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 61,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X1 = X[: , 1:]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 62,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"[[0.0 1.0 165349.2 136897.8 471784.1]\n [0.0 0.0 162597.7 151377.59 443898.53]\n [1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 1.0 144372.41 118671.85 383199.62]\n [1.0 0.0 142107.34 91391.77 366168.42]\n [0.0 1.0 131876.9 99814.71 362861.36]\n [0.0 0.0 134615.46 147198.87 127716.82]\n [1.0 0.0 130298.13 145530.06 323876.68]\n [0.0 1.0 120542.52 148718.95 311613.29]\n [0.0 0.0 123334.88 108679.17 304981.62]\n [1.0 0.0 101913.08 110594.11 229160.95]\n [0.0 0.0 100671.96 91790.61 249744.55]\n [1.0 0.0 93863.75 127320.38 249839.44]\n [0.0 0.0 91992.39 135495.07 252664.93]\n [1.0 0.0 119943.24 156547.42 256512.92]\n [0.0 1.0 114523.61 122616.84 261776.23]\n [0.0 0.0 78013.11 121597.55 264346.06]\n [0.0 1.0 94657.16 145077.58 282574.31]\n [1.0 0.0 91749.16 114175.79 294919.57]\n [0.0 1.0 86419.7 153514.11 224494.78489361703]\n [0.0 0.0 76253.86 113867.3 298664.47]\n [0.0 1.0 78389.47 153773.43 299737.29]\n [1.0 0.0 73994.56 122782.75 303319.26]\n [1.0 0.0 67532.53 105751.03 304768.73]\n [0.0 1.0 77044.01 99281.34 140574.81]\n [0.0 0.0 64664.71 139553.16 137962.62]\n [1.0 0.0 75328.87 144135.98 134050.07]\n [0.0 1.0 72107.6 127864.55 353183.81]\n [1.0 0.0 66051.52 182645.56 118148.2]\n [0.0 1.0 65605.48 153032.06 107138.38]\n [1.0 0.0 61994.48 115641.28 91131.24]\n [0.0 1.0 61136.38 152701.92 88218.23]\n [0.0 0.0 63408.86 129219.61 46085.25]\n [1.0 0.0 55493.95 103057.49 214634.81]\n [0.0 0.0 46426.07 157693.92 210797.67]\n [0.0 1.0 46014.02 85047.44 205517.64]\n [1.0 0.0 28663.76 127056.21 201126.82]\n [0.0 0.0 44069.95 51283.14 197029.42]\n [0.0 1.0 20229.59 65947.93 185265.1]\n [0.0 0.0 38558.51 82982.09 174999.3]\n [0.0 0.0 28754.33 118546.05 172795.67]\n [1.0 0.0 27892.92 84710.77 164470.71]\n [0.0 0.0 23640.93 96189.63 148001.11]\n [0.0 1.0 15505.73 127382.3 35534.17]\n [0.0 0.0 22177.74 154806.14 28334.72]\n [0.0 1.0 1000.23 124153.04 1903.93]\n [1.0 0.0 1315.46 115816.21 297114.46]\n [0.0 0.0 76793.34958333334 135426.92 224494.78489361703]\n [0.0 1.0 542.05 51743.15 224494.78489361703]\n [0.0 0.0 76793.34958333334 116983.8 45173.06]]\n[[0.0 0.0 1.0 165349.2 136897.8 471784.1]\n [1.0 0.0 0.0 162597.7 151377.59 443898.53]\n [0.0 1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 0.0 1.0 144372.41 118671.85 383199.62]\n [0.0 1.0 0.0 142107.34 91391.77 366168.42]\n [0.0 0.0 1.0 131876.9 99814.71 362861.36]\n [1.0 0.0 0.0 134615.46 147198.87 127716.82]\n [0.0 1.0 0.0 130298.13 145530.06 323876.68]\n [0.0 0.0 1.0 120542.52 148718.95 311613.29]\n [1.0 0.0 0.0 123334.88 108679.17 304981.62]\n [0.0 1.0 0.0 101913.08 110594.11 229160.95]\n [1.0 0.0 0.0 100671.96 91790.61 249744.55]\n [0.0 1.0 0.0 93863.75 127320.38 249839.44]\n [1.0 0.0 0.0 91992.39 135495.07 252664.93]\n [0.0 1.0 0.0 119943.24 156547.42 256512.92]\n [0.0 0.0 1.0 114523.61 122616.84 261776.23]\n [1.0 0.0 0.0 78013.11 121597.55 264346.06]\n [0.0 0.0 1.0 94657.16 145077.58 282574.31]\n [0.0 1.0 0.0 91749.16 114175.79 294919.57]\n [0.0 0.0 1.0 86419.7 153514.11 224494.78489361703]\n [1.0 0.0 0.0 76253.86 113867.3 298664.47]\n [0.0 0.0 1.0 78389.47 153773.43 299737.29]\n [0.0 1.0 0.0 73994.56 122782.75 303319.26]\n [0.0 1.0 0.0 67532.53 105751.03 304768.73]\n [0.0 0.0 1.0 77044.01 99281.34 140574.81]\n [1.0 0.0 0.0 64664.71 139553.16 137962.62]\n [0.0 1.0 0.0 75328.87 144135.98 134050.07]\n [0.0 0.0 1.0 72107.6 127864.55 353183.81]\n [0.0 1.0 0.0 66051.52 182645.56 118148.2]\n [0.0 0.0 1.0 65605.48 153032.06 107138.38]\n [0.0 1.0 0.0 61994.48 115641.28 91131.24]\n [0.0 0.0 1.0 61136.38 152701.92 88218.23]\n [1.0 0.0 0.0 63408.86 129219.61 46085.25]\n [0.0 1.0 0.0 55493.95 103057.49 214634.81]\n [1.0 0.0 0.0 46426.07 157693.92 210797.67]\n [0.0 0.0 1.0 46014.02 85047.44 205517.64]\n [0.0 1.0 0.0 28663.76 127056.21 201126.82]\n [1.0 0.0 0.0 44069.95 51283.14 197029.42]\n [0.0 0.0 1.0 20229.59 65947.93 185265.1]\n [1.0 0.0 0.0 38558.51 82982.09 174999.3]\n [1.0 0.0 0.0 28754.33 118546.05 172795.67]\n [0.0 1.0 0.0 27892.92 84710.77 164470.71]\n [1.0 0.0 0.0 23640.93 96189.63 148001.11]\n [0.0 0.0 1.0 15505.73 127382.3 35534.17]\n [1.0 0.0 0.0 22177.74 154806.14 28334.72]\n [0.0 0.0 1.0 1000.23 124153.04 1903.93]\n [0.0 1.0 0.0 1315.46 115816.21 297114.46]\n [1.0 0.0 0.0 76793.34958333334 135426.92 224494.78489361703]\n [0.0 0.0 1.0 542.05 51743.15 224494.78489361703]\n [1.0 0.0 0.0 76793.34958333334 116983.8 45173.06]]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(X1)\n",
|
||
"print(X)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**拆分数据集为训练集和测试集**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 63,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"[[0.0 1.0 0.0 66051.52 182645.56 118148.2]\n [1.0 0.0 0.0 100671.96 91790.61 249744.55]\n [0.0 1.0 0.0 101913.08 110594.11 229160.95]\n [0.0 1.0 0.0 27892.92 84710.77 164470.71]\n [0.0 1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 0.0 1.0 72107.6 127864.55 353183.81]\n [0.0 0.0 1.0 20229.59 65947.93 185265.1]\n [0.0 0.0 1.0 61136.38 152701.92 88218.23]\n [0.0 1.0 0.0 73994.56 122782.75 303319.26]\n [0.0 1.0 0.0 142107.34 91391.77 366168.42]]\n[103282.38 144259.4 146121.95 77798.83 191050.39 105008.31 81229.06\n 97483.56 110352.25 166187.94]\n[[1.0 0.0 66051.52 182645.56 118148.2]\n [0.0 0.0 100671.96 91790.61 249744.55]\n [1.0 0.0 101913.08 110594.11 229160.95]\n [1.0 0.0 27892.92 84710.77 164470.71]\n [1.0 0.0 153441.51 101145.55 407934.54]\n [0.0 1.0 72107.6 127864.55 353183.81]\n [0.0 1.0 20229.59 65947.93 185265.1]\n [0.0 1.0 61136.38 152701.92 88218.23]\n [1.0 0.0 73994.56 122782.75 303319.26]\n [1.0 0.0 142107.34 91391.77 366168.42]]\n[103282.38 144259.4 146121.95 77798.83 191050.39 105008.31 81229.06\n 97483.56 110352.25 166187.94]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.2, random_state = 0)\n",
|
||
"X1_train, X1_test, Y1_train, Y1_test = train_test_split(X1, Y, test_size = 0.2, random_state = 0)\n",
|
||
"print(X_test)\n",
|
||
"print(Y_test)\n",
|
||
"print(X1_test)\n",
|
||
"print(Y1_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 第2步:在训练集上训练多元线性回归模型"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 64,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"output_type": "execute_result",
|
||
"data": {
|
||
"text/plain": [
|
||
"LinearRegression()"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"execution_count": 64
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"regressor = LinearRegression()\n",
|
||
"regressor.fit(X_train, Y_train)\n",
|
||
"regressor1 = LinearRegression()\n",
|
||
"regressor1.fit(X1_train, Y1_train)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 第3步:在测试集上预测结果¶"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 65,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_pred = regressor.predict(X_test)\n",
|
||
"y1_pred = regressor1.predict(X1_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 66,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"[102388.94113041 121465.72713517 127340.57708619 71709.47538912\n 174211.0848 121771.65061494 68393.54360668 95588.5313349\n 116596.3467699 162514.07218551]\n[102388.94113046 121465.72713518 127340.57708619 71709.47538916\n 174211.08479987 121771.65061482 68393.5436067 95588.53133498\n 116596.34676982 162514.07218541]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(y_pred)\n",
|
||
"print(y1_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**完整的项目请前往Github项目100-Days-Of-ML-Code查看。有任何的建议或者意见欢迎在issue中提出~**"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.3-final"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
} |