当前位置:   article > 正文

人工智能AI编程基础(九)_ai人工智能编程代码

ai人工智能编程代码

tensor切片的方法在实践中大量运用,其中涉及到多维度的切片操作,有时还是挺让人头晕的。

tf.gather()的下标取值和切片的方法:
  1. import tensorflow as tf
  2. from datetime import datetime
  3. import numpy as np
  4. def pprint(*args, **kwargs):
  5. print(datetime.now(), *args, **kwargs, end='\n' + '*' * 50 + '\n')
  6. params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
  7. pprint(params[3].numpy()) # 获取第4个
  8. pprint(tf.gather(params, 3).numpy()) # 获取第4个
  9. pprint(tf.gather(params, indices=[2, 0, 2, 5]).numpy()) # 分别获取第3个,第1个,第3个,第6个
  10. pprint(tf.gather(params, [[2, 0], [2, 5]]).numpy()) # 分别取下标的值,然后生成一个2 * 2的数组
  11. params = tf.constant([[0, 1.0, 2.0],
  12. [10.0, 11.0, 12.0],
  13. [20.0, 21.0, 22.0],
  14. [30.0, 31.0, 32.0]])
  15. pprint(tf.gather(params, indices=[3, 1])) # 第4个下标和第1个
  16. # 如果axis=0,则沿着纵轴进行操作;
  17. # 如果axis=1,则沿着横轴进行操作
  18. pprint(tf.gather(params, indices=[2, 1], axis=1).numpy())
  19. # 多维度下标取值
  20. params = tf.constant([
  21. [0, 0, 1, 0, 2],
  22. [3, 0, 0, 0, 4],
  23. [0, 5, 0, 6, 0]])
  24. indices = tf.constant([
  25. [2, 4],
  26. [0, 4],
  27. [1, 3]])
  28. pprint(tf.gather(params, indices, axis=1, batch_dims=1).numpy())
  29. #################################################################
  30. a = tf.random.normal([4, 35, 8])
  31. pprint(tf.gather(a, axis=1, indices=[2, 3, 7, 9, 16]).shape) # axis=1就是第2个维度的变化
  32. pprint(tf.gather(a, axis=2, indices=[2, 3, 7]).shape) # axis=2就是最里面的维度,所以是[4,35,3]
  33. #################################################################
  34. # array([[b'c0', b'd0'],
  35. # [b'a1', b'b1']], dtype=object)
  36. result = tf.gather_nd(indices=[[0, 1], [1, 0]],
  37. params=[[['a0', 'b0'], ['c0', 'd0']],
  38. [['a1', 'b1'], ['c1', 'd1']]]).numpy()
  39. pprint(result)
  40. # array([b'b0', b'b1'], dtype=object),由外向内
  41. pprint(tf.gather_nd(indices=[[0, 0, 1], [1, 0, 1]],
  42. params=[[['a0', 'b0'], ['c0', 'd0']],
  43. [['a1', 'b1'], ['c1', 'd1']]]).numpy())
  44. # array([[[[b'a1', b'b1'],
  45. # [b'c1', b'd1']]],
  46. # [[[b'a0', b'b0'],
  47. # [b'c0', b'd0']]]], dtype=object)
  48. pprint(tf.gather_nd(indices=[[[1]], [[0]]],
  49. params=[[['a0', 'b0'], ['c0', 'd0']],
  50. [['a1', 'b1'], ['c1', 'd1']]]).numpy())
tf.boolean_mask()数据过滤的方法:
  1. # 根据布尔值筛选值
  2. tensor = [0, 1, 2, 3]
  3. mask = np.array([True, False, True, False]) # 位置对应
  4. pprint(tf.boolean_mask(tensor, mask))
  5. #################################
  6. tensor = [[1, 2], [3, 4], [5, 6]]
  7. mask = np.array([True, False, True])
  8. pprint(tf.boolean_mask(tensor, mask)) # [[1,2],[5,6]]
  9. tensor = tf.random.normal([3, 4])
  10. mask = np.array([True, False, True])
  11. pprint('-----1', tf.boolean_mask(tensor, mask).shape) # shape=(2,4)
  12. tensor = tf.random.normal([4, 28, 28, 3])
  13. mask = np.array([True, True, False, False]) # 4维中取前组数据,所以输同是(2,28,28,3)
  14. pprint('-----2', tf.boolean_mask(tensor, mask).shape)
  15. # 在axis=3这个轴进行取值,[4,28,28,2]
  16. pprint('-----3', tf.boolean_mask(tensor, mask=[True, True, False], axis=3).shape)
  17. # 生成的数据是(3,4)
  18. pprint('-----4', tf.boolean_mask(tf.ones([2, 3, 4]), mask=[[True, False, False], [False, True, True]]).shape)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/332393
推荐阅读
相关标签
  

闽ICP备14008679号