diff --git a/setup.py b/setup.py index c9d248bfbf..336861f0f1 100644 --- a/setup.py +++ b/setup.py @@ -420,10 +420,12 @@ def get_extensions(): "--ptxas-options=-O2", "--ptxas-options=-allow-expensive-optimizations=true", ] - elif torch.cuda.is_available() and torch.version.hip + elif ( + (torch.cuda.is_available() and torch.version.hip) or os.getenv("FORCE_ROCM", "0") == "1" - or os.getenv("PYTORCH_ROCM_ARCH", "") != "": - + or os.getenv("PYTORCH_ROCM_ARCH", "") != "" + ): + rename_cpp_cu(source_hip) source_hip_cu = [] for ff in source_hip: