Skip to content

Commit

Permalink
feat/neon_transformers (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl authored Apr 13, 2023
1 parent 0e84acf commit c0b319b
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 2 deletions.
167 changes: 167 additions & 0 deletions ovos_plugin_manager/templates/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import List

from ovos_config.config import Configuration
from ovos_utils.messagebus import get_mycroft_bus

from ovos_plugin_manager.utils import ReadWriteStream


class MetadataTransformer:

def __init__(self, name, priority=50, config=None):
self.name = name
self.bus = None
self.priority = priority
if not config:
config_core = dict(Configuration())
config = config_core.get("metadata_transformers", {}).get(self.name)
self.config = config or {}

def bind(self, bus=None):
""" attach messagebus """
self.bus = bus or get_mycroft_bus()

def initialize(self):
""" perform any initialization actions """
pass

def transform(self, context: dict = None) -> (list, dict):
"""
Optionally transform passed context
eg. inject default values or convert metadata format
:param context: existing Message context from all previous transformers
:returns: dict of possibly modified or additional context
"""
context = context or {}
return context

def default_shutdown(self):
""" perform any shutdown actions """
pass


class UtteranceTransformer:

def __init__(self, name, priority=50, config=None):
self.name = name
self.bus = None
self.priority = priority
if not config:
config_core = dict(Configuration())
config = config_core.get("utterance_transformers", {}).get(self.name)
self.config = config or {}

def bind(self, bus=None):
""" attach messagebus """
self.bus = bus or get_mycroft_bus()

def initialize(self):
""" perform any initialization actions """
pass

def transform(self, utterances: List[str],
context: dict = None) -> (list, dict):
"""
Optionally transform passed utterances and/or return additional context
:param utterances: List of str utterances to parse
:param context: existing Message context associated with utterances
:returns: tuple of (possibly modified utterances, additional context)
"""
return utterances, {}

def default_shutdown(self):
""" perform any shutdown actions """
pass


class AudioTransformer:
"""process audio data and optionally transform it before STT stage"""

def __init__(self, name, priority=50, config=None):
self.name = name
self.bus = None
self.priority = priority
self.config = config or self._read_mycroft_conf()

# listener config
self.sample_width = self.config.get("sample_width", 2)
self.channels = self.config.get("channels", 1)
self.sample_rate = self.config.get("sample_rate", 16000)

# buffers with audio chunks to be used in predictions
# always cleared before STT stage
self.noise_feed = ReadWriteStream()
self.hotword_feed = ReadWriteStream()
self.speech_feed = ReadWriteStream()

def _read_mycroft_conf(self):
config_core = dict(Configuration())
config = config_core.get("audio_transformers", {}).get(self.name) or {}
listener_config = config_core.get("listener") or {}
for k in ["sample_width", "sample_rate", "channels"]:
if k not in config and k in listener_config:
config[k] = listener_config[k]
return config

def bind(self, bus=None):
""" attach messagebus """
self.bus = bus or get_mycroft_bus()

def feed_audio_chunk(self, chunk):
chunk = self.on_audio(chunk)
self.noise_feed.write(chunk)

def feed_hotword_chunk(self, chunk):
chunk = self.on_hotword(chunk)
self.hotword_feed.write(chunk)

def feed_speech_chunk(self, chunk):
chunk = self.on_speech(chunk)
self.speech_feed.write(chunk)

def feed_speech_utterance(self, chunk):
return self.on_speech_end(chunk)

def reset(self):
# end of prediction, reset buffers
self.speech_feed.clear()
self.hotword_feed.clear()
self.noise_feed.clear()

def initialize(self):
""" perform any initialization actions """
pass

def on_audio(self, audio_data):
""" Take any action you want, audio_data is a non-speech chunk
"""
return audio_data

def on_hotword(self, audio_data):
""" Take any action you want, audio_data is a full wake/hotword
Common action would be to prepare to received speech chunks
NOTE: this might be a hotword or a wakeword, listening is not assured
"""
return audio_data

def on_speech(self, audio_data):
""" Take any action you want, audio_data is a speech chunk (NOT a
full utterance) during recording
"""
return audio_data

def on_speech_end(self, audio_data):
""" Take any action you want, audio_data is the full speech audio
"""
return audio_data

def transform(self, audio_data):
""" return any additional message context to be passed in
recognize_loop:utterance message, usually a streaming prediction
Optionally make the prediction here with saved chunks from other handlers
"""
return audio_data, {}

def default_shutdown(self):
""" perform any shutdown actions """
pass
49 changes: 47 additions & 2 deletions ovos_plugin_manager/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# limitations under the License.
#
"""Common functions for loading plugins."""
import pkg_resources
import time
from enum import Enum
from ovos_utils.log import LOG
from threading import Event

import pkg_resources
from langcodes import standardize_tag as _normalize_lang
from ovos_utils.log import LOG


class PluginTypes(str, Enum):
Expand Down Expand Up @@ -137,3 +140,45 @@ def normalize_lang(lang):
pass
return lang


class ReadWriteStream:
"""
Class used to support writing binary audio data at any pace,
optionally chopping when the buffer gets too large
"""

def __init__(self, s=b'', chop_samples=-1):
self.buffer = s
self.write_event = Event()
self.chop_samples = chop_samples

def __len__(self):
return len(self.buffer)

def read(self, n=-1, timeout=None):
if n == -1:
n = len(self.buffer)
if 0 < self.chop_samples < len(self.buffer):
samples_left = len(self.buffer) % self.chop_samples
self.buffer = self.buffer[-samples_left:]
return_time = 1e10 if timeout is None else (
timeout + time.time()
)
while len(self.buffer) < n:
self.write_event.clear()
if not self.write_event.wait(return_time - time.time()):
return b''
chunk = self.buffer[:n]
self.buffer = self.buffer[n:]
return chunk

def write(self, s):
self.buffer += s
self.write_event.set()

def flush(self):
"""Makes compatible with sys.stdout"""
pass

def clear(self):
self.buffer = b''

0 comments on commit c0b319b

Please sign in to comment.