Skip to content

Commit

Permalink
[ROCm] Enable ROCm support for inductor's dynamic_rblock_scaling (pyt…
Browse files Browse the repository at this point in the history
…orch#129663)

As of ROCm 6.1 [hipDeviceProp_t::regsPerMultiprocessor](https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/structhip_device_prop__t.html#a7390d5b180d63978c81aa971060270b4) is now available allowing us to enable this attribute on ROCm.
```
>>> torch.cuda.get_device_properties(0)
_CudaDeviceProperties(name='AMD Instinct MI250X/MI250', major=9, minor=0, gcnArchName='gfx90a:sramecc+:xnack-', total_memory=65520MB, multi_processor_count=104)
>>> torch.cuda.get_device_properties(0).regs_per_multiprocessor
65536
```

With https://github.com/triton-lang/triton/pull/3962we can extract n_regs and n_spells from a triton binary with AMD backend allowing us to enable inductor's dynamic_rblock_scaling on ROCm initially implemented in pytorch#115094

Leaving this in draft until following PRs have landed:
- pytorch#129361 to bump the triton commit pin
- pytorch#128449 to allow us to grab warp_size from device properties instead of hard coding 64 on ROCm.

Pull Request resolved: pytorch#129663
Approved by: https://github.com/jansel, https://github.com/shunting314
  • Loading branch information
jataylo authored and pytorchmergebot committed Sep 13, 2024
1 parent 564d00f commit a157745
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
14 changes: 11 additions & 3 deletions torch/_inductor/runtime/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,24 +104,32 @@ class DeviceProperties(typing.NamedTuple):
regs_per_multiprocessor: Optional[int] = None
max_threads_per_multi_processor: Optional[int] = None
multi_processor_count: Optional[int] = None
warp_size: Optional[int] = None

@classmethod
def create(cls, device):
import torch
from torch._dynamo.device_interface import get_interface_for_device

device_type = device.type if torch.version.hip is None else "hip"
device_type = device.type

if torch.version.hip and device_type == "cuda":
device_type = "hip"

device_interface = get_interface_for_device(device)
if device_type == "cuda":
if device_type in ["cuda", "hip"]:
props = device_interface.get_device_properties(device)
return cls(
type=device_type,
index=device.index,
cc=device_interface.get_compute_capability(device),
major=props.major,
regs_per_multiprocessor=props.regs_per_multiprocessor,
regs_per_multiprocessor=props.regs_per_multiprocessor
if hasattr(props, "regs_per_multiprocessor")
else None,
max_threads_per_multi_processor=props.max_threads_per_multi_processor,
multi_processor_count=props.multi_processor_count,
warp_size=props.warp_size,
)
return cls(
type=device_type,
Expand Down
9 changes: 5 additions & 4 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,11 @@ def precompile(self, warm_cache_only=False):
self.inductor_meta.get("dynamic_scale_rblock", True)
and self.heuristic_type == HeuristicType.REDUCTION
and self.size_hints is not None
# Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary.
and device_prop.type == "cuda"
# Disable for Intel as Triton is not ready to return n_regs for a compiled_binary.
and device_prop.type in ["cuda", "hip"]
and device_prop.major
and device_prop.major >= 8
and (device_prop.major >= 8 or torch.version.hip)
and device_prop.regs_per_multiprocessor is not None
):
assert device_prop.regs_per_multiprocessor
assert device_prop.max_threads_per_multi_processor
Expand Down Expand Up @@ -305,7 +306,7 @@ def precompile(self, warm_cache_only=False):
):
continue

nreg_per_warp = nreg * 32
nreg_per_warp = nreg * device_prop.warp_size
nreg_per_block = nreg_per_warp * triton_config.num_warps

# Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/cuda/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,11 +989,10 @@ static void registerCudaDeviceProperties(PyObject* module) {
"max_threads_per_multi_processor",
&cudaDeviceProp::maxThreadsPerMultiProcessor)
.def_readonly("warp_size", &cudaDeviceProp::warpSize)
#if !USE_ROCM
// NVIDA only property
#if (defined(USE_ROCM) && ROCM_VERSION >= 60100) || !USE_ROCM
.def_readonly(
"regs_per_multiprocessor", &cudaDeviceProp::regsPerMultiprocessor)
#endif // USE_ROCM
#endif
// HIP-only property; reuse name attribute for CUDA builds
.def_readonly(
"gcnArchName",
Expand Down

0 comments on commit a157745

Please sign in to comment.