当前位置:   article > 正文

AGI|教你用一部电影的时间训练个人专属Agent_ai agent 如何训练

ai agent 如何训练

目录

一、Agent如何工作?

二、Function Call 原理

三、开源模型工具调用微调

//Chat模型微调

//训练过程日志

//测试结果

//测试Tools

四、预训练模型微调

五、总结


Agent是一个超越简单文本生成的人工智能系统。它使用大型语言模型(LLM)作为其中央计算引擎,使其能够进行对话、执行任务、推理并显示一定程度的自主权。

一、Agent如何工作?

1、当用户给出一个任务task之后可以从memory中查询记录(可选),查询出的结果(如果有)给AgentLLM进行判断是否可复用,这里指的复用是针对时效性没那么高的任务,例如对过去时的数据“中国19-22年的出生及死亡人口数据”,但如果查询股票数据,天气这种对时效性有很高要求的任务则不适合复用。


2、Agent对任务实现的方式有很多,可以拆解任务、使用lCOT或REACT框架、SOP(Standard Operating Procedure)标准作业规程等等。其目的都是将一个复杂的任务分成n个可在one step内即可完成的子任务。


3、对于子任务,是否需要调用工具,如果无需调用工具则只需要进行一次推理即可;对于需要调用工具的子任务AgentLLM会根据任务描述调用一个或多个工具,根据工具返回结果判断是否可以更改任务状态。待所有的子任务都完成状态变更之后AgentLLM会对结果进行评估反思,判断当前任务是否已经完成。如果某些子任务因为种种原因无法完成,AgentLLM会采取别的方法完成此任务,重复以上步骤直到可以给出结果为止,当然这里的Loop需要设置最大重试次数避免死循环。


4、当AgentLLM判断可以完成任务后可以进行历史任务存储(可选)。长期记忆是将数据存储在数据库中,以便下次查询,短期记忆则保存在内存或缓存中,程序结束时释放。

二、Function Call 原理

在一些任务中我们希望LLM返回我们格式化的数据如json、xml等,function call则需要LLM返回特定的json格式,以OpenAI为例,需要提供工具的描述信息。

  1. from openai import OpenAI
  2. import json
  3. client = OpenAI()
  4. def get_current_weather(location, unit="fahrenheit"):
  5. """Get the current weather in a given location"""
  6. if "tokyo" in location.lower():
  7. return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
  8. elif "san francisco" in location.lower():
  9. return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit})
  10. elif "paris" in location.lower():
  11. return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
  12. else:
  13. return json.dumps({"location": location, "temperature": "unknown"})
  14. def run_conversation():
  15. messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
  16. tools = [
  17. {
  18. "type": "function",
  19. "function": {
  20. "name": "get_current_weather",
  21. "description": "Get the current weather in a given location",
  22. "parameters": {
  23. "type": "object",
  24. "properties": {
  25. "location": {
  26. "type": "string",
  27. "description": "The city and state, e.g. San Francisco, CA",
  28. },
  29. "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
  30. },
  31. "required": ["location"],
  32. },
  33. },
  34. }
  35. ]
  36. response = client.chat.completions.create(
  37. model="gpt-3.5-turbo-1106",
  38. messages=messages,
  39. tools=tools,
  40. tool_choice="auto",
  41. )
  42. response_message = response.choices[0].message
  43. tool_calls = response_message.tool_calls
  44. if tool_calls:
  45. available_functions = {
  46. "get_current_weather": get_current_weather,
  47. }
  48. messages.append(response_message)
  49. for tool_call in tool_calls:
  50. function_name = tool_call.function.name
  51. function_to_call = available_functions[function_name]
  52. function_args = json.loads(tool_call.function.arguments)
  53. function_response = function_to_call(
  54. location=function_args.get("location"),
  55. unit=function_args.get("unit"),
  56. )
  57. messages.append(
  58. {
  59. "tool_call_id": tool_call.id,
  60. "role": "tool",
  61. "name": function_name,
  62. "content": function_response,
  63. }
  64. )
  65. second_response = client.chat.completions.create(
  66. model="gpt-3.5-turbo-1106",
  67. messages=messages,
  68. )
  69. return second_response
  70. print(run_conversation())

在推理结果中可以拿到类似{"name": "get_current_weather", "params": {"location": "北京", "unit": "celsius"}}这样的json数据,这里有需要调用的工具名称以及参数信息,接下来只需要编写代码实现工具调用,将工具返回的结果构造成message加入到与LLM对话的上下文中即可实现工具调用。这里的难点在于对一个开源模型来说,如何根据任务以及提供的工具描述给出正确的工具名称以及正确的参数。

三、开源模型工具调用微调

以下为复现实验数据过程记录

//Chat模型微调

模型Yi-6B-Chat硬件信息NVIDIA A100-SXM4-80GB
sft超参

  1. CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
  2. --stage sft \
  3. --do_train \
  4. --model_name_or_path /mnt/models/Yi-6B-Chat \
  5. --dataset glaive_toolcall \
  6. --template yi \
  7. --finetuning_type lora \
  8. --lora_target q_proj,v_proj \
  9. --output_dir yi_agent_checkopint \
  10. --lora_target all \
  11. --overwrite_cache \
  12. --per_device_train_batch_size 4 \
  13. --gradient_accumulation_steps 4 \
  14. --lr_scheduler_type cosine \
  15. --logging_steps 10 \
  16. --save_steps 1000 \
  17. --learning_rate 5e-4 \
  18. --num_train_epochs 3 \
  19. --plot_loss \
  20. --fp16

export model

  1. python src/export_model.py \
  2. --model_name_or_path /mnt/models/Yi-6B-Chat \
  3. --adapter_name_or_path yi_agent_checkopint \
  4. --template yi \
  5. --finetuning_type lora \
  6. --export_dir Yi-Agent-6b-Chat \
  7. --export_size 2 \
  8. --export_legacy_format False

web demo


python src/web_demo.py --model_name_or_path Yi-Agent-6b-Chat --template yi

//训练过程日志

  1. {'train_runtime': 7735.6787, 'train_samples_per_second': 3.878, 'train_steps_per_second': 0.242, 'train_loss': 0.3381453339894613, 'epoch': 3.0}
  2. 100%|████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [2:08:55<00:00, 4.13s/it]
  3. [INFO|trainer.py:2889] 2024-01-25 13:39:49,599 >> Saving model checkpoint to yi_agent_checkopint
  4. [INFO|tokenization_utils_base.py:2432] 2024-01-25 13:39:49,709 >> tokenizer config file saved in yi_agent_checkopint/tokenizer_config.json
  5. [INFO|tokenization_utils_base.py:2441] 2024-01-25 13:39:49,709 >> Special tokens file saved in yi_agent_checkopint/special_tokens_map.json
  6. ***** train metrics *****
  7. epoch = 3.0
  8. train_loss = 0.3381
  9. train_runtime = 2:08:55.67
  10. train_samples_per_second = 3.878
  11. train_steps_per_second = 0.242
  12. Figure saved: yi_agent_checkopint/training_loss.png
  13. 01/25/2024 13:39:49 - WARNING - llmtuner.extras.ploting - No metric eval_loss to plot.
  14. [INFO|modelcard.py:452] 2024-01-25 13:39:49,848 >> Dropping the following result as it does not have all the necessary fields:
  15. {'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}}

//测试结果

//测试Tools

  1. [
  2. {
  3. "name": "get_province_list",
  4. "description": "获取省份ID",
  5. "parameters": {
  6. "type": "object",
  7. "properties": {}
  8. }
  9. },
  10. {
  11. "name": "get_cities_list",
  12. "description": "根据省份ID查询城市地区ID",
  13. "parameters": {
  14. "type": "object",
  15. "properties": {
  16. "province_id": {
  17. "type": "string",
  18. "description": "省份ID,可以通过调用get_province_list获取省份ID"
  19. }
  20. },
  21. "required": [
  22. "province_id"
  23. ]
  24. }
  25. },
  26. {
  27. "name": "get_history_weather",
  28. "description": "根据城市ID和日期查询历史天气信息,日期支持从2011-01-01开始。注:个别地区个别日期数据记录可能会不存在",
  29. "parameters": {
  30. "type": "object",
  31. "properties": {
  32. "city_id": {
  33. "type": "string",
  34. "description": "城市地区ID,可以通过调用get_cities_list获取城市地区ID"
  35. },
  36. "weather_date": {
  37. "type": "string",
  38. "description": "日期,格式:2017-07-15,日期不能大于等于今日日期"
  39. }
  40. },
  41. "required": [
  42. "city_id",
  43. "weather_date"
  44. ]
  45. }
  46. },
  47. {
  48. "name": "get_river_environment",
  49. "description": "查询地表水水质",
  50. "parameters": {
  51. "type": "object",
  52. "properties": {
  53. "page": {
  54. "type": "integer",
  55. "description": "第几页(默认1)"
  56. },
  57. "province": {
  58. "type": "string",
  59. "description": "省份,例:江苏省"
  60. },
  61. "river": {
  62. "type": "string",
  63. "description": "流域,例:海河流域"
  64. },
  65. "section": {
  66. "type": "string",
  67. "description": "断面名称,例:鼓楼外大街"
  68. }
  69. },
  70. "required": []
  71. }
  72. },
  73. {
  74. "name": "get_environment_air_pm",
  75. "description": "查询的城市PM2.5数据",
  76. "parameters": {
  77. "type": "object",
  78. "properties": {
  79. "city": {
  80. "type": "string",
  81. "description": "城市名称的中文名称或拼音,如:上海 或 shanghai"
  82. }
  83. },
  84. "required": [
  85. "city"
  86. ]
  87. }
  88. },
  89. {
  90. "name": "get_toutiao_news",
  91. "description": "新闻列表查询",
  92. "parameters": {
  93. "type": "object",
  94. "properties": {
  95. "type": {
  96. "type": "string",
  97. "description": "支持类型 top(推荐,默认) guonei(国内) guoji(国际) yule(娱乐) tiyu(体育) junshi(军事) keji(科技) caijing(财经) youxi(游戏) qiche(汽车) jiankang(健康)"
  98. },
  99. "page": {
  100. "type": "string",
  101. "description": "当前页数, 默认1, 最大50"
  102. },
  103. "page_size": {
  104. "type": "string",
  105. "description": "每页返回条数, 默认30 , 最大30"
  106. },
  107. "is_filter": {
  108. "type": "string",
  109. "description": "是否只返回有内容详情的新闻, 1:是, 默认0"
  110. }
  111. },
  112. "required": []
  113. }
  114. },
  115. {
  116. "name": "chejian_query",
  117. "description": "根据车辆注册日期及类型,计算车辆的下次上线检验时间。本计算结果仅供参考。",
  118. "parameters": {
  119. "type": "object",
  120. "properties": {
  121. "type": {
  122. "type": "string",
  123. "description": "车辆类型, 3:9座(含)以下非营运小微型载客汽车(面包车除外) 4:摩托车 7:非营运大型轿车 1:营运车辆 2:货车、大中型客车 6:面包车 5:其他机动车"
  124. },
  125. "reg_date": {
  126. "type": "string",
  127. "description": "注册登记日期,格式:2022-11-02"
  128. },
  129. "iis_sg": {
  130. "type": "integer",
  131. "description": "事故情况(是否发生过致人伤亡事故或存在非法改装被依法处罚的交通违法),如是传1"
  132. }
  133. },
  134. "required": [
  135. "type",
  136. "reg_date"
  137. ]
  138. }
  139. },
  140. {
  141. "name": "loan_calc_query",
  142. "description": "公积金贷款计算器用于计算用户在申请公积金贷款时,选择等额本金和等额本息两种不同的还款方式后,每一期需偿还公积金贷款的月供,以及利息总额和还款总额。",
  143. "parameters": {
  144. "type": "object",
  145. "properties": {
  146. "money": {
  147. "type": "integer",
  148. "description": "贷款金额(0 < money <= 500),单位(万),如70表示70万;"
  149. },
  150. "year": {
  151. "type": "integer",
  152. "description": "贷款年限,单位(年),仅限输入 5、10、15、20、25、30"
  153. },
  154. "active": {
  155. "type": "string",
  156. "description": "贷款利率,默认3.25"
  157. }
  158. },
  159. "required": [
  160. "money",
  161. "year"
  162. ]
  163. }
  164. },
  165. {
  166. "name": "icp_query",
  167. "description": "网站icp备案查询",
  168. "parameters": {
  169. "type": "object",
  170. "properties": {
  171. "domainName": {
  172. "type": "string",
  173. "description": "获取的域名,如:juhe.cn"
  174. }
  175. },
  176. "required": [
  177. "domainName"
  178. ]
  179. }
  180. },
  181. {
  182. "name": "airport_query",
  183. "description": "获取全球机场三字码",
  184. "parameters": {
  185. "type": "object",
  186. "properties": {
  187. "airport": {
  188. "type": "string",
  189. "description": "关键词(可匹配城市机场的中英文名称、机场三字码)"
  190. },
  191. "page": {
  192. "type": "integer",
  193. "description": "页码(默认为1)"
  194. },
  195. "per_page": {
  196. "type": "integer",
  197. "description": "每页显示数量(默认为20,最大为100)"
  198. }
  199. },
  200. "required": [
  201. "airport"
  202. ]
  203. }
  204. },
  205. {
  206. "name": "aptabnormal_query",
  207. "description": "根据机场三字码查询国内机场不正常航班列表",
  208. "parameters": {
  209. "type": "object",
  210. "properties": {
  211. "airport": {
  212. "type": "string",
  213. "description": "机场三字码,字母大写(如:PEK),可通过airport_query获取三字码"
  214. }
  215. },
  216. "required": [
  217. "airport"
  218. ]
  219. }
  220. }
  221. ]

测试问题

请参考工具调用能力测试中的场景列(https://www.yuque.com/mrbun/sgr5h5/hsnz17g1a1wr6k2t#KmgD

四、预训练模型微调

硬件信息NVIDIA-4090 24G 单卡
sft超参

  1. CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
  2. --stage sft \
  3. --do_train \
  4. --model_name_or_path /data/models/Yi-6B \
  5. --dataset glaive_toolcall,alpaca_gpt4_en,alpaca_gpt4_zh,oaast_sft_zh \
  6. --max_samples 8000 \
  7. --template default \
  8. --finetuning_type lora \
  9. --lora_target q_proj,v_proj \
  10. --output_dir yi_agent_checkopint \
  11. --lora_target all \
  12. --overwrite_cache \
  13. --per_device_train_batch_size 1 \
  14. --gradient_accumulation_steps 4 \
  15. --lr_scheduler_type cosine \
  16. --logging_steps 10 \
  17. --save_steps 1000 \
  18. --learning_rate 5e-5 \
  19. --num_train_epochs 2 \
  20. --plot_loss \
  21. --fp16 \
  22. --flash_attn

export model

  1. python src/export_model.py \
  2. --model_name_or_path /data/models/Yi-6B \
  3. --adapter_name_or_path /data/projects/LLaMA-Factory/yi_agent_checkopint \
  4. --template default \
  5. --finetuning_type lora \
  6. --export_dir Yi-Agent-6B-Chat \
  7. --export_size 2 \
  8. --export_legacy_format False

web demo

  1. python src/web_demo.py \
  2. --model_name_or_path Yi-Agent-6B-Chat \
  3. --template default

测试结果不再赘述。

五、总结

通过SFT微调后可以让原本不具备工具调用能力的模型实现工具调用。通过测试结果可以看出对于复杂场景的效果不是很好,单工具的场景正确率很高,测试的场景是中文场景,训练集中是英文,泛化效果也很不错,我正在准备以下类型数据集,如果有类似的数据集可以在下面贴出连接。

  • API参数描述中需要调用另外一个接口拿到的场景,例如天气查询中的城市id需要调用获取城市idAPI拿到。
  • 对于问题中参数信息不完整,主动抛出问题获取更详细参数信息的场景。
  • 多工具场景。

模型已发布到modelscope Yi-Agent-6B-Chat

作者:徐辉| 后端开发工程师

更多AI小知识欢迎关注“神州数码云基地”公众号,回复“AI与数字化转型”进入社群交流

版权声明:文章由神州数码武汉云基地团队实践整理输出,转载请注明出处。

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号