Skip to content

Commit

Permalink
improve diarization object remove jupyter
Browse files Browse the repository at this point in the history
  • Loading branch information
Josephrp committed Aug 5, 2024
1 parent 59705dc commit 6fed7e9
Showing 1 changed file with 54 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,57 @@ def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Option
super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs)
self.model_name = model_name
self.hf_token = hf_token
self.with_exposed_ports(8888) # Expose Jupyter notebook port
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
self.with_env("CUDA_VISIBLE_DEVICES", "all")
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
self.start_timeout = 600 # 10 minutes
self.connection_retries = 5
self.connection_retry_delay = 10 # seconds

# Install required dependencies
self.with_command("sh -c '"
"pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && "
"pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && "
"python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && "
"jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
"python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
"'")

@wait_container_is_ready(URLError)
def _connect(self):
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
res = urllib.request.urlopen(url)
if res.status != 200:
raise Exception(f"Failed to connect to JAX-Whisper-Diarization container. Status: {res.status}")
for attempt in range(self.connection_retries):
try:
# Check if JAX and other required libraries are properly installed and functioning
result = self.run_command(
"import jax; import whisper_jax; import pyannote.audio; "
"print(f'JAX version: {jax.__version__}'); "
"print(f'Whisper-JAX version: {whisper_jax.__version__}'); "
"print(f'Pyannote Audio version: {pyannote.audio.__version__}'); "
"print(f'Available devices: {jax.devices()}'); "
"print(jax.numpy.add(1, 1))"
)

if "JAX version" in result.output.decode() and "Available devices" in result.output.decode():
logging.info(f"JAX-Whisper-Diarization environment verified:\n{result.output.decode()}")
return True
else:
raise Exception("JAX-Whisper-Diarization environment check failed")

except Exception as e:
if attempt < self.connection_retries - 1:
logging.warning(f"Connection attempt {attempt + 1} failed. Retrying in {self.connection_retry_delay} seconds...")
time.sleep(self.connection_retry_delay)
else:
raise Exception(f"Failed to connect to JAX-Whisper-Diarization container after {self.connection_retries} attempts: {str(e)}")

return False

def connect(self):
"""
Connect to the JAX-Whisper-Diarization container and ensure it's ready.
This method verifies that JAX, Whisper-JAX, and Pyannote Audio are properly installed and functioning.
It also checks for available devices, including GPUs if applicable.
"""
self._connect()
logging.info("Successfully connected to JAX-Whisper-Diarization container")
logging.info("Successfully connected to JAX-Whisper-Diarization container and verified the environment")

def run_command(self, command: str):
"""
Expand Down Expand Up @@ -242,8 +267,27 @@ def align(transcription, segments, group_by_speaker=True):

def start(self):
"""
Start the JAX-Whisper-Diarization container.
Start the JAX-Whisper-Diarization container and wait for it to be ready.
"""
super().start()
logging.info(f"JAX-Whisper-Diarization container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}")
self._wait_for_container_to_be_ready()
logging.info("JAX-Whisper-Diarization container started and ready.")
return self

def _wait_for_container_to_be_ready(self):
# Wait for a specific log message that indicates the container is ready
self.wait_for_logs("Installation completed")

def stop(self, force=True):
"""
Stop the JAX-Whisper-Diarization container.
"""
super().stop(force)
logging.info("JAX-Whisper-Diarization container stopped.")

@property
def timeout(self):
"""
Get the container start timeout.
"""
return self.start_timeout

0 comments on commit 6fed7e9

Please sign in to comment.