Skip to content

Commit

Permalink
Add batch node and edge creation for graphs, closes #693
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Apr 17, 2024
1 parent 9f9ed6d commit 8277ce9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 20 deletions.
61 changes: 47 additions & 14 deletions src/python/txtai/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def addnode(self, node, **attrs):

raise NotImplementedError

def addnodes(self, nodes):
"""
Adds nodes to the graph.
Args:
nodes: list of (node, attributes) to add
"""

raise NotImplementedError

def removenode(self, node):
"""
Removes a node and all it's edges from graph.
Expand Down Expand Up @@ -200,6 +210,16 @@ def addedge(self, source, target, **attrs):

raise NotImplementedError

def addedges(self, edges):
"""
Adds an edge to graph.
Args:
edges: list of (source, target, attributes) to add
"""

raise NotImplementedError

def hasedge(self, source, target=None):
"""
Returns True if edge found, False otherwise. If target is None, this method
Expand Down Expand Up @@ -409,6 +429,7 @@ def insert(self, documents, index=0):
# Initialize graph backend
self.initialize()

nodes = []
for uid, document, _ in documents:
# Relationships are manually-provided edges
relations = None
Expand All @@ -427,13 +448,15 @@ def insert(self, documents, index=0):
document = " ".join(document)

# Create node
self.addnode(index, id=uid, data=document)
nodes.append((index, {"id": uid, "data": document}))

# Add relationships
self.addrelations(index, relations)

index += 1

self.addnodes(nodes)

def delete(self, ids):
"""
Deletes ids from graph.
Expand Down Expand Up @@ -471,8 +494,8 @@ def index(self, search, ids, similarity):
# Add relationship edges
self.resolverelations(ids)

# Add node edges
self.addedges(self.scan(), search)
# Infer node edges using search function
self.inferedges(self.scan(), search)

# Label categories/topics
if "topics" in self.config:
Expand All @@ -494,8 +517,8 @@ def upsert(self, search, ids, similarity=None):
# Add relationship edges
self.resolverelations(ids)

# Add node edges using new/updated nodes, set updated flag for topic processing, if necessary
self.addedges(self.scan(attribute="data"), search, {"updated": True} if hastopics else None)
# Infer node edges using new/updated nodes, set updated flag for topic processing, if necessary
self.inferedges(self.scan(attribute="data"), search, {"updated": True} if hastopics else None)

# Infer topics with topics of connected nodes
if hastopics:
Expand All @@ -505,20 +528,21 @@ def upsert(self, search, ids, similarity=None):
else:
self.addtopics(similarity)

def filter(self, nodes):
def filter(self, nodes, graph=None):
"""
Creates a subgraph of this graph using the list of input nodes. This method creates a new graph
selecting only matching nodes, edges, topics and categories.
Args:
nodes: nodes to select as a list of ids or list of (id, score) tuples
graph: optional graph used to store filtered results
Returns:
graph
"""

# Create a new empty graph of the same type
graph = type(self)(self.config)
# Set graph if available, otherwise create a new empty graph of the same type
graph = graph if graph else type(self)(self.config)

# Initalize subgraph
graph.initialize()
Expand Down Expand Up @@ -587,6 +611,9 @@ def resolverelations(self, ids):
ids: internal id resolver
"""

# Relationship edges
edges = []

# Resolve ids and create edges for relationships
for node, relations in self.relations.items():
# Resolve internal ids
Expand All @@ -607,14 +634,17 @@ def resolverelations(self, ids):
relation["weight"] = relation.get("weight", 1.0)

# Add edge and all other attributes
self.addedge(node, target, **relation)
edges.append((node, target, relation))

# Add relationships
self.addedges(edges)

# Clear temporary relationship storage
self.relations = {}

def addedges(self, nodes, search, attributes=None):
def inferedges(self, nodes, search, attributes=None):
"""
Adds edges for a list of nodes using a score-based search function.
Infers edges for a list of nodes using a score-based search function.
Args:
nodes: list of nodes
Expand All @@ -641,7 +671,7 @@ def addedges(self, nodes, search, attributes=None):
self.addattribute(node, field, value)

# Skip nodes with existing edges when building an approximate network
if not self.hasedge(node) or not approximate:
if not approximate or not self.hasedge(node):
batch.append((node, data))

# Process batch
Expand All @@ -664,14 +694,17 @@ def addbatch(self, search, batch, limit, minscore):
minscore: min score to add node edge
"""

edges = []
for x, result in enumerate(search([data for _, data in batch], limit)):
# Get input node id
x, _ = batch[x]

# Add edges for each input node id and result node id pair that meets specified criteria
for y, score in result:
if x != y and score > minscore and not self.hasedge(x, y):
self.addedge(x, y, weight=score)
if str(x) != str(y) and score > minscore:
edges.append((x, y, {"weight": score}))

self.addedges(edges)

def addtopics(self, similarity=None):
"""
Expand Down
18 changes: 12 additions & 6 deletions src/python/txtai/graph/networkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,19 @@ def scan(self, attribute=None):
def node(self, node):
return self.backend.nodes.get(node)

def hasnode(self, node):
return self.backend.has_node(node)

def addnode(self, node, **attrs):
self.backend.add_node(node, **attrs)

def addnodes(self, nodes):
self.backend.add_nodes_from(nodes)

def removenode(self, node):
if self.hasnode(node):
self.backend.remove_node(node)

def hasnode(self, node):
return self.backend.has_node(node)

def attribute(self, node, field):
return self.node(node).get(field) if self.hasnode(node) else None

Expand All @@ -81,16 +84,19 @@ def edges(self, node):

return None

def addedge(self, source, target, **attrs):
self.backend.add_edge(source, target, **attrs)

def addedges(self, edges):
self.backend.add_edges_from(edges)

def hasedge(self, source, target=None):
if not target:
edges = self.backend.adj.get(source)
return len(edges) > 0 if edges else False

return self.backend.has_edge(source, target)

def addedge(self, source, target, **attrs):
self.backend.add_edge(source, target, **attrs)

def centrality(self):
rank = nx.degree_centrality(self.backend)
return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True))
Expand Down
16 changes: 16 additions & 0 deletions test/python/testgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ def testDelete(self):
self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 5)
self.assertEqual(len(graph.categories), 6)

def testEdges(self):
"""
Test edges
"""

# Create graph
graph = GraphFactory.create({})
graph.initialize()
graph.addedge(0, 1)

# Test edge exists
self.assertTrue(graph.hasedge(0))
self.assertTrue(graph.hasedge(0, 1))

def testFilter(self):
"""
Test creating filtered subgraphs
Expand Down Expand Up @@ -199,6 +213,7 @@ def testNotImplemented(self):
self.assertRaises(NotImplementedError, graph.scan, None)
self.assertRaises(NotImplementedError, graph.node, None)
self.assertRaises(NotImplementedError, graph.addnode, None)
self.assertRaises(NotImplementedError, graph.addnodes, None)
self.assertRaises(NotImplementedError, graph.removenode, None)
self.assertRaises(NotImplementedError, graph.hasnode, None)
self.assertRaises(NotImplementedError, graph.attribute, None, None)
Expand All @@ -207,6 +222,7 @@ def testNotImplemented(self):
self.assertRaises(NotImplementedError, graph.edgecount)
self.assertRaises(NotImplementedError, graph.edges, None)
self.assertRaises(NotImplementedError, graph.addedge, None, None)
self.assertRaises(NotImplementedError, graph.addedges, None)
self.assertRaises(NotImplementedError, graph.hasedge, None, None)
self.assertRaises(NotImplementedError, graph.centrality)
self.assertRaises(NotImplementedError, graph.pagerank)
Expand Down

0 comments on commit 8277ce9

Please sign in to comment.