Skip to content

Commit

Permalink
Merge pull request #267 from alexhernandezgarcia/jdv/test_fix
Browse files Browse the repository at this point in the history
Jdv/test fix
  • Loading branch information
josephdviviano authored Jan 31, 2024
2 parents a30a1c8 + a9d722b commit 2bf5c91
Show file tree
Hide file tree
Showing 15 changed files with 773 additions and 568 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ pyxtal = ">=0.6.0"
rdkit = "*"
scikit-learn = ">=1.3.1"
scipy = ">=1.11.2"
six = ">=1.16.0"
#torch = ">=2.0.1"
six = "*"
torch = "==2.0.1"
torch-geometric = ">=2.3.1"
torch-scatter = ">=2.1.1"
torchani = "*"
Expand Down
962 changes: 495 additions & 467 deletions tests/gflownet/envs/common.py

Large diffs are not rendered by default.

50 changes: 33 additions & 17 deletions tests/gflownet/envs/test_ccrystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,20 +1587,36 @@ def test__get_logprobs_backward__returns_valid_actions(env, states, actions):
assert torch.all(torch.isfinite(logprobs))


def test__continuous_env_common(env):
print(
"\n\nCommon tests for crystal without composition <-> space group constraints\n"
)
return common.test__continuous_env_common(env)


def test__continuous_env_with_stoichiometry_sg_check_common(
env_with_stoichiometry_sg_check,
):
print("\n\nCommon tests for crystal with composition <-> space group constraints\n")
return common.test__continuous_env_common(env_with_stoichiometry_sg_check)


def test__continuous_env_common(env_sg_first):
print("\n\nCommon tests for crystal with space group first\n")
return common.test__continuous_env_common(env_sg_first)
class TestContinuousCrystalBasic(common.BaseTestsContinuous):
"""Common tests for crystal without composition <-> space group constraints."""

@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env
self.repeats = {
"test__reset__state_is_source": 10,
}


class TestContinuousCrystalSGCheck(common.BaseTestsContinuous):
"""Common tests for crystal with composition <-> space group constraints."""

@pytest.fixture(autouse=True)
def setup(self, env_with_stoichiometry_sg_check):
self.env = env_with_stoichiometry_sg_check
self.repeats = {
"test__set_state__creates_new_copy_of_state": 10, # Overrides no repeat.
"test__reset__state_is_source": 0,
}


class TestContinuousCrystalSGFirst(common.BaseTestsContinuous):
"""Common tests for crystal with space group first."""

@pytest.fixture(autouse=True)
def setup(self, env_sg_first):
self.env = env_sg_first
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}
24 changes: 18 additions & 6 deletions tests/gflownet/envs/test_ccube.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,9 +1135,21 @@ def test__get_mask_invalid_actions_forward__returns_expected(env, state, expecte
)


def test__continuous_env_common__cube1d(cube1d):
return common.test__continuous_env_common(cube1d)


def test__continuous_env_common__cube2d(cube2d):
return common.test__continuous_env_common(cube2d)
class TestContinuousCubeBasic(common.BaseTestsContinuous):
@pytest.fixture(autouse=True)
def setup(self, cube1d):
self.env = cube1d
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}


class TestContinuousCubeBasic(common.BaseTestsContinuous):
@pytest.fixture(autouse=True)
def setup(self, cube2d):
self.env = cube2d
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}
10 changes: 8 additions & 2 deletions tests/gflownet/envs/test_clattice_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,5 +300,11 @@ def test__readable2state__gives_expected_results_for_initial_states(
"lattice_system",
[CUBIC, HEXAGONAL, MONOCLINIC, ORTHORHOMBIC, RHOMBOHEDRAL, TETRAGONAL, TRICLINIC],
)
def test__continuous_env_common(env, lattice_system):
return common.test__continuous_env_common(env)
class TestContinuousLatticeBasic(common.BaseTestsContinuous):
@pytest.fixture(autouse=True)
def setup(self, env, lattice_system):
self.env = env # lattice_system intializes env fixture.
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}
24 changes: 18 additions & 6 deletions tests/gflownet/envs/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,21 @@ def test__insufficient_elements_left_does_not_cause_environment_to_get_stuck():
env.step(action)


def test__all_env_common(env):
return common.test__all_env_common(env)


def test__all_env_common__with_spacegroup_constraints(env_with_spacegroup):
return common.test__all_env_common(env_with_spacegroup)
class TestCompositionBasic(common.BaseTestsDiscrete):
@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}


class TestCompositionWithSpaceGroup(common.BaseTestsDiscrete):
@pytest.fixture(autouse=True)
def setup(self, env_with_spacegroup):
self.env = env_with_spacegroup
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}
24 changes: 18 additions & 6 deletions tests/gflownet/envs/test_crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,21 @@ def test__get_mask_invalid_actions_forward__masks_all_actions_from_different_sta
)


def test__all_env_common(env):
return common.test__all_env_common(env)


def test__all_env_common(env_with_stoichiometry_sg_check):
return common.test__all_env_common(env_with_stoichiometry_sg_check)
class TestCrystalBasic(common.BaseTestsDiscrete):
@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}


class TestCrystalStoichiometrySGCheck(common.BaseTestsDiscrete):
@pytest.fixture(autouse=True)
def setup(self, env_with_stoichiometry_sg_check):
self.env = env_with_stoichiometry_sg_check
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}
10 changes: 8 additions & 2 deletions tests/gflownet/envs/test_ctorus.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,11 @@ def test__sample_actions_batch__not_special_cases(
assert action_sampled != action_special


def test__continuous_env_common(env):
return common.test__continuous_env_common(env)
class TestContinuousTorusBasic(common.BaseTestsContinuous):
@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env
self.repeats = {
"test__get_logprobs__backward__returns_zero_if_done": 100, # Overrides no repeat.
"test__reset__state_is_source": 10,
}
56 changes: 45 additions & 11 deletions tests/gflownet/envs/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ def env():
return Grid(n_dim=3, length=5, cell_min=-1.0, cell_max=1.0)


@pytest.fixture
def env_default():
return Grid()


@pytest.fixture
def env_extended_action_space_2d():
return Grid(
Expand All @@ -34,11 +39,6 @@ def env_extended_action_space_3d():
)


@pytest.fixture
def env_default():
return Grid()


@pytest.fixture
def config_path():
return "../../../config/env/grid.yaml"
Expand Down Expand Up @@ -94,11 +94,45 @@ def test__get_action_space__returns_expected(
assert set(action_space) == set(env_extended_action_space_2d.action_space)


def test__all_env_common__standard(env_extended_action_space_3d):
print("\n\nCommon tests for 5x5 Grid with extended action space\n")
return common.test__all_env_common(env_extended_action_space_3d)
class TestGridBasic(common.BaseTestsContinuous):
"""Common tests for 5x5 Grid with standard action space."""

@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env
self.repeats = {
"test__reset__state_is_source": 10,
}


class TestGridDefaults(common.BaseTestsContinuous):
"""Common tests for 5x5 Grid with standard action space."""

@pytest.fixture(autouse=True)
def setup(self, env_default):
self.env = env_default
self.repeats = {
"test__reset__state_is_source": 10,
}


class TestGridExtended2D(common.BaseTestsContinuous):
"""Common tests for 5x5 Grid with extended action space."""

@pytest.fixture(autouse=True)
def setup(self, env_extended_action_space_2d):
self.env = env_extended_action_space_2d
self.repeats = {
"test__reset__state_is_source": 10,
}


class TestGridExtended3D(common.BaseTestsContinuous):
"""Common tests for 5x5 Grid with extended action space."""

def test__all_env_common__extended(env):
print("\n\nCommon tests for 5x5 Grid with standard action space\n")
return common.test__all_env_common(env)
@pytest.fixture(autouse=True)
def setup(self, env_extended_action_space_3d):
self.env = env_extended_action_space_3d
self.repeats = {
"test__reset__state_is_source": 10,
}
12 changes: 11 additions & 1 deletion tests/gflownet/envs/test_htorus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,18 @@ def env():
(0, 0),
(1, 0),
(2, 0),
],
]
],
)
def test__get_action_space__returns_expected(env, action_space):
assert set(action_space) == set(env.action_space)


@pytest.mark.skip(reason="skip while the environment remains outdated")
class TestHybridTorus(common.BaseTestsDiscrete):
@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env
self.repeats = {
"test__reset__state_is_source": 10,
}
24 changes: 12 additions & 12 deletions tests/gflownet/envs/test_lattice_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def env(lattice_system):
return LatticeParameters(lattice_system=lattice_system, grid_size=61)


@pytest.fixture()
def triclinic_env():
return LatticeParameters(lattice_system=TRICLINIC, grid_size=5)


@pytest.mark.parametrize("lattice_system", LATTICE_SYSTEMS)
def test__environment__initializes_properly(env, lattice_system):
pass
Expand Down Expand Up @@ -159,17 +164,14 @@ def test__get_mask_invalid_actions_forward__returns_expected_mask(
@pytest.mark.parametrize("lattice_system", LATTICE_SYSTEMS)
def test__get_parents__returns_no_parents_in_initial_state(env, lattice_system):
parents, actions = env.get_parents()

assert len(parents) == 0
assert len(actions) == 0


@pytest.mark.parametrize("lattice_system", LATTICE_SYSTEMS)
def test__get_parents__returns_parents_after_step(env, lattice_system):
env.step((1, 1, 1, 0, 0, 0))

parents, actions = env.get_parents()

assert len(parents) != 0
assert len(actions) != 0

Expand All @@ -189,9 +191,7 @@ def test__get_parents__returns_same_number_of_parents_and_actions(
):
for action in actions:
env.step(action=action)

parents, actions = env.get_parents()

assert len(parents) == len(actions)


Expand Down Expand Up @@ -289,11 +289,8 @@ def test__state2oracle__returns_expected_tensor(env, lattice_system, state, exp_
@pytest.mark.parametrize("lattice_system", [TRICLINIC])
def test__reset(env, lattice_system):
env.step((1, 1, 1, 0, 0, 0))

assert env.state != env.source

env.reset()

assert env.state == env.source


Expand Down Expand Up @@ -328,7 +325,10 @@ def test__readable2state__returns_initial_state_for_rhombohedral_and_triclinic(
assert env.readable2state(readable) == [0, 0, 0, 0, 0, 0]


def test__all_env_common():
env = LatticeParameters(lattice_system=TRICLINIC, grid_size=5)

return common.test__all_env_common(env)
class TestLattice(common.BaseTestsDiscrete):
@pytest.fixture(autouse=True)
def setup(self, triclinic_env):
self.env = triclinic_env
self.repeats = {
"test__reset__state_is_source": 10,
}
38 changes: 27 additions & 11 deletions tests/gflownet/envs/test_spacegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,18 +385,34 @@ def test__special_cases_composition_compatibility(n_atoms, cls_idx, ps_idx):
assert valid is False


def test__all_common__env(env):
print("\n\nCommon tests for SpaceGroup without composition restrictions\n")
return common.test__all_env_common(env)
class TestSpaceGroupBasic(common.BaseTestsDiscrete):
"""Common tests for SpaceGroup without composition restrictions."""

@pytest.fixture(autouse=True)
def setup(self, env):
self.env = env
self.repeats = {
"test__reset__state_is_source": 10,
}


class TestSpaceGroupWithComposition(common.BaseTestsDiscrete):
"""Common tests for SpaceGroup with restrictions from composition."""

@pytest.fixture(autouse=True)
def setup(self, env_with_composition):
self.env = env_with_composition
self.repeats = {
"test__reset__state_is_source": 10,
}

def test__all_common__env_with_composition(env_with_composition):
print(
f"\n\nCommon tests for SpaceGroup with restrictions from composition {N_ATOMS}\n"
)
return common.test__all_env_common(env_with_composition)

class TestSpaceGroupWithRestrictedSpaceGroups(common.BaseTestsDiscrete):
"""Common tests for SpaceGroup with restricted space groups."""

def test__all_common__env_with_restricted_spacegroups(env_with_restricted_spacegroups):
print(f"\n\nCommon tests for SpaceGroup with restricted space groups {SG_SUBSET}")
return common.test__all_env_common(env_with_restricted_spacegroups)
@pytest.fixture(autouse=True)
def setup(self, env_with_restricted_spacegroups):
self.env = env_with_restricted_spacegroups
self.repeats = {
"test__reset__state_is_source": 10,
}
Loading

0 comments on commit 2bf5c91

Please sign in to comment.