Skip to content

Commit

Permalink
Merge branch 'main' into feat/add_sliding_tile_puzzle_environment
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Jan 16, 2024
2 parents 3fa7677 + d21d23b commit 0d8e386
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 20 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="docs/img/jumanji_logo_dm.png">
<source media="(prefers-color-scheme: light)" srcset="docs/img/jumanji_logo.png">
<img alt="Jumanji Logo" src="docs/img/jumanji_logo.png", width="50%">
</picture>
<a href="docs/img/jumanji_logo.png">
<img src="docs/img/jumanji_logo.png" alt="Jumanji logo" width="50%"/>
</a>
</p>

[![Python Versions](https://img.shields.io/pypi/pyversions/jumanji.svg?style=flat-square)](https://www.python.org/doc/versions/)
Expand Down
Binary file removed docs/img/jumanji_logo_dm.png
Binary file not shown.
6 changes: 4 additions & 2 deletions jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Abstract environment class"""

from __future__ import annotations

import abc
from typing import Any, Generic, Tuple, TypeVar

Expand Down Expand Up @@ -105,7 +107,7 @@ def discount_spec(self) -> specs.BoundedArray:
)

@property
def unwrapped(self) -> "Environment":
def unwrapped(self) -> Environment:
return self

def render(self, state: State) -> Any:
Expand All @@ -119,7 +121,7 @@ def render(self, state: State) -> Any:
def close(self) -> None:
"""Perform any necessary cleanup."""

def __enter__(self) -> "Environment":
def __enter__(self) -> Environment:
return self

def __exit__(self, *args: Any) -> None:
Expand Down
11 changes: 6 additions & 5 deletions jumanji/environments/packing/bin_pack/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

Expand All @@ -32,7 +33,7 @@ class Space:
z1: chex.Numeric
z2: chex.Numeric

def astype(self, dtype: Any) -> "Space":
def astype(self, dtype: Any) -> Space:
space_dict = {
key: jnp.asarray(value, dtype) for key, value in self.__dict__.items()
}
Expand Down Expand Up @@ -90,7 +91,7 @@ def volume(self) -> chex.Numeric:
z_len = jnp.asarray(self.z2 - self.z1, float)
return x_len * y_len * z_len

def intersection(self, space: "Space") -> "Space":
def intersection(self, space: Space) -> Space:
"""Returns the intersected space with another space (i.e. the space that is included in both
spaces whose volume is maximum).
"""
Expand All @@ -102,15 +103,15 @@ def intersection(self, space: "Space") -> "Space":
z2 = jnp.minimum(self.z2, space.z2)
return Space(x1=x1, x2=x2, y1=y1, y2=y2, z1=z1, z2=z2)

def intersect(self, space: "Space") -> chex.Numeric:
def intersect(self, space: Space) -> chex.Numeric:
"""Returns whether a space intersect another space or not."""
return ~(self.intersection(space).is_empty())

def is_empty(self) -> chex.Numeric:
"""A space is empty if at least one dimension is negative or zero."""
return (self.x1 >= self.x2) | (self.y1 >= self.y2) | (self.z1 >= self.z2)

def is_included(self, space: "Space") -> chex.Numeric:
def is_included(self, space: Space) -> chex.Numeric:
"""Returns whether self is included into another space."""
return (
(self.x1 >= space.x1)
Expand All @@ -121,7 +122,7 @@ def is_included(self, space: "Space") -> chex.Numeric:
& (self.z2 <= space.z2)
)

def hyperplane(self, axis: str, direction: str) -> "Space":
def hyperplane(self, axis: str, direction: str) -> Space:
"""Returns the hyperplane (e.g. lower hyperplane on the x axis) for EMS creation when
packing an item.
Expand Down
3 changes: 2 additions & 1 deletion jumanji/environments/routing/connector/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, NamedTuple

Expand Down Expand Up @@ -42,7 +43,7 @@ def connected(self) -> chex.Array:
"""returns: True if the agent has reached its target."""
return jnp.all(self.position == self.target, axis=-1)

def __eq__(self: "Agent", agent_2: Any) -> chex.Array:
def __eq__(self: Agent, agent_2: Any) -> chex.Array:
if not isinstance(agent_2, Agent):
return NotImplemented
same_ids = (agent_2.id == self.id).all()
Expand Down
5 changes: 3 additions & 2 deletions jumanji/environments/routing/snake/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from enum import IntEnum
from typing import TYPE_CHECKING, NamedTuple
Expand All @@ -27,12 +28,12 @@ class Position(NamedTuple):
row: chex.Array
col: chex.Array

def __eq__(self, other: "Position") -> chex.Array: # type: ignore[override]
def __eq__(self, other: Position) -> chex.Array: # type: ignore[override]
if not isinstance(other, Position):
return NotImplemented
return (self.row == other.row) & (self.col == other.col)

def __add__(self, other: "Position") -> "Position": # type: ignore[override]
def __add__(self, other: Position) -> Position: # type: ignore[override]
if not isinstance(other, Position):
return NotImplemented
return Position(row=self.row + other.row, col=self.col + other.col)
Expand Down
3 changes: 2 additions & 1 deletion jumanji/training/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import abc
import collections
Expand Down Expand Up @@ -56,7 +57,7 @@ def close(self) -> None:
def upload_checkpoint(self) -> None:
"""Uploads a checkpoint when exiting the logger."""

def __enter__(self) -> "Logger":
def __enter__(self) -> Logger:
logging.info("Starting logger.")
self._variables_enter = self._get_variables()
return self
Expand Down
5 changes: 3 additions & 2 deletions jumanji/training/networks/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Adapted from Brax."""
from __future__ import annotations

import abc

Expand All @@ -39,7 +40,7 @@ def entropy(self) -> chex.Array:
pass

@abc.abstractmethod
def kl_divergence(self, other: "Distribution") -> chex.Array:
def kl_divergence(self, other: Distribution) -> chex.Array:
pass


Expand Down Expand Up @@ -77,7 +78,7 @@ def entropy(self) -> chex.Array:

def kl_divergence( # type: ignore[override]
self,
other: "CategoricalDistribution",
other: CategoricalDistribution,
) -> chex.Array:
log_probs = jax.nn.log_softmax(self.logits)
probs = jax.nn.softmax(self.logits)
Expand Down
3 changes: 2 additions & 1 deletion jumanji/training/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Inspired from https://stackoverflow.com/questions/51849395/how-can-we-associate-a-python-context-m
# anager-to-the-variables-appearing-in-it#:~:text=also%20inspect%20the-,stack,-for%20locals()%20variables
from __future__ import annotations

import inspect
import logging
Expand Down Expand Up @@ -45,7 +46,7 @@ def _get_variables(self) -> Dict:
"""
return {(k, id(v)): v for k, v in inspect.stack()[2].frame.f_locals.items()}

def __enter__(self) -> "Timer":
def __enter__(self) -> Timer:
self._variables_enter = self._get_variables()
self._start_time = time.perf_counter()
return self
Expand Down
3 changes: 2 additions & 1 deletion jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import (
Any,
Expand Down Expand Up @@ -120,7 +121,7 @@ def close(self) -> None:
"""
return self._env.close()

def __enter__(self) -> "Wrapper":
def __enter__(self) -> Wrapper:
return self

def __exit__(self, *args: Any) -> None:
Expand Down

0 comments on commit 0d8e386

Please sign in to comment.