Skip to content

Commit

Permalink
Unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Hjorthmedh committed Oct 10, 2023
1 parent 480fce0 commit 63e054f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 53 deletions.
2 changes: 1 addition & 1 deletion snudda/init/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ def define_striatum(self,
struct_mesh=os.path.join("$SNUDDA_DATA", "mesh", "Striatum-d.obj"),
mesh_bin_width=1e-4,
d_min=d_min,
nn_putative_points=num_neurons*3)
n_putative_points=num_neurons*3)

density_file = os.path.join("$SNUDDA_DATA", "density", "dorsal_striatum_density.json")

Expand Down
5 changes: 3 additions & 2 deletions tests/networks/network_testing_input/network-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"type": "mesh",
"dMin": 1.5e-05,
"meshFile": "networks/network_testing_input/mesh/Striatum-cube-mesh-5e-05.obj",
"meshBinWidth": 5e-05
"meshBinWidth": 5e-05,
"n_putative_points": 30
}
},
"Connectivity": {
Expand Down Expand Up @@ -519,4 +520,4 @@
"structure": "Striatum"
}
}
}
}
2 changes: 1 addition & 1 deletion tests/networks/network_testing_project/network-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,4 @@
"volumeID": "VolumeB"
}
}
}
}
2 changes: 1 addition & 1 deletion tests/test_degeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_something(self):
orig_load = SnuddaLoad(self.network_A)
degen_load = SnuddaLoad(self.network_C)

self.assertEqual(orig_load.data["nSynapses"], 159)
self.assertEqual(orig_load.data["nSynapses"], 156)

# Verify that it should be 99 synapses -- now it is just a regression test
# Old version gave 99, new gives 165 --- CHECK WHY!
Expand Down
68 changes: 20 additions & 48 deletions tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,12 @@ def test_input_1(self):
f_gen = len(t_idx) / (n_traces * (et - st))
print(f"ID {neuron_id_str} {neuron_name} {input_type} f={f}, f_gen={f_gen}")

try:
if np.max(input_info["spikes"].attrs["correlation"]) == 0:
self.assertTrue(f_gen > f - 5*np.sqrt(f)/np.sqrt(n_traces))
self.assertTrue(f_gen < f + 5*np.sqrt(f)/np.sqrt(n_traces))
else:
# For high correlations and short durations we have huge fluctuations, so skip those
pass
except:
import pdb
import traceback
print(traceback.format_exc())
pdb.set_trace()
if np.max(input_info["spikes"].attrs["correlation"]) == 0:
self.assertTrue(f_gen > f - 5*np.sqrt(f)/np.sqrt(n_traces))
self.assertTrue(f_gen < f + 5*np.sqrt(f)/np.sqrt(n_traces))
else:
# For high correlations and short durations we have huge fluctuations, so skip those
pass

if "populationUnitCorrelation" in config_data[neuron_type][input_type]:
correlation = config_data[neuron_type][input_type]["populationUnitCorrelation"]
Expand All @@ -227,16 +221,9 @@ def test_input_1(self):
readout = np.zeros((spikes.size, ))
ctr = 0
for t_idx in (spikes.flatten() / bin_size).astype(int):
try:
if t_idx > 0:
readout[ctr] = binned_data[t_idx]
ctr += 1
except:
import traceback
t_str = traceback.format_exc()
print(t_str)
import pdb
pdb.set_trace()
if t_idx > 0:
readout[ctr] = binned_data[t_idx]
ctr += 1

readout = readout[:ctr]

Expand Down Expand Up @@ -272,19 +259,9 @@ def test_input_1(self):

print(f"Simultaneous spikes: {np.mean(readout):.2f} (expected {expected_mean:.2f}) "
f"- correlation {correlation}")
try:
if jitter <= 0.001:
# Only do check for non-jittered input
self.assertTrue(expected_mean * 0.75 < np.mean(readout) < expected_mean * 1.25)

except:
import traceback

t_str = traceback.format_exc()
print(t_str)
import pdb

pdb.set_trace()
if jitter <= 0.001:
# Only do check for non-jittered input
self.assertTrue(expected_mean * 0.75 < np.mean(readout) < expected_mean * 1.25)

def test_input_2(self):

Expand All @@ -305,23 +282,17 @@ def test_input_2(self):

# OBS, population unit 0 does not get any of the extra mother spikes specified
# So we need to check FS neuron that belongs to population unit 1 or 2.
some_spikes = input_data["input/1/Cortical/spikes"][()].flatten()
some_spikes = input_data["input/3/Cortical/spikes"][()].flatten()
some_spikes = some_spikes[some_spikes >= 0]
n_trains = input_data["input/1/Cortical/spikes"][()].shape[0]
n_trains = input_data["input/3/Cortical/spikes"][()].shape[0]

for extra_spike in [0.2, 0.3, 0.45]:

try:
self.assertTrue(np.sum(np.abs(some_spikes - extra_spike) < 1e-4)
>= n_trains)
self.assertTrue(np.sum(np.abs(some_spikes - extra_spike + 0.05) < 1e-3) < 50)
except:
import traceback
print(traceback.format_exc())
import pdb
pdb.set_trace()

some_spikes2 = input_data["input/1/Thalamic/spikes"][()].flatten()
self.assertTrue(np.sum(np.abs(some_spikes - extra_spike) < 1e-4)
>= n_trains)
self.assertTrue(np.sum(np.abs(some_spikes - extra_spike + 0.05) < 1e-3) < 50)

some_spikes2 = input_data["input/3/Thalamic/spikes"][()].flatten()
some_spikes2 = some_spikes2[some_spikes2 >= 0]

for spike in [0.1, 0.2, 0.3]:
Expand All @@ -332,6 +303,7 @@ def test_input_2(self):
# Check input generated, this focuses on the frequency function generation
# and also checks input correlation

# TODO: New cell numbering, so need to pick other cell numbers
some_spikes_c3 = input_data["input/3/CorticalSignal/spikes"][()]
some_spikes_c8 = input_data["input/8/CorticalSignal/spikes"][()]

Expand Down

0 comments on commit 63e054f

Please sign in to comment.