赞
踩
功能:主要用于提取中间变量,同时也可以做修改等操作。
常用的相关函数方法
# 提取数据的梯度
torch.Tensor.register_hook(hook)
# 提取模型的中间特征数据
torch.nn.Module.register_forward_hook(hook)
# 提取网络层中的梯度
torch.nn.Module.register_full_backward_hook(hook)
PyTorch在每一次运算结束后都会释放中间变量,从而节省内存空间,例如释放模型中间得到的特征数据、反向传播过程中的梯度等等,因此就有了hook方法,可以操作中间变量,如保存梯度、保存中间特征数据,也可以对中间变量做修改,如增大梯度、限制梯度范围等等,核心在于hook
函数的定义。
定义hook:
# register_hook
hook(grad) -> Tensor or None
# register_forward_hook
hook(module, input, output) -> None or modified output
# register_full_backward_hook
# 一般只利用grad_output,提取模块输出元素的梯度
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
利用torch.Tensor.register_hook(hook)
方法实现,计算数据在做反向传播时的梯度
以下面的公式为例:
z
=
1
4
∑
i
=
1
4
y
i
,
y
i
=
x
i
2
z=\frac14\sum_{i=1}^4y_i,\quad y_i=x_i^2
z=41i=1∑4yi,yi=xi2
import torch
def grad_hook_x(grad):
# 只传入梯度这一个变量
x_grad.append(grad)
def grad_hook_y(grad):
y_grad.append(grad)
torch.manual_seed(0)
y_grad = []
x_grad = []
x = torch.rand(4, requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
y.register_hook(grad_hook_y)
x.register_hook(grad_hook_x)
z.backward()
print(x)
print("x grad: ", x_grad[0])
print("y grad: ", y_grad[0])
输出,相当于对x和y上的梯度做了保存
# 输入x
tensor([0.4963, 0.7682, 0.0885, 0.1320], requires_grad=True)
# x上的梯度
x grad: tensor([0.2481, 0.3841, 0.0442, 0.0660])
# y上的梯度
y grad: tensor([0.2500, 0.2500, 0.2500, 0.2500])
注:将z.backward()
改为z.backward(retain_graph=True)
也可以实现储存梯度的功能
如果想要修改梯度,则只需要修改hook函数,如下面案例,此时y上的梯度是原来的两倍,将会影响x的参数更新(更新幅度变大)
def grad_hook_y(grad):
return grad * 2
利用torch.nn.Module.register_forward_hook(hook)
方法实现,实现提取特征数据的功能。
注:尽量不要在这里修改特征数据,容易出问题,最好直接去网络结构里面改。
import torch
from torch import nn
class Net(nn.Module):
def __init__(self, input_size, out_size, middle_size=None):
super().__init__()
if not middle_size:
middle_size = input_size // 2
self.conv1 = nn.Conv2d(input_size, middle_size, 3)
self.conv2 = nn.Conv2d(middle_size, middle_size, 3)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(middle_size, out_size)
self.middle_size = middle_size
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.pool(x2).view(-1, self.middle_size)
x4 = self.fc(x3)
return x4
def forward_hook(module, inputs, outputs):
# 传入模块、模块输入、模块输出三种参数
feature_map_inputs.append(inputs)
feature_map_outputs.append(outputs)
torch.manual_seed(0)
feature_map_inputs = []
feature_map_outputs = []
net = Net(4, 2, 3)
net.conv1.register_forward_hook(forward_hook)
data = torch.rand((1, 4, 6, 6), dtype=torch.float32)
out = net(data)
out1 = out[:, 0]
net.zero_grad()
out1.backward(retain_graph=True)
输出
利用torch.nn.Module.register_full_backward_hook(hook)
方法实现,实现提取特征数据的梯度功能、也可以修改梯度。
注:
.detach()
方法,切断梯度,防止后续操作对网络反向传播有影响;backward_hook()
中的inputs
和outputs
均为元组类型。网络结构还是之前定义的结构
import torch
from torch import nn
def backward_hook(module, inputs, outputs):
# 元组类型,常利用[0]提取梯度数据
grad_inputs.append(inputs[0].detach())
grad_outputs.append(outputs[0].detach())
torch.manual_seed(0)
grad_inputs = []
grad_outputs = []
net = Net(4, 2, 3)
net.conv2.register_backward_hook(backward_hook)
data = torch.rand((1, 4, 6, 6), dtype=torch.float32)
out = net(data)
out1 = out[:, 0]
net.zero_grad()
# retain_graph设为True表明在反向传播时保存梯度
out1.backward(retain_graph=True)
输出
注:
import torch
from torch import nn
class Net(nn.Module):
def __init__(self, input_size, out_size, middle_size=None):
super().__init__()
if not middle_size:
middle_size = input_size // 2
self.fc1 = nn.Linear(input_size, middle_size)
self.fc2 = nn.Linear(middle_size, out_size)
def forward(self, x):
x1 = self.fc1(x)
x2 = self.fc2(x1)
return x2
def backward_hook(module, inputs, outputs):
grad_inputs.append(inputs[0].detach())
grad_outputs.append(outputs[0].detach())
torch.manual_seed(0)
grad_inputs = []
grad_outputs = []
net = Net(6, 2, 3)
net.fc2.register_backward_hook(backward_hook)
data = torch.rand((1, 6), dtype=torch.float32)
out = net(data)
out1 = out[:, 0]
net.zero_grad()
# retain_graph设为True,目的保留梯度
out1.backward(retain_graph=True)
print(grad_inputs, grad_outputs)
输出
注:
register_hook:https://pytorch.org/docs/1.2.0/tensors.html#torch.Tensor.register_hook
register_forward_hook:https://pytorch.org/docs/1.2.0/nn.html#torch.nn.Module.register_forward_hook
register_full_backward_hook:https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_full_backward_hook#torch.nn.Module.register_full_backward_hook
注:以上内容仅是笔者个人见解,若有错误,欢迎指正。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。