diff --git a/src/delphi/train/config/training_config.py b/src/delphi/train/config/training_config.py index 2de91cce..10076de3 100644 --- a/src/delphi/train/config/training_config.py +++ b/src/delphi/train/config/training_config.py @@ -92,6 +92,10 @@ class TrainingConfig: # third party wandb: Optional[WandbConfig] = None out_repo_id: str + readme_path: str = field( + default="", + metadata={"help": "for HF model card"}, + ) # debug debug_config: DebugConfig = field(default_factory=DebugConfig) diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 93b467ef..d394583f 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -2,6 +2,7 @@ import logging import math import os +import shutil import time from collections.abc import Generator from dataclasses import asdict, dataclass, field @@ -275,6 +276,8 @@ def save_results( run_context_dict = asdict(run_context) run_context_dict["device"] = str(run_context.device) json.dump(run_context_dict, file, indent=2) + if config.readme_path: + shutil.copy(config.readme_path, os.path.join(results_path, "README.md")) if config.out_repo_id: api = HfApi() api.create_repo(config.out_repo_id, exist_ok=True)