Skip to content

Commit

Permalink
Add comments to STDPLearner
Browse files Browse the repository at this point in the history
  • Loading branch information
metr0jw committed Aug 7, 2024
1 parent 68a5c86 commit 29fabfb
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/snntorch.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ State Quantization
:members:
:undoc-members:
:show-inheritance:

STDP Learner
^^^^^^^^^^^^^^^^^^^^^^^^

.. automodule:: snntorch.functional.stdp_learner
:members:
:undoc-members:
:show-inheritance:

Probe
^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
55 changes: 55 additions & 0 deletions snntorch/functional/stdp_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def stdp_linear_single_step(
f_pre: Callable = lambda x: x,
f_post: Callable = lambda x: x,
):
"""
Single step of STDP learning rule for Linear layer.
"""

if trace_pre is None:
trace_pre = 0.0

Expand Down Expand Up @@ -55,6 +60,11 @@ def mstdp_linear_single_step(
f_pre: Callable = lambda x: x,
f_post: Callable = lambda x: x,
):
"""
Single step of mSTDP learning rule for Linear layer.
"""

if trace_pre is None:
trace_pre = 0.0

Expand Down Expand Up @@ -88,6 +98,11 @@ def mstdpet_linear_single_step(
f_pre: Callable = lambda x: x,
f_post: Callable = lambda x: x,
):
"""
Single step of mSTDP learning rule with Eligibility Trace for Linear layer.
"""

if trace_pre is None:
trace_pre = 0.0

Expand Down Expand Up @@ -115,6 +130,11 @@ def stdp_conv2d_single_step(
f_pre: Callable = lambda x: x,
f_post: Callable = lambda x: x,
):
"""
Single step of STDP learning rule for Conv2d layer.
"""

if conv.dilation != (1, 1):
raise NotImplementedError(
"STDP with dilation != 1 for Conv2d has not been implemented!"
Expand Down Expand Up @@ -198,6 +218,11 @@ def stdp_conv1d_single_step(
f_pre: Callable = lambda x: x,
f_post: Callable = lambda x: x,
):
"""
Single step of STDP learning rule for Conv1d layer.
"""

if conv.dilation != (1,):
raise NotImplementedError(
"STDP with dilation != 1 for Conv1d has not been implemented!"
Expand Down Expand Up @@ -285,19 +310,49 @@ def __init__(
self.trace_post = None

def reset(self):
"""
Reset the recorded data in the monitors.
"""

super(STDPLearner, self).reset()
self.in_spike_monitor.clear_recorded_data()
self.out_spike_monitor.clear_recorded_data()

def disable(self):
"""
Disable the recording of the data in the monitors.
"""

self.in_spike_monitor.disable()
self.out_spike_monitor.disable()

def enable(self):
"""
Enable the recording of the data in the monitors.
"""

self.in_spike_monitor.enable()
self.out_spike_monitor.enable()

def step(self, on_grad: bool = True, scale: float = 1.0):
"""
Perform a single step of STDP learning rule.
:param on_grad: If True, the delta_w is added to the weight.grad of the synapse.
If False, the delta_w is returned.
:type on_grad: bool
:param scale: Scaling factor for the delta_w.
:type scale: float
:return: delta_w if on_grad is False.
:rtype: torch.Tensor
"""

length = self.in_spike_monitor.records.__len__()
delta_w = None

Expand Down

0 comments on commit 29fabfb

Please sign in to comment.