forked from sjiggins/carl-torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_loss.py
23 lines (18 loc) · 881 Bytes
/
plot_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import argparse
import matplotlib.pyplot as plt
if __name__ == "__main__":
parser = argparse.ArgumentParser(usage="usage: %prog [opts]")
parser.add_argument('--version', action='version', version='%prog 1.0')
parser.add_argument('-g', '--global_name', action='store', type=str, dest='global_name', default='Test', help='Global name for identifying this run - used in folder naming and output naming')
opts = parser.parse_args()
train_loss = f"loss_train_{opts.global_name}.npy"
val_loss = f"loss_val_{opts.global_name}.npy"
train_loss = np.load(train_loss)
val_loss = np.load(val_loss)
plt.plot(train_loss, label="train loss")
plt.plot(val_loss, label="val loss")
plt.ylabel("loss")
plt.legend(frameon=False, title="")
#plt.show()
plt.savefig("plots/train_val_loss_{}.png".format(opts.global_name))