Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX zero copy #5703

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/plugin/jax/clu.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DALIGenericPeekableIterator(DALIGenericIterator):
is called internally automatically.
last_batch_policy: optional, default = LastBatchPolicy.FILL
What to do with the last batch when there are not enough samples in the epoch
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test that compares docstrings and it failed before. Apparently it's not run in CI (?).

JAX iterator does not support LastBatchPolicy.PARTIAL
last_batch_padded : bool, optional, default = False
Whether the last batch provided by DALI is padded with the last sample
Expand Down
49 changes: 32 additions & 17 deletions dali/python/nvidia/dali/plugin/jax/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,46 @@
import jax.dlpack

from nvidia.dali.backend import TensorGPU
from packaging.version import Version


_jax_version_pre_0_4_16 = None
_jax_has_old_dlpack = Version(jax.__version__) < Version("0.4.16")


def _jax_has_old_dlpack():
global _jax_version_pre_0_4_16
if _jax_version_pre_0_4_16 is not None:
return _jax_version_pre_0_4_16
if Version(jax.__version__) >= Version("0.4.31"):

from packaging.version import Version
def _jax_device(jax_array):
return jax_array.device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find the documentation for old JAX releases, but looking at the code, the method version used to raise an error for multi-dev array, while this property (according to the docs) returns sharding. So this call works fine with the multi-dev arrays while the other two variants don't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's kind of the problem with jax: it keeps changing and the documentation is hard to find.


_jax_version_pre_0_4_16 = Version(jax.__version__) < Version("0.4.16")
return _jax_version_pre_0_4_16
elif Version(jax.__version__) >= Version("0.4.27"):

def _jax_device(jax_array):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
devs = jax_array.devices()
if len(devs) != 1:
raise RuntimeError("The array must be associated with exactly one device")
for d in devs:
return d

def _to_jax_array(dali_tensor: TensorGPU) -> jax.Array:
else:

def _jax_device(jax_array):
return jax_array.device()


def _to_jax_array(dali_tensor: TensorGPU, copy: bool) -> jax.Array:
"""Converts input DALI tensor to JAX array.

Args:
dali_tensor (TensorGPU): DALI GPU tensor to be converted to JAX array.
dali_tensor (TensorGPU):
DALI GPU tensor to be converted to JAX array.

copy (bool):
If True, the output is copied;
if False, the output may wrap DLPack capsule obtained from `dali_tensor`.

Note:
This function performs deep copy of the underlying data. That will change in
future releases.
This function may perform a copy of the data even if `copy==False` when JAX version is
insufficient (<0.4.16)

Warning:
As private this API may change without notice.
Expand All @@ -49,12 +64,12 @@
jax.Array: JAX array with the same values and backing device as
input DALI tensor.
"""
if _jax_has_old_dlpack():
if _jax_has_old_dlpack:
copy = True
jax_array = jax.dlpack.from_dlpack(dali_tensor.__dlpack__(stream=None))
else:
jax_array = jax.dlpack.from_dlpack(dali_tensor)

# For now we need this copy to make sure that underlying memory is available.
# One solution is to implement full DLPack contract in DALI.
# TODO(awolant): Remove this copy.
return jax_array.copy()
if copy:
jax_array = jax_array.copy()
return jax_array
7 changes: 5 additions & 2 deletions dali/python/nvidia/dali/plugin/jax/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class DALIGenericIterator(_DaliBaseIterator):
is called internally automatically.
last_batch_policy: optional, default = LastBatchPolicy.FILL
What to do with the last batch when there are not enough samples in the epoch
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`.
JAX iterator does not support LastBatchPolicy.PARTIAL
last_batch_padded : bool, optional, default = False
Whether the last batch provided by DALI is padded with the last sample
Expand Down Expand Up @@ -193,7 +193,10 @@ def _gather_outputs_for_category(self, pipelines_outputs, category_id):

for pipeline_id in range(self._num_gpus):
category_outputs.append(
_to_jax_array(pipelines_outputs[pipeline_id][category_id].as_tensor())
_to_jax_array(
pipelines_outputs[pipeline_id][category_id].as_tensor(),
not self._pipes[pipeline_id].exec_dynamic,
)
)

return category_outputs
Expand Down
28 changes: 18 additions & 10 deletions dali/test/python/jax_plugin/jax_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,10 +47,10 @@ def print_devices_details(devices_list, process_id):


def test_lax_workflow(process_id):
array_from_dali = dax.integration._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32))
array_from_dali = dax.integration._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32), False)

assert (
array_from_dali.device() == jax.local_devices()[0]
dax.integration._jax_device(array_from_dali) == jax.local_devices()[0]
), "Array should be backed by the device local to current process."

sum_across_devices = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(array_from_dali)
Expand All @@ -64,7 +64,7 @@ def test_lax_workflow(process_id):

def run_distributed_sharing_test(sharding, process_id):
dali_local_shard = dax.integration._to_jax_array(
get_dali_tensor_gpu(process_id, (1), np.int32, 0)
get_dali_tensor_gpu(process_id, (1), np.int32, 0), False
)

# Note: we pass only one local shard but the array virtually
Expand All @@ -73,12 +73,20 @@ def run_distributed_sharing_test(sharding, process_id):
shape=(2,), sharding=sharding, arrays=[dali_local_shard]
)

# This array should be backed only by one device buffer that holds
# local part of the data. This buffer should be on the local device.
assert len(dali_sharded_array.device_buffers) == 1
assert dali_sharded_array.device_buffer == jnp.array([process_id])
assert dali_sharded_array.device_buffer.device() == jax.local_devices()[0]
assert dali_sharded_array.device_buffer.device() == jax.devices()[process_id]
# device_buffers has been removed
if hasattr(dali_sharded_array, "device_buffers"):
# This array should be backed only by one device buffer that holds
# local part of the data. This buffer should be on the local device.
assert len(dali_sharded_array.device_buffers) == 1
assert dali_sharded_array.addressable_data(0) == jnp.array([process_id])
assert (
dax.integration._jax_device(dali_sharded_array.addressable_data(0))
== jax.local_devices()[0]
)
assert (
dax.integration._jax_device(dali_sharded_array.addressable_data(0))
== jax.devices()[process_id]
)


def test_positional_sharding_workflow(process_id):
Expand Down
10 changes: 5 additions & 5 deletions dali/test/python/jax_plugin/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,13 +35,13 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value):
dali_tensor_gpu = get_dali_tensor_gpu(value=value, shape=shape, dtype=dtype)

# when
jax_array = dax.integration._to_jax_array(dali_tensor_gpu)
jax_array = dax.integration._to_jax_array(dali_tensor_gpu, False)

# then
assert jax.numpy.array_equal(jax_array, jax.numpy.full(shape, value, dtype))

# Make sure JAX array is backed by the GPU
assert jax_array.device() == jax.devices()[0]
assert dax.integration._jax_device(jax_array) == jax.devices()[0]


def test_dali_sequential_tensors_to_jax_array():
Expand All @@ -56,10 +56,10 @@ def test_dali_sequential_tensors_to_jax_array():
dali_tensor_gpu = pipe.run()[0].as_tensor()

# when
jax_array = dax.integration._to_jax_array(dali_tensor_gpu)
jax_array = dax.integration._to_jax_array(dali_tensor_gpu, False)

# then
assert jax_array.device() == jax.devices()[0]
assert dax.integration._jax_device(jax_array) == jax.devices()[0]

for i in range(batch_size):
assert jax.numpy.array_equal(
Expand Down
13 changes: 9 additions & 4 deletions dali/test/python/jax_plugin/test_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,10 +21,12 @@

from utils import iterator_function_def

import nvidia.dali.plugin.jax as dax

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'nvidia.dali.plugin.jax' is imported with both 'import' and 'import from'.
Module 'plugin.jax' is imported with both 'import' and 'import from'.
from nvidia.dali.plugin.jax import DALIGenericIterator
from nvidia.dali.pipeline import pipeline_def
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
from nose_utils import raises
from nose2.tools import params

import itertools

Expand All @@ -39,7 +41,7 @@
jax_array = data["data"]

# then
assert jax_array.device() == jax.devices()[0]
assert dax.integration._jax_device(jax_array) == jax.devices()[0]

for i in range(batch_size):
assert jax.numpy.array_equal(
Expand All @@ -49,9 +51,12 @@
assert batch_id == num_iters - 1


def test_dali_sequential_iterator():
@params((False,), (True,))
def test_dali_sequential_iterator(exec_dynamic):
# given
pipe = pipeline_def(iterator_function_def)(batch_size=batch_size, num_threads=4, device_id=0)
pipe = pipeline_def(iterator_function_def)(
batch_size=batch_size, num_threads=4, device_id=0, exec_dynamic=exec_dynamic
)
iter = DALIGenericIterator([pipe], ["data"], reader_name="reader")

# then
Expand Down
29 changes: 18 additions & 11 deletions dali/test/python/jax_plugin/test_iterator_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@

from nvidia.dali.plugin.jax import DALIGenericIterator, data_iterator
from test_iterator import run_and_assert_sequential_iterator
from nose2.tools import params

import inspect

Expand Down Expand Up @@ -46,25 +47,29 @@ def iterator_function():
run_and_assert_sequential_iterator(iter)


def test_dali_iterator_decorator_declarative_with_default_args():
@params((False,), (True,))
def test_dali_iterator_decorator_declarative_with_default_args(exec_dynamic):
# given
@data_iterator(output_map=["data"], reader_name="reader")
def iterator_function():
return iterator_function_def()

iter = iterator_function(batch_size=batch_size)
iter = iterator_function(batch_size=batch_size, exec_dynamic=exec_dynamic)

# then
run_and_assert_sequential_iterator(iter)


def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument():
@params((False,), (True,))
def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument(exec_dynamic):
# given
@data_iterator(output_map=["data"], reader_name="reader")
def iterator_function(num_shards):
return iterator_function_def(num_shards=num_shards)

iter = iterator_function(num_shards=2, num_threads=4, device_id=0, batch_size=batch_size)
iter = iterator_function(
num_shards=2, num_threads=4, device_id=0, batch_size=batch_size, exec_dynamic=exec_dynamic
)

# then
run_and_assert_sequential_iterator(iter)
Expand All @@ -91,9 +96,10 @@ def test_iterator_decorator_api_match_iterator_init():
iterator_decorator_args.remove("devices")

# then
assert (
iterator_decorator_args == iterator_init_args
), "Arguments for the iterator decorator and the iterator __init__ method do not match"
assert iterator_decorator_args == iterator_init_args, (
f"Arguments for the iterator decorator and the iterator __init__ method do not match:"
f"\n------\n{iterator_decorator_args}\n-- vs --\n{iterator_init_args}\n------"
)

# Get docs for the decorator "Parameters" section
# Skip the first argument, which differs (pipelines vs. pipeline_fn)
Expand All @@ -107,6 +113,7 @@ def test_iterator_decorator_api_match_iterator_init():
iterator_init_docs = iterator_init_docs.split("output_map")[1]
iterator_init_docs = iterator_init_docs.split("sharding")[0]

assert (
iterator_decorator_docs == iterator_init_docs
), "Documentation for the iterator decorator and the iterator __init__ method does not match"
assert iterator_decorator_docs == iterator_init_docs, (
"Documentation for the iterator decorator and the iterator __init__ method does not match:"
f"\n------\n{iterator_decorator_docs}\n-- vs --\n{iterator_init_docs}\n------"
)
10 changes: 5 additions & 5 deletions dali/test/python/jax_plugin/test_multigpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -99,8 +99,8 @@ def test_dali_sequential_sharded_tensors_to_jax_sharded_array_manuall():
dali_tensor_gpu_0 = pipe_0.run()[0].as_tensor()
dali_tensor_gpu_1 = pipe_1.run()[0].as_tensor()

jax_shard_0 = dax.integration._to_jax_array(dali_tensor_gpu_0)
jax_shard_1 = dax.integration._to_jax_array(dali_tensor_gpu_1)
jax_shard_0 = dax.integration._to_jax_array(dali_tensor_gpu_0, False)
jax_shard_1 = dax.integration._to_jax_array(dali_tensor_gpu_1, False)

assert jax_shard_0.device() == jax.devices()[0]
assert jax_shard_1.device() == jax.devices()[1]
Expand Down Expand Up @@ -224,8 +224,8 @@ def run_sharding_test(sharding):
dali_shard_1 = get_dali_tensor_gpu(1, (1), np.int32, 1)

shards = [
dax.integration._to_jax_array(dali_shard_0),
dax.integration._to_jax_array(dali_shard_1),
dax.integration._to_jax_array(dali_shard_0, False),
dax.integration._to_jax_array(dali_shard_1, False),
]

assert shards[0].device() == jax.devices()[0]
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/jax_plugin/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,7 +36,7 @@ def get_dali_tensor_gpu(value, shape, dtype, device_id=0) -> TensorGPU:
with provided value.
"""

@pipeline_def(num_threads=1, batch_size=1)
@pipeline_def(num_threads=1, batch_size=1, exec_dynamic=True)
def dali_pipeline():
values = types.Constant(value=np.full(shape, value, dtype), device="gpu")

Expand Down
8 changes: 4 additions & 4 deletions qa/TL3_JAX_multiprocess/jax_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,10 +82,10 @@ def run_distributed_sharing_test(sharding, process_id):
dali_local_shards = []
for id, device in enumerate(jax.local_devices()):
current_shard = dax.integration._to_jax_array(
get_dali_tensor_gpu(process_id, (1), np.int32, id)
get_dali_tensor_gpu(process_id, (1), np.int32, id), False
)

assert current_shard.device() == device
assert dax.integration._jax_device(current_shard) == device

dali_local_shards.append(current_shard)

Expand All @@ -97,7 +97,7 @@ def run_distributed_sharing_test(sharding, process_id):

for id, buffer in enumerate(dali_sharded_array.device_buffers):
assert buffer == jnp.array([process_id])
assert buffer.device() == jax.local_devices()[id]
assert dax.integration._jax_device(buffer) == jax.local_devices()[id]


def test_positional_sharding_workflow(process_id):
Expand Down
Loading