赞
踩
{
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 768 entries, 0 to 767\n",
"Data columns (total 9 columns):\n",
"Pregnancies 768 non-null int64\n",
"Glucose 768 non-null int64\n",
"BloodPressure 768 non-null int64\n",
"SkinThickness 768 non-null int64\n",
"Insulin 768 non-null int64\n",
"BMI 768 non-null float64\n",
"DiabetesPedigreeFunction 768 non-null float64\n",
"Age 768 non-null int64\n",
"Outcome 768 non-null int64\n",
"dtypes: float64(2), int64(7)\n",
"memory usage: 54.1 KB\n"
]
},
{
"data": {
"text/html": [
"
"
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"\n",
"
" \n",
"
\n","
\n","
Pregnancies\n","
Glucose\n","
BloodPressure\n","
SkinThickness\n","
Insulin\n","
BMI\n","
DiabetesPedigreeFunction\n","
Age\n","
Outcome\n","
\n","
\n","
\n","
\n","
0\n","
6\n","
148\n","
72\n","
35\n","
0\n","
33.6\n","
0.627\n","
50\n","
1\n","
\n","
\n","
1\n","
1\n","
85\n","
66\n","
29\n","
0\n","
26.6\n","
0.351\n","
31\n","
0\n","
\n","
\n","
2\n","
8\n","
183\n","
64\n","
0\n","
0\n","
23.3\n","
0.672\n","
32\n","
1\n","
\n","
\n","
3\n","
1\n","
89\n","
66\n","
23\n","
94\n","
28.1\n","
0.167\n","
21\n","
0\n","
\n","
\n","
4\n","
0\n","
137\n","
40\n","
35\n","
168\n","
43.1\n","
2.288\n","
33\n","
1\n","
\n","
\n","
\n","
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"0 0.627 50 1 \n",
"1 0.351 31 0 \n",
"2 0.672 32 1 \n",
"3 0.167 21 0 \n",
"4 2.288 33 1 "
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 导入必要模块\n",
"import numpy as numpy\n",
"import pandas as pandas\n",
"import matplotlib.pyplot as matplot\n",
"import seaborn as seaborn\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"#竞赛的评价指标为logloss\n",
"from sklearn.metrics import log_loss \n",
"%matplotlib inline\n",
"data=pandas.read_csv('diabetes.csv');\n",
"#数据信息\n",
"data.info()\n",
"#前五行\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"
"
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"\n",
"
" \n",
"
\n","
\n","
Pregnancies\n","
Glucose\n","
BloodPressure\n","
SkinThickness\n","
Insulin\n","
BMI\n","
DiabetesPedigreeFunction\n","
Age\n","
Outcome\n","
\n","
\n","
\n","
\n","
count\n","
768.000000\n","
768.000000\n","
768.000000\n","
768.000000\n","
768.000000\n","
768.000000\n","
768.000000\n","
768.000000\n","
768.000000\n","
\n","
\n","
mean\n","
3.845052\n","
120.894531\n","
69.105469\n","
20.536458\n","
79.799479\n","
31.992578\n","
0.471876\n","
33.240885\n","
0.348958\n","
\n","
\n","
std\n","
3.369578\n","
31.972618\n","
19.355807\n","
15.952218\n","
115.244002\n","
7.884160\n","
0.331329\n","
11.760232\n","
0.476951\n","
\n","
\n","
min\n","
0.000000\n","
0.000000\n","
0.000000\n","
0.000000\n","
0.000000\n","
0.000000\n","
0.078000\n","
21.000000\n","
0.000000\n","
\n","
\n","
25%\n","
1.000000\n","
99.000000\n","
62.000000\n","
0.000000\n","
0.000000\n","
27.300000\n","
0.243750\n","
24.000000\n","
0.000000\n","
\n","
\n","
50%\n","
3.000000\n","
117.000000\n","
72.000000\n","
23.000000\n","
30.500000\n","
32.000000\n","
0.372500\n","
29.000000\n","
0.000000\n","
\n","
\n","
75%\n","
6.000000\n","
140.250000\n","
80.000000\n","
32.000000\n","
127.250000\n","
36.600000\n","
0.626250\n","
41.000000\n","
1.000000\n","
\n","
\n","
max\n","
17.000000\n","
199.000000\n","
122.000000\n","
99.000000\n","
846.000000\n","
67.100000\n","
2.420000\n","
81.000000\n","
1.000000\n","
\n","
\n","
\n","
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin \\\n",
"count 768.000000 768.000000 768.000000 768.000000 768.000000 \n",
"mean 3.845052 120.894531 69.105469 20.536458 79.799479 \n",
"std 3.369578 31.972618 19.355807 15.952218 115.244002 \n",
"min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
"25% 1.000000 99.000000 62.000000 0.000000 0.000000 \n",
"50% 3.000000 117.000000 72.000000 23.000000 30.500000 \n",
"75% 6.000000 140.250000 80.000000 32.000000 127.250000 \n",
"max 17.000000 199.000000 122.000000 99.000000 846.000000 \n",
"\n",
" BMI DiabetesPedigreeFunction Age Outcome \n",
"count 768.000000 768.000000 768.000000 768.000000 \n",
"mean 31.992578 0.471876 33.240885 0.348958 \n",
"std 7.884160 0.331329 11.760232 0.476951 \n",
"min 0.000000 0.078000 21.000000 0.000000 \n",
"25% 27.300000 0.243750 24.000000 0.000000 \n",
"50% 32.000000 0.372500 29.000000 0.000000 \n",
"75% 36.600000 0.626250 41.000000 1.000000 \n",
"max 67.100000 2.420000 81.000000 1.000000 "
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#数据描述\n",
"data.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 糖尿病人数直方图"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEKCAYAAAAIO8L1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAD49JREFUeJzt3XusZWV5x/HvD0a81AuXOVCcGRxTx1aMinRCaPmjFowB2zrUitGoTHGSaVLaqjSt1Da1l5hoq6KoIZkUZTBUpXgBDWlLRtR4QT1jcQCpnZFamAxlBkHUWmzBp3/s95TN8DKzuayzD7O/n2Rnr/Wsd6/zHHKYX9ZlvytVhSRJezto2g1IkpYmA0KS1GVASJK6DAhJUpcBIUnqMiAkSV0GhCSpy4CQJHUZEJKkrmXTbuCRWL58ea1evXrabUjSY8rWrVtvr6q5/Y17TAfE6tWrmZ+fn3YbkvSYkuQ/JhnnKSZJUpcBIUnqMiAkSV0GhCSpy4CQJHUNGhBJvpvkuiTXJplvtcOTXJVke3s/rNWT5PwkO5JsS3L8kL1JkvZtMY4gfrWqjquqtW39XGBLVa0BtrR1gNOANe21EbhgEXqTJD2IaZxiWgdsbsubgdPH6hfXyDXAoUmOnkJ/kiSGD4gC/jnJ1iQbW+2oqroVoL0f2eorgFvGPruz1SRJUzD0N6lPqqpdSY4Erkryr/sYm06tHjBoFDQbAY455phH3OAv/tHFj3gfOvBs/dszp92CNHWDHkFU1a72vhv4JHACcNvCqaP2vrsN3wmsGvv4SmBXZ5+bqmptVa2dm9vvVCKSpIdpsIBI8jNJnrKwDLwEuB64Aljfhq0HLm/LVwBntruZTgTuWjgVJUlafEOeYjoK+GSShZ/z91X1j0m+DlyaZANwM3BGG38l8FJgB/Bj4KwBe5Mk7cdgAVFVNwEv6NS/B5zSqRdw9lD9SJIeGr9JLUnqMiAkSV0GhCSpy4CQJHUZEJKkLgNCktRlQEiSugwISVKXASFJ6jIgJEldBoQkqcuAkCR1GRCSpC4DQpLUZUBIkroMCElSlwEhSeoyICRJXQaEJKnLgJAkdRkQkqQuA0KS1GVASJK6DAhJUpcBIUnqMiAkSV0GhCSpy4CQJHUZEJKkLgNCktRlQEiSugwISVLX4AGR5OAk/5LkM239mUm+mmR7ko8lOaTVH9/Wd7Ttq4fuTZL04BbjCOINwI1j6+8AzquqNcCdwIZW3wDcWVXPAs5r4yRJUzJoQCRZCfwa8HdtPcDJwGVtyGbg9La8rq3Ttp/SxkuSpmDoI4j3AH8M/LStHwF8v6ruaes7gRVteQVwC0DbflcbL0magsECIsmvA7uraut4uTO0Jtg2vt+NSeaTzO/Zs+dR6FSS1DPkEcRJwMuSfBf4KKNTS+8BDk2yrI1ZCexqyzuBVQBt+9OAO/beaVVtqqq1VbV2bm5uwPYlabYNFhBV9SdVtbKqVgOvAj5bVa8BrgZe0YatBy5vy1e0ddr2z1bVA44gJEmLYxrfg3gzcE6SHYyuMVzY6hcCR7T6OcC5U+hNktQs2/+QR66qPgd8ri3fBJzQGXM3cMZi9CNJ2j+/SS1J6jIgJEldBoQkqcuAkCR1GRCSpC4DQpLUZUBIkroMCElSlwEhSeoyICRJXQaEJKnLgJAkdRkQkqQuA0KS1GVASJK6DAhJUpcBIUnqMiAkSV0GhCSpy4CQJHUZEJKkLgNCktRlQEiSugwISVKXASFJ6jIgJEldBoQkqcuAkCR1GRCSpC4DQpLUZUBIkroMCElSlwEhSeoaLCCSPCHJ15J8M8kNSf6y1Z+Z5KtJtif5WJJDWv3xbX1H2756qN4kSfs35BHET4CTq+oFwHHAqUlOBN4BnFdVa4A7gQ1t/Abgzqp6FnBeGydJmpLBAqJGftRWH9deBZwMXNbqm4HT2/K6tk7bfkqSDNWfJGnfBr0GkeTgJNcCu4GrgO8A36+qe9qQncCKtrwCuAWgbb8LOGLI/iRJD27QgKiqe6vqOGAlcALwnN6w9t47Wqi9C0k2JplPMr9nz55Hr1lJ0v0syl1MVfV94HPAicChSZa1TSuBXW15J7AKoG1/GnBHZ1+bqmptVa2dm5sbunVJmllD3sU0l+TQtvxE4MXAjcDVwCvasPXA5W35irZO2/7ZqnrAEYQkaXFMFBBJtkxS28vRwNVJtgFfB66qqs8AbwbOSbKD0TWGC9v4C4EjWv0c4NzJfgVJ0hCW7WtjkicATwKWJzmM+64TPBV4+r4+W1XbgBd26jcxuh6xd/1u4IzJ2pYkDW2fAQH8DvBGRmGwlfsC4gfABwbsS5I0ZfsMiKp6L/DeJL9fVe9bpJ4kSUvA/o4gAKiq9yX5ZWD1+Geq6uKB+pIkTdlEAZHkw8DPAdcC97ZyAQaEJB2gJgoIYC1wrLedStLsmPR7ENcDPztkI5KkpWXSI4jlwLeSfI3RLK0AVNXLBulKkjR1kwbEXwzZhCRp6Zn0LqbPD92IpPu7+a+eN+0WtAQd8+fXLdrPmvQuph9y38yqhzB6tsN/VdVTh2pMkjRdkx5BPGV8PcnpdKbLkCQdOB7WbK5V9SlGT4aTJB2gJj3F9PKx1YMYfS/C70RI0gFs0ruYfmNs+R7gu4yeIS1JOkBNeg3irKEbkSQtLZM+MGhlkk8m2Z3ktiQfT7Jy6OYkSdMz6UXqDzF6JOjTgRXAp1tNknSAmjQg5qrqQ1V1T3tdBMwN2JckacomDYjbk7w2ycHt9Vrge0M2JkmarkkD4vXAK4H/BG4FXgF44VqSDmCT3ub618D6qroTIMnhwDsZBYck6QA06RHE8xfCAaCq7gBeOExLkqSlYNKAOCjJYQsr7Qhi0qMPSdJj0KT/yL8L+HKSyxhNsfFK4G2DdSVJmrpJv0l9cZJ5RhP0BXh5VX1r0M4kSVM18WmiFgiGgiTNiIc13bck6cBnQEiSugwISVKXASFJ6jIgJEldBoQkqWuwgEiyKsnVSW5MckOSN7T64UmuSrK9vR/W6klyfpIdSbYlOX6o3iRJ+zfkEcQ9wB9W1XOAE4GzkxwLnAtsqao1wJa2DnAasKa9NgIXDNibJGk/BguIqrq1qr7Rln8I3MjoaXTrgM1t2Gbg9La8Dri4Rq4BDk1y9FD9SZL2bVGuQSRZzWj2168CR1XVrTAKEeDINmwFcMvYx3a2miRpCgYPiCRPBj4OvLGqfrCvoZ1adfa3Mcl8kvk9e/Y8Wm1KkvYyaEAkeRyjcLikqj7RyrctnDpq77tbfSewauzjK4Fde++zqjZV1dqqWjs352OxJWkoQ97FFOBC4MaqevfYpiuA9W15PXD5WP3MdjfTicBdC6eiJEmLb8iH/pwEvA64Lsm1rfYW4O3ApUk2ADcDZ7RtVwIvBXYAP8ZnXkvSVA0WEFX1RfrXFQBO6Ywv4Oyh+pEkPTR+k1qS1GVASJK6DAhJUpcBIUnqMiAkSV0GhCSpy4CQJHUZEJKkLgNCktRlQEiSugwISVKXASFJ6jIgJEldBoQkqcuAkCR1GRCSpC4DQpLUZUBIkroMCElSlwEhSeoyICRJXQaEJKnLgJAkdRkQkqQuA0KS1GVASJK6DAhJUpcBIUnqMiAkSV0GhCSpy4CQJHUZEJKkLgNCktQ1WEAk+WCS3UmuH6sdnuSqJNvb+2GtniTnJ9mRZFuS44fqS5I0mSGPIC4CTt2rdi6wparWAFvaOsBpwJr22ghcMGBfkqQJDBYQVfUF4I69yuuAzW15M3D6WP3iGrkGODTJ0UP1Jknav8W+BnFUVd0K0N6PbPUVwC1j43a22gMk2ZhkPsn8nj17Bm1WkmbZUrlInU6tegOralNVra2qtXNzcwO3JUmza7ED4raFU0ftfXer7wRWjY1bCexa5N4kSWMWOyCuANa35fXA5WP1M9vdTCcCdy2cipIkTceyoXac5CPAi4DlSXYCbwXeDlyaZANwM3BGG34l8FJgB/Bj4Kyh+pIkTWawgKiqVz/IplM6Yws4e6heJEkP3VK5SC1JWmIMCElSlwEhSeoyICRJXQaEJKnLgJAkdRkQkqQuA0KS1GVASJK6DAhJUpcBIUnqMiAkSV0GhCSpy4CQJHUZEJKkLgNCktRlQEiSugwISVKXASFJ6jIgJEldBoQkqcuAkCR1GRCSpC4DQpLUZUBIkroMCElSlwEhSeoyICRJXQaEJKnLgJAkdRkQkqQuA0KS1LWkAiLJqUm+nWRHknOn3Y8kzbIlExBJDgY+AJwGHAu8Osmx0+1KkmbXkgkI4ARgR1XdVFX/A3wUWDflniRpZi2lgFgB3DK2vrPVJElTsGzaDYxJp1YPGJRsBDa21R8l+fagXc2W5cDt025iKcg710+7Bd2ff5sL3tr7p/Ihe8Ykg5ZSQOwEVo2trwR27T2oqjYBmxarqVmSZL6q1k67D2lv/m1Ox1I6xfR1YE2SZyY5BHgVcMWUe5KkmbVkjiCq6p4kvwf8E3Aw8MGqumHKbUnSzFoyAQFQVVcCV067jxnmqTstVf5tTkGqHnAdWJKkJXUNQpK0hBgQcooTLVlJPphkd5Lrp93LLDIgZpxTnGiJuwg4ddpNzCoDQk5xoiWrqr4A3DHtPmaVASGnOJHUZUBooilOJM0eA0ITTXEiafYYEHKKE0ldBsSMq6p7gIUpTm4ELnWKEy0VST4CfAX4+SQ7k2yYdk+zxG9SS5K6PIKQJHUZEJKkLgNCktRlQEiSugwISVKXAaGZl2RlksuTbE/ynSTvbd8J2ddn3rJY/UnTYkBopiUJ8AngU1W1Bng28GTgbfv5qAGhA54BoVl3MnB3VX0IoKruBd4EvD7J7yZ5/8LAJJ9J8qIkbweemOTaJJe0bWcm2Zbkm0k+3GrPSLKl1bckOabVL0pyQZKrk9yU5Ffacw9uTHLR2M97SZKvJPlGkn9I8uRF+68iYUBIzwW2jheq6gfAzTzIM9ur6lzgv6vquKp6TZLnAn8KnFxVLwDe0Ia+H7i4qp4PXAKcP7abwxiF05uATwPntV6el+S4JMuBPwNeXFXHA/PAOY/GLyxNqvs/gDRDQn/22ger95wMXFZVtwNU1cLzC34JeHlb/jDwN2Of+XRVVZLrgNuq6jqAJDcAqxlNmngs8KXRWTAOYTTlhLRoDAjNuhuA3xovJHkqoxlu7+L+R9lPeJB9TBom42N+0t5/Ora8sL4MuBe4qqpePcF+pUF4ikmzbgvwpCRnwv8/gvVdjB51eRNwXJKDkqxi9PS9Bf+b5HFj+3hlkiPaPg5v9S8zmh0X4DXAFx9CX9cAJyV5Vtvnk5I8+6H+ctIjYUBoptVotsrfBM5Ish34N+BuRncpfQn4d+A64J3AN8Y+ugnYluSSNvvt24DPJ/km8O425g+As5JsA17HfdcmJulrD/DbwEfa568BfuHh/p7Sw+FsrpKkLo8gJEldBoQkqcuAkCR1GRCSpC4DQpLUZUBIkroMCElSlwEhSer6Pz8/pjgBxhB9AAAAAElFTkSuQmCC\n",
"text/plain": [
"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"seaborn.countplot(data.Outcome);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 选择合适得特征值,分割数据"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(614, 8)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y_Feature=data['Outcome']\n",
"X_Feature =data.drop('Outcome', axis = 1)\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"#random select 20% as test data\n",
"x_train, x_test, y_train, y_test = train_test_split(X_Feature, Y_Feature, random_state=33, test_size=0.2)\n",
"x_train.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 标准化处理"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"ss_X = StandardScaler()\n",
"x_train=ss_X.fit_transform(x_train)\n",
"x_test = ss_X.fit_transform(x_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 缺省的Logistic正则"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"logloss of each fold is: [0.46129644 0.46510588 0.56426612 0.44377717 0.46617809]\n",
"cv logloss is: 0.48012474004701106\n"
]
},
{
"data": {
"text/plain": [
"0.7402597402597403"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"lr= LogisticRegression()\n",
"from sklearn.cross_validation import cross_val_score\n",
"loss = cross_val_score(lr, x_train, y_train, cv=5, scoring='neg_log_loss')\n",
"print 'logloss of each fold is: ',-loss\n",
"print'cv logloss is:', -loss.mean()\n",
"\n",
"lr.fit(x_train,y_train)\n",
"lr.score(x_test,y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"采用5这交叉验证,发现负损失稳定在0.4左右,准确度为0.73,下面用GridSearchCV进行参数调优"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=5, error_score='raise',\n",
" estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
" intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n",
" penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n",
" verbose=0, warm_start=False),\n",
" fit_params=None, iid=True, n_jobs=1,\n",
" param_grid={'penalty': ['l1', 'l2'], 'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]},\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
" scoring='neg_log_loss', verbose=0)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"penaltys = ['l1','l2']\n",
"Cs = [0.001, 0.01, 0.1, 1, 10, 100, 1000]\n",
"tuned_parameters = dict(penalty = penaltys, C = Cs)\n",
"\n",
"lr_penalty= LogisticRegression()\n",
"grid= GridSearchCV(lr_penalty, tuned_parameters,cv=5, scoring='neg_log_loss')\n",
"grid.fit(x_train,y_train)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mean_fit_time': array([0.00282516, 0.00275798, 0.00244303, 0.00339441, 0.00315719,\n",
" 0.003087 , 0.00379777, 0.00329165, 0.00360203, 0.00362625,\n",
" 0.00277553, 0.0023922 , 0.00354481, 0.00334883]),\n",
" 'mean_score_time': array([0.00222125, 0.00185637, 0.00199313, 0.00197988, 0.00186162,\n",
" 0.00181241, 0.00183372, 0.00223641, 0.00215921, 0.00247059,\n",
" 0.0017406 , 0.001439 , 0.00229459, 0.00270238]),\n",
" 'mean_test_score': array([-0.69314718, -0.64214833, -0.6721329 , -0.52844007, -0.48658742,\n",
" -0.47999943, -0.48043706, -0.48017599, -0.48089965, -0.48086932,\n",
" -0.48095236, -0.48095123, -0.48095803, -0.48095956]),\n",
" 'mean_train_score': array([-0.69314718, -0.6412946 , -0.67079105, -0.52380684, -0.47502596,\n",
" -0.46674403, -0.46228807, -0.46214818, -0.46206767, -0.46206619,\n",
" -0.46206531, -0.4620653 , -0.46206529, -0.46206529]),\n",
" 'param_C': masked_array(data=[0.001, 0.001, 0.01, 0.01, 0.1, 0.1, 1, 1, 10, 10, 100,\n",
" 100, 1000, 1000],\n",
" mask=[False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False],\n",
" fill_value='?',\n",
" dtype=object),\n",
" 'param_penalty': masked_array(data=['l1', 'l2', 'l1', 'l2', 'l1', 'l2', 'l1', 'l2', 'l1',\n",
" 'l2', 'l1', 'l2', 'l1', 'l2'],\n",
" mask=[False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False],\n",
" fill_value='?',\n",
" dtype=object),\n",
" 'params': [{'C': 0.001, 'penalty': 'l1'},\n",
" {'C': 0.001, 'penalty': 'l2'},\n",
" {'C': 0.01, 'penalty': 'l1'},\n",
" {'C': 0.01, 'penalty': 'l2'},\n",
" {'C': 0.1, 'penalty': 'l1'},\n",
" {'C': 0.1, 'penalty': 'l2'},\n",
" {'C': 1, 'penalty': 'l1'},\n",
" {'C': 1, 'penalty': 'l2'},\n",
" {'C': 10, 'penalty': 'l1'},\n",
" {'C': 10, 'penalty': 'l2'},\n",
" {'C': 100, 'penalty': 'l1'},\n",
" {'C': 100, 'penalty': 'l2'},\n",
" {'C': 1000, 'penalty': 'l1'},\n",
" {'C': 1000, 'penalty': 'l2'}],\n",
" 'rank_test_score': array([14, 12, 13, 11, 10, 1, 3, 2, 5, 4, 7, 6, 8, 9],\n",
" dtype=int32),\n",
" 'split0_test_score': array([-0.69314718, -0.64371675, -0.66946739, -0.52816851, -0.48618155,\n",
" -0.4678555 , -0.46345904, -0.46129644, -0.46108568, -0.46088502,\n",
" -0.46086642, -0.4608487 , -0.46084313, -0.46084513]),\n",
" 'split0_train_score': array([-0.69314718, -0.64085132, -0.66189542, -0.52529099, -0.47823867,\n",
" -0.47085398, -0.46684285, -0.46669378, -0.46662379, -0.46662226,\n",
" -0.46662151, -0.46662149, -0.46662149, -0.46662148]),\n",
" 'split1_test_score': array([-0.69314718, -0.64113725, -0.67517494, -0.52518603, -0.47166225,\n",
" -0.47077815, -0.46426366, -0.46510588, -0.4646352 , -0.46473286,\n",
" -0.46468867, -0.46469927, -0.46469489, -0.46469595]),\n",
" 'split1_train_score': array([-0.69314718, -0.64188719, -0.67797602, -0.52532854, -0.47970513,\n",
" -0.47076994, -0.4669271 , -0.46678157, -0.46671611, -0.4667146 ,\n",
" -0.4667139 , -0.46671388, -0.46671388, -0.46671388]),\n",
" 'split2_test_score': array([-0.69314718, -0.64575466, -0.6680453 , -0.5512625 , -0.54752984,\n",
" -0.54301469, -0.56469495, -0.56426612, -0.56847801, -0.56834734,\n",
" -0.56879909, -0.56879096, -0.56883272, -0.5688357 ]),\n",
" 'split2_train_score': array([-0.69314718, -0.63908575, -0.66162164, -0.5135644 , -0.45564429,\n",
" -0.44783203, -0.44194643, -0.44184498, -0.44173178, -0.44173047,\n",
" -0.44172921, -0.4417292 , -0.44172919, -0.44172919]),\n",
" 'split3_test_score': array([-0.69314718, -0.63856226, -0.67581323, -0.50975335, -0.45503223,\n",
" -0.44706595, -0.44365688, -0.44377717, -0.44397598, -0.44400066,\n",
" -0.44403119, -0.4440332 , -0.44403654, -0.44403656]),\n",
" 'split3_train_score': array([-0.69314718, -0.64333601, -0.67874616, -0.5309552 , -0.48382671,\n",
" -0.47509784, -0.47066678, -0.47052445, -0.4704446 , -0.47044314,\n",
" -0.47044228, -0.47044226, -0.47044225, -0.47044225]),\n",
" 'split4_test_score': array([-0.69314718, -0.64152376, -0.67221592, -0.52767402, -0.47216067,\n",
" -0.47104103, -0.46583102, -0.46617809, -0.46606368, -0.46612355,\n",
" -0.46611895, -0.4661268 , -0.46612564, -0.46612722]),\n",
" 'split4_train_score': array([-0.69314718, -0.64131272, -0.67371601, -0.52389506, -0.47771499,\n",
" -0.46916638, -0.46505718, -0.46489609, -0.46482209, -0.46482048,\n",
" -0.46481966, -0.46481964, -0.46481963, -0.46481963]),\n",
" 'std_fit_time': array([0.00086323, 0.00059983, 0.00060702, 0.00071247, 0.00065826,\n",
" 0.00033766, 0.00033068, 0.00036014, 0.00111406, 0.00057291,\n",
" 0.00033485, 0.00048554, 0.0004031 , 0.00028537]),\n",
" 'std_score_time': array([9.43189981e-04, 4.28180928e-04, 3.54302375e-04, 4.90339327e-05,\n",
" 1.43874296e-04, 7.93332081e-05, 1.74334456e-04, 5.03879477e-04,\n",
" 4.15763520e-04, 8.80480163e-04, 1.15419279e-04, 4.23557745e-04,\n",
" 2.53882737e-04, 3.60673921e-04]),\n",
" 'std_test_score': array([0. , 0.00243715, 0.00305426, 0.0132657 , 0.03206037,\n",
" 0.03276814, 0.04294169, 0.04285084, 0.04453539, 0.04448689,\n",
" 0.04466504, 0.04466183, 0.04467864, 0.04467945]),\n",
" 'std_train_score': array([0. , 0.00138524, 0.00757197, 0.00566626, 0.00992521,\n",
" 0.00965834, 0.01033363, 0.01031565, 0.01033126, 0.01033118,\n",
" 0.01033136, 0.01033136, 0.01033136, 0.01033136])}"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grid.cv_results_"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"test_means = grid.cv_results_[ 'mean_test_score' ]\n",
"test_stds = grid.cv_results_[ 'std_test_score' ]\n",
"train_means = grid.cv_results_[ 'mean_train_score' ]\n",
"train_stds = grid.cv_results_[ 'std_train_score' ]\n",
"\n",
"\n",
"n_Cs = len(Cs)\n",
"number_penaltys = len(penaltys)\n",
"test_scores = numpy.array(test_means).reshape(n_Cs,number_penaltys)\n",
"train_scores = numpy.array(train_means).reshape(n_Cs,number_penaltys)\n",
"test_stds = numpy.array(test_stds).reshape(n_Cs,number_penaltys)\n",
"train_stds = numpy.array(train_stds).reshape(n_Cs,number_penaltys)\n",
"\n",
"x_axis = numpy.log10(Cs)\n",
"for i, value in enumerate(penaltys):\n",
" #pyplot.plot(log(Cs), test_scores[i], label= 'penalty:' + str(value))\n",
" matplot.errorbar(x_axis, test_scores[:,i], yerr=test_stds[:,i] ,label = penaltys[i] +' Test')\n",
" matplot.errorbar(x_axis, train_scores[:,i], yerr=train_stds[:,i] ,label = penaltys[i] +' Train')\n",
" \n",
"matplot.legend()\n",
"matplot.xlabel( 'log(C)' ) \n",
"matplot.ylabel( 'neg-logloss' )\n",
"matplot.savefig('LogisticGridSearchCV_C.png' )\n",
"\n",
"matplot.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"从图标中可以看出,参数C=1时,负损失最小"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7337662337662337, 0.7402597402597403)\n"
]
}
],
"source": [
"lr_l1= LogisticRegression(penalty='l1',C=1)\n",
"lr_l1.fit(x_train,y_train)\n",
"accuracy_l1 = lr_l1.score(x_test, y_test)\n",
"lr_l2= LogisticRegression(penalty='l2',C=1)\n",
"lr_l2.fit(x_train,y_train)\n",
"accuracy_l2 = lr_l2.score(x_test, y_test)\n",
"print(accuracy_l1,accuracy_l2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"发现带入参数后,L2正则比L1正则性能更好"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Default SVC"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.svm import LinearSVC\n",
"import sklearn.metrics as metrics\n",
"SVC1 = LinearSVC().fit(x_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification report for classifier LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n",
" intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n",
" multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n",
" verbose=0):\n",
" precision recall f1-score support\n",
"\n",
" 0 0.75 0.90 0.82 99\n",
" 1 0.72 0.47 0.57 55\n",
"\n",
"avg / total 0.74 0.75 0.73 154\n",
"\n",
"\n",
"Confusion matrix:\n",
"[[89 10]\n",
" [29 26]]\n"
]
}
],
"source": [
"y_predict = SVC1.predict(x_test)\n",
"\n",
"print(\"Classification report for classifier %s:\\n%s\\n\"% (SVC1, metrics.classification_report(y_test, y_predict)))\n",
"print(\"Confusion matrix:\\n%s\" % metrics.confusion_matrix(y_test, y_predict))"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"def fit_grid_point_Linear(C, x_train, y_train, x_val, y_val):\n",
" \n",
" # 在训练集是那个利用SVC训练\n",
" SVC2 = LinearSVC( C = C)\n",
" SVC2 = SVC2.fit(x_train, y_train)\n",
" \n",
" # 在校验集上返回accuracy\n",
" accuracy = SVC2.score(x_val, y_val)\n",
" \n",
" print(\"accuracy: {}\".format(accuracy))\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy: 0.727272727273\n",
"accuracy: 0.733766233766\n",
"accuracy: 0.74025974026\n",
"accuracy: 0.733766233766\n",
"accuracy: 0.753246753247\n",
"accuracy: 0.701298701299\n",
"accuracy: 0.701298701299\n"
]
}
],
"source": [
"C_s = numpy.logspace(-3, 3, 7)\n",
"accuracy_s = []\n",
"for i, oneC in enumerate(C_s):\n",
"# for j, penalty in enumerate(penalty_s):\n",
" tmp = fit_grid_point_Linear(oneC, x_train, y_train, x_test, y_test)\n",
" accuracy_s.append(tmp)"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x_axis = numpy.log10(C_s)\n",
"\n",
"matplot.plot(x_axis, numpy.array(accuracy_s), 'b-')\n",
" \n",
"matplot.legend()\n",
"matplot.xlabel( 'log(C)' ) \n",
"matplot.ylabel( 'accuracy' )\n",
"matplot.savefig('SVM_Otto.png' )\n",
"\n",
"matplot.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"可以由图看吃,C的取值为1时,准确度比较高\n",
"\n",
"### RBF核"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
"def fit_grid_point_RBF(C, gamma, x_train, y_train, x_val, y_val):\n",
" \n",
" # 在训练集是那个利用SVC训练\n",
" SVC3 = SVC( C = C, kernel='rbf', gamma = gamma)\n",
" SVC3 = SVC3.fit(x_train, y_train)\n",
" \n",
" # 在校验集上返回accuracy\n",
" accuracy = SVC3.score(x_val, y_val)\n",
" \n",
" print(\"accuracy: {}\".format(accuracy))\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.707792207792\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.733766233766\n",
"accuracy: 0.74025974026\n",
"accuracy: 0.707792207792\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.753246753247\n",
"accuracy: 0.727272727273\n",
"accuracy: 0.688311688312\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.746753246753\n",
"accuracy: 0.662337662338\n",
"accuracy: 0.688311688312\n",
"accuracy: 0.642857142857\n",
"accuracy: 0.642857142857\n"
]
}
],
"source": [
"from sklearn.svm import SVC\n",
"C_s = numpy.logspace(-2, 2, 5)\n",
"gamma_s = numpy.logspace(-2, 2, 5) \n",
"\n",
"accuracy_s = []\n",
"for i, oneC in enumerate(C_s):\n",
" for j, gamma in enumerate(gamma_s):\n",
" tmp = fit_grid_point_RBF(oneC, gamma, x_train, y_train, x_test, y_test)\n",
" accuracy_s.append(tmp)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 从结果上看,gamma值不同准确度也不同,最高在0.75"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"accuracy_s1 =numpy.array(accuracy_s).reshape(len(C_s),len(gamma_s))\n",
"x_axis = numpy.log10(C_s)\n",
"for j, gamma in enumerate(gamma_s):\n",
" matplot.plot(x_axis, numpy.array(accuracy_s1[:,j]), label = ' Test - log(gamma)' + str(numpy.log10(gamma)))\n",
"\n",
"matplot.legend()\n",
"matplot.xlabel( 'log(C)' ) \n",
"matplot.ylabel( 'accuracy' )\n",
"matplot.savefig('RBF_SVM_Otto.png' )\n",
"\n",
"matplot.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 由RBF核看出,参数取值为1时测试结果相对比较稳定。如果c取值为10,则结果有时准确度会低于0.7,可能发生了过拟合\n",
"\n",
"并且根据结果准确度和负损失来看,SVM比Logistic回归更好。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
一键复制
编辑
Web IDE
原始数据
按行查看
历史
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。