Skip to content

Commit

Permalink
adapted continuous isolation checks to distributed pytorch on amd
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 23, 2023
1 parent bba1199 commit 2a57ebf
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 27 deletions.
59 changes: 35 additions & 24 deletions optimum_benchmark/backends/isolation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
LOGGER = getLogger("isolation")


def check_cuda_isolation(isolated_devices: List[int], isolated_pid: int) -> None:
def check_cuda_isolation(isolated_devices: List[int], permitted_pids: List[int]) -> None:
"""
Raises a RuntimeError if any process other than the benchmark process is running on the specified CUDA devices.
Raises a RuntimeError if any process other than the permitted ones is running on the specified CUDA devices.
"""
pids: Dict[int, set] = {}
for device_id in isolated_devices:
Expand All @@ -34,9 +34,7 @@ def check_cuda_isolation(isolated_devices: List[int], isolated_pid: int) -> None
device_handle = nvml.nvmlDeviceGetHandleByIndex(device_id)
device_processes = nvml.nvmlDeviceGetComputeRunningProcesses(device_handle)
for device_process in device_processes:
if device_process.pid == os.getpid():
continue
if device_process.pid != isolated_pid:
if device_process.pid not in permitted_pids:
LOGGER.warning(f"Found unexpected process {device_process.pid} on device {device_id}.")
LOGGER.warning(f"Process info: {device_process}")

Expand Down Expand Up @@ -65,9 +63,7 @@ def check_cuda_isolation(isolated_devices: List[int], isolated_pid: int) -> None
info = amdsmi.amdsmi_get_gpu_process_info(device_handle, process_handle)
if info["memory_usage"]["vram_mem"] == 4096:
continue
if info["pid"] == os.getpid():
continue
if info["pid"] != isolated_pid:
if info["pid"] not in permitted_pids:
LOGGER.warning(f"Found unexpected process {info['pid']} on device {device_id}.")
LOGGER.warning(f"Process info: {info}")

Expand All @@ -81,9 +77,7 @@ def check_cuda_isolation(isolated_devices: List[int], isolated_pid: int) -> None
info = amdsmi.amdsmi_get_process_info(device_handle, process_handle)
if info["memory_usage"]["vram_mem"] == 4096:
continue
if info["pid"] == os.getpid():
continue
if info["pid"] != isolated_pid:
if info["pid"] not in permitted_pids:
LOGGER.warning(f"Found unexpected process {info['pid']} on device {device_id}.")
LOGGER.warning(f"Process info: {info}")

Expand All @@ -96,37 +90,54 @@ def check_cuda_isolation(isolated_devices: List[int], isolated_pid: int) -> None
all_pids = set()
for device_id in isolated_devices:
all_pids |= pids[device_id]
other_pids = all_pids - {isolated_pid}
other_pids = all_pids - set(permitted_pids)

if len(other_pids) > 0:
error_message = (
f"Expected only process {isolated_pid} on device(s) {isolated_devices}, but found {other_pids}."
f"Expected only process(se) {permitted_pids} on device(s) {isolated_devices}, but found {other_pids}."
)
raise RuntimeError(error_message)


def check_cuda_continuous_isolation(isolated_pid: int) -> None:
"""
Kills the benchmark process if any other process is running on the specified CUDA devices.
Kills the isolated process if any other process than the permitted ones is running on the specified CUDA devices.
"""

hydra_conf = OmegaConf.load(".hydra/hydra.yaml")
logging.config.dictConfig(OmegaConf.to_container(hydra_conf.hydra.job_logging, resolve=True))

if len(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")) == 1:
isolated_devices = [int(os.environ.get("CUDA_VISIBLE_DEVICES", "0"))]
elif os.environ.get("LOCAL_RANK", None) is not None:
local_rank = int(os.environ["LOCAL_RANK"])
available_devices = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")))
isolated_devices = [available_devices[local_rank]]
# distributed setting is tricky
if os.environ.get("LOCAL_WORLD_SIZE", None) is not None:
from torch.distributed import TCPStore

local_rank = os.environ["LOCAL_RANK"]
all_isolated_keys = [f"isolated_{other_rank}" for other_rank in range(int(os.environ["LOCAL_WORLD_SIZE"]))]
all_isolators_keys = [f"isolator_{other_rank}" for other_rank in range(int(os.environ["LOCAL_WORLD_SIZE"]))]

store = TCPStore(host_name=os.environ["MASTER_ADDR"], port=int(os.environ["MASTER_PORT"]))

store.add(f"isolator_{local_rank}", os.getpid())
store.add(f"isolated_{local_rank}", isolated_pid)
store.wait(all_isolated_keys + all_isolators_keys)

all_isolated_pids = [int(store.get(name)) for name in all_isolated_keys]
all_isolators_pids = [int(store.get(name)) for name in all_isolators_keys]
permitted_pids = all_isolated_pids + all_isolators_pids
assert len(permitted_pids) == len(set(permitted_pids)), "Found duplicated pids in the distributed setting"
else:
isolated_devices = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")))
isolator_pid = os.getpid()
permitted_pids = [isolator_pid, isolated_pid]

isolated_devices = [int(device) for device in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]

LOGGER.info(
f"Continuously checking only process(es) {permitted_pids} is/are running on device(s) {isolated_devices}"
)

LOGGER.info(f"Continuously checking only process {isolated_pid} is running on device(s) {isolated_devices}")
print(f"Continuously checking only process {isolated_pid} is running on device(s) {isolated_devices}")
while True:
try:
check_cuda_isolation(isolated_devices, isolated_pid)
check_cuda_isolation(isolated_devices, permitted_pids)
time.sleep(0.1)
except Exception as e:
LOGGER.error(f"Error while checking CUDA isolation: {e}")
Expand Down
4 changes: 1 addition & 3 deletions optimum_benchmark/launchers/torchrun/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ def __init__(self) -> None:
def configure(self, config: "TorchrunConfig") -> None:
super().configure(config)

LOGGER.info(f"Running {self.config.nproc_per_node} processes per node")

def launch(self, worker: Callable, *worker_args):
launch_config = LaunchConfig(
min_nodes=self.config.min_nodes,
Expand Down Expand Up @@ -62,7 +60,7 @@ def entrypoint(fn, *args):
This a pickalable function that correctly sets up the logging configuration
"""

if os.environ.get("LOCAL_RANK", "0") == "0":
if os.environ["LOCAL_RANK"] == "0":
hydra_conf = OmegaConf.load(".hydra/hydra.yaml")
logging.config.dictConfig(OmegaConf.to_container(hydra_conf.hydra.job_logging, resolve=True))
else:
Expand Down

0 comments on commit 2a57ebf

Please sign in to comment.