Skip to content

Commit

Permalink
Add Xgrammar Support (#701)
Browse files Browse the repository at this point in the history
Change:
- change launch argument from `--simple_constraint_mode` to
`--output_constraint_mode`, now user can choose the constriant decode
backend from ['outlines', 'xgrammar']
- add `XgrammarBackend` used for xgrammar constraint decode, maybe we
should merge it with `SimpleConstraintBackend` later?
- now we adopt the same request body, the same as vLLM with
xgrammar(https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html)
- user can add `guided_grammar` to pass a EBNF grammar and execute the
constraint decode
- user can add `guided_json` to pass a standard json schema and do the
constraint decode

---------

Co-authored-by: hiworldwzj <[email protected]>
  • Loading branch information
flyinglandlord and hiworldwzj authored Feb 26, 2025
1 parent c483b1e commit c8c892a
Show file tree
Hide file tree
Showing 13 changed files with 517 additions and 16 deletions.
9 changes: 8 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument("--enable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")
parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode")
parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode")
parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode")

parser.add_argument(
"--output_constraint_mode",
type=str,
choices=["outlines", "xgrammar", "none"],
default="none",
help="set the output constraint backend, none means no output constraint",
)
parser.add_argument(
"--first_token_constraint_mode",
action="store_true",
Expand Down
8 changes: 7 additions & 1 deletion lightllm/server/core/objs/py_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def __init__(
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
input_penalty: bool = DEFAULT_INPUT_PENALTY,
regular_constraint: Optional[str] = None, # Regular expressions constrain the output.
guided_grammar: Optional[str] = None, # EBNF constrain the output.
guided_json: Optional[Union[str, dict]] = None, # JSON schema constrain the output.
# If provided, the engine will construct a logits,
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--simple_constraint_mode" started server.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
allowed_token_ids: Optional[List[int]] = None,
# p d mode used params
group_request_id: Optional[int] = None,
Expand Down Expand Up @@ -81,6 +83,8 @@ def __init__(
self.add_spaces_between_special_tokens = add_spaces_between_special_tokens
self.print_eos_token = print_eos_token
self.regular_constraint = regular_constraint
self.guided_grammar = guided_grammar
self.guided_json = guided_json
self.allowed_token_ids = allowed_token_ids
self.group_request_id = group_request_id
self.move_kv_to_decode_node = move_kv_to_decode_node
Expand Down Expand Up @@ -257,6 +261,8 @@ def to_dict(self):
ret["best_of"] = self.best_of
ret["input_penalty"] = self.input_penalty
ret["regular_constraint"] = self.regular_constraint
ret["guided_grammar"] = self.guided_grammar
ret["guided_json"] = self.guided_json
ret["allowed_token_ids"] = self.allowed_token_ids
ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node
return ret
Expand Down
93 changes: 91 additions & 2 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048))
JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048))


class StopSequence(ctypes.Structure):
Expand Down Expand Up @@ -76,7 +78,7 @@ def to_list(self):
class RegularConstraint(ctypes.Structure):
_pack_ = 4
_fields_ = [
("constraint", ctypes.c_byte * REGULAR_CONSTRAINT_MAX_LENGTH),
("constraint", ctypes.c_ubyte * REGULAR_CONSTRAINT_MAX_LENGTH),
("length", ctypes.c_int),
]

Expand All @@ -98,6 +100,66 @@ def to_str(self):
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")


class GuidedGrammar(ctypes.Structure):
_pack_ = 4
_fields_ = [
("constraint", ctypes.c_ubyte * GRAMMAR_CONSTRAINT_MAX_LENGTH),
("length", ctypes.c_int),
]

def initialize(self, constraint: str, tokenizer):
constraint_bytes = constraint.encode("utf-8")
assert len(constraint_bytes) < GRAMMAR_CONSTRAINT_MAX_LENGTH, "Guided grammar is too long."

ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes))
self.length = len(constraint_bytes)
try:
if self.length > 0:
import xgrammar as xgr

tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
xgrammar_compiler.compile_grammar(constraint)
except Exception as e:
raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}")
return

def to_str(self):
if self.length == 0:
return ""
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")


class GuidedJsonSchema(ctypes.Structure):
_pack_ = 4
_fields_ = [
("constraint", ctypes.c_ubyte * JSON_SCHEMA_MAX_LENGTH),
("length", ctypes.c_int),
]

def initialize(self, constraint: str, tokenizer):
constraint_bytes = constraint.encode("utf-8")
assert len(constraint_bytes) < JSON_SCHEMA_MAX_LENGTH, "Guided json schema is too long."

ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes))
self.length = len(constraint_bytes)
try:
if self.length > 0:
import xgrammar as xgr

tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
xgrammar_compiler.compile_json_schema(constraint)
except Exception as e:
raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}")
return

def to_str(self):
if self.length == 0:
return ""
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")


class AllowedTokenIds(ctypes.Structure):
_pack_ = 4
_fields_ = [
Expand Down Expand Up @@ -191,9 +253,11 @@ class SamplingParams(ctypes.Structure):
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
("input_penalty", ctypes.c_bool),
("regular_constraint", RegularConstraint),
("guided_grammar", GuidedGrammar),
("guided_json", GuidedJsonSchema),
# If provided, the engine will construct a logits,
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--simple_constraint_mode" started server.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
("allowed_token_ids", AllowedTokenIds),
("stop_sequences", StopSequenceGroups),
("exponential_decay_length_penalty", ExponentialDecayLengthPenalty),
Expand Down Expand Up @@ -251,6 +315,16 @@ def init(self, tokenizer, **kwargs):
self.regular_constraint = RegularConstraint()
self.regular_constraint.initialize(regular_constraint)

# Initialize guided_grammar
guided_grammar = kwargs.get("guided_grammar", "")
self.guided_grammar = GuidedGrammar()
self.guided_grammar.initialize(guided_grammar, tokenizer)

# Initialize guided_json
guided_json = kwargs.get("guided_json", "")
self.guided_json = GuidedJsonSchema()
self.guided_json.initialize(guided_json, tokenizer)

# Initialize stop_sequence_groups
stop_sequences = kwargs.get("stop_sequences", [])
self.stop_sequences = StopSequenceGroups()
Expand Down Expand Up @@ -316,13 +390,26 @@ def verify(self):
)

self._verify_allowed_token_ids()
self._verify_grammar_constraint()

return

def _verify_grammar_constraint(self):
if self.guided_grammar.length != 0:
if self.regular_constraint.length != 0:
raise ValueError("guided_grammar and regular_constraint can not be used in same time")
if self.guided_json.length != 0:
raise ValueError("guided_grammar and guided_json can not be used in same time")
return

def _verify_allowed_token_ids(self):
if self.allowed_token_ids.size != 0:
if self.regular_constraint.length != 0:
raise ValueError("allowed_token_ids and regular_constraint can not be used in same time")
if self.guided_grammar.length != 0:
raise ValueError("allowed_token_ids and guided_grammar can not be used in same time")
if self.guided_json.length != 0:
raise ValueError("allowed_token_ids and guided_json can not be used in same time")
return

def to_dict(self):
Expand All @@ -342,6 +429,8 @@ def to_dict(self):
"best_of": self.best_of,
"input_penalty": self.input_penalty,
"regular_constraint": self.regular_constraint.to_str(),
"guided_grammar": self.guided_grammar.to_str(),
"guided_json": self.guided_json.to_str(),
"allowed_token_ids": self.allowed_token_ids.to_list(),
"group_request_id": self.group_request_id,
"move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(),
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class StartArgs:
enable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
token_healing_mode: bool = field(default=False)
simple_constraint_mode: bool = field(default=False)
output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]})
first_token_constraint_mode: bool = field(default=False)
enable_multimodal: bool = field(default=False)
cache_capacity: int = field(default=200)
Expand Down
16 changes: 13 additions & 3 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import collections

from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Any
from typing import List, Dict, Tuple, Optional, Union, Any
from lightllm.common.req_manager import ReqManager
from lightllm.utils.infer_utils import mark_start, mark_end
from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager
Expand Down Expand Up @@ -194,10 +194,15 @@ def __init__(

# output constraint states
self.regular_constraint = self.shm_param.regular_constraint.to_str()
self.guided_grammar = self.shm_param.guided_grammar.to_str()
self.guided_json = self.shm_param.guided_json.to_str()
if len(self.regular_constraint) == 0:
self.regular_constraint = None
if len(self.guided_grammar) == 0:
self.guided_grammar = None
if len(self.guided_json) == 0:
self.guided_json = None

self.regex_guide = None
self.fsm_current_state: int = 0
self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list()
if len(self.allowed_token_ids) == 0:
Expand All @@ -217,7 +222,12 @@ def __init__(
return

def has_constraint_setting(self) -> bool:
return self.regular_constraint is not None or self.allowed_token_ids is not None
return (
self.regular_constraint is not None
or self.allowed_token_ids is not None
or self.guided_grammar is not None
or self.guided_json is not None
)


class InferReq:
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/router/model_infer/mode_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from .chunked_prefill.impl import ChunkedPrefillBackend
from .diverse_backend.impl import DiversehBackend
from .continues_batch.impl_for_token_healing import TokenHealingBackend
from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend
from .continues_batch.impl_for_outlines_constraint_mode import OutlinesConstraintBackend
from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
from .dp_backend.impl import DPBackend
from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode
from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode
from .continues_batch.impl_for_xgrammar_mode import XgrammarBackend
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = init_logger(__name__)


class SimpleConstraintBackend(ContinuesBatchBackend):
class OutlinesConstraintBackend(ContinuesBatchBackend):
def __init__(self) -> None:
super().__init__()

Expand Down
Loading

0 comments on commit c8c892a

Please sign in to comment.