Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
replace optparse with argparser (#61)
Browse files Browse the repository at this point in the history
replace `optparse`  with `argparser` in `pylibwholegraph`

Authors:
  - Chuang Zhu (https://github.com/chuangz0)

Approvers:
  - https://github.com/dongxuy04
  - Brad Rees (https://github.com/BradReesWork)

URL: #61
  • Loading branch information
chuangz0 authored Nov 21, 2023
1 parent 02794d9 commit 405d86c
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 140 deletions.
93 changes: 45 additions & 48 deletions python/pylibwholegraph/examples/node_classfication.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,26 @@
import datetime
import os
import time
from optparse import OptionParser
import argparse

import apex
import torch
from apex.parallel import DistributedDataParallel as DDP

import pylibwholegraph.torch as wgth

parser = OptionParser()

wgth.add_distributed_launch_options(parser)
wgth.add_training_options(parser)
wgth.add_common_graph_options(parser)
wgth.add_common_model_options(parser)
wgth.add_common_sampler_options(parser)
wgth.add_node_classfication_options(parser)
wgth.add_dataloader_options(parser)
parser.add_option(
argparser = argparse.ArgumentParser()
wgth.add_distributed_launch_options(argparser)
wgth.add_training_options(argparser)
wgth.add_common_graph_options(argparser)
wgth.add_common_model_options(argparser)
wgth.add_common_sampler_options(argparser)
wgth.add_node_classfication_options(argparser)
wgth.add_dataloader_options(argparser)
argparser.add_argument(
"--fp16_embedding", action="store_true", dest="fp16_mbedding", default=False, help="Whether to use fp16 embedding"
)


(options, args) = parser.parse_args()
args = argparser.parse_args()


def valid_test(dataloader, model, name):
Expand Down Expand Up @@ -68,7 +65,7 @@ def valid(valid_dataloader, model):


def test(test_dataset, model):
test_dataloader = wgth.get_valid_test_dataloader(test_dataset, options.batchsize)
test_dataloader = wgth.get_valid_test_dataloader(test_dataset, args.batchsize)
valid_test(test_dataloader, model, "TEST")


Expand All @@ -77,19 +74,19 @@ def train(train_data, valid_data, model, optimizer, wm_optimizer, global_comm):
print("start training...")
train_dataloader = wgth.get_train_dataloader(
train_data,
options.batchsize,
args.batchsize,
replica_id=wgth.get_rank(),
num_replicas=wgth.get_world_size(),
num_workers=options.dataloaderworkers,
num_workers=args.dataloaderworkers,
)
valid_dataloader = wgth.get_valid_test_dataloader(valid_data, options.batchsize)
valid_dataloader = wgth.get_valid_test_dataloader(valid_data, args.batchsize)
valid(valid_dataloader, model)

train_step = 0
epoch = 0
loss_fcn = torch.nn.CrossEntropyLoss()
train_start_time = time.time()
while epoch < options.epochs:
while epoch < args.epochs:
for i, (idx, label) in enumerate(train_dataloader):
label = torch.reshape(label, (-1,)).cuda()
optimizer.zero_grad()
Expand All @@ -99,7 +96,7 @@ def train(train_data, valid_data, model, optimizer, wm_optimizer, global_comm):
loss.backward()
optimizer.step()
if wm_optimizer is not None:
wm_optimizer.step(options.lr * 0.1)
wm_optimizer.step(args.lr * 0.1)
if wgth.get_rank() == 0 and train_step % 100 == 0:
print(
"[%s] [LOSS] step=%d, loss=%f"
Expand All @@ -121,7 +118,7 @@ def train(train_data, valid_data, model, optimizer, wm_optimizer, global_comm):
)
print(
"[EPOCH_TIME] %.2f seconds."
% ((train_end_time - train_start_time) / options.epochs,)
% ((train_end_time - train_start_time) / args.epochs,)
)
valid(valid_dataloader, model)

Expand All @@ -135,11 +132,11 @@ def main_func():
wgth.get_local_size(),
)

if options.use_cpp_ext:
if args.use_cpp_ext:
wgth.compile_cpp_extension()

train_ds, valid_ds, test_ds = wgth.create_node_claffication_datasets(
options.pickle_data_path
args.pickle_data_path
)

graph_structure = wgth.GraphStructure()
Expand All @@ -152,70 +149,70 @@ def main_func():
graph_comm = global_comm
graph_structure_wholememory_type = "continuous"
graph_structure_wholememory_location = "cuda"
if not options.use_global_embedding:
options.use_global_embedding = True
if not args.use_global_embedding:
args.use_global_embedding = True
print("Changing to using global communicator for embedding...")
if options.embedding_memory_type == "chunked":
if args.embedding_memory_type == "chunked":
print("Changing to continuous wholememory for embedding...")
options.embedding_memory_type = "continuous"
args.embedding_memory_type = "continuous"

csr_row_ptr_wm_tensor = wgth.create_wholememory_tensor_from_filelist(
graph_comm,
graph_structure_wholememory_type,
graph_structure_wholememory_location,
os.path.join(options.root_dir, "homograph_csr_row_ptr"),
os.path.join(args.root_dir, "homograph_csr_row_ptr"),
torch.int64,
)
csr_col_ind_wm_tensor = wgth.create_wholememory_tensor_from_filelist(
graph_comm,
graph_structure_wholememory_type,
graph_structure_wholememory_location,
os.path.join(options.root_dir, "homograph_csr_col_idx"),
os.path.join(args.root_dir, "homograph_csr_col_idx"),
torch.int,
)
graph_structure.set_csr_graph(csr_row_ptr_wm_tensor, csr_col_ind_wm_tensor)

feature_comm = global_comm if options.use_global_embedding else local_comm
feature_comm = global_comm if args.use_global_embedding else local_comm

embedding_wholememory_type = options.embedding_memory_type
embedding_wholememory_type = args.embedding_memory_type
embedding_wholememory_location = (
"cpu" if options.cache_type != "none" or options.cache_ratio == 0.0 else "cuda"
"cpu" if args.cache_type != "none" or args.cache_ratio == 0.0 else "cuda"
)
if options.cache_ratio == 0.0:
options.cache_type = "none"
access_type = "readonly" if options.train_embedding is False else "readwrite"
if args.cache_ratio == 0.0:
args.cache_type = "none"
access_type = "readonly" if args.train_embedding is False else "readwrite"
if wgth.get_rank() == 0:
print(
f"graph_structure: type={graph_structure_wholememory_type}, "
f"location={graph_structure_wholememory_location}\n"
f"embedding: type={embedding_wholememory_type}, location={embedding_wholememory_location}, "
f"cache_type={options.cache_type}, cache_ratio={options.cache_ratio}, "
f"trainable={options.train_embedding}"
f"cache_type={args.cache_type}, cache_ratio={args.cache_ratio}, "
f"trainable={args.train_embedding}"
)
cache_policy = wgth.create_builtin_cache_policy(
options.cache_type,
args.cache_type,
embedding_wholememory_type,
embedding_wholememory_location,
access_type,
options.cache_ratio,
args.cache_ratio,
)

wm_optimizer = (
None
if options.train_embedding is False
if args.train_embedding is False
else wgth.create_wholememory_optimizer("adam", {})
)

embedding_dtype = torch.float32 if not options.fp16_mbedding else torch.float16
embedding_dtype = torch.float32 if not args.fp16_mbedding else torch.float16

if wm_optimizer is None:
node_feat_wm_embedding = wgth.create_embedding_from_filelist(
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
os.path.join(options.root_dir, "node_feat.bin"),
os.path.join(args.root_dir, "node_feat.bin"),
embedding_dtype,
options.feat_dim,
args.feat_dim,
optimizer=wm_optimizer,
cache_policy=cache_policy,
)
Expand All @@ -225,16 +222,16 @@ def main_func():
embedding_wholememory_type,
embedding_wholememory_location,
embedding_dtype,
[graph_structure.node_count, options.feat_dim],
[graph_structure.node_count, args.feat_dim],
optimizer=wm_optimizer,
cache_policy=cache_policy,
random_init=True,
)
wgth.set_framework(options.framework)
model = wgth.HomoGNNModel(graph_structure, node_feat_wm_embedding, options)
wgth.set_framework(args.framework)
model = wgth.HomoGNNModel(graph_structure, node_feat_wm_embedding, args)
model.cuda()
model = DDP(model, delay_allreduce=True)
optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=options.lr)
optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=args.lr)

train(train_ds, valid_ds, model, optimizer, wm_optimizer, global_comm)
test(test_ds, model)
Expand All @@ -243,4 +240,4 @@ def main_func():


if __name__ == "__main__":
wgth.distributed_launch(options, main_func)
wgth.distributed_launch(args, main_func)
Loading

0 comments on commit 405d86c

Please sign in to comment.