diff --git a/trax/jaxboard.py b/trax/jaxboard.py index 024c9942f..241c9694e 100644 --- a/trax/jaxboard.py +++ b/trax/jaxboard.py @@ -38,7 +38,6 @@ from tensorflow import HistogramProto from tensorflow import Summary from tensorflow import SummaryMetadata -from tensorflow.io import gfile # pylint: disable=g-direct-tensorflow-import from tensorflow.core.util import event_pb2 @@ -83,8 +82,8 @@ def __init__(self, log_dir, enable=True): multihost training. """ # If needed, create log_dir directory as well as missing parent directories. - if not gfile.isdir(log_dir): - gfile.makedirs(log_dir) + if not tf.io.gfile.isdir(log_dir): + tf.io.gfile.makedirs(log_dir) self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._step = 0 diff --git a/trax/rl/base_trainer.py b/trax/rl/base_trainer.py index 5717dfec4..96e4c23e0 100644 --- a/trax/rl/base_trainer.py +++ b/trax/rl/base_trainer.py @@ -22,7 +22,7 @@ import os from absl import logging -from tensorflow.io import gfile +import tensorflow as tf from trax import utils @@ -61,7 +61,7 @@ def __init__( def reset(self, output_dir): self._output_dir = output_dir - gfile.makedirs(self._output_dir) + tf.io.gfile.makedirs(self._output_dir) @property def async_mode(self): @@ -101,7 +101,7 @@ def dump_trajectories(self, force=False): pkl_module = utils.get_pickle_module() if self.trajectory_dump_dir is None: return - gfile.makedirs(self.trajectory_dump_dir) + tf.io.gfile.makedirs(self.trajectory_dump_dir) trajectories = self.train_env.trajectories if force: @@ -124,13 +124,13 @@ def has_any_action(trajectory): if ready or force: shard_path = os.path.join( self.trajectory_dump_dir, '{}.pkl'.format(self.epoch)) - if gfile.exists(shard_path): + if tf.io.gfile.exists(shard_path): # Since we do an extra dump at the end of the training loop, we # sometimes dump 2 times in the same epoch. When this happens, merge the # two sets of trajectories. - with gfile.GFile(shard_path, 'rb') as f: + with tf.io.gfile.GFile(shard_path, 'rb') as f: self._trajectory_buffer = pkl_module.load(f) + self._trajectory_buffer - with gfile.GFile(shard_path, 'wb') as f: + with tf.io.gfile.GFile(shard_path, 'wb') as f: pkl_module.dump(self._trajectory_buffer, f) self._trajectory_buffer = [] @@ -152,5 +152,6 @@ def indicate_done(self): """If in async mode, workers need to know we are done.""" if not self.async_mode: return - with gfile.GFile(os.path.join(self._output_dir, '__done__'), 'wb') as f: + with tf.io.gfile.GFile( + os.path.join(self._output_dir, '__done__'), 'wb') as f: f.write('') diff --git a/trax/rl/envs/online_tune_env.py b/trax/rl/envs/online_tune_env.py index 8297759a6..d21708f22 100644 --- a/trax/rl/envs/online_tune_env.py +++ b/trax/rl/envs/online_tune_env.py @@ -23,7 +23,7 @@ import os import gym -from tensorflow.io import gfile +import tensorflow as tf from trax import inputs as trax_inputs from trax import layers from trax import models as trax_models @@ -112,7 +112,7 @@ def __init__(self, self._observation_range = observation_range self._output_dir = output_dir - gfile.makedirs(self._output_dir) + tf.io.gfile.makedirs(self._output_dir) # Actions are indices in self._action_multipliers. self.action_space = gym.spaces.MultiDiscrete( [len(self._action_multipliers)] * len(self._control_configs) @@ -141,7 +141,7 @@ def _next_trajectory_dir(self): Returns: A path of the new directory. """ - trajectory_dirs = gfile.listdir(self._output_dir) + trajectory_dirs = tf.io.gfile.listdir(self._output_dir) def int_or_none(s): try: diff --git a/trax/rl/ppo.py b/trax/rl/ppo.py index 60984d656..f990f322e 100644 --- a/trax/rl/ppo.py +++ b/trax/rl/ppo.py @@ -66,10 +66,9 @@ from jax import numpy as np from jax import random as jax_random import numpy as onp - from tensor2tensor.envs import env_problem from tensor2tensor.envs import env_problem_utils -from tensorflow.io import gfile +import tensorflow as tf from trax import history as trax_history from trax import layers as tl from trax import utils @@ -805,7 +804,8 @@ def masked_entropy(log_probs, mask): def get_policy_model_files(output_dir): return list( reversed( - sorted(gfile.glob(os.path.join(output_dir, 'model-??????.pkl'))))) + sorted( + tf.io.gfile.glob(os.path.join(output_dir, 'model-??????.pkl'))))) def get_epoch_from_policy_model_file(policy_model_file): @@ -843,7 +843,7 @@ def maybe_restore_opt_state(output_dir, for model_file in get_policy_model_files(output_dir): logging.info('Trying to restore model from %s', model_file) try: - with gfile.GFile(model_file, 'rb') as f: + with tf.io.gfile.GFile(model_file, 'rb') as f: (policy_and_value_opt_state, policy_and_value_state, total_opt_step, history) = pkl_module.load(f) epoch = get_epoch_from_policy_model_file(model_file) @@ -874,14 +874,14 @@ def save_opt_state(output_dir, pkl_module = utils.get_pickle_module() old_model_files = get_policy_model_files(output_dir) params_file = os.path.join(output_dir, 'model-%06d.pkl' % epoch) - with gfile.GFile(params_file, 'wb') as f: + with tf.io.gfile.GFile(params_file, 'wb') as f: pkl_module.dump((policy_and_value_opt_state, policy_and_value_state, total_opt_step, history), f) # Keep the last k model files lying around (note k > 1 because the latest # model file might be in the process of getting read async). for path in old_model_files[LAST_N_POLICY_MODELS_TO_KEEP:]: if path != params_file: - gfile.remove(path) + tf.io.gfile.remove(path) def init_policy_from_world_model_checkpoint(policy_params, model_output_dir): @@ -889,7 +889,7 @@ def init_policy_from_world_model_checkpoint(policy_params, model_output_dir): pkl_module = utils.get_pickle_module() params_file = os.path.join(model_output_dir, 'model.pkl') # Don't use trax.load_trainer_state to avoid a circular import. - with gfile.GFile(params_file, 'rb') as f: + with tf.io.gfile.GFile(params_file, 'rb') as f: model_params = pkl_module.load(f)[0][0] # TODO(pkozakowski): The following, brittle line of code is hardcoded for # transplanting parameters from TransformerLM to TransformerDecoder-based diff --git a/trax/rl/simple.py b/trax/rl/simple.py index cefc8551e..6ef2f32d8 100644 --- a/trax/rl/simple.py +++ b/trax/rl/simple.py @@ -27,7 +27,7 @@ import numpy as np from tensor2tensor.envs import env_problem_utils from tensor2tensor.envs import trajectory -from tensorflow.io import gfile +import tensorflow as tf from trax import utils @@ -37,11 +37,11 @@ def load_trajectories(trajectory_dir, eval_frac): train_trajectories = [] eval_trajectories = [] # Search the entire directory subtree for trajectories. - for (subdir, _, filenames) in gfile.walk(trajectory_dir): + for (subdir, _, filenames) in tf.io.gfile.walk(trajectory_dir): for filename in filenames: shard_path = os.path.join(subdir, filename) try: - with gfile.GFile(shard_path, 'rb') as f: + with tf.io.gfile.GFile(shard_path, 'rb') as f: trajectories = pkl_module.load(f) pivot = int(len(trajectories) * (1 - eval_frac)) train_trajectories.extend(trajectories[:pivot]) diff --git a/trax/rl/simple_trainer.py b/trax/rl/simple_trainer.py index db6d1de10..66e6d8710 100644 --- a/trax/rl/simple_trainer.py +++ b/trax/rl/simple_trainer.py @@ -28,7 +28,7 @@ from absl import logging import gin from matplotlib import pyplot as plt -from tensorflow.io import gfile +import tensorflow as tf from trax import inputs as trax_inputs from trax import jaxboard from trax import trainer_lib @@ -83,9 +83,9 @@ def __init__(self, self._n_model_train_steps_per_epoch = n_model_train_steps_per_epoch self._data_eval_frac = data_eval_frac - gfile.makedirs(self._model_dir) + tf.io.gfile.makedirs(self._model_dir) if initial_model is not None: - gfile.copy( + tf.io.gfile.copy( initial_model, os.path.join(self._model_dir, 'model.pkl'), overwrite=True, diff --git a/trax/trainer_lib.py b/trax/trainer_lib.py index 93e4648b7..44b33330f 100644 --- a/trax/trainer_lib.py +++ b/trax/trainer_lib.py @@ -35,7 +35,6 @@ import numpy import six import tensorflow.compat.v2 as tf -from tensorflow.io import gfile from trax import backend from trax import history as trax_history from trax import inputs as trax_inputs @@ -271,7 +270,7 @@ def reset(self, output_dir): output_dir: Output directory. """ self._output_dir = output_dir - gfile.makedirs(output_dir) + tf.io.gfile.makedirs(output_dir) # Create summary writers and history. if self._should_write_summaries: self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'), @@ -426,7 +425,7 @@ def update_nontrainable_params(self): def save_gin(self): config_path = os.path.join(self._output_dir, 'config.gin') config_str = gin.operative_config_str() - with gfile.GFile(config_path, 'w') as f: + with tf.io.gfile.GFile(config_path, 'w') as f: f.write(config_str) sw = self._train_sw if sw: @@ -448,11 +447,11 @@ def save_state(self, keep): pkl_module = utils.get_pickle_module() weights_file = os.path.join(output_dir, 'model.pkl') - with gfile.GFile(weights_file, 'wb') as f: + with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) if keep: weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step)) - with gfile.GFile(weights_file, 'wb') as f: + with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) log('Model saved to %s' % weights_file, stdout=False) @@ -468,17 +467,18 @@ def save_computation_graphs(self, save_backward_graph): forward_computation = jax.xla_computation(self._model_predict_eval)( batch, weights=weights, state=self._model_state[0], rng=self._rngs[0]) - with gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f: + with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f: f.write(forward_computation.GetHloText()) - with gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f: + with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f: f.write(forward_computation.GetHloDotGraph()) backward_computation = jax.xla_computation(self._jit_update_fn)( self._step, self._opt_state, batch, self._model_state, self._rngs) - with gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f: + with tf.io.gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f: f.write(backward_computation.GetHloText()) if save_backward_graph: # Backward graphs can be large so we guard it. - with gfile.GFile(os.path.join(output_dir, 'backward.dot'), 'w') as f: + with tf.io.gfile.GFile( + os.path.join(output_dir, 'backward.dot'), 'w') as f: f.write(backward_computation.GetHloDotGraph()) def log_step(self, step_message): @@ -782,12 +782,12 @@ def epochs(total_steps, steps_to_skip, epoch_steps): def load_trainer_state(output_dir): """Returns a TrainerState instance loaded from the given `output_dir`.""" weights_file = os.path.join(output_dir, 'model.pkl') - if not gfile.exists(weights_file): + if not tf.io.gfile.exists(weights_file): return TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) pkl_module = utils.get_pickle_module() - with gfile.GFile(weights_file, 'rb') as f: + with tf.io.gfile.GFile(weights_file, 'rb') as f: (opt_state, step, history, model_state) = pkl_module.load(f) log('Model loaded from %s at step %d' % (weights_file, step)) logging.debug('From loaded model : history = %s', history)