赞
踩
常用的库是
thop
,它是一个用于计算PyTorch模型的FLOPs和参数数量的库。
1.安装 thop
库:
pip install thop
2 导入需要计算 FLOPs 的模型和 profile
函数:.
- import torch
- from torchvision.models import resnet18
- from thop import profile
3. 创建模型并准备输入数据:
- model = resnet18()
- input_data = torch.randn(1, 28, 256, 256)
我这里是跟随训练一起写的,要调用GPU,所以是如下代码:
- # 创建模型并将其放在GPU上
- model = resnet18().cuda()
- # 创建输入数据并将其放在GPU上
- input_data = torch.randn(1, 28, 256, 256).cuda()
4.使用 profile
函数计算 FLOPs:
- flops, params = profile(model, inputs=(input_data,))
- print(f"FLOPs: {flops}, Params: {params}")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。