赞
踩
首先,反向传播或许被称为“反向求导”更加合适,因为它只是个求导的过程,即计算中间参数的梯度。在PyTorch中,通过loss.backward()
进行反向求导,关于loss.backward()
有两点需要注意:【1】loss是标量(零维张量),只有标量才能直接使用 backward()
;【2】loss.backward()
的完整写法是loss.backward(retain_graph=False)
,其中的形参retain_graph
的意义在于是否保留计算图,默认为False,即在反向传播后,自动释放当前计算图,节省资源的同时,为下一次反向传播做好准备。
一般情况下是每次迭代,只需一次 forward()
和一次 backward()
。但是不排除,由于自定义loss
等有多个,网络需要计算多个不同loss
的backward()
产生的梯度,来更新参数。于是,如果在当前loss.backward()
后,还需要执行其他loss
的backward()
,那么就需要在当前的loss.backward()
时,指定保留计算图,即loss.backward(retain_graph=True)
。
需要特别注意的是,反向传播求导与网络参数更新是两个不同的过程。必须要先反向求导,再进行参数更新,其代码逻辑大致如下:
loss.backward() # 先在构建好的计算图进行反向传播,计算中间变量的梯度,计算完后,立即释放计算图
optimizer.step() # 根据计算好的梯度,进行网络的参数更新
计算图可以说是输入变量一直到输出变量的逻辑运算关系,是模型前向forward()
和后向求梯度backward()
的流程参照。这里需要注意的是,能获取回传梯度(grad)的只有计算图的叶节点(即输入节点,在loss.backward()后)。中间节点的梯度在计算求取并回传之后就会被释放掉,没办法获取。想要获取中间节点梯度,可以使用 register_hook (钩子)函数工具。当然, register_hook 不仅仅只有这个作用。
如上所述,,能获取回传梯度(grad)的只有计算图的叶节点(即输入节点,在loss.backward()后)。中间节点的梯度在计算求取并回传之后就会被释放掉,没办法获取。想要获取中间节点梯度,可以使用 register_hook (钩子)函数工具。
在编写Pytorch的训练代码时,下面一段代码是非常常见的:
for i in range(batch): # 在每个Batch都执行如下操作
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
在这里明确一下zero_grad()
函数的作用以及为什么需要在每个batch都执行该操作:根据pytorch中backward()
函数的计算,当进行求导时,梯度是累积计算而不是被替换,但在处理每一个batch时并不需要与其他batch的梯度混合起来累积计算,因此需要对每个batch调用一遍zero_grad()
将当前可变参数的梯度置0。
当然,我们也可以不选择每个batch都清除一次梯度,比如两次或多次再清除一次,这样相当于提高了batch_size,对GPU的内存需要更高。
detach()
,如果 x
为中间输出,x_1 = x.detach
表示创建一个与 x 相同,但requires_grad==False
的新Tensor (相当于是把x_1
以前的计算图 grad_fn 都消除了),x_1
也就成了叶节点(输入节点)。原先反向传播时,回传到x
时还会继续,而现在回到x_1
处后,就结束了,不继续回传求导了,在x之前的网络参数也就不再进行更新了。另外值得注意的是detach_()
表示不创建新张量,而是直接修改 x
本身。
如果对于张量 x
,如果 x.requires_grad == True
, 则表示它可以参与求导,也可以从它向后求导。默认情况下,一个新的Tensor的 requires_grad
为False
。
可以向后求导的意思是说,requires_grad == True
具有传递性,如果:
x.requires_grad == True
y.requires_grad == False
z = f(x,y)
则z.requires_grad == True
,注意requires_grad == False
则不具有传递性,在PyTorch中,凡是参与运算的变量(包括输入、输出、中间输出、网络权重参数等),都可以设置requires_grad
。但是一般来说,输入是否含requires_grad=True
是无所谓的,因为需要更新的是网络权重的参数,如果通过nn方法调用卷积层等,会默认其中的参数的requires_grad=True
。当然,设置输入的requires_grad=True
会更保险,但是可能会带来一些计算和内存上的代价。
volatile
在PyTorch1.0
版本后已经被移除了。实际上,volatile==True
就等价于 requires_grad==False
。 volatile==True
同样具有传递性。一般只用在推理过程中。若是从某个中间输出x
张量 开始都只需做推理,而不需反传梯度的话,那么只需设置x.volatile=True
,那么 x
以后的运算过程得到的输出张量均为 volatile==Tru
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。