From 57ed5d8cde69b55cf41acb06505ecae56b0a2479 Mon Sep 17 00:00:00 2001 From: ecal Date: Fri, 26 May 2023 14:08:29 -0400 Subject: [PATCH] Move printing logic to Trace --- macq/trace/trace.py | 38 ++++++++++++++++++++++++++++++++++++++ macq/trace/trace_list.py | 27 +++------------------------ 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/macq/trace/trace.py b/macq/trace/trace.py index 314bf1b3..0e77b32c 100644 --- a/macq/trace/trace.py +++ b/macq/trace/trace.py @@ -2,8 +2,10 @@ from dataclasses import dataclass from typing import List, Type, Iterable, Callable, Set from inspect import cleandoc +from warnings import warn from rich.table import Table from rich.text import Text +from rich.console import Console from . import Action, Step, State from ..observation import Observation, NoisyPartialDisorderedParallelObservation from ..utils import TokenizationError @@ -198,6 +200,42 @@ def colorgrid(self, filter_func=lambda _: True, wrap=True): return colorgrid + def get_printable(self, view="details", filter_func=lambda _: True, wrap=None): + """Returns a printable representation of the trace in the specified view.""" + views = ["details", "color", "actions"] + if view not in views: + warn(f'Invalid view {view}. Defaulting to "details".') + view = "details" + + if view == "details": + if wrap is None: wrap = False + return self.details(wrap=wrap) + elif view == "color": + if wrap is None: wrap = True + return self.colorgrid(filter_func=filter_func, wrap=wrap) + elif view == "actions": + return [step.action for step in self] + + + def print(self, view="details", filter_func=lambda _: True, wrap=None): + """Pretty prints the trace in the specified view. + + Arguments: + view ("details" | "color" | "actions"): + Specifies the view format to print in. "details" prints a + detailed summary of each step in a trace. "color" prints a + color grid, mapping fluents in a step to either red or green + corresponding to the truth value. "actions" prints the actions + in the trace. + filter_func (Callable): + A function used to filter the fluents to be printed. + wrap (bool): + Specifies whether or not to wrap the text in the printed output. + """ + console = Console() + console.print(self.get_printable(view=view, filter_func=filter_func, wrap=wrap)) + print() + def get_static_fluents(self): fstates = defaultdict(list) for step in self: diff --git a/macq/trace/trace_list.py b/macq/trace/trace_list.py index 002e6eb1..4a5e0ba8 100644 --- a/macq/trace/trace_list.py +++ b/macq/trace/trace_list.py @@ -1,10 +1,7 @@ from collections.abc import MutableSequence -from logging import warn -from typing import Callable, List, Optional, Type, Union +from typing import Callable, List, Type, Union from warnings import warn -from rich.console import Console - from ..observation import Observation, ObservedTraceList from . import Action, Trace @@ -148,28 +145,10 @@ def print(self, view="details", filter_func=lambda _: True, wrap=None): corresponding to the truth value. "actions" prints the actions in the traces. """ - console = Console() - views = ["details", "color", "actions"] if view not in views: warn(f'Invalid view {view}. Defaulting to "details".') view = "details" - traces = [] - if view == "details": - if wrap is None: - wrap = False - traces = [trace.details(wrap=wrap) for trace in self] - - elif view == "color": - if wrap is None: - wrap = True - traces = [ - trace.colorgrid(filter_func=filter_func, wrap=wrap) for trace in self - ] - elif view == "actions": - traces = [[step.action for step in trace] for trace in self] - - for trace in traces: - console.print(trace) - print() + for trace in self: + trace.print(view, filter_func, wrap)