From d9a252bc8e8a2741d8a2997032a94208fb8f29d9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 20 Jun 2024 22:12:35 -0700 Subject: [PATCH] [Core][Distributed] add shm broadcast (#5399) Co-authored-by: Cody Yu --- .buildkite/test-pipeline.yaml | 4 +- tests/distributed/test_shm_broadcast.py | 82 ++++++ .../device_communicators/shm_broadcast.py | 259 ++++++++++++++++++ vllm/distributed/parallel_state.py | 44 ++- vllm/envs.py | 5 + 5 files changed, 384 insertions(+), 10 deletions(-) create mode 100644 tests/distributed/test_shm_broadcast.py create mode 100644 vllm/distributed/device_communicators/shm_broadcast.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 5e92ba3c24f55..c337a81d4a0d2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -28,9 +28,11 @@ steps: - label: Distributed Comm Ops Test #mirror_hardwares: [amd] - command: pytest -v -s distributed/test_comm_ops.py working_dir: "/vllm-workspace/tests" num_gpus: 2 + commands: + - pytest -v -s distributed/test_comm_ops.py + - pytest -v -s distributed/test_shm_broadcast.py - label: Distributed Tests (2 GPUs) mirror_hardwares: [amd] diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py new file mode 100644 index 0000000000000..d92900ffce00b --- /dev/null +++ b/tests/distributed/test_shm_broadcast.py @@ -0,0 +1,82 @@ +import multiprocessing +import random +import time + +import torch.distributed as dist + +from vllm.distributed.device_communicators.shm_broadcast import ( + ShmRingBuffer, ShmRingBufferIO) +from vllm.utils import update_environment_variables + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes = [] + for i in range(number_of_processes): + env = {} + env['RANK'] = str(i) + env['LOCAL_RANK'] = str(i) + env['WORLD_SIZE'] = str(number_of_processes) + env['LOCAL_WORLD_SIZE'] = str(number_of_processes) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = '12345' + p = multiprocessing.Process(target=fn, args=(env, )) + processes.append(p) + p.start() + + for p in processes: + p.join() + + for p in processes: + assert p.exitcode == 0 + + +def worker_fn_wrapper(fn): + # `multiprocessing.Process` cannot accept environment variables directly + # so we need to pass the environment variables as arguments + # and update the environment variables in the function + def wrapped_fn(env): + update_environment_variables(env) + dist.init_process_group(backend="gloo") + fn() + + return wrapped_fn + + +@worker_fn_wrapper +def worker_fn(): + writer_rank = 2 + broadcaster = ShmRingBufferIO.create_from_process_group( + dist.group.WORLD, 1024, 2, writer_rank) + if dist.get_rank() == writer_rank: + time.sleep(random.random()) + broadcaster.broadcast_object(0) + time.sleep(random.random()) + broadcaster.broadcast_object({}) + time.sleep(random.random()) + broadcaster.broadcast_object([]) + else: + time.sleep(random.random()) + a = broadcaster.broadcast_object(None) + time.sleep(random.random()) + b = broadcaster.broadcast_object(None) + time.sleep(random.random()) + c = broadcaster.broadcast_object(None) + assert a == 0 + assert b == {} + assert c == [] + dist.barrier() + + +def test_shm_broadcast(): + distributed_run(worker_fn, 4) + + +def test_singe_process(): + buffer = ShmRingBuffer(1, 1024, 4) + reader = ShmRingBufferIO(buffer, reader_rank=0) + writer = ShmRingBufferIO(buffer, reader_rank=-1) + writer.enqueue([0]) + writer.enqueue([1]) + assert reader.dequeue() == [0] + assert reader.dequeue() == [1] diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py new file mode 100644 index 0000000000000..119befcf64052 --- /dev/null +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -0,0 +1,259 @@ +import pickle +import time +from contextlib import contextmanager +from multiprocessing import shared_memory +from typing import Optional +from unittest.mock import patch + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm.logger import init_logger + +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL + +logger = init_logger(__name__) + + +class ShmRingBuffer: + + def __init__(self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None): + """ + A shared memory ring buffer implementation for broadcast communication. + Essentially, it is a queue where only one will `enqueue` and multiple + will `dequeue`. The max size of each item, together with the max number + of items that can be stored in the buffer are known in advance. + In this case, we don't need to synchronize the access to + the buffer. + + Buffer memory layout: + data metadata + | | + | (current_idx) | (current_idx) + v v + +-------------------------------+----------------------------------------+ + | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | + +-------------------------------+----------------------------------------+ + | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes | + + metadata memory layout: each byte is a flag, the first byte is the written + flag, and the rest are reader flags. The flags are set to 0 by default. + +--------------+--------------+--------------+-----+--------------+ + | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | + +--------------+--------------+--------------+-----+--------------+ + + During creation, `name` is None and the buffer is created. We can pass the + created object to other processes by pickling it. The other processes will + get the name of the shared memory and open it, so that they can access the + same shared memory buffer. + """# noqa + self.n_reader = n_reader + self.metadata_size = 1 + n_reader + self.max_chunk_bytes = max_chunk_bytes + self.max_chunks = max_chunks + self.total_bytes_of_buffer = (self.max_chunk_bytes + + self.metadata_size) * self.max_chunks + self.data_offset = 0 + self.metadata_offset = self.max_chunk_bytes * self.max_chunks + + if name is None: + # we are creating a buffer + self.is_creator = True + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.total_bytes_of_buffer) + # initialize the metadata section to 0 + with memoryview(self.shared_memory.buf[self.metadata_offset:] + ) as metadata_buffer: + torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) + else: + # we are opening an existing buffer + self.is_creator = False + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + self.shared_memory = shared_memory.SharedMemory(name=name) + assert self.shared_memory.size == self.total_bytes_of_buffer + with memoryview(self.shared_memory.buf[self.metadata_offset:] + ) as metadata_buffer: + tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8) + assert torch.all(tensor == 0) + + def __reduce__(self): + return ( + self.__class__, + (self.n_reader, self.max_chunk_bytes, self.max_chunks, + self.shared_memory.name), + ) + + def __del__(self): + self.shared_memory.close() + if self.is_creator: + self.shared_memory.unlink() + + @contextmanager + def get_data(self, current_idx: int): + start = self.data_offset + current_idx * self.max_chunk_bytes + end = start + self.max_chunk_bytes + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + @contextmanager + def get_metadata(self, current_idx: int): + start = self.metadata_offset + current_idx * self.metadata_size + end = start + self.metadata_size + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + +class ShmRingBufferIO: + + def __init__(self, buffer: ShmRingBuffer, reader_rank: int): + self.buffer = buffer + self.reader_rank = reader_rank + self._is_writer = self.reader_rank == -1 + self._is_reader = not self._is_writer + if self._is_reader: + assert 0 <= self.reader_rank < buffer.n_reader, \ + (f"Invalid reader rank {self.reader_rank} for buffer" + f" created with {buffer.n_reader} readers") + self.current_idx = 0 + + @contextmanager + def acquire_write(self): + assert self._is_writer, "Only writers can acquire write" + start_index = self.current_idx + start_time = time.time() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # try to write to the next block + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + if self.current_idx == start_index: + # no empty block found + if time.time( + ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + # wait for a while (0.1 us) + time.sleep(1e-7) + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # mark the block as written + metadata_buffer[0] = 1 + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + break + + @contextmanager + def acquire_read(self): + assert self._is_reader, "Only readers can acquire read" + start_index = self.current_idx + start_time = time.time() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_flag = metadata_buffer[self.reader_rank + 1] + written_flag = metadata_buffer[0] + if not written_flag or read_flag: + # this block is either + # (1) not written + # (2) already read by this reader + # try to read the next block + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + if self.current_idx == start_index: + # no block found + if time.time( + ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + # wait for a while (0.1 us) + time.sleep(1e-7) + continue + # found a block that is not read by this reader + # let caller read from the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has read from the buffer + # set the read flag + metadata_buffer[self.reader_rank + 1] = 1 + break + + def enqueue(self, obj): + assert self._is_writer, "Only writers can enqueue" + serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + if len(serialized_obj) > self.buffer.max_chunk_bytes: + raise RuntimeError( + f"{len(serialized_obj)=} larger than the allowed value " + f"{self.buffer.max_chunk_bytes}," + "Please increase the max_chunk_bytes parameter.") + with self.acquire_write() as buf: + buf[:len(serialized_obj)] = serialized_obj + + def dequeue(self): + assert self._is_reader, "Only readers can dequeue" + with self.acquire_read() as buf: + # no need to know the size of serialized object + # pickle format itself contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf) + return obj + + def broadcast_object(self, obj=None): + if self._is_writer: + self.enqueue(obj) + return obj + else: + return self.dequeue() + + def create_from_process_group(pg: ProcessGroup, + max_chunk_bytes, + max_chunks, + writer_rank=0) -> "ShmRingBufferIO": + group_rank = dist.get_rank(pg) + group_world_size = dist.get_world_size(pg) + ranks_inside_group = list(range(group_world_size)) + global_ranks = dist.get_process_group_ranks(pg) + n_reader = group_world_size - 1 + buffer: ShmRingBuffer + if group_rank == writer_rank: + buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks) + dist.broadcast_object_list([buffer], src=global_ranks[writer_rank]) + dist.barrier(pg) + return ShmRingBufferIO(buffer, -1) + else: + recv = [None] + dist.broadcast_object_list(recv, src=global_ranks[writer_rank]) + dist.barrier(pg) + buffer = recv[0] # type: ignore + rest_ranks = [r for r in ranks_inside_group if r != writer_rank] + return ShmRingBufferIO(buffer, rest_ranks.index(group_rank)) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 02b0dcbcb6b24..5188fadbb92a5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -98,6 +98,7 @@ class GroupCoordinator: # communicators are only created for world size > 1 pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator + shm_broadcaster: Optional[Any] # shared memory broadcaster def __init__( self, @@ -162,6 +163,13 @@ def __init__( else: self.ca_comm = None + from vllm.distributed.device_communicators.shm_broadcast import ( + ShmRingBufferIO) + self.shm_broadcaster: Optional[ShmRingBufferIO] = None + if self.world_size > 1 and is_in_the_same_node(self.cpu_group): + self.shm_broadcaster = ShmRingBufferIO.create_from_process_group( + self.cpu_group, 1 << 20, 6) + @property def first_rank(self): """Return the global rank of the first process in the group""" @@ -324,6 +332,30 @@ def broadcast(self, input_: torch.Tensor, src: int = 0): group=self.device_group) return input_ + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], + src=self.ranks[src], + group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=self.ranks[src], + group=self.cpu_group) + return recv[0] + def broadcast_object_list(self, obj_list: List[Any], src: int = 0, @@ -371,9 +403,7 @@ def broadcast_tensor_dict( # `metadata_list` lives in CPU memory. # `broadcast_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. - torch.distributed.broadcast_object_list([metadata_list], - src=src, - group=metadata_group) + self.broadcast_object(metadata_list, src=src) async_handles = [] for tensor in tensor_list: if tensor.numel() == 0: @@ -396,14 +426,10 @@ def broadcast_tensor_dict( async_handle.wait() else: - recv_metadata_list = [None] - torch.distributed.broadcast_object_list(recv_metadata_list, - src=src, - group=metadata_group) - assert recv_metadata_list[0] is not None + metadata_list = self.broadcast_object(None, src=src) tensor_dict = {} async_handles = [] - for key, value in recv_metadata_list[0]: + for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, diff --git a/vllm/envs.py b/vllm/envs.py index ae2fcd0826fb1..49277e2d3519f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -5,6 +5,7 @@ VLLM_HOST_IP: str = "" VLLM_PORT: Optional[int] = None VLLM_USE_MODELSCOPE: bool = False + VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_INSTANCE_ID: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None @@ -114,6 +115,10 @@ "VLLM_INSTANCE_ID": lambda: os.environ.get("VLLM_INSTANCE_ID", None), + # Interval in seconds to log a warning message when the ring buffer is full + "VLLM_RINGBUFFER_WARNING_INTERVAL": + lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), + # path to cudatoolkit home directory, under which should be bin, include, # and lib directories. "CUDA_HOME":