diff --git a/.gitmodules b/.gitmodules index ac9bafe076..8939f0ee53 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "3rdparty/stb"] path = 3rdparty/stb url = https://github.com/nothings/stb.git +[submodule "3rdparty/xgrammar"] + path = 3rdparty/xgrammar + url = https://github.com/mlc-ai/xgrammar.git diff --git a/3rdparty/tokenizers-cpp b/3rdparty/tokenizers-cpp index c0fab1e14a..4bb7533776 160000 --- a/3rdparty/tokenizers-cpp +++ b/3rdparty/tokenizers-cpp @@ -1 +1 @@ -Subproject commit c0fab1e14a9421c1501acee5b7703e5dafa60479 +Subproject commit 4bb753377680e249345b54c6b10e6d0674c8af03 diff --git a/3rdparty/tvm b/3rdparty/tvm index 35a317f387..79a69ae4a9 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 35a317f387249f9592d176c9f12ddf44e2dc3853 +Subproject commit 79a69ae4a92c9d4f23e62f93ce5b0d90ed29e5ed diff --git a/3rdparty/xgrammar b/3rdparty/xgrammar new file mode 160000 index 0000000000..d4f57c440f --- /dev/null +++ b/3rdparty/xgrammar @@ -0,0 +1 @@ +Subproject commit d4f57c440f3da8e7330a1e5d50bba9c31f9433ea diff --git a/CMakeLists.txt b/CMakeLists.txt index e09728727c..08eef03b5f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,8 +68,11 @@ set(MLC_LLM_RUNTIME_LINKER_LIB "") set(TOKENZIER_CPP_PATH 3rdparty/tokenizers-cpp) add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) - +set(XGRAMMAR_PATH 3rdparty/xgrammar) tvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc) +tvm_file_glob(GLOB_RECURSE XGRAMMAR_SRCS ${XGRAMMAR_PATH}/cpp/*.cc) +list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/pybind/.*\\.cc") +list(APPEND MLC_LLM_SRCS ${XGRAMMAR_SRCS}) add_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS}) set( @@ -83,12 +86,14 @@ set( set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} DMLC_USE_LOGGING_LIBRARY=) set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} __STDC_FORMAT_MACROS=1) set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} PICOJSON_USE_INT64) +set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} XGRAMMAR_ENABLE_LOG_DEBUG=0) -target_include_directories(mlc_llm_objs PRIVATE ${MLC_LLM_INCLUDES}) target_compile_definitions(mlc_llm_objs PRIVATE ${MLC_LLM_COMPILE_DEFS}) -target_include_directories(mlc_llm_objs PRIVATE ${TOKENZIER_CPP_PATH}/include) target_compile_definitions(mlc_llm_objs PRIVATE -DMLC_LLM_EXPORTS) +target_include_directories(mlc_llm_objs PRIVATE ${MLC_LLM_INCLUDES}) target_include_directories(mlc_llm_objs PRIVATE 3rdparty/stb) +target_include_directories(mlc_llm_objs PRIVATE ${TOKENZIER_CPP_PATH}/include) +target_include_directories(mlc_llm_objs PRIVATE ${XGRAMMAR_PATH}/include) add_library(mlc_llm SHARED $) add_library(mlc_llm_static STATIC $) @@ -135,7 +140,6 @@ add_library(mlc_llm_module SHARED $) target_link_libraries(mlc_llm_module PUBLIC tvm) target_link_libraries(mlc_llm_module PRIVATE tokenizers_cpp) - set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS "${MLC_VISIBILITY_FLAG}") set_property(TARGET mlc_llm APPEND PROPERTY LINK_OPTIONS "${MLC_VISIBILITY_FLAG}") diff --git a/cmake/gen_cmake_config.py b/cmake/gen_cmake_config.py index 31972862dc..b03f686c4f 100644 --- a/cmake/gen_cmake_config.py +++ b/cmake/gen_cmake_config.py @@ -1,6 +1,6 @@ from collections import namedtuple -Backend = namedtuple("Backend", ["name", "cmake_config_name", "prompt_str"]) +Backend = namedtuple("Backend", ["name", "cmake_config_name", "prompt_str", "parent"]) if __name__ == "__main__": tvm_home = "" # pylint: disable=invalid-name @@ -13,65 +13,73 @@ cmake_config_str = f"set(TVM_SOURCE_DIR {tvm_home})\n" cmake_config_str += "set(CMAKE_BUILD_TYPE RelWithDebInfo)\n" + cuda_backend = Backend("CUDA", "USE_CUDA", "Use CUDA? (y/n): ", None) + opencl_backend = Backend("OpenCL", "USE_OPENCL", "Use OpenCL? (y/n) ", None) backends = [ - Backend("CUDA", "USE_CUDA", "Use CUDA? (y/n): "), - Backend("CUTLASS", "USE_CUTLASS", "Use CUTLASS? (y/n): "), - Backend("CUBLAS", "USE_CUBLAS", "Use CUBLAS? (y/n): "), - Backend("ROCm", "USE_ROCM", "Use ROCm? (y/n): "), - Backend("Vulkan", "USE_VULKAN", "Use Vulkan? (y/n): "), + cuda_backend, + Backend("CUTLASS", "USE_CUTLASS", "Use CUTLASS? (y/n): ", cuda_backend), + Backend("CUBLAS", "USE_CUBLAS", "Use CUBLAS? (y/n): ", cuda_backend), + Backend("ROCm", "USE_ROCM", "Use ROCm? (y/n): ", None), + Backend("Vulkan", "USE_VULKAN", "Use Vulkan? (y/n): ", None), + Backend("Metal", "USE_METAL", "Use Metal (Apple M1/M2 GPU) ? (y/n): ", None), + opencl_backend, Backend( - "Metal", - "USE_METAL", - "Use Metal (Apple M1/M2 GPU) ? (y/n): ", + "OpenCLHostPtr", + "USE_OPENCL_ENABLE_HOST_PTR", + "Use OpenCLHostPtr? (y/n): ", + opencl_backend, ), - Backend( - "OpenCL", - "USE_OPENCL", - "Use OpenCL? (y/n) ", - ), - Backend("OpenCLHostPtr", "USE_OPENCL_ENABLE_HOST_PTR", "Use OpenCLHostPtr? (y/n): "), ] enabled_backends = set() for backend in backends: - while True: - use_backend = input(backend.prompt_str) - if use_backend in ["yes", "Y", "y"]: - cmake_config_str += f"set({backend.cmake_config_name} ON)\n" - enabled_backends.add(backend.name) - break - elif use_backend in ["no", "N", "n"]: - cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" - break - else: - print(f"Invalid input: {use_backend}. Please input again.") + if backend.parent is not None and backend.parent.name not in enabled_backends: + cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" + else: + while True: + use_backend = input(backend.prompt_str) + if use_backend in ["yes", "Y", "y"]: + cmake_config_str += f"set({backend.cmake_config_name} ON)\n" + enabled_backends.add(backend.name) + break + elif use_backend in ["no", "N", "n"]: + cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" + break + else: + print(f"Invalid input: {use_backend}. Please input again.") if "CUDA" in enabled_backends: cmake_config_str += f"set(USE_THRUST ON)\n" # FlashInfer related use_flashInfer = False # pylint: disable=invalid-name - while True: - user_input = input("Use FlashInfer? (need CUDA w/ compute capability 80;86;89;90) (y/n): ") - if user_input in ["yes", "Y", "y"]: - cmake_config_str += "set(USE_FLASHINFER ON)\n" - cmake_config_str += "set(FLASHINFER_ENABLE_FP8 OFF)\n" - cmake_config_str += "set(FLASHINFER_ENABLE_BF16 OFF)\n" - cmake_config_str += "set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)\n" - cmake_config_str += "set(FLASHINFER_GEN_PAGE_SIZES 16)\n" - cmake_config_str += "set(FLASHINFER_GEN_HEAD_DIMS 128)\n" - cmake_config_str += "set(FLASHINFER_GEN_KV_LAYOUTS 0 1)\n" - cmake_config_str += "set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)\n" - cmake_config_str += 'set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")\n' - cmake_config_str += 'set(FLASHINFER_GEN_CASUALS "false" "true")\n' - use_flashInfer = True # pylint: disable=invalid-name - break - elif user_input in ["no", "N", "n"]: - cmake_config_str += "set(USE_FLASHINFER OFF)\n" - break - else: - print(f"Invalid input: {use_flashInfer}. Please input again.") + if "CUDA" in enabled_backends: + while True: + user_input = input( + "Use FlashInfer? (need CUDA w/ compute capability 80;86;89;90) (y/n): " + ) + if user_input in ["yes", "Y", "y"]: + cmake_config_str += "set(USE_FLASHINFER ON)\n" + cmake_config_str += "set(FLASHINFER_ENABLE_FP8 OFF)\n" + cmake_config_str += "set(FLASHINFER_ENABLE_BF16 OFF)\n" + cmake_config_str += "set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)\n" + cmake_config_str += "set(FLASHINFER_GEN_PAGE_SIZES 16)\n" + cmake_config_str += "set(FLASHINFER_GEN_HEAD_DIMS 128)\n" + cmake_config_str += "set(FLASHINFER_GEN_KV_LAYOUTS 0 1)\n" + cmake_config_str += "set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)\n" + cmake_config_str += 'set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")\n' + cmake_config_str += 'set(FLASHINFER_GEN_CASUALS "false" "true")\n' + use_flashInfer = True # pylint: disable=invalid-name + break + elif user_input in ["no", "N", "n"]: + cmake_config_str += "set(USE_FLASHINFER OFF)\n" + break + else: + print(f"Invalid input: {use_flashInfer}. Please input again.") + else: + cmake_config_str += "set(USE_FLASHINFER OFF)\n" + if use_flashInfer: while True: user_input = input("Enter your CUDA compute capability: ") diff --git a/cpp/grammar/grammar.cc b/cpp/grammar/grammar.cc deleted file mode 100644 index 1f5d38ba14..0000000000 --- a/cpp/grammar/grammar.cc +++ /dev/null @@ -1,175 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar.cc - */ - -#include "grammar.h" - -#include "grammar_functor.h" -#include "grammar_parser.h" -#include "grammar_serializer.h" -#include "json_schema_converter.h" - -namespace mlc { -namespace llm { -namespace serve { - -TVM_REGISTER_OBJECT_TYPE(BNFGrammarNode); - -std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) { - os << BNFGrammarPrinter(grammar).ToString(); - return os; -} - -BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, - const std::string& main_rule) { - auto grammar = EBNFParser::Parse(ebnf_string, main_rule); - // Normalize the grammar by default - grammar = BNFGrammarNormalizer().Apply(grammar); - return grammar; -} - -TVM_REGISTER_GLOBAL("mlc.grammar.BNFGrammarFromEBNFString") - .set_body_typed([](String ebnf_string, String main_rule) { - return BNFGrammar::FromEBNFString(ebnf_string, main_rule); - }); - -// Parse the EBNF string but not normalize it -BNFGrammar DebugFromEBNFStringNoNormalize(const std::string& ebnf_string, - const std::string& main_rule) { - return EBNFParser::Parse(ebnf_string, main_rule); -} - -TVM_REGISTER_GLOBAL("mlc.grammar.BNFGrammarDebugFromEBNFStringNoNormalize") - .set_body_typed([](String ebnf_string, String main_rule) { - return DebugFromEBNFStringNoNormalize(ebnf_string, main_rule); - }); - -BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) { - return BNFJSONParser::Parse(json_string); -} - -TVM_REGISTER_GLOBAL("mlc.grammar.BNFGrammarFromJSON").set_body_typed([](String json_string) { - return BNFGrammar::FromJSON(json_string); -}); - -BNFGrammar BNFGrammar::FromSchema(const std::string& schema, std::optional indent, - std::optional> separators, - bool strict_mode) { - return FromEBNFString(JSONSchemaToEBNF(schema, indent, separators, strict_mode)); -} - -TVM_REGISTER_GLOBAL("mlc.grammar.BNFGrammarFromSchema").set_body([](TVMArgs args, TVMRetValue* rv) { - std::optional indent; - if (args[1].type_code() != kTVMNullptr) { - indent = args[1]; - } else { - indent = std::nullopt; - } - - std::optional> separators; - if (args[2].type_code() != kTVMNullptr) { - Array separators_arr = args[2]; - CHECK(separators_arr.size() == 2); - separators = std::make_pair(separators_arr[0], separators_arr[1]); - } else { - separators = std::nullopt; - } - - *rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]); -}); - -// Optimized json grammar for the speed of the grammar state matcher -const std::string kJSONGrammarString = R"( -main ::= ( - "{" [ \n\t]* members_and_embrace | - "[" [ \n\t]* elements_or_embrace -) -value_non_str ::= ( - "{" [ \n\t]* members_and_embrace | - "[" [ \n\t]* elements_or_embrace | - "0" fraction exponent | - [1-9] [0-9]* fraction exponent | - "-" [0-9] fraction exponent | - "-" [1-9] [0-9]* fraction exponent | - "true" | - "false" | - "null" -) (= [ \n\t,}\]]) -members_and_embrace ::= ("\"" characters_and_colon [ \n\t]* members_suffix | "}") (= [ \n\t,}\]]) -members_suffix ::= ( - value_non_str [ \n\t]* member_suffix_suffix | - "\"" characters_and_embrace | - "\"" characters_and_comma [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix -) (= [ \n\t,}\]]) -member_suffix_suffix ::= ( - "}" | - "," [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix -) (= [ \n\t,}\]]) -elements_or_embrace ::= ( - "{" [ \n\t]* members_and_embrace elements_rest [ \n\t]* "]" | - "[" [ \n\t]* elements_or_embrace elements_rest [ \n\t]* "]" | - "\"" characters_item elements_rest [ \n\t]* "]" | - "0" fraction exponent elements_rest [ \n\t]* "]" | - [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" | - "-" "0" fraction exponent elements_rest [ \n\t]* "]" | - "-" [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" | - "true" elements_rest [ \n\t]* "]" | - "false" elements_rest [ \n\t]* "]" | - "null" elements_rest [ \n\t]* "]" | - "]" -) -elements ::= ( - "{" [ \n\t]* members_and_embrace elements_rest | - "[" [ \n\t]* elements_or_embrace elements_rest | - "\"" characters_item elements_rest | - "0" fraction exponent elements_rest | - [1-9] [0-9]* fraction exponent elements_rest | - "-" [0-9] fraction exponent elements_rest | - "-" [1-9] [0-9]* fraction exponent elements_rest | - "true" elements_rest | - "false" elements_rest | - "null" elements_rest -) -elements_rest ::= ( - "" | - [ \n\t]* "," [ \n\t]* elements -) -characters_and_colon ::= ( - "\"" [ \n\t]* ":" | - [^"\\\x00-\x1F] characters_and_colon | - "\\" escape characters_and_colon -) (=[ \n\t]* [\"{[0-9tfn-]) -characters_and_comma ::= ( - "\"" [ \n\t]* "," | - [^"\\\x00-\x1F] characters_and_comma | - "\\" escape characters_and_comma -) (=[ \n\t]* "\"") -characters_and_embrace ::= ( - "\"" [ \n\t]* "}" | - [^"\\\x00-\x1F] characters_and_embrace | - "\\" escape characters_and_embrace -) (=[ \n\t]* [},]) -characters_item ::= ( - "\"" | - [^"\\\x00-\x1F] characters_item | - "\\" escape characters_item -) (= [ \n\t]* [,\]]) -escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -fraction ::= "" | "." [0-9] [0-9]* -exponent ::= "" | "e" sign [0-9] [0-9]* | "E" sign [0-9] [0-9]* -sign ::= "" | "+" | "-" -)"; - -BNFGrammar BNFGrammar::GetGrammarOfJSON() { - static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, "main"); - return grammar; -} - -TVM_REGISTER_GLOBAL("mlc.grammar.BNFGrammarGetGrammarOfJSON").set_body_typed([]() { - return BNFGrammar::GetGrammarOfJSON(); -}); - -} // namespace serve -} // namespace llm -} // namespace mlc diff --git a/cpp/grammar/grammar.h b/cpp/grammar/grammar.h deleted file mode 100644 index 2e304dadb2..0000000000 --- a/cpp/grammar/grammar.h +++ /dev/null @@ -1,226 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar.h - * \brief The header for the support of grammar-guided generation. - */ - -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_H_ - -#include -#include - -#include -#include -#include -#include - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm; -using namespace tvm::runtime; - -/*! - * \brief This class stores the abstract syntax tree (AST) of the Backus-Naur Form (BNF) grammar. - * The BNF definition here is standard BNF, and the characters are represented using regex-style - * character classes (e.g. [a-z], [^a-z]). - * - * \details - * ### Rules - * The BNF grammar AST consists of a set of rules. Each rule contains a name and a definition, and - * corresponds to a production in the grammar. The definition of a rule is a RuleExpr. Each rule - * has a rule_id for reference. - * - * ### RuleExprs - * RuleExpr is the definition of a rule or part of the definition of a rule. It can contain - * elements, empty string, reference to other RuleExprs, or reference to other rules. Each RuleExpr - * corresponds to an rule_expr_id for reference. - * - * For example, in the following rule: rule ::= ("a" "b") | "c" - * ("a" "b"), "c", ("a" "b") | "c" are all RuleExprs. - * - * #### Types of RuleExprs - * Every RuleExpr is represented by a type as well as a variable-length array containing its data. - * RuleExpr has several types: - * - Byte string: a string of bytes (0~255). Supports UTF-8 strings. - * - Character class: a range of characters (each character is a unicode codepoint), e.g. [a-z], - * [ac-z]. Can be negated: [^a-z], [^ac-z]. Now only ascii chars is allowed in [], but this - * expression can accept/reject unicode chars. - * - Character class star: a star quantifier of a character class. e.g. [a-z]*, [^a-z]*. - * - EmptyStr: an empty string, i.e. "" - * - Rule reference: a reference to another rule - * - Sequence: a sequence of rule_exprs, e.g. ("a" "b"). These rule_exprs are concatenated together. - * - Choices: a choice of rule_exprs, e.g. ("a" "b") | "c". Each rule_expr can be matched. - * - * #### Storage of RuleExprs - * Each type of RuleExpr has a different data format. For the format of each type of RuleExpr, see - * docs in BNFGrammarNode::RuleExprType. - * - * We store all RuleExprs in csr_matrix style. That is, they are stored consecutively in one vector - * (data vector) and the starting position of each RuleExpr is recorded in the indptr vector. - * - * \remark The character class star RuleExpr is for the special support for elements like [a-z]* - * in the grammar. We add it to make the matching more efficient, as we can avoid recursion into - * rules when matching a sequence of characters. It should be used like: - * rule1 ::= ((element1 element2 rule2 ...) | ...) - * rule2 ::= character_class_star_rule_expr(id_of_a_character_class_rule_expr) - */ -class BNFGrammarNode : public Object { - public: - /*! \brief A rule with name. */ - struct Rule { - /*! \brief The name of the rule. */ - std::string name; - /*! \brief The RuleExpr id of the body of the rule. */ - int32_t body_expr_id; - /*! \brief The id of the associated lookahead assertion expr. For now it must be a id of a - * sequence RuleExpr. -1 if not exists. */ - int32_t lookahead_assertion_id = -1; - }; - - /*! \brief Get the number of rules. */ - size_t NumRules() const { return rules_.size(); } - /*! \brief Get the rule with the given id. */ - const Rule& GetRule(int32_t rule_id) const { - DCHECK(rule_id >= 0 && rule_id < static_cast(rules_.size())) - << "rule_id " << rule_id << " is out of bound"; - return rules_[rule_id]; - } - /*! \brief Get the main rule id of the grammar. */ - int32_t GetMainRuleId() const { return main_rule_id_; } - /*! \brief Get the main rule of the grammar. */ - const Rule& GetMainRule() const { - DCHECK(main_rule_id_ >= 0 && main_rule_id_ < static_cast(rules_.size())) - << "main_rule_id " << main_rule_id_ << " is out of bound"; - return rules_[main_rule_id_]; - } - - /*! \brief The type of the rule expr. */ - enum class RuleExprType : int32_t { - // data format: [byte0, byte1, ...] - kByteString, - // data format: [is_negative, lower0, upper0, lower1, upper1, ...] - kCharacterClass, - kCharacterClassStar, - // data format: [] - kEmptyStr, - // data format: [rule_id] - kRuleRef, - // data format: [rule_expr_id0, rule_expr_id1, ...] - kSequence, - // data format: [rule_expr_id0, rule_expr_id1, ...] - kChoices, - }; - - /*! \brief The object representing a rule expr. */ - struct RuleExpr { - /*! \brief The type of the rule expr. */ - RuleExprType type; - /*! \brief The data of the RuleExpr. A variable-length array. */ - const int32_t* data; - /*! \brief The length of the data array. */ - int32_t data_len; - - const int32_t size() const { return data_len; } - /*! \brief Get the i-th element of the data array. */ - const int32_t& operator[](int i) const { - DCHECK(i >= 0 && i < static_cast(data_len)) << "Index " << i << " is out of bound"; - return data[i]; - } - const int32_t* begin() const { return data; } - const int32_t* end() const { return data + data_len; } - }; - - /*! \brief Get the number of rule_exprs. */ - size_t NumRuleExprs() const { return rule_expr_indptr_.size(); } - /*! \brief Get the rule_expr with the given id. */ - RuleExpr GetRuleExpr(int32_t rule_expr_id) const { - DCHECK(rule_expr_id >= 0 && rule_expr_id < static_cast(rule_expr_indptr_.size())) - << "rule_expr_id " << rule_expr_id << " is out of bound"; - int start_index = rule_expr_indptr_[rule_expr_id]; - auto start_ptr = rule_expr_data_.data() + start_index; - auto type = static_cast(start_ptr[0]); - auto data_ptr = start_ptr + 2; - auto data_len = start_ptr[1]; - return {type, data_ptr, data_len}; - } - - static constexpr const char* _type_key = "mlc.grammar.BNFGrammar"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(BNFGrammarNode, Object); - - private: - /*! \brief The rules of the grammar. rule_id corresponds the index of this vector. */ - std::vector rules_; - /*! \brief The data of all rule_exprs. */ - std::vector rule_expr_data_; - /*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id is the index - * to the elements in this vector. */ - std::vector rule_expr_indptr_; - /*! \brief The id of the main rule. */ - int32_t main_rule_id_ = -1; - - friend class BNFGrammarBuilder; - friend class BNFGrammarJSONSerializer; - friend class BNFJSONParser; -}; - -class BNFGrammar : public ObjectRef { - public: - /*! - * \brief Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized - * (simplified) by default. - * \param ebnf_string The EBNF-formatted string. - * \param main_rule The name of the main rule. - */ - static BNFGrammar FromEBNFString(const std::string& ebnf_string, - const std::string& main_rule = "main"); - - /*! - * \brief Construct a BNF grammar from the dumped JSON string. - * \param json_string The JSON-formatted string. This string should have the same format as - * the result of BNFGrammarJSONSerializer::ToString. - */ - static BNFGrammar FromJSON(const std::string& json_string); - - /*! - * \brief Construct a BNF grammar from the json schema string. The schema string should be in the - * format of the schema of a JSON file. We will parse the schema and generate a BNF grammar. - * \param schema The schema string. - * \param indent The number of spaces for indentation. If set to std::nullopt, the output will be - * in one line. Default: 2. - * \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"}, - * {", ", ": "}. If std::nullopt, the default separators will be used: {",", ": "} when the - * indent is not nullopt, and {", ", ": "} otherwise. This follows the convention in python - * json.dumps(). Default: std::nullopt. - * \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not - * allow properties and items that is not specified in the schema. This is equivalent to - * setting unevaluatedProperties and unevaluatedItems to false. - * - * This helps LLM to generate accurate output in the grammar-guided generation with JSON - * schema. Default: true. - */ - static BNFGrammar FromSchema( - const std::string& schema, std::optional indent = std::nullopt, - std::optional> separators = std::nullopt, - bool strict_mode = true); - - /*! - * \brief Get the grammar of standard JSON format. We have built-in support for JSON. - */ - static BNFGrammar GetGrammarOfJSON(); - - /*! \brief Print a BNF grammar. */ - friend std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar); - - TVM_DEFINE_OBJECT_REF_METHODS(BNFGrammar, ObjectRef, BNFGrammarNode); -}; - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_H_ diff --git a/cpp/grammar/grammar_builder.h b/cpp/grammar/grammar_builder.h deleted file mode 100644 index 05ffaff4fe..0000000000 --- a/cpp/grammar/grammar_builder.h +++ /dev/null @@ -1,254 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_builder.h - * \brief The header for the building the BNF AST. - */ - -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_BUILDER_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_BUILDER_H_ -#include - -#include - -#include "grammar.h" - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm; -using namespace tvm::runtime; - -/*! - * \brief Helper class to build a BNF grammar. - */ -class BNFGrammarBuilder { - public: - using Rule = BNFGrammarNode::Rule; - using RuleExprType = BNFGrammarNode::RuleExprType; - using RuleExpr = BNFGrammarNode::RuleExpr; - - /*! \brief Default constructor. Creates a new grammar object. */ - BNFGrammarBuilder() : grammar_(make_object()) {} - - /*! - * \brief Get the result grammar. This function will also set the main rule to the rule with the - * specified name. The rule should be already added to the grammar. - * \param main_rule The name of the main rule. Default is "main". - */ - BNFGrammar Get(const std::string& main_rule = "main") { - int32_t main_rule_id = GetRuleId(main_rule); - CHECK(main_rule_id != -1) << "The main rule with name \"" << main_rule << "\" is not found."; - grammar_->main_rule_id_ = main_rule_id; - - return BNFGrammar(grammar_); - } - - /****************** RuleExpr handling ******************/ - - /*! \brief Add a rule_expr and return the rule_expr id. */ - int32_t AddRuleExpr(const RuleExpr& rule_expr) { - grammar_->rule_expr_indptr_.push_back(grammar_->rule_expr_data_.size()); - grammar_->rule_expr_data_.push_back(static_cast(rule_expr.type)); - grammar_->rule_expr_data_.push_back(rule_expr.data_len); - grammar_->rule_expr_data_.insert(grammar_->rule_expr_data_.end(), rule_expr.data, - rule_expr.data + rule_expr.data_len); - return static_cast(grammar_->rule_expr_indptr_.size()) - 1; - } - - /*! - * \brief Add a RuleExpr for string stored in bytes. - * \param bytes A vector of int32_t, each representing a byte (0~255) in the string. - * The string is stored in int32 vector to match the storage format of the grammar. - */ - int32_t AddByteString(const std::vector& bytes) { - return AddRuleExpr( - {RuleExprType::kByteString, bytes.data(), static_cast(bytes.size())}); - } - - /*! - * \brief One element of a character class, containing a lower and a upper bound. Both bounds are - * inclusive. - */ - struct CharacterClassElement { - int32_t lower; - int32_t upper; - }; - - /*! - * \brief Add a RuleExpr for a character class. - * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. - * \param is_negative Whether the character class is negated. - */ - int32_t AddCharacterClass(const std::vector& elements, - bool is_negative = false) { - std::vector data; - data.reserve(1 + elements.size() * 2); - data.push_back(static_cast(is_negative)); - for (const auto& range : elements) { - data.push_back(range.lower); - data.push_back(range.upper); - } - return AddRuleExpr( - {RuleExprType::kCharacterClass, data.data(), static_cast(data.size())}); - } - - /*! - * \brief Add a RuleExpr for a star quantifier of a character class. - * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. - * \param is_negative Whether the character class is negated. - */ - int32_t AddCharacterClassStar(const std::vector& elements, - bool is_negative = false) { - std::vector data; - data.reserve(1 + elements.size() * 2); - data.push_back(static_cast(is_negative)); - for (const auto& range : elements) { - data.push_back(range.lower); - data.push_back(range.upper); - } - return AddRuleExpr( - {RuleExprType::kCharacterClassStar, data.data(), static_cast(data.size())}); - } - - /*! \brief Add a RuleExpr for empty string.*/ - int32_t AddEmptyStr() { return AddRuleExpr({RuleExprType::kEmptyStr, nullptr, 0}); } - - /*! \brief Add a RuleExpr for rule reference.*/ - int32_t AddRuleRef(int32_t rule_id) { - std::vector data; - data.push_back(rule_id); - return AddRuleExpr({RuleExprType::kRuleRef, data.data(), static_cast(data.size())}); - } - - /*! \brief Add a RuleExpr for RuleExpr sequence.*/ - int32_t AddSequence(const std::vector& elements) { - return AddRuleExpr( - {RuleExprType::kSequence, elements.data(), static_cast(elements.size())}); - } - - /*! \brief Add a RuleExpr for RuleExpr choices.*/ - int32_t AddChoices(const std::vector& choices) { - return AddRuleExpr( - {RuleExprType::kChoices, choices.data(), static_cast(choices.size())}); - } - - size_t NumRuleExprs() const { return grammar_->NumRuleExprs(); } - /*! \brief Get the rule_expr with the given id. */ - RuleExpr GetRuleExpr(int32_t rule_expr_id) { return grammar_->GetRuleExpr(rule_expr_id); } - - /****************** Rule handling ******************/ - - /*! \brief Add a rule and return the rule id. */ - int32_t AddRule(const Rule& rule) { - int32_t id = grammar_->rules_.size(); - auto rules = grammar_->rules_; - grammar_->rules_.push_back(rule); - CHECK_EQ(rule_name_to_id_.count(rule.name), 0); - rule_name_to_id_[rule.name] = id; - return id; - } - - int32_t AddRule(const std::string& name, int32_t body_expr_id) { - return AddRule({name, body_expr_id}); - } - - int32_t AddRuleWithHint(const std::string& name_hint, int32_t body_expr_id) { - return AddRule({GetNewRuleName(name_hint), body_expr_id}); - } - - size_t NumRules() const { return grammar_->NumRules(); } - - /*! \brief Get the rule with the given id. */ - const Rule& GetRule(int32_t rule_id) const { return grammar_->rules_[rule_id]; } - - /*! - * \brief Add an rule without body, and return the rule id. The rule body should be set later - * with BNFGrammarBuilder::UpdateRuleBody. This method is useful for cases where the rule id is - * required to build the rule body. - * \sa BNFGrammarBuilder::UpdateRuleBody - */ - int32_t AddEmptyRule(const std::string& name) { return AddRule({name, -1}); } - - /*! - * \brief Update the rule body of the given rule, specified by rule id. Can be used to set the - * rule body of a rule inserted by BNFGrammarBuilder::AddEmptyRule. - */ - void UpdateRuleBody(int32_t rule_id, int32_t body_expr_id) { - CHECK(rule_id >= 0 && rule_id < static_cast(grammar_->rules_.size())) - << "Rule id " << rule_id << " is out of range."; - grammar_->rules_[rule_id].body_expr_id = body_expr_id; - } - - /*! - * \brief Update the rule body of the given rule, specified by rule name. Can be used to set the - * rule body of a rule inserted by BNFGrammarBuilder::AddEmptyRule. - */ - void UpdateRuleBody(std::string rule_name, int32_t body_expr_id) { - int32_t rule_id = GetRuleId(rule_name); - CHECK(rule_id != -1) << "Rule " << rule_name << " is not found."; - UpdateRuleBody(rule_id, body_expr_id); - } - - /*! - * \brief Add a lookahead assertion to a rule referred by the given rule_id. The lookahead - * assertion should be a sequence RuleExpr id. An id of -1 means no lookahead assertion. - */ - void AddLookaheadAssertion(int32_t rule_id, int32_t lookahead_assertion_id) { - CHECK(rule_id < static_cast(grammar_->rules_.size())) - << "Rule id " << rule_id << " is out of range."; - CHECK(grammar_->rules_[rule_id].lookahead_assertion_id == -1) - << "Rule " << rule_id << " already has a lookahead assertion."; - grammar_->rules_[rule_id].lookahead_assertion_id = lookahead_assertion_id; - } - - /*! - * \brief Add a lookahead assertion to a rule referred by the given name. The lookahead - * assertion should be a sequence RuleExpr id. An id of -1 means no lookahead assertion. - */ - void AddLookaheadAssertion(std::string rule_name, int32_t lookahead_assertion_id) { - int32_t rule_id = GetRuleId(rule_name); - CHECK(rule_id != -1) << "Rule " << rule_name << " is not found."; - AddLookaheadAssertion(rule_id, lookahead_assertion_id); - } - - /*! - * \brief Find a name for a new rule starting with the given name hint. Some integer suffix (_1, - * _2, ...) may be added to avoid name conflict. - */ - std::string GetNewRuleName(const std::string& name_hint) { - if (rule_name_to_id_.count(name_hint) == 0) { - return name_hint; - } else { - int cnt = 1; - while (rule_name_to_id_.count(name_hint + "_" + std::to_string(cnt)) != 0) { - ++cnt; - } - return name_hint + "_" + std::to_string(cnt); - } - } - - /*! - * \brief Get the rule id of the rule with the given name. Return -1 if not found. - */ - int32_t GetRuleId(const std::string& name) const { - auto it = rule_name_to_id_.find(name); - if (it == rule_name_to_id_.end()) { - return -1; - } else { - return it->second; - } - } - - private: - // Mutable pointer to the grammar object. - ObjectPtr grammar_; - // Map from rule name to rule id. - std::unordered_map rule_name_to_id_; -}; - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_BUILDER_H_ diff --git a/cpp/grammar/grammar_functor.cc b/cpp/grammar/grammar_functor.cc deleted file mode 100644 index 32378c559f..0000000000 --- a/cpp/grammar/grammar_functor.cc +++ /dev/null @@ -1,320 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_functor.cc - */ - -#include "grammar_functor.h" - -#include "../support/encoding.h" - -namespace mlc { -namespace llm { -namespace serve { - -/*! - * \brief Eliminates single-element sequence or choice or character class in the grammar. - * \example `A ::= choices("a")` --> `A ::= "a"` (the body is a string) - * \example `A ::= sequence("a")` --> `A ::= "a"` (the body is a string) - * \example `A ::= [a-a]` --> `A ::= "a"` (the body is a string) - */ -class SingleElementExprEliminator : public BNFGrammarMutator { - public: - using BNFGrammarMutator::Apply; - using BNFGrammarMutator::BNFGrammarMutator; - - private: - // Keep the sequence expr in lookahead assertion - int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final { - if (lookahead_assertion_id == -1) { - return -1; - } - auto rule_expr = grammar_->GetRuleExpr(lookahead_assertion_id); - CHECK(rule_expr.type == RuleExprType::kSequence); - - std::vector sequence_ids; - for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(i)); - } - return builder_.AddSequence(sequence_ids); - } - - int32_t VisitSequence(const RuleExpr& rule_expr) final { - std::vector sequence_ids; - for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(i)); - } - if (sequence_ids.size() == 1) { - return sequence_ids[0]; - } - return builder_.AddSequence(sequence_ids); - } - - int32_t VisitChoices(const RuleExpr& rule_expr) final { - std::vector choice_ids; - for (int32_t i : rule_expr) { - choice_ids.push_back(VisitExpr(i)); - } - if (choice_ids.size() == 1) { - return choice_ids[0]; - } - return builder_.AddChoices(choice_ids); - } - - int32_t VisitCharacterClass(const RuleExpr& rule_expr) final { - if (rule_expr.data_len == 3 && rule_expr[0] == 0 && rule_expr[1] == rule_expr[2]) { - std::string str = PrintAsUTF8(rule_expr[1]); - std::vector bytes; - bytes.reserve(str.size()); - for (char c : str) { - bytes.push_back(static_cast(c)); - } - return builder_.AddByteString(bytes); - } - return builder_.AddRuleExpr(rule_expr); - } -}; - -/*! - * \brief Unwrap the rules containing nested expressions. After unwrapping, each rule will be in - * the form: `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - * - * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character class - * or a rule reference. And if the rule can be empty, the first choice will be an empty string. - * - * \example The rule `A ::= ((a) (((b)) (c)) "")` will be replaced by `A ::= ((a b c))`. One choice - * containing a sequence of three elements. The empty string is removed. - * \example The rule `A ::= (a | (b | (c | "")))` will be replaced by - * `A ::= ("" | (a) | (b) | (c))`. The first choice is an empty string, and each of the other three - * choices is a sequence containing a single element. - * \example The rule `A ::= (a | (b (c | d)))` will be replaced by - * `A ::= ((a) | (b B)), B ::= ((c) | (d))`. A new rule B is created to represent the nested - * choices. - */ -class NestedRuleUnwrapper : public BNFGrammarMutator { - public: - using BNFGrammarMutator::BNFGrammarMutator; - - BNFGrammar Apply(const BNFGrammar& grammar) final { - Init(grammar); - for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { - builder_.AddEmptyRule(grammar_->GetRule(i).name); - } - for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { - auto rule = grammar_->GetRule(i); - auto rule_expr = grammar_->GetRuleExpr(rule.body_expr_id); - cur_rule_name_ = rule.name; - auto new_body_expr_id = VisitRuleBody(rule_expr); - builder_.UpdateRuleBody(i, new_body_expr_id); - builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); - } - return builder_.Get(grammar_->GetMainRule().name); - } - - private: - int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final { - if (lookahead_assertion_id == -1) { - return -1; - } - auto assertion_expr = grammar_->GetRuleExpr(lookahead_assertion_id); - return builder_.AddSequence(VisitSequence_(assertion_expr)); - } - - /*! \brief Visit a RuleExpr as a rule body. */ - int32_t VisitRuleBody(const RuleExpr& rule_expr) { - switch (rule_expr.type) { - case RuleExprType::kSequence: - return builder_.AddChoices({builder_.AddSequence(VisitSequence_(rule_expr))}); - case RuleExprType::kChoices: - return builder_.AddChoices(VisitChoices_(rule_expr)); - case RuleExprType::kEmptyStr: - return builder_.AddChoices({builder_.AddEmptyStr()}); - case RuleExprType::kByteString: - case RuleExprType::kCharacterClass: - case RuleExprType::kCharacterClassStar: - case RuleExprType::kRuleRef: - return builder_.AddChoices({builder_.AddSequence({builder_.AddRuleExpr(rule_expr)})}); - default: - LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); - } - } - - /*! - * \brief Visit a RuleExpr containing choices. - * \returns A list of new choice RuleExpr ids. - */ - std::vector VisitChoices_(const RuleExpr& rule_expr) { - std::vector new_choice_ids; - bool found_empty = false; - for (auto i : rule_expr) { - auto choice_expr = grammar_->GetRuleExpr(i); - switch (choice_expr.type) { - case RuleExprType::kSequence: - VisitSequenceInChoices(choice_expr, &new_choice_ids, &found_empty); - break; - case RuleExprType::kChoices: - VisitChoicesInChoices(choice_expr, &new_choice_ids, &found_empty); - break; - case RuleExprType::kEmptyStr: - found_empty = true; - break; - case RuleExprType::kByteString: - case RuleExprType::kCharacterClass: - case RuleExprType::kCharacterClassStar: - case RuleExprType::kRuleRef: - VisitElementInChoices(choice_expr, &new_choice_ids); - break; - default: - LOG(FATAL) << "Unexpected choice type: " << static_cast(choice_expr.type); - } - } - if (found_empty) { - new_choice_ids.insert(new_choice_ids.begin(), builder_.AddEmptyStr()); - } - ICHECK_GE(new_choice_ids.size(), 1); - return new_choice_ids; - } - - /*! \brief Visit a sequence RuleExpr that is one of a list of choices. */ - void VisitSequenceInChoices(const RuleExpr& rule_expr, std::vector* new_choice_ids, - bool* found_empty) { - auto sub_sequence_ids = VisitSequence_(rule_expr); - if (sub_sequence_ids.size() == 0) { - *found_empty = true; - } else { - new_choice_ids->push_back(builder_.AddSequence(sub_sequence_ids)); - } - } - - /*! \brief Visit a choice RuleExpr that is one of a list of choices. */ - void VisitChoicesInChoices(const RuleExpr& rule_expr, std::vector* new_choice_ids, - bool* found_empty) { - auto sub_choice_ids = VisitChoices_(rule_expr); - bool contains_empty = builder_.GetRuleExpr(sub_choice_ids[0]).type == RuleExprType::kEmptyStr; - if (contains_empty) { - *found_empty = true; - new_choice_ids->insert(new_choice_ids->end(), sub_choice_ids.begin() + 1, - sub_choice_ids.end()); - } else { - new_choice_ids->insert(new_choice_ids->end(), sub_choice_ids.begin(), sub_choice_ids.end()); - } - } - - /*! \brief Visit an atom element RuleExpr that is one of a list of choices. */ - void VisitElementInChoices(const RuleExpr& rule_expr, std::vector* new_choice_ids) { - auto sub_expr_id = builder_.AddRuleExpr(rule_expr); - new_choice_ids->push_back(builder_.AddSequence({sub_expr_id})); - } - - /*! - * \brief Visit a RuleExpr containing a sequence. - * \returns A list of new sequence RuleExpr ids. - */ - std::vector VisitSequence_(const RuleExpr& rule_expr) { - std::vector new_sequence_ids; - for (auto i : rule_expr) { - auto element_expr = grammar_->GetRuleExpr(i); - switch (element_expr.type) { - case RuleExprType::kSequence: - VisitSequenceInSequence(element_expr, &new_sequence_ids); - break; - case RuleExprType::kChoices: - VisitChoiceInSequence(element_expr, &new_sequence_ids); - break; - case RuleExprType::kEmptyStr: - break; - case RuleExprType::kByteString: - case RuleExprType::kCharacterClass: - case RuleExprType::kCharacterClassStar: - case RuleExprType::kRuleRef: - VisitElementInSequence(element_expr, &new_sequence_ids); - break; - default: - LOG(FATAL) << "Unexpected sequence type: " << static_cast(element_expr.type); - } - } - return new_sequence_ids; - } - - /*! \brief Visit a sequence RuleExpr that is one element in another sequence. */ - void VisitSequenceInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { - auto sub_sequence_ids = VisitSequence_(rule_expr); - new_sequence_ids->insert(new_sequence_ids->end(), sub_sequence_ids.begin(), - sub_sequence_ids.end()); - } - - /*! \brief Visit a choice RuleExpr that is one element in a sequence. */ - void VisitChoiceInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { - auto sub_choice_ids = VisitChoices_(rule_expr); - if (sub_choice_ids.size() == 1) { - auto choice_element_expr = builder_.GetRuleExpr(sub_choice_ids[0]); - if (choice_element_expr.type != RuleExprType::kEmptyStr) { - new_sequence_ids->insert(new_sequence_ids->end(), choice_element_expr.begin(), - choice_element_expr.end()); - } - } else { - auto new_choice_id = builder_.AddChoices(sub_choice_ids); - auto new_choice_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_choice", new_choice_id); - new_sequence_ids->push_back(builder_.AddRuleRef(new_choice_rule_id)); - } - } - - /*! \brief Visit an atom element RuleExpr that is in a sequence. */ - void VisitElementInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { - new_sequence_ids->push_back(builder_.AddRuleExpr(rule_expr)); - } -}; - -class ByteStringFuser : public BNFGrammarMutator { - public: - using BNFGrammarMutator::Apply; - using BNFGrammarMutator::BNFGrammarMutator; - - private: - /*! - * \brief Visit a RuleExpr containing a sequence. - * \returns A list of new sequence RuleExpr ids. - */ - int32_t VisitSequence(const RuleExpr& rule_expr) final { - std::vector new_sequence_ids; - std::vector cur_byte_string; - for (auto i : rule_expr) { - auto element_expr = grammar_->GetRuleExpr(i); - if (element_expr.type == RuleExprType::kByteString) { - cur_byte_string.insert(cur_byte_string.end(), element_expr.begin(), element_expr.end()); - continue; - } else { - if (!cur_byte_string.empty()) { - new_sequence_ids.push_back(builder_.AddByteString(cur_byte_string)); - cur_byte_string.clear(); - } - new_sequence_ids.push_back(builder_.AddRuleExpr(element_expr)); - } - } - if (!cur_byte_string.empty()) { - new_sequence_ids.push_back(builder_.AddByteString(cur_byte_string)); - } - return builder_.AddSequence(new_sequence_ids); - } -}; - -// Return the list of all normalizers in the class. The normalizers are applied one by one. -std::vector> BNFGrammarNormalizer::GetNormalizerList() { - std::vector> normalizer_mutators; - normalizer_mutators.emplace_back(std::make_unique()); - normalizer_mutators.emplace_back(std::make_unique()); - normalizer_mutators.emplace_back(std::make_unique()); - return normalizer_mutators; -} - -BNFGrammar BNFGrammarNormalizer::Apply(const BNFGrammar& grammar) { - std::vector> normalizer_mutators = GetNormalizerList(); - grammar_ = grammar; - for (auto& mutator : normalizer_mutators) { - grammar_ = mutator->Apply(grammar_); - } - return grammar_; -} - -} // namespace serve -} // namespace llm -} // namespace mlc diff --git a/cpp/grammar/grammar_functor.h b/cpp/grammar/grammar_functor.h deleted file mode 100644 index 07da50519d..0000000000 --- a/cpp/grammar/grammar_functor.h +++ /dev/null @@ -1,219 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_functor.h - * \brief The header for the simplification of the BNF AST. - */ - -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_FUNCTOR_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_FUNCTOR_H_ - -#include -#include - -#include "grammar.h" -#include "grammar_builder.h" -#include "grammar_serializer.h" - -namespace mlc { -namespace llm { -namespace serve { - -/*! - * \brief Base class for visitors and mutators of the BNF grammar. - * \tparam T The type of the return value of visitor functions. Typical values: - * - int32_t: the id of the new rule_expr - * - void: no return value - * \tparam ReturnType The type of the return value of the transform function Apply(). Typical values - * are void (for visitor) and BNFGrammar (for mutator). - */ -template -class BNFGrammarFunctor { - public: - /*! - * \brief Constructor. - * \param grammar The grammar to visit or mutate. - */ - explicit BNFGrammarFunctor() {} - - /*! - * \brief Apply the transformation to the grammar, or visit the grammar. - * \return The transformed grammar, or the visiting result, or void. - */ - virtual ReturnType Apply(const BNFGrammar& grammar) { - Init(grammar); - if constexpr (std::is_same::value) { - for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { - auto rule = grammar_->GetRule(i); - cur_rule_name_ = rule.name; - VisitExpr(rule.body_expr_id); - VisitLookaheadAssertion(rule.lookahead_assertion_id); - } - } else if constexpr (std::is_same::value && - std::is_same::value) { - // First add empty rules to ensure the new rule ids the same as the old ones, then update - // the rule bodies - for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { - builder_.AddEmptyRule(grammar_->GetRule(i).name); - } - for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { - auto rule = grammar_->GetRule(i); - cur_rule_name_ = rule.name; - auto new_body_expr_id = VisitExpr(rule.body_expr_id); - builder_.UpdateRuleBody(i, new_body_expr_id); - // Handle lookahead assertion - builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); - } - return builder_.Get(grammar_->GetMainRule().name); - } else { - return ReturnType(); - } - } - - protected: - using Rule = BNFGrammarNode::Rule; - using RuleExpr = BNFGrammarNode::RuleExpr; - using RuleExprType = BNFGrammarNode::RuleExprType; - - /*! \brief Initialize the functor. Should be called at the beginning of Apply(). */ - virtual void Init(const BNFGrammar& grammar) { - grammar_ = grammar; - builder_ = BNFGrammarBuilder(); - } - - /*! \brief Visit a lookahead assertion expr referred by id. */ - virtual T VisitLookaheadAssertion(int32_t lookahead_assertion_id) { - if (lookahead_assertion_id == -1) { - return -1; - } - return VisitExpr(lookahead_assertion_id); - } - - /*! \brief Visit a RuleExpr by id. */ - virtual T VisitExpr(int32_t old_rule_expr_id) { - return VisitExpr(grammar_->GetRuleExpr(old_rule_expr_id)); - } - - /*! \brief Visit a RuleExpr. Dispatch to the corresponding Visit function. */ - virtual T VisitExpr(const RuleExpr& rule_expr) { - switch (rule_expr.type) { - case RuleExprType::kSequence: - return VisitSequence(rule_expr); - case RuleExprType::kChoices: - return VisitChoices(rule_expr); - case RuleExprType::kEmptyStr: - return VisitEmptyStr(rule_expr); - case RuleExprType::kByteString: - return VisitByteString(rule_expr); - case RuleExprType::kCharacterClass: - return VisitCharacterClass(rule_expr); - case RuleExprType::kCharacterClassStar: - return VisitCharacterClassStar(rule_expr); - case RuleExprType::kRuleRef: - return VisitRuleRef(rule_expr); - default: - LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); - } - } - - /*! \brief Visit a choices RuleExpr. */ - virtual T VisitChoices(const RuleExpr& rule_expr) { - if constexpr (std::is_same::value) { - for (auto i : rule_expr) { - VisitExpr(i); - } - } else if constexpr (std::is_same::value) { - std::vector choice_ids; - for (int32_t i : rule_expr) { - choice_ids.push_back(VisitExpr(i)); - } - return builder_.AddChoices(choice_ids); - } else { - return T(); - } - } - - /*! \brief Visit a sequence RuleExpr. */ - virtual T VisitSequence(const RuleExpr& rule_expr) { - if constexpr (std::is_same::value) { - for (auto i : rule_expr) { - VisitExpr(i); - } - } else if constexpr (std::is_same::value) { - std::vector sequence_ids; - for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(i)); - } - return builder_.AddSequence(sequence_ids); - } else { - return T(); - } - } - - /*! \brief Visit an element RuleExpr, including empty string, character class, and rule ref. */ - virtual T VisitElement(const RuleExpr& rule_expr) { - if constexpr (std::is_same::value) { - return; - } else if constexpr (std::is_same::value) { - return builder_.AddRuleExpr(rule_expr); - } else { - return T(); - } - } - - /*! \brief Visit an empty string RuleExpr. */ - virtual T VisitEmptyStr(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } - - /*! \brief Visit a character class RuleExpr. */ - virtual T VisitByteString(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } - - /*! \brief Visit a character class RuleExpr. */ - virtual T VisitCharacterClass(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } - - /*! \brief Visit a star quantifier RuleExpr. */ - virtual T VisitCharacterClassStar(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } - - /*! \brief Visit a rule reference RuleExpr. */ - virtual T VisitRuleRef(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } - - /*! \brief The grammar to visit or mutate. */ - BNFGrammar grammar_; - /*! - * \brief The builder to build the new grammar. It is empty when the mutator is constructed, and - * can be used to build a new grammar in subclasses. - */ - BNFGrammarBuilder builder_; - /*! \brief The name of the current rule being visited. */ - std::string cur_rule_name_; -}; - -/*! - * \brief Visitor of BNFGrammar. - * \tparam ReturnType The return type of the Apply() function. Denotes the collected information. - */ -template -using BNFGrammarVisitor = BNFGrammarFunctor; - -/*! - * \brief Mutator of BNFGrammar. The Apply() function returns the updated grammar. - */ -using BNFGrammarMutator = BNFGrammarFunctor; - -/*! - * \brief Normalize a BNFGrammar: expand the nested rules, combine consequent sequences and strings, - * etc. - */ -class BNFGrammarNormalizer : public BNFGrammarMutator { - public: - using BNFGrammarMutator::BNFGrammarMutator; - - BNFGrammar Apply(const BNFGrammar& grammar) final; - - private: - std::vector> GetNormalizerList(); -}; - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_FUNCTOR_H_ diff --git a/cpp/grammar/grammar_parser.cc b/cpp/grammar/grammar_parser.cc deleted file mode 100644 index b585798d3c..0000000000 --- a/cpp/grammar/grammar_parser.cc +++ /dev/null @@ -1,485 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_parser.cc - */ - -#include "grammar_parser.h" - -#include "../support/encoding.h" -#include "../support/json_parser.h" -#include "grammar_builder.h" - -namespace mlc { -namespace llm { -namespace serve { - -class EBNFParserImpl { - public: - /*! \brief The logic of parsing the grammar string. */ - BNFGrammar DoParse(std::string ebnf_string, std::string main_rule); - - private: - using Rule = BNFGrammarNode::Rule; - using ParseError = EBNFParser::ParseError; - - // Parsing different parts of the grammar - std::string ParseName(bool accept_empty = false); - int32_t ParseCharacterClass(); - int32_t ParseString(); - int32_t ParseRuleRef(); - int32_t ParseElement(); - int32_t ParseQuantifier(); - int32_t ParseLookaheadAssertion(); - int32_t ParseSequence(); - int32_t ParseChoices(); - Rule ParseRule(); - - // Helper functions - // Helper for ParseQuantifier - int32_t HandleStarQuantifier(int32_t rule_expr_id); - int32_t HandlePlusQuantifier(int32_t rule_expr_id); - int32_t HandleQuestionQuantifier(int32_t rule_expr_id); - - // When parsing, we first find the names of all rules, and build the mapping from name to rule id. - void BuildRuleNameToId(); - // Consumes several spaces (newline, space, tab, comment, etc.) - void ConsumeSpace(bool allow_newline = true); - // Check the validity of a name - static bool IsNameChar(TCodepoint c, bool first_char = false); - // Reset the parser to the beginning of the string. - void ResetStringIterator(const char* cur); - - // Consume a specified number of characters, and maintain the line and column number. - void Consume(int cnt = 1) { - for (int i = 0; i < cnt; ++i) { - // \n \r \r\n - if (Peek() == '\n' || (Peek() == '\r' && Peek(1) != '\n')) { - ++cur_line_; - cur_column_ = 1; - } else { - ++cur_column_; - } - ++cur_; - } - } - - // Peek the next character. - char Peek(int delta = 0) const { return *(cur_ + delta); } - - // Throw a ParseError with the given message and the line and column number. - [[noreturn]] void ThrowParseError(const std::string& msg) { - throw ParseError("EBNF parse error at line " + std::to_string(cur_line_) + ", column " + - std::to_string(cur_column_) + ": " + msg); - } - - // The grammar builder - BNFGrammarBuilder builder_; - // A pointer to the current parse position in the string - const char* cur_ = nullptr; - // The current line and column number - int cur_line_ = 1; - int cur_column_ = 1; - // The current rule name. Help to generate a name for a new rule. - std::string cur_rule_name_; - // Whether the current element is in parentheses. - // A sequence expression cannot contain newline, unless it is in parentheses. - bool in_parentheses_ = false; -}; - -void EBNFParserImpl::ConsumeSpace(bool allow_newline) { - while (Peek() && (Peek() == ' ' || Peek() == '\t' || Peek() == '#' || - (allow_newline && (Peek() == '\n' || Peek() == '\r')))) { - Consume(); - if (Peek(-1) == '#') { - while (Peek() && Peek() != '\n' && Peek() != '\r') { - Consume(); - } - if (!Peek()) { - return; - } - Consume(); - if (Peek(-1) == '\r' && Peek() == '\n') { - Consume(); - } - } - } -} - -bool EBNFParserImpl::IsNameChar(TCodepoint c, bool first_char) { - return c == '_' || c == '-' || c == '.' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || - (!first_char && c >= '0' && c <= '9'); -} - -// name should be a char string (not a utf8 string) -std::string EBNFParserImpl::ParseName(bool accept_empty) { - auto start = cur_; - bool first_char = true; - while (Peek() && IsNameChar(Peek(), first_char)) { - Consume(); - first_char = false; - } - if (start == cur_ && !accept_empty) { - ThrowParseError("Expect rule name"); - } - return std::string(start, cur_); -} - -// Character class: -// 1. Examples: [a-z] [ab] [a-zA-Z0-9] [^a-z] [测] [\u0123] -// 2. The "-" character is treated as a literal character if it is the last or the first (after -// the "^"", if present) character within the brackets. E.g. [a-] and [-a] means "a" or "-" -// 3. "-" and "]" should be escaped when used as a literal character: -// [\-] means - -// [\]] means ] -// Character class should not contain newlines. -int32_t EBNFParserImpl::ParseCharacterClass() { - static constexpr TCodepoint kUnknownUpperBound = -4; - static const std::unordered_map kCustomEscapeMap = {{"\\-", '-'}, - {"\\]", ']'}}; - - std::vector elements; - - bool is_negated = false; - if (Peek() == '^') { - is_negated = true; - Consume(); - } - - bool past_is_hyphen = false; - bool past_is_single_char = false; - while (Peek() && Peek() != ']') { - if (Peek() == '\r' || Peek() == '\n') { - ThrowParseError("Character class should not contain newline"); - } else if (Peek() == '-' && Peek(1) != ']' && !past_is_hyphen && past_is_single_char) { - Consume(); - past_is_hyphen = true; - past_is_single_char = false; - continue; - } - - auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap); - if (codepoint == CharHandlingError::kInvalidUTF8) { - ThrowParseError("Invalid UTF8 sequence"); - } - if (codepoint == CharHandlingError::kInvalidEscape) { - ThrowParseError("Invalid escape sequence"); - } - Consume(new_cur - cur_); - if (past_is_hyphen) { - ICHECK(!elements.empty()); - if (elements.back().lower > codepoint) { - ThrowParseError("Invalid character class: lower bound is larger than upper bound"); - } - elements.back().upper = codepoint; - past_is_hyphen = false; - ICHECK(past_is_single_char == false); - } else { - elements.push_back({codepoint, kUnknownUpperBound}); - past_is_single_char = true; - } - } - - for (auto& element : elements) { - if (element.upper == kUnknownUpperBound) { - element.upper = element.lower; - } - } - - return builder_.AddCharacterClass(elements, is_negated); -} - -// parse a c style string with utf8 support -int32_t EBNFParserImpl::ParseString() { - std::vector codepoints; - while (Peek() && Peek() != '\"') { - if (Peek() == '\r' || Peek() == '\n') { - ThrowParseError("There should be no newline character in a string literal"); - } - - auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_); - if (codepoint == CharHandlingError::kInvalidUTF8) { - ThrowParseError("Invalid utf8 sequence"); - } - if (codepoint == CharHandlingError::kInvalidEscape) { - ThrowParseError("Invalid escape sequence"); - } - Consume(new_cur - cur_); - codepoints.push_back(codepoint); - } - if (codepoints.empty()) { - return builder_.AddEmptyStr(); - } - - // convert codepoints to string - std::string str; - for (auto codepoint : codepoints) { - str += PrintAsUTF8(codepoint); - } - // convert str to int32_t vector - std::vector bytes; - for (auto c : str) { - bytes.push_back(static_cast(c)); - } - return builder_.AddByteString(bytes); -} - -int32_t EBNFParserImpl::ParseRuleRef() { - std::string name = ParseName(); - auto rule_id = builder_.GetRuleId(name); - if (rule_id == -1) { - ThrowParseError("Rule \"" + name + "\" is not defined"); - } - return builder_.AddRuleRef(rule_id); -} - -int32_t EBNFParserImpl::ParseElement() { - switch (Peek()) { - case '(': { - Consume(); - ConsumeSpace(); - auto prev_in_parentheses = in_parentheses_; - in_parentheses_ = true; - auto rule_expr_id = ParseChoices(); - ConsumeSpace(); - if (Peek() != ')') { - ThrowParseError("Expect )"); - } - Consume(); - in_parentheses_ = prev_in_parentheses; - return rule_expr_id; - } - case '[': { - Consume(); - auto rule_expr_id = ParseCharacterClass(); - if (Peek() != ']') { - ThrowParseError("Expect ]"); - } - Consume(); - return rule_expr_id; - } - case '\"': { - Consume(); - auto rule_expr_id = ParseString(); - if (Peek() != '\"') { - ThrowParseError("Expect \""); - } - Consume(); - return rule_expr_id; - } - default: { - if (IsNameChar(Peek(), true)) { - return ParseRuleRef(); - } - ThrowParseError("Expect element"); - } - } -} - -int32_t EBNFParserImpl::HandleStarQuantifier(int32_t rule_expr_id) { - BNFGrammarNode::RuleExpr rule_expr = builder_.GetRuleExpr(rule_expr_id); - if (rule_expr.type == BNFGrammarBuilder::RuleExprType::kCharacterClass) { - // We have special handling for character class star, e.g. [a-z]* - rule_expr.type = BNFGrammarBuilder::RuleExprType::kCharacterClassStar; - return builder_.AddRuleExpr(rule_expr); - } else { - // For other star quantifiers, we transform it into a rule: - // a* --> rule ::= a rule | "" - auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto new_rule_id = builder_.AddEmptyRule(new_rule_name); - auto ref_to_new_rule = builder_.AddRuleRef(new_rule_id); - auto new_rule_expr_id = builder_.AddChoices( - {builder_.AddSequence({rule_expr_id, ref_to_new_rule}), builder_.AddEmptyStr()}); - builder_.UpdateRuleBody(new_rule_id, new_rule_expr_id); - - // Return the reference to the new rule - return builder_.AddRuleRef(new_rule_id); - } -} - -int32_t EBNFParserImpl::HandlePlusQuantifier(int32_t rule_expr_id) { - // a+ --> rule ::= a rule | a - auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto new_rule_id = builder_.AddEmptyRule(new_rule_name); - auto ref_to_new_rule = builder_.AddRuleRef(new_rule_id); - auto new_rule_expr_id = - builder_.AddChoices({builder_.AddSequence({rule_expr_id, ref_to_new_rule}), rule_expr_id}); - builder_.UpdateRuleBody(new_rule_id, new_rule_expr_id); - - // Return the reference to the new rule - return builder_.AddRuleRef(new_rule_id); -} - -int32_t EBNFParserImpl::HandleQuestionQuantifier(int32_t rule_expr_id) { - // a? --> rule ::= a | empty - auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto new_rule_expr_id = builder_.AddChoices({rule_expr_id, builder_.AddEmptyStr()}); - auto new_rule_id = builder_.AddRule({new_rule_name, new_rule_expr_id}); - return builder_.AddRuleRef(new_rule_id); -} - -int32_t EBNFParserImpl::ParseQuantifier() { - int32_t rule_expr_id = ParseElement(); - ConsumeSpace(in_parentheses_); - if (Peek() != '*' && Peek() != '+' && Peek() != '?') { - return rule_expr_id; - } - Consume(); - - // We will transform a*, a+, a? into a rule, and return the reference to this rule - switch (Peek(-1)) { - case '*': - // We assume that the star quantifier should be the body of some rule now - return HandleStarQuantifier(rule_expr_id); - case '+': - return HandlePlusQuantifier(rule_expr_id); - case '?': - return HandleQuestionQuantifier(rule_expr_id); - default: - LOG(FATAL) << "Unreachable"; - } -} - -int32_t EBNFParserImpl::ParseSequence() { - std::vector elements; - do { - elements.push_back(ParseQuantifier()); - ConsumeSpace(in_parentheses_); - } while (Peek() && Peek() != '|' && Peek() != ')' && Peek() != '\n' && Peek() != '\r' && - (Peek() != '(' || Peek(1) != '=')); - return builder_.AddSequence(elements); -} - -int32_t EBNFParserImpl::ParseChoices() { - std::vector choices; - - choices.push_back(ParseSequence()); - ConsumeSpace(); - while (Peek() == '|') { - Consume(); - ConsumeSpace(); - choices.push_back(ParseSequence()); - ConsumeSpace(); - } - return builder_.AddChoices(choices); -} - -int32_t EBNFParserImpl::ParseLookaheadAssertion() { - if (Peek() != '(' || Peek(1) != '=') { - return -1; - } - Consume(2); - auto prev_in_parentheses = in_parentheses_; - in_parentheses_ = true; - ConsumeSpace(in_parentheses_); - auto result = ParseSequence(); - ConsumeSpace(in_parentheses_); - if (Peek() != ')') { - ThrowParseError("Expect )"); - } - Consume(); - in_parentheses_ = prev_in_parentheses; - return result; -} - -EBNFParserImpl::Rule EBNFParserImpl::ParseRule() { - std::string name = ParseName(); - cur_rule_name_ = name; - ConsumeSpace(); - if (Peek() != ':' || Peek(1) != ':' || Peek(2) != '=') { - ThrowParseError("Expect ::="); - } - Consume(3); - ConsumeSpace(); - auto body_id = ParseChoices(); - ConsumeSpace(); - auto lookahead_id = ParseLookaheadAssertion(); - return {name, body_id, lookahead_id}; -} - -void EBNFParserImpl::BuildRuleNameToId() { - ConsumeSpace(); - while (Peek()) { - auto name = ParseName(true); - ConsumeSpace(false); - if (Peek() == ':' && Peek(1) == ':' && Peek(2) == '=') { - if (name.empty()) { - ThrowParseError("Expect rule name"); - } - Consume(3); - if (builder_.GetRuleId(name) != -1) { - ThrowParseError("Rule \"" + name + "\" is defined multiple times"); - } - builder_.AddEmptyRule(name); - } - while (Peek() && Peek() != '\n' && Peek() != '\r') { - Consume(); - } - ConsumeSpace(); - } -} - -void EBNFParserImpl::ResetStringIterator(const char* cur) { - cur_ = cur; - cur_line_ = 1; - cur_column_ = 1; - cur_rule_name_ = ""; - in_parentheses_ = false; -} - -BNFGrammar EBNFParserImpl::DoParse(std::string ebnf_string, std::string main_rule) { - ResetStringIterator(ebnf_string.c_str()); - BuildRuleNameToId(); - - ResetStringIterator(ebnf_string.c_str()); - ConsumeSpace(); - while (Peek()) { - // Throw error when there are multiple lookahead assertions - if (Peek() == '(' && Peek(1) == '=') { - ThrowParseError("Unexpected lookahead assertion"); - } - auto new_rule = ParseRule(); - builder_.UpdateRuleBody(new_rule.name, new_rule.body_expr_id); - // Update the lookahead assertion - builder_.AddLookaheadAssertion(new_rule.name, new_rule.lookahead_assertion_id); - - ConsumeSpace(); - } - - // Check that the main rule is defined - if (builder_.GetRuleId(main_rule) == -1) { - ThrowParseError("The main rule with name \"" + main_rule + "\" is not found."); - } - - return builder_.Get(main_rule); -} - -BNFGrammar EBNFParser::Parse(std::string ebnf_string, std::string main_rule) { - EBNFParserImpl parser; - return parser.DoParse(ebnf_string, main_rule); -} - -BNFGrammar BNFJSONParser::Parse(std::string json_string) { - auto node = make_object(); - auto grammar_json = json::ParseToJSONObject(json_string); - auto rules_json = json::Lookup(grammar_json, "rules"); - for (const auto& rule_json : rules_json) { - auto rule_json_obj = rule_json.get(); - auto name = json::Lookup(rule_json.get(), "name"); - auto rule_expr = static_cast( - json::Lookup(rule_json.get(), "body_expr_id")); - node->rules_.push_back(BNFGrammarNode::Rule({name, rule_expr})); - } - auto rule_expr_data_json = json::Lookup(grammar_json, "rule_expr_data"); - for (const auto& data_json : rule_expr_data_json) { - node->rule_expr_data_.push_back(static_cast(data_json.get())); - } - auto rule_expr_indptr_json = json::Lookup(grammar_json, "rule_expr_indptr"); - for (const auto& index_ptr_json : rule_expr_indptr_json) { - node->rule_expr_indptr_.push_back(static_cast(index_ptr_json.get())); - } - return BNFGrammar(std::move(node)); -} - -} // namespace serve -} // namespace llm -} // namespace mlc diff --git a/cpp/grammar/grammar_parser.h b/cpp/grammar/grammar_parser.h deleted file mode 100644 index b55b726e14..0000000000 --- a/cpp/grammar/grammar_parser.h +++ /dev/null @@ -1,68 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_parser.h - * \brief The header for the parser of BNF/EBNF grammar into BNF AST. - */ - -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_PARSER_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_PARSER_H_ - -#include -#include - -#include "grammar.h" - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm::runtime; - -/*! - * \brief This class parses a BNF/EBNF grammar string into an BNF abstract syntax tree (AST). - * \details This function accepts the EBNF notation defined in the W3C XML Specification - * (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following - * changes: - * - Using # as comment mark instead of C-style comments - * - Accept C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 - * - Rule A-B (match A and not match B) is not supported yet - * - * See tests/python/serve/json.ebnf for an example. - */ -class EBNFParser { - public: - /*! - * \brief Parse the grammar string. If fails, throw ParseError with the error message. - * \param ebnf_string The grammar string. - * \param main_rule The name of the main rule. Default is "main". - * \return The parsed grammar. - */ - static BNFGrammar Parse(std::string ebnf_string, std::string main_rule = "main"); - - /*! - * \brief The exception thrown when parsing fails. - */ - class ParseError : public Error { - public: - ParseError(const std::string& msg) : Error(msg) {} - }; -}; - -/*! - * \brief Parse a BNF grammar from the raw representation of the AST in JSON format. - */ -class BNFJSONParser { - public: - /*! - * \brief Parse the JSON string - * \param json_string The JSON string. - * \return The parsed BNF grammar. - */ - static BNFGrammar Parse(std::string json_string); -}; - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_PARSER_H_ diff --git a/cpp/grammar/grammar_serializer.cc b/cpp/grammar/grammar_serializer.cc deleted file mode 100644 index 6f4125ce6c..0000000000 --- a/cpp/grammar/grammar_serializer.cc +++ /dev/null @@ -1,175 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_serializer.cc - */ - -#include "grammar_serializer.h" - -#include -#include - -#include "../support/encoding.h" - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm::runtime; - -std::string BNFGrammarPrinter::PrintRule(const Rule& rule) { - std::string res = rule.name + " ::= " + PrintRuleExpr(rule.body_expr_id); - if (rule.lookahead_assertion_id != -1) { - res += " (=" + PrintRuleExpr(rule.lookahead_assertion_id) + ")"; - } - return res; -} - -std::string BNFGrammarPrinter::PrintRule(int32_t rule_id) { - return PrintRule(grammar_->GetRule(rule_id)); -} - -std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { - std::string result; - switch (rule_expr.type) { - case RuleExprType::kByteString: - return PrintByteString(rule_expr); - case RuleExprType::kCharacterClass: - return PrintCharacterClass(rule_expr); - case RuleExprType::kCharacterClassStar: - return PrintCharacterClassStar(rule_expr); - case RuleExprType::kEmptyStr: - return PrintEmptyStr(rule_expr); - case RuleExprType::kRuleRef: - return PrintRuleRef(rule_expr); - case RuleExprType::kSequence: - return PrintSequence(rule_expr); - case RuleExprType::kChoices: - return PrintChoices(rule_expr); - default: - LOG(FATAL) << "Unexpected RuleExpr type: " << static_cast(rule_expr.type); - } -} - -std::string BNFGrammarPrinter::PrintRuleExpr(int32_t rule_expr_id) { - return PrintRuleExpr(grammar_->GetRuleExpr(rule_expr_id)); -} - -std::string BNFGrammarPrinter::PrintByteString(const RuleExpr& rule_expr) { - std::string internal_str; - internal_str.reserve(rule_expr.data_len); - for (int i = 0; i < rule_expr.data_len; ++i) { - internal_str += static_cast(rule_expr[i]); - } - auto codepoints = ParseUTF8(internal_str.c_str(), UTF8ErrorPolicy::kReturnByte); - std::string result; - for (auto codepoint : codepoints) { - result += PrintAsEscaped(codepoint); - } - return "\"" + result + "\""; -} - -std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { - static const std::unordered_map kCustomEscapeMap = {{'-', "\\-"}, - {']', "\\]"}}; - std::string result = "["; - bool is_negative = static_cast(rule_expr[0]); - if (is_negative) { - result += "^"; - } - for (auto i = 1; i < rule_expr.data_len; i += 2) { - result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap); - if (rule_expr[i] == rule_expr[i + 1]) { - continue; - } - result += "-"; - result += PrintAsEscaped(rule_expr[i + 1], kCustomEscapeMap); - } - result += "]"; - return result; -} - -std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr) { - return PrintCharacterClass(rule_expr) + "*"; -} - -std::string BNFGrammarPrinter::PrintEmptyStr(const RuleExpr& rule_expr) { return "\"\""; } - -std::string BNFGrammarPrinter::PrintRuleRef(const RuleExpr& rule_expr) { - return grammar_->GetRule(rule_expr[0]).name; -} - -std::string BNFGrammarPrinter::PrintSequence(const RuleExpr& rule_expr) { - std::string result; - result += "("; - for (int i = 0; i < rule_expr.data_len; ++i) { - result += PrintRuleExpr(rule_expr[i]); - if (i + 1 != rule_expr.data_len) { - result += " "; - } - } - result += ")"; - return result; -} - -std::string BNFGrammarPrinter::PrintChoices(const RuleExpr& rule_expr) { - std::string result; - - result += "("; - for (int i = 0; i < rule_expr.data_len; ++i) { - result += PrintRuleExpr(rule_expr[i]); - if (i + 1 != rule_expr.data_len) { - result += " | "; - } - } - result += ")"; - return result; -} - -std::string BNFGrammarPrinter::ToString() { - std::string result; - auto num_rules = grammar_->NumRules(); - for (auto i = 0; i < num_rules; ++i) { - result += PrintRule(grammar_->GetRule(i)) + "\n"; - } - return result; -} - -TVM_REGISTER_GLOBAL("mlc.grammar.BNFGrammarToString").set_body_typed([](const BNFGrammar& grammar) { - return BNFGrammarPrinter(grammar).ToString(); -}); - -std::string BNFGrammarJSONSerializer::ToString() { - picojson::object grammar_json_obj; - - picojson::array rules_json; - for (const auto& rule : grammar_->rules_) { - picojson::object rule_json; - rule_json["name"] = picojson::value(rule.name); - rule_json["body_expr_id"] = picojson::value(static_cast(rule.body_expr_id)); - rules_json.push_back(picojson::value(rule_json)); - } - grammar_json_obj["rules"] = picojson::value(rules_json); - - picojson::array rule_expr_data_json; - for (const auto& data : grammar_->rule_expr_data_) { - rule_expr_data_json.push_back(picojson::value(static_cast(data))); - } - grammar_json_obj["rule_expr_data"] = picojson::value(rule_expr_data_json); - picojson::array rule_expr_indptr_json; - for (const auto& index_ptr : grammar_->rule_expr_indptr_) { - rule_expr_indptr_json.push_back(picojson::value(static_cast(index_ptr))); - } - grammar_json_obj["rule_expr_indptr"] = picojson::value(rule_expr_indptr_json); - - auto grammar_json = picojson::value(grammar_json_obj); - return grammar_json.serialize(prettify_); -} - -TVM_REGISTER_GLOBAL("mlc.grammar.BNFGrammarToJSON") - .set_body_typed([](const BNFGrammar& grammar, bool prettify) { - return BNFGrammarJSONSerializer(grammar, prettify).ToString(); - }); - -} // namespace serve -} // namespace llm -} // namespace mlc diff --git a/cpp/grammar/grammar_serializer.h b/cpp/grammar/grammar_serializer.h deleted file mode 100644 index bb8ded5099..0000000000 --- a/cpp/grammar/grammar_serializer.h +++ /dev/null @@ -1,117 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_serializer.h - * \brief The header for printing the AST of a BNF grammar. - */ - -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_SERIALIZER_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_SERIALIZER_H_ - -#include - -#include "grammar.h" - -namespace mlc { -namespace llm { -namespace serve { - -/*! - * \brief Serialize the abstract syntax tree of a BNF grammar to a string. - */ -class BNFGrammarSerializer { - public: - /*! - * \brief Constructor. - * \param grammar The grammar to print. - */ - explicit BNFGrammarSerializer(const BNFGrammar& grammar) : grammar_(grammar) {} - - /*! \brief Serialize the grammar to string. */ - virtual std::string ToString() = 0; - - protected: - const BNFGrammar& grammar_; -}; - -/*! - * \brief Prints the BNF AST with standard BNF format. - */ -class BNFGrammarPrinter : public BNFGrammarSerializer { - private: - using Rule = BNFGrammarNode::Rule; - using RuleExprType = BNFGrammarNode::RuleExprType; - using RuleExpr = BNFGrammarNode::RuleExpr; - - public: - /*! - * \brief Constructor. - * \param grammar The grammar to print. - */ - explicit BNFGrammarPrinter(const BNFGrammar& grammar) : BNFGrammarSerializer(grammar) {} - - /*! \brief Print the complete grammar. */ - std::string ToString() final; - - /*! \brief Print a rule. */ - std::string PrintRule(const Rule& rule); - /*! \brief Print a rule corresponding to the given id. */ - std::string PrintRule(int32_t rule_id); - /*! \brief Print a RuleExpr. */ - std::string PrintRuleExpr(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr corresponding to the given id. */ - std::string PrintRuleExpr(int32_t rule_expr_id); - - private: - /*! \brief Print a RuleExpr for byte string. */ - std::string PrintByteString(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for character class. */ - std::string PrintCharacterClass(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for a star quantifier of a character class. */ - std::string PrintCharacterClassStar(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for empty string. */ - std::string PrintEmptyStr(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for rule reference. */ - std::string PrintRuleRef(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for rule_expr sequence. */ - std::string PrintSequence(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for rule_expr choices. */ - std::string PrintChoices(const RuleExpr& rule_expr); -}; - -/*! - * \brief Serialize the raw representation of the BNF AST to a string with JSON format. - * \sa BNFJSONParser::Parse for parsing the JSON string. - * \details JSON format: - * { - * "rules": [ - * {"name": "...", "rule_expr": rule_expr_id}, - * {"name": "...", "rule_expr": rule_expr_id}, - * ], - * "rule_expr_data": [integers...], - * "rule_expr_indptr": [integers...], - * } - */ -class BNFGrammarJSONSerializer : public BNFGrammarSerializer { - public: - /*! - * \brief Constructor. - * \param grammar The grammar to print. - */ - explicit BNFGrammarJSONSerializer(const BNFGrammar& grammar, bool prettify = true) - : BNFGrammarSerializer(grammar), prettify_(prettify) {} - - /*! - * \brief Dump the raw representation of the AST to a JSON file. - * \param prettify Whether to format the JSON string. If false, all whitespaces will be removed. - */ - std::string ToString() final; - - private: - bool prettify_; -}; - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_SERIALIZER_H_ diff --git a/cpp/grammar/grammar_state_matcher.cc b/cpp/grammar/grammar_state_matcher.cc deleted file mode 100644 index 660b8a5e3d..0000000000 --- a/cpp/grammar/grammar_state_matcher.cc +++ /dev/null @@ -1,760 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_state_matcher.cc - */ -// #define TVM_LOG_DEBUG 1 -#include "grammar_state_matcher.h" - -#include -#include - -#include "../support/dynamic_bitset.h" -#include "../tokenizers/tokenizers.h" -#include "grammar.h" -#include "grammar_serializer.h" -#include "grammar_state_matcher_base.h" -#include "grammar_state_matcher_preproc.h" -#include "grammar_state_matcher_state.h" -#include "support.h" - -namespace mlc { -namespace llm { -namespace serve { - -/* - * Note on the matching algorithm - * - * Given a context-free grammar, we match the characters in a string one by one. - * - * We adopt a non-deterministic pushdown automata (NPDA) in matching. To be specific, we maintain - * several stacks, each of which represents a possible path in the NPDA, and update the stacks - * during matching. - * - * ## Stack Structure (see grammar_state_matcher_state.h) - * The element of every stack is a RulePosition object, referring a position in the grammar. If a - * RulePosition is a RuleRef element (referring to another rule), the next element of the stack will - * be a position in this rule. If a RulePosition is a CharacterClass element, it will be the last - * in the stack, meaning *the next* character to match. - * - * ## Matching Process (see grammar_state_matcher_base.h) - * When accepting a new character and it is accepted by a stack, the last element of the stack will - * be advanced to the next position in the grammar. If it gets to the end of the rule, several - * elements at the end may be popped out, and the last element of the stack will be advanced. - * - * One stack may split since there may be multiple possible next positions. In this case, similar - * stacks with different top elements will be added. When one stack cannot accept the new character, - * it will be removed from the stacks. - * - * ## Storage of Stacks (see grammar_state_matcher_state.h) - * Note these stacks form a tree structure as when splitting, the new stacks share the same prefix. - * We store all RulePositions as a tree, where every path from tree root to a node represents a - * stack. To represent stack tops, we attach additional pointers pointing the stack top nodes. - * Also, We maintain a history of the stack top pointers, so we can rollback to the previous state. - * - * All tree nodes are maintained by a buffer, and utilize reference counting to recycle. If a node - * is neither pointed by a stack top pointer, not pointed by some child nodes, it will be freed. - * - * ## Example - * ### Grammar - * main ::= [a] R - * R ::= [b] S [c] | [b] [c] T - * S ::= "" | [c] [d] - * T ::= [e] - * - * ### The previous step - * Previous accepted string: ab - * Previous stack tree: - * A------ - * | \ \ - * B D< E< - * | - * C< - * - * A: (rule main, choice 0, element 1) - * B: (rule R, choice 0, element 1) - * C: (rule S, choice 1, element 0) - * D: (rule R, choice 0, element 2) - * E: (rule R, choice 1, element 1) - * < means the stack top pointers in the previous step. - * The stacks in the previous step is: (A, B, C), (A, D), (A, E) - * - * ### The current step - * Current accepted string: abc - * Current stack tree: - * A----------------- G<< - * | \ \ \ - * B--- D< E< H - * | \ | - * C< F<< I<< - * - * F: (rule S, choice 1, element 1) - * G: (rule main, choice 0, element 2) (means the matching process has finished, and will be deleted - * when the next char comes) - * H: (rule R, choice 1, element 2) - * I: (rule T, choice 0, element 0) - * << means the stack top pointers in the current step. - * The stacks in the current step is: (A, B, F), (A, H, I), (G,) - * - * ## Preprocess (see grammar_state_matcher_preproc.h) - * We will store all information about tokens that needed in matching in a GrammarStateInitContext - * object. Tokens are sorted by codepoint, allowing us to reuse the repeated prefixes between - * different tokens. - * - * For a given position in a rule, if we only consider this rule and its sub-rules during matching, - * without considering its parent rules (in actual matching, we also need to consider its parent - * rules), we can already determine that some tokens are acceptable while others are definitely - * rejected. Therefore, for a position in a rule, we can divide the token set into three categories: - * - accepted_indices: If a token is accepted by this rule - * - rejected_indices: If a token is rejected by this rule - * - uncertain_indices: Whether it can be accepted depends on the information from the parent - * level during actual matching. To be specific, If this token has a prefix that has not been - * rejected and has reached the end of this rule, then it is possible for it to be further accepted - * by the parent rule. - * - * During actual matching, we will directly accept or reject the tokens in accepted_indices and - * rejected_indices, and only consider the tokens in uncertain_indices. That speeds up the matching - * process. - */ - -using namespace tvm::runtime; - -TVM_REGISTER_OBJECT_TYPE(GrammarStateMatcherNode); - -/* \brief The concrete implementation of GrammarStateMatcherNode. */ -class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public GrammarStateMatcherBase { - private: - using RuleExpr = BNFGrammarNode::RuleExpr; - using RuleExprType = BNFGrammarNode::RuleExprType; - using SaveType = CatagorizedTokens::SaveType; - - public: - GrammarStateMatcherNodeImpl(std::shared_ptr init_ctx, - int max_rollback_steps = 0) - : GrammarStateMatcherBase(init_ctx->grammar), - init_ctx_(init_ctx), - max_rollback_steps_(max_rollback_steps), - tmp_accepted_bitset_(init_ctx_->vocab_size) {} - - bool AcceptToken(int32_t token_id, bool verbose = false) final; - - void FindNextTokenBitmask(DLTensor* next_token_bitmask) final; - - std::string FindJumpForwardString() final; - - void Rollback(int num_tokens) final; - - int MaxRollbackSteps() const final { return max_rollback_steps_; } - - bool IsTerminated() const { return stack_tops_history_.GetLatest().empty(); } - - void ResetState() final { - stack_tops_history_.Reset(); - token_length_history.clear(); - PushInitialState(kInvalidRulePosition, true); - } - - void SetStopTokenIds(const std::vector& stop_token_ids) final { - init_ctx_->stop_token_ids = stop_token_ids; - } - - private: - /*! - * \brief If is_uncertain_saved is true, find the next token in uncertain_indices. Otherwise, - * find the next token that is set to true in uncertain_tokens_bitset. - * \param iterator_uncertain The helper iterator to iterate over uncertain_indices or - * uncertain_tokens_bitset. - * \returns The index of the next token, or -1 if no more token. - */ - int GetNextUncertainToken(bool is_uncertain_saved, int* iterator_uncertain, - const std::vector& uncertain_indices, - const std::vector& uncertain_tokens_bitset); - - /*! \brief Set the acceptable next token in next_token_bitmask. */ - void SetTokenBitmask(DLTensor* next_token_bitmask, const DynamicBitset& accepted_bitset, - const std::vector& rejected_indices, bool can_reach_end); - - /*! - * \brief Accept the stop token and terminates the matcher. - * \returns Whether the stop token can be accepted. - */ - bool AcceptStopToken(); - - friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose); - friend NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher, int full_vocab_size); - - std::shared_ptr init_ctx_; - int max_rollback_steps_; - std::deque token_length_history; - - // Temporary data for FindNextTokenBitmask. They are stored here to avoid repeated allocation. - DynamicBitset tmp_accepted_bitset_; - std::vector tmp_rejected_indices_; - std::vector tmp_rejected_indices_delta_; -}; - -bool GrammarStateMatcherNodeImpl::AcceptStopToken() { - if (!CanReachEnd()) { - return false; - } - stack_tops_history_.PushHistory({}); // Terminate the matcher by setting the stack to empty - return true; -} - -bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id, bool verbose) { - CHECK(!IsTerminated()) - << "GrammarStateMatcher has terminated after accepting the stop token, but is trying to " - "accept another token id " - << token_id; - - CHECK(token_id >= 0 && token_id < init_ctx_->vocab_size) - << "Invalid token id " << token_id << " for GrammarStateMatcher"; - - if (verbose) { - LOG(INFO) << "Accepting token id " << token_id << ", string: \"" - << PrintAsEscaped(init_ctx_->token_table[token_id]) << "\", state state:\n" - << PrintStackState(); - } - - // Handle the stop token - if (std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), token_id) != - init_ctx_->stop_token_ids.end()) { - bool accepted = AcceptStopToken(); - if (verbose) { - LOG(INFO) << "The token is an end token. Is accepted: " << accepted; - } - return accepted; - } - - if (init_ctx_->special_token_ids.count(token_id) > 0) { - LOG(FATAL) - << "Token id " << token_id << ": " << init_ctx_->token_table[token_id] - << " is regarded as a special token, and cannot be accepted by the GrammarStateMatcher"; - } - - const auto& token = init_ctx_->token_table[token_id]; - int pos = 0; - for (auto char_value : token) { - if (!AcceptChar(char_value, false)) { - if (verbose) { - LOG(INFO) << "The token is rejected at position " << pos << ", character " - << PrintAsEscaped(char_value); - } - return false; - } - ++pos; - } - token_length_history.push_back(token.size()); - if (token_length_history.size() > max_rollback_steps_) { - DiscardEarliestChars(token_length_history.front()); - token_length_history.pop_front(); - } - if (verbose) { - LOG(INFO) << "The token is accepted. State after accepting:\n" << PrintStackState(); - } - return true; -} - -void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitmask) { - CHECK(!IsTerminated()) - << "GrammarStateMatcher has terminated after accepting the stop token, but is trying to " - "find the next token mask"; - const auto& sorted_token_table = init_ctx_->sorted_token_table; - const auto& catagorized_tokens_for_grammar = init_ctx_->catagorized_tokens_for_grammar; - const auto& latest_stack_tops = stack_tops_history_.GetLatest(); - - // We check all the stacks one by one, and find the accepted token set or the rejected token set - // for each stack. We will try to find the small one of the two sets. - // The final accepted token set is the union of the accepted token sets of all stacks. - // The final rejected token set is the intersection of the rejected token sets of all stacks. - - // Note these indices store the indices in sorted_token_table, instead of the token ids. - tmp_accepted_bitset_.Reset(); - // {-1} means the universal set, i.e. all tokens initially - tmp_rejected_indices_.assign({-1}); - - int check_cnt = 0; - - for (auto top : latest_stack_tops) { - auto cur_rule_position = tree_[top]; - if (tree_.IsEndPosition(cur_rule_position)) { - continue; - } - - const auto& catagorized_tokens = catagorized_tokens_for_grammar.at(cur_rule_position); - - // For each stack, we will check every uncertain token and put them into the accepted or - // rejected list. - - // Step 2. Update the accepted tokens in accepted_indices_delta, or the rejected tokens in - // rejected_indices_delta. - - // If the accepted tokens are saved, it means it is likely to be smaller than the rejected - // tokens, so we will just find the accepted tokens, and vice versa. - - tmp_rejected_indices_delta_.clear(); - - // Examine only the current one stack - stack_tops_history_.PushHistory({tree_.NewNode(cur_rule_position)}); - - const std::string* prev_token = nullptr; - int prev_matched_size = 0; - - for (auto cur_token_idx : catagorized_tokens.uncertain_indices) { - const auto& cur_token = sorted_token_table[cur_token_idx].second; - bool accepted = true; - - // Step 2.1. Find the longest common prefix with the accepted part of the previous token. - // We can reuse the previous matched size to avoid unnecessary matching. - if (prev_token) { - int lcp_len = std::mismatch(cur_token.begin(), cur_token.end(), prev_token->begin(), - prev_token->end()) - .first - - cur_token.begin(); - if (lcp_len > prev_matched_size) { - accepted = false; - } else if (lcp_len < prev_matched_size) { - RollbackChars(prev_matched_size - lcp_len); - } - prev_matched_size = std::min(prev_matched_size, lcp_len); - } - - // Step 2.2. Find if the current token is accepted or rejected. - if (accepted) { - for (int j = prev_matched_size; j < cur_token.size(); ++j) { - ++check_cnt; - if (!AcceptChar(cur_token[j], false)) { - accepted = false; - break; - } - prev_matched_size = j + 1; - } - } - - // Step 2.3. Push the result to the delta list. - if (catagorized_tokens.save_type == SaveType::kAcceptedBitset || - catagorized_tokens.save_type == SaveType::kAccepted) { - if (accepted) { - tmp_accepted_bitset_.Set(sorted_token_table[cur_token_idx].first, true); - } - } else { - if (!accepted) { - tmp_rejected_indices_delta_.push_back(cur_token_idx); - } - } - - prev_token = &cur_token; - } - - RollbackChars(prev_matched_size + 1); - - // Step 3. Update the accepted_indices or rejected_indices - if (catagorized_tokens.save_type == SaveType::kAcceptedBitset) { - tmp_accepted_bitset_ |= catagorized_tokens.accepted_bitset; - } else if (catagorized_tokens.save_type == SaveType::kAccepted) { - for (auto idx : catagorized_tokens.accepted_indices) { - tmp_accepted_bitset_.Set(sorted_token_table[idx].first, true); - } - } else { - // rejected_indices = Intersect( - // rejected_indices, - // catagorized_tokens.rejected_indices + rejected_indices_delta) - IntsetUnion(&tmp_rejected_indices_delta_, catagorized_tokens.rejected_indices); - IntsetIntersection(&tmp_rejected_indices_, tmp_rejected_indices_delta_); - } - } - - // Finally update the rejected_ids bitset - bool can_reach_end = CanReachEnd(); - SetTokenBitmask(next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end); - - // Up till now, we use vocab_size from `GetVocabSize()`, while `next_token_bitmask` is of - // vocab_size read from `config.json`. For models like QWen2 and Phi3, the latter can be larger. - // So we further mask out the dummy padded tokens. - CHECK(next_token_bitmask->ndim == 1); - DynamicBitset next_token_bitset(next_token_bitmask->shape[0] * 32, - reinterpret_cast(next_token_bitmask->data)); - for (int i = init_ctx_->vocab_size; i < next_token_bitmask->shape[0] * 32; i++) { - next_token_bitset.Set(i, false); - } -} - -std::string GrammarStateMatcherNodeImpl::FindJumpForwardString() { - CHECK(!IsTerminated()) - << "GrammarStateMatcher has terminated after accepting the stop token, but is trying to " - "get the jump forward string"; - - std::string result; - int num_accepted_chars = 0; - bool can_find_next_char = true; - - while (can_find_next_char) { - const auto& stack_tops = stack_tops_history_.GetLatest(); - - // 1. Check that for every stack top, the next possible char is unique and the same - // -1 means not found yet; 0~255 means the next char - int next_char = -1; - for (auto stack_top : stack_tops) { - auto rule_position = tree_[stack_top]; - auto cur_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); - if (rule_position.parent_id == RulePosition::kNoParent && - rule_position.element_id == cur_sequence.size()) { - can_find_next_char = false; - break; - } - - auto cur_element = grammar_->GetRuleExpr(cur_sequence[rule_position.element_id]); - - if (cur_element.type == RuleExprType::kByteString) { - DCHECK(rule_position.element_in_string < cur_element.size()); - if (next_char == -1) { - next_char = cur_element[rule_position.element_in_string]; - } else if (next_char != cur_element[rule_position.element_in_string]) { - can_find_next_char = false; - break; - } - } else { - DCHECK(cur_element.type == RuleExprType::kCharacterClass || - cur_element.type == RuleExprType::kCharacterClassStar); - if (rule_position.left_utf8_bytes > 0 || cur_element.size() != 3 || cur_element[0] != 0 || - cur_element[1] != cur_element[2]) { - can_find_next_char = false; - break; - } else if (next_char == -1) { - next_char = cur_element[1]; - } else if (next_char != cur_element[1]) { - can_find_next_char = false; - break; - } - } - } - - if (next_char == -1) { - can_find_next_char = false; - } - - // 2. If found, accept the char and iterate to the next position - if (can_find_next_char) { - result += static_cast(next_char); - - tmp_new_stack_tops_.clear(); - for (auto stack_top : stack_tops) { - auto cur_rule_position = tree_[stack_top]; - auto new_rule_position = UpdatePositionWithChar(cur_rule_position, next_char); - - if (new_rule_position == cur_rule_position) { - ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true, stack_top); - } else { - ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true); - } - } - stack_tops_history_.PushHistory(tmp_new_stack_tops_); - ++num_accepted_chars; - } - } - - // Rollback all chars accepted - RollbackChars(num_accepted_chars); - return result; -} - -void GrammarStateMatcherNodeImpl::Rollback(int num_tokens) { - CHECK(num_tokens <= token_length_history.size()) - << "Intended to rollback " << num_tokens << " tokens, but only the last " - << token_length_history.size() << " steps of history are saved"; - while (num_tokens > 0) { - int steps = token_length_history.back(); - RollbackChars(steps); - token_length_history.pop_back(); - --num_tokens; - } -} - -void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, - const DynamicBitset& accepted_bitset, - const std::vector& rejected_indices, - bool can_reach_end) { - // next_token_bitmask = set(all accepted tokens) = - // 1. all_tokens - (rejected_ids / accepted_ids) - // (when rejected_ids != {-1}, i.e. rejected_ids is not the universal set) - // 2. accepted_ids - // (otherwise, when rejected_ids is the universal set) - CHECK(next_token_bitmask->dtype.code == kDLUInt && next_token_bitmask->dtype.bits == 32 && - next_token_bitmask->data && next_token_bitmask->ndim == 1 && next_token_bitmask->shape) - << "The provied bitmask's shape or dtype is not valid."; - CHECK(next_token_bitmask->shape[0] >= DynamicBitset::CalculateBufferSize(init_ctx_->vocab_size)) - << "The provided bitmask is not large enough to store the token set. The length should be " - << DynamicBitset::CalculateBufferSize(init_ctx_->vocab_size) << " at least"; - - DynamicBitset next_token_bitset(init_ctx_->vocab_size, - reinterpret_cast(next_token_bitmask->data)); - const auto& sorted_token_table = init_ctx_->sorted_token_table; - - if (rejected_indices.size() == 1 && rejected_indices[0] == -1) { - // If rejected_indices is the universal set, the final accepted token set is just - // accepted_indices - next_token_bitset = accepted_bitset; - - if (can_reach_end) { - // add end tokens - for (int id : init_ctx_->stop_token_ids) { - next_token_bitset.Set(id, true); - } - } - } else { - // Otherwise, the final rejected token set is (rejected_indices \ accepted_indices) - next_token_bitset.Set(); - - for (auto i : rejected_indices) { - auto id = sorted_token_table[i].first; - if (!accepted_bitset[id]) { - next_token_bitset.Set(id, false); - } - } - - for (int id : init_ctx_->special_token_ids) { - next_token_bitset.Set(id, false); - } - if (!can_reach_end) { - for (int id : init_ctx_->stop_token_ids) { - next_token_bitset.Set(id, false); - } - } - } -} - -int GrammarStateMatcherNodeImpl::GetNextUncertainToken( - bool is_uncertain_saved, int* iterator_uncertain, const std::vector& uncertain_indices, - const std::vector& uncertain_tokens_bitset) { - if (is_uncertain_saved) { - ++*iterator_uncertain; - if (*iterator_uncertain == uncertain_indices.size()) { - return -1; - } - return uncertain_indices[*iterator_uncertain]; - } else { - ++*iterator_uncertain; - while (*iterator_uncertain < uncertain_tokens_bitset.size() && - !uncertain_tokens_bitset[*iterator_uncertain]) { - ++*iterator_uncertain; - } - if (*iterator_uncertain == uncertain_tokens_bitset.size()) { - return -1; - } - return *iterator_uncertain; - } -} - -GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr init_ctx, - int max_rollback_steps) - : ObjectRef(make_object(init_ctx, max_rollback_steps)) {} - -#ifndef COMPILE_MLC_WASM_RUNTIME -// This creates tokenizer dependency issue in WASM building for web, hence skipped -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFromTokenizer") - .set_body_typed([](BNFGrammar grammar, Optional tokenizer, int max_rollback_steps) { - auto preproc_start = std::chrono::high_resolution_clock::now(); - std::shared_ptr init_ctx; - if (tokenizer) { - init_ctx = GrammarStateMatcher::CreateInitContext( - grammar, tokenizer.value()->PostProcessedTokenTable()); - } else { - init_ctx = GrammarStateMatcher::CreateInitContext(grammar, {}); - } - - auto preproc_end = std::chrono::high_resolution_clock::now(); - LOG(INFO) << "GrammarStateMatcher preprocess takes " - << std::chrono::duration_cast(preproc_end - - preproc_start) - .count() - << "us"; - return GrammarStateMatcher(init_ctx, max_rollback_steps); - }); -#endif - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFromTokenTable") - .set_body([](TVMArgs args, TVMRetValue* rv) { - BNFGrammar grammar = args[0]; - Array token_table_arr = args[1]; - std::vector token_table; - for (int i = 0; i < token_table_arr.size(); ++i) { - token_table.push_back(token_table_arr[i]); - } - int max_rollback_steps = args[args.size() - 1]; - auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table); - *rv = GrammarStateMatcher(init_ctx, max_rollback_steps); - }); - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherDebugAcceptChar") - .set_body_typed([](GrammarStateMatcher matcher, int32_t codepoint, bool verbose) { - auto mutable_node = - const_cast(matcher.as()); - return mutable_node->AcceptChar(codepoint, verbose); - }); - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherAcceptToken") - .set_body_typed([](GrammarStateMatcher matcher, int32_t token_id, bool verbose) { - return matcher->AcceptToken(token_id, verbose); - }); - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFindJumpForwardString") - .set_body_typed([](GrammarStateMatcher matcher) { return matcher->FindJumpForwardString(); }); - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherRollback") - .set_body_typed([](GrammarStateMatcher matcher, int num_tokens) { - matcher->Rollback(num_tokens); - }); - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherMaxRollbackSteps") - .set_body_typed([](GrammarStateMatcher matcher) { return matcher->MaxRollbackSteps(); }); - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherIsTerminated") - .set_body_typed([](GrammarStateMatcher matcher) { return matcher->IsTerminated(); }); - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherResetState") - .set_body_typed([](GrammarStateMatcher matcher) { matcher->ResetState(); }); - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherSetStopTokenIds") - .set_body_typed([](GrammarStateMatcher matcher, IntTuple stop_token_ids) { - std::vector stop_token_ids_vector{stop_token_ids.begin(), stop_token_ids.end()}; - matcher->SetStopTokenIds(stop_token_ids_vector); - }); - -/*! \brief Check if a matcher can accept the complete string, and then reach the end of the - * grammar. Does not change the state of the GrammarStateMatcher. For test purpose. */ -bool MatchCompleteString(GrammarStateMatcher matcher, String str, bool verbose) { - auto mutable_node = - const_cast(matcher.as()); - int accepted_cnt = 0; - for (auto char_value : str.operator std::string()) { - if (!mutable_node->AcceptChar(char_value, verbose)) { - if (verbose) { - LOG(INFO) << "Matching failed after accepting " << accepted_cnt << " characters"; - } - mutable_node->RollbackChars(accepted_cnt); - return false; - } - ++accepted_cnt; - } - auto accepted = mutable_node->CanReachEnd(); - if (verbose) { - if (accepted) { - LOG(INFO) << "Matching succeed after accepting " << accepted_cnt << " characters"; - } else { - LOG(INFO) << "Matching failed due to the end state not reached after all " << accepted_cnt - << " characters are accepted"; - } - } - mutable_node->RollbackChars(accepted_cnt); - return accepted; -} - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherDebugMatchCompleteString") - .set_body_typed([](GrammarStateMatcher matcher, String str, bool verbose) { - return MatchCompleteString(matcher, str, verbose); - }); - -/*! \brief Print the accepted and rejected tokens stored in the bitset. For debug purposes. */ -std::string PrintAcceptedRejectedTokens( - const std::shared_ptr& init_ctx, - const DynamicBitset& bitset, int threshold = 300) { - std::stringstream ss; - auto vocab_size = init_ctx->vocab_size; - std::vector accepted_ids; - std::vector rejected_ids; - for (int i = 0; i < vocab_size; i++) { - if (bitset[i]) { - accepted_ids.push_back(i); - } else { - rejected_ids.push_back(i); - } - } - - ss << "Accepted: "; - auto end_it = - accepted_ids.size() > threshold ? accepted_ids.begin() + threshold : accepted_ids.end(); - for (auto it = accepted_ids.begin(); it != end_it; ++it) { - ss << "<" << PrintAsEscaped(init_ctx->token_table[*it]) << "> "; - } - if (accepted_ids.size() > threshold) { - ss << "..."; - } - ss << "\n"; - - ss << "Rejected: "; - end_it = rejected_ids.size() > threshold ? rejected_ids.begin() + threshold : rejected_ids.end(); - for (auto it = rejected_ids.begin(); it != end_it; ++it) { - ss << "<" << PrintAsEscaped(init_ctx->token_table[*it]) << "> "; - } - if (rejected_ids.size() > threshold) { - ss << "..."; - } - ss << "\n"; - return ss.str(); -} - -/*! - * \brief Find the ids of the rejected tokens for the next step. For debug purposes. - * \param matcher The matcher to test. - * \param verbose Whether to print information about the timing and results to stderr. - * \returns A tuple of rejected token ids. - */ -IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = false) { - auto init_ctx = matcher.as()->init_ctx_; - auto vocab_size = init_ctx->vocab_size; - auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size); - auto ndarray = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, - DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); - auto dltensor = const_cast(ndarray.operator->()); - - std::chrono::time_point start, end; - if (verbose) { - start = std::chrono::high_resolution_clock::now(); - } - matcher->FindNextTokenBitmask(dltensor); - if (verbose) { - end = std::chrono::high_resolution_clock::now(); - } - - auto bitset = DynamicBitset(vocab_size, reinterpret_cast(dltensor->data)); - std::vector rejected_ids; - for (int i = 0; i < vocab_size; i++) { - if (bitset[i] == 0) { - rejected_ids.push_back(i); - } - } - - if (verbose) { - LOG(INFO) << "FindNextTokenBitmask takes " - << std::chrono::duration_cast(end - start).count() << "us" - << ", found accepted: " << vocab_size - rejected_ids.size() - << ", rejected: " << rejected_ids.size(); - } - - auto ret = IntTuple(rejected_ids); - return ret; -} - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFindNextRejectedTokens") - .set_body_typed(FindNextRejectedTokens); - -/*! - * \brief Find the bitmask for the next token as an NDArray. - * \param full_vocab_size Different from `tokenizer->GetVocabSize()` or `init_ctx_->vocab_size`, - * this is the vocab_size read from `config.json` that can be potentially larger. - * \returns An NDArray of the bitmask for the next token of shape (bitmask_size,). - */ -NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher, int full_vocab_size) { - auto bitset_size = DynamicBitset::CalculateBufferSize(full_vocab_size); - auto bitmask = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, - DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); - auto dltensor = const_cast(bitmask.operator->()); - matcher->FindNextTokenBitmask(dltensor); - return bitmask; -} - -TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFindNextTokenBitmaskAsNDArray") - .set_body_typed(FindNextTokenBitmaskAsNDArray); - -} // namespace serve -} // namespace llm -} // namespace mlc diff --git a/cpp/grammar/grammar_state_matcher.h b/cpp/grammar/grammar_state_matcher.h deleted file mode 100644 index 98fda522d0..0000000000 --- a/cpp/grammar/grammar_state_matcher.h +++ /dev/null @@ -1,185 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_state_matcher.h - * \brief The header for the support of matching tokens to BNF grammar. This is the core - * logic of the grammar-guided generation. - */ - -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_H_ - -#include -#include - -#include -#include -#include - -#include "../support/encoding.h" -#include "grammar.h" -#include "support.h" - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm::runtime; - -/*! - * \brief A stateful matcher to match tokens to the specified BNF grammar. This class is the core - * logic of the grammar-guided generation. - * - * \details This class implements the non-deterministic pushdown automaton (NPDA) matching algorithm - * to match characters to a BNF grammar. It keep track of the current state of the matching process - * by maintaining several stacks internally as possible paths in the NPDA. It also supports - * backtracking. - * - * It is particularly capable of finding the set of tokens that are acceptable for the next step - * and storing them in a bitmask. This aids in grammar-guided generation. - * - * \example - * \code - * Tokenizer tokenizer = ...; - * auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, - * tokenizer->PostProcessedTokenTable()); - * GrammarStateMatcher matcher(init_ctx, 10); - * matcher->AcceptToken(67); - * - * // Construct a DLTensor with shape (tokenizer.GetVocabSize() + 31) / 32, and dtype uint32. - * DLTensor next_token_bitmask = ...; - * matcher->FindNextTokenBitmask(&next_token_bitmask); - * - * // Rollback is supported - * matcher->Rollback(1); - * \endcode - */ -class GrammarStateMatcherNode : public Object { - public: - /*! - * \brief Accept one token and update the state of the matcher. - * \param token_id The id of the token to accept. - * \return Whether the token is accepted. - * \note Termination state. - * When the end of the main rule is reached, the matcher can only accept the stop token. - * The matcher is terminated after accepting the stop token, i.e. no AcceptToken or - * FindNextTokenMask operations can be performed. The termination state can be canceled - * using Rollback(). - */ - virtual bool AcceptToken(int32_t token_id, bool verbose = false) = 0; - - /*! - * \brief Find the set of tokens that are acceptable for the next step and store them in a - * bitmask. - * \param next_token_bitmask The bitmask to store the result. The bitmask must be pre-allocated, - * and its shape needs to be (ceil(vocab_size, 32),), with a dtype of uint32. - */ - virtual void FindNextTokenBitmask(DLTensor* next_token_bitmask) = 0; - - /*! - * \brief Find the jump-forward string for jump-forward decoding. This is the longest string that - will be valid according to the current syntax. - * \note This method does not change the grammar state. - */ - virtual std::string FindJumpForwardString() = 0; - - /*! - * \brief Rollback the matcher to a previous state. - * \param num_tokens The number of tokens to rollback. It cannot exceed the current number of - * steps, nor can it exceed the specified maximum number of rollback steps. - */ - virtual void Rollback(int num_tokens) = 0; - - /*! \brief Get the maximum number of rollback steps allowed. */ - virtual int MaxRollbackSteps() const = 0; - - /*! - * \brief Check if the matcher has accepted the stop token and terminated. - * \sa AcceptToken - */ - virtual bool IsTerminated() const = 0; - - /*! \brief Reset the matcher to the initial state. */ - virtual void ResetState() = 0; - - /*! \brief Set the stop token ids, overriding the existing defaults ones. */ - virtual void SetStopTokenIds(const std::vector& stop_token_ids) = 0; - - static constexpr const char* _type_key = "mlc.grammar.GrammarStateMatcher"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(GrammarStateMatcherNode, Object); -}; - -/*! - * \brief The init context of a GrammarStateMatcher. It contains the preprocessing results of the - * grammar and tokenizer. - */ -class GrammarStateInitContext; - -class GrammarStateMatcher : public ObjectRef { - public: - /*! - * \brief Construct a GrammarStateMatcher from the preprocessing result of type - * GrammarStateInitContext. - * \param init_ctx The init context. It is obtained through - * CreateInitContext as a result of preprocessing the grammar and tokenizer. - */ - GrammarStateMatcher(std::shared_ptr init_ctx, - int max_rollback_steps = 0); - - /*! - * \brief Specify a grammar and token_table to return their preprocessing results. These results - * are used to construct a GrammarStateMatcher. They can be stored elsewhere for quick - * construction of GrammarStateMatcher. - * \param grammar The grammar that the matcher follows. - * \param token_table The tokens that the matcher requires for matching. - */ - static std::shared_ptr CreateInitContext( - const BNFGrammar& grammar, const std::vector& token_table); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarStateMatcher, ObjectRef, GrammarStateMatcherNode); -}; - -/*! - * \brief A cache to get the grammar state init context for grammar or schema. This class avoids - * redundant preprocessing of the grammar or schema when constructing a GrammarStateInitContext. - * \note This class is associated with a token table when constructed. The token table is used to - * create every grammar state init context. If multiple toke tables are used to create init - * contexts, an instance of this class for each token table should be created. - */ -class GrammarInitContextCacheNode : public Object { - public: - /*! \brief Get the init context for pure JSON. */ - virtual std::shared_ptr GetInitContextForJSON() = 0; - - /*! \brief Get the init context for a JSON schema string. */ - virtual std::shared_ptr GetInitContextForJSONSchema( - const std::string& schema) = 0; - - /*! \brief Clear the interal cache of init contexts. */ - virtual void Clear() = 0; - - static constexpr const char* _type_key = "mlc.serve.GrammarInitContextCacheNode"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(GrammarInitContextCacheNode, Object); -}; - -class GrammarInitContextCache : public ObjectRef { - public: - /*! - * \brief Construct a GrammarInitContextCache with a token table. This class will always create - * grammar state init contexts with this token table. - * \param token_table The token table that the grammar will use. - */ - GrammarInitContextCache(const std::vector& token_table); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarInitContextCache, ObjectRef, - GrammarInitContextCacheNode); -}; - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_H_ diff --git a/cpp/grammar/grammar_state_matcher_base.h b/cpp/grammar/grammar_state_matcher_base.h deleted file mode 100644 index a26a482eac..0000000000 --- a/cpp/grammar/grammar_state_matcher_base.h +++ /dev/null @@ -1,401 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_state_matcher_base.h - * \brief The base class of GrammarStateMatcher. It implements a character-based matching automata. - */ -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_ - -#include - -#include "grammar.h" -#include "grammar_state_matcher_state.h" - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm::runtime; - -/*! \brief The base class of GrammarStateMatcher. It implements a character-based matching - * automata, and supports accepting a character, rolling back by character, etc. - */ -class GrammarStateMatcherBase { - protected: - using RuleExpr = BNFGrammarNode::RuleExpr; - using RuleExprType = BNFGrammarNode::RuleExprType; - - public: - /*! - * \brief Construct a GrammarStateMatcherBase with the given grammar and initial rule position. - * \param grammar The grammar to match. - * \param init_rule_position The initial rule position. If not specified, the main rule will be - * used. - * \param expand_init_rule_position Whether to expand the initial rule position to all possible - * locations. See ExpandRulePosition. - */ - GrammarStateMatcherBase(const BNFGrammar& grammar, - RulePosition init_rule_position = kInvalidRulePosition, - bool expand_init_rule_position = true) - : grammar_(grammar), tree_(grammar), stack_tops_history_(&tree_) { - PushInitialState(init_rule_position, expand_init_rule_position); - } - - /*! \brief Accept one character. */ - bool AcceptChar(uint8_t char_value, bool verbose = false); - - /*! \brief Check if the end of the main rule is reached. If so, the stop token can be accepted. */ - bool CanReachEnd() const; - - /*! \brief Rollback the matcher to a previous state by the number of characters. */ - void RollbackChars(int rollback_cnt); - - /*! \brief Discard the earliest history by the number of characters. */ - void DiscardEarliestChars(int discard_cnt); - - /*! \brief Print the stack state. */ - std::string PrintStackState(int steps_behind_latest = 0) const; - - protected: - // Push an initial stack state according to the given rule position. - // If init_rule_position is kInvalidRulePosition, init the stack with the main rule. - void PushInitialState(RulePosition init_rule_position, bool expand_init_rule_position); - - // Check if the character is accepted by the current rule position. - bool CheckIfAccepted(const RulePosition& rule_position, uint8_t char_value) const; - - /*! - * \brief Find the next position in the rule. If the next position is at the end of the rule, - * and consider_parent is true, will iteratively find the next position in the parent rule. - * \param rule_position The current position. - * \param consider_parent Whether to consider the parent position if the current position is - * at the end of the rule. - * \returns (success, next_rule_position), indicating if the iteration is successful and the - * next rule position. - */ - std::pair GetNextPositionInSequence(const RulePosition& rule_position, - bool consider_parent) const; - - // Return the updated rule position after accepting the char - RulePosition UpdatePositionWithChar(const RulePosition& rule_position, uint8_t char_value) const; - - /*! - * \brief Expand the given rule position to all possible positions approachable in the grammar. - * The expanded positions must refers to an element (CharacterClass or CharacterClassStar or - * ByteString) in a rule. Push all new positions into new_stack_tops. - * \example - * A ::= "a" B [a-z]* "c" - * B ::= "b" | "" - * - * Input position: (rule=A, position=B) - * Approachable positions: (rule=B, position="b"), (rule=A, position=[a-z]*), - * (rule=A, position="c"), since B and [a-z]* can be empty. - * \param cur_rule_position The current rule position. - * \param new_stack_tops The vector to store the new stack tops. - * \param consider_parent Whether consider expanding the elements in the parent rule. Useful for - * inner recursion. - * \param first_id_if_inserted An optimization. When cur_rule_position is already inserted to - * the state tree, pass its id to avoid inserting it again. -1 (ignore it) by default. - * \return Whether the end of the rule can be reached. Useful for inner recursion. - */ - bool ExpandRulePosition(RulePosition cur_rule_position, std::vector* new_stack_tops, - bool consider_parent = true, int32_t first_id_if_inserted = -1); - - // The matched grammar. - BNFGrammar grammar_; - // The tree storing all states - RulePositionTree tree_; - // The tracked history of stack tops (each stack top refers to a node in the tree). - // We store the stack tops in different steps in the history to support rollback. - StackTopsHistory stack_tops_history_; - - // Temporary data for AcceptChar, PushInitialState, etc to store new stacks. - // They are stored here to avoid repeated allocation. - std::vector tmp_new_stack_tops_; -}; - -/*! \brief Check the codepoint is contained in the character class. */ -inline bool GrammarStateMatcherBase::CheckIfAccepted(const RulePosition& rule_position, - uint8_t char_value) const { - auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); - auto current_element = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); - if (current_element.type == RuleExprType::kCharacterClass || - current_element.type == RuleExprType::kCharacterClassStar) { - if (rule_position.left_utf8_bytes > 0) { - return (char_value & 0xC0) == 0x80; - } - auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); - if (!accepted) { - return false; - } - bool is_negative = static_cast(current_element[0]); - if (num_bytes > 1) { - return is_negative; - } - for (int i = 1; i < current_element.size(); i += 2) { - if (current_element[i] <= char_value && char_value <= current_element[i + 1]) { - return !is_negative; - } - } - return is_negative; - } else if (current_element.type == RuleExprType::kByteString) { - return current_element[rule_position.element_in_string] == char_value; - } else { - LOG(FATAL) << "Unexpected RuleExprType in CheckIfAccepted: " - << static_cast(current_element.type); - } -} - -inline RulePosition GrammarStateMatcherBase::UpdatePositionWithChar( - const RulePosition& rule_position, uint8_t char_value) const { - auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); - auto current_element = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); - RulePosition new_rule_position = rule_position; - switch (current_element.type) { - case RuleExprType::kCharacterClass: { - if (rule_position.left_utf8_bytes > 1) { - new_rule_position.left_utf8_bytes -= 1; - return new_rule_position; - } else if (rule_position.left_utf8_bytes == 1) { - return GetNextPositionInSequence(rule_position, true).second; - } - // If no left utf8 bytes, check the first byte to find the left bytes needed. - DCHECK(rule_position.left_utf8_bytes == 0); - auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); - DCHECK(accepted); - if (num_bytes > 1) { - new_rule_position.left_utf8_bytes = num_bytes - 1; - return new_rule_position; - } - return GetNextPositionInSequence(rule_position, true).second; - } - case RuleExprType::kCharacterClassStar: { - if (rule_position.left_utf8_bytes >= 1) { - new_rule_position.left_utf8_bytes -= 1; - } else { - DCHECK(rule_position.left_utf8_bytes == 0); - auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); - DCHECK(accepted); - new_rule_position.left_utf8_bytes = num_bytes - 1; - } - return new_rule_position; - } - case RuleExprType::kByteString: { - if (rule_position.element_in_string + 1 < current_element.size()) { - new_rule_position.element_in_string += 1; - return new_rule_position; - } - return GetNextPositionInSequence(rule_position, true).second; - } - default: - LOG(FATAL) << "Unexpected RuleExprType in UpdatePositionWithChar: " - << static_cast(current_element.type); - } -} - -inline bool GrammarStateMatcherBase::AcceptChar(uint8_t char_value, bool verbose) { - if (verbose) { - LOG(INFO) << "Matching char: " << static_cast(char_value) << " \"" - << PrintAsEscaped(char_value) << "\""; - LOG(INFO) << "Previous stack: " << PrintStackState(); - } - const auto& prev_stack_tops = stack_tops_history_.GetLatest(); - - tmp_new_stack_tops_.clear(); - for (auto prev_top : prev_stack_tops) { - auto cur_rule_position = tree_[prev_top]; - auto current_sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); - if (cur_rule_position.parent_id == RulePosition::kNoParent && - cur_rule_position.element_id == current_sequence.size()) { - // This RulePosition means previous elements has matched the complete rule. - // But we are still need to accept a new character, so this stack will become invalid. - continue; - } - - auto accepted = CheckIfAccepted(cur_rule_position, char_value); - if (!accepted) { - continue; - } - - auto new_rule_position = UpdatePositionWithChar(cur_rule_position, char_value); - - if (new_rule_position == cur_rule_position) { - ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true, prev_top); - } else { - ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true); - } - } - if (tmp_new_stack_tops_.empty()) { - if (verbose) { - LOG(INFO) << "Character " << static_cast(char_value) << " \"" - << PrintAsEscaped(char_value) << "\" Rejected"; - } - return false; - } - stack_tops_history_.PushHistory(tmp_new_stack_tops_); - if (verbose) { - LOG(INFO) << "Character: " << static_cast(char_value) << " \"" - << PrintAsEscaped(char_value) << "\" Accepted"; - LOG(INFO) << "New stack after acceptance: " << PrintStackState(); - } -#if TVM_LOG_DEBUG - stack_tops_history_.CheckWellFormed(); -#endif - return true; -} - -inline bool GrammarStateMatcherBase::CanReachEnd() const { - const auto& last_stack_tops = stack_tops_history_.GetLatest(); - return std::any_of(last_stack_tops.begin(), last_stack_tops.end(), - [&](int32_t id) { return tree_.IsEndPosition(tree_[id]); }); -} - -inline void GrammarStateMatcherBase::RollbackChars(int rollback_cnt) { - stack_tops_history_.Rollback(rollback_cnt); -} - -inline void GrammarStateMatcherBase::DiscardEarliestChars(int discard_cnt) { - stack_tops_history_.DiscardEarliest(discard_cnt); -} - -inline std::string GrammarStateMatcherBase::PrintStackState(int steps_behind_latest) const { - return stack_tops_history_.PrintHistory(steps_behind_latest); -} - -inline void GrammarStateMatcherBase::PushInitialState(RulePosition init_rule_position, - bool expand_init_rule_position) { - if (init_rule_position == kInvalidRulePosition) { - // Initialize the stack with the main rule. - auto main_rule = grammar_->GetMainRule(); - auto main_rule_body = grammar_->GetRuleExpr(main_rule.body_expr_id); - tmp_new_stack_tops_.clear(); - for (auto i : main_rule_body) { - auto init_rule_position = RulePosition(0, i, 0, RulePosition::kNoParent); - if (expand_init_rule_position) { - ExpandRulePosition(init_rule_position, &tmp_new_stack_tops_, true); - } else { - tmp_new_stack_tops_.push_back(tree_.NewNode(init_rule_position)); - } - } - stack_tops_history_.PushHistory(tmp_new_stack_tops_); - } else { - if (expand_init_rule_position) { - tmp_new_stack_tops_.clear(); - ExpandRulePosition(init_rule_position, &tmp_new_stack_tops_, true); - stack_tops_history_.PushHistory(tmp_new_stack_tops_); - } else { - stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)}); - } - } -} - -inline std::pair GrammarStateMatcherBase::GetNextPositionInSequence( - const RulePosition& rule_position, bool consider_parent) const { - auto sequence = grammar_->GetRuleExpr(rule_position.sequence_id); - - auto next_position = rule_position; - next_position.element_id += 1; - next_position.element_in_string = 0; - next_position.left_utf8_bytes = 0; - - DCHECK(next_position.element_id <= sequence.size()); - - if (next_position.element_id < sequence.size()) { - return {true, next_position}; - } - - if (!consider_parent) { - return {false, kInvalidRulePosition}; - } - - // Find the next position in the parent rule - while (next_position.parent_id != RulePosition::kNoParent) { - next_position = tree_[next_position.parent_id]; - next_position.element_id += 1; - DCHECK(next_position.element_in_string == 0); - DCHECK(next_position.left_utf8_bytes == 0); - - sequence = grammar_->GetRuleExpr(next_position.sequence_id); - DCHECK(next_position.element_id <= sequence.size()); - - if (next_position.element_id < sequence.size()) { - break; - } - } - - return {true, next_position}; -} - -inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_position, - std::vector* new_stack_tops, - bool consider_parent, - int32_t first_id_if_inserted) { - bool is_first = false; - bool is_iteration_successful = true; - - for (; is_iteration_successful; - std::tie(is_iteration_successful, cur_rule_position) = - GetNextPositionInSequence(cur_rule_position, consider_parent)) { - // Insert the node to the tree, if not inserted before. - int32_t new_node_id; - if (is_first && first_id_if_inserted != -1) { - new_node_id = first_id_if_inserted; - } else { - new_node_id = tree_.NewNode(cur_rule_position); - } - is_first = false; - - // Case 1. The current position points to the end of the grammar. - if (consider_parent) { - if (tree_.IsEndPosition(cur_rule_position)) { - new_stack_tops->push_back(new_node_id); - return true; - } - } else { - DCHECK(!tree_.IsEndPosition(cur_rule_position)); - } - - auto sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); - auto element = grammar_->GetRuleExpr(sequence[cur_rule_position.element_id]); - bool can_be_empty = false; - - if (element.type == RuleExprType::kRuleRef) { - // Case 2. The current position refers to another rule. - auto ref_rule = grammar_->GetRule(element[0]); - auto ref_rule_body = grammar_->GetRuleExpr(ref_rule.body_expr_id); - DCHECK(ref_rule_body.type == RuleExprType::kChoices); - - for (auto sequence_id : ref_rule_body) { - auto ref_rule_sequence = grammar_->GetRuleExpr(sequence_id); - if (ref_rule_sequence.type == RuleExprType::kEmptyStr) { - can_be_empty = true; - continue; - } - auto ref_rule_position = RulePosition(element[0], sequence_id, 0, new_node_id); - // Find the positions in every choice of the referred rule - can_be_empty |= ExpandRulePosition(ref_rule_position, new_stack_tops, false); - } - } else if (element.type == RuleExprType::kCharacterClass || - element.type == RuleExprType::kByteString) { - // Case 3. Character class or byte string. cannot be empty. - new_stack_tops->push_back(new_node_id); - can_be_empty = false; - } else { - DCHECK(element.type == RuleExprType::kCharacterClassStar); - // Case 4. Character class star. Might be empty. - new_stack_tops->push_back(new_node_id); - can_be_empty = cur_rule_position.left_utf8_bytes == 0; - } - - if (!can_be_empty) { - return false; - } - } - return true; -} - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_ diff --git a/cpp/grammar/grammar_state_matcher_preproc.h b/cpp/grammar/grammar_state_matcher_preproc.h deleted file mode 100644 index bad42683d0..0000000000 --- a/cpp/grammar/grammar_state_matcher_preproc.h +++ /dev/null @@ -1,438 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_state_matcher_preproc.h - * \brief The header for the preprocessing of the grammar state matcher. - */ -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_ - -#include - -#include "../support/dynamic_bitset.h" -#include "../support/encoding.h" -#include "../support/utils.h" -#include "grammar.h" -#include "grammar_state_matcher_base.h" - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm::runtime; - -/*! - * \brief Preprocessed information, for a given specific RulePosition, divides the token set - * into three categories: accepted, rejected, and uncertain. - * Accepted: tokens that can be determined by the current RulePosition to be acceptable - * Rejected: tokens that can be determined by the current RulePosition to be unacceptable - * Uncertain: tokens that need the state of the parent RulePositions to determine if acceptable - * - * \note uncertain indices are stored directly. Accepted / rejected indices have three ways to - * store to reduce memory and computation usage. See SaveType. - * \note These indices are the indices of sorted_token_table in the GrammarStateInitContext - * object, instead of the token ids. That helps the matching process. - */ -struct CatagorizedTokens { - enum class SaveType { - // Only store all accepted token indices. Then rejected indices = all_indices - accepted_indices - // - uncertain_indices. This is useful when |accepted_indices| < |rejected_indices|. - kAccepted = 0, - // Only store all accepted token indices. Then accepted indices = all_indices - rejected_indices - // - uncertain_indices. This is useful when |accepted_indices| > |rejected_indices|. - kRejected = 1, - // Store all accepted token indices in a bitset. This is useful when both |accepted_indices| and - // |rejected_indices| are large. - kAcceptedBitset = 2 - }; - SaveType save_type; - - static constexpr int USE_BITSET_THRESHOLD = 200; - - std::vector accepted_indices; - std::vector rejected_indices; - DynamicBitset accepted_bitset; - - std::vector uncertain_indices; - - CatagorizedTokens() = default; - - CatagorizedTokens(int vocab_size, - const std::vector>& sorted_token_table, - const std::vector& accepted_indices, - const std::vector& rejected_indices, - const std::vector& uncertain_indices); -}; - -/*! - * \brief All information that we need to match tokens in the tokenizer to the specified grammar. - * It is the result of preprocessing. - * \sa mlc::llm::serve::GrammarStateMatcher - */ -class GrammarStateInitContext { - public: - /******************* Information about the tokenizer *******************/ - - /*! \brief The vocabulary size of the tokenizer. Special tokens are included. */ - size_t vocab_size; - /*! \brief The token table. Special tokens are included. */ - std::vector token_table; - /*! \brief All (id, token) pairs sorted in lexicographic order. This sorting is done to - * maximize prefix reuse during matching. Special tokens and stop tokens are not included. */ - std::vector> sorted_token_table; - /*! \brief The stop tokens. When the GrammarStateMatcher can reach the end of the= grammar, - * stop tokens can be accepted. */ - std::vector stop_token_ids; - /*! \brief The special tokens. These tokens are ignored (masked out) during the grammar-guided - * generation. */ - std::unordered_set special_token_ids; - - /******************* Information about the grammar *******************/ - - /*! \brief The grammar for the GrammarStateMatcher. */ - BNFGrammar grammar; - - /******************* Grammar-specific tokenizer information *******************/ - - struct RulePositionEqual { - std::size_t operator()(const RulePosition& lhs, const RulePosition& rhs) const noexcept { - return lhs.sequence_id == rhs.sequence_id && lhs.element_id == rhs.element_id && - lhs.left_utf8_bytes == rhs.left_utf8_bytes && - lhs.element_in_string == rhs.element_in_string; - } - }; - - struct RulePositionHash { - std::size_t operator()(const RulePosition& rule_position) const noexcept { - return HashCombine(rule_position.sequence_id, rule_position.element_id, - rule_position.left_utf8_bytes, rule_position.element_in_string); - } - }; - - /*! \brief Mapping from RulePositions to the catagorized tokens. */ - std::unordered_map - catagorized_tokens_for_grammar; -}; - -/*! \brief The concrete implementation of GrammarStateMatcherNode. */ -class GrammarStateMatcherForInitContext : public GrammarStateMatcherBase { - public: - // Do not expand the initial rule position: we want to find the accepted/rejected tokens - // that exactly start from the initial rule position. - GrammarStateMatcherForInitContext(const BNFGrammar& grammar, RulePosition init_rule_position) - : GrammarStateMatcherBase(grammar, init_rule_position, false), - init_rule_id(init_rule_position.rule_id) {} - - /*! - * \brief Get the catagorized tokens for the given RulePosition. - * \param consider_parent_rule Whether to consider the parent rule. If false, there will be - * no uncertain tokens. Useful for the main rule. - */ - CatagorizedTokens GetCatagorizedTokens( - int vocab_size, const std::vector>& sorted_token_table, - bool consider_parent_rule); - - private: - using RuleExpr = BNFGrammarNode::RuleExpr; - using RuleExprType = BNFGrammarNode::RuleExprType; - - /*! \brief Check if a token can pass the lookahead assertion. */ - bool IsTokenPassLookaheadAssertion(const std::string& token, - const std::vector& can_reach_end_stack); - - // The id of the initial rule. - int32_t init_rule_id; - - // Temporary data for GetCatagorizedTokens. - std::vector tmp_accepted_indices_; - std::vector tmp_rejected_indices_; - std::vector tmp_uncertain_indices_; - std::vector tmp_can_reach_end_stack_; - std::vector tmp_can_reach_end_prefix_or_stack_; -}; - -inline CatagorizedTokens::CatagorizedTokens( - int vocab_size, const std::vector>& sorted_token_table, - const std::vector& accepted_indices, const std::vector& rejected_indices, - const std::vector& uncertain_indices) { - auto size_acc = accepted_indices.size(); - auto size_rej = rejected_indices.size(); - - save_type = size_acc >= USE_BITSET_THRESHOLD && size_rej >= USE_BITSET_THRESHOLD - ? SaveType::kAcceptedBitset - : size_acc < size_rej ? SaveType::kAccepted - : SaveType::kRejected; - - if (save_type == SaveType::kAcceptedBitset) { - accepted_bitset = DynamicBitset(vocab_size); - for (auto idx : accepted_indices) { - accepted_bitset.Set(sorted_token_table[idx].first, true); - } - } else if (save_type == SaveType::kAccepted) { - this->accepted_indices = accepted_indices; - } else { - this->rejected_indices = rejected_indices; - } - - this->uncertain_indices = uncertain_indices; -} - -bool GrammarStateMatcherForInitContext::IsTokenPassLookaheadAssertion( - const std::string& token, const std::vector& can_reach_end_stack) { - auto lookahead_assertion_id = grammar_->GetRule(init_rule_id).lookahead_assertion_id; - if (lookahead_assertion_id == -1) { - return true; - } - auto lookahead_rule_position = RulePosition(-1, lookahead_assertion_id, 0); - PushInitialState(lookahead_rule_position, true); - int token_len = token.size(); - - // Find all positions that can come to and end. Then check if the suffix from that position - // can be accepted by the lookahead assertion. - for (int i = static_cast(can_reach_end_stack.size()); i >= 0; --i) { - if (!can_reach_end_stack[i]) { - continue; - } - int last_accept_pos = i - 1; - for (int pos = i; pos < token_len; ++pos) { - if (!AcceptChar(token[pos])) { - break; - } - last_accept_pos = pos; - // Case 1. The whole rule is finished. - if (CanReachEnd()) { - // accepted chars: pos - i + 1 - // we need to rollback the pushed initial state as well - RollbackChars(pos - i + 2); - return true; - } - } - // Case 2. The whole token is accepted - if (last_accept_pos == token_len - 1) { - RollbackChars(last_accept_pos - i + 2); - return true; - } - // Case 3. The token is not accepted. Check the next position. - RollbackChars(last_accept_pos - i + 1); - } - - RollbackChars(1); - return false; -} - -inline CatagorizedTokens GrammarStateMatcherForInitContext::GetCatagorizedTokens( - int vocab_size, const std::vector>& sorted_token_table, - bool consider_parent_rule) { - tmp_accepted_indices_.clear(); - tmp_rejected_indices_.clear(); - tmp_uncertain_indices_.clear(); - - // For every character in the current token, stores whether it is possible to reach the end of - // the rule when matching until this character. Store it in a stack for later rollback. - tmp_can_reach_end_stack_.assign({CanReachEnd()}); - tmp_can_reach_end_prefix_or_stack_.assign({tmp_can_reach_end_stack_.back()}); - - int prev_matched_size = 0; - for (int i = 0; i < static_cast(sorted_token_table.size()); ++i) { - const auto& token = sorted_token_table[i].second; - - bool accepted = true; - - // Many tokens may contain the same prefix, so we will avoid unnecessary matching - // by finding the longest common prefix with the previous token. - if (i > 0) { - const auto& prev_token = sorted_token_table[i - 1].second; - int lcp_len = - std::mismatch(token.begin(), token.end(), prev_token.begin(), prev_token.end()).first - - token.begin(); - if (lcp_len > prev_matched_size) { - // Case 1. The common prefix is rejected by the matcher in the last token. Reject directly. - accepted = false; - } else if (lcp_len < prev_matched_size) { - // Case 2. The common prefix is shorter than the previous matched size. Rollback - // the non-common part. - RollbackChars(prev_matched_size - lcp_len); - tmp_can_reach_end_stack_.erase( - tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len), - tmp_can_reach_end_stack_.end()); - tmp_can_reach_end_prefix_or_stack_.erase( - tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len), - tmp_can_reach_end_prefix_or_stack_.end()); - } - prev_matched_size = std::min(prev_matched_size, lcp_len); - } - - if (accepted) { - // Accept the rest chars one by one - for (int j = prev_matched_size; j < token.size(); ++j) { - if (!AcceptChar(token[j], false)) { - accepted = false; - break; - } - tmp_can_reach_end_stack_.push_back(CanReachEnd()); - tmp_can_reach_end_prefix_or_stack_.push_back(tmp_can_reach_end_stack_.back() || - tmp_can_reach_end_prefix_or_stack_.back()); - prev_matched_size = j + 1; - } - } - - bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back(); - - if (accepted) { - tmp_accepted_indices_.push_back(i); - } else if (can_reach_end && consider_parent_rule && - IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_)) { - // 1. If the current rule is the main rule (consider_parent_rule=false), there are no - // uncertain tokens. Not accepted tokens are just rejected. - // 2. If a token cannot pass the lookahead assertion, it is rejected. - tmp_uncertain_indices_.push_back(i); - } else { - tmp_rejected_indices_.push_back(i); - } - } - // Rollback the last matched part - RollbackChars(prev_matched_size); - return CatagorizedTokens(vocab_size, sorted_token_table, tmp_accepted_indices_, - tmp_rejected_indices_, tmp_uncertain_indices_); -} - -inline std::shared_ptr GrammarStateMatcher::CreateInitContext( - const BNFGrammar& grammar, const std::vector& token_table) { - using RuleExprType = BNFGrammarNode::RuleExprType; - auto ptr = std::make_shared(); - - ptr->grammar = grammar; - ptr->vocab_size = token_table.size(); - ptr->token_table = token_table; - - if (ptr->vocab_size == 0) { - return ptr; - } - - for (int i = 0; i < token_table.size(); ++i) { - const auto& token = token_table[i]; - // TODO(yixin): Now we detect stop tokens from the token string. We should be able to pass - // the stop token set in. - // LLaMA2: - // LLaMA3: <|end_of_text|>, <|eot_id|> - // Phi-2: <|endoftext|> - // Gemma: , - if (token == "" || token == "<|end_of_text|>" || token == "<|eot_id|>" || - token == "<|endoftext|>" || token == "" || token == "") { - ptr->stop_token_ids.push_back(i); - } else if ((token[0] == '<' && token.back() == '>' && token.size() >= 3) || - token == "[@BOS@]") { - // gemma treats [@BOS@] as a special token - ptr->special_token_ids.insert(i); - } else { - ptr->sorted_token_table.push_back({i, token}); - } - } - - auto f_compare_token = [](const std::pair& a, - const std::pair& b) { - return a.second < b.second; - }; - std::sort(ptr->sorted_token_table.begin(), ptr->sorted_token_table.end(), f_compare_token); - - // Find the corresponding catagorized tokens for: - // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3) - // 2. All byte strings (with element_in_string=0, 1, 2, ...) - auto main_rule_id = grammar->GetMainRuleId(); - for (int rule_id = 0; rule_id < static_cast(grammar->NumRules()); ++rule_id) { - auto rule = grammar->GetRule(rule_id); - auto rule_body = grammar->GetRuleExpr(rule.body_expr_id); - DCHECK(rule_body.type == RuleExprType::kChoices); - for (auto sequence_id : rule_body) { - auto sequence = grammar->GetRuleExpr(sequence_id); - if (sequence.type == RuleExprType::kEmptyStr) { - continue; - } - DCHECK(sequence.type == RuleExprType::kSequence); - for (int element_id = 0; element_id < sequence.size(); ++element_id) { - auto element = grammar->GetRuleExpr(sequence[element_id]); - if (element.type == RuleExprType::kRuleRef) { - continue; - } - - auto add_catagorized_tokens = [&](const RulePosition& rule_position) { - auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, rule_position); - auto cur_catagorized_tokens_for_grammar = grammar_state_matcher.GetCatagorizedTokens( - ptr->vocab_size, ptr->sorted_token_table, rule_id != main_rule_id); - ptr->catagorized_tokens_for_grammar[rule_position] = cur_catagorized_tokens_for_grammar; - }; - - auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id); - if (element.type == RuleExprType::kByteString) { - for (int idx = 0; idx < element.size(); ++idx) { - cur_rule_position.element_in_string = idx; - add_catagorized_tokens(cur_rule_position); - } - } else { - DCHECK(element.type == RuleExprType::kCharacterClassStar || - element.type == RuleExprType::kCharacterClass); - for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) { - cur_rule_position.left_utf8_bytes = left_utf8_bytes; - add_catagorized_tokens(cur_rule_position); - } - } - } - } - } - return ptr; -} - -class GrammarInitContextCacheImpl : public GrammarInitContextCacheNode { - public: - GrammarInitContextCacheImpl(const std::vector& token_table); - - std::shared_ptr GetInitContextForJSONSchema( - const std::string& schema) final; - - std::shared_ptr GetInitContextForJSON() final; - - void Clear() final; - - private: - /*! \brief The token table associated with this storage class. */ - std::vector token_table_; - /*! \brief The cache for the init context of a JSON schema. */ - std::unordered_map> - init_ctx_for_schema_cache_; - /*! \brief The init context for JSON. */ - std::shared_ptr init_ctx_for_json_; -}; - -inline GrammarInitContextCacheImpl::GrammarInitContextCacheImpl( - const std::vector& token_table) - : token_table_(token_table) { - init_ctx_for_json_ = - GrammarStateMatcher::CreateInitContext(BNFGrammar::GetGrammarOfJSON(), token_table_); -} - -inline std::shared_ptr -GrammarInitContextCacheImpl::GetInitContextForJSONSchema(const std::string& schema) { - auto it = init_ctx_for_schema_cache_.find(schema); - if (it != init_ctx_for_schema_cache_.end()) { - return it->second; - } - auto init_ctx = - GrammarStateMatcher::CreateInitContext(BNFGrammar::FromSchema(schema), token_table_); - init_ctx_for_schema_cache_[schema] = init_ctx; - return init_ctx; -} - -inline std::shared_ptr -GrammarInitContextCacheImpl::GetInitContextForJSON() { - return init_ctx_for_json_; -} - -inline void GrammarInitContextCacheImpl::Clear() { init_ctx_for_schema_cache_.clear(); } - -GrammarInitContextCache::GrammarInitContextCache(const std::vector& token_table) - : ObjectRef(make_object(token_table)) {} - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_ diff --git a/cpp/grammar/grammar_state_matcher_state.h b/cpp/grammar/grammar_state_matcher_state.h deleted file mode 100644 index 1a132b8980..0000000000 --- a/cpp/grammar/grammar_state_matcher_state.h +++ /dev/null @@ -1,446 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/grammar_state_matcher_state.h - * \brief The header for the definition of the state used in the grammar state matcher. - */ -#ifndef MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_ -#define MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_ - -#include -#include - -#include "grammar.h" -#include "grammar_serializer.h" - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm::runtime; - -/*! \brief Specifies a position in a rule. */ -struct RulePosition { - /*! \brief The rule's id. Used for debug purposes. */ - int32_t rule_id = -1; - /*! \brief Which choice in this rule is selected. */ - int32_t sequence_id = -1; - /*! \brief Which element of the choice sequence is to be visited. */ - int32_t element_id = -1; - - /*! \brief The number of left utf8 bytes in the current element. Used when the element is - * a character class or a character class star. */ - int32_t left_utf8_bytes = 0; - /*! \brief The next position to match in the current byte string. Used when the element is - * a byte string. */ - int32_t element_in_string = 0; - - /*! \brief The id of the parent node in the RulePositionTree. */ - int32_t parent_id = -1; - /*! \brief The reference count of this RulePosition. If reduces to zero, the node will be - * removed from the RulePositionBuffer. */ - int reference_count = 0; - - /*! \brief A parent_id value of kNoParent means this RulePosition is the root of the tree. */ - static constexpr int32_t kNoParent = -1; - - constexpr RulePosition() = default; - constexpr RulePosition(int32_t rule_id, int32_t sequence_id, int32_t element_id, - int32_t parent_id = kNoParent) - : rule_id(rule_id), sequence_id(sequence_id), element_id(element_id), parent_id(parent_id) {} - - // The position is invalid when sequence_id is -1. - bool IsInvalid() const { return sequence_id == -1; } - - bool operator==(const RulePosition& other) const { - return rule_id == other.rule_id && sequence_id == other.sequence_id && - element_id == other.element_id && parent_id == other.parent_id && - left_utf8_bytes == other.left_utf8_bytes && element_in_string == other.element_in_string; - } -}; - -/*! \brief A special value for invalid RulePosition. */ -inline constexpr RulePosition kInvalidRulePosition(-1, -1, -1, -1); - -/*! \brief A buffer to manage all RulePositions. */ -class RulePositionBuffer { - public: - /*! - * \brief Allocate a new RulePosition. with given initial value. - * \returns The id of the allocated node. - */ - int32_t Allocate(RulePosition rule_position) { - int32_t id; - if (free_nodes_.empty()) { - buffer_.emplace_back(); - id = static_cast(buffer_.size()) - 1; - } else { - id = free_nodes_.back(); - DCHECK(buffer_[id].IsInvalid()); - free_nodes_.pop_back(); - } - rule_position.reference_count = 0; - buffer_[id] = rule_position; - return id; - } - - /*! \brief Free the RulePosition with the given id. */ - void Free(int32_t id) { - DCHECK(!buffer_[id].IsInvalid()); - buffer_[id] = kInvalidRulePosition; - free_nodes_.push_back(id); - } - - /*! \brief Get the capacity of the buffer. */ - size_t Capacity() const { return buffer_.size(); } - - /*! \brief Get the number of allocated nodes. */ - size_t Size() const { - DCHECK(buffer_.size() >= free_nodes_.size()); - return buffer_.size() - free_nodes_.size(); - } - - /*! \brief Get the RulePosition with the given id. */ - RulePosition& operator[](int32_t id) { - DCHECK(id >= 0 && id < static_cast(buffer_.size())); - DCHECK(!buffer_[id].IsInvalid()); - return buffer_[id]; - } - const RulePosition& operator[](int32_t id) const { - DCHECK(id >= 0 && id < static_cast(buffer_.size())); - DCHECK(!buffer_[id].IsInvalid()); - return buffer_[id]; - } - - void Reset() { - buffer_.clear(); - free_nodes_.clear(); - } - - friend class RulePositionTree; - - private: - /*! \brief The buffer to store all RulePositions. */ - std::vector buffer_; - /*! \brief A stack to store all free node ids. */ - std::vector free_nodes_; -}; - -/*! - * \brief A tree structure to store all stacks. Every stack contains several RulePositions, and - * is represented as a path from the root to a leaf node. - */ -class RulePositionTree { - public: - /*! \brief Construct a RulePositionTree associated with the given grammar. */ - RulePositionTree(const BNFGrammar& grammar) : grammar_(grammar) {} - - /*! - * \brief Create a new node with the given RulePosition. The reference count of the new node - * is zero. - * - * \note Later, this node should either be pointed by some child rule, or become a stack top - * node (so it will be pointed to by an attached pointer) to be maintained in the - * reference-counting based memory management. - */ - int32_t NewNode(const RulePosition& rule_position) { - auto id = node_buffer_.Allocate(rule_position); - if (rule_position.parent_id != RulePosition::kNoParent) { - DCHECK(rule_position.parent_id < static_cast(node_buffer_.Capacity()) && - !node_buffer_[rule_position.parent_id].IsInvalid()); - node_buffer_[rule_position.parent_id].reference_count++; - } - return id; - } - - /*! - * \brief Check if the given RulePosition points to the end of the grammar. For a position, if its - * rule id is the main rule id, and the element id equals to the length of the sequence it refers - * to, it would be the end position. - */ - bool IsEndPosition(const RulePosition& rule_position) const; - - /*! \brief Attach an additional reference to the node with the given id. */ - void AttachRefTo(int32_t id) { - DCHECK(id != RulePosition::kNoParent); - node_buffer_[id].reference_count++; - } - - /*! \brief Remove a reference to the node with the given id. If the reference count becomes zero, - * free the node and recursively all its ancestors with zero reference count. */ - void RemoveRefTo(int32_t id) { - DCHECK(id != RulePosition::kNoParent); - auto cur_node = id; - while (cur_node != RulePosition::kNoParent) { - node_buffer_[cur_node].reference_count--; - if (node_buffer_[cur_node].reference_count != 0) { - break; - } - auto next_node = node_buffer_[cur_node].parent_id; - node_buffer_.Free(cur_node); - cur_node = next_node; - } - } - - /*! \brief Get the RulePosition with the given id. */ - const RulePosition& operator[](int32_t id) const { - DCHECK(id != RulePosition::kNoParent); - DCHECK(!node_buffer_[id].IsInvalid()); - return node_buffer_[id]; - } - - /*! \brief Print the given rule_position to a string. */ - std::string PrintNode(const RulePosition& rule_position) const; - - /*! \brief Print the rule_position associated with the given id to a string. */ - std::string PrintNode(int32_t id) const; - - /*! \brief Print the stack with the given top id to a string. */ - std::string PrintStackByTopId(int32_t top_id) const; - - /*! - * \brief Check the well-formedness of the tree and the associated buffer. For debug purpose. - * \details This function checks the following properties: - * 1. Every node is pointed directly or indirectly by a outside pointer. - * 2. Every node's reference count is consistent with the actual reference count. - * 3. All ids and positions are valid. - * 4. If a node in the buffer is free, it should be equal to kInvalidRulePosition. - */ - void CheckWellFormed(const std::vector& outside_pointers) const; - - /*! \brief Reset the tree and the associated buffer. */ - void Reset() { node_buffer_.Reset(); } - - private: - /*! \brief The grammar associated with this RulePositionTree. */ - BNFGrammar grammar_; - /*! \brief The buffer to store all RulePositions. */ - RulePositionBuffer node_buffer_; -}; - -/*! - * \brief A class to maintain the stack tops and its history to support rollback. - * \details This class helps to maintain nodes by automatically maintaining the attached references. - * If a node is not existing in any stack in the history record, it will be freed. - * - * It can store up to the previous max_rollback_steps + 1 steps of history, and thus supports - * rolling back up to max_rollback_steps steps. - */ -class StackTopsHistory { - public: - /*! - * \param tree The RulePositionTree to be associated with. Possibly modify the tree by attaching - * and removing references to the stack top nodes. - * \param max_rollback_steps The maximum number of rollback steps to be supported. - */ - StackTopsHistory(RulePositionTree* tree) : tree_(tree) {} - - /*! - * \brief Push a new history record consisting a list of stack tops. These nodes will be recorded - * as existing in a stack (by attaching a reference to them). - * \param stack_tops The stack tops to be pushed. - * \param drop_old Whether to drop the oldest history record if the history size exceeds the - * limit. If the history is dropped, node that do not exist in any stack any more will be freed. - */ - void PushHistory(const std::vector& stack_tops) { - stack_tops_history_.push_back(stack_tops); - for (auto id : stack_tops) { - tree_->AttachRefTo(id); - } - } - - /*! \brief Roll back to several previous steps. Possibly frees node that do not exist in any stack - * any more. */ - void Rollback(int rollback_steps) { - DCHECK(rollback_steps < stack_tops_history_.size()) - << "The number of requested rollback steps is greater than or equal to the current " - "history " - << "size: " << rollback_steps << " vs " << stack_tops_history_.size() << "."; - while (rollback_steps--) { - PopLatest(); - } - } - - /*! \brief Discard the earliest several steps. Possibly frees node that do not exist in any stack - * any more. */ - void DiscardEarliest(int discard_steps) { - DCHECK(discard_steps < stack_tops_history_.size()) - << "The number of requested discard steps is greater than or equal to the current " - "history " - << "size: " << discard_steps << " vs " << stack_tops_history_.size() << "."; - while (discard_steps--) { - PopEarliest(); - } - } - - /*! \brief Get the latest stack tops. */ - const std::vector& GetLatest() const { return stack_tops_history_.back(); } - - /*! - * \brief Print one history record. - * \param history_position_to_latest The number of steps behind the latest record. 0 means the - * latest record. - */ - std::string PrintHistory(int history_position_to_latest = 0) const; - - /*! \brief Get the number of history records. */ - int Size() const { return stack_tops_history_.size(); } - - /*! \brief Check the well-formedness of the tree and the associated buffer. */ - void CheckWellFormed() const; - - /*! \brief Reset the history and the associated node tree. */ - void Reset() { - stack_tops_history_.clear(); - tree_->Reset(); - } - - private: - /*! \brief Pop the oldest history record. Possibly frees node that do not exist in any stack any - * more. */ - void PopEarliest() { - const auto& old_stack_tops = stack_tops_history_.front(); - for (auto id : old_stack_tops) { - tree_->RemoveRefTo(id); - } - stack_tops_history_.pop_front(); - } - - /*! \brief Pop the latest history record. Possibly frees node that do not exist in any stack any - * more. */ - void PopLatest() { - const auto& new_stack_tops = stack_tops_history_.back(); - for (auto id : new_stack_tops) { - tree_->RemoveRefTo(id); - } - stack_tops_history_.pop_back(); - } - - /*! \brief Modifiable pointer to the RulePositionTree. */ - RulePositionTree* tree_; - /*! \brief The history of stack tops. */ - std::deque> stack_tops_history_; -}; - -inline bool RulePositionTree::IsEndPosition(const RulePosition& rule_position) const { - return rule_position.parent_id == RulePosition::kNoParent && - grammar_->GetRuleExpr(rule_position.sequence_id).size() == rule_position.element_id; -} - -inline std::string RulePositionTree::PrintNode(int32_t id) const { - return "id: " + std::to_string(id) + ", " + PrintNode(node_buffer_[id]); -} - -inline std::string RulePositionTree::PrintNode(const RulePosition& rule_position) const { - std::stringstream ss; - ss << "RulePosition: rule " << rule_position.rule_id; - if (rule_position.rule_id != -1) { - ss << ": " << grammar_->GetRule(rule_position.rule_id).name; - } - ss << ", sequence " << rule_position.sequence_id << ": " - << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.sequence_id); - ss << ", element id: " << rule_position.element_id; - - auto sequence = grammar_->GetRuleExpr(rule_position.sequence_id); - if (rule_position.element_id < static_cast(sequence.size())) { - auto element = grammar_->GetRuleExpr(sequence[rule_position.element_id]); - if (element.type == BNFGrammarNode::RuleExprType::kByteString) { - ss << ", element in string: " << rule_position.element_in_string; - } else { - DCHECK(element.type == BNFGrammarNode::RuleExprType::kCharacterClass || - element.type == BNFGrammarNode::RuleExprType::kCharacterClassStar); - ss << ", left utf8 bytes: " << rule_position.left_utf8_bytes; - } - } - - ss << ", parent id: " << rule_position.parent_id - << ", ref count: " << rule_position.reference_count; - return ss.str(); -} - -inline std::string RulePositionTree::PrintStackByTopId(int32_t top_id) const { - std::stringstream ss; - std::vector stack; - for (auto cur_id = top_id; cur_id != RulePosition::kNoParent; - cur_id = node_buffer_[cur_id].parent_id) { - stack.push_back(cur_id); - } - ss << "{\n"; - for (auto it = stack.rbegin(); it != stack.rend(); ++it) { - ss << PrintNode(*it) << "\n"; - } - ss << "}"; - return ss.str(); -} - -inline void RulePositionTree::CheckWellFormed(const std::vector& outside_pointers) const { - const auto& buffer = node_buffer_.buffer_; - std::unordered_set free_nodes_set(node_buffer_.free_nodes_.begin(), - node_buffer_.free_nodes_.end()); - int buffer_size = static_cast(buffer.size()); - std::vector new_reference_counter(buffer_size, 0); - std::vector visited(buffer_size, false); - std::queue visit_queue; - for (auto id : outside_pointers) { - CHECK(id >= 0 && id < buffer_size); - CHECK(!buffer[id].IsInvalid()); - new_reference_counter[id]++; - if (visited[id] == false) { - visited[id] = true; - visit_queue.push(id); - } - } - while (!visit_queue.empty()) { - auto cur_id = visit_queue.front(); - visit_queue.pop(); - const auto& rule_position = buffer[cur_id]; - if (rule_position.parent_id != RulePosition::kNoParent) { - CHECK(rule_position.parent_id >= 0 && rule_position.parent_id < buffer_size); - CHECK(!buffer[rule_position.parent_id].IsInvalid()); - new_reference_counter[rule_position.parent_id]++; - if (visited[rule_position.parent_id] == false) { - visited[rule_position.parent_id] = true; - visit_queue.push(rule_position.parent_id); - } - } - } - - for (int i = 0; i < static_cast(buffer.size()); ++i) { - if (free_nodes_set.count(i)) { - CHECK(buffer[i].IsInvalid()); - CHECK(visited[i] == false); - } else { - CHECK(visited[i] == true); - CHECK(!buffer[i].IsInvalid()); - CHECK(new_reference_counter[i] == buffer[i].reference_count) - << "Reference counters unmatch for node #" << i << ": Updated " - << new_reference_counter[i] << ", Original " << buffer[i].reference_count; - } - } -} - -inline std::string StackTopsHistory::PrintHistory(int history_position_to_latest) const { - const auto& latest_tops = stack_tops_history_[static_cast(stack_tops_history_.size()) - - 1 - history_position_to_latest]; - std::stringstream ss; - ss << "Stacks tops size: " << latest_tops.size() << std::endl; - int cnt = 0; - for (auto id : latest_tops) { - ss << "Stack #" << cnt << ": " << tree_->PrintStackByTopId(id) << "\n"; - ++cnt; - } - return ss.str(); -} - -inline void StackTopsHistory::CheckWellFormed() const { - std::vector outside_pointers; - for (const auto& stack_tops : stack_tops_history_) { - outside_pointers.insert(outside_pointers.end(), stack_tops.begin(), stack_tops.end()); - } - tree_->CheckWellFormed(outside_pointers); -} - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_ diff --git a/cpp/grammar/json_schema_converter.cc b/cpp/grammar/json_schema_converter.cc deleted file mode 100644 index 10c1dbe76b..0000000000 --- a/cpp/grammar/json_schema_converter.cc +++ /dev/null @@ -1,997 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/json_schema_converter.cc - */ -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm::runtime; - -// EMCC somehow cannot pickup operator overload from picojson.h, so we copy here. -#ifdef COMPILE_MLC_WASM_RUNTIME -inline std::ostream& operator<<(std::ostream& os, const picojson::value& x) { - x.serialize(std::ostream_iterator(os)); - return os; -} -#endif - -/*! - * \brief Manage the indent and separator for the generation of EBNF grammar. - * \param indent The number of spaces for each indent. If it is std::nullopt, there will be no - * indent or newline. - * \param separator The separator between different elements in json. Examples include "," and ", ". - */ -class IndentManager { - public: - IndentManager(std::optional indent, const std::string& separator) - : enable_newline_(indent.has_value()), - indent_(indent.value_or(0)), - separator_(separator), - total_indent_(0), - is_first_({true}) {} - - /*! \brief Enter a new indent level. */ - void StartIndent() { - total_indent_ += indent_; - is_first_.push_back(true); - } - - /*! \brief Exit the current indent level. */ - void EndIndent() { - total_indent_ -= indent_; - is_first_.pop_back(); - } - - /*! - * \brief Get the next separator in the current level. When first called in the current - * level, the starting separator will be returned. When called again, the middle separator will be - * returned. When called with `is_end=True`, the ending separator will be returned. - * \param is_end Get the separator for the end of the current level. - * \example - * \code - * IndentManager indent_manager(2, ", "); - * indent_manager.StartIndent(); - * indent_manager.GetSep(); // get the start separator: "\"\n \"" - * indent_manager.GetSep(); // get the middle separator: "\",\n \"" - * indent_manager.GetSep(true); // get the end separator: "\"\n\"" - * \endcode - */ - std::string NextSeparator(bool is_end = false); - - /*! \brief Get the separator itself. */ - std::string GetBareSeparator() { return separator_; } - - private: - bool enable_newline_; - int indent_; - std::string separator_; - int total_indent_; - std::vector is_first_; - friend class JSONSchemaToEBNFConverter; -}; - -std::string IndentManager::NextSeparator(bool is_end) { - std::string res = ""; - if (!is_first_.back() && !is_end) { - res += separator_; - } - is_first_.back() = false; - - if (enable_newline_) { - res += "\\n"; - } - - if (!is_end) { - res += std::string(total_indent_, ' '); - } else { - res += std::string(total_indent_ - indent_, ' '); - } - - return "\"" + res + "\""; -} - -/*! - * \brief Convert JSON schema string to EBNF grammar string. The parameters follow - * JSONSchemaToEBNF(). - * - * \note About the representation of json schema in this converter. JSON schema could be two types: - * bool (true or false) or dict (a json dict) containing attributes. We use picojson::value to - * represent the json schema. - */ -class JSONSchemaToEBNFConverter { - public: - JSONSchemaToEBNFConverter( - const picojson::value& json_schema, std::optional indent = std::nullopt, - std::optional> separators = std::nullopt, - bool strict_mode = false); - - /*! \brief The main method. Convert the JSON schema to EBNF grammar string. */ - std::string Convert(); - - private: - // The name of the basic rules - inline static const std::string kBasicAny = "basic_any"; - inline static const std::string kBasicInteger = "basic_integer"; - inline static const std::string kBasicNumber = "basic_number"; - inline static const std::string kBasicString = "basic_string"; - inline static const std::string kBasicBoolean = "basic_boolean"; - inline static const std::string kBasicNull = "basic_null"; - inline static const std::string kBasicArray = "basic_array"; - inline static const std::string kBasicObject = "basic_object"; - - // The name of the helper rules to construct basic rules - inline static const std::string kBasicEscape = "basic_escape"; - inline static const std::string kBasicStringSub = "basic_string_sub"; - - /*! \brief Add the basic rules to the rules list and the basic_rules_cache. */ - void AddBasicRules(); - - /*! \brief Add helper rules for the basic rules. */ - void AddHelperRules(); - - /*! \brief Create a rule for the given schema and name, and add it to the basic_rules_cache. */ - void CreateBasicRule(const picojson::value& schema, const std::string& name); - - /*! \brief Get the index for the schema in the cache. Keys that do not effect the validation - * will be ignored when finding the corresponding cache rule. */ - std::string GetSchemaCacheIndex(const picojson::value& schema); - - /*! - * \brief Create a rule with the given schema and rule name hint. - * \returns The name of the rule will be returned. That is not necessarily the same as the - * rule_name_hint due to the caching mechanism. - */ - std::string CreateRuleFromSchema(const picojson::value& schema, - const std::string& rule_name_hint); - - /*! \brief Get the next separator in the current level from the indent manager. */ - std::string NextSeparator(bool is_end = false); - - /*! \brief Warn if any keyword is existing in the schema but not supported. */ - static void WarnUnsupportedKeywords(const picojson::value& schema, - const std::vector& keywords); - - /*! \brief Warn if any keyword is existing in the object but not supported. */ - static void WarnUnsupportedKeywords(const picojson::object& schema, - const std::vector& keywords); - - /*! \brief Visit the schema and return the rule body for later constructing the rule. */ - std::string VisitSchema(const picojson::value& schema, const std::string& rule_name); - - /*! \brief Visit a reference schema. */ - std::string VisitRef(const picojson::object& schema, const std::string& rule_name); - - /*! \brief Get the schema from the URI. */ - picojson::value URIToSchema(const picojson::value& uri); - - /*! \brief Visit a const schema. */ - std::string VisitConst(const picojson::object& schema, const std::string& rule_name); - - /*! \brief Visit an enum schema. */ - std::string VisitEnum(const picojson::object& schema, const std::string& rule_name); - - /*! \brief Convert the JSON string to a printable string that can be shown in BNF. */ - std::string JSONStrToPrintableStr(const std::string& json_str); - - /*! \brief Visit an anyOf schema. */ - std::string VisitAnyOf(const picojson::object& schema, const std::string& rule_name); - - /*! \brief Visit a true schema that can match anything. */ - std::string VisitAny(const picojson::value& schema, const std::string& rule_name); - - /*! \brief Visit an integer schema. */ - std::string VisitInteger(const picojson::object& schema, const std::string& rule_name); - - /*! \brief Visit a number schema. */ - std::string VisitNumber(const picojson::object& schema, const std::string& rule_name); - /*! \brief Visit a string schema. */ - std::string VisitString(const picojson::object& schema, const std::string& rule_name); - - /*! \brief Visit a boolean schema. */ - std::string VisitBoolean(const picojson::object& schema, const std::string& rule_name); - - /*! \brief Visit a null schema. */ - std::string VisitNull(const picojson::object& schema, const std::string& rule_name); - - /*! - * \brief Visit an array schema. - * \example - * Schema: - * \code - * { - * "type": "array", - * "prefixItems": [ - * {"type": "boolean"}, - * {"type": "integer"} - * ], - * "items": { - * "type": "string" - * } - * } - * \endcode - * Rule (not considering the indent): - * \code - * main ::= "[" basic_boolean ", " basic_integer (", " basic_string)* "]" - * \endcode - */ - std::string VisitArray(const picojson::object& schema, const std::string& rule_name); - - /*! - * \brief Visit an object schema. - * \example - * Schema: - * \code - * { - * "type": "object", - * "properties": { - * "a": {"type": "string"}, - * "b": {"type": "integer"} - * }, - * "required": ["a"], - * "additionalProperties": true - * } - * \endcode - * - * Rule (not considering the indent): - * \code - * main ::= "{" "a" ":" basic_string (", " "b" ":" basic_integer)* - * (", " basic_string ": " basic_any)* "}" - * \endcode - - * We need special handling when all properties are optional, since the handling of separators - * is tricky in this case. E.g. - - * Schema: - * \code - * { - * "type": "object", - * "properties": { - * "a": {"type": "string"}, - * "b": {"type": "integer"}, - * "c": {"type": "boolean"} - * }, - * "additionalProperties": true - * } - * \endcode - * - * Rule (indent=2): - * \code - * main ::= "{" ("\n " (a main_sub_1 | b main_sub_2 | c main_sub_3 | d main_sub_3) - * "\n" | "") "}" - * main_sub_1 ::= ",\n " b r2 | r2 - * main_sub_2 ::= ",\n " c r3 | r3 - * main_sub_3 ::= (",\n " d)* - * \endcode - */ - std::string VisitObject(const picojson::object& schema, const std::string& rule_name); - - /*! \brief Get the pattern for a property in the object schema. */ - std::string GetPropertyPattern(const std::string& prop_name, const picojson::value& prop_schema, - const std::string& rule_name, int idx); - - /*! \brief Get the pattern for the additional/unevaluated properties in the object schema. */ - std::string GetOtherPropertyPattern(const std::string& key_pattern, - const picojson::value& prop_schema, - const std::string& rule_name, - const std::string& rule_name_suffix); - - /*! \brief Get the partial rule for the properties when all properties are optional. See the - * example in VisitObject(). */ - std::string GetPartialRuleForPropertiesAllOptional( - const std::vector>& properties, - const picojson::value& additional, const std::string& rule_name, - const std::string& additional_suffix = ""); - - /*! - * \brief Get the partial rule for the properties when some properties are required. See the - * example in VisitObject(). - * - * The constructed rule should be: - * \code - * start_separator (optional_property separator)? (optional_property separator)? ... - * first_required_property (separator optional_property)? separator required_property ... - * end_separator - * \endcode - * - * i.e. Before the first required property, all properties are in the form - * (property separator) ; and after the first required property, all properties are in the form - * (separator property) . */ - std::string GetPartialRuleForPropertiesContainRequired( - const std::vector>& properties, - const std::unordered_set& required, const std::string& rule_name); - - // The indent manager to get separators - std::unique_ptr indentManager_; - // The root JSON schema - picojson::value json_schema_; - // Whether to use strict mode in conversion. See JSONSchemaToEBNF(). - bool strict_mode_; - // The colon separator - std::string colon_; - // The rules constructed - std::vector> rules_; - // The cache for basic rules. Mapping from the key of schema returned by GetSchemaCacheIndex() - // to the basic rule name. - std::map basic_rules_cache_; -}; - -JSONSchemaToEBNFConverter::JSONSchemaToEBNFConverter( - const picojson::value& json_schema, std::optional indent, - std::optional> separators, bool strict_mode) - : json_schema_(json_schema), strict_mode_(strict_mode) { - if (!separators.has_value()) { - separators = (indent == std::nullopt) ? std::make_pair(", ", ": ") : std::make_pair(",", ": "); - } - indentManager_ = std::make_unique(indent, separators->first); - colon_ = separators->second; - - AddBasicRules(); -} - -std::string JSONSchemaToEBNFConverter::Convert() { - CreateRuleFromSchema(json_schema_, "main"); - std::string res; - for (auto& rule : rules_) { - res += rule.first + " ::= " + rule.second + "\n"; - } - return res; -} - -void JSONSchemaToEBNFConverter::AddBasicRules() { - bool past_strict_mode = strict_mode_; - strict_mode_ = false; - - auto past_indent_manager = std::move(indentManager_); - indentManager_ = - std::make_unique(std::nullopt, past_indent_manager->GetBareSeparator()); - - AddHelperRules(); - CreateBasicRule(picojson::value(true), kBasicAny); - basic_rules_cache_[GetSchemaCacheIndex(picojson::value(picojson::object()))] = kBasicAny; - CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("integer")}}), - kBasicInteger); - CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("number")}}), - kBasicNumber); - CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("string")}}), - kBasicString); - CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("boolean")}}), - kBasicBoolean); - CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("null")}}), kBasicNull); - CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("array")}}), - kBasicArray); - CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("object")}}), - kBasicObject); - - strict_mode_ = past_strict_mode; - indentManager_ = std::move(past_indent_manager); -} - -void JSONSchemaToEBNFConverter::AddHelperRules() { - rules_.push_back(std::make_pair( - kBasicEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]")); - rules_.push_back(std::make_pair( - kBasicStringSub, "(\"\\\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + " | \"\\\\\" " + - kBasicEscape + " " + kBasicStringSub + ") (= [ \\n\\t]* [,}\\]:])")); -} - -void JSONSchemaToEBNFConverter::CreateBasicRule(const picojson::value& schema, - const std::string& name) { - std::string rule_name = CreateRuleFromSchema(schema, name); - basic_rules_cache_[GetSchemaCacheIndex(schema)] = rule_name; -} - -std::string JSONSchemaToEBNFConverter::NextSeparator(bool is_end) { - return indentManager_->NextSeparator(is_end); -} - -void JSONSchemaToEBNFConverter::WarnUnsupportedKeywords(const picojson::value& schema, - const std::vector& keywords) { - if (schema.is()) { - return; - } - - ICHECK(schema.is()); - WarnUnsupportedKeywords(schema.get(), keywords); -} - -void JSONSchemaToEBNFConverter::WarnUnsupportedKeywords(const picojson::object& schema, - const std::vector& keywords) { - for (const auto& keyword : keywords) { - if (schema.find(keyword) != schema.end()) { - LOG(WARNING) << "Keyword " << keyword << " is not supported in schema " - << picojson::value(schema); - } - } -} - -std::string JSONSchemaToEBNFConverter::CreateRuleFromSchema(const picojson::value& schema, - const std::string& rule_name_hint) { - std::string idx = GetSchemaCacheIndex(schema); - if (basic_rules_cache_.count(idx)) { - return basic_rules_cache_[idx]; - } - - rules_.push_back(std::make_pair(rule_name_hint, VisitSchema(schema, rule_name_hint))); - return rule_name_hint; -} - -std::string JSONSchemaToEBNFConverter::GetSchemaCacheIndex(const picojson::value& schema) { - // Keys that do not effect the validation - static const std::unordered_set kSkippedKeys = { - "title", "default", "description", "examples", "deprecated", - "readOnly", "writeOnly", "$comment", "$schema", - }; - if (schema.is()) { - // remove skipped keys and sort key by lexicographical order - std::string result = "{"; - std::vector> sorted_kv; - for (const auto& kv : schema.get()) { - if (kSkippedKeys.count(kv.first) == 0) { - sorted_kv.push_back(kv); - } - } - std::sort(sorted_kv.begin(), sorted_kv.end(), - [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; }); - int idx = 0; - for (const auto& [key, value] : sorted_kv) { - if (idx != 0) { - result += ","; - } - ++idx; - result += "\"" + key + "\":" + GetSchemaCacheIndex(value); - } - return result + "}"; - } else if (schema.is()) { - std::string result = "["; - int idx = 0; - for (const auto& item : schema.get()) { - if (idx != 0) { - result += ","; - } - ++idx; - result += GetSchemaCacheIndex(item); - } - return result + "]"; - } - // If the object is neither an array nor an object, return it directly - return schema.serialize(false); -} - -std::string JSONSchemaToEBNFConverter::VisitSchema(const picojson::value& schema, - const std::string& rule_name) { - if (schema.is()) { - ICHECK(schema.get()); - return VisitAny(schema, rule_name); - } - - WarnUnsupportedKeywords(schema, { - "allof", - "oneof", - "not", - "if", - "then", - "else", - "dependentRequired", - "dependentSchemas", - }); - - ICHECK(schema.is()); - - const auto& schema_obj = schema.get(); - - if (schema_obj.count("$ref")) { - return VisitRef(schema_obj, rule_name); - } else if (schema_obj.count("const")) { - return VisitConst(schema_obj, rule_name); - } else if (schema_obj.count("enum")) { - return VisitEnum(schema_obj, rule_name); - } else if (schema_obj.count("anyOf")) { - return VisitAnyOf(schema_obj, rule_name); - } else if (schema_obj.count("type")) { - const std::string& type = schema_obj.at("type").get(); - if (type == "integer") { - return VisitInteger(schema_obj, rule_name); - } else if (type == "number") { - return VisitNumber(schema_obj, rule_name); - } else if (type == "string") { - return VisitString(schema_obj, rule_name); - } else if (type == "boolean") { - return VisitBoolean(schema_obj, rule_name); - } else if (type == "null") { - return VisitNull(schema_obj, rule_name); - } else if (type == "array") { - return VisitArray(schema_obj, rule_name); - } else if (type == "object") { - return VisitObject(schema_obj, rule_name); - } else { - LOG(FATAL) << "Unsupported type " << type << " in schema " << schema; - } - } - - // If no above keyword is detected, we treat it as any - return VisitAny(schema, rule_name); -} - -std::string JSONSchemaToEBNFConverter::VisitRef(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("$ref")); - picojson::value new_schema = URIToSchema(schema.at("$ref")); - if (!new_schema.is()) { - picojson::object new_schema_obj = new_schema.get(); - for (const auto& [k, v] : schema) { - if (k != "$ref") { - new_schema_obj[k] = v; - } - } - new_schema = picojson::value(new_schema_obj); - } - return VisitSchema(new_schema, rule_name); -} - -picojson::value JSONSchemaToEBNFConverter::URIToSchema(const picojson::value& uri) { - if (uri.get().substr(0, 8) == "#/$defs/") { - return json_schema_.get("$defs").get(uri.get().substr(8)); - } - LOG(WARNING) << "Now only support URI starting with '#/$defs/' but got " << uri; - return picojson::value(true); -} - -std::string JSONSchemaToEBNFConverter::VisitConst(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("const")); - // TODO(yixin): Customize serialize to support indent logics - return "\"" + JSONStrToPrintableStr(schema.at("const").serialize()) + "\""; -} - -std::string JSONSchemaToEBNFConverter::VisitEnum(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("enum")); - std::string result = ""; - int idx = 0; - for (auto value : schema.at("enum").get()) { - if (idx != 0) { - result += " | "; - } - ++idx; - result += "(\"" + JSONStrToPrintableStr(value.serialize()) + "\")"; - } - return result; -} - -std::string JSONSchemaToEBNFConverter::JSONStrToPrintableStr(const std::string& json_str) { - static const std::vector> kReplaceMapping = {{"\\", "\\\\"}, - {"\"", "\\\""}}; - std::string result = json_str; - for (const auto& [k, v] : kReplaceMapping) { - size_t pos = 0; - while ((pos = result.find(k, pos)) != std::string::npos) { - result.replace(pos, k.length(), v); - pos += v.length(); - } - } - return result; -} - -std::string JSONSchemaToEBNFConverter::VisitAnyOf(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("anyOf")); - std::string result = ""; - int idx = 0; - for (auto anyof_schema : schema.at("anyOf").get()) { - if (idx != 0) { - result += " | "; - } - result += CreateRuleFromSchema(anyof_schema, rule_name + "_case_" + std::to_string(idx)); - ++idx; - } - return result; -} - -std::string JSONSchemaToEBNFConverter::VisitAny(const picojson::value& schema, - const std::string& rule_name) { - // Note integer is a subset of number, so we don't need to add integer here - return kBasicNumber + " | " + kBasicString + " | " + kBasicBoolean + " | " + kBasicNull + " | " + - kBasicArray + " | " + kBasicObject; -} - -std::string JSONSchemaToEBNFConverter::VisitInteger(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("type")); - ICHECK(schema.at("type").get() == "integer"); - WarnUnsupportedKeywords(schema, { - "multipleOf", - "minimum", - "maximum", - "exclusiveMinimum", - "exclusiveMaximum", - }); - return "(\"0\" | \"-\"? [1-9] [0-9]*) \".0\"?"; -} - -std::string JSONSchemaToEBNFConverter::VisitNumber(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("type")); - ICHECK(schema.at("type").get() == "number"); - WarnUnsupportedKeywords(schema, { - "multipleOf", - "minimum", - "maximum", - "exclusiveMinimum", - "exclusiveMaximum", - }); - return "(\"0\" | \"-\"? [1-9] [0-9]*) (\".\" [0-9]+)? ([eE] [+-]? [0-9]+)?"; -} - -std::string JSONSchemaToEBNFConverter::VisitString(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("type")); - ICHECK(schema.at("type").get() == "string"); - WarnUnsupportedKeywords(schema, { - "minLength", - "maxLength", - "pattern", - "format", - }); - return "[\"] " + kBasicStringSub; -} - -std::string JSONSchemaToEBNFConverter::VisitBoolean(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("type")); - ICHECK(schema.at("type").get() == "boolean"); - return "\"true\" | \"false\""; -} - -std::string JSONSchemaToEBNFConverter::VisitNull(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("type")); - ICHECK(schema.at("type").get() == "null"); - return "\"null\""; -} - -std::string JSONSchemaToEBNFConverter::VisitArray(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("type")); - ICHECK(schema.at("type").get() == "array"); - WarnUnsupportedKeywords(schema, { - "uniqueItems", - "contains", - "minContains", - "maxContains", - "minItems", - "maxItems", - }); - - std::string result = "\"[\""; - - indentManager_->StartIndent(); - - // 1. Handle prefix items - if (schema.count("prefixItems")) { - const auto& prefix_items = schema.at("prefixItems").get(); - for (int i = 0; i < prefix_items.size(); ++i) { - ICHECK(prefix_items[i].is()); - result += " " + NextSeparator() + " "; - result += CreateRuleFromSchema(prefix_items[i], rule_name + "_item_" + std::to_string(i)); - } - } - - // 2. Find additional items - picojson::value additional_item = picojson::value(false); - std::string additional_suffix = ""; - - if (schema.count("items") && (!schema.at("items").is() || schema.at("items").get())) { - additional_item = schema.at("items"); - additional_suffix = "items"; - } - - // If items is specified in the schema, we don't need to consider unevaluatedItems - if (schema.count("items") == 0) { - picojson::value unevaluated = schema.count("unevaluatedItems") ? schema.at("unevaluatedItems") - : picojson::value(!strict_mode_); - if (!unevaluated.is() || unevaluated.get()) { - additional_item = unevaluated; - additional_suffix = "uneval"; - } - } - - // 3. Handle additional items and the end separator - bool could_be_empty = false; - if (additional_item.is() && !additional_item.get()) { - result += " " + NextSeparator(true); - } else { - std::string additional_pattern = - CreateRuleFromSchema(additional_item, rule_name + "_" + additional_suffix); - if (schema.count("prefixItems")) { - result += " (" + NextSeparator() + " " + additional_pattern + ")* "; - result += NextSeparator(true); - } else { - result += " " + NextSeparator() + " " + additional_pattern + " ("; - result += NextSeparator() + " " + additional_pattern + ")* "; - result += NextSeparator(true); - could_be_empty = true; - } - } - - indentManager_->EndIndent(); - - result += " \"]\""; - - if (could_be_empty) { - result = "(" + result + ") | \"[]\""; - } - - return result; -} - -std::string JSONSchemaToEBNFConverter::GetPropertyPattern(const std::string& prop_name, - const picojson::value& prop_schema, - const std::string& rule_name, int idx) { - // the outer quote is for the string in EBNF grammar, and the inner quote is for - // the string in JSON - std::string key = "\"\\\"" + prop_name + "\\\"\""; - std::string colon = "\"" + colon_ + "\""; - std::string value = CreateRuleFromSchema(prop_schema, rule_name + "_prop_" + std::to_string(idx)); - return key + " " + colon + " " + value; -} - -std::string JSONSchemaToEBNFConverter::GetOtherPropertyPattern( - const std::string& key_pattern, const picojson::value& prop_schema, - const std::string& rule_name, const std::string& rule_name_suffix) { - std::string colon = "\"" + colon_ + "\""; - std::string value = CreateRuleFromSchema(prop_schema, rule_name + "_" + rule_name_suffix); - return key_pattern + " " + colon + " " + value; -} - -std::string JSONSchemaToEBNFConverter::GetPartialRuleForPropertiesAllOptional( - const std::vector>& properties, - const picojson::value& additional, const std::string& rule_name, - const std::string& additional_suffix) { - ICHECK(properties.size() >= 1); - - std::string first_sep = NextSeparator(); - std::string mid_sep = NextSeparator(); - std::string last_sep = NextSeparator(true); - - std::string res = ""; - - std::vector prop_patterns; - int idx = 0; - for (const auto& [prop_name, prop_schema] : properties) { - prop_patterns.push_back(GetPropertyPattern(prop_name, prop_schema, rule_name, idx)); - ++idx; - } - - std::vector rule_names(properties.size(), ""); - - // construct the last rule - std::string additional_prop_pattern; - if (!additional.is() || additional.get()) { - additional_prop_pattern = - GetOtherPropertyPattern(kBasicString, additional, rule_name, additional_suffix); - std::string last_rule_body = "(" + mid_sep + " " + additional_prop_pattern + ")*"; - std::string last_rule_name = - rule_name + "_part_" + std::to_string(static_cast(properties.size()) - 1); - rules_.push_back(std::make_pair(last_rule_name, last_rule_body)); - rule_names.back() = last_rule_name; - } else { - rule_names.back() = "\"\""; - } - - // construct 0~(len(properties) - 2) rules - for (int i = properties.size() - 2; i >= 0; --i) { - const std::string& prop_pattern = prop_patterns[i + 1]; - const std::string& last_rule_name = rule_names[i + 1]; - std::string cur_rule_body = - last_rule_name + " | " + mid_sep + " " + prop_pattern + " " + last_rule_name; - std::string cur_rule_name = rule_name + "_part_" + std::to_string(i); - rules_.push_back(std::make_pair(cur_rule_name, cur_rule_body)); - rule_names[i] = cur_rule_name; - } - - // construct the main rule - for (int i = 0; i < properties.size(); ++i) { - if (i != 0) { - res += " | "; - } - res += "(" + prop_patterns[i] + " " + rule_names[i] + ")"; - } - - if (!additional.is() || additional.get()) { - res += " | " + additional_prop_pattern + " " + rule_names.back(); - } - - // add separators and the empty string option - res = first_sep + " (" + res + ") " + last_sep; - return res; -} - -std::string JSONSchemaToEBNFConverter::GetPartialRuleForPropertiesContainRequired( - const std::vector>& properties, - const std::unordered_set& required, const std::string& rule_name) { - // Find the index of the first required property - int first_required_idx = properties.size(); - for (int i = 0; i < properties.size(); ++i) { - if (required.count(properties[i].first)) { - first_required_idx = i; - break; - } - } - ICHECK(first_required_idx < properties.size()); - - std::string res = NextSeparator(); - - // Handle the properties before the first required property - for (int i = 0; i < first_required_idx; ++i) { - const auto& [prop_name, prop_schema] = properties[i]; - ICHECK(!prop_schema.is() || prop_schema.get()); - std::string property_pattern = GetPropertyPattern(prop_name, prop_schema, rule_name, i); - res += " (" + property_pattern + " " + NextSeparator() + ")?"; - } - - // Handle the first required property - const auto& [prop_name, prop_schema] = properties[first_required_idx]; - std::string property_pattern = - GetPropertyPattern(prop_name, prop_schema, rule_name, first_required_idx); - res += " " + property_pattern; - - // Handle the properties after the first required property - for (int i = first_required_idx + 1; i < properties.size(); ++i) { - const auto& [prop_name, prop_schema] = properties[i]; - ICHECK(!prop_schema.is() || prop_schema.get()); - std::string property_pattern = GetPropertyPattern(prop_name, prop_schema, rule_name, i); - if (required.count(prop_name)) { - res += " " + NextSeparator() + " " + property_pattern; - } else { - res += " (" + NextSeparator() + " " + property_pattern + ")?"; - } - } - - return res; -} - -std::string JSONSchemaToEBNFConverter::VisitObject(const picojson::object& schema, - const std::string& rule_name) { - ICHECK(schema.count("type")); - ICHECK(schema.at("type").get() == "object"); - WarnUnsupportedKeywords(schema, { - "patternProperties", - "minProperties", - "maxProperties", - "propertyNames", - }); - - std::string result = "\"{\""; - - // could_be_empty will be set to True when the rule could be "{}". We will handle this case at - // last, and handle non-empty cases before that. - bool could_be_empty = false; - - indentManager_->StartIndent(); - - // 1. Handle properties - std::vector> properties; - if (schema.count("properties")) { - auto properties_obj = schema.at("properties").get(); - for (const auto& key : properties_obj.ordered_keys()) { - properties.push_back({key, properties_obj.at(key)}); - } - } - - std::unordered_set required; - if (schema.count("required")) { - for (const auto& required_prop : schema.at("required").get()) { - required.insert(required_prop.get()); - } - } - - // 2. Find additional properties - picojson::value additional_property = picojson::value(false); - std::string additional_suffix = ""; - - if (schema.count("additionalProperties") && (!schema.at("additionalProperties").is() || - schema.at("additionalProperties").get())) { - additional_property = schema.at("additionalProperties"); - additional_suffix = "addl"; - } - - if (schema.count("additionalProperties") == 0) { - picojson::value unevaluated = schema.count("unevaluatedProperties") - ? schema.at("unevaluatedProperties") - : picojson::value(!strict_mode_); - if (!unevaluated.is() || unevaluated.get()) { - additional_property = unevaluated; - additional_suffix = "uneval"; - } - } - - bool is_all_properties_optional = - std::all_of(properties.begin(), properties.end(), - [&](const auto& prop) { return required.count(prop.first) == 0; }); - - if (is_all_properties_optional && properties.size() > 0) { - // 3.1 Case 1: properties are defined and all properties are optional - result += " " + GetPartialRuleForPropertiesAllOptional(properties, additional_property, - rule_name, additional_suffix); - could_be_empty = true; - } else if (properties.size() > 0) { - // 3.2 Case 2: properties are defined and some properties are required - result += " " + GetPartialRuleForPropertiesContainRequired(properties, required, rule_name); - if (!additional_property.is() || additional_property.get()) { - std::string other_property_pattern = - GetOtherPropertyPattern(kBasicString, additional_property, rule_name, additional_suffix); - result += " (" + NextSeparator() + " " + other_property_pattern + ")*"; - } - result += " " + NextSeparator(true); - } else if (!additional_property.is() || additional_property.get()) { - // 3.3 Case 3: no properties are defined and additional properties are allowed - std::string other_property_pattern = - GetOtherPropertyPattern(kBasicString, additional_property, rule_name, additional_suffix); - result += " " + NextSeparator() + " " + other_property_pattern + " ("; - result += NextSeparator() + " " + other_property_pattern + ")* "; - result += NextSeparator(true); - could_be_empty = true; - } - - indentManager_->EndIndent(); - - result += " \"}\""; - if (could_be_empty) { - result = "(" + result + ") | \"{}\""; - } - - return result; -}; - -std::string JSONSchemaToEBNF(std::string schema, std::optional indent, - std::optional> separators, - bool strict_mode) { - picojson::value schema_value; - std::string err = picojson::parse(schema_value, schema); - if (!err.empty()) { - LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << schema; - } - JSONSchemaToEBNFConverter converter(schema_value, indent, separators, strict_mode); - return converter.Convert(); -} - -TVM_REGISTER_GLOBAL("mlc.grammar.DebugJSONSchemaToEBNF") - .set_body([](TVMArgs args, TVMRetValue* rv) { - std::optional indent; - if (args[1].type_code() != kTVMNullptr) { - indent = args[1]; - } else { - indent = std::nullopt; - } - - std::optional> separators; - if (args[2].type_code() != kTVMNullptr) { - Array separators_arr = args[2]; - CHECK(separators_arr.size() == 2); - separators = std::make_pair(separators_arr[0], separators_arr[1]); - } else { - separators = std::nullopt; - } - - *rv = JSONSchemaToEBNF(args[0], indent, separators, args[3]); - }); - -} // namespace serve -} // namespace llm -} // namespace mlc diff --git a/cpp/grammar/json_schema_converter.h b/cpp/grammar/json_schema_converter.h deleted file mode 100644 index 52044d21bb..0000000000 --- a/cpp/grammar/json_schema_converter.h +++ /dev/null @@ -1,44 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/json_grammar_converter.h - * \brief The header for translating JSON schema to EBNF grammar. - */ - -#ifndef MLC_LLM_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ -#define MLC_LLM_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ - -#include -#include -#include - -namespace mlc { -namespace llm { -namespace serve { - -/*! - * \brief Convert JSON schema string to EBNF grammar string. - * \param json_schema The JSON schema string. - * \param indent The number of spaces for indentation. If set to std::nullopt, the output will be - * in one line. Default: 2. - * \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"}, - * {", ", ": "}. If std::nullopt, the default separators will be used: {",", ": "} when the - * indent is not -1, and {", ", ": "} otherwise. This follows the convention in python json.dumps(). - * Default: std::nullopt. - * \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not - * allow properties and items that is not specified in the schema. This is equivalent to - * setting unevaluatedProperties and unevaluatedItems to false. - * - * This helps LLM to generate accurate output in the grammar-guided generation with JSON - * schema. Default: true. - * \returns The EBNF grammar string. - */ -std::string JSONSchemaToEBNF( - std::string schema, std::optional indent = 2, - std::optional> separators = std::nullopt, - bool strict_mode = true); - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ diff --git a/cpp/grammar/support.h b/cpp/grammar/support.h deleted file mode 100644 index ec721aa004..0000000000 --- a/cpp/grammar/support.h +++ /dev/null @@ -1,94 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file grammar/support.h - * \brief The header for utilities used in grammar-guided generation. - */ -#ifndef MLC_LLM_GRAMMAR_SUPPORT_H_ -#define MLC_LLM_GRAMMAR_SUPPORT_H_ - -#include - -#include -#include -#include -#include - -namespace mlc { -namespace llm { -namespace serve { - -/*! - * \brief Let lhs be the union of lhs and rhs. Suppose that both sets are sorted. - * \note No additional vectors are allocated, and the time complexity is O(n) - */ -inline void IntsetUnion(std::vector* lhs, const std::vector& rhs) { - int original_lhs_size = lhs->size(); - int rhs_size = rhs.size(); - - lhs->resize(original_lhs_size + rhs_size); - - auto it_lhs = lhs->rbegin() + rhs_size; - auto it_rhs = rhs.rbegin(); - auto it_result = lhs->rbegin(); - - while (it_lhs != lhs->rend() && it_rhs != rhs.rend()) { - if (*it_lhs > *it_rhs) { - *it_result = *it_lhs; - ++it_lhs; - } else if (*it_lhs < *it_rhs) { - *it_result = *it_rhs; - ++it_rhs; - } else { - *it_result = *it_lhs; - ++it_lhs; - ++it_rhs; - } - ++it_result; - } - - while (it_rhs != rhs.rend()) { - *it_result = *it_rhs; - ++it_result; - ++it_rhs; - } - - auto last = std::unique(lhs->begin(), lhs->end()); - lhs->erase(last, lhs->end()); -} - -/*! - * \brief Let lhs be the intersection of lhs and rhs. Suppose that both sets are sorted. - * \note No additional vector is allocated, and the time complexity is O(n). - * \note Support the case where lhs is the universal set by setting lhs to {-1}. The result will be - * rhs then. - */ -inline void IntsetIntersection(std::vector* lhs, const std::vector& rhs) { - if (lhs->size() == 1 && (*lhs)[0] == -1) { - *lhs = rhs; - return; - } - - auto it_lhs = lhs->begin(); - auto it_rhs = rhs.begin(); - auto it_result = lhs->begin(); - - while (it_lhs != lhs->end() && it_rhs != rhs.end()) { - if (*it_lhs < *it_rhs) { - ++it_lhs; - } else if (*it_lhs > *it_rhs) { - ++it_rhs; - } else { - *it_result = *it_lhs; - ++it_lhs; - ++it_rhs; - ++it_result; - } - } - lhs->erase(it_result, lhs->end()); -} - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_GRAMMAR_SUPPORT_H_ diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 67b312a0bf..728d76b78d 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -21,7 +22,6 @@ #include #include -#include "../grammar/grammar_state_matcher.h" #include "../support/json_parser.h" #include "../support/result.h" #include "../support/utils.h" @@ -382,7 +382,7 @@ class EngineImpl : public Engine { // - Initialize tokenizer and grammar n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0])); n->token_table_ = n->tokenizer_->PostProcessedTokenTable(); - n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_); + n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_); // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -495,13 +495,12 @@ class EngineImpl : public Engine { int n = request->generation_cfg->n; int rng_seed = request->generation_cfg->seed; - auto grammar_state_init_ctx = - GetGrammarInitCtxFromResponseFormat(request->generation_cfg->response_format); + auto compiled_grammar = GetGrammarFromResponseFormat(request->generation_cfg->response_format); std::vector rsentries; // Create the request state entry for the input. rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), rng_seed, - token_table_, grammar_state_init_ctx); + token_table_, compiled_grammar); if (n > 1) { // Then create a request state entry for each parallel generation branch. // We add a offset to the rng seed so that to make generations different. @@ -510,7 +509,7 @@ class EngineImpl : public Engine { for (int i = 0; i < n; ++i) { rsentries[0]->child_indices.push_back(rsentries.size()); rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), - rng_seed + i + 1, token_table_, grammar_state_init_ctx, + rng_seed + i + 1, token_table_, compiled_grammar, /*parent_idx=*/0); } } @@ -814,14 +813,14 @@ class EngineImpl : public Engine { /*! \brief Create a grammar init context according to the response format. If the response format * is not JSON, return std::nullopt. */ - std::optional> GetGrammarInitCtxFromResponseFormat( + std::optional GetGrammarFromResponseFormat( const ResponseFormat& response_format) { if (response_format.type != "json_object") { return std::nullopt; } else if (!response_format.schema) { - return grammar_init_context_cache_->GetInitContextForJSON(); + return cached_grammar_compiler_.GetCompiledGrammarForJSON(); } else { - return grammar_init_context_cache_->GetInitContextForJSONSchema( + return cached_grammar_compiler_.GetCompiledGrammarForJSONSchema( response_format.schema.value()); } } @@ -833,8 +832,8 @@ class EngineImpl : public Engine { // internal tokenizer Tokenizer tokenizer_; std::vector token_table_; - // Helper to get the grammar init context for requests. - GrammarInitContextCache grammar_init_context_cache_; + // Cached grammar compiler for grammar matching. + xgrammar::CachedGrammarCompiler cached_grammar_compiler_; // Models Array models_; // Device that the models run on. diff --git a/cpp/serve/engine_actions/batch_jumpforward.cc b/cpp/serve/engine_actions/batch_jumpforward.cc index 8d9858e73a..894adb6bc4 100644 --- a/cpp/serve/engine_actions/batch_jumpforward.cc +++ b/cpp/serve/engine_actions/batch_jumpforward.cc @@ -65,7 +65,7 @@ class BatchJumpForwardActionObj : public EngineActionObj { } auto mstate = rsentry->mstates[0]; - auto jump_forward_str = mstate->grammar_state_matcher.value()->FindJumpForwardString(); + auto jump_forward_str = mstate->grammar_matcher->FindJumpForwardString(); if (jump_forward_str.empty()) { continue; @@ -116,7 +116,7 @@ class BatchJumpForwardActionObj : public EngineActionObj { if (rsentry->request->generation_cfg->logprobs) { return false; } - if (!rsentry->mstates[0]->grammar_state_matcher) { + if (!rsentry->mstates[0]->grammar_matcher) { return false; } return true; diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index e7d8b9b273..de595f0b9c 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -232,7 +232,7 @@ void FunctionTable::_InitFunctions() { this->apply_bitmask_func_ = mod->GetFunction("apply_bitmask_inplace", true); this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); - if (!this->create_kv_cache_func_.defined()) { + if (this->model_metadata_.sliding_window_size != -1 || !this->create_kv_cache_func_.defined()) { PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); if (f_create_rnn_state.defined()) { this->create_kv_cache_func_ = f_create_rnn_state; diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index f20303490d..7355fa1e5b 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -62,7 +62,7 @@ class LogitProcessorImpl : public LogitProcessorObj { NDArray::Empty({max_num_token * vocab_size}, dtype_f32_, preferred_host_device); penalties_host_ = NDArray::Empty({max_num_token, 3}, dtype_f32_, preferred_host_device); bitmask_host_ = - NDArray::Empty({max_num_token, bitmask_size_}, dtype_u32_, preferred_host_device); + NDArray::Empty({max_num_token, bitmask_size_}, dtype_i32_, preferred_host_device); temperature_host_ = NDArray::Empty({max_num_token}, dtype_f32_, preferred_host_device); // Initialize auxiliary arrays on GPU. seq_ids_device_ = NDArray::Empty({max_num_token}, dtype_i32_, device); @@ -403,7 +403,7 @@ class LogitProcessorImpl : public LogitProcessorObj { (*draft_mstates)[i]->draft_token_parent_idx[cur_draft_token_index]; } for (auto it = draft_token_seq.rbegin(); it != draft_token_seq.rend(); ++it) { - mstates[i]->grammar_state_matcher.value()->AcceptToken(it->GetTokenId()); + mstates[i]->grammar_matcher.value().AcceptToken(it->GetTokenId()); } } // Find a slice of bitmask_host_: bitmask_host_[num_token_for_mask, :] @@ -413,11 +413,11 @@ class LogitProcessorImpl : public LogitProcessorObj { bitmask_dltensor.shape = bitmask_shape; bitmask_dltensor.ndim = 1; - mstates[i]->FindNextTokenBitmask(&bitmask_dltensor); + mstates[i]->GetNextTokenBitmask(&bitmask_dltensor); p_seq_ids[token_start_offset + j] = 1; if (draft_token_seq.size() > 0) { - mstates[i]->grammar_state_matcher.value()->Rollback(draft_token_seq.size()); + mstates[i]->grammar_matcher.value().Rollback(draft_token_seq.size()); } } } diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index caa762f598..d582803c16 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -15,15 +15,16 @@ TVM_REGISTER_OBJECT_TYPE(RequestModelStateNode); RequestModelState::RequestModelState( Request request, int model_id, int64_t internal_id, Array inputs, - const std::optional>& grammar_state_init_ctx) { + const std::optional& compiled_grammar) { ObjectPtr n = make_object(); n->model_id = model_id; n->internal_id = internal_id; n->inputs = std::move(inputs); - if (grammar_state_init_ctx.has_value()) { + if (compiled_grammar.has_value()) { // TODO(yixin): set rollback limit to a configurable value. - n->grammar_state_matcher = GrammarStateMatcher(grammar_state_init_ctx.value(), 10); + n->grammar_matcher = + xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, std::nullopt, 10); } n->request = std::move(request); @@ -38,12 +39,12 @@ int RequestModelStateNode::GetInputLength() const { return total_length; } -bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_state_matcher.defined(); } +bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_matcher.has_value(); } -void RequestModelStateNode::FindNextTokenBitmask(DLTensor* bitmask) { - ICHECK(grammar_state_matcher.defined()); +void RequestModelStateNode::GetNextTokenBitmask(DLTensor* bitmask) { + ICHECK(grammar_matcher.has_value()); - grammar_state_matcher.value()->FindNextTokenBitmask(bitmask); + grammar_matcher->GetNextTokenBitmask(bitmask); } void RequestModelStateNode::CommitToken(SampleResult sampled_token) { @@ -53,8 +54,8 @@ void RequestModelStateNode::CommitToken(SampleResult sampled_token) { ++num_tokens_for_next_decode; // Update the grammar matcher state if it exists. - if (grammar_state_matcher) { - bool accepted = grammar_state_matcher.value()->AcceptToken(sampled_token.GetTokenId()); + if (grammar_matcher) { + bool accepted = grammar_matcher->AcceptToken(sampled_token.GetTokenId()); ICHECK(accepted) << "Token id " << sampled_token.GetTokenId() << " is not accepted by the grammar state matcher."; } @@ -69,8 +70,8 @@ void RequestModelStateNode::RollbackTokens(int count) { appeared_token_ids.erase(it); } committed_tokens.pop_back(); - if (grammar_state_matcher) { - grammar_state_matcher.value()->Rollback(1); + if (grammar_matcher) { + grammar_matcher->Rollback(1); } } } @@ -143,8 +144,7 @@ TVM_REGISTER_OBJECT_TYPE(RequestStateEntryNode); RequestStateEntry::RequestStateEntry( Request request, int num_models, int64_t internal_id, int rng_seed, const std::vector& token_table, - const std::optional>& grammar_state_init_ctx, - int parent_idx) { + const std::optional& compiled_grammar, int parent_idx) { ObjectPtr n = make_object(); Array mstates; Array inputs; @@ -153,7 +153,7 @@ RequestStateEntry::RequestStateEntry( } mstates.reserve(num_models); for (int i = 0; i < num_models; ++i) { - mstates.push_back(RequestModelState(request, i, internal_id, inputs, grammar_state_init_ctx)); + mstates.push_back(RequestModelState(request, i, internal_id, inputs, compiled_grammar)); } n->status = RequestStateStatus::kPending; n->rng = RandomGenerator(rng_seed); @@ -233,8 +233,8 @@ void RequestStateEntryNode::GetDeltaRequestReturn(const Tokenizer& tokenizer, // Case 4. When stop token is not detected (e.g. ignore_eos is set), but the grammar state is // terminated, stop the generation and pop the last token (used to trigger the termination). if ((*delta_stream_output)->group_finish_reason[idx] != "stop" && - this->mstates[0]->grammar_state_matcher.defined() && - this->mstates[0]->grammar_state_matcher.value()->IsTerminated()) { + this->mstates[0]->grammar_matcher.has_value() && + this->mstates[0]->grammar_matcher->IsTerminated()) { (*delta_stream_output)->group_delta_token_ids[idx].pop_back(); (*delta_stream_output)->group_finish_reason[idx] = "stop"; } diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 855c05bfd8..e6e3821b2c 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -9,10 +9,10 @@ #include #include #include +#include #include -#include "../grammar/grammar_state_matcher.h" #include "../support/random.h" #include "../tokenizers/streamer.h" #include "config.h" @@ -86,9 +86,9 @@ class RequestModelStateNode : public Object { /*! * \brief The current state of the generated token matching the grammar. Used in grammar-guided - * generation, otherwise it's NullOpt. + * generation, otherwise it's std::nullopt. */ - Optional grammar_state_matcher; + std::optional grammar_matcher; /*! \brief Return the total length of the input data. */ int GetInputLength() const; @@ -102,7 +102,7 @@ class RequestModelStateNode : public Object { * \param bitmask The DLTensor to store the next token bitmask. The bitmask should be a tensor * with dtype uint32_t and shape (ceildiv(vocab_size, 32),). */ - void FindNextTokenBitmask(DLTensor* bitmask); + void GetNextTokenBitmask(DLTensor* bitmask); /*! \brief Commit a new token into committed_tokens. Does not effect the kv cache. Update * appeared_token_ids and the grammar state. */ void CommitToken(SampleResult sampled_token); @@ -123,9 +123,8 @@ class RequestModelStateNode : public Object { class RequestModelState : public ObjectRef { public: - explicit RequestModelState( - Request request, int model_id, int64_t internal_id, Array inputs, - const std::optional>& grammar_state_init_ctx); + explicit RequestModelState(Request request, int model_id, int64_t internal_id, Array inputs, + const std::optional& compiled_grammar); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestModelState, ObjectRef, RequestModelStateNode); }; @@ -255,11 +254,10 @@ class RequestStateEntryNode : public Object { class RequestStateEntry : public ObjectRef { public: - explicit RequestStateEntry( - Request request, int num_models, int64_t internal_id, int rng_seed, - const std::vector& token_table, - const std::optional>& grammar_state_init_ctx, - int parent_idx = -1); + explicit RequestStateEntry(Request request, int num_models, int64_t internal_id, int rng_seed, + const std::vector& token_table, + const std::optional& compiled_grammar, + int parent_idx = -1); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestStateEntry, ObjectRef, RequestStateEntryNode); }; diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 8a49f1936f..5a91b297ea 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include "../../support/random.h" #include "sampler.h" diff --git a/cpp/support/utils.h b/cpp/support/utils.h index b1e3875f8d..7674699907 100644 --- a/cpp/support/utils.h +++ b/cpp/support/utils.h @@ -40,24 +40,6 @@ inline bool StartsWith(const std::string& str, const char* prefix) { return prefix[n] == '\0'; } -/*! - * \brief Hash and combine value into seed. - * \ref https://www.boost.org/doc/libs/1_84_0/boost/intrusive/detail/hash_combine.hpp - */ -inline void HashCombineBinary(uint32_t& seed, uint32_t value) { - seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - -/*! - * \brief Find the hash sum of several uint32_t args. - */ -template -uint32_t HashCombine(Args... args) { - uint32_t seed = 0; - (..., HashCombineBinary(seed, args)); - return seed; -} - } // namespace llm } // namespace mlc diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 6eda3b3537..690cf418b0 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -996,7 +996,7 @@ Note that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` ar --conv-template CONV_TEMPLATE Conversation template. It depends on how the model is tuned. Use "LM" for vanilla base model For existing pre-defined templates, see ``CONV_TEMPLATES`` - `here `_. + `here `_. --context-window-size CONTEXT_WINDOW_SIZE Option to provide the maximum sequence length supported by the model. This is usually explicitly shown as context length or context window in the model card. diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index c0b9ea2fbb..1bca2de439 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -102,8 +102,8 @@ See :ref:`compile-command-specification` for specification of ``gen_config``. ``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions). You can also simply use the default configuration. - `conversation_template.py `__ - contains a full list of conversation templates that MLC provides. If the model you are adding + `conversation_template `__ + directory contains a full list of conversation templates that MLC provides. If the model you are adding requires a new conversation template, you would need to add your own. Follow `this PR `__ as an example. However, adding your own template would require you :ref:`build mlc_llm from source ` in order for it diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index 4b2317012c..0cd724af77 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -136,7 +136,7 @@ We have a one-line command to build and prepare all the model libraries: .. code:: bash cd /path/to/MLCChat # e.g., "android/MLCChat" - export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm # e.g., "../.." + export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm # has to be absolute path, ../.. does not work mlc_llm package This command mainly executes the following two steps: diff --git a/docs/deploy/mlc_chat_config.rst b/docs/deploy/mlc_chat_config.rst index d5e5628fc2..4222a2ccd8 100644 --- a/docs/deploy/mlc_chat_config.rst +++ b/docs/deploy/mlc_chat_config.rst @@ -110,7 +110,7 @@ supported conversation templates: - ``phi-2`` - ... -Please refer to `conversation_template.py `_ for the full list of supported templates and their implementations. +Please refer to `conversation_template `_ directory for the full list of supported templates and their implementations. Below is a generic structure of a JSON conversation configuration (we use vicuna as an example): diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index 72122c343c..1123596dff 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -276,7 +276,7 @@ Below is an example command of compiling model libraries in MLC LLM: .. code:: bash - export $MODEL_LIB=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. + export MODEL_LIB=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. # ".dll" for Windows. # ".wasm" for web. # ".tar" for iPhone/Android. diff --git a/python/mlc_llm/grammar/__init__.py b/python/mlc_llm/grammar/__init__.py deleted file mode 100644 index 89cff27828..0000000000 --- a/python/mlc_llm/grammar/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Namespace for grammar handling""" - -from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/grammar/_ffi_api.py b/python/mlc_llm/grammar/_ffi_api.py deleted file mode 100644 index 549457fb94..0000000000 --- a/python/mlc_llm/grammar/_ffi_api.py +++ /dev/null @@ -1,6 +0,0 @@ -"""FFI APIs for mlc_llm grammar""" - -import tvm._ffi - -# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.grammar" prefix. -tvm._ffi._init_api("mlc.grammar", __name__) # pylint: disable=protected-access diff --git a/python/mlc_llm/grammar/grammar.py b/python/mlc_llm/grammar/grammar.py deleted file mode 100644 index 97cb30c719..0000000000 --- a/python/mlc_llm/grammar/grammar.py +++ /dev/null @@ -1,406 +0,0 @@ -"""Classes handling the grammar guided generation of MLC LLM serving""" - -from typing import List, Optional, Tuple, Union - -import tvm -import tvm._ffi -from tvm.runtime import Object - -from ..tokenizers import Tokenizer -from . import _ffi_api - - -@tvm._ffi.register_object("mlc.grammar.BNFGrammar") # pylint: disable=protected-access -class BNFGrammar(Object): - """This class stores the abstract syntax tree (AST) of the Backus-Naur Form (BNF) grammar and - provides utilities to parse and print the AST. User should provide a BNF/EBNF (Extended - Backus-Naur Form) grammar, and use from_ebnf_string to parse and simplify the grammar into an - AST of BNF grammar. - """ - - @staticmethod - def from_ebnf_string( - ebnf_string: str, - main_rule: str = "main", - ) -> "BNFGrammar": - r"""Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized - (simplified) by default. - - EBNF grammar: see https://www.w3.org/TR/xml/#sec-notation. Note: - 1. Use # as the comment mark - 2. Use C-style unicode escape sequence \u01AB, \U000001AB, \xAB - 3. A-B (match A and not match B) is not supported yet - 4. Lookahead assertion can be added at the end of a rule to speed up matching. E.g. - ``` - main ::= "ab" a [a-z] - a ::= "cd" (=[a-z]) - ``` - The assertion (=[a-z]) means a must be followed by [a-z]. - - Parameters - ---------- - ebnf_string : str - The grammar string. - - main_rule : str - The name of the main rule. Default: "main". - - Returns - ------- - grammar : BNFGrammar - The parsed BNF grammar. - """ - return _ffi_api.BNFGrammarFromEBNFString( # type: ignore # pylint: disable=no-member - ebnf_string, main_rule - ) - - def to_string(self) -> str: - """Print the BNF grammar to a string, in standard BNF format. - - Returns - ------- - grammar_string : str - The BNF grammar string. - """ - return str(_ffi_api.BNFGrammarToString(self)) # type: ignore # pylint: disable=no-member - - def __str__(self) -> str: - return self.to_string() - - @staticmethod - def from_json(json_string: str) -> "BNFGrammar": - """Load a BNF grammar from the raw representation of the AST in JSON format. - - Parameters - ---------- - json_string : str - The JSON string. - - Returns - ------- - grammar : BNFGrammar - The loaded BNF grammar. - """ - return _ffi_api.BNFGrammarFromJSON(json_string) # type: ignore # pylint: disable=no-member - - def to_json(self, prettify: bool = True) -> str: - """Serialize the AST. Dump the raw representation of the AST to a JSON file. - - Parameters - ---------- - prettify : bool - Whether to format the JSON string. If False, all whitespaces will be removed. - - Returns - ------- - json_string : str - The JSON string. - """ - return str( - _ffi_api.BNFGrammarToJSON(self, prettify) # type: ignore # pylint: disable=no-member - ) - - @staticmethod - def from_schema( - schema: str, - *, - indent: Optional[int] = 2, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True - ) -> "BNFGrammar": - """Construct a BNF grammar from the json schema string. The schema string should be in the - format of the schema of a JSON file. We will parse the schema and generate a BNF grammar. - - Parameters - ---------- - schema : str - The schema string. - - indent : Optional[int] - The number of spaces for indentation. If None, the output will be in one line. - Default: None. - - separators : Optional[Tuple[str, str]] - Two separators used in the schema: comma and colon. Examples: (",", ":"), (", ", ": "). - If None, the default separators will be used: (",", ": ") when the indent is not None, - and (", ", ": ") otherwise. This follows the convention in json.dumps(). Default: None. - - strict_mode : bool - Whether to use strict mode. In strict mode, the generated grammar will not allow - properties and items that is not specified in the schema. This is equivalent to - setting unevaluatedProperties and unevaluatedItems to false. - - This helps LLM to generate accurate output in the grammar-guided generation with JSON - schema. Default: True. - - Returns - ------- - grammar : BNFGrammar - The generated BNF grammar. - """ - return _ffi_api.BNFGrammarFromSchema( # type: ignore # pylint: disable=no-member - schema, indent, separators, strict_mode - ) - - @staticmethod - def get_grammar_of_json() -> "BNFGrammar": - """Get the grammar of standard JSON. - - Returns - ------- - grammar : BNFGrammar - The JSON grammar. - """ - return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member - - @staticmethod - def debug_from_ebnf_string_no_normalize( - ebnf_string: str, - main_rule: str = "main", - ) -> "BNFGrammar": - r"""Construct a BNF grammar with a EBNF-formatted string, but not normalize it. - For test purposes. - - Parameters - ---------- - ebnf_string : str - The grammar string. - - main_rule : str - The name of the main rule. Default: "main". - - Returns - ------- - grammar : BNFGrammar - The parsed BNF grammar. - """ - return _ffi_api.BNFGrammarDebugFromEBNFStringNoNormalize( # type: ignore # pylint: disable=no-member - ebnf_string, main_rule - ) - - @staticmethod - def debug_json_schema_to_ebnf( - schema: str, - *, - indent: Optional[int] = 2, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True - ) -> str: - """Convert JSON schema string to EBNF grammar string. For test purposes. - - Parameters - ---------- - json_schema : str - The JSON schema string. - - indent : Optional[int] - The number of spaces for indentation. If None, the output will be in one line. - Default: 2. - - separators : Optional[Tuple[str, str]] - Two separators used in the schema: comma and colon. Examples: (",", ":"), (", ", ": "). - If None, the default separators will be used: (",", ": ") when the indent is not None, - and (", ", ": ") otherwise. This follows the convention in json.dumps(). Default: None. - - strict_mode : bool - Whether to use strict mode. In strict mode, the generated grammar will not allow - properties and items that is not specified in the schema. This is equivalent to - setting unevaluatedProperties and unevaluatedItems to false. - - This helps LLM to generate accurate output in the grammar-guided generation with JSON - schema. Default: True. - - Returns - ------- - ebnf_string : str - The EBNF grammar string. - """ - return _ffi_api.DebugJSONSchemaToEBNF( # type: ignore # pylint: disable=no-member - schema, indent, separators, strict_mode - ) - - -@tvm._ffi.register_object("mlc.grammar.GrammarStateMatcher") # pylint: disable=protected-access -class GrammarStateMatcher(Object): - """A stateful matcher to match tokens to the specified BNF grammar. This class is the core logic - of the grammar-guided generation. - - This class implements the non-deterministic pushdown automaton (NPDA) matching algorithm to - match characters to a BNF grammar. It keep track of the current state of the matching process by - maintaining several stacks internally as possible paths in the NPDA. It also supports - backtracking. - - It is particularly capable of finding the set of tokens that are acceptable for the next step - and storing them in a bitmask. This aids in grammar-guided generation. - - Parameters - ---------- - grammar : BNFGrammar - The BNF grammar to match. - - tokenizer : Union[None, Tokenizer, List[str]] - The tokenizer to use, or the list of tokens. - - (For debug purpose) If None, the matcher will use an empty token set, and can only accept - and match characters. Default: None. - - max_rollback_steps : int - The maximum number of steps to rollback when backtracking. Default: 0. - """ - - def __init__( - self, - grammar: BNFGrammar, - tokenizer: Union[None, Tokenizer, List[str]] = None, - max_rollback_steps: int = 0, - ): - if isinstance(tokenizer, list): - self.__init_handle_by_constructor__( - _ffi_api.GrammarStateMatcherFromTokenTable, # type: ignore # pylint: disable=no-member - grammar, - tokenizer, - max_rollback_steps, - ) - else: - self.__init_handle_by_constructor__( - _ffi_api.GrammarStateMatcherFromTokenizer, # type: ignore # pylint: disable=no-member - grammar, - tokenizer, - max_rollback_steps, - ) - - def accept_token(self, token_id: int) -> bool: - """Accept one token and update the state of the matcher. - - Parameters - ---------- - token_id : int - The id of the token to accept. - - Returns - ------- - accepted : bool - Whether the token is accepted. - - Note - ---- - Termination state. - - When the end of the main rule is reached, the matcher can only accept the stop token. - The matcher is terminated after accepting the stop token, i.e. no accept_token or - find_next_rejected_tokens operations can be performed. The termination state can be canceled - using Rollback(). - """ - return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id, False) # type: ignore # pylint: disable=no-member - - def find_next_rejected_tokens(self, verbose: bool = False) -> List[int]: - """Find the ids of the rejected tokens for the next step. - - Parameters - ---------- - verbose : bool - Whether to print information about timing and result counts to stderr. - For debug purposes. Default: False. - - Returns - ------- - rejected_token_ids : List[int] - A list of rejected token ids. - """ - - return _ffi_api.GrammarStateMatcherFindNextRejectedTokens(self, verbose) # type: ignore # pylint: disable=no-member - - def find_next_token_bitmask_as_ndarray(self, full_vocab_size: int) -> tvm.nd.array: - """Find the bitmask for the next step. - - Parameters - ---------- - full_vocab_size: int - Different from `tokenizer->GetVocabSize()` or `init_ctx_->vocab_size`, this is the - vocab_size read from `config.json` that can be potentially larger. - - Returns - ------- - bitmask_ndarray : tvm.nd.array - Bitmask for the next step. - """ - - return _ffi_api.GrammarStateMatcherFindNextTokenBitmaskAsNDArray(self, full_vocab_size) # type: ignore # pylint: disable=no-member - - def find_jump_forward_string(self) -> str: - """Find the jump-forward string for jump-forward decoding. This is the longest string that - will be valid according to the current syntax. - - Notes - ----- - This method does not change the grammar state. - - Returns - ------- - jump_forward_string : str - The jump-forward string. - """ - return _ffi_api.GrammarStateMatcherFindJumpForwardString(self) # type: ignore # pylint: disable=no-member - - def rollback(self, num_tokens: int) -> None: - """Rollback the matcher to a previous state. - - Parameters - ---------- - num_tokens : int - The number of tokens to rollback. It cannot exceed the current number of steps, nor can - it exceed the specified maximum number of rollback steps. - """ - _ffi_api.GrammarStateMatcherRollback(self, num_tokens) # type: ignore # pylint: disable=no-member - - def max_rollback_steps(self) -> int: - """Get the maximum number of rollback steps allowed. - - Returns - ------- - max_rollback_steps : int - The maximum number of rollback steps. - """ - return _ffi_api.GrammarStateMatcherMaxRollbackSteps(self) # type: ignore # pylint: disable=no-member - - def reset_state(self) -> None: - """Reset the matcher to the initial state.""" - _ffi_api.GrammarStateMatcherResetState(self) # type: ignore # pylint: disable=no-member - - def is_terminated(self) -> bool: - """Check if the matcher has accepted the stop token and terminated. See also - GrammarStateMatcher.accept_token. - - Returns - ------- - terminated : bool - Whether the matcher has terminated. - """ - return _ffi_api.GrammarStateMatcherIsTerminated(self) # type: ignore # pylint: disable=no-member - - def debug_accept_char(self, codepoint: int, verbose: bool = False) -> bool: - """Accept one unicode codepoint to the current state. For test purposes. - - Parameters - ---------- - codepoint : int - The unicode codepoint of the character to be accepted. - """ - return _ffi_api.GrammarStateMatcherDebugAcceptChar( # type: ignore # pylint: disable=no-member - self, codepoint, verbose - ) - - def debug_match_complete_string(self, string: str, verbose: bool = False) -> bool: - """Check if the matcher can accept the complete string, and then reach the end of the - grammar. Does not change the state of the GrammarStateMatcher. For test purposes. - - Parameters - ---------- - string : str - The string to be matched. - """ - return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string, verbose) # type: ignore # pylint: disable=no-member - - def set_stop_token_ids(self, stop_token_ids: List[int]) -> None: - """Set the stop token ids, overriding the default ones.""" - _ffi_api.GrammarStateMatcherSetStopTokenIds(self, tvm.runtime.ShapeTuple(stop_token_ids)) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_llm/model/deepseek/deepseek_model.py b/python/mlc_llm/model/deepseek/deepseek_model.py index e6bf120478..96f544162b 100644 --- a/python/mlc_llm/model/deepseek/deepseek_model.py +++ b/python/mlc_llm/model/deepseek/deepseek_model.py @@ -79,17 +79,17 @@ def __post_init__(self): logger.info( "%s defaults to %d", bold("prefill_chunk_size"), - min(self.context_window_size, 2048), + min(self.context_window_size, 8192), ) - self.prefill_chunk_size = min(self.context_window_size, 2048) + self.prefill_chunk_size = min(self.context_window_size, 8192) elif self.prefill_chunk_size > self.context_window_size: logger.info( "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - min(self.context_window_size, 2048), + min(self.context_window_size, 8192), ) - self.prefill_chunk_size = min(self.context_window_size, 2048) + self.prefill_chunk_size = min(self.context_window_size, 8192) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index c492b47be2..a0f3b8d197 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1856,7 +1856,7 @@ def _generate( # pylint: disable=too-many-locals generation_config: GenerationConfig, request_id: str, ) -> Iterator[List[engine_base.CallbackStreamOutput]]: - """Internal synchronous text generation interface of AsyncMLCEngine. + """Internal synchronous text generation interface of MLCEngine. The method is a coroutine that streams a list of CallbackStreamOutput at a time via yield. The returned list length is the number of parallel generations specified by `generation_config.n` diff --git a/python/setup.py b/python/setup.py index 60fb57d88c..9022920f5e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -104,6 +104,7 @@ def main(): "safetensors", "requests", "tqdm", + "sentencepiece", "tiktoken", "prompt_toolkit", "openai", diff --git a/tests/python/grammar/test_grammar_parser.py b/tests/python/grammar/test_grammar_parser.py deleted file mode 100644 index 4a53743dbc..0000000000 --- a/tests/python/grammar/test_grammar_parser.py +++ /dev/null @@ -1,339 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-function-docstring -import json -import os - -import pytest -import tvm.testing -from tvm import TVMError - -from mlc_llm.grammar import BNFGrammar - - -def test_bnf_simple(): - before = """main ::= b c -b ::= "b" -c ::= "c" -""" - expected = """main ::= ((b c)) -b ::= (("b")) -c ::= (("c")) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - print(after) - print(expected) - assert after == expected - - -def test_ebnf(): - before = """main ::= b c | b main -b ::= "ab"* -c ::= [acep-z]+ -d ::= "d"? -""" - expected = """main ::= ((b c) | (b main)) -b ::= ((b_1)) -c ::= ((c_1)) -d ::= ((d_1)) -b_1 ::= ("" | ("ab" b_1)) -c_1 ::= (([acep-z] c_1) | ([acep-z])) -d_1 ::= ("" | ("d")) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - assert after == expected - - -def test_star_quantifier(): - before = """main ::= b c d -b ::= [b]* -c ::= "b"* -d ::= ([b] [c] [d] | ([p] [q]))* -e ::= [e]* [f]* | [g]* -""" - expected = """main ::= ((b c d)) -b ::= (([b]*)) -c ::= ((c_1)) -d ::= ((d_1)) -e ::= (([e]* [f]*) | ([g]*)) -c_1 ::= ("" | ("b" c_1)) -d_1 ::= ("" | (d_1_choice d_1)) -d_1_choice ::= (("bcd") | ("pq")) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - assert after == expected - - -def test_lookahead_assertion(): - before = """main ::= ((b c d)) -b ::= (("abc" [a-z])) (=("abc")) -c ::= (("a") | ("b")) (=([a-z] "b")) -d ::= (("ac") | ("b" d_choice)) (=("abc")) -d_choice ::= (("e") | ("d")) -""" - expected = """main ::= ((b c d)) -b ::= (("abc" [a-z])) (=("abc")) -c ::= (("a") | ("b")) (=([a-z] "b")) -d ::= (("ac") | ("b" d_choice)) (=("abc")) -d_choice ::= (("e") | ("d")) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - assert after == expected - - -def test_char(): - before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] [^a] rest -rest ::= [a-zA-Z0-9-] [\u0234-\U00000345] [测-试] [\--\]] rest1 -rest1 ::= "\?\"\'测试あc" "👀" "" [a-a] [b-b] -""" - expected = r"""main ::= (([a-z] [A-z] "\u0234\u0345\xff" [\-A-Z] [\-\-] [^a] rest)) -rest ::= (([a-zA-Z0-9\-] [\u0234-\u0345] [\u6d4b-\u8bd5] [\--\]] rest1)) -rest1 ::= (("\?\"\'\u6d4b\u8bd5\u3042c\U0001f440ab")) -""" - # Disable unwrap_nesting_rules to expose the result before unwrapping. - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - assert after == expected - - -def test_space(): - before = """ - -main::="a" "b" ("c""d" -"e") | - -"f" | "g" -""" - expected = """main ::= (("abcde") | ("f") | ("g")) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - assert after == expected - - -def test_nest(): - before = """main::= "a" ("b" | "c" "d") | (("e" "f")) -""" - expected = """main ::= (("a" main_choice) | ("ef")) -main_choice ::= (("b") | ("cd")) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - assert after == expected - - -def test_flatten(): - before = """main ::= or_test sequence_test nested_test empty_test -or_test ::= ([a] | "b") | "de" | "" | or_test | [^a-z] -sequence_test ::= [a] "a" ("b" ("c" | "d")) ("d" "e") sequence_test "" -nested_test ::= ("a" ("b" ("c" "d"))) | ("a" | ("b" | "c")) | nested_rest -nested_rest ::= ("a" | ("b" "c" | ("d" | "e" "f"))) | ((("g"))) -empty_test ::= "d" | (("" | "" "") "" | "a" "") | ("" ("" | "")) "" "" -""" - expected = """main ::= ((or_test sequence_test nested_test empty_test)) -or_test ::= ("" | ("a") | ("b") | ("de") | (or_test) | ([^a-z])) -sequence_test ::= (("aab" sequence_test_choice "de" sequence_test)) -nested_test ::= (("abcd") | ("a") | ("b") | ("c") | (nested_rest)) -nested_rest ::= (("a") | ("bc") | ("d") | ("ef") | ("g")) -empty_test ::= ("" | ("d") | ("a")) -sequence_test_choice ::= (("c") | ("d")) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - print(after) - assert after == expected - - -def test_json(): - # Adopted from https://www.crockford.com/mckeeman.html. Not optimized - before = r"""main ::= element -value ::= object | array | string | number | "true" | "false" | "null" -object ::= "{" ws "}" | "{" members "}" -members ::= member | member "," members -member ::= ws string ws ":" element -array ::= "[" ws "]" | "[" elements "]" -elements ::= element | element "," elements -element ::= ws value ws -string ::= "\"" characters "\"" -characters ::= "" | character characters -character ::= [^"\\] | "\\" escape -escape ::= "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | "u" hex hex hex hex -hex ::= [A-Fa-f0-9] -number ::= integer fraction exponent -integer ::= digit | onenine digits | "-" digit | "-" onenine digits -digits ::= digit | digit digits -digit ::= [0-9] -onenine ::= [1-9] -fraction ::= "" | "." digits -exponent ::= "" | ("e" | "E") ("" | "+" | "-") digits -ws ::= "" | "\u0020" ws | "\u000A" ws | "\u000D" ws | "\u0009" ws -""" - - expected = r"""main ::= ((element)) -value ::= ((object) | (array) | (string) | (number) | ("true") | ("false") | ("null")) -object ::= (("{" ws "}") | ("{" members "}")) -members ::= ((member) | (member "," members)) -member ::= ((ws string ws ":" element)) -array ::= (("[" ws "]") | ("[" elements "]")) -elements ::= ((element) | (element "," elements)) -element ::= ((ws value ws)) -string ::= (("\"" characters "\"")) -characters ::= ("" | (character characters)) -character ::= (([^\"\\]) | ("\\" escape)) -escape ::= (("\"") | ("\\") | ("/") | ("b") | ("f") | ("n") | ("r") | ("t") | ("u" hex hex hex hex)) -hex ::= (([A-Fa-f0-9])) -number ::= ((integer fraction exponent)) -integer ::= ((digit) | (onenine digits) | ("-" digit) | ("-" onenine digits)) -digits ::= ((digit) | (digit digits)) -digit ::= (([0-9])) -onenine ::= (([1-9])) -fraction ::= ("" | ("." digits)) -exponent ::= ("" | (exponent_choice exponent_choice_1 digits)) -ws ::= ("" | (" " ws) | ("\n" ws) | ("\r" ws) | ("\t" ws)) -exponent_choice ::= (("e") | ("E")) -exponent_choice_1 ::= ("" | ("+") | ("-")) -""" - - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - after = bnf_grammar.to_string() - assert after == expected - - -def test_to_string_roundtrip(): - """Checks the printed result can be parsed, and the parsing-printing process is idempotent.""" - - before = r"""main ::= ((b c) | (b main)) -b ::= ((b_1 d)) -c ::= ((c_1)) -d ::= ((d_1)) -b_1 ::= ("" | ("b" b_1)) -c_1 ::= ((c_2 c_1) | (c_2)) (=("abc" [a-z])) -c_2 ::= (([acep-z])) -d_1 ::= ("" | ("d")) -""" - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main") - output_string_1 = bnf_grammar_1.to_string() - bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main") - output_string_2 = bnf_grammar_2.to_string() - assert before == output_string_1 - assert output_string_1 == output_string_2 - - -def test_error(): - with pytest.raises( - TVMError, match='TVMError: EBNF parse error at line 1, column 11: Rule "a" is not defined' - ): - BNFGrammar.from_ebnf_string("main ::= a b") - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 15: Expect element" - ): - BNFGrammar.from_ebnf_string('main ::= "a" |') - - with pytest.raises(TVMError, match='TVMError: EBNF parse error at line 1, column 15: Expect "'): - BNFGrammar.from_ebnf_string('main ::= "a" "') - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 1: Expect rule name" - ): - BNFGrammar.from_ebnf_string('::= "a"') - - with pytest.raises( - TVMError, - match="TVMError: EBNF parse error at line 1, column 12: Character class should not contain " - "newline", - ): - BNFGrammar.from_ebnf_string("main ::= [a\n]") - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 11: Invalid escape sequence" - ): - BNFGrammar.from_ebnf_string(r'main ::= "\@"') - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 11: Invalid escape sequence" - ): - BNFGrammar.from_ebnf_string(r'main ::= "\uFF"') - - with pytest.raises( - TVMError, - match="TVMError: EBNF parse error at line 1, column 14: Invalid character class: " - "lower bound is larger than upper bound", - ): - BNFGrammar.from_ebnf_string(r"main ::= [Z-A]") - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 6: Expect ::=" - ): - BNFGrammar.from_ebnf_string(r'main := "a"') - - with pytest.raises( - TVMError, - match='TVMError: EBNF parse error at line 2, column 9: Rule "main" is defined multiple ' - "times", - ): - BNFGrammar.from_ebnf_string('main ::= "a"\nmain ::= "b"') - - with pytest.raises( - TVMError, - match="TVMError: EBNF parse error at line 1, column 10: " - 'The main rule with name "main" is not found.', - ): - BNFGrammar.from_ebnf_string('a ::= "a"') - - with pytest.raises( - TVMError, - match="TVMError: EBNF parse error at line 1, column 21: Unexpected lookahead assertion", - ): - BNFGrammar.from_ebnf_string('main ::= "a" (="a") (="b")') - - -def test_to_json(): - before = """main ::= b c | b main -b ::= "bcd" -c ::= [a-z] -""" - expected_obj = { - "rules": [ - {"body_expr_id": 6, "name": "main"}, - {"body_expr_id": 9, "name": "b"}, - {"body_expr_id": 12, "name": "c"}, - ], - "rule_expr_indptr": [0, 3, 6, 10, 13, 16, 20, 24, 29, 32, 35, 40, 43], - "rule_expr_data": [ - # fmt: off - 4,1,1,4,1,2,5,2,0,1,4,1,1,4,1,0,5,2,3,4,6,2,2,5,0,3,98,99, - 100,5,1,7,6,1,8,1,3,0,97,122,5,1,10,6,1,11 - # fmt: on - ], - } - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") - print(bnf_grammar) - after_str = bnf_grammar.to_json(False) - after_obj = json.loads(after_str) - assert after_obj == expected_obj - - -def test_to_json_roundtrip(): - before = r"""main ::= ((b c) | (b main)) -b ::= ((b_1 d [a]*)) -c ::= ((c_1)) -d ::= ((d_1)) -b_1 ::= ("" | ("b" b_1)) -c_1 ::= ((c_2 c_1) | (c_2)) -c_2 ::= (([acep-z])) -d_1 ::= ("" | ("d")) -""" - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main") - output_json_1 = bnf_grammar_1.to_json(False) - bnf_grammar_2 = BNFGrammar.from_json(output_json_1) - output_json_2 = bnf_grammar_2.to_json(False) - output_str = bnf_grammar_2.to_string() - assert output_json_1 == output_json_2 - assert output_str == before - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/grammar/test_grammar_state_matcher_custom.py b/tests/python/grammar/test_grammar_state_matcher_custom.py deleted file mode 100644 index a497f4e2d8..0000000000 --- a/tests/python/grammar/test_grammar_state_matcher_custom.py +++ /dev/null @@ -1,465 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-function-docstring -# pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking -"""This test is adopted from test_grammar_state_matcher_json.py, but the grammar is parsed from -a unoptimized, non-simplified EBNF string. This is to test the robustness of the grammar state -matcher.""" -import json -import sys -from typing import Dict, List, Optional, Tuple - -import pytest -import tvm -import tvm.testing -from pydantic import BaseModel - -from mlc_llm.grammar import BNFGrammar, GrammarStateMatcher -from mlc_llm.tokenizers import Tokenizer - - -def get_json_grammar(): - json_grammar_ebnf = r""" -main ::= basic_array | basic_object -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= (([\"] basic_string_1 [\"])) -basic_string_1 ::= "" | [^"\\\r\n] basic_string_1 | "\\" escape basic_string_1 -escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" -basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" -ws ::= [ \n\t]* -""" - grammar = BNFGrammar.from_ebnf_string(json_grammar_ebnf) - return grammar - - -@pytest.fixture(scope="function") -def json_grammar(): - return get_json_grammar() - - -def test_simple(): - grammar_str = """main ::= rule1 rule2 -rule1 ::= (rule2 | rule3) "a" -rule2 ::= "b" -rule3 ::= "c" -""" - - grammar = BNFGrammar.from_ebnf_string(grammar_str) - matcher = GrammarStateMatcher(grammar) - assert matcher.debug_match_complete_string("bab") - assert not matcher.debug_match_complete_string("abb") - assert matcher.debug_match_complete_string("cab") - - -(json_input_accepted,) = tvm.testing.parameters( - ('{"name": "John"}',), - ('{ "name" : "John" }',), - ("{}",), - ("[]",), - ('{"name": "Alice", "age": 30, "city": "New York"}',), - ('{"name": "Mike", "hobbies": ["reading", "cycling", "hiking"]}',), - ('{"name": "Emma", "address": {"street": "Maple Street", "city": "Boston"}}',), - ('[{"name": "David"}, {"name": "Sophia"}]',), - ( - '{"name": "William", "age": null, "married": true, "children": ["Liam", "Olivia"],' - ' "hasPets": false}', - ), - ( - '{"name": "Olivia", "contact": {"email": "olivia@example.com", "address": ' - '{"city": "Chicago", "zipcode": "60601"}}}', - ), - ( - '{"name": "Liam", "skills": ["Java", "Python"], "experience": ' - '[{"company": "CompanyA", "years": 5}, {"company": "CompanyB", "years": 3}]}', - ), - ( - '{"person": {"name": "Ethan", "age": 40}, "education": {"degree": "Masters", ' - '"university": "XYZ University"}, "work": [{"company": "ABC Corp", "position": ' - '"Manager"}, {"company": "DEF Corp", "position": "Senior Manager"}]}', - ), - ( - '{"name": "Charlotte", "details": {"personal": {"age": 35, "hobbies": ["gardening", ' - '"painting"]}, "professional": {"occupation": "Engineer", "skills": ' - '["CAD", "Project Management"], "projects": [{"name": "Project A", ' - '"status": "Completed"}, {"name": "Project B", "status": "In Progress"}]}}}', - ), -) - - -def test_json_accept(json_grammar: BNFGrammar, json_input_accepted: str): - assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_accepted) - - -(json_input_refused,) = tvm.testing.parameters( - (r'{ name: "John" }',), - (r'{ "name": "John" } ',), # trailing space is not accepted - (r'{ "name": "John", "age": 30, }',), - (r'{ "name": "John", "address": { "street": "123 Main St", "city": "New York" }',), - (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), - (r'{ "name": "John", "age": 30.5.7 }',), - (r'{ "name": "John, "age": 30, "hobbies": ["reading", "traveling"] }',), - ( - r'{ "name": "John", "age": 30, "hobbies": ["reading", { "type": "outdoor", "list": ' - r'["hiking", "swimming",]}] }', - ), - (r'{ "name": "John", "age": 30, "status": "\P\J" }',), - ( - r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling"], "address": ' - r'{ "street": "123 Main St", "city": "New York", "coordinates": { "latitude": 40.7128, ' - r'"longitude": -74.0060 }}}, "work": { "company": "Acme", "position": "developer" }}', - ), -) - - -def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): - assert not GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_refused) - - -(json_input_pressure,) = tvm.testing.parameters( - # Extra long string: 1k chars - ( - '["Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer nec odio. Praesent ' - "libero. Sed cursus ante dapibus diam. Sed nisi. Nulla quis sem at nibh elementum " - "imperdiet. Duis sagittis ipsum. Praesent mauris. Fusce nec tellus sed augue semper " - "porta. Mauris massa. Vestibulum lacinia arcu eget nulla. Class aptent taciti sociosqu " - "ad litora torquent per conubia nostra, per inceptos himenaeos. Curabitur sodales ligula " - "in libero. Sed dignissim lacinia nunc. Curabitur tortor. Pellentesque nibh. Aenean quam. " - "In scelerisque sem at dolor. Maecenas mattis. Sed convallis tristique sem. Proin ut " - "ligula vel nunc egestas porttitor. Morbi lectus risus, iaculis vel, suscipit quis, " - "luctus non, massa. Fusce ac turpis quis ligula lacinia aliquet. Mauris ipsum. Nulla " - "metus metus, ullamcorper vel, tincidunt sed, euismod in, nibh. Quisque volutpat " - "condimentum velit. Class aptent taciti sociosqu ad litora torquent per conubia nostra, " - "per inceptos himenaeos. Nam nec ante. Sed lacinia, urna non tincidunt mattis, tortor " - "neque adipiscing diam, a cursus ipsum ante quis turpis. Nulla facilisi. Ut fringilla. " - "Suspendisse potenti. Nunc feugiat mi a tellus consequat imperdiet. Vestibulum sapien. " - "Proin quam. Etiam ultrices. Suspendisse in justo eu magna luctus suscipit. Sed lectus. " - "Integer euismod lacus luctus magna. Quisque cursus, metus vitae pharetra auctor, sem " - 'massa mattis sem, at interdum magna augue eget diam."]', - ), - # long and complex json: 3k chars - ( - r"""{ - "web-app": { - "servlet": [ - { - "servlet-name": "cofaxCDS", - "servlet-class": "org.cofax.cds.CDSServlet", - "init-param": { - "configGlossary:installationAt": "Philadelphia, PA", - "configGlossary:adminEmail": "ksm@pobox.com", - "configGlossary:poweredBy": "Cofax", - "configGlossary:poweredByIcon": "/images/cofax.gif", - "configGlossary:staticPath": "/content/static", - "templateProcessorClass": "org.cofax.WysiwygTemplate", - "templateLoaderClass": "org.cofax.FilesTemplateLoader", - "templatePath": "templates", - "templateOverridePath": "", - "defaultListTemplate": "listTemplate.htm", - "defaultFileTemplate": "articleTemplate.htm", - "useJSP": false, - "jspListTemplate": "listTemplate.jsp", - "jspFileTemplate": "articleTemplate.jsp", - "cachePackageTagsTrack": 200, - "cachePackageTagsStore": 200, - "cachePackageTagsRefresh": 60, - "cacheTemplatesTrack": 100, - "cacheTemplatesStore": 50, - "cacheTemplatesRefresh": 15, - "cachePagesTrack": 200, - "cachePagesStore": 100, - "cachePagesRefresh": 10, - "cachePagesDirtyRead": 10, - "searchEngineListTemplate": "forSearchEnginesList.htm", - "searchEngineFileTemplate": "forSearchEngines.htm", - "searchEngineRobotsDb": "WEB-INF/robots.db", - "useDataStore": true, - "dataStoreClass": "org.cofax.SqlDataStore", - "redirectionClass": "org.cofax.SqlRedirection", - "dataStoreName": "cofax", - "dataStoreDriver": "com.microsoft.jdbc.sqlserver.SQLServerDriver", - "dataStoreUrl": "jdbc:microsoft:sqlserver://LOCALHOST:1433;DatabaseName=goon", - "dataStoreUser": "sa", - "dataStorePassword": "dataStoreTestQuery", - "dataStoreTestQuery": "SET NOCOUNT ON;select test='test';", - "dataStoreLogFile": "/usr/local/tomcat/logs/datastore.log", - "dataStoreInitConns": 10, - "dataStoreMaxConns": 100, - "dataStoreConnUsageLimit": 100, - "dataStoreLogLevel": "debug", - "maxUrlLength": 500 - } - }, - { - "servlet-name": "cofaxEmail", - "servlet-class": "org.cofax.cds.EmailServlet", - "init-param": { - "mailHost": "mail1", - "mailHostOverride": "mail2" - } - }, - { - "servlet-name": "cofaxAdmin", - "servlet-class": "org.cofax.cds.AdminServlet" - }, - { - "servlet-name": "fileServlet", - "servlet-class": "org.cofax.cds.FileServlet" - }, - { - "servlet-name": "cofaxTools", - "servlet-class": "org.cofax.cms.CofaxToolsServlet", - "init-param": { - "templatePath": "toolstemplates/", - "log": 1, - "logLocation": "/usr/local/tomcat/logs/CofaxTools.log", - "logMaxSize": "", - "dataLog": 1, - "dataLogLocation": "/usr/local/tomcat/logs/dataLog.log", - "dataLogMaxSize": "", - "removePageCache": "/content/admin/remove?cache=pages&id=", - "removeTemplateCache": "/content/admin/remove?cache=templates&id=", - "fileTransferFolder": "/usr/local/tomcat/webapps/content/fileTransferFolder", - "lookInContext": 1, - "adminGroupID": 4, - "betaServer": true - } - } - ], - "servlet-mapping": { - "cofaxCDS": "/", - "cofaxEmail": "/cofaxutil/aemail/*", - "cofaxAdmin": "/admin/*", - "fileServlet": "/static/*", - "cofaxTools": "/tools/*" - }, - "taglib": { - "taglib-uri": "cofax.tld", - "taglib-location": "/WEB-INF/tlds/cofax.tld" - } - } -}""", - ), -) - - -def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): - assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) - - -(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( - ( - # short test - '{"id": 1,"name": "Example"}', - [ - # fmt: off - 31989, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 272, 272, 272, 272, - 272, 31973, 31846, 31846, 265, 265, 265, 265, 265, 265, 265, 265, 31974, 31999 - # fmt: on - ], - ), - ( - # long test - """{ -"id": 1, -"na": "ex", -"ac": true, -"t": ["t1", "t2"], -"ne": {"lv2": {"val": "dp"}, "arr": [1, 2, 3]}, -"res": "res" -}""", - [ - # fmt: off - 31989, 31912, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 31915, 272, 272, - 272, 31973, 31846, 31846, 265, 265, 265, 31974, 31915, 31915, 272, 272, 272, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 272, 272, 31973, 31846, 31846, - 31840, 264, 264, 264, 31969, 31846, 31846, 264, 264, 264, 31969, 31974, 31915, 31915, - 272, 272, 272, 31973, 31846, 31846, 31908, 272, 272, 272, 272, 31973, 31846, 31846, - 31906, 272, 272, 272, 272, 31973, 31846, 31846, 264, 264, 264, 31968, 31970, 31915, - 31915, 272, 272, 272, 272, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 272, 272, 272, 272, 31973, 31846, - 31846, 265, 265, 265, 265, 31974, 31974, 31999 - # fmt: on - ], - ), -) - - -def test_find_next_rejected_tokens( - json_grammar: BNFGrammar, - input_find_rejected_tokens: str, - expected_rejected_sizes: Optional[List[int]] = None, -): - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - tokenizer = Tokenizer(tokenizer_path) - grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) - - real_sizes = [] - for c in input_find_rejected_tokens: - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - print("Accepting char:", c, file=sys.stderr) - assert grammar_state_matcher.debug_accept_char(ord(c)) - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - - if expected_rejected_sizes is not None: - assert real_sizes == expected_rejected_sizes - - -def test_token_based_operations(json_grammar: BNFGrammar): - """Test accepting token and finding the next token mask.""" - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) - - expected = [ - ["{"], - ['"', "}", "\n", " ", '"a":true'], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], - [":", "\n", " ", ':"'], - ['"', "{", "6", "\n", " "], - ["}", ", ", "6", "\n", " "], - [" ", "\n", '"', '"a":true'], - [" ", "\n", '"', '"a":true'], - ["}", ", ", "\n", " "], - [""], - ] - - result = [] - - for id in input_ids: - rejected = grammar_state_matcher.find_next_rejected_tokens() - accepted = list(set(range(len(token_table))) - set(rejected)) - accepted_tokens = [token_table[i] for i in accepted] - result.append(accepted_tokens) - assert id in accepted - assert grammar_state_matcher.accept_token(id) - - rejected = grammar_state_matcher.find_next_rejected_tokens() - accepted = list(set(range(len(token_table))) - set(rejected)) - accepted_tokens = [token_table[i] for i in accepted] - result.append(accepted_tokens) - - assert result == expected - - -def test_custom_main_rule() -> None: - json_grammar_ebnf = r""" -main ::= basic_object -basic_any ::= basic_string | basic_object -basic_string ::= (([\"] basic_string_1 [\"])) -basic_string_1 ::= "" | [^"\\\r\n] basic_string_1 | "\\" escape basic_string_1 -escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" -ws ::= [ \n\t]* -""" - grammar = BNFGrammar.from_ebnf_string(json_grammar_ebnf, "basic_string") - assert GrammarStateMatcher(grammar).debug_match_complete_string(r'"abc\r\n"') - assert not GrammarStateMatcher(grammar).debug_match_complete_string(r'{"name": "John" }') - - -def test_find_next_rejected_tokens_schema() -> None: - class MainModel(BaseModel): - integer_field: int - number_field: float - boolean_field: bool - any_array_field: List - array_field: List[str] - tuple_field: Tuple[str, int, List[str]] - object_field: Dict[str, int] - nested_object_field: Dict[str, Dict[str, int]] - - schema = MainModel.model_json_schema() - schema_str = json.dumps(schema) - ebnf_grammar = BNFGrammar.from_schema(schema_str, indent=2) - - instance = MainModel( - integer_field=42, - number_field=3.14e5, - boolean_field=True, - any_array_field=[3.14, "foo", None, True], - array_field=["foo", "bar"], - tuple_field=("foo", 42, ["bar", "baz"]), - object_field={"foo": 42, "bar": 43}, - nested_object_field={"foo": {"bar": 42}}, - ) - instance_str = instance.model_dump_json(indent=2, round_trip=True) - - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - tokenizer = Tokenizer(tokenizer_path) - matcher = GrammarStateMatcher(ebnf_grammar, tokenizer) - - for c in instance_str: - matcher.find_next_rejected_tokens(True) - print("Accepting char:", c, file=sys.stderr) - assert matcher.debug_accept_char(ord(c)) - assert 2 not in matcher.find_next_rejected_tokens(True) - - -def test_get_jump_forward_string(): - grammar_ebnf = r"""main ::= "abb" | "abbd" | other_rule -other_rule ::= "a" sub_rule "b" -sub_rule ::= "b" -""" - grammar = BNFGrammar.from_ebnf_string(grammar_ebnf) - matcher = GrammarStateMatcher(grammar) - assert matcher.debug_accept_char(ord("a"), True) - assert matcher.find_jump_forward_string() == "bb" - - -def test_find_jump_forward_string_schema(): - class MainModel(BaseModel): - integer_field: int - number_field: float - boolean_field: bool - any_array_field: List - array_field: List[str] - tuple_field: Tuple[str, int, List[str]] - object_field: Dict[str, int] - nested_object_field: Dict[str, Dict[str, int]] - - schema = MainModel.model_json_schema() - schema_str = json.dumps(schema) - ebnf_grammar = BNFGrammar.from_schema(schema_str, indent=2) - - instance = MainModel( - integer_field=42, - number_field=3.14e5, - boolean_field=True, - any_array_field=[3.14, "foo", None, True], - array_field=["foo", "bar"], - tuple_field=("foo", 42, ["bar", "baz"]), - object_field={"foo": 42, "bar": 43}, - nested_object_field={"foo": {"bar": 42}}, - ) - instance_str = instance.model_dump_json(indent=2, round_trip=True) - - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - tokenizer = Tokenizer(tokenizer_path) - matcher = GrammarStateMatcher(ebnf_grammar, tokenizer) - - for i, c in enumerate(instance_str): - jump_forward_str = matcher.find_jump_forward_string() - print(f"Jump forward string at {i}: {jump_forward_str}") - assert instance_str[i : i + len(jump_forward_str)] == jump_forward_str - print("Accepting char:", c, file=sys.stderr) - assert matcher.debug_accept_char(ord(c)) - assert matcher.find_jump_forward_string() == "" - - -if __name__ == "__main__": - # Run a benchmark to show the performance before running tests - test_find_next_rejected_tokens(get_json_grammar(), '{"id": 1,"name": "Example"}') - test_find_next_rejected_tokens_schema() - - tvm.testing.main() diff --git a/tests/python/grammar/test_grammar_state_matcher_json.py b/tests/python/grammar/test_grammar_state_matcher_json.py deleted file mode 100644 index 333e36a283..0000000000 --- a/tests/python/grammar/test_grammar_state_matcher_json.py +++ /dev/null @@ -1,478 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-function-docstring -# pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking -"""This test uses the optimized JSON grammar provided by the grammar library.""" -import sys -from typing import List, Optional - -import pytest -import tvm -import tvm.testing -from tvm import TVMError - -from mlc_llm.grammar import BNFGrammar, GrammarStateMatcher -from mlc_llm.tokenizers import Tokenizer - - -@pytest.fixture(scope="function") -def json_grammar(): - return BNFGrammar.get_grammar_of_json() - - -(json_input_accepted,) = tvm.testing.parameters( - ('{"name": "John"}',), - ('{ "name" : "John" }',), - ("{}",), - ("[]",), - ('{"name": "Alice", "age": 30, "city": "New York"}',), - ('{"name": "Mike", "hobbies": ["reading", "cycling", "hiking"]}',), - ('{"name": "Emma", "address": {"street": "Maple Street", "city": "Boston"}}',), - ('[{"name": "David"}, {"name": "Sophia"}]',), - ( - '{"name": "William", "age": null, "married": true, "children": ["Liam", "Olivia"],' - ' "hasPets": false}', - ), - ( - '{"name": "Olivia", "contact": {"email": "olivia@example.com", "address": ' - '{"city": "Chicago", "zipcode": "60601"}}}', - ), - ( - '{"name": "Liam", "skills": ["Java", "Python"], "experience": ' - '[{"company": "CompanyA", "years": 5}, {"company": "CompanyB", "years": 3}]}', - ), - ( - '{"person": {"name": "Ethan", "age": 40}, "education": {"degree": "Masters", ' - '"university": "XYZ University"}, "work": [{"company": "ABC Corp", "position": ' - '"Manager"}, {"company": "DEF Corp", "position": "Senior Manager"}]}', - ), - ( - '{"name": "Charlotte", "details": {"personal": {"age": 35, "hobbies": ["gardening", ' - '"painting"]}, "professional": {"occupation": "Engineer", "skills": ' - '["CAD", "Project Management"], "projects": [{"name": "Project A", ' - '"status": "Completed"}, {"name": "Project B", "status": "In Progress"}]}}}', - ), -) - - -def test_json_accept(json_grammar: BNFGrammar, json_input_accepted: str): - assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_accepted) - - -(json_input_refused,) = tvm.testing.parameters( - (r'{ name: "John" }',), - (r'{ "name": "John" } ',), # trailing space is not accepted - (r'{ "name": "John", "age": 30, }',), - (r'{ "name": "John", "address": { "street": "123 Main St", "city": "New York" }',), - (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), - (r'{ "name": "John", "age": 30.5.7 }',), - (r'{ "name": "John, "age": 30, "hobbies": ["reading", "traveling"] }',), - ( - r'{ "name": "John", "age": 30, "hobbies": ["reading", { "type": "outdoor", "list": ' - r'["hiking", "swimming",]}] }', - ), - (r'{ "name": "John", "age": 30, "status": "\P\J" }',), - ( - r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling"], "address": ' - r'{ "street": "123 Main St", "city": "New York", "coordinates": { "latitude": 40.7128, ' - r'"longitude": -74.0060 }}}, "work": { "company": "Acme", "position": "developer" }}', - ), -) - - -def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): - assert not GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_refused) - - -(json_input_pressure,) = tvm.testing.parameters( - # Extra long string: 1k chars - ( - '["Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer nec odio. Praesent ' - "libero. Sed cursus ante dapibus diam. Sed nisi. Nulla quis sem at nibh elementum " - "imperdiet. Duis sagittis ipsum. Praesent mauris. Fusce nec tellus sed augue semper " - "porta. Mauris massa. Vestibulum lacinia arcu eget nulla. Class aptent taciti sociosqu " - "ad litora torquent per conubia nostra, per inceptos himenaeos. Curabitur sodales ligula " - "in libero. Sed dignissim lacinia nunc. Curabitur tortor. Pellentesque nibh. Aenean quam. " - "In scelerisque sem at dolor. Maecenas mattis. Sed convallis tristique sem. Proin ut " - "ligula vel nunc egestas porttitor. Morbi lectus risus, iaculis vel, suscipit quis, " - "luctus non, massa. Fusce ac turpis quis ligula lacinia aliquet. Mauris ipsum. Nulla " - "metus metus, ullamcorper vel, tincidunt sed, euismod in, nibh. Quisque volutpat " - "condimentum velit. Class aptent taciti sociosqu ad litora torquent per conubia nostra, " - "per inceptos himenaeos. Nam nec ante. Sed lacinia, urna non tincidunt mattis, tortor " - "neque adipiscing diam, a cursus ipsum ante quis turpis. Nulla facilisi. Ut fringilla. " - "Suspendisse potenti. Nunc feugiat mi a tellus consequat imperdiet. Vestibulum sapien. " - "Proin quam. Etiam ultrices. Suspendisse in justo eu magna luctus suscipit. Sed lectus. " - "Integer euismod lacus luctus magna. Quisque cursus, metus vitae pharetra auctor, sem " - 'massa mattis sem, at interdum magna augue eget diam."]', - ), - # long and complex json: 3k chars - ( - r"""{ - "web-app": { - "servlet": [ - { - "servlet-name": "cofaxCDS", - "servlet-class": "org.cofax.cds.CDSServlet", - "init-param": { - "configGlossary:installationAt": "Philadelphia, PA", - "configGlossary:adminEmail": "ksm@pobox.com", - "configGlossary:poweredBy": "Cofax", - "configGlossary:poweredByIcon": "/images/cofax.gif", - "configGlossary:staticPath": "/content/static", - "templateProcessorClass": "org.cofax.WysiwygTemplate", - "templateLoaderClass": "org.cofax.FilesTemplateLoader", - "templatePath": "templates", - "templateOverridePath": "", - "defaultListTemplate": "listTemplate.htm", - "defaultFileTemplate": "articleTemplate.htm", - "useJSP": false, - "jspListTemplate": "listTemplate.jsp", - "jspFileTemplate": "articleTemplate.jsp", - "cachePackageTagsTrack": 200, - "cachePackageTagsStore": 200, - "cachePackageTagsRefresh": 60, - "cacheTemplatesTrack": 100, - "cacheTemplatesStore": 50, - "cacheTemplatesRefresh": 15, - "cachePagesTrack": 200, - "cachePagesStore": 100, - "cachePagesRefresh": 10, - "cachePagesDirtyRead": 10, - "searchEngineListTemplate": "forSearchEnginesList.htm", - "searchEngineFileTemplate": "forSearchEngines.htm", - "searchEngineRobotsDb": "WEB-INF/robots.db", - "useDataStore": true, - "dataStoreClass": "org.cofax.SqlDataStore", - "redirectionClass": "org.cofax.SqlRedirection", - "dataStoreName": "cofax", - "dataStoreDriver": "com.microsoft.jdbc.sqlserver.SQLServerDriver", - "dataStoreUrl": "jdbc:microsoft:sqlserver://LOCALHOST:1433;DatabaseName=goon", - "dataStoreUser": "sa", - "dataStorePassword": "dataStoreTestQuery", - "dataStoreTestQuery": "SET NOCOUNT ON;select test='test';", - "dataStoreLogFile": "/usr/local/tomcat/logs/datastore.log", - "dataStoreInitConns": 10, - "dataStoreMaxConns": 100, - "dataStoreConnUsageLimit": 100, - "dataStoreLogLevel": "debug", - "maxUrlLength": 500 - } - }, - { - "servlet-name": "cofaxEmail", - "servlet-class": "org.cofax.cds.EmailServlet", - "init-param": { - "mailHost": "mail1", - "mailHostOverride": "mail2" - } - }, - { - "servlet-name": "cofaxAdmin", - "servlet-class": "org.cofax.cds.AdminServlet" - }, - { - "servlet-name": "fileServlet", - "servlet-class": "org.cofax.cds.FileServlet" - }, - { - "servlet-name": "cofaxTools", - "servlet-class": "org.cofax.cms.CofaxToolsServlet", - "init-param": { - "templatePath": "toolstemplates/", - "log": 1, - "logLocation": "/usr/local/tomcat/logs/CofaxTools.log", - "logMaxSize": "", - "dataLog": 1, - "dataLogLocation": "/usr/local/tomcat/logs/dataLog.log", - "dataLogMaxSize": "", - "removePageCache": "/content/admin/remove?cache=pages&id=", - "removeTemplateCache": "/content/admin/remove?cache=templates&id=", - "fileTransferFolder": "/usr/local/tomcat/webapps/content/fileTransferFolder", - "lookInContext": 1, - "adminGroupID": 4, - "betaServer": true - } - } - ], - "servlet-mapping": { - "cofaxCDS": "/", - "cofaxEmail": "/cofaxutil/aemail/*", - "cofaxAdmin": "/admin/*", - "fileServlet": "/static/*", - "cofaxTools": "/tools/*" - }, - "taglib": { - "taglib-uri": "cofax.tld", - "taglib-location": "/WEB-INF/tlds/cofax.tld" - } - } -}""", - ), -) - - -def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): - assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) - - -(tokenizer_path, input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( - ( - # short test - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - '{"id": 1,"name": "Example"}', - [ - # fmt: off - 31989, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 272, 272, 272, 272, - 272, 31973, 31846, 31846, 265, 265, 265, 265, 265, 265, 265, 265, 31974, 31999 - # fmt: on - ], - ), - ( - # short test - "dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC", - '{"id": 1,"name": "Example哈哈"}', - [ - # fmt: off - 128235, 127497, 5002, 5002, 5002, 127849, 126399, 126399, 126760, 127499, 5002, 5002, - 5002, 5002, 5002, 127849, 126399, 126399, 4952, 4952, 4952, 4952, 4952, 4952, 4952, - 4952, 128066, 128111, 4952, 128066, 128111, 4952, 127873, 128254 - # fmt: on - ], - ), - ( - # long test - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - """{ -"id": 1, -"na": "ex", -"ac": true, -"t": ["t1", "t2"], -"ne": {"lv2": {"val": "dp"}, "arr": [1, 2, 3]}, -"res": "res" -}""", - [ - # fmt: off - 31989, 31912, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 31915, 272, 272, - 272, 31973, 31846, 31846, 265, 265, 265, 31974, 31915, 31915, 272, 272, 272, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 272, 272, 31973, 31846, 31846, - 31840, 264, 264, 264, 31969, 31846, 31846, 264, 264, 264, 31969, 31974, 31915, 31915, - 272, 272, 272, 31973, 31846, 31846, 31908, 272, 272, 272, 272, 31973, 31846, 31846, - 31906, 272, 272, 272, 272, 31973, 31846, 31846, 264, 264, 264, 31968, 31970, 31915, - 31915, 272, 272, 272, 272, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 272, 272, 272, 272, 31973, 31846, - 31846, 265, 265, 265, 265, 31974, 31974, 31999 - # fmt: on - ], - ), -) - - -def test_find_next_rejected_tokens( - json_grammar: BNFGrammar, - tokenizer_path: str, - input_find_rejected_tokens: str, - expected_rejected_sizes: Optional[List[int]], -): - tokenizer = Tokenizer(tokenizer_path) - grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) - input_bytes = input_find_rejected_tokens.encode("utf-8") - rejected_sizes = [] - - for i, c in enumerate(input_bytes): - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - rejected_sizes.append(len(rejected_token_ids)) - if expected_rejected_sizes is not None: - assert rejected_sizes[-1] == expected_rejected_sizes[i], ( - len(rejected_token_ids), - expected_rejected_sizes[i], - ) - print("Accepting char:", c, bytes([c]), file=sys.stderr) - assert grammar_state_matcher.debug_accept_char(c) - - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - rejected_sizes.append(len(rejected_token_ids)) - if expected_rejected_sizes is not None: - assert rejected_sizes[-1] == expected_rejected_sizes[-1] - - -def test_token_based_operations(json_grammar: BNFGrammar): - """Test accepting token and finding the next token mask.""" - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) - - expected = [ - ["{"], - ['"', "}", "\n", " ", '"a":true'], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], - [":", "\n", " ", ':"'], - ['"', "{", "6", "\n", " "], - ["}", ", ", "6", "\n", " "], - [" ", "\n", '"', '"a":true'], - [" ", "\n", '"', '"a":true'], - ["}", ", ", "\n", " "], - [""], - ] - - result = [] - - for id in input_ids: - rejected = grammar_state_matcher.find_next_rejected_tokens() - accepted = list(set(range(len(token_table))) - set(rejected)) - accepted_tokens = [token_table[i] for i in accepted] - result.append(accepted_tokens) - assert id in accepted, token_table[id] - assert grammar_state_matcher.accept_token(id) - - rejected = grammar_state_matcher.find_next_rejected_tokens() - accepted = list(set(range(len(token_table))) - set(rejected)) - accepted_tokens = [token_table[i] for i in accepted] - result.append(accepted_tokens) - - assert result == expected - - -def test_rollback(json_grammar: BNFGrammar): - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table, 5) - - assert grammar_state_matcher.max_rollback_steps() == 5 - - input_ids_splitted = [input_ids[i : i + 2] for i in range(0, len(input_ids), 2)] - - for i_1, i_2 in input_ids_splitted: - orig_result = [] - orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i_1) - orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i_2) - grammar_state_matcher.rollback(2) - result_after_rollback = [] - result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i_1) - result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i_2) - assert orig_result == result_after_rollback - - -def test_reset(json_grammar: BNFGrammar): - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) - - orig_result = [] - - for i in input_ids: - orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i) - - grammar_state_matcher.reset_state() - - result_after_reset = [] - - for i in input_ids: - result_after_reset.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i) - - assert orig_result == result_after_reset - - -def test_set_stop_token_ids(json_grammar: BNFGrammar): - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", ""] - input_ids = [token_table.index(t) for t in input_splitted] - - # 1. Will accept as last token for stop token - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) - for i in input_ids: - assert grammar_state_matcher.accept_token(i) - - # 2. Will reject as last token for stop token - grammar_state_matcher.reset_state() - grammar_state_matcher.set_stop_token_ids([2]) - for i in input_ids: - if i == 1: - # 1 is , will be rejected - assert not grammar_state_matcher.accept_token(i) - else: - assert grammar_state_matcher.accept_token(i) - - # 3. Will accept "a" as stop token - grammar_state_matcher.reset_state() - grammar_state_matcher.set_stop_token_ids([2]) - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", "a"] - input_ids = [token_table.index(t) for t in input_splitted] - for i in input_ids: - assert grammar_state_matcher.accept_token(i) - - -def test_termination(json_grammar: BNFGrammar): - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", ""] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table, 5) - - orig_result = [] - - for i in input_ids: - orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i) - - assert grammar_state_matcher.is_terminated() - - with pytest.raises(TVMError): - grammar_state_matcher.accept_token(0) - - with pytest.raises(TVMError): - grammar_state_matcher.find_next_rejected_tokens() - - grammar_state_matcher.rollback(2) - - assert not grammar_state_matcher.is_terminated() - assert grammar_state_matcher.accept_token(input_ids[-2]) - - -if __name__ == "__main__": - # Run a benchmark to show the performance before running tests - test_find_next_rejected_tokens( - BNFGrammar.get_grammar_of_json(), - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - '{"id": 1,"name": "Example"}', - None, - ) - - test_find_next_rejected_tokens( - BNFGrammar.get_grammar_of_json(), - "dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC", - '{"id": 1,"name": "Example哈哈"}', - None, - ) - - tvm.testing.main() diff --git a/tests/python/grammar/test_json_schema_converter.py b/tests/python/grammar/test_json_schema_converter.py deleted file mode 100644 index 0ec250992a..0000000000 --- a/tests/python/grammar/test_json_schema_converter.py +++ /dev/null @@ -1,478 +0,0 @@ -import json -from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Union - -import tvm.testing -from pydantic import BaseModel, Field, TypeAdapter - -from mlc_llm.grammar import BNFGrammar, GrammarStateMatcher - - -def check_schema_with_grammar( - schema: Dict[str, Any], - expected_grammar: str, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -): - schema_str = json.dumps(schema) - grammar = BNFGrammar.debug_json_schema_to_ebnf( - schema_str, indent=indent, separators=separators, strict_mode=strict_mode - ) - assert grammar == expected_grammar - - -def check_schema_with_json( - schema: Dict[str, Any], - json_str: str, - check_accepted: bool = True, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -): - ebnf_grammar = BNFGrammar.from_schema( - json.dumps(schema, indent=2), indent=indent, separators=separators, strict_mode=strict_mode - ) - matcher = GrammarStateMatcher(ebnf_grammar) - - if check_accepted: - assert matcher.debug_match_complete_string(json_str) - else: - assert not matcher.debug_match_complete_string(json_str) - - -def check_schema_with_instance( - schema: Dict[str, Any], - instance: BaseModel, - check_accepted: bool = True, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -): - instance_obj = instance.model_dump(mode="json", round_trip=True) - instance_str = json.dumps(instance_obj, indent=indent, separators=separators) - check_schema_with_json(schema, instance_str, check_accepted, indent, separators, strict_mode) - - -def test_basic() -> None: - class MainModel(BaseModel): - integer_field: int - number_field: float - boolean_field: bool - any_array_field: List - array_field: List[str] - tuple_field: Tuple[str, int, List[str]] - object_field: Dict[str, int] - nested_object_field: Dict[str, Dict[str, int]] - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_3 ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -main_prop_4 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" -main_prop_5_item_2 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" -main_prop_5 ::= "[" "" basic_string ", " basic_integer ", " main_prop_5_item_2 "" "]" -main_prop_6 ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" -main_prop_7_addl ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" -main_prop_7 ::= ("{" "" basic_string ": " main_prop_7_addl (", " basic_string ": " main_prop_7_addl)* "" "}") | "{}" -main ::= "{" "" "\"integer_field\"" ": " basic_integer ", " "\"number_field\"" ": " basic_number ", " "\"boolean_field\"" ": " basic_boolean ", " "\"any_array_field\"" ": " main_prop_3 ", " "\"array_field\"" ": " main_prop_4 ", " "\"tuple_field\"" ": " main_prop_5 ", " "\"object_field\"" ": " main_prop_6 ", " "\"nested_object_field\"" ": " main_prop_7 "" "}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - - instance = MainModel( - integer_field=42, - number_field=3.14e5, - boolean_field=True, - any_array_field=[3.14, "foo", None, True], - array_field=["foo", "bar"], - tuple_field=("foo", 42, ["bar", "baz"]), - object_field={"foo": 42, "bar": 43}, - nested_object_field={"foo": {"bar": 42}}, - ) - check_schema_with_instance(schema, instance) - - instance_empty = MainModel( - integer_field=42, - number_field=3.14e5, - boolean_field=True, - any_array_field=[], - array_field=[], - tuple_field=("foo", 42, []), - object_field={}, - nested_object_field={}, - ) - - schema = MainModel.model_json_schema() - check_schema_with_instance(schema, instance_empty) - - -def test_indent() -> None: - class MainModel(BaseModel): - array_field: List[str] - tuple_field: Tuple[str, int, List[str]] - object_field: Dict[str, int] - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" -main_prop_1_item_2 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" -main_prop_1 ::= "[" "\n " basic_string ",\n " basic_integer ",\n " main_prop_1_item_2 "\n " "]" -main_prop_2 ::= ("{" "\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " "}") | "{}" -main ::= "{" "\n " "\"array_field\"" ": " main_prop_0 ",\n " "\"tuple_field\"" ": " main_prop_1 ",\n " "\"object_field\"" ": " main_prop_2 "\n" "}" -""" - - instance = MainModel( - array_field=["foo", "bar"], - tuple_field=("foo", 42, ["bar", "baz"]), - object_field={"foo": 42, "bar": 43}, - ) - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar, indent=2) - check_schema_with_instance(schema, instance, indent=2) - check_schema_with_instance(schema, instance, indent=None, separators=(",", ":")) - - -def test_non_strict() -> None: - class Foo(BaseModel): - pass - - class MainModel(BaseModel): - tuple_field: Tuple[str, Tuple[int, int]] - foo_field: Foo - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0_item_1 ::= "[" "\n " basic_integer ",\n " basic_integer (",\n " basic_any)* "\n " "]" -main_prop_0 ::= "[" "\n " basic_string ",\n " main_prop_0_item_1 (",\n " basic_any)* "\n " "]" -main_prop_1 ::= ("{" "\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " "}") | "{}" -main ::= "{" "\n " "\"tuple_field\"" ": " main_prop_0 ",\n " "\"foo_field\"" ": " main_prop_1 (",\n " basic_string ": " basic_any)* "\n" "}" -""" - - instance_json = """{ - "tuple_field": [ - "foo", - [ - 12, - 13, - "ext" - ], - "extra" - ], - "foo_field": { - "tmp": "str" - }, - "extra": "field" -}""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar, indent=2, strict_mode=False) - check_schema_with_json(schema, instance_json, indent=2, strict_mode=False) - - -def test_enum_const() -> None: - class Field(Enum): - FOO = "foo" - BAR = "bar" - - class MainModel(BaseModel): - bars: Literal["a"] - str_values: Literal['a\n\r"'] - foo: Literal["a", "b", "c"] - values: Literal[1, "a", True] - field: Field - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0 ::= "\"a\"" -main_prop_1 ::= "\"a\\n\\r\\\"\"" -main_prop_2 ::= ("\"a\"") | ("\"b\"") | ("\"c\"") -main_prop_3 ::= ("1") | ("\"a\"") | ("true") -main_prop_4 ::= ("\"foo\"") | ("\"bar\"") -main ::= "{" "" "\"bars\"" ": " main_prop_0 ", " "\"str_values\"" ": " main_prop_1 ", " "\"foo\"" ": " main_prop_2 ", " "\"values\"" ": " main_prop_3 ", " "\"field\"" ": " main_prop_4 "" "}" -""" - - schema = MainModel.model_json_schema() - instance = MainModel(foo="a", values=1, bars="a", str_values='a\n\r"', field=Field.FOO) - check_schema_with_grammar(schema, ebnf_grammar) - check_schema_with_instance(schema, instance) - - -def test_optional() -> None: - class MainModel(BaseModel): - num: int = 0 - opt_bool: Optional[bool] = None - size: Optional[float] - name: str = "" - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_1 ::= basic_boolean | basic_null -main_prop_2 ::= basic_number | basic_null -main ::= "{" "" ("\"num\"" ": " basic_integer ", ")? ("\"opt_bool\"" ": " main_prop_1 ", ")? "\"size\"" ": " main_prop_2 (", " "\"name\"" ": " basic_string)? "" "}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - - instance = MainModel(num=42, opt_bool=True, size=3.14, name="foo") - check_schema_with_instance(schema, instance) - - instance = MainModel(size=None) - check_schema_with_instance(schema, instance) - - check_schema_with_json(schema, '{"size": null}') - check_schema_with_json(schema, '{"size": null, "name": "foo"}') - check_schema_with_json(schema, '{"num": 1, "size": null, "name": "foo"}') - - -def test_all_optional() -> None: - class MainModel(BaseModel): - size: int = 0 - state: bool = False - num: float = 0 - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_part_1 ::= "" | ", " "\"num\"" ": " basic_number "" -main_part_0 ::= main_part_1 | ", " "\"state\"" ": " basic_boolean main_part_1 -main ::= ("{" "" (("\"size\"" ": " basic_integer main_part_0) | ("\"state\"" ": " basic_boolean main_part_1) | ("\"num\"" ": " basic_number "")) "" "}") | "{}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - - instance = MainModel(size=42, state=True, num=3.14) - check_schema_with_instance(schema, instance) - - check_schema_with_json(schema, '{"state": false}') - check_schema_with_json(schema, '{"size": 1, "num": 1.5}') - - ebnf_grammar_non_strict = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_part_2 ::= (", " basic_string ": " basic_any)* -main_part_1 ::= main_part_2 | ", " "\"num\"" ": " basic_number main_part_2 -main_part_0 ::= main_part_1 | ", " "\"state\"" ": " basic_boolean main_part_1 -main ::= ("{" "" (("\"size\"" ": " basic_integer main_part_0) | ("\"state\"" ": " basic_boolean main_part_1) | ("\"num\"" ": " basic_number main_part_2) | basic_string ": " basic_any main_part_2) "" "}") | "{}" -""" - - check_schema_with_grammar(schema, ebnf_grammar_non_strict, strict_mode=False) - - check_schema_with_json(schema, '{"size": 1, "num": 1.5, "other": false}', strict_mode=False) - check_schema_with_json(schema, '{"other": false}', strict_mode=False) - - -def test_empty() -> None: - class MainModel(BaseModel): - pass - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main ::= "{" "}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - - instance = MainModel() - check_schema_with_instance(schema, instance) - - check_schema_with_json(schema, '{"tmp": 123}', strict_mode=False) - - -def test_reference() -> None: - class Foo(BaseModel): - count: int - size: Optional[float] = None - - class Bar(BaseModel): - apple: str = "x" - banana: str = "y" - - class MainModel(BaseModel): - foo: Foo - bars: List[Bar] - - instance = MainModel( - foo=Foo(count=42, size=3.14), - bars=[Bar(apple="a", banana="b"), Bar(apple="c", banana="d")], - ) - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0_prop_1 ::= basic_number | basic_null -main_prop_0 ::= "{" "" "\"count\"" ": " basic_integer (", " "\"size\"" ": " main_prop_0_prop_1)? "" "}" -main_prop_1_items_part_0 ::= "" | ", " "\"banana\"" ": " basic_string "" -main_prop_1_items ::= ("{" "" (("\"apple\"" ": " basic_string main_prop_1_items_part_0) | ("\"banana\"" ": " basic_string "")) "" "}") | "{}" -main_prop_1 ::= ("[" "" main_prop_1_items (", " main_prop_1_items)* "" "]") | "[]" -main ::= "{" "" "\"foo\"" ": " main_prop_0 ", " "\"bars\"" ": " main_prop_1 "" "}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - check_schema_with_instance(schema, instance) - - -def test_union() -> None: - class Cat(BaseModel): - name: str - color: str - - class Dog(BaseModel): - name: str - breed: str - - ta = TypeAdapter(Union[Cat, Dog]) - - model_schema = ta.json_schema() - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_case_0 ::= "{" "" "\"name\"" ": " basic_string ", " "\"color\"" ": " basic_string "" "}" -main_case_1 ::= "{" "" "\"name\"" ": " basic_string ", " "\"breed\"" ": " basic_string "" "}" -main ::= main_case_0 | main_case_1 -""" - - check_schema_with_grammar(model_schema, ebnf_grammar) - - check_schema_with_instance(model_schema, Cat(name="kitty", color="black")) - check_schema_with_instance(model_schema, Dog(name="doggy", breed="bulldog")) - check_schema_with_json(model_schema, '{"name": "kitty", "test": "black"}', False) - - -def test_alias() -> None: - class MainModel(BaseModel): - test: str = Field(..., alias="name") - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main ::= "{" "" "\"name\"" ": " basic_string "" "}" -""" - - check_schema_with_grammar(MainModel.model_json_schema(), ebnf_grammar) - - instance = MainModel(name="kitty") - instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=False)) - check_schema_with_json(MainModel.model_json_schema(by_alias=False), instance_str) - - instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=True)) - check_schema_with_json(MainModel.model_json_schema(by_alias=True), instance_str) - - # property name contains space - class MainModelSpace(BaseModel): - test: Literal["abc"] = Field(..., alias="name 1") - - ebnf_grammar_space = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0 ::= "\"abc\"" -main ::= "{" "" "\"name 1\"" ": " main_prop_0 "" "}" -""" - - check_schema_with_grammar(MainModelSpace.model_json_schema(), ebnf_grammar_space) - - instance_space = MainModelSpace(**{"name 1": "abc"}) - instance_space_str = json.dumps( - instance_space.model_dump(mode="json", round_trip=True, by_alias=True) - ) - check_schema_with_json(MainModelSpace.model_json_schema(by_alias=True), instance_space_str) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc index 8bb47a7946..978f5533bd 100644 --- a/web/emcc/mlc_wasm_runtime.cc +++ b/web/emcc/mlc_wasm_runtime.cc @@ -33,15 +33,3 @@ #define PICOJSON_USE_INT64 #define DMLC_USE_LOGGING_LIBRARY - -// Grammar related -#include "grammar/grammar.cc" -#include "grammar/grammar_functor.cc" -#include "grammar/grammar_parser.cc" -#include "grammar/grammar_serializer.cc" -#include "grammar/grammar_state_matcher.cc" -#include "grammar/json_schema_converter.cc" -#include "support/encoding.cc" - -// Only compiles necessary functions for mlc.PostProcessTokenTable -#include "tokenizers/tokenizers.cc"