Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed optimize in lego module #282

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3427,8 +3427,9 @@ def update_from_1d_rep(self, x):
cell, pos = x[:N], x[N:]

# update cell
l_type = self.lattice.ltype
self.lattice = Lattice.from_1d_representation(cell, l_type)
self.lattice.update_from_1d_representation(cell)
#l_type = self.lattice.ltype
#self.lattice = Lattice.from_1d_representation(cell, l_type)
# print("lattice dof", N, cell, l_type, self.lattice)

# update position
Expand Down
19 changes: 19 additions & 0 deletions pyxtal/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,25 @@ def from_1d_representation(self, v, ltype):
except:
print(a, b, c, alpha, beta, gamma, ltype)

def update_from_1d_representation(self, v):
"""
Update the cell para and matrix from the 1d rep
"""
if self.ltype == "triclinic":
self.a, self.b, self.c, self.alpha, self.beta, self.gamma = v[:6]
elif self.ltype == "monoclinic":
self.a, self.b, self.c, self.beta = v[:4]
elif self.ltype == "orthorhombic":
self.a, self.b, self.c = v[:3]
elif self.ltype in ["tetragonal", "trigonal", "hexagonal"]:
self.a, self.b, self.c = v[0], v[0], v[1]
else:
self.a, self.b, self.c = v[0], v[0], v[0]

para = (self.a, self.b, self.c, self.alpha, self.beta, self.gamma)
self.set_matrix(para2matrix(para))


def mutate(self, degree=0.10, frozen=False):
"""
Mutate the lattice object
Expand Down
239 changes: 128 additions & 111 deletions pyxtal/lego/SO3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0,
self.ls = np.arange(self.lmax+1)
self.norm = np.sqrt(2*np.sqrt(2)*np.pi/np.sqrt(2*self.ls+1))
self.keys = ['keys', '_nmax', '_lmax', '_rcut', '_alpha',
'_cutoff_function', 'weight_on', 'neighborcalc',
'_cutoff_function', 'weight_on',
'ncoefs', 'ls', 'norm', 'tril_indices', 'args']
self.args = (self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)

def __str__(self):
s = "SO3 descriptor with Cutoff: {:6.3f}".format(self.rcut)
s += " lmax: {:d}, nmax: {:d}, alpha: {:.3f}\n".format(self.lmax, self.nmax, self.alpha)
s += "neighborlist: {:s}\n".format(self.neighborcalc)
return s

def __repr__(self):
Expand Down Expand Up @@ -283,9 +282,8 @@ def compute_dpdr_5d(self, atoms, atom_ids=None):

# loop over each pair
for pair_id in pair_ids:
(_, j, x, y, z) = self.neighbor_indices[pair_id]
(_, j, cell_id) = self.neighbor_indices[pair_id]
# map from (x, y, z) to (0, 27)
cell_id = (x+1) * 9 + (y+1) * 3 + z + 1

# (N_ijs, n, n', l, 3)
# dp = dc * c_conj + c * dc_conj
Expand Down Expand Up @@ -331,48 +329,74 @@ def build_neighbor_list(self, atom_ids=None):
a given ASE atoms object given in the calculate method.
'''
atoms = self._atoms
cell_matrix = atoms.get_cell()
VECTORS = np.array([[x1, y1, z1] for x1 in range(-1, 2) for y1 in range(-1, 2) for z1 in range(-1, 2)])
neighbors = []
neighbor_indices = []
atomic_weights = []

if atom_ids is None:
atom_ids = range(len(atoms))

cutoffs = [self.rcut/2]*len(atoms)
nl = NeighborList(cutoffs, self_interaction=False, bothways=True, skin=0.0)
nl.update(atoms)

#print(atoms, atom_ids)
#print(atoms.get_scaled_positions())
cutoffs = [self.rcut/2]*len(atoms)
nl = NeighborList(cutoffs, self_interaction=False, bothways=True, skin=0.0)
nl.update(atoms, atom_ids)

for i in atom_ids:
# get center atom position vector
center_atom = atoms.positions[i]
# get indices and cell offsets for each neighbor
indices, offsets = nl.get_neighbors(i)
#print(indices); import sys; sys.exit()

for j, offset in zip(indices, offsets):
pos = atoms.positions[j] + offset@cell_matrix - center_atom
# to prevent division by zero
if np.sum(np.abs(pos)) < 1e-3: pos += 0.001
neighbors.append(pos)
if self.weight_on and atoms[j].number != atoms[i].number:
factor = -1
else:
factor = 1
atomic_weights.append(factor*atoms[j].number)
(x, y, z) = offset
cell_id = (x+1) * 9 + (y+1) * 3 + z + 1
neighbor_indices.append([i, j, cell_id])
else:
# A short cut version if we only compute the neighbors for a few atoms
ref_pos = np.repeat(atoms.positions[:, :, np.newaxis], 27, axis=2)
cell_shifts = np.dot(VECTORS, cell_matrix)
ref_pos += cell_shifts.T[np.newaxis, :, :]
ref_ids = np.arange(len(atoms))
for i in atom_ids:
center_atom = atoms.positions[i]

# Compute distances relative to the center_atom across all cells
dists = np.linalg.norm(ref_pos - center_atom[:, np.newaxis], axis=1)

# Create mask for distances that are within the cutoff and greater than 1e-2
mask = (dists > 1e-2) & (dists < self.rcut)

# Use the mask to find neighbors in all cells at once
valid_atoms, valid_cells = np.where(mask)
# Append the neighbors and weights
for j, cell_id in zip(valid_atoms, valid_cells):
neighbors.append(ref_pos[j, :, cell_id] - center_atom)
factor = -1 if self.weight_on and atoms[j].number != atoms[i].number else 1
atomic_weights.append(factor * atoms[j].number)
neighbor_indices.append([i, j, cell_id])
#for cell_id in range(27):
# dists = np.linalg.norm(ref_pos[:, :, cell_id] - center_atom, axis=1)
# mask = (dists > 1e-2) & (dists < self.rcut)
# for j in ref_ids[mask]:
# neighbors.append(ref_pos[j, :, cell_id] - center_atom)
# factor = -1 if self.weight_on and atoms[j].number != atoms[i].number else 1
# atomic_weights.append(factor*atoms[j].number)
# neighbor_indices.append([i, j, cell_id])

center_atoms = []
neighbors = []
neighbor_indices = []
atomic_weights = []
temp_indices = []

for i in atom_ids:
# get center atom position vector
center_atom = atoms.positions[i]
# get indices and cell offsets for each neighbor
indices, offsets = nl.get_neighbors(i)
#print(indices); import sys; sys.exit()
temp_indices.append(indices)
for j, offset in zip(indices, offsets):
pos = atoms.positions[j] + np.dot(offset, atoms.get_cell()) - center_atom
# to prevent division by zero
if np.sum(np.abs(pos)) < 1e-3: pos += 0.001
center_atoms.append(center_atom)
neighbors.append(pos)
if self.weight_on and atoms[j].number != atoms[i].number:
factor = -1
else:
factor = 1
atomic_weights.append(factor*atoms[j].number)
neighbor_indices.append([i, j, *offset])

neighbor_indices = np.array(neighbor_indices, dtype=np.int64)

self.center_atoms = np.array(center_atoms, dtype=np.float64)
self.neighborlist = np.array(neighbors, dtype=np.float64)
self.atomic_weights = np.array(atomic_weights, dtype=np.int64)
self.neighbor_indices = neighbor_indices
self.neighbor_indices = np.array(neighbor_indices, dtype=np.int64)

def Cosine(Rij, Rc, derivative=False):
# Rij is the norm
Expand Down Expand Up @@ -422,83 +446,77 @@ def GaussChebyshevQuadrature(nmax, lmax):

def compute_cs(pos, nmax, lmax, rcut, alpha, cutoff):
"""
Compute exapnsion coefficients
Compute expansion coefficients for a system based on the input positions.

This function calculates the expansion coefficients for a set of atomic positions using
Gauss-Chebyshev quadrature, spherical Bessel functions, and spherical harmonics. It is
typically used in models that require high-dimensional projections of atomic environments.

Args:
pos:
nmax (int):
lmax (int):
rcut (float):
alpha (float):
cutoff (callable):
pos (numpy.ndarray): An array of atomic positions (N x 3) where N is the number of atoms.
nmax (int): Maximum radial quantum number used in the expansion.
lmax (int): Maximum angular momentum quantum number used in the expansion.
rcut (float): Cutoff radius for interactions and the radial expansion.
alpha (float): Gaussian decay factor applied to the radial functions.
cutoff (callable): A function to compute cutoff values for the radial distances.

Returns:
C(N_ij, nmax, lmax+1, 2lmax+1)
numpy.ndarray: A 4D array of expansion coefficients with shape (N_neighbors, nmax, lmax+1, 2*lmax+1),
where `N_neighbors` is the number of neighbor atoms, `nmax` is the number of radial terms, and
`lmax` and `m` correspond to angular momentum quantum numbers.
"""
# compute the overlap matrix
w = W(nmax)

# get the norm of the position vectors
Ris = np.linalg.norm(pos, axis=1) # (Nneighbors)

# initialize Gauss Chebyshev Quadrature
GCQuadrature, weight = GaussChebyshevQuadrature(nmax, lmax) #(Nquad)
weight *= rcut/2
# transform the quadrature from (-1,1) to (0, rcut)
Quadrature = rcut/2*(GCQuadrature+1)

# compute the arguments for the bessel functions
BesselArgs = 2*alpha*np.outer(Ris,Quadrature)#(Nneighbors x Nquad)

# initalize the arrays for the bessel function values
# and the G function values
Bessels = np.zeros((len(Ris), len(Quadrature), lmax+1), dtype=np.float64) #(Nneighbors x Nquad x lmax+1)
Gs = np.zeros((nmax, len(Quadrature)), dtype=np.float64) # (nmax, nquad)

# compute the g values
for n in range(1,nmax+1,1):
Gs[n-1,:] = g(Quadrature, n, nmax, rcut, w)

# compute the bessel values
for l in range(lmax+1):
Bessels[:,:,l] = spherical_in(l, BesselArgs)

# mutliply the terms in the integral separate from the Bessels
Quad_Squared = Quadrature**2
Gs *= Quad_Squared * np.exp(-alpha*Quad_Squared) * np.sqrt(1-GCQuadrature**2) * weight

# perform the integration with the Bessels
integral_array = np.einsum('ij,kjl->kil', Gs, Bessels) # (Nneighbors x nmax x lmax+1)

# compute the gaussian for each atom and multiply with 4*pi
# to minimize floating point operations
# weight can also go here since the Chebyshev gauss quadrature weights are uniform
exparray = 4*np.pi*np.exp(-alpha*Ris**2) # (Nneighbors)

# 1. Init Overlap matrix, distances and quadrature points
w = W(nmax)
Ris = np.linalg.norm(pos, axis=1) # (N_neighbors)
GCQuadrature, weight = GaussChebyshevQuadrature(nmax, lmax) # (Nquad)
weight *= rcut / 2
Quadrature = rcut / 2 * (GCQuadrature + 1)

# 2. Rdial G functions
Gs = np.zeros((nmax, len(Quadrature)), dtype=np.float64) # (nmax, Nquad)
for n in range(1, nmax + 1):
Gs[n - 1, :] = g(Quadrature, n, nmax, rcut, w)

Quad_Squared = Quadrature ** 2
Gs *= Quad_Squared * np.exp(-alpha * Quad_Squared) * np.sqrt(1 - GCQuadrature**2) * weight

# 3. Bessel functions
BesselArgs = 2 * alpha * np.outer(Ris, Quadrature) # (N_neighbors x Nquad)
Bessels = np.zeros((len(Ris), len(Quadrature), lmax + 1), dtype=np.float64) # (N_neighbors x Nquad x lmax+1)
for l in range(lmax + 1):
Bessels[:, :, l] = spherical_in(l, BesselArgs)

# 4. Integration over radial coordinates using Einstein summation
integral_array = np.einsum('ij,kjl->kil', Gs, Bessels) # (N_neighbors x nmax x lmax+1)

# 5. Gaussian factors for each atom and apply the cutoff function
exparray = 4 * np.pi * np.exp(-alpha * Ris**2) # (N_neighbors)
cutoff_array = cutoff(Ris, rcut)

exparray *= cutoff_array

# get the spherical coordinates of each atom
thetas = np.arccos(pos[:,2]/Ris[:])
phis = np.arctan2(pos[:,1], pos[:,0])

# determine the size of the m axis
msize = 2*lmax+1
# initialize an array for the spherical harmonics
ylms = np.zeros((len(Ris), lmax+1, msize), dtype=np.complex128)

# compute the spherical harmonics
for l in range(lmax+1):
for m in range(-l,l+1,1):
midx = msize//2 + m
ylms[:,l,midx] = sph_harm(m, l, phis, thetas)

# multiply the spherical harmonics and the radial inner product
np.multiply(exparray, cutoff_array, out=exparray)

# 6. Compute the spherical harmonics for each atom
# From Cartesian coordinates to spherical coordinates
thetas = pos[:, 2] / Ris[:]
thetas = np.arccos(thetas)
phis = np.arctan2(pos[:, 1], pos[:, 0])
msize = 2 * lmax + 1
ylms = np.zeros((len(Ris), lmax + 1, msize), dtype=np.complex128)

#for l in range(lmax + 1):
# for m in range(-l, l + 1):
# midx = msize // 2 + m
# ylms[:, l, midx] = sph_harm(m, l, phis, thetas)
l_vals = np.repeat(np.arange(lmax + 1), [2 * l + 1 for l in range(lmax + 1)])
m_vals = np.concatenate([np.arange(-l, l + 1) for l in range(lmax + 1)])
midx_vals = msize // 2 + m_vals
ylms[:, l_vals, midx_vals] = sph_harm(m_vals, l_vals, phis[:, np.newaxis], thetas[:, np.newaxis])


# 7. Multiply Ylm and Gaussian factors
Y_mul_innerprod = np.einsum('ijk,ilj->iljk', ylms, integral_array)

# multiply the gaussians into the expression
C = np.einsum('i,ijkl->ijkl', exparray, Y_mul_innerprod)

return C

def compute_dcs(pos, nmax, lmax, rcut, alpha, cutoff):
Expand Down Expand Up @@ -690,11 +708,10 @@ def compute_dcs(pos, nmax, lmax, rcut, alpha, cutoff):

start1 = time.time()
f = SO3(nmax=nmax, lmax=lmax, rcut=rcut, alpha=alpha)
x = f.calculate(test, derivative=der)
start2 = time.time()
print(f)
p = f.compute_p(test, atom_ids=[0]); print('from P', p)
x = f.calculate(test, derivative=der)
print('x', x['x'])
#print('dxdr', x['dxdr'])
p = f.compute_p(test); print('from P', p)
dp, p = f.compute_dpdr(test); print('from dP', p)
dp, p = f.compute_dpdr_5d(test); print('from dP5d', p)
2 changes: 1 addition & 1 deletion pyxtal/wyckoff_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def optimize_orientation_by_dist(self, ori_attempts):
else:
ang_lo, fun_lo = ang, fun

print("optimize_orientation_by_dist", _it, ang, fun)
#print("optimize_orientation_by_dist", _it, ang, fun)

return None

Expand Down