Skip to content

Commit

Permalink
add pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00-pl committed Jan 8, 2025
1 parent 7f106fd commit b02e29b
Showing 1 changed file with 50 additions and 8 deletions.
58 changes: 50 additions & 8 deletions plai/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,66 @@


class Pass:
def __init__(self, name, fn):
def __init__(self, name):
self.name = name
self.fn = fn

def __call__(self, x):
return self.fn(x)
def __call__(self, graph) -> bool:
"""
:param graph:
:return: True when changed.
"""
return False

def __repr__(self):
return f"Pass({self.name})"


class FnPass(Pass):
def __init__(self, fn):
super().__init__(fn.__name__)
self.fn = fn

def __call__(self, graph) -> bool:
"""
:param graph:
:return: True when changed.
"""
return self.fn(graph)


class UntilStablePass(Pass):
def __init__(self, name: str = None, step: Pass = None):
super().__init__(name or f'until_stable')
self.step = step

def __call__(self, graph) -> bool:
changed = True
while changed:
changed = self.step(graph)
return changed

def __repr__(self):
return f"UntilStable({repr(self.step)})"


class Pipeline(Pass):
def __init__(self, name: str = None, steps: typing.List[Pass] = None, metadata: dict = None):
super().__init__(name or f'pipeline', self.call_steps)
super().__init__(name or f'pipeline')
self.steps = steps or []
self.metadata = metadata or {}

def call_steps(self, x):
def __call__(self, graph) -> bool:
changed = False
for step in self.steps:
x = step(x)
return x
step_changed = step(graph)
changed = changed or step_changed
return changed

def __repr__(self):
return f"Pipeline({repr(self.steps)})"

def add_step(self, step: Pass):
self.steps.append(step)



0 comments on commit b02e29b

Please sign in to comment.