当前位置:   article > 正文

【ASR代码】基于pyannote和whisper的语音识别代码

【ASR代码】基于pyannote和whisper的语音识别代码
from pyannote.core import Segment
import whisper
import pickle
import torch
import time
import os
from zhconv import convert
from pyannote.audio import Pipeline
from pyannote.core import Annotation
def get_text_with_timestamp(transcribe_res):
    timestamp_texts = []
    for item in transcribe_res["segments"]:
        start = item["start"]
        end = item["end"]
        # text = convert(item["text"],'zh-cn').strip()
        text=item["text"]
        timestamp_texts.append((Segment(start, end), text))
    return timestamp_texts


def add_speaker_info_to_text(timestamp_texts, ann):
    spk_text = []
    for seg, text in timestamp_texts:
        print(ann.crop(seg))
        spk = ann.crop(seg).argmax()
        spk_text.append((seg, spk, text))
    return spk_text


def merge_cache(text_cache):
    sentence = ''.join([item[-1] for item in text_cache])
    spk = text_cache[0][1]
    start = round(text_cache[0][0].start, 1)
    end = round(text_cache[-1][0].end, 1)
    return Segment(start, end), spk, sentence


PUNC_SENT_END = ['.', '?', '!', "。", "?", "!"]


def merge_sentence(spk_text):
    merged_spk_text = []
    pre_spk = None
    text_cache = []
    for seg, spk, text in spk_text:
        if spk != pre_spk and len(text_cache) > 0:
            merged_spk_text.append(merge_cache(text_cache))
            text_cache = [(seg, spk, text)]
            pre_spk = spk
        elif spk==pre_spk and text==text_cache[-1][2]:
            print(text_cache[-1][2])
            print(text)
            continue

            # merged_spk_text.append(merge_cache(text_cache))
            # text_cache.append((seg, spk, text))
            # pre_spk = spk
        else:
            text_cache.append((seg, spk, text))
            pre_spk = spk
    if len(text_cache) > 0:
        merged_spk_text.append(merge_cache(text_cache))
    return merged_spk_text


def diarize_text(transcribe_res, diarization_result):
    
    timestamp_texts = get_text_with_timestamp(transcribe_res)
    spk_text = add_speaker_info_to_text(timestamp_texts, diarization_result)
    res_processed = merge_sentence(spk_text)
    return res_processed

# def write_to_txt(spk_sent, file):
#     with open(file, 'w') as fp:
#         for seg, spk, sentence in spk_sent:
#             line = f'{seg.start:.2f} {seg.end:.2f} {spk} {sentence}\n'
#             fp.write(line)
# def format_time(seconds):
#     # 计算小时、分钟和秒数
#     hours = seconds // 3600
#     minutes = (seconds % 3600) // 60
#     seconds = seconds % 60
    
#     # 格式化输出
#     return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
if __name__=="__main__":
    sd_config_path="./models/speaker-diarization-3.1/config.yaml"
    asr_model=whisper.load_model("large")
    asr_model.to(torch.device("cuda"))

    speaker_diarization = Pipeline.from_pretrained(sd_config_path)
    speaker_diarization.to(torch.device("cuda"))
    files = os.listdir("./audios_wav")
    
    for file in files:
        start_time = time.time()
        print(file)
       
        dialogue_path="./audios_txt/"+file.split(".")[0]+".pkl"
        audio="./audios_wav/"+file
        asr_result = asr_model.transcribe(audio,initial_prompt="输入的音频是关于一个采访内容,接下来您将扮演一个优秀记录能力的听众,通过倾听语音内容,将语音信息通过文字的方式记录下来。请你首先要判断语音中讲话者的讲话内容和语气,根据内容和语气记录带有标点符号的文本信息。具体要求为:1、中文语音的文本字体为简体中文,其他类型语音根据语音中说话的语种类型记录;2、文本信息的标点符号和文本内容要准确,不能捏造信息,同一段语音不能重复识别,不能捏造语音的语种类型;示例输出格式:-就AI的研发和部署而言,为什么你觉得中国很快就能赶上、甚至赶超美国?-首要原因是AI已经完成了从探索阶段到应用阶段的转型。在探索阶段,最先取得探索成果的人拥有绝对优势;然而现在AI算法已为诸多业内实践人士所熟知。所以,现在的关键在于速度、执行、资本以及对海量数据的获取,而中国在以上每个层面都具有优势。")
        asr_time=time.time()
        print("ASR time:"+str(asr_time-start_time))

        diarization_result: Annotation = speaker_diarization(audio)
        final_result = diarize_text(asr_result, diarization_result)
        
        dialogue=[]
        for segment, spk, sent in final_result:
            content={'speaker':spk,'start':segment.start,'end': segment.end,'text':sent}
            dialogue.append(content)
            print("%s [%.2fs -> %.2fs] %s " % (spk,segment.start, segment.end, sent))
        with open(dialogue_path, 'wb') as f:
            pickle.dump(dialogue, f)
        end_time = time.time()
        print(file+" spend time:"+str(end_time-start.time))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117

  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/288712
推荐阅读
相关标签
  

闽ICP备14008679号