diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 959b2b629..872e5d78f 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -114,7 +114,7 @@ def initialize(self, constraint: str, tokenizer): ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) self.length = len(constraint_bytes) try: - if self.length > 0: + if self.length > 0 and tokenizer is not None: import xgrammar as xgr tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) @@ -144,7 +144,7 @@ def initialize(self, constraint: str, tokenizer): ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) self.length = len(constraint_bytes) try: - if self.length > 0: + if self.length > 0 and tokenizer is not None: import xgrammar as xgr tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) diff --git a/unit_tests/server/core/objs/test_sampling_params.py b/unit_tests/server/core/objs/test_sampling_params.py index 489f8ae34..ca684274e 100644 --- a/unit_tests/server/core/objs/test_sampling_params.py +++ b/unit_tests/server/core/objs/test_sampling_params.py @@ -12,6 +12,8 @@ STOP_SEQUENCE_MAX_LENGTH, REGULAR_CONSTRAINT_MAX_LENGTH, ALLOWED_TOKEN_IDS_MAX_LENGTH, + JSON_SCHEMA_MAX_LENGTH, + GRAMMAR_CONSTRAINT_MAX_LENGTH, ) grammar_str = r"""root ::= (expr "=" term)+ @@ -80,20 +82,20 @@ def test_regular_constraint_initialization(): def test_guided_grammar_initialization(): grammar = GuidedGrammar() - grammar.initialize(grammar_str) + grammar.initialize(grammar_str, None) assert grammar.to_str() == grammar_str with pytest.raises(AssertionError): - grammar.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) + grammar.initialize("a" * (GRAMMAR_CONSTRAINT_MAX_LENGTH + 1), None) def test_guided_json_schema_initialization(): schema = GuidedJsonSchema() - schema.initialize(schema_str) + schema.initialize(schema_str, None) assert schema.to_str() == schema_str with pytest.raises(AssertionError): - schema.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) + schema.initialize("a" * (JSON_SCHEMA_MAX_LENGTH + 1), None) def test_allowed_token_ids_initialization():