当前位置:   article > 正文

【深度学习笔记】打印模型的FLOPS,使用python的thop库_如何安装thop库

如何安装thop库

常用的库是 thop,它是一个用于计算PyTorch模型的FLOPs和参数数量的库。

1.安装 thop 库:

pip install thop

2 导入需要计算 FLOPs 的模型和 profile 函数:.

  1. import torch
  2. from torchvision.models import resnet18
  3. from thop import profile

3. 创建模型并准备输入数据:

  1. model = resnet18()
  2. input_data = torch.randn(1, 28, 256, 256)

我这里是跟随训练一起写的,要调用GPU,所以是如下代码:

  1. # 创建模型并将其放在GPU上
  2. model = resnet18().cuda()
  3. # 创建输入数据并将其放在GPU上
  4. input_data = torch.randn(1, 28, 256, 256).cuda()

4.使用 profile 函数计算 FLOPs:

  1. flops, params = profile(model, inputs=(input_data,))
  2. print(f"FLOPs: {flops}, Params: {params}")

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/974319
推荐阅读
相关标签
  

闽ICP备14008679号