Skip to content

Commit

Permalink
send message on slack (pytorch script)
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Feb 6, 2024
1 parent cafb64b commit 114d7f0
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@
from doctr.utils.metrics import LocalizationConfusion
from utils import EarlyStopper, plot_recorder, plot_samples

SLACK_WEBHOOK_URL = None
SLACK_WEBHOOK_PATH = os.path.join(os.path.expanduser("~"), ".config", "doctr", "slack_webhook_url.txt")
if SLACK_WEBHOOK_PATH.exists():
with open(SLACK_WEBHOOK_PATH) as f:
SLACK_WEBHOOK_URL = f.read().strip()
else:
print(f"{SLACK_WEBHOOK_PATH} does not exist, skip Slack integration configuration...")


def send_on_slack(text: str):
"""Send a message on Slack.
Args:
text (str): message to send on Slack
"""
if SLACK_WEBHOOK_URL:
try:
import requests

requests.post(
url=SLACK_WEBHOOK_URL,
json={"text": f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]: {text}"},
)
except Exception:
print("Impossible to send message on Slack, continue...")


def record_lr(
model: torch.nn.Module,
Expand Down Expand Up @@ -106,6 +132,9 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a

model.train()
# Iterate over the batches of the dataset
total = 0
last_progress = 0
interval_progress = 5
pbar = tqdm(train_loader, position=1)
for images, targets in pbar:
if torch.cuda.is_available():
Expand All @@ -130,8 +159,13 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a
optimizer.step()

scheduler.step()

total += len(images)
pbar.set_description(f"Training loss: {train_loss.item():.6}")
current_progress = total / len(pbar) * 100
if current_progress - last_progress > interval_progress:
send_on_slack(f"Training loss: {train_loss.item():.6} ({current_progress:.2})")
last_progress = int(current_progress)
send_on_slack(f"Final training loss: {train_loss.item():.6}")


@torch.no_grad()
Expand Down Expand Up @@ -170,6 +204,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):

def main(args):
print(args)
send_on_slack(f"Start training: {args}")

if args.push_to_hub:
login_to_hub()
Expand Down Expand Up @@ -212,6 +247,9 @@ def main(args):
collate_fn=val_set.collate_fn,
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{len(val_loader)} batches)")
send_on_slack(
f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{len(val_loader)} batches)"
)
with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

Expand All @@ -227,6 +265,7 @@ def main(args):
# Resume weights
if isinstance(args.resume, str):
print(f"Resuming {args.resume}")
send_on_slack(f"Resuming {args.resume}")
checkpoint = torch.load(args.resume, map_location="cpu")
model.load_state_dict(checkpoint)

Expand Down Expand Up @@ -306,6 +345,9 @@ def main(args):
collate_fn=train_set.collate_fn,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{len(train_loader)} batches)")
send_on_slack(
f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{len(train_loader)} batches)"
)
with open(os.path.join(args.train_path, "labels.json"), "rb") as f:
train_hash = hashlib.sha256(f.read()).hexdigest()

Expand Down Expand Up @@ -379,6 +421,7 @@ def main(args):
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
send_on_slack(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
min_loss = val_loss
log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
Expand All @@ -387,6 +430,7 @@ def main(args):
else:
log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})"
print(log_msg)
send_on_slack(log_msg)
# W&B
if args.wb:
wandb.log(
Expand Down

0 comments on commit 114d7f0

Please sign in to comment.