Skip to content

Release v0.80

Compare
Choose a tag to compare
@takuseno takuseno released this 24 Apr 05:16
· 672 commits to master since this release

Algorithms

New algorithms are introduced in this version.

Model-based RL

Previously, model-based RL has been supported. The model-based specific logic was implemented in dynamics side. This approach enabled us to combine model-based algorithms with arbitrary model-free algorithms. However, this requires complex designs to implement the recent model-based RL. So, the dynamics interface was refactored and the MOPO is the first algorithm to show how d3rlpy supports model-based RL algorithms.

# train dynamics model
from d3rlpy.datasets import get_pendulum
from d3rlpy.dynamics import ProbabilisticEnsembleDynamics
from d3rlpy.metrics.scorer import dynamics_observation_prediction_error_scorer
from d3rlpy.metrics.scorer import dynamics_reward_prediction_error_scorer
from d3rlpy.metrics.scorer import dynamics_prediction_variance_scorer
from sklearn.model_selection import train_test_split

dataset, _ = get_pendulum()

train_episodes, test_episodes = train_test_split(dataset)

dynamics = d3rlpy.dynamics.ProbabilisticEnsembleDynamics(learning_rate=1e-4, use_gpu=True)

dynamics.fit(train_episodes,
             eval_episodes=test_episodes,
             n_epochs=100,
             scorers={
                'observation_error': dynamics_observation_prediction_error_scorer,
                'reward_error': dynamics_reward_prediction_error_scorer,
                'variance': dynamics_prediction_variance_scorer,
             })

# train Model-based RL algorithm
from d3rlpy.algos import MOPO

# give mopo as generator argument.
mopo = MOPO(dynamics=dynamics)

mopo.fit(dataset, n_steps=100000)

enhancements

  • fitter method has been implemented (thanks @jamartinh )
  • tensorboard_dir repleces tensorboard flag at fit method (thanks @navidmdn )
  • show warning messages when the unused arguments are passed
  • show comprehensive error messages when action-space is not compatible
  • fit method accepts MDPDataset object
  • dropout option has been implemented in encoders
  • add appropriate __repr__ methods to show pretty outputs when print(algo)
  • metrics collection is refactored

bugfix

  • fix core dumped errors by fixing numpy version
  • fix CQL backup