Skip to content

Commit

Permalink
Update train_regression.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kjysmu authored Nov 3, 2023
1 parent 030cfc9 commit 3839563
Showing 1 changed file with 0 additions and 7 deletions.
7 changes: 0 additions & 7 deletions train_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def main( vm = "" , isPrintArgs = True ):
if isPrintArgs:
print_train_args(args)
if vm != "":
#VIS_MODELS = vm
args.vis_models = vm

if args.is_video:
Expand Down Expand Up @@ -129,11 +128,6 @@ def main( vm = "" , isPrintArgs = True ):
eval_loss_func = nn.MSELoss()
train_loss_func = nn.MSELoss()

# lr = LR_DEFAULT_START
# lr_stepper = LrStepTracker(args.d_model, SCHEDULER_WARMUP_STEPS, 0)
# opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)
# lr_scheduler = LambdaLR(opt, lr_stepper.step)

opt = Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
lr_scheduler = None

Expand All @@ -151,7 +145,6 @@ def main( vm = "" , isPrintArgs = True ):

##### TRAIN LOOP #####
for epoch in range(start_epoch, args.epochs):
# Baseline has no training and acts as a base loss and accuracy (epoch 0 in a sense)
if(epoch > BASELINE_EPOCH):
print(SEPERATOR)
print("NEW EPOCH:", epoch+1)
Expand Down

0 comments on commit 3839563

Please sign in to comment.