Skip to content

Commit

Permalink
Checkpoint standalone + print_every + run
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Apr 18, 2019
1 parent b4ad65d commit 71a49e3
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 42 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
__pycache__
shakespeare.txt
dist/
gpt_2_simple.egg-info/
gpt_2_simple.egg-info/
build/
30 changes: 20 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This package incorporates and makes minimal low-level changes to:
* Model finetuning from Neil Shepperd's [fork](https://github.com/nshepperd/gpt-2) of GPT-2 (MIT License)
* Text generation output management from [textgenrnn](https://github.com/minimaxir/textgenrnn) (MIT License / also created by me)

For finetuning, it is **strongly** recommended to use a GPU, although you can generate using a CPU. If you are training in the cloud, using a Colaboratory notebook or a Google Compute Engine VM w/ the TensorFlow Deep Learning image is strongly recommended. (as the GPT-2 model is hosted on GCP)
For finetuning, it is **strongly** recommended to use a GPU, although you can generate using a CPU (albeit much more slowly). If you are training in the cloud, using a Colaboratory notebook or a Google Compute Engine VM w/ the [TensorFlow Deep Learning](https://cloud.google.com/deep-learning-vm/) image is strongly recommended. (as the GPT-2 model is hosted on GCP)

## Usage

Expand All @@ -20,51 +20,61 @@ pip3 install gpt-2-simple

An example for downloading the model to the local system, fineturning it on a dataset. and generating some text.

Warning: the pretrained model, and thus any finetuned model, is 500MB!
Warning: the pretrained model, and thus any finetuned model, is 500 MB!

```python
import gpt_2_simple as gpt2

gpt2.download_gpt2() # model is saved into current directory under /117M/

sess = gpt2.start_tf_sess()
gpt2.finetune(sess, 'shakespeare.txt', steps=100) # steps is max number of training steps
gpt2.finetune(sess, 'shakespeare.txt', steps=1000) # steps is max number of training steps

text = gpt2.generate(sess)
gpt2.generate(sess)
```

The generated model checkpoints are by default in `checkpoint/run1`. If you want to load a model from that folder and generate text from it:
The generated model checkpoints are by default in `/checkpoint/run1`. If you want to load a model from that folder and generate text from it:

```python
import gpt_2_simple as gpt2

gpt2.download_gpt2() # model is saved into current directory under /117M/

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)

text = gpt2.generate(sess)
gpt2.generate(sess)
```

As with textgenrnn, you can generate and save text for later use (e.g. an API or a bot) by using the `return_as_list` parameter.

```python
single_text = gpt2.generate(sess, return_as_list=True)[0]
print(single_text)
```

You can pass a `run_name` parameter to `finetune` and `load_gpt2` if you want to store/load multiple models in a `checkpoint` folder.

NB: *Restart the Python session first* if you want to finetune on another dataset or load another model.

## Differences Between gpt-2-simple And Other Text Generation Utilities

The method GPT-2 uses to generate text is slightly different than those like other packages like textgenrnn (specifically, generating full sequence purely in the GPU and decoding it later), and cannot easily be fixed without hacking the underlying model code. As a result:

* In general, GPT-2 is better at maintaining context over its entire gereration length, making it good for generating conversational text. The text is also generally gramatically correct, with proper capitalization and few typoes.
* The original GPT-2 model was trained on a *very* large variety of sources, allowing the model to incorporate trends not seen in the input text.
* GPT-2 can only generate a maximum of 1024 tokens per request (about 3-4 paragraphs of English text).
* GPT-2 cannot stop early upon reaching a specific end token. (workaround: pass the `truncate` parameter to a `generate` function to only collect text until a specified end token)
* Higher temperatures work better (e.g. 0.7 - 1.0) to generate more interesting text (while other frameworks work better between 0.2 - 0.5)
* Higher temperatures work better (e.g. 0.7 - 1.0) to generate more interesting text, while other frameworks work better between 0.2 - 0.5.
* When finetuning GPT-2, it has no sense of the beginning or end of a document within a larger text. You'll need to use a bespoke character sequence to indicate the beginning and end of a document. Then while generating, you can specify a `prefix` targeting the beginning token sequences, and a `truncate` targeting the end token sequence.
* GPT-2 allows you to generate texts in parallel by setting a `batch_size` that is divisible into `nsamples`, resulting in much faster generation. Works very well with a GPU (can set `batch_size` to ~20 on Colaboratory's K80)!
* GPT-2 allows you to generate texts in parallel by setting a `batch_size` that is divisible into `nsamples`, resulting in much faster generation. Works very well with a GPU (can set `batch_size` up to 20 on Colaboratory's K80)!

## Planned Work

Note: this project is intended to have a very tight scope unless demand dictates otherwise.

* Allow users to generate texts longer than 1024 tokens.
* Allow users to use Colaboratory's TPU for finetuning.
* Allow users to use multiple GPUs (e.g. Horovod)
* For Colaboratory, allow model to automatically save checkpoints to Google Drive during training to prevent timeouts.

## Maintainer/Creator

Expand Down
72 changes: 44 additions & 28 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def finetune(sess,
sample_length=1023,
sample_num=1,
save_every=1000,
print_every=1,
max_checkpoints=1,
model_load=False):
"""Finetunes the model on the given dataset.
Expand All @@ -83,15 +84,23 @@ def finetune(sess,
CHECKPOINT_DIR = 'checkpoint'
SAMPLE_DIR = 'samples'

checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

def maketree(path):
try:
os.makedirs(path)
except:
pass

enc = encoder.get_encoder(model_name)
maketree(checkpoint_path)
if not model_load:
for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
shutil.copyfile(os.path.join('models', model_name, file),
os.path.join(checkpoint_path, file))

enc = encoder.get_encoder(checkpoint_path)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))

if sample_length > hparams.n_ctx:
Expand Down Expand Up @@ -127,17 +136,15 @@ def maketree(path):
loss, var_list=train_vars)
summary_loss = tf.summary.scalar('loss', loss)

summary_log = tf.summary.FileWriter(
os.path.join(CHECKPOINT_DIR, run_name))
summary_log = tf.summary.FileWriter(checkpoint_path)

saver = tf.train.Saver(
var_list=train_vars,
max_to_keep=max_checkpoints)
sess.run(tf.global_variables_initializer())

if restore_from == 'latest':
ckpt = tf.train.latest_checkpoint(
os.path.join(CHECKPOINT_DIR, run_name))
ckpt = tf.train.latest_checkpoint(checkpoint_path)
if ckpt is None:
# Get fresh GPT weights if new run.
ckpt = tf.train.latest_checkpoint(
Expand All @@ -160,22 +167,22 @@ def maketree(path):
print('Training...')

counter = 1
counter_path = os.path.join(CHECKPOINT_DIR, run_name, 'counter')
counter_path = os.path.join(checkpoint_path, 'counter')
if os.path.exists(counter_path):
# Load the step number if we're resuming a run
# Add 1 so we don't immediately try to save again
with open(counter_path, 'r') as fp:
counter = int(fp.read()) + 1

def save():
maketree(os.path.join(CHECKPOINT_DIR, run_name))
maketree(checkpoint_path)
print(
'Saving',
os.path.join(CHECKPOINT_DIR, run_name,
os.path.join(checkpoint_path,
'model-{}').format(counter))
saver.save(
sess,
os.path.join(CHECKPOINT_DIR, run_name, 'model'),
os.path.join(checkpoint_path, 'model'),
global_step=counter)
with open(counter_path, 'w') as fp:
fp.write(str(counter) + '\n')
Expand Down Expand Up @@ -230,16 +237,17 @@ def sample_batch():

summary_log.add_summary(v_summary, counter)

avg_loss = (avg_loss[0] * 0.99 + v_loss,
avg_loss[1] * 0.99 + 1.0)
if counter % print_every == 0:
avg_loss = (avg_loss[0] * 0.99 + v_loss,
avg_loss[1] * 0.99 + 1.0)

print(
'[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
.format(
counter=counter,
time=time.time() - start_time,
loss=v_loss,
avg=avg_loss[0] / avg_loss[1]))
print(
'[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
.format(
counter=counter,
time=time.time() - start_time,
loss=v_loss,
avg=avg_loss[0] / avg_loss[1]))

counter += 1
except KeyboardInterrupt:
Expand All @@ -248,12 +256,12 @@ def sample_batch():


def load_gpt2(sess,
checkpoint_path=os.path.join('models', '117M')):
run_name="run1"):
"""Loads the model checkpoint into a TensorFlow session
for repeated predictions.
"""

finetune(sess, '', model_load=True)
finetune(sess, '', run_name=run_name, model_load=True)


def generate(sess,
Expand All @@ -266,9 +274,10 @@ def generate(sess,
seed=None,
nsamples=1,
batch_size=1,
length=1024,
length=1023,
temperature=0.7,
top_k=0):
top_k=0,
run_name='run1'):
"""Generates text from a model loaded into memory.
Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
Expand All @@ -284,9 +293,14 @@ def generate(sess,
if prefix:
context = tf.placeholder(tf.int32, [batch_size, None])

enc = encoder.get_encoder(model_name)
CHECKPOINT_DIR = 'checkpoint'
SAMPLE_DIR = 'samples'

checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

enc = encoder.get_encoder(checkpoint_path)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))

np.random.seed(seed)
Expand Down Expand Up @@ -345,9 +359,10 @@ def generate_to_file(sess,
seed=None,
nsamples=1,
batch_size=1,
length=1024,
length=1023,
temperature=0.7,
top_k=0):
top_k=0,
run_name='run1'):
"""Generates the texts to a file.
sample_delim separates texts: set to '' if each text is a small document.
Expand All @@ -367,7 +382,8 @@ def generate_to_file(sess,
batch_size,
length,
temperature,
top_k)
top_k,
run_name)


def mount_gdrive():
Expand Down
6 changes: 3 additions & 3 deletions gpt_2_simple/src/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def decode(self, tokens):
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text

def get_encoder(model_name):
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f:
def get_encoder(checkpoint_path):
with open(os.path.join(checkpoint_path, 'encoder.json'), 'r') as f:
encoder = json.load(f)
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
with open(os.path.join(checkpoint_path, 'vocab.bpe'), 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(
Expand Down

0 comments on commit 71a49e3

Please sign in to comment.