当前位置:   article > 正文

Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法_torch detach

torch detach
detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来
在这里插入图片描述
需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度
在这里插入图片描述

import torch as t

a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

那么这个函数有什么作用?
–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()

b = B(a)
loss = criterion(b, target)
loss.backward()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True

x = x.detach()   #分离之后
x.requires_grad   #False

y = x+y      	  #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行

z = t.pow(y, 2)
z.backward()    #反向传播

y.grad        #tensor([4.])
x.grad        #None
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None


既然谈到了修改模型的权重问题,那么还有一种情况是:
–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
	param.requires_grad = False

a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/160115
推荐阅读
相关标签
  

闽ICP备14008679号