当前位置:   article > 正文

西工大 ASLP 实验室在 WeNet 中开源基于 CPPN 的神经网络热词增强语音识别方案_contextualized end-to-end speech recognition with

contextualized end-to-end speech recognition with contextual phrase predicti

语境偏置(Contextual biasing)旨在将语境知识集成到语音识别(ASR)系统中,以提高在相关领域词汇(俗称“热词”)上的识别准确率。在许多ASR场景中,待识别语音中可能会包含训练数据中数量很少或完全没出现的短语,例如一些领域专有名词、用户通讯录中的人名等,这些短语的识别准确程度对用户体验或下游任务的影响很大,但对于在通用数据上训练ASR系统来说又难以完全正确识别。因此语境偏置方法具有重要价值,旨在提升在这些“热词”上的识别准确率。

西工大音频与语音处理研究组(ASLP@NPU)近期在语音研究顶级会议 INTERSPEECH2023发表了论文题目为“Contextualized End-to-End Speech Recognition with Contextual Phrase Prediction Network”。该论文提出了一种基于热词短语预测网络的深度热词增强方法,该网络利用热词编码来预测语音中的热词,并通过计算偏置损失来辅助深度偏置模型的训练。所提出的方法能够应用在多种主流端到端ASR模型上,显著改进模型在热词数据上的识别准确率。

该方案的代码主要由该论文的一作黄凯勋同学完成,目前已开源在WeNet社区,详见:https://github.com/wenet-e2e/wenet/pull/1982

基于规则的热词增强

WeNet 之前实现的热词方案为基于规则的热词增强,即在解码过程中匹配到热词列表中的词时,给予该解码路径一定的奖励,这样含热词的搜索路径概率更大,更容易被识别出来。详见:

WeNet 热词增强 2.0 强势来袭

WeNet 更新:支持热词增强

虽然该方案在一般情况下也能有不错的处理能力,但对于训练中少见或没有见过的词即 rare word 的识别效果却非常很差。本质上因为该方法为浅融合(shadow fusion,或者后融合)的热词方案,在训练中没有见过,很难预测出来,在解码候选路径中根本没有出现,所以再怎么奖励也无济于事。而下文介绍的神经网络的热词则为深融合(deep fusion,或者前融合)则可以非常有效的解决 rare word 的问题,同时在常见热词识别也有更好的效果。

神经网络热词增强

基于神经网络的热词增强在端到端ASR模型中引入了一组独立的偏置模块来建模、集成热词信息,将热词增强的过程放到了神经网络模型推理的过程中。相比于传统的基于解码图的热词增强方案,神经网络热词增强使用上十分灵活,对于较小的热词列表可以取得优于解码图增强方案的效果,并且对于训练集中从未出现的罕见词也有着良好的增强效果。

近年来,神经网络热词增强方法得到了快速的发展。在2018年,谷歌首先提出了 CLAS 模型 [2],在 LAS 模型上通过全神经网络的方式集成热词信息。CLAS 的模型结构如下图所示,其中包含的两个偏置模块为:

  • 热词编码器(Bias Encoder):通过 RNN 对热词信息进行编码,将不等长的热词编码为相同维度的热词嵌入。

  • 热词偏置层(Attention):包含一个多头注意力层,使用音频编码作为 Query,热词编码作为 Key 和 Value,查询与当前音频相关的热词,提取热词感知编码。

图片

后续的神经网络热词增强研究基本沿袭自 CLAS,分为“编码热词、查询音频相关热词并集成信息”两步。在其他 ASR 模型上,提出了适用于 RNNT [3] 以及 CIF [4] 模型的神经网络热词增强模型,在神经网络增强方法上也产生了许多变体,如 TCPGen [5] 和 NAM [6] 等。

CPPN 论文介绍

之前的神经网络热词增强研究大多探索在如 RNNT 等 ASR 模型上实现或提升热词增强的效果,缺少在 AED 模型上有效的纯神经网络热词增强方法。因此,我们最初的动机就是寻找 AED 模型上有效的神经网络热词增强方案,并且由于我们使用 WeNet 框架进行实验,基于 attention rescore 解码依靠 CTC 后验的特点,我们希望能够在 Encoder 部分就能进行神经网络热词增强。

在实验最初阶段,我们尝试将一个基础的包含热词编码器以及热词偏置层的热词增强模块加入 AED 模型的 Encoder 中进行训练,但是模型却完全没有学习到热词增强的能力,推理过程中加入热词信息几乎不会对结果产生任何影响。我们分析认为这很可能是由于对 AED 模型而言,热词增强模块更难学习到热词编码与音频编码之间的联系。如果仅依靠 ASR 的损失函数进行训练,在没有对热词增强任务给出明确的监督的情况下,模型就会倾向于直接放弃热词增强模块的输出。因此我们研究的重心便转移到了寻找一个对于热词增强任务有效的监督损失上。

最直观的想法就是让热词增强模块直接预测出语句中包含的热词,但热词模块的输入输出长度与音频编码长度相同,这就存在编码长度与语句中包含的热词短语长度之间的不匹配。我们实验发现 CTC 损失可以很好的解决这个问题,并且由于 WeNet AED 模型的实现中 Encoder 本身也依靠 CTC 损失进行辅助训练,在预测热词任务上同样使用 CTC 损失可以让热词增强模块更好的契合 Encoder 的输出。在此基础上,我们提出了基于热词短语预测网络(Contextual Phrase Prediction Network, CPPN)的热词语音识别方案。热词增强模型结构如下图所示。

图片

图(c)是我们在 AED 模型基础上实现的热词增强模型结构,由于仅对模型 Encoder 部分进行了修改,因此没有画出 Decoder 部分的结构,这样的热词增强方案同样适用于仅包含一个 Encoder 的 CTC 模型。在 AED 热词增强模型中,热词增强相关的模块有:

  • 热词编码器(Context Encoder):我们使用 BLSTM 作为热词编码器对热词短语进行编码。

  • 热词偏置层(Biasing Layer):如图(b)所示,查询与当前音频相关的热词,提取热词感知编码。

  • 组合器(Combiner):拼接音频编码与热词感知编码,并通过一个线性层集成信息。

  • 热词短语预测网络(CPP Network):如图(a)所示,包含两个线性层,其中第二个线性层与 Encoder 的 CTC 线性层共享参数。热词短语预测网络使用热词感知编码作为输入,对语句中包含的热词短语进行预测,并使用 CTC 损失监督训练。例如对于语句标签为“我想去西工大学习语音识别”,如果“西工大”和“语音识别”是热词,则热词短语预测网络的标签为“西工大语音识别”。

图(d)是在 RNNT 模型基础上实现的热词增强模型结构,我们在 CATT [7] 热词增强方案的基础上添加了热词短语预测网络,相比于 AED 模型,在标签编码器(Label Encoder)后也添加了热词偏置层。

基于神经网络的热词增强方案还存在对于热词短语列表的大小较为敏感的问题,由于模型依靠注意力机制查找与音频关联的热词短语,当给出的热词列表过大时,注意力就会较为分散,影响热词增强的效果。因此,我们引入了一种两阶段的热词短语筛选算法[8],首先在不计算热词增强模块的情况下计算模型encoder部分的输出,使用 CTC 后验分别计算后验和置信度(PSC)和序列顺序置信度(SOC)以筛选出最可能出现在音频中的热词,然后构建成更小的热词列表再进行神经网络热词增强。

更多方案细节可参考推文:

Interspeech2023 | 基于热词短语预测网络的热词语音识别

模型实现

  1. 大部分热词增强模块的实现都位于 transformer/context_module.py 中,我们在 ContextModule 类中分别声明了热词编码器、热词偏置层、组合器、热词短语预测网络,以及增强任务的 CTC 损失。相比于论文中的实现,在此处我们将组合器替换为了更加简单的通过一个线性层后直接与音频编码相加的形式,这样能使 ASR 模型与热词模块更加解耦,并且经测试效果基本没有区别。并且我们还额外设置了一个热词增强权重用于控制热词增强的程度,热词感知编码先会与该权重相乘后再与音频编码相加。

  1. class ContextModule(torch.nn.Module):
  2. """Context module, Using context information for deep contextual bias
  3. During the training process, the original parameters of the ASR model
  4. are frozen, and only the parameters of context module are trained.
  5. Args:
  6. vocab_size (int): vocabulary size
  7. embedding_size (int): number of ASR encoder projection units
  8. encoder_layers (int): number of context encoder layers
  9. attention_heads (int): number of heads in the biasing layer
  10. """
  11. def __init__(
  12. self,
  13. vocab_size: int,
  14. embedding_size: int,
  15. encoder_layers: int = 2,
  16. attention_heads: int = 4,
  17. dropout_rate: float = 0.0,
  18. ):
  19. super().__init__()
  20. self.embedding_size = embedding_size
  21. self.encoder_layers = encoder_layers
  22. self.vocab_size = vocab_size
  23. self.attention_heads = attention_heads
  24. self.dropout_rate = dropout_rate
  25. # 热词编码器
  26. self.context_extractor = BLSTM(self.vocab_size, self.embedding_size,
  27. self.encoder_layers)
  28. self.context_encoder = nn.Sequential(
  29. nn.Linear(self.embedding_size * 4, self.embedding_size),
  30. nn.LayerNorm(self.embedding_size)
  31. )
  32. # 热词偏置层
  33. self.biasing_layer = MultiHeadedAttention(
  34. n_head=self.attention_heads,
  35. n_feat=self.embedding_size,s
  36. dropout_rate=self.dropout_rate
  37. )
  38. # 组合器
  39. self.combiner = nn.Linear(self.embedding_size, self.embedding_size)
  40. self.norm_aft_combiner = nn.LayerNorm(self.embedding_size)
  41. # 热词短语预测网络
  42. self.context_decoder = nn.Sequential(
  43. nn.Linear(self.embedding_size, self.embedding_size),
  44. nn.LayerNorm(self.embedding_size),
  45. nn.ReLU(inplace=True),
  46. )
  47. self.context_decoder_ctc_linear = nn.Linear(self.embedding_size,
  48. self.vocab_size)
  49. # 热词 CTC 损失
  50. self.bias_loss = torch.nn.CTCLoss(reduction="sum", zero_infinity=True)
  1. 我们将热词短语筛选部分的代码放到了 /utils/context_graph.py 中实现,该文件原本负责实现基于解码图的热词增强,包含有解码时所用的热词列表的信息,因此适合在此处进行热词短语筛选并将 tensor 形式的热词列表传递给热词增强模块。

  1. class ContextGraph:
  2. def get_context_list_tensor(self, context_list: List[List[int]]):
  3. """Add 0 as no-bias in the context list and obtain the tensor
  4. form of the context list
  5. """
  6. def two_stage_filtering(self,
  7. context_list: List[List[int]],
  8. ctc_posterior: torch.Tensor,
  9. filter_window_size: int = 64):
  10. """Calculate PSC and SOC for context phrase filtering,
  11. refer to: https://arxiv.org/abs/2301.06735
  12. """
  1. 训练时热词采样的代码位于 /dataset/processor.py 的 context_sampling 与 context_label_generate 方法中。在训练时,我们从每条话语中随机抽取数个长度随机的短语作为该条话语的热词,每个 batch 抽取出的热词短语再加上一些干扰短语(来自之前 batch 的热词)共同构成该 batch 的热词列表。

  1. def context_sampling(data,
  2. symbol_table,
  3. len_min,
  4. len_max,
  5. utt_num_context,
  6. batch_num_context,
  7. ):
  8. """Perform context sampling by randomly selecting context phrases from the
  9. utterance to obtain a context list for the entire batch
  10. Args:
  11. data: Iterable[List[{key, feat, label}]]
  12. Returns:
  13. Iterable[List[{key, feat, label, context_list}]]
  14. """
  15. def context_label_generate(label, context_list):
  16. """ Generate context labels corresponding to the utterances based on
  17. the context list
  18. """

实验结果

我们在 Librispeech test other 测试集上进行测试,使用 [9] 中给出的热词列表,热词列表可以从该链接获取:https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias

我们使用 WeNet 开源的 Librispeech 预训练 AED 模型进行实验,冻结原本 ASR 模型的参数,仅对热词增强模块训练 30 epoch 后取 3 个 epoch 平均得到的模型作为结果。在训练时,需要关闭 use_dynamic_chunk 流式训练选项,开启该选项会使得热词增强效果变差。经过测试非流式训练得到的模型对非流式和流式解码具有基本一致的热词增强性能。训练好的热词增强模型可以从该链接下载:https://huggingface.co/kxhuang/Wenet_Librispeech_deep_biasing/tree/main

除了词错误率 (WER) 之外,我们分别使用 U-WER 和 B-WER 来评估非热词部分与热词部分的字错误率,实验的结果如下。

非流式推理:

MethodList sizeGraph scoreBiasing scoreWERU-WERB-WER
baseline///8.775.5836.84
context graph38383.0/7.755.8324.62
CPPN3838/1.57.935.9225.64
context graph + CPPN38382.01.07.666.0821.48
context graph1003.0/7.325.4523.70
CPPN100/2.07.085.3322.41
context graph + CPPN1002.51.56.555.3317.27

表中 context graph 表示 WeNet 中基于解码图的热词增强方法。在使用较大热词列表(3838)的情况下,基于解码图的方法效果较好,而使用较小热词列表(100)的情况下,CPPN 的增强效果则更好。而无论热词列表大小,以合理的增强权重和分数同时使用两种方法时都能够取得比单一方法更好的效果。

热词列表(3838)的情况下,基于解码图的方法效果较好,而使用较小热词列表(100)的情况下,CPPN 的增强效果则更好。而无论热词列表大小,以合理的增强权重和分数同时使用两种方法时都能够取得比单一方法更好的效果。

小热词列表实验同时还表现出了非热词部分 WER 下降的情况,这是由于经过热词增强,模型在波束搜索时能够选择到包含正确热词并且非热词部分更加准确的路径。

流式推理 (chunk 16):

MethodList sizeGraph scoreBiasing scoreWERU-WERB-WER
baseline///10.477.0740.30
context graph1003.0/9.066.9927.21
CPPN100/2.08.866.8726.28
context graph + CPPN1002.51.58.176.8519.72

在流式推理的结果中,使用CPPN也有显著的 B-WER 改进,与非流式推理热词增强效果基本一致。

参考文献

[1] Kaixun Huang, Ao Zhang, Zhanheng Yang, Pengcheng Guo, Bingshen Mu, Tianyi Xu, Lei Xie: Contextualized End-to-End Speech Recognition with Contextual Phrase Prediction Network. INTERSPEECH 2023

[2] Golan Pundak, Tara N. Sainath, Rohit Prabhavalkar, Anjuli Kannan, Ding Zhao: Deep Context: End-to-end Contextual Speech Recognition. SLT 2018

[3] Mahaveer Jain, Gil Keren, Jay Mahadeokar, Geoffrey Zweig, Florian Metze, Yatharth Saraf: Contextual RNN-T for Open Domain ASR. INTERSPEECH 2020

[4] Minglun Han, Linhao Dong, Shiyu Zhou, Bo Xu: Cif-Based Collaborative Decoding for End-to-End Contextual Speech Recognition. ICASSP 2021

[5] Guangzhi Sun, Chao Zhang, Philip C. Woodland: Tree-Constrained Pointer Generator for End-to-End Contextual Speech Recognition. ASRU 2021

[6] Tsendsuren Munkhdalai, Khe Chai Sim, Angad Chandorkar, Fan Gao, Mason Chua, Trevor Strohman, Françoise Beaufays: Fast Contextual Adaptation with Neural Associative Memory for On-Device Personalized Speech Recognition. ICASSP 2022

[7] Feng-Ju Chang, Jing Liu, Martin Radfar, Athanasios Mouchtaris, Maurizio Omologo, Ariya Rastrow, Siegfried Kunzmann:

Context-Aware Transformer Transducer for Speech Recognition. ASRU 2021 [8] Zhanheng Yang, Sining Sun, Xiong Wang, Yike Zhang, Long Ma, Lei Xie: Two Stage Contextual Word Filtering for Context bias in Unified Streaming and Non-streaming Transducer. INTERSPEECH 2023

[9] Duc Le, Mahaveer Jain, Gil Keren, Suyoun Kim, Yangyang Shi, Jay Mahadeokar, Julian Chan, Yuan Shangguan, Christian Fuegen, Ozlem Kalinli, Yatharth Saraf, Michael L. Seltzer: Contextualized Streaming End-to-End Speech Recognition with Trie-Based Deep Biasing and Shallow Fusion. INTERSPEECH 2021

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/362203
推荐阅读
相关标签
  

闽ICP备14008679号