From 36ce42bb5c19633b63a802d2cbe5d9c3746716ed Mon Sep 17 00:00:00 2001 From: Hongyi ZHANG <50618951+Zhanghyi@users.noreply.github.com> Date: Fri, 1 Nov 2024 16:51:16 +0800 Subject: [PATCH] fix: fix pyg remote backend ut (#147) --- graphlearn_torch/python/distributed/dist_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphlearn_torch/python/distributed/dist_server.py b/graphlearn_torch/python/distributed/dist_server.py index 283893c4..264e650d 100644 --- a/graphlearn_torch/python/distributed/dist_server.py +++ b/graphlearn_torch/python/distributed/dist_server.py @@ -20,7 +20,7 @@ import warnings import torch -from graphscope.learning.graphlearn_torch.partition.base import PartitionBook +from ..partition import PartitionBook from ..channel import ShmChannel, QueueTimeoutError from ..sampler import NodeSamplerInput, EdgeSamplerInput, SamplingConfig, RemoteSamplerInput @@ -95,7 +95,7 @@ def get_node_partition_id(self, node_type, index): def get_node_feature(self, node_type, index): feature = self.dataset.get_node_feature(node_type) - return feature[index] + return feature[index].cpu() def get_tensor_size(self, node_type): feature = self.dataset.get_node_feature(node_type)