Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert a personal-graph to and from NetworkX #20

Merged
merged 40 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8031a87
feature: added networkx dependencies
anubhuti24 Apr 19, 2024
1299df3
Merge branch 'main' into enhancement/export-to-PyG-NetworkX
anubhuti24 Apr 22, 2024
224428f
feature: added to_networkX method
anubhuti24 Apr 22, 2024
0ac7be9
refactor: updated examples
anubhuti24 Apr 22, 2024
6f866d3
refactor: networkX method
anubhuti24 Apr 22, 2024
57b9320
feature: added post visualize parameter
anubhuti24 Apr 22, 2024
dfae1a9
feature: added to_personal method
anubhuti24 Apr 23, 2024
c5ce2a7
added helper functions
anubhuti24 Apr 23, 2024
cdf104d
updated examples
anubhuti24 Apr 23, 2024
a5f9731
updated to_networkX method
anubhuti24 Apr 23, 2024
82c2795
refactor: updated from_networkX method
anubhuti24 Apr 23, 2024
5ebb0b9
refactor: examples and ml.py
anubhuti24 Apr 23, 2024
00b39df
refactor: updated logic for networkX
anubhuti24 Apr 23, 2024
1a32843
refactor: updated examples
anubhuti24 Apr 23, 2024
8764036
updated ml.py
anubhuti24 Apr 23, 2024
a43b2a8
refactor: updated conftest.py
anubhuti24 Apr 24, 2024
eebb854
feature: added unit tests
anubhuti24 Apr 24, 2024
b679162
refactor: updated ml.py
anubhuti24 Apr 24, 2024
d095365
refactor: updated unit tests
anubhuti24 Apr 24, 2024
5ae463a
refactor: ml.py
anubhuti24 Apr 24, 2024
abac4af
refactor: updated networkX visualisation
anubhuti24 Apr 24, 2024
3ae3bc5
bugfx: unit tests
anubhuti24 Apr 24, 2024
a274585
refactor: updated unit tests
anubhuti24 Apr 24, 2024
e720fcd
refactor: added override param
anubhuti24 Apr 24, 2024
0bfb0b8
updated examples
anubhuti24 Apr 24, 2024
c3e16ea
bugfix: retrieve answer as sub graph
anubhuti24 Apr 25, 2024
2eeec71
refactor: updated ml.py
anubhuti24 Apr 25, 2024
ba909cc
bugfix: test_ml.py
anubhuti24 Apr 25, 2024
3edd84b
refactor: unit test workflow
anubhuti24 Apr 25, 2024
3152523
refactor: updated override logic
anubhuti24 Apr 25, 2024
f76218d
refctor: kgchat script
anubhuti24 Apr 25, 2024
4f45957
refactor: library version
anubhuti24 Apr 25, 2024
f67dcc1
feature: added __eq__ to compare objects
anubhuti24 Apr 25, 2024
c25d4bd
refactor: updated examples to compare objects
anubhuti24 Apr 25, 2024
2c7c8b2
bugfix: updted ml.py
anubhuti24 Apr 25, 2024
5846f10
bugfix: resolved transaction timeout issue
anubhuti24 Apr 25, 2024
02d037d
bugfix: unit tests
anubhuti24 Apr 25, 2024
5f6192a
bugfix: test_ml
anubhuti24 Apr 25, 2024
a5926a2
refactor: added high level imports
anubhuti24 Apr 25, 2024
388dc01
updated __init__.py
anubhuti24 Apr 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/personal-graph.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ jobs:

- name: Test with pytest
run: |
cd personal_graph
cd tests
poetry run pytest -vvv
9 changes: 9 additions & 0 deletions examples/pythonic_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import logging
from personal_graph.graph import Graph
from personal_graph.ml import to_networkx, from_networkx
from personal_graph.models import Node, EdgeInput, KnowledgeGraph, Edge
sutyum marked this conversation as resolved.
Show resolved Hide resolved


Expand Down Expand Up @@ -104,6 +105,14 @@ def main(url, token):
)
logging.info(graph.visualize_graph(kg))

# Transforms to and from networkx do not alter the graph
g2 = from_networkx(
to_networkx(graph, post_visualize=True), post_visualize=True, override=False
)
if graph == g2:
logging.info("TRUE")
assert graph == g2

graph.save()


Expand Down
41 changes: 41 additions & 0 deletions personal_graph/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,16 @@ def _find_node(cursor, connection):
return _find_node


def find_label(identifier: Any) -> CursorExecFunction:
def _find_label(cursor, connection):
node_label = cursor.execute(
"SELECT label from nodes where id=?", (identifier,)
).fetchone()
return node_label

return _find_label


def _parse_search_results(results: List[Tuple], idx: int = 0) -> List[Dict]:
return [json.loads(item[idx]) for item in results]

Expand Down Expand Up @@ -650,6 +660,9 @@ def _merge(cursor, connection):
cursor, connection
)

if similar_nodes is None:
continue

if len(similar_nodes) < 1:
continue

Expand Down Expand Up @@ -800,6 +813,34 @@ def _fetch_nodes_from_db(cursor, connection):
return _fetch_nodes_from_db


def find_indegree_edges(target_id: Any) -> CursorExecFunction:
def _indegree_edges(cursor, connection):
indegree = cursor.execute(
"SELECT source, label, attributes from edges where target=? ", (target_id,)
)

if indegree:
indegree = indegree.fetchall()

return indegree

return _indegree_edges


def find_outdegree_edges(source_id: Any) -> CursorExecFunction:
def _outdegree_edges(cursor, connection):
outdegree = cursor.execute(
"SELECT target, label, attributes from edges where source=? ", (source_id,)
)

if outdegree:
outdegree = outdegree.fetchall()

return outdegree

return _outdegree_edges


def all_connected_nodes(node_or_edge: Union[Node | Edge]) -> CursorExecFunction:
def _connected_nodes(cursor, connection):
nodes = None
Expand Down
17 changes: 17 additions & 0 deletions personal_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
find_similar_nodes,
nodes_list,
vector_search_node,
find_label,
find_outdegree_edges,
find_indegree_edges,
)
from .natural import insert_into_graph, search_from_graph, visualize_knowledge_graph
from .visualizers import graphviz_visualize
Expand All @@ -38,6 +41,11 @@ def __init__(self, db_url: Optional[str] = None, auth_token: Optional[str] = Non
self.db_url = db_url
self.auth_token = auth_token

def __eq__(self, other):
if not isinstance(other, Graph):
return "Not of Graph Type"
return self.db_url == other.db_url and self.auth_token == other.auth_token

def __enter__(self, schema_file: str = "schema.sql") -> Graph:
if not self.db_url:
# Support for local SQLite database
Expand Down Expand Up @@ -131,6 +139,9 @@ def remove_nodes(self, ids: List[Any]) -> None:
def search_node(self, node_id: Any) -> Any:
return atomic(find_node(node_id), self.db_url, self.auth_token)

def search_node_label(self, node_id: Any) -> Any:
return atomic(find_label(node_id), self.db_url, self.auth_token)

def traverse(
self, source: Any, target: Optional[Any] = None, with_bodies: bool = False
) -> List:
Expand Down Expand Up @@ -170,6 +181,12 @@ def visualize(self, file: str, path: List[str]) -> Digraph:
def fetch_ids_from_db(self) -> List[str]:
return atomic(nodes_list(), self.db_url, self.auth_token)

def search_indegree_edges(self, target) -> List[Any]:
return atomic(find_indegree_edges(target), self.db_url, self.auth_token)

def search_outdegree_edges(self, source) -> List[Any]:
return atomic(find_outdegree_edges(source), self.db_url, self.auth_token)

def is_unique_prompt(self, text: str, threshold: float) -> bool:
similar_nodes = atomic(
vector_search_node({"body": text}, threshold, 1),
Expand Down
187 changes: 183 additions & 4 deletions personal_graph/ml.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,183 @@
"""
Provide functions to convert the data in the graph db into formats compatible with
graph analysis frameworks such as Networkx and DGL
"""
import json
import os

import networkx as nx # type: ignore
from typing import Dict, Any
import matplotlib.pyplot as plt
from graphviz import Digraph # type: ignore

import personal_graph.graph as pg
from personal_graph.models import Node, KnowledgeGraph, Edge, EdgeInput


def to_networkx(graph: pg.Graph, *, post_visualize: bool = False) -> nx.Graph:
"""
Convert the graph database to a NetworkX DiGraph object.
"""
G = nx.Graph() # Empty Graph with no nodes and edges

node_ids = graph.fetch_ids_from_db()
# Add edges to networkX
for source_id in node_ids:
for target_id, edge_label, edge_data in graph.search_outdegree_edges(source_id):
edge_data = json.loads(edge_data)
edge_data["label"] = edge_label
G.add_edge(source_id, target_id, **edge_data)

for target_id in node_ids:
for source_id, edge_label, edge_data in graph.search_indegree_edges(target_id):
edge_data = json.loads(edge_data)
edge_data["label"] = edge_label
G.add_edge(source_id, target_id, **edge_data)

node_ids_with_edges = set([node for edge in G.edges() for node in edge])
for node_id in node_ids:
if node_id not in node_ids_with_edges:
node_data = graph.search_node(node_id)
node_label = graph.search_node_label(node_id)
node_data["label"] = node_label
G.add_node(node_id, **node_data)

if post_visualize:
# Visualizing the NetworkX Graph
plt.figure(figsize=(20, 20), dpi=100) # Increase the figure size and resolution
pos = nx.spring_layout(
G, scale=6
) # Use spring layout for better node positioning

nx.draw_networkx(
G,
pos,
with_labels=True,
nodelist=G.nodes(),
edgelist=G.edges(),
node_size=600,
node_color="skyblue",
edge_color="gray",
width=1.5,
)
nx.draw_networkx_edge_labels(
G, pos, edge_labels=nx.get_edge_attributes(G, "label")
)
plt.axis("off") # Show the axes
plt.savefig("networkX_graph.png")

return G


def from_networkx(
network_graph: nx, *, post_visualize: bool = False, override: bool = True
) -> pg.Graph:
with pg.Graph(os.getenv("LIBSQL_URL"), os.getenv("LIBSQL_AUTH_TOKEN")) as graph:
if override:
node_ids = graph.fetch_ids_from_db()
graph.remove_nodes(node_ids)

node_ids_with_edges = set()
kg = KnowledgeGraph()

# Convert networkX edges to personal graph edges
for source_id, target_id, edge_data in network_graph.edges(data=True):
edge_attributes: Dict[str, Any] = edge_data
edge_label: str = edge_attributes["label"]
node_ids_with_edges.add(str(source_id))
node_ids_with_edges.add(str(target_id))

if not override:
# Check if the node with the given id exists, if not then firstly add the node.
source = graph.search_node(source_id)
if source is []:
node_ids_with_edges.remove(str(source_id))
graph.add_node(
Node(
id=str(source_id),
label=edge_label if edge_label else "",
attributes=edge_attributes,
)
)

target = graph.search_node(target_id)
if target is []:
node_ids_with_edges.remove(str(target_id))
graph.add_node(
Node(
id=str(target_id),
label=edge_label if edge_label else "",
attributes=edge_attributes,
)
)

# After adding the new nodes if exists , add an edge
edge = Edge(
source=str(source_id),
target=str(target_id),
label=edge_label if edge_label else "",
attributes=edge_attributes,
)
kg.edges.append(edge)

# Convert networkX nodes to personal graph nodes
for node_id, node_data in network_graph.nodes(data=True):
if str(node_id) not in node_ids_with_edges:
node_attributes: Dict[str, Any] = node_data
node_label: str = node_attributes.pop("label", "")
node = Node(
id=str(node_id),
label=node_label[0] if node_label else "",
attributes=node_attributes,
)

if not override:
# Check if the node exists
if_node_exists = graph.search_node(node_id)

if if_node_exists:
graph.update_node(node)
else:
graph.add_node(node)
kg.nodes.append(node)

for edge in kg.edges:
source_node_attributes = graph.search_node(edge.source)
source_node_label = graph.search_node_label(edge.source)
target_node_attributes = graph.search_node(edge.target)
target_node_label = graph.search_node_label(edge.target)
final_edge_to_be_inserted = EdgeInput(
source=Node(
id=edge.source,
label=source_node_label
if isinstance(source_node_label, str)
else "Sample label",
attributes=source_node_attributes
if isinstance(source_node_attributes, Dict)
else "Sample Attributes",
),
target=Node(
id=edge.target,
label=target_node_label
if isinstance(target_node_label, str)
else "Sample label",
attributes=target_node_attributes
if isinstance(target_node_attributes, Dict)
else "Sample Attributes",
),
label=edge.label if isinstance(edge.label, str) else "Sample label",
attributes=edge.attributes
if isinstance(edge_attributes, Dict)
else "Sample Attributes",
)
graph.add_edge(final_edge_to_be_inserted)

if post_visualize:
# Visualize the personal graph using graphviz
dot = Digraph()

for node in kg.nodes:
dot.node(node.id, label=f"{node.label}: {node.id}")

for edge in kg.edges:
dot.edge(edge.source, edge.target, label=edge.label)

dot.render("personal_graph.gv", view=True)

return graph
1 change: 0 additions & 1 deletion personal_graph/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class Edge(BaseModel):
)
label: str = Field(
...,
min_length=2,
description="Most related and unique name associated with the edge.",
)
attributes: Union[str, Dict[str, str]] = Field(
Expand Down
Empty file removed personal_graph/tests/test_ml.py
Empty file.
Loading
Loading