当前位置:   article > 正文

python 决策树分类 泰坦尼克生存预测_泰坦尼克号生存预测python

泰坦尼克号生存预测python

一、项目简介

官方链接:Titanic - Machine Learning from Disaster

1.1 项目背景

  • 1、泰坦尼克号: 英国白星航运公司下辖的一艘奥林匹克级邮轮,于1909年3月31日在爱尔兰贝尔法斯特港的哈兰德与沃尔夫造船厂动工建造,1911年5月31日下水,1912年4月2日完工试航。
  • 2、首航时间: 1912年4月10日
  • 3、航线: 从英国南安普敦出发,途经法国瑟堡-奥克特维尔以及爱尔兰昆士敦,驶向美国纽约。
  • 4、沉船: 1912年4月15日(1912年4月14日23时40分左右撞击冰山)
    船员+乘客人数:2224
  • 5、遇难人数: 1502(67.5%)

1.2 目标问题

  • 根据训练集中各位乘客的特征及是否获救标志的对应关系训练模型,预测测试集中的乘客是否获救。(二元分类问题

1.3 字段描述

在这里插入图片描述

二、训练集(train)建模

2.1 导入相关库

import numpy as np
import pandas as pd
from scipy import stats

# sklearn 相关库
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.metrics import confusion_matrix,accuracy_score,roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score

# 可视化相关库
import seaborn as sns
import matplotlib.pyplot as plt

# 解决mac 系统画图中文不显示问题
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] 

# # 解决win 系统中文不显示问题
# from pylab import mpl
# mpl.rcParams['font.sans-serif']=['SimHei']

# 不显示警告
import warnings
warnings.filterwarnings('ignore')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

2.2 自定义函数

def PieChart(df):
    '''
        绘制环形饼图
    '''
    plt.figure(
        figsize = (4,4), # 设置图片大小
        dpi = 100        # 精度
    )
    df.value_counts().plot( 
        kind = 'pie',               # 设置绘图类型为饼图
        wedgeprops = {'width':0.4}, # 设置空心比例
        autopct = "%.1f%%"          # 显示百分比
    )

def BarPlot(df,ColumnsName):
    '''
        绘制不同 ColumnsName 的存活人数柱形图
    '''
    ColumnsDf = df.groupby(['Survived',ColumnsName]).count()[['PassengerId']].reset_index()\
            .rename(columns={"PassengerId":"Count"})
    plt.figure(figsize=(4,3),dpi=150)
    sns.barplot(
        data=ColumnsDf,
        x=ColumnsName,
        y="Count",
        hue="Survived"
    )
    plt.title('Survived Count Of {}'.format(ColumnsName))
    
def OneHot(x):
    '''
        功能:one-hot 编码
        传入:需要编码的分类变量
        返回:返回编码后的结果,形式为 dataframe
    '''
    # 通过 LabelEncoder 将分类变量打上数值标签 
    lb = LabelEncoder()                             # 初始化
    x_pre = lb.fit_transform(x)                     # 模型拟合
    x_dict = dict([[i,j] for i,j in zip(x,x_pre)])  # 生成编码字典--> {'收藏': 1, '点赞': 2, '关注': 0}
    x_num = [[x_dict[i]] for i in x]                # 通过 x_dict 将分类变量转为数值型
    
    # 进行one-hot编码
    enc = OneHotEncoder()                        # 初始化
    enc.fit(x_num)                               # 模型拟合
    array_data = enc.transform(x_num).toarray()  # one-hot 编码后的结果,二维数组形式
    
    # 转成 dataframe 形式
    df = pd.DataFrame(array_data)
    inverse_dict = dict([val,key] for key,val in x_dict.items()) # 反转 x_dict 的键、值
    # columns 重命名
    if type(x) == pd.Series:
        firs_name = x.name
    else:
        firs_name = ""
    df.columns = [firs_name+"_"+inverse_dict[i] for i in df.columns]           
    
    return df
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

2.3 特征工程

2.3.1 数据导入

train = pd.read_csv("train.csv")
train.head(5)
  • 1
  • 2

在这里插入图片描述

2.3.2 数据初探

(1)特征信息
train.info()
  • 1
  • 可以看出训练集共有891个样本,且有三个字段(Age、Cabin、Embarked)存在缺失值。
    在这里插入图片描述
(2)特征缺失值比例统计
train.isnull().sum()/len(train)
  • 1
  • 可以看出,字段Cabin缺失比例较大,达到77%。
    在这里插入图片描述
(3)数值特征描述统计
train.describe()
  • 1
  • 可以看出,票价(Fare)最低为0,估计是船上的员工。
    在这里插入图片描述

2.3.3 单特征可视化分析与处理

(1)Survived 是否存活
########################## 1、Survived 是否存活 ##########################
# Y标签,{0:不存活,1:存活}
# 有无缺失值:无
# 数据处理:不处理
# 从图中可以看出,死亡人数与存活人数占比差异不大

PieChart(train['Survived'])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述

(2)Pclass 乘客等级
########################## 2、Pclass 乘客等级 ##########################
# 无缺失值,等级变量
# 用柱状图查看各乘客等级的存活情况
# 可以看出 Pclass=3 的乘客中,存活人数远低于死亡人数
BarPlot(train,"Pclass")

# 数据处理:将Pclass分成两类,Pclass>=3、Pclass<3
train['PclassType'] = ["Pclass>=3" if i >= 3 else "Pclass<3" for i in train['Pclass']]

# 查看不同 PclassType 的存活情况
BarPlot(train,"PclassType")

# 再对 PclassType 进行One-Hot编码处理
train = pd.merge(
    train,
    OneHot(train['PclassType']),
    left_index=True,
    right_index=True
)    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

在这里插入图片描述

在这里插入图片描述

(3)Name 乘客姓名
########################## 3、Name 乘客姓名 ##########################
# 字符串变量
# 有无缺失值:无

# 从乘客姓名中获取头街
# 姓名中头街字符串与定义头街类别之间的关系
#     Officer: 政府官员,
#     RoyaIty: 王室(皇室),
#     Mr:      已婚男士,
#     Mrs:     已婚女士,
#     Miss:    年轻未婚女子,
#     Master:  有技能的人/教师 
# 新建字段 Title_Dict 
Title_Dict = {
    'Mr':'Mr',
    'Mrs':'Mrs', 
    'Miss':'Miss',
    'Master': 'Master', 
    'Don':'Royalty',
    'Rev':'Officer',
    'Dr':')fficer', 
    'Mme':'Mrs',
    'Ms':'Mrs',
    'Major':'Officer', 
    'Lady': 'Royalty',
    'Sir': 'Royalty',
    'Mlle':'Miss', 
    'Col': 'Officer',
    'Capt':'Officer',
    'the Countess': 'Royalty',
    'Jonkheer': 'Royalty',
    'Dona': 'Royalty'
}
train['NameType'] = [Title_Dict[i.split(".")[0].split(", ")[-1]] for i in train['Name']] # 对Name进行分类
# 用柱状图查看各 NameType 的存活情况
# 可以看出 乘客为Mr(已婚男士)中,死亡人数远远大于存活人数;
#        乘客为Mrs(已婚女士)、Miss(年轻未婚女子)中,死亡人数远远低于存活人数;
BarPlot(train,"NameType")

# 数据进一步处理:将 NameType 分成三类
# Mr(已婚男士)
# Mrs(已婚女士)、Miss(年轻未婚女子)
# 其他
train['NameType2'] = ["Mr" if i == "Mr" else ("Mrs and Miss" if i in ['Mrs','Miss'] else "Other") \
                      for i in train['NameType']]

# 查看不同 NameType2 的存活情况
BarPlot(train,"NameType2")

# 再对 NameType2 进行One-Hot编码处理
train = pd.merge(
    train,
    OneHot(train['NameType2']),
    left_index=True,
    right_index=True
)     
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

在这里插入图片描述

在这里插入图片描述

(4)Sex 性别
########################## 4、Sex 性别 ##########################
# 分类变量
# 有无缺失值:无

# 用柱状图查看各 NameType 的存活情况
# 可以看出 乘客为男性中,死亡人数远远大于存活人数
BarPlot(train,"Sex")

# 对 Sex 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['Sex']),left_index=True,right_index=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

在这里插入图片描述

(5)Age 年龄
########################## 5、Age 年龄 ##########################
# 连续变量
# 有无缺失值:有,缺失比例19.9%

# 缺失值用均值填充
train['Age'] = train['Age'].fillna(train['Age'].mean())

# 用直方图查看各 Age 的存活情况
# 可以看出 可以看出小于5岁的小孩存活率很高
plt.figure(figsize=(8,4),dpi=150)
sns.distplot(train[train['Survived']==0]['Age'],color="red",kde=False)
sns.distplot(train[train['Survived']==1]['Age'],color="blue",kde=False)

# 数据处理:将 Age 分成两类,Age<=5、Age>5
train['AgeType'] = ["Age<=5" if i <= 5 else "Age>5"  for i in train['Age']]

# 查看不同 AgeType 的存活情况
BarPlot(train,"AgeType")

# 再对 AgeType 进行One-Hot编码处理
train = pd.merge(
    train,
    OneHot(train['AgeType']),
    left_index=True,
    right_index=True
)   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

在这里插入图片描述
在这里插入图片描述

(6)SibSp 堂兄弟妹个数
########################## 6、SibSp 堂兄弟妹个数 ##########################
# 无缺失值,等级变量
# 用柱状图查看各堂兄弟妹个数的存活情况
# 可以看出 SibSp=0 的乘客中,死亡人数较多
BarPlot(train,"SibSp")

# 数据处理:将 SibSp 分成两类,SibSp=0、SibSp>0
train['SibSpType'] = ["SibSp=0" if i == 0 else "SibSp>0" for i in train['SibSp']]

# 查看不同 SibSpType 的存活情况
BarPlot(train,"SibSpType")

# 再对 SibSpType 进行One-Hot编码处理
train = pd.merge(
    train,
    OneHot(train['SibSpType']),
    left_index=True,
    right_index=True
)    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

在这里插入图片描述

在这里插入图片描述

(7)Parch 父母与小孩的个数
########################## 7、Parch 父母与小孩的个数 ##########################
# 连续变量
# 有无缺失值:无

# 用柱状图查看父母与小孩的个数的存活情况
# 可以看出 Parch=0 的乘客中,死亡人数较多
BarPlot(train,"Parch")

# 数据处理:将 Parch 分成两类,Parch=0、Parch>0
train['ParchType'] = ["Parch=0" if i == 0 else "Parch>0" for i in train['Parch']]

# 查看不同 ParchType 的存活情况
BarPlot(train,"ParchType")

# 再对 ParchType 进行One-Hot编码处理
train = pd.merge(
    train,
    OneHot(train['ParchType']),
    left_index=True,
    right_index=True
)    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

在这里插入图片描述
在这里插入图片描述

(8)Ticket 船票信息
  • 字符变量
  • 有无缺失值:无
  • 数据处理:这里直接删去(下文会删)
(9)Fare 票价
########################## 9、Fare 票价 ##########################
# 连续变量
# 有无缺失值:无

# 查看Fare(票价)= 0 的生存情况
Fare0 = train[train['Fare']==0]
Fare0Survived = Fare0.groupby(['Survived']).count()[['PassengerId']].reset_index().rename(columns={"PassengerId":"Count"})
plt.figure(figsize=(4,3),dpi=150)
sns.barplot(
    data=Fare0Survived,
    x="Survived",
    y="Count"
)
plt.title('Survived Count Of Fare=0')

# 查看Fare(票价)!= 0 的生存情况
Fare1 = train[train['Fare']!=0]
plt.figure(figsize=(8,4),dpi=150)
sns.distplot(Fare1[Fare1['Survived']==0]['Fare'],color="red",kde=False)
sns.distplot(Fare1[Fare1['Survived']==1]['Fare'],color="blue",kde=False)
plt.title('Survived Count Of Fare!=0')

# 对 Fare 分成三类
# Fare = 0
# Fare <=50
# Fare > 50
train['FareType'] = ["Fare=0" if i == 0 else ("Fare<=50" if i <= 50 else "Fare>50") for i in train['Fare']]

# 用柱状图查看不同 FareType 的存活情况
# 可以看出 Fare=0 的乘客中,乘客几乎都死亡
#        Fare <=50 的乘客中,死亡人数大于存活人数
#        Fare > 50 的乘客中,存活人数大于死亡人数
BarPlot(train,"FareType")


# 再对 FareType 进行One-Hot编码处理
train = pd.merge(
    train,
    OneHot(train['FareType']),
    left_index=True,
    right_index=True
)    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

(10)Cabin 船舱
  • 离散变量
  • 有无缺失值:有,缺失值比例高达77%
  • 数据处理:缺失值比例较大,直接删去(下文会删)
(11)Embarked 登船的港口
########################## 11、Embarked 登船的港口 ##########################
# 离散变量
# 有无缺失值:有,缺失值比例很低

# 用柱状图查看各登船的港口的存活情况
# 可以看出 Embarked=S 的乘客中,死亡人数较多
BarPlot(train,"Embarked")

# 数据处理:缺失值按众数填充,然后再进行One-hot编码处理
mode = stats.mode(train['Embarked'])[0][0] # 众数
train['Embarked'] = train['Embarked'].fillna(mode)

train = pd.merge(train,OneHot(train['Embarked']),left_index=True,right_index=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

在这里插入图片描述

2.3.4 衍生特征可视化分析与处理

FamilyNumbers 家庭人数
########################## FamilyNumbers 家庭人数 ##########################
# 计算方式:SibSp(堂兄弟妹个数) + Parch(父母与小孩的个数) + 1(自己)
train['FamilyNumbers'] = train['SibSp'] + train['Parch'] + 1

# 用柱状图查看各家庭人数的存活情况
# 可以看出 家庭人数=1 的乘客中,死亡人数较多
#        家庭人数>=5 的乘客中,存活人数较多
BarPlot(train,"FamilyNumbers")

# 新增 FamilyType 字段
# 1 : 单身(Single)        
# 2-4:小家庭(Family_Small)
# >4: 大家庭(Family_Large)
train['FamilyType'] = ['Single' if i == 1 else('Family_Small' if i<=4 else 'Family_Large') for i in train['FamilyNumbers']]

# 查看不同 FamilyType 的存活情况
BarPlot(train,"FamilyType")

# 对 FamilyType 进行One-hot编码处理
train = pd.merge(train,OneHot(train['FamilyType']),left_index=True,right_index=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

在这里插入图片描述

在这里插入图片描述

2.3.5 删除冗余字段

drop_columns = ['PassengerId','Pclass','PclassType','Name','NameType','NameType2','Sex','Age','AgeType',\
                'SibSp','SibSpType','Parch','ParchType','Fare','FareType','Ticket','Cabin','Embarked',\
                'FamilyNumbers','FamilyType']
train.drop(drop_columns,axis=1,inplace=True)
  • 1
  • 2
  • 3
  • 4

2.3.6 相关性矩阵可视化

  • 采用斯皮尔曼相关系数
corr_df = train.corr(method="spearman")[['Survived']].sort_values(by="Survived",ascending=False)
plt.figure(figsize=(1,8),dpi=100)
sns.heatmap(
    corr_df,
    cmap='Blues',
    center=0,
    vmax=1,
    vmin=-1,
    annot=True,
    annot_kws={'size':10,'weight':'bold', 'color':'red'}
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

在这里插入图片描述

2.4 决策树模型训练

2.4.1 数据标准化(Z-score)

def ZscoreNormalization(x):
    '''
        Z-score 标准化
    '''
    return (x - np.mean(x)) / np.std(x)
    
data = train.drop("Survived",axis=1).agg(ZscoreNormalization)
data['Lable'] = train['Survived']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2.4.2 划分训练集、测试集

  • 按7:3比例划分
x_train, x_test, y_train, y_test = train_test_split(
    data.drop("Lable",axis=1), 
    data['Lable'], 
    test_size = 0.3, 
    random_state = 0
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2.4.3 网格寻参与交叉验证

param_grid = {
    'criterion' : ['gini','entropy'], # 划分属性时选用的准则:{“gini”, “entropy”}, default=”gini”
    'splitter' : ['best','random'],   # 划分方式:{“best”, “random”}, default=”best”
    'max_depth' : range(1,6),         # 最大深度
    'min_samples_split' : range(1,6), # 拆分内部节点所需的最小样本数
    'min_samples_leaf' : range(1,6),  # 叶节点所需的最小样本数
}
clf = DecisionTreeClassifier()               # 初始化
gs = GridSearchCV(clf,param_grid,cv=5)       # 网格搜索与交叉验证
gs.fit(x_train,y_train)                      # 模型训练
print("Best Estimator: ",gs.best_estimator_) # 打印最好的分类器
print("Best Score: ",gs.best_score_)         # 打印最好分数
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

在这里插入图片描述
注意: 每次运行的结果输出会存在差别。

2.4.4 模型评价

print("\n---------- 模型评价 ----------")
y_pred = gs.predict(x_test)                         # 预测
cm = confusion_matrix(y_test, y_pred,labels=[0, 1]) # 混淆矩阵
df_cm = pd.DataFrame(cm)                            # 构建DataFrame
print('Accuracy score:', accuracy_score(y_test, y_pred))                       # 准确率
print('Recall:', recall_score(y_test, y_pred, average='weighted'))             # 召回率
print('F1-score:', f1_score(y_test, y_pred, average='weighted'))               # F1分数
print('Precision score:', precision_score(y_test, y_pred, average='weighted')) # 精确度
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在这里插入图片描述

2.4.5 混淆矩阵可视化

plt.figure(dpi=150)

heatmap = sns.heatmap(df_cm, annot=True, fmt='.0f', cmap='Blues')
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=0, ha='right')

plt.title('DecisionTreeClassifier Model Results')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在这里插入图片描述

2.4.6 ROC曲线

y_pred_proba = gs.predict_proba(np.array(x_test))[:,1]
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)

sns.set()

plt.figure(figsize=(5,4),dpi=150)
plt.plot(fpr, tpr)
plt.plot(fpr, fpr, linestyle = '-' , color = 'k')

plt.xlabel('False positive rate')
plt.ylabel('True positive rate')

AU = np.round(roc_auc_score(y_test, y_pred_proba), 2)

plt.title(f'AU: {AU}');

plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

在这里插入图片描述

三、完整代码(含对test预测)

  • 含预测,不含可视化
import numpy as np
import pandas as pd
from scipy import stats
 
# sklearn 相关库
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.metrics import confusion_matrix,accuracy_score,roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score

# 不显示红色警告
import warnings
warnings.filterwarnings('ignore')


def OneHot(x):
    '''
        功能:one-hot 编码
        传入:需要编码的分类变量
        返回:返回编码后的结果,形式为 dataframe
    '''
    # 通过 LabelEncoder 将分类变量打上数值标签 
    lb = LabelEncoder()                             # 初始化
    x_pre = lb.fit_transform(x)                     # 模型拟合
    x_dict = dict([[i,j] for i,j in zip(x,x_pre)])  # 生成编码字典--> {'收藏': 1, '点赞': 2, '关注': 0}
    x_num = [[x_dict[i]] for i in x]                # 通过 x_dict 将分类变量转为数值型
    
    # 进行one-hot编码
    enc = OneHotEncoder()                        # 初始化
    enc.fit(x_num)                               # 模型拟合
    array_data = enc.transform(x_num).toarray()  # one-hot 编码后的结果,二维数组形式
    
    # 转成 dataframe 形式
    df = pd.DataFrame(array_data)
    inverse_dict = dict([val,key] for key,val in x_dict.items()) # 反转 x_dict 的键、值
    # columns 重命名
    if type(x) == pd.Series:
        firs_name = x.name
    else:
        firs_name = ""
    df.columns = [firs_name+"_"+inverse_dict[i] for i in df.columns]           
    
    return df

def ZscoreNormalization(x):
    '''
        Z-score 标准化
    '''
    return (x - np.mean(x)) / np.std(x)


def DataClean(df,Lable=True):
    '''
        数据预处理函数
    '''
    ########################## 1、Pclass 乘客等级 ##########################
    # 无缺失值,等级变量
    # 数据处理:将Pclass分成两类,Pclass>=3、Pclass<3
    df['PclassType'] = ["Pclass>=3" if i >= 3 else "Pclass<3" for i in df['Pclass']]
    # 再对 PclassType 进行One-Hot编码处理
    df = pd.merge(df,OneHot(df['PclassType']),left_index=True,right_index=True)    
    
    ########################## 2、Name 乘客姓名 ##########################
    # 字符串变量
    # 有无缺失值:无
    # 从乘客姓名中获取头街
    # 姓名中头街字符串与定义头街类别之间的关系
    #     Officer: 政府官员,
    #     RoyaIty: 王室(皇室),
    #     Mr:      已婚男士,
    #     Mrs:     已婚女士,
    #     Miss:    年轻未婚女子,
    #     Master:  有技能的人/教师 
    # 新建字段 Title_Dict 
    Title_Dict = {
        'Mr':'Mr',
        'Mrs':'Mrs', 
        'Miss':'Miss',
        'Master': 'Master', 
        'Don':'Royalty',
        'Rev':'Officer',
        'Dr':')fficer', 
        'Mme':'Mrs',
        'Ms':'Mrs',
        'Major':'Officer', 
        'Lady': 'Royalty',
        'Sir': 'Royalty',
        'Mlle':'Miss', 
        'Col': 'Officer',
        'Capt':'Officer',
        'the Countess': 'Royalty',
        'Jonkheer': 'Royalty',
        'Dona': 'Royalty'
    }
    df['NameType'] = [Title_Dict[i.split(".")[0].split(", ")[-1]] for i in df['Name']] # 对Name进行分类
    # 数据进一步处理:将 NameType 分成三类
    # Mr(已婚男士)
    # Mrs(已婚女士)、Miss(年轻未婚女子)
    # 其他
    df['NameType2'] = ["Mr" if i == "Mr" else ("Mrs and Miss" if i in ['Mrs','Miss'] else "Other") \
                          for i in df['NameType']]
    # 再对 NameType2 进行One-Hot编码处理
    df = pd.merge(df,OneHot(df['NameType2']),left_index=True,right_index=True)   
    
    ########################## 3、Sex 性别 ##########################
    # 分类变量
    # 有无缺失值:无
    # 对 Sex 进行One-Hot编码处理
    df = pd.merge(df,OneHot(df['Sex']),left_index=True,right_index=True)
    
    ########################## 4、Age 年龄 ##########################
    # 连续变量
    # 有无缺失值:有,缺失比例19.9%
    # 缺失值用均值填充
    df['Age'] = df['Age'].fillna(df['Age'].mean())
    # 数据处理:将 Age 分成两类,Age<=5、Age>5
    df['AgeType'] = ["Age<=5" if i <= 5 else "Age>5"  for i in df['Age']]
    # 再对 AgeType 进行One-Hot编码处理
    df = pd.merge(df,OneHot(df['AgeType']),left_index=True,right_index=True)   

    ########################## 5、SibSp 堂兄弟妹个数 ##########################
    # 无缺失值,等级变量
    # 数据处理:将 SibSp 分成两类,SibSp=0、SibSp>0
    df['SibSpType'] = ["SibSp=0" if i == 0 else "SibSp>0" for i in df['SibSp']]
    # 再对 SibSpType 进行One-Hot编码处理
    df = pd.merge(df,OneHot(df['SibSpType']),left_index=True,right_index=True)    
    
    ########################## 6、Parch 父母与小孩的个数 ##########################
    # 连续变量
    # 有无缺失值:无
    # 数据处理:将 Parch 分成两类,Parch=0、Parch>0
    df['ParchType'] = ["Parch=0" if i == 0 else "Parch>0" for i in df['Parch']]
    # 再对 ParchType 进行One-Hot编码处理
    df = pd.merge(df,OneHot(df['ParchType']),left_index=True,right_index=True)    
    
    ########################## 8、Fare 票价 ##########################
    # 连续变量
    # 有无缺失值:无
    # 对 Fare 分成三类
    # Fare = 0
    # Fare <=50
    # Fare > 50
    df['FareType'] = ["Fare=0" if i == 0 else ("Fare<=50" if i <= 50 else "Fare>50") for i in df['Fare']]
    # 再对 FareType 进行One-Hot编码处理
    df = pd.merge(df,OneHot(df['FareType']),left_index=True,right_index=True)    
    
    
    ########################## 10、Embarked 登船的港口 ##########################
    # 离散变量
    # 有无缺失值:有,缺失值比例很低
    # 数据处理:缺失值按众数填充,然后再进行One-hot编码处理
    mode = stats.mode(df['Embarked'])[0][0] # 众数
    df['Embarked'] = df['Embarked'].fillna(mode)

    df = pd.merge(df,OneHot(df['Embarked']),left_index=True,right_index=True)
    
    ########################## 11、FamilyNumbers 家庭人数 ##########################
    # 计算方式:SibSp(堂兄弟妹个数) + Parch(父母与小孩的个数) + 1(自己)
    df['FamilyNumbers'] = df['SibSp'] + df['Parch'] + 1
    # 新增 FamilyType 字段
    # 1 : 单身(Single)        
    # 2-4:小家庭(Family_Small)
    # >4: 大家庭(Family_Large)
    df['FamilyType'] = ['Single' if i == 1 else('Family_Small' if i<=4 else 'Family_Large') for i in df['FamilyNumbers']]
    # 对 FamilyType 进行One-hot编码处理
    df = pd.merge(df,OneHot(df['FamilyType']),left_index=True,right_index=True)
    
    
    ########################## 删除冗余变量 ##########################
    drop_columns = ['PassengerId','Pclass','PclassType','Name','NameType','NameType2','Sex','Age','AgeType',\
                'SibSp','SibSpType','Parch','ParchType','Fare','FareType','Ticket','Cabin','Embarked',\
                'FamilyNumbers','FamilyType']
    df.drop(drop_columns,axis=1,inplace=True)

    ########################## 数据标准化 ##########################
    if Lable == True: # 判断是否是测试集(测试集不含标签)
        data = df.drop("Survived",axis=1).agg(ZscoreNormalization)
        data['Lable'] = df['Survived']
    else:
        data = df.agg(ZscoreNormalization)
    
    return data

def sklearn_DecisionTreeClassifier(data):
    '''
        决策树二分类
    '''
    # 划分训练集、测试集
    x_train, x_test, y_train, y_test = train_test_split(
        data.drop("Lable",axis=1), 
        data['Lable'], 
        test_size = 0.3, 
        random_state = 0
    )
    
    print("\n---------- 模型训练 ----------")
    # 网格寻参
    param_grid = {
        'criterion' : ['gini','entropy'], # 划分属性时选用的准则:{“gini”, “entropy”}, default=”gini”
        'splitter' : ['best','random'],   # 划分方式:{“best”, “random”}, default=”best”
        'max_depth' : range(1,6),         # 最大深度
        'min_samples_split' : range(1,6), # 拆分内部节点所需的最小样本数
        'min_samples_leaf' : range(1,6),  # 叶节点所需的最小样本数
    }
    clf = DecisionTreeClassifier()               # 初始化
    gs = GridSearchCV(clf,param_grid,cv=5)       # 网格搜索与交叉验证
    gs.fit(x_train,y_train)                      # 模型训练
    print("Best Estimator: ",gs.best_estimator_) # 打印最好的分类器
    print("Best Score: ",gs.best_score_)         # 打印最好分数
    
    # 模型预测
    print("\n---------- 模型评价 ----------")
    y_pred = gs.predict(x_test)                         # 预测
    cm = confusion_matrix(y_test, y_pred,labels=[0, 1]) # 混淆矩阵
    df_cm = pd.DataFrame(cm)                            # 构建DataFrame
    print('Accuracy score:', accuracy_score(y_test, y_pred))                       # 准确率
    print('Recall:', recall_score(y_test, y_pred, average='weighted'))             # 召回率
    print('F1-score:', f1_score(y_test, y_pred, average='weighted'))               # F1分数
    print('Precision score:', precision_score(y_test, y_pred, average='weighted')) # 精确度
    
    return gs.best_estimator_ # 返回最好的训练模型


if __name__ == "__main__":
    train = pd.read_csv("train.csv")
    test  = pd.read_csv("test.csv")
    
    print("\n---------- 数据预处理 ----------")
    train_data = DataClean(train)           
    test_data = DataClean(test,Lable=False) 
    
    # 决策树二分类
    best_estimator = sklearn_DecisionTreeClassifier(train_data) 
    
    # 预测
    y_pred = best_estimator.predict(test_data) 
    # 输出预测结果
    result = test[['PassengerId']]
    result['Survived'] = y_pred
    result.to_csv("Titanic Results.csv",index=False)
    
    print("\n程序运行完成")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244

在这里插入图片描述

四、Kaggle 得分

  • 得分:0.77511
  • 排名:7651
    在这里插入图片描述

参考
1、Kaggle泰坦尼克号比赛项目详解
2、机器学习实战——kaggle 泰坦尼克号生存预测——六种算法模型实现与比较

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/166254
推荐阅读
相关标签
  

闽ICP备14008679号