赞
踩
最近在做多伦多大学Rotman Commerce主办的boardroom的比赛,是一个business case商业案例类竞赛。今年的题目大致意思是一家律师事务所希望引入AI来提升工作效率。一拿到题目我就知道这把稳了。然后仔细分析,第一晚上就写出了项目初稿洋洋洒洒2400字。大体分析一下,律师行要使用AI,唯一的选择是NLP。想想其他的,cv的律师也不需要,ocr早就烂大街了,试问现在还有谁不会用扫描软件,至于处理数字的更是比较遥远。既然主题定下来了,就要思考NLP可以在哪些方面应用。我们把整个项目分成了internal内部和external外部两个部分。内部主要是针对生产流程的优化,比如文档整理,律师助手等等;外部主要是针对客户的,最简单就是做一个可以给客户咨询的LLM。详细的方案是:
文档管理系统(内部)
智能合约与客户服务(外部)
合规性审查(外部)
律师助手(外部)
这是从我的初稿中截取出来的部分,实际上到了终稿,也基本上是这个框架。由此可见,几乎所有的业务都是基于NLP。因此,我们决定做一个demo,于是我首先做了这个关用于服务外部可的LLM的demo。
实际上这个模型制作起来并不复杂,简而言之就是基于已有的大模型进行微调,使用法律相关的数据集进行微调。我们选择的基础模型base model是llama3-8B,一是因为这是在这个参数量级别当今最强的模型,其二是因为相比于其他的解决方案,LLama的社区更加完善和庞大,finetune的解决方案也更多更成熟。我们最终选择了LLaMA-Factory这个项目作为微调的方案。
那么下一步就是选择数据集。我最初的设想是去一些法律论坛上怕一些问答下来,但是找了一圈发现美加这边的法律论坛大多都是有用户基于自己的经验回答,正确性暂且不论,言辞也不够严谨。最终选择了HuggingFace上的两个数据集:dzunggg/legal-qa-v1,ibunescu/qa_legal_dataset_train和coastalcph/lex_glue。其中第一个数据集我全部使用了,因为都是很专业的问答,问是真心问的(贴合实际的,生活中常见的),答也是认真答的(有参考依据的,用词严谨的)。第二个数据集更大,但是质量有所下降,我猜测就是从哪个法律论坛上爬取的。第三个数据集是质量最好的一个,分为好几个部分,case_hold是针对案例的描述引用出了对应的法律条文,严谨阐述罪名,scotus是美国最高法院的文书,unfair_tos是50个软件的“用户条例”分成一个一个句子,标出了其中包含不公平条款的句子。之所以没有使用这个数据集,是因为这个数据集没有贴合生活的问答,都是文书形式的,及其正式,但不可否认这一定是质量顶尖的NLP类法律数据集之一。使用dataset将数据集进行下载,使用LLaMA-Factory对数据集进行处理,就可以开始训练。
import json import re from datasets import load_dataset # 加载数据集 dataset = load_dataset("dzunggg/legal-qa-v1") # 定义一个函数去掉开头的“Q:”和“A:” def remove_prefix(example): if example['question'].startswith('Q:'): example['question'] = example['question'][2:].strip() if example['answer'].startswith('A:'): example['answer'] = example['answer'][2:].strip() return example # 去掉转义字符和链接的函数 def clean_text(text): # 去掉转义字符 text = text.replace('\n', ' ').replace('\r', ' ') # 去掉链接 text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) return text # 处理数据集 dataset = dataset.map(remove_prefix) # 生成一个通用的instruction def generate_instruction(): return "Please provide detailed answers to the following legal questions." # 将数据集转换为新的格式 def convert_to_new_format(example): question = clean_text(example['question']) answer = clean_text(example['answer']) return { "instruction": generate_instruction(), "input": question, "output": answer } # 应用转换函数 new_dataset = {} new_dataset['train'] = [convert_to_new_format(example) for example in dataset['train']] # 将数据集保存为JSON文件 with open('legal_qa_v1_train.json', 'w') as f: json.dump(new_dataset['train'], f, indent=4, ensure_ascii=False) print("Dataset has been saved to legal_qa_v1_train.json files.")
#LLaMA-Factory 要求格式
import json
%cd /content/LLaMA-Factory/
NAME = "Llama-3"
AUTHOR = "LLaMA Factory"
with open("data/legal_qa_v1_train.json", "r", encoding="utf-8") as f:
dataset = json.load(f)
for sample in dataset:
sample["output"] = sample["output"].replace("{{"+ "name" + "}}", NAME).replace("{{"+ "author" + "}}", AUTHOR)
with open("data/legal_qa_v1_train.json", "w", encoding="utf-8") as f:
json.dump(dataset, f, indent=2, ensure_ascii=False)
因为有了LLaMA-Factory的加持,最复杂最核心的finetune环节居然成了最简单好事最短的一个环节。直接使用example的code,更改一下参数一键运行,时间根据数据集大小而定。如果我仅使用上文提到的第一个数据集的话,30个epoch仅需要半小时不到。
# 使用WebUI
%cd /content/LLaMA-Factory/
!GRADIO_SHARE=1 llamafactory-cli webui
import json args = dict( stage="sft", # do supervised fine-tuning do_train=True, model_name_or_path="mattshumer/Llama-3-8B-16K", #如果在国内不方便使用HuggingFace的可以下载模型然后改成模型路径 dataset="legal_qa_v1_train", # use alpaca and identity datasets template="llama3", # use llama3 prompt template finetuning_type="lora", # use LoRA adapters to save memory lora_target="all", # attach LoRA adapters to all linear layers output_dir="llama3_lora", # the path to save LoRA adapters per_device_train_batch_size=8, # the batch size,实测L20 48G最高batch可以设置为48 gradient_accumulation_steps=6, # the gradient accumulation steps lr_scheduler_type="cosine", # use cosine learning rate scheduler logging_steps=10, # log every 10 steps warmup_ratio=0.1, # use warmup scheduler save_steps=1000, # save checkpoint every 1000 steps learning_rate=1e-4, # the learning rate num_train_epochs=10.0, # the epochs of training max_samples=500, # use 500 examples in each dataset max_grad_norm=1.0, # clip gradient norm to 1.0 quantization_bit=8, # use 4-bit QLoRA loraplus_lr_ratio=16.0, # use LoRA+ algorithm with lambda=16.0 use_unsloth=True, # use UnslothAI's LoRA optimization for 2x faster training use_unsloth=False, #使用unsloth加速 fp16=True, # use float16 mixed precision training overwrite_output_dir=True, ) json.dump(args, open("train_llama3.json", "w", encoding="utf-8"), indent=2) %cd /content/LLaMA-Factory/ !llamafactory-cli train train_llama3.json
得到的结果是一个80MB左右的safetensors文件,这个可以理解为是一个大模型的“补丁”。再次使用LLaMA-Factory的infer功能即可快速看到模型效果。如果觉得不太行就回到上一步重新训练或者resume续训,如果觉得OK就可以进入下一步,merge。
infer代码
from llamafactory.chat import ChatModel from llamafactory.extras.misc import torch_gc %cd /content/LLaMA-Factory/ args = dict( model_name_or_path="mattshumer/Llama-3-8B-16K", adapter_name_or_path="llama3_lora", # load the saved LoRA adapters, 就是刚刚生成的“补丁”的地址 template="llama3", # same to the one in training finetuning_type="lora", # same to the one in training quantization_bit=8, # load 4-bit quantized model,可选4或8 use_unsloth=True, # use UnslothAI's LoRA optimization for 2x faster generation ) chat_model = ChatModel(args) background_prompt = """ 自定义背景提示词 """ messages = [] print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") while True: query = input("\nUser: ") if query.strip() == "exit": break if query.strip() == "clear": messages = [] torch_gc() print("History has been removed.") continue # Combine the user input with the background prompt combined_query = background_prompt + query messages.append({"role": "user", "content": combined_query}) print("\n\nAssistant: ", end="", flush=True) response = "" for new_text in chat_model.stream_chat(messages,temperature=,max_new_token=): # 自定义参数 print(new_text, end="", flush=True) response += new_text print() messages.append({"role": "assistant", "content": response}) torch_gc()
merge,顾名思义,就是将补丁和原生大模型进行融合,得到一个新的大模型。这一步还是借用LLaMA-Factory的merge功能,大约2分钟即可完成。
import json %cd /content/LLaMA-Factory args = dict( model_name_or_path="nvidia/Llama3-ChatQA-1.5-8B", # use official non-quantized Llama-3-8B-Instruct model adapter_name_or_path="/content/LLaMA-Factory/llama3_lora_11", # 生成的"补丁"地址 template="llama3", # same to the one in training finetuning_type="lora", # same to the one in training export_dir="llama3_lora_merged", # the path to save the merged model,输出目录 export_size=2, # the file shard size (in GB) of the merged model export_device="cuda", # the device used in export, can be chosen from `cpu` and `cuda` #export_hub_model_id="your_id/your_model", # the Hugging Face hub ID to upload model ) json.dump(args, open("merge_llama3.json", "w", encoding="utf-8"), indent=2) %cd /content/LLaMA-Factory/ !llamafactory-cli export merge_llama3.json
至此,整个过程就算完成。提交到huggingface的代码也附上
import os from huggingface_hub import HfApi, HfFolder api = HfApi() # token = HfFolder.get_token() token = 'HuggingFaceToken' # 如果提前使用huggingface-cli login登陆后或者在环境变量中设置过可以直接注释 model_dir = "/content/LLaMA-Factory/llama3_lora_merged" repo_id_base = "StevenChen16/llama3-8b-Lawyer" #要上传的仓库 # 创建主仓库 api.create_repo(repo_id=repo_id_base, token=token, private=False, exist_ok=True) #注意如果前文将token注释了这里需要把token参数删掉 # 遍历模型文件夹并上传每个模型 for model_name in os.listdir(model_dir): model_path = os.path.join(model_dir, model_name) if os.path.isfile(model_path): path_in_repo = f"{model_name}" api.upload_file( path_or_fileobj=model_path, path_in_repo=path_in_repo, repo_id=repo_id_base, token=token )
大功臣LLaMA-Factory仓库地址:
https://github.com/hiyouga/LLaMA-Factory.git
做好的space:https://huggingface.co/spaces/StevenChen16/llama3-8b-Lawyer
模型地址:StevenChen16/llama3-8b-Lawyer
国内不方便访问hf的可以访问我们参赛队伍的网站(顺便推销一下做的网站):https://wealthwizards.org/static/ai_lawyer/
GitHub:https://github.com/StevenChen16/lawyer-llama3-8b.git
训练代码:colab
TODO:
使用coastalcph/lex_glue的unfair_tos数据集训练一个识别不公平条款和合规性审查的模型。
使这个模型支持中文并使用中国的法律文书进行finetune
2024/06/06 更新
完成了compliance review (合规性审查) 的搭建与部署。
依然使用了LLama-Factory,但是这次使用的base model基础模型是号称性能超过GPT-4的 “princeton-nlp/Llama-3-Instruct-8B-SimPO ”,使用lex_glue数据集的unfair_tos部分进行训练。另外我自己加入了约1000个判断条款。数据集我上传到了huggingface上:StevenChen16/unfair_tos。训练好的模型(adapter)我放在了huggingface上:StevenChen16/llama3-8b-compliance-review-adapter。注意这是一个adapter,调用的时候还是需要调用base model。下面是一个示例代码,使用Gradio:
import gradio as gr from llamafactory.chat import ChatModel from llamafactory.extras.misc import torch_gc import re def split_into_sentences(text): sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s') sentences = sentence_endings.split(text) return [sentence.strip() for sentence in sentences if sentence] def process_paragraph(paragraph, progress=gr.Progress()): sentences = split_into_sentences(paragraph) results = [] total_sentences = len(sentences) for i, sentence in enumerate(sentences): progress((i + 1) / total_sentences) messages.append({"role": "user", "content": sentence}) sentence_response = "" for new_text in chat_model.stream_chat(messages, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=300): sentence_response += new_text.strip() category = sentence_response.strip().lower().replace(' ', '_') if category != "fair": results.append((sentence, category)) else: results.append((sentence, "fair")) messages.append({"role": "assistant", "content": sentence_response}) torch_gc() return results args = dict( model_name_or_path="princeton-nlp/Llama-3-Instruct-8B-SimPO", # 使用量化的 Llama-3-8B-Instruct 模型 adapter_name_or_path="StevenChen16/llama3-8b-compliance-review-adapter", # 加载保存的 LoRA 适配器 template="llama3", # 与训练时使用的模板相同 finetuning_type="lora", # 与训练时使用的微调类型相同 quantization_bit=8, # 加载 4-bit 量化模型 use_unsloth=True, # 使用 UnslothAI 的 LoRA 优化以加速生成 ) chat_model = ChatModel(args) messages = [] # 定义类型到颜色的映射 label_to_color = { "fair": "green", "limitation_of_liability": "red", "unilateral_termination": "orange", "unilateral_change": "yellow", "content_removal": "purple", "contract_by_using": "blue", "choice_of_law": "cyan", "jurisdiction": "magenta", "arbitration": "brown", } with gr.Blocks() as demo: with gr.Row(equal_height=True): with gr.Column(): input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...") btn = gr.Button("Process") with gr.Column(): output = gr.HighlightedText(label="Processed Paragraph", color_map=label_to_color) progress = gr.Progress() def on_click(paragraph): results = process_paragraph(paragraph, progress=progress) return results btn.click(on_click, inputs=input_text, outputs=[output]) demo.launch(share=True)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。