Skip to content

Commit

Permalink
fix unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Jan 18, 2024
1 parent 2b40bdb commit 14cdebe
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 40 deletions.
78 changes: 44 additions & 34 deletions src/agentscope/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def register_budget(
self,
model_name: str,
value: float,
prefix: Optional[str] = 'local'
prefix: Optional[str] = "local",
) -> bool:
"""Register model call budget to the monitor, the monitor will raise
QuotaExceededError, when budget is exceeded.
Expand All @@ -212,9 +212,11 @@ def register_budget(
class QuotaExceededError(Exception):
"""An Exception used to indicate that a certain metric exceeds quota"""

def __init__(self,
metric_name: Optional[str] = None,
quota: Optional[float] = None) -> None:
def __init__(
self,
metric_name: Optional[str] = None,
quota: Optional[float] = None,
) -> None:
if metric_name is not None and quota is not None:
self.message = f"Metric [{metric_name}] exceeds quota [{quota}]"
super().__init__(self.message)
Expand Down Expand Up @@ -357,10 +359,12 @@ def get_metrics(self, filter_regex: Optional[str] = None) -> dict:
if pattern.search(key)
}

def register_budget(self,
model_name: str,
value: float,
prefix: Optional[str] = 'local') -> bool:
def register_budget(
self,
model_name: str,
value: float,
prefix: Optional[str] = "local",
) -> bool:
logger.warning("DictMonitor doesn't support register_budget")
return False

Expand Down Expand Up @@ -446,12 +450,12 @@ def _create_monitor_table(self, drop_exists: bool = False) -> None:
BEGIN
SELECT RAISE(FAIL, 'QuotaExceeded');
END;
"""
""",
)

def _get_trigger_name(self, metric_name: str) -> str:
"""Get the name of the trigger on a certain metric"""
return f'{self.table_name}.{metric_name}.trigger'
return f"{self.table_name}.{metric_name}.trigger"

def register(
self,
Expand Down Expand Up @@ -630,8 +634,8 @@ def _create_update_cost_trigger(
self,
token_metric: str,
cost_metric: str,
unit_price: float
) -> bool:
unit_price: float,
) -> None:
with sqlite_transaction(self.db_path) as cursor:
cursor.execute(
f"""
Expand All @@ -645,36 +649,42 @@ def _create_update_cost_trigger(
SET value = value + (NEW.value - OLD.value) * {unit_price}
WHERE name = "{cost_metric}";
END;
"""
""",
)

def register_budget(
self,
model_name: str,
value: float,
prefix: Optional[str] = None
prefix: Optional[str] = None,
) -> bool:
logger.info(f"set budget {value} to {model_name}")
pricing = get_pricing()
if model_name in pricing:
budget_metric_name = f'{prefix}.{model_name}.cost'
budget_metric_name = f"{prefix}.{model_name}.cost"
ok = self.register(
metric_name=budget_metric_name,
metric_unit='dollor',
quota=value)
metric_unit="dollor",
quota=value,
)
if not ok:
return False
for metric_name, unit_price in pricing[model_name].items():
token_metric_name = f'{prefix}.{model_name}.{metric_name}'
token_metric_name = f"{prefix}.{model_name}.{metric_name}"
self.register(
metric_name=token_metric_name,
metric_unit='token')
metric_unit="token",
)
self._create_update_cost_trigger(
token_metric_name, budget_metric_name, unit_price)
token_metric_name,
budget_metric_name,
unit_price,
)
return True
else:
logger.warning(
f'Calculate budgets for model [{model_name}] is not supported')
f"Calculate budgets for model [{model_name}] is not supported",
)
return False


Expand All @@ -685,22 +695,22 @@ def get_pricing() -> dict:
`dict`: the dict with pricing information.
"""
return {
'gpt-4-turbo': {
'prompt_tokens': 0.00001,
'completion_tokens': 0.00003
"gpt-4-turbo": {
"prompt_tokens": 0.00001,
"completion_tokens": 0.00003,
},
'gpt-4': {
'prompt_tokens': 0.00003,
'completion_tokens': 0.00006
"gpt-4": {
"prompt_tokens": 0.00003,
"completion_tokens": 0.00006,
},
'gpt-4-32k': {
'prompt_tokens': 0.00006,
'completion_tokens': 0.00012
"gpt-4-32k": {
"prompt_tokens": 0.00006,
"completion_tokens": 0.00012,
},
"gpt-3.5-turbo": {
"prompt_tokens": 0.000001,
"completion_tokens": 0.000002,
},
'gpt-3.5-turbo': {
'prompt_tokens': 0.000001,
'completion_tokens': 0.000002
}
}


Expand Down
26 changes: 20 additions & 6 deletions tests/monitor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,29 +183,43 @@ def test_register_budget(self) -> None:
"""Test register_budget method of monitor"""
self.assertTrue(
self.monitor.register_budget(
model_name='gpt-4', value=5, prefix="agent_A")
model_name="gpt-4",
value=5,
prefix="agent_A",
),
)
# register an existing model with different prefix is ok
self.assertTrue(
self.monitor.register_budget(
model_name='gpt-4', value=15, prefix="agent_B")
model_name="gpt-4",
value=15,
prefix="agent_B",
),
)
gpt_4_3d = {
"agent_A.gpt-4.prompt_tokens": 50000,
"agent_A.gpt-4.completion_tokens": 25000,
"agent_A.gpt-4.total_tokens": 750000
"agent_A.gpt-4.total_tokens": 750000,
}
# agentA uses 3 dollors
self.monitor.update(**gpt_4_3d)
# agentA uses another 3 dollors and exceeds quota
self.assertRaises(
QuotaExceededError,
self.monitor.update,
**gpt_4_3d
**gpt_4_3d,
)
self.assertLess(
self.monitor.get_value( # type: ignore [arg-type]
"agent_A.gpt-4.cost",
),
5,
)
self.assertLess(self.monitor.get_value('agent_A.gpt-4.cost'), 5)
# register an existing model with existing prefix is wrong
self.assertFalse(
self.monitor.register_budget(
model_name='gpt-4', value=5, prefix="agent_A")
model_name="gpt-4",
value=5,
prefix="agent_A",
),
)

0 comments on commit 14cdebe

Please sign in to comment.