Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ported back new grammar changes from C++ to Python implementation #1637

Merged
merged 10 commits into from
Aug 7, 2024

Conversation

ExtReMLapin
Copy link
Contributor

@ExtReMLapin ExtReMLapin commented Jul 29, 2024

@ExtReMLapin
Copy link
Contributor Author

Not working yet, for example :

Ok : root ::= ("EYEYAHA"){5}

Not working : root ::= ("EYEYAHA"){1,5}

from_string grammar:
root ::= root_1 root_5
root_1 ::= [E] [Y] [E] [Y] [A] [H] [A]
root_2 ::= root_1 | print_grammar: error printing grammar: unexpected end of rule: 2,2

@ExtReMLapin
Copy link
Contributor Author

I've been looking again and again and I don't see what I missed from the pr diff 😕

@ExtReMLapin
Copy link
Contributor Author

ExtReMLapin commented Jul 30, 2024

Help is welcome if you can help @abetlen

Right now, root ::= "A"{1,6} generates :

root ::= [A] root_5
root_1 ::= [A] root_4 |
root_2 ::= [A] root_4 |
root_3 ::= [A] root_4 |
root_4 ::= [A] root_4 |
root_5 ::= [A] root_4 |

@abetlen
Copy link
Owner

abetlen commented Aug 1, 2024

Hey @ExtReMLapin thanks for starting on this fix, just getting back to everything now after vacation. I'll take a stab at this over the next couple days as well.

@abetlen abetlen marked this pull request as ready for review August 4, 2024 21:16
@abetlen
Copy link
Owner

abetlen commented Aug 4, 2024

@ExtReMLapin got the new grammar features back-ported and ended up rewriting most of llama_grammar.py. With #1649 this should bring the grammar implementation in-line with llama.cpp.

@ExtReMLapin
Copy link
Contributor Author

ExtReMLapin commented Aug 4, 2024

Thank you abetlen.
While checking the code I was a little surprised by the multiple else-if and the list hardcoded inside the function instead of being outside (so regenerated on each function call) (ex : decode_utf8)

As a proud lazy man, I asked GPT4 to

  1. Try to write an optimized version of parse_hex decode_utf8 parse_char (I actually expected it to make a jmp table for it)
  2. Write a benchmark and tests for it

Issue is that is seems that not all UTF8 characters are supported
(see test code bellow), Is it really an issue ?

import timeit
import typing

# Original Functions
def original_parse_hex(src: str, size: int) -> typing.Tuple[int, str]:
    pos = 0
    value = 0
    for _ in range(size):
        value <<= 4
        c = src[pos]
        if "a" <= c <= "f":
            value += ord(c) - ord("a") + 10
        elif "A" <= c <= "F":
            value += ord(c) - ord("A") + 10
        elif "0" <= c <= "9":
            value += ord(c) - ord("0")
        else:
            break
        pos += 1
    if pos != size:
        raise ValueError(f"expecting {size} hex chars at {src}")
    return value, src[pos:]


def original_decode_utf8(src: str) -> typing.Tuple[int, str]:
    lookup: list[int] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]
    first_byte: int = ord(src[0])
    highbits: int = first_byte >> 4
    #first_byte to hex 

    print(highbits)
    length: int = lookup[highbits]
    mask: int = (1 << (8 - length)) - 1
    value: int = first_byte & mask
    end: int = min(len(src), length)

    pos: int = 1
    for pos in range(1, end):
        if not src[pos]:
            break
        value = (value << 6) + (ord(src[pos]) & 0x3F)

    return value, src[pos:] if pos < len(src) else ""


def original_parse_char(src: str) -> typing.Tuple[int, str]:
    if src[0] == "\\":
        if src[1] == "x":
            return original_parse_hex(src[2:], 2)
        elif src[1] == "u":
            return original_parse_hex(src[2:], 4)
        elif src[1] == "U":
            return original_parse_hex(src[2:], 8)
        elif src[1] == "t":
            return ord("\t"), src[2:]
        elif src[1] == "r":
            return ord("\r"), src[2:]
        elif src[1] == "n":
            return ord("\n"), src[2:]
        elif src[1] in ('\\', '"', '[', ']'):
            return ord(src[1]), src[2:]
        else:
            raise ValueError(f"unknown escape at {src}")
    elif src:
        return original_decode_utf8(src)
    raise ValueError("unexpected end of input")


hex_map = {**{f"{x}": x for x in range(10)}, **{chr(x): x - ord('a') + 10 for x in range(ord('a'), ord('f') + 1)}, **{chr(x): x - ord('A') + 10 for x in range(ord('A'), ord('F') + 1)}}
# Optimized Functions
def optimized_parse_hex(src: str, size: int) -> typing.Tuple[int, str]:
    
    value = 0
    for i in range(size):
        c = src[i]
        if c in hex_map:
            value = (value << 4) + hex_map[c]
        else:
            raise ValueError(f"expecting {size} hex chars at {src}")
    return value, src[size:]


prealloc = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]

def optimized_decode_utf8(src: str) -> typing.Tuple[int, str]:
    first_byte = ord(src[0])
    highbits = first_byte >> 4
    length = prealloc[highbits]
    value = first_byte & ((1 << (8 - length)) - 1)

    for i in range(1, length):
        value = (value << 6) + (ord(src[i]) & 0x3F)
    
    return value, src[length:]


escape_sequences = {
    "x": 2, "u": 4, "U": 8,
    "t": ord("\t"), "r": ord("\r"), "n": ord("\n"),
    "\\": ord("\\"), '"': ord('"'), '[': ord('['), ']': ord(']')
}

def optimized_parse_char(src: str) -> typing.Tuple[int, str]:
    if src[0] == "\\":

        esc = src[1]
        if esc in escape_sequences:
            if esc in 'xuU':
                return optimized_parse_hex(src[2:], escape_sequences[esc])
            return escape_sequences[esc], src[2:]
        raise ValueError(f"unknown escape at {src}")
    elif src:
        return optimized_decode_utf8(src)
    raise ValueError("unexpected end of input")

import random 
def generate_utf8_string(length: int) -> str:
    utf8_chars = [
        chr(random.randint(0x20, 0x7E)),    # ASCII characters
        chr(random.randint(0x80, 0x07FF)),  # Extended Latin and similar
        chr(random.randint(0x0800, 0xFFFF)),  # Multilingual Plane
        chr(random.randint(0x10000, 0x10FFFF)) # Supplementary Planes (Emoji, etc.)
    ]
    return ''.join(random.choice(utf8_chars) for _ in range(length))

def benchmark():
    # Generate a random UTF-8 string of 500 characters
    test_string = generate_utf8_string(500)
    print('Random string : ', test_string)
    
    # Ensure both functions return the same result
    original_result = original_parse_char(test_string)
    optimized_result = optimized_parse_char(test_string)
    
    assert original_result == optimized_result, "The results of original and optimized functions do not match!"
    
    original_time = timeit.timeit(lambda: original_parse_char(test_string), number=100000)
    optimized_time = timeit.timeit(lambda: optimized_parse_char(test_string), number=100000)
    
    print(f"Original parse_char time: {original_time:.6f} seconds")
    print(f"Optimized parse_char time: {optimized_time:.6f} seconds")

if __name__ == "__main__":
    benchmark()

It can easily be fixed by adding a auto cap to 4 is the len > len(byte len array)

Benchmark results anyway

Original parse_char time: 0.075589 seconds
Optimized parse_char time: 0.051987 seconds

@ExtReMLapin
Copy link
Contributor Author

ExtReMLapin commented Aug 5, 2024

Alright, I gave a try at the office, rule parsing is broken, and few functions are missing (ex from_file).

Test code :

from llama_cpp import LlamaGrammar, Llama

gbnf_str = r"""# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
# Useful for generating JSON arrays

root   ::= arr
value  ::= object | array | string | number | ("true" | "false" | "null") ws

arr  ::=
  "[\n" ws (
            value
    (",\n" ws value)*
  )? "]"

object ::=
  "{" ws (
            string ":" ws value
    ("," ws string ":" ws value)*
  )? "}" ws

array  ::=
  "[" ws (
            value
    ("," ws value)*
  )? "]" ws

string ::=
  "\"" (
    [^"\\\x7F\x00-\x1F] |
    "\\" (["\\bfnrt] | "u" [0-9a-fA-F]) # escapes
  )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9])) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9])? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= | " " | "\n" [ \t]
"""
gguf = "/opt/IdExtend/models/llm/mistral-7b-instruct-v0.2.Q5_K_M.gguf"



grammar = LlamaGrammar.from_string(gbnf_str, verbose=False)
model = Llama(gguf, n_ctx=8192, n_gpu_layers=-1, tensor_split=[1,0,0], verbose=False)


stream = model.create_completion("In a json format give me a list of known stars :", grammar=grammar, stream=True, max_tokens=1024)
for output in stream:
    print(output['choices'][0]['text'], end="")

@abetlen
Copy link
Owner

abetlen commented Aug 7, 2024

@ExtReMLapin just fixed the last bug, was re-assigning out_elements by mistake inside of parse_sequence.

Do you mind opening another PR for those changes? For now I just wanted to keep the implementation as close to the c++ as possible but obviously there's room to optimize (may be better to do some other kind of caching here though).

@abetlen abetlen merged commit dff186c into abetlen:main Aug 7, 2024
13 checks passed
@ExtReMLapin
Copy link
Contributor Author

Thanks for the fix, will do !

@ExtReMLapin ExtReMLapin deleted the patch-1 branch August 7, 2024 04:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Grammars bracket repetition symbol not working
2 participants