GRPO is in Unsloth!
- Experience the "aha moment" from DeepSeek R1's paper now with Unsloth!
- LoRA (16bit) / QLoRA (4bit) actually work for GRPO now!
- Unsloth can do GRPO for Phi-4 14B Llama-3.1 8B in a free 15GB Colab GPU!
- Unsloth now has native fast inference (20x more throughput) via vLLM! Use it via
model.fast_generate
after settingFastLanguageModel.from_pretrained(..., fast_inference = True)
and installing vLLM viapip install vllm
- Llama 3.3 70B QLoRA GRPO should fit in 1x 48GB (best 1x 80GB)
- Update unsloth via
pip install --upgrade --no-cache-dir --force-reinstall unsloth_zoo unsloth vllm
GRPO Notebooks
Model | Type | Colab Link |
---|---|---|
Phi 4 (14B) | GRPO | Open in Colab |
Llama 3.1 (8B) | GRPO | Open in Colab |
Qwen 2.5 (3B) | GRPO | Open in Colab |
Minimal GRPO example (courtesy of Will Brown]
!pip install unsloth vllm
!pip install git+https://github.com/huggingface/trl.git
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
from unsloth import is_bfloat16_supported
import torch
max_seq_length = 512
lora_rank = 32
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True,
fast_inference = True,
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6,
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank,
lora_alpha = lora_rank,
)
import re
from datasets import load_dataset, Dataset
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
}) # type: ignore
return data # type: ignore
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
use_vllm = True, # use vLLM for fast inference!
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "paged_adamw_8bit",
logging_steps = 1,
bf16 = is_bfloat16_supported(),
fp16 = not is_bfloat16_supported(),
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1,
num_generations = 6,
max_prompt_length = 256,
max_completion_length = 200,
# num_train_epochs = 1,
max_steps = 250,
save_steps = 250,
max_grad_norm = 0.1,
report_to = "none",
output_dir = "outputs",
)
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = dataset,
)
trainer.train()
Bug Fixes
- Gemma 2 should be fixed now
- Mistral base mapping should be fixed
- Some syntax warning issue fixes
- And many many more bug fixes!
What's Changed
- Add use_exact_model_name option to prevent automatic model name modification by @niryuu in #1339
- Improve debugging experience by @Erland366 in #1512
- changing model to base_model if peft model is already used by @mosama1994 in #1509
- All attention refactor fix by @KareemMusleh in #1491
- Update granite to work with latest post_patch methods by @Datta0 in #1502
- Minor fixes for granite models by @CoffeeVampir3 in #1503
- support modelscope models and datasets by @tastelikefeet in #1481
- Update README.md by @shimmyshimmer in #1529
- Update bug_report.md by @danielhanchen in #1538
- Update README.md by @shimmyshimmer in #1542
- Torch.Cuda Is Available Condition and Warning by @aminwhat in #1545
- Add dropout to granite to match HF's implementation by @Datta0 in #1557
- fix: flash_attn_detection_error by @Zzhiter in #1556
- Fix Mistral, Qwen by @danielhanchen in #1565
- Update README.md by @shimmyshimmer in #1569
- Update README.md by @shimmyshimmer in #1580
- Update README.md by @shimmyshimmer in #1595
- Mistral 24B, Qwen 2.5 VL support by @danielhanchen in #1598
- GRPO, vLLM, Bug Fixes, Reinforcement Learning by @danielhanchen in #1620
New Contributors
- @niryuu made their first contribution in #1339
- @mosama1994 made their first contribution in #1509
- @KareemMusleh made their first contribution in #1491
- @tastelikefeet made their first contribution in #1481
- @aminwhat made their first contribution in #1545
- @Zzhiter made their first contribution in #1556
Full Changelog: 2025-01...2025-02