From b01d6993ce2c647172f279ea90dd155d36cd76fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Mu=CC=88ller?= Date: Tue, 15 Oct 2024 15:49:54 +0200 Subject: [PATCH] fix(ml): change early stopping implementation Early stopping now terminates one epoch later --- ml/gnn/src/gnn.py | 2 +- ml/gnn/src/model/base_model.py | 2 +- ml/tree-lstm/src/paper/mdeoperation.py | 2 +- ml/tree-lstm/src/tree_dataset.py | 4 ---- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/ml/gnn/src/gnn.py b/ml/gnn/src/gnn.py index eff0dd2..bb1ef48 100644 --- a/ml/gnn/src/gnn.py +++ b/ml/gnn/src/gnn.py @@ -33,7 +33,7 @@ "Please provide the train, validation, and test dataset file paths as arguments" ) -num_epochs = 2000 +num_epochs = 100 start_epoch = 0 patience = 10 diff --git a/ml/gnn/src/model/base_model.py b/ml/gnn/src/model/base_model.py index f7769fd..6a90dbf 100644 --- a/ml/gnn/src/model/base_model.py +++ b/ml/gnn/src/model/base_model.py @@ -130,7 +130,7 @@ def fit( remaining_patience = patience else: remaining_patience -= 1 - if remaining_patience == 0: + if remaining_patience < 0: self.layout_proxy.print( f"{text_padding}Early stopping in epoch {epoch}" ) diff --git a/ml/tree-lstm/src/paper/mdeoperation.py b/ml/tree-lstm/src/paper/mdeoperation.py index 6ef0faf..6c43020 100644 --- a/ml/tree-lstm/src/paper/mdeoperation.py +++ b/ml/tree-lstm/src/paper/mdeoperation.py @@ -298,7 +298,7 @@ def train( remaining_patience = args.patience else: remaining_patience -= 1 - if remaining_patience == 0: + if remaining_patience < 0: print(f"Early stopping in epoch {epoch}") break sys.stdout.flush() diff --git a/ml/tree-lstm/src/tree_dataset.py b/ml/tree-lstm/src/tree_dataset.py index 8b55a25..9f8ace1 100644 --- a/ml/tree-lstm/src/tree_dataset.py +++ b/ml/tree-lstm/src/tree_dataset.py @@ -24,8 +24,6 @@ def load(self): dataset_load_start_time = time.perf_counter() if self.is_cached: self.data, self.metadata, self.vocabulary = torch.load(self.dataset_cache_file) - # TODO? - # self.to(device) else: with open(self.dataset_path, "r") as file: dataset_input = json.load(file) @@ -42,8 +40,6 @@ def load(self): (self.data, self.metadata, self.vocabulary), self.dataset_cache_file, ) - # TODO? - # self.to(device) dataset_load_end_time = time.perf_counter() print(