Skip to content

Commit

Permalink
Context free grammar examples&bug fixes
Browse files Browse the repository at this point in the history
- Add utilities in `FormatterBuilder` to make using custom CFG easier
- Fix a bug when JsonExtractor is used together with ChoiceExtractor
- Fix readme example
  • Loading branch information
Dan-wanna-M committed Aug 21, 2024
1 parent 5faff8a commit 57f980d
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 5 deletions.
71 changes: 70 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128

f = FormatterBuilder()
schema = infer_mapping({"name":"foo", "age": 28})
f.append_line(f"{f.str(not_contain='{')}{f.schema(schema, JsonGenerator(), capture_name='json')}")
f.append_line(f"{f.schema(schema, JsonGenerator(), capture_name='json')}")
logits_processor = create_formatter_logits_processor_list(tokenizer, f)
inputs = tokenizer(["""<|system|>
You are a helpful assistant.<|end|>
Expand Down Expand Up @@ -203,7 +203,76 @@ print(logits_processor[0].formatters_captures)
# possible output:
# [{'json': 14}]
```
### CFG-Constrained generation
Context free grammars use [kbnf's syntax](https://docs.rs/kbnf/latest/kbnf/#kbnf-grammar) which is a variant of EBNF.
Since formatron uses [kbnf](https://github.com/Dan-wanna-M/kbnf?tab=readme-ov-file#features) under the hood, all kbnf's claims on performance hold.
```python
from formatron import extractor
from formatron.formatter import FormatterBuilder
from transformers import AutoModelForCausalLM
import transformers
import typing
from formatron.integrations.transformers import create_formatter_logits_processor_list
import torch
rules = """
expression ::= term { ("+" | "-") term };
term ::= factor { ("*" | "/") factor };
factor ::= number | "(" expression ")";
number ::= #"[0-9]+(\\\\.[0-9]+)?";
"""
class ArithmeticExpressionExtractor(extractor.Extractor):
def __init__(self,nonterminal:str, capture_name: typing.Optional[str] = None):
super().__init__(capture_name)
self._nonterminal = nonterminal

def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
i = 0
left_bracket = 0
while i < len(input_str):
if input_str[i].isdigit() or input_str[i] in "+-*/.":
i += 1
continue
if input_str[i] == "(":
i += 1
left_bracket += 1
continue
if input_str[i] == ")":
i += 1
left_bracket -= 1
continue
else:
break
if left_bracket != 0:
return None
return input_str[i:], input_str[:i]

@property
def nonterminal(self) -> str:
return self._nonterminal

@property
def kbnf_representation(self) -> str:
return self._nonterminal

model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k",
device_map="cuda",
torch_dtype=torch.float16)
tokenizer = transformers.AutoTokenizer.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k")
inputs = tokenizer(["""<|system|>
You are a helpful assistant.<|end|>
<|user|>Repeat it: ((32+43)*114)<|end|>
<|assistant|>((32+43)*114)<|end|>
<|user|>Repeat it: ((32+43)*(114-514))<|end|>
<|assistant|>"""], return_tensors="pt").to("cuda")
f = FormatterBuilder()
f.append_line(
f"{f.extractor(lambda nonterminal: ArithmeticExpressionExtractor(nonterminal, 'json'),lambda nonterminal: rules.replace('expression', nonterminal), capture_name='json')}")
logits_processor = create_formatter_logits_processor_list(tokenizer, f)
print(tokenizer.batch_decode(model.generate(**inputs, top_p=0.5, temperature=1,
max_new_tokens=100, logits_processor=logits_processor)))
print(logits_processor[0].formatters_captures)
# possible output: [{'json': '(((32+43)*(114-514)))*1.5'}]
```
### Json Schema
You can use [pydantic's code generator](https://docs.pydantic.dev/latest/integrations/datamodel_code_generator/)
to generate pydantic models from json schema.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "formatron"
version = "0.2.0"
version = "0.3.0"
authors = [
{name = "Xintong Sun", email = "[email protected]"},
]
Expand Down
11 changes: 9 additions & 2 deletions src/formatron/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import textwrap
import typing
from copy import copy
from json import JSONDecodeError

import kbnf
from kbnf import AcceptTokenResult, Engine
Expand Down Expand Up @@ -312,12 +313,18 @@ def schema(self, schema: typing.Type[schemas.schema.Schema],
:param capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
:return: The schema extractor.
"""
def to_json(json:str):
try:
return schema.from_json(json)
except JSONDecodeError: # make ChoiceExtractor work appropriately
return None

return self._add_extractor(capture_name, "schema",
lambda nonterminal: grammar_generator.get_extractor(nonterminal, capture_name,
lambda json: schema.from_json(
json)),
to_json),
lambda nonterminal: grammar_generator.generate(schema, nonterminal))


def str(self, *, stop: typing.Union[str, list[str]] = None,
capture_name: typing.Optional[str] = None) -> RegexExtractor:
"""
Expand Down
66 changes: 65 additions & 1 deletion tests/test_transformers_integration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import typing

import torch
import transformers
from formatron.schemas.pydantic import ClassSchema
from transformers import GPT2LMHeadModel, AutoModelForCausalLM

import extractor
from formatter import FormatterBuilder
from grammar_generators.json_generator import JsonGenerator
from integrations.transformers import create_formatter_logits_processor_list
Expand Down Expand Up @@ -52,7 +55,7 @@ def test_readme_example2(snapshot):

f = FormatterBuilder()
schema = infer_mapping({"name":"foo", "age": 28})
f.append_line(f"{f.str(not_contain='{')}{f.schema(schema, JsonGenerator(), capture_name='json')}")
f.append_line(f"{f.schema(schema, JsonGenerator(), capture_name='json')}")
logits_processor = create_formatter_logits_processor_list(tokenizer, f)
inputs = tokenizer(["""<|system|>
You are a helpful assistant.<|end|>
Expand Down Expand Up @@ -116,6 +119,67 @@ def add(a: int, b: int, /, *, c: int):
# possible output:
# [{'json': 14}]

def test_readme_example5(snapshot):
rules = """
expression ::= term { ("+" | "-") term };
term ::= factor { ("*" | "/") factor };
factor ::= number | "(" expression ")";
number ::= #"[0-9]+(\\\\.[0-9]+)?";
"""

class ArithmeticExpressionExtractor(extractor.Extractor):
def __init__(self,nonterminal:str, capture_name: typing.Optional[str] = None):
super().__init__(capture_name)
self._nonterminal = nonterminal

def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
i = 0
left_bracket = 0
while i < len(input_str):
if input_str[i].isdigit() or input_str[i] in "+-*/.":
i += 1
continue
if input_str[i] == "(":
i += 1
left_bracket += 1
continue
if input_str[i] == ")":
i += 1
left_bracket -= 1
continue
else:
break
if left_bracket != 0:
return None
return input_str[i:], input_str[:i]

@property
def nonterminal(self) -> str:
return self._nonterminal

@property
def kbnf_representation(self) -> str:
return self._nonterminal

model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k",
device_map="cuda",
torch_dtype=torch.float16)
tokenizer = transformers.AutoTokenizer.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k")
inputs = tokenizer(["""<|system|>
You are a helpful assistant.<|end|>
<|user|>Repeat it: ((32+43)*114)<|end|>
<|assistant|>((32+43)*114)<|end|>
<|user|>Repeat it: ((32+43)*(114-514))<|end|>
<|assistant|>"""], return_tensors="pt").to("cuda")
f = FormatterBuilder()
f.append_line(
f"{f.extractor(lambda nonterminal: ArithmeticExpressionExtractor(nonterminal, 'json'),lambda nonterminal: rules.replace('expression', nonterminal), capture_name='json')}")
logits_processor = create_formatter_logits_processor_list(tokenizer, f)
print(tokenizer.batch_decode(model.generate(**inputs, top_p=0.5, temperature=1,
max_new_tokens=100, logits_processor=logits_processor)))
print(logits_processor[0].formatters_captures)
# possible output: [{'json': '(((32+43)*(114-514)))*1.5'}]

def test_transformers_batched_inference(snapshot):
f = FormatterBuilder()
f.append_line(f"Hello, Huggingface!")
Expand Down

0 comments on commit 57f980d

Please sign in to comment.