Skip to content

Commit

Permalink
add the print out of memory
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Sep 8, 2024
1 parent b2acbc6 commit ab69c53
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
19 changes: 15 additions & 4 deletions pyxtal/lego/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,11 @@ def print_fun(x, f, accepted):

# Extract the optimized xtal
xtal = pyxtal()
xtal.from_1d_rep(res.x, sites, dim=dim)
return xtal, (x0, res.x)
try:
xtal.from_1d_rep(res.x, sites, dim=dim)
return xtal, (x0, res.x)
except:
return None


def calculate_dSdx(x, xtal, des_ref, f, eps=1e-4, symmetry=True, verbose=False):
Expand Down Expand Up @@ -592,6 +595,12 @@ def __str__(self):
def __repr__(self):
return str(self)

def print_memory_usage(self):
import psutil
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / 1024 ** 2
self.logging.info(f"Rank {self.rank} memory: {mem:.1f} MB")

def set_descriptor_calculator(self, dtype='SO3', mykwargs={}):
"""
Set up the calculator for descriptor computation.
Expand Down Expand Up @@ -808,6 +817,7 @@ def optimize_xtals_mproc(self, xtals, ncpu, args):
start, end = i, min([i+N_batches, len(xtals)])
ids = list(range(start, end))
print(f"Rank {self.rank} minibatch {start} {end}")
self.print_memory_usage()

def generate_args():
"""
Expand All @@ -825,7 +835,8 @@ def generate_args():
self.ref_environments, opt_type, T, niter,
early_quit, minimizers)
# Use the generator to pass args to reduce memory usage
for result in pool.imap_unordered(minimize_from_x_par, args_list):
_xtal, _xs = None, None
for result in pool.imap_unordered(minimize_from_x_par, generate_args()):
if result is not None:
(_xtals, _xs) = result
valid_xtals = self.process_xtals(
Expand All @@ -837,7 +848,7 @@ def generate_args():
self.db.clean_structures_spg_topology(dim=self.dim)

# After each minibatch, delete the local variables and run garbage collection
del ids, wp_libs, _xtals, _xs
del ids, _xtals, _xs
gc.collect() # Explicitly call garbage collector to free memory

xtals_opt = list(xtals_opt)
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/optimize/WFS.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _run(self, pool=None):

# Local optimization
gen_results = self.local_optimization(cur_xtals, pool=pool)
self.logging.info(f"Rank {self.rank} finishes local_opt")
self.logging.info(f"Rank {self.rank} finishes local_opt {len(gen_results)}")

prev_xtals = None
if self.rank == 0:
Expand Down
17 changes: 12 additions & 5 deletions pyxtal/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def handler(signum, frame):
return None # or some other placeholder for timeout results

def process_task(args):
#logger = args[-1]
#logger.info(f"Rank start processing task")
result = run_optimizer_with_timeout(args)#[:-1])
#logger.info(f"Rank finished processing task")
Expand Down Expand Up @@ -330,6 +329,13 @@ def __str__(self):
def __repr__(self):
return str(self)

def print_memory_usage(self):
import psutil
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / 1024 ** 2
gen = self.generation
self.logging.info(f"Rank {self.rank} memory: {mem:.1f} MB in gen {gen}")

def new_struc(self, xtal, xtals):
return new_struc(xtal, xtals)

Expand Down Expand Up @@ -967,7 +973,7 @@ def local_optimization(self, xtals, qrs=False, pool=None):
elif self.ncpu == 1:
return self.local_optimization_serial(xtals, qrs)
else:
print(f"Local optimization by multi-threads {ncpu}")
print(f"Local optimization by multi-threads {self.ncpu}")
return self.local_optimization_mproc(xtals, self.ncpu, qrs=qrs, pool=pool)

def local_optimization_serial(self, xtals, qrs=False):
Expand Down Expand Up @@ -1057,7 +1063,6 @@ def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None):
ids = range(len(xtals))

N_cycle = int(np.ceil(len(xtals) / ncpu))

# Generator to create arg_lists for multiprocessing tasks
def generate_args_lists():
for i in range(ncpu):
Expand All @@ -1071,7 +1076,8 @@ def generate_args_lists():
my_args = [_xtals, _ids, mutates, job_tags, *args, self.timeout]
yield tuple(my_args) # Yield args instead of appending to a list

self.logging.info(f"Rank {self.rank} assign args in local_opt_mproc")
self.print_memory_usage()
self.logging.info(f"Rank {self.rank} assign local_opt")

gen_results = []
# Stream the results to avoid holding too much in memory at once
Expand All @@ -1081,8 +1087,9 @@ def generate_args_lists():
for _res in result:
gen_results.append(_res)
# Explicitly delete the result and call garbage collection
def result
del result
gc.collect()
self.logging.info(f"Rank {self.rank} finish local_opt {len(gen_results)}")

return gen_results

Expand Down

0 comments on commit ab69c53

Please sign in to comment.