Skip to content

Commit

Permalink
Fix gcp text to speech uri fetch
Browse files Browse the repository at this point in the history
- Fix acces to the uri attribute, if it's provided via the RecognitionAudio model.
  • Loading branch information
Oleg Kachur committed Sep 17, 2024
1 parent b5576c0 commit f4dc346
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 27 deletions.
17 changes: 8 additions & 9 deletions airflow/providers/google/cloud/operators/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

FileDetailsLink.persist(
context=context,
task_instance=self,
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
uri=self.audio["uri"][5:],
project_id=self.project_id or hook.project_id,
)

if self.audio.uri:
FileDetailsLink.persist(
context=context,
task_instance=self,
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
uri=self.audio.uri[5:],
project_id=self.project_id or hook.project_id,
)
response = hook.recognize_speech(
config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout
)
Expand Down
15 changes: 8 additions & 7 deletions airflow/providers/google/cloud/operators/translate_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,14 @@ def execute(self, context: Context) -> dict:
raise AirflowException(
f"Wrong response '{recognize_dict}' returned - it should contain {key} field"
)

if self.audio.uri:
FileDetailsLink.persist(
context=context,
task_instance=self,
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
uri=self.audio.uri[5:],
project_id=self.project_id or translate_hook.project_id,
)
try:
translation = translate_hook.translate(
values=transcript,
Expand All @@ -179,12 +186,6 @@ def execute(self, context: Context) -> dict:
model=self.model,
)
self.log.info("Translated output: %s", translation)
FileDetailsLink.persist(
context=context,
task_instance=self,
uri=self.audio["uri"][5:],
project_id=self.project_id or translate_hook.project_id,
)
return translation
except ValueError as e:
self.log.error("An error has been thrown from translate speech method:")
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/google/cloud/operators/test_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@

import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.speech_v1 import RecognizeResponse
from google.cloud.speech_v1 import RecognitionAudio, RecognitionConfig, RecognizeResponse

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator

PROJECT_ID = "project-id"
GCP_CONN_ID = "gcp-conn-id"
IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
CONFIG = {"encoding": "LINEAR16"}
AUDIO = {"uri": "gs://bucket/object"}
CONFIG = RecognitionConfig({"encoding": "LINEAR16"})
AUDIO = RecognitionAudio({"uri": "gs://bucket/object"})


class TestCloudSpeechToTextRecognizeSpeechOperator:
Expand Down
18 changes: 10 additions & 8 deletions tests/providers/google/cloud/operators/test_translate_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pytest
from google.cloud.speech_v1 import (
RecognitionAudio,
RecognitionConfig,
RecognizeResponse,
SpeechRecognitionAlternative,
SpeechRecognitionResult,
Expand Down Expand Up @@ -54,8 +56,8 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook):
]

op = CloudTranslateSpeechOperator(
audio={"uri": "gs://bucket/object"},
config={"encoding": "LINEAR16"},
audio=RecognitionAudio({"uri": "gs://bucket/object"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
target_language="pl",
format_="text",
source_language=None,
Expand All @@ -77,8 +79,8 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook):
)

mock_speech_hook.return_value.recognize_speech.assert_called_once_with(
audio={"uri": "gs://bucket/object"},
config={"encoding": "LINEAR16"},
audio=RecognitionAudio({"uri": "gs://bucket/object"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
)

mock_translate_hook.return_value.translate.assert_called_once_with(
Expand All @@ -104,8 +106,8 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook):
results=[SpeechRecognitionResult()]
)
op = CloudTranslateSpeechOperator(
audio={"uri": "gs://bucket/object"},
config={"encoding": "LINEAR16"},
audio=RecognitionAudio({"uri": "gs://bucket/object"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
target_language="pl",
format_="text",
source_language=None,
Expand All @@ -128,8 +130,8 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook):
)

mock_speech_hook.return_value.recognize_speech.assert_called_once_with(
audio={"uri": "gs://bucket/object"},
config={"encoding": "LINEAR16"},
audio=RecognitionAudio({"uri": "gs://bucket/object"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
)

mock_translate_hook.return_value.translate.assert_not_called()

0 comments on commit f4dc346

Please sign in to comment.