赞
踩
通过参与“开源模型应用落地-业务整合系列篇”的学习,我们已经成功建立了基本的业务流程。然而,这只是迈出了万里长征的第一步。现在我们要对整个项目进行优化,以提高效率。我们计划利用线程池来加快处理速度,使用redis来实现排队需求,以及通过多级环境来减轻负载压力。这些优化措施将有助于我们进一步改进项目的性能和效果。
是一种用于线程管理的技术,它包含一组预先创建的线程,用于执行任务。线程池维护着一个任务队列,当有任务到达时,线程池中的线程会自动分配任务并执行。
线程池的主要目的是重用线程,避免频繁地创建和销毁线程带来的开销。通过使用线程池,可以在程序初始化时创建一组线程,并将任务提交给线程池进行处理,而不需要为每个任务都创建一个新的线程。这样可以有效地管理系统中的线程数量,控制并发度,提高系统的性能和资源利用率。
线程池通常包含以下几个关键组件:
对于每次交互的chat对话,都需要经过以下步骤,包括但不限于:
特别是调用第三方在线违规词检测,例如:某某云的内容安全审核服务,是非常耗时,会阻塞正常线程的执行,导致吞吐量的下降。
所以,我们就要对下面这块的处理逻辑进行调整,通过自定义线程池的方式,去处理核心的Chat交互流程
调整后:
- import io.netty.channel.ChannelHandlerContext;
- import lombok.extern.slf4j.Slf4j;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.stereotype.Component;
-
- import java.util.List;
- import java.util.concurrent.ExecutorService;
- import java.util.concurrent.Executors;
-
- @Component
- @Slf4j
- public class TaskUtils{
- private static ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
- @Autowired
- private AIChatUtils aiChatUtils;
-
- public void execute(AITaskReqMessage aiTaskReqMessage) {
-
- executorService.execute(() -> {
- Long userId = aiTaskReqMessage.getUserId();
-
- if (null == userId || (long) userId < 10000) {
- log.warn("用户身份标识有误!");
- return;
- }
-
- ChannelHandlerContext channelHandlerContext = AbstractBusinessLogicHandler.getContextByUserId(userId);
-
- if (channelHandlerContext != null) {
- try {
- aiChatUtils.chatStream(aiTaskReqMessage);
-
- } catch (Throwable exception) {
- exception.printStackTrace();
- }
- }
- });
- }
-
- public static void destory(){
- executorService.shutdownNow();
- executorService = null;
- }
-
- }
- import lombok.Builder;
- import lombok.Getter;
- import lombok.Setter;
-
- import java.util.List;
-
- @Builder
- @Setter
- @Getter
- public class AITaskReqMessage {
-
- private String messageId;
- private Long userId;
- private String contents;
- private List<ChatContext> history;
- }
在线测试方式:WebSocket在线测试工具
服务端输出:
服务端输出
6.1. 可以使用jmeter进行websocket压测,以评估各项性能指标是否符合预期(下一篇)
6.2. BusinessHandler完整代码
- import com.alibaba.fastjson.JSON;
- import io.netty.channel.ChannelHandler;
- import lombok.extern.slf4j.Slf4j;
- import io.netty.channel.ChannelHandlerContext;
- import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
- import org.apache.commons.lang3.StringUtils;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.stereotype.Component;
-
- import java.util.List;
-
-
- /**
- * @Description: 处理消息的handler
- */
- @Slf4j
- @ChannelHandler.Sharable
- @Component
- public class BusinessHandler extends AbstractBusinessLogicHandler<TextWebSocketFrame> {
- @Autowired
- private TaskUtils taskExecuteUtils;
-
- @Override
- public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
- String channelId = ctx.channel().id().asShortText();
- log.info("add client,channelId:{}", channelId);
- }
-
- @Override
- public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
- String channelId = ctx.channel().id().asShortText();
- log.info("remove client,channelId:{}", channelId);
- }
-
-
- @Override
- protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)
- throws Exception {
- // 获取客户端传输过来的消息
- String content = textWebSocketFrame.text();
- // 兼容在线测试
- if (StringUtils.equals(content, "PING")) {
- buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
- .respTime(String.valueOf(System.currentTimeMillis()))
- .msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
- .contents("心跳测试,很高兴收到你的心跳包")
- .build());
-
- return;
- }
- log.info("接收到客户端发送的信息: {}", content);
-
- Long userIdForReq;
- String msgType = "";
- String contents = "";
-
- try {
- ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);
- msgType = apiReqMessage.getMsgType();
- contents = apiReqMessage.getContents();
-
-
- userIdForReq = apiReqMessage.getUserId();
- // 用户身份标识校验
- if (null == userIdForReq || (long) userIdForReq <= 10000) {
- ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SYSTEM_ERROR.getCode()))
- .respTime(String.valueOf(System.currentTimeMillis()))
- .contents("用户身份标识有误!")
- .msgType(String.valueOf(MsgType.SYSTEM.getCode()))
- .build();
- buildResponseAndClose(channelHandlerContext, apiRespMessage);
- return;
- }
-
-
- if (StringUtils.equals(msgType, String.valueOf(MsgType.CHAT.getCode()))) {
- // 对用户输入的内容进行自定义违规词检测
- // 对用户输入的内容进行第三方在线违规词检测
- // 对用户输入的内容进行组装成Prompt
- // 对Prompt根据业务进行增强(完善prompt的内容)
- // 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)
- // ...
- String messageId = apiReqMessage.getMessageId();
- List<ChatContext> history = apiReqMessage.getHistory();
- AITaskReqMessage aiTaskReqMessage = AITaskReqMessage.builder().messageId(messageId).userId(userIdForReq).contents(contents).history(history).build();
- taskExecuteUtils.execute(aiTaskReqMessage);
-
-
- } else if (StringUtils.equals(msgType, String.valueOf(MsgType.INIT.getCode()))) {
- //一、业务黑名单检测(多次违规,永久锁定)
-
- //二、账户锁定检测(临时锁定)
-
- //三、多设备登录检测
-
- //四、剩余对话次数检测
-
- //检测通过,绑定用户与channel之间关系
- addChannel(channelHandlerContext, userIdForReq);
- String respMessage = "用户标识: " + userIdForReq + " 登录成功";
-
- buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
- .respTime(String.valueOf(System.currentTimeMillis()))
- .msgType(String.valueOf(MsgType.INIT.getCode()))
- .contents(respMessage)
- .build());
-
- } else if (StringUtils.equals(msgType, String.valueOf(MsgType.HEARTBEAT.getCode()))) {
-
- buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
- .respTime(String.valueOf(System.currentTimeMillis()))
- .msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
- .contents("心跳测试,很高兴收到你的心跳包")
- .build());
- }
- else {
- log.info("用户标识: {}, 消息类型有误,不支持类型: {}", userIdForReq, msgType);
- }
-
-
- } catch (Exception e) {
- log.warn("【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);
- // 异常返回
- return;
- }
-
- }
-
- }
6.3. AIChatUtils完整代码
- import com.alibaba.fastjson.JSON;
- import lombok.extern.slf4j.Slf4j;
- import okhttp3.MediaType;
- import okhttp3.Request;
- import okhttp3.RequestBody;
- import okhttp3.Response;
- import org.apache.commons.lang3.StringUtils;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.stereotype.Component;
-
- import java.io.ByteArrayOutputStream;
- import java.io.InputStream;
- import java.nio.charset.StandardCharsets;
- import java.security.MessageDigest;
- import java.util.List;
- import java.util.Objects;
-
- @Slf4j
- @Component
- public class AIChatUtils {
- @Autowired
- private AIConfig aiConfig;
-
- private Request buildRequest(Long userId, String prompt) throws Exception {
- //创建一个请求体对象(body)
- MediaType mediaType = MediaType.parse("application/json");
- RequestBody requestBody = RequestBody.create(mediaType, prompt);
-
- return buildHeader(userId, new Request.Builder().post(requestBody))
- .url(aiConfig.getUrl()).build();
- }
-
- private Request.Builder buildHeader(Long userId, Request.Builder builder) throws Exception {
- return builder
- .addHeader("Content-Type", "application/json")
- .addHeader("userId", String.valueOf(userId))
- .addHeader("secret",generateSecret(userId))
- }
-
-
-
- /**
- * 生成请求密钥
- *
- * @param userId 用户ID
- * @return
- */
- private String generateSecret(Long userId) throws Exception {
- String key = aiConfig.getServerKey();
- String content = key + userId + key;
-
- MessageDigest digest = MessageDigest.getInstance("SHA-256");
- byte[] hash = digest.digest(content.getBytes(StandardCharsets.UTF_8));
-
- StringBuilder hexString = new StringBuilder();
- for (byte b : hash) {
- String hex = Integer.toHexString(0xff & b);
- if (hex.length() == 1) {
- hexString.append('0');
- }
- hexString.append(hex);
- }
- return hexString.toString();
- }
-
- public String chatStream(AITaskReqMessage aiTaskReqMessage) throws Exception {
-
- String messageId = aiTaskReqMessage.getMessageId();
- Long userId = aiTaskReqMessage.getUserId();
- String contents = aiTaskReqMessage.getContents();
- List<ChatContext> history = aiTaskReqMessage.getHistory();
-
- if(StringUtils.isEmpty(contents) || StringUtils.isBlank(contents)){
- log.warn("用户输入内容不能为空!");
- return null;
- }
-
- //定义请求的参数
- String prompt = JSON.toJSONString(AIChatReqVO.init(contents, history));
- log.info("【AIChatUtils】调用AI聊天,用户({}),prompt:{}", userId, prompt);
-
- //创建一个请求对象
- Request request = buildRequest(userId, prompt);
-
- InputStream is = null;
- try {
-
- // 从线程池获取http请求并执行
- Response response =OkHttpUtils.getInstance(aiConfig).getOkHttpClient().newCall(request).execute();
-
- // 响应结果
- StringBuffer resultBuff = new StringBuffer();
- //正常返回
- if (response.code() == 200) {
- //打印返回的字符数据
- is = response.body().byteStream();
- byte[] bytes = new byte[1024];
-
- int len = is.read(bytes);
- while (len != -1) {
- ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
- outputStream.write(bytes, 0, len);
- outputStream.flush();
- // 本轮读取到的数据
- String result = new String(outputStream.toByteArray(), StandardCharsets.UTF_8);
- resultBuff.append(result);
-
- len = is.read(bytes);
-
- // 将数据逐个传输给用户
- AbstractBusinessLogicHandler.pushChatMessageForUser(userId, result);
- }
-
- // 正常响应
- return resultBuff.toString();
- }
- else {
- String result = response.body().string();
- log.warn("处理异常,异常描述:{}",result);
- }
- } catch (Throwable e) {
- log.error("【AIChatUtils】消息({})调用AI聊天 chatStream 异常,异常消息:{}", messageId, e.getMessage(), e);
-
- } finally {
- if (!Objects.isNull(is)) {
- try {
- is.close();
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
- }
- return null;
- }
-
-
- }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。