Skip to content

Commit

Permalink
add edge type assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Feb 8, 2024
1 parent 9698df3 commit eca28a8
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dgl import DGLGraph, DGLHeteroGraph
from dgl.view import EdgeSpace, NodeSpace
from pandas import DataFrame
from torch import Tensor, cat, long, tensor
from torch import Tensor, cat, int64, long, tensor

from adbdgl_adapter import ADBDGL_Adapter
from adbdgl_adapter.encoders import CategoricalEncoder, IdentityEncoder
Expand Down Expand Up @@ -573,6 +573,8 @@ def test_adb_partial_to_dgl() -> None:
# Grab the same nodes from the Homogeneous graph
from_nodes_new, to_nodes_new = dgl_g_new.edges(etype=None)

assert from_nodes.dtype == from_nodes_new.dtype
assert to_nodes.dtype == to_nodes_new.dtype
assert from_nodes.tolist() == from_nodes_new.tolist()
assert to_nodes.tolist() == to_nodes_new.tolist()

Expand Down Expand Up @@ -772,8 +774,11 @@ def assert_adb_to_dgl(
from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist()
to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist()

assert from_nodes == dgl_g.edges(etype=e_key)[0].tolist()
assert to_nodes == dgl_g.edges(etype=e_key)[1].tolist()
src, dst = dgl_g.edges(etype=e_key)
assert src.dtype == int64
assert dst.dtype == int64
assert from_nodes == src.tolist()
assert to_nodes == dst.tolist()

assert_adb_to_dgl_meta(meta, et_df, dgl_g.edges[e_key].data)

Expand Down

0 comments on commit eca28a8

Please sign in to comment.