diff --git a/modules/jax/README.rst b/modules/jax/README.rst new file mode 100644 index 00000000..2ec2bd73 --- /dev/null +++ b/modules/jax/README.rst @@ -0,0 +1,58 @@ +# Testcontainers : JAX + +## Docker Containers for JAX with GPU Support + +1. **Official JAX Docker Container** + - **Container**: `jax/jax:cuda-12.0` + - **Documentation**: [JAX Docker](https://github.com/google/jax/blob/main/docker/README.md) + +2. **NVIDIA Docker Container** + - **Container**: `nvidia/cuda:12.0-cudnn8-devel-ubuntu20.04` + - **Documentation**: [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda) + +## Benefits of Having This Container + +1. **Optimized Performance**: JAX uses XLA to compile and run NumPy programs on GPUs, which can significantly speed up numerical computations and machine learning tasks. A container specifically optimized for JAX with CUDA ensures that the environment is configured to leverage GPU acceleration fully. + +2. **Reproducibility**: Containers encapsulate all dependencies, libraries, and configurations needed to run JAX, ensuring that the environment is consistent across different systems. This is crucial for reproducible research and development. + +3. **Ease of Use**: Users can easily pull and run the container without worrying about the complex setup required for GPU support and JAX configuration. This reduces the barrier to entry for new users and accelerates development workflows. + +4. **Isolation and Security**: Containers provide an isolated environment, which enhances security by limiting the impact of potential vulnerabilities. It also avoids conflicts with other software on the host system. + +## Troubleshooting + +**Ensure Docker is configured to use the NVIDIA runtime**: + - You need to install the NVIDIA Container Toolkit. Follow the instructions for your operating system: [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). + - Update your Docker daemon configuration to include the NVIDIA runtime. Edit the Docker daemon configuration file, typically located at `/etc/docker/daemon.json`, to include the following: + + ```json + { + "runtimes": { + "nvidia": { + "path": "nvidia-container-runtime", + "runtimeArgs": [] + } + } + } + ``` + + - Restart the Docker daemon to apply the changes: + ```sh + sudo systemctl restart docker + ``` + +## Relevant Reading Material + +1. **JAX Documentation** + - [JAX Quickstart](https://github.com/google/jax#quickstart) + - [JAX Transformations](https://github.com/google/jax#transformations) + - [JAX Installation Guide](https://github.com/google/jax#installation) + +2. **NVIDIA Docker Documentation** + - [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda) + - [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker) + +3. **Docker Best Practices** + - [Docker Documentation](https://docs.docker.com/get-started/) + - [Best practices for writing Dockerfiles](https://docs.docker.com/develop/develop-images/dockerfile_best-practices/) diff --git a/modules/jax/testcontainers/jax_amd/__init__.py b/modules/jax/testcontainers/jax_amd/__init__.py new file mode 100644 index 00000000..63c07368 --- /dev/null +++ b/modules/jax/testcontainers/jax_amd/__init__.py @@ -0,0 +1,93 @@ +import logging +import urllib.request +from urllib.error import URLError + +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs + + +class JAXContainer(DockerContainer): + """ + JAX container for GPU-accelerated numerical computing and machine learning. + + Example: + + .. doctest:: + + >>> import jax + >>> from testcontainers.jax import JAXContainer + + >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: + ... # Connect to the container + ... jax_container.connect() + ... + ... # Run a simple JAX computation + ... result = jax.numpy.add(1, 1) + ... assert result == 2 + + .. auto-class:: JAXContainer + :members: + :undoc-members: + :show-inheritance: + """ + + def __init__(self, image="huggingface/transformers-jax-light:latest", **kwargs): + super().__init__(image, **kwargs) + self.with_exposed_ports(8888) # Expose Jupyter notebook port + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + self.start_timeout = 600 # 10 minutes + + @wait_container_is_ready(URLError) + def _connect(self): + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + res = urllib.request.urlopen(url, timeout=self.start_timeout) + if res.status != 200: + raise Exception(f"Failed to connect to JAX container. Status: {res.status}") + + def connect(self): + """ + Connect to the JAX container and ensure it's ready. + """ + self._connect() + logging.info("Successfully connected to JAX container") + + def get_jupyter_url(self): + """ + Get the URL for accessing the Jupyter notebook server. + """ + return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + + def run_jax_command(self, command): + """ + Run a JAX command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def _wait_for_container_to_be_ready(self): + wait_for_logs(self, "Jupyter Server", timeout=self.start_timeout) + + def start(self): + """ + Start the JAX container and wait for it to be ready. + """ + super().start() + self._wait_for_container_to_be_ready() + logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}") + return self + + def stop(self, force=True): + """ + Stop the JAX container. + """ + super().stop(force) + logging.info("JAX container stopped.") + + @property + def timeout(self): + """ + Get the container start timeout. + """ + return self.start_timeout diff --git a/modules/jax/testcontainers/jax_cuda/__init__.py b/modules/jax/testcontainers/jax_cuda/__init__.py new file mode 100644 index 00000000..33cd1801 --- /dev/null +++ b/modules/jax/testcontainers/jax_cuda/__init__.py @@ -0,0 +1,97 @@ +import logging +from urllib.error import URLError + +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs + + +class JAXContainer(DockerContainer): + """ + JAX container for GPU-accelerated numerical computing and machine learning. + + Example: + + .. doctest:: + + >>> from testcontainers.jax import JAXContainer + + >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: + ... # Connect to the container + ... jax_container.connect() + ... + ... # Run a simple JAX computation + ... result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))") + ... assert "2" in result.output + + .. auto-class:: JAXContainer + :members: + :undoc-members: + :show-inheritance: + """ + + def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs): + super().__init__(image, **kwargs) + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + self.start_timeout = 600 # 10 minutes + self.connection_retries = 5 + self.connection_retry_delay = 10 # seconds + + @wait_container_is_ready(URLError) + def _connect(self): + # Check if JAX is properly installed and functioning + result = self.run_jax_command( + "import jax; import jaxlib; " + "print(f'JAX version: {jax.__version__}'); " + "print(f'JAXlib version: {jaxlib.__version__}'); " + "print(f'Available devices: {jax.devices()}'); " + "print(jax.numpy.add(1, 1))" + ) + + if "JAX version" in result.output and "Available devices" in result.output: + logging.info(f"JAX environment verified:\n{result.output}") + else: + raise Exception("JAX environment check failed") + + def connect(self): + """ + Connect to the JAX container and ensure it's ready. + This method verifies that JAX is properly installed and functioning. + It also checks for available devices, including GPUs if applicable. + """ + self._connect() + logging.info("Successfully connected to JAX container and verified the environment") + + def run_jax_command(self, command): + """ + Run a JAX command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def _wait_for_container_to_be_ready(self): + wait_for_logs(self, "JAX is ready", timeout=self.start_timeout) + + def start(self): + """ + Start the JAX container and wait for it to be ready. + """ + super().start() + self._wait_for_container_to_be_ready() + logging.info("JAX container started and ready.") + return self + + def stop(self, force=True, delete_volume=True) -> None: + """ + Stop the JAX container. + """ + super().stop(force, delete_volume) + logging.info("JAX container stopped.") + + @property + def timeout(self): + """ + Get the container start timeout. + """ + return self.start_timeout diff --git a/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py new file mode 100644 index 00000000..6a28010c --- /dev/null +++ b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py @@ -0,0 +1,287 @@ +import logging +from typing import Optional +from urllib.error import URLError + +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready + + +class JAXWhisperDiarizationContainer(DockerContainer): + """ + JAX-Whisper-Diarization container for fast speech recognition, transcription, and speaker diarization. + + Example: + + .. doctest:: + + >>> logging.basicConfig(level=logging.INFO) + + ... # You need to provide your Hugging Face token to use the pyannote.audio models + >>> hf_token = "your_huggingface_token_here" + + >>> with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: + ... whisper_diarization.connect() + ... + ... # Example: Transcribe and diarize an audio file + ... result = whisper_diarization.transcribe_and_diarize_file("/path/to/audio/file.wav") + ... print(f"Transcription and Diarization: {result}") + ... + ... # Example: Transcribe and diarize a YouTube video + ... result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + ... print(f"YouTube Transcription and Diarization: {result}") + """ + + def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Optional[str] = None, **kwargs): + super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) + self.model_name = model_name + self.hf_token = hf_token + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + self.start_timeout = 600 # 10 minutes + self.connection_retries = 5 + self.connection_retry_delay = 10 # seconds + + # Install required dependencies + self.with_command( + "sh -c '" + "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " + "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && " + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" + "'" + ) + + @wait_container_is_ready(URLError) + def _connect(self): + # Check if JAX and other required libraries are properly installed and functioning + result = self.run_command( + "import jax; import whisper_jax; import pyannote.audio; " + "print(f'JAX version: {jax.__version__}'); " + "print(f'Whisper-JAX version: {whisper_jax.__version__}'); " + "print(f'Pyannote Audio version: {pyannote.audio.__version__}'); " + "print(f'Available devices: {jax.devices()}'); " + "print(jax.numpy.add(1, 1))" + ) + + if "JAX version" in result.output.decode() and "Available devices" in result.output.decode(): + logging.info(f"JAX-Whisper-Diarization environment verified:\n{result.output.decode()}") + else: + raise Exception("JAX-Whisper-Diarization environment check failed") + + def connect(self): + """ + Connect to the JAX-Whisper-Diarization container and ensure it's ready. + This method verifies that JAX, Whisper-JAX, and Pyannote Audio are properly installed and functioning. + It also checks for available devices, including GPUs if applicable. + """ + self._connect() + logging.info("Successfully connected to JAX-Whisper-Diarization container and verified the environment") + + def run_command(self, command: str): + """ + Run a Python command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def transcribe_and_diarize_file( + self, file_path: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True + ): + """ + Transcribe and diarize an audio file using Whisper-JAX and pyannote. + """ + command = f""" +import soundfile as sf +import torch +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp +from pyannote.audio import Pipeline +import numpy as np + +def align(transcription, segments, group_by_speaker=True): + transcription_split = transcription.split("\\n") + transcript = [] + for chunk in transcription_split: + start_end, text = chunk[1:].split("] ") + start, end = start_end.split("->") + start, end = float(start), float(end) + transcript.append({{"timestamp": (start, end), "text": text}}) + + new_segments = [] + prev_segment = segments[0] + for i in range(1, len(segments)): + cur_segment = segments[i] + if cur_segment["label"] != prev_segment["label"]: + new_segments.append({{ + "segment": {{"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}}, + "speaker": prev_segment["label"] + }}) + prev_segment = segments[i] + new_segments.append({{ + "segment": {{"start": prev_segment["segment"]["start"], "end": segments[-1]["segment"]["end"]}}, + "speaker": prev_segment["label"] + }}) + + end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) + segmented_preds = [] + + for segment in new_segments: + end_time = segment["segment"]["end"] + upto_idx = np.argmin(np.abs(end_timestamps - end_time)) + + if group_by_speaker: + segmented_preds.append({{ + "speaker": segment["speaker"], + "text": " ".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), + "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]) + }}) + else: + for i in range(upto_idx + 1): + segmented_preds.append({{"speaker": segment["speaker"], **transcript[i]}}) + + transcript = transcript[upto_idx + 1 :] + end_timestamps = end_timestamps[upto_idx + 1 :] + + return segmented_preds + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) +diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="{self.hf_token}") + +audio, sr = sf.read("{file_path}") +inputs = {{"array": audio, "sampling_rate": sr}} + +# Transcribe +result = pipeline(inputs, task="{task}", return_timestamps={return_timestamps}) + +# Diarize +diarization = diarization_pipeline({{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr}}) +segments = diarization.for_json()["content"] + +# Align transcription and diarization +aligned_result = align(result["text"], segments, group_by_speaker={group_by_speaker}) +print(aligned_result) +""" + return self.run_command(command) + + def transcribe_and_diarize_youtube( + self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True + ): + """ + Transcribe and diarize a YouTube video using Whisper-JAX and pyannote. + """ + command = f""" +import tempfile +import youtube_dl +import soundfile as sf +import torch +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp +from pyannote.audio import Pipeline +import numpy as np + +def download_youtube_audio(youtube_url, output_file): + ydl_opts = {{ + 'format': 'bestaudio/best', + 'postprocessors': [{{ + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'wav', + 'preferredquality': '192', + }}], + 'outtmpl': output_file, + }} + with youtube_dl.YoutubeDL(ydl_opts) as ydl: + ydl.download([youtube_url]) + +def align(transcription, segments, group_by_speaker=True): + transcription_split = transcription.split("\\n") + transcript = [] + for chunk in transcription_split: + start_end, text = chunk[1:].split("] ") + start, end = start_end.split("->") + start, end = float(start), float(end) + transcript.append({{"timestamp": (start, end), "text": text}}) + + new_segments = [] + prev_segment = segments[0] + for i in range(1, len(segments)): + cur_segment = segments[i] + if cur_segment["label"] != prev_segment["label"]: + new_segments.append({{ + "segment": {{"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}}, + "speaker": prev_segment["label"] + }}) + prev_segment = segments[i] + new_segments.append({{ + "segment": {{"start": prev_segment["segment"]["start"], "end": segments[-1]["segment"]["end"]}}, + "speaker": prev_segment["label"] + }}) + + end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) + segmented_preds = [] + + for segment in new_segments: + end_time = segment["segment"]["end"] + upto_idx = np.argmin(np.abs(end_timestamps - end_time)) + + if group_by_speaker: + segmented_preds.append({{ + "speaker": segment["speaker"], + "text": " ".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), + "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]) + }}) + else: + for i in range(upto_idx + 1): + segmented_preds.append({{"speaker": segment["speaker"], **transcript[i]}}) + + transcript = transcript[upto_idx + 1 :] + end_timestamps = end_timestamps[upto_idx + 1 :] + + return segmented_preds + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) +diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="{self.hf_token}") + +with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: + download_youtube_audio("{youtube_url}", temp_file.name) + audio, sr = sf.read(temp_file.name) + inputs = {{"array": audio, "sampling_rate": sr}} + + # Transcribe + result = pipeline(inputs, task="{task}", return_timestamps={return_timestamps}) + + # Diarize + diarization = diarization_pipeline({{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr}}) + segments = diarization.for_json()["content"] + + # Align transcription and diarization + aligned_result = align(result["text"], segments, group_by_speaker={group_by_speaker}) + print(aligned_result) +""" + return self.run_command(command) + + def start(self): + """ + Start the JAX-Whisper-Diarization container and wait for it to be ready. + """ + super().start() + self._wait_for_container_to_be_ready() + logging.info("JAX-Whisper-Diarization container started and ready.") + return self + + def _wait_for_container_to_be_ready(self): + # Wait for a specific log message that indicates the container is ready + self.wait_for_logs("Installation completed") + + def stop(self, force=True): + """ + Stop the JAX-Whisper-Diarization container. + """ + super().stop(force) + logging.info("JAX-Whisper-Diarization container stopped.") + + @property + def timeout(self): + """ + Get the container start timeout. + """ + return self.start_timeout diff --git a/modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py b/modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py new file mode 100644 index 00000000..da31fb1e --- /dev/null +++ b/modules/jax/testcontainers/whisper_cuda/whisper_transcription/__init__.py @@ -0,0 +1,130 @@ +import logging +from urllib.error import URLError + +from core.testcontainers.core.container import DockerContainer +from core.testcontainers.core.waiting_utils import wait_container_is_ready + + +class WhisperJAXContainer(DockerContainer): + """ + Whisper-JAX container for fast speech recognition and transcription. + + Example: + + .. doctest:: + + >>> from testcontainers.whisper_jax import WhisperJAXContainer + + >>> with WhisperJAXContainer("openai/whisper-large-v2") as whisper: + ... # Connect to the container + ... whisper.connect() + ... + ... # Transcribe an audio file + ... result = whisper.transcribe_file("path/to/audio/file.wav") + ... print(result['text']) + ... + ... # Transcribe a YouTube video + ... result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + ... print(result['text']) + """ + + def __init__(self, model_name: str = "openai/whisper-large-v2", **kwargs): + super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) + self.model_name = model_name + self.with_exposed_ports(8888) # Expose Jupyter notebook port + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") + self.with_env("CUDA_VISIBLE_DEVICES", "all") + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + + # Install required dependencies + self.with_command( + "sh -c '" + "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " + "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && " + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " + "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" + "'" + ) + + @wait_container_is_ready(URLError) + def _connect(self): + import urllib.request + + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + res = urllib.request.urlopen(url) + if res.status != 200: + raise Exception(f"Failed to connect to Whisper-JAX container. Status: {res.status}") + + def connect(self): + """ + Connect to the Whisper-JAX container and ensure it's ready. + """ + self._connect() + logging.info("Successfully connected to Whisper-JAX container") + + def run_command(self, command: str): + """ + Run a Python command inside the container. + """ + exec_result = self.exec(f"python -c '{command}'") + return exec_result + + def transcribe_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = False): + """ + Transcribe an audio file using Whisper-JAX. + """ + command = f""" +import soundfile as sf +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) +audio, sr = sf.read("{file_path}") +result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps}) +print(result) +""" + return self.run_command(command) + + def transcribe_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = False): + """ + Transcribe a YouTube video using Whisper-JAX. + """ + command = f""" +import tempfile +import youtube_dl +import soundfile as sf +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp + +def download_youtube_audio(youtube_url, output_file): + ydl_opts = {{ + 'format': 'bestaudio/best', + 'postprocessors': [{{ + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'wav', + 'preferredquality': '192', + }}], + 'outtmpl': output_file, + }} + with youtube_dl.YoutubeDL(ydl_opts) as ydl: + ydl.download([youtube_url]) + +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) + +with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: + download_youtube_audio("{youtube_url}", temp_file.name) + audio, sr = sf.read(temp_file.name) + result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps}) + print(result) +""" + return self.run_command(command) + + def start(self): + """ + Start the Whisper-JAX container. + """ + super().start() + logging.info( + f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" + ) + return self diff --git a/modules/jax/tests/test_jax.py b/modules/jax/tests/test_jax.py new file mode 100644 index 00000000..9ca2c212 --- /dev/null +++ b/modules/jax/tests/test_jax.py @@ -0,0 +1,48 @@ +import pytest +from modules.jax.testcontainers.jax_cuda import JAXContainer + + +@pytest.fixture(scope="module") +def jax_container(): + with JAXContainer() as container: + container.connect() + yield container + + +def test_jax_container_basic_computation(jax_container): + result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))") + assert "2" in result.output.decode(), "Basic JAX computation failed" + + +def test_jax_container_version(jax_container): + result = jax_container.run_jax_command("import jax; print(jax.__version__)") + assert result.exit_code == 0, "Failed to get JAX version" + assert result.output.decode().strip(), "JAX version is empty" + + +def test_jax_container_gpu_support(jax_container): + result = jax_container.run_jax_command( + "import jax; devices = jax.devices(); " "print(any(dev.platform == 'gpu' for dev in devices))" + ) + assert "True" in result.output.decode(), "No GPU device found" + + +def test_jax_container_matrix_multiplication(jax_container): + command = """ +import jax +import jax.numpy as jnp +x = jnp.array([[1, 2], [3, 4]]) +y = jnp.array([[5, 6], [7, 8]]) +result = jnp.dot(x, y) +print(result) + """ + result = jax_container.run_jax_command(command) + assert "[[19 22]\n [43 50]]" in result.output.decode(), "Matrix multiplication failed" + + +def test_jax_container_custom_image(): + custom_image = "nvcr.io/nvidia/jax:23.09-py3" + with JAXContainer(image=custom_image) as container: + container.connect() + result = container.run_jax_command("import jax; print(jax.__version__)") + assert result.exit_code == 0, f"Failed to run JAX with custom image {custom_image}" diff --git a/modules/jax/tests/test_whisper_diarization.py b/modules/jax/tests/test_whisper_diarization.py new file mode 100644 index 00000000..ebbb4def --- /dev/null +++ b/modules/jax/tests/test_whisper_diarization.py @@ -0,0 +1,34 @@ +import pytest +from modules.jax.testcontainers.whisper_cuda.whisper_diarization import JAXWhisperDiarizationContainer + + +@pytest.fixture(scope="module") +def hf_token(): + return "your_huggingface_token_here" # Replace with a valid token or use an environment variable + + +def test_jax_whisper_diarization_container(hf_token): + with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: + whisper_diarization.connect() + + # Test file transcription and diarization + result = whisper_diarization.transcribe_and_diarize_file("/path/to/test/audio.wav") + assert isinstance(result, list) + assert all(isinstance(item, dict) for item in result) + assert all("speaker" in item and "text" in item and "timestamp" in item for item in result) + + # Test YouTube transcription and diarization + result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + assert isinstance(result, list) + assert all(isinstance(item, dict) for item in result) + assert all("speaker" in item and "text" in item and "timestamp" in item for item in result) + + +def test_jax_whisper_diarization_container_without_grouping(hf_token): + with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: + whisper_diarization.connect() + + result = whisper_diarization.transcribe_and_diarize_file("/path/to/test/audio.wav", group_by_speaker=False) + assert isinstance(result, list) + assert all(isinstance(item, dict) for item in result) + assert all("speaker" in item and "text" in item and "timestamp" in item for item in result) diff --git a/modules/jax/tests/test_whisper_jax.py b/modules/jax/tests/test_whisper_jax.py new file mode 100644 index 00000000..4fa91bc7 --- /dev/null +++ b/modules/jax/tests/test_whisper_jax.py @@ -0,0 +1,32 @@ +import pytest +from modules.jax.testcontainers.whisper_cuda.whisper_transcription import WhisperJAXContainer + + +@pytest.mark.parametrize("model_name", ["openai/whisper-tiny", "openai/whisper-base"]) +def test_whisper_jax_container(model_name): + with WhisperJAXContainer(model_name) as whisper: + whisper.connect() + + # Test file transcription + result = whisper.transcribe_file("/path/to/test/audio.wav") + assert isinstance(result, dict) + assert "text" in result + assert isinstance(result["text"], str) + + # Test YouTube transcription + result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + assert isinstance(result, dict) + assert "text" in result + assert isinstance(result["text"], str) + + +def test_whisper_jax_container_with_timestamps(): + with WhisperJAXContainer() as whisper: + whisper.connect() + + result = whisper.transcribe_file("/path/to/test/audio.wav", return_timestamps=True) + assert isinstance(result, dict) + assert "text" in result + assert "chunks" in result + assert isinstance(result["chunks"], list) + assert all("timestamp" in chunk for chunk in result["chunks"]) diff --git a/poetry.lock b/poetry.lock index 228c9c48..1b094e71 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1968,7 +1968,6 @@ python-versions = ">=3.7" files = [ {file = "milvus_lite-2.4.7-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:c828190118b104b05b8c8e0b5a4147811c86b54b8fb67bc2e726ad10fc0b544e"}, {file = "milvus_lite-2.4.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e1537633c39879714fb15082be56a4b97f74c905a6e98e302ec01320561081af"}, - {file = "milvus_lite-2.4.7-py3-none-manylinux2014_aarch64.whl", hash = "sha256:fcb909d38c83f21478ca9cb500c84264f988c69f62715ae9462e966767fb76dd"}, {file = "milvus_lite-2.4.7-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f016474d663045787dddf1c3aad13b7d8b61fd329220318f858184918143dcbf"}, ] @@ -4643,6 +4642,7 @@ elasticsearch = [] generic = ["httpx"] google = ["google-cloud-datastore", "google-cloud-pubsub"] influxdb = ["influxdb", "influxdb-client"] +jax = [] k3s = ["kubernetes", "pyyaml"] kafka = [] keycloak = ["python-keycloak"] @@ -4677,4 +4677,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "18a5763385d12114513ef5d65268de3ea6567e79b21049b6d58d1803f4257306" +content-hash = "3d381b82f4484c2fff23b22a08d7750f9eed2dc525a7cdf361346b81560283fb" diff --git a/pyproject.toml b/pyproject.toml index 3bccf880..44c5f23e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ packages = [ { include = "testcontainers", from = "modules/test_module_import"}, { include = "testcontainers", from = "modules/google" }, { include = "testcontainers", from = "modules/influxdb" }, + { include = "testcontainers", from = "modules/jax" }, { include = "testcontainers", from = "modules/k3s" }, { include = "testcontainers", from = "modules/kafka" }, { include = "testcontainers", from = "modules/keycloak" }, @@ -148,6 +149,7 @@ neo4j = ["neo4j"] nginx = [] opensearch = ["opensearch-py"] ollama = [] +jax = ["jax"] oracle = ["sqlalchemy", "oracledb"] oracle-free = ["sqlalchemy", "oracledb"] postgres = []