diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index 6badae3..65838d7 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -26,6 +26,7 @@ from asyncio import Event from signal import signal, SIGINT +from time import sleep from typing import Optional, Union from fastapi import APIRouter, WebSocket, HTTPException, Request @@ -75,16 +76,24 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str): must first establish a connection to the `/v1` endpoint. """ client_id = client_manager.get_client_id(token) + + # Handle problem clients that don't explicitly wait for the Node WS to + # connect before starting a stream + retries = 0 + while not socket_api.get_session(client_id) and retries < 3: + sleep(1) + retries += 1 if not socket_api.get_session(client_id): raise HTTPException(status_code=401, detail=f"Client not known ({client_id})") + await websocket.accept() disconnect_event = Event() - + socket_api.new_stream(websocket, client_id) while not disconnect_event.is_set(): try: client_in: bytes = await websocket.receive_bytes() - socket_api.handle_audio_stream(client_in, client_id) + socket_api.handle_audio_input_stream(client_in, client_id) except WebSocketDisconnect: disconnect_event.set() diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 41d4beb..c46907a 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -60,6 +60,28 @@ def new_connection(self, ws: WebSocket, session_id: str): "socket": ws, "user": self.user_config} + def new_stream(self, ws: WebSocket, session_id: str): + """ + Establish a new streaming connection, associated with an existing session. + @param ws: Client WebSocket that handles byte audio + @param session_id: Session ID the websocket is associated with + """ + if session_id not in self._sessions: + raise RuntimeError(f"Stream cannot be established for {session_id}") + from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone + if not self._sessions[session_id].get('stream'): + LOG.info(f"starting stream for session {session_id}") + audio_queue = Queue() + stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id, + input_audio_callback=self.handle_client_input, + ww_callback=self.handle_ww_detected, + client_socket=ws) + self._sessions[session_id]['stream'] = stream + try: + stream.start() + except RuntimeError: + pass + def end_session(self, session_id: str): """ End a client connection upon WS disconnection @@ -130,20 +152,7 @@ def _update_session_data(self, message: Message): if user_config: self._sessions[session_id]['user'] = user_config - def handle_audio_stream(self, audio: bytes, session_id: str): - from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone - if not self._sessions[session_id].get('stream'): - LOG.info(f"starting stream for session {session_id}") - audio_queue = Queue() - stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id, - audio_callback=self.handle_client_input, - ww_callback=self.handle_ww_detected) - self._sessions[session_id]['stream'] = stream - try: - stream.start() - except RuntimeError: - pass - + def handle_audio_input_stream(self, audio: bytes, session_id: str): self._sessions[session_id]['stream'].mic.queue.put(audio) def handle_ww_detected(self, ww_context: dict, session_id: str): @@ -172,13 +181,16 @@ def handle_klat_response(self, message: Message): Handle a Neon text+audio response to a user input. @param message: `klat.response` message from Neon """ - self._update_session_data(message) - run(self.send_to_client(message)) - session_id = message.context.get('session', {}).get('session_id') - if self._sessions.get(session_id, {}).get('stream'): - # TODO: stream response audio to streaming socket - pass - LOG.debug(message.context.get("timing")) + try: + self._update_session_data(message) + run(self.send_to_client(message)) + session_id = message.context.get('session', {}).get('session_id') + if stream := self._sessions.get(session_id, {}).get('stream'): + LOG.info("Stream response audio") + stream.on_response_audio(message.data) + LOG.debug(message.context.get("timing")) + except Exception as e: + LOG.exception(e) def handle_complete_intent_failure(self, message: Message): """ diff --git a/neon_hana/streaming_client.py b/neon_hana/streaming_client.py index 67a5b7f..536c9d6 100644 --- a/neon_hana/streaming_client.py +++ b/neon_hana/streaming_client.py @@ -1,4 +1,6 @@ -from base64 import b64encode +import io +from asyncio import run +from base64 import b64encode, b64decode from typing import Optional, Callable from mock.mock import Mock from threading import Thread @@ -12,6 +14,7 @@ from ovos_utils.fakebus import FakeBus from speech_recognition import AudioData from ovos_utils import LOG +from starlette.websockets import WebSocket class StreamMicrophone(Microphone): @@ -30,12 +33,14 @@ def read_chunk(self) -> Optional[bytes]: class RemoteStreamHandler(Thread): def __init__(self, mic: StreamMicrophone, session_id: str, - audio_callback: Callable, + input_audio_callback: Callable, + client_socket: WebSocket, ww_callback: Callable, lang: str = "en-us"): Thread.__init__(self) self.session_id = session_id self.ww_callback = ww_callback - self.audio_callback = audio_callback + self.input_audio_callback = input_audio_callback + self.client_socket = client_socket self.bus = FakeBus() self.mic = mic self.lang = lang @@ -49,7 +54,7 @@ def __init__(self, mic: StreamMicrophone, session_id: str, hotword_audio_callback=self.on_hotword, stopword_audio_callback=self.on_hotword, wakeupword_audio_callback=self.on_hotword, - stt_audio_callback=self.on_audio, + stt_audio_callback=self.on_input_audio, stt=Mock(transcribe=Mock(return_value=[])), fallback_stt=Mock(transcribe=Mock(return_value=[])), transformers=MockTransformers(), @@ -67,14 +72,27 @@ def on_hotword(self, audio_bytes: bytes, context: dict): LOG.info(f"Hotword: {context}") self.ww_callback(context, self.session_id) - def on_audio(self, audio_bytes: bytes, context: dict): + def on_input_audio(self, audio_bytes: bytes, context: dict): LOG.info(f"Audio: {context}") audio_data = AudioData(audio_bytes, self.mic.sample_rate, self.mic.sample_width).get_wav_data() audio_data = b64encode(audio_data).decode("utf-8") callback_data = {"type": "neon.audio_input", "data": {"audio_data": audio_data, "lang": self.lang}} - self.audio_callback(callback_data, self.session_id) + self.input_audio_callback(callback_data, self.session_id) + + def on_response_audio(self, data: dict): + async def _send_bytes(audio_bytes: bytes): + await self.client_socket.send_bytes(audio_bytes) + + i = 0 + for lang_response in data.get('responses', {}).values(): + for encoded_audio in lang_response.get('audio', {}).values(): + i += 1 + wav_audio_bytes = b64decode(encoded_audio) + LOG.info(f"Sending {len(wav_audio_bytes)} bytes of audio") + run(_send_bytes(wav_audio_bytes)) + LOG.info(f"Sent {i} binary audio response") def on_chunk(self, chunk: ChunkInfo): LOG.debug(f"Chunk: {chunk}") diff --git a/setup.py b/setup.py index 3777817..85ddbef 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ def get_requirements(requirements_filename: str): packages=find_packages(), install_requires=get_requirements("requirements.txt"), extras_require={"websocket": get_requirements("websocket.txt"), - "steaming": get_requirements("streaming.txt")}, + "streaming": get_requirements("streaming.txt")}, zip_safe=True, classifiers=[ 'Intended Audience :: Developers',