Skip to content

Commit

Permalink
corrected bug with batchsize and explained hyperparams
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulysse Rancon committed Jan 13, 2022
1 parent c389b53 commit 5434a00
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ def set_random_seed(seed):
# GENERAL PARAMETERS #
######################

nfpdm = 1 # (!) don't choose it too big because of memory limitations (!)
N_inference = 1
N_warmup = 1
batchsize = 1
learned_metric = 'LIN'
nfpdm = 1 # number of frames per depth map (1 label every 50 ms)
N_inference = 1 # number of chunks for training/testing (1 chunk = 50 ms = nfpdm frames)
N_warmup = 1 # number of chunks for warmup (if you want to use a stateful model)
batchsize = 1
learned_metric = 'LIN' # learn metric depth ('LIN'), normalized log depth ('LOG') or disparity ('DISP')
learning_rate = 0.0002
weight_decay = 0.0
n_epochs = 70
show = False
show = False # display network's predictions during training / validation


###########################
Expand All @@ -90,7 +90,7 @@ def set_random_seed(seed):
])

train_set, val_set, test_set = load_MVSEC('./datasets/MVSEC/data/', scenario='indoor_flying', split='1',
num_frames_per_depth_map=1, warmup_chunks=1, train_chunks=1,
num_frames_per_depth_map=nfpdm, warmup_chunks=1, train_chunks=1,
transform=tsfm, normalize=False, learn_on='LIN')

train_data_loader = torch.utils.data.DataLoader(dataset=train_set,
Expand All @@ -102,13 +102,13 @@ def set_random_seed(seed):
val_data_loader = torch.utils.data.DataLoader(dataset=val_set,
batch_size=1,
shuffle=False,
drop_last=False,
drop_last=True
pin_memory=True)

test_data_loader = torch.utils.data.DataLoader(dataset=test_set,
batch_size=1,
shuffle=False,
drop_last=False,
drop_last=True,
pin_memory=True)

###########
Expand Down

0 comments on commit 5434a00

Please sign in to comment.