Skip to content

Commit

Permalink
Fix typo in extra dependencies
Browse files Browse the repository at this point in the history
Handle streaming socket retry if too early
Implement streaming audio responses
  • Loading branch information
NeonDaniel committed Oct 4, 2024
1 parent 270479b commit 2970e1d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 30 deletions.
13 changes: 11 additions & 2 deletions neon_hana/app/routers/node_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from asyncio import Event
from signal import signal, SIGINT
from time import sleep
from typing import Optional, Union

from fastapi import APIRouter, WebSocket, HTTPException, Request
Expand Down Expand Up @@ -75,16 +76,24 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str):
must first establish a connection to the `/v1` endpoint.
"""
client_id = client_manager.get_client_id(token)

# Handle problem clients that don't explicitly wait for the Node WS to
# connect before starting a stream
retries = 0
while not socket_api.get_session(client_id) and retries < 3:
sleep(1)
retries += 1
if not socket_api.get_session(client_id):
raise HTTPException(status_code=401,
detail=f"Client not known ({client_id})")

await websocket.accept()
disconnect_event = Event()

socket_api.new_stream(websocket, client_id)
while not disconnect_event.is_set():
try:
client_in: bytes = await websocket.receive_bytes()
socket_api.handle_audio_stream(client_in, client_id)
socket_api.handle_audio_input_stream(client_in, client_id)
except WebSocketDisconnect:
disconnect_event.set()

Expand Down
54 changes: 33 additions & 21 deletions neon_hana/mq_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ def new_connection(self, ws: WebSocket, session_id: str):
"socket": ws,
"user": self.user_config}

def new_stream(self, ws: WebSocket, session_id: str):
"""
Establish a new streaming connection, associated with an existing session.
@param ws: Client WebSocket that handles byte audio
@param session_id: Session ID the websocket is associated with
"""
if session_id not in self._sessions:
raise RuntimeError(f"Stream cannot be established for {session_id}")
from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone
if not self._sessions[session_id].get('stream'):
LOG.info(f"starting stream for session {session_id}")
audio_queue = Queue()
stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id,
input_audio_callback=self.handle_client_input,
ww_callback=self.handle_ww_detected,
client_socket=ws)
self._sessions[session_id]['stream'] = stream
try:
stream.start()
except RuntimeError:
pass

def end_session(self, session_id: str):
"""
End a client connection upon WS disconnection
Expand Down Expand Up @@ -130,20 +152,7 @@ def _update_session_data(self, message: Message):
if user_config:
self._sessions[session_id]['user'] = user_config

def handle_audio_stream(self, audio: bytes, session_id: str):
from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone
if not self._sessions[session_id].get('stream'):
LOG.info(f"starting stream for session {session_id}")
audio_queue = Queue()
stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id,
audio_callback=self.handle_client_input,
ww_callback=self.handle_ww_detected)
self._sessions[session_id]['stream'] = stream
try:
stream.start()
except RuntimeError:
pass

def handle_audio_input_stream(self, audio: bytes, session_id: str):
self._sessions[session_id]['stream'].mic.queue.put(audio)

def handle_ww_detected(self, ww_context: dict, session_id: str):
Expand Down Expand Up @@ -172,13 +181,16 @@ def handle_klat_response(self, message: Message):
Handle a Neon text+audio response to a user input.
@param message: `klat.response` message from Neon
"""
self._update_session_data(message)
run(self.send_to_client(message))
session_id = message.context.get('session', {}).get('session_id')
if self._sessions.get(session_id, {}).get('stream'):
# TODO: stream response audio to streaming socket
pass
LOG.debug(message.context.get("timing"))
try:
self._update_session_data(message)
run(self.send_to_client(message))
session_id = message.context.get('session', {}).get('session_id')
if stream := self._sessions.get(session_id, {}).get('stream'):
LOG.info("Stream response audio")
stream.on_response_audio(message.data)
LOG.debug(message.context.get("timing"))
except Exception as e:
LOG.exception(e)

def handle_complete_intent_failure(self, message: Message):
"""
Expand Down
30 changes: 24 additions & 6 deletions neon_hana/streaming_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from base64 import b64encode
import io
from asyncio import run
from base64 import b64encode, b64decode
from typing import Optional, Callable
from mock.mock import Mock
from threading import Thread
Expand All @@ -12,6 +14,7 @@
from ovos_utils.fakebus import FakeBus
from speech_recognition import AudioData
from ovos_utils import LOG
from starlette.websockets import WebSocket


class StreamMicrophone(Microphone):
Expand All @@ -30,12 +33,14 @@ def read_chunk(self) -> Optional[bytes]:

class RemoteStreamHandler(Thread):
def __init__(self, mic: StreamMicrophone, session_id: str,
audio_callback: Callable,
input_audio_callback: Callable,
client_socket: WebSocket,
ww_callback: Callable, lang: str = "en-us"):
Thread.__init__(self)
self.session_id = session_id
self.ww_callback = ww_callback
self.audio_callback = audio_callback
self.input_audio_callback = input_audio_callback
self.client_socket = client_socket
self.bus = FakeBus()
self.mic = mic
self.lang = lang
Expand All @@ -49,7 +54,7 @@ def __init__(self, mic: StreamMicrophone, session_id: str,
hotword_audio_callback=self.on_hotword,
stopword_audio_callback=self.on_hotword,
wakeupword_audio_callback=self.on_hotword,
stt_audio_callback=self.on_audio,
stt_audio_callback=self.on_input_audio,
stt=Mock(transcribe=Mock(return_value=[])),
fallback_stt=Mock(transcribe=Mock(return_value=[])),
transformers=MockTransformers(),
Expand All @@ -67,14 +72,27 @@ def on_hotword(self, audio_bytes: bytes, context: dict):
LOG.info(f"Hotword: {context}")
self.ww_callback(context, self.session_id)

def on_audio(self, audio_bytes: bytes, context: dict):
def on_input_audio(self, audio_bytes: bytes, context: dict):
LOG.info(f"Audio: {context}")
audio_data = AudioData(audio_bytes, self.mic.sample_rate,
self.mic.sample_width).get_wav_data()
audio_data = b64encode(audio_data).decode("utf-8")
callback_data = {"type": "neon.audio_input",
"data": {"audio_data": audio_data, "lang": self.lang}}
self.audio_callback(callback_data, self.session_id)
self.input_audio_callback(callback_data, self.session_id)

def on_response_audio(self, data: dict):
async def _send_bytes(audio_bytes: bytes):
await self.client_socket.send_bytes(audio_bytes)

i = 0
for lang_response in data.get('responses', {}).values():
for encoded_audio in lang_response.get('audio', {}).values():
i += 1
wav_audio_bytes = b64decode(encoded_audio)
LOG.info(f"Sending {len(wav_audio_bytes)} bytes of audio")
run(_send_bytes(wav_audio_bytes))
LOG.info(f"Sent {i} binary audio response")

def on_chunk(self, chunk: ChunkInfo):
LOG.debug(f"Chunk: {chunk}")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_requirements(requirements_filename: str):
packages=find_packages(),
install_requires=get_requirements("requirements.txt"),
extras_require={"websocket": get_requirements("websocket.txt"),
"steaming": get_requirements("streaming.txt")},
"streaming": get_requirements("streaming.txt")},
zip_safe=True,
classifiers=[
'Intended Audience :: Developers',
Expand Down

0 comments on commit 2970e1d

Please sign in to comment.