当前位置:   article > 正文

Java调用Python大语言模型(LLM)遇到乱码_java llm

java llm

遇到问题,Java调用Python大语言模型(LLM)遇到乱码
解决方法:

Java代码

public class PythonRunner {
    public static String runPythonScript(String python, String pythonFile, String prompt) {
        try {
            // 打印Python解释器路径、Python脚本路径和输入
            // 用于调试,可以删除,可以在终端执行该命令行以检查Python脚本是否能正常运行
            System.out.println(python + " " + pythonFile + " " + prompt);
            // 构建进程命令
            ProcessBuilder processBuilder = new ProcessBuilder(python, pythonFile, prompt);
            processBuilder.redirectErrorStream(true);

            // 启动进程
            Process process = processBuilder.start();

            // 获取进程输出流
            InputStream inputStream = process.getInputStream();
            InputStreamReader inputStreamReader = new InputStreamReader(inputStream, StandardCharsets.UTF_8);
            BufferedReader bufferedReader = new BufferedReader(inputStreamReader);

            // 读取输出
            StringBuilder result = new StringBuilder();
            String line;
            while ((line = bufferedReader.readLine()) != null) {
                result.append(line).append("\n");
            }

            // 等待进程结束
            process.waitFor();

            // 返回结果
            return result.toString();
        } catch (IOException | InterruptedException e) {
            e.printStackTrace();
            return null;
        }
    }
    
        public static void main(String[] args) {
        String python = "E:\\ProgramData\\Anaconda3\\envs\\LLM3.8\\python.exe";  // 你的Python解释器路径
        String pythonFile = "G:\\outSource\\project1\\project1-springboot\\src\\main\\python\\model.py";  // 你的Python脚本路径
        String prompt = "今天天气真好";

        String result = runPythonScript(python, pythonFile, prompt);
        System.out.println(result);
    }

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45

Python模型代码

import io
import sys

from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import logging

# 设置编码格式,防止乱码
# sys的默认编码虽然是UTF-8,但是sys.stdout的编码不是UTF-8,所以需要重新设置一下
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')

logging.set_verbosity_error()

# 设置绝对路径,防止程序读不到文件
# 这里是模型的路径,需要根据自己的实际情况进行修改
# 不设置相对路径的原因是因为Java中可能存在多个不同位置的文件调用该python文件,所以需要设置绝对路径
path = "G:\\outSource\\project1\\project1-springboot\\src\\main\\python\\TinyLlama-1.1B-Chat-v1.0"

# Load model and tokenizer
model = LlamaForCausalLM.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)

# sys.argv[1:]存放的是脚本的第一、二...个参数,这里是传入的对话
if len(sys.argv) > 1:
    prompt = (" ".join(sys.argv[1:]))
else:
    prompt = (
        "私は日本語を学んだことがなくて、簡単な語句で交流したいと思って、私の第1文はあなたが日本人をどのように思うのです"
    )
inputs = tokenizer(prompt, return_tensors="pt")

# Generate
# 生成对话
generate_ids = model.generate(inputs.input_ids, max_length=120)
res = tokenizer.batch_decode(
    generate_ids, skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)[0]

# java通过获取打印值来获取python的返回值
print(res)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

上面两个文件一个是在Java中的配置,另一个是在Python中的配置,遇到的问题其实就是Python打印时显示乱码,找了好久问题,最后发现是因为虽然sys库的默认编码是UTF-8,但是sys.stdout不是,需要重新设置UTF-8。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/362667
推荐阅读
相关标签
  

闽ICP备14008679号