当前位置:   article > 正文

Pytorch学习笔记:hook操作——提取特征、梯度等信息_pytorch hook 提取特征

pytorch hook 提取特征

介绍

功能:主要用于提取中间变量,同时也可以做修改等操作。

常用的相关函数方法

# 提取数据的梯度
torch.Tensor.register_hook(hook)
# 提取模型的中间特征数据
torch.nn.Module.register_forward_hook(hook)
# 提取网络层中的梯度
torch.nn.Module.register_full_backward_hook(hook)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

  PyTorch在每一次运算结束后都会释放中间变量,从而节省内存空间,例如释放模型中间得到的特征数据、反向传播过程中的梯度等等,因此就有了hook方法,可以操作中间变量,如保存梯度、保存中间特征数据,也可以对中间变量做修改,如增大梯度、限制梯度范围等等,核心在于hook函数的定义。

定义hook:

# register_hook
hook(grad) -> Tensor or None
# register_forward_hook
hook(module, input, output) -> None or modified output
# register_full_backward_hook
# 一般只利用grad_output,提取模块输出元素的梯度
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

数据梯度

  利用torch.Tensor.register_hook(hook)方法实现,计算数据在做反向传播时的梯度

代码案例

以下面的公式为例:
z = 1 4 ∑ i = 1 4 y i , y i = x i 2 z=\frac14\sum_{i=1}^4y_i,\quad y_i=x_i^2 z=41i=14yi,yi=xi2

import torch


def grad_hook_x(grad):
    # 只传入梯度这一个变量
    x_grad.append(grad)


def grad_hook_y(grad):
    y_grad.append(grad)


torch.manual_seed(0)
y_grad = []
x_grad = []
x = torch.rand(4, requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
y.register_hook(grad_hook_y)
x.register_hook(grad_hook_x)
z.backward()
print(x)
print("x grad: ", x_grad[0])
print("y grad: ", y_grad[0])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

输出,相当于对x和y上的梯度做了保存

# 输入x
tensor([0.4963, 0.7682, 0.0885, 0.1320], requires_grad=True)
# x上的梯度
x grad:  tensor([0.2481, 0.3841, 0.0442, 0.0660])
# y上的梯度
y grad:  tensor([0.2500, 0.2500, 0.2500, 0.2500])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

注:将z.backward()改为z.backward(retain_graph=True)也可以实现储存梯度的功能

修改梯度

  如果想要修改梯度,则只需要修改hook函数,如下面案例,此时y上的梯度是原来的两倍,将会影响x的参数更新(更新幅度变大)

def grad_hook_y(grad):
    return grad * 2
  • 1
  • 2

网络中间特征

  利用torch.nn.Module.register_forward_hook(hook)方法实现,实现提取特征数据的功能。

注:尽量不要在这里修改特征数据,容易出问题,最好直接去网络结构里面改。

代码案例

网络结构

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self, input_size, out_size, middle_size=None):
        super().__init__()
        if not middle_size:
            middle_size = input_size // 2
        self.conv1 = nn.Conv2d(input_size, middle_size, 3)
        self.conv2 = nn.Conv2d(middle_size, middle_size, 3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(middle_size, out_size)
        self.middle_size = middle_size

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.pool(x2).view(-1, self.middle_size)
        x4 = self.fc(x3)
        return x4
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

调用方法

def forward_hook(module, inputs, outputs):
    # 传入模块、模块输入、模块输出三种参数
    feature_map_inputs.append(inputs)
    feature_map_outputs.append(outputs)


torch.manual_seed(0)
feature_map_inputs = []
feature_map_outputs = []

net = Net(4, 2, 3)
net.conv1.register_forward_hook(forward_hook)
data = torch.rand((1, 4, 6, 6), dtype=torch.float32)

out = net(data)
out1 = out[:, 0]
net.zero_grad()
out1.backward(retain_graph=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

输出
在这里插入图片描述

网络梯度

  利用torch.nn.Module.register_full_backward_hook(hook)方法实现,实现提取特征数据的梯度功能、也可以修改梯度。

注:

  • 在提取梯度时,最好加一个.detach()方法,切断梯度,防止后续操作对网络反向传播有影响;
  • 模块存在多个输入输出时,backward_hook()中的inputsoutputs均为元组类型。

代码案例

卷积模块

网络结构还是之前定义的结构

import torch
from torch import nn


def backward_hook(module, inputs, outputs):
    # 元组类型,常利用[0]提取梯度数据
    grad_inputs.append(inputs[0].detach())
    grad_outputs.append(outputs[0].detach())


torch.manual_seed(0)
grad_inputs = []
grad_outputs = []

net = Net(4, 2, 3)
net.conv2.register_backward_hook(backward_hook)
data = torch.rand((1, 4, 6, 6), dtype=torch.float32)

out = net(data)
out1 = out[:, 0]
net.zero_grad()
# retain_graph设为True表明在反向传播时保存梯度
out1.backward(retain_graph=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

输出

在这里插入图片描述

注:

  • grad_inputs表示模块输入参数的梯度,梯度尺寸和输入特征图的尺寸相同;
  • grad_outputs表示模块输出参数的梯度,同上,梯度尺寸和输出特征图的尺寸相同。

全连接模块

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self, input_size, out_size, middle_size=None):
        super().__init__()
        if not middle_size:
            middle_size = input_size // 2
        self.fc1 = nn.Linear(input_size, middle_size)
        self.fc2 = nn.Linear(middle_size, out_size)

    def forward(self, x):
        x1 = self.fc1(x)
        x2 = self.fc2(x1)

        return x2

    
def backward_hook(module, inputs, outputs):
    grad_inputs.append(inputs[0].detach())
    grad_outputs.append(outputs[0].detach())


torch.manual_seed(0)
grad_inputs = []
grad_outputs = []
net = Net(6, 2, 3)
net.fc2.register_backward_hook(backward_hook)
data = torch.rand((1, 6), dtype=torch.float32)

out = net(data)
out1 = out[:, 0]
net.zero_grad()
# retain_graph设为True,目的保留梯度
out1.backward(retain_graph=True)
print(grad_inputs, grad_outputs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

输出

在这里插入图片描述

注:

  • grad_inputs表示模块输入参数的梯度,梯度尺寸和输入特征的尺寸相同;
  • grad_outputs表示模块输出参数的梯度,梯度尺寸和输出特征的尺寸相同。

官方文档

register_hook:https://pytorch.org/docs/1.2.0/tensors.html#torch.Tensor.register_hook

register_forward_hook:https://pytorch.org/docs/1.2.0/nn.html#torch.nn.Module.register_forward_hook

register_full_backward_hook:https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_full_backward_hook#torch.nn.Module.register_full_backward_hook

注:以上内容仅是笔者个人见解,若有错误,欢迎指正。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/365870
推荐阅读
相关标签
  

闽ICP备14008679号