Skip to content

Commit

Permalink
Remove bucket zone checks
Browse files Browse the repository at this point in the history
  • Loading branch information
dedeswim committed May 4, 2022
1 parent b30fe8f commit 49eab48
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 18 deletions.
3 changes: 0 additions & 3 deletions src/setup_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@


def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool):
if args.data_dir.startswith("gs://"):
utils.check_bucket_zone(args.data_dir, "large-ds")

data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary)
data_config['normalize'] = not (args.no_normalize or args.normalize_model)

Expand Down
12 changes: 0 additions & 12 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def get_outdir(path: str, *paths: str, inc=False) -> str:
"""Adapted to get out dir from GCS"""
outdir = os.path.join(path, *paths)
if path.startswith('gs://'):
check_bucket_zone(path, "robust-vits")
os_module = tf.io.gfile
exists_fn = lambda x: os_module.exists(x)
else:
Expand Down Expand Up @@ -67,17 +66,6 @@ def upload_checkpoints_gcs(checkpoints_dir: str, output_dir: str):
tf.io.gfile.copy(checkpoint, gcs_checkpoint_path)


def check_bucket_zone(data_dir, prefix):
if "ZONE" not in os.environ:
raise ValueError(
"The zone is not set for this machine, set the ZONE env variable to the zone of the machine")
zone = os.environ['ZONE']
if zone == "US":
assert data_dir.startswith(f"gs://{prefix}-us/"), f"The given dir {data_dir} is in the wrong zone"
elif zone == "EU":
assert data_dir.startswith(f"gs://{prefix}/"), f"The given dir {data_dir} is in the wrong zone"


class GCSSummaryCsv(bits.monitor.SummaryCsv):
"""SummaryCSV version to work with GCS"""
def __init__(self, output_dir, filename='summary.csv'):
Expand Down
13 changes: 10 additions & 3 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,6 @@ def validate(args):
model, criterion = dev_env.to_device(model, nn.CrossEntropyLoss())
model.to(dev_env.device)

if args.data.startswith("gs://"):
utils.check_bucket_zone(args.data, "large-ds")

dataset = create_dataset(root=args.data,
name=args.dataset,
split=args.split,
Expand Down Expand Up @@ -327,6 +324,7 @@ def validate(args):
logger = Monitor(logger=_logger)
tracker = Tracker()
losses = AvgTensor()
adv_losses = AvgTensor()
accuracy = AccuracyTopK(dev_env=dev_env)
adv_accuracy = AccuracyTopK(dev_env=dev_env)

Expand Down Expand Up @@ -370,6 +368,10 @@ def validate(args):
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if adv_output is not None:
adv_loss = criterion(adv_output, target)
else:
adv_loss = None

if dev_env.type_xla:
dev_env.mark_step()
Expand All @@ -383,12 +385,16 @@ def validate(args):
accuracy.update(output.detach(), target)
if adv_output is not None:
adv_accuracy.update(adv_output.detach(), target)
if adv_losses is not None:
adv_losses.update(adv_loss.detach(), sample.size(0))

tracker.mark_iter()
if last_step or step_idx % args.log_freq == 0:
top1, top5 = accuracy.compute().values()
robust_top1, robust_top5 = adv_accuracy.compute().values()
loss_avg = losses.compute()
adv_loss_avg = adv_losses.compute()

logger.log_step(
phase='eval',
step_idx=step_idx,
Expand All @@ -398,6 +404,7 @@ def validate(args):
loss=loss_avg.item(),
top1=top1.item(),
top5=top5.item(),
adv_loss=adv_loss_avg.item(),
robust_top1=robust_top1.item(),
robust_top5=robust_top5.item(),
)
Expand Down

0 comments on commit 49eab48

Please sign in to comment.