From 8ec0fb8fb04821a2a0ac8cd3e7ff5fabf9725008 Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 00:06:51 +0200 Subject: [PATCH 01/16] adding files and jax folder --- modules/jax/README.rst | 0 modules/jax/testcontainers/__init__.py | 0 modules/jax/tests/test_jax.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 modules/jax/README.rst create mode 100644 modules/jax/testcontainers/__init__.py create mode 100644 modules/jax/tests/test_jax.py diff --git a/modules/jax/README.rst b/modules/jax/README.rst new file mode 100644 index 000000000..e69de29bb diff --git a/modules/jax/testcontainers/__init__.py b/modules/jax/testcontainers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/jax/tests/test_jax.py b/modules/jax/tests/test_jax.py new file mode 100644 index 000000000..e69de29bb From ea831a85a253047d0fe39db6691c9cf20c045fb6 Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 11:21:39 +0200 Subject: [PATCH 02/16] add jax testcontainer and whisperjax folder --- modules/jax/testcontainers/jax/__init__.py | 68 +++++++++++++++++++ .../{ => whisper-jax}/__init__.py | 0 2 files changed, 68 insertions(+) create mode 100644 modules/jax/testcontainers/jax/__init__.py rename modules/jax/testcontainers/{ => whisper-jax}/__init__.py (100%) diff --git a/modules/jax/testcontainers/jax/__init__.py b/modules/jax/testcontainers/jax/__init__.py new file mode 100644 index 000000000..8ed2a576d --- /dev/null +++ b/modules/jax/testcontainers/jax/__init__.py @@ -0,0 +1,68 @@ +import logging +import urllib.request +from urllib.error import URLError + +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready + +class JAXContainer(DockerContainer): + """ + JAX container for GPU-accelerated numerical computing and machine learning. + + Example: + + .. doctest:: + + >>> import jax + >>> from testcontainers.jax import JAXContainer + + >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: + ... # Connect to the container + ... jax_container.connect() + ... + ... # Run a simple JAX computation + ... result = jax.numpy.add(1, 1) + ... assert result == 2 + """ + + def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs): + super().__init__(image, **kwargs) + self.with_exposed_ports(8888) # Expose Jupyter notebook port + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + + @wait_container_is_ready(URLError) + def _connect(self): + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + res = urllib.request.urlopen(url) + if res.status != 200: + raise Exception(f"Failed to connect to JAX container. Status: {res.status}") + + def connect(self): + """ + Connect to the JAX container and ensure it's ready. + """ + self._connect() + logging.info("Successfully connected to JAX container") + + def get_jupyter_url(self): + """ + Get the URL for accessing the Jupyter notebook server. + """ + return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + + def run_jax_command(self, command): + """ + Run a JAX command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def start(self): + """ + Start the JAX container. + """ + super().start() + logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}") + return self diff --git a/modules/jax/testcontainers/__init__.py b/modules/jax/testcontainers/whisper-jax/__init__.py similarity index 100% rename from modules/jax/testcontainers/__init__.py rename to modules/jax/testcontainers/whisper-jax/__init__.py From 3d9af5070973d236e25081c779b5535b6e7e0d35 Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 12:07:08 +0200 Subject: [PATCH 03/16] add whispe-jax --- .../testcontainers/whisper-jax/__init__.py | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/modules/jax/testcontainers/whisper-jax/__init__.py b/modules/jax/testcontainers/whisper-jax/__init__.py index e69de29bb..eda1c988e 100644 --- a/modules/jax/testcontainers/whisper-jax/__init__.py +++ b/modules/jax/testcontainers/whisper-jax/__init__.py @@ -0,0 +1,126 @@ +import logging +import tempfile +import time +from typing import Optional + +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready +from urllib.error import URLError + +class WhisperJAXContainer(DockerContainer): + """ + Whisper-JAX container for fast speech recognition and transcription. + + Example: + + .. doctest:: + + >>> from testcontainers.whisper_jax import WhisperJAXContainer + + >>> with WhisperJAXContainer("openai/whisper-large-v2") as whisper: + ... # Connect to the container + ... whisper.connect() + ... + ... # Transcribe an audio file + ... result = whisper.transcribe_file("path/to/audio/file.wav") + ... print(result['text']) + ... + ... # Transcribe a YouTube video + ... result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + ... print(result['text']) + """ + + def __init__(self, model_name: str = "openai/whisper-large-v2", **kwargs): + super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) + self.model_name = model_name + self.with_exposed_ports(8888) # Expose Jupyter notebook port + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + + # Install required dependencies + self.with_command("sh -c '" + "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " + "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && " + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " + "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" + "'") + + @wait_container_is_ready(URLError) + def _connect(self): + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + res = urllib.request.urlopen(url) + if res.status != 200: + raise Exception(f"Failed to connect to Whisper-JAX container. Status: {res.status}") + + def connect(self): + """ + Connect to the Whisper-JAX container and ensure it's ready. + """ + self._connect() + logging.info("Successfully connected to Whisper-JAX container") + + def run_command(self, command: str): + """ + Run a Python command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def transcribe_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = False): + """ + Transcribe an audio file using Whisper-JAX. + """ + command = f""" +import soundfile as sf +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) +audio, sr = sf.read("{file_path}") +result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps}) +print(result) +""" + return self.run_command(command) + + def transcribe_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = False): + """ + Transcribe a YouTube video using Whisper-JAX. + """ + command = f""" +import tempfile +import youtube_dl +import soundfile as sf +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp + +def download_youtube_audio(youtube_url, output_file): + ydl_opts = {{ + 'format': 'bestaudio/best', + 'postprocessors': [{{ + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'wav', + 'preferredquality': '192', + }}], + 'outtmpl': output_file, + }} + with youtube_dl.YoutubeDL(ydl_opts) as ydl: + ydl.download([youtube_url]) + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) + +with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: + download_youtube_audio("{youtube_url}", temp_file.name) + audio, sr = sf.read(temp_file.name) + result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps}) + print(result) +""" + return self.run_command(command) + + def start(self): + """ + Start the Whisper-JAX container. + """ + super().start() + logging.info(f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}") + return self From 75276244e521293fc604e419acee00d2dd8ba508 Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 12:15:46 +0200 Subject: [PATCH 04/16] add diarization --- .../whisper-diarization/__init__.py | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 modules/jax/testcontainers/whisper-diarization/__init__.py diff --git a/modules/jax/testcontainers/whisper-diarization/__init__.py b/modules/jax/testcontainers/whisper-diarization/__init__.py new file mode 100644 index 000000000..cabc3a37a --- /dev/null +++ b/modules/jax/testcontainers/whisper-diarization/__init__.py @@ -0,0 +1,249 @@ +import logging +import tempfile +from typing import Optional + +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_container_is_ready +from urllib.error import URLError + +class JAXWhisperDiarizationContainer(DockerContainer): + """ + JAX-Whisper-Diarization container for fast speech recognition, transcription, and speaker diarization. + + Example: + + .. doctest:: + + >>> logging.basicConfig(level=logging.INFO) + + ... # You need to provide your Hugging Face token to use the pyannote.audio models + >>> hf_token = "your_huggingface_token_here" + + >>> with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: + ... whisper_diarization.connect() + ... + ... # Example: Transcribe and diarize an audio file + ... result = whisper_diarization.transcribe_and_diarize_file("/path/to/audio/file.wav") + ... print(f"Transcription and Diarization: {result}") + ... + ... # Example: Transcribe and diarize a YouTube video + ... result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + ... print(f"YouTube Transcription and Diarization: {result}") + """ + + def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Optional[str] = None, **kwargs): + super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) + self.model_name = model_name + self.hf_token = hf_token + self.with_exposed_ports(8888) # Expose Jupyter notebook port + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + + # Install required dependencies + self.with_command("sh -c '" + "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " + "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && " + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " + "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" + "'") + + @wait_container_is_ready(URLError) + def _connect(self): + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + res = urllib.request.urlopen(url) + if res.status != 200: + raise Exception(f"Failed to connect to JAX-Whisper-Diarization container. Status: {res.status}") + + def connect(self): + """ + Connect to the JAX-Whisper-Diarization container and ensure it's ready. + """ + self._connect() + logging.info("Successfully connected to JAX-Whisper-Diarization container") + + def run_command(self, command: str): + """ + Run a Python command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def transcribe_and_diarize_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True): + """ + Transcribe and diarize an audio file using Whisper-JAX and pyannote. + """ + command = f""" +import soundfile as sf +import torch +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp +from pyannote.audio import Pipeline +import numpy as np + +def align(transcription, segments, group_by_speaker=True): + transcription_split = transcription.split("\\n") + transcript = [] + for chunk in transcription_split: + start_end, text = chunk[1:].split("] ") + start, end = start_end.split("->") + start, end = float(start), float(end) + transcript.append({{"timestamp": (start, end), "text": text}}) + + new_segments = [] + prev_segment = segments[0] + for i in range(1, len(segments)): + cur_segment = segments[i] + if cur_segment["label"] != prev_segment["label"]: + new_segments.append({{ + "segment": {{"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}}, + "speaker": prev_segment["label"] + }}) + prev_segment = segments[i] + new_segments.append({{ + "segment": {{"start": prev_segment["segment"]["start"], "end": segments[-1]["segment"]["end"]}}, + "speaker": prev_segment["label"] + }}) + + end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) + segmented_preds = [] + + for segment in new_segments: + end_time = segment["segment"]["end"] + upto_idx = np.argmin(np.abs(end_timestamps - end_time)) + + if group_by_speaker: + segmented_preds.append({{ + "speaker": segment["speaker"], + "text": " ".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), + "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]) + }}) + else: + for i in range(upto_idx + 1): + segmented_preds.append({{"speaker": segment["speaker"], **transcript[i]}}) + + transcript = transcript[upto_idx + 1 :] + end_timestamps = end_timestamps[upto_idx + 1 :] + + return segmented_preds + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) +diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="{self.hf_token}") + +audio, sr = sf.read("{file_path}") +inputs = {{"array": audio, "sampling_rate": sr}} + +# Transcribe +result = pipeline(inputs, task="{task}", return_timestamps={return_timestamps}) + +# Diarize +diarization = diarization_pipeline({{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr}}) +segments = diarization.for_json()["content"] + +# Align transcription and diarization +aligned_result = align(result["text"], segments, group_by_speaker={group_by_speaker}) +print(aligned_result) +""" + return self.run_command(command) + + def transcribe_and_diarize_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True): + """ + Transcribe and diarize a YouTube video using Whisper-JAX and pyannote. + """ + command = f""" +import tempfile +import youtube_dl +import soundfile as sf +import torch +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp +from pyannote.audio import Pipeline +import numpy as np + +def download_youtube_audio(youtube_url, output_file): + ydl_opts = {{ + 'format': 'bestaudio/best', + 'postprocessors': [{{ + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'wav', + 'preferredquality': '192', + }}], + 'outtmpl': output_file, + }} + with youtube_dl.YoutubeDL(ydl_opts) as ydl: + ydl.download([youtube_url]) + +def align(transcription, segments, group_by_speaker=True): + transcription_split = transcription.split("\\n") + transcript = [] + for chunk in transcription_split: + start_end, text = chunk[1:].split("] ") + start, end = start_end.split("->") + start, end = float(start), float(end) + transcript.append({{"timestamp": (start, end), "text": text}}) + + new_segments = [] + prev_segment = segments[0] + for i in range(1, len(segments)): + cur_segment = segments[i] + if cur_segment["label"] != prev_segment["label"]: + new_segments.append({{ + "segment": {{"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}}, + "speaker": prev_segment["label"] + }}) + prev_segment = segments[i] + new_segments.append({{ + "segment": {{"start": prev_segment["segment"]["start"], "end": segments[-1]["segment"]["end"]}}, + "speaker": prev_segment["label"] + }}) + + end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) + segmented_preds = [] + + for segment in new_segments: + end_time = segment["segment"]["end"] + upto_idx = np.argmin(np.abs(end_timestamps - end_time)) + + if group_by_speaker: + segmented_preds.append({{ + "speaker": segment["speaker"], + "text": " ".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), + "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]) + }}) + else: + for i in range(upto_idx + 1): + segmented_preds.append({{"speaker": segment["speaker"], **transcript[i]}}) + + transcript = transcript[upto_idx + 1 :] + end_timestamps = end_timestamps[upto_idx + 1 :] + + return segmented_preds + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) +diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="{self.hf_token}") + +with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: + download_youtube_audio("{youtube_url}", temp_file.name) + audio, sr = sf.read(temp_file.name) + inputs = {{"array": audio, "sampling_rate": sr}} + + # Transcribe + result = pipeline(inputs, task="{task}", return_timestamps={return_timestamps}) + + # Diarize + diarization = diarization_pipeline({{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr}}) + segments = diarization.for_json()["content"] + + # Align transcription and diarization + aligned_result = align(result["text"], segments, group_by_speaker={group_by_speaker}) + print(aligned_result) +""" + return self.run_command(command) + + def start(self): + """ + Start the JAX-Whisper-Diarization container. + """ + super().start() + logging.info(f"JAX-Whisper-Diarization container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}") + return self From e1b978fb6509276a33164861b33285b5745484aa Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 12:28:14 +0200 Subject: [PATCH 05/16] add tests --- .../whisper-diarization/__init__.py | 4 +-- modules/jax/tests/test_jax.py | 28 +++++++++++++++++ modules/jax/tests/test_whisper.py | 30 ++++++++++++++++++ modules/jax/tests/test_whisper_diarization.py | 31 +++++++++++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 modules/jax/tests/test_whisper.py create mode 100644 modules/jax/tests/test_whisper_diarization.py diff --git a/modules/jax/testcontainers/whisper-diarization/__init__.py b/modules/jax/testcontainers/whisper-diarization/__init__.py index cabc3a37a..05a897eea 100644 --- a/modules/jax/testcontainers/whisper-diarization/__init__.py +++ b/modules/jax/testcontainers/whisper-diarization/__init__.py @@ -2,8 +2,8 @@ import tempfile from typing import Optional -from testcontainers.core.container import DockerContainer -from testcontainers.core.waiting_utils import wait_container_is_ready +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready from urllib.error import URLError class JAXWhisperDiarizationContainer(DockerContainer): diff --git a/modules/jax/tests/test_jax.py b/modules/jax/tests/test_jax.py index e69de29bb..635a37ca6 100644 --- a/modules/jax/tests/test_jax.py +++ b/modules/jax/tests/test_jax.py @@ -0,0 +1,28 @@ +import pytest +from testcontainers.jax import JAXContainer + +def test_jax_container(): + with JAXContainer() as jax_container: + jax_container.connect() + + # Test running a simple JAX computation + result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))") + assert "2" in result.output.decode() + +def test_jax_container_gpu_support(): + with JAXContainer() as jax_container: + jax_container.connect() + + # Test GPU availability + result = jax_container.run_jax_command( + "import jax; print(jax.devices())" + ) + assert "gpu" in result.output.decode().lower() + +def test_jax_container_jupyter(): + with JAXContainer() as jax_container: + jax_container.connect() + + jupyter_url = jax_container.get_jupyter_url() + assert jupyter_url.startswith("http://") + assert ":8888" in jupyter_url \ No newline at end of file diff --git a/modules/jax/tests/test_whisper.py b/modules/jax/tests/test_whisper.py new file mode 100644 index 000000000..7e9d320b3 --- /dev/null +++ b/modules/jax/tests/test_whisper.py @@ -0,0 +1,30 @@ +import pytest +from testcontainers.whisper_jax import WhisperJAXContainer + +@pytest.mark.parametrize("model_name", ["openai/whisper-tiny", "openai/whisper-base"]) +def test_whisper_jax_container(model_name): + with WhisperJAXContainer(model_name) as whisper: + whisper.connect() + + # Test file transcription + result = whisper.transcribe_file("/path/to/test/audio.wav") + assert isinstance(result, dict) + assert 'text' in result + assert isinstance(result['text'], str) + + # Test YouTube transcription + result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + assert isinstance(result, dict) + assert 'text' in result + assert isinstance(result['text'], str) + +def test_whisper_jax_container_with_timestamps(): + with WhisperJAXContainer() as whisper: + whisper.connect() + + result = whisper.transcribe_file("/path/to/test/audio.wav", return_timestamps=True) + assert isinstance(result, dict) + assert 'text' in result + assert 'chunks' in result + assert isinstance(result['chunks'], list) + assert all('timestamp' in chunk for chunk in result['chunks']) \ No newline at end of file diff --git a/modules/jax/tests/test_whisper_diarization.py b/modules/jax/tests/test_whisper_diarization.py new file mode 100644 index 000000000..36924ffb2 --- /dev/null +++ b/modules/jax/tests/test_whisper_diarization.py @@ -0,0 +1,31 @@ +import pytest +from testcontainers.jax_whisper_diarization import JAXWhisperDiarizationContainer + +@pytest.fixture(scope="module") +def hf_token(): + return "your_huggingface_token_here" # Replace with a valid token or use an environment variable + +def test_jax_whisper_diarization_container(hf_token): + with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: + whisper_diarization.connect() + + # Test file transcription and diarization + result = whisper_diarization.transcribe_and_diarize_file("/path/to/test/audio.wav") + assert isinstance(result, list) + assert all(isinstance(item, dict) for item in result) + assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result) + + # Test YouTube transcription and diarization + result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + assert isinstance(result, list) + assert all(isinstance(item, dict) for item in result) + assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result) + +def test_jax_whisper_diarization_container_without_grouping(hf_token): + with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: + whisper_diarization.connect() + + result = whisper_diarization.transcribe_and_diarize_file("/path/to/test/audio.wav", group_by_speaker=False) + assert isinstance(result, list) + assert all(isinstance(item, dict) for item in result) + assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result) \ No newline at end of file From e60645957590f2934297d84047b39ebfd63a8506 Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 12:38:17 +0200 Subject: [PATCH 06/16] fix folder name --- modules/jax/tests/{test_whisper.py => test_whisper_jax.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/jax/tests/{test_whisper.py => test_whisper_jax.py} (100%) diff --git a/modules/jax/tests/test_whisper.py b/modules/jax/tests/test_whisper_jax.py similarity index 100% rename from modules/jax/tests/test_whisper.py rename to modules/jax/tests/test_whisper_jax.py From 2a650d2fe8dded6edadfd031c2e3f5dcec2e0352 Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 17:27:39 +0200 Subject: [PATCH 07/16] add readme draft --- modules/jax/README.rst | 36 +++++ .../testcontainers/whisper-jax/__init__.py | 127 +----------------- pyproject.toml | 1 + 3 files changed, 38 insertions(+), 126 deletions(-) diff --git a/modules/jax/README.rst b/modules/jax/README.rst index e69de29bb..b7ece8e4d 100644 --- a/modules/jax/README.rst +++ b/modules/jax/README.rst @@ -0,0 +1,36 @@ +# Testcontainers : JAX + +## Docker Containers for JAX with GPU Support + +1. **Official JAX Docker Container** + - **Container**: `jax/jax:cuda-12.0` + - **Documentation**: [JAX Docker](https://github.com/google/jax/blob/main/docker/README.md) + +2. **NVIDIA Docker Container** + - **Container**: `nvidia/cuda:12.0-cudnn8-devel-ubuntu20.04` + - **Documentation**: [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda) + +## Benefits of Having This Container + +1. **Optimized Performance**: JAX uses XLA to compile and run NumPy programs on GPUs, which can significantly speed up numerical computations and machine learning tasks. A container specifically optimized for JAX with CUDA ensures that the environment is configured to leverage GPU acceleration fully. + +2. **Reproducibility**: Containers encapsulate all dependencies, libraries, and configurations needed to run JAX, ensuring that the environment is consistent across different systems. This is crucial for reproducible research and development. + +3. **Ease of Use**: Users can easily pull and run the container without worrying about the complex setup required for GPU support and JAX configuration. This reduces the barrier to entry for new users and accelerates development workflows. + +4. **Isolation and Security**: Containers provide an isolated environment, which enhances security by limiting the impact of potential vulnerabilities. It also avoids conflicts with other software on the host system. + +## Relevant Reading Material + +1. **JAX Documentation** + - [JAX Quickstart](https://github.com/google/jax#quickstart) + - [JAX Transformations](https://github.com/google/jax#transformations) + - [JAX Installation Guide](https://github.com/google/jax#installation) + +2. **NVIDIA Docker Documentation** + - [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda) + - [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker) + +3. **Docker Best Practices** + - [Docker Documentation](https://docs.docker.com/get-started/) + - [Best practices for writing Dockerfiles](https://docs.docker.com/develop/develop-images/dockerfile_best-practices/) \ No newline at end of file diff --git a/modules/jax/testcontainers/whisper-jax/__init__.py b/modules/jax/testcontainers/whisper-jax/__init__.py index eda1c988e..0519ecba6 100644 --- a/modules/jax/testcontainers/whisper-jax/__init__.py +++ b/modules/jax/testcontainers/whisper-jax/__init__.py @@ -1,126 +1 @@ -import logging -import tempfile -import time -from typing import Optional - -from core.testcontainers.core.container import DockerContainer -from core.testcontainers.core.waiting_utils import wait_container_is_ready -from urllib.error import URLError - -class WhisperJAXContainer(DockerContainer): - """ - Whisper-JAX container for fast speech recognition and transcription. - - Example: - - .. doctest:: - - >>> from testcontainers.whisper_jax import WhisperJAXContainer - - >>> with WhisperJAXContainer("openai/whisper-large-v2") as whisper: - ... # Connect to the container - ... whisper.connect() - ... - ... # Transcribe an audio file - ... result = whisper.transcribe_file("path/to/audio/file.wav") - ... print(result['text']) - ... - ... # Transcribe a YouTube video - ... result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") - ... print(result['text']) - """ - - def __init__(self, model_name: str = "openai/whisper-large-v2", **kwargs): - super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) - self.model_name = model_name - self.with_exposed_ports(8888) # Expose Jupyter notebook port - self.with_env("NVIDIA_VISIBLE_DEVICES", "all") - self.with_env("CUDA_VISIBLE_DEVICES", "all") - self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support - - # Install required dependencies - self.with_command("sh -c '" - "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " - "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && " - "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " - "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" - "'") - - @wait_container_is_ready(URLError) - def _connect(self): - url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" - res = urllib.request.urlopen(url) - if res.status != 200: - raise Exception(f"Failed to connect to Whisper-JAX container. Status: {res.status}") - - def connect(self): - """ - Connect to the Whisper-JAX container and ensure it's ready. - """ - self._connect() - logging.info("Successfully connected to Whisper-JAX container") - - def run_command(self, command: str): - """ - Run a Python command inside the container. - """ - exec_result = self.exec(f"python -c '{command}'") - return exec_result - - def transcribe_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = False): - """ - Transcribe an audio file using Whisper-JAX. - """ - command = f""" -import soundfile as sf -from whisper_jax import FlaxWhisperPipline -import jax.numpy as jnp - -pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) -audio, sr = sf.read("{file_path}") -result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps}) -print(result) -""" - return self.run_command(command) - - def transcribe_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = False): - """ - Transcribe a YouTube video using Whisper-JAX. - """ - command = f""" -import tempfile -import youtube_dl -import soundfile as sf -from whisper_jax import FlaxWhisperPipline -import jax.numpy as jnp - -def download_youtube_audio(youtube_url, output_file): - ydl_opts = {{ - 'format': 'bestaudio/best', - 'postprocessors': [{{ - 'key': 'FFmpegExtractAudio', - 'preferredcodec': 'wav', - 'preferredquality': '192', - }}], - 'outtmpl': output_file, - }} - with youtube_dl.YoutubeDL(ydl_opts) as ydl: - ydl.download([youtube_url]) - -pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) - -with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: - download_youtube_audio("{youtube_url}", temp_file.name) - audio, sr = sf.read(temp_file.name) - result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps}) - print(result) -""" - return self.run_command(command) - - def start(self): - """ - Start the Whisper-JAX container. - """ - super().start() - logging.info(f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}") - return self + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3bccf8800..ddc9eb96f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ packages = [ { include = "testcontainers", from = "modules/test_module_import"}, { include = "testcontainers", from = "modules/google" }, { include = "testcontainers", from = "modules/influxdb" }, + { include = "testcontainers", from = "modules/jax" }, { include = "testcontainers", from = "modules/k3s" }, { include = "testcontainers", from = "modules/kafka" }, { include = "testcontainers", from = "modules/keycloak" }, From b701be0e46b8b6d521b89e2cd50cd4355fc454e3 Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 18:24:37 +0200 Subject: [PATCH 08/16] add installation instructions --- modules/jax/README.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/modules/jax/README.rst b/modules/jax/README.rst index b7ece8e4d..c1894c5ba 100644 --- a/modules/jax/README.rst +++ b/modules/jax/README.rst @@ -20,6 +20,28 @@ 4. **Isolation and Security**: Containers provide an isolated environment, which enhances security by limiting the impact of potential vulnerabilities. It also avoids conflicts with other software on the host system. +## Troubleshooting + +**Ensure Docker is configured to use the NVIDIA runtime**: + - You need to install the NVIDIA Container Toolkit. Follow the instructions for your operating system: [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). + - Update your Docker daemon configuration to include the NVIDIA runtime. Edit the Docker daemon configuration file, typically located at `/etc/docker/daemon.json`, to include the following: + + ```json + { + "runtimes": { + "nvidia": { + "path": "nvidia-container-runtime", + "runtimeArgs": [] + } + } + } + ``` + + - Restart the Docker daemon to apply the changes: + ```sh + sudo systemctl restart docker + ``` + ## Relevant Reading Material 1. **JAX Documentation** From 0cdddf8f9c41a29f2d324820eb09aaca6139094d Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 18:30:27 +0200 Subject: [PATCH 09/16] add jax to pyproject.toml --- modules/jax/testcontainers/jax/__init__.py | 25 +++++++++++++--------- pyproject.toml | 1 + 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/modules/jax/testcontainers/jax/__init__.py b/modules/jax/testcontainers/jax/__init__.py index 8ed2a576d..e9844336a 100644 --- a/modules/jax/testcontainers/jax/__init__.py +++ b/modules/jax/testcontainers/jax/__init__.py @@ -11,18 +11,23 @@ class JAXContainer(DockerContainer): Example: - .. doctest:: + .. doctest:: - >>> import jax - >>> from testcontainers.jax import JAXContainer + >>> import jax + >>> from testcontainers.jax import JAXContainer - >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: - ... # Connect to the container - ... jax_container.connect() - ... - ... # Run a simple JAX computation - ... result = jax.numpy.add(1, 1) - ... assert result == 2 + >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: + ... # Connect to the container + ... jax_container.connect() + ... + ... # Run a simple JAX computation + ... result = jax.numpy.add(1, 1) + ... assert result == 2 + + .. auto-class:: JAXContainer + :members: + :undoc-members: + :show-inheritance: """ def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs): diff --git a/pyproject.toml b/pyproject.toml index ddc9eb96f..44c5f23e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ neo4j = ["neo4j"] nginx = [] opensearch = ["opensearch-py"] ollama = [] +jax = ["jax"] oracle = ["sqlalchemy", "oracledb"] oracle-free = ["sqlalchemy", "oracledb"] postgres = [] From 06e7c150fae14aeec0e004d9abb700463d180028 Mon Sep 17 00:00:00 2001 From: Tonic Date: Sun, 4 Aug 2024 23:25:04 +0200 Subject: [PATCH 10/16] increase timeout and use get_logs, refactor whisper-cuda into folder --- .../{jax => jax_cuda}/__init__.py | 25 +++- .../whisper-diarization/__init__.py | 0 .../whisper-transcription/__init__.py | 126 ++++++++++++++++++ .../testcontainers/whisper-jax/__init__.py | 1 - modules/jax/tests/test_jax.py | 2 +- 5 files changed, 150 insertions(+), 4 deletions(-) rename modules/jax/testcontainers/{jax => jax_cuda}/__init__.py (74%) rename modules/jax/testcontainers/{ => whisper-cuda}/whisper-diarization/__init__.py (100%) create mode 100644 modules/jax/testcontainers/whisper-cuda/whisper-transcription/__init__.py delete mode 100644 modules/jax/testcontainers/whisper-jax/__init__.py diff --git a/modules/jax/testcontainers/jax/__init__.py b/modules/jax/testcontainers/jax_cuda/__init__.py similarity index 74% rename from modules/jax/testcontainers/jax/__init__.py rename to modules/jax/testcontainers/jax_cuda/__init__.py index e9844336a..6b01ac96d 100644 --- a/modules/jax/testcontainers/jax/__init__.py +++ b/modules/jax/testcontainers/jax_cuda/__init__.py @@ -4,6 +4,8 @@ from core.testcontainers.core.container import DockerContainer from core.testcontainers.core.waiting_utils import wait_container_is_ready +from core.testcontainers.core.config import testcontainers_config +from core.testcontainers.core.waiting_utils import wait_for_logs class JAXContainer(DockerContainer): """ @@ -36,11 +38,12 @@ def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs): self.with_env("NVIDIA_VISIBLE_DEVICES", "all") self.with_env("CUDA_VISIBLE_DEVICES", "all") self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + self.start_timeout = 600 # 10 minutes @wait_container_is_ready(URLError) def _connect(self): url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" - res = urllib.request.urlopen(url) + res = urllib.request.urlopen(url, timeout=self.start_timeout) if res.status != 200: raise Exception(f"Failed to connect to JAX container. Status: {res.status}") @@ -64,10 +67,28 @@ def run_jax_command(self, command): exec_result = self.exec(f"python -c '{command}'") return exec_result + def _wait_for_container_to_be_ready(self): + wait_for_logs(self, "Jupyter Server", timeout=self.start_timeout) + def start(self): """ - Start the JAX container. + Start the JAX container and wait for it to be ready. """ super().start() + self._wait_for_container_to_be_ready() logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}") return self + + def stop(self, force=True): + """ + Stop the JAX container. + """ + super().stop(force) + logging.info("JAX container stopped.") + + @property + def timeout(self): + """ + Get the container start timeout. + """ + return self.start_timeout \ No newline at end of file diff --git a/modules/jax/testcontainers/whisper-diarization/__init__.py b/modules/jax/testcontainers/whisper-cuda/whisper-diarization/__init__.py similarity index 100% rename from modules/jax/testcontainers/whisper-diarization/__init__.py rename to modules/jax/testcontainers/whisper-cuda/whisper-diarization/__init__.py diff --git a/modules/jax/testcontainers/whisper-cuda/whisper-transcription/__init__.py b/modules/jax/testcontainers/whisper-cuda/whisper-transcription/__init__.py new file mode 100644 index 000000000..eda1c988e --- /dev/null +++ b/modules/jax/testcontainers/whisper-cuda/whisper-transcription/__init__.py @@ -0,0 +1,126 @@ +import logging +import tempfile +import time +from typing import Optional + +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready +from urllib.error import URLError + +class WhisperJAXContainer(DockerContainer): + """ + Whisper-JAX container for fast speech recognition and transcription. + + Example: + + .. doctest:: + + >>> from testcontainers.whisper_jax import WhisperJAXContainer + + >>> with WhisperJAXContainer("openai/whisper-large-v2") as whisper: + ... # Connect to the container + ... whisper.connect() + ... + ... # Transcribe an audio file + ... result = whisper.transcribe_file("path/to/audio/file.wav") + ... print(result['text']) + ... + ... # Transcribe a YouTube video + ... result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + ... print(result['text']) + """ + + def __init__(self, model_name: str = "openai/whisper-large-v2", **kwargs): + super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) + self.model_name = model_name + self.with_exposed_ports(8888) # Expose Jupyter notebook port + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + + # Install required dependencies + self.with_command("sh -c '" + "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " + "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && " + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " + "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" + "'") + + @wait_container_is_ready(URLError) + def _connect(self): + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + res = urllib.request.urlopen(url) + if res.status != 200: + raise Exception(f"Failed to connect to Whisper-JAX container. Status: {res.status}") + + def connect(self): + """ + Connect to the Whisper-JAX container and ensure it's ready. + """ + self._connect() + logging.info("Successfully connected to Whisper-JAX container") + + def run_command(self, command: str): + """ + Run a Python command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def transcribe_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = False): + """ + Transcribe an audio file using Whisper-JAX. + """ + command = f""" +import soundfile as sf +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) +audio, sr = sf.read("{file_path}") +result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps}) +print(result) +""" + return self.run_command(command) + + def transcribe_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = False): + """ + Transcribe a YouTube video using Whisper-JAX. + """ + command = f""" +import tempfile +import youtube_dl +import soundfile as sf +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp + +def download_youtube_audio(youtube_url, output_file): + ydl_opts = {{ + 'format': 'bestaudio/best', + 'postprocessors': [{{ + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'wav', + 'preferredquality': '192', + }}], + 'outtmpl': output_file, + }} + with youtube_dl.YoutubeDL(ydl_opts) as ydl: + ydl.download([youtube_url]) + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) + +with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: + download_youtube_audio("{youtube_url}", temp_file.name) + audio, sr = sf.read(temp_file.name) + result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps}) + print(result) +""" + return self.run_command(command) + + def start(self): + """ + Start the Whisper-JAX container. + """ + super().start() + logging.info(f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}") + return self diff --git a/modules/jax/testcontainers/whisper-jax/__init__.py b/modules/jax/testcontainers/whisper-jax/__init__.py deleted file mode 100644 index 0519ecba6..000000000 --- a/modules/jax/testcontainers/whisper-jax/__init__.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/modules/jax/tests/test_jax.py b/modules/jax/tests/test_jax.py index 635a37ca6..90329259f 100644 --- a/modules/jax/tests/test_jax.py +++ b/modules/jax/tests/test_jax.py @@ -1,5 +1,5 @@ import pytest -from testcontainers.jax import JAXContainer +from modules.jax.testcontainers.jax_cuda import JAXContainer def test_jax_container(): with JAXContainer() as jax_container: From 0b4033ec064b921c8ec4a7e79f9ed6eda45bac6a Mon Sep 17 00:00:00 2001 From: Tonic Date: Mon, 5 Aug 2024 10:06:36 +0200 Subject: [PATCH 11/16] add huggingface amd jax --- .../jax/testcontainers/jax_amd/__init__.py | 94 +++++++++++++++++++ .../whisper_diarization}/__init__.py | 0 .../whisper_transcription}/__init__.py | 0 modules/jax/tests/test_whisper_diarization.py | 2 +- modules/jax/tests/test_whisper_jax.py | 2 +- 5 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 modules/jax/testcontainers/jax_amd/__init__.py rename modules/jax/testcontainers/{whisper-cuda/whisper-diarization => whisper_cuda/whisper_diarization}/__init__.py (100%) rename modules/jax/testcontainers/{whisper-cuda/whisper-transcription => whisper_cuda/whisper_transcription}/__init__.py (100%) diff --git a/modules/jax/testcontainers/jax_amd/__init__.py b/modules/jax/testcontainers/jax_amd/__init__.py new file mode 100644 index 000000000..9929cd6bf --- /dev/null +++ b/modules/jax/testcontainers/jax_amd/__init__.py @@ -0,0 +1,94 @@ +import logging +import urllib.request +from urllib.error import URLError + +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready +from core.testcontainers.core.config import testcontainers_config +from core.testcontainers.core.waiting_utils import wait_for_logs + +class JAXContainer(DockerContainer): + """ + JAX container for GPU-accelerated numerical computing and machine learning. + + Example: + + .. doctest:: + + >>> import jax + >>> from testcontainers.jax import JAXContainer + + >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: + ... # Connect to the container + ... jax_container.connect() + ... + ... # Run a simple JAX computation + ... result = jax.numpy.add(1, 1) + ... assert result == 2 + + .. auto-class:: JAXContainer + :members: + :undoc-members: + :show-inheritance: + """ + + def __init__(self, image="huggingface/transformers-jax-light:latest", **kwargs): + super().__init__(image, **kwargs) + self.with_exposed_ports(8888) # Expose Jupyter notebook port + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + self.start_timeout = 600 # 10 minutes + + @wait_container_is_ready(URLError) + def _connect(self): + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + res = urllib.request.urlopen(url, timeout=self.start_timeout) + if res.status != 200: + raise Exception(f"Failed to connect to JAX container. Status: {res.status}") + + def connect(self): + """ + Connect to the JAX container and ensure it's ready. + """ + self._connect() + logging.info("Successfully connected to JAX container") + + def get_jupyter_url(self): + """ + Get the URL for accessing the Jupyter notebook server. + """ + return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + + def run_jax_command(self, command): + """ + Run a JAX command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def _wait_for_container_to_be_ready(self): + wait_for_logs(self, "Jupyter Server", timeout=self.start_timeout) + + def start(self): + """ + Start the JAX container and wait for it to be ready. + """ + super().start() + self._wait_for_container_to_be_ready() + logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}") + return self + + def stop(self, force=True): + """ + Stop the JAX container. + """ + super().stop(force) + logging.info("JAX container stopped.") + + @property + def timeout(self): + """ + Get the container start timeout. + """ + return self.start_timeout \ No newline at end of file diff --git a/modules/jax/testcontainers/whisper-cuda/whisper-diarization/__init__.py b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py similarity index 100% rename from modules/jax/testcontainers/whisper-cuda/whisper-diarization/__init__.py rename to modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py diff --git a/modules/jax/testcontainers/whisper-cuda/whisper-transcription/__init__.py b/modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py similarity index 100% rename from modules/jax/testcontainers/whisper-cuda/whisper-transcription/__init__.py rename to modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py diff --git a/modules/jax/tests/test_whisper_diarization.py b/modules/jax/tests/test_whisper_diarization.py index 36924ffb2..f31ce5f3b 100644 --- a/modules/jax/tests/test_whisper_diarization.py +++ b/modules/jax/tests/test_whisper_diarization.py @@ -1,5 +1,5 @@ import pytest -from testcontainers.jax_whisper_diarization import JAXWhisperDiarizationContainer +from modules.jax.testcontainers.whisper_cuda.whisper_diarization import JAXWhisperDiarizationContainer @pytest.fixture(scope="module") def hf_token(): diff --git a/modules/jax/tests/test_whisper_jax.py b/modules/jax/tests/test_whisper_jax.py index 7e9d320b3..12534cce2 100644 --- a/modules/jax/tests/test_whisper_jax.py +++ b/modules/jax/tests/test_whisper_jax.py @@ -1,5 +1,5 @@ import pytest -from testcontainers.whisper_jax import WhisperJAXContainer +from modules.jax.testcontainers.whisper_cuda.whisper_transcription import WhisperJAXContainer @pytest.mark.parametrize("model_name", ["openai/whisper-tiny", "openai/whisper-base"]) def test_whisper_jax_container(model_name): From 54f842d54b6cb4cdda2c896fe23e3744163d9b33 Mon Sep 17 00:00:00 2001 From: Tonic Date: Mon, 5 Aug 2024 20:59:48 +0200 Subject: [PATCH 12/16] add connect method , remove jupyter port connect --- .../jax/testcontainers/jax_cuda/__init__.py | 59 ++++++++++++------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/modules/jax/testcontainers/jax_cuda/__init__.py b/modules/jax/testcontainers/jax_cuda/__init__.py index 6b01ac96d..c28de41c2 100644 --- a/modules/jax/testcontainers/jax_cuda/__init__.py +++ b/modules/jax/testcontainers/jax_cuda/__init__.py @@ -1,11 +1,9 @@ import logging -import urllib.request +import time from urllib.error import URLError from core.testcontainers.core.container import DockerContainer -from core.testcontainers.core.waiting_utils import wait_container_is_ready -from core.testcontainers.core.config import testcontainers_config -from core.testcontainers.core.waiting_utils import wait_for_logs +from core.testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs class JAXContainer(DockerContainer): """ @@ -15,7 +13,6 @@ class JAXContainer(DockerContainer): .. doctest:: - >>> import jax >>> from testcontainers.jax import JAXContainer >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: @@ -23,8 +20,8 @@ class JAXContainer(DockerContainer): ... jax_container.connect() ... ... # Run a simple JAX computation - ... result = jax.numpy.add(1, 1) - ... assert result == 2 + ... result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))") + ... assert "2" in result.output .. auto-class:: JAXContainer :members: @@ -34,31 +31,49 @@ class JAXContainer(DockerContainer): def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs): super().__init__(image, **kwargs) - self.with_exposed_ports(8888) # Expose Jupyter notebook port self.with_env("NVIDIA_VISIBLE_DEVICES", "all") self.with_env("CUDA_VISIBLE_DEVICES", "all") self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support - self.start_timeout = 600 # 10 minutes + self.start_timeout = 600 # 10 minutes + self.connection_retries = 5 + self.connection_retry_delay = 10 # seconds @wait_container_is_ready(URLError) def _connect(self): - url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" - res = urllib.request.urlopen(url, timeout=self.start_timeout) - if res.status != 200: - raise Exception(f"Failed to connect to JAX container. Status: {res.status}") + for attempt in range(self.connection_retries): + try: + # Check if JAX is properly installed and functioning + result = self.run_jax_command( + "import jax; import jaxlib; " + "print(f'JAX version: {jax.__version__}'); " + "print(f'JAXlib version: {jaxlib.__version__}'); " + "print(f'Available devices: {jax.devices()}'); " + "print(jax.numpy.add(1, 1))" + ) + + if "JAX version" in result.output and "Available devices" in result.output: + logging.info(f"JAX environment verified:\n{result.output}") + return True + else: + raise Exception("JAX environment check failed") + + except Exception as e: + if attempt < self.connection_retries - 1: + logging.warning(f"Connection attempt {attempt + 1} failed. Retrying in {self.connection_retry_delay} seconds...") + time.sleep(self.connection_retry_delay) + else: + raise Exception(f"Failed to connect to JAX container after {self.connection_retries} attempts: {str(e)}") + + return False def connect(self): """ Connect to the JAX container and ensure it's ready. + This method verifies that JAX is properly installed and functioning. + It also checks for available devices, including GPUs if applicable. """ self._connect() - logging.info("Successfully connected to JAX container") - - def get_jupyter_url(self): - """ - Get the URL for accessing the Jupyter notebook server. - """ - return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + logging.info("Successfully connected to JAX container and verified the environment") def run_jax_command(self, command): """ @@ -68,7 +83,7 @@ def run_jax_command(self, command): return exec_result def _wait_for_container_to_be_ready(self): - wait_for_logs(self, "Jupyter Server", timeout=self.start_timeout) + wait_for_logs(self, "JAX is ready", timeout=self.start_timeout) def start(self): """ @@ -76,7 +91,7 @@ def start(self): """ super().start() self._wait_for_container_to_be_ready() - logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}") + logging.info("JAX container started and ready.") return self def stop(self, force=True): From 8dabef01a5a6c56c3a0913def24f505cdc6aa336 Mon Sep 17 00:00:00 2001 From: Tonic Date: Mon, 5 Aug 2024 21:26:53 +0200 Subject: [PATCH 13/16] add test for jaxcontainer --- modules/jax/tests/test_jax.py | 61 ++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/modules/jax/tests/test_jax.py b/modules/jax/tests/test_jax.py index 90329259f..c2138e8b7 100644 --- a/modules/jax/tests/test_jax.py +++ b/modules/jax/tests/test_jax.py @@ -1,28 +1,43 @@ import pytest from modules.jax.testcontainers.jax_cuda import JAXContainer -def test_jax_container(): - with JAXContainer() as jax_container: - jax_container.connect() - - # Test running a simple JAX computation - result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))") - assert "2" in result.output.decode() +@pytest.fixture(scope="module") +def jax_container(): + with JAXContainer() as container: + container.connect() + yield container -def test_jax_container_gpu_support(): - with JAXContainer() as jax_container: - jax_container.connect() - - # Test GPU availability - result = jax_container.run_jax_command( - "import jax; print(jax.devices())" - ) - assert "gpu" in result.output.decode().lower() +def test_jax_container_basic_computation(jax_container): + result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))") + assert "2" in result.output.decode(), "Basic JAX computation failed" -def test_jax_container_jupyter(): - with JAXContainer() as jax_container: - jax_container.connect() - - jupyter_url = jax_container.get_jupyter_url() - assert jupyter_url.startswith("http://") - assert ":8888" in jupyter_url \ No newline at end of file +def test_jax_container_version(jax_container): + result = jax_container.run_jax_command("import jax; print(jax.__version__)") + assert result.exit_code == 0, "Failed to get JAX version" + assert result.output.decode().strip(), "JAX version is empty" + +def test_jax_container_gpu_support(jax_container): + result = jax_container.run_jax_command( + "import jax; devices = jax.devices(); " + "print(any(dev.platform == 'gpu' for dev in devices))" + ) + assert "True" in result.output.decode(), "No GPU device found" + +def test_jax_container_matrix_multiplication(jax_container): + command = """ +import jax +import jax.numpy as jnp +x = jnp.array([[1, 2], [3, 4]]) +y = jnp.array([[5, 6], [7, 8]]) +result = jnp.dot(x, y) +print(result) + """ + result = jax_container.run_jax_command(command) + assert "[[19 22]\n [43 50]]" in result.output.decode(), "Matrix multiplication failed" + +def test_jax_container_custom_image(): + custom_image = "nvcr.io/nvidia/jax:23.09-py3" + with JAXContainer(image=custom_image) as container: + container.connect() + result = container.run_jax_command("import jax; print(jax.__version__)") + assert result.exit_code == 0, f"Failed to run JAX with custom image {custom_image}" \ No newline at end of file From 62443464332c05833b8c2de3a7ec0e5028ca98ca Mon Sep 17 00:00:00 2001 From: Tonic Date: Mon, 5 Aug 2024 21:52:25 +0200 Subject: [PATCH 14/16] improve diarization object remove jupyter --- .../whisper_diarization/__init__.py | 64 ++++++++++++++++--- 1 file changed, 54 insertions(+), 10 deletions(-) diff --git a/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py index 05a897eea..2d3628128 100644 --- a/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py +++ b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py @@ -35,32 +35,57 @@ def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Option super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) self.model_name = model_name self.hf_token = hf_token - self.with_exposed_ports(8888) # Expose Jupyter notebook port self.with_env("NVIDIA_VISIBLE_DEVICES", "all") self.with_env("CUDA_VISIBLE_DEVICES", "all") self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + self.start_timeout = 600 # 10 minutes + self.connection_retries = 5 + self.connection_retry_delay = 10 # seconds # Install required dependencies self.with_command("sh -c '" "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && " - "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " - "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" "'") @wait_container_is_ready(URLError) def _connect(self): - url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" - res = urllib.request.urlopen(url) - if res.status != 200: - raise Exception(f"Failed to connect to JAX-Whisper-Diarization container. Status: {res.status}") + for attempt in range(self.connection_retries): + try: + # Check if JAX and other required libraries are properly installed and functioning + result = self.run_command( + "import jax; import whisper_jax; import pyannote.audio; " + "print(f'JAX version: {jax.__version__}'); " + "print(f'Whisper-JAX version: {whisper_jax.__version__}'); " + "print(f'Pyannote Audio version: {pyannote.audio.__version__}'); " + "print(f'Available devices: {jax.devices()}'); " + "print(jax.numpy.add(1, 1))" + ) + + if "JAX version" in result.output.decode() and "Available devices" in result.output.decode(): + logging.info(f"JAX-Whisper-Diarization environment verified:\n{result.output.decode()}") + return True + else: + raise Exception("JAX-Whisper-Diarization environment check failed") + + except Exception as e: + if attempt < self.connection_retries - 1: + logging.warning(f"Connection attempt {attempt + 1} failed. Retrying in {self.connection_retry_delay} seconds...") + time.sleep(self.connection_retry_delay) + else: + raise Exception(f"Failed to connect to JAX-Whisper-Diarization container after {self.connection_retries} attempts: {str(e)}") + + return False def connect(self): """ Connect to the JAX-Whisper-Diarization container and ensure it's ready. + This method verifies that JAX, Whisper-JAX, and Pyannote Audio are properly installed and functioning. + It also checks for available devices, including GPUs if applicable. """ self._connect() - logging.info("Successfully connected to JAX-Whisper-Diarization container") + logging.info("Successfully connected to JAX-Whisper-Diarization container and verified the environment") def run_command(self, command: str): """ @@ -242,8 +267,27 @@ def align(transcription, segments, group_by_speaker=True): def start(self): """ - Start the JAX-Whisper-Diarization container. + Start the JAX-Whisper-Diarization container and wait for it to be ready. """ super().start() - logging.info(f"JAX-Whisper-Diarization container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}") + self._wait_for_container_to_be_ready() + logging.info("JAX-Whisper-Diarization container started and ready.") return self + + def _wait_for_container_to_be_ready(self): + # Wait for a specific log message that indicates the container is ready + self.wait_for_logs("Installation completed") + + def stop(self, force=True): + """ + Stop the JAX-Whisper-Diarization container. + """ + super().stop(force) + logging.info("JAX-Whisper-Diarization container stopped.") + + @property + def timeout(self): + """ + Get the container start timeout. + """ + return self.start_timeout From fd0a0e1c0788a7d46d3d6b45d66e0b344e55895c Mon Sep 17 00:00:00 2001 From: David Ankin Date: Tue, 13 Aug 2024 09:07:17 -0400 Subject: [PATCH 15/16] linting --- modules/jax/README.rst | 6 +- .../jax/testcontainers/jax_amd/__init__.py | 13 ++-- .../jax/testcontainers/jax_cuda/__init__.py | 50 +++++-------- .../whisper_diarization/__init__.py | 72 +++++++++---------- .../whisper_transcription/__init__.py | 28 ++++---- modules/jax/tests/test_jax.py | 11 ++- modules/jax/tests/test_whisper_diarization.py | 15 ++-- modules/jax/tests/test_whisper_jax.py | 24 ++++--- 8 files changed, 107 insertions(+), 112 deletions(-) diff --git a/modules/jax/README.rst b/modules/jax/README.rst index c1894c5ba..2ec2bd738 100644 --- a/modules/jax/README.rst +++ b/modules/jax/README.rst @@ -5,7 +5,7 @@ 1. **Official JAX Docker Container** - **Container**: `jax/jax:cuda-12.0` - **Documentation**: [JAX Docker](https://github.com/google/jax/blob/main/docker/README.md) - + 2. **NVIDIA Docker Container** - **Container**: `nvidia/cuda:12.0-cudnn8-devel-ubuntu20.04` - **Documentation**: [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda) @@ -25,7 +25,7 @@ **Ensure Docker is configured to use the NVIDIA runtime**: - You need to install the NVIDIA Container Toolkit. Follow the instructions for your operating system: [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). - Update your Docker daemon configuration to include the NVIDIA runtime. Edit the Docker daemon configuration file, typically located at `/etc/docker/daemon.json`, to include the following: - + ```json { "runtimes": { @@ -55,4 +55,4 @@ 3. **Docker Best Practices** - [Docker Documentation](https://docs.docker.com/get-started/) - - [Best practices for writing Dockerfiles](https://docs.docker.com/develop/develop-images/dockerfile_best-practices/) \ No newline at end of file + - [Best practices for writing Dockerfiles](https://docs.docker.com/develop/develop-images/dockerfile_best-practices/) diff --git a/modules/jax/testcontainers/jax_amd/__init__.py b/modules/jax/testcontainers/jax_amd/__init__.py index 9929cd6bf..63c07368c 100644 --- a/modules/jax/testcontainers/jax_amd/__init__.py +++ b/modules/jax/testcontainers/jax_amd/__init__.py @@ -3,9 +3,8 @@ from urllib.error import URLError from core.testcontainers.core.container import DockerContainer -from core.testcontainers.core.waiting_utils import wait_container_is_ready -from core.testcontainers.core.config import testcontainers_config -from core.testcontainers.core.waiting_utils import wait_for_logs +from core.testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs + class JAXContainer(DockerContainer): """ @@ -21,7 +20,7 @@ class JAXContainer(DockerContainer): >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: ... # Connect to the container ... jax_container.connect() - ... + ... ... # Run a simple JAX computation ... result = jax.numpy.add(1, 1) ... assert result == 2 @@ -38,7 +37,7 @@ def __init__(self, image="huggingface/transformers-jax-light:latest", **kwargs): self.with_env("NVIDIA_VISIBLE_DEVICES", "all") self.with_env("CUDA_VISIBLE_DEVICES", "all") self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support - self.start_timeout = 600 # 10 minutes + self.start_timeout = 600 # 10 minutes @wait_container_is_ready(URLError) def _connect(self): @@ -78,7 +77,7 @@ def start(self): self._wait_for_container_to_be_ready() logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}") return self - + def stop(self, force=True): """ Stop the JAX container. @@ -91,4 +90,4 @@ def timeout(self): """ Get the container start timeout. """ - return self.start_timeout \ No newline at end of file + return self.start_timeout diff --git a/modules/jax/testcontainers/jax_cuda/__init__.py b/modules/jax/testcontainers/jax_cuda/__init__.py index c28de41c2..33cd18015 100644 --- a/modules/jax/testcontainers/jax_cuda/__init__.py +++ b/modules/jax/testcontainers/jax_cuda/__init__.py @@ -1,10 +1,10 @@ import logging -import time from urllib.error import URLError from core.testcontainers.core.container import DockerContainer from core.testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs + class JAXContainer(DockerContainer): """ JAX container for GPU-accelerated numerical computing and machine learning. @@ -18,7 +18,7 @@ class JAXContainer(DockerContainer): >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: ... # Connect to the container ... jax_container.connect() - ... + ... ... # Run a simple JAX computation ... result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))") ... assert "2" in result.output @@ -40,31 +40,19 @@ def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs): @wait_container_is_ready(URLError) def _connect(self): - for attempt in range(self.connection_retries): - try: - # Check if JAX is properly installed and functioning - result = self.run_jax_command( - "import jax; import jaxlib; " - "print(f'JAX version: {jax.__version__}'); " - "print(f'JAXlib version: {jaxlib.__version__}'); " - "print(f'Available devices: {jax.devices()}'); " - "print(jax.numpy.add(1, 1))" - ) - - if "JAX version" in result.output and "Available devices" in result.output: - logging.info(f"JAX environment verified:\n{result.output}") - return True - else: - raise Exception("JAX environment check failed") - - except Exception as e: - if attempt < self.connection_retries - 1: - logging.warning(f"Connection attempt {attempt + 1} failed. Retrying in {self.connection_retry_delay} seconds...") - time.sleep(self.connection_retry_delay) - else: - raise Exception(f"Failed to connect to JAX container after {self.connection_retries} attempts: {str(e)}") - - return False + # Check if JAX is properly installed and functioning + result = self.run_jax_command( + "import jax; import jaxlib; " + "print(f'JAX version: {jax.__version__}'); " + "print(f'JAXlib version: {jaxlib.__version__}'); " + "print(f'Available devices: {jax.devices()}'); " + "print(jax.numpy.add(1, 1))" + ) + + if "JAX version" in result.output and "Available devices" in result.output: + logging.info(f"JAX environment verified:\n{result.output}") + else: + raise Exception("JAX environment check failed") def connect(self): """ @@ -93,12 +81,12 @@ def start(self): self._wait_for_container_to_be_ready() logging.info("JAX container started and ready.") return self - - def stop(self, force=True): + + def stop(self, force=True, delete_volume=True) -> None: """ Stop the JAX container. """ - super().stop(force) + super().stop(force, delete_volume) logging.info("JAX container stopped.") @property @@ -106,4 +94,4 @@ def timeout(self): """ Get the container start timeout. """ - return self.start_timeout \ No newline at end of file + return self.start_timeout diff --git a/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py index 2d3628128..6a28010cb 100644 --- a/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py +++ b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py @@ -1,10 +1,10 @@ import logging -import tempfile from typing import Optional +from urllib.error import URLError from core.testcontainers.core.container import DockerContainer from core.testcontainers.core.waiting_utils import wait_container_is_ready -from urllib.error import URLError + class JAXWhisperDiarizationContainer(DockerContainer): """ @@ -15,17 +15,17 @@ class JAXWhisperDiarizationContainer(DockerContainer): .. doctest:: >>> logging.basicConfig(level=logging.INFO) - + ... # You need to provide your Hugging Face token to use the pyannote.audio models >>> hf_token = "your_huggingface_token_here" - + >>> with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: ... whisper_diarization.connect() - ... + ... ... # Example: Transcribe and diarize an audio file ... result = whisper_diarization.transcribe_and_diarize_file("/path/to/audio/file.wav") ... print(f"Transcription and Diarization: {result}") - ... + ... ... # Example: Transcribe and diarize a YouTube video ... result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") ... print(f"YouTube Transcription and Diarization: {result}") @@ -43,40 +43,30 @@ def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Option self.connection_retry_delay = 10 # seconds # Install required dependencies - self.with_command("sh -c '" - "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " - "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && " - "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" - "'") + self.with_command( + "sh -c '" + "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " + "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && " + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" + "'" + ) @wait_container_is_ready(URLError) def _connect(self): - for attempt in range(self.connection_retries): - try: - # Check if JAX and other required libraries are properly installed and functioning - result = self.run_command( - "import jax; import whisper_jax; import pyannote.audio; " - "print(f'JAX version: {jax.__version__}'); " - "print(f'Whisper-JAX version: {whisper_jax.__version__}'); " - "print(f'Pyannote Audio version: {pyannote.audio.__version__}'); " - "print(f'Available devices: {jax.devices()}'); " - "print(jax.numpy.add(1, 1))" - ) - - if "JAX version" in result.output.decode() and "Available devices" in result.output.decode(): - logging.info(f"JAX-Whisper-Diarization environment verified:\n{result.output.decode()}") - return True - else: - raise Exception("JAX-Whisper-Diarization environment check failed") - - except Exception as e: - if attempt < self.connection_retries - 1: - logging.warning(f"Connection attempt {attempt + 1} failed. Retrying in {self.connection_retry_delay} seconds...") - time.sleep(self.connection_retry_delay) - else: - raise Exception(f"Failed to connect to JAX-Whisper-Diarization container after {self.connection_retries} attempts: {str(e)}") - - return False + # Check if JAX and other required libraries are properly installed and functioning + result = self.run_command( + "import jax; import whisper_jax; import pyannote.audio; " + "print(f'JAX version: {jax.__version__}'); " + "print(f'Whisper-JAX version: {whisper_jax.__version__}'); " + "print(f'Pyannote Audio version: {pyannote.audio.__version__}'); " + "print(f'Available devices: {jax.devices()}'); " + "print(jax.numpy.add(1, 1))" + ) + + if "JAX version" in result.output.decode() and "Available devices" in result.output.decode(): + logging.info(f"JAX-Whisper-Diarization environment verified:\n{result.output.decode()}") + else: + raise Exception("JAX-Whisper-Diarization environment check failed") def connect(self): """ @@ -94,7 +84,9 @@ def run_command(self, command: str): exec_result = self.exec(f"python -c '{command}'") return exec_result - def transcribe_and_diarize_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True): + def transcribe_and_diarize_file( + self, file_path: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True + ): """ Transcribe and diarize an audio file using Whisper-JAX and pyannote. """ @@ -171,7 +163,9 @@ def align(transcription, segments, group_by_speaker=True): """ return self.run_command(command) - def transcribe_and_diarize_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True): + def transcribe_and_diarize_youtube( + self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True + ): """ Transcribe and diarize a YouTube video using Whisper-JAX and pyannote. """ diff --git a/modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py b/modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py index eda1c988e..da31fb1e0 100644 --- a/modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py +++ b/modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py @@ -1,11 +1,9 @@ import logging -import tempfile -import time -from typing import Optional +from urllib.error import URLError from core.testcontainers.core.container import DockerContainer from core.testcontainers.core.waiting_utils import wait_container_is_ready -from urllib.error import URLError + class WhisperJAXContainer(DockerContainer): """ @@ -20,7 +18,7 @@ class WhisperJAXContainer(DockerContainer): >>> with WhisperJAXContainer("openai/whisper-large-v2") as whisper: ... # Connect to the container ... whisper.connect() - ... + ... ... # Transcribe an audio file ... result = whisper.transcribe_file("path/to/audio/file.wav") ... print(result['text']) @@ -39,15 +37,19 @@ def __init__(self, model_name: str = "openai/whisper-large-v2", **kwargs): self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support # Install required dependencies - self.with_command("sh -c '" - "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " - "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && " - "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " - "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" - "'") + self.with_command( + "sh -c '" + "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " + "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && " + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " + "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" + "'" + ) @wait_container_is_ready(URLError) def _connect(self): + import urllib.request + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" res = urllib.request.urlopen(url) if res.status != 200: @@ -122,5 +124,7 @@ def start(self): Start the Whisper-JAX container. """ super().start() - logging.info(f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}") + logging.info( + f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + ) return self diff --git a/modules/jax/tests/test_jax.py b/modules/jax/tests/test_jax.py index c2138e8b7..9ca2c2129 100644 --- a/modules/jax/tests/test_jax.py +++ b/modules/jax/tests/test_jax.py @@ -1,28 +1,32 @@ import pytest from modules.jax.testcontainers.jax_cuda import JAXContainer + @pytest.fixture(scope="module") def jax_container(): with JAXContainer() as container: container.connect() yield container + def test_jax_container_basic_computation(jax_container): result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))") assert "2" in result.output.decode(), "Basic JAX computation failed" + def test_jax_container_version(jax_container): result = jax_container.run_jax_command("import jax; print(jax.__version__)") assert result.exit_code == 0, "Failed to get JAX version" assert result.output.decode().strip(), "JAX version is empty" + def test_jax_container_gpu_support(jax_container): result = jax_container.run_jax_command( - "import jax; devices = jax.devices(); " - "print(any(dev.platform == 'gpu' for dev in devices))" + "import jax; devices = jax.devices(); " "print(any(dev.platform == 'gpu' for dev in devices))" ) assert "True" in result.output.decode(), "No GPU device found" + def test_jax_container_matrix_multiplication(jax_container): command = """ import jax @@ -35,9 +39,10 @@ def test_jax_container_matrix_multiplication(jax_container): result = jax_container.run_jax_command(command) assert "[[19 22]\n [43 50]]" in result.output.decode(), "Matrix multiplication failed" + def test_jax_container_custom_image(): custom_image = "nvcr.io/nvidia/jax:23.09-py3" with JAXContainer(image=custom_image) as container: container.connect() result = container.run_jax_command("import jax; print(jax.__version__)") - assert result.exit_code == 0, f"Failed to run JAX with custom image {custom_image}" \ No newline at end of file + assert result.exit_code == 0, f"Failed to run JAX with custom image {custom_image}" diff --git a/modules/jax/tests/test_whisper_diarization.py b/modules/jax/tests/test_whisper_diarization.py index f31ce5f3b..ebbb4def4 100644 --- a/modules/jax/tests/test_whisper_diarization.py +++ b/modules/jax/tests/test_whisper_diarization.py @@ -1,31 +1,34 @@ import pytest from modules.jax.testcontainers.whisper_cuda.whisper_diarization import JAXWhisperDiarizationContainer + @pytest.fixture(scope="module") def hf_token(): return "your_huggingface_token_here" # Replace with a valid token or use an environment variable + def test_jax_whisper_diarization_container(hf_token): with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: whisper_diarization.connect() - + # Test file transcription and diarization result = whisper_diarization.transcribe_and_diarize_file("/path/to/test/audio.wav") assert isinstance(result, list) assert all(isinstance(item, dict) for item in result) - assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result) - + assert all("speaker" in item and "text" in item and "timestamp" in item for item in result) + # Test YouTube transcription and diarization result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") assert isinstance(result, list) assert all(isinstance(item, dict) for item in result) - assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result) + assert all("speaker" in item and "text" in item and "timestamp" in item for item in result) + def test_jax_whisper_diarization_container_without_grouping(hf_token): with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: whisper_diarization.connect() - + result = whisper_diarization.transcribe_and_diarize_file("/path/to/test/audio.wav", group_by_speaker=False) assert isinstance(result, list) assert all(isinstance(item, dict) for item in result) - assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result) \ No newline at end of file + assert all("speaker" in item and "text" in item and "timestamp" in item for item in result) diff --git a/modules/jax/tests/test_whisper_jax.py b/modules/jax/tests/test_whisper_jax.py index 12534cce2..4fa91bc76 100644 --- a/modules/jax/tests/test_whisper_jax.py +++ b/modules/jax/tests/test_whisper_jax.py @@ -1,30 +1,32 @@ import pytest from modules.jax.testcontainers.whisper_cuda.whisper_transcription import WhisperJAXContainer + @pytest.mark.parametrize("model_name", ["openai/whisper-tiny", "openai/whisper-base"]) def test_whisper_jax_container(model_name): with WhisperJAXContainer(model_name) as whisper: whisper.connect() - + # Test file transcription result = whisper.transcribe_file("/path/to/test/audio.wav") assert isinstance(result, dict) - assert 'text' in result - assert isinstance(result['text'], str) - + assert "text" in result + assert isinstance(result["text"], str) + # Test YouTube transcription result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") assert isinstance(result, dict) - assert 'text' in result - assert isinstance(result['text'], str) + assert "text" in result + assert isinstance(result["text"], str) + def test_whisper_jax_container_with_timestamps(): with WhisperJAXContainer() as whisper: whisper.connect() - + result = whisper.transcribe_file("/path/to/test/audio.wav", return_timestamps=True) assert isinstance(result, dict) - assert 'text' in result - assert 'chunks' in result - assert isinstance(result['chunks'], list) - assert all('timestamp' in chunk for chunk in result['chunks']) \ No newline at end of file + assert "text" in result + assert "chunks" in result + assert isinstance(result["chunks"], list) + assert all("timestamp" in chunk for chunk in result["chunks"]) From ca24227ea7bc0599c6b00d8dbfad3ed5d1606f92 Mon Sep 17 00:00:00 2001 From: David Ankin Date: Tue, 13 Aug 2024 09:09:28 -0400 Subject: [PATCH 16/16] update poetry lock file --- poetry.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 228c9c483..1b094e71c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1968,7 +1968,6 @@ python-versions = ">=3.7" files = [ {file = "milvus_lite-2.4.7-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:c828190118b104b05b8c8e0b5a4147811c86b54b8fb67bc2e726ad10fc0b544e"}, {file = "milvus_lite-2.4.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e1537633c39879714fb15082be56a4b97f74c905a6e98e302ec01320561081af"}, - {file = "milvus_lite-2.4.7-py3-none-manylinux2014_aarch64.whl", hash = "sha256:fcb909d38c83f21478ca9cb500c84264f988c69f62715ae9462e966767fb76dd"}, {file = "milvus_lite-2.4.7-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f016474d663045787dddf1c3aad13b7d8b61fd329220318f858184918143dcbf"}, ] @@ -4643,6 +4642,7 @@ elasticsearch = [] generic = ["httpx"] google = ["google-cloud-datastore", "google-cloud-pubsub"] influxdb = ["influxdb", "influxdb-client"] +jax = [] k3s = ["kubernetes", "pyyaml"] kafka = [] keycloak = ["python-keycloak"] @@ -4677,4 +4677,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "18a5763385d12114513ef5d65268de3ea6567e79b21049b6d58d1803f4257306" +content-hash = "3d381b82f4484c2fff23b22a08d7750f9eed2dc525a7cdf361346b81560283fb"