From a15774563b8d2179bcb2347cf1217c6223234c44 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 13 Sep 2024 16:45:39 +0000 Subject: [PATCH] [ROCm] Enable ROCm support for inductor's dynamic_rblock_scaling (#129663) As of ROCm 6.1 [hipDeviceProp_t::regsPerMultiprocessor](https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/structhip_device_prop__t.html#a7390d5b180d63978c81aa971060270b4) is now available allowing us to enable this attribute on ROCm. ``` >>> torch.cuda.get_device_properties(0) _CudaDeviceProperties(name='AMD Instinct MI250X/MI250', major=9, minor=0, gcnArchName='gfx90a:sramecc+:xnack-', total_memory=65520MB, multi_processor_count=104) >>> torch.cuda.get_device_properties(0).regs_per_multiprocessor 65536 ``` With https://github.com/triton-lang/triton/pull/3962we can extract n_regs and n_spells from a triton binary with AMD backend allowing us to enable inductor's dynamic_rblock_scaling on ROCm initially implemented in https://github.com/pytorch/pytorch/pull/115094 Leaving this in draft until following PRs have landed: - https://github.com/pytorch/pytorch/pull/129361 to bump the triton commit pin - https://github.com/pytorch/pytorch/pull/128449 to allow us to grab warp_size from device properties instead of hard coding 64 on ROCm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129663 Approved by: https://github.com/jansel, https://github.com/shunting314 --- torch/_inductor/runtime/hints.py | 14 +++++++++++--- torch/_inductor/runtime/triton_heuristics.py | 9 +++++---- torch/csrc/cuda/Module.cpp | 5 ++--- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index d3c9f560fd2a1e..0f1495e49972c7 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -104,24 +104,32 @@ class DeviceProperties(typing.NamedTuple): regs_per_multiprocessor: Optional[int] = None max_threads_per_multi_processor: Optional[int] = None multi_processor_count: Optional[int] = None + warp_size: Optional[int] = None @classmethod def create(cls, device): import torch from torch._dynamo.device_interface import get_interface_for_device - device_type = device.type if torch.version.hip is None else "hip" + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + device_interface = get_interface_for_device(device) - if device_type == "cuda": + if device_type in ["cuda", "hip"]: props = device_interface.get_device_properties(device) return cls( type=device_type, index=device.index, cc=device_interface.get_compute_capability(device), major=props.major, - regs_per_multiprocessor=props.regs_per_multiprocessor, + regs_per_multiprocessor=props.regs_per_multiprocessor + if hasattr(props, "regs_per_multiprocessor") + else None, max_threads_per_multi_processor=props.max_threads_per_multi_processor, multi_processor_count=props.multi_processor_count, + warp_size=props.warp_size, ) return cls( type=device_type, diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c0b0e20e906d5e..89c47838da43f3 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -265,10 +265,11 @@ def precompile(self, warm_cache_only=False): self.inductor_meta.get("dynamic_scale_rblock", True) and self.heuristic_type == HeuristicType.REDUCTION and self.size_hints is not None - # Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary. - and device_prop.type == "cuda" + # Disable for Intel as Triton is not ready to return n_regs for a compiled_binary. + and device_prop.type in ["cuda", "hip"] and device_prop.major - and device_prop.major >= 8 + and (device_prop.major >= 8 or torch.version.hip) + and device_prop.regs_per_multiprocessor is not None ): assert device_prop.regs_per_multiprocessor assert device_prop.max_threads_per_multi_processor @@ -305,7 +306,7 @@ def precompile(self, warm_cache_only=False): ): continue - nreg_per_warp = nreg * 32 + nreg_per_warp = nreg * device_prop.warp_size nreg_per_block = nreg_per_warp * triton_config.num_warps # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 22cbadc42fc09f..461a23e651924b 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -989,11 +989,10 @@ static void registerCudaDeviceProperties(PyObject* module) { "max_threads_per_multi_processor", &cudaDeviceProp::maxThreadsPerMultiProcessor) .def_readonly("warp_size", &cudaDeviceProp::warpSize) -#if !USE_ROCM - // NVIDA only property +#if (defined(USE_ROCM) && ROCM_VERSION >= 60100) || !USE_ROCM .def_readonly( "regs_per_multiprocessor", &cudaDeviceProp::regsPerMultiprocessor) -#endif // USE_ROCM +#endif // HIP-only property; reuse name attribute for CUDA builds .def_readonly( "gcnArchName",