340 lines
10 KiB
Plaintext
340 lines
10 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#机器学习100天——第十一天:K近邻法(K-NN)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"##第一步:导入相关库"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import pandas as pd"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"##第二步:导入数据集"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" User ID Gender Age EstimatedSalary Purchased\n",
|
||
"0 15624510 Male 19 19000 0\n",
|
||
"1 15810944 Male 35 20000 0\n",
|
||
"2 15668575 Female 26 43000 0\n",
|
||
"3 15603246 Female 27 57000 0\n",
|
||
"4 15804002 Male 19 76000 0\n",
|
||
"5 15728773 Male 27 58000 0\n",
|
||
"6 15598044 Female 27 84000 0\n",
|
||
"7 15694829 Female 32 150000 1\n",
|
||
"8 15600575 Male 25 33000 0\n",
|
||
"9 15727311 Female 35 65000 0\n",
|
||
"10 15570769 Female 26 80000 0\n",
|
||
"11 15606274 Female 26 52000 0\n",
|
||
"12 15746139 Male 20 86000 0\n",
|
||
"13 15704987 Male 32 18000 0\n",
|
||
"14 15628972 Male 18 82000 0\n",
|
||
"15 15697686 Male 29 80000 0\n",
|
||
"16 15733883 Male 47 25000 1\n",
|
||
"17 15617482 Male 45 26000 1\n",
|
||
"18 15704583 Male 46 28000 1\n",
|
||
"19 15621083 Female 48 29000 1\n",
|
||
"20 15649487 Male 45 22000 1\n",
|
||
"21 15736760 Female 47 49000 1\n",
|
||
"22 15714658 Male 48 41000 1\n",
|
||
"23 15599081 Female 45 22000 1\n",
|
||
"24 15705113 Male 46 23000 1\n",
|
||
"25 15631159 Male 47 20000 1\n",
|
||
"26 15792818 Male 49 28000 1\n",
|
||
"27 15633531 Female 47 30000 1\n",
|
||
"28 15744529 Male 29 43000 0\n",
|
||
"29 15669656 Male 31 18000 0\n",
|
||
".. ... ... ... ... ...\n",
|
||
"370 15611430 Female 60 46000 1\n",
|
||
"371 15774744 Male 60 83000 1\n",
|
||
"372 15629885 Female 39 73000 0\n",
|
||
"373 15708791 Male 59 130000 1\n",
|
||
"374 15793890 Female 37 80000 0\n",
|
||
"375 15646091 Female 46 32000 1\n",
|
||
"376 15596984 Female 46 74000 0\n",
|
||
"377 15800215 Female 42 53000 0\n",
|
||
"378 15577806 Male 41 87000 1\n",
|
||
"379 15749381 Female 58 23000 1\n",
|
||
"380 15683758 Male 42 64000 0\n",
|
||
"381 15670615 Male 48 33000 1\n",
|
||
"382 15715622 Female 44 139000 1\n",
|
||
"383 15707634 Male 49 28000 1\n",
|
||
"384 15806901 Female 57 33000 1\n",
|
||
"385 15775335 Male 56 60000 1\n",
|
||
"386 15724150 Female 49 39000 1\n",
|
||
"387 15627220 Male 39 71000 0\n",
|
||
"388 15672330 Male 47 34000 1\n",
|
||
"389 15668521 Female 48 35000 1\n",
|
||
"390 15807837 Male 48 33000 1\n",
|
||
"391 15592570 Male 47 23000 1\n",
|
||
"392 15748589 Female 45 45000 1\n",
|
||
"393 15635893 Male 60 42000 1\n",
|
||
"394 15757632 Female 39 59000 0\n",
|
||
"395 15691863 Female 46 41000 1\n",
|
||
"396 15706071 Male 51 23000 1\n",
|
||
"397 15654296 Female 50 20000 1\n",
|
||
"398 15755018 Male 36 33000 0\n",
|
||
"399 15594041 Female 49 36000 1\n",
|
||
"\n",
|
||
"[400 rows x 5 columns]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"dataset = pd.read_csv('../datasets/Social_Network_Ads.csv')\n",
|
||
"print(dataset)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"为了方便理解,这里我们只取Age年龄和EstimatedSalary估计工资作为特征"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X = dataset.iloc[:, [2, 3]].values\n",
|
||
"y = dataset.iloc[:, 4].values"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"##第三步:将数据划分成训练集和测试集"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, random_state = 0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"##第四步:特征缩放"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"D:\\anaconda\\lib\\site-packages\\sklearn\\utils\\validation.py:429: DataConversionWarning: Data with input dtype int64 was converted to float64 by StandardScaler.\n",
|
||
" warnings.warn(msg, _DataConversionWarning)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"sc = StandardScaler()\n",
|
||
"X_train = sc.fit_transform(X_train)\n",
|
||
"X_test = sc.transform(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"##第五步:使用K-NN对训练集数据进行训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"从sklearn的neighbors类中导入KNeighborsClassifier学习器"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.neighbors import KNeighborsClassifier"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"设置好相关的参数\n",
|
||
"n_neighbors = 5(K值的选择,默认选择5)、\n",
|
||
"metric = 'minkowski'(距离度量的选择,这里选择的是闵氏距离(默认参数))、\n",
|
||
"p = 2 (距离度量metric的附属参数,只用于闵氏距离和带权重闵氏距离中p值的选择,p=1为曼哈顿距离, p=2为欧式距离。默认为2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
|
||
" metric_params=None, n_jobs=1, n_neighbors=5, p=2,\n",
|
||
" weights='uniform')"
|
||
]
|
||
},
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"classifier = KNeighborsClassifier(n_neighbors = 5, metric = 'minkowski', p = 2)\n",
|
||
"classifier.fit(X_train, y_train)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"##第六步:对测试集进行预测"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 0 1 0 0 1 0 1 0 1 0 0 0 0 0 0 1 0 0 0 0\n",
|
||
" 0 0 1 0 0 0 0 1 0 0 1 0 1 1 0 0 1 1 1 0 0 1 0 0 1 0 1 0 1 0 0 0 0 1 0 0 1\n",
|
||
" 0 0 0 0 1 1 1 1 0 0 1 0 0 1 1 0 0 1 0 0 0 0 0 1 1 1]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"y_pred = classifier.predict(X_test)\n",
|
||
"print(y_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"##第七步:生成混淆矩阵"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"混淆矩阵可以对一个分类器性能进行分析,由此可以计算出许多指标,例如:ROC曲线、正确率等"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[[64 4]\n",
|
||
" [ 3 29]]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import confusion_matrix\n",
|
||
"cm = confusion_matrix(y_test, y_pred)\n",
|
||
"print(cm)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
" 预测值\n",
|
||
" 0 1\n",
|
||
" 实0 64 4 \n",
|
||
" 际1 3 29\n",
|
||
" 值\n",
|
||
"\n",
|
||
"预测集中的0总共有68个,1总共有32个。\n",
|
||
"在这个混淆矩阵中,实际有68个0,但K-NN预测出有67(64+3)个0,其中有3个实际上是1。\n",
|
||
"同时K-NN预测出有33(4+29)个1,其中4个实际上是0。"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.1"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|