赞
踩
由于上一篇博客,我把torchAudio中的wav2vec2样例加上自己理解发了出来,这次我们就来看看torchaudio中,wav2vec2.0的模型是怎么创建的。
博主也在边写博客边看这源码学习,理解得不一定对,有错希望大佬指出~
观前提示:
首先我们看回这个样例的一行和创建模型很相关的代码。
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
我在理解中,这句话就等于创建一个wav2vec2语音识别的模型工厂了。我们看看里面是什么。
可以看到引用的bundle在pipelines初始化文件中。
这个文件主要引用了文件夹_wav2vec2中的文件impl.py的数据类。
可以看出WAV2VEC2_ASR_960H
等于创建了一个Wav2Vec2ASRBundle
数据类。这个数据类里的数据,看着就是特征提取和分类的模型构造参数。然后还有最终的字母标签获取_labels
和采样率设置_sample_rate
。
我们继续递归,看看这个Wav2Vec2ASRBundle
数据类的结构:
就这点东西。Wav2Vec2ASRBundle
数据类继承了Wav2Vec2Bundle
数据类。然后自己新增了标签属性_labels
,和一个不知什么东西的位置_remove_aux_axis
。里面还定义了get_labels方法和_get_state_dict方法
然后我们再看看被Wav2Vec2ASRBundle
类继承的Wav2Vec2Bundle
类是什么:
里面有模型路径_path
、模型参数_param
、采样率_sample_rate
,以及获取采样率的方法sample_rate、获取状态字典方法_get_state_dict、获取模型方法get_model。
(博主表示不知道状态字典的作用是什么,后面看到了再进行修改,或者有大佬指点也可以,感谢)
在我写了样例博客的单元块三中有一行代码,如是写道:
model = bundle.get_model().to(device)
可以看出,他就是调用了Wav2Vec2Bundle
类中的get_model方法。通过get_model方法能直接建立好我们的wav2vec2模型。因此,我们仔细看看get_model方法的代码。
我们点进wav2vec2_model方法看看
它跳到了models文件夹中的model.py文件。
我们发现传参就是我们的模型结构参数。
然后做了特征提取工作feature_extractor
和预训练transformer模型创建工作encoder
,以及下面的线性转换aux
。最后返回一个Wav2Vec2Models回去。
而aux_num_out代表什么,博主不知道,博主表示压根不知道aux代表什么东西。因此希望各位好学者推测推测或者大牛给个答案?
今天就先水到这了,明天再接着看。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。