diff --git a/src/generate.py b/src/generate.py new file mode 100644 index 000000000..ab1f92393 --- /dev/null +++ b/src/generate.py @@ -0,0 +1,91 @@ +import json +import os +import numpy as np +import tensorflow as tf + +import model, sample, encoder + + +def samples( + prompt, + model_name='124M', + seed=None, + nsamples=1, + batch_size=1, + length=None, + temperature=1, + top_k=40, + top_p=1, + models_dir='models' +): + """ + Run the sample_model and output as a list + :model_name=124M : String, which model to use + :seed=None : Integer seed for random number generators, fix seed to + reproduce results + :nsamples=0 : Number of samples to return, if 0, continues to + generate samples indefinately. + :batch_size=1 : Number of batches (only affects speed/memory). + :length=None : Number of tokens in generated text, if None (default), is + determined by model hyperparameters + :temperature=1 : Float value controlling randomness in boltzmann + distribution. Lower temperature results in less random completions. As the + temperature approaches zero, the model will become deterministic and + repetitive. Higher temperature results in more random completions. + :top_k=0 : Integer value controlling diversity. 1 means only 1 word is + considered for each step (token), resulting in deterministic completions, + while 40 means 40 words are considered at each step. 0 (default) is a + special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) + """ + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) + if batch_size is None: + batch_size = 1 + assert nsamples % batch_size == 0 + + enc = encoder.get_encoder(model_name, models_dir) + hparams = model.default_hparams() + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: + hparams.update(json.load(f)) + + if length is None: + length = hparams['n_ctx'] // 2 + elif length > hparams['n_ctx']: + raise ValueError("Can't get samples longer than window size: %s" % hparams['n_ctx']) + + with tf.compat.v1.Session(graph=tf.Graph()) as sess: + context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) + np.random.seed(seed) + tf.compat.v1.set_random_seed(seed) + output = sample.sample_sequence( + hparams=hparams, length=length, + context=context, + batch_size=batch_size, + temperature=temperature, top_k=top_k, top_p=top_p + ) + + saver = tf.compat.v1.train.Saver() + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) + saver.restore(sess, ckpt) + + + context_tokens = enc.encode(prompt) + generated = 0 + generated_text_samples = [] + for _ in range(nsamples // batch_size): + out = sess.run(output, feed_dict={ + context: [context_tokens for _ in range(batch_size)] + })[:, len(context_tokens):] + for i in range(batch_size): + generated += 1 + text = enc.decode(out[i]) + generated_text_samples += [text] + + return generated_text_samples + +if __name__=="__main__": + prompt = 'The fitness grahm pacer test.' + output = samples(prompt) + with open('test.json', 'w') as f: + f.write(json.dumps(output, indent=4)) \ No newline at end of file