当前位置:   article > 正文

微调ChatGLM2-6b模型解决文本二分类任务代码解读_chatglm2-6b调用语句

chatglm2-6b调用语句

微调ChatGLM2-6b模型解决文本二分类任务代码解读

1. 导入包

import pandas as pd
  • 1

导入pandas,用于后续读取和处理CSV数据。pandas为处理表格和时间序列数据提供了高效的数据结构。

2. 加载数据

train_df = pd.read_csv('./csv_data/train.csv')
test_df = pd.read_csv('./csv_data/test.csv')

print(train_df.info())
  • 1
  • 2
  • 3
  • 4
  • pd.read_csv()读取CSV文件,返回DataFrame格式的数据。
  • train_df和test_df分别存储训练集和测试集。
  • print(train_df.info()) 打印数据概览,包括列名、类型、非空值个数等信息,对数据有个直观了解。

3. 制作数据集

res = []

for i in range(len(train_df)):
  # 构造每一项
  tmp = {
      "instruction": "判断是否医学论文",
      "input": "标题:"+title+" 摘要:"+abstract,  
      "output": label
  }
  
  res.append(tmp)

# 保存到json文件
import json
with open('paper_label.json', 'w') as f:
  json.dump(res, f)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 利用循环,遍历训练集每一行,根据标题和摘要构造prompt格式的训练样本。
  • prompt包含三项:instruction(提示词),input(输入文本),output(目标输出)。
  • 将训练样本保存在JSON文件中,方便后续训练。

4. 微调模型

from peft import PeftModel
from transformers import AutoTokenizer, AutoModel 

model_path = "./chatglm2-6b"

model = AutoModel.from_pretrained(model_path) 
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = PeftModel.from_pretrained(model, 'output_dir')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 从huggingface加载预训练语言模型ChatGLM2-6B和相应的tokenizer。
  • 使用peft工具进行微调,生成适用于当前任务的fine-tuned模型。
  • 微调需要准备prompt格式的数据。

5. 定义预测函数

def predict(text):

  input = f"判断是否医学论文:{text}"
  
  response = model.chat(tokenizer, input)

  return response
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 封装预测逻辑在一个函数中。
  • 根据输入文本构造prompt输入,调用模型生成响应。
  • 将模型输出返回。

6. 预测测试集

predictions = []

for i in range(len(test_df)):
  
  text = 获取测试集论文信息
  
  pred = predict(text)
  
  predictions.append(pred)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 循环遍历测试集每一项。
  • 构造输入文本,调用预测函数生成结果。
  • 将预测结果保存到predictions列表中。

7. 生成提交文件

submit = test_df[['id', 'keywords', 'predictions']]

submit.to_csv('submit.csv', index=False) 
  • 1
  • 2
  • 3
  • 从测试集中取出需要的列,和预测一起构造提交文件格式。
  • to_csv()可将DataFrame写入csv文件。

总结

ChatGLM2-6B是一个基于Transformer架构预训练的巨大语言模型,它通过在大规模文本语料上进行自监督学习,获得了强大的语言理解和生成能力。

要利用ChatGLM2-6B进行文本二分类,主要分以下几个步骤:

  • 数据准备
    收集文本分类任务需要的训练集和验证集/测试集,并转换成Prompt格式。prompt包含分类说明、文本输入和期望输出。
  • 模型微调
    在预训练ChatGLM2-6B模型基础上,使用prompt格式的数据进一步训练,使模型适应文本二分类任务。这一步会稍微调节模型参数。
  • Prompt设计
    为模型设计合适的prompt模板,指导模型对新输入的文本进行判断和分类。
  • 模型推理
    对新输入的文本,先转换为prompt输入,然后输入到fine-tuned的ChatGLM2-6B中,从生成的响应中解析出分类结果。
  • 模型输出解析
    解析模型生成的响应,提取分类结果,转换为标准化的0/1输出。

通过微调+Prompt设计,ChatGLM2-6B可以高效地进行文本二分类,完全利用了其强大的语言理解能力,并优于传统的分类模型。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/350309
推荐阅读
相关标签
  

闽ICP备14008679号