diff --git a/README.md b/README.md index aaa11ece2..38eb32850 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,7 @@

- - - - Jumanji Logo - + + Jumanji logo +

[![Python Versions](https://img.shields.io/pypi/pyversions/jumanji.svg?style=flat-square)](https://www.python.org/doc/versions/) diff --git a/docs/img/jumanji_logo_dm.png b/docs/img/jumanji_logo_dm.png deleted file mode 100644 index 0a13109cb..000000000 Binary files a/docs/img/jumanji_logo_dm.png and /dev/null differ diff --git a/jumanji/env.py b/jumanji/env.py index 7f855b4fd..d3ddac6bd 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -14,6 +14,8 @@ """Abstract environment class""" +from __future__ import annotations + import abc from typing import Any, Generic, Tuple, TypeVar @@ -105,7 +107,7 @@ def discount_spec(self) -> specs.BoundedArray: ) @property - def unwrapped(self) -> "Environment": + def unwrapped(self) -> Environment: return self def render(self, state: State) -> Any: @@ -119,7 +121,7 @@ def render(self, state: State) -> Any: def close(self) -> None: """Perform any necessary cleanup.""" - def __enter__(self) -> "Environment": + def __enter__(self) -> Environment: return self def __exit__(self, *args: Any) -> None: diff --git a/jumanji/environments/packing/bin_pack/space.py b/jumanji/environments/packing/bin_pack/space.py index 45d1fccf1..89f38dc45 100644 --- a/jumanji/environments/packing/bin_pack/space.py +++ b/jumanji/environments/packing/bin_pack/space.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import TYPE_CHECKING, Any @@ -32,7 +33,7 @@ class Space: z1: chex.Numeric z2: chex.Numeric - def astype(self, dtype: Any) -> "Space": + def astype(self, dtype: Any) -> Space: space_dict = { key: jnp.asarray(value, dtype) for key, value in self.__dict__.items() } @@ -90,7 +91,7 @@ def volume(self) -> chex.Numeric: z_len = jnp.asarray(self.z2 - self.z1, float) return x_len * y_len * z_len - def intersection(self, space: "Space") -> "Space": + def intersection(self, space: Space) -> Space: """Returns the intersected space with another space (i.e. the space that is included in both spaces whose volume is maximum). """ @@ -102,7 +103,7 @@ def intersection(self, space: "Space") -> "Space": z2 = jnp.minimum(self.z2, space.z2) return Space(x1=x1, x2=x2, y1=y1, y2=y2, z1=z1, z2=z2) - def intersect(self, space: "Space") -> chex.Numeric: + def intersect(self, space: Space) -> chex.Numeric: """Returns whether a space intersect another space or not.""" return ~(self.intersection(space).is_empty()) @@ -110,7 +111,7 @@ def is_empty(self) -> chex.Numeric: """A space is empty if at least one dimension is negative or zero.""" return (self.x1 >= self.x2) | (self.y1 >= self.y2) | (self.z1 >= self.z2) - def is_included(self, space: "Space") -> chex.Numeric: + def is_included(self, space: Space) -> chex.Numeric: """Returns whether self is included into another space.""" return ( (self.x1 >= space.x1) @@ -121,7 +122,7 @@ def is_included(self, space: "Space") -> chex.Numeric: & (self.z2 <= space.z2) ) - def hyperplane(self, axis: str, direction: str) -> "Space": + def hyperplane(self, axis: str, direction: str) -> Space: """Returns the hyperplane (e.g. lower hyperplane on the x axis) for EMS creation when packing an item. diff --git a/jumanji/environments/routing/connector/types.py b/jumanji/environments/routing/connector/types.py index 53f12edf8..acf81ea6e 100644 --- a/jumanji/environments/routing/connector/types.py +++ b/jumanji/environments/routing/connector/types.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import TYPE_CHECKING, Any, NamedTuple @@ -42,7 +43,7 @@ def connected(self) -> chex.Array: """returns: True if the agent has reached its target.""" return jnp.all(self.position == self.target, axis=-1) - def __eq__(self: "Agent", agent_2: Any) -> chex.Array: + def __eq__(self: Agent, agent_2: Any) -> chex.Array: if not isinstance(agent_2, Agent): return NotImplemented same_ids = (agent_2.id == self.id).all() diff --git a/jumanji/environments/routing/snake/types.py b/jumanji/environments/routing/snake/types.py index 6b6b0354d..7617240e8 100644 --- a/jumanji/environments/routing/snake/types.py +++ b/jumanji/environments/routing/snake/types.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from enum import IntEnum from typing import TYPE_CHECKING, NamedTuple @@ -27,12 +28,12 @@ class Position(NamedTuple): row: chex.Array col: chex.Array - def __eq__(self, other: "Position") -> chex.Array: # type: ignore[override] + def __eq__(self, other: Position) -> chex.Array: # type: ignore[override] if not isinstance(other, Position): return NotImplemented return (self.row == other.row) & (self.col == other.col) - def __add__(self, other: "Position") -> "Position": # type: ignore[override] + def __add__(self, other: Position) -> Position: # type: ignore[override] if not isinstance(other, Position): return NotImplemented return Position(row=self.row + other.row, col=self.col + other.col) diff --git a/jumanji/training/loggers.py b/jumanji/training/loggers.py index 8ccf07b26..722bd9987 100644 --- a/jumanji/training/loggers.py +++ b/jumanji/training/loggers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import abc import collections @@ -56,7 +57,7 @@ def close(self) -> None: def upload_checkpoint(self) -> None: """Uploads a checkpoint when exiting the logger.""" - def __enter__(self) -> "Logger": + def __enter__(self) -> Logger: logging.info("Starting logger.") self._variables_enter = self._get_variables() return self diff --git a/jumanji/training/networks/distribution.py b/jumanji/training/networks/distribution.py index 0e7a751fb..481790d76 100644 --- a/jumanji/training/networks/distribution.py +++ b/jumanji/training/networks/distribution.py @@ -13,6 +13,7 @@ # limitations under the License. """Adapted from Brax.""" +from __future__ import annotations import abc @@ -39,7 +40,7 @@ def entropy(self) -> chex.Array: pass @abc.abstractmethod - def kl_divergence(self, other: "Distribution") -> chex.Array: + def kl_divergence(self, other: Distribution) -> chex.Array: pass @@ -77,7 +78,7 @@ def entropy(self) -> chex.Array: def kl_divergence( # type: ignore[override] self, - other: "CategoricalDistribution", + other: CategoricalDistribution, ) -> chex.Array: log_probs = jax.nn.log_softmax(self.logits) probs = jax.nn.softmax(self.logits) diff --git a/jumanji/training/timer.py b/jumanji/training/timer.py index 24c7a446b..3d03d55d3 100644 --- a/jumanji/training/timer.py +++ b/jumanji/training/timer.py @@ -14,6 +14,7 @@ # Inspired from https://stackoverflow.com/questions/51849395/how-can-we-associate-a-python-context-m # anager-to-the-variables-appearing-in-it#:~:text=also%20inspect%20the-,stack,-for%20locals()%20variables +from __future__ import annotations import inspect import logging @@ -45,7 +46,7 @@ def _get_variables(self) -> Dict: """ return {(k, id(v)): v for k, v in inspect.stack()[2].frame.f_locals.items()} - def __enter__(self) -> "Timer": + def __enter__(self) -> Timer: self._variables_enter = self._get_variables() self._start_time = time.perf_counter() return self diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index a96e29b57..98d1ec91d 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import ( Any, @@ -120,7 +121,7 @@ def close(self) -> None: """ return self._env.close() - def __enter__(self) -> "Wrapper": + def __enter__(self) -> Wrapper: return self def __exit__(self, *args: Any) -> None: