Skip to content

Commit

Permalink
Record number of skipped tokens in the response (#681)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Nov 14, 2024
1 parent fc2cebb commit 0204c2e
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 44 deletions.
4 changes: 4 additions & 0 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ class Token(BaseModel):
special: bool
# Alternative tokens
alternative_tokens: Optional[List[AlternativeToken]] = None
# If token was skipped due to speculative decoding
skipped: bool


# Generation finish reason
Expand Down Expand Up @@ -312,6 +314,8 @@ class Details(BaseModel):
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Number of skipped tokens
skipped_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int] = None
# Decoder input tokens, empty if decoder_input_details is False
Expand Down
6 changes: 4 additions & 2 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,12 @@ message GeneratedText {
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Number of skipped tokens due to speculative decoding hits
uint32 skipped_tokens = 3;
/// Finish reason
FinishReason finish_reason = 3;
FinishReason finish_reason = 4;
/// Seed
optional uint64 seed = 4;
optional uint64 seed = 5;
}

message PrefillTokens {
Expand Down
5 changes: 4 additions & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1398,9 +1398,11 @@ fn send_responses(
next_tokens.is_special,
alternative_tokens,
))
.enumerate()
.peekable();

while let Some((id, logprob, text, special, alternative_tokens)) = iterator.next() {
while let Some((idx, (id, logprob, text, special, alternative_tokens))) = iterator.next() {
let skipped = idx > 0;
let token = Token {
id,
text,
Expand All @@ -1416,6 +1418,7 @@ fn send_responses(
.collect(),
)
}),
skipped,
};

match (&generation.generated_text, iterator.peek()) {
Expand Down
4 changes: 4 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ pub struct Token {
#[schema(nullable = true)]
#[serde(skip_serializing_if = "Option::is_none")]
alternative_tokens: Option<Vec<AlternativeToken>>,
#[schema(example = "false")]
skipped: bool,
}

#[derive(Debug, Serialize, ToSchema)]
Expand Down Expand Up @@ -462,6 +464,8 @@ pub(crate) struct Details {
pub prompt_tokens: u32,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 1)]
pub skipped_tokens: u32,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
Expand Down
7 changes: 7 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ async fn generate(
};

let generated_tokens = response.generated_text.generated_tokens;
let skipped_tokens = response.generated_text.skipped_tokens;
let prompt_tokens = response.prompt_tokens;
let total_tokens = prompt_tokens + generated_tokens;

Expand Down Expand Up @@ -680,6 +681,7 @@ async fn generate(
finish_reason: FinishReason::from(response.generated_text.finish_reason),
prompt_tokens: prompt_tokens,
generated_tokens: generated_tokens,
skipped_tokens: skipped_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
Expand All @@ -705,6 +707,7 @@ async fn generate(
span.record("seed", format!("{:?}", response.generated_text.seed));
span.record("prompt_tokens", format!("{prompt_tokens:?}"));
span.record("generated_tokens", format!("{generated_tokens:?}"));
span.record("skipped_tokens", format!("{skipped_tokens:?}"));

// Headers
let mut headers = HeaderMap::new();
Expand All @@ -729,6 +732,10 @@ async fn generate(
"x-generated-tokens",
generated_tokens.to_string().parse().unwrap(),
);
headers.insert(
"x-skipped-tokens",
skipped_tokens.to_string().parse().unwrap(),
);
headers.insert("x-total-tokens", total_tokens.to_string().parse().unwrap());
headers.insert(
"x-validation-time",
Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,9 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option
else:
seed = None

generated_text = GeneratedText(output_text, stopping_criteria.current_tokens, reason, seed)
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, stopping_criteria.current_skipped, reason, seed
)
else:
generated_text = None

Expand Down
16 changes: 6 additions & 10 deletions server/lorax_server/models/custom_modeling/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type):

out_size = fc1.linear.weight.shape[-1] * weights.process_group.size()
self.fc1 = TensorParallelMultiAdapterLinear.load(
fc1,
layer_id,
[f'{model_type}_{FC1}'],
sizes=[out_size],
process_group=weights.process_group
fc1, layer_id, [f"{model_type}_{FC1}"], sizes=[out_size], process_group=weights.process_group
)
self.fc2 = TensorParallelAdapterRowLinear.load(
TensorParallelRowLinear.load(
Expand All @@ -239,7 +235,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type):
bias=True,
),
layer_id,
f'{model_type}_{FC2}',
f"{model_type}_{FC2}",
process_group=weights.process_group,
)

Expand All @@ -261,7 +257,7 @@ def load_attention(config, prefix, weights, layer_id, model_type, head_dim, n_he
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
[f'{model_type}_{Q_PROJ}', f'{model_type}_{K_PROJ}', f'{model_type}_{V_PROJ}'],
[f"{model_type}_{Q_PROJ}", f"{model_type}_{K_PROJ}", f"{model_type}_{V_PROJ}"],
sizes=[
head_dim * n_head,
head_dim * n_head_kv,
Expand Down Expand Up @@ -306,7 +302,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type):
bias=False,
),
layer_id,
f'{model_type}_{O_PROJ}',
f"{model_type}_{O_PROJ}",
process_group=weights.process_group,
)

Expand Down Expand Up @@ -557,15 +553,15 @@ def __init__(self, *, prefix, config, weights):
weights=weights,
is_gated=False,
num_layers=config.num_hidden_layers,
model_type='VISION_TRANSFORMER',
model_type="VISION_TRANSFORMER",
)
self.global_transformer = MllamaVisionEncoder(
prefix=f"{prefix}.global_transformer",
config=config,
weights=weights,
is_gated=True,
num_layers=config.num_global_layers,
model_type='VISION_GLOBAL_TRANSFORMER',
model_type="VISION_GLOBAL_TRANSFORMER",
)

def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
Expand Down
6 changes: 6 additions & 0 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,6 +1980,8 @@ def generate_token(
if n_accepted_ids > 1:
logger.debug(f"speculated ids {n_accepted_ids - 1}")

# First token is not skipped, next tokens are
skipped = False
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
Expand All @@ -1995,8 +1997,11 @@ def generate_token(
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
skipped=skipped,
)

# All subsequent tokens are skipped
skipped = True
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
Expand All @@ -2022,6 +2027,7 @@ def generate_token(
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
stopping_criteria.current_skipped,
reason,
seed if do_sample else None,
)
Expand Down
57 changes: 31 additions & 26 deletions server/lorax_server/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TEXT_ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD]
VISION_ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, FC1, FC2]


@dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch):
image_indices: List[int] = 42
Expand Down Expand Up @@ -179,33 +180,34 @@ def from_pb(


class MllamaCausalLM(VlmCausalLM):

@property
def supports_adapter_loading(self) -> bool:
return True

@property
def adapter_layers(self) -> List[str]:
return [f'TEXT_{layer_type}' for layer_type in TEXT_ADAPTER_LAYERS] \
+ [f'VISION_GLOBAL_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] \
+ [f'VISION_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS]
return (
[f"TEXT_{layer_type}" for layer_type in TEXT_ADAPTER_LAYERS]
+ [f"VISION_GLOBAL_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS]
+ [f"VISION_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS]
)

@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]

def get_num_layers_for_type(self, layer_type: str) -> int:
if 'LM_HEAD' in layer_type:
if "LM_HEAD" in layer_type:
return 1
if 'TEXT_' in layer_type:
if "TEXT_" in layer_type:
return [
layer_id
for layer_id, layer in enumerate(self.model.text_model.model.layers)
if not isinstance(layer, FlashLlamaCrossLayer)
if not isinstance(layer, FlashLlamaCrossLayer)
]
if 'VISION_GLOBAL_TRANSFORMER_' in layer_type:
if "VISION_GLOBAL_TRANSFORMER_" in layer_type:
return len(self.model.vision_model.global_transformer.layers)
if 'VISION_TRANSFORMER_' in layer_type:
if "VISION_TRANSFORMER_" in layer_type:
return len(self.model.vision_model.transformer.layers)

def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
Expand All @@ -215,51 +217,54 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
for i, layer in enumerate(self.model.text_model.model.layers):
if isinstance(layer, FlashLlamaCrossLayer):
continue
layer_weights[(i, f'TEXT_{Q_PROJ}')] = (
layer_weights[(i, f"TEXT_{Q_PROJ}")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{K_PROJ}')] = (
layer_weights[(i, f"TEXT_{K_PROJ}")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{V_PROJ}')] = (
layer_weights[(i, f"TEXT_{V_PROJ}")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{O_PROJ}')] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj)

layer_weights[(i, f'TEXT_{GATE_PROJ}')] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f'TEXT_{UP_PROJ}')] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f'TEXT_{DOWN_PROJ}')] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)
layer_weights[(0, f'TEXT_{LM_HEAD}')] = ("base_model.model.language_model.lm_head", self.model.text_model.lm_head)
layer_weights[(i, f"TEXT_{O_PROJ}")] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj)

layer_weights[(i, f"TEXT_{GATE_PROJ}")] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f"TEXT_{UP_PROJ}")] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f"TEXT_{DOWN_PROJ}")] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)
layer_weights[(0, f"TEXT_{LM_HEAD}")] = (
"base_model.model.language_model.lm_head",
self.model.text_model.lm_head,
)

vision_layer_mappings = [
("vision_model.global_transformer.layers", self.model.vision_model.global_transformer.layers),
("vision_model.transformer.layers", self.model.vision_model.transformer.layers),
]
for prefix, layer_list in vision_layer_mappings:
layer_type_prefix = 'VISION_GLOBAL_TRANSFORMER' if 'global_transformer' in prefix else 'VISION_TRANSFORMER'
layer_type_prefix = "VISION_GLOBAL_TRANSFORMER" if "global_transformer" in prefix else "VISION_TRANSFORMER"
for i, layer in enumerate(layer_list):
layer_weights[(i, f'{layer_type_prefix}_{Q_PROJ}')] = (
layer_weights[(i, f"{layer_type_prefix}_{Q_PROJ}")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{K_PROJ}')] = (
layer_weights[(i, f"{layer_type_prefix}_{K_PROJ}")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{V_PROJ}')] = (
layer_weights[(i, f"{layer_type_prefix}_{V_PROJ}")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{O_PROJ}')] = (
layer_weights[(i, f"{layer_type_prefix}_{O_PROJ}")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj
layer.self_attn.o_proj,
)

layer_weights[(i, f'{layer_type_prefix}_{FC1}')] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1)
layer_weights[(i, f'{layer_type_prefix}_{FC2}')] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2)
layer_weights[(i, f"{layer_type_prefix}_{FC1}")] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1)
layer_weights[(i, f"{layer_type_prefix}_{FC2}")] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2)

return layer_weights

Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,9 @@ def generate_token(self, batch: Seq2SeqLMBatch) -> Tuple[List[Generation], Optio
else:
seed = None

generated_text = GeneratedText(output_text, stopping_criteria.current_tokens, reason, seed)
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, stopping_criteria.current_skipped, reason, seed
)
else:
generated_text = None

Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def __len__(self):
class GeneratedText:
text: str
generated_tokens: int
skipped_tokens: int
finish_reason: FinishReason
seed: Optional[int]

def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
text=self.text,
generated_tokens=self.generated_tokens,
skipped_tokens=self.skipped_tokens,
finish_reason=self.finish_reason,
seed=self.seed,
)
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"

FC1 = 'fc1'
FC2 = 'fc2'
FC1 = "fc1"
FC2 = "fc2"

LM_HEAD = "lm_head"
6 changes: 5 additions & 1 deletion server/lorax_server/utils/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,13 @@ def __init__(
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
self.current_skipped = 0
self.ignore_eos_token = ignore_eos_token

def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
def __call__(self, last_token: int, last_output: str, skipped: bool = False) -> Tuple[bool, Optional[str]]:
if skipped:
self.current_skipped += 1

self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
Expand Down

0 comments on commit 0204c2e

Please sign in to comment.