Skip to content

Commit

Permalink
Merge pull request #285 from MaterSim/network
Browse files Browse the repository at this point in the history
Network
  • Loading branch information
qzhu2017 authored Oct 12, 2024
2 parents 515fc66 + d31af9e commit 6b25f1f
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 65 deletions.
16 changes: 9 additions & 7 deletions pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ def subgroup_once(
min_cell=0,
mut_lat=True,
ignore_special=False,
verbose=False,
):
"""
Generate a structure with lower symmetry (for atomic crystals only)
Expand All @@ -861,8 +862,9 @@ def subgroup_once(
# print("Test subgroup_once", self.group.number, idx, sites)

if idx is None or len(idx) == 0:
msg = "Cannot find valid splitter id (likely due to insufficient cellsize)"
print(msg)
if verbose:
msg = "Cannot find valid splitter (likely due to large cellsize)"
print(msg)
return None

# Try 100 times to see if a valid split can be found
Expand Down Expand Up @@ -3680,7 +3682,7 @@ def get_tabular_representations(
if len(ids) > N_max:
ids = self.random_state.choice(ids, N_max)

print(f"N_reps {len(ids)}/{N_max}: ", self.get_xtal_string())
#print(f"N_reps {len(ids)}/{N_max}: ", self.get_xtal_string())
for sites_id in ids:
rep = self.get_tabular_representation(
sites_id, normalize, N_wp, perturb,
Expand All @@ -3694,7 +3696,7 @@ def get_tabular_representations(
def get_tabular_representation(
self,
ids=None,
normalize=True,
normalize=False,
N_wp=6,
perturb=False,
max_abc=50.0,
Expand Down Expand Up @@ -3777,7 +3779,7 @@ def from_tabular_representation(
rep,
max_abc=50.0,
max_angle=180,
normalize=True,
normalize=False,
tol=0.1,
discrete=False,
N_grids=50,
Expand Down Expand Up @@ -3857,9 +3859,9 @@ def from_tabular_representation(
# ; print(wp.get_label(), xyz)
xyz = wp.search_generator([x, y, z], tol=tol, symmetrize=True)
if xyz is not None:
xyz, wp, _ = wp.merge(xyz, np.eye(3), 1e-3)
#xyz, wp, _ = wp.merge(xyz, np.eye(3), 1e-3)
label = wp.get_label()
# ; print(x, y, z, label, xyz[0], xyz[1], xyz[2])
#print('after merge', x, y, z, label, xyz[0], xyz[1], xyz[2])
sites.append((label, xyz[0], xyz[1], xyz[2]))
numIons += wp.multiplicity
if verbose:
Expand Down
47 changes: 29 additions & 18 deletions pyxtal/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,15 @@ def gulp_opt_single(id, xtal, ff_lib, path, criteria):
return xtal, eng, status


def mace_opt_single(id, xtal, criteria, step=250):
def mace_opt_single(id, xtal, step, criteria):
"""
Perform a single MACE optimization for a given atomic crystal structure.
Args:
id (int): Identifier for the current structure.
xtal: PyXtal instance representing the crystal structure.
criteria (dict): Dictionary to check the validity of the optimized structure.
step (int): Maximum number of relaxation steps. Default is 250.
criteria (dict): Dictionary to check the validity of the optimized structure.
Returns:
tuple:
Expand Down Expand Up @@ -1112,61 +1112,72 @@ def get_max_id(self):
max_id = row.id + 1
return max_id

def select_xtals(self, ids, overwrite=False, attribute=None, use_relaxed=None):
def select_xtals(self, ids, N_atoms=(None, None), overwrite=False, attribute=None, use_relaxed=None):
"""
Extract xtals based on attribute name.
Mostly called by update_row_energy
Args:
ids:
N_atoms:
overwrite:
atttribute:
use_relaxed (str): 'ff_relaxed' or 'vasp_relaxed'
"""
(min_id, max_id) = ids
if min_id is None:
min_id = 1
if max_id is None:
max_id = self.get_max_id()
if min_id is None: min_id = 1
if max_id is None: max_id = self.get_max_id()

(min_atoms, max_atoms) = N_atoms
if min_atoms is None: min_atoms = 1
if max_atoms is None: max_atoms = 5000

ids, xtals = [], []
for row in self.db.select():
if overwrite or attribute is None or not hasattr(row, attribute):
if min_id <= row.id <= max_id:
if min_id <= row.id <= max_id and min_atoms < natoms <= max_atoms:
xtal = self.get_pyxtal(row.id, use_relaxed)
ids.append(row.id)
xtals.append(xtal)
if len(xtals) % 100 == 0:
print("Loading xtals from db", len(xtals))
return ids, xtals

def select_xtal(self, ids, overwrite=False, attribute=None, use_relaxed=None):
def select_xtal(self, ids, N_atoms=(None, None), overwrite=False, attribute=None, use_relaxed=None):
"""
Lazy extraction of select xtals
Args:
ids:
N_atoms:
overwrite:
atttribute:
use_relaxed (str): 'ff_relaxed' or 'vasp_relaxed'
"""
(min_id, max_id) = ids
if min_id is None:
min_id = 1
if max_id is None:
max_id = self.get_max_id()
if min_id is None: min_id = 1
if max_id is None: max_id = self.get_max_id()

(min_atoms, max_atoms) = N_atoms
if min_atoms is None: min_atoms = 1
if max_atoms is None: max_atoms = 5000

ids, xtals = [], []
for row in self.db.select():
if overwrite or attribute is None or not hasattr(row, attribute):
id = row.id
if min_id <= id <= max_id and id % self.size== self.rank:
id, natoms = row.id, row.natoms
if min_id <= id <= max_id and \
min_atoms < natoms <= max_atoms \
and id % self.size== self.rank:

xtal = self.get_pyxtal(id, use_relaxed)
yield id, xtal

def update_row_energy(
self,
calculator='GULP',
ids=(None, None),
N_atoms=(None, None),
ncpu=1,
criteria=None,
symmetrize=False,
Expand Down Expand Up @@ -1211,21 +1222,21 @@ def update_row_energy(
label = calculator.lower() + "_energy"
if calc_folder is None:
calc_folder = calculator.lower() + "_calc"
# MACE does not need a folder

if calculator != 'MACE':
#self.logging.info("make new folders", calc_folder, os.getpwd())
os.makedirs(calc_folder, exist_ok=True)

# Generate structures for calculation
generator = self.select_xtal(ids, overwrite, label, use_relaxed)
generator = self.select_xtal(ids, N_atoms, overwrite, label, use_relaxed)

# Set up arguments for the chosen calculator
args_up = []
if calculator == 'GULP':
args = [calculator, ff_lib, calc_folder, criteria]
args_up = [ff_lib]
elif calculator == 'MACE':
args = [calculator, criteria]
args = [calculator, steps, criteria]
elif calculator == 'DFTB':
args = [calculator, skf_dir, steps, symmetrize, criteria]
elif calculator == 'VASP':
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/interface/ase_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile='as
def handler(signum, frame):
raise TimeoutError("Optimization timed out")

step_init = 40
step_init = min([30, int(step/2)])
logger = logging.getLogger()
max_time *= 60
timeout = int(max_time)
Expand Down
49 changes: 38 additions & 11 deletions pyxtal/lego/SO3.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def init_atoms(self, atoms, atom_ids=None):
self.natoms = len(atoms)
self.build_neighbor_list(atom_ids)

def compute_p(self, atoms, atom_ids=None):
def compute_p(self, atoms, atom_ids=None, return_CN=False):
"""
Compute the powerspectrum function
Expand All @@ -163,6 +163,8 @@ def compute_p(self, atoms, atom_ids=None):
if atom_ids is None: atom_ids = range(len(atoms))
self.init_atoms(atoms, atom_ids)
plist = np.zeros((len(atom_ids), self.ncoefs), dtype=np.float64)
dist = np.zeros((len(atom_ids), 2))

if len(self.neighborlist) > 0:
cs = compute_cs(self.neighborlist, *self.args)
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
Expand All @@ -171,11 +173,18 @@ def compute_p(self, atoms, atom_ids=None):
# Get r_ij and compute C*np.conj(C)
for _i, i in enumerate(atom_ids):
centers = self.neighbor_indices[:, 0] == i
if len(centers) > 0:
CN = len(self.neighborlist[centers])
if CN > 0:
ctot = cs[centers].sum(axis=0)
P = np.einsum('ijk,ljk->ilj', ctot, np.conj(ctot)).real
plist[_i] = P[self.tril_indices].flatten()
return plist
dist[_i, 0] += np.linalg.norm(self.neighborlist[centers], axis=1).max()
dist[_i, 1] += CN
if return_CN:
return plist, dist
else:
return plist


def compute_dpdr(self, atoms, atom_ids=None):
"""
Expand Down Expand Up @@ -335,12 +344,13 @@ def build_neighbor_list(self, atom_ids=None):
neighbor_indices = []
atomic_weights = []

#if True: #atom_ids is None:
if atom_ids is None:
atom_ids = range(len(atoms))

cutoffs = [self.rcut/2]*len(atoms)
nl = NeighborList(cutoffs, self_interaction=False, bothways=True, skin=0.0)
nl.update(atoms, atom_ids)
nl.update(atoms)

for i in atom_ids:
# get center atom position vector
Expand All @@ -350,17 +360,22 @@ def build_neighbor_list(self, atom_ids=None):
#print(indices); import sys; sys.exit()

for j, offset in zip(indices, offsets):
(x, y, z) = offset
cell_id = (x+1) * 9 + (y+1) * 3 + z + 1
pos = atoms.positions[j] + offset@cell_matrix - center_atom
# to prevent division by zero
if np.sum(np.abs(pos)) < 1e-3: pos += 0.001
if np.sum(np.abs(pos)) < 1e-3:
# to skip self
if cell_id == 13:
continue
# to prevent division by zero
else:
pos += 1e-3
neighbors.append(pos)
if self.weight_on and atoms[j].number != atoms[i].number:
factor = -1
else:
factor = 1
atomic_weights.append(factor*atoms[j].number)
(x, y, z) = offset
cell_id = (x+1) * 9 + (y+1) * 3 + z + 1
neighbor_indices.append([i, j, cell_id])
else:
# A short cut version if we only compute the neighbors for a few atoms
Expand Down Expand Up @@ -674,11 +689,11 @@ def compute_dcs(pos, nmax, lmax, rcut, alpha, cutoff):
)

parser.add_option("-l", "--lmax", dest="lmax", default=2, type=int,
help="lmax, default: 1"
help="lmax, default: 2"
)

parser.add_option("-n", "--nmax", dest="nmax", default=1, type=int,
help="nmax, default: 1"
parser.add_option("-n", "--nmax", dest="nmax", default=2, type=int,
help="nmax, default: 2"
)

parser.add_option("-a", "--alpha", dest="alpha", default=2.0, type=float,
Expand Down Expand Up @@ -715,3 +730,15 @@ def compute_dcs(pos, nmax, lmax, rcut, alpha, cutoff):
print('x', x['x'])
dp, p = f.compute_dpdr(test); print('from dP', p)
dp, p = f.compute_dpdr_5d(test); print('from dP5d', p)

(x, spg, wps) = ([13.45493847, 0.98, 0.04, 0.86, 0.08, 0.94, 0.14, 0.22, 0.66, 0.22, 0.74, 0.28, 0.2, 0.18, 0.68], 226, ['192j', '192j', '192j', '96i', '192j'])

from pyxtal import pyxtal
xtal = pyxtal()
xtal.from_spg_wps_rep(spg, wps, x, ['C']*len(wps))
atoms = xtal.to_ase()
p1 = f.compute_p(atoms)[0]; print("from ase", p1)
p2 = f.compute_p(atoms, atom_ids=[0])[0]; print("from self", p2)
print(np.sum((p1-p2)**2))
assert(np.sum((p1-p2)**2)<1e-4)

Loading

0 comments on commit 6b25f1f

Please sign in to comment.