赞
踩
讯飞星火大模型,通过websocket方式通信传递协议要求的报文,然后将流式返回的报文拼接为完整的响应内容,status=2时是最后一条消息。因为是websocket方式所以是异步响应的,如果想要同步需要使用CountDownLatch控制下线程等待最后一条消息处理完再继续往下走。星火不同版本稍微有一些差异,具体以官网提供的demo为准。
https://console.xfyun.cn/services/bm3
点应用名称进去查看详情
https://www.xfyun.cn/doc/spark/Web.html
下面仅仅是个示例,具体代码要根据官网最新文档确定。注意domain在1.5,2,3版本分别传general,generalv2,generalv3,传错了会报错10404。另外如果想同步返回结果,需要自己使用CountDownLatch控制主线程等待一段时间 countDownLatch.await(30000, TimeUnit.MILLISECONDS); 这样会在指定的最大超时时间内等待。
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.google.gson.Gson; import okhttp3.*; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import java.io.IOException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.text.SimpleDateFormat; import java.util.*; /** * 带历史会话的示例 */ public class BigModelNew extends WebSocketListener { // 地址与鉴权信息 https://spark-api.xf-yun.com/v1.1/chat 1.5地址 domain参数为general // 地址与鉴权信息 https://spark-api.xf-yun.com/v2.1/chat 2.0地址 domain参数为generalv2 public static final String hostUrl = "https://spark-api.xf-yun.com/v1/chat"; public static final String appid = "1234"; public static final String apiSecret = "xxxxx"; public static final String apiKey = "xxxxxxxxxx"; public static final String domain = "general"; //1.5,2,3版本分别传general,generalv2,generalv3,传错了会报错10404 public static List<RoleContent> historyList=new ArrayList<>(); // 对话历史存储集合 public static String totalAnswer=""; // 大模型的答案汇总 // 环境治理的重要性 环保 人口老龄化 我爱我的祖国 public static String NewQuestion = ""; public static final Gson gson = new Gson(); // 个性化参数 private String userId; private Boolean wsCloseFlag; private static Boolean totalFlag=true; // 控制提示用户是否输入 // 构造函数 public BigModelNew(String userId, Boolean wsCloseFlag) { this.userId = userId; this.wsCloseFlag = wsCloseFlag; } // 主函数 public static void main(String[] args) throws Exception { // 个性化参数入口,如果是并发使用,可以在这里模拟 while (true){ if(totalFlag){ Scanner scanner=new Scanner(System.in); System.out.print("我:"); totalFlag=false; NewQuestion=scanner.nextLine(); // 构建鉴权url String authUrl = getAuthUrl(hostUrl, apiKey, apiSecret); OkHttpClient client = new OkHttpClient.Builder().build(); String url = authUrl.toString().replace("http://", "ws://").replace("https://", "wss://"); Request request = new Request.Builder().url(url).build(); for (int i = 0; i < 1; i++) { totalAnswer=""; WebSocket webSocket = client.newWebSocket(request, new BigModelNew(i + "", false)); } }else{ Thread.sleep(200); } } } public static boolean canAddHistory(){ // 由于历史记录最大上线1.2W左右,需要判断是能能加入历史 int history_length=0; for(RoleContent temp:historyList){ history_length=history_length+temp.content.length(); } if(history_length>12000){ historyList.remove(0); historyList.remove(1); historyList.remove(2); historyList.remove(3); historyList.remove(4); return false; }else{ return true; } } // 线程来发送音频与参数 class MyThread extends Thread { private WebSocket webSocket; public MyThread(WebSocket webSocket) { this.webSocket = webSocket; } public void run() { try { JSONObject requestJson=new JSONObject(); JSONObject header=new JSONObject(); // header参数 header.put("app_id",appid); header.put("uid",UUID.randomUUID().toString().substring(0, 10)); JSONObject parameter=new JSONObject(); // parameter参数 JSONObject chat=new JSONObject(); chat.put("domain", domain); chat.put("temperature",0.5); chat.put("max_tokens",4096); parameter.put("chat",chat); JSONObject payload=new JSONObject(); // payload参数 JSONObject message=new JSONObject(); JSONArray text=new JSONArray(); // 历史问题获取 if(historyList.size()>0){ for(RoleContent tempRoleContent:historyList){ text.add(JSON.toJSON(tempRoleContent)); } } // 最新问题 RoleContent roleContent=new RoleContent(); roleContent.role="user"; roleContent.content=NewQuestion; text.add(JSON.toJSON(roleContent)); historyList.add(roleContent); message.put("text",text); payload.put("message",message); requestJson.put("header",header); requestJson.put("parameter",parameter); requestJson.put("payload",payload); // System.err.println(requestJson); // 可以打印看每次的传参明细 webSocket.send(requestJson.toString()); // 等待服务端返回完毕后关闭 while (true) { // System.err.println(wsCloseFlag + "---"); Thread.sleep(200); if (wsCloseFlag) { break; } } webSocket.close(1000, ""); } catch (Exception e) { e.printStackTrace(); } } } @Override public void onOpen(WebSocket webSocket, Response response) { super.onOpen(webSocket, response); System.out.print("大模型:"); MyThread myThread = new MyThread(webSocket); myThread.start(); } @Override public void onMessage(WebSocket webSocket, String text) { // System.out.println(userId + "用来区分那个用户的结果" + text); JsonParse myJsonParse = gson.fromJson(text, JsonParse.class); if (myJsonParse.header.code != 0) { System.out.println("发生错误,错误码为:" + myJsonParse.header.code); System.out.println("本次请求的sid为:" + myJsonParse.header.sid); webSocket.close(1000, ""); } List<Text> textList = myJsonParse.payload.choices.text; for (Text temp : textList) { System.out.print(temp.content); totalAnswer=totalAnswer+temp.content; } if (myJsonParse.header.status == 2) { // 可以关闭连接,释放资源 System.out.println(); System.out.println("*************************************************************************************"); if(canAddHistory()){ RoleContent roleContent=new RoleContent(); roleContent.setRole("assistant"); roleContent.setContent(totalAnswer); historyList.add(roleContent); }else{ historyList.remove(0); RoleContent roleContent=new RoleContent(); roleContent.setRole("assistant"); roleContent.setContent(totalAnswer); historyList.add(roleContent); } wsCloseFlag = true; totalFlag=true; } } @Override public void onFailure(WebSocket webSocket, Throwable t, Response response) { super.onFailure(webSocket, t, response); try { if (null != response) { int code = response.code(); System.out.println("onFailure code:" + code); System.out.println("onFailure body:" + response.body().string()); if (101 != code) { System.out.println("connection failed"); System.exit(0); } } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } // 鉴权方法 public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception { URL url = new URL(hostUrl); // 时间 SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US); format.setTimeZone(TimeZone.getTimeZone("GMT")); String date = format.format(new Date()); // 拼接 String preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "GET " + url.getPath() + " HTTP/1.1"; // System.err.println(preStr); // SHA256加密 Mac mac = Mac.getInstance("hmacsha256"); SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256"); mac.init(spec); byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8)); // Base64加密 String sha = Base64.getEncoder().encodeToString(hexDigits); // System.err.println(sha); // 拼接 String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha); // 拼接地址 HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().// addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).// addQueryParameter("date", date).// addQueryParameter("host", url.getHost()).// build(); // System.err.println(httpUrl.toString()); return httpUrl.toString(); } //返回的json结果拆解 class JsonParse { Header header; Payload payload; } class Header { int code; int status; String sid; } class Payload { Choices choices; } class Choices { List<Text> text; } class Text { String role; String content; } class RoleContent{ String role; String content; public String getRole() { return role; } public void setRole(String role) { this.role = role; } public String getContent() { return content; } public void setContent(String content) { this.content = content; } } }
pom.xml
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>org.example</groupId> <artifactId>big_model</artifactId> <version>1.0-SNAPSHOT</version> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> <source>8</source> <target>8</target> </configuration> </plugin> </plugins> </build> <properties> <java.version>1.8</java.version> </properties> <dependencies> <!-- https://mvnrepository.com/artifact/com.alibaba/fastjson --> <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <version>1.2.67</version> </dependency> <!-- https://mvnrepository.com/artifact/com.google.code.gson/gson --> <dependency> <groupId>com.google.code.gson</groupId> <artifactId>gson</artifactId> <version>2.8.5</version> </dependency> <!-- https://mvnrepository.com/artifact/org.java-websocket/Java-WebSocket --> <dependency> <groupId>org.java-websocket</groupId> <artifactId>Java-WebSocket</artifactId> <version>1.3.8</version> </dependency> <!-- https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp --> <dependency> <groupId>com.squareup.okhttp3</groupId> <artifactId>okhttp</artifactId> <version>4.10.0</version> </dependency> <!-- https://mvnrepository.com/artifact/com.squareup.okio/okio --> <dependency> <groupId>com.squareup.okio</groupId> <artifactId>okio</artifactId> <version>2.10.0</version> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <version>RELEASE</version> <scope>compile</scope> </dependency> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-log4j12</artifactId> <version>1.7.25</version> </dependency> </dependencies> </project>
# Global logging configuration
log4j.rootLogger=DEBUG, stdout
# Console output...
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。