From d10a227c99f3a2c6bba25508bf959bf3db1b5f48 Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Wed, 6 Nov 2024 16:55:33 +0100 Subject: [PATCH 1/5] Add zero-copy to jax iterator. Use capsule directly (and copy) when JAX is too old. Signed-off-by: Michal Zientkiewicz --- dali/python/nvidia/dali/plugin/jax/integration.py | 15 ++++++++------- dali/python/nvidia/dali/plugin/jax/iterator.py | 5 ++++- dali/test/python/jax_plugin/jax_server.py | 6 +++--- dali/test/python/jax_plugin/test_integration.py | 6 +++--- dali/test/python/jax_plugin/test_multigpu.py | 10 +++++----- dali/test/python/jax_plugin/utils.py | 4 ++-- qa/TL3_JAX_multiprocess/jax_server.py | 4 ++-- 7 files changed, 27 insertions(+), 23 deletions(-) diff --git a/dali/python/nvidia/dali/plugin/jax/integration.py b/dali/python/nvidia/dali/plugin/jax/integration.py index a29405d4eea..3ac3f644866 100644 --- a/dali/python/nvidia/dali/plugin/jax/integration.py +++ b/dali/python/nvidia/dali/plugin/jax/integration.py @@ -16,6 +16,7 @@ import jax.dlpack from nvidia.dali.backend import TensorGPU +from packaging.version import Version _jax_version_pre_0_4_16 = None @@ -32,15 +33,15 @@ def _jax_has_old_dlpack(): return _jax_version_pre_0_4_16 -def _to_jax_array(dali_tensor: TensorGPU) -> jax.Array: +def _to_jax_array(dali_tensor: TensorGPU, copy: bool) -> jax.Array: """Converts input DALI tensor to JAX array. Args: dali_tensor (TensorGPU): DALI GPU tensor to be converted to JAX array. Note: - This function performs deep copy of the underlying data. That will change in - future releases. + This function may perform a copy of the data even if `copy==False` when JAX version is + insufficient (<0.4.16) Warning: As private this API may change without notice. @@ -50,11 +51,11 @@ def _to_jax_array(dali_tensor: TensorGPU) -> jax.Array: input DALI tensor. """ if _jax_has_old_dlpack(): + copy = True jax_array = jax.dlpack.from_dlpack(dali_tensor.__dlpack__(stream=None)) else: jax_array = jax.dlpack.from_dlpack(dali_tensor) - # For now we need this copy to make sure that underlying memory is available. - # One solution is to implement full DLPack contract in DALI. - # TODO(awolant): Remove this copy. - return jax_array.copy() + if copy: + jax_array = jax_array.copy() + return jax_array diff --git a/dali/python/nvidia/dali/plugin/jax/iterator.py b/dali/python/nvidia/dali/plugin/jax/iterator.py index 93fc5184276..0dcf9bf22c8 100644 --- a/dali/python/nvidia/dali/plugin/jax/iterator.py +++ b/dali/python/nvidia/dali/plugin/jax/iterator.py @@ -193,7 +193,10 @@ def _gather_outputs_for_category(self, pipelines_outputs, category_id): for pipeline_id in range(self._num_gpus): category_outputs.append( - _to_jax_array(pipelines_outputs[pipeline_id][category_id].as_tensor()) + _to_jax_array( + pipelines_outputs[pipeline_id][category_id].as_tensor(), + not self._pipes[pipeline_id].exec_dynamic, + ) ) return category_outputs diff --git a/dali/test/python/jax_plugin/jax_server.py b/dali/test/python/jax_plugin/jax_server.py index db4694ba8b0..f20b3e18d02 100644 --- a/dali/test/python/jax_plugin/jax_server.py +++ b/dali/test/python/jax_plugin/jax_server.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ def print_devices_details(devices_list, process_id): def test_lax_workflow(process_id): - array_from_dali = dax.integration._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32)) + array_from_dali = dax.integration._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32), False) assert ( array_from_dali.device() == jax.local_devices()[0] @@ -64,7 +64,7 @@ def test_lax_workflow(process_id): def run_distributed_sharing_test(sharding, process_id): dali_local_shard = dax.integration._to_jax_array( - get_dali_tensor_gpu(process_id, (1), np.int32, 0) + get_dali_tensor_gpu(process_id, (1), np.int32, 0), False ) # Note: we pass only one local shard but the array virtually diff --git a/dali/test/python/jax_plugin/test_integration.py b/dali/test/python/jax_plugin/test_integration.py index 8eff162ebb4..1ff999bf38f 100644 --- a/dali/test/python/jax_plugin/test_integration.py +++ b/dali/test/python/jax_plugin/test_integration.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value): dali_tensor_gpu = get_dali_tensor_gpu(value=value, shape=shape, dtype=dtype) # when - jax_array = dax.integration._to_jax_array(dali_tensor_gpu) + jax_array = dax.integration._to_jax_array(dali_tensor_gpu, False) # then assert jax.numpy.array_equal(jax_array, jax.numpy.full(shape, value, dtype)) @@ -56,7 +56,7 @@ def test_dali_sequential_tensors_to_jax_array(): dali_tensor_gpu = pipe.run()[0].as_tensor() # when - jax_array = dax.integration._to_jax_array(dali_tensor_gpu) + jax_array = dax.integration._to_jax_array(dali_tensor_gpu, False) # then assert jax_array.device() == jax.devices()[0] diff --git a/dali/test/python/jax_plugin/test_multigpu.py b/dali/test/python/jax_plugin/test_multigpu.py index 4bc1877dd6c..8f6ffe02cfa 100644 --- a/dali/test/python/jax_plugin/test_multigpu.py +++ b/dali/test/python/jax_plugin/test_multigpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -99,8 +99,8 @@ def test_dali_sequential_sharded_tensors_to_jax_sharded_array_manuall(): dali_tensor_gpu_0 = pipe_0.run()[0].as_tensor() dali_tensor_gpu_1 = pipe_1.run()[0].as_tensor() - jax_shard_0 = dax.integration._to_jax_array(dali_tensor_gpu_0) - jax_shard_1 = dax.integration._to_jax_array(dali_tensor_gpu_1) + jax_shard_0 = dax.integration._to_jax_array(dali_tensor_gpu_0, False) + jax_shard_1 = dax.integration._to_jax_array(dali_tensor_gpu_1, False) assert jax_shard_0.device() == jax.devices()[0] assert jax_shard_1.device() == jax.devices()[1] @@ -224,8 +224,8 @@ def run_sharding_test(sharding): dali_shard_1 = get_dali_tensor_gpu(1, (1), np.int32, 1) shards = [ - dax.integration._to_jax_array(dali_shard_0), - dax.integration._to_jax_array(dali_shard_1), + dax.integration._to_jax_array(dali_shard_0, False), + dax.integration._to_jax_array(dali_shard_1, False), ] assert shards[0].device() == jax.devices()[0] diff --git a/dali/test/python/jax_plugin/utils.py b/dali/test/python/jax_plugin/utils.py index e2b2a0a08d7..6333f882fd0 100644 --- a/dali/test/python/jax_plugin/utils.py +++ b/dali/test/python/jax_plugin/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ def get_dali_tensor_gpu(value, shape, dtype, device_id=0) -> TensorGPU: with provided value. """ - @pipeline_def(num_threads=1, batch_size=1) + @pipeline_def(num_threads=1, batch_size=1, exec_dynamic=True) def dali_pipeline(): values = types.Constant(value=np.full(shape, value, dtype), device="gpu") diff --git a/qa/TL3_JAX_multiprocess/jax_server.py b/qa/TL3_JAX_multiprocess/jax_server.py index 3d834ad25c7..90dbfbe9eba 100644 --- a/qa/TL3_JAX_multiprocess/jax_server.py +++ b/qa/TL3_JAX_multiprocess/jax_server.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -82,7 +82,7 @@ def run_distributed_sharing_test(sharding, process_id): dali_local_shards = [] for id, device in enumerate(jax.local_devices()): current_shard = dax.integration._to_jax_array( - get_dali_tensor_gpu(process_id, (1), np.int32, id) + get_dali_tensor_gpu(process_id, (1), np.int32, id), False ) assert current_shard.device() == device From c5d6d2e9ecb233e90b6f3cc424e338dbf82c9ffc Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Thu, 7 Nov 2024 13:13:50 +0100 Subject: [PATCH 2/5] Don't copy the JAX array when using exec_dynamic. Signed-off-by: Michal Zientkiewicz --- dali/python/nvidia/dali/plugin/jax/clu.py | 2 +- .../nvidia/dali/plugin/jax/integration.py | 25 +++++++++------- .../python/nvidia/dali/plugin/jax/iterator.py | 2 +- dali/test/python/jax_plugin/jax_server.py | 8 +++-- .../python/jax_plugin/test_integration.py | 4 +-- dali/test/python/jax_plugin/test_iterator.py | 13 ++++++--- .../jax_plugin/test_iterator_decorator.py | 29 ++++++++++++------- qa/TL3_JAX_multiprocess/jax_server.py | 4 +-- 8 files changed, 53 insertions(+), 34 deletions(-) diff --git a/dali/python/nvidia/dali/plugin/jax/clu.py b/dali/python/nvidia/dali/plugin/jax/clu.py index 1c3a5a73e1c..23808f531b6 100644 --- a/dali/python/nvidia/dali/plugin/jax/clu.py +++ b/dali/python/nvidia/dali/plugin/jax/clu.py @@ -82,7 +82,7 @@ class DALIGenericPeekableIterator(DALIGenericIterator): is called internally automatically. last_batch_policy: optional, default = LastBatchPolicy.FILL What to do with the last batch when there are not enough samples in the epoch - to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy` + to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`. JAX iterator does not support LastBatchPolicy.PARTIAL last_batch_padded : bool, optional, default = False Whether the last batch provided by DALI is padded with the last sample diff --git a/dali/python/nvidia/dali/plugin/jax/integration.py b/dali/python/nvidia/dali/plugin/jax/integration.py index 3ac3f644866..f4e3f003438 100644 --- a/dali/python/nvidia/dali/plugin/jax/integration.py +++ b/dali/python/nvidia/dali/plugin/jax/integration.py @@ -19,25 +19,30 @@ from packaging.version import Version -_jax_version_pre_0_4_16 = None +_jax_has_old_dlpack = Version(jax.__version__) < Version("0.4.16") -def _jax_has_old_dlpack(): - global _jax_version_pre_0_4_16 - if _jax_version_pre_0_4_16 is not None: - return _jax_version_pre_0_4_16 +if Version(jax.__version__) >= Version("0.4.26"): - from packaging.version import Version + def _jax_device(jax_array): + return jax_array.device - _jax_version_pre_0_4_16 = Version(jax.__version__) < Version("0.4.16") - return _jax_version_pre_0_4_16 +else: + + def _jax_device(jax_array): + return jax_array.device() def _to_jax_array(dali_tensor: TensorGPU, copy: bool) -> jax.Array: """Converts input DALI tensor to JAX array. Args: - dali_tensor (TensorGPU): DALI GPU tensor to be converted to JAX array. + dali_tensor (TensorGPU): + DALI GPU tensor to be converted to JAX array. + + copy (bool): + If True, the output is copied; + if False, the output may wrap DLPack capsule obtained from `dali_tensor`. Note: This function may perform a copy of the data even if `copy==False` when JAX version is @@ -50,7 +55,7 @@ def _to_jax_array(dali_tensor: TensorGPU, copy: bool) -> jax.Array: jax.Array: JAX array with the same values and backing device as input DALI tensor. """ - if _jax_has_old_dlpack(): + if _jax_has_old_dlpack: copy = True jax_array = jax.dlpack.from_dlpack(dali_tensor.__dlpack__(stream=None)) else: diff --git a/dali/python/nvidia/dali/plugin/jax/iterator.py b/dali/python/nvidia/dali/plugin/jax/iterator.py index 0dcf9bf22c8..3d9632d2b94 100644 --- a/dali/python/nvidia/dali/plugin/jax/iterator.py +++ b/dali/python/nvidia/dali/plugin/jax/iterator.py @@ -66,7 +66,7 @@ class DALIGenericIterator(_DaliBaseIterator): is called internally automatically. last_batch_policy: optional, default = LastBatchPolicy.FILL What to do with the last batch when there are not enough samples in the epoch - to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy` + to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`. JAX iterator does not support LastBatchPolicy.PARTIAL last_batch_padded : bool, optional, default = False Whether the last batch provided by DALI is padded with the last sample diff --git a/dali/test/python/jax_plugin/jax_server.py b/dali/test/python/jax_plugin/jax_server.py index f20b3e18d02..d6a3fe07fe2 100644 --- a/dali/test/python/jax_plugin/jax_server.py +++ b/dali/test/python/jax_plugin/jax_server.py @@ -50,7 +50,7 @@ def test_lax_workflow(process_id): array_from_dali = dax.integration._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32), False) assert ( - array_from_dali.device() == jax.local_devices()[0] + dax.integration._jax_device(array_from_dali) == jax.local_devices()[0] ), "Array should be backed by the device local to current process." sum_across_devices = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(array_from_dali) @@ -77,8 +77,10 @@ def run_distributed_sharing_test(sharding, process_id): # local part of the data. This buffer should be on the local device. assert len(dali_sharded_array.device_buffers) == 1 assert dali_sharded_array.device_buffer == jnp.array([process_id]) - assert dali_sharded_array.device_buffer.device() == jax.local_devices()[0] - assert dali_sharded_array.device_buffer.device() == jax.devices()[process_id] + assert dax.integration._jax_device(dali_sharded_array.device_buffer) == jax.local_devices()[0] + assert ( + dax.integration._jax_device(dali_sharded_array.device_buffer) == jax.devices()[process_id] + ) def test_positional_sharding_workflow(process_id): diff --git a/dali/test/python/jax_plugin/test_integration.py b/dali/test/python/jax_plugin/test_integration.py index 1ff999bf38f..0af217d1963 100644 --- a/dali/test/python/jax_plugin/test_integration.py +++ b/dali/test/python/jax_plugin/test_integration.py @@ -41,7 +41,7 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value): assert jax.numpy.array_equal(jax_array, jax.numpy.full(shape, value, dtype)) # Make sure JAX array is backed by the GPU - assert jax_array.device() == jax.devices()[0] + assert dax.integration._jax_device(jax_array) == jax.devices()[0] def test_dali_sequential_tensors_to_jax_array(): @@ -59,7 +59,7 @@ def test_dali_sequential_tensors_to_jax_array(): jax_array = dax.integration._to_jax_array(dali_tensor_gpu, False) # then - assert jax_array.device() == jax.devices()[0] + assert dax.integration._jax_device(jax_array) == jax.devices()[0] for i in range(batch_size): assert jax.numpy.array_equal( diff --git a/dali/test/python/jax_plugin/test_iterator.py b/dali/test/python/jax_plugin/test_iterator.py index 1fe30622352..3d6b671255c 100644 --- a/dali/test/python/jax_plugin/test_iterator.py +++ b/dali/test/python/jax_plugin/test_iterator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,10 +21,12 @@ from utils import iterator_function_def +import nvidia.dali.plugin.jax as dax from nvidia.dali.plugin.jax import DALIGenericIterator from nvidia.dali.pipeline import pipeline_def from nvidia.dali.plugin.base_iterator import LastBatchPolicy from nose_utils import raises +from nose2.tools import params import itertools @@ -39,7 +41,7 @@ def run_and_assert_sequential_iterator(iter, num_iters=4): jax_array = data["data"] # then - assert jax_array.device() == jax.devices()[0] + assert dax.integration._jax_device(jax_array) == jax.devices()[0] for i in range(batch_size): assert jax.numpy.array_equal( @@ -49,9 +51,12 @@ def run_and_assert_sequential_iterator(iter, num_iters=4): assert batch_id == num_iters - 1 -def test_dali_sequential_iterator(): +@params((False,), (True,)) +def test_dali_sequential_iterator(exec_dynamic): # given - pipe = pipeline_def(iterator_function_def)(batch_size=batch_size, num_threads=4, device_id=0) + pipe = pipeline_def(iterator_function_def)( + batch_size=batch_size, num_threads=4, device_id=0, exec_dynamic=exec_dynamic + ) iter = DALIGenericIterator([pipe], ["data"], reader_name="reader") # then diff --git a/dali/test/python/jax_plugin/test_iterator_decorator.py b/dali/test/python/jax_plugin/test_iterator_decorator.py index 98a4c3c1124..a6f2cf46b17 100644 --- a/dali/test/python/jax_plugin/test_iterator_decorator.py +++ b/dali/test/python/jax_plugin/test_iterator_decorator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ from nvidia.dali.plugin.jax import DALIGenericIterator, data_iterator from test_iterator import run_and_assert_sequential_iterator +from nose2.tools import params import inspect @@ -46,25 +47,29 @@ def iterator_function(): run_and_assert_sequential_iterator(iter) -def test_dali_iterator_decorator_declarative_with_default_args(): +@params((False,), (True,)) +def test_dali_iterator_decorator_declarative_with_default_args(exec_dynamic): # given @data_iterator(output_map=["data"], reader_name="reader") def iterator_function(): return iterator_function_def() - iter = iterator_function(batch_size=batch_size) + iter = iterator_function(batch_size=batch_size, exec_dynamic=exec_dynamic) # then run_and_assert_sequential_iterator(iter) -def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument(): +@params((False,), (True,)) +def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument(exec_dynamic): # given @data_iterator(output_map=["data"], reader_name="reader") def iterator_function(num_shards): return iterator_function_def(num_shards=num_shards) - iter = iterator_function(num_shards=2, num_threads=4, device_id=0, batch_size=batch_size) + iter = iterator_function( + num_shards=2, num_threads=4, device_id=0, batch_size=batch_size, exec_dynamic=exec_dynamic + ) # then run_and_assert_sequential_iterator(iter) @@ -91,9 +96,10 @@ def test_iterator_decorator_api_match_iterator_init(): iterator_decorator_args.remove("devices") # then - assert ( - iterator_decorator_args == iterator_init_args - ), "Arguments for the iterator decorator and the iterator __init__ method do not match" + assert iterator_decorator_args == iterator_init_args, ( + f"Arguments for the iterator decorator and the iterator __init__ method do not match:" + f"\n------\n{iterator_decorator_args}\n-- vs --\n{iterator_init_args}\n------" + ) # Get docs for the decorator "Parameters" section # Skip the first argument, which differs (pipelines vs. pipeline_fn) @@ -107,6 +113,7 @@ def test_iterator_decorator_api_match_iterator_init(): iterator_init_docs = iterator_init_docs.split("output_map")[1] iterator_init_docs = iterator_init_docs.split("sharding")[0] - assert ( - iterator_decorator_docs == iterator_init_docs - ), "Documentation for the iterator decorator and the iterator __init__ method does not match" + assert iterator_decorator_docs == iterator_init_docs, ( + "Documentation for the iterator decorator and the iterator __init__ method does not match:" + f"\n------\n{iterator_decorator_docs}\n-- vs --\n{iterator_init_docs}\n------" + ) diff --git a/qa/TL3_JAX_multiprocess/jax_server.py b/qa/TL3_JAX_multiprocess/jax_server.py index 90dbfbe9eba..c6a0f2028bc 100644 --- a/qa/TL3_JAX_multiprocess/jax_server.py +++ b/qa/TL3_JAX_multiprocess/jax_server.py @@ -85,7 +85,7 @@ def run_distributed_sharing_test(sharding, process_id): get_dali_tensor_gpu(process_id, (1), np.int32, id), False ) - assert current_shard.device() == device + assert dax.integration._jax_device(current_shard) == device dali_local_shards.append(current_shard) @@ -97,7 +97,7 @@ def run_distributed_sharing_test(sharding, process_id): for id, buffer in enumerate(dali_sharded_array.device_buffers): assert buffer == jnp.array([process_id]) - assert buffer.device() == jax.local_devices()[id] + assert dax.integration._jax_device(buffer) == jax.local_devices()[id] def test_positional_sharding_workflow(process_id): From 050a4aadf2d8c92cce8a46114bbdfd918676e9ed Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Wed, 13 Nov 2024 14:49:58 +0100 Subject: [PATCH 3/5] Adjust to new jax. Signed-off-by: Michal Zientkiewicz --- dali/test/python/jax_plugin/jax_server.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/dali/test/python/jax_plugin/jax_server.py b/dali/test/python/jax_plugin/jax_server.py index d6a3fe07fe2..2e64fee983b 100644 --- a/dali/test/python/jax_plugin/jax_server.py +++ b/dali/test/python/jax_plugin/jax_server.py @@ -73,13 +73,19 @@ def run_distributed_sharing_test(sharding, process_id): shape=(2,), sharding=sharding, arrays=[dali_local_shard] ) - # This array should be backed only by one device buffer that holds - # local part of the data. This buffer should be on the local device. - assert len(dali_sharded_array.device_buffers) == 1 - assert dali_sharded_array.device_buffer == jnp.array([process_id]) - assert dax.integration._jax_device(dali_sharded_array.device_buffer) == jax.local_devices()[0] + # device_buffers has been removed + if hasattr(dali_sharded_array, "device_buffers"): + # This array should be backed only by one device buffer that holds + # local part of the data. This buffer should be on the local device. + assert len(dali_sharded_array.device_buffers) == 1 + assert dali_sharded_array.addressable_data(0) == jnp.array([process_id]) assert ( - dax.integration._jax_device(dali_sharded_array.device_buffer) == jax.devices()[process_id] + dax.integration._jax_device(dali_sharded_array.addressable_data(0)) + == jax.local_devices()[0] + ) + assert ( + dax.integration._jax_device(dali_sharded_array.addressable_data(0)) + == jax.devices()[process_id] ) From b3cc6aab5554f9dd0c05f97c99a2b7b4b0142118 Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Thu, 14 Nov 2024 16:22:09 +0100 Subject: [PATCH 4/5] Fix JAX array device handling for various versions of JAX. Signed-off-by: Michal Zientkiewicz --- dali/python/nvidia/dali/plugin/jax/integration.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/dali/python/nvidia/dali/plugin/jax/integration.py b/dali/python/nvidia/dali/plugin/jax/integration.py index f4e3f003438..0750a19747b 100644 --- a/dali/python/nvidia/dali/plugin/jax/integration.py +++ b/dali/python/nvidia/dali/plugin/jax/integration.py @@ -22,11 +22,20 @@ _jax_has_old_dlpack = Version(jax.__version__) < Version("0.4.16") -if Version(jax.__version__) >= Version("0.4.26"): +if Version(jax.__version__) >= Version("0.4.31"): def _jax_device(jax_array): return jax_array.device +elif Version(jax.__version__) >= Version("0.4.27"): + + def _jax_device(jax_array): + devs = jax_array.devices() + if len(devs) != 1: + raise RuntimeError("The must be associated with exactly one device") + for d in devs: + return d + else: def _jax_device(jax_array): From 80a6b429bb5f70f2182d7ab1d60edd8266295514 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Fri, 15 Nov 2024 14:59:07 +0100 Subject: [PATCH 5/5] Fix error message. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: MichaƂ Zientkiewicz --- dali/python/nvidia/dali/plugin/jax/integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dali/python/nvidia/dali/plugin/jax/integration.py b/dali/python/nvidia/dali/plugin/jax/integration.py index 0750a19747b..72f537b96b6 100644 --- a/dali/python/nvidia/dali/plugin/jax/integration.py +++ b/dali/python/nvidia/dali/plugin/jax/integration.py @@ -32,7 +32,7 @@ def _jax_device(jax_array): def _jax_device(jax_array): devs = jax_array.devices() if len(devs) != 1: - raise RuntimeError("The must be associated with exactly one device") + raise RuntimeError("The array must be associated with exactly one device") for d in devs: return d