当前位置:   article > 正文

SpringBoot项目接入讯飞星火大模型Api_springboot 接入大模型

springboot 接入大模型

SpringBoot项目接入讯飞星火大模型Api

  1. 说明:此代码使用 webSocket 连接 ai 大模型,前端页面使用 websocket 连接后台服务端,

  2. 自行注册讯飞星火大模型平台,申请tokens。

  3. 代码:

    1. maven依赖文件:

              <dependency>
                  <groupId>org.springframework.boot</groupId>
                  <artifactId>spring-boot-starter-web</artifactId>
              </dependency>
      
              <dependency>
                  <groupId>org.springframework.boot</groupId>
                  <artifactId>spring-boot-starter-test</artifactId>
                  <scope>test</scope>
              </dependency>
              <!-- https://mvnrepository.com/artifact/com.alibaba/fastjson -->
              <dependency>
                  <groupId>com.alibaba</groupId>
                  <artifactId>fastjson</artifactId>
                  <version>1.2.67</version>
              </dependency>
              <!-- https://mvnrepository.com/artifact/com.google.code.gson/gson -->
              <dependency>
                  <groupId>com.google.code.gson</groupId>
                  <artifactId>gson</artifactId>
                  <version>2.8.5</version>
              </dependency>
              <!-- https://mvnrepository.com/artifact/org.java-websocket/Java-WebSocket -->
      <!--        <dependency>-->
      <!--            <groupId>org.java-websocket</groupId>-->
      <!--            <artifactId>Java-WebSocket</artifactId>-->
      <!--            <version>1.3.8</version>-->
      <!--        </dependency>-->
      
              <dependency>
                  <groupId>org.springframework.boot</groupId>
                  <artifactId>spring-boot-starter-websocket</artifactId>
              </dependency>
      
              <!-- https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp -->
              <dependency>
                  <groupId>com.squareup.okhttp3</groupId>
                  <artifactId>okhttp</artifactId>
                  <version>4.10.0</version>
              </dependency>
              <!-- https://mvnrepository.com/artifact/com.squareup.okio/okio -->
              <dependency>
                  <groupId>com.squareup.okio</groupId>
                  <artifactId>okio</artifactId>
                  <version>2.10.0</version>
              </dependency>
              <dependency>
                  <groupId>org.springframework.boot</groupId>
                  <artifactId>spring-boot</artifactId>
                  <version>2.7.8</version>
              </dependency>
              <dependency>
                  <groupId>org.projectlombok</groupId>
                  <artifactId>lombok</artifactId>
              </dependency>
              <!-- druid 连接池依赖 -->
              <dependency>
                  <groupId>com.alibaba</groupId>
                  <artifactId>druid</artifactId>
                  <version>1.1.20</version>
              </dependency>
              <!--            hutool依赖-->
              <dependency>
                  <groupId>cn.hutool</groupId>
                  <artifactId>hutool-all</artifactId>
                  <version>5.7.22</version>
              </dependency>
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
      • 41
      • 42
      • 43
      • 44
      • 45
      • 46
      • 47
      • 48
      • 49
      • 50
      • 51
      • 52
      • 53
      • 54
      • 55
      • 56
      • 57
      • 58
      • 59
      • 60
      • 61
      • 62
      • 63
      • 64
      • 65
      • 66
      • 67
    2. java调用 ai 大模型代码:

      import com.alibaba.fastjson.JSON;
      import com.alibaba.fastjson.JSONArray;
      import com.alibaba.fastjson.JSONObject;
      import com.google.gson.Gson;
      import okhttp3.*;
      
      import javax.crypto.Mac;
      import javax.crypto.spec.SecretKeySpec;
      import java.io.IOException;
      import java.net.URL;
      import java.nio.charset.StandardCharsets;
      import java.text.SimpleDateFormat;
      import java.util.*;
      
      /**
       * @author hanyiming
       */
      public class BigModelNew extends WebSocketListener {
          // 地址与鉴权信息  https://spark-api.xf-yun.com/v1.1/chat   1.5地址  domain参数为general
          // 地址与鉴权信息  https://spark-api.xf-yun.com/v2.1/chat   2.0地址  domain参数为generalv2
          public static final String hostUrl = "https://spark-api.xf-yun.com/v3.5/chat";
          // 以下参数替换为自己的身份认证信息
          public static final String appid = "appid";
          public static final String apiSecret = "apiSecret";
          public static final String apiKey = "apiKey";
      
          public static List<RoleContent> historyList = new ArrayList<>(); // 对话历史存储集合
      
          public static String totalAnswer = ""; // 大模型的答案汇总
      
          // 环境治理的重要性  环保  人口老龄化  我爱我的祖国
          public static String NewQuestion = "";
      
          public static final Gson gson = new Gson();
      
          // 个性化参数
          private String userId;
          private Boolean wsCloseFlag;
      
          private static Boolean totalFlag = true; // 控制提示用户是否输入
      
          // 构造函数
          public BigModelNew(String userId, Boolean wsCloseFlag) {
              this.userId = userId;
              this.wsCloseFlag = wsCloseFlag;
          }
      
          public static boolean canAddHistory() {  // 由于历史记录最大上线1.2W左右,需要判断是能能加入历史
              int history_length = 0;
              for (RoleContent temp : historyList) {
                  history_length = history_length + temp.content.length();
              }
              if (history_length > 12000) {
                  historyList.remove(0);
                  historyList.remove(1);
                  historyList.remove(2);
                  historyList.remove(3);
                  historyList.remove(4);
                  return false;
              } else {
                  return true;
              }
          }
      
          // 线程来发送音频与参数
          class MyThread extends Thread {
              private WebSocket webSocket;
      
              public MyThread(WebSocket webSocket) {
                  this.webSocket = webSocket;
              }
      
              public void run() {
                  try {
                      JSONObject requestJson = new JSONObject();
      
                      JSONObject header = new JSONObject();  // header参数
                      header.put("app_id", appid);
                      header.put("uid", UUID.randomUUID().toString().substring(0, 10));
      
                      JSONObject parameter = new JSONObject(); // parameter参数
                      JSONObject chat = new JSONObject();
                      chat.put("domain", "generalv2");
                      chat.put("temperature", 0.5);
                      chat.put("max_tokens", 4096);
                      parameter.put("chat", chat);
      
                      JSONObject payload = new JSONObject(); // payload参数
                      JSONObject message = new JSONObject();
                      JSONArray text = new JSONArray();
      
                      // 历史问题获取
                      if (historyList.size() > 0) {
                          for (RoleContent tempRoleContent : historyList) {
                              text.add(JSON.toJSON(tempRoleContent));
                          }
                      }
      
                      // 最新问题
                      RoleContent roleContent = new RoleContent();
                      roleContent.role = "user";
                      roleContent.content = NewQuestion;
                      text.add(JSON.toJSON(roleContent));
                      historyList.add(roleContent);
      
      
                      message.put("text", text);
                      payload.put("message", message);
      
      
                      requestJson.put("header", header);
                      requestJson.put("parameter", parameter);
                      requestJson.put("payload", payload);
                      // System.err.println(requestJson); // 可以打印看每次的传参明细
                      webSocket.send(requestJson.toString());
                      // 等待服务端返回完毕后关闭
                      while (true) {
                          // System.err.println(wsCloseFlag + "---");
                          Thread.sleep(200);
                          if (wsCloseFlag) {
                              break;
                          }
                      }
                      webSocket.close(1000, "");
                  } catch (Exception e) {
                      e.printStackTrace();
                  }
              }
          }
      
          @Override
          public void onOpen(WebSocket webSocket, Response response) {
              super.onOpen(webSocket, response);
              System.out.print("大模型:");
              MyThread myThread = new MyThread(webSocket);
              myThread.start();
          }
      
          @Override
          public void onMessage(WebSocket webSocket, String text) {
              // System.out.println(userId + "用来区分那个用户的结果" + text);
              JsonParse myJsonParse = gson.fromJson(text, JsonParse.class);
              if (myJsonParse.header.code != 0) {
                  System.out.println("发生错误,错误码为:" + myJsonParse.header.code);
                  System.out.println("本次请求的sid为:" + myJsonParse.header.sid);
                  webSocket.close(1000, "");
              }
              List<Text> textList = myJsonParse.payload.choices.text;
              for (Text temp : textList) {
                  // 在此处给前端页面发送回答信息,如有存储问答需求,请在此处存储回答信息
                  WebSocketClient.sendInfo(temp.content);
                  System.out.print(temp.content);
                  totalAnswer = totalAnswer + temp.content;
              }
              if (myJsonParse.header.status == 2) {
                  // 可以关闭连接,释放资源
                  System.out.println();
                  System.out.println("*************************************************************************************");
                  if (canAddHistory()) {
                      RoleContent roleContent = new RoleContent();
                      roleContent.setRole("assistant");
                      roleContent.setContent(totalAnswer);
                      historyList.add(roleContent);
                  } else {
                      historyList.remove(0);
                      RoleContent roleContent = new RoleContent();
                      roleContent.setRole("assistant");
                      roleContent.setContent(totalAnswer);
                      historyList.add(roleContent);
                  }
                  wsCloseFlag = true;
                  totalFlag = true;
              }
          }
      
          @Override
          public void onFailure(WebSocket webSocket, Throwable t, Response response) {
              super.onFailure(webSocket, t, response);
              try {
                  if (null != response) {
                      int code = response.code();
                      System.out.println("onFailure code:" + code);
                      System.out.println("onFailure body:" + response.body().string());
                      if (101 != code) {
                          System.out.println("connection failed");
                          System.exit(0);
                      }
                  }
              } catch (IOException e) {
                  // TODO Auto-generated catch block
                  e.printStackTrace();
              }
          }
      
      
          // 鉴权方法
          public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
              URL url = new URL(hostUrl);
              // 时间
              SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
              format.setTimeZone(TimeZone.getTimeZone("GMT"));
              String date = format.format(new Date());
              // 拼接
              String preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "GET " + url.getPath() + " HTTP/1.1";
              // System.err.println(preStr);
              // SHA256加密
              Mac mac = Mac.getInstance("hmacsha256");
              SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
              mac.init(spec);
      
              byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
              // Base64加密
              String sha = Base64.getEncoder().encodeToString(hexDigits);
              // System.err.println(sha);
              // 拼接
              String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
              // 拼接地址
              HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().//
                      addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).//
                      addQueryParameter("date", date).//
                      addQueryParameter("host", url.getHost()).//
                      build();
      
              // System.err.println(httpUrl.toString());
              return httpUrl.toString();
          }
      
          //返回的json结果拆解
          class JsonParse {
              Header header;
              Payload payload;
          }
      
          class Header {
              int code;
              int status;
              String sid;
          }
      
          class Payload {
              Choices choices;
          }
      
          class Choices {
              List<Text> text;
          }
      
          class Text {
              String role;
              String content;
          }
      
          class RoleContent {
              String role;
              String content;
      
              public String getRole() {
                  return role;
              }
      
              public void setRole(String role) {
                  this.role = role;
              }
      
              public String getContent() {
                  return content;
              }
      
              public void setContent(String content) {
                  this.content = content;
              }
          }
      }
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
      • 41
      • 42
      • 43
      • 44
      • 45
      • 46
      • 47
      • 48
      • 49
      • 50
      • 51
      • 52
      • 53
      • 54
      • 55
      • 56
      • 57
      • 58
      • 59
      • 60
      • 61
      • 62
      • 63
      • 64
      • 65
      • 66
      • 67
      • 68
      • 69
      • 70
      • 71
      • 72
      • 73
      • 74
      • 75
      • 76
      • 77
      • 78
      • 79
      • 80
      • 81
      • 82
      • 83
      • 84
      • 85
      • 86
      • 87
      • 88
      • 89
      • 90
      • 91
      • 92
      • 93
      • 94
      • 95
      • 96
      • 97
      • 98
      • 99
      • 100
      • 101
      • 102
      • 103
      • 104
      • 105
      • 106
      • 107
      • 108
      • 109
      • 110
      • 111
      • 112
      • 113
      • 114
      • 115
      • 116
      • 117
      • 118
      • 119
      • 120
      • 121
      • 122
      • 123
      • 124
      • 125
      • 126
      • 127
      • 128
      • 129
      • 130
      • 131
      • 132
      • 133
      • 134
      • 135
      • 136
      • 137
      • 138
      • 139
      • 140
      • 141
      • 142
      • 143
      • 144
      • 145
      • 146
      • 147
      • 148
      • 149
      • 150
      • 151
      • 152
      • 153
      • 154
      • 155
      • 156
      • 157
      • 158
      • 159
      • 160
      • 161
      • 162
      • 163
      • 164
      • 165
      • 166
      • 167
      • 168
      • 169
      • 170
      • 171
      • 172
      • 173
      • 174
      • 175
      • 176
      • 177
      • 178
      • 179
      • 180
      • 181
      • 182
      • 183
      • 184
      • 185
      • 186
      • 187
      • 188
      • 189
      • 190
      • 191
      • 192
      • 193
      • 194
      • 195
      • 196
      • 197
      • 198
      • 199
      • 200
      • 201
      • 202
      • 203
      • 204
      • 205
      • 206
      • 207
      • 208
      • 209
      • 210
      • 211
      • 212
      • 213
      • 214
      • 215
      • 216
      • 217
      • 218
      • 219
      • 220
      • 221
      • 222
      • 223
      • 224
      • 225
      • 226
      • 227
      • 228
      • 229
      • 230
      • 231
      • 232
      • 233
      • 234
      • 235
      • 236
      • 237
      • 238
      • 239
      • 240
      • 241
      • 242
      • 243
      • 244
      • 245
      • 246
      • 247
      • 248
      • 249
      • 250
      • 251
      • 252
      • 253
      • 254
      • 255
      • 256
      • 257
      • 258
      • 259
      • 260
      • 261
      • 262
      • 263
      • 264
      • 265
      • 266
      • 267
      • 268
      • 269
      • 270
      • 271
      • 272
      • 273
    3. 编写给前端页面使用的 websocket 连接接口

      import cn.hutool.core.util.StrUtil;
      import com.alibaba.druid.util.StringUtils;
      import lombok.extern.slf4j.Slf4j;
      import okhttp3.OkHttpClient;
      import okhttp3.Request;
      import okhttp3.WebSocket;
      import org.springframework.stereotype.Component;
      
      import javax.websocket.*;
      import javax.websocket.server.PathParam;
      import javax.websocket.server.ServerEndpoint;
      import java.io.IOException;
      import java.util.concurrent.ConcurrentHashMap;
      
      /**
       *
       * @author HanYiMing
       * @date 2024/3/1
       * @description websocket配置类
       */
      @ServerEndpoint(value = "/websocketClient/{userId}")
      @Component
      @Slf4j
      public class WebSocketClient {
      
      
          /**
           * 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的
           */
      
          private static int onlineCount = 0;
      
          /**
           * concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象
           */
          private static final ConcurrentHashMap<String, WebSocketClient> webSocketMap = new ConcurrentHashMap<>();
      
          public ConcurrentHashMap<String, WebSocketClient> getWebSocketMap() {
              return webSocketMap;
          }
      
          /**
           * 与某个客户端的连接会话,需要通过它来给客户端发送数据
           */
          private Session session;
      
          /**
           * 用户id 唯一标识
           */
          private String userId;
      
          /**
           * 连接建立成功调用的方法
           */
          @OnOpen
          public void onOpen(Session session, @PathParam("userId") String userId) {
              this.session = session;
              this.userId = userId;
              //加入map
              webSocketMap.put(userId, this);
              //在线数加1
              addOnlineCount();
              log.info("WebSocket客户端{}连接成功,客户端标识:{},当前在线人数:{}", session.getId(), userId, getOnlineCount());
              sendMessage("用户" + userId + "连接成功!");
          }
      
          /**
           * 连接关闭调用的方法
           */
          @OnClose
          public void onClose() {
              //从map中删除
              webSocketMap.remove(userId);
              //在线数减1
              subOnlineCount();
              log.info("WebSocket客户端{}连接断开,客户端标识:{},当前在线人数:{}", session.getId(), userId, getOnlineCount());
          }
      
          /**
           * 收到客户端消息后调用的方法
           *
           * @param message 客户端发送过来的消息
           */
          @OnMessage
          public void onMessage(String message, Session session) throws Exception {
              // 心跳检测响应
              if (StringUtils.equalsIgnoreCase("ping", message)) {
                  sendMessage("pong");
                  log.info("WebSocket服务端已回复客户端{}的心跳检测:pong", session.getId());
                  return;
              }
              BigModelNew.NewQuestion = message;
              // 构建鉴权url
              String authUrl = BigModelNew.getAuthUrl(BigModelNew.hostUrl, BigModelNew.apiKey, BigModelNew.apiSecret);
              OkHttpClient client = new OkHttpClient.Builder().build();
              String url = authUrl.toString().replace("http://", "ws://").replace("https://", "wss://");
              Request request = new Request.Builder().url(url).build();
              for (int i = 0; i < 1; i++) {
                  BigModelNew.totalAnswer="";
                  WebSocket webSocket = client.newWebSocket(request, new BigModelNew(i + "",
                          false));
              }
              log.info("收到客户端{}的消息:{}", session.getId(), message);
          }
      
          /**
           * 发生错误时调用
           */
          @OnError
          public void onError(Session session, Throwable error) {
              log.error("发生错误{}", session.getId(), error);
              error.printStackTrace();
          }
      
          /**
           * 向客户端发送消息
           */
          public void sendMessage(String message) {
              try {
                  this.session.getBasicRemote().sendText(message);
              } catch (IOException e) {
                  e.printStackTrace();
                  log.error("客户端{}发送消息{{}}失败", session.getId(), message);
              }
          }
      
          /**
           * 通过userId向客户端发送消息
           */
          public static void sendMessageByUserId(String userId, String message) throws IOException {
              log.info("给用户{}发送{}信息", userId, message);
              if (StrUtil.isNotBlank(userId) && webSocketMap.containsKey(userId)) {
                  webSocketMap.get(userId).sendMessage(message);
              }
          }
      
          /**
           * 关闭WebSocket
           *
           * @param userId 用户id
           */
          public static void closeWebSocket(String userId) {
              if (StrUtil.isNotBlank(userId) && webSocketMap.containsKey(userId)) {
                  webSocketMap.get(userId).onClose();
              }
          }
      
          /**
           * 群发自定义消息
           */
          public static void sendInfo(String message) {
              for (String item : webSocketMap.keySet()) {
                  webSocketMap.get(item).sendMessage(message);
              }
          }
      
          /**
           * 获取在线人数
           *
           * @return 在线人数
           */
          public static synchronized int getOnlineCount() {
              return onlineCount;
          }
      
          /**
           * 添加一位在线人数
           */
          public static synchronized void addOnlineCount() {
              WebSocketClient.onlineCount++;
          }
      
          /**
           * 减少一位在线人数
           */
          public static synchronized void subOnlineCount() {
              WebSocketClient.onlineCount--;
          }
      
      }
      
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
      • 41
      • 42
      • 43
      • 44
      • 45
      • 46
      • 47
      • 48
      • 49
      • 50
      • 51
      • 52
      • 53
      • 54
      • 55
      • 56
      • 57
      • 58
      • 59
      • 60
      • 61
      • 62
      • 63
      • 64
      • 65
      • 66
      • 67
      • 68
      • 69
      • 70
      • 71
      • 72
      • 73
      • 74
      • 75
      • 76
      • 77
      • 78
      • 79
      • 80
      • 81
      • 82
      • 83
      • 84
      • 85
      • 86
      • 87
      • 88
      • 89
      • 90
      • 91
      • 92
      • 93
      • 94
      • 95
      • 96
      • 97
      • 98
      • 99
      • 100
      • 101
      • 102
      • 103
      • 104
      • 105
      • 106
      • 107
      • 108
      • 109
      • 110
      • 111
      • 112
      • 113
      • 114
      • 115
      • 116
      • 117
      • 118
      • 119
      • 120
      • 121
      • 122
      • 123
      • 124
      • 125
      • 126
      • 127
      • 128
      • 129
      • 130
      • 131
      • 132
      • 133
      • 134
      • 135
      • 136
      • 137
      • 138
      • 139
      • 140
      • 141
      • 142
      • 143
      • 144
      • 145
      • 146
      • 147
      • 148
      • 149
      • 150
      • 151
      • 152
      • 153
      • 154
      • 155
      • 156
      • 157
      • 158
      • 159
      • 160
      • 161
      • 162
      • 163
      • 164
      • 165
      • 166
      • 167
      • 168
      • 169
      • 170
      • 171
      • 172
      • 173
      • 174
      • 175
      • 176
      • 177
      • 178
      • 179
      • 180
      • 181
    4. 使用 postman 进行测试:

      1. 创建一个websocket 连接测试案例

        image-20240301172900672

      2. 输入端口号,点击connect连接

        image-20240301172959923

      3. 返回连接成功信息:

        image-20240301173036059

      4. 发送文字,ai回答,测试成功。

        image-20240301173937156

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/578880
推荐阅读
相关标签
  

闽ICP备14008679号