当前位置:   article > 正文

tf.clip_by_global_norm使用解析

tf.clip_by_global_norm

找了好几篇觉得都写得或者说翻译得不能让我很好地理解,所以自己找来官方文档翻译并记录了一下,以便以后自己查阅

官方文档可见这里,可能需要梯子

说明:该API在v1、v2的Tensoflow中用法一致

接下来进入正文。

  • 作用:简单来说,就是利用梯度裁剪的方式避免梯度爆炸,“梯度爆炸”自己可查阅相关资料理解。
  • 原型
tf.clip_by_global_norm(
    t_list,
    clip_norm,
    use_norm=None,
    name=None
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

给定张量t_list的元组或列表以及裁剪率clip_norm,此操作将返回裁剪后的list_clipped的张量列表以及t_list中所有张量的全局范数(global_norm)。 或者,如果您已经为t_list计算了全局范数,则可以使用use_norm指定全局范数。

  • 参数说明:
    • t_list:梯度张量,tuple或者mixed tensor、IndexedSlices或None的列表。
    • clip_norm: 标量,表示梯度裁剪的比例因子,在裁剪中,裁剪之后的梯度符合如下公式:
      t _ l i s t [ i ] = t _ l i s t [ i ] ∗ c l i p _ n o r m m a x ( g l o b a l _ n o r m , c l i p _ n o r m ) t\_list[i] = t\_list[i] * \frac {clip\_norm } {max(global\_norm, clip\_norm)} t_list[i]=t_list[i]max(global_norm,clip_norm)clip_norm
      其中,
      g l o b a l _ n o r m = s u m ( [ l 2 _ n o r m ( t ) 2   f o r   t   i n   t _ l i s t ] ) global\_norm = \sqrt{sum([l2\_norm(t)^2\ for\ t\ in\ t\_list])} global_norm=sum([l2_norm(t)2 for t in t_list])
      l2_norm代表L2范数。如果clip_norm> global_norm,则t_list中的梯度将保持原样,否则它们将被裁剪并按比率缩小;如果global_norm == infinity,则t_list中的所有条目均设置为NaN以表示发生了错误;否则,t_list将保持不变。
    • use_norm:可选参数,float类型的0-D(标量)张量,如果自己计算了global_norm,则可以利用该参数以提供。 默认为None,则代表Tensorflow将自己计算global_norm()范数。
    • name:可选参数,设定名字。
  • 返回参数:
    • list_clipped: 与list_t类型相同的张量列表,是裁剪后的梯度。
    • global_norm:一个表示全局范数的0-D(标量)张量。
  • 可能出现的ERROR:
    • TypeError:如果t_list不是序列。

注:官方文档中提了这样一句话:但是,它比clip_by_norm()慢,因为在执行剪切操作之前必须准备好所有参数。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/311524
推荐阅读
相关标签
  

闽ICP备14008679号