-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: Support for Array Api namespace as allocator #1771
base: main
Are you sure you want to change the base?
Conversation
src/gt4py/next/allocators.py
Outdated
allocator: FieldBufferAllocationUtil = actual_allocator, | ||
device: core_defs.Device = device, | ||
) -> core_defs.NDArrayObject: | ||
# TODO check how to get from FieldBufferAllocationUtil to FieldBufferAllocatorProtocol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- think about long names -> maybe rename to
NDArray...
__copy__
,__deepcopy__
as_ndarray(allocator: ConcreteAllocator, copy: Optional[bool])
- can
TensorBuffer
be removed?
|
self.ndarray, | ||
domain=self.domain, | ||
dtype=self.dtype, | ||
copy=True, # aligned_index??? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is a missing piece in the current implementation. We should probably provide this as a default in all functions.
def get_array_allocation_namespace( | ||
allocator: Optional[FieldBufferAllocationUtil | core_defs.ArrayApiNamespace], | ||
device: Optional[core_defs.Device] = None, | ||
) -> GTArrayAllocationNamespace: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe put aligned_index here and add it as a default to the construction functions.
Same for device.
Allow Array API namespaces as allocators in
gtx.constructors
. This allows e.g. to construct jax fields in non-hacky way.Additional:
__copy__
,__deepcopy__
to NDArrayField (with same memory layout as source Field)