Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

replace optparse with argparser #61

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 45 additions & 48 deletions python/pylibwholegraph/examples/node_classfication.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,26 @@
import datetime
import os
import time
from optparse import OptionParser
import argparse

import apex
import torch
from apex.parallel import DistributedDataParallel as DDP

import pylibwholegraph.torch as wgth

parser = OptionParser()

wgth.add_distributed_launch_options(parser)
wgth.add_training_options(parser)
wgth.add_common_graph_options(parser)
wgth.add_common_model_options(parser)
wgth.add_common_sampler_options(parser)
wgth.add_node_classfication_options(parser)
wgth.add_dataloader_options(parser)
parser.add_option(
argparser = argparse.ArgumentParser()
wgth.add_distributed_launch_options(argparser)
wgth.add_training_options(argparser)
wgth.add_common_graph_options(argparser)
wgth.add_common_model_options(argparser)
wgth.add_common_sampler_options(argparser)
wgth.add_node_classfication_options(argparser)
wgth.add_dataloader_options(argparser)
argparser.add_argument(
"--fp16_embedding", action="store_true", dest="fp16_mbedding", default=False, help="Whether to use fp16 embedding"
)


(options, args) = parser.parse_args()
args = argparser.parse_args()


def valid_test(dataloader, model, name):
Expand Down Expand Up @@ -68,7 +65,7 @@ def valid(valid_dataloader, model):


def test(test_dataset, model):
test_dataloader = wgth.get_valid_test_dataloader(test_dataset, options.batchsize)
test_dataloader = wgth.get_valid_test_dataloader(test_dataset, args.batchsize)
valid_test(test_dataloader, model, "TEST")


Expand All @@ -77,19 +74,19 @@ def train(train_data, valid_data, model, optimizer, wm_optimizer, global_comm):
print("start training...")
train_dataloader = wgth.get_train_dataloader(
train_data,
options.batchsize,
args.batchsize,
replica_id=wgth.get_rank(),
num_replicas=wgth.get_world_size(),
num_workers=options.dataloaderworkers,
num_workers=args.dataloaderworkers,
)
valid_dataloader = wgth.get_valid_test_dataloader(valid_data, options.batchsize)
valid_dataloader = wgth.get_valid_test_dataloader(valid_data, args.batchsize)
valid(valid_dataloader, model)

train_step = 0
epoch = 0
loss_fcn = torch.nn.CrossEntropyLoss()
train_start_time = time.time()
while epoch < options.epochs:
while epoch < args.epochs:
for i, (idx, label) in enumerate(train_dataloader):
label = torch.reshape(label, (-1,)).cuda()
optimizer.zero_grad()
Expand All @@ -99,7 +96,7 @@ def train(train_data, valid_data, model, optimizer, wm_optimizer, global_comm):
loss.backward()
optimizer.step()
if wm_optimizer is not None:
wm_optimizer.step(options.lr * 0.1)
wm_optimizer.step(args.lr * 0.1)
if wgth.get_rank() == 0 and train_step % 100 == 0:
print(
"[%s] [LOSS] step=%d, loss=%f"
Expand All @@ -121,7 +118,7 @@ def train(train_data, valid_data, model, optimizer, wm_optimizer, global_comm):
)
print(
"[EPOCH_TIME] %.2f seconds."
% ((train_end_time - train_start_time) / options.epochs,)
% ((train_end_time - train_start_time) / args.epochs,)
)
valid(valid_dataloader, model)

Expand All @@ -135,11 +132,11 @@ def main_func():
wgth.get_local_size(),
)

if options.use_cpp_ext:
if args.use_cpp_ext:
wgth.compile_cpp_extension()

train_ds, valid_ds, test_ds = wgth.create_node_claffication_datasets(
options.pickle_data_path
args.pickle_data_path
)

graph_structure = wgth.GraphStructure()
Expand All @@ -152,70 +149,70 @@ def main_func():
graph_comm = global_comm
graph_structure_wholememory_type = "continuous"
graph_structure_wholememory_location = "cuda"
if not options.use_global_embedding:
options.use_global_embedding = True
if not args.use_global_embedding:
args.use_global_embedding = True
print("Changing to using global communicator for embedding...")
if options.embedding_memory_type == "chunked":
if args.embedding_memory_type == "chunked":
print("Changing to continuous wholememory for embedding...")
options.embedding_memory_type = "continuous"
args.embedding_memory_type = "continuous"

csr_row_ptr_wm_tensor = wgth.create_wholememory_tensor_from_filelist(
graph_comm,
graph_structure_wholememory_type,
graph_structure_wholememory_location,
os.path.join(options.root_dir, "homograph_csr_row_ptr"),
os.path.join(args.root_dir, "homograph_csr_row_ptr"),
torch.int64,
)
csr_col_ind_wm_tensor = wgth.create_wholememory_tensor_from_filelist(
graph_comm,
graph_structure_wholememory_type,
graph_structure_wholememory_location,
os.path.join(options.root_dir, "homograph_csr_col_idx"),
os.path.join(args.root_dir, "homograph_csr_col_idx"),
torch.int,
)
graph_structure.set_csr_graph(csr_row_ptr_wm_tensor, csr_col_ind_wm_tensor)

feature_comm = global_comm if options.use_global_embedding else local_comm
feature_comm = global_comm if args.use_global_embedding else local_comm

embedding_wholememory_type = options.embedding_memory_type
embedding_wholememory_type = args.embedding_memory_type
embedding_wholememory_location = (
"cpu" if options.cache_type != "none" or options.cache_ratio == 0.0 else "cuda"
"cpu" if args.cache_type != "none" or args.cache_ratio == 0.0 else "cuda"
)
if options.cache_ratio == 0.0:
options.cache_type = "none"
access_type = "readonly" if options.train_embedding is False else "readwrite"
if args.cache_ratio == 0.0:
args.cache_type = "none"
access_type = "readonly" if args.train_embedding is False else "readwrite"
if wgth.get_rank() == 0:
print(
f"graph_structure: type={graph_structure_wholememory_type}, "
f"location={graph_structure_wholememory_location}\n"
f"embedding: type={embedding_wholememory_type}, location={embedding_wholememory_location}, "
f"cache_type={options.cache_type}, cache_ratio={options.cache_ratio}, "
f"trainable={options.train_embedding}"
f"cache_type={args.cache_type}, cache_ratio={args.cache_ratio}, "
f"trainable={args.train_embedding}"
)
cache_policy = wgth.create_builtin_cache_policy(
options.cache_type,
args.cache_type,
embedding_wholememory_type,
embedding_wholememory_location,
access_type,
options.cache_ratio,
args.cache_ratio,
)

wm_optimizer = (
None
if options.train_embedding is False
if args.train_embedding is False
else wgth.create_wholememory_optimizer("adam", {})
)

embedding_dtype = torch.float32 if not options.fp16_mbedding else torch.float16
embedding_dtype = torch.float32 if not args.fp16_mbedding else torch.float16

if wm_optimizer is None:
node_feat_wm_embedding = wgth.create_embedding_from_filelist(
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
os.path.join(options.root_dir, "node_feat.bin"),
os.path.join(args.root_dir, "node_feat.bin"),
embedding_dtype,
options.feat_dim,
args.feat_dim,
optimizer=wm_optimizer,
cache_policy=cache_policy,
)
Expand All @@ -225,16 +222,16 @@ def main_func():
embedding_wholememory_type,
embedding_wholememory_location,
embedding_dtype,
[graph_structure.node_count, options.feat_dim],
[graph_structure.node_count, args.feat_dim],
optimizer=wm_optimizer,
cache_policy=cache_policy,
random_init=True,
)
wgth.set_framework(options.framework)
model = wgth.HomoGNNModel(graph_structure, node_feat_wm_embedding, options)
wgth.set_framework(args.framework)
model = wgth.HomoGNNModel(graph_structure, node_feat_wm_embedding, args)
model.cuda()
model = DDP(model, delay_allreduce=True)
optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=options.lr)
optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=args.lr)

train(train_ds, valid_ds, model, optimizer, wm_optimizer, global_comm)
test(test_ds, model)
Expand All @@ -243,4 +240,4 @@ def main_func():


if __name__ == "__main__":
wgth.distributed_launch(options, main_func)
wgth.distributed_launch(args, main_func)
Loading