From 13f6568c7f063db34251daf53cc67065c2f2c1ad Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Wed, 2 Oct 2024 17:43:48 +0100 Subject: [PATCH 01/14] Port initial speech interfaces --- .../CMakeLists.txt | 36 ++++ .../LICENSE | 17 ++ .../README.md | 51 +++++ .../action/TranscribeSpeech.action | 11 + .../msg/Transcription.msg | 2 + .../package.xml | 23 ++ .../srv/TranscribeAudio.srv | 2 + .../lasr_speech_recognition_whisper/LICENSE | 202 ++++++++++++++++++ 8 files changed, 344 insertions(+) create mode 100644 common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt create mode 100644 common/speech/lasr_speech_recognition_interfaces/LICENSE create mode 100644 common/speech/lasr_speech_recognition_interfaces/README.md create mode 100644 common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action create mode 100644 common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg create mode 100644 common/speech/lasr_speech_recognition_interfaces/package.xml create mode 100644 common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv create mode 100644 common/speech/lasr_speech_recognition_whisper/LICENSE diff --git a/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt b/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt new file mode 100644 index 000000000..ac4de94d7 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 3.8) +project(lasr_speech_recognition_interfaces) + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +# find dependencies +find_package(ament_cmake REQUIRED) +# uncomment the following section in order to fill in +# further dependencies manually. +# find_package( REQUIRED) + +# For actions, messages, and services +find_package(rosidl_default_generators REQUIRED) + +rosidl_generate_interfaces(${PROJECT_NAME} + "action/TranscribeSpeech.action" + "msg/Transcription.msg" + "srv/TranscribeAudio.srv" + DEPENDENCIES builtin_interfaces # Add packages that above messages depend on +) + +if(BUILD_TESTING) + find_package(ament_lint_auto REQUIRED) + # the following line skips the linter which checks for copyrights + # comment the line when a copyright and license is added to all source files + set(ament_cmake_copyright_FOUND TRUE) + # the following line skips cpplint (only works in a git repo) + # comment the line when this package is in a git repo and when + # a copyright and license is added to all source files + set(ament_cmake_cpplint_FOUND TRUE) + ament_lint_auto_find_test_dependencies() +endif() + +ament_package() diff --git a/common/speech/lasr_speech_recognition_interfaces/LICENSE b/common/speech/lasr_speech_recognition_interfaces/LICENSE new file mode 100644 index 000000000..30e8e2ece --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/LICENSE @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/common/speech/lasr_speech_recognition_interfaces/README.md b/common/speech/lasr_speech_recognition_interfaces/README.md new file mode 100644 index 000000000..1378bcdf4 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/README.md @@ -0,0 +1,51 @@ +# lasr_speech_recognition_msgs + +Common messages used for speech recognition + +This package is maintained by: +- [Maayan Armony](mailto:maayan.armony@gmail.com) +- [Paul Makles](mailto:me@insrt.uk) (ROS1) + +## Prerequisites + +This package depends on the following ROS packages: +- colcon (buildtool) +- message_generation (build) +- message_runtime (exec) + + +## Usage + +Ask the package maintainer to write a `doc/USAGE.md` for their package! + +## Example + +Ask the package maintainer to write a `doc/EXAMPLE.md` for their package! + +## Technical Overview + +Ask the package maintainer to write a `doc/TECHNICAL.md` for their package! + +## ROS Definitions + +### Launch Files + +This package has no launch files. + +### Messages + +#### `Transcription` + +| Field | Type | Description | +|:-:|:-:|---| +| phrase | string | | +| finished | bool | | + + +### Services + +This package has no services. + +### Actions + +This package has no actions. diff --git a/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action b/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action new file mode 100644 index 000000000..5cac9317e --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action @@ -0,0 +1,11 @@ +# Energy threshold +float32 energy_threshold + +# Max phrase duration +float32 max_phrase_limit +--- +#result definition +string sequence +--- +#feedback +string sequence \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg b/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg new file mode 100644 index 000000000..9c7483636 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg @@ -0,0 +1,2 @@ +string phrase +bool finished \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_interfaces/package.xml b/common/speech/lasr_speech_recognition_interfaces/package.xml new file mode 100644 index 000000000..b15638eb1 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/package.xml @@ -0,0 +1,23 @@ + + + + lasr_speech_recognition_interfaces + 0.0.0 + Common messages used for speech recognition + maayan + MIT + + ament_cmake + + rosidl_default_generators + action_msgs + rosidl_default_runtime + rosidl_interface_packages + + ament_lint_auto + ament_lint_common + + + ament_cmake + + diff --git a/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv b/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv new file mode 100644 index 000000000..f416a67c4 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv @@ -0,0 +1,2 @@ +--- +string phrase \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/LICENSE b/common/speech/lasr_speech_recognition_whisper/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. From e30f395d9aaf1e84a3a9e8fd02208591e54a91bd Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Mon, 7 Oct 2024 17:01:37 +0100 Subject: [PATCH 02/14] Start porting whisper package setup and nodes --- .../lasr_speech_recognition_whisper/LICENSE | 219 +--------- .../__init__.py | 0 .../nodes/simple_transcribe_microphone | 80 ++++ .../nodes/transcribe_microphone | 80 ++++ .../nodes/transcribe_microphone_server | 374 ++++++++++++++++++ .../package.xml | 30 ++ .../requirements.in | 6 + .../requirements.txt | 51 +++ .../resource/lasr_speech_recognition_whisper | 0 .../scripts/list_microphones.py | 8 + .../scripts/microphone_tuning_test.py | 67 ++++ .../scripts/repeat_after_me.py | 58 +++ .../scripts/test_microphones.py | 57 +++ .../scripts/test_speech_server.py | 21 + .../lasr_speech_recognition_whisper/setup.cfg | 4 + .../lasr_speech_recognition_whisper/setup.py | 25 ++ .../__init__.py | 12 + .../bytesfifo.py | 137 +++++++ .../lasr_speech_recognition_whisper/cache.py | 43 ++ .../collector.py | 131 ++++++ .../lasr_speech_recognition_whisper/source.py | 57 +++ .../lasr_speech_recognition_whisper/worker.py | 203 ++++++++++ .../test/test_copyright.py | 25 ++ .../test/test_flake8.py | 25 ++ .../test/test_pep257.py | 23 ++ 25 files changed, 1534 insertions(+), 202 deletions(-) create mode 100644 common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/__init__.py create mode 100644 common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone create mode 100644 common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone create mode 100644 common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server create mode 100644 common/speech/lasr_speech_recognition_whisper/package.xml create mode 100644 common/speech/lasr_speech_recognition_whisper/requirements.in create mode 100644 common/speech/lasr_speech_recognition_whisper/requirements.txt create mode 100644 common/speech/lasr_speech_recognition_whisper/resource/lasr_speech_recognition_whisper create mode 100644 common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py create mode 100644 common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py create mode 100644 common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py create mode 100644 common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py create mode 100644 common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py create mode 100644 common/speech/lasr_speech_recognition_whisper/setup.cfg create mode 100644 common/speech/lasr_speech_recognition_whisper/setup.py create mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py create mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py create mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py create mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py create mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py create mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py create mode 100644 common/speech/lasr_speech_recognition_whisper/test/test_copyright.py create mode 100644 common/speech/lasr_speech_recognition_whisper/test/test_flake8.py create mode 100644 common/speech/lasr_speech_recognition_whisper/test/test_pep257.py diff --git a/common/speech/lasr_speech_recognition_whisper/LICENSE b/common/speech/lasr_speech_recognition_whisper/LICENSE index d64569567..30e8e2ece 100644 --- a/common/speech/lasr_speech_recognition_whisper/LICENSE +++ b/common/speech/lasr_speech_recognition_whisper/LICENSE @@ -1,202 +1,17 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone b/common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone new file mode 100644 index 000000000..62342afc9 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone @@ -0,0 +1,80 @@ +#!/usr/bin python3 +import os +import torch +import rospkg # TODO check if change +import rclpy +import sys +from pathlib import Path +import speech_recognition as sr +import numpy as np + +from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse +from lasr_speech_recognition_whisper import load_model + +# TODO rospkg + +MODEL = "medium.en" # Whisper model +TIMEOUT = 5.0 # Timeout for listening for the start of a phrase +PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase + +WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper') +os.makedirs(WHISPER_CACHE, exist_ok=True) +os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE + +if len(sys.argv) < 3: + print('Usage:') + print('rosrun lasr_speech_recognition transcribe_microphone by-index ') + print('rosrun lasr_speech_recognition transcribe_microphone by-name ') + exit(1) +else: + matcher = sys.argv[1] + device_index = None + if matcher == 'by-index': + device_index = int(sys.argv[2]) + elif matcher == 'by-name': + import speech_recognition as sr + microphones = enumerate(sr.Microphone.list_microphone_names()) + + target_name = sys.argv[2] + for index, name in microphones: + if target_name in name: + device_index = index + break + + if device_index is None: + print('Could not find device!') + exit(1) + else: + print('Invalid matcher') + exit(1) + +rclpy.init(args=sys.argv) +node = rclpy.create_node('transcribe_mic', anonymous=True) + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = load_model("medium.en", device=device) + +# try to run inference on the example file +r = rospkg.RosPack() +EXAMPLE_FILE = r.get_path('lasr_speech_recognition_whisper') + "/test.m4a" +rclpy.get_logger().info("Running transcription on example file to ensure model is loaded...") +rclpy.get_logger().info(model.transcribe(EXAMPLE_FILE, fp16=torch.cuda.is_available())) + +microphone = sr.Microphone(device_index=device_index, sample_rate=16000) +r = sr.Recognizer() +with microphone as source: + r.adjust_for_ambient_noise(source) + +def handle_transcribe_audio(_): + with microphone as source: + + wav_data = r.listen(source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT).get_wav_data() + float_data = np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order='C') / 32768.0 + + phrase = model.transcribe(float_data, fp16=device == "cuda")["text"] + return TranscribeAudioResponse(phrase=phrase) + +node.create_service('/whisper/transcribe_audio', TranscribeAudio, handle_transcribe_audio) + +rclpy.get_logger().info("Whisper service ready") +rclpy.spin(node) \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone new file mode 100644 index 000000000..ae86ef310 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone @@ -0,0 +1,80 @@ +#!/usr/bin python3 +import os +import torch +import rospkg +from pathlib import Path + +WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper') +os.makedirs(WHISPER_CACHE, exist_ok=True) +os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE + +import sys + +# TODO port node + +if len(sys.argv) < 3: + print('Usage:') + print('rosrun lasr_speech_recognition transcribe_microphone by-index ') + print('rosrun lasr_speech_recognition transcribe_microphone by-name ') + exit(1) +else: + matcher = sys.argv[1] + device_index = None + if matcher == 'by-index': + device_index = int(sys.argv[2]) + elif matcher == 'by-name': + import speech_recognition as sr + microphones = enumerate(sr.Microphone.list_microphone_names()) + + target_name = sys.argv[2] + for index, name in microphones: + if target_name in name: + device_index = index + break + + if device_index is None: + print('Could not find device!') + exit(1) + else: + print('Invalid matcher') + exit(1) + +import rospy +from std_srvs.srv import Empty, EmptyResponse +rospy.init_node('transcribe_mic', anonymous=True) + +from lasr_speech_recognition_whisper import SpeechRecognitionToTopic, MicrophonePhraseCollector, load_model + +collector = MicrophonePhraseCollector(device_index=device_index) +collector.adjust_for_noise() + +#model = load_model("base.en") + +model = load_model("medium.en") + +# try to run inference on the example file +r = rospkg.RosPack() +EXAMPLE_FILE = r.get_path('lasr_speech_recognition_whisper') + "/test.m4a" +rospy.loginfo("Running transcription on example file to ensure model is loaded...") +rospy.loginfo(model.transcribe(EXAMPLE_FILE, fp16=torch.cuda.is_available())) + +worker = SpeechRecognitionToTopic(collector, model, "transcription", infer_partial = False) + +def adjust_for_noise(_): + collector.adjust_for_noise() + return EmptyResponse() + +def start_listening(_): + worker.start() + return EmptyResponse() + +def stop_listening(_): + worker.stop() + return EmptyResponse() + +rospy.Service('/whisper/adjust_for_noise', Empty, adjust_for_noise) +rospy.Service('/whisper/start_listening', Empty, start_listening) +rospy.Service('/whisper/stop_listening', Empty, stop_listening) + +rospy.loginfo("Starting the Whisper worker!") +rospy.spin() diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server new file mode 100644 index 000000000..680f97c5d --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server @@ -0,0 +1,374 @@ +#!/usr/bin python3 +import os +import sounddevice # needed to remove ALSA error messages +import argparse +from typing import Optional +from dataclasses import dataclass +from pathlib import Path +from timeit import default_timer as timer + +import numpy as np +import torch + +import rclpy +from rclpy.node import Node +from rclpy.action import ActionServer + +import speech_recognition as sr # type: ignore +import lasr_speech_recognition_msgs.msg # type: ignore +from std_msgs.msg import String # type: ignore +from lasr_speech_recognition_whisper import load_model # type: ignore + +# TODO: argpars -> ROS2 params, behaviour of preemption, behaviour of rclpy.spin(node) + +@dataclass +class speech_model_params: + """Class for storing speech recognition model parameters. + + Args: + model_name (str, optional): Name of the speech recognition model. Defaults to "medium.en". + Must be a valid Whisper model name. + device (str, optional): Device to run the model on. Defaults to "cuda" if available, otherwise "cpu". + start_timeout (float): Max number of seconds of silence when starting listening before stopping. Defaults to 5.0. + phrase_duration (Optional[float]): Max number of seconds of the phrase. Defaults to 10 seconds. + sample_rate (int): Sample rate of the microphone. Defaults to 16000Hz. + mic_device (Optional[str]): Microphone device index or name. Defaults to None. + timer_duration (Optional[int]): Duration of the timer for adjusting the microphone for ambient noise. Defaults to 20 seconds. + warmup (bool): Whether to warmup the model by running inference on a test file. Defaults to True. + energy_threshold (Optional[int]): Energy threshold for silence detection. Using this disables automatic adjustment. Defaults to None. + pause_threshold (Optional[float]): Seconds of non-speaking audio before a phrase is considered complete. Defaults to 0.8 seconds. + """ + + model_name: str = "medium.en" + device: str = "cuda" if torch.cuda.is_available() else "cpu" + start_timeout: float = 5.0 + phrase_duration: Optional[float] = 10 + sample_rate: int = 16000 + mic_device: Optional[str] = None + timer_duration: Optional[int] = 20 + warmup: bool = True + energy_threshold: Optional[int] = None + pause_threshold: Optional[float] = 2.0 + + +class TranscribeSpeechAction(object): + # create messages that are used to publish feedback/result + _feedback = lasr_speech_recognition_msgs.msg.TranscribeSpeechFeedback() + _result = lasr_speech_recognition_msgs.msg.TranscribeSpeechResult() + + def __init__( + self, + action_name: str, + model_params: speech_model_params, + ) -> None: + """Starts an action server for transcribing speech. + + Args: + action_name (str): Name of the action server. + """ + + self._action_name = action_name + self._model_params = model_params + self._transcription_server = node.create_publisher( + String, "/live_speech_transcription", 10 + ) + + self._model = load_model( + self._model_params.model_name, + self._model_params.device, + self._model_params.warmup, + ) + # Configure the speech recogniser object and adjust for ambient noise + self.recogniser = self._configure_recogniser() + # Setup the action server and register execution callback + # TODO check behaviour of ActionServer + self._action_server = ActionServer( + self._action_name, + lasr_speech_recognition_msgs.msg.TranscribeSpeechAction, + execute_cb=self.execute_cb, + auto_start=False, + ) + self._action_server.register_preempt_callback(self.prempt_cb) + self._listening = False + + self._action_server.start() + rclpy.get_logger().info(f"Speech Action server {self._action_name} started") + + def _configure_microphone(self) -> sr.Microphone: + """Configures the microphone for listening to speech based on the + microphone device index or name. + + Returns: microphone object + """ + + if self._model_params.mic_device is None: + # If no microphone device is specified, use the system default microphone + return sr.Microphone(sample_rate=self._model_params.sample_rate) + elif self._model_params.mic_device.isdigit(): + return sr.Microphone( + device_index=int(self._model_params.mic_device), + sample_rate=self._model_params.sample_rate, + ) + else: + microphones = enumerate(sr.Microphone.list_microphone_names()) + for index, name in microphones: + if self._model_params.mic_device in name: + return sr.Microphone( + device_index=index, + sample_rate=self._model_params.sample_rate, + ) + raise ValueError( + f"Could not find microphone with name: {self._model_params.mic_device}" + ) + + def _configure_recogniser( + self, + energy_threshold: Optional[float] = None, + pause_threshold: Optional[float] = None, + ) -> sr.Recognizer: + """Configures the speech recogniser object. + + Args: + energy_threshold (float): Energy threshold for silence detection. Using this disables automatic adjustment. + pause_threshold (float): Seconds of non-speaking audio before a phrase is considered complete. + + Returns: + sr.Recognizer: speech recogniser object. + """ + self._listening = True + recogniser = sr.Recognizer() + + if pause_threshold: + recogniser.pause_threshold = pause_threshold + + elif self._model_params.pause_threshold: + recogniser.pause_threshold = self._model_params.pause_threshold + + if energy_threshold: + recogniser.dynamic_energy_threshold = False + recogniser.energy_threshold = energy_threshold + return recogniser + + if self._model_params.energy_threshold: + recogniser.dynamic_energy_threshold = False + recogniser.energy_threshold = self._model_params.energy_threshold + return recogniser + + with self._configure_microphone() as source: + recogniser.adjust_for_ambient_noise(source) + self._listening = False + return recogniser + + def prempt_cb(self) -> None: + """Callback for preempting the action server. + + Sets server to preempted state. + """ + preempted_str = f"{self._action_name} has been preempted" + rclpy.get_logger().info(preempted_str) + self._result.sequence = preempted_str + self._action_server.set_preempted(result=self._result, text=preempted_str) + + def execute_cb(self, goal) -> None: + """Callback for executing the action server. + + Checks for preemption before listening and before and after transcribing, returning + if preemption is requested. + + Args: + goal: UNUSED - actionlib requires a goal argument in the execute callback, but + this action server does not use a goal. + """ + rclpy.get_logger().info("Request Received") + if self._action_server.is_preempt_requested(): + return + + if goal.energy_threshold > 0.0 and goal.max_phrase_limit > 0.0: + self.recogniser = self._configure_recogniser( + goal.energy_threshold, goal.max_phrase_limit + ) + elif goal.energy_threshold > 0.0: + self.recogniser = self._configure_recogniser(goal.energy_threshold) + elif goal.max_phrase_limit > 0.0: + self.recogniser = self._configure_recogniser( + pause_threshold=goal.max_phrase_limit + ) + + with self._configure_microphone() as src: + self._listening = True + wav_data = self.recogniser.listen( + src, + timeout=self._model_params.start_timeout, + phrase_time_limit=self._model_params.phrase_duration, + ).get_wav_data() + # Magic number 32768.0 is the maximum value of a 16-bit signed integer + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + if self._action_server.is_preempt_requested(): + self._listening = False + return + + rclpy.get_logger().info(f"Transcribing phrase with Whisper...") + transcription_start_time = timer() + # Cast to fp16 if using GPU + phrase = self._model.transcribe( + float_data, + fp16=self._model_params.device == "cuda", + )["text"] + transcription_end_time = timer() + rclpy.get_logger().info(f"Transcription finished!") + rclpy.get_logger().info( + f"Time taken: {transcription_end_time - transcription_start_time:.2f}s" + ) + self._transcription_server.publish(phrase) + if self._action_server.is_preempt_requested(): + self._listening = False + return + + self._result.sequence = phrase + rclpy.get_logger().info(f"Transcribed phrase: {phrase}") + rclpy.get_logger().info(f"{self._action_name} has succeeded") + self._action_server.set_succeeded(self._result) + + # Have this at the very end to not disrupt the action server + self._listening = False + + +def parse_args() -> dict: + """Parses the command line arguments into a name: value dictinoary. + + Returns: + dict: Dictionary of name: value pairs of command line arguments. + """ + parser = argparse.ArgumentParser( + description="Starts an action server for transcribing speech." + ) + + # TODO change to ROS2 rosparams: + # port = node.declare_parameter('port', '/dev/ttyUSB0').value + # assert isinstance(port, str), 'port parameter must be a str' + + + parser.add_argument( + "--action_name", + type=str, + default="transcribe_speech", + help="Name of the action server.", + ) + parser.add_argument( + "--model_name", + type=str, + default="medium.en", + help="Name of the speech recognition model.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run the model on.", + ) + parser.add_argument( + "--start_timeout", + type=float, + default=5.0, + help="Timeout for listening for the start of a phrase.", + ) + parser.add_argument( + "--phrase_duration", + type=float, + default=10, + help="Maximum phrase duration after starting listening in seconds.", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=16000, + help="Sample rate of the microphone.", + ) + parser.add_argument( + "--mic_device", + type=str, + default=None, + help="Microphone device index or name", + ) + parser.add_argument( + "--no_warmup", + action="store_true", + help="Disable warming up the model by running inference on a test file.", + ) + + parser.add_argument( + "--energy_threshold", + type=Optional[int], + default=None, + help="Energy threshold for silence detection. Using this disables automatic adjustment", + ) + + parser.add_argument( + "--pause_threshold", + type=float, + default=2.0, + help="Seconds of non-speaking audio before a phrase is considered complete.", + ) + + args, unknown = parser.parse_known_args() + return vars(args) + + +def configure_model_params(config: dict) -> speech_model_params: + """Configures the speech model parameters based on the provided + command line parameters. + + Args: + config (dict): Command line parameters parsed in dictionary form. + + Returns: + speech_model_params: dataclass containing the speech model parameters + """ + model_params = speech_model_params() + if config["model_name"]: + model_params.model_name = config["model_name"] + if config["device"]: + model_params.device = config["device"] + if config["start_timeout"]: + model_params.start_timeout = config["start_timeout"] + if config["phrase_duration"]: + model_params.phrase_duration = config["phrase_duration"] + if config["sample_rate"]: + model_params.sample_rate = config["sample_rate"] + if config["mic_device"]: + model_params.mic_device = config["mic_device"] + if config["no_warmup"]: + model_params.warmup = False + # if config["energy_threshold"]: + # model_params.energy_threshold = config["energy_threshold"] + if config["pause_threshold"]: + model_params.pause_threshold = config["pause_threshold"] + + return model_params + + +def configure_whisper_cache() -> None: + """Configures the whisper cache directory.""" + whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") + os.makedirs(whisper_cache, exist_ok=True) + # Environemntal variable required to run whisper locally + os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + +def main(args=None): + configure_whisper_cache() + config = parse_args() + try: + rclpy.init(args=args) + whisper_node = rclpy.create_node("transcribe_speech_server") + + server = TranscribeSpeechAction("transcribe_speech", configure_model_params(config)) + rclpy.spin(whisper_node) # TODO check behaviour (was rospy.spin()) + except (KeyboardInterrupt, ExternalShutdownException): + pass + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml new file mode 100644 index 000000000..bb3215ee6 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/package.xml @@ -0,0 +1,30 @@ + + + + lasr_speech_recognition_whisper + 0.0.0 + Speech recognition implemented using OpenAI Whisper + maayan + MIT + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + + lasr_speech_recognition_msgs + actionlib + actionlib_msgs + actionlib + actionlib_msgs + + + ament_python + requirements.txt + + diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.in b/common/speech/lasr_speech_recognition_whisper/requirements.in new file mode 100644 index 000000000..da48c5086 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/requirements.in @@ -0,0 +1,6 @@ +SpeechRecognition==3.10.0 +sounddevice==0.4.6 +openai-whisper==20231117 +PyAudio==0.2.13 +PyYaml==6.0.1 +rospkg==1.5.0 diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt new file mode 100644 index 000000000..bc986de21 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt @@ -0,0 +1,51 @@ +catkin-pkg==1.0.0 # via rospkg +certifi==2024.2.2 # via requests +cffi==1.16.0 # via sounddevice +charset-normalizer==3.3.2 # via requests +distro==1.9.0 # via rospkg +docutils==0.21.2 # via catkin-pkg +filelock==3.14.0 # via torch, triton +fsspec==2024.3.1 # via torch +idna==3.7 # via requests +jinja2==3.1.4 # via torch +llvmlite==0.42.0 # via numba +markupsafe==2.1.5 # via jinja2 +more-itertools==10.2.0 # via openai-whisper +mpmath==1.3.0 # via sympy +networkx==3.3 # via torch +numba==0.59.1 # via openai-whisper +numpy==1.26.4 # via numba, openai-whisper +nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch +nvidia-cuda-cupti-cu12==12.1.105 # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 # via torch +nvidia-cuda-runtime-cu12==12.1.105 # via torch +nvidia-cudnn-cu12==8.9.2.26 # via torch +nvidia-cufft-cu12==11.0.2.54 # via torch +nvidia-curand-cu12==10.3.2.106 # via torch +nvidia-cusolver-cu12==11.4.5.107 # via torch +nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch +nvidia-nccl-cu12==2.20.5 # via torch +nvidia-nvjitlink-cu12==12.4.127 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 # via torch +openai-whisper==20231117 # via -r requirements.in +pyaudio==0.2.13 # via -r requirements.in +pycparser==2.22 # via cffi +pyparsing==3.1.2 # via catkin-pkg +python-dateutil==2.9.0.post0 # via catkin-pkg +pyyaml==6.0.1 # via -r requirements.in, rospkg +regex==2024.4.28 # via tiktoken +requests==2.31.0 # via speechrecognition, tiktoken +rospkg==1.5.0 # via -r requirements.in +six==1.16.0 # via python-dateutil +sounddevice==0.4.6 # via -r requirements.in +speechrecognition==3.10.0 # via -r requirements.in +sympy==1.12 # via torch +tiktoken==0.6.0 # via openai-whisper +torch==2.3.0 # via openai-whisper +tqdm==4.66.4 # via openai-whisper +triton==2.3.0 # via openai-whisper, torch +typing-extensions==4.11.0 # via torch +urllib3==2.2.1 # via requests + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/common/speech/lasr_speech_recognition_whisper/resource/lasr_speech_recognition_whisper b/common/speech/lasr_speech_recognition_whisper/resource/lasr_speech_recognition_whisper new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py new file mode 100644 index 000000000..c9681aab9 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py @@ -0,0 +1,8 @@ +#!/usr/bin python3 +import speech_recognition as sr + +microphones = enumerate(sr.Microphone.list_microphone_names()) + +print("\nAvailable microphones:") +for index, name in microphones: + print(f"[{index}] {name}") diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py new file mode 100644 index 000000000..a9a425df1 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py @@ -0,0 +1,67 @@ +#!/usr/bin python3 +import argparse +import os +import torch +import numpy as np +from pathlib import Path +import speech_recognition as sr +from lasr_speech_recognition_whisper import load_model # type: ignore +import sounddevice # needed to remove ALSA error messages +from typing import Dict + + +def parse_args() -> Dict: + parser = argparse.ArgumentParser() + parser.add_argument("--device_index", type=int, default=None) + return vars(parser.parse_args()) + + +def configure_whisper_cache() -> None: + """Configures the whisper cache directory.""" + whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") + os.makedirs(whisper_cache, exist_ok=True) + # Environemntal variable required to run whisper locally + os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + + +def main(): + args = parse_args() + + recognizer = sr.Recognizer() + recognizer.pause_threshold = 2 + microphone = sr.Microphone(device_index=args["device_index"], sample_rate=16000) + threshold = 100 + recognizer.dynamic_energy_threshold = False + recognizer.energy_threshold = threshold + transcription_model = load_model( + "medium.en", "cuda" if torch.cuda.is_available() else "cpu", True + ) + transcription_result = "The quick brown fox jumps over the lazy dog." + while transcription_result != "": + print(f"Listening...") + with microphone as source: + wav_data = recognizer.listen( + source, phrase_time_limit=10, timeout=5 + ).get_wav_data() + print(f"Processing...") + # Magic number 32768.0 is the maximum value of a 16-bit signed integer + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + # Cast to fp16 if using GPU + transcription_result = transcription_model.transcribe( + float_data, fp16=torch.cuda.is_available() + )["text"] + + print( + f"Transcription: {transcription_result} at energy threshold {recognizer.energy_threshold}" + ) + threshold += 100 + recognizer.energy_threshold = threshold + + +if __name__ == "__main__": + configure_whisper_cache() + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py new file mode 100644 index 000000000..d7cce0519 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py @@ -0,0 +1,58 @@ +#!/usr/bin python3 +import rclpy +import actionlib # TODO change to reg actions +from lasr_voice import Voice # type: ignore +from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore +from lasr_speech_recognition_msgs.msg import ( # type: ignore + TranscribeSpeechAction, + TranscribeSpeechGoal, +) + +rospy.init_node("repeat") + +USE_ACTIONLIB = True + +voice = Voice() + + +if USE_ACTIONLIB: + client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction) + rospy.loginfo("Waiting for server...") + client.wait_for_server() + repeating = False + rospy.loginfo("Done waiting") + while not rospy.is_shutdown(): + goal = TranscribeSpeechGoal() + client.send_goal(goal) + client.wait_for_result() + result = client.get_result() + text = result.sequence + print(text) + if "tiago" in text.lower().strip(): + if "repeat" in text.lower().strip(): + repeating = True + voice.sync_tts("Okay, I'll start repeating now.") + continue + elif "stop" in text.lower().strip(): + repeating = False + voice.sync_tts("Okay, I'll stop repeating now.") + break + if repeating: + voice.sync_tts(f"I heard {text}") +else: + transcribe = rospy.ServiceProxy("/whisper/transcribe_audio", TranscribeAudio) + repeating = False + while not rospy.is_shutdown(): + text = transcribe().phrase + print(text) + if "tiago" in text.lower().strip(): + if "repeat" in text.lower().strip(): + repeating = True + voice.sync_tts("Okay, I'll start repeating now.") + continue + elif "stop" in text.lower().strip(): + repeating = False + voice.sync_tts("Okay, I'll stop repeating now.") + break + if repeating: + voice.sync_tts(f"I heard {text}") diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py new file mode 100644 index 000000000..85616017c --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py @@ -0,0 +1,57 @@ +#!/usr/bin python3 + +import os +import argparse +import speech_recognition as sr + + +def parse_args() -> dict: + """Parse command line arguments into a dictionary. + + Returns: + dict: name: value pairs of command line arguments + """ + + parser = argparse.ArgumentParser(description="Test microphones") + parser.add_argument("-m", "--microphone", type=int, help="Microphone index") + parser.add_argument( + "-o", "--output_dir", type=str, help="Directory to save audio files" + ) + + return vars(parser.parse_args()) + + +def main(args: dict) -> None: + """Generate audio files from microphone input. + + Args: + args (dict): dictionary of command line arguments. + """ + + # Adapted from https://github.com/Uberi/speech_recognition/blob/master/examples/write_audio.py + + mic_index = args["microphone"] + output_dir = args["output_dir"] + + r = sr.Recognizer() + r.pause_threshold = 2 + with sr.Microphone(device_index=9, sample_rate=16000) as source: + print("Say something!") + audio = r.listen(source, timeout=5, phrase_time_limit=10) + print("Finished listening") + + with open(os.path.join(output_dir, "microphone.raw"), "wb") as f: + f.write(audio.get_raw_data()) + + with open(os.path.join(output_dir, "microphone.wav"), "wb") as f: + f.write(audio.get_wav_data()) + + with open(os.path.join(output_dir, "microphone.flac"), "wb") as f: + f.write(audio.get_flac_data()) + + with open(os.path.join(output_dir, "microphone.aiff"), "wb") as f: + f.write(audio.get_aiff_data()) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py new file mode 100644 index 000000000..9e1253988 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -0,0 +1,21 @@ +#!/usr/bin python3 +import rclpy +import actionlib # TODO change to reg actions +from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore +from lasr_speech_recognition_msgs.msg import ( # type: ignore + TranscribeSpeechAction, + TranscribeSpeechGoal, +) + + +rospy.init_node("test_speech_server") +client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction) +client.wait_for_server() +rospy.loginfo("Done waiting") +while not rospy.is_shutdown(): + goal = TranscribeSpeechGoal() + client.send_goal(goal) + client.wait_for_result() + result = client.get_result() + text = result.sequence + print(f"Transcribed Speech: {text}") diff --git a/common/speech/lasr_speech_recognition_whisper/setup.cfg b/common/speech/lasr_speech_recognition_whisper/setup.cfg new file mode 100644 index 000000000..5ec86217a --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/lasr_speech_recognition_whisper +[install] +install_scripts=$base/lib/lasr_speech_recognition_whisper diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py new file mode 100644 index 000000000..68c121e12 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -0,0 +1,25 @@ +from setuptools import find_packages, setup + +package_name = 'lasr_speech_recognition_whisper' + +setup( + name=package_name, + version='0.0.0', + packages=find_packages(exclude=['test']), + data_files=[ + ('share/ament_index/resource_index/packages', + ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + ], + install_requires=['setuptools'], + zip_safe=True, + maintainer='maayan', + maintainer_email='maayan@todo.todo', + description='TODO: Package description', + license='TODO: License declaration', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + ], + }, +) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py new file mode 100644 index 000000000..85b18fb9b --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py @@ -0,0 +1,12 @@ +# from .collector import AbstractPhraseCollector, AudioTopicPhraseCollector, MicrophonePhraseCollector, RecognizerPhraseCollector +from .collector import ( + AbstractPhraseCollector, + MicrophonePhraseCollector, + RecognizerPhraseCollector, +) +from .worker import ( + SpeechRecognitionWorker, + SpeechRecognitionToStdout, + SpeechRecognitionToTopic, +) +from .cache import load_model diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py new file mode 100644 index 000000000..1f86b7ffc --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py @@ -0,0 +1,137 @@ +import io + + +class BytesFIFO(object): + """ + A FIFO that can store a fixed number of bytes. + https://github.com/hbock/byte-fifo/blob/master/fifo.py + """ + + def __init__(self, init_size): + """Create a FIFO of ``init_size`` bytes.""" + self._buffer = io.BytesIO(b"\x00" * init_size) + self._size = init_size + self._filled = 0 + self._read_ptr = 0 + self._write_ptr = 0 + + def read(self, size=-1): + """ + Read at most ``size`` bytes from the FIFO. + + If less than ``size`` bytes are available, or ``size`` is negative, + return all remaining bytes. + """ + if size < 0: + size = self._filled + + # Go to read pointer + self._buffer.seek(self._read_ptr) + + # Figure out how many bytes we can really read + size = min(size, self._filled) + contig = self._size - self._read_ptr + contig_read = min(contig, size) + + ret = self._buffer.read(contig_read) + self._read_ptr += contig_read + if contig_read < size: + leftover_size = size - contig_read + self._buffer.seek(0) + ret += self._buffer.read(leftover_size) + self._read_ptr = leftover_size + + self._filled -= size + + return ret + + def write(self, data): + """ + Write as many bytes of ``data`` as are free in the FIFO. + + If less than ``len(data)`` bytes are free, write as many as can be written. + Returns the number of bytes written. + """ + free = self.free() + write_size = min(len(data), free) + + if write_size: + contig = self._size - self._write_ptr + contig_write = min(contig, write_size) + # TODO: avoid 0 write + # TODO: avoid copy + # TODO: test performance of above + self._buffer.seek(self._write_ptr) + self._buffer.write(data[:contig_write]) + self._write_ptr += contig_write + + if contig < write_size: + self._buffer.seek(0) + self._buffer.write(data[contig_write:write_size]) + # self._buffer.write(buffer(data, contig_write, write_size - contig_write)) + self._write_ptr = write_size - contig_write + + self._filled += write_size + + return write_size + + def flush(self): + """Flush all data from the FIFO.""" + self._filled = 0 + self._read_ptr = 0 + self._write_ptr = 0 + + def empty(self): + """Return ```True``` if FIFO is empty.""" + return self._filled == 0 + + def full(self): + """Return ``True`` if FIFO is full.""" + return self._filled == self._size + + def free(self): + """Return the number of bytes that can be written to the FIFO.""" + return self._size - self._filled + + def capacity(self): + """Return the total space allocated for this FIFO.""" + return self._size + + def __len__(self): + """Return the amount of data filled in FIFO""" + return self._filled + + def __nonzero__(self): + """Return ```True``` if the FIFO is not empty.""" + return self._filled > 0 + + def resize(self, new_size): + """ + Resize FIFO to contain ``new_size`` bytes. If FIFO currently has + more than ``new_size`` bytes filled, :exc:`ValueError` is raised. + If ``new_size`` is less than 1, :exc:`ValueError` is raised. + + If ``new_size`` is smaller than the current size, the internal + buffer is not contracted (yet). + """ + if new_size < 1: + raise ValueError("Cannot resize to zero or less bytes.") + + if new_size < self._filled: + raise ValueError( + "Cannot contract FIFO to less than {} bytes, " + "or data will be lost.".format(self._filled) + ) + + # original data is non-contiguous. we need to copy old data, + # re-write to the beginning of the buffer, and re-sync + # the read and write pointers. + if self._read_ptr >= self._write_ptr: + old_data = self.read(self._filled) + self._buffer.seek(0) + self._buffer.write(old_data) + self._filled = len(old_data) + self._read_ptr = 0 + self._write_ptr = self._filled + + self._size = new_size diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py new file mode 100644 index 000000000..42ec44785 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -0,0 +1,43 @@ +import os +import whisper # type: ignore +import rospkg # type: ignore +import rospy + +# Keep all loaded models in memory +MODEL_CACHE = {} + + +def load_model( + name: str, device: str = "cpu", load_test_file: bool = False +) -> whisper.Whisper: + """Loads a whisper model from disk, or from cache if it has already been loaded. + + Args: + name (str): Name of the whisper model. Must be the name of an official whisper + model, or the path to a model checkpoint. + device (str, optional): Pytorch device to put the model on. Defaults to 'cpu'. + load_test_file (bool, optional): Whether to run inference on a test audio file + after loading the model (if model is not in cache). Defaults to False. Test file + is assumed to be called "test.m4a" and be in the root of the package directory. + + Returns: + whisper.Whisper: Whisper model instance + """ + global MODEL_CACHE + + if name not in MODEL_CACHE: + rospy.loginfo(f"Loading model {name}") + MODEL_CACHE[name] = whisper.load_model(name, device=device) + rospy.loginfo(f"Sucessfully loaded model {name} on {device}") + if load_test_file: + package_root = rospkg.RosPack().get_path("lasr_speech_recognition_whisper") + example_fp = os.path.join(package_root, "test.m4a") + rospy.loginfo( + "Running transcription on example file to ensure model is loaded..." + ) + test_result: str = MODEL_CACHE[name].transcribe( + example_fp, fp16=device == "cuda" + ) + rospy.loginfo(f"Transcription test result: {test_result}") + + return MODEL_CACHE[name] diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py new file mode 100644 index 000000000..ae15a5347 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py @@ -0,0 +1,131 @@ +import rospy + +import speech_recognition as sr + +from queue import Queue +from abc import ABC, abstractmethod + +# from .source import AudioTopic + + +class AbstractPhraseCollector(ABC): + """ + Supertype holding a queue of audio data representing a phrase + """ + + data: Queue[bytes] = Queue() + + @abstractmethod + def start(self): + """ + Start collecting phrases + """ + pass + + @abstractmethod + def stop(self): + """ + Stop collecting phrases + """ + pass + + @abstractmethod + def sample_rate(self): + """ + Sample rate of the data + """ + pass + + @abstractmethod + def sample_width(self): + """ + Sample width of the data + """ + pass + + +class RecognizerPhraseCollector(AbstractPhraseCollector): + """ + Collect phrases using a SoundRecognition Recognizer + + This will monitor energy levels on the input and only + capture when a certain threshold of activity is met. + """ + + _recorder: sr.Recognizer + _phrase_time_limit: float + + def _record_callback(self, _, audio: sr.AudioData) -> None: + """ + Collect raw audio data from the microphone + """ + self.data.put(audio.get_raw_data()) + + def __init__( + self, energy_threshold: int = 500, phrase_time_limit: float = 2 + ) -> None: + super().__init__() + self._recorder = sr.Recognizer() + self._recorder.dynamic_energy_threshold = False + self._recorder.energy_threshold = energy_threshold + self._phrase_time_limit = phrase_time_limit + + @abstractmethod + def adjust_for_noise(self, source: sr.AudioSource): + rospy.loginfo("Adjusting for background noise...") + with source: + self._recorder.adjust_for_ambient_noise(source) + + @abstractmethod + def start(self, source: sr.AudioSource): + rospy.loginfo("Started source listen thread") + self._stopper = self._recorder.listen_in_background( + source, self._record_callback, phrase_time_limit=self._phrase_time_limit + ) + + def stop(self): + self._stopper() + + def sample_rate(self): + return self._source.SAMPLE_RATE + + def sample_width(self): + return self._source.SAMPLE_WIDTH + + +class MicrophonePhraseCollector(RecognizerPhraseCollector): + """ + Collect phrases from the default microphone + """ + + _source: sr.Microphone + + def __init__( + self, + energy_threshold: int = 500, + phrase_time_limit: float = 2, + device_index: int = None, + ) -> None: + self._source = sr.Microphone(device_index=device_index, sample_rate=16000) + super().__init__(energy_threshold, phrase_time_limit) + + def adjust_for_noise(self): + return super().adjust_for_noise(self._source) + + def start(self): + return super().start(self._source) + + +# class AudioTopicPhraseCollector(RecognizerPhraseCollector): +# ''' +# Collect phrases from an audio topic +# ''' + +# _source: AudioTopic + +# def __init__(self, topic: str, energy_threshold: int = 100, phrase_time_limit: float = 2) -> None: +# self._source = AudioTopic(topic) +# super().__init__(energy_threshold, phrase_time_limit) + +# def start(self): +# return super().start(self._source) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py new file mode 100644 index 000000000..dd235ceb2 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py @@ -0,0 +1,57 @@ +import rospy +import pyaudio +import speech_recognition as sr + +from audio_common_msgs.msg import AudioInfo, AudioData + +from .bytesfifo import BytesFIFO + + +class AudioTopic(sr.AudioSource): + """ + Use a ROS topic as an AudioSource + """ + + _topic: str + _sub: rospy.Subscriber + + def __init__(self, topic: str, chunk_size=1024) -> None: + self._topic = topic + + config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo) + assert config.coding_format == "wave", "Expected Wave audio format" + assert config.sample_format == "S16LE", "Expected sample format S16LE" + rospy.loginfo(config) + + self.SAMPLE_WIDTH = pyaudio.get_sample_size(pyaudio.paInt16) + self.SAMPLE_RATE = config.sample_rate + + self.CHUNK = chunk_size + self.stream = None + + def __enter__(self): + """ + Start stream when entering with: block + """ + + assert ( + self.stream is None + ), "This audio source is already inside a context manager" + self.stream = BytesFIFO(1024 * 10) # 10 kB buffer + self._sub = rospy.Subscriber(f"{self._topic}/audio", AudioData, self._read) + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + Close out stream on exit + """ + + self.stream = None + self._sub.unregister() + + def _read(self, msg: AudioData) -> None: + """ + Forward raw audio data to queue + """ + + self.stream.write(msg.data) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py new file mode 100644 index 000000000..a560dc798 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py @@ -0,0 +1,203 @@ +import torch +import rospy +import whisper +import speech_recognition as sr + +from io import BytesIO +from time import sleep +from threading import Thread +from abc import ABC, abstractmethod +from tempfile import NamedTemporaryFile +from datetime import datetime, timedelta + +from .collector import AbstractPhraseCollector + +from lasr_speech_recognition_msgs.msg import Transcription + + +class SpeechRecognitionWorker(ABC): + """ + Collect and run inference on phrases to produce a transcription + """ + + _collector: AbstractPhraseCollector + _tmp_file: NamedTemporaryFile + _model: whisper.Whisper + _current_sample: bytes + _phrase_start: datetime + _maximum_phrase_length: timedelta | None + _infer_partial: bool + _stopped = True + + def __init__( + self, + collector: AbstractPhraseCollector, + model: whisper.Whisper, + maximum_phrase_length=timedelta(seconds=3), + infer_partial=True, + ) -> None: + self._collector = collector + self._tmp_file = NamedTemporaryFile().name + self._model = model + self._current_sample = bytes() + self._phrase_start = None + self._maximum_phrase_length = maximum_phrase_length + self._infer_partial = infer_partial + + @abstractmethod + def on_phrase(self, phrase: str, finished: bool) -> None: + """ + Handle a partial or complete transcription + """ + pass + + def _finish_phrase(self): + """ + Complete the current phrase and clear the sample + """ + + text = self._perform_inference() + if text is not None: + self.on_phrase(text, True) + + self._current_sample = bytes() + self._phrase_start = None + + def _perform_inference(self): + """ + Run inference on the current sample + """ + + rospy.loginfo("Processing sample") + audio_data = sr.AudioData( + self._current_sample, + self._collector.sample_rate(), + self._collector.sample_width(), + ) + wav_data = BytesIO(audio_data.get_wav_data()) + + with open(self._tmp_file, "w+b") as f: + f.write(wav_data.read()) + + rospy.loginfo("Running inference") + try: + result = self._model.transcribe( + self._tmp_file, fp16=torch.cuda.is_available() + ) + except RuntimeError: + return None + text = result["text"].strip() + + # Detect and drop garbage + if len(text) == 0 or text.lower() in [".", "you", "thanks for watching!"]: + self._phrase_start = None + self._current_sample = bytes() + rospy.loginfo("Skipping garbage...") + return None + + return text + + def _worker(self): + """ + Indefinitely perform inference on the given data + """ + + rospy.loginfo("Started inference worker") + + while not self._stopped: + try: + # Check whether the current phrase has timed out + now = datetime.utcnow() + if ( + self._phrase_start + and now - self._phrase_start > self._maximum_phrase_length + ): + rospy.loginfo("Reached timeout for phrase, ending now.") + self._finish_phrase() + + # Start / continue phrase if data is coming in + if not self._collector.data.empty(): + self._phrase_start = datetime.utcnow() + + # Concatenate new data with current sample + while not self._collector.data.empty(): + self._current_sample += self._collector.data.get() + + rospy.loginfo( + "Received and added more data to current audio sample." + ) + + # Run inference on partial sample if enabled + if self._infer_partial: + text = self._perform_inference() + + # Handle partial transcription + if text is not None: + self.on_phrase(text, False) + + sleep(0.2) + except KeyboardInterrupt: + self._stopped = True + + rospy.loginfo("Worker finished") + + def start(self): + """ + Start performing inference on incoming data + """ + + assert self._stopped, "Already running inference" + self._stopped = False + self._collector.start() + worker_thread = Thread(target=self._worker) + worker_thread.start() + + def stop(self): + """ + Stop the worker from running inference + """ + + assert not self._stopped, "Not currently running" + self._collector.stop() + self._stopped = True + + # clear next phrase + self._current_sample = bytes() + while not self._collector.data.empty(): + self._current_sample += self._collector.data.get() + + +class SpeechRecognitionToStdout(SpeechRecognitionWorker): + """ + Recognise speech and pass it through to standard output + """ + + def on_phrase(self, phrase: str, finished: bool) -> None: + rospy.loginfo("[" + ("x" if finished else " ") + "] " + phrase) + + +class SpeechRecognitionToTopic(SpeechRecognitionToStdout): + """ + Recognise speech and publish it to a topic + """ + + _pub: rospy.Publisher + + def __init__( + self, + collector: AbstractPhraseCollector, + model: whisper.Whisper, + topic: str, + maximum_phrase_length=timedelta(seconds=1), + infer_partial=True, + ) -> None: + super().__init__(collector, model, maximum_phrase_length, infer_partial) + rospy.loginfo(f"Will be publishing transcription to {topic}") + self._pub = rospy.Publisher(topic, Transcription, queue_size=5) + + def on_phrase(self, phrase: str, finished: bool) -> None: + super().on_phrase(phrase, finished) + msg = Transcription() + msg.phrase = phrase + msg.finished = finished + self._pub.publish(msg) diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py new file mode 100644 index 000000000..97a39196e --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py @@ -0,0 +1,25 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_copyright.main import main +import pytest + + +# Remove the `skip` decorator once the source file(s) have a copyright header +@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') +@pytest.mark.copyright +@pytest.mark.linter +def test_copyright(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found errors' diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py new file mode 100644 index 000000000..27ee1078f --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py @@ -0,0 +1,25 @@ +# Copyright 2017 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_flake8.main import main_with_errors +import pytest + + +@pytest.mark.flake8 +@pytest.mark.linter +def test_flake8(): + rc, errors = main_with_errors(argv=[]) + assert rc == 0, \ + 'Found %d code style errors / warnings:\n' % len(errors) + \ + '\n'.join(errors) diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py new file mode 100644 index 000000000..b234a3840 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py @@ -0,0 +1,23 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_pep257.main import main +import pytest + + +@pytest.mark.linter +@pytest.mark.pep257 +def test_pep257(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found code style errors / warnings' From 28e3881308918271362cbe80557a86e2e9c94d2c Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Sun, 20 Oct 2024 21:13:12 +0100 Subject: [PATCH 03/14] Port preemption for whisper server node --- .../CMakeLists.txt | 5 + .../README.md | 2 +- .../nodes/simple_transcribe_microphone | 17 +- .../nodes/transcribe_microphone | 25 +-- .../nodes/transcribe_microphone_server | 117 ++++++++------ .../package.xml | 2 +- .../requirements.in | 3 +- .../requirements.txt | 10 +- .../scripts/microphone_tuning_test.py | 1 + .../scripts/repeat_after_me.py | 145 +++++++++++------- .../scripts/test_speech_server.py | 17 +- .../lasr_speech_recognition_whisper/setup.py | 11 +- .../__init__.py | 2 +- .../lasr_speech_recognition_whisper/cache.py | 69 +++++---- .../collector.py | 9 +- .../lasr_speech_recognition_whisper/source.py | 14 +- .../lasr_speech_recognition_whisper/worker.py | 28 ++-- 17 files changed, 272 insertions(+), 205 deletions(-) diff --git a/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt b/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt index ac4de94d7..204934829 100644 --- a/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt +++ b/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt @@ -7,6 +7,9 @@ endif() # find dependencies find_package(ament_cmake REQUIRED) +find_package(rclpy REQUIRED) +find_package(action_msgs REQUIRED) + # uncomment the following section in order to fill in # further dependencies manually. # find_package( REQUIRED) @@ -21,6 +24,8 @@ rosidl_generate_interfaces(${PROJECT_NAME} DEPENDENCIES builtin_interfaces # Add packages that above messages depend on ) +ament_export_dependencies(rosidl_default_runtime) + if(BUILD_TESTING) find_package(ament_lint_auto REQUIRED) # the following line skips the linter which checks for copyrights diff --git a/common/speech/lasr_speech_recognition_interfaces/README.md b/common/speech/lasr_speech_recognition_interfaces/README.md index 1378bcdf4..c96279cb6 100644 --- a/common/speech/lasr_speech_recognition_interfaces/README.md +++ b/common/speech/lasr_speech_recognition_interfaces/README.md @@ -1,4 +1,4 @@ -# lasr_speech_recognition_msgs +# lasr_speech_recognition_interfaces Common messages used for speech recognition diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone b/common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone index 62342afc9..f03cf2757 100644 --- a/common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone +++ b/common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone @@ -1,14 +1,15 @@ #!/usr/bin python3 import os import torch -import rospkg # TODO check if change import rclpy +from ament_index_python import packages + import sys from pathlib import Path import speech_recognition as sr import numpy as np -from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse +from lasr_speech_recognition_interfaces.srv import TranscribeAudio, TranscribeAudioResponse from lasr_speech_recognition_whisper import load_model # TODO rospkg @@ -23,8 +24,8 @@ os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE if len(sys.argv) < 3: print('Usage:') - print('rosrun lasr_speech_recognition transcribe_microphone by-index ') - print('rosrun lasr_speech_recognition transcribe_microphone by-name ') + print('ros2 run lasr_speech_recognition transcribe_microphone by-index ') + print('ros2 run lasr_speech_recognition transcribe_microphone by-name ') exit(1) else: matcher = sys.argv[1] @@ -57,8 +58,8 @@ model = load_model("medium.en", device=device) # try to run inference on the example file r = rospkg.RosPack() EXAMPLE_FILE = r.get_path('lasr_speech_recognition_whisper') + "/test.m4a" -rclpy.get_logger().info("Running transcription on example file to ensure model is loaded...") -rclpy.get_logger().info(model.transcribe(EXAMPLE_FILE, fp16=torch.cuda.is_available())) +node.get_logger().info("Running transcription on example file to ensure model is loaded...") +node.get_logger().info(model.transcribe(EXAMPLE_FILE, fp16=torch.cuda.is_available())) microphone = sr.Microphone(device_index=device_index, sample_rate=16000) r = sr.Recognizer() @@ -74,7 +75,7 @@ def handle_transcribe_audio(_): phrase = model.transcribe(float_data, fp16=device == "cuda")["text"] return TranscribeAudioResponse(phrase=phrase) -node.create_service('/whisper/transcribe_audio', TranscribeAudio, handle_transcribe_audio) +node.create_service(TranscribeAudio, '/whisper/transcribe_audio', handle_transcribe_audio) -rclpy.get_logger().info("Whisper service ready") +node.get_logger().info("Whisper service ready") rclpy.spin(node) \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone index ae86ef310..38a8d84ef 100644 --- a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone +++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone @@ -1,7 +1,7 @@ #!/usr/bin python3 import os import torch -import rospkg +from ament_index_python import packages from pathlib import Path WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper') @@ -39,9 +39,11 @@ else: print('Invalid matcher') exit(1) -import rospy +import rclpy from std_srvs.srv import Empty, EmptyResponse -rospy.init_node('transcribe_mic', anonymous=True) + +with rclpy.init(args=None): + node = rclpy.create_node('transcribe_mic') # was anonymous in ROS1 from lasr_speech_recognition_whisper import SpeechRecognitionToTopic, MicrophonePhraseCollector, load_model @@ -53,10 +55,9 @@ collector.adjust_for_noise() model = load_model("medium.en") # try to run inference on the example file -r = rospkg.RosPack() -EXAMPLE_FILE = r.get_path('lasr_speech_recognition_whisper') + "/test.m4a" -rospy.loginfo("Running transcription on example file to ensure model is loaded...") -rospy.loginfo(model.transcribe(EXAMPLE_FILE, fp16=torch.cuda.is_available())) +EXAMPLE_FILE = packages.get_package_share_path('lasr_speech_recognition_whisper') + "/test.m4a" +node.get_logger().info("Running transcription on example file to ensure model is loaded...") +node.get_logger().info(model.transcribe(EXAMPLE_FILE, fp16=torch.cuda.is_available())) worker = SpeechRecognitionToTopic(collector, model, "transcription", infer_partial = False) @@ -72,9 +73,9 @@ def stop_listening(_): worker.stop() return EmptyResponse() -rospy.Service('/whisper/adjust_for_noise', Empty, adjust_for_noise) -rospy.Service('/whisper/start_listening', Empty, start_listening) -rospy.Service('/whisper/stop_listening', Empty, stop_listening) +node.create_service(Empty, '/whisper/adjust_for_noise', adjust_for_noise) +node.create_service(Empty, '/whisper/start_listening', start_listening) +node.create_service(Empty, '/whisper/stop_listening', stop_listening) -rospy.loginfo("Starting the Whisper worker!") -rospy.spin() +node.get_logger().info("Starting the Whisper worker!") +rclpy.spin(node) diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server index 680f97c5d..7ff1f3efe 100644 --- a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server +++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server @@ -12,14 +12,18 @@ import torch import rclpy from rclpy.node import Node -from rclpy.action import ActionServer +from rclpy.action.server import ActionServer, CancelResponse import speech_recognition as sr # type: ignore -import lasr_speech_recognition_msgs.msg # type: ignore +import lasr_speech_recognition_interfaces.msg # type: ignore +from rclpy.executors import ExternalShutdownException from std_msgs.msg import String # type: ignore -from lasr_speech_recognition_whisper import load_model # type: ignore +from lasr_speech_recognition_whisper import ModelCache # type: ignore -# TODO: argpars -> ROS2 params, behaviour of preemption, behaviour of rclpy.spin(node) +# from common.speech.lasr_speech_recognition_whisper.scripts.repeat_after_me import result + + +# TODO: argpars -> ROS2 params, test behaviour of preemption @dataclass class speech_model_params: @@ -51,10 +55,10 @@ class speech_model_params: pause_threshold: Optional[float] = 2.0 -class TranscribeSpeechAction(object): +class TranscribeSpeechAction(object, Node): # create messages that are used to publish feedback/result - _feedback = lasr_speech_recognition_msgs.msg.TranscribeSpeechFeedback() - _result = lasr_speech_recognition_msgs.msg.TranscribeSpeechResult() + _feedback = lasr_speech_recognition_interfaces.msg.TranscribeSpeechFeedback() + _result = lasr_speech_recognition_interfaces.msg.TranscribeSpeechResult() def __init__( self, @@ -66,33 +70,35 @@ class TranscribeSpeechAction(object): Args: action_name (str): Name of the action server. """ - + Node.__init__(self, "transcribe_speech_action") self._action_name = action_name self._model_params = model_params - self._transcription_server = node.create_publisher( + self._transcription_server = self.create_publisher( String, "/live_speech_transcription", 10 ) - self._model = load_model( + self._model = ModelCache.load_model( self._model_params.model_name, self._model_params.device, self._model_params.warmup, ) # Configure the speech recogniser object and adjust for ambient noise self.recogniser = self._configure_recogniser() - # Setup the action server and register execution callback - # TODO check behaviour of ActionServer + + # Set up the action server and register execution callback self._action_server = ActionServer( + self, + lasr_speech_recognition_interfaces.action.TranscribeSpeechAction, self._action_name, - lasr_speech_recognition_msgs.msg.TranscribeSpeechAction, - execute_cb=self.execute_cb, - auto_start=False, + execute_callback=self.execute_cb, + cancel_callback=self.cancel_cb, + # auto_start=False, # not required in ROS2 ?? (cb is async) ) - self._action_server.register_preempt_callback(self.prempt_cb) + self._action_server.register_cancel_callback(self.cancel_cb) self._listening = False - self._action_server.start() - rclpy.get_logger().info(f"Speech Action server {self._action_name} started") + # self._action_server.start() # not required in ROS2 + self.get_logger().info(f"Speech Action server {self._action_name} started") def _configure_microphone(self) -> sr.Microphone: """Configures the microphone for listening to speech based on the @@ -159,28 +165,33 @@ class TranscribeSpeechAction(object): self._listening = False return recogniser - def prempt_cb(self) -> None: - """Callback for preempting the action server. - - Sets server to preempted state. + def cancel_cb(self, goal_handle) -> CancelResponse: + """Callback for cancelling the action server. + Sets server to 'canceled' state. """ - preempted_str = f"{self._action_name} has been preempted" - rclpy.get_logger().info(preempted_str) - self._result.sequence = preempted_str - self._action_server.set_preempted(result=self._result, text=preempted_str) + cancel_str = f"{self._action_name} has been cancelled" + self.get_logger().info(cancel_str) + self._result.sequence = cancel_str + + # self._action_server.set_preempted(result=self._result, text=cancel_str) + goal_handle.canceled() - def execute_cb(self, goal) -> None: + return CancelResponse.ACCEPT # TODO decide if always accept cancellation + + async def execute_cb(self, goal_handle) -> None: """Callback for executing the action server. - Checks for preemption before listening and before and after transcribing, returning - if preemption is requested. + Checks for cancellation before listening and before and after transcribing, returning + if cancellation is requested. Args: - goal: UNUSED - actionlib requires a goal argument in the execute callback, but - this action server does not use a goal. + :param goal_handle: handles the goal request, and provides access to the goal parameters """ - rclpy.get_logger().info("Request Received") - if self._action_server.is_preempt_requested(): + + goal = goal_handle.request + + self.get_logger().info("Request Received") + if goal_handle.is_cancel_requested(): return if goal.energy_threshold > 0.0 and goal.max_phrase_limit > 0.0: @@ -207,11 +218,13 @@ class TranscribeSpeechAction(object): / 32768.0 ) - if self._action_server.is_preempt_requested(): + if goal_handle.is_cancel_requested(): self._listening = False - return + self.get_logger().info("Goal was cancelled during execution.") + goal_handle.canceled() + return self._result - rclpy.get_logger().info(f"Transcribing phrase with Whisper...") + self.get_logger().info(f"Transcribing phrase with Whisper...") transcription_start_time = timer() # Cast to fp16 if using GPU phrase = self._model.transcribe( @@ -219,23 +232,26 @@ class TranscribeSpeechAction(object): fp16=self._model_params.device == "cuda", )["text"] transcription_end_time = timer() - rclpy.get_logger().info(f"Transcription finished!") - rclpy.get_logger().info( + self.get_logger().info(f"Transcription finished!") + self.get_logger().info( f"Time taken: {transcription_end_time - transcription_start_time:.2f}s" ) self._transcription_server.publish(phrase) - if self._action_server.is_preempt_requested(): + if goal_handle.is_cancel_requested(): self._listening = False return self._result.sequence = phrase - rclpy.get_logger().info(f"Transcribed phrase: {phrase}") - rclpy.get_logger().info(f"{self._action_name} has succeeded") - self._action_server.set_succeeded(self._result) + self.get_logger().info(f"Transcribed phrase: {phrase}") + self.get_logger().info(f"{self._action_name} has succeeded") + + goal_handle.succeed() # Have this at the very end to not disrupt the action server self._listening = False + return self._result + def parse_args() -> dict: """Parses the command line arguments into a name: value dictinoary. @@ -355,20 +371,21 @@ def configure_whisper_cache() -> None: """Configures the whisper cache directory.""" whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") os.makedirs(whisper_cache, exist_ok=True) - # Environemntal variable required to run whisper locally + # Environmental variable required to run whisper locally os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache def main(args=None): + rclpy.init(args=args) + configure_whisper_cache() config = parse_args() - try: - rclpy.init(args=args) - whisper_node = rclpy.create_node("transcribe_speech_server") - server = TranscribeSpeechAction("transcribe_speech", configure_model_params(config)) - rclpy.spin(whisper_node) # TODO check behaviour (was rospy.spin()) + server = TranscribeSpeechAction("transcribe_speech", configure_model_params(config)) + + try: + rclpy.spin(server) except (KeyboardInterrupt, ExternalShutdownException): pass -if __name__ == "__main__": - main() +# if __name__ == "__main__": +# main() diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml index bb3215ee6..825aae036 100644 --- a/common/speech/lasr_speech_recognition_whisper/package.xml +++ b/common/speech/lasr_speech_recognition_whisper/package.xml @@ -17,7 +17,7 @@ catkin_virtualenv TODO fix virtualenv build --> - lasr_speech_recognition_msgs + lasr_speech_recognition_interfaces actionlib actionlib_msgs actionlib diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.in b/common/speech/lasr_speech_recognition_whisper/requirements.in index da48c5086..25fe69190 100644 --- a/common/speech/lasr_speech_recognition_whisper/requirements.in +++ b/common/speech/lasr_speech_recognition_whisper/requirements.in @@ -2,5 +2,4 @@ SpeechRecognition==3.10.0 sounddevice==0.4.6 openai-whisper==20231117 PyAudio==0.2.13 -PyYaml==6.0.1 -rospkg==1.5.0 +PyYaml==6.0.1 \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt index bc986de21..87bab6318 100644 --- a/common/speech/lasr_speech_recognition_whisper/requirements.txt +++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt @@ -1,9 +1,6 @@ -catkin-pkg==1.0.0 # via rospkg certifi==2024.2.2 # via requests cffi==1.16.0 # via sounddevice charset-normalizer==3.3.2 # via requests -distro==1.9.0 # via rospkg -docutils==0.21.2 # via catkin-pkg filelock==3.14.0 # via torch, triton fsspec==2024.3.1 # via torch idna==3.7 # via requests @@ -28,14 +25,11 @@ nvidia-nccl-cu12==2.20.5 # via torch nvidia-nvjitlink-cu12==12.4.127 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 # via torch openai-whisper==20231117 # via -r requirements.in -pyaudio==0.2.13 # via -r requirements.in +# pyaudio==0.2.13 # via -r requirements.in pycparser==2.22 # via cffi -pyparsing==3.1.2 # via catkin-pkg -python-dateutil==2.9.0.post0 # via catkin-pkg -pyyaml==6.0.1 # via -r requirements.in, rospkg +pyyaml==6.0.1 # via -r requirements.in regex==2024.4.28 # via tiktoken requests==2.31.0 # via speechrecognition, tiktoken -rospkg==1.5.0 # via -r requirements.in six==1.16.0 # via python-dateutil sounddevice==0.4.6 # via -r requirements.in speechrecognition==3.10.0 # via -r requirements.in diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py index a9a425df1..1a2b529ed 100644 --- a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py @@ -9,6 +9,7 @@ import sounddevice # needed to remove ALSA error messages from typing import Dict +# TODO argparse -> ROS params def parse_args() -> Dict: parser = argparse.ArgumentParser() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py index d7cce0519..6876d96b1 100644 --- a/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py @@ -1,58 +1,87 @@ -#!/usr/bin python3 -import rclpy -import actionlib # TODO change to reg actions -from lasr_voice import Voice # type: ignore -from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore -from lasr_speech_recognition_msgs.msg import ( # type: ignore - TranscribeSpeechAction, - TranscribeSpeechGoal, -) - -rospy.init_node("repeat") - -USE_ACTIONLIB = True - -voice = Voice() - - -if USE_ACTIONLIB: - client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction) - rospy.loginfo("Waiting for server...") - client.wait_for_server() - repeating = False - rospy.loginfo("Done waiting") - while not rospy.is_shutdown(): - goal = TranscribeSpeechGoal() - client.send_goal(goal) - client.wait_for_result() - result = client.get_result() - text = result.sequence - print(text) - if "tiago" in text.lower().strip(): - if "repeat" in text.lower().strip(): - repeating = True - voice.sync_tts("Okay, I'll start repeating now.") - continue - elif "stop" in text.lower().strip(): - repeating = False - voice.sync_tts("Okay, I'll stop repeating now.") - break - if repeating: - voice.sync_tts(f"I heard {text}") -else: - transcribe = rospy.ServiceProxy("/whisper/transcribe_audio", TranscribeAudio) - repeating = False - while not rospy.is_shutdown(): - text = transcribe().phrase - print(text) - if "tiago" in text.lower().strip(): - if "repeat" in text.lower().strip(): - repeating = True - voice.sync_tts("Okay, I'll start repeating now.") - continue - elif "stop" in text.lower().strip(): - repeating = False - voice.sync_tts("Okay, I'll stop repeating now.") - break - if repeating: - voice.sync_tts(f"I heard {text}") +# #!/usr/bin python3 +# import rclpy +# from rclpy.node import Node +# from rclpy.action import ActionClient +# from lasr_voice import Voice # type: ignore +# from lasr_speech_recognition_interfaces.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore +# from lasr_speech_recognition_interfaces.action import TranscribeSpeechAction, TranscribeSpeechGoal +# +# +# # TODO port file: action client, service proxy +# +# USE_ACTIONLIB = True +# +# class ServiceClientNode(Node): +# def __init__(self): +# super().__init__('service_client_node') +# self.client = None +# +# self.voice = Voice() +# +# # def call_service(self): +# # request = TranscribeAudio +# # +# # # Call the service synchronously +# # future = self.client.call_async(request) +# # rclpy.spin_until_future_complete(self, future) +# # +# # if future.result() is not None: +# # self.get_logger().info('Service call succeeded') +# # else: +# # self.get_logger().error('Service call failed') +# +# +# def call_service(self): +# request = TranscribeAudio +# else: +# # transcribe = rospy.ServiceProxy("/whisper/transcribe_audio", TranscribeAudio) +# repeating = False +# while rclpy.ok(): +# future = self.client.call_async(request) +# rclpy.spin_until_future_complete(self, future) +# if future.done(): +# text = transcribe().phrase +# self.get_logger().info(text) +# if "tiago" in text.lower().strip(): +# if "repeat" in text.lower().strip(): +# repeating = True +# self.voice.sync_tts("Okay, I'll start repeating now.") +# continue +# elif "stop" in text.lower().strip(): +# repeating = False +# self.voice.sync_tts("Okay, I'll stop repeating now.") +# break +# if repeating: +# self.voice.sync_tts(f"I heard {text}") +# +# +# class ActionClientNode(Node): +# def __init__(self): +# super().__init__('action_client_node') +# self.client = None +# +# def goal_callback(self, future): +# self.client = self.create_client(TranscribeSpeechAction, "transcribe_speech") +# # Wait for the server to be available +# while not self.client.wait_for_service(timeout_sec=5.0): +# self.get_logger().info("Waiting for server...") +# repeating = False +# self.get_logger().info("Done waiting") +# while rclpy.ok(): +# goal = TranscribeSpeechGoal() +# self.client.send_goal(goal) +# self.client.wait_for_result() +# result = self.client.get_result() +# text = result.sequence +# self.get_logger().info(text) +# if "tiago" in text.lower().strip(): +# if "repeat" in text.lower().strip(): +# repeating = True +# self.voice.sync_tts("Okay, I'll start repeating now.") +# continue +# elif "stop" in text.lower().strip(): +# repeating = False +# self.voice.sync_tts("Okay, I'll stop repeating now.") +# break +# if repeating: +# self.voice.sync_tts(f"I heard {text}") \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py index 9e1253988..7456293c2 100644 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -1,17 +1,22 @@ #!/usr/bin python3 +from argparse import Action + import rclpy -import actionlib # TODO change to reg actions -from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore -from lasr_speech_recognition_msgs.msg import ( # type: ignore +from rclpy.action import ActionClient +from lasr_speech_recognition_interfaces.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore +from lasr_speech_recognition_interfaces.msg import ( # type: ignore TranscribeSpeechAction, TranscribeSpeechGoal, ) +# TODO port file: action client, is_shutdown + +with rclpy.init(args=None): + node = rclpy.create_node("test_speech_server") -rospy.init_node("test_speech_server") -client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction) +client = ActionClient("transcribe_speech", TranscribeSpeechAction) client.wait_for_server() -rospy.loginfo("Done waiting") +node.get_logger().info("Done waiting") while not rospy.is_shutdown(): goal = TranscribeSpeechGoal() client.send_goal(goal) diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py index 68c121e12..cc8613cb8 100644 --- a/common/speech/lasr_speech_recognition_whisper/setup.py +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -5,7 +5,9 @@ setup( name=package_name, version='0.0.0', - packages=find_packages(exclude=['test']), + # packages=find_packages(exclude=['test']), + packages=["lasr_speech_recognition_whisper"], + package_dir = {'': 'src'}, data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + package_name]), @@ -14,12 +16,13 @@ install_requires=['setuptools'], zip_safe=True, maintainer='maayan', - maintainer_email='maayan@todo.todo', - description='TODO: Package description', - license='TODO: License declaration', + maintainer_email='maayan.armony@gmail.com', + description='Speech recognition implemented using OpenAI Whisper', + license='MIT', tests_require=['pytest'], entry_points={ 'console_scripts': [ + 'transcribe_mic_server_node = lasr_speech_recognition_whisper.nodes.transcribe_microphone_server:main', ], }, ) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py index 85b18fb9b..372e26477 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py @@ -9,4 +9,4 @@ SpeechRecognitionToStdout, SpeechRecognitionToTopic, ) -from .cache import load_model +from .cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py index 42ec44785..196d4a715 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -1,43 +1,46 @@ import os import whisper # type: ignore -import rospkg # type: ignore -import rospy - +from ament_index_python import packages +from rclpy.node import Node # Keep all loaded models in memory MODEL_CACHE = {} +class ModelCache(Node): + def __init__(self): + super().__init__('lasr_speech_recognition_whisper_cache') -def load_model( - name: str, device: str = "cpu", load_test_file: bool = False -) -> whisper.Whisper: - """Loads a whisper model from disk, or from cache if it has already been loaded. + def load_model( + self, + name: str, device: str = "cpu", load_test_file: bool = False + ) -> whisper.Whisper: + """Loads a whisper model from disk, or from cache if it has already been loaded. - Args: - name (str): Name of the whisper model. Must be the name of an official whisper - model, or the path to a model checkpoint. - device (str, optional): Pytorch device to put the model on. Defaults to 'cpu'. - load_test_file (bool, optional): Whether to run inference on a test audio file - after loading the model (if model is not in cache). Defaults to False. Test file - is assumed to be called "test.m4a" and be in the root of the package directory. + Args: + name (str): Name of the whisper model. Must be the name of an official whisper + model, or the path to a model checkpoint. + device (str, optional): Pytorch device to put the model on. Defaults to 'cpu'. + load_test_file (bool, optional): Whether to run inference on a test audio file + after loading the model (if model is not in cache). Defaults to False. Test file + is assumed to be called "test.m4a" and be in the root of the package directory. - Returns: - whisper.Whisper: Whisper model instance - """ - global MODEL_CACHE + Returns: + whisper.Whisper: Whisper model instance + """ + global MODEL_CACHE - if name not in MODEL_CACHE: - rospy.loginfo(f"Loading model {name}") - MODEL_CACHE[name] = whisper.load_model(name, device=device) - rospy.loginfo(f"Sucessfully loaded model {name} on {device}") - if load_test_file: - package_root = rospkg.RosPack().get_path("lasr_speech_recognition_whisper") - example_fp = os.path.join(package_root, "test.m4a") - rospy.loginfo( - "Running transcription on example file to ensure model is loaded..." - ) - test_result: str = MODEL_CACHE[name].transcribe( - example_fp, fp16=device == "cuda" - ) - rospy.loginfo(f"Transcription test result: {test_result}") + if name not in MODEL_CACHE: + self.get_logger().info(f"Loading model {name}") + MODEL_CACHE[name] = whisper.load_model(name, device=device) + self.get_logger().info(f"Sucessfully loaded model {name} on {device}") + if load_test_file: + package_root = packages.get_package_share_path("lasr_speech_recognition_whisper") + example_fp = os.path.join(package_root, "test.m4a") + self.get_logger().info( + "Running transcription on example file to ensure model is loaded..." + ) + test_result: str = MODEL_CACHE[name].transcribe( + example_fp, fp16=device == "cuda" + ) + self.get_logger().info(f"Transcription test result: {test_result}") - return MODEL_CACHE[name] + return MODEL_CACHE[name] diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py index ae15a5347..bca4ee00f 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py @@ -1,4 +1,4 @@ -import rospy +import rclpy import speech_recognition as sr @@ -65,6 +65,9 @@ def __init__( self, energy_threshold: int = 500, phrase_time_limit: float = 2 ) -> None: super().__init__() + with rclpy.init(args=None): + self.node = rclpy.create_node('source') + self._recorder = sr.Recognizer() self._recorder.dynamic_energy_threshold = False self._recorder.energy_threshold = energy_threshold @@ -72,13 +75,13 @@ def __init__( @abstractmethod def adjust_for_noise(self, source: sr.AudioSource): - rospy.loginfo("Adjusting for background noise...") + self.node.get_logger().info("Adjusting for background noise...") with source: self._recorder.adjust_for_ambient_noise(source) @abstractmethod def start(self, source: sr.AudioSource): - rospy.loginfo("Started source listen thread") + self.node.get_logger().info("Started source listen thread") self._stopper = self._recorder.listen_in_background( source, self._record_callback, phrase_time_limit=self._phrase_time_limit ) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py index dd235ceb2..53f8f103f 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py @@ -1,4 +1,4 @@ -import rospy +import rclpy import pyaudio import speech_recognition as sr @@ -6,6 +6,7 @@ from .bytesfifo import BytesFIFO +# TODO rospy.wait_for_message() class AudioTopic(sr.AudioSource): """ @@ -13,15 +14,18 @@ class AudioTopic(sr.AudioSource): """ _topic: str - _sub: rospy.Subscriber + # _sub: node.create_subscription TODO add type if possible def __init__(self, topic: str, chunk_size=1024) -> None: + with rclpy.init(args=None): + self.node = rclpy.create_node('source') + self._topic = topic config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo) assert config.coding_format == "wave", "Expected Wave audio format" assert config.sample_format == "S16LE", "Expected sample format S16LE" - rospy.loginfo(config) + self.node.get_logger().info(config) self.SAMPLE_WIDTH = pyaudio.get_sample_size(pyaudio.paInt16) self.SAMPLE_RATE = config.sample_rate @@ -38,7 +42,7 @@ def __enter__(self): self.stream is None ), "This audio source is already inside a context manager" self.stream = BytesFIFO(1024 * 10) # 10 kB buffer - self._sub = rospy.Subscriber(f"{self._topic}/audio", AudioData, self._read) + self._sub = self.node.create_subscription(AudioData, f"{self._topic}/audio", self._read) return self def __exit__(self, exc_type, exc_value, traceback): @@ -47,7 +51,7 @@ def __exit__(self, exc_type, exc_value, traceback): """ self.stream = None - self._sub.unregister() + self.node.destroy_subscription(self._sub) # TODO behaviour, was self._sub.unregister() def _read(self, msg: AudioData) -> None: """ diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py index a560dc798..f2dcf0cbb 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py @@ -1,5 +1,5 @@ import torch -import rospy +import rclpy import whisper import speech_recognition as sr @@ -12,7 +12,7 @@ from .collector import AbstractPhraseCollector -from lasr_speech_recognition_msgs.msg import Transcription +from lasr_speech_recognition_interfaces.msg import Transcription class SpeechRecognitionWorker(ABC): @@ -36,6 +36,8 @@ def __init__( maximum_phrase_length=timedelta(seconds=3), infer_partial=True, ) -> None: + with rclpy.init(args=None): + self.node = rclpy.create_node('worker') self._collector = collector self._tmp_file = NamedTemporaryFile().name self._model = model @@ -68,7 +70,7 @@ def _perform_inference(self): Run inference on the current sample """ - rospy.loginfo("Processing sample") + self.node.get_logger().info("Processing sample") audio_data = sr.AudioData( self._current_sample, self._collector.sample_rate(), @@ -79,7 +81,7 @@ def _perform_inference(self): with open(self._tmp_file, "w+b") as f: f.write(wav_data.read()) - rospy.loginfo("Running inference") + self.node.get_logger().info("Running inference") try: result = self._model.transcribe( self._tmp_file, fp16=torch.cuda.is_available() @@ -92,7 +94,7 @@ def _perform_inference(self): if len(text) == 0 or text.lower() in [".", "you", "thanks for watching!"]: self._phrase_start = None self._current_sample = bytes() - rospy.loginfo("Skipping garbage...") + self.node.get_logger().info("Skipping garbage...") return None return text @@ -102,7 +104,7 @@ def _worker(self): Indefinitely perform inference on the given data """ - rospy.loginfo("Started inference worker") + self.node.get_logger().info("Started inference worker") while not self._stopped: try: @@ -112,7 +114,7 @@ def _worker(self): self._phrase_start and now - self._phrase_start > self._maximum_phrase_length ): - rospy.loginfo("Reached timeout for phrase, ending now.") + self.node.get_logger().info("Reached timeout for phrase, ending now.") self._finish_phrase() # Start / continue phrase if data is coming in @@ -123,7 +125,7 @@ def _worker(self): while not self._collector.data.empty(): self._current_sample += self._collector.data.get() - rospy.loginfo( + self.node.get_logger().info( "Received and added more data to current audio sample." ) @@ -139,7 +141,7 @@ def _worker(self): except KeyboardInterrupt: self._stopped = True - rospy.loginfo("Worker finished") + self.node.get_logger().info("Worker finished") def start(self): """ @@ -173,7 +175,7 @@ class SpeechRecognitionToStdout(SpeechRecognitionWorker): """ def on_phrase(self, phrase: str, finished: bool) -> None: - rospy.loginfo("[" + ("x" if finished else " ") + "] " + phrase) + self.node.get_logger().info("[" + ("x" if finished else " ") + "] " + phrase) class SpeechRecognitionToTopic(SpeechRecognitionToStdout): @@ -181,7 +183,7 @@ class SpeechRecognitionToTopic(SpeechRecognitionToStdout): Recognise speech and publish it to a topic """ - _pub: rospy.Publisher + # _pub: node.create_publisher() TODO add type if possible def __init__( self, @@ -192,8 +194,8 @@ def __init__( infer_partial=True, ) -> None: super().__init__(collector, model, maximum_phrase_length, infer_partial) - rospy.loginfo(f"Will be publishing transcription to {topic}") - self._pub = rospy.Publisher(topic, Transcription, queue_size=5) + self.node.get_logger().info(f"Will be publishing transcription to {topic}") + self._pub = self.node.create_publisher(Transcription, topic, 5) def on_phrase(self, phrase: str, finished: bool) -> None: super().on_phrase(phrase, finished) From 2f76afd541aedfd1afb7d8aa3881b8e2bd66a74e Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Mon, 4 Nov 2024 10:24:13 +0000 Subject: [PATCH 04/14] Fix structure issues with Nodes --- common/__init__.py | 0 common/speech/__init__.py | 0 .../lasr_speech_recognition_whisper/README.md | 103 ++++++++++++++++++ .../__init__.py | 0 .../simple_transcribe_microphone.py} | 0 .../transcribe_microphone.py} | 0 .../transcribe_microphone_server.py} | 19 ++-- .../scripts/list_microphones.py | 0 .../scripts/microphone_tuning_test.py | 0 .../scripts/repeat_after_me.py | 0 .../scripts/test_microphones.py | 0 .../scripts/test_speech_server.py | 0 .../lasr_speech_recognition_whisper/setup.py | 9 +- .../lasr_speech_recognition_whisper/cache.py | 1 + 14 files changed, 117 insertions(+), 15 deletions(-) create mode 100644 common/__init__.py create mode 100644 common/speech/__init__.py create mode 100644 common/speech/lasr_speech_recognition_whisper/README.md create mode 100644 common/speech/lasr_speech_recognition_whisper/__init__.py rename common/speech/lasr_speech_recognition_whisper/{nodes/simple_transcribe_microphone => lasr_speech_recognition_whisper/simple_transcribe_microphone.py} (100%) rename common/speech/lasr_speech_recognition_whisper/{nodes/transcribe_microphone => lasr_speech_recognition_whisper/transcribe_microphone.py} (100%) rename common/speech/lasr_speech_recognition_whisper/{nodes/transcribe_microphone_server => lasr_speech_recognition_whisper/transcribe_microphone_server.py} (96%) mode change 100644 => 100755 common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py mode change 100644 => 100755 common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py mode change 100644 => 100755 common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py mode change 100644 => 100755 common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py mode change 100644 => 100755 common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/__init__.py b/common/speech/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/README.md b/common/speech/lasr_speech_recognition_whisper/README.md new file mode 100644 index 000000000..e386a720f --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/README.md @@ -0,0 +1,103 @@ +# lasr_speech_recognition_whisper + +Speech recognition implemented using OpenAI Whisper + +This package is maintained by: +- [Maayan Armony](mailto:maayan.armony@gmail.com) +- [Paul Makles](mailto:me@insrt.uk) (ROS1) + +## Prerequisites + +This package depends on the following ROS packages: +- colcon (buildtool) +- lasr_speech_recognition_interfaces + +This packages requires Python 3.10 to be present. + +This package has 48 Python dependencies: +- [SpeechRecognition](https://pypi.org/project/SpeechRecognition)==3.10.0 +- [openai-whisper](https://pypi.org/project/openai-whisper)==20230314 +- [PyAudio](https://pypi.org/project/PyAudio)==0.2.13 +- [PyYaml](https://pypi.org/project/PyYaml)==6.0.1 +- .. and sub dependencies (see [requirements file](requirements.txt)) + +This package requires that [ffmpeg](https://ffmpeg.org/) is available during runtime. + +## Usage + +> **Warning**: this package is not complete, this is subject to change. + +List available microphones: + +```bash +ros2 run lasr_speech_recognition_whisper list_microphones.py +``` + +Start the example script: + +```bash +ros2 run lasr_speech_recognition_whisper transcribe_microphone by-index +ros2 run lasr_speech_recognition_whisper transcribe_microphone by-name +``` + +Then start listening to people: + +```bash +rosservice call /whisper/start_listening "{}" +``` + +You can now listen on `/transcription` for a live transcription. + +Stop listening whenever: + +```bash +rosservice call /whisper/stop_listening "{}" +``` + +## Example + +Ask the package maintainer to write a `doc/EXAMPLE.md` for their package! + +## Technical Overview + +This package does speech recognition in three parts: + +- Adjusting for background noise + + We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice data. + +- Collecting appropriate voice data for phrases + + We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually speaking with enough energy that we would consider them to be speaking to the robot. + +- Running inference on phrases + + We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe. + +The package can input from the following sources: + +- On-board or external microphone on device +- Audio data from ROS topic (WORK IN PROGRESS) + +The package can output transcriptions to: + +- Standard output +- A ROS topic + +## ROS Definitions + +### Launch Files + +This package has no launch files. + +### Messages + +This package has no messages. + +### Services + +This package has no services. + +### Actions + +This package has no actions. \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py similarity index 100% rename from common/speech/lasr_speech_recognition_whisper/nodes/simple_transcribe_microphone rename to common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py similarity index 100% rename from common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone rename to common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py similarity index 96% rename from common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server rename to common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py index 7ff1f3efe..a00cc5e54 100644 --- a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -1,6 +1,6 @@ #!/usr/bin python3 import os -import sounddevice # needed to remove ALSA error messages +# import sounddevice # needed to remove ALSA error messages import argparse from typing import Optional from dataclasses import dataclass @@ -15,12 +15,11 @@ from rclpy.action.server import ActionServer, CancelResponse import speech_recognition as sr # type: ignore -import lasr_speech_recognition_interfaces.msg # type: ignore +from lasr_speech_recognition_interfaces.action import TranscribeSpeech # type: ignore from rclpy.executors import ExternalShutdownException from std_msgs.msg import String # type: ignore from lasr_speech_recognition_whisper import ModelCache # type: ignore -# from common.speech.lasr_speech_recognition_whisper.scripts.repeat_after_me import result # TODO: argpars -> ROS2 params, test behaviour of preemption @@ -55,10 +54,10 @@ class speech_model_params: pause_threshold: Optional[float] = 2.0 -class TranscribeSpeechAction(object, Node): +class TranscribeSpeechAction(Node): # create messages that are used to publish feedback/result - _feedback = lasr_speech_recognition_interfaces.msg.TranscribeSpeechFeedback() - _result = lasr_speech_recognition_interfaces.msg.TranscribeSpeechResult() + _feedback = TranscribeSpeech.Feedback() + _result = TranscribeSpeech.Result() def __init__( self, @@ -77,7 +76,8 @@ def __init__( String, "/live_speech_transcription", 10 ) - self._model = ModelCache.load_model( + model_cache = ModelCache() + self._model = model_cache.load_model( self._model_params.model_name, self._model_params.device, self._model_params.warmup, @@ -385,7 +385,4 @@ def main(args=None): try: rclpy.spin(server) except (KeyboardInterrupt, ExternalShutdownException): - pass - -# if __name__ == "__main__": -# main() + pass \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py old mode 100644 new mode 100755 diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py old mode 100644 new mode 100755 diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py old mode 100644 new mode 100755 diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py old mode 100644 new mode 100755 diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py old mode 100644 new mode 100755 diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py index cc8613cb8..1e163e63f 100644 --- a/common/speech/lasr_speech_recognition_whisper/setup.py +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -5,9 +5,7 @@ setup( name=package_name, version='0.0.0', - # packages=find_packages(exclude=['test']), - packages=["lasr_speech_recognition_whisper"], - package_dir = {'': 'src'}, + packages=find_packages(exclude=['test']), data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + package_name]), @@ -22,7 +20,10 @@ tests_require=['pytest'], entry_points={ 'console_scripts': [ - 'transcribe_mic_server_node = lasr_speech_recognition_whisper.nodes.transcribe_microphone_server:main', + 'transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main', + 'transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main', + 'simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main', + 'list_microphones = lasr_speech_recognition_whisper.list_microphones:main', ], }, ) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py index 196d4a715..fb4038509 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -2,6 +2,7 @@ import whisper # type: ignore from ament_index_python import packages from rclpy.node import Node + # Keep all loaded models in memory MODEL_CACHE = {} From 2d3091eb4fe3f6b31debdd6e3b84ea06e317e40e Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Mon, 4 Nov 2024 15:27:47 +0000 Subject: [PATCH 05/14] Fix src import --- .../transcribe_microphone_server.py | 4 +--- .../requirements.txt | 2 +- .../scripts/__init__.py | 12 ++++++++++++ .../lasr_speech_recognition_whisper/setup.py | 7 +++++++ .../src/__init__.py | 12 ++++++++++++ .../lasr_speech_recognition_whisper/__init__.py | 2 +- .../lasr_speech_recognition_whisper/cache.py | 2 +- .../lasr_speech_recognition_whisper/test.m4a | Bin 0 -> 14117 bytes 8 files changed, 35 insertions(+), 6 deletions(-) create mode 100644 common/speech/lasr_speech_recognition_whisper/scripts/__init__.py create mode 100644 common/speech/lasr_speech_recognition_whisper/src/__init__.py create mode 100644 common/speech/lasr_speech_recognition_whisper/test.m4a diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py index a00cc5e54..340e3328f 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -18,9 +18,7 @@ from lasr_speech_recognition_interfaces.action import TranscribeSpeech # type: ignore from rclpy.executors import ExternalShutdownException from std_msgs.msg import String # type: ignore -from lasr_speech_recognition_whisper import ModelCache # type: ignore - - +from src import ModelCache # type: ignore # TODO: argpars -> ROS2 params, test behaviour of preemption diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt index 87bab6318..3aac78679 100644 --- a/common/speech/lasr_speech_recognition_whisper/requirements.txt +++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt @@ -42,4 +42,4 @@ typing-extensions==4.11.0 # via torch urllib3==2.2.1 # via requests # The following packages are considered to be unsafe in a requirements file: -# setuptools +# setuptools == 60.0.1 diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/__init__.py b/common/speech/lasr_speech_recognition_whisper/scripts/__init__.py new file mode 100644 index 000000000..372e26477 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/__init__.py @@ -0,0 +1,12 @@ +# from .collector import AbstractPhraseCollector, AudioTopicPhraseCollector, MicrophonePhraseCollector, RecognizerPhraseCollector +from .collector import ( + AbstractPhraseCollector, + MicrophonePhraseCollector, + RecognizerPhraseCollector, +) +from .worker import ( + SpeechRecognitionWorker, + SpeechRecognitionToStdout, + SpeechRecognitionToTopic, +) +from .cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py index 1e163e63f..d2489fc98 100644 --- a/common/speech/lasr_speech_recognition_whisper/setup.py +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -6,6 +6,13 @@ name=package_name, version='0.0.0', packages=find_packages(exclude=['test']), + # packages=[package_name, f"{package_name}.lasr_speech_recognition_whisper", f"{package_name}.src"], + # package_dir={ + # '': '.', + # package_name: os.path.join(package_name), + # f"{package_name}.whisper": os.path.join(package_name, 'whisper'), + # f"{package_name}.src": os.path.join(package_name, 'src'), + # }, data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + package_name]), diff --git a/common/speech/lasr_speech_recognition_whisper/src/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/__init__.py new file mode 100644 index 000000000..473b206b7 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/__init__.py @@ -0,0 +1,12 @@ +# from .collector import AbstractPhraseCollector, AudioTopicPhraseCollector, MicrophonePhraseCollector, RecognizerPhraseCollector +from .lasr_speech_recognition_whisper.collector import ( + AbstractPhraseCollector, + MicrophonePhraseCollector, + RecognizerPhraseCollector, +) +from .lasr_speech_recognition_whisper.worker import ( + SpeechRecognitionWorker, + SpeechRecognitionToStdout, + SpeechRecognitionToTopic, +) +from .lasr_speech_recognition_whisper.cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py index 372e26477..69327473c 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py @@ -9,4 +9,4 @@ SpeechRecognitionToStdout, SpeechRecognitionToTopic, ) -from .cache import ModelCache +from .cache import ModelCache \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py index fb4038509..19e7e8d6a 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -34,7 +34,7 @@ def load_model( MODEL_CACHE[name] = whisper.load_model(name, device=device) self.get_logger().info(f"Sucessfully loaded model {name} on {device}") if load_test_file: - package_root = packages.get_package_share_path("lasr_speech_recognition_whisper") + package_root = packages.get_package_prefix("lasr_speech_recognition_whisper") example_fp = os.path.join(package_root, "test.m4a") self.get_logger().info( "Running transcription on example file to ensure model is loaded..." diff --git a/common/speech/lasr_speech_recognition_whisper/test.m4a b/common/speech/lasr_speech_recognition_whisper/test.m4a new file mode 100644 index 0000000000000000000000000000000000000000..1fbef3f084479d426332bf973c4673e1dab0fa7f GIT binary patch literal 14117 zcmeI1UrQ8G6u|GSsTC4xR9X@uD+;md?5>F~F_E0ZE(RSUL&B>jeWoFhe5PS)u zr=X~Z9((B{1VZ%CO9jzaDEa_BSr0vQ&Yc-|c3Kzn448Ef-oNLbGxz?^nYp5jvA&YG zR<)}6ECVrJ+w%KIbCzwd!c^H>UDi?edGuo%3u7bQjLm9HVfa&*h5<~+##()4v)*Bt z{AL**hc>;emz{8trHi?67`FO~vE~MYHfYb_x}t53@rv%)rarc6IId||Ks6XwS-es! z@uF9A417%gZ+crd%AmMqmR%3p(@^j3A2MxaBqE4 z?LtXN0K96A6Y9>;Z< z{u-`sVD8#A1h*~M#8~-mnUxYucDt7ED_D@${rLgiF;I_!yX@j+6JrE@Wz+SRkm+^T zb45M{Yvv)}aDOOhUW4}BDz2kAQQx5MwC{Q-XN`D0OofA;1- zDCv@87;w%3!+#y1iwuL*6hTuywMgZ5G-xWqHlm4*wTKEHIk#gfD}_gqr46YZd{Va6 z3;rkOXOPNWm3k`gB8}PGAdMGeX*>td!A}K>!mHJ?S>zske#U`I#VtB!6-R|EYo!C* zcn*g|1|m6(bJ5Zl$=hK4@MGUsL%yO8Ux>7!i}xfJPtc$U56b%%@$B#yHmQTBZ);*> zZCimIWhTLRZtXFiuY-7Ajm2|jr+D(3ZrTNp-*#icY*4`Ql8WRr-`8FvnYz;=lARa_ z#^x^jZKRGx&d-XvUaNPZr+jG{$<)K}gTIS6Q+c|#NGdmjNM4CW^1@D$yvtXOlI>Up zwc(|VOvBOQQ5hgz9h`Vw+=A_b z%%q|EhvFp;1=YDB3w}R^4n|Q?w4sd`>x=g@2UEu7{`WO-lK;Nq&q)4~KZStI9|}FX zt>jN Date: Mon, 4 Nov 2024 16:01:18 +0000 Subject: [PATCH 06/14] Finish testing server node --- .../transcribe_microphone_server.py | 4 ++-- .../speech/lasr_speech_recognition_whisper/requirements.txt | 2 +- .../src/lasr_speech_recognition_whisper/cache.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py index 340e3328f..9ef4ea6a7 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -1,6 +1,6 @@ #!/usr/bin python3 import os -# import sounddevice # needed to remove ALSA error messages +import sounddevice # needed to remove ALSA error messages import argparse from typing import Optional from dataclasses import dataclass @@ -86,7 +86,7 @@ def __init__( # Set up the action server and register execution callback self._action_server = ActionServer( self, - lasr_speech_recognition_interfaces.action.TranscribeSpeechAction, + TranscribeSpeech, self._action_name, execute_callback=self.execute_cb, cancel_callback=self.cancel_cb, diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt index 3aac78679..eade8e0a3 100644 --- a/common/speech/lasr_speech_recognition_whisper/requirements.txt +++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt @@ -25,7 +25,7 @@ nvidia-nccl-cu12==2.20.5 # via torch nvidia-nvjitlink-cu12==12.4.127 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 # via torch openai-whisper==20231117 # via -r requirements.in -# pyaudio==0.2.13 # via -r requirements.in +pyaudio==0.2.13 # via -r requirements.in pycparser==2.22 # via cffi pyyaml==6.0.1 # via -r requirements.in regex==2024.4.28 # via tiktoken diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py index 19e7e8d6a..7a86f38e6 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -34,7 +34,8 @@ def load_model( MODEL_CACHE[name] = whisper.load_model(name, device=device) self.get_logger().info(f"Sucessfully loaded model {name} on {device}") if load_test_file: - package_root = packages.get_package_prefix("lasr_speech_recognition_whisper") + package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") + package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) example_fp = os.path.join(package_root, "test.m4a") self.get_logger().info( "Running transcription on example file to ensure model is loaded..." From d9da0beb013604f38ce39ce394e990fd92ce6c89 Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Fri, 22 Nov 2024 13:11:33 +0000 Subject: [PATCH 07/14] Changes for whisper on 24 laptop --- .../lasr_speech_recognition_whisper/README.md | 4 +- .../simple_transcribe_microphone.py | 9 +- .../transcribe_microphone.py | 7 +- .../log/COLCON_IGNORE | 0 .../requirements.in | 5 +- .../scripts/__init__.py | 12 --- .../scripts/list_microphones.py | 18 +++- .../scripts/microphone_tuning_test.py | 18 ++-- .../scripts/repeat_after_me.py | 87 ------------------- .../scripts/test_microphones.py | 19 ++-- .../scripts/test_speech_server.py | 44 ++++++---- .../lasr_speech_recognition_whisper/setup.py | 5 +- .../collector.py | 12 +-- .../lasr_speech_recognition_whisper/source.py | 30 ++++--- .../lasr_speech_recognition_whisper/worker.py | 32 +++---- 15 files changed, 126 insertions(+), 176 deletions(-) create mode 100644 common/speech/lasr_speech_recognition_whisper/log/COLCON_IGNORE delete mode 100755 common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py diff --git a/common/speech/lasr_speech_recognition_whisper/README.md b/common/speech/lasr_speech_recognition_whisper/README.md index e386a720f..9954290ec 100644 --- a/common/speech/lasr_speech_recognition_whisper/README.md +++ b/common/speech/lasr_speech_recognition_whisper/README.md @@ -43,7 +43,7 @@ ros2 run lasr_speech_recognition_whisper transcribe_microphone by-name ROS params @@ -21,11 +22,14 @@ def configure_whisper_cache() -> None: """Configures the whisper cache directory.""" whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") os.makedirs(whisper_cache, exist_ok=True) - # Environemntal variable required to run whisper locally + # Environmental variable required to run whisper locally os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache -def main(): +def main(args=None): + rclpy.init(args=args) # Have to initialise rclpy for the ModelCache + + configure_whisper_cache() args = parse_args() recognizer = sr.Recognizer() @@ -34,7 +38,8 @@ def main(): threshold = 100 recognizer.dynamic_energy_threshold = False recognizer.energy_threshold = threshold - transcription_model = load_model( + model_cache = ModelCache() + transcription_model = model_cache.load_model( "medium.en", "cuda" if torch.cuda.is_available() else "cpu", True ) transcription_result = "The quick brown fox jumps over the lazy dog." @@ -62,7 +67,6 @@ def main(): threshold += 100 recognizer.energy_threshold = threshold - if __name__ == "__main__": - configure_whisper_cache() - main() + + main() \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py deleted file mode 100755 index 6876d96b1..000000000 --- a/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py +++ /dev/null @@ -1,87 +0,0 @@ -# #!/usr/bin python3 -# import rclpy -# from rclpy.node import Node -# from rclpy.action import ActionClient -# from lasr_voice import Voice # type: ignore -# from lasr_speech_recognition_interfaces.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore -# from lasr_speech_recognition_interfaces.action import TranscribeSpeechAction, TranscribeSpeechGoal -# -# -# # TODO port file: action client, service proxy -# -# USE_ACTIONLIB = True -# -# class ServiceClientNode(Node): -# def __init__(self): -# super().__init__('service_client_node') -# self.client = None -# -# self.voice = Voice() -# -# # def call_service(self): -# # request = TranscribeAudio -# # -# # # Call the service synchronously -# # future = self.client.call_async(request) -# # rclpy.spin_until_future_complete(self, future) -# # -# # if future.result() is not None: -# # self.get_logger().info('Service call succeeded') -# # else: -# # self.get_logger().error('Service call failed') -# -# -# def call_service(self): -# request = TranscribeAudio -# else: -# # transcribe = rospy.ServiceProxy("/whisper/transcribe_audio", TranscribeAudio) -# repeating = False -# while rclpy.ok(): -# future = self.client.call_async(request) -# rclpy.spin_until_future_complete(self, future) -# if future.done(): -# text = transcribe().phrase -# self.get_logger().info(text) -# if "tiago" in text.lower().strip(): -# if "repeat" in text.lower().strip(): -# repeating = True -# self.voice.sync_tts("Okay, I'll start repeating now.") -# continue -# elif "stop" in text.lower().strip(): -# repeating = False -# self.voice.sync_tts("Okay, I'll stop repeating now.") -# break -# if repeating: -# self.voice.sync_tts(f"I heard {text}") -# -# -# class ActionClientNode(Node): -# def __init__(self): -# super().__init__('action_client_node') -# self.client = None -# -# def goal_callback(self, future): -# self.client = self.create_client(TranscribeSpeechAction, "transcribe_speech") -# # Wait for the server to be available -# while not self.client.wait_for_service(timeout_sec=5.0): -# self.get_logger().info("Waiting for server...") -# repeating = False -# self.get_logger().info("Done waiting") -# while rclpy.ok(): -# goal = TranscribeSpeechGoal() -# self.client.send_goal(goal) -# self.client.wait_for_result() -# result = self.client.get_result() -# text = result.sequence -# self.get_logger().info(text) -# if "tiago" in text.lower().strip(): -# if "repeat" in text.lower().strip(): -# repeating = True -# self.voice.sync_tts("Okay, I'll start repeating now.") -# continue -# elif "stop" in text.lower().strip(): -# repeating = False -# self.voice.sync_tts("Okay, I'll stop repeating now.") -# break -# if repeating: -# self.voice.sync_tts(f"I heard {text}") \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py index 85616017c..ed30691ab 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py @@ -3,7 +3,9 @@ import os import argparse import speech_recognition as sr +import rclpy +# TODO argparse -> ROS params def parse_args() -> dict: """Parse command line arguments into a dictionary. @@ -18,10 +20,12 @@ def parse_args() -> dict: "-o", "--output_dir", type=str, help="Directory to save audio files" ) - return vars(parser.parse_args()) + # return vars(parser.parse_args()) + args, _ = parser.parse_known_args() + return vars(args) -def main(args: dict) -> None: +def main(args: dict = None) -> None: """Generate audio files from microphone input. Args: @@ -30,8 +34,12 @@ def main(args: dict) -> None: # Adapted from https://github.com/Uberi/speech_recognition/blob/master/examples/write_audio.py - mic_index = args["microphone"] - output_dir = args["output_dir"] + rclpy.init(args=args) + + parser_args = parse_args() + + mic_index = parser_args["microphone"] + output_dir = parser_args["output_dir"] r = sr.Recognizer() r.pause_threshold = 2 @@ -52,6 +60,7 @@ def main(args: dict) -> None: with open(os.path.join(output_dir, "microphone.aiff"), "wb") as f: f.write(audio.get_aiff_data()) + rclpy.shutdown() if __name__ == "__main__": - main(parse_args()) + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py index 7456293c2..ff4c570b9 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -4,23 +4,35 @@ import rclpy from rclpy.action import ActionClient from lasr_speech_recognition_interfaces.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore -from lasr_speech_recognition_interfaces.msg import ( # type: ignore - TranscribeSpeechAction, - TranscribeSpeechGoal, -) +from lasr_speech_recognition_interfaces.action import TranscribeSpeech -# TODO port file: action client, is_shutdown +# TODO port file: action client -with rclpy.init(args=None): - node = rclpy.create_node("test_speech_server") +class TestSpeechServerClient: + def __init__(self): + self.node = rclpy.create_node("test_speech_server") + self.client = ActionClient(self.node, TranscribeSpeech, "transcribe_speech") -client = ActionClient("transcribe_speech", TranscribeSpeechAction) -client.wait_for_server() -node.get_logger().info("Done waiting") -while not rospy.is_shutdown(): - goal = TranscribeSpeechGoal() - client.send_goal(goal) - client.wait_for_result() - result = client.get_result() - text = result.sequence + def send_goal(self, msg): + goal = TranscribeSpeech.Goal() + goal.msg = msg + + self.client.wait_for_server() + self.client.send_goal(goal) # should be future and async? + + # TODO add callback with future and handle result + +# client.wait_for_server() +# node.get_logger().info("Done waiting") +while rclpy.ok(): + rclpy.init() + # goal = TranscribeSpeech.Goal() + # client.send_goal(goal) + client = TestSpeechServerClient() + client.send_goal(10) + rclpy.spin(client.node) + # client.wait_for_result() + # result = client.get_result() + # text = result.sequence + text = "" print(f"Transcribed Speech: {text}") diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py index d2489fc98..3fbac464a 100644 --- a/common/speech/lasr_speech_recognition_whisper/setup.py +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -30,7 +30,10 @@ 'transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main', 'transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main', 'simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main', - 'list_microphones = lasr_speech_recognition_whisper.list_microphones:main', + 'list_microphones = scripts.list_microphones:main', + 'microphone_tuning_test = scripts.microphone_tuning_test:main', + 'test_microphones = scripts.test_microphones:main', + 'test_speech_server = scripts.test_speech_server:main', ], }, ) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py index bca4ee00f..936896e77 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py @@ -1,5 +1,5 @@ import rclpy - +from rclpy.node import Node import speech_recognition as sr from queue import Queue @@ -7,6 +7,7 @@ # from .source import AudioTopic +# TODO test class AbstractPhraseCollector(ABC): """ @@ -44,7 +45,7 @@ def sample_width(self): pass -class RecognizerPhraseCollector(AbstractPhraseCollector): +class RecognizerPhraseCollector(AbstractPhraseCollector, Node): """ Collect phrases using a SoundRecognition Recognizer @@ -65,8 +66,7 @@ def __init__( self, energy_threshold: int = 500, phrase_time_limit: float = 2 ) -> None: super().__init__() - with rclpy.init(args=None): - self.node = rclpy.create_node('source') + Node.__init__(self, "collector") self._recorder = sr.Recognizer() self._recorder.dynamic_energy_threshold = False @@ -75,13 +75,13 @@ def __init__( @abstractmethod def adjust_for_noise(self, source: sr.AudioSource): - self.node.get_logger().info("Adjusting for background noise...") + self.get_logger().info("Adjusting for background noise...") with source: self._recorder.adjust_for_ambient_noise(source) @abstractmethod def start(self, source: sr.AudioSource): - self.node.get_logger().info("Started source listen thread") + self.get_logger().info("Started source listen thread") self._stopper = self._recorder.listen_in_background( source, self._record_callback, phrase_time_limit=self._phrase_time_limit ) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py index 53f8f103f..abcd0fd1e 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py @@ -1,4 +1,5 @@ import rclpy +from rclpy.node import Node import pyaudio import speech_recognition as sr @@ -8,7 +9,7 @@ # TODO rospy.wait_for_message() -class AudioTopic(sr.AudioSource): +class AudioTopic(sr.AudioSource, Node): """ Use a ROS topic as an AudioSource """ @@ -17,21 +18,26 @@ class AudioTopic(sr.AudioSource): # _sub: node.create_subscription TODO add type if possible def __init__(self, topic: str, chunk_size=1024) -> None: - with rclpy.init(args=None): - self.node = rclpy.create_node('source') + Node.__init__(self, "source") self._topic = topic + self.subscription = self.create_subscription(AudioInfo, f"{topic}/audio_info", self.callback, 10) + # config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo) + self.config = None # TODO test that this works + if self.config is not None: + assert self.config.coding_format == "wave", "Expected Wave audio format" + assert self.config.sample_format == "S16LE", "Expected sample format S16LE" + self.get_logger().info(self.config) - config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo) - assert config.coding_format == "wave", "Expected Wave audio format" - assert config.sample_format == "S16LE", "Expected sample format S16LE" - self.node.get_logger().info(config) + self.SAMPLE_WIDTH = pyaudio.get_sample_size(pyaudio.paInt16) + self.SAMPLE_RATE = self.config.sample_rate - self.SAMPLE_WIDTH = pyaudio.get_sample_size(pyaudio.paInt16) - self.SAMPLE_RATE = config.sample_rate + self.CHUNK = chunk_size + self.stream = None - self.CHUNK = chunk_size - self.stream = None + def callback(self, msg): + self.get_logger().info("Message received") + self.config = msg def __enter__(self): """ @@ -51,7 +57,7 @@ def __exit__(self, exc_type, exc_value, traceback): """ self.stream = None - self.node.destroy_subscription(self._sub) # TODO behaviour, was self._sub.unregister() + self.destroy_subscription(self._sub) # TODO behaviour, was self._sub.unregister() def _read(self, msg: AudioData) -> None: """ diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py index f2dcf0cbb..998475578 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py @@ -1,5 +1,8 @@ import torch -import rclpy + +from rclpy.node import Node +from rclpy.publisher import Publisher + import whisper import speech_recognition as sr @@ -15,7 +18,7 @@ from lasr_speech_recognition_interfaces.msg import Transcription -class SpeechRecognitionWorker(ABC): +class SpeechRecognitionWorker(ABC, Node): """ Collect and run inference on phrases to produce a transcription """ @@ -36,8 +39,7 @@ def __init__( maximum_phrase_length=timedelta(seconds=3), infer_partial=True, ) -> None: - with rclpy.init(args=None): - self.node = rclpy.create_node('worker') + Node.__init__(self, 'worker') self._collector = collector self._tmp_file = NamedTemporaryFile().name self._model = model @@ -70,7 +72,7 @@ def _perform_inference(self): Run inference on the current sample """ - self.node.get_logger().info("Processing sample") + self.get_logger().info("Processing sample") audio_data = sr.AudioData( self._current_sample, self._collector.sample_rate(), @@ -81,7 +83,7 @@ def _perform_inference(self): with open(self._tmp_file, "w+b") as f: f.write(wav_data.read()) - self.node.get_logger().info("Running inference") + self.get_logger().info("Running inference") try: result = self._model.transcribe( self._tmp_file, fp16=torch.cuda.is_available() @@ -94,7 +96,7 @@ def _perform_inference(self): if len(text) == 0 or text.lower() in [".", "you", "thanks for watching!"]: self._phrase_start = None self._current_sample = bytes() - self.node.get_logger().info("Skipping garbage...") + self.get_logger().info("Skipping garbage...") return None return text @@ -104,7 +106,7 @@ def _worker(self): Indefinitely perform inference on the given data """ - self.node.get_logger().info("Started inference worker") + self.get_logger().info("Started inference worker") while not self._stopped: try: @@ -114,7 +116,7 @@ def _worker(self): self._phrase_start and now - self._phrase_start > self._maximum_phrase_length ): - self.node.get_logger().info("Reached timeout for phrase, ending now.") + self.get_logger().info("Reached timeout for phrase, ending now.") self._finish_phrase() # Start / continue phrase if data is coming in @@ -125,7 +127,7 @@ def _worker(self): while not self._collector.data.empty(): self._current_sample += self._collector.data.get() - self.node.get_logger().info( + self.get_logger().info( "Received and added more data to current audio sample." ) @@ -141,7 +143,7 @@ def _worker(self): except KeyboardInterrupt: self._stopped = True - self.node.get_logger().info("Worker finished") + self.get_logger().info("Worker finished") def start(self): """ @@ -175,7 +177,7 @@ class SpeechRecognitionToStdout(SpeechRecognitionWorker): """ def on_phrase(self, phrase: str, finished: bool) -> None: - self.node.get_logger().info("[" + ("x" if finished else " ") + "] " + phrase) + self.get_logger().info("[" + ("x" if finished else " ") + "] " + phrase) class SpeechRecognitionToTopic(SpeechRecognitionToStdout): @@ -183,7 +185,7 @@ class SpeechRecognitionToTopic(SpeechRecognitionToStdout): Recognise speech and publish it to a topic """ - # _pub: node.create_publisher() TODO add type if possible + _pub: Publisher def __init__( self, @@ -194,8 +196,8 @@ def __init__( infer_partial=True, ) -> None: super().__init__(collector, model, maximum_phrase_length, infer_partial) - self.node.get_logger().info(f"Will be publishing transcription to {topic}") - self._pub = self.node.create_publisher(Transcription, topic, 5) + self.get_logger().info(f"Will be publishing transcription to {topic}") + self._pub = self.create_publisher(Transcription, topic, 5) def on_phrase(self, phrase: str, finished: bool) -> None: super().on_phrase(phrase, finished) From 802950c3d3efa4900141eda4336a21ee7c918e8a Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Fri, 22 Nov 2024 21:10:49 +0000 Subject: [PATCH 08/14] Fix test_microphones ALSA issue --- .../scripts/microphone_tuning_test.py | 2 +- .../scripts/test_microphones.py | 6 ++-- .../scripts/test_speech_server.py | 32 +++++++++++-------- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py index 08338d6ea..c77396467 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py @@ -14,7 +14,7 @@ def parse_args() -> Dict: parser = argparse.ArgumentParser() - parser.add_argument("--device_index", type=int, default=None) + parser.add_argument("--device_index", help="Microphone index", type=int, default=None) return vars(parser.parse_args()) diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py index ed30691ab..e0c94c23e 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py @@ -4,6 +4,7 @@ import argparse import speech_recognition as sr import rclpy +import sounddevice # needed to remove ALSA error messages # TODO argparse -> ROS params @@ -15,7 +16,7 @@ def parse_args() -> dict: """ parser = argparse.ArgumentParser(description="Test microphones") - parser.add_argument("-m", "--microphone", type=int, help="Microphone index") + parser.add_argument("-m", "--microphone", type=int, help="Microphone index", default=None) parser.add_argument( "-o", "--output_dir", type=str, help="Directory to save audio files" ) @@ -43,7 +44,8 @@ def main(args: dict = None) -> None: r = sr.Recognizer() r.pause_threshold = 2 - with sr.Microphone(device_index=9, sample_rate=16000) as source: + microphone = sr.Microphone(device_index=mic_index, sample_rate=16000) + with microphone as source: print("Say something!") audio = r.listen(source, timeout=5, phrase_time_limit=10) print("Finished listening") diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py index ff4c570b9..b523687d5 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -1,9 +1,8 @@ #!/usr/bin python3 from argparse import Action - import rclpy from rclpy.action import ActionClient -from lasr_speech_recognition_interfaces.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore +from lasr_speech_recognition_interfaces.srv import TranscribeAudio # type: ignore from lasr_speech_recognition_interfaces.action import TranscribeSpeech # TODO port file: action client @@ -24,15 +23,20 @@ def send_goal(self, msg): # client.wait_for_server() # node.get_logger().info("Done waiting") -while rclpy.ok(): - rclpy.init() - # goal = TranscribeSpeech.Goal() - # client.send_goal(goal) - client = TestSpeechServerClient() - client.send_goal(10) - rclpy.spin(client.node) - # client.wait_for_result() - # result = client.get_result() - # text = result.sequence - text = "" - print(f"Transcribed Speech: {text}") + +def main(): + while rclpy.ok(): + rclpy.init() + # goal = TranscribeSpeech.Goal() + # client.send_goal(goal) + client = TestSpeechServerClient() + client.send_goal(10) + rclpy.spin(client.node) + # client.wait_for_result() + # result = client.get_result() + # text = result.sequence + text = "" + print(f"Transcribed Speech: {text}") + +if __name__ == '__main__': + main() From 36cebad18afb91bd81d9a49d439ee999ba68daa1 Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Fri, 22 Nov 2024 21:53:11 +0000 Subject: [PATCH 09/14] Test rest of srcipts exc test_speech_server --- .../simple_transcribe_microphone.py | 15 +- .../transcribe_microphone.py | 155 ++++++++++-------- .../collector.py | 8 +- 3 files changed, 94 insertions(+), 84 deletions(-) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py index f1fc167f3..960bd5997 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py @@ -9,8 +9,9 @@ import speech_recognition as sr import numpy as np -from lasr_speech_recognition_interfaces.srv import TranscribeAudio, TranscribeAudioResponse -from lasr_speech_recognition_whisper import load_model +import sounddevice # needed to remove ALSA error messages +from lasr_speech_recognition_interfaces.srv import TranscribeAudio +from src import ModelCache # type: ignore MODEL = "medium.en" # Whisper model TIMEOUT = 5.0 # Timeout for listening for the start of a phrase @@ -48,17 +49,19 @@ exit(1) rclpy.init(args=sys.argv) -node = rclpy.create_node('transcribe_mic', anonymous=True) +node = rclpy.create_node('transcribe_mic') device = "cuda" if torch.cuda.is_available() else "cpu" -model = load_model("medium.en", device=device) +model_cache = ModelCache() +model = model_cache.load_model("medium.en", device=device) # try to run inference on the example file package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) example_fp = os.path.join(package_root, "test.m4a") node.get_logger().info("Running transcription on example file to ensure model is loaded...") -node.get_logger().info(model.transcribe(example_fp, fp16=torch.cuda.is_available())) +transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available()) +node.get_logger().info(str(transcription)) microphone = sr.Microphone(device_index=device_index, sample_rate=16000) r = sr.Recognizer() @@ -72,7 +75,7 @@ def handle_transcribe_audio(_): float_data = np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order='C') / 32768.0 phrase = model.transcribe(float_data, fp16=device == "cuda")["text"] - return TranscribeAudioResponse(phrase=phrase) + return TranscribeAudio.Response(phrase=phrase) node.create_service(TranscribeAudio, '/whisper/transcribe_audio', handle_transcribe_audio) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py index 0bf7e6add..f8553b0b6 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py @@ -1,84 +1,95 @@ #!/usr/bin python3 import os +import sys import torch -from ament_index_python import packages from pathlib import Path +import rclpy +from rclpy.node import Node +from ament_index_python import packages +from std_srvs.srv import Empty +from src import SpeechRecognitionToTopic, MicrophonePhraseCollector, ModelCache + WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper') os.makedirs(WHISPER_CACHE, exist_ok=True) os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE -import sys - -# TODO port node - -if len(sys.argv) < 3: - print('Usage:') - print('rosrun lasr_speech_recognition transcribe_microphone by-index ') - print('rosrun lasr_speech_recognition transcribe_microphone by-name ') - exit(1) -else: - matcher = sys.argv[1] - device_index = None - if matcher == 'by-index': - device_index = int(sys.argv[2]) - elif matcher == 'by-name': - import speech_recognition as sr - microphones = enumerate(sr.Microphone.list_microphone_names()) - - target_name = sys.argv[2] - for index, name in microphones: - if target_name in name: - device_index = index - break - - if device_index is None: - print('Could not find device!') - exit(1) - else: - print('Invalid matcher') - exit(1) - -import rclpy -from std_srvs.srv import Empty, EmptyResponse - -with rclpy.init(args=None): - node = rclpy.create_node('transcribe_mic') # was anonymous in ROS1 +class TranscribeMicrophone(Node): + def __init__(self): + Node.__init__(self, 'transcribe_microphone') + self.worker = None + self.collector = None -from lasr_speech_recognition_whisper import SpeechRecognitionToTopic, MicrophonePhraseCollector, load_model + self.create_service(Empty, '/whisper/adjust_for_noise', self.adjust_for_noise) + self.create_service(Empty, '/whisper/start_listening', self.start_listening) + self.create_service(Empty, '/whisper/stop_listening', self.stop_listening) -collector = MicrophonePhraseCollector(device_index=device_index) -collector.adjust_for_noise() + self.get_logger().info("Starting the Whisper worker!") + self.run_transcription() -#model = load_model("base.en") - -model = load_model("medium.en") - -# try to run inference on the example file -package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") -package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) -example_fp = os.path.join(package_root, "test.m4a") - -node.get_logger().info("Running transcription on example file to ensure model is loaded...") -node.get_logger().info(model.transcribe(example_fp, fp16=torch.cuda.is_available())) - -worker = SpeechRecognitionToTopic(collector, model, "transcription", infer_partial = False) - -def adjust_for_noise(_): - collector.adjust_for_noise() - return EmptyResponse() - -def start_listening(_): - worker.start() - return EmptyResponse() - -def stop_listening(_): - worker.stop() - return EmptyResponse() - -node.create_service(Empty, '/whisper/adjust_for_noise', adjust_for_noise) -node.create_service(Empty, '/whisper/start_listening', start_listening) -node.create_service(Empty, '/whisper/stop_listening', stop_listening) - -node.get_logger().info("Starting the Whisper worker!") -rclpy.spin(node) + def run_transcription(self): + if len(sys.argv) < 3: + print('Usage:') + print('rosrun lasr_speech_recognition transcribe_microphone by-index ') + print('rosrun lasr_speech_recognition transcribe_microphone by-name ') + exit(1) + else: + matcher = sys.argv[1] + device_index = None + if matcher == 'by-index': + device_index = int(sys.argv[2]) + elif matcher == 'by-name': + import speech_recognition as sr + microphones = enumerate(sr.Microphone.list_microphone_names()) + + target_name = sys.argv[2] + for index, name in microphones: + if target_name in name: + device_index = index + break + + if device_index is None: + print('Could not find device!') + exit(1) + else: + print('Invalid matcher') + exit(1) + + + self.collector = MicrophonePhraseCollector(device_index=device_index) + self.collector.adjust_for_noise() + + model_cache = ModelCache() + model = model_cache.load_model("medium.en") + + # try to run inference on the example file + package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") + package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) + example_fp = os.path.join(package_root, "test.m4a") + + self.get_logger().info("Running transcription on example file to ensure model is loaded...") + model_transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available()) + self.get_logger().info(str(model_transcription)) + + self.worker = SpeechRecognitionToTopic(self.collector, model, "transcription", infer_partial = False) + + def adjust_for_noise(self, request, response): + self.collector.adjust_for_noise() + return response + + def start_listening(self, request, response): + self.worker.start() + return response + + def stop_listening(self, request, response): + self.worker.stop() + return response + +def main(args=None): + rclpy.init(args=args) + transcribe_microphone = TranscribeMicrophone() + rclpy.spin(transcribe_microphone) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py index 936896e77..9edbc313b 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py @@ -5,10 +5,6 @@ from queue import Queue from abc import ABC, abstractmethod -# from .source import AudioTopic - -# TODO test - class AbstractPhraseCollector(ABC): """ Supertype holding a queue of audio data representing a phrase @@ -65,8 +61,8 @@ def _record_callback(self, _, audio: sr.AudioData) -> None: def __init__( self, energy_threshold: int = 500, phrase_time_limit: float = 2 ) -> None: - super().__init__() - Node.__init__(self, "collector") + super().__init__("collector") + # Node.__init__(self, "collector") self._recorder = sr.Recognizer() self._recorder.dynamic_energy_threshold = False From 44827106f5949495f25b3a1b4a8300464672e3ff Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Sun, 24 Nov 2024 17:18:27 +0000 Subject: [PATCH 10/14] Finish future callbacks for Action client --- .../transcribe_microphone_server.py | 2 +- .../scripts/test_speech_server.py | 66 ++++++++++++------- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py index 9ef4ea6a7..e7c307f50 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -189,7 +189,7 @@ async def execute_cb(self, goal_handle) -> None: goal = goal_handle.request self.get_logger().info("Request Received") - if goal_handle.is_cancel_requested(): + if goal_handle.is_cancel_requested: return if goal.energy_threshold > 0.0 and goal.max_phrase_limit > 0.0: diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py index b523687d5..d0670845d 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -1,42 +1,58 @@ #!/usr/bin python3 -from argparse import Action import rclpy +from rclpy.node import Node from rclpy.action import ActionClient from lasr_speech_recognition_interfaces.srv import TranscribeAudio # type: ignore from lasr_speech_recognition_interfaces.action import TranscribeSpeech -# TODO port file: action client +# https://docs.ros2.org/latest/api/rclpy/api/actions.html -class TestSpeechServerClient: +class TestSpeechServerClient(Node): def __init__(self): - self.node = rclpy.create_node("test_speech_server") - self.client = ActionClient(self.node, TranscribeSpeech, "transcribe_speech") + Node.__init__(self, "listen_action_client") - def send_goal(self, msg): - goal = TranscribeSpeech.Goal() - goal.msg = msg + self.client = ActionClient(self, TranscribeSpeech, "transcribe_speech") + self.goal_future = None + self.result_future = None + def send_goal(self, goal): + self.get_logger().info("Waiting for Whisper server...") self.client.wait_for_server() - self.client.send_goal(goal) # should be future and async? + self.get_logger().info("Server activated, sending goal...") + + self.goal_future = self.client.send_goal_async(goal, feedback_callback=self.feedback_cb) # Returns a Future instance when the goal request has been accepted or rejected. + self.goal_future.add_done_callback(self.response_cb) # When received get response + + def feedback_cb(self, msg): + self.get_logger().info(f"Received feedback: {msg.feedback}") - # TODO add callback with future and handle result + def response_cb(self, future): + handle = future.result() + if not handle.accepted: + self.get_logger().info("Goal was rejected") + return -# client.wait_for_server() -# node.get_logger().info("Done waiting") + self.get_logger().info("Goal was accepted") + self.result_future = handle.get_result_async() # Not using get_result() in cb, as can cause deadlock according to docs + self.result_future.add_done_callback(self.result_cb) -def main(): + def result_cb(self, future): + result = future.result().result + self.get_logger().info(f"Transcribed Speech: {result.sequence}") + +def main(args=None): + rclpy.init(args=args) while rclpy.ok(): - rclpy.init() - # goal = TranscribeSpeech.Goal() - # client.send_goal(goal) + goal = TranscribeSpeech.Goal() client = TestSpeechServerClient() - client.send_goal(10) - rclpy.spin(client.node) - # client.wait_for_result() - # result = client.get_result() - # text = result.sequence - text = "" - print(f"Transcribed Speech: {text}") - -if __name__ == '__main__': + try: + client.send_goal(goal) + rclpy.spin(client) + except KeyboardInterrupt: + client.get_logger().info("Shutting down...") + finally: + client.destroy_node() + rclpy.shutdown() + +if __name__ == "__main__": main() From 7f0dd5548d9d12c8bab06ac5ac5a12cf45b99341 Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Sun, 24 Nov 2024 17:33:29 +0000 Subject: [PATCH 11/14] refactor: Black format --- .../README.md | 12 ++-- .../package.xml | 32 +++++----- .../lasr_speech_recognition_whisper/README.md | 12 +++- .../simple_transcribe_microphone.py | 58 ++++++++++++------- .../transcribe_microphone.py | 54 +++++++++++------ .../transcribe_microphone_server.py | 23 ++++---- .../package.xml | 46 +++++++-------- .../scripts/list_microphones.py | 5 +- .../scripts/microphone_tuning_test.py | 10 +++- .../scripts/test_microphones.py | 6 +- .../scripts/test_speech_server.py | 15 ++++- .../lasr_speech_recognition_whisper/setup.cfg | 4 +- .../lasr_speech_recognition_whisper/setup.py | 39 ++++++------- .../__init__.py | 2 +- .../lasr_speech_recognition_whisper/cache.py | 19 ++++-- .../collector.py | 1 + .../lasr_speech_recognition_whisper/source.py | 13 ++++- .../lasr_speech_recognition_whisper/worker.py | 2 +- .../test/test_copyright.py | 8 ++- .../test/test_flake8.py | 6 +- .../test/test_pep257.py | 4 +- 21 files changed, 224 insertions(+), 147 deletions(-) diff --git a/common/speech/lasr_speech_recognition_interfaces/README.md b/common/speech/lasr_speech_recognition_interfaces/README.md index c96279cb6..8e7aab96f 100644 --- a/common/speech/lasr_speech_recognition_interfaces/README.md +++ b/common/speech/lasr_speech_recognition_interfaces/README.md @@ -3,17 +3,18 @@ Common messages used for speech recognition This package is maintained by: + - [Maayan Armony](mailto:maayan.armony@gmail.com) - [Paul Makles](mailto:me@insrt.uk) (ROS1) ## Prerequisites This package depends on the following ROS packages: + - colcon (buildtool) - message_generation (build) - message_runtime (exec) - ## Usage Ask the package maintainer to write a `doc/USAGE.md` for their package! @@ -36,11 +37,10 @@ This package has no launch files. #### `Transcription` -| Field | Type | Description | -|:-:|:-:|---| -| phrase | string | | -| finished | bool | | - +| Field | Type | Description | +|:--------:|:------:|-------------| +| phrase | string | | +| finished | bool | | ### Services diff --git a/common/speech/lasr_speech_recognition_interfaces/package.xml b/common/speech/lasr_speech_recognition_interfaces/package.xml index b15638eb1..fd72011b7 100644 --- a/common/speech/lasr_speech_recognition_interfaces/package.xml +++ b/common/speech/lasr_speech_recognition_interfaces/package.xml @@ -1,23 +1,23 @@ - lasr_speech_recognition_interfaces - 0.0.0 - Common messages used for speech recognition - maayan - MIT + lasr_speech_recognition_interfaces + 0.0.0 + Common messages used for speech recognition + maayan + MIT - ament_cmake - - rosidl_default_generators - action_msgs - rosidl_default_runtime - rosidl_interface_packages + ament_cmake + + rosidl_default_generators + action_msgs + rosidl_default_runtime + rosidl_interface_packages - ament_lint_auto - ament_lint_common + ament_lint_auto + ament_lint_common - - ament_cmake - + + ament_cmake + diff --git a/common/speech/lasr_speech_recognition_whisper/README.md b/common/speech/lasr_speech_recognition_whisper/README.md index 9954290ec..c9f58557e 100644 --- a/common/speech/lasr_speech_recognition_whisper/README.md +++ b/common/speech/lasr_speech_recognition_whisper/README.md @@ -3,18 +3,21 @@ Speech recognition implemented using OpenAI Whisper This package is maintained by: + - [Maayan Armony](mailto:maayan.armony@gmail.com) - [Paul Makles](mailto:me@insrt.uk) (ROS1) ## Prerequisites This package depends on the following ROS packages: + - colcon (buildtool) - lasr_speech_recognition_interfaces This packages requires Python 3.10 to be present. This package has 48 Python dependencies: + - [SpeechRecognition](https://pypi.org/project/SpeechRecognition)==3.10.0 - [openai-whisper](https://pypi.org/project/openai-whisper)==20230314 - [PyAudio](https://pypi.org/project/PyAudio)==0.2.13 @@ -64,15 +67,18 @@ This package does speech recognition in three parts: - Adjusting for background noise - We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice data. + We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice + data. - Collecting appropriate voice data for phrases - We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually speaking with enough energy that we would consider them to be speaking to the robot. + We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually + speaking with enough energy that we would consider them to be speaking to the robot. - Running inference on phrases - We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe. + We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after + which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe. The package can input from the following sources: diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py index 960bd5997..7b3b1f8a0 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py @@ -11,28 +11,31 @@ import sounddevice # needed to remove ALSA error messages from lasr_speech_recognition_interfaces.srv import TranscribeAudio -from src import ModelCache # type: ignore +from src import ModelCache # type: ignore -MODEL = "medium.en" # Whisper model -TIMEOUT = 5.0 # Timeout for listening for the start of a phrase -PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase +MODEL = "medium.en" # Whisper model +TIMEOUT = 5.0 # Timeout for listening for the start of a phrase +PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase -WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper') +WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper") os.makedirs(WHISPER_CACHE, exist_ok=True) os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE if len(sys.argv) < 3: - print('Usage:') - print('ros2 run lasr_speech_recognition transcribe_microphone by-index ') - print('ros2 run lasr_speech_recognition transcribe_microphone by-name ') + print("Usage:") + print( + "ros2 run lasr_speech_recognition transcribe_microphone by-index " + ) + print("ros2 run lasr_speech_recognition transcribe_microphone by-name ") exit(1) else: matcher = sys.argv[1] device_index = None - if matcher == 'by-index': + if matcher == "by-index": device_index = int(sys.argv[2]) - elif matcher == 'by-name': + elif matcher == "by-name": import speech_recognition as sr + microphones = enumerate(sr.Microphone.list_microphone_names()) target_name = sys.argv[2] @@ -40,16 +43,16 @@ if target_name in name: device_index = index break - + if device_index is None: - print('Could not find device!') + print("Could not find device!") exit(1) else: - print('Invalid matcher') + print("Invalid matcher") exit(1) rclpy.init(args=sys.argv) -node = rclpy.create_node('transcribe_mic') +node = rclpy.create_node("transcribe_mic") device = "cuda" if torch.cuda.is_available() else "cpu" model_cache = ModelCache() @@ -57,9 +60,15 @@ # try to run inference on the example file package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") -package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) +package_root = os.path.abspath( + os.path.join( + package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper" + ) +) example_fp = os.path.join(package_root, "test.m4a") -node.get_logger().info("Running transcription on example file to ensure model is loaded...") +node.get_logger().info( + "Running transcription on example file to ensure model is loaded..." +) transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available()) node.get_logger().info(str(transcription)) @@ -68,16 +77,25 @@ with microphone as source: r.adjust_for_ambient_noise(source) + def handle_transcribe_audio(_): with microphone as source: - wav_data = r.listen(source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT).get_wav_data() - float_data = np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order='C') / 32768.0 + wav_data = r.listen( + source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT + ).get_wav_data() + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) phrase = model.transcribe(float_data, fp16=device == "cuda")["text"] return TranscribeAudio.Response(phrase=phrase) -node.create_service(TranscribeAudio, '/whisper/transcribe_audio', handle_transcribe_audio) + +node.create_service( + TranscribeAudio, "/whisper/transcribe_audio", handle_transcribe_audio +) node.get_logger().info("Whisper service ready") -rclpy.spin(node) \ No newline at end of file +rclpy.spin(node) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py index f8553b0b6..3225072c3 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py @@ -10,36 +10,42 @@ from std_srvs.srv import Empty from src import SpeechRecognitionToTopic, MicrophonePhraseCollector, ModelCache -WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper') +WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper") os.makedirs(WHISPER_CACHE, exist_ok=True) os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE + class TranscribeMicrophone(Node): def __init__(self): - Node.__init__(self, 'transcribe_microphone') + Node.__init__(self, "transcribe_microphone") self.worker = None self.collector = None - self.create_service(Empty, '/whisper/adjust_for_noise', self.adjust_for_noise) - self.create_service(Empty, '/whisper/start_listening', self.start_listening) - self.create_service(Empty, '/whisper/stop_listening', self.stop_listening) + self.create_service(Empty, "/whisper/adjust_for_noise", self.adjust_for_noise) + self.create_service(Empty, "/whisper/start_listening", self.start_listening) + self.create_service(Empty, "/whisper/stop_listening", self.stop_listening) self.get_logger().info("Starting the Whisper worker!") self.run_transcription() def run_transcription(self): if len(sys.argv) < 3: - print('Usage:') - print('rosrun lasr_speech_recognition transcribe_microphone by-index ') - print('rosrun lasr_speech_recognition transcribe_microphone by-name ') + print("Usage:") + print( + "rosrun lasr_speech_recognition transcribe_microphone by-index " + ) + print( + "rosrun lasr_speech_recognition transcribe_microphone by-name " + ) exit(1) else: matcher = sys.argv[1] device_index = None - if matcher == 'by-index': + if matcher == "by-index": device_index = int(sys.argv[2]) - elif matcher == 'by-name': + elif matcher == "by-name": import speech_recognition as sr + microphones = enumerate(sr.Microphone.list_microphone_names()) target_name = sys.argv[2] @@ -49,13 +55,12 @@ def run_transcription(self): break if device_index is None: - print('Could not find device!') + print("Could not find device!") exit(1) else: - print('Invalid matcher') + print("Invalid matcher") exit(1) - self.collector = MicrophonePhraseCollector(device_index=device_index) self.collector.adjust_for_noise() @@ -64,14 +69,24 @@ def run_transcription(self): # try to run inference on the example file package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") - package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) + package_root = os.path.abspath( + os.path.join( + package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper" + ) + ) example_fp = os.path.join(package_root, "test.m4a") - self.get_logger().info("Running transcription on example file to ensure model is loaded...") - model_transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available()) + self.get_logger().info( + "Running transcription on example file to ensure model is loaded..." + ) + model_transcription = model.transcribe( + example_fp, fp16=torch.cuda.is_available() + ) self.get_logger().info(str(model_transcription)) - self.worker = SpeechRecognitionToTopic(self.collector, model, "transcription", infer_partial = False) + self.worker = SpeechRecognitionToTopic( + self.collector, model, "transcription", infer_partial=False + ) def adjust_for_noise(self, request, response): self.collector.adjust_for_noise() @@ -85,11 +100,12 @@ def stop_listening(self, request, response): self.worker.stop() return response + def main(args=None): rclpy.init(args=args) transcribe_microphone = TranscribeMicrophone() rclpy.spin(transcribe_microphone) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py index e7c307f50..8adf3bd8c 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -18,10 +18,11 @@ from lasr_speech_recognition_interfaces.action import TranscribeSpeech # type: ignore from rclpy.executors import ExternalShutdownException from std_msgs.msg import String # type: ignore -from src import ModelCache # type: ignore +from src import ModelCache # type: ignore # TODO: argpars -> ROS2 params, test behaviour of preemption + @dataclass class speech_model_params: """Class for storing speech recognition model parameters. @@ -58,9 +59,9 @@ class TranscribeSpeechAction(Node): _result = TranscribeSpeech.Result() def __init__( - self, - action_name: str, - model_params: speech_model_params, + self, + action_name: str, + model_params: speech_model_params, ) -> None: """Starts an action server for transcribing speech. @@ -126,9 +127,9 @@ def _configure_microphone(self) -> sr.Microphone: ) def _configure_recogniser( - self, - energy_threshold: Optional[float] = None, - pause_threshold: Optional[float] = None, + self, + energy_threshold: Optional[float] = None, + pause_threshold: Optional[float] = None, ) -> sr.Recognizer: """Configures the speech recogniser object. @@ -212,8 +213,8 @@ async def execute_cb(self, goal_handle) -> None: ).get_wav_data() # Magic number 32768.0 is the maximum value of a 16-bit signed integer float_data = ( - np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") - / 32768.0 + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 ) if goal_handle.is_cancel_requested(): @@ -265,7 +266,6 @@ def parse_args() -> dict: # port = node.declare_parameter('port', '/dev/ttyUSB0').value # assert isinstance(port, str), 'port parameter must be a str' - parser.add_argument( "--action_name", type=str, @@ -372,6 +372,7 @@ def configure_whisper_cache() -> None: # Environmental variable required to run whisper locally os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + def main(args=None): rclpy.init(args=args) @@ -383,4 +384,4 @@ def main(args=None): try: rclpy.spin(server) except (KeyboardInterrupt, ExternalShutdownException): - pass \ No newline at end of file + pass diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml index 825aae036..1cac47617 100644 --- a/common/speech/lasr_speech_recognition_whisper/package.xml +++ b/common/speech/lasr_speech_recognition_whisper/package.xml @@ -1,30 +1,30 @@ - lasr_speech_recognition_whisper - 0.0.0 - Speech recognition implemented using OpenAI Whisper - maayan - MIT + lasr_speech_recognition_whisper + 0.0.0 + Speech recognition implemented using OpenAI Whisper + maayan + MIT - ament_copyright - ament_flake8 - ament_pep257 - python3-pytest + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest - - - lasr_speech_recognition_interfaces - actionlib - actionlib_msgs - actionlib - actionlib_msgs + + + lasr_speech_recognition_interfaces + actionlib + actionlib_msgs + actionlib + actionlib_msgs - - ament_python - requirements.txt - + + ament_python + requirements.txt + diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py index 16ca35d25..a3ce21904 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py @@ -14,5 +14,6 @@ def main(): # print("Available microphone devices (sounddevice):") # print(sounddevice.query_devices()) -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py index c77396467..026ab2875 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py @@ -5,16 +5,19 @@ import numpy as np from pathlib import Path import speech_recognition as sr -from src import ModelCache # type: ignore +from src import ModelCache # type: ignore import sounddevice # needed to remove ALSA error messages from typing import Dict import rclpy # TODO argparse -> ROS params + def parse_args() -> Dict: parser = argparse.ArgumentParser() - parser.add_argument("--device_index", help="Microphone index", type=int, default=None) + parser.add_argument( + "--device_index", help="Microphone index", type=int, default=None + ) return vars(parser.parse_args()) @@ -67,6 +70,7 @@ def main(args=None): threshold += 100 recognizer.energy_threshold = threshold + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py index e0c94c23e..d14144e21 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py @@ -8,6 +8,7 @@ # TODO argparse -> ROS params + def parse_args() -> dict: """Parse command line arguments into a dictionary. @@ -16,7 +17,9 @@ def parse_args() -> dict: """ parser = argparse.ArgumentParser(description="Test microphones") - parser.add_argument("-m", "--microphone", type=int, help="Microphone index", default=None) + parser.add_argument( + "-m", "--microphone", type=int, help="Microphone index", default=None + ) parser.add_argument( "-o", "--output_dir", type=str, help="Directory to save audio files" ) @@ -64,5 +67,6 @@ def main(args: dict = None) -> None: rclpy.shutdown() + if __name__ == "__main__": main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py index d0670845d..2448e73ec 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -7,6 +7,7 @@ # https://docs.ros2.org/latest/api/rclpy/api/actions.html + class TestSpeechServerClient(Node): def __init__(self): Node.__init__(self, "listen_action_client") @@ -20,8 +21,12 @@ def send_goal(self, goal): self.client.wait_for_server() self.get_logger().info("Server activated, sending goal...") - self.goal_future = self.client.send_goal_async(goal, feedback_callback=self.feedback_cb) # Returns a Future instance when the goal request has been accepted or rejected. - self.goal_future.add_done_callback(self.response_cb) # When received get response + self.goal_future = self.client.send_goal_async( + goal, feedback_callback=self.feedback_cb + ) # Returns a Future instance when the goal request has been accepted or rejected. + self.goal_future.add_done_callback( + self.response_cb + ) # When received get response def feedback_cb(self, msg): self.get_logger().info(f"Received feedback: {msg.feedback}") @@ -33,13 +38,16 @@ def response_cb(self, future): return self.get_logger().info("Goal was accepted") - self.result_future = handle.get_result_async() # Not using get_result() in cb, as can cause deadlock according to docs + self.result_future = ( + handle.get_result_async() + ) # Not using get_result() in cb, as can cause deadlock according to docs self.result_future.add_done_callback(self.result_cb) def result_cb(self, future): result = future.result().result self.get_logger().info(f"Transcribed Speech: {result.sequence}") + def main(args=None): rclpy.init(args=args) while rclpy.ok(): @@ -54,5 +62,6 @@ def main(args=None): client.destroy_node() rclpy.shutdown() + if __name__ == "__main__": main() diff --git a/common/speech/lasr_speech_recognition_whisper/setup.cfg b/common/speech/lasr_speech_recognition_whisper/setup.cfg index 5ec86217a..1f6a54400 100644 --- a/common/speech/lasr_speech_recognition_whisper/setup.cfg +++ b/common/speech/lasr_speech_recognition_whisper/setup.cfg @@ -1,4 +1,4 @@ [develop] -script_dir=$base/lib/lasr_speech_recognition_whisper +script_dir = $base/lib/lasr_speech_recognition_whisper [install] -install_scripts=$base/lib/lasr_speech_recognition_whisper +install_scripts = $base/lib/lasr_speech_recognition_whisper diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py index 3fbac464a..c6a801483 100644 --- a/common/speech/lasr_speech_recognition_whisper/setup.py +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -1,11 +1,11 @@ from setuptools import find_packages, setup -package_name = 'lasr_speech_recognition_whisper' +package_name = "lasr_speech_recognition_whisper" setup( name=package_name, - version='0.0.0', - packages=find_packages(exclude=['test']), + version="0.0.0", + packages=find_packages(exclude=["test"]), # packages=[package_name, f"{package_name}.lasr_speech_recognition_whisper", f"{package_name}.src"], # package_dir={ # '': '.', @@ -14,26 +14,25 @@ # f"{package_name}.src": os.path.join(package_name, 'src'), # }, data_files=[ - ('share/ament_index/resource_index/packages', - ['resource/' + package_name]), - ('share/' + package_name, ['package.xml']), + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), ], - install_requires=['setuptools'], + install_requires=["setuptools"], zip_safe=True, - maintainer='maayan', - maintainer_email='maayan.armony@gmail.com', - description='Speech recognition implemented using OpenAI Whisper', - license='MIT', - tests_require=['pytest'], + maintainer="maayan", + maintainer_email="maayan.armony@gmail.com", + description="Speech recognition implemented using OpenAI Whisper", + license="MIT", + tests_require=["pytest"], entry_points={ - 'console_scripts': [ - 'transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main', - 'transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main', - 'simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main', - 'list_microphones = scripts.list_microphones:main', - 'microphone_tuning_test = scripts.microphone_tuning_test:main', - 'test_microphones = scripts.test_microphones:main', - 'test_speech_server = scripts.test_speech_server:main', + "console_scripts": [ + "transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main", + "transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main", + "simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main", + "list_microphones = scripts.list_microphones:main", + "microphone_tuning_test = scripts.microphone_tuning_test:main", + "test_microphones = scripts.test_microphones:main", + "test_speech_server = scripts.test_speech_server:main", ], }, ) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py index 69327473c..372e26477 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py @@ -9,4 +9,4 @@ SpeechRecognitionToStdout, SpeechRecognitionToTopic, ) -from .cache import ModelCache \ No newline at end of file +from .cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py index 7a86f38e6..259ffffa5 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -6,13 +6,13 @@ # Keep all loaded models in memory MODEL_CACHE = {} + class ModelCache(Node): def __init__(self): - super().__init__('lasr_speech_recognition_whisper_cache') + super().__init__("lasr_speech_recognition_whisper_cache") def load_model( - self, - name: str, device: str = "cpu", load_test_file: bool = False + self, name: str, device: str = "cpu", load_test_file: bool = False ) -> whisper.Whisper: """Loads a whisper model from disk, or from cache if it has already been loaded. @@ -34,8 +34,17 @@ def load_model( MODEL_CACHE[name] = whisper.load_model(name, device=device) self.get_logger().info(f"Sucessfully loaded model {name} on {device}") if load_test_file: - package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") - package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) + package_install = packages.get_package_prefix( + "lasr_speech_recognition_whisper" + ) + package_root = os.path.abspath( + os.path.join( + package_install, + os.pardir, + os.pardir, + "lasr_speech_recognition_whisper", + ) + ) example_fp = os.path.join(package_root, "test.m4a") self.get_logger().info( "Running transcription on example file to ensure model is loaded..." diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py index 9edbc313b..d8c5fbea4 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py @@ -5,6 +5,7 @@ from queue import Queue from abc import ABC, abstractmethod + class AbstractPhraseCollector(ABC): """ Supertype holding a queue of audio data representing a phrase diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py index abcd0fd1e..e405ca8c0 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py @@ -9,6 +9,7 @@ # TODO rospy.wait_for_message() + class AudioTopic(sr.AudioSource, Node): """ Use a ROS topic as an AudioSource @@ -21,7 +22,9 @@ def __init__(self, topic: str, chunk_size=1024) -> None: Node.__init__(self, "source") self._topic = topic - self.subscription = self.create_subscription(AudioInfo, f"{topic}/audio_info", self.callback, 10) + self.subscription = self.create_subscription( + AudioInfo, f"{topic}/audio_info", self.callback, 10 + ) # config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo) self.config = None # TODO test that this works if self.config is not None: @@ -48,7 +51,9 @@ def __enter__(self): self.stream is None ), "This audio source is already inside a context manager" self.stream = BytesFIFO(1024 * 10) # 10 kB buffer - self._sub = self.node.create_subscription(AudioData, f"{self._topic}/audio", self._read) + self._sub = self.node.create_subscription( + AudioData, f"{self._topic}/audio", self._read + ) return self def __exit__(self, exc_type, exc_value, traceback): @@ -57,7 +62,9 @@ def __exit__(self, exc_type, exc_value, traceback): """ self.stream = None - self.destroy_subscription(self._sub) # TODO behaviour, was self._sub.unregister() + self.destroy_subscription( + self._sub + ) # TODO behaviour, was self._sub.unregister() def _read(self, msg: AudioData) -> None: """ diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py index 998475578..43eac780b 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py @@ -39,7 +39,7 @@ def __init__( maximum_phrase_length=timedelta(seconds=3), infer_partial=True, ) -> None: - Node.__init__(self, 'worker') + Node.__init__(self, "worker") self._collector = collector self._tmp_file = NamedTemporaryFile().name self._model = model diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py index 97a39196e..ceffe896d 100644 --- a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py +++ b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py @@ -17,9 +17,11 @@ # Remove the `skip` decorator once the source file(s) have a copyright header -@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') +@pytest.mark.skip( + reason="No copyright header has been placed in the generated source file." +) @pytest.mark.copyright @pytest.mark.linter def test_copyright(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found errors' + rc = main(argv=[".", "test"]) + assert rc == 0, "Found errors" diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py index 27ee1078f..ee79f31ac 100644 --- a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py +++ b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py @@ -20,6 +20,6 @@ @pytest.mark.linter def test_flake8(): rc, errors = main_with_errors(argv=[]) - assert rc == 0, \ - 'Found %d code style errors / warnings:\n' % len(errors) + \ - '\n'.join(errors) + assert rc == 0, "Found %d code style errors / warnings:\n" % len( + errors + ) + "\n".join(errors) diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py index b234a3840..a2c3deb8e 100644 --- a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py +++ b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py @@ -19,5 +19,5 @@ @pytest.mark.linter @pytest.mark.pep257 def test_pep257(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found code style errors / warnings' + rc = main(argv=[".", "test"]) + assert rc == 0, "Found code style errors / warnings" From 065ffb377eda45800a002a89f903b8c1b0e9a4ad Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony <99334379+maayan25@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:50:15 +0000 Subject: [PATCH 12/14] Remove todo Co-authored-by: Jared Swift --- .../transcribe_microphone_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py index 8adf3bd8c..9b6cfc574 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -262,9 +262,6 @@ def parse_args() -> dict: description="Starts an action server for transcribing speech." ) - # TODO change to ROS2 rosparams: - # port = node.declare_parameter('port', '/dev/ttyUSB0').value - # assert isinstance(port, str), 'port parameter must be a str' parser.add_argument( "--action_name", From 4adee4cc5e24f63f5910cfcdc8fe9bdb7de7a692 Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Tue, 3 Dec 2024 12:33:17 +0000 Subject: [PATCH 13/14] Update files according to review --- common/__init__.py | 0 common/speech/__init__.py | 0 .../README.md | 1 - .../lasr_speech_recognition_whisper/README.md | 1 - .../__init__.py | 0 .../simple_transcribe_microphone.py | 101 --------- .../transcribe_microphone.py | 111 ---------- .../log/COLCON_IGNORE | 0 .../src/__init__.py | 11 - .../__init__.py | 11 - .../bytesfifo.py | 137 ------------ .../collector.py | 131 ----------- .../lasr_speech_recognition_whisper/source.py | 74 ------- .../lasr_speech_recognition_whisper/worker.py | 207 ------------------ 14 files changed, 785 deletions(-) delete mode 100644 common/__init__.py delete mode 100644 common/speech/__init__.py delete mode 100644 common/speech/lasr_speech_recognition_whisper/__init__.py delete mode 100644 common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py delete mode 100644 common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py delete mode 100644 common/speech/lasr_speech_recognition_whisper/log/COLCON_IGNORE delete mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py delete mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py delete mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py delete mode 100644 common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py diff --git a/common/__init__.py b/common/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/common/speech/__init__.py b/common/speech/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/common/speech/lasr_speech_recognition_interfaces/README.md b/common/speech/lasr_speech_recognition_interfaces/README.md index 8e7aab96f..30878a2f4 100644 --- a/common/speech/lasr_speech_recognition_interfaces/README.md +++ b/common/speech/lasr_speech_recognition_interfaces/README.md @@ -5,7 +5,6 @@ Common messages used for speech recognition This package is maintained by: - [Maayan Armony](mailto:maayan.armony@gmail.com) -- [Paul Makles](mailto:me@insrt.uk) (ROS1) ## Prerequisites diff --git a/common/speech/lasr_speech_recognition_whisper/README.md b/common/speech/lasr_speech_recognition_whisper/README.md index c9f58557e..0da406522 100644 --- a/common/speech/lasr_speech_recognition_whisper/README.md +++ b/common/speech/lasr_speech_recognition_whisper/README.md @@ -5,7 +5,6 @@ Speech recognition implemented using OpenAI Whisper This package is maintained by: - [Maayan Armony](mailto:maayan.armony@gmail.com) -- [Paul Makles](mailto:me@insrt.uk) (ROS1) ## Prerequisites diff --git a/common/speech/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py deleted file mode 100644 index 7b3b1f8a0..000000000 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin python3 -import os -import torch -import rclpy -from ament_index_python import packages - -import sys -from pathlib import Path -import speech_recognition as sr -import numpy as np - -import sounddevice # needed to remove ALSA error messages -from lasr_speech_recognition_interfaces.srv import TranscribeAudio -from src import ModelCache # type: ignore - -MODEL = "medium.en" # Whisper model -TIMEOUT = 5.0 # Timeout for listening for the start of a phrase -PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase - -WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper") -os.makedirs(WHISPER_CACHE, exist_ok=True) -os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE - -if len(sys.argv) < 3: - print("Usage:") - print( - "ros2 run lasr_speech_recognition transcribe_microphone by-index " - ) - print("ros2 run lasr_speech_recognition transcribe_microphone by-name ") - exit(1) -else: - matcher = sys.argv[1] - device_index = None - if matcher == "by-index": - device_index = int(sys.argv[2]) - elif matcher == "by-name": - import speech_recognition as sr - - microphones = enumerate(sr.Microphone.list_microphone_names()) - - target_name = sys.argv[2] - for index, name in microphones: - if target_name in name: - device_index = index - break - - if device_index is None: - print("Could not find device!") - exit(1) - else: - print("Invalid matcher") - exit(1) - -rclpy.init(args=sys.argv) -node = rclpy.create_node("transcribe_mic") - -device = "cuda" if torch.cuda.is_available() else "cpu" -model_cache = ModelCache() -model = model_cache.load_model("medium.en", device=device) - -# try to run inference on the example file -package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") -package_root = os.path.abspath( - os.path.join( - package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper" - ) -) -example_fp = os.path.join(package_root, "test.m4a") -node.get_logger().info( - "Running transcription on example file to ensure model is loaded..." -) -transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available()) -node.get_logger().info(str(transcription)) - -microphone = sr.Microphone(device_index=device_index, sample_rate=16000) -r = sr.Recognizer() -with microphone as source: - r.adjust_for_ambient_noise(source) - - -def handle_transcribe_audio(_): - with microphone as source: - - wav_data = r.listen( - source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT - ).get_wav_data() - float_data = ( - np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") - / 32768.0 - ) - - phrase = model.transcribe(float_data, fp16=device == "cuda")["text"] - return TranscribeAudio.Response(phrase=phrase) - - -node.create_service( - TranscribeAudio, "/whisper/transcribe_audio", handle_transcribe_audio -) - -node.get_logger().info("Whisper service ready") -rclpy.spin(node) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py deleted file mode 100644 index 3225072c3..000000000 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin python3 -import os -import sys -import torch -from pathlib import Path - -import rclpy -from rclpy.node import Node -from ament_index_python import packages -from std_srvs.srv import Empty -from src import SpeechRecognitionToTopic, MicrophonePhraseCollector, ModelCache - -WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper") -os.makedirs(WHISPER_CACHE, exist_ok=True) -os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE - - -class TranscribeMicrophone(Node): - def __init__(self): - Node.__init__(self, "transcribe_microphone") - self.worker = None - self.collector = None - - self.create_service(Empty, "/whisper/adjust_for_noise", self.adjust_for_noise) - self.create_service(Empty, "/whisper/start_listening", self.start_listening) - self.create_service(Empty, "/whisper/stop_listening", self.stop_listening) - - self.get_logger().info("Starting the Whisper worker!") - self.run_transcription() - - def run_transcription(self): - if len(sys.argv) < 3: - print("Usage:") - print( - "rosrun lasr_speech_recognition transcribe_microphone by-index " - ) - print( - "rosrun lasr_speech_recognition transcribe_microphone by-name " - ) - exit(1) - else: - matcher = sys.argv[1] - device_index = None - if matcher == "by-index": - device_index = int(sys.argv[2]) - elif matcher == "by-name": - import speech_recognition as sr - - microphones = enumerate(sr.Microphone.list_microphone_names()) - - target_name = sys.argv[2] - for index, name in microphones: - if target_name in name: - device_index = index - break - - if device_index is None: - print("Could not find device!") - exit(1) - else: - print("Invalid matcher") - exit(1) - - self.collector = MicrophonePhraseCollector(device_index=device_index) - self.collector.adjust_for_noise() - - model_cache = ModelCache() - model = model_cache.load_model("medium.en") - - # try to run inference on the example file - package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") - package_root = os.path.abspath( - os.path.join( - package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper" - ) - ) - example_fp = os.path.join(package_root, "test.m4a") - - self.get_logger().info( - "Running transcription on example file to ensure model is loaded..." - ) - model_transcription = model.transcribe( - example_fp, fp16=torch.cuda.is_available() - ) - self.get_logger().info(str(model_transcription)) - - self.worker = SpeechRecognitionToTopic( - self.collector, model, "transcription", infer_partial=False - ) - - def adjust_for_noise(self, request, response): - self.collector.adjust_for_noise() - return response - - def start_listening(self, request, response): - self.worker.start() - return response - - def stop_listening(self, request, response): - self.worker.stop() - return response - - -def main(args=None): - rclpy.init(args=args) - transcribe_microphone = TranscribeMicrophone() - rclpy.spin(transcribe_microphone) - - -if __name__ == "__main__": - main() diff --git a/common/speech/lasr_speech_recognition_whisper/log/COLCON_IGNORE b/common/speech/lasr_speech_recognition_whisper/log/COLCON_IGNORE deleted file mode 100644 index e69de29bb..000000000 diff --git a/common/speech/lasr_speech_recognition_whisper/src/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/__init__.py index 473b206b7..ca8a17393 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/__init__.py +++ b/common/speech/lasr_speech_recognition_whisper/src/__init__.py @@ -1,12 +1 @@ -# from .collector import AbstractPhraseCollector, AudioTopicPhraseCollector, MicrophonePhraseCollector, RecognizerPhraseCollector -from .lasr_speech_recognition_whisper.collector import ( - AbstractPhraseCollector, - MicrophonePhraseCollector, - RecognizerPhraseCollector, -) -from .lasr_speech_recognition_whisper.worker import ( - SpeechRecognitionWorker, - SpeechRecognitionToStdout, - SpeechRecognitionToTopic, -) from .lasr_speech_recognition_whisper.cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py index 372e26477..f662b86a0 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py @@ -1,12 +1 @@ -# from .collector import AbstractPhraseCollector, AudioTopicPhraseCollector, MicrophonePhraseCollector, RecognizerPhraseCollector -from .collector import ( - AbstractPhraseCollector, - MicrophonePhraseCollector, - RecognizerPhraseCollector, -) -from .worker import ( - SpeechRecognitionWorker, - SpeechRecognitionToStdout, - SpeechRecognitionToTopic, -) from .cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py deleted file mode 100644 index 1f86b7ffc..000000000 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py +++ /dev/null @@ -1,137 +0,0 @@ -import io - - -class BytesFIFO(object): - """ - A FIFO that can store a fixed number of bytes. - https://github.com/hbock/byte-fifo/blob/master/fifo.py - """ - - def __init__(self, init_size): - """Create a FIFO of ``init_size`` bytes.""" - self._buffer = io.BytesIO(b"\x00" * init_size) - self._size = init_size - self._filled = 0 - self._read_ptr = 0 - self._write_ptr = 0 - - def read(self, size=-1): - """ - Read at most ``size`` bytes from the FIFO. - - If less than ``size`` bytes are available, or ``size`` is negative, - return all remaining bytes. - """ - if size < 0: - size = self._filled - - # Go to read pointer - self._buffer.seek(self._read_ptr) - - # Figure out how many bytes we can really read - size = min(size, self._filled) - contig = self._size - self._read_ptr - contig_read = min(contig, size) - - ret = self._buffer.read(contig_read) - self._read_ptr += contig_read - if contig_read < size: - leftover_size = size - contig_read - self._buffer.seek(0) - ret += self._buffer.read(leftover_size) - self._read_ptr = leftover_size - - self._filled -= size - - return ret - - def write(self, data): - """ - Write as many bytes of ``data`` as are free in the FIFO. - - If less than ``len(data)`` bytes are free, write as many as can be written. - Returns the number of bytes written. - """ - free = self.free() - write_size = min(len(data), free) - - if write_size: - contig = self._size - self._write_ptr - contig_write = min(contig, write_size) - # TODO: avoid 0 write - # TODO: avoid copy - # TODO: test performance of above - self._buffer.seek(self._write_ptr) - self._buffer.write(data[:contig_write]) - self._write_ptr += contig_write - - if contig < write_size: - self._buffer.seek(0) - self._buffer.write(data[contig_write:write_size]) - # self._buffer.write(buffer(data, contig_write, write_size - contig_write)) - self._write_ptr = write_size - contig_write - - self._filled += write_size - - return write_size - - def flush(self): - """Flush all data from the FIFO.""" - self._filled = 0 - self._read_ptr = 0 - self._write_ptr = 0 - - def empty(self): - """Return ```True``` if FIFO is empty.""" - return self._filled == 0 - - def full(self): - """Return ``True`` if FIFO is full.""" - return self._filled == self._size - - def free(self): - """Return the number of bytes that can be written to the FIFO.""" - return self._size - self._filled - - def capacity(self): - """Return the total space allocated for this FIFO.""" - return self._size - - def __len__(self): - """Return the amount of data filled in FIFO""" - return self._filled - - def __nonzero__(self): - """Return ```True``` if the FIFO is not empty.""" - return self._filled > 0 - - def resize(self, new_size): - """ - Resize FIFO to contain ``new_size`` bytes. If FIFO currently has - more than ``new_size`` bytes filled, :exc:`ValueError` is raised. - If ``new_size`` is less than 1, :exc:`ValueError` is raised. - - If ``new_size`` is smaller than the current size, the internal - buffer is not contracted (yet). - """ - if new_size < 1: - raise ValueError("Cannot resize to zero or less bytes.") - - if new_size < self._filled: - raise ValueError( - "Cannot contract FIFO to less than {} bytes, " - "or data will be lost.".format(self._filled) - ) - - # original data is non-contiguous. we need to copy old data, - # re-write to the beginning of the buffer, and re-sync - # the read and write pointers. - if self._read_ptr >= self._write_ptr: - old_data = self.read(self._filled) - self._buffer.seek(0) - self._buffer.write(old_data) - self._filled = len(old_data) - self._read_ptr = 0 - self._write_ptr = self._filled - - self._size = new_size diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py deleted file mode 100644 index d8c5fbea4..000000000 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py +++ /dev/null @@ -1,131 +0,0 @@ -import rclpy -from rclpy.node import Node -import speech_recognition as sr - -from queue import Queue -from abc import ABC, abstractmethod - - -class AbstractPhraseCollector(ABC): - """ - Supertype holding a queue of audio data representing a phrase - """ - - data: Queue[bytes] = Queue() - - @abstractmethod - def start(self): - """ - Start collecting phrases - """ - pass - - @abstractmethod - def stop(self): - """ - Stop collecting phrases - """ - pass - - @abstractmethod - def sample_rate(self): - """ - Sample rate of the data - """ - pass - - @abstractmethod - def sample_width(self): - """ - Sample width of the data - """ - pass - - -class RecognizerPhraseCollector(AbstractPhraseCollector, Node): - """ - Collect phrases using a SoundRecognition Recognizer - - This will monitor energy levels on the input and only - capture when a certain threshold of activity is met. - """ - - _recorder: sr.Recognizer - _phrase_time_limit: float - - def _record_callback(self, _, audio: sr.AudioData) -> None: - """ - Collect raw audio data from the microphone - """ - self.data.put(audio.get_raw_data()) - - def __init__( - self, energy_threshold: int = 500, phrase_time_limit: float = 2 - ) -> None: - super().__init__("collector") - # Node.__init__(self, "collector") - - self._recorder = sr.Recognizer() - self._recorder.dynamic_energy_threshold = False - self._recorder.energy_threshold = energy_threshold - self._phrase_time_limit = phrase_time_limit - - @abstractmethod - def adjust_for_noise(self, source: sr.AudioSource): - self.get_logger().info("Adjusting for background noise...") - with source: - self._recorder.adjust_for_ambient_noise(source) - - @abstractmethod - def start(self, source: sr.AudioSource): - self.get_logger().info("Started source listen thread") - self._stopper = self._recorder.listen_in_background( - source, self._record_callback, phrase_time_limit=self._phrase_time_limit - ) - - def stop(self): - self._stopper() - - def sample_rate(self): - return self._source.SAMPLE_RATE - - def sample_width(self): - return self._source.SAMPLE_WIDTH - - -class MicrophonePhraseCollector(RecognizerPhraseCollector): - """ - Collect phrases from the default microphone - """ - - _source: sr.Microphone - - def __init__( - self, - energy_threshold: int = 500, - phrase_time_limit: float = 2, - device_index: int = None, - ) -> None: - self._source = sr.Microphone(device_index=device_index, sample_rate=16000) - super().__init__(energy_threshold, phrase_time_limit) - - def adjust_for_noise(self): - return super().adjust_for_noise(self._source) - - def start(self): - return super().start(self._source) - - -# class AudioTopicPhraseCollector(RecognizerPhraseCollector): -# ''' -# Collect phrases from an audio topic -# ''' - -# _source: AudioTopic - -# def __init__(self, topic: str, energy_threshold: int = 100, phrase_time_limit: float = 2) -> None: -# self._source = AudioTopic(topic) -# super().__init__(energy_threshold, phrase_time_limit) - -# def start(self): -# return super().start(self._source) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py deleted file mode 100644 index e405ca8c0..000000000 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py +++ /dev/null @@ -1,74 +0,0 @@ -import rclpy -from rclpy.node import Node -import pyaudio -import speech_recognition as sr - -from audio_common_msgs.msg import AudioInfo, AudioData - -from .bytesfifo import BytesFIFO - -# TODO rospy.wait_for_message() - - -class AudioTopic(sr.AudioSource, Node): - """ - Use a ROS topic as an AudioSource - """ - - _topic: str - # _sub: node.create_subscription TODO add type if possible - - def __init__(self, topic: str, chunk_size=1024) -> None: - Node.__init__(self, "source") - - self._topic = topic - self.subscription = self.create_subscription( - AudioInfo, f"{topic}/audio_info", self.callback, 10 - ) - # config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo) - self.config = None # TODO test that this works - if self.config is not None: - assert self.config.coding_format == "wave", "Expected Wave audio format" - assert self.config.sample_format == "S16LE", "Expected sample format S16LE" - self.get_logger().info(self.config) - - self.SAMPLE_WIDTH = pyaudio.get_sample_size(pyaudio.paInt16) - self.SAMPLE_RATE = self.config.sample_rate - - self.CHUNK = chunk_size - self.stream = None - - def callback(self, msg): - self.get_logger().info("Message received") - self.config = msg - - def __enter__(self): - """ - Start stream when entering with: block - """ - - assert ( - self.stream is None - ), "This audio source is already inside a context manager" - self.stream = BytesFIFO(1024 * 10) # 10 kB buffer - self._sub = self.node.create_subscription( - AudioData, f"{self._topic}/audio", self._read - ) - return self - - def __exit__(self, exc_type, exc_value, traceback): - """ - Close out stream on exit - """ - - self.stream = None - self.destroy_subscription( - self._sub - ) # TODO behaviour, was self._sub.unregister() - - def _read(self, msg: AudioData) -> None: - """ - Forward raw audio data to queue - """ - - self.stream.write(msg.data) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py deleted file mode 100644 index 43eac780b..000000000 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py +++ /dev/null @@ -1,207 +0,0 @@ -import torch - -from rclpy.node import Node -from rclpy.publisher import Publisher - -import whisper -import speech_recognition as sr - -from io import BytesIO -from time import sleep -from threading import Thread -from abc import ABC, abstractmethod -from tempfile import NamedTemporaryFile -from datetime import datetime, timedelta - -from .collector import AbstractPhraseCollector - -from lasr_speech_recognition_interfaces.msg import Transcription - - -class SpeechRecognitionWorker(ABC, Node): - """ - Collect and run inference on phrases to produce a transcription - """ - - _collector: AbstractPhraseCollector - _tmp_file: NamedTemporaryFile - _model: whisper.Whisper - _current_sample: bytes - _phrase_start: datetime - _maximum_phrase_length: timedelta | None - _infer_partial: bool - _stopped = True - - def __init__( - self, - collector: AbstractPhraseCollector, - model: whisper.Whisper, - maximum_phrase_length=timedelta(seconds=3), - infer_partial=True, - ) -> None: - Node.__init__(self, "worker") - self._collector = collector - self._tmp_file = NamedTemporaryFile().name - self._model = model - self._current_sample = bytes() - self._phrase_start = None - self._maximum_phrase_length = maximum_phrase_length - self._infer_partial = infer_partial - - @abstractmethod - def on_phrase(self, phrase: str, finished: bool) -> None: - """ - Handle a partial or complete transcription - """ - pass - - def _finish_phrase(self): - """ - Complete the current phrase and clear the sample - """ - - text = self._perform_inference() - if text is not None: - self.on_phrase(text, True) - - self._current_sample = bytes() - self._phrase_start = None - - def _perform_inference(self): - """ - Run inference on the current sample - """ - - self.get_logger().info("Processing sample") - audio_data = sr.AudioData( - self._current_sample, - self._collector.sample_rate(), - self._collector.sample_width(), - ) - wav_data = BytesIO(audio_data.get_wav_data()) - - with open(self._tmp_file, "w+b") as f: - f.write(wav_data.read()) - - self.get_logger().info("Running inference") - try: - result = self._model.transcribe( - self._tmp_file, fp16=torch.cuda.is_available() - ) - except RuntimeError: - return None - text = result["text"].strip() - - # Detect and drop garbage - if len(text) == 0 or text.lower() in [".", "you", "thanks for watching!"]: - self._phrase_start = None - self._current_sample = bytes() - self.get_logger().info("Skipping garbage...") - return None - - return text - - def _worker(self): - """ - Indefinitely perform inference on the given data - """ - - self.get_logger().info("Started inference worker") - - while not self._stopped: - try: - # Check whether the current phrase has timed out - now = datetime.utcnow() - if ( - self._phrase_start - and now - self._phrase_start > self._maximum_phrase_length - ): - self.get_logger().info("Reached timeout for phrase, ending now.") - self._finish_phrase() - - # Start / continue phrase if data is coming in - if not self._collector.data.empty(): - self._phrase_start = datetime.utcnow() - - # Concatenate new data with current sample - while not self._collector.data.empty(): - self._current_sample += self._collector.data.get() - - self.get_logger().info( - "Received and added more data to current audio sample." - ) - - # Run inference on partial sample if enabled - if self._infer_partial: - text = self._perform_inference() - - # Handle partial transcription - if text is not None: - self.on_phrase(text, False) - - sleep(0.2) - except KeyboardInterrupt: - self._stopped = True - - self.get_logger().info("Worker finished") - - def start(self): - """ - Start performing inference on incoming data - """ - - assert self._stopped, "Already running inference" - self._stopped = False - self._collector.start() - worker_thread = Thread(target=self._worker) - worker_thread.start() - - def stop(self): - """ - Stop the worker from running inference - """ - - assert not self._stopped, "Not currently running" - self._collector.stop() - self._stopped = True - - # clear next phrase - self._current_sample = bytes() - while not self._collector.data.empty(): - self._current_sample += self._collector.data.get() - - -class SpeechRecognitionToStdout(SpeechRecognitionWorker): - """ - Recognise speech and pass it through to standard output - """ - - def on_phrase(self, phrase: str, finished: bool) -> None: - self.get_logger().info("[" + ("x" if finished else " ") + "] " + phrase) - - -class SpeechRecognitionToTopic(SpeechRecognitionToStdout): - """ - Recognise speech and publish it to a topic - """ - - _pub: Publisher - - def __init__( - self, - collector: AbstractPhraseCollector, - model: whisper.Whisper, - topic: str, - maximum_phrase_length=timedelta(seconds=1), - infer_partial=True, - ) -> None: - super().__init__(collector, model, maximum_phrase_length, infer_partial) - self.get_logger().info(f"Will be publishing transcription to {topic}") - self._pub = self.create_publisher(Transcription, topic, 5) - - def on_phrase(self, phrase: str, finished: bool) -> None: - super().on_phrase(phrase, finished) - msg = Transcription() - msg.phrase = phrase - msg.finished = finished - self._pub.publish(msg) From c8b30c55b3b8c33e99635d6e5b681ba4a0dce67e Mon Sep 17 00:00:00 2001 From: Ma'ayan Armony Date: Tue, 3 Dec 2024 12:35:16 +0000 Subject: [PATCH 14/14] refactor: black format --- .../transcribe_microphone_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py index 9b6cfc574..000678d06 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -262,7 +262,6 @@ def parse_args() -> dict: description="Starts an action server for transcribing speech." ) - parser.add_argument( "--action_name", type=str,