当前位置:   article > 正文

炼丹5至7倍速,使用Mac M1 芯片加速pytorch完全指南_mac炼丹

mac炼丹

2022年5月,PyTorch官方宣布已正式支持在M1芯片版本的Mac上进行模型加速。官方对比数据显示,和CPU相比,M1上炼丹速度平均可加速7倍。

哇哦,不用单独配个GPU也能加速这么多,我迫不及待地搞到一个M1芯片的MacBook后试水了一番,并把我认为相关重要的信息梳理成了本文。

公众号后台回复关键词:M1,可获取本文jupyter notebook源代码。

一,加速原理

  • Question1,Mac M1芯片 为什么可以用来加速 pytorch?

因为 Mac M1芯片不是一个单纯的一个CPU芯片,而是包括了CPU(中央处理器),GPU(图形处理器),NPU(神经网络引擎),以及统一内存单元等众多组件的一块集成芯片。由于Mac M1芯片集成了GPU组件,所以可以用来加速pytorch.

  • Question2,Mac M1芯片 上GPU的的显存有多大?

Mac M1芯片的CPU和GPU使用统一的内存单元。所以Mac M1芯片的能使用的显存大小就是 Mac 电脑的内存大小。

  • Question3,使用Mac M1芯片加速 pytorch 需要安装 cuda后端吗?

不需要,cuda是适配nvidia的GPU的,Mac M1芯片中的GPU适配的加速后端是mps,在Mac对应操作系统中已经具备,无需单独安装。只需要安装适配的pytorch即可。

  • Question4,为什么有些可以在Mac Intel芯片电脑安装的软件不能在Mac M1芯片电脑上安装?

Mac M1芯片为了追求高性能和节能,在底层设计上使用的是一种叫做arm架构的精简指令集,不同于Intel等常用CPU芯片采用的x86架构完整指令集。所以有些基于x86指令集开发的软件不能直接在Mac M1芯片电脑上使用。

0840849789142d559d4ba2b384649a36.jpeg

二,环境配置

0,检查mac型号

点击桌面左上角mac图标——>关于本机——>概览,确定是m1芯片,了解内存大小(最好有16G以上,8G可能不太够用)。

a6b9131995a40865fd3366c4d6025feb.jpeg

1,下载 miniforge3 (miniforge3可以理解成 miniconda/annoconda 的社区版,提供了更稳定的对M1芯片的支持)

https://github.com/conda-forge/miniforge/#download

ca7fd2933d6be2c21e2e647ac6c122dd.jpeg

备注: annoconda 在 2022年5月开始也发布了对 mac m1芯片的官方支持,但还是推荐社区发布的miniforge3,开源且更加稳定。

2,安装 miniforge3

  1. chmod +x ~/Downloads/Miniforge3-MacOSX-arm64.sh
  2. sh ~/Downloads/Miniforge3-MacOSX-arm64.sh
  3. source ~/miniforge3/bin/activate

3,安装 pytorch (v1.12版本已经正式支持了用于mac m1芯片gpu加速的mps后端。)

pip install torch>=1.12 -i https://pypi.tuna.tsinghua.edu.cn/simple

4,测试环境

  1. import torch 
  2. print(torch.backends.mps.is_available()) 
  3. print(torch.backends.mps.is_built())

如果输出都是True的话,那么恭喜你配置成功了。

三,范例代码

下面以mnist手写数字识别为例,演示使用mac M1芯片GPU的mps后端来加速pytorch的完整流程。

核心操作非常简单,和使用cuda类似,训练前把模型和数据都移动到torch.device("mps")就可以了。

  1. import torch 
  2. from torch import nn 
  3. import torchvision 
  4. from torchvision import transforms 
  5. import torch.nn.functional as F 
  6. import os,sys,time
  7. import numpy as np
  8. import pandas as pd
  9. import datetime 
  10. from tqdm import tqdm 
  11. from copy import deepcopy
  12. from torchmetrics import Accuracy
  13. def printlog(info):
  14.     nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  15.     print("\n"+"=========="*8 + "%s"%nowtime)
  16.     print(str(info)+"\n")
  17.     
  18.     
  19. #================================================================================
  20. # 一,准备数据
  21. #================================================================================
  22. transform = transforms.Compose([transforms.ToTensor()])
  23. ds_train = torchvision.datasets.MNIST(root="mnist/",train=True,download=True,transform=transform)
  24. ds_val = torchvision.datasets.MNIST(root="mnist/",train=False,download=True,transform=transform)
  25. dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=2)
  26. dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=2)
  27. #================================================================================
  28. # 二,定义模型
  29. #================================================================================
  30. def create_net():
  31.     net = nn.Sequential()
  32.     net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=64,kernel_size = 3))
  33.     net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
  34.     net.add_module("conv2",nn.Conv2d(in_channels=64,out_channels=512,kernel_size = 3))
  35.     net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
  36.     net.add_module("dropout",nn.Dropout2d(p = 0.1))
  37.     net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
  38.     net.add_module("flatten",nn.Flatten())
  39.     net.add_module("linear1",nn.Linear(512,1024))
  40.     net.add_module("relu",nn.ReLU())
  41.     net.add_module("linear2",nn.Linear(1024,10))
  42.     return net
  43. net = create_net()
  44. print(net)
  45. # 评估指标
  46. class Accuracy(nn.Module):
  47.     def __init__(self):
  48.         super().__init__()
  49.         self.correct = nn.Parameter(torch.tensor(0.0),requires_grad=False)
  50.         self.total = nn.Parameter(torch.tensor(0.0),requires_grad=False)
  51.     def forward(self, preds: torch.Tensor, targets: torch.Tensor):
  52.         preds = preds.argmax(dim=-1)
  53.         m = (preds == targets).sum()
  54.         n = targets.shape[0
  55.         self.correct += m 
  56.         self.total += n
  57.         
  58.         return m/n
  59.     def compute(self):
  60.         return self.correct.float() / self.total 
  61.     
  62.     def reset(self):
  63.         self.correct -= self.correct
  64.         self.total -= self.total
  65.         
  66. #================================================================================
  67. # 三,训练模型
  68. #================================================================================     
  69. loss_fn = nn.CrossEntropyLoss()
  70. optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   
  71. metrics_dict = nn.ModuleDict({"acc":Accuracy()})
  72. # =========================移动模型到mps上==============================
  73. device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
  74. net.to(device)
  75. loss_fn.to(device)
  76. metrics_dict.to(device)
  77. # ====================================================================
  78. epochs = 20 
  79. ckpt_path='checkpoint.pt'
  80. #early_stopping相关设置
  81. monitor="val_acc"
  82. patience=5
  83. mode="max"
  84. history = {}
  85. for epoch in range(1, epochs+1):
  86.     printlog("Epoch {0} / {1}".format(epoch, epochs))
  87.     # 1,train -------------------------------------------------  
  88.     net.train()
  89.     
  90.     total_loss,step = 0,0
  91.     
  92.     loop = tqdm(enumerate(dl_train), total =len(dl_train),ncols=100)
  93.     train_metrics_dict = deepcopy(metrics_dict) 
  94.     
  95.     for i, batch in loop: 
  96.         
  97.         features,labels = batch
  98.         
  99.         # =========================移动数据到mps上==============================
  100.         features = features.to(device)
  101.         labels = labels.to(device)
  102.         # ====================================================================
  103.         
  104.         #forward
  105.         preds = net(features)
  106.         loss = loss_fn(preds,labels)
  107.         
  108.         #backward
  109.         loss.backward()
  110.         optimizer.step()
  111.         optimizer.zero_grad()
  112.             
  113.         #metrics
  114.         step_metrics = {"train_"+name:metric_fn(preds, labels).item() 
  115.                         for name,metric_fn in train_metrics_dict.items()}
  116.         
  117.         step_log = dict({"train_loss":loss.item()},**step_metrics)
  118.         total_loss += loss.item()
  119.         
  120.         step+=1
  121.         if i!=len(dl_train)-1:
  122.             loop.set_postfix(**step_log)
  123.         else:
  124.             epoch_loss = total_loss/step
  125.             epoch_metrics = {"train_"+name:metric_fn.compute().item() 
  126.                              for name,metric_fn in train_metrics_dict.items()}
  127.             epoch_log = dict({"train_loss":epoch_loss},**epoch_metrics)
  128.             loop.set_postfix(**epoch_log)
  129.             for name,metric_fn in train_metrics_dict.items():
  130.                 metric_fn.reset()
  131.                 
  132.     for name, metric in epoch_log.items():
  133.         history[name] = history.get(name, []) + [metric]
  134.         
  135.     # 2,validate -------------------------------------------------
  136.     net.eval()
  137.     
  138.     total_loss,step = 0,0
  139.     loop = tqdm(enumerate(dl_val), total =len(dl_val),ncols=100)
  140.     
  141.     val_metrics_dict = deepcopy(metrics_dict) 
  142.     
  143.     with torch.no_grad():
  144.         for i, batch in loop: 
  145.             features,labels = batch
  146.             
  147.             # =========================移动数据到mps上==============================
  148.             features = features.to(device)
  149.             labels = labels.to(device)
  150.             # ====================================================================
  151.             
  152.             #forward
  153.             preds = net(features)
  154.             loss = loss_fn(preds,labels)
  155.             #metrics
  156.             step_metrics = {"val_"+name:metric_fn(preds, labels).item() 
  157.                             for name,metric_fn in val_metrics_dict.items()}
  158.             step_log = dict({"val_loss":loss.item()},**step_metrics)
  159.             total_loss += loss.item()
  160.             step+=1
  161.             if i!=len(dl_val)-1:
  162.                 loop.set_postfix(**step_log)
  163.             else:
  164.                 epoch_loss = (total_loss/step)
  165.                 epoch_metrics = {"val_"+name:metric_fn.compute().item() 
  166.                                  for name,metric_fn in val_metrics_dict.items()}
  167.                 epoch_log = dict({"val_loss":epoch_loss},**epoch_metrics)
  168.                 loop.set_postfix(**epoch_log)
  169.                 for name,metric_fn in val_metrics_dict.items():
  170.                     metric_fn.reset()
  171.                     
  172.     epoch_log["epoch"] = epoch           
  173.     for name, metric in epoch_log.items():
  174.         history[name] = history.get(name, []) + [metric]
  175.     # 3,early-stopping -------------------------------------------------
  176.     arr_scores = history[monitor]
  177.     best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)
  178.     if best_score_idx==len(arr_scores)-1:
  179.         torch.save(net.state_dict(),ckpt_path)
  180.         print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,
  181.              arr_scores[best_score_idx]),file=sys.stderr)
  182.     if len(arr_scores)-best_score_idx>patience:
  183.         print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
  184.             monitor,patience),file=sys.stderr)
  185.         break 
  186.     net.load_state_dict(torch.load(ckpt_path))
  187.     
  188. dfhistory = pd.DataFrame(history)

四,使用torchkeras支持Mac M1芯片加速

我在最新的3.3.0的torchkeras版本中引入了对 mac m1芯片的支持,当存在可用的 mac m1芯片/ GPU 时,会默认使用它们进行加速,无需做任何配置。

使用范例如下。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/376469
推荐阅读
相关标签