Skip to content

Commit

Permalink
HF revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak authored and jaidhyani committed Apr 27, 2024
1 parent ad98f61 commit 9feac6e
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 59 deletions.
9 changes: 0 additions & 9 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def main(
batch_size: Int,
dataset_name: str,
username: str,
token: str,
funct_test: bool = False,
):
"""
Expand All @@ -30,7 +29,6 @@ def main(
- batch_size: The batch size for processing. 80 worked well in CPU.
- dataset_name: The name of the dataset from which validation set will be loaded
- username: Hugging Face API username
- token: Hugging Face API token
"""
val_ds = load_validation_dataset(dataset_name)

Expand Down Expand Up @@ -63,7 +61,6 @@ def main(
repo_id=repo_id,
split="validation",
private=False,
token=token,
)


Expand All @@ -90,11 +87,6 @@ def main(
type=str,
help="Hugging Face API username",
)
parser.add_argument(
"--token",
type=str,
help="Hugging Face API token",
)
parser.add_argument(
"--test-funct", action="store_true", help="Enable test function mode"
)
Expand All @@ -109,6 +101,5 @@ def main(
args.batch_size,
args.dataset_name,
args.username,
args.token,
args.test_funct,
)
2 changes: 1 addition & 1 deletion scripts/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main():
# run training
results, run_context = run_training(config)
final_out_dir = os.path.join(config.output_dir, "final")
save_results(config, results, run_context, final_out_dir)
save_results(config, results, run_context, final_out_dir, final=True)
print(f"Saved results to {final_out_dir}")


Expand Down
8 changes: 1 addition & 7 deletions scripts/tokenize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@
required=True,
help="Context size of the tokenized dataset as input of the model",
)
parser.add_argument(
"--hf-token",
"-t",
type=str,
help="Hugging Face API token",
)
parser.add_argument(
"--batch-size",
"-b",
Expand Down Expand Up @@ -86,7 +80,7 @@
assert tokenizer.bos_token_id is not None, "Tokenizer must have a bos_token_id"
assert tokenizer.eos_token_id is not None, "Tokenizer must have a eos_token_id"

api = HfApi(token=args.hf_token)
api = HfApi()
api.create_repo(repo_id=args.out_repo_id, repo_type="dataset", exist_ok=True)
tokenize_and_upload_split(
dataset_split=in_dataset_split,
Expand Down
7 changes: 0 additions & 7 deletions scripts/train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ def train_byte_level_bpe(
required=True,
help="Where to push the resulting tokenizer",
)
parser.add_argument(
"--hf-token",
"-t",
type=str,
help="Hugging Face API token",
)
args = parser.parse_args()

print(f"Loading dataset '{args.in_repo_id}'...")
Expand All @@ -86,5 +80,4 @@ def train_byte_level_bpe(
)
tokenizer.push_to_hub(
repo_id=args.out_repo_id,
token=args.hf_token,
)
3 changes: 2 additions & 1 deletion src/delphi/test_configs/debug.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
},
"dataset": {
"name": "delphi-suite/v0-tinystories-v2-clean-tokenized"
}
},
"out_repo_id": ""
}
3 changes: 2 additions & 1 deletion src/delphi/test_configs/debug_transformers_bloom.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
"torch_seed": 1337,
"dataset": {
"name": "delphi-suite/v0-tinystories-v2-clean-tokenized"
}
},
"out_repo_id": ""
}
3 changes: 2 additions & 1 deletion src/delphi/test_configs/v0-llama2-100k.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
"torch_seed": 1337,
"dataset": {
"name": "delphi-suite/v0-tinystories-v2-clean-tokenized"
}
},
"out_repo_id": ""
}
19 changes: 0 additions & 19 deletions src/delphi/train/config/huggingface_config.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/delphi/train/config/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .adam_config import AdamConfig
from .dataset_config import DatasetConfig
from .debug_config import DebugConfig
from .huggingface_config import HuggingfaceConfig
from .wandb_config import WandbConfig


Expand Down Expand Up @@ -83,6 +82,7 @@ class TrainingConfig:
metadata={"help": "seed used for pseudorandomly sampling data during training"},
)
torch_seed: int = field(metadata={"help": "seed used for torch"})
save_optimizer: bool = True

# data
dataset: DatasetConfig = field(
Expand All @@ -91,7 +91,7 @@ class TrainingConfig:

# third party
wandb: Optional[WandbConfig] = None
hf: Optional[HuggingfaceConfig] = None
out_repo_id: str

# debug
debug_config: DebugConfig = field(default_factory=DebugConfig)
32 changes: 21 additions & 11 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def save_results(
train_results: ModelTrainingState,
run_context: RunContext,
results_path: str,
final: bool = False,
):
"""
save results to disk, and to huggingface if configured to do so.
Expand All @@ -247,21 +248,21 @@ def save_results(
config, context (e.g. hardware), training step, etc
"""
os.makedirs(results_path, exist_ok=True)
with open(os.path.join(results_path, "config.json"), "w") as file:
with open(os.path.join(results_path, "training_config.json"), "w") as file:
json.dump(asdict(config), file, indent=2)
model = train_results.model
if isinstance(model, PreTrainedModel):
model = cast(PreTrainedModel, model)
model.save_pretrained(
save_directory=os.path.join(results_path, "model"),
save_directory=results_path,
)
else:
st.save_model(
train_results.model,
os.path.join(results_path, "model", "model.safetensors"),
model,
os.path.join(results_path, "model.safetensors"),
)
with open(os.path.join(results_path, "opt.pt"), "wb") as f:
torch.save(train_results.optimizer.state_dict(), f)
if config.save_optimizer:
with open(os.path.join(results_path, "optimizer.pt"), "wb") as f:
torch.save(train_results.optimizer.state_dict(), f)
with open(os.path.join(results_path, "training_state.json"), "w") as file:
training_state_dict = {
"iter_num": train_results.iter_num,
Expand All @@ -274,13 +275,22 @@ def save_results(
run_context_dict = asdict(run_context)
run_context_dict["device"] = str(run_context.device)
json.dump(run_context_dict, file, indent=2)
if config.hf and config.hf.push_checkpoints_to_hub:
api = HfApi(token=config.hf.token)
if config.out_repo_id:
api = HfApi()
api.create_repo(config.out_repo_id, exist_ok=True)
branch_name = f"iter{train_results.iter_num}"
api.create_branch(config.out_repo_id, branch=branch_name)
api.upload_folder(
folder_path=results_path,
repo_id=str(config.hf.repo_id),
revision=f"iter_{train_results.iter_num}",
repo_id=config.out_repo_id,
revision=branch_name,
)
if final:
api.upload_folder(
folder_path=results_path,
repo_id=config.out_repo_id,
revision="main",
)


def count_tokens_so_far(config: TrainingConfig, mts: ModelTrainingState) -> int:
Expand Down

0 comments on commit 9feac6e

Please sign in to comment.