Skip to content

Commit

Permalink
Revert partition refactor (#158)
Browse files Browse the repository at this point in the history
Co-authored-by: Zeyuan Tan <[email protected]>
  • Loading branch information
ZenoTan and Zeyuan Tan authored Jun 10, 2023
1 parent bd723ab commit 3454393
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 36 deletions.
2 changes: 1 addition & 1 deletion benchmarks/ogbn-mag240m/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_sparse import SparseTensor
import time
import numpy as np
from quiver.partition import partition_with_replication, partition_without_replication, select_nodes
from quiver.partition import partition_without_replication, select_nodes

SCALE = 1
GPU_CACHE_GB = 8
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/ogbn-papers100M/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pathlib import Path

import quiver
from quiver.partition import partition_with_replication, partition_without_replication, select_nodes
from quiver.partition import partition_without_replication, select_nodes

# data_root = "/data/papers/ogbn_papers100M/raw/"
# label = np.load(osp.join(data_root, "node-label.npz"))
Expand Down
177 changes: 143 additions & 34 deletions srcs/python/quiver/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,95 @@
from typing import List
import quiver.utils as quiver_util

__all__ = [
"quiver_partition_feature", "load_quiver_feature_partition",
"partition_without_replication", "select_nodes"
]

QUIVER_MAGIC_NUMBER = 256

__all__ = ["quiver_partition_feature", "load_quiver_feature_partition"]

def partition_without_replication(device, probs, ids):
"""Partition node with given node IDs and node access distribution.
The result will cause no replication between each parititon.
We assume node IDs can be placed in the given device.
QUIVER_MAGIC_NUMBER = 256
Args:
device (int): device which computes the partitioning strategy
probs (torch.Tensor): node access distribution
ids (Optional[torch.Tensor]): specified node IDs
Returns:
[torch.Tensor]: list of IDs for each partition
"""
ranks = len(probs)
if ids is not None:
ids = ids.to(device)
probs = [
prob[ids].to(device) if ids is not None else prob.to(device)
for prob in probs
]
total_size = ids.size(0) if ids is not None else probs[0].size(0)
res = [None] * ranks
for rank in range(ranks):
res[rank] = []
CHUNK_SIZE = (total_size + CHUNK_NUM - 1) // CHUNK_NUM
chunk_beg = 0
beg_rank = 0
for i in range(CHUNK_NUM):
chunk_end = min(total_size, chunk_beg + CHUNK_SIZE)
chunk_size = chunk_end - chunk_beg
chunk = torch.arange(chunk_beg,
chunk_end,
dtype=torch.int64,
device=device)
probs_sum_chunk = [
torch.zeros(chunk_size, device=device) + 1e-6 for i in range(ranks)
]
for rank in range(ranks):
for dst_rank in range(ranks):
if dst_rank == rank:
probs_sum_chunk[rank] += probs[dst_rank][chunk] * ranks
else:
probs_sum_chunk[rank] -= probs[dst_rank][chunk]
acc_size = 0
rank_size = (chunk_size + ranks - 1) // ranks
picked_chunk_parts = torch.LongTensor([]).to(device)
for rank_ in range(beg_rank, beg_rank + ranks):
rank = rank_ % ranks
probs_sum_chunk[rank][picked_chunk_parts] -= 1e6
rank_size = min(rank_size, chunk_size - acc_size)
_, rank_order = torch.sort(probs_sum_chunk[rank], descending=True)
pick_chunk_part = rank_order[:rank_size]
pick_ids = chunk[pick_chunk_part]
picked_chunk_parts = torch.cat(
(picked_chunk_parts, pick_chunk_part))
res[rank].append(pick_ids)
acc_size += rank_size
beg_rank += 1
chunk_beg += chunk_size
for rank in range(ranks):
res[rank] = torch.cat(res[rank])
if ids is not None:
res[rank] = ids[res[rank]]
return res


def select_nodes(device, probs, ids):
nodes = probs[0].size(0)
prob_sum = torch.zeros(nodes, device=device)
for prob in probs:
if ids is None:
prob_sum += prob
else:
prob_sum[ids] += prob[ids]
node_ids = torch.nonzero(prob_sum)
return prob_sum, node_ids

def partition_feature_without_replication(probs: List[torch.Tensor], chunk_size: int):

def partition_feature_without_replication(probs: List[torch.Tensor],
chunk_size: int):
"""Partition node with node access distribution.
The result will cause no replication between each parititon.
Expand All @@ -38,24 +119,32 @@ def partition_feature_without_replication(probs: List[torch.Tensor], chunk_size:
current_chunk_start_pos = 0
current_partition_idx = 0
for _ in range(chunk_num):
current_chunk_end_pos = min(total_node_num, current_chunk_start_pos + blob_size)
current_chunk_end_pos = min(total_node_num,
current_chunk_start_pos + blob_size)
current_chunk_size = current_chunk_end_pos - current_chunk_start_pos
chunk = torch.arange(current_chunk_start_pos, current_chunk_end_pos, device=device)
chunk = torch.arange(current_chunk_start_pos,
current_chunk_end_pos,
device=device)
probs_sum_chunk = [
torch.zeros(current_chunk_size, device=device) + 1e-6 for _ in range(partitioned_num)
torch.zeros(current_chunk_size, device=device) + 1e-6
for _ in range(partitioned_num)
]
for src_rank in range(partitioned_num):
for dst_rank in range(partitioned_num):
if dst_rank == src_rank:
probs_sum_chunk[src_rank] += probs[dst_rank][chunk] * partitioned_num
probs_sum_chunk[
src_rank] += probs[dst_rank][chunk] * partitioned_num
else:
probs_sum_chunk[src_rank] -= probs[dst_rank][chunk]
assigned_node_size = 0
per_partition_size = chunk_size
for partition_idx in range(current_partition_idx, current_partition_idx + partitioned_num):
for partition_idx in range(current_partition_idx,
current_partition_idx + partitioned_num):
partition_idx = partition_idx % partitioned_num
actual_per_partition_size = min(per_partition_size, current_chunk_size - assigned_node_size)
_, sorted_res_order = torch.sort(probs_sum_chunk[partition_idx], descending=True)
actual_per_partition_size = min(
per_partition_size, current_chunk_size - assigned_node_size)
_, sorted_res_order = torch.sort(probs_sum_chunk[partition_idx],
descending=True)
pick_chunk_part = sorted_res_order[:actual_per_partition_size]
pick_ids = chunk[pick_chunk_part]
res[partition_idx].append(pick_ids)
Expand All @@ -70,7 +159,11 @@ def partition_feature_without_replication(probs: List[torch.Tensor], chunk_size:
return res, probs


def quiver_partition_feature(probs:torch.Tensor, result_path: str, cache_memory_budget=0, per_feature_size=0, chunk_size=QUIVER_MAGIC_NUMBER):
def quiver_partition_feature(probs: torch.Tensor,
result_path: str,
cache_memory_budget=0,
per_feature_size=0,
chunk_size=QUIVER_MAGIC_NUMBER):
"""
Partition graph feature based on access probability and generate result folder. The final result folder will be like:
Expand Down Expand Up @@ -99,51 +192,63 @@ def quiver_partition_feature(probs:torch.Tensor, result_path: str, cache_memory_
"""

if os.path.exists(result_path):
res = input(f"{result_path} already exists, enter Y/N to continue, If continue, {result_path} will be deleted:")
res = input(
f"{result_path} already exists, enter Y/N to continue, If continue, {result_path} will be deleted:"
)
res = res.upper()
if res == "Y":
shutil.rmtree(result_path)
else:
print("exiting ...")
exit()

partition_num = len(probs)



# create result folder
for partition_idx in range(partition_num):
os.makedirs(os.path.join(result_path, f"feature_partition_{partition_idx}"))

os.makedirs(
os.path.join(result_path, f"feature_partition_{partition_idx}"))

# calculate cached feature count
cache_memory_budget_bytes = quiver_util.parse_size(cache_memory_budget)
per_feature_size_bytes = quiver_util.parse_size(per_feature_size)
cache_count = int(cache_memory_budget_bytes / (per_feature_size_bytes + 1e-6))
cache_count = int(cache_memory_budget_bytes /
(per_feature_size_bytes + 1e-6))
per_partition_cache_count = cache_count // partition_num

partition_book = torch.zeros(probs[0].shape, dtype=torch.int64, device=torch.cuda.current_device())
partition_res, changed_probs = partition_feature_without_replication(probs, chunk_size)

partition_book = torch.zeros(probs[0].shape,
dtype=torch.int64,
device=torch.cuda.current_device())
partition_res, changed_probs = partition_feature_without_replication(
probs, chunk_size)

cache_res = [None] * partition_num

if cache_count > 0:
for partition_idx in range(partition_num):
_, prev_order = torch.sort(changed_probs[partition_idx], descending=True)
cache_res[partition_idx] = prev_order[: per_partition_cache_count]

_, prev_order = torch.sort(changed_probs[partition_idx],
descending=True)
cache_res[partition_idx] = prev_order[:per_partition_cache_count]

for partition_idx in range(partition_num):
partition_result_path = os.path.join(result_path, f"feature_partition_{partition_idx}", "partition_res.pth")
cache_result_path = os.path.join(result_path, f"feature_partition_{partition_idx}", "cache_res.pth")
partition_result_path = os.path.join(
result_path, f"feature_partition_{partition_idx}",
"partition_res.pth")
cache_result_path = os.path.join(result_path,
f"feature_partition_{partition_idx}",
"cache_res.pth")
partition_book[partition_res[partition_idx]] = partition_idx
torch.save(partition_res[partition_idx], partition_result_path)
torch.save(cache_res[partition_idx], cache_result_path)

partition_book_path = os.path.join(result_path, f"feature_partition_book.pth")

partition_book_path = os.path.join(result_path,
f"feature_partition_book.pth")
torch.save(partition_book, partition_book_path)

return partition_book, partition_res, cache_res


def load_quiver_feature_partition(partition_idx: int, result_path:str):
def load_quiver_feature_partition(partition_idx: int, result_path: str):
"""
Load partition result for partition ${partition_idx}
Expand All @@ -160,11 +265,15 @@ def load_quiver_feature_partition(partition_idx: int, result_path:str):

if not os.path.exists(result_path):
raise Exception("Result path not exists")

partition_result_path = os.path.join(result_path, f"feature_partition_{partition_idx}", "partition_res.pth")
cache_result_path = os.path.join(result_path, f"feature_partition_{partition_idx}", "cache_res.pth")
partition_book_path = os.path.join(result_path, f"feature_partition_book.pth")


partition_result_path = os.path.join(result_path,
f"feature_partition_{partition_idx}",
"partition_res.pth")
cache_result_path = os.path.join(result_path,
f"feature_partition_{partition_idx}",
"cache_res.pth")
partition_book_path = os.path.join(result_path,
f"feature_partition_book.pth")

partition_book = torch.load(partition_book_path)
partition_res = torch.load(partition_result_path)
Expand Down

0 comments on commit 3454393

Please sign in to comment.