diff --git a/universe/wrappers/recording.py b/universe/wrappers/recording.py index bb749b24..633f14ab 100644 --- a/universe/wrappers/recording.py +++ b/universe/wrappers/recording.py @@ -3,7 +3,8 @@ import os import json import numpy as np -from universe import rewarder, spaces, vectorized +import threading, queue +from universe import rewarder, spaces, vectorized, pyprofile from universe.utils import random_alphanumeric logger = logging.getLogger(__name__) @@ -91,13 +92,6 @@ def _get_writer(self, i): self._log_n[i] = RecordingWriter(self._recording_dir, self._instance_id, i) return self._log_n[i] - def _close_log_files(self, i): - if self._log_n is None: - return - if self._log_n[i] is not None: - self._log_n[i].close() - self._log_n[i] = None - def _reset(self): for i in range(self.n): writer = self._get_writer(i) @@ -135,9 +129,16 @@ def _step(self, action_n): return observation_n, reward_n, done_n, info + def _close(self): + super(Recording, self)._close() + if self._log_n is not None: + for i in range(self.n): + if self._log_n[i] is not None: + self._log_n[i].close() + self._log_n[i] = None class RecordingWriter(object): - def __init__(self, recording_dir, instance_id, channel_id): + def __init__(self, recording_dir, instance_id, channel_id, async_write=True): self.log_fn = 'universe.recording.{}.{}.{}.jsonl'.format(os.getpid(), instance_id, channel_id) log_path = os.path.join(recording_dir, self.log_fn) self.bin_fn = 'universe.recording.{}.{}.{}.bin'.format(os.getpid(), instance_id, channel_id) @@ -145,8 +146,20 @@ def __init__(self, recording_dir, instance_id, channel_id): extra_logger.info('Logging to %s and %s', log_path, self.bin_fn) self.log_f = open(log_path, 'w') self.bin_f = open(bin_path, 'wb') + self.async_write = async_write + if self.async_write: + self.q = queue.Queue() + self.t = threading.Thread(target=self.writer_main) + self.t.start() def close(self): + if self.async_write: + self.q.put(None) + self.t.join() + else: + self.close_files() + + def close_files(self): if self.bin_f is not None: self.bin_f.close() self.bin_f = None @@ -176,10 +189,26 @@ def json_encode(self, obj): else: return obj + def writer_main(self): + while True: + item = self.q.get() + if item is None: break + self.write_item(item) + self.q.task_done() + self.close_files() + def __call__(self, **kwargs): - l = json.dumps(kwargs, skipkeys=True, default=self.json_encode) - self.log_f.write(l + '\n') - self.log_f.flush() + if self.async_write: + pyprofile.gauge('recording.qsize', self.q.qsize()) + self.q.put(kwargs) + else: + self.write_item(kwargs) + + def write_item(self, item): + with pyprofile.push('recording.write'): + l = json.dumps(item, skipkeys=True, default=self.json_encode) + self.log_f.write(l + '\n') + self.log_f.flush() class RecordingAnnotator(object): def __init__(self, writer, episode_id, step_id):