赞
踩
“You Jump,I Jump”语出经典爱情电影《泰坦尼克号》经典台词,女主角Rose在船首即将跳入海里,站在旁边的男主Jack为挽救女主,便说出经典台词“You Jump,I Jump”。当一个陌生男人肯为一个陌生女人没理由地去死的时候,毫无缘由的,女主对男主产生了爱的情愫。
当然这跟我这篇教程关系不大,这里我们将会通过AI预测Jack和Rose的存活率,国庆没断更,属实不易,需要数据集可以私聊本人or加学习群。谢谢大家支持!
import pandas as pd
df = pd.read_excel('titanic3.xls')
df.describe()
pclass | survived | age | sibsp | parch | fare | body | |
---|---|---|---|---|---|---|---|
count | 1309.000000 | 1309.000000 | 1046.000000 | 1309.000000 | 1309.000000 | 1308.000000 | 121.000000 |
mean | 2.294882 | 0.381971 | 29.881135 | 0.498854 | 0.385027 | 33.295479 | 160.809917 |
std | 0.837836 | 0.486055 | 14.413500 | 1.041658 | 0.865560 | 51.758668 | 97.696922 |
min | 1.000000 | 0.000000 | 0.166700 | 0.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 2.000000 | 0.000000 | 21.000000 | 0.000000 | 0.000000 | 7.895800 | 72.000000 |
50% | 3.000000 | 0.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 | 155.000000 |
75% | 3.000000 | 1.000000 | 39.000000 | 1.000000 | 0.000000 | 31.275000 | 256.000000 |
max | 3.000000 | 1.000000 | 80.000000 | 8.000000 | 9.000000 | 512.329200 | 328.000000 |
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1309 entries, 0 to 1308 Data columns (total 14 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 pclass 1309 non-null int64 1 survived 1309 non-null int64 2 name 1309 non-null object 3 sex 1309 non-null object 4 age 1046 non-null float64 5 sibsp 1309 non-null int64 6 parch 1309 non-null int64 7 ticket 1309 non-null object 8 fare 1308 non-null float64 9 cabin 295 non-null object 10 embarked 1307 non-null object 11 boat 486 non-null object 12 body 121 non-null float64 13 home.dest 745 non-null object dtypes: float64(3), int64(4), object(7) memory usage: 143.3+ KB
df.head()
pclass | survived | name | sex | age | sibsp | parch | ticket | fare | cabin | embarked | boat | body | home.dest | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | Allen, Miss. Elisabeth Walton | female | 29.0000 | 0 | 0 | 24160 | 211.3375 | B5 | S | 2 | NaN | St Louis, MO |
1 | 1 | 1 | Allison, Master. Hudson Trevor | male | 0.9167 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | 11 | NaN | Montreal, PQ / Chesterville, ON |
2 | 1 | 0 | Allison, Miss. Helen Loraine | female | 2.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
3 | 1 | 0 | Allison, Mr. Hudson Joshua Creighton | male | 30.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | 135.0 | Montreal, PQ / Chesterville, ON |
4 | 1 | 0 | Allison, Mrs. Hudson J C (Bessie Waldo Daniels) | female | 25.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
# 筛选需要提取的字段
selected_cols = ['survived','name','pclass','sex','age','sibsp','parch','fare','embarked']
df_selected = df[selected_cols]
df = df[selected_cols] # 默认按列取值
df.head()
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | Allen, Miss. Elisabeth Walton | 1 | female | 29.0000 | 0 | 0 | 211.3375 | S |
1 | 1 | Allison, Master. Hudson Trevor | 1 | male | 0.9167 | 1 | 2 | 151.5500 | S |
2 | 0 | Allison, Miss. Helen Loraine | 1 | female | 2.0000 | 1 | 2 | 151.5500 | S |
3 | 0 | Allison, Mr. Hudson Joshua Creighton | 1 | male | 30.0000 | 1 | 2 | 151.5500 | S |
4 | 0 | Allison, Mrs. Hudson J C (Bessie Waldo Daniels) | 1 | female | 25.0000 | 1 | 2 | 151.5500 | S |
# 找出有null值的字段
df.isnull().any()
survived False
name False
pclass False
sex False
age True
sibsp False
parch False
fare True
embarked True
dtype: bool
# 统计各个列有多少个空值
df.isnull().sum()
survived 0
name 0
pclass 0
sex 0
age 263
sibsp 0
parch 0
fare 1
embarked 2
dtype: int64
# 确定缺失值的位置
df[df.isnull().values==True]
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|---|
15 | 0 | Baumann, Mr. John D | 1 | male | NaN | 0 | 0 | 25.9250 | S |
37 | 1 | Bradley, Mr. George ("George Arthur Brayton") | 1 | male | NaN | 0 | 0 | 26.5500 | S |
40 | 0 | Brewe, Dr. Arthur Jackson | 1 | male | NaN | 0 | 0 | 39.6000 | C |
46 | 0 | Cairns, Mr. Alexander | 1 | male | NaN | 0 | 0 | 31.0000 | S |
59 | 1 | Cassebeer, Mrs. Henry Arthur Jr (Eleanor Genev... | 1 | female | NaN | 0 | 0 | 27.7208 | C |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1293 | 0 | Williams, Mr. Howard Hugh "Harry" | 3 | male | NaN | 0 | 0 | 8.0500 | S |
1297 | 0 | Wiseman, Mr. Phillippe | 3 | male | NaN | 0 | 0 | 7.2500 | S |
1302 | 0 | Yousif, Mr. Wazli | 3 | male | NaN | 0 | 0 | 7.2250 | C |
1303 | 0 | Yousseff, Mr. Gerious | 3 | male | NaN | 0 | 0 | 14.4583 | C |
1305 | 0 | Zabour, Miss. Thamine | 3 | female | NaN | 1 | 0 | 14.4542 | C |
266 rows × 9 columns
# 将age空的字段改为平均值
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
df['age'].isnull().any() # 但凡有空值就返回True
False
# 将fare空的字段改为平均值
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(age_mean)
# 为确实embarked记录填充值
df['embarked'] = df['embarked'].fillna('S')
df.isnull().any()
survived False
name False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool
# 转换编码
# 性别sex由字符串转换为数字编码
df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)
# 港口embarked由字母表示转换为数字编码
df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)
# 删除name字段
df = df.drop(['name'],axis=1) # 0行1列
df.head()
survived | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 0 | 29.0000 | 0 | 0 | 211.3375 | 2 |
1 | 1 | 1 | 1 | 0.9167 | 1 | 2 | 151.5500 | 2 |
2 | 0 | 1 | 0 | 2.0000 | 1 | 2 | 151.5500 | 2 |
3 | 0 | 1 | 1 | 30.0000 | 1 | 2 | 151.5500 | 2 |
4 | 0 | 1 | 0 | 25.0000 | 1 | 2 | 151.5500 | 2 |
# 分离特征值和标签值
data = df.values
# 后七列是特征值
features = data[:,1:] # ndarray默认取行,dataframe默认取列
# 第零列是标签值
labels = data[:,0]
labels.shape
(1309,)
def prepare_data(df): # 删除name列 df = df.drop(['name'],axis=1) # 将age空的字段改为平均值 age_mean = df['age'].mean() df['age'] = df['age'].fillna(age_mean) # 将fare空的字段改为平均值 fare_mean = df['fare'].mean() df['fare'] = df['fare'].fillna(age_mean) # 为确实embarked记录填充值 df['embarked'] = df['embarked'].fillna('S') # 性别sex由字符串转换为数字编码 df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int) # 港口embarked由字母表示转换为数字编码 df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int) print(df.isnull().any()) # 分离特征值和标签值 data = df.values # 后七列是特征值 features = data[:,1:] # ndarray默认取行,dataframe默认取列 # 第零列是标签值 labels = data[:,0] return features,labels
shuffle_df = df_selected.sample(frac=1) # 打乱数据顺序,为后面训练做准备,frac为百分比,df保持不变
x_data,y_data = prepare_data(shuffle_df)
x_data.shape,y_data.shape
survived False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool
((1309, 7), (1309,))
shuffle_df.head()
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|---|
58 | 0 | Case, Mr. Howard Brown | 1 | male | 49.0 | 0 | 0 | 26.0000 | S |
666 | 0 | Barbara, Mrs. (Catherine David) | 3 | female | 45.0 | 0 | 1 | 14.4542 | C |
781 | 0 | Drazenoic, Mr. Jozef | 3 | male | 33.0 | 0 | 0 | 7.8958 | C |
480 | 0 | Laroche, Mr. Joseph Philippe Lemercier | 2 | male | 25.0 | 1 | 2 | 41.5792 | C |
459 | 0 | Jacobsohn, Mr. Sidney Samuel | 2 | male | 42.0 | 1 | 0 | 27.0000 | S |
test_split = 0.2
train_num = int((1 - test_split) * x_data.shape[0])
# 训练集
x_train = x_data[:train_num]
y_trian = y_data[:train_num]
# 测试集
x_test = x_data[train_num:]
y_test = y_data[train_num:]
from sklearn import preprocessing
minmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_train = minmax_scale.fit_transform(x_train) # 特征值标准化
x_test = minmax_scale.fit_transform(x_test)
import tensorflow as tf
tf.__version__
'2.6.0'
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=64,
use_bias=True,
activation='relu',
input_dim=7, # 也可以用input_shape=(7,)
bias_initializer='zeros',
kernel_initializer='normal'))
model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合
model.add(tf.keras.layers.Dense(units=32,
activation='sigmoid',
input_shape=(64,), # 也可以用input_dim=64
bias_initializer='zeros',
kernel_initializer='uniform'))
model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合
model.add(tf.keras.layers.Dense(units=1,
activation='sigmoid',
input_dim=32, # 也可以用input_shape=(7,)
bias_initializer='zeros',
kernel_initializer='uniform'))
model.summary()
Model: "sequential_23" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_68 (Dense) (None, 64) 512 _________________________________________________________________ dropout_6 (Dropout) (None, 64) 0 _________________________________________________________________ dense_69 (Dense) (None, 32) 2080 _________________________________________________________________ dropout_7 (Dropout) (None, 32) 0 _________________________________________________________________ dense_70 (Dense) (None, 1) 33 ================================================================= Total params: 2,625 Trainable params: 2,625 Non-trainable params: 0 _________________________________________________________________
# 定义训练模式
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.003),loss='binary_crossentropy',metrics=['accuracy'])
# 设置训练参数
train_epochs = 100
batch_size = 40
train_history = model.fit(x=x_train,#训练特征值
y=y_trian,#训练集的标签
validation_split=0.2,#验证集的比例
epochs=train_epochs,#训练的次数
batch_size=batch_size,#批量的大小
verbose=2) #训练过程的日志信息显示,一个epoch输出一行记录
Epoch 1/100 21/21 - 1s - loss: 0.6780 - accuracy: 0.5854 - val_loss: 0.6464 - val_accuracy: 0.6429 Epoch 2/100 21/21 - 0s - loss: 0.6623 - accuracy: 0.6057 - val_loss: 0.6293 - val_accuracy: 0.6429 Epoch 3/100 21/21 - 0s - loss: 0.6306 - accuracy: 0.6069 - val_loss: 0.5861 - val_accuracy: 0.6667 Epoch 4/100 21/21 - 0s - loss: 0.5771 - accuracy: 0.7336 - val_loss: 0.5199 - val_accuracy: 0.7905 Epoch 5/100 21/21 - 0s - loss: 0.5364 - accuracy: 0.7646 - val_loss: 0.4939 - val_accuracy: 0.7952 Epoch 6/100 21/21 - 0s - loss: 0.5200 - accuracy: 0.7670 - val_loss: 0.4847 - val_accuracy: 0.8143 Epoch 7/100 21/21 - 0s - loss: 0.5118 - accuracy: 0.7718 - val_loss: 0.4771 - val_accuracy: 0.8143 Epoch 8/100 21/21 - 0s - loss: 0.5060 - accuracy: 0.7766 - val_loss: 0.4738 - val_accuracy: 0.8095 Epoch 9/100 21/21 - 0s - loss: 0.4934 - accuracy: 0.7861 - val_loss: 0.4670 - val_accuracy: 0.7952 Epoch 10/100 21/21 - 0s - loss: 0.4966 - accuracy: 0.7814 - val_loss: 0.4637 - val_accuracy: 0.8000 Epoch 11/100 21/21 - 0s - loss: 0.4928 - accuracy: 0.7766 - val_loss: 0.4635 - val_accuracy: 0.7905 Epoch 12/100 21/21 - 0s - loss: 0.4995 - accuracy: 0.7670 - val_loss: 0.4691 - val_accuracy: 0.7905 Epoch 13/100 21/21 - 0s - loss: 0.4886 - accuracy: 0.7957 - val_loss: 0.4620 - val_accuracy: 0.8095 Epoch 14/100 21/21 - 0s - loss: 0.4790 - accuracy: 0.7838 - val_loss: 0.4565 - val_accuracy: 0.8095 Epoch 15/100 21/21 - 0s - loss: 0.4877 - accuracy: 0.7766 - val_loss: 0.4576 - val_accuracy: 0.8095 Epoch 16/100 21/21 - 0s - loss: 0.4839 - accuracy: 0.7897 - val_loss: 0.4560 - val_accuracy: 0.8095 Epoch 17/100 21/21 - 0s - loss: 0.4813 - accuracy: 0.7814 - val_loss: 0.4614 - val_accuracy: 0.8095 Epoch 18/100 21/21 - 0s - loss: 0.4812 - accuracy: 0.7742 - val_loss: 0.4553 - val_accuracy: 0.8095 Epoch 19/100 21/21 - 0s - loss: 0.4762 - accuracy: 0.7885 - val_loss: 0.4554 - val_accuracy: 0.8048 Epoch 20/100 21/21 - 0s - loss: 0.4784 - accuracy: 0.7802 - val_loss: 0.4567 - val_accuracy: 0.8000 Epoch 21/100 21/21 - 0s - loss: 0.4794 - accuracy: 0.7885 - val_loss: 0.4626 - val_accuracy: 0.8000 Epoch 22/100 21/21 - 0s - loss: 0.4824 - accuracy: 0.7838 - val_loss: 0.4567 - val_accuracy: 0.7857 Epoch 23/100 21/21 - 0s - loss: 0.4786 - accuracy: 0.7849 - val_loss: 0.4553 - val_accuracy: 0.8048 Epoch 24/100 21/21 - 0s - loss: 0.4801 - accuracy: 0.7742 - val_loss: 0.4735 - val_accuracy: 0.7905 Epoch 25/100 21/21 - 0s - loss: 0.4752 - accuracy: 0.7849 - val_loss: 0.4571 - val_accuracy: 0.7905 Epoch 26/100 21/21 - 0s - loss: 0.4688 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8000 Epoch 27/100 21/21 - 0s - loss: 0.4624 - accuracy: 0.7873 - val_loss: 0.4577 - val_accuracy: 0.8048 Epoch 28/100 21/21 - 0s - loss: 0.4656 - accuracy: 0.7993 - val_loss: 0.4602 - val_accuracy: 0.8000 Epoch 29/100 21/21 - 0s - loss: 0.4649 - accuracy: 0.7969 - val_loss: 0.4546 - val_accuracy: 0.8000 Epoch 30/100 21/21 - 0s - loss: 0.4645 - accuracy: 0.7849 - val_loss: 0.4638 - val_accuracy: 0.8000 Epoch 31/100 21/21 - 0s - loss: 0.4635 - accuracy: 0.7921 - val_loss: 0.4603 - val_accuracy: 0.7952 Epoch 32/100 21/21 - 0s - loss: 0.4646 - accuracy: 0.7909 - val_loss: 0.4567 - val_accuracy: 0.7952 Epoch 33/100 21/21 - 0s - loss: 0.4664 - accuracy: 0.7909 - val_loss: 0.4583 - val_accuracy: 0.7952 Epoch 34/100 21/21 - 0s - loss: 0.4661 - accuracy: 0.7921 - val_loss: 0.4575 - val_accuracy: 0.8000 Epoch 35/100 21/21 - 0s - loss: 0.4660 - accuracy: 0.7838 - val_loss: 0.4582 - val_accuracy: 0.7952 Epoch 36/100 21/21 - 0s - loss: 0.4577 - accuracy: 0.8005 - val_loss: 0.4567 - val_accuracy: 0.8000 Epoch 37/100 21/21 - 0s - loss: 0.4648 - accuracy: 0.7909 - val_loss: 0.4585 - val_accuracy: 0.7952 Epoch 38/100 21/21 - 0s - loss: 0.4613 - accuracy: 0.7921 - val_loss: 0.4569 - val_accuracy: 0.7952 Epoch 39/100 21/21 - 0s - loss: 0.4643 - accuracy: 0.7921 - val_loss: 0.4687 - val_accuracy: 0.8000 Epoch 40/100 21/21 - 0s - loss: 0.4696 - accuracy: 0.7814 - val_loss: 0.4601 - val_accuracy: 0.8048 Epoch 41/100 21/21 - 0s - loss: 0.4589 - accuracy: 0.7933 - val_loss: 0.4562 - val_accuracy: 0.7952 Epoch 42/100 21/21 - 0s - loss: 0.4587 - accuracy: 0.7885 - val_loss: 0.4594 - val_accuracy: 0.8000 Epoch 43/100 21/21 - 0s - loss: 0.4601 - accuracy: 0.7981 - val_loss: 0.4563 - val_accuracy: 0.7905 Epoch 44/100 21/21 - 0s - loss: 0.4639 - accuracy: 0.7897 - val_loss: 0.4594 - val_accuracy: 0.8048 Epoch 45/100 21/21 - 0s - loss: 0.4569 - accuracy: 0.7957 - val_loss: 0.4587 - val_accuracy: 0.8000 Epoch 46/100 21/21 - 0s - loss: 0.4619 - accuracy: 0.7957 - val_loss: 0.4556 - val_accuracy: 0.8048 Epoch 47/100 21/21 - 0s - loss: 0.4661 - accuracy: 0.7861 - val_loss: 0.4563 - val_accuracy: 0.8000 Epoch 48/100 21/21 - 0s - loss: 0.4550 - accuracy: 0.7969 - val_loss: 0.4538 - val_accuracy: 0.8000 Epoch 49/100 21/21 - 0s - loss: 0.4550 - accuracy: 0.7873 - val_loss: 0.4572 - val_accuracy: 0.8048 Epoch 50/100 21/21 - 0s - loss: 0.4603 - accuracy: 0.7909 - val_loss: 0.4584 - val_accuracy: 0.8000 Epoch 51/100 21/21 - 0s - loss: 0.4575 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8095 Epoch 52/100 21/21 - 0s - loss: 0.4568 - accuracy: 0.8029 - val_loss: 0.4584 - val_accuracy: 0.8048 Epoch 53/100 21/21 - 0s - loss: 0.4594 - accuracy: 0.7909 - val_loss: 0.4558 - val_accuracy: 0.8000 Epoch 54/100 21/21 - 0s - loss: 0.4588 - accuracy: 0.8065 - val_loss: 0.4523 - val_accuracy: 0.8000 Epoch 55/100 21/21 - 0s - loss: 0.4532 - accuracy: 0.8029 - val_loss: 0.4593 - val_accuracy: 0.8048 Epoch 56/100 21/21 - 0s - loss: 0.4578 - accuracy: 0.8100 - val_loss: 0.4614 - val_accuracy: 0.8048 Epoch 57/100 21/21 - 0s - loss: 0.4549 - accuracy: 0.8041 - val_loss: 0.4580 - val_accuracy: 0.8095 Epoch 58/100 21/21 - 0s - loss: 0.4568 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8095 Epoch 59/100 21/21 - 0s - loss: 0.4567 - accuracy: 0.7981 - val_loss: 0.4532 - val_accuracy: 0.8095 Epoch 60/100 21/21 - 0s - loss: 0.4532 - accuracy: 0.7993 - val_loss: 0.4569 - val_accuracy: 0.7952 Epoch 61/100 21/21 - 0s - loss: 0.4543 - accuracy: 0.7969 - val_loss: 0.4555 - val_accuracy: 0.8000 Epoch 62/100 21/21 - 0s - loss: 0.4472 - accuracy: 0.8053 - val_loss: 0.4543 - val_accuracy: 0.8048 Epoch 63/100 21/21 - 0s - loss: 0.4458 - accuracy: 0.8100 - val_loss: 0.4534 - val_accuracy: 0.8095 Epoch 64/100 21/21 - 0s - loss: 0.4497 - accuracy: 0.8005 - val_loss: 0.4593 - val_accuracy: 0.8000 Epoch 65/100 21/21 - 0s - loss: 0.4511 - accuracy: 0.8053 - val_loss: 0.4522 - val_accuracy: 0.8095 Epoch 66/100 21/21 - 0s - loss: 0.4506 - accuracy: 0.8005 - val_loss: 0.4592 - val_accuracy: 0.7952 Epoch 67/100 21/21 - 0s - loss: 0.4533 - accuracy: 0.8005 - val_loss: 0.4545 - val_accuracy: 0.8000 Epoch 68/100 21/21 - 0s - loss: 0.4481 - accuracy: 0.7909 - val_loss: 0.4545 - val_accuracy: 0.7952 Epoch 69/100 21/21 - 0s - loss: 0.4555 - accuracy: 0.7981 - val_loss: 0.4551 - val_accuracy: 0.8000 Epoch 70/100 21/21 - 0s - loss: 0.4440 - accuracy: 0.8029 - val_loss: 0.4552 - val_accuracy: 0.7952 Epoch 71/100 21/21 - 0s - loss: 0.4584 - accuracy: 0.8029 - val_loss: 0.4530 - val_accuracy: 0.7952 Epoch 72/100 21/21 - 0s - loss: 0.4480 - accuracy: 0.7933 - val_loss: 0.4549 - val_accuracy: 0.8048 Epoch 73/100 21/21 - 0s - loss: 0.4554 - accuracy: 0.7981 - val_loss: 0.4536 - val_accuracy: 0.7952 Epoch 74/100 21/21 - 0s - loss: 0.4438 - accuracy: 0.8029 - val_loss: 0.4532 - val_accuracy: 0.7952 Epoch 75/100 21/21 - 0s - loss: 0.4483 - accuracy: 0.8053 - val_loss: 0.4515 - val_accuracy: 0.8095 Epoch 76/100 21/21 - 0s - loss: 0.4408 - accuracy: 0.8041 - val_loss: 0.4554 - val_accuracy: 0.8048 Epoch 77/100 21/21 - 0s - loss: 0.4470 - accuracy: 0.8017 - val_loss: 0.4531 - val_accuracy: 0.8000 Epoch 78/100 21/21 - 0s - loss: 0.4484 - accuracy: 0.8053 - val_loss: 0.4549 - val_accuracy: 0.8048 Epoch 79/100 21/21 - 0s - loss: 0.4456 - accuracy: 0.8053 - val_loss: 0.4526 - val_accuracy: 0.8048 Epoch 80/100 21/21 - 0s - loss: 0.4459 - accuracy: 0.8100 - val_loss: 0.4573 - val_accuracy: 0.7952 Epoch 81/100 21/21 - 0s - loss: 0.4496 - accuracy: 0.7981 - val_loss: 0.4573 - val_accuracy: 0.8095 Epoch 82/100 21/21 - 0s - loss: 0.4515 - accuracy: 0.8053 - val_loss: 0.4502 - val_accuracy: 0.8095 Epoch 83/100 21/21 - 0s - loss: 0.4503 - accuracy: 0.8100 - val_loss: 0.4546 - val_accuracy: 0.7952 Epoch 84/100 21/21 - 0s - loss: 0.4386 - accuracy: 0.8065 - val_loss: 0.4540 - val_accuracy: 0.8048 Epoch 85/100 21/21 - 0s - loss: 0.4371 - accuracy: 0.8088 - val_loss: 0.4552 - val_accuracy: 0.8095 Epoch 86/100 21/21 - 0s - loss: 0.4420 - accuracy: 0.8053 - val_loss: 0.4553 - val_accuracy: 0.8048 Epoch 87/100 21/21 - 0s - loss: 0.4437 - accuracy: 0.8112 - val_loss: 0.4550 - val_accuracy: 0.7952 Epoch 88/100 21/21 - 0s - loss: 0.4432 - accuracy: 0.7969 - val_loss: 0.4565 - val_accuracy: 0.8095 Epoch 89/100 21/21 - 0s - loss: 0.4396 - accuracy: 0.8065 - val_loss: 0.4552 - val_accuracy: 0.8000 Epoch 90/100 21/21 - 0s - loss: 0.4477 - accuracy: 0.8088 - val_loss: 0.4554 - val_accuracy: 0.8048 Epoch 91/100 21/21 - 0s - loss: 0.4412 - accuracy: 0.8017 - val_loss: 0.4507 - val_accuracy: 0.8048 Epoch 92/100 21/21 - 0s - loss: 0.4484 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8048 Epoch 93/100 21/21 - 0s - loss: 0.4433 - accuracy: 0.8017 - val_loss: 0.4519 - val_accuracy: 0.8048 Epoch 94/100 21/21 - 0s - loss: 0.4415 - accuracy: 0.7957 - val_loss: 0.4524 - val_accuracy: 0.8095 Epoch 95/100 21/21 - 0s - loss: 0.4399 - accuracy: 0.8065 - val_loss: 0.4549 - val_accuracy: 0.8048 Epoch 96/100 21/21 - 0s - loss: 0.4387 - accuracy: 0.8065 - val_loss: 0.4546 - val_accuracy: 0.8095 Epoch 97/100 21/21 - 0s - loss: 0.4463 - accuracy: 0.7945 - val_loss: 0.4542 - val_accuracy: 0.8048 Epoch 98/100 21/21 - 0s - loss: 0.4447 - accuracy: 0.7993 - val_loss: 0.4542 - val_accuracy: 0.8143 Epoch 99/100 21/21 - 0s - loss: 0.4368 - accuracy: 0.8041 - val_loss: 0.4551 - val_accuracy: 0.8048 Epoch 100/100 21/21 - 0s - loss: 0.4395 - accuracy: 0.8053 - val_loss: 0.4501 - val_accuracy: 0.8095
# 训练过程可视化
import matplotlib.pyplot as plt
def show_train_history(trian_history,train_metric,validation_metric):
plt.plot(trian_history[train_metric])
plt.plot(trian_history[validation_metric])
plt.title('Train History')
plt.ylabel(train_metric)
plt.xlabel('epoch')
plt.legend(['train','validation'],loc='upper left')
plt.show()
show_train_history(train_history.history,'loss','val_loss')
show_train_history(train_history.history,'accuracy','val_accuracy')
loss,acc = model.evaluate(x_test,y_test)
9/9 [==============================] - 0s 2ms/step - loss: 0.3703 - accuracy: 0.8435
loss,acc
(0.3702643811702728, 0.8435114622116089)
#@title
Jack_info = [0,'Jack',3,'male',23,1,0,5.0000,'S']
Rose_info = [1,'Rose',1,'female',20,1,0,100.0000,'S']
x_pre = pd.DataFrame([Jack_info,Rose_info],columns=selected_cols)
x_pre
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | Jack | 3 | male | 23 | 1 | 0 | 5.0 | S |
1 | 1 | Rose | 1 | female | 20 | 1 | 0 | 100.0 | S |
x_pre_features,y = prepare_data(x_pre)
from sklearn import preprocessing
minmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_pre_features = minmax_scale.fit_transform(x_pre_features) # 特征值标准化
y_pre = model.predict(x_pre_features)
survived False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool
x_pre.insert(len(x_pre.columns),'surv_probabilty',y_pre)
x_pre
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | surv_probabilty | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | Jack | 3 | male | 23 | 1 | 0 | 5.0 | S | 0.058498 |
1 | 1 | Rose | 1 | female | 20 | 1 | 0 | 100.0 | S | 0.975978 |
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。