当前位置:   article > 正文

与量化相关的pytoch准备工作_深度学习中地hook函数 需要remove吗

深度学习中地hook函数 需要remove吗

hook函数

为了节省显存(内存),PyTorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数在使用后应及时删除(remove),以避免每次都运行钩子增加运行负载。

这里总结一下并给出实际用法和注意点。hook方法有4种:

1、Tensor.register_hook()
2、torch.nn.Module.register_forward_hook()
3、torch.nn.Module.register_backward_hook()
4、torch.nn.Module.register_forward_pre_hook()
在这里插入图片描述

1.Tensor.register_hook(hook):对于单个张量,可以使用register_hook()方法注册一个hook。该方法将一个函数(即hook)注册到张量上,在张量被计算时调用该函数。这个函数可以用来获取张量的梯度或值,或者对张量进行其他操作。例如,以下代码演示了如何使用register_hook()方法获取张量的梯度:

import torch

x = torch.randn(2, 2, requires_grad=True)

def print_grad(grad):
    print(grad)

hook_handle = x.register_hook(print_grad)

y = x.sum()

y.backward()

hook_handle.remove()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

在这个例子中,我们创建了一个包含梯度的张量x,并使用register_hook()方法注册了一个打印梯度的hook函数print_grad()。在计算y的梯度时,hook函数被调用并打印梯度。最后,我们使用hook_handle.remove()方法从张量中删除hook函数。

在这里插入图片描述

2.torch.nn.Module.register_forward_hook(hook):对于模型中的每个层,可以使用register_forward_hook()方法注册一个hook函数。这个函数将在模型的前向传递中被调用,并可以用来获取层的输出。例如,以下代码演示了如何使用register_forward_hook()方法获取模型的某一层的输出:

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        return x

def hook(module, input, output):
    print(output.shape)

model = MyModel()
handle = model.conv2.register_forward_hook(hook)

x = torch.randn(1, 3, 224, 224)
output = model(x)

handle.remove()


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x)
      
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x)


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x)


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
   		 pass
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87

在这个例子中,我们定义了一个包含两个卷积层和一个全局平均池化层的模型,并使用register_forward_hook()方法在第二个卷积层上注册了一个hook函数。当模型进行前向传递时,hook函数将被调用并打印输出张量的形状。最后,我们使用handle.remove()方法从模型中删除hook函数。

在这里插入图片描述

torch.nn.Module.register_backward_hook(hook):对于模型中的每个层,可以使用register_backward_hook()方法注册一个hook函数。这个函数将在模型的反向传递中被调用,并可以用来获取梯度或其他信息。例如,以下代码演示了如何使用register_backward_hook()方法获取某一层的梯度

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        return x

def hook(module, grad_input, grad_output):
    print(grad_input[0].shape, grad_output[0].shape)

model = MyModel()
handle = model.conv2.register_backward_hook(hook)

x = torch.randn(1, 3, 224,
output = model(x)
output.sum().backward()

handle.remove()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

在这个例子中,我们定义了一个包含两个卷积层和一个全局平均池化层的模型,并使用register_backward_hook()方法在第二个卷积层上注册了一个hook函数。当模型进行反向传递时,hook函数将被调用并打印输入梯度和输出梯度张量的形状。最后,我们使用handle.remove()方法从模型中删除hook函数。

在这里插入图片描述

  1. torch.nn.Module.register_forward_pre_hook(hook):对于模型中的每个层,可以使用register_forward_pre_hook()方法注册一个hook函数。这个函数将在模型的前向传递之前被调用,并可以用来获取输入张量或其他信息。例如,以下代码演示了如何使用register_forward_pre_hook()方法获取模型的输入张量:
import torch.nn as nn

class MyModel(nn.Module):
def init(self):
super(MyModel, self).init()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

def forward(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.relu(x)
    x = self.avgpool(x)
    [return x](poe://www.poe.com/_api/key_phrase?phrase=return%20x&prompt=Tell%20me%20more%20about%20return%20x.)
def hook(module, input):
print(input[0].shape)

model = MyModel()
handle = model.conv1.register_forward_pre_hook(hook)

x = torch.randn(1, 3, 224, 224)
output = model(x)

handle.remove()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

在这个例子中,我们定义了一个包含两个卷积层和一个全局平均池化层的模型,并使用register_forward_pre_hook()方法在第一个卷积层上注册了一个hook函数。当模型进行前向传递时,hook函数将被调用并打印输入张量的形状。最后,我们使用handle.remove()方法从模型中删除hook函数。

需要注意的是,hook函数应该尽可能快地执行,以避免对模型的计算时间造成过多的影响。此外,如果注册了太多的hook函数,会导致额外的内存占用和计算负担。因此,应该仔细考虑何时需要注册hook函数,并在使用后及时删除它们。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/750842
推荐阅读
  

闽ICP备14008679号