Skip to content

Commit

Permalink
add seperate prefill detokenization thread (#152)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhihao Shan <[email protected]>
  • Loading branch information
zhihaoshan-google and Zhihao Shan authored Nov 19, 2024
1 parent 15e3963 commit d462ca9
Showing 1 changed file with 49 additions and 19 deletions.
68 changes: 49 additions & 19 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
of the generation loop at the relevant slot.
- Regardless, it performs a step.
- It takes the sampled tokens, and places them on a 'detokenizing_queue'.
7. Within the detokenizing thread:
7. Within the detokenizing thread (Prefill and Generate separately):
- Tokens are detokenized for every 'slot' in a given set of sampled tokens.
- When an end condition is met, the 'slot' integer is returned to the
respective generation queue.
Expand Down Expand Up @@ -210,7 +210,8 @@ class Driver:
# Stage 4
# This can be a list because we can pass it as an arg to generate and
# detokenize threads. It is a list of tokens to be detokenized.
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_generate_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_generate_slots: list[queue.Queue[int]] = []
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = []

Expand Down Expand Up @@ -270,11 +271,11 @@ def __init__(
# one of the generate backlogs.
# Interleaved Mode: Max size is 1 to increase the HBM utilization
# during generate.
# Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
# while 1 transfer is enqueued while 1 is being transferred.
# Disaggregated Mode: Max size is 16 to allow for total 16 prefills to
# be enqueued or enqueued while 1 is being transferred.
# TODO: Make queue size configurable.
self._transfer_backlogs = [
queue.Queue(1 if self._interleaved_mode else 4)
queue.Queue(1 if self._interleaved_mode else 16)
for i in range(len(self._prefill_engines))
]
if self._metrics_collector:
Expand Down Expand Up @@ -302,10 +303,11 @@ def __init__(
functools.partial(float, backlog.qsize())
)
# Stage 4
# After generation, ActiveRequests are placed on the detokenization backlog
# for tokens to be sent into each ActiveRequest's return channel.
# We have one of these per generate engine to simplify the logic keeping
# track of which generation engine to replace slots on.
# After prefill and generation, ActiveRequests are placed on the
# detokenization backlog for tokens to be sent into each ActiveRequest's
# return channel.
# We have one of these per prefill / generate engine to simplify
# the logic keeping track of which generation engine to replace slots on.
# This is a queue of either - tuple[int, ActiveRequest] which represents our
# active requests, or tuple[int, sample_tokens]. We combine these into one
# queue because it allows us to be somewhat clever with how we do
Expand All @@ -320,7 +322,16 @@ def __init__(
# the possibility of race conditions where a slot is made live before the
# tokens are ready and it receives tokens from a different sequence,
# or tokens detokenized before the relevant slot is live.
self._detokenize_backlogs = [

self._prefill_detokenize_backlogs = [
# No need to set maxsize, as transfer queue can
# provide the backpressure to the prefill workload
# (to avoid the overwhelming prefill).
queue.Queue()
for _ in self._prefill_engines
]

self._generate_detokenize_backlogs = [
# We don't let detokenization accumulate more than 8 steps to avoid
# synchronization issues.
queue.Queue(8)
Expand Down Expand Up @@ -376,13 +387,25 @@ def __init__(
)
for idx in range(len(self._generate_engines))
]
self.detokenize_threads = [
self.prefill_detokenize_threads = [
JetThread(
target=functools.partial(
self._detokenize_thread,
idx,
is_prefill=True,
idx=idx,
),
name=f"prefill_detokenize-{idx}",
)
for idx in range(len(self._generate_engines))
]
self.generate_detokenize_threads = [
JetThread(
target=functools.partial(
self._detokenize_thread,
is_prefill=False,
idx=idx,
),
name=f"detokenize-{idx}",
name=f"generate_detokenize-{idx}",
)
for idx in range(len(self._generate_engines))
]
Expand All @@ -391,7 +414,8 @@ def __init__(
self._prefill_threads,
self._transfer_threads,
self._generate_threads,
self.detokenize_threads,
self.prefill_detokenize_threads,
self.generate_detokenize_threads,
)
)
self.live = True
Expand All @@ -410,7 +434,8 @@ def stop(self):
[self._prefill_backlog],
self._transfer_backlogs,
self._generate_backlogs.values(),
self._detokenize_backlogs,
self._prefill_detokenize_backlogs,
self._generate_detokenize_backlogs,
)
)

Expand Down Expand Up @@ -523,7 +548,7 @@ def _prefill_thread(self, idx: int):

# put first token to detokenize queue
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
my_detokenize_backlog = self._detokenize_backlogs[idx]
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
request.metadata.transfer_enqueue_time = time.perf_counter()
my_detokenize_backlog.put(
(first_token, request, request.metadata.prefill_dequeue_time),
Expand Down Expand Up @@ -619,7 +644,7 @@ def _generate_thread(self, idx: int):
generate_engine = self._generate_engines[idx]
my_slots = self._generate_slots[idx]
my_generate_backlog = self._generate_backlogs[idx]
my_detokenize_backlog = self._detokenize_backlogs[idx]
my_detokenize_backlog = self._generate_detokenize_backlogs[idx]

# Keep track of what step tokens were generated at.
generate_timestep = 0
Expand Down Expand Up @@ -749,12 +774,17 @@ def _generate_thread(self, idx: int):
)
time_of_last_generate = time.time()

def _detokenize_thread(self, idx: int):
def _detokenize_thread(self, is_prefill: bool, idx: int):
"""Detokenize sampled tokens and returns them to the user."""
# One of these per generate engine.
# For all filled my_slots, pop the sampled token onto the relevant
# requests return channel. If it done, place it back onto free slots.
my_detokenize_backlog = self._detokenize_backlogs[idx]

if is_prefill:
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
else:
my_detokenize_backlog = self._generate_detokenize_backlogs[idx]

my_generate_engine = self._generate_engines[idx]
my_slots = self._generate_slots[idx]

Expand Down

0 comments on commit d462ca9

Please sign in to comment.