diff --git a/openeo_pg_parser_networkx/graph.py b/openeo_pg_parser_networkx/graph.py index 7ec630a..5bd57f1 100644 --- a/openeo_pg_parser_networkx/graph.py +++ b/openeo_pg_parser_networkx/graph.py @@ -355,41 +355,39 @@ def node_callable(*args, parent_callables, named_parameters=None, **kwargs): for func in parent_callables: func(*args, named_parameters=named_parameters, **kwargs) - try: - # If this node has already been computed once, just grab that result from the results_cache instead of recomputing it. - # This cannot be done for aggregated data as the wrapped function has to be called multiple times with different values. - # This also means the results_cache will be useless for these functions. - # TODO: track how often functions need to be called and check if they have been called that many times, if yes, we can - # use the cache for aggregate functions, but this is probably not super necessary - no_cache_processes = [ - "aggregate_temporal_period", - "fit_curve", - "predict_curve", - ] - for n in self.nodes: - if n[1]["process_id"] in no_cache_processes: - raise KeyError() + # If this node has already been computed once, just grab that result from the results_cache instead of recomputing it. + # This cannot be done for aggregated data as the wrapped function has to be called multiple times with different values. + # This also means the results_cache will be useless for these functions. + # TODO: track how often functions need to be called and check if they have been called that many times, if yes, we can + # use the cache for aggregate functions, but this is probably not super necessary + # we now found that this is also necessary for curve fitting + no_cache_processes = [ + "aggregate_temporal_period", + "fit_curve", + "predict_curve", + ] + for n in self.nodes: + if n[1]["process_id"] in no_cache_processes: + for _, source_node, data in self.G.out_edges(node, data=True): + if data["reference_type"] == PGEdgeType.ResultReference: + for arg_sub in data["arg_substitutions"]: + arg_sub.access_func( + new_value=results_cache[source_node], set_bool=True + ) + + kwargs[arg_sub.arg_name] = self.G.nodes(data=True)[node][ + "resolved_kwargs" + ].__getitem__(arg_sub.arg_name) + + result = prebaked_process_impl( + *args, named_parameters=named_parameters, **kwargs + ) - return results_cache.__getitem__(node) - except KeyError: - for _, source_node, data in self.G.out_edges(node, data=True): - if data["reference_type"] == PGEdgeType.ResultReference: - for arg_sub in data["arg_substitutions"]: - arg_sub.access_func( - new_value=results_cache[source_node], set_bool=True - ) - - kwargs[arg_sub.arg_name] = self.G.nodes(data=True)[node][ - "resolved_kwargs" - ].__getitem__(arg_sub.arg_name) - - result = prebaked_process_impl( - *args, named_parameters=named_parameters, **kwargs - ) + results_cache[node] = result - results_cache[node] = result + return result - return result + return results_cache.__getitem__(node) return partial(node_callable, parent_callables=parent_callables) diff --git a/pyproject.toml b/pyproject.toml index e4a58de..f4033f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "openeo-pg-parser-networkx" -version = "2023.8.3" +version = "2023.8.4" description = "Parse OpenEO process graphs from JSON to traversible Python objects." authors = ["Lukas Weidenholzer ", "Sean Hoyal ", "Valentina Hutter ", "Gerald Irsiegler "]