赞
踩
MindSpore易点通·精讲系列–数据集加载之TFRecordDataset
本文开发环境
本文内容摘要
TFRecord
格式是TensorFlow
官方设计的一种数据格式。
TFRecord
格式是一种用于存储二进制记录序列的简单格式,该格式能够更好的利用内存,内部包含多个tf.train.Example
,在一个Examples
消息体中包含一系列的tf.train.feature
属性,而每一个feature
是一个key-value
的键值对,其中key
是string类型,value
的取值有三种:
string
和byte
两种数据类型float(float32)
和double(float64)
两种数据类型bool, enum, int32, uint32, int64, uint64
数据类型上面简单介绍了TFRecord
的知识,下面我们就要进入正题,来谈谈MindSpore
中对TFRecord
格式的支持。
老传统,先来看看官方对API的描述。
下面对主要参数做简单介绍:
本文使用的是
THUCNews
数据集,如果需要将该数据集用于商业用途,请联系数据集作者。
由于下文需要用到TFRecord
数据集来做加载,本节先来生成TFRecord
数据集。对TensorFlow
不了解的读者可以直接照搬代码即可。
生成TFRecord
代码如下:
import codecs import os import re import six import tensorflow as tf from collections import Counter def _int64_feature(values): """Returns a TF-Feature of int64s. Args: values: A scalar or list of values. Returns: A TF-Feature. """ if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def _float32_feature(values): """Returns a TF-Feature of float32s. Args: values: A scalar or list of values. Returns: A TF-Feature. """ if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(float_list=tf.train.FloatList(value=values)) def _bytes_feature(values): """Returns a TF-Feature of bytes. Args: values: A scalar or list of values. Returns: A TF-Feature """ if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) def convert_to_feature(values): """Convert to TF-Feature based on the type of element in values. Args: values: A scalar or list of values. Returns: A TF-Feature. """ if not isinstance(values, (tuple, list)): values = [values] if isinstance(values[0], int): return _int64_feature(values) elif isinstance(values[0], float): return _float32_feature(values) elif isinstance(values[0], bytes): return _bytes_feature(values) else: raise ValueError("feature type {0} is not supported now !".format(type(values[0]))) def dict_to_example(dictionary): """Converts a dictionary of string->int to a tf.Example.""" features = {} for k, v in six.iteritems(dictionary): features[k] = convert_to_feature(values=v) return tf.train.Example(features=tf.train.Features(feature=features)) def get_txt_files(data_dir): cls_txt_dict = {} txt_file_list = [] # get files list and class files list. sub_data_name_list = next(os.walk(data_dir))[1] sub_data_name_list = sorted(sub_data_name_list) for sub_data_name in sub_data_name_list: sub_data_dir = os.path.join(data_dir, sub_data_name) data_name_list = next(os.walk(sub_data_dir))[2] data_file_list = [os.path.join(sub_data_dir, data_name) for data_name in data_name_list] cls_txt_dict[sub_data_name] = data_file_list txt_file_list.extend(data_file_list) num_data_files = len(data_file_list) print("{}: {}".format(sub_data_name, num_data_files), flush=True) num_txt_files = len(txt_file_list) print("total: {}".format(num_txt_files), flush=True) return cls_txt_dict, txt_file_list def get_txt_data(txt_file): with codecs.open(txt_file, "r", "UTF8") as fp: txt_content = fp.read() txt_data = re.sub("\s+", " ", txt_content) return txt_data def build_vocab(txt_file_list, vocab_size=7000): counter = Counter() for txt_file in txt_file_list: txt_data = get_txt_data(txt_file) counter.update(txt_data) num_vocab = len(counter) if num_vocab < vocab_size - 1: real_vocab_size = num_vocab + 2 else: real_vocab_size = vocab_size # pad_id is 0, unk_id is 1 vocab_dict = {word_freq[0]: ix + 1 for ix, word_freq in enumerate(counter.most_common(real_vocab_size - 2))} print("real vocab size: {}".format(real_vocab_size), flush=True) print("vocab dict:\n{}".format(vocab_dict), flush=True) return vocab_dict def make_tfrecords( data_dir, tfrecord_dir, vocab_size=7000, min_seq_length=10, max_seq_length=800, num_train=8, num_test=2, start_fid=0): # get txt files cls_txt_dict, txt_file_list = get_txt_files(data_dir=data_dir) # map word to id vocab_dict = build_vocab(txt_file_list=txt_file_list, vocab_size=vocab_size) # map class to id class_dict = {class_name: ix for ix, class_name in enumerate(cls_txt_dict.keys())} train_writers = [] for fid in range(start_fid, num_train+start_fid): tfrecord_file = os.path.join(tfrecord_dir, "train_{:04d}.tfrecord".format(fid)) writer = tf.io.TFRecordWriter(tfrecord_file) train_writers.append(writer) test_writers = [] for fid in range(start_fid, num_test+start_fid): tfrecord_file = os.path.join(tfrecord_dir, "test_{:04d}.tfrecord".format(fid)) writer = tf.io.TFRecordWriter(tfrecord_file) test_writers.append(writer) pad_id = 0 unk_id = 1 num_samples = 0 num_train_samples = 0 num_test_samples = 0 for class_name, class_file_list in cls_txt_dict.items(): class_id = class_dict[class_name] num_class_pass = 0 for txt_file in class_file_list: txt_data = get_txt_data(txt_file=txt_file) txt_len = len(txt_data) if txt_len < min_seq_length: num_class_pass += 1 continue if txt_len > max_seq_length: txt_data = txt_data[:max_seq_length] txt_len = max_seq_length word_ids = [] for word in txt_data: word_id = vocab_dict.get(word, unk_id) word_ids.append(word_id) for _ in range(max_seq_length - txt_len): word_ids.append(pad_id) example = dict_to_example({"input": word_ids, "class": class_id}) num_samples += 1 if num_samples % 10 == 0: num_test_samples += 1 writer_id = num_test_samples % num_test test_writers[writer_id].write(example.SerializeToString()) else: num_train_samples += 1 writer_id = num_train_samples % num_train train_writers[writer_id].write(example.SerializeToString()) print("{} pass: {}".format(class_name, num_class_pass), flush=True) for writer in train_writers: writer.close() for writer in test_writers: writer.close() print("num samples: {}".format(num_samples), flush=True) print("num train samples: {}".format(num_train_samples), flush=True) print("num test samples: {}".format(num_test_samples), flush=True) def main(): data_dir = "{your_data_dir}" tfrecord_dir = "{your_tfrecord_dir}" make_tfrecords(data_dir=data_dir, tfrecord_dir=tfrecord_dir) if __name__ == "__main__": main()
将以上代码保存到文件make_tfrecord.py
,运行命令:
注意:需要替换
data_dir
和tfrecord_dir
为个人目录。
python3 make_tfrecord.py
使用tree
命令查看生成的TFRecord
数据目录,输出内容如下:
.
├── test_0000.tfrecord
├── test_0001.tfrecord
├── train_0000.tfrecord
├── train_0001.tfrecord
├── train_0002.tfrecord
├── train_0003.tfrecord
├── train_0004.tfrecord
├── train_0005.tfrecord
├── train_0006.tfrecord
└── train_0007.tfrecord
0 directories, 10 files
有了3
中的TFRecord
数据集,下面来介绍如何在MindSpore
中使用该数据集。
首先来看看对于参数schema
不指定,即采用默认值的情况下,能否正确读取数据。
代码如下:
import os from mindspore.common import dtype as mstype from mindspore.dataset import Schema from mindspore.dataset import TFRecordDataset def get_tfrecord_files(tfrecord_dir, file_suffix="tfrecord", is_train=True): if not os.path.exists(tfrecord_dir): raise ValueError("tfrecord directory: {} not exists!".format(tfrecord_dir)) if is_train: file_prefix = "train" else: file_prefix = "test" data_sources = [] for parent, _, filenames in os.walk(tfrecord_dir): for filename in filenames: if not filename.startswith(file_prefix): continue tmp_path = os.path.join(parent, filename) if tmp_path.endswith(file_suffix): data_sources.append(tmp_path) return data_sources def load_tfrecord(tfrecord_dir, tfrecord_json=None): tfrecord_files = get_tfrecord_files(tfrecord_dir) # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True) dataset = TFRecordDataset(dataset_files=tfrecord_files, shuffle=False) data_iter = dataset.create_dict_iterator() for item in data_iter: print(item, flush=True) break def main(): tfrecord_dir = "{your_tfrecord_dir}" tfrecord_json = "{your_tfrecord_json_file}" load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=None) if __name__ == "__main__": main()
代码解读:
将上述代码保存到文件load_tfrecord_dataset.py
,运行如下命令:
python3 load_tfrecord_dataset.py
输出内容如下:
可以看出能正确解析出之前保存在TFRecord内的数据,数据类型和数据维度解析正确。
{'class': Tensor(shape=[1], dtype=Int64, value= [0]), 'input': Tensor(shape=[800], dtype=Int64, value= [1719, 636, 1063, 18, 742, 330, 385, 999, 837, 56, 529, 1000, 260, 3, 171, 45, 7, 65, 136, 869, 211, 215, 443, 541, 3, 91, 1719, 636, 2, 424, 291, 16, 86, 31, 12, 211, 215, 443, 541, 999, 322, 128, 916, 102, 743, 136, 121, 298, 454, 2, 234, 225, 1, 136, 121, 298, 454, 100, 49, 22, 152, 70, 677, 806, 31, 1719, 636, 100, 25, 237, 2, 424, 1, 39, 100, 39, 71, 228, 385, 999, 837, 1, 171, 45, 136, 869, 211, 215, 443, 541, 1, 35, 68, 20, 149, 304, 31, 70, 677, 1, 106, 9, 308, 487, 869, 153, 597, 2, 523, 262, 184, 145, 57, 36, 158, 13, 69, 41, 1, 35, 68, 511, 402, 152, 469, 41, 617, 761, 50, 36, 144, 281, 26, 308, 487, 869, 153, 597, 4, 23, 208, 17, 121, 428, 646, 71, 8, 10, 47, 40, 87, 32, 413, 133, 9, 641, 159, 74, 144, 281, 26, 308, 487, 869, 153, 597, 2, 197, 447, 1, 91, 549, 202, 208, 17, 121, 558, 123, 2, 113, 203, 1, 419, 1024, 200, 154, 80, 16, 147, 64, 111, 208, 219, 136, 25, 6, 153, 597, 4, 160, 134, 16, 1, 167, 229, 1719, 636, 514, 2, 9, 7, 65, 321, 136, 869, 211, 215, 443, 541, 1, 514, 2, 69, 33, 13, 88, 80, 94, 294, 2, 308, 487, 869, 153, 597, 1, 39, 69, 33, 197, 57, 310, 335, 50, 94, 294, 2, 308, 487, 869, 153, 597, 1, 221, 13, 74, 337, 56, 499, 117, 836, 621, 488, 26, 94, 294, 1, 10, 5, 7, 10, 21, 973, 124, 492, 69, 33, 514, 218, 168, 117, 1, 82, 285, 148, 697, 2, 982, 298, 1535, 119, 743, 201, 1187, 4, 3, 136, 121, 298, 454, 103, 752, 31, 12, 1496, 762, 164, 2, 609, 6, 175, 83, 170, 257, 454, 963, 1, 149, 57, 136, 121, 298, 454, 62, 52, 87, 110, 257, 12, 34, 39, 2, 677, 1151, 1, 136, 121, 298, 454, 100, 49, 22, 138, 55, 39, 1, 752, 744, 184, 36, 169, 11, 561, 9, 1, 13, 74, 39, 62, 9, 308, 487, 869, 250, 321, 4, 23, 211, 215, 443, 541, 9, 641, 32, 900, 2586, 83, 1157, 165, 978, 97, 694, 837, 301, 22, 97, 694, 837, 9, 124, 492, 2, 720, 1341, 1, 35, 68, 32, 2294, 216, 2, 1, 106, 9, 13, 9, 20, 12, 25, 26, 973, 124, 1, 91, 9, 20, 12, 344, 36, 4, 23, 3, 167, 229, 211, 215, 443, 541, 2, 283, 683, 9, 1719, 636, 1, 292, 278, 9, 641, 103, 32, 283, 683, 976, 944, 511, 316, 30, 178, 223, 795, 136, 164, 301, 22, 25, 172, 26, 18, 1102, 69, 41, 136, 869, 1, 35, 68, 184, 344, 285, 74, 1, 178, 223, 795, 136, 164, 13, 9, 6, 26, 1152, 285, 1, 163, 20, 35, 68, 184, 344, 165, 894, 74, 521, 96, 39, 1, 976, 944, 511, 316, 62, 9, 167, 149, 1, 1024, 1405, 164, 271, 454, 102, 743, 62, 9, 25, 278, 100, 2, 4, 23, 3, 136, 121, 298, 454, 103, 752, 31, 12, 145, 57, 442, 32, 401, 665, 14, 2, 432, 848, 808, 49, 22, 432, 848, 808, 30, 35, 68, 2, 116, 39, 57, 896, 6, 237, 1, 112, 9, 508, 922, 2, 83, 479, 1, 106, 35, 13, 382, 203, 39, 9, 641, 32, 96, 168, 19, 59, 117, 1, 62, 13, 382, 203, 39, 351, 37, 309, 641, 32, 309, 51, 1, 35, 13, 9, 102, 406, 621, 4, 23, 136, 121, 298, 454, 103, 110, 177, 1, 145, 57, 2, 211, 215, 443, 541, 1, 1171, 736, 2, 9, 14, 37, 83, 170, 1, 22, 35, 68, 2, 14, 37, 9, 173, 45, 1652, 136, 6, 57, 1652, 516, 565, 1, 35, 68, 151, 1171, 736, 2, 9, 14, 37, 1, 62, 513, 755, 57, 1652, 1, 91, 253, 71, 15, 45, 655, 15, 57, 896, 1, 35, 68, 13, 318, 165, 894, 4, 23, 3, 136, 121, 298, 454, 103, 34, 145, 57, 211, 215, 443, 541, 2, 878, 503, 516, 565, 304, 31, 648, 208, 49, 22, 83, 117, 147, 64, 219, 246, 12, 152, 66, 1, 1290, 455, 164, 154, 234, 36, 12, 1, 1000, 316, 164, 15, 998, 812, 1289, 112, 36, 12, 1, 1426, 201, 119, 1078, 319, 512, 71, 8, 182, 124, 238, 230, 123, 901, 1, 184, 222, 6, 87, 435, 71, 60, 20, 211, 215, 443, 541, 2, 6, 170, 1, 16, 94, 294, 475, 419, 2450, 9, 571, 11, 1, 63, 8, 7, 5, 5, 122, 1080, 35, 68, 12, 4, 846, 337, 61, 301, 701, 297, 39, 6, 539, 27, 135, 979, 1, 35, 166, 181, 90, 143])}
下面介绍,如何使用mindspore.dataset.Schema
来指定读取模型策略。
修改load_tfrecord
代码如下:
def load_tfrecord(tfrecord_dir, tfrecord_json=None):
tfrecord_files = get_tfrecord_files(tfrecord_dir)
# print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)
data_schema = Schema()
data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])
dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, shuffle=False)
data_iter = dataset.create_dict_iterator()
for item in data_iter:
print(item, flush=True)
break
代码解读:
Schema
对象,并且指定了列名,列的数据类型和数据维度。保存并再次运行文件load_tfrecord_dataset.py
,输出内容如下:
可以看出能正确解析出之前保存在TFRecord内的数据,数据类型和数据维度解析正确。
{'input': Tensor(shape=[800], dtype=Int64, value= [1719, 636, 1063, 18, 742, 330, 385, 999, 837, 56, 529, 1000, 260, 3, 171, 45, 7, 65, 136, 869, 211, 215, 443, 541, 3, 91, 1719, 636, 2, 424, 291, 16, 86, 31, 12, 211, 215, 443, 541, 999, 322, 128, 916, 102, 743, 136, 121, 298, 454, 2, 234, 225, 1, 136, 121, 298, 454, 100, 49, 22, 152, 70, 677, 806, 31, 1719, 636, 100, 25, 237, 2, 424, 1, 39, 100, 39, 71, 228, 385, 999, 837, 1, 171, 45, 136, 869, 211, 215, 443, 541, 1, 35, 68, 20, 149, 304, 31, 70, 677, 1, 106, 9, 308, 487, 869, 153, 597, 2, 523, 262, 184, 145, 57, 36, 158, 13, 69, 41, 1, 35, 68, 511, 402, 152, 469, 41, 617, 761, 50, 36, 144, 281, 26, 308, 487, 869, 153, 597, 4, 23, 208, 17, 121, 428, 646, 71, 8, 10, 47, 40, 87, 32, 413, 133, 9, 641, 159, 74, 144, 281, 26, 308, 487, 869, 153, 597, 2, 197, 447, 1, 91, 549, 202, 208, 17, 121, 558, 123, 2, 113, 203, 1, 419, 1024, 200, 154, 80, 16, 147, 64, 111, 208, 219, 136, 25, 6, 153, 597, 4, 160, 134, 16, 1, 167, 229, 1719, 636, 514, 2, 9, 7, 65, 321, 136, 869, 211, 215, 443, 541, 1, 514, 2, 69, 33, 13, 88, 80, 94, 294, 2, 308, 487, 869, 153, 597, 1, 39, 69, 33, 197, 57, 310, 335, 50, 94, 294, 2, 308, 487, 869, 153, 597, 1, 221, 13, 74, 337, 56, 499, 117, 836, 621, 488, 26, 94, 294, 1, 10, 5, 7, 10, 21, 973, 124, 492, 69, 33, 514, 218, 168, 117, 1, 82, 285, 148, 697, 2, 982, 298, 1535, 119, 743, 201, 1187, 4, 3, 136, 121, 298, 454, 103, 752, 31, 12, 1496, 762, 164, 2, 609, 6, 175, 83, 170, 257, 454, 963, 1, 149, 57, 136, 121, 298, 454, 62, 52, 87, 110, 257, 12, 34, 39, 2, 677, 1151, 1, 136, 121, 298, 454, 100, 49, 22, 138, 55, 39, 1, 752, 744, 184, 36, 169, 11, 561, 9, 1, 13, 74, 39, 62, 9, 308, 487, 869, 250, 321, 4, 23, 211, 215, 443, 541, 9, 641, 32, 900, 2586, 83, 1157, 165, 978, 97, 694, 837, 301, 22, 97, 694, 837, 9, 124, 492, 2, 720, 1341, 1, 35, 68, 32, 2294, 216, 2, 1, 106, 9, 13, 9, 20, 12, 25, 26, 973, 124, 1, 91, 9, 20, 12, 344, 36, 4, 23, 3, 167, 229, 211, 215, 443, 541, 2, 283, 683, 9, 1719, 636, 1, 292, 278, 9, 641, 103, 32, 283, 683, 976, 944, 511, 316, 30, 178, 223, 795, 136, 164, 301, 22, 25, 172, 26, 18, 1102, 69, 41, 136, 869, 1, 35, 68, 184, 344, 285, 74, 1, 178, 223, 795, 136, 164, 13, 9, 6, 26, 1152, 285, 1, 163, 20, 35, 68, 184, 344, 165, 894, 74, 521, 96, 39, 1, 976, 944, 511, 316, 62, 9, 167, 149, 1, 1024, 1405, 164, 271, 454, 102, 743, 62, 9, 25, 278, 100, 2, 4, 23, 3, 136, 121, 298, 454, 103, 752, 31, 12, 145, 57, 442, 32, 401, 665, 14, 2, 432, 848, 808, 49, 22, 432, 848, 808, 30, 35, 68, 2, 116, 39, 57, 896, 6, 237, 1, 112, 9, 508, 922, 2, 83, 479, 1, 106, 35, 13, 382, 203, 39, 9, 641, 32, 96, 168, 19, 59, 117, 1, 62, 13, 382, 203, 39, 351, 37, 309, 641, 32, 309, 51, 1, 35, 13, 9, 102, 406, 621, 4, 23, 136, 121, 298, 454, 103, 110, 177, 1, 145, 57, 2, 211, 215, 443, 541, 1, 1171, 736, 2, 9, 14, 37, 83, 170, 1, 22, 35, 68, 2, 14, 37, 9, 173, 45, 1652, 136, 6, 57, 1652, 516, 565, 1, 35, 68, 151, 1171, 736, 2, 9, 14, 37, 1, 62, 513, 755, 57, 1652, 1, 91, 253, 71, 15, 45, 655, 15, 57, 896, 1, 35, 68, 13, 318, 165, 894, 4, 23, 3, 136, 121, 298, 454, 103, 34, 145, 57, 211, 215, 443, 541, 2, 878, 503, 516, 565, 304, 31, 648, 208, 49, 22, 83, 117, 147, 64, 219, 246, 12, 152, 66, 1, 1290, 455, 164, 154, 234, 36, 12, 1, 1000, 316, 164, 15, 998, 812, 1289, 112, 36, 12, 1, 1426, 201, 119, 1078, 319, 512, 71, 8, 182, 124, 238, 230, 123, 901, 1, 184, 222, 6, 87, 435, 71, 60, 20, 211, 215, 443, 541, 2, 6, 170, 1, 16, 94, 294, 475, 419, 2450, 9, 571, 11, 1, 63, 8, 7, 5, 5, 122, 1080, 35, 68, 12, 4, 846, 337, 61, 301, 701, 297, 39, 6, 539, 27, 135, 979, 1, 35, 166, 181, 90, 143]), 'class': Tensor(shape=[1], dtype=Int64, value= [0])}
下面介绍,如何使用JSON
文件来指定读取模型策略。
新建tfrecord_sample.json
文件,在文件内写入如下内容:
numRows – 数据列数
columns – 依次为每列的列名、数据类型、数据维数、数据维度。
{ "datasetType": "TF", "numRows": 2, "columns": { "input": { "type": "int64", "rank": 1, "shape": [800] }, "class" : { "type": "int64", "rank": 1, "shape": [1] } } }
有了相应的JSON
文件,下面来介绍如何使用该文件进行数据读取。
修改load_tfrecord
代码如下:
def load_tfrecord(tfrecord_dir, tfrecord_json=None):
tfrecord_files = get_tfrecord_files(tfrecord_dir)
# print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)
dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=tfrecord_json, shuffle=False)
data_iter = dataset.create_dict_iterator()
for item in data_iter:
print(item, flush=True)
break
同时修改main部分代码如下:
load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=tfrecord_json)
代码解读
schema
参数指定为JSON
的文件路径保存并再次运行文件load_tfrecord_dataset.py
,输出内容如下:
{'class': Tensor(shape=[1], dtype=Int64, value= [0]), 'input': Tensor(shape=[800], dtype=Int64, value= [1719, 636, 1063, 18, 742, 330, 385, 999, 837, 56, 529, 1000, 260, 3, 171, 45, 7, 65, 136, 869, 211, 215, 443, 541, 3, 91, 1719, 636, 2, 424, 291, 16, 86, 31, 12, 211, 215, 443, 541, 999, 322, 128, 916, 102, 743, 136, 121, 298, 454, 2, 234, 225, 1, 136, 121, 298, 454, 100, 49, 22, 152, 70, 677, 806, 31, 1719, 636, 100, 25, 237, 2, 424, 1, 39, 100, 39, 71, 228, 385, 999, 837, 1, 171, 45, 136, 869, 211, 215, 443, 541, 1, 35, 68, 20, 149, 304, 31, 70, 677, 1, 106, 9, 308, 487, 869, 153, 597, 2, 523, 262, 184, 145, 57, 36, 158, 13, 69, 41, 1, 35, 68, 511, 402, 152, 469, 41, 617, 761, 50, 36, 144, 281, 26, 308, 487, 869, 153, 597, 4, 23, 208, 17, 121, 428, 646, 71, 8, 10, 47, 40, 87, 32, 413, 133, 9, 641, 159, 74, 144, 281, 26, 308, 487, 869, 153, 597, 2, 197, 447, 1, 91, 549, 202, 208, 17, 121, 558, 123, 2, 113, 203, 1, 419, 1024, 200, 154, 80, 16, 147, 64, 111, 208, 219, 136, 25, 6, 153, 597, 4, 160, 134, 16, 1, 167, 229, 1719, 636, 514, 2, 9, 7, 65, 321, 136, 869, 211, 215, 443, 541, 1, 514, 2, 69, 33, 13, 88, 80, 94, 294, 2, 308, 487, 869, 153, 597, 1, 39, 69, 33, 197, 57, 310, 335, 50, 94, 294, 2, 308, 487, 869, 153, 597, 1, 221, 13, 74, 337, 56, 499, 117, 836, 621, 488, 26, 94, 294, 1, 10, 5, 7, 10, 21, 973, 124, 492, 69, 33, 514, 218, 168, 117, 1, 82, 285, 148, 697, 2, 982, 298, 1535, 119, 743, 201, 1187, 4, 3, 136, 121, 298, 454, 103, 752, 31, 12, 1496, 762, 164, 2, 609, 6, 175, 83, 170, 257, 454, 963, 1, 149, 57, 136, 121, 298, 454, 62, 52, 87, 110, 257, 12, 34, 39, 2, 677, 1151, 1, 136, 121, 298, 454, 100, 49, 22, 138, 55, 39, 1, 752, 744, 184, 36, 169, 11, 561, 9, 1, 13, 74, 39, 62, 9, 308, 487, 869, 250, 321, 4, 23, 211, 215, 443, 541, 9, 641, 32, 900, 2586, 83, 1157, 165, 978, 97, 694, 837, 301, 22, 97, 694, 837, 9, 124, 492, 2, 720, 1341, 1, 35, 68, 32, 2294, 216, 2, 1, 106, 9, 13, 9, 20, 12, 25, 26, 973, 124, 1, 91, 9, 20, 12, 344, 36, 4, 23, 3, 167, 229, 211, 215, 443, 541, 2, 283, 683, 9, 1719, 636, 1, 292, 278, 9, 641, 103, 32, 283, 683, 976, 944, 511, 316, 30, 178, 223, 795, 136, 164, 301, 22, 25, 172, 26, 18, 1102, 69, 41, 136, 869, 1, 35, 68, 184, 344, 285, 74, 1, 178, 223, 795, 136, 164, 13, 9, 6, 26, 1152, 285, 1, 163, 20, 35, 68, 184, 344, 165, 894, 74, 521, 96, 39, 1, 976, 944, 511, 316, 62, 9, 167, 149, 1, 1024, 1405, 164, 271, 454, 102, 743, 62, 9, 25, 278, 100, 2, 4, 23, 3, 136, 121, 298, 454, 103, 752, 31, 12, 145, 57, 442, 32, 401, 665, 14, 2, 432, 848, 808, 49, 22, 432, 848, 808, 30, 35, 68, 2, 116, 39, 57, 896, 6, 237, 1, 112, 9, 508, 922, 2, 83, 479, 1, 106, 35, 13, 382, 203, 39, 9, 641, 32, 96, 168, 19, 59, 117, 1, 62, 13, 382, 203, 39, 351, 37, 309, 641, 32, 309, 51, 1, 35, 13, 9, 102, 406, 621, 4, 23, 136, 121, 298, 454, 103, 110, 177, 1, 145, 57, 2, 211, 215, 443, 541, 1, 1171, 736, 2, 9, 14, 37, 83, 170, 1, 22, 35, 68, 2, 14, 37, 9, 173, 45, 1652, 136, 6, 57, 1652, 516, 565, 1, 35, 68, 151, 1171, 736, 2, 9, 14, 37, 1, 62, 513, 755, 57, 1652, 1, 91, 253, 71, 15, 45, 655, 15, 57, 896, 1, 35, 68, 13, 318, 165, 894, 4, 23, 3, 136, 121, 298, 454, 103, 34, 145, 57, 211, 215, 443, 541, 2, 878, 503, 516, 565, 304, 31, 648, 208, 49, 22, 83, 117, 147, 64, 219, 246, 12, 152, 66, 1, 1290, 455, 164, 154, 234, 36, 12, 1, 1000, 316, 164, 15, 998, 812, 1289, 112, 36, 12, 1, 1426, 201, 119, 1078, 319, 512, 71, 8, 182, 124, 238, 230, 123, 901, 1, 184, 222, 6, 87, 435, 71, 60, 20, 211, 215, 443, 541, 2, 6, 170, 1, 16, 94, 294, 475, 419, 2450, 9, 571, 11, 1, 63, 8, 7, 5, 5, 122, 1080, 35, 68, 12, 4, 846, 337, 61, 301, 701, 297, 39, 6, 539, 27, 135, 979, 1, 35, 166, 181, 90, 143])}
在某些场景下,我们可能只需要某(几)列的数据,而非全部数据,这时候就可以通过制定columns_list
来进行数据加载。
下面我们只读取class
列,来简单看看如何操作。
在4.1.2
基础上,修改load_tfrecord
代码如下:
def load_tfrecord(tfrecord_dir, tfrecord_json=None):
tfrecord_files = get_tfrecord_files(tfrecord_dir)
# print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)
data_schema = Schema()
data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])
dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, columns_list=["class"], shuffle=False)
data_iter = dataset.create_dict_iterator()
for item in data_iter:
print(item, flush=True)
break
保存并再次运行文件load_tfrecord_dataset.py
,输出内容如下:
可以看到只读取了我们指定的列,且数据加载正确。
{'class': Tensor(shape=[1], dtype=Int64, value= [0])}
本文介绍了在MindSpore
中如何加载TFRecord
数据集,并重点介绍了TFRecordDataset
中的schema
和columns_list
参数使用。
本文为原创文章,版权归作者所有,未经授权不得转载!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。