Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/field_storage_interface' into fi…
Browse files Browse the repository at this point in the history
…eld_storage_interface_gpu
  • Loading branch information
havogt committed Nov 14, 2023
2 parents 93cefcd + 7756a8e commit e57afa5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/gt4py/next/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def __gt_allocate__(
) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]:
shape = domain.shape
layout_map = self.layout_mapper(domain.dims)
assert aligned_index is None # TODO
# TODO(egparedes): add support for non-empty aligned index values
assert aligned_index is None

return core_allocators.NDArrayBufferAllocator(self.device_type, self.array_ns).allocate(
shape, dtype, device_id, layout_map, self.byte_alignment, aligned_index
Expand Down
18 changes: 16 additions & 2 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,22 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
)


gtfn_gpu_executor = otf_compile_executor.OTFCompileExecutor(
name="run_gtfn_gpu", otf_workflow=GTFN_GPU_WORKFLOW
)
run_gtfn_gpu = otf_compile_executor.OTFBackend(
executor=gtfn_cached_executor,
executor=gtfn_gpu_executor,
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
)


gtfn_gpu_cached_executor = otf_compile_executor.CachedOTFCompileExecutor(
name="run_gtfn_gpu_cached",
otf_workflow=workflow.CachedStep(
step=gtfn_gpu_executor.otf_workflow, hash_function=compilation_hash
),
)
run_gtfn_gpu_cached = otf_compile_executor.OTFBackend(
executor=gtfn_gpu_cached_executor,
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
)
4 changes: 3 additions & 1 deletion tests/next_tests/unit_tests/test_allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def test_get_allocator():

# Test with an invalid object and no default allocator
invalid_obj = "not an allocator"
assert next_allocators.get_allocator(invalid_obj) is None

with pytest.raises(
TypeError,
match=f"Object {invalid_obj} is neither a field allocator nor a field allocator factory",
):
next_allocators.get_allocator(invalid_obj)
next_allocators.get_allocator(invalid_obj, strict=True)


def test_horizontal_first_layout_mapper():
Expand Down

0 comments on commit e57afa5

Please sign in to comment.