当前位置:   article > 正文

基于Azure实现Java访问OpenAI_azure openai java

azure openai java

1、前言

        之前使用了Java代码访问OpenAI:OpenAI注册以及Java代码调用_雨欲语的博客-CSDN博客但是需要vpn才能访问,现在可以基于微软的Azure访问OpenAI,不再需要vpn,官方文档:快速入门 - 开始通过 Azure OpenAI 服务使用 ChatGPT 和 GPT-4 - Azure OpenAI Service | Microsoft Learn,官方对Python和C#进行了封装,java没有,但是可以通过uri的方式进行访问。

        Azure申请:什么是 Azure OpenAI 服务? - Azure Cognitive Services | Microsoft Learn

2、返回结果封装

       首先根据返回结果可以封装一些java类:

        AzureAIChatResponse类:

  1. public class AzureAIChatResponse {
  2. private String id;
  3. private String object;
  4. private String created;
  5. private String model;
  6. private AzureAIUsage usage;
  7. private List<AzureAIChoice> choices;
  8. public String getId() {
  9. return id;
  10. }
  11. public void setId(String id) {
  12. this.id = id;
  13. }
  14. public String getObject() {
  15. return object;
  16. }
  17. public void setObject(String object) {
  18. this.object = object;
  19. }
  20. public String getCreated() {
  21. return created;
  22. }
  23. public void setCreated(String created) {
  24. this.created = created;
  25. }
  26. public String getModel() {
  27. return model;
  28. }
  29. public void setModel(String model) {
  30. this.model = model;
  31. }
  32. public AzureAIUsage getUsage() {
  33. return usage;
  34. }
  35. public void setUsage(AzureAIUsage usage) {
  36. this.usage = usage;
  37. }
  38. public List<AzureAIChoice> getChoices() {
  39. return choices;
  40. }
  41. public void setChoices(List<AzureAIChoice> choices) {
  42. this.choices = choices;
  43. }
  44. }

         AzureAICompletionsResult类:

  1. import java.util.ArrayList;
  2. import lombok.Data;
  3. @Data
  4. public class AzureAICompletionsResult {
  5. private String id;
  6. private String object;
  7. private String created;
  8. private String model;
  9. private ArrayList<AzureAIChoice> choices;
  10. }

        AzureAICompletionRequest类:

  1. import com.google.gson.annotations.SerializedName;
  2. import lombok.Data;
  3. import java.util.HashMap;
  4. @Data
  5. public class AzureAICompletionRequest {
  6. private String prompt;
  7. @SerializedName("max_tokens")
  8. private int maxTokens;
  9. private int temperature = 1;
  10. @SerializedName("top_p")
  11. private int topProbibility = 1;
  12. @SerializedName("logit_bias")
  13. private HashMap<String, Integer> logitBiasMap;
  14. private String user;
  15. @SerializedName("n")
  16. private int choices = 1;
  17. private boolean stream = false;
  18. private String suffix;
  19. private boolean echo;
  20. private String stop;
  21. @SerializedName("presence_penalty")
  22. private int presencePenalty = 0;
  23. @SerializedName("frequency_penalty")
  24. private int frequencyPenalty = 0;
  25. @SerializedName("best_of")
  26. private int bestOf;
  27. }

         AzureAIUsage类:

  1. public class AzureAIUsage {
  2. /*
  3. "prompt_tokens": 10,
  4. "completion_tokens": 9,
  5. "total_tokens": 19
  6. */
  7. @SerializedName("prompt_tokens")
  8. private int promptTokens;
  9. @SerializedName("completion_tokens")
  10. private int completionTokens;
  11. @SerializedName("total_tokens")
  12. private int totalTokens;
  13. public int getPromptTokens() {
  14. return promptTokens;
  15. }
  16. public void setPromptTokens(int promptTokens) {
  17. this.promptTokens = promptTokens;
  18. }
  19. public int getCompletionTokens() {
  20. return completionTokens;
  21. }
  22. public void setCompletionTokens(int completionTokens) {
  23. this.completionTokens = completionTokens;
  24. }
  25. public int getTotalTokens() {
  26. return totalTokens;
  27. }
  28. public void setTotalTokens(int totalTokens) {
  29. this.totalTokens = totalTokens;
  30. }
  31. }

        AzureAIChoice类:

  1. public class AzureAIChoice {
  2. private AzureAIMessage azureAIMessage;
  3. }

        AzureAIMessage类:

  1. public class AzureAIMessage {
  2. private String role;
  3. private String content;
  4. }

3、参数封装

        根据参数封装类:

        AzureAIChatRequest类:

  1. public class AzureAIChatRequest {
  2. private List<AzureAIMessage> messages;
  3. private Double temperature;
  4. @SerializedName("n")
  5. private Integer choices;
  6. private boolean stream;
  7. private String stop;
  8. @SerializedName("max_tokens")
  9. private Integer maxTokens;
  10. @SerializedName("presence_penalty")
  11. private Integer presencePenalty;
  12. @SerializedName("frequency_penalty")
  13. private Integer frequencyPenalty;
  14. private String user;
  15. public List<AzureAIMessage> getMessages() {
  16. return messages;
  17. }
  18. public void setMessages(List<AzureAIMessage> messages) {
  19. this.messages = messages;
  20. }
  21. public void addMessage(AzureAIMessage message) {
  22. if (this.messages == null) {
  23. this.messages = new ArrayList<>();
  24. }
  25. this.messages.add(message);
  26. }
  27. public Double getTemperature() {
  28. return temperature;
  29. }
  30. public void setTemperature(Double temperature) {
  31. this.temperature = temperature;
  32. }
  33. public int getChoices() {
  34. return choices;
  35. }
  36. public void setChoices(int choices) {
  37. this.choices = choices;
  38. }
  39. public boolean isStream() {
  40. return stream;
  41. }
  42. public void setStream(boolean stream) {
  43. this.stream = stream;
  44. }
  45. public String isStop() {
  46. return stop;
  47. }
  48. public void setStop(String stop) {
  49. this.stop = stop;
  50. }
  51. public void setStop(boolean stop) {
  52. if (stop) {
  53. this.stop = "true";
  54. } else {
  55. this.stop = "false";
  56. }
  57. }
  58. public int getMaxTokens() {
  59. return maxTokens;
  60. }
  61. public void setMaxTokens(int maxTokens) {
  62. this.maxTokens = maxTokens;
  63. }
  64. public int getPresencePenalty() {
  65. return presencePenalty;
  66. }
  67. public void setPresencePenalty(int presencePenalty) {
  68. this.presencePenalty = presencePenalty;
  69. }
  70. public int getFrequencyPenalty() {
  71. return frequencyPenalty;
  72. }
  73. public void setFrequencyPenalty(int frequencyPenalty) {
  74. this.frequencyPenalty = frequencyPenalty;
  75. }
  76. public String getUser() {
  77. return user;
  78. }
  79. public void setUser(String user) {
  80. this.user = user;
  81. }
  82. }

         AzureAIMessage类:

  1. public class AzureAIMessage {
  2. private AzureAIRole role;
  3. private String content;
  4. public AzureAIMessage() {
  5. }
  6. public AzureAIMessage(String content, AzureAIRole role) {
  7. this.content = content;
  8. this.role = role;
  9. }
  10. public AzureAIRole getRole() {
  11. return role;
  12. }
  13. public void setRole(AzureAIRole role) {
  14. this.role = role;
  15. }
  16. public String getContent() {
  17. return content;
  18. }
  19. public void setContent(String content) {
  20. this.content = content;
  21. }
  22. }

        AzureAIRole类:

  1. public enum AzureAIRole {
  2. @SerializedName("assistant")
  3. ASSISTANT("assistant"),
  4. @SerializedName("system")
  5. SYSTEM("system"),
  6. @SerializedName("user")
  7. USER("user"),
  8. ;
  9. private final String text;
  10. private AzureAIRole(final String text) {
  11. this.text = text;
  12. }
  13. @Override
  14. public String toString() {
  15. return text;
  16. }
  17. }

4、客户端访问

        客户端访问类:

  1. import cn.hutool.core.date.BetweenFormatter;
  2. import cn.hutool.core.date.DateUnit;
  3. import cn.hutool.core.date.DateUtil;
  4. import cn.hutool.http.HttpRequest;
  5. import cn.hutool.json.JSONUtil;
  6. import com.google.gson.Gson;
  7. import lombok.extern.slf4j.Slf4j;
  8. import org.asynchttpclient.*;
  9. import java.io.Closeable;
  10. import java.io.IOException;
  11. import java.util.Date;
  12. import java.util.concurrent.Future;
  13. @Slf4j
  14. public class AzureAIClient implements Closeable {
  15. private static final String JSON = "application/json; charset=UTF-8";
  16. private final boolean closeClient;
  17. private final AsyncHttpClient client;
  18. private final String deploymentName;
  19. private final String url;
  20. private final String token;
  21. private static final Version version = new Version();
  22. private final String apiVersion;
  23. private boolean closed = false;
  24. Gson gson = new Gson();
  25. public AzureAIClient(String url, String apiKey, String deploymentName, String apiVersion) throws Exception {
  26. this.client = new DefaultAsyncHttpClient();
  27. this.url = url + "/openai/deployments/" + deploymentName + "/";
  28. this.token = apiKey;
  29. this.deploymentName = deploymentName;
  30. this.apiVersion = apiVersion;
  31. closeClient = true;
  32. }
  33. public boolean isClosed() {
  34. return closed || client.isClosed();
  35. }
  36. @Override
  37. public void close() {
  38. if (closeClient && !client.isClosed()) {
  39. try {
  40. client.close();
  41. } catch (IOException ex) {
  42. }
  43. }
  44. closed = true;
  45. }
  46. public static String getVersion() {
  47. return version.getBuildNumber();
  48. }
  49. public static String getBuildName() {
  50. return version.getBuildName();
  51. }
  52. public AzureAICompletionsResult getCompletion(AzureAICompletionRequest completion) throws Exception {
  53. //chat/completions
  54. Future<Response> f = client.executeRequest(buildRequest("POST", "completions?api-version=" + apiVersion, gson.toJson(completion)));
  55. Response r = f.get();
  56. if (r.getStatusCode() != 200) {
  57. throw new Exception("Could not get competion result");
  58. } else {
  59. return gson.fromJson(r.getResponseBody(), AzureAICompletionsResult.class);
  60. }
  61. }
  62. public AzureAICreateEmbedingResponse createEmbedding(AzureAIEmbedding embedding) throws Exception {
  63. Future<Response> f = client.executeRequest(buildRequest("POST", "embeddings?api-version=" + apiVersion, gson.toJson(embedding)));
  64. Response r = f.get();
  65. if (r.getStatusCode() != 200) {
  66. throw new Exception("Could not create embedding");
  67. } else {
  68. AzureAICreateEmbedingResponse azureAICreateEmbedingResponse = JSONUtil.toBean(r.getResponseBody(), AzureAICreateEmbedingResponse.class);
  69. return azureAICreateEmbedingResponse;
  70. }
  71. }
  72. public AzureAIChatResponse sendMyChatRequest(AzureAIChatRequest chatRequest) throws Exception {
  73. Date startDateOne = DateUtil.date();
  74. String f = buildMyRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest));
  75. Date endDateOne = DateUtil.date();
  76. // 获取开始时间和结束时间的时间差
  77. long betweenDateOne = DateUtil.between(startDateOne, endDateOne, DateUnit.MS);
  78. // 格式化时间
  79. String formatBetweenOne = DateUtil.formatBetween(betweenDateOne, BetweenFormatter.Level.MILLISECOND);
  80. log.info(String.format("请求数据耗时(毫秒):%s",formatBetweenOne));
  81. Date startDate = DateUtil.date();
  82. System.err.println(f);
  83. AzureAIChatResponse azureAIChatResponse = gson.fromJson(f, AzureAIChatResponse.class);
  84. Date endDate = DateUtil.date();
  85. // 获取开始时间和结束时间的时间差
  86. long betweenDate = DateUtil.between(startDate, endDate, DateUnit.MS);
  87. // 格式化时间
  88. String formatBetween = DateUtil.formatBetween(betweenDate, BetweenFormatter.Level.MILLISECOND);
  89. log.info(String.format("格式化数据耗时(毫秒):%s",formatBetween));
  90. return azureAIChatResponse;
  91. }
  92. private String buildMyRequest(String type, String subUrl, String requestBody) {
  93. // RestTemplate restTemplate = new RestTemplate();
  94. // HttpHeaders httpHeaders = new HttpHeaders();
  95. // // 设置contentType
  96. httpHeaders.setContentType(MediaType.APPLICATION_JSON_UTF8);
  97. // // 给请求header中添加一些数据
  98. // httpHeaders.add("Accept", JSON);
  99. // httpHeaders.add("Content-Type", JSON);
  100. // httpHeaders.add("api-key", this.token);
  101. //
  102. //
  103. // HttpEntity<String> httpEntity = new HttpEntity<String>(requestBody, httpHeaders);
  104. // ResponseEntity<String> exchange = restTemplate.postForEntity(this.url + subUrl, httpEntity, String.class);
  105. //
  106. // String resultRemote = exchange.getBody();//得到返回的值
  107. String accept = HttpRequest.post(this.url + subUrl)
  108. .header("Accept", JSON)
  109. .header("Content-Type", "application/json")
  110. .header("api-key", this.token)
  111. .setReadTimeout(30000)
  112. .body(requestBody)
  113. .execute()
  114. .body();
  115. return accept;
  116. }
  117. public AzureAIChatResponse sendChatRequest(AzureAIChatRequest chatRequest) throws Exception {
  118. Date startDateOne = DateUtil.date();
  119. Future<Response> f = client.executeRequest(buildRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest)));
  120. // Request r = buildRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest));
  121. Response r = f.get();
  122. Date endDateOne = DateUtil.date();
  123. // 获取开始时间和结束时间的时间差
  124. long betweenDateOne = DateUtil.between(startDateOne, endDateOne, DateUnit.MS);
  125. // 格式化时间
  126. String formatBetweenOne = DateUtil.formatBetween(betweenDateOne, BetweenFormatter.Level.MILLISECOND);
  127. log.info(String.format("请求数据耗时(毫秒):%s",formatBetweenOne));
  128. if (r.getStatusCode() != 200) {
  129. log.info("Could not create chat request - server resposne was " + r.getStatusCode() + " to url: " + url + "chat/completions?api-version=2023-03-15-preview");
  130. return null;
  131. } else {
  132. Date startDate = DateUtil.date();
  133. // System.err.println(r.getResponseBody());
  134. AzureAIChatResponse azureAIChatResponse = JSONUtil.toBean(r.getResponseBody(), AzureAIChatResponse.class);
  135. // AzureAIChatResponse azureAIChatResponse = gson.fromJson(r.getResponseBody(), AzureAIChatResponse.class);
  136. Date endDate = DateUtil.date();
  137. // 获取开始时间和结束时间的时间差
  138. long betweenDate = DateUtil.between(startDate, endDate, DateUnit.MS);
  139. // 格式化时间
  140. String formatBetween = DateUtil.formatBetween(betweenDate, BetweenFormatter.Level.MILLISECOND);
  141. log.info(String.format("格式化数据耗时(毫秒):%s",formatBetween));
  142. return azureAIChatResponse;
  143. }
  144. }
  145. private Request buildRequest(String type, String subUrl) {
  146. RequestBuilder builder = new RequestBuilder(type);
  147. Request request = builder.setUrl(this.url + subUrl)
  148. .addHeader("Accept", JSON)
  149. .addHeader("Content-Type", JSON)
  150. .addHeader("Authorization", "Bearer " + this.token)
  151. .build();
  152. return request;
  153. }
  154. private Request buildRequest(String type, String subUrl, String requestBody) {
  155. RequestBuilder builder = new RequestBuilder(type);
  156. Request request = builder.setUrl(this.url + subUrl)
  157. .addHeader("Accept", JSON)
  158. .addHeader("Content-Type", JSON)
  159. .addHeader("api-key", this.token)
  160. .setBody(requestBody)
  161. .build();
  162. return request;
  163. }
  164. }

5、调用测试

        调用测试:

  1. public static void main(String[] args) {
  2. // 装配请求集合
  3. List<AzureAIMessage> azureAiMessageList = new ArrayList<>();
  4. AzureAIChatRequest azureAiChatRequest = new AzureAIChatRequest();
  5. AzureAIMessage azureAIMessage0 = new AzureAIMessage();
  6. azureAIMessage0.setRole(AzureAIRole.SYSTEM);
  7. azureAIMessage0.setContent("你是一个AI机器人,请根据提问进行回答");
  8. azureAiMessageList.add(azureAIMessage0);
  9. AzureAIMessage azureAIMessage1 = new AzureAIMessage();
  10. azureAIMessage1.setRole(AzureAIRole.USER);
  11. azureAIMessage1.setContent("请解释一下java的gc");
  12. azureAiMessageList.add(azureAIMessage1);
  13. azureAiChatRequest.setMessages(azureAiMessageList);
  14. azureAiChatRequest.setMaxTokens(maxTokens);
  15. azureAiChatRequest.setTemperature(temperature);
  16. // 是否进行留式返回
  17. // azureAiChatRequest.setStream(true);
  18. azureAiChatRequest.setPresencePenalty(0);
  19. azureAiChatRequest.setFrequencyPenalty(0);
  20. azureAiChatRequest.setStop(null);
  21. AzureAIClient azureAIClient = new AzureAIClient("申请的azure地址", "zaure的apikey",
  22. "模型(gpt-35-turbo)", "api版本:(023-03-15-preview)");
  23. AzureAIChatResponse azureAIChatResponse = azureAIClient.sendChatRequest(azureAIChatRequest);
  24. }

6、依赖

maven依赖:

  1. <dependencies>
  2. <dependency>
  3. <groupId>org.asynchttpclient</groupId>
  4. <artifactId>async-http-client</artifactId>
  5. <version>2.12.3</version>
  6. <type>jar</type>
  7. </dependency>
  8. <dependency>
  9. <groupId>com.google.code.gson</groupId>
  10. <artifactId>gson</artifactId>
  11. <version>2.10.1</version>
  12. </dependency>
  13. </dependencies>

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/738071
推荐阅读
相关标签
  

闽ICP备14008679号