Skip to content

Commit

Permalink
Refactors code given feedback
Browse files Browse the repository at this point in the history
Made including default values an optional argument, since
people might not want that due to serialization issues.
  • Loading branch information
skrawcz committed Aug 28, 2024
1 parent 52ecf5f commit 8e39dff
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
19 changes: 13 additions & 6 deletions hamilton/graph_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,20 @@ class HamiltonNode:
documentation: typing.Optional[str]
required_dependencies: typing.Set[str]
optional_dependencies: typing.Set[str]
optional_input_values: typing.Dict[str, typing.Any]
optional_dependencies_default_values: typing.Dict[str, typing.Any]

def as_dict(self):
def as_dict(self, include_optional_dependencies_default_values: bool = False) -> dict:
"""Create a dictionary representation of the Node that is JSON serializable.
Note: optional values could be anything and might not be JSON serializable.
:param include_optional_dependencies_default_values: Include optional dependencies default values in the output.
Note: optional values could be anything and might not be JSON serializable.
"""
return {
dict_representation = {
"name": self.name,
"tags": self.tags,
"output_type": (get_type_as_string(self.type) if get_type_as_string(self.type) else ""),
"required_dependencies": sorted(self.required_dependencies),
"optional_dependencies": sorted(self.optional_dependencies),
"optional_input_values": self.optional_input_values,
"source": (
inspect.getsource(self.originating_functions[0])
if self.originating_functions
Expand All @@ -122,6 +122,11 @@ def as_dict(self):
"documentation": self.documentation,
"version": self.version,
}
if include_optional_dependencies_default_values:
dict_representation["optional_dependencies_default_values"] = (
self.optional_dependencies_default_values
)
return dict_representation

@staticmethod
def from_node(n: node.Node) -> "HamiltonNode":
Expand All @@ -147,7 +152,9 @@ def from_node(n: node.Node) -> "HamiltonNode":
for dep, (type_, dep_type) in n.input_types.items()
if dep_type == node.DependencyType.OPTIONAL
},
optional_input_values={name: value for name, value in n.default_input_values.items()},
optional_dependencies_default_values={
name: value for name, value in n.default_parameter_values.items()
},
)

@functools.cached_property
Expand Down
8 changes: 5 additions & 3 deletions hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def input_types(self) -> Dict[Any, Tuple[Any, DependencyType]]:
return self._input_types

@property
def default_input_values(self) -> Dict[str, Any]:
"""Only returns keys for which we have optional values."""
def default_parameter_values(self) -> Dict[str, Any]:
"""Only returns parameters for which we have optional values."""
return self._default_parameter_values

def requires(self, dependency: str) -> bool:
Expand Down Expand Up @@ -322,7 +322,9 @@ def copy_with(self, include_refs: bool = True, **overrides) -> "Node":
input_types=self.input_types.copy(),
tags=self.tags.copy(),
originating_functions=self.originating_functions,
optional_values=self.default_input_values.copy() if self.default_input_values else {},
optional_values=self.default_parameter_values.copy()
if self.default_parameter_values
else {},
)
constructor_args.update(**overrides)
out = Node(**constructor_args)
Expand Down
4 changes: 2 additions & 2 deletions tests/plugins/test_h_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def foo(x: pd.DataFrame) -> pd.DataFrame:
originating_functions=(foo,),
required_dependencies=set(),
optional_dependencies=set(),
optional_input_values={},
optional_dependencies_default_values={},
)
df = pd.DataFrame({"a": [0, 1], "b": [True, False]})

Expand Down Expand Up @@ -281,7 +281,7 @@ def foo(x: pd.DataFrame) -> pd.DataFrame:
originating_functions=(foo,),
required_dependencies=set(),
optional_dependencies=set(),
optional_input_values={},
optional_dependencies_default_values={},
)
h_graph = graph_types.HamiltonGraph([node])
df = pd.DataFrame({"a": [0, 1], "b": [True, False]})
Expand Down

0 comments on commit 8e39dff

Please sign in to comment.