赞
踩
tf.keras.layers.Layer 是所有 Keras 层的基类,它继承自 tf.Module。
定义call
您只需换出父项,然后将 __call__ 更改为 call 即可将模块转换为 Keras 层
定义build
定义输入的shape.build 仅被调用一次,而且是使用输入形状时调用的。它通常用于创建变量(权重)
- import tensorflow as tf
-
- class MyLayer(tf.keras.layers.Layer):
- # Note the added `**kwargs`, as Keras supports many arguments
- def __init__(self , out_features, **kwargs):
- super().__init__(**kwargs)
- self.out_features = out_features
-
- # 创建变量(权重)
- def build(self, input_shape):
- self.w = tf.Variable(
- tf.random.normal([input_shape[-1], self.out_features]),
- name='w'
- )
- self.b = tf.Variable(tf.zeros([self.out_features]),
- name='b')
-
- # 定义前项传递的计算
- def call(self, inputs):
- return tf.matmul(inputs, self.w) + self.b
-
- #
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。