Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
anmolagarwalcp810 committed Aug 30, 2024
1 parent 1f60e35 commit 2e3bafd
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 45 deletions.
84 changes: 42 additions & 42 deletions etalon/capacity_search/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,54 +119,54 @@ def to_config_dict(self):
"seed": self.seed,
}
if self.request_interval_generator_provider == "gamma":
config_dict["gamma-request-interval-generator-cv"] = (
self.gamma_request_interval_generator_cv
)
config_dict[
"gamma-request-interval-generator-cv"
] = self.gamma_request_interval_generator_cv
elif self.request_interval_generator_provider == "trace":
config_dict["trace-request-interval-generator-trace-file"] = (
self.trace_request_interval_generator_trace_file
)
config_dict["trace-request-interval-generator-start-time"] = (
self.trace_request_interval_generator_start_time
)
config_dict["trace-request-interval-generator-end-time"] = (
self.trace_request_interval_generator_end_time
)
config_dict["trace-request-interval-generator-time-scale-factor"] = (
self.trace_request_interval_generator_time_scale_factor
)
config_dict[
"trace-request-interval-generator-trace-file"
] = self.trace_request_interval_generator_trace_file
config_dict[
"trace-request-interval-generator-start-time"
] = self.trace_request_interval_generator_start_time
config_dict[
"trace-request-interval-generator-end-time"
] = self.trace_request_interval_generator_end_time
config_dict[
"trace-request-interval-generator-time-scale-factor"
] = self.trace_request_interval_generator_time_scale_factor

if self.request_length_generator_provider == "trace":
config_dict["trace-request-length-generator-trace-file"] = (
self.trace_request_length_generator_trace_file
)
config_dict["trace-request-length-generator-prefill-scale-factor"] = (
self.trace_request_length_generator_prefill_scale_factor
)
config_dict["trace-request-length-generator-decode-scale-factor"] = (
self.trace_request_length_generator_decode_scale_factor
)
config_dict[
"trace-request-length-generator-trace-file"
] = self.trace_request_length_generator_trace_file
config_dict[
"trace-request-length-generator-prefill-scale-factor"
] = self.trace_request_length_generator_prefill_scale_factor
config_dict[
"trace-request-length-generator-decode-scale-factor"
] = self.trace_request_length_generator_decode_scale_factor
elif self.request_length_generator_provider == "fixed":
config_dict["fixed-request-generator-prefill-tokens"] = (
self.fixed_request_generator_prefill_tokens
)
config_dict["fixed-request-generator-decode-tokens"] = (
self.fixed_request_generator_decode_tokens
)
config_dict[
"fixed-request-generator-prefill-tokens"
] = self.fixed_request_generator_prefill_tokens
config_dict[
"fixed-request-generator-decode-tokens"
] = self.fixed_request_generator_decode_tokens
elif self.request_length_generator_provider == "synthetic":
config_dict["synthetic-request-generator-min-tokens"] = (
self.synthetic_request_generator_min_tokens
)
config_dict["synthetic-request-generator-prefill-to-decode-ratio"] = (
self.synthetic_request_generator_prefill_to_decode_ratio
)
config_dict[
"synthetic-request-generator-min-tokens"
] = self.synthetic_request_generator_min_tokens
config_dict[
"synthetic-request-generator-prefill-to-decode-ratio"
] = self.synthetic_request_generator_prefill_to_decode_ratio
elif self.request_length_generator_provider == "zipf":
config_dict["zipf-request-length-generator-theta"] = (
self.zipf_request_length_generator_theta
)
config_dict["zipf-request-length-generator-scramble"] = (
self.zipf_request_length_generator_scramble
)
config_dict[
"zipf-request-length-generator-theta"
] = self.zipf_request_length_generator_theta
config_dict[
"zipf-request-length-generator-scramble"
] = self.zipf_request_length_generator_scramble
return config_dict

def to_args(self):
Expand Down
6 changes: 5 additions & 1 deletion etalon/core/llm_clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@


def construct_clients(
model_name: str, tokenizer_name: str, llm_api: str, num_clients: int, use_ray: bool = True
model_name: str,
tokenizer_name: str,
llm_api: str,
num_clients: int,
use_ray: bool = True,
) -> List[BaseLLMClient]:
"""Construct LLMClients that will be used to make requests to the LLM API.
Expand Down
7 changes: 6 additions & 1 deletion etalon/core/requests_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ class AsyncRequestsManager:
"""Manages requests for single LLM API client."""

def __init__(
self, client_id: int, model: str, tokenizer_name: str, llm_api: str, max_concurrent_requests: int
self,
client_id: int,
model: str,
tokenizer_name: str,
llm_api: str,
max_concurrent_requests: int,
):
self.max_concurrent_requests = max_concurrent_requests
self.requests_queue = asyncio.Queue(maxsize=max_concurrent_requests)
Expand Down
5 changes: 4 additions & 1 deletion etalon/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,10 @@ def parse_args():
"--model", type=str, required=True, help="The model to use for this load test."
)
args.add_argument(
"--tokenizer", type=str, required=False, help="The tokenizer to use for this load test. By default, the tokenizer is inferred from the model."
"--tokenizer",
type=str,
required=False,
help="The tokenizer to use for this load test. By default, the tokenizer is inferred from the model.",
)
args.add_argument(
"--num-ray-clients",
Expand Down

0 comments on commit 2e3bafd

Please sign in to comment.