Skip to content

Commit

Permalink
Merge pull request #280 from MaterSim/network
Browse files Browse the repository at this point in the history
Network
  • Loading branch information
qzhu2017 authored Oct 8, 2024
2 parents 2e5941a + 2b9dda2 commit 97c075d
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 156 deletions.
15 changes: 11 additions & 4 deletions pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,7 +1638,6 @@ def build(self, group, species, numIons, lattice, sites, tol=1e-2, dim=3, use_ha
# msg = 'The input lattice needs to be a pyxtal.lattice.Lattice class'
# raise ValueError(msg, lattice)
self.lattice = lattice

self.dim = dim
self.factor = 1.0
self.PBC = self.group.PBC
Expand Down Expand Up @@ -1667,12 +1666,15 @@ def build(self, group, species, numIons, lattice, sites, tol=1e-2, dim=3, use_ha
elif len(wp) == 4: # tuple:
(key, x, y, z) = wp
_wp = choose_wyckoff(self.group, site=key, dim=dim)
#print('debug build', key, x, y, z, _wp.get_label())
if _wp is not False:
if _wp.get_dof() == 0: # fixed pos
pt = [0.0, 0.0, 0.0]
else:
ans = _wp.get_all_positions([x, y, z])
pt = ans[0] if ans is not None else None
#print('debug build', ans, x, y, z)
# print('debug', ans)
if pt is not None:
_sites.append(atom_site(_wp, pt, sp))
else:
Expand Down Expand Up @@ -3843,22 +3845,27 @@ def from_tabular_representation(

# Conversion from discrete to continuous
if discrete:
#print('discrete', x, y, z)
[x, y, z] = wp.from_discrete_grid([x, y, z], N_grids)
#print('conversion', x, y, z)

# ; print(wp.get_label(), xyz)
xyz = wp.search_generator([x, y, z], tol=tol)
xyz = wp.search_generator([x, y, z], tol=tol, symmetrize=True)
if xyz is not None:
xyz, wp, _ = wp.merge(xyz, np.eye(3), 1e-3)
label = wp.get_label()
# ; print(x, y, z, label, xyz[0], xyz[1], xyz[2])
sites.append((label, xyz[0], xyz[1], xyz[2]))
numIons += wp.multiplicity
if verbose:
print('add wp', label, xyz)
else:
if verbose:
print("Cannot find generator from", x, y, z)
print(wp)
print("Cannot get wp in", x, y, z, wp.get_label())

if len(sites) > 0:
try:
# print(sites)
self.build(group, ["C"], [numIons], lattice, [sites])
except:
print("Invalid Build", number, lattice, numIons, sites)
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/database/HM_Full.csv
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Hall,Spg_num,Spg_full,Symbol,P,P^-1,Permutation
49,9,9:-c2,A 1 1 n,"-a-c,a,-b","b,-c,-a-b",1
50,9,9:-c3,I 1 1 a,"c,-a-c,-b","-a-b,-c,a",1
51,9,9:a1,B b 1 1,"b,c,a","c,a,b",1
52,9,9:a2,C n 1 1,"-b,a,c","b,-a,c",1
52,9,9:a2,C n 1 1,"b,a,-a-c","b,a,-b-c",1
53,9,9:a3,I c 1 1,"b,-a-c,c","-b-c,a,c",1
54,9,9:-a1,C c 1 1,"-b,a,c","b,-a,c",1
55,9,9:-a2,B n 1 1,"-b,-a-c,a","c,-a,-b-c",1
Expand Down
122 changes: 60 additions & 62 deletions pyxtal/lego/SO3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0,
self.norm = np.sqrt(2*np.sqrt(2)*np.pi/np.sqrt(2*self.ls+1))
self.keys = ['keys', '_nmax', '_lmax', '_rcut', '_alpha',
'_cutoff_function', 'weight_on', 'neighborcalc',
'ncoefs', 'ls', 'norm', 'tril_indices']
'ncoefs', 'ls', 'norm', 'tril_indices', 'args']
self.args = (self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)

def __str__(self):
s = "SO3 descriptor with Cutoff: {:6.3f}".format(self.rcut)
Expand Down Expand Up @@ -160,20 +161,21 @@ def compute_p(self, atoms, atom_ids=None):
p array (N, M)
"""

if atom_ids is None: atom_ids = range(len(atoms))
self.init_atoms(atoms, atom_ids)
plist = np.zeros((len(atoms), self.ncoefs), dtype=np.float64)
plist = np.zeros((len(atom_ids), self.ncoefs), dtype=np.float64)
if len(self.neighborlist) > 0:
cs = compute_cs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)
cs = compute_cs(self.neighborlist, *self.args)
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
cs = np.einsum('inlm,l->inlm', cs, self.norm)

# Get r_ij and compute C*np.conj(C)
for i in range(len(atoms)):
centers = self.neighbor_indices[:,0] == i
for _i, i in enumerate(atom_ids):
centers = self.neighbor_indices[:, 0] == i
if len(centers) > 0:
ctot = cs[centers].sum(axis=0)
P = np.einsum('ijk,ljk->ilj', ctot, np.conj(ctot)).real
plist[i] = P[self.tril_indices].flatten()
plist[_i] = P[self.tril_indices].flatten()
return plist

def compute_dpdr(self, atoms, atom_ids=None):
Expand All @@ -188,13 +190,14 @@ def compute_dpdr(self, atoms, atom_ids=None):
dpdr array (N, N, M, 3) and p array (N, M)
"""

if atom_ids is None: atom_ids = range(len(atoms))
self.init_atoms(atoms, atom_ids)
p_list = np.zeros((self.natoms, self.ncoefs), dtype=np.float64)
dp_list = np.zeros((self.natoms, self.natoms, self.ncoefs, 3), dtype=np.float64)
p_list = np.zeros((len(atom_ids), self.ncoefs), dtype=np.float64)
dp_list = np.zeros((len(atom_ids), self.natoms, self.ncoefs, 3), dtype=np.float64)

if len(self.neighborlist) > 0:
# get expansion coefficients and derivatives
cs, dcs = compute_dcs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)
cs, dcs = compute_dcs(self.neighborlist, *self.args)

# weight cs and dcs
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
Expand All @@ -203,12 +206,11 @@ def compute_dpdr(self, atoms, atom_ids=None):
dcs = np.einsum('inlmj,l->inlmj', dcs, self.norm)
#print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)

# Assign cs and dcs to P and dP
# cs: (N_ij, n, l, m) => P (N_i, N_des)
# dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
# (n, l, m) needs to be merged to 1 dimension

for i in range(len(atoms)):
for _i, i in enumerate(atom_ids):
# find atoms for which i is the center
centers = self.neighbor_indices[:, 0] == i

Expand All @@ -219,7 +221,7 @@ def compute_dpdr(self, atoms, atom_ids=None):
# power spectrum P = c*c_conj
# eq_3 (n, n', l) eliminate m
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
p_list[i] += P[self.tril_indices].flatten()
p_list[_i] += P[self.tril_indices].flatten()

# gradient of P for each neighbor, eq_26
# (N_ijs, n, n', l, 3)
Expand All @@ -233,12 +235,13 @@ def compute_dpdr(self, atoms, atom_ids=None):
# QZ: to check
ijs = self.neighbor_indices[centers]
for _id, j in enumerate(ijs[:, 1]):
dp_list[i, j, :, :] += dP[_id][self.tril_indices].flatten().reshape(self.ncoefs, 3)
dp_list[i, i, :, :] -= dP[_id][self.tril_indices].flatten().reshape(self.ncoefs, 3)
tmp = dP[_id][self.tril_indices].flatten().reshape(self.ncoefs, 3)
dp_list[_i, j, :, :] += tmp
dp_list[_i, i, :, :] -= tmp

return dp_list, p_list

def compute_dpdr_5d(self, atoms):
def compute_dpdr_5d(self, atoms, atom_ids=None):
"""
Compute the powerspectrum function with respect to supercell
Expand All @@ -248,57 +251,52 @@ def compute_dpdr_5d(self, atoms):
Returns:
dpdr array (N, N, M, 3, 27) and p array (N, M)
"""
if atom_ids is None: atom_ids = range(len(atoms))
self.init_atoms(atoms, atom_ids)
p_list = np.zeros((len(atom_ids), self.ncoefs), dtype=np.float64)
dp_list = np.zeros((len(atom_ids), self.natoms, self.ncoefs, 3, 27), dtype=np.float64)

self.init_atoms(atoms)
p_list = np.zeros((self.natoms, self.ncoefs), dtype=np.float64)
dp_list = np.zeros((self.natoms, self.natoms, self.ncoefs, 3, 27), dtype=np.float64)

# get expansion coefficients and derivatives
cs, dcs = compute_dcs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)

# weight cs and dcs
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
dcs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]
cs = np.einsum('inlm,l->inlm', cs, self.norm)
dcs = np.einsum('inlmj,l->inlmj', dcs, self.norm)
#print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)

# Assign cs and dcs to P and dP
# cs: (N_ij, n, l, m) => P (N_i, N_des)
# dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
# (n, l, m) needs to be merged to 1 dimension
neigh_ids = np.arange(len(self.neighbor_indices))
for i in range(len(atoms)):
# find atoms for which i is the center
pair_ids = neigh_ids[self.neighbor_indices[:, 0] == i]
if len(pair_ids) > 0:
ctot = cs[pair_ids].sum(axis=0) #(n, l, m)
dctot = dcs[pair_ids].sum(axis=0)
# power spectrum P = c*c_conj
# eq_3 (n, n', l) eliminate m
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
p_list[i] += P[self.tril_indices].flatten()

# loop over each pair
for pair_id in pair_ids:
(_, j, x, y, z) = self.neighbor_indices[pair_id]
# map from (x, y, z) to (0, 27)
cell_id = (x+1) * 9 + (y+1) * 3 + z + 1
if len(self.neighborlist) > 0:
# get expansion coefficients and derivatives
cs, dcs = compute_dcs(self.neighborlist, *self.args)

# gradient of P for each neighbor, eq_26
# (N_ijs, n, n', l, 3)
# dc * c_conj + c * dc_conj
dP = np.einsum('ijkn, ljk->iljn', dcs[pair_id], np.conj(ctot))
dP += np.einsum('ijkn, ljk->iljn', np.conj(dcs[pair_id]), ctot)
#dP += np.conj(np.transpose(dP, axes=[1, 0, 2, 3]))
#dP += np.einsum('ijkn, ljk->iljn', np.conj(dctot), cs[pair_id])
#dP += np.einsum('ijkn, ljk->iljn', dctot, np.conj(cs[pair_id]))
# weight cs and dcs
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
dcs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]
cs = np.einsum('inlm,l->inlm', cs, self.norm)
dcs = np.einsum('inlmj,l->inlmj', dcs, self.norm)
#print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)

# cs: (N_ij, n, l, m) => P (N_i, N_des)
# dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
# (n, l, m) needs to be merged to 1 dimension
neigh_ids = np.arange(len(self.neighbor_indices))
for _i, i in enumerate(atom_ids):
# find atoms for which i is the center
pair_ids = neigh_ids[self.neighbor_indices[:, 0] == i]
if len(pair_ids) > 0:
ctot = cs[pair_ids].sum(axis=0) #(n, l, m)
dctot = dcs[pair_ids].sum(axis=0)
# power spectrum P = c * c_conj
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
p_list[_i] += P[self.tril_indices].flatten()

# loop over each pair
for pair_id in pair_ids:
(_, j, x, y, z) = self.neighbor_indices[pair_id]
# map from (x, y, z) to (0, 27)
cell_id = (x+1) * 9 + (y+1) * 3 + z + 1

# (N_ijs, n, n', l, 3)
# dp = dc * c_conj + c * dc_conj
dP = np.einsum('ijkn, ljk->iljn', dcs[pair_id], np.conj(ctot))
dP += np.einsum('ijkn, ljk->iljn', np.conj(dcs[pair_id]), ctot)

dP = dP.real[self.tril_indices].flatten().reshape(self.ncoefs, 3)
#print(cs[pair_id].shape, dcs[pair_id].shape, dP.shape)
dP = dP.real[self.tril_indices].flatten().reshape(self.ncoefs, 3)
#print(cs[pair_id].shape, dcs[pair_id].shape, dP.shape)

dp_list[i, j, :, :, cell_id] += dP
dp_list[i, i, :, :, 13] -= dP
dp_list[_i, j, :, :, cell_id] += dP
dp_list[_i, i, :, :, 13] -= dP

return dp_list, p_list

Expand Down
Loading

0 comments on commit 97c075d

Please sign in to comment.