当前位置:   article > 正文

模型手把手系列开篇 之 python、spark 和java 生成TFrecord_spark-tensorflow-connector_2.11-1.14.0.jar

spark-tensorflow-connector_2.11-1.14.0.jar

模型手把手系列开篇 之 python、spark 和java 生成TFrecord


文章源码下载地址:点我下载http://inf.zhihang.info/resources/pay/7692.html

书接上文,我们的 图算法十篇 之 图机器学习系列文章总结 已经完结, 接下来 我们 将开始 从零开始 一点一点 的用 tensorflow 实现 一些 经典的 模型,除了 和大家 一起学习 之外,也是为了 可以帮助自己 对 过去学习过的 知识 做一些系统化 的 总结与回顾 ,进行 查漏补缺

接下来,文章里的很多代码,我们会使用 notebook远程访问pyspark集群, 算法工具神器重磅推荐 文章里介绍的 notebook 工具 进行介绍,而 部分 java 与 scala spark 代码 则 均是目前 我在自己 mac 上搭建的 单机版环境 编写的,如有 任何环境问题,欢迎 在 算法全栈之路 的公众号上 和我联系 ~

模型手把手系列 计划 主要 围着 tensorflow 实现模型 的 流程展开,计划中会涵盖 训练数据的生成数据的读入特征的处理模型结构的搭建损失函数的设计序列建模经典模型的实现 等模块展开,中间 很多内容 可能我也会去 查找 很多资料与源码 ,希望 能够 真正起到 总结自己学习过 的知识、对 各位算法工程师 们在 工作学习面试 等过程中 有所帮助 的 作用吧 !

闲言少叙,本文主要先从 模型训练 的 上游 数据生成开始讲起,主要介绍 使用 python 、spark( pyspark/scala spark ) 、Java 、tfrecorder 等这 4 种方式 生成 tfrecord 的过程 以及使用 python 解码 tfrecord文件的 过程,下面 让 我们 开始正文 吧~ go go go !!!


(1) tensorflow 模型训练数据来源简介

书接上文,我们知道:tensorflow 训练 所 需要的 上游数据,在 数据量比较小 的时候,我们可以用 python 的 pandas 或则 numpy 等方法 直接 在单机PC上 读取数据 然后 喂给模型 ,这种 模型 的 文件类型 可以是 本地的 txt 或则 csv 等格式。当 数据量比较大 的时候,我们通常 将 数据放在 集群 hdfs 上,也可以 保存成 csv 或 txt 的格式,然后 训练的时候去 进行 分布式并行读取

TFRecords 是 TensorFlow 官方推荐 和 支持的二进制文件格式,其对于 tensorflow 非常友好,其对于 特征列多 的数据 存储占用空间 更小。当 数据量特别大,且 io 读取数据成为 模型训练速度 的 瓶颈、甚至 有时候 gpu 的 利用率时高时低 的时候,这个时候 我们可以 将我们 的 数据 保存成 tfrecord 的格式。 这同时也 对应着 tfrecord 的一些优点: 读取速度快、占用空间少、支持并行读取等。这里 我们 就不再对 tfrecord 文件 生成的理论进行 展开说明了,感兴趣 的 同学可以下去 自己搜索资料 哈 ~

虽然本文 是 介绍 tfrecord 的 数据格式,但是我们 选用模型训练数据的数据格式 的时候,也 不一定非要运用 tfrecord 。很多时候 我们 训练模型数据量 不是很大,并且 单机内存 完全可以 hold住 所有 的 数据,而 我们 对 模型 的 训练速度 也 没有那么高要求,这个时候 普通的 csv 和 txt 等格式 简单直接,便于 查看 数据, 也可以 作为 我们的 首选 ~

本文 主要是 介绍 多种方式生成tfrecord 格式的数据,本身就是 偏向于工程的 ,理论性 没那么强,我们 直接开始 看 代码 吧 ~


(2)代码时光

本文我们主要介绍 使用 python 、spark( pyspark/scala spark ) 、Java 、tfrecorder 等 这 4 种方式 生成 tfrecord 的过程 以及 使用python 解码 tfrecord文件 的 过程 ,下面就让我们逐一开始介绍吧 ~

因为 本文的 代码 涉及 多种语言 ,这里我们 对各个模块 分别导包, 可能有冗余的地方,读者 可以 自行进行 区分,对于 代码的可读性 应该 无影响。

(2.1) 数据准备

看代码吧~

@ 欢迎关注作者公众号 算法全栈之路

import pandas as pd
raw_df = pd.DataFrame([[28,12.1,'male',"1#2",1], [30,8.7, 'female',"3#4#5",0], [32,24.6,'female',"6#7#8#9#10",1]], columns=['age', 'price','sex','click_list','label'])
print(raw_df)
raw_df.to_csv("./raw_df.csv",sep='\t',index=False,header=None)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

这里 数据类型 我们分别选择了 搜广推算法 用的最多的 int 型、float型、categroy 型、seq 序列类型特征 以及 label 这几列 数据 用来生成 tfrecord,如果有 其他类型 的 特征 同理 可得。


(2.2) python生成 tfrecord 数据
@ 欢迎关注作者公众号 算法全栈之路

# 文件路径 
intput_csv_file = "./raw_df.csv" 
intput_csv_file = "./py_tf_record" 

# 生成整数型的属性 
def _int64_feature(value): 
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
# 生成浮点数类型的属性 
def _float_feature(value): 
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 
# 生成字符串型的属性 
def _bytes_feature(value): 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
# 生成序列类型的特征
def _int64list_feature(value_list):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))


def generate_tf_records(intput_file_path,out_file_path):
    
    with codecs.open("./raw_df.csv", "r", "utf-8") as raw_file:
        line_list=raw_file.readlines()
        
    print("line_list_len:",len(line_list))
    
    writer = tf.compat.v1.python_io.TFRecordWriter(out_file_path)
    for line in tqdm.tqdm(line_list):
        age = int(line.split("\t")[0])
        price = (float)(line.split("\t")[1])
        gender = line.split("\t")[2]
        click_list =  list(map(int, line.split("\t")[3].split("#")))
        label = int(line.split("\t")[4])
            
        example = tf.train.Example(features=tf.train.Features(
            feature={
                "age": _int64_feature(int(age)),
                "price": _float_feature(float(price)),
                "gender": _bytes_feature(gender.encode()),
                "click_list": _int64list_feature(click_list),
                "label": _int64_feature(label)
            }))
        # 写入一条tfrecord
        writer.write(example.SerializeToString())
    writer.close()

generate_tf_records(intput_csv_file,intput_csv_file)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

这里,我们选择了4种 极具典型 的、搜广推算法常用的特征类型 来进行说明。tfrecord 里有 examplefeature 的概念: example 是 protocol buffer 数据标准的实现,我们 可以认为 每个example 可以 是一条样本(当然也可以有多条样本)。一个 example 消息体 中包含了 一系列的 features ,而 features 里又包含有 featuer, 每一个feature 是一个 dict 形式 的 数据结构。

其中 要注意 的是: click_list 这个表示的是 用户的点击序列特征 ,长度 对于 每个用户 可能不同。我们可以在 这里 传入一个 列表封装到 tfrecord 对象 里去,然后 让 tensorflow 直接读取序列特征 。当然,我们 也可以 这里把 列表拼接成字符串 ,然后 tensorflow 读入进去 之后 再去split 得到序列,只是 模型 会 更耗费时间 而已。

另外需要注意的是 value= 后面 接的是数组,也可以是单个元素。如果 你写的 代码有报 数据格式问题 的话,这里 可能需要重点看下 然后 作出 调整。

这里要 推荐一下 codecs 这个 python 包,其对于 python读写文件 格式的编码转换 非常友好,当读写 数据格式 兼容 会出现 bug 的时候,强烈推荐 codecs 哦。


(2.3) spark 生成 tfrecord 数据( scala spark + pyspark)

书接上文,在 很多时候 数据量比较少 的话,我们可以 用 上面介绍 的 单机版 python 来生成 tfrecord 文件,但是 我们上面 也介绍了: 数据量小的时候,内存足够,用啥tfrecocrd啊,直接上 csv等不香吗? 数据量大 的时候,就得靠 我们 这里 介绍 的 spark 来生成 tfrecord 了,亲测速度快了十数倍不止!

那 我们上面 介绍的 python 单机版生成tfrecord 就 无用武之地 了吗? 当然 不是,天生它才 必有用! 我们 可以 在 开发代码 并 进行流程测试 的时候用 单机版python 去 生成测试 ,保证 整个开发 流程 的 流畅,最后 要大规模跑数 进行 实验 的 时候,改用 本小节介绍 的 spark版本 的 方法 来 提高效率 , 两者结合 简直 perfert !!!

(2.3.1) scala spark 生成 tfrecord

因为 scala 和 java 均是跑在 虚拟机jvm 的 语言,在 maven 工程里 是 可以 混合编译 互相调用 的。 要想使用 spark 直接生成 tfrecord ,需要用到 google 提供的 spark 和 tensorflow 交互的包

pom.xml 里导入这个包就可以

@ 欢迎关注作者公众号 算法全栈之路

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>spark-tensorflow-connector_2.11</artifactId>
            <version>1.15.0</version>
        </dependency>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

然后下面是我提供了一个 基于scala spark 生成 tfrecord 的demo ,中间 的 环境 是我 单机版的spark ,可能你 用的时候 这里 需要微调,非常简单,自己 去适配 下吧~

@ 欢迎关注作者公众号 算法全栈之路

package zmt_demo.model_sbs

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

object Demo {

  def main(args: Array[String]) {
    val sparkConf = new SparkConf()
      // .registerKryoClasses(Array(classOf[XgbScoreRow]))
      // 调节长数据本地化时间
      .setMaster("local[*]")
      .set("spark.locality.wait", "10")
      .set("spark.sql.orc.enabled", "false")

    val sparkSession = SparkSession.builder()
      .appName("scala spark generate tfrecord")
      .config(sparkConf)
      .config("spark.kryoserializer.buffer.max", "1024m")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config("hive.exec.dynamic.partition.mode", "nonstrict")
      .enableHiveSupport()
      .getOrCreate()

    val demo_df = sparkSession.read
      .option("inferSchema", "false") //是否自动推到内容的类型
      .option("delimiter","\t")  //分隔符,默认为 ,
      .csv("/Users/dhl/Desktop/notebook_all/模型手把手系列/raw_df.csv")
      .toDF("age","price","sex","click_list","label")
      .withColumn("click_list",split(col("click_list"),"#"))

    demo_df.printSchema()
    demo_df.show(2,false)

    val savedPath = "/Users/dhl/Desktop/notebook_all/模型手把手系列/scala_spark_tfcord"

    demo_df.write
      .mode("overwrite")
      .format("tfrecords")
      .option("recordType", "Example")
      .save(savedPath)
  }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

使用单机版的 spark,我们 在自己 mac 就能进行 业务流程代码 的 调试哈,不用 在 链接spark集群 就可以 完成 spark 代码的调试,当然 数据 是 需要 我们 本地伪造 的 ~

对于一些 使用spark RDD 接口 较多 的同学,可以 先将 RDD 转 dataframe ,然后 在 生成tfrecord 哦 !


(2.3.2) pyspark 生成 tfrecord

目前 在 国内大厂,还是有 很多公司 的 算法团队 使用 pyspark 非常频繁 ,这里 我们 也提供下 pyspark 版本生成 tfrecord 的代码吧。

中间在用 spark-submit 提交 pyspark 脚本任务的时候,需要在最后参数列表里加上

--jars /Users/dhl/Desktop/notebook_all//spark-tensorflow-connector_2.11-1.15.0.jar

其实 作用和 maven 类似 ,和 上面一样 引入我们 的 Jar 包 。导完包后,就可以 写代码 提spark job 任务 了。

@ 欢迎关注作者公众号 算法全栈之路

import os
import sys
import findspark
findspark.init()
import os.path as path
import importlib

from pyspark import StorageLevel
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from py4j.protocol import Py4JJavaError
from pyspark.sql import functions as fun
from pyspark.sql.functions import col
from pyspark.sql import HiveContext
from pyspark.sql.functions import *
from pyspark.sql.functions import lit

import warnings 
warnings.filterwarnings("ignore")
# spark config setup
spark = SparkSession.builder.appName("pyspark-app") \
    .config("spark.submit.deployMode", "client")\
    .config('spark.yarn.queue', 'idm-prod')\
    .config("spark.kryoserializer.buffer.max", "1024m") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("hive.exec.dynamic.partition.mode", "nonstrict") \
    .enableHiveSupport()\
    .getOrCreate()

path="./pyspark_tfrecord"

pdf_values=raw_df.values.tolist()
pdf_columns=raw_df.columns.tolist()
spark_df = spark.createDataFrame(pdf_values,pdf_columns).persist(StorageLevel.MEMORY_AND_DISK)

spark_df.write.format("tfrecords").option("recordType", "Example").save(path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

从代码里可以看到,我们这里是使用 python 的pandas dataframe直接转的 pyspark 的 dataframe ,然后由 spark 的 dataframe 直接保存成 tfrecord 的格式。

其中需要注意的一点是: option("recordType", "Example") 这个地方的参数。 当然,对于 序列特征,我们也可以使用 SequenceExample 这个参数来生成。

但是 对于 序列特征,我们只要在 特征列 的 位置放入 列表元素 就 可以,tensorflow 读入 list 数据之后 再去转 序列特征 处理 也是可以的。

这里和上面一样,这里的 pyspark 方法 也可以和 上面 小节介绍 的 python方法 相互结合 使用,达到 pyspark + python 包来 生成 tfrecord 的目的,非常优秀!!!

这里 我就不在去 具体实现 了哈,但是 pyspark + python 自定义函数 与 scala + java 自定义函数的联合使用,可以说 是 灵活开发 的 典范之作 了 !


(2.4) java 生成 tfrecord 数据

书接上文,我们说了 Java 和 scala spark 代码可以 混合编译,然后进行 互相灵活调用 的,我们 这里 介绍的 Java 版本 的 生成tfrecord 的 函数与方法 ,也是可以 结合上面 介绍的 scala spark 方法,在 spark 的 map算子 调用 这里介绍 的 方法,达到 spark + java相结合的方法来生成 tfrecord 格式文件,对于 广大的 javaer 们,算是 非常友好 了。

要想用 Java 生成 tfrecord 数据,需要导入下面这 两个jar 包 ,其中一个 和 上面要用到 的 重复。

@ 欢迎关注作者公众号 算法全栈之路

       <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>spark-tensorflow-connector_2.11</artifactId>
            <version>1.15.0</version>
        </dependency>

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.5.0</version>
        </dependency>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

Java代码嘛,没说的,就是包多!! 导入就是了 。

@ 欢迎关注作者公众号 算法全栈之路

package demo;

import java.io.*;
import java.util.*;
import org.tensorflow.example.Example;
import org.tensorflow.example.Feature;
import org.tensorflow.example.Features;
import org.tensorflow.example.Int64List;
import org.tensorflow.example.*;
import org.tensorflow.spark.shaded.com.google.protobuf.ByteString;
import org.tensorflow.spark.shaded.org.tensorflow.hadoop.util.TFRecordWriter;


public class Generate_TFrecord_Demo {

    public static void main(String[] args) throws IOException {

        TFRecordWriter tf_write = new TFRecordWriter(new DataOutputStream(new FileOutputStream("/Users/dhl/Desktop/notebook_all/模型手把手系列/java_tfcord")));

        Map<String, Object> featureMap = new HashMap<>();
        featureMap.put("age", "20");
        featureMap.put("price", "15.5");
        featureMap.put("sex", "male");
        featureMap.put("click_list", Arrays.asList("1", "2", "3"));
        featureMap.put("label", "1");

        Map<String, Feature> inputFeatureMap = new HashMap<String, Feature>();

        for (String key : featureMap.keySet()) {
            Feature feature = null;

            if (key.equals("sex")) {
                BytesList.Builder byteListBuilder = BytesList.newBuilder();
                ByteString bytes = ByteString.copyFromUtf8((String) featureMap.get(key));
                byteListBuilder.addValue(bytes);
                feature = Feature.newBuilder().setBytesList(byteListBuilder.build()).build();
            } else if (key.equals("age")) {
                Int64List.Builder int64ListBuilder = Int64List.newBuilder();
                int64ListBuilder.addValue(Integer.parseInt(featureMap.get(key).toString()));
                feature = Feature.newBuilder().setInt64List(int64ListBuilder.build()).build();
            } else if (key.equals("price")) {
                FloatList.Builder floatListBuilder = FloatList.newBuilder();
                floatListBuilder.addValue(Float.parseFloat(featureMap.get(key).toString()));
                feature = Feature.newBuilder().setFloatList(floatListBuilder.build()).build();
            } else if (key.equals("click_list")) {
                List<String> stringList = (List<String>) featureMap.get(key);
                List<ByteString> byteStrings = new ArrayList<ByteString>();
                for (String s : stringList) {
                    byteStrings.add(ByteString.copyFromUtf8(s));
                }
                BytesList.Builder byteListBuilder = BytesList.newBuilder();
                byteListBuilder.addAllValue(byteStrings);
                feature = Feature.newBuilder().setBytesList(byteListBuilder.build()).build();
            }

            if (feature != null) {
                inputFeatureMap.put(key, feature);
            }
        }

        Features features = Features.newBuilder().putAllFeature(inputFeatureMap).build();
        Example example = Example.newBuilder().setFeatures(features).build();
        System.out.println(example.getFeatures());

        // java 版本 tfrecord 生成与写入
        tf_write.write(example.toByteArray());
    }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71

这里,我们 把文件 写入了 我 自己本机 的 路径,也选择了 几个常用 的 特征类型 来使用 Java生成tfrecord 文件,自己 去 按需求更改 吧。


(2.5)python 的 tfrecorder 生成tfrecord

在我 最后 开始写小作文 做总结 的时候,偶然发现了 这个python 包 : tfrecorder ,我们 可以使用pip install tfrecorder来 进行安装。

虽然也 是python 单机版 的包,但是这个包可以 不用写代码 显式的 打开 csv 文件 进行 文件转换,非常强大了!

下面的 两种方式 均是 使用google 开源的tfrecorder 这个包工具的。

闲言少叙,看代码吧~

(2.5.1) csv 直接转tfrecord

实现的功能如题,单机版python神器啊!

@ 欢迎关注作者公众号 算法全栈之路

import tfrecorder

tfrecorder.create_tfrecords(
    input_data='./raw_df.csv',
    output_dir='./csv_tfrecord')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

(2.5.2) pandas dataframe 直接转 tfrecord
@ 欢迎关注作者公众号 算法全栈之路

import pandas as pd
import tfrecorder

raw_df.tensorflow.to_tfr(output_dir='./pd_tfrecord')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

这个 工具 有一个 ,就是 安装的时候 依赖 比较多,会 出现 包冲突 的情况,很难缠。看说明 好像 google 已经 放弃维护 这个包 了,最后 更新时间 在2020年 ?

不管了,上面 介绍的方法足够多 ,总有一种 姿势 可以满足 你。


(2.6) 解码 tfrecord 文件

这里 要 重点介绍 下: 因为 tfrecord 是 二进制文件 ,我们 生成了之后 如何 查看里面数据结构 呢?

简单! 用下面的方法就可以了 ,看代码 ~

@ 欢迎关注作者公众号 算法全栈之路

import tensorflow.compat.v1 as tf

def getTFRecordFormat(files):
    with tf.Session() as sess:
        # 加载TFRecord数据
        ds = tf.data.TFRecordDataset(files)
        ds = ds.batch(1)
        ds = ds.prefetch(buffer_size=2)
        iterator = ds.make_one_shot_iterator()
        # 为了加快速度,仅仅简单拿一组数据看下结构
        batch_data = iterator.get_next()
        while True:
                res = sess.run(batch_data)
                for serialized_example in res:
                    example_proto = tf.train.Example.FromString(serialized_example)
                    features = example_proto.features

                    for key in features.feature:
                        feature = features.feature[key]
                        if len(feature.bytes_list.value) > 0:
                            ftype = 'bytes_list'
                            fvalue = feature.bytes_list.value

                        if len(feature.float_list.value) > 0:
                            ftype = 'float_list'
                            fvalue = feature.float_list.value

                        if len(feature.int64_list.value) > 0:
                            ftype = 'int64_list'
                            fvalue = feature.int64_list.value
                        result = '{0} : {1} {2} {3}'.format(key, ftype, len(fvalue),fvalue)
                        print(result)
                    break
                    print("*"*20)
                break

# getTFRecordFormat('./pd_tfrecord')
getTFRecordFormat('./py_tf_record')
# getTFRecordFormat('./pyspark_tfrecord/part-r-00007')
# getTFRecordFormat('./scala_spark_tfcord/part-r-00000')
# getTFRecordFormat('./java_tfcord')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

注意,这里 我们 使用 的是 tensorflow 1.x 的 版本 ~

最后 tfrecord文件解析 出来 在 我们 的 demo 式例中 长这个样子:

到这里,模型手把手系列开篇 之 python、spark 和java 生成TFrecord 的 全文就写完了。 在本文里,我们 提供了 众多生成 tfrecord 的 方法与工具,代码均可以完美跑成功,总有一款适合你,希望可以对你有参考作用 ~


码字不易,觉得有收获就动动小手转载一下吧,你的支持是我写下去的最大动力 ~

更多更全更新内容 : 算法全栈之路

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/123777
推荐阅读
相关标签
  

闽ICP备14008679号