当前位置:   article > 正文

TensorFlow 从入门到精通(8)—— 泰坦尼克号旅客生存预测_网络神经元-泰坦尼克号jack和rose的生存概率

网络神经元-泰坦尼克号jack和rose的生存概率

“You Jump,I Jump”语出经典爱情电影《泰坦尼克号》经典台词,女主角Rose在船首即将跳入海里,站在旁边的男主Jack为挽救女主,便说出经典台词“You Jump,I Jump”。当一个陌生男人肯为一个陌生女人没理由地去死的时候,毫无缘由的,女主对男主产生了爱的情愫。
当然这跟我这篇教程关系不大,这里我们将会通过AI预测Jack和Rose的存活率,国庆没断更,属实不易,需要数据集可以私聊本人or加学习群。谢谢大家支持!

一、数据集

1.读取数据集

import pandas as pd

df = pd.read_excel('titanic3.xls')
df.describe()
  • 1
  • 2
  • 3
  • 4
pclasssurvivedagesibspparchfarebody
count1309.0000001309.0000001046.0000001309.0000001309.0000001308.000000121.000000
mean2.2948820.38197129.8811350.4988540.38502733.295479160.809917
std0.8378360.48605514.4135001.0416580.86556051.75866897.696922
min1.0000000.0000000.1667000.0000000.0000000.0000001.000000
25%2.0000000.00000021.0000000.0000000.0000007.89580072.000000
50%3.0000000.00000028.0000000.0000000.00000014.454200155.000000
75%3.0000001.00000039.0000001.0000000.00000031.275000256.000000
max3.0000001.00000080.0000008.0000009.000000512.329200328.000000
df.info()
  • 1
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1309 entries, 0 to 1308
Data columns (total 14 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   pclass     1309 non-null   int64  
 1   survived   1309 non-null   int64  
 2   name       1309 non-null   object 
 3   sex        1309 non-null   object 
 4   age        1046 non-null   float64
 5   sibsp      1309 non-null   int64  
 6   parch      1309 non-null   int64  
 7   ticket     1309 non-null   object 
 8   fare       1308 non-null   float64
 9   cabin      295 non-null    object 
 10  embarked   1307 non-null   object 
 11  boat       486 non-null    object 
 12  body       121 non-null    float64
 13  home.dest  745 non-null    object 
dtypes: float64(3), int64(4), object(7)
memory usage: 143.3+ KB
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
df.head()
  • 1
pclasssurvivednamesexagesibspparchticketfarecabinembarkedboatbodyhome.dest
011Allen, Miss. Elisabeth Waltonfemale29.00000024160211.3375B5S2NaNSt Louis, MO
111Allison, Master. Hudson Trevormale0.916712113781151.5500C22 C26S11NaNMontreal, PQ / Chesterville, ON
210Allison, Miss. Helen Lorainefemale2.000012113781151.5500C22 C26SNaNNaNMontreal, PQ / Chesterville, ON
310Allison, Mr. Hudson Joshua Creightonmale30.000012113781151.5500C22 C26SNaN135.0Montreal, PQ / Chesterville, ON
410Allison, Mrs. Hudson J C (Bessie Waldo Daniels)female25.000012113781151.5500C22 C26SNaNNaNMontreal, PQ / Chesterville, ON

2.处理数据集

  • 提取字段
  • 处理缺失值
  • 转换编码
  • 删除name列
# 筛选需要提取的字段
selected_cols = ['survived','name','pclass','sex','age','sibsp','parch','fare','embarked']
df_selected = df[selected_cols]
df = df[selected_cols] # 默认按列取值
df.head()
  • 1
  • 2
  • 3
  • 4
  • 5
survivednamepclasssexagesibspparchfareembarked
01Allen, Miss. Elisabeth Walton1female29.000000211.3375S
11Allison, Master. Hudson Trevor1male0.916712151.5500S
20Allison, Miss. Helen Loraine1female2.000012151.5500S
30Allison, Mr. Hudson Joshua Creighton1male30.000012151.5500S
40Allison, Mrs. Hudson J C (Bessie Waldo Daniels)1female25.000012151.5500S
# 找出有null值的字段
df.isnull().any()
  • 1
  • 2
survived    False
name        False
pclass      False
sex         False
age          True
sibsp       False
parch       False
fare         True
embarked     True
dtype: bool
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
# 统计各个列有多少个空值
df.isnull().sum()
  • 1
  • 2
survived      0
name          0
pclass        0
sex           0
age         263
sibsp         0
parch         0
fare          1
embarked      2
dtype: int64
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
# 确定缺失值的位置
df[df.isnull().values==True]
  • 1
  • 2
survivednamepclasssexagesibspparchfareembarked
150Baumann, Mr. John D1maleNaN0025.9250S
371Bradley, Mr. George ("George Arthur Brayton")1maleNaN0026.5500S
400Brewe, Dr. Arthur Jackson1maleNaN0039.6000C
460Cairns, Mr. Alexander1maleNaN0031.0000S
591Cassebeer, Mrs. Henry Arthur Jr (Eleanor Genev...1femaleNaN0027.7208C
..............................
12930Williams, Mr. Howard Hugh "Harry"3maleNaN008.0500S
12970Wiseman, Mr. Phillippe3maleNaN007.2500S
13020Yousif, Mr. Wazli3maleNaN007.2250C
13030Yousseff, Mr. Gerious3maleNaN0014.4583C
13050Zabour, Miss. Thamine3femaleNaN1014.4542C

266 rows × 9 columns

# 将age空的字段改为平均值
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
df['age'].isnull().any() # 但凡有空值就返回True
  • 1
  • 2
  • 3
  • 4
False
  • 1
# 将fare空的字段改为平均值
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(age_mean)

# 为确实embarked记录填充值
df['embarked'] = df['embarked'].fillna('S')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
df.isnull().any()
  • 1
survived    False
name        False
pclass      False
sex         False
age         False
sibsp       False
parch       False
fare        False
embarked    False
dtype: bool
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
# 转换编码
# 性别sex由字符串转换为数字编码
df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)
# 港口embarked由字母表示转换为数字编码
df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)
  • 1
  • 2
  • 3
  • 4
  • 5
# 删除name字段
df = df.drop(['name'],axis=1) # 0行1列
df.head()
  • 1
  • 2
  • 3
survivedpclasssexagesibspparchfareembarked
011029.000000211.33752
11110.916712151.55002
20102.000012151.55002
301130.000012151.55002
401025.000012151.55002

3.划分特征值和标签值

# 分离特征值和标签值
data = df.values

# 后七列是特征值
features = data[:,1:] # ndarray默认取行,dataframe默认取列
# 第零列是标签值
labels = data[:,0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
labels.shape
  • 1
(1309,)
  • 1

4.定义数据预处理函数

def prepare_data(df):
  # 删除name列
  df = df.drop(['name'],axis=1) 

  # 将age空的字段改为平均值
  age_mean = df['age'].mean()
  df['age'] = df['age'].fillna(age_mean)

  # 将fare空的字段改为平均值
  fare_mean = df['fare'].mean()
  df['fare'] = df['fare'].fillna(age_mean)

  # 为确实embarked记录填充值
  df['embarked'] = df['embarked'].fillna('S')

  # 性别sex由字符串转换为数字编码
  df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)

  # 港口embarked由字母表示转换为数字编码
  df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)
  print(df.isnull().any())
  # 分离特征值和标签值
  data = df.values

  # 后七列是特征值
  features = data[:,1:] # ndarray默认取行,dataframe默认取列
  # 第零列是标签值
  labels = data[:,0]

  return features,labels
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

5.划分训练集和测试集

shuffle_df = df_selected.sample(frac=1) # 打乱数据顺序,为后面训练做准备,frac为百分比,df保持不变
  • 1
x_data,y_data = prepare_data(shuffle_df)
x_data.shape,y_data.shape
  • 1
  • 2
survived    False
pclass      False
sex         False
age         False
sibsp       False
parch       False
fare        False
embarked    False
dtype: bool





((1309, 7), (1309,))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
shuffle_df.head()
  • 1
survivednamepclasssexagesibspparchfareembarked
580Case, Mr. Howard Brown1male49.00026.0000S
6660Barbara, Mrs. (Catherine David)3female45.00114.4542C
7810Drazenoic, Mr. Jozef3male33.0007.8958C
4800Laroche, Mr. Joseph Philippe Lemercier2male25.01241.5792C
4590Jacobsohn, Mr. Sidney Samuel2male42.01027.0000S
test_split = 0.2
train_num = int((1 - test_split) * x_data.shape[0])
# 训练集
x_train = x_data[:train_num]
y_trian = y_data[:train_num]
# 测试集
x_test = x_data[train_num:]
y_test = y_data[train_num:]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

6.归一化

from sklearn import preprocessing

minmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_train = minmax_scale.fit_transform(x_train) # 特征值标准化
x_test = minmax_scale.fit_transform(x_test)
  • 1
  • 2
  • 3
  • 4
  • 5

二、模型

import tensorflow as tf
tf.__version__
  • 1
  • 2
'2.6.0'
  • 1

1.建立序列模型

model = tf.keras.models.Sequential()
  • 1

2.添加隐藏层

model.add(tf.keras.layers.Dense(units=64,
                use_bias=True,
                activation='relu',
                input_dim=7, # 也可以用input_shape=(7,)
                bias_initializer='zeros',
                kernel_initializer='normal'))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合
  • 1
model.add(tf.keras.layers.Dense(units=32,
                activation='sigmoid',
                input_shape=(64,), # 也可以用input_dim=64
                bias_initializer='zeros',
                kernel_initializer='uniform'))
  • 1
  • 2
  • 3
  • 4
  • 5
model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合
  • 1

3.添加输出层

model.add(tf.keras.layers.Dense(units=1,
                activation='sigmoid',
                input_dim=32, # 也可以用input_shape=(7,)
                bias_initializer='zeros',
                kernel_initializer='uniform'))
  • 1
  • 2
  • 3
  • 4
  • 5
model.summary()
  • 1
Model: "sequential_23"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_68 (Dense)             (None, 64)                512       
_________________________________________________________________
dropout_6 (Dropout)          (None, 64)                0         
_________________________________________________________________
dense_69 (Dense)             (None, 32)                2080      
_________________________________________________________________
dropout_7 (Dropout)          (None, 32)                0         
_________________________________________________________________
dense_70 (Dense)             (None, 1)                 33        
=================================================================
Total params: 2,625
Trainable params: 2,625
Non-trainable params: 0
_________________________________________________________________
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

三、训练

1.训练

# 定义训练模式
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.003),loss='binary_crossentropy',metrics=['accuracy'])
  • 1
  • 2
# 设置训练参数
train_epochs = 100 
batch_size = 40
  • 1
  • 2
  • 3
train_history = model.fit(x=x_train,#训练特征值
              y=y_trian,#训练集的标签
              validation_split=0.2,#验证集的比例
              epochs=train_epochs,#训练的次数
              batch_size=batch_size,#批量的大小
              verbose=2) #训练过程的日志信息显示,一个epoch输出一行记录
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
Epoch 1/100
21/21 - 1s - loss: 0.6780 - accuracy: 0.5854 - val_loss: 0.6464 - val_accuracy: 0.6429
Epoch 2/100
21/21 - 0s - loss: 0.6623 - accuracy: 0.6057 - val_loss: 0.6293 - val_accuracy: 0.6429
Epoch 3/100
21/21 - 0s - loss: 0.6306 - accuracy: 0.6069 - val_loss: 0.5861 - val_accuracy: 0.6667
Epoch 4/100
21/21 - 0s - loss: 0.5771 - accuracy: 0.7336 - val_loss: 0.5199 - val_accuracy: 0.7905
Epoch 5/100
21/21 - 0s - loss: 0.5364 - accuracy: 0.7646 - val_loss: 0.4939 - val_accuracy: 0.7952
Epoch 6/100
21/21 - 0s - loss: 0.5200 - accuracy: 0.7670 - val_loss: 0.4847 - val_accuracy: 0.8143
Epoch 7/100
21/21 - 0s - loss: 0.5118 - accuracy: 0.7718 - val_loss: 0.4771 - val_accuracy: 0.8143
Epoch 8/100
21/21 - 0s - loss: 0.5060 - accuracy: 0.7766 - val_loss: 0.4738 - val_accuracy: 0.8095
Epoch 9/100
21/21 - 0s - loss: 0.4934 - accuracy: 0.7861 - val_loss: 0.4670 - val_accuracy: 0.7952
Epoch 10/100
21/21 - 0s - loss: 0.4966 - accuracy: 0.7814 - val_loss: 0.4637 - val_accuracy: 0.8000
Epoch 11/100
21/21 - 0s - loss: 0.4928 - accuracy: 0.7766 - val_loss: 0.4635 - val_accuracy: 0.7905
Epoch 12/100
21/21 - 0s - loss: 0.4995 - accuracy: 0.7670 - val_loss: 0.4691 - val_accuracy: 0.7905
Epoch 13/100
21/21 - 0s - loss: 0.4886 - accuracy: 0.7957 - val_loss: 0.4620 - val_accuracy: 0.8095
Epoch 14/100
21/21 - 0s - loss: 0.4790 - accuracy: 0.7838 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 15/100
21/21 - 0s - loss: 0.4877 - accuracy: 0.7766 - val_loss: 0.4576 - val_accuracy: 0.8095
Epoch 16/100
21/21 - 0s - loss: 0.4839 - accuracy: 0.7897 - val_loss: 0.4560 - val_accuracy: 0.8095
Epoch 17/100
21/21 - 0s - loss: 0.4813 - accuracy: 0.7814 - val_loss: 0.4614 - val_accuracy: 0.8095
Epoch 18/100
21/21 - 0s - loss: 0.4812 - accuracy: 0.7742 - val_loss: 0.4553 - val_accuracy: 0.8095
Epoch 19/100
21/21 - 0s - loss: 0.4762 - accuracy: 0.7885 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 20/100
21/21 - 0s - loss: 0.4784 - accuracy: 0.7802 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 21/100
21/21 - 0s - loss: 0.4794 - accuracy: 0.7885 - val_loss: 0.4626 - val_accuracy: 0.8000
Epoch 22/100
21/21 - 0s - loss: 0.4824 - accuracy: 0.7838 - val_loss: 0.4567 - val_accuracy: 0.7857
Epoch 23/100
21/21 - 0s - loss: 0.4786 - accuracy: 0.7849 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 24/100
21/21 - 0s - loss: 0.4801 - accuracy: 0.7742 - val_loss: 0.4735 - val_accuracy: 0.7905
Epoch 25/100
21/21 - 0s - loss: 0.4752 - accuracy: 0.7849 - val_loss: 0.4571 - val_accuracy: 0.7905
Epoch 26/100
21/21 - 0s - loss: 0.4688 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8000
Epoch 27/100
21/21 - 0s - loss: 0.4624 - accuracy: 0.7873 - val_loss: 0.4577 - val_accuracy: 0.8048
Epoch 28/100
21/21 - 0s - loss: 0.4656 - accuracy: 0.7993 - val_loss: 0.4602 - val_accuracy: 0.8000
Epoch 29/100
21/21 - 0s - loss: 0.4649 - accuracy: 0.7969 - val_loss: 0.4546 - val_accuracy: 0.8000
Epoch 30/100
21/21 - 0s - loss: 0.4645 - accuracy: 0.7849 - val_loss: 0.4638 - val_accuracy: 0.8000
Epoch 31/100
21/21 - 0s - loss: 0.4635 - accuracy: 0.7921 - val_loss: 0.4603 - val_accuracy: 0.7952
Epoch 32/100
21/21 - 0s - loss: 0.4646 - accuracy: 0.7909 - val_loss: 0.4567 - val_accuracy: 0.7952
Epoch 33/100
21/21 - 0s - loss: 0.4664 - accuracy: 0.7909 - val_loss: 0.4583 - val_accuracy: 0.7952
Epoch 34/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7921 - val_loss: 0.4575 - val_accuracy: 0.8000
Epoch 35/100
21/21 - 0s - loss: 0.4660 - accuracy: 0.7838 - val_loss: 0.4582 - val_accuracy: 0.7952
Epoch 36/100
21/21 - 0s - loss: 0.4577 - accuracy: 0.8005 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 37/100
21/21 - 0s - loss: 0.4648 - accuracy: 0.7909 - val_loss: 0.4585 - val_accuracy: 0.7952
Epoch 38/100
21/21 - 0s - loss: 0.4613 - accuracy: 0.7921 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 39/100
21/21 - 0s - loss: 0.4643 - accuracy: 0.7921 - val_loss: 0.4687 - val_accuracy: 0.8000
Epoch 40/100
21/21 - 0s - loss: 0.4696 - accuracy: 0.7814 - val_loss: 0.4601 - val_accuracy: 0.8048
Epoch 41/100
21/21 - 0s - loss: 0.4589 - accuracy: 0.7933 - val_loss: 0.4562 - val_accuracy: 0.7952
Epoch 42/100
21/21 - 0s - loss: 0.4587 - accuracy: 0.7885 - val_loss: 0.4594 - val_accuracy: 0.8000
Epoch 43/100
21/21 - 0s - loss: 0.4601 - accuracy: 0.7981 - val_loss: 0.4563 - val_accuracy: 0.7905
Epoch 44/100
21/21 - 0s - loss: 0.4639 - accuracy: 0.7897 - val_loss: 0.4594 - val_accuracy: 0.8048
Epoch 45/100
21/21 - 0s - loss: 0.4569 - accuracy: 0.7957 - val_loss: 0.4587 - val_accuracy: 0.8000
Epoch 46/100
21/21 - 0s - loss: 0.4619 - accuracy: 0.7957 - val_loss: 0.4556 - val_accuracy: 0.8048
Epoch 47/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7861 - val_loss: 0.4563 - val_accuracy: 0.8000
Epoch 48/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7969 - val_loss: 0.4538 - val_accuracy: 0.8000
Epoch 49/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7873 - val_loss: 0.4572 - val_accuracy: 0.8048
Epoch 50/100
21/21 - 0s - loss: 0.4603 - accuracy: 0.7909 - val_loss: 0.4584 - val_accuracy: 0.8000
Epoch 51/100
21/21 - 0s - loss: 0.4575 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8095
Epoch 52/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.8029 - val_loss: 0.4584 - val_accuracy: 0.8048
Epoch 53/100
21/21 - 0s - loss: 0.4594 - accuracy: 0.7909 - val_loss: 0.4558 - val_accuracy: 0.8000
Epoch 54/100
21/21 - 0s - loss: 0.4588 - accuracy: 0.8065 - val_loss: 0.4523 - val_accuracy: 0.8000
Epoch 55/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.8029 - val_loss: 0.4593 - val_accuracy: 0.8048
Epoch 56/100
21/21 - 0s - loss: 0.4578 - accuracy: 0.8100 - val_loss: 0.4614 - val_accuracy: 0.8048
Epoch 57/100
21/21 - 0s - loss: 0.4549 - accuracy: 0.8041 - val_loss: 0.4580 - val_accuracy: 0.8095
Epoch 58/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8095
Epoch 59/100
21/21 - 0s - loss: 0.4567 - accuracy: 0.7981 - val_loss: 0.4532 - val_accuracy: 0.8095
Epoch 60/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.7993 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 61/100
21/21 - 0s - loss: 0.4543 - accuracy: 0.7969 - val_loss: 0.4555 - val_accuracy: 0.8000
Epoch 62/100
21/21 - 0s - loss: 0.4472 - accuracy: 0.8053 - val_loss: 0.4543 - val_accuracy: 0.8048
Epoch 63/100
21/21 - 0s - loss: 0.4458 - accuracy: 0.8100 - val_loss: 0.4534 - val_accuracy: 0.8095
Epoch 64/100
21/21 - 0s - loss: 0.4497 - accuracy: 0.8005 - val_loss: 0.4593 - val_accuracy: 0.8000
Epoch 65/100
21/21 - 0s - loss: 0.4511 - accuracy: 0.8053 - val_loss: 0.4522 - val_accuracy: 0.8095
Epoch 66/100
21/21 - 0s - loss: 0.4506 - accuracy: 0.8005 - val_loss: 0.4592 - val_accuracy: 0.7952
Epoch 67/100
21/21 - 0s - loss: 0.4533 - accuracy: 0.8005 - val_loss: 0.4545 - val_accuracy: 0.8000
Epoch 68/100
21/21 - 0s - loss: 0.4481 - accuracy: 0.7909 - val_loss: 0.4545 - val_accuracy: 0.7952
Epoch 69/100
21/21 - 0s - loss: 0.4555 - accuracy: 0.7981 - val_loss: 0.4551 - val_accuracy: 0.8000
Epoch 70/100
21/21 - 0s - loss: 0.4440 - accuracy: 0.8029 - val_loss: 0.4552 - val_accuracy: 0.7952
Epoch 71/100
21/21 - 0s - loss: 0.4584 - accuracy: 0.8029 - val_loss: 0.4530 - val_accuracy: 0.7952
Epoch 72/100
21/21 - 0s - loss: 0.4480 - accuracy: 0.7933 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 73/100
21/21 - 0s - loss: 0.4554 - accuracy: 0.7981 - val_loss: 0.4536 - val_accuracy: 0.7952
Epoch 74/100
21/21 - 0s - loss: 0.4438 - accuracy: 0.8029 - val_loss: 0.4532 - val_accuracy: 0.7952
Epoch 75/100
21/21 - 0s - loss: 0.4483 - accuracy: 0.8053 - val_loss: 0.4515 - val_accuracy: 0.8095
Epoch 76/100
21/21 - 0s - loss: 0.4408 - accuracy: 0.8041 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 77/100
21/21 - 0s - loss: 0.4470 - accuracy: 0.8017 - val_loss: 0.4531 - val_accuracy: 0.8000
Epoch 78/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.8053 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 79/100
21/21 - 0s - loss: 0.4456 - accuracy: 0.8053 - val_loss: 0.4526 - val_accuracy: 0.8048
Epoch 80/100
21/21 - 0s - loss: 0.4459 - accuracy: 0.8100 - val_loss: 0.4573 - val_accuracy: 0.7952
Epoch 81/100
21/21 - 0s - loss: 0.4496 - accuracy: 0.7981 - val_loss: 0.4573 - val_accuracy: 0.8095
Epoch 82/100
21/21 - 0s - loss: 0.4515 - accuracy: 0.8053 - val_loss: 0.4502 - val_accuracy: 0.8095
Epoch 83/100
21/21 - 0s - loss: 0.4503 - accuracy: 0.8100 - val_loss: 0.4546 - val_accuracy: 0.7952
Epoch 84/100
21/21 - 0s - loss: 0.4386 - accuracy: 0.8065 - val_loss: 0.4540 - val_accuracy: 0.8048
Epoch 85/100
21/21 - 0s - loss: 0.4371 - accuracy: 0.8088 - val_loss: 0.4552 - val_accuracy: 0.8095
Epoch 86/100
21/21 - 0s - loss: 0.4420 - accuracy: 0.8053 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 87/100
21/21 - 0s - loss: 0.4437 - accuracy: 0.8112 - val_loss: 0.4550 - val_accuracy: 0.7952
Epoch 88/100
21/21 - 0s - loss: 0.4432 - accuracy: 0.7969 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 89/100
21/21 - 0s - loss: 0.4396 - accuracy: 0.8065 - val_loss: 0.4552 - val_accuracy: 0.8000
Epoch 90/100
21/21 - 0s - loss: 0.4477 - accuracy: 0.8088 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 91/100
21/21 - 0s - loss: 0.4412 - accuracy: 0.8017 - val_loss: 0.4507 - val_accuracy: 0.8048
Epoch 92/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8048
Epoch 93/100
21/21 - 0s - loss: 0.4433 - accuracy: 0.8017 - val_loss: 0.4519 - val_accuracy: 0.8048
Epoch 94/100
21/21 - 0s - loss: 0.4415 - accuracy: 0.7957 - val_loss: 0.4524 - val_accuracy: 0.8095
Epoch 95/100
21/21 - 0s - loss: 0.4399 - accuracy: 0.8065 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 96/100
21/21 - 0s - loss: 0.4387 - accuracy: 0.8065 - val_loss: 0.4546 - val_accuracy: 0.8095
Epoch 97/100
21/21 - 0s - loss: 0.4463 - accuracy: 0.7945 - val_loss: 0.4542 - val_accuracy: 0.8048
Epoch 98/100
21/21 - 0s - loss: 0.4447 - accuracy: 0.7993 - val_loss: 0.4542 - val_accuracy: 0.8143
Epoch 99/100
21/21 - 0s - loss: 0.4368 - accuracy: 0.8041 - val_loss: 0.4551 - val_accuracy: 0.8048
Epoch 100/100
21/21 - 0s - loss: 0.4395 - accuracy: 0.8053 - val_loss: 0.4501 - val_accuracy: 0.8095
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200

2.训练过程可视化

# 训练过程可视化
import matplotlib.pyplot as plt

def show_train_history(trian_history,train_metric,validation_metric):
  plt.plot(trian_history[train_metric])
  plt.plot(trian_history[validation_metric])
  plt.title('Train History')
  plt.ylabel(train_metric)
  plt.xlabel('epoch')
  plt.legend(['train','validation'],loc='upper left')
  plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
show_train_history(train_history.history,'loss','val_loss')
  • 1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5IKkPioq-1633101625327)(output_46_0.png)]

show_train_history(train_history.history,'accuracy','val_accuracy')
  • 1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kzL6nZ6W-1633101625328)(output_47_0.png)]

3.评估模型

loss,acc = model.evaluate(x_test,y_test)
  • 1
9/9 [==============================] - 0s 2ms/step - loss: 0.3703 - accuracy: 0.8435
  • 1
loss,acc
  • 1
(0.3702643811702728, 0.8435114622116089)
  • 1

四.预测

#@title
Jack_info = [0,'Jack',3,'male',23,1,0,5.0000,'S']
Rose_info = [1,'Rose',1,'female',20,1,0,100.0000,'S']
  • 1
  • 2
  • 3
x_pre = pd.DataFrame([Jack_info,Rose_info],columns=selected_cols)
x_pre
  • 1
  • 2
survivednamepclasssexagesibspparchfareembarked
00Jack3male23105.0S
11Rose1female2010100.0S
x_pre_features,y = prepare_data(x_pre)
from sklearn import preprocessing

minmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_pre_features = minmax_scale.fit_transform(x_pre_features) # 特征值标准化
y_pre = model.predict(x_pre_features)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
survived    False
pclass      False
sex         False
age         False
sibsp       False
parch       False
fare        False
embarked    False
dtype: bool
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
x_pre.insert(len(x_pre.columns),'surv_probabilty',y_pre)
  • 1
x_pre
  • 1
survivednamepclasssexagesibspparchfareembarkedsurv_probabilty
00Jack3male23105.0S0.058498
11Rose1female2010100.0S0.975978
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/145000
推荐阅读
相关标签
  

闽ICP备14008679号