diff --git a/src/python/txtai/graph/base.py b/src/python/txtai/graph/base.py index 8eade3819..b9d98f930 100644 --- a/src/python/txtai/graph/base.py +++ b/src/python/txtai/graph/base.py @@ -103,6 +103,16 @@ def addnode(self, node, **attrs): raise NotImplementedError + def addnodes(self, nodes): + """ + Adds nodes to the graph. + + Args: + nodes: list of (node, attributes) to add + """ + + raise NotImplementedError + def removenode(self, node): """ Removes a node and all it's edges from graph. @@ -200,6 +210,16 @@ def addedge(self, source, target, **attrs): raise NotImplementedError + def addedges(self, edges): + """ + Adds an edge to graph. + + Args: + edges: list of (source, target, attributes) to add + """ + + raise NotImplementedError + def hasedge(self, source, target=None): """ Returns True if edge found, False otherwise. If target is None, this method @@ -409,6 +429,7 @@ def insert(self, documents, index=0): # Initialize graph backend self.initialize() + nodes = [] for uid, document, _ in documents: # Relationships are manually-provided edges relations = None @@ -427,13 +448,15 @@ def insert(self, documents, index=0): document = " ".join(document) # Create node - self.addnode(index, id=uid, data=document) + nodes.append((index, {"id": uid, "data": document})) # Add relationships self.addrelations(index, relations) index += 1 + self.addnodes(nodes) + def delete(self, ids): """ Deletes ids from graph. @@ -471,8 +494,8 @@ def index(self, search, ids, similarity): # Add relationship edges self.resolverelations(ids) - # Add node edges - self.addedges(self.scan(), search) + # Infer node edges using search function + self.inferedges(self.scan(), search) # Label categories/topics if "topics" in self.config: @@ -494,8 +517,8 @@ def upsert(self, search, ids, similarity=None): # Add relationship edges self.resolverelations(ids) - # Add node edges using new/updated nodes, set updated flag for topic processing, if necessary - self.addedges(self.scan(attribute="data"), search, {"updated": True} if hastopics else None) + # Infer node edges using new/updated nodes, set updated flag for topic processing, if necessary + self.inferedges(self.scan(attribute="data"), search, {"updated": True} if hastopics else None) # Infer topics with topics of connected nodes if hastopics: @@ -505,20 +528,21 @@ def upsert(self, search, ids, similarity=None): else: self.addtopics(similarity) - def filter(self, nodes): + def filter(self, nodes, graph=None): """ Creates a subgraph of this graph using the list of input nodes. This method creates a new graph selecting only matching nodes, edges, topics and categories. Args: nodes: nodes to select as a list of ids or list of (id, score) tuples + graph: optional graph used to store filtered results Returns: graph """ - # Create a new empty graph of the same type - graph = type(self)(self.config) + # Set graph if available, otherwise create a new empty graph of the same type + graph = graph if graph else type(self)(self.config) # Initalize subgraph graph.initialize() @@ -587,6 +611,9 @@ def resolverelations(self, ids): ids: internal id resolver """ + # Relationship edges + edges = [] + # Resolve ids and create edges for relationships for node, relations in self.relations.items(): # Resolve internal ids @@ -607,14 +634,17 @@ def resolverelations(self, ids): relation["weight"] = relation.get("weight", 1.0) # Add edge and all other attributes - self.addedge(node, target, **relation) + edges.append((node, target, relation)) + + # Add relationships + self.addedges(edges) # Clear temporary relationship storage self.relations = {} - def addedges(self, nodes, search, attributes=None): + def inferedges(self, nodes, search, attributes=None): """ - Adds edges for a list of nodes using a score-based search function. + Infers edges for a list of nodes using a score-based search function. Args: nodes: list of nodes @@ -641,7 +671,7 @@ def addedges(self, nodes, search, attributes=None): self.addattribute(node, field, value) # Skip nodes with existing edges when building an approximate network - if not self.hasedge(node) or not approximate: + if not approximate or not self.hasedge(node): batch.append((node, data)) # Process batch @@ -664,14 +694,17 @@ def addbatch(self, search, batch, limit, minscore): minscore: min score to add node edge """ + edges = [] for x, result in enumerate(search([data for _, data in batch], limit)): # Get input node id x, _ = batch[x] # Add edges for each input node id and result node id pair that meets specified criteria for y, score in result: - if x != y and score > minscore and not self.hasedge(x, y): - self.addedge(x, y, weight=score) + if str(x) != str(y) and score > minscore: + edges.append((x, y, {"weight": score})) + + self.addedges(edges) def addtopics(self, similarity=None): """ diff --git a/src/python/txtai/graph/networkx.py b/src/python/txtai/graph/networkx.py index 82e5deaa7..72edec38a 100644 --- a/src/python/txtai/graph/networkx.py +++ b/src/python/txtai/graph/networkx.py @@ -51,16 +51,19 @@ def scan(self, attribute=None): def node(self, node): return self.backend.nodes.get(node) - def hasnode(self, node): - return self.backend.has_node(node) - def addnode(self, node, **attrs): self.backend.add_node(node, **attrs) + def addnodes(self, nodes): + self.backend.add_nodes_from(nodes) + def removenode(self, node): if self.hasnode(node): self.backend.remove_node(node) + def hasnode(self, node): + return self.backend.has_node(node) + def attribute(self, node, field): return self.node(node).get(field) if self.hasnode(node) else None @@ -81,6 +84,12 @@ def edges(self, node): return None + def addedge(self, source, target, **attrs): + self.backend.add_edge(source, target, **attrs) + + def addedges(self, edges): + self.backend.add_edges_from(edges) + def hasedge(self, source, target=None): if not target: edges = self.backend.adj.get(source) @@ -88,9 +97,6 @@ def hasedge(self, source, target=None): return self.backend.has_edge(source, target) - def addedge(self, source, target, **attrs): - self.backend.add_edge(source, target, **attrs) - def centrality(self): rank = nx.degree_centrality(self.backend) return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True)) diff --git a/test/python/testgraph.py b/test/python/testgraph.py index 568ef4ced..322eb9bdf 100644 --- a/test/python/testgraph.py +++ b/test/python/testgraph.py @@ -124,6 +124,20 @@ def testDelete(self): self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 5) self.assertEqual(len(graph.categories), 6) + def testEdges(self): + """ + Test edges + """ + + # Create graph + graph = GraphFactory.create({}) + graph.initialize() + graph.addedge(0, 1) + + # Test edge exists + self.assertTrue(graph.hasedge(0)) + self.assertTrue(graph.hasedge(0, 1)) + def testFilter(self): """ Test creating filtered subgraphs @@ -199,6 +213,7 @@ def testNotImplemented(self): self.assertRaises(NotImplementedError, graph.scan, None) self.assertRaises(NotImplementedError, graph.node, None) self.assertRaises(NotImplementedError, graph.addnode, None) + self.assertRaises(NotImplementedError, graph.addnodes, None) self.assertRaises(NotImplementedError, graph.removenode, None) self.assertRaises(NotImplementedError, graph.hasnode, None) self.assertRaises(NotImplementedError, graph.attribute, None, None) @@ -207,6 +222,7 @@ def testNotImplemented(self): self.assertRaises(NotImplementedError, graph.edgecount) self.assertRaises(NotImplementedError, graph.edges, None) self.assertRaises(NotImplementedError, graph.addedge, None, None) + self.assertRaises(NotImplementedError, graph.addedges, None) self.assertRaises(NotImplementedError, graph.hasedge, None, None) self.assertRaises(NotImplementedError, graph.centrality) self.assertRaises(NotImplementedError, graph.pagerank)