Skip to content

Commit

Permalink
fix: fix pyg remote backend ut (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhanghyi authored Nov 1, 2024
1 parent 2034837 commit 36ce42b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions graphlearn_torch/python/distributed/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings

import torch
from graphscope.learning.graphlearn_torch.partition.base import PartitionBook
from ..partition import PartitionBook

from ..channel import ShmChannel, QueueTimeoutError
from ..sampler import NodeSamplerInput, EdgeSamplerInput, SamplingConfig, RemoteSamplerInput
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_node_partition_id(self, node_type, index):

def get_node_feature(self, node_type, index):
feature = self.dataset.get_node_feature(node_type)
return feature[index]
return feature[index].cpu()

def get_tensor_size(self, node_type):
feature = self.dataset.get_node_feature(node_type)
Expand Down

0 comments on commit 36ce42b

Please sign in to comment.