diff --git a/src/encoder.py b/src/encoder.py index 03e0ce211..5f52e723c 100644 --- a/src/encoder.py +++ b/src/encoder.py @@ -105,10 +105,10 @@ def decode(self, tokens): text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) return text -def get_encoder(model_name): - with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f: +def get_encoder(model_name, models_dir): + with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f: encoder = json.load(f) - with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: + with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: bpe_data = f.read() bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] return Encoder( diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py index 87e212972..f18a83891 100755 --- a/src/generate_unconditional_samples.py +++ b/src/generate_unconditional_samples.py @@ -16,6 +16,7 @@ def sample_model( length=None, temperature=1, top_k=0, + models_dir='models', ): """ Run the sample_model @@ -35,10 +36,13 @@ def sample_model( considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) """ - enc = encoder.get_encoder(model_name) + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) + enc = encoder.get_encoder(model_name, models_dir) hparams = model.default_hparams() - with open(os.path.join('models', model_name, 'hparams.json')) as f: + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: @@ -58,7 +62,7 @@ def sample_model( )[:, 1:] saver = tf.train.Saver() - ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) saver.restore(sess, ckpt) generated = 0 diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py index 166171aaf..ae348d842 100755 --- a/src/interactive_conditional_samples.py +++ b/src/interactive_conditional_samples.py @@ -16,6 +16,7 @@ def interact_model( length=None, temperature=1, top_k=0, + models_dir='models', ): """ Interactively run the model @@ -34,14 +35,17 @@ def interact_model( considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) """ + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 - enc = encoder.get_encoder(model_name) + enc = encoder.get_encoder(model_name, models_dir) hparams = model.default_hparams() - with open(os.path.join('models', model_name, 'hparams.json')) as f: + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: @@ -61,7 +65,7 @@ def interact_model( ) saver = tf.train.Saver() - ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) saver.restore(sess, ckpt) while True: