Skip to content

Commit

Permalink
stop using custom ds for val
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Mar 20, 2024
1 parent 1172e24 commit 3518e73
Showing 1 changed file with 88 additions and 90 deletions.
178 changes: 88 additions & 90 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,72 +267,72 @@ def main(args):

batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287))

funsd_ds = DetectionDataset(
img_folder=os.path.join(args.funsd_path, "images"),
label_path=os.path.join(args.funsd_path, "labels.json"),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
if not args.rotation or args.eval_straight
else []
)
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation and not args.eval_straight
else []
)
),
use_polygons=args.rotation and not args.eval_straight,
)

funsd_test_loader = DataLoader(
funsd_ds,
batch_size=args.batch_size,
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(funsd_ds),
pin_memory=torch.cuda.is_available(),
collate_fn=funsd_ds.collate_fn,
)
print(f"FUNSD Test set loaded in {time.time() - st:.4}s ({len(funsd_ds)} samples in " f"{len(funsd_test_loader)} batches)")


cord_ds = DetectionDataset(
img_folder=os.path.join(args.cord_path, "images"),
label_path=os.path.join(args.cord_path, "labels.json"),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
if not args.rotation or args.eval_straight
else []
)
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation and not args.eval_straight
else []
)
),
use_polygons=args.rotation and not args.eval_straight,
)

cord_test_loader = DataLoader(
cord_ds,
batch_size=args.batch_size,
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(cord_ds),
pin_memory=torch.cuda.is_available(),
collate_fn=cord_ds.collate_fn,
)
print(f"CORD Test set loaded in {time.time() - st:.4}s ({len(cord_ds)} samples in " f"{len(funsd_test_loader)} batches)")
#funsd_ds = DetectionDataset(
# img_folder=os.path.join(args.funsd_path, "images"),
# label_path=os.path.join(args.funsd_path, "labels.json"),
# sample_transforms=T.SampleCompose(
# (
# [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
# if not args.rotation or args.eval_straight
# else []
# )
# + (
# [
# T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
# T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
# T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
# ]
# if args.rotation and not args.eval_straight
# else []
# )
# ),
# use_polygons=args.rotation and not args.eval_straight,
#)

#funsd_test_loader = DataLoader(
# funsd_ds,
# batch_size=args.batch_size,
# drop_last=False,
# num_workers=args.workers,
# sampler=SequentialSampler(funsd_ds),
# pin_memory=torch.cuda.is_available(),
# collate_fn=funsd_ds.collate_fn,
#)
#print(f"FUNSD Test set loaded in {time.time() - st:.4}s ({len(funsd_ds)} samples in " f"{len(funsd_test_loader)} batches)")


#cord_ds = DetectionDataset(
# img_folder=os.path.join(args.cord_path, "images"),
# label_path=os.path.join(args.cord_path, "labels.json"),
# sample_transforms=T.SampleCompose(
# (
# [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
# if not args.rotation or args.eval_straight
# else []
# )
# + (
# [
# T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
# T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
# T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
# ]
# if args.rotation and not args.eval_straight
# else []
# )
# ),
# use_polygons=args.rotation and not args.eval_straight,
#)

#cord_test_loader = DataLoader(
# cord_ds,
# batch_size=args.batch_size,
# drop_last=False,
# num_workers=args.workers,
# sampler=SequentialSampler(cord_ds),
# pin_memory=torch.cuda.is_available(),
# collate_fn=cord_ds.collate_fn,
#)
#print(f"CORD Test set loaded in {time.time() - st:.4}s ({len(cord_ds)} samples in " f"{len(funsd_test_loader)} batches)")

# Load doctr model
model = detection.__dict__[args.arch](
Expand Down Expand Up @@ -369,16 +369,16 @@ def main(args):
mask_shape=(args.input_size, args.input_size),
use_broadcasting=True if system_available_memory > 62 else False,
)
funsd_val_metric = LocalizationConfusion(
use_polygons=args.rotation and not args.eval_straight,
mask_shape=(args.input_size, args.input_size),
use_broadcasting=True if system_available_memory > 62 else False,
)
cord_val_metric = LocalizationConfusion(
use_polygons=args.rotation and not args.eval_straight,
mask_shape=(args.input_size, args.input_size),
use_broadcasting=True if system_available_memory > 62 else False,
)
#funsd_val_metric = LocalizationConfusion(
# use_polygons=args.rotation and not args.eval_straight,
# mask_shape=(args.input_size, args.input_size),
# use_broadcasting=True if system_available_memory > 62 else False,
#)
#cord_val_metric = LocalizationConfusion(
# use_polygons=args.rotation and not args.eval_straight,
# mask_shape=(args.input_size, args.input_size),
# use_broadcasting=True if system_available_memory > 62 else False,
#)

if args.test_only:
print("Running evaluation")
Expand Down Expand Up @@ -510,18 +510,18 @@ def main(args):
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
funsd_recall, funsd_precision, funsd_mean_iou = 0.0, 0.0, 0.0
cord_recall, cord_precision, cord_mean_iou = 0.0, 0.0, 0.0
try:
_, funsd_recall, funsd_precision, funsd_mean_iou = evaluate(
model, funsd_test_loader, batch_transforms, funsd_val_metric, amp=args.amp
)
except Exception:
pass
try:
_, cord_recall, cord_precision, cord_mean_iou = evaluate(
model, cord_test_loader, batch_transforms, cord_val_metric, amp=args.amp
)
except Exception:
pass
#try:
# _, funsd_recall, funsd_precision, funsd_mean_iou = evaluate(
# model, funsd_test_loader, batch_transforms, funsd_val_metric, amp=args.amp
# )
#except Exception:
# pass
#try:
# _, cord_recall, cord_precision, cord_mean_iou = evaluate(
# model, cord_test_loader, batch_transforms, cord_val_metric, amp=args.amp
# )
#except Exception:
# pass
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
send_on_slack(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
Expand Down Expand Up @@ -569,8 +569,6 @@ def parse_args():

parser.add_argument("train_path", type=str, help="path to training data folder")
parser.add_argument("val_path", type=str, help="path to validation data folder")
parser.add_argument("funsd_path", type=str, help="path to FUNSD data folder")
parser.add_argument("cord_path", type=str, help="path to Cord data folder")
parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
Expand Down

0 comments on commit 3518e73

Please sign in to comment.