当前位置:   article > 正文

Whisper语音识别 -- 自回归解码分析

Whisper语音识别 -- 自回归解码分析

前言

Whisper 是由 OpenAI 开发的一种先进语音识别系统。它采用深度学习技术,能够高效、准确地将语音转换为文本。Whisper 支持多种语言和口音,并且在处理背景噪音和语音变异方面表现出色。其广泛应用于语音助手、翻译服务、字幕生成等领域,为用户提供了更流畅的语音交互体验。作为一个开源项目,Whisper 鼓励开发者和研究人员进一步优化和创新。
在这里插入图片描述
作者将解码过程整理成 简单的python代码进行讲解

核心思想

whisper解码核心是 基于自回归解码的token游戏 ,换句话说他的参数读取是通过传入token id的形式,即采用大语言模型的prompt范式(whisper的解码器一定程度上也是个大语言模型,虽然语音训练样本token数远不及纯文本token数)
h
图中除了识别结果的框框大多数都是prompt工程, 常用的token id 如图:
在这里插入图片描述

自回归解码

在这里插入图片描述

详细解释放在代码中啦

def main():
    
    """
        解码器须构建Deocder的prompt,序列为【SOT,语种,任务】, 本文中是 model.sot_sequence
        其中SOT:50258
        语种:50332,50309,50333,50335,50273,...
        任务:transcribe 转写 50359, translate 翻译 50358
    """


    """
                加载whisper模型
    """
    encoder_onnx_file = './small-encoder.int8.onnx'
    decoder_onnx_file = './small-decoder.int8.onnx'
    tokenizer_file = './small-tokens.txt'
    model = OnnxModel(encoder_onnx_file, decoder_onnx_file)
    token_table = load_tokenizer(tokenizer_file) # token id to char 


    """
                提取MEL特征
    """
    wav_file = "output.wav"
    mel = compute_features(wav_file)


    """
                计算encoder的K/V编码 
    """
    # 交叉注意力 encoder:K/V, with decoder:Q
    n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
    # 自注意力 decoder:K/V, with decoder:Q
    n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()


    """
                检测语种
    """
    lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
    model.sot_sequence[1] = lang


    """
                任务选择
    """
    # task = model.translate
    task = model.transcribe
    model.sot_sequence[2] = task
    
    
    """
                根据prompt进行首次解码
    """
    tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
    offset = torch.zeros(1, dtype=torch.int64)
    logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
        tokens=tokens,
        n_layer_self_k_cache=n_layer_self_k_cache,
        n_layer_self_v_cache=n_layer_self_v_cache,
        n_layer_cross_k=n_layer_cross_k,
        n_layer_cross_v=n_layer_cross_v,
        offset=offset,
    )
    offset += len(model.sot_sequence)
    logits = logits[0, -1] # token 声学后验
    model.suppress_tokens(logits, is_initial=True) # 无效token后验抑制



    """
                自回归解码
    """
    max_token_id = logits.argmax(dim=-1) # 选择后验中最大输出的token【贪心解码】
    results = []
    sentence = {'start':0,'end':0,'text':b""} 
    sentences = []
    for i in range(model.n_text_ctx):

        # 打印token属性
        if max_token_id.item() == model.sot:
            print("iter:%8s docode token id:%8s [sot]"%(i,max_token_id.item()))
        elif max_token_id.item() == model.eot:
            print("iter:%8s docode token id:%8s [eot]"%(i,max_token_id.item()))
        elif max_token_id.item() >= model.timestamp_begin:
            print("iter:%8s docode token id:%8s [boundary]"%(i,max_token_id.item()))
        else:
            print("iter:%8s docode token id:%8s [char]"%(i,max_token_id.item()))
        
        # eot 结束
        if max_token_id.item() == model.eot:
            print("Finish !!")
            break

        # 检测到时间戳
        if max_token_id.item()>=model.timestamp_begin:
            timestamp = ((max_token_id.item()-model.timestamp_begin)*model.time_precision)
            # 遇到结束符
            if sentence['text']:
                sentence['end'] = timestamp
                sentence['text'] = sentence['text'].decode().strip()
                print(sentence)
                sentences.append(sentence)
                sentence = {'start':0,'end':0,'text':b""}
            # 遇到开始符
            else:
                sentence['start'] = timestamp
        else:
            decode_token = base64.b64decode(token_table[max_token_id.item()])
            sentence['text'] += decode_token


        results.append(max_token_id.item())
        tokens = torch.tensor([[results[-1]]])
        # deocder 单步解码
        logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
            tokens=tokens,
            n_layer_self_k_cache=n_layer_self_k_cache,
            n_layer_self_v_cache=n_layer_self_v_cache,
            n_layer_cross_k=n_layer_cross_k,
            n_layer_cross_v=n_layer_cross_v,
            offset=offset,
        )
        offset += 1
        logits = logits[0, -1]
        model.suppress_tokens(logits, is_initial=False)
        max_token_id = logits.argmax(dim=-1) # 贪心搜索
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127

没错连时间戳也是token形式~,下面是运行结果感受一下。我们在边界处对句子进行保存
在这里插入图片描述

以上就是whisper解码的基本原理,感兴趣的同学关注走一波

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/723930
推荐阅读
相关标签
  

闽ICP备14008679号