当前位置:   article > 正文

Tensorflow2.0笔记 - 自定义Layer和Model

Tensorflow2.0笔记 - 自定义Layer和Model

        本笔记主要记录如何在tensorflow中实现自定的Layer和Model。详细内容请参考代码中的链接。

  1. import time
  2. import tensorflow as tf
  3. from tensorflow import keras
  4. from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
  5. tf.__version__
  6. #关于自定义layer和自定义Model的相关介绍,参考下面的链接:
  7. #https://tf.wiki/zh_hans/basic/models.html
  8. #https://blog.csdn.net/lzs781/article/details/104741958
  9. #自定义Dense层,继承自Layer
  10. class MyDense(layers.Layer):
  11. #需要实现__init__和call方法
  12. def __init__(self, input_dim, output_dim):
  13. super(MyDense, self).__init__()
  14. self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))
  15. self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))
  16. def call(self, inputs, training=None):
  17. out = inputs@self.kernel + self.bias
  18. return out
  19. #自定义Model,继承自Model
  20. class MyModel(keras.Model):
  21. #需要实现__init__和call方法
  22. def __init__(self):
  23. super(MyModel, self).__init__()
  24. #自定义5层MyDense自定义Layer
  25. self.fc1 = MyDense(28*28, 256)
  26. self.fc2 = MyDense(256, 128)
  27. self.fc3 = MyDense(128, 64)
  28. self.fc4 = MyDense(64, 32)
  29. self.fc5 = MyDense(32, 10)
  30. def call(self, inputs, trainning=None):
  31. x = self.fc1(inputs) #会调用MyDense的call方法
  32. x = tf.nn.relu(x)
  33. x = self.fc2(x)
  34. x = tf.nn.relu(x)
  35. x = self.fc3(x)
  36. x = tf.nn.relu(x)
  37. x = self.fc4(x)
  38. x = tf.nn.relu(x)
  39. x = self.fc5(x)
  40. return x
  41. customModel = MyModel()
  42. customModel.build(input_shape=[None, 28*28])
  43. customModel.summary()

运行结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/366376
推荐阅读
相关标签
  

闽ICP备14008679号