Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Add separate init, expose gather/scatter for WholeMemoryTensor and up…
Browse files Browse the repository at this point in the history
…date example (#81)

Some Updates:

- Add separate init
- Expose gather/scatter for WholeMemoryTensor
- Some updates on examples and flags
- Fix integer scatter bug (closes issue #69 ).

Authors:
  - https://github.com/dongxuy04

Approvers:
  - Brad Rees (https://github.com/BradReesWork)

URL: #81
  • Loading branch information
dongxuy04 authored Oct 17, 2023
1 parent 01ad40f commit 4f9f39d
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ void scatter_integer_int64_temp_func(const void* input,

REGISTER_DISPATCH_TWO_TYPES(ScatterFuncIntegerInt64,
scatter_integer_int64_temp_func,
HALF_FLOAT_DOUBLE,
HALF_FLOAT_DOUBLE)
ALLSINT,
ALLSINT)

wholememory_error_code_t scatter_integer_int64_func(const void* input,
wholememory_matrix_description_t input_desc,
Expand Down
10 changes: 8 additions & 2 deletions python/pylibwholegraph/examples/node_classfication.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
wgth.add_common_sampler_options(parser)
wgth.add_node_classfication_options(parser)
wgth.add_dataloader_options(parser)
parser.add_option(
"--fp16_embedding", action="store_true", dest="fp16_mbedding", default=False, help="Whether to use fp16 embedding"
)


(options, args) = parser.parse_args()

Expand Down Expand Up @@ -188,13 +192,15 @@ def main_func():
else wgth.create_wholememory_optimizer("adam", {})
)

embedding_dtype = torch.float32 if not options.fp16_mbedding else torch.float16

if wm_optimizer is None:
node_feat_wm_embedding = wgth.create_embedding_from_filelist(
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
os.path.join(options.root_dir, "node_feat.bin"),
torch.float,
embedding_dtype,
options.feat_dim,
optimizer=wm_optimizer,
cache_policy=cache_policy,
Expand All @@ -204,7 +210,7 @@ def main_func():
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
torch.float,
embedding_dtype,
[graph_structure.node_count, options.feat_dim],
optimizer=wm_optimizer,
cache_policy=cache_policy,
Expand Down
86 changes: 86 additions & 0 deletions python/pylibwholegraph/examples/ogbn_papers100m_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
import os
import numpy as np
from scipy.sparse import coo_matrix
import pickle
from ogb.nodeproppred import NodePropPredDataset


def save_array(np_array, save_path, array_file_name):
array_full_path = os.path.join(save_path, array_file_name)
with open(array_full_path, 'wb') as f:
np_array.tofile(f)


def convert_papers100m_dataset(args):
ogb_root = args.ogb_root_dir
dataset = NodePropPredDataset(name='ogbn-papers100M', root=ogb_root)
graph, label = dataset[0]
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = (
split_idx["train"],
split_idx["valid"],
split_idx["test"],
)
train_label = label[train_idx]
valid_label = label[valid_idx]
test_label = label[test_idx]
data_and_label = {
"train_idx": train_idx,
"valid_idx": valid_idx,
"test_idx": test_idx,
"train_label": train_label,
"valid_label": valid_label,
"test_label": test_label,
}
num_nodes = graph["num_nodes"]
edge_index = graph["edge_index"]
node_feat = graph["node_feat"].astype(np.dtype(args.node_feat_format))
if not os.path.exists(args.convert_dir):
print(f"creating directory {args.convert_dir}...")
os.makedirs(args.convert_dir)
print("saving idx and labels...")
with open(
os.path.join(args.convert_dir, 'ogbn_papers100M_data_and_label.pkl'), "wb"
) as f:
pickle.dump(data_and_label, f)
print("saving node feature...")
with open(
os.path.join(args.convert_dir, 'node_feat.bin'), "wb"
) as f:
node_feat.tofile(f)

print("converting graph to csr...")
assert len(edge_index.shape) == 2
assert edge_index.shape[0] == 2
coo_src_ids = edge_index[0, :].astype(np.int32)
coo_dst_ids = edge_index[1, :].astype(np.int32)
if args.add_reverse_edges:
arg_graph_src = np.concatenate([coo_src_ids, coo_dst_ids])
arg_graph_dst = np.concatenate([coo_dst_ids, coo_src_ids])
else:
arg_graph_src = coo_src_ids
arg_graph_dst = coo_dst_ids
values = np.arange(len(arg_graph_src), dtype='int64')
coo_graph = coo_matrix((values, (arg_graph_src, arg_graph_dst)), shape=(num_nodes, num_nodes))
csr_graph = coo_graph.tocsr()
csr_row_ptr = csr_graph.indptr.astype(dtype='int64')
csr_col_ind = csr_graph.indices.astype(dtype='int32')
print("saving csr graph...")
save_array(csr_row_ptr, args.convert_dir, 'homograph_csr_row_ptr')
save_array(csr_col_ind, args.convert_dir, 'homograph_csr_col_idx')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ogb_root_dir', type=str, default='dataset',
help='root dir of containing ogb datasets')
parser.add_argument('--convert_dir', type=str, default='dataset_papers100m_converted',
help='output dir containing converted datasets')
parser.add_argument('--node_feat_format', type=str, default='float32',
choices=['float32', 'float16'],
help='save format of node feature')
parser.add_argument('--add_reverse_edges', type=bool, default=True,
help='whether to add reverse edges')
args = parser.parse_args()
convert_papers100m_dataset(args)
2 changes: 1 addition & 1 deletion python/pylibwholegraph/pylibwholegraph/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from .embedding import WholeMemoryEmbeddingModule

from .initialize import init_torch_env, init_torch_env_and_create_wm_comm, finalize
from .initialize import init, init_torch_env, init_torch_env_and_create_wm_comm, finalize

from .tensor import (
WholeMemoryTensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def add_common_model_options(parser: OptionParser):
default="cugraph",
help="framework type, valid values are: dgl, pyg, wg, cugraph",
)
parser.add_option("--heads", type="int", dest="heads", default=1, help="num heads")
parser.add_option("--heads", type="int", dest="heads", default=4, help="num heads")
parser.add_option(
"-d", "--dropout", type="float", dest="dropout", default=0.5, help="dropout"
)
Expand All @@ -126,9 +126,8 @@ def add_common_sampler_options(parser: OptionParser):
parser.add_option(
"-s",
"--inferencesample",
type="int",
dest="inferencesample",
default=30,
default="30",
help="inference sample count, -1 is all",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ def __init__(
self.reset_parameters()

def reset_parameters(self):
self.lin.reset_parameters()
gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_normal_(self.lin.weight, gain=gain)
torch.nn.init.xavier_normal_(
self.att.view(2, self.heads, self.out_channels), gain=gain
self.att.view(2, self.heads, self.out_channels)[0, :, :], gain=gain
)
torch.nn.init.xavier_normal_(
self.att.view(2, self.heads, self.out_channels)[1, :, :], gain=gain
)
torch.nn.init.zeros_(self.bias)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def __init__(
self.reset_parameters()

def reset_parameters(self):
gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_uniform_(self.lin.weight, gain=gain)
if self.project:
self.pre_lin.reset_parameters()
self.lin.reset_parameters()
torch.nn.init.xavier_uniform_(self.pre_lin.weight, gain=gain)
torch.nn.init.xavier_uniform_(self.lin.weight, gain=gain)

def forward(
self,
Expand Down
37 changes: 16 additions & 21 deletions python/pylibwholegraph/pylibwholegraph/torch/gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,26 +150,19 @@ def create_sub_graph(
return edge_index
elif framework_name == "dgl":
if add_self_loop:
self_loop_ids = torch.arange(
0,
target_gid_1.numel(),
dtype=edge_data[0].dtype,
device=target_gid.device,
)
block = dgl.create_block(
csr_row_ptr, csr_col_ind = add_csr_self_loop(csr_row_ptr, csr_col_ind)
block = dgl.create_block(
(
'csc',
(
torch.cat([edge_data[0], self_loop_ids]),
torch.cat([edge_data[1], self_loop_ids]),
csr_row_ptr,
csr_col_ind,
torch.empty(0, dtype=torch.int),
),
num_src_nodes=target_gid.size(0),
num_dst_nodes=target_gid_1.size(0),
)
else:
block = dgl.create_block(
(edge_data[0], edge_data[1]),
num_src_nodes=target_gid.size(0),
num_dst_nodes=target_gid_1.size(0),
)
),
num_src_nodes=target_gid.size(0),
num_dst_nodes=target_gid_1.size(0),
)
return block
elif framework_name == "cugraph":
if add_self_loop:
Expand Down Expand Up @@ -224,19 +217,21 @@ def __init__(
self.gather_fn = WholeMemoryEmbeddingModule(self.node_embedding)
self.dropout = options.dropout
self.max_neighbors = parse_max_neighbors(options.layernum, options.neighbors)
self.max_inference_neighbors = parse_max_neighbors(options.layernum, options.inferencesample)

def forward(self, ids):
global framework_name
max_neighbors = self.max_neighbors if self.training else self.max_inference_neighbors
ids = ids.to(self.graph_structure.csr_col_ind.dtype).cuda()
(
target_gids,
edge_indice,
csr_row_ptrs,
csr_col_inds,
) = self.graph_structure.multilayer_sample_without_replacement(
ids, self.max_neighbors
ids, max_neighbors
)
x_feat = self.gather_fn(target_gids[0])
x_feat = self.gather_fn(target_gids[0], force_dtype=torch.float32)
for i in range(self.num_layer):
x_target_feat = x_feat[: target_gids[i + 1].numel()]
sub_graph = create_sub_graph(
Expand All @@ -245,7 +240,7 @@ def forward(self, ids):
edge_indice[i],
csr_row_ptrs[i],
csr_col_inds[i],
self.max_neighbors[i],
max_neighbors[self.num_layer - 1 - i],
self.add_self_loop,
)
x_feat = layer_forward(
Expand Down
5 changes: 5 additions & 0 deletions python/pylibwholegraph/pylibwholegraph/torch/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from .comm import set_world_info, get_global_communicator, get_local_node_communicator


def init(world_rank: int, world_size: int, local_rank: int, local_size: int):
wmb.init(0)
set_world_info(world_rank, world_size, local_rank, local_size)


def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int):
r"""Init WholeGraph environment for PyTorch.
:param world_rank: world rank of current process
Expand Down
38 changes: 38 additions & 0 deletions python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .comm import WholeMemoryCommunicator
from typing import Union, List
from .dlpack_utils import torch_import_from_dlpack
from .wholegraph_env import wrap_torch_tensor, get_wholegraph_env_fns, get_stream


WholeMemoryMemoryType = wmb.WholeMemoryMemoryType
Expand Down Expand Up @@ -57,6 +58,43 @@ def get_comm(self):
self.wmb_tensor.get_wholememory_handle().get_communicator()
)

def gather(self,
indice: torch.Tensor,
*,
force_dtype: Union[torch.dtype, None] = None):
assert indice.dim() == 1
embedding_dim = self.shape[1]
embedding_count = indice.shape[0]
current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),)
output_dtype = (
force_dtype if force_dtype is not None else self.embedding_tensor.dtype
)
output_tensor = torch.empty(
[embedding_count, embedding_dim],
device=current_cuda_device,
dtype=output_dtype,
requires_grad=False,
)
wmb.wholememory_gather_op(self.wmb_tensor,
wrap_torch_tensor(indice),
wrap_torch_tensor(output_tensor),
get_wholegraph_env_fns(),
get_stream())
return output_tensor

def scatter(self,
input_tensor: torch.Tensor,
indice: torch.Tensor):
assert indice.dim() == 1
assert input_tensor.dim() == 2
assert indice.shape[0] == input_tensor.shape[0]
assert input_tensor.shape[1] == self.shape[1]
wmb.wholememory_scatter_op(wrap_torch_tensor(input_tensor),
wrap_torch_tensor(indice),
self.wmb_tensor,
get_wholegraph_env_fns(),
get_stream())

def get_sub_tensor(self, starts, ends):
"""
Get sub tensor of WholeMemory Tensor
Expand Down

0 comments on commit 4f9f39d

Please sign in to comment.