diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/correctness_test.yml similarity index 86% rename from .github/workflows/e2e_test.yml rename to .github/workflows/correctness_test.yml index f247a6eb..fd3455c2 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/correctness_test.yml @@ -1,4 +1,4 @@ -name: e2e_test +name: correctness_test on: pull_request: @@ -17,7 +17,7 @@ jobs: with: all_but_latest: true - e2e_tests: + correctness_tests: needs: cancel_previous_workflows runs-on: [self-hosted] timeout-minutes: 30 @@ -28,4 +28,4 @@ jobs: run: | [[ -n $(docker ps -q) ]] && docker kill $(docker ps -q) || echo "No running containers to kill." - name: Build And Test - run: ./tools/run_test.sh e2e_test + run: ./tools/run_test.sh correctness_test diff --git a/Makefile b/Makefile index 887e342f..e7cc4703 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ proto-clean: test: check_pytest_installed @pytest -v --ignore=third_party --ignore=tests/e2e_test --disable-warnings @python examlpes/offline_inference.py - @pytest -v -x -s --tb=long ./tests/e2e_test/test_e2e.py + @pytest -v -x -s --tb=long ./tests/e2e_test/test_correctness.py @pytest -v -x -s --tb=long ./tests/e2e_test/test_bench.py @pytest -v -x -s --tb=long ./tests/e2e_test/test_migration.py @@ -67,9 +67,9 @@ unit_test: check_pytest_installed offline_test: @python examlpes/offline_inference.py -.PHONY: e2e_test -e2e_test: - @pytest -v -x -s --tb=long ./tests/e2e_test/test_e2e.py +.PHONY: correctness_test +correctness_test: + @pytest -v -x -s --tb=long ./tests/e2e_test/test_correctness.py .PHONY: bench_test bench_test: diff --git a/docs/Arguments.md b/docs/Arguments.md index e2cab101..73173d5f 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -20,7 +20,9 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--log-request-timestamps] [--config-file CONFIG_FILE] [--initial-instances INITIAL_INSTANCES] - [--load-metric {remaining_steps,usage_ratio}] + [--dispatch-load-metric {remaining_steps,usage_ratio}] + [--migration-load-metric {remaining_steps,usage_ratio}] + [--scaling-load-metric {remaining_steps,usage_ratio}] [--polling-interval POLLING_INTERVAL] [--dispatch-policy {balanced,load,queue,rr}] [--enable-migration] @@ -49,12 +51,14 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--migration-backend-transfer-type {cuda_ipc,rdma,}] [--grpc-migration-backend-server-address GRPC_MIGRATION_BACKEND_SERVER_ADDRESS] [--kvtransfer-migration-backend-naming-url KVTRANSFER_MIGRATION_BACKEND_NAMING_URL] - [--max-stages MAX_STAGES] - [--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS] + [--migration-max-stages MIGRATION_MAX_STAGES] + [--migration-last-stage-max-blocks MIGRATION_LAST_STAGE_MAX_BLOCKS] [--enable-pd-disagg] - [--num-dispatch-instances NUM_DISPATCH_INSTANCES] + [--pd-ratio PD_RATIO] [--enable-port-increment] [--enable-port-offset-store] + [--instance-type INSTANCE_TYPE] + ``` `--host` @@ -111,8 +115,18 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Number of instances created at initialization. - Default: 1 -`--load-metric` -- Instance load metric. +`--dispatch-load-metric` +- Instance dispatch load metric. +- Possible choices: remaining_steps, usage_ratio +- Default: "remaining_steps" + +`--migration-load-metric` +- Instance migration load metric. +- Possible choices: remaining_steps, usage_ratio +- Default: "remaining_steps" + +`--scaling-load-metric` +- Instance scaling load metric. - Possible choices: remaining_steps, usage_ratio - Default: "remaining_steps" @@ -224,20 +238,20 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - URL of naming server for kvtransfer migration backend - Default: "file:/tmp/llumnix/naming/" -`--max-stages` -- Drop migration if the number of stages > max_stages. +`--migration-max-stages` +- Drop migration if the number of stages > migration_max_stages. - Default: 3 -`--last-stage-max-blocks` -- If the number of remaining blocks < last_stage_max_blocks, do last stage migration. +`--migration-last-stage-max-blocks` +- If the number of remaining blocks < migration_last_stage_max_blocks, do last stage migration. - Default: 16 `--enable-pd-disagg` - Enable prefill decoding disaggregation. -`--num-dispatch-instances` -- Number of available instances for dispatch. -- Default: math.inf +`--pd-ratio` +- The p:d ratio used in gloabl launch model. +- Default: "1:1" `--enable-port-increment` - Enable port increment when desploying multiple servers. @@ -245,6 +259,10 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--enable-port-offset-store` - Enable store port offset when desploying multiple servers. +`--instance-type` +- Instance types for the engine. +- Possible choices: prefill, decode, no_constraints + # Unsupported vLLM feature options `--device` diff --git a/examlpes/offline_inference.py b/examlpes/offline_inference.py index 5148a9e8..0624da3b 100644 --- a/examlpes/offline_inference.py +++ b/examlpes/offline_inference.py @@ -5,7 +5,7 @@ import ray from llumnix import launch_ray_cluster, connect_to_ray_cluster, init_manager -from llumnix import (ManagerArgs, EngineArgs, Manager, +from llumnix import (ManagerArgs, InstanceArgs, EngineArgs, Manager, Llumlet, ServerInfo, QueueType, BackendType, SamplingParams) from llumnix.utils import random_uuid @@ -35,6 +35,7 @@ # Set manager args and engine args. manager_args = ManagerArgs() +instance_args = InstanceArgs() engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True, trust_remote_code=True, max_model_len=370) @@ -45,7 +46,8 @@ # Create instances. instance_ids: List[str] = None instances: List[Llumlet] = None -instance_ids, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.VLLM, engine_args)) +instance_ids, instances = ray.get(manager.init_instances.remote( + QueueType("rayqueue"), BackendType.VLLM, instance_args, engine_args)) # The requests‘ outputs will be put to the request_output_queue no matter which instance it's running in. server_id = random_uuid() diff --git a/llumnix/__init__.py b/llumnix/__init__.py index fba69575..cae5978d 100644 --- a/llumnix/__init__.py +++ b/llumnix/__init__.py @@ -15,7 +15,7 @@ from llumnix.entrypoints.setup import (launch_ray_cluster, connect_to_ray_cluster, init_manager) -from llumnix.arg_utils import ManagerArgs +from llumnix.arg_utils import ManagerArgs, InstanceArgs from llumnix.manager import Manager from llumnix.llumlet.llumlet import Llumlet from llumnix.queue.queue_type import QueueType @@ -29,6 +29,7 @@ "connect_to_ray_cluster", "init_manager", "ManagerArgs", + "InstanceArgs", "Manager", "Llumlet", "QueueType", diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 5046be9c..491c5525 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -16,7 +16,7 @@ import dataclasses from dataclasses import dataclass import argparse -from typing import Tuple +from typing import List, Tuple, Union from llumnix.internal_config import GlobalSchedulerConfig, MigrationConfig from llumnix.config import LlumnixConfig, get_llumnix_config @@ -42,9 +42,6 @@ def add_argument(self, *args, **kwargs): kwargs['default'] = None super().add_argument(*args, **kwargs) - -# All the default values of llumnix arguments are set in default.py. So all the arguments here are set to None. - @dataclass class EntrypointsArgs: host: str = None @@ -71,9 +68,10 @@ def __post_init__(self): def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'EntrypointsArgs': # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] + cfg_attrs = [attr for attr in attrs if hasattr(cfg.SERVER, attr.upper())] # Set the attributes from the parsed arguments. # The defalut values of attributes are defined in default.py. - entrypoints_args = cls(**{attr: getattr(cfg.SERVER, attr.upper()) for attr in attrs}) + entrypoints_args = cls(**{attr: getattr(cfg.SERVER, attr.upper()) for attr in cfg_attrs}) return entrypoints_args @classmethod @@ -110,24 +108,20 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--config-file", type=str, help="path to config file of arguments") - return parser @dataclass class ManagerArgs: initial_instances: int = None - load_metric: str = None polling_interval: float = None - dispatch_policy: str = None + scaling_load_metric: str = None enable_migration: bool = None - enable_defrag: bool = None pair_migration_frequency: int = None pair_migration_policy: str = None migrate_out_threshold: float = None - request_migration_policy: str = None enable_scaling: bool = None min_instances: int = None @@ -140,25 +134,15 @@ class ManagerArgs: disable_log_requests_manager: bool = None log_instance_info: bool = None log_filename: str = None - simulator_mode: bool = None - profiling_result_file_path: str = None - - migration_backend: str = None - migration_buffer_blocks: int = None - migration_num_layers: int = None - migration_backend_init_timeout: float = None - migration_backend_transfer_type: str = None - grpc_migration_backend_server_address: str = None - kvtransfer_migration_backend_naming_url: str = None - last_stage_max_blocks: int = None - max_stages: int = None - - enable_pd_disagg: bool = None - num_dispatch_instances: int = None - enable_port_increment: bool = None enable_port_offset_store: bool = None + enable_pd_disagg: bool = None + pd_ratio: Union[str, List[int]] = None + + # init from instance args + is_group_kind_migration_backend: bool = None + enable_engine_pd_disagg: bool = None def __post_init__(self): # Check if all fields default to None @@ -168,47 +152,44 @@ def __post_init__(self): for attr in dataclasses.fields(self): if getattr(self, attr.name) is None: - setattr(self, attr.name, getattr(_C.MANAGER, attr.name.upper())) - - def create_global_scheduler_config( - self, - ) -> Tuple[GlobalSchedulerConfig]: - + if hasattr(_C.MANAGER, attr.name.upper()): + setattr(self, attr.name, getattr(_C.MANAGER, attr.name.upper())) + + def parse_ratio(ratio_str): + parts = ratio_str.split(':') + if len(parts) != 2: + raise ValueError(f"Invalid format for --pd-ratio : '{ratio_str}'. Expected format 'a:b'.") + num_prefill, num_decode = int(parts[0].strip()), int(parts[1].strip()) + assert num_prefill > 0 and num_decode > 0, "Both parts of --pd-ratio must be non-negative." + return [num_prefill, num_decode] + self.pd_ratio = parse_ratio(self.pd_ratio) + + def init_from_instance_args(self, instance_args: 'InstanceArgs'): + self.enable_engine_pd_disagg = instance_args.enable_engine_pd_disagg + self.is_group_kind_migration_backend = instance_args.migration_backend in ['gloo', 'nccl'] + + def create_global_scheduler_config(self, is_group_kind_migration_backend) -> Tuple[GlobalSchedulerConfig]: # Create the GlobalScheduler Configuration. global_scheduler_config = GlobalSchedulerConfig(self.initial_instances, - self.load_metric, self.dispatch_policy, - self.num_dispatch_instances, self.pair_migration_policy, self.migrate_out_threshold, - self.enable_defrag, self.scaling_policy, + self.scaling_load_metric, self.scale_up_threshold, self.scale_down_threshold, self.enable_pd_disagg, - self.migration_backend) + is_group_kind_migration_backend) return global_scheduler_config - def create_migration_config(self) -> MigrationConfig: - migration_config = MigrationConfig(self.request_migration_policy, - self.migration_backend, - self.migration_buffer_blocks, - self.migration_num_layers, - self.last_stage_max_blocks, - self.max_stages, - self.migration_backend_init_timeout, - self.migration_backend_transfer_type, - self.grpc_migration_backend_server_address, - self.kvtransfer_migration_backend_naming_url) - return migration_config - @classmethod def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'ManagerArgs': # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] + cfg_attrs = [attr for attr in attrs if hasattr(cfg.MANAGER, attr.upper())] # Set the attributes from the parsed arguments. # The defalut values of attributes are defined in default.py. - manager_args = cls(**{attr: getattr(cfg.MANAGER, attr.upper()) for attr in attrs}) + manager_args = cls(**{attr: getattr(cfg.MANAGER, attr.upper()) for attr in cfg_attrs}) return manager_args @classmethod @@ -219,15 +200,6 @@ def check_args(cls, args: 'ManagerArgs', parser: argparse.ArgumentParser): cur_arg = getattr(args, action.dest) assert cur_arg in action.choices, f"{action.dest} should be one of {action.choices}, but {cur_arg} is set." - # bladellm only - assert args.migration_backend not in ['kvtransfer'] or (args.migration_backend == 'kvtransfer' \ - and args.migration_backend_transfer_type), \ - ("When using kvTransfer as migration backend, " - "do not set --migration-backend-transfer-type as empty.") - - assert not args.simulator_mode or args.profiling_result_file_path is not None, \ - "Set profiling_result_file_path args when enable simulator mode" - assert not args.enable_port_offset_store or args.enable_port_increment, \ "Set enable_port_increment when enable_port_offset_store" @@ -236,15 +208,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--initial-instances', type=int, help='number of instances created at initialzation') - - parser.add_argument('--load-metric', - type=str, - choices=['remaining_steps', 'usage_ratio'], - help='instance load metric') parser.add_argument('--polling-interval', type=float, help='time interval(s) to update instance info and pair migration') - + parser.add_argument('--scaling-load-metric', + type=str, + choices=['remaining_steps', 'usage_ratio'], + help='instance scaling load metric') parser.add_argument('--dispatch-policy', type=str, choices=['balanced', 'load', 'queue', 'flood', 'rr'], @@ -254,13 +224,9 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: '* "queue" dispatch request to the instance with minimum waiting request queue length.\n' '* "flood" dispatch request to the instance with maximum requests dispatched.\n' '* "rr" dispatch requests with round-robin policy.\n') - parser.add_argument('--enable-migration', action='store_true', help='enable migrate requests between instances') - parser.add_argument('--enable-defrag', - type=bool, - help='enable defragmentation through migration based on virtual usage') parser.add_argument('--pair-migration-frequency', type=int, help='pair migration frequency') @@ -276,17 +242,6 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--migrate-out-threshold', type=float, help='migrate out instance load threshold') - parser.add_argument('--request-migration-policy', - type=str, - default=None, - choices=['LCR', 'SR', 'LR', 'FCW', 'FCWSR'], - help='The request migration policy.\n\n' - '* "LCR" migrate the running request last come.\n' - '* "SR" migrate the running request shortest.\n' - '* "LR" migrate the running request longest.\n' - '* "FCW" migrate the waiting request first come.\n' - '* "FCWSR" migrate the waiting request first come and running request shortest.\n') - parser.add_argument('--enable-scaling', action='store_true', help='enable auto scaling') @@ -309,7 +264,6 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--scale-down-threshold', type=float, help='scale down threshold') - parser.add_argument('--disable-log-requests-manager', action='store_true', help='disable logging requests in manager') @@ -319,12 +273,152 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--log-filename', type=str, help='log filename') + parser.add_argument('--enable-port-increment', + action='store_true', + help='enable port increment when desploying multiple servers') + parser.add_argument('--enable-port-offset-store', + action='store_true', + help='enable store port offset when desploying multiple servers') + parser.add_argument('--enable-pd-disagg', + action='store_true', + help='enable prefill decoding disaggregation') + parser.add_argument('--pd-ratio', + type=str, + help='the prefill decode ratio used in gloabl launch model e.g. "1:1"') + return parser + +@dataclass +class LaunchArgs: + launch_mode: LaunchMode = None + backend_type: BackendType = None + +@dataclass +class InstanceArgs: + instance_type: str = None + + simulator_mode: bool = None + profiling_result_file_path: str = None + + dispatch_load_metric: str = None + migration_load_metric: str = None + enable_defrag: bool = None + + request_migration_policy: str = None + + migration_backend: str = None + migration_buffer_blocks: int = None + migration_num_layers: int = None + migration_backend_init_timeout: float = None + migration_backend_transfer_type: str = None + grpc_migration_backend_server_address: str = None + kvtransfer_migration_backend_naming_url: str = None + migration_last_stage_max_blocks: int = None + migration_max_stages: int = None + + # init from engine args + enable_engine_pd_disagg: bool = None + + def __post_init__(self): + # Check if all fields default to None + for field_info in dataclasses.fields(self): + if field_info.default is not None: + raise ValueError(f"The default value of '{field_info.name}' should be None") + + for attr in dataclasses.fields(self): + if getattr(self, attr.name) is None: + if hasattr(_C.INSTANCE, attr.name.upper()): + setattr(self, attr.name, getattr(_C.INSTANCE, attr.name.upper())) + + def init_from_engine_args(self, engine_args, backend_type: BackendType): + if backend_type == BackendType.BLADELLM: + self.enable_engine_pd_disagg = engine_args.enable_disagg + elif backend_type == BackendType.VLLM: + self.enable_engine_pd_disagg = False + elif backend_type == BackendType.SIM_VLLM: + self.enable_engine_pd_disagg = False + else: + raise ValueError(f"Unsupported backend type: {backend_type}") + + @classmethod + def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'InstanceArgs': + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + cfg_attrs = [attr for attr in attrs if hasattr(cfg.INSTANCE, attr.upper())] + # Set the attributes from the parsed arguments. + # The defalut values of attributes are defined in default.py. + instance_args = cls(**{attr: getattr(cfg.INSTANCE, attr.upper()) for attr in cfg_attrs}) + return instance_args + + @classmethod + def check_args(cls, args: 'InstanceArgs', manager_args: ManagerArgs, + launch_model: LaunchMode, parser: argparse.ArgumentParser): + # pylint: disable=protected-access + for action in parser._optionals._actions: + if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest): + assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}." + + assert not args.simulator_mode or args.profiling_result_file_path is not None, \ + "Set profiling_result_file_path args when enable simulator mode" + + # instance_type check + if manager_args.enable_pd_disagg and launch_model == LaunchMode.LOCAL: + assert args.instance_type in ['prefill', 'decode'], \ + "instance_type should be prefill or decode if enable_pd_disagg is set." + + def create_migration_config(self) -> MigrationConfig: + migration_config = MigrationConfig(self.request_migration_policy, + self.migration_backend, + self.migration_buffer_blocks, + self.migration_num_layers, + self.migration_last_stage_max_blocks, + self.migration_max_stages, + self.migration_backend_init_timeout, + self.migration_backend_transfer_type, + self.grpc_migration_backend_server_address, + self.kvtransfer_migration_backend_naming_url) + return migration_config + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument('--instance-type', + type=str, + choices=['prefill', 'decode', 'no_constraints'], + help="instance type for the engine. When non-pd-disaggregation option, set instance_type \ + to no_constraints. For pd-disaggregation implemented via LLuminx, specify instance_type \ + as either prefill or decode for local launch model and it is not necessary to set for \ + global launch model as the manager will automatically determine the instance type and \ + quantity based on the --pd-ratio. When pd-disaggregation is handled internally within \ + the LLM engine, don't set --enable-pd-disagg. --instance-type parameters should not \ + alse be set. Instead, the instance_type will be automatically assigned to either prefill \ + or decode based on engine_args for local launch mode, and donot set it for global launch \ + model.") parser.add_argument('--profiling-result-file-path', type=str, help='profiling result file path when using simulator') parser.add_argument('--simulator-mode', action='store_true', help='enable simulator mode') + parser.add_argument('--dispatch-load-metric', + type=str, + choices=['remaining_steps', 'usage_ratio'], + help='instance dispatch load metric') + parser.add_argument('--migration-load-metric', + type=str, + choices=['remaining_steps', 'usage_ratio'], + help='instance migration load metric') + parser.add_argument('--enable-defrag', + type=bool, + help='enable defragmentation through migration based on virtual usage') + parser.add_argument('--request-migration-policy', + type=str, + default=None, + choices=['LCR', 'SR', 'LR', 'FCW', 'FCWSR'], + help='The request migration policy.\n\n' + '* "LCR" migrate the running request last come.\n' + '* "SR" migrate the running request shortest.\n' + '* "LR" migrate the running request longest.\n' + '* "FCW" migrate the waiting request first come.\n' + '* "FCWSR" migrate the waiting request first come and running request shortest.\n') parser.add_argument('--migration-backend', type=str, choices=['gloo','nccl','rayrpc','grpc','kvtransfer'], @@ -341,7 +435,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help='timeout(s) for initializing migration backend') parser.add_argument('--migration-backend-transfer-type', type=str, - choices=['cuda_ipc','rdma', ''], + choices=['cuda_ipc','rdma'], help='transfer type for migration backend grpc and kvTransfer') parser.add_argument('--grpc-migration-backend-server-address', type=str, @@ -349,30 +443,10 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--kvtransfer-migration-backend-naming-url', type=str, help='url of naming server for kvtransfer migration backend') - parser.add_argument('--max-stages', + parser.add_argument('--migration-max-stages', type=int, - help='drop migration if the number of stages > max_stages') - parser.add_argument('--last-stage-max-blocks', + help='drop migration if the number of stages > migration_max_stages') + parser.add_argument('--migration-last-stage-max-blocks', type=int, - help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration') - - parser.add_argument('--enable-pd-disagg', - action='store_true', - help='enable prefill decoding disaggregation') - parser.add_argument('--num-dispatch-instances', - type=int, - help='number of available instances for dispatch') - - parser.add_argument('--enable-port-increment', - action='store_true', - help='enable port increment when desploying multiple servers') - parser.add_argument('--enable-port-offset-store', - action='store_true', - help='enable store port offset when desploying multiple servers') - + help='if the number pf remain blocks < migration_last_stage_max_blocks, do last stage migration') return parser - -@dataclass -class LaunchArgs: - launch_mode: LaunchMode = None - backend_type: BackendType = None diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 03180cad..0b7f6c9a 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -170,7 +170,8 @@ async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]: instance_info.instance_id = self.instance_id instance_info.step_id = next(self.step_counter) instance_info.timestamp = time.time() - instance_info.profiling_data=(instance_info.inference_type.value, + # TODO(KuilongCui): add cli_args to determine whether to collect profiling data + instance_info.profiling_data=(instance_info.inference_type.value if instance_info.inference_type else "", instance_info.num_seqs, sum(instance_info.running_seq_lens), self.model_executor.last_inference_latency) diff --git a/llumnix/backends/vllm/utils.py b/llumnix/backends/vllm/utils.py index 9b9826b4..a44ba1e4 100644 --- a/llumnix/backends/vllm/utils.py +++ b/llumnix/backends/vllm/utils.py @@ -23,7 +23,7 @@ _modify_greedy_probs_inplace, _beam_search_sample from llumnix.logging.logger import init_logger -from llumnix.arg_utils import ManagerArgs +from llumnix.arg_utils import InstanceArgs logger = init_logger(__name__) @@ -41,15 +41,15 @@ def detect_unsupported_feature(engine_args: EngineArgs) -> None: if unsupported_feature: raise ValueError(f'Unsupported feature: Llumnix does not support "{unsupported_feature}" currently.') -def check_engine_args(engine_args: AsyncEngineArgs, manager_args: ManagerArgs) -> None: +def check_engine_args(engine_args: AsyncEngineArgs, intance_args: InstanceArgs) -> None: assert engine_args.engine_use_ray and engine_args.worker_use_ray, \ ("In Llumnix, engine and worker must be ray actor.") - migration_config = manager_args.create_migration_config() + migration_config = intance_args.create_migration_config() engine_config = engine_args.create_engine_config() parallel_config = engine_config.parallel_config if parallel_config.world_size > 1 and migration_config.migration_backend == 'nccl': logger.warning("Llumnix does not support TP or PP when the migration backend is nccl, change migration backend to gloo.") - manager_args.migration_backend = 'gloo' + intance_args.migration_backend = 'gloo' detect_unsupported_feature(engine_args) def _get_dtype_size(dtype: torch.dtype) -> int: diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index 5879c11f..00fe4881 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -141,9 +141,5 @@ def warmup(self) -> bool: def shutdown(self) -> None: torch.cuda.synchronize() - del self.model_runner - del self.cache_engine - del self.gpu_cache - del self.migration_backend torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 0527cba6..7576c257 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -11,8 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - from .config import LlumnixConfig as LC # ----------------------------------------------------------------------------- @@ -47,9 +45,7 @@ # Disable keep serve process alive _C.SERVER.DISABLE_KEEP_SERVE_PROCESS_ALIVE = False -# ----------------------------------------------------------------------------- -# RAY CONFIGURATION -# ----------------------------------------------------------------------------- +# ----------------------------- RAY CONFIGURATION ----------------------------- # If True, launch Ray cluster in API server _C.SERVER.LAUNCH_RAY_CLUSTER = False # Port number for the Ray cluster @@ -71,63 +67,34 @@ _C.MANAGER.LOG_INSTANCE_INFO = False # Log filename _C.MANAGER.LOG_FILENAME = "server.log" -# Enable simulator mode -_C.MANAGER.SIMULATOR_MODE = False -# Profiling result file path when using simulator -_C.MANAGER.PROFILING_RESULT_FILE_PATH = None # Enable port increment when deploying multiple servers _C.MANAGER.ENABLE_PORT_INCREMENT = False # Enable store port offset when deploying multiple servers _C.MANAGER.ENABLE_PORT_OFFSET_STORE = False +# Enable prefill decoding disaggregation +_C.MANAGER.ENABLE_PD_DISAGG = False +# The p:d ratio used in gloabl launch model +_C.MANAGER.PD_RATIO = "1:1" -# ----------------------------------------------------------------------------- -# DISPATCH CONFIGURATION -# ----------------------------------------------------------------------------- -# Instance load metric -_C.MANAGER.LOAD_METRIC = 'remaining_steps' +# -------------------------- DISPATCH CONFIGURATION --------------------------- # Request dispatch policy _C.MANAGER.DISPATCH_POLICY = 'load' -# ----------------------------------------------------------------------------- -# MIGRATION CONFIGURATION -# ----------------------------------------------------------------------------- +# -------------------------- MIGRATION CONFIGURATION -------------------------- # Enable migrate requests between instances _C.MANAGER.ENABLE_MIGRATION = False -# Enable defragmentation through migration based on virtual usage -_C.MANAGER.ENABLE_DEFRAG = False # Pair migration frequency _C.MANAGER.PAIR_MIGRATION_FREQUENCY = 1 # Pair migration policy _C.MANAGER.PAIR_MIGRATION_POLICY = 'defrag_constrained' # Migrate out instance load threshold -_C.MANAGER.MIGRATE_OUT_THRESHOLD = 3.0 -# Request migration policy -_C.MANAGER.REQUEST_MIGRATION_POLICY = 'SR' -# Drop migration if the number of stages > max_stages -_C.MANAGER.MAX_STAGES = 3 -# If the number of remain blocks < last_stage_max_blocks, do last stage migration -_C.MANAGER.LAST_STAGE_MAX_BLOCKS = 16 - -# Communication backend of migration -_C.MANAGER.MIGRATION_BACKEND = "gloo" -# Number of cache blocks in migration -_C.MANAGER.MIGRATION_BUFFER_BLOCKS = 512 -# Number of kv-cache layers to transfer in each round during migration -_C.MANAGER.MIGRATION_NUM_LAYERS = 1 -# Timeout(s) for initializing migration backend -_C.MANAGER.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0 -# Transfer type for migration backend kvTransfer -_C.MANAGER.MIGRATION_BACKEND_TRANSFER_TYPE = "rdma" -# Address of grpc server for migration backend -_C.MANAGER.GRPC_MIGRATION_BACKEND_SERVER_ADDRESS = "127.0.0.1:50051" -# URL of naming server for kvtransfer migration backend -_C.MANAGER.KVTRANSFER_MIGRATION_BACKEND_NAMING_URL = "file:/tmp/llumnix/naming/" +_C.MANAGER.MIGRATE_OUT_THRESHOLD = -3.0 -# ----------------------------------------------------------------------------- -# SCALING CONFIGURATION -# ----------------------------------------------------------------------------- +# --------------------------- SCALING CONFIGURATION --------------------------- # Enable auto scaling _C.MANAGER.ENABLE_SCALING = False +# Instance scaling load metric +_C.MANAGER.SCALING_LOAD_METRIC = 'remaining_steps' # Minimum number of instances _C.MANAGER.MIN_INSTANCES = 1 # Maximum number of instances @@ -137,14 +104,47 @@ # Scaling policy _C.MANAGER.SCALING_POLICY = 'avg_load' # Scale up threshold -_C.MANAGER.SCALE_UP_THRESHOLD = 10 +_C.MANAGER.SCALE_UP_THRESHOLD = -10 # Scale down threshold -_C.MANAGER.SCALE_DOWN_THRESHOLD = 60 +_C.MANAGER.SCALE_DOWN_THRESHOLD = -60 # ----------------------------------------------------------------------------- -# PREFILL DECODING DISAGGREGATION CONFIGURATION +# INSTANCE CONFIGURATION # ----------------------------------------------------------------------------- -# Enable prefill decoding disaggregation -_C.MANAGER.ENABLE_PD_DISAGG = False -# Number of available instances for dispatch. math.inf indicates that all instances can be used for dispatching -_C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf +_C.INSTANCE = LC() +# Engine types: prefill, decode, no_constraints +_C.INSTANCE.INSTANCE_TYPE = "no_constraints" +# Enable simulator mode +_C.INSTANCE.SIMULATOR_MODE = False +# Profiling result file path when using simulator +_C.INSTANCE.PROFILING_RESULT_FILE_PATH = None + +# ------------------------- LOAD METRICS CONFIGURATION ------------------------ +# Instance dispatch load metric +_C.INSTANCE.DISPATCH_LOAD_METRIC = 'remaining_steps' +# Instance migration load metric +_C.INSTANCE.MIGRATION_LOAD_METRIC = 'remaining_steps' + +# -------------------------- MIGRATION CONFIGURATION -------------------------- +# Enable defragmentation through migration based on virtual usage +_C.INSTANCE.ENABLE_DEFRAG = False +# Request migration policy +_C.INSTANCE.REQUEST_MIGRATION_POLICY = 'SR' +# Drop migration if the number of stages > migration_max_stages +_C.INSTANCE.MIGRATION_MAX_STAGES = 3 +# If the number of remain blocks < migration_last_stage_max_blocks, do last stage migration +_C.INSTANCE.MIGRATION_LAST_STAGE_MAX_BLOCKS = 16 +# Communication backend of migration +_C.INSTANCE.MIGRATION_BACKEND = "gloo" +# Number of cache blocks in migration +_C.INSTANCE.MIGRATION_BUFFER_BLOCKS = 512 +# Number of kv-cache layers to transfer in each round during migration +_C.INSTANCE.MIGRATION_NUM_LAYERS = 1 +# Timeout(s) for initializing migration backend +_C.INSTANCE.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0 +# Transfer type for migration backend kvTransfer +_C.INSTANCE.MIGRATION_BACKEND_TRANSFER_TYPE = "rdma" +# Address of grpc server for migration backend +_C.INSTANCE.GRPC_MIGRATION_BACKEND_SERVER_ADDRESS = "127.0.0.1:50051" +# URL of naming server for kvtransfer migration backend +_C.INSTANCE.KVTRANSFER_MIGRATION_BACKEND_NAMING_URL = "file:/tmp/llumnix/naming/" diff --git a/llumnix/entrypoints/bladellm/api_server.py b/llumnix/entrypoints/bladellm/api_server.py index 537798f5..3ac445a4 100644 --- a/llumnix/entrypoints/bladellm/api_server.py +++ b/llumnix/entrypoints/bladellm/api_server.py @@ -17,7 +17,7 @@ from llumnix.config import get_llumnix_config from llumnix.backends.backend_interface import BackendType from llumnix.arg_utils import (EntrypointsArgs, ManagerArgs, LlumnixArgumentParser, - LaunchArgs) + LaunchArgs, InstanceArgs) from llumnix.entrypoints.setup import setup_ray_cluster, setup_llumnix from llumnix.entrypoints.bladellm.client import LlumnixClientBladeLLM from llumnix.entrypoints.bladellm.utils import get_args @@ -29,19 +29,19 @@ def setup_llumnix_api_server(bladellm_args: ServingArgs, loop: asyncio.AbstractE llumnix_parser = LlumnixArgumentParser() llumnix_parser = EntrypointsArgs.add_cli_args(llumnix_parser) llumnix_parser = ManagerArgs.add_cli_args(llumnix_parser) + llumnix_parser = InstanceArgs.add_cli_args(llumnix_parser) llumnix_config = get_llumnix_config(bladellm_args.llumnix_config) - entrypoints_args, manager_args, engine_args = get_args(llumnix_config, llumnix_parser, bladellm_args) + entrypoints_args, manager_args, instance_args, engine_args = \ + get_args(llumnix_config, llumnix_parser, bladellm_args) - assert not manager_args.simulator_mode, "Only support the simulator mode for vLLM." launch_args = LaunchArgs(launch_mode=LaunchMode.LOCAL, backend_type=BackendType.BLADELLM) - setup_ray_cluster(entrypoints_args) llumnix_client = None # if gpu is not available, it means that this node is head pod x any llumnix components if is_gpu_available(): llumnix_context: EntrypointsContext = \ - setup_llumnix(manager_args, entrypoints_args, engine_args, launch_args) + setup_llumnix(entrypoints_args, manager_args, instance_args, engine_args, launch_args) llumnix_client = LlumnixClientBladeLLM(bladellm_args, llumnix_context, loop) return llumnix_client diff --git a/llumnix/entrypoints/bladellm/utils.py b/llumnix/entrypoints/bladellm/utils.py index 8283e1fe..7d6675db 100644 --- a/llumnix/entrypoints/bladellm/utils.py +++ b/llumnix/entrypoints/bladellm/utils.py @@ -19,7 +19,10 @@ logger = init_logger(__name__) -def detect_unsupported_feature(engine_args: ServingArgs) -> None: +from llumnix.backends.backend_interface import BackendType +from llumnix.arg_utils import EntrypointsArgs, ManagerArgs, InstanceArgs, LaunchMode + +def detect_unsupported_engine_feature(engine_args: ServingArgs) -> None: unsupported_feature = None if engine_args.enable_lora: unsupported_feature = "multi-lora serving" @@ -29,28 +32,33 @@ def detect_unsupported_feature(engine_args: ServingArgs) -> None: unsupported_feature = "speculative decoding" elif engine_args.enable_remote_worker: unsupported_feature = "enable_remote_worker" + elif engine_args.enable_hybrid_dp: + unsupported_feature = "hybrid data parallel" if unsupported_feature: raise ValueError(f'Llumnix does not support "{unsupported_feature}" for bladeLLM currently.') -def check_engine_args(engine_args: ServingArgs, manager_args: ManagerArgs) -> None: - migration_config = manager_args.create_migration_config() - if (engine_args.tensor_parallel_size > 1 or engine_args.tensor_parallel_size > 1) and \ - migration_config.migration_backend == 'nccl': - logger.warning("Llumnix does not support TP or PP when the migration backend is nccl, \ - change migration backend to gloo.") - manager_args.migration_backend = 'gloo' - detect_unsupported_feature(engine_args) - -def get_args(llumnix_cfg, llumnix_parser, engine_args): +def get_args(llumnix_cfg, llumnix_parser, engine_args: ServingArgs): + instance_args = InstanceArgs.from_llumnix_config(llumnix_cfg) + instance_args.init_from_engine_args(engine_args, BackendType.BLADELLM) + manager_args = ManagerArgs.from_llumnix_config(llumnix_cfg) + manager_args.init_from_instance_args(instance_args) entrypoints_args = EntrypointsArgs.from_llumnix_config(llumnix_cfg) + EntrypointsArgs.check_args(entrypoints_args, llumnix_parser) - manager_args = ManagerArgs.from_llumnix_config(llumnix_cfg) + instance_args.check_args(instance_args, manager_args, LaunchMode.LOCAL, llumnix_parser) ManagerArgs.check_args(manager_args, llumnix_parser) - check_engine_args(engine_args, manager_args) - logger.info("entrypoints_args: {}".format(entrypoints_args)) - logger.info("manager_args: {}".format(manager_args)) - logger.info("engine_args: {}".format(engine_args)) + assert not manager_args.simulator_mode, "Only support the simulator mode for vLLM." + + assert not (engine_args.enable_disagg and manager_args.enable_pd_disagg), \ + "Cannot enable both pd-disaggregation inside the LLM engine and pd-disaggregation from Lluminx." + + detect_unsupported_engine_feature(engine_args) + + logger.info("entrypoints_args: {}", entrypoints_args) + logger.info("manager_args: {}", manager_args) + logger.info("instance_args: {}", instance_args) + logger.info("engine_args: {}", engine_args) - return entrypoints_args, manager_args, engine_args + return entrypoints_args, manager_args, instance_args, engine_args diff --git a/llumnix/entrypoints/setup.py b/llumnix/entrypoints/setup.py index 6fc48c3a..ea504275 100644 --- a/llumnix/entrypoints/setup.py +++ b/llumnix/entrypoints/setup.py @@ -22,7 +22,7 @@ from llumnix.llumlet.llumlet import Llumlet from llumnix.logging.logger import init_logger from llumnix.utils import random_uuid, get_manager_name -from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, LaunchArgs +from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, LaunchArgs, InstanceArgs from llumnix.queue.queue_type import QueueType from llumnix.server_info import ServerInfo from llumnix.queue.utils import init_request_output_queue_server @@ -92,33 +92,41 @@ def setup_ray_cluster(entrypoints_args) -> None: log_to_driver=not entrypoints_args.disable_log_to_driver) def init_manager(manager_args: ManagerArgs, + instance_args: InstanceArgs = None, entrypoints_args: EntrypointsArgs = None, engine_args = None, launch_args: LaunchArgs = None, ) -> Manager: # Only one instance create the manager actor, the other instances get the existing manager actor through ray. try: - manager = Manager.from_args(manager_args=manager_args, - entrypoints_args=entrypoints_args, - engine_args=engine_args, - launch_args=launch_args) + manager = Manager.from_args( + entrypoints_args=entrypoints_args, + manager_args=manager_args, + instance_args=instance_args, + engine_args=engine_args, + launch_args=launch_args) logger.info("Init Manager on current node.") except ValueError: manager = ray.get_actor(get_manager_name(), namespace='llumnix') logger.info("Get existing Manager.") return manager -def init_llumnix_components(manager_args: ManagerArgs, +def init_llumnix_components(entrypoints_args: EntrypointsArgs, + manager_args: ManagerArgs, + instance_args: InstanceArgs, engine_args, - request_output_queue_type: QueueType, - request_output_queue_port: str, - backend_type: BackendType) -> Tuple[Manager, List[str], List[Llumlet], QueueServerBase]: + launch_args: LaunchArgs, + ) -> Tuple[Manager, List[str], List[Llumlet], QueueServerBase]: manager = init_manager(manager_args) + backend_type: BackendType = launch_args.backend_type + request_output_queue_type: QueueType = QueueType(entrypoints_args.request_output_queue_type) instance_ids, instances = retry_manager_method_sync( - manager.init_instances.remote, 'init_instances', request_output_queue_type, backend_type, engine_args) + manager.init_instances.remote, 'init_instances', request_output_queue_type, + backend_type, instance_args, engine_args) ip = get_ip_address() + request_output_queue_port: str = entrypoints_args.request_output_queue_port request_output_queue = init_request_output_queue_server(ip, request_output_queue_port, request_output_queue_type) return manager, instance_ids, instances, request_output_queue @@ -138,32 +146,25 @@ def setup_entrypoints_context(entrypoints_args, manager, instance_ids, instances log_requests = not entrypoints_args.disable_log_requests_server log_request_timestamps = entrypoints_args.log_request_timestamps - logger.info("log_requests: {}, log_request_timestamps: {}".format(log_requests, log_request_timestamps)) - entrypoints_context = EntrypointsContext(manager, instances_dict, request_output_queue, server_info, log_requests, log_request_timestamps) - return entrypoints_context -def _setup_llumnix_local(manager_args, entrypoints_args, engine_args, launch_args) -> EntrypointsContext: +def _setup_llumnix_local(entrypoints_args, manager_args, instance_args, engine_args, launch_args) -> EntrypointsContext: manager, instance_ids, instances, request_output_queue = \ - init_llumnix_components(manager_args, - engine_args, - QueueType(entrypoints_args.request_output_queue_type), - entrypoints_args.request_output_queue_port, - launch_args.backend_type) + init_llumnix_components(entrypoints_args, manager_args, instance_args, engine_args, launch_args) return setup_entrypoints_context(entrypoints_args, manager, instance_ids, instances, request_output_queue) -def _setup_llumnix_global(manager_args, entrypoints_args, engine_args, launch_args) -> None: - _ = init_manager(manager_args, entrypoints_args, engine_args, launch_args) +def _setup_llumnix_global(entrypoints_args, manager_args, instance_args, engine_args, launch_args) -> None: + _ = init_manager(manager_args, instance_args, entrypoints_args, engine_args, launch_args) -def setup_llumnix(manager_args, entrypoints_args, engine_args, launch_args) -> Optional[EntrypointsContext]: +def setup_llumnix(entrypoints_args, manager_args, instance_args, engine_args, launch_args) -> Optional[EntrypointsContext]: if launch_args.launch_mode == LaunchMode.LOCAL: - return _setup_llumnix_local(manager_args, entrypoints_args, engine_args, launch_args) + return _setup_llumnix_local(entrypoints_args, manager_args, instance_args, engine_args, launch_args) - return _setup_llumnix_global(manager_args, entrypoints_args, engine_args, launch_args) + return _setup_llumnix_global(entrypoints_args, manager_args, instance_args, engine_args, launch_args) diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 0e89a8ee..5abdeb69 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -179,9 +179,9 @@ async def is_ready() -> bool: cli_args = add_cli_args(parser) cfg = get_llumnix_config(cli_args.config_file, cli_args) - entrypoints_args, manager_args, engine_args = get_args(cfg, parser, cli_args) - backend_type = BackendType.VLLM if not manager_args.simulator_mode else BackendType.SIM_VLLM + entrypoints_args, manager_args, instance_args, engine_args = get_args(cfg, LaunchMode.LOCAL, parser, cli_args) + backend_type = BackendType.VLLM if not instance_args.simulator_mode else BackendType.SIM_VLLM launch_args = LaunchArgs(launch_mode=LaunchMode.LOCAL, backend_type=backend_type) # Launch or connect to the ray cluster for multi-node serving. @@ -189,7 +189,7 @@ async def is_ready() -> bool: # if gpu is not available, it means that this node is head pod without any llumnix components if is_gpu_available(): - entrypoints_context = setup_llumnix(manager_args, entrypoints_args, engine_args, launch_args) + entrypoints_context = setup_llumnix(entrypoints_args, manager_args, instance_args, engine_args, launch_args) llumnix_client = LlumnixClientVLLM(entrypoints_context) # Start the api server after all the components of llumnix are ready. diff --git a/llumnix/entrypoints/vllm/arg_utils.py b/llumnix/entrypoints/vllm/arg_utils.py index dbb2dad0..ac35d308 100644 --- a/llumnix/entrypoints/vllm/arg_utils.py +++ b/llumnix/entrypoints/vllm/arg_utils.py @@ -1,31 +1,40 @@ from vllm.engine.arg_utils import AsyncEngineArgs -from llumnix.backends.vllm.utils import check_engine_args from llumnix.arg_utils import EntrypointsArgs, ManagerArgs from llumnix.logging.logger import init_logger +from llumnix.backends.backend_interface import BackendType +from llumnix.backends.vllm.utils import check_engine_args +from llumnix.arg_utils import EntrypointsArgs, ManagerArgs, InstanceArgs, LlumnixArgumentParser logger = init_logger(__name__) -def add_cli_args(parser): +def add_cli_args(parser: LlumnixArgumentParser): parser.set_namespace("llumnix") parser = EntrypointsArgs.add_cli_args(parser) parser = ManagerArgs.add_cli_args(parser) + parser = InstanceArgs.add_cli_args(parser) parser.set_namespace("vllm") parser = AsyncEngineArgs.add_cli_args(parser) cli_args = parser.parse_args() return cli_args -def get_args(cfg, parser, cli_args): +def get_args(cfg, launch_model, parser, cli_args): + engine_args = AsyncEngineArgs.from_cli_args(cli_args) + instance_args: InstanceArgs = InstanceArgs.from_llumnix_config(cfg) + instance_args.init_from_engine_args(engine_args, BackendType.VLLM) + manager_args = ManagerArgs.from_llumnix_config(cfg) + manager_args.init_from_instance_args(instance_args) entrypoints_args = EntrypointsArgs.from_llumnix_config(cfg) + EntrypointsArgs.check_args(entrypoints_args, parser) - manager_args = ManagerArgs.from_llumnix_config(cfg) ManagerArgs.check_args(manager_args, parser) - engine_args = AsyncEngineArgs.from_cli_args(cli_args) - check_engine_args(engine_args, manager_args) + InstanceArgs.check_args(instance_args, manager_args, launch_model, parser) + check_engine_args(engine_args, instance_args) logger.info("entrypoints_args: {}".format(entrypoints_args)) logger.info("manager_args: {}".format(manager_args)) + logger.info("instance_args: {}".format(instance_args)) logger.info("engine_args: {}".format(engine_args)) - return entrypoints_args, manager_args, engine_args + return entrypoints_args, manager_args, instance_args, engine_args diff --git a/llumnix/entrypoints/vllm/client.py b/llumnix/entrypoints/vllm/client.py index 31a7bf93..e1f88ca5 100644 --- a/llumnix/entrypoints/vllm/client.py +++ b/llumnix/entrypoints/vllm/client.py @@ -1,6 +1,7 @@ import copy import time import asyncio +from typing import Dict import ray from vllm.engine.async_llm_engine import AsyncStream diff --git a/llumnix/entrypoints/vllm/serve.py b/llumnix/entrypoints/vllm/serve.py index a73f1ce9..25ba4cd9 100644 --- a/llumnix/entrypoints/vllm/serve.py +++ b/llumnix/entrypoints/vllm/serve.py @@ -23,9 +23,9 @@ cli_args = add_cli_args(parser) cfg = get_llumnix_config(cli_args.config_file, cli_args) - entrypoints_args, manager_args, engine_args = get_args(cfg, parser, cli_args) + entrypoints_args, manager_args, instance_args, engine_args = get_args(cfg, LaunchMode.GLOBAL, parser, cli_args) - backend_type = BackendType.VLLM if not manager_args.simulator_mode else BackendType.SIM_VLLM + backend_type = BackendType.VLLM if not instance_args.simulator_mode else BackendType.SIM_VLLM launch_args = LaunchArgs(launch_mode=LaunchMode.GLOBAL, backend_type=BackendType.VLLM) # Assume that there is an existing ray cluster when using centralized deployment. @@ -35,7 +35,7 @@ request_output_queue = RayQueue(actor_options={"namespace": "llumnix", "name": "magic_ray_queue"}) - setup_llumnix(manager_args, entrypoints_args, engine_args, launch_args) + setup_llumnix(entrypoints_args, manager_args, instance_args, engine_args, launch_args) # keep the process alive to get the terminal output. if not entrypoints_args.disable_keep_serve_process_alive: diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index bfa5b79f..665a82b1 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -16,107 +16,70 @@ import random from llumnix.logging.logger import init_logger -from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo +from llumnix.instance_info import InstanceInfo +from llumnix.instance_info import InstanceInfo +from llumnix.arg_utils import InstanceArgs +from llumnix.instance_info import InstanceType logger = init_logger(__name__) +DISPATCH_LOG_FREQUENCY = 100 class DispatchScheduler: - def __init__(self, - dispatch_policy: str, - instance_load_calculator: InstanceLoadCalculator, - num_dispatch_instances: int) -> None: + def __init__(self, dispatch_policy: str,) -> None: self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy) - self.instance_load_calculator = instance_load_calculator - self.num_instances = 0 - self.instance_id_set: Set[str] = set() self.available_dispatch_instance_set: Set[str] = set() - self.num_dispatch_instances = num_dispatch_instances - # instance info args self.instance_info: Dict[str, InstanceInfo] = {} - self.sorted_instance_infos: List[InstanceInfo] = None # statistics - self.num_requests = 0 + self.total_num_requests = 0 self.instance_num_requests: Dict[str, int] = {} def dispatch(self) -> str: - self.num_requests += 1 - if isinstance(self.dispatch_policy, (Load, Queue)): - self._sort_instance_infos(descending=False) + self.total_num_requests += 1 dispatch_instance_id = self.dispatch_policy.dispatch(self.instance_num_requests, - self.sorted_instance_infos) + self.instance_info.values()) self.instance_num_requests[dispatch_instance_id] += 1 - if self.num_requests % 100 == 0: - logger.info("num_requests: {}".format(self.num_requests)) + if self.total_num_requests % DISPATCH_LOG_FREQUENCY == 0: + logger.info("dispatch scheduler total_dispatched_requests: {}".format(self.total_num_requests)) for instance_id, num_requests in self.instance_num_requests.items(): logger.info("instance {} num_dispatched_requests: {}".format(instance_id, num_requests)) return dispatch_instance_id - def update_instance_infos(self, - instance_info: Dict[str, InstanceInfo]) -> None: - self.instance_info = instance_info + def update_instance_infos(self, instance_infos: Dict[str, InstanceInfo]) -> None: + for instance_id, instance_info in instance_infos.items(): + if instance_id not in self.available_dispatch_instance_set: + continue + self.instance_info[instance_id] = instance_info - def add_instance(self, instance_id: str) -> None: - self.instance_id_set.add(instance_id) - self.num_instances = len(self.instance_id_set) - - # TODO(KuilongCui): a hacky method is being used to avoid the only-decode type engine dispatched - if "decode" not in instance_id: - if self.num_dispatch_instances <= 0 or (self.num_dispatch_instances > 0 and - len(self.available_dispatch_instance_set) < self.num_dispatch_instances): - self.available_dispatch_instance_set.add(instance_id) - self.instance_num_requests[instance_id] = 0 + def add_instance(self, instance_id: str, instance_args: InstanceArgs) -> None: + if instance_args.instance_type in [InstanceType.NO_CONSTRAINTS, InstanceType.PREFILL]: + self.available_dispatch_instance_set.add(instance_id) + self.instance_num_requests[instance_id] = 0 def remove_instance(self, instance_id: str) -> None: - self.instance_id_set.remove(instance_id) - self.num_instances = len(self.instance_id_set) - if instance_id in self.instance_num_requests: - del self.instance_num_requests[instance_id] - else: - logger.warning("instance {} not in instance_num_requests".format(instance_id)) if instance_id in self.available_dispatch_instance_set: self.available_dispatch_instance_set.remove(instance_id) - # TODO(KuilongCui): Check it when there is no decode instance. - if self.num_instances >= self.num_dispatch_instances: - free_instance_id = next(iter(self.instance_id_set - self.available_dispatch_instance_set)) - self.available_dispatch_instance_set.add(free_instance_id) - - def _sort_instance_infos(self, - descending: bool = True) -> None: - instance_infos: List[InstanceInfo] = list(self.instance_info.values()) - available_instance_infos = [info for info in instance_infos if info.instance_id in self.available_dispatch_instance_set] - if isinstance(self.dispatch_policy, Queue): - key_attr = 'num_waiting_requests' - else: - key_attr = 'instance_load_dispatch_scale' - self.sorted_instance_infos = sorted( - available_instance_infos, - key=lambda instance_info: getattr(instance_info, key_attr), - reverse=descending - ) + self.instance_num_requests.pop(instance_id, None) class DispatchPolicy(ABC): - def __init__(self): - self.instance_ptr = 0 - @abstractmethod def dispatch(self, instance_num_requests: Dict[str, int], - sorted_instance_infos: List[InstanceInfo]) -> int: + available_instance_infos: List[InstanceInfo]) -> str: pass # Dispatch all requests to a single instance, used only for testing class Flood(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], - sorted_instance_infos: List[InstanceInfo]) -> str: + available_instance_infos: List[InstanceInfo]) -> str: instance_id = max(instance_num_requests, key=instance_num_requests.get) return instance_id class Balanced(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], - sorted_instance_infos: List[InstanceInfo]) -> str: + available_instance_infos: List[InstanceInfo]) -> str: # dispatch request according to the number of requests dispatched to instance by manager instance_id = min(instance_num_requests, key=instance_num_requests.get) return instance_id @@ -124,35 +87,42 @@ def dispatch(self, class Load(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], - sorted_instance_infos: List[InstanceInfo]) -> str: + available_instance_infos: List[InstanceInfo]) -> str: + sorted_instance_infos = sorted( + available_instance_infos, + key=lambda instance_info: getattr(instance_info, 'dispatch_load_metric'), + ) instance_id = sorted_instance_infos[0].instance_id - logger.info("dispatch to {}, load: {}".format(instance_id, sorted_instance_infos[0].instance_load_dispatch_scale)) + logger.debug("dispatch to {}, load: {}".format(instance_id, sorted_instance_infos[0].dispatch_load_metric)) return instance_id class Queue(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], - sorted_instance_infos: List[InstanceInfo]) -> str: + available_instance_infos: List[InstanceInfo]) -> str: + sorted_instance_infos = sorted( + available_instance_infos, + key=lambda instance_info: getattr(instance_info, 'num_waiting_requests'), + ) min_queue_size = sorted_instance_infos[0].num_waiting_requests instance_id_list = [] for instance_info in sorted_instance_infos: if instance_info.num_waiting_requests == min_queue_size: instance_id_list.append(instance_info.instance_id) instance_id = random.choice(instance_id_list) - logger.info("dispatch to {}, queue size: {}".format(instance_id, sorted_instance_infos[0].num_waiting_requests)) + logger.debug("dispatch to {}, queue size: {}".format(instance_id, sorted_instance_infos[0].num_waiting_requests)) return instance_id class RoundRobin(DispatchPolicy): - prev_instance_idx: int = -1 + next_instance_idx: int = 0 def dispatch(self, instance_num_requests: Dict[str, int], - sorted_instance_infos: List[InstanceInfo]) -> str: + available_instance_infos: List[InstanceInfo]) -> str: all_instance_ids = sorted(instance_num_requests.keys()) - cur_instance_idx = (self.prev_instance_idx + 1) % len(all_instance_ids) - - target_instance_id = all_instance_ids[cur_instance_idx] - self.prev_instance_idx = cur_instance_idx + assert len(all_instance_ids) > 0 + target_instance_id = all_instance_ids[self.next_instance_idx % len(all_instance_ids)] + self.next_instance_idx += 1 return target_instance_id class DispatchPolicyFactory: diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index 2d162452..548e7fe4 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -16,60 +16,45 @@ from llumnix.logging.logger import init_logger from llumnix.internal_config import GlobalSchedulerConfig -from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo +from llumnix.instance_info import InstanceInfo from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler from llumnix.global_scheduler.migration_scheduler import MigrationScheduler from llumnix.global_scheduler.migration_policy import PairMigrationConstraints from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler +from llumnix.arg_utils import InstanceArgs logger = init_logger(__name__) class GlobalScheduler: - def __init__(self, - global_scheduler_config: GlobalSchedulerConfig) -> None: + def __init__(self, global_scheduler_config: GlobalSchedulerConfig) -> None: self.global_scheduler_config = global_scheduler_config - # instance load and instance info args - self.load_metric = global_scheduler_config.load_metric - self.enable_defrag = global_scheduler_config.enable_defrag - self.enable_pd_disagg = global_scheduler_config.enable_pd_disagg - self.instance_load_calculator = InstanceLoadCalculator(load_metric=self.load_metric, - enable_defrag=self.enable_defrag) + self.num_instances = 0 + self.instance_id_set: Set[str] = set() + self.instance_info: Dict[str, InstanceInfo] = {} + # dispatch args - self.dispatch_policy = global_scheduler_config.dispatch_policy - self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy, - self.instance_load_calculator, - global_scheduler_config.num_dispatch_instances) + self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy) # migrate args self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy, global_scheduler_config.migrate_out_load_threshold, - self.instance_load_calculator, - global_scheduler_config.migration_backend) + global_scheduler_config.is_group_kind_migration_backend) # auto-scaling args self.scaling_scheduler = ScalingScheduler(global_scheduler_config.scale_up_threshold, global_scheduler_config.scale_down_threshold, global_scheduler_config.scaling_policy, - self.instance_load_calculator, - self.enable_pd_disagg, - global_scheduler_config.num_dispatch_instances) - - self.num_instances = 0 - self.instance_id_set: Set[str] = set() - self.instance_info: Dict[str, InstanceInfo] = {} + global_scheduler_config.scaling_load_metric, + global_scheduler_config.enable_pd_disagg,) def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None: for instance_info in instance_infos: if instance_info.instance_id in self.instance_id_set: - # Llumnix have different instance load compuatation methods for dispatch/migrate/scale. - instance_info.instance_load_dispatch_scale = self.instance_load_calculator.compute_instance_load(instance_info, action='dispatch') - instance_info.instance_load_migrate = self.instance_load_calculator.compute_instance_load(instance_info, action='migrate') - instance_info.instance_type = self.scaling_scheduler.get_instance_type_info(instance_info.instance_id) self.instance_info[instance_info.instance_id] = instance_info def dispatch(self) -> str: self.dispatch_scheduler.update_instance_infos(self.instance_info) instance_id = self.dispatch_scheduler.dispatch() - request_expected_steps = 1 if self.enable_pd_disagg else math.inf + request_expected_steps = 1 if self.global_scheduler_config.enable_pd_disagg else math.inf return instance_id, request_expected_steps def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]: @@ -82,17 +67,17 @@ def check_scale(self) -> Tuple[str, str]: scale_up_num, scale_down_num = self.scaling_scheduler.check_scale() return scale_up_num, scale_down_num - def scale_up(self, instance_id: Union[str, Iterable[str]]) -> int: + def scale_up(self, instance_id: Union[str, Iterable[str]], instance_args: List[InstanceArgs]) -> int: if isinstance(instance_id, str): instance_id = [instance_id,] instance_ids = list(instance_id) - for ins_id in instance_ids: + for ins_id, ins_args in zip(instance_ids, instance_args): if ins_id not in self.instance_id_set: logger.info("Scale up instance: {}.".format(ins_id)) new_intance_info = self._get_empty_instance_info() new_intance_info.instance_id = ins_id self.instance_info[ins_id] = new_intance_info - self._add_instance(ins_id) + self._add_instance(ins_id, ins_args) logger.info("num_instances: {}, instances: {}".format(self.num_instances, self.instance_id_set)) return self.num_instances @@ -111,12 +96,12 @@ def scale_down(self, instance_id: Union[str, Iterable[str]]) -> int: logger.info("num_instances: {}, instances: {}".format(self.num_instances, self.instance_id_set)) return self.num_instances - def _add_instance(self, instance_id: str) -> None: + def _add_instance(self, instance_id: str, instance_args: InstanceArgs) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) for scheduler in (self.dispatch_scheduler, self.migration_scheduler, self.scaling_scheduler): scheduler.update_instance_infos(self.instance_info) - scheduler.add_instance(instance_id) + scheduler.add_instance(instance_id, instance_args) def _remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) diff --git a/llumnix/global_scheduler/migration_filter.py b/llumnix/global_scheduler/migration_filter.py index 2d1d049a..b2bfd943 100644 --- a/llumnix/global_scheduler/migration_filter.py +++ b/llumnix/global_scheduler/migration_filter.py @@ -16,7 +16,7 @@ from llumnix.logging.logger import init_logger from llumnix.instance_info import InstanceInfo -from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.instance_info import InstanceType from llumnix.global_scheduler.migration_policy import PairMigrationConstraints logger = init_logger(__name__) @@ -80,12 +80,12 @@ class LoadConstrainedFilter(MigrationFilterPolicy): def filter_src_condition(self, filter_config: MigrationFilterConfig, pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: return lambda instance_info: instance_info.num_killed_requests > 0 \ - or instance_info.instance_load_migrate > filter_config.migrate_out_load_threshold + or instance_info.migration_load_metric > filter_config.migrate_out_load_threshold def filter_dst_condition(self, filter_config: MigrationFilterConfig, pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: return lambda instance_info: instance_info.num_killed_requests == 0 \ - and instance_info.instance_load_migrate < filter_config.migrate_out_load_threshold + and instance_info.migration_load_metric < filter_config.migrate_out_load_threshold class PddFilter(MigrationFilterPolicy): INSTANCE_FILTER_RULES = { @@ -102,7 +102,7 @@ def filter_src_condition(self, filter_config: MigrationFilterConfig, inner_policy = MigrationFilterPolicyFactory.get_policy('load') policy_filter = inner_policy.filter_src_condition(filter_config, pair_migration_type) else: - policy_filter = lambda instance_info: True + policy_filter = lambda _: True return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info) diff --git a/llumnix/global_scheduler/migration_policy.py b/llumnix/global_scheduler/migration_policy.py index 5e9eae4c..6ade3fcb 100644 --- a/llumnix/global_scheduler/migration_policy.py +++ b/llumnix/global_scheduler/migration_policy.py @@ -14,11 +14,11 @@ from typing import List, Tuple from abc import ABC, abstractmethod from enum import Enum -import copy import numpy as np from llumnix.logging.logger import init_logger -from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator +from llumnix.instance_info import InstanceInfo +from llumnix.instance_info import InstanceInfo logger = init_logger(__name__) @@ -31,11 +31,8 @@ class PairMigrationConstraints(str, Enum): PREFILL_2_DECODING = "PREFILL_2_DECODING" class PairMigrationPolicy(ABC): - def __init__(self, - migrate_out_load_threshold: float, - instance_load_calculator: InstanceLoadCalculator) -> None: + def __init__(self, migrate_out_load_threshold: float) -> None: self.migrate_out_load_threshold = migrate_out_load_threshold - self.instance_load_calculator = instance_load_calculator @abstractmethod def pair_migration(self, @@ -45,7 +42,7 @@ def pair_migration(self, raise NotImplementedError def sort_instance_infos(self, instance_infos: List[InstanceInfo], descending: bool = True) -> None: - key_attr = 'instance_load_migrate' + key_attr = 'migration_load_metric' sorted_instance_infos = sorted( instance_infos, key=lambda instance_info: getattr(instance_info, key_attr), @@ -62,33 +59,18 @@ def pair_migration(self, sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False) migrate_instance_pairs = [] for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): - load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate - - left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) - right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) - + load_diff_before_mig = sorted_src_instance_infos[i].migration_load_metric - sorted_dst_instance_infos[i].migration_load_metric + left_load_after_mig = sorted_src_instance_infos[i].migration_load_metric_after_migrate_out + right_load_after_mig = sorted_dst_instance_infos[i].migration_load_metric_after_migrate_in # Add some constrains to reduce unnecessary migrations if right_load_after_mig > self.migrate_out_load_threshold: continue load_diff_after_mig = left_load_after_mig - right_load_after_mig - if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf): + if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].migration_load_metric == -np.inf): migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id)) return migrate_instance_pairs - def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: - instance_info_after_migrate = copy.deepcopy(instance_info) - num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request - - if is_migrate_in: - instance_info_after_migrate.num_running_requests += 1 - instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request - else: - instance_info_after_migrate.num_running_requests -= 1 - instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request - - return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate') - class DefragConstrained(PairMigrationPolicy): def pair_migration(self, src_instance_infos: List[InstanceInfo], diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 86f2d37f..69e1b4cc 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -11,66 +11,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Set +from typing import Dict, List, Set, Tuple from llumnix.logging.logger import init_logger -from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator +from llumnix.instance_info import InstanceInfo from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig, CustomFilter from llumnix.global_scheduler.migration_policy import PairMigrationConstraints, PairMigrationPolicyFactory +from llumnix.arg_utils import InstanceArgs logger = init_logger(__name__) class MigrationScheduler: - def __init__(self, - pair_migration_policy: str, + def __init__(self, pair_migration_policy: str, migrate_out_load_threshold: float, - instance_load_calculator: InstanceLoadCalculator, - migration_backend: str,) -> None: - self.filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold) - self.migration_filter = MigrationInstanceFilter(self.filter_config) + is_group_kind_migration_backend: bool) -> None: + filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold) + self.migration_filter = MigrationInstanceFilter(filter_config) + self._register_migration_backend_init_filter(is_group_kind_migration_backend) - # Some migration backends require init_process_group before passing the KV cache. Here, we add a filter - # to prevent instances of migration backends that have not been initialized from participating in migration. - migration_backend_init_filter = CustomFilter() - migration_backend_init_filter.set_filter_condtition( - src_filter=lambda _: migration_backend not in ['gloo', 'nccl'], - dst_filter=lambda _: migration_backend not in ['gloo', 'nccl']) - self.migration_filter.register_filter("migration_backend_init_filter", - migration_backend_init_filter) - - self.instance_load_calculator = instance_load_calculator - self.enable_defrag = instance_load_calculator.enable_defrag - if not self.enable_defrag: - self.pair_migration_policy \ - = PairMigrationPolicyFactory.get_policy("balanced", - migrate_out_load_threshold=migrate_out_load_threshold, - instance_load_calculator=instance_load_calculator) - else: - self.pair_migration_policy \ - = PairMigrationPolicyFactory.get_policy(pair_migration_policy, - migrate_out_load_threshold=migrate_out_load_threshold, - instance_load_calculator=instance_load_calculator) + self.pair_migration_policy = PairMigrationPolicyFactory.get_policy( + pair_migration_policy, migrate_out_load_threshold=migrate_out_load_threshold) self.num_instances = 0 self.instance_id_set: Set[str] = set() - # instance info args self.instance_info: Dict[str, InstanceInfo] = None - self.sorted_instance_infos: List[InstanceInfo] = None + def _register_migration_backend_init_filter(self, is_group_kind_migration_backend: bool) -> None: + # some migration backends require init_process_group before passing the KV cache. Here, we add a filter + # to prevent instances of migration backends that have not been initialized from participating in migration. + migration_backend_init_filter = CustomFilter() + migration_backend_init_filter.set_filter_condtition( + src_filter=lambda _: not is_group_kind_migration_backend, + dst_filter=lambda _: not is_group_kind_migration_backend) + self.migration_filter.register_filter("migration_backend_init_filter", migration_backend_init_filter) + + # migration_filter must ensure that the specific instance_info does not appear in both src and dst simultaneously def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]: src_instance_infos, dst_instance_infos = self.migration_filter.filter_instances( self.instance_info.values(), pair_migration_type) return self.pair_migration_policy.pair_migration(src_instance_infos, dst_instance_infos) - def update_instance_infos(self, - instance_info: Dict[str, InstanceInfo]) -> None: + def update_instance_infos(self, instance_info: Dict[str, InstanceInfo]) -> None: self.instance_info = instance_info - def add_instance(self, instance_id: str) -> None: + # pylint: disable=unused-argument + def add_instance(self, instance_id: str, instance_args: InstanceArgs) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) def remove_instance(self, instance_id: str) -> None: - self.instance_id_set.remove(instance_id) + if instance_id in self.instance_id_set: + self.instance_id_set.remove(instance_id) self.num_instances = len(self.instance_id_set) diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index c85b211e..ff966617 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -13,40 +13,30 @@ from typing import Dict, List, Tuple, Set from abc import ABC, abstractmethod -from enum import Enum import numpy as np from llumnix.logging.logger import init_logger -from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator +from llumnix.instance_info import InstanceInfo, ScalingLoadComputation, InstanceType +from llumnix.arg_utils import InstanceArgs logger = init_logger(__name__) -class InstanceType(str, Enum): - NO_CONSTRAINTS = "NO_CONSTRAINTS" - # Specific to Prefill-Decoding disaggregation. - PREFILL = "PREFILL" - DECODE = "DECODE" - - class ScalingScheduler: def __init__(self, scale_up_threshold: float, scale_down_threshold: float, scaling_policy: str, - instance_load_calculator: InstanceLoadCalculator, - enable_pd_disagg: bool, - maximum_prefill_instance_num: int) -> None: + scaling_load_metric: str, + enable_pd_disagg: bool,) -> None: self.scale_up_threshold = scale_up_threshold self.scale_down_threshold = scale_down_threshold - self.scaling_policy = ScalePolicyFactory.get_policy(scaling_policy, - instance_load_calculator=instance_load_calculator) - self.instance_load_calculator = instance_load_calculator + self.scaling_policy = ScalePolicyFactory.get_policy(scaling_policy, scaling_load_metric=scaling_load_metric) self.num_instances = 0 self.instance_id_set: Set[str] = set() - self.maximum_prefill_instance_num = maximum_prefill_instance_num self.enable_pd_disagg = enable_pd_disagg + # instance info args self.instance_info: Dict[str, InstanceInfo] = None self.sorted_instance_infos: List[InstanceInfo] = None @@ -71,26 +61,18 @@ def check_scale(self) -> Tuple[str, str]: scale_down_num = 1 return scale_up_num, scale_down_num - def update_instance_infos(self, - instance_info: Dict[str, InstanceInfo]) -> None: + def update_instance_infos(self, instance_info: Dict[str, InstanceInfo]) -> None: self.instance_info = instance_info - def add_instance(self, instance_id: str) -> None: + def add_instance(self, instance_id: str, instance_args: InstanceArgs) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) - instance_type = None - if not self.enable_pd_disagg: - instance_type = InstanceType.NO_CONSTRAINTS - else: - if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.maximum_prefill_instance_num: - instance_type = InstanceType.PREFILL - else: - instance_type = InstanceType.DECODE + instance_type = InstanceType(instance_args.instance_type) self.instance_type_id_set[instance_type].add(instance_id) - return instance_type def remove_instance(self, instance_id: str) -> None: - self.instance_id_set.remove(instance_id) + if instance_id in self.instance_id_set: + self.instance_id_set.remove(instance_id) for instance_type in InstanceType: if instance_id in self.instance_type_id_set[instance_type]: self.instance_type_id_set[instance_type].remove(instance_id) @@ -108,16 +90,9 @@ def get_empty_instance_info(self) -> InstanceInfo: dummy_intance_info.num_available_gpu_blocks_waiting = np.inf return dummy_intance_info - def get_instance_type_info(self, instance_id: str) -> InstanceInfo: - for instance_type in InstanceType: - if instance_id in self.instance_type_id_set[instance_type]: - return instance_type - return self.add_instance(instance_id) - class ScalePolicy(ABC): - def __init__(self, - instance_load_calculator: InstanceLoadCalculator) -> None: - self.instance_load_calculator = instance_load_calculator + def __init__(self, scaling_load_metric: str) -> None: + self.scaling_load_calculator = ScalingLoadComputation(scaling_load_metric) @abstractmethod def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: @@ -138,7 +113,7 @@ def compute_load_metric_avg(self, instance_infos: List[InstanceInfo]) -> float: tot_instance_info.num_watermark_blocks = sum([i.num_watermark_blocks for i in instance_infos]) tot_instance_info.num_blocks_all_waiting_requests = sum([i.num_blocks_all_waiting_requests for i in instance_infos]) tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks - return self.instance_load_calculator.compute_instance_load(tot_instance_info, action="scale") + return self.scaling_load_calculator.compute_instance_load(tot_instance_info) class MaxLoad(ScalePolicy): def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: @@ -175,7 +150,7 @@ def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: for i in instance_infos]) tot_instance_info.num_blocks_all_waiting_requests = sum([i.num_blocks_all_waiting_requests for i in instance_infos]) tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks - return self.instance_load_calculator.compute_instance_load(tot_instance_info, action='scale') + return self.scaling_load_calculator.compute_instance_load(tot_instance_info) class ScalePolicyFactory: _POLICY_REGISTRY = { diff --git a/llumnix/instance_info.py b/llumnix/instance_info.py index 576d467f..b20f4ed0 100644 --- a/llumnix/instance_info.py +++ b/llumnix/instance_info.py @@ -12,7 +12,10 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Dict +import copy +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Tuple import numpy as np from llumnix.logging.logger import init_logger @@ -20,111 +23,104 @@ logger = init_logger(__name__) +class InstanceType(str, Enum): + NO_CONSTRAINTS = "no_constraints" + PREFILL = "prefill" + DECODE = "decode" +@dataclass class InstanceInfo: - def __init__(self, - num_total_gpu_blocks: int = 0, - num_watermark_blocks: int= 0, - num_used_gpu_blocks: int = 0, - num_free_gpu_blocks: int = 0, - gpu_cache_usage: float = 0.0, - num_running_requests: int = 0, - num_waiting_requests: int = 0, - num_killed_requests: int = 0, - num_blocks_first_waiting_request: int = 0, - waiting_time_first_waiting_request: int = 0, - num_blocks_all_waiting_requests: int = 0, - inference_type: RequestInferenceType = RequestInferenceType.PREFILL, - instance_type: str = "", - num_batched_tokens: int = 0, - instance_id: str = "",) -> None: - self.num_total_gpu_blocks = num_total_gpu_blocks - self.num_watermark_blocks = num_watermark_blocks - self.num_used_gpu_blocks = num_used_gpu_blocks - self.num_free_gpu_blocks = num_free_gpu_blocks + instance_id: str = "" + instance_type: InstanceType = None + + step_id: int = None + timestamp: float = None + num_batched_tokens: int = None + num_seqs = None + running_seq_lens: List[int] = field(default_factory=list) + last_inference_latency: float = None + inference_type: RequestInferenceType = None + + num_total_gpu_blocks: int = 0 + num_watermark_blocks: int = 0 + num_used_gpu_blocks: int = 0 + num_free_gpu_blocks: int = 0 + gpu_cache_usage: float = 0.0 + num_running_requests: int = 0 + num_waiting_requests: int = 0 + num_killed_requests: int = 0 + num_blocks_first_waiting_request: int = 0 + waiting_time_first_waiting_request: int = 0 + num_blocks_all_waiting_requests: int = 0 + num_blocks_last_running_request: int = 0 + + # on-demand init infos + dispatch_load_metric: float = -np.inf + migration_load_metric: float = np.inf + migration_load_metric_after_migrate_in: float = -np.inf + migration_load_metric_after_migrate_out: float = np.inf + + # lazy init infos + num_available_gpu_blocks: int = 0 + num_available_gpu_blocks_waiting: int = 0 + + # manual init infos + profiling_data: Tuple[str, int, int, float] = None + + def __post_init__(self) -> None: self.num_available_gpu_blocks = self.num_free_gpu_blocks - self.num_watermark_blocks - self.gpu_cache_usage = gpu_cache_usage - self.num_running_requests = num_running_requests - self.num_waiting_requests = num_waiting_requests - self.num_killed_requests = num_killed_requests - self.num_blocks_first_waiting_request = num_blocks_first_waiting_request - self.waiting_time_first_waiting_request = waiting_time_first_waiting_request - self.num_blocks_all_waiting_requests = num_blocks_all_waiting_requests self.num_available_gpu_blocks_waiting = self.num_available_gpu_blocks - self.num_blocks_all_waiting_requests - # For instance load computation before migration. - self.num_blocks_last_running_request = 0 - - # For global scheduling. - self.instance_load_migrate = -np.inf - self.instance_load_dispatch_scale = -np.inf - self.instance_type = instance_type - - # For record statistics, assigned in scheduler. - self.inference_type = inference_type - self.num_batched_tokens = num_batched_tokens - self.running_seq_lens = [] - self.num_seqs = 0 - self.max_tot_tokens = 0 - self.finished_request_ids = None - - # For record statistics, assigned in backend engine. - self.instance_id = instance_id - self.step_id = None - self.timestamp = None - self.profiling_data = () - -class InstanceLoadInfo: - def __init__(self, instance_info: InstanceInfo) -> None: - self.num_total_gpu_blocks = instance_info.num_total_gpu_blocks - self.num_watermark_blocks = instance_info.num_watermark_blocks - self.num_used_gpu_blocks = instance_info.num_used_gpu_blocks - self.num_free_gpu_blocks = instance_info.num_free_gpu_blocks - self.num_available_gpu_blocks = instance_info.num_available_gpu_blocks - - self.num_waiting_requests = instance_info.num_waiting_requests - self.num_running_requests = instance_info.num_running_requests - self.num_killed_requests = instance_info.num_killed_requests - - self.num_blocks_first_waiting_request = instance_info.num_blocks_first_waiting_request - self.waiting_time_first_waiting_request = instance_info.waiting_time_first_waiting_request - self.num_blocks_all_waiting_requests = instance_info.num_blocks_all_waiting_requests - - self.instance_id = instance_info.instance_id - self.step_id = instance_info.step_id class InstanceLoadCalculator: - def __init__(self, - load_metric: str, - enable_defrag: bool) -> None: - self.load_metric = load_metric - self.enable_defrag = enable_defrag - self.load_computation_strategies: Dict[str, LoadComputationStrategy] = { - 'migrate': MigrationLoadComputation(load_metric, enable_defrag), - 'dispatch': DispatchAndScalingLoadComputation(load_metric, enable_defrag), - 'scale': DispatchAndScalingLoadComputation(load_metric, enable_defrag), - } - - def compute_instance_load(self, - instance_info: InstanceInfo, - action: str = 'migrate') -> float: - instance_load_info = InstanceLoadInfo(instance_info) - assert action in self.load_computation_strategies - load_computation_strategy = self.load_computation_strategies[action] - return load_computation_strategy.compute_instance_load(instance_load_info) + def __init__(self, dispatch_load_metric: str, migration_load_metric: str, enable_defrag: bool) -> None: + self.dispatch_load_calculator = DispatchLoadComputation(migration_load_metric) + self.migration_load_calculator = MigrationLoadComputation(dispatch_load_metric, enable_defrag) + + def compute_instance_load(self, instance_info: InstanceInfo): + instance_info.dispatch_load_metric = self.dispatch_load_calculator.compute_instance_load(instance_info) + instance_info.migration_load_metric = self.migration_load_calculator.compute_instance_load(instance_info) + instance_info.migration_load_metric_after_migrate_out = self.migration_load_calculator.\ + compute_instance_load_after_migrate(instance_info, is_migrate_in=False) + instance_info.migration_load_metric_after_migrate_in = self.migration_load_calculator.\ + compute_instance_load_after_migrate(instance_info, is_migrate_in=True) class LoadComputationStrategy(ABC): - def __init__(self, - load_metric: str, - enable_defrag: bool) -> None: + def __init__(self, load_metric: str, enable_defrag: bool = False) -> None: self.load_metric = load_metric self.enable_defrag = enable_defrag @abstractmethod - def compute_instance_load(self, i: InstanceLoadInfo) -> float: + def compute_instance_load(self, i: InstanceInfo) -> float: pass +class DispatchLoadComputation(LoadComputationStrategy): + def compute_instance_load(self, i: InstanceInfo) -> float: + instance_load = -np.inf + if self.load_metric == 'usage_ratio': + instance_load = (i.num_used_gpu_blocks + i.num_blocks_all_waiting_requests) / i.num_total_gpu_blocks + elif self.load_metric == 'remaining_steps': + num_requests = i.num_running_requests + i.num_waiting_requests + num_available_gpu_blocks = i.num_available_gpu_blocks - i.num_blocks_all_waiting_requests + if num_requests == 0: + return -np.inf + instance_load = (num_available_gpu_blocks / num_requests)*(-1) + return instance_load + class MigrationLoadComputation(LoadComputationStrategy): - def compute_instance_load(self, i: InstanceLoadInfo) -> float: + def compute_instance_load_after_migrate(self, i: InstanceInfo, is_migrate_in: bool) -> float: + instance_info_after_migrate = copy.deepcopy(i) + num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request + + if is_migrate_in: + instance_info_after_migrate.num_running_requests += 1 + instance_info_after_migrate.num_available_gpu_blocks -= num_blocks_last_running_request + else: + instance_info_after_migrate.num_running_requests -= 1 + instance_info_after_migrate.num_available_gpu_blocks += num_blocks_last_running_request + + return self.compute_instance_load(instance_info_after_migrate) + + def compute_instance_load(self, i: InstanceInfo) -> float: instance_load = -np.inf if self.load_metric == 'usage_ratio': instance_load = (i.num_used_gpu_blocks + i.num_blocks_first_waiting_request) / i.num_total_gpu_blocks @@ -136,23 +132,19 @@ def compute_instance_load(self, i: InstanceLoadInfo) -> float: num_requests = i.num_running_requests if i.num_waiting_requests != 0: num_requests += 1 - # num_requests = i.num_running_requests + i.num_waiting_requests num_available_gpu_blocks = i.num_available_gpu_blocks - i.num_blocks_first_waiting_request - # num_available_gpu_blocks = i.num_available_gpu_blocks - i.num_blocks_all_waiting_requests if num_requests == 0: return -np.inf - instance_load = (num_available_gpu_blocks / num_requests)*(-1) + instance_load = (num_available_gpu_blocks / num_requests) * (-1) return instance_load -class DispatchAndScalingLoadComputation(LoadComputationStrategy): - def compute_instance_load(self, i: InstanceLoadInfo) -> float: - instance_load = -np.inf - if self.load_metric == 'usage_ratio': - instance_load = (i.num_used_gpu_blocks + i.num_blocks_all_waiting_requests) / i.num_total_gpu_blocks - elif self.load_metric == 'remaining_steps': - num_requests = i.num_running_requests + i.num_waiting_requests - num_available_gpu_blocks = i.num_available_gpu_blocks - i.num_blocks_all_waiting_requests - if num_requests == 0: - return -np.inf - instance_load = (num_available_gpu_blocks / num_requests)*(-1) - return instance_load +# TODO(KuilongCui): currently scaling and dispatch use the same load calculator, leave +# it in the future to refine +class ScalingLoadComputation(LoadComputationStrategy): + def __init__(self, load_metric): + super().__init__(load_metric) + self.load_calculator = DispatchLoadComputation(load_metric) + + def compute_instance_load(self, i: InstanceInfo) -> float: + return self.load_calculator.compute_instance_load(i) + \ No newline at end of file diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 9e1db037..013882c0 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -18,8 +18,8 @@ def __init__( migration_backend: str, migration_buffer_blocks: int, migration_num_layers: int, - last_stage_max_blocks: int, - max_stages: int, + migration_last_stage_max_blocks: int, + migration_max_stages: int, migration_backend_init_timeout: float, migration_backend_transfer_type: str = "", grpc_migration_backend_server_address: str = "", @@ -30,8 +30,8 @@ def __init__( self.migration_backend_transfer_type = migration_backend_transfer_type self.migration_num_layers = migration_num_layers self.migration_buffer_blocks = migration_buffer_blocks - self.last_stage_max_blocks = last_stage_max_blocks - self.max_stages = max_stages + self.migration_last_stage_max_blocks = migration_last_stage_max_blocks + self.migration_max_stages = migration_max_stages self.migration_backend_init_timeout = migration_backend_init_timeout self.grpc_migration_backend_server_address = grpc_migration_backend_server_address self.kvtransfer_migration_backend_naming_url = kvtransfer_migration_backend_naming_url @@ -41,33 +41,24 @@ class GlobalSchedulerConfig: def __init__( self, initial_instances: int, - load_metric: str, dispatch_policy: str, - num_dispatch_instances: int, pair_migration_policy: str, migrate_out_threshold: float, - enable_defrag: bool, scaling_policy: str, + scaling_load_metric: str, scale_up_threshold: float, scale_down_threshold: float, enable_pd_disagg: bool, - migration_backend: str,) -> None: + is_group_kind_migration_backend: bool,) -> None: self.initial_instances = initial_instances - self.load_metric = load_metric - self.dispatch_policy = dispatch_policy - self.pair_migration_policy = pair_migration_policy - # TODO(KuilongCui): Use a better way to set the threshold, as having both positive and negative - # values can cause confusion. - self.migrate_out_load_threshold = migrate_out_threshold*(-1) - self.enable_defrag = enable_defrag + self.migrate_out_load_threshold = migrate_out_threshold self.scaling_policy = scaling_policy - self.scale_up_threshold = scale_up_threshold*(-1) - self.scale_down_threshold = scale_down_threshold*(-1) + self.scaling_load_metric = scaling_load_metric + self.scale_up_threshold = scale_up_threshold + self.scale_down_threshold = scale_down_threshold self.enable_pd_disagg = enable_pd_disagg - self.num_dispatch_instances = num_dispatch_instances - - self.migration_backend = migration_backend + self.is_group_kind_migration_backend = is_group_kind_migration_backend diff --git a/llumnix/launcher.py b/llumnix/launcher.py new file mode 100644 index 00000000..7978ea78 --- /dev/null +++ b/llumnix/launcher.py @@ -0,0 +1,224 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import copy +import traceback +from typing import Callable, Dict, List, Tuple + +import ray +from ray.util.state import list_placement_groups, list_actors +from ray.util.placement_group import PlacementGroup + +from llumnix.logger import init_logger +from llumnix.instance_info import InstanceType +from llumnix.global_scheduler.global_scheduler import GlobalScheduler +from llumnix.llumlet.llumlet import Llumlet +from llumnix.queue.queue_type import QueueType +from llumnix.backends.backend_interface import BackendType +from llumnix.arg_utils import EntrypointsArgs, InstanceArgs +from llumnix.entrypoints.vllm.api_server_actor import FastAPIServerActor +from llumnix.backends.utils import get_engine_world_size +from llumnix.utils import (remove_placement_group, get_manager_name, INSTANCE_NAME_PREFIX, get_instance_name, + SERVER_NAME_PREFIX, kill_server, kill_instance, get_actor_data_from_ray_internal_kv, + initialize_placement_group, get_server_name, put_actor_data_to_ray_internal_kv, + get_placement_group_name) + +logger = init_logger(__name__) + +class Launcher: + def __init__(self, global_scheduler: GlobalScheduler, enable_port_increment: bool, + enable_port_offset_store: bool, enable_pd_disagg: bool, + enablde_engine_pd_disagg: bool, pd_ratio: List[int]): + self.global_scheduler = global_scheduler + self.enable_port_increment = enable_port_increment + self.enable_port_offset_store = enable_port_offset_store + self.enable_pd_disagg = enable_pd_disagg + self.enablde_engine_pd_disagg = enablde_engine_pd_disagg + self.pd_ratio = pd_ratio + + if enable_port_increment: + self.port_offset = 0 + if enable_port_offset_store: + value = get_actor_data_from_ray_internal_kv("manager", "port_offset") + if value is not None: + self.port_offset = int(value) + + self.inflight_num_prefill = 0 + self.inflight_num_decode = 0 + + def init_placement_group(self, + placement_group_name: str, + engine_args, + backend_type: BackendType, + init_server: bool = False, + block: bool = True) -> PlacementGroup: + if not BackendType.is_sim_backend(backend_type): + # num_cpus=3, for Llumlet + AsyncPutQueueActor + ProxyActor + # num_gpus=world_size, for world_size Workers + world_size = get_engine_world_size(engine_args, backend_type) + placement_group = initialize_placement_group(placement_group_name, num_cpus=3+int(init_server), + num_gpus=world_size, detached=True, block=block) + else: + # num_cpus=1, for Llumlet + AsyncPutQueueActor + placement_group = initialize_placement_group(placement_group_name, num_cpus=2+int(init_server), + num_gpus=0, detached=True, block=block) + + return placement_group + + def get_instance_deployment_states(self, instance_id: str): + pg_state = list_placement_groups(filters=[("name", "=", get_placement_group_name(instance_id))]) + pg_created = len(pg_state) == 1 and pg_state[0]["state"] == "CREATED" + server_state = list_actors(filters=[("name", "=", get_server_name(instance_id))]) + server_alive = len(server_state) == 1 and server_state[0]["state"] == "ALIVE" + instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))]) + instance_alive = len(instance_state) == 1 and instance_state[0]["state"] == "ALIVE" + + return pg_created, server_alive, instance_alive + + def get_cluster_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, FastAPIServerActor], Dict[str, Llumlet]]: + curr_pgs: Dict[str, PlacementGroup] = {} + curr_servers: Dict[str, PlacementGroup] = {} + curr_instances: Dict[str, Llumlet] = {} + + created_pg_states = list_placement_groups(filters=[("state", "=", "CREATED")]) + for created_pg_state in created_pg_states: + instance_id = created_pg_state["name"].split("_")[-1] + curr_pgs[instance_id] = ray.util.get_placement_group(created_pg_state["name"]) + + alive_actor_states = list_actors(filters=[("state", "=", "ALIVE")]) + for alive_actor_state in alive_actor_states: + if alive_actor_state["name"].startswith(SERVER_NAME_PREFIX): + instance_id = alive_actor_state["name"].split("_")[-1] + curr_servers[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") + elif alive_actor_state["name"].startswith(INSTANCE_NAME_PREFIX): + instance_id = alive_actor_state["name"].split("_")[-1] + curr_instances[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") + + return curr_pgs, curr_servers, curr_instances + + def clear_instance_ray_resources(self, instance_id: str): + if not remove_placement_group(instance_id): + logger.debug("Failed to remove placement group {}.".format(instance_id)) + if not kill_server(instance_id): + logger.debug("Failed to kill server {}.".format(instance_id)) + if not kill_instance(instance_id): + logger.debug("Failed to kill instance {}.".format(instance_id)) + + def _get_next_instance_type(self, cur_num_prefill, cur_num_decode, pd_ratio) -> str: + instance_type = InstanceType.NO_CONSTRAINTS + + if self.enable_pd_disagg: + # Note: There are no instances simultaneously in inflight_num_prefill and cur_num_prefill as + # inflight_num will decrease before scaling up the instances. The same applies to num_decode. + total_num_prefill = self.inflight_num_prefill + cur_num_prefill + total_num_decode = self.inflight_num_decode + cur_num_decode + + if total_num_prefill == 0: + instance_type = InstanceType.PREFILL + elif total_num_decode == 0: + instance_type = InstanceType.DECODE + else: + # compute distance if launch prefill or decode + normal_distance = pd_ratio[0] - pd_ratio[1] + distance_if_prefill = total_num_prefill + 1 - total_num_decode + distance_if_decode = total_num_prefill - (total_num_decode + 1) + gap_to_normal_if_prefill = abs(distance_if_prefill - normal_distance) + gap_to_normal_if_decode = abs(distance_if_decode - normal_distance) + instance_type = InstanceType.PREFILL if gap_to_normal_if_prefill <= gap_to_normal_if_decode \ + else InstanceType.DECODE + + return instance_type + + def _get_next_instance_args(self, instance_args) -> InstanceArgs: + assert not self.enablde_engine_pd_disagg, \ + "currently not support engine based pd-disaggregation in Global Launch Model." + + config: InstanceArgs = copy.deepcopy(instance_args) + cur_num_prefill = len(self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set) + cur_num_decode = len(self.global_scheduler.instance_id_set - + self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set) + config.instance_type = self._get_next_instance_type(cur_num_prefill, cur_num_decode, self.pd_ratio) + return config + + def _get_next_entrypoints_args(self, entrypoints_args: EntrypointsArgs) -> EntrypointsArgs: + config = copy.deepcopy(entrypoints_args) + if self.enable_port_increment: + config.port += self.port_offset + config.request_output_queue_port += self.port_offset + self.port_offset += 1 + if self.enable_port_offset_store: + put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset) + return config + + def init_server_and_instance(self, instance_id: str, entrypoints_args: EntrypointsArgs, + instance_args: InstanceArgs, engine_args, backend_type: BackendType, + placement_group: PlacementGroup, instance_finish_cb: Callable = None, + server_finish_cb: Callable = None): + async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: EntrypointsArgs): + try: + manager = ray.get_actor(get_manager_name(), namespace="llumnix") + await instance.is_ready.remote() + await server.run.remote(manager, instance_id, instance) + self.inflight_num_prefill -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0 + self.inflight_num_decode -= 1 if instance_args.instance_type == InstanceType.DECODE else 0 + if instance_finish_cb: + # manager.scale_up will be called here after the instance is ready + instance_finish_cb(instance_id, instance, instance_args) + if server_finish_cb: + server_finish_cb(instance_id, server) + logger.info("launcher init_server_and_instance done, instance_id: {}, instance_type: {}, " + "api_server_port: {}, request_output_queue_port: {}".format(instance_id, + instance_args.instance_type, entrypoint_args.port, + entrypoint_args.request_output_queue_port)) + # pylint: disable=broad-except + except Exception as e: + self.inflight_num_prefill -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0 + self.inflight_num_decode -= 1 if instance_args.instance_type == InstanceType.DECODE else 0 + logger.error("[_init_server_and_instance] unexpected exception occurs: {}".format(e)) + logger.error("[_init_server_and_instance] exception traceback: {}".format(traceback.format_exc())) + self.clear_instance_ray_resources(instance_id) + + request_output_queue_type = QueueType(entrypoints_args.request_output_queue_type) + next_instance_args = self._get_next_instance_args(instance_args) + instance = self.init_instance(instance_id, next_instance_args, placement_group, + request_output_queue_type, backend_type, engine_args) + next_entrypoints_args = self._get_next_entrypoints_args(entrypoints_args) + server = self.init_server(get_server_name(instance_id), placement_group, next_entrypoints_args) + + self.inflight_num_prefill += 1 if next_instance_args.instance_type == InstanceType.PREFILL else 0 + self.inflight_num_decode += 1 if next_instance_args.instance_type == InstanceType.DECODE else 0 + asyncio.create_task(done_scale_up(next_instance_args, next_entrypoints_args)) + + def init_server(self, server_name: str, placement_group: PlacementGroup, + entrypoints_args: EntrypointsArgs) -> FastAPIServerActor: + fastapi_server = FastAPIServerActor.from_args(server_name, placement_group, entrypoints_args) + return fastapi_server + + def init_instance(self, + instance_id: str, + instance_args: InstanceArgs, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, + backend_type: BackendType, + engine_args + ) -> Tuple[str, Llumlet]: + instance = Llumlet.from_args( + instance_id, + instance_args, + placement_group, + request_output_queue_type, + backend_type, + engine_args) + + return instance diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index e0958deb..2a865e47 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -21,7 +21,7 @@ from ray.util.placement_group import PlacementGroup from llumnix.logging.logger import init_logger -from llumnix.instance_info import InstanceInfo +from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState from llumnix.backends.utils import init_backend_engine, get_engine_world_size from llumnix.llumlet.migration_coordinator import MigrationCoordinator, MigrationStatus @@ -30,6 +30,7 @@ from llumnix.internal_config import MigrationConfig from llumnix.queue.queue_type import QueueType from llumnix.llumlet.request import LlumnixRequest, RequestStatus +from llumnix.arg_utils import InstanceArgs from llumnix.utils import get_instance_name from llumnix.constants import CHECK_ENGINE_STATE_INTERVAL @@ -39,12 +40,11 @@ class Llumlet: def __init__(self, instance_id: str, + instance_args: InstanceArgs, placement_group: PlacementGroup, request_output_queue_type: QueueType, - migration_config: MigrationConfig, backend_type: BackendType, - engine_args, - profiling_result_file_path: str = None) -> None: + engine_args) -> None: try: self.job_id = ray.get_runtime_context().get_job_id() self.worker_id = ray.get_runtime_context().get_worker_id() @@ -54,17 +54,24 @@ def __init__(self, logger.info("Llumlet(job_id={}, worker_id={}, actor_id={}, node_id={}, instance_id={})".format( self.job_id, self.worker_id, self.actor_id, self.node_id, self.instance_id)) logger.info("Llumlet backend type: {}".format(backend_type)) + self.instance_args = instance_args self.actor_name = get_instance_name(instance_id) + self.instance_load_calculator = InstanceLoadCalculator( + dispatch_load_metric=instance_args.dispatch_load_metric, + migration_load_metric=instance_args.migration_load_metric, + enable_defrag=instance_args.enable_defrag + ) + migration_config: MigrationConfig = instance_args.create_migration_config() self.backend_engine: BackendInterface = init_backend_engine(instance_id, placement_group, request_output_queue_type, migration_config, backend_type, engine_args, - profiling_result_file_path) + instance_args.profiling_result_file_path) self.migration_coordinator = MigrationCoordinator(self.backend_engine, - migration_config.last_stage_max_blocks, - migration_config.max_stages) + migration_config.migration_last_stage_max_blocks, + migration_config.migration_max_stages) self.migration_scheduler = LocalMigrationScheduler(migration_config.request_migration_policy, self.backend_engine) self.log_requests = True @@ -82,12 +89,11 @@ def __repr__(self): @classmethod def from_args(cls, instance_id: str, + instance_args: InstanceArgs, placement_group: PlacementGroup, request_output_queue_type: QueueType, - migration_config: MigrationConfig, backend_type: BackendType, - engine_args, - profiling_result_file_path: str = None): + engine_args): try: assert backend_type in [backend_type.VLLM, backend_type.BLADELLM, backend_type.SIM_VLLM], \ f'unimplemented backend {backend_type}' @@ -109,12 +115,11 @@ def from_args(cls, ) ) llumlet = llumlet_class.remote(instance_id, + instance_args, placement_group, request_output_queue_type, - migration_config, backend_type, - engine_args, - profiling_result_file_path) + engine_args) # pylint: disable=broad-except except Exception as e: logger.error("Failed to initialize Llumlet: {}".format(e)) @@ -131,7 +136,7 @@ async def _check_engine_state_loop(self): # pylint: disable=protected-access self.backend_engine._stop_event.set() await asyncio.sleep(0) - self_actor = ray.get_actor(self.actor_name) + self_actor = ray.get_actor(name=self.actor_name, namespace="llumnix") ray.kill(self_actor) async def migrate_out(self, dst_instance_name: str) -> List[str]: @@ -192,11 +197,15 @@ async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, ds raise return migrated_request + # TODO(KuilongCui): only the metrics-related information needs to be synchronously loaded for the manager def get_instance_info(self) -> InstanceInfo: - return self.backend_engine.engine.instance_info + instance_info: InstanceInfo = self.backend_engine.engine.instance_info + instance_info.instance_type = self.instance_args.instance_type + self.instance_load_calculator.compute_instance_load(instance_info) + return instance_info - def is_ready(self) -> bool: - return True + def is_ready(self) -> InstanceArgs: + return self.instance_args def get_all_request_ids(self) -> List[str]: return self.backend_engine.get_all_request_ids() diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index 338fda3e..56bff867 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -43,10 +43,10 @@ def is_finished(status: "MigrationStatus") -> bool: class MigrationCoordinator: def __init__(self, backend_engine: BackendInterface, - last_stage_max_blocks: int, - max_stages: int) -> None: - self.last_stage_max_blocks = last_stage_max_blocks - self.max_stages = max_stages + migration_last_stage_max_blocks: int, + migration_max_stages: int) -> None: + self.migration_last_stage_max_blocks = migration_last_stage_max_blocks + self.migration_max_stages = migration_max_stages self.backend_engine = backend_engine async def migrate_out_running_request(self, @@ -95,7 +95,7 @@ async def _migrate_out_multistage(self, """ try: stage_count = 0 - while stage_count < self.max_stages: + while stage_count < self.migration_max_stages: stage_count += 1 status = await self._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) if MigrationStatus.is_finished(status): @@ -119,7 +119,7 @@ async def _migrate_out_onestage(self, pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list) incremental_blocks = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks) # live migration, transfer all blocks except last one(currently updating) - is_last_stage = (len(incremental_blocks) <= self.last_stage_max_blocks) or migrate_out_request.blocking_migration + is_last_stage = (len(incremental_blocks) <= self.migration_last_stage_max_blocks) or migrate_out_request.blocking_migration if not is_last_stage: migration_status = MigrationStatus.RUNNING src_blocks = incremental_blocks[:-1] diff --git a/llumnix/manager.py b/llumnix/manager.py index ba8c13a4..921e0fe0 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -12,17 +12,18 @@ # limitations under the License. import asyncio +import random import time import csv -import copy import os from typing import Dict, List, Tuple, Union, Iterable from collections import defaultdict import traceback from functools import partial + import ray +import ray.actor from ray.util.state import list_placement_groups, list_actors -from ray.util.placement_group import PlacementGroup from llumnix.llumlet.llumlet import Llumlet from llumnix.logging.logger import init_logger @@ -30,24 +31,19 @@ from llumnix.global_scheduler.migration_scheduler import PairMigrationConstraints from llumnix.global_scheduler.migration_filter import CustomFilter from llumnix.instance_info import InstanceInfo -from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, LaunchArgs +from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, InstanceArgs, LaunchArgs from llumnix.server_info import ServerInfo from llumnix.backends.backend_interface import BackendType -from llumnix.utils import (random_uuid, clear_gloo_backend_state, remove_placement_group, - get_instance_name, get_manager_name, INSTANCE_NAME_PREFIX, - SERVER_NAME_PREFIX, get_placement_group_name, run_async_func_sync, - kill_server, kill_instance, initialize_placement_group, - get_server_name, get_actor_data_from_ray_internal_kv, - put_actor_data_to_ray_internal_kv) +from llumnix.utils import (random_uuid, clear_gloo_backend_state, get_instance_name, + get_manager_name, INSTANCE_NAME_PREFIX, get_placement_group_name, + run_async_func_sync,) from llumnix.entrypoints.utils import LaunchMode -from llumnix.backends.utils import get_engine_world_size from llumnix.queue.queue_type import QueueType -from llumnix.entrypoints.vllm.api_server_actor import APIServerActor from llumnix.constants import (CLEAR_REQUEST_INSTANCE_INTERVAL, NO_INSTANCE_RETRY_INTERVAL, WAIT_ALL_MIGRATIONS_DONE_INTERVAL, AUTO_SCALE_UP_INTERVAL, WAIT_PLACEMENT_GROUP_TIMEOUT, CHECK_DEPLOYMENT_STATES_INTERVAL, WATCH_DEPLOYMENT_INTERVAL, WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE) - +from llumnix.launcher import Launcher logger = init_logger(__name__) @@ -56,11 +52,12 @@ class Manager: def __init__(self, + entrypoints_args: EntrypointsArgs, manager_args: ManagerArgs, + instance_args: InstanceArgs, + engine_args, + launch_args: LaunchArgs, work_dir: str, - entrypoints_args: EntrypointsArgs = None, - engine_args = None, - launch_args: LaunchArgs = None ) -> None: os.chdir(work_dir) self.job_id = ray.get_runtime_context().get_job_id() @@ -71,8 +68,10 @@ def __init__(self, self.job_id, self.worker_id, self.actor_id, self.node_id)) self.actor_name = get_manager_name() self.manager_args = manager_args - # engine_args and entrypoints_args are used in global deployment. + + # used in global deployment. self.entrypoints_args = entrypoints_args + self.instance_args = instance_args self.engine_args = engine_args self.launch_args = launch_args @@ -97,9 +96,14 @@ def __init__(self, self.polling_interval = manager_args.polling_interval - global_scheduler_config = manager_args.create_global_scheduler_config() + self.is_group_kind_migration_backend = manager_args.is_group_kind_migration_backend + global_scheduler_config = manager_args.create_global_scheduler_config(self.is_group_kind_migration_backend) self.global_scheduler = GlobalScheduler(global_scheduler_config) + self.launcher: Launcher = Launcher(self.global_scheduler, manager_args.enable_port_increment, + manager_args.enable_port_offset_store, manager_args.enable_pd_disagg, + manager_args.enable_engine_pd_disagg, manager_args.pd_ratio) + # log args self.log_requests = not manager_args.disable_log_requests_manager self.log_instance_info = manager_args.log_instance_info @@ -133,17 +137,13 @@ def __init__(self, asyncio.create_task(self._update_instance_info_loop(self.polling_interval)) asyncio.create_task(self._clear_request_instance_loop(CLEAR_REQUEST_INSTANCE_INTERVAL)) - if self.manager_args.enable_port_increment: - self.port_offset = 0 - if self.manager_args.enable_port_offset_store: - value = get_actor_data_from_ray_internal_kv("manager", "port_offset") - if value is not None: - self.port_offset = int(value) if hasattr(self, "launch_mode") and self.launch_mode == LaunchMode.GLOBAL: assert self.entrypoints_args is not None and self.engine_args is not None self.last_timeout_instance_id = None asyncio.create_task(self._auto_scale_up_loop(AUTO_SCALE_UP_INTERVAL)) asyncio.create_task(self._check_deployment_states_loop(CHECK_DEPLOYMENT_STATES_INTERVAL)) + if self.manager_args.enable_pd_disagg: + asyncio.create_task(self._check_pd_deployment_states_loop(CHECK_DEPLOYMENT_STATES_INTERVAL)) async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None: while self.num_instances == 0: @@ -312,7 +312,7 @@ async def _auto_scale_up_loop(self, interval: float) -> None: self.scale_down(instance_id) if new_pg is None: new_instance_id = random_uuid() - new_pg = self._init_placement_group(get_placement_group_name(new_instance_id), self.engine_args, self.backend_type, + new_pg = self.launcher.init_placement_group(get_placement_group_name(new_instance_id), self.engine_args, self.backend_type, init_server=True, block=False) try: await asyncio.wait_for(new_pg.ready(), WAIT_PLACEMENT_GROUP_TIMEOUT) @@ -323,7 +323,9 @@ async def _auto_scale_up_loop(self, interval: float) -> None: self.last_timeout_instance_id = new_instance_id await asyncio.sleep(interval) continue - self._init_server_and_instance(new_instance_id, new_pg) + self.launcher.init_server_and_instance(new_instance_id, self.entrypoints_args, self.instance_args, + self.engine_args, self.backend_type, new_pg, + instance_finish_cb=self.scale_up) logger.info("Deploy server and instance to new placement group done, instance_id: {}.".format(new_instance_id)) # pylint: disable=broad-except except Exception as e: @@ -352,16 +354,13 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): dead_instances.add(instance_name) if len(dead_instances) > 0: self.scale_down(dead_instances, rebuild_migration_backend=False) - if self.manager_args.migration_backend == 'gloo': - clear_gloo_backend_state() + clear_gloo_backend_state() return dead_instances alive_instances = sorted(self.instances.keys()) pending_task = self.pending_rebuild_migration_instances group_name = None - - if self.manager_args.migration_backend == 'gloo': - clear_gloo_backend_state() + clear_gloo_backend_state() while len(alive_instances) > 0 and self.pending_rebuild_migration_instances > 0: dead_instances = set() @@ -392,14 +391,16 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): # Restore migrate config self.enable_migration = origin_config - def scale_up(self, - instance_id: Union[str, Iterable[str]], - instance_actor_handle: Union["ray.actor.ActorHandle", List["ray.actor.ActorHandle"]]) -> None: + def scale_up(self, instance_id: Union[str, Iterable[str]], + instance_actor_handle: Union[ray.actor.ActorHandle, List[ray.actor.ActorHandle]], + instance_arg: Union[InstanceArgs, Iterable[InstanceArgs]]) -> None: if isinstance(instance_id, str): instance_id = [instance_id,] instance_actor_handle = [instance_actor_handle,] + instance_arg = [instance_arg,] instance_ids = list(instance_id) instance_actor_handles = list(instance_actor_handle) + instance_args = list(instance_arg) indeed_update = False no_pending_instance = (self.pending_rebuild_migration_instances == 0) @@ -412,14 +413,14 @@ def scale_up(self, if self.log_instance_info: self.instance_last_logged_empty[ins_id] = False self.pending_rebuild_migration_instances += 1 - self.global_scheduler.scale_up(instance_ids) + self.global_scheduler.scale_up(instance_ids, instance_args) self.num_instances = len(self.instances) # When scaling up, we need to rebuild the migration backend. But if initially self.pending_rebuild_migration_instances != 0, # a coroutine is already handling the changes in the number of instances in the cluster and it will account for the changes # caused by this scale-up (see rebuild_migration_backend for details). Therefore, we simply return in this case. - # Specifically, for RayRPC migration backend, the Ray actor handle is used for the migration cache, so there is no need to rebuild the group. - if self.enable_migration and self.manager_args.migration_backend in ['gloo', 'nccl'] \ + # Specifically, for not group kind migration backend, there is no need to rebuild the group. + if self.enable_migration and self.is_group_kind_migration_backend \ and indeed_update and no_pending_instance: asyncio.create_task(self._rebuild_migration_backend()) @@ -434,7 +435,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_b no_pending_instance = self.pending_rebuild_migration_instances == 0 for ins_id in instance_ids: - self._clear_instance_ray_states(ins_id) + self.launcher.clear_instance_ray_resources(ins_id) if ins_id in self.instances: indeed_update = True if ins_id in self.instances: @@ -454,30 +455,22 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_b self.global_scheduler.scale_down(instance_ids) self.num_instances = len(self.instances) - if self.enable_migration and self.manager_args.migration_backend in ['gloo', 'nccl']: + if self.enable_migration and self.is_group_kind_migration_backend: if len(self.instances) == 0: self.pending_rebuild_migration_instances = 0 - if self.manager_args.migration_backend == 'gloo': - clear_gloo_backend_state() + clear_gloo_backend_state() elif indeed_update and no_pending_instance and rebuild_migration_backend: asyncio.create_task(self._rebuild_migration_backend()) return self.num_instances - def _clear_instance_ray_states(self, instance_id: str): - if not remove_placement_group(instance_id): - logger.debug("Failed to remove placement group {}.".format(instance_id)) - if not kill_server(instance_id): - logger.debug("Failed to kill server {}.".format(instance_id)) - if not kill_instance(instance_id): - logger.debug("Failed to kill instance {}.".format(instance_id)) - async def _connect_to_instances(self): def connect_to_instances_done_callback(instance_id: str, instance_actor_handle: "ray.actor.ActorHandle", fut): ret = fut.result()[0] if not isinstance(ret, Exception): scale_up_instance_ids.append(instance_id) scale_up_instance_actor_handles.append(instance_actor_handle) + scale_up_instance_args.append(ret) logger.info("Connect to instance {}".format(instance_id)) else: logger.warning("Connect to instance {} failed, exception: {}".format(instance_id, ret)) @@ -489,6 +482,7 @@ def connect_to_instances_done_callback(instance_id: str, instance_actor_handle: instance_actor_handles = [ray.get_actor(actor_name, namespace='llumnix') for actor_name in instance_actor_names] scale_up_instance_ids = [] scale_up_instance_actor_handles = [] + scale_up_instance_args = [] tasks = [] for instance_actor_name, instance_actor_handle in zip(instance_actor_names, instance_actor_handles): instance_id = instance_actor_name[len('instance_'):] @@ -498,116 +492,99 @@ def connect_to_instances_done_callback(instance_id: str, instance_actor_handle: tasks.append(task) await asyncio.gather(*tasks) # The only function that can add instance actor handles to manager. - self.scale_up(scale_up_instance_ids, scale_up_instance_actor_handles) + self.scale_up(scale_up_instance_ids, scale_up_instance_actor_handles, scale_up_instance_args) @classmethod def from_args(cls, + entrypoints_args: EntrypointsArgs, manager_args: ManagerArgs, - entrypoints_args: EntrypointsArgs = None, - engine_args = None, - launch_args: LaunchArgs = None, + instance_args: InstanceArgs, + engine_args, + launch_args: LaunchArgs, ) -> "Manager": manager_class = ray.remote(num_cpus=1, max_restarts=-1, name=get_manager_name(), namespace="llumnix", lifetime="detached")(cls) - manager = manager_class.remote(manager_args, - os.getcwd(), - entrypoints_args, - engine_args, - launch_args) - + manager = manager_class.remote( + entrypoints_args, + manager_args, + instance_args, + engine_args, + launch_args, + os.getcwd()) return manager - def _init_placement_group(self, - placement_group_name: str, - engine_args, - backend_type: BackendType, - init_server: bool = False, - block: bool = True) -> PlacementGroup: - if not BackendType.is_sim_backend(backend_type): - # num_cpus=3, for Llumlet + AsyncPutQueueActor + ProxyActor - # num_gpus=world_size, for world_size Workers - world_size = get_engine_world_size(engine_args, backend_type) - placement_group = initialize_placement_group(placement_group_name, - num_cpus=3+int(init_server), num_gpus=world_size, detached=True, block=block) - else: - # num_cpus=1, for Llumlet + AsyncPutQueueActor - placement_group = initialize_placement_group(placement_group_name, - num_cpus=2+int(init_server), num_gpus=0, detached=True, block=block) - - return placement_group - - def _init_server(self, - server_name: str, - placement_group: PlacementGroup, - entrypoints_args: EntrypointsArgs) -> APIServerActor: - entrypoints_args = copy.deepcopy(entrypoints_args) - if self.manager_args.enable_port_increment: - entrypoints_args.port += self.port_offset - entrypoints_args.request_output_queue_port += self.port_offset - self.port_offset += 1 - if self.manager_args.enable_port_offset_store: - put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset) - api_server = APIServerActor.from_args(server_name, placement_group, entrypoints_args) - return api_server - - def _init_instance(self, - instance_id: str, - placement_group: PlacementGroup, - request_output_queue_type: QueueType, - backend_type: BackendType, - engine_args - ) -> Tuple[str, Llumlet]: - instance = Llumlet.from_args( - instance_id, - placement_group, - request_output_queue_type, - self.manager_args.create_migration_config(), - backend_type, - engine_args, - self.manager_args.profiling_result_file_path) - - return instance - def init_instances(self, request_output_queue_type: QueueType, backend_type: BackendType, + instance_args: InstanceArgs, engine_args ) -> Tuple[List[str], List[Llumlet]]: instance_ids: List[str] = [] instances: List[Llumlet] = [] for _ in range(self.manager_args.initial_instances): instance_id = random_uuid() - placement_group = self._init_placement_group(get_placement_group_name(instance_id), engine_args, backend_type) - instance = self._init_instance(instance_id, placement_group, request_output_queue_type, backend_type, engine_args) + placement_group = self.launcher.init_placement_group(get_placement_group_name(instance_id), engine_args, backend_type) + instance = self.launcher.init_instance(instance_id, instance_args, placement_group, request_output_queue_type, + backend_type, engine_args) instance_ids.append(instance_id) instances.append(instance) - self.scale_up(instance_ids, instances) + self.scale_up(instance_ids, instances, [instance_args]*len(instance_ids)) return instance_ids, instances - def _init_server_and_instance(self, - instance_id: str, - placement_group: PlacementGroup): - async def done_scale_up(): + def _inner_check_pd_deployment(self) -> str: + prefill_instance_ids = self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set + cur_num_prefill = len(prefill_instance_ids) + decode_instance_ids = self.global_scheduler.instance_id_set - prefill_instance_ids + cur_num_decode = len(decode_instance_ids) + + scale_down_instance_id = "" + if cur_num_prefill == 0 and cur_num_decode > 0: + scale_down_instance_id = random.choice(list(decode_instance_ids)) + logger.info("[_inner_check_pd_deployment] pd_ratio: {}, cur_num_prefill: {}, cur_num_decode: {}, " + "all decode, scale down decode instance {}".format(self.manager_args.pd_ratio, + cur_num_prefill, cur_num_decode, scale_down_instance_id)) + + if cur_num_decode == 0 and cur_num_prefill > 0: + scale_down_instance_id = random.choice(list(prefill_instance_ids)) + logger.info("[_inner_check_pd_deployment] pd_ratio: {}, cur_num_prefill: {}, cur_num_decode: {}, " + "all prefill, scale down prefill instance {}".format(self.manager_args.pd_ratio, + cur_num_prefill, cur_num_decode, scale_down_instance_id)) + + if scale_down_instance_id: + self.scale_down(scale_down_instance_id) + + return scale_down_instance_id + + # TODO(KuilongCui): currently, only one naive state check policy is implemented, which prevents the + # cluster from consisting entirely of prefill or decode instances. + async def _check_pd_deployment_states_loop(self, interval: float) -> None: + previous_penging_pg_names = None + + while True: try: - manager = ray.get_actor(get_manager_name(), namespace="llumnix") - await instance.is_ready.remote() - await server.run.remote(manager, instance_id, instance) - self.scale_up(instance_id, instance) + pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")]) + rescheduling_pg_states = list_placement_groups(filters=[("state", "=", "RESCHEDULING")]) + all_penging_pg_names = [pg.name for pg in pending_pg_states] + + if previous_penging_pg_names and len(rescheduling_pg_states) == 0 : + new_pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")]) + all_new_penging_pg_names = [pg.name for pg in new_pending_pg_states] + if len(set(previous_penging_pg_names).difference(set(all_new_penging_pg_names))) == 0: + self._inner_check_pd_deployment() + previous_penging_pg_names = all_new_penging_pg_names + else: + previous_penging_pg_names = all_penging_pg_names + + await asyncio.sleep(interval) # pylint: disable=broad-except except Exception as e: logger.error("Unexpected exception: {}".format(e)) logger.error("Exception traceback: {}".format(traceback.format_exc())) - self._clear_instance_ray_resources(instance_id) - - request_output_queue_type = QueueType(self.entrypoints_args.request_output_queue_type) - instance = self._init_instance(instance_id, placement_group, request_output_queue_type, self.backend_type, self.engine_args) - server = self._init_server(get_server_name(instance_id), placement_group, self.entrypoints_args) - asyncio.create_task(done_scale_up()) async def _check_deployment_states_loop(self, interval: float) -> None: async def watch_instance_deployment_states(instance_id: str): @@ -623,7 +600,7 @@ async def watch_instance_deployment_states(instance_id: str): wait_pending_instance_time += WATCH_DEPLOYMENT_INTERVAL if wait_pending_instance_time >= WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE: break - pg_created, server_alive, instance_alive = self._get_instance_deployment_states(instance_id) + pg_created, server_alive, instance_alive = self.launcher.get_instance_deployment_states(instance_id) if pg_created and (not server_alive or not instance_alive): logger.warning("Instance {} deployment states incorrect, states: (pg {}, server {}, instance {})" .format(instance_id, pg_created, server_alive, instance_alive)) @@ -631,7 +608,7 @@ async def watch_instance_deployment_states(instance_id: str): while True: try: - curr_pgs, curr_servers, curr_instances = self._get_cluster_deployment() + curr_pgs, curr_servers, curr_instances = self.launcher.get_cluster_deployment() assert len(curr_pgs) >= max(len(curr_servers), len(curr_instances)) tasks = [] for instance_id in curr_pgs: @@ -670,37 +647,6 @@ def check_instance_error_done_callback(idx: int, instance_id: str, fut): return results - def _get_cluster_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, APIServerActor], Dict[str, Llumlet]]: - curr_pgs: Dict[str, PlacementGroup] = {} - curr_servers: Dict[str, PlacementGroup] = {} - curr_instances: Dict[str, Llumlet] = {} - - created_pg_states = list_placement_groups(filters=[("state", "=", "CREATED")]) - for created_pg_state in created_pg_states: - instance_id = created_pg_state["name"].split("_")[-1] - curr_pgs[instance_id] = ray.util.get_placement_group(created_pg_state["name"]) - - alive_actor_states = list_actors(filters=[("state", "=", "ALIVE")]) - for alive_actor_state in alive_actor_states: - if alive_actor_state["name"].startswith(SERVER_NAME_PREFIX): - instance_id = alive_actor_state["name"].split("_")[-1] - curr_servers[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") - elif alive_actor_state["name"].startswith(INSTANCE_NAME_PREFIX): - instance_id = alive_actor_state["name"].split("_")[-1] - curr_instances[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") - - return curr_pgs, curr_servers, curr_instances - - def _get_instance_deployment_states(self, instance_id: str): - pg_state = list_placement_groups(filters=[("name", "=", get_placement_group_name(instance_id))]) - pg_created = len(pg_state) == 1 and pg_state[0]["state"] == "CREATED" - server_state = list_actors(filters=[("name", "=", get_server_name(instance_id))]) - server_alive = len(server_state) == 1 and server_state[0]["state"] == "ALIVE" - instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))]) - instance_alive = len(instance_state) == 1 and instance_state[0]["state"] == "ALIVE" - - return pg_created, server_alive, instance_alive - async def _get_request_instance(self) -> None: def get_request_instance_done_callback(instance_id: str, fut): ret = fut.result()[0] @@ -742,8 +688,8 @@ def _init_instance_info_csv(self, manager_args: ManagerArgs) -> None: 'step_id', 'gpu_cache_usage', 'num_available_gpu_blocks', - 'instance_load', - 'max_tot_tokens', + 'dispatch_load_metric', + 'migration_load_metric', 'num_running_requests', 'num_waiting_requests', 'num_killed_requests', @@ -770,8 +716,8 @@ def _log_instance_infos_to_csv(self, instance_infos: List[InstanceInfo]) -> None instance_info.step_id, instance_info.gpu_cache_usage, instance_info.num_available_gpu_blocks, - instance_info.instance_load_migrate, - instance_info.max_tot_tokens, + instance_info.dispatch_load_metric, + instance_info.migration_load_metric, instance_info.num_running_requests, instance_info.num_waiting_requests, instance_info.num_killed_requests, diff --git a/llumnix/utils.py b/llumnix/utils.py index 3140ba32..dd6426c5 100644 --- a/llumnix/utils.py +++ b/llumnix/utils.py @@ -99,7 +99,8 @@ def clear_gloo_backend_state(): try: # clear gloo migrate backend intermediate state ray.kill(ray.get_actor("gloo_queue", "llumnix")) - except ValueError: + # pylint: disable=broad-except + except Exception: # gloo_queue may not have been created yet; just ignore this error. pass diff --git a/tests/e2e_test/test_bench.py b/tests/e2e_test/test_bench.py index 5567db1a..3491b3e1 100644 --- a/tests/e2e_test/test_bench.py +++ b/tests/e2e_test/test_bench.py @@ -62,25 +62,57 @@ def get_markdown_data(key: str, head_name: str): return to_markdown_table(prefill_data) + "\n\n" + to_markdown_table(decode_data) @pytest.mark.asyncio -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for simple benchmark") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for simple benchmark") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) @pytest.mark.parametrize("launch_mode", ['global', 'local']) -async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch_mode): +@pytest.mark.parametrize("enable_pd_disagg", [True, False]) +async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch_mode, enable_pd_disagg): + if launch_mode == 'local': + num_prompts = 500 if not enable_pd_disagg else 50 + else: + num_prompts = 50 if not enable_pd_disagg else 50 + ip = "127.0.0.1" base_port = 37037 ip_ports = [] if launch_mode == 'local': device_count = torch.cuda.device_count() - for i in range(device_count): - port = base_port+i - ip_port = f"{ip}:{port}" - ip_ports.append(ip_port) - launch_command = generate_launch_command(result_filename=str(base_port+i)+".out", - launch_ray_cluster=False, - ip=ip, - port=port, - model=model) - subprocess.run(launch_command, shell=True, check=True) + if enable_pd_disagg: + for i in range(device_count//2): + port = base_port+i + ip_port = f"{ip}:{port}" + ip_ports.append(ip_port) + launch_command = generate_launch_command(result_filename=str(base_port+i)+".out", + launch_ray_cluster=False, + ip=ip, + port=port, + model=model, + enable_pd_disagg=enable_pd_disagg, + instance_type="prefill") + subprocess.run(launch_command, shell=True, check=True) + for i in range(device_count//2): + port = base_port+i+device_count//2 + ip_port = f"{ip}:{port}" + ip_ports.append(ip_port) + launch_command = generate_launch_command(result_filename=str(base_port+i)+".out", + launch_ray_cluster=False, + ip=ip, + port=port, + model=model, + enable_pd_disagg=enable_pd_disagg, + instance_type="decode") + subprocess.run(launch_command, shell=True, check=True) + else: + for i in range(device_count): + port = base_port+i + ip_port = f"{ip}:{port}" + ip_ports.append(ip_port) + launch_command = generate_launch_command(result_filename=str(base_port+i)+".out", + launch_ray_cluster=False, + ip=ip, + port=port, + model=model) + subprocess.run(launch_command, shell=True, check=True) else: # global device_count = torch.cuda.device_count() for i in range(device_count): @@ -90,11 +122,11 @@ async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch serve_command = generate_serve_command(result_filename=str(base_port)+".out", ip=ip, port=base_port, - model=model) + model=model, + enable_pd_disagg=enable_pd_disagg) # pylint: disable=subprocess-run-check subprocess.run('ray start --head', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) subprocess.run(serve_command, shell=True, check=True) - wait_for_llumnix_service_ready(ip_ports) def run_bench_command(command): @@ -107,7 +139,7 @@ def run_bench_command(command): bench_command = generate_bench_command( ip_ports=f"127.0.0.1:{base_port + i}", model=model, - num_prompts=200, + num_prompts=num_prompts, dataset_type="sharegpt", dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl", qps=5, @@ -129,7 +161,7 @@ def run_bench_command(command): process.kill() assert False, "bench_test timed out after {} minutes.".format(BENCH_TEST_TIMEOUT_MINS) - if launch_mode == 'local': + if launch_mode == 'local' and not enable_pd_disagg: with open("performance.txt", "w", encoding="utf-8") as f: f.write(parse_log_file()) diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_correctness.py similarity index 53% rename from tests/e2e_test/test_e2e.py rename to tests/e2e_test/test_correctness.py index 3fa3b3a2..7cf7f94c 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_correctness.py @@ -22,7 +22,7 @@ # pylint: disable=unused-import from tests.conftest import ray_env -from .utils import (generate_launch_command, wait_for_llumnix_service_ready, +from .utils import (generate_launch_command, generate_serve_command, wait_for_llumnix_service_ready, shutdown_llumnix_service) @@ -61,10 +61,11 @@ def run_vllm(model, max_model_len, sampling_params): return vllm_output @pytest.mark.asyncio -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for e2e test") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for correctness test") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo']) -async def test_e2e(ray_env, shutdown_llumnix_service, model, migration_backend): +@pytest.mark.parametrize("launch_mode", ['global', 'local']) +@pytest.mark.parametrize("enable_pd_disagg", [True, False]) +async def test_correctness(ray_env, shutdown_llumnix_service, model, launch_mode, enable_pd_disagg): max_model_len = 370 sampling_params = { "n": 1, @@ -87,12 +88,38 @@ async def test_e2e(ray_env, shutdown_llumnix_service, model, migration_backend): # generate llumnix outputs ip = "127.0.0.1" base_port = 37037 - launch_command = generate_launch_command(model=model, - max_model_len=max_model_len, - ip=ip, - port=base_port, - migration_backend=migration_backend) - subprocess.run(launch_command, shell=True, check=True) + + launch_commands = [] + if launch_mode == "local": + if enable_pd_disagg: + launch_commands.append(generate_launch_command(result_filename=str(base_port)+".out", + model=model, + max_model_len=max_model_len, + port=base_port, + enable_pd_disagg=enable_pd_disagg, + instance_type="prefill")) + launch_commands.append(generate_launch_command(result_filename=str(base_port+1)+".out", + launch_ray_cluster=False, + model=model, + max_model_len=max_model_len, + ip=ip, + port=base_port+1, + enable_pd_disagg=enable_pd_disagg, + instance_type="decode")) + else: + launch_commands.append(generate_launch_command(model=model, + max_model_len=max_model_len, + ip=ip, + port=base_port)) + else: + launch_commands.append(generate_serve_command(result_filename=str(base_port)+".out", + ip=ip, + port=base_port, + model=model, + enable_pd_disagg=enable_pd_disagg)) + for launch_command in launch_commands: + subprocess.run(launch_command, shell=True, check=True) + await asyncio.sleep(3) wait_for_llumnix_service_ready(ip_ports=[f"{ip}:{base_port}"]) diff --git a/tests/e2e_test/utils.py b/tests/e2e_test/utils.py index da71f32a..6cc2372f 100644 --- a/tests/e2e_test/utils.py +++ b/tests/e2e_test/utils.py @@ -29,7 +29,9 @@ def generate_launch_command(result_filename: str = "", max_model_len: int = 4096, log_instance_info: bool = False, request_migration_policy: str = 'SR', - max_num_batched_tokens: int = 16000): + max_num_batched_tokens: int = 16000, + enable_pd_disagg: bool = False, + instance_type: str = "no_constraints"): command = ( f"RAY_DEDUP_LOGS=0 HEAD_NODE_IP={HEAD_NODE_IP} HEAD_NODE=1 " f"nohup python -u -m llumnix.entrypoints.vllm.api_server " @@ -51,12 +53,15 @@ def generate_launch_command(result_filename: str = "", f"--tensor-parallel-size 1 " f"--request-output-queue-port {1234+port} " f"{'--launch-ray-cluster ' if launch_ray_cluster else ''}" + f"{'--enable-pd-disagg ' if enable_pd_disagg else ''}" + f"--instance-type {instance_type} " f"--max-num-batched-tokens {max_num_batched_tokens} " f"{'> instance_'+result_filename if len(result_filename)> 0 else ''} 2>&1 &" ) return command def generate_serve_command(result_filename: str = "", + launch_ray_cluster: bool = True, ip: str = "127.0.0.1", port: int = 37000, dispatch_policy: str = "load", @@ -65,12 +70,15 @@ def generate_serve_command(result_filename: str = "", max_model_len: int = 4096, log_instance_info: bool = False, request_migration_policy: str = 'SR', - max_num_batched_tokens: int = 16000): + max_num_batched_tokens: int = 16000, + enable_pd_disagg: bool = False, + pd_ratio: str = "1:1"): command = ( f"RAY_DEDUP_LOGS=0 " f"nohup python -u -m llumnix.entrypoints.vllm.serve " f"--host {ip} " f"--port {port} " + f"{'--launch-ray-cluster ' if launch_ray_cluster else ''}" f"{'--log-filename manager ' if log_instance_info else ''}" f"{'--log-instance-info ' if log_instance_info else ''}" f"--enable-migration " @@ -86,7 +94,9 @@ def generate_serve_command(result_filename: str = "", f"--tensor-parallel-size 1 " f"--request-output-queue-port {1234+port} " f"--max-num-batched-tokens {max_num_batched_tokens} " + f"--pd-ratio {pd_ratio} " f"--enable-port-increment " + f"{'--enable-pd-disagg ' if enable_pd_disagg else ''}" f"{'> instance_'+result_filename if len(result_filename)> 0 else ''} 2>&1 &" ) return command diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index 4f28d753..24628fe9 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -24,9 +24,9 @@ from llumnix.backends.vllm.llm_engine import BackendVLLM from llumnix.llumlet.llumlet import Llumlet from llumnix.backends.utils import BackendType -from llumnix.internal_config import MigrationConfig from llumnix.llumlet.request import RequestInferenceType, RequestStatus from llumnix.queue.queue_type import QueueType +from llumnix.arg_utils import InstanceArgs from llumnix.utils import initialize_placement_group, get_placement_group_name from tests.unit_test.queue.utils import request_output_queue_server @@ -44,13 +44,13 @@ "Swahili: 'The early bird catches the worm.'\n" ] -def init_llumlet(request_output_queue_type, instance_id, migration_config, engine_args): +def init_llumlet(request_output_queue_type, instance_id, instance_args, engine_args): placement_group = initialize_placement_group(get_placement_group_name(instance_id), num_cpus=3, num_gpus=1, detached=True) llumlet = Llumlet.from_args( instance_id=instance_id, + instance_args=instance_args, placement_group=placement_group, request_output_queue_type=request_output_queue_type, - migration_config=migration_config, backend_type=BackendType.VLLM, engine_args=engine_args) return llumlet @@ -107,27 +107,31 @@ async def test_migration_correctness(ray_env, migration_backend, migration_reque request_migration_policy = "SR" elif migration_request_status == 'waiting': request_migration_policy = "FCW" - migration_config = MigrationConfig(request_migration_policy, migration_backend, 16, 1, 4, 5, 20) + + instance_args = InstanceArgs() + instance_args.request_migration_policy = request_migration_policy + instance_args.migration_backend = migration_backend request_output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(request_output_queue_type) asyncio.create_task(que.run_server_loop()) - llumlet_0 = init_llumlet(request_output_queue_type, "0", migration_config, engine_args) - llumlet_1 = init_llumlet(request_output_queue_type, "1", migration_config, engine_args) + llumlet_0: Llumlet = init_llumlet(request_output_queue_type, "0", instance_args, engine_args) + llumlet_1: Llumlet = init_llumlet(request_output_queue_type, "1", instance_args, engine_args) llumlet_2: Llumlet = MockLlumletDoNotSchedule.options( name='instance_2', namespace='llumnix').remote( instance_id="2", + instance_args=instance_args, request_output_queue_type=request_output_queue_type, backend_type=BackendType.VLLM, - migration_config=migration_config, engine_args=engine_args, ) while True: res = ray.get([llumlet_0.is_ready.remote(), llumlet_1.is_ready.remote(), llumlet_2.is_ready.remote()]) + print("--------", res) if all(res): break @@ -199,14 +203,17 @@ async def test_correctness(prompt): async def test_pd_diaggregation_correctness(ray_env, migration_backend): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) id_rank_map = {"0":0, "1":1} - migration_config = MigrationConfig("SR", migration_backend, 16, 1, 4, 5, 20) + + instance_args = InstanceArgs() + instance_args.request_migration_policy = "SR" + instance_args.migration_backend = migration_backend request_output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(request_output_queue_type) asyncio.create_task(que.run_server_loop()) - llumlet_0 = init_llumlet(request_output_queue_type, "0", migration_config, engine_args) - llumlet_1 = init_llumlet(request_output_queue_type, "1", migration_config, engine_args) + llumlet_0 = init_llumlet(request_output_queue_type, "0", instance_args, engine_args) + llumlet_1 = init_llumlet(request_output_queue_type, "1", instance_args, engine_args) while True: res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()]) diff --git a/tests/unit_test/backends/vllm/test_migration_backend.py b/tests/unit_test/backends/vllm/test_migration_backend.py index f6b1d50d..a5456a30 100644 --- a/tests/unit_test/backends/vllm/test_migration_backend.py +++ b/tests/unit_test/backends/vllm/test_migration_backend.py @@ -19,7 +19,7 @@ from vllm.engine.arg_utils import EngineArgs from llumnix.backends.vllm.worker import MigrationWorker -from llumnix.arg_utils import ManagerArgs +from llumnix.arg_utils import InstanceArgs from llumnix.utils import random_uuid, initialize_placement_group, get_placement_group_name # pylint: disable=unused-import @@ -40,7 +40,7 @@ def get_gpu_cache(self): @pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_migrate_cache(ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migraiton_config = ManagerArgs(migration_buffer_blocks=3, migration_num_layers=5).create_migration_config() + migraiton_config = InstanceArgs(migration_buffer_blocks=3, migration_num_layers=5).create_migration_config() migraiton_config.migration_backend = backend worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config, diff --git a/tests/unit_test/backends/vllm/test_worker.py b/tests/unit_test/backends/vllm/test_worker.py index 15e8e6d6..7b5a107f 100644 --- a/tests/unit_test/backends/vllm/test_worker.py +++ b/tests/unit_test/backends/vllm/test_worker.py @@ -21,7 +21,7 @@ from vllm.config import EngineConfig from vllm.executor.ray_gpu_executor import RayWorkerWrapper -from llumnix.arg_utils import ManagerArgs +from llumnix.arg_utils import InstanceArgs from llumnix.utils import random_uuid from llumnix.utils import initialize_placement_group, get_placement_group_name @@ -60,7 +60,7 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, @pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_reserve_memory_for_migration(ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migration_config = ManagerArgs(migration_buffer_blocks=1).create_migration_config() + migration_config = InstanceArgs(migration_buffer_blocks=1).create_migration_config() migration_config.migration_backend = backend worker = create_worker(rank=0, local_rank=0, engine_config=engine_config) ray.get(worker.execute_method.remote('init_device')) @@ -81,7 +81,7 @@ def test_reserve_memory_for_migration(ray_env, backend): @pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_rebuild_migration_backend(ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migration_config = ManagerArgs(migration_buffer_blocks=1).create_migration_config() + migration_config = InstanceArgs(migration_buffer_blocks=1).create_migration_config() migration_config.migration_backend = backend worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config) diff --git a/tests/unit_test/entrypoints/test_utils.py b/tests/unit_test/entrypoints/test_utils.py index 413906ee..079778b8 100644 --- a/tests/unit_test/entrypoints/test_utils.py +++ b/tests/unit_test/entrypoints/test_utils.py @@ -33,8 +33,7 @@ def test_launch_ray_cluster(): assert result.returncode == 0 def test_init_manager(ray_env): - manager_args = ManagerArgs() - manager = init_manager(manager_args) + manager = init_manager(ManagerArgs()) assert manager is not None manager_actor_handle = ray.get_actor(get_manager_name(), namespace='llumnix') assert manager_actor_handle is not None @@ -47,14 +46,12 @@ def test_init_zmq(ray_env): assert request_output_queue is not None def test_retry_manager_method_sync(ray_env): - manager_args = ManagerArgs() - manager = init_manager(manager_args) + manager = init_manager(ManagerArgs()) ret = retry_manager_method_sync(manager.is_ready.remote, 'is_ready') assert ret is True @pytest.mark.asyncio async def test_retry_manager_method_async(ray_env): - manager_args = ManagerArgs() - manager = init_manager(manager_args) + manager = init_manager(ManagerArgs()) ret = await retry_manager_method_async(manager.is_ready.remote, 'is_ready') assert ret is True diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 11408902..e32faf3b 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -11,56 +11,73 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict import random -import pytest -from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo +from llumnix.instance_info import InstanceInfo from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler +from llumnix.arg_utils import InstanceArgs INSTANCE_NUM = 4 -def init_dispatch_scheduler(policy='load'): - instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, 2) - return dispatch_scheduler - -@pytest.fixture -def dispatch_scheduler(): - dispatch_scheduler = init_dispatch_scheduler() - yield dispatch_scheduler - -@pytest.mark.parametrize("num_dispatch_instances", [1, 2, 3]) -def test_add_instance_and_remove_instance(dispatch_scheduler, num_dispatch_instances): - dispatch_scheduler.num_dispatch_instances = num_dispatch_instances - dispatch_scheduler.add_instance('instance_1') - assert dispatch_scheduler.num_instances == 1 +def test_add_instance_and_remove_instance(): + dispatch_scheduler = DispatchScheduler('balanced') + + dispatch_scheduler.add_instance('instance_1', InstanceArgs(instance_type="no_constraints")) assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 dispatch_scheduler.remove_instance('instance_1') - assert dispatch_scheduler.num_instances == 0 assert len(dispatch_scheduler.available_dispatch_instance_set) == 0 - dispatch_scheduler.add_instance('instance_2') - assert dispatch_scheduler.num_instances == 1 + dispatch_scheduler.add_instance('instance_2', InstanceArgs(instance_type="no_constraints")) assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 - dispatch_scheduler.add_instance('instance_3') - assert dispatch_scheduler.num_instances == 2 - assert len(dispatch_scheduler.available_dispatch_instance_set) == min(2, dispatch_scheduler.num_dispatch_instances) + dispatch_scheduler.add_instance('instance_3', InstanceArgs(instance_type="no_constraints")) + assert len(dispatch_scheduler.available_dispatch_instance_set) == 2 dispatch_scheduler.remove_instance('instance_2') - assert dispatch_scheduler.num_instances == 1 assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 dispatch_scheduler.remove_instance('instance_3') - assert dispatch_scheduler.num_instances == 0 + assert len(dispatch_scheduler.available_dispatch_instance_set) == 0 + +def test_dispatch_to_no_constraints_and_prefill(): + dispatch_scheduler = DispatchScheduler('rr') + instance_num_requests = {} + instance_info_dict = {} + for instance_id in [f'instance_{i}' for i in range(INSTANCE_NUM)]: + instance_info = InstanceInfo( + instance_id=instance_id, + dispatch_load_metric=random.randint(1, 10), + ) + instance_info_dict[instance_id] = instance_info + dispatch_scheduler.instance_num_requests = instance_num_requests + dispatch_scheduler.instance_info = instance_info_dict + + dispatched_instance_ids = [] + available_instance_type = ['no_constraints', 'prefill', 'decode'] + for instance_id, _ in dispatch_scheduler.instance_info.items(): + instance_type = random.choice(available_instance_type) + dispatch_scheduler.add_instance(instance_id, InstanceArgs(instance_type=instance_type)) + if instance_type != 'decode': + dispatched_instance_ids.append(instance_id) + else: + assert instance_id not in dispatch_scheduler.available_dispatch_instance_set + + instance_dispatch_info = defaultdict(int) + for _ in range(INSTANCE_NUM * 2): + instance_id = dispatch_scheduler.dispatch() + instance_dispatch_info[instance_id] += 1 + + for instance_id, num_requests in instance_dispatch_info.items(): + assert instance_id in dispatched_instance_ids + assert num_requests >= 2 def test_dispatch_balanced(): num_tests = 100 for _ in range(num_tests): - dispatch_scheduler = init_dispatch_scheduler('balanced') + dispatch_scheduler = DispatchScheduler('balanced') instance_num_requests = {} for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: - if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: - dispatch_scheduler.available_dispatch_instance_set.add(instance_id) - instance_num_requests[instance_id] = random.randint(1, 10) + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = random.randint(1, 10) dispatch_scheduler.instance_num_requests = instance_num_requests min_instance_id = next(key for key, value in sorted(instance_num_requests.items(), key=lambda item: item[1])) instance_id = dispatch_scheduler.dispatch() @@ -69,31 +86,29 @@ def test_dispatch_balanced(): def test_dispatch_load(): num_tests = 100 for _ in range(num_tests): - dispatch_scheduler = init_dispatch_scheduler('load') + dispatch_scheduler = DispatchScheduler('load') instance_num_requests = {} instance_info_dict = {} for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id - instance_info.instance_load_dispatch_scale = random.random() + instance_info.dispatch_load_metric = random.random() instance_info_dict[instance_id] = instance_info - if dispatch_scheduler.num_dispatch_instances <= 0 or (dispatch_scheduler.num_dispatch_instances > 0 - and len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances): - dispatch_scheduler.available_dispatch_instance_set.add(instance_id) - instance_num_requests[instance_id] = 0 + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict available_instance_dict = {key: value for key, value in instance_info_dict.items() if key in dispatch_scheduler.available_dispatch_instance_set} min_instance_id = next(key for key, value in sorted(available_instance_dict.items(), - key=lambda item: item[1].instance_load_dispatch_scale)) + key=lambda item: item[1].dispatch_load_metric)) instance_id = dispatch_scheduler.dispatch() assert min_instance_id == instance_id def test_dispatch_queue(): num_tests = 100 for _ in range(num_tests): - dispatch_scheduler = init_dispatch_scheduler('queue') + dispatch_scheduler = DispatchScheduler('queue') instance_num_requests = {} instance_info_dict = {} for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: @@ -101,9 +116,8 @@ def test_dispatch_queue(): instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) instance_info_dict[instance_id] = instance_info - if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: - dispatch_scheduler.available_dispatch_instance_set.add(instance_id) - instance_num_requests[instance_id] = 0 + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict available_instance_dict = {key: value for key, value in instance_info_dict.items() @@ -115,8 +129,7 @@ def test_dispatch_queue(): def test_dispatch_rr(): instance_num = 7 - instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - dispatch_scheduler = DispatchScheduler('rr', instance_load_calculator, 3) + dispatch_scheduler = DispatchScheduler("rr") instance_num_requests = {} instance_info_dict = {} @@ -125,23 +138,13 @@ def test_dispatch_rr(): instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) instance_info_dict[instance_id] = instance_info - if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: - dispatch_scheduler.available_dispatch_instance_set.add(instance_id) - instance_num_requests[instance_id] = 0 + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict - num_request = 2 * instance_num + 2 + num_request = 2 * instance_num + 1 for idx in range(0, num_request): instance_id = dispatch_scheduler.dispatch() - target_instance_id = idx%dispatch_scheduler.num_dispatch_instances + target_instance_id = idx%instance_num assert instance_id == f'instance_{target_instance_id}' - - for idx in range(instance_num): - if idx < dispatch_scheduler.num_dispatch_instances: - dispatch_scheduler.instance_num_requests[f'instance_{idx}'] = \ - num_request // dispatch_scheduler.num_dispatch_instances + (1 \ - if num_request % dispatch_scheduler.num_dispatch_instances > \ - idx % dispatch_scheduler.num_dispatch_instances else 0) - else: - dispatch_scheduler.instance_num_requests[f'instance_{idx}'] = 0 diff --git a/tests/unit_test/global_scheduler/test_global_scheduler.py b/tests/unit_test/global_scheduler/test_global_scheduler.py index 9a30b6c9..2c431bf0 100644 --- a/tests/unit_test/global_scheduler/test_global_scheduler.py +++ b/tests/unit_test/global_scheduler/test_global_scheduler.py @@ -16,16 +16,16 @@ from llumnix.internal_config import GlobalSchedulerConfig from llumnix.global_scheduler.global_scheduler import GlobalScheduler -from llumnix.instance_info import InstanceInfo +from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator from llumnix.utils import random_uuid +from llumnix.arg_utils import InstanceArgs from .test_manager import get_instance_info_migrate_in, get_instance_info_migrate_out def init_global_scheduler(): - global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', math.inf, - 'defrag_constrained', 3.0, True, 'avg_load', - 10, 60, False, 'rayrpc') + global_scheduler_config = GlobalSchedulerConfig(0, 'load', 'defrag_constrained', 3.0, + 'avg_load', 'remaining_steps', 10, 60, False, False) global_scheduler = GlobalScheduler(global_scheduler_config) return global_scheduler @@ -47,7 +47,7 @@ def test_scale_up_and_scale_down(global_scheduler): initial_instances = 4 instance_infos = init_instance_infos(initial_instances) instance_ids = [instance_info.instance_id for instance_info in instance_infos] - num_instances = global_scheduler.scale_up(instance_ids) + num_instances = global_scheduler.scale_up(instance_ids, [InstanceArgs(instance_type="no_constraints")]*len(instance_ids)) assert num_instances == initial_instances instance_infos = init_instance_infos(initial_instances) instance_ids_1 = [instance_info.instance_id for instance_info in instance_infos] @@ -62,7 +62,7 @@ def test_update_instance_infos(global_scheduler): global_scheduler.update_instance_infos(instance_infos) assert len(global_scheduler.instance_info) == 0 instance_ids = [instance_info.instance_id for instance_info in instance_infos] - global_scheduler.scale_up(instance_ids) + global_scheduler.scale_up(instance_ids, [InstanceArgs(instance_type="no_constraints")]*len(instance_ids)) global_scheduler.update_instance_infos(instance_infos) assert len(global_scheduler.instance_info) == initial_instances @@ -70,7 +70,7 @@ def test_dispatch(global_scheduler): initial_instances = 4 instance_infos = init_instance_infos(initial_instances) instance_ids = [instance_info.instance_id for instance_info in instance_infos] - global_scheduler.scale_up(instance_ids) + global_scheduler.scale_up(instance_ids, [InstanceArgs(instance_type="no_constraints")]*len(instance_ids)) global_scheduler.update_instance_infos(instance_infos) instance_id, request_expected_steps = global_scheduler.dispatch() assert instance_id in instance_ids @@ -82,9 +82,16 @@ def test_pair_migration(global_scheduler): instance_ids = [instance_id, instance_id_1] instance_info_migrate_in = get_instance_info_migrate_in(instance_id) instance_info_migrate_out = get_instance_info_migrate_out(instance_id_1) + instance_load_calculator = InstanceLoadCalculator("remaining_steps", "remaining_steps", False) + instance_load_calculator.compute_instance_load(instance_info_migrate_in) + instance_load_calculator.compute_instance_load(instance_info_migrate_out) + print("-------", instance_info_migrate_in.migration_load_metric) + print(instance_info_migrate_out.migration_load_metric) instance_infos = [instance_info_migrate_in, instance_info_migrate_out] - global_scheduler.scale_up(instance_ids) + global_scheduler.scale_up(instance_ids, [InstanceArgs(instance_type="no_constraints")]*len(instance_ids)) global_scheduler.update_instance_infos(instance_infos) + migrate_instace_pairs = global_scheduler.pair_migration("NO_CONSTRAINTS") + assert len(migrate_instace_pairs) > 0 assert migrate_instace_pairs[0][0] == instance_id_1 assert migrate_instace_pairs[0][1] == instance_id diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index 9cbbe2df..e41b2a37 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -11,20 +11,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import os import time import math import ray import pytest import numpy as np +import torch from vllm import EngineArgs -from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, LaunchArgs +from llumnix.launcher import Launcher +from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, LaunchArgs, InstanceArgs from llumnix.manager import Manager -from llumnix.instance_info import InstanceInfo +from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator from llumnix.server_info import ServerInfo from llumnix.queue.queue_type import QueueType -from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.instance_info import InstanceType from llumnix.backends.vllm.simulator import BackendSimVLLM from llumnix.backends.backend_interface import BackendType from llumnix.backends.profiling import LatencyMemData @@ -57,8 +61,8 @@ def set_instance_info(self, instance_info): def get_instance_info(self): return self.instance_info - def is_ready(self) -> bool: - return True + def is_ready(self) -> InstanceArgs: + return InstanceArgs() def get_all_request_ids(self): return list(self.request_id_set) @@ -108,24 +112,53 @@ def _get_lantecy_mem(self, *args, **kwargs): def init_manager(): try: - manager_args = ManagerArgs(migration_backend="rayrpc", enable_migration=True) + manager_args = ManagerArgs(enable_migration=True) manager_args.log_instance_info = False - manager = Manager.from_args(manager_args=manager_args) + manager = Manager.from_args( + entrypoints_args=None, + manager_args=manager_args, + instance_args=InstanceArgs(migration_backend="rayrpc"), + engine_args=None, + launch_args=None, + ) except ValueError: manager = ray.get_actor(get_manager_name(), namespace='llumnix') ray.get(manager.is_ready.remote()) return manager -def init_manager_with_launch_mode(launch_mode, request_output_queue_type="rayqueue"): - manager_args = ManagerArgs(migration_backend="rayrpc", enable_port_increment=True) +class MockManager(Manager): + async def init_placement_group(self, *args, **kwargs): + return self.launcher.init_placement_group(*args, **kwargs) + + async def init_server_and_instance(self, *args, **kwargs): + return self.launcher.init_server_and_instance(*args, **kwargs) + + async def clear_instance_ray_resources(self, instance_id: str): + return self.launcher.clear_instance_ray_resources(instance_id) + + async def get_cluster_deployment(self): + return self.launcher.get_cluster_deployment() + + async def get_instance_deployment_states(self, instance_id: str): + return self.launcher.get_instance_deployment_states(instance_id) + +def init_manager_with_launch_mode(launch_mode, request_output_queue_type="rayqueue", + enable_pd_disagg=False, pd_ratio="1:3"): + manager_args = ManagerArgs(enable_port_increment=True, enable_port_offset_store=True, + enable_pd_disagg=enable_pd_disagg, pd_ratio=pd_ratio) + instance_args = InstanceArgs(migration_backend="rayrpc") entrypoints_args = EntrypointsArgs(host="127.0.0.1", port=8000, request_output_queue_type=request_output_queue_type) engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) launch_args = LaunchArgs(launch_mode=launch_mode, backend_type=BackendType.VLLM) - manager = Manager.from_args(manager_args=manager_args, - entrypoints_args=entrypoints_args, - engine_args=engine_args, - launch_args=launch_args) - ray.get(manager.is_ready.remote()) + + # As mock_manager can not be initialized to ray actor, it is initialized as local variable. + # But, some place need to get the manager actor, so create the dummy manager actor here. + dummy_manager_actor = init_manager() + ray.get(dummy_manager_actor.is_ready.remote()) + manager = MockManager(entrypoints_args=entrypoints_args, manager_args=manager_args, + instance_args=instance_args, engine_args=engine_args, + launch_args=launch_args, work_dir=os.getcwd()) + return manager, manager_args, entrypoints_args, engine_args, launch_args def init_instances(initial_instances): @@ -182,7 +215,7 @@ def test_init_llumlet(ray_env, llumlet): def test_init_instances(ray_env, manager): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - _, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.VLLM, engine_args)) + _, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.VLLM, InstanceArgs(), engine_args)) num_instances = len(instances) manager_args = ManagerArgs() assert num_instances == manager_args.initial_instances @@ -193,7 +226,7 @@ def test_init_instances_sim(ray_env, manager): import llumnix.backends.vllm.simulator llumnix.backends.vllm.simulator.BackendSimVLLM = MockBackendSim engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - _, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.SIM_VLLM, engine_args)) + _, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.SIM_VLLM, InstanceArgs(), engine_args)) num_instances = len(instances) manager_args = ManagerArgs() assert num_instances == manager_args.initial_instances @@ -201,12 +234,12 @@ def test_init_instances_sim(ray_env, manager): def test_scale_up_and_down(ray_env, manager): initial_instances = 4 instance_ids, instances = init_instances(initial_instances) - num_instances = ray.get(manager.scale_up.remote(instance_ids, instances)) + num_instances = ray.get(manager.scale_up.remote(instance_ids, instances, [InstanceArgs()]*initial_instances)) assert num_instances == initial_instances instance_ids_1, instances_1 = init_instances(initial_instances) num_instances = ray.get(manager.scale_down.remote(instance_ids_1)) assert num_instances == initial_instances - num_instances = ray.get(manager.scale_up.remote(instance_ids_1, instances_1)) + num_instances = ray.get(manager.scale_up.remote(instance_ids_1, instances_1, [InstanceArgs()]*initial_instances)) assert num_instances == initial_instances * 2 num_instances = ray.get(manager.scale_down.remote(instance_ids)) assert num_instances == initial_instances @@ -219,14 +252,14 @@ def test_connect_to_instances(ray_env): ray.get([instance.is_ready.remote() for instance in instances]) manager = init_manager() instance_ids_1, instances_1 = init_instances(initial_instances) - num_instances = ray.get(manager.scale_up.remote(instance_ids_1, instances_1)) + num_instances = ray.get(manager.scale_up.remote(instance_ids_1, instances_1, [InstanceArgs()]*initial_instances)) assert num_instances == initial_instances * 2 num_instances = ray.get(manager.scale_down.remote(instance_ids)) assert num_instances == initial_instances def test_generate_and_abort(ray_env, manager, llumlet): instance_id = ray.get(llumlet.get_instance_id.remote()) - ray.get(manager.scale_up.remote(instance_id, llumlet)) + ray.get(manager.scale_up.remote(instance_id, llumlet, InstanceArgs())) request_id = random_uuid() num_requests = ray.get(llumlet.get_num_requests.remote()) assert num_requests == 0 @@ -263,21 +296,26 @@ def test_get_request_instance(ray_env): assert num_requests_1 == 0 def get_instance_info_migrate_in(instance_id): - instance_info = InstanceInfo() - instance_info.instance_id = instance_id - instance_info.num_available_gpu_blocks = np.inf - instance_info.num_running_requests = 1 - instance_info.num_blocks_first_waiting_request = 0 - instance_info.instance_type = InstanceType.NO_CONSTRAINTS + instance_info = InstanceInfo( + instance_id=instance_id, + instance_type=InstanceType.NO_CONSTRAINTS, + num_available_gpu_blocks=np.inf, + num_running_requests=1, + num_blocks_first_waiting_request=0, + num_killed_requests=0 + ) + return instance_info def get_instance_info_migrate_out(instance_id): - instance_info = InstanceInfo() - instance_info.instance_id = instance_id - instance_info.num_available_gpu_blocks = 0 - instance_info.num_running_requests = 1 - instance_info.num_blocks_first_waiting_request = np.inf - instance_info.instance_type = InstanceType.NO_CONSTRAINTS + instance_info = InstanceInfo( + instance_id=instance_id, + instance_type=InstanceType.NO_CONSTRAINTS, + num_available_gpu_blocks=0, + num_running_requests=1, + num_blocks_first_waiting_request=np.inf, + num_killed_requests=np.inf + ) return instance_info def test_update_instance_info_loop_and_migrate(ray_env, manager): @@ -288,58 +326,60 @@ def test_update_instance_info_loop_and_migrate(ray_env, manager): for _ in range(2*(i+1)): ray.get(instances[i].generate.remote(random_uuid(), None, math.inf, None, None)) - instance_info = InstanceInfo() - instance_info.instance_type = InstanceType.NO_CONSTRAINTS - + instance_load_calculator = InstanceLoadCalculator("remaining_steps", "remaining_steps", True) for i in range(num_instances): - instance_info.instance_id = instance_ids[i] - instance_info.num_available_gpu_blocks = 40 - i * 10 - instance_info.num_running_requests = i - instance_info.num_blocks_first_waiting_request = i + instance_info = InstanceInfo( + instance_id=instance_ids[i], + instance_type=InstanceType.NO_CONSTRAINTS, + num_free_gpu_blocks=40-i*10, + num_running_requests=i+1, + num_blocks_first_waiting_request=i, + ) + instance_load_calculator.compute_instance_load(instance_info) ray.get(instances[i].set_instance_info.remote(instance_info)) for i in range(num_instances): num_migrate_out = ray.get(instances[i].get_num_migrate_out.remote()) assert num_migrate_out == 0 - ray.get(manager.scale_up.remote(instance_ids, instances)) - time.sleep(2) + ray.get(manager.scale_up.remote(instance_ids, instances, [InstanceArgs()]*len(instance_ids))) + time.sleep(3) for i in range(num_instances): num_migrate_out = ray.get(instances[i].get_num_migrate_out.remote()) num_migrate_in = ray.get(instances[i].get_num_migrate_in.remote()) - if i == 0: assert num_migrate_in > 1 and num_migrate_out == 0 elif i == num_instances - 1: assert num_migrate_in == 0 and num_migrate_out > 1 - else: - assert num_migrate_in == 0 and num_migrate_out == 0 -def test_init_server_and_get_instance_deployment_states_and_instance_and_clear_instance_ray_resources(ray_env): +@pytest.mark.asyncio +async def test_init_server_and_get_instance_deployment_states_and_instance_and_clear_instance_ray_resources(ray_env): manager, _, _, engine_args, _ = init_manager_with_launch_mode(LaunchMode.LOCAL) instance_id = random_uuid() - pg = ray.get(manager._init_placement_group.remote(get_placement_group_name(instance_id), - engine_args, BackendType.VLLM, init_server=True)) + pg = await manager.init_placement_group(get_placement_group_name(instance_id), + engine_args, BackendType.VLLM, init_server=True) pg = ray.util.get_placement_group(get_placement_group_name(instance_id)) ray.get(pg.ready()) - ray.get(manager._init_server_and_instance.remote(instance_id, pg)) + await manager.init_server_and_instance(instance_id, EntrypointsArgs(), InstanceArgs(), engine_args, BackendType.VLLM, pg) + # wait for scale up - time.sleep(5.0) + await asyncio.sleep(5.0) server = ray.get_actor(get_server_name(instance_id), namespace="llumnix") ray.get(server.is_ready.remote()) instance = ray.get_actor(get_instance_name(instance_id), namespace="llumnix") ray.get(instance.is_ready.remote()) - num_instances = ray.get(manager.scale_up.remote(instance_id, instance)) + num_instances = manager.scale_up(instance_id, instance, InstanceArgs()) assert num_instances == 1 - pg_created, server_alive, instance_alive = ray.get(manager._get_instance_deployment_states.remote(instance_id)) + pg_created, server_alive, instance_alive = await manager.get_instance_deployment_states(instance_id) assert pg_created and server_alive and instance_alive # test clear_instance_ray_resources - ray.get(manager._clear_instance_ray_states.remote(instance_id)) + await manager.clear_instance_ray_resources(instance_id) # wait for remove and kill - time.sleep(1.0) + await asyncio.sleep(5.0) + pg_exists = is_placement_group_exists(get_placement_group_name(instance_id)) assert not pg_exists server_exists = is_actor_exists(get_server_name(instance_id)) @@ -347,37 +387,42 @@ def test_init_server_and_get_instance_deployment_states_and_instance_and_clear_i instance_exists = is_actor_exists(get_instance_name(instance_id)) assert not instance_exists - pg_created, server_alive, instance_alive = ray.get(manager._get_instance_deployment_states.remote(instance_id)) + pg_created, server_alive, instance_alive = await manager.get_instance_deployment_states(instance_id) assert not pg_created and not server_alive and not instance_alive +@pytest.mark.asyncio @pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq']) -def test_auto_scale_up_loop_and_get_cluster_deployment(ray_env, request_output_queue_type): +async def test_auto_scale_up_loop_and_get_cluster_deployment(ray_env, request_output_queue_type): manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, request_output_queue_type) - time.sleep(30.0) - num_instances = ray.get(manager.scale_up.remote([], [])) + await asyncio.sleep(30.0) + + num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote()) + curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 actor_names_dict = ray.util.list_named_actors(all_namespaces=True) instance_ids = [actor_name_dict['name'].split("_")[-1] for actor_name_dict in actor_names_dict if actor_name_dict['name'].startswith(INSTANCE_NAME_PREFIX)] assert len(instance_ids) == 4 - ray.get(manager._clear_instance_ray_states.remote(instance_ids[0])) - ray.get(manager._clear_instance_ray_states.remote(instance_ids[1])) - time.sleep(30.0) - num_instances = ray.get(manager.scale_up.remote([], [])) + await manager.clear_instance_ray_resources(instance_ids[0]) + await manager.clear_instance_ray_resources(instance_ids[1]) + await asyncio.sleep(30.0) + + num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote()) + curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 +@pytest.mark.asyncio @pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq']) -def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, request_output_queue_type): +async def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, request_output_queue_type): manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, request_output_queue_type) - time.sleep(30.0) - num_instances = ray.get(manager.scale_up.remote([], [])) + await asyncio.sleep(30.0) + + num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote()) + curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 actor_names_dict = ray.util.list_named_actors(all_namespaces=True) @@ -388,8 +433,108 @@ def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, request_ou kill_server(instance_ids[1]) kill_instance(instance_ids[2]) # Wait for check deployment states, scale down instance and auto scale up. - time.sleep(90.0) - num_instances = ray.get(manager.scale_up.remote([], [])) + await asyncio.sleep(90.0) + + num_instances = manager.scale_up([], [], []) + assert num_instances == 4 + curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() + assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 + +def test_pd_disagg_gloal_launch_instance_type(): + launcher = Launcher(None, True, False, True, False, [1, 2]) + + assert launcher._get_next_instance_type(0, 0, [1, 2]) == InstanceType.PREFILL + launcher.inflight_num_prefill += 1 + + assert launcher._get_next_instance_type(0, 0, [1, 2]) == InstanceType.DECODE + launcher.inflight_num_decode += 1 + + launcher.inflight_num_prefill = 0 + launcher.inflight_num_decode = 0 + assert launcher._get_next_instance_type(1, 1, [1, 2]) == InstanceType.DECODE + + assert launcher._get_next_instance_type(1, 2, [1, 2]) == InstanceType.PREFILL + +@pytest.mark.asyncio +@pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq']) +async def test_pd_disagg_gloal_launch_deployment_and_auto_scale_up_loop(ray_env, request_output_queue_type): + manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, request_output_queue_type, + enable_pd_disagg=True, pd_ratio="1:1") + await asyncio.sleep(30.0) + + num_instances = manager.scale_up([], [], []) + assert num_instances == 4 + curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() + assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 + + num_prefill_instances = 0 + num_decode_instances = 0 + prefill_instance_ids = [] + decode_instance_ids = [] + for _, instance_handle in curr_instances.items(): + instance_type = ray.get(instance_handle.is_ready.remote()).instance_type + if instance_type == InstanceType.PREFILL: + num_prefill_instances += 1 + prefill_instance_ids.append(ray.get(instance_handle.get_instance_info.remote()).instance_id) + elif instance_type == InstanceType.DECODE: + num_decode_instances += 1 + decode_instance_ids.append(ray.get(instance_handle.get_instance_info.remote()).instance_id) + + assert torch.cuda.device_count() == 4 + assert num_prefill_instances == 2 and num_decode_instances == 2 + assert set(prefill_instance_ids).union(set(decode_instance_ids)) == set(curr_instances.keys()) + + kill_instance(prefill_instance_ids[0]) + await asyncio.sleep(10.0) + + kill_instance(prefill_instance_ids[1]) + await asyncio.sleep(10.0) + + kill_instance(decode_instance_ids[1]) + await asyncio.sleep(90.0) + alive_decode_instance_id = decode_instance_ids[0] + + num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote()) + curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 + + num_prefill_instances = 0 + num_decode_instances = 0 + decode_instance_ids = [] + for instance_id, instance_handle in curr_instances.items(): + instance_type = ray.get(instance_handle.is_ready.remote()).instance_type + if instance_type == InstanceType.PREFILL: + num_prefill_instances += 1 + elif instance_type == InstanceType.DECODE: + num_decode_instances += 1 + decode_instance_ids.append(instance_id) + + assert num_prefill_instances == 2 and num_decode_instances == 2 + assert alive_decode_instance_id in decode_instance_ids + +@pytest.mark.asyncio +async def test_pd_disagg_deployment_states(): + manager_args = ManagerArgs(enable_migration=True, enable_pd_disagg=True, pd_ratio="1:2") + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + manager = Manager(entrypoints_args=EntrypointsArgs(), manager_args=manager_args, + instance_args=InstanceArgs(migration_backend="rayrpc"), + engine_args=engine_args, launch_args=LaunchArgs(LaunchMode.LOCAL, BackendType.VLLM), + work_dir=os.getcwd()) + assert not manager._inner_check_pd_deployment() + + prefill_instance_ids = [random_uuid() for _ in range(3)] + decode_instance_ids = [random_uuid() for _ in range(3)] + + manager.scale_up(prefill_instance_ids, [None]*len(prefill_instance_ids), + [InstanceArgs(instance_type="prefill")]*len(prefill_instance_ids)) + assert manager._inner_check_pd_deployment() in prefill_instance_ids + + manager.scale_down(prefill_instance_ids) + manager.scale_up(decode_instance_ids, [None]*len(decode_instance_ids), + [InstanceArgs(instance_type="decode")]*len(decode_instance_ids)) + assert manager._inner_check_pd_deployment() in decode_instance_ids + + manager.scale_up(prefill_instance_ids, [None]*len(prefill_instance_ids), + [InstanceArgs(instance_type="prefill")]*len(prefill_instance_ids)) + assert not manager._inner_check_pd_deployment() diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index 3ed1655a..669ce074 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -14,31 +14,26 @@ import math import random import pytest -import numpy as np -from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo +from llumnix.instance_info import InstanceInfo from llumnix.global_scheduler.migration_scheduler import MigrationScheduler -from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.instance_info import InstanceType from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig from llumnix.global_scheduler.migration_policy import PairMigrationConstraints +from llumnix.arg_utils import InstanceArgs -MIGRATE_OUT_LOAD_THRESHOLD = 3.0 +MIGRATE_OUT_LOAD_THRESHOLD = -3.0 INSTANCE_NUM = 16 def init_migration_scheduler(policy='balanced'): - instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, instance_load_calculator, 'rayrpc') + migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, False) return migration_scheduler -@pytest.fixture -def migration_scheduler(): - migration_scheduler = init_migration_scheduler() - yield migration_scheduler - -def test_add_instance_and_remove_instance(migration_scheduler): - migration_scheduler.add_instance('instance_1') +def test_add_instance_and_remove_instance(): + migration_scheduler = init_migration_scheduler('balanced') + migration_scheduler.add_instance('instance_1', InstanceArgs(instance_type="no_constraints")) assert migration_scheduler.num_instances == 1 - migration_scheduler.add_instance('instance_2') + migration_scheduler.add_instance('instance_2', InstanceArgs(instance_type="no_constraints")) assert migration_scheduler.num_instances == 2 migration_scheduler.remove_instance('instance_1') assert migration_scheduler.num_instances == 1 @@ -58,7 +53,7 @@ def test_migration_filter(pair_migration_type): for instance_id in range(1, INSTANCE_NUM + 1): instance_info = InstanceInfo() instance_info.instance_id = instance_id - instance_info.instance_load_migrate = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) + instance_info.migration_load_metric = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) instance_info.num_killed_requests = random.randint(0, 1) if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: @@ -82,7 +77,7 @@ def test_migration_filter(pair_migration_type): for instance in src_instance_infos: if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: assert instance.num_killed_requests > 0 \ - or instance.instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD + or instance.migration_load_metric > MIGRATE_OUT_LOAD_THRESHOLD if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: assert instance.instance_type == InstanceType.NO_CONSTRAINTS elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: @@ -92,7 +87,7 @@ def test_migration_filter(pair_migration_type): for instance in dst_instance_infos: if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: - assert instance.num_killed_requests == 0 and instance.instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD + assert instance.num_killed_requests == 0 and instance.migration_load_metric < MIGRATE_OUT_LOAD_THRESHOLD if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: assert instance.instance_type == InstanceType.NO_CONSTRAINTS elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: @@ -104,6 +99,7 @@ def test_migration_filter(pair_migration_type): @pytest.mark.parametrize("policy", ['balanced', 'defrag_constrained']) def test_pair_migration(policy): num_tests = 1000 + exist_migration = False for _ in range(num_tests): migration_scheduler = init_migration_scheduler(policy) @@ -111,18 +107,22 @@ def test_pair_migration(policy): for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id - instance_info.instance_load_migrate = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) + instance_info.migration_load_metric = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-6, 3) + instance_info.migration_load_metric_after_migrate_out = instance_info.migration_load_metric - random.uniform(0, 1) + instance_info.migration_load_metric_after_migrate_in = instance_info.migration_load_metric + random.uniform(0, 1) instance_info.num_killed_requests = random.randint(0, 1) - instance_info.num_blocks_last_running_request = random.randint(0, 1) * np.inf instance_info.instance_type = InstanceType.NO_CONSTRAINTS instance_info_dict[instance_id] = instance_info migration_scheduler.instance_info = instance_info_dict migrate_instance_pairs = migration_scheduler.pair_migration(PairMigrationConstraints.NO_CONSTRAINTS) + exist_migration = exist_migration or len(migrate_instance_pairs) > 0 for migrate_out_instance, migrate_in_instance in migrate_instance_pairs: assert migrate_out_instance != migrate_in_instance if policy == 'balanced': assert instance_info_dict[migrate_out_instance].num_blocks_last_running_request == 0 if instance_info_dict[migrate_out_instance].num_killed_requests == 0: - assert instance_info_dict[migrate_out_instance].instance_load_migrate > instance_info_dict[migrate_in_instance].instance_load_migrate + assert instance_info_dict[migrate_out_instance].migration_load_metric > instance_info_dict[migrate_in_instance].migration_load_metric + + assert exist_migration diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index a15dc52f..92a85a1f 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -19,9 +19,9 @@ from vllm.engine.arg_utils import EngineArgs +from llumnix.arg_utils import InstanceArgs from llumnix.backends.backend_interface import BackendType from llumnix.llumlet.llumlet import Llumlet -from llumnix.internal_config import MigrationConfig from llumnix.queue.queue_type import QueueType from llumnix.utils import initialize_placement_group, get_placement_group_name @@ -54,10 +54,9 @@ async def raise_error_step(): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.") def test_engine_step_exception(ray_env): engine_args = EngineArgs(model="facebook/opt-125m", max_model_len=8, worker_use_ray=True) - migration_config = MigrationConfig("SR", "rayrpc", 16, 1, 4, 5, 20) + # wait previous test to release the GPU memory time.sleep(5.0) - device_count = torch.cuda.device_count() origin_free_memory_list = [] for device_id in range(device_count): @@ -67,9 +66,9 @@ def test_engine_step_exception(ray_env): actor_name = "instance_0" llumlet = MockLlumlet.options(name=actor_name, namespace='llumnix').remote( instance_id="0", + instance_args=InstanceArgs(), request_output_queue_type=QueueType.RAYQUEUE, backend_type=BackendType.VLLM, - migration_config=migration_config, engine_args=engine_args, ) ray.get(llumlet.is_ready.remote()) diff --git a/tests/unit_test/llumlet/test_migration_coordinator.py b/tests/unit_test/llumlet/test_migration_coordinator.py index d73389e5..d9f7c93f 100644 --- a/tests/unit_test/llumlet/test_migration_coordinator.py +++ b/tests/unit_test/llumlet/test_migration_coordinator.py @@ -38,7 +38,7 @@ async def test_migrate_out_onestage(ray_env): migrate_out_request = MagicMock() # Create an instance of MigrationCoordinator - coordinator = MigrationCoordinator(backend_engine, last_stage_max_blocks=1, max_stages=3) + coordinator = MigrationCoordinator(backend_engine, migration_last_stage_max_blocks=1, migration_max_stages=3) # Mock method return values and test data src_blocks = [1, 2, 3] @@ -94,8 +94,8 @@ async def test_migrate_out_running_request(_, ray_env): migrate_out_request = MockRequest("1", 1, math.inf) # Create an instance of MigrationCoordinator - max_stages = 3 - coordinator = MigrationCoordinator(backend_engine, 1, max_stages) + migration_max_stages = 3 + coordinator = MigrationCoordinator(backend_engine, 1, migration_max_stages) migrate_in_ray_actor = MagicMock() migrate_in_ray_actor.execute_engine_method = MagicMock() migrate_in_ray_actor.execute_engine_method.remote = MagicMock() @@ -106,13 +106,13 @@ async def test_migrate_out_running_request(_, ray_env): assert coordinator._migrate_out_onestage.call_count == 1 assert status == MigrationStatus.FINISHED - max_stages = 3 + migration_max_stages = 3 coordinator._migrate_out_onestage.side_effect = [MigrationStatus.RUNNING, MigrationStatus.RUNNING, MigrationStatus.RUNNING, MigrationStatus.RUNNING] status = await coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) - assert coordinator._migrate_out_onestage.call_count == max_stages + 1 + assert coordinator._migrate_out_onestage.call_count == migration_max_stages + 1 assert status == MigrationStatus.ABORTED_SRC @pytest.mark.asyncio @@ -123,7 +123,7 @@ async def test_migrate_out_waiting_request(): migrate_out_request = MagicMock() # Create an instance of MigrationCoordinator - coordinator = MigrationCoordinator(backend_engine, last_stage_max_blocks=1, max_stages=3) + coordinator = MigrationCoordinator(backend_engine, migration_last_stage_max_blocks=1, migration_max_stages=3) # Test FINISHED migrate_out_request.prefill_num_blocks = 3