赞
踩
说明:此代码使用 webSocket 连接 ai 大模型,前端页面使用 websocket 连接后台服务端,
自行注册讯飞星火大模型平台,申请tokens。
代码:
maven依赖文件:
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-test</artifactId> <scope>test</scope> </dependency> <!-- https://mvnrepository.com/artifact/com.alibaba/fastjson --> <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <version>1.2.67</version> </dependency> <!-- https://mvnrepository.com/artifact/com.google.code.gson/gson --> <dependency> <groupId>com.google.code.gson</groupId> <artifactId>gson</artifactId> <version>2.8.5</version> </dependency> <!-- https://mvnrepository.com/artifact/org.java-websocket/Java-WebSocket --> <!-- <dependency>--> <!-- <groupId>org.java-websocket</groupId>--> <!-- <artifactId>Java-WebSocket</artifactId>--> <!-- <version>1.3.8</version>--> <!-- </dependency>--> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-websocket</artifactId> </dependency> <!-- https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp --> <dependency> <groupId>com.squareup.okhttp3</groupId> <artifactId>okhttp</artifactId> <version>4.10.0</version> </dependency> <!-- https://mvnrepository.com/artifact/com.squareup.okio/okio --> <dependency> <groupId>com.squareup.okio</groupId> <artifactId>okio</artifactId> <version>2.10.0</version> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot</artifactId> <version>2.7.8</version> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> </dependency> <!-- druid 连接池依赖 --> <dependency> <groupId>com.alibaba</groupId> <artifactId>druid</artifactId> <version>1.1.20</version> </dependency> <!-- hutool依赖--> <dependency> <groupId>cn.hutool</groupId> <artifactId>hutool-all</artifactId> <version>5.7.22</version> </dependency>
java调用 ai 大模型代码:
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.google.gson.Gson; import okhttp3.*; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import java.io.IOException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.text.SimpleDateFormat; import java.util.*; /** * @author hanyiming */ public class BigModelNew extends WebSocketListener { // 地址与鉴权信息 https://spark-api.xf-yun.com/v1.1/chat 1.5地址 domain参数为general // 地址与鉴权信息 https://spark-api.xf-yun.com/v2.1/chat 2.0地址 domain参数为generalv2 public static final String hostUrl = "https://spark-api.xf-yun.com/v3.5/chat"; // 以下参数替换为自己的身份认证信息 public static final String appid = "appid"; public static final String apiSecret = "apiSecret"; public static final String apiKey = "apiKey"; public static List<RoleContent> historyList = new ArrayList<>(); // 对话历史存储集合 public static String totalAnswer = ""; // 大模型的答案汇总 // 环境治理的重要性 环保 人口老龄化 我爱我的祖国 public static String NewQuestion = ""; public static final Gson gson = new Gson(); // 个性化参数 private String userId; private Boolean wsCloseFlag; private static Boolean totalFlag = true; // 控制提示用户是否输入 // 构造函数 public BigModelNew(String userId, Boolean wsCloseFlag) { this.userId = userId; this.wsCloseFlag = wsCloseFlag; } public static boolean canAddHistory() { // 由于历史记录最大上线1.2W左右,需要判断是能能加入历史 int history_length = 0; for (RoleContent temp : historyList) { history_length = history_length + temp.content.length(); } if (history_length > 12000) { historyList.remove(0); historyList.remove(1); historyList.remove(2); historyList.remove(3); historyList.remove(4); return false; } else { return true; } } // 线程来发送音频与参数 class MyThread extends Thread { private WebSocket webSocket; public MyThread(WebSocket webSocket) { this.webSocket = webSocket; } public void run() { try { JSONObject requestJson = new JSONObject(); JSONObject header = new JSONObject(); // header参数 header.put("app_id", appid); header.put("uid", UUID.randomUUID().toString().substring(0, 10)); JSONObject parameter = new JSONObject(); // parameter参数 JSONObject chat = new JSONObject(); chat.put("domain", "generalv2"); chat.put("temperature", 0.5); chat.put("max_tokens", 4096); parameter.put("chat", chat); JSONObject payload = new JSONObject(); // payload参数 JSONObject message = new JSONObject(); JSONArray text = new JSONArray(); // 历史问题获取 if (historyList.size() > 0) { for (RoleContent tempRoleContent : historyList) { text.add(JSON.toJSON(tempRoleContent)); } } // 最新问题 RoleContent roleContent = new RoleContent(); roleContent.role = "user"; roleContent.content = NewQuestion; text.add(JSON.toJSON(roleContent)); historyList.add(roleContent); message.put("text", text); payload.put("message", message); requestJson.put("header", header); requestJson.put("parameter", parameter); requestJson.put("payload", payload); // System.err.println(requestJson); // 可以打印看每次的传参明细 webSocket.send(requestJson.toString()); // 等待服务端返回完毕后关闭 while (true) { // System.err.println(wsCloseFlag + "---"); Thread.sleep(200); if (wsCloseFlag) { break; } } webSocket.close(1000, ""); } catch (Exception e) { e.printStackTrace(); } } } @Override public void onOpen(WebSocket webSocket, Response response) { super.onOpen(webSocket, response); System.out.print("大模型:"); MyThread myThread = new MyThread(webSocket); myThread.start(); } @Override public void onMessage(WebSocket webSocket, String text) { // System.out.println(userId + "用来区分那个用户的结果" + text); JsonParse myJsonParse = gson.fromJson(text, JsonParse.class); if (myJsonParse.header.code != 0) { System.out.println("发生错误,错误码为:" + myJsonParse.header.code); System.out.println("本次请求的sid为:" + myJsonParse.header.sid); webSocket.close(1000, ""); } List<Text> textList = myJsonParse.payload.choices.text; for (Text temp : textList) { // 在此处给前端页面发送回答信息,如有存储问答需求,请在此处存储回答信息 WebSocketClient.sendInfo(temp.content); System.out.print(temp.content); totalAnswer = totalAnswer + temp.content; } if (myJsonParse.header.status == 2) { // 可以关闭连接,释放资源 System.out.println(); System.out.println("*************************************************************************************"); if (canAddHistory()) { RoleContent roleContent = new RoleContent(); roleContent.setRole("assistant"); roleContent.setContent(totalAnswer); historyList.add(roleContent); } else { historyList.remove(0); RoleContent roleContent = new RoleContent(); roleContent.setRole("assistant"); roleContent.setContent(totalAnswer); historyList.add(roleContent); } wsCloseFlag = true; totalFlag = true; } } @Override public void onFailure(WebSocket webSocket, Throwable t, Response response) { super.onFailure(webSocket, t, response); try { if (null != response) { int code = response.code(); System.out.println("onFailure code:" + code); System.out.println("onFailure body:" + response.body().string()); if (101 != code) { System.out.println("connection failed"); System.exit(0); } } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } // 鉴权方法 public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception { URL url = new URL(hostUrl); // 时间 SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US); format.setTimeZone(TimeZone.getTimeZone("GMT")); String date = format.format(new Date()); // 拼接 String preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "GET " + url.getPath() + " HTTP/1.1"; // System.err.println(preStr); // SHA256加密 Mac mac = Mac.getInstance("hmacsha256"); SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256"); mac.init(spec); byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8)); // Base64加密 String sha = Base64.getEncoder().encodeToString(hexDigits); // System.err.println(sha); // 拼接 String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha); // 拼接地址 HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().// addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).// addQueryParameter("date", date).// addQueryParameter("host", url.getHost()).// build(); // System.err.println(httpUrl.toString()); return httpUrl.toString(); } //返回的json结果拆解 class JsonParse { Header header; Payload payload; } class Header { int code; int status; String sid; } class Payload { Choices choices; } class Choices { List<Text> text; } class Text { String role; String content; } class RoleContent { String role; String content; public String getRole() { return role; } public void setRole(String role) { this.role = role; } public String getContent() { return content; } public void setContent(String content) { this.content = content; } } }
编写给前端页面使用的 websocket 连接接口
import cn.hutool.core.util.StrUtil; import com.alibaba.druid.util.StringUtils; import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.WebSocket; import org.springframework.stereotype.Component; import javax.websocket.*; import javax.websocket.server.PathParam; import javax.websocket.server.ServerEndpoint; import java.io.IOException; import java.util.concurrent.ConcurrentHashMap; /** * * @author HanYiMing * @date 2024/3/1 * @description websocket配置类 */ @ServerEndpoint(value = "/websocketClient/{userId}") @Component @Slf4j public class WebSocketClient { /** * 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的 */ private static int onlineCount = 0; /** * concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象 */ private static final ConcurrentHashMap<String, WebSocketClient> webSocketMap = new ConcurrentHashMap<>(); public ConcurrentHashMap<String, WebSocketClient> getWebSocketMap() { return webSocketMap; } /** * 与某个客户端的连接会话,需要通过它来给客户端发送数据 */ private Session session; /** * 用户id 唯一标识 */ private String userId; /** * 连接建立成功调用的方法 */ @OnOpen public void onOpen(Session session, @PathParam("userId") String userId) { this.session = session; this.userId = userId; //加入map webSocketMap.put(userId, this); //在线数加1 addOnlineCount(); log.info("WebSocket客户端{}连接成功,客户端标识:{},当前在线人数:{}", session.getId(), userId, getOnlineCount()); sendMessage("用户" + userId + "连接成功!"); } /** * 连接关闭调用的方法 */ @OnClose public void onClose() { //从map中删除 webSocketMap.remove(userId); //在线数减1 subOnlineCount(); log.info("WebSocket客户端{}连接断开,客户端标识:{},当前在线人数:{}", session.getId(), userId, getOnlineCount()); } /** * 收到客户端消息后调用的方法 * * @param message 客户端发送过来的消息 */ @OnMessage public void onMessage(String message, Session session) throws Exception { // 心跳检测响应 if (StringUtils.equalsIgnoreCase("ping", message)) { sendMessage("pong"); log.info("WebSocket服务端已回复客户端{}的心跳检测:pong", session.getId()); return; } BigModelNew.NewQuestion = message; // 构建鉴权url String authUrl = BigModelNew.getAuthUrl(BigModelNew.hostUrl, BigModelNew.apiKey, BigModelNew.apiSecret); OkHttpClient client = new OkHttpClient.Builder().build(); String url = authUrl.toString().replace("http://", "ws://").replace("https://", "wss://"); Request request = new Request.Builder().url(url).build(); for (int i = 0; i < 1; i++) { BigModelNew.totalAnswer=""; WebSocket webSocket = client.newWebSocket(request, new BigModelNew(i + "", false)); } log.info("收到客户端{}的消息:{}", session.getId(), message); } /** * 发生错误时调用 */ @OnError public void onError(Session session, Throwable error) { log.error("发生错误{}", session.getId(), error); error.printStackTrace(); } /** * 向客户端发送消息 */ public void sendMessage(String message) { try { this.session.getBasicRemote().sendText(message); } catch (IOException e) { e.printStackTrace(); log.error("客户端{}发送消息{{}}失败", session.getId(), message); } } /** * 通过userId向客户端发送消息 */ public static void sendMessageByUserId(String userId, String message) throws IOException { log.info("给用户{}发送{}信息", userId, message); if (StrUtil.isNotBlank(userId) && webSocketMap.containsKey(userId)) { webSocketMap.get(userId).sendMessage(message); } } /** * 关闭WebSocket * * @param userId 用户id */ public static void closeWebSocket(String userId) { if (StrUtil.isNotBlank(userId) && webSocketMap.containsKey(userId)) { webSocketMap.get(userId).onClose(); } } /** * 群发自定义消息 */ public static void sendInfo(String message) { for (String item : webSocketMap.keySet()) { webSocketMap.get(item).sendMessage(message); } } /** * 获取在线人数 * * @return 在线人数 */ public static synchronized int getOnlineCount() { return onlineCount; } /** * 添加一位在线人数 */ public static synchronized void addOnlineCount() { WebSocketClient.onlineCount++; } /** * 减少一位在线人数 */ public static synchronized void subOnlineCount() { WebSocketClient.onlineCount--; } }
使用 postman 进行测试:
创建一个websocket 连接测试案例
输入端口号,点击connect连接
返回连接成功信息:
发送文字,ai回答,测试成功。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。