当前位置:   article > 正文

Tensorflow2.0笔记 - where,scatter_nd, meshgrid相关操作

Tensorflow2.0笔记 - where,scatter_nd, meshgrid相关操作

        本笔记记录tf.where进行元素位置查找,scatter_nd用于指派元素到tensor的特定位置,meshgrid用作绘图的相关操作。

  1. import tensorflow as tf
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. tf.__version__
  5. #where操作查找元素位置
  6. #输入的tensor是True,False组成的tensor
  7. tensor = tf.random.uniform([3,3], minval=-12, maxval=12, dtype=tf.int32)
  8. print(tensor.numpy())
  9. #获得大于0的值的mask
  10. mask = tensor > 0
  11. print(mask)
  12. #方式1:通过boolean_mask获得大于0的元素的值
  13. print("=====tf.boolean_mask(tensor, mask):\n", tf.boolean_mask(tensor, mask).numpy())
  14. #方式2:先通过where查询到大于0的元素位置,然后用gather_nd收集
  15. indices = tf.where(mask)
  16. print("=====indices for the ones greater than 0:\n", indices.numpy())
  17. print("=====tf.gather_nd(tensor, indices):\n", tf.gather_nd(tensor, indices))
  18. #where带条件选择元素
  19. #where(cond, tensor1, tensor2)
  20. #传入cond,如果cond对应位置为True,会收集tensor1对应位置的元素,否则收集tensor2对应位置的元素
  21. tensor1 = tf.random.uniform([3,3], minval=-12, maxval=12, dtype=tf.int32)
  22. tensor2 = tf.random.uniform([3,3], minval=-12, maxval=12, dtype=tf.int32)
  23. print(tensor1)
  24. print(tensor2)
  25. cond = tensor1 > 0
  26. print("=====Condition:\n", cond)
  27. print("=====where(cond, tensor1, tensor2):\n", tf.where(cond, tensor1, tensor2))
  28. #scatter_nd将元素放到对应位置,其他值为0
  29. #scatter_nd(indices, updates, shape)
  30. #indices指定要更新到的位置
  31. #updates指定更新的值
  32. #shape表示tensor的形状
  33. #1维tensor的例子
  34. indices = tf.constant([[4], [3], [1], [9]])
  35. updates = tf.constant([6, 7, 8, 9])
  36. shape = tf.constant([10])
  37. print("=====tf.scatter_nd(indices, updates, shape):\n", tf.scatter_nd(indices, updates, shape))
  38. #多维tensor的scatrer_nd
  39. # shape为5x4x4
  40. #将值更新到大维度的02处,实际对应一个4x4的tensor
  41. indices = tf.constant([[0], [2], [4]])
  42. updates = tf.constant([[
  43. [1, 1, 1, 1],
  44. [1, 1, 1, 1],
  45. [1, 1, 1, 1],
  46. [1, 1, 1, 1],
  47. ],
  48. [
  49. [2, 2, 2, 2],
  50. [2, 2, 2, 2],
  51. [2, 2, 2, 2],
  52. [2, 2, 2, 2],
  53. ],
  54. [
  55. [3, 3, 3, 3],
  56. [3, 3, 3, 3],
  57. [3, 3, 3, 3],
  58. [3, 3, 3, 3],
  59. ]])
  60. shape = tf.constant([5,4,4])
  61. print("=====tf.scatter_nd(indices, updates, shape):\n", tf.scatter_nd(indices, updates, shape))
  62. #meshgrid绘图
  63. #1. 设置x和y的linspace
  64. y = tf.linspace(-2., 2, 5)
  65. x = tf.linspace(-2., 2, 5)
  66. #获得坐标点tensor
  67. xPoints, yPoints = tf.meshgrid(x, y)
  68. print("X points:\n", xPoints)
  69. print("Y points:\n", yPoints)
  70. #通过tf.stack获得点的xy集合
  71. points = tf.stack([xPoints, yPoints], axis=2)
  72. print("Collection of XY points on plane:\n", points)
  73. #meshgrid实例,z = sin(x) +sin(y)
  74. x = tf.linspace(0., 2 * 3.14, 500)
  75. y = tf.linspace(0., 2 * 3.14, 500)
  76. xPoints, yPoints = tf.meshgrid(x, y)
  77. points = tf.stack([xPoints, yPoints], axis=2)
  78. z = tf.math.sin(points[..., 0]) + tf.math.sin(points[..., 1])
  79. #绘制z的值
  80. plt.figure('z = sin(x) + sin(y)')
  81. plt.imshow(z, origin='lower', interpolation='none')
  82. plt.colorbar()
  83. #绘制等高线
  84. plt.figure('plot contour')
  85. plt.contour(xPoints, yPoints, z)
  86. plt.colorbar()
  87. plt.show()

        运行结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/blog/article/detail/53390
推荐阅读
相关标签
  

闽ICP备14008679号