Skip to content

Commit

Permalink
Fit curve case (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
ValentinaHutter authored Oct 2, 2023
1 parent 594feda commit cb78def
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 33 deletions.
62 changes: 30 additions & 32 deletions openeo_pg_parser_networkx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,41 +355,39 @@ def node_callable(*args, parent_callables, named_parameters=None, **kwargs):
for func in parent_callables:
func(*args, named_parameters=named_parameters, **kwargs)

try:
# If this node has already been computed once, just grab that result from the results_cache instead of recomputing it.
# This cannot be done for aggregated data as the wrapped function has to be called multiple times with different values.
# This also means the results_cache will be useless for these functions.
# TODO: track how often functions need to be called and check if they have been called that many times, if yes, we can
# use the cache for aggregate functions, but this is probably not super necessary
no_cache_processes = [
"aggregate_temporal_period",
"fit_curve",
"predict_curve",
]
for n in self.nodes:
if n[1]["process_id"] in no_cache_processes:
raise KeyError()
# If this node has already been computed once, just grab that result from the results_cache instead of recomputing it.
# This cannot be done for aggregated data as the wrapped function has to be called multiple times with different values.
# This also means the results_cache will be useless for these functions.
# TODO: track how often functions need to be called and check if they have been called that many times, if yes, we can
# use the cache for aggregate functions, but this is probably not super necessary
# we now found that this is also necessary for curve fitting
no_cache_processes = [
"aggregate_temporal_period",
"fit_curve",
"predict_curve",
]
for n in self.nodes:
if n[1]["process_id"] in no_cache_processes:
for _, source_node, data in self.G.out_edges(node, data=True):
if data["reference_type"] == PGEdgeType.ResultReference:
for arg_sub in data["arg_substitutions"]:
arg_sub.access_func(
new_value=results_cache[source_node], set_bool=True
)

kwargs[arg_sub.arg_name] = self.G.nodes(data=True)[node][
"resolved_kwargs"
].__getitem__(arg_sub.arg_name)

result = prebaked_process_impl(
*args, named_parameters=named_parameters, **kwargs
)

return results_cache.__getitem__(node)
except KeyError:
for _, source_node, data in self.G.out_edges(node, data=True):
if data["reference_type"] == PGEdgeType.ResultReference:
for arg_sub in data["arg_substitutions"]:
arg_sub.access_func(
new_value=results_cache[source_node], set_bool=True
)

kwargs[arg_sub.arg_name] = self.G.nodes(data=True)[node][
"resolved_kwargs"
].__getitem__(arg_sub.arg_name)

result = prebaked_process_impl(
*args, named_parameters=named_parameters, **kwargs
)
results_cache[node] = result

results_cache[node] = result
return result

return result
return results_cache.__getitem__(node)

return partial(node_callable, parent_callables=parent_callables)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openeo-pg-parser-networkx"
version = "2023.8.3"
version = "2023.8.4"

description = "Parse OpenEO process graphs from JSON to traversible Python objects."
authors = ["Lukas Weidenholzer <[email protected]>", "Sean Hoyal <[email protected]>", "Valentina Hutter <[email protected]>", "Gerald Irsiegler <[email protected]>"]
Expand Down

0 comments on commit cb78def

Please sign in to comment.