Skip to content

Commit

Permalink
format(dcy): format files
Browse files Browse the repository at this point in the history
  • Loading branch information
Berit-chengyi committed Feb 14, 2025
1 parent 71190d4 commit eba91a1
Show file tree
Hide file tree
Showing 12 changed files with 22 additions and 68 deletions.
1 change: 0 additions & 1 deletion ding/framework/middleware/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .mock_for_test import MockEnv, MockPolicy, MockHerRewardModel, CONFIG
22 changes: 6 additions & 16 deletions ding/rl_utils/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
from collections import namedtuple
import torch

grpo_policy_data = namedtuple(
'grpo_policy_data',
['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight']
)
grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight'])


def grpo_policy_error(
Expand Down Expand Up @@ -44,10 +41,7 @@ def grpo_policy_error(

# Calculate KL divergence: exp(q-p) - (q-p) - 1,
# where p is current policy and q is reference policy
per_token_kl = (
torch.exp(per_token_ref_logps - per_token_logps) -
(per_token_ref_logps - per_token_logps) - 1
)
per_token_kl = (torch.exp(per_token_ref_logps - per_token_logps) - (per_token_ref_logps - per_token_logps) - 1)

# Calculate policy ratio
ratio = torch.exp(per_token_logps - per_token_old_logps)
Expand All @@ -57,8 +51,7 @@ def grpo_policy_error(
advantages = data.adv.unsqueeze(1) # [B, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped,
per_token_loss_clipped)
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)

# Add KL divergence regularization term
per_token_loss = per_token_loss + beta * per_token_kl
Expand All @@ -70,13 +63,10 @@ def grpo_policy_error(

# Calculate additional metrics
metrics = {
'mean_kl': ((per_token_kl * weight).sum(dim=1) /
weight.sum(dim=1)).mean().item(),
'mean_ratio': ((ratio * weight).sum(dim=1) /
weight.sum(dim=1)).mean().item(),
'mean_kl': ((per_token_kl * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (
(ratio > (1 + clip_ratio)).float().mean().item() +
(ratio < (1 - clip_ratio)).float().mean().item()
(ratio > (1 + clip_ratio)).float().mean().item() + (ratio < (1 - clip_ratio)).float().mean().item()
),
}

Expand Down
20 changes: 7 additions & 13 deletions ding/rl_utils/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
from collections import namedtuple
import torch

rloo_policy_data = namedtuple(
'rloo_policy_data',
['logit_new', 'logit_old',
'action', 'adv', 'weight'])
rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight'])


def rloo_policy_error(
Expand Down Expand Up @@ -45,25 +42,22 @@ def rloo_policy_error(
advantages = data.adv.unsqueeze(1) # [B, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped,
per_token_loss_clipped)
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)

# Calculate average loss using weight mask
weight = data.weight if data.weight is not None else (
torch.ones_like(per_token_loss))
weight = data.weight if data.weight is not None else (torch.ones_like(per_token_loss))
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# Calculate additional metrics
metrics = {
'mean_ratio': ((ratio * weight).sum(dim=1) /
weight.sum(dim=1)).mean().item(),
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() +
(ratio < (1 - clip_ratio)).float().mean().item(),
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + (ratio <
(1 - clip_ratio)).float().mean().item(),
'mean_advantage': advantages.mean().item(),
}

# Create return namedtuples
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)

return loss_info, metric_info
return loss_info, metric_info
16 changes: 4 additions & 12 deletions ding/rl_utils/tests/test_grpo_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@ def dictionary_num():


@pytest.mark.unittest
def test_grpo_policy_loss_with_mask(
batch_size: int = 4,
seq_length: int = 8,
vocab_size: int = 1000):
def test_grpo_policy_loss_with_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
"""Test GRPO policy loss calculation with mask"""
# 1. Create test data
logit_new = (torch.randn(batch_size, seq_length,
vocab_size).requires_grad_(True))
logit_new = (torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True))
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2
action = torch.randint(0, vocab_size, (batch_size, seq_length))
Expand Down Expand Up @@ -72,14 +68,10 @@ def test_grpo_policy_loss_with_mask(


@pytest.mark.unittest
def test_grpo_policy_loss_without_mask(
batch_size: int = 4,
seq_length: int = 8,
vocab_size: int = 1000):
def test_grpo_policy_loss_without_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
"""Test GRPO policy loss calculation without mask"""
# 1. Create test data
logit_new = torch.randn(batch_size,
seq_length, vocab_size).requires_grad_(True)
logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2
action = torch.randint(0, vocab_size, (batch_size, seq_length))
Expand Down
24 changes: 5 additions & 19 deletions ding/rl_utils/tests/test_rloo_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,13 @@ def dictionary_num():
def test_rloo_policy_loss_without_mask(batch_size, seq_length, dictionary_num):
"""Test RLOO policy loss calculation without mask"""
# Create test data
logit_new = torch.randn(batch_size,
seq_length, dictionary_num).requires_grad_(True)
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
action = torch.randint(0, dictionary_num, (batch_size, seq_length))
advantages = torch.randn(batch_size)

# Calculate loss
data = rloo_policy_data(
logit_new=logit_new,
logit_old=logit_old,
action=action,
adv=advantages,
weight=None
)
data = rloo_policy_data(logit_new=logit_new, logit_old=logit_old, action=action, adv=advantages, weight=None)
loss, info = rloo_policy_error(data, clip_ratio=0.2)

# Verify outputs
Expand All @@ -54,22 +47,15 @@ def test_rloo_policy_loss_without_mask(batch_size, seq_length, dictionary_num):
def test_rloo_policy_loss_with_mask(batch_size, seq_length, dictionary_num):
"""Test RLOO policy loss calculation with mask"""
# Create test data
logit_new = torch.randn(batch_size,
seq_length, dictionary_num).requires_grad_(True)
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
action = torch.randint(0, dictionary_num, (batch_size, seq_length))
advantages = torch.randn(batch_size)
action_mask = torch.ones(batch_size, seq_length)
action_mask[:, -2:] = 0

# Calculate loss
data = rloo_policy_data(
logit_new=logit_new,
logit_old=logit_old,
action=action,
adv=advantages,
weight=action_mask
)
data = rloo_policy_data(logit_new=logit_new, logit_old=logit_old, action=action, adv=advantages, weight=action_mask)
loss, info = rloo_policy_error(data, clip_ratio=0.2)

# Verify outputs
Expand All @@ -85,4 +71,4 @@ def test_rloo_policy_loss_with_mask(batch_size, seq_length, dictionary_num):
assert 'mean_ratio' in info._asdict()
assert 'mean_clipped' in info._asdict()
assert 'mean_advantage' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])
assert all([np.isscalar(v) for v in info._asdict().values()])
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,3 @@
from ding.entry import serial_pipeline
with DDPContext():
serial_pipeline((main_config, create_config), seed=0)

1 change: 0 additions & 1 deletion dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
Expand Down
1 change: 0 additions & 1 deletion dizoo/d4rl/config/halfcheetah_medium_iql_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
Expand Down
1 change: 0 additions & 1 deletion dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
Expand Down
1 change: 0 additions & 1 deletion dizoo/d4rl/config/hopper_medium_expert_iql_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
Expand Down
1 change: 0 additions & 1 deletion dizoo/d4rl/config/hopper_medium_iql_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
Expand Down
1 change: 0 additions & 1 deletion dizoo/d4rl/config/hopper_medium_replay_iql_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
Expand Down

0 comments on commit eba91a1

Please sign in to comment.