赞
踩
流式请求gpt并且流式推送相关前端页面
使用post请求数据,由于gpt是eventsource的方式返回数据,所以格式是data:,需要手动替换一下值
- /**
- org.apache.http.client.methods
- **/
- @SneakyThrows
- private void chatStream(List<ChatParamMessagesBO> messagesBOList) {
- CloseableHttpClient httpclient = HttpClients.createDefault();
- HttpPost httpPost = new HttpPost("https://api.openai.com/v1/chat/completions");
- httpPost.setHeader("Authorization","xxxxxxxxxxxx");
- httpPost.setHeader("Content-Type","application/json; charset=UTF-8");
-
- ChatParamBO build = ChatParamBO.builder()
- .temperature(0.7)
- .model("gpt-3.5-turbo")
- .messages(messagesBOList)
- .stream(true)
- .build();
- System.out.println(JsonUtils.toJson(build));
- httpPost.setEntity(new StringEntity(JsonUtils.toJson(build),"utf-8"));
- CloseableHttpResponse response = httpclient.execute(httpPost);
- try {
- HttpEntity entity = response.getEntity();
- if (entity != null) {
- InputStream inputStream = entity.getContent();
- BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
-
- String line;
- while ((line = reader.readLine()) != null) {
- // 处理 event stream 数据
-
- try {
- // System.out.println(line);
- ChatResultBO chatResultBO = JsonUtils.toObject(line.replace("data:", ""), ChatResultBO.class);
- String content = chatResultBO.getChoices().get(0).getDelta().getContent();
- log.info(content);
-
- // System.out.println(chatResultBO.getChoices().get(0).getMessage().getContent());
- } catch (Exception e) {
- // e.printStackTrace();
- }
- }
- }
- } finally {
- response.close();
- }
- }
用到了okhttp
需要先引用相关maven:
- <dependency>
- <groupId>com.squareup.okhttp3</groupId>
- <artifactId>okhttp</artifactId>
- </dependency>
- <dependency>
- <groupId>com.squareup.okhttp3</groupId>
- <artifactId>okhttp-sse</artifactId>
- </dependency>
-
- // 定义see接口
- Request request = new Request.Builder().url("https://api.openai.com/v1/chat/completions")
- .header("Authorization","xxx")
- .post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),param.toJSONString()))
- .build();
- OkHttpClient okHttpClient = new OkHttpClient.Builder()
- .connectTimeout(10, TimeUnit.MINUTES)
- .readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
- .build();
-
- // 实例化EventSource,注册EventSource监听器
- RealEventSource realEventSource = new RealEventSource(request, new EventSourceListener() {
-
- @Override
- public void onOpen(EventSource eventSource, Response response) {
- log.info("onOpen");
- }
-
- @SneakyThrows
- @Override
- public void onEvent(EventSource eventSource, String id, String type, String data) {
- // log.info("onEvent");
- log.info(data);//请求到的数据
-
- }
-
- @Override
- public void onClosed(EventSource eventSource) {
- log.info("onClosed");
- // emitter.complete();
- }
-
- @Override
- public void onFailure(EventSource eventSource, Throwable t, Response response) {
- log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
- // emitter.complete();
- }
- });
- realEventSource.connect(okHttpClient);//真正开始请求的一步
原理是先建立链接,然后不断发消息就可以
-
- import javax.websocket.Session;
-
- import lombok.Data;
-
- /**
- * @description WebSocket客户端连接
- */
- @Data
- public class WebSocketClient {
-
- // 与某个客户端的连接会话,需要通过它来给客户端发送数据
- private Session session;
-
- //连接的uri
- private String uri;
-
- }
-
-
- import org.springframework.context.annotation.Bean;
- import org.springframework.context.annotation.Configuration;
- import org.springframework.web.socket.server.standard.ServerEndpointExporter;
-
- @Configuration
- public class WebSocketConfig {
- @Bean
- public ServerEndpointExporter serverEndpointExporter() {
- return new ServerEndpointExporter();
- }
- }
-
-
- @Slf4j
- @Component
- @ServerEndpoint("/websocket/chat/{chatId}")
- public class ChatWebsocketService {
-
- static final ConcurrentHashMap<String, List<WebSocketClient>> webSocketClientMap= new ConcurrentHashMap<>();
-
- private String chatId;
-
- /**
- * 连接建立成功时触发,绑定参数
- * @param session 与某个客户端的连接会话,需要通过它来给客户端发送数据
- * @param chatId 商户ID
- */
- @OnOpen
- public void onOpen(Session session, @PathParam("chatId") String chatId){
-
- WebSocketClient client = new WebSocketClient();
- client.setSession(session);
- client.setUri(session.getRequestURI().toString());
-
- List<WebSocketClient> webSocketClientList = webSocketClientMap.get(chatId);
- if(webSocketClientList == null){
- webSocketClientList = new ArrayList<>();
- }
- webSocketClientList.add(client);
- webSocketClientMap.put(chatId, webSocketClientList);
- this.chatId = chatId;
- }
-
- /**
- * 收到客户端消息后调用的方法
- *
- * @param message 客户端发送过来的消息
- */
- @OnMessage
- public void onMessage(String message) {
- log.info("chatId = {},message = {}",chatId,message);
- // 回复消息
- this.chatStream(BaseUtil.newList(ChatParamMessagesBO.builder().content(message).role("user").build()));
- // this.sendMessage(chatId,message+"233");
- }
-
- /**
- * 连接关闭时触发,注意不能向客户端发送消息了
- * @param chatId
- */
- @OnClose
- public void onClose(@PathParam("chatId") String chatId){
- webSocketClientMap.remove(chatId);
- }
-
- /**
- * 通信发生错误时触发
- * @param session
- * @param error
- */
- @OnError
- public void onError(Session session, Throwable error) {
- System.out.println("发生错误");
- error.printStackTrace();
- }
-
- /**
- * 向客户端发送消息
- * @param chatId
- * @param message
- */
- public void sendMessage(String chatId,String message){
- try {
- List<WebSocketClient> webSocketClientList = webSocketClientMap.get(chatId);
- if(webSocketClientList!=null){
- for(WebSocketClient webSocketServer:webSocketClientList){
- webSocketServer.getSession().getBasicRemote().sendText(message);
- }
- }
- } catch (IOException e) {
- e.printStackTrace();
- throw new RuntimeException(e.getMessage());
- }
- }
-
- /**
- * 流式调用查询gpt
- * @param messagesBOList
- * @throws IOException
- */
- @SneakyThrows
- private void chatStream(List<ChatParamMessagesBO> messagesBOList) {
- // TODO 和GPT的访问请求
- }
- }
本质也是基于订阅推送方式
- <!DOCTYPE html>
- <html lang="en">
-
- <head>
- <meta charset="UTF-8">
- <title>SseEmitter</title>
- </head>
-
- <body>
- <button onclick="closeSse()">关闭连接</button>
- <div id="message"></div>
- </body>
- <script>
- let source = null;
-
- // 用时间戳模拟登录用户
- //const id = new Date().getTime();
- const id = '7829083B42464C5B9C445A087E873C7D';
-
- if (window.EventSource) {
-
- // 建立连接
-
- source = new EventSource('http://172.28.54.27:8902/api/sse/connect?conversationId=' + id);
- setMessageInnerHTML("连接用户=" + id);
- /**
- * 连接一旦建立,就会触发open事件
- * 另一种写法:source.onopen = function (event) {}
- */
- source.addEventListener('open', function(e) {
- setMessageInnerHTML("建立连接。。。");
- }, false);
-
- /**
- * 客户端收到服务器发来的数据
- * 另一种写法:source.onmessage = function (event) {}
- */
- source.addEventListener('message', function(e) {
- //console.log(e);
- setMessageInnerHTML(e.data);
-
- });
-
- source.addEventListener("close", function (event) {
- // 在这里处理关闭事件
- console.log("Server closed the connection");
-
- // 可以选择关闭EventSource连接
- source.close();
- });
-
- /**
- * 如果发生通信错误(比如连接中断),就会触发error事件
- * 或者:
- * 另一种写法:source.onerror = function (event) {}
- */
- source.addEventListener('error', function(e) {
- console.log(e);
- if (e.readyState === EventSource.CLOSED) {
- setMessageInnerHTML("连接关闭");
- } else {
- console.log(e);
- }
- }, false);
-
- } else {
- setMessageInnerHTML("你的浏览器不支持SSE");
- }
-
- // 监听窗口关闭事件,主动去关闭sse连接,如果服务端设置永不过期,浏览器关闭后手动清理服务端数据
- window.onbeforeunload = function() {
- //closeSse();
- };
-
- // 关闭Sse连接
- function closeSse() {
- source.close();
- const httpRequest = new XMLHttpRequest();
- httpRequest.open('GET', 'http://172.28.54.27:8902/api/sse/disconnection?conversationId=' + id, true);
- httpRequest.send();
- console.log("close");
- }
-
- // 将消息显示在网页上
- function setMessageInnerHTML(innerHTML) {
- document.getElementById('message').innerHTML += innerHTML + '<br/>';
- }
- </script>
-
- </html>
-
-
- import org.springframework.cloud.context.config.annotation.RefreshScope;
- import org.springframework.validation.annotation.Validated;
- import org.springframework.web.bind.annotation.GetMapping;
- import org.springframework.web.bind.annotation.RequestMapping;
- import org.springframework.web.bind.annotation.RestController;
- import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
-
- import java.util.Set;
- import java.util.function.Consumer;
-
- import javax.annotation.Resource;
-
- import lombok.SneakyThrows;
- import lombok.extern.slf4j.Slf4j;
-
- @Validated
- @RestController
- @RequestMapping("/api/sse")
- @Slf4j
- @RefreshScope // 会监听变化实时变化值
- public class SseController {
-
- @Resource
- private SseBizService sseBizService;
-
-
- /**
- * 创建用户连接并返回 SseEmitter
- *
- * @param conversationId 用户ID
- * @return SseEmitter
- */
- @SneakyThrows
- @GetMapping(value = "/connect", produces = "text/event-stream; charset=utf-8")
- public SseEmitter connect(String conversationId) {
- // 设置超时时间,0表示不过期。默认30秒,超过时间未完成会抛出异常:AsyncRequestTimeoutException
- SseEmitter sseEmitter = new SseEmitter(0L);
- // 注册回调
- sseEmitter.onCompletion(completionCallBack(conversationId));
- sseEmitter.onError(errorCallBack(conversationId));
- sseEmitter.onTimeout(timeoutCallBack(conversationId));
- log.info("创建新的sse连接,当前用户:{}", conversationId);
- sseBizService.addConnect(conversationId,sseEmitter);
- sseBizService.sendMsg(conversationId,"链接成功");
- // sseCache.get(conversationId).send(SseEmitter.event().reconnectTime(10000).data("链接成功"),MediaType.TEXT_EVENT_STREAM);
- return sseEmitter;
- }
-
- /**
- * 给指定用户发送信息 -- 单播
- */
- @GetMapping(value = "/send", produces = "text/event-stream; charset=utf-8")
- public void sendMessage(String conversationId, String msg) {
- sseBizService.sendMsg(conversationId,msg);
- }
-
- /**
- * 移除用户连接
- */
- @GetMapping(value = "/disconnection", produces = "text/event-stream; charset=utf-8")
- public void removeUser(String conversationId) {
- log.info("移除用户:{}", conversationId);
- sseBizService.deleteConnect(conversationId);
- }
-
- /**
- * 向多人发布消息 -- 组播
- * @param groupId 开头标识
- * @param message 消息内容
- */
- public void groupSendMessage(String groupId, String message) {
- /* if (!BaseUtil.isNullOrEmpty(sseCache)) {
- *//*Set<String> ids = sseEmitterMap.keySet().stream().filter(m -> m.startsWith(groupId)).collect(Collectors.toSet());
- batchSendMessage(message, ids);*//*
- sseCache.forEach((k, v) -> {
- try {
- if (k.startsWith(groupId)) {
- v.send(message, MediaType.APPLICATION_JSON);
- }
- } catch (IOException e) {
- log.error("用户[{}]推送异常:{}", k, e.getMessage());
- removeUser(k);
- }
- });
- }*/
- }
-
- /**
- * 群发所有人 -- 广播
- */
- public void batchSendMessage(String message) {
- /*sseCache.forEach((k, v) -> {
- try {
- v.send(message, MediaType.APPLICATION_JSON);
- } catch (IOException e) {
- log.error("用户[{}]推送异常:{}", k, e.getMessage());
- removeUser(k);
- }
- });*/
- }
-
- /**
- * 群发消息
- */
- public void batchSendMessage(String message, Set<String> ids) {
- ids.forEach(userId -> sendMessage(userId, message));
- }
-
-
- /**
- * 获取当前连接信息
- */
- // public List<String> getIds() {
- // return new ArrayList<>(sseCache.keySet());
- // }
-
- /**
- * 获取当前连接数量
- */
- // public int getUserCount() {
- // return count.intValue();
- // }
-
- private Runnable completionCallBack(String userId) {
- return () -> {
- log.info("结束连接:{}", userId);
- removeUser(userId);
- };
- }
-
- private Runnable timeoutCallBack(String userId) {
- return () -> {
- log.info("连接超时:{}", userId);
- removeUser(userId);
- };
- }
-
- private Consumer<Throwable> errorCallBack(String userId) {
- return throwable -> {
- log.info("连接异常:{}", userId);
- removeUser(userId);
- };
- }
- }
-
-
- import org.springframework.cloud.context.config.annotation.RefreshScope;
- import org.springframework.http.MediaType;
- import org.springframework.stereotype.Component;
- import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
-
- import java.util.Map;
- import java.util.concurrent.ConcurrentHashMap;
- import java.util.concurrent.atomic.AtomicInteger;
-
- import lombok.SneakyThrows;
- import lombok.extern.slf4j.Slf4j;
-
- @Component
- @Slf4j
- @RefreshScope // 会监听变化实时变化值
- public class SseBizService {
- /**
- *
- * 当前连接数
- */
- private AtomicInteger count = new AtomicInteger(0);
-
- /**
- * 使用map对象,便于根据userId来获取对应的SseEmitter,或者放redis里面
- */
- private Map<String, SseEmitter> sseCache = new ConcurrentHashMap<>();
-
-
- /**
- * 添加用户
- * @author pengbin <pengbin>
- * @date 2023/9/11 11:37
- * @param
- * @return
- */
- public void addConnect(String id,SseEmitter sseEmitter){
- sseCache.put(id, sseEmitter);
- // 数量+1
- count.getAndIncrement();
- }
- /**
- * 删除用户
- * @author pengbin <pengbin>
- * @date 2023/9/11 11:37
- * @param
- * @return
- */
- public void deleteConnect(String id){
- sseCache.remove(id);
- // 数量+1
- count.getAndDecrement();
- }
-
- /**
- * 发送消息
- * @author pengbin <pengbin>
- * @date 2023/9/11 11:38
- * @param
- * @return
- */
- @SneakyThrows
- public void sendMsg(String id, String msg){
- if(sseCache.containsKey(id)){
- sseCache.get(id).send(msg, MediaType.TEXT_EVENT_STREAM);
- }
- }
-
- }
- /**
- * 客户端收到服务器发来的数据
- * 另一种写法:source.onmessage = function (event) {}
- */
- source.addEventListener('message', function(e) {
- //console.log(e);
- setMessageInnerHTML(e.data);
- if(e.data == '[DONE]'){
- source.close();
- }
- });
后端:
- @SneakyThrows
- @GetMapping(value = "/stream/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
- public SseEmitter completionsStream(@RequestParam String conversationId){
- //
- List<ChatParamMessagesBO> messagesBOList =new ArrayList();
-
- // 获取内容信息
- ChatParamBO build = ChatParamBO.builder()
- .temperature(0.7)
- .stream(true)
- .model("xxxx")
- .messages(messagesBOList)
- .build();
-
- SseEmitter emitter = new SseEmitter();
-
- // 定义see接口
- Request request = new Request.Builder().url("xxx")
- .header("Authorization","xxxx")
- .post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),JsonUtils.toJson(build)))
- .build();
- OkHttpClient okHttpClient = new OkHttpClient.Builder()
- .connectTimeout(10, TimeUnit.MINUTES)
- .readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
- .build();
-
- StringBuffer sb = new StringBuffer("");
-
- // 实例化EventSource,注册EventSource监听器
- RealEventSource realEventSource = null;
- realEventSource = new RealEventSource(request, new EventSourceListener() {
-
- @Override
- public void onOpen(EventSource eventSource, Response response) {
- log.info("onOpen");
- }
-
- @SneakyThrows
- @Override
- public void onEvent(EventSource eventSource, String id, String type, String data) {
-
- log.info(data);//请求到的数据
- try {
-
- ChatResultBO chatResultBO = JsonUtils.toObject(data.replace("data:", ""), ChatResultBO.class);
- String content = chatResultBO.getChoices().get(0).getDelta().getContent();
- sb.append(content);
- emitter.send(SseEmitter.event().data(JsonUtils.toJson(ChatContentBO.builder().content(content).build())));
-
- } catch (Exception e) {
- // e.printStackTrace();
- }
- if("[DONE]".equals(data)){
- emitter.send(SseEmitter.event().data(data));
- emitter.complete();
- log.info("result={}",sb);
- }
- }
-
- @Override
- public void onClosed(EventSource eventSource) {
- log.info("onClosed,eventSource={}",eventSource);//这边可以监听并重新打开
- // emitter.complete();
- }
-
- @Override
- public void onFailure(EventSource eventSource, Throwable t, Response response) {
- log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
- // emitter.complete();
- }
- });
- realEventSource.connect(okHttpClient);//真正开始请求的一步
- return emitter;
- }
后端配置需要添加:
#gpt支持流式处理
proxy_buffering off;
- location / {
- proxy_pass http://backend;
- proxy_redirect default;
- proxy_connect_timeout 90;
- proxy_read_timeout 90;
- proxy_send_timeout 90;
- #gpt支持流式处理
- proxy_buffering off;
- #root html;
- #root /opt/project/;
- index index.html index.htm;
- client_max_body_size 1024m;
- #设置正确的外网ip
- proxy_set_header Host $host;
- proxy_set_header X-Real-IP $remote_addr;
- proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
- }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。