From 91726256eec7131fdcfb8e8b4faa0bb9d8c08427 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Mon, 22 Apr 2019 20:30:12 -0700 Subject: [PATCH] Better restore and counter behavior (#13, #14) --- gpt_2_simple/gpt_2.py | 11 ++++++----- setup.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index ed84d74..bec4974 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -171,24 +171,25 @@ def maketree(path): counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') - if os.path.exists(counter_path): + if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 + counter_base = counter def save(): maketree(checkpoint_path) print( 'Saving', os.path.join(checkpoint_path, - 'model-{}').format(counter)) + 'model-{}').format(counter-1)) saver.save( sess, os.path.join(checkpoint_path, 'model'), - global_step=counter) + global_step=counter-1) with open(counter_path, 'w') as fp: - fp.write(str(counter) + '\n') + fp.write(str(counter-1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) @@ -219,7 +220,7 @@ def sample_batch(): try: while True: - if counter == steps: + if steps > 0 and counter == (counter_base + steps): save() return if counter % save_every == 0: diff --git a/setup.py b/setup.py index 3ed3d20..502e492 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( name='gpt_2_simple', packages=['gpt_2_simple'], # this must be the same as the name above - version='0.3', + version='0.3.1', description="Python package to easily retrain OpenAI's GPT-2 " \ "text-generating model on new texts.", long_description=long_description,