赞
踩
Pytorch 中的主要的数据结构包括标量、向量、矩阵、张量,同时支持数据之间的运算。在 Pytorch 中有一个张量广播的概念,就是要把小的放大,最后在一起做计算,并不是所有的张量都可以计算,规则如下
不同维度:
# 3, 2, 2
tensor_a = torch.tensor([[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]]]) # Shape (3, 2, 2)
# 2,2
tensor_b = torch.tensor([[0.1, 0.2],
[0.3, 0.4]])
result = tensor_a + tensor_b
结果
[
[
[ 1.1000, 2.2000], [ 3.3000, 4.4000]
],
[
[ 5.1000, 6.2000], [ 7.3000, 8.4000]
],
[
[ 9.1000, 10.2000], [11.3000, 12.4000]
]
]
广播方便了对张量的操作,例如我们想生成一张绿色北京的图,初始一张图,把绿色通道变成 255。
import torch import matplotlib.pyplot as plt width, height = 256, 256 green_color = torch.tensor([0, 255, 0], dtype=torch.uint8) green_image = torch.zeros((height, width, 3), dtype=torch.uint8) + green_color green_image_np = green_image.numpy() plt.figure(figsize=(6, 6)) plt.imshow(green_image_np) plt.axis('off') plt.title('Green') plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。