Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds speculative decoding #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions benchmark_speculative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import time
from gpt2 import main as gpt2_main
from utils import load_encoder_hparams_and_params

def benchmark_generation(prompt, n_tokens_to_generate, model_size, use_speculative):
start_time = time.time()
gpt2_main(prompt, n_tokens_to_generate, model_size, use_generate_speculative=use_speculative)
end_time = time.time()
return end_time - start_time

def run_benchmark(prompt, n_tokens_to_generate, model_sizes):
results = {}

for model_size in model_sizes:
print(f"Benchmarking {model_size} model...")

# Warm-up run
benchmark_generation(prompt, n_tokens_to_generate, model_size, False)
benchmark_generation(prompt, n_tokens_to_generate, model_size, True)

# Actual benchmark
standard_time = benchmark_generation(prompt, n_tokens_to_generate, model_size, False)
speculative_time = benchmark_generation(prompt, n_tokens_to_generate, model_size, True)

improvement = (standard_time - speculative_time) / standard_time * 100
results[model_size] = {
"standard_time": standard_time,
"speculative_time": speculative_time,
"improvement": improvement
}

return results

def main():
prompt = "In a world where artificial intelligence has become ubiquitous"
n_tokens_to_generate = 50
model_sizes = ["124M", "355M"]

results = run_benchmark(prompt, n_tokens_to_generate, model_sizes)

print("\nBenchmark Results:")
print("==================")
for model_size, data in results.items():
print(f"\nModel Size: {model_size}")
print(f"Standard Generation Time: {data['standard_time']:.4f} seconds")
print(f"Speculative Generation Time: {data['speculative_time']:.4f} seconds")
print(f"Improvement: {data['improvement']:.2f}%")

if __name__ == "__main__":
main()
47 changes: 44 additions & 3 deletions gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,61 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids


def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
def generate_speculative(inputs, params, draft_params, n_head, n_tokens_to_generate, n_speculative=3):
from tqdm import tqdm

for _ in tqdm(range(n_tokens_to_generate), "generating"):
# Generate speculative tokens using the draft model
draft_inputs = inputs.copy()
draft_tokens = []
for _ in range(n_speculative):
# Use the draft model to predict the next token
draft_logits = gpt2(draft_inputs, **draft_params, n_head=n_head)
next_id = np.argmax(draft_logits[-1])
draft_tokens.append(int(next_id))
draft_inputs.append(next_id)

# Verify speculative tokens using the main model
main_logits = gpt2(inputs + draft_tokens, **params, n_head=n_head)
main_probs = softmax(main_logits[-n_speculative-1:])

# Compare draft model predictions with main model predictions
accepted_tokens = 0
for i, token in enumerate(draft_tokens):
if np.argmax(main_probs[i]) == token:
accepted_tokens += 1
else:
break

# Add accepted tokens to the input
inputs.extend(draft_tokens[:accepted_tokens])

# If no tokens were accepted, use the main model's prediction
if accepted_tokens == 0:
next_id = np.argmax(main_probs[0])
inputs.append(int(next_id))

# Return only the newly generated tokens
return inputs[len(inputs) - n_tokens_to_generate:]


def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models", use_generate_speculative: bool = True):
from utils import load_encoder_hparams_and_params

# load encoder, hparams, and params from the released open-ai gpt-2 files
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)

_, _, draft_params = load_encoder_hparams_and_params("124M", models_dir)
# encode the input string using the BPE tokenizer
input_ids = encoder.encode(prompt)

# make sure we are not surpassing the max sequence length of our model
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]

# generate output ids
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
if use_generate_speculative:
output_ids = generate_speculative(input_ids, params, draft_params, hparams["n_head"], n_tokens_to_generate)
else:
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)

# decode the ids back into a string
output_text = encoder.decode(output_ids)
Expand Down