Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Jan 17, 2025
1 parent 40e9125 commit 169e4cd
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 90 deletions.
21 changes: 11 additions & 10 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class ManagerArgs:
pd_ratio: Union[str, List[int]] = None

# init from instance args
group_kind_migration_backend: bool = None
is_group_kind_migration_backend: bool = None
enable_engine_pd_disagg: bool = None

def __post_init__(self):
Expand All @@ -162,9 +162,9 @@ def parse_ratio(ratio_str):

def init_from_instance_args(self, instance_args: 'InstanceArgs'):
self.enable_engine_pd_disagg = instance_args.enable_engine_pd_disagg
self.group_kind_migration_backend = instance_args.migration_backend in ['gloo', 'nccl']
self.is_group_kind_migration_backend = instance_args.migration_backend in ['gloo', 'nccl']

def create_global_scheduler_config(self, group_kind_migration_backend) -> Tuple[GlobalSchedulerConfig]:
def create_global_scheduler_config(self, is_group_kind_migration_backend) -> Tuple[GlobalSchedulerConfig]:
# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
self.dispatch_policy,
Expand All @@ -175,7 +175,7 @@ def create_global_scheduler_config(self, group_kind_migration_backend) -> Tuple[
self.scale_up_threshold,
self.scale_down_threshold,
self.enable_pd_disagg,
group_kind_migration_backend)
is_group_kind_migration_backend)
return global_scheduler_config

@classmethod
Expand Down Expand Up @@ -280,7 +280,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
help='enable prefill decoding disaggregation')
parser.add_argument('--pd-ratio',
type=str,
help='the p:d ratio used in gloabl launch model.')
help='the prefill decode ratio used in gloabl launch model e.g. "1:1"')
return parser

@dataclass
Expand Down Expand Up @@ -381,12 +381,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
choices=['prefill', 'decode', 'no_constraints'],
help="instance type for the engine. When non-pd-disaggregation option, set instance_type \
to no_constraints. For pd-disaggregation implemented via LLuminx, specify instance_type \
as either prefill or decode for local launch model, and configure --pd-ratio \
appropriately for global launch model. When pd-disaggregation is handled internally \
within the LLM engine, don't set --enable-pd-disagg. --instance-type parameters should not \
as either prefill or decode for local launch model and it is not necessary to set for \
global launch model as the manager will automatically determine the instance type and \
quantity based on the --pd-ratio. When pd-disaggregation is handled internally within \
the LLM engine, don't set --enable-pd-disagg. --instance-type parameters should not \
alse be set. Instead, the instance_type will be automatically assigned to either prefill \
or decode based on engine_args for local launch mode, and the p:d ratio must also be \
configured for global launch model.")
or decode based on engine_args for local launch mode, and donot set it for global launch \
model.")
parser.add_argument('--profiling-result-file-path',
type=str,
help='profiling result file path when using simulator')
Expand Down
3 changes: 2 additions & 1 deletion llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]:
instance_info.instance_id = self.instance_id
instance_info.step_id = next(self.step_counter)
instance_info.timestamp = time.time()
instance_info.profiling_data=(instance_info.inference_type.value,
# TODO(KuilongCui): add cli_args to determine whether to collect profiling data
instance_info.profiling_data=(instance_info.inference_type.value if instance_info.inference_type else "",
instance_info.num_seqs,
sum(instance_info.running_seq_lens),
self.model_executor.last_inference_latency)
Expand Down
1 change: 1 addition & 0 deletions llumnix/entrypoints/vllm/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import time
import asyncio
from typing import Dict
import ray

from vllm.engine.async_llm_engine import AsyncStream
Expand Down
2 changes: 1 addition & 1 deletion llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, global_scheduler_config: GlobalSchedulerConfig) -> None:
# migrate args
self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy,
global_scheduler_config.migrate_out_load_threshold,
global_scheduler_config.group_kind_migration_backend)
global_scheduler_config.is_group_kind_migration_backend)
# auto-scaling args
self.scaling_scheduler = ScalingScheduler(global_scheduler_config.scale_up_threshold,
global_scheduler_config.scale_down_threshold,
Expand Down
10 changes: 5 additions & 5 deletions llumnix/global_scheduler/migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
class MigrationScheduler:
def __init__(self, pair_migration_policy: str,
migrate_out_load_threshold: float,
group_kind_migration_backend: bool) -> None:
is_group_kind_migration_backend: bool) -> None:
filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold)
self.migration_filter = MigrationInstanceFilter(filter_config)
self._register_migration_backend_init_filter(group_kind_migration_backend)
self._register_migration_backend_init_filter(is_group_kind_migration_backend)

self.pair_migration_policy = PairMigrationPolicyFactory.get_policy(
pair_migration_policy, migrate_out_load_threshold=migrate_out_load_threshold)
Expand All @@ -36,13 +36,13 @@ def __init__(self, pair_migration_policy: str,
self.instance_id_set: Set[str] = set()
self.instance_info: Dict[str, InstanceInfo] = None

def _register_migration_backend_init_filter(self, group_kind_migration_backend: bool) -> None:
def _register_migration_backend_init_filter(self, is_group_kind_migration_backend: bool) -> None:
# some migration backends require init_process_group before passing the KV cache. Here, we add a filter
# to prevent instances of migration backends that have not been initialized from participating in migration.
migration_backend_init_filter = CustomFilter()
migration_backend_init_filter.set_filter_condtition(
src_filter=lambda _: not group_kind_migration_backend,
dst_filter=lambda _: not group_kind_migration_backend)
src_filter=lambda _: not is_group_kind_migration_backend,
dst_filter=lambda _: not is_group_kind_migration_backend)
self.migration_filter.register_filter("migration_backend_init_filter", migration_backend_init_filter)

# migration_filter must ensure that the specific instance_info does not appear in both src and dst simultaneously
Expand Down
5 changes: 2 additions & 3 deletions llumnix/instance_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,22 @@
logger = init_logger(__name__)

class InstanceType(str, Enum):
UNKKNOWN = "unknown"
NO_CONSTRAINTS = "no_constraints"
PREFILL = "prefill"
DECODE = "decode"

@dataclass
class InstanceInfo:
instance_id: str = ""
instance_type: InstanceType = InstanceType.UNKKNOWN
instance_type: InstanceType = None

step_id: int = None
timestamp: float = None
num_batched_tokens: int = None
num_seqs = None
running_seq_lens: List[int] = field(default_factory=list)
last_inference_latency: float = None
inference_type: RequestInferenceType = RequestInferenceType.UNKNOWN
inference_type: RequestInferenceType = None

num_total_gpu_blocks: int = 0
num_watermark_blocks: int = 0
Expand Down
4 changes: 2 additions & 2 deletions llumnix/internal_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
scale_up_threshold: float,
scale_down_threshold: float,
enable_pd_disagg: bool,
group_kind_migration_backend: bool,) -> None:
is_group_kind_migration_backend: bool,) -> None:
self.initial_instances = initial_instances
self.dispatch_policy = dispatch_policy
self.pair_migration_policy = pair_migration_policy
Expand All @@ -60,4 +60,4 @@ def __init__(
self.scale_down_threshold = scale_down_threshold

self.enable_pd_disagg = enable_pd_disagg
self.group_kind_migration_backend = group_kind_migration_backend
self.is_group_kind_migration_backend = is_group_kind_migration_backend
4 changes: 2 additions & 2 deletions llumnix/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: Entrypoint
self.inflight_num_prefill -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0
self.inflight_num_decode -= 1 if instance_args.instance_type == InstanceType.DECODE else 0
if instance_finish_cb:
# manager.scale_up will be called after the instance is ready
# manager.scale_up will be called here after the instance is ready
instance_finish_cb(instance_id, instance, instance_args)
if server_finish_cb:
server_finish_cb(instance_id, server)
Expand All @@ -188,7 +188,7 @@ async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: Entrypoint
logger.error("[_init_server_and_instance] unexpected exception occurs: {}".format(e))
logger.error("[_init_server_and_instance] exception traceback: {}".format(traceback.format_exc()))
self.clear_instance_ray_resources(instance_id)
logger.info("init_server_and_instance is called")

request_output_queue_type = QueueType(entrypoints_args.request_output_queue_type)
next_instance_args = self._get_next_instance_args(instance_args)
instance = self.init_instance(instance_id, next_instance_args, placement_group,
Expand Down
1 change: 1 addition & 0 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, ds
raise
return migrated_request

# TODO(KuilongCui): only the metrics-related information needs to be synchronously loaded for the manager
def get_instance_info(self) -> InstanceInfo:
instance_info: InstanceInfo = self.backend_engine.engine.instance_info
instance_info.instance_type = self.instance_args.instance_type
Expand Down
1 change: 0 additions & 1 deletion llumnix/llumlet/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class RequestInferenceType(str, Enum):
UNKNOWN = "unknown"
PREFILL = "prefill"
DECODE = "decode"

Expand Down
30 changes: 7 additions & 23 deletions llumnix/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __init__(self,

self.polling_interval = manager_args.polling_interval

self.group_kind_migration_backend = manager_args.group_kind_migration_backend
global_scheduler_config = manager_args.create_global_scheduler_config(self.group_kind_migration_backend)
self.is_group_kind_migration_backend = manager_args.is_group_kind_migration_backend
global_scheduler_config = manager_args.create_global_scheduler_config(self.is_group_kind_migration_backend)
self.global_scheduler = GlobalScheduler(global_scheduler_config)

self.launcher: Launcher = Launcher(self.global_scheduler, manager_args.enable_port_increment,
Expand Down Expand Up @@ -422,7 +422,7 @@ def scale_up(self, instance_id: Union[str, Iterable[str]],
# a coroutine is already handling the changes in the number of instances in the cluster and it will account for the changes
# caused by this scale-up (see rebuild_migration_backend for details). Therefore, we simply return in this case.
# Specifically, for not group kind migration backend, there is no need to rebuild the group.
if self.enable_migration and self.group_kind_migration_backend \
if self.enable_migration and self.is_group_kind_migration_backend \
and indeed_update and no_pending_instance:
asyncio.create_task(self._rebuild_migration_backend())

Expand Down Expand Up @@ -457,7 +457,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_b
self.global_scheduler.scale_down(instance_ids)
self.num_instances = len(self.instances)

if self.enable_migration and self.group_kind_migration_backend:
if self.enable_migration and self.is_group_kind_migration_backend:
if len(self.instances) == 0:
self.pending_rebuild_migration_instances = 0
clear_gloo_backend_state()
Expand Down Expand Up @@ -562,20 +562,19 @@ def _inner_check_pd_deployment(self) -> str:

return scale_down_instance_id

# TODO(KuilongCui): currently, only one fallback strategy is implemented, which prevents the
# TODO(KuilongCui): currently, only one naive state check policy is implemented, which prevents the
# cluster from consisting entirely of prefill or decode instances.
async def _check_pd_deployment_states_loop(self, interval: float) -> None:
previous_penging_pg_names = None

while True:
try:
pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")])
pending_pg_states.extend(list_placement_groups(filters=[("state", "=", "RESCHEDULING")]))
rescheduling_pg_states = list_placement_groups(filters=[("state", "=", "RESCHEDULING")])
all_penging_pg_names = [pg.name for pg in pending_pg_states]

if previous_penging_pg_names:
if previous_penging_pg_names and len(rescheduling_pg_states) == 0 :
new_pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")])
new_pending_pg_states.extend(list_placement_groups(filters=[("state", "=", "RESCHEDULING")]))
all_new_penging_pg_names = [pg.name for pg in new_pending_pg_states]
if len(set(previous_penging_pg_names).difference(set(all_new_penging_pg_names))) == 0:
self._inner_check_pd_deployment()
Expand Down Expand Up @@ -734,18 +733,3 @@ def _log_instance_infos_to_csv(self, instance_infos: List[InstanceInfo]) -> None
instance_info.num_blocks_all_waiting_requests,
instance_info.waiting_time_first_waiting_request])
self.instance_info_file.flush()

def init_placement_group(self, *args, **kwargs):
return self.launcher.init_placement_group(*args, **kwargs)

def init_server_and_instance(self, *args, **kwargs):
return self.launcher.init_server_and_instance(*args, **kwargs)

def clear_instance_ray_resources(self, instance_id: str):
self.launcher.clear_instance_ray_resources(instance_id)

def get_cluster_deployment(self):
return self.launcher.get_cluster_deployment()

def get_instance_deployment_states(self, instance_id: str):
return self.launcher.get_instance_deployment_states(instance_id)
12 changes: 6 additions & 6 deletions tests/e2e_test/test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def get_markdown_data(key: str, head_name: str):
@pytest.mark.parametrize("launch_mode", ['global', 'local'])
@pytest.mark.parametrize("enable_pd_disagg", [True, False])
async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch_mode, enable_pd_disagg):
if launch_mode == 'local':
num_prompts = 500 if not enable_pd_disagg else 50
else:
num_prompts = 50 if not enable_pd_disagg else 50

ip = "127.0.0.1"
base_port = 37037
ip_ports = []
Expand Down Expand Up @@ -129,11 +134,6 @@ def run_bench_command(command):
process = subprocess.Popen(command, shell=True)
return process

if launch_mode == 'local':
num_prompts = 500 if not enable_pd_disagg else 50
else:
num_prompts = 50 if not enable_pd_disagg else 50

tasks = []
for i in range(device_count):
bench_command = generate_bench_command(
Expand Down Expand Up @@ -161,7 +161,7 @@ def run_bench_command(command):
process.kill()
assert False, "bench_test timed out after {} minutes.".format(BENCH_TEST_TIMEOUT_MINS)

if launch_mode == 'local':
if launch_mode == 'local' and not enable_pd_disagg:
with open("performance.txt", "w", encoding="utf-8") as f:
f.write(parse_log_file())

Expand Down
Loading

0 comments on commit 169e4cd

Please sign in to comment.