赞
踩
目录
2、pb文件,直接包含图结构和变量值,加载时只需要一个文件即可
3)tf.saved_model 保存得到一个pb文件和一个variables文件夹。
保存时主要有四个文件:
1)checkpoint:指示当前目录有哪些模型文件以及最新的模型文件
内容举例:
- model_checkpoint_path: "model.ckpt-2625"
- all_model_checkpoint_paths: "model.ckpt-2000"
- all_model_checkpoint_paths: "model.ckpt-2625"
2)model.ckpt-2625.data-00000-of-00001
包含训练变量的文件,在bert训练过程中,约1.2g ,这是由于除了记录每个变量的值,还记录的一阶矩和二阶矩,即adam当中的v,u
3)model.ckpt-2625.index
描述变量的key和value的对应关系。
4)model.ckpt-2625.meta
描述整个网络的结构。
保存:可以由两种方式产生:
1).tf.train.Saver
-
- saver=tf.train.Saver(max_to_keep=5) #max_to_keep=5意思就是保存最近的5个模型
- saver.save(sess,'path',global_step=epoch)
2) estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
加载:
- 1)saver=tf.train.import_meta_graph('model/model.meta') #恢复计算图结构
- saver.restore(sess, tf.train.latest_checkpoint("model/")) #恢复所有变量信息
- 2)estimator
-
- def prepare(self):
- tokenization.validate_case_matches_checkpoint(arg_dic['do_lower_case'], arg_dic['init_checkpoint'])
- self.config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
-
- if arg_dic['max_seq_length'] > self.config.max_position_embeddings:
- raise ValueError(
- "Cannot use sequence length %d because the BERT model "
- "was only trained up to sequence length %d" %
- (arg_dic['max_seq_length'], self.config.max_position_embeddings))
-
- # tf.gfile.MakeDirs(self.out_dir)
- self.tokenizer = tokenization.FullTokenizer(vocab_file=arg_dic['vocab_file'],
- do_lower_case=arg_dic['do_lower_case'])
-
- self.processor = SelfProcessor()
- self.train_examples = self.processor.get_train_examples(arg_dic['data_dir'])
- global label_list
- label_list = self.processor.get_labels()
-
- self.run_config = tf.estimator.RunConfig(
- model_dir=arg_dic['output_dir'], save_checkpoints_steps=arg_dic['save_checkpoints_steps'],
- tf_random_seed=None, save_summary_steps=100, session_config=None, keep_checkpoint_max=5,
- keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, )
- self.predict_fn = tf.contrib.predictor.from_saved_model("pb_save_test")
-
- def predict_on_ckpt(self, sentence):
- if not self.ckpt_tool:
- num_train_steps = int(len(self.train_examples) / arg_dic['train_batch_size'] * arg_dic['num_train_epochs'])
- num_warmup_steps = int(num_train_steps * arg_dic['warmup_proportion'])
-
- model_fn = model_fn_builder(bert_config=self.config, num_labels=len(label_list),
- init_checkpoint=arg_dic['init_checkpoint'], learning_rate=arg_dic['learning_rate'],
- num_train=num_train_steps, num_warmup=num_warmup_steps)
-
- self.ckpt_tool = tf.estimator.Estimator(model_fn=model_fn, config=self.run_config, )
- exam = self.processor.one_example(sentence) # 待预测的样本列表
- feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
-
- predict_input_fn = input_fn_builder(features=[feature, ],
- seq_length=arg_dic['max_seq_length'], is_training=False,
- drop_remainder=False)
- result = self.ckpt_tool.predict(input_fn=predict_input_fn) # 执行预测操作,得到一个生成器。
- gailv = list(result)[0]["probabilities"].tolist()
- pos = gailv.index(max(gailv)) # 定位到最大概率值索引,
- return label_list[pos]
1)保存:
- pb_file = os.path.join(arg_dic['pb_model_dir'], 'classification_model.pb')
- graph = tf.Graph()
- with graph.as_default():
- input_ids = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_ids')
- input_mask = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_mask')
- bert_config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
- loss, per_example_loss, logits, probabilities = create_classification_model(
- bert_config=bert_config, is_training=False,
- input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
-
- probabilities = tf.identity(probabilities, 'pred_prob')
- saver = tf.train.Saver()
-
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- latest_checkpoint = tf.train.latest_checkpoint(arg_dic['output_dir'])
- saver.restore(sess, latest_checkpoint)
- from tensorflow.python.framework import graph_util
- tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob'])
- # 存储二进制模型到文件中
- with tf.gfile.GFile(pb_file, 'wb') as f:
- f.write(tmp_g.SerializeToString())
- return pb_file
- except Exception as e:
- print('fail to optimize the graph! %s', e)
2)加载:
- def classification_model_fn(self, features, mode):
- with tf.gfile.GFile(self.graph_path, 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- input_ids = features["input_ids"]
- input_mask = features["input_mask"]
- input_map = {"input_ids": input_ids, "input_mask": input_mask}
- pred_probs = tf.import_graph_def(graph_def, name='', input_map=input_map, return_elements=['pred_prob:0'])
-
- return EstimatorSpec(mode=mode, predictions={
- 'encodes': tf.argmax(pred_probs[0], axis=-1),
- 'score': tf.reduce_max(pred_probs[0], axis=-1)})
-
-
- def predict_on_pb(self, sentence):
- if not self.pbTool:
- self.pbTool = tf.estimator.Estimator(model_fn=self.classification_model_fn, config=self.run_config, )
- exam = self.processor.one_example(sentence) # 待预测的样本列表
- feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
- predict_input_fn = input_fn_builder(features=[feature, ],
- seq_length=arg_dic['max_seq_length'], is_training=False,
- drop_remainder=False)
- result = self.pbTool.predict(input_fn=predict_input_fn) # 执行预测操作,得到一个生成器。
- ele = list(result)[0]
- print('类别:{},置信度:{:.3f}'.format(label_list[ele['encodes']], ele['score']))
- return label_list[ele['encodes']]
保存两种方式:
- 1)graph = tf.Graph()
- with graph.as_default():
- input_ids = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_ids')
- input_mask = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_mask')
- bert_config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
- loss, per_example_loss, logits, probabilities = create_classification_model(
- bert_config=bert_config, is_training=False,
- input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
-
- probabilities = tf.identity(probabilities, 'pred_prob')
- saver = tf.train.Saver()
-
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- latest_checkpoint = tf.train.latest_checkpoint(arg_dic['output_dir'])
- saver.restore(sess, latest_checkpoint)
- path_pb_model = "pb_save_test"
- builder = tf.saved_model.builder.SavedModelBuilder(path_pb_model) # 创建一个保存模型的实例对象
- # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
- input_ids1 = tf.saved_model.utils.build_tensor_info(input_ids)
- input_mask1 = tf.saved_model.utils.build_tensor_info(input_mask)
- probabilities1 = tf.saved_model.utils.build_tensor_info(probabilities)
- # 构建 SignatureDef protobuf
- signature_def = tf.saved_model.signature_def_utils.build_signature_def(
- inputs={'input_ids': input_ids1, 'input_mask': input_mask1},
- outputs={'probabilities':probabilities1 },
- method_name='test')
- # 将 graph 和变量等信息写入 MetaGraphDef protobuf
- # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,可以使用预定义的这些值
- builder.add_meta_graph_and_variables(sess,
- tags=[tf.saved_model.tag_constants.SERVING],
- signature_def_map={tf.saved_model.signature_constants.CLASSIFY_INPUTS: signature_def})
-
- # 将 MetaGraphDef 写入磁盘
- builder.save()
- 2)estimator
- def serving_input_receiver_fn():
- """
- 用于在serving时,接收数据
- :return:
- """
- feature_spec = {
- "input_ids": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
- "input_mask": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
- "segment_ids": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
- }
- serialized_tf_example = tf.placeholder(dtype=tf.string,
- shape=[None],
- name='input_example_tensor')
- receiver_tensors = {'examples': serialized_tf_example}
- features = tf.parse_example(serialized_tf_example, feature_spec)
- return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
-
- if arg_dic['do_predict']:
- estimator._export_to_tpu = False
- estimator.export_savedmodel("pb_save_test", serving_input_receiver_fn)
模型的加载:
- tokenization.validate_case_matches_checkpoint(arg_dic['do_lower_case'], arg_dic['init_checkpoint'])
- self.config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
-
- if arg_dic['max_seq_length'] > self.config.max_position_embeddings:
- raise ValueError(
- "Cannot use sequence length %d because the BERT model "
- "was only trained up to sequence length %d" %
- (arg_dic['max_seq_length'], self.config.max_position_embeddings))
-
- # tf.gfile.MakeDirs(self.out_dir)
- self.tokenizer = tokenization.FullTokenizer(vocab_file=arg_dic['vocab_file'],
- do_lower_case=arg_dic['do_lower_case'])
-
- self.processor = SelfProcessor()
- global label_list
- label_list = self.processor.get_labels()
- self.predict_fn = tf.contrib.predictor.from_saved_model("/home/hadoop-health-alg/TextClassify_with_BERT/pb_save_test")
-
-
- def predict_on_pb(self, sentence):
-
- exam = self.processor.one_example(sentence) # 待预测的样本列表
- feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
- features = dict()
- features['input_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.input_ids))
- features['input_mask'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.input_mask))
- features['segment_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.segment_ids))
- tmp_feature = {"input_ids":feature.input_ids,"input_mask":feature.input_mask}
- examples = []
- example = tf.train.Example(features=tf.train.Features(feature=features))
- examples.append(example.SerializeToString())
- predictions = self.predict_fn({'examples': examples})
- result = predictions['probabilities']
- result = result.tolist()
- pos = result[0].index(max(result[0])) # 定位到最大概率值索引,
-
- print("hahahah",result,label_list,pos)
- return label_list[pos]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。