Skip to content

Commit

Permalink
Merge pull request #194 from e-cal/print-trace
Browse files Browse the repository at this point in the history
Print single trace
  • Loading branch information
haz authored May 26, 2023
2 parents 5feb8ed + 57ed5d8 commit d190062
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
38 changes: 38 additions & 0 deletions macq/trace/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from dataclasses import dataclass
from typing import List, Type, Iterable, Callable, Set
from inspect import cleandoc
from warnings import warn
from rich.table import Table
from rich.text import Text
from rich.console import Console
from . import Action, Step, State
from ..observation import Observation, NoisyPartialDisorderedParallelObservation
from ..utils import TokenizationError
Expand Down Expand Up @@ -198,6 +200,42 @@ def colorgrid(self, filter_func=lambda _: True, wrap=True):

return colorgrid

def get_printable(self, view="details", filter_func=lambda _: True, wrap=None):
"""Returns a printable representation of the trace in the specified view."""
views = ["details", "color", "actions"]
if view not in views:
warn(f'Invalid view {view}. Defaulting to "details".')
view = "details"

if view == "details":
if wrap is None: wrap = False
return self.details(wrap=wrap)
elif view == "color":
if wrap is None: wrap = True
return self.colorgrid(filter_func=filter_func, wrap=wrap)
elif view == "actions":
return [step.action for step in self]


def print(self, view="details", filter_func=lambda _: True, wrap=None):
"""Pretty prints the trace in the specified view.
Arguments:
view ("details" | "color" | "actions"):
Specifies the view format to print in. "details" prints a
detailed summary of each step in a trace. "color" prints a
color grid, mapping fluents in a step to either red or green
corresponding to the truth value. "actions" prints the actions
in the trace.
filter_func (Callable):
A function used to filter the fluents to be printed.
wrap (bool):
Specifies whether or not to wrap the text in the printed output.
"""
console = Console()
console.print(self.get_printable(view=view, filter_func=filter_func, wrap=wrap))
print()

def get_static_fluents(self):
fstates = defaultdict(list)
for step in self:
Expand Down
27 changes: 3 additions & 24 deletions macq/trace/trace_list.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from collections.abc import MutableSequence
from logging import warn
from typing import Callable, List, Optional, Type, Union
from typing import Callable, List, Type, Union
from warnings import warn

from rich.console import Console

from ..observation import Observation, ObservedTraceList
from . import Action, Trace

Expand Down Expand Up @@ -148,28 +145,10 @@ def print(self, view="details", filter_func=lambda _: True, wrap=None):
corresponding to the truth value. "actions" prints the actions
in the traces.
"""
console = Console()

views = ["details", "color", "actions"]
if view not in views:
warn(f'Invalid view {view}. Defaulting to "details".')
view = "details"

traces = []
if view == "details":
if wrap is None:
wrap = False
traces = [trace.details(wrap=wrap) for trace in self]

elif view == "color":
if wrap is None:
wrap = True
traces = [
trace.colorgrid(filter_func=filter_func, wrap=wrap) for trace in self
]
elif view == "actions":
traces = [[step.action for step in trace] for trace in self]

for trace in traces:
console.print(trace)
print()
for trace in self:
trace.print(view, filter_func, wrap)

0 comments on commit d190062

Please sign in to comment.