-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit e0a0886
Showing
9 changed files
with
41,387 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
__* | ||
log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Recreating gpt2 | ||
## About | ||
Recreated OpenAI's GPT-2 through looking over the GPT-2 and GPT-3 papers and following Andrej Karpathy's Make More series. Trained the model on 8 H100s rented through Lambda Labs achieving a lower final loss than OpenAI's original GPT-2. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
Loads the FineWeb edu dataset from hugging face | ||
Saves shards of the dataset to local dir "edu_fineweb10B" | ||
""" | ||
import os | ||
import multiprocessing as mp | ||
import numpy as np | ||
import tiktoken | ||
from datasets import load_dataset # pip install datasets | ||
from tqdm import tqdm | ||
|
||
#-------------------------------------------------#------------------------------------------------- | ||
local_dir = "edu_fineweb10B" | ||
remote_name = "sample-10BT" | ||
shard_size = int(1e8) # 100M tokens per shard total 100 shards | ||
|
||
# create cache for local dir if it doesn't exist | ||
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir) | ||
os.makedirs(DATA_CACHE_DIR, exist_ok=True) | ||
|
||
# download dataset from internet | ||
fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train") | ||
|
||
# init tokenizer | ||
enc = tiktoken.get_encoding('gpt2') | ||
eot = enc._special_tokens["<|endoftext|>"] # end of text token | ||
def tokenize(doc): | ||
# tokenizes a document and returns the tokens as an array of np.uint16 | ||
tokens = [eot] # this token begins a document | ||
tokens.extend(enc.encode_ordinary(doc["text"])) | ||
tokens_np = np.array(tokens) | ||
assert (0 <= tokens_np).all() and (tokens_np < 2**16).all() | ||
tokens_np_uint16 = tokens_np.astype(np.uint16) | ||
return tokens_np_uint16 | ||
|
||
def write_datafile(filename, tokens_np): | ||
# writes np array of tokens as binary file | ||
np.save(filename, tokens_np) | ||
|
||
# tokenize all documents and save all output shards, each of shard_size number of tokens | ||
nprocs = max(1, os.cpu_count()//2) | ||
with mp.Pool(nprocs) as pool: | ||
shard_index = 0 | ||
# preallocate buffer to hold current shard | ||
all_tokens_np = np.empty((shard_size,), dtype=np.uint16) | ||
token_count = 0 | ||
progress_bar = None | ||
for tokens in pool.imap(tokenize, fw, chunksize=16): | ||
|
||
# is there enough space in the current shard for the new tokens | ||
if token_count + len(tokens) < shard_size: | ||
# append tokens to current shard | ||
all_tokens_np[token_count:token_count+len(tokens)] = tokens | ||
token_count += len(tokens) | ||
# update progress bar | ||
if progress_bar is None: | ||
progress_bar = tqdm(total=shard_size, unit='tokens', desc=f'Shard {shard_index}') | ||
progress_bar.update(len(tokens)) | ||
else: | ||
# write the current shard and start a new one | ||
split = 'val' if shard_index == 0 else 'train' | ||
filename = os.path.join(DATA_CACHE_DIR, f"edu_fineweb_{split}_{shard_index:06d}") | ||
# split documents into whatever fits into this shard | ||
remainder = shard_size - token_count | ||
all_tokens_np[token_count:token_count+remainder] = tokens[:remainder] | ||
write_datafile(filename, all_tokens_np) | ||
shard_index += 1 | ||
progress_bar = None | ||
#populate next shard with leftovers from current doc | ||
token_count = len(tokens) - remainder | ||
all_tokens_np[:token_count] = tokens[remainder:] | ||
|
||
# write any remaining tokens as last shard | ||
if token_count != 0: | ||
split = 'val' if shard_index == 0 else 'train' | ||
filename = os.path.join(DATA_CACHE_DIR, f"edu_fineweb_{split}_{shard_index:06d}") | ||
write_datafile(filename, all_tokens_np[:token_count]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
from torch.nn import functional as F | ||
from train_gpt2 import GPT2Config, GPT | ||
|
||
|
||
device = ( | ||
'cuda' if torch.cuda.is_available() | ||
else 'mps' if torch.backends.mps.is_available() | ||
else 'cpu' | ||
) | ||
device = 'cpu' | ||
print(f"{device=}") | ||
|
||
# running base gpt-2 model 124M params | ||
checkpoint = torch.load("./log/model_19072.pt", map_location=device) | ||
print(checkpoint['config']) | ||
model = GPT(checkpoint['config']) | ||
model.load_state_dict(checkpoint['model']) | ||
|
||
num_return_sequences = 5 | ||
max_length = 60 | ||
|
||
# prefix tokens | ||
import tiktoken | ||
enc = tiktoken.get_encoding('gpt2') | ||
|
||
# get input and then generate potential outputs | ||
print("This is the Alex's gpt2, please enter a phrase or sentence you wish to be completed:") | ||
in_string = input() | ||
|
||
tokens = enc.encode(in_string) | ||
tokens = torch.tensor(tokens, dtype=torch.long) # (8,) | ||
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) #(5,8) | ||
# print(tokens) | ||
x = tokens.to(device) | ||
|
||
# generate! x is (B,T) B=5 T=8 | ||
#set seed to 42 | ||
torch.manual_seed(42) | ||
torch.cuda.manual_seed(42) | ||
while x.size(1) < max_length: | ||
with torch.no_grad(): | ||
# forward the model to get logits | ||
logits = model(x)[0] | ||
# get last logits for each batch | ||
logits = logits[:,-1,:] #(B, vocab_size) | ||
# get the probabilites via softmax | ||
probs = F.softmax(logits, dim=-1) | ||
# do topk sampling to get top 50 (default from HF) | ||
# topk_probs and topk_indices both (5,50) | ||
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) | ||
# select a token from topk | ||
ix = torch.multinomial(topk_probs, 1) # (B,1) | ||
# get corresponding indices | ||
xcol = torch.gather(topk_indices, -1, ix) | ||
# append to the sequence | ||
x = torch.cat((x, xcol), dim=1) | ||
|
||
|
||
# print generated text | ||
for i in range(num_return_sequences): | ||
tokens = x[i, :max_length].tolist() | ||
decoded = enc.decode(tokens) | ||
print(">", decoded) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
""" | ||
Downloads and evaluates HellaSwag in Python. | ||
https://github.com/rowanz/hellaswag | ||
Example HellaSwag json item: | ||
{"ind": 24, "activity_label": "Roof shingle removal", "ctx_a": "A man is sitting on a roof.", "ctx_b": "he", "ctx": "A man is sitting on a roof. he", "split": "val", "split_type": "indomain", "label": 3, "endings": ["is using wrap to wrap a pair of skis.", "is ripping level tiles off.", "is holding a rubik's cube.", "starts pulling up roofing on a roof."], "source_id": "activitynet~v_-JhWjGDPHMY"} | ||
ind: dataset ID | ||
activity_label: The ActivityNet or WikiHow label for this example | ||
context: There are two formats. The full context is in ctx. When the context ends in an (incomplete) noun phrase, like for ActivityNet, this incomplete noun phrase is in ctx_b, and the context up until then is in ctx_a. This can be useful for models such as BERT that need the last sentence to be complete. However, it's never required. If ctx_b is nonempty, then ctx is the same thing as ctx_a, followed by a space, then ctx_b. | ||
endings: a list of 4 endings. The correct index is given by label (0,1,2, or 3) | ||
split: train, val, or test. | ||
split_type: indomain if the activity label is seen during training, else zeroshot | ||
source_id: Which video or WikiHow article this example came from | ||
gpt2 (124M) | ||
- eleuther harness reports acc 28.92%, acc_norm 31.14% (multiple choice style) | ||
- this script: 10042 acc: 0.2859 acc_norm: 0.2955 (completion style) | ||
gpt2-xl (1558M) | ||
- eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style) | ||
- this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style) | ||
The validation set of HellaSwag has a total of 10,042 examples. | ||
""" | ||
|
||
import os | ||
import json | ||
import requests | ||
import tiktoken | ||
from tqdm import tqdm | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn import functional as F | ||
from transformers import GPT2LMHeadModel | ||
|
||
# ----------------------------------------------------------------------------- | ||
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag") | ||
|
||
def download_file(url: str, fname: str, chunk_size=1024): | ||
"""Helper function to download a file from a given url""" | ||
resp = requests.get(url, stream=True) | ||
total = int(resp.headers.get("content-length", 0)) | ||
with open(fname, "wb") as file, tqdm( | ||
desc=fname, | ||
total=total, | ||
unit="iB", | ||
unit_scale=True, | ||
unit_divisor=1024, | ||
) as bar: | ||
for data in resp.iter_content(chunk_size=chunk_size): | ||
size = file.write(data) | ||
bar.update(size) | ||
|
||
hellaswags = { | ||
"train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl", | ||
"val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl", | ||
"test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl", | ||
} | ||
|
||
enc = tiktoken.get_encoding("gpt2") | ||
|
||
def download(split): | ||
"""Downloads HellaSwag DATA_CACHE_DIR""" | ||
os.makedirs(DATA_CACHE_DIR, exist_ok=True) | ||
data_url = hellaswags[split] | ||
data_filename = os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl") | ||
if not os.path.exists(data_filename): | ||
print(f"Downloading {data_url} to {data_filename}...") | ||
download_file(data_url, data_filename) | ||
|
||
def render_example(example): | ||
""" | ||
Given the example as a dictionary, render it as three torch tensors: | ||
- tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates) | ||
- mask (is 1 in the region of the candidate completion, where we evaluate likelihoods) | ||
- label (the index of the correct completion, which we hope has the highest likelihood) | ||
""" | ||
ctx = example["ctx"] | ||
label = example["label"] | ||
endings = example["endings"] | ||
|
||
# data needed to reproduce this eval on the C size | ||
data = { | ||
"label": label, | ||
"ctx_tokens": None, | ||
"ending_tokens": [], | ||
} | ||
|
||
# gather up all the tokens | ||
ctx_tokens = enc.encode(ctx) | ||
data["ctx_tokens"] = ctx_tokens | ||
tok_rows = [] | ||
mask_rows = [] | ||
for end in endings: | ||
end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer | ||
tok_rows.append(ctx_tokens + end_tokens) | ||
mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens)) | ||
data["ending_tokens"].append(end_tokens) | ||
|
||
# have to be careful during the collation because the number of tokens in each row can differ | ||
max_len = max(len(row) for row in tok_rows) | ||
tokens = torch.zeros((4, max_len), dtype=torch.long) | ||
mask = torch.zeros((4, max_len), dtype=torch.long) | ||
for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)): | ||
tokens[i, :len(tok_row)] = torch.tensor(tok_row) | ||
mask[i, :len(mask_row)] = torch.tensor(mask_row) | ||
|
||
return data, tokens, mask, label | ||
|
||
def iterate_examples(split): | ||
# there are 10,042 examples in total in val | ||
download(split) | ||
with open(os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r") as f: | ||
for line in f: | ||
example = json.loads(line) | ||
yield example | ||
|
||
@torch.no_grad() | ||
def evaluate(model_type, device): | ||
|
||
torch.set_float32_matmul_precision('high') # use tf32 | ||
model = GPT2LMHeadModel.from_pretrained(model_type) | ||
model.to(device) | ||
# model = torch.compile(model) # optionally torch compile the model | ||
|
||
num_correct_norm = 0 | ||
num_correct = 0 | ||
num_total = 0 | ||
for example in iterate_examples("val"): | ||
data, tokens, mask, label = render_example(example) | ||
tokens = tokens.to(device) | ||
mask = mask.to(device) | ||
|
||
# get the logits | ||
logits = model(tokens).logits | ||
# evaluate the autoregressive loss at all positions | ||
shift_logits = (logits[..., :-1, :]).contiguous() | ||
shift_tokens = (tokens[..., 1:]).contiguous() | ||
flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) | ||
flat_shift_tokens = shift_tokens.view(-1) | ||
shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none') | ||
shift_losses = shift_losses.view(tokens.size(0), -1) | ||
# now get the average loss just for the completion region (where mask == 1), in each row | ||
shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token | ||
masked_shift_losses = shift_losses * shift_mask | ||
# sum and divide by the number of 1s in the mask | ||
sum_loss = masked_shift_losses.sum(dim=1) | ||
avg_loss = sum_loss / shift_mask.sum(dim=1) | ||
# now we have a loss for each of the 4 completions | ||
# the one with the lowest loss should be the most likely | ||
pred = sum_loss.argmin().item() | ||
pred_norm = avg_loss.argmin().item() | ||
|
||
# accumulate stats | ||
num_total += 1 | ||
num_correct += int(pred == label) | ||
num_correct_norm += int(pred_norm == label) | ||
print(f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}") | ||
|
||
# debug: pretty print a few examples, and the losses in each case | ||
if num_total < 10: | ||
print("---") | ||
print(f"Context:\n {example['ctx']}") | ||
print(f"Endings:") | ||
for i, end in enumerate(example["endings"]): | ||
print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}") | ||
print(f"predicted: {pred_norm}, actual: {label}") | ||
|
||
if __name__ == "__main__": | ||
import argparse | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-m", "--model_type", type=str, default="gpt2", help="the model type to use") | ||
parser.add_argument("-d", "--device", type=str, default="cuda", help="the device to use") | ||
args = parser.parse_args() | ||
evaluate(args.model_type, args.device) |
Oops, something went wrong.