当前位置:   article > 正文

开源模型应用落地-业务优化篇(一)_开源大模型调优

开源大模型调优

一、前言

    通过参与“开源模型应用落地-业务整合系列篇”的学习,我们已经成功建立了基本的业务流程。然而,这只是迈出了万里长征的第一步。现在我们要对整个项目进行优化,以提高效率。我们计划利用线程池来加快处理速度,使用redis来实现排队需求,以及通过多级环境来减轻负载压力。这些优化措施将有助于我们进一步改进项目的性能和效果。


二、术语

2.1. 线程池

    是一种用于线程管理的技术,它包含一组预先创建的线程,用于执行任务。线程池维护着一个任务队列,当有任务到达时,线程池中的线程会自动分配任务并执行。

    线程池的主要目的是重用线程,避免频繁地创建和销毁线程带来的开销。通过使用线程池,可以在程序初始化时创建一组线程,并将任务提交给线程池进行处理,而不需要为每个任务都创建一个新的线程。这样可以有效地管理系统中的线程数量,控制并发度,提高系统的性能和资源利用率。

线程池通常包含以下几个关键组件:

  1. 任务队列(Task Queue):用于存储待执行的任务,通常是一个队列结构。当有新的任务到达时,会被添加到任务队列中。
  2. 线程池管理器(Thread Pool Manager):负责管理线程池的创建、销毁和线程的调度。它会监视任务队列的状态,并根据需要动态地创建或回收线程。
  3. 工作线程(Worker Threads):线程池中的线程,用于执行任务。它们会从任务队列中获取任务,并执行任务的处理逻辑。

三、前置条件

3.1. 已搭建WebSocket与AI服务调用链路


四、技术实现

4.1. 调整业务逻辑处理类

     对于每次交互的chat对话,都需要经过以下步骤,包括但不限于:

  1.   对用户输入的内容进行自定义违规词检测
  2.   对用户输入的内容进行第三方在线违规词检测
  3.   对用户输入的内容进行组装成Prompt
  4.   对Prompt根据业务进行增强(完善prompt的内容)
  5.   对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)

特别是调用第三方在线违规词检测,例如:某某云的内容安全审核服务,是非常耗时,会阻塞正常线程的执行,导致吞吐量的下降。

所以,我们就要对下面这块的处理逻辑进行调整,通过自定义线程池的方式,去处理核心的Chat交互流程

调整后:

4.2. 新增线程处理类

  1. import io.netty.channel.ChannelHandlerContext;
  2. import lombok.extern.slf4j.Slf4j;
  3. import org.springframework.beans.factory.annotation.Autowired;
  4. import org.springframework.stereotype.Component;
  5. import java.util.List;
  6. import java.util.concurrent.ExecutorService;
  7. import java.util.concurrent.Executors;
  8. @Component
  9. @Slf4j
  10. public class TaskUtils{
  11. private static ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
  12. @Autowired
  13. private AIChatUtils aiChatUtils;
  14. public void execute(AITaskReqMessage aiTaskReqMessage) {
  15. executorService.execute(() -> {
  16. Long userId = aiTaskReqMessage.getUserId();
  17. if (null == userId || (long) userId < 10000) {
  18. log.warn("用户身份标识有误!");
  19. return;
  20. }
  21. ChannelHandlerContext channelHandlerContext = AbstractBusinessLogicHandler.getContextByUserId(userId);
  22. if (channelHandlerContext != null) {
  23. try {
  24. aiChatUtils.chatStream(aiTaskReqMessage);
  25. } catch (Throwable exception) {
  26. exception.printStackTrace();
  27. }
  28. }
  29. });
  30. }
  31. public static void destory(){
  32. executorService.shutdownNow();
  33. executorService = null;
  34. }
  35. }

4.3. 新增线程处理实体类

  1. import lombok.Builder;
  2. import lombok.Getter;
  3. import lombok.Setter;
  4. import java.util.List;
  5. @Builder
  6. @Setter
  7. @Getter
  8. public class AITaskReqMessage {
  9. private String messageId;
  10. private Long userId;
  11. private String contents;
  12. private List<ChatContext> history;
  13. }


五、测试

在线测试方式:WebSocket在线测试工具

5.1.  建立连接

5.2.  业务初始化

服务端输出:

5.3.  业务对话

服务端输出

5.4.  关闭连接


六、附带说明

6.1. 可以使用jmeter进行websocket压测,以评估各项性能指标是否符合预期(下一篇)

6.2. BusinessHandler完整代码

  1. import com.alibaba.fastjson.JSON;
  2. import io.netty.channel.ChannelHandler;
  3. import lombok.extern.slf4j.Slf4j;
  4. import io.netty.channel.ChannelHandlerContext;
  5. import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
  6. import org.apache.commons.lang3.StringUtils;
  7. import org.springframework.beans.factory.annotation.Autowired;
  8. import org.springframework.stereotype.Component;
  9. import java.util.List;
  10. /**
  11. * @Description: 处理消息的handler
  12. */
  13. @Slf4j
  14. @ChannelHandler.Sharable
  15. @Component
  16. public class BusinessHandler extends AbstractBusinessLogicHandler<TextWebSocketFrame> {
  17. @Autowired
  18. private TaskUtils taskExecuteUtils;
  19. @Override
  20. public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
  21. String channelId = ctx.channel().id().asShortText();
  22. log.info("add client,channelId:{}", channelId);
  23. }
  24. @Override
  25. public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
  26. String channelId = ctx.channel().id().asShortText();
  27. log.info("remove client,channelId:{}", channelId);
  28. }
  29. @Override
  30. protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)
  31. throws Exception {
  32. // 获取客户端传输过来的消息
  33. String content = textWebSocketFrame.text();
  34. // 兼容在线测试
  35. if (StringUtils.equals(content, "PING")) {
  36. buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
  37. .respTime(String.valueOf(System.currentTimeMillis()))
  38. .msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
  39. .contents("心跳测试,很高兴收到你的心跳包")
  40. .build());
  41. return;
  42. }
  43. log.info("接收到客户端发送的信息: {}", content);
  44. Long userIdForReq;
  45. String msgType = "";
  46. String contents = "";
  47. try {
  48. ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);
  49. msgType = apiReqMessage.getMsgType();
  50. contents = apiReqMessage.getContents();
  51. userIdForReq = apiReqMessage.getUserId();
  52. // 用户身份标识校验
  53. if (null == userIdForReq || (long) userIdForReq <= 10000) {
  54. ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SYSTEM_ERROR.getCode()))
  55. .respTime(String.valueOf(System.currentTimeMillis()))
  56. .contents("用户身份标识有误!")
  57. .msgType(String.valueOf(MsgType.SYSTEM.getCode()))
  58. .build();
  59. buildResponseAndClose(channelHandlerContext, apiRespMessage);
  60. return;
  61. }
  62. if (StringUtils.equals(msgType, String.valueOf(MsgType.CHAT.getCode()))) {
  63. // 对用户输入的内容进行自定义违规词检测
  64. // 对用户输入的内容进行第三方在线违规词检测
  65. // 对用户输入的内容进行组装成Prompt
  66. // 对Prompt根据业务进行增强(完善prompt的内容)
  67. // 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)
  68. // ...
  69. String messageId = apiReqMessage.getMessageId();
  70. List<ChatContext> history = apiReqMessage.getHistory();
  71. AITaskReqMessage aiTaskReqMessage = AITaskReqMessage.builder().messageId(messageId).userId(userIdForReq).contents(contents).history(history).build();
  72. taskExecuteUtils.execute(aiTaskReqMessage);
  73. } else if (StringUtils.equals(msgType, String.valueOf(MsgType.INIT.getCode()))) {
  74. //一、业务黑名单检测(多次违规,永久锁定)
  75. //二、账户锁定检测(临时锁定)
  76. //三、多设备登录检测
  77. //四、剩余对话次数检测
  78. //检测通过,绑定用户与channel之间关系
  79. addChannel(channelHandlerContext, userIdForReq);
  80. String respMessage = "用户标识: " + userIdForReq + " 登录成功";
  81. buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
  82. .respTime(String.valueOf(System.currentTimeMillis()))
  83. .msgType(String.valueOf(MsgType.INIT.getCode()))
  84. .contents(respMessage)
  85. .build());
  86. } else if (StringUtils.equals(msgType, String.valueOf(MsgType.HEARTBEAT.getCode()))) {
  87. buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
  88. .respTime(String.valueOf(System.currentTimeMillis()))
  89. .msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
  90. .contents("心跳测试,很高兴收到你的心跳包")
  91. .build());
  92. }
  93. else {
  94. log.info("用户标识: {}, 消息类型有误,不支持类型: {}", userIdForReq, msgType);
  95. }
  96. } catch (Exception e) {
  97. log.warn("【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);
  98. // 异常返回
  99. return;
  100. }
  101. }
  102. }

6.3.  AIChatUtils完整代码

  1. import com.alibaba.fastjson.JSON;
  2. import lombok.extern.slf4j.Slf4j;
  3. import okhttp3.MediaType;
  4. import okhttp3.Request;
  5. import okhttp3.RequestBody;
  6. import okhttp3.Response;
  7. import org.apache.commons.lang3.StringUtils;
  8. import org.springframework.beans.factory.annotation.Autowired;
  9. import org.springframework.stereotype.Component;
  10. import java.io.ByteArrayOutputStream;
  11. import java.io.InputStream;
  12. import java.nio.charset.StandardCharsets;
  13. import java.security.MessageDigest;
  14. import java.util.List;
  15. import java.util.Objects;
  16. @Slf4j
  17. @Component
  18. public class AIChatUtils {
  19. @Autowired
  20. private AIConfig aiConfig;
  21. private Request buildRequest(Long userId, String prompt) throws Exception {
  22. //创建一个请求体对象(body)
  23. MediaType mediaType = MediaType.parse("application/json");
  24. RequestBody requestBody = RequestBody.create(mediaType, prompt);
  25. return buildHeader(userId, new Request.Builder().post(requestBody))
  26. .url(aiConfig.getUrl()).build();
  27. }
  28. private Request.Builder buildHeader(Long userId, Request.Builder builder) throws Exception {
  29. return builder
  30. .addHeader("Content-Type", "application/json")
  31. .addHeader("userId", String.valueOf(userId))
  32. .addHeader("secret",generateSecret(userId))
  33. }
  34. /**
  35. * 生成请求密钥
  36. *
  37. * @param userId 用户ID
  38. * @return
  39. */
  40. private String generateSecret(Long userId) throws Exception {
  41. String key = aiConfig.getServerKey();
  42. String content = key + userId + key;
  43. MessageDigest digest = MessageDigest.getInstance("SHA-256");
  44. byte[] hash = digest.digest(content.getBytes(StandardCharsets.UTF_8));
  45. StringBuilder hexString = new StringBuilder();
  46. for (byte b : hash) {
  47. String hex = Integer.toHexString(0xff & b);
  48. if (hex.length() == 1) {
  49. hexString.append('0');
  50. }
  51. hexString.append(hex);
  52. }
  53. return hexString.toString();
  54. }
  55. public String chatStream(AITaskReqMessage aiTaskReqMessage) throws Exception {
  56. String messageId = aiTaskReqMessage.getMessageId();
  57. Long userId = aiTaskReqMessage.getUserId();
  58. String contents = aiTaskReqMessage.getContents();
  59. List<ChatContext> history = aiTaskReqMessage.getHistory();
  60. if(StringUtils.isEmpty(contents) || StringUtils.isBlank(contents)){
  61. log.warn("用户输入内容不能为空!");
  62. return null;
  63. }
  64. //定义请求的参数
  65. String prompt = JSON.toJSONString(AIChatReqVO.init(contents, history));
  66. log.info("【AIChatUtils】调用AI聊天,用户({}),prompt:{}", userId, prompt);
  67. //创建一个请求对象
  68. Request request = buildRequest(userId, prompt);
  69. InputStream is = null;
  70. try {
  71. // 从线程池获取http请求并执行
  72. Response response =OkHttpUtils.getInstance(aiConfig).getOkHttpClient().newCall(request).execute();
  73. // 响应结果
  74. StringBuffer resultBuff = new StringBuffer();
  75. //正常返回
  76. if (response.code() == 200) {
  77. //打印返回的字符数据
  78. is = response.body().byteStream();
  79. byte[] bytes = new byte[1024];
  80. int len = is.read(bytes);
  81. while (len != -1) {
  82. ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
  83. outputStream.write(bytes, 0, len);
  84. outputStream.flush();
  85. // 本轮读取到的数据
  86. String result = new String(outputStream.toByteArray(), StandardCharsets.UTF_8);
  87. resultBuff.append(result);
  88. len = is.read(bytes);
  89. // 将数据逐个传输给用户
  90. AbstractBusinessLogicHandler.pushChatMessageForUser(userId, result);
  91. }
  92. // 正常响应
  93. return resultBuff.toString();
  94. }
  95. else {
  96. String result = response.body().string();
  97. log.warn("处理异常,异常描述:{}",result);
  98. }
  99. } catch (Throwable e) {
  100. log.error("【AIChatUtils】消息({})调用AI聊天 chatStream 异常,异常消息:{}", messageId, e.getMessage(), e);
  101. } finally {
  102. if (!Objects.isNull(is)) {
  103. try {
  104. is.close();
  105. } catch (Exception e) {
  106. e.printStackTrace();
  107. }
  108. }
  109. }
  110. return null;
  111. }
  112. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/511059
推荐阅读
相关标签
  

闽ICP备14008679号