当前位置:   article > 正文

Ubuntu22.04安装Whisper-jax_whisper jax

whisper jax

1、安装jax

1.1、前提条件

已经安装好了NVIDIA显卡驱动和CUDA。如果你还没安装,那么你可以参考我的这篇文章

jax是谷歌推出的深度学习框架

这里安装的是GPU版本的jax。

1.2、安装

源码地址:

https://github.com/google/jax
  • 1

官方安装教程

更新pip:

pip install --upgrade pip
  • 1

安装jax:

cuda 11

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  • 1

cuda 12

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  • 1

测试是否安装成功,可以参考这篇文章

import jax.numpy as np

from jax import random
import time

x = random.uniform(random.PRNGKey(0),[5000,5000])
st = time.time()
try:
	y=np.matmul(x,x)

except Exception:

	print("error")

print(time.time()-st)

print(y)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

如果使用国外的源无法下载时,可以更换为国内的镜像,这里推荐更换为清华大学的镜像。

2、安装cuDNN

点击这里下载cuDNN

下载完成之后上传到服务器,然后解压cuDNN。

tar -xvf xxx.tar
  • 1

然后进入到解压后的目录。
然后复制到CUDA-12.1目录。

sudo cp include/* /usr/local/cuda-12.1/include
sudo cp lib/libcudnn* /usr/local/cuda-12.1/lib64
sudo chmod a+r /usr/local/cuda-12.1/include/cudnn*
sudo chmod a+r /usr/local/cuda-12.1/lib64/libcudnn*
  • 1
  • 2
  • 3
  • 4

查看cuDNN版本。

cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
  • 1

3、安装whisper-jax

源码:

https://github.com/sanchit-gandhi/whisper-jax
  • 1

安装:这里推荐在Anaconda的虚拟环境中安装。如何安装Anaconda,可以去参考我的这篇文章

pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
  • 1

注意不能使用国内pip源安装。

修改pip配置文件。

vim ~/.pip/pip.conf
  • 1

把国内镜像源注释掉。

验证jax是否使用GPU

import jax
print(jax.devices()[0])
  • 1
  • 2

如果返回gpu:0说明可以使用GPU。

4、whisper-jax的使用

whisper-jax没有提供命令行方式运行。

from whisper_jax import FlaxWhisperPipline

# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-small")

# JIT compile the forward call - slow, but we only do once
text = pipeline("audio.mp3")

# used cached function thereafter - super fast!!
text = pipeline("audio.mp3")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

个人感觉whisper-jax更吃显卡性能,经过我在Tesla T4 15GB显卡上的测试,无法跑large模型。

转录的速度要比whisper更快。

4.1、调整精度

可以修改精度,加快转录速度。

对于A100显卡和TPU:

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.bfloat16)
  • 1
  • 2
  • 3
  • 4
  • 5

对于非A100显卡:

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.float16)
  • 1
  • 2
  • 3
  • 4
  • 5

4.2、批处理

from whisper_jax import FlaxWhisperPipline

# instantiate pipeline with batching
pipeline = FlaxWhisperPipline("openai/whisper-small", batch_size=16)
  • 1
  • 2
  • 3
  • 4

4.3、输出带时间戳的文件

from whisper_jax import FlaxWhisperPipline

pipeline = FlaxWhisperPipline("openai/whisper-small")

text = pipeline("audio.mp3", return_timestamps=True)

chunks = text["chunks"]

with open('output', 'w') as f:    
	for item in chunks:
      	f.write(str(item) + "\n")
	f.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

更多内容欢迎访问个人博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/288868
推荐阅读
相关标签
  

闽ICP备14008679号