当前位置:   article > 正文

rwkv模型lora微调之accelerate和deepspeed训练加速

rwkv模型

       

目录

一、rwkv模型简介

二、lora原理简介

三、rwkv-lora微调

1、数据整理

2、环境搭建

a、Dockerfile编写

b、制造镜像

c、容器启动

3、训练代码修改

四、模型推理

1、模型推理

2、lora权重合并

3、推理web服务

五、总结


        由于业务采用的ChatGLM模型推理成本太大了,希望降低模型推理成本。因此对rwkv_1.5B模型进行了预研和业务领域的验证。为了快速验证,采用了lora+accelerate+deepspeed的训练方式。微调的过程中对rwkv模型认识更加深刻,同时对于docker训练环境搭建也更加熟悉了。这篇博客就分享一下这次微调中的一些实践,主要是关于训练流程拉通和rwkv模型在业务领域的一些结论。

一、rwkv模型简介

                rwkv模型是国人研发的一个非常优秀的模型,采用RNN架构代码目前主流的attention机制的transformer架构,在时间复杂度和空间复杂度都减少比较多的情况下,还能取得非常不错的效果,在各个榜单都有上榜。

       ​​

      上图是rwkv模型语言建模的架构,可以看到舍弃了attention机制,采用time mix 和channel mix模块。 

二、lora原理简介

      论文LoRA: Low-Rank Adaptation of Large Language Models 开发了一种方法,专为微调大模型减小显存。如下图:

       

   

对于一个参数,在微调的时候不直接微调W,而是把W通过低秩分解为两个小矩阵B和A的乘积,然后学习更新B和A,从而达到减少参数量和梯度等,同时保证模型lora微调后的效果和全参数微调效果相当。实现的时候会在BAx乘以一个系数,一般是lora_alpha/lora_rank的比值,注意lora_rank越大可学习的参数越多,显存占用就越多。

实践一般采用peft来实现对模型的linear层进行weight分解,使用方法如下:

  1. model初始化
  2. ......
  3. peft_config = LoraConfig(
  4. peft_type="LORA",
  5. task_type=TaskType.CAUSAL_LM,
  6. inference_mode=False,
  7. r=args.lora_rank,
  8. lora_alpha=args.lora_alpha,
  9. lora_dropout=args.lora_dropout,
  10. target_modules=args.target_modules.split(","),
  11. )
  12. model = get_peft_model(model, peft_config)
  13. ......
  14. model训练和保存
  15. model_state_dict = lora.lora_state_dict(model)
  16. torch.save(path,model_state_dict )

三、rwkv-lora微调

        rwkv的微调主要的重点内容在于数据的整理(整理成模型可训练的格式)、训练环境的搭建、训练代码的修改和最后的模型效果评估,其中至于怎么样微调才能获得比较好的效果,本文不予讨论。由于rwkv支持2中数据格式,一种是question+answer拼接,另外一种是instruction+input+response拼接;目前1.5B,rwkv开源了v4和v5版本的权重,因此这里会做4次实验,用相同的业务数据构成训练集和测试集,使用不用的权重和数据指令拼接方式进行实验。

1、数据整理

qa指令拼接——适合做问答类

{"text": "Question: 问题\n\nAnswer: 答案"}

iir指令拼接——适合做阅读理解问答

{"text": "Instruction:基于专业背景的知识问题\n\nInput:专业领域的资料背景知识内容\n\nResponse:基于上述的专业回答"}

其中Instruction 是指示,Input 是需要操作的数据(注意Input可以为空),Response是答案

2、环境搭建

        官方代码库指定的环境直接安装就好了,不过安装的过程中要注意机器的显卡驱动一定要比安装的cuda版本要高,并且cuda版本的算力不能低于显卡的算力(大多数情况下,显卡是支持一定的低版本的cuda和torch的);torch的版本要和cuda的版本一致,比如4090显卡安装了12.0的显卡驱动,安装了cuda11.8,那么torch也要安装cuda11.8的版本 torch2.0_cu118。rwkv有自己实现的cuda算子需要python调用C++和nvcc来编译作为torch的扩展,所以要严格匹配版本,不然会报显卡算力过高和torch版本不匹配,cuda和torch版本不匹配等错误。C++编译的时候还需要完整的libso库文件,由于本人使用的机器多人使用,不好升级libso库文件——错误操作可能会导致linux系统出错。稳妥起见直接使用docker来搭建训练环境,并且在docker中训练。物理机器上安装docker,编写dockerfile后,制作镜像,启动容器然后训练就OK了。

a、Dockerfile编写
  1. ##build 镜像
  2. #docker build -t images_name(images_name:tag) -f ./Dockerfile .
  3. ##运行容器 --gpus all 宿主机上的显卡可用 --ipc host 代表与宿主机器共享命名空间,即让Docker容器和宿主机器使用同一个进程ID命名空间和信号命名空间,从而实现进程间通信的能力
  4. ## --network host docker 使用本机的IP和端口
  5. #docker run -d -it --name my_container --gpus all --network host --ipc host images_name(id)
  6. #cuda toolkit共享的库,涵盖了运行环境的最小集合如动态库等,但没有cuda的编译工具nvcc
  7. #FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
  8. #基于runtime,添加了编译工具链、调试工具、头文件、静态库,用于从源码编译cuda应用,是有nvcc的
  9. FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
  10. WORKDIR /rwkv
  11. # Set up time zone.
  12. ENV TZ=Asia/Shanghai
  13. RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime
  14. ENV STAGE_DIR=/tmp
  15. RUN mkdir -p ${STAGE_DIR}
  16. RUN apt-get update && \
  17. apt-get install -y --no-install-recommends \
  18. software-properties-common build-essential autotools-dev \
  19. nfs-common pdsh \
  20. cmake g++ gcc \
  21. curl wget vim tmux emacs less unzip \
  22. htop iftop iotop ca-certificates openssh-client openssh-server \
  23. rsync iputils-ping net-tools
  24. RUN apt-get update && \
  25. apt-get install -y --no-install-recommends \
  26. libsndfile-dev \
  27. libcupti-dev \
  28. libjpeg-dev \
  29. libpng-dev \
  30. screen \
  31. libaio-dev
  32. #从源码安装python
  33. RUN apt install unzip wget build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libsqlite3-dev libreadline-dev libffi-dev curl libbz2-dev pkg-config make -y
  34. RUN apt-get install liblzma-dev -y
  35. #RUN wget https://www.python.org/ftp/python/3.10.10/Python-3.10.10.tar.xz
  36. COPY Python-3.10.10.tar.xz ./
  37. RUN tar xf Python-3.10.10.tar.xz
  38. RUN cd Python-3.10.10 && ./configure --enable-optimizations && make altinstall && cd .. && rm -fr *
  39. RUN python3.10 -m pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu118
  40. WORKDIR /rwkv
  41. COPY requirements.txt ./
  42. #RUN python3.10 -m pip install -r requirements.txt
  43. #RUN python3.10 -m pip install --upgrade pip && python3.10 -m pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
  44. RUN python3.10 -m pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
  45. # 拷贝所有nue文件
  46. COPY . ./

        注意python可以提前现在源码,然后上传到服务器再制作镜像;cuda docker 一定要拉取devel版本,runtime版本会精简,不安装nvcc等编译工具,python安装一些第三方库会依赖nvcc编译工具的。其他的都没有什么了,一切正常编写即可。

b、制造镜像
docker build -t  images_name(images_name:tag) -f ./Dockerfile .

这个耗时比较久,一个是镜像、已经库文件安装,还有数据、代码等copy。

c、容器启动
docker run -d -it --name my_container --gpus all --network host --ipc host  images_name(id)

        关注的地方是--gpus 一定要是all,这样容器才能使用物理机上的所有显卡;--network host保证docker使用物理机的ip和端口,可以通过改ip访问docker内的服务;--ipc host让Docker容器和宿主机器使用同一个进程ID命名空间和信号命名空间,从而实现进程间通信的能力——跑分布式训练必须选项,因为多进程中的子进程要和主进程进行通信,传输梯度等信息。

3、训练代码修改

        原始的训练代码是不支持lora和accelerate的,这里我们修改为支持lora以及accelerate的形式。同时由于采用分布式训练,目前可以使用deepspeed来做,而accelerate也支持deepspeed的插件形式(和直接使用deepspeed来做分布式训练稍有不同,直接使用deepspeed对系统的各种库libso要求的比较严格,之前使用deepspeed一直没有成功过)。代码主体结构如下:

  1. from accelerate import Accelerator, DeepSpeedPlugin
  2. from peft import get_peft_model, LoraConfig, TaskType
  3. import loralib as lora
  4. #初始化分布式环境
  5. accumulate_step = 4
  6. mixed_precision = 'bf16'
  7. deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=accumulate_step)
  8. accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=accumulate_step, deepspeed_plugin=deepspeed_plugin)
  9. device = accelerator.device
  10. ......
  11. ......
  12. model = RWKV(args)
  13. #lora设置,设置模型的那些参数使用lora以及其他的一些参数。
  14. peft_config = LoraConfig(
  15. peft_type="LORA",
  16. task_type=TaskType.CAUSAL_LM,
  17. inference_mode=False,
  18. r=args.lora_rank,
  19. lora_alpha=args.lora_alpha,
  20. lora_dropout=args.lora_dropout,
  21. target_modules=args.target_modules.split(","),
  22. )
  23. model = get_peft_model(model, peft_config)
  24. ......
  25. #模型、优化器、数据加载器等用accelerate包装一下。
  26. model, optimizer, train_dataloader = accelerator.prepare(model, optimizer,train_dataloader)
  27. ......
  28. for epoch in range(int(args.epoch_count)):
  29. for step, batch in enumerate(t := tqdm(train_dataloader, ncols=100)):
  30. model(batch)
  31. ......
  32. accelerator.backward(loss)
  33. optimizer.step()
  34. lr_scheduler.step()
  35. optimizer.zero_grad()

分布式环境的初始化以及lora参数的设置,针对rwkv模型lora设置如下:

  1. lora_rank=16
  2. lora_alpha=32
  3. lora_dropout=0.1
  4. target_modules=emb,key,value,receptance,output,head

完整的训练代码如下(其他的部分自行完成,代码修改自rwkv_LM中的rwkv-v4neo):

  1. ########################################################################################################
  2. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  3. ########################################################################################################
  4. import os, warnings, math, sys, time
  5. import numpy as np
  6. import torch
  7. from torch.utils.data import DataLoader
  8. import logging
  9. from transformers import get_linear_schedule_with_warmup
  10. from argparse import ArgumentParser
  11. logging.basicConfig(level=logging.INFO)
  12. import os
  13. import sys
  14. sys.path.append(os.getcwd())
  15. def script_method(fn, _rcb=None):
  16. return fn
  17. def script(obj, optimize=True, _frames_up=0, _rcb=None):
  18. return obj
  19. import torch.jit
  20. script_method1 = torch.jit.script_method
  21. script1 = torch.jit.script
  22. torch.jit.script_method = script_method
  23. torch.jit.script = script
  24. from torch.utils.tensorboard import SummaryWriter
  25. import torch
  26. import torch.nn as nn
  27. from torch.utils.data import DataLoader
  28. import gc
  29. import psutil
  30. import traceback
  31. from tqdm import tqdm
  32. import numpy as np
  33. from accelerate import Accelerator, DeepSpeedPlugin
  34. from torch.utils.data import Dataset, IterableDataset
  35. import random
  36. import json
  37. from collections import defaultdict
  38. import threading
  39. from tokenizer import build_tokenizer
  40. from datetime import datetime
  41. from peft import get_peft_model, LoraConfig, TaskType
  42. import loralib as lora
  43. accumulate_step = 4
  44. mixed_precision = 'bf16'
  45. deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=accumulate_step)
  46. accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=accumulate_step, deepspeed_plugin=deepspeed_plugin)
  47. device = accelerator.device
  48. def b2mb(x):
  49. return int(x / 2 ** 20)
  50. class TorchTracemalloc:
  51. def __enter__(self):
  52. gc.collect()
  53. torch.cuda.empty_cache()
  54. torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
  55. self.begin = torch.cuda.memory_allocated()
  56. self.process = psutil.Process()
  57. self.cpu_begin = self.cpu_mem_used()
  58. self.peak_monitoring = True
  59. peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
  60. peak_monitor_thread.daemon = True
  61. peak_monitor_thread.start()
  62. return self
  63. def cpu_mem_used(self):
  64. """get resident set size memory for the current process"""
  65. return self.process.memory_info().rss
  66. def peak_monitor_func(self):
  67. self.cpu_peak = -1
  68. while True:
  69. self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
  70. # can't sleep or will not catch the peak right (this comment is here on purpose)
  71. # time.sleep(0.001) # 1msec
  72. if not self.peak_monitoring:
  73. break
  74. def __exit__(self, *exc):
  75. self.peak_monitoring = False
  76. gc.collect()
  77. torch.cuda.empty_cache()
  78. self.end = torch.cuda.memory_allocated()
  79. self.peak = torch.cuda.max_memory_allocated()
  80. self.used = b2mb(self.end - self.begin)
  81. self.peaked = b2mb(self.peak - self.begin)
  82. self.cpu_end = self.cpu_mem_used()
  83. self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
  84. self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
  85. # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
  86. def collate_fn(batch):
  87. tokens, labels, domains = zip(*batch)
  88. input_ids = torch.nn.utils.rnn.pad_sequence(tokens,batch_first=True,padding_value=0)
  89. labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
  90. domains = torch.stack(domains)
  91. return {"input_ids": input_ids, "labels": labels, "domains":domains}
  92. idx2domain = {}
  93. domain2idx = {}
  94. # 所有数据全部加载 batch内采样
  95. class DataReader(Dataset):
  96. def __init__(self,tokenizer, file_list, sample_ratios, domain_names, max_token, args):
  97. self.args = args
  98. self.tokenizer = tokenizer
  99. file_list = file_list.split(",")
  100. sample_ratios = list(map(float, sample_ratios.split(",")))
  101. domain_names = domain_names.split(",")
  102. assert len(file_list) == len(sample_ratios) and len(file_list) == len(domain_names)
  103. self.file_list = file_list
  104. self.domain_names = domain_names
  105. self.max_token = max_token
  106. self.sample_ratios = sample_ratios
  107. self.sum_ratio = sum(sample_ratios)
  108. print("self.sum_ratio: ",self.sum_ratio)
  109. assert self.sum_ratio <= 1.0
  110. self.cum_ratios = [sum(sample_ratios[:i + 1]) for i in range(len(sample_ratios))]
  111. print("file_list: {}, sample_ratios: {} cum_ratios:{}".format(file_list, sample_ratios, self.cum_ratios))
  112. self.domain2num = defaultdict(int)
  113. self.common_datas = {}
  114. for i in range(len(file_list)):
  115. domain2idx[domain_names[i]] = i
  116. idx2domain[i] = domain_names[i]
  117. self.common_datas[domain_names[i]] = self.loaddata_convert_token_to_ids(domain_names[i], file_list[i])
  118. print(file_list[i], len(self.common_datas[domain_names[i]]))
  119. print("domain2num:{}".format(self.domain2num))
  120. self.train_data = []
  121. self.index = 0
  122. self.epoch = 0
  123. self.train_length = 4000
  124. self.train_step = 1000
  125. def loaddata_convert_token_to_ids(self, domain_name, file_path):
  126. with open(file_path, 'r', encoding='utf-8') as f:
  127. lines = f.readlines()
  128. domain_idx = domain2idx[domain_name]
  129. all_datas = []
  130. for line in tqdm(lines[0:], desc=f"read{file_path}",ncols=100):
  131. text = json.loads(line)["text"]
  132. text = text.split('\n\n')
  133. q = '\n\n'.join(text[0:3]) + "Answer:"
  134. a = '\n\n'.join(text[3:])
  135. a = a.replace('Answer:',"")
  136. q_ids = self.tokenizer.tokenize(q)
  137. a_ids = self.tokenizer.tokenize(a)
  138. ids = q_ids + a_ids
  139. ids.append(self.tokenizer.eod)
  140. if len(ids) > 2:
  141. if len(ids) > self.max_token:
  142. # 大于最大长度的数据丢弃掉
  143. continue
  144. else:
  145. labels = [-100] * len(q_ids) + a_ids + [self.tokenizer.eod]
  146. assert len(ids) == len(labels), " len(ids) != len(labels)"
  147. input_ids = torch.as_tensor(ids[:-1], dtype=torch.long)
  148. labels = torch.as_tensor(labels[1:], dtype=torch.long)
  149. domain_idx = torch.as_tensor(domain_idx, dtype=torch.long)
  150. all_datas.append((input_ids, labels, domain_idx))
  151. print(f"{file_path}--{len(all_datas)}")
  152. self.domain2num[domain_name] += 1
  153. return all_datas
  154. def __getitem__(self, item):
  155. if len(self.train_data) == 0:
  156. time_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
  157. print("=============={}==============".format(time_str))
  158. for k, v in self.common_datas.items():
  159. if k in ['friso','kongtiao','qa','other']:
  160. self.train_data.extend(v)
  161. else:
  162. split_count = len(v)//20
  163. epoch = self.epoch % 20
  164. temp = v[epoch*split_count:(epoch+1)*split_count]
  165. # temp = random.choices(v, k=split_count)
  166. self.train_data.extend(temp)
  167. print(f"len(self.train_data) {len(self.train_data)} epoch {self.epoch}")
  168. if self.index < self.train_step:
  169. self.index += 1
  170. if item >= len(self.train_data):
  171. item = random.randint(0,len(self.train_data)-1)
  172. input_ids, labels, domain_idx = self.train_data[item]
  173. return input_ids, labels, domain_idx
  174. else:
  175. self.epoch += 1
  176. self.index = 0
  177. self.train_data = []
  178. for k, v in self.common_datas.items():
  179. if k in ['friso','kongtiao','qa','other']:
  180. self.train_data.extend(v)
  181. else:
  182. split_count = len(v)//20
  183. epoch = self.epoch % 20
  184. temp = v[epoch*split_count:(epoch+1)*split_count]
  185. # temp = random.choices(v, k=split_count)
  186. self.train_data.extend(temp)
  187. print(f"len(self.train_data) {len(self.train_data)} epoch {self.epoch}")
  188. self.index += 1
  189. if item >= len(self.train_data):
  190. item = random.randint(0, len(self.train_data) - 1)
  191. input_ids, labels, domain_idx = self.train_data[item]
  192. return input_ids, labels, domain_idx
  193. def __len__(self):
  194. # return 910000
  195. return self.train_length
  196. if __name__ == "__main__":
  197. parser = ArgumentParser()
  198. parser.add_argument("--file_list", default="", type=str)
  199. parser.add_argument("--sample_ratios", default="utf-8", type=str)
  200. parser.add_argument("--domain_names", default="", type=str)
  201. parser.add_argument("--use_owndatareader", default="1", type=str)
  202. parser.add_argument("--logdir", default="", type=str)
  203. parser.add_argument("--datadir", default="", type=str)
  204. parser.add_argument("--save_step",default=50000,type=int)
  205. # lora
  206. parser.add_argument("--lora_rank", default=16, type=int)
  207. parser.add_argument("--lora_alpha", default=32, type=int)
  208. parser.add_argument("--lora_dropout", default=0.1, type=float)
  209. parser.add_argument("--target_modules", default="emb,key,value,receptance,output,head", type=str)
  210. parser.add_argument("--load_model", default="/AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth", type=str) # full path, with .pth
  211. parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
  212. parser.add_argument("--proj_dir", default="out", type=str)
  213. parser.add_argument("--random_seed", default="-1", type=int)
  214. parser.add_argument("--data_file", default="", type=str)
  215. parser.add_argument("--data_type", default="utf-8", type=str)
  216. parser.add_argument("--vocab_size", default=65536, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
  217. parser.add_argument("--ctx_len", default=2560, type=int)
  218. parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
  219. parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
  220. parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
  221. parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
  222. parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
  223. parser.add_argument("--n_layer", default=24, type=int)
  224. parser.add_argument("--n_embd", default=2048, type=int)
  225. parser.add_argument("--dim_att", default=0, type=int)
  226. parser.add_argument("--dim_ffn", default=0, type=int)
  227. parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
  228. parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
  229. parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
  230. parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
  231. parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
  232. parser.add_argument("--lr_final", default=1e-5, type=float)
  233. parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
  234. parser.add_argument("--beta1", default=0.9, type=float)
  235. parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
  236. parser.add_argument("--adam_eps", default=1e-8, type=float)
  237. parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
  238. parser.add_argument("--dropout", default=0, type=float) # try 0.01 / 0.02 / 0.05 / 0.1
  239. parser.add_argument("--weight_decay", default=0, type=float) # try 0.1 / 0.01 / 0.001
  240. parser.add_argument("--weight_decay_final", default=-1, type=float)
  241. parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
  242. parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
  243. parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
  244. parser.add_argument("--my_pile_edecay", default=0, type=int)
  245. parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
  246. parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
  247. # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
  248. parser.add_argument("--my_img_version", default=0, type=str)
  249. parser.add_argument("--my_img_size", default=0, type=int)
  250. parser.add_argument("--my_img_bit", default=0, type=int)
  251. parser.add_argument("--my_img_clip", default='x', type=str)
  252. parser.add_argument("--my_img_clip_scale", default=1, type=float)
  253. parser.add_argument("--my_img_l1_scale", default=0, type=float)
  254. parser.add_argument("--my_img_encoder", default='x', type=str)
  255. # parser.add_argument("--my_img_noise_scale", default=0, type=float)
  256. parser.add_argument("--my_sample_len", default=0, type=int)
  257. parser.add_argument("--my_ffn_shift", default=1, type=int)
  258. parser.add_argument("--my_att_shift", default=1, type=int)
  259. parser.add_argument("--head_size_a", default=64, type=int) # can try larger values for larger models
  260. parser.add_argument("--head_size_divisor", default=8, type=int)
  261. parser.add_argument("--my_pos_emb", default=0, type=int)
  262. parser.add_argument("--load_partial", default=0, type=int)
  263. parser.add_argument("--magic_prime", default=0, type=int)
  264. parser.add_argument("--my_qa_mask", default=0, type=int)
  265. parser.add_argument("--my_random_steps", default=0, type=int)
  266. parser.add_argument("--my_testing", default='', type=str)
  267. parser.add_argument("--my_exit", default=99999999, type=int)
  268. parser.add_argument("--my_exit_tokens", default=0, type=int)
  269. args = parser.parse_args()
  270. summary_writer = SummaryWriter(args.logdir)
  271. print(args)
  272. ########################################################################################################
  273. np.set_printoptions(precision=4, suppress=True, linewidth=200)
  274. warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
  275. warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
  276. # os.environ["WDS_SHOW_SEED"] = "1"
  277. args.my_timestamp = datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
  278. args.enable_checkpointing = False
  279. args.replace_sampler_ddp = False
  280. args.logger = False
  281. args.gradient_clip_val = 1.0
  282. args.num_sanity_val_steps = 0
  283. args.check_val_every_n_epoch = int(1e20)
  284. args.log_every_n_steps = int(1e20)
  285. args.max_epochs = -1 # continue forever
  286. args.betas = (args.beta1, args.beta2)
  287. args.real_bsz = args.micro_bsz
  288. os.environ["RWKV_T_MAX"] = str(args.ctx_len)
  289. os.environ["RWKV_MY_TESTING"] = args.my_testing
  290. os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
  291. if args.dim_att <= 0:
  292. args.dim_att = args.n_embd
  293. if args.dim_ffn <= 0:
  294. if 'r3' in args.my_testing:
  295. args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
  296. else:
  297. args.dim_ffn = args.n_embd * 4
  298. if args.data_type == "wds_img":
  299. args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
  300. args.proj_dir = f"{args.proj_dir}-{args.run_name}"
  301. else:
  302. args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
  303. if accelerator.is_main_process and not os.path.exists(args.proj_dir):
  304. os.makedirs(args.proj_dir)
  305. if args.my_pile_stage > 0:
  306. magic_prime_bak = args.magic_prime
  307. if args.my_pile_version == 1:
  308. if args.ctx_len == 1024:
  309. args.magic_prime = 324331313
  310. elif args.ctx_len == 2048:
  311. args.magic_prime = 162165671
  312. elif args.ctx_len == 4096:
  313. args.magic_prime = 81082817
  314. elif args.ctx_len == 8192:
  315. args.magic_prime = 40541399
  316. else:
  317. if args.ctx_len == 1024:
  318. args.magic_prime = 1670239709
  319. elif args.ctx_len == 2048:
  320. args.magic_prime = 835119767
  321. elif args.ctx_len == 4096:
  322. args.magic_prime = 417559889
  323. elif args.ctx_len == 6144:
  324. args.magic_prime = 278373239
  325. elif args.ctx_len == 8192:
  326. args.magic_prime = 208779911
  327. if args.my_pile_shift < 0:
  328. args.my_pile_shift = 0
  329. if magic_prime_bak > 0:
  330. args.magic_prime = magic_prime_bak
  331. if args.my_qa_mask == 2:
  332. args.epoch_count = 2 * args.magic_prime // 40320
  333. else:
  334. args.epoch_count = args.magic_prime // 40320
  335. args.epoch_steps = 40320 // args.real_bsz
  336. assert args.epoch_steps * args.real_bsz == 40320
  337. # if args.my_pile_stage == 2:
  338. # assert args.lr_final == args.lr_init
  339. if args.my_pile_stage >= 2: # find latest saved model
  340. list_p = []
  341. for p in os.listdir(args.proj_dir):
  342. if p.startswith("rwkv") and p.endswith(".pth"):
  343. p = ((p.split("-"))[1].split("."))[0]
  344. if p != "final":
  345. if p == "init":
  346. p = -1
  347. else:
  348. p = int(p)
  349. list_p += [p]
  350. list_p.sort()
  351. max_p = list_p[-1]
  352. if len(list_p) > 1:
  353. args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
  354. if max_p == -1:
  355. args.load_model = f"{args.proj_dir}/rwkv-init.pth"
  356. else:
  357. args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
  358. if args.warmup_steps < 0:
  359. if args.my_pile_stage == 2:
  360. args.warmup_steps = 10
  361. else:
  362. args.warmup_steps = 30
  363. args.epoch_begin = max_p + 1
  364. samples_per_epoch = args.epoch_steps * args.real_bsz
  365. tokens_per_epoch = samples_per_epoch * args.ctx_len
  366. assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
  367. args.precision = "bf16"
  368. assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
  369. os.environ["RWKV_FLOAT_MODE"] = args.precision
  370. # os.environ["RWKV_JIT_ON"] = "1"
  371. os.environ["RWKV_JIT_ON"] = "0"
  372. torch.backends.cudnn.benchmark = True
  373. torch.backends.cudnn.enabled = True
  374. if args.precision == "fp32":
  375. torch.backends.cudnn.allow_tf32 = False
  376. torch.backends.cuda.matmul.allow_tf32 = False
  377. else:
  378. torch.backends.cudnn.allow_tf32 = True
  379. torch.backends.cuda.matmul.allow_tf32 = True
  380. args.precision = "bf16"
  381. if args.data_type == 'wds_img':
  382. from src.model_img import RWKV_IMG
  383. model = RWKV_IMG(args)
  384. else:
  385. from src.model import RWKV
  386. model = RWKV(args)
  387. try:
  388. load_dict = torch.load(args.load_model, map_location="cpu")
  389. load_keys = list(load_dict.keys())
  390. for k in load_keys:
  391. if k.startswith('_forward_module.'):
  392. load_dict[k.replace('_forward_module.','')] = load_dict[k]
  393. del load_dict[k]
  394. except:
  395. if args.my_pile_stage >= 2: # try again using another checkpoint
  396. max_p = args.my_pile_prev_p
  397. if max_p == -1:
  398. args.load_model = f"{args.proj_dir}/rwkv-init.pth"
  399. else:
  400. args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
  401. args.epoch_begin = max_p + 1
  402. load_dict = torch.load(args.load_model, map_location="cpu")
  403. model.load_state_dict(load_dict)
  404. peft_config = LoraConfig(
  405. peft_type="LORA",
  406. task_type=TaskType.CAUSAL_LM,
  407. inference_mode=False,
  408. r=args.lora_rank,
  409. lora_alpha=args.lora_alpha,
  410. lora_dropout=args.lora_dropout,
  411. target_modules=args.target_modules.split(","),
  412. )
  413. model = get_peft_model(model, peft_config)
  414. optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr_init)
  415. tokenizer_type = "RWKVTokenizer"
  416. vocab_file = "./json2binidx/rwkv_vocab_v20230424.txt"
  417. tokenizer = build_tokenizer(tokenizer_type, vocab_file)
  418. train_data = DataReader(tokenizer, args.file_list, args.sample_ratios, args.domain_names, args.ctx_len, args)
  419. # train_data = DataReader( tokenizer, args.ctx_len, args.datadir, read_file_count=2)
  420. train_dataloader = DataLoader(dataset=train_data, collate_fn=collate_fn, shuffle=True, batch_size=args.micro_bsz)
  421. print(f"已经加载完了数据:{len(train_dataloader)}条")
  422. warm_up_ratio = 0.1
  423. lr_scheduler = get_linear_schedule_with_warmup(
  424. optimizer=optimizer,
  425. num_warmup_steps=int(len(train_dataloader) / accumulate_step * warm_up_ratio),
  426. num_training_steps=(int(len(train_dataloader) / accumulate_step) * args.epoch_count),
  427. )
  428. model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
  429. print(f"已经加载完了数据:{len(train_dataloader)}条")
  430. loss_fct = nn.CrossEntropyLoss()
  431. global_step = 0
  432. domain2globalstep = {k: 0 for k in domain2idx}
  433. for epoch in range(int(args.epoch_count)):
  434. name2loss = {k: 0 for k in domain2idx}
  435. domain2step = {k: 0 for k in domain2idx}
  436. print("name2loss",name2loss)
  437. total_loss = 0
  438. mean_loss = 0
  439. domain2num = {k: 0 for k in domain2idx}
  440. with TorchTracemalloc() as tracemalloc:
  441. model.to(device).train()
  442. i = 0
  443. for step, batch in enumerate(t := tqdm(train_dataloader, ncols=100)):
  444. try:
  445. i += 1
  446. if accelerator.is_main_process and i % args.save_step == 0:
  447. model_state_dict = lora.lora_state_dict(accelerator.unwrap_model(model))
  448. save_path = os.path.join(args.proj_dir, f"rwkv-epoch{epoch}_step{i}_lora.pt")
  449. accelerator.save(model_state_dict, save_path)
  450. labels = batch['labels']
  451. domains = batch['domains']
  452. input_ids = batch['input_ids']
  453. lm_logits = model(input_ids)
  454. shift_logits = lm_logits.contiguous()
  455. shift_labels = labels.contiguous()
  456. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  457. accelerator.backward(loss)
  458. optimizer.step()
  459. lr_scheduler.step()
  460. optimizer.zero_grad()
  461. if i % 50 == 0:
  462. torch.cuda.empty_cache()
  463. loss_detach = loss.detach().cpu().float()
  464. total_loss += loss_detach
  465. time_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
  466. des_train = f"{time_str} shape:{input_ids.shape[1]} loss: {loss_detach}"
  467. for domian_name, domian_idx in domain2idx.items():
  468. select_idx = domains == domian_idx
  469. select_shift_logits = shift_logits[select_idx]
  470. select_shift_labels = shift_labels[select_idx]
  471. loss_domain = 0
  472. if len(select_shift_labels) > 0:
  473. domain2num[domian_name] += len(select_shift_labels)
  474. loss_domain = loss_fct(select_shift_logits.view(-1, select_shift_logits.size(-1)),
  475. select_shift_labels.view(-1)).detach().cpu().float()
  476. domain2globalstep[domian_name] += 1
  477. domain2step[domian_name] += 1
  478. name2loss[domian_name] += loss_domain
  479. summary_writer.add_scalar(f"train_step/{domian_name}", loss_domain, domain2globalstep[domian_name])
  480. des_train += f" {domian_name}: {loss_domain}"
  481. # domain2loss_detach[domian_name] = loss_domain
  482. t.set_description(des_train)
  483. # t.set_postfix(des_train)
  484. if accelerator.is_main_process:
  485. summary_writer.add_scalar(f"train_step/total_loss", loss_detach, global_step)
  486. global_step += 1
  487. except Exception as e:
  488. print(str(e))
  489. print(traceback.format_exc())
  490. print("oom", batch['input_ids'].shape)
  491. optimizer.zero_grad()
  492. torch.cuda.empty_cache()
  493. mean_loss = total_loss / (step + 1)
  494. for k in name2loss:
  495. name2loss[k] = name2loss[k] / (domain2step[k] + 1)
  496. if accelerator.is_main_process:
  497. summary_writer.add_scalar(f"train/{k}", name2loss[k], epoch)
  498. s = ""
  499. s_num = ""
  500. for k, v in name2loss.items():
  501. s += f" {k}_loss={v}"
  502. s_num += f" {k}_num={domain2num[k]}"
  503. train_epoch_loss = total_loss
  504. train_mean_epoch_loss = mean_loss
  505. train_ppl = torch.exp(train_epoch_loss)
  506. time_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
  507. accelerator.print(
  508. f"{time_str} epoch={epoch}: train_ppl={train_ppl} train_epoch_loss={train_epoch_loss} train_mean_epoch_loss={train_mean_epoch_loss}")
  509. accelerator.print(s)
  510. accelerator.print(s_num)
  511. accelerator.wait_for_everyone()

accelerate联合deepspeed启动的时候需要配置文件:

  1. compute_environment: LOCAL_MACHINE
  2. deepspeed_config:
  3. gradient_accumulation_steps: 1
  4. gradient_clipping: 1.0
  5. offload_optimizer_device: none
  6. offload_param_device: none
  7. zero3_init_flag: false
  8. zero3_save_16bit_model: false
  9. zero_stage: 2
  10. distributed_type: DEEPSPEED
  11. downcast_bf16: 'yes'
  12. dynamo_backend: 'yes'
  13. fsdp_config: {}
  14. machine_rank: 0
  15. main_training_function: main
  16. megatron_lm_config: {}
  17. mixed_precision: fp16
  18. num_machines: 1
  19. num_processes: 2
  20. rdzv_backend: static
  21. same_network: true
  22. use_cpu: true
  23. main_process_port: 20667

主要关注num_processes,要和使用的显卡数量一致。

训练启动脚本,使用CUDA_VISIBLE_DEVICES指定机器上使用的显卡;nohup后台启动;accelerate launch 启动accelerate;--config_file 配置文件设置以及deepspeed的配置等

  1. CUDA_VISIBLE_DEVICES=1,2,4,5 nohup accelerate launch --config_file accelerate_ds_zero3_cpu_offload_config.yaml train_accelerator_deepspeed_lora_v1.py \
  2. --load_model /AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth
  3. ......
  4. ......

采用lora以及2张4090来训练,只需要几分钟就可以训练好一个epoch,显存占用也非常友好:

四、模型推理

1、模型推理

模型推理使用rwkv第三方库来实现,核心逻辑如下:

  1. from rwkv.model import RWKV
  2. from rwkv.utils import PIPELINE
  3. model = RWKV(model='./rwkv.pth', strategy='cuda bf16')
  4. model.eval()
  5. pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
  6. out_tokens = []
  7. out_last = 0
  8. out_str = ''
  9. occurrence = {}
  10. state = None
  11. token = None
  12. for i in range(max_length):
  13. tokens = pipeline.encode(ctx) if i == 0 else [token]
  14. out, state = pipeline.model.forward(tokens, state)
  15. for n in occurrence:
  16. out[n] -= (0.4 + occurrence[n] * 0.4) # repetition penalty
  17. token = pipeline.sample_logits(out, temperature=1.0, top_p=0.0)
  18. if token == 0:
  19. break # exit when 'endoftext'
  20. out_tokens += [token]
  21. occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
  22. tmp = pipeline.decode(out_tokens[out_last:])
  23. if ('\ufffd' not in tmp) and (not tmp.endswith('\n')):
  24. # print(tmp, end='', flush=True)
  25. out_str += tmp
  26. out_last = i + 1
  27. return out_str

同时由于采用lora训练因此需要把lora权重合并到原始的权重上,方可使用上述方式进行模型加载和推理

2、lora权重合并

lora权重合并到原始权重,依据公式直接实现,代码如下:

  1. def merge_lora_weights():
  2. rwkv_path = "RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth"
  3. lora_path = "./lora.pt"
  4. print("lora_path: ",lora_path)
  5. model_weight = torch.load(rwkv_path, map_location='cpu')
  6. lora_model = torch.load(lora_path, map_location='cpu')
  7. for k, v in tqdm(model_weight.items(),desc="model_weight", ncols=100):
  8. if "emb" in k or "key" in k or "value" in k or "receptance" in k or "output" in k or "head" in k:
  9. if "emb" in k:
  10. lora_a = "base_model.model." + k.replace(".weight", ".lora_embedding_A.default")
  11. lora_b = "base_model.model." + k.replace(".weight", ".lora_embedding_B.default")
  12. device = v.device
  13. w_a = lora_model[lora_a].T
  14. w_b = lora_model[lora_b].T
  15. w = torch.mm(w_a, w_b).cpu()
  16. new_w = v.cpu() + 2 * w
  17. model_weight[k] = new_w.to(device)
  18. elif "weight" in k:
  19. lora_a = "base_model.model." + k.replace(".weight", ".lora_A.default.weight")
  20. lora_b = "base_model.model." + k.replace(".weight", ".lora_B.default.weight")
  21. device = v.device
  22. w_a = lora_model[lora_a]
  23. w_b = lora_model[lora_b]
  24. w = torch.mm(w_b, w_a).cpu()
  25. # w = torch.mm(w_b, w_a)
  26. new_w = v.cpu() + 2 * w
  27. model_weight[k] = new_w.to(device)
  28. else:
  29. model_weight[k] = v
  30. else:
  31. model_weight[k] = v
  32. rwkv_lora_path = "./rwkv.pth"
  33. torch.save(model_weight,rwkv_lora_path)
  34. print("merge_lora_weights finished!")

3、推理web服务

一般都是需要提供web接口,采用aiohttp来做异步web接口,把上述模型推理和lora权重合并功能逻辑集成到web服务程序中:

  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3. import asyncio
  4. import json
  5. import logging.handlers
  6. import os
  7. import socket
  8. import time
  9. import aiohttp
  10. from aiohttp import web
  11. import torch
  12. from argparse import ArgumentParser
  13. from tqdm import tqdm
  14. torch.backends.cudnn.benchmark = True
  15. torch.backends.cudnn.allow_tf32 = True
  16. torch.backends.cuda.matmul.allow_tf32 = True
  17. os.environ["RWKV_JIT_ON"] = '1'
  18. os.environ["RWKV_CUDA_ON"] = '1'
  19. from rwkv.model import RWKV
  20. from rwkv.utils import PIPELINE, PIPELINE_ARGS
  21. # logger
  22. log_level = logging.DEBUG
  23. logger = logging.getLogger(__name__)
  24. logger.setLevel(log_level)
  25. formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)s %(message)s')
  26. stream_handler = logging.StreamHandler()
  27. stream_handler.setLevel(log_level)
  28. stream_handler.setFormatter(formatter)
  29. os.makedirs('./log', exist_ok=True)
  30. file_handler = logging.handlers.RotatingFileHandler(filename='log/server.log', maxBytes=10 << 20, backupCount=5,encoding='utf8')
  31. file_handler.setLevel(log_level)
  32. file_handler.setFormatter(formatter)
  33. logger.addHandler(stream_handler)
  34. logger.addHandler(file_handler)
  35. #
  36. NODE_NAME = 'general.rwkv.loratest_20231010'
  37. NODE_NAME_2 = 'general.chat.hydiversity_20231010'
  38. print(NODE_NAME)
  39. print(NODE_NAME_2)
  40. NUS = '心跳IP:端口'
  41. async def heart_beat(ip, port):
  42. data_dic = {
  43. 'method': 'heartbeat',
  44. 'params': {
  45. 'data': [
  46. {
  47. 'nodename': NODE_NAME,
  48. 'addrip': ip + ':' + str(port),
  49. 'type': 'transparent'
  50. },
  51. {
  52. 'nodename': NODE_NAME_2,
  53. 'addrip': ip + ':' + str(port),
  54. 'type': 'transparent'
  55. }
  56. ]
  57. }
  58. }
  59. send_data = json.dumps(data_dic)
  60. client = aiohttp.ClientSession()
  61. while True:
  62. try:
  63. await client.post(f'http://{NUS}/heartbeat', data=send_data)
  64. except Exception as e:
  65. logger.error(f'send heartbeat fail: {e}')
  66. await asyncio.sleep(1)
  67. class TimeMeasure:
  68. def __init__(self, desc=''):
  69. self.start = 0
  70. self.desc = desc
  71. def __enter__(self):
  72. self.start = time.time()
  73. logger.info(f'{self.desc} start')
  74. def __exit__(self, exc_type, exc_val, exc_tb):
  75. end = time.time()
  76. cost_s = end - self.start
  77. if cost_s > 10:
  78. cost_s = round(cost_s, 2)
  79. logger.info(f'{self.desc} end, cost : {cost_s}s')
  80. else:
  81. cost_ms = round(cost_s * 1000, 2)
  82. logger.info(f'{self.desc} end, cost : {cost_ms}ms')
  83. def build_fail_resp(id_: int, code: int, msg: str):
  84. return web.json_response({
  85. 'id': id_,
  86. 'jsonrpc': '2.0',
  87. 'ret': code,
  88. 'result': {
  89. "error_info": msg
  90. }
  91. })
  92. def build_success_resp(id_, result):
  93. data = {
  94. 'id': id_,
  95. 'jsonrpc': '2.0',
  96. 'ret': 0,
  97. 'result': {
  98. 'chatInfo': {
  99. 'answer': result,
  100. 'elements':[]
  101. }
  102. }
  103. }
  104. for ele in result.split('\n\n'):
  105. ele = ele.split(":")
  106. try:
  107. temp = {"tag":ele[0],"value":ele[1]}
  108. data['result']['chatInfo']['elements'].append(temp)
  109. except Exception as e:
  110. print(e)
  111. send_data = json.dumps(data, ensure_ascii=False)
  112. return web.json_response(text=send_data)
  113. class Server:
  114. def __init__(self):
  115. self.lock = asyncio.Semaphore(20)
  116. self.model = RWKV(model='./rwkv.pth', strategy='cuda bf16')
  117. # self.model = RWKV(model='./rwkv.pth', strategy='cuda fp16')
  118. self.model.eval()
  119. self.pipeline = PIPELINE(self.model, "rwkv_vocab_v20230424")
  120. out_str = self.chat("Question:你好呀,你是谁?\n\nAnswer:")
  121. logger.info(f'out_str——{out_str}')
  122. logger.info(f'Server __init__ finished!')
  123. @torch.no_grad()
  124. def chat(self, ctx: str):
  125. out_tokens = []
  126. out_last = 0
  127. out_str = ''
  128. occurrence = {}
  129. state = None
  130. token = None
  131. for i in range(2560):
  132. tokens = self.pipeline.encode(ctx) if i == 0 else [token]
  133. out, state = self.pipeline.model.forward(tokens, state)
  134. for n in occurrence:
  135. out[n] -= (0.4 + occurrence[n] * 0.4) # repetition penalty
  136. token = self.pipeline.sample_logits(out, temperature=1.0, top_p=0.0)
  137. if token == 0:
  138. break # exit when 'endoftext'
  139. out_tokens += [token]
  140. occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
  141. tmp = self.pipeline.decode(out_tokens[out_last:])
  142. if ('\ufffd' not in tmp) and (not tmp.endswith('\n')):
  143. # print(tmp, end='', flush=True)
  144. out_str += tmp
  145. out_last = i + 1
  146. return out_str
  147. async def inference(self, request: web.Request):
  148. req = await request.json()
  149. id_ = 0
  150. try:
  151. id_ = req['id']
  152. content = req['params']['data']['content']
  153. if not isinstance(content, str):
  154. raise RuntimeError('parameter type error')
  155. except Exception as e:
  156. logger.exception(f'params error: {e}')
  157. return build_fail_resp(id_, 8002, 'parameter error')
  158. logger.info(f'id: {id_}\nreq content:\n{content}')
  159. prompt = f'Question:{content}\n\nAnswer:'
  160. # prompt = f"Instruction:这是一通交通事故报警的通话, 你是要素抽取方面的专家,需要提取的要素名为“案发地址”\n请给出要素抽取结果\n\nInput:{content}\n\nResponse:"
  161. logger.info(f'id: {id_}\nreq prompt:\n{prompt}')
  162. with TimeMeasure(f'id: {id_} infer'):
  163. try:
  164. # result = await asyncio.get_running_loop().run_in_executor(None, self.chat, prompt)
  165. result = await asyncio.to_thread(self.chat, prompt)
  166. except Exception as e:
  167. logger.exception(f'id: {id_} inference fail: {e}')
  168. return build_fail_resp(id_, 8001, 'internal error')
  169. logger.info(f'id: {id_}, resp: {result}')
  170. return build_success_resp(id_, result)
  171. def get_local_ip(ip, port):
  172. try:
  173. conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  174. conn.connect((ip, port))
  175. ip = conn.getsockname()[0]
  176. except Exception:
  177. raise
  178. conn.close()
  179. return ip
  180. async def main(ip, port):
  181. server = Server()
  182. app = web.Application()
  183. app.add_routes([
  184. web.post('/nlp', server.inference)
  185. ])
  186. asyncio.create_task(heart_beat(ip, port))
  187. return app
  188. def merge_lora_weights():
  189. rwkv_path = "/AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth"
  190. lora_path = "./output/20231016_kongtiao_v1/rwkv-epoch5_step1000_lora.pt"
  191. print("lora_path: ",lora_path)
  192. model_weight = torch.load(rwkv_path, map_location='cpu')
  193. lora_model = torch.load(lora_path, map_location='cpu')
  194. for k, v in tqdm(model_weight.items(),desc="model_weight", ncols=100):
  195. if "emb" in k or "key" in k or "value" in k or "receptance" in k or "output" in k or "head" in k:
  196. if "emb" in k:
  197. lora_a = "base_model.model." + k.replace(".weight", ".lora_embedding_A.default")
  198. lora_b = "base_model.model." + k.replace(".weight", ".lora_embedding_B.default")
  199. device = v.device
  200. w_a = lora_model[lora_a].T
  201. w_b = lora_model[lora_b].T
  202. w = torch.mm(w_a, w_b).cpu()
  203. new_w = v.cpu() + 2 * w
  204. model_weight[k] = new_w.to(device)
  205. elif "weight" in k:
  206. lora_a = "base_model.model." + k.replace(".weight", ".lora_A.default.weight")
  207. lora_b = "base_model.model." + k.replace(".weight", ".lora_B.default.weight")
  208. device = v.device
  209. w_a = lora_model[lora_a]
  210. w_b = lora_model[lora_b]
  211. w = torch.mm(w_b, w_a).cpu()
  212. # w = torch.mm(w_b, w_a)
  213. new_w = v.cpu() + 2 * w
  214. model_weight[k] = new_w.to(device)
  215. else:
  216. model_weight[k] = v
  217. else:
  218. model_weight[k] = v
  219. rwkv_lora_path = "./rwkv.pth"
  220. torch.save(model_weight,rwkv_lora_path)
  221. print("merge_lora_weights finished!")
  222. if __name__ == '__main__':
  223. merge_lora_weights()
  224. bind_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0)
  225. local_ip = get_local_ip('心跳地址', 心跳IP)
  226. bind_socket.bind(('0.0.0.0', 0))
  227. web.run_app(main(local_ip, bind_socket.getsockname()[1]), sock=bind_socket)

web服务启动展示

  1. 2023-11-02 06:21:12,812 [INFO] rwkv_chat_lora_iir.py:147 out_str——我是一个基于GPT-3.5接口的AI机器人。
  2. Question: 你好呀,你是谁?
  3. Answer: 我是一个基于GPT-3.5接口的AI机器人
  4. 2023-11-02 06:21:12,838 [INFO] rwkv_chat_lora_iir.py:148 Server __init__ finished!
  5. ======== Running on http://0.0.0.0:45149 ========
  6. (Press CTRL+C to quit)

可以采用心跳地址来请求 也可以直连物理机IP:45149/nlp地址来请求

五、总结

结果:

1、今天rwkv_v4  集内55%(49 epoch) 集外15% (1191条数据)
2、昨天rwkv_v5 集内最高34%(9 epoch) 集外24%(1191条数据 4epoch)
结论:
a、rwkv_v5  确实要比rwkv_v4 对集外的泛化能力强很多
b、比ChatGLM6B蒸馏到ChatGLM1.5B效果差很多(集外92%)——训练方式完全不同,这个训练成本非常大

        虽然rwkv1.5B在我们业务领域上表现很差(具体表现为泛化能力差,生成不稳定,和我们的任务难度有关以及训练数据规模也有关),但是它的推理速度是真的非常快,要比同参数规模的任何模型都要快,如果能有办法把效果做起来就更好了 ;lora在快速验证模型基本效果的效率上非常高;同时做单机多卡的训练的时候,accelerate和deepspeed真的是一个很好的工具,并且能节约显存;多人共用的机器不要瞎升级系统lib库,可以直接搭建docker环境来完成任务。

参考文章

RWKV语言模型从入门到放弃,保姆级Training、Fine-tuning、Lora入坑教程

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/331936
推荐阅读
相关标签
  

闽ICP备14008679号