Skip to content

Commit

Permalink
Allow using non-terminals as right-hand side of simple rules
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaMuravjov committed Mar 2, 2024
1 parent 4b7110e commit b04d4aa
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 112 deletions.
59 changes: 2 additions & 57 deletions src/algo_setting/preprocessor_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from src.algo_setting.algo_setting import AlgoSetting
from src.grammar.cnf_grammar_template import CnfGrammarTemplate, Symbol
from src.graph.label_decomposed_graph import LabelDecomposedGraph
from src.problems.Base.template_cfg.utils import explode_indices


class PreProcessorSetting(AlgoSetting, ABC):
Expand Down Expand Up @@ -67,60 +68,4 @@ def preprocess(
if not self.is_enabled:
return graph, grammar

block_matrix_space = graph.block_matrix_space
block_count = block_matrix_space.block_count

matrices = dict()
for symbol, matrix in graph.matrices.items():
if block_matrix_space.is_single_cell(matrix.shape):
matrices[symbol] = matrix
else:
for i, block in enumerate(block_matrix_space.get_hyper_vector_blocks(matrix)):
matrices[_index_symbol(symbol, i)] = block

epsilon_rules = []
for non_terminal in grammar.epsilon_rules:
if non_terminal.is_indexed:
for i in range(block_count):
epsilon_rules.append(_index_symbol(non_terminal, i))
else:
epsilon_rules.append(non_terminal)

simple_rules = []
for (non_terminal, terminal) in grammar.simple_rules:
if non_terminal.is_indexed or terminal.is_indexed:
for i in range(block_count):
simple_rules.append((_index_symbol(non_terminal, i), _index_symbol(terminal, i)))
else:
simple_rules.append((non_terminal, terminal))

complex_rules = []
for (non_terminal, symbol1, symbol2) in grammar.complex_rules:
if non_terminal.is_indexed or symbol1.is_indexed or symbol2.is_indexed:
for i in range(block_count):
complex_rules.append((
_index_symbol(non_terminal, i),
_index_symbol(symbol1, i),
_index_symbol(symbol2, i),
))
else:
complex_rules.append((non_terminal, symbol1, symbol2))

return (
LabelDecomposedGraph(
vertex_count=graph.vertex_count,
block_matrix_space=block_matrix_space,
dtype=graph.dtype,
matrices=matrices
),
CnfGrammarTemplate(
start_nonterm=grammar.start_nonterm,
epsilon_rules=epsilon_rules,
simple_rules=simple_rules,
complex_rules=complex_rules
)
)


def _index_symbol(symbol: Symbol, index: int) -> Symbol:
return Symbol(f"{symbol.label}_{index}") if symbol.is_indexed else symbol
return explode_indices(graph, grammar)
14 changes: 5 additions & 9 deletions src/grammar/cnf_grammar_template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Set


class Symbol:
Expand All @@ -11,7 +12,7 @@ def __repr__(self):
return self.label

def __eq__(self, other):
return self.label == other.label
return isinstance(other, Symbol) and self.label == other.label

def __hash__(self) -> int:
return self.label.__hash__()
Expand All @@ -30,11 +31,6 @@ def __init__(
self.simple_rules = simple_rules
self.complex_rules = complex_rules

for (non_terminal, terminal) in simple_rules:
if terminal in self.non_terminals:
raise ValueError(f"Invalid rule '{non_terminal} {terminal}'. "
f"Right hand side of a simple rule should be a terminal symbol.")

@property
def non_terminals(self):
return set.union(
Expand All @@ -51,7 +47,7 @@ def read_from_pocr_cnf_file(path: Union[Path, str]) -> "CnfGrammarTemplate":
The file format is expected to be as follows:
- Each non-empty line represents a rule, except the last two lines.
- Complex rules are in the format: `<NON_TERMINAL> <SYMBOL_1> <SYMBOL_2>`
- Simple rules are in the format: `<NON_TERMINAL> <TERMINAL>`
- Simple rules are in the format: `<NON_TERMINAL> <SYMBOL_1>`
- Epsilon rules are in the format: `<NON_TERMINAL>`
- Indexed symbols names must end with suffix `_i`.
- Whitespace characters are used to separate values on one line
Expand Down Expand Up @@ -93,7 +89,7 @@ def read_from_pocr_cnf_file(path: Union[Path, str]) -> "CnfGrammarTemplate":
raise ValueError(
f"Invalid rule format: `{line}` in file `{path}`. "
f"Expected formats are `<NON_TERMINAL> <SYMBOL_1> <SYMBOL_2>` for complex rules, "
f"`<NON_TERMINAL> <TERMINAL>` for simple rules, and `<NON_TERMINAL>` for epsilon rules."
f"`<NON_TERMINAL> <SYMBOL_1>` for simple rules, and `<NON_TERMINAL>` for epsilon rules."
)

return CnfGrammarTemplate(start_nonterm, epsilon_rules, simple_rules, complex_rules)
7 changes: 7 additions & 0 deletions src/graph/label_decomposed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ def read_from_pocr_graph_file(path: Union[Path, str]) -> "LabelDecomposedGraph":
def __sizeof__(self) -> int:
return sum(m.__sizeof__() for m in self.matrices.values())

def __getitem__(self, symbol: Symbol):
return (
self.matrices[symbol]
if symbol in self.matrices
else self.block_matrix_space.create_space_element(self.dtype, is_vector=symbol.is_indexed)
)


class OptimizedLabelDecomposedGraph:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/matrix/optimized_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Optional, Tuple

import graphblas
from graphblas import Matrix
from graphblas.core.dtypes import DataType
from graphblas.core.matrix import Matrix
from graphblas.core.operator import Semiring, Monoid

from src.utils.subtractable_semiring import SubOp
Expand Down
12 changes: 12 additions & 0 deletions src/matrix/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any

import graphblas
from graphblas.core.dtypes import DataType
from graphblas.core.matrix import Matrix
from graphblas.core.vector import Vector


def complimentary_mask(matrix: Matrix, mask: Matrix) -> Matrix:
Expand All @@ -9,3 +13,11 @@ def complimentary_mask(matrix: Matrix, mask: Matrix) -> Matrix:
res.ss.config["format"] = matrix.ss.config["format"]
res(~mask.S) << zero.ewise_add(matrix, op=graphblas.monoid.any)
return res


def identity_matrix(one: Any, dtype: DataType, size: int) -> Matrix:
return Vector.from_scalar(
value=one,
size=size,
dtype=dtype
).diag()
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

from graphblas.core.matrix import Matrix
from graphblas.core.operator import Semiring, Monoid
from graphblas.core.vector import Vector
from graphblas.semiring import any_pair

from src.algo_setting.algo_setting import AlgoSetting
from src.grammar.cnf_grammar_template import CnfGrammarTemplate
from src.graph.label_decomposed_graph import OptimizedLabelDecomposedGraph, LabelDecomposedGraph
from src.matrix.matrix_optimizer_setting import get_matrix_optimizer_settings
from src.matrix.utils import complimentary_mask
from src.matrix.utils import complimentary_mask, identity_matrix
from src.problems.Base.template_cfg.template_cfg_all_pairs_reachability import AllPairsCflReachabilityAlgoInstance
from src.utils.subtractable_semiring import SubtractableSemiring

Expand Down Expand Up @@ -43,7 +42,7 @@ def monoid(self) -> Monoid:

def solve(self) -> Matrix:
self.add_epsilon_edges()
self.add_edges_for_simple_rules()
self.add_edges_for_simple_terminal_rules()
self.compute_transitive_closure()
return self.graph[self.grammar.start_nonterm]

Expand All @@ -54,14 +53,14 @@ def compute_transitive_closure(self):
def add_epsilon_edges(self):
if len(self.grammar.epsilon_rules) == 0:
return
identity_matrix = Vector.from_scalar(
self.algebraic_structure.one,
id_matrix = identity_matrix(
one=self.algebraic_structure.one,
size=self.graph.vertex_count,
dtype=self.graph.dtype
).diag()
)
for non_terminal in self.grammar.epsilon_rules:
self.graph.iadd_by_symbol(non_terminal, identity_matrix, op=self.monoid)
self.graph.iadd_by_symbol(non_terminal, id_matrix, op=self.monoid)

def add_edges_for_simple_rules(self):
for (non_terminal, terminal) in self.grammar.simple_rules:
self.graph.iadd_by_symbol(non_terminal, self.graph[terminal], op=self.monoid)
def add_edges_for_simple_terminal_rules(self):
for (lhs, rhs) in self.grammar.simple_rules:
self.graph.iadd_by_symbol(lhs, self.graph[rhs], op=self.monoid)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def compute_transitive_closure(self):
accum=new_front,
op=self.semiring
)
for (lhs, rhs) in self.grammar.simple_rules:
if rhs in self.grammar.non_terminals:
new_front.iadd_by_symbol(lhs, front[rhs], op=self.monoid)
front = new_front.to_unoptimized()
front = self.graph.rsub(front, op=self.algebraic_structure.sub_op)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def __init__(self, *args, **kwargs):
def compute_transitive_closure(self) -> OptimizedLabelDecomposedGraph:
old_nvals = self.graph.nvals
while True:
for (lhs, rhs) in self.grammar.simple_rules:
if rhs in self.grammar.non_terminals:
self.graph.iadd_by_symbol(lhs, self.graph[rhs], op=self.monoid)
self.graph.mxm(
self.graph.to_unoptimized(),
self.grammar,
Expand Down
65 changes: 65 additions & 0 deletions src/problems/Base/template_cfg/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from src.grammar.cnf_grammar_template import CnfGrammarTemplate, Symbol
from src.graph.label_decomposed_graph import LabelDecomposedGraph


def explode_indices(
graph: LabelDecomposedGraph,
grammar: CnfGrammarTemplate
) -> (LabelDecomposedGraph, CnfGrammarTemplate):
block_matrix_space = graph.block_matrix_space
block_count = block_matrix_space.block_count

matrices = dict()
for symbol, matrix in graph.matrices.items():
if block_matrix_space.is_single_cell(matrix.shape):
matrices[symbol] = matrix
else:
for i, block in enumerate(block_matrix_space.get_hyper_vector_blocks(matrix)):
matrices[_index_symbol(symbol, i)] = block

epsilon_rules = []
for non_terminal in grammar.epsilon_rules:
if non_terminal.is_indexed:
for i in range(block_count):
epsilon_rules.append(_index_symbol(non_terminal, i))
else:
epsilon_rules.append(non_terminal)

simple_rules = []
for (non_terminal, terminal) in grammar.simple_rules:
if non_terminal.is_indexed or terminal.is_indexed:
for i in range(block_count):
simple_rules.append((_index_symbol(non_terminal, i), _index_symbol(terminal, i)))
else:
simple_rules.append((non_terminal, terminal))

complex_rules = []
for (non_terminal, symbol1, symbol2) in grammar.complex_rules:
if non_terminal.is_indexed or symbol1.is_indexed or symbol2.is_indexed:
for i in range(block_count):
complex_rules.append((
_index_symbol(non_terminal, i),
_index_symbol(symbol1, i),
_index_symbol(symbol2, i),
))
else:
complex_rules.append((non_terminal, symbol1, symbol2))

return (
LabelDecomposedGraph(
vertex_count=graph.vertex_count,
block_matrix_space=block_matrix_space,
dtype=graph.dtype,
matrices=matrices
),
CnfGrammarTemplate(
start_nonterm=grammar.start_nonterm,
epsilon_rules=epsilon_rules,
simple_rules=simple_rules,
complex_rules=complex_rules
)
)


def _index_symbol(symbol: Symbol, index: int) -> Symbol:
return Symbol(f"{symbol.label}_{index}") if symbol.is_indexed else symbol
47 changes: 12 additions & 35 deletions test/pocr_data/c_alias/c_alias.cnf
Original file line number Diff line number Diff line change
@@ -1,37 +1,14 @@
S H1 H0
S H2 H0
V H3 V3
V V2 V3
V V1 V3
V V1 V2
V H4 V3
V H5 V3
V H5 V2
V a
V H1 H0
V H2 H0
V H6 V1
V V2 H7
V H7 V1
V a_r
V1 H6 V1
V1 V2 H7
V1 H7 V1
V1 a_r
V2 H1 H0
V2 H2 H0
V3 H4 V3
V3 H5 V3
V3 H5 V2
V3 a
H0 d
H1 H2 V
H2 d_r
H3 V1 V2
H4 H5 V2
H5 a
H6 V2 H7
H7 a_r
S d_r V_d
V1
V3 a V2_V3
V2_V3 V2 V3
V2 d_r V_d
a_r_V1 a_r V1
V V1 V2_V3
V3
V2
V_d V d
V1 V2 a_r_V1

Count:
S
S

0 comments on commit b04d4aa

Please sign in to comment.