From f4dc346f06f74cf48337446d37e27bd2fa3390a5 Mon Sep 17 00:00:00 2001 From: Oleg Kachur Date: Tue, 10 Sep 2024 10:48:54 +0000 Subject: [PATCH] Fix gcp text to speech uri fetch - Fix acces to the uri attribute, if it's provided via the RecognitionAudio model. --- .../google/cloud/operators/speech_to_text.py | 17 ++++++++--------- .../google/cloud/operators/translate_speech.py | 15 ++++++++------- .../cloud/operators/test_speech_to_text.py | 6 +++--- .../cloud/operators/test_translate_speech.py | 18 ++++++++++-------- 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/airflow/providers/google/cloud/operators/speech_to_text.py b/airflow/providers/google/cloud/operators/speech_to_text.py index de26a8ba8216f..f8c3e4703f9e0 100644 --- a/airflow/providers/google/cloud/operators/speech_to_text.py +++ b/airflow/providers/google/cloud/operators/speech_to_text.py @@ -113,15 +113,14 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - - FileDetailsLink.persist( - context=context, - task_instance=self, - # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}" - uri=self.audio["uri"][5:], - project_id=self.project_id or hook.project_id, - ) - + if self.audio.uri: + FileDetailsLink.persist( + context=context, + task_instance=self, + # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}" + uri=self.audio.uri[5:], + project_id=self.project_id or hook.project_id, + ) response = hook.recognize_speech( config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout ) diff --git a/airflow/providers/google/cloud/operators/translate_speech.py b/airflow/providers/google/cloud/operators/translate_speech.py index fb3bdccb1abee..b0b540c31e086 100644 --- a/airflow/providers/google/cloud/operators/translate_speech.py +++ b/airflow/providers/google/cloud/operators/translate_speech.py @@ -169,7 +169,14 @@ def execute(self, context: Context) -> dict: raise AirflowException( f"Wrong response '{recognize_dict}' returned - it should contain {key} field" ) - + if self.audio.uri: + FileDetailsLink.persist( + context=context, + task_instance=self, + # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}" + uri=self.audio.uri[5:], + project_id=self.project_id or translate_hook.project_id, + ) try: translation = translate_hook.translate( values=transcript, @@ -179,12 +186,6 @@ def execute(self, context: Context) -> dict: model=self.model, ) self.log.info("Translated output: %s", translation) - FileDetailsLink.persist( - context=context, - task_instance=self, - uri=self.audio["uri"][5:], - project_id=self.project_id or translate_hook.project_id, - ) return translation except ValueError as e: self.log.error("An error has been thrown from translate speech method:") diff --git a/tests/providers/google/cloud/operators/test_speech_to_text.py b/tests/providers/google/cloud/operators/test_speech_to_text.py index 51dd6dd8db7c0..7e329500fce81 100644 --- a/tests/providers/google/cloud/operators/test_speech_to_text.py +++ b/tests/providers/google/cloud/operators/test_speech_to_text.py @@ -21,7 +21,7 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.speech_v1 import RecognizeResponse +from google.cloud.speech_v1 import RecognitionAudio, RecognitionConfig, RecognizeResponse from airflow.exceptions import AirflowException from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator @@ -29,8 +29,8 @@ PROJECT_ID = "project-id" GCP_CONN_ID = "gcp-conn-id" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] -CONFIG = {"encoding": "LINEAR16"} -AUDIO = {"uri": "gs://bucket/object"} +CONFIG = RecognitionConfig({"encoding": "LINEAR16"}) +AUDIO = RecognitionAudio({"uri": "gs://bucket/object"}) class TestCloudSpeechToTextRecognizeSpeechOperator: diff --git a/tests/providers/google/cloud/operators/test_translate_speech.py b/tests/providers/google/cloud/operators/test_translate_speech.py index 6dd000504cef5..f6f2735992e90 100644 --- a/tests/providers/google/cloud/operators/test_translate_speech.py +++ b/tests/providers/google/cloud/operators/test_translate_speech.py @@ -21,6 +21,8 @@ import pytest from google.cloud.speech_v1 import ( + RecognitionAudio, + RecognitionConfig, RecognizeResponse, SpeechRecognitionAlternative, SpeechRecognitionResult, @@ -54,8 +56,8 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook): ] op = CloudTranslateSpeechOperator( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio=RecognitionAudio({"uri": "gs://bucket/object"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), target_language="pl", format_="text", source_language=None, @@ -77,8 +79,8 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook): ) mock_speech_hook.return_value.recognize_speech.assert_called_once_with( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio=RecognitionAudio({"uri": "gs://bucket/object"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), ) mock_translate_hook.return_value.translate.assert_called_once_with( @@ -104,8 +106,8 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook): results=[SpeechRecognitionResult()] ) op = CloudTranslateSpeechOperator( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio=RecognitionAudio({"uri": "gs://bucket/object"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), target_language="pl", format_="text", source_language=None, @@ -128,8 +130,8 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook): ) mock_speech_hook.return_value.recognize_speech.assert_called_once_with( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio=RecognitionAudio({"uri": "gs://bucket/object"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), ) mock_translate_hook.return_value.translate.assert_not_called()