Skip to content

Commit

Permalink
(dcy) add readme and typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Berit-chengyi committed Feb 28, 2025
1 parent 1bfc4f2 commit aafbc18
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 99 deletions.
31 changes: 13 additions & 18 deletions ding/rl_utils/grpo.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from typing import Tuple
from collections import namedtuple
import torch
from .log_prob_utils import efficient_method, naive_method, less_efficient_method
from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction

grpo_policy_data = namedtuple('grpo_policy_data',
['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight'])
grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight'])
grpo_info = namedtuple('grpo_info', ['approx_kl', 'clipfrac'])


def grpo_policy_error(
data: namedtuple,
log_prob_fn = efficient_method, # Method to calculate the log probabilities
log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities
clip_ratio: float = 0.2,
beta: float = 0.1 # Weight coefficient for KL divergence
) -> Tuple[namedtuple, namedtuple]:
beta: float = 0.1 # Weight coefficient for KL divergence
) -> Tuple[torch.Tensor, namedtuple]:
"""
Overview:
Implementation of Generalized Reward-Conditioned Policy Optimization( arXiv:2405.20304) .
Arguments:
- data (:obj:`namedtuple`): the grpo input data with fields shown in ``grpo_policy_data``.
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
- beta (:obj:`float`): weight coefficient for KL divergence regularization, defaults to 0.1.
- logpro_cal (:obj:`function`): the method to calculate the log probabilities, defaults to efficient_method.
- log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, defaults to `efficient_method`.
Returns:
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
- grpo_info (:obj:`namedtuple`): the grpo optim information for monitoring, all of them are Python scalar.
Shapes:
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,
and V is vocabulary size.
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length, \
and V is vocabulary size.
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- logit_ref (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- action (:obj:`torch.LongTensor`): :math:`(B, S)`.
Expand All @@ -39,17 +39,13 @@ def grpo_policy_error(
"""

# Calculate log probabilities for selected token
per_token_logps= log_prob_fn(data.logit_new, data.action)
per_token_logps = log_prob_fn(data.logit_new, data.action)
per_token_ref_logps = log_prob_fn(data.logit_ref, data.action)
per_token_old_logps = log_prob_fn(data.logit_old, data.action)


# Calculate KL divergence: exp(q-p) - (q-p) - 1,
# where p is current policy and q is reference policy
per_token_kl = (
torch.exp(per_token_ref_logps - per_token_logps) -
(per_token_ref_logps - per_token_logps) - 1
)
per_token_kl = (torch.exp(per_token_ref_logps - per_token_logps) - (per_token_ref_logps - per_token_logps) - 1)

# Calculate policy ratio
ratio = torch.exp(per_token_logps - per_token_old_logps)
Expand All @@ -59,8 +55,7 @@ def grpo_policy_error(
advantages = data.adv.unsqueeze(1) # [B, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped,
per_token_loss_clipped)
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)

# Add KL divergence regularization term
per_token_loss = per_token_loss + beta * per_token_kl
Expand Down
31 changes: 15 additions & 16 deletions ding/rl_utils/log_prob_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Callable, Optional
from typing import List, Callable, Optional, Any
import torch
from torch import Tensor

LogitsProcessor = Callable[[Tensor, Tensor], Tensor]


def naive_method(logits: Tensor, index: Tensor) -> Tensor:
"""Calculate per-token log probabilities using naive method.
Expand Down Expand Up @@ -39,16 +40,10 @@ def efficient_method(logits: Tensor, index: Tensor) -> Tensor:
Tensor: Log probabilities for selected tokens of shape [B, S] or [S]
"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits: Tensor = torch.gather(
logits,
dim=-1,
index=index.unsqueeze(-1)
).squeeze(-1)
selected_logits: Tensor = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)

# Loop to reduce peak mem consumption
logsumexp_values: Tensor = torch.stack([
torch.logsumexp(lg, dim=-1) for lg in logits
])
logsumexp_values: Tensor = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])

# log_softmax(x_i) = x_i - logsumexp(x)
per_token_logps: Tensor = selected_logits - logsumexp_values
Expand All @@ -59,30 +54,34 @@ def efficient_method(logits: Tensor, index: Tensor) -> Tensor:
# Loop to reduce peak mem consumption
for row_logits, row_labels in zip(logits, index): # Iterate over sequence length
row_logps: Tensor = torch.log_softmax(row_logits, dim=-1)
row_per_token_logps: Tensor = row_logps.gather(
dim=-1,
index=row_labels.unsqueeze(-1)
).squeeze(-1)
row_per_token_logps: Tensor = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)

Check warning on line 58 in ding/rl_utils/log_prob_utils.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/log_prob_utils.py#L55-L58

Added lines #L55 - L58 were not covered by tests

per_token_logps = torch.stack(per_token_logps)

Check warning on line 60 in ding/rl_utils/log_prob_utils.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/log_prob_utils.py#L60

Added line #L60 was not covered by tests

return per_token_logps


def less_efficient_method(logits: Tensor, action: Tensor) -> Tensor:
def less_efficient_method(logits: Tensor, index: Tensor) -> Tensor:
"""Calculate per-token log probabilities using categorical distribution.
Args:
logits: Token logits of shape [B, S, V] or [S, V] where:
B = batch size
S = sequence length
V = vocabulary size
action: Selected token indices of shape [B, S] or [S]
index: Selected token indices of shape [B, S] or [S]
Returns:
Tensor: Log probabilities for selected tokens of shape [B, S] or [S]
"""
dist = torch.distributions.categorical.Categorical(logits=logits)
logp: Tensor = dist.log_prob(action)
logp: Tensor = dist.log_prob(index)
return logp

Check warning on line 80 in ding/rl_utils/log_prob_utils.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/log_prob_utils.py#L78-L80

Added lines #L78 - L80 were not covered by tests


# 定义一个统一的类型
LogProbFunction = Callable[[Tensor, Tensor], Tensor]

# 导出所有方法
__all__ = ['naive_method', 'efficient_method', 'less_efficient_method', 'LogProbFunction']
28 changes: 13 additions & 15 deletions ding/rl_utils/rloo.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Tuple
from collections import namedtuple
import torch
from .log_prob_utils import efficient_method, naive_method, less_efficient_method
from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction


rloo_policy_data = namedtuple('rloo_policy_data',
['logit_new', 'logit_old', 'action', 'reward', 'weight'])
rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'reward', 'weight'])
rloo_info = namedtuple('rloo_info', ['approx_kl', 'clipfrac'])


def rloo_policy_error(
data: namedtuple,
logpro_cal = efficient_method, # Method to calculate the log probabilities
log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities
clip_ratio: float = 0.2,
) -> Tuple[namedtuple, namedtuple]:
"""
Expand All @@ -19,12 +18,13 @@ def rloo_policy_error(
Arguments:
- data (:obj:`namedtuple`): the rloo input data with fields shown in ``rloo_policy_data``.
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
- log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, defaults to `efficient_method`.
Returns:
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
- rloo_info (:obj:`namedtuple`): the rloo optim information for monitoring, all of them are Python scalar.
Shapes:
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,
and V is vocabulary size.
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,\
and V is vocabulary size.
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- action (:obj:`torch.LongTensor`): :math:`(B, S)`.
- reward (:obj:`torch.FloatTensor`): :math:`(K, B)`, where K is the number of samples per prompt.
Expand All @@ -42,8 +42,8 @@ def rloo_policy_error(
adv = adv.flatten()

# Get log probabilities for selected actions
per_token_logps= logpro_cal(data.logit_new, data.action)
per_token_old_logps = logpro_cal(data.logit_old, data.action)
per_token_logps = log_prob_fn(data.logit_new, data.action)
per_token_old_logps = log_prob_fn(data.logit_old, data.action)

# Calculate policy ratio
ratio = torch.exp(per_token_logps - per_token_old_logps)
Expand All @@ -53,18 +53,16 @@ def rloo_policy_error(
advantages = adv.unsqueeze(1) # [B, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped,
per_token_loss_clipped)
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)

# Calculate average loss using weight mask
weight = data.weight if data.weight is not None else (
torch.ones_like(per_token_loss))
weight = data.weight if data.weight is not None else (torch.ones_like(per_token_loss))
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# Calculate additional metrics
# Calculate additional metrics
with torch.no_grad():
approx_kl = (per_token_old_logps - per_token_logps).mean().item()
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
clipfrac = torch.as_tensor(clipped).float().mean().item()

return loss, rloo_info(approx_kl=approx_kl, clipfrac=clipfrac)
return loss, rloo_info(approx_kl=approx_kl, clipfrac=clipfrac)
103 changes: 103 additions & 0 deletions ding/rl_utils/tests/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Testing Log Probability Methods with GPU

The script `test_log_prob_fn` benchmarks different methods for calculating log probabilities (`naive_method`, `efficient_method`, `less_efficient_method`) using both `float32` and `bfloat16` precision formats on a GPU.

## Overview

The script performs benchmarks on three different methods for calculating log probabilities and reports the time and peak GPU memory usage for each method:

- **Naive Method**
- **Efficient Method**
- **Less Efficient Method**

It runs the tests for two types of precision:

- `float32`
- `bfloat16`

## Test Functions

There are two main test functions in the script:

1. **`test_log_prob_methods_float32`**: This function benchmarks the three methods using `float32` precision.
2. **`test_log_prob_methods_bfloat16`**: This function benchmarks the three methods using `bfloat16` precision.

### Workflow

1. **Parameters Setup**: The tests are executed using a batch size of `16`, a sequence length of `1024`, and a dictionary size of `32768`. The data is randomly generated for benchmarking.
2. **GPU Memory Tracking**: The GPU memory is tracked using `torch.cuda.max_memory_allocated()` to measure the peak memory usage during the benchmark.
3. **Method Execution**: Each method is run multiple times (10 iterations) to measure the execution time and to ensure stability.
4. **Results Validation**: The results from each method are compared with the `Naive` method to check for correctness, with a tolerance value applied for `bfloat16` precision.

### Benchmarked Methods:

- **Naive Method**: The basic, unoptimized method for calculating log probabilities.
- **Efficient Method**: An optimized version of the naive method to reduce memory usage.
- **Less Efficient Method**: A method with a higher memory consumption compared to the efficient method.

### GPU Memory Usage:

The function `get_gpu_memory()` is used to fetch the current peak GPU memory usage during the execution of each method.

## Output Example

### Testing with `float32` Precision

```
==================================================
Testing with float32 precision
==================================================
Naive:
Time: 5.07 ± 0.83 ms
Peak GPU Memory: 4096.31 MB
Efficient:
Time: 15.76 ± 21.19 ms
Peak GPU Memory: 2176.44 MB
Less_Efficient:
Time: 14.63 ± 5.06 ms
Peak GPU Memory: 4608.39 MB
PASSED [100%]
```

### Testing with `bfloat16` Precision

```
==================================================
Testing with bfloat16 precision
==================================================
Naive:
Time: 1.42 ± 0.00 ms
Peak GPU Memory: 2048.22 MB
Efficient:
Time: 1.83 ± 0.01 ms
Peak GPU Memory: 1152.25 MB
Less_Efficient:
Time: 8.67 ± 0.07 ms
Peak GPU Memory: 2560.27 MB
```

## Results Analysis

- Execution Time
- The Naive method is the fastest in both precisions but sacrifices memory efficiency.
- The Efficient method balances memory usage and execution time, though it is slower than the Naive method.
- The Less Efficient method is slower than both the Naive and Efficient methods and consumes the most memory, making it the least desirable for both speed and memory usage.
- GPU Memory
- The Efficient method consistently uses the least memory, especially in the `bfloat16` precision where it achieves the lowest memory consumption.
- The Naive method uses more memory than the Efficient method but has lower execution times.
- The Less Efficient method consumes the most memory in both precision formats.

## How to Run the Tests

To run the tests:

```bash
pytest -v -s test_log_prob_fn.py
```

17 changes: 4 additions & 13 deletions ding/rl_utils/tests/test_grpo_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@ def dictionary_num():


@pytest.mark.unittest
def test_grpo_policy_loss_with_mask(
batch_size: int = 4,
seq_length: int = 8,
vocab_size: int = 1000):
def test_grpo_policy_loss_with_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
"""Test GRPO policy loss calculation with mask"""
# 1. Create test data
logit_new = (torch.randn(batch_size, seq_length,
vocab_size).requires_grad_(True))
logit_new = (torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True))
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2
action = torch.randint(0, vocab_size, (batch_size, seq_length))
Expand Down Expand Up @@ -64,21 +60,16 @@ def test_grpo_policy_loss_with_mask(
loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)


assert 'approx_kl' in info._asdict()
assert 'clipfrac' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])


@pytest.mark.unittest
def test_grpo_policy_loss_without_mask(
batch_size: int = 4,
seq_length: int = 8,
vocab_size: int = 1000):
def test_grpo_policy_loss_without_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
"""Test GRPO policy loss calculation without mask"""
# 1. Create test data
logit_new = torch.randn(batch_size,
seq_length, vocab_size).requires_grad_(True)
logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2
action = torch.randint(0, vocab_size, (batch_size, seq_length))
Expand Down
Loading

0 comments on commit aafbc18

Please sign in to comment.