Skip to content

Commit

Permalink
Remove unused functions in outlines_core.fsm.regex
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Aug 29, 2024
1 parent bc936ae commit c6cf2cf
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 411 deletions.
208 changes: 0 additions & 208 deletions python/outlines_core/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
TYPE_CHECKING,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Expand All @@ -18,7 +17,6 @@
from interegular.fsm import (
FSM,
Alphabet,
OblivionError,
State,
TransitionKey,
_AnythingElseCls,
Expand Down Expand Up @@ -270,17 +268,6 @@ def create_seq_transitions(
)


def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM:
new_fsm = make_byte_level_fsm(fsm, keep_utf8)
return BetterFSM(
alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping),
states=new_fsm.states,
initial=new_fsm.initial,
finals=new_fsm.finals,
map=new_fsm.map,
)


def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
"""Construct an equivalent FSM with deterministic state labels."""
old_to_new_trans_keys = {
Expand Down Expand Up @@ -355,201 +342,6 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
return new_fsm, old_to_new_states


def walk_fsm(
fsm: BetterFSM,
token_transition_keys: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
fsm_finals = fsm.finals

state = start_state
accepted_states: List[int] = []
last_final_idx: int = 0

fsm_transitions = fsm.flat_transition_map

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_transition_keys):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]

return []

state = new_state

if state in fsm_finals:
last_final_idx = i + 1

accepted_states.append(state)

if full_match and last_final_idx - 1 != i:
return []

return accepted_states


def fsm_union(
fsms: Sequence[FSM],
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
"""Construct an FSM representing the union of the FSMs in `fsms`.
This is an updated version of `interegular.fsm.FSM.union` made to return an
extra map of component FSMs to the sets of state transitions that
correspond to them in the new FSM.
"""

alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])

indexed_fsms = tuple(enumerate(fsms))

initial = {i: fsm.initial for (i, fsm) in indexed_fsms}

# Dedicated function accepting a "superset" and returning the next
# "superset" obtained by following this transition in the new FSM
def follow(current_state, new_transition: int):
next = {}
for i, f in indexed_fsms:
old_transition = new_to_old[i][new_transition]
if (
i in current_state
and current_state[i] in f.map
and old_transition in f.map[current_state[i]]
):
next[i] = f.map[current_state[i]][old_transition]
if not next:
raise OblivionError
return next

states = [initial]
finals: Set[int] = set()
map: Dict[int, Dict[int, int]] = {}

# Map component FSMs to their new state-to-state transitions, finals, and a
# map translating component FSM states to aggregate FSM states
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
] = {}

i = 0
while i < len(states):
state = states[i]

# Add to the finals of the aggregate FSM whenever we hit a final in a
# component FSM
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
finals.add(i)

# Compute the map for this state
map[i] = {}
for transition in alphabet.by_transition:
try:
next = follow(state, transition)
except OblivionError:
# Reached an oblivion state; don't list it
continue
else:
try:
# TODO: Seems like this could--and should--be avoided
j = states.index(next)
except ValueError:
j = len(states)
states.append(next)

map[i][transition] = j

for fsm_id, fsm_state in next.items():
(
fsm_transitions,
fsm_finals,
fsm_old_to_new,
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
old_from = state[fsm_id]
old_to = fsm_state
fsm_old_to_new.setdefault(old_from, set()).add(i)
fsm_old_to_new.setdefault(old_to, set()).add(j)
fsm_transitions.add((i, j))
if fsm_state in fsms[fsm_id].finals:
fsm_finals.add(j)

i += 1

fsm = FSM(
alphabet=alphabet,
states=range(len(states)),
initial=0,
finals=finals,
map=map,
__no_validation__=True,
)

fsm, old_to_new_states = make_deterministic_fsm(fsm)
_fsms_to_trans_finals = {
fsm_id: (
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
{old_to_new_states[s] for s in finals},
{
old_state: {old_to_new_states[new_state] for new_state in new_states}
for old_state, new_states in old_to_new.items()
},
)
for fsm_id, (transitions, finals, old_to_new) in sorted(
fsms_to_trans_finals.items(), key=lambda x: x[0]
)
}

return (
fsm,
_fsms_to_trans_finals,
)


def get_sub_fsms_from_seq(
state_seq: Sequence[int],
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
],
) -> Generator[Tuple[int, bool, bool], None, None]:
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
Parameters
----------
state_seq
A state sequence.
fsms_to_trans_finals
A map from FSM indices to tuples containing sets of their state transitions
and sets of the final/accept states.
Returns
-------
A generator returning tuples containing each sub-FSM index (in the order
they were union-ed to construct `fsm`) and booleans indicating whether or
not there is another valid transition from the last state in the sequence
for the associated sub-FSM (i.e. if the FSM can continue
accepting/matching) and whether or not the sequence ends in a final state
of the sub-FSM.
"""
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
last_fsm_state = state_seq[-1]
yield from (
(
# The sub-FMS index
fsm_idx,
# Is there another possible transition in this sub-FSM?
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
# Is this sub-FSM in a final state?
state_seq[-1] in finals,
)
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
if state_seq_transitions.issubset(transitions)
)


re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")

# The "▁*" prefix is required to handle Gemma and GPT-SW3 tokenizers, and the "\.*"
Expand Down
Loading

0 comments on commit c6cf2cf

Please sign in to comment.