diff --git a/README.md b/README.md
index aaa11ece2..38eb32850 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,7 @@
-
-
-
-
-
+
+
+
[![Python Versions](https://img.shields.io/pypi/pyversions/jumanji.svg?style=flat-square)](https://www.python.org/doc/versions/)
diff --git a/docs/img/jumanji_logo_dm.png b/docs/img/jumanji_logo_dm.png
deleted file mode 100644
index 0a13109cb..000000000
Binary files a/docs/img/jumanji_logo_dm.png and /dev/null differ
diff --git a/jumanji/env.py b/jumanji/env.py
index 7f855b4fd..d3ddac6bd 100644
--- a/jumanji/env.py
+++ b/jumanji/env.py
@@ -14,6 +14,8 @@
"""Abstract environment class"""
+from __future__ import annotations
+
import abc
from typing import Any, Generic, Tuple, TypeVar
@@ -105,7 +107,7 @@ def discount_spec(self) -> specs.BoundedArray:
)
@property
- def unwrapped(self) -> "Environment":
+ def unwrapped(self) -> Environment:
return self
def render(self, state: State) -> Any:
@@ -119,7 +121,7 @@ def render(self, state: State) -> Any:
def close(self) -> None:
"""Perform any necessary cleanup."""
- def __enter__(self) -> "Environment":
+ def __enter__(self) -> Environment:
return self
def __exit__(self, *args: Any) -> None:
diff --git a/jumanji/environments/packing/bin_pack/space.py b/jumanji/environments/packing/bin_pack/space.py
index 45d1fccf1..89f38dc45 100644
--- a/jumanji/environments/packing/bin_pack/space.py
+++ b/jumanji/environments/packing/bin_pack/space.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
from typing import TYPE_CHECKING, Any
@@ -32,7 +33,7 @@ class Space:
z1: chex.Numeric
z2: chex.Numeric
- def astype(self, dtype: Any) -> "Space":
+ def astype(self, dtype: Any) -> Space:
space_dict = {
key: jnp.asarray(value, dtype) for key, value in self.__dict__.items()
}
@@ -90,7 +91,7 @@ def volume(self) -> chex.Numeric:
z_len = jnp.asarray(self.z2 - self.z1, float)
return x_len * y_len * z_len
- def intersection(self, space: "Space") -> "Space":
+ def intersection(self, space: Space) -> Space:
"""Returns the intersected space with another space (i.e. the space that is included in both
spaces whose volume is maximum).
"""
@@ -102,7 +103,7 @@ def intersection(self, space: "Space") -> "Space":
z2 = jnp.minimum(self.z2, space.z2)
return Space(x1=x1, x2=x2, y1=y1, y2=y2, z1=z1, z2=z2)
- def intersect(self, space: "Space") -> chex.Numeric:
+ def intersect(self, space: Space) -> chex.Numeric:
"""Returns whether a space intersect another space or not."""
return ~(self.intersection(space).is_empty())
@@ -110,7 +111,7 @@ def is_empty(self) -> chex.Numeric:
"""A space is empty if at least one dimension is negative or zero."""
return (self.x1 >= self.x2) | (self.y1 >= self.y2) | (self.z1 >= self.z2)
- def is_included(self, space: "Space") -> chex.Numeric:
+ def is_included(self, space: Space) -> chex.Numeric:
"""Returns whether self is included into another space."""
return (
(self.x1 >= space.x1)
@@ -121,7 +122,7 @@ def is_included(self, space: "Space") -> chex.Numeric:
& (self.z2 <= space.z2)
)
- def hyperplane(self, axis: str, direction: str) -> "Space":
+ def hyperplane(self, axis: str, direction: str) -> Space:
"""Returns the hyperplane (e.g. lower hyperplane on the x axis) for EMS creation when
packing an item.
diff --git a/jumanji/environments/routing/connector/types.py b/jumanji/environments/routing/connector/types.py
index 53f12edf8..acf81ea6e 100644
--- a/jumanji/environments/routing/connector/types.py
+++ b/jumanji/environments/routing/connector/types.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
from typing import TYPE_CHECKING, Any, NamedTuple
@@ -42,7 +43,7 @@ def connected(self) -> chex.Array:
"""returns: True if the agent has reached its target."""
return jnp.all(self.position == self.target, axis=-1)
- def __eq__(self: "Agent", agent_2: Any) -> chex.Array:
+ def __eq__(self: Agent, agent_2: Any) -> chex.Array:
if not isinstance(agent_2, Agent):
return NotImplemented
same_ids = (agent_2.id == self.id).all()
diff --git a/jumanji/environments/routing/snake/types.py b/jumanji/environments/routing/snake/types.py
index 6b6b0354d..7617240e8 100644
--- a/jumanji/environments/routing/snake/types.py
+++ b/jumanji/environments/routing/snake/types.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
from enum import IntEnum
from typing import TYPE_CHECKING, NamedTuple
@@ -27,12 +28,12 @@ class Position(NamedTuple):
row: chex.Array
col: chex.Array
- def __eq__(self, other: "Position") -> chex.Array: # type: ignore[override]
+ def __eq__(self, other: Position) -> chex.Array: # type: ignore[override]
if not isinstance(other, Position):
return NotImplemented
return (self.row == other.row) & (self.col == other.col)
- def __add__(self, other: "Position") -> "Position": # type: ignore[override]
+ def __add__(self, other: Position) -> Position: # type: ignore[override]
if not isinstance(other, Position):
return NotImplemented
return Position(row=self.row + other.row, col=self.col + other.col)
diff --git a/jumanji/training/loggers.py b/jumanji/training/loggers.py
index 8ccf07b26..722bd9987 100644
--- a/jumanji/training/loggers.py
+++ b/jumanji/training/loggers.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
import abc
import collections
@@ -56,7 +57,7 @@ def close(self) -> None:
def upload_checkpoint(self) -> None:
"""Uploads a checkpoint when exiting the logger."""
- def __enter__(self) -> "Logger":
+ def __enter__(self) -> Logger:
logging.info("Starting logger.")
self._variables_enter = self._get_variables()
return self
diff --git a/jumanji/training/networks/distribution.py b/jumanji/training/networks/distribution.py
index 0e7a751fb..481790d76 100644
--- a/jumanji/training/networks/distribution.py
+++ b/jumanji/training/networks/distribution.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Adapted from Brax."""
+from __future__ import annotations
import abc
@@ -39,7 +40,7 @@ def entropy(self) -> chex.Array:
pass
@abc.abstractmethod
- def kl_divergence(self, other: "Distribution") -> chex.Array:
+ def kl_divergence(self, other: Distribution) -> chex.Array:
pass
@@ -77,7 +78,7 @@ def entropy(self) -> chex.Array:
def kl_divergence( # type: ignore[override]
self,
- other: "CategoricalDistribution",
+ other: CategoricalDistribution,
) -> chex.Array:
log_probs = jax.nn.log_softmax(self.logits)
probs = jax.nn.softmax(self.logits)
diff --git a/jumanji/training/timer.py b/jumanji/training/timer.py
index 24c7a446b..3d03d55d3 100644
--- a/jumanji/training/timer.py
+++ b/jumanji/training/timer.py
@@ -14,6 +14,7 @@
# Inspired from https://stackoverflow.com/questions/51849395/how-can-we-associate-a-python-context-m
# anager-to-the-variables-appearing-in-it#:~:text=also%20inspect%20the-,stack,-for%20locals()%20variables
+from __future__ import annotations
import inspect
import logging
@@ -45,7 +46,7 @@ def _get_variables(self) -> Dict:
"""
return {(k, id(v)): v for k, v in inspect.stack()[2].frame.f_locals.items()}
- def __enter__(self) -> "Timer":
+ def __enter__(self) -> Timer:
self._variables_enter = self._get_variables()
self._start_time = time.perf_counter()
return self
diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py
index a96e29b57..98d1ec91d 100644
--- a/jumanji/wrappers.py
+++ b/jumanji/wrappers.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
from typing import (
Any,
@@ -120,7 +121,7 @@ def close(self) -> None:
"""
return self._env.close()
- def __enter__(self) -> "Wrapper":
+ def __enter__(self) -> Wrapper:
return self
def __exit__(self, *args: Any) -> None: