当前位置:   article > 正文

bert 三种模型保存的方式以及调用方法总结(ckpt,单文件pb,tf_serving使用的pb)_bert.ckpt

bert.ckpt

目录

1、在训练的过程中保存的ckpt文件:

2、pb文件,直接包含图结构和变量值,加载时只需要一个文件即可

3)tf.saved_model 保存得到一个pb文件和一个variables文件夹。


1、在训练的过程中保存的ckpt文件:

保存时主要有四个文件:

1)checkpoint:指示当前目录有哪些模型文件以及最新的模型文件

内容举例:

  1.   model_checkpoint_path: "model.ckpt-2625"
  2.   all_model_checkpoint_paths: "model.ckpt-2000"
  3.   all_model_checkpoint_paths: "model.ckpt-2625"

2)model.ckpt-2625.data-00000-of-00001

包含训练变量的文件,在bert训练过程中,约1.2g ,这是由于除了记录每个变量的值,还记录的一阶矩和二阶矩,即adam当中的v,u

3)model.ckpt-2625.index

描述变量的key和value的对应关系。

4)model.ckpt-2625.meta

描述整个网络的结构。

保存:可以由两种方式产生:

1).tf.train.Saver

  1. saver=tf.train.Saver(max_to_keep=5) #max_to_keep=5意思就是保存最近的5个模型
  2. saver.save(sess,'path',global_step=epoch)

2) estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

加载:

  1. 1)saver=tf.train.import_meta_graph('model/model.meta') #恢复计算图结构
  2. saver.restore(sess, tf.train.latest_checkpoint("model/")) #恢复所有变量信息

  1. 2)estimator
  2. def prepare(self):
  3. tokenization.validate_case_matches_checkpoint(arg_dic['do_lower_case'], arg_dic['init_checkpoint'])
  4. self.config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
  5. if arg_dic['max_seq_length'] > self.config.max_position_embeddings:
  6. raise ValueError(
  7. "Cannot use sequence length %d because the BERT model "
  8. "was only trained up to sequence length %d" %
  9. (arg_dic['max_seq_length'], self.config.max_position_embeddings))
  10. # tf.gfile.MakeDirs(self.out_dir)
  11. self.tokenizer = tokenization.FullTokenizer(vocab_file=arg_dic['vocab_file'],
  12. do_lower_case=arg_dic['do_lower_case'])
  13. self.processor = SelfProcessor()
  14. self.train_examples = self.processor.get_train_examples(arg_dic['data_dir'])
  15. global label_list
  16. label_list = self.processor.get_labels()
  17. self.run_config = tf.estimator.RunConfig(
  18. model_dir=arg_dic['output_dir'], save_checkpoints_steps=arg_dic['save_checkpoints_steps'],
  19. tf_random_seed=None, save_summary_steps=100, session_config=None, keep_checkpoint_max=5,
  20. keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, )
  21. self.predict_fn = tf.contrib.predictor.from_saved_model("pb_save_test")
  22. def predict_on_ckpt(self, sentence):
  23. if not self.ckpt_tool:
  24. num_train_steps = int(len(self.train_examples) / arg_dic['train_batch_size'] * arg_dic['num_train_epochs'])
  25. num_warmup_steps = int(num_train_steps * arg_dic['warmup_proportion'])
  26. model_fn = model_fn_builder(bert_config=self.config, num_labels=len(label_list),
  27. init_checkpoint=arg_dic['init_checkpoint'], learning_rate=arg_dic['learning_rate'],
  28. num_train=num_train_steps, num_warmup=num_warmup_steps)
  29. self.ckpt_tool = tf.estimator.Estimator(model_fn=model_fn, config=self.run_config, )
  30. exam = self.processor.one_example(sentence) # 待预测的样本列表
  31. feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
  32. predict_input_fn = input_fn_builder(features=[feature, ],
  33. seq_length=arg_dic['max_seq_length'], is_training=False,
  34. drop_remainder=False)
  35. result = self.ckpt_tool.predict(input_fn=predict_input_fn) # 执行预测操作,得到一个生成器。
  36. gailv = list(result)[0]["probabilities"].tolist()
  37. pos = gailv.index(max(gailv)) # 定位到最大概率值索引,
  38. return label_list[pos]

2、pb文件,直接包含图结构和变量值,加载时只需要一个文件即可

1)保存:

  1. pb_file = os.path.join(arg_dic['pb_model_dir'], 'classification_model.pb')
  2. graph = tf.Graph()
  3. with graph.as_default():
  4. input_ids = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_ids')
  5. input_mask = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_mask')
  6. bert_config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
  7. loss, per_example_loss, logits, probabilities = create_classification_model(
  8. bert_config=bert_config, is_training=False,
  9. input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
  10. probabilities = tf.identity(probabilities, 'pred_prob')
  11. saver = tf.train.Saver()
  12. with tf.Session() as sess:
  13. sess.run(tf.global_variables_initializer())
  14. latest_checkpoint = tf.train.latest_checkpoint(arg_dic['output_dir'])
  15. saver.restore(sess, latest_checkpoint)
  16. from tensorflow.python.framework import graph_util
  17. tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob'])
  18. # 存储二进制模型到文件中
  19. with tf.gfile.GFile(pb_file, 'wb') as f:
  20. f.write(tmp_g.SerializeToString())
  21. return pb_file
  22. except Exception as e:
  23. print('fail to optimize the graph! %s', e)

2)加载:

  1. def classification_model_fn(self, features, mode):
  2. with tf.gfile.GFile(self.graph_path, 'rb') as f:
  3. graph_def = tf.GraphDef()
  4. graph_def.ParseFromString(f.read())
  5. input_ids = features["input_ids"]
  6. input_mask = features["input_mask"]
  7. input_map = {"input_ids": input_ids, "input_mask": input_mask}
  8. pred_probs = tf.import_graph_def(graph_def, name='', input_map=input_map, return_elements=['pred_prob:0'])
  9. return EstimatorSpec(mode=mode, predictions={
  10. 'encodes': tf.argmax(pred_probs[0], axis=-1),
  11. 'score': tf.reduce_max(pred_probs[0], axis=-1)})
  12. def predict_on_pb(self, sentence):
  13. if not self.pbTool:
  14. self.pbTool = tf.estimator.Estimator(model_fn=self.classification_model_fn, config=self.run_config, )
  15. exam = self.processor.one_example(sentence) # 待预测的样本列表
  16. feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
  17. predict_input_fn = input_fn_builder(features=[feature, ],
  18. seq_length=arg_dic['max_seq_length'], is_training=False,
  19. drop_remainder=False)
  20. result = self.pbTool.predict(input_fn=predict_input_fn) # 执行预测操作,得到一个生成器。
  21. ele = list(result)[0]
  22. print('类别:{},置信度:{:.3f}'.format(label_list[ele['encodes']], ele['score']))
  23. return label_list[ele['encodes']]

3)tf.saved_model 保存得到一个pb文件和一个variables文件夹。

保存两种方式:

  1. 1)graph = tf.Graph()
  2. with graph.as_default():
  3. input_ids = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_ids')
  4. input_mask = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_mask')
  5. bert_config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
  6. loss, per_example_loss, logits, probabilities = create_classification_model(
  7. bert_config=bert_config, is_training=False,
  8. input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
  9. probabilities = tf.identity(probabilities, 'pred_prob')
  10. saver = tf.train.Saver()
  11. with tf.Session() as sess:
  12. sess.run(tf.global_variables_initializer())
  13. latest_checkpoint = tf.train.latest_checkpoint(arg_dic['output_dir'])
  14. saver.restore(sess, latest_checkpoint)
  15. path_pb_model = "pb_save_test"
  16. builder = tf.saved_model.builder.SavedModelBuilder(path_pb_model) # 创建一个保存模型的实例对象
  17. # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
  18. input_ids1 = tf.saved_model.utils.build_tensor_info(input_ids)
  19. input_mask1 = tf.saved_model.utils.build_tensor_info(input_mask)
  20. probabilities1 = tf.saved_model.utils.build_tensor_info(probabilities)
  21. # 构建 SignatureDef protobuf
  22. signature_def = tf.saved_model.signature_def_utils.build_signature_def(
  23. inputs={'input_ids': input_ids1, 'input_mask': input_mask1},
  24. outputs={'probabilities':probabilities1 },
  25. method_name='test')
  26. # 将 graph 和变量等信息写入 MetaGraphDef protobuf
  27. # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,可以使用预定义的这些值
  28. builder.add_meta_graph_and_variables(sess,
  29. tags=[tf.saved_model.tag_constants.SERVING],
  30. signature_def_map={tf.saved_model.signature_constants.CLASSIFY_INPUTS: signature_def})
  31. # 将 MetaGraphDef 写入磁盘
  32. builder.save()
  1. 2)estimator
  2. def serving_input_receiver_fn():
  3. """
  4. 用于在serving时,接收数据
  5. :return:
  6. """
  7. feature_spec = {
  8. "input_ids": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
  9. "input_mask": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
  10. "segment_ids": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
  11. }
  12. serialized_tf_example = tf.placeholder(dtype=tf.string,
  13. shape=[None],
  14. name='input_example_tensor')
  15. receiver_tensors = {'examples': serialized_tf_example}
  16. features = tf.parse_example(serialized_tf_example, feature_spec)
  17. return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
  18. if arg_dic['do_predict']:
  19. estimator._export_to_tpu = False
  20. estimator.export_savedmodel("pb_save_test", serving_input_receiver_fn)

模型的加载:

  1. tokenization.validate_case_matches_checkpoint(arg_dic['do_lower_case'], arg_dic['init_checkpoint'])
  2. self.config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
  3. if arg_dic['max_seq_length'] > self.config.max_position_embeddings:
  4. raise ValueError(
  5. "Cannot use sequence length %d because the BERT model "
  6. "was only trained up to sequence length %d" %
  7. (arg_dic['max_seq_length'], self.config.max_position_embeddings))
  8. # tf.gfile.MakeDirs(self.out_dir)
  9. self.tokenizer = tokenization.FullTokenizer(vocab_file=arg_dic['vocab_file'],
  10. do_lower_case=arg_dic['do_lower_case'])
  11. self.processor = SelfProcessor()
  12. global label_list
  13. label_list = self.processor.get_labels()
  14. self.predict_fn = tf.contrib.predictor.from_saved_model("/home/hadoop-health-alg/TextClassify_with_BERT/pb_save_test")
  15. def predict_on_pb(self, sentence):
  16. exam = self.processor.one_example(sentence) # 待预测的样本列表
  17. feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
  18. features = dict()
  19. features['input_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.input_ids))
  20. features['input_mask'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.input_mask))
  21. features['segment_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.segment_ids))
  22. tmp_feature = {"input_ids":feature.input_ids,"input_mask":feature.input_mask}
  23. examples = []
  24. example = tf.train.Example(features=tf.train.Features(feature=features))
  25. examples.append(example.SerializeToString())
  26. predictions = self.predict_fn({'examples': examples})
  27. result = predictions['probabilities']
  28. result = result.tolist()
  29. pos = result[0].index(max(result[0])) # 定位到最大概率值索引,
  30. print("hahahah",result,label_list,pos)
  31. return label_list[pos]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/348799
推荐阅读
相关标签
  

闽ICP备14008679号