From 57f980d642f3a33cde5e4f24a0c5391c93823a57 Mon Sep 17 00:00:00 2001 From: Huanghe Date: Tue, 20 Aug 2024 20:14:49 -0500 Subject: [PATCH] Context free grammar examples&bug fixes - Add utilities in `FormatterBuilder` to make using custom CFG easier - Fix a bug when JsonExtractor is used together with ChoiceExtractor - Fix readme example --- README.md | 71 +++++++++++++++++++++++++- pyproject.toml | 2 +- src/formatron/formatter.py | 11 +++- tests/test_transformers_integration.py | 66 +++++++++++++++++++++++- 4 files changed, 145 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ee591ac1..b3e45555 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128 f = FormatterBuilder() schema = infer_mapping({"name":"foo", "age": 28}) -f.append_line(f"{f.str(not_contain='{')}{f.schema(schema, JsonGenerator(), capture_name='json')}") +f.append_line(f"{f.schema(schema, JsonGenerator(), capture_name='json')}") logits_processor = create_formatter_logits_processor_list(tokenizer, f) inputs = tokenizer(["""<|system|> You are a helpful assistant.<|end|> @@ -203,7 +203,76 @@ print(logits_processor[0].formatters_captures) # possible output: # [{'json': 14}] ``` +### CFG-Constrained generation +Context free grammars use [kbnf's syntax](https://docs.rs/kbnf/latest/kbnf/#kbnf-grammar) which is a variant of EBNF. +Since formatron uses [kbnf](https://github.com/Dan-wanna-M/kbnf?tab=readme-ov-file#features) under the hood, all kbnf's claims on performance hold. +```python +from formatron import extractor +from formatron.formatter import FormatterBuilder +from transformers import AutoModelForCausalLM +import transformers +import typing +from formatron.integrations.transformers import create_formatter_logits_processor_list +import torch +rules = """ +expression ::= term { ("+" | "-") term }; +term ::= factor { ("*" | "/") factor }; +factor ::= number | "(" expression ")"; +number ::= #"[0-9]+(\\\\.[0-9]+)?"; +""" +class ArithmeticExpressionExtractor(extractor.Extractor): + def __init__(self,nonterminal:str, capture_name: typing.Optional[str] = None): + super().__init__(capture_name) + self._nonterminal = nonterminal + + def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]: + i = 0 + left_bracket = 0 + while i < len(input_str): + if input_str[i].isdigit() or input_str[i] in "+-*/.": + i += 1 + continue + if input_str[i] == "(": + i += 1 + left_bracket += 1 + continue + if input_str[i] == ")": + i += 1 + left_bracket -= 1 + continue + else: + break + if left_bracket != 0: + return None + return input_str[i:], input_str[:i] + + @property + def nonterminal(self) -> str: + return self._nonterminal + @property + def kbnf_representation(self) -> str: + return self._nonterminal + +model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k", + device_map="cuda", + torch_dtype=torch.float16) +tokenizer = transformers.AutoTokenizer.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k") +inputs = tokenizer(["""<|system|> + You are a helpful assistant.<|end|> + <|user|>Repeat it: ((32+43)*114)<|end|> + <|assistant|>((32+43)*114)<|end|> + <|user|>Repeat it: ((32+43)*(114-514))<|end|> + <|assistant|>"""], return_tensors="pt").to("cuda") +f = FormatterBuilder() +f.append_line( +f"{f.extractor(lambda nonterminal: ArithmeticExpressionExtractor(nonterminal, 'json'),lambda nonterminal: rules.replace('expression', nonterminal), capture_name='json')}") +logits_processor = create_formatter_logits_processor_list(tokenizer, f) +print(tokenizer.batch_decode(model.generate(**inputs, top_p=0.5, temperature=1, + max_new_tokens=100, logits_processor=logits_processor))) +print(logits_processor[0].formatters_captures) +# possible output: [{'json': '(((32+43)*(114-514)))*1.5'}] +``` ### Json Schema You can use [pydantic's code generator](https://docs.pydantic.dev/latest/integrations/datamodel_code_generator/) to generate pydantic models from json schema. diff --git a/pyproject.toml b/pyproject.toml index cd93dc02..a78e6fc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "formatron" -version = "0.2.0" +version = "0.3.0" authors = [ {name = "Xintong Sun", email = "xs28@rice.edu"}, ] diff --git a/src/formatron/formatter.py b/src/formatron/formatter.py index c8ef6845..dcac05e5 100644 --- a/src/formatron/formatter.py +++ b/src/formatron/formatter.py @@ -6,6 +6,7 @@ import textwrap import typing from copy import copy +from json import JSONDecodeError import kbnf from kbnf import AcceptTokenResult, Engine @@ -312,12 +313,18 @@ def schema(self, schema: typing.Type[schemas.schema.Schema], :param capture_name: The capture name of the extractor, or `None` if the extractor does not capture. :return: The schema extractor. """ + def to_json(json:str): + try: + return schema.from_json(json) + except JSONDecodeError: # make ChoiceExtractor work appropriately + return None + return self._add_extractor(capture_name, "schema", lambda nonterminal: grammar_generator.get_extractor(nonterminal, capture_name, - lambda json: schema.from_json( - json)), + to_json), lambda nonterminal: grammar_generator.generate(schema, nonterminal)) + def str(self, *, stop: typing.Union[str, list[str]] = None, capture_name: typing.Optional[str] = None) -> RegexExtractor: """ diff --git a/tests/test_transformers_integration.py b/tests/test_transformers_integration.py index d81ca056..3c63769b 100644 --- a/tests/test_transformers_integration.py +++ b/tests/test_transformers_integration.py @@ -1,8 +1,11 @@ +import typing + import torch import transformers from formatron.schemas.pydantic import ClassSchema from transformers import GPT2LMHeadModel, AutoModelForCausalLM +import extractor from formatter import FormatterBuilder from grammar_generators.json_generator import JsonGenerator from integrations.transformers import create_formatter_logits_processor_list @@ -52,7 +55,7 @@ def test_readme_example2(snapshot): f = FormatterBuilder() schema = infer_mapping({"name":"foo", "age": 28}) - f.append_line(f"{f.str(not_contain='{')}{f.schema(schema, JsonGenerator(), capture_name='json')}") + f.append_line(f"{f.schema(schema, JsonGenerator(), capture_name='json')}") logits_processor = create_formatter_logits_processor_list(tokenizer, f) inputs = tokenizer(["""<|system|> You are a helpful assistant.<|end|> @@ -116,6 +119,67 @@ def add(a: int, b: int, /, *, c: int): # possible output: # [{'json': 14}] +def test_readme_example5(snapshot): + rules = """ +expression ::= term { ("+" | "-") term }; +term ::= factor { ("*" | "/") factor }; +factor ::= number | "(" expression ")"; +number ::= #"[0-9]+(\\\\.[0-9]+)?"; +""" + + class ArithmeticExpressionExtractor(extractor.Extractor): + def __init__(self,nonterminal:str, capture_name: typing.Optional[str] = None): + super().__init__(capture_name) + self._nonterminal = nonterminal + + def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]: + i = 0 + left_bracket = 0 + while i < len(input_str): + if input_str[i].isdigit() or input_str[i] in "+-*/.": + i += 1 + continue + if input_str[i] == "(": + i += 1 + left_bracket += 1 + continue + if input_str[i] == ")": + i += 1 + left_bracket -= 1 + continue + else: + break + if left_bracket != 0: + return None + return input_str[i:], input_str[:i] + + @property + def nonterminal(self) -> str: + return self._nonterminal + + @property + def kbnf_representation(self) -> str: + return self._nonterminal + + model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k", + device_map="cuda", + torch_dtype=torch.float16) + tokenizer = transformers.AutoTokenizer.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k") + inputs = tokenizer(["""<|system|> + You are a helpful assistant.<|end|> + <|user|>Repeat it: ((32+43)*114)<|end|> + <|assistant|>((32+43)*114)<|end|> + <|user|>Repeat it: ((32+43)*(114-514))<|end|> + <|assistant|>"""], return_tensors="pt").to("cuda") + f = FormatterBuilder() + f.append_line( + f"{f.extractor(lambda nonterminal: ArithmeticExpressionExtractor(nonterminal, 'json'),lambda nonterminal: rules.replace('expression', nonterminal), capture_name='json')}") + logits_processor = create_formatter_logits_processor_list(tokenizer, f) + print(tokenizer.batch_decode(model.generate(**inputs, top_p=0.5, temperature=1, + max_new_tokens=100, logits_processor=logits_processor))) + print(logits_processor[0].formatters_captures) + # possible output: [{'json': '(((32+43)*(114-514)))*1.5'}] + def test_transformers_batched_inference(snapshot): f = FormatterBuilder() f.append_line(f"Hello, Huggingface!")