当前位置:   article > 正文

机器学习实战:Kaggle泰坦尼克号生存预测 利用决策树进行预测_kaggle决策树案例

kaggle决策树案例

决策树分类的应用场景非常广泛,在各行各业都有应用,比如在金融行业可以用决策树做贷款风险评估,医疗行业可以用决策树生成辅助诊断,电商行业可以用决策树对销售额进行预测等。
我们利用 sklearn 工具中的决策树分类器解决一个实际的问题:泰坦尼克号乘客的生存预测。
问题描述
泰坦尼克海难是著名的十大灾难之一,究竟多少人遇难,各方统计的结果不一。项目全部内容可以到我的github下载:https://github.com/Richard88888/Titanic_competition
具体流程分为以下几个步骤:在这里插入图片描述

  1. 准备阶段:我们首先需要对训练集、测试集的数据进行探索,分析数据质量,并对数据进行清洗,然后通过特征选择对数据进行降维,方便后续分类运算;
  2. 分类阶段:首先通过训练集的特征矩阵、分类结果得到决策树分类器,然后将分类器应用于测试集。然后我们对决策树分类器的准确性进行分析,并对决策树模型进行可视化。
    接下来我们进行一一讲解。

首先加载数据

# 加载数据
train_data = pd.read_csv('train.csv')
test_data = pd.read_csv('test.csv')
  • 1
  • 2
  • 3
  • 第一步:数据探索
import pandas as pd
print(train_data.info())  # 了解数据表的基本情况:行数、列数、每列的数据类型、数据完整度
print('-'*30)
print(train_data.describe())  # 了解数据表的统计情况:总数、平均值、标准差、最小值、最大值等
print('-'*30)
print(train_data.describe(include=['O']))  # 查看字符串类型(非数字)的整体情况
print('-'*30)
print(train_data.head())  # 查看前几行数据(默认是前5行)
print('-'*30)
print(train_data.tail())  # 查看后几行数据(默认是最后5行)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

运行结果


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
PassengerId    891 non-null int64
Survived       891 non-null int64
Pclass         891 non-null int64
Name           891 non-null object
Sex            891 non-null object
Age            714 non-null float64
SibSp          891 non-null int64
Parch          891 non-null int64
Ticket         891 non-null object
Fare           891 non-null float64
Cabin          204 non-null object
Embarked       889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.6+ KB
None
------------------------------
       PassengerId    Survived     ...           Parch        Fare
count   891.000000  891.000000     ...      891.000000  891.000000
mean    446.000000    0.383838     ...        0.381594   32.204208
std     257.353842    0.486592     ...        0.806057   49.693429
min       1.000000    0.000000     ...        0.000000    0.000000
25%     223.500000    0.000000     ...        0.000000    7.910400
50%     446.000000    0.000000     ...        0.000000   14.454200
75%     668.500000    1.000000     ...        0.000000   31.000000
max     891.000000    1.000000     ...        6.000000  512.329200

[8 rows x 7 columns]
------------------------------
                                          Name   Sex   ...       Cabin Embarked
count                                      891   891   ...         204      889
unique                                     891     2   ...         147        3
top     Peter, Mrs. Catherine (Catherine Rizk)  male   ...     B96 B98        S
freq                                         1   577   ...           4      644

[4 rows x 5 columns]
------------------------------
   PassengerId  Survived  Pclass    ...        Fare Cabin  Embarked
0            1         0       3    ...      7.2500   NaN         S
1            2         1       1    ...     71.2833   C85         C
2            3         1       3    ...      7.9250   NaN         S
3            4         1       1    ...     53.1000  C123         S
4            5         0       3    ...      8.0500   NaN         S

[5 rows x 12 columns]
------------------------------
     PassengerId  Survived  Pclass    ...      Fare Cabin  Embarked
886          887         0       2    ...     13.00   NaN         S
887          888         1       1    ...     30.00   B42         S
888          889         0       3    ...     23.45   NaN         S
889          890         1       1    ...     30.00  C148         C
890          891         0       3    ...      7.75   NaN         Q

[5 rows x 12 columns]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 第二步:数据清洗
# 使用平均年龄来填充年龄中的nan值
train_data['Age'].fillna(train_data['Age'].mean(), inplace=True)
test_data['Age'].fillna(test_data['Age'].mean(),inplace=True)
# 使用票价的均值填充票价中的nan值
train_data['Fare'].fillna(train_data['Fare'].mean(), inplace=True)
test_data['Fare'].fillna(test_data['Fare'].mean(),inplace=True)
print('-'*30)
print(train_data['Embarked'].value_counts())
# 使用登陆最多的港口来填充港口的nan值
train_data['Embarked'].fillna('S',inplace=True)
test_data['Embarked'].fillna('S',inplace=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

运行结果


S    644
C    168
Q     77
  • 1
  • 2
  • 3
  • 4
  • 第三步:特征选择
from sklearn.feature_extraction import DictVectorizer
# PassengerId 为乘客编号,对分类没有作用,可以放弃;Name 为乘客姓名,对分类没有作用,可以放弃;Cabin 字段缺失值太多,可以放弃;Ticket 字段为船票号码,杂乱无章且无规律,可以放弃。
# 其余的字段包括:Pclass、Sex、Age、SibSp、Parch 和 Fare,这些属性分别表示了乘客的船票等级、性别、年龄、亲戚数量以及船票价格,可能会和乘客的生存预测分类有关系。
features = ['Pclass','Sex','Age','SibSp','Parch','Fare','Embarked']
train_features = train_data[features]
train_labels = train_data['Survived']
test_features = test_data[features]
# 将特征值中字符串表示为0.1表示
dvec = DictVectorizer(sparse=False)
# fit_transform() 将特征向量转化为特征值矩阵
train_features = dvec.fit_transform(train_features.to_dict(orient='record'))
# 查看dvec转化后的特征属性
print(dvec.feature_names_)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

运行结果

['Age', 'Embarked=C', 'Embarked=Q', 'Embarked=S', 'Fare', 'Parch', 'Pclass', 'Sex=female', 'Sex=male', 'SibSp']
  • 1
  • 第四步:构建决策树模型
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(criterion='entropy')
# 决策树训练
clf.fit(train_features,train_labels)
  • 1
  • 2
  • 3
  • 4
  • 第五步:模型预测
test_features = dvec.transform(test_features.to_dict(orient='record'))
# 决策树预测
pred_labels = clf.predict(test_features).astype(np.int64)
print(pred_labels)
  • 1
  • 2
  • 3
  • 4

运行结果

[0 0 0 1 1 0 0 0 0 0 0 0 1 0 1 1 0 1 1 0 1 1 1 0 1 0 1 1 1 0 0 0 1 0 1 1 0
 0 0 1 0 0 0 1 1 0 0 0 1 1 0 0 1 0 1 0 0 0 0 1 0 0 0 1 1 1 0 0 0 1 1 0 0 0
 1 0 0 1 0 1 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 1 1 1 1 1 0 0 0 1 0 0 0 1 0 0
 0 1 0 1 0 0 1 1 1 1 0 1 0 0 1 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0
 0 0 1 0 0 1 0 0 1 0 0 0 0 1 1 0 0 1 1 0 1 0 0 0 0 0 0 1 1 1 1 1 0 0 1 0 1
 0 1 0 0 0 0 0 1 1 1 0 1 0 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 0 1 0
 1 1 1 0 0 0 0 0 0 1 1 0 1 0 1 1 1 1 1 1 0 0 0 0 1 0 1 0 1 0 0 0 0 0 0 0 1
 0 0 0 1 1 0 0 0 0 0 0 1 0 1 1 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1 0 0 1 1 1 0 0 0 0 0 0 1 1 0 1 0 0 0 1 0 0
 1 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 0 0 0 0 0 1 0 0 1 0 1 1 1 1 0 0 1 1 1
 0 1 0 0 1 1 0 0 0 0 0 0 0 1 0 1 0 0 1 0 1 1 1 0 0 1 0 1 0 0 1 0 1 0 0 1 0
 0 0 0 0 1 0 0 1 0 0 0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 结果写入submission.csv文件,提交kaggle
id = test_data['PassengerId']
sub = {'PassengerId': id, 'Survived': pred_labels}
submission = pd.DataFrame(sub)
submission.to_csv("submission.csv", index=False)
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/674687
推荐阅读
相关标签
  

闽ICP备14008679号