当前位置:   article > 正文

java模拟GPT流式问答_realeventsource

realeventsource

流式请求gpt并且流式推送相关前端页面

1)java流式获取gpt答案

1、读取文件流的方式

使用post请求数据,由于gpt是eventsource的方式返回数据,所以格式是data:,需要手动替换一下值

  1. /**
  2. org.apache.http.client.methods
  3. **/
  4. @SneakyThrows
  5. private void chatStream(List<ChatParamMessagesBO> messagesBOList) {
  6. CloseableHttpClient httpclient = HttpClients.createDefault();
  7. HttpPost httpPost = new HttpPost("https://api.openai.com/v1/chat/completions");
  8. httpPost.setHeader("Authorization","xxxxxxxxxxxx");
  9. httpPost.setHeader("Content-Type","application/json; charset=UTF-8");
  10. ChatParamBO build = ChatParamBO.builder()
  11. .temperature(0.7)
  12. .model("gpt-3.5-turbo")
  13. .messages(messagesBOList)
  14. .stream(true)
  15. .build();
  16. System.out.println(JsonUtils.toJson(build));
  17. httpPost.setEntity(new StringEntity(JsonUtils.toJson(build),"utf-8"));
  18. CloseableHttpResponse response = httpclient.execute(httpPost);
  19. try {
  20. HttpEntity entity = response.getEntity();
  21. if (entity != null) {
  22. InputStream inputStream = entity.getContent();
  23. BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
  24. String line;
  25. while ((line = reader.readLine()) != null) {
  26. // 处理 event stream 数据
  27. try {
  28. // System.out.println(line);
  29. ChatResultBO chatResultBO = JsonUtils.toObject(line.replace("data:", ""), ChatResultBO.class);
  30. String content = chatResultBO.getChoices().get(0).getDelta().getContent();
  31. log.info(content);
  32. // System.out.println(chatResultBO.getChoices().get(0).getMessage().getContent());
  33. } catch (Exception e) {
  34. // e.printStackTrace();
  35. }
  36. }
  37. }
  38. } finally {
  39. response.close();
  40. }
  41. }

2、sse链接的方式获取数据

用到了okhttp

需要先引用相关maven:

  1. <dependency>
  2. <groupId>com.squareup.okhttp3</groupId>
  3. <artifactId>okhttp</artifactId>
  4. </dependency>
  5. <dependency>
  6. <groupId>com.squareup.okhttp3</groupId>
  7. <artifactId>okhttp-sse</artifactId>
  8. </dependency>
  1. // 定义see接口
  2. Request request = new Request.Builder().url("https://api.openai.com/v1/chat/completions")
  3. .header("Authorization","xxx")
  4. .post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),param.toJSONString()))
  5. .build();
  6. OkHttpClient okHttpClient = new OkHttpClient.Builder()
  7. .connectTimeout(10, TimeUnit.MINUTES)
  8. .readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
  9. .build();
  10. // 实例化EventSource,注册EventSource监听器
  11. RealEventSource realEventSource = new RealEventSource(request, new EventSourceListener() {
  12. @Override
  13. public void onOpen(EventSource eventSource, Response response) {
  14. log.info("onOpen");
  15. }
  16. @SneakyThrows
  17. @Override
  18. public void onEvent(EventSource eventSource, String id, String type, String data) {
  19. // log.info("onEvent");
  20. log.info(data);//请求到的数据
  21. }
  22. @Override
  23. public void onClosed(EventSource eventSource) {
  24. log.info("onClosed");
  25. // emitter.complete();
  26. }
  27. @Override
  28. public void onFailure(EventSource eventSource, Throwable t, Response response) {
  29. log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
  30. // emitter.complete();
  31. }
  32. });
  33. realEventSource.connect(okHttpClient);//真正开始请求的一步

2)流式推送答案

方法一:通过订阅式SSE/WebSocket

原理是先建立链接,然后不断发消息就可以

1、websocket

创建相关配置:

  1. import javax.websocket.Session;
  2. import lombok.Data;
  3. /**
  4. * @description WebSocket客户端连接
  5. */
  6. @Data
  7. public class WebSocketClient {
  8. // 与某个客户端的连接会话,需要通过它来给客户端发送数据
  9. private Session session;
  10. //连接的uri
  11. private String uri;
  12. }
  1. import org.springframework.context.annotation.Bean;
  2. import org.springframework.context.annotation.Configuration;
  3. import org.springframework.web.socket.server.standard.ServerEndpointExporter;
  4. @Configuration
  5. public class WebSocketConfig {
  6. @Bean
  7. public ServerEndpointExporter serverEndpointExporter() {
  8. return new ServerEndpointExporter();
  9. }
  10. }
配置相关service
  1. @Slf4j
  2. @Component
  3. @ServerEndpoint("/websocket/chat/{chatId}")
  4. public class ChatWebsocketService {
  5. static final ConcurrentHashMap<String, List<WebSocketClient>> webSocketClientMap= new ConcurrentHashMap<>();
  6. private String chatId;
  7. /**
  8. * 连接建立成功时触发,绑定参数
  9. * @param session 与某个客户端的连接会话,需要通过它来给客户端发送数据
  10. * @param chatId 商户ID
  11. */
  12. @OnOpen
  13. public void onOpen(Session session, @PathParam("chatId") String chatId){
  14. WebSocketClient client = new WebSocketClient();
  15. client.setSession(session);
  16. client.setUri(session.getRequestURI().toString());
  17. List<WebSocketClient> webSocketClientList = webSocketClientMap.get(chatId);
  18. if(webSocketClientList == null){
  19. webSocketClientList = new ArrayList<>();
  20. }
  21. webSocketClientList.add(client);
  22. webSocketClientMap.put(chatId, webSocketClientList);
  23. this.chatId = chatId;
  24. }
  25. /**
  26. * 收到客户端消息后调用的方法
  27. *
  28. * @param message 客户端发送过来的消息
  29. */
  30. @OnMessage
  31. public void onMessage(String message) {
  32. log.info("chatId = {},message = {}",chatId,message);
  33. // 回复消息
  34. this.chatStream(BaseUtil.newList(ChatParamMessagesBO.builder().content(message).role("user").build()));
  35. // this.sendMessage(chatId,message+"233");
  36. }
  37. /**
  38. * 连接关闭时触发,注意不能向客户端发送消息了
  39. * @param chatId
  40. */
  41. @OnClose
  42. public void onClose(@PathParam("chatId") String chatId){
  43. webSocketClientMap.remove(chatId);
  44. }
  45. /**
  46. * 通信发生错误时触发
  47. * @param session
  48. * @param error
  49. */
  50. @OnError
  51. public void onError(Session session, Throwable error) {
  52. System.out.println("发生错误");
  53. error.printStackTrace();
  54. }
  55. /**
  56. * 向客户端发送消息
  57. * @param chatId
  58. * @param message
  59. */
  60. public void sendMessage(String chatId,String message){
  61. try {
  62. List<WebSocketClient> webSocketClientList = webSocketClientMap.get(chatId);
  63. if(webSocketClientList!=null){
  64. for(WebSocketClient webSocketServer:webSocketClientList){
  65. webSocketServer.getSession().getBasicRemote().sendText(message);
  66. }
  67. }
  68. } catch (IOException e) {
  69. e.printStackTrace();
  70. throw new RuntimeException(e.getMessage());
  71. }
  72. }
  73. /**
  74. * 流式调用查询gpt
  75. * @param messagesBOList
  76. * @throws IOException
  77. */
  78. @SneakyThrows
  79. private void chatStream(List<ChatParamMessagesBO> messagesBOList) {
  80. // TODO 和GPT的访问请求
  81. }
  82. }
测试,postman建立链接

2、SSE

本质也是基于订阅推送方式

前端:
  1. <!DOCTYPE html>
  2. <html lang="en">
  3. <head>
  4. <meta charset="UTF-8">
  5. <title>SseEmitter</title>
  6. </head>
  7. <body>
  8. <button onclick="closeSse()">关闭连接</button>
  9. <div id="message"></div>
  10. </body>
  11. <script>
  12. let source = null;
  13. // 用时间戳模拟登录用户
  14. //const id = new Date().getTime();
  15. const id = '7829083B42464C5B9C445A087E873C7D';
  16. if (window.EventSource) {
  17. // 建立连接
  18. source = new EventSource('http://172.28.54.27:8902/api/sse/connect?conversationId=' + id);
  19. setMessageInnerHTML("连接用户=" + id);
  20. /**
  21. * 连接一旦建立,就会触发open事件
  22. * 另一种写法:source.onopen = function (event) {}
  23. */
  24. source.addEventListener('open', function(e) {
  25. setMessageInnerHTML("建立连接。。。");
  26. }, false);
  27. /**
  28. * 客户端收到服务器发来的数据
  29. * 另一种写法:source.onmessage = function (event) {}
  30. */
  31. source.addEventListener('message', function(e) {
  32. //console.log(e);
  33. setMessageInnerHTML(e.data);
  34. });
  35. source.addEventListener("close", function (event) {
  36. // 在这里处理关闭事件
  37. console.log("Server closed the connection");
  38. // 可以选择关闭EventSource连接
  39. source.close();
  40. });
  41. /**
  42. * 如果发生通信错误(比如连接中断),就会触发error事件
  43. * 或者:
  44. * 另一种写法:source.onerror = function (event) {}
  45. */
  46. source.addEventListener('error', function(e) {
  47. console.log(e);
  48. if (e.readyState === EventSource.CLOSED) {
  49. setMessageInnerHTML("连接关闭");
  50. } else {
  51. console.log(e);
  52. }
  53. }, false);
  54. } else {
  55. setMessageInnerHTML("你的浏览器不支持SSE");
  56. }
  57. // 监听窗口关闭事件,主动去关闭sse连接,如果服务端设置永不过期,浏览器关闭后手动清理服务端数据
  58. window.onbeforeunload = function() {
  59. //closeSse();
  60. };
  61. // 关闭Sse连接
  62. function closeSse() {
  63. source.close();
  64. const httpRequest = new XMLHttpRequest();
  65. httpRequest.open('GET', 'http://172.28.54.27:8902/api/sse/disconnection?conversationId=' + id, true);
  66. httpRequest.send();
  67. console.log("close");
  68. }
  69. // 将消息显示在网页上
  70. function setMessageInnerHTML(innerHTML) {
  71. document.getElementById('message').innerHTML += innerHTML + '<br/>';
  72. }
  73. </script>
  74. </html>
后端:
controller
  1. import org.springframework.cloud.context.config.annotation.RefreshScope;
  2. import org.springframework.validation.annotation.Validated;
  3. import org.springframework.web.bind.annotation.GetMapping;
  4. import org.springframework.web.bind.annotation.RequestMapping;
  5. import org.springframework.web.bind.annotation.RestController;
  6. import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
  7. import java.util.Set;
  8. import java.util.function.Consumer;
  9. import javax.annotation.Resource;
  10. import lombok.SneakyThrows;
  11. import lombok.extern.slf4j.Slf4j;
  12. @Validated
  13. @RestController
  14. @RequestMapping("/api/sse")
  15. @Slf4j
  16. @RefreshScope // 会监听变化实时变化值
  17. public class SseController {
  18. @Resource
  19. private SseBizService sseBizService;
  20. /**
  21. * 创建用户连接并返回 SseEmitter
  22. *
  23. * @param conversationId 用户ID
  24. * @return SseEmitter
  25. */
  26. @SneakyThrows
  27. @GetMapping(value = "/connect", produces = "text/event-stream; charset=utf-8")
  28. public SseEmitter connect(String conversationId) {
  29. // 设置超时时间,0表示不过期。默认30秒,超过时间未完成会抛出异常:AsyncRequestTimeoutException
  30. SseEmitter sseEmitter = new SseEmitter(0L);
  31. // 注册回调
  32. sseEmitter.onCompletion(completionCallBack(conversationId));
  33. sseEmitter.onError(errorCallBack(conversationId));
  34. sseEmitter.onTimeout(timeoutCallBack(conversationId));
  35. log.info("创建新的sse连接,当前用户:{}", conversationId);
  36. sseBizService.addConnect(conversationId,sseEmitter);
  37. sseBizService.sendMsg(conversationId,"链接成功");
  38. // sseCache.get(conversationId).send(SseEmitter.event().reconnectTime(10000).data("链接成功"),MediaType.TEXT_EVENT_STREAM);
  39. return sseEmitter;
  40. }
  41. /**
  42. * 给指定用户发送信息 -- 单播
  43. */
  44. @GetMapping(value = "/send", produces = "text/event-stream; charset=utf-8")
  45. public void sendMessage(String conversationId, String msg) {
  46. sseBizService.sendMsg(conversationId,msg);
  47. }
  48. /**
  49. * 移除用户连接
  50. */
  51. @GetMapping(value = "/disconnection", produces = "text/event-stream; charset=utf-8")
  52. public void removeUser(String conversationId) {
  53. log.info("移除用户:{}", conversationId);
  54. sseBizService.deleteConnect(conversationId);
  55. }
  56. /**
  57. * 向多人发布消息 -- 组播
  58. * @param groupId 开头标识
  59. * @param message 消息内容
  60. */
  61. public void groupSendMessage(String groupId, String message) {
  62. /* if (!BaseUtil.isNullOrEmpty(sseCache)) {
  63. *//*Set<String> ids = sseEmitterMap.keySet().stream().filter(m -> m.startsWith(groupId)).collect(Collectors.toSet());
  64. batchSendMessage(message, ids);*//*
  65. sseCache.forEach((k, v) -> {
  66. try {
  67. if (k.startsWith(groupId)) {
  68. v.send(message, MediaType.APPLICATION_JSON);
  69. }
  70. } catch (IOException e) {
  71. log.error("用户[{}]推送异常:{}", k, e.getMessage());
  72. removeUser(k);
  73. }
  74. });
  75. }*/
  76. }
  77. /**
  78. * 群发所有人 -- 广播
  79. */
  80. public void batchSendMessage(String message) {
  81. /*sseCache.forEach((k, v) -> {
  82. try {
  83. v.send(message, MediaType.APPLICATION_JSON);
  84. } catch (IOException e) {
  85. log.error("用户[{}]推送异常:{}", k, e.getMessage());
  86. removeUser(k);
  87. }
  88. });*/
  89. }
  90. /**
  91. * 群发消息
  92. */
  93. public void batchSendMessage(String message, Set<String> ids) {
  94. ids.forEach(userId -> sendMessage(userId, message));
  95. }
  96. /**
  97. * 获取当前连接信息
  98. */
  99. // public List<String> getIds() {
  100. // return new ArrayList<>(sseCache.keySet());
  101. // }
  102. /**
  103. * 获取当前连接数量
  104. */
  105. // public int getUserCount() {
  106. // return count.intValue();
  107. // }
  108. private Runnable completionCallBack(String userId) {
  109. return () -> {
  110. log.info("结束连接:{}", userId);
  111. removeUser(userId);
  112. };
  113. }
  114. private Runnable timeoutCallBack(String userId) {
  115. return () -> {
  116. log.info("连接超时:{}", userId);
  117. removeUser(userId);
  118. };
  119. }
  120. private Consumer<Throwable> errorCallBack(String userId) {
  121. return throwable -> {
  122. log.info("连接异常:{}", userId);
  123. removeUser(userId);
  124. };
  125. }
  126. }
service
  1. import org.springframework.cloud.context.config.annotation.RefreshScope;
  2. import org.springframework.http.MediaType;
  3. import org.springframework.stereotype.Component;
  4. import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
  5. import java.util.Map;
  6. import java.util.concurrent.ConcurrentHashMap;
  7. import java.util.concurrent.atomic.AtomicInteger;
  8. import lombok.SneakyThrows;
  9. import lombok.extern.slf4j.Slf4j;
  10. @Component
  11. @Slf4j
  12. @RefreshScope // 会监听变化实时变化值
  13. public class SseBizService {
  14. /**
  15. *
  16. * 当前连接数
  17. */
  18. private AtomicInteger count = new AtomicInteger(0);
  19. /**
  20. * 使用map对象,便于根据userId来获取对应的SseEmitter,或者放redis里面
  21. */
  22. private Map<String, SseEmitter> sseCache = new ConcurrentHashMap<>();
  23. /**
  24. * 添加用户
  25. * @author pengbin <pengbin>
  26. * @date 2023/9/11 11:37
  27. * @param
  28. * @return
  29. */
  30. public void addConnect(String id,SseEmitter sseEmitter){
  31. sseCache.put(id, sseEmitter);
  32. // 数量+1
  33. count.getAndIncrement();
  34. }
  35. /**
  36. * 删除用户
  37. * @author pengbin <pengbin>
  38. * @date 2023/9/11 11:37
  39. * @param
  40. * @return
  41. */
  42. public void deleteConnect(String id){
  43. sseCache.remove(id);
  44. // 数量+1
  45. count.getAndDecrement();
  46. }
  47. /**
  48. * 发送消息
  49. * @author pengbin <pengbin>
  50. * @date 2023/9/11 11:38
  51. * @param
  52. * @return
  53. */
  54. @SneakyThrows
  55. public void sendMsg(String id, String msg){
  56. if(sseCache.containsKey(id)){
  57. sseCache.get(id).send(msg, MediaType.TEXT_EVENT_STREAM);
  58. }
  59. }
  60. }

方法二:SSE建立eventSource,使用完成后即刻销毁

前端:在接收到结束标识后立即销毁

  1. /**
  2. * 客户端收到服务器发来的数据
  3. * 另一种写法:source.onmessage = function (event) {}
  4. */
  5. source.addEventListener('message', function(e) {
  6. //console.log(e);
  7. setMessageInnerHTML(e.data);
  8. if(e.data == '[DONE]'){
  9. source.close();
  10. }
  11. });

后端:
 

  1. @SneakyThrows
  2. @GetMapping(value = "/stream/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
  3. public SseEmitter completionsStream(@RequestParam String conversationId){
  4. //
  5. List<ChatParamMessagesBO> messagesBOList =new ArrayList();
  6. // 获取内容信息
  7. ChatParamBO build = ChatParamBO.builder()
  8. .temperature(0.7)
  9. .stream(true)
  10. .model("xxxx")
  11. .messages(messagesBOList)
  12. .build();
  13. SseEmitter emitter = new SseEmitter();
  14. // 定义see接口
  15. Request request = new Request.Builder().url("xxx")
  16. .header("Authorization","xxxx")
  17. .post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),JsonUtils.toJson(build)))
  18. .build();
  19. OkHttpClient okHttpClient = new OkHttpClient.Builder()
  20. .connectTimeout(10, TimeUnit.MINUTES)
  21. .readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
  22. .build();
  23. StringBuffer sb = new StringBuffer("");
  24. // 实例化EventSource,注册EventSource监听器
  25. RealEventSource realEventSource = null;
  26. realEventSource = new RealEventSource(request, new EventSourceListener() {
  27. @Override
  28. public void onOpen(EventSource eventSource, Response response) {
  29. log.info("onOpen");
  30. }
  31. @SneakyThrows
  32. @Override
  33. public void onEvent(EventSource eventSource, String id, String type, String data) {
  34. log.info(data);//请求到的数据
  35. try {
  36. ChatResultBO chatResultBO = JsonUtils.toObject(data.replace("data:", ""), ChatResultBO.class);
  37. String content = chatResultBO.getChoices().get(0).getDelta().getContent();
  38. sb.append(content);
  39. emitter.send(SseEmitter.event().data(JsonUtils.toJson(ChatContentBO.builder().content(content).build())));
  40. } catch (Exception e) {
  41. // e.printStackTrace();
  42. }
  43. if("[DONE]".equals(data)){
  44. emitter.send(SseEmitter.event().data(data));
  45. emitter.complete();
  46. log.info("result={}",sb);
  47. }
  48. }
  49. @Override
  50. public void onClosed(EventSource eventSource) {
  51. log.info("onClosed,eventSource={}",eventSource);//这边可以监听并重新打开
  52. // emitter.complete();
  53. }
  54. @Override
  55. public void onFailure(EventSource eventSource, Throwable t, Response response) {
  56. log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
  57. // emitter.complete();
  58. }
  59. });
  60. realEventSource.connect(okHttpClient);//真正开始请求的一步
  61. return emitter;
  62. }

3)踩坑

ngnix配置:

后端配置需要添加:

 #gpt支持流式处理
  proxy_buffering off;

  1. location / {
  2. proxy_pass http://backend;
  3. proxy_redirect default;
  4. proxy_connect_timeout 90;
  5. proxy_read_timeout 90;
  6. proxy_send_timeout 90;
  7. #gpt支持流式处理
  8. proxy_buffering off;
  9. #root html;
  10. #root /opt/project/;
  11. index index.html index.htm;
  12. client_max_body_size 1024m;
  13. #设置正确的外网ip
  14. proxy_set_header Host $host;
  15. proxy_set_header X-Real-IP $remote_addr;
  16. proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
  17. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/人工智能uu/article/detail/962848
推荐阅读
相关标签
  

闽ICP备14008679号