diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index d3c9f560fd2a1..0f1495e49972c 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -104,24 +104,32 @@ class DeviceProperties(typing.NamedTuple): regs_per_multiprocessor: Optional[int] = None max_threads_per_multi_processor: Optional[int] = None multi_processor_count: Optional[int] = None + warp_size: Optional[int] = None @classmethod def create(cls, device): import torch from torch._dynamo.device_interface import get_interface_for_device - device_type = device.type if torch.version.hip is None else "hip" + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + device_interface = get_interface_for_device(device) - if device_type == "cuda": + if device_type in ["cuda", "hip"]: props = device_interface.get_device_properties(device) return cls( type=device_type, index=device.index, cc=device_interface.get_compute_capability(device), major=props.major, - regs_per_multiprocessor=props.regs_per_multiprocessor, + regs_per_multiprocessor=props.regs_per_multiprocessor + if hasattr(props, "regs_per_multiprocessor") + else None, max_threads_per_multi_processor=props.max_threads_per_multi_processor, multi_processor_count=props.multi_processor_count, + warp_size=props.warp_size, ) return cls( type=device_type, diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c0b0e20e906d5..89c47838da43f 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -265,10 +265,11 @@ def precompile(self, warm_cache_only=False): self.inductor_meta.get("dynamic_scale_rblock", True) and self.heuristic_type == HeuristicType.REDUCTION and self.size_hints is not None - # Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary. - and device_prop.type == "cuda" + # Disable for Intel as Triton is not ready to return n_regs for a compiled_binary. + and device_prop.type in ["cuda", "hip"] and device_prop.major - and device_prop.major >= 8 + and (device_prop.major >= 8 or torch.version.hip) + and device_prop.regs_per_multiprocessor is not None ): assert device_prop.regs_per_multiprocessor assert device_prop.max_threads_per_multi_processor @@ -305,7 +306,7 @@ def precompile(self, warm_cache_only=False): ): continue - nreg_per_warp = nreg * 32 + nreg_per_warp = nreg * device_prop.warp_size nreg_per_block = nreg_per_warp * triton_config.num_warps # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 22cbadc42fc09..461a23e651924 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -989,11 +989,10 @@ static void registerCudaDeviceProperties(PyObject* module) { "max_threads_per_multi_processor", &cudaDeviceProp::maxThreadsPerMultiProcessor) .def_readonly("warp_size", &cudaDeviceProp::warpSize) -#if !USE_ROCM - // NVIDA only property +#if (defined(USE_ROCM) && ROCM_VERSION >= 60100) || !USE_ROCM .def_readonly( "regs_per_multiprocessor", &cudaDeviceProp::regsPerMultiprocessor) -#endif // USE_ROCM +#endif // HIP-only property; reuse name attribute for CUDA builds .def_readonly( "gcnArchName",