diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 8de313991..52d371cb7 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -45,17 +45,40 @@ def _get_space_groups(): return SPACE_GROUPS -class Prop(Enum): +class Prop: """ - Enumeration of the 3 properties of the SpaceGroup Environment: - - Crystal lattice system - - Point symmetry - - Space group + Encodes the 3 properties of the SpaceGroup Environment: + 0: Crystal lattice system + 1: Point symmetry + 2: Space group """ CLS = 0 PS = 1 SG = 2 + ALL = (CLS, PS, SG) + + +class StateType: + """ + Encodes the 4 types of state from which transitions can originate: + 0: Source - both crystal-lattice system and point symmetry are unset (== 0) + 1: CLS - crystal-lattice system is set (!= 0); point symmetry is unset + 2: PS - crystal-lattice system is unset; point symmetry is set + 3: CLS_PS - both crystal-lattice system and point symmetry are set + """ + + SOURCE = 0 + CLS = 1 + PS = 2 + CLS_PS = 3 + ALL = (SOURCE, CLS, PS, CLS_PS) + + def get_state_type(state: List[int]) -> int: + """ + Returns the value of the type of the state passed as an argument. + """ + return sum([int(s > 0) * f for s, f in zip(state, (1, 2))]) class SpaceGroup(GFlowNetEnv): @@ -102,24 +125,14 @@ def __init__( the space group. 0's are removed from the list. If None, composition/space group constraints are ignored. """ - # Get dictionaries + # Read dictionaries from YAML files self.crystal_lattice_systems = _get_crystal_lattice_systems() self.point_symmetries = _get_point_symmetries() self.space_groups = _get_space_groups() + # Restrict spacce groups to a subset self._restrict_space_groups(space_groups_subset) # Set dictionary of compatibility with number of atoms self.set_n_atoms_compatibility_dict(n_atoms) - # Indices in the state representation: crystal-lattice system (cls), point - # symmetry (ps) and space group (sg) - self.cls_idx, self.ps_idx, self.sg_idx = 0, 1, 2 - # Dictionary of all properties - self.properties = { - Prop.CLS: self.crystal_lattice_systems, - Prop.PS: self.point_symmetries, - Prop.SG: self.space_groups, - } - # Indices of state types (see self.get_state_type) - self.state_type_indices = [0, 1, 2, 3] # End-of-sequence action self.eos = (-1, -1, -1) # Source state: index 0 (empty) for all three properties (crystal-lattice @@ -138,16 +151,22 @@ def get_action_space(self): (property, index, state_from_type), where property is (0: crystal-lattice system, 1: point symmetry, 2: space group), index is the index of the property set by the action and state_from_type is the state type of the originating - state (see self.state_type_indices). + state (see StateType). """ actions = [] - for prop, indices in self.properties.items(): - for s_from_type in self.state_type_indices: - if prop == Prop.CLS and s_from_type in [1, 3]: + # Create dictionary with of all properties + properties = { + Prop.CLS: self.crystal_lattice_systems, + Prop.PS: self.point_symmetries, + Prop.SG: self.space_groups, + } + for prop, indices in properties.items(): + for state_type in StateType.ALL: + if prop == Prop.CLS and state_type in [StateType.CLS, StateType.CLS_PS]: continue - if prop == Prop.PS and s_from_type in [2, 3]: + if prop == Prop.PS and state_type in [StateType.PS, StateType.CLS_PS]: continue - actions_prop = [(prop.value, idx, s_from_type) for idx in indices] + actions_prop = [(prop, idx, state_type) for idx in indices] actions += actions_prop actions += [self.eos] return actions @@ -168,67 +187,71 @@ def get_mask_invalid_actions_forward( done = self.done if done: return [True for _ in self.action_space] - cls_idx, ps_idx, sg_idx = state + cls_state, ps_state, sg_state = state # If space group has been selected, only valid action is EOS - if sg_idx != 0: + if sg_state != 0: mask = [True for _ in self.action_space] mask[-1] = False return mask state_type = self.get_state_type(state) # If neither crystal-lattice system nor point symmetry selected, apply only # composition-compatibility constraints - if cls_idx == 0 and ps_idx == 0: + if cls_state == 0 and ps_state == 0: crystal_lattice_systems = [ - (self.cls_idx, idx, state_type) - for idx in self.crystal_lattice_systems - if self._is_compatible(cls_idx=idx) + (Prop.CLS, cls, state_type) + for cls in self.crystal_lattice_systems + if self._is_compatible(cls=cls) ] point_symmetries = [ - (self.ps_idx, idx, state_type) - for idx in self.point_symmetries - if self._is_compatible(ps_idx=idx) + (Prop.PS, ps, state_type) + for ps in self.point_symmetries + if self._is_compatible(ps=ps) ] # Constraints after having selected crystal-lattice system - if cls_idx != 0: + if cls_state != 0: crystal_lattice_systems = [] space_groups_cls = [ - (self.sg_idx, sg, state_type) - for sg in self.crystal_lattice_systems[cls_idx]["space_groups"] + (Prop.SG, sg, state_type) + for sg in self.crystal_lattice_systems[cls_state]["space_groups"] if self.n_atoms_compatibility_dict[sg] ] # If no point symmetry selected yet - if ps_idx == 0: + if ps_state == 0: point_symmetries = [ - (self.ps_idx, idx, state_type) - for idx in self.crystal_lattice_systems[cls_idx]["point_symmetries"] - if self._is_compatible(cls_idx=cls_idx, ps_idx=idx) + (Prop.PS, ps, state_type) + for ps in self.crystal_lattice_systems[cls_state][ + "point_symmetries" + ] + if self._is_compatible(cls=cls_state, ps=ps) ] else: space_groups_cls = [ - (self.sg_idx, idx, state_type) - for idx in self.space_groups - if self.n_atoms_compatibility_dict[idx] + (Prop.SG, sg, state_type) + for sg in self.space_groups + if self.n_atoms_compatibility_dict[sg] ] # Constraints after having selected point symmetry - if ps_idx != 0: + if ps_state != 0: point_symmetries = [] space_groups_ps = [ - (self.sg_idx, sg, state_type) - for sg in self.point_symmetries[ps_idx]["space_groups"] + (Prop.SG, sg, state_type) + for sg in self.point_symmetries[ps_state]["space_groups"] if self.n_atoms_compatibility_dict[sg] ] # If no crystal-lattice system selected yet - if cls_idx == 0: + if cls_state == 0: crystal_lattice_systems = [ - (self.cls_idx, idx, state_type) - for idx in self.point_symmetries[ps_idx]["crystal_lattice_systems"] - if self._is_compatible(cls_idx=idx, ps_idx=ps_idx) + (Prop.CLS, cls, state_type) + for cls in self.point_symmetries[ps_state][ + "crystal_lattice_systems" + ] + if self._is_compatible(cls=cls, ps=ps_state) ] else: space_groups_ps = [ - (self.sg_idx, idx, state_type) - for idx in self.space_groups - if self.n_atoms_compatibility_dict[idx] + (Prop.SG, sg, state_type) + for sg in self.space_groups + if self.n_atoms_compatibility_dict[sg] ] # Merge space_groups constraints and determine valid space group actions space_groups = list(set(space_groups_cls).intersection(set(space_groups_ps))) @@ -258,11 +281,11 @@ def state2oracle(self, state: List = None) -> Tensor: """ if state is None: state = self.state - if state[self.sg_idx] == 0: + if state[Prop.SG] == 0: raise ValueError( "The space group must have been set in order to call the oracle" ) - return torch.tensor(state[self.sg_idx], device=self.device, dtype=torch.long) + return torch.tensor(state[Prop.SG], device=self.device, dtype=torch.long) def statebatch2oracle( self, states: List[List] @@ -300,7 +323,7 @@ def statetorch2oracle( ---- oracle_state : Tensor """ - return torch.unsqueeze(states[:, self.sg_idx], dim=1).to(torch.long) + return torch.unsqueeze(states[:, Prop.SG], dim=1).to(torch.long) def state2readable(self, state=None): """ @@ -381,24 +404,24 @@ def get_parents(self, state=None, done=None, action=None): parents = [] actions = [] # Catch cases where space group has been selected - if state[self.sg_idx] != 0: - sg = state[self.sg_idx] + if state[Prop.SG] != 0: + sg = state[Prop.SG] # Add parent: source parents.append(self.source) - action = (self.sg_idx, sg, 0) + action = (Prop.SG, sg, 0) actions.append(action) # Add parents: states before setting space group - state[self.sg_idx] = 0 + state[Prop.SG] = 0 for prop in range(len(state)): parent = state.copy() parent[prop] = 0 parents.append(parent) parent_type = self.get_state_type(parent) - action = (self.sg_idx, sg, parent_type) + action = (Prop.SG, sg, parent_type) actions.append(action) else: # Catch other parents - for prop, idx in enumerate(state[: self.sg_idx]): + for prop, idx in enumerate(state[: Prop.SG]): if idx != 0: parent = state.copy() parent[prop] = 0 @@ -457,16 +480,14 @@ def get_max_traj_length(self): return len(self.source) + 1 def _set_constrained_properties(self, state: List[int]) -> List[int]: - cls_idx, ps_idx, sg_idx = state - if sg_idx != 0: - if cls_idx == 0: - state[self.cls_idx] = self.space_groups[state[self.sg_idx]][ + cls, ps, sg = state + if sg != 0: + if cls == 0: + state[Prop.CLS] = self.space_groups[state[Prop.SG]][ "crystal_lattice_system_idx" ] - if ps_idx == 0: - state[self.ps_idx] = self.space_groups[state[self.sg_idx]][ - "point_symmetry_idx" - ] + if ps == 0: + state[Prop.PS] = self.space_groups[state[Prop.SG]]["point_symmetry_idx"] return state def get_crystal_system(self, state: List[int] = None) -> str: @@ -475,8 +496,8 @@ def get_crystal_system(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.cls_idx] != 0: - return self.crystal_lattice_systems[state[self.cls_idx]]["crystal_system"] + if state[Prop.CLS] != 0: + return self.crystal_lattice_systems[state[Prop.CLS]]["crystal_system"] else: return "None" @@ -490,8 +511,8 @@ def get_lattice_system(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.cls_idx] != 0: - return self.crystal_lattice_systems[state[self.cls_idx]]["lattice_system"] + if state[Prop.CLS] != 0: + return self.crystal_lattice_systems[state[Prop.CLS]]["lattice_system"] else: return "None" @@ -522,8 +543,8 @@ def get_point_symmetry(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.ps_idx] != 0: - return self.point_symmetries[state[self.ps_idx]]["point_symmetry"] + if state[Prop.PS] != 0: + return self.point_symmetries[state[Prop.PS]]["point_symmetry"] else: return "None" @@ -537,8 +558,8 @@ def get_space_group_symbol(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.sg_idx] != 0: - return self.space_groups[state[self.sg_idx]]["full_symbol"] + if state[Prop.SG] != 0: + return self.space_groups[state[Prop.SG]]["full_symbol"] else: return "None" @@ -554,8 +575,8 @@ def get_crystal_class(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.sg_idx] != 0: - return self.space_groups[state[self.sg_idx]]["crystal_class"] + if state[Prop.SG] != 0: + return self.space_groups[state[Prop.SG]]["crystal_class"] else: return "None" @@ -571,8 +592,8 @@ def get_point_group(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.sg_idx] != 0: - return self.space_groups[state[self.sg_idx]]["point_group"] + if state[Prop.SG] != 0: + return self.space_groups[state[Prop.SG]]["point_group"] else: return "None" @@ -582,16 +603,11 @@ def point_group(self) -> str: def get_state_type(self, state: List[int] = None) -> int: """ - Returns the index of the type of the state passed as an argument. The state - type is one of the following (self.state_type_indices): - 0: both crystal-lattice system and point symmetry are unset (== 0) - 1: crystal-lattice system is set (!= 0); point symmetry is unset - 2: crystal-lattice system is unset; point symmetry is set - 3: both crystal-lattice system and point symmetry are set + Returns the value of the type of the state passed as an argument. """ if state is None: state = self.state - return sum([int(s > 0) * f for s, f in zip(state, (1, 2))]) + return StateType.get_state_type(state) def set_n_atoms_compatibility_dict(self, n_atoms: List): """ @@ -616,9 +632,7 @@ def set_n_atoms_compatibility_dict(self, n_atoms: List): n_atoms, self.space_groups.keys() ) - def _is_compatible( - self, cls_idx: Optional[int] = None, ps_idx: Optional[int] = None - ): + def _is_compatible(self, cls: Optional[int] = None, ps: Optional[int] = None): """ Returns True if there is exists at least one space group compatible with the atom composition (according to self.n_atoms_compatibility_dict), with the @@ -630,14 +644,14 @@ def _is_compatible( # Prune the list of space groups to those compatible with the provided crystal- # lattice system - if cls_idx is not None: - space_groups_cls = self.crystal_lattice_systems[cls_idx]["space_groups"] + if cls is not None: + space_groups_cls = self.crystal_lattice_systems[cls]["space_groups"] space_groups = list(set(space_groups).intersection(set(space_groups_cls))) # Prune the list of space groups to those compatible with the provided point # symmetry - if ps_idx is not None: - space_groups_ps = self.point_symmetries[ps_idx]["space_groups"] + if ps is not None: + space_groups_ps = self.point_symmetries[ps]["space_groups"] space_groups = list(set(space_groups).intersection(set(space_groups_ps))) return len(space_groups) > 0 diff --git a/tests/gflownet/envs/test_spacegroup.py b/tests/gflownet/envs/test_spacegroup.py index 50e82d61c..41a66b235 100644 --- a/tests/gflownet/envs/test_spacegroup.py +++ b/tests/gflownet/envs/test_spacegroup.py @@ -5,7 +5,7 @@ import torch from pyxtal.symmetry import Group -from gflownet.envs.crystals.spacegroup import SpaceGroup +from gflownet.envs.crystals.spacegroup import Prop, SpaceGroup N_ATOMS = [3, 7, 9] SG_SUBSET = [1, 17, 39, 123, 230] @@ -284,14 +284,14 @@ def test__get_mask_invalid_actions_forward__incompatible_sg_are_invalid( """ all_x = env_with_composition.get_all_terminating_states() for state in all_x: - state[env_with_composition.sg_idx] = 0 + state[Prop.SG] = 0 env_with_composition.set_state(state=state, done=False) mask_f = env_with_composition.get_mask_invalid_actions_forward() state_type = env_with_composition.get_state_type(state) for sg in env_with_composition.space_groups: sg_pyxtal = Group(sg) is_compatible = sg_pyxtal.check_compatible(N_ATOMS)[0] - action = (env_with_composition.sg_idx, sg, state_type) + action = (Prop.SG, sg, state_type) if not is_compatible: assert mask_f[env_with_composition.action_space.index(action)] is True @@ -302,31 +302,31 @@ def test__states_are_compatible_with_pymatgen(env): env.step((2, idx, 0)) sg_int = pmgg.sg_symbol_from_int_number(idx) sg = pmgg.SpaceGroup(sg_int) - assert sg.int_number == env.state[env.sg_idx] + assert sg.int_number == env.state[Prop.SG] assert sg.crystal_system == env.crystal_system assert sg.symbol == env.space_group_symbol assert sg.point_group == env.point_group @pytest.mark.parametrize( - "n_atoms, cls_idx, ps_idx", + "n_atoms, cls, ps", [ [[1], 5, 1], [[17], 5, 1], [[1, 13], 5, 1], ], ) -def test__special_cases_composition_compatibility(n_atoms, cls_idx, ps_idx): +def test__special_cases_composition_compatibility(n_atoms, cls, ps): env = SpaceGroup(n_atoms=n_atoms) # Crystal lattice system space groups must not compatible with composition # constraints - assert env._is_compatible(cls_idx=cls_idx) is False + assert env._is_compatible(cls=cls) is False # Setting crystal lattice system should fail action_cls_5_from_0 = (0, 5, 0) state_new, action, valid = env.step(action_cls_5_from_0) assert valid is False # Point symmetry space groups must be compatible with composition constraints - assert env._is_compatible(ps_idx=ps_idx) is True + assert env._is_compatible(ps=ps) is True # Setting point symmetry should be valid action_ps_1_from_0 = (1, 1, 0) state_new, action, valid = env.step(action_ps_1_from_0)