From e5c5054474f583d6d9499624649353995d63c70a Mon Sep 17 00:00:00 2001 From: Memo Akten Date: Thu, 16 May 2019 19:42:58 +0300 Subject: [PATCH] allow models to be in a separate folder via models_dir argument (#129) * models_dir argument to allow models in a separate folder * default value for models_dir to be same as before * allow environment variables and user home in models_dir --- src/encoder.py | 6 +++--- src/generate_unconditional_samples.py | 10 +++++++--- src/interactive_conditional_samples.py | 10 +++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/encoder.py b/src/encoder.py index 03e0ce211..5f52e723c 100644 --- a/src/encoder.py +++ b/src/encoder.py @@ -105,10 +105,10 @@ def decode(self, tokens): text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) return text -def get_encoder(model_name): - with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f: +def get_encoder(model_name, models_dir): + with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f: encoder = json.load(f) - with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: + with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: bpe_data = f.read() bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] return Encoder( diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py index 87e212972..f18a83891 100755 --- a/src/generate_unconditional_samples.py +++ b/src/generate_unconditional_samples.py @@ -16,6 +16,7 @@ def sample_model( length=None, temperature=1, top_k=0, + models_dir='models', ): """ Run the sample_model @@ -35,10 +36,13 @@ def sample_model( considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) """ - enc = encoder.get_encoder(model_name) + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) + enc = encoder.get_encoder(model_name, models_dir) hparams = model.default_hparams() - with open(os.path.join('models', model_name, 'hparams.json')) as f: + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: @@ -58,7 +62,7 @@ def sample_model( )[:, 1:] saver = tf.train.Saver() - ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) saver.restore(sess, ckpt) generated = 0 diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py index 166171aaf..ae348d842 100755 --- a/src/interactive_conditional_samples.py +++ b/src/interactive_conditional_samples.py @@ -16,6 +16,7 @@ def interact_model( length=None, temperature=1, top_k=0, + models_dir='models', ): """ Interactively run the model @@ -34,14 +35,17 @@ def interact_model( considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) """ + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 - enc = encoder.get_encoder(model_name) + enc = encoder.get_encoder(model_name, models_dir) hparams = model.default_hparams() - with open(os.path.join('models', model_name, 'hparams.json')) as f: + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: @@ -61,7 +65,7 @@ def interact_model( ) saver = tf.train.Saver() - ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) saver.restore(sess, ckpt) while True: