Skip to content

Commit

Permalink
fix #251 regarding distance check
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Apr 28, 2024
1 parent 344bb64 commit 4c79be5
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 63 deletions.
2 changes: 1 addition & 1 deletion pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,7 +2044,7 @@ def get_init_translations(self, ref_struc, tol=0.75):
match = False
for trans_ref in good_translations:
diff = trans - trans_ref
diff -= np.round(diff)
diff -= np.rint(diff)
diff = np.dot(diff, self.lattice.matrix)
if np.linalg.norm(diff) < tol:
match = True
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def resort(self, molecules):
p0 = positions[id0]
p1, wp, _ = self.wyc.merge(p0, new_lat, 0.1)
diff = p1 - p0
diff -= np.round(diff)
diff -= np.rint(diff)
if np.abs(diff).sum() < 1e-2: #sort position by mapping
wps.append(wp)
ids.append(id0) #find the right ids
Expand Down
9 changes: 5 additions & 4 deletions pyxtal/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def filtered_coords_euclidean(coords, PBC=[1, 1, 1]):
def filter_vector_euclidean(vector):
for i, a in enumerate(PBC):
if a:
# QZ: check if this is equivalent to -= np.rint()
vector[i] -= np.floor(vector[i])
if vector[i] > 0.5:
vector[i] = 1 - vector[i]
Expand Down Expand Up @@ -606,7 +607,7 @@ def are_equal(op1, op2, PBC=[1, 1, 1], rtol=1e-3, atol=1e-3):

for i, a in enumerate(PBC):
if a:
difference[i] -= np.floor(difference[i])
difference[i] -= np.rint(difference[i])

d = np.linalg.norm(difference)

Expand Down Expand Up @@ -643,7 +644,7 @@ def get_order(angle, rotoinversion=False, tol=1e-2):
found = False
for n in range(1, 61):
x = (n * angle) / (2.0 * np.pi)
y = x - np.round(x)
y = x - np.rint(x)
if abs(y) <= tol:
found = True
break
Expand Down Expand Up @@ -832,7 +833,7 @@ def find_ids(coords, ref, tol=1e-3):
#print('ref', ref)
for coord in coords:
diffs = ref - coord
diffs -= np.round(diffs)
diffs -= np.rint(diffs)
norms = np.linalg.norm(diffs, axis=1)
#print(norms, diffs)
for i, norm in enumerate(norms):
Expand All @@ -856,7 +857,7 @@ def get_best_match(positions, ref, cell):
id: matched id
"""
diffs = positions - ref
diffs -= np.round(diffs)
diffs -= np.rint(diffs)
diffs = np.dot(diffs, cell)
dists = np.linalg.norm(diffs, axis=1)
id = np.argmin(dists)
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def group_slabs(self, slabs, tol):
new = False
else:
dist = center - group[0]
shift = np.round(dist)
shift = np.rint(dist)
if abs(dist-shift) < tol:
new = False
lower -= shift
Expand Down
6 changes: 3 additions & 3 deletions pyxtal/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,16 @@ def get_dist(self, rep):
# symmmetry variation
xyzs = wp.apply_ops(tmp2[:3])
diff_xyzs = xyzs - tmp1[:3]
diff_xyzs -= np.round(diff_xyzs)
diff_xyzs -= np.rint(diff_xyzs)
id = np.argmin(np.linalg.norm(diff_xyzs, axis=1))
diff_xyz = diff_xyzs[id]
diff_ori = tmp2[3:6] - tmp1[3:6]
diff_ori /= [360.0, 180.0, 360.0]
diff_ori -= np.round(diff_ori)
diff_ori -= np.rint(diff_ori)
diff_ori *= [360.0, 180.0, 360.0]
diff_tor = tmp2[6:] - tmp1[6:]
diff_tor /= 360.0
diff_tor -= np.round(diff_tor)
diff_tor -= np.rint(diff_tor)
diff_tor *= 360.0
diffs.extend(diff_xyz)
diffs.extend(diff_ori)
Expand Down
22 changes: 11 additions & 11 deletions pyxtal/supergroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def search_G1(G, rot, tran, pos, wp1, op):
res = np.dot(rot, pos + shift) + tran.T
tmp = sym.search_cloest_wp(G, wp1, op, res)
diff = res - tmp
diff -= np.round(diff)
diff -= np.rint(diff)
dist = np.linalg.norm(diff)
diffs.append(dist)
coords.append(tmp)
Expand All @@ -197,7 +197,7 @@ def search_G1(G, rot, tran, pos, wp1, op):
diffs = np.array(diffs)
minID = np.argmin(diffs)
tmp = coords[minID]
tmp -= np.round(tmp)
tmp -= np.rint(tmp)
return tmp, np.min(diffs)


Expand All @@ -217,14 +217,14 @@ def search_G2(rot, tran, pos1, pos2, cell=None):
dist: relative distance
"""

pos1 -= np.round(pos1)
pos1 -= np.rint(pos1)
shifts = ALL_SHIFTS

dists = []
for shift in shifts:
res = np.dot(rot, pos1 + shift + tran.T)
diff = res - pos2
diff -= np.round(diff)
diff -= np.rint(diff)
dist = np.linalg.norm(diff)
dists.append(dist)
if dist < 1e-1:
Expand All @@ -236,7 +236,7 @@ def search_G2(rot, tran, pos1, pos2, cell=None):


diff = pos - pos2
diff -= np.round(diff)
diff -= np.rint(diff)

if cell is not None:
diff = np.dot(diff, cell)
Expand Down Expand Up @@ -661,7 +661,7 @@ def symmetrize_site_single(self, splitter, id, base, translation, run_type=1):
if translation is None:
coord_G2, dist1 = search_G2(inv_rot, -tran, tmp, coord_H, None)
diff = coord_G2 - coord_H
diff -= np.round(diff)
diff -= np.rint(diff)
translation = diff.copy()
for m in range(3):
if abs(diff[m])<1e-4:
Expand Down Expand Up @@ -712,7 +712,7 @@ def symmetrize_site_double_k(self, splitter, id, coord_H, translation, run_type=
if np.sum(diff**2) < 1e-3:
trans = op_G22.translation_vector - op_G21.translation_vector
break
trans -= np.round(trans)
trans -= np.rint(trans)
coords11 = apply_ops(coord1_G2, ops_H1)
coords11 += trans
tmp, dist = get_best_match(coords11, coord2_G2, self.cell)
Expand All @@ -722,7 +722,7 @@ def symmetrize_site_double_k(self, splitter, id, coord_H, translation, run_type=
return dist/2 #np.linalg.norm(np.dot(d/2, self.cell))
else:
d = coord2_G2 - tmp
d -= np.round(d)
d -= np.rint(d)
op_G11 = splitter.G1_orbits[id][0][0]
coord2_G2 -= d/2
coord1_G2 += d/2
Expand Down Expand Up @@ -769,7 +769,7 @@ def symmetrize_site_double_t(self, splitter, id, coord_H, translation, run_type=
else:
# G1->G2->H
d = coord2_G1 - tmp
d -= np.round(d)
d -= np.rint(d)
coord2_G1 -= d/2

coords22 = apply_ops(coord2_G1, ops_G1)
Expand Down Expand Up @@ -886,7 +886,7 @@ def print_detail(self, solution, coords_H, coords_G, elements):
x, y, ele = coords_H[count], coords_G[count], elements[count]
label = wp.get_label() + '->' + wp1.get_label()
dis = y - x - translation
dis -= np.round(dis)
dis -= np.rint(dis)
dis_abs = np.linalg.norm(dis.dot(self.cell))
output = "{:2s}[{:8s}] {:8.4f}{:8.4f}{:8.4f}".format(ele, label, *x)
output += " -> {:8.4f}{:8.4f}{:8.4f}".format(*y)
Expand Down Expand Up @@ -927,7 +927,7 @@ def make_pyxtals_in_subgroup(self, solution, N_images=5):
for wp in wp2:
x, y, ele = coords_H1[count], coords_G2[count], elements[count]
disp = y - x - translation
disp -= np.round(disp)
disp -= np.rint(disp)
disps.append(disp)
count += 1

Expand Down
20 changes: 10 additions & 10 deletions pyxtal/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,7 +1990,7 @@ def has_equivalent_ops(self, wp2, tol=1e-3):
for i, op0 in enumerate(ops0):
for j, op1 in enumerate(self.ops):
diff0 = op0.translation_vector - op1.translation_vector
diff0 -= np.round(diff0)
diff0 -= np.rint(diff0)
diff1 = op0.rotation_matrix - op1.rotation_matrix
if max([np.abs(diff0).sum(), np.abs(diff1).sum()]) < tol:
count += 1
Expand Down Expand Up @@ -2077,7 +2077,7 @@ def are_equivalent_pts(self, pt1, pt2, cell=np.eye(3), tol=0.05):
pt2 = np.array(pt2); pt2 -= np.floor(pt2)
pts = self.apply_ops(pt1); pts -= np.floor(pts)
diffs = pt2 - pts
diffs -= np.round(diffs)
diffs -= np.rint(diffs)
diffs = np.dot(diffs, cell)
dists = np.linalg.norm(diffs, axis=1)
#print(dists)
Expand Down Expand Up @@ -2306,7 +2306,7 @@ def search_generator(self, pos, ops=None, tol=1e-2):
pos1 = op.operate(pos) #
pos0 = self.ops[0].operate(pos1)
diff = pos1 - pos0
diff -= np.round(diff)
diff -= np.rint(diff)
diff = np.abs(diff)
#print(self.letter, "{:24s}".format(op.as_xyz_str()), pos, pos0, pos1, diff)
if diff.sum() < tol:
Expand Down Expand Up @@ -2340,7 +2340,7 @@ def search_all_generators(self, pos, ops=None, tol=1e-2):
pos1 = op.operate(pos)
pos0 = self.ops[0].operate(pos1)
diff = pos1 - pos0
diff -= np.round(diff)
diff -= np.rint(diff)
diff = np.abs(diff)
#print(wp.letter, pos1, pos0, diff)
if diff.sum() < tol:
Expand Down Expand Up @@ -2601,7 +2601,7 @@ def are_equivalent_ops(op1, op2, tol=1e-2):
check if two ops are equivalent, assuming the same ordering
"""
diff = op1.affine_matrix - op2.affine_matrix
diff[:,3] -= np.round(diff[:,3])
diff[:,3] -= np.rint(diff[:,3])
diff = np.abs(diff.flatten())
if np.sum(diff) < tol:
return True
Expand Down Expand Up @@ -3308,7 +3308,7 @@ def site_symm(point, gen_pos, tol=1e-3, lattice=np.eye(3), PBC=None):
by (+1,-1,0) and (0,0,+1), respectively.
"""
el = SymmOp.from_rotation_and_translation(
op.rotation_matrix, op.translation_vector - np.round(displacement)
op.rotation_matrix, op.translation_vector - np.rint(displacement)
)
symmetry.append(el)
return symmetry
Expand Down Expand Up @@ -3533,7 +3533,7 @@ def search_cloest_wp(G, wp, op, pos):
for coord in coords:
tmp = op.operate(coord)
diff1 = tmp - pos
diff1 -= np.round(diff1)
diff1 -= np.rint(diff1)
dist = np.linalg.norm(diff1)
if dist < 1e-3:
return tmp
Expand All @@ -3547,12 +3547,12 @@ def search_cloest_wp(G, wp, op, pos):
# extract all possible xyzs
all_xyz = apply_ops(pos, wp0)[1:]
dists = all_xyz - pos
dists -= np.round(dists)
dists -= np.rint(dists)
ds = np.linalg.norm(dists, axis=1)
ids = np.argsort(ds)
for id in ids:
d = all_xyz[id] - pos
d -= np.round(d)
d -= np.rint(d)
res = pos + d/2
if wp.search_generator(res, wp0) is not None:
#print(ds[id], pos, res)
Expand Down Expand Up @@ -3865,7 +3865,7 @@ def transform_ops(ops, P, P1):
for i, op in enumerate(ops):
inv = np.linalg.inv(op.affine_matrix)
trans = inv[:3, 3] + base
trans -= np.round(trans)
trans -= np.rint(trans)
rot_ops = ops[i].rotation_matrix
ops[i] = SymmOp.from_rotation_and_translation(rot_ops, trans)

Expand Down
20 changes: 18 additions & 2 deletions pyxtal/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pyxtal.operations import get_inverse
from pyxtal.supergroup import supergroups, supergroup
from pyxtal.util import generate_wp_lib

from pyxtal.wyckoff_site import atom_site

def resource_filename(package_name, resource_path):
package_path = importlib.util.find_spec(package_name).submodule_search_locations[0]
Expand Down Expand Up @@ -593,7 +593,7 @@ def check_error(spg, pt, cell):
p2 += op1.translation_vector

diff = p1-p2
diff -= np.round(diff)
diff -= np.rint(diff)
if np.linalg.norm(diff) > 0.02:
#res = '{:2d} {:28s}'.format(i, op0.as_xyz_str())
#res += ' {:28s}'.format(op1.as_xyz_str())
Expand Down Expand Up @@ -1143,6 +1143,22 @@ def test_atom(self):
N2 = len(struc.atom_sites)
self.assertTrue(N1 == N2)

class Test_wyckoff_site(unittest.TestCase):
def test_atom_site(self):
"""
Test the search function
"""
wp = Group(227)[-5]
arr = np.array([0.1, 0.1, 0.1])
for xyz in [[0.6501, 0.15001, 0.5999],
[0.6501, 0.14999, 0.5999],
[0.6499, 0.14999, 0.5999],
[0.6499, 0.14999, 0.60001],
]:
site = atom_site(wp, xyz, search=True)
self.assertTrue(np.allclose(site.position, arr, rtol=1e-3))


class Test_operations(unittest.TestCase):
def test_inverse(self):
coord0 = [0.35, 0.1, 0.4]
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def sort_by_dimer(atoms, N_mols, id=10, tol=4.0):
for j in lefts[1:]:
ref_j = refs[j]
dist = ref_j - ref_i
shift = np.round(dist)
shift = np.rint(dist)
dist -= shift
dist = np.linalg.norm(dist.dot(atoms.cell[:]))
if dist < tol:
Expand Down
Loading

0 comments on commit 4c79be5

Please sign in to comment.