Skip to content

Commit

Permalink
feat:binary handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Oct 30, 2024
1 parent 9a4bb0f commit 4505792
Showing 1 changed file with 63 additions and 12 deletions.
75 changes: 63 additions & 12 deletions hivemind_listener/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from ovos_bus_client import MessageBusClient
from ovos_bus_client.message import Message

from hivemind_bus_client.message import HiveMessage, HiveMessageType
from hivemind_bus_client.serialization import HiveMindBinaryPayloadType
from hivemind_bus_client.message import HiveMessage, HiveMessageType, HiveMindBinaryPayloadType
from hivemind_core.protocol import HiveMindListenerProtocol, HiveMindClientConnection
from hivemind_core.service import HiveMindService
from ovos_plugin_manager.stt import OVOSSTTFactory
Expand Down Expand Up @@ -153,12 +152,17 @@ def handle_client_disconnected(self, client: HiveMindClientConnection):
self.stop_listener(client)

@classmethod
def get_b64_tts(cls, message: Message = None) -> str:
def get_tts(cls, message: Message = None) -> str:
utterance = message.data['utterance']
ctxt = cls.tts._get_ctxt({"message": message})
wav, _ = cls.tts.synth(utterance, ctxt)
return str(wav)

@classmethod
def get_b64_tts(cls, message: Message = None) -> str:
wav = cls.get_tts(message)
# cast to str() to get a path, as it is a AudioFile object from tts cache
with open(str(wav), "rb") as f:
with open(wav, "rb") as f:
audio = f.read()
return base64.b64encode(audio).decode("utf-8")

Expand All @@ -171,20 +175,67 @@ def transcribe_b64_audio(cls, message: Message = None) -> List[Tuple[str, float]
utterances = cls.stt.transcribe(audio, lang)
return utterances

def handle_binary_message(self, message: HiveMessage, client: HiveMindClientConnection):
assert message.msg_type == HiveMessageType.BINARY
if message.bin_type == HiveMindBinaryPayloadType.RAW_AUDIO:
bin_data = message.payload
if client.peer in self.listeners:
# LOG.debug(f"Got {len(bin_data)} bytes of audio data from {client.peer}")
m: FakeMicrophone = self.listeners[client.peer].mic
def handle_microphone_input(self, bin_data: bytes,
sample_rate: int,
sample_width: int,
client: HiveMindClientConnection):
if client.peer in self.listeners:
m: FakeMicrophone = self.listeners[client.peer].mic
if m.sample_rate != sample_rate or m.sample_width != sample_width:
LOG.debug(f"Got {len(bin_data)} bytes of audio data from {client.peer}")
LOG.error(f"sample_rate/sample_width mismatch! "
f"got: ({sample_rate}, {sample_width}) "
f"expected: ({m.sample_rate}, {m.sample_width})")
# TODO - convert sample_rate if needed
else:
m.queue.put(bin_data)

def handle_stt_transcribe_request(self, bin_data: bytes,
sample_rate: int,
sample_width: int,
lang: str,
client: HiveMindClientConnection):
LOG.debug(f"Received binary STT input: {len(bin_data)} bytes")
audio = sr.AudioData(bin_data, sample_rate, sample_width)
tx = self.stt.transcribe(audio, lang)
m = Message("recognizer_loop:transcribe.response", {"transcriptions": tx, "lang": lang})
client.send(HiveMessage(HiveMessageType.BUS, payload=m))

def handle_stt_handle_request(self, bin_data: bytes,
sample_rate: int,
sample_width: int,
lang: str,
client: HiveMindClientConnection):
LOG.debug(f"Received binary STT input: {len(bin_data)} bytes")
audio = sr.AudioData(bin_data, sample_rate, sample_width)
tx = self.stt.transcribe(audio, lang)
if tx:
utts = [t[0].rstrip(" '\"").lstrip(" '\"") for t in tx]
m = Message("recognizer_loop:utterance",
{"utterances": utts, "lang": lang})
self.handle_inject_mycroft_msg(m, client)
else:
LOG.info(f"STT transcription error for client: {client.peer}")
m = Message("recognizer_loop:speech.recognition.unknown")
client.send(HiveMessage(HiveMessageType.BUS, payload=m))

def handle_inject_mycroft_msg(self, message: Message, client: HiveMindClientConnection):
"""
message (Message): mycroft bus message object
"""
if message.msg_type == "speak:b64_audio":
if message.msg_type == "speak:synth":
wav = self.get_tts(message)
with open(wav, "rb") as f:
bin_data = f.read()
payload = HiveMessage(HiveMessageType.BINARY,
payload=bin_data,
metadata={"lang": message.data["lang"],
"file_name": wav.split("/")[-1],
"utterance": message.data["utterance"]},
bin_type=HiveMindBinaryPayloadType.TTS_AUDIO)
client.send(payload)
return
elif message.msg_type == "speak:b64_audio":
msg: Message = message.reply("speak:b64_audio.response", message.data)
msg.data["audio"] = self.get_b64_tts(message)
if msg.context.get("destination") is None:
Expand Down

0 comments on commit 4505792

Please sign in to comment.