From 78a5c0d05ed6031621560653c4ca944cdc4079e6 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 10 Mar 2024 11:35:20 +0800 Subject: [PATCH] fix --- .../_src/optimizers/tests/test_scheduler.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/brainpy/_src/optimizers/tests/test_scheduler.py b/brainpy/_src/optimizers/tests/test_scheduler.py index 8c53f33d..82f2fd1e 100644 --- a/brainpy/_src/optimizers/tests/test_scheduler.py +++ b/brainpy/_src/optimizers/tests/test_scheduler.py @@ -18,8 +18,8 @@ class TestMultiStepLR(parameterized.TestCase): ) def test2(self, last_epoch): bm.random.seed() - scheduler1 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) - scheduler2 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) + scheduler1 = sgd_scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) + scheduler2 = sgd_scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) for i in range(1, 25): lr1 = scheduler1(i + last_epoch) @@ -38,8 +38,8 @@ class TestStepLR(parameterized.TestCase): ) def test1(self, last_epoch): bm.random.seed() - scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) - scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) + scheduler1 = sgd_scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) + scheduler2 = sgd_scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) for i in range(1, 25): lr1 = scheduler1(i + last_epoch) lr2 = scheduler2() @@ -54,7 +54,7 @@ def test1(self): bm.random.seed() max_epoch = 50 iters = 200 - sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1) + sch = sgd_scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1) all_lr1 = [[], []] all_lr2 = [[], []] for epoch in range(max_epoch): @@ -81,11 +81,11 @@ def test1(self): bm.random.seed() max_epoch = 50 iters = 200 - sch = scheduler.CosineAnnealingWarmRestarts(0.1, - iters, - T_0=5, - T_mult=1, - last_call=-1) + sch = sgd_scheduler.CosineAnnealingWarmRestarts(0.1, + iters, + T_0=5, + T_mult=1, + last_call=-1) all_lr1 = [] all_lr2 = [] for epoch in range(max_epoch):