Skip to content

Commit

Permalink
send_on_slack tf
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Feb 13, 2024
1 parent 50b3dc2 commit 99b796e
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import hashlib
import multiprocessing as mp
import time
from pathlib import Path

import numpy as np
import psutil
Expand All @@ -31,6 +32,32 @@
from doctr.utils.metrics import LocalizationConfusion
from utils import EarlyStopper, load_backbone, plot_recorder, plot_samples

SLACK_WEBHOOK_URL = None
SLACK_WEBHOOK_PATH = Path(os.path.join(os.path.expanduser("~"), ".config", "doctr", "slack_webhook_url.txt"))
if SLACK_WEBHOOK_PATH.exists():
with open(SLACK_WEBHOOK_PATH) as f:
SLACK_WEBHOOK_URL = f.read().strip()
else:
print(f"{SLACK_WEBHOOK_PATH} does not exist, skip Slack integration configuration...")


def send_on_slack(text: str):
"""Send a message on Slack.
Args:
text (str): message to send on Slack
"""
if SLACK_WEBHOOK_URL:
try:
import requests

requests.post(
url=SLACK_WEBHOOK_URL,
json={"text": f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]: {text}"},
)
except Exception:
print("Impossible to send message on Slack, continue...")


def record_lr(
model: tf.keras.Model,
Expand Down Expand Up @@ -87,6 +114,8 @@ def record_lr(
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False):
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
last_progress = 0
interval_progress = 5
pbar = tqdm(train_iter, position=1)
for images, targets in pbar:
images = batch_transforms(images)
Expand All @@ -99,6 +128,11 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False):
optimizer.apply_gradients(zip(grads, model.trainable_weights))

pbar.set_description(f"Training loss: {train_loss.numpy():.6}")
current_progress = pbar.n / pbar.total * 100
if current_progress - last_progress > interval_progress:
send_on_slack(str(pbar))
last_progress = int(current_progress)
send_on_slack(f"Final training loss: {train_loss.item():.6}")


def evaluate(model, val_loader, batch_transforms, val_metric):
Expand Down Expand Up @@ -129,6 +163,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric):

def main(args):
print(args)
send_on_slack(f"Start training: {args}")

if args.push_to_hub:
login_to_hub()
Expand Down Expand Up @@ -175,6 +210,10 @@ def main(args):
f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{val_loader.num_batches} batches)"
)
send_on_slack(
f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{val_loader.num_batches} batches)"
)
with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

Expand Down Expand Up @@ -264,6 +303,10 @@ def main(args):
f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{train_loader.num_batches} batches)"
)
send_on_slack(
f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{train_loader.num_batches} batches)"
)
with open(os.path.join(args.train_path, "labels.json"), "rb") as f:
train_hash = hashlib.sha256(f.read()).hexdigest()

Expand Down Expand Up @@ -347,17 +390,20 @@ def main(args):
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
send_on_slack(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
model.save_weights(f"./{exp_name}/weights")
min_loss = val_loss
if args.save_interval_epoch:
print(f"Saving state at epoch: {epoch + 1}")
send_on_slack(f"Saving state at epoch: {epoch + 1}")
model.save_weights(f"./{exp_name}_{epoch + 1}/weights")
log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
if any(val is None for val in (recall, precision, mean_iou)):
log_msg += "(Undefined metric value, caused by empty GTs or predictions)"
else:
log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})"
print(log_msg)
send_on_slack(log_msg)
# W&B
if args.wb:
wandb.log(
Expand Down

0 comments on commit 99b796e

Please sign in to comment.