Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print single trace #194

Merged
merged 1 commit into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions macq/trace/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from dataclasses import dataclass
from typing import List, Type, Iterable, Callable, Set
from inspect import cleandoc
from warnings import warn
from rich.table import Table
from rich.text import Text
from rich.console import Console
from . import Action, Step, State
from ..observation import Observation, NoisyPartialDisorderedParallelObservation
from ..utils import TokenizationError
Expand Down Expand Up @@ -198,6 +200,42 @@ def colorgrid(self, filter_func=lambda _: True, wrap=True):

return colorgrid

def get_printable(self, view="details", filter_func=lambda _: True, wrap=None):
"""Returns a printable representation of the trace in the specified view."""
views = ["details", "color", "actions"]
if view not in views:
warn(f'Invalid view {view}. Defaulting to "details".')
view = "details"

if view == "details":
if wrap is None: wrap = False
return self.details(wrap=wrap)
elif view == "color":
if wrap is None: wrap = True
return self.colorgrid(filter_func=filter_func, wrap=wrap)
elif view == "actions":
return [step.action for step in self]


def print(self, view="details", filter_func=lambda _: True, wrap=None):
"""Pretty prints the trace in the specified view.

Arguments:
view ("details" | "color" | "actions"):
Specifies the view format to print in. "details" prints a
detailed summary of each step in a trace. "color" prints a
color grid, mapping fluents in a step to either red or green
corresponding to the truth value. "actions" prints the actions
in the trace.
filter_func (Callable):
A function used to filter the fluents to be printed.
wrap (bool):
Specifies whether or not to wrap the text in the printed output.
"""
console = Console()
console.print(self.get_printable(view=view, filter_func=filter_func, wrap=wrap))
print()

def get_static_fluents(self):
fstates = defaultdict(list)
for step in self:
Expand Down
27 changes: 3 additions & 24 deletions macq/trace/trace_list.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from collections.abc import MutableSequence
from logging import warn
from typing import Callable, List, Optional, Type, Union
from typing import Callable, List, Type, Union
from warnings import warn

from rich.console import Console

from ..observation import Observation, ObservedTraceList
from . import Action, Trace

Expand Down Expand Up @@ -148,28 +145,10 @@ def print(self, view="details", filter_func=lambda _: True, wrap=None):
corresponding to the truth value. "actions" prints the actions
in the traces.
"""
console = Console()

views = ["details", "color", "actions"]
if view not in views:
warn(f'Invalid view {view}. Defaulting to "details".')
view = "details"

traces = []
if view == "details":
if wrap is None:
wrap = False
traces = [trace.details(wrap=wrap) for trace in self]

elif view == "color":
if wrap is None:
wrap = True
traces = [
trace.colorgrid(filter_func=filter_func, wrap=wrap) for trace in self
]
elif view == "actions":
traces = [[step.action for step in trace] for trace in self]

for trace in traces:
console.print(trace)
print()
for trace in self:
trace.print(view, filter_func, wrap)