赞
踩
六步法
imoprt
train_data, test_data
逐层搭建网络结构 model=tf.keras.model.Sequential
在 model.compile()中配置训练方法,选择训练时使用的优化器、损失函数和最终评价指标。
在 model.fit()中执行训练过程,告知训练集和测试集的输入值和标签、每个 batch 的大小(batchsize)和数据集的迭代次数(epoch)。
使用 model.summary()打印网络结构,统计参数数目。
Sequential 函数是一个容器, 描述神经网络的网络结构,Sequential函数的输入参数中描述从输入层到输出层的网络结构。
拉直层: tf.keras.layers.Flatten()
拉直层可以变换张量的尺寸,把输入特征拉直为一维数组,是不含计算参数的层。
全连接层: tf.keras.layers.Dense( 神经元个数,activation=”激活函数”,kernel_regularizer=”正则化方式”)
activation(字符串给出)可选 relu、 softmax、 sigmoid、 tanh 等
kernel_regularizer 可选 tf.keras.regularizers.l1()、tf.keras.regularizers.l2();
卷积层:tf.keras.layers.Conv2D( filter = 卷积核个数,kernel_size = 卷积核尺寸,strides = 卷积步长,padding = “valid” or “same”)
LSTM 层: tf.keras.layers.LSTM()。
Compile 用于配置神经网络的训练方法,告知训练时使用的优化器、损失函数和准确率评测标准。
optimizer可以是字符串形式给出的优化器名字,也可以是函数形式,使用函数形式可以设置学习率、动量和超参数。
可选项包括:
‘sgd’or tf.optimizers.SGD( lr=学习率, decay=学习率衰减率, momentum=动量参数)
‘adagrad’or tf.keras.optimizers.Adagrad(lr=学习率, decay=学习率衰减率)
‘adadelta’or tf.keras.optimizers.Adadelta(lr=学习率, decay=学习率衰减率)
‘adam’or tf.keras.optimizers.Adam (lr=学习率, decay=学习率衰减率)
Loss 可以是字符串形式给出的损失函数的名字,也可以是函数形式。
可选项包括:
‘mse’or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy
or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
损失函数常需要经过 softmax 等函数将输出转化为概率分布的形式。from_logits 则用来标注该损失函数是否需要转换为概率的形式, 取 False 时表示转化为概率分布,取 True 时表示没有转化为概率分布,直接输出。
Metrics 标注网络评测指标。
可选项包括:
‘accuracy’: y_和 y 都是数值,如 y_=[1] y=[1]。
‘categorical_accuracy’: y_和 y 都是以独热码和概率分布表示。
如 y_=[0, 1, 0], y=[0.256, 0.695, 0.048]。
‘sparse_ categorical_accuracy’: y_是以数值形式给出, y 是以独热码形式给出。
如 y_=[1],y=[0.256, 0.695, 0.048]。
model.fit(训练集的输入特征, 训练集的标签, batch_size, epochs,
validation_data = (测试集的输入特征,测试集的标签),
validataion_split = 从测试集划分多少比例给训练集,
validation_freq = 测试的 epoch 间隔次数)
上图是 model.summary()对鸢尾花分类网络的网络结构和参数统计,对于一个输入为 4 输出为 3 的全连接网络,共有 15 个参数。
from bert4keras.tokenizers import Tokenizer
tokenizer = Tokenizer(dict_path, do_lower_case=True)
sentence = "雀巢裁员4000人:时代抛弃你时,连招呼都不会打!"
tokens = tokenizer.tokenize(sentence)
print(tokens)
encode = tokenizer.encode(tokens, maxlen=128)
print(encode)
['[CLS]', '雀', '巢', '裁', '员', '4000', '人', ':', '时', '代', '抛', '弃', '你', '时', ',', '连', '招', '呼', '都', '不', '会', '打', '!', '[SEP]']
([101, 7411, 2338, 6161, 1447, 8442, 782, 8038, 3198, 807, 2837, 2461, 872, 3198, 8024, 6825, 2875, 1461, 6963, 679, 833, 2802, 8013, 102], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。