当前位置:   article > 正文

Keras load_model raise ValueError: Unknown layer: TokenEmbedding问题_unknown layer: tcn. please ensure this object is p

unknown layer: tcn. please ensure this object is passed to the `custom_objec

问题复现: 

训练的模型存储方式: 

model.save_weights(model_path)

模型加载方式

  1. from keras.models import load_model
  2. model = load_model(model_path)

异常:

ValueError: Unknown layer: TokenEmbedding

出现该错误是因为要保存的model中包含了自定义的层(Custom Layer),导致加载模型的时候无法解析该Layer
解决该问题的方法是在load_model函数中添加custom_objects参数,该参数接受一个字典,键值为自定义的层,当然也可以偷懒,直接使用custom_objects=get_custom_objects()

解决方案:

  1. from keras.models import load_model
  2. from keras_bert import get_custom_objects
  3. model = load_model(model_path, custom_objects=get_custom_objects())

keras_bert中提供的get_custom_objects() 中包含如下元素: 

{'LayerNormalization': <class 'keras_layer_normalization.layer_normalization.LayerNormalization'>, 'MultiHeadAttention': <class 'keras_multi_head.multi_head_attention.MultiHeadAttention'>, 'FeedForward': <class 'keras_position_wise_feed_forward.feed_forward.FeedForward'>, 'TrigPosEmbedding': <class 'keras_pos_embd.trig_pos_embd.TrigPosEmbedding'>, 'EmbeddingRet': <class 'keras_embed_sim.embeddings.EmbeddingRet'>, 'EmbeddingSim': <class 'keras_embed_sim.embeddings.EmbeddingSim'>, 'PositionEmbedding': <class 'keras_pos_embd.pos_embd.PositionEmbedding'>, 'TokenEmbedding': <class 'keras_bert.layers.embedding.TokenEmbedding'>, 'EmbeddingSimilarity': <class 'keras_bert.layers.embedding.EmbeddingSimilarity'>, 'Masked': <class 'keras_bert.layers.masked.Masked'>, 'Extract': <class 'keras_bert.layers.extract.Extract'>, 'gelu': <function gelu_tensorflow at 0x7fc1a0aed620>, 'gelu_tensorflow': <function gelu_tensorflow at 0x7fc1a0aed620>, 'gelu_fallback': <function gelu_fallback at 0x7fc1a0aed730>, 'AdamWarmup': <class 'keras_bert.optimizers.warmup.AdamWarmup'>}

当然,如果有些自定义的object不在上面这里面的话,那么需要将自定义的object放到里面,如下所示: 

  1. from keras.models import load_model
  2. custom_objects = get_custom_objects()
  3. my_objects = {'RAdam': RAdam}
  4. custom_objects.update(my_objects)
  5. model = load_model(model_path, custom_objects=custom_objects)

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/499797
推荐阅读
相关标签
  

闽ICP备14008679号