Skip to content

Commit

Permalink
Reworked traing resume again.
Browse files Browse the repository at this point in the history
  • Loading branch information
MillionIntegrals committed Apr 7, 2019
1 parent 2547da7 commit 599a8c4
Show file tree
Hide file tree
Showing 19 changed files with 64 additions and 43 deletions.
7 changes: 5 additions & 2 deletions vel/api/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ def number_of_parameters(self):
""" Count model parameters """
return sum(p.numel() for p in self.model.parameters())

def initialize_training(self, training_info: TrainingInfo):
def initialize_training(self, training_info: TrainingInfo, model_state=None, hidden_state=None):
""" Prepare for training """
self.model.reset_weights()
if model_state is None:
self.model.reset_weights()
else:
self.model.load_state_dict(model_state)

def run_epoch(self, epoch_info: EpochInfo, source: 'vel.api.Source'):
""" Run full epoch of learning """
Expand Down
2 changes: 1 addition & 1 deletion vel/api/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def last_epoch_idx(self) -> int:
""" Return last checkpointed epoch idx for given configuration. Returns 0 if no results have been stored """
raise NotImplementedError

def resume(self, train_info: TrainingInfo, model: Model) -> dict:
def load(self, train_info: TrainingInfo) -> (dict, dict):
"""
Resume learning process and return loaded hidden state dictionary
"""
Expand Down
3 changes: 2 additions & 1 deletion vel/commands/phase_train_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def resume_training(self, learner, callbacks, metrics) -> (TrainingInfo, dict):
learner.initialize_training(training_info)
hidden_state = None
else:
hidden_state = self.storage.resume(training_info, learner.model)
model_state, hidden_state = self.storage.load(training_info)
learner.initialize_training(training_info, model_state, hidden_state)

return training_info, hidden_state

Expand Down
3 changes: 2 additions & 1 deletion vel/commands/rnn/generate_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def run(self):
run_name=self.model_config.run_name,
)

self.storage.resume(training_info, model)
model_state, hidden_state = self.storage.load(training_info)
model.load_state_dict(model_state)

model.eval()

Expand Down
3 changes: 2 additions & 1 deletion vel/commands/train_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def resume_training(self, learner, callbacks, metrics) -> api.TrainingInfo:
training_info.initialize()
learner.initialize_training(training_info)
else:
self.storage.resume(training_info, learner.model)
model_state, hidden_state = self.storage.load(training_info)
learner.initialize_training(training_info, model_state, hidden_state)

return training_info

Expand Down
2 changes: 1 addition & 1 deletion vel/rl/algo/distributional_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model_factory: ModelFactory, discount_factor: float, double_d
self.support_atoms = None
self.atom_delta = None

def initialize(self, model, environment, device):
def initialize(self, training_info, model, environment, device):
""" Initialize policy gradient from reinforcer settings """
self.target_model = self.model_factory.instantiate(action_space=environment.action_space).to(device)
self.target_model.load_state_dict(model.state_dict())
Expand Down
2 changes: 1 addition & 1 deletion vel/rl/algo/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, model_factory: ModelFactory, discount_factor: float, double_d

self.target_model = None

def initialize(self, model, environment, device):
def initialize(self, training_info, model, environment, device):
""" Initialize policy gradient from reinforcer settings """
self.target_model = self.model_factory.instantiate(action_space=environment.action_space).to(device)
self.target_model.load_state_dict(model.state_dict())
Expand Down
21 changes: 8 additions & 13 deletions vel/rl/algo/policy_gradient/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def select_indices(tensor, indices):
class AcerPolicyGradient(OptimizerAlgoBase):
""" Actor-Critic with Experience Replay - policy gradient calculations """

def __init__(self, model_factory, discount_factor, trust_region: bool=True, entropy_coefficient: float=0.01,
q_coefficient: float=0.5, rho_cap: float=10.0, retrace_rho_cap: float=1.0, max_grad_norm: float=None,
average_model_alpha=0.99, trust_region_delta=1.0):
def __init__(self, model_factory, discount_factor, trust_region: bool = True, entropy_coefficient: float = 0.01,
q_coefficient: float = 0.5, rho_cap: float = 10.0, retrace_rho_cap: float = 1.0,
max_grad_norm: float = None, average_model_alpha: float = 0.99, trust_region_delta: float = 1.0):
super().__init__(max_grad_norm)

self.discount_factor = discount_factor
Expand All @@ -31,25 +31,20 @@ def __init__(self, model_factory, discount_factor, trust_region: bool=True, entr

# Trust region settings
self.average_model = None
self.average_model_initialized = False
self.average_model_alpha = average_model_alpha
self.trust_region_delta = trust_region_delta

def initialize(self, model, environment, device):
def initialize(self, training_info, model, environment, device):
""" Initialize policy gradient from reinforcer settings """
if self.trust_region:
self.average_model = self.model_factory.instantiate(action_space=environment.action_space).to(device)
self.average_model.load_state_dict(model.state_dict())

def update_average_model(self, model):
""" Update weights of the average model with new model observation """
if not self.average_model_initialized:
# Initialize average model to have the same weights as the main model
self.average_model.load_state_dict(model.state_dict())
self.average_model_initialized = True
else:
for model_param, average_param in zip(model.parameters(), self.average_model.parameters()):
# EWMA average model update
average_param.data.mul_(self.average_model_alpha).add_(model_param.data * (1 - self.average_model_alpha))
for model_param, average_param in zip(model.parameters(), self.average_model.parameters()):
# EWMA average model update
average_param.data.mul_(self.average_model_alpha).add_(model_param.data * (1 - self.average_model_alpha))

def calculate_gradient(self, batch_info, device, model, rollout):
""" Calculate loss of the supplied rollout """
Expand Down
3 changes: 2 additions & 1 deletion vel/rl/algo/policy_gradient/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ def __init__(self, model_factory, discount_factor: float, tau: float, max_grad_n

self.target_model = None

def initialize(self, model, environment, device):
def initialize(self, training_info, model, environment, device):
""" Initialize algo from reinforcer settings """
self.target_model = self.model_factory.instantiate(action_space=environment.action_space).to(device)
self.target_model.load_state_dict(model.state_dict())
self.target_model.eval()

def calculate_gradient(self, batch_info, device, model, rollout):
""" Calculate loss of the supplied rollout """
Expand Down
2 changes: 1 addition & 1 deletion vel/rl/api/algo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def clip_gradients(batch_result, model, max_grad_norm):
class AlgoBase:
""" Base class for algo reinforcement calculations """

def initialize(self, model, environment, device):
def initialize(self, training_info, model, environment, device):
""" Initialize algo from reinforcer settings """
pass

Expand Down
2 changes: 1 addition & 1 deletion vel/rl/api/reinforcer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ReinforcerBase:
Learner version for reinforcement-learning problems.
"""

def initialize_training(self, training_info: TrainingInfo):
def initialize_training(self, training_info: TrainingInfo, model_state=None, hidden_state=None):
""" Run the initialization procedure """
pass

Expand Down
2 changes: 1 addition & 1 deletion vel/rl/commands/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run(self):
run_name=self.model_config.run_name
)

self.storage.resume(training_info, model)
self.storage.load(training_info, model)

model.eval()

Expand Down
3 changes: 2 additions & 1 deletion vel/rl/commands/evaluate_env_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def run(self):
start_epoch_idx=self.storage.last_epoch_idx(), run_name=self.model_config.run_name
)

self.storage.resume(training_info, model)
model_state, hidden_state = self.storage.load(training_info)
model.load_state_dict(model_state)

print("Loading model trained for {} epochs".format(training_info.start_epoch_idx))

Expand Down
3 changes: 2 additions & 1 deletion vel/rl/commands/record_movie_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def run(self):
run_name=self.model_config.run_name
)

self.storage.resume(training_info, model)
model_state, hidden_state = self.storage.load(training_info)
model.load_state_dict(model_state)

model.eval()

Expand Down
3 changes: 2 additions & 1 deletion vel/rl/commands/rl_train_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def resume_training(self, reinforcer, callbacks, metrics) -> TrainingInfo:
training_info.initialize()
reinforcer.initialize_training(training_info)
else:
self.storage.resume(training_info, reinforcer.model)
model_state, hidden_state = self.storage.load(training_info)
reinforcer.initialize_training(training_info, model_state, hidden_state)

return training_info

Expand Down
14 changes: 10 additions & 4 deletions vel/rl/reinforcers/buffered_mixed_policy_iteration_reinforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import tqdm

from vel.api import EpochInfo, BatchInfo, Model, ModelFactory
from vel.api import TrainingInfo, EpochInfo, BatchInfo, Model, ModelFactory
from vel.openai.baselines.common.vec_env import VecEnv
from vel.rl.api import (
ReinforcerBase, ReinforcerFactory, VecEnvFactory, ReplayEnvRollerBase, AlgoBase, ReplayEnvRollerFactoryBase
Expand Down Expand Up @@ -61,10 +61,16 @@ def model(self) -> Model:
""" Model trained by this reinforcer """
return self._trained_model

def initialize_training(self, training_info):
def initialize_training(self, training_info: TrainingInfo, model_state=None, hidden_state=None):
""" Prepare models for training """
self.model.reset_weights()
self.algo.initialize(model=self.model, environment=self.environment, device=self.device)
if model_state is not None:
self.model.load_state_dict(model_state)
else:
self.model.reset_weights()

self.algo.initialize(
training_info=training_info, model=self.model, environment=self.environment, device=self.device
)

def train_epoch(self, epoch_info: EpochInfo, interactive=True):
""" Train model on an epoch of a fixed number of batch updates """
Expand Down
14 changes: 10 additions & 4 deletions vel/rl/reinforcers/buffered_off_policy_iteration_reinforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from vel.api import BatchInfo, EpochInfo, Model, ModelFactory
from vel.api import TrainingInfo, EpochInfo, BatchInfo, Model, ModelFactory
from vel.openai.baselines.common.vec_env import VecEnv
from vel.rl.api import (
ReinforcerBase, ReinforcerFactory, ReplayEnvRollerBase, AlgoBase, VecEnvFactory, ReplayEnvRollerFactoryBase
Expand Down Expand Up @@ -59,10 +59,16 @@ def metrics(self) -> list:
def model(self) -> Model:
return self._trained_model

def initialize_training(self, training_info):
def initialize_training(self, training_info: TrainingInfo, model_state=None, hidden_state=None):
""" Prepare models for training """
self.model.reset_weights()
self.algo.initialize(model=self.model, environment=self.environment, device=self.device)
if model_state is not None:
self.model.load_state_dict(model_state)
else:
self.model.reset_weights()

self.algo.initialize(
training_info=training_info, model=self.model, environment=self.environment, device=self.device
)

def train_epoch(self, epoch_info: EpochInfo, interactive=True) -> None:
""" Train model for a single epoch """
Expand Down
12 changes: 8 additions & 4 deletions vel/rl/reinforcers/on_policy_iteration_reinforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import tqdm

from vel.api import Model, ModelFactory, EpochInfo, BatchInfo
from vel.api import Model, ModelFactory, TrainingInfo, EpochInfo, BatchInfo
from vel.rl.api import ReinforcerBase, ReinforcerFactory, VecEnvFactory, EnvRollerFactoryBase, EnvRollerBase, AlgoBase
from vel.rl.metrics import (
FPSMetric, EpisodeLengthMetric, EpisodeRewardMetricQuantile,
Expand Down Expand Up @@ -60,11 +60,15 @@ def model(self) -> Model:
""" Model trained by this reinforcer """
return self._trained_model

def initialize_training(self, training_info):
def initialize_training(self, training_info: TrainingInfo, model_state=None, hidden_state=None):
""" Prepare models for training """
self.model.reset_weights()
if model_state is not None:
self.model.load_state_dict(model_state)
else:
self.model.reset_weights()

self.algo.initialize(
model=self.model, environment=self.env_roller.environment, device=self.device
training_info=training_info, model=self.model, environment=self.env_roller.environment, device=self.device
)

def train_epoch(self, epoch_info: EpochInfo, interactive=True) -> None:
Expand Down
6 changes: 3 additions & 3 deletions vel/storage/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ def reset(self, configuration: dict) -> None:
self.clean(0)
self.backend.store_config(configuration)

def resume(self, train_info: TrainingInfo, model: Model) -> dict:
def load(self, train_info: TrainingInfo) -> (dict, dict):
"""
Resume learning process and return loaded hidden state dictionary
"""
last_epoch = train_info.start_epoch_idx

model.load_state_dict(torch.load(self.checkpoint_filename(last_epoch)))
model_state = torch.load(self.checkpoint_filename(last_epoch))
hidden_state = torch.load(self.checkpoint_hidden_filename(last_epoch))

self.checkpoint_strategy.restore(hidden_state)
train_info.restore(hidden_state)

return hidden_state
return model_state, hidden_state

def get_metrics_frame(self):
""" Get a frame of metrics from backend """
Expand Down

0 comments on commit 599a8c4

Please sign in to comment.