Skip to content

Commit

Permalink
allow models to be in a separate folder via models_dir argument (ope…
Browse files Browse the repository at this point in the history
…nai#129)

* models_dir argument to allow models in a separate folder

* default value for models_dir to be same as before

* allow environment variables and user home in models_dir
  • Loading branch information
memo authored and WuTheFWasThat committed May 16, 2019
1 parent dd75299 commit e5c5054
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def decode(self, tokens):
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text

def get_encoder(model_name):
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f:
def get_encoder(model_name, models_dir):
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
encoder = json.load(f)
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(
Expand Down
10 changes: 7 additions & 3 deletions src/generate_unconditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def sample_model(
length=None,
temperature=1,
top_k=0,
models_dir='models',
):
"""
Run the sample_model
Expand All @@ -35,10 +36,13 @@ def sample_model(
considered for each step (token), resulting in deterministic completions,
while 40 means 40 words are considered at each step. 0 (default) is a
special setting meaning no restrictions. 40 generally is a good value.
:models_dir : path to parent folder containing model subfolders
(i.e. contains the <model_name> folder)
"""
enc = encoder.get_encoder(model_name)
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))

if length is None:
Expand All @@ -58,7 +62,7 @@ def sample_model(
)[:, 1:]

saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt)

generated = 0
Expand Down
10 changes: 7 additions & 3 deletions src/interactive_conditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def interact_model(
length=None,
temperature=1,
top_k=0,
models_dir='models',
):
"""
Interactively run the model
Expand All @@ -34,14 +35,17 @@ def interact_model(
considered for each step (token), resulting in deterministic completions,
while 40 means 40 words are considered at each step. 0 (default) is a
special setting meaning no restrictions. 40 generally is a good value.
:models_dir : path to parent folder containing model subfolders
(i.e. contains the <model_name> folder)
"""
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
if batch_size is None:
batch_size = 1
assert nsamples % batch_size == 0

enc = encoder.get_encoder(model_name)
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))

if length is None:
Expand All @@ -61,7 +65,7 @@ def interact_model(
)

saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt)

while True:
Expand Down

0 comments on commit e5c5054

Please sign in to comment.