diff --git a/alpaca/graph.py b/alpaca/graph.py index 20c2d78..5983273 100644 --- a/alpaca/graph.py +++ b/alpaca/graph.py @@ -611,13 +611,26 @@ def aggregate(self, group_node_attributes, use_function_parameters=True, group_node_attributes : dict Dictionary selecting which attributes are used in the aggregation. The keys are the possible labels in the graph, and the values - are tuples of the node attributes used for determining supernodes. + are tuples of the node attributes or callables used for + determining supernodes. + For example, to aggregate `Quantity` nodes based on different `shape` attribute values, `group_node_attributes` would be `{'Quantity': ('shape',)}`. If passing an empty dictionary, no attributes will be considered, and the aggregation will be based on the topology (i.e., nodes at similar levels will be grouped according to the connectivity). + + In addition to attribute names, callables that take the + arguments `(graph, node, data)`, where `graph` is the graph being + aggregated, `node` is the node being evaluated for grouping, and + `data` is the dictionary of attributes, can be used. The returned + value is used to define the group. This allows flexibility when + grouping, as attribute values can be transformed (e.g., extracting + a token such as file extension from an attribute that stores the + path as a string), or the relationship of the node to neighbors + and values of edges can be checked. However, this will increase + the time to evaluate the grouping criteria of a node. use_function_parameters : bool, optional If True, the parameters of function nodes in the graph will be considered in the aggregation, i.e., if the same function is called @@ -661,8 +674,8 @@ def aggregate(self, group_node_attributes, use_function_parameters=True, [1]_. The function was modified to group the nodes based on different - attributes (using a dictionary based on the labels) instead of a - single attribute that is common to all nodes. + attributes or callables (using a dictionary based on the labels) + instead of attributes that are common to all nodes. During the summary graph generation, the attribute values are also summarized, so that the user has an idea of all the possible values in @@ -679,7 +692,7 @@ def aggregate(self, group_node_attributes, use_function_parameters=True, June 2008. """ - def _fetch_group_tuple(data, label, data_attributes, + def _fetch_group_tuple(graph, node, data, label, data_attributes, use_function_params): group_info = [label] @@ -694,9 +707,24 @@ def _fetch_group_tuple(data, label, data_attributes, else: if data_attributes is not None: # We have requested grouping for this object based on - # selected attributes. Otherwise, we will use the label + # selected attributes/callables. Otherwise, we will use + # the label only + for attr in data_attributes: - group_info.append(data[attr]) + + if callable(attr): + # We have requested grouping using a function that + # takes the graph, the node, and the node + # attributes as parameters. This allows a more + # customized filtering, that can extract specific + # information from the attribute value or use the + # node relationships + group_info.append(attr(graph, node, data)) + else: + # Fetch the attribute value for this node, if + # available + group_info.append(data.get(attr, None)) + return tuple(group_info) # We don't consider edges @@ -704,7 +732,8 @@ def _fetch_group_tuple(data, label, data_attributes, # Create the groups based on the selected conditions group_lookup = { - node: _fetch_group_tuple(attrs, attrs['label'], + node: _fetch_group_tuple( + self.graph, node, attrs, attrs['label'], group_node_attributes.get(attrs['label'], None), use_function_parameters) for node, attrs in self.graph.nodes.items() diff --git a/alpaca/test/res/multiple_file_output.ttl b/alpaca/test/res/multiple_file_output.ttl new file mode 100644 index 0000000..4e1d0f2 --- /dev/null +++ b/alpaca/test/res/multiple_file_output.ttl @@ -0,0 +1,103 @@ +@prefix alpaca: . +@prefix prov: . +@prefix xsd: . + + a alpaca:FileEntity ; + alpaca:filePath "/outputs/1.png"^^xsd:string ; + prov:wasDerivedFrom ; + prov:wasAttributedTo ; + prov:wasGeneratedBy . + + a alpaca:FileEntity ; + alpaca:filePath "/outputs/2.png"^^xsd:string ; + prov:wasDerivedFrom ; + prov:wasAttributedTo ; + prov:wasGeneratedBy . + + a alpaca:DataObjectEntity ; + prov:wasAttributedTo ; + prov:wasDerivedFrom ; + prov:wasGeneratedBy ; + alpaca:hashSource "UUID" . + + a alpaca:DataObjectEntity ; + prov:wasAttributedTo ; + prov:wasDerivedFrom ; + prov:wasGeneratedBy ; + alpaca:hashSource "UUID" . + + a alpaca:DataObjectEntity ; + prov:wasAttributedTo ; + prov:wasDerivedFrom ; + prov:wasGeneratedBy ; + alpaca:hashSource "joblib_SHA1" . + + a alpaca:DataObjectEntity ; + prov:wasAttributedTo ; + prov:wasDerivedFrom ; + prov:wasGeneratedBy ; + alpaca:hashSource "joblib_SHA1" . + + a alpaca:DataObjectEntity ; + prov:wasAttributedTo ; + alpaca:hashSource "joblib_SHA1" . + + a alpaca:FileEntity ; + alpaca:filePath "/full.png"^^xsd:string ; + prov:wasDerivedFrom ; + prov:wasAttributedTo ; + prov:wasGeneratedBy . + + a alpaca:DataObjectEntity ; + prov:wasAttributedTo ; + prov:wasDerivedFrom ; + prov:wasGeneratedBy ; + alpaca:hashSource "UUID" . + + a alpaca:FunctionExecution ; + prov:startedAtTime "2022-05-02T12:34:56.123456"^^xsd:dateTime ; + prov:endedAtTime "2022-05-02T12:35:56.123456"^^xsd:dateTime ; + prov:used ; + prov:wasAssociatedWith ; + alpaca:codeStatement "plot_function(input, out_file)" ; + alpaca:executionOrder 3 ; + alpaca:usedFunction . + + a alpaca:FunctionExecution ; + prov:startedAtTime "2022-05-02T12:34:56.123456"^^xsd:dateTime ; + prov:endedAtTime "2022-05-02T12:35:56.123456"^^xsd:dateTime ; + prov:used ; + prov:wasAssociatedWith ; + alpaca:codeStatement "plot_function(input, out_file)" ; + alpaca:executionOrder 4 ; + alpaca:usedFunction . + + a alpaca:FunctionExecution ; + prov:startedAtTime "2022-05-02T12:34:56.123456"^^xsd:dateTime ; + prov:endedAtTime "2022-05-02T12:35:56.123456"^^xsd:dateTime ; + prov:used ; + prov:wasAssociatedWith ; + alpaca:codeStatement "plot_function(input, out_file)" ; + alpaca:executionOrder 1 ; + alpaca:usedFunction . + + a alpaca:FunctionExecution ;prov:startedAtTime "2022-05-02T12:34:56.123456"^^xsd:dateTime ; + prov:endedAtTime "2022-05-02T12:35:56.123456"^^xsd:dateTime ; + prov:used ; + prov:wasAssociatedWith ; + alpaca:codeStatement "cut_function(full_data)" ; + alpaca:executionOrder 2 ; + alpaca:usedFunction . + + a alpaca:Function ; + alpaca:functionName "plot_function" ; + alpaca:implementedIn "test" ; + alpaca:functionVersion "0.0.1" . + + a alpaca:Function ; + alpaca:functionName "cut_function" ; + alpaca:implementedIn "test" ; + alpaca:functionVersion "0.0.1" . + + a alpaca:ScriptAgent ; + alpaca:scriptPath "/script.py" . diff --git a/alpaca/test/res/parallel_graph.ttl b/alpaca/test/res/parallel_graph.ttl index 0b2054b..9d137e3 100644 --- a/alpaca/test/res/parallel_graph.ttl +++ b/alpaca/test/res/parallel_graph.ttl @@ -136,7 +136,10 @@ alpaca:pairValue 5 ], [ a alpaca:NameValuePair ; alpaca:pairName "shape" ; - alpaca:pairValue "(2,)" ] ; + alpaca:pairValue "(2,)" ], + [ a alpaca:NameValuePair ; + alpaca:pairName "id" ; + alpaca:pairValue 1 ] ; alpaca:hashSource "joblib_SHA1" . a alpaca:DataObjectEntity ; diff --git a/alpaca/test/test_graph.py b/alpaca/test/test_graph.py index 303a316..46bb78b 100644 --- a/alpaca/test/test_graph.py +++ b/alpaca/test/test_graph.py @@ -1,3 +1,4 @@ +import sys import unittest from pathlib import Path @@ -494,7 +495,8 @@ class GraphAggregationTestCase(unittest.TestCase): def setUpClass(cls): cls.ttl_path = Path(__file__).parent / "res" input_file = cls.ttl_path / "parallel_graph.ttl" - cls.graph = ProvenanceGraph(input_file, attributes=['shape', 'metadata']) + cls.graph = ProvenanceGraph(input_file, attributes=['shape', + 'metadata', 'id']) alpaca_setting('authority', "my-authority") def test_serialization(self): @@ -542,6 +544,123 @@ def test_overall_aggregation(self): for key, value in expected_values_per_node[label].items(): self.assertEqual(attrs[key], value) + def test_aggregation_by_callable(self): + graph_file = self.ttl_path / "multiple_file_output.ttl" + + # Non-aggregated graph + graph = ProvenanceGraph(graph_file) + + # Aggregate without attributes + aggregated = graph.aggregate({}, output_file=None) + + # Aggregate separating by file path in File nodes + aggregated_path = graph.aggregate({'File': ('File_path',)}, + output_file=None) + + # Aggregate using a callable to separate files which path starts with + # "/outputs/" + is_cut_plot = lambda g, n, d: d['File_path'].startswith("/outputs/") + aggregated_callable = graph.aggregate({'File': (is_cut_plot,)}, + output_file=None) + + # Define a dictionary with the expected values for each case, that + # are used in subtests below + tests = { + 'non_aggregated': {'graph': graph.graph, 'length': 10, + 'counts': {'InputObject': 3, + 'plot_function': 3, + 'cut_function': 1, + 'File': 3}, + 'paths': ["/full.png", + "/outputs/1.png", + "/outputs/2.png"] + }, + + 'aggregated': {'graph': aggregated, 'length': 5, + 'counts': {'InputObject': 2, + 'plot_function': 1, + 'cut_function': 1, + 'File': 1}, + 'paths': "/full.png;/outputs/1.png;/outputs/2.png" + }, + + 'aggregated_path': {'graph': aggregated_path, 'length': 10, + 'counts': {'InputObject': 3, + 'plot_function': 3, + 'cut_function': 1, + 'File': 3}, + 'paths': ["/full.png", + "/outputs/1.png", + "/outputs/2.png"] + }, + 'aggregated_callable': {'graph': aggregated_callable, 'length': 7, + 'counts': {'InputObject': 2, + 'plot_function': 2, + 'cut_function': 1, + 'File': 2}, + 'paths': ["/full.png", + "/outputs/1.png;/outputs/2.png"] + }, + } + + for key, expected in tests.items(): + with self.subTest(f"Graph {key}"): + test_graph = expected['graph'] + nodes = test_graph.nodes + self.assertEqual(len(nodes), expected['length']) + + # Check if node counts is as expected + all_labels = [nodes[node]['label'] for node in nodes] + counts = Counter(all_labels) + for label, count in expected['counts'].items(): + self.assertEqual(counts[label], count) + + # Check if file paths in the node are as expected + paths = expected['paths'] + for node, attrs in nodes.items(): + # Check value of file paths in File nodes + if attrs['label'] == "File": + if isinstance(paths, list): + self.assertTrue(attrs['File_path'] in paths) + else: + self.assertEqual(attrs['File_path'], paths) + + def test_aggregation_by_attribute_with_missing(self): + aggregated = self.graph.aggregate({'InputObject': ('id',)}, + use_function_parameters=False, + output_file=None) + nodes = aggregated.nodes + + self.assertEqual(len(nodes), 5) + + expected_values_per_node = { + 'OutputObject': {'metadata': "0;1", + 'shape': "(2,);(3,);(4,);(5,)"}, + 'InputObject': {'metadata': "5", + 'shape': ["(2,)", "(3,);(4,);(5,)"], + 'id': ["1", None]}, + 'process': {'process:value': "0;1;2;3"}, + 'list': {} + } + + all_labels = [nodes[node]['label'] for node in nodes] + counts = Counter(all_labels) + self.assertEqual(counts['OutputObject'], 1) + self.assertEqual(counts['InputObject'], 2) + self.assertEqual(counts['process'], 1) + self.assertEqual(counts['list'], 1) + + for node, attrs in nodes.items(): + label = attrs['label'] + with self.subTest(f"Node label {label}"): + self.assertTrue(label in expected_values_per_node) + for key, value in expected_values_per_node[label].items(): + attr_val = attrs[key] if key in attrs else None + if not isinstance(value, list): + self.assertEqual(attr_val, value) + else: + self.assertTrue(attr_val in value) + def test_aggregation_by_attribute(self): aggregated = self.graph.aggregate({'InputObject': ('shape',)}, use_function_parameters=False,