Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX zero copy #5703

Merged
merged 5 commits into from
Nov 15, 2024
Merged

JAX zero copy #5703

merged 5 commits into from
Nov 15, 2024

Conversation

mzient
Copy link
Contributor

@mzient mzient commented Nov 7, 2024

Category:

New feature (non-breaking change which adds functionality)

Description:

This PR omits the copy when the new executor is used with JAX >= 0.4.16.
When JAX is older or the pipeline doesn't use the dynamic_executor, the copy is still performed.

Additional information:

Affected modules and functionalities:

Key points relevant for the review:

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-4117

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -21,10 +21,12 @@

from utils import iterator_function_def

import nvidia.dali.plugin.jax as dax

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'nvidia.dali.plugin.jax' is imported with both 'import' and 'import from'.
Module 'plugin.jax' is imported with both 'import' and 'import from'.
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [20171555]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [20171555]: BUILD FAILED

@mzient mzient marked this pull request as ready for review November 8, 2024 15:27
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [20208927]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [20208927]: BUILD FAILED

@mzient mzient force-pushed the jax_zero_copy branch 2 times, most recently from 218732d to e7029b3 Compare November 13, 2024 13:50
@mzient
Copy link
Contributor Author

mzient commented Nov 14, 2024

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [20407117]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [20407117]: BUILD PASSED

@@ -82,7 +82,7 @@ class DALIGenericPeekableIterator(DALIGenericIterator):
is called internally automatically.
last_batch_policy: optional, default = LastBatchPolicy.FILL
What to do with the last batch when there are not enough samples in the epoch
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test that compares docstrings and it failed before. Apparently it's not run in CI (?).

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [20432462]: BUILD STARTED

@NVIDIA NVIDIA deleted a comment from dali-automaton Nov 15, 2024
def _jax_device(jax_array):
devs = jax_array.devices()
if len(devs) != 1:
raise RuntimeError("The must be associated with exactly one device")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The what?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The array.... I'll re-run build stage only.

Signed-off-by: Michał Zientkiewicz <[email protected]>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [20436842]: BUILD STARTED


def _jax_device(jax_array):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.

from packaging.version import Version
def _jax_device(jax_array):
return jax_array.device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find the documentation for old JAX releases, but looking at the code, the method version used to raise an error for multi-dev array, while this property (according to the docs) returns sharding. So this call works fine with the multi-dev arrays while the other two variants don't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's kind of the problem with jax: it keeps changing and the documentation is hard to find.

@mzient mzient merged commit 2d07f81 into NVIDIA:main Nov 15, 2024
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants