From 7564f936a4575c54b4f5260ae0eace453ea23a69 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 25 Sep 2023 17:01:37 -0400 Subject: [PATCH 1/6] Use Prop instead of self._idx --- gflownet/envs/crystals/spacegroup.py | 62 ++++++++++++++-------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 8de313991..4e74bf0b3 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -179,12 +179,12 @@ def get_mask_invalid_actions_forward( # composition-compatibility constraints if cls_idx == 0 and ps_idx == 0: crystal_lattice_systems = [ - (self.cls_idx, idx, state_type) + (Prop.CLS.value, idx, state_type) for idx in self.crystal_lattice_systems if self._is_compatible(cls_idx=idx) ] point_symmetries = [ - (self.ps_idx, idx, state_type) + (Prop.PS.value, idx, state_type) for idx in self.point_symmetries if self._is_compatible(ps_idx=idx) ] @@ -192,20 +192,20 @@ def get_mask_invalid_actions_forward( if cls_idx != 0: crystal_lattice_systems = [] space_groups_cls = [ - (self.sg_idx, sg, state_type) + (Prop.SG.value, sg, state_type) for sg in self.crystal_lattice_systems[cls_idx]["space_groups"] if self.n_atoms_compatibility_dict[sg] ] # If no point symmetry selected yet if ps_idx == 0: point_symmetries = [ - (self.ps_idx, idx, state_type) + (Prop.PS.value, idx, state_type) for idx in self.crystal_lattice_systems[cls_idx]["point_symmetries"] if self._is_compatible(cls_idx=cls_idx, ps_idx=idx) ] else: space_groups_cls = [ - (self.sg_idx, idx, state_type) + (Prop.SG.value, idx, state_type) for idx in self.space_groups if self.n_atoms_compatibility_dict[idx] ] @@ -213,20 +213,20 @@ def get_mask_invalid_actions_forward( if ps_idx != 0: point_symmetries = [] space_groups_ps = [ - (self.sg_idx, sg, state_type) + (Prop.SG.value, sg, state_type) for sg in self.point_symmetries[ps_idx]["space_groups"] if self.n_atoms_compatibility_dict[sg] ] # If no crystal-lattice system selected yet if cls_idx == 0: crystal_lattice_systems = [ - (self.cls_idx, idx, state_type) + (Prop.CLS.value, idx, state_type) for idx in self.point_symmetries[ps_idx]["crystal_lattice_systems"] if self._is_compatible(cls_idx=idx, ps_idx=ps_idx) ] else: space_groups_ps = [ - (self.sg_idx, idx, state_type) + (Prop.SG.value, idx, state_type) for idx in self.space_groups if self.n_atoms_compatibility_dict[idx] ] @@ -258,11 +258,11 @@ def state2oracle(self, state: List = None) -> Tensor: """ if state is None: state = self.state - if state[self.sg_idx] == 0: + if state[Prop.SG.value] == 0: raise ValueError( "The space group must have been set in order to call the oracle" ) - return torch.tensor(state[self.sg_idx], device=self.device, dtype=torch.long) + return torch.tensor(state[Prop.SG.value], device=self.device, dtype=torch.long) def statebatch2oracle( self, states: List[List] @@ -300,7 +300,7 @@ def statetorch2oracle( ---- oracle_state : Tensor """ - return torch.unsqueeze(states[:, self.sg_idx], dim=1).to(torch.long) + return torch.unsqueeze(states[:, Prop.SG.value], dim=1).to(torch.long) def state2readable(self, state=None): """ @@ -381,24 +381,24 @@ def get_parents(self, state=None, done=None, action=None): parents = [] actions = [] # Catch cases where space group has been selected - if state[self.sg_idx] != 0: - sg = state[self.sg_idx] + if state[Prop.SG.value] != 0: + sg = state[Prop.SG.value] # Add parent: source parents.append(self.source) - action = (self.sg_idx, sg, 0) + action = (Prop.SG.value, sg, 0) actions.append(action) # Add parents: states before setting space group - state[self.sg_idx] = 0 + state[Prop.SG.value] = 0 for prop in range(len(state)): parent = state.copy() parent[prop] = 0 parents.append(parent) parent_type = self.get_state_type(parent) - action = (self.sg_idx, sg, parent_type) + action = (Prop.SG.value, sg, parent_type) actions.append(action) else: # Catch other parents - for prop, idx in enumerate(state[: self.sg_idx]): + for prop, idx in enumerate(state[: Prop.SG.value]): if idx != 0: parent = state.copy() parent[prop] = 0 @@ -460,11 +460,11 @@ def _set_constrained_properties(self, state: List[int]) -> List[int]: cls_idx, ps_idx, sg_idx = state if sg_idx != 0: if cls_idx == 0: - state[self.cls_idx] = self.space_groups[state[self.sg_idx]][ + state[Prop.CLS.value] = self.space_groups[state[Prop.SG.value]][ "crystal_lattice_system_idx" ] if ps_idx == 0: - state[self.ps_idx] = self.space_groups[state[self.sg_idx]][ + state[Prop.PS.value] = self.space_groups[state[Prop.SG.value]][ "point_symmetry_idx" ] return state @@ -475,8 +475,8 @@ def get_crystal_system(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.cls_idx] != 0: - return self.crystal_lattice_systems[state[self.cls_idx]]["crystal_system"] + if state[Prop.CLS.value] != 0: + return self.crystal_lattice_systems[state[Prop.CLS.value]]["crystal_system"] else: return "None" @@ -490,8 +490,8 @@ def get_lattice_system(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.cls_idx] != 0: - return self.crystal_lattice_systems[state[self.cls_idx]]["lattice_system"] + if state[Prop.CLS.value] != 0: + return self.crystal_lattice_systems[state[Prop.CLS.value]]["lattice_system"] else: return "None" @@ -522,8 +522,8 @@ def get_point_symmetry(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.ps_idx] != 0: - return self.point_symmetries[state[self.ps_idx]]["point_symmetry"] + if state[Prop.PS.value] != 0: + return self.point_symmetries[state[Prop.PS.value]]["point_symmetry"] else: return "None" @@ -537,8 +537,8 @@ def get_space_group_symbol(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.sg_idx] != 0: - return self.space_groups[state[self.sg_idx]]["full_symbol"] + if state[Prop.SG.value] != 0: + return self.space_groups[state[Prop.SG.value]]["full_symbol"] else: return "None" @@ -554,8 +554,8 @@ def get_crystal_class(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.sg_idx] != 0: - return self.space_groups[state[self.sg_idx]]["crystal_class"] + if state[Prop.SG.value] != 0: + return self.space_groups[state[Prop.SG.value]]["crystal_class"] else: return "None" @@ -571,8 +571,8 @@ def get_point_group(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[self.sg_idx] != 0: - return self.space_groups[state[self.sg_idx]]["point_group"] + if state[Prop.SG.value] != 0: + return self.space_groups[state[Prop.SG.value]]["point_group"] else: return "None" From 76cd939f43af3c56fcf56d35874c1870e8fbd711 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 25 Sep 2023 17:28:45 -0400 Subject: [PATCH 2/6] Refactor all using Prop --- gflownet/envs/crystals/spacegroup.py | 139 ++++++++++++------------- tests/gflownet/envs/test_spacegroup.py | 16 +-- 2 files changed, 76 insertions(+), 79 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 4e74bf0b3..d71d20d59 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -45,9 +45,9 @@ def _get_space_groups(): return SPACE_GROUPS -class Prop(Enum): +class Prop: """ - Enumeration of the 3 properties of the SpaceGroup Environment: + Encodes the 3 properties of the SpaceGroup Environment: - Crystal lattice system - Point symmetry - Space group @@ -109,9 +109,6 @@ def __init__( self._restrict_space_groups(space_groups_subset) # Set dictionary of compatibility with number of atoms self.set_n_atoms_compatibility_dict(n_atoms) - # Indices in the state representation: crystal-lattice system (cls), point - # symmetry (ps) and space group (sg) - self.cls_idx, self.ps_idx, self.sg_idx = 0, 1, 2 # Dictionary of all properties self.properties = { Prop.CLS: self.crystal_lattice_systems, @@ -147,7 +144,7 @@ def get_action_space(self): continue if prop == Prop.PS and s_from_type in [2, 3]: continue - actions_prop = [(prop.value, idx, s_from_type) for idx in indices] + actions_prop = [(prop, idx, s_from_type) for idx in indices] actions += actions_prop actions += [self.eos] return actions @@ -168,67 +165,71 @@ def get_mask_invalid_actions_forward( done = self.done if done: return [True for _ in self.action_space] - cls_idx, ps_idx, sg_idx = state + cls_state, ps_state, sg_state = state # If space group has been selected, only valid action is EOS - if sg_idx != 0: + if sg_state != 0: mask = [True for _ in self.action_space] mask[-1] = False return mask state_type = self.get_state_type(state) # If neither crystal-lattice system nor point symmetry selected, apply only # composition-compatibility constraints - if cls_idx == 0 and ps_idx == 0: + if cls_state == 0 and ps_state == 0: crystal_lattice_systems = [ - (Prop.CLS.value, idx, state_type) - for idx in self.crystal_lattice_systems - if self._is_compatible(cls_idx=idx) + (Prop.CLS, cls, state_type) + for cls in self.crystal_lattice_systems + if self._is_compatible(cls=cls) ] point_symmetries = [ - (Prop.PS.value, idx, state_type) - for idx in self.point_symmetries - if self._is_compatible(ps_idx=idx) + (Prop.PS, ps, state_type) + for ps in self.point_symmetries + if self._is_compatible(ps=ps) ] # Constraints after having selected crystal-lattice system - if cls_idx != 0: + if cls_state != 0: crystal_lattice_systems = [] space_groups_cls = [ - (Prop.SG.value, sg, state_type) - for sg in self.crystal_lattice_systems[cls_idx]["space_groups"] + (Prop.SG, sg, state_type) + for sg in self.crystal_lattice_systems[cls_state]["space_groups"] if self.n_atoms_compatibility_dict[sg] ] # If no point symmetry selected yet - if ps_idx == 0: + if ps_state == 0: point_symmetries = [ - (Prop.PS.value, idx, state_type) - for idx in self.crystal_lattice_systems[cls_idx]["point_symmetries"] - if self._is_compatible(cls_idx=cls_idx, ps_idx=idx) + (Prop.PS, ps, state_type) + for ps in self.crystal_lattice_systems[cls_state][ + "point_symmetries" + ] + if self._is_compatible(cls=cls_state, ps=ps) ] else: space_groups_cls = [ - (Prop.SG.value, idx, state_type) - for idx in self.space_groups - if self.n_atoms_compatibility_dict[idx] + (Prop.SG, sg, state_type) + for sg in self.space_groups + if self.n_atoms_compatibility_dict[sg] ] # Constraints after having selected point symmetry - if ps_idx != 0: + if ps_state != 0: point_symmetries = [] space_groups_ps = [ - (Prop.SG.value, sg, state_type) - for sg in self.point_symmetries[ps_idx]["space_groups"] + (Prop.SG, sg, state_type) + for sg in self.point_symmetries[ps_state]["space_groups"] if self.n_atoms_compatibility_dict[sg] ] # If no crystal-lattice system selected yet - if cls_idx == 0: + if cls_state == 0: crystal_lattice_systems = [ - (Prop.CLS.value, idx, state_type) - for idx in self.point_symmetries[ps_idx]["crystal_lattice_systems"] - if self._is_compatible(cls_idx=idx, ps_idx=ps_idx) + (Prop.CLS, cls, state_type) + for cls in self.point_symmetries[ps_state][ + "crystal_lattice_systems" + ] + if self._is_compatible(cls=cls, ps=ps_state) ] else: space_groups_ps = [ - (Prop.SG.value, idx, state_type) - for idx in self.space_groups - if self.n_atoms_compatibility_dict[idx] + (Prop.SG, sg, state_type) + for sg in self.space_groups + if self.n_atoms_compatibility_dict[sg] ] # Merge space_groups constraints and determine valid space group actions space_groups = list(set(space_groups_cls).intersection(set(space_groups_ps))) @@ -258,11 +259,11 @@ def state2oracle(self, state: List = None) -> Tensor: """ if state is None: state = self.state - if state[Prop.SG.value] == 0: + if state[Prop.SG] == 0: raise ValueError( "The space group must have been set in order to call the oracle" ) - return torch.tensor(state[Prop.SG.value], device=self.device, dtype=torch.long) + return torch.tensor(state[Prop.SG], device=self.device, dtype=torch.long) def statebatch2oracle( self, states: List[List] @@ -300,7 +301,7 @@ def statetorch2oracle( ---- oracle_state : Tensor """ - return torch.unsqueeze(states[:, Prop.SG.value], dim=1).to(torch.long) + return torch.unsqueeze(states[:, Prop.SG], dim=1).to(torch.long) def state2readable(self, state=None): """ @@ -381,24 +382,24 @@ def get_parents(self, state=None, done=None, action=None): parents = [] actions = [] # Catch cases where space group has been selected - if state[Prop.SG.value] != 0: - sg = state[Prop.SG.value] + if state[Prop.SG] != 0: + sg = state[Prop.SG] # Add parent: source parents.append(self.source) - action = (Prop.SG.value, sg, 0) + action = (Prop.SG, sg, 0) actions.append(action) # Add parents: states before setting space group - state[Prop.SG.value] = 0 + state[Prop.SG] = 0 for prop in range(len(state)): parent = state.copy() parent[prop] = 0 parents.append(parent) parent_type = self.get_state_type(parent) - action = (Prop.SG.value, sg, parent_type) + action = (Prop.SG, sg, parent_type) actions.append(action) else: # Catch other parents - for prop, idx in enumerate(state[: Prop.SG.value]): + for prop, idx in enumerate(state[: Prop.SG]): if idx != 0: parent = state.copy() parent[prop] = 0 @@ -457,16 +458,14 @@ def get_max_traj_length(self): return len(self.source) + 1 def _set_constrained_properties(self, state: List[int]) -> List[int]: - cls_idx, ps_idx, sg_idx = state - if sg_idx != 0: - if cls_idx == 0: - state[Prop.CLS.value] = self.space_groups[state[Prop.SG.value]][ + cls, ps, sg = state + if sg != 0: + if cls == 0: + state[Prop.CLS] = self.space_groups[state[Prop.SG]][ "crystal_lattice_system_idx" ] - if ps_idx == 0: - state[Prop.PS.value] = self.space_groups[state[Prop.SG.value]][ - "point_symmetry_idx" - ] + if ps == 0: + state[Prop.PS] = self.space_groups[state[Prop.SG]]["point_symmetry_idx"] return state def get_crystal_system(self, state: List[int] = None) -> str: @@ -475,8 +474,8 @@ def get_crystal_system(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[Prop.CLS.value] != 0: - return self.crystal_lattice_systems[state[Prop.CLS.value]]["crystal_system"] + if state[Prop.CLS] != 0: + return self.crystal_lattice_systems[state[Prop.CLS]]["crystal_system"] else: return "None" @@ -490,8 +489,8 @@ def get_lattice_system(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[Prop.CLS.value] != 0: - return self.crystal_lattice_systems[state[Prop.CLS.value]]["lattice_system"] + if state[Prop.CLS] != 0: + return self.crystal_lattice_systems[state[Prop.CLS]]["lattice_system"] else: return "None" @@ -522,8 +521,8 @@ def get_point_symmetry(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[Prop.PS.value] != 0: - return self.point_symmetries[state[Prop.PS.value]]["point_symmetry"] + if state[Prop.PS] != 0: + return self.point_symmetries[state[Prop.PS]]["point_symmetry"] else: return "None" @@ -537,8 +536,8 @@ def get_space_group_symbol(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[Prop.SG.value] != 0: - return self.space_groups[state[Prop.SG.value]]["full_symbol"] + if state[Prop.SG] != 0: + return self.space_groups[state[Prop.SG]]["full_symbol"] else: return "None" @@ -554,8 +553,8 @@ def get_crystal_class(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[Prop.SG.value] != 0: - return self.space_groups[state[Prop.SG.value]]["crystal_class"] + if state[Prop.SG] != 0: + return self.space_groups[state[Prop.SG]]["crystal_class"] else: return "None" @@ -571,8 +570,8 @@ def get_point_group(self, state: List[int] = None) -> str: """ if state is None: state = self.state - if state[Prop.SG.value] != 0: - return self.space_groups[state[Prop.SG.value]]["point_group"] + if state[Prop.SG] != 0: + return self.space_groups[state[Prop.SG]]["point_group"] else: return "None" @@ -616,9 +615,7 @@ def set_n_atoms_compatibility_dict(self, n_atoms: List): n_atoms, self.space_groups.keys() ) - def _is_compatible( - self, cls_idx: Optional[int] = None, ps_idx: Optional[int] = None - ): + def _is_compatible(self, cls: Optional[int] = None, ps: Optional[int] = None): """ Returns True if there is exists at least one space group compatible with the atom composition (according to self.n_atoms_compatibility_dict), with the @@ -630,14 +627,14 @@ def _is_compatible( # Prune the list of space groups to those compatible with the provided crystal- # lattice system - if cls_idx is not None: - space_groups_cls = self.crystal_lattice_systems[cls_idx]["space_groups"] + if cls is not None: + space_groups_cls = self.crystal_lattice_systems[cls]["space_groups"] space_groups = list(set(space_groups).intersection(set(space_groups_cls))) # Prune the list of space groups to those compatible with the provided point # symmetry - if ps_idx is not None: - space_groups_ps = self.point_symmetries[ps_idx]["space_groups"] + if ps is not None: + space_groups_ps = self.point_symmetries[ps]["space_groups"] space_groups = list(set(space_groups).intersection(set(space_groups_ps))) return len(space_groups) > 0 diff --git a/tests/gflownet/envs/test_spacegroup.py b/tests/gflownet/envs/test_spacegroup.py index 50e82d61c..41a66b235 100644 --- a/tests/gflownet/envs/test_spacegroup.py +++ b/tests/gflownet/envs/test_spacegroup.py @@ -5,7 +5,7 @@ import torch from pyxtal.symmetry import Group -from gflownet.envs.crystals.spacegroup import SpaceGroup +from gflownet.envs.crystals.spacegroup import Prop, SpaceGroup N_ATOMS = [3, 7, 9] SG_SUBSET = [1, 17, 39, 123, 230] @@ -284,14 +284,14 @@ def test__get_mask_invalid_actions_forward__incompatible_sg_are_invalid( """ all_x = env_with_composition.get_all_terminating_states() for state in all_x: - state[env_with_composition.sg_idx] = 0 + state[Prop.SG] = 0 env_with_composition.set_state(state=state, done=False) mask_f = env_with_composition.get_mask_invalid_actions_forward() state_type = env_with_composition.get_state_type(state) for sg in env_with_composition.space_groups: sg_pyxtal = Group(sg) is_compatible = sg_pyxtal.check_compatible(N_ATOMS)[0] - action = (env_with_composition.sg_idx, sg, state_type) + action = (Prop.SG, sg, state_type) if not is_compatible: assert mask_f[env_with_composition.action_space.index(action)] is True @@ -302,31 +302,31 @@ def test__states_are_compatible_with_pymatgen(env): env.step((2, idx, 0)) sg_int = pmgg.sg_symbol_from_int_number(idx) sg = pmgg.SpaceGroup(sg_int) - assert sg.int_number == env.state[env.sg_idx] + assert sg.int_number == env.state[Prop.SG] assert sg.crystal_system == env.crystal_system assert sg.symbol == env.space_group_symbol assert sg.point_group == env.point_group @pytest.mark.parametrize( - "n_atoms, cls_idx, ps_idx", + "n_atoms, cls, ps", [ [[1], 5, 1], [[17], 5, 1], [[1, 13], 5, 1], ], ) -def test__special_cases_composition_compatibility(n_atoms, cls_idx, ps_idx): +def test__special_cases_composition_compatibility(n_atoms, cls, ps): env = SpaceGroup(n_atoms=n_atoms) # Crystal lattice system space groups must not compatible with composition # constraints - assert env._is_compatible(cls_idx=cls_idx) is False + assert env._is_compatible(cls=cls) is False # Setting crystal lattice system should fail action_cls_5_from_0 = (0, 5, 0) state_new, action, valid = env.step(action_cls_5_from_0) assert valid is False # Point symmetry space groups must be compatible with composition constraints - assert env._is_compatible(ps_idx=ps_idx) is True + assert env._is_compatible(ps=ps) is True # Setting point symmetry should be valid action_ps_1_from_0 = (1, 1, 0) state_new, action, valid = env.step(action_ps_1_from_0) From 9332f1e48f9b60592d754db82a55ae938a99c93c Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 25 Sep 2023 17:43:37 -0400 Subject: [PATCH 3/6] Dictionary of properties created and only used in get_action_space --- gflownet/envs/crystals/spacegroup.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index d71d20d59..c2b5d030d 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -102,19 +102,14 @@ def __init__( the space group. 0's are removed from the list. If None, composition/space group constraints are ignored. """ - # Get dictionaries + # Read dictionaries from YAML files self.crystal_lattice_systems = _get_crystal_lattice_systems() self.point_symmetries = _get_point_symmetries() self.space_groups = _get_space_groups() + # Restrict spacce groups to a subset self._restrict_space_groups(space_groups_subset) # Set dictionary of compatibility with number of atoms self.set_n_atoms_compatibility_dict(n_atoms) - # Dictionary of all properties - self.properties = { - Prop.CLS: self.crystal_lattice_systems, - Prop.PS: self.point_symmetries, - Prop.SG: self.space_groups, - } # Indices of state types (see self.get_state_type) self.state_type_indices = [0, 1, 2, 3] # End-of-sequence action @@ -138,7 +133,13 @@ def get_action_space(self): state (see self.state_type_indices). """ actions = [] - for prop, indices in self.properties.items(): + # Create dictionary with of all properties + properties = { + Prop.CLS: self.crystal_lattice_systems, + Prop.PS: self.point_symmetries, + Prop.SG: self.space_groups, + } + for prop, indices in properties.items(): for s_from_type in self.state_type_indices: if prop == Prop.CLS and s_from_type in [1, 3]: continue From ece309108754ab724bcc309b483e272d9a61f01f Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 25 Sep 2023 18:03:22 -0400 Subject: [PATCH 4/6] Make state type indices an Enum. --- gflownet/envs/crystals/spacegroup.py | 39 ++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index c2b5d030d..a4cce0a3e 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -48,9 +48,9 @@ def _get_space_groups(): class Prop: """ Encodes the 3 properties of the SpaceGroup Environment: - - Crystal lattice system - - Point symmetry - - Space group + 0: Crystal lattice system + 1: Point symmetry + 2: Space group """ CLS = 0 @@ -58,6 +58,23 @@ class Prop: SG = 2 +class StateType(Enum): + """ + Enumeration of the 5 types of state: + 0: Source - both crystal-lattice system and point symmetry are unset (== 0) + 1: CLS - crystal-lattice system is set (!= 0); point symmetry is unset + 2: PS - crystal-lattice system is unset; point symmetry is set + 3: CLS_PS - both crystal-lattice system and point symmetry are set + 4: SG: space group is set (trajectory done) + """ + + SOURCE = 0 + CLS = 1 + PS = 2 + CLS_PS = 3 + SG = 4 + + class SpaceGroup(GFlowNetEnv): """ SpaceGroup environment for ionic conductivity. @@ -110,8 +127,6 @@ def __init__( self._restrict_space_groups(space_groups_subset) # Set dictionary of compatibility with number of atoms self.set_n_atoms_compatibility_dict(n_atoms) - # Indices of state types (see self.get_state_type) - self.state_type_indices = [0, 1, 2, 3] # End-of-sequence action self.eos = (-1, -1, -1) # Source state: index 0 (empty) for all three properties (crystal-lattice @@ -130,7 +145,7 @@ def get_action_space(self): (property, index, state_from_type), where property is (0: crystal-lattice system, 1: point symmetry, 2: space group), index is the index of the property set by the action and state_from_type is the state type of the originating - state (see self.state_type_indices). + state (see StateType). """ actions = [] # Create dictionary with of all properties @@ -140,12 +155,14 @@ def get_action_space(self): Prop.SG: self.space_groups, } for prop, indices in properties.items(): - for s_from_type in self.state_type_indices: - if prop == Prop.CLS and s_from_type in [1, 3]: + for state_type in StateType: + if state_type == StateType.SG: + continue + if prop == Prop.CLS and state_type in [StateType.CLS, StateType.CLS_PS]: continue - if prop == Prop.PS and s_from_type in [2, 3]: + if prop == Prop.PS and state_type in [StateType.PS, StateType.CLS_PS]: continue - actions_prop = [(prop, idx, s_from_type) for idx in indices] + actions_prop = [(prop, idx, state_type.value) for idx in indices] actions += actions_prop actions += [self.eos] return actions @@ -583,7 +600,7 @@ def point_group(self) -> str: def get_state_type(self, state: List[int] = None) -> int: """ Returns the index of the type of the state passed as an argument. The state - type is one of the following (self.state_type_indices): + type is one of the following (StateType): 0: both crystal-lattice system and point symmetry are unset (== 0) 1: crystal-lattice system is set (!= 0); point symmetry is unset 2: crystal-lattice system is unset; point symmetry is set From f26b700e22ea0647bb32c856a00cef89ceb95b60 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 16 Oct 2023 08:37:50 -0400 Subject: [PATCH 5/6] Remove unnecessary 5th StateType Make missing change --- gflownet/envs/crystals/spacegroup.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index a4cce0a3e..02ced5070 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -60,19 +60,17 @@ class Prop: class StateType(Enum): """ - Enumeration of the 5 types of state: + Enumeration of the 4 types of state from which transitions can originate: 0: Source - both crystal-lattice system and point symmetry are unset (== 0) 1: CLS - crystal-lattice system is set (!= 0); point symmetry is unset 2: PS - crystal-lattice system is unset; point symmetry is set 3: CLS_PS - both crystal-lattice system and point symmetry are set - 4: SG: space group is set (trajectory done) """ SOURCE = 0 CLS = 1 PS = 2 CLS_PS = 3 - SG = 4 class SpaceGroup(GFlowNetEnv): @@ -156,8 +154,6 @@ def get_action_space(self): } for prop, indices in properties.items(): for state_type in StateType: - if state_type == StateType.SG: - continue if prop == Prop.CLS and state_type in [StateType.CLS, StateType.CLS_PS]: continue if prop == Prop.PS and state_type in [StateType.PS, StateType.CLS_PS]: From bf432398b97f460cca2cd862444fc4a716bcf979 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 17 Oct 2023 22:00:51 -0400 Subject: [PATCH 6/6] StateType not Enum anymore --- gflownet/envs/crystals/spacegroup.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 02ced5070..52d371cb7 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -56,11 +56,12 @@ class Prop: CLS = 0 PS = 1 SG = 2 + ALL = (CLS, PS, SG) -class StateType(Enum): +class StateType: """ - Enumeration of the 4 types of state from which transitions can originate: + Encodes the 4 types of state from which transitions can originate: 0: Source - both crystal-lattice system and point symmetry are unset (== 0) 1: CLS - crystal-lattice system is set (!= 0); point symmetry is unset 2: PS - crystal-lattice system is unset; point symmetry is set @@ -71,6 +72,13 @@ class StateType(Enum): CLS = 1 PS = 2 CLS_PS = 3 + ALL = (SOURCE, CLS, PS, CLS_PS) + + def get_state_type(state: List[int]) -> int: + """ + Returns the value of the type of the state passed as an argument. + """ + return sum([int(s > 0) * f for s, f in zip(state, (1, 2))]) class SpaceGroup(GFlowNetEnv): @@ -153,12 +161,12 @@ def get_action_space(self): Prop.SG: self.space_groups, } for prop, indices in properties.items(): - for state_type in StateType: + for state_type in StateType.ALL: if prop == Prop.CLS and state_type in [StateType.CLS, StateType.CLS_PS]: continue if prop == Prop.PS and state_type in [StateType.PS, StateType.CLS_PS]: continue - actions_prop = [(prop, idx, state_type.value) for idx in indices] + actions_prop = [(prop, idx, state_type) for idx in indices] actions += actions_prop actions += [self.eos] return actions @@ -595,16 +603,11 @@ def point_group(self) -> str: def get_state_type(self, state: List[int] = None) -> int: """ - Returns the index of the type of the state passed as an argument. The state - type is one of the following (StateType): - 0: both crystal-lattice system and point symmetry are unset (== 0) - 1: crystal-lattice system is set (!= 0); point symmetry is unset - 2: crystal-lattice system is unset; point symmetry is set - 3: both crystal-lattice system and point symmetry are set + Returns the value of the type of the state passed as an argument. """ if state is None: state = self.state - return sum([int(s > 0) * f for s, f in zip(state, (1, 2))]) + return StateType.get_state_type(state) def set_n_atoms_compatibility_dict(self, n_atoms: List): """