当前位置:   article > 正文

torchmetrics,一个无敌的 Python 库!

torchmetrics

21380339fdd5986086353c6c0e81b880.png

更多Python学习内容:ipengtao.com

大家好,今天为大家分享一个无敌的 Python 库 - torchmetrics。

Github地址:https://github.com/Lightning-AI/torchmetrics


在深度学习和机器学习项目中,模型评估是一个至关重要的环节。为了准确地评估模型的性能,开发者通常需要计算各种指标(metrics),如准确率、精确率、召回率、F1 分数等。torchmetrics 是一个用于 PyTorch 的开源库,提供了一组方便且高效的评估指标计算工具。本文将详细介绍 torchmetrics 库,包括其安装方法、主要特性、基本和高级功能,以及实际应用场景,帮助全面了解并掌握该库的使用。

安装

要使用 torchmetrics 库,首先需要安装它。可以通过 pip 工具方便地进行安装。

以下是安装步骤:

pip install torchmetrics

安装完成后,可以通过导入 torchmetrics 库来验证是否安装成功:

  1. import torchmetrics
  2. print("torchmetrics 库安装成功!")

特性

  1. 广泛的指标支持:提供多种评估指标,包括分类、回归、图像处理和生成模型等领域的常用指标。

  2. 模块化设计:指标可以像模块一样轻松集成到 PyTorch Lightning 或任何 PyTorch 项目中。

  3. GPU 加速:支持 GPU 加速,能够高效处理大规模数据。

  4. 易于扩展:用户可以自定义指标并轻松集成到现有项目中。

  5. 高效计算:优化的计算方法,确保在训练过程中实时计算指标,性能开销最小。

基本功能

计算准确率

使用 torchmetrics 库,可以方便地计算分类任务的准确率。

  1. import torch
  2. import torchmetrics
  3. # 创建 Accuracy 指标
  4. accuracy = torchmetrics.Accuracy()
  5. # 模拟预测和真实标签
  6. preds = torch.tensor([0213])
  7. target = torch.tensor([0123])
  8. # 计算准确率
  9. acc = accuracy(preds, target)
  10. print(f"准确率:{acc}")

计算精确率和召回率

torchmetrics 库可以计算分类任务的精确率和召回率。

  1. import torch
  2. import torchmetrics
  3. # 创建 Precision 和 Recall 指标
  4. precision = torchmetrics.Precision(num_classes=4)
  5. recall = torchmetrics.Recall(num_classes=4)
  6. # 模拟预测和真实标签
  7. preds = torch.tensor([0213])
  8. target = torch.tensor([0123])
  9. # 计算精确率和召回率
  10. prec = precision(preds, target)
  11. rec = recall(preds, target)
  12. print(f"精确率:{prec}")
  13. print(f"召回率:{rec}")

计算 F1 分数

torchmetrics 库还可以计算分类任务的 F1 分数。

  1. import torch
  2. import torchmetrics
  3. # 创建 F1 指标
  4. f1 = torchmetrics.F1(num_classes=4)
  5. # 模拟预测和真实标签
  6. preds = torch.tensor([0213])
  7. target = torch.tensor([0123])
  8. # 计算 F1 分数
  9. f1_score = f1(preds, target)
  10. print(f"F1 分数:{f1_score}")

高级功能

自定义指标

torchmetrics 库允许用户自定义指标,以满足特定需求。

  1. import torch
  2. import torchmetrics
  3. class CustomMetric(torchmetrics.Metric):
  4.     def __init__(self):
  5.         super().__init__()
  6.         self.add_state("sum"default=torch.tensor(0), dist_reduce_fx="sum")
  7.         self.add_state("count"default=torch.tensor(0), dist_reduce_fx="sum")
  8.     def update(self, preds: torch.Tensor, target: torch.Tensor):
  9.         self.sum += torch.sum(preds == target)
  10.         self.count += target.numel()
  11.     def compute(self):
  12.         return self.sum.float() / self.count
  13. # 创建自定义指标
  14. custom_metric = CustomMetric()
  15. # 模拟预测和真实标签
  16. preds = torch.tensor([0213])
  17. target = torch.tensor([0123])
  18. # 计算自定义指标
  19. result = custom_metric(preds, target)
  20. print(f"自定义指标结果:{result}")

与 PyTorch Lightning 集成

torchmetrics 库可以无缝集成到 PyTorch Lightning 中,简化指标计算流程。

  1. import torch
  2. import torchmetrics
  3. import pytorch_lightning as pl
  4. from torch import nn
  5. class LitModel(pl.LightningModule):
  6.     def __init__(self):
  7.         super().__init__()
  8.         self.model = nn.Linear(104)
  9.         self.accuracy = torchmetrics.Accuracy()
  10.     def forward(self, x):
  11.         return self.model(x)
  12.     def training_step(self, batch, batch_idx):
  13.         x, y = batch
  14.         preds = self(x)
  15.         loss = nn.functional.cross_entropy(preds, y)
  16.         acc = self.accuracy(preds, y)
  17.         self.log('train_acc', acc)
  18.         return loss
  19.     def configure_optimizers(self):
  20.         return torch.optim.Adam(self.parameters(), lr=0.001)
  21. # 示例数据
  22. train_data = torch.utils.data.TensorDataset(torch.randn(10010), torch.randint(04, (100,)))
  23. train_loader = torch.utils.data.DataLoader(train_data, batch_size=32)
  24. # 训练模型
  25. model = LitModel()
  26. trainer = pl.Trainer(max_epochs=5)
  27. trainer.fit(model, train_loader)

GPU 加速

torchmetrics 库支持 GPU 加速,可以在 GPU 上高效地计算指标。

  1. import torch
  2. import torchmetrics
  3. # 创建 Accuracy 指标并移动到 GPU
  4. accuracy = torchmetrics.Accuracy().cuda()
  5. # 模拟预测和真实标签并移动到 GPU
  6. preds = torch.tensor([0213]).cuda()
  7. target = torch.tensor([0123]).cuda()
  8. # 计算准确率
  9. acc = accuracy(preds, target)
  10. print(f"准确率:{acc}")

实际应用场景

图像分类任务中的指标计算

在图像分类任务中,需要计算各种评估指标,如准确率、精确率、召回率等。

  1. import torch
  2. import torchmetrics
  3. import torchvision.models as models
  4. import torchvision.transforms as transforms
  5. from torchvision.datasets import CIFAR10
  6. from torch.utils.data import DataLoader
  7. # 加载数据
  8. transform = transforms.Compose([transforms.ToTensor()])
  9. train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
  10. train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
  11. # 创建模型和指标
  12. model = models.resnet18(num_classes=10)
  13. accuracy = torchmetrics.Accuracy()
  14. # 训练模型并计算准确率
  15. for inputs, targets in train_loader:
  16.     outputs = model(inputs)
  17.     acc = accuracy(outputs, targets)
  18.     print(f"批次准确率:{acc}")

文本分类任务中的指标计算

在文本分类任务中,需要计算评估指标,如 F1 分数。

  1. import torch
  2. import torchmetrics
  3. from transformers import BertTokenizer, BertForSequenceClassification
  4. # 加载模型和分词器
  5. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  6. model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
  7. # 示例数据
  8. texts = ["I love this!""This is bad."]
  9. labels = torch.tensor([10])
  10. # 预处理数据
  11. inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
  12. outputs = model(**inputs)
  13. # 创建 F1 指标
  14. f1 = torchmetrics.F1(num_classes=2)
  15. # 计算 F1 分数
  16. preds = torch.argmax(outputs.logits, dim=1)
  17. f1_score = f1(preds, labels)
  18. print(f"F1 分数:{f1_score}")

生成对抗网络(GAN)中的指标计算

在生成对抗网络(GAN)的训练中,需要计算生成图片的质量指标,如 Frechet Inception Distance(FID)。

  1. import torch
  2. import torchmetrics
  3. from torchvision.models import inception_v3
  4. from torchvision.transforms import transforms
  5. from torch.utils.data import DataLoader, TensorDataset
  6. # 创建生成对抗网络(GAN)的生成器模型
  7. class Generator(torch.nn.Module):
  8.     def __init__(self):
  9.         super(Generator, self).__init__()
  10.         self.fc = torch.nn.Linear(100128 * 7 * 7)
  11.         self.deconv = torch.nn.Sequential(
  12.             torch.nn.ConvTranspose2d(128644, stride=2, padding=1),
  13.             torch.nn.BatchNorm2d(64),
  14.             torch.nn.ReLU(True),
  15.             torch.nn.ConvTranspose2d(6414, stride=2, padding=1),
  16.             torch.nn.Tanh()
  17.         )
  18.     def forward(self, x):
  19.         x = self.fc(x).view(-112877)
  20.         return self.deconv(x)
  21. # 创建生成器模型
  22. generator = Generator()
  23. # 创建 FID 指标
  24. fid = torchmetrics.image.fid.FrechetInceptionDistance(feature=64)
  25. # 模拟生成图片和真实图片
  26. latent_vectors = torch.randn(100100)
  27. generated_images = generator(latent_vectors)
  28. real_images = torch.randn(10012828)
  29. # 转换图片为 Inception V3 输入格式
  30. transform = transforms.Compose([
  31.     transforms.Resize((299299)),
  32.     transforms.Normalize(mean=[0.5], std=[0.5])
  33. ])
  34. generated_images = transform(generated_images)
  35. real_images = transform(real_images)
  36. # 创建 DataLoader
  37. generated_loader = DataLoader(TensorDataset(generated_images), batch_size=32)
  38. real_loader = DataLoader(TensorDataset(real_images), batch_size=32)
  39. # 计算 FID
  40. for gen_batch, real_batch in zip(generated_loader, real_loader):
  41.     fid.update(real_batch[0], gen_batch[0])
  42. fid_value = fid.compute()
  43. print(f"FID 分数:{fid_value}")

总结

torchmetrics 库是一个功能强大且易于使用的评估指标计算工具,能够帮助开发者在深度学习和机器学习项目中高效地计算各种评估指标。通过支持广泛的指标、多种计算模式、GPU 加速和自定义扩展,torchmetrics 库能够满足各种复杂的评估需求。本文详细介绍了 torchmetrics 库的安装方法、主要特性、基本和高级功能,以及实际应用场景。希望本文能帮助大家全面掌握 torchmetrics 库的使用,并在实际项目中发挥其优势。

如果你觉得文章还不错,请大家 点赞、分享、留言 下,因为这将是我持续输出更多优质文章的最强动力!


如果想要系统学习Python、Python问题咨询,或者考虑做一些工作以外的副业,都可以扫描二维码添加微信,围观朋友圈一起交流学习。

b6e9de34fef3fa667b741f36ce203968.gif

我们还为大家准备了Python资料和副业项目合集,感兴趣的小伙伴快来找我领取一起交流学习哦!

a49ac0374b3473b26fe3180dd7267e58.jpeg

往期推荐

历时一个月整理的 Python 爬虫学习手册全集PDF(免费开放下载)

Python基础学习常见的100个问题.pdf(附答案)

学习 数据结构与算法,这是我见过最友好的教程!(PDF免费下载)

Python办公自动化完全指南(免费PDF)

Python Web 开发常见的100个问题.PDF

肝了一周,整理了Python 从0到1学习路线(附思维导图和PDF下载)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/974265
推荐阅读
相关标签
  

闽ICP备14008679号