当前位置:   article > 正文

java如何对接cahtgpt API(简单记录)_java对接chatgpt

java对接chatgpt

java如何对接cahtgpt API(简单记录)

技术选型

springboot+mybatis-plus

实现效果

  • 通过java调用chatgpt API实现对话,将chatgpt生成的内容通过与前端建立websocket及时发送给前端,为加快chatgpt响应速度,采用exent/stream的方式进行,实现了逐字出现的效果

实现过程

java对接chatgpt API
  • 使用java原生的网络请求方式完成
    • 在发送网络请求时,将"stream"设置为 true,代表使用event stream的方式进行返回数据
  String url = "https://api.openai.com/v1/chat/completions";
        HashMap<String, Object> bodymap = new HashMap<>();

        bodymap.put("model", "gpt-3.5-turbo");
        bodymap.put("temperature", 0.7);
//        bodymap.put("stream",true);
        bodymap.put("messages", messagelist);
        bodymap.put("stream", true);
        Gson gson = new Gson();
        String s = gson.toJson(bodymap);
//        System.out.println(s);
        URL url1 = new URL(url);
        HttpURLConnection conn = (HttpURLConnection) url1.openConnection(new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port)));
        conn.setRequestMethod("POST");
        conn.setRequestProperty("Authorization", "Bearer " + ApiKey);
        conn.setRequestProperty("Content-Type", "application/json");
        conn.setRequestProperty("stream", "true");
        conn.setDoOutput(true);
//    写入请求参数
        OutputStream os = conn.getOutputStream();
        BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(os, Charset.forName("UTF-8")));
        writer.write(s);
        writer.close();
        os.close();

       
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 读取返回值

     InputStream inputStream = conn.getInputStream();
    
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            String line = null;
    //        System.out.println("开始回答");
            StringBuffer answoer = new StringBuffer();
            while ((line = bufferedReader.readLine()) != null) {
    
                line = line.replace("data:", "");
                JsonElement jsonElement = JsonParser.parseString(line);
                if (!jsonElement.isJsonObject()) {
    
                    continue;
                }
                JsonObject asJsonObject = jsonElement.getAsJsonObject();
                JsonArray choices = asJsonObject.get("choices").getAsJsonArray();
                if (choices.size() > 0) {
                    JsonObject choice = choices.get(0).getAsJsonObject();
                    JsonObject delta = choice.get("delta").getAsJsonObject();
                    if (delta != null) {
    //                    System.out.println(delta);
                        if (delta.has("content")) {
    //                        发送消息
                            String content = delta.get("content").getAsString();
                            BaseResponse<String> success = ResultUtils.success(content);
                            WebSocket webSocket = new WebSocket();
    
                            webSocket.sendMessageByUserId(conversionid, gson.toJson(success));
                            answoer.append(content);
    //                        webSocket.sendOneMessage(userid, success);
    //                        webSocket.sendOneMessage(userid, success);
    //                      打印在控制台中
                            System.out.print(content);
                        }
                    }
                }
    
            }
            String context = answoer.toString();
            //        将chatgpt返回的结果保存到数据库中
            Chat entity = new Chat();
            entity.setContext(context);
            entity.setRole("assistant");
            entity.setConversionid(conversionid);
            boolean save = chatService.save(entity);
    
    
    //        String s1 = stringRedisTemplate.opsForValue().get("web:" + userid);
    //        List<ChatModel> json = (List<ChatModel>) gson.fromJson(s1, new TypeToken<List<ChatModel>>() {
    //        }.getType());
    //        ChatModel chatModel = new ChatModel("assistant",answoer.toString());
    //        json.add(chatModel);
    //        stringRedisTemplate.opsForValue().set("web:" + userid,gson.toJson(json),1, TimeUnit.DAYS);
    
        }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
实现websocket与前端建立连接
@ServerEndpoint(value = "/websocket/{ConversionId}")
@Component
public class WebSocket {

    private static ChatGptUntil chatGptUntil;

    private static ChatService chatService;

    private static ConversionService conversionService;

    @Resource
    public void setConversionService(ConversionService conversionService) {
        WebSocket.conversionService = conversionService;
    }

    @Resource
    public void setChatService(ChatService chatService) {
        WebSocket.chatService = chatService;
    }

    @Resource
    public void setChatGptUntil(ChatGptUntil chatGptUntil) {
        WebSocket.chatGptUntil = chatGptUntil;
    }

    private final static Logger logger = LogManager.getLogger(WebSocket.class);

    /**
     * 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的
     */

    private static int onlineCount = 0;

    /**
     * concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象
     */
    private static ConcurrentHashMap<String, WebSocket> webSocketMap = new ConcurrentHashMap<>();

    /**
     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
     */

    private Session session;
    private Long ConversionId;


    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("ConversionId") Long ConversionId) {
        this.session = session;
        this.ConversionId = ConversionId;
        //加入map
        webSocketMap.put(ConversionId.toString(), this);
        addOnlineCount();           //在线数加1
        logger.info("对话{}连接成功,当前在线人数为{}", ConversionId, getOnlineCount());
        try {
            sendMessage(String.valueOf(this.session.getQueryString()));
        } catch (IOException e) {
            logger.error("IO异常");
        }
    }


    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        //从map中删除
        webSocketMap.remove(ConversionId.toString());
        subOnlineCount();           //在线数减1
        logger.info("对话{}关闭连接!当前在线人数为{}", ConversionId, getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) throws IOException {
        logger.info("来自客户端对话:{} 消息:{}", ConversionId, message);


        Gson gson = new Gson();

//        ChatMessage chatMessage = gson.fromJson(message, ChatMessage.class);

        System.out.println(message);

//        Long conversionid = chatMessage.getConversionid();
//        if (conversionid == null) {
//            BaseResponse baseResponse = ResultUtils.error(4000, "请指明是哪个对话");
//            String s = gson.toJson(baseResponse);
//            session.getBasicRemote().sendText(s);
//        }

        if (message == null) {
            BaseResponse baseResponse = ResultUtils.error(4000, "请指明是该对话的用途");
            String s = gson.toJson(baseResponse);
            session.getBasicRemote().sendText(s);
        }
//        将对话保存到数据库中
        Chat entity = new Chat();
        entity.setContext(message);
        entity.setConversionid(this.ConversionId);
        entity.setRole("user");
        boolean save = chatService.save(entity);

        if (!save) {
            BaseResponse baseResponse = ResultUtils.error(500, "数据库出现错误");
            String s = gson.toJson(baseResponse);
            session.getBasicRemote().sendText(s);
        }


//        查询出身份
        Conversion byId = conversionService.getById(this.ConversionId);
        String instructions = byId.getInstructions();// 指令
//     给予chatgot身份
        ArrayList<ChatModel> chatModels = new ArrayList<>();
//        ChatModel scene = new ChatModel("user", instructions);
//        chatModels.add(scene);

        LambdaQueryWrapper<Chat> queryWrapper = new LambdaQueryWrapper<>();
        // 按照修改时间进行升序排序
        queryWrapper.eq(Chat::getConversionid, byId.getId()).orderByDesc(Chat::getUpdatedtime);
        List<Chat> list = chatService.list(queryWrapper);

//        查询之前的对话记录
        List<ChatModel> collect = list.stream().map(chat -> {
            ChatModel chatModel = new ChatModel();
            chatModel.setRole(chat.getRole());
            chatModel.setContent(chat.getContext());
//            BeanUtils.copyProperties(chat, chatModel);
            return chatModel;
        }).collect(Collectors.toList());
        chatModels.addAll(collect);


        chatGptUntil.getRespost(this.ConversionId, chatModels);
//        if (chatGptUntil==null){
//            System.out.println("chatuntil是空");
//        }
//
//        if (stringRedisTemplate==null){
//            System.out.println("缓存是空");
//        }


        //群发消息
        /*for (String item : webSocketMap.keySet()) {
            try {
                webSocketMap.get(item).sendMessage(message);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }*/
    }

    /**
     * 发生错误时调用
     *
     * @OnError
     */
    @OnError
    public void onError(Session session, Throwable error) {
        logger.error("对话错误:" + this.ConversionId + ",原因:" + error.getMessage());
        error.printStackTrace();
    }

    /**
     * 向客户端发送消息
     */
    public void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
        //this.session.getAsyncRemote().sendText(message);
    }

    /**
     * 通过userId向客户端发送消息
     */
    public void sendMessageByUserId(Long ConversionId, String message) throws IOException {
        logger.info("服务端发送消息到{},消息:{}", ConversionId, message);
        if (StrUtil.isNotBlank(ConversionId.toString()) && webSocketMap.containsKey(ConversionId.toString())) {
            webSocketMap.get(ConversionId.toString()).sendMessage(message);
        } else {
            logger.error("{}不在线", ConversionId);
        }

    }

    /**
     * 群发自定义消息
     */
    public static void sendInfo(String message) {
        for (String item : webSocketMap.keySet()) {
            try {
                webSocketMap.get(item).sendMessage(message);
            } catch (IOException e) {
                continue;
            }
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocket.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocket.onlineCount--;
    }

}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 在本项目中通过对话id标识用户的每次与cahtgpt的交互,并且将该对话下的所有内容保存在数据库中实现了对话的长久保存
gitee地址:

https://gitee.com/li-manxiang/chatgptservice.git

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/372384
推荐阅读
相关标签
  

闽ICP备14008679号