Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply black and codespell pre-commit hooks #222

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@ repos:
- id: end-of-file-fixer
- id: name-tests-test
args: ["--pytest-test-first"]
exclude: ^(tests/strategies.py|tests/utils.py)
- id: no-commit-to-branch
- id: trailing-whitespace
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
hooks:
- id: codespell
exclude: ^(docs/source/_static|docs/_build|pyproject.toml)
additional_dependencies:
- tomli
- repo: https://github.com/python/black
rev: 24.8.0
hooks:
- id: black
args: ["--line-length=120"]
exclude: ^(docs/)
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0)
- Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9).

### Changed (breaking changes: style)
- Format code using Black code formatter (it's ugly, yes, but it does its job)

### Fixed
- Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/agents/sarsa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ State Action Reward State Action (SARSA)

SARSA is a **model-free** **on-policy** algorithm that uses a **tabular** Q-function to handle **discrete** observations and action spaces

Paper: `On-Line Q-Learning Using Connectionist Systems <https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.17.2539>`_
Paper: `On-Line Q-Learning Using Connectionist Systems <https://scholar.google.com/scholar?q=On-line+Q-learning+using+connectionist+system>`_

.. raw:: html

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _get_observation_reward_done(self):
return self.obs_buf, reward, done

def reset(self):
print("Reseting...")
print("Resetting...")

# end current motion
if self.motion is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _get_observation_reward_done(self):
return self.obs_buf, reward, done

def reset(self):
print("Reseting...")
print("Resetting...")

# go to 1) safe position, 2) random position
self.robot.command_joint_position(self.robot_default_dof_pos)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def _callback_joint_states(self, msg):
self.robot_state["joint_velocity"] = np.array(msg.velocity)

def _callback_end_effector_pose(self, msg):
positon = msg.position
self.robot_state["cartesian_position"] = np.array([positon.x, positon.y, positon.z])
position = msg.position
self.robot_state["cartesian_position"] = np.array([position.x, position.y, position.z])

def _get_observation_reward_done(self):
# observation
Expand Down Expand Up @@ -146,7 +146,7 @@ def _get_observation_reward_done(self):
return self.obs_buf, reward, done

def reset(self):
print("Reseting...")
print("Resetting...")

# go to 1) safe position, 2) random position
msg = sensor_msgs.msg.JointState()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def _callback_joint_states(self, msg):
self.robot_state["joint_velocity"] = np.array(msg.velocity)

def _callback_end_effector_pose(self, msg):
positon = msg.position
self.robot_state["cartesian_position"] = np.array([positon.x, positon.y, positon.z])
position = msg.position
self.robot_state["cartesian_position"] = np.array([position.x, position.y, position.z])

def _get_observation_reward_done(self):
# observation
Expand Down Expand Up @@ -123,7 +123,7 @@ def _get_observation_reward_done(self):
return self.obs_buf, reward, done

def reset(self):
print("Reseting...")
print("Resetting...")

# go to 1) safe position, 2) random position
msg = sensor_msgs.msg.JointState()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/utils/tensorboard_file_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
mean = np.mean(rewards[:,:,1], axis=0)
std = np.std(rewards[:,:,1], axis=0)

# creae two subplots (one for each reward and one for the mean)
# create two subplots (one for each reward and one for the mean)
fig, ax = plt.subplots(1, 2, figsize=(15, 5))

# plot the rewards for each experiment
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ tests = [

[tool.black]
line-length = 120
extend-exclude = """
(
^/docs
)
"""


[tool.codespell]
Expand Down
57 changes: 35 additions & 22 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# read library version from metadata
try:
import importlib.metadata

__version__ = importlib.metadata.version("skrl")
except ImportError:
__version__ = "unknown"
Expand All @@ -21,15 +22,18 @@
# logger with format
class _Formatter(logging.Formatter):
_format = "[%(name)s:%(levelname)s] %(message)s"
_formats = {logging.DEBUG: f"\x1b[38;20m{_format}\x1b[0m",
logging.INFO: f"\x1b[38;20m{_format}\x1b[0m",
logging.WARNING: f"\x1b[33;20m{_format}\x1b[0m",
logging.ERROR: f"\x1b[31;20m{_format}\x1b[0m",
logging.CRITICAL: f"\x1b[31;1m{_format}\x1b[0m"}
_formats = {
logging.DEBUG: f"\x1b[38;20m{_format}\x1b[0m",
logging.INFO: f"\x1b[38;20m{_format}\x1b[0m",
logging.WARNING: f"\x1b[33;20m{_format}\x1b[0m",
logging.ERROR: f"\x1b[31;20m{_format}\x1b[0m",
logging.CRITICAL: f"\x1b[31;1m{_format}\x1b[0m",
}

def format(self, record):
return logging.Formatter(self._formats.get(record.levelno)).format(record)


_handler = logging.StreamHandler()
_handler.setLevel(logging.DEBUG)
_handler.setFormatter(_Formatter())
Expand All @@ -42,13 +46,11 @@ def format(self, record):
# machine learning framework configuration
class _Config(object):
def __init__(self) -> None:
"""Machine learning framework specific configuration
"""
"""Machine learning framework specific configuration"""

class PyTorch(object):
def __init__(self) -> None:
"""PyTorch configuration
"""
"""PyTorch configuration"""
self._device = None
# torch.distributed config
self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand All @@ -59,7 +61,10 @@ def __init__(self) -> None:
# set up distributed runs
if self._is_distributed:
import torch
logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})")

logger.info(
f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})"
)
torch.distributed.init_process_group("nccl", rank=self._rank, world_size=self._world_size)
torch.cuda.set_device(self._local_rank)

Expand All @@ -72,6 +77,7 @@ def device(self) -> "torch.device":
"""
try:
import torch

if self._device is None:
return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
return torch.device(self._device)
Expand Down Expand Up @@ -116,8 +122,7 @@ def is_distributed(self) -> bool:

class JAX(object):
def __init__(self) -> None:
"""JAX configuration
"""
"""JAX configuration"""
self._backend = "numpy"
self._key = np.array([0, 0], dtype=np.uint32)
# distributed config (based on torch.distributed, since JAX doesn't implement it)
Expand All @@ -126,19 +131,26 @@ def __init__(self) -> None:
self._local_rank = int(os.getenv("JAX_LOCAL_RANK", "0"))
self._rank = int(os.getenv("JAX_RANK", "0"))
self._world_size = int(os.getenv("JAX_WORLD_SIZE", "1"))
self._coordinator_address = os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234")
self._coordinator_address = (
os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234")
)
self._is_distributed = self._world_size > 1
# device
self._device = f"cuda:{self._local_rank}"

# set up distributed runs
if self._is_distributed:
import jax
logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})")
jax.distributed.initialize(coordinator_address=self._coordinator_address,
num_processes=self._world_size,
process_id=self._rank,
local_device_ids=self._local_rank)

logger.info(
f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})"
)
jax.distributed.initialize(
coordinator_address=self._coordinator_address,
num_processes=self._world_size,
process_id=self._rank,
local_device_ids=self._local_rank,
)

@staticmethod
def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
Expand All @@ -148,7 +160,7 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":

This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``).

:param device: Device specification. If the specified device is ``None`` ot it cannot be resolved,
:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
the default available device will be returned instead.

:return: JAX Device.
Expand All @@ -158,7 +170,7 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
if isinstance(device, jax.Device):
return device
elif isinstance(device, str):
device_type, device_index = f"{device}:0".split(':')[:2]
device_type, device_index = f"{device}:0".split(":")[:2]
try:
return jax.devices(device_type)[int(device_index)]
except (RuntimeError, IndexError) as e:
Expand Down Expand Up @@ -196,11 +208,11 @@ def backend(self, value: str) -> None:

@property
def key(self) -> "jax.Array":
"""Pseudo-random number generator (PRNG) key
"""
"""Pseudo-random number generator (PRNG) key"""
if isinstance(self._key, np.ndarray):
try:
import jax

with jax.default_device(self.device):
self._key = jax.random.PRNGKey(self._key[1])
except ImportError:
Expand Down Expand Up @@ -257,4 +269,5 @@ def is_distributed(self) -> bool:
self.jax = JAX()
self.torch = PyTorch()


config = _Config()
Loading
Loading