From 19944a073e536921f64713e7b82afe48e555dda1 Mon Sep 17 00:00:00 2001 From: Paul Prescod Date: Mon, 6 Jun 2022 10:59:44 -0700 Subject: [PATCH] Refactor YAML loading to use add_representer --- snowfakery/data_generator.py | 6 ++-- snowfakery/data_generator_runtime.py | 31 +++++++++----------- snowfakery/object_rows.py | 19 +++++++----- snowfakery/plugins.py | 18 ++---------- snowfakery/standard_plugins/datasets.py | 4 +-- snowfakery/utils/yaml_utils.py | 39 +++++++++++++++++++++++-- 6 files changed, 69 insertions(+), 48 deletions(-) diff --git a/snowfakery/data_generator.py b/snowfakery/data_generator.py index 53c1b6a6..9d027707 100644 --- a/snowfakery/data_generator.py +++ b/snowfakery/data_generator.py @@ -19,7 +19,7 @@ from .data_gen_exceptions import DataGenError from .plugins import SnowfakeryPlugin, PluginOption -from .utils.yaml_utils import SnowfakeryDumper, hydrate +from .utils.yaml_utils import SnowfakeryContinuationDumper, hydrate from snowfakery.standard_plugins.UniqueId import UniqueId # This tool is essentially a three stage interpreter. @@ -95,9 +95,9 @@ def load_continuation_yaml(continuation_file: OpenFileLike): def save_continuation_yaml(continuation_data: Globals, continuation_file: OpenFileLike): """Save the global interpreter state from Globals into a continuation_file""" yaml.dump( - continuation_data.__getstate__(), + continuation_data, continuation_file, - Dumper=SnowfakeryDumper, + Dumper=SnowfakeryContinuationDumper, ) diff --git a/snowfakery/data_generator_runtime.py b/snowfakery/data_generator_runtime.py index 400cb9c0..d05c6d43 100644 --- a/snowfakery/data_generator_runtime.py +++ b/snowfakery/data_generator_runtime.py @@ -13,7 +13,7 @@ import yaml from .utils.template_utils import FakerTemplateLibrary -from .utils.yaml_utils import SnowfakeryDumper, hydrate +from .utils.yaml_utils import hydrate from .row_history import RowHistory from .template_funcs import StandardFuncs from .data_gen_exceptions import DataGenSyntaxError, DataGenNameError @@ -27,6 +27,7 @@ ) from snowfakery.plugins import PluginContext, SnowfakeryPlugin, ScalarTypes from snowfakery.utils.collections import OrderedSet +from snowfakery.utils.yaml_utils import register_for_continuation OutputStream = "snowfakery.output_streams.OutputStream" VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition" @@ -60,17 +61,15 @@ def generate_id(self, table_name: str) -> int: def __getitem__(self, table_name: str) -> int: return self.last_used_ids[table_name] - def __getstate__(self): + # TODO: Fix this to use the new convention of get_continuation_data + def get_continuation_state(self): return {"last_used_ids": dict(self.last_used_ids)} - def __setstate__(self, state): + def restore_from_continuation(self, state): self.last_used_ids = defaultdict(lambda: 0, state["last_used_ids"]) self.start_ids = {name: val + 1 for name, val in self.last_used_ids.items()} -SnowfakeryDumper.add_representer(defaultdict, SnowfakeryDumper.represent_dict) - - class Dependency(NamedTuple): table_name_from: str table_name_to: str @@ -195,29 +194,22 @@ def check_slots_filled(self): def first_new_id(self, tablename): return self.transients.first_new_id(tablename) - def __getstate__(self): - def serialize_dict_of_object_rows(dct): - return {k: v.__getstate__() for k, v in dct.items()} - - persistent_nicknames = serialize_dict_of_object_rows(self.persistent_nicknames) - persistent_objects_by_table = serialize_dict_of_object_rows( - self.persistent_objects_by_table - ) + def get_continuation_state(self): intertable_dependencies = [ dict(v._asdict()) for v in self.intertable_dependencies ] # converts ordered-dict to dict for Python 3.6 and 3.7 state = { - "persistent_nicknames": persistent_nicknames, - "persistent_objects_by_table": persistent_objects_by_table, - "id_manager": self.id_manager.__getstate__(), + "persistent_nicknames": self.persistent_nicknames, + "persistent_objects_by_table": self.persistent_objects_by_table, + "id_manager": self.id_manager.get_continuation_state(), "today": self.today, "nicknames_and_tables": self.nicknames_and_tables, "intertable_dependencies": intertable_dependencies, } return state - def __setstate__(self, state): + def restore_from_continuation(self, state): def deserialize_dict_of_object_rows(dct): return {k: hydrate(ObjectRow, v) for k, v in dct.items()} @@ -244,6 +236,9 @@ def deserialize_dict_of_object_rows(dct): self.reset_slots() +register_for_continuation(Globals, Globals.get_continuation_state) + + class JinjaTemplateEvaluatorFactory: def __init__(self, native_types: bool): if native_types: diff --git a/snowfakery/object_rows.py b/snowfakery/object_rows.py index 3e836a35..ac9696df 100644 --- a/snowfakery/object_rows.py +++ b/snowfakery/object_rows.py @@ -2,7 +2,7 @@ import yaml import snowfakery # noqa -from .utils.yaml_utils import SnowfakeryDumper +from .utils.yaml_utils import register_for_continuation from contextvars import ContextVar IdManager = "snowfakery.data_generator_runtime.IdManager" @@ -14,10 +14,6 @@ class ObjectRow: Uses __getattr__ so that the template evaluator can use dot-notation.""" - yaml_loader = yaml.SafeLoader - yaml_dumper = SnowfakeryDumper - yaml_tag = "!snowfakery_objectrow" - # be careful changing these slots because these objects must be serializable # to YAML and JSON __slots__ = ["_tablename", "_values", "_child_index"] @@ -49,19 +45,28 @@ def __repr__(self): except Exception: return super().__repr__() - def __getstate__(self): + def get_continuation_state(self): """Get the state of this ObjectRow for serialization. Do not include related ObjectRows because circular references in serialization formats cause problems.""" + + # If we decided to try to serialize hierarchies, we could + # do it like this: + # * keep track of if an object has already been serialized using a + # property of the SnowfakeryContinuationDumper + # * If so, output an ObjectReference instead of an ObjectRow values = {k: v for k, v in self._values.items() if not isinstance(v, ObjectRow)} return {"_tablename": self._tablename, "_values": values} - def __setstate__(self, state): + def restore_from_continuation(self, state): for slot, value in state.items(): setattr(self, slot, value) +register_for_continuation(ObjectRow, ObjectRow.get_continuation_state) + + class ObjectReference(yaml.YAMLObject): def __init__(self, tablename: str, id: int): self._tablename = tablename diff --git a/snowfakery/plugins.py b/snowfakery/plugins.py index c7c548ce..71b25429 100644 --- a/snowfakery/plugins.py +++ b/snowfakery/plugins.py @@ -8,13 +8,11 @@ from functools import wraps import typing as T -import yaml -from yaml.representer import Representer from faker.providers import BaseProvider as FakerProvider from dateutil.relativedelta import relativedelta import snowfakery.data_gen_exceptions as exc -from .utils.yaml_utils import SnowfakeryDumper +from snowfakery.utils.yaml_utils import register_for_continuation from .utils.collections import CaseInsensitiveDict from numbers import Number @@ -306,17 +304,7 @@ def _from_continuation(cls, args): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - _register_for_continuation(cls) - - -def _register_for_continuation(cls): - SnowfakeryDumper.add_representer(cls, Representer.represent_object) - yaml.SafeLoader.add_constructor( - f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}", - lambda loader, node: cls._from_continuation( - loader.construct_mapping(node.value[0]) - ), - ) + register_for_continuation(cls) class PluginResultIterator(PluginResult): @@ -372,4 +360,4 @@ def convert(self, value): # round-trip PluginResult objects through continuation YAML if needed. -_register_for_continuation(PluginResult) +register_for_continuation(PluginResult) diff --git a/snowfakery/standard_plugins/datasets.py b/snowfakery/standard_plugins/datasets.py index 51368a8b..c72a1796 100644 --- a/snowfakery/standard_plugins/datasets.py +++ b/snowfakery/standard_plugins/datasets.py @@ -17,7 +17,7 @@ memorable, ) from snowfakery.utils.files import FileLike, open_file_like -from snowfakery.utils.yaml_utils import SnowfakeryDumper +from snowfakery.utils.yaml_utils import SnowfakeryContinuationDumper def _open_db(db_url): @@ -258,4 +258,4 @@ def chdir(path): os.chdir(cwd) -SnowfakeryDumper.add_representer(quoted_name, Representer.represent_str) +SnowfakeryContinuationDumper.add_representer(quoted_name, Representer.represent_str) diff --git a/snowfakery/utils/yaml_utils.py b/snowfakery/utils/yaml_utils.py index 73a5a367..4125bffc 100644 --- a/snowfakery/utils/yaml_utils.py +++ b/snowfakery/utils/yaml_utils.py @@ -1,11 +1,44 @@ -from yaml import SafeDumper +from typing import Callable +from yaml import SafeDumper, SafeLoader +from yaml.representer import Representer +from collections import defaultdict -class SnowfakeryDumper(SafeDumper): +class SnowfakeryContinuationDumper(SafeDumper): pass +SnowfakeryContinuationDumper.add_representer( + defaultdict, SnowfakeryContinuationDumper.represent_dict +) + + def hydrate(cls, data): obj = cls.__new__(cls) - obj.__setstate__(data) + obj.restore_from_continuation(data) return obj + + +# Evaluate whether its cleaner for functions to bypass register_for_continuation +# and go directly to SnowfakeryContinuationDumper.add_representer. +# +# + + +def represent_continuation(dumper: SnowfakeryContinuationDumper, data): + if isinstance(data, dict): + return Representer.represent_dict(dumper, data) + else: + return Representer.represent_object(dumper, data) + + +def register_for_continuation(cls, dump_transformer: Callable = lambda x: x): + SnowfakeryContinuationDumper.add_representer( + cls, lambda self, data: represent_continuation(self, dump_transformer(data)) + ) + SafeLoader.add_constructor( + f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}", + lambda loader, node: cls._from_continuation( + loader.construct_mapping(node.value[0]) + ), + )