Skip to content
This repository has been archived by the owner on Jan 16, 2022. It is now read-only.

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
mgrankin committed Oct 25, 2019
1 parent 6e241ef commit 7f1e6c0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,4 @@ tpu/.gcp_credentials.json
.terraform
*terraform.tfstate*

pytorch_model.bin
5 changes: 1 addition & 4 deletions tpu_lm_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,7 @@ def save_pretrained(model, save_directory):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)

best_model_state_dict = {k:v.to('cpu') for k, v in model.state_dict().items()}
best_model_state_dict = OrderedDict(best_model_state_dict)
torch.save(best_model_state_dict, output_model_file)
#xm.save(model_to_save.state_dict(), output_model_file)
xm.save(model_to_save.state_dict(), output_model_file)
log_info(f"Model weights saved in {output_model_file}")

def save_state(args, model, tokenizer, global_step):
Expand Down

0 comments on commit 7f1e6c0

Please sign in to comment.