Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[👶 PR] Improve Launcher to handle SLURM_TMPDIR #228

Merged
merged 5 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions LAUNCH.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ usage: launch.py [-h] [--help-md] [--job_name JOB_NAME] [--outdir OUTDIR]
[--cpus_per_task CPUS_PER_TASK] [--mem MEM] [--gres GRES]
[--partition PARTITION] [--modules MODULES]
[--conda_env CONDA_ENV] [--venv VENV] [--template TEMPLATE]
[--code_dir CODE_DIR] [--jobs JOBS] [--dry-run] [--verbose]
[--force]
[--code_dir CODE_DIR] [--git_checkout GIT_CHECKOUT]
[--jobs JOBS] [--dry-run] [--verbose] [--force]

optional arguments:
-h, --help show this help message and exit
Expand All @@ -35,6 +35,11 @@ optional arguments:
$root/mila/sbatch/template-conda.sh
--code_dir CODE_DIR cd before running main.py (defaults to here). Defaults
to $root
--git_checkout GIT_CHECKOUT
Branch or commit to checkout before running the code.
This is only used if --code_dir='$SLURM_TMPDIR'. If
not specified, the current branch is used. Defaults to
None
--jobs JOBS jobs (nested) file name in external/jobs (with or
without .yaml). Or an absolute path to a yaml file
anywhere Defaults to None
Expand All @@ -54,6 +59,7 @@ conda_env : gflownet
cpus_per_task : 2
dry-run : False
force : False
git_checkout : None
gres : gpu:1
job_name : gflownet
jobs : None
Expand Down
69 changes: 67 additions & 2 deletions mila/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from os.path import expandvars
from pathlib import Path
from textwrap import dedent
from git import Repo

from yaml import safe_load

ROOT = Path(__file__).resolve().parent.parent

GIT_WARNING = True

HELP = dedent(
"""
## 🥳 User guide
Expand Down Expand Up @@ -337,13 +340,63 @@ def print_md_help(parser, defaults):
print(HELP, end="")


def ssh_to_https(url):
"""
Converts a ssh git url to https.
Eg:
"""
if "https://" in url:
return url
if "git@" in url:
path = url.split(":")[1]
return f"https://github.com/{path}"
raise ValueError(f"Could not convert {url} to https")


def code_dir_for_slurm_tmp_dir_checkout(git_checkout):
global GIT_WARNING

repo = Repo(ROOT)
if git_checkout is None:
git_checkout = repo.active_branch.name
if GIT_WARNING:
print("💥 Git warnings:")
print(
f" • `git_checkout` not provided. Using current branch: {git_checkout}"
)
# warn for uncommitted changes
if repo.is_dirty() and GIT_WARNING:
print(
" • Your repo contains uncommitted changes. "
+ "They will *not* be available when cloning happens within the job."
)
if GIT_WARNING and "y" not in input("Continue anyway? [y/N] ").lower():
print("🛑 Aborted")
sys.exit(0)
GIT_WARNING = False

return dedent(
"""\
$SLURM_TMPDIR
git clone {git_url} tpm-gflownet
cd tpm-gflownet
{git_checkout}
echo "Current commit: $(git rev-parse HEAD)"
"""
).format(
git_url=ssh_to_https(repo.remotes.origin.url),
git_checkout=f"git checkout {git_checkout}" if git_checkout else "",
)


if __name__ == "__main__":
defaults = {
"code_dir": "$root",
"conda_env": "gflownet",
"cpus_per_task": 2,
"dry-run": False,
"force": False,
"git_checkout": None,
"gres": "gpu:1",
"job_name": "gflownet",
"jobs": None,
Expand Down Expand Up @@ -428,6 +481,14 @@ def print_md_help(parser, defaults):
help="cd before running main.py (defaults to here)."
+ f" Defaults to {defaults['code_dir']}",
)
parser.add_argument(
"--git_checkout",
type=str,
help="Branch or commit to checkout before running the code."
+ " This is only used if --code_dir='$SLURM_TMPDIR'. If not specified, "
+ " the current branch is used."
+ f" Defaults to {defaults['git_checkout']}",
)
parser.add_argument(
"--jobs",
type=str,
Expand Down Expand Up @@ -510,7 +571,11 @@ def print_md_help(parser, defaults):
job_args = deep_update(job_args, job_dict)
job_args = deep_update(job_args, args)

job_args["code_dir"] = str(resolve(job_args["code_dir"]))
job_args["code_dir"] = (
str(resolve(job_args["code_dir"]))
if "SLURM_TMPDIR" not in job_args["code_dir"]
else code_dir_for_slurm_tmp_dir_checkout(job_args.get("git_checkout"))
)
job_args["outdir"] = str(resolve(job_args["outdir"]))
job_args["venv"] = str(resolve(job_args["venv"]))
job_args["main_args"] = script_dict_to_main_args_str(job_args.get("script", {}))
Expand Down Expand Up @@ -542,7 +607,7 @@ def print_md_help(parser, defaults):
sbatch_path.parent.mkdir(parents=True, exist_ok=True)
# write template
sbatch_path.write_text(templated)
print(f" 🏷 Created ./{sbatch_path.relative_to(Path.cwd())}")
print(f"\n 🏷 Created ./{sbatch_path.relative_to(Path.cwd())}")
# Submit job to SLURM
out = popen(f"sbatch {sbatch_path}").read()
# Identify printed-out job id
Expand Down