diff --git a/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py index 05a897ee..2d362812 100644 --- a/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py +++ b/modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py @@ -35,32 +35,57 @@ def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Option super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) self.model_name = model_name self.hf_token = hf_token - self.with_exposed_ports(8888) # Expose Jupyter notebook port self.with_env("NVIDIA_VISIBLE_DEVICES", "all") self.with_env("CUDA_VISIBLE_DEVICES", "all") self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support + self.start_timeout = 600 # 10 minutes + self.connection_retries = 5 + self.connection_retry_delay = 10 # seconds # Install required dependencies self.with_command("sh -c '" "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && " - "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " - "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" "'") @wait_container_is_ready(URLError) def _connect(self): - url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" - res = urllib.request.urlopen(url) - if res.status != 200: - raise Exception(f"Failed to connect to JAX-Whisper-Diarization container. Status: {res.status}") + for attempt in range(self.connection_retries): + try: + # Check if JAX and other required libraries are properly installed and functioning + result = self.run_command( + "import jax; import whisper_jax; import pyannote.audio; " + "print(f'JAX version: {jax.__version__}'); " + "print(f'Whisper-JAX version: {whisper_jax.__version__}'); " + "print(f'Pyannote Audio version: {pyannote.audio.__version__}'); " + "print(f'Available devices: {jax.devices()}'); " + "print(jax.numpy.add(1, 1))" + ) + + if "JAX version" in result.output.decode() and "Available devices" in result.output.decode(): + logging.info(f"JAX-Whisper-Diarization environment verified:\n{result.output.decode()}") + return True + else: + raise Exception("JAX-Whisper-Diarization environment check failed") + + except Exception as e: + if attempt < self.connection_retries - 1: + logging.warning(f"Connection attempt {attempt + 1} failed. Retrying in {self.connection_retry_delay} seconds...") + time.sleep(self.connection_retry_delay) + else: + raise Exception(f"Failed to connect to JAX-Whisper-Diarization container after {self.connection_retries} attempts: {str(e)}") + + return False def connect(self): """ Connect to the JAX-Whisper-Diarization container and ensure it's ready. + This method verifies that JAX, Whisper-JAX, and Pyannote Audio are properly installed and functioning. + It also checks for available devices, including GPUs if applicable. """ self._connect() - logging.info("Successfully connected to JAX-Whisper-Diarization container") + logging.info("Successfully connected to JAX-Whisper-Diarization container and verified the environment") def run_command(self, command: str): """ @@ -242,8 +267,27 @@ def align(transcription, segments, group_by_speaker=True): def start(self): """ - Start the JAX-Whisper-Diarization container. + Start the JAX-Whisper-Diarization container and wait for it to be ready. """ super().start() - logging.info(f"JAX-Whisper-Diarization container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}") + self._wait_for_container_to_be_ready() + logging.info("JAX-Whisper-Diarization container started and ready.") return self + + def _wait_for_container_to_be_ready(self): + # Wait for a specific log message that indicates the container is ready + self.wait_for_logs("Installation completed") + + def stop(self, force=True): + """ + Stop the JAX-Whisper-Diarization container. + """ + super().stop(force) + logging.info("JAX-Whisper-Diarization container stopped.") + + @property + def timeout(self): + """ + Get the container start timeout. + """ + return self.start_timeout