Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kjysmu authored Nov 2, 2023
1 parent 062e499 commit 02332b5
Showing 1 changed file with 2 additions and 34 deletions.
36 changes: 2 additions & 34 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,12 @@ def main( vm = "" , isPrintArgs = True ):

##### Not smoothing evaluation loss #####
eval_loss_func = nn.CrossEntropyLoss(ignore_index=CHORD_PAD)
# eval_loss_func_root = nn.CrossEntropyLoss(ignore_index=CHORD_ROOT_PAD)
# eval_loss_func_attr = nn.CrossEntropyLoss(ignore_index=CHORD_ATTR_PAD)

##### SmoothCrossEntropyLoss or CrossEntropyLoss for training #####
if(args.ce_smoothing is None):
train_loss_func = eval_loss_func
# train_loss_func_root = eval_loss_func_root
# train_loss_func_attr = eval_loss_func_attr
else:
train_loss_func = SmoothCrossEntropyLoss(args.ce_smoothing, CHORD_SIZE, ignore_index=CHORD_PAD)
# train_loss_func_root = SmoothCrossEntropyLoss(args.ce_smoothing, CHORD_ROOT_SIZE, ignore_index=CHORD_ROOT_PAD)
# train_loss_func_attr = SmoothCrossEntropyLoss(args.ce_smoothing, CHORD_ATTR_SIZE, ignore_index=CHORD_ATTR_PAD)

eval_loss_emotion_func = nn.BCEWithLogitsLoss()
train_loss_emotion_func = eval_loss_emotion_func
Expand All @@ -172,8 +166,6 @@ def main( vm = "" , isPrintArgs = True ):
else:
lr_scheduler = None

# opt = Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
# lr_scheduler = LambdaLR(opt, lr_stepper.step)

##### Tracking best evaluation accuracy #####
best_eval_acc = 0.0
Expand Down Expand Up @@ -208,27 +200,14 @@ def main( vm = "" , isPrintArgs = True ):
print(SEPERATOR)
print("Baseline model evaluation (Epoch 0):")

# Eval v19
# train_loss, train_acc = eval_model(model, train_loader,
# train_loss_func, train_loss_func_root, train_loss_func_attr,
# isVideo= args.is_video)

# eval_loss, eval_acc = eval_model(model, val_loader,
# eval_loss_func, eval_loss_func_root, eval_loss_func_attr,
# isVideo= args.is_video)

# Eval

train_metric_dict = eval_model(model, train_loader,
train_loss_func, train_loss_emotion_func,
isVideo= args.is_video)

train_total_loss = train_metric_dict["avg_total_loss"]
train_loss_chord = train_metric_dict["avg_loss_chord"]
train_loss_emotion = train_metric_dict["avg_loss_emotion"]
train_acc = train_metric_dict["avg_acc"]
train_cor = train_metric_dict["avg_cor"]
train_acc_cor = train_metric_dict["avg_acc_cor"]

train_h1 = train_metric_dict["avg_h1"]
train_h3 = train_metric_dict["avg_h3"]
train_h5 = train_metric_dict["avg_h5"]
Expand All @@ -240,9 +219,7 @@ def main( vm = "" , isPrintArgs = True ):
eval_total_loss = eval_metric_dict["avg_total_loss"]
eval_loss_chord = eval_metric_dict["avg_loss_chord"]
eval_loss_emotion = eval_metric_dict["avg_loss_emotion"]
eval_acc = eval_metric_dict["avg_acc"]
eval_cor = eval_metric_dict["avg_cor"]
eval_acc_cor = eval_metric_dict["avg_acc_cor"]

eval_h1 = eval_metric_dict["avg_h1"]
eval_h3 = eval_metric_dict["avg_h3"]
eval_h5 = eval_metric_dict["avg_h5"]
Expand All @@ -254,23 +231,14 @@ def main( vm = "" , isPrintArgs = True ):
print("Avg train loss (total):", train_total_loss)
print("Avg train loss (chord):", train_loss_chord)
print("Avg train loss (emotion):", train_loss_emotion)

print("Avg train acc:", train_acc)
print("Avg train cor:", train_cor)
print("Avg train acc_cor:", train_acc_cor)

print("Avg train h1:", train_h1)
print("Avg train h3:", train_h3)
print("Avg train h5:", train_h5)


print("Avg val loss (total):", eval_total_loss)
print("Avg val loss (chord):", eval_loss_chord)
print("Avg val loss (emotion):", eval_loss_emotion)

print("Avg val acc:", eval_acc)
print("Avg val cor:", eval_cor)
print("Avg val acc_cor:", eval_acc_cor)

print("Avg val h1:", eval_h1)
print("Avg val h3:", eval_h3)
Expand Down

0 comments on commit 02332b5

Please sign in to comment.