Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Space group cosmetics #230

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 109 additions & 95 deletions gflownet/envs/crystals/spacegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,40 @@ def _get_space_groups():
return SPACE_GROUPS


class Prop(Enum):
class Prop:
"""
Enumeration of the 3 properties of the SpaceGroup Environment:
- Crystal lattice system
- Point symmetry
- Space group
Encodes the 3 properties of the SpaceGroup Environment:
0: Crystal lattice system
1: Point symmetry
2: Space group
"""

CLS = 0
PS = 1
SG = 2
ALL = (CLS, PS, SG)


class StateType:
"""
Encodes the 4 types of state from which transitions can originate:
0: Source - both crystal-lattice system and point symmetry are unset (== 0)
1: CLS - crystal-lattice system is set (!= 0); point symmetry is unset
2: PS - crystal-lattice system is unset; point symmetry is set
3: CLS_PS - both crystal-lattice system and point symmetry are set
"""

SOURCE = 0
CLS = 1
PS = 2
CLS_PS = 3
ALL = (SOURCE, CLS, PS, CLS_PS)

def get_state_type(state: List[int]) -> int:
"""
Returns the value of the type of the state passed as an argument.
"""
return sum([int(s > 0) * f for s, f in zip(state, (1, 2))])


class SpaceGroup(GFlowNetEnv):
Expand Down Expand Up @@ -102,24 +125,14 @@ def __init__(
the space group. 0's are removed from the list. If None, composition/space
group constraints are ignored.
"""
# Get dictionaries
# Read dictionaries from YAML files
self.crystal_lattice_systems = _get_crystal_lattice_systems()
self.point_symmetries = _get_point_symmetries()
self.space_groups = _get_space_groups()
# Restrict spacce groups to a subset
self._restrict_space_groups(space_groups_subset)
# Set dictionary of compatibility with number of atoms
self.set_n_atoms_compatibility_dict(n_atoms)
# Indices in the state representation: crystal-lattice system (cls), point
# symmetry (ps) and space group (sg)
self.cls_idx, self.ps_idx, self.sg_idx = 0, 1, 2
# Dictionary of all properties
self.properties = {
Prop.CLS: self.crystal_lattice_systems,
Prop.PS: self.point_symmetries,
Prop.SG: self.space_groups,
}
# Indices of state types (see self.get_state_type)
self.state_type_indices = [0, 1, 2, 3]
# End-of-sequence action
self.eos = (-1, -1, -1)
# Source state: index 0 (empty) for all three properties (crystal-lattice
Expand All @@ -138,16 +151,22 @@ def get_action_space(self):
(property, index, state_from_type), where property is (0: crystal-lattice
system, 1: point symmetry, 2: space group), index is the index of the property
set by the action and state_from_type is the state type of the originating
state (see self.state_type_indices).
state (see StateType).
"""
actions = []
for prop, indices in self.properties.items():
for s_from_type in self.state_type_indices:
if prop == Prop.CLS and s_from_type in [1, 3]:
# Create dictionary with of all properties
properties = {
Prop.CLS: self.crystal_lattice_systems,
Prop.PS: self.point_symmetries,
Prop.SG: self.space_groups,
}
for prop, indices in properties.items():
for state_type in StateType.ALL:
if prop == Prop.CLS and state_type in [StateType.CLS, StateType.CLS_PS]:
continue
if prop == Prop.PS and s_from_type in [2, 3]:
if prop == Prop.PS and state_type in [StateType.PS, StateType.CLS_PS]:
continue
actions_prop = [(prop.value, idx, s_from_type) for idx in indices]
actions_prop = [(prop, idx, state_type) for idx in indices]
actions += actions_prop
actions += [self.eos]
return actions
Expand All @@ -168,67 +187,71 @@ def get_mask_invalid_actions_forward(
done = self.done
if done:
return [True for _ in self.action_space]
cls_idx, ps_idx, sg_idx = state
cls_state, ps_state, sg_state = state
# If space group has been selected, only valid action is EOS
if sg_idx != 0:
if sg_state != 0:
mask = [True for _ in self.action_space]
mask[-1] = False
return mask
state_type = self.get_state_type(state)
# If neither crystal-lattice system nor point symmetry selected, apply only
# composition-compatibility constraints
if cls_idx == 0 and ps_idx == 0:
if cls_state == 0 and ps_state == 0:
crystal_lattice_systems = [
(self.cls_idx, idx, state_type)
for idx in self.crystal_lattice_systems
if self._is_compatible(cls_idx=idx)
(Prop.CLS, cls, state_type)
for cls in self.crystal_lattice_systems
if self._is_compatible(cls=cls)
]
point_symmetries = [
(self.ps_idx, idx, state_type)
for idx in self.point_symmetries
if self._is_compatible(ps_idx=idx)
(Prop.PS, ps, state_type)
for ps in self.point_symmetries
if self._is_compatible(ps=ps)
]
# Constraints after having selected crystal-lattice system
if cls_idx != 0:
if cls_state != 0:
crystal_lattice_systems = []
space_groups_cls = [
(self.sg_idx, sg, state_type)
for sg in self.crystal_lattice_systems[cls_idx]["space_groups"]
(Prop.SG, sg, state_type)
for sg in self.crystal_lattice_systems[cls_state]["space_groups"]
if self.n_atoms_compatibility_dict[sg]
]
# If no point symmetry selected yet
if ps_idx == 0:
if ps_state == 0:
point_symmetries = [
(self.ps_idx, idx, state_type)
for idx in self.crystal_lattice_systems[cls_idx]["point_symmetries"]
if self._is_compatible(cls_idx=cls_idx, ps_idx=idx)
(Prop.PS, ps, state_type)
for ps in self.crystal_lattice_systems[cls_state][
"point_symmetries"
]
if self._is_compatible(cls=cls_state, ps=ps)
]
else:
space_groups_cls = [
(self.sg_idx, idx, state_type)
for idx in self.space_groups
if self.n_atoms_compatibility_dict[idx]
(Prop.SG, sg, state_type)
for sg in self.space_groups
if self.n_atoms_compatibility_dict[sg]
]
# Constraints after having selected point symmetry
if ps_idx != 0:
if ps_state != 0:
point_symmetries = []
space_groups_ps = [
(self.sg_idx, sg, state_type)
for sg in self.point_symmetries[ps_idx]["space_groups"]
(Prop.SG, sg, state_type)
for sg in self.point_symmetries[ps_state]["space_groups"]
if self.n_atoms_compatibility_dict[sg]
]
# If no crystal-lattice system selected yet
if cls_idx == 0:
if cls_state == 0:
crystal_lattice_systems = [
(self.cls_idx, idx, state_type)
for idx in self.point_symmetries[ps_idx]["crystal_lattice_systems"]
if self._is_compatible(cls_idx=idx, ps_idx=ps_idx)
(Prop.CLS, cls, state_type)
for cls in self.point_symmetries[ps_state][
"crystal_lattice_systems"
]
if self._is_compatible(cls=cls, ps=ps_state)
]
else:
space_groups_ps = [
(self.sg_idx, idx, state_type)
for idx in self.space_groups
if self.n_atoms_compatibility_dict[idx]
(Prop.SG, sg, state_type)
for sg in self.space_groups
if self.n_atoms_compatibility_dict[sg]
]
# Merge space_groups constraints and determine valid space group actions
space_groups = list(set(space_groups_cls).intersection(set(space_groups_ps)))
Expand Down Expand Up @@ -258,11 +281,11 @@ def state2oracle(self, state: List = None) -> Tensor:
"""
if state is None:
state = self.state
if state[self.sg_idx] == 0:
if state[Prop.SG] == 0:
raise ValueError(
"The space group must have been set in order to call the oracle"
)
return torch.tensor(state[self.sg_idx], device=self.device, dtype=torch.long)
return torch.tensor(state[Prop.SG], device=self.device, dtype=torch.long)

def statebatch2oracle(
self, states: List[List]
Expand Down Expand Up @@ -300,7 +323,7 @@ def statetorch2oracle(
----
oracle_state : Tensor
"""
return torch.unsqueeze(states[:, self.sg_idx], dim=1).to(torch.long)
return torch.unsqueeze(states[:, Prop.SG], dim=1).to(torch.long)

def state2readable(self, state=None):
"""
Expand Down Expand Up @@ -381,24 +404,24 @@ def get_parents(self, state=None, done=None, action=None):
parents = []
actions = []
# Catch cases where space group has been selected
if state[self.sg_idx] != 0:
sg = state[self.sg_idx]
if state[Prop.SG] != 0:
sg = state[Prop.SG]
# Add parent: source
parents.append(self.source)
action = (self.sg_idx, sg, 0)
action = (Prop.SG, sg, 0)
actions.append(action)
# Add parents: states before setting space group
state[self.sg_idx] = 0
state[Prop.SG] = 0
for prop in range(len(state)):
parent = state.copy()
parent[prop] = 0
parents.append(parent)
parent_type = self.get_state_type(parent)
action = (self.sg_idx, sg, parent_type)
action = (Prop.SG, sg, parent_type)
actions.append(action)
else:
# Catch other parents
for prop, idx in enumerate(state[: self.sg_idx]):
for prop, idx in enumerate(state[: Prop.SG]):
if idx != 0:
parent = state.copy()
parent[prop] = 0
Expand Down Expand Up @@ -457,16 +480,14 @@ def get_max_traj_length(self):
return len(self.source) + 1

def _set_constrained_properties(self, state: List[int]) -> List[int]:
cls_idx, ps_idx, sg_idx = state
if sg_idx != 0:
if cls_idx == 0:
state[self.cls_idx] = self.space_groups[state[self.sg_idx]][
cls, ps, sg = state
if sg != 0:
if cls == 0:
state[Prop.CLS] = self.space_groups[state[Prop.SG]][
"crystal_lattice_system_idx"
]
if ps_idx == 0:
state[self.ps_idx] = self.space_groups[state[self.sg_idx]][
"point_symmetry_idx"
]
if ps == 0:
state[Prop.PS] = self.space_groups[state[Prop.SG]]["point_symmetry_idx"]
return state

def get_crystal_system(self, state: List[int] = None) -> str:
Expand All @@ -475,8 +496,8 @@ def get_crystal_system(self, state: List[int] = None) -> str:
"""
if state is None:
state = self.state
if state[self.cls_idx] != 0:
return self.crystal_lattice_systems[state[self.cls_idx]]["crystal_system"]
if state[Prop.CLS] != 0:
return self.crystal_lattice_systems[state[Prop.CLS]]["crystal_system"]
else:
return "None"

Expand All @@ -490,8 +511,8 @@ def get_lattice_system(self, state: List[int] = None) -> str:
"""
if state is None:
state = self.state
if state[self.cls_idx] != 0:
return self.crystal_lattice_systems[state[self.cls_idx]]["lattice_system"]
if state[Prop.CLS] != 0:
return self.crystal_lattice_systems[state[Prop.CLS]]["lattice_system"]
else:
return "None"

Expand Down Expand Up @@ -522,8 +543,8 @@ def get_point_symmetry(self, state: List[int] = None) -> str:
"""
if state is None:
state = self.state
if state[self.ps_idx] != 0:
return self.point_symmetries[state[self.ps_idx]]["point_symmetry"]
if state[Prop.PS] != 0:
return self.point_symmetries[state[Prop.PS]]["point_symmetry"]
else:
return "None"

Expand All @@ -537,8 +558,8 @@ def get_space_group_symbol(self, state: List[int] = None) -> str:
"""
if state is None:
state = self.state
if state[self.sg_idx] != 0:
return self.space_groups[state[self.sg_idx]]["full_symbol"]
if state[Prop.SG] != 0:
return self.space_groups[state[Prop.SG]]["full_symbol"]
else:
return "None"

Expand All @@ -554,8 +575,8 @@ def get_crystal_class(self, state: List[int] = None) -> str:
"""
if state is None:
state = self.state
if state[self.sg_idx] != 0:
return self.space_groups[state[self.sg_idx]]["crystal_class"]
if state[Prop.SG] != 0:
return self.space_groups[state[Prop.SG]]["crystal_class"]
else:
return "None"

Expand All @@ -571,8 +592,8 @@ def get_point_group(self, state: List[int] = None) -> str:
"""
if state is None:
state = self.state
if state[self.sg_idx] != 0:
return self.space_groups[state[self.sg_idx]]["point_group"]
if state[Prop.SG] != 0:
return self.space_groups[state[Prop.SG]]["point_group"]
else:
return "None"

Expand All @@ -582,16 +603,11 @@ def point_group(self) -> str:

def get_state_type(self, state: List[int] = None) -> int:
"""
Returns the index of the type of the state passed as an argument. The state
type is one of the following (self.state_type_indices):
0: both crystal-lattice system and point symmetry are unset (== 0)
1: crystal-lattice system is set (!= 0); point symmetry is unset
2: crystal-lattice system is unset; point symmetry is set
3: both crystal-lattice system and point symmetry are set
Returns the value of the type of the state passed as an argument.
"""
if state is None:
state = self.state
return sum([int(s > 0) * f for s, f in zip(state, (1, 2))])
return StateType.get_state_type(state)

def set_n_atoms_compatibility_dict(self, n_atoms: List):
"""
Expand All @@ -616,9 +632,7 @@ def set_n_atoms_compatibility_dict(self, n_atoms: List):
n_atoms, self.space_groups.keys()
)

def _is_compatible(
self, cls_idx: Optional[int] = None, ps_idx: Optional[int] = None
):
def _is_compatible(self, cls: Optional[int] = None, ps: Optional[int] = None):
"""
Returns True if there is exists at least one space group compatible with the
atom composition (according to self.n_atoms_compatibility_dict), with the
Expand All @@ -630,14 +644,14 @@ def _is_compatible(

# Prune the list of space groups to those compatible with the provided crystal-
# lattice system
if cls_idx is not None:
space_groups_cls = self.crystal_lattice_systems[cls_idx]["space_groups"]
if cls is not None:
space_groups_cls = self.crystal_lattice_systems[cls]["space_groups"]
space_groups = list(set(space_groups).intersection(set(space_groups_cls)))

# Prune the list of space groups to those compatible with the provided point
# symmetry
if ps_idx is not None:
space_groups_ps = self.point_symmetries[ps_idx]["space_groups"]
if ps is not None:
space_groups_ps = self.point_symmetries[ps]["space_groups"]
space_groups = list(set(space_groups).intersection(set(space_groups_ps)))

return len(space_groups) > 0
Expand Down
Loading