Skip to content

Commit

Permalink
Merge pull request #15 from minimaxir/cli
Browse files Browse the repository at this point in the history
Cli
  • Loading branch information
minimaxir authored Apr 21, 2019
2 parents 15477cf + 636bd79 commit cd257bb
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 13 deletions.
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@ print(single_text)

You can pass a `run_name` parameter to `finetune` and `load_gpt2` if you want to store/load multiple models in a `checkpoint` folder.

There is also a command-line interface for both finetining and generation with strong default for just running on a Cloud VM w/ GPU. For finetuning (which will also download the model if not present):

```shell
gpt_2_simple finetune shakespeare.txt
```

And for generation, which generates texts to files in a `gen` folder:

```shell
gpt_2_simple generate
```

Most of the same parameters available in the functions are available as CLI arguments, e.g.:

```shell
gpt_2_simple generate --temperature 1.0 --nsamples 20 --batch_size 20 --length 50 --prefix "<|startoftext|>" --truncate "<|endoftext|>" --include_prefix False --nfiles 5
```

See below to see what some of the CLI arguments do.

NB: *Restart the Python session first* if you want to finetune on another dataset or load another model.

## Differences Between gpt-2-simple And Other Text Generation Utilities
Expand All @@ -72,8 +92,9 @@ The method GPT-2 uses to generate text is slightly different than those like oth
* GPT-2 can only generate a maximum of 1024 tokens per request (about 3-4 paragraphs of English text).
* GPT-2 cannot stop early upon reaching a specific end token. (workaround: pass the `truncate` parameter to a `generate` function to only collect text until a specified end token. You may want to reduce `length` appropriately.)
* Higher temperatures work better (e.g. 0.7 - 1.0) to generate more interesting text, while other frameworks work better between 0.2 - 0.5.
* When finetuning GPT-2, it has no sense of the beginning or end of a document within a larger text. You'll need to use a bespoke character sequence to indicate the beginning and end of a document. Then while generating, you can specify a `prefix` targeting the beginning token sequences, and a `truncate` targeting the end token sequence.
* When finetuning GPT-2, it has no sense of the beginning or end of a document within a larger text. You'll need to use a bespoke character sequence to indicate the beginning and end of a document. Then while generating, you can specify a `prefix` targeting the beginning token sequences, and a `truncate` targeting the end token sequence. You can also set `include_prefix=False` to discard the prefix token while generating (e.g. if it's something unwanted like `<|startoftext|>`).
* GPT-2 allows you to generate texts in parallel by setting a `batch_size` that is divisible into `nsamples`, resulting in much faster generation. Works very well with a GPU (can set `batch_size` up to 20 on Colaboratory's K80)!
* Due to GPT-2's architecture, it scales up nicely with more powerful GPUs. If you want to train for longer periods of time, GCP's P100 GPU is about 3x faster than a K80 for only 3x the price, making it compariable (the V100 is about 1.5x faster than the P100 but about 2x the price). The P100 uses 100% of the GPU even with `batch_size=1`, and about 88% of the V100 GPU.

## Planned Work

Expand All @@ -86,13 +107,12 @@ Note: this project is intended to have a very tight scope unless demand dictates

## Examples Using gpt-2-simple

* [ResetEra](https://www.resetera.com/threads/i-trained-an-ai-on-thousands-of-resetera-thread-conversations-and-it-created-hot-gaming-shitposts.112167/) — Generated video game forum discussions
* [ResetEra](https://www.resetera.com/threads/i-trained-an-ai-on-thousands-of-resetera-thread-conversations-and-it-created-hot-gaming-shitposts.112167/) — Generated video game forum discussions ([GitHub w/ dumps](https://github.com/minimaxir/resetera-gpt-2))
* [/r/legaladvice](https://www.reddit.com/r/legaladviceofftopic/comments/bfqf22/i_trained_a_moreadvanced_ai_on_rlegaladvice/) — Title generation ([GitHub w/ dumps](https://github.com/minimaxir/legaladvice-gpt2))

## Maintainer/Creator

Max Woolf ([@minimaxir](http://minimaxir.com))

*Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use.*
Max Woolf ([@minimaxir](https://minimaxir.com))

## License

Expand Down
161 changes: 155 additions & 6 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import sys
import shutil
import re
from tqdm import tqdm
from tqdm import tqdm, trange
import numpy as np
import tensorflow as tf
import time
from datetime import datetime
import csv
import argparse

# if in Google Colaboratory
try:
Expand Down Expand Up @@ -278,7 +280,8 @@ def generate(sess,
length=1023,
temperature=0.7,
top_k=0,
run_name='run1'):
run_name='run1',
include_prefix=True):
"""Generates text from a model loaded into memory.
Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
Expand Down Expand Up @@ -334,8 +337,15 @@ def generate(sess,
if prefix:
gen_text = prefix[0] + gen_text
if truncate:
trunc_text = re.search(r'(.*?)(?:{})'.format(truncate),
gen_text, re.S)
truncate_esc = re.escape(truncate)
if prefix and not include_prefix:
prefix_esc = re.escape(prefix)
pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc,
truncate_esc)
else:
pattern = '(.*?)(?:{})'.format(truncate_esc)

trunc_text = re.search(pattern, gen_text, re.S)
if trunc_text:
gen_text = trunc_text.group(1)
if destination_path:
Expand Down Expand Up @@ -363,7 +373,8 @@ def generate_to_file(sess,
length=1023,
temperature=0.7,
top_k=0,
run_name='run1'):
run_name='run1',
include_prefix=True):
"""Generates the texts to a file.
sample_delim separates texts: set to '' if each text is a small document.
Expand All @@ -384,7 +395,8 @@ def generate_to_file(sess,
length,
temperature,
top_k,
run_name)
run_name,
include_prefix)


def mount_gdrive():
Expand Down Expand Up @@ -452,3 +464,140 @@ def encode_csv(csv_path, out_path='csv_encoded.txt', header=True,
reader = csv.reader(f)
for row in reader:
w.write(start_token + row[0] + end_token + "\n")


def cmd():
"""Function called when invoking from the terminal."""

parser = argparse.ArgumentParser(
description="Easily retrain OpenAI's GPT-2 text-generating model on new texts. (https://github.com/minimaxir/gpt-2-simple)"
)

# Explicit arguments

parser.add_argument(
'--mode', help='Mode for using the CLI (either "finetune" or "generate") [Required]', nargs='?')
parser.add_argument(
'--run_name', help="[finetune/generate] Run number to save/load the model",
nargs='?', default='run1')
parser.add_argument(
'--dataset', help="[finetune] Path to the source text.",
nargs='?', default=None)
parser.add_argument(
'--steps', help="[finetune] Number of steps to train (-1 for infinite)",
nargs='?', default=-1)
parser.add_argument(
'--restore_from', help="[finetune] Whether to load model 'fresh' or from 'latest' checkpoint.",
nargs='?', default='latest')
parser.add_argument(
'--sample_every', help="[finetune] After how many steps to print sample",
nargs='?', default=1000000, type=int)
parser.add_argument(
'--save_every', help="[finetune] After how many steps to save checkpoint",
nargs='?', default=100, type=int)
parser.add_argument(
'--print_every', help="[finetune] After how many steps to print progress",
nargs='?', default=10, type=int)
parser.add_argument(
'--nfiles', help="[generate] How many files to generate.",
nargs='?', default=1, type=int)
parser.add_argument(
'--nsamples', help="[generate] How many texts to generate.",
nargs='?', default=1, type=int)
parser.add_argument(
'--folder', help="[generate] Folder to save the generated files",
nargs='?', default="gen", type=str)
parser.add_argument(
'--length', help="[generate] Length (tokens) of the generated texts",
nargs='?', default=1023, type=int)
parser.add_argument(
'--temperature', help="[generate] Temperature of the generated texts",
nargs='?', default=0.7, type=float)
parser.add_argument(
'--batch_size', help="[generate] Batch size for generation (increase for GPUs)",
nargs='?', default=1, type=int)
parser.add_argument(
'--prefix', help="[generate] Prefix for generated texts",
nargs='?', default=None)
parser.add_argument(
'--truncate', help="[generate] Truncation for generated texts",
nargs='?', default=None)
# https://stackoverflow.com/a/46951029
parser.add_argument(
'--include_prefix', help="[generate] Include prefix when truncating.",
nargs='?', default=True, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument(
'--sample_delim', help="[generate] Delimiter between each generated sample.",
nargs='?', default='=' * 20 + '\n', type=str)

# Positional arguments
parser.add_argument('mode', nargs='?')
parser.add_argument('dataset', nargs='?')

args = parser.parse_args()
assert args.mode in ['finetune', 'generate'], "Mode must be 'finetune' or 'generate'"

if args.mode == 'finetune':
assert args.dataset is not None, "You need to provide a dataset."

cmd_finetune(dataset=args.dataset, run_name=args.run_name,
steps=args.steps, restore_from=args.restore_from,
sample_every=args.sample_every,
save_every=args.save_every,
print_every=args.print_every)
if args.mode == "generate":
cmd_generate(nfiles=args.nfiles, nsamples=args.nsamples,
folder=args.folder, length=args.length,
temperature=args.temperature, batch_size=args.batch_size,
prefix=args.prefix, truncate=args.truncate,
include_prefix=args.include_prefix,
sample_delim=args.sample_delim)


def cmd_finetune(dataset, run_name, steps, restore_from, sample_every,
save_every, print_every):
"""Wrapper script for finetuning the model via the CLI."""

if not is_gpt2_downloaded():
download_gpt2()

sess = start_tf_sess()
finetune(sess, dataset=dataset, run_name=run_name,
steps=steps, restore_from=restore_from,
sample_every=sample_every, save_every=save_every,
print_every=print_every)


def cmd_generate(nfiles, nsamples, folder,
length, temperature, batch_size,
prefix, truncate, include_prefix,
sample_delim):
"""Wrapper script for generating text via the CLI.
The files are generated into a folder, which can be downloaded
recursively by downloading the entire folder.
"""

sess = start_tf_sess()
load_gpt2(sess)

try:
os.mkdir(folder)
except:
shutil.rmtree(folder)
os.mkdir(folder)

for _ in trange(nfiles):
gen_file = os.path.join(folder,
'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow()))

generate_to_file(sess,
destination_path=gen_file,
length=length,
temperature=temperature,
nsamples=nsamples,
batch_size=batch_size,
prefix=prefix,
truncate=truncate,
include_prefix=include_prefix,
sample_delim=sample_delim
)
2 changes: 1 addition & 1 deletion gpt_2_simple/src/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_dataset(enc, path, combine):
token_chunks.append(npz[item])
else:
# Plain text
with open(path, 'r') as fp:
with open(path, 'r', encoding='utf8', errors='ignore') as fp:
raw_text += fp.read()
if len(raw_text) >= combine:
tokens = np.stack(enc.encode(raw_text))
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
setup(
name='gpt_2_simple',
packages=['gpt_2_simple'], # this must be the same as the name above
version='0.2',
version='0.3',
description="Python package to easily retrain OpenAI's GPT-2 " \
"text-generating model on new texts.",
long_description=long_description,
Expand All @@ -58,6 +58,9 @@
keywords=['deep learning', 'tensorflow', 'text generation'],
classifiers=[],
license='MIT',
entry_points={
'console_scripts': ['gpt_2_simple=gpt_2_simple.gpt_2:cmd'],
},
python_requires='>=3.5',
include_package_data=True,
install_requires=['regex', 'requests', 'tqdm', 'numpy']
Expand Down

0 comments on commit cd257bb

Please sign in to comment.