Skip to content

Commit

Permalink
Generate to max_total_tokens during warmup (#286)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Feb 28, 2024
1 parent 74f0c28 commit e51f078
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 13 deletions.
3 changes: 3 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ message DecodeResponse {
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;

/// Maximum number of new tokens to warmup
uint32 max_new_tokens = 2;
}

/// Empty response
Expand Down
15 changes: 11 additions & 4 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,19 @@ impl Client {
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();

// Create requests
while n_tokens < max_prefill_tokens {
// We truncate the input on the server side to be sure that it has the correct size
let truncate_length = min(max_input_length, max_prefill_tokens - n_tokens);
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
truncate: truncate_length,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
Expand All @@ -129,7 +131,7 @@ impl Client {
schema: None,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2,
max_new_tokens: max_total_tokens - truncate_length,
stop_sequences: vec![],
ignore_eos_token: false,
}),
Expand All @@ -147,7 +149,12 @@ impl Client {
max_tokens: 0,
};

let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
let max_new_tokens = max_total_tokens - max_input_length;
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_new_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
Expand Down
5 changes: 4 additions & 1 deletion router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ impl ShardedClient {
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect();
// Take the minimum value
let results = join_all(futures)
Expand Down
6 changes: 5 additions & 1 deletion router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,11 @@ async fn main() -> Result<(), RouterError> {
// Warmup model
tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens)
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
)
.await
.map_err(RouterError::Warmup)?
{
Expand Down
9 changes: 7 additions & 2 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.distributed
from loguru import logger
from opentelemetry import trace
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase

from lorax_server.models import Model
Expand Down Expand Up @@ -719,7 +720,7 @@ def __init__(
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch

def warmup(self, batch: FlashCausalLMBatch):
def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):
torch.cuda.empty_cache()
try:
cache_manager = set_cache_manager(
Expand All @@ -731,7 +732,11 @@ def warmup(self, batch: FlashCausalLMBatch):
self.dtype,
self.device,
)
_, batch = self.generate_token(batch)

with tqdm(total=max_new_tokens, desc="Warmup to max_total_tokens") as pbar:
for _ in range(max_new_tokens):
_, batch = self.generate_token(batch)
pbar.update(1)
except RuntimeError as e:
if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError):
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def batch_type(self) -> Type[B]:
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError

def warmup(self, batch: B) -> Optional[int]:
def warmup(self, batch: B, max_new_tokens: int) -> Optional[int]:
self.generate_token(batch)
return None

Expand Down
8 changes: 4 additions & 4 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,21 @@ async def FilterBatch(self, request, context):

return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

async def Warmup(self, request, context):
async def Warmup(self, request: generate_pb2.WarmupRequest, context):
batch = self.model.batch_type.from_pb(
request.batch,
self.model.tokenizer,
self.model.tokenizers,
self.model.dtype,
self.model.device,
)
max_supported_total_tokens = self.model.warmup(batch)
max_supported_total_tokens = self.model.warmup(batch, request.max_new_tokens)

return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens
)

async def Prefill(self, request, context):
async def Prefill(self, request: generate_pb2.PrefillRequest, context):
batch = self.model.batch_type.from_pb(
request.batch,
self.model.tokenizer,
Expand All @@ -99,7 +99,7 @@ async def Prefill(self, request, context):
batch=next_batch.to_pb() if next_batch else None,
)

async def Decode(self, request, context):
async def Decode(self, request: generate_pb2.DecodeRequest, context):
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")

Expand Down

0 comments on commit e51f078

Please sign in to comment.