diff --git a/benchmark_speculative.py b/benchmark_speculative.py new file mode 100644 index 0000000..be2e883 --- /dev/null +++ b/benchmark_speculative.py @@ -0,0 +1,50 @@ +import time +from gpt2 import main as gpt2_main +from utils import load_encoder_hparams_and_params + +def benchmark_generation(prompt, n_tokens_to_generate, model_size, use_speculative): + start_time = time.time() + gpt2_main(prompt, n_tokens_to_generate, model_size, use_generate_speculative=use_speculative) + end_time = time.time() + return end_time - start_time + +def run_benchmark(prompt, n_tokens_to_generate, model_sizes): + results = {} + + for model_size in model_sizes: + print(f"Benchmarking {model_size} model...") + + # Warm-up run + benchmark_generation(prompt, n_tokens_to_generate, model_size, False) + benchmark_generation(prompt, n_tokens_to_generate, model_size, True) + + # Actual benchmark + standard_time = benchmark_generation(prompt, n_tokens_to_generate, model_size, False) + speculative_time = benchmark_generation(prompt, n_tokens_to_generate, model_size, True) + + improvement = (standard_time - speculative_time) / standard_time * 100 + results[model_size] = { + "standard_time": standard_time, + "speculative_time": speculative_time, + "improvement": improvement + } + + return results + +def main(): + prompt = "In a world where artificial intelligence has become ubiquitous" + n_tokens_to_generate = 50 + model_sizes = ["124M", "355M"] + + results = run_benchmark(prompt, n_tokens_to_generate, model_sizes) + + print("\nBenchmark Results:") + print("==================") + for model_size, data in results.items(): + print(f"\nModel Size: {model_size}") + print(f"Standard Generation Time: {data['standard_time']:.4f} seconds") + print(f"Speculative Generation Time: {data['speculative_time']:.4f} seconds") + print(f"Improvement: {data['improvement']:.2f}%") + +if __name__ == "__main__": + main() diff --git a/gpt2.py b/gpt2.py index c85e8b8..2f330f4 100644 --- a/gpt2.py +++ b/gpt2.py @@ -94,12 +94,50 @@ def generate(inputs, params, n_head, n_tokens_to_generate): return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids -def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): +def generate_speculative(inputs, params, draft_params, n_head, n_tokens_to_generate, n_speculative=3): + from tqdm import tqdm + + for _ in tqdm(range(n_tokens_to_generate), "generating"): + # Generate speculative tokens using the draft model + draft_inputs = inputs.copy() + draft_tokens = [] + for _ in range(n_speculative): + # Use the draft model to predict the next token + draft_logits = gpt2(draft_inputs, **draft_params, n_head=n_head) + next_id = np.argmax(draft_logits[-1]) + draft_tokens.append(int(next_id)) + draft_inputs.append(next_id) + + # Verify speculative tokens using the main model + main_logits = gpt2(inputs + draft_tokens, **params, n_head=n_head) + main_probs = softmax(main_logits[-n_speculative-1:]) + + # Compare draft model predictions with main model predictions + accepted_tokens = 0 + for i, token in enumerate(draft_tokens): + if np.argmax(main_probs[i]) == token: + accepted_tokens += 1 + else: + break + + # Add accepted tokens to the input + inputs.extend(draft_tokens[:accepted_tokens]) + + # If no tokens were accepted, use the main model's prediction + if accepted_tokens == 0: + next_id = np.argmax(main_probs[0]) + inputs.append(int(next_id)) + + # Return only the newly generated tokens + return inputs[len(inputs) - n_tokens_to_generate:] + + +def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models", use_generate_speculative: bool = True): from utils import load_encoder_hparams_and_params # load encoder, hparams, and params from the released open-ai gpt-2 files encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir) - + _, _, draft_params = load_encoder_hparams_and_params("124M", models_dir) # encode the input string using the BPE tokenizer input_ids = encoder.encode(prompt) @@ -107,7 +145,10 @@ def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"] # generate output ids - output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) + if use_generate_speculative: + output_ids = generate_speculative(input_ids, params, draft_params, hparams["n_head"], n_tokens_to_generate) + else: + output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) # decode the ids back into a string output_text = encoder.decode(output_ids)