-
Notifications
You must be signed in to change notification settings - Fork 1
/
callback.py
20 lines (18 loc) · 1022 Bytes
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import TrainerCallback
class EarlyStoppingCallback(TrainerCallback):
def __init__(self, early_stopping_patience: int, early_stopping_threshold: float):
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
self.best_score = None
self.patience_counter = 0
def on_evaluate(self, args, state, control, **kwargs):
eval_metric = kwargs.get("metrics", {}).get("eval_loss", None) # Replace eval_loss with your evaluation metric
if eval_metric is not None:
if self.best_score is None or eval_metric < self.best_score + self.early_stopping_threshold:
self.best_score = eval_metric
self.patience_counter = 0
else:
self.patience_counter += 1
if self.patience_counter >= self.early_stopping_patience:
print("Early stopping triggered")
control.should_training_stop = True