Skip to content

Commit

Permalink
create a data class to handle loading
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Aug 30, 2024
1 parent 9ec44a4 commit ee595b5
Showing 1 changed file with 125 additions and 44 deletions.
169 changes: 125 additions & 44 deletions pyxtal/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,26 @@ def rf(package_name, resource_path):
package_name).submodule_search_locations[0]
return os.path.join(package_path, resource_path)

class SubgroupData:
"""
Properties for Lazy Loading
"""
class SymmetryData:
def __init__(self):
self._wyckoff_sg = None
self._wyckoff_lg = None
self._wyckoff_rg = None
self._wyckoff_pg = None
self._symmetry_sg = None
self._symmetry_lg = None
self._symmetry_rg = None
self._symmetry_pg = None
self._generator_sg = None
self._generator_lg = None
self._generator_rg = None
self._generator_pg = None
self._t_subgroup = None
self._k_subgroup = None
self._hall_table = None

def get_t_subgroup(self):
if self._t_subgroup is None:
Expand All @@ -52,12 +68,77 @@ def get_k_subgroup(self):
if self._k_subgroup is None:
self._k_subgroup = loadfn(rf("pyxtal", "database/k_subgroup.json"))
return self._k_subgroup

def get_wyckoff_sg(self):
if self._wyckoff_sg is None:
self._wyckoff_sg = read_csv(rf("pyxtal", "database/wyckoff_list.csv"))
return self._wyckoff_sg

@property
def get_wyckoff_lg(self):
if self._wyckoff_lg is None:
self._wyckoff_lg = read_csv(rf("pyxtal", "database/layer.csv"))
return self._wyckoff_lg

def get_wyckoff_rg(self):
if self._wyckoff_rg is None:
self._wyckoff_rg = read_csv(rf("pyxtal", "database/rod.csv"))
return self._wyckoff_rg

def get_wyckoff_pg(self):
if self._wyckoff_pg is None:
self._wyckoff_pg = read_csv(rf("pyxtal", "database/point.csv"))
return self._wyckoff_pg

def get_symmetry_sg(self):
if self._symmetry_sg is None:
self._symmetry_sg = read_csv(rf("pyxtal", "database/wyckoff_symmetry.csv"))
return self._symmetry_sg

def get_symmetry_lg(self):
if self._symmetry_lg is None:
self._symmetry_lg = read_csv(rf("pyxtal", "database/layer_symmetry.csv"))
return self._wyckoff_lg

def get_symmetry_rg(self):
if self._symmetry_rg is None:
self._symmetry_rg = read_csv(rf("pyxtal", "database/rod_symmetry.csv"))
return self._symmetry_rg

def get_symmetry_pg(self):
if self._symmetry_pg is None:
self._symmetry_pg = read_csv(rf("pyxtal", "database/point_symmetry.csv"))
return self._symmetry_pg

def get_generator_sg(self):
if self._generator_sg is None:
self._generator_sg = read_csv(rf("pyxtal", "database/wyckoff_generators.csv"))
return self._generator_sg

def get_generator_lg(self):
if self._generator_lg is None:
self._generator_lg = read_csv(rf("pyxtal", "database/layer_generators.csv"))
return self._generator_lg

def get_symmetry_rg(self):
if self._generator_rg is None:
self._generator_rg = read_csv(rf("pyxtal", "database/rod_generators.csv"))
return self._generator_rg

def get_symmetry_pg(self):
if self._generator_pg is None:
self._generator_pg = read_csv(rf("pyxtal", "database/point_generators.csv"))
return self._generator_pg

def get_hall_table(self):
if self._hall_table is None:
self._hall_table = read_csv(rf("pyxtal", "database/HM_Full.csv"), sep=",")
return self._hall_table

# ------------------------------ Constants ---------------------------------------
#t_subgroup = loadfn(rf("pyxtal", "database/t_subgroup.json"))
#k_subgroup = loadfn(rf("pyxtal", "database/k_subgroup.json"))
subgroup_data = SubgroupData()
hall_table = read_csv(rf("pyxtal", "database/HM_Full.csv"), sep=",")
# The map between spglib default space group and hall numbers
SYMDATA = SymmetryData()
HALL_TABLE = SYMDATA.get_hall_table()

spglib_hall_numbers = [
1,
2,
Expand Down Expand Up @@ -548,15 +629,15 @@ def __init__(self, spgnum, style="pyxtal", permutation=False):
self.hall_symbols = []
self.Ps = [] # convertion from standard
self.P1s = [] # inverse convertion to standard
for id in range(len(hall_table["Hall"])):
if hall_table["Spg_num"][id] == spgnum:
include = True if permutation else hall_table["Permutation"][id] == 0
for id in range(len(HALL_TABLE["Hall"])):
if HALL_TABLE["Spg_num"][id] == spgnum:
include = True if permutation else HALL_TABLE["Permutation"][id] == 0
if include:
self.hall_numbers.append(hall_table["Hall"][id])
self.hall_symbols.append(hall_table["Symbol"][id])
self.Ps.append(abc2matrix(hall_table["P"][id]))
self.P1s.append(abc2matrix(hall_table["P^-1"][id]))
elif hall_table["Spg_num"][id] > spgnum:
self.hall_numbers.append(HALL_TABLE["Hall"][id])
self.hall_symbols.append(HALL_TABLE["Symbol"][id])
self.Ps.append(abc2matrix(HALL_TABLE["P"][id]))
self.P1s.append(abc2matrix(HALL_TABLE["P^-1"][id]))
elif HALL_TABLE["Spg_num"][id] > spgnum:
break
if len(self.hall_numbers) == 0:
msg = "hall numbers cannot be found, check input " + spgnum
Expand Down Expand Up @@ -662,8 +743,8 @@ def __init__(self, group, dim=3, use_hall=False, style="pyxtal", quick=False):
if not use_hall:
self.symbol, self.number = get_symbol_and_number(group, dim)
else:
self.symbol = hall_table["Symbol"][group - 1]
self.number = hall_table["Spg_num"][group - 1]
self.symbol = HALL_TABLE["Symbol"][group - 1]
self.number = HALL_TABLE["Spg_num"][group - 1]

self.PBC, self.lattice_type = get_pbc_and_lattice(self.number, dim)

Expand All @@ -684,8 +765,8 @@ def __init__(self, group, dim=3, use_hall=False, style="pyxtal", quick=False):
self.hall_number = spglib_hall_numbers[self.number - 1]
else:
self.hall_number = group
self.P = abc2matrix(hall_table["P"][self.hall_number - 1])
self.P1 = abc2matrix(hall_table["P^-1"][self.hall_number - 1])
self.P = abc2matrix(HALL_TABLE["P"][self.hall_number - 1])
self.P1 = abc2matrix(HALL_TABLE["P^-1"][self.hall_number - 1])
else:
self.hall_number = None
self.P = None
Expand Down Expand Up @@ -1052,7 +1133,7 @@ def get_max_k_subgroup(self):
Returns the maximal k-subgroups as a dictionary
"""
if self.dim == 3:
k_subgroup = subgroup_data.get_k_subgroup()
k_subgroup = SYMDATA.get_k_subgroup()
return k_subgroup[str(self.number)]
else:
msg = "Only supports the subgroups for space group"
Expand All @@ -1063,7 +1144,7 @@ def get_max_t_subgroup(self):
Returns the maximal t-subgroups as a dictionary
"""
if self.dim == 3:
t_subgroup = subgroup_data.get_t_subgroup()
t_subgroup = SYMDATA.get_t_subgroup()
return t_subgroup[str(self.number)]
else:
msg = "Only supports the subgroups for space group"
Expand Down Expand Up @@ -1541,8 +1622,8 @@ def add_k_transitions(self, path, n=1):
a list of maximal subgroup chains with extra k type transitions
"""

k_subgroup = subgroup_data.get_k_subgroup()
t_subgroup = subgroup_data.get_t_subgroup()
k_subgroup = SYMDATA.get_k_subgroup()
t_subgroup = SYMDATA.get_t_subgroup()

if n != 1:
print("only 1 extra k type supported at this time")
Expand Down Expand Up @@ -1880,8 +1961,8 @@ def from_group_and_index(cls, group, index, dim=3, use_hall=False, style="pyxtal
if not use_hall:
symbol, number = get_symbol_and_number(group, dim)
else:
symbol = hall_table["Symbol"][group - 1]
number = hall_table["Spg_num"][group - 1]
symbol = HALL_TABLE["Symbol"][group - 1]
number = HALL_TABLE["Spg_num"][group - 1]
pbc, lattice_type = get_pbc_and_lattice(number, dim)

if dim == 3:
Expand All @@ -1891,12 +1972,12 @@ def from_group_and_index(cls, group, index, dim=3, use_hall=False, style="pyxtal
hall_number = pyxtal_hall_numbers[number - 1]
else:
hall_number = spglib_hall_numbers[number - 1]
P = abc2matrix(hall_table["P"][hall_number - 1])
P1 = abc2matrix(hall_table["P^-1"][hall_number - 1])
P = abc2matrix(HALL_TABLE["P"][hall_number - 1])
P1 = abc2matrix(HALL_TABLE["P^-1"][hall_number - 1])
else:
hall_number = group
P = abc2matrix(hall_table["P"][hall_number - 1])
P1 = abc2matrix(hall_table["P^-1"][hall_number - 1])
P = abc2matrix(HALL_TABLE["P"][hall_number - 1])
P1 = abc2matrix(HALL_TABLE["P^-1"][hall_number - 1])
directions = get_symmetry_directions(lattice_type, symbol[0])

elif dim == 2:
Expand Down Expand Up @@ -2127,8 +2208,8 @@ def update_hall(self, hall_numbers=None):
candidates = self.process_ops()
success = False
for hall_number in hall_numbers:
P = abc2matrix(hall_table["P"][hall_number - 1])
P1 = abc2matrix(hall_table["P^-1"][hall_number - 1])
P = abc2matrix(HALL_TABLE["P"][hall_number - 1])
P1 = abc2matrix(HALL_TABLE["P^-1"][hall_number - 1])
wyckoffs = get_wyckoffs(self.number, dim=self.dim)

# Fist check the original index
Expand Down Expand Up @@ -2281,7 +2362,7 @@ def get_hm_symbol(self):
"""
get Hermann-Mauguin symbol
"""
return hall_table["Symbol"][self.hall_number - 1]
return HALL_TABLE["Symbol"][self.hall_number - 1]

def get_dof(self):
"""
Expand Down Expand Up @@ -3747,13 +3828,13 @@ def get_wyckoffs(num, organized=False, dim=3):
a list of Wyckoff positions, each of which is a list of SymmOp's
"""
if dim == 3:
df = read_csv(rf("pyxtal", "database/wyckoff_list.csv"))
df = SYMDATA.get_wyckoff_sg()
elif dim == 2:
df = read_csv(rf("pyxtal", "database/layer.csv"))
df = SYMDATA.get_wyckoff_lg()
elif dim == 1:
df = read_csv(rf("pyxtal", "database/rod.csv"))
df = SYMDATA.get_wyckoff_rg()
elif dim == 0:
df = read_csv(rf("pyxtal", "database/point.csv"))
df = SYMDATA.get_wyckoff_pg()

wyckoff_strings = eval(df["0"][num])

Expand Down Expand Up @@ -3795,13 +3876,13 @@ def get_wyckoff_symmetry(num, dim=3):
point in each Wyckoff position
"""
if dim == 3:
symmetry_df = read_csv(rf("pyxtal", "database/wyckoff_symmetry.csv"))
symmetry_df = SYMDATA.get_symmetry_sg()
elif dim == 2:
symmetry_df = read_csv(rf("pyxtal", "database/layer_symmetry.csv"))
symmetry_df = SYMDATA.get_symmetry_lg()
elif dim == 1:
symmetry_df = read_csv(rf("pyxtal", "database/rod_symmetry.csv"))
symmetry_df = SYMDATA.get_symmetry_rg()
elif dim == 0:
symmetry_df = read_csv(rf("pyxtal", "database/point_symmetry.csv"))
symmetry_df = SYMDATA.get_symmetry_pg()

symmetry_strings = eval(symmetry_df["0"][num])

Expand Down Expand Up @@ -3838,13 +3919,13 @@ def get_generators(num, dim=3):
"""

if dim == 3:
generators_df = read_csv(rf("pyxtal", "database/wyckoff_generators.csv"))
generators_df = SYMDATA.get_generator_sg()
elif dim == 2:
generators_df = read_csv(rf("pyxtal", "database/layer_generators.csv"))
generators_df = SYMDATA.get_generator_lg()
elif dim == 1:
generators_df = read_csv(rf("pyxtal", "database/rod_generators.csv"))
generators_df = SYMDATA.get_generator_rg()
elif dim == 0:
generators_df = read_csv(rf("pyxtal", "database/point_generators.csv"))
generators_df = SYMDATA.get_generator_pg()

generator_strings = eval(generators_df["0"][num])

Expand Down Expand Up @@ -4442,7 +4523,7 @@ def get_symmetry_from_ops(ops, tol=1e-5):
rot = [op.rotation_matrix for op in ops]
tran = [op.translation_vector for op in ops]
hall_number = get_hall_number_from_symmetry(rot, tran, tol)
spg_number = hall_table["Spg_num"][hall_number - 1]
spg_number = HALL_TABLE["Spg_num"][hall_number - 1]
return hall_number, spg_number


Expand Down

0 comments on commit ee595b5

Please sign in to comment.