From a031488062a654d95c54ebbc0ee71434c286c329 Mon Sep 17 00:00:00 2001 From: LiSu Date: Thu, 1 Feb 2024 11:50:10 +0000 Subject: [PATCH 1/2] Fix drop last in distributed sampler --- examples/igbh/dist_train_rgnn.py | 3 ++- graphlearn_torch/python/distributed/dist_loader.py | 4 ++-- .../python/distributed/dist_neighbor_sampler.py | 10 ++++++++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/igbh/dist_train_rgnn.py b/examples/igbh/dist_train_rgnn.py index 0e72c962..1930f7c4 100644 --- a/examples/igbh/dist_train_rgnn.py +++ b/examples/igbh/dist_train_rgnn.py @@ -166,7 +166,8 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, num_neighbors=[int(fanout) for fanout in fan_out.split(',')], input_nodes=('paper', val_idx), batch_size=val_batch_size, - shuffle=False, + shuffle=True, + drop_last=False, edge_dir=edge_dir, collect_features=True, to_device=current_device, diff --git a/graphlearn_torch/python/distributed/dist_loader.py b/graphlearn_torch/python/distributed/dist_loader.py index d52b2919..76f33fe9 100644 --- a/graphlearn_torch/python/distributed/dist_loader.py +++ b/graphlearn_torch/python/distributed/dist_loader.py @@ -388,7 +388,7 @@ def _collate_fn( if self.sampling_config.sampling_type in [SamplingType.NODE, SamplingType.SUBGRAPH]: batch_dict = { - self._input_type: node_dict[self._input_type][:self.batch_size] + self._input_type: msg[f'{self._input_type}.batch'].to(self.to_device) } batch_labels_key = f'{self._input_type}.nlabels' if batch_labels_key in msg: @@ -426,7 +426,7 @@ def _collate_fn( if self.sampling_config.sampling_type in [SamplingType.NODE, SamplingType.SUBGRAPH]: - batch = ids[:self.batch_size] + batch = msg['batch'].to(self.to_device) batch_labels = msg['nlabels'].to(self.to_device) if 'nlabels' in msg else None else: batch = None diff --git a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py index 79d2cb70..ec04e87c 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py @@ -282,6 +282,7 @@ async def _sample_from_nodes( if is_hetero: assert input_type is not None src_dict = inducer.init_node({input_type: input_seeds}) + batch = src_dict out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {} num_sampled_nodes, num_sampled_edges = {}, {} merge_dict(src_dict, out_nodes) @@ -329,6 +330,7 @@ async def _sample_from_nodes( {etype: torch.cat(eids) for etype, eids in out_edges.items()} if self.with_edge else None ), + batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, input_type=input_type, @@ -337,6 +339,7 @@ async def _sample_from_nodes( else: srcs = inducer.init_node(input_seeds) + batch = srcs out_nodes, out_edges = [], [] num_sampled_nodes, num_sampled_edges = [], [] out_nodes.append(srcs) @@ -359,6 +362,7 @@ async def _sample_from_nodes( row=torch.cat([e[0] for e in out_edges]), col=torch.cat([e[1] for e in out_edges]), edge=(torch.cat([e[2] for e in out_edges]) if self.with_edge else None), + batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, metadata={} @@ -717,6 +721,10 @@ async def _colloate_fn( result_map[f'{as_str(etype)}.efeats'] = efeats elif self.edge_dir == 'in': result_map[f'{as_str(reverse_edge_type(etype))}.efeats'] = efeats + # Collect batch info + if output.batch is not None: + for ntype, batch in output.batch.items(): + result_map[f'{as_str(ntype)}.batch'] = batch else: result_map['ids'] = output.node result_map['rows'] = output.row @@ -743,5 +751,7 @@ async def _colloate_fn( fut = self.dist_edge_feature.async_get(eids) efeats = await wrap_torch_future(fut) result_map['efeats'] = efeats + # Collect batch info + result_map['batch'] = output.batch return result_map From 0086db6ea0a71e3359e30b2210429176e180868f Mon Sep 17 00:00:00 2001 From: LiSu Date: Thu, 1 Feb 2024 12:24:59 +0000 Subject: [PATCH 2/2] minor --- .../python/distributed/dist_loader.py | 17 +++++++++++++---- .../python/distributed/dist_neighbor_sampler.py | 3 ++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/graphlearn_torch/python/distributed/dist_loader.py b/graphlearn_torch/python/distributed/dist_loader.py index 76f33fe9..56f953e2 100644 --- a/graphlearn_torch/python/distributed/dist_loader.py +++ b/graphlearn_torch/python/distributed/dist_loader.py @@ -387,9 +387,15 @@ def _collate_fn( if self.sampling_config.sampling_type in [SamplingType.NODE, SamplingType.SUBGRAPH]: - batch_dict = { - self._input_type: msg[f'{self._input_type}.batch'].to(self.to_device) - } + batch_key = f'{self._input_type}.batch' + if msg.get(batch_key) is not None: + batch_dict = { + self._input_type: msg[f'{self._input_type}.batch'].to(self.to_device) + } + else: + batch_dict = { + self._input_type: node_dict[self._input_type][:self.batch_size] + } batch_labels_key = f'{self._input_type}.nlabels' if batch_labels_key in msg: batch_labels = msg[batch_labels_key].to(self.to_device) @@ -426,7 +432,10 @@ def _collate_fn( if self.sampling_config.sampling_type in [SamplingType.NODE, SamplingType.SUBGRAPH]: - batch = msg['batch'].to(self.to_device) + if msg.get('batch') is not None: + batch = msg['batch'].to(self.to_device) + else: + batch = ids[:self.batch_size] batch_labels = msg['nlabels'].to(self.to_device) if 'nlabels' in msg else None else: batch = None diff --git a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py index ec04e87c..11e83237 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py @@ -752,6 +752,7 @@ async def _colloate_fn( efeats = await wrap_torch_future(fut) result_map['efeats'] = efeats # Collect batch info - result_map['batch'] = output.batch + if output.batch is not None: + result_map['batch'] = output.batch return result_map