Skip to content

Commit

Permalink
(dcy)git add readme and typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Berit-chengyi committed Feb 28, 2025
1 parent b594a38 commit 4acb28d
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 30 deletions.
15 changes: 8 additions & 7 deletions ding/rl_utils/grpo.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
from typing import Tuple
from collections import namedtuple
import torch
from .log_prob_utils import efficient_method, naive_method, less_efficient_method
from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction

grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight'])
grpo_info = namedtuple('grpo_info', ['approx_kl', 'clipfrac'])


def grpo_policy_error(
data: namedtuple,
log_prob_fn=efficient_method, # Method to calculate the log probabilities
log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities
clip_ratio: float = 0.2,
beta: float = 0.1 # Weight coefficient for KL divergence
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Implementation of Generalized Reward-Conditioned Policy Optimization( arXiv:2405.20304) .
Implementation of Generalized Reward-Conditioned Policy Optimization( arxiv:2402.03300) .
Arguments:
- data (:obj:`namedtuple`): the grpo input data with fields shown in ``grpo_policy_data``.
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
- beta (:obj:`float`): weight coefficient for KL divergence regularization, defaults to 0.1.
- logpro_cal (:obj:`function`): the method to calculate the log probabilities, defaults to efficient_method.
- log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, \
defaults to `efficient_method`.
Returns:
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
- grpo_info (:obj:`namedtuple`): the grpo optim information for monitoring, all of them are Python scalar.
Shapes:
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,
and V is vocabulary size.
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length, \
and V is vocabulary size.
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- logit_ref (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- action (:obj:`torch.LongTensor`): :math:`(B, S)`.
Expand Down
15 changes: 11 additions & 4 deletions ding/rl_utils/log_prob_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Callable, Optional
from typing import List, Callable, Optional, Any
import torch
from torch import Tensor

Expand Down Expand Up @@ -62,19 +62,26 @@ def efficient_method(logits: Tensor, index: Tensor) -> Tensor:
return per_token_logps


def less_efficient_method(logits: Tensor, action: Tensor) -> Tensor:
def less_efficient_method(logits: Tensor, index: Tensor) -> Tensor:
"""Calculate per-token log probabilities using categorical distribution.
Args:
logits: Token logits of shape [B, S, V] or [S, V] where:
B = batch size
S = sequence length
V = vocabulary size
action: Selected token indices of shape [B, S] or [S]
index: Selected token indices of shape [B, S] or [S]
Returns:
Tensor: Log probabilities for selected tokens of shape [B, S] or [S]
"""
dist = torch.distributions.categorical.Categorical(logits=logits)
logp: Tensor = dist.log_prob(action)
logp: Tensor = dist.log_prob(index)
return logp

Check warning on line 80 in ding/rl_utils/log_prob_utils.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/log_prob_utils.py#L78-L80

Added lines #L78 - L80 were not covered by tests


# 定义一个统一的类型
LogProbFunction = Callable[[Tensor, Tensor], Tensor]

# 导出所有方法
__all__ = ['naive_method', 'efficient_method', 'less_efficient_method', 'LogProbFunction']
16 changes: 9 additions & 7 deletions ding/rl_utils/rloo.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
from typing import Tuple
from collections import namedtuple
import torch
from .log_prob_utils import efficient_method, naive_method, less_efficient_method
from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction

rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'reward', 'weight'])
rloo_info = namedtuple('rloo_info', ['approx_kl', 'clipfrac'])


def rloo_policy_error(
data: namedtuple,
logpro_cal=efficient_method, # Method to calculate the log probabilities
log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities
clip_ratio: float = 0.2,
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Implementation of Rejection Learning with Optimistic Optimization (RLOO) for RLHF.
Implementation of Rejection Learning with Optimistic Optimization (arXiv:2402.14740)
Arguments:
- data (:obj:`namedtuple`): the rloo input data with fields shown in ``rloo_policy_data``.
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
- log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, \
defaults to `efficient_method`.
Returns:
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
- rloo_info (:obj:`namedtuple`): the rloo optim information for monitoring, all of them are Python scalar.
Shapes:
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,
and V is vocabulary size.
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,\
and V is vocabulary size.
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- action (:obj:`torch.LongTensor`): :math:`(B, S)`.
- reward (:obj:`torch.FloatTensor`): :math:`(K, B)`, where K is the number of samples per prompt.
Expand All @@ -41,8 +43,8 @@ def rloo_policy_error(
adv = adv.flatten()

# Get log probabilities for selected actions
per_token_logps = logpro_cal(data.logit_new, data.action)
per_token_old_logps = logpro_cal(data.logit_old, data.action)
per_token_logps = log_prob_fn(data.logit_new, data.action)
per_token_old_logps = log_prob_fn(data.logit_old, data.action)

# Calculate policy ratio
ratio = torch.exp(per_token_logps - per_token_old_logps)
Expand Down
103 changes: 103 additions & 0 deletions ding/rl_utils/tests/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Testing Log Probability Methods with GPU

The script `test_log_prob_fn` benchmarks different methods for calculating log probabilities (`naive_method`, `efficient_method`, `less_efficient_method`) using both `float32` and `bfloat16` precision formats on a GPU.

## Overview

The script performs benchmarks on three different methods for calculating log probabilities and reports the time and peak GPU memory usage for each method:

- **Naive Method**
- **Efficient Method**
- **Less Efficient Method**

It runs the tests for two types of precision:

- `float32`
- `bfloat16`

## Test Functions

There are two main test functions in the script:

1. **`test_log_prob_methods_float32`**: This function benchmarks the three methods using `float32` precision.
2. **`test_log_prob_methods_bfloat16`**: This function benchmarks the three methods using `bfloat16` precision.

### Workflow

1. **Parameters Setup**: The tests are executed using a batch size of `16`, a sequence length of `1024`, and a dictionary size of `32768`. The data is randomly generated for benchmarking.
2. **GPU Memory Tracking**: The GPU memory is tracked using `torch.cuda.max_memory_allocated()` to measure the peak memory usage during the benchmark.
3. **Method Execution**: Each method is run multiple times (10 iterations) to measure the execution time and to ensure stability.
4. **Results Validation**: The results from each method are compared with the `Naive` method to check for correctness, with a tolerance value applied for `bfloat16` precision.

### Benchmarked Methods:

- **Naive Method**: The basic, unoptimized method for calculating log probabilities.
- **Efficient Method**: An optimized version of the naive method to reduce memory usage.
- **Less Efficient Method**: A method with a higher memory consumption compared to the efficient method.

### GPU Memory Usage:

The function `get_gpu_memory()` is used to fetch the current peak GPU memory usage during the execution of each method.

## Output Example

### Testing with `float32` Precision

```
==================================================
Testing with float32 precision
==================================================
Naive:
Time: 5.07 ± 0.83 ms
Peak GPU Memory: 4096.31 MB
Efficient:
Time: 15.76 ± 21.19 ms
Peak GPU Memory: 2176.44 MB
Less_Efficient:
Time: 14.63 ± 5.06 ms
Peak GPU Memory: 4608.39 MB
PASSED [100%]
```

### Testing with `bfloat16` Precision

```
==================================================
Testing with bfloat16 precision
==================================================
Naive:
Time: 1.42 ± 0.00 ms
Peak GPU Memory: 2048.22 MB
Efficient:
Time: 1.83 ± 0.01 ms
Peak GPU Memory: 1152.25 MB
Less_Efficient:
Time: 8.67 ± 0.07 ms
Peak GPU Memory: 2560.27 MB
```

## Results Analysis

- Execution Time
- The Naive method is the fastest in both precisions but sacrifices memory efficiency.
- The Efficient method balances memory usage and execution time, though it is slower than the Naive method.
- The Less Efficient method is slower than both the Naive and Efficient methods and consumes the most memory, making it the least desirable for both speed and memory usage.
- GPU Memory
- The Efficient method consistently uses the least memory, especially in the `bfloat16` precision where it achieves the lowest memory consumption.
- The Naive method uses more memory than the Efficient method but has lower execution times.
- The Less Efficient method consumes the most memory in both precision formats.

## How to Run the Tests

To run the tests:

```bash
pytest -v -s test_log_prob_fn.py
```

109 changes: 97 additions & 12 deletions ding/rl_utils/tests/test_log_prob_fn.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,64 @@
import pytest
import numpy as np
import torch
from ding.rl_utils.log_prob_utils import (efficient_method, naive_method, less_efficient_method)
from torch import Tensor
from typing import Dict, List, Tuple
from ding.rl_utils.log_prob_utils import (efficient_method, naive_method, less_efficient_method, LogProbFunction)


def get_gpu_memory() -> float:
"""获取当前GPU内存使用情况"""
if torch.cuda.is_available():
return torch.cuda.max_memory_allocated() / 1024 / 1024 # 转换为MB
return 0

Check warning on line 13 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L11-L13

Added lines #L11 - L13 were not covered by tests


@pytest.fixture
def batch_size():
def batch_size() -> int:
return 16

Check warning on line 18 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L18

Added line #L18 was not covered by tests


@pytest.fixture
def seq_length():
def seq_length() -> int:
return 1024

Check warning on line 23 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L23

Added line #L23 was not covered by tests


@pytest.fixture
def dictionary_num():
def dictionary_num() -> int:
return 32768

Check warning on line 28 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L28

Added line #L28 was not covered by tests


@pytest.mark.gputest
def test_log_prob_methods_benchmark():
"""Benchmark different methods for calculating log probabilities"""
# 设置参数
def test_log_prob_methods_float32(batch_size: int, seq_length: int, dictionary_num: int) -> None:
"""Benchmark different methods for calculating log probabilities with float32"""
print("\n" + "=" * 50)
print("Testing with float32 precision")
print("=" * 50)

Check warning on line 36 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L34-L36

Added lines #L34 - L36 were not covered by tests

# 设置参数
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() # 重置内存统计

Check warning on line 41 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L39-L41

Added lines #L39 - L41 were not covered by tests

# 生成测试数据
logits = torch.randn(batch_size, seq_length, dictionary_num, device=device, dtype=torch.float32)
input_ids = torch.randint(0, dictionary_num, (batch_size, seq_length), device=device)
logits: Tensor = torch.randn(batch_size, seq_length, dictionary_num, device=device, dtype=torch.float32)
input_ids: Tensor = torch.randint(0, dictionary_num, (batch_size, seq_length), device=device)

Check warning on line 45 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L44-L45

Added lines #L44 - L45 were not covered by tests

# 预热 GPU
for _ in range(3):
_ = naive_method(logits[:2], input_ids[:2])
torch.cuda.synchronize()

Check warning on line 50 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L48-L50

Added lines #L48 - L50 were not covered by tests

# 测试每个方法
results = {}
results: Dict[str, Tensor] = {}
peak_memory: Dict[str, float] = {}
for method, name in [(naive_method, "Naive"), (efficient_method, "Efficient"),

Check warning on line 55 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L53-L55

Added lines #L53 - L55 were not covered by tests
(less_efficient_method, "Less_Efficient")]:
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() # 重置每个方法的内存统计

Check warning on line 58 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L57-L58

Added lines #L57 - L58 were not covered by tests

# 运行多次并计时
times = []
times: List[float] = []
for _ in range(10):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
Expand All @@ -53,13 +71,80 @@ def test_log_prob_methods_benchmark():
if len(times) == 1:
results[name] = result

Check warning on line 72 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L61-L72

Added lines #L61 - L72 were not covered by tests

# 记录内存使用
peak_memory[name] = get_gpu_memory()

Check warning on line 75 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L75

Added line #L75 was not covered by tests

# 计算统计信息
mean_time = np.mean(times)
std_time = np.std(times)
print(f"\n{name}: {mean_time:.2f} ± {std_time:.2f} ms")
print(f"\n{name}:")
print(f"Time: {mean_time:.2f} ± {std_time:.2f} ms")
print(f"Peak GPU Memory: {peak_memory[name]:.2f} MB")

Check warning on line 82 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L78-L82

Added lines #L78 - L82 were not covered by tests

# 验证结果正确性
for name, result in results.items():
if name != "Naive":
diff = (results["Naive"] - result).abs().max().item()
assert diff < 1e-5, f"Results mismatch between Naive and {name}: {diff}"

Check warning on line 88 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L85-L88

Added lines #L85 - L88 were not covered by tests


@pytest.mark.gputest
def test_log_prob_methods_bfloat16(batch_size: int, seq_length: int, dictionary_num: int) -> None:
"""Benchmark different methods for calculating log probabilities with bfloat16"""
print("\n" + "=" * 50)
print("Testing with bfloat16 precision")
print("=" * 50)

Check warning on line 96 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L94-L96

Added lines #L94 - L96 were not covered by tests

# 设置参数
device = "cuda" if torch.cuda.is_available() else "cpu"
tolerance = 0.1 # bfloat16的容差值要更大一些

Check warning on line 100 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L99-L100

Added lines #L99 - L100 were not covered by tests

if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() # 重置内存统计

Check warning on line 103 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L102-L103

Added lines #L102 - L103 were not covered by tests

# 生成测试数据
logits: Tensor = torch.randn(batch_size, seq_length, dictionary_num, device=device, dtype=torch.bfloat16)
input_ids: Tensor = torch.randint(0, dictionary_num, (batch_size, seq_length), device=device)

Check warning on line 107 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L106-L107

Added lines #L106 - L107 were not covered by tests

# 预热 GPU
for _ in range(3):
_ = naive_method(logits[:2], input_ids[:2])
torch.cuda.synchronize()

Check warning on line 112 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L110-L112

Added lines #L110 - L112 were not covered by tests

# 测试每个方法
results: Dict[str, Tensor] = {}
peak_memory: Dict[str, float] = {}
for method, name in [(naive_method, "Naive"), (efficient_method, "Efficient"),

Check warning on line 117 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L115-L117

Added lines #L115 - L117 were not covered by tests
(less_efficient_method, "Less_Efficient")]:
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() # 重置每个方法的内存统计

Check warning on line 120 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L119-L120

Added lines #L119 - L120 were not covered by tests

# 运行多次并计时
times: List[float] = []
for _ in range(10):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
result = method(logits, input_ids)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
if len(times) == 1:
results[name] = result

Check warning on line 134 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L123-L134

Added lines #L123 - L134 were not covered by tests

# 记录内存使用
peak_memory[name] = get_gpu_memory()

Check warning on line 137 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L137

Added line #L137 was not covered by tests

# 计算统计信息
mean_time = np.mean(times)
std_time = np.std(times)
print(f"\n{name}:")
print(f"Time: {mean_time:.2f} ± {std_time:.2f} ms")
print(f"Peak GPU Memory: {peak_memory[name]:.2f} MB")

Check warning on line 144 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L140-L144

Added lines #L140 - L144 were not covered by tests

# 验证结果正确性
for name, result in results.items():
if name != "Naive":
diff = (results["Naive"] - result).abs().max().item()
assert diff < tolerance, f"Results mismatch between Naive and {name}: {diff}"

Check warning on line 150 in ding/rl_utils/tests/test_log_prob_fn.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_log_prob_fn.py#L147-L150

Added lines #L147 - L150 were not covered by tests

0 comments on commit 4acb28d

Please sign in to comment.