diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index 37297a2..2103532 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -93,7 +93,7 @@ def download_gpt2(model_dir='models', model_name='124M'): file_name=file_name) -def start_tf_sess(threads=-1, server=None): +def start_tf_sess(threads=-1, server=None, reuse=False): """ Returns a tf.Session w/ config """ @@ -105,9 +105,9 @@ def start_tf_sess(threads=-1, server=None): config.inter_op_parallelism_threads = threads if server is not None: - return tf.compat.v1.Session(target=server.target, config=config) + return tf.compat.v1.Session(target=server.target, config=config, reuse=reuse) - return tf.compat.v1.Session(config=config) + return tf.compat.v1.Session(config=config, reuse=reuse) def reset_session(sess, threads=-1, server=None):