Skip to content

Commit

Permalink
corrected typos in ANN model name and train/test loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulysse Rancon committed Dec 20, 2021
1 parent dfd3ceb commit c389b53
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion network/ANN_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def count_trainable_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)


class SteroSpike_equivalentANN(AnalogNet):
class StereoSpike_equivalentANN(AnalogNet):
"""
An Analog Neural Network (ANN) with the exact same architecture as StereoSpike.
Uses biases in convolution layers, batch normalization, and classical activation functions such as Sigmoid.
Expand Down
7 changes: 4 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
RandomEventDrop

from network.SNN_models import StereoSpike, fromZero_feedforward_multiscale_tempo_Matt_SpikeFlowNetLike
from network.ANN_models import SteroSpike_equivalentANN
from network.ANN_models import StereoSpike_equivalentANN

from network.metrics import MeanDepthError, log_to_lin_depths, disparity_to_depth
from network.loss import Total_Loss
Expand Down Expand Up @@ -78,7 +78,8 @@
###########

net = StereoSpike(surrogate_function=surrogate.ATan(), detach_reset=True, v_threshold=1.0, v_reset=0.).to(device)
# net = SteroSpike_equivalentANN(activation_function=nn.Sigmoid()).to(device)
net = StereoSpike_equivalentANN(activation_function=nn.Sigmoid()).to(device)
# net = fromZero_feedforward_multiscale_tempo_Matt_SpikeFlowNetLike(tau=3., v_threshold=1.0, v_reset=0.0, use_plif=True, multiply_factor=10.).to(device)

net.load_state_dict(torch.load('./results/checkpoints/stereospike.pth'))

Expand Down Expand Up @@ -111,7 +112,7 @@
init_pots = init_pots.to(device)
warmup_chunks_left = warmup_chunks_left.to(device, dtype=torch.float)
warmup_chunks_right = warmup_chunks_right.to(device, dtype=torch.float)
test_chunks_left = test_chunks_right.to(device, dtype=torch.float)
test_chunks_left = test_chunks_left.to(device, dtype=torch.float)
test_chunks_right = test_chunks_right.to(device, dtype=torch.float)
label = label.to(device)

Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
RandomEventDrop

from network.SNN_models import StereoSpike, fromZero_feedforward_multiscale_tempo_Matt_SpikeFlowNetLike
from network.ANN_models import SteroSpike_equivalentANN
from network.ANN_models import StereoSpike_equivalentANN

from network.metrics import MeanDepthError, log_to_lin_depths, disparity_to_depth
from network.loss import Total_Loss
Expand Down Expand Up @@ -116,8 +116,8 @@ def set_random_seed(seed):
###########

net = StereoSpike(surrogate_function=surrogate.ATan(), detach_reset=True, v_threshold=1.0, v_reset=0.).to(device)
# net = SteroSpike_equivalentANN(activation_function=nn.Sigmoid()).to(device)

# net = StereoSpike_equivalentANN(activation_function=nn.Sigmoid()).to(device)
# net = fromZero_feedforward_multiscale_tempo_Matt_SpikeFlowNetLike(tau=3., v_threshold=1.0, v_reset=0.0, use_plif=True, multiply_factor=10.).to(device)

################
# OPTIMIZATION #
Expand Down Expand Up @@ -193,7 +193,7 @@ def set_random_seed(seed):
init_pots = init_pots.to(device)
warmup_chunks_left = warmup_chunks_left.to(device, dtype=torch.float)
warmup_chunks_right = warmup_chunks_right.to(device, dtype=torch.float)
train_chunks_left = train_chunks_right.to(device, dtype=torch.float)
train_chunks_left = train_chunks_left.to(device, dtype=torch.float)
train_chunks_right = train_chunks_right.to(device, dtype=torch.float)
label = label.to(device)

Expand Down Expand Up @@ -282,7 +282,7 @@ def set_random_seed(seed):
init_pots = init_pots.to(device)
warmup_chunks_left = warmup_chunks_left.to(device, dtype=torch.float)
warmup_chunks_right = warmup_chunks_right.to(device, dtype=torch.float)
test_chunks_left = test_chunks_right.to(device, dtype=torch.float)
test_chunks_left = test_chunks_left.to(device, dtype=torch.float)
test_chunks_right = test_chunks_right.to(device, dtype=torch.float)
label = label.to(device)

Expand Down

0 comments on commit c389b53

Please sign in to comment.