当前位置:   article > 正文

使用Java DeepJavaLibrary ONNXRuntime 运行TinyLlama进行文本生成_onnxruntime java

onnxruntime java

1. 转换模型格式为onnx optimum-cli export onnx -m TinyLlama/TinyLlama-1.1B-Chat-v1.0  --monolith  TinyLlama-onnx

2. 执行量化 optimum-cli onnxruntime quantize --onnx_model  TinyLlama-onnx  -o llama_quantize --avx2

  1. package com.yucl.demo.djl;
  2. import java.nio.FloatBuffer;
  3. import java.nio.LongBuffer;
  4. import java.nio.file.Paths;
  5. import java.util.ArrayList;
  6. import java.util.HashMap;
  7. import java.util.List;
  8. import java.util.Map;
  9. import ai.djl.huggingface.tokenizers.Encoding;
  10. import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
  11. import ai.djl.ndarray.NDArray;
  12. import ai.djl.ndarray.NDManager;
  13. import ai.djl.ndarray.types.Shape;
  14. import ai.onnxruntime.OnnxTensor;
  15. import ai.onnxruntime.OnnxValue;
  16. import ai.onnxruntime.OrtEnvironment;
  17. import ai.onnxruntime.OrtLoggingLevel;
  18. import ai.onnxruntime.OrtSession;
  19. public class OnnxTextGenerateDemo {
  20. public static void main(String[] args) throws Exception {
  21. String TOKENIZER_URI = "D:\\llm\\llama_quantize\\tokenizer.json";
  22. String MODEL_URI = "D:\\llm\\llama_quantize\\model_quantized.onnx";
  23. OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
  24. sessionOptions.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR);
  25. sessionOptions.setMemoryPatternOptimization(true);
  26. try (OrtEnvironment env = OrtEnvironment.getEnvironment();
  27. OrtSession session = env.createSession(MODEL_URI, sessionOptions);) {
  28. String sentences = "How to learn ai program ? ";
  29. HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(TOKENIZER_URI), Map.of());
  30. Encoding encodings = tokenizer.encode(sentences);
  31. long[] input_ids = encodings.getIds();
  32. List<Long> generatedIds = new ArrayList<>();
  33. for (long id : input_ids) {
  34. generatedIds.add(id);
  35. }
  36. int totalLength = 100;
  37. while (generatedIds.size() < totalLength) {
  38. long[] currentInputIds = new long[generatedIds.size()];
  39. long[] currentPositionIds = new long[generatedIds.size()];
  40. long[] attentionMask = new long[generatedIds.size()];
  41. for (int i = 0; i < generatedIds.size(); i++) {
  42. currentInputIds[i] = generatedIds.get(i);
  43. currentPositionIds[i] = i;
  44. attentionMask[i] = 1;
  45. }
  46. try (OnnxTensor inputTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(currentInputIds),
  47. new long[] { 1, currentInputIds.length });
  48. OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMask),
  49. new long[] { 1, attentionMask.length });
  50. OnnxTensor positionTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(currentPositionIds),
  51. new long[] { 1, currentPositionIds.length });) {
  52. Map<String, OnnxTensor> inputs = new HashMap<>();
  53. inputs.put("input_ids", inputTensor);
  54. inputs.put("attention_mask", attentionMaskTensor);
  55. inputs.put("position_ids", positionTensor);
  56. try (OrtSession.Result results = session.run(inputs)) {
  57. OnnxValue lastHiddenState = results.get(0);
  58. float[][][] logits = (float[][][]) lastHiddenState.getValue();
  59. long nextTokenId = argmax(logits[0][logits[0].length - 1]);
  60. generatedIds.add(nextTokenId);
  61. System.out.print(tokenizer.decode(new long[] { nextTokenId }) + " ");
  62. }
  63. inputs.clear();
  64. }
  65. }
  66. long[] gen_tokens = new long[generatedIds.size()];
  67. for (int i = 0; i < generatedIds.size(); i++) {
  68. gen_tokens[i] = generatedIds.get(i);
  69. }
  70. String outputText = tokenizer.decode(gen_tokens);
  71. System.out.println("Generated text: " + outputText);
  72. }
  73. }
  74. public static NDArray create(float[][][] data, NDManager manager) {
  75. FloatBuffer buffer = FloatBuffer.allocate(data.length * data[0].length * data[0][0].length);
  76. for (float[][] data2 : data) {
  77. for (float[] d : data2) {
  78. buffer.put(d);
  79. }
  80. }
  81. buffer.rewind();
  82. return manager.create(buffer, new Shape(data.length, data[0].length, data[0][0].length));
  83. }
  84. // 选择最大概率对应的 token ID (贪心搜索)
  85. private static int argmax(float[] array) {
  86. int maxIndex = 0;
  87. for (int i = 1; i < array.length; i++) {
  88. if (array[i] > array[maxIndex]) {
  89. maxIndex = i;
  90. }
  91. }
  92. return maxIndex;
  93. }
  94. }

源码地址:yucl80/ai-demo-java (github.com)

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/887552
推荐阅读
相关标签
  

闽ICP备14008679号