赞
踩
决策树分类的应用场景非常广泛,在各行各业都有应用,比如在金融行业可以用决策树做贷款风险评估,医疗行业可以用决策树生成辅助诊断,电商行业可以用决策树对销售额进行预测等。
我们利用 sklearn 工具中的决策树分类器解决一个实际的问题:泰坦尼克号乘客的生存预测。
问题描述
泰坦尼克海难是著名的十大灾难之一,究竟多少人遇难,各方统计的结果不一。项目全部内容可以到我的github下载:https://github.com/Richard88888/Titanic_competition
具体流程分为以下几个步骤:
首先加载数据
# 加载数据
train_data = pd.read_csv('train.csv')
test_data = pd.read_csv('test.csv')
import pandas as pd
print(train_data.info()) # 了解数据表的基本情况:行数、列数、每列的数据类型、数据完整度
print('-'*30)
print(train_data.describe()) # 了解数据表的统计情况:总数、平均值、标准差、最小值、最大值等
print('-'*30)
print(train_data.describe(include=['O'])) # 查看字符串类型(非数字)的整体情况
print('-'*30)
print(train_data.head()) # 查看前几行数据(默认是前5行)
print('-'*30)
print(train_data.tail()) # 查看后几行数据(默认是最后5行)
运行结果
<class 'pandas.core.frame.DataFrame'> RangeIndex: 891 entries, 0 to 890 Data columns (total 12 columns): PassengerId 891 non-null int64 Survived 891 non-null int64 Pclass 891 non-null int64 Name 891 non-null object Sex 891 non-null object Age 714 non-null float64 SibSp 891 non-null int64 Parch 891 non-null int64 Ticket 891 non-null object Fare 891 non-null float64 Cabin 204 non-null object Embarked 889 non-null object dtypes: float64(2), int64(5), object(5) memory usage: 83.6+ KB None ------------------------------ PassengerId Survived ... Parch Fare count 891.000000 891.000000 ... 891.000000 891.000000 mean 446.000000 0.383838 ... 0.381594 32.204208 std 257.353842 0.486592 ... 0.806057 49.693429 min 1.000000 0.000000 ... 0.000000 0.000000 25% 223.500000 0.000000 ... 0.000000 7.910400 50% 446.000000 0.000000 ... 0.000000 14.454200 75% 668.500000 1.000000 ... 0.000000 31.000000 max 891.000000 1.000000 ... 6.000000 512.329200 [8 rows x 7 columns] ------------------------------ Name Sex ... Cabin Embarked count 891 891 ... 204 889 unique 891 2 ... 147 3 top Peter, Mrs. Catherine (Catherine Rizk) male ... B96 B98 S freq 1 577 ... 4 644 [4 rows x 5 columns] ------------------------------ PassengerId Survived Pclass ... Fare Cabin Embarked 0 1 0 3 ... 7.2500 NaN S 1 2 1 1 ... 71.2833 C85 C 2 3 1 3 ... 7.9250 NaN S 3 4 1 1 ... 53.1000 C123 S 4 5 0 3 ... 8.0500 NaN S [5 rows x 12 columns] ------------------------------ PassengerId Survived Pclass ... Fare Cabin Embarked 886 887 0 2 ... 13.00 NaN S 887 888 1 1 ... 30.00 B42 S 888 889 0 3 ... 23.45 NaN S 889 890 1 1 ... 30.00 C148 C 890 891 0 3 ... 7.75 NaN Q [5 rows x 12 columns]
# 使用平均年龄来填充年龄中的nan值
train_data['Age'].fillna(train_data['Age'].mean(), inplace=True)
test_data['Age'].fillna(test_data['Age'].mean(),inplace=True)
# 使用票价的均值填充票价中的nan值
train_data['Fare'].fillna(train_data['Fare'].mean(), inplace=True)
test_data['Fare'].fillna(test_data['Fare'].mean(),inplace=True)
print('-'*30)
print(train_data['Embarked'].value_counts())
# 使用登陆最多的港口来填充港口的nan值
train_data['Embarked'].fillna('S',inplace=True)
test_data['Embarked'].fillna('S',inplace=True)
运行结果
S 644
C 168
Q 77
from sklearn.feature_extraction import DictVectorizer
# PassengerId 为乘客编号,对分类没有作用,可以放弃;Name 为乘客姓名,对分类没有作用,可以放弃;Cabin 字段缺失值太多,可以放弃;Ticket 字段为船票号码,杂乱无章且无规律,可以放弃。
# 其余的字段包括:Pclass、Sex、Age、SibSp、Parch 和 Fare,这些属性分别表示了乘客的船票等级、性别、年龄、亲戚数量以及船票价格,可能会和乘客的生存预测分类有关系。
features = ['Pclass','Sex','Age','SibSp','Parch','Fare','Embarked']
train_features = train_data[features]
train_labels = train_data['Survived']
test_features = test_data[features]
# 将特征值中字符串表示为0.1表示
dvec = DictVectorizer(sparse=False)
# fit_transform() 将特征向量转化为特征值矩阵
train_features = dvec.fit_transform(train_features.to_dict(orient='record'))
# 查看dvec转化后的特征属性
print(dvec.feature_names_)
运行结果
['Age', 'Embarked=C', 'Embarked=Q', 'Embarked=S', 'Fare', 'Parch', 'Pclass', 'Sex=female', 'Sex=male', 'SibSp']
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(criterion='entropy')
# 决策树训练
clf.fit(train_features,train_labels)
test_features = dvec.transform(test_features.to_dict(orient='record'))
# 决策树预测
pred_labels = clf.predict(test_features).astype(np.int64)
print(pred_labels)
运行结果
[0 0 0 1 1 0 0 0 0 0 0 0 1 0 1 1 0 1 1 0 1 1 1 0 1 0 1 1 1 0 0 0 1 0 1 1 0
0 0 1 0 0 0 1 1 0 0 0 1 1 0 0 1 0 1 0 0 0 0 1 0 0 0 1 1 1 0 0 0 1 1 0 0 0
1 0 0 1 0 1 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 1 1 1 1 1 0 0 0 1 0 0 0 1 0 0
0 1 0 1 0 0 1 1 1 1 0 1 0 0 1 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0
0 0 1 0 0 1 0 0 1 0 0 0 0 1 1 0 0 1 1 0 1 0 0 0 0 0 0 1 1 1 1 1 0 0 1 0 1
0 1 0 0 0 0 0 1 1 1 0 1 0 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 0 1 0
1 1 1 0 0 0 0 0 0 1 1 0 1 0 1 1 1 1 1 1 0 0 0 0 1 0 1 0 1 0 0 0 0 0 0 0 1
0 0 0 1 1 0 0 0 0 0 0 1 0 1 1 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1 0 0 1 1 1 0 0 0 0 0 0 1 1 0 1 0 0 0 1 0 0
1 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 0 0 0 0 0 1 0 0 1 0 1 1 1 1 0 0 1 1 1
0 1 0 0 1 1 0 0 0 0 0 0 0 1 0 1 0 0 1 0 1 1 1 0 0 1 0 1 0 0 1 0 1 0 0 1 0
0 0 0 0 1 0 0 1 0 0 0]
id = test_data['PassengerId']
sub = {'PassengerId': id, 'Survived': pred_labels}
submission = pd.DataFrame(sub)
submission.to_csv("submission.csv", index=False)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。