Skip to content

Commit

Permalink
add test for jaxcontainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Josephrp committed Aug 5, 2024
1 parent 54f842d commit 8dabef0
Showing 1 changed file with 38 additions and 23 deletions.
61 changes: 38 additions & 23 deletions modules/jax/tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,43 @@
import pytest
from modules.jax.testcontainers.jax_cuda import JAXContainer

def test_jax_container():
with JAXContainer() as jax_container:
jax_container.connect()

# Test running a simple JAX computation
result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))")
assert "2" in result.output.decode()
@pytest.fixture(scope="module")
def jax_container():
with JAXContainer() as container:
container.connect()
yield container

def test_jax_container_gpu_support():
with JAXContainer() as jax_container:
jax_container.connect()

# Test GPU availability
result = jax_container.run_jax_command(
"import jax; print(jax.devices())"
)
assert "gpu" in result.output.decode().lower()
def test_jax_container_basic_computation(jax_container):
result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))")
assert "2" in result.output.decode(), "Basic JAX computation failed"

def test_jax_container_jupyter():
with JAXContainer() as jax_container:
jax_container.connect()

jupyter_url = jax_container.get_jupyter_url()
assert jupyter_url.startswith("http://")
assert ":8888" in jupyter_url
def test_jax_container_version(jax_container):
result = jax_container.run_jax_command("import jax; print(jax.__version__)")
assert result.exit_code == 0, "Failed to get JAX version"
assert result.output.decode().strip(), "JAX version is empty"

def test_jax_container_gpu_support(jax_container):
result = jax_container.run_jax_command(
"import jax; devices = jax.devices(); "
"print(any(dev.platform == 'gpu' for dev in devices))"
)
assert "True" in result.output.decode(), "No GPU device found"

def test_jax_container_matrix_multiplication(jax_container):
command = """
import jax
import jax.numpy as jnp
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6], [7, 8]])
result = jnp.dot(x, y)
print(result)
"""
result = jax_container.run_jax_command(command)
assert "[[19 22]\n [43 50]]" in result.output.decode(), "Matrix multiplication failed"

def test_jax_container_custom_image():
custom_image = "nvcr.io/nvidia/jax:23.09-py3"
with JAXContainer(image=custom_image) as container:
container.connect()
result = container.run_jax_command("import jax; print(jax.__version__)")
assert result.exit_code == 0, f"Failed to run JAX with custom image {custom_image}"

0 comments on commit 8dabef0

Please sign in to comment.