From d0ba4d4b3040f60f381b0b194b140bcb745a620e Mon Sep 17 00:00:00 2001 From: ChenglongMa Date: Sat, 23 Sep 2023 01:22:24 +1000 Subject: [PATCH 1/5] Fix bugs when collecting results from `mp.spawn` --- recbole/quick_start/__init__.py | 1 + recbole/quick_start/quick_start.py | 64 ++++++++++++++++++- run_recbole.py | 36 ++++------- run_recbole_group.py | 44 ++++---------- significance_test.py | 98 ++++++++++++++++-------------- 5 files changed, 136 insertions(+), 107 deletions(-) diff --git a/recbole/quick_start/__init__.py b/recbole/quick_start/__init__.py index 58b937d6a..2fe193a15 100644 --- a/recbole/quick_start/__init__.py +++ b/recbole/quick_start/__init__.py @@ -1,4 +1,5 @@ from recbole.quick_start.quick_start import ( + run, run_recbole, objective_function, load_data_and_model, diff --git a/recbole/quick_start/quick_start.py b/recbole/quick_start/quick_start.py index a898584e2..2e73d410a 100644 --- a/recbole/quick_start/quick_start.py +++ b/recbole/quick_start/quick_start.py @@ -39,8 +39,62 @@ ) +def run( + model, + dataset, + config_file_list=None, + config_dict=None, + saved=True, + nproc=1, + world_size=-1, + ip="localhost", + port="5678", + group_offset=0, +): + if nproc == 1 and world_size <= 0: + res = run_recbole( + model=model, + dataset=dataset, + config_file_list=config_file_list, + config_dict=config_dict, + saved=saved, + ) + else: + if world_size == -1: + world_size = nproc + import torch.multiprocessing as mp + + # Refer to https://discuss.pytorch.org/t/problems-with-torch-multiprocess-spawn-and-simplequeue/69674/2 + # https://discuss.pytorch.org/t/return-from-mp-spawn/94302/2 + queue = mp.get_context('spawn').SimpleQueue() + + config_dict = config_dict or {} + config_dict.update({ + "world_size": world_size, + "ip": ip, + "port": port, + "nproc": nproc, + "offset": group_offset, + }) + kwargs = { + "config_dict": config_dict, + "queue": queue, + } + + mp.spawn( + run_recboles, + args=(model, dataset, config_file_list, kwargs), + nprocs=nproc, + join=True, + ) + + # Normally, there should be only one item in the queue + res = None if queue.empty() else queue.get() + return res + + def run_recbole( - model=None, dataset=None, config_file_list=None, config_dict=None, saved=True + model=None, dataset=None, config_file_list=None, config_dict=None, saved=True, queue=None ): r"""A fast running api, which includes the complete process of training and testing a model on a specified dataset @@ -51,6 +105,7 @@ def run_recbole( config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. saved (bool, optional): Whether to save the model. Defaults to ``True``. + queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``. """ # configurations initialization config = Config( @@ -104,13 +159,18 @@ def run_recbole( logger.info(set_color("best valid ", "yellow") + f": {best_valid_result}") logger.info(set_color("test result", "yellow") + f": {test_result}") - return { + result = { "best_valid_score": best_valid_score, "valid_score_bigger": config["valid_metric_bigger"], "best_valid_result": best_valid_result, "test_result": test_result, } + if config["local_rank"] == 0 and queue is not None: + queue.put(result) # for multiprocessing, e.g., mp.spawn + + return result # for the single process + def run_recboles(rank, *args): ip, port, world_size, nproc, offset = args[3:] diff --git a/run_recbole.py b/run_recbole.py index 09e0740e2..2badf308c 100644 --- a/run_recbole.py +++ b/run_recbole.py @@ -8,9 +8,8 @@ # @Email : chenyuwuxinn@gmail.com, houyupeng@ruc.edu.cn, zhlin@ruc.edu.cn import argparse -from ast import arg -from recbole.quick_start import run_recbole, run_recboles +from recbole.quick_start import run if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -44,26 +43,13 @@ args.config_files.strip().split(" ") if args.config_files else None ) - if args.nproc == 1 and args.world_size <= 0: - run_recbole( - model=args.model, dataset=args.dataset, config_file_list=config_file_list - ) - else: - if args.world_size == -1: - args.world_size = args.nproc - import torch.multiprocessing as mp - - mp.spawn( - run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, - ) + run( + args.model, + args.dataset, + config_file_list=config_file_list, + nproc=args.nproc, + world_size=args.world_size, + ip=args.ip, + port=args.port, + group_offset=args.group_offset, + ) diff --git a/run_recbole_group.py b/run_recbole_group.py index 925f1d41a..2468b9577 100644 --- a/run_recbole_group.py +++ b/run_recbole_group.py @@ -4,41 +4,10 @@ import argparse -from ast import arg -from recbole.quick_start import run_recbole, run_recboles +from recbole.quick_start import run from recbole.utils import list_to_latex - -def run(args, model, config_file_list): - if args.nproc == 1 and args.world_size <= 0: - res = run_recbole( - model=model, - dataset=args.dataset, - config_file_list=config_file_list, - ) - else: - if args.world_size == -1: - args.world_size = args.nproc - import torch.multiprocessing as mp - - res = mp.spawn( - run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, - ) - return res - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -92,7 +61,16 @@ def run(args, model, config_file_list): valid_res_dict = {"Model": model} test_res_dict = {"Model": model} - result = run(args, model, config_file_list) + result = run( + model, + args.dataset, + config_file_list=config_file_list, + nproc=args.nproc, + world_size=args.world_size, + ip=args.ip, + port=args.port, + group_offset=args.group_offset, + ) valid_res_dict.update(result["best_valid_result"]) test_res_dict.update(result["test_result"]) bigger_flag = result["valid_score_bigger"] diff --git a/significance_test.py b/significance_test.py index bcd65c321..589883482 100644 --- a/significance_test.py +++ b/significance_test.py @@ -8,43 +8,41 @@ # @Email : import argparse -from ast import arg import random -import sys from collections import defaultdict -from scipy import stats - -from recbole.quick_start import run_recbole, run_recboles +from scipy import stats -def run(args, seed): - if args.nproc == 1 and args.world_size <= 0: - res = run_recbole( - model=args.model, - dataset=args.dataset, - config_file_list=config_file_list, +from recbole.quick_start import run + + +def run_test( + model, + dataset, + config_files, + seeds, + nproc, + world_size, + ip, + port, + group_offset, +): + results = defaultdict(list) + for seed in seeds: + res = run( + model, + dataset, + config_files, config_dict={"seed": seed}, + nproc=nproc, + world_size=world_size, + ip=ip, + port=port, + group_offset=group_offset, ) - else: - if args.world_size == -1: - args.world_size = args.nproc - import torch.multiprocessing as mp - - res = mp.spawn( - run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, - ) - return res + for _key, _value in res["test_result"].items(): + results[_key].append(_value) + return results if __name__ == "__main__": @@ -101,24 +99,30 @@ def run(args, seed): random.seed(args.st_seed) random_seeds = [random.randint(0, 2**32 - 1) for _ in range(args.run_times)] - result_ours = defaultdict(list) - result_baseline = defaultdict(list) - config_file_ours, config_file_baseline = config_file_list - args.model = args.model_ours - args.config_file_list = [result_ours] - for seed in random_seeds: - res = run(args, seed) - for key, value in res["test_result"].items(): - result_ours[key].append(value) - - args.model = args.model_baseline - args.config_file_list = [config_file_baseline] - for seed in random_seeds: - res = run(args, seed) - for key, value in res["test_result"].items(): - result_baseline[key].append(value) + result_ours = run_test( + args.model_ours, + args.dataset, + [config_file_ours], + random_seeds, + args.nproc, + args.world_size, + args.ip, + args.port, + args.group_offset, + ) + result_baseline = run_test( + args.model_baseline, + args.dataset, + [config_file_baseline], + random_seeds, + args.nproc, + args.world_size, + args.ip, + args.port, + args.group_offset, + ) final_result = {} for key, value in result_ours.items(): From f3d92df1aff39deb5128407be6fec0bbd3c539d8 Mon Sep 17 00:00:00 2001 From: ChenglongMa Date: Sat, 23 Sep 2023 01:30:40 +1000 Subject: [PATCH 2/5] Update docs of distributed training --- .../get_started/distributed_training.rst | 68 +++++++++++++------ 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/docs/source/get_started/distributed_training.rst b/docs/source/get_started/distributed_training.rst index ed7e6960c..b9ecf3be5 100644 --- a/docs/source/get_started/distributed_training.rst +++ b/docs/source/get_started/distributed_training.rst @@ -121,21 +121,33 @@ In above example, you can create a new python file (e.g., `run_a.py`) on node A, nproc = 4, group_offset = 0 ) + + # Optional, only needed if you want to get the result of each process. + queue = mp.get_context('spawn').SimpleQueue() + + config_dict = config_dict or {} + config_dict.update({ + "world_size": args.world_size, + "ip": args.ip, + "port": args.port, + "nproc": args.nproc, + "offset": args.group_offset, + }) + kwargs = { + "config_dict": config_dict, + "queue": queue, # Optional + } + mp.spawn( run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, + args=(args.model, args.dataset, args.config_file_list, kwargs), + nprocs=nproc, + join=True, ) + # Normally, there should be only one item in the queue + res = None if queue.empty() else queue.get() + Then run the following command: @@ -159,21 +171,33 @@ Similarly, you can create a new python file (e.g., `run_b.py`) on node B, and wr nproc = 4, group_offset = 4 ) + + # Optional, only needed if you want to get the result of each process. + queue = mp.get_context('spawn').SimpleQueue() + + config_dict = config_dict or {} + config_dict.update({ + "world_size": args.world_size, + "ip": args.ip, + "port": args.port, + "nproc": args.nproc, + "offset": args.group_offset, + }) + kwargs = { + "config_dict": config_dict, + "queue": queue, # Optional + } + mp.spawn( run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, + args=(args.model, args.dataset, args.config_file_list, kwargs), + nprocs=nproc, + join=True, ) + # Normally, there should be only one item in the queue + res = None if queue.empty() else queue.get() + Then run the following command: From 85634a6ac78d4f83f2fe865e0865b84185cd5b33 Mon Sep 17 00:00:00 2001 From: ChenglongMa Date: Wed, 27 Sep 2023 10:25:55 +1000 Subject: [PATCH 3/5] Update `run_recboles` function --- recbole/quick_start/quick_start.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/recbole/quick_start/quick_start.py b/recbole/quick_start/quick_start.py index 2e73d410a..097acc6d9 100644 --- a/recbole/quick_start/quick_start.py +++ b/recbole/quick_start/quick_start.py @@ -173,18 +173,12 @@ def run_recbole( def run_recboles(rank, *args): - ip, port, world_size, nproc, offset = args[3:] - args = args[:3] + kwargs = args[-1] + kwargs["config_dict"] = kwargs.get("config_dict", {}) + kwargs["config_dict"]["local_rank"] = rank run_recbole( - *args, - config_dict={ - "local_rank": rank, - "world_size": world_size, - "ip": ip, - "port": port, - "nproc": nproc, - "offset": offset, - }, + *args[:3], + **kwargs, ) From d6c1cf2f8258630de5fed8c0ed9339edd7d03676 Mon Sep 17 00:00:00 2001 From: ChenglongMa Date: Wed, 27 Sep 2023 10:44:07 +1000 Subject: [PATCH 4/5] Add data type check --- recbole/quick_start/quick_start.py | 37 ++++++++++++++++++------------ 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/recbole/quick_start/quick_start.py b/recbole/quick_start/quick_start.py index 097acc6d9..4d8368e9c 100644 --- a/recbole/quick_start/quick_start.py +++ b/recbole/quick_start/quick_start.py @@ -12,20 +12,16 @@ ######################## """ import logging -from logging import getLogger - import sys +from collections.abc import MutableMapping +from logging import getLogger - -import pickle from ray import tune from recbole.config import Config from recbole.data import ( create_dataset, data_preparation, - save_split_dataloaders, - load_split_dataloaders, ) from recbole.data.transform import construct_transform from recbole.utils import ( @@ -69,13 +65,15 @@ def run( queue = mp.get_context('spawn').SimpleQueue() config_dict = config_dict or {} - config_dict.update({ - "world_size": world_size, - "ip": ip, - "port": port, - "nproc": nproc, - "offset": group_offset, - }) + config_dict.update( + { + "world_size": world_size, + "ip": ip, + "port": port, + "nproc": nproc, + "offset": group_offset, + } + ) kwargs = { "config_dict": config_dict, "queue": queue, @@ -94,7 +92,12 @@ def run( def run_recbole( - model=None, dataset=None, config_file_list=None, config_dict=None, saved=True, queue=None + model=None, + dataset=None, + config_file_list=None, + config_dict=None, + saved=True, + queue=None, ): r"""A fast running api, which includes the complete process of training and testing a model on a specified dataset @@ -169,11 +172,15 @@ def run_recbole( if config["local_rank"] == 0 and queue is not None: queue.put(result) # for multiprocessing, e.g., mp.spawn - return result # for the single process + return result # for the single process def run_recboles(rank, *args): kwargs = args[-1] + if not isinstance(kwargs, MutableMapping): + raise ValueError( + f"The last argument of run_recboles should be a dict, but got {type(kwargs)}" + ) kwargs["config_dict"] = kwargs.get("config_dict", {}) kwargs["config_dict"]["local_rank"] = rank run_recbole( From d7fd7932554eee24907c71ffb4a1db4569c0401d Mon Sep 17 00:00:00 2001 From: ChenglongMa Date: Sat, 30 Sep 2023 01:16:26 +1000 Subject: [PATCH 5/5] Add clean-up function --- recbole/quick_start/quick_start.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/recbole/quick_start/quick_start.py b/recbole/quick_start/quick_start.py index 4d8368e9c..0300703fe 100644 --- a/recbole/quick_start/quick_start.py +++ b/recbole/quick_start/quick_start.py @@ -13,6 +13,7 @@ """ import logging import sys +import torch.distributed as dist from collections.abc import MutableMapping from logging import getLogger @@ -169,6 +170,9 @@ def run_recbole( "test_result": test_result, } + if not config["single_spec"]: + dist.destroy_process_group() + if config["local_rank"] == 0 and queue is not None: queue.put(result) # for multiprocessing, e.g., mp.spawn