Skip to content

Commit

Permalink
fix(dist_utils): fix port conflict in setup_distribution (#178)
Browse files Browse the repository at this point in the history
* fix(dist_utils): fix port conflict in setup_distribution

* fix(dist_utils): fix port conflict in setup_distribution

* fix(dist_utils): fix port conflict in setup_distribution

* style: polish some code

---------

Co-authored-by: Kai Lv <[email protected]>
  • Loading branch information
gyt1145028706 and KaiLv69 authored May 9, 2024
1 parent 53de415 commit 95dd8a2
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions collie/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import subprocess
import socket

import deepspeed
import torch
Expand Down Expand Up @@ -167,6 +168,23 @@ def _decompose_slurm_nodes(s):
return results


def is_port_occupied(host: str, port: int) -> bool:
"""检查端口是否被占用"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((host, port))
return False # 如果绑定成功,返回 False, 表示端口未被占用
except socket.error as e:
return True # 如果绑定失败,返回 True,表示端口被占用


def find_free_port(host: str) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((host, 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


def setup_distribution(config) -> None:
"""设置分布式环境。
Expand Down Expand Up @@ -227,8 +245,12 @@ def setup_distribution(config) -> None:
if "MASTER_PORT" in os.environ.keys():
master_port = os.environ["MASTER_PORT"]
else:
master_port = 27002
os.environ["MASTER_PORT"] = f"{master_port}"
master_port = "27002"
os.environ["MASTER_PORT"] = master_port
if is_port_occupied(master_addr, int(master_port)):
free_port = find_free_port(master_addr)
raise RuntimeError(f"Port {master_port} is already in use, "
f"please switch to port {free_port} by `export MASTER_PORT={free_port}` in terminal.")
os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
os.environ["RANK"] = os.environ["SLURM_PROCID"]
os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]
Expand Down

0 comments on commit 95dd8a2

Please sign in to comment.