Skip to content

Commit

Permalink
add optimize_reps to reduce the memory storage
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Sep 8, 2024
1 parent ab69c53 commit 7ca5fec
Showing 1 changed file with 86 additions and 2 deletions.
88 changes: 86 additions & 2 deletions pyxtal/lego/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def minimize_from_x(x, dim, spg, wps, elements, calculator, ref_environments,
while True:
count += 1
try:
xtal.from_random(dim, g, elements, numIons,
sites=sites_wp, factor=1.0,
xtal.from_random(dim, g, elements, numIons,
sites=sites_wp, factor=1.0,
random_state=random_state)
except RuntimeError:
print(g.number, numIons, sites)
Expand Down Expand Up @@ -600,6 +600,7 @@ def print_memory_usage(self):
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / 1024 ** 2
self.logging.info(f"Rank {self.rank} memory: {mem:.1f} MB")
print(f"Rank {self.rank} memory: {mem:.1f} MB")

def set_descriptor_calculator(self, dtype='SO3', mykwargs={}):
"""
Expand Down Expand Up @@ -793,6 +794,89 @@ def optimize_xtals_serial(self, xtals, args):
xtals_opt.append(xtal)
return xtals_opt

def optimize_reps(self, reps, ncpu=1, opt_type='local',
T=0.2, niter=20, early_quit=0.02,
add_db=True, symmetrize=False,
minimizers=[('Nelder-Mead', 100), ('L-BFGS-B', 100)],
):
"""
Perform optimization for each structure
Args:
reps: list of reps
ncpu (int):
"""
args = (opt_type, T, niter, early_quit, add_db, symmetrize, minimizers)
if ncpu > 1:
valid_xtals = self.optimize_reps_mproc(reps, ncpu, args)
return valid_xtals
else:
raise NotImplementedError("optimize_reps works in parallel mode")

def optimize_reps_mproc(self, reps, ncpu, args):
"""
Optimization in multiprocess mode.
Args:
reps: list of reps
ncpu (int): number of parallel python processes
args: (opt_type, T, n_iter, early_quit, add_db, symmetrize, minimizers)
"""
from multiprocessing import Pool
from collections import deque
import gc

pool = Pool(processes=ncpu)
(opt_type, T, niter, early_quit, add_db, symmetrize, minimizers) = args
xtals_opt = deque()

# Split the input structures to minibatches
N_rep = 4
N_batches = N_rep * ncpu
for _i, i in enumerate(range(0, len(reps), N_batches)):
start, end = i, min([i+N_batches, len(reps)])
ids = list(range(start, end))
print(f"Rank {self.rank} minibatch {start} {end}")
self.print_memory_usage()

def generate_args():
"""
A generator to yield argument lists for minimize_from_x_par.
"""
for j in range(ncpu):
_ids = ids[j::ncpu]
wp_libs = []
for id in _ids:
rep = reps[id]
xtal = pyxtal()
xtal.from_tabular_representation(rep, normalize=False)
x = xtal.get_1d_rep_x()
spg, wps, _ = self.get_input_from_ref_xtal(xtal)
wp_libs.append((x, spg, wps))
yield (self.dim, wp_libs, self.elements, self.calculator,
self.ref_environments, opt_type, T, niter,
early_quit, minimizers)
# Use the generator to pass args to reduce memory usage
_xtal, _xs = None, None
for result in pool.imap_unordered(minimize_from_x_par, generate_args()):
if result is not None:
(_xtals, _xs) = result
valid_xtals = self.process_xtals(
_xtals, _xs, add_db, symmetrize)
xtals_opt.extend(valid_xtals) # Use deque to reduce memory

# Remove the duplicate structures
self.db.update_row_topology(overwrite=False, prefix=self.prefix)
self.db.clean_structures_spg_topology(dim=self.dim)

# After each minibatch, delete the local variables and run garbage collection
del ids, _xtals, _xs
gc.collect() # Explicitly call garbage collector to free memory

xtals_opt = list(xtals_opt)
print(f"Rank {self.rank} finish optimize_reps_mproc {len(xtals_opt)}")
return xtals_opt

def optimize_xtals_mproc(self, xtals, ncpu, args):
"""
Optimization in multiprocess mode.
Expand Down

0 comments on commit 7ca5fec

Please sign in to comment.