From a017a6c92190b5968ba83a90c7c9da3ada15c21c Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 21 Nov 2023 16:06:18 -0500 Subject: [PATCH] TTST test passing --- dwi_ml/models/projects/transformer_models.py | 9 +++++---- scripts_python/tests/test_all_steps_ttst.py | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index 07aee8c9..9486cd39 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -984,11 +984,12 @@ def _run_main_layer_forward(self, concat_s_t, masks, return outputs, (sa_weights,) def merge_batches_weights(self, weights, new_weights, device): - # weights is a single attention tensor (encoder) - new_weights = [a.to(device) for a in new_weights] + # weights is a single attention tensor (encoder): a tuple of 1. + new_weights = [a.to(device) for a in new_weights[0]] if weights is None: - return new_weights + return (new_weights,) else: weights.extend(new_weights) - return weights + return (weights,) + diff --git a/scripts_python/tests/test_all_steps_ttst.py b/scripts_python/tests/test_all_steps_ttst.py index d3d4a1ce..d116ac60 100644 --- a/scripts_python/tests/test_all_steps_ttst.py +++ b/scripts_python/tests/test_all_steps_ttst.py @@ -62,7 +62,7 @@ def test_execution(script_runner, experiments_path): subj_id = TEST_EXPECTED_SUBJ_NAMES[0] # Test visu loss - prefix = os.path.join(experiments_path, 'test_visu') + prefix = os.path.join(experiments_path, 'test_visu_loss') ret = script_runner.run('tt_visualize_loss.py', whole_experiment_path, hdf5_file, subj_id, input_group_name, streamline_group_name, prefix, @@ -76,10 +76,12 @@ def test_execution(script_runner, experiments_path): assert ret.success # Test visu weights - in_sft = os.path.join(data_dir, 'dwi_ml_ready/subjX/example_bundle/Fornix.trk') + in_sft = os.path.join(data_dir, + 'dwi_ml_ready/subjX/example_bundle/Fornix.trk') ret = script_runner.run( 'tt_visualize_weights.py', whole_experiment_path, hdf5_file, subj_id, input_group, in_sft, '--visu_type', 'as_matrix', 'colored_sft', 'bertviz_locally', '--subset', 'training', '--logging', 'INFO', - '--resample_attention', '25') + '--resample_attention', '25', '--rescale') assert ret.success +