Skip to content

Commit

Permalink
Change gfile to tf.io.gfile, since we can't do `from tensorflow.io im…
Browse files Browse the repository at this point in the history
…port gfile` anymore.

PiperOrigin-RevId: 282105830
  • Loading branch information
afrozenator authored and copybara-github committed Nov 23, 2019
1 parent 6dcf02d commit 59b21ca
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 37 deletions.
5 changes: 2 additions & 3 deletions trax/jaxboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from tensorflow import HistogramProto
from tensorflow import Summary
from tensorflow import SummaryMetadata
from tensorflow.io import gfile

# pylint: disable=g-direct-tensorflow-import
from tensorflow.core.util import event_pb2
Expand Down Expand Up @@ -83,8 +82,8 @@ def __init__(self, log_dir, enable=True):
multihost training.
"""
# If needed, create log_dir directory as well as missing parent directories.
if not gfile.isdir(log_dir):
gfile.makedirs(log_dir)
if not tf.io.gfile.isdir(log_dir):
tf.io.gfile.makedirs(log_dir)

self._event_writer = EventFileWriter(log_dir, 10, 120, None)
self._step = 0
Expand Down
15 changes: 8 additions & 7 deletions trax/rl/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os

from absl import logging
from tensorflow.io import gfile
import tensorflow as tf
from trax import utils


Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(

def reset(self, output_dir):
self._output_dir = output_dir
gfile.makedirs(self._output_dir)
tf.io.gfile.makedirs(self._output_dir)

@property
def async_mode(self):
Expand Down Expand Up @@ -101,7 +101,7 @@ def dump_trajectories(self, force=False):
pkl_module = utils.get_pickle_module()
if self.trajectory_dump_dir is None:
return
gfile.makedirs(self.trajectory_dump_dir)
tf.io.gfile.makedirs(self.trajectory_dump_dir)

trajectories = self.train_env.trajectories
if force:
Expand All @@ -124,13 +124,13 @@ def has_any_action(trajectory):
if ready or force:
shard_path = os.path.join(
self.trajectory_dump_dir, '{}.pkl'.format(self.epoch))
if gfile.exists(shard_path):
if tf.io.gfile.exists(shard_path):
# Since we do an extra dump at the end of the training loop, we
# sometimes dump 2 times in the same epoch. When this happens, merge the
# two sets of trajectories.
with gfile.GFile(shard_path, 'rb') as f:
with tf.io.gfile.GFile(shard_path, 'rb') as f:
self._trajectory_buffer = pkl_module.load(f) + self._trajectory_buffer
with gfile.GFile(shard_path, 'wb') as f:
with tf.io.gfile.GFile(shard_path, 'wb') as f:
pkl_module.dump(self._trajectory_buffer, f)
self._trajectory_buffer = []

Expand All @@ -152,5 +152,6 @@ def indicate_done(self):
"""If in async mode, workers need to know we are done."""
if not self.async_mode:
return
with gfile.GFile(os.path.join(self._output_dir, '__done__'), 'wb') as f:
with tf.io.gfile.GFile(
os.path.join(self._output_dir, '__done__'), 'wb') as f:
f.write('')
6 changes: 3 additions & 3 deletions trax/rl/envs/online_tune_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os

import gym
from tensorflow.io import gfile
import tensorflow as tf
from trax import inputs as trax_inputs
from trax import layers
from trax import models as trax_models
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(self,
self._observation_range = observation_range

self._output_dir = output_dir
gfile.makedirs(self._output_dir)
tf.io.gfile.makedirs(self._output_dir)
# Actions are indices in self._action_multipliers.
self.action_space = gym.spaces.MultiDiscrete(
[len(self._action_multipliers)] * len(self._control_configs)
Expand Down Expand Up @@ -141,7 +141,7 @@ def _next_trajectory_dir(self):
Returns:
A path of the new directory.
"""
trajectory_dirs = gfile.listdir(self._output_dir)
trajectory_dirs = tf.io.gfile.listdir(self._output_dir)

def int_or_none(s):
try:
Expand Down
14 changes: 7 additions & 7 deletions trax/rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,9 @@
from jax import numpy as np
from jax import random as jax_random
import numpy as onp

from tensor2tensor.envs import env_problem
from tensor2tensor.envs import env_problem_utils
from tensorflow.io import gfile
import tensorflow as tf
from trax import history as trax_history
from trax import layers as tl
from trax import utils
Expand Down Expand Up @@ -805,7 +804,8 @@ def masked_entropy(log_probs, mask):
def get_policy_model_files(output_dir):
return list(
reversed(
sorted(gfile.glob(os.path.join(output_dir, 'model-??????.pkl')))))
sorted(
tf.io.gfile.glob(os.path.join(output_dir, 'model-??????.pkl')))))


def get_epoch_from_policy_model_file(policy_model_file):
Expand Down Expand Up @@ -843,7 +843,7 @@ def maybe_restore_opt_state(output_dir,
for model_file in get_policy_model_files(output_dir):
logging.info('Trying to restore model from %s', model_file)
try:
with gfile.GFile(model_file, 'rb') as f:
with tf.io.gfile.GFile(model_file, 'rb') as f:
(policy_and_value_opt_state, policy_and_value_state, total_opt_step,
history) = pkl_module.load(f)
epoch = get_epoch_from_policy_model_file(model_file)
Expand Down Expand Up @@ -874,22 +874,22 @@ def save_opt_state(output_dir,
pkl_module = utils.get_pickle_module()
old_model_files = get_policy_model_files(output_dir)
params_file = os.path.join(output_dir, 'model-%06d.pkl' % epoch)
with gfile.GFile(params_file, 'wb') as f:
with tf.io.gfile.GFile(params_file, 'wb') as f:
pkl_module.dump((policy_and_value_opt_state, policy_and_value_state,
total_opt_step, history), f)
# Keep the last k model files lying around (note k > 1 because the latest
# model file might be in the process of getting read async).
for path in old_model_files[LAST_N_POLICY_MODELS_TO_KEEP:]:
if path != params_file:
gfile.remove(path)
tf.io.gfile.remove(path)


def init_policy_from_world_model_checkpoint(policy_params, model_output_dir):
"""Initializes policy parameters from world model parameters."""
pkl_module = utils.get_pickle_module()
params_file = os.path.join(model_output_dir, 'model.pkl')
# Don't use trax.load_trainer_state to avoid a circular import.
with gfile.GFile(params_file, 'rb') as f:
with tf.io.gfile.GFile(params_file, 'rb') as f:
model_params = pkl_module.load(f)[0][0]
# TODO(pkozakowski): The following, brittle line of code is hardcoded for
# transplanting parameters from TransformerLM to TransformerDecoder-based
Expand Down
6 changes: 3 additions & 3 deletions trax/rl/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
from tensor2tensor.envs import env_problem_utils
from tensor2tensor.envs import trajectory
from tensorflow.io import gfile
import tensorflow as tf
from trax import utils


Expand All @@ -37,11 +37,11 @@ def load_trajectories(trajectory_dir, eval_frac):
train_trajectories = []
eval_trajectories = []
# Search the entire directory subtree for trajectories.
for (subdir, _, filenames) in gfile.walk(trajectory_dir):
for (subdir, _, filenames) in tf.io.gfile.walk(trajectory_dir):
for filename in filenames:
shard_path = os.path.join(subdir, filename)
try:
with gfile.GFile(shard_path, 'rb') as f:
with tf.io.gfile.GFile(shard_path, 'rb') as f:
trajectories = pkl_module.load(f)
pivot = int(len(trajectories) * (1 - eval_frac))
train_trajectories.extend(trajectories[:pivot])
Expand Down
6 changes: 3 additions & 3 deletions trax/rl/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from absl import logging
import gin
from matplotlib import pyplot as plt
from tensorflow.io import gfile
import tensorflow as tf
from trax import inputs as trax_inputs
from trax import jaxboard
from trax import trainer_lib
Expand Down Expand Up @@ -83,9 +83,9 @@ def __init__(self,
self._n_model_train_steps_per_epoch = n_model_train_steps_per_epoch
self._data_eval_frac = data_eval_frac

gfile.makedirs(self._model_dir)
tf.io.gfile.makedirs(self._model_dir)
if initial_model is not None:
gfile.copy(
tf.io.gfile.copy(
initial_model,
os.path.join(self._model_dir, 'model.pkl'),
overwrite=True,
Expand Down
22 changes: 11 additions & 11 deletions trax/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import numpy
import six
import tensorflow.compat.v2 as tf
from tensorflow.io import gfile
from trax import backend
from trax import history as trax_history
from trax import inputs as trax_inputs
Expand Down Expand Up @@ -271,7 +270,7 @@ def reset(self, output_dir):
output_dir: Output directory.
"""
self._output_dir = output_dir
gfile.makedirs(output_dir)
tf.io.gfile.makedirs(output_dir)
# Create summary writers and history.
if self._should_write_summaries:
self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'),
Expand Down Expand Up @@ -426,7 +425,7 @@ def update_nontrainable_params(self):
def save_gin(self):
config_path = os.path.join(self._output_dir, 'config.gin')
config_str = gin.operative_config_str()
with gfile.GFile(config_path, 'w') as f:
with tf.io.gfile.GFile(config_path, 'w') as f:
f.write(config_str)
sw = self._train_sw
if sw:
Expand All @@ -448,11 +447,11 @@ def save_state(self, keep):

pkl_module = utils.get_pickle_module()
weights_file = os.path.join(output_dir, 'model.pkl')
with gfile.GFile(weights_file, 'wb') as f:
with tf.io.gfile.GFile(weights_file, 'wb') as f:
pkl_module.dump((tuple(opt_state), step, history, model_state), f)
if keep:
weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step))
with gfile.GFile(weights_file, 'wb') as f:
with tf.io.gfile.GFile(weights_file, 'wb') as f:
pkl_module.dump((tuple(opt_state), step, history, model_state), f)
log('Model saved to %s' % weights_file, stdout=False)

Expand All @@ -468,17 +467,18 @@ def save_computation_graphs(self, save_backward_graph):
forward_computation = jax.xla_computation(self._model_predict_eval)(
batch, weights=weights, state=self._model_state[0],
rng=self._rngs[0])
with gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f:
with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f:
f.write(forward_computation.GetHloText())
with gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f:
with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f:
f.write(forward_computation.GetHloDotGraph())
backward_computation = jax.xla_computation(self._jit_update_fn)(
self._step, self._opt_state, batch, self._model_state,
self._rngs)
with gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f:
with tf.io.gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f:
f.write(backward_computation.GetHloText())
if save_backward_graph: # Backward graphs can be large so we guard it.
with gfile.GFile(os.path.join(output_dir, 'backward.dot'), 'w') as f:
with tf.io.gfile.GFile(
os.path.join(output_dir, 'backward.dot'), 'w') as f:
f.write(backward_computation.GetHloDotGraph())

def log_step(self, step_message):
Expand Down Expand Up @@ -782,12 +782,12 @@ def epochs(total_steps, steps_to_skip, epoch_steps):
def load_trainer_state(output_dir):
"""Returns a TrainerState instance loaded from the given `output_dir`."""
weights_file = os.path.join(output_dir, 'model.pkl')
if not gfile.exists(weights_file):
if not tf.io.gfile.exists(weights_file):
return TrainerState(step=None, opt_state=None,
history=trax_history.History(), model_state=None)

pkl_module = utils.get_pickle_module()
with gfile.GFile(weights_file, 'rb') as f:
with tf.io.gfile.GFile(weights_file, 'rb') as f:
(opt_state, step, history, model_state) = pkl_module.load(f)
log('Model loaded from %s at step %d' % (weights_file, step))
logging.debug('From loaded model : history = %s', history)
Expand Down

0 comments on commit 59b21ca

Please sign in to comment.