当前位置:   article > 正文

提升 5-7 倍速,使用 Mac M1 芯片加速 Pytorch 完全指南_mac cuda

mac cuda

2022年5月,PyTorch官方宣布已正式支持在M1芯片版本的Mac上进行模型加速。官方对比数据显示,和CPU相比,M1上炼丹速度平均可加速7倍。

哇哦,不用单独配个GPU也能加速这么多,我迫不及待地搞到一个M1芯片的MacBook后试水了一番,并把我认为相关重要的信息梳理成了本文。

一,加速原理

  • Question1,Mac M1芯片 为什么可以用来加速 pytorch?

因为 Mac M1芯片不是一个单纯的一个CPU芯片,而是包括了CPU(中央处理器),GPU(图形处理器),NPU(神经网络引擎),以及统一内存单元等众多组件的一块集成芯片。由于Mac M1芯片集成了GPU组件,所以可以用来加速pytorch.

  • Question2,Mac M1芯片 上GPU的的显存有多大?

Mac M1芯片的CPU和GPU使用统一的内存单元。所以Mac M1芯片的能使用的显存大小就是 Mac 电脑的内存大小。

  • Question3,使用Mac M1芯片加速 pytorch 需要安装 cuda后端吗?

不需要,cuda是适配nvidia的GPU的,Mac M1芯片中的GPU适配的加速后端是mps,在Mac对应操作系统中已经具备,无需单独安装。只需要安装适配的pytorch即可。

  • Question4,为什么有些可以在Mac Intel芯片电脑安装的软件不能在Mac M1芯片电脑上安装?

Mac M1芯片为了追求高性能和节能,在底层设计上使用的是一种叫做arm架构的精简指令集,不同于Intel等常用CPU芯片采用的x86架构完整指令集。所以有些基于x86指令集开发的软件不能直接在Mac M1芯片电脑上使用。

二,环境配置

0,检查mac型号

点击桌面左上角mac图标——>关于本机——>概览,确定是m1芯片,了解内存大小(最好有16G以上,8G可能不太够用)。

1,下载 miniforge3 (miniforge3可以理解成 miniconda/annoconda 的社区版,提供了更稳定的对M1芯片的支持)

https://github.com/conda-forge/miniforge/#download

备注: annoconda 在 2022年5月开始也发布了对 mac m1芯片的官方支持,但还是推荐社区发布的miniforge3,开源且更加稳定。

2,安装 miniforge3

  1. chmod +x ~/Downloads/Miniforge3-MacOSX-arm64.sh
  2. sh ~/Downloads/Miniforge3-MacOSX-arm64.sh
  3. source ~/miniforge3/bin/activate

3,安装 pytorch (v1.12版本已经正式支持了用于mac m1芯片gpu加速的mps后端。)

  1. pip install torch>=1.12 -i https://pypi.tuna.tsinghua.edu.cn/simple 

4,测试环境

  1. import torch 
  2. print(torch.backends.mps.is_available()) 
  3. print(torch.backends.mps.is_built())

如果输出都是True的话,那么恭喜你配置成功了。

三,范例代码

下面以mnist手写数字识别为例,演示使用mac M1芯片GPU的mps后端来加速pytorch的完整流程。

核心操作非常简单,和使用cuda类似,训练前把模型和数据都移动到torch.device("mps")就可以了。

  1. import torch 
  2. from torch import nn 
  3. import torchvision 
  4. from torchvision import transforms 
  5. import torch.nn.functional as F 
  6. import os,sys,time
  7. import numpy as np
  8. import pandas as pd
  9. import datetime 
  10. from tqdm import tqdm 
  11. from copy import deepcopy
  12. from torchmetrics import Accuracy
  13. def printlog(info):
  14.     nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  15.     print("\n"+"=========="*8 + "%s"%nowtime)
  16.     print(str(info)+"\n")
  17.     
  18.     
  19. #================================================================================
  20. # 一,准备数据
  21. #================================================================================
  22. transform = transforms.Compose([transforms.ToTensor()])
  23. ds_train = torchvision.datasets.MNIST(root="mnist/",train=True,download=True,transform=transform)
  24. ds_val = torchvision.datasets.MNIST(root="mnist/",train=False,download=True,transform=transform)
  25. dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=2)
  26. dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=2)
  27. #======================================================================&
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/822807
推荐阅读
相关标签
  

闽ICP备14008679号