Skip to content

Commit

Permalink
to debug mpi
Browse files Browse the repository at this point in the history
  • Loading branch information
Qiang Zhu committed Sep 1, 2024
1 parent 065000c commit a146cc6
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 78 deletions.
4 changes: 2 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __getattr__(cls, name):
author = "Qiang Zhu, Scott Fredericks, Kevin Parrish"

# The short X.Y version
version = "1.0.2"
version = "1.0.3"
# The full version, including alpha/beta/rc tags
release = "1.0.2"
release = "1.0.3"

# -- General configuration ---------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ab-initio generation of random crystal structures. It has the following features
- Structural manipulation via symmetry relation (both subgroup and supergroup)
- Geometry optimization from built-in and external optimization methods

The current version is ``1.0.2`` at `GitHub <https://github.com/MaterSim/PyXtal>`_.
The current version is ``1.0.3`` at `GitHub <https://github.com/MaterSim/PyXtal>`_.
It is available for use under the MIT license. Expect updates upon request by
`Qiang Zhu\'s group <https://qzhu2017.github.io>`_ at the
University of North Carolina at Charlotte.
Expand Down
2 changes: 2 additions & 0 deletions pyxtal/optimize/WFS.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def run_mpi(self):
success_rate = 0

for gen in range(self.N_gen):
print(f"Rank {self.rank} entering generation {gen}")

current_xtals = None

Expand All @@ -192,6 +193,7 @@ def run_mpi(self):

# broadcast
current_xtals = self.comm.bcast(current_xtals, root=0)
print(f"Rank {self.rank} after broadcast: current_xtals = {current_xtals}")

# Local optimization
gen_results = self.local_optimization(gen, current_xtals)
Expand Down
183 changes: 109 additions & 74 deletions pyxtal/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
"""
from __future__ import annotations
import multiprocessing
from multiprocessing import Process, Queue
from concurrent.futures import ProcessPoolExecutor, TimeoutError
from concurrent.futures import ThreadPoolExecutor

import logging
import os
Expand All @@ -26,16 +28,33 @@
from pyxtal.lattice import Lattice
from pyxtal.symmetry import Group

def run_optimizer_single_with_timeout(args, timeout=30.0):
"""Run optimizer_single with a timeout."""
with multiprocessing.Pool(processes=1) as pool:
result = pool.apply_async(optimizer_single, args)
def run_optimizer_single_with_timeout(args, timeout=60.0):
"""Run optimizer_single with a timeout and error handling."""
def worker(*args):
try:
return result.get(timeout=timeout)
except multiprocessing.TimeoutError:
return optimizer_single(*args)
except Exception as e:
print(f"Error in worker thread: {e}")
return None, False # Or any default value you prefer

with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(worker, *args)
try:
return future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
print(f"Timeout: Optimization took longer than {timeout} seconds.")
return None, False # Return a default value indicating a timeout

def run_optimizer(args, timeout=60):
"""
Wrapper function to run the optimizer with a timeout.
This function can be pickled and used with multiprocessing.
"""
try:
return run_optimizer_single_with_timeout(args, timeout)
except Exception as e:
print(f"Error in run_optimizer: {e}")
return None, False # Return a default value indicating an error

class GlobalOptimize:
"""
Expand Down Expand Up @@ -107,17 +126,15 @@ def __init__(
from mpi4py import MPI
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
node_name = MPI.Get_processor_name()
unique_node_names = self.comm.gather(node_name, root=0)
self.node = node_name
self.size = len(unique_node_names)
#self.size = self.comm.Get_size()
else:
self.rank = 0
self.size = self.ncpu

# General information
if isinstance(random_state, Generator):
self.random_state = random_state.spawn(1)[0]
else:
self.random_state = np.random.default_rng(random_state)

# Molecular information
self.smile = smiles
self.smiles = self.smile.split(".") # list
Expand Down Expand Up @@ -173,23 +190,53 @@ def __init__(
self.ff_opt = ff_opt
self.ff_style = ff_style

if info is not None:
self.atom_info = info
self.parameters = None
self.ff_opt = False
# Structure matcher
if matcher is None:
self.matcher = StructureMatcher(ltol=0.3, stol=0.3, angle_tol=5)
else:
self.ff_parameters = ff_parameters
self.reference_file = reference_file
self.parameters = ForceFieldParameters(self.smiles, style=ff_style, f_coef=1.0, s_coef=1.0, ncpu=self.ncpu)
# Only call ForceFieldParameters once
# No need to broadcast self.parameters?
# Just broadcast atom_info should be fine
#parameters = None
atom_info = None
if self.rank == 0:
self.matcher = matcher

self.E_max = E_max
self.tag = tag
self.suffix = f"{self.workdir:s}/{self.name:s}-{self.ff_style:s}"


atom_info = None
if self.rank == 0:

# General information
if isinstance(random_state, Generator):
self.random_state = random_state.spawn(1)[0]
else:
self.random_state = np.random.default_rng(random_state)

# I/O stuff
self.early_quit = early_quit
self.N_min_matches = 10 # The min_num_matches for early termination

# Setup logger
logging.getLogger().handlers.clear()
logging.basicConfig(format="%(asctime)s| %(message)s",
filename=self.log_file, level=logging.INFO)
self.logging = logging

# Some neccessary trackers
self.matches = []
self.best_reps = []
self.reps = []
self.engs = []


if info is not None:
atom_info = info
self.parameters = None
self.ff_opt = False
else:
# Only call ForceFieldParameters in rank 0
self.ff_parameters = ff_parameters
self.reference_file = reference_file
os.makedirs(self.workdir, exist_ok=True)

from pyocse.parameters import ForceFieldParameters
self.parameters = ForceFieldParameters(
self.smiles, style=ff_style, f_coef=1.0, s_coef=1.0, ncpu=self.ncpu)
Expand Down Expand Up @@ -226,26 +273,6 @@ def __init__(
self.parameters.export_parameters(
self.workdir + "/" + self.ff_parameters, params0)
atom_info = self._prepare_chm_info(params0, suffix='pyxtal')

if self.use_mpi:
#self.parameters = self.comm.bcast(parameters, root=0)
self.atom_info = self.comm.bcast(atom_info, root=0)
else:
self.atom_info = atom_info

# Structure matcher
if matcher is None:
self.matcher = StructureMatcher(ltol=0.3, stol=0.3, angle_tol=5)
else:
self.matcher = matcher

# I/O stuff
self.early_quit = early_quit
self.N_min_matches = 10 # The min_num_matches for early termination
self.E_max = E_max
self.tag = tag
self.suffix = f"{self.workdir:s}/{self.name:s}-{self.ff_style:s}"
if self.rank == 0:
if cif is None:
self.cif = self.suffix + '.cif'
else:
Expand All @@ -256,17 +283,11 @@ def __init__(
self.matched_cif = self.suffix + "-matched.cif"
# print(self)

# Setup logger
logging.getLogger().handlers.clear()
logging.basicConfig(format="%(asctime)s| %(message)s",
filename=self.log_file, level=logging.INFO)
self.logging = logging
if self.use_mpi:
self.atom_info = self.comm.bcast(atom_info, root=0)
else:
self.atom_info = atom_info

# Some neccessary trackers
self.matches = []
self.best_reps = []
self.reps = []
self.engs = []

def print(self, *args, **kwargs):
"""Utility method to print only from rank 0."""
Expand Down Expand Up @@ -333,9 +354,9 @@ def run(self, ref_pmg=None, ref_pxrd=None):
results = self.run_mpi()
else:
results = self.run_serial()

t = (time() - t0)/60
self.print(f"{self.name:s} COMPLETED in {t:.1f} mins {self.N_struc:d} strucs.")
if self.rank == 0:
t = (time() - t0)/60
print(f"{self.name:s} COMPLETED in {t:.1f} mins {self.N_struc:d} strucs.")

return results

Expand Down Expand Up @@ -943,38 +964,52 @@ def local_optimization_mpi(self, gen, xtals, qrs=False):
Perform MPI optimization for each structure in each generation.
Args:
gen (int):
xtals : list of (xtal, tag) tuples
gen (int): Current generation number.
xtals (list) : list of (xtal, tag) tuples
qrs (bool): Force mutation or not (related to QRS)
"""
self.print("Local optimization enabled by MPI", self.size)
print("Local optimization enabled by MPI", self.size, self.rank)

# Prepare arguments for the optimization function
args = self._get_local_optimization_args()
args_lists = []
for i in range(self.N_pop):
job_tag = self.tag + "-g" + str(gen) + "-p" + str(i)
xtal = xtals[i][0]
if qrs:
mutate = False
else:
mutate = xtal is not None
mutate = False if qrs else xtal is not None
my_args = [xtal, i, mutate, job_tag, *args]
args_lists.append(tuple(my_args))

# Distribute args_lists across available ranks (processes)
local_args = args_lists[self.rank::self.size]

# Run the optimizer in parallel using MPI
local_results = []
for args in local_args:
# print('rank', self.rank, 'id', args[1])
xtal, match = run_optimizer_single_with_timeout(args, timeout=60)
# (xtal, match) = optimizer_single(*args)
local_results.append((args[1], xtal, match))
#local_results = []
#for args in local_args:
# # print('rank', self.rank, 'id', args[1])
# xtal, match = run_optimizer_single_with_timeout(args, timeout=60)
# # (xtal, match) = optimizer_single(*args)
# local_results.append((args[1], xtal, match))

# Determine the number of cores available on the node
num_cores = multiprocessing.cpu_count() // self.size
print("Local optimization distribution", self.rank, num_cores)

# Use multiprocessing within each MPI rank (node)
with multiprocessing.Pool(processes=num_cores) as pool:
local_results = pool.map(run_optimizer, local_args)

#local_results = []
#for args in local_args:
# print("Local optimization", self.rank, args[1])
# result = run_optimizer(args)
# print("Local optimization", self.rank, result)
# local_results.append(result)

# Gather all results at the root process
all_results = self.comm.gather(local_results, root=0)

# If root process, process the results
# Process results at the root process
gen_results = None
if self.rank == 0:
gen_results = [(None, None)] * len(xtals)
Expand All @@ -983,7 +1018,7 @@ def local_optimization_mpi(self, gen, xtals, qrs=False):
(id, xtal, match) = res
gen_results[id] = (xtal, match)

# Broadcast
# Broadcast the processed results to all processes
gen_results = self.comm.bcast(gen_results, root=0)

return gen_results
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def run(self):

setup(
name="pyxtal",
version="1.0.2",
version="1.0.3",
author="Scott Fredericks, Kevin Parrish, Qiang Zhu",
author_email="[email protected]",
description="Python code for generation of crystal structures based on symmetry constraints.",
Expand Down

0 comments on commit a146cc6

Please sign in to comment.