diff --git a/brian/synapses/synapses.py b/brian/synapses/synapses.py index 23e19e8b..a08d6026 100644 --- a/brian/synapses/synapses.py +++ b/brian/synapses/synapses.py @@ -985,21 +985,18 @@ def synapse_index(self,i): raise NotImplementedError, "The first two coordinates must be integers" return i - def save_connectivity(self, fn): + def save_connectivity(self, fn, synaptic_vars=[]): ''' Saves the connectivity matrices and delays to a file ``fn``, so that they can be reloaded afterwards. - Notice that this only saves the connectivity, not the current state of the variables in the Synapses class. In fact, it is completely decoupled from the pre/post synaptic groups, and the models of the Synapses object. + Additionally, the states of a list of synaptic variables (identified by their name as string) provided in ``synaptic_vars`` is saved to the file and automatically restored upon a subsequent loading operation. - *Example*: Say we want to save the connectivity of Synapses, and some other state of the network, say ``my_state``. We would simply do:: + *Example*: Say we want to save the connectivity of Synapses along with their weights stored in a synaptic variable ``w``. We would simply do:: - - array_to_save = synapses.my_state[:,:] - synapses.save_connectivity('./somefile') + synapses.save_connectivity('./somefile',['w']) ... new_synapses = Synapses(newgroup0, newgroup0, model = newmodel, pre = newpre, ...) - new_synapses.load_connectivity('./somefile') - new_synapses.my_state[:,:] = array_that_was_saved_and_then_reloaded + new_synapses.load_connectivity('./somefile') # has w assigned already Note: You have to deal with dynamical delays as you would with any other variable. ''' @@ -1017,8 +1014,8 @@ def save_connectivity(self, fn): '_delay_pre' : self._delay_pre, '_delay_post' : self._delay_post } - + savez_args.update({ key: self.__getattr__(key)[:,:] for key in custom_vars if not key in self.__dict__}) np.savez(f, **savez_args) return 1 @@ -1038,6 +1035,11 @@ def load_connectivity(self, fn): self._delay_pre = data['_delay_pre'] self._delay_post = data['_delay_post'] + for key in data.iterkeys(): + if key in ['presynaptic','postsynaptic','_delay_pre', '_delay_post']: + continue + self.__setattr__(key,data[key]) + def __repr__(self): return 'Synapses object with '+ str(len(self))+ ' synapses'