Skip to content

Commit

Permalink
Fix broken master (pyg-team#8646)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 20, 2023
1 parent c4d6d02 commit 11a29b0
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions test/distributed/test_dist_link_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import torch

from torch_geometric.data import Data
from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore
from torch_geometric.distributed.dist_context import DistContext
from torch_geometric.distributed.dist_neighbor_sampler import (
from torch_geometric.distributed import (
DistNeighborSampler,
close_sampler,
LocalFeatureStore,
LocalGraphStore,
)
from torch_geometric.distributed.rpc import init_rpc
from torch_geometric.distributed.dist_context import DistContext
from torch_geometric.distributed.event_loop import ConcurrentEventLoop
from torch_geometric.distributed.rpc import init_rpc, shutdown_rpc
from torch_geometric.sampler import EdgeSamplerInput, NeighborSampler
from torch_geometric.sampler.neighbor_sampler import edge_sample
from torch_geometric.testing import onlyLinux, withPackage
Expand Down Expand Up @@ -88,24 +89,23 @@ def dist_link_neighbor_sampler(
dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=[-1, -1],
shuffle=False,
disjoint=disjoint,
)

# Close RPC & worker group at exit:
atexit.register(close_sampler, 0, dist_sampler)
atexit.register(shutdown_rpc)

init_rpc(
current_ctx=current_ctx,
rpc_worker_names={},
master_addr='localhost',
master_port=master_port,
)

dist_sampler.init_sampler_instance()
dist_sampler.register_sampler_rpc()
dist_sampler.init_event_loop()
dist_sampler.event_loop = ConcurrentEventLoop(2)
dist_sampler.event_loop.start_loop()

if rank == 0: # Seed nodes:
input_row = torch.tensor([1, 6], dtype=torch.int64)
Expand Down Expand Up @@ -170,7 +170,6 @@ def dist_link_neighbor_sampler_temporal(
dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=num_neighbors,
shuffle=False,
disjoint=True,
Expand All @@ -179,17 +178,17 @@ def dist_link_neighbor_sampler_temporal(
)

# Close RPC & worker group at exit:
atexit.register(close_sampler, 0, dist_sampler)
atexit.register(shutdown_rpc)

init_rpc(
current_ctx=current_ctx,
rpc_worker_names={},
master_addr='localhost',
master_port=master_port,
)

dist_sampler.init_sampler_instance()
dist_sampler.register_sampler_rpc()
dist_sampler.init_event_loop()
dist_sampler.event_loop = ConcurrentEventLoop(2)
dist_sampler.event_loop.start_loop()

if rank == 0: # Seed nodes:
input_row = torch.tensor([1, 6], dtype=torch.int64)
Expand Down Expand Up @@ -294,9 +293,10 @@ def test_dist_link_neighbor_sampler_temporal(seed_time, temporal_strategy):
@withPackage('pyg_lib')
@pytest.mark.parametrize('seed_time', [[1, 1], [3, 7]])
@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])
def test_dist_neighbor_sampler_edge_level_temporal(seed_time,
temporal_strategy):

def test_dist_neighbor_sampler_edge_level_temporal(
seed_time,
temporal_strategy,
):
seed_time = torch.tensor(seed_time)

mp_context = torch.multiprocessing.get_context('spawn')
Expand Down

0 comments on commit 11a29b0

Please sign in to comment.