Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add JAX Container #664

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
58 changes: 58 additions & 0 deletions modules/jax/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Testcontainers : JAX

## Docker Containers for JAX with GPU Support

1. **Official JAX Docker Container**
- **Container**: `jax/jax:cuda-12.0`
- **Documentation**: [JAX Docker](https://github.com/google/jax/blob/main/docker/README.md)

2. **NVIDIA Docker Container**
- **Container**: `nvidia/cuda:12.0-cudnn8-devel-ubuntu20.04`
- **Documentation**: [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda)

## Benefits of Having This Container

1. **Optimized Performance**: JAX uses XLA to compile and run NumPy programs on GPUs, which can significantly speed up numerical computations and machine learning tasks. A container specifically optimized for JAX with CUDA ensures that the environment is configured to leverage GPU acceleration fully.

2. **Reproducibility**: Containers encapsulate all dependencies, libraries, and configurations needed to run JAX, ensuring that the environment is consistent across different systems. This is crucial for reproducible research and development.

3. **Ease of Use**: Users can easily pull and run the container without worrying about the complex setup required for GPU support and JAX configuration. This reduces the barrier to entry for new users and accelerates development workflows.

4. **Isolation and Security**: Containers provide an isolated environment, which enhances security by limiting the impact of potential vulnerabilities. It also avoids conflicts with other software on the host system.

## Troubleshooting

**Ensure Docker is configured to use the NVIDIA runtime**:
- You need to install the NVIDIA Container Toolkit. Follow the instructions for your operating system: [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
- Update your Docker daemon configuration to include the NVIDIA runtime. Edit the Docker daemon configuration file, typically located at `/etc/docker/daemon.json`, to include the following:

```json
{
"runtimes": {
"nvidia": {
"path": "nvidia-container-runtime",
"runtimeArgs": []
}
}
}
```

- Restart the Docker daemon to apply the changes:
```sh
sudo systemctl restart docker
```

## Relevant Reading Material

1. **JAX Documentation**
- [JAX Quickstart](https://github.com/google/jax#quickstart)
- [JAX Transformations](https://github.com/google/jax#transformations)
- [JAX Installation Guide](https://github.com/google/jax#installation)

2. **NVIDIA Docker Documentation**
- [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda)
- [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker)

3. **Docker Best Practices**
- [Docker Documentation](https://docs.docker.com/get-started/)
- [Best practices for writing Dockerfiles](https://docs.docker.com/develop/develop-images/dockerfile_best-practices/)
93 changes: 93 additions & 0 deletions modules/jax/testcontainers/jax_amd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
import urllib.request
from urllib.error import URLError

from core.testcontainers.core.container import DockerContainer
from core.testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs


class JAXContainer(DockerContainer):
"""
JAX container for GPU-accelerated numerical computing and machine learning.

Example:

.. doctest::

>>> import jax
>>> from testcontainers.jax import JAXContainer

>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
... # Connect to the container
... jax_container.connect()
...
... # Run a simple JAX computation
... result = jax.numpy.add(1, 1)
... assert result == 2

.. auto-class:: JAXContainer
:members:
:undoc-members:
:show-inheritance:
"""

def __init__(self, image="huggingface/transformers-jax-light:latest", **kwargs):
super().__init__(image, **kwargs)
self.with_exposed_ports(8888) # Expose Jupyter notebook port
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
self.with_env("CUDA_VISIBLE_DEVICES", "all")
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
self.start_timeout = 600 # 10 minutes

@wait_container_is_ready(URLError)
def _connect(self):
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
res = urllib.request.urlopen(url, timeout=self.start_timeout)
if res.status != 200:
raise Exception(f"Failed to connect to JAX container. Status: {res.status}")

def connect(self):
"""
Connect to the JAX container and ensure it's ready.
"""
self._connect()
logging.info("Successfully connected to JAX container")

def get_jupyter_url(self):
"""
Get the URL for accessing the Jupyter notebook server.
"""
return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"

def run_jax_command(self, command):
"""
Run a JAX command inside the container.
"""
exec_result = self.exec(f"python -c '{command}'")
return exec_result

def _wait_for_container_to_be_ready(self):
wait_for_logs(self, "Jupyter Server", timeout=self.start_timeout)

def start(self):
"""
Start the JAX container and wait for it to be ready.
"""
super().start()
self._wait_for_container_to_be_ready()
logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}")
return self

def stop(self, force=True):
"""
Stop the JAX container.
"""
super().stop(force)
logging.info("JAX container stopped.")

@property
def timeout(self):
"""
Get the container start timeout.
"""
return self.start_timeout
97 changes: 97 additions & 0 deletions modules/jax/testcontainers/jax_cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
from urllib.error import URLError

from core.testcontainers.core.container import DockerContainer
from core.testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs


class JAXContainer(DockerContainer):
"""
JAX container for GPU-accelerated numerical computing and machine learning.

Example:

.. doctest::

>>> from testcontainers.jax import JAXContainer

>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
... # Connect to the container
... jax_container.connect()
...
... # Run a simple JAX computation
... result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))")
... assert "2" in result.output

.. auto-class:: JAXContainer
:members:
:undoc-members:
:show-inheritance:
"""

def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs):
super().__init__(image, **kwargs)
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
self.with_env("CUDA_VISIBLE_DEVICES", "all")
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
self.start_timeout = 600 # 10 minutes
self.connection_retries = 5
self.connection_retry_delay = 10 # seconds

@wait_container_is_ready(URLError)
def _connect(self):
# Check if JAX is properly installed and functioning
result = self.run_jax_command(
"import jax; import jaxlib; "
"print(f'JAX version: {jax.__version__}'); "
"print(f'JAXlib version: {jaxlib.__version__}'); "
"print(f'Available devices: {jax.devices()}'); "
"print(jax.numpy.add(1, 1))"
)

if "JAX version" in result.output and "Available devices" in result.output:
logging.info(f"JAX environment verified:\n{result.output}")
else:
raise Exception("JAX environment check failed")

def connect(self):
"""
Connect to the JAX container and ensure it's ready.
This method verifies that JAX is properly installed and functioning.
It also checks for available devices, including GPUs if applicable.
"""
self._connect()
logging.info("Successfully connected to JAX container and verified the environment")

def run_jax_command(self, command):
"""
Run a JAX command inside the container.
"""
exec_result = self.exec(f"python -c '{command}'")
return exec_result

def _wait_for_container_to_be_ready(self):
wait_for_logs(self, "JAX is ready", timeout=self.start_timeout)

def start(self):
"""
Start the JAX container and wait for it to be ready.
"""
super().start()
self._wait_for_container_to_be_ready()
logging.info("JAX container started and ready.")
return self

def stop(self, force=True, delete_volume=True) -> None:
"""
Stop the JAX container.
"""
super().stop(force, delete_volume)
logging.info("JAX container stopped.")

@property
def timeout(self):
"""
Get the container start timeout.
"""
return self.start_timeout
Loading
Loading