Skip to content

Commit

Permalink
change layout for TB logging, add steps_per_sec
Browse files Browse the repository at this point in the history
  • Loading branch information
SunQpark committed Apr 23, 2019
1 parent 10f05e0 commit 9bd42ab
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
10 changes: 8 additions & 2 deletions logger/visualization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
from utils import Timer


class WriterTensorboardX():
Expand All @@ -21,11 +22,16 @@ def __init__(self, log_dir, logger, enable):
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
]
self.tag_mode_exceptions = ['add_histogram', 'add_embedding']

self.timer = Timer()

def set_step(self, step, mode='train'):
self.mode = mode
self.step = step
if step == 0:
self.timer.reset()
else:
duration = self.timer.check()
self.add_scalar('steps_per_sec', 1 / duration)

def __getattr__(self, name):
"""
Expand All @@ -41,7 +47,7 @@ def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = '{}/{}'.format(self.mode, tag)
tag = '{}/{}'.format(tag, self.mode)
add_data(tag, data, self.step, *args, **kwargs)
return wrapper
else:
Expand Down
14 changes: 14 additions & 0 deletions utils/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from pathlib import Path
from datetime import datetime
from collections import OrderedDict


Expand All @@ -15,3 +16,16 @@ def read_json(fname):
def write_json(content, fname):
with fname.open('wt') as handle:
json.dump(content, handle, indent=4, sort_keys=False)

class Timer:
def __init__(self):
self.cache = datetime.now()

def check(self):
now = datetime.now()
duration = now - self.cache
self.cache = now
return duration.total_seconds()

def reset(self):
self.cache = datetime.now()

0 comments on commit 9bd42ab

Please sign in to comment.