当前位置:   article > 正文

Tensorflow Serving部署-Python/Java客户端调用_tensorflow java-client

tensorflow java-client

本项目代码

一、准备工作
1、加载数据集
import tensorflow as tf
import numpy as np 
class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()
        # 归一化,增加颜色通道 [60000, 28, 28, 1]
        self.x_train = np.expand_dims(self.x_train.astype(np.float32) / 255.0, axis=-1)  
        # [10000, 28, 28, 1]
        self.x_test = np.expand_dims(self.x_test.astype(np.float32) / 255.0, axis=-1)
        # 将标签转换为整型
        self.y_train = self.y_train.astype(np.int32)
        self.y_test = self.y_test.astype(np.int32)
        # 获取训练集和测试集的总数
        self.x_train_count, self.x_test_count = self.x_train.shape[0], self.x_test.shape[0]
        
    def get_batch(self, batch_size):
        # 从0-60000随机选择batch_size个元素
        index = np.random.randint(0, np.shape(self.x_train)[0], batch_size)
        return self.x_train[index, :], self.y_train[index]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
2、构建模型
class MLP(tf.keras.Model):
    def __init__(self):
        super(MLP, self).__init__()
        # 将除第一维以外的维度展平
        self.flatten = tf.keras.layers.Flatten()
        # units 为输出张量的维度
        self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)
    
    @tf.function(input_signature=[tf.TensorSpec([None, 28, 28, 1], tf.float32)])
    def call(self, inputs):  # [batch_size, 28, 28, 1]
        x = self.flatten(inputs)
        x = self.dense1(x)
        x = self.dense2(x)
        output = tf.nn.softmax(x)
        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
3、模型超参数
epochs = 5
batch_size = 50
learning_rate = 0.001
  • 1
  • 2
  • 3
4、实例化模型
model = MLP()
data_loader = MNISTLoader()
# 实例化优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
  • 1
  • 2
  • 3
  • 4
1、从data_loader中随机取一批数据
2、将这批数据送入模型,计算出模型的预测值
3、预测值与真实值比较,计算损失函数
4、计算损失函数关于模型变量的导数
5、将求出的导数值传入优化器中,使用优化器更新模型参数以最小化损失函数
  • 1
  • 2
  • 3
  • 4
  • 5
num_batches = int(data_loader.x_train_count // batch_size * epochs)
for batch_index in range(num_batches):
    X, Y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        # 预测值与真实值比较,计算损失函数
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=Y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    
    # 计算梯度
    grads = tape.gradient(loss, model.variables)
    # 自动根据梯度更新参数
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
5、模型评估
# 实例化评估器
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.x_test_count // batch_size)
for batch_index in range(num_batches):
    # 定义一个batch的开始和结束位置
    start_index, end_index = batch_index * batch_size,(batch_index+1) * batch_size
    y_pred = model.predict(data_loader.x_test[start_index: end_index])
    sparse_categorical_accuracy.update_state(y_true=data_loader.y_test[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % sparse_categorical_accuracy.result())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
6、模型保存
tf.saved_model.save(model, "D:/file/model/")

assets #模型依赖的外部文件,比如vocab
saved_model.pb #模型的网络结构,可以接受tensor输入,计算完后输出tensor
# saved_model.pb或saved_model.pbtxt是SavedModel协议缓冲区。它将图形定义作为MetaGraphDef协议缓冲区。MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示
variables #模型的参数
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

模型文件结构

7、检查保存的模型
saved_model_cli show --dir model_dir_path --all
  • 1

在这里插入图片描述

二、部署模型
1、安装docker并拉取TensorFlow Serving
# 卸载旧docker(若之前安装过docker)该步可选
yum remove docker \
                  docker-client \
                  docker-client-latest \
                  docker-common \
                  docker-latest \
                  docker-latest-logrotate \
                  docker-logrotate \
                  docker-engine
# 下载需要的安装包
yum install -y yum-utils
# 设置阿里云的Docker镜像仓库
yum-config-manager \
    --add-repo \
    https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo
# 更新yum软件包索引
yum makecache fast
# 安装docker(社区版)
yum install docker-ce docker-ce-cli containerd.io

# 拉取TensorFlow Serving
docker pull tensorflow/serving
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
2、模型上传
3、docker部署(单模型)
docker run -p 8501:8501  --mount type=bind,source=/home/linjie/model,target=/models/saved_model  -e MODEL_NAME=saved_model  -t tensorflow/serving &(RESTful API)
    
docker run -p 8500:8500  --mount type=bind,source=/home/linjie/model,target=/models/saved_model  -e MODEL_NAME=saved_model  -t tensorflow/serving &(gRPC)
  • 1
  • 2
  • 3

在这里插入图片描述

注意:若要进行多模型部署,建议使用config
1、先配置config文件
model_config_list:{
	config:{
		name:"z_model",  # 名字随意
		base_path:"/models/ble/z_model",  # 一定要用/models/XXXX
		model_platform:"tensorflow"
	},

	config:{
		name:"xy_model",
		base_path:"/models/ble/xy_model",
		model_platform:"tensorflow"
	}
}
2、进行多模型部署
docker run -p 8500:8500 -p 8501:8501  --mount type=bind,source=/home/ble/,target=/models/ble -t tensorflow/serving  --model_config_file=/models/ble/model.config
# 其中,model_config_file路径也要用/models/XXX,端口8500为gRPC方式调用,端口8501位RESTful API方式调用
3、多模型部署后,请求地址也有稍微不同
原地址:http://192.168.110.100:8501/v1/models/saved_model:predict
现地址:http://192.168.110.100:8501/v1/models/(config中模型的name):predict
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
以上均在CentOS7虚拟机上进行,所用详细命令暂不给出
  • 1
4、docker命令
// 查看正在运行的容器
docker ps
// 停止容器
docker stop 容器ID
// 查看当前容器状态
service docker status
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
三、Python客户端
1、加载数据集
import tensorflow as tf 
import numpy as np 
class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()
        # 归一化,增加颜色通道 [60000, 28, 28, 1]
        self.x_train = np.expand_dims(self.x_train.astype(np.float32) / 255.0, axis=-1)  
        # [10000, 28, 28, 1]
        self.x_test = np.expand_dims(self.x_test.astype(np.float32) / 255.0, axis=-1)
        # 将标签转换为整型
        self.y_train = self.y_train.astype(np.int32)
        self.y_test = self.y_test.astype(np.int32)
        # 获取训练集和测试集的总数
        self.x_train_count, self.x_test_count = self.x_train.shape[0], self.x_test.shape[0]
        
    def get_batch(self, batch_size):
        # 从0-60000随机选择batch_size个元素
        index = np.random.randint(0, np.shape(self.x_train)[0], batch_size)
        return self.x_train[index, :], self.y_train[index]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
2、读取文件
import json
import requests 
data_loder = MNISTLoader()
data = json.dumps({"instances": data_loder.x_test[0:10].tolist()})
headers = {"content-type": "application/json"}
json_response = requests.post('http://192.168.110.100:8501/v1/models/saved_model:predict',data=data, headers=headers)
pre = np.array(json.loads(json_response.text)['predictions'])
print(np.argmax(pre, axis=-1))
print(data_loder.y_test[0:10])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
request.post(url, data=None, json=None, **kwargs)
# 返回响应对象
  • 1
  • 2
四、Java客户端
(gRPC调用)若使用Java作为客户端,则需要编译proto文件 
参考地址:
1、https://github.com/junwan01/tensorflow-serve-client
2、https://www.cnblogs.com/ustcwx/p/12768463.html
  • 1
  • 2
  • 3
  • 4
1、获取protobuf文件
// 需要注意版本问题,由.proto文件编译出来的java class依赖tensorflow的jar包,可能存在不兼容问题
【Linux】
export SRC=~/Documents/source_code/
mkdir -p $SRC

cd $SRC
git clone git@github.com:tensorflow/serving.git
cd serving
git checkout tags/2.1.0

cd $RSC
git clone git@github.com:tensorflow/tensorflow.git
cd tensorflow
git checkout tags/v2.1.0

【Windows】
// 创建文件夹
mkdir D:/file/source_code
cd D:/file/source_code
// git下载serving
git clone https://github.com/tensorflow/serving
cd serving
git checkout tags/2.1.0

cd D:/file/source_code
// git下载tensorflow
git clone https://github.com/tensorflow/tensorflow
cd tensorflow
git checkout tags/v2.1.0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
2、将需要的.proto文件复制到java工程下
$ mkdir -p $PROJECT_ROOT/src/main/proto/
$ rsync -arv  --prune-empty-dirs --include="*/" --include='*.proto' --exclude='*' $SRC/serving/tensorflow_serving  $PROJECT_ROOT/src/main/proto/
$ rsync -arv  --prune-empty-dirs --include="*/" --include="tensorflow/core/lib/core/*.proto" --include='tensorflow/core/framework/*.proto' --include="tensorflow/core/example/*.proto" --include="tensorflow/core/protobuf/*.proto" --exclude='*' $SRC/tensorflow/tensorflow  $PROJECT_ROOT/src/main/proto/

// 因未安装rsync,所以直接拷贝前人准备好的proto文件放置java工程中
参考地址:https://github.com/junwan01/tensorflow-serve-client/tree/master/src/main/proto
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
3、生成java源码
<!-- 在maven项目中添加依赖 -->
    <properties>
        <grpc.version>1.20.0</grpc.version>
    </properties>
    
    <dependencies>
        <!-- gRPC protobuf client -->
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-protobuf</artifactId>
            <version>${grpc.version}</version>
        </dependency>
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-stub</artifactId>
            <version>${grpc.version}</version>
        </dependency>
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-netty-shaded</artifactId>
            <version>${grpc.version}</version>
        </dependency>
        <dependency>
            <groupId>com.google.api.grpc</groupId>
            <artifactId>proto-google-common-protos</artifactId>
            <version>1.0.0</version>
        </dependency>
    </dependencies>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
4、安装protoc工具
brew install protobuf(Windows实测不行,没有brew命令)
  • 1
5.1、插件编译proto文件
<!-- 通过maven添加编译插件(注意版本) -->
<build>
        <extensions>
            <extension>
                <groupId>kr.motd.maven</groupId>
                <artifactId>os-maven-plugin</artifactId>
                <version>1.6.2</version>
            </extension>
        </extensions>
        <plugins>
            <plugin>
                <groupId>org.xolstice.maven.plugins</groupId>
                <artifactId>protobuf-maven-plugin</artifactId>
                <version>0.6.1</version>
                <executions>
                    <execution>
                        <goals>
                            <goal>compile</goal>
                            <goal>compile-custom</goal>
                        </goals>
                    </execution>
                </executions>
                <configuration>
                    <checkStaleness>true</checkStaleness>
                    <protocArtifact>com.google.protobuf:protoc:3.6.1:exe:${os.detected.classifier}</protocArtifact>
                    <pluginId>grpc-java</pluginId>
                    <pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
                </configuration>
            </plugin>
        </plugins>
</build>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
// 在当前工程根路径下执行命令
mvn protobuf:compile(执行报错,版本问题未解决)
  • 1
  • 2

在这里插入图片描述

编译完成之后,在$PROJECT_ROOT/src/main/resources下会增加一个new_old的文件夹,将里面的./org/tensorflow 和 ./tensorflow/serving 两个文件夹移动至PROJECT_ROOT/src/main/java下即可
执行失败,所以直接拷贝前人的文件至工程路径。
参考链接:https://github.com/junwan01/tensorflow-serve-client/tree/master/target/generated-sources/protobuf
  • 1
  • 2
  • 3

proto文件

5.2、手动编译proto文件
手动编译相较前者麻烦些,但是可以编译出静态的代码集成至工程中,而不是每次运行都动态生成(未尝试)
  • 1
// grpc-java repo代码地址:https://github.com/grpc/grpc-java
$ cd $SRC
$ git clone https://github.com/grpc/grpc-java.git
Cloning into 'grpc-java'...
remote: Enumerating objects: 166, done.
remote: Counting objects: 100% (166/166), done.
remote: Compressing objects: 100% (121/121), done.
remote: Total 84096 (delta 66), reused 92 (delta 25), pack-reused 83930
Receiving objects: 100% (84096/84096), 31.18 MiB | 23.14 MiB/s, done.
Resolving deltas: 100% (38843/38843), done.

$ cd grpc-java/compiler/
$ ../gradlew java_pluginExecutable
$ ls -l build/exe/java_plugin/protoc-gen-grpc-java
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
// 运行shell脚本,编译protobuf文件
export SRC=~/code/TFS_source/
export PROJECT_ROOT=~/java/JavaClient/
cd $PROJECT_ROOT/src/main/proto/
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/example/*.proto
# append by wangxiao
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/core/logging.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/stream_executor/dnn.proto

protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/apis/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/config/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/util/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/sources/storage_path/*.proto

# the following 3 cmds will generate extra *Grpc.java stub source files in addition to the regular protobuf Java source files.
# The output grpc-java files are put in the same directory as the regular java source files.
# note the --plugin option uses the grpc-java plugin file we created in step 1.
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/protobuf/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/lib/core/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/framework/*.proto

protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/prediction_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java
protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/model_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java
protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/session_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
运行正常的情况下,$PROJECT_ROOT/src/main/java/ 文件夹里应该新增了/org/tensorflow 和 /tensorflow/serving 两个文件夹,至此,编译结束!
  • 1
6、JavaClient

1、参考源码

package client;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import tensorflow.serving.Model;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
import java.util.ArrayList;
import java.util.List;

public class FastTextTFSClient {

    /**
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        String host = "127.0.0.1";
        int port = 8500;
        // the model's name.
        String modelName = "fastText";
        int seqLen = 50;

        // assume this model takes input of free text, and make some sentiment prediction.
        List<Integer> intData = new ArrayList<Integer>();
        for(int i=0; i < seqLen; i++){
            intData.add(i);
        }

        // create a channel for gRPC
        ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
        PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);

        // create a modelspec
        Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
        modelSpecBuilder.setName(modelName);
        modelSpecBuilder.setSignatureName("fastText_sig_def");

        Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
        builder.setModelSpec(modelSpecBuilder);

        // create the input TensorProto and request
        TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
        tensorProtoBuilder.setDtype(DataType.DT_INT32);
        for (Integer intDatum : intData) {
            tensorProtoBuilder.addIntVal(intDatum);
        }
        // build input TensorProto shape
        TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(seqLen));
        tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
      
        TensorProto tp = tensorProtoBuilder.build();
        builder.putInputs("input_x", tp);
        Predict.PredictRequest request = builder.build();
        
        // get response
        Predict.PredictResponse response = stub.predict(request);
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

2、我的代码

1、以MNIST数据集为例,在java客户端进行调用
2、需要编写load(读取)MNIST数据集的代码
  • 1
  • 2
<!-- 需要在pom文件中添加依赖 -->
<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.7.0</version>
</dependency>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
// Mnist.java

package data;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;

/**
 * @Title:Mnist
 * @Package:com.linjie.client
 * @Description:
 * @author:done
 * @date:2021/8/12 21:36
 */

public class Mnist {

    public static class Data {
        public byte[] data;
        public int label;
        public float[] input;
        public float[] output;
    }

    public static void main(String[] args) throws Exception {
        Mnist mnist = new Mnist();
        mnist.load();
        System.out.println("Data loaded.");
        Random rand = new Random(System.nanoTime());
        for (int i = 0; i < 20; i++) {
            int idx = rand.nextInt(60000);
            Data d = mnist.getTrainingData(idx);
            BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
            for (int x = 0; x < 28; x++) {
                for (int y = 0; y < 28; y++) {
                    img.setRGB(x, y, toRgb(d.data[y * 28 + x]));
                }
            }
            File output = new File(i + "_" + d.label + ".png");
            if (!output.exists()) {
                output.createNewFile();
            }
            ImageIO.write(img, "png", output);
        }
    }

    static int toRgb(byte bb) {
        int b = (255 - (0xff & bb));
        return (b << 16 | b << 8 | b) & 0xffffff;
    }

    Data[] trainingSet;
    Data[] testSet;

    public void shuffle() {
        Random rand = new Random();
        for (int i = 0; i < trainingSet.length; i++) {
            int x = rand.nextInt(trainingSet.length);
            Data d = trainingSet[i];
            trainingSet[i] = trainingSet[x];
            trainingSet[x] = trainingSet[i];
        }
    }

    public Data getTrainingData(int idx) {
        return trainingSet[idx];
    }

    public Data[] getTrainingSlice(int start, int count) {
        Data[] ret = new Data[count];
        System.arraycopy(trainingSet, start, ret, 0, count);
        return ret;
    }

    public Data getTestData(int idx) {
        return testSet[idx];
    }

    public Data[] getTestSlice(int start, int count) {
        Data[] ret = new Data[count];
        System.arraycopy(testSet, start, ret, 0, count);
        return ret;
    }

    public void load() {
        trainingSet = load("D:\\dowl\\mnist_dataset\\mnist_dataset\\train-images-idx3-ubyte.gz", "D:\\dowl\\mnist_dataset\\mnist_dataset\\train-labels-idx1-ubyte.gz");
        testSet = load("D:\\dowl\\mnist_dataset\\mnist_dataset\\t10k-images-idx3-ubyte.gz", "D:\\dowl\\mnist_dataset\\mnist_dataset\\t10k-labels-idx1-ubyte.gz");
        if (trainingSet.length != 60000 || testSet.length != 10000) {
            throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);
        }
    }

    private Data[] load(String imgFile, String labelFile) {
        byte[][] images = loadImages(imgFile);
        byte[] labels = loadLabels(labelFile);
        if (images.length != labels.length) {
            throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile);
        }
        int len = images.length;
        Data[] data = new Data[len];
        for (int i = 0; i < len; i++) {
            data[i] = new Data();
            data[i].data = images[i];
            data[i].label = 0xff & labels[i];
            data[i].input = dataToInput(images[i]);
            data[i].output = labelToOutput(labels[i]);
        }
        return data;
    }

    private float[] labelToOutput(byte label) {
        float[] o = new float[10];
        o[label] = 1;
        return o;
    }

    private float[] dataToInput(byte[] b) {
        float[] d = new float[b.length];
        for (int i = 0; i < b.length; i++) {
            d[i] = (b[i] & 0xff) / 255.f;
        }
        return d;
    }

    private byte[][] loadImages(String imgFile) {
        try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {
            int magic = in.readInt();
            if (magic != 0x00000803) {
                throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
            }
            int count = in.readInt();
            int rows = in.readInt();
            int cols = in.readInt();
            if (rows != 28 || cols != 28) {
                throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);
            }
            byte[][] data = new byte[count][rows * cols];
            for (int i = 0; i < count; i++) {
                in.readFully(data[i]);
            }
            return data;
        } catch (Exception ex) {
            throw new RuntimeException("Failed to read file: " + imgFile, ex);
        }
    }

    private byte[] loadLabels(String labelFile) {
        try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {
            int magic = in.readInt();
            if (magic != 0x00000801) {
                throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
            }
            int count = in.readInt();
            byte[] data = new byte[count];
            in.readFully(data);
            return data;
        } catch (Exception ex) {
            throw new RuntimeException("Failed to read file: " + labelFile, ex);
        }
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
// TestClient.java

package client;
import data.Mnist;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;

/**
 * @Title:TestClient
 * @Package:client
 * @Description:
 * @author:done
 * @date:2021/8/16 21:49
 */
public class TestClient {

    public static void main(String[] args) throws Exception {
        String host = "192.168.110.100";
        int port = 8500;
        // the model's name.
        // 使用命令查看模型 saved_model_cli show --dir model_dir_path --all
        String modelName = "saved_model";
        
        // Mnist实例化
        Mnist mnist = new Mnist();
        mnist.load();

        /************************** 单张图片分类 *******************************/
        
        Mnist.Data testData = mnist.getTestData(0);  // 获取第一张图片
        float[] x = testData.input;  // 获取单张图片的输入张量
        int seqLen = 784;  // 输入大小
        System.out.println("data[0]的真实标签为:" + testData.label);

        /************************** 单张图片分类 *******************************/
        /************************** 多张图片分类 *******************************/

//        ArrayList<float[]> X = new ArrayList<float[]>();
//        Mnist.Data[] data = mnist.getTestSlice(0, 10);
//        int seqLen = 784 * data.length;
//        System.out.print("data[0-10]真实标签为:");
//        for (int i=0; i<data.length; i++){
//            X.add(data[i].input);
//            System.out.print(data[i].label + " ");
//        }
        
        /************************** 多张图片分类 *******************************/

        // create a channel for gRPC
        ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
        PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);

        // create a modelspec
        Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
        modelSpecBuilder.setName(modelName);
        modelSpecBuilder.setSignatureName("serving_default");

        Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
        builder.setModelSpec(modelSpecBuilder);

        // create the input TensorProto and request
        TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
        tensorProtoBuilder.setDtype(DataType.DT_FLOAT);

        /************************** 单张图片分类 *******************************/
        
        for (Float intDatum : x) {
            // 添加输入
            tensorProtoBuilder.addFloatVal(intDatum);
        }
        
        /************************** 单张图片分类 *******************************/
        /************************** 多张图片分类 *******************************/
        
//        for (float[] temp: X) {
//            float[] input = temp;
//            for (Float intDatum : input) {
//                tensorProtoBuilder.addFloatVal(intDatum);
//            }
//        }
        
        /************************** 多张图片分类 *******************************/

        // build input TensorProto shape
        TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(seqLen));
        tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());

        TensorProto tp = tensorProtoBuilder.build();
        builder.putInputs("args_0", tp);  // 输入签名 args_0
        Predict.PredictRequest request = builder.build();

        System.out.println("******************* 请求响应 *******************");
        // get response
        Predict.PredictResponse response = stub.predict(request);
        // 获取分类概率列表
        List<Float> pro = response.getOutputsMap().get("output_1").getFloatValList(); // 输出签名output_1
        // 获取分类结果
        int pre_y = pro.indexOf(pro.stream().max((o1, o2) -> o1.compareTo(o2)).get());
        System.out.println("data[0]的分类结果为:" + pre_y);
    }

    static private byte[] loadTensorflowModel(String path){
        try {
            return Files.readAllBytes(Paths.get(path));
        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;
    }

    static private Tensor<Double> covertArrayToTensor(double inputs[]){
        return Tensors.create(inputs);
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/330744
推荐阅读
相关标签
  

闽ICP备14008679号