当前位置:   article > 正文

tf.name_scope and tf.variable_scope_可以看到变量名自行变成了'var2_1',避免了和'var2'冲突

可以看到变量名自行变成了'var2_1',避免了和'var2'冲突
作者:C Li
链接:https://www.zhihu.com/question/54513728/answer/181819324
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

在 tf.name_scope下时,tf.get_variable()创建的变量名不受 name_scope 的影响,而且在未指定共享变量时,如果重名会报错,tf.Variable()会自动检测有没有变量重名,如果有则会自行处理。

  1. import tensorflow as tf
  2. with tf.name_scope('name_scope_x'):
  3. var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
  4. var3 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
  5. var4 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
  6. with tf.Session() as sess:
  7. sess.run(tf.global_variables_initializer())
  8. print(var1.name, sess.run(var1))
  9. print(var3.name, sess.run(var3))
  10. print(var4.name, sess.run(var4))
  11. # 输出结果:
  12. # var1:0 [-0.30036557] 可以看到前面不含有指定的'name_scope_x'
  13. # name_scope_x/var2:0 [ 2.]
  14. # name_scope_x/var2_1:0 [ 2.] 可以看到变量名自行变成了'var2_1',避免了和'var2'冲突

如果使用tf.get_variable()创建变量,且没有设置共享变量,重名时会报错

  1. import tensorflow as tf
  2. with tf.name_scope('name_scope_1'):
  3. var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
  4. var2 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
  5. with tf.Session() as sess:
  6. sess.run(tf.global_variables_initializer())
  7. print(var1.name, sess.run(var1))
  8. print(var2.name, sess.run(var2))
  9. # ValueError: Variable var1 already exists, disallowed. Did you mean
  10. # to set reuse=True in VarScope? Originally defined at:
  11. # var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)

所以要共享变量,需要使用tf.variable_scope()

  1. import tensorflow as tf
  2. with tf.variable_scope('variable_scope_y') as scope:
  3. var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
  4. scope.reuse_variables() # 设置共享变量
  5. var1_reuse = tf.get_variable(name='var1')
  6. var2 = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
  7. var2_reuse = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
  8. with tf.Session() as sess:
  9. sess.run(tf.global_variables_initializer())
  10. print(var1.name, sess.run(var1))
  11. print(var1_reuse.name, sess.run(var1_reuse))
  12. print(var2.name, sess.run(var2))
  13. print(var2_reuse.name, sess.run(var2_reuse))
  14. # 输出结果:
  15. # variable_scope_y/var1:0 [-1.59682846]
  16. # variable_scope_y/var1:0 [-1.59682846] 可以看到变量var1_reuse重复使用了var1
  17. # variable_scope_y/var2:0 [ 2.]
  18. # variable_scope_y/var2_1:0 [ 2.]

也可以这样

  1. with tf.variable_scope('foo') as foo_scope:
  2. v = tf.get_variable('v', [1])
  3. with tf.variable_scope('foo', reuse=True):
  4. v1 = tf.get_variable('v')
  5. assert v1 == v

或者这样:

  1. with tf.variable_scope('foo') as foo_scope:
  2. v = tf.get_variable('v', [1])
  3. with tf.variable_scope(foo_scope, reuse=True):
  4. v1 = tf.get_variable('v')
  5. assert v1 == v
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/95976
推荐阅读
  

闽ICP备14008679号