diff --git a/src/setup_task.py b/src/setup_task.py index 4fb0fa9..0050aa8 100644 --- a/src/setup_task.py +++ b/src/setup_task.py @@ -31,9 +31,6 @@ def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool): - if args.data_dir.startswith("gs://"): - utils.check_bucket_zone(args.data_dir, "large-ds") - data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary) data_config['normalize'] = not (args.no_normalize or args.normalize_model) diff --git a/src/utils.py b/src/utils.py index f35ba2f..1737028 100644 --- a/src/utils.py +++ b/src/utils.py @@ -24,7 +24,6 @@ def get_outdir(path: str, *paths: str, inc=False) -> str: """Adapted to get out dir from GCS""" outdir = os.path.join(path, *paths) if path.startswith('gs://'): - check_bucket_zone(path, "robust-vits") os_module = tf.io.gfile exists_fn = lambda x: os_module.exists(x) else: @@ -67,17 +66,6 @@ def upload_checkpoints_gcs(checkpoints_dir: str, output_dir: str): tf.io.gfile.copy(checkpoint, gcs_checkpoint_path) -def check_bucket_zone(data_dir, prefix): - if "ZONE" not in os.environ: - raise ValueError( - "The zone is not set for this machine, set the ZONE env variable to the zone of the machine") - zone = os.environ['ZONE'] - if zone == "US": - assert data_dir.startswith(f"gs://{prefix}-us/"), f"The given dir {data_dir} is in the wrong zone" - elif zone == "EU": - assert data_dir.startswith(f"gs://{prefix}/"), f"The given dir {data_dir} is in the wrong zone" - - class GCSSummaryCsv(bits.monitor.SummaryCsv): """SummaryCSV version to work with GCS""" def __init__(self, output_dir, filename='summary.csv'): diff --git a/validate.py b/validate.py index 7fa692f..a237137 100755 --- a/validate.py +++ b/validate.py @@ -283,9 +283,6 @@ def validate(args): model, criterion = dev_env.to_device(model, nn.CrossEntropyLoss()) model.to(dev_env.device) - if args.data.startswith("gs://"): - utils.check_bucket_zone(args.data, "large-ds") - dataset = create_dataset(root=args.data, name=args.dataset, split=args.split, @@ -327,6 +324,7 @@ def validate(args): logger = Monitor(logger=_logger) tracker = Tracker() losses = AvgTensor() + adv_losses = AvgTensor() accuracy = AccuracyTopK(dev_env=dev_env) adv_accuracy = AccuracyTopK(dev_env=dev_env) @@ -370,6 +368,10 @@ def validate(args): if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) + if adv_output is not None: + adv_loss = criterion(adv_output, target) + else: + adv_loss = None if dev_env.type_xla: dev_env.mark_step() @@ -383,12 +385,16 @@ def validate(args): accuracy.update(output.detach(), target) if adv_output is not None: adv_accuracy.update(adv_output.detach(), target) + if adv_losses is not None: + adv_losses.update(adv_loss.detach(), sample.size(0)) tracker.mark_iter() if last_step or step_idx % args.log_freq == 0: top1, top5 = accuracy.compute().values() robust_top1, robust_top5 = adv_accuracy.compute().values() loss_avg = losses.compute() + adv_loss_avg = adv_losses.compute() + logger.log_step( phase='eval', step_idx=step_idx, @@ -398,6 +404,7 @@ def validate(args): loss=loss_avg.item(), top1=top1.item(), top5=top5.item(), + adv_loss=adv_loss_avg.item(), robust_top1=robust_top1.item(), robust_top5=robust_top5.item(), )