Skip to content

Commit

Permalink
datasets: add node renumbering code #29
Browse files Browse the repository at this point in the history
  • Loading branch information
abhidg committed Jan 10, 2025
1 parent 355842b commit d37375d
Showing 1 changed file with 56 additions and 20 deletions.
76 changes: 56 additions & 20 deletions l2gv2/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EDGE_COLUMNS = {"source", "dest"} # required columns

EdgeList = list[tuple[str, str]]
NodeIndex = dict[str | int, int]


def is_graph_dataset(p: Path) -> bool:
Expand All @@ -42,6 +43,8 @@ def __init__(self, dset: str | Path, timestamp_fmt: str = "%Y-%m-%d"):
if (nodes_path := self.path / (self.path.stem + "_nodes.parquet")).exists():
self.paths["nodes"] = nodes_path

self._node_index_map = None

self._load_files()

def timestamp_from_string(self, ts: str) -> datetime.datetime:
Expand Down Expand Up @@ -108,15 +111,16 @@ def _load_files(self):
x for x in self.nodes.columns if x not in ["timestamp", "label", "nodes"]
]

def get_dates(self) -> list[str]:
"Returns list of dates"
return self.datelist.to_list()

def get_edges(self) -> pl.DataFrame:
"Returns edges as a polars DataFrame"
return self.edges

def get_nodes(self, ts: str | None = None) -> pl.DataFrame:
@property
def timestamps(self) -> list:
"Returns sorted list of dates"
return sorted(self.datelist.to_list())

def get_nodes(self, ts: str | datetime.datetime | None = None) -> pl.DataFrame:
"""Returns node data as a polars DataFrame
Args:
Expand All @@ -127,11 +131,10 @@ def get_nodes(self, ts: str | None = None) -> pl.DataFrame:
"""
if ts is None:
return self.nodes
if isinstance(ts, str):
ts = self.timestamp_from_string(ts)
return self.nodes.filter(pl.col("timestamp") == ts)
ts_cast = self.timestamp_from_string(ts) if isinstance(ts, str) else ts
return self.nodes.filter(pl.col("timestamp") == ts_cast)

def get_node_list(self, ts: str | None = None) -> list[str]:
def get_node_list(self, ts: str | datetime.datetime | None = None) -> list[str]:
"""Returns node list
Args:
Expand All @@ -143,19 +146,10 @@ def get_node_list(self, ts: str | None = None) -> list[str]:
nodes = self.nodes

if ts is not None:
if isinstance(ts, str):
ts = self.timestamp_from_string(ts)
nodes = nodes.filter(pl.col("timestamp") == ts)
ts_cast = self.timestamp_from_string(ts) if isinstance(ts, str) else ts
nodes = nodes.filter(pl.col("timestamp") == ts_cast)
return nodes.select("nodes").unique(maintain_order=True).to_series().to_list()

def get_node_features(self) -> list[str]:
"Returns node features as a list of strings"
return self.node_features

def get_edge_features(self) -> list[str]:
"Returns edge features as a list of strings"
return self.edge_features

def get_graph(self) -> rp.Graph: # pylint: disable=no-member
"Returns a raphtory.Graph representation"
g = rp.Graph() # pylint: disable=no-member
Expand Down Expand Up @@ -253,6 +247,48 @@ def get_edge_index(
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
return edge_index

def get_node_index_map(
self, ts: str | datetime.datetime | None = None
) -> NodeIndex:
"""Returns mapping from node value to an integer index.
Local2Global requires 0..|V|-1 indexing for a node set V.
Nodes have to be indexed from 1 both within a patch (timestamp),
and globally. This function returns a dictionary mapping
node values to the index.
Parameters
----------
ts
If specified, return index map for patch with timestamp `ts`
"""
if ts is None:
if self._node_index_map is None:
all_nodes = self.get_node_list()
self._node_index_map: dict[str | int, int] = {
x: i for i, x in enumerate(all_nodes)
}
else:
return self._node_index_map
ts_cast = self.timestamp_from_string(ts) if isinstance(ts, str) else ts
nodes = self.get_node_list(ts_cast)
return {x: i for i, x in enumerate(nodes)}

def get_renumbered_nodes(self) -> list[list[int]]:
"""Returns a list of renumbered nodes R from a timestamp based patch graph
In this list $R_i$ corresponds to the set of nodes at patch $i$, but
with the global node indexing applied (using
:meth:`DataLoader.get_node_index_map`). The ordering of the patch
graphs in this list is the timestamp ordering, with the earliest
timestamp as index 0.
"""
list_nodes_renumbered = []
node_idx = self.get_node_index_map()
for ts in self.timestamps:
list_nodes_renumbered.append([node_idx[i] for i in self.get_node_list(ts)])
return list_nodes_renumbered

def get_tgeometric(
self, temp: bool = True
) -> torch_geometric.data.Data | dict[datetime.datetime, torch_geometric.data.Data]:
Expand Down

0 comments on commit d37375d

Please sign in to comment.