Skip to content

Commit

Permalink
get rid of the dimensionality check
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Oct 8, 2024
1 parent 0df8653 commit 2b9dda2
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 107 deletions.
122 changes: 60 additions & 62 deletions pyxtal/lego/SO3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0,
self.norm = np.sqrt(2*np.sqrt(2)*np.pi/np.sqrt(2*self.ls+1))
self.keys = ['keys', '_nmax', '_lmax', '_rcut', '_alpha',
'_cutoff_function', 'weight_on', 'neighborcalc',
'ncoefs', 'ls', 'norm', 'tril_indices']
'ncoefs', 'ls', 'norm', 'tril_indices', 'args']
self.args = (self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)

def __str__(self):
s = "SO3 descriptor with Cutoff: {:6.3f}".format(self.rcut)
Expand Down Expand Up @@ -160,20 +161,21 @@ def compute_p(self, atoms, atom_ids=None):
p array (N, M)
"""

if atom_ids is None: atom_ids = range(len(atoms))
self.init_atoms(atoms, atom_ids)
plist = np.zeros((len(atoms), self.ncoefs), dtype=np.float64)
plist = np.zeros((len(atom_ids), self.ncoefs), dtype=np.float64)
if len(self.neighborlist) > 0:
cs = compute_cs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)
cs = compute_cs(self.neighborlist, *self.args)
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
cs = np.einsum('inlm,l->inlm', cs, self.norm)

# Get r_ij and compute C*np.conj(C)
for i in range(len(atoms)):
centers = self.neighbor_indices[:,0] == i
for _i, i in enumerate(atom_ids):
centers = self.neighbor_indices[:, 0] == i
if len(centers) > 0:
ctot = cs[centers].sum(axis=0)
P = np.einsum('ijk,ljk->ilj', ctot, np.conj(ctot)).real
plist[i] = P[self.tril_indices].flatten()
plist[_i] = P[self.tril_indices].flatten()
return plist

def compute_dpdr(self, atoms, atom_ids=None):
Expand All @@ -188,13 +190,14 @@ def compute_dpdr(self, atoms, atom_ids=None):
dpdr array (N, N, M, 3) and p array (N, M)
"""

if atom_ids is None: atom_ids = range(len(atoms))
self.init_atoms(atoms, atom_ids)
p_list = np.zeros((self.natoms, self.ncoefs), dtype=np.float64)
dp_list = np.zeros((self.natoms, self.natoms, self.ncoefs, 3), dtype=np.float64)
p_list = np.zeros((len(atom_ids), self.ncoefs), dtype=np.float64)
dp_list = np.zeros((len(atom_ids), self.natoms, self.ncoefs, 3), dtype=np.float64)

if len(self.neighborlist) > 0:
# get expansion coefficients and derivatives
cs, dcs = compute_dcs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)
cs, dcs = compute_dcs(self.neighborlist, *self.args)

# weight cs and dcs
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
Expand All @@ -203,12 +206,11 @@ def compute_dpdr(self, atoms, atom_ids=None):
dcs = np.einsum('inlmj,l->inlmj', dcs, self.norm)
#print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)

# Assign cs and dcs to P and dP
# cs: (N_ij, n, l, m) => P (N_i, N_des)
# dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
# (n, l, m) needs to be merged to 1 dimension

for i in range(len(atoms)):
for _i, i in enumerate(atom_ids):
# find atoms for which i is the center
centers = self.neighbor_indices[:, 0] == i

Expand All @@ -219,7 +221,7 @@ def compute_dpdr(self, atoms, atom_ids=None):
# power spectrum P = c*c_conj
# eq_3 (n, n', l) eliminate m
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
p_list[i] += P[self.tril_indices].flatten()
p_list[_i] += P[self.tril_indices].flatten()

# gradient of P for each neighbor, eq_26
# (N_ijs, n, n', l, 3)
Expand All @@ -233,12 +235,13 @@ def compute_dpdr(self, atoms, atom_ids=None):
# QZ: to check
ijs = self.neighbor_indices[centers]
for _id, j in enumerate(ijs[:, 1]):
dp_list[i, j, :, :] += dP[_id][self.tril_indices].flatten().reshape(self.ncoefs, 3)
dp_list[i, i, :, :] -= dP[_id][self.tril_indices].flatten().reshape(self.ncoefs, 3)
tmp = dP[_id][self.tril_indices].flatten().reshape(self.ncoefs, 3)
dp_list[_i, j, :, :] += tmp
dp_list[_i, i, :, :] -= tmp

return dp_list, p_list

def compute_dpdr_5d(self, atoms):
def compute_dpdr_5d(self, atoms, atom_ids=None):
"""
Compute the powerspectrum function with respect to supercell
Expand All @@ -248,57 +251,52 @@ def compute_dpdr_5d(self, atoms):
Returns:
dpdr array (N, N, M, 3, 27) and p array (N, M)
"""
if atom_ids is None: atom_ids = range(len(atoms))
self.init_atoms(atoms, atom_ids)
p_list = np.zeros((len(atom_ids), self.ncoefs), dtype=np.float64)
dp_list = np.zeros((len(atom_ids), self.natoms, self.ncoefs, 3, 27), dtype=np.float64)

self.init_atoms(atoms)
p_list = np.zeros((self.natoms, self.ncoefs), dtype=np.float64)
dp_list = np.zeros((self.natoms, self.natoms, self.ncoefs, 3, 27), dtype=np.float64)

# get expansion coefficients and derivatives
cs, dcs = compute_dcs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)

# weight cs and dcs
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
dcs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]
cs = np.einsum('inlm,l->inlm', cs, self.norm)
dcs = np.einsum('inlmj,l->inlmj', dcs, self.norm)
#print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)

# Assign cs and dcs to P and dP
# cs: (N_ij, n, l, m) => P (N_i, N_des)
# dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
# (n, l, m) needs to be merged to 1 dimension
neigh_ids = np.arange(len(self.neighbor_indices))
for i in range(len(atoms)):
# find atoms for which i is the center
pair_ids = neigh_ids[self.neighbor_indices[:, 0] == i]
if len(pair_ids) > 0:
ctot = cs[pair_ids].sum(axis=0) #(n, l, m)
dctot = dcs[pair_ids].sum(axis=0)
# power spectrum P = c*c_conj
# eq_3 (n, n', l) eliminate m
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
p_list[i] += P[self.tril_indices].flatten()

# loop over each pair
for pair_id in pair_ids:
(_, j, x, y, z) = self.neighbor_indices[pair_id]
# map from (x, y, z) to (0, 27)
cell_id = (x+1) * 9 + (y+1) * 3 + z + 1
if len(self.neighborlist) > 0:
# get expansion coefficients and derivatives
cs, dcs = compute_dcs(self.neighborlist, *self.args)

# gradient of P for each neighbor, eq_26
# (N_ijs, n, n', l, 3)
# dc * c_conj + c * dc_conj
dP = np.einsum('ijkn, ljk->iljn', dcs[pair_id], np.conj(ctot))
dP += np.einsum('ijkn, ljk->iljn', np.conj(dcs[pair_id]), ctot)
#dP += np.conj(np.transpose(dP, axes=[1, 0, 2, 3]))
#dP += np.einsum('ijkn, ljk->iljn', np.conj(dctot), cs[pair_id])
#dP += np.einsum('ijkn, ljk->iljn', dctot, np.conj(cs[pair_id]))
# weight cs and dcs
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
dcs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]
cs = np.einsum('inlm,l->inlm', cs, self.norm)
dcs = np.einsum('inlmj,l->inlmj', dcs, self.norm)
#print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)

# cs: (N_ij, n, l, m) => P (N_i, N_des)
# dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
# (n, l, m) needs to be merged to 1 dimension
neigh_ids = np.arange(len(self.neighbor_indices))
for _i, i in enumerate(atom_ids):
# find atoms for which i is the center
pair_ids = neigh_ids[self.neighbor_indices[:, 0] == i]
if len(pair_ids) > 0:
ctot = cs[pair_ids].sum(axis=0) #(n, l, m)
dctot = dcs[pair_ids].sum(axis=0)
# power spectrum P = c * c_conj
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
p_list[_i] += P[self.tril_indices].flatten()

# loop over each pair
for pair_id in pair_ids:
(_, j, x, y, z) = self.neighbor_indices[pair_id]
# map from (x, y, z) to (0, 27)
cell_id = (x+1) * 9 + (y+1) * 3 + z + 1

# (N_ijs, n, n', l, 3)
# dp = dc * c_conj + c * dc_conj
dP = np.einsum('ijkn, ljk->iljn', dcs[pair_id], np.conj(ctot))
dP += np.einsum('ijkn, ljk->iljn', np.conj(dcs[pair_id]), ctot)

dP = dP.real[self.tril_indices].flatten().reshape(self.ncoefs, 3)
#print(cs[pair_id].shape, dcs[pair_id].shape, dP.shape)
dP = dP.real[self.tril_indices].flatten().reshape(self.ncoefs, 3)
#print(cs[pair_id].shape, dcs[pair_id].shape, dP.shape)

dp_list[i, j, :, :, cell_id] += dP
dp_list[i, i, :, :, 13] -= dP
dp_list[_i, j, :, :, cell_id] += dP
dp_list[_i, i, :, :, 13] -= dP

return dp_list, p_list

Expand Down
Loading

0 comments on commit 2b9dda2

Please sign in to comment.