当前位置:   article > 正文

Pytorch学习(3):Tensor合并、分割与基本运算

tensor合并


前言

Pytorch学习笔记第三篇,关于Tensor的合并(cat/stack)、分割(split/chunk)与基本运算。


一、合并Cat/Stack

1.Cat

Tensor中cat是contract的缩写,代表着两个张量Tensor在制定维度上进行合并,这就要求这两个张量Tensor在其余维度的长度一致。

代码如下(示例):

import torch
#1 cat
a=torch.rand(5,32,8)
b=torch.rand(4,32,8)
c=torch.cat([a,b],dim=0) #将a,b在0维上合并 ->[9.32.8]
  • 1
  • 2
  • 3
  • 4
  • 5

2.Stack

stack也可以用于Tensor的合并,但区别于cat,stack会在指定索引上创造一个新维度,因此stack要求原来的两个张量Tensor必须维度形状完全一致。
代码如下(示例):

#2 stack 创造一个新的维度
a=torch.rand(32,8)
b=torch.rand(32,8)
c=torch.stack([a,b],dim=0) #在0维增加新的维度->[2,32,8],其中c[0]=a,c[1]=b
#dim!=0时则是后续维度与前面Tensor的后续部分按维度匹配。
  • 1
  • 2
  • 3
  • 4
  • 5

二、分割Split/Chunk

1.Split

split对张量Tensor在指定维度上进行划分,split的划分是按照长度进行划分的,因此输入的参数为分割后各张量的长度。

代码如下(示例):

#3 split 按长度拆分
a=torch.rand(4,2,2)
aa,bb=a.split(2,dim=0) #dim=0上拆分为2个长度为2的张量,aa=a[:2],bb=a[2:]
aa,bb,cc=a.split([2,1,1],dim=0) #dim=0上拆分3个,aa=a[:2],bb=a[2],cc=a[3]
  • 1
  • 2
  • 3
  • 4

2.Chunk

chunk与split不同的在于,chunk需要指定的是分解后张量的个数,而非结果张量的长度
代码如下(示例):

#4 chunk 按个数拆分
a=torch.rand(8,2,2)
aa,bb=a.chunk(2,dim=0) #aa,bb->[4,2,2]
  • 1
  • 2
  • 3

三、基本运算

1.加减乘除

对应于torch中add、sub、mul、div,且已经重载为+、-、*、/,具有广播机制。

代码如下(示例):

#1 基本运算符+-*/ 对应元素运算,拥有广播机制
a=torch.rand(5,3)
b=torch.rand(3)
a+b
a-b
a*b
a/b
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2.矩阵乘法mm/@/matmul

矩阵乘法的运算分为mm、@、matmul。
其中mm只能作用于2维矩阵。
matmul可以作用于dim>=2的矩阵,其机理为最后两维做矩阵乘法,前面的维数保持不变或广播。
@是matmul的运算符重载,使用方便。
代码如下(示例):

#2 矩阵乘法 mm/matmul/@
c=torch.rand(5,3)
d=torch.rand(3,4)
e=torch.rand(784,3)

torch.mm(c,d) #只适用于dim=2,不建议使用
torch.matmul(c,d) #适用于任一情况
c@e.t() #运算符重载@为矩阵乘法

#dim>2时,最后两维做矩阵乘法@,前面的维数保持不变或广播
e=torch.rand(4,3,7,8)
f=torch.rand(4,1,8,9)
e@f #[4,3,7,9]
torch.matmul(e,f)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

3.幂运算**

幂运算可以采用pow、sqrt、rsqrt进行,也可以采用重载运算符**进行。
pow可以指定幂指数
sqrt求平方根
rsqrt求平方根的倒数
代码如下(示例):

#3 幂运算**/pow
a**2 #平方
a**(0.5) #开方
a.pow(2)
a.sqrt()
a.rsqrt()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

4.指数exp/对数log

exp求以e为底的指数结果。
log求自然对数ln运算结果。
log10、log2等为底数不同的对数运算
代码如下(示例):

#4 exp/log
torch.exp(a) #以e为底求幂
torch.log(a) #取对数ln
torch.log10(a) #取对数lg
  • 1
  • 2
  • 3
  • 4

5.近似floor/ceil/round/trunc/frac

对浮点Tensor进行近似。
floor:向下取整
ceil:向上取整
round:四舍五入
trunc:取整数部分
frac:取小数部分
代码如下(示例):

#5 近似floor/ceil/round/trunc/frac
g=torch.tensor(3.1415)
g.floor() #向下取整3.
g.ceil() #向上取整4.
g.trunc() #取整数部分3.(浮点)
g.frac() #取小数部分0.1415
g.roung() #四舍五入
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

6.裁剪(归化)clamp

clamp作用于张量每一个元素,将会指定范围,并将超出范围[min,max]的数据规范到min、max
代码如下(示例):

#6 clamp裁剪(将超出范围的数规范到min/max)
h=torch.rand(3,3)*10  #0-10之间随机浮点
h1=h.clamp(5)   #小于5的归化为5
h2=h.clamp(6,7) #小于6的归化为6,大于7的归化为7
  • 1
  • 2
  • 3
  • 4

总结

以上是Tensor的合并、分割与基本运算,下一篇计划为Tensor统计操作与高级操作。
2021.2.18

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/134014
推荐阅读
相关标签
  

闽ICP备14008679号