From d462ca9bbc55531bbe785203cb076e7797250f2a Mon Sep 17 00:00:00 2001 From: Zhihao Shan <60905719+zhihaoshan-google@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:56:16 -0800 Subject: [PATCH] add seperate prefill detokenization thread (#152) Co-authored-by: Zhihao Shan --- jetstream/core/orchestrator.py | 68 ++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 0fd64c5e..15fc36dd 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -38,7 +38,7 @@ of the generation loop at the relevant slot. - Regardless, it performs a step. - It takes the sampled tokens, and places them on a 'detokenizing_queue'. -7. Within the detokenizing thread: +7. Within the detokenizing thread (Prefill and Generate separately): - Tokens are detokenized for every 'slot' in a given set of sampled tokens. - When an end condition is met, the 'slot' integer is returned to the respective generation queue. @@ -210,7 +210,8 @@ class Driver: # Stage 4 # This can be a list because we can pass it as an arg to generate and # detokenize threads. It is a list of tokens to be detokenized. - _detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] + _prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] + _generate_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] _generate_slots: list[queue.Queue[int]] = [] _active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = [] @@ -270,11 +271,11 @@ def __init__( # one of the generate backlogs. # Interleaved Mode: Max size is 1 to increase the HBM utilization # during generate. - # Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued - # while 1 transfer is enqueued while 1 is being transferred. + # Disaggregated Mode: Max size is 16 to allow for total 16 prefills to + # be enqueued or enqueued while 1 is being transferred. # TODO: Make queue size configurable. self._transfer_backlogs = [ - queue.Queue(1 if self._interleaved_mode else 4) + queue.Queue(1 if self._interleaved_mode else 16) for i in range(len(self._prefill_engines)) ] if self._metrics_collector: @@ -302,10 +303,11 @@ def __init__( functools.partial(float, backlog.qsize()) ) # Stage 4 - # After generation, ActiveRequests are placed on the detokenization backlog - # for tokens to be sent into each ActiveRequest's return channel. - # We have one of these per generate engine to simplify the logic keeping - # track of which generation engine to replace slots on. + # After prefill and generation, ActiveRequests are placed on the + # detokenization backlog for tokens to be sent into each ActiveRequest's + # return channel. + # We have one of these per prefill / generate engine to simplify + # the logic keeping track of which generation engine to replace slots on. # This is a queue of either - tuple[int, ActiveRequest] which represents our # active requests, or tuple[int, sample_tokens]. We combine these into one # queue because it allows us to be somewhat clever with how we do @@ -320,7 +322,16 @@ def __init__( # the possibility of race conditions where a slot is made live before the # tokens are ready and it receives tokens from a different sequence, # or tokens detokenized before the relevant slot is live. - self._detokenize_backlogs = [ + + self._prefill_detokenize_backlogs = [ + # No need to set maxsize, as transfer queue can + # provide the backpressure to the prefill workload + # (to avoid the overwhelming prefill). + queue.Queue() + for _ in self._prefill_engines + ] + + self._generate_detokenize_backlogs = [ # We don't let detokenization accumulate more than 8 steps to avoid # synchronization issues. queue.Queue(8) @@ -376,13 +387,25 @@ def __init__( ) for idx in range(len(self._generate_engines)) ] - self.detokenize_threads = [ + self.prefill_detokenize_threads = [ JetThread( target=functools.partial( self._detokenize_thread, - idx, + is_prefill=True, + idx=idx, + ), + name=f"prefill_detokenize-{idx}", + ) + for idx in range(len(self._generate_engines)) + ] + self.generate_detokenize_threads = [ + JetThread( + target=functools.partial( + self._detokenize_thread, + is_prefill=False, + idx=idx, ), - name=f"detokenize-{idx}", + name=f"generate_detokenize-{idx}", ) for idx in range(len(self._generate_engines)) ] @@ -391,7 +414,8 @@ def __init__( self._prefill_threads, self._transfer_threads, self._generate_threads, - self.detokenize_threads, + self.prefill_detokenize_threads, + self.generate_detokenize_threads, ) ) self.live = True @@ -410,7 +434,8 @@ def stop(self): [self._prefill_backlog], self._transfer_backlogs, self._generate_backlogs.values(), - self._detokenize_backlogs, + self._prefill_detokenize_backlogs, + self._generate_detokenize_backlogs, ) ) @@ -523,7 +548,7 @@ def _prefill_thread(self, idx: int): # put first token to detokenize queue request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) - my_detokenize_backlog = self._detokenize_backlogs[idx] + my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] request.metadata.transfer_enqueue_time = time.perf_counter() my_detokenize_backlog.put( (first_token, request, request.metadata.prefill_dequeue_time), @@ -619,7 +644,7 @@ def _generate_thread(self, idx: int): generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx] my_generate_backlog = self._generate_backlogs[idx] - my_detokenize_backlog = self._detokenize_backlogs[idx] + my_detokenize_backlog = self._generate_detokenize_backlogs[idx] # Keep track of what step tokens were generated at. generate_timestep = 0 @@ -749,12 +774,17 @@ def _generate_thread(self, idx: int): ) time_of_last_generate = time.time() - def _detokenize_thread(self, idx: int): + def _detokenize_thread(self, is_prefill: bool, idx: int): """Detokenize sampled tokens and returns them to the user.""" # One of these per generate engine. # For all filled my_slots, pop the sampled token onto the relevant # requests return channel. If it done, place it back onto free slots. - my_detokenize_backlog = self._detokenize_backlogs[idx] + + if is_prefill: + my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] + else: + my_detokenize_backlog = self._generate_detokenize_backlogs[idx] + my_generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx]