-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import os | ||
import io | ||
import contextlib | ||
import pytest | ||
import pandas as pd | ||
import numpy as np | ||
import resistnet | ||
from resistnet.hall_of_fame import HallOfFame | ||
|
||
|
||
def create_fake_population(num_models, num_variables): | ||
"""Creates a fake population for testing.""" | ||
population = [] | ||
for _ in range(num_models): | ||
model = {'fitness': np.random.rand()} | ||
for var in range(num_variables): | ||
model[f'var{var+1}'] = np.random.rand() | ||
model[f'var{var+1}_weight'] = 1.0 | ||
model[f'var{var+1}_trans'] = 0 | ||
model[f'var{var+1}_shape'] = 1.0 | ||
model.update( | ||
{'loglik': np.random.rand(), 'r2m': np.random.rand(), | ||
'aic': -1 * np.random.rand(), 'delta_aic_null': np.random.rand()} | ||
) | ||
population.append(model) | ||
return population | ||
|
||
|
||
@pytest.fixture | ||
def hall_of_fame_df_fixture(): | ||
base_path = os.path.dirname(resistnet.__file__) | ||
file_path = os.path.join( | ||
base_path, 'data', 'test_ensemble', 'replicates', | ||
'test_18.out.HallOfFame.tsv' | ||
) | ||
|
||
# Read the TSV file into a DataFrame | ||
df = pd.read_csv(file_path, sep='\t', header=0) | ||
return df | ||
|
||
|
||
@pytest.fixture | ||
def hall_of_fame_fixture(): | ||
variables = ['var1', 'var2'] | ||
max_size = 10 | ||
num_models = 5 | ||
|
||
fake_population = create_fake_population(num_models, len(variables)) | ||
|
||
hall_of_fame = HallOfFame(variables, max_size, init_pop=fake_population) | ||
return hall_of_fame | ||
|
||
|
||
def test_hall_of_fame_initialization(): | ||
variables = ['var1', 'var2'] | ||
max_size = 10 | ||
init_pop = None | ||
|
||
# Initialize the HallOfFame | ||
hall_of_fame = HallOfFame(variables, max_size, init_pop) | ||
|
||
# Check the DataFrame structure | ||
expected_columns = ["fitness"] | ||
for v in variables: | ||
expected_columns.extend( | ||
[str(v), f"{v}_weight", f"{v}_trans", f"{v}_shape"] | ||
) | ||
expected_columns.extend(["loglik", "r2m", "aic", "delta_aic_null"]) | ||
|
||
assert list(hall_of_fame.data.columns) == expected_columns, \ | ||
"Data columns do not match expected columns" | ||
|
||
# Check other attributes | ||
assert hall_of_fame.variables == variables, \ | ||
"Variables attribute not set correctly" | ||
assert hall_of_fame.max_size == max_size, \ | ||
"Max size attribute not set correctly" | ||
assert hall_of_fame.min_fitness == float("-inf"), \ | ||
"Min fitness attribute not set correctly" | ||
assert hall_of_fame.rvi is None, \ | ||
"RVI attribute should be None initially" | ||
assert hall_of_fame.maw is None, \ | ||
"MAW attribute should be None initially" | ||
assert hall_of_fame.best is None, \ | ||
"Best attribute should be None initially" | ||
assert hall_of_fame.zero_threshold == 1e-17, \ | ||
"Zero threshold attribute not set correctly" | ||
|
||
|
||
def test_hall_of_fame_from_dataframe(hall_of_fame_df_fixture): | ||
max_size = 200 # Example max_size, adjust as needed | ||
|
||
# Create a HallOfFame instance from the DataFrame | ||
hall_of_fame_instance = HallOfFame.from_dataframe(hall_of_fame_df_fixture, | ||
max_size) | ||
|
||
# Asserts to verify the HallOfFame instance | ||
assert len(hall_of_fame_instance.data) == len(hall_of_fame_df_fixture), \ | ||
"Data length in HallOfFame instance does not match input DataFrame" | ||
|
||
|
||
def test_check_population(): | ||
variables = ['var1', 'var2'] | ||
max_size = 10 | ||
num_models = 5 | ||
|
||
# Create a fake population | ||
fake_population = create_fake_population(num_models, len(variables)) | ||
|
||
# Initialize HallOfFame | ||
hall_of_fame = HallOfFame(variables, max_size) | ||
|
||
# Check the population | ||
hall_of_fame.check_population(fake_population) | ||
|
||
# Assertions to verify the HallOfFame data is updated correctly | ||
assert len(hall_of_fame.data) <= max_size, \ | ||
"Hall of fame size exceeds maximum size" | ||
assert all( | ||
col in hall_of_fame.data.columns for col in ['fitness', 'var1', 'var2'] | ||
), "Hall of fame data does not contain expected columns" | ||
|
||
|
||
def test_print_hof(hall_of_fame_fixture): | ||
# Redirect the stdout to a string buffer | ||
buffer = io.StringIO() | ||
with contextlib.redirect_stdout(buffer): | ||
hall_of_fame_fixture.printHOF() | ||
|
||
# Get the content from the buffer | ||
output = buffer.getvalue() | ||
|
||
# Asserts to verify the output | ||
assert "fitness" in output, "Output should contain 'fitness'" | ||
assert "var1" in output, "Output should contain 'var1'" | ||
assert "var2" in output, "Output should contain 'var2'" | ||
|
||
|
||
def test_calculate_bic(hall_of_fame_fixture): | ||
n = 10 | ||
|
||
hall_of_fame_fixture.calculate_bic(n) | ||
|
||
assert "bic" in hall_of_fame_fixture.data.columns, \ | ||
"BIC column not added to data" | ||
|
||
|
||
def test_delta_aic(hall_of_fame_fixture): | ||
hall_of_fame_fixture.delta_aic() | ||
|
||
assert "delta_aic_best" in hall_of_fame_fixture.data.columns, \ | ||
"delta_aic_best column not added to data" | ||
|
||
|
||
def test_akaike_weights(hall_of_fame_fixture): | ||
hall_of_fame_fixture.akaike_weights() | ||
|
||
assert "akaike_weight" in hall_of_fame_fixture.data.columns, \ | ||
"akaike_weight column not added to data" | ||
|
||
|
||
def test_cumulative_akaike(hall_of_fame_fixture): | ||
threshold = 0.8 | ||
|
||
hall_of_fame_fixture.cumulative_akaike(threshold) | ||
|
||
assert "acc_akaike_weight" in hall_of_fame_fixture.data.columns, \ | ||
"acc_akaike_weight column not added to data" | ||
assert "keep" in hall_of_fame_fixture.data.columns, \ | ||
"keep column not added to data" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import os | ||
import pytest | ||
import tempfile | ||
import resistnet | ||
from resistnet.resistance_network import ResistanceNetwork | ||
from resistnet.model_optimisation import ModelRunner | ||
|
||
|
||
@pytest.fixture | ||
def ga_parameters_fixture(): | ||
return { | ||
'mutpb': 0.1, | ||
'cxpb': 0.8, | ||
'indpb': 0.05, | ||
'popsize': 10, | ||
'maxpopsize': 10, | ||
'posWeight': True, | ||
'fixWeight': False, | ||
'fixShape': True, | ||
'allShapes': False, | ||
'min_weight': 0.5, | ||
'max_shape': 2.0, | ||
'max_hof_size': 100, | ||
'tournsize': 3, | ||
'fitmetric': 'aic', | ||
'awsum': 0.75, | ||
'only_keep': True, | ||
'verbose': False, | ||
'report_all': False | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def network_graph_path(): | ||
base_path = os.path.dirname(resistnet.__file__) | ||
file_path = os.path.join(base_path, 'data', 'test.network') | ||
return file_path | ||
|
||
|
||
@pytest.fixture | ||
def shapefile_path(): | ||
base_path = os.path.dirname(resistnet.__file__) | ||
file_path = os.path.join(base_path, 'data', 'test.shp') | ||
return file_path | ||
|
||
|
||
@pytest.fixture | ||
def coords_path(): | ||
base_path = os.path.dirname(resistnet.__file__) | ||
file_path = os.path.join(base_path, 'data', 'test.pointCoords.txt') | ||
return file_path | ||
|
||
|
||
@pytest.fixture | ||
def inmat_path(): | ||
base_path = os.path.dirname(resistnet.__file__) | ||
file_path = os.path.join(base_path, 'data', 'test.popGenDistMat.txt') | ||
return file_path | ||
|
||
|
||
@pytest.fixture | ||
def resistance_network_fixture( | ||
network_graph_path, shapefile_path, coords_path, inmat_path): | ||
# Use temporary directory for output | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
output_prefix = os.path.join(temp_dir, "output") | ||
|
||
network = ResistanceNetwork( | ||
network=network_graph_path, | ||
shapefile=shapefile_path, | ||
coords=coords_path, | ||
variables=["run_mm_cyr"], | ||
inmat=inmat_path, | ||
agg_opts={"run_mm_cyr": "ARITH"}, | ||
pop_agg="ARITH", | ||
reachid_col="EDGE_ID", | ||
length_col="LENGTH_KM", | ||
out=output_prefix, | ||
verbose=False | ||
) | ||
|
||
yield network | ||
|
||
|
||
@pytest.fixture | ||
def model_runner_fixture(resistance_network_fixture): | ||
seed = 1234 | ||
verbose = True | ||
|
||
model_runner = ModelRunner( | ||
resistance_network=resistance_network_fixture, | ||
seed=seed, verbose=verbose) | ||
|
||
yield model_runner | ||
|
||
|
||
def test_model_runner_initialization(resistance_network_fixture): | ||
seed = 1234 | ||
verbose = True | ||
|
||
model_runner = ModelRunner( | ||
resistance_network=resistance_network_fixture, | ||
seed=seed, verbose=verbose) | ||
|
||
# Check if the passed parameters are correctly assigned | ||
assert model_runner.seed == seed | ||
assert model_runner.resistance_network == resistance_network_fixture | ||
assert model_runner.verbose == verbose | ||
|
||
# Check if the default values are correctly initialized | ||
assert isinstance(model_runner.workers, list) | ||
assert model_runner.bests is None | ||
assert model_runner.toolbox is None | ||
assert isinstance(model_runner.logger, list) | ||
|
||
# Check the default values of GA parameters | ||
assert model_runner.cxpb is None | ||
assert model_runner.mutpb is None | ||
assert model_runner.indpb is None | ||
assert model_runner.popsize is None | ||
assert model_runner.maxpopsize is None | ||
assert model_runner.posWeight is None | ||
assert model_runner.fixWeight is None | ||
assert model_runner.fixShape is None | ||
assert model_runner.allShapes is None | ||
assert model_runner.min_weight is None | ||
assert model_runner.max_shape is None | ||
assert model_runner.max_hof_size is None | ||
assert model_runner.tournsize is None | ||
assert model_runner.fitmetric is None | ||
assert model_runner.awsum is None | ||
assert model_runner.only_keep is None | ||
assert model_runner.report_all is None | ||
|
||
|
||
def test_set_ga_parameters( | ||
model_runner_fixture, ga_parameters_fixture): | ||
# Call set_ga_parameters on the fixture instance with the test parameters | ||
model_runner_fixture.set_ga_parameters( | ||
**ga_parameters_fixture | ||
) | ||
|
||
# Assert that each parameter is set correctly | ||
for param, value in ga_parameters_fixture.items(): | ||
assert getattr(model_runner_fixture, param) == \ | ||
value, f"Parameter {param} not set correctly" | ||
|
||
|
||
def test_run_ga(model_runner_fixture, ga_parameters_fixture): | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
out_prefix = os.path.join(temp_dir, "ga_output") | ||
ga_parameters_fixture["verbose"] = False | ||
ga_parameters_fixture["out"] = out_prefix | ||
ga_parameters_fixture["threads"] = 1 | ||
ga_parameters_fixture["maxgens"] = 1 | ||
|
||
# Run the genetic algorithm with the modified parameters | ||
try: | ||
model_runner_fixture.run_ga(**ga_parameters_fixture) | ||
error_occurred = False | ||
except Exception as e: | ||
error_occurred = True | ||
print(f"Error during GA run: {e}") | ||
|
||
# Assert that no errors occurred during the run | ||
assert not error_occurred, "An error occurred during the GA run" | ||
|
||
# Assert that the best models are identified (bests is not None) | ||
assert model_runner_fixture.bests is not None, \ | ||
"ModelRunner.bests missing" | ||
|
||
# Assert that the expected output files are generated | ||
expected_files = [ | ||
f"{out_prefix}.varImportance.tsv", | ||
f"{out_prefix}.HallOfFame.tsv", | ||
f"{out_prefix}.FitnessLog.tsv", | ||
f"{out_prefix}.Model-Average.streamsByResistance.pdf" | ||
] | ||
|
||
for file in expected_files: | ||
assert os.path.isfile(file), f"Expected output file not found: \ | ||
{file}" | ||
|