Skip to content

Commit

Permalink
TTST test passing
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Nov 22, 2023
1 parent 300b015 commit a017a6c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
9 changes: 5 additions & 4 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,11 +984,12 @@ def _run_main_layer_forward(self, concat_s_t, masks,
return outputs, (sa_weights,)

def merge_batches_weights(self, weights, new_weights, device):
# weights is a single attention tensor (encoder)
new_weights = [a.to(device) for a in new_weights]
# weights is a single attention tensor (encoder): a tuple of 1.
new_weights = [a.to(device) for a in new_weights[0]]

if weights is None:
return new_weights
return (new_weights,)
else:
weights.extend(new_weights)
return weights
return (weights,)

8 changes: 5 additions & 3 deletions scripts_python/tests/test_all_steps_ttst.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_execution(script_runner, experiments_path):
subj_id = TEST_EXPECTED_SUBJ_NAMES[0]

# Test visu loss
prefix = os.path.join(experiments_path, 'test_visu')
prefix = os.path.join(experiments_path, 'test_visu_loss')
ret = script_runner.run('tt_visualize_loss.py', whole_experiment_path,
hdf5_file, subj_id, input_group_name,
streamline_group_name, prefix,
Expand All @@ -76,10 +76,12 @@ def test_execution(script_runner, experiments_path):
assert ret.success

# Test visu weights
in_sft = os.path.join(data_dir, 'dwi_ml_ready/subjX/example_bundle/Fornix.trk')
in_sft = os.path.join(data_dir,
'dwi_ml_ready/subjX/example_bundle/Fornix.trk')
ret = script_runner.run(
'tt_visualize_weights.py', whole_experiment_path, hdf5_file, subj_id,
input_group, in_sft, '--visu_type', 'as_matrix', 'colored_sft',
'bertviz_locally', '--subset', 'training', '--logging', 'INFO',
'--resample_attention', '25')
'--resample_attention', '25', '--rescale')
assert ret.success

0 comments on commit a017a6c

Please sign in to comment.