diff --git a/hamilton/graph_types.py b/hamilton/graph_types.py index f61336fe7..f6f060fd5 100644 --- a/hamilton/graph_types.py +++ b/hamilton/graph_types.py @@ -100,20 +100,20 @@ class HamiltonNode: documentation: typing.Optional[str] required_dependencies: typing.Set[str] optional_dependencies: typing.Set[str] - optional_input_values: typing.Dict[str, typing.Any] + optional_dependencies_default_values: typing.Dict[str, typing.Any] - def as_dict(self): + def as_dict(self, include_optional_dependencies_default_values: bool = False) -> dict: """Create a dictionary representation of the Node that is JSON serializable. - Note: optional values could be anything and might not be JSON serializable. + :param include_optional_dependencies_default_values: Include optional dependencies default values in the output. + Note: optional values could be anything and might not be JSON serializable. """ - return { + dict_representation = { "name": self.name, "tags": self.tags, "output_type": (get_type_as_string(self.type) if get_type_as_string(self.type) else ""), "required_dependencies": sorted(self.required_dependencies), "optional_dependencies": sorted(self.optional_dependencies), - "optional_input_values": self.optional_input_values, "source": ( inspect.getsource(self.originating_functions[0]) if self.originating_functions @@ -122,6 +122,11 @@ def as_dict(self): "documentation": self.documentation, "version": self.version, } + if include_optional_dependencies_default_values: + dict_representation["optional_dependencies_default_values"] = ( + self.optional_dependencies_default_values + ) + return dict_representation @staticmethod def from_node(n: node.Node) -> "HamiltonNode": @@ -147,7 +152,9 @@ def from_node(n: node.Node) -> "HamiltonNode": for dep, (type_, dep_type) in n.input_types.items() if dep_type == node.DependencyType.OPTIONAL }, - optional_input_values={name: value for name, value in n.default_input_values.items()}, + optional_dependencies_default_values={ + name: value for name, value in n.default_parameter_values.items() + }, ) @functools.cached_property diff --git a/hamilton/node.py b/hamilton/node.py index 779bde460..b875ee6b9 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -148,8 +148,8 @@ def input_types(self) -> Dict[Any, Tuple[Any, DependencyType]]: return self._input_types @property - def default_input_values(self) -> Dict[str, Any]: - """Only returns keys for which we have optional values.""" + def default_parameter_values(self) -> Dict[str, Any]: + """Only returns parameters for which we have optional values.""" return self._default_parameter_values def requires(self, dependency: str) -> bool: @@ -322,7 +322,9 @@ def copy_with(self, include_refs: bool = True, **overrides) -> "Node": input_types=self.input_types.copy(), tags=self.tags.copy(), originating_functions=self.originating_functions, - optional_values=self.default_input_values.copy() if self.default_input_values else {}, + optional_values=self.default_parameter_values.copy() + if self.default_parameter_values + else {}, ) constructor_args.update(**overrides) out = Node(**constructor_args) diff --git a/tests/plugins/test_h_schema.py b/tests/plugins/test_h_schema.py index 1726c6dd2..c44831c98 100644 --- a/tests/plugins/test_h_schema.py +++ b/tests/plugins/test_h_schema.py @@ -245,7 +245,7 @@ def foo(x: pd.DataFrame) -> pd.DataFrame: originating_functions=(foo,), required_dependencies=set(), optional_dependencies=set(), - optional_input_values={}, + optional_dependencies_default_values={}, ) df = pd.DataFrame({"a": [0, 1], "b": [True, False]}) @@ -281,7 +281,7 @@ def foo(x: pd.DataFrame) -> pd.DataFrame: originating_functions=(foo,), required_dependencies=set(), optional_dependencies=set(), - optional_input_values={}, + optional_dependencies_default_values={}, ) h_graph = graph_types.HamiltonGraph([node]) df = pd.DataFrame({"a": [0, 1], "b": [True, False]})