当前位置:   article > 正文

tf.keras快速入门——自定义损失函数(一)_keras自定义模型损失函数

keras自定义模型损失函数

官网的自定义损失函数here
在官网中,我们可以看见给出了两种方法来使用Keras提供自定义损失。

1. 简单的均方误差

def custom_mean_squared_error(y_true, y_pred):
    return tf.math.reduce_mean(tf.square(y_true - y_pred))
    
model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)
  • 1
  • 2
  • 3
  • 4

这种情况下,y_truey_pred的值会自动传入该损失函数中,且仅有实际输出、预测输出的两个默认参数,所以这种方式仅仅适合于简单的损失函数。如果您需要一个使用除y_truey_pred之外的其他参数的损失函数,则可以将tf.keras.losses.Loss类子类化。

1.1 简单案例

这里还是以鸢尾花分类为例,我们简单的使用序列模型来定义所需要的模型,然后在compile中指定我们自定义的损失函数。由于我们在序列模型中使用的损失函数是sparse_categorical_crossentropy,这里先简单记录下这两个的区别:

1.1.1 sparse_categorical_crossentropy & categorical_crossentropy
  • 如果yone-hot encoding格式(即独热编码的向量格式),使用sparse_categorical_crossentropy
  • 如果y是整数,非one-hot encoding格式,使用categorical_crossentropy

简单抄下,在鸢尾花分类中的代码,即:

from sklearn.datasets import load_iris
x_data = load_iris().data  # 特征,【花萼长度,花萼宽度,花瓣长度,花瓣宽度】
y_data = load_iris().target # 分类

import tensorflow as tf
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, input_shape=(4,), activation='relu'))
model.add(tf.keras.layers.Dense(3, input_shape=(4,), activation='softmax'))
model.compile(
    optimizer="adam", 
    loss="sparse_categorical_crossentropy",  # 三分类的结果,已经需要使用独热码来表示。故而不能使用categorical_crossentropy
    metrics=['accuracy']
)
model.fit(x_data, y_data, epochs=100)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

由于鸢尾花是三种类别,也就是三分类的结果。而不是非01的结果,故而这里需要使用sparse_categorical_crossentropy。但其本质都是交叉熵。我们都知道交叉熵的公式可以表示为:
H ( p , q ) = − ∑ i y ( x i ) l o g ( y ^ ( x i ) ) H(p, q) = -\sum_i y(x_i) log(\hat y(x_i)) H(p,q)=iy(xi)log(y^(xi))
y ^ ( x i ) \hat y(x_i) y^(xi)表示 x i x_i xi的预测分布,而 y ( x i ) y(x_i) y(xi)表示在训练数据中的类别的概率分布。
但是,由于如果需要使用交叉熵,我需要获取到最终的类别的下标值,而对于y_truey_pred这两个均是tf.Tensor类型的对象,却没办法获取到其对应的值,故而这里宣告失败了。转而,还是使用平方损失函数来解决:

from sklearn.datasets import load_iris
x_data = load_iris().data  # 特征,【花萼长度,花萼宽度,花瓣长度,花瓣宽度】
y_data = load_iris().target # 分类
# 转化为独热编码
y_data_one_hot = tf.one_hot(y_data, depth=3) # 3分类

import tensorflow as tf
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, input_shape=(4,), activation='relu'))
model.add(tf.keras.layers.Dense(3, input_shape=(4,), activation='softmax'))


def custom_mean_squared_error(y_true, y_pred):
    return tf.math.reduce_mean(tf.square(y_true - y_pred))


model.compile(
    optimizer="adam", 
    loss=custom_mean_squared_error,
    metrics=['acc']
)
history = model.fit(x_data, y_data_one_hot, epochs=300)


for key in history.history.keys():
    plt.plot(history.epoch, history.history[key])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

在这里插入图片描述
当然,还又另一种方式,也就是直接调用tensorflow提供的交叉熵损失函数,即:

# 转化为独热编码
y_data_one_hot = tf.one_hot(y_data, depth=3) # 3分类

import tensorflow as tf
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, input_shape=(4,), activation='relu'))
model.add(tf.keras.layers.Dense(3, input_shape=(4,), activation='softmax'))


def custom_mean_squared_error(y_true, y_pred):
    return tf.losses.categorical_crossentropy(y_true, y_pred)


model.compile(
    optimizer="adam", 
    loss=custom_mean_squared_error,
    metrics=['acc']
)
history = model.fit(x_data, y_data_one_hot, epochs=300)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

or

from sklearn.datasets import load_iris
x_data = load_iris().data  # 特征,【花萼长度,花萼宽度,花瓣长度,花瓣宽度】
y_data = load_iris().target # 分类
# 转化为独热编码
#y_data_one_hot = tf.one_hot(y_data, depth=3) # 3分类

import tensorflow as tf
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, input_shape=(4,), activation='relu'))
model.add(tf.keras.layers.Dense(3, input_shape=(4,), activation='softmax'))


def custom_mean_squared_error(y_true, y_pred):
    return tf.losses.sparse_categorical_crossentropy(y_true, y_pred)


model.compile(
    optimizer="adam", 
    loss=custom_mean_squared_error,
    metrics=['acc']
)
history = model.fit(x_data, y_data, epochs=300)


for key in history.history.keys():
    plt.plot(history.epoch, history.history[key])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

2. 实例化Loss类

类似的,该类需要继承自tf.keras.losses.Loss,然后需要复写下面的两个方法:

  • __init__(self):接受要在调用损失函数期间传递的参数;
  • call(self, y_true, y_pred):使用目标 y_true 和模型预测 y_pred 来计算模型的损失;

如,官网案例:mse, 存在一个会抑制预测值远离 0.5

class CustomMSE(keras.losses.Loss):
    def __init__(self, regularization_factor=0.1, name="custom_mse"):
        super().__init__(name=name)
        self.regularization_factor = regularization_factor

    def call(self, y_true, y_pred):
        mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
        reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
        return mse + reg * self.regularization_factor
        
model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

由于比较简单,就不再套鸢尾花分类案例了。但从上面的官网的案例中也可以看见一个问题就是:
这里的简单的实例化子类来定义损失函数,其实和上面的直接定义一个function来作为其损失函数类似。因为传入call中的参数还是只有y_truey_pred两个值,然后需要在上面做文章,也是比较困难的。
在下篇中,将讨论一个更加灵活的自定义损失函数的方式,其实在前面的博客中,也提及过,即:复杂度&学习率&损失函数,如:

import tensorflow as tf
x = tf.random.normal([20, 2], mean=2, stddev=1, dtype=tf.float32)
y = [item1 + 2 * item2 for item1, item2 in x]
w = tf.Variable(tf.random.normal([2, 1], mean=0, stddev=1))

epoch = 5000
lr = 0.002
for epoch in range(epoch):
    with tf.GradientTape() as tape:
        y_hat = tf.matmul(x, w)
        loss = tf.reduce_mean(tf.where(tf.greater(y_hat, y), 3*(y_hat - y), y-y_hat))
    w_grad = tape.gradient(loss, w)
    w.assign_sub(lr * w_grad)

print(w.numpy().T) # [[0.73728406 0.83368826]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/133205
推荐阅读
相关标签
  

闽ICP备14008679号