赞
踩
本文使用fastapi框架,异步传入用户语音输入,并流式返回大模型输出,实现语音电话的功能
首先定义一个websocket接口
@router.websocket("/chat/voice_call")
async def voice_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)):
await ws.accept()
await voice_call_handler(ws,db,redis)
再定义接口数据帧交互格式
data = {
"audio": [str], #base64编码后的音频数据切片
"meta_info": {
"session_id": [str], #会话_id
"encoding": [str] #压缩类型,暂时只有raw
},
"is_close":[bool] #当准备结束连接时发送True,正常连接时为False
}
ws_data = {
"audio" : "/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA...",
"meta_info":{
"session_id":"28445e6d-e8c1-46a6-b980-fbf39b918def",
"encoding": 'raw'
},
"is_close" : False
}
返回有两种形式,一种是返回文本信息,一种是返回二进制流音频信息
参数名称 | 参数类型 | 参数说明 |
---|---|---|
type | string | 说明返回帧类型,仅有类型,“error”,表示出现error |
code | int | 200为正常返回,500为异常返回 |
msg | string | 返回帧的信息 |
{"type": "error", "code": 500, "msg": "wrong frame"}
#获取session内容 def get_session_content(session_id,redis,db): session_content_str = "" if redis.exists(session_id): session_content_str = redis.get(session_id) else: session_db = db.query(Session).filter(Session.id == session_id).first() if not session_db: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") session_content_str = session_db.content return json.loads(session_content_str) #解析大模型流式返回内容 def parseChunkDelta(chunk): decoded_data = chunk.decode('utf-8') parsed_data = json.loads(decoded_data[6:]) if 'delta' in parsed_data['choices'][0]: delta_content = parsed_data['choices'][0]['delta'] return delta_content['content'] else: return "" #断句函数 def split_string_with_punctuation(current_sentence,text,is_first): result = [] for char in text: current_sentence += char if is_first and char in ',.?!,。?!': result.append(current_sentence) current_sentence = '' is_first = False elif char in '。?!': result.append(current_sentence) current_sentence = '' return result, current_sentence, is_first #vad预处理,语音活性检测数据必须为1280长度的字符串 def vad_preprocess(audio): if len(audio)<1280: return ('A'*1280) return audio[:1280],audio[1280:]
import webrtcvad
import base64
class VAD():
def __init__(self, vad_sensitivity=1, frame_duration=30, vad_buffer_size=7, min_act_time=1, RATE=16000,**kwargs):
self.RATE = RATE
self.vad = webrtcvad.Vad(vad_sensitivity)
self.vad_buffer_size = vad_buffer_size
self.vad_chunk_size = int(self.RATE * frame_duration / 1000)
self.min_act_time = min_act_time # 最小活动时间,单位秒
def is_speech(self,data):
byte_data = base64.b64decode(data)
return self.vad.is_speech(byte_data, self.RATE)
audio_q = asyncio.Queue() #音频队列
asr_result_q = asyncio.Queue() #语音识别结果队列
llm_response_q = asyncio.Queue() #大模型返回队列
split_result_q = asyncio.Queue() #断句结果队列
input_finished_event = asyncio.Event() #用户输入结束事件
asr_finished_event = asyncio.Event() #语音识别结束事件
llm_finished_event = asyncio.Event() #大模型结束事件
split_finished_event = asyncio.Event() #断句结束事件
voice_call_end_event = asyncio.Event() #语音电话终止事件
future = asyncio.Future() #用于获取传输的session_id
async def voice_call_audio_producer(ws,audio_queue,future,input_finished_event): logger.debug("音频数据生产函数启动") is_future_done = False audio_data = "" try: while not input_finished_event.is_set(): voice_call_data_json = json.loads(await ws.receive_text()) if not is_future_done: #在第一次循环中读取session_id future.set_result(voice_call_data_json['meta_info']['session_id']) is_future_done = True if voice_call_data_json["is_close"]: input_finished_event.set() break else: audio_data += voice_call_data_json["audio"] while len(audio_data) > 1280: vad_frame,audio_data = vad_preprocess(audio_data) await audio_queue.put(vad_frame) #将音频数据存入audio_q except KeyError as ke: logger.info(f"收到心跳包")
async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event): logger.debug("音频数据消费者函数启动") vad = VAD() current_message = "" vad_count = 0 while not (input_finished_event.is_set() and audio_q.empty()): audio_data = await audio_q.get() if vad.is_speech(audio_data): if vad_count > 0: vad_count -= 1 asr_result = asr.streaming_recognize(audio_data) current_message += ''.join(asr_result['text']) else: vad_count += 1 if vad_count >= 25: #连续25帧没有语音,则认为说完了 asr_result = asr.streaming_recognize(audio_data, is_end=True) if current_message: logger.debug(f"检测到静默,用户输入为:{current_message}") await asr_result_q.put(current_message) current_message = "" vad_count = 0 asr_finished_event.set()
async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event): logger.debug("asr结果消费以及llm返回生产函数启动") while not (asr_finished_event.is_set() and asr_result_q.empty()): session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) current_message = await asr_result_q.get() messages.append({'role': 'user', "content": current_message}) payload = json.dumps({ "model": llm_info["model"], "stream": True, "messages": messages, "max_tokens":10000, "temperature": llm_info["temperature"], "top_p": llm_info["top_p"] }) headers = { 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) if response.status_code == 200: for chunk in response.iter_lines(): if chunk: chunk_data =parseChunkDelta(chunk) llm_frame = {'message':chunk_data,'is_end':False} await llm_response_q.put(llm_frame) llm_frame = {'message':"",'is_end':True} await llm_response_q.put(llm_frame) llm_finished_event.set()
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
logger.debug("语音合成及返回函数启动")
while not (split_finished_event.is_set() and split_result_q.empty()):
sentence = await split_result_q.get()
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回:{sentence}")
asyncio.sleep(0.5)
await ws.close()
voice_call_end_event.set()
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
logger.debug("语音合成及返回函数启动")
while not (split_finished_event.is_set() and split_result_q.empty()):
sentence = await split_result_q.get()
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回:{sentence}")
asyncio.sleep(0.5)
await ws.close()
voice_call_end_event.set()
async def voice_call_handler(ws, db, redis): logger.debug("voice_call websocket 连接建立") audio_q = asyncio.Queue() #音频队列 asr_result_q = asyncio.Queue() #语音识别结果队列 llm_response_q = asyncio.Queue() #大模型返回队列 split_result_q = asyncio.Queue() #断句结果队列 input_finished_event = asyncio.Event() #用户输入结束事件 asr_finished_event = asyncio.Event() #语音识别结束事件 llm_finished_event = asyncio.Event() #大模型结束事件 split_finished_event = asyncio.Event() #断句结束事件 voice_call_end_event = asyncio.Event() #语音电话终止事件 future = asyncio.Future() #用于获取传输的session_id asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者 asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者 #获取session内容 session_id = await future #获取session_id tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event)) #创建llm处理者 asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) #创建llm断句结果 asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event)) #返回tts音频结果 while not voice_call_end_event.is_set(): await asyncio.sleep(3) await ws.close() logger.debug("voice_call websocket 连接断开")
最后贴一版完整代码
注:无法直接使用,仅提供思路
import webrtcvad import base64 class VAD(): def __init__(self, vad_sensitivity=1, frame_duration=30, vad_buffer_size=7, min_act_time=1, RATE=16000,**kwargs): self.RATE = RATE self.vad = webrtcvad.Vad(vad_sensitivity) self.vad_buffer_size = vad_buffer_size self.vad_chunk_size = int(self.RATE * frame_duration / 1000) self.min_act_time = min_act_time # 最小活动时间,单位秒 def is_speech(self,data): byte_data = base64.b64decode(data) return self.vad.is_speech(byte_data, self.RATE) def get_session_content(session_id,redis,db): session_content_str = "" if redis.exists(session_id): session_content_str = redis.get(session_id) else: session_db = db.query(Session).filter(Session.id == session_id).first() if not session_db: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") session_content_str = session_db.content return json.loads(session_content_str) #解析大模型流式返回内容 def parseChunkDelta(chunk): decoded_data = chunk.decode('utf-8') parsed_data = json.loads(decoded_data[6:]) if 'delta' in parsed_data['choices'][0]: delta_content = parsed_data['choices'][0]['delta'] return delta_content['content'] else: return "" #断句函数 def split_string_with_punctuation(current_sentence,text,is_first): result = [] for char in text: current_sentence += char if is_first and char in ',.?!,。?!': result.append(current_sentence) current_sentence = '' is_first = False elif char in '。?!': result.append(current_sentence) current_sentence = '' return result, current_sentence, is_first #vad预处理 def vad_preprocess(audio): if len(audio)<1280: return ('A'*1280) return audio[:1280],audio[1280:] #音频数据生产函数 async def voice_call_audio_producer(ws,audio_q,future,input_finished_event): logger.debug("音频数据生产函数启动") is_future_done = False audio_data = "" try: while not input_finished_event.is_set(): voice_call_data_json = json.loads(await ws.receive_text()) if not is_future_done: #在第一次循环中读取session_id future.set_result(voice_call_data_json['meta_info']['session_id']) is_future_done = True if voice_call_data_json["is_close"]: input_finished_event.set() break else: audio_data += voice_call_data_json["audio"] while len(audio_data) > 1280: vad_frame,audio_data = vad_preprocess(audio_data) await audio_q.put(vad_frame) #将音频数据存入audio_q except KeyError as ke: logger.info(f"收到心跳包") #音频数据消费函数 async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event): logger.debug("音频数据消费者函数启动") vad = VAD() current_message = "" vad_count = 0 while not (input_finished_event.is_set() and audio_q.empty()): audio_data = await audio_q.get() if vad.is_speech(audio_data): if vad_count > 0: vad_count -= 1 asr_result = asr.streaming_recognize(audio_data) current_message += ''.join(asr_result['text']) else: vad_count += 1 if vad_count >= 25: #连续25帧没有语音,则认为说完了 asr_result = asr.streaming_recognize(audio_data, is_end=True) if current_message: logger.debug(f"检测到静默,用户输入为:{current_message}") await asr_result_q.put(current_message) current_message = "" vad_count = 0 asr_finished_event.set() #asr结果消费以及llm返回生产函数 async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event): logger.debug("asr结果消费以及llm返回生产函数启动") while not (asr_finished_event.is_set() and asr_result_q.empty()): session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) current_message = await asr_result_q.get() messages.append({'role': 'user', "content": current_message}) payload = json.dumps({ "model": llm_info["model"], "stream": True, "messages": messages, "max_tokens":10000, "temperature": llm_info["temperature"], "top_p": llm_info["top_p"] }) headers = { 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) if response.status_code == 200: for chunk in response.iter_lines(): if chunk: chunk_data =parseChunkDelta(chunk) llm_frame = {'message':chunk_data,'is_end':False} await llm_response_q.put(llm_frame) llm_frame = {'message':"",'is_end':True} await llm_response_q.put(llm_frame) llm_finished_event.set() #llm结果返回函数 async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event): logger.debug("llm结果返回函数启动") llm_response = "" current_sentence = "" is_first = True while not (llm_finished_event.is_set() and llm_response_q.empty()): llm_frame = await llm_response_q.get() llm_response += llm_frame['message'] sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first) for sentence in sentences: await split_result_q.put(sentence) if llm_frame['is_end']: is_first = True session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session logger.debug(f"llm返回结果: {llm_response}") llm_response = "" current_sentence = "" split_finished_event.set() #语音合成及返回函数 async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event): logger.debug("语音合成及返回函数启动") while not (split_finished_event.is_set() and split_result_q.empty()): sentence = await split_result_q.get() sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) text_response = {"type": "text", "code": 200, "msg": sentence} await ws.send_bytes(audio) #返回音频二进制流数据 await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 logger.debug(f"websocket返回:{sentence}") asyncio.sleep(0.5) await ws.close() voice_call_end_event.set() async def voice_call_handler(ws, db, redis): logger.debug("voice_call websocket 连接建立") audio_q = asyncio.Queue() #音频队列 asr_result_q = asyncio.Queue() #语音识别结果队列 llm_response_q = asyncio.Queue() #大模型返回队列 split_result_q = asyncio.Queue() #断句结果队列 input_finished_event = asyncio.Event() #用户输入结束事件 asr_finished_event = asyncio.Event() #语音识别结束事件 llm_finished_event = asyncio.Event() #大模型结束事件 split_finished_event = asyncio.Event() #断句结束事件 voice_call_end_event = asyncio.Event() #语音电话终止事件 future = asyncio.Future() #用于获取传输的session_id asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者 asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者 #获取session内容 session_id = await future #获取session_id tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event)) #创建llm处理者 asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) #创建llm断句结果 asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event)) #返回tts音频结果 while not voice_call_end_event.is_set(): await asyncio.sleep(3) await ws.close() logger.debug("voice_call websocket 连接断开")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。