赞
踩
from pypro.chapters03.demo03_数据获取与处理 import train_list, label_list, val_train_list, val_label_list import tensorflow as tf from transformers import TFBertForSequenceClassification bert_model = "bert-base-chinese" model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32) model.compile(metrics=['accuracy'], loss=tf.nn.sigmoid_cross_entropy_with_logits) model.summary() result = model.fit(x=train_list[:24], y=label_list[:24], batch_size=12, epochs=1) print(result.history) # 保存模型(模型保存的本质就是保存训练的参数,而对于深度学习而言还保存神经网络结构) model.save_weights('../data/model.h5') model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32) model.load_weights('../data/model.h5') result = model.predict(val_train_list[:12]) # 预测值 print(result) result = tf.nn.sigmoid(result) print(result) result = tf.cast(tf.greater_equal(result, 0.5), tf.float32) print(result)
这段代码的目的是利用TensorFlow和transformers库来进行文本序列的分类任务。下面是整体流程的概述和逐步计划:
导入必要的库和数据:
pypro.chapters03.demo03_数据获取与处理
的模块中导入了四个列表:train_list
, label_list
, val_train_list
, val_label_list
。这些列表分别包含训练数据、训练标签、验证数据和验证标签。初始化预训练的BERT模型:
bert-base-chinese
模型初始化一个用于序列分类的BERT模型。编译模型:
模型摘要:
训练模型:
输出训练结果:
保存模型权重:
model.h5
。加载模型权重:
模型预测:
激活函数处理:
转换预测结果:
下面逐行解释上述代码:
from pypro.chapters03.demo03_数据获取与处理 import train_list, label_list, val_train_list, val_label_list
这行代码从demo03_数据获取与处理
模块中导入四个列表。这些列表包含训练数据和标签(train_list
, label_list
),以及验证数据和标签(val_train_list
, val_label_list
)。这是数据准备步骤的一部分。
import tensorflow as tf
这行代码导入了TensorFlow库,它是一个广泛用于机器学习和深度学习任务的开源库。
from transformers import TFBertForSequenceClassification
这里导入了transformers
库中的TFBertForSequenceClassification
类。transformers
库包含了许多预训练模型,用于NLP任务,这里特别导入的是适用于TensorFlow的BERT模型,用于序列分类任务。
bert_model = "bert-base-chinese"
定义一个字符串变量bert_model
,它保存了预训练模型的名称。在这里,我们将使用中文BERT基础模型。
model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)
使用bert-base-chinese
模型和TFBertForSequenceClassification
类创建一个新的序列分类模型实例。num_labels=32
表明有32个不同的类别用于分类。
model.compile(metrics=['accuracy'], loss=tf.nn.sigmoid_cross_entropy_with_logits)
编译模型,设置度量为准确度(accuracy
),并使用sigmoid_cross_entropy_with_logits
作为损失函数,这通常用于二分类问题,但在这里,由于是多标签分类(32个类别),可能是对每个标签进行二分类。
model.summary()
输出模型的摘要信息,包括模型中的层,每层的输出形状和参数数量等详细信息。
result = model.fit(x=train_list[:24], y=label_list[:24], batch_size=12, epochs=1)
开始训练模型,仅使用前24个样本作为训练数据和标签。批处理大小设置为12,意味着每次梯度更新将基于12个样本。epochs=1
表示整个数据集只通过模型训练一次。
print(result.history)
打印出训练过程中的历史数据,如损失和准确度。
model.save_weights('../data/model.h5')
保存训练好的模型权重到本地文件model.h5
。
model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)
再次初始化一个模型,用于演示如何从头加载一个模型。
model.load_weights('../data/model.h5')
加载先前保存的模型权重。
result = model.predict(val_train_list[:12]) # 预测值
使用验证数据集中的前12个样本进行预测,得到模型的输出。
print(result)
打印出预测结果。
result = tf.nn.sigmoid(result)
将模型的原始输出通过sigmoid函数转换,得到一个在0到1之间的值,表示属于每个类别的概率。
print(result)
再次打印经过sigmoid激活函数处理后的预测结果。
result = tf.cast(tf.greater_equal(result, 0.5), tf.float32)
将sigmoid输出的概率转换为二分类结果。对于每个标签,如果概率大于或等于0.5,则认为该样本属于该标签(转换为1),否则不属于(转换为0)。
print(result)
最后,打印出转换后的分类结果。
整体而言,这段代码展示了使用预训练的BERT模型在一个多标签文本分类任务上的训练、保存、加载和预测的完整过程。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。