Skip to content

Commit

Permalink
Don't copy the JAX array when using exec_dynamic.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Nov 7, 2024
1 parent 097135a commit 81af254
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 34 deletions.
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/plugin/jax/clu.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DALIGenericPeekableIterator(DALIGenericIterator):
is called internally automatically.
last_batch_policy: optional, default = LastBatchPolicy.FILL
What to do with the last batch when there are not enough samples in the epoch
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`.
JAX iterator does not support LastBatchPolicy.PARTIAL
last_batch_padded : bool, optional, default = False
Whether the last batch provided by DALI is padded with the last sample
Expand Down
25 changes: 15 additions & 10 deletions dali/python/nvidia/dali/plugin/jax/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,30 @@
from packaging.version import Version


_jax_version_pre_0_4_16 = None
_jax_has_old_dlpack = Version(jax.__version__) < Version("0.4.16")


def _jax_has_old_dlpack():
global _jax_version_pre_0_4_16
if _jax_version_pre_0_4_16 is not None:
return _jax_version_pre_0_4_16
if Version(jax.__version__) >= Version("0.4.26"):

from packaging.version import Version
def _jax_device(jax_array):
return jax_array.device

_jax_version_pre_0_4_16 = Version(jax.__version__) < Version("0.4.16")
return _jax_version_pre_0_4_16
else:

def _jax_device(jax_array):
return jax_array.device()


def _to_jax_array(dali_tensor: TensorGPU, copy: bool) -> jax.Array:
"""Converts input DALI tensor to JAX array.
Args:
dali_tensor (TensorGPU): DALI GPU tensor to be converted to JAX array.
dali_tensor (TensorGPU):
DALI GPU tensor to be converted to JAX array.
copy (bool):
If True, the output is copied;
if False, the output may wrap DLPack capsule obtained from `dali_tensor`.
Note:
This function may perform a copy of the data even if `copy==False` when JAX version is
Expand All @@ -50,7 +55,7 @@ def _to_jax_array(dali_tensor: TensorGPU, copy: bool) -> jax.Array:
jax.Array: JAX array with the same values and backing device as
input DALI tensor.
"""
if _jax_has_old_dlpack():
if _jax_has_old_dlpack:
copy = True
jax_array = jax.dlpack.from_dlpack(dali_tensor.__dlpack__(stream=None))
else:
Expand Down
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/plugin/jax/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class DALIGenericIterator(_DaliBaseIterator):
is called internally automatically.
last_batch_policy: optional, default = LastBatchPolicy.FILL
What to do with the last batch when there are not enough samples in the epoch
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`.
JAX iterator does not support LastBatchPolicy.PARTIAL
last_batch_padded : bool, optional, default = False
Whether the last batch provided by DALI is padded with the last sample
Expand Down
8 changes: 5 additions & 3 deletions dali/test/python/jax_plugin/jax_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_lax_workflow(process_id):
array_from_dali = dax.integration._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32), False)

assert (
array_from_dali.device() == jax.local_devices()[0]
dax.integration._jax_device(array_from_dali) == jax.local_devices()[0]
), "Array should be backed by the device local to current process."

sum_across_devices = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(array_from_dali)
Expand All @@ -77,8 +77,10 @@ def run_distributed_sharing_test(sharding, process_id):
# local part of the data. This buffer should be on the local device.
assert len(dali_sharded_array.device_buffers) == 1
assert dali_sharded_array.device_buffer == jnp.array([process_id])
assert dali_sharded_array.device_buffer.device() == jax.local_devices()[0]
assert dali_sharded_array.device_buffer.device() == jax.devices()[process_id]
assert dax.integration._jax_device(dali_sharded_array.device_buffer) == jax.local_devices()[0]
assert (
dax.integration._jax_device(dali_sharded_array.device_buffer) == jax.devices()[process_id]
)


def test_positional_sharding_workflow(process_id):
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/jax_plugin/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value):
assert jax.numpy.array_equal(jax_array, jax.numpy.full(shape, value, dtype))

# Make sure JAX array is backed by the GPU
assert jax_array.device() == jax.devices()[0]
assert dax.integration._jax_device(jax_array) == jax.devices()[0]


def test_dali_sequential_tensors_to_jax_array():
Expand All @@ -59,7 +59,7 @@ def test_dali_sequential_tensors_to_jax_array():
jax_array = dax.integration._to_jax_array(dali_tensor_gpu, False)

# then
assert jax_array.device() == jax.devices()[0]
assert dax.integration._jax_device(jax_array) == jax.devices()[0]

for i in range(batch_size):
assert jax.numpy.array_equal(
Expand Down
13 changes: 9 additions & 4 deletions dali/test/python/jax_plugin/test_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,10 +21,12 @@

from utils import iterator_function_def

import nvidia.dali.plugin.jax as dax

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'nvidia.dali.plugin.jax' is imported with both 'import' and 'import from'.
Module 'plugin.jax' is imported with both 'import' and 'import from'.
from nvidia.dali.plugin.jax import DALIGenericIterator
from nvidia.dali.pipeline import pipeline_def
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
from nose_utils import raises
from nose2.tools import params

import itertools

Expand All @@ -39,7 +41,7 @@ def run_and_assert_sequential_iterator(iter, num_iters=4):
jax_array = data["data"]

# then
assert jax_array.device() == jax.devices()[0]
assert dax.integration._jax_device(jax_array) == jax.devices()[0]

for i in range(batch_size):
assert jax.numpy.array_equal(
Expand All @@ -49,9 +51,12 @@ def run_and_assert_sequential_iterator(iter, num_iters=4):
assert batch_id == num_iters - 1


def test_dali_sequential_iterator():
@params((False,), (True,))
def test_dali_sequential_iterator(exec_dynamic):
# given
pipe = pipeline_def(iterator_function_def)(batch_size=batch_size, num_threads=4, device_id=0)
pipe = pipeline_def(iterator_function_def)(
batch_size=batch_size, num_threads=4, device_id=0, exec_dynamic=exec_dynamic
)
iter = DALIGenericIterator([pipe], ["data"], reader_name="reader")

# then
Expand Down
29 changes: 18 additions & 11 deletions dali/test/python/jax_plugin/test_iterator_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@

from nvidia.dali.plugin.jax import DALIGenericIterator, data_iterator
from test_iterator import run_and_assert_sequential_iterator
from nose2.tools import params

import inspect

Expand Down Expand Up @@ -46,25 +47,29 @@ def iterator_function():
run_and_assert_sequential_iterator(iter)


def test_dali_iterator_decorator_declarative_with_default_args():
@params((False,), (True,))
def test_dali_iterator_decorator_declarative_with_default_args(exec_dynamic):
# given
@data_iterator(output_map=["data"], reader_name="reader")
def iterator_function():
return iterator_function_def()

iter = iterator_function(batch_size=batch_size)
iter = iterator_function(batch_size=batch_size, exec_dynamic=exec_dynamic)

# then
run_and_assert_sequential_iterator(iter)


def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument():
@params((False,), (True,))
def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument(exec_dynamic):
# given
@data_iterator(output_map=["data"], reader_name="reader")
def iterator_function(num_shards):
return iterator_function_def(num_shards=num_shards)

iter = iterator_function(num_shards=2, num_threads=4, device_id=0, batch_size=batch_size)
iter = iterator_function(
num_shards=2, num_threads=4, device_id=0, batch_size=batch_size, exec_dynamic=exec_dynamic
)

# then
run_and_assert_sequential_iterator(iter)
Expand All @@ -91,9 +96,10 @@ def test_iterator_decorator_api_match_iterator_init():
iterator_decorator_args.remove("devices")

# then
assert (
iterator_decorator_args == iterator_init_args
), "Arguments for the iterator decorator and the iterator __init__ method do not match"
assert iterator_decorator_args == iterator_init_args, (
f"Arguments for the iterator decorator and the iterator __init__ method do not match:"
f"\n------\n{iterator_decorator_args}\n-- vs --\n{iterator_init_args}\n------"
)

# Get docs for the decorator "Parameters" section
# Skip the first argument, which differs (pipelines vs. pipeline_fn)
Expand All @@ -107,6 +113,7 @@ def test_iterator_decorator_api_match_iterator_init():
iterator_init_docs = iterator_init_docs.split("output_map")[1]
iterator_init_docs = iterator_init_docs.split("sharding")[0]

assert (
iterator_decorator_docs == iterator_init_docs
), "Documentation for the iterator decorator and the iterator __init__ method does not match"
assert iterator_decorator_docs == iterator_init_docs, (
"Documentation for the iterator decorator and the iterator __init__ method does not match:"
f"\n------\n{iterator_decorator_docs}\n-- vs --\n{iterator_init_docs}\n------"
)
4 changes: 2 additions & 2 deletions qa/TL3_JAX_multiprocess/jax_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run_distributed_sharing_test(sharding, process_id):
get_dali_tensor_gpu(process_id, (1), np.int32, id), False
)

assert current_shard.device() == device
assert dax.integration._jax_device(current_shard) == device

dali_local_shards.append(current_shard)

Expand All @@ -97,7 +97,7 @@ def run_distributed_sharing_test(sharding, process_id):

for id, buffer in enumerate(dali_sharded_array.device_buffers):
assert buffer == jnp.array([process_id])
assert buffer.device() == jax.local_devices()[id]
assert dax.integration._jax_device(buffer) == jax.local_devices()[id]


def test_positional_sharding_workflow(process_id):
Expand Down

0 comments on commit 81af254

Please sign in to comment.