From c4d6d02973fd70e4bb120e6a0a69745c2a0fb1f7 Mon Sep 17 00:00:00 2001 From: Kinga Gajdamowicz Date: Wed, 20 Dec 2023 16:08:44 +0100 Subject: [PATCH] Enable homo distributed link sampling (#8375) This implementation covers distributed homo edge sampling. Hetero is in progress. The edge_sample() function is almost the same as the one from the `NeighborSampler` class. The difference is that it calls the node_sample() from the distributed package. Why I decided to copy paste it to the DistSampler?: - to not interfere with the original edge_sample() function from the `NeighborSampler` - To make this function async (seamlessly) - In case of hetero there will be more changes needed This PR introduces support for edge sampling: - disjoint, negative, temporal, bidirectional - with unit tests - unit tests for edge-level temporal sampling --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey --- CHANGELOG.md | 2 +- .../test_dist_link_neighbor_loader.py | 2 - .../test_dist_link_neighbor_sampler.py | 321 ++++++++++++++++++ .../distributed/dist_neighbor_sampler.py | 218 +++++++++++- 4 files changed, 525 insertions(+), 18 deletions(-) create mode 100644 test/distributed/test_dist_link_neighbor_sampler.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bd3866f60d52..a3e838ec0761 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `ExplainerDataset` will now contain node labels for any motif generator ([#8519](https://github.com/pyg-team/pytorch_geometric/pull/8519)) - Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399)) - Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369)) -- Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367)) +- Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375)) - Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083)) - Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210)) - Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220)) diff --git a/test/distributed/test_dist_link_neighbor_loader.py b/test/distributed/test_dist_link_neighbor_loader.py index 744f151bca8c..07d8960e0c88 100644 --- a/test/distributed/test_dist_link_neighbor_loader.py +++ b/test/distributed/test_dist_link_neighbor_loader.py @@ -90,7 +90,6 @@ def dist_link_neighbor_loader_homo( for batch in loader: assert isinstance(batch, Data) assert batch.n_id.size() == (batch.num_nodes, ) - assert batch.input_id.numel() == batch.batch_size == 10 assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes assert loader.channel.empty() @@ -163,7 +162,6 @@ def dist_link_neighbor_loader_hetero( @pytest.mark.parametrize('num_workers', [0]) @pytest.mark.parametrize('async_sampling', [True]) @pytest.mark.parametrize('neg_ratio', [None]) -@pytest.mark.skip(reason="'sample_from_edges' not yet implemented") def test_dist_link_neighbor_loader_homo( tmp_path, num_parts, diff --git a/test/distributed/test_dist_link_neighbor_sampler.py b/test/distributed/test_dist_link_neighbor_sampler.py new file mode 100644 index 000000000000..e177a00f5e82 --- /dev/null +++ b/test/distributed/test_dist_link_neighbor_sampler.py @@ -0,0 +1,321 @@ +import atexit +import socket +from typing import Optional + +import pytest +import torch + +from torch_geometric.data import Data +from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore +from torch_geometric.distributed.dist_context import DistContext +from torch_geometric.distributed.dist_neighbor_sampler import ( + DistNeighborSampler, + close_sampler, +) +from torch_geometric.distributed.rpc import init_rpc +from torch_geometric.sampler import EdgeSamplerInput, NeighborSampler +from torch_geometric.sampler.neighbor_sampler import edge_sample +from torch_geometric.testing import onlyLinux, withPackage + + +def create_data(rank, world_size, time_attr: Optional[str] = None): + if rank == 0: # Partition 0: + node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9]) + edge_index = torch.tensor([ # Sorted by destination. + [1, 2, 3, 4, 5, 0, 0], + [0, 1, 2, 3, 4, 4, 9], + ]) + else: # Partition 1: + node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9]) + edge_index = torch.tensor([ # Sorted by destination. + [5, 6, 7, 8, 9, 5, 0], + [4, 5, 6, 7, 8, 9, 9], + ]) + + feature_store = LocalFeatureStore.from_data(node_id) + graph_store = LocalGraphStore.from_data( + edge_id=None, + edge_index=edge_index, + num_nodes=10, + is_sorted=True, + ) + + graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + graph_store.meta.update({'num_parts': 2}) + graph_store.partition_idx = rank + graph_store.num_partitions = world_size + + edge_index = torch.tensor([ # Create reference data: + [1, 2, 3, 4, 5, 0, 5, 6, 7, 8, 9, 0], + [0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9], + ]) + data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10) + + if time_attr == 'time': # Create node-level time data: + data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4]) + feature_store.put_tensor(data.time, group_name=None, attr_name='time') + + elif time_attr == 'edge_time': # Create edge-level time data: + data.edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 11]) + + if rank == 0: + edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 11]) + if rank == 1: + edge_time = torch.tensor([4, 7, 7, 7, 7, 7, 11]) + + feature_store.put_tensor(edge_time, group_name=None, + attr_name=time_attr) + + return (feature_store, graph_store), data + + +def dist_link_neighbor_sampler( + world_size: int, + rank: int, + master_port: int, + disjoint: bool = False, +): + dist_data, data = create_data(rank, world_size) + + current_ctx = DistContext( + rank=rank, + global_rank=rank, + world_size=world_size, + global_world_size=world_size, + group_name='dist-sampler-test', + ) + + dist_sampler = DistNeighborSampler( + data=dist_data, + current_ctx=current_ctx, + rpc_worker_names={}, + num_neighbors=[-1, -1], + shuffle=False, + disjoint=disjoint, + ) + + # Close RPC & worker group at exit: + atexit.register(close_sampler, 0, dist_sampler) + + init_rpc( + current_ctx=current_ctx, + rpc_worker_names={}, + master_addr='localhost', + master_port=master_port, + ) + + dist_sampler.register_sampler_rpc() + dist_sampler.init_event_loop() + + if rank == 0: # Seed nodes: + input_row = torch.tensor([1, 6], dtype=torch.int64) + input_col = torch.tensor([2, 7], dtype=torch.int64) + else: + input_row = torch.tensor([4, 9], dtype=torch.int64) + input_col = torch.tensor([5, 0], dtype=torch.int64) + + inputs = EdgeSamplerInput( + input_id=None, + row=input_row, + col=input_col, + input_type=None, + ) + + # evaluate distributed edge sample function + out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample( + inputs, dist_sampler.node_sample, data.num_nodes, disjoint)) + + sampler = NeighborSampler(data=data, num_neighbors=[-1, -1], + disjoint=disjoint) + + # Evaluate edge sample function: + out = edge_sample( + inputs, + sampler._sample, + data.num_nodes, + disjoint, + node_time=None, + neg_sampling=None, + ) + + # Compare distributed output with single machine output: + assert torch.equal(out_dist.node, out.node) + assert torch.equal(out_dist.row, out.row) + assert torch.equal(out_dist.col, out.col) + if disjoint: + assert torch.equal(out_dist.batch, out.batch) + assert out_dist.num_sampled_nodes == out.num_sampled_nodes + assert out_dist.num_sampled_edges == out.num_sampled_edges + + +def dist_link_neighbor_sampler_temporal( + world_size: int, + rank: int, + master_port: int, + seed_time: torch.tensor = None, + temporal_strategy: str = 'uniform', + time_attr: str = 'time', +): + dist_data, data = create_data(rank, world_size, time_attr) + + current_ctx = DistContext( + rank=rank, + global_rank=rank, + world_size=world_size, + global_world_size=world_size, + group_name='dist-sampler-test', + ) + + num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1] + dist_sampler = DistNeighborSampler( + data=dist_data, + current_ctx=current_ctx, + rpc_worker_names={}, + num_neighbors=num_neighbors, + shuffle=False, + disjoint=True, + temporal_strategy=temporal_strategy, + time_attr=time_attr, + ) + + # Close RPC & worker group at exit: + atexit.register(close_sampler, 0, dist_sampler) + + init_rpc( + current_ctx=current_ctx, + rpc_worker_names={}, + master_addr='localhost', + master_port=master_port, + ) + + dist_sampler.register_sampler_rpc() + dist_sampler.init_event_loop() + + if rank == 0: # Seed nodes: + input_row = torch.tensor([1, 6], dtype=torch.int64) + input_col = torch.tensor([2, 7], dtype=torch.int64) + else: + input_row = torch.tensor([4, 9], dtype=torch.int64) + input_col = torch.tensor([5, 0], dtype=torch.int64) + + inputs = EdgeSamplerInput( + input_id=None, + row=input_row, + col=input_col, + time=seed_time, + ) + + # Evaluate distributed edge sample function + out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample( + inputs, dist_sampler.node_sample, data.num_nodes, disjoint=True, + node_time=seed_time, neg_sampling=None)) + + sampler = NeighborSampler( + data=data, + num_neighbors=num_neighbors, + disjoint=True, + temporal_strategy=temporal_strategy, + time_attr=time_attr, + ) + + # Evaluate edge sample function + out = edge_sample( + inputs, + sampler._sample, + data.num_nodes, + disjoint=True, + node_time=seed_time, + neg_sampling=None, + ) + + # Compare distributed output with single machine output + assert torch.equal(out_dist.node, out.node) + assert torch.equal(out_dist.row, out.row) + assert torch.equal(out_dist.col, out.col) + assert torch.equal(out_dist.batch, out.batch) + assert out_dist.num_sampled_nodes == out.num_sampled_nodes + assert out_dist.num_sampled_edges == out.num_sampled_edges + + +@onlyLinux +@withPackage('pyg_lib') +@pytest.mark.parametrize('disjoint', [False, True]) +def test_dist_link_neighbor_sampler(disjoint): + mp_context = torch.multiprocessing.get_context('spawn') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + + world_size = 2 + w0 = mp_context.Process( + target=dist_link_neighbor_sampler, + args=(world_size, 0, port, disjoint), + ) + + w1 = mp_context.Process( + target=dist_link_neighbor_sampler, + args=(world_size, 1, port, disjoint), + ) + + w0.start() + w1.start() + w0.join() + w1.join() + + +@onlyLinux +@withPackage('pyg_lib') +@pytest.mark.parametrize('seed_time', [None, torch.tensor([3, 6])]) +@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) +def test_dist_link_neighbor_sampler_temporal(seed_time, temporal_strategy): + mp_context = torch.multiprocessing.get_context('spawn') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + + world_size = 2 + w0 = mp_context.Process( + target=dist_link_neighbor_sampler_temporal, + args=(world_size, 0, port, seed_time, temporal_strategy, 'time'), + ) + + w1 = mp_context.Process( + target=dist_link_neighbor_sampler_temporal, + args=(world_size, 1, port, seed_time, temporal_strategy, 'time'), + ) + + w0.start() + w1.start() + w0.join() + w1.join() + + +@onlyLinux +@withPackage('pyg_lib') +@pytest.mark.parametrize('seed_time', [[1, 1], [3, 7]]) +@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) +def test_dist_neighbor_sampler_edge_level_temporal(seed_time, + temporal_strategy): + + seed_time = torch.tensor(seed_time) + + mp_context = torch.multiprocessing.get_context('spawn') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + + world_size = 2 + w0 = mp_context.Process( + target=dist_link_neighbor_sampler_temporal, + args=(world_size, 0, port, seed_time, temporal_strategy, 'edge_time'), + ) + + w1 = mp_context.Process( + target=dist_link_neighbor_sampler_temporal, + args=(world_size, 1, port, seed_time, temporal_strategy, 'edge_time'), + ) + + w0.start() + w1.start() + w0.join() + w1.join() diff --git a/torch_geometric/distributed/dist_neighbor_sampler.py b/torch_geometric/distributed/dist_neighbor_sampler.py index e3419185706a..2bd0f0654fc1 100644 --- a/torch_geometric/distributed/dist_neighbor_sampler.py +++ b/torch_geometric/distributed/dist_neighbor_sampler.py @@ -1,6 +1,7 @@ import itertools import logging -from typing import Any, Dict, List, Optional, Tuple, Union +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -26,12 +27,15 @@ remove_duplicates, ) from torch_geometric.sampler import ( + EdgeSamplerInput, HeteroSamplerOutput, + NegativeSampling, NeighborSampler, NodeSamplerInput, SamplerOutput, ) from torch_geometric.sampler.base import NumNeighbors, SubgraphType +from torch_geometric.sampler.neighbor_sampler import neg_sample from torch_geometric.sampler.utils import remap_keys from torch_geometric.typing import EdgeType, NodeType @@ -89,6 +93,7 @@ def __init__( self.disjoint = disjoint self.temporal_strategy = temporal_strategy self.time_attr = time_attr + self.temporal = time_attr is not None self.with_edge_attr = self.feature_store.has_edge_attr() self.csc = True @@ -145,6 +150,28 @@ def sample_from_nodes( coro=self._sample_from(self.node_sample, inputs), callback=cb) return None + # Edge-based distributed sampling ######################################### + + def sample_from_edges( + self, + inputs: EdgeSamplerInput, + neg_sampling: Optional[NegativeSampling] = None, + **kwargs, + ) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]: + if self.channel is None: + # synchronous sampling + return self.event_loop.run_task(coro=self._sample_from( + self.edge_sample, inputs, self.node_sample, self._sampler. + num_nodes, self.disjoint, self.node_time, neg_sampling)) + + # asynchronous sampling + cb = kwargs.get("callback", None) + self.event_loop.add_task( + coro=self._sample_from(self.edge_sample, inputs, self.node_sample, + self._sampler.num_nodes, self.disjoint, + self.node_time, neg_sampling), callback=cb) + return None + async def _sample_from( self, async_func, @@ -180,10 +207,12 @@ async def node_sample( seed = inputs.node.to(self.device) batch_size = len(inputs.node) + seed_batch = torch.arange(batch_size) if self.disjoint else None + metadata = (inputs.input_id, inputs.time, batch_size) seed_time: Optional[Tensor] = None - if self.time_attr is not None: + if self.temporal: if inputs.time is not None: seed_time = inputs.time.to(self.device) elif self.node_time is not None: @@ -339,10 +368,12 @@ async def node_sample( ) else: src = seed - node = src + node = src.clone() + + src_batch = seed_batch.clone() if self.disjoint else None + batch = seed_batch.clone() if self.disjoint else None - src_batch = torch.arange(len(seed)) if self.disjoint else None - batch = src_batch + src_seed_time = seed_time.clone() if self.temporal else None node_with_dupl = [torch.empty(0, dtype=torch.int64)] batch_with_dupl = [torch.empty(0, dtype=torch.int64)] @@ -354,8 +385,8 @@ async def node_sample( # Loop over the layers: for i, one_hop_num in enumerate(self.num_neighbors): - out = await self.sample_one_hop(src, one_hop_num, seed_time, - src_batch) + out = await self.sample_one_hop(src, one_hop_num, + src_seed_time, src_batch) if out.node.numel() == 0: # No neighbors were sampled: num_zero_layers = self.num_hops - i @@ -373,11 +404,14 @@ async def node_sample( if self.disjoint: batch_with_dupl.append(out.batch) - if seed_time is not None and i < self.num_hops - 1: - # Get the seed time for the next layer based on the - # previous seed_time and sampled neighbors per node info: - seed_time = torch.repeat_interleave( - seed_time, torch.as_tensor(out.metadata[0])) + if self.temporal and i < self.num_hops - 1: + # Assign seed time based on src nodes subgraph IDs. + src_seed_time = [ + seed_time[(seed_batch == batch_idx).nonzero()] + for batch_idx in src_batch + ] + src_seed_time = torch.as_tensor(src_seed_time, + dtype=torch.int64) num_sampled_nodes.append(len(src)) num_sampled_edges.append(len(out.node)) @@ -406,6 +440,160 @@ async def node_sample( return sampler_output + async def edge_sample( + self, + inputs: EdgeSamplerInput, + sample_fn: Callable, + num_nodes: Union[int, Dict[NodeType, int]], + disjoint: bool, + node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None, + neg_sampling: Optional[NegativeSampling] = None, + ) -> Union[SamplerOutput, HeteroSamplerOutput]: + r"""Performs distributed asynchronous sampling from an edge sampler + input, leveraging a sampling function of the same signature as + `node_sample`. This function is almost the same as the `edge_sample` + in the :class:`NeighborSampler`, but calls the `node_sample` from + the distributed package. + """ + input_id = inputs.input_id + src = inputs.row + dst = inputs.col + edge_label = inputs.label + edge_label_time = inputs.time + input_type = inputs.input_type + + src_time = dst_time = edge_label_time + assert edge_label_time is None or disjoint + + assert isinstance(num_nodes, (dict, int)) + if not isinstance(num_nodes, dict): + num_src_nodes = num_dst_nodes = num_nodes + else: + num_src_nodes = num_nodes[input_type[0]] + num_dst_nodes = num_nodes[input_type[-1]] + + num_pos = src.numel() + num_neg = 0 + + # Negative Sampling ################################################### + + if neg_sampling is not None: + # When we are doing negative sampling, we append negative + # information of nodes/edges to `src`, `dst`, `src_time`, + # `dst_time`. Later on, we can easily reconstruct what belongs to + # positive and negative examples by slicing via `num_pos`. + num_neg = math.ceil(num_pos * neg_sampling.amount) + + if neg_sampling.is_binary(): + # In the "binary" case, we randomly sample negative pairs of + # nodes. + if isinstance(node_time, dict): + src_node_time = node_time.get(input_type[0]) + else: + src_node_time = node_time + + src_neg = neg_sample(src, neg_sampling, num_src_nodes, + src_time, src_node_time) + src = torch.cat([src, src_neg], dim=0) + + if isinstance(node_time, dict): + dst_node_time = node_time.get(input_type[-1]) + else: + dst_node_time = node_time + + dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, + dst_time, dst_node_time) + dst = torch.cat([dst, dst_neg], dim=0) + + if edge_label is None: + edge_label = torch.ones(num_pos) + size = (num_neg, ) + edge_label.size()[1:] + edge_neg_label = edge_label.new_zeros(size) + edge_label = torch.cat([edge_label, edge_neg_label]) + + if edge_label_time is not None: + src_time = dst_time = edge_label_time.repeat( + 1 + math.ceil(neg_sampling.amount))[:num_pos + num_neg] + + elif neg_sampling.is_triplet(): + # In the "triplet" case, we randomly sample negative + # destinations. + if isinstance(node_time, dict): + dst_node_time = node_time.get(input_type[-1]) + else: + dst_node_time = node_time + + dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, + dst_time, dst_node_time) + dst = torch.cat([dst, dst_neg], dim=0) + + assert edge_label is None + + if edge_label_time is not None: + dst_time = edge_label_time.repeat(1 + neg_sampling.amount) + + # Heterogeneus Neighborhood Sampling ################################## + + if input_type is not None: # TODO: (kgajdamo) + raise NotImplementedError + + # Homogeneus Neighborhood Sampling #################################### + + else: + + seed = torch.cat([src, dst], dim=0) + seed_time = None + + if not disjoint: + seed, inverse_seed = seed.unique(return_inverse=True) + + if edge_label_time is not None: # Always disjoint. + seed_time = torch.cat([src_time, dst_time]) + + out = await sample_fn( + NodeSamplerInput( + input_id=inputs.input_id, + node=seed, + time=seed_time, + input_type=None, + )) + + # Enhance `out` by label information ############################## + if neg_sampling is None or neg_sampling.is_binary(): + if disjoint: + out.batch = out.batch % num_pos + edge_label_index = torch.arange(seed.numel()).view(2, -1) + else: + edge_label_index = inverse_seed.view(2, -1) + + out.metadata = (input_id, edge_label_index, edge_label, + src_time) + + elif neg_sampling.is_triplet(): + if disjoint: + out.batch = out.batch % num_pos + src_index = torch.arange(num_pos) + dst_pos_index = torch.arange(num_pos, 2 * num_pos) + # `dst_neg_index` needs to be offset such that indices with + # offset `num_pos` belong to the same triplet: + dst_neg_index = torch.arange(2 * num_pos, seed.numel()) + dst_neg_index = dst_neg_index.view(-1, num_pos).t() + else: + src_index = inverse_seed[:num_pos] + dst_pos_index = inverse_seed[num_pos:2 * num_pos] + dst_neg_index = inverse_seed[2 * num_pos:] + dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) + + out.metadata = ( + input_id, + src_index, + dst_pos_index, + dst_neg_index, + src_time, + ) + + return out + def get_sampler_output( self, outputs: List[SamplerOutput], @@ -543,7 +731,7 @@ async def sample_one_hop( p_mask = partition_ids == p_id p_srcs = torch.masked_select(srcs, p_mask) p_seed_time = (torch.masked_select(seed_time, p_mask) - if seed_time is not None else None) + if self.temporal else None) p_indices = torch.arange(len(p_srcs), dtype=torch.long) partition_orders[p_mask] = p_indices @@ -619,12 +807,12 @@ def _sample_one_hop( True, # csc self.replace, self.subgraph_type != SubgraphType.induced, - self.disjoint and self.time_attr is not None, + self.disjoint and self.temporal, self.temporal_strategy, ) node, edge, cumsum_neighbors_per_node = out - if self.disjoint and self.time_attr is not None: + if self.disjoint and self.temporal: # We create a batch during the step of merging sampler outputs. _, node = node.t().contiguous()