From fa07f863ad501329ef67963d2a598e1fdaf13fcd Mon Sep 17 00:00:00 2001 From: Yi Hong Date: Wed, 30 Oct 2024 14:06:19 +0800 Subject: [PATCH] feat(learning): feature_store & graph_store V1 (#4237) ## What do these changes do? Step 1: Implement GraphScope-based PyG Remote Backend and complete the end-to-end integration of GraphScope and PyG. (Finished) Step 2: Get data from the Server through PyG Remote Backend and support sampling on the Client side. (Finished) ## Related issue number PyG Remote Backend Based on GraphScope #3739 --- learning_engine/graphlearn-for-pytorch | 2 +- python/graphscope/client/session.py | 11 + .../gl_torch_examples/pyg_remote_backend.py | 157 ++++++++++++ .../graphscope/learning/gs_feature_store.py | 232 ++++++++++++++++++ python/graphscope/learning/gs_graph_store.py | 133 ++++++++++ 5 files changed, 534 insertions(+), 1 deletion(-) create mode 100644 python/graphscope/learning/gl_torch_examples/pyg_remote_backend.py create mode 100644 python/graphscope/learning/gs_feature_store.py create mode 100644 python/graphscope/learning/gs_graph_store.py diff --git a/learning_engine/graphlearn-for-pytorch b/learning_engine/graphlearn-for-pytorch index 6d7f31aae7cb..2034837f6705 160000 --- a/learning_engine/graphlearn-for-pytorch +++ b/learning_engine/graphlearn-for-pytorch @@ -1 +1 @@ -Subproject commit 6d7f31aae7cb9a719e4009c3f18fbb4e85b2a0e1 +Subproject commit 2034837f670532b787b69925bdc895509a924e7a diff --git a/python/graphscope/client/session.py b/python/graphscope/client/session.py index 6f4a42e60cca..2a04cce73cf0 100755 --- a/python/graphscope/client/session.py +++ b/python/graphscope/client/session.py @@ -1331,8 +1331,11 @@ def graphlearn_torch( num_clients=1, manifest_path=None, client_folder_path="./", + return_pyg_remote_backend=False, ): from graphscope.learning.gl_torch_graph import GLTorchGraph + from graphscope.learning.gs_feature_store import GsFeatureStore + from graphscope.learning.gs_graph_store import GsGraphStore from graphscope.learning.utils import fill_params_in_yaml from graphscope.learning.utils import read_folder_files_content @@ -1380,6 +1383,12 @@ def graphlearn_torch( g = GLTorchGraph(endpoints) self._learning_instance_dict[graph.vineyard_id] = g graph._attach_learning_instance(g) + + if return_pyg_remote_backend: + feature_store = GsFeatureStore(config) + graph_store = GsGraphStore(config) + return g, feature_store, graph_store + return g def nx(self): @@ -1682,6 +1691,7 @@ def graphlearn_torch( num_clients=1, manifest_path=None, client_folder_path="./", + return_pyg_remote_backend=False, ): assert graph is not None, "graph cannot be None" assert ( @@ -1699,4 +1709,5 @@ def graphlearn_torch( num_clients, manifest_path, client_folder_path, + return_pyg_remote_backend, ) # pylint: disable=protected-access diff --git a/python/graphscope/learning/gl_torch_examples/pyg_remote_backend.py b/python/graphscope/learning/gl_torch_examples/pyg_remote_backend.py new file mode 100644 index 000000000000..2e8c8a522eb7 --- /dev/null +++ b/python/graphscope/learning/gl_torch_examples/pyg_remote_backend.py @@ -0,0 +1,157 @@ +import torch +import torch.nn.functional as F +from ogb.nodeproppred import Evaluator +from torch_geometric.data.feature_store import TensorAttr +from torch_geometric.loader import NeighborLoader +from torch_geometric.nn import GraphSAGE +from tqdm import tqdm + +import graphscope as gs +import graphscope.learning.graphlearn_torch as glt +from graphscope.dataset import load_ogbn_arxiv + +NUM_EPOCHS = 10 +BATCH_SIZE = 4096 +NUM_SERVERS = 1 +NUM_NEIGHBORS = [2, 2, 2] + +print("Batch size:", BATCH_SIZE) +print("Number of epochs:", NUM_EPOCHS) +print("Number of neighbors:", NUM_NEIGHBORS) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Using device:", device) + +gs.set_option(show_log=True) + +# load the ogbn_arxiv graph. +sess = gs.session(cluster_type="hosts", num_workers=NUM_SERVERS) +g = load_ogbn_arxiv(sess=sess) + +print("-- Initializing store ...") +glt_graph, feature_store, graph_store = gs.graphlearn_torch( + g, + edges=[ + ("paper", "citation", "paper"), + ], + node_features={ + "paper": [f"feat_{i}" for i in range(128)], + }, + node_labels={ + "paper": "label", + }, + edge_dir="out", + random_node_split={ + "num_val": 0.1, + "num_test": 0.1, + }, + return_pyg_remote_backend=True, +) + +print("-- Initializing client ...") +glt.distributed.init_client( + num_servers=1, + num_clients=1, + client_rank=0, + master_addr=glt_graph.master_addr, + master_port=glt_graph.server_client_master_port, + num_rpc_threads=4, + is_dynamic=True, +) + + +print("-- Initializing loader ...") +# get train & test mask +num_nodes = feature_store.get_tensor_size(TensorAttr(group_name="paper"))[0] +print("Node num:", num_nodes) +shuffle_id = torch.randperm(num_nodes) +train_indices = shuffle_id[: int(0.8 * num_nodes)] +test_indices = shuffle_id[int(0.2 * num_nodes) :] +train_mask = torch.zeros(num_nodes, dtype=torch.bool) +test_mask = torch.zeros(num_nodes, dtype=torch.bool) +train_mask[train_indices] = True +test_mask[test_indices] = True + +train_loader = NeighborLoader( + data=(feature_store, graph_store), + batch_size=BATCH_SIZE, + num_neighbors=NUM_NEIGHBORS, + shuffle=False, + input_nodes=("paper", train_mask), +) + +test_loader = NeighborLoader( + data=(feature_store, graph_store), + batch_size=BATCH_SIZE, + num_neighbors=NUM_NEIGHBORS, + shuffle=False, + input_nodes=("paper", test_mask), +) + +model = GraphSAGE( + in_channels=128, + hidden_channels=256, + num_layers=3, + out_channels=47, +).to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + +@torch.no_grad() +def test(model, test_loader, dataset_name): + evaluator = Evaluator(name=dataset_name) + model.eval() + xs = [] + y_true = [] + for i, batch in enumerate(test_loader): + if i == 0: + device = batch["paper"].x.device + batch["paper"].x = batch["paper"].x.to(torch.float32) # TODO + x = model(batch["paper"].x, batch[("paper", "citation", "paper")].edge_index)[ + : batch["paper"].batch_size + ] + xs.append(x.cpu()) + y_true.append(batch["paper"].label[: batch["paper"].batch_size].cpu()) + del batch + + xs = [t.to(device) for t in xs] + y_true = [t.to(device) for t in y_true] + y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True) + y_true = torch.cat(y_true, dim=0).unsqueeze(-1) + test_acc = evaluator.eval( + { + "y_true": y_true, + "y_pred": y_pred, + } + )["acc"] + return test_acc + + +dataset_name = "ogbn-arxiv" +for epoch in range(NUM_EPOCHS): + model.train() + with tqdm( + total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", unit="batch" + ) as pbar: + for batch in train_loader: + optimizer.zero_grad() + batch["paper"].x = batch["paper"].x.to(torch.float32) # TODO + out = model( + batch["paper"].x, batch[("paper", "citation", "paper")].edge_index + )[: batch["paper"].batch_size].log_softmax(dim=-1) + label = batch["paper"].label[: batch["paper"].batch_size].long() + loss = F.nll_loss(out, label) + loss.backward() + optimizer.step() + pbar.set_postfix({"Loss": f"{loss:.4f}"}) + pbar.update(1) + + # Test accuracy. + if epoch % 2 == 0: + test_acc = test(model, test_loader, dataset_name) + print(f"-- Test Accuracy: {test_acc:.4f}", flush=True) + +print("-- Shutdowning ...") +glt.distributed.shutdown_client() + +print("-- Exited ...") diff --git a/python/graphscope/learning/gs_feature_store.py b/python/graphscope/learning/gs_feature_store.py new file mode 100644 index 000000000000..2fd9728ce690 --- /dev/null +++ b/python/graphscope/learning/gs_feature_store.py @@ -0,0 +1,232 @@ +import base64 +import json +from collections import defaultdict +from dataclasses import dataclass +from enum import Enum +from multiprocessing.reduction import ForkingPickler +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from torch import Tensor +from torch_geometric.data.feature_store import FeatureStore +from torch_geometric.data.feature_store import IndexType +from torch_geometric.data.feature_store import TensorAttr +from torch_geometric.data.feature_store import _FieldStatus +from torch_geometric.typing import FeatureTensorType + +from graphscope.learning.graphlearn_torch.distributed.dist_client import request_server +from graphscope.learning.graphlearn_torch.distributed.dist_server import DistServer +from graphscope.learning.graphlearn_torch.typing import NodeType + +KeyType = Tuple[str, ...] + + +class GsFeatureStore(FeatureStore): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.tensor_attrs: Dict[Tuple[NodeType, str], TensorAttr] = {} + + assert config is not None + config = json.loads( + base64.b64decode(config.encode("utf-8", errors="ignore")).decode( + "utf-8", errors="ignore" + ) + ) + self.node_features = config["node_features"] + self.node_labels = config["node_labels"] + self.edges = config["edges"] + + assert self.node_features is not None + self.node_types = set() + for node in self.node_features: + self.node_types.add(node) + + for edge in self.edges: + self.node_types.add(edge[0]) + self.node_types.add(edge[-1]) + + for node_type in self.node_types: + self.tensor_attrs[(node_type, "x")] = TensorAttr(node_type, "x") + + assert self.node_labels is not None + for node_type, node_label in self.node_labels.items(): + self.tensor_attrs[(node_type, node_label)] = TensorAttr( + node_type, node_label + ) + + @staticmethod + def key(attr: TensorAttr) -> KeyType: + return (attr.group_name, attr.attr_name, attr.index) + + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + r"""To be implemented by :class:`GsFeatureStore`.""" + raise NotImplementedError + + def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]: + r"""Obtains a :class:`torch.Tensor` from the remote server. + + Args: + attr(`TensorAttr`): Uniquely corresponds to a node/edge feature tensor . + + Raises: + ValueError: If the attr can not be found in the attrlists of feature store. + + Returns: + feature(`torch.Tensor`): The node/edge feature tensor. + """ + + group_name, attr_name, index = self.key(attr) + if not self._check_attr(attr): + raise ValueError( + f"Attribute {group_name}-{attr_name} not found in feature store." + ) + result = torch.tensor([]) + index = self.index_to_tensor(index) + if index.numel() == 0: + return result + + server_fun = DistServer.get_node_feature + labels = self.node_labels[group_name] + is_label = False + if isinstance(labels, list) and attr_name in labels: + server_fun = DistServer.get_node_label + is_label = True + elif isinstance(labels, str) and attr_name == labels: + server_fun = DistServer.get_node_label + is_label = True + + num_partitions, _, _, _ = request_server(0, DistServer.get_dataset_meta) + partition_ids = self._get_partition_id(attr) + indexes = [] + features = [] + input_order = torch.arange(index.size(0), dtype=torch.long) + for pidx in range(0, num_partitions): + remote_mask = partition_ids == pidx + remote_ids = torch.masked_select(index, remote_mask) + if remote_ids.shape[0] > 0: + feature = request_server(pidx, server_fun, group_name, remote_ids) + features.append(feature) + indexes.append(torch.masked_select(input_order, remote_mask)) + + if not is_label: + result = torch.zeros( + index.shape[0], features[0].shape[1], dtype=features[0].dtype + ) + else: + result = torch.zeros(index.shape[0], 1, dtype=features[0].dtype) + + for i, feature in enumerate(features): + result[indexes[i]] = feature + if is_label: + result = result.reshape(-1) + return result + + def _get_partition_id(self, attr: TensorAttr) -> Optional[int]: + r"""Obtains the id of the partition where the tensor is stored from remote server. + + Args: + attr(`TensorAttr`): Uniquely corresponds to a node/edge feature tensor . + + Returns: + partition_id(int): The corresponding partition id. + """ + result = None + group_name, _, gid = self.key(attr) + gid = self.index_to_tensor(gid) + result = request_server(0, DistServer.get_node_partition_id, group_name, gid) + return result + + def _remove_tensor(self, attr: TensorAttr) -> bool: + r"""To be implemented by :class:`GsFeatureStore`.""" + raise NotImplementedError + + def _check_attr(self, attr: TensorAttr) -> bool: + r"""Check the given :class:`TensorAttr` is stored in remote server or not. + + Args: + attr(`TensorAttr`): Uniquely corresponds to a node/edge feature tensor . + + Returns: + flag(bool): True: :class:`TensorAttr` is stored in remote server. \ + False: :class:`TensorAttr` is not stored in remote server + """ + group_name, attr_name, _ = self.key(attr) + if not attr.is_set("attr_name"): + return any(group_name in key for key in self.tensor_attrs.keys()) + return (group_name, attr_name) in self.tensor_attrs + + def _get_tensor_size(self, attr: TensorAttr) -> Optional[torch.Size]: + r"""Obtains the dimension of feature tensor from remote server. + + Args: + attr(`TensorAttr`): Uniquely corresponds to a node/edge feature tensor type . + + Returns: + tensor_size(`torch.Size`): The num of corresponding tensor type. + """ + group_name, _, _ = self.key(attr) + size = request_server(0, DistServer.get_tensor_size, group_name) + return size + + def get_all_tensor_attrs(self) -> List[TensorAttr]: + r"""Obtains all the tensor type stored in remote server. + + Returns: + tensor_attrs(`List[TensorAttr]`): All the tensor type stored in the remote server. + """ + return [attr for attr in self.tensor_attrs.values()] + + def index_to_tensor(self, index) -> torch.Tensor: + r"""Convert the Index to type :class:`torch.Tensor`. + + Args: + index(`IndexType`): The index that needs to be converted. + + Returns: + index(`torch.Tensor`): index of type :class:`torch.Tensor`. + """ + if isinstance(index, torch.Tensor): + return index + elif isinstance(index, np.ndarray): + return torch.from_numpy(index) + elif isinstance(index, slice): + start = index.start if index.start is not None else 0 + stop = index.stop if index.stop is not None else -1 + step = index.step if index.step is not None else 1 + return torch.arange(start, stop, step) + elif isinstance(index, int): + return torch.tensor([index]) + elif isinstance(index, list): + return torch.tensor(index) + else: + raise TypeError(f"Unsupported index type: {type(index)}") + + @classmethod + def from_ipc_handle(cls, ipc_handle): + return cls(*ipc_handle) + + def share_ipc(self): + ipc_hanlde = self.config + return ipc_hanlde + + +# Pickling Registration + + +def rebuild_featurestore(ipc_handle): + fs = GsFeatureStore.from_ipc_handle(ipc_handle) + return fs + + +def reduce_featurestore(FeatureStore: GsFeatureStore): + ipc_handle = FeatureStore.share_ipc() + return (rebuild_featurestore, (ipc_handle,)) + + +ForkingPickler.register(GsFeatureStore, reduce_featurestore) diff --git a/python/graphscope/learning/gs_graph_store.py b/python/graphscope/learning/gs_graph_store.py new file mode 100644 index 000000000000..4aa8e3e067d6 --- /dev/null +++ b/python/graphscope/learning/gs_graph_store.py @@ -0,0 +1,133 @@ +import base64 +import json +from multiprocessing.reduction import ForkingPickler +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import torch +from torch_geometric.data.graph_store import EdgeAttr +from torch_geometric.data.graph_store import GraphStore +from torch_geometric.typing import EdgeTensorType +from torch_geometric.utils import index_sort + +from graphscope.learning.graphlearn_torch.distributed.dist_client import request_server +from graphscope.learning.graphlearn_torch.distributed.dist_server import DistServer + + +class GsGraphStore(GraphStore): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.edge_attrs: Dict[Tuple[Tuple[str, str, str], str, bool], EdgeAttr] = {} + + assert config is not None + config = json.loads( + base64.b64decode(config.encode("utf-8", errors="ignore")).decode( + "utf-8", errors="ignore" + ) + ) + self.edges = config["edges"] + self.edge_dir = config["edge_dir"] + + assert self.edges is not None + for edge in self.edges: + edge = tuple(edge) + # Only support COO layout + layout = "coo" + new_edge_attr = EdgeAttr(edge, layout, True) + self.edge_attrs[(edge, layout, True)] = new_edge_attr + + @staticmethod + def key(attr: EdgeAttr) -> Tuple: + return (attr.edge_type, attr.layout.value, attr.is_sorted, attr.size) + + def _put_edge_index( + self, + edge_index: EdgeTensorType, + edge_attr: EdgeAttr, + ) -> bool: + r"""To be implemented by :class:`GsFeatureStore`.""" + raise NotImplementedError + + def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: + r"""Obtains a :class:`EdgeTensorType` from the remote server with :class:`EdgeAttr`. + + Args: + edge_attr(`EdgeAttr`): Uniquely corresponds to a topology of subgraph . + + Returns: + edge_index(`EdgeTensorType`): The edge index tensor, which is a :class:`tuple` of\ + (row indice tensor, column indice tensor) + """ + group_name, layout, _, _ = self.key(edge_attr) + num_servers, _, _, _ = request_server(0, DistServer.get_dataset_meta) + rows = [] + cols = [] + for server_id in range(num_servers): + (row, col) = request_server( + server_id, DistServer.get_edge_index, group_name, layout + ) + rows.append(row) + cols.append(col) + + global_row = torch.cat(rows, dim=0) + global_row, perm = index_sort(global_row, max_value=int(global_row.max()) + 1) + global_col = torch.cat(cols, dim=0)[perm] + return (global_row, global_col) + + def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: + r"""To be implemented by :class:`GsFeatureStore`.""" + raise NotImplementedError + + def _get_edge_size(self, edge_attr: EdgeAttr) -> Tuple[int, int]: + r"""Obtains a :class:`EdgeTensorType` from the remote server with :class:`EdgeAttr`. + + Args: + edge_attr(`EdgeAttr`): Uniquely corresponds to a topology of subgraph . + + Returns: + size(`tupple(int, int)`): The size of the subgraph. + """ + group_name, layout, is_sorted, _ = self.key(edge_attr) + (row, col) = self._get_edge_index(edge_attr) + size = (int(row.max()) + 1, int(col.max()) + 1) + new_edge_attr = EdgeAttr(group_name, layout, is_sorted, size) + self.edge_attrs[(group_name, layout, is_sorted)] = new_edge_attr + return size + + def get_all_edge_attrs(self) -> List[EdgeAttr]: + r"""Obtains all the subgraph type stored in remote server. + + Returns: + edge_attrs(`List[EdgeAttr]`): All the subgraph type stored in the remote server. + """ + result = [] + for attr in self.edge_attrs.values(): + result.append(attr) + return result + + @classmethod + def from_ipc_handle(cls, ipc_handle): + return cls(*ipc_handle) + + def share_ipc(self): + ipc_hanlde = self.config + return ipc_hanlde + + +# Pickling Registration + + +def rebuild_graphstore(ipc_handle): + gs = GsGraphStore.from_ipc_handle(ipc_handle) + return gs + + +def reduce_graphstore(GraphStore: GsGraphStore): + ipc_handle = GraphStore.share_ipc() + return (rebuild_graphstore, (ipc_handle,)) + + +ForkingPickler.register(GsGraphStore, reduce_graphstore)