当前位置:   article > 正文

ChatGLM系列八:微调医疗问答系统_chatgml2-6b 中医

chatgml2-6b 中医

一、ChatGLM2-6B

ChatGLM2-6B 是 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,同时引入了许多新特性,如:更强大的性能、更长的上下文、更高效的推理、更开放的协议等。

二、P-tuning v2

P-tuning v2 微调技术利用 deep prompt tuning,即对预训练 Transformer 的每一层输入应用 continuous prompts 。deep prompt tuning 增加了 continuo us prompts 的能力,并缩小了跨各种设置进行微调的差距,特别是对于小型模型和困难任务。
在这里插入图片描述
上图左边为 P-Tuning,右边为P-Tuning v2。P-Tuning v2 层与层之间的 continuous prompt 是相互独立的。

三、ChatGLM2-6B 模型下载

huggingface 地址:https://huggingface.co/THUDM/chatglm2-6b/tree/main
  • 1

在这里插入图片描述

四、数据集下载

https://github.com/Toyhom/Chinese-medical-dialogue-data
  • 1

数据格式:
在这里插入图片描述

五、数据预处理

import json
import pandas as pd

data_path = [
    "./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]

train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000


def doHandler():
    train_f = open(train_json_path, "a", encoding='utf-8')
    val_f = open(val_json_path, "a", encoding='utf-8')
    for path in data_path:
        data = pd.read_csv(path, encoding='ANSI')
        train_count = 0
        val_count = 0
        for index, row in data.iterrows():
            ask = row["ask"]
            answer = row["answer"]
            line = {
                "content": ask,
                "summary": answer
            }
            line = json.dumps(line, ensure_ascii=False)
            if train_count < train_size:
                train_f.write(line + "\n")
                train_count = train_count + 1
            elif val_count < val_size:
                val_f.write(line + "\n")
                val_count = val_count + 1
            else:
                break
    print("数据处理完毕!")
    train_f.close()
    val_f.close()


if __name__ == '__main__':
    doHandler()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

六、模型微调训练

git clone https://github.com/THUDM/ChatGLM2-6B
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install rouge_chinese nltk jieba datasets -i https://pypi.tuna.tsinghua.edu.cn/simple
  • 1
  • 2
  • 3

修改 ptuning 下的 train.sh 文件:

PRE_SEQ_LEN=300
LR=2e-2
NUM_GPUS=1

torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
    --do_train \
    --train_file data/train.json \
    --validation_file data/val.json \
    --preprocessing_num_workers 10 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /home/chatglm2/chatglm-6b \
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 300 \
    --max_target_length 1024 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
–standalone` 以单机模式训练。
–nnodes` 节点数。这里只有一个节点,设置为 1。
–nproc-per-node` 每个节点上的进程数。
–do_train` 执行训练任务。
–train_file` 训练数据文件路径, 上面生成的 train.json 文件。
–validation_file` 验证数据文件路径, 上面生成的 val.json 文件。
–preprocessing_num_workers` 指定数据预处理时的 workers 数。
–prompt_column` 输入信息的字段名称。
–response_column` 输出信息的字段名称。
–overwrite_cache` 覆盖缓存文件。
–model_name_or_path` 预训练模型的名称或路径,注意这里我是用的下载后的模型存放地址,需要修改为你的。
–output_dir` 模型保存目录。
–overwrite_output_dir` 覆盖输出目录。
–max_source_length` 输入文本的最大长度。
–max_target_length` 输出文本的最大长度。
–per_device_train_batch_size` 训练时的批次大小。
–per_device_eval_batch_size` 验证时的批次大小。
–gradient_accumulation_steps` 累积多少个梯度之后再进行一次反向传播。
–predict_with_generate` 预测时使用生成模式。
–max_steps` 最大训练轮数。
–logging_steps` 多少轮打印一次日志。
–save_steps` 多少轮保存一次模型。
–learning_rate` 初始学习率。
–pre_seq_len` 预处理时选取的序列长度。
–quantization_bit` 量化位大小。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

执行脚本

bash train.sh
  • 1

七、模型测试

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel, AutoConfig
import uvicorn, json, datetime
import torch
import os


def main():
    pre_seq_len = 300
    # 训练权重地址
    checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"

    tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)
    config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)
    model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)
    prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
    new_prefix_state_dict = {}
    for k, v in prefix_state_dict.items():
        if k.startswith("transformer.prefix_encoder."):
            new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
    model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    # 量化
    model = model.quantize(4)
    model.eval()

    # 问题
    question = "突然感到了不适,去检查后竟然得了这个病,请问:宝宝白天爱磨牙会是哪些情况呢"

    response, history = model.chat(tokenizer,
                                   question,
                                   history=[],
                                   max_length=2048,
                                   top_p=0.7,
                                   temperature=0.95)

    print("回答:", response)

    if torch.backends.mps.is_available():
        torch.mps.empty_cache()


if __name__ == '__main__':
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel, AutoConfig
import uvicorn, json, datetime
import torch
import os

app = FastAPI()

# 允许所有域的请求
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/")
async def create_item(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    history = json_post_list.get('history')
    max_length = json_post_list.get('max_length')
    top_p = json_post_list.get('top_p')
    temperature = json_post_list.get('temperature')
    response, history = model.chat(tokenizer,
                                   prompt,
                                   history=history,
                                   max_length=max_length if max_length else 2048,
                                   top_p=top_p if top_p else 0.7,
                                   temperature=temperature if temperature else 0.95)
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "history": history,
        "status": 200,
        "time": time
    }
    log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
    print(log)
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    return answer


if __name__ == '__main__':
    pre_seq_len = 300
    checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"

    tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)
    config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)
    model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)
    prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
    new_prefix_state_dict = {}
    for k, v in prefix_state_dict.items():
        if k.startswith("transformer.prefix_encoder."):
            new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
    model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    ## 量化
    model = model.quantize(4)
    model = model.cuda()
    model.eval()
    uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/81494
推荐阅读
相关标签
  

闽ICP备14008679号