当前位置:   article > 正文

60分钟吃掉ChatGLM2-6b微调范例~

chatglm2微调epoch

干货预警:这可能是你能够找到的最容易懂的最完整的,适用于各种NLP任务的开源LLM的finetune教程~

ChatGLM2-6b是清华开源的小尺寸LLM,只需要一块普通的显卡(32G较稳妥)即可推理和微调,是目前社区非常活跃的一个开源LLM。

本范例使用非常简单的,外卖评论数据集来实施微调,让ChatGLM2-6b来对一段外卖评论区分是好评还是差评。

可以发现,经过微调后的模型,相比直接 3-shot-prompt 可以取得明显更好的效果。

值得注意的是,尽管我们以文本分类任务为例,实际上,任何NLP任务,例如,命名实体识别,翻译,聊天对话等等,都可以通过加上合适的上下文,转换成一个对话问题,并针对我们的使用场景,设计出合适的数据集来微调开源LLM.

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook源代码,以及waimai数据集下载链接~

〇,预训练模型

我们需要从 https://huggingface.co/THUDM/chatglm2-6b 下载chatglm2的模型。

国内可能速度会比较慢,总共有14多个G,网速不太好的话,大概可能需要一两个小时。

如果网络不稳定,也可以手动从这个页面一个一个下载全部文件然后放置到 一个文件夹中例如 'chatglm2-6b' 以便读取。

  1. from transformers import  AutoModel,AutoTokenizer
  2. model_name = "chatglm2-6b" #或者远程 “THUDM/chatglm2-6b”
  3. tokenizer = AutoTokenizer.from_pretrained(
  4.     model_name, trust_remote_code=True)
  5. model = AutoModel.from_pretrained(model_name,trust_remote_code=True).half().cuda()
Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]
  1. prompt = """文本分类任务:将一段用户给外卖服务的评论进行分类,分成好评或者差评。
  2. 下面是一些范例:
  3. 味道真不错 -> 好评
  4. 太辣了,吃不下都  -> 差评
  5. 请对下述评论进行分类。返回'好评'或者'差评',无需其它说明和解释。
  6. xxxxxx ->
  7. """
  8. def get_prompt(text):
  9.     return prompt.replace('xxxxxx',text)
  1. response, his = model.chat(tokenizer, get_prompt('味道不错,下次还来'), history=[])
  2. print(response)
好评
  1. #增加4个范例
  2. his.append(("太贵了 -> ","差评"))
  3. his.append(("非常快,味道好 -> ","好评"))
  4. his.append(("这么咸真的是醉了 -> ","差评"))
  5. his.append(("价格感人 优惠多多 -> ","好评"))

我们来测试一下

  1. response, history = model.chat(tokenizer, "一言难尽啊 -> ", history=his)
  2. print(response) 
  3. response, history = model.chat(tokenizer, "还凑合一般般 -> ", history=his)
  4. print(response) 
  5. response, history = model.chat(tokenizer, "我家狗狗爱吃的 -> ", history=his)
  6. print(response)
  1. 差评
  2. 差评
  3. 好评
  1. #封装成一个函数吧~
  2. def predict(text):
  3.     response, history = model.chat(tokenizer, f"{text} ->", history=his,
  4.     temperature=0.01)
  5.     return response 
  6. predict('死鬼,咋弄得这么有滋味呢') #在我们精心设计的一个评论下,ChatGLM2-6b终于预测错误了~
'差评'

我们拿外卖数据集测试一下未经微调,纯粹的 6-shot prompt 的准确率。

  1. import pandas as pd 
  2. import numpy as np 
  3. import datasets 
  4. df = pd.read_csv("data/waimai_10k.csv")
  5. df['tag'] = df['label'].map({0:'差评',1:'好评'})
  6. df = df.rename({'review':'text'},axis = 1)
  7. dfgood = df.query('tag=="好评"')
  8. dfbad = df.query('tag=="差评"').head(len(dfgood)) #采样部分差评,让好评差评平衡
  9. df = pd.concat([dfgood,dfbad])
  10. print(df['tag'].value_counts())
  1. 好评 4000
  2. 差评 4000
  1. ds_dic = datasets.Dataset.from_pandas(df).train_test_split(
  2.     test_size = 2000,shuffle=True, seed = 43)
  3. dftrain = ds_dic['train'].to_pandas()
  4. dftest = ds_dic['test'].to_pandas()
  5. dftrain.to_parquet('data/dftrain.parquet')
  6. dftest.to_parquet('data/dftest.parquet')
preds = ['' for x in dftest['tag']]
  1. from tqdm import tqdm 
  2. for i in tqdm(range(len(dftest))):
  3.     text = dftest['text'].loc[i]
  4.     preds[i] = predict(text)
dftest['pred'] = preds
dftest.pivot_table(index='tag',columns = 'pred',values='text',aggfunc='count')

c7236a41c098875e7b9e7d2796e0fcd2.png

acc = len(dftest.query('tag==pred'))/len(dftest)
print('acc=',acc)
acc= 0.878

可以看到,微调之前,我们的模型准确率为87.8%,下面我们通过6000条左右数据的微调,看看能否把acc打上去~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/81400
推荐阅读
相关标签