diff --git a/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt b/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt new file mode 100644 index 000000000..204934829 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt @@ -0,0 +1,41 @@ +cmake_minimum_required(VERSION 3.8) +project(lasr_speech_recognition_interfaces) + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +# find dependencies +find_package(ament_cmake REQUIRED) +find_package(rclpy REQUIRED) +find_package(action_msgs REQUIRED) + +# uncomment the following section in order to fill in +# further dependencies manually. +# find_package( REQUIRED) + +# For actions, messages, and services +find_package(rosidl_default_generators REQUIRED) + +rosidl_generate_interfaces(${PROJECT_NAME} + "action/TranscribeSpeech.action" + "msg/Transcription.msg" + "srv/TranscribeAudio.srv" + DEPENDENCIES builtin_interfaces # Add packages that above messages depend on +) + +ament_export_dependencies(rosidl_default_runtime) + +if(BUILD_TESTING) + find_package(ament_lint_auto REQUIRED) + # the following line skips the linter which checks for copyrights + # comment the line when a copyright and license is added to all source files + set(ament_cmake_copyright_FOUND TRUE) + # the following line skips cpplint (only works in a git repo) + # comment the line when this package is in a git repo and when + # a copyright and license is added to all source files + set(ament_cmake_cpplint_FOUND TRUE) + ament_lint_auto_find_test_dependencies() +endif() + +ament_package() diff --git a/common/speech/lasr_speech_recognition_interfaces/LICENSE b/common/speech/lasr_speech_recognition_interfaces/LICENSE new file mode 100644 index 000000000..30e8e2ece --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/LICENSE @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/common/speech/lasr_speech_recognition_interfaces/README.md b/common/speech/lasr_speech_recognition_interfaces/README.md new file mode 100644 index 000000000..30878a2f4 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/README.md @@ -0,0 +1,50 @@ +# lasr_speech_recognition_interfaces + +Common messages used for speech recognition + +This package is maintained by: + +- [Maayan Armony](mailto:maayan.armony@gmail.com) + +## Prerequisites + +This package depends on the following ROS packages: + +- colcon (buildtool) +- message_generation (build) +- message_runtime (exec) + +## Usage + +Ask the package maintainer to write a `doc/USAGE.md` for their package! + +## Example + +Ask the package maintainer to write a `doc/EXAMPLE.md` for their package! + +## Technical Overview + +Ask the package maintainer to write a `doc/TECHNICAL.md` for their package! + +## ROS Definitions + +### Launch Files + +This package has no launch files. + +### Messages + +#### `Transcription` + +| Field | Type | Description | +|:--------:|:------:|-------------| +| phrase | string | | +| finished | bool | | + +### Services + +This package has no services. + +### Actions + +This package has no actions. diff --git a/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action b/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action new file mode 100644 index 000000000..5cac9317e --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action @@ -0,0 +1,11 @@ +# Energy threshold +float32 energy_threshold + +# Max phrase duration +float32 max_phrase_limit +--- +#result definition +string sequence +--- +#feedback +string sequence \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg b/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg new file mode 100644 index 000000000..9c7483636 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg @@ -0,0 +1,2 @@ +string phrase +bool finished \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_interfaces/package.xml b/common/speech/lasr_speech_recognition_interfaces/package.xml new file mode 100644 index 000000000..fd72011b7 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/package.xml @@ -0,0 +1,23 @@ + + + + lasr_speech_recognition_interfaces + 0.0.0 + Common messages used for speech recognition + maayan + MIT + + ament_cmake + + rosidl_default_generators + action_msgs + rosidl_default_runtime + rosidl_interface_packages + + ament_lint_auto + ament_lint_common + + + ament_cmake + + diff --git a/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv b/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv new file mode 100644 index 000000000..f416a67c4 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv @@ -0,0 +1,2 @@ +--- +string phrase \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/LICENSE b/common/speech/lasr_speech_recognition_whisper/LICENSE new file mode 100644 index 000000000..30e8e2ece --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/LICENSE @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/common/speech/lasr_speech_recognition_whisper/README.md b/common/speech/lasr_speech_recognition_whisper/README.md new file mode 100644 index 000000000..0da406522 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/README.md @@ -0,0 +1,108 @@ +# lasr_speech_recognition_whisper + +Speech recognition implemented using OpenAI Whisper + +This package is maintained by: + +- [Maayan Armony](mailto:maayan.armony@gmail.com) + +## Prerequisites + +This package depends on the following ROS packages: + +- colcon (buildtool) +- lasr_speech_recognition_interfaces + +This packages requires Python 3.10 to be present. + +This package has 48 Python dependencies: + +- [SpeechRecognition](https://pypi.org/project/SpeechRecognition)==3.10.0 +- [openai-whisper](https://pypi.org/project/openai-whisper)==20230314 +- [PyAudio](https://pypi.org/project/PyAudio)==0.2.13 +- [PyYaml](https://pypi.org/project/PyYaml)==6.0.1 +- .. and sub dependencies (see [requirements file](requirements.txt)) + +This package requires that [ffmpeg](https://ffmpeg.org/) is available during runtime. + +## Usage + +> **Warning**: this package is not complete, this is subject to change. + +List available microphones: + +```bash +ros2 run lasr_speech_recognition_whisper list_microphones.py +``` + +Start the example script: + +```bash +ros2 run lasr_speech_recognition_whisper transcribe_microphone by-index +ros2 run lasr_speech_recognition_whisper transcribe_microphone by-name +``` + +Then start listening to people: + +```bash +ros2 service call /whisper/start_listening "{}" +``` + +You can now listen on `/transcription` for a live transcription. + +Stop listening whenever: + +```bash +ros2 service call /whisper/stop_listening "{}" +``` + +## Example + +Ask the package maintainer to write a `doc/EXAMPLE.md` for their package! + +## Technical Overview + +This package does speech recognition in three parts: + +- Adjusting for background noise + + We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice + data. + +- Collecting appropriate voice data for phrases + + We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually + speaking with enough energy that we would consider them to be speaking to the robot. + +- Running inference on phrases + + We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after + which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe. + +The package can input from the following sources: + +- On-board or external microphone on device +- Audio data from ROS topic (WORK IN PROGRESS) + +The package can output transcriptions to: + +- Standard output +- A ROS topic + +## ROS Definitions + +### Launch Files + +This package has no launch files. + +### Messages + +This package has no messages. + +### Services + +This package has no services. + +### Actions + +This package has no actions. \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py new file mode 100644 index 000000000..000678d06 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -0,0 +1,383 @@ +#!/usr/bin python3 +import os +import sounddevice # needed to remove ALSA error messages +import argparse +from typing import Optional +from dataclasses import dataclass +from pathlib import Path +from timeit import default_timer as timer + +import numpy as np +import torch + +import rclpy +from rclpy.node import Node +from rclpy.action.server import ActionServer, CancelResponse + +import speech_recognition as sr # type: ignore +from lasr_speech_recognition_interfaces.action import TranscribeSpeech # type: ignore +from rclpy.executors import ExternalShutdownException +from std_msgs.msg import String # type: ignore +from src import ModelCache # type: ignore + +# TODO: argpars -> ROS2 params, test behaviour of preemption + + +@dataclass +class speech_model_params: + """Class for storing speech recognition model parameters. + + Args: + model_name (str, optional): Name of the speech recognition model. Defaults to "medium.en". + Must be a valid Whisper model name. + device (str, optional): Device to run the model on. Defaults to "cuda" if available, otherwise "cpu". + start_timeout (float): Max number of seconds of silence when starting listening before stopping. Defaults to 5.0. + phrase_duration (Optional[float]): Max number of seconds of the phrase. Defaults to 10 seconds. + sample_rate (int): Sample rate of the microphone. Defaults to 16000Hz. + mic_device (Optional[str]): Microphone device index or name. Defaults to None. + timer_duration (Optional[int]): Duration of the timer for adjusting the microphone for ambient noise. Defaults to 20 seconds. + warmup (bool): Whether to warmup the model by running inference on a test file. Defaults to True. + energy_threshold (Optional[int]): Energy threshold for silence detection. Using this disables automatic adjustment. Defaults to None. + pause_threshold (Optional[float]): Seconds of non-speaking audio before a phrase is considered complete. Defaults to 0.8 seconds. + """ + + model_name: str = "medium.en" + device: str = "cuda" if torch.cuda.is_available() else "cpu" + start_timeout: float = 5.0 + phrase_duration: Optional[float] = 10 + sample_rate: int = 16000 + mic_device: Optional[str] = None + timer_duration: Optional[int] = 20 + warmup: bool = True + energy_threshold: Optional[int] = None + pause_threshold: Optional[float] = 2.0 + + +class TranscribeSpeechAction(Node): + # create messages that are used to publish feedback/result + _feedback = TranscribeSpeech.Feedback() + _result = TranscribeSpeech.Result() + + def __init__( + self, + action_name: str, + model_params: speech_model_params, + ) -> None: + """Starts an action server for transcribing speech. + + Args: + action_name (str): Name of the action server. + """ + Node.__init__(self, "transcribe_speech_action") + self._action_name = action_name + self._model_params = model_params + self._transcription_server = self.create_publisher( + String, "/live_speech_transcription", 10 + ) + + model_cache = ModelCache() + self._model = model_cache.load_model( + self._model_params.model_name, + self._model_params.device, + self._model_params.warmup, + ) + # Configure the speech recogniser object and adjust for ambient noise + self.recogniser = self._configure_recogniser() + + # Set up the action server and register execution callback + self._action_server = ActionServer( + self, + TranscribeSpeech, + self._action_name, + execute_callback=self.execute_cb, + cancel_callback=self.cancel_cb, + # auto_start=False, # not required in ROS2 ?? (cb is async) + ) + self._action_server.register_cancel_callback(self.cancel_cb) + self._listening = False + + # self._action_server.start() # not required in ROS2 + self.get_logger().info(f"Speech Action server {self._action_name} started") + + def _configure_microphone(self) -> sr.Microphone: + """Configures the microphone for listening to speech based on the + microphone device index or name. + + Returns: microphone object + """ + + if self._model_params.mic_device is None: + # If no microphone device is specified, use the system default microphone + return sr.Microphone(sample_rate=self._model_params.sample_rate) + elif self._model_params.mic_device.isdigit(): + return sr.Microphone( + device_index=int(self._model_params.mic_device), + sample_rate=self._model_params.sample_rate, + ) + else: + microphones = enumerate(sr.Microphone.list_microphone_names()) + for index, name in microphones: + if self._model_params.mic_device in name: + return sr.Microphone( + device_index=index, + sample_rate=self._model_params.sample_rate, + ) + raise ValueError( + f"Could not find microphone with name: {self._model_params.mic_device}" + ) + + def _configure_recogniser( + self, + energy_threshold: Optional[float] = None, + pause_threshold: Optional[float] = None, + ) -> sr.Recognizer: + """Configures the speech recogniser object. + + Args: + energy_threshold (float): Energy threshold for silence detection. Using this disables automatic adjustment. + pause_threshold (float): Seconds of non-speaking audio before a phrase is considered complete. + + Returns: + sr.Recognizer: speech recogniser object. + """ + self._listening = True + recogniser = sr.Recognizer() + + if pause_threshold: + recogniser.pause_threshold = pause_threshold + + elif self._model_params.pause_threshold: + recogniser.pause_threshold = self._model_params.pause_threshold + + if energy_threshold: + recogniser.dynamic_energy_threshold = False + recogniser.energy_threshold = energy_threshold + return recogniser + + if self._model_params.energy_threshold: + recogniser.dynamic_energy_threshold = False + recogniser.energy_threshold = self._model_params.energy_threshold + return recogniser + + with self._configure_microphone() as source: + recogniser.adjust_for_ambient_noise(source) + self._listening = False + return recogniser + + def cancel_cb(self, goal_handle) -> CancelResponse: + """Callback for cancelling the action server. + Sets server to 'canceled' state. + """ + cancel_str = f"{self._action_name} has been cancelled" + self.get_logger().info(cancel_str) + self._result.sequence = cancel_str + + # self._action_server.set_preempted(result=self._result, text=cancel_str) + goal_handle.canceled() + + return CancelResponse.ACCEPT # TODO decide if always accept cancellation + + async def execute_cb(self, goal_handle) -> None: + """Callback for executing the action server. + + Checks for cancellation before listening and before and after transcribing, returning + if cancellation is requested. + + Args: + :param goal_handle: handles the goal request, and provides access to the goal parameters + """ + + goal = goal_handle.request + + self.get_logger().info("Request Received") + if goal_handle.is_cancel_requested: + return + + if goal.energy_threshold > 0.0 and goal.max_phrase_limit > 0.0: + self.recogniser = self._configure_recogniser( + goal.energy_threshold, goal.max_phrase_limit + ) + elif goal.energy_threshold > 0.0: + self.recogniser = self._configure_recogniser(goal.energy_threshold) + elif goal.max_phrase_limit > 0.0: + self.recogniser = self._configure_recogniser( + pause_threshold=goal.max_phrase_limit + ) + + with self._configure_microphone() as src: + self._listening = True + wav_data = self.recogniser.listen( + src, + timeout=self._model_params.start_timeout, + phrase_time_limit=self._model_params.phrase_duration, + ).get_wav_data() + # Magic number 32768.0 is the maximum value of a 16-bit signed integer + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + if goal_handle.is_cancel_requested(): + self._listening = False + self.get_logger().info("Goal was cancelled during execution.") + goal_handle.canceled() + return self._result + + self.get_logger().info(f"Transcribing phrase with Whisper...") + transcription_start_time = timer() + # Cast to fp16 if using GPU + phrase = self._model.transcribe( + float_data, + fp16=self._model_params.device == "cuda", + )["text"] + transcription_end_time = timer() + self.get_logger().info(f"Transcription finished!") + self.get_logger().info( + f"Time taken: {transcription_end_time - transcription_start_time:.2f}s" + ) + self._transcription_server.publish(phrase) + if goal_handle.is_cancel_requested(): + self._listening = False + return + + self._result.sequence = phrase + self.get_logger().info(f"Transcribed phrase: {phrase}") + self.get_logger().info(f"{self._action_name} has succeeded") + + goal_handle.succeed() + + # Have this at the very end to not disrupt the action server + self._listening = False + + return self._result + + +def parse_args() -> dict: + """Parses the command line arguments into a name: value dictinoary. + + Returns: + dict: Dictionary of name: value pairs of command line arguments. + """ + parser = argparse.ArgumentParser( + description="Starts an action server for transcribing speech." + ) + + parser.add_argument( + "--action_name", + type=str, + default="transcribe_speech", + help="Name of the action server.", + ) + parser.add_argument( + "--model_name", + type=str, + default="medium.en", + help="Name of the speech recognition model.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run the model on.", + ) + parser.add_argument( + "--start_timeout", + type=float, + default=5.0, + help="Timeout for listening for the start of a phrase.", + ) + parser.add_argument( + "--phrase_duration", + type=float, + default=10, + help="Maximum phrase duration after starting listening in seconds.", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=16000, + help="Sample rate of the microphone.", + ) + parser.add_argument( + "--mic_device", + type=str, + default=None, + help="Microphone device index or name", + ) + parser.add_argument( + "--no_warmup", + action="store_true", + help="Disable warming up the model by running inference on a test file.", + ) + + parser.add_argument( + "--energy_threshold", + type=Optional[int], + default=None, + help="Energy threshold for silence detection. Using this disables automatic adjustment", + ) + + parser.add_argument( + "--pause_threshold", + type=float, + default=2.0, + help="Seconds of non-speaking audio before a phrase is considered complete.", + ) + + args, unknown = parser.parse_known_args() + return vars(args) + + +def configure_model_params(config: dict) -> speech_model_params: + """Configures the speech model parameters based on the provided + command line parameters. + + Args: + config (dict): Command line parameters parsed in dictionary form. + + Returns: + speech_model_params: dataclass containing the speech model parameters + """ + model_params = speech_model_params() + if config["model_name"]: + model_params.model_name = config["model_name"] + if config["device"]: + model_params.device = config["device"] + if config["start_timeout"]: + model_params.start_timeout = config["start_timeout"] + if config["phrase_duration"]: + model_params.phrase_duration = config["phrase_duration"] + if config["sample_rate"]: + model_params.sample_rate = config["sample_rate"] + if config["mic_device"]: + model_params.mic_device = config["mic_device"] + if config["no_warmup"]: + model_params.warmup = False + # if config["energy_threshold"]: + # model_params.energy_threshold = config["energy_threshold"] + if config["pause_threshold"]: + model_params.pause_threshold = config["pause_threshold"] + + return model_params + + +def configure_whisper_cache() -> None: + """Configures the whisper cache directory.""" + whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") + os.makedirs(whisper_cache, exist_ok=True) + # Environmental variable required to run whisper locally + os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + + +def main(args=None): + rclpy.init(args=args) + + configure_whisper_cache() + config = parse_args() + + server = TranscribeSpeechAction("transcribe_speech", configure_model_params(config)) + + try: + rclpy.spin(server) + except (KeyboardInterrupt, ExternalShutdownException): + pass diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml new file mode 100644 index 000000000..1cac47617 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/package.xml @@ -0,0 +1,30 @@ + + + + lasr_speech_recognition_whisper + 0.0.0 + Speech recognition implemented using OpenAI Whisper + maayan + MIT + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + + lasr_speech_recognition_interfaces + actionlib + actionlib_msgs + actionlib + actionlib_msgs + + + ament_python + requirements.txt + + diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.in b/common/speech/lasr_speech_recognition_whisper/requirements.in new file mode 100644 index 000000000..1d515543e --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/requirements.in @@ -0,0 +1,6 @@ +SpeechRecognition==3.10.0 +sounddevice==0.4.6 +openai-whisper==20231117 +PyAudio~=0.2.13 +PyYaml==6.0.1 +setuptools==60.0.1 \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt new file mode 100644 index 000000000..eade8e0a3 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt @@ -0,0 +1,45 @@ +certifi==2024.2.2 # via requests +cffi==1.16.0 # via sounddevice +charset-normalizer==3.3.2 # via requests +filelock==3.14.0 # via torch, triton +fsspec==2024.3.1 # via torch +idna==3.7 # via requests +jinja2==3.1.4 # via torch +llvmlite==0.42.0 # via numba +markupsafe==2.1.5 # via jinja2 +more-itertools==10.2.0 # via openai-whisper +mpmath==1.3.0 # via sympy +networkx==3.3 # via torch +numba==0.59.1 # via openai-whisper +numpy==1.26.4 # via numba, openai-whisper +nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch +nvidia-cuda-cupti-cu12==12.1.105 # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 # via torch +nvidia-cuda-runtime-cu12==12.1.105 # via torch +nvidia-cudnn-cu12==8.9.2.26 # via torch +nvidia-cufft-cu12==11.0.2.54 # via torch +nvidia-curand-cu12==10.3.2.106 # via torch +nvidia-cusolver-cu12==11.4.5.107 # via torch +nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch +nvidia-nccl-cu12==2.20.5 # via torch +nvidia-nvjitlink-cu12==12.4.127 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 # via torch +openai-whisper==20231117 # via -r requirements.in +pyaudio==0.2.13 # via -r requirements.in +pycparser==2.22 # via cffi +pyyaml==6.0.1 # via -r requirements.in +regex==2024.4.28 # via tiktoken +requests==2.31.0 # via speechrecognition, tiktoken +six==1.16.0 # via python-dateutil +sounddevice==0.4.6 # via -r requirements.in +speechrecognition==3.10.0 # via -r requirements.in +sympy==1.12 # via torch +tiktoken==0.6.0 # via openai-whisper +torch==2.3.0 # via openai-whisper +tqdm==4.66.4 # via openai-whisper +triton==2.3.0 # via openai-whisper, torch +typing-extensions==4.11.0 # via torch +urllib3==2.2.1 # via requests + +# The following packages are considered to be unsafe in a requirements file: +# setuptools == 60.0.1 diff --git a/common/speech/lasr_speech_recognition_whisper/resource/lasr_speech_recognition_whisper b/common/speech/lasr_speech_recognition_whisper/resource/lasr_speech_recognition_whisper new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/__init__.py b/common/speech/lasr_speech_recognition_whisper/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py new file mode 100755 index 000000000..a3ce21904 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py @@ -0,0 +1,19 @@ +#!/usr/bin python3 +import speech_recognition as sr +import sounddevice # needed to remove ALSA error messages + + +def main(): + microphones = enumerate(sr.Microphone.list_microphone_names()) + + print("\nAvailable microphones:") + for index, name in microphones: + print(f"[{index}] {name}") + + # # Uncomment for debugging, to see if sounddevice recongises the microphone as well + # print("Available microphone devices (sounddevice):") + # print(sounddevice.query_devices()) + + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py new file mode 100755 index 000000000..026ab2875 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py @@ -0,0 +1,76 @@ +#!/usr/bin python3 +import argparse +import os +import torch +import numpy as np +from pathlib import Path +import speech_recognition as sr +from src import ModelCache # type: ignore +import sounddevice # needed to remove ALSA error messages +from typing import Dict +import rclpy + +# TODO argparse -> ROS params + + +def parse_args() -> Dict: + parser = argparse.ArgumentParser() + parser.add_argument( + "--device_index", help="Microphone index", type=int, default=None + ) + return vars(parser.parse_args()) + + +def configure_whisper_cache() -> None: + """Configures the whisper cache directory.""" + whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") + os.makedirs(whisper_cache, exist_ok=True) + # Environmental variable required to run whisper locally + os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + + +def main(args=None): + rclpy.init(args=args) # Have to initialise rclpy for the ModelCache + + configure_whisper_cache() + args = parse_args() + + recognizer = sr.Recognizer() + recognizer.pause_threshold = 2 + microphone = sr.Microphone(device_index=args["device_index"], sample_rate=16000) + threshold = 100 + recognizer.dynamic_energy_threshold = False + recognizer.energy_threshold = threshold + model_cache = ModelCache() + transcription_model = model_cache.load_model( + "medium.en", "cuda" if torch.cuda.is_available() else "cpu", True + ) + transcription_result = "The quick brown fox jumps over the lazy dog." + while transcription_result != "": + print(f"Listening...") + with microphone as source: + wav_data = recognizer.listen( + source, phrase_time_limit=10, timeout=5 + ).get_wav_data() + print(f"Processing...") + # Magic number 32768.0 is the maximum value of a 16-bit signed integer + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + # Cast to fp16 if using GPU + transcription_result = transcription_model.transcribe( + float_data, fp16=torch.cuda.is_available() + )["text"] + + print( + f"Transcription: {transcription_result} at energy threshold {recognizer.energy_threshold}" + ) + threshold += 100 + recognizer.energy_threshold = threshold + + +if __name__ == "__main__": + + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py new file mode 100755 index 000000000..d14144e21 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py @@ -0,0 +1,72 @@ +#!/usr/bin python3 + +import os +import argparse +import speech_recognition as sr +import rclpy +import sounddevice # needed to remove ALSA error messages + +# TODO argparse -> ROS params + + +def parse_args() -> dict: + """Parse command line arguments into a dictionary. + + Returns: + dict: name: value pairs of command line arguments + """ + + parser = argparse.ArgumentParser(description="Test microphones") + parser.add_argument( + "-m", "--microphone", type=int, help="Microphone index", default=None + ) + parser.add_argument( + "-o", "--output_dir", type=str, help="Directory to save audio files" + ) + + # return vars(parser.parse_args()) + args, _ = parser.parse_known_args() + return vars(args) + + +def main(args: dict = None) -> None: + """Generate audio files from microphone input. + + Args: + args (dict): dictionary of command line arguments. + """ + + # Adapted from https://github.com/Uberi/speech_recognition/blob/master/examples/write_audio.py + + rclpy.init(args=args) + + parser_args = parse_args() + + mic_index = parser_args["microphone"] + output_dir = parser_args["output_dir"] + + r = sr.Recognizer() + r.pause_threshold = 2 + microphone = sr.Microphone(device_index=mic_index, sample_rate=16000) + with microphone as source: + print("Say something!") + audio = r.listen(source, timeout=5, phrase_time_limit=10) + print("Finished listening") + + with open(os.path.join(output_dir, "microphone.raw"), "wb") as f: + f.write(audio.get_raw_data()) + + with open(os.path.join(output_dir, "microphone.wav"), "wb") as f: + f.write(audio.get_wav_data()) + + with open(os.path.join(output_dir, "microphone.flac"), "wb") as f: + f.write(audio.get_flac_data()) + + with open(os.path.join(output_dir, "microphone.aiff"), "wb") as f: + f.write(audio.get_aiff_data()) + + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py new file mode 100755 index 000000000..2448e73ec --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -0,0 +1,67 @@ +#!/usr/bin python3 +import rclpy +from rclpy.node import Node +from rclpy.action import ActionClient +from lasr_speech_recognition_interfaces.srv import TranscribeAudio # type: ignore +from lasr_speech_recognition_interfaces.action import TranscribeSpeech + +# https://docs.ros2.org/latest/api/rclpy/api/actions.html + + +class TestSpeechServerClient(Node): + def __init__(self): + Node.__init__(self, "listen_action_client") + + self.client = ActionClient(self, TranscribeSpeech, "transcribe_speech") + self.goal_future = None + self.result_future = None + + def send_goal(self, goal): + self.get_logger().info("Waiting for Whisper server...") + self.client.wait_for_server() + self.get_logger().info("Server activated, sending goal...") + + self.goal_future = self.client.send_goal_async( + goal, feedback_callback=self.feedback_cb + ) # Returns a Future instance when the goal request has been accepted or rejected. + self.goal_future.add_done_callback( + self.response_cb + ) # When received get response + + def feedback_cb(self, msg): + self.get_logger().info(f"Received feedback: {msg.feedback}") + + def response_cb(self, future): + handle = future.result() + if not handle.accepted: + self.get_logger().info("Goal was rejected") + return + + self.get_logger().info("Goal was accepted") + self.result_future = ( + handle.get_result_async() + ) # Not using get_result() in cb, as can cause deadlock according to docs + self.result_future.add_done_callback(self.result_cb) + + def result_cb(self, future): + result = future.result().result + self.get_logger().info(f"Transcribed Speech: {result.sequence}") + + +def main(args=None): + rclpy.init(args=args) + while rclpy.ok(): + goal = TranscribeSpeech.Goal() + client = TestSpeechServerClient() + try: + client.send_goal(goal) + rclpy.spin(client) + except KeyboardInterrupt: + client.get_logger().info("Shutting down...") + finally: + client.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/setup.cfg b/common/speech/lasr_speech_recognition_whisper/setup.cfg new file mode 100644 index 000000000..1f6a54400 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir = $base/lib/lasr_speech_recognition_whisper +[install] +install_scripts = $base/lib/lasr_speech_recognition_whisper diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py new file mode 100644 index 000000000..c6a801483 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -0,0 +1,38 @@ +from setuptools import find_packages, setup + +package_name = "lasr_speech_recognition_whisper" + +setup( + name=package_name, + version="0.0.0", + packages=find_packages(exclude=["test"]), + # packages=[package_name, f"{package_name}.lasr_speech_recognition_whisper", f"{package_name}.src"], + # package_dir={ + # '': '.', + # package_name: os.path.join(package_name), + # f"{package_name}.whisper": os.path.join(package_name, 'whisper'), + # f"{package_name}.src": os.path.join(package_name, 'src'), + # }, + data_files=[ + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), + ], + install_requires=["setuptools"], + zip_safe=True, + maintainer="maayan", + maintainer_email="maayan.armony@gmail.com", + description="Speech recognition implemented using OpenAI Whisper", + license="MIT", + tests_require=["pytest"], + entry_points={ + "console_scripts": [ + "transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main", + "transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main", + "simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main", + "list_microphones = scripts.list_microphones:main", + "microphone_tuning_test = scripts.microphone_tuning_test:main", + "test_microphones = scripts.test_microphones:main", + "test_speech_server = scripts.test_speech_server:main", + ], + }, +) diff --git a/common/speech/lasr_speech_recognition_whisper/src/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/__init__.py new file mode 100644 index 000000000..ca8a17393 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/__init__.py @@ -0,0 +1 @@ +from .lasr_speech_recognition_whisper.cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py new file mode 100644 index 000000000..f662b86a0 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py @@ -0,0 +1 @@ +from .cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py new file mode 100644 index 000000000..259ffffa5 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -0,0 +1,57 @@ +import os +import whisper # type: ignore +from ament_index_python import packages +from rclpy.node import Node + +# Keep all loaded models in memory +MODEL_CACHE = {} + + +class ModelCache(Node): + def __init__(self): + super().__init__("lasr_speech_recognition_whisper_cache") + + def load_model( + self, name: str, device: str = "cpu", load_test_file: bool = False + ) -> whisper.Whisper: + """Loads a whisper model from disk, or from cache if it has already been loaded. + + Args: + name (str): Name of the whisper model. Must be the name of an official whisper + model, or the path to a model checkpoint. + device (str, optional): Pytorch device to put the model on. Defaults to 'cpu'. + load_test_file (bool, optional): Whether to run inference on a test audio file + after loading the model (if model is not in cache). Defaults to False. Test file + is assumed to be called "test.m4a" and be in the root of the package directory. + + Returns: + whisper.Whisper: Whisper model instance + """ + global MODEL_CACHE + + if name not in MODEL_CACHE: + self.get_logger().info(f"Loading model {name}") + MODEL_CACHE[name] = whisper.load_model(name, device=device) + self.get_logger().info(f"Sucessfully loaded model {name} on {device}") + if load_test_file: + package_install = packages.get_package_prefix( + "lasr_speech_recognition_whisper" + ) + package_root = os.path.abspath( + os.path.join( + package_install, + os.pardir, + os.pardir, + "lasr_speech_recognition_whisper", + ) + ) + example_fp = os.path.join(package_root, "test.m4a") + self.get_logger().info( + "Running transcription on example file to ensure model is loaded..." + ) + test_result: str = MODEL_CACHE[name].transcribe( + example_fp, fp16=device == "cuda" + ) + self.get_logger().info(f"Transcription test result: {test_result}") + + return MODEL_CACHE[name] diff --git a/common/speech/lasr_speech_recognition_whisper/test.m4a b/common/speech/lasr_speech_recognition_whisper/test.m4a new file mode 100644 index 000000000..1fbef3f08 Binary files /dev/null and b/common/speech/lasr_speech_recognition_whisper/test.m4a differ diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py new file mode 100644 index 000000000..ceffe896d --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py @@ -0,0 +1,27 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_copyright.main import main +import pytest + + +# Remove the `skip` decorator once the source file(s) have a copyright header +@pytest.mark.skip( + reason="No copyright header has been placed in the generated source file." +) +@pytest.mark.copyright +@pytest.mark.linter +def test_copyright(): + rc = main(argv=[".", "test"]) + assert rc == 0, "Found errors" diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py new file mode 100644 index 000000000..ee79f31ac --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py @@ -0,0 +1,25 @@ +# Copyright 2017 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_flake8.main import main_with_errors +import pytest + + +@pytest.mark.flake8 +@pytest.mark.linter +def test_flake8(): + rc, errors = main_with_errors(argv=[]) + assert rc == 0, "Found %d code style errors / warnings:\n" % len( + errors + ) + "\n".join(errors) diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py new file mode 100644 index 000000000..a2c3deb8e --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py @@ -0,0 +1,23 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_pep257.main import main +import pytest + + +@pytest.mark.linter +@pytest.mark.pep257 +def test_pep257(): + rc = main(argv=[".", "test"]) + assert rc == 0, "Found code style errors / warnings"