Skip to content

Commit

Permalink
Better restore and counter behavior (#13, #14)
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Apr 23, 2019
1 parent afaf43f commit 9172625
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,24 +171,25 @@ def maketree(path):

counter = 1
counter_path = os.path.join(checkpoint_path, 'counter')
if os.path.exists(counter_path):
if os.path.exists(counter_path) and restore_from == 'latest':
# Load the step number if we're resuming a run
# Add 1 so we don't immediately try to save again
with open(counter_path, 'r') as fp:
counter = int(fp.read()) + 1
counter_base = counter

def save():
maketree(checkpoint_path)
print(
'Saving',
os.path.join(checkpoint_path,
'model-{}').format(counter))
'model-{}').format(counter-1))
saver.save(
sess,
os.path.join(checkpoint_path, 'model'),
global_step=counter)
global_step=counter-1)
with open(counter_path, 'w') as fp:
fp.write(str(counter) + '\n')
fp.write(str(counter-1) + '\n')

def generate_samples():
context_tokens = data_sampler.sample(1)
Expand Down Expand Up @@ -219,7 +220,7 @@ def sample_batch():

try:
while True:
if counter == steps:
if steps > 0 and counter == (counter_base + steps):
save()
return
if counter % save_every == 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
setup(
name='gpt_2_simple',
packages=['gpt_2_simple'], # this must be the same as the name above
version='0.3',
version='0.3.1',
description="Python package to easily retrain OpenAI's GPT-2 " \
"text-generating model on new texts.",
long_description=long_description,
Expand Down

0 comments on commit 9172625

Please sign in to comment.