当前位置:   article > 正文

使用`amp`进行GPU运算优化的学习笔记(会出现nan)_runtimeerror: torch.nn.functional.binary_cross_ent

runtimeerror: torch.nn.functional.binary_cross_entropy and torch.nn.bceloss

1 前言:不推荐,容易出现nan

经过一段时间的探索,我们认为:amp训练是不推荐的,容易出现nan
在debug过程中,我们发现在网络的前向运算中出现了nan,这一点已经在PyTorchForums提交了issue_nan_amp,目前还没有得到maintainer的回复;
该问题具体的表现:在前向运算中出现nan,但是1. 当前算子的参数不包含nan;2.并且输入数据不含有nan,目前无法判断出现nan的原因;
测试代码如下:

...
# forward pass
with torch.cuda.amp.autocast(self.use_amp):
    x = self.backbone(img)
    x = self.neck(x)

    if self.qa.check_nan:
    	# 检查输入x是否含有nan
        assert not x.isnan().any()
        # 检查dict_feats["layer2"]是否含有nan
        assert not self.dict_feats["layer2"].isnan().any()
        for p in self.decoder2.parameters():
            assert not p.isnan().any()

    if isinstance(self.decoder2, FuseDecoder):
        encode_data = self.decoder2(x, self.dict_feats["layer2"])
    else:
        raise NotImplementedError

if self.qa.check_nan:
    assert not encode_data.isnan().any()
...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

经过实验,出现以下的提示信息:

in train_epochs
output = self.model(batch_img)
File “/home/user/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1110, in _call_impl
return forward_call(*input, **kwargs)
File “/home/…net.py”, line 331, in forward
assert not encode_data.isnan().any()
AssertionError

可以看到,encode_data中出现nan,说明数值计算过程不是很稳定;

调试记录

“x = self.bn(x)”: BatchNorm2d

提示信息:
在这里插入图片描述
可以看到是在146行出现了错误,
代码截图:
在这里插入图片描述
是在BN层出现了错误,这里感觉BN层容易出现数值溢出(目前暂时没有什么比较好的解决方案);

2 容易出现nan的算子BatchNorm2d

这里我们来记录一下容易出现nan的算子:nn.BatchNorm2d

3 AMP训练:torch.cuda.amp

示例代码:

# amp依赖Tensor core架构,所以model参数必须是cuda tensor类型
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# GradScaler对象用来自动做梯度缩放
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        # 在autocast enable 区域运行forward
        with autocast():
            # model做一个FP16的副本,forward
            output = model(input)
            loss = loss_fn(output, target)
        # 用scaler,scale loss(FP16),backward得到scaled的梯度(FP16)
        scaler.scale(loss).backward()
        # scaler 更新参数,会先自动unscale梯度
        # 如果有nan或inf,自动跳过
        scaler.step(optimizer)
        # scaler factor更新
        scaler.update()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

4 使用装饰器指定不使用amp的模块:keep_forward_float_()

def keep_forward_float_(m):
    def float_forward(self, x, forward):
        assert isinstance(self, nn.Module)
        with autocast(enabled=False):
            return forward(x.float())  
            # x.float()指将输入数据转换为fp32类型

    m.forward = MethodType(functools.partial(float_forward, forward=m.forward), m)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

5 使用amp和GradientAccumulation联合进行优化

scaler = GradScaler()

for epoch in epochs:
    for i, (input, target) in enumerate(data):
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
            loss = loss / iters_to_accumulate # 看看这个是否可以省略

        # Accumulates scaled gradients.
        scaler.scale(loss).backward()

        if (i + 1) % iters_to_accumulate == 0:
            # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

3 Troubleshooting

3.1 RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.

运行时出现错误提示:

RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are safe to autocast.

由提示信息可知,torch规定无法在autocast作用域中使用nn.BCELoss(reduction="none"),于是,需要在代码中单独声明在计算BCE损失时不使用autocast,示例代码如下:

with autocast(enabled=False):
    bce = self.bce_loss(output_map.float(), target_map.float())
  • 1
  • 2

Note:
autocast(enabled=False)作用域中引用的tensor需要使用其float()版本;请参考torch官方示例amp_force_float32

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/245657
推荐阅读
相关标签
  

闽ICP备14008679号