当前位置:   article > 正文

【PyTorch】PyTorch中张量(Tensor)拼接和拆分操作_pytorch 张亮拼接

pytorch 张亮拼接

PyTorch深度学习总结

第四章 PyTorch中张量(Tensor)拼接和拆分操作



前言

上文介绍了PyTorch中张量(Tensor)的切片操作,本文主要介绍张量的拆分拼接操作。


一、张量拼接

函数描述
torch.cat()将张量按照指定维度关系进行拼接
torch.stack()将张量按照指定维度关系进行拼接(用法同cat相同
# 引入库
import torch

# 创建张量
A = torch.arange(9).reshape(1, 3, 3)
print(A)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

输出结果为:
tensor(
[[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])


1、按照维度1进行拼接:

B0 = torch.cat((A, A), dim=0)
print(B0)
  • 1
  • 2

输出结果为:
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]],
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])


1、按照维度2()进行拼接:

B1 = torch.cat((A, A), dim=2)
print(B1)
  • 1
  • 2

输出结果为:
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])


1、按照维度3()进行拼接:

B2 = torch.cat((A, A), dim=2)
print(B2)
  • 1
  • 2

输出结果为:
tensor([[[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5],
[6, 7, 8, 6, 7, 8]]])

二、张量拆分

函数描述
torch.chunk()将张量分割为特定数量的块(当张量对应维度元素数量不足以拆分时会按照可以拆分数量进行拆分,且会出现不均等拆分情况)
torch.split()将张量分割为特定数量的块,可以指定块的大小

注意:
torch.chunk():当张量对应维度元素数量不足以拆分时,会按照可以拆分的最大数量进行拆分,且会出现不均等拆分情况,且最后一个块最小

下文使用B0进行示例

B0 = tensor([[[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]],
        [[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

1、torch.chunk()按照维度1进行拆分:

C1, C2 = torch.chunk(B0, 2, dim=1) # 维度1只有三组元素,所以会按照2:1的比例进行拆分
print(C1, C2)
  • 1
  • 2

输出结果为:
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])


1、torch.chunk()按照维度2进行拆分:

D1, D2 = torch.chunk(B0, 2, dim=1) # 3表示指定拆分数,但由于不足以拆分,所以只会拆分两组
print(D1, D2)
  • 1
  • 2

输出结果为:
tensor([[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]]])
tensor([[[6, 7, 8]],
[[6, 7, 8]]])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/127557
推荐阅读
相关标签
  

闽ICP备14008679号