-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Audio transcription support #781
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
e5a87ae
Updates, docs stubs
andrewfrench b5e7997
Audio Loader dependencies
andrewfrench 2c20378
Remove task, prefer ToolTask with Client?
andrewfrench 5e73eee
Docs and tests
andrewfrench c055935
Fix typo
andrewfrench 55cbf44
poetry lock --no-update
andrewfrench a9162cb
Tasks, tests, docs
andrewfrench af4d323
TranscriptionClient tests
andrewfrench 3d306db
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench 6715875
Add audio transcription task tests
andrewfrench 47ebb6a
Fix docs
andrewfrench d719edb
poetry run ruff format
andrewfrench ae93421
Update changelog
andrewfrench 91c97ba
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench 211eeae
Evaluate callable input at runtime
andrewfrench dd00a89
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench 34108eb
Fix docs example
andrewfrench 2906a9b
Update changelog
andrewfrench 949792c
Naming
andrewfrench 0e944d0
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench 09b9e24
poetry lock --no-update
andrewfrench 2c0e261
attr -> attrs
andrewfrench 40692b6
Merge branch 'dev' into french/240514/transcription--transcription
andrewfrench 83e8408
Fix integration test
andrewfrench File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
docs/griptape-framework/drivers/audio-transcription-drivers.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
## Overview | ||
|
||
[Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md) extract text from spoken audio. | ||
|
||
This driver acts as a critical bridge between audio transcription Engines and the underlying models, facilitating the construction and execution of API calls that transform speech into editable and searchable text. Utilized predominantly in applications that support the input of verbal communications, the Audio Transcription Driver effectively extracts and interprets speech, rendering it into a textual format that can be easily integrated into data systems and Workflows. | ||
|
||
This capability is essential for enhancing accessibility, improving content discoverability, and automating tasks that traditionally relied on manual transcription, thereby streamlining operations and enhancing efficiency across various industries. | ||
|
||
### OpenAI | ||
|
||
The [OpenAI Audio Transcription Driver](../../reference/griptape/drivers/audio_transcription/openai_audio_transcription_driver.md) utilizes OpenAI's sophisticated `whisper` model to accurately transcribe spoken audio into text. This model supports multiple languages, ensuring precise transcription across a wide range of dialects. | ||
|
||
```python | ||
from griptape.drivers import OpenAiAudioTranscriptionDriver | ||
from griptape.engines import AudioTranscriptionEngine | ||
from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient | ||
from griptape.structures import Agent | ||
|
||
|
||
driver = OpenAiAudioTranscriptionDriver( | ||
model="whisper-1" | ||
) | ||
|
||
tool = AudioTranscriptionClient( | ||
off_prompt=False, | ||
engine=AudioTranscriptionEngine( | ||
audio_transcription_driver=driver, | ||
), | ||
) | ||
|
||
Agent(tools=[tool]).run("Transcribe the following audio file: tests/resources/sentences.wav") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
docs/griptape-tools/official-tools/audio-transcription-client.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# AudioTranscriptionClient | ||
|
||
This Tool enables [Agents](../../griptape-framework/structures/agents.md) to transcribe speech from text using [Audio Transcription Engines](../../reference/griptape/engines/audio/audio_transcription_engine.md) and [Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md). | ||
|
||
```python | ||
from griptape.drivers import OpenAiAudioTranscriptionDriver | ||
from griptape.engines import AudioTranscriptionEngine | ||
from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient | ||
from griptape.structures import Agent | ||
|
||
|
||
driver = OpenAiAudioTranscriptionDriver( | ||
model="whisper-1" | ||
) | ||
|
||
tool = AudioTranscriptionClient( | ||
off_prompt=False, | ||
engine=AudioTranscriptionEngine( | ||
audio_transcription_driver=driver, | ||
), | ||
) | ||
|
||
Agent(tools=[tool]).run("Transcribe the following audio file: /Users/andrew/code/griptape/tests/resources/sentences2.wav") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
42 changes: 42 additions & 0 deletions
42
griptape/drivers/audio_transcription/base_audio_transcription_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import TextArtifact, AudioArtifact | ||
from griptape.events import StartAudioTranscriptionEvent, FinishAudioTranscriptionEvent | ||
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin | ||
|
||
if TYPE_CHECKING: | ||
from griptape.structures import Structure | ||
|
||
|
||
@define | ||
class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): | ||
model: str = field(kw_only=True, metadata={"serializable": True}) | ||
structure: Optional[Structure] = field(default=None, kw_only=True) | ||
|
||
def before_run(self) -> None: | ||
if self.structure: | ||
self.structure.publish_event(StartAudioTranscriptionEvent()) | ||
|
||
def after_run(self) -> None: | ||
if self.structure: | ||
self.structure.publish_event(FinishAudioTranscriptionEvent()) | ||
|
||
def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: | ||
for attempt in self.retrying(): | ||
with attempt: | ||
self.before_run() | ||
result = self.try_run(audio, prompts) | ||
self.after_run() | ||
|
||
return result | ||
|
||
else: | ||
raise Exception("Failed to run audio transcription") | ||
|
||
@abstractmethod | ||
def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: ... |
14 changes: 14 additions & 0 deletions
14
griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Optional | ||
|
||
from attrs import define, field | ||
from griptape.artifacts import AudioArtifact, TextArtifact | ||
from griptape.drivers import BaseAudioTranscriptionDriver | ||
from griptape.exceptions import DummyException | ||
|
||
|
||
@define | ||
class DummyAudioTranscriptionDriver(BaseAudioTranscriptionDriver): | ||
model: str = field(init=False) | ||
|
||
def try_run(self, audio: AudioArtifact, prompts: Optional[list] = None) -> TextArtifact: | ||
raise DummyException(__class__.__name__, "try_transcription") | ||
43 changes: 43 additions & 0 deletions
43
griptape/drivers/audio_transcription/openai_audio_transcription_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import annotations | ||
|
||
import io | ||
from typing import Optional | ||
|
||
import openai | ||
from attrs import field, Factory, define | ||
|
||
from griptape.artifacts import AudioArtifact, TextArtifact | ||
from griptape.drivers import BaseAudioTranscriptionDriver | ||
|
||
|
||
@define | ||
class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver): | ||
api_type: str = field(default=openai.api_type, kw_only=True) | ||
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) | ||
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) | ||
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) | ||
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) | ||
client: openai.OpenAI = field( | ||
default=Factory( | ||
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), | ||
takes_self=True, | ||
) | ||
) | ||
|
||
def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: | ||
additional_params = {} | ||
|
||
if prompts is not None: | ||
additional_params["prompt"] = ", ".join(prompts) | ||
|
||
transcription = self.client.audio.transcriptions.create( | ||
# Even though we're not actually providing a file to the client, the API still requires that we send a file | ||
# name. We set the file name to use the same format as the audio file so that the API can reject | ||
# it if the format is unsupported. | ||
model=self.model, | ||
file=(f"a.{audio.format}", io.BytesIO(audio.value)), | ||
response_format="json", | ||
**additional_params, | ||
) | ||
|
||
return TextArtifact(value=transcription.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from attrs import define, field | ||
|
||
from griptape.artifacts import AudioArtifact, TextArtifact | ||
from griptape.drivers import BaseAudioTranscriptionDriver | ||
|
||
|
||
@define | ||
class AudioTranscriptionEngine: | ||
audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True) | ||
|
||
def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: | ||
return self.audio_transcription_driver.try_run(audio) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we name these Drivers
BaseSpeechToTextDriver
for consistency with the inverse Drivers?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I've steered a bit too far in the direction of naming drivers based on their artifact interfaces. I think the specificity of this name is helpful, what do you think about adding a similarly specific name to the
BaseTextToSpeechDriver
?BaseSpeechGenerationDriver
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah thinking about it more I think this current convention is the more "correct" one.
Down to do a rename of
BaseTextToSpeechDriver
(though maybe in a separate PR), what do you think aboutBaseAudioGenerationDriver
? This may extend beyond speech in the future?