diff --git a/references/classification/train_pytorch_orientation.py b/references/classification/train_pytorch_orientation.py index 688e48564..82c2bd46a 100644 --- a/references/classification/train_pytorch_orientation.py +++ b/references/classification/train_pytorch_orientation.py @@ -11,6 +11,7 @@ import logging import multiprocessing as mp import time +from pathlib import Path import numpy as np import torch @@ -35,6 +36,33 @@ from doctr.models.utils import export_model_to_onnx from utils import EarlyStopper, plot_recorder, plot_samples +SLACK_WEBHOOK_URL = None +SLACK_WEBHOOK_PATH = Path(os.path.join(os.path.expanduser("~"), ".config", "doctr", "slack_webhook_url.txt")) +if SLACK_WEBHOOK_PATH.exists(): + with open(SLACK_WEBHOOK_PATH) as f: + SLACK_WEBHOOK_URL = f.read().strip() +else: + print(f"{SLACK_WEBHOOK_PATH} does not exist, skip Slack integration configuration...") + + +def send_on_slack(text: str): + """Send a message on Slack. + + Args: + text (str): message to send on Slack + """ + if SLACK_WEBHOOK_URL: + try: + import requests + + requests.post( + url=SLACK_WEBHOOK_URL, + json={"text": f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]: {text}"}, + ) + except Exception: + print("Impossible to send message on Slack, continue...") + + CLASSES = [0, 90, 180, 270] @@ -121,7 +149,10 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a model.train() # Iterate over the batches of the dataset + last_progress = 0 + interval_progress = 5 pbar = tqdm(train_loader, position=1) + send_on_slack(str(pbar)) for images, targets in pbar: if torch.cuda.is_available(): images = images.cuda() @@ -146,15 +177,24 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a scheduler.step() pbar.set_description(f"Training loss: {train_loss.item():.6}") + current_progress = pbar.n / pbar.total * 100 + if current_progress - last_progress > interval_progress: + send_on_slack(str(pbar)) + last_progress = int(current_progress) + send_on_slack(f"Final training loss: {train_loss.item():.6}") @torch.no_grad() def evaluate(model, val_loader, batch_transforms, amp=False): # Model in eval mode model.eval() + last_progress = 0 + interval_progress = 5 + pbar = tqdm(val_loader) + send_on_slack(str(pbar)) # Validation loop val_loss, correct, samples, batch_cnt = 0.0, 0.0, 0.0, 0.0 - for images, targets in tqdm(val_loader): + for images, targets in pbar: images = batch_transforms(images) if torch.cuda.is_available(): @@ -175,6 +215,11 @@ def evaluate(model, val_loader, batch_transforms, amp=False): batch_cnt += 1 samples += images.shape[0] + current_progress = pbar.n / pbar.total * 100 + if current_progress - last_progress > interval_progress: + send_on_slack(str(pbar)) + last_progress = int(current_progress) + val_loss /= batch_cnt acc = correct / samples return val_loss, acc @@ -214,6 +259,9 @@ def main(args): pin_memory=torch.cuda.is_available(), ) print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{len(val_loader)} batches)") + send_on_slack( + f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{len(val_loader)} batches)" + ) batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)) @@ -223,6 +271,7 @@ def main(args): # Resume weights if isinstance(args.resume, str): print(f"Resuming {args.resume}") + send_on_slack(f"Resuming {args.resume}") checkpoint = torch.load(args.resume, map_location="cpu") model.load_state_dict(checkpoint) @@ -276,6 +325,9 @@ def main(args): pin_memory=torch.cuda.is_available(), ) print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{len(train_loader)} batches)") + send_on_slack( + f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{len(train_loader)} batches)" + ) if args.show_samples: x, target = next(iter(train_loader)) @@ -338,9 +390,11 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + send_on_slack(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") torch.save(model.state_dict(), f"./{exp_name}.pt") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") + send_on_slack(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B if args.wb: wandb.log({ @@ -349,6 +403,7 @@ def main(args): }) if args.early_stop and early_stopper.early_stop(val_loss): print("Training halted early due to reaching patience limit.") + send_on_slack("Training halted early due to reaching patience limit.") break if args.wb: run.finish()