赞
踩
detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来
需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度
import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
那么这个函数有什么作用?
–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法
a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()
来看一个实际的例子:
import torch as t x = t.ones(1, requires_grad=True) x.requires_grad #True y = t.ones(1, requires_grad=True) y.requires_grad #True x = x.detach() #分离之后 x.requires_grad #False y = x+y #tensor([2.]) y.requires_grad #我还是True y.retain_grad() #y不是叶子张量,要加上这一行 z = t.pow(y, 2) z.backward() #反向传播 y.grad #tensor([4.]) x.grad #None
以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None
既然谈到了修改模型的权重问题,那么还有一种情况是:
–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?
这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.
for param in B.parameters():
param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。