当前位置:   article > 正文

深度学习基础之《TensorFlow框架(5)—会话》

深度学习基础之《TensorFlow框架(5)—会话》

一、会话

2.x版本由于是即时执行模式,所以不需要会话。但是可以手工开启会话

1、什么是会话
一个运行TensorFlow operation的类。会话包含以下两种开启方式
(1)tf.compat.v1.Session:用于完整的程序当中
(2)tf.compat.v1.InteractiveSession:用于交互式上下文中的TensorFlow,比如想验证下自己的想法

2、InteractiveSession例子
在2.x版本中没有eval()函数了,用numpy()函数代替

  1. ipython
  2. Python 3.6.8 (default, Nov 14 2023, 16:29:52)
  3. Type 'copyright', 'credits' or 'license' for more information
  4. IPython 7.16.3 -- An enhanced Interactive Python. Type '?' for help.
  5. In [1]: import os
  6. In [2]: os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
  7. In [3]: import tensorflow as tf
  8. In [4]: tf.compat.v1.InteractiveSession()
  9. Out[4]: <tensorflow.python.client.session.InteractiveSession at 0x7fa7efc09e48>
  10. In [5]: a = tf.constant(3)
  11. In [6]: a
  12. Out[6]: <tf.Tensor: shape=(), dtype=int32, numpy=3>
  13. In [7]: a.eval()
  14. ---------------------------------------------------------------------------
  15. NotImplementedError Traceback (most recent call last)
  16. <ipython-input-7-90f91b557aeb> in <module>
  17. ----> 1 a.eval()
  18. /usr/local/lib64/python3.6/site-packages/tensorflow/python/framework/ops.py in eval(self, feed_dict, session)
  19. 1279 def eval(self, feed_dict=None, session=None):
  20. 1280 raise NotImplementedError(
  21. -> 1281 "eval is not supported when eager execution is enabled, "
  22. 1282 "is .numpy() what you're looking for?")
  23. 1283
  24. NotImplementedError: eval is not supported when eager execution is enabled, is .numpy() what you're looking for?
  25. In [8]: a.numpy()
  26. Out[8]: 3

3、2.0不用专门教程,主要改变就是不用定义session了,2.0采用动态图,一旦print就会立即返回数值

4、tf.Session.close函数
会话可能拥有资源,如tf.Variable,tf.queue.QueueBase,tf.compat.v1.ReaderBase
当这些资源不再需要时,释放这些资源非常重要。因此,需要在会话中调用tf.Session.close函数,或将会话用作上下文管理器

  1. # Using the `close()` method.
  2. sess = tf.compat.v1.Session()
  3. sess.run(...)
  4. sess.close()
  5. # Using the context manager.
  6. with tf.compat.v1.Session() as sess:
  7. sess.run(...)

5、tf.compat.v1.Session(target='', graph=None, config=None)
说明:
target:如果将此参数留空(默认设置),会话将仅使用本地计算机中的设备。可以指定grpc://网址,以便指定TensorFlow服务器的地址,这使得会话可以访问该服务器控制的计算机上的所有设备
graph:默认情况下,新的Session将绑定到当前的默认图
config:此参数允许您指定一个tf.compat.v1.ConfigProto以便控制会话的行为。例如ConfigProto协议用于打印设备使用信息

  1. # 运行会话并打印设备信息
  2. sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
  3. allow_soft_placement=True,
  4. log_device_placement=True))

6、run方法
run(fetches, feed_dict=None, options=None, run_metadata=None)
说明:
通过使用sess.run()来运行operation
fetches:单一的operation,或者列表、元组(其他不属于tensorflow的类型不行)
feed_dict:参数允许调用者覆盖图中张量的值(将图形元素映射到值的字典上),运行时赋值。与tf.compat.v1.placeholder搭配使用,则会检查图的形状是否与占位符兼容
placeholder:提供占位符,run时候通过feed_dict指定参数

  1. import os
  2. os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
  3. import tensorflow as tf
  4. def tensorflow_demo():
  5. """
  6. TensorFlow的基本结构
  7. """
  8. # TensorFlow实现加减法运算
  9. a_t = tf.constant(2)
  10. b_t = tf.constant(3)
  11. c_t = a_t + b_t
  12. print("TensorFlow加法运算结果:\n", c_t)
  13. print(c_t.numpy())
  14. # 2.0版本不需要开启会话,已经没有会话模块了
  15. return None
  16. def graph_demo():
  17. """
  18. 图的演示
  19. """
  20. # TensorFlow实现加减法运算
  21. a_t = tf.constant(2)
  22. b_t = tf.constant(3)
  23. c_t = a_t + b_t
  24. print("TensorFlow加法运算结果:\n", c_t)
  25. print(c_t.numpy())
  26. # 查看默认图
  27. # 方法1:调用方法
  28. default_g = tf.compat.v1.get_default_graph()
  29. print("default_g:\n", default_g)
  30. # 方法2:查看属性
  31. # print("a_t的图属性:\n", a_t.graph)
  32. # print("c_t的图属性:\n", c_t.graph)
  33. # 自定义图
  34. new_g = tf.Graph()
  35. # 在自己的图中定义数据和操作
  36. with new_g.as_default():
  37. a_new = tf.constant(20)
  38. b_new = tf.constant(30)
  39. c_new = a_new + b_new
  40. print("c_new:\n", c_new)
  41. print("a_new的图属性:\n", a_new.graph)
  42. print("b_new的图属性:\n", b_new.graph)
  43. # 开启new_g的会话
  44. with tf.compat.v1.Session(graph=new_g) as sess:
  45. c_new_value = sess.run(c_new)
  46. print("c_new_value:\n", c_new_value)
  47. print("我们自己创建的图为:\n", sess.graph)
  48. # 可视化自定义图
  49. # 1)创建一个writer
  50. writer = tf.summary.create_file_writer("./tmp/summary")
  51. # 2)将图写入
  52. with writer.as_default():
  53. tf.summary.graph(new_g)
  54. return None
  55. def session_run_demo():
  56. """
  57. feed操作
  58. """
  59. tf.compat.v1.disable_eager_execution()
  60. # 定义占位符
  61. a = tf.compat.v1.placeholder(tf.float32)
  62. b = tf.compat.v1.placeholder(tf.float32)
  63. sum_ab = tf.add(a, b)
  64. print("a:\n", a)
  65. print("b:\n", b)
  66. print("sum_ab:\n", sum_ab)
  67. # 开启会话
  68. with tf.compat.v1.Session() as sess:
  69. print("占位符的结果:\n", sess.run(sum_ab, feed_dict={a: 1.1, b: 2.2}))
  70. return None
  71. if __name__ == "__main__":
  72. # 代码1:TensorFlow的基本结构
  73. # tensorflow_demo()
  74. # 代码2:图的演示
  75. #graph_demo()
  76. # feed操作
  77. session_run_demo()
  1. python3 day01_deeplearning.py
  2. a:
  3. Tensor("Placeholder:0", dtype=float32)
  4. b:
  5. Tensor("Placeholder_1:0", dtype=float32)
  6. sum_ab:
  7. Tensor("Add:0", dtype=float32)
  8. 占位符的结果:
  9. 3.3000002

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/128256
推荐阅读
相关标签
  

闽ICP备14008679号