赞
踩
import tensorflow as tf import numpy as np class MNISTLoader(): def __init__(self): mnist = tf.keras.datasets.mnist (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data() # 归一化,增加颜色通道 [60000, 28, 28, 1] self.x_train = np.expand_dims(self.x_train.astype(np.float32) / 255.0, axis=-1) # [10000, 28, 28, 1] self.x_test = np.expand_dims(self.x_test.astype(np.float32) / 255.0, axis=-1) # 将标签转换为整型 self.y_train = self.y_train.astype(np.int32) self.y_test = self.y_test.astype(np.int32) # 获取训练集和测试集的总数 self.x_train_count, self.x_test_count = self.x_train.shape[0], self.x_test.shape[0] def get_batch(self, batch_size): # 从0-60000随机选择batch_size个元素 index = np.random.randint(0, np.shape(self.x_train)[0], batch_size) return self.x_train[index, :], self.y_train[index]
class MLP(tf.keras.Model): def __init__(self): super(MLP, self).__init__() # 将除第一维以外的维度展平 self.flatten = tf.keras.layers.Flatten() # units 为输出张量的维度 self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu) self.dense2 = tf.keras.layers.Dense(units=10) @tf.function(input_signature=[tf.TensorSpec([None, 28, 28, 1], tf.float32)]) def call(self, inputs): # [batch_size, 28, 28, 1] x = self.flatten(inputs) x = self.dense1(x) x = self.dense2(x) output = tf.nn.softmax(x) return output
epochs = 5
batch_size = 50
learning_rate = 0.001
model = MLP()
data_loader = MNISTLoader()
# 实例化优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
1、从data_loader中随机取一批数据
2、将这批数据送入模型,计算出模型的预测值
3、预测值与真实值比较,计算损失函数
4、计算损失函数关于模型变量的导数
5、将求出的导数值传入优化器中,使用优化器更新模型参数以最小化损失函数
num_batches = int(data_loader.x_train_count // batch_size * epochs)
for batch_index in range(num_batches):
X, Y = data_loader.get_batch(batch_size)
with tf.GradientTape() as tape:
y_pred = model(X)
# 预测值与真实值比较,计算损失函数
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=Y, y_pred=y_pred)
loss = tf.reduce_mean(loss)
print("batch %d: loss %f" % (batch_index, loss.numpy()))
# 计算梯度
grads = tape.gradient(loss, model.variables)
# 自动根据梯度更新参数
optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
# 实例化评估器
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.x_test_count // batch_size)
for batch_index in range(num_batches):
# 定义一个batch的开始和结束位置
start_index, end_index = batch_index * batch_size,(batch_index+1) * batch_size
y_pred = model.predict(data_loader.x_test[start_index: end_index])
sparse_categorical_accuracy.update_state(y_true=data_loader.y_test[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % sparse_categorical_accuracy.result())
tf.saved_model.save(model, "D:/file/model/")
assets #模型依赖的外部文件,比如vocab
saved_model.pb #模型的网络结构,可以接受tensor输入,计算完后输出tensor
# saved_model.pb或saved_model.pbtxt是SavedModel协议缓冲区。它将图形定义作为MetaGraphDef协议缓冲区。MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示
variables #模型的参数
saved_model_cli show --dir model_dir_path --all
# 卸载旧docker(若之前安装过docker)该步可选 yum remove docker \ docker-client \ docker-client-latest \ docker-common \ docker-latest \ docker-latest-logrotate \ docker-logrotate \ docker-engine # 下载需要的安装包 yum install -y yum-utils # 设置阿里云的Docker镜像仓库 yum-config-manager \ --add-repo \ https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo # 更新yum软件包索引 yum makecache fast # 安装docker(社区版) yum install docker-ce docker-ce-cli containerd.io # 拉取TensorFlow Serving docker pull tensorflow/serving
docker run -p 8501:8501 --mount type=bind,source=/home/linjie/model,target=/models/saved_model -e MODEL_NAME=saved_model -t tensorflow/serving &(RESTful API)
docker run -p 8500:8500 --mount type=bind,source=/home/linjie/model,target=/models/saved_model -e MODEL_NAME=saved_model -t tensorflow/serving &(gRPC)
1、先配置config文件 model_config_list:{ config:{ name:"z_model", # 名字随意 base_path:"/models/ble/z_model", # 一定要用/models/XXXX model_platform:"tensorflow" }, config:{ name:"xy_model", base_path:"/models/ble/xy_model", model_platform:"tensorflow" } } 2、进行多模型部署 docker run -p 8500:8500 -p 8501:8501 --mount type=bind,source=/home/ble/,target=/models/ble -t tensorflow/serving --model_config_file=/models/ble/model.config # 其中,model_config_file路径也要用/models/XXX,端口8500为gRPC方式调用,端口8501位RESTful API方式调用 3、多模型部署后,请求地址也有稍微不同 原地址:http://192.168.110.100:8501/v1/models/saved_model:predict 现地址:http://192.168.110.100:8501/v1/models/(config中模型的name):predict
以上均在CentOS7虚拟机上进行,所用详细命令暂不给出
// 查看正在运行的容器
docker ps
// 停止容器
docker stop 容器ID
// 查看当前容器状态
service docker status
import tensorflow as tf import numpy as np class MNISTLoader(): def __init__(self): mnist = tf.keras.datasets.mnist (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data() # 归一化,增加颜色通道 [60000, 28, 28, 1] self.x_train = np.expand_dims(self.x_train.astype(np.float32) / 255.0, axis=-1) # [10000, 28, 28, 1] self.x_test = np.expand_dims(self.x_test.astype(np.float32) / 255.0, axis=-1) # 将标签转换为整型 self.y_train = self.y_train.astype(np.int32) self.y_test = self.y_test.astype(np.int32) # 获取训练集和测试集的总数 self.x_train_count, self.x_test_count = self.x_train.shape[0], self.x_test.shape[0] def get_batch(self, batch_size): # 从0-60000随机选择batch_size个元素 index = np.random.randint(0, np.shape(self.x_train)[0], batch_size) return self.x_train[index, :], self.y_train[index]
import json
import requests
data_loder = MNISTLoader()
data = json.dumps({"instances": data_loder.x_test[0:10].tolist()})
headers = {"content-type": "application/json"}
json_response = requests.post('http://192.168.110.100:8501/v1/models/saved_model:predict',data=data, headers=headers)
pre = np.array(json.loads(json_response.text)['predictions'])
print(np.argmax(pre, axis=-1))
print(data_loder.y_test[0:10])
request.post(url, data=None, json=None, **kwargs)
# 返回响应对象
(gRPC调用)若使用Java作为客户端,则需要编译proto文件
参考地址:
1、https://github.com/junwan01/tensorflow-serve-client
2、https://www.cnblogs.com/ustcwx/p/12768463.html
// 需要注意版本问题,由.proto文件编译出来的java class依赖tensorflow的jar包,可能存在不兼容问题 【Linux】 export SRC=~/Documents/source_code/ mkdir -p $SRC cd $SRC git clone git@github.com:tensorflow/serving.git cd serving git checkout tags/2.1.0 cd $RSC git clone git@github.com:tensorflow/tensorflow.git cd tensorflow git checkout tags/v2.1.0 【Windows】 // 创建文件夹 mkdir D:/file/source_code cd D:/file/source_code // git下载serving git clone https://github.com/tensorflow/serving cd serving git checkout tags/2.1.0 cd D:/file/source_code // git下载tensorflow git clone https://github.com/tensorflow/tensorflow cd tensorflow git checkout tags/v2.1.0
$ mkdir -p $PROJECT_ROOT/src/main/proto/
$ rsync -arv --prune-empty-dirs --include="*/" --include='*.proto' --exclude='*' $SRC/serving/tensorflow_serving $PROJECT_ROOT/src/main/proto/
$ rsync -arv --prune-empty-dirs --include="*/" --include="tensorflow/core/lib/core/*.proto" --include='tensorflow/core/framework/*.proto' --include="tensorflow/core/example/*.proto" --include="tensorflow/core/protobuf/*.proto" --exclude='*' $SRC/tensorflow/tensorflow $PROJECT_ROOT/src/main/proto/
// 因未安装rsync,所以直接拷贝前人准备好的proto文件放置java工程中
参考地址:https://github.com/junwan01/tensorflow-serve-client/tree/master/src/main/proto
<!-- 在maven项目中添加依赖 --> <properties> <grpc.version>1.20.0</grpc.version> </properties> <dependencies> <!-- gRPC protobuf client --> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-protobuf</artifactId> <version>${grpc.version}</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-stub</artifactId> <version>${grpc.version}</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-netty-shaded</artifactId> <version>${grpc.version}</version> </dependency> <dependency> <groupId>com.google.api.grpc</groupId> <artifactId>proto-google-common-protos</artifactId> <version>1.0.0</version> </dependency> </dependencies>
brew install protobuf(Windows实测不行,没有brew命令)
<!-- 通过maven添加编译插件(注意版本) --> <build> <extensions> <extension> <groupId>kr.motd.maven</groupId> <artifactId>os-maven-plugin</artifactId> <version>1.6.2</version> </extension> </extensions> <plugins> <plugin> <groupId>org.xolstice.maven.plugins</groupId> <artifactId>protobuf-maven-plugin</artifactId> <version>0.6.1</version> <executions> <execution> <goals> <goal>compile</goal> <goal>compile-custom</goal> </goals> </execution> </executions> <configuration> <checkStaleness>true</checkStaleness> <protocArtifact>com.google.protobuf:protoc:3.6.1:exe:${os.detected.classifier}</protocArtifact> <pluginId>grpc-java</pluginId> <pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact> </configuration> </plugin> </plugins> </build>
// 在当前工程根路径下执行命令
mvn protobuf:compile(执行报错,版本问题未解决)
编译完成之后,在$PROJECT_ROOT/src/main/resources下会增加一个new_old的文件夹,将里面的./org/tensorflow 和 ./tensorflow/serving 两个文件夹移动至PROJECT_ROOT/src/main/java下即可
执行失败,所以直接拷贝前人的文件至工程路径。
参考链接:https://github.com/junwan01/tensorflow-serve-client/tree/master/target/generated-sources/protobuf
手动编译相较前者麻烦些,但是可以编译出静态的代码集成至工程中,而不是每次运行都动态生成(未尝试)
// grpc-java repo代码地址:https://github.com/grpc/grpc-java
$ cd $SRC
$ git clone https://github.com/grpc/grpc-java.git
Cloning into 'grpc-java'...
remote: Enumerating objects: 166, done.
remote: Counting objects: 100% (166/166), done.
remote: Compressing objects: 100% (121/121), done.
remote: Total 84096 (delta 66), reused 92 (delta 25), pack-reused 83930
Receiving objects: 100% (84096/84096), 31.18 MiB | 23.14 MiB/s, done.
Resolving deltas: 100% (38843/38843), done.
$ cd grpc-java/compiler/
$ ../gradlew java_pluginExecutable
$ ls -l build/exe/java_plugin/protoc-gen-grpc-java
// 运行shell脚本,编译protobuf文件 export SRC=~/code/TFS_source/ export PROJECT_ROOT=~/java/JavaClient/ cd $PROJECT_ROOT/src/main/proto/ protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/example/*.proto # append by wangxiao protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/core/logging.proto protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/stream_executor/dnn.proto protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/apis/*.proto protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/config/*.proto protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/util/*.proto protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/sources/storage_path/*.proto # the following 3 cmds will generate extra *Grpc.java stub source files in addition to the regular protobuf Java source files. # The output grpc-java files are put in the same directory as the regular java source files. # note the --plugin option uses the grpc-java plugin file we created in step 1. protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/protobuf/*.proto protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/lib/core/*.proto protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/framework/*.proto protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/prediction_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/model_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/session_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java
运行正常的情况下,$PROJECT_ROOT/src/main/java/ 文件夹里应该新增了/org/tensorflow 和 /tensorflow/serving 两个文件夹,至此,编译结束!
1、参考源码
package client; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import tensorflow.serving.Model; import org.tensorflow.framework.DataType; import org.tensorflow.framework.TensorProto; import org.tensorflow.framework.TensorShapeProto; import tensorflow.serving.Predict; import tensorflow.serving.PredictionServiceGrpc; import java.util.ArrayList; import java.util.List; public class FastTextTFSClient { /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { String host = "127.0.0.1"; int port = 8500; // the model's name. String modelName = "fastText"; int seqLen = 50; // assume this model takes input of free text, and make some sentiment prediction. List<Integer> intData = new ArrayList<Integer>(); for(int i=0; i < seqLen; i++){ intData.add(i); } // create a channel for gRPC ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel); // create a modelspec Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder(); modelSpecBuilder.setName(modelName); modelSpecBuilder.setSignatureName("fastText_sig_def"); Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder(); builder.setModelSpec(modelSpecBuilder); // create the input TensorProto and request TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder(); tensorProtoBuilder.setDtype(DataType.DT_INT32); for (Integer intDatum : intData) { tensorProtoBuilder.addIntVal(intDatum); } // build input TensorProto shape TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1)); tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(seqLen)); tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build()); TensorProto tp = tensorProtoBuilder.build(); builder.putInputs("input_x", tp); Predict.PredictRequest request = builder.build(); // get response Predict.PredictResponse response = stub.predict(request); } }
2、我的代码
1、以MNIST数据集为例,在java客户端进行调用
2、需要编写load(读取)MNIST数据集的代码
<!-- 需要在pom文件中添加依赖 -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.7.0</version>
</dependency>
// Mnist.java package data; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.DataInputStream; import java.io.File; import java.io.FileInputStream; import java.util.Random; import java.util.zip.GZIPInputStream; /** * @Title:Mnist * @Package:com.linjie.client * @Description: * @author:done * @date:2021/8/12 21:36 */ public class Mnist { public static class Data { public byte[] data; public int label; public float[] input; public float[] output; } public static void main(String[] args) throws Exception { Mnist mnist = new Mnist(); mnist.load(); System.out.println("Data loaded."); Random rand = new Random(System.nanoTime()); for (int i = 0; i < 20; i++) { int idx = rand.nextInt(60000); Data d = mnist.getTrainingData(idx); BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB); for (int x = 0; x < 28; x++) { for (int y = 0; y < 28; y++) { img.setRGB(x, y, toRgb(d.data[y * 28 + x])); } } File output = new File(i + "_" + d.label + ".png"); if (!output.exists()) { output.createNewFile(); } ImageIO.write(img, "png", output); } } static int toRgb(byte bb) { int b = (255 - (0xff & bb)); return (b << 16 | b << 8 | b) & 0xffffff; } Data[] trainingSet; Data[] testSet; public void shuffle() { Random rand = new Random(); for (int i = 0; i < trainingSet.length; i++) { int x = rand.nextInt(trainingSet.length); Data d = trainingSet[i]; trainingSet[i] = trainingSet[x]; trainingSet[x] = trainingSet[i]; } } public Data getTrainingData(int idx) { return trainingSet[idx]; } public Data[] getTrainingSlice(int start, int count) { Data[] ret = new Data[count]; System.arraycopy(trainingSet, start, ret, 0, count); return ret; } public Data getTestData(int idx) { return testSet[idx]; } public Data[] getTestSlice(int start, int count) { Data[] ret = new Data[count]; System.arraycopy(testSet, start, ret, 0, count); return ret; } public void load() { trainingSet = load("D:\\dowl\\mnist_dataset\\mnist_dataset\\train-images-idx3-ubyte.gz", "D:\\dowl\\mnist_dataset\\mnist_dataset\\train-labels-idx1-ubyte.gz"); testSet = load("D:\\dowl\\mnist_dataset\\mnist_dataset\\t10k-images-idx3-ubyte.gz", "D:\\dowl\\mnist_dataset\\mnist_dataset\\t10k-labels-idx1-ubyte.gz"); if (trainingSet.length != 60000 || testSet.length != 10000) { throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length); } } private Data[] load(String imgFile, String labelFile) { byte[][] images = loadImages(imgFile); byte[] labels = loadLabels(labelFile); if (images.length != labels.length) { throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile); } int len = images.length; Data[] data = new Data[len]; for (int i = 0; i < len; i++) { data[i] = new Data(); data[i].data = images[i]; data[i].label = 0xff & labels[i]; data[i].input = dataToInput(images[i]); data[i].output = labelToOutput(labels[i]); } return data; } private float[] labelToOutput(byte label) { float[] o = new float[10]; o[label] = 1; return o; } private float[] dataToInput(byte[] b) { float[] d = new float[b.length]; for (int i = 0; i < b.length; i++) { d[i] = (b[i] & 0xff) / 255.f; } return d; } private byte[][] loadImages(String imgFile) { try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) { int magic = in.readInt(); if (magic != 0x00000803) { throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic)); } int count = in.readInt(); int rows = in.readInt(); int cols = in.readInt(); if (rows != 28 || cols != 28) { throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols); } byte[][] data = new byte[count][rows * cols]; for (int i = 0; i < count; i++) { in.readFully(data[i]); } return data; } catch (Exception ex) { throw new RuntimeException("Failed to read file: " + imgFile, ex); } } private byte[] loadLabels(String labelFile) { try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) { int magic = in.readInt(); if (magic != 0x00000801) { throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic)); } int count = in.readInt(); byte[] data = new byte[count]; in.readFully(data); return data; } catch (Exception ex) { throw new RuntimeException("Failed to read file: " + labelFile, ex); } } }
// TestClient.java package client; import data.Mnist; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import org.tensorflow.Tensor; import org.tensorflow.Tensors; import org.tensorflow.framework.DataType; import org.tensorflow.framework.TensorProto; import org.tensorflow.framework.TensorShapeProto; import tensorflow.serving.Model; import tensorflow.serving.Predict; import tensorflow.serving.PredictionServiceGrpc; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.List; /** * @Title:TestClient * @Package:client * @Description: * @author:done * @date:2021/8/16 21:49 */ public class TestClient { public static void main(String[] args) throws Exception { String host = "192.168.110.100"; int port = 8500; // the model's name. // 使用命令查看模型 saved_model_cli show --dir model_dir_path --all String modelName = "saved_model"; // Mnist实例化 Mnist mnist = new Mnist(); mnist.load(); /************************** 单张图片分类 *******************************/ Mnist.Data testData = mnist.getTestData(0); // 获取第一张图片 float[] x = testData.input; // 获取单张图片的输入张量 int seqLen = 784; // 输入大小 System.out.println("data[0]的真实标签为:" + testData.label); /************************** 单张图片分类 *******************************/ /************************** 多张图片分类 *******************************/ // ArrayList<float[]> X = new ArrayList<float[]>(); // Mnist.Data[] data = mnist.getTestSlice(0, 10); // int seqLen = 784 * data.length; // System.out.print("data[0-10]真实标签为:"); // for (int i=0; i<data.length; i++){ // X.add(data[i].input); // System.out.print(data[i].label + " "); // } /************************** 多张图片分类 *******************************/ // create a channel for gRPC ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel); // create a modelspec Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder(); modelSpecBuilder.setName(modelName); modelSpecBuilder.setSignatureName("serving_default"); Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder(); builder.setModelSpec(modelSpecBuilder); // create the input TensorProto and request TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder(); tensorProtoBuilder.setDtype(DataType.DT_FLOAT); /************************** 单张图片分类 *******************************/ for (Float intDatum : x) { // 添加输入 tensorProtoBuilder.addFloatVal(intDatum); } /************************** 单张图片分类 *******************************/ /************************** 多张图片分类 *******************************/ // for (float[] temp: X) { // float[] input = temp; // for (Float intDatum : input) { // tensorProtoBuilder.addFloatVal(intDatum); // } // } /************************** 多张图片分类 *******************************/ // build input TensorProto shape TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1)); tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(seqLen)); tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build()); TensorProto tp = tensorProtoBuilder.build(); builder.putInputs("args_0", tp); // 输入签名 args_0 Predict.PredictRequest request = builder.build(); System.out.println("******************* 请求响应 *******************"); // get response Predict.PredictResponse response = stub.predict(request); // 获取分类概率列表 List<Float> pro = response.getOutputsMap().get("output_1").getFloatValList(); // 输出签名output_1 // 获取分类结果 int pre_y = pro.indexOf(pro.stream().max((o1, o2) -> o1.compareTo(o2)).get()); System.out.println("data[0]的分类结果为:" + pre_y); } static private byte[] loadTensorflowModel(String path){ try { return Files.readAllBytes(Paths.get(path)); } catch (IOException e) { e.printStackTrace(); } return null; } static private Tensor<Double> covertArrayToTensor(double inputs[]){ return Tensors.create(inputs); } }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。