-
Notifications
You must be signed in to change notification settings - Fork 630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX zero copy #5703
JAX zero copy #5703
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -21,10 +21,12 @@ | |||
|
|||
from utils import iterator_function_def | |||
|
|||
import nvidia.dali.plugin.jax as dax |
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note test
Module 'plugin.jax' is imported with both 'import' and 'import from'.
CI MESSAGE: [20171555]: BUILD STARTED |
CI MESSAGE: [20171555]: BUILD FAILED |
CI MESSAGE: [20208927]: BUILD STARTED |
CI MESSAGE: [20208927]: BUILD FAILED |
218732d
to
e7029b3
Compare
…AX is too old. Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
8cd7d0f
to
b3cc6aa
Compare
!build |
CI MESSAGE: [20407117]: BUILD STARTED |
CI MESSAGE: [20407117]: BUILD PASSED |
@@ -82,7 +82,7 @@ class DALIGenericPeekableIterator(DALIGenericIterator): | |||
is called internally automatically. | |||
last_batch_policy: optional, default = LastBatchPolicy.FILL | |||
What to do with the last batch when there are not enough samples in the epoch | |||
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy` | |||
to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a test that compares docstrings and it failed before. Apparently it's not run in CI (?).
CI MESSAGE: [20432462]: BUILD STARTED |
def _jax_device(jax_array): | ||
devs = jax_array.devices() | ||
if len(devs) != 1: | ||
raise RuntimeError("The must be associated with exactly one device") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The array.... I'll re-run build stage only.
Signed-off-by: Michał Zientkiewicz <[email protected]>
CI MESSAGE: [20436842]: BUILD STARTED |
|
||
from packaging.version import Version | ||
def _jax_device(jax_array): | ||
return jax_array.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find the documentation for old JAX releases, but looking at the code, the method version used to raise an error for multi-dev array, while this property (according to the docs) returns sharding. So this call works fine with the multi-dev arrays while the other two variants don't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's kind of the problem with jax: it keeps changing and the documentation is hard to find.
Category:
New feature (non-breaking change which adds functionality)
Description:
This PR omits the copy when the new executor is used with JAX >= 0.4.16.
When JAX is older or the pipeline doesn't use the dynamic_executor, the copy is still performed.
Additional information:
Affected modules and functionalities:
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-4117