From eca28a8882ccbddf18c73dff6ac9a09757733f5c Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 8 Feb 2024 16:26:40 -0500 Subject: [PATCH] add edge type assertions --- tests/test_adapter.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 7adc9d1..0828361 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -5,7 +5,7 @@ from dgl import DGLGraph, DGLHeteroGraph from dgl.view import EdgeSpace, NodeSpace from pandas import DataFrame -from torch import Tensor, cat, long, tensor +from torch import Tensor, cat, int64, long, tensor from adbdgl_adapter import ADBDGL_Adapter from adbdgl_adapter.encoders import CategoricalEncoder, IdentityEncoder @@ -573,6 +573,8 @@ def test_adb_partial_to_dgl() -> None: # Grab the same nodes from the Homogeneous graph from_nodes_new, to_nodes_new = dgl_g_new.edges(etype=None) + assert from_nodes.dtype == from_nodes_new.dtype + assert to_nodes.dtype == to_nodes_new.dtype assert from_nodes.tolist() == from_nodes_new.tolist() assert to_nodes.tolist() == to_nodes_new.tolist() @@ -772,8 +774,11 @@ def assert_adb_to_dgl( from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist() to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist() - assert from_nodes == dgl_g.edges(etype=e_key)[0].tolist() - assert to_nodes == dgl_g.edges(etype=e_key)[1].tolist() + src, dst = dgl_g.edges(etype=e_key) + assert src.dtype == int64 + assert dst.dtype == int64 + assert from_nodes == src.tolist() + assert to_nodes == dst.tolist() assert_adb_to_dgl_meta(meta, et_df, dgl_g.edges[e_key].data)