diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f62452ace..635a719bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,23 @@ -default_stages: [ "commit", "commit-msg", "push" ] +default_stages: [ "pre-commit", "commit-msg", "pre-push" ] default_language_version: python: python3 repos: - - repo: https://github.com/timothycrosley/isort - rev: 5.11.5 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.7.2 hooks: - - id: isort - - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - name: "Code formatter" + # Run the linter. + - id: ruff + types_or: [ python ] + args: [ --fix ] + # Run the formatter. + - id: ruff-format + types_or: [ python, pyi, jupyter ] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v5.0.0 hooks: - id: end-of-file-fixer name: "End of file fixer" @@ -32,22 +33,6 @@ repos: - id: trailing-whitespace name: "Trailing whitespace fixer" - - repo: https://github.com/PyCQA/flake8 - rev: 7.1.1 - hooks: - - id: flake8 - name: "Linter" - args: - - --config=setup.cfg - additional_dependencies: - - pep8-naming - - flake8-builtins - - flake8-comprehensions - - flake8-bugbear - - flake8-pytest-style - - flake8-cognitive-complexity - - importlib-metadata<5.0 - - repo: local hooks: - id: mypy @@ -58,7 +43,7 @@ repos: pass_filenames: false - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook - rev: v4.1.0 + rev: v9.18.0 hooks: - id: commitlint name: "Commit linter" @@ -66,7 +51,7 @@ repos: additional_dependencies: [ '@commitlint/config-conventional' ] - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.3.0 + rev: v1.5.5 hooks: - id: insert-license name: "License inserter" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6f5de9671..cde2f0771 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,16 +48,14 @@ Before sending your pull request for review, make sure your changes are consiste ### Coding Style In general, we follow the [Google Style Guide](https://google.github.io/styleguide/pyguide.html). We use [conventional commit messages](https://www.conventionalcommits.org/en/v1.0.0/) for commit messages. -In addition, to guarantee the quality and uniformity of the code, we use various linters: +In addition, to guarantee the quality and uniformity of the code, we use two tools: -- [Black](https://black.readthedocs.io/en/stable/#) is a deterministic code formatter that is compliant with PEP8 standards. -- [Isort](https://pycqa.github.io/isort/) sorts imports alphabetically and separates them into sections. -- [Flake8](https://flake8.pycqa.org/en/latest/) is a library that wraps PyFlakes and PyCodeStyle. It is a great toolkit for checking your codebase against coding style (PEP8), programming, and syntax errors. Flake8 also benefits from an ecosystem of plugins developed by the community that extend its capabilities. You can read more about Flake8 plugins on the documentation and find a curated list of plugins here. +- [Ruff](https://docs.astral.sh/ruff/) is an extremely fast Python linter and code formatter. - [MyPy](https://mypy.readthedocs.io/en/stable/#) is a static type checker that can help you detect inconsistent typing of variables. #### Pre-Commit -To help in automating the quality of the code, we use [pre-commit](https://pre-commit.com/), a framework that manages the installation and execution of git hooks that will be run before every commit. These hooks help to automatically point out issues in code such as formatting mistakes, unused variables, trailing whitespace, debug statements, etc. By pointing these issues out before code review, it allows a code reviewer to focus on the architecture of a change while not wasting time with trivial style nitpicks. Each commit should be preceded by a call to pre-commit to ensure code quality and formatting. The configuration is in .pre-commit-config.yaml and includes Black, Flake8, MyPy and checks for the yaml formatting, trimming trailing whitespace, etc. +To help in automating the quality of the code, we use [pre-commit](https://pre-commit.com/), a framework that manages the installation and execution of git hooks that will be run before every commit. These hooks help to automatically point out issues in code such as formatting mistakes, unused variables, trailing whitespace, debug statements, etc. By pointing these issues out before code review, it allows a code reviewer to focus on the architecture of a change while not wasting time with trivial style nitpicks. Each commit should be preceded by a call to pre-commit to ensure code quality and formatting. The configuration is in .pre-commit-config.yaml and includes Ruff, MyPy and checks for the yaml formatting, trimming trailing whitespace, etc. Try running: `pre-commit run --all-files`. All linters must pass before committing your change. ### Code of Conduct diff --git a/README.md b/README.md index f18ca16b0..bfe2b1e65 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Python Versions](https://img.shields.io/pypi/pyversions/jumanji.svg?style=flat-square)](https://www.python.org/doc/versions/) [![PyPI Version](https://badge.fury.io/py/jumanji.svg)](https://badge.fury.io/py/jumanji) [![Tests](https://github.com/instadeepai/jumanji/actions/workflows/tests_linters.yml/badge.svg)](https://github.com/instadeepai/jumanji/actions/workflows/tests_linters.yml) -[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![MyPy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/) [![License](https://img.shields.io/badge/License-Apache%202.0-orange.svg)](https://opensource.org/licenses/Apache-2.0) [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97-Hugging%20Face-F8D521)](https://huggingface.co/InstaDeepAI) diff --git a/examples/load_checkpoints.ipynb b/examples/load_checkpoints.ipynb index f99cea250..094ffaec2 100644 --- a/examples/load_checkpoints.ipynb +++ b/examples/load_checkpoints.ipynb @@ -2,30 +2,30 @@ "cells": [ { "cell_type": "markdown", + "metadata": { + "collapsed": false + }, "source": [ "\n", " \"Open\n", "" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": 1, "metadata": { + "ExecuteTime": { + "end_time": "2023-06-14T10:11:06.832854981Z", + "start_time": "2023-06-14T10:10:51.403505913Z" + }, "jupyter": { "outputs_hidden": true }, "pycharm": { "is_executing": true }, - "scrolled": true, - "ExecuteTime": { - "end_time": "2023-06-14T10:11:06.832854981Z", - "start_time": "2023-06-14T10:10:51.403505913Z" - } + "scrolled": true }, "outputs": [], "source": [ @@ -36,6 +36,13 @@ { "cell_type": "code", "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2023-06-14T10:11:06.844131189Z", + "start_time": "2023-06-14T10:11:06.837796509Z" + }, + "collapsed": false + }, "outputs": [ { "name": "stdout", @@ -53,7 +60,7 @@ "\n", "# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system\n", "try:\n", - " subprocess.check_output('nvidia-smi')\n", + " subprocess.check_output(\"nvidia-smi\")\n", " print(\"a GPU is connected.\")\n", "except Exception:\n", " # TPU or CPU\n", @@ -63,29 +70,22 @@ " jax.tools.colab_tpu.setup_tpu()\n", " print(\"A TPU is connected.\")\n", " else:\n", - " print(\"Only CPU accelerator is connected.\")\n" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-06-14T10:11:06.844131189Z", - "start_time": "2023-06-14T10:11:06.837796509Z" - } - } + " print(\"Only CPU accelerator is connected.\")" + ] }, { "cell_type": "code", "execution_count": 3, "metadata": { + "ExecuteTime": { + "end_time": "2023-06-14T10:11:08.370733527Z", + "start_time": "2023-06-14T10:11:06.842722444Z" + }, "jupyter": { "outputs_hidden": false }, "pycharm": { "is_executing": true - }, - "ExecuteTime": { - "end_time": "2023-06-14T10:11:08.370733527Z", - "start_time": "2023-06-14T10:11:06.842722444Z" } }, "outputs": [ @@ -123,29 +123,37 @@ { "cell_type": "code", "execution_count": 4, - "outputs": [], - "source": [ - "env = \"bin_pack\" # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'graph_coloring', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'mmst', 'multi_cvrp', 'robot_warehouse', 'rubiks_cube', 'snake', 'sudoku', 'tetris', 'tsp']\n", - "agent = \"a2c\" # @param ['random', 'a2c']" - ], "metadata": { - "collapsed": false, "ExecuteTime": { "end_time": "2023-06-14T10:11:08.373857448Z", "start_time": "2023-06-14T10:11:08.371354210Z" - } - } + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "env = \"bin_pack\" # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'graph_coloring', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'mmst', 'multi_cvrp', 'robot_warehouse', 'rubiks_cube', 'snake', 'sudoku', 'tetris', 'tsp']\n", + "agent = \"a2c\" # @param ['random', 'a2c']" + ] }, { "cell_type": "code", "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2023-06-14T10:11:08.479313689Z", + "start_time": "2023-06-14T10:11:08.376210715Z" + }, + "collapsed": false + }, "outputs": [], "source": [ - "#@title Download Jumanji Configs (run me) { display-mode: \"form\" }\n", + "# @title Download Jumanji Configs (run me) { display-mode: \"form\" }\n", "\n", "import os\n", "import requests\n", "\n", + "\n", "def download_file(url: str, file_path: str) -> None:\n", " # Send an HTTP GET request to the URL\n", " response = requests.get(url)\n", @@ -156,20 +164,14 @@ " else:\n", " print(\"Failed to download the file.\")\n", "\n", + "\n", "os.makedirs(\"configs\", exist_ok=True)\n", "config_url = \"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml\"\n", "download_file(config_url, \"configs/config.yaml\")\n", "env_url = f\"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env}.yaml\"\n", "os.makedirs(\"configs/env\", exist_ok=True)\n", "download_file(env_url, f\"configs/env/{env}.yaml\")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-06-14T10:11:08.479313689Z", - "start_time": "2023-06-14T10:11:08.376210715Z" - } - } + ] }, { "cell_type": "code", @@ -207,6 +209,13 @@ { "cell_type": "code", "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2023-06-14T10:11:10.226606119Z", + "start_time": "2023-06-14T10:11:08.702541986Z" + }, + "collapsed": false + }, "outputs": [], "source": [ "# Chose the corresponding checkpoint from the InstaDeep Model Hub\n", @@ -218,14 +227,7 @@ "\n", "with open(model_checkpoint, \"rb\") as f:\n", " training_state = pickle.load(f)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-06-14T10:11:10.226606119Z", - "start_time": "2023-06-14T10:11:08.702541986Z" - } - } + ] }, { "cell_type": "code", @@ -241,7 +243,7 @@ "params = first_from_device(training_state.params_state.params)\n", "env = setup_env(cfg).unwrapped\n", "agent = setup_agent(cfg, env)\n", - "policy = jax.jit(agent.make_policy(params.actor, stochastic = False))\n", + "policy = jax.jit(agent.make_policy(params.actor, stochastic=False))\n", "if agent == \"a2c\":\n", " policy = lambda *args: policy(*args)[0]" ] @@ -269,7 +271,7 @@ "states = []\n", "key = jax.random.PRNGKey(cfg.seed)\n", "for episode in range(NUM_EPISODES):\n", - " key, reset_key = jax.random.split(key) \n", + " key, reset_key = jax.random.split(key)\n", " state, timestep = reset_fn(reset_key)\n", " states.append(state)\n", " while not timestep.last():\n", @@ -295,35 +297,35 @@ "cell_type": "code", "execution_count": 10, "metadata": { - "pycharm": { - "is_executing": true - }, "ExecuteTime": { "end_time": "2023-06-14T10:11:23.572860540Z", "start_time": "2023-06-14T10:11:19.277668279Z" + }, + "pycharm": { + "is_executing": true } }, "outputs": [ { "data": { - "text/plain": "", - "application/javascript": "/* Put everything inside the global mpl namespace */\n/* global mpl */\nwindow.mpl = {};\n\nmpl.get_websocket_type = function () {\n if (typeof WebSocket !== 'undefined') {\n return WebSocket;\n } else if (typeof MozWebSocket !== 'undefined') {\n return MozWebSocket;\n } else {\n alert(\n 'Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.'\n );\n }\n};\n\nmpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = this.ws.binaryType !== undefined;\n\n if (!this.supports_binary) {\n var warnings = document.getElementById('mpl-warnings');\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent =\n 'This browser does not support binary websocket messages. ' +\n 'Performance may be slow.';\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = document.createElement('div');\n this.root.setAttribute('style', 'display: inline-block');\n this._root_extra_style(this.root);\n\n parent_element.appendChild(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message('supports_binary', { value: fig.supports_binary });\n fig.send_message('send_image_mode', {});\n if (fig.ratio !== 1) {\n fig.send_message('set_device_pixel_ratio', {\n device_pixel_ratio: fig.ratio,\n });\n }\n fig.send_message('refresh', {});\n };\n\n this.imageObj.onload = function () {\n if (fig.image_mode === 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function () {\n fig.ws.close();\n };\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n};\n\nmpl.figure.prototype._init_header = function () {\n var titlebar = document.createElement('div');\n titlebar.classList =\n 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n var titletext = document.createElement('div');\n titletext.classList = 'ui-dialog-title';\n titletext.setAttribute(\n 'style',\n 'width: 100%; text-align: center; padding: 3px;'\n );\n titlebar.appendChild(titletext);\n this.root.appendChild(titlebar);\n this.header = titletext;\n};\n\nmpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._init_canvas = function () {\n var fig = this;\n\n var canvas_div = (this.canvas_div = document.createElement('div'));\n canvas_div.setAttribute('tabindex', '0');\n canvas_div.setAttribute(\n 'style',\n 'border: 1px solid #ddd;' +\n 'box-sizing: content-box;' +\n 'clear: both;' +\n 'min-height: 1px;' +\n 'min-width: 1px;' +\n 'outline: 0;' +\n 'overflow: hidden;' +\n 'position: relative;' +\n 'resize: both;' +\n 'z-index: 2;'\n );\n\n function on_keyboard_event_closure(name) {\n return function (event) {\n return fig.key_event(event, name);\n };\n }\n\n canvas_div.addEventListener(\n 'keydown',\n on_keyboard_event_closure('key_press')\n );\n canvas_div.addEventListener(\n 'keyup',\n on_keyboard_event_closure('key_release')\n );\n\n this._canvas_extra_style(canvas_div);\n this.root.appendChild(canvas_div);\n\n var canvas = (this.canvas = document.createElement('canvas'));\n canvas.classList.add('mpl-canvas');\n canvas.setAttribute(\n 'style',\n 'box-sizing: content-box;' +\n 'pointer-events: none;' +\n 'position: relative;' +\n 'z-index: 0;'\n );\n\n this.context = canvas.getContext('2d');\n\n var backingStore =\n this.context.backingStorePixelRatio ||\n this.context.webkitBackingStorePixelRatio ||\n this.context.mozBackingStorePixelRatio ||\n this.context.msBackingStorePixelRatio ||\n this.context.oBackingStorePixelRatio ||\n this.context.backingStorePixelRatio ||\n 1;\n\n this.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n 'canvas'\n ));\n rubberband_canvas.setAttribute(\n 'style',\n 'box-sizing: content-box;' +\n 'left: 0;' +\n 'pointer-events: none;' +\n 'position: absolute;' +\n 'top: 0;' +\n 'z-index: 1;'\n );\n\n // Apply a ponyfill if ResizeObserver is not implemented by browser.\n if (this.ResizeObserver === undefined) {\n if (window.ResizeObserver !== undefined) {\n this.ResizeObserver = window.ResizeObserver;\n } else {\n var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n this.ResizeObserver = obs.ResizeObserver;\n }\n }\n\n this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n var nentries = entries.length;\n for (var i = 0; i < nentries; i++) {\n var entry = entries[i];\n var width, height;\n if (entry.contentBoxSize) {\n if (entry.contentBoxSize instanceof Array) {\n // Chrome 84 implements new version of spec.\n width = entry.contentBoxSize[0].inlineSize;\n height = entry.contentBoxSize[0].blockSize;\n } else {\n // Firefox implements old version of spec.\n width = entry.contentBoxSize.inlineSize;\n height = entry.contentBoxSize.blockSize;\n }\n } else {\n // Chrome <84 implements even older version of spec.\n width = entry.contentRect.width;\n height = entry.contentRect.height;\n }\n\n // Keep the size of the canvas and rubber band canvas in sync with\n // the canvas container.\n if (entry.devicePixelContentBoxSize) {\n // Chrome 84 implements new version of spec.\n canvas.setAttribute(\n 'width',\n entry.devicePixelContentBoxSize[0].inlineSize\n );\n canvas.setAttribute(\n 'height',\n entry.devicePixelContentBoxSize[0].blockSize\n );\n } else {\n canvas.setAttribute('width', width * fig.ratio);\n canvas.setAttribute('height', height * fig.ratio);\n }\n /* This rescales the canvas back to display pixels, so that it\n * appears correct on HiDPI screens. */\n canvas.style.width = width + 'px';\n canvas.style.height = height + 'px';\n\n rubberband_canvas.setAttribute('width', width);\n rubberband_canvas.setAttribute('height', height);\n\n // And update the size in Python. We ignore the initial 0/0 size\n // that occurs as the element is placed into the DOM, which should\n // otherwise not happen due to the minimum size styling.\n if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n fig.request_resize(width, height);\n }\n }\n });\n this.resizeObserverInstance.observe(canvas_div);\n\n function on_mouse_event_closure(name) {\n /* User Agent sniffing is bad, but WebKit is busted:\n * https://bugs.webkit.org/show_bug.cgi?id=144526\n * https://bugs.webkit.org/show_bug.cgi?id=181818\n * The worst that happens here is that they get an extra browser\n * selection when dragging, if this check fails to catch them.\n */\n var UA = navigator.userAgent;\n var isWebKit = /AppleWebKit/.test(UA) && !/Chrome/.test(UA);\n if(isWebKit) {\n return function (event) {\n /* This prevents the web browser from automatically changing to\n * the text insertion cursor when the button is pressed. We\n * want to control all of the cursor setting manually through\n * the 'cursor' event from matplotlib */\n event.preventDefault()\n return fig.mouse_event(event, name);\n };\n } else {\n return function (event) {\n return fig.mouse_event(event, name);\n };\n }\n }\n\n canvas_div.addEventListener(\n 'mousedown',\n on_mouse_event_closure('button_press')\n );\n canvas_div.addEventListener(\n 'mouseup',\n on_mouse_event_closure('button_release')\n );\n canvas_div.addEventListener(\n 'dblclick',\n on_mouse_event_closure('dblclick')\n );\n // Throttle sequential mouse events to 1 every 20ms.\n canvas_div.addEventListener(\n 'mousemove',\n on_mouse_event_closure('motion_notify')\n );\n\n canvas_div.addEventListener(\n 'mouseenter',\n on_mouse_event_closure('figure_enter')\n );\n canvas_div.addEventListener(\n 'mouseleave',\n on_mouse_event_closure('figure_leave')\n );\n\n canvas_div.addEventListener('wheel', function (event) {\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n on_mouse_event_closure('scroll')(event);\n });\n\n canvas_div.appendChild(canvas);\n canvas_div.appendChild(rubberband_canvas);\n\n this.rubberband_context = rubberband_canvas.getContext('2d');\n this.rubberband_context.strokeStyle = '#000000';\n\n this._resize_canvas = function (width, height, forward) {\n if (forward) {\n canvas_div.style.width = width + 'px';\n canvas_div.style.height = height + 'px';\n }\n };\n\n // Disable right mouse context menu.\n canvas_div.addEventListener('contextmenu', function (_e) {\n event.preventDefault();\n return false;\n });\n\n function set_focus() {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'mpl-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n continue;\n }\n\n var button = (fig.buttons[name] = document.createElement('button'));\n button.classList = 'mpl-widget';\n button.setAttribute('role', 'button');\n button.setAttribute('aria-disabled', 'false');\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n\n var icon_img = document.createElement('img');\n icon_img.src = '_images/' + image + '.png';\n icon_img.srcset = '_images/' + image + '_large.png 2x';\n icon_img.alt = tooltip;\n button.appendChild(icon_img);\n\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n var fmt_picker = document.createElement('select');\n fmt_picker.classList = 'mpl-widget';\n toolbar.appendChild(fmt_picker);\n this.format_dropdown = fmt_picker;\n\n for (var ind in mpl.extensions) {\n var fmt = mpl.extensions[ind];\n var option = document.createElement('option');\n option.selected = fmt === mpl.default_extension;\n option.innerHTML = fmt;\n fmt_picker.appendChild(option);\n }\n\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n};\n\nmpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n // which will in turn request a refresh of the image.\n this.send_message('resize', { width: x_pixels, height: y_pixels });\n};\n\nmpl.figure.prototype.send_message = function (type, properties) {\n properties['type'] = type;\n properties['figure_id'] = this.id;\n this.ws.send(JSON.stringify(properties));\n};\n\nmpl.figure.prototype.send_draw_message = function () {\n if (!this.waiting) {\n this.waiting = true;\n this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n var format_dropdown = fig.format_dropdown;\n var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n fig.ondownload(fig, format);\n};\n\nmpl.figure.prototype.handle_resize = function (fig, msg) {\n var size = msg['size'];\n if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n fig._resize_canvas(size[0], size[1], msg['forward']);\n fig.send_message('refresh', {});\n }\n};\n\nmpl.figure.prototype.handle_rubberband = function (fig, msg) {\n var x0 = msg['x0'] / fig.ratio;\n var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n var x1 = msg['x1'] / fig.ratio;\n var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n x0 = Math.floor(x0) + 0.5;\n y0 = Math.floor(y0) + 0.5;\n x1 = Math.floor(x1) + 0.5;\n y1 = Math.floor(y1) + 0.5;\n var min_x = Math.min(x0, x1);\n var min_y = Math.min(y0, y1);\n var width = Math.abs(x1 - x0);\n var height = Math.abs(y1 - y0);\n\n fig.rubberband_context.clearRect(\n 0,\n 0,\n fig.canvas.width / fig.ratio,\n fig.canvas.height / fig.ratio\n );\n\n fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n};\n\nmpl.figure.prototype.handle_figure_label = function (fig, msg) {\n // Updates the figure title.\n fig.header.textContent = msg['label'];\n};\n\nmpl.figure.prototype.handle_cursor = function (fig, msg) {\n fig.canvas_div.style.cursor = msg['cursor'];\n};\n\nmpl.figure.prototype.handle_message = function (fig, msg) {\n fig.message.textContent = msg['message'];\n};\n\nmpl.figure.prototype.handle_draw = function (fig, _msg) {\n // Request the server to send over a new figure.\n fig.send_draw_message();\n};\n\nmpl.figure.prototype.handle_image_mode = function (fig, msg) {\n fig.image_mode = msg['mode'];\n};\n\nmpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n for (var key in msg) {\n if (!(key in fig.buttons)) {\n continue;\n }\n fig.buttons[key].disabled = !msg[key];\n fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n }\n};\n\nmpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n if (msg['mode'] === 'PAN') {\n fig.buttons['Pan'].classList.add('active');\n fig.buttons['Zoom'].classList.remove('active');\n } else if (msg['mode'] === 'ZOOM') {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.add('active');\n } else {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.remove('active');\n }\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Called whenever the canvas gets updated.\n this.send_message('ack', {});\n};\n\n// A function to construct a web socket function for onmessage handling.\n// Called in the figure constructor.\nmpl.figure.prototype._make_on_message_function = function (fig) {\n return function socket_on_message(evt) {\n if (evt.data instanceof Blob) {\n var img = evt.data;\n if (img.type !== 'image/png') {\n /* FIXME: We get \"Resource interpreted as Image but\n * transferred with MIME type text/plain:\" errors on\n * Chrome. But how to set the MIME type? It doesn't seem\n * to be part of the websocket stream */\n img.type = 'image/png';\n }\n\n /* Free the memory for the previous frames */\n if (fig.imageObj.src) {\n (window.URL || window.webkitURL).revokeObjectURL(\n fig.imageObj.src\n );\n }\n\n fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n img\n );\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n } else if (\n typeof evt.data === 'string' &&\n evt.data.slice(0, 21) === 'data:image/png;base64'\n ) {\n fig.imageObj.src = evt.data;\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n }\n\n var msg = JSON.parse(evt.data);\n var msg_type = msg['type'];\n\n // Call the \"handle_{type}\" callback, which takes\n // the figure and JSON message as its only arguments.\n try {\n var callback = fig['handle_' + msg_type];\n } catch (e) {\n console.log(\n \"No handler for the '\" + msg_type + \"' message type: \",\n msg\n );\n return;\n }\n\n if (callback) {\n try {\n // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n callback(fig, msg);\n } catch (e) {\n console.log(\n \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n e,\n e.stack,\n msg\n );\n }\n }\n };\n};\n\nfunction getModifiers(event) {\n var mods = [];\n if (event.ctrlKey) {\n mods.push('ctrl');\n }\n if (event.altKey) {\n mods.push('alt');\n }\n if (event.shiftKey) {\n mods.push('shift');\n }\n if (event.metaKey) {\n mods.push('meta');\n }\n return mods;\n}\n\n/*\n * return a copy of an object with only non-object keys\n * we need this to avoid circular references\n * https://stackoverflow.com/a/24161582/3208463\n */\nfunction simpleKeys(original) {\n return Object.keys(original).reduce(function (obj, key) {\n if (typeof original[key] !== 'object') {\n obj[key] = original[key];\n }\n return obj;\n }, {});\n}\n\nmpl.figure.prototype.mouse_event = function (event, name) {\n if (name === 'button_press') {\n this.canvas.focus();\n this.canvas_div.focus();\n }\n\n // from https://stackoverflow.com/q/1114465\n var boundingRect = this.canvas.getBoundingClientRect();\n var x = (event.clientX - boundingRect.left) * this.ratio;\n var y = (event.clientY - boundingRect.top) * this.ratio;\n\n this.send_message(name, {\n x: x,\n y: y,\n button: event.button,\n step: event.step,\n modifiers: getModifiers(event),\n guiEvent: simpleKeys(event),\n });\n\n return false;\n};\n\nmpl.figure.prototype._key_event_extra = function (_event, _name) {\n // Handle any extra behaviour associated with a key event\n};\n\nmpl.figure.prototype.key_event = function (event, name) {\n // Prevent repeat events\n if (name === 'key_press') {\n if (event.key === this._key) {\n return;\n } else {\n this._key = event.key;\n }\n }\n if (name === 'key_release') {\n this._key = null;\n }\n\n var value = '';\n if (event.ctrlKey && event.key !== 'Control') {\n value += 'ctrl+';\n }\n else if (event.altKey && event.key !== 'Alt') {\n value += 'alt+';\n }\n else if (event.shiftKey && event.key !== 'Shift') {\n value += 'shift+';\n }\n\n value += 'k' + event.key;\n\n this._key_event_extra(event, name);\n\n this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n return false;\n};\n\nmpl.figure.prototype.toolbar_button_onclick = function (name) {\n if (name === 'download') {\n this.handle_save(this, null);\n } else {\n this.send_message('toolbar_button', { name: name });\n }\n};\n\nmpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n this.message.textContent = tooltip;\n};\n\n///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n// prettier-ignore\nvar _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\nmpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis\", \"fa fa-square-o\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o\", \"download\"]];\n\nmpl.extensions = [\"eps\", \"jpeg\", \"pgf\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\", \"webp\"];\n\nmpl.default_extension = \"png\";/* global mpl */\n\nvar comm_websocket_adapter = function (comm) {\n // Create a \"websocket\"-like object which calls the given IPython comm\n // object with the appropriate methods. Currently this is a non binary\n // socket, so there is still some room for performance tuning.\n var ws = {};\n\n ws.binaryType = comm.kernel.ws.binaryType;\n ws.readyState = comm.kernel.ws.readyState;\n function updateReadyState(_event) {\n if (comm.kernel.ws) {\n ws.readyState = comm.kernel.ws.readyState;\n } else {\n ws.readyState = 3; // Closed state.\n }\n }\n comm.kernel.ws.addEventListener('open', updateReadyState);\n comm.kernel.ws.addEventListener('close', updateReadyState);\n comm.kernel.ws.addEventListener('error', updateReadyState);\n\n ws.close = function () {\n comm.close();\n };\n ws.send = function (m) {\n //console.log('sending', m);\n comm.send(m);\n };\n // Register the callback with on_msg.\n comm.on_msg(function (msg) {\n //console.log('receiving', msg['content']['data'], msg);\n var data = msg['content']['data'];\n if (data['blob'] !== undefined) {\n data = {\n data: new Blob(msg['buffers'], { type: data['blob'] }),\n };\n }\n // Pass the mpl event to the overridden (by mpl) onmessage function.\n ws.onmessage(data);\n });\n return ws;\n};\n\nmpl.mpl_figure_comm = function (comm, msg) {\n // This is the function which gets called when the mpl process\n // starts-up an IPython Comm through the \"matplotlib\" channel.\n\n var id = msg.content.data.id;\n // Get hold of the div created by the display call when the Comm\n // socket was opened in Python.\n var element = document.getElementById(id);\n var ws_proxy = comm_websocket_adapter(comm);\n\n function ondownload(figure, _format) {\n window.open(figure.canvas.toDataURL());\n }\n\n var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n\n // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n // web socket which is closed, not our websocket->open comm proxy.\n ws_proxy.onopen();\n\n fig.parent_element = element;\n fig.cell_info = mpl.find_output_cell(\"
\");\n if (!fig.cell_info) {\n console.error('Failed to find cell for figure', id, fig);\n return;\n }\n fig.cell_info[0].output_area.element.on(\n 'cleared',\n { fig: fig },\n fig._remove_fig_handler\n );\n};\n\nmpl.figure.prototype.handle_close = function (fig, msg) {\n var width = fig.canvas.width / fig.ratio;\n fig.cell_info[0].output_area.element.off(\n 'cleared',\n fig._remove_fig_handler\n );\n fig.resizeObserverInstance.unobserve(fig.canvas_div);\n\n // Update the output cell to use the data from the current canvas.\n fig.push_to_output();\n var dataURL = fig.canvas.toDataURL();\n // Re-enable the keyboard manager in IPython - without this line, in FF,\n // the notebook keyboard shortcuts fail.\n IPython.keyboard_manager.enable();\n fig.parent_element.innerHTML =\n '';\n fig.close_ws(fig, msg);\n};\n\nmpl.figure.prototype.close_ws = function (fig, msg) {\n fig.send_message('closing', msg);\n // fig.ws.close()\n};\n\nmpl.figure.prototype.push_to_output = function (_remove_interactive) {\n // Turn the data on the canvas into data in the output cell.\n var width = this.canvas.width / this.ratio;\n var dataURL = this.canvas.toDataURL();\n this.cell_info[1]['text/html'] =\n '';\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Tell IPython that the notebook contents must change.\n IPython.notebook.set_dirty(true);\n this.send_message('ack', {});\n var fig = this;\n // Wait a second, then push the new image to the DOM so\n // that it is saved nicely (might be nice to debounce this).\n setTimeout(function () {\n fig.push_to_output();\n }, 1000);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'btn-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n var button;\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n continue;\n }\n\n button = fig.buttons[name] = document.createElement('button');\n button.classList = 'btn btn-default';\n button.href = '#';\n button.title = name;\n button.innerHTML = '';\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n // Add the status bar.\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message pull-right';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n\n // Add the close button to the window.\n var buttongrp = document.createElement('div');\n buttongrp.classList = 'btn-group inline pull-right';\n button = document.createElement('button');\n button.classList = 'btn btn-mini btn-primary';\n button.href = '#';\n button.title = 'Stop Interaction';\n button.innerHTML = '';\n button.addEventListener('click', function (_evt) {\n fig.handle_close(fig, {});\n });\n button.addEventListener(\n 'mouseover',\n on_mouseover_closure('Stop Interaction')\n );\n buttongrp.appendChild(button);\n var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n titlebar.insertBefore(buttongrp, titlebar.firstChild);\n};\n\nmpl.figure.prototype._remove_fig_handler = function (event) {\n var fig = event.data.fig;\n if (event.target !== this) {\n // Ignore bubbled events from children.\n return;\n }\n fig.close_ws(fig, {});\n};\n\nmpl.figure.prototype._root_extra_style = function (el) {\n el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n};\n\nmpl.figure.prototype._canvas_extra_style = function (el) {\n // this is important to make the div 'focusable\n el.setAttribute('tabindex', 0);\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n } else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n};\n\nmpl.figure.prototype._key_event_extra = function (event, _name) {\n // Check for shift+enter\n if (event.shiftKey && event.which === 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n fig.ondownload(fig, null);\n};\n\nmpl.find_output_cell = function (html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i = 0; i < ncells; i++) {\n var cell = cells[i];\n if (cell.cell_type === 'code') {\n for (var j = 0; j < cell.output_area.outputs.length; j++) {\n var data = cell.output_area.outputs[j];\n if (data.data) {\n // IPython >= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] === html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n};\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel !== null) {\n IPython.notebook.kernel.comm_manager.register_target(\n 'matplotlib',\n mpl.mpl_figure_comm\n );\n}\n" + "application/javascript": "/* Put everything inside the global mpl namespace */\n/* global mpl */\nwindow.mpl = {};\n\nmpl.get_websocket_type = function () {\n if (typeof WebSocket !== 'undefined') {\n return WebSocket;\n } else if (typeof MozWebSocket !== 'undefined') {\n return MozWebSocket;\n } else {\n alert(\n 'Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.'\n );\n }\n};\n\nmpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = this.ws.binaryType !== undefined;\n\n if (!this.supports_binary) {\n var warnings = document.getElementById('mpl-warnings');\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent =\n 'This browser does not support binary websocket messages. ' +\n 'Performance may be slow.';\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = document.createElement('div');\n this.root.setAttribute('style', 'display: inline-block');\n this._root_extra_style(this.root);\n\n parent_element.appendChild(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message('supports_binary', { value: fig.supports_binary });\n fig.send_message('send_image_mode', {});\n if (fig.ratio !== 1) {\n fig.send_message('set_device_pixel_ratio', {\n device_pixel_ratio: fig.ratio,\n });\n }\n fig.send_message('refresh', {});\n };\n\n this.imageObj.onload = function () {\n if (fig.image_mode === 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function () {\n fig.ws.close();\n };\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n};\n\nmpl.figure.prototype._init_header = function () {\n var titlebar = document.createElement('div');\n titlebar.classList =\n 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n var titletext = document.createElement('div');\n titletext.classList = 'ui-dialog-title';\n titletext.setAttribute(\n 'style',\n 'width: 100%; text-align: center; padding: 3px;'\n );\n titlebar.appendChild(titletext);\n this.root.appendChild(titlebar);\n this.header = titletext;\n};\n\nmpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._init_canvas = function () {\n var fig = this;\n\n var canvas_div = (this.canvas_div = document.createElement('div'));\n canvas_div.setAttribute('tabindex', '0');\n canvas_div.setAttribute(\n 'style',\n 'border: 1px solid #ddd;' +\n 'box-sizing: content-box;' +\n 'clear: both;' +\n 'min-height: 1px;' +\n 'min-width: 1px;' +\n 'outline: 0;' +\n 'overflow: hidden;' +\n 'position: relative;' +\n 'resize: both;' +\n 'z-index: 2;'\n );\n\n function on_keyboard_event_closure(name) {\n return function (event) {\n return fig.key_event(event, name);\n };\n }\n\n canvas_div.addEventListener(\n 'keydown',\n on_keyboard_event_closure('key_press')\n );\n canvas_div.addEventListener(\n 'keyup',\n on_keyboard_event_closure('key_release')\n );\n\n this._canvas_extra_style(canvas_div);\n this.root.appendChild(canvas_div);\n\n var canvas = (this.canvas = document.createElement('canvas'));\n canvas.classList.add('mpl-canvas');\n canvas.setAttribute(\n 'style',\n 'box-sizing: content-box;' +\n 'pointer-events: none;' +\n 'position: relative;' +\n 'z-index: 0;'\n );\n\n this.context = canvas.getContext('2d');\n\n var backingStore =\n this.context.backingStorePixelRatio ||\n this.context.webkitBackingStorePixelRatio ||\n this.context.mozBackingStorePixelRatio ||\n this.context.msBackingStorePixelRatio ||\n this.context.oBackingStorePixelRatio ||\n this.context.backingStorePixelRatio ||\n 1;\n\n this.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n 'canvas'\n ));\n rubberband_canvas.setAttribute(\n 'style',\n 'box-sizing: content-box;' +\n 'left: 0;' +\n 'pointer-events: none;' +\n 'position: absolute;' +\n 'top: 0;' +\n 'z-index: 1;'\n );\n\n // Apply a ponyfill if ResizeObserver is not implemented by browser.\n if (this.ResizeObserver === undefined) {\n if (window.ResizeObserver !== undefined) {\n this.ResizeObserver = window.ResizeObserver;\n } else {\n var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n this.ResizeObserver = obs.ResizeObserver;\n }\n }\n\n this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n var nentries = entries.length;\n for (var i = 0; i < nentries; i++) {\n var entry = entries[i];\n var width, height;\n if (entry.contentBoxSize) {\n if (entry.contentBoxSize instanceof Array) {\n // Chrome 84 implements new version of spec.\n width = entry.contentBoxSize[0].inlineSize;\n height = entry.contentBoxSize[0].blockSize;\n } else {\n // Firefox implements old version of spec.\n width = entry.contentBoxSize.inlineSize;\n height = entry.contentBoxSize.blockSize;\n }\n } else {\n // Chrome <84 implements even older version of spec.\n width = entry.contentRect.width;\n height = entry.contentRect.height;\n }\n\n // Keep the size of the canvas and rubber band canvas in sync with\n // the canvas container.\n if (entry.devicePixelContentBoxSize) {\n // Chrome 84 implements new version of spec.\n canvas.setAttribute(\n 'width',\n entry.devicePixelContentBoxSize[0].inlineSize\n );\n canvas.setAttribute(\n 'height',\n entry.devicePixelContentBoxSize[0].blockSize\n );\n } else {\n canvas.setAttribute('width', width * fig.ratio);\n canvas.setAttribute('height', height * fig.ratio);\n }\n /* This rescales the canvas back to display pixels, so that it\n * appears correct on HiDPI screens. */\n canvas.style.width = width + 'px';\n canvas.style.height = height + 'px';\n\n rubberband_canvas.setAttribute('width', width);\n rubberband_canvas.setAttribute('height', height);\n\n // And update the size in Python. We ignore the initial 0/0 size\n // that occurs as the element is placed into the DOM, which should\n // otherwise not happen due to the minimum size styling.\n if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n fig.request_resize(width, height);\n }\n }\n });\n this.resizeObserverInstance.observe(canvas_div);\n\n function on_mouse_event_closure(name) {\n /* User Agent sniffing is bad, but WebKit is busted:\n * https://bugs.webkit.org/show_bug.cgi?id=144526\n * https://bugs.webkit.org/show_bug.cgi?id=181818\n * The worst that happens here is that they get an extra browser\n * selection when dragging, if this check fails to catch them.\n */\n var UA = navigator.userAgent;\n var isWebKit = /AppleWebKit/.test(UA) && !/Chrome/.test(UA);\n if(isWebKit) {\n return function (event) {\n /* This prevents the web browser from automatically changing to\n * the text insertion cursor when the button is pressed. We\n * want to control all of the cursor setting manually through\n * the 'cursor' event from matplotlib */\n event.preventDefault()\n return fig.mouse_event(event, name);\n };\n } else {\n return function (event) {\n return fig.mouse_event(event, name);\n };\n }\n }\n\n canvas_div.addEventListener(\n 'mousedown',\n on_mouse_event_closure('button_press')\n );\n canvas_div.addEventListener(\n 'mouseup',\n on_mouse_event_closure('button_release')\n );\n canvas_div.addEventListener(\n 'dblclick',\n on_mouse_event_closure('dblclick')\n );\n // Throttle sequential mouse events to 1 every 20ms.\n canvas_div.addEventListener(\n 'mousemove',\n on_mouse_event_closure('motion_notify')\n );\n\n canvas_div.addEventListener(\n 'mouseenter',\n on_mouse_event_closure('figure_enter')\n );\n canvas_div.addEventListener(\n 'mouseleave',\n on_mouse_event_closure('figure_leave')\n );\n\n canvas_div.addEventListener('wheel', function (event) {\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n on_mouse_event_closure('scroll')(event);\n });\n\n canvas_div.appendChild(canvas);\n canvas_div.appendChild(rubberband_canvas);\n\n this.rubberband_context = rubberband_canvas.getContext('2d');\n this.rubberband_context.strokeStyle = '#000000';\n\n this._resize_canvas = function (width, height, forward) {\n if (forward) {\n canvas_div.style.width = width + 'px';\n canvas_div.style.height = height + 'px';\n }\n };\n\n // Disable right mouse context menu.\n canvas_div.addEventListener('contextmenu', function (_e) {\n event.preventDefault();\n return false;\n });\n\n function set_focus() {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'mpl-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n continue;\n }\n\n var button = (fig.buttons[name] = document.createElement('button'));\n button.classList = 'mpl-widget';\n button.setAttribute('role', 'button');\n button.setAttribute('aria-disabled', 'false');\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n\n var icon_img = document.createElement('img');\n icon_img.src = '_images/' + image + '.png';\n icon_img.srcset = '_images/' + image + '_large.png 2x';\n icon_img.alt = tooltip;\n button.appendChild(icon_img);\n\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n var fmt_picker = document.createElement('select');\n fmt_picker.classList = 'mpl-widget';\n toolbar.appendChild(fmt_picker);\n this.format_dropdown = fmt_picker;\n\n for (var ind in mpl.extensions) {\n var fmt = mpl.extensions[ind];\n var option = document.createElement('option');\n option.selected = fmt === mpl.default_extension;\n option.innerHTML = fmt;\n fmt_picker.appendChild(option);\n }\n\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n};\n\nmpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n // which will in turn request a refresh of the image.\n this.send_message('resize', { width: x_pixels, height: y_pixels });\n};\n\nmpl.figure.prototype.send_message = function (type, properties) {\n properties['type'] = type;\n properties['figure_id'] = this.id;\n this.ws.send(JSON.stringify(properties));\n};\n\nmpl.figure.prototype.send_draw_message = function () {\n if (!this.waiting) {\n this.waiting = true;\n this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n var format_dropdown = fig.format_dropdown;\n var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n fig.ondownload(fig, format);\n};\n\nmpl.figure.prototype.handle_resize = function (fig, msg) {\n var size = msg['size'];\n if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n fig._resize_canvas(size[0], size[1], msg['forward']);\n fig.send_message('refresh', {});\n }\n};\n\nmpl.figure.prototype.handle_rubberband = function (fig, msg) {\n var x0 = msg['x0'] / fig.ratio;\n var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n var x1 = msg['x1'] / fig.ratio;\n var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n x0 = Math.floor(x0) + 0.5;\n y0 = Math.floor(y0) + 0.5;\n x1 = Math.floor(x1) + 0.5;\n y1 = Math.floor(y1) + 0.5;\n var min_x = Math.min(x0, x1);\n var min_y = Math.min(y0, y1);\n var width = Math.abs(x1 - x0);\n var height = Math.abs(y1 - y0);\n\n fig.rubberband_context.clearRect(\n 0,\n 0,\n fig.canvas.width / fig.ratio,\n fig.canvas.height / fig.ratio\n );\n\n fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n};\n\nmpl.figure.prototype.handle_figure_label = function (fig, msg) {\n // Updates the figure title.\n fig.header.textContent = msg['label'];\n};\n\nmpl.figure.prototype.handle_cursor = function (fig, msg) {\n fig.canvas_div.style.cursor = msg['cursor'];\n};\n\nmpl.figure.prototype.handle_message = function (fig, msg) {\n fig.message.textContent = msg['message'];\n};\n\nmpl.figure.prototype.handle_draw = function (fig, _msg) {\n // Request the server to send over a new figure.\n fig.send_draw_message();\n};\n\nmpl.figure.prototype.handle_image_mode = function (fig, msg) {\n fig.image_mode = msg['mode'];\n};\n\nmpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n for (var key in msg) {\n if (!(key in fig.buttons)) {\n continue;\n }\n fig.buttons[key].disabled = !msg[key];\n fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n }\n};\n\nmpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n if (msg['mode'] === 'PAN') {\n fig.buttons['Pan'].classList.add('active');\n fig.buttons['Zoom'].classList.remove('active');\n } else if (msg['mode'] === 'ZOOM') {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.add('active');\n } else {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.remove('active');\n }\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Called whenever the canvas gets updated.\n this.send_message('ack', {});\n};\n\n// A function to construct a web socket function for onmessage handling.\n// Called in the figure constructor.\nmpl.figure.prototype._make_on_message_function = function (fig) {\n return function socket_on_message(evt) {\n if (evt.data instanceof Blob) {\n var img = evt.data;\n if (img.type !== 'image/png') {\n /* FIXME: We get \"Resource interpreted as Image but\n * transferred with MIME type text/plain:\" errors on\n * Chrome. But how to set the MIME type? It doesn't seem\n * to be part of the websocket stream */\n img.type = 'image/png';\n }\n\n /* Free the memory for the previous frames */\n if (fig.imageObj.src) {\n (window.URL || window.webkitURL).revokeObjectURL(\n fig.imageObj.src\n );\n }\n\n fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n img\n );\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n } else if (\n typeof evt.data === 'string' &&\n evt.data.slice(0, 21) === 'data:image/png;base64'\n ) {\n fig.imageObj.src = evt.data;\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n }\n\n var msg = JSON.parse(evt.data);\n var msg_type = msg['type'];\n\n // Call the \"handle_{type}\" callback, which takes\n // the figure and JSON message as its only arguments.\n try {\n var callback = fig['handle_' + msg_type];\n } catch (e) {\n console.log(\n \"No handler for the '\" + msg_type + \"' message type: \",\n msg\n );\n return;\n }\n\n if (callback) {\n try {\n // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n callback(fig, msg);\n } catch (e) {\n console.log(\n \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n e,\n e.stack,\n msg\n );\n }\n }\n };\n};\n\nfunction getModifiers(event) {\n var mods = [];\n if (event.ctrlKey) {\n mods.push('ctrl');\n }\n if (event.altKey) {\n mods.push('alt');\n }\n if (event.shiftKey) {\n mods.push('shift');\n }\n if (event.metaKey) {\n mods.push('meta');\n }\n return mods;\n}\n\n/*\n * return a copy of an object with only non-object keys\n * we need this to avoid circular references\n * https://stackoverflow.com/a/24161582/3208463\n */\nfunction simpleKeys(original) {\n return Object.keys(original).reduce(function (obj, key) {\n if (typeof original[key] !== 'object') {\n obj[key] = original[key];\n }\n return obj;\n }, {});\n}\n\nmpl.figure.prototype.mouse_event = function (event, name) {\n if (name === 'button_press') {\n this.canvas.focus();\n this.canvas_div.focus();\n }\n\n // from https://stackoverflow.com/q/1114465\n var boundingRect = this.canvas.getBoundingClientRect();\n var x = (event.clientX - boundingRect.left) * this.ratio;\n var y = (event.clientY - boundingRect.top) * this.ratio;\n\n this.send_message(name, {\n x: x,\n y: y,\n button: event.button,\n step: event.step,\n modifiers: getModifiers(event),\n guiEvent: simpleKeys(event),\n });\n\n return false;\n};\n\nmpl.figure.prototype._key_event_extra = function (_event, _name) {\n // Handle any extra behaviour associated with a key event\n};\n\nmpl.figure.prototype.key_event = function (event, name) {\n // Prevent repeat events\n if (name === 'key_press') {\n if (event.key === this._key) {\n return;\n } else {\n this._key = event.key;\n }\n }\n if (name === 'key_release') {\n this._key = null;\n }\n\n var value = '';\n if (event.ctrlKey && event.key !== 'Control') {\n value += 'ctrl+';\n }\n else if (event.altKey && event.key !== 'Alt') {\n value += 'alt+';\n }\n else if (event.shiftKey && event.key !== 'Shift') {\n value += 'shift+';\n }\n\n value += 'k' + event.key;\n\n this._key_event_extra(event, name);\n\n this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n return false;\n};\n\nmpl.figure.prototype.toolbar_button_onclick = function (name) {\n if (name === 'download') {\n this.handle_save(this, null);\n } else {\n this.send_message('toolbar_button', { name: name });\n }\n};\n\nmpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n this.message.textContent = tooltip;\n};\n\n///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n// prettier-ignore\nvar _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\nmpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis\", \"fa fa-square-o\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o\", \"download\"]];\n\nmpl.extensions = [\"eps\", \"jpeg\", \"pgf\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\", \"webp\"];\n\nmpl.default_extension = \"png\";/* global mpl */\n\nvar comm_websocket_adapter = function (comm) {\n // Create a \"websocket\"-like object which calls the given IPython comm\n // object with the appropriate methods. Currently this is a non binary\n // socket, so there is still some room for performance tuning.\n var ws = {};\n\n ws.binaryType = comm.kernel.ws.binaryType;\n ws.readyState = comm.kernel.ws.readyState;\n function updateReadyState(_event) {\n if (comm.kernel.ws) {\n ws.readyState = comm.kernel.ws.readyState;\n } else {\n ws.readyState = 3; // Closed state.\n }\n }\n comm.kernel.ws.addEventListener('open', updateReadyState);\n comm.kernel.ws.addEventListener('close', updateReadyState);\n comm.kernel.ws.addEventListener('error', updateReadyState);\n\n ws.close = function () {\n comm.close();\n };\n ws.send = function (m) {\n //console.log('sending', m);\n comm.send(m);\n };\n // Register the callback with on_msg.\n comm.on_msg(function (msg) {\n //console.log('receiving', msg['content']['data'], msg);\n var data = msg['content']['data'];\n if (data['blob'] !== undefined) {\n data = {\n data: new Blob(msg['buffers'], { type: data['blob'] }),\n };\n }\n // Pass the mpl event to the overridden (by mpl) onmessage function.\n ws.onmessage(data);\n });\n return ws;\n};\n\nmpl.mpl_figure_comm = function (comm, msg) {\n // This is the function which gets called when the mpl process\n // starts-up an IPython Comm through the \"matplotlib\" channel.\n\n var id = msg.content.data.id;\n // Get hold of the div created by the display call when the Comm\n // socket was opened in Python.\n var element = document.getElementById(id);\n var ws_proxy = comm_websocket_adapter(comm);\n\n function ondownload(figure, _format) {\n window.open(figure.canvas.toDataURL());\n }\n\n var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n\n // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n // web socket which is closed, not our websocket->open comm proxy.\n ws_proxy.onopen();\n\n fig.parent_element = element;\n fig.cell_info = mpl.find_output_cell(\"
\");\n if (!fig.cell_info) {\n console.error('Failed to find cell for figure', id, fig);\n return;\n }\n fig.cell_info[0].output_area.element.on(\n 'cleared',\n { fig: fig },\n fig._remove_fig_handler\n );\n};\n\nmpl.figure.prototype.handle_close = function (fig, msg) {\n var width = fig.canvas.width / fig.ratio;\n fig.cell_info[0].output_area.element.off(\n 'cleared',\n fig._remove_fig_handler\n );\n fig.resizeObserverInstance.unobserve(fig.canvas_div);\n\n // Update the output cell to use the data from the current canvas.\n fig.push_to_output();\n var dataURL = fig.canvas.toDataURL();\n // Re-enable the keyboard manager in IPython - without this line, in FF,\n // the notebook keyboard shortcuts fail.\n IPython.keyboard_manager.enable();\n fig.parent_element.innerHTML =\n '';\n fig.close_ws(fig, msg);\n};\n\nmpl.figure.prototype.close_ws = function (fig, msg) {\n fig.send_message('closing', msg);\n // fig.ws.close()\n};\n\nmpl.figure.prototype.push_to_output = function (_remove_interactive) {\n // Turn the data on the canvas into data in the output cell.\n var width = this.canvas.width / this.ratio;\n var dataURL = this.canvas.toDataURL();\n this.cell_info[1]['text/html'] =\n '';\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Tell IPython that the notebook contents must change.\n IPython.notebook.set_dirty(true);\n this.send_message('ack', {});\n var fig = this;\n // Wait a second, then push the new image to the DOM so\n // that it is saved nicely (might be nice to debounce this).\n setTimeout(function () {\n fig.push_to_output();\n }, 1000);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'btn-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n var button;\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n continue;\n }\n\n button = fig.buttons[name] = document.createElement('button');\n button.classList = 'btn btn-default';\n button.href = '#';\n button.title = name;\n button.innerHTML = '';\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n // Add the status bar.\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message pull-right';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n\n // Add the close button to the window.\n var buttongrp = document.createElement('div');\n buttongrp.classList = 'btn-group inline pull-right';\n button = document.createElement('button');\n button.classList = 'btn btn-mini btn-primary';\n button.href = '#';\n button.title = 'Stop Interaction';\n button.innerHTML = '';\n button.addEventListener('click', function (_evt) {\n fig.handle_close(fig, {});\n });\n button.addEventListener(\n 'mouseover',\n on_mouseover_closure('Stop Interaction')\n );\n buttongrp.appendChild(button);\n var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n titlebar.insertBefore(buttongrp, titlebar.firstChild);\n};\n\nmpl.figure.prototype._remove_fig_handler = function (event) {\n var fig = event.data.fig;\n if (event.target !== this) {\n // Ignore bubbled events from children.\n return;\n }\n fig.close_ws(fig, {});\n};\n\nmpl.figure.prototype._root_extra_style = function (el) {\n el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n};\n\nmpl.figure.prototype._canvas_extra_style = function (el) {\n // this is important to make the div 'focusable\n el.setAttribute('tabindex', 0);\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n } else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n};\n\nmpl.figure.prototype._key_event_extra = function (event, _name) {\n // Check for shift+enter\n if (event.shiftKey && event.which === 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n fig.ondownload(fig, null);\n};\n\nmpl.find_output_cell = function (html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i = 0; i < ncells; i++) {\n var cell = cells[i];\n if (cell.cell_type === 'code') {\n for (var j = 0; j < cell.output_area.outputs.length; j++) {\n var data = cell.output_area.outputs[j];\n if (data.data) {\n // IPython >= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] === html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n};\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel !== null) {\n IPython.notebook.kernel.comm_manager.register_target(\n 'matplotlib',\n mpl.mpl_figure_comm\n );\n}\n", + "text/plain": "" }, "metadata": {}, "output_type": "display_data" }, { "data": { - "text/plain": "", - "text/html": "
" + "text/html": "
", + "text/plain": "" }, "metadata": {}, "output_type": "display_data" }, { "data": { - "text/plain": "", - "text/html": "\n\n\n\n\n\n
\n \n
\n \n
\n \n \n \n \n \n \n \n \n \n
\n
\n \n \n \n \n \n \n
\n
\n
\n\n\n\n" + "text/html": "\n\n\n\n\n\n
\n \n
\n \n
\n \n \n \n \n \n \n \n \n \n
\n
\n \n \n \n \n \n \n
\n
\n
\n\n\n\n", + "text/plain": "" }, "execution_count": 10, "metadata": {}, diff --git a/examples/training.ipynb b/examples/training.ipynb index 2d1034ce9..3c3f49f35 100644 --- a/examples/training.ipynb +++ b/examples/training.ipynb @@ -54,7 +54,7 @@ "\n", "# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system\n", "try:\n", - " subprocess.check_output('nvidia-smi')\n", + " subprocess.check_output(\"nvidia-smi\")\n", " print(\"a GPU is connected.\")\n", "except Exception:\n", " # TPU or CPU\n", @@ -82,6 +82,7 @@ "outputs": [], "source": [ "import warnings\n", + "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "from jumanji.training.train import train\n", @@ -117,7 +118,7 @@ }, "outputs": [], "source": [ - "#@title Download Jumanji Configs (run me) { display-mode: \"form\" }\n", + "# @title Download Jumanji Configs (run me) { display-mode: \"form\" }\n", "\n", "import os\n", "import requests\n", @@ -407,7 +408,15 @@ ], "source": [ "with initialize(version_base=None, config_path=\"configs\"):\n", - " cfg = compose(config_name=\"config.yaml\", overrides=[f\"env={env}\", f\"agent={agent}\", \"logger.type=terminal\", \"logger.save_checkpoint=true\"])\n", + " cfg = compose(\n", + " config_name=\"config.yaml\",\n", + " overrides=[\n", + " f\"env={env}\",\n", + " f\"agent={agent}\",\n", + " \"logger.type=terminal\",\n", + " \"logger.save_checkpoint=true\",\n", + " ],\n", + " )\n", "\n", "train(cfg)" ] diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 7d5c02807..60ba1da88 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -137,12 +137,8 @@ register(id="PacMan-v1", entry_point="jumanji.environments:PacMan") # SlidingTilePuzzle - A sliding tile puzzle environment with the default grid size of 5x5. -register( - id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle" -) +register(id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle") # LevelBasedForaging with a random generator with 8 grid size, # 2 agents and 2 food items and the maximum agent's level is 2. -register( - id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging" -) +register(id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging") diff --git a/jumanji/env.py b/jumanji/env.py index 9674960c8..737a9002c 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -49,10 +49,10 @@ def __repr__(self) -> str: def __init__(self) -> None: """Initialize environment.""" - self.observation_spec - self.action_spec - self.reward_spec - self.discount_spec + self.observation_spec # noqa: B018 + self.action_spec # noqa: B018 + self.reward_spec # noqa: B018 + self.discount_spec # noqa: B018 @abc.abstractmethod def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -67,9 +67,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """ @abc.abstractmethod - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -115,9 +113,7 @@ def discount_spec(self) -> specs.BoundedArray: Returns: discount_spec: a `specs.BoundedArray` spec. """ - return specs.BoundedArray( - shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount" - ) + return specs.BoundedArray(shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount") @property def unwrapped(self) -> Environment[State, ActionSpec, Observation]: diff --git a/jumanji/environments/commons/maze_utils/maze_generation.py b/jumanji/environments/commons/maze_utils/maze_generation.py index 3be0197e8..5fdbadaed 100644 --- a/jumanji/environments/commons/maze_utils/maze_generation.py +++ b/jumanji/environments/commons/maze_utils/maze_generation.py @@ -36,6 +36,7 @@ nodes) through a vertical wall must be at an even y coordinate while a passage through a horizontal wall must be at an even x coordinate. """ + from typing import NamedTuple, Tuple import chex @@ -123,9 +124,7 @@ def create_chamber(chambers: Stack, x: int, y: int, width: int, height: int) -> return new_stack -def split_vertically( - state: MazeGenerationState, chamber: chex.Array -) -> MazeGenerationState: +def split_vertically(state: MazeGenerationState, chamber: chex.Array) -> MazeGenerationState: """Split the chamber vertically. Randomly draw a horizontal wall to split the chamber vertically. Randomly open a passage @@ -215,8 +214,6 @@ def generate_maze(width: int, height: int, key: chex.PRNGKey) -> chex.Array: initial_state = MazeGenerationState(maze, chambers, key) - final_state = jax.lax.while_loop( - chambers_remaining, split_next_chamber, initial_state - ) + final_state = jax.lax.while_loop(chambers_remaining, split_next_chamber, initial_state) return final_state.maze diff --git a/jumanji/environments/commons/maze_utils/maze_generation_test.py b/jumanji/environments/commons/maze_utils/maze_generation_test.py index 4281e277c..8f9a63fcd 100644 --- a/jumanji/environments/commons/maze_utils/maze_generation_test.py +++ b/jumanji/environments/commons/maze_utils/maze_generation_test.py @@ -109,9 +109,7 @@ def test_random_odd(self, key: chex.PRNGKey) -> None: assert i % 2 == 1 assert 0 <= i < max_val - def test_split_vertically( - self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey - ) -> None: + def test_split_vertically(self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey) -> None: """Test that a horizontal wall is drawn and that subchambers are added to stack.""" chambers, chamber = stack_pop(chambers) state = MazeGenerationState(maze, chambers, key) @@ -124,9 +122,7 @@ def test_split_vertically( assert chambers.insertion_index >= 1 - def test_split_horizontally( - self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey - ) -> None: + def test_split_horizontally(self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey) -> None: """Test that a vertical wall is drawn and that subchambers are added to stack.""" chambers, chamber = stack_pop(chambers) state = MazeGenerationState(maze, chambers, key) diff --git a/jumanji/environments/commons/maze_utils/maze_rendering.py b/jumanji/environments/commons/maze_utils/maze_rendering.py index 631e4b736..59011f1b9 100644 --- a/jumanji/environments/commons/maze_utils/maze_rendering.py +++ b/jumanji/environments/commons/maze_utils/maze_rendering.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence, Tuple +from typing import Callable, ClassVar, Dict, List, Optional, Sequence, Tuple import chex import matplotlib.animation @@ -32,7 +32,7 @@ class MazeViewer(Viewer): FONT_STYLE = "monospace" FIGURE_SIZE = (10.0, 10.0) # EMPTY is white, WALL is black - COLORS = {EMPTY: [1, 1, 1], WALL: [0, 0, 0]} + COLORS: ClassVar[Dict[int, List[int]]] = {EMPTY: [1, 1, 1], WALL: [0, 0, 0]} def __init__(self, name: str, render_mode: str = "human") -> None: """Viewer for a maze environment. diff --git a/jumanji/environments/commons/maze_utils/stack.py b/jumanji/environments/commons/maze_utils/stack.py index d82e4d352..bd5b911f8 100644 --- a/jumanji/environments/commons/maze_utils/stack.py +++ b/jumanji/environments/commons/maze_utils/stack.py @@ -52,6 +52,7 @@ [. . . .]] """ + from typing import NamedTuple, Tuple import chex diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index 45d189d2e..066fd1989 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -76,9 +76,7 @@ class Game2048(Environment[State, specs.DiscreteArray, Observation]): ``` """ - def __init__( - self, board_size: int = 4, viewer: Optional[Viewer[State]] = None - ) -> None: + def __init__(self, board_size: int = 4, viewer: Optional[Viewer[State]] = None) -> None: """Initialize the 2048 game. Args: @@ -166,9 +164,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Updates the environment state after the agent takes an action. Args: @@ -279,9 +275,7 @@ def _add_random_cell(self, board: Board, key: chex.PRNGKey) -> Board: position = jnp.divmod(tile_idx, self.board_size) # Choose the value of the new cell: 1 with probability 90% or 2 with probability of 10% - cell_value = jax.random.choice( - subkey, jnp.array([1, 2]), p=jnp.array([0.9, 0.1]) - ) + cell_value = jax.random.choice(subkey, jnp.array([1, 2]), p=jnp.array([0.9, 0.1])) board = board.at[position].set(cell_value) return board @@ -325,9 +319,7 @@ def animate( Returns: animation.FuncAnimation: the animation object that was created. """ - return self._viewer.animate( - states=states, interval=interval, save_path=save_path - ) + return self._viewer.animate(states=states, interval=interval, save_path=save_path) def close(self) -> None: """Perform any necessary cleanup. diff --git a/jumanji/environments/logic/game_2048/utils.py b/jumanji/environments/logic/game_2048/utils.py index 0cea3389a..7c03d1b2e 100644 --- a/jumanji/environments/logic/game_2048/utils.py +++ b/jumanji/environments/logic/game_2048/utils.py @@ -61,9 +61,7 @@ def can_move_left_row_cond(carry: CanMoveCarry) -> chex.Numeric: def can_move_left_row_body(carry: CanMoveCarry) -> CanMoveCarry: """Check if the current tiles can move and increment the indices.""" # Check if tiles can move - can_move = (carry.origin != 0) & ( - (carry.target == 0) | (carry.target == carry.origin) - ) + can_move = (carry.origin != 0) & ((carry.target == 0) | (carry.target == carry.origin)) # Increment indices as if performed a no op # If not performing no op, loop will be terminated anyways @@ -75,17 +73,13 @@ def can_move_left_row_body(carry: CanMoveCarry) -> CanMoveCarry: ) # Return updated carry - return carry._replace( - can_move=can_move, target_idx=target_idx, origin_idx=origin_idx - ) + return carry._replace(can_move=can_move, target_idx=target_idx, origin_idx=origin_idx) def can_move_left_row(row: chex.Array) -> bool: """Check if row can move left.""" carry = CanMoveCarry(can_move=False, row=row, target_idx=0, origin_idx=1) - can_move: bool = jax.lax.while_loop( - can_move_left_row_cond, can_move_left_row_body, carry - )[0] + can_move: bool = jax.lax.while_loop(can_move_left_row_cond, can_move_left_row_body, carry)[0] return can_move diff --git a/jumanji/environments/logic/game_2048/viewer.py b/jumanji/environments/logic/game_2048/viewer.py index b64b48b20..a82f4ff29 100644 --- a/jumanji/environments/logic/game_2048/viewer.py +++ b/jumanji/environments/logic/game_2048/viewer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from typing import ClassVar, Dict, Optional, Sequence, Tuple import jax.numpy as jnp import matplotlib.animation @@ -24,7 +24,7 @@ class Game2048Viewer(Viewer): - COLORS = { + COLORS: ClassVar[Dict[int | str, str]] = { 1: "#ccc0b3", 2: "#eee4da", 4: "#ede0c8", @@ -158,13 +158,9 @@ def render_tile(self, tile_value: int, ax: plt.Axes, row: int, col: int) -> None """ # Set the background color of the tile based on its value. if tile_value <= 16384: - rect = plt.Rectangle( - [col - 0.5, row - 0.5], 1, 1, color=self.COLORS[int(tile_value)] - ) + rect = plt.Rectangle([col - 0.5, row - 0.5], 1, 1, color=self.COLORS[int(tile_value)]) else: - rect = plt.Rectangle( - [col - 0.5, row - 0.5], 1, 1, color=self.COLORS["other"] - ) + rect = plt.Rectangle([col - 0.5, row - 0.5], 1, 1, color=self.COLORS["other"]) ax.add_patch(rect) if tile_value in [2, 4]: diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index 32de81019..75fcfcbeb 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -97,9 +97,7 @@ def __init__( viewer: environment viewer for rendering. Defaults to `GraphColoringViewer`. """ - self.generator = generator or RandomGenerator( - num_nodes=20, edge_probability=0.8 - ) + self.generator = generator or RandomGenerator(num_nodes=20, edge_probability=0.8) self.num_nodes = self.generator.num_nodes super().__init__() @@ -138,9 +136,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Updates the environment state after the agent takes an action. Specifically, this function allows the agent to choose @@ -182,9 +178,7 @@ def step( # Update the current node index next_node_index = (state.current_node_index + 1) % self.num_nodes - next_action_mask = self._get_valid_actions( - next_node_index, state.adj_matrix, state.colors - ) + next_action_mask = self._get_valid_actions(next_node_index, state.adj_matrix, state.colors) next_state = State( adj_matrix=state.adj_matrix, @@ -263,9 +257,7 @@ def action_spec(self) -> specs.DiscreteArray: Returns: action_spec: specs.DiscreteArray object """ - return specs.DiscreteArray( - num_values=self.num_nodes, name="action", dtype=jnp.int32 - ) + return specs.DiscreteArray(num_values=self.num_nodes, name="action", dtype=jnp.int32) def _get_valid_actions( self, current_node_index: int, adj_matrix: chex.Array, colors: chex.Array @@ -307,9 +299,7 @@ def animate( Returns: animation.FuncAnimation: the animation object that was created. """ - return self._env_viewer.animate( - states=states, interval=interval, save_path=save_path - ) + return self._env_viewer.animate(states=states, interval=interval, save_path=save_path) def close(self) -> None: """Perform any necessary cleanup. diff --git a/jumanji/environments/logic/graph_coloring/env_test.py b/jumanji/environments/logic/graph_coloring/env_test.py index f7b618b1d..b3de0dcaa 100644 --- a/jumanji/environments/logic/graph_coloring/env_test.py +++ b/jumanji/environments/logic/graph_coloring/env_test.py @@ -78,9 +78,7 @@ def test_graph_coloring_get_action_mask(graph_coloring: GraphColoring) -> None: state, _ = graph_coloring.reset(key) num_nodes = graph_coloring.generator.num_nodes get_valid_actions_fn = jax.jit(graph_coloring._get_valid_actions) - action_mask = get_valid_actions_fn( - state.current_node_index, state.adj_matrix, state.colors - ) + action_mask = get_valid_actions_fn(state.current_node_index, state.adj_matrix, state.colors) # Check that the action mask is a boolean array with the correct shape. assert action_mask.dtype == jnp.bool_ diff --git a/jumanji/environments/logic/graph_coloring/generator.py b/jumanji/environments/logic/graph_coloring/generator.py index 4aab23c10..92d975de3 100644 --- a/jumanji/environments/logic/graph_coloring/generator.py +++ b/jumanji/environments/logic/graph_coloring/generator.py @@ -92,9 +92,7 @@ def __call__(self, key: chex.PRNGKey) -> chex.Array: key, edge_key = jax.random.split(key) # Generate a random adjacency matrix with probabilities of connections. - p_matrix = jax.random.uniform( - key=edge_key, shape=(self.num_nodes, self.num_nodes) - ) + p_matrix = jax.random.uniform(key=edge_key, shape=(self.num_nodes, self.num_nodes)) # Threshold the probabilities to create a boolean adjacency matrix. adj_matrix = p_matrix < self.edge_probability diff --git a/jumanji/environments/logic/graph_coloring/viewer.py b/jumanji/environments/logic/graph_coloring/viewer.py index af5d5f179..d18a6ae89 100644 --- a/jumanji/environments/logic/graph_coloring/viewer.py +++ b/jumanji/environments/logic/graph_coloring/viewer.py @@ -205,9 +205,7 @@ def _render_nodes( fill=(colors[i] != -1), ) ) - ax.text( - x, y, str(i), color="white", ha="center", va="center", weight="bold" - ) + ax.text(x, y, str(i), color="white", ha="center", va="center", weight="bold") def _render_edges( self, diff --git a/jumanji/environments/logic/minesweeper/conftest.py b/jumanji/environments/logic/minesweeper/conftest.py index 4f7cc321c..714a7d772 100644 --- a/jumanji/environments/logic/minesweeper/conftest.py +++ b/jumanji/environments/logic/minesweeper/conftest.py @@ -25,9 +25,7 @@ @pytest.fixture def minesweeper_env() -> Minesweeper: """Fixture for a default minesweeper environment with 10 rows and columns, and 10 mines.""" - return Minesweeper( - generator=UniformSamplingGenerator(num_rows=10, num_cols=10, num_mines=10) - ) + return Minesweeper(generator=UniformSamplingGenerator(num_rows=10, num_cols=10, num_mines=10)) @pytest.fixture diff --git a/jumanji/environments/logic/minesweeper/done.py b/jumanji/environments/logic/minesweeper/done.py index a6f40d270..8cf39c5d9 100644 --- a/jumanji/environments/logic/minesweeper/done.py +++ b/jumanji/environments/logic/minesweeper/done.py @@ -26,9 +26,7 @@ class DoneFn(abc.ABC): @abc.abstractmethod - def __call__( - self, state: State, next_state: State, action: chex.Array - ) -> chex.Array: + def __call__(self, state: State, next_state: State, action: chex.Array) -> chex.Array: """Call method for computing the done signal given the current and next state, and the action taken. """ @@ -39,9 +37,7 @@ class DefaultDoneFn(DoneFn): or the board is solved. """ - def __call__( - self, state: State, next_state: State, action: chex.Array - ) -> chex.Array: + def __call__(self, state: State, next_state: State, action: chex.Array) -> chex.Array: return ( ~is_valid_action(state=state, action=action) | explored_mine(state=state, action=action) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 1e9d8d4f1..7dba97184 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -129,9 +129,7 @@ def __init__( self.num_cols = self.generator.num_cols self.num_mines = self.generator.num_mines super().__init__() - self._viewer = viewer or MinesweeperViewer( - num_rows=self.num_rows, num_cols=self.num_cols - ) + self._viewer = viewer or MinesweeperViewer(num_rows=self.num_rows, num_cols=self.num_cols) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -149,9 +147,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=observation) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -162,9 +158,7 @@ def step( next_state: `State` corresponding to the next state of the environment, next_timestep: `TimeStep` corresponding to the timestep returned by the environment. """ - board = state.board.at[tuple(action)].set( - count_adjacent_mines(state=state, action=action) - ) + board = state.board.at[tuple(action)].set(count_adjacent_mines(state=state, action=action)) step_count = state.step_count + 1 next_state = State( board=board, @@ -279,9 +273,7 @@ def animate( Returns: animation.FuncAnimation: the animation object that was created. """ - return self._viewer.animate( - states=states, interval=interval, save_path=save_path - ) + return self._viewer.animate(states=states, interval=interval, save_path=save_path) def close(self) -> None: """Perform any necessary cleanup. diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 7ae532deb..68bd52537 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -144,9 +144,7 @@ def test_minesweeper__step(minesweeper_env: Minesweeper) -> None: # Check that the state has changed, since we took the same action twice assert jnp.array_equal(next_next_state.board, next_state.board) - assert ( - next_next_timestep.observation.num_mines == next_timestep.observation.num_mines - ) + assert next_next_timestep.observation.num_mines == next_timestep.observation.num_mines assert next_next_state.step_count == 2 assert next_next_timestep.observation.step_count == 2 assert jnp.array_equal(next_next_state.board, next_next_timestep.observation.board) @@ -162,9 +160,7 @@ def test_minesweeper__specs_does_not_smoke(minesweeper_env: Minesweeper) -> None check_env_specs_does_not_smoke(minesweeper_env) -def test_minesweeper__render( - monkeypatch: pytest.MonkeyPatch, minesweeper_env: Minesweeper -) -> None: +def test_minesweeper__render(monkeypatch: pytest.MonkeyPatch, minesweeper_env: Minesweeper) -> None: """Check that the render method builds the figure but does not display it.""" monkeypatch.setattr(plt, "show", lambda fig: None) state, timestep = jax.jit(minesweeper_env.reset)(jax.random.PRNGKey(0)) @@ -205,9 +201,7 @@ def test_minesweeper__solved(minesweeper_env: Minesweeper) -> None: minesweeper_env.num_rows * minesweeper_env.num_cols - minesweeper_env.num_mines ) assert collected_rewards == [1.0] * expected_episode_length - assert collected_step_types == [StepType.MID] * (expected_episode_length - 1) + [ - StepType.LAST - ] + assert collected_step_types == [StepType.MID] * (expected_episode_length - 1) + [StepType.LAST] def test_minesweeper_animation( diff --git a/jumanji/environments/logic/minesweeper/utils_test.py b/jumanji/environments/logic/minesweeper/utils_test.py index 55b2d2b99..f36177e40 100644 --- a/jumanji/environments/logic/minesweeper/utils_test.py +++ b/jumanji/environments/logic/minesweeper/utils_test.py @@ -47,6 +47,7 @@ False, False, ], + strict=False, ), ) def test_explored_mine( @@ -66,7 +67,9 @@ def test_explored_mine( @pytest.mark.parametrize( "action, expected_count_adjacent_mines_result", - zip(product(range(4), range(4)), [2, 4, 2, 2, 4, 8, 4, 3, 2, 4, 2, 2, 2, 3, 2, 1]), + zip( + product(range(4), range(4)), [2, 4, 2, 2, 4, 8, 4, 3, 2, 4, 2, 2, 2, 3, 2, 1], strict=False + ), ) def test_count_adjacent_mines( manual_start_state: State, diff --git a/jumanji/environments/logic/rubiks_cube/conftest.py b/jumanji/environments/logic/rubiks_cube/conftest.py index c1e7ff671..0364d24cf 100644 --- a/jumanji/environments/logic/rubiks_cube/conftest.py +++ b/jumanji/environments/logic/rubiks_cube/conftest.py @@ -77,6 +77,4 @@ def expected_scramble_result() -> chex.Array: @pytest.fixture def rubiks_cube() -> RubiksCube: """Instantiates a `RubiksCube` environment with 10 scrambles on reset.""" - return RubiksCube( - generator=ScramblingGenerator(cube_size=3, num_scrambles_on_reset=10) - ) + return RubiksCube(generator=ScramblingGenerator(cube_size=3, num_scrambles_on_reset=10)) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index a4472e0ed..4d8b6db06 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -135,9 +135,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=observation) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -214,9 +212,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: action_spec: `MultiDiscreteArray` object. """ return specs.MultiDiscreteArray( - num_values=jnp.array( - [len(Face), self.generator.cube_size // 2, 3], jnp.int32 - ), + num_values=jnp.array([len(Face), self.generator.cube_size // 2, 3], jnp.int32), name="action", dtype=jnp.int32, ) @@ -249,9 +245,7 @@ def animate( Returns: animation.FuncAnimation: the animation object that was created. """ - return self._viewer.animate( - states=states, interval=interval, save_path=save_path - ) + return self._viewer.animate(states=states, interval=interval, save_path=save_path) def close(self) -> None: """Perform any necessary cleanup. diff --git a/jumanji/environments/logic/rubiks_cube/env_test.py b/jumanji/environments/logic/rubiks_cube/env_test.py index 59d56cfa3..af7b08a14 100644 --- a/jumanji/environments/logic/rubiks_cube/env_test.py +++ b/jumanji/environments/logic/rubiks_cube/env_test.py @@ -97,9 +97,7 @@ def test_rubiks_cube__specs_does_not_smoke(cube_size: int) -> None: check_env_specs_does_not_smoke(env) -def test_rubiks_cube__render( - monkeypatch: pytest.MonkeyPatch, rubiks_cube: RubiksCube -) -> None: +def test_rubiks_cube__render(monkeypatch: pytest.MonkeyPatch, rubiks_cube: RubiksCube) -> None: """Test that the render method builds the figure (but does not display it).""" monkeypatch.setattr(plt, "show", lambda fig: None) state, timestep = rubiks_cube.reset(jax.random.PRNGKey(0)) @@ -128,9 +126,7 @@ def test_rubiks_cube__done(time_limit: int) -> None: assert episode_length == time_limit -def test_rubiks_cube__animate( - rubiks_cube: RubiksCube, mocker: pytest_mock.MockerFixture -) -> None: +def test_rubiks_cube__animate(rubiks_cube: RubiksCube, mocker: pytest_mock.MockerFixture) -> None: """Test that the `animate` method creates the animation correctly (but does not display it).""" states = mocker.MagicMock() animation = rubiks_cube.animate(states) diff --git a/jumanji/environments/logic/rubiks_cube/utils.py b/jumanji/environments/logic/rubiks_cube/utils.py index 040aabe45..db6581332 100644 --- a/jumanji/environments/logic/rubiks_cube/utils.py +++ b/jumanji/environments/logic/rubiks_cube/utils.py @@ -160,9 +160,7 @@ def up_move_function(cube: Cube) -> Cube: return up_move_function -def generate_front_move( - amount: CubeMovementAmount, depth: int -) -> Callable[[Cube], Cube]: +def generate_front_move(amount: CubeMovementAmount, depth: int) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the front face. Args: @@ -207,9 +205,7 @@ def front_move_function(cube: Cube) -> Cube: return front_move_function -def generate_right_move( - amount: CubeMovementAmount, depth: int -) -> Callable[[Cube], Cube]: +def generate_right_move(amount: CubeMovementAmount, depth: int) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the right face. Args: @@ -254,9 +250,7 @@ def right_move_function(cube: Cube) -> Cube: return right_move_function -def generate_back_move( - amount: CubeMovementAmount, depth: int -) -> Callable[[Cube], Cube]: +def generate_back_move(amount: CubeMovementAmount, depth: int) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the back face. Args: @@ -301,9 +295,7 @@ def back_move_function(cube: Cube) -> Cube: return back_move_function -def generate_left_move( - amount: CubeMovementAmount, depth: int -) -> Callable[[Cube], Cube]: +def generate_left_move(amount: CubeMovementAmount, depth: int) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the left face. Args: @@ -348,9 +340,7 @@ def left_move_function(cube: Cube) -> Cube: return left_move_function -def generate_down_move( - amount: CubeMovementAmount, depth: int -) -> Callable[[Cube], Cube]: +def generate_down_move(amount: CubeMovementAmount, depth: int) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the down face. Args: @@ -491,9 +481,7 @@ def flatten_action(unflattened_action: chex.Array, cube_size: int) -> chex.Array """ face, depth, amount = unflattened_action return ( - face * len(CubeMovementAmount) * (cube_size // 2) - + depth * len(CubeMovementAmount) - + amount + face * len(CubeMovementAmount) * (cube_size // 2) + depth * len(CubeMovementAmount) + amount ) diff --git a/jumanji/environments/logic/rubiks_cube/utils_test.py b/jumanji/environments/logic/rubiks_cube/utils_test.py index c28154002..e854d8c2d 100644 --- a/jumanji/environments/logic/rubiks_cube/utils_test.py +++ b/jumanji/environments/logic/rubiks_cube/utils_test.py @@ -79,9 +79,7 @@ def test_flatten_and_unflatten_action(cube_size: int) -> None: unflattened_actions = jnp.stack( [ jnp.repeat(faces, len(CubeMovementAmount) * (cube_size // 2)), - jnp.concatenate( - [jnp.repeat(depths, len(CubeMovementAmount)) for _ in Face] - ), + jnp.concatenate([jnp.repeat(depths, len(CubeMovementAmount)) for _ in Face]), jnp.concatenate([amounts for _ in range(len(Face) * (cube_size // 2))]), ] ) @@ -140,9 +138,7 @@ def test_half_turns( assert jnp.array_equal(cube, differently_stickered_cube) -def test_solved_reward( - solved_cube: chex.Array, differently_stickered_cube: chex.Array -) -> None: +def test_solved_reward(solved_cube: chex.Array, differently_stickered_cube: chex.Array) -> None: """Test that the cube fixtures have the expected rewards.""" solved_state = State( cube=solved_cube, @@ -163,6 +159,7 @@ def test_solved_reward( zip( generate_all_moves(cube_size=3), is_face_turn(cube_size=3), + strict=False, ), ) def test_moves_nontrivial( @@ -197,9 +194,7 @@ def test_moves_nontrivial( num_non_face_impacted_cubies = (len(Face) - 2) * differently_stickered_cube_size assert jnp.not_equal( differently_stickered_cube, moved_differently_stickered_cube - ).sum() == num_non_face_impacted_cubies + ( - num_face_impacted_cubies if move_is_face_turn else 0 - ) + ).sum() == num_non_face_impacted_cubies + (num_face_impacted_cubies if move_is_face_turn else 0) if differently_stickered_cube_size % 2 == 1: assert jnp.array_equal( differently_stickered_cube[ @@ -285,9 +280,7 @@ def test_checkerboard(cube_size: int, indices: List[int]) -> None: assert jnp.array_equal(cube[face.value], expected_result) -def test_manual_scramble( - solved_cube: chex.Array, expected_scramble_result: chex.Array -) -> None: +def test_manual_scramble(solved_cube: chex.Array, expected_scramble_result: chex.Array) -> None: """Testing a particular scramble manually. Scramble chosen to have all faces touched at least once.""" scramble = [ @@ -328,17 +321,13 @@ def test_manual_scramble( ], dtype=jnp.int32, ) - flattened_sequence = jnp.array( - [0, 14, 16, 2, 10, 6, 3, 7, 13, 11, 4, 0, 15], dtype=jnp.int32 - ) + flattened_sequence = jnp.array([0, 14, 16, 2, 10, 6, 3, 7, 13, 11, 4, 0, 15], dtype=jnp.int32) assert jnp.array_equal( unflattened_sequence.transpose(), unflatten_action(flattened_action=flattened_sequence, cube_size=3), ) flatten_fn = lambda x: flatten_action(x, 3) - assert jnp.array_equal( - flattened_sequence, jax.vmap(flatten_fn)(unflattened_sequence) - ) + assert jnp.array_equal(flattened_sequence, jax.vmap(flatten_fn)(unflattened_sequence)) cube = scramble_solved_cube( flattened_actions_in_scramble=flattened_sequence, cube_size=3, diff --git a/jumanji/environments/logic/rubiks_cube/viewer.py b/jumanji/environments/logic/rubiks_cube/viewer.py index 77d493a78..98c6473e6 100644 --- a/jumanji/environments/logic/rubiks_cube/viewer.py +++ b/jumanji/environments/logic/rubiks_cube/viewer.py @@ -94,9 +94,7 @@ def _get_fig_ax(self) -> Tuple[plt.Figure, List[plt.Axes]]: fig = plt.figure(self.figure_name) ax = fig.get_axes() else: - fig, ax = plt.subplots( - nrows=3, ncols=2, figsize=self.figure_size, num=self.figure_name - ) + fig, ax = plt.subplots(nrows=3, ncols=2, figsize=self.figure_size, num=self.figure_name) fig.suptitle(self.figure_name) ax = ax.flatten() plt.tight_layout() diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env.py b/jumanji/environments/logic/sliding_tile_puzzle/env.py index 9ab17e6c4..3304dc32d 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/env.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/env.py @@ -92,9 +92,7 @@ def __init__( time_limit: maximum number of steps before the episode is terminated, default to 500. viewer: environment viewer for rendering. """ - self.generator = generator or RandomWalkGenerator( - grid_size=5, num_random_moves=200 - ) + self.generator = generator or RandomWalkGenerator(grid_size=5, num_random_moves=200) self.reward_fn = reward_fn or DenseRewardFn() self.time_limit = time_limit super().__init__() @@ -117,9 +115,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=obs, extras=self._get_extras(state)) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Updates the environment state after the agent takes an action.""" (updated_puzzle, updated_empty_tile_position) = self._move_empty_tile( state.puzzle, state.empty_tile_position, action @@ -175,8 +171,7 @@ def _move_empty_tile( # Predicate for the conditional is_valid_move = jnp.all( - (new_empty_tile_position >= 0) - & (new_empty_tile_position < self.generator.grid_size) + (new_empty_tile_position >= 0) & (new_empty_tile_position < self.generator.grid_size) ) # Swap the empty tile and the tile at the new position @@ -274,9 +269,7 @@ def animate( Returns: animation.FuncAnimation: the animation object that was created. """ - return self._env_viewer.animate( - states=states, interval=interval, save_path=save_path - ) + return self._env_viewer.animate(states=states, interval=interval, save_path=save_path) def close(self) -> None: """Perform any necessary cleanup. diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env_test.py b/jumanji/environments/logic/sliding_tile_puzzle/env_test.py index 5ed9ddc75..4b7cb3f74 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/env_test.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/env_test.py @@ -47,9 +47,7 @@ def test_sliding_tile_puzzle_reset_jit(sliding_tile_puzzle: SlidingTilePuzzle) - assert isinstance(state, State) -def test_sliding_tile_puzzle_step_jit( - sliding_tile_puzzle: SlidingTilePuzzle, state: State -) -> None: +def test_sliding_tile_puzzle_step_jit(sliding_tile_puzzle: SlidingTilePuzzle, state: State) -> None: """Confirm that the step is only compiled once when jitted.""" up_action = jnp.array(0) down_action = jnp.array(2) diff --git a/jumanji/environments/logic/sliding_tile_puzzle/generator.py b/jumanji/environments/logic/sliding_tile_puzzle/generator.py index f2848e5d2..4b6a41973 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/generator.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/generator.py @@ -117,15 +117,11 @@ def _make_random_move( move = jax.random.choice(key, MOVES, shape=(), p=valid_moves_mask) new_empty_tile_position = empty_tile_position + move # Swap the empty tile with the tile at the new position using _swap_tiles - updated_puzzle = self._swap_tiles( - puzzle, empty_tile_position, new_empty_tile_position - ) + updated_puzzle = self._swap_tiles(puzzle, empty_tile_position, new_empty_tile_position) return updated_puzzle, new_empty_tile_position - def _swap_tiles( - self, puzzle: chex.Array, pos1: chex.Array, pos2: chex.Array - ) -> chex.Array: + def _swap_tiles(self, puzzle: chex.Array, pos1: chex.Array, pos2: chex.Array) -> chex.Array: """Swaps the tiles at the given positions.""" temp = puzzle[tuple(pos1)] puzzle = puzzle.at[tuple(pos1)].set(puzzle[tuple(pos2)]) diff --git a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py index 7fa905fcf..8ea9bfd20 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py @@ -50,9 +50,7 @@ def __init__(self, name: str = "SlidingTilePuzzle") -> None: """ self._name = name self._animation: Optional[matplotlib.animation.Animation] = None - self._color_map = mcolors.LinearSegmentedColormap.from_list( - "", ["white", "blue"] - ) + self._color_map = mcolors.LinearSegmentedColormap.from_list("", ["white", "blue"]) def render(self, state: State) -> None: """Renders the current state of the game puzzle. @@ -137,9 +135,7 @@ def draw_puzzle(self, ax: plt.Axes, state: State) -> None: tile_value = state.puzzle[row, col] if tile_value == 0: # Render the empty tile - rect = plt.Rectangle( - [col - 0.5, row - 0.5], 1, 1, color=self.EMPTY_TILE_COLOR - ) + rect = plt.Rectangle([col - 0.5, row - 0.5], 1, 1, color=self.EMPTY_TILE_COLOR) ax.add_patch(rect) else: # Render the numbered tile diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py index 64a91899d..41c206500 100644 --- a/jumanji/environments/logic/sudoku/env.py +++ b/jumanji/environments/logic/sudoku/env.py @@ -98,26 +98,20 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=obs) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: # check if action is valid invalid = ~state.action_mask[tuple(action)] updated_board = apply_action(action=action, board=state.board) updated_action_mask = get_action_mask(board=updated_board) # creating next state - next_state = State( - board=updated_board, action_mask=updated_action_mask, key=state.key - ) + next_state = State(board=updated_board, action_mask=updated_action_mask, key=state.key) no_actions_available = ~jnp.any(updated_action_mask) # computing terminal condition done = invalid | no_actions_available - reward = self._reward_fn( - state=state, new_state=next_state, action=action, done=done - ) + reward = self._reward_fn(state=state, new_state=next_state, action=action, done=done) observation = Observation(board=updated_board, action_mask=updated_action_mask) @@ -157,9 +151,7 @@ def observation_spec(self) -> specs.Spec[Observation]: name="action_mask", ) - return specs.Spec( - Observation, "ObservationSpec", board=board, action_mask=action_mask - ) + return specs.Spec(Observation, "ObservationSpec", board=board, action_mask=action_mask) @cached_property def action_spec(self) -> specs.MultiDiscreteArray: @@ -200,6 +192,4 @@ def animate( Returns: animation.FuncAnimation: the animation object that was created. """ - return self._viewer.animate( - states=states, interval=interval, save_path=save_path - ) + return self._viewer.animate(states=states, interval=interval, save_path=save_path) diff --git a/jumanji/environments/logic/sudoku/env_test.py b/jumanji/environments/logic/sudoku/env_test.py index 85bd9d672..a20ab91ac 100644 --- a/jumanji/environments/logic/sudoku/env_test.py +++ b/jumanji/environments/logic/sudoku/env_test.py @@ -95,9 +95,7 @@ def test_sudoku__render(monkeypatch: pytest.MonkeyPatch, sudoku_env: Sudoku) -> sudoku_env.close() -def test_sudoku_animation( - sudoku_env: Sudoku, mocker: pytest_mock.MockerFixture -) -> None: +def test_sudoku_animation(sudoku_env: Sudoku, mocker: pytest_mock.MockerFixture) -> None: """Check that the animation method creates the animation correctly.""" states = mocker.MagicMock() animation = sudoku_env.animate(states) diff --git a/jumanji/environments/logic/sudoku/generator.py b/jumanji/environments/logic/sudoku/generator.py index 65d359ab7..00c93db9b 100644 --- a/jumanji/environments/logic/sudoku/generator.py +++ b/jumanji/environments/logic/sudoku/generator.py @@ -88,9 +88,7 @@ def __init__(self, database: chex.Array): def __call__(self, key: chex.PRNGKey) -> State: key, idx_key = jax.random.split(key) - idx = jax.random.randint( - idx_key, shape=(), minval=0, maxval=self._boards.shape[0] - ) + idx = jax.random.randint(idx_key, shape=(), minval=0, maxval=self._boards.shape[0]) board = self._boards.take(idx, axis=0) board = jnp.asarray(board, dtype=jnp.int32) - 1 action_mask = get_action_mask(board) diff --git a/jumanji/environments/logic/sudoku/utils.py b/jumanji/environments/logic/sudoku/utils.py index 761dc9fb1..f359f6264 100644 --- a/jumanji/environments/logic/sudoku/utils.py +++ b/jumanji/environments/logic/sudoku/utils.py @@ -39,9 +39,7 @@ def _validate_row(row: chex.Array) -> chex.Array: condition_rows = jax.vmap(_validate_row)(board).all() condition_columns = jax.vmap(_validate_row)(board.T).all() - condition_boxes = jax.vmap(_validate_row)( - jnp.take(board, jnp.asarray(BOX_IDX)) - ).all() + condition_boxes = jax.vmap(_validate_row)(jnp.take(board, jnp.asarray(BOX_IDX))).all() return condition_rows & condition_columns & condition_boxes diff --git a/jumanji/environments/packing/bin_pack/conftest.py b/jumanji/environments/packing/bin_pack/conftest.py index 33ca6c437..60e3a3b0f 100644 --- a/jumanji/environments/packing/bin_pack/conftest.py +++ b/jumanji/environments/packing/bin_pack/conftest.py @@ -125,9 +125,7 @@ def dense_reward() -> DenseReward: @pytest.fixture -def bin_pack_dense_reward( - dummy_generator: DummyGenerator, dense_reward: DenseReward -) -> BinPack: +def bin_pack_dense_reward(dummy_generator: DummyGenerator, dense_reward: DenseReward) -> BinPack: return BinPack( generator=dummy_generator, obs_num_ems=5, @@ -141,9 +139,7 @@ def sparse_reward() -> SparseReward: @pytest.fixture -def bin_pack_sparse_reward( - dummy_generator: DummyGenerator, sparse_reward: SparseReward -) -> BinPack: +def bin_pack_sparse_reward(dummy_generator: DummyGenerator, sparse_reward: SparseReward) -> BinPack: return BinPack( generator=dummy_generator, obs_num_ems=5, diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 3506fa0b0..7e06c3353 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -198,9 +198,7 @@ def observation_spec(self) -> specs.Spec[Observation]: if self.normalize_dimensions: ems_dict = { - f"{coord_name}": specs.BoundedArray( - (obs_num_ems,), float, 0.0, 1.0, coord_name - ) + f"{coord_name}": specs.BoundedArray((obs_num_ems,), float, 0.0, 1.0, coord_name) for coord_name in ["x1", "x2", "y1", "y2", "z1", "z2"] } else: @@ -219,18 +217,12 @@ def observation_spec(self) -> specs.Spec[Observation]: } else: items_dict = { - f"{axis}": specs.BoundedArray( - (max_num_items,), jnp.int32, 0, max_dim, axis - ) + f"{axis}": specs.BoundedArray((max_num_items,), jnp.int32, 0, max_dim, axis) for axis in ["x_len", "y_len", "z_len"] } items = specs.Spec(Item, "ItemsSpec", **items_dict) - items_mask = specs.BoundedArray( - (max_num_items,), bool, False, True, "items_mask" - ) - items_placed = specs.BoundedArray( - (max_num_items,), bool, False, True, "items_placed" - ) + items_mask = specs.BoundedArray((max_num_items,), bool, False, True, "items_mask") + items_placed = specs.BoundedArray((max_num_items,), bool, False, True, "items_placed") action_mask = specs.BoundedArray( (obs_num_ems, max_num_items), bool, @@ -258,9 +250,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: - ems_id: int between 0 and obs_num_ems - 1 (included). - item_id: int between 0 and max_num_items - 1 (included). """ - num_values = jnp.array( - [self.obs_num_ems, self.generator.max_num_items], jnp.int32 - ) + num_values = jnp.array([self.obs_num_ems, self.generator.max_num_items], jnp.int32) return specs.MultiDiscreteArray(num_values=num_values, name="action") def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -294,9 +284,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. If the action is invalid, the state is not updated, i.e. the action is not taken, and the episode terminates. @@ -394,9 +382,7 @@ def close(self) -> None: """ self._viewer.close() - def _make_observation_and_extras( - self, state: State - ) -> Tuple[State, Observation, Dict]: + def _make_observation_and_extras(self, state: State) -> Tuple[State, Observation, Dict]: """Computes the observation and the environment metrics to include in `timestep.extras`. Also updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is obtained by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones. @@ -456,9 +442,7 @@ def _get_extras(self, state: State) -> Dict: } return extras - def _normalize_ems_and_items( - self, state: State, obs_ems: EMS, items: Item - ) -> Tuple[EMS, Item]: + def _normalize_ems_and_items(self, state: State, obs_ems: EMS, items: Item) -> Tuple[EMS, Item]: """Normalize the EMSs and items in the observation. Each dimension is divided by the container length so that they are all between 0.0 and 1.0. Hence, the ratio is not kept. """ @@ -477,14 +461,10 @@ def _ems_are_all_valid(self, state: State) -> chex.Array: container. This check is only done in debug mode. """ item_spaces = space_from_item_and_location(state.items, state.items_location) - ems_intersect_items = jax.vmap(Space.intersect, in_axes=(0, None))( - state.ems, item_spaces - ) + ems_intersect_items = jax.vmap(Space.intersect, in_axes=(0, None))(state.ems, item_spaces) ems_intersect_items &= jnp.outer(state.ems_mask, state.items_placed) ems_intersection_with_items = jnp.any(ems_intersect_items) - ems_outside_container = jnp.any( - state.ems_mask & ~state.ems.is_included(state.container) - ) + ems_outside_container = jnp.any(state.ems_mask & ~state.ems.is_included(state.container)) return ~ems_intersection_with_items & ~ems_outside_container def _get_set_of_largest_ems( @@ -492,9 +472,7 @@ def _get_set_of_largest_ems( ) -> Tuple[EMS, chex.Array, chex.Array]: """Returns a subset of EMSs by selecting the `obs_num_ems` largest EMSs.""" ems_volumes = ems.volume() * ems_mask - sorted_ems_indexes = jnp.argsort( - -ems_volumes - ) # minus sign to sort in decreasing order + sorted_ems_indexes = jnp.argsort(-ems_volumes) # minus sign to sort in decreasing order obs_ems_indexes = sorted_ems_indexes[: self.obs_num_ems] obs_ems = jax.tree_util.tree_map(lambda x: x[obs_ems_indexes], ems) obs_ems_mask = ems_mask[obs_ems_indexes] @@ -567,7 +545,7 @@ def _update_ems(self, state: State, item_id: chex.Numeric) -> State: new_ems = state.ems new_ems_mask = ems_mask_after_intersect for intersection_ems, intersection_mask in zip( - intersections_ems_dict.values(), intersections_mask_dict.values() + intersections_ems_dict.values(), intersections_mask_dict.values(), strict=False ): new_ems, new_ems_mask = self._add_ems( intersection_ems, intersection_mask, new_ems, new_ems_mask @@ -585,12 +563,8 @@ def _get_intersections_dict( """ # Create new EMSs from EMSs that intersect the new item intersections_ems_dict: Dict[str, Space] = { - f"{axis}_{direction}": item_space.hyperplane(axis, direction).intersection( - state.ems - ) - for axis, direction in itertools.product( - ["x", "y", "z"], ["lower", "upper"] - ) + f"{axis}_{direction}": item_space.hyperplane(axis, direction).intersection(state.ems) + for axis, direction in itertools.product(["x", "y", "z"], ["lower", "upper"]) } # A new EMS is added if the intersection is not empty and if it is not fully included in @@ -611,12 +585,12 @@ def _get_intersections_dict( for (direction, direction_intersections_ems), ( _, direction_intersections_mask, - ) in zip(intersections_ems_dict.items(), intersections_mask_dict.items()): + ) in zip(intersections_ems_dict.items(), intersections_mask_dict.items(), strict=False): # Inner loop iterates through alternative directions. for (alt_direction, alt_direction_intersections_ems), ( _, alt_direction_intersections_mask, - ) in zip(intersections_ems_dict.items(), intersections_mask_dict.items()): + ) in zip(intersections_ems_dict.items(), intersections_mask_dict.items(), strict=False): # The current direction EMS is included in the alternative EMS. directions_included_in_alt_directions = jax.vmap( jax.vmap(Space.is_included, in_axes=(None, 0)), in_axes=(0, None) @@ -629,9 +603,7 @@ def _get_intersections_dict( ) directions_included_in_alt_directions = ( directions_included_in_alt_directions - & jnp.outer( - direction_intersections_mask, alt_direction_intersections_mask - ) + & jnp.outer(direction_intersections_mask, alt_direction_intersections_mask) ) # The alternative EMSs are included in the current direction EMSs. @@ -646,9 +618,7 @@ def _get_intersections_dict( ) alt_directions_included_in_directions = ( alt_directions_included_in_directions - & jnp.outer( - alt_direction_intersections_mask, direction_intersections_mask - ) + & jnp.outer(alt_direction_intersections_mask, direction_intersections_mask) ) # Remove EMSs that are strictly included in other EMSs. This does not remove diff --git a/jumanji/environments/packing/bin_pack/env_test.py b/jumanji/environments/packing/bin_pack/env_test.py index 967a56538..3daa8b944 100644 --- a/jumanji/environments/packing/bin_pack/env_test.py +++ b/jumanji/environments/packing/bin_pack/env_test.py @@ -65,8 +65,8 @@ def normalize_dimensions(request: pytest.mark.FixtureRequest) -> bool: return request.param # type: ignore -@pytest.fixture(scope="function") # noqa: CCR001 -def bin_pack_optimal_policy_select_action( # noqa: CCR001 +@pytest.fixture(scope="function") +def bin_pack_optimal_policy_select_action( request: pytest.mark.FixtureRequest, ) -> Callable[[Observation, State], chex.Array]: """Optimal policy for the BinPack environment. @@ -84,9 +84,7 @@ def unnormalize_obs_ems(obs_ems: Space, solution: State) -> Space: ) return obs_ems - def select_action( # noqa: CCR001 - observation: Observation, solution: State - ) -> chex.Array: + def select_action(observation: Observation, solution: State) -> chex.Array: """Outputs the best action to fully pack the container.""" for obs_ems_id, obs_ems_action_mask in enumerate(observation.action_mask): if not obs_ems_action_mask.any(): @@ -257,9 +255,7 @@ def test_bin_pack__optimal_policy_random_instance( solution = generate_solution_fn(key) while not timestep.last(): - action = bin_pack_optimal_policy_select_action( - timestep.observation, solution - ) + action = bin_pack_optimal_policy_select_action(timestep.observation, solution) assert timestep.observation.action_mask[tuple(action)] state, timestep = step_fn(state, action) assert not timestep.extras["invalid_action"] diff --git a/jumanji/environments/packing/bin_pack/generator.py b/jumanji/environments/packing/bin_pack/generator.py index de32e635d..46557aa4a 100644 --- a/jumanji/environments/packing/bin_pack/generator.py +++ b/jumanji/environments/packing/bin_pack/generator.py @@ -65,9 +65,7 @@ class Generator(abc.ABC): for generating an instance when the environment is reset. """ - def __init__( - self, max_num_items: int, max_num_ems: int, container_dims: Tuple[int, int, int] - ): + def __init__(self, max_num_items: int, max_num_ems: int, container_dims: Tuple[int, int, int]): """Abstract class implementing `max_num_items` and `max_num_ems` properties. Args: @@ -131,9 +129,7 @@ def _unpack_items(self, state: State) -> State: """ state.ems_mask = jnp.zeros(self.max_num_ems, bool).at[0].set(True) state.items_placed = jnp.zeros(self.max_num_items, bool) - state.items_location = Location( - *tuple(jnp.zeros((3, self.max_num_items), jnp.int32)) - ) + state.items_location = Location(*tuple(jnp.zeros((3, self.max_num_items), jnp.int32))) return state @@ -144,9 +140,7 @@ class ToyGenerator(Generator): def __init__(self) -> None: """Instantiate a `ToyGenerator` with 20 items and 60 EMSs maximum.""" - super().__init__( - max_num_items=20, max_num_ems=60, container_dims=TWENTY_FOOT_DIMS - ) + super().__init__(max_num_items=20, max_num_ems=60, container_dims=TWENTY_FOOT_DIMS) def __call__(self, key: chex.PRNGKey) -> State: """Call method responsible for generating a new state. It returns a 20-ft container instance @@ -405,9 +399,7 @@ def __init__( container_dims: (length, width, height) tuple of integers corresponding to the dimensions of the container in millimeters. By default, assume a 20-ft container. """ - self.instance_from_csv = self._parse_csv_file( - csv_path, max_num_ems, container_dims - ) + self.instance_from_csv = self._parse_csv_file(csv_path, max_num_ems, container_dims) max_num_items = self.instance_from_csv.items_mask.shape[0] super().__init__(max_num_items, max_num_ems, container_dims) @@ -497,9 +489,7 @@ def _read_csv(self, csv_path: str) -> List[Tuple[str, int, int, int, int]]: ) return rows - def _generate_list_of_items( - self, rows: List[Tuple[str, int, int, int, int]] - ) -> List[Item]: + def _generate_list_of_items(self, rows: List[Tuple[str, int, int, int, int]]) -> List[Item]: """Generate the list of items from a Pandas DataFrame. Args: @@ -510,7 +500,7 @@ def _generate_list_of_items( their quantity. """ list_of_items = [] - for (_, x_len, y_len, z_len, quantity) in rows: + for _, x_len, y_len, z_len, quantity in rows: identical_items = quantity * [ Item( x_len=jnp.array(x_len, jnp.int32), @@ -538,18 +528,15 @@ def save_instance_to_csv(state: State, path: str) -> None: shape_1,1080,760,300,5 shape_2,1100,430,250,3 """ - items = list(zip(state.items.x_len, state.items.y_len, state.items.z_len)) + items = list(zip(state.items.x_len, state.items.y_len, state.items.z_len, strict=False)) items = [ tuple(x.item() for x in item) - for item, mask in zip(items, state.items_mask) + for item, mask in zip(items, state.items_mask, strict=False) if mask and all(x > 0 for x in item) ] grouped_items = list(collections.Counter(items).items()) grouped_items.sort(key=operator.itemgetter(1), reverse=True) - rows = [ - (f"shape_{i}", *item, count) - for i, (item, count) in enumerate(grouped_items, start=1) - ] + rows = [(f"shape_{i}", *item, count) for i, (item, count) in enumerate(grouped_items, start=1)] with open(path, "w", newline="") as csvfile: writer = csv.writer(csvfile) writer.writerow(CSV_COLUMNS) @@ -659,9 +646,7 @@ def _generate_solved_instance(self, key: chex.PRNGKey) -> State: ems = tree_transpose(list_of_ems) ems_mask = jnp.zeros(self.max_num_ems, bool) - items_spaces, items_mask = self._split_container_into_items_spaces( - container, split_key - ) + items_spaces, items_mask = self._split_container_into_items_spaces(container, split_key) items = item_from_space(items_spaces) sorted_ems_indexes = jnp.arange(0, self.max_num_ems, dtype=jnp.int32) @@ -691,12 +676,10 @@ def _split_container_into_items_spaces( def cond_fun(val: Tuple[Space, chex.Array, chex.PRNGKey]) -> jnp.bool_: _, items_mask, _ = val num_placed_items = jnp.sum(items_mask) - return ( - num_placed_items < self.max_num_items - self._split_num_same_items + 1 - ) + return num_placed_items < self.max_num_items - self._split_num_same_items + 1 def body_fun( - val: Tuple[Space, chex.Array, chex.PRNGKey] + val: Tuple[Space, chex.Array, chex.PRNGKey], ) -> Tuple[Space, chex.Array, chex.PRNGKey]: items_spaces, items_mask, key = val key, subkey = jax.random.split(key) @@ -773,12 +756,8 @@ def _split_along_axis( items_spaces, items_mask = jax.lax.cond( jax.random.uniform(mode_key) < self._prob_split_one_item, - functools.partial( - self._split_item_once, item_space, axis, axis_len, item_id - ), - functools.partial( - self._split_item_multiple_times, item_space, axis, axis_len, item_id - ), + functools.partial(self._split_item_once, item_space, axis, axis_len, item_id), + functools.partial(self._split_item_multiple_times, item_space, axis, axis_len, item_id), items_spaces, items_mask, split_key, @@ -799,10 +778,8 @@ def _split_item_once( space axis length with paddings equal to `_split_eps`% on each side of the space. """ axis_min, axis_max = ( - item_space.get_axis_value(axis, 1) - + jnp.array(self._split_eps * axis_len, jnp.int32), - item_space.get_axis_value(axis, 2) - - jnp.array(self._split_eps * axis_len, jnp.int32), + item_space.get_axis_value(axis, 1) + jnp.array(self._split_eps * axis_len, jnp.int32), + item_space.get_axis_value(axis, 2) - jnp.array(self._split_eps * axis_len, jnp.int32), ) axis_split = jax.random.randint(split_key, (), axis_min, axis_max, jnp.int32) free_index = jnp.argmin(items_mask) @@ -840,28 +817,18 @@ def _split_item_multiple_times( ) items_spaces.set_axis_value(axis, 2, new_items_axis_2) - def body_fn( - i: int, carry: Tuple[Space, chex.Array] - ) -> Tuple[Space, chex.Array]: + def body_fn(i: int, carry: Tuple[Space, chex.Array]) -> Tuple[Space, chex.Array]: items_spaces, items_mask = carry free_index = jnp.argmin(items_mask) items_spaces = jax.tree_util.tree_map( lambda coord: coord.at[free_index].set(coord[item_id]), items_spaces, ) - item_axis_1 = initial_item_axis_1 + jnp.array( - i * axis_len / num_split, jnp.int32 - ) - item_axis_2 = initial_item_axis_1 + jnp.array( - (i + 1) * axis_len / num_split, jnp.int32 - ) - new_items_axis_1 = ( - items_spaces.get_axis_value(axis, 1).at[free_index].set(item_axis_1) - ) + item_axis_1 = initial_item_axis_1 + jnp.array(i * axis_len / num_split, jnp.int32) + item_axis_2 = initial_item_axis_1 + jnp.array((i + 1) * axis_len / num_split, jnp.int32) + new_items_axis_1 = items_spaces.get_axis_value(axis, 1).at[free_index].set(item_axis_1) items_spaces.set_axis_value(axis, 1, new_items_axis_1) - new_items_axis_2 = ( - items_spaces.get_axis_value(axis, 2).at[free_index].set(item_axis_2) - ) + new_items_axis_2 = items_spaces.get_axis_value(axis, 2).at[free_index].set(item_axis_2) items_spaces.set_axis_value(axis, 2, new_items_axis_2) items_mask = items_mask.at[free_index].set(True) return items_spaces, items_mask diff --git a/jumanji/environments/packing/bin_pack/generator_test.py b/jumanji/environments/packing/bin_pack/generator_test.py index e93a4aa51..adaf725cc 100644 --- a/jumanji/environments/packing/bin_pack/generator_test.py +++ b/jumanji/environments/packing/bin_pack/generator_test.py @@ -76,9 +76,7 @@ def test_toy_generator__generate_solution( state1 = toy_generator(jax.random.PRNGKey(1)) chex.clear_trace_counter() - generate_solution = jax.jit( - chex.assert_max_traces(toy_generator.generate_solution, n=1) - ) + generate_solution = jax.jit(chex.assert_max_traces(toy_generator.generate_solution, n=1)) solution_state1 = generate_solution(jax.random.PRNGKey(1)) assert isinstance(solution_state1, State) @@ -87,9 +85,7 @@ def test_toy_generator__generate_solution( assert_trees_are_equal(solution_state1.items, state1.items) assert_trees_are_equal(solution_state1.items_mask, state1.items_mask) assert_trees_are_different(solution_state1.items_placed, state1.items_placed) - assert_trees_are_different( - solution_state1.items_location, state1.items_location - ) + assert_trees_are_different(solution_state1.items_location, state1.items_location) assert jnp.all(solution_state1.items_placed) solution_state2 = generate_solution(jax.random.PRNGKey(2)) @@ -120,9 +116,7 @@ def test_csv_generator__properties( assert csv_generator.max_num_items == dummy_generator.max_num_items assert csv_generator.max_num_ems == dummy_generator.max_num_ems - def test_csv_generator__call( - self, dummy_state: State, csv_generator: CSVGenerator - ) -> None: + def test_csv_generator__call(self, dummy_state: State, csv_generator: CSVGenerator) -> None: """Validate that the csv instance generator's call function is jittable and compiles only once. Also check that the function is independent of the key. """ @@ -146,9 +140,7 @@ def test_csv_generator__call( class TestRandomGenerator: @pytest.fixture - def random_generator( - self, max_num_items: int = 6, max_num_ems: int = 10 - ) -> RandomGenerator: + def random_generator(self, max_num_items: int = 6, max_num_ems: int = 10) -> RandomGenerator: return RandomGenerator(max_num_items, max_num_ems) def test_random_generator__properties( @@ -181,9 +173,7 @@ def test_random_generator__generate_solution( state1 = random_generator(jax.random.PRNGKey(1)) chex.clear_trace_counter() - generate_solution = jax.jit( - chex.assert_max_traces(random_generator.generate_solution, n=1) - ) + generate_solution = jax.jit(chex.assert_max_traces(random_generator.generate_solution, n=1)) solution_state1 = generate_solution(jax.random.PRNGKey(1)) assert isinstance(solution_state1, State) @@ -192,13 +182,9 @@ def test_random_generator__generate_solution( assert_trees_are_equal(solution_state1.items, state1.items) assert_trees_are_equal(solution_state1.items_mask, state1.items_mask) assert_trees_are_different(solution_state1.items_placed, state1.items_placed) - assert_trees_are_different( - solution_state1.items_location, state1.items_location - ) + assert_trees_are_different(solution_state1.items_location, state1.items_location) assert jnp.all(solution_state1.items_placed | ~solution_state1.items_mask) - items_volume = ( - item_volume(solution_state1.items) * solution_state1.items_mask - ).sum() + items_volume = (item_volume(solution_state1.items) * solution_state1.items_mask).sum() assert jnp.isclose(items_volume, solution_state1.container.volume()) solution_state2 = generate_solution(jax.random.PRNGKey(2)) diff --git a/jumanji/environments/packing/bin_pack/reward_test.py b/jumanji/environments/packing/bin_pack/reward_test.py index 1630fc292..f6ef97d3b 100644 --- a/jumanji/environments/packing/bin_pack/reward_test.py +++ b/jumanji/environments/packing/bin_pack/reward_test.py @@ -21,9 +21,7 @@ from jumanji.environments.packing.bin_pack.types import item_volume -def test__sparse_reward( - bin_pack_sparse_reward: BinPack, sparse_reward: SparseReward -) -> None: +def test__sparse_reward(bin_pack_sparse_reward: BinPack, sparse_reward: SparseReward) -> None: reward_fn = jax.jit(sparse_reward) step_fn = jax.jit(bin_pack_sparse_reward.step) state, timestep = bin_pack_sparse_reward.reset(jax.random.PRNGKey(0)) @@ -32,9 +30,7 @@ def test__sparse_reward( for item_id, is_valid in enumerate(timestep.observation.items_mask): action = jnp.array([0, item_id], jnp.int32) next_state, next_timestep = step_fn(state, action) - reward = reward_fn( - state, action, next_state, is_valid, is_done=next_timestep.last() - ) + reward = reward_fn(state, action, next_state, is_valid, is_done=next_timestep.last()) assert reward == next_timestep.reward == 0 # Check that all other invalid actions lead to the 0 reward, any ems_id > 0 is not valid at @@ -62,9 +58,7 @@ def test__sparse_reward( assert jnp.isclose(reward, item_volume(item)) -def test_dense_reward( - bin_pack_dense_reward: BinPack, dense_reward: DenseReward -) -> None: +def test_dense_reward(bin_pack_dense_reward: BinPack, dense_reward: DenseReward) -> None: reward_fn = jax.jit(dense_reward) step_fn = jax.jit(bin_pack_dense_reward.step) state, timestep = bin_pack_dense_reward.reset(jax.random.PRNGKey(0)) @@ -73,9 +67,7 @@ def test_dense_reward( for item_id, is_valid in enumerate(timestep.observation.items_mask): action = jnp.array([0, item_id], jnp.int32) next_state, next_timestep = step_fn(state, action) - reward = reward_fn( - state, action, next_state, is_valid, is_done=next_timestep.last() - ) + reward = reward_fn(state, action, next_state, is_valid, is_done=next_timestep.last()) assert reward == next_timestep.reward if is_valid: item = jumanji.tree_utils.tree_slice(timestep.observation.items, item_id) diff --git a/jumanji/environments/packing/bin_pack/space.py b/jumanji/environments/packing/bin_pack/space.py index 89f38dc45..63805ab46 100644 --- a/jumanji/environments/packing/bin_pack/space.py +++ b/jumanji/environments/packing/bin_pack/space.py @@ -34,9 +34,7 @@ class Space: z2: chex.Numeric def astype(self, dtype: Any) -> Space: - space_dict = { - key: jnp.asarray(value, dtype) for key, value in self.__dict__.items() - } + space_dict = {key: jnp.asarray(value, dtype) for key, value in self.__dict__.items()} return Space(**space_dict) def get_axis_value(self, axis: str, index: int) -> chex.Numeric: @@ -78,9 +76,9 @@ def set_axis_value(self, axis: str, index: int, value: chex.Numeric) -> None: def __repr__(self) -> str: return ( "Space(\n" - f"\tx1={repr(self.x1)}, x2={repr(self.x2)},\n" - f"\ty1={repr(self.y1)}, y2={repr(self.y2)},\n" - f"\tz1={repr(self.z1)}, z2={repr(self.z2)},\n" + f"\tx1={self.x1!r}, x2={self.x2!r},\n" + f"\ty1={self.y1!r}, y2={self.y2!r},\n" + f"\tz1={self.z1!r}, z2={self.z2!r},\n" ")" ) @@ -148,6 +146,4 @@ def hyperplane(self, axis: str, direction: str) -> Space: elif axis_direction == "z_upper": return Space(x1=-inf_, x2=inf_, y1=-inf_, y2=inf_, z1=self.z2, z2=inf_) else: - raise ValueError( - f"arguments not valid, got axis: {axis} and direction: {direction}." - ) + raise ValueError(f"arguments not valid, got axis: {axis} and direction: {direction}.") diff --git a/jumanji/environments/packing/bin_pack/space_test.py b/jumanji/environments/packing/bin_pack/space_test.py index 50eb505b9..dfb34280e 100644 --- a/jumanji/environments/packing/bin_pack/space_test.py +++ b/jumanji/environments/packing/bin_pack/space_test.py @@ -87,9 +87,7 @@ def test_space__volume(space: Space) -> None: ), ], ) -def test_space__intersection( - space1: Space, space2: Space, expected_intersection: Space -) -> None: +def test_space__intersection(space1: Space, space2: Space, expected_intersection: Space) -> None: space = space1.intersection(space2) assert space == expected_intersection @@ -128,9 +126,7 @@ def test_space__is_empty(space: Space, is_empty: bool) -> None: ), ], ) -def test_space__intersect( - space1: Space, space2: Space, expected_intersect: bool -) -> None: +def test_space__intersect(space1: Space, space2: Space, expected_intersect: bool) -> None: assert space1.intersect(space2) == expected_intersect diff --git a/jumanji/environments/packing/bin_pack/viewer.py b/jumanji/environments/packing/bin_pack/viewer.py index 223cc3813..a1f6e8d34 100644 --- a/jumanji/environments/packing/bin_pack/viewer.py +++ b/jumanji/environments/packing/bin_pack/viewer.py @@ -131,9 +131,7 @@ def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: ax = fig.get_axes()[0] return fig, ax - def _create_entities( - self, state: State - ) -> List[mpl_toolkits.mplot3d.art3d.Poly3DCollection]: + def _create_entities(self, state: State) -> List[mpl_toolkits.mplot3d.art3d.Poly3DCollection]: entities = [] n_items = len(state.items_mask) cmap = plt.cm.get_cmap("hsv", n_items) @@ -201,9 +199,7 @@ def _add_overlay(self, fig: plt.Figure, ax: plt.Axes, state: State) -> None: n_items = sum(state.items_mask) placed_items = sum(state.items_placed) - container_volume = ( - float(container.x_len) * float(container.y_len) * float(container.z_len) - ) + container_volume = float(container.x_len) * float(container.y_len) * float(container.z_len) used_volume = self._get_used_volume(state) metrics = [ ("Placed", f"{placed_items:{len(str(n_items))}}/{n_items}"), @@ -267,9 +263,7 @@ def _create_box_vertices( def _get_used_volume(self, state: State) -> float: used_volume = sum( - float(state.items.x_len[i]) - * float(state.items.y_len[i]) - * float(state.items.z_len[i]) + float(state.items.x_len[i]) * float(state.items.y_len[i]) * float(state.items.z_len[i]) for i, placed in enumerate(state.items_placed) if placed ) diff --git a/jumanji/environments/packing/flat_pack/env.py b/jumanji/environments/packing/flat_pack/env.py index e1125e98b..dd23f03e2 100644 --- a/jumanji/environments/packing/flat_pack/env.py +++ b/jumanji/environments/packing/flat_pack/env.py @@ -36,7 +36,6 @@ class FlatPack(Environment[State, specs.MultiDiscreteArray, Observation]): - """The FlatPack environment with a configurable number of row and column blocks. Here the goal of an agent is to completely fill an empty grid by placing all available blocks. It can be thought of as a discrete 2D version of the `BinPack` @@ -127,9 +126,7 @@ def __init__( compute_grid_dim(self.num_col_blocks), ) self.reward_fn = reward_fn or CellDenseReward() - self.viewer = viewer or FlatPackViewer( - "FlatPack", self.num_blocks, render_mode="human" - ) + self.viewer = viewer or FlatPackViewer("FlatPack", self.num_blocks, render_mode="human") super().__init__() def __repr__(self) -> str: @@ -159,9 +156,7 @@ def reset( return grid_state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Steps the environment. Args: @@ -326,9 +321,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: max_col_position = self.num_cols - 2 return specs.MultiDiscreteArray( - num_values=jnp.array( - [self.num_blocks, 4, max_row_position, max_col_position] - ), + num_values=jnp.array([self.num_blocks, 4, max_row_position, max_col_position]), name="action", ) @@ -404,9 +397,7 @@ def _expand_block_to_grid( grid_with_block = jnp.zeros((self.num_rows, self.num_cols), dtype=jnp.int32) place_location = (row_coord, col_coord) - grid_with_block = jax.lax.dynamic_update_slice( - grid_with_block, block, place_location - ) + grid_with_block = jax.lax.dynamic_update_slice(grid_with_block, block, place_location) return grid_with_block @@ -491,9 +482,7 @@ def _make_action_mask( cols_grid.flatten(), ) - batch_is_legal_action = jax.vmap( - self._is_legal_action, in_axes=(0, None, None, 0) - ) + batch_is_legal_action = jax.vmap(self._is_legal_action, in_axes=(0, None, None, 0)) all_actions = jnp.stack( (blocks_grid, rotations_grid, rows_grid, cols_grid), axis=-1 diff --git a/jumanji/environments/packing/flat_pack/generator.py b/jumanji/environments/packing/flat_pack/generator.py index 412c4d9e1..af71aa864 100644 --- a/jumanji/environments/packing/flat_pack/generator.py +++ b/jumanji/environments/packing/flat_pack/generator.py @@ -171,22 +171,16 @@ def _select_row_interlocks( return (grid, key), row - def _first_nonzero( - self, arr: chex.Array, axis: int, invalid_val: int = 1000 - ) -> chex.Numeric: + def _first_nonzero(self, arr: chex.Array, axis: int, invalid_val: int = 1000) -> chex.Numeric: """Returns the index of the first non-zero value in an array.""" mask = arr != 0 - return jnp.min( - jnp.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val) - ) + return jnp.min(jnp.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val)) def _crop_nonzero(self, arr_: chex.Array) -> chex.Array: """Crops a block to be of shape (3, 3).""" - row_roll, col_roll = self._first_nonzero(arr_, axis=0), self._first_nonzero( - arr_, axis=1 - ) + row_roll, col_roll = self._first_nonzero(arr_, axis=0), self._first_nonzero(arr_, axis=1) arr_ = jnp.roll(arr_, -row_roll, axis=0) arr_ = jnp.roll(arr_, -col_roll, axis=1) @@ -281,16 +275,12 @@ def __call__(self, key: chex.PRNGKey) -> State: # Finally shuffle the blocks along the leading dimension to # untangle a block's number from its position in the blocks array. key, shuffle_blocks_key = jax.random.split(key) - blocks = jax.random.permutation( - key=shuffle_blocks_key, x=blocks, axis=0, independent=False - ) + blocks = jax.random.permutation(key=shuffle_blocks_key, x=blocks, axis=0, independent=False) return State( blocks=blocks, num_blocks=num_blocks, - action_mask=jnp.ones( - (num_blocks, 4, grid_row_dim - 2, grid_col_dim - 2), dtype=bool - ), + action_mask=jnp.ones((num_blocks, 4, grid_row_dim - 2, grid_col_dim - 2), dtype=bool), grid=jnp.zeros_like(solved_grid), step_count=0, key=key, @@ -307,7 +297,6 @@ def __init__(self) -> None: super().__init__(num_row_blocks=2, num_col_blocks=2) def __call__(self, key: chex.PRNGKey) -> State: - del key solved_grid = jnp.array( @@ -351,7 +340,6 @@ def __init__(self) -> None: super().__init__(num_row_blocks=2, num_col_blocks=2) def __call__(self, key: chex.PRNGKey) -> State: - del key solved_grid = jnp.array( diff --git a/jumanji/environments/packing/flat_pack/generator_test.py b/jumanji/environments/packing/flat_pack/generator_test.py index 5047c0eae..b86c7ee43 100644 --- a/jumanji/environments/packing/flat_pack/generator_test.py +++ b/jumanji/environments/packing/flat_pack/generator_test.py @@ -84,9 +84,7 @@ def test_random_flat_pack_generator__no_retrace( ) -> None: """Checks that generator call method is only traced once when jitted.""" keys = jax.random.split(key, 2) - jitted_generator = jax.jit( - chex.assert_max_traces((random_flat_pack_generator.__call__), n=1) - ) + jitted_generator = jax.jit(chex.assert_max_traces((random_flat_pack_generator.__call__), n=1)) for key in keys: jitted_generator(key) @@ -121,12 +119,13 @@ def test_random_flat_pack_generator__fill_grid_rows( """ ( - grid, - sum_value, - num_col_blocks, - ), arr_value = random_flat_pack_generator._fill_grid_rows( - (grid_columns_partially_filled, 2, 2), 2 - ) + ( + grid, + sum_value, + num_col_blocks, + ), + arr_value, + ) = random_flat_pack_generator._fill_grid_rows((grid_columns_partially_filled, 2, 2), 2) assert grid.shape == (5, 5) assert jnp.array_equal(grid, grid_rows_partially_filled) @@ -143,9 +142,7 @@ def test_random_flat_pack_generator__select_sides( at index 0 or 2. """ - side_chosen_array = random_flat_pack_generator._select_sides( - jnp.array([1.0, 2.0, 3.0]), key - ) + side_chosen_array = random_flat_pack_generator._select_sides(jnp.array([1.0, 2.0, 3.0]), key) assert side_chosen_array.shape == (3,) # check that the output is different from the input @@ -160,11 +157,12 @@ def test_random_flat_pack_generator__select_col_interlocks( """Checks that interlocks are created along a given column of the grid.""" ( - grid_with_interlocks_selected, - new_key, - ), column = random_flat_pack_generator._select_col_interlocks( - (grid_rows_partially_filled, key), 2 - ) + ( + grid_with_interlocks_selected, + new_key, + ), + column, + ) = random_flat_pack_generator._select_col_interlocks((grid_rows_partially_filled, key), 2) assert grid_with_interlocks_selected.shape == (5, 5) assert jnp.not_equal(key, new_key).all() @@ -185,11 +183,12 @@ def test_random_flat_pack_generator__select_row_interlocks( """Checks that interlocks are created along a given row of the grid.""" ( - grid_with_interlocks_selected, - new_key, - ), row = random_flat_pack_generator._select_row_interlocks( - (grid_rows_partially_filled, key), 2 - ) + ( + grid_with_interlocks_selected, + new_key, + ), + row, + ) = random_flat_pack_generator._select_row_interlocks((grid_rows_partially_filled, key), 2) assert grid_with_interlocks_selected.shape == (5, 5) assert jnp.not_equal(key, new_key).all() @@ -208,12 +207,8 @@ def test_random_flat_pack_generator__first_nonzero( ) -> None: """Checks that the indices of the first non-zero value in a grid is found correctly.""" - first_nonzero_row = random_flat_pack_generator._first_nonzero( - block_one_placed_at_1_1, 0 - ) - first_nonzero_col = random_flat_pack_generator._first_nonzero( - block_one_placed_at_1_1, 1 - ) + first_nonzero_row = random_flat_pack_generator._first_nonzero(block_one_placed_at_1_1, 0) + first_nonzero_col = random_flat_pack_generator._first_nonzero(block_one_placed_at_1_1, 1) assert first_nonzero_row == 1 assert first_nonzero_col == 1 @@ -241,9 +236,7 @@ def test_random_flat_pack_generator__extract_block( """Checks that a block is correctly extracted from a solved grid.""" # extract block number 3 - (_, new_key), block = random_flat_pack_generator._extract_block( - (solved_grid, key), 3 - ) + (_, new_key), block = random_flat_pack_generator._extract_block((solved_grid, key), 3) assert block.shape == (3, 3) assert jnp.not_equal(key, new_key).all() diff --git a/jumanji/environments/packing/flat_pack/reward.py b/jumanji/environments/packing/flat_pack/reward.py index 74ac1166e..6e96ef165 100644 --- a/jumanji/environments/packing/flat_pack/reward.py +++ b/jumanji/environments/packing/flat_pack/reward.py @@ -65,8 +65,7 @@ def __call__( reward = jax.lax.cond( is_valid, - lambda: jnp.sum(placed_block != 0.0, dtype=jnp.float32) - / (num_rows * num_cols), + lambda: jnp.sum(placed_block != 0.0, dtype=jnp.float32) / (num_rows * num_cols), lambda: jnp.float32(0.0), ) diff --git a/jumanji/environments/packing/flat_pack/reward_test.py b/jumanji/environments/packing/flat_pack/reward_test.py index dc846a8b7..8767e2349 100644 --- a/jumanji/environments/packing/flat_pack/reward_test.py +++ b/jumanji/environments/packing/flat_pack/reward_test.py @@ -164,7 +164,6 @@ def test_cell_dense_reward( block_one_placed_at_1_1: chex.Array, block_one_placed_at_2_2: chex.Array, ) -> None: - dense_reward = jax.jit(CellDenseReward()) # Test placing block one completely correctly @@ -215,7 +214,6 @@ def test_block_dense_reward( block_one_placed_at_1_1: chex.Array, block_one_placed_at_2_2: chex.Array, ) -> None: - dense_reward = jax.jit(BlockDenseReward()) # Test placing block one completely correctly diff --git a/jumanji/environments/packing/flat_pack/utils_test.py b/jumanji/environments/packing/flat_pack/utils_test.py index 1981a3010..01f3f9e7a 100644 --- a/jumanji/environments/packing/flat_pack/utils_test.py +++ b/jumanji/environments/packing/flat_pack/utils_test.py @@ -53,7 +53,6 @@ def test_get_significant_idxs(grid_dim: int, expected_idxs: chex.Array) -> None: def test_rotate_block(block: chex.Array) -> None: - # Test with no rotation. rotated_block = rotate_block(block, 0) assert jnp.array_equal(rotated_block, block) diff --git a/jumanji/environments/packing/flat_pack/viewer.py b/jumanji/environments/packing/flat_pack/viewer.py index 639435dd1..21377b92d 100644 --- a/jumanji/environments/packing/flat_pack/viewer.py +++ b/jumanji/environments/packing/flat_pack/viewer.py @@ -94,9 +94,7 @@ def animate( Returns: Animation that can be saved as a GIF, MP4, or rendered with HTML. """ - fig, ax = plt.subplots( - num=f"{self._name}Animation", figsize=FlatPackViewer.FIGURE_SIZE - ) + fig, ax = plt.subplots(num=f"{self._name}Animation", figsize=FlatPackViewer.FIGURE_SIZE) plt.close(fig) def make_frame(state_index: int) -> None: @@ -169,9 +167,7 @@ def _draw_grid(self, grid: chex.Array, ax: plt.Axes) -> None: for col in range(cols): self._draw_grid_cell(grid[row, col], row, col, ax) - def _draw_grid_cell( - self, cell_value: int, row: int, col: int, ax: plt.Axes - ) -> None: + def _draw_grid_cell(self, cell_value: int, row: int, col: int, ax: plt.Axes) -> None: cell = plt.Rectangle((col, row), 1, 1, **self._get_cell_attributes(cell_value)) ax.add_patch(cell) if cell_value != 0: diff --git a/jumanji/environments/packing/job_shop/conftest.py b/jumanji/environments/packing/job_shop/conftest.py index 16f5c0ee2..cb7ad1010 100644 --- a/jumanji/environments/packing/job_shop/conftest.py +++ b/jumanji/environments/packing/job_shop/conftest.py @@ -67,9 +67,7 @@ def __call__(self, key: PRNGKey) -> State: ) # Initially, all machines are available (the index self.num_jobs corresponds to no-op) - machines_job_ids = jnp.array( - [self.num_jobs, self.num_jobs, self.num_jobs], jnp.int32 - ) + machines_job_ids = jnp.array([self.num_jobs, self.num_jobs, self.num_jobs], jnp.int32) machines_remaining_times = jnp.array([0, 0, 0], jnp.int32) # Initial action mask given the problem instance @@ -83,9 +81,7 @@ def __call__(self, key: PRNGKey) -> State: ) # Initially, all ops have yet to be scheduled (ignore the padded element) - ops_mask = jnp.array( - [[True, True, True], [True, True, True], [True, True, False]], bool - ) + ops_mask = jnp.array([[True, True, True], [True, True, True], [True, True, False]], bool) # Initially, none of the operations have been scheduled scheduled_times = jnp.array( diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index ec1e22e79..bb80c60f0 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -168,9 +168,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Updates the status of all machines, the status of the operations, and increments the time step. It updates the environment state and the timestep (which contains the new observation). It calculates the reward based on the three terminal conditions: @@ -226,8 +224,7 @@ def step( # Check if all machines are idle simultaneously all_machines_idle = jnp.all( - (updated_machines_job_ids == self.no_op_idx) - & (updated_machines_remaining_times == 0) + (updated_machines_job_ids == self.no_op_idx) & (updated_machines_remaining_times == 0) ) # Check if the schedule has finished @@ -300,9 +297,7 @@ def _update_operations( is_next_op = jnp.zeros(shape=(self.num_jobs, self.max_num_ops), dtype=bool) is_next_op = is_next_op.at[jnp.arange(self.num_jobs), op_ids].set(True) is_new_job_and_next_op = jnp.logical_and(is_new_job, is_next_op) - updated_scheduled_times = jnp.where( - is_new_job_and_next_op, step_count, scheduled_times - ) + updated_scheduled_times = jnp.where(is_new_job_and_next_op, step_count, scheduled_times) updated_ops_mask = ops_mask & ~is_new_job_and_next_op return updated_ops_mask, updated_scheduled_times @@ -352,9 +347,7 @@ def _update_machines( ) # For busy machines, decrement the remaining time by one - updated_machines_remaining_times = jnp.where( - remaining_times > 0, remaining_times - 1, 0 - ) + updated_machines_remaining_times = jnp.where(remaining_times > 0, remaining_times - 1, 0) return updated_machines_job_ids, updated_machines_remaining_times @@ -528,13 +521,9 @@ def _is_action_valid( """ is_machine_available = machines_remaining_times[machine_id] == 0 is_correct_machine = ops_machine_ids[job_id, op_id] == machine_id - is_job_ready = ~jnp.any( - (machines_job_ids == job_id) & (machines_remaining_times > 0) - ) + is_job_ready = ~jnp.any((machines_job_ids == job_id) & (machines_remaining_times > 0)) is_job_finished = jnp.all(~updated_ops_mask[job_id]) - return ( - is_machine_available & is_correct_machine & is_job_ready & ~is_job_finished - ) + return is_machine_available & is_correct_machine & is_job_ready & ~is_job_finished def _set_busy(self, job_id: jnp.int32, action: chex.Array) -> Any: """Determine, for a given action and job, whether the job is a new job to be @@ -581,9 +570,7 @@ def _create_action_mask( # vmap over the jobs (and their ops) and vmap over the machines action_mask = jax.vmap( - jax.vmap( - self._is_action_valid, in_axes=(0, 0, None, None, None, None, None) - ), + jax.vmap(self._is_action_valid, in_axes=(0, 0, None, None, None, None, None)), in_axes=(None, None, 0, None, None, None, None), )( job_indexes, diff --git a/jumanji/environments/packing/job_shop/env_test.py b/jumanji/environments/packing/job_shop/env_test.py index 964042dac..75ec06294 100644 --- a/jumanji/environments/packing/job_shop/env_test.py +++ b/jumanji/environments/packing/job_shop/env_test.py @@ -136,9 +136,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[True, True, True], [False, True, True], [False, True, False]] - ) + == jnp.array([[True, True, True], [False, True, True], [False, True, False]]) ) assert jnp.all( next_state.scheduled_times @@ -154,9 +152,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=1 -> T=2 next_action = jnp.array([3, 3, 3]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids == jnp.array( @@ -191,9 +187,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[True, True, True], [False, True, True], [False, True, False]] - ) + == jnp.array([[True, True, True], [False, True, True], [False, True, False]]) ) assert jnp.all( next_state.scheduled_times @@ -209,9 +203,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=2 -> T=3 next_action = jnp.array([0, 3, 1]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -247,9 +239,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, True, True], [False, False, True], [False, True, False]] - ) + == jnp.array([[False, True, True], [False, False, True], [False, True, False]]) ) assert jnp.all( next_state.scheduled_times @@ -265,9 +255,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=3 -> T=4 next_action = jnp.array([3, 3, 3]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -303,9 +291,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, True, True], [False, False, True], [False, True, False]] - ) + == jnp.array([[False, True, True], [False, False, True], [False, True, False]]) ) assert jnp.all( next_state.scheduled_times @@ -321,9 +307,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=4 -> T=5 next_action = jnp.array([3, 1, 2]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -359,9 +343,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, True, True], [False, False, False], [False, False, False]] - ) + == jnp.array([[False, True, True], [False, False, False], [False, False, False]]) ) assert jnp.all( next_state.scheduled_times @@ -377,9 +359,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=5 -> T=6 next_action = jnp.array([3, 3, 3]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -415,9 +395,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, True, True], [False, False, False], [False, False, False]] - ) + == jnp.array([[False, True, True], [False, False, False], [False, False, False]]) ) assert jnp.all( next_state.scheduled_times @@ -433,9 +411,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=6 -> T=7 next_action = jnp.array([3, 3, 3]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -471,9 +447,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, True, True], [False, False, False], [False, False, False]] - ) + == jnp.array([[False, True, True], [False, False, False], [False, False, False]]) ) assert jnp.all( next_state.scheduled_times @@ -489,9 +463,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=7 -> T=8 next_action = jnp.array([3, 3, 3]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -527,9 +499,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, True, True], [False, False, False], [False, False, False]] - ) + == jnp.array([[False, True, True], [False, False, False], [False, False, False]]) ) assert jnp.all( next_state.scheduled_times @@ -545,9 +515,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=8 -> T=9 next_action = jnp.array([3, 0, 3]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -583,9 +551,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, False, True], [False, False, False], [False, False, False]] - ) + == jnp.array([[False, False, True], [False, False, False], [False, False, False]]) ) assert jnp.all( next_state.scheduled_times @@ -601,9 +567,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # # STEP T=9 -> T=10 next_action = jnp.array([3, 3, 3]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -639,9 +603,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, False, True], [False, False, False], [False, False, False]] - ) + == jnp.array([[False, False, True], [False, False, False], [False, False, False]]) ) assert jnp.all( next_state.scheduled_times @@ -657,9 +619,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # # STEP T=10 -> T=11 next_action = jnp.array([3, 3, 0]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -695,9 +655,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, False, False], [False, False, False], [False, False, False]] - ) + == jnp.array([[False, False, False], [False, False, False], [False, False, False]]) ) assert jnp.all( next_state.scheduled_times @@ -713,9 +671,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: # STEP T=10 -> T=11 next_action = jnp.array([3, 3, 3]) - next_state, next_timestep = job_shop_env.step( - state=next_state, action=next_action - ) + next_state, next_timestep = job_shop_env.step(state=next_state, action=next_action) assert jnp.all( next_state.ops_machine_ids @@ -751,9 +707,7 @@ def test_job_shop__step(self, job_shop_env: JobShop) -> None: ) assert jnp.all( next_state.ops_mask - == jnp.array( - [[False, False, False], [False, False, False], [False, False, False]] - ) + == jnp.array([[False, False, False], [False, False, False], [False, False, False]]) ) assert jnp.all( next_state.scheduled_times diff --git a/jumanji/environments/packing/job_shop/generator.py b/jumanji/environments/packing/job_shop/generator.py index cfbec6d6c..241745fcd 100644 --- a/jumanji/environments/packing/job_shop/generator.py +++ b/jumanji/environments/packing/job_shop/generator.py @@ -124,9 +124,7 @@ class RandomGenerator(Generator): the max. """ - def __init__( - self, num_jobs: int, num_machines: int, max_num_ops: int, max_op_duration: int - ): + def __init__(self, num_jobs: int, num_machines: int, max_num_ops: int, max_op_duration: int): super().__init__(num_jobs, num_machines, max_num_ops, max_op_duration) def __call__(self, key: chex.PRNGKey) -> State: diff --git a/jumanji/environments/packing/job_shop/generator_test.py b/jumanji/environments/packing/job_shop/generator_test.py index e35ebf3fe..7d94619fa 100644 --- a/jumanji/environments/packing/job_shop/generator_test.py +++ b/jumanji/environments/packing/job_shop/generator_test.py @@ -83,9 +83,7 @@ def random_generator(self) -> RandomGenerator: max_op_duration=8, ) - def test_random_generator__properties( - self, random_generator: RandomGenerator - ) -> None: + def test_random_generator__properties(self, random_generator: RandomGenerator) -> None: """Validate that the random instance generator has the correct properties.""" assert random_generator.num_jobs == 20 assert random_generator.num_machines == 10 diff --git a/jumanji/environments/packing/job_shop/viewer.py b/jumanji/environments/packing/job_shop/viewer.py index 9ad339348..dea9a5add 100644 --- a/jumanji/environments/packing/job_shop/viewer.py +++ b/jumanji/environments/packing/job_shop/viewer.py @@ -118,12 +118,7 @@ def make_frame(state_index: int) -> None: def _prepare_figure(self, ax: plt.Axes) -> None: ax.set_xlabel("Time") ax.set_ylabel("Machine ID") - xlim = ( - self._num_jobs - * self._max_num_ops - * self._max_op_duration - // self._num_machines - ) + xlim = self._num_jobs * self._max_num_ops * self._max_op_duration // self._num_machines ax.set_xlim(0, xlim) ax.set_ylim(-0.9, self._num_machines) ax.xaxis.get_major_locator().set_params(integer=True) diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py index 3d544132a..d34f365b4 100644 --- a/jumanji/environments/packing/knapsack/env.py +++ b/jumanji/environments/packing/knapsack/env.py @@ -137,9 +137,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=self._state_to_observation(state)) return state, timestep - def step( - self, state: State, action: chex.Numeric - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Numeric) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: diff --git a/jumanji/environments/packing/knapsack/env_test.py b/jumanji/environments/packing/knapsack/env_test.py index 32ea10cf7..ea4639435 100644 --- a/jumanji/environments/packing/knapsack/env_test.py +++ b/jumanji/environments/packing/knapsack/env_test.py @@ -72,21 +72,15 @@ def test_knapsack_sparse__step(self, knapsack_sparse_reward: Knapsack) -> None: assert jnp.array_equal(new_state.packed_items, state.packed_items) assert jnp.array_equal(new_state.remaining_budget, state.remaining_budget) - def test_knapsack_sparse__does_not_smoke( - self, knapsack_sparse_reward: Knapsack - ) -> None: + def test_knapsack_sparse__does_not_smoke(self, knapsack_sparse_reward: Knapsack) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(knapsack_sparse_reward) - def test_knapsack_sparse__specs_does_not_smoke( - self, knapsack_sparse_reward: Knapsack - ) -> None: + def test_knapsack_sparse__specs_does_not_smoke(self, knapsack_sparse_reward: Knapsack) -> None: """Test that we can access specs without any errors.""" check_env_specs_does_not_smoke(knapsack_sparse_reward) - def test_knapsack_sparse__trajectory_action( - self, knapsack_sparse_reward: Knapsack - ) -> None: + def test_knapsack_sparse__trajectory_action(self, knapsack_sparse_reward: Knapsack) -> None: """Checks that the agent stops when the remaining budget does not allow extra items and that the appropriate reward is received. """ @@ -111,9 +105,7 @@ def test_knapsack_sparse__trajectory_action( assert not jnp.any(timestep.observation.action_mask) assert timestep.last() - def test_knapsack_sparse__invalid_action( - self, knapsack_sparse_reward: Knapsack - ) -> None: + def test_knapsack_sparse__invalid_action(self, knapsack_sparse_reward: Knapsack) -> None: """Checks that an invalid action leads to a termination and that the appropriate reward is returned. """ @@ -167,9 +159,7 @@ def test_knapsack_dense__step(self, knapsack_dense_reward: Knapsack) -> None: key = jax.random.PRNGKey(0) state, timestep = knapsack_dense_reward.reset(key) - action = jax.random.randint( - key, shape=(), minval=0, maxval=knapsack_dense_reward.num_items - ) + action = jax.random.randint(key, shape=(), minval=0, maxval=knapsack_dense_reward.num_items) new_state, next_timestep = step_fn(state, action) # Check that the state has changed. @@ -186,15 +176,11 @@ def test_knapsack_dense__step(self, knapsack_dense_reward: Knapsack) -> None: assert jnp.array_equal(new_state.packed_items, state.packed_items) assert jnp.array_equal(new_state.remaining_budget, state.remaining_budget) - def test_knapsack_dense__does_not_smoke( - self, knapsack_dense_reward: Knapsack - ) -> None: + def test_knapsack_dense__does_not_smoke(self, knapsack_dense_reward: Knapsack) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(knapsack_dense_reward) - def test_knapsack_dense__trajectory_action( - self, knapsack_dense_reward: Knapsack - ) -> None: + def test_knapsack_dense__trajectory_action(self, knapsack_dense_reward: Knapsack) -> None: """Checks that the agent stops when the remaining budget does not allow extra items and that the appropriate reward is received. """ @@ -219,9 +205,7 @@ def test_knapsack_dense__trajectory_action( assert not jnp.any(timestep.observation.action_mask) assert timestep.last() - def test_knapsack_dense__invalid_action( - self, knapsack_dense_reward: Knapsack - ) -> None: + def test_knapsack_dense__invalid_action(self, knapsack_dense_reward: Knapsack) -> None: """Checks that an invalid action leads to a termination and that the appropriate reward is returned. """ diff --git a/jumanji/environments/packing/knapsack/generator.py b/jumanji/environments/packing/knapsack/generator.py index 76994a8cf..e294cf049 100644 --- a/jumanji/environments/packing/knapsack/generator.py +++ b/jumanji/environments/packing/knapsack/generator.py @@ -65,9 +65,7 @@ def __call__(self, key: chex.PRNGKey) -> State: key, sample_key = jax.random.split(key) # Sample weights and values of the items from a uniform distribution on [0, 1] - weights, values = jax.random.uniform( - sample_key, (2, self.num_items), minval=0, maxval=1 - ) + weights, values = jax.random.uniform(sample_key, (2, self.num_items), minval=0, maxval=1) # Initially, no items are packed. packed_items = jnp.zeros(self.num_items, dtype=bool) diff --git a/jumanji/environments/packing/knapsack/generator_test.py b/jumanji/environments/packing/knapsack/generator_test.py index 344804633..6ddcd314f 100644 --- a/jumanji/environments/packing/knapsack/generator_test.py +++ b/jumanji/environments/packing/knapsack/generator_test.py @@ -49,9 +49,7 @@ class TestRandomGenerator: def random_generator(self) -> RandomGenerator: return RandomGenerator(num_items=50, total_budget=12.5) - def test_random_generator__properties( - self, random_generator: RandomGenerator - ) -> None: + def test_random_generator__properties(self, random_generator: RandomGenerator) -> None: """Validate that the random instance generator has the correct properties.""" assert random_generator.num_items == 50 assert random_generator.total_budget == 12.5 diff --git a/jumanji/environments/packing/knapsack/reward_test.py b/jumanji/environments/packing/knapsack/reward_test.py index 4a4bbacbc..b6f8c828c 100644 --- a/jumanji/environments/packing/knapsack/reward_test.py +++ b/jumanji/environments/packing/knapsack/reward_test.py @@ -19,9 +19,7 @@ from jumanji.environments.packing.knapsack.reward import DenseReward, SparseReward -def test_dense_reward( - knapsack_dense_reward: Knapsack, dense_reward: DenseReward -) -> None: +def test_dense_reward(knapsack_dense_reward: Knapsack, dense_reward: DenseReward) -> None: dense_reward = jax.jit(dense_reward) step_fn = jax.jit(knapsack_dense_reward.step) state, timestep = knapsack_dense_reward.reset(jax.random.PRNGKey(0)) @@ -39,9 +37,7 @@ def test_dense_reward( assert reward == 0 -def test_sparse_reward( # noqa: CCR001 - knapsack_sparse_reward: Knapsack, sparse_reward: SparseReward -) -> None: +def test_sparse_reward(knapsack_sparse_reward: Knapsack, sparse_reward: SparseReward) -> None: sparse_reward = jax.jit(sparse_reward) step_fn = jax.jit(knapsack_sparse_reward.step) state, timestep = knapsack_sparse_reward.reset(jax.random.PRNGKey(0)) @@ -52,9 +48,7 @@ def test_sparse_reward( # noqa: CCR001 for action, is_valid in enumerate(timestep.observation.action_mask): if is_valid: next_state, timestep = step_fn(state, action) - reward = sparse_reward( - state, action, next_state, is_valid, is_done=timestep.last() - ) + reward = sparse_reward(state, action, next_state, is_valid, is_done=timestep.last()) if timestep.last(): # At the end of the episode, check that the reward is the total values of # packed items. diff --git a/jumanji/environments/packing/knapsack/viewer.py b/jumanji/environments/packing/knapsack/viewer.py index 279a2d365..b6e6b9843 100644 --- a/jumanji/environments/packing/knapsack/viewer.py +++ b/jumanji/environments/packing/knapsack/viewer.py @@ -27,9 +27,7 @@ class KnapsackViewer(Viewer): FIGURE_SIZE = (5.0, 5.0) - def __init__( - self, name: str, render_mode: str = "human", total_budget: float = 2.0 - ) -> None: + def __init__(self, name: str, render_mode: str = "human", total_budget: float = 2.0) -> None: """Viewer for the `Knapsack` environment. Args: diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py index 995cb1fd6..a7e5854f5 100644 --- a/jumanji/environments/packing/tetris/env.py +++ b/jumanji/environments/packing/tetris/env.py @@ -93,13 +93,9 @@ def __init__( viewer: `Viewer` used for rendering. Defaults to `TetrisViewer`. """ if num_rows < 4: - raise ValueError( - f"The `num_rows` must be >= 4, but got num_rows={num_rows}" - ) + raise ValueError(f"The `num_rows` must be >= 4, but got num_rows={num_rows}") if num_cols < 4: - raise ValueError( - f"The `num_cols` must be >= 4, but got num_cols={num_cols}" - ) + raise ValueError(f"The `num_cols` must be >= 4, but got num_cols={num_cols}") self.num_rows = num_rows self.num_cols = num_cols self.padded_num_rows = num_rows + 3 @@ -134,12 +130,8 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep: `TimeStep` corresponding to the first timestep returned by the environment. """ - grid_padded = jnp.zeros( - shape=(self.padded_num_rows, self.padded_num_cols), dtype=jnp.int32 - ) - tetromino, tetromino_index = utils.sample_tetromino_list( - key, self.TETROMINOES_LIST - ) + grid_padded = jnp.zeros(shape=(self.padded_num_rows, self.padded_num_cols), dtype=jnp.int32) + tetromino, tetromino_index = utils.sample_tetromino_list(key, self.TETROMINOES_LIST) action_mask = self._calculate_action_mask(grid_padded, tetromino_index) state = State( @@ -168,9 +160,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=observation) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -186,9 +176,7 @@ def step( key, sample_key = jax.random.split(state.key) tetromino = self._rotate(rotation_index, tetromino_index) # Place the tetromino in the selected place - grid_padded, y_position = utils.place_tetromino( - state.grid_padded, tetromino, x_position - ) + grid_padded, y_position = utils.place_tetromino(state.grid_padded, tetromino, x_position) # A line is full when it doesn't contain any 0. full_lines = jnp.all(grid_padded[:, : self.num_cols] != 0, axis=1) nbr_full_lines = sum(full_lines) @@ -283,9 +271,7 @@ def observation_spec(self) -> specs.Spec[Observation]: maximum=True, name="action_mask", ), - step_count=specs.DiscreteArray( - self.time_limit, dtype=jnp.int32, name="step_count" - ), + step_count=specs.DiscreteArray(self.time_limit, dtype=jnp.int32, name="step_count"), ) @cached_property @@ -321,9 +307,7 @@ def animate( return self._viewer.animate(states, interval, save_path) - def _calculate_action_mask( - self, grid_padded: chex.Array, tetromino_index: int - ) -> chex.Array: + def _calculate_action_mask(self, grid_padded: chex.Array, tetromino_index: int) -> chex.Array: """Calculate the mask for legal actions in the game. Args: diff --git a/jumanji/environments/packing/tetris/utils.py b/jumanji/environments/packing/tetris/utils.py index 8e283cbfd..01b205710 100644 --- a/jumanji/environments/packing/tetris/utils.py +++ b/jumanji/environments/packing/tetris/utils.py @@ -56,9 +56,7 @@ def check_valid_tetromino_placement( Returns: chex.array of shape (). """ - crop = jax.lax.dynamic_slice( - grid, start_indices=(y_position, x_position), slice_sizes=(4, 4) - ) + crop = jax.lax.dynamic_slice(grid, start_indices=(y_position, x_position), slice_sizes=(4, 4)) crop = crop + tetromino return ~jnp.any(crop >= 2) @@ -82,9 +80,7 @@ def tetromino_action_mask(grid_padded: chex.Array, tetromino: chex.Array) -> che to all possible positions for one side of a tetromino in the `grid_padded`. """ tetromino_mask = tetromino.at[1, :].set(tetromino[1, :] + tetromino[2, :]) - tetromino_mask = tetromino_mask.at[0, :].set( - tetromino_mask[0, :] + tetromino_mask[1, :] - ) + tetromino_mask = tetromino_mask.at[0, :].set(tetromino_mask[0, :] + tetromino_mask[1, :]) tetromino_mask = jnp.clip(tetromino_mask, a_max=1) num_cols = grid_padded.shape[1] - 3 # Check if tetromino can be placed at the top of the grid, if so it means @@ -156,9 +152,7 @@ def place_tetromino( # Update the `grid_padded`. tetromino_color_id = grid_padded.max() + 1 tetromino = tetromino * tetromino_color_id - new_grid_padded = jax.lax.dynamic_update_slice( - grid_padded, tetromino, (y_position, x_position) - ) + new_grid_padded = jax.lax.dynamic_update_slice(grid_padded, tetromino, (y_position, x_position)) # Get the max of the old and the new `grid_padded`. grid_padded = jnp.maximum(grid_padded, new_grid_padded) return grid_padded, y_position diff --git a/jumanji/environments/packing/tetris/utils_test.py b/jumanji/environments/packing/tetris/utils_test.py index 939a2fdc5..a27292790 100644 --- a/jumanji/environments/packing/tetris/utils_test.py +++ b/jumanji/environments/packing/tetris/utils_test.py @@ -50,9 +50,7 @@ def grid_padded() -> chex.Array: @pytest.fixture def full_lines() -> chex.Array: """Full lines related to the grid_padded""" - full_lines = jnp.array( - [False, False, False, False, False, True, False, False, False, False] - ) + full_lines = jnp.array([False, False, False, False, False, True, False, False, False, False]) return full_lines @@ -83,9 +81,7 @@ def test_place_tetromino(grid_padded: chex.Array, tetromino: chex.Array) -> None new_grid_padded, _ = place_tetromino_fn(grid_padded, tetromino, 0) cells_count = jnp.clip(new_grid_padded, a_max=1).sum() old_cells_count = jnp.clip(grid_padded, a_max=1).sum() - assert ( - cells_count == old_cells_count + 4 - ) # 4 is the number of filled cells a tetromino + assert cells_count == old_cells_count + 4 # 4 is the number of filled cells a tetromino expected_binary_grid_padded = grid_padded.at[2:6, 0:4].add(tetromino) new_grid_padded_binary = jnp.clip(new_grid_padded, a_max=1) assert (expected_binary_grid_padded == new_grid_padded_binary).all() diff --git a/jumanji/environments/packing/tetris/viewer.py b/jumanji/environments/packing/tetris/viewer.py index 9a728f210..d18b2dfc8 100644 --- a/jumanji/environments/packing/tetris/viewer.py +++ b/jumanji/environments/packing/tetris/viewer.py @@ -30,9 +30,7 @@ class TetrisViewer(Viewer): FIGURE_SIZE = (6.0, 10.0) - def __init__( - self, num_rows: int, num_cols: int, render_mode: str = "human" - ) -> None: + def __init__(self, num_rows: int, num_cols: int, render_mode: str = "human") -> None: """ Viewer for a `Tetris` environment. @@ -87,9 +85,7 @@ def render(self, state: State) -> Optional[NDArray]: self._add_grid_image(ax, grid) return self._display(fig) - def _move_tetromino( - self, state: State, old_padded_grid: chex.Array - ) -> List[chex.Array]: + def _move_tetromino(self, state: State, old_padded_grid: chex.Array) -> List[chex.Array]: """Shifts the tetromino from center to the selected position. Args: @@ -106,9 +102,7 @@ def _move_tetromino( step = 1 if center_position < state.x_position else -1 for xi in range(center_position, state.x_position + step, step): tetromino_zonne = jnp.zeros((4, state.grid_padded.shape[1])) - tetromino_zonne = tetromino_zonne.at[0:4, xi : xi + 4].add( - state.old_tetromino_rotated - ) + tetromino_zonne = tetromino_zonne.at[0:4, xi : xi + 4].add(state.old_tetromino_rotated) # Delete the cols dedicated for the right padding tetromino_zonne = tetromino_zonne[:, : self.num_cols] # Stack the tetromino with grid position @@ -116,9 +110,7 @@ def _move_tetromino( grids.append(mixed_grid) return grids - def _crush_lines( - self, state: State, grid: chex.Array, n: int = 2 - ) -> List[chex.Array]: + def _crush_lines(self, state: State, grid: chex.Array, n: int = 2) -> List[chex.Array]: """Creates animation when a line is crushed by toggling its value. Args: @@ -133,9 +125,7 @@ def _crush_lines( for _i in range(n): animation_list.append(grid) # `State.full_lines` is a vector of booleans of shape num_rows+3. - full_lines = jnp.concatenate( - [jnp.full((4,), False), state.full_lines[: self.num_rows]] - ) + full_lines = jnp.concatenate([jnp.full((4,), False), state.full_lines[: self.num_rows]]) full_lines_reshaped = full_lines[:, jnp.newaxis] animation_list.append( jnp.where(~full_lines_reshaped, grid, jnp.zeros((1, grid.shape[1]))) @@ -156,15 +146,11 @@ def _create_rendering_grid(self, state: State) -> chex.Array: center_position = self.num_cols - 4 tetromino_color_id = state.grid_padded.max() + 1 colored_tetromino = state.new_tetromino * tetromino_color_id - tetromino = tetromino.at[0:4, center_position : center_position + 4].set( - colored_tetromino - ) + tetromino = tetromino.at[0:4, center_position : center_position + 4].set(colored_tetromino) rendering_grid = jnp.vstack((tetromino, grid)) return rendering_grid - def _drop_tetromino( - self, state: State, old_padded_grid: chex.Array - ) -> List[NDArray]: + def _drop_tetromino(self, state: State, old_padded_grid: chex.Array) -> List[NDArray]: """Creates animation while the tetromino is droping verticaly. Args: @@ -179,15 +165,13 @@ def _drop_tetromino( # `y_position` may contain a value -1 if it bellongs to first tetromino. y_position = state.y_position if state.y_position != -1 else self.num_rows - 1 # Stack the tetromino's rows on top of the grid. - rendering_grid = jnp.vstack( - (jnp.zeros((4, old_padded_grid.shape[1])), old_padded_grid) - ) + rendering_grid = jnp.vstack((jnp.zeros((4, old_padded_grid.shape[1])), old_padded_grid)) # the animation grid contains 4 rows at the top dedicated to show the tetromino. for yi in range(y_position + 4 + 1): # Place the tetromino. - grid = rendering_grid.at[ - yi : yi + 4, state.x_position : state.x_position + 4 - ].add(state.old_tetromino_rotated) + grid = rendering_grid.at[yi : yi + 4, state.x_position : state.x_position + 4].add( + state.old_tetromino_rotated + ) # Crop the grid (delete the 3 rows and columns padding at the bottom and the right.) grid = grid[: self.num_rows + 4, : self.num_cols] grids.append(grid) @@ -210,9 +194,7 @@ def animate( Returns: Animation that can be saved as a GIF, MP4, or rendered with HTML. """ - fig, ax = plt.subplots( - num=f"{self._name}Animation", figsize=TetrisViewer.FIGURE_SIZE - ) + fig, ax = plt.subplots(num=f"{self._name}Animation", figsize=TetrisViewer.FIGURE_SIZE) plt.close(fig) def make_frame(grid_index: int) -> None: @@ -239,9 +221,7 @@ def make_frame(grid_index: int) -> None: grids.extend(x_shift_grids) grids.extend(y_shift_grids) score = state.score - state.reward - scores.extend( - [score for i in range(len(x_shift_grids) + len(y_shift_grids))] - ) + scores.extend([score for i in range(len(x_shift_grids) + len(y_shift_grids))]) if state.full_lines.sum() > 0: grids += self._crush_lines(state, grids[-1]) scores.extend([score for i in range(len(grids) - len(scores))]) @@ -287,20 +267,14 @@ def _draw_grid(self, grid: chex.Array, ax: plt.Axes) -> None: for col in range(cols): self._draw_grid_cell(grid[row, col], row, col, ax) - def _draw_grid_cell( - self, cell_value: int, row: int, col: int, ax: plt.Axes - ) -> None: + def _draw_grid_cell(self, cell_value: int, row: int, col: int, ax: plt.Axes) -> None: is_padd = row < 4 - cell = plt.Rectangle( - (col, row), 1, 1, **self._get_cell_attributes(cell_value, is_padd) - ) + cell = plt.Rectangle((col, row), 1, 1, **self._get_cell_attributes(cell_value, is_padd)) ax.add_patch(cell) def _get_cell_attributes(self, cell_value: int, is_padd: bool) -> Dict[str, Any]: cell_value = int(cell_value) - color_id = ( - cell_value if cell_value == 0 else cell_value % (len(self.colors) - 1) + 1 - ) + color_id = cell_value if cell_value == 0 else cell_value % (len(self.colors) - 1) + 1 color = self.colors[color_id] edge_color = self.edgecolors[is_padd] diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py index 28dbe9761..525032f3b 100644 --- a/jumanji/environments/routing/cleaner/env.py +++ b/jumanji/environments/routing/cleaner/env.py @@ -99,9 +99,7 @@ def __init__( viewer: `Viewer` used for rendering. Defaults to `CleanerViewer` with "human" render mode. """ - self.generator = generator or RandomGenerator( - num_rows=10, num_cols=10, num_agents=3 - ) + self.generator = generator or RandomGenerator(num_rows=10, num_cols=10, num_agents=3) self.num_agents = self.generator.num_agents self.num_rows = self.generator.num_rows self.num_cols = self.generator.num_cols @@ -141,9 +139,7 @@ def observation_spec(self) -> specs.Spec[Observation]: agents_locations = specs.BoundedArray( (self.num_agents, 2), jnp.int32, [0, 0], self.grid_shape, "agents_locations" ) - action_mask = specs.BoundedArray( - (self.num_agents, 4), bool, False, True, "action_mask" - ) + action_mask = specs.BoundedArray((self.num_agents, 4), bool, False, True, "action_mask") step_count = specs.BoundedArray((), jnp.int32, 0, self.time_limit, "step_count") return specs.Spec( Observation, @@ -196,9 +192,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. If an action is invalid, the corresponding agent does not move and @@ -299,14 +293,9 @@ def _compute_reward(self, prev_state: State, state: State) -> chex.Array: Since walls and dirty tiles do not change, counting the tiles which changed since previeous step is the same as counting the tiles which were cleaned. """ - return ( - jnp.sum(prev_state.grid != state.grid, dtype=float) - - self.penalty_per_timestep - ) + return jnp.sum(prev_state.grid != state.grid, dtype=float) - self.penalty_per_timestep - def _compute_action_mask( - self, grid: chex.Array, agents_locations: chex.Array - ) -> chex.Array: + def _compute_action_mask(self, grid: chex.Array, agents_locations: chex.Array) -> chex.Array: """Compute the action mask. An action is masked if it leads to a WALL or out of the maze. @@ -323,9 +312,9 @@ def is_move_valid(agent_location: chex.Array, move: chex.Array) -> chex.Array: ) # vmap over the moves and agents - action_mask = jax.vmap( - jax.vmap(is_move_valid, in_axes=(None, 0)), in_axes=(0, None) - )(agents_locations, MOVES) + action_mask = jax.vmap(jax.vmap(is_move_valid, in_axes=(None, 0)), in_axes=(0, None))( + agents_locations, MOVES + ) return action_mask @@ -338,9 +327,7 @@ def _observation_from_state(self, state: State) -> Observation: step_count=state.step_count, ) - def _is_action_valid( - self, action: chex.Array, action_mask: chex.Array - ) -> chex.Array: + def _is_action_valid(self, action: chex.Array, action_mask: chex.Array) -> chex.Array: """Compute, for the action of each agent, whether said action is valid. Args: diff --git a/jumanji/environments/routing/cleaner/env_test.py b/jumanji/environments/routing/cleaner/env_test.py index f386a5dad..4aa44b98a 100644 --- a/jumanji/environments/routing/cleaner/env_test.py +++ b/jumanji/environments/routing/cleaner/env_test.py @@ -135,9 +135,7 @@ def test_cleaner__step(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: assert jnp.all(state.agents_locations[1] == jnp.array([0, 0])) assert jnp.all(state.agents_locations[2] == jnp.array([1, 1])) - def test_cleaner__step_invalid_action( - self, cleaner: Cleaner, key: chex.PRNGKey - ) -> None: + def test_cleaner__step_invalid_action(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: state, _ = cleaner.reset(key) step_fn = jax.jit(cleaner.step) @@ -153,9 +151,7 @@ def test_cleaner__step_invalid_action( assert timestep.reward == 1 - cleaner.penalty_per_timestep - def test_cleaner__initial_action_mask( - self, cleaner: Cleaner, key: chex.PRNGKey - ) -> None: + def test_cleaner__initial_action_mask(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: state, _ = cleaner.reset(key) # All agents can only move right in the initial state @@ -182,12 +178,8 @@ def test_cleaner__action_mask(self, cleaner: Cleaner, key: chex.PRNGKey) -> None def test_cleaner__does_not_smoke(self, cleaner: Cleaner) -> None: def select_actions(key: chex.PRNGKey, observation: Observation) -> chex.Array: @jax.vmap # map over the keys and agents - def select_action( - key: chex.PRNGKey, agent_action_mask: chex.Array - ) -> chex.Array: - return jax.random.choice( - key, jnp.arange(4), p=agent_action_mask.flatten() - ) + def select_action(key: chex.PRNGKey, agent_action_mask: chex.Array) -> chex.Array: + return jax.random.choice(key, jnp.arange(4), p=agent_action_mask.flatten()) subkeys = jax.random.split(key, cleaner.num_agents) return select_action(subkeys, observation.action_mask) @@ -205,7 +197,5 @@ def test_cleaner__compute_extras(self, cleaner: Cleaner, key: chex.PRNGKey) -> N assert list(extras.keys()) == ["ratio_dirty_tiles", "num_dirty_tiles"] assert 0 <= extras["ratio_dirty_tiles"] <= 1 grid = state.grid - assert extras["ratio_dirty_tiles"] == jnp.sum(grid == DIRTY) / jnp.sum( - grid != WALL - ) + assert extras["ratio_dirty_tiles"] == jnp.sum(grid == DIRTY) / jnp.sum(grid != WALL) assert extras["num_dirty_tiles"] == jnp.sum(grid == DIRTY) diff --git a/jumanji/environments/routing/cleaner/generator.py b/jumanji/environments/routing/cleaner/generator.py index 52475bde5..f853d0ab1 100644 --- a/jumanji/environments/routing/cleaner/generator.py +++ b/jumanji/environments/routing/cleaner/generator.py @@ -68,9 +68,7 @@ def __call__(self, key: chex.PRNGKey) -> State: state: the generated state. """ generator_key, state_key = jax.random.split(key) - maze = maze_generation.generate_maze( - self.num_cols, self.num_rows, generator_key - ) + maze = maze_generation.generate_maze(self.num_cols, self.num_rows, generator_key) grid = self._adapt_values(maze) diff --git a/jumanji/environments/routing/cleaner/viewer.py b/jumanji/environments/routing/cleaner/viewer.py index 7eeab7662..a1a17b185 100644 --- a/jumanji/environments/routing/cleaner/viewer.py +++ b/jumanji/environments/routing/cleaner/viewer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence +from typing import ClassVar, Dict, List, Optional, Sequence import matplotlib import matplotlib.animation @@ -30,7 +30,7 @@ class CleanerViewer(MazeViewer): AGENT = 3 - COLORS = { + COLORS: ClassVar[Dict[int, List[int]]] = { CLEAN: [1, 1, 1], # White WALL: [0, 0, 0], # Black DIRTY: [0, 1, 0], # Green @@ -121,10 +121,8 @@ def _create_grid_image(self, state: State) -> NDArray: return img def _set_agents_colors(self, img: NDArray, agents_locations: NDArray) -> NDArray: - unique_locations, counts = np.unique( - agents_locations, return_counts=True, axis=0 - ) - for location, count in zip(unique_locations, counts): + unique_locations, counts = np.unique(agents_locations, return_counts=True, axis=0) + for location, count in zip(unique_locations, counts, strict=False): img[location[0], location[1], :3] = np.array(self.AGENT_COLOR) img[location[0], location[1], 3] = 1 - self.ALPHA**count return img diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index b9d15c5c9..dae55c0fc 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -121,9 +121,7 @@ def __init__( self.grid_size = self._generator.grid_size super().__init__() self._agent_ids = jnp.arange(self.num_agents) - self._viewer = viewer or ConnectorViewer( - "Connector", self.num_agents, render_mode="human" - ) + self._viewer = viewer or ConnectorViewer("Connector", self.num_agents, render_mode="human") def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -137,9 +135,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """ state = self._generator(key) - action_mask = jax.vmap(self._get_action_mask, (0, None))( - state.agents, state.grid - ) + action_mask = jax.vmap(self._get_action_mask, (0, None))(state.agents, state.grid) observation = Observation( grid=state.grid, action_mask=action_mask, @@ -149,9 +145,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=observation, extras=extras) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Perform an environment step. Args: @@ -168,9 +162,7 @@ def step( timestep: `TimeStep` object corresponding the timestep returned by the environment. """ agents, grid = self._step_agents(state, action) - new_state = State( - grid=grid, step_count=state.step_count + 1, agents=agents, key=state.key - ) + new_state = State(grid=grid, step_count=state.step_count + 1, agents=agents, key=state.key) # Construct timestep: get reward, legal actions and done reward = self._reward_fn(state, action, new_state) @@ -197,9 +189,7 @@ def step( return new_state, timestep - def _step_agents( - self, state: State, action: chex.Array - ) -> Tuple[Agent, chex.Array]: + def _step_agents(self, state: State, action: chex.Array) -> Tuple[Agent, chex.Array]: """Steps all agents at the same time correcting for possible collisions. If a collision occurs we place the agent with the lower `agent_id` in its previous position. @@ -220,9 +210,7 @@ def _step_agents( # Create a correction mask for possible collisions (see the docs of `get_correction_mask`) correction_fn = jax.vmap(get_correction_mask, in_axes=(None, None, 0)) - correction_masks, collided_agents = correction_fn( - state.grid, joined_grid, agent_ids - ) + correction_masks, collided_agents = correction_fn(state.grid, joined_grid, agent_ids) correction_mask = jnp.sum(correction_masks, 0) # Correct state.agents diff --git a/jumanji/environments/routing/connector/env_test.py b/jumanji/environments/routing/connector/env_test.py index 12a6c1d94..5abe0dbbe 100644 --- a/jumanji/environments/routing/connector/env_test.py +++ b/jumanji/environments/routing/connector/env_test.py @@ -240,9 +240,7 @@ def test_connector__specs_does_not_smoke(connector: Connector) -> None: def test_connector__get_action_mask(state: State, connector: Connector) -> None: """Validates the action masking.""" - action_masks = jax.vmap(connector._get_action_mask, (0, None))( - state.agents, state.grid - ) + action_masks = jax.vmap(connector._get_action_mask, (0, None))(state.agents, state.grid) expected_mask = jnp.array( [ [True, True, False, True, True], diff --git a/jumanji/environments/routing/connector/generator.py b/jumanji/environments/routing/connector/generator.py index 0a63aa2a5..cc8f1bce6 100644 --- a/jumanji/environments/routing/connector/generator.py +++ b/jumanji/environments/routing/connector/generator.py @@ -173,9 +173,7 @@ def generate_board(self, key: chex.PRNGKey) -> Tuple[chex.Array, Agent, chex.Arr stepping_tuple = (step_key, grid, agents) - _, grid, agents = jax.lax.while_loop( - self._continue_stepping, self._step, stepping_tuple - ) + _, grid, agents = jax.lax.while_loop(self._continue_stepping, self._step, stepping_tuple) # Convert heads and targets to format accepted by generator heads = agents.start.T @@ -220,14 +218,10 @@ def _step_agents( keys = jax.random.split(key, num=self.num_agents) # Randomly select action for each agent - actions = jax.vmap(self._select_action, in_axes=(0, None, 0))( - keys, grid, agents - ) + actions = jax.vmap(self._select_action, in_axes=(0, None, 0))(keys, grid, agents) # Step all agents at the same time (separately) and return all of the grids - new_agents, grids = jax.vmap(self._step_agent, in_axes=(0, None, 0))( - agents, grid, actions - ) + new_agents, grids = jax.vmap(self._step_agent, in_axes=(0, None, 0))(agents, grid, actions) # Get grids with only values related to a single agent. # For example: remove all other agents from agent 1's grid. Do this for all agents. @@ -251,9 +245,7 @@ def _step_agents( # Create the new grid by fixing old one with correction mask and adding the obstacles return agents, joined_grid + correction_mask - def _initialize_agents( - self, key: chex.PRNGKey, grid: chex.Array - ) -> Tuple[chex.Array, Agent]: + def _initialize_agents(self, key: chex.PRNGKey, grid: chex.Array) -> Tuple[chex.Array, Agent]: """Initializes agents using random starting point and places heads on the grid. Args: @@ -345,9 +337,7 @@ def _no_available_cells(self, grid: chex.Array, agent: Agent) -> chex.Array: cell = self._convert_tuple_to_flat_position(agent.position) return (self._available_cells(grid, cell) == -1).all() - def _select_action( - self, key: chex.PRNGKey, grid: chex.Array, agent: Agent - ) -> chex.Array: + def _select_action(self, key: chex.PRNGKey, grid: chex.Array, agent: Agent) -> chex.Array: """Selects action for agent to take given its current position. Args: @@ -380,9 +370,7 @@ def _convert_flat_position_to_tuple(self, position: chex.Array) -> chex.Array: def _convert_tuple_to_flat_position(self, position: chex.Array) -> chex.Array: return jnp.array((position[0] * self.grid_size + position[1]), jnp.int32) - def _action_from_positions( - self, position_1: chex.Array, position_2: chex.Array - ) -> chex.Array: + def _action_from_positions(self, position_1: chex.Array, position_2: chex.Array) -> chex.Array: """Compares two positions and returns action id to get from one to the other.""" position_1 = self._convert_flat_position_to_tuple(position_1) position_2 = self._convert_flat_position_to_tuple(position_2) @@ -453,13 +441,11 @@ def _available_cells(self, grid: chex.Array, cell: chex.Array) -> chex.Array: value = grid[jnp.divmod(cell, self.grid_size)] wire_id = (value - 1) // 3 - available_cells_mask = jax.vmap(self._is_cell_free, in_axes=(None, 0))( - grid, adjacent_cells - ) + available_cells_mask = jax.vmap(self._is_cell_free, in_axes=(None, 0))(grid, adjacent_cells) # Also want to check if the cell is touching itself more than once - touching_cells_mask = jax.vmap( - self._is_cell_doubling_back, in_axes=(None, None, 0) - )(grid, wire_id, adjacent_cells) + touching_cells_mask = jax.vmap(self._is_cell_doubling_back, in_axes=(None, None, 0))( + grid, wire_id, adjacent_cells + ) available_cells_mask = available_cells_mask & touching_cells_mask available_cells = jnp.where(available_cells_mask, adjacent_cells, -1) return available_cells @@ -496,9 +482,7 @@ def _is_cell_doubling_back( # Get the adjacent cells of the current cell adjacent_cells = self._adjacent_cells(cell) - def is_cell_doubling_back_inner( - grid: chex.Array, cell: chex.Array - ) -> chex.Array: + def is_cell_doubling_back_inner(grid: chex.Array, cell: chex.Array) -> chex.Array: coordinate = jnp.divmod(cell, self.grid_size) cell_value = grid[tuple(coordinate)] touching_self = ( diff --git a/jumanji/environments/routing/connector/generator_test.py b/jumanji/environments/routing/connector/generator_test.py index 2356bf649..0a4b858c9 100644 --- a/jumanji/environments/routing/connector/generator_test.py +++ b/jumanji/environments/routing/connector/generator_test.py @@ -54,9 +54,7 @@ def test_uniform_random_generator__no_retrace( ) -> None: """Checks that generator only traces the function once and works when jitted.""" keys = jax.random.split(key, 2) - jitted_generator = jax.jit( - chex.assert_max_traces((uniform_random_generator.__call__), n=1) - ) + jitted_generator = jax.jit(chex.assert_max_traces((uniform_random_generator.__call__), n=1)) for key in keys: jitted_generator(key) @@ -246,9 +244,7 @@ def test_random_walk_generator__no_retrace( ) -> None: """Checks that generator only traces the function once and works when jitted.""" keys = jax.random.split(key, 2) - jitted_generator = jax.jit( - chex.assert_max_traces((random_walk_generator.__call__), n=1) - ) + jitted_generator = jax.jit(chex.assert_max_traces((random_walk_generator.__call__), n=1)) for key in keys: jitted_generator(key) @@ -320,9 +316,7 @@ def test_step( assert (new_grid == end_grid).all() assert (new_key == end_key).all() - def test_initialize_agents( - self, random_walk_generator: RandomWalkGenerator - ) -> None: + def test_initialize_agents(self, random_walk_generator: RandomWalkGenerator) -> None: grid, agents = random_walk_generator._initialize_agents(key, empty_grid) assert agents == agents_starting_initialise_agents assert (grid == valid_starting_grid_initialize_agents).all() @@ -392,9 +386,7 @@ def test_no_available_cells( expected_value: chex.Array, ) -> None: grid, agents = function_input - dones = jax.vmap(random_walk_generator._no_available_cells, in_axes=(None, 0))( - grid, agents - ) + dones = jax.vmap(random_walk_generator._no_available_cells, in_axes=(None, 0))(grid, agents) assert (dones == expected_value).all() @staticmethod @@ -411,9 +403,7 @@ def test_convert_flat_position_to_tuple( function_input: chex.Array, expected_value: chex.Array, ) -> None: - position_tuple = random_walk_generator._convert_flat_position_to_tuple( - function_input - ) + position_tuple = random_walk_generator._convert_flat_position_to_tuple(function_input) assert (position_tuple == expected_value).all() @staticmethod @@ -430,9 +420,7 @@ def test_convert_tuple_to_flat_position( function_input: chex.Array, expected_value: chex.Array, ) -> None: - position_tuple = random_walk_generator._convert_tuple_to_flat_position( - function_input - ) + position_tuple = random_walk_generator._convert_tuple_to_flat_position(function_input) assert (position_tuple == expected_value).all() @staticmethod @@ -545,9 +533,9 @@ def test_step_agent( ) -> None: agent, grid, action = function_input expected_grids, expected_agents = expected_value - new_agents, new_grids = jax.vmap( - random_walk_generator._step_agent, in_axes=(0, None, 0) - )(agent, grid, action) + new_agents, new_grids = jax.vmap(random_walk_generator._step_agent, in_axes=(0, None, 0))( + agent, grid, action + ) # assert new_agents == expected_agents assert (new_grids == expected_grids).all() @@ -569,7 +557,5 @@ def test_is_valid_position_rw( expected_value: chex.Array, ) -> None: grid, agent, new_position = function_input - valid_position = random_walk_generator._is_valid_position( - grid, agent, new_position - ) + valid_position = random_walk_generator._is_valid_position(grid, agent, new_position) assert (valid_position == expected_value).all() diff --git a/jumanji/environments/routing/connector/reward.py b/jumanji/environments/routing/connector/reward.py index 3d8e1195a..f29dbf2d5 100644 --- a/jumanji/environments/routing/connector/reward.py +++ b/jumanji/environments/routing/connector/reward.py @@ -70,7 +70,5 @@ def __call__( connected_rewards = self.connected_reward * jnp.asarray( ~state.agents.connected & next_state.agents.connected, float ) - timestep_rewards = self.timestep_reward * jnp.asarray( - ~state.agents.connected, float - ) + timestep_rewards = self.timestep_reward * jnp.asarray(~state.agents.connected, float) return jnp.sum(connected_rewards + timestep_rewards) diff --git a/jumanji/environments/routing/connector/utils.py b/jumanji/environments/routing/connector/utils.py index 9cb1dba35..6f6c2bed8 100644 --- a/jumanji/environments/routing/connector/utils.py +++ b/jumanji/environments/routing/connector/utils.py @@ -79,14 +79,10 @@ def move_position(position: chex.Array, action: jnp.int32) -> chex.Array: move_right = lambda row, col: jnp.array([row, col + 1], jnp.int32) move_down = lambda row, col: jnp.array([row + 1, col], jnp.int32) - return jax.lax.switch( - action, [move_noop, move_up, move_right, move_down, move_left], row, col - ) + return jax.lax.switch(action, [move_noop, move_up, move_right, move_down, move_left], row, col) -def move_agent( - agent: Agent, grid: chex.Array, new_pos: chex.Array -) -> Tuple[Agent, chex.Array]: +def move_agent(agent: Agent, grid: chex.Array, new_pos: chex.Array) -> Tuple[Agent, chex.Array]: """Moves `agent` to `new_pos` on `grid`. Sets `agent`'s position to `new_pos`. Returns: @@ -104,9 +100,7 @@ def move_agent( return new_agent, grid -def is_valid_position( - grid: chex.Array, agent: Agent, position: chex.Array -) -> chex.Array: +def is_valid_position(grid: chex.Array, agent: Agent, position: chex.Array) -> chex.Array: """Checks to see if the specified agent can move to `position`. Args: diff --git a/jumanji/environments/routing/connector/utils_test.py b/jumanji/environments/routing/connector/utils_test.py index 561fc82a7..82a533831 100644 --- a/jumanji/environments/routing/connector/utils_test.py +++ b/jumanji/environments/routing/connector/utils_test.py @@ -93,12 +93,8 @@ def test_move_agent_invalid(state: State) -> None: def test_is_valid_position(state: State) -> None: """Tests that the _is_valid_move method flags invalid moves.""" agent1 = tree_slice(state.agents, 1) - valid_move = is_valid_position( - grid=state.grid, agent=agent1, position=jnp.array([2, 2]) - ) - move_into_path = is_valid_position( - grid=state.grid, agent=agent1, position=jnp.array([4, 2]) - ) + valid_move = is_valid_position(grid=state.grid, agent=agent1, position=jnp.array([2, 2])) + move_into_path = is_valid_position(grid=state.grid, agent=agent1, position=jnp.array([4, 2])) assert valid_move assert not move_into_path diff --git a/jumanji/environments/routing/connector/viewer.py b/jumanji/environments/routing/connector/viewer.py index b08f5dabc..ea5c49836 100644 --- a/jumanji/environments/routing/connector/viewer.py +++ b/jumanji/environments/routing/connector/viewer.py @@ -98,9 +98,7 @@ def animate( Returns: Animation that can be saved as a GIF, MP4, or rendered with HTML. """ - fig, ax = plt.subplots( - num=f"{self._name}Animation", figsize=ConnectorViewer.FIGURE_SIZE - ) + fig, ax = plt.subplots(num=f"{self._name}Animation", figsize=ConnectorViewer.FIGURE_SIZE) plt.close(fig) def make_frame(grid_index: int) -> None: @@ -150,9 +148,7 @@ def _draw_grid(self, grid: chex.Array, ax: plt.Axes) -> None: for col in range(cols): self._draw_grid_cell(grid[row, col], row, col, ax) - def _draw_grid_cell( - self, cell_value: int, row: int, col: int, ax: plt.Axes - ) -> None: + def _draw_grid_cell(self, cell_value: int, row: int, col: int, ax: plt.Axes) -> None: cell = plt.Rectangle((col, row), 1, 1, **self._get_cell_attributes(cell_value)) ax.add_patch(cell) diff --git a/jumanji/environments/routing/cvrp/conftest.py b/jumanji/environments/routing/cvrp/conftest.py index b0e5c0bb9..e40222f5b 100644 --- a/jumanji/environments/routing/cvrp/conftest.py +++ b/jumanji/environments/routing/cvrp/conftest.py @@ -76,9 +76,7 @@ def __call__(self, key: chex.PRNGKey) -> State: """ del key - coordinates = jnp.array( - [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.5, 0.5]], float - ) + coordinates = jnp.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.5, 0.5]], float) demands = jnp.array([0, 1, 2, 1, 2], jnp.int32) # The initial position is set at the depot. diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py index 921dc646e..941515319 100644 --- a/jumanji/environments/routing/cvrp/env.py +++ b/jumanji/environments/routing/cvrp/env.py @@ -158,9 +158,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=self._state_to_observation(state)) return state, timestep - def step( - self, state: State, action: chex.Numeric - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Numeric) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -232,9 +230,7 @@ def observation_spec(self) -> specs.Spec[Observation]: maximum=True, name="unvisited_nodes", ) - position = specs.DiscreteArray( - self.num_nodes + 1, dtype=jnp.int32, name="position" - ) + position = specs.DiscreteArray(self.num_nodes + 1, dtype=jnp.int32, name="position") trajectory = specs.BoundedArray( shape=(2 * self.num_nodes,), minimum=0, diff --git a/jumanji/environments/routing/cvrp/env_test.py b/jumanji/environments/routing/cvrp/env_test.py index c0f828db2..397f6f1f0 100644 --- a/jumanji/environments/routing/cvrp/env_test.py +++ b/jumanji/environments/routing/cvrp/env_test.py @@ -195,9 +195,7 @@ def test_cvrp_sparse__revisit_depot_invalid(self, cvrp_sparse_reward: CVRP) -> N """Checks that the depot cannot be revisited if we are already at the depot.""" key = jax.random.PRNGKey(0) state, timestep = cvrp_sparse_reward.reset(key) - state, timestep = cvrp_sparse_reward.step( - state, jnp.array(DEPOT_IDX, jnp.int32) - ) + state, timestep = cvrp_sparse_reward.step(state, jnp.array(DEPOT_IDX, jnp.int32)) assert timestep.last() @@ -392,6 +390,4 @@ def test_cvrp__equivalence_dense_sparse_reward( return_sparse += timestep.reward # Check that both returns are the same and not the invalid action penalty - assert ( - return_sparse == return_dense > -2 * cvrp_dense_reward.num_nodes * jnp.sqrt(2) - ) + assert return_sparse == return_dense > -2 * cvrp_dense_reward.num_nodes * jnp.sqrt(2) diff --git a/jumanji/environments/routing/cvrp/generator_test.py b/jumanji/environments/routing/cvrp/generator_test.py index f72751700..936e5cb3e 100644 --- a/jumanji/environments/routing/cvrp/generator_test.py +++ b/jumanji/environments/routing/cvrp/generator_test.py @@ -54,9 +54,7 @@ def uniform_generator(self) -> UniformGenerator: max_demand=10, ) - def test_uniform_generator__properties( - self, uniform_generator: UniformGenerator - ) -> None: + def test_uniform_generator__properties(self, uniform_generator: UniformGenerator) -> None: """Validate that the random instance generator has the correct properties.""" assert uniform_generator.num_nodes == 20 assert uniform_generator.max_capacity == 30 diff --git a/jumanji/environments/routing/cvrp/reward.py b/jumanji/environments/routing/cvrp/reward.py index f0f97a1eb..02bbffaad 100644 --- a/jumanji/environments/routing/cvrp/reward.py +++ b/jumanji/environments/routing/cvrp/reward.py @@ -96,9 +96,7 @@ def __call__( return reward -def compute_tour_length( - coordinates: chex.Array, trajectory: chex.Array -) -> chex.Numeric: +def compute_tour_length(coordinates: chex.Array, trajectory: chex.Array) -> chex.Numeric: """Calculate the length of a tour.""" sorted_coordinates = coordinates[trajectory] # Shift coordinates to compute the distance between neighboring cities. diff --git a/jumanji/environments/routing/cvrp/reward_test.py b/jumanji/environments/routing/cvrp/reward_test.py index acb35a9e4..edfd18200 100644 --- a/jumanji/environments/routing/cvrp/reward_test.py +++ b/jumanji/environments/routing/cvrp/reward_test.py @@ -72,9 +72,7 @@ def test_dense_reward(cvrp_dense_reward: CVRP, dense_reward: DenseReward) -> Non assert dense_reward(state, 0, next_state, is_valid=False) == penalty -def test_sparse_reward( # noqa: CCR001 - cvrp_sparse_reward: CVRP, sparse_reward: SparseReward -) -> None: +def test_sparse_reward(cvrp_sparse_reward: CVRP, sparse_reward: SparseReward) -> None: sparse_reward = jax.jit(sparse_reward) step_fn = jax.jit(cvrp_sparse_reward.step) state, timestep = cvrp_sparse_reward.reset(jax.random.PRNGKey(0)) @@ -101,8 +99,5 @@ def test_sparse_reward( # noqa: CCR001 else: # Check that a penalty is given for every invalid action. invalid_next_state, _ = step_fn(state, action) - assert ( - sparse_reward(state, action, invalid_next_state, is_valid) - == penalty - ) + assert sparse_reward(state, action, invalid_next_state, is_valid) == penalty state = next_state diff --git a/jumanji/environments/routing/cvrp/viewer.py b/jumanji/environments/routing/cvrp/viewer.py index e1eba4b6d..be9f7bc34 100644 --- a/jumanji/environments/routing/cvrp/viewer.py +++ b/jumanji/environments/routing/cvrp/viewer.py @@ -62,9 +62,7 @@ def __init__(self, name: str, num_cities: int, render_mode: str = "human") -> No else: raise ValueError(f"Invalid render mode: {render_mode}") - def render( - self, state: State, save_path: Optional[str] = None - ) -> Optional[NDArray]: + def render(self, state: State, save_path: Optional[str] = None) -> Optional[NDArray]: """Render the given state of the `CVRP` environment. Args: @@ -166,9 +164,7 @@ def _group_tour(self, tour: Array) -> list: depot = tour[0] check_depot_fn = lambda x: (x != depot).all() tour_grouped = [ - np.array([depot] + list(g) + [depot]) - for k, g in groupby(tour, key=check_depot_fn) - if k + np.array([depot, *list(g), depot]) for k, g in groupby(tour, key=check_depot_fn) if k ] if (tour[-1] != tour[0]).all(): tour_grouped[-1] = tour_grouped[-1][:-1] @@ -213,7 +209,7 @@ def _add_tour(self, ax: plt.Axes, state: State) -> None: # Draw each route in different colour for coords_route, col_id in zip( - coords_grouped, np.arange(0, len(coords_grouped)) + coords_grouped, np.arange(0, len(coords_grouped)), strict=False ): self._draw_route(ax, coords_route, col_id) diff --git a/jumanji/environments/routing/lbf/env.py b/jumanji/environments/routing/lbf/env.py index cbdd132b1..42de8827e 100644 --- a/jumanji/environments/routing/lbf/env.py +++ b/jumanji/environments/routing/lbf/env.py @@ -119,7 +119,6 @@ def __init__( normalize_reward: bool = True, penalty: float = 0.0, ) -> None: - self._generator = generator or RandomGenerator( grid_size=8, fov=8, @@ -154,9 +153,7 @@ def __init__( super().__init__() # create viewer for rendering environment - self._viewer = viewer or LevelBasedForagingViewer( - self.grid_size, "LevelBasedForaging" - ) + self._viewer = viewer or LevelBasedForagingViewer(self.grid_size, "LevelBasedForaging") def __repr__(self) -> str: return ( @@ -225,21 +222,13 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep]: terminate + 2 * truncate, [ # !terminate !trunc - lambda rew, obs: transition( - reward=rew, observation=obs, shape=self.num_agents - ), + lambda rew, obs: transition(reward=rew, observation=obs, shape=self.num_agents), # terminate !truncate - lambda rew, obs: termination( - reward=rew, observation=obs, shape=self.num_agents - ), + lambda rew, obs: termination(reward=rew, observation=obs, shape=self.num_agents), # !terminate truncate - lambda rew, obs: truncation( - reward=rew, observation=obs, shape=self.num_agents - ), + lambda rew, obs: truncation(reward=rew, observation=obs, shape=self.num_agents), # terminate truncate - lambda rew, obs: termination( - reward=rew, observation=obs, shape=self.num_agents - ), + lambda rew, obs: termination(reward=rew, observation=obs, shape=self.num_agents), ], reward, observation, @@ -250,9 +239,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep]: def _get_extra_info(self, state: State, timestep: TimeStep) -> Dict: """Computes extras metrics to be returned within the timestep.""" - n_eaten = state.food_items.eaten.sum() + timestep.extras.get( - "eaten_food", jnp.int32(0) - ) + n_eaten = state.food_items.eaten.sum() + timestep.extras.get("eaten_food", jnp.int32(0)) percent_eaten = (n_eaten / self.num_food) * 100 return {"percent_eaten": percent_eaten} @@ -288,15 +275,11 @@ def get_reward_per_food( ) # Zero out all agents if food was not eaten and add penalty - reward = ( - adj_loading_agents_levels * eaten_this_step * food.level - ) - penalty + reward = (adj_loading_agents_levels * eaten_this_step * food.level) - penalty # jnp.nan_to_num: Used in the case where no agents are adjacent to the food normalizer = sum_agents_levels * total_food_level - reward = jnp.where( - self.normalize_reward, jnp.nan_to_num(reward / normalizer), reward - ) + reward = jnp.where(self.normalize_reward, jnp.nan_to_num(reward / normalizer), reward) return reward @@ -336,9 +319,7 @@ def animate( matplotlib.animation.FuncAnimation: Animation object that can be saved as a GIF, MP4, or rendered with HTML. """ - return self._viewer.animate( - states=states, interval=interval, save_path=save_path - ) + return self._viewer.animate(states=states, interval=interval, save_path=save_path) def close(self) -> None: """Perform any necessary cleanup.""" diff --git a/jumanji/environments/routing/lbf/env_test.py b/jumanji/environments/routing/lbf/env_test.py index 41f4e3c45..89b62f50c 100644 --- a/jumanji/environments/routing/lbf/env_test.py +++ b/jumanji/environments/routing/lbf/env_test.py @@ -33,9 +33,7 @@ def test_lbf_environment_integration( assert isinstance(initial_state, State) assert isinstance(timestep, TimeStep) assert timestep.step_type == StepType.FIRST - assert jnp.isclose( - timestep.reward, jnp.zeros(lbf_environment.num_agents, dtype=float) - ).all() + assert jnp.isclose(timestep.reward, jnp.zeros(lbf_environment.num_agents, dtype=float)).all() assert timestep.extras == {"percent_eaten": jnp.float32(0)} # Test the step function action = jnp.array([NOOP] * lbf_environment.num_agents) @@ -64,9 +62,7 @@ def test_reset(lbf_environment: LevelBasedForaging, key: chex.PRNGKey) -> None: assert timestep.reward.shape == (num_agents,) -def test_reset_grid_obs( - lbf_env_grid_obs: LevelBasedForaging, key: chex.PRNGKey -) -> None: +def test_reset_grid_obs(lbf_env_grid_obs: LevelBasedForaging, key: chex.PRNGKey) -> None: num_agents = lbf_env_grid_obs.num_agents state, timestep = lbf_env_grid_obs.reset(key) @@ -89,9 +85,7 @@ def test_reset_grid_obs( assert timestep.reward.shape == (num_agents,) -def test_get_reward( - lbf_environment: LevelBasedForaging, agents: Agent, food_items: Food -) -> None: +def test_get_reward(lbf_environment: LevelBasedForaging, agents: Agent, food_items: Food) -> None: adj_food0_level = jnp.array([0.0, agents.level[1], agents.level[2]]) adj_food1_level = jnp.array([0.0, 0.0, agents.level[2]]) adj_food2_level = jnp.array([0.0, 0.0, 0.0]) @@ -110,9 +104,7 @@ def test_get_reward( assert jnp.all(reward == expected_reward) -def test_reward_with_penalty( - lbf_with_penalty: LevelBasedForaging, food_items: Food -) -> None: +def test_reward_with_penalty(lbf_with_penalty: LevelBasedForaging, food_items: Food) -> None: adj_food0_level = jnp.array([0.0, 1, 2]) adj_food1_level = jnp.array([0.0, 0.0, 4]) adj_food2_level = jnp.array([2, 0.0, 0.0]) @@ -123,20 +115,18 @@ def test_reward_with_penalty( penalty = lbf_with_penalty.penalty penalty_0 = jnp.where(jnp.sum(adj_food0_level) < food_items.level[0], penalty, 0) - expected_reward_food0 = ( - adj_food0_level * eaten[0] * food_items.level[0] - penalty_0 - ) / (jnp.sum(food_items.level) * jnp.sum(adj_food0_level)) + expected_reward_food0 = (adj_food0_level * eaten[0] * food_items.level[0] - penalty_0) / ( + jnp.sum(food_items.level) * jnp.sum(adj_food0_level) + ) penalty_1 = jnp.where(jnp.sum(adj_food1_level) < food_items.level[1], penalty, 0) - expected_reward_food1 = ( - adj_food1_level * eaten[1] * food_items.level[1] - penalty_1 - ) / (jnp.sum(food_items.level) * jnp.sum(adj_food1_level)) + expected_reward_food1 = (adj_food1_level * eaten[1] * food_items.level[1] - penalty_1) / ( + jnp.sum(food_items.level) * jnp.sum(adj_food1_level) + ) penalty_2 = jnp.where(jnp.sum(adj_food2_level) < food_items.level[2], penalty, 0) - expected_reward_food2 = ( - adj_food2_level * eaten[1] * food_items.level[2] - penalty_2 - ) / (jnp.sum(food_items.level) * jnp.sum(adj_food2_level)) - expected_reward = ( - expected_reward_food0 + expected_reward_food1 + expected_reward_food2 + expected_reward_food2 = (adj_food2_level * eaten[1] * food_items.level[2] - penalty_2) / ( + jnp.sum(food_items.level) * jnp.sum(adj_food2_level) ) + expected_reward = expected_reward_food0 + expected_reward_food1 + expected_reward_food2 assert jnp.all(reward == expected_reward) @@ -158,7 +148,6 @@ def test_reward_with_no_norm( def test_step(lbf_environment: LevelBasedForaging, state: State) -> None: - num_agents = lbf_environment.num_agents ep_return = jnp.zeros((num_agents,), jnp.int32) @@ -196,9 +185,7 @@ def test_step(lbf_environment: LevelBasedForaging, state: State) -> None: assert jnp.sum(ep_return) == 1 -def test_step_done_horizon( - lbf_environment: LevelBasedForaging, key: chex.PRNGKey -) -> None: +def test_step_done_horizon(lbf_environment: LevelBasedForaging, key: chex.PRNGKey) -> None: num_agents = lbf_environment.num_agents # Test the done after 5 steps state, timestep = lbf_environment.reset(key) diff --git a/jumanji/environments/routing/lbf/generator.py b/jumanji/environments/routing/lbf/generator.py index 1cb1b7214..1d2d47f30 100644 --- a/jumanji/environments/routing/lbf/generator.py +++ b/jumanji/environments/routing/lbf/generator.py @@ -91,9 +91,7 @@ def sample_food(self, key: chex.PRNGKey) -> chex.Array: False, indices_are_sorted=True, unique_indices=True ) # right - def take_positions( - mask: chex.Array, key: chex.PRNGKey - ) -> Tuple[chex.Array, chex.Array]: + def take_positions(mask: chex.Array, key: chex.PRNGKey) -> Tuple[chex.Array, chex.Array]: food_flat_pos = jax.random.choice(key=key, a=flat_size, shape=(), p=mask) # Mask out adjacent positions to avoid placing food items next to each other @@ -138,17 +136,13 @@ def sample_agents(self, key: chex.PRNGKey, mask: chex.Array) -> chex.Array: # Stack x and y coordinates to form a 2D array return jnp.stack([agent_positions_x, agent_positions_y], axis=1) - def sample_levels( - self, max_level: int, shape: chex.Shape, key: chex.PRNGKey - ) -> chex.Array: + def sample_levels(self, max_level: int, shape: chex.Shape, key: chex.PRNGKey) -> chex.Array: """Samples levels within specified bounds.""" return jax.random.randint(key, shape=shape, minval=1, maxval=max_level + 1) def __call__(self, key: chex.PRNGKey) -> State: """Generates a state containing grid, agent, and food item configurations.""" - key_food, key_agents, key_food_level, key_agent_level, key = jax.random.split( - key, 5 - ) + key_food, key_agents, key_food_level, key_agent_level, key = jax.random.split(key, 5) # Generate positions for food items food_positions = self.sample_food(key_food) @@ -161,9 +155,7 @@ def __call__(self, key: chex.PRNGKey) -> State: agent_positions = self.sample_agents(key=key_agents, mask=mask) # Generate levels for agents and food items - agent_levels = self.sample_levels( - self.max_agent_level, (self.num_agents,), key_agent_level - ) + agent_levels = self.sample_levels(self.max_agent_level, (self.num_agents,), key_agent_level) # In the worst case, 3 agents are needed to eat a food item max_food_level = jnp.sum(jnp.sort(agent_levels)[:3]) @@ -189,6 +181,4 @@ def __call__(self, key: chex.PRNGKey) -> State: ) step_count = jnp.array(0, jnp.int32) - return State( - key=key, step_count=step_count, agents=agents, food_items=food_items - ) + return State(key=key, step_count=step_count, agents=agents, food_items=food_items) diff --git a/jumanji/environments/routing/lbf/generator_test.py b/jumanji/environments/routing/lbf/generator_test.py index 6c59b7e48..55b5bd54d 100644 --- a/jumanji/environments/routing/lbf/generator_test.py +++ b/jumanji/environments/routing/lbf/generator_test.py @@ -21,9 +21,7 @@ from jumanji.environments.routing.lbf.types import State -def test_random_generator_call( - random_generator: RandomGenerator, key: chex.PRNGKey -) -> None: +def test_random_generator_call(random_generator: RandomGenerator, key: chex.PRNGKey) -> None: state = random_generator(key) assert random_generator.grid_size >= 5 assert 2 <= random_generator.fov <= random_generator.grid_size @@ -48,9 +46,7 @@ def test_sample_food(random_generator: RandomGenerator, key: chex.PRNGKey) -> No food_positions = random_generator.sample_food(key) # Check if positions are within the grid bounds and no food on the edge of the grid - assert jnp.all( - (food_positions > 0) & (food_positions < random_generator.grid_size - 1) - ) + assert jnp.all((food_positions > 0) & (food_positions < random_generator.grid_size - 1)) # Check if no food positions overlap assert not jnp.any( @@ -60,17 +56,13 @@ def test_sample_food(random_generator: RandomGenerator, key: chex.PRNGKey) -> No def test_sample_agents(random_generator: RandomGenerator, key: chex.PRNGKey) -> None: - mask = jnp.ones( - (random_generator.grid_size, random_generator.grid_size), dtype=bool - ) + mask = jnp.ones((random_generator.grid_size, random_generator.grid_size), dtype=bool) mask = mask.ravel() agent_positions = random_generator.sample_agents(key, mask) # Check if positions are within the grid bounds - assert jnp.all( - (agent_positions >= 0) & (agent_positions < random_generator.grid_size) - ) + assert jnp.all((agent_positions >= 0) & (agent_positions < random_generator.grid_size)) # Check if no agent positions overlap assert not jnp.any( @@ -85,9 +77,7 @@ def test_sample_levels(random_generator: RandomGenerator, key: chex.PRNGKey) -> ) # Check if levels are within the specified range - assert jnp.all( - (agent_levels >= 1) & (agent_levels <= random_generator.max_agent_level) - ) + assert jnp.all((agent_levels >= 1) & (agent_levels <= random_generator.max_agent_level)) # Check if levels are generated randomly key2 = jax.random.PRNGKey(43) diff --git a/jumanji/environments/routing/lbf/observer.py b/jumanji/environments/routing/lbf/observer.py index 7565ee9ce..01fafbc71 100644 --- a/jumanji/environments/routing/lbf/observer.py +++ b/jumanji/environments/routing/lbf/observer.py @@ -114,9 +114,7 @@ class VectorObserver(LbfObserver): - num_food (int): The number of food items in the environment. """ - def __init__( - self, fov: int, grid_size: int, num_agents: int, num_food: int - ) -> None: + def __init__(self, fov: int, grid_size: int, num_agents: int, num_food: int) -> None: super().__init__(fov, grid_size, num_agents, num_food) def transform_positions(self, agent: Agent, items: Entity) -> chex.Array: @@ -185,9 +183,7 @@ def extract_agents_info( ] ).ravel() - other_agents_indices = jnp.where( - agent.id != all_agents.id, size=self.num_agents - 1 - ) + other_agents_indices = jnp.where(agent.id != all_agents.id, size=self.num_agents - 1) agent_xs = agent_xs[other_agents_indices] agent_ys = agent_ys[other_agents_indices] agent_levels = agent_levels[other_agents_indices] @@ -236,9 +232,9 @@ def make_agents_view(self, agent: Agent, state: State) -> chex.Array: ) # Always place the current agent's info first. - agent_view = agent_view.at[ - jnp.arange(3 * self.num_food, 3 * self.num_food + 3) - ].set(agent_i_infos, indices_are_sorted=True, unique_indices=True) + agent_view = agent_view.at[jnp.arange(3 * self.num_food, 3 * self.num_food + 3)].set( + agent_i_infos, indices_are_sorted=True, unique_indices=True + ) start_idx = 3 * self.num_food + 3 end_idx = start_idx + 3 * (self.num_agents - 1) diff --git a/jumanji/environments/routing/lbf/observer_test.py b/jumanji/environments/routing/lbf/observer_test.py index 65a7370a4..c1ceea427 100644 --- a/jumanji/environments/routing/lbf/observer_test.py +++ b/jumanji/environments/routing/lbf/observer_test.py @@ -39,28 +39,16 @@ def test_lbf_observer_initialization(lbf_env_2s: LevelBasedForaging) -> None: def test_vector_full_obs(state: State) -> None: observer = VectorObserver(fov=6, grid_size=6, num_agents=3, num_food=3) obs1 = observer.state_to_observation(state) - expected_agent_0_view = jnp.array( - [2, 1, 4, 2, 3, 4, 4, 2, 3, 0, 0, 1, 1, 1, 2, 2, 2, 4] - ) - expected_agent_1_view = jnp.array( - [2, 1, 4, 2, 3, 4, 4, 2, 3, 1, 1, 2, 0, 0, 1, 2, 2, 4] - ) - expected_agent_2_view = jnp.array( - [2, 1, 4, 2, 3, 4, 4, 2, 3, 2, 2, 4, 0, 0, 1, 1, 1, 2] - ) + expected_agent_0_view = jnp.array([2, 1, 4, 2, 3, 4, 4, 2, 3, 0, 0, 1, 1, 1, 2, 2, 2, 4]) + expected_agent_1_view = jnp.array([2, 1, 4, 2, 3, 4, 4, 2, 3, 1, 1, 2, 0, 0, 1, 2, 2, 4]) + expected_agent_2_view = jnp.array([2, 1, 4, 2, 3, 4, 4, 2, 3, 2, 2, 4, 0, 0, 1, 1, 1, 2]) assert jnp.all(obs1.agents_view[0, :] == expected_agent_0_view) - assert jnp.all( - obs1.action_mask[0, :] == jnp.array([True, False, True, False, True, False]) - ) + assert jnp.all(obs1.action_mask[0, :] == jnp.array([True, False, True, False, True, False])) assert jnp.all(obs1.agents_view[1, :] == expected_agent_1_view) - assert jnp.all( - obs1.action_mask[1, :] == jnp.array([True, True, False, True, True, True]) - ) + assert jnp.all(obs1.action_mask[1, :] == jnp.array([True, True, False, True, True, True])) assert jnp.all(obs1.agents_view[2, :] == expected_agent_2_view) - assert jnp.all( - obs1.action_mask[2, :] == jnp.array([True, True, True, False, False, True]) - ) + assert jnp.all(obs1.action_mask[2, :] == jnp.array([True, True, True, False, False, True])) # If agent1 and agent2 eat the food0 eaten = jnp.array([True, False, False]) @@ -73,47 +61,27 @@ def test_vector_full_obs(state: State) -> None: state = state.replace(food_items=food_items) # type: ignore obs2 = observer.state_to_observation(state) - expected_agent_1_view = jnp.array( - [-1, -1, 0, 2, 3, 4, 4, 2, 3, 1, 1, 2, 0, 0, 1, 2, 2, 4] - ) - expected_agent_2_view = jnp.array( - [-1, -1, 0, 2, 3, 4, 4, 2, 3, 2, 2, 4, 0, 0, 1, 1, 1, 2] - ) + expected_agent_1_view = jnp.array([-1, -1, 0, 2, 3, 4, 4, 2, 3, 1, 1, 2, 0, 0, 1, 2, 2, 4]) + expected_agent_2_view = jnp.array([-1, -1, 0, 2, 3, 4, 4, 2, 3, 2, 2, 4, 0, 0, 1, 1, 1, 2]) assert jnp.all(obs2.agents_view[1, :] == expected_agent_1_view) - assert jnp.all( - obs2.action_mask[1, :] == jnp.array([True, True, True, True, True, False]) - ) + assert jnp.all(obs2.action_mask[1, :] == jnp.array([True, True, True, True, True, False])) assert jnp.all(obs2.agents_view[2, :] == expected_agent_2_view) - assert jnp.all( - obs2.action_mask[2, :] == jnp.array([True, True, True, True, False, True]) - ) + assert jnp.all(obs2.action_mask[2, :] == jnp.array([True, True, True, True, False, True])) def test_vector_partial_obs(state: State) -> None: observer = VectorObserver(fov=2, grid_size=6, num_agents=3, num_food=3) obs1 = observer.state_to_observation(state) - expected_agent_0_view = jnp.array( - [2, 1, 4, -1, -1, 0, -1, -1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 4] - ) - expected_agent_1_view = jnp.array( - [2, 1, 4, 2, 3, 4, -1, -1, 0, 1, 1, 2, 0, 0, 1, 2, 2, 4] - ) - expected_agent_2_view = jnp.array( - [2, 1, 4, 2, 3, 4, 4, 2, 3, 2, 2, 4, 0, 0, 1, 1, 1, 2] - ) + expected_agent_0_view = jnp.array([2, 1, 4, -1, -1, 0, -1, -1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 4]) + expected_agent_1_view = jnp.array([2, 1, 4, 2, 3, 4, -1, -1, 0, 1, 1, 2, 0, 0, 1, 2, 2, 4]) + expected_agent_2_view = jnp.array([2, 1, 4, 2, 3, 4, 4, 2, 3, 2, 2, 4, 0, 0, 1, 1, 1, 2]) assert jnp.all(obs1.agents_view[0, :] == expected_agent_0_view) - assert jnp.all( - obs1.action_mask[0, :] == jnp.array([True, False, True, False, True, False]) - ) + assert jnp.all(obs1.action_mask[0, :] == jnp.array([True, False, True, False, True, False])) assert jnp.all(obs1.agents_view[1, :] == expected_agent_1_view) - assert jnp.all( - obs1.action_mask[1, :] == jnp.array([True, True, False, True, True, True]) - ) + assert jnp.all(obs1.action_mask[1, :] == jnp.array([True, True, False, True, True, True])) assert jnp.all(obs1.agents_view[2, :] == expected_agent_2_view) - assert jnp.all( - obs1.action_mask[2, :] == jnp.array([True, True, True, False, False, True]) - ) + assert jnp.all(obs1.action_mask[2, :] == jnp.array([True, True, True, False, False, True])) # test eaten food is not visible eaten = jnp.array([True, False, False]) @@ -126,20 +94,12 @@ def test_vector_partial_obs(state: State) -> None: state = state.replace(food_items=food_items) # type: ignore obs2 = observer.state_to_observation(state) - expected_agent_1_view = jnp.array( - [-1, -1, 0, 2, 3, 4, -1, -1, 0, 1, 1, 2, 0, 0, 1, 2, 2, 4] - ) - expected_agent_2_view = jnp.array( - [-1, -1, 0, 2, 3, 4, 4, 2, 3, 2, 2, 4, 0, 0, 1, 1, 1, 2] - ) + expected_agent_1_view = jnp.array([-1, -1, 0, 2, 3, 4, -1, -1, 0, 1, 1, 2, 0, 0, 1, 2, 2, 4]) + expected_agent_2_view = jnp.array([-1, -1, 0, 2, 3, 4, 4, 2, 3, 2, 2, 4, 0, 0, 1, 1, 1, 2]) assert jnp.all(obs2.agents_view[1, :] == expected_agent_1_view) - assert jnp.all( - obs2.action_mask[1, :] == jnp.array([True, True, True, True, True, False]) - ) + assert jnp.all(obs2.action_mask[1, :] == jnp.array([True, True, True, True, True, False])) assert jnp.all(obs2.agents_view[2, :] == expected_agent_2_view) - assert jnp.all( - obs2.action_mask[2, :] == jnp.array([True, True, True, True, False, True]) - ) + assert jnp.all(obs2.action_mask[2, :] == jnp.array([True, True, True, True, False, True])) def test_grid_observer(state: State) -> None: @@ -225,14 +185,8 @@ def test_grid_observer(state: State) -> None: ) assert jnp.all(obs.agents_view[0, :] == expected_agent_0_view) - assert jnp.all( - obs.action_mask[0, :] == jnp.array([True, False, True, False, True, False]) - ) + assert jnp.all(obs.action_mask[0, :] == jnp.array([True, False, True, False, True, False])) assert jnp.all(obs.agents_view[1, :] == expected_agent_1_view) - assert jnp.all( - obs.action_mask[1, :] == jnp.array([True, True, False, True, True, True]) - ) + assert jnp.all(obs.action_mask[1, :] == jnp.array([True, True, False, True, True, True])) assert jnp.all(obs.agents_view[2, :] == expected_agent_2_view) - assert jnp.all( - obs.action_mask[2, :] == jnp.array([True, True, True, False, False, True]) - ) + assert jnp.all(obs.action_mask[2, :] == jnp.array([True, True, True, False, False, True])) diff --git a/jumanji/environments/routing/lbf/utils.py b/jumanji/environments/routing/lbf/utils.py index 1e45ea180..2ca6da1d4 100644 --- a/jumanji/environments/routing/lbf/utils.py +++ b/jumanji/environments/routing/lbf/utils.py @@ -45,9 +45,7 @@ def flag_duplicates(a: chex.Array) -> chex.Array: flag_duplicates(a) # jnp.array([True, False, True, False, True, True]) """ # https://stackoverflow.com/a/11528078/5768407 - _, indices, counts = jnp.unique( - a, return_inverse=True, return_counts=True, size=len(a), axis=0 - ) + _, indices, counts = jnp.unique(a, return_inverse=True, return_counts=True, size=len(a), axis=0) return ~(counts[indices] == 1) @@ -85,9 +83,7 @@ def simulate_agent_movement( # Move the agent to the new position if it's a valid position, # otherwise keep the current position - new_agent_position = jnp.where( - out_of_bounds | entity_at_position, agent.position, new_position - ) + new_agent_position = jnp.where(out_of_bounds | entity_at_position, agent.position, new_position) # Return the agent with the updated position return agent.replace(position=new_agent_position) # type: ignore @@ -121,9 +117,9 @@ def update_agent_positions( moved_agents = fix_collisions(moved_agents, agents) # set agent's loading status - moved_agents = jax.vmap( - lambda agent, action: agent.replace(loading=(action == LOAD)) - )(moved_agents, actions) + moved_agents = jax.vmap(lambda agent, action: agent.replace(loading=(action == LOAD)))( + moved_agents, actions + ) return moved_agents @@ -222,9 +218,7 @@ def check_pos_fn(next_pos: Any, entities: Entity, condition: bool) -> Any: next_positions, state.food_items, ~state.food_items.eaten ) # Check if the next position is out of bounds - out_of_bounds = jnp.any( - (next_positions < 0) | (next_positions >= grid_size), axis=-1 - ) + out_of_bounds = jnp.any((next_positions < 0) | (next_positions >= grid_size), axis=-1) action_mask = ~(food_occupied | agent_occupied | out_of_bounds) diff --git a/jumanji/environments/routing/lbf/utils_test.py b/jumanji/environments/routing/lbf/utils_test.py index dfbb1feb6..2b4507e6c 100644 --- a/jumanji/environments/routing/lbf/utils_test.py +++ b/jumanji/environments/routing/lbf/utils_test.py @@ -25,35 +25,23 @@ def test_simulate_agent_movement( agent0: Agent, agent1: Agent, agent2: Agent, agents: Agent, food_items: Food ) -> None: grid_size = 6 - agent0_new = utils.simulate_agent_movement( - agent0, RIGHT, food_items, agents, grid_size - ) + agent0_new = utils.simulate_agent_movement(agent0, RIGHT, food_items, agents, grid_size) assert jnp.all(agent0_new.position == jnp.array([0, 1])) - agent1_new = utils.simulate_agent_movement( - agent1, LEFT, food_items, agents, grid_size - ) + agent1_new = utils.simulate_agent_movement(agent1, LEFT, food_items, agents, grid_size) assert jnp.all(agent1_new.position == jnp.array([1, 0])) # Move agent out of bounds - agent0_new = utils.simulate_agent_movement( - agent0, UP, food_items, agents, grid_size - ) + agent0_new = utils.simulate_agent_movement(agent0, UP, food_items, agents, grid_size) assert jnp.all(agent0_new.position == agent0.position) # Move agent1 to take the position of the food0 - agent1_new = utils.simulate_agent_movement( - agent1, DOWN, food_items, agents, grid_size - ) + agent1_new = utils.simulate_agent_movement(agent1, DOWN, food_items, agents, grid_size) assert jnp.all(agent1_new.position == agent1.position) # Try to load and do nothing. - agent2_new = utils.simulate_agent_movement( - agent2, NOOP, food_items, agents, grid_size - ) + agent2_new = utils.simulate_agent_movement(agent2, NOOP, food_items, agents, grid_size) assert jnp.all(agent2_new.position == agent2.position) - agent2_new = utils.simulate_agent_movement( - agent2, LOAD, food_items, agents, grid_size - ) + agent2_new = utils.simulate_agent_movement(agent2, LOAD, food_items, agents, grid_size) assert jnp.all(agent2_new.position == agent2.position) @@ -67,7 +55,6 @@ def test_are_entities_adjacent( food2: Food, food_items: Food, ) -> None: - assert utils.are_entities_adjacent(agent1, food0) assert utils.are_entities_adjacent(agent2, food0) assert utils.are_entities_adjacent(agent2, food1) diff --git a/jumanji/environments/routing/lbf/viewer.py b/jumanji/environments/routing/lbf/viewer.py index 99370c920..d8f66cef6 100644 --- a/jumanji/environments/routing/lbf/viewer.py +++ b/jumanji/environments/routing/lbf/viewer.py @@ -14,11 +14,9 @@ # flake8: noqa: CCR001 -import os from typing import Callable, Optional, Sequence, Tuple import matplotlib.animation as animation -import matplotlib.image as mpimg import matplotlib.pyplot as plt import numpy as np import pkg_resources @@ -218,9 +216,7 @@ def _draw_agents(self, agents: Agent, ax: plt.Axes) -> None: # Create an OffsetImage and add it to the axis imagebox = OffsetImage(img, zoom=self.icon_size / self.grid_size) - ab = AnnotationBbox( - imagebox, (cell_center[0], cell_center[1]), frameon=False, zorder=0 - ) + ab = AnnotationBbox(imagebox, (cell_center[0], cell_center[1]), frameon=False, zorder=0) ax.add_artist(ab) # Add a rectangle (polygon) next to the agent with the agent's level @@ -245,9 +241,7 @@ def _draw_food(self, food_items: Food, ax: plt.Axes) -> None: # Create an OffsetImage and add it to the axis imagebox = OffsetImage(img, zoom=self.icon_size / self.grid_size) - ab = AnnotationBbox( - imagebox, (cell_center[0], cell_center[1]), frameon=False, zorder=0 - ) + ab = AnnotationBbox(imagebox, (cell_center[0], cell_center[1]), frameon=False, zorder=0) ax.add_artist(ab) # Add a rectangle (polygon) next to the agent with the food's level @@ -263,9 +257,7 @@ def _entity_position(self, entity: Entity) -> Tuple[float, float]: y_center, ) - def draw_badge( - self, level: int, anchor_point: Tuple[float, float], ax: plt.Axes - ) -> None: + def draw_badge(self, level: int, anchor_point: Tuple[float, float], ax: plt.Axes) -> None: resolution = 6 radius = self.grid_size / 6 diff --git a/jumanji/environments/routing/lbf/viewer_test.py b/jumanji/environments/routing/lbf/viewer_test.py index 9626d77b1..99e30e35c 100644 --- a/jumanji/environments/routing/lbf/viewer_test.py +++ b/jumanji/environments/routing/lbf/viewer_test.py @@ -38,9 +38,7 @@ def test_lbf_viewer_render( viewer.close() -def test_lbf_viewer_animate( - lbf_environment: LevelBasedForaging, key: chex.PRNGKey -) -> None: +def test_lbf_viewer_animate(lbf_environment: LevelBasedForaging, key: chex.PRNGKey) -> None: """Test animation using LevelBasedForagingViewer.""" state, _ = jax.jit(lbf_environment.reset)(key) diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py index c2f0100dd..101fcec2f 100644 --- a/jumanji/environments/routing/maze/env.py +++ b/jumanji/environments/routing/maze/env.py @@ -134,12 +134,8 @@ def observation_spec(self) -> specs.Spec[Observation]: agent_position = specs.Spec( Position, "PositionSpec", - row=specs.BoundedArray( - (), jnp.int32, 0, self.num_rows - 1, "row_coordinate" - ), - col=specs.BoundedArray( - (), jnp.int32, 0, self.num_cols - 1, "col_coordinate" - ), + row=specs.BoundedArray((), jnp.int32, 0, self.num_rows - 1, "row_coordinate"), + col=specs.BoundedArray((), jnp.int32, 0, self.num_cols - 1, "col_coordinate"), ) walls = specs.BoundedArray( shape=(self.num_rows, self.num_cols), @@ -196,9 +192,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """ Run one timestep of the environment's dynamics. @@ -268,9 +262,7 @@ def step( ) return state, timestep - def _compute_action_mask( - self, walls: chex.Array, agent_position: Position - ) -> chex.Array: + def _compute_action_mask(self, walls: chex.Array, agent_position: Position) -> chex.Array: """Compute the action mask. An action is considered invalid if it leads to a WALL or goes outside of the maze. """ diff --git a/jumanji/environments/routing/maze/generator.py b/jumanji/environments/routing/maze/generator.py index 2fd5ba6d8..25a9eb578 100644 --- a/jumanji/environments/routing/maze/generator.py +++ b/jumanji/environments/routing/maze/generator.py @@ -90,9 +90,7 @@ def __call__(self, key: chex.PRNGKey) -> State: """ key, maze_key, agent_key = jax.random.split(key, 3) - walls = maze_generation.generate_maze( - self.num_cols, self.num_rows, maze_key - ).astype(bool) + walls = maze_generation.generate_maze(self.num_cols, self.num_rows, maze_key).astype(bool) # Randomise agent start and target positions. start_and_target_indices = jax.random.choice( diff --git a/jumanji/environments/routing/maze/viewer.py b/jumanji/environments/routing/maze/viewer.py index b8d8d5f64..ae43a6aaa 100644 --- a/jumanji/environments/routing/maze/viewer.py +++ b/jumanji/environments/routing/maze/viewer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence +from typing import ClassVar, Optional, Sequence import chex import matplotlib @@ -29,7 +29,7 @@ class MazeEnvViewer(MazeViewer): AGENT = 2 TARGET = 3 - COLORS = { + COLORS: ClassVar = { EMPTY: [1, 1, 1], # White WALL: [0, 0, 0], # Black AGENT: [0, 1, 0], # Green diff --git a/jumanji/environments/routing/mmst/conftest.py b/jumanji/environments/routing/mmst/conftest.py index f8b01c425..9b5fe1245 100644 --- a/jumanji/environments/routing/mmst/conftest.py +++ b/jumanji/environments/routing/mmst/conftest.py @@ -120,9 +120,7 @@ def deterministic_mmst_env() -> Tuple[MMST, State, TimeStep]: positions = jnp.array([1, 3], dtype=jnp.int32) active_node_edges = jnp.repeat(node_edges[None, ...], num_agents, axis=0) - active_node_edges = update_active_edges( - num_agents, active_node_edges, positions, node_types - ) + active_node_edges = update_active_edges(num_agents, active_node_edges, positions, node_types) finished_agents = jnp.zeros((num_agents), dtype=bool) state = State( diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index 386f2dd3c..be0057ac5 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -189,9 +189,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=self._state_to_observation(state), extras=extras) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -233,9 +231,7 @@ def step_agent_fn( return connected_nodes, conn_index, new_node, indices key, step_key = jax.random.split(state.key) - action, next_nodes = self._trim_duplicated_invalid_actions( - state, action, step_key - ) + action, next_nodes = self._trim_duplicated_invalid_actions(state, action, step_key) connected_nodes = jnp.zeros_like(state.connected_nodes) connected_nodes_index = jnp.zeros_like(state.connected_nodes_index) @@ -476,9 +472,7 @@ def _get_agent_node( added_nodes = jnp.ones((self.num_agents), dtype=jnp.int32) * DUMMY_NODE - agent_permutation = jax.random.permutation( - step_key, jnp.arange(self.num_agents) - ) + agent_permutation = jax.random.permutation(step_key, jnp.arange(self.num_agents)) def not_all_agents_actions_examined(arg: Any) -> Any: added_nodes, new_actions, action, nodes, agent_permutation, index = arg @@ -530,9 +524,7 @@ def modify_action_if_agent_target_node_is_selected(arg: Any) -> Any: (added_nodes, new_actions, action, nodes, agent_permutation, 0), ) - def mask_visited_nodes( - node_visited: jnp.int32, old_action: jnp.int32 - ) -> jnp.int32: + def mask_visited_nodes(node_visited: jnp.int32, old_action: jnp.int32) -> jnp.int32: new_action = jax.lax.cond( # type:ignore node_visited != EMPTY_NODE, lambda *_: INVALID_ALREADY_TRAVERSED, @@ -586,9 +578,7 @@ def get_finished_agents(self, state: State) -> chex.Array: Array : array of boolean flags in the shape (number of agents, ). """ - def done_fun( - nodes: chex.Array, connected_nodes: chex.Array, n_comps: int - ) -> jnp.bool_: + def done_fun(nodes: chex.Array, connected_nodes: chex.Array, n_comps: int) -> jnp.bool_: connects = jnp.isin(nodes, connected_nodes) return jnp.sum(connects) == n_comps diff --git a/jumanji/environments/routing/mmst/generator.py b/jumanji/environments/routing/mmst/generator.py index be0e4685a..20c35db6c 100644 --- a/jumanji/environments/routing/mmst/generator.py +++ b/jumanji/environments/routing/mmst/generator.py @@ -136,9 +136,7 @@ def __call__(self, key: chex.PRNGKey) -> State: conn_nodes_index = conn_nodes_index.at[agent, agent_components[0]].set( agent_components[0] ) - state_nodes_to_connect = state_nodes_to_connect.at[agent].set( - agent_components - ) + state_nodes_to_connect = state_nodes_to_connect.at[agent].set(agent_components) active_node_edges = jnp.repeat(node_edges[None, ...], self.num_agents, axis=0) active_node_edges = update_active_edges( @@ -169,10 +167,7 @@ def __call__(self, key: chex.PRNGKey) -> State: return state - def _generate_graph( - self, key: chex.PRNGKey - ) -> Tuple[chex.Array, chex.Array, chex.Array]: - + def _generate_graph(self, key: chex.PRNGKey) -> Tuple[chex.Array, chex.Array, chex.Array]: nodes = jnp.arange(self._num_nodes, dtype=jnp.int32) graph, nodes_per_sub_graph = multi_random_walk( nodes, self._num_edges, self._num_agents, self._max_degree, key @@ -193,9 +188,7 @@ def _initialise_states( ) node_types = EMPTY_NODE * jnp.ones(self._num_nodes, dtype=jnp.int32) - conn_nodes = EMPTY_NODE * jnp.ones( - (self._num_agents, self._max_step), dtype=jnp.int32 - ) + conn_nodes = EMPTY_NODE * jnp.ones((self._num_agents, self._max_step), dtype=jnp.int32) conn_nodes_index = EMPTY_NODE * jnp.ones( (self._num_agents, self._num_nodes), dtype=jnp.int32 ) diff --git a/jumanji/environments/routing/mmst/reward.py b/jumanji/environments/routing/mmst/reward.py index e6318ee5a..c6180f96d 100644 --- a/jumanji/environments/routing/mmst/reward.py +++ b/jumanji/environments/routing/mmst/reward.py @@ -79,7 +79,6 @@ def reward_fun(nodes: chex.Array, action: int, node: int) -> jnp.float_: def __call__( self, state: State, actions: chex.Array, nodes_to_connect: chex.Array ) -> chex.Array: - num_agents = len(actions) rewards = jnp.zeros((num_agents,), dtype=jnp.float32) diff --git a/jumanji/environments/routing/mmst/reward_test.py b/jumanji/environments/routing/mmst/reward_test.py index c6986df13..d39ca833c 100644 --- a/jumanji/environments/routing/mmst/reward_test.py +++ b/jumanji/environments/routing/mmst/reward_test.py @@ -27,9 +27,7 @@ from jumanji.types import TimeStep -def test__mmst_dense_rewards( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] -) -> None: +def test__mmst_dense_rewards(deterministic_mmst_env: Tuple[MMST, State, TimeStep]) -> None: """Test that the default dense reward function works correctly.""" # Default reward values are (10.0, -1.0, -1.0) @@ -39,9 +37,7 @@ def test__mmst_dense_rewards( action = jnp.array([4, 3]) - new_action, next_nodes = env._trim_duplicated_invalid_actions( - state, action, state.key - ) + new_action, next_nodes = env._trim_duplicated_invalid_actions(state, action, state.key) assert new_action[1] == INVALID_CHOICE assert next_nodes[1] == INVALID_NODE @@ -52,9 +48,7 @@ def test__mmst_dense_rewards( assert jnp.array_equal(timestep.reward, expected) action = jnp.array([1, 7]) - new_action, next_nodes = env._trim_duplicated_invalid_actions( - state, action, state.key - ) + new_action, next_nodes = env._trim_duplicated_invalid_actions(state, action, state.key) state, timestep = step_fn(state, action) assert new_action[0] == INVALID_ALREADY_TRAVERSED diff --git a/jumanji/environments/routing/mmst/utils.py b/jumanji/environments/routing/mmst/utils.py index 18f653822..d88e0f77a 100644 --- a/jumanji/environments/routing/mmst/utils.py +++ b/jumanji/environments/routing/mmst/utils.py @@ -102,8 +102,7 @@ def make_action_mask( full_action_mask = node_edges != EMPTY_NODE action_mask = ( - full_action_mask[jnp.arange(num_agents), position] - & ~finished_agents[:, jnp.newaxis] + full_action_mask[jnp.arange(num_agents), position] & ~finished_agents[:, jnp.newaxis] ) return action_mask @@ -194,9 +193,7 @@ def init_graph(nodes: chex.Array, max_degree: int, num_edges: int) -> Graph: return graph -def init_graph_merge( - graph_a: Graph, graph_b: Graph, num_edges: int, max_degree: int -) -> Graph: +def init_graph_merge(graph_a: Graph, graph_b: Graph, num_edges: int, max_degree: int) -> Graph: """Merge two graphs and initialize the setting to add new edges. Args: @@ -230,9 +227,7 @@ def init_graph_merge( node_edges = jnp.ones((total_nodes, total_nodes), dtype=jnp.int32) * EMPTY_NODE node_edges = node_edges.at[0:nodes_a, 0:nodes_a].set(graph_a.node_edges) - node_edges = node_edges.at[nodes_a:total_nodes, nodes_a:total_nodes].set( - graph_b.node_edges - ) + node_edges = node_edges.at[nodes_a:total_nodes, nodes_a:total_nodes].set(graph_b.node_edges) graph = Graph( nodes=nodes, @@ -262,9 +257,7 @@ def correct_graph_offset(graph: Graph, offset: int) -> Graph: nodes = graph.nodes + offset edges = graph.edges + offset - edge_codes = jax.vmap(correct_edge_code_offset, in_axes=(0, None))( - graph.edge_codes, offset - ) + edge_codes = jax.vmap(correct_edge_code_offset, in_axes=(0, None))(graph.edge_codes, offset) node_edges = graph.node_edges zero_mask = node_edges != EMPTY_NODE @@ -358,9 +351,7 @@ def make_random_edge_from_nodes( return edge -def add_random_edges( - graph: Graph, total_edges: jnp.int32, base_key: chex.PRNGKey -) -> Graph: +def add_random_edges(graph: Graph, total_edges: jnp.int32, base_key: chex.PRNGKey) -> Graph: """Add random edges until the number of desired edges is reached.""" def desired_num_edges_not_reach(arg: Any) -> Any: @@ -406,9 +397,7 @@ def update_conected_nodes( return (source, target, graph) -def dummy_add_nodes( - graph: Graph, edge: chex.Array, source: chex.Array, target: chex.Array -) -> Any: +def dummy_add_nodes(graph: Graph, edge: chex.Array, source: chex.Array, target: chex.Array) -> Any: return (source, target, graph) @@ -516,9 +505,7 @@ def merge_graphs( base_key1, base_key2 = jax.random.split(base_key, 2) # Add one edge between both subgraphs to guarentee the new graph is connected. - graph, _ = add_edge( - graph, make_random_edge_from_nodes(graph_a.nodes, graph_b.nodes, base_key1) - ) + graph, _ = add_edge(graph, make_random_edge_from_nodes(graph_a.nodes, graph_b.nodes, base_key1)) # Add remaining edges until the desired number of edges is reached. graph = add_random_edges(graph, num_edges, base_key2) @@ -597,9 +584,9 @@ def multi_random_walk( # Get the total number of edges we need to add when merging the graphs. sum_ratio = np.arange(1, num_agents).sum() - frac = np.cumsum( - total_edges_merge_graph * np.arange(1, num_agents - 1) / sum_ratio - ).astype(np.int32) + frac = np.cumsum(total_edges_merge_graph * np.arange(1, num_agents - 1) / sum_ratio).astype( + np.int32 + ) edges_per_merge_graph = jnp.split(jnp.arange(total_edges_merge_graph), frac) num_edges_per_merge_graph = [len(edges) for edges in edges_per_merge_graph] @@ -612,8 +599,6 @@ def multi_random_walk( for i in range(1, num_agents): total_edges += num_edges_per_sub_graph[i] + num_edges_per_merge_graph[i - 1] graph_i = correct_graph_offset(graphs[i - 1], nodes_offsets[i]) - graph = merge_graphs( - graph, graph_i, total_edges, max_degree, merge_graph_keys[i] - ) + graph = merge_graphs(graph, graph_i, total_edges, max_degree, merge_graph_keys[i]) return graph, nodes_per_sub_graph diff --git a/jumanji/environments/routing/mmst/viewer.py b/jumanji/environments/routing/mmst/viewer.py index 99146edbf..515822fb4 100644 --- a/jumanji/environments/routing/mmst/viewer.py +++ b/jumanji/environments/routing/mmst/viewer.py @@ -152,7 +152,6 @@ def _draw_graph(self, state: State, ax: plt.Axes) -> None: def build_edges( self, adj_matrix: chex.Array, connected_nodes: chex.Array ) -> Dict[Tuple[int, ...], List[Tuple[float, ...]]]: - # Normalize id for either order. def edge_id(n1: int, n2: int) -> Tuple[int, ...]: return tuple(sorted((n1, n2))) @@ -165,7 +164,7 @@ def edge_id(n1: int, n2: int) -> Tuple[int, ...]: row_indices, col_indices = jnp.nonzero(adj_matrix) # Create the edge list as a list of tuples (source, target) edges_list = [ - (int(row), int(col)) for row, col in zip(row_indices, col_indices) + (int(row), int(col)) for row, col in zip(row_indices, col_indices, strict=False) ] for edge in edges_list: @@ -176,9 +175,7 @@ def edge_id(n1: int, n2: int) -> Tuple[int, ...]: for agent in range(self.num_agents): conn_group = connected_nodes[agent] - len_conn = np.where(conn_group != -1)[0][ - -1 - ] # Get last index where node is not -1. + len_conn = np.where(conn_group != -1)[0][-1] # Get last index where node is not -1. for i in range(len_conn): eid = edge_id(conn_group[i], conn_group[i + 1]) edges[eid] = [(conn_group[i], conn_group[i + 1]), self.palette[agent]] @@ -217,9 +214,7 @@ def animate( num_nodes = states[0].adj_matrix.shape[0] node_scale = 5 + int(np.sqrt(num_nodes)) - fig, ax = plt.subplots( - num=f"{self._name}Animation", figsize=(node_scale, node_scale) - ) + fig, ax = plt.subplots(num=f"{self._name}Animation", figsize=(node_scale, node_scale)) plt.close(fig) def make_frame(grid_index: int) -> None: @@ -308,9 +303,7 @@ def _compute_attractive_forces( return attractive_forces - def _spring_layout( - self, graph: chex.Array, seed: int = 42 - ) -> List[Tuple[float, float]]: + def _spring_layout(self, graph: chex.Array, seed: int = 42) -> List[Tuple[float, float]]: """Compute a 2D spring layout for the given graph using the Fruchterman-Reingold force-directed algorithm. @@ -334,9 +327,7 @@ def _spring_layout( temperature = 2.0 # Added a temperature variable for _ in range(iterations): - repulsive_forces = self._compute_repulsive_forces( - np.zeros((num_nodes, 2)), pos, k - ) + repulsive_forces = self._compute_repulsive_forces(np.zeros((num_nodes, 2)), pos, k) attractive_forces = self._compute_attractive_forces( graph, np.zeros((num_nodes, 2)), pos, k ) diff --git a/jumanji/environments/routing/multi_cvrp/conftest.py b/jumanji/environments/routing/multi_cvrp/conftest.py index 2c5a7ac93..1d198db02 100644 --- a/jumanji/environments/routing/multi_cvrp/conftest.py +++ b/jumanji/environments/routing/multi_cvrp/conftest.py @@ -27,8 +27,6 @@ def multicvrp_env() -> MultiCVRP: generator = UniformRandomGenerator(num_vehicles=2, num_customers=20) # Create the reward function - reward_fn = SparseReward( - generator._num_vehicles, generator._num_customers, generator._map_max - ) + reward_fn = SparseReward(generator._num_vehicles, generator._num_customers, generator._map_max) return MultiCVRP(generator=generator, reward_fn=reward_fn) diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index 2cd53c46c..ec57b79d5 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -136,8 +136,7 @@ def __init__( self._speed: int = 1 self._max_local_time = ( - max_single_vehicle_distance(self._map_max, self._num_customers) - / self._speed + max_single_vehicle_distance(self._map_max, self._num_customers) / self._speed ) super().__init__() @@ -163,9 +162,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """ Run one timestep of the environment's dynamics. @@ -180,8 +177,7 @@ def step( new_state = self._update_state(state, action) is_done = ( - (new_state.nodes.demands.sum() == 0) - & (new_state.vehicles.positions == DEPOT_IDX).all() + (new_state.nodes.demands.sum() == 0) & (new_state.vehicles.positions == DEPOT_IDX).all() ) | jnp.any(new_state.step_count > self._num_customers * 2) reward = self._reward_fn(state, new_state, is_done) @@ -398,9 +394,7 @@ def _update_state(self, state: State, action: chex.Array) -> State: # Zero node selections where more than one vehicle selected a valid conditional # action to visit the same node. - values, unique_indices = jnp.unique( - next_nodes, return_index=True, size=self._num_vehicles - ) + values, unique_indices = jnp.unique(next_nodes, return_index=True, size=self._num_vehicles) next_nodes = jnp.zeros(len(next_nodes), dtype=next_nodes.dtype) next_nodes = next_nodes.at[unique_indices].set(values) @@ -409,9 +403,7 @@ def _update_state(self, state: State, action: chex.Array) -> State: end_coords = state.nodes.coordinates[next_nodes] step_travel_distances = compute_distance(start_coords, end_coords) vehicle_distances = state.vehicles.distances + step_travel_distances - vehicle_local_times = ( - state.vehicles.local_times + step_travel_distances / self._speed - ) + vehicle_local_times = state.vehicles.local_times + step_travel_distances / self._speed # Update the vehicle time penalties. vehicle_time_penalties = state.vehicles.time_penalties + compute_time_penalties( @@ -489,9 +481,7 @@ def _state_to_observation(self, state: State) -> Observation: action_mask=state.action_mask, ) - def _state_to_timestep( - self, state: State, reward: chex.Numeric, is_done: bool - ) -> TimeStep: + def _state_to_timestep(self, state: State, reward: chex.Numeric, is_done: bool) -> TimeStep: """ Checks if the state is terminal and converts it into a timestep. diff --git a/jumanji/environments/routing/multi_cvrp/env_test.py b/jumanji/environments/routing/multi_cvrp/env_test.py index f3292cf40..15cad20d4 100644 --- a/jumanji/environments/routing/multi_cvrp/env_test.py +++ b/jumanji/environments/routing/multi_cvrp/env_test.py @@ -117,9 +117,7 @@ def test_multicvrp__step(self, multicvrp_env: MultiCVRP) -> None: new_actions.append(node_i) node_i += 1 if node_i >= multicvrp_env._num_customers: - raise ValueError( - "There is not enough customer demand for a second action.." - ) + raise ValueError("There is not enough customer demand for a second action..") new_actions = jax.numpy.array(new_actions, dtype=np.int16) # # Take the same actions again which should now be invalid. @@ -142,9 +140,7 @@ def test_multicvrp__update_state(self, multicvrp_env: MultiCVRP) -> None: """Validates the jitted step of the environment.""" chex.clear_trace_counter() - _update_state_fn = jax.jit( - chex.assert_max_traces(multicvrp_env._update_state, n=1) - ) + _update_state_fn = jax.jit(chex.assert_max_traces(multicvrp_env._update_state, n=1)) key = jax.random.PRNGKey(0) state, _ = multicvrp_env.reset(key) @@ -176,9 +172,7 @@ def test_multicvrp__update_state(self, multicvrp_env: MultiCVRP) -> None: ) # Check that the node coordinates remained the same - assert jax.numpy.array_equal( - state.nodes.coordinates, new_state.nodes.coordinates - ) + assert jax.numpy.array_equal(state.nodes.coordinates, new_state.nodes.coordinates) new_actions = jax.numpy.array([0, 0], dtype=np.int16) @@ -195,15 +189,11 @@ def test_multicvrp__update_state(self, multicvrp_env: MultiCVRP) -> None: state.vehicles.positions, jax.numpy.array([0, 0], dtype=jax.numpy.int16) ) - def test_multicvrp__state_to_observation_timestep( - self, multicvrp_env: MultiCVRP - ) -> None: + def test_multicvrp__state_to_observation_timestep(self, multicvrp_env: MultiCVRP) -> None: """Validates the jitted step of the environment.""" chex.clear_trace_counter() - update_state_fn = jax.jit( - chex.assert_max_traces(multicvrp_env._update_state, n=1) - ) + update_state_fn = jax.jit(chex.assert_max_traces(multicvrp_env._update_state, n=1)) state_to_observation_fn = jax.jit( chex.assert_max_traces(multicvrp_env._state_to_observation, n=1) ) @@ -211,9 +201,7 @@ def test_multicvrp__state_to_observation_timestep( chex.assert_max_traces(multicvrp_env._state_to_timestep, n=1) ) - reward_fn = jax.jit( - chex.assert_max_traces(multicvrp_env._reward_fn.__call__, n=1) - ) + reward_fn = jax.jit(chex.assert_max_traces(multicvrp_env._reward_fn.__call__, n=1)) key = jax.random.PRNGKey(0) state, _ = multicvrp_env.reset(key) @@ -255,19 +243,13 @@ def test_multicvrp__state_to_observation_timestep( assert timestep.mid() # Check that the reward and discount values are correct - assert np.array_equal( - timestep.reward, jax.numpy.array(0.0, dtype=jax.numpy.float32) - ) - assert np.array_equal( - timestep.discount, jax.numpy.array(1.0, dtype=jax.numpy.float32) - ) + assert np.array_equal(timestep.reward, jax.numpy.array(0.0, dtype=jax.numpy.float32)) + assert np.array_equal(timestep.discount, jax.numpy.array(1.0, dtype=jax.numpy.float32)) def test_env_multicvrp__does_not_smoke(self, multicvrp_env: MultiCVRP) -> None: def select_actions(key: chex.PRNGKey, observation: Observation) -> chex.Array: @jax.vmap # map over the agents - def select_action( - key: chex.PRNGKey, agent_action_mask: chex.Array - ) -> chex.Array: + def select_action(key: chex.PRNGKey, agent_action_mask: chex.Array) -> chex.Array: return jax.numpy.array( jax.random.choice( key, @@ -282,8 +264,6 @@ def select_action( check_env_does_not_smoke(multicvrp_env, select_actions) - def test_env_multicvrp__specs_does_not_smoke( - self, multicvrp_env: MultiCVRP - ) -> None: + def test_env_multicvrp__specs_does_not_smoke(self, multicvrp_env: MultiCVRP) -> None: """Test that we can access specs without any errors.""" check_env_specs_does_not_smoke(multicvrp_env) diff --git a/jumanji/environments/routing/multi_cvrp/generator.py b/jumanji/environments/routing/multi_cvrp/generator.py index e2c0d8780..c063db38b 100644 --- a/jumanji/environments/routing/multi_cvrp/generator.py +++ b/jumanji/environments/routing/multi_cvrp/generator.py @@ -162,9 +162,7 @@ def __call__(self, key: chex.PRNGKey) -> State: distances=jnp.zeros(self._num_vehicles, dtype=jnp.float32), time_penalties=jnp.zeros(self._num_vehicles, dtype=jnp.float32), ), - order=jnp.zeros( - (self._num_vehicles, 2 * self._num_customers), dtype=jnp.int16 - ), + order=jnp.zeros((self._num_vehicles, 2 * self._num_customers), dtype=jnp.int16), action_mask=create_action_mask(node_demands, capacities), step_count=jnp.ones((), dtype=jnp.int16), key=key, diff --git a/jumanji/environments/routing/multi_cvrp/generator_test.py b/jumanji/environments/routing/multi_cvrp/generator_test.py index b87ce7b31..5aea798da 100644 --- a/jumanji/environments/routing/multi_cvrp/generator_test.py +++ b/jumanji/environments/routing/multi_cvrp/generator_test.py @@ -63,9 +63,7 @@ def test_uniform_random_generator__no_retrace( ) -> None: """Checks that generator only traces the function once and works when jitted.""" keys = jax.random.split(key, 2) - jitted_generator = jax.jit( - chex.assert_max_traces((uniform_random_generator.__call__), n=1) - ) + jitted_generator = jax.jit(chex.assert_max_traces((uniform_random_generator.__call__), n=1)) for key in keys: jitted_generator(key) diff --git a/jumanji/environments/routing/multi_cvrp/reward.py b/jumanji/environments/routing/multi_cvrp/reward.py index 43395afbf..82388a0d5 100644 --- a/jumanji/environments/routing/multi_cvrp/reward.py +++ b/jumanji/environments/routing/multi_cvrp/reward.py @@ -55,7 +55,6 @@ def __call__( is_done: bool, ) -> chex.Numeric: def compute_episode_reward(new_state: State) -> float: - return jax.lax.cond( # type: ignore jnp.any(new_state.step_count > self._num_customers * 2), # Penalise for running into step limit. This is not including max time @@ -91,13 +90,11 @@ def __call__( is_done: bool, ) -> chex.Numeric: def compute_reward(state: State, new_state: State) -> float: - step_vehicle_distance_penalty = ( state.vehicles.distances.sum() - new_state.vehicles.distances.sum() ) step_time_penalty = ( - state.vehicles.time_penalties.sum() - - new_state.vehicles.time_penalties.sum() + state.vehicles.time_penalties.sum() - new_state.vehicles.time_penalties.sum() ) return jax.lax.cond( # type: ignore diff --git a/jumanji/environments/routing/multi_cvrp/utils.py b/jumanji/environments/routing/multi_cvrp/utils.py index 53e219029..8cc44ecda 100644 --- a/jumanji/environments/routing/multi_cvrp/utils.py +++ b/jumanji/environments/routing/multi_cvrp/utils.py @@ -22,9 +22,7 @@ from jumanji.environments.routing.multi_cvrp.types import State -def create_action_mask( - node_demands: chex.Array, vehicle_capacities: chex.Array -) -> chex.Array: +def create_action_mask(node_demands: chex.Array, vehicle_capacities: chex.Array) -> chex.Array: # The action is valid if the node has a # non-zero demand and the vehicle has enough capacity. def single_vehicle_action_mask(capacity: chex.Array) -> chex.Array: @@ -69,9 +67,7 @@ def compute_time_penalties( return time_penalties -def max_single_vehicle_distance( - map_max: chex.Array, num_customers: chex.Array -) -> chex.Array: +def max_single_vehicle_distance(map_max: chex.Array, num_customers: chex.Array) -> chex.Array: return 2 * map_max * jnp.sqrt(2) * num_customers @@ -79,10 +75,7 @@ def worst_case_remaining_reward(state: State) -> chex.Array: has_demand = state.nodes.demands > 0 distance_penalty = ( 2 - * ( - compute_distance(state.nodes.coordinates[0], state.nodes.coordinates) - * has_demand - ).sum() + * (compute_distance(state.nodes.coordinates[0], state.nodes.coordinates) * has_demand).sum() ) # Assuming the speed is 1.0. @@ -122,7 +115,6 @@ def generate_uniform_random_problem( early_coef_rand: Tuple[jnp.float32, jnp.float32], late_coef_rand: Tuple[jnp.float32, jnp.float32], ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]: - # Generate the node coordinates coord_key, demand_key, window_key, earl_key, late_key = jax.random.split(key, 5) diff --git a/jumanji/environments/routing/multi_cvrp/utils_test.py b/jumanji/environments/routing/multi_cvrp/utils_test.py index f20de2d3c..d68e1a7e6 100644 --- a/jumanji/environments/routing/multi_cvrp/utils_test.py +++ b/jumanji/environments/routing/multi_cvrp/utils_test.py @@ -197,6 +197,4 @@ def test_generate_uniform_random_problem(self) -> None: assert np.array_equal( early_coefs, np.array([0.0] + [0.1] * num_customers, dtype=np.float32) ) - assert np.array_equal( - late_coefs, np.array([0.0] + [0.5] * num_customers, dtype=np.float32) - ) + assert np.array_equal(late_coefs, np.array([0.0] + [0.5] * num_customers, dtype=np.float32)) diff --git a/jumanji/environments/routing/multi_cvrp/viewer.py b/jumanji/environments/routing/multi_cvrp/viewer.py index 705d545a0..ea32ec8b4 100644 --- a/jumanji/environments/routing/multi_cvrp/viewer.py +++ b/jumanji/environments/routing/multi_cvrp/viewer.py @@ -179,9 +179,7 @@ def _group_tour(self, tour: chex.Array) -> list: depot = tour[0] check_depot_fn = lambda x: (x != depot).all() tour_grouped = [ - np.array([depot] + list(g) + [depot]) - for k, g in groupby(tour, key=check_depot_fn) - if k + np.array([depot, *list(g), depot]) for k, g in groupby(tour, key=check_depot_fn) if k ] if (tour[-1] != tour[0]).all(): tour_grouped[-1] = tour_grouped[-1][:-1] @@ -226,16 +224,13 @@ def _add_tour(self, ax: plt.Axes, state: State) -> None: if state.step_count > 0: # TODO (dries): Can we do this without a for loop? for i in range(len(state.order)): - coords = ( - state.nodes.coordinates[state.order[i, : state.step_count]] - / self._map_max - ) + coords = state.nodes.coordinates[state.order[i, : state.step_count]] / self._map_max coords_grouped = self._group_tour(coords) # Draw each route in different colour for coords_route, _ in zip( - coords_grouped, np.arange(0, len(coords_grouped)) + coords_grouped, np.arange(0, len(coords_grouped)), strict=False ): self._draw_route(ax, coords_route, i) diff --git a/jumanji/environments/routing/multi_cvrp/viewer_test.py b/jumanji/environments/routing/multi_cvrp/viewer_test.py index 98b030060..606db00b0 100644 --- a/jumanji/environments/routing/multi_cvrp/viewer_test.py +++ b/jumanji/environments/routing/multi_cvrp/viewer_test.py @@ -40,9 +40,7 @@ def test_render(multicvrp_env: MultiCVRP) -> None: ) # Starting position is depot, new action to visit first node - new_actions = jnp.array( - jnp.arange(1, multicvrp_env._num_vehicles + 1), dtype=np.int16 - ) + new_actions = jnp.array(jnp.arange(1, multicvrp_env._num_vehicles + 1), dtype=np.int16) new_state, next_timestep = step_fn(state, new_actions) @@ -58,9 +56,7 @@ def test_animation(multicvrp_env: MultiCVRP) -> None: def select_actions(key: chex.PRNGKey, observation: Observation) -> chex.Array: @jax.vmap # map over the agents - def select_action( - key: chex.PRNGKey, agent_action_mask: chex.Array - ) -> chex.Array: + def select_action(key: chex.PRNGKey, agent_action_mask: chex.Array) -> chex.Array: return jnp.array( jax.random.choice( key, diff --git a/jumanji/environments/routing/pac_man/constants.py b/jumanji/environments/routing/pac_man/constants.py index 9042a5e79..6bccd93f1 100644 --- a/jumanji/environments/routing/pac_man/constants.py +++ b/jumanji/environments/routing/pac_man/constants.py @@ -14,9 +14,7 @@ import jax.numpy as jnp -MOVES = jnp.array( - [[0, -1], [-1, 0], [0, 1], [1, 0], [0, 0]] -) # Up, Right, Down, Left, No-op +MOVES = jnp.array([[0, -1], [-1, 0], [0, 1], [1, 0], [0, 0]]) # Up, Right, Down, Left, No-op # Default Maze design diff --git a/jumanji/environments/routing/pac_man/env.py b/jumanji/environments/routing/pac_man/env.py index 3007042b2..ad696e63c 100644 --- a/jumanji/environments/routing/pac_man/env.py +++ b/jumanji/environments/routing/pac_man/env.py @@ -245,9 +245,7 @@ def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. If an action is invalid, the agent does not move, i.e. the episode does not @@ -320,9 +318,7 @@ def _update_state(self, state: State, action: chex.Array) -> Tuple[State, int]: ghost_paths, ghost_actions, key = ghost_move(state, self.x_size, self.y_size) # Check for collisions with ghosts - state, done, ghost_col_rewards = check_ghost_collisions( - ghost_paths, next_player_pos, state - ) + state, done, ghost_col_rewards = check_ghost_collisions(ghost_paths, next_player_pos, state) state = state.replace(player_locations=next_player_pos) # type: ignore state = state.replace(dead=done) @@ -349,9 +345,7 @@ def tick_frightened_time() -> chex.Array: return jnp.array(state.frightened_state_time - 1, jnp.int32) # Check if frightened state is active and decrement timer - state.frightened_state_time = jax.lax.cond( - eat > 0, powerup_collected, tick_frightened_time - ) + state.frightened_state_time = jax.lax.cond(eat > 0, powerup_collected, tick_frightened_time) # Update power-up locations state.power_up_locations = power_up_locations @@ -424,9 +418,7 @@ def player_step(self, state: State, action: int, steps: int = 1) -> Position: new_pos = Position(x=new_pos_col % self.x_size, y=new_pos_row % self.y_size) return new_pos - def check_power_up( - self, state: State - ) -> Tuple[chex.Array, chex.Numeric, chex.Numeric]: + def check_power_up(self, state: State) -> Tuple[chex.Array, chex.Numeric, chex.Numeric]: """ Check if the player is on a power-up location and update the power-up locations array accordingly. @@ -498,9 +490,9 @@ def is_move_valid(agent_position: Position, move: chex.Array) -> chex.Array: return grid[x][y] # vmap over the moves. - action_mask = jax.vmap(is_move_valid, in_axes=(None, 0))( - player_pos, MOVES - ) * jnp.array([True, True, True, True, False]) + action_mask = jax.vmap(is_move_valid, in_axes=(None, 0))(player_pos, MOVES) * jnp.array( + [True, True, True, True, False] + ) return action_mask diff --git a/jumanji/environments/routing/pac_man/generator.py b/jumanji/environments/routing/pac_man/generator.py index 4bf305753..6fa369b12 100644 --- a/jumanji/environments/routing/pac_man/generator.py +++ b/jumanji/environments/routing/pac_man/generator.py @@ -124,7 +124,6 @@ def __init__(self, maze: List) -> None: self.y_size = self.numpy_maze.shape[1] def __call__(self, key: chex.PRNGKey) -> State: - grid = self.numpy_maze pellets = self.pellet_spaces.shape[0] frightened_state_time = jnp.array(0, jnp.int32) diff --git a/jumanji/environments/routing/pac_man/utils.py b/jumanji/environments/routing/pac_man/utils.py index ef49f7741..ac21852ac 100644 --- a/jumanji/environments/routing/pac_man/utils.py +++ b/jumanji/environments/routing/pac_man/utils.py @@ -83,16 +83,12 @@ def move( jnp.array_equal(valids, vert_col), jnp.array_equal(valids, hor_col) ) - def is_tunnel( - inputs: Tuple[chex.Array, chex.Array, int, chex.PRNGKey, int] - ) -> int: + def is_tunnel(inputs: Tuple[chex.Array, chex.Array, int, chex.PRNGKey, int]) -> int: """Repeat old action if in tunnel""" _, _, ghost_action, _, _ = inputs return ghost_action - def no_tunnel( - inputs: Tuple[chex.Array, chex.Array, int, chex.PRNGKey, int] - ) -> Any: + def no_tunnel(inputs: Tuple[chex.Array, chex.Array, int, chex.PRNGKey, int]) -> Any: """Chose new action when at intersection""" logits, actions, _, ghost_tunnel_key, _ = inputs _, ghost_tunnel_key = jax.random.split(ghost_tunnel_key) @@ -113,24 +109,18 @@ def no_tunnel( ghost_num, ) - def start_over( - inputs: Tuple[chex.Array, chex.Array, int, chex.PRNGKey, int] - ) -> Any: + def start_over(inputs: Tuple[chex.Array, chex.Array, int, chex.PRNGKey, int]) -> Any: """If not in waiting mode then pick new action""" chosen_action = jax.lax.cond(is_in_tunnel, is_tunnel, no_tunnel, inputs) return jnp.squeeze(chosen_action) - def no_start( - inputs: Tuple[chex.Array, chex.Array, int, chex.PRNGKey, int] - ) -> int: + def no_start(inputs: Tuple[chex.Array, chex.Array, int, chex.PRNGKey, int]) -> int: """If in waiting mode then use no-op""" return 4 position = ghost_pos - chosen_action = jax.lax.cond( - ghost_start < 0, start_over, no_start, inputs_no_tunnel - ) + chosen_action = jax.lax.cond(ghost_start < 0, start_over, no_start, inputs_no_tunnel) # Use chosen action move_left = lambda position: (position[1], position[0] - 1) @@ -179,7 +169,6 @@ def check_ghost_wall_collisions( x_size: int, y_size: int, ) -> Tuple[bool, chex.Array, chex.Array]: - """ Determine valid move for the ghost. @@ -227,9 +216,7 @@ def red_ghost(pacman_pos: Position) -> Tuple[chex.Array, chex.Array]: return distance_list, ghost_dist # For ghost 1: move 4 steps ahead of pacman - def pink_ghost( - pacman_pos: Position, steps: int = 4 - ) -> Tuple[chex.Array, chex.Array]: + def pink_ghost(pacman_pos: Position, steps: int = 4) -> Tuple[chex.Array, chex.Array]: """ Select targets for pink ghost as distance from the tile 4 steps ahead of the current position of pacman. @@ -268,9 +255,7 @@ def orange_ghost(pacman_pos: Position) -> Tuple[chex.Array, chex.Array]: distance_list = jax.vmap(get_directions, in_axes=(None, 0))(pacman_pos, ghost_p) - _, ghost_dist = jax.lax.cond( - distance_pacman > 8, red_ghost, scared_behaviors, pacman_pos - ) + _, ghost_dist = jax.lax.cond(distance_pacman > 8, red_ghost, scared_behaviors, pacman_pos) return distance_list, ghost_dist @@ -283,9 +268,7 @@ def general_behaviors(pacman_pos: Position) -> Tuple[chex.Array, chex.Array]: inky = lambda pacman_pos: pink_ghost(pacman_pos) pinky = lambda pacman_pos: blue_ghost(pacman_pos) clyde = lambda pacman_pos: orange_ghost(pacman_pos) - _, ghost_dist = jax.lax.switch( - ghost_num, [blinky, inky, pinky, clyde], pacman_pos - ) + _, ghost_dist = jax.lax.switch(ghost_num, [blinky, inky, pinky, clyde], pacman_pos) return _, ghost_dist def scared_behaviors(pacman_pos: Position) -> Tuple[chex.Array, chex.Array]: @@ -296,9 +279,7 @@ def scared_behaviors(pacman_pos: Position) -> Tuple[chex.Array, chex.Array]: def behaviors() -> Tuple[chex.Array, chex.Array]: """Select scatter or normal targets""" - _, ghost_dist = jax.lax.cond( - is_scared > 0, scared_behaviors, general_behaviors, pacman_pos - ) + _, ghost_dist = jax.lax.cond(is_scared > 0, scared_behaviors, general_behaviors, pacman_pos) return _, ghost_dist def init_behaviors() -> Tuple[chex.Array, chex.Array]: @@ -307,9 +288,7 @@ def init_behaviors() -> Tuple[chex.Array, chex.Array]: _, ghost_dist = red_ghost(pacman_pos=target) return _, ghost_dist - _, ghost_dist = jax.lax.cond( - ghost_init_steps[ghost_num] > 0, init_behaviors, behaviors - ) + _, ghost_dist = jax.lax.cond(ghost_init_steps[ghost_num] > 0, init_behaviors, behaviors) def get_valid_positions(pos: chex.Array) -> Any: """Get values of surrounding positions""" @@ -401,9 +380,7 @@ def no_col_fn() -> Tuple[chex.Array, chex.Numeric, chex.Numeric, chex.Numeric]: def col_fn() -> Tuple[chex.Array, chex.Numeric, chex.Numeric, chex.Numeric]: reset_true = lambda: (jnp.array(og_pos), False, 200.0, False) reset_false = lambda: (ghost_pos, True, 0.0, edible) - path, done, col_reward, ghost_eaten = jax.lax.cond( - ghost_reset, reset_true, reset_false - ) + path, done, col_reward, ghost_eaten = jax.lax.cond(ghost_reset, reset_true, reset_false) return path, done, col_reward, ghost_eaten # First check for collision @@ -441,9 +418,7 @@ def get_directions(pacman_position: Position, ghost_position: chex.Array) -> che return direction -def player_step( - state: State, action: int, x_size: int, y_size: int, steps: int = 1 -) -> Position: +def player_step(state: State, action: int, x_size: int, y_size: int, steps: int = 1) -> Position: """ Compute the new position of the player based on the given state and action. diff --git a/jumanji/environments/routing/pac_man/viewer.py b/jumanji/environments/routing/pac_man/viewer.py index bf9ed5174..6dac26660 100644 --- a/jumanji/environments/routing/pac_man/viewer.py +++ b/jumanji/environments/routing/pac_man/viewer.py @@ -123,9 +123,7 @@ def make_frame(state_index: int) -> None: return self._animation - def _add_grid_image( - self, state: Union[Observation, State], ax: Axes - ) -> image.AxesImage: + def _add_grid_image(self, state: Union[Observation, State], ax: Axes) -> image.AxesImage: img = create_grid_image(state) ax.set_axis_off() return ax.imshow(img) diff --git a/jumanji/environments/routing/pac_man/viewer_test.py b/jumanji/environments/routing/pac_man/viewer_test.py index aee71a3f9..66d20dda9 100644 --- a/jumanji/environments/routing/pac_man/viewer_test.py +++ b/jumanji/environments/routing/pac_man/viewer_test.py @@ -29,9 +29,7 @@ def pac_man() -> PacMan: return PacMan() -def test_pacman_viewer__render( - pac_man: PacMan, monkeypatch: pytest.MonkeyPatch -) -> None: +def test_pacman_viewer__render(pac_man: PacMan, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(plt, "show", lambda fig: None) key = jax.random.PRNGKey(0) state, _ = pac_man.reset(key) @@ -57,9 +55,7 @@ def test_robot_warehouse_viewer__animate(pac_man: PacMan) -> None: viewer.close() -def test_robot_warehouse_viewer__save_animation( - pac_man: PacMan, tmpdir: py.path.local -) -> None: +def test_robot_warehouse_viewer__save_animation(pac_man: PacMan, tmpdir: py.path.local) -> None: key = jax.random.PRNGKey(0) state, _ = jax.jit(pac_man.reset)(key) diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index 8ab107bc4..7e1346832 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -178,17 +178,13 @@ def __init__( self.agent_ids = jnp.arange(self.num_agents) self.directions = jnp.array([d.value for d in Direction]) - self.num_obs_features = utils.calculate_num_observation_features( - self.sensor_range - ) + self.num_obs_features = utils.calculate_num_observation_features(self.sensor_range) self.goals = self._generator.goals self.time_limit = time_limit super().__init__() # create viewer for rendering environment - self._viewer = viewer or RobotWarehouseViewer( - self.grid_size, self.goals, "RobotWarehouse" - ) + self._viewer = viewer or RobotWarehouseViewer(self.grid_size, self.goals, "RobotWarehouse") def __repr__(self) -> str: return ( @@ -258,9 +254,7 @@ def update_state_scan( carry_info: Tuple[chex.Array, chex.Array, chex.Array, int], action: int ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array, int], None]: grid, agents, shelves, agent_id = carry_info - grid, agents, shelves = self._update_state( - grid, agents, shelves, action, agent_id - ) + grid, agents, shelves = self._update_state(grid, agents, shelves, action, agent_id) return (grid, agents, shelves, agent_id + 1), None (grid, agents, shelves, _), _ = jax.lax.scan( @@ -268,9 +262,7 @@ def update_state_scan( ) # check for agent collisions - collisions = jax.vmap(functools.partial(utils.is_collision, grid))( - agents, self.agent_ids - ) + collisions = jax.vmap(functools.partial(utils.is_collision, grid))(agents, self.agent_ids) collision = jnp.any(collisions) # compute shared reward for all agents and update request queue @@ -278,13 +270,9 @@ def update_state_scan( reward = jnp.array(0, dtype=jnp.float32) def update_reward_and_request_queue_scan( - carry_info: Tuple[ - chex.PRNGKey, chex.Array, chex.Array, chex.Array, chex.Array - ], + carry_info: Tuple[chex.PRNGKey, chex.Array, chex.Array, chex.Array, chex.Array], goal: chex.Array, - ) -> Tuple[ - Tuple[chex.PRNGKey, chex.Array, chex.Array, chex.Array, chex.Array], None - ]: + ) -> Tuple[Tuple[chex.PRNGKey, chex.Array, chex.Array, chex.Array, chex.Array], None]: key, reward, request_queue, grid, shelves = carry_info ( key, @@ -348,9 +336,7 @@ def observation_spec(self) -> specs.Spec[Observation]: agents_view = specs.Array( (self.num_agents, self.num_obs_features), jnp.int32, "agents_view" ) - action_mask = specs.BoundedArray( - (self.num_agents, 5), bool, False, True, "action_mask" - ) + action_mask = specs.BoundedArray((self.num_agents, 5), bool, False, True, "action_mask") step_count = specs.BoundedArray((), jnp.int32, 0, self.time_limit, "step_count") return specs.Spec( Observation, @@ -497,9 +483,7 @@ def reward_and_update_request_queue_if_shelf_in_goal( return key, reward, request_queue, shelves # check if shelf is at goal position and in request queue - shelf_at_goal = (~jnp.equal(shelf_id, 0)) & jnp.isin( - shelf_id, request_queue + 1 - ) + shelf_at_goal = (~jnp.equal(shelf_id, 0)) & jnp.isin(shelf_id, request_queue + 1) key, reward, request_queue, shelves = jax.lax.cond( shelf_at_goal, @@ -540,9 +524,7 @@ def animate( Returns: Animation object that can be saved as a GIF, MP4, or rendered with HTML. """ - return self._viewer.animate( - states=states, interval=interval, save_path=save_path - ) + return self._viewer.animate(states=states, interval=interval, save_path=save_path) def close(self) -> None: """Perform any necessary cleanup. diff --git a/jumanji/environments/routing/robot_warehouse/env_test.py b/jumanji/environments/routing/robot_warehouse/env_test.py index cf37e3b2e..d9cbcc6e5 100644 --- a/jumanji/environments/routing/robot_warehouse/env_test.py +++ b/jumanji/environments/routing/robot_warehouse/env_test.py @@ -73,19 +73,13 @@ def test_robot_warehouse__agent_observation( agent1_own_view = jnp.array([3, 4, 0, 0, 0, 1, 0, 1]) agent1_other_agents_view = jnp.array(8 * [0, 0, 0, 0, 0]) agent1_shelf_view = jnp.array(9 * [0, 0]) - agent1_obs = jnp.hstack( - [agent1_own_view, agent1_other_agents_view, agent1_shelf_view] - ) + agent1_obs = jnp.hstack([agent1_own_view, agent1_other_agents_view, agent1_shelf_view]) # agent 2 obs agent2_own_view = jnp.array([1, 7, 0, 0, 0, 0, 1, 0]) agent2_other_agents_view = jnp.array(8 * [0, 0, 0, 0, 0]) - agent2_shelf_view = jnp.array( - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1] - ) - agent2_obs = jnp.hstack( - [agent2_own_view, agent2_other_agents_view, agent2_shelf_view] - ) + agent2_shelf_view = jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1]) + agent2_obs = jnp.hstack([agent2_own_view, agent2_other_agents_view, agent2_shelf_view]) assert jnp.all(timestep.observation.agents_view[0] == agent1_obs) assert jnp.all(timestep.observation.agents_view[1] == agent2_obs) @@ -150,7 +144,7 @@ def test_robot_warehouse__step(robot_warehouse_env: RobotWarehouse) -> None: (x - 1, y - 1, 0), # move forward -> move up ] - for action, new_loc in zip(actions, new_locs): + for action, new_loc in zip(actions, new_locs, strict=False): state, timestep = step_fn(state, jnp.array([action, action])) agent1_info = tree_slice(state.agents, 1) agent1_loc = ( diff --git a/jumanji/environments/routing/robot_warehouse/generator.py b/jumanji/environments/routing/robot_warehouse/generator.py index 2f0b79ec0..b27ebfc1d 100644 --- a/jumanji/environments/routing/robot_warehouse/generator.py +++ b/jumanji/environments/routing/robot_warehouse/generator.py @@ -58,9 +58,7 @@ def __init__( given time which remains fixed throughout environment steps. Defaults to 4. """ if shelf_columns % 2 != 1: - raise ValueError( - "Environment argument: `shelf_columns`, must be an odd number." - ) + raise ValueError("Environment argument: `shelf_columns`, must be an odd number.") self._shelf_rows = shelf_rows self._shelf_columns = shelf_columns diff --git a/jumanji/environments/routing/robot_warehouse/utils.py b/jumanji/environments/routing/robot_warehouse/utils.py index 6b48500b9..ca3b0c795 100644 --- a/jumanji/environments/routing/robot_warehouse/utils.py +++ b/jumanji/environments/routing/robot_warehouse/utils.py @@ -238,23 +238,17 @@ def make_agent_observation( ) # function for writing receptive field cells - def write_no_agent( - obs: chex.Array, idx: int, _: int, is_self: bool - ) -> Tuple[chex.Array, int]: + def write_no_agent(obs: chex.Array, idx: int, _: int, is_self: bool) -> Tuple[chex.Array, int]: "Write information for empty agent cell." # if there is no agent we set a 0 and all zeros # for the direction as well, i.e. [0, 0, 0, 0, 0] idx = jax.lax.cond(is_self, lambda i: i, lambda i: move_writer_index(i, 5), idx) return obs, idx - def write_agent( - obs: chex.Array, idx: int, id_agent: int, _: bool - ) -> Tuple[chex.Array, int]: + def write_agent(obs: chex.Array, idx: int, id_agent: int, _: bool) -> Tuple[chex.Array, int]: "Write information for cell containing an agent." obs, idx = write_to_observation(obs, idx, jnp.array([1], dtype=jnp.int32)) - direction = jax.nn.one_hot( - tree_slice(agents, id_agent - 1).direction, 4, dtype=jnp.int32 - ) + direction = jax.nn.one_hot(tree_slice(agents, id_agent - 1).direction, 4, dtype=jnp.int32) obs, idx = write_to_observation(obs, idx, direction) return obs, idx @@ -311,9 +305,7 @@ def shelf_sensor_scan( ) return (obs, idx), None - (obs, idx, _), _ = jax.lax.scan( - agent_sensor_scan, (obs, idx, agent_id), agents_grid - ) + (obs, idx, _), _ = jax.lax.scan(agent_sensor_scan, (obs, idx, agent_id), agents_grid) (obs, _), _ = jax.lax.scan(shelf_sensor_scan, (obs, idx), shelves_grid) return obs diff --git a/jumanji/environments/routing/robot_warehouse/utils_agent.py b/jumanji/environments/routing/robot_warehouse/utils_agent.py index 1e568f7fb..720a29adf 100644 --- a/jumanji/environments/routing/robot_warehouse/utils_agent.py +++ b/jumanji/environments/routing/robot_warehouse/utils_agent.py @@ -50,9 +50,7 @@ def update_agent( return agents -def get_new_direction_after_turn( - action: chex.Array, agent_direction: chex.Array -) -> chex.Array: +def get_new_direction_after_turn(action: chex.Array, agent_direction: chex.Array) -> chex.Array: """Get the correct direction the agent should face given the turn action it took. E.g. if the agent is facing LEFT and turns RIGHT it should now be facing UP, etc. diff --git a/jumanji/environments/routing/robot_warehouse/utils_spawn.py b/jumanji/environments/routing/robot_warehouse/utils_spawn.py index 2da8b91c2..1e92eeb69 100644 --- a/jumanji/environments/routing/robot_warehouse/utils_spawn.py +++ b/jumanji/environments/routing/robot_warehouse/utils_spawn.py @@ -104,16 +104,12 @@ def spawn_random_entities( shape=(num_agents,), replace=False, ) - agent_coords = jnp.transpose( - jnp.asarray(jnp.unravel_index(agent_coords, grid_size)) - ) + agent_coords = jnp.transpose(jnp.asarray(jnp.unravel_index(agent_coords, grid_size))) # random agent directions key, direction_key = jax.random.split(key) - agent_dirs = jax.random.choice( - direction_key, _POSSIBLE_DIRECTIONS, shape=(num_agents,) - ) + agent_dirs = jax.random.choice(direction_key, _POSSIBLE_DIRECTIONS, shape=(num_agents,)) # sample request queue key, queue_key = jax.random.split(key) @@ -155,9 +151,7 @@ def place_entity_on_grid( return grid.at[channel, x, y].set(entity_id + 1) -def place_entities_on_grid( - grid: chex.Array, agents: Agent, shelves: Shelf -) -> chex.Array: +def place_entities_on_grid(grid: chex.Array, agents: Agent, shelves: Shelf) -> chex.Array: """Place agents and shelves on the grid. Args: diff --git a/jumanji/environments/routing/robot_warehouse/utils_test.py b/jumanji/environments/routing/robot_warehouse/utils_test.py index e41036543..47f9f7969 100644 --- a/jumanji/environments/routing/robot_warehouse/utils_test.py +++ b/jumanji/environments/routing/robot_warehouse/utils_test.py @@ -133,33 +133,25 @@ def test_robot_warehouse_utils__entity_update( # test updating agent direction new_direction = 3 - agents_with_new_agent_0_direction = update_agent( - agents, 0, "direction", new_direction - ) + agents_with_new_agent_0_direction = update_agent(agents, 0, "direction", new_direction) agent_0 = tree_slice(agents_with_new_agent_0_direction, 0) assert agent_0.direction == new_direction # test updating agent carrying new_is_carrying = 1 - agents_with_new_agent_0_carrying = update_agent( - agents, 0, "is_carrying", new_is_carrying - ) + agents_with_new_agent_0_carrying = update_agent(agents, 0, "is_carrying", new_is_carrying) agent_0 = tree_slice(agents_with_new_agent_0_carrying, 0) assert agent_0.is_carrying == new_is_carrying # test updating shelf position new_position = Position(x=1, y=3) - shelves_with_new_shelf_0_position = update_shelf( - shelves, 0, "position", new_position - ) + shelves_with_new_shelf_0_position = update_shelf(shelves, 0, "position", new_position) shelf_0 = tree_slice(shelves_with_new_shelf_0_position, 0) assert shelf_0.position == new_position # test updating shelf requested new_is_requested = 1 - shelves_with_new_shelf_0_requested = update_shelf( - shelves, 0, "is_requested", new_is_requested - ) + shelves_with_new_shelf_0_requested = update_shelf(shelves, 0, "is_requested", new_is_requested) shelf_0 = tree_slice(shelves_with_new_shelf_0_requested, 0) assert shelf_0.is_requested == new_is_requested @@ -183,7 +175,7 @@ def test_robot_warehouse_utils__get_new_direction( 3, # turn right -> face left ] - for action, expected_direction in zip(actions, expected_directions): + for action, expected_direction in zip(actions, expected_directions, strict=False): new_direction = get_new_direction_after_turn(action, direction) assert new_direction == expected_direction direction = new_direction @@ -209,7 +201,7 @@ def test_robot_warehouse_utils__get_new_position( Position(3, 3), # facing left move forward ] - for direction, expected_position in zip(directions, expected_positions): + for direction, expected_position in zip(directions, expected_positions, strict=False): new_position = get_new_position_after_forward(grid, position, direction) assert new_position == expected_position @@ -310,9 +302,7 @@ def test_robot_warehouse_utils__get_agent_view( # get agent view with sensor range of 1 sensor_range = 1 - agent_view_of_agents, agent_view_of_shelves = get_agent_view( - grid, agent, sensor_range - ) + agent_view_of_agents, agent_view_of_shelves = get_agent_view(grid, agent, sensor_range) # flattened agent view of other agents and shelves flat_agents = jnp.array([0, 0, 0, 0, 1, 2, 0, 0, 0]) @@ -323,9 +313,7 @@ def test_robot_warehouse_utils__get_agent_view( # get agent view with sensor range of 2 sensor_range = 2 - agent_view_of_agents, agent_view_of_shelves = get_agent_view( - grid, agent, sensor_range - ) + agent_view_of_agents, agent_view_of_shelves = get_agent_view(grid, agent, sensor_range) # flattened agent view of other agents and shelves flat_agents = jnp.array( diff --git a/jumanji/environments/routing/robot_warehouse/viewer.py b/jumanji/environments/routing/robot_warehouse/viewer.py index aaa683ff7..fc4aeeed7 100644 --- a/jumanji/environments/routing/robot_warehouse/viewer.py +++ b/jumanji/environments/routing/robot_warehouse/viewer.py @@ -225,9 +225,7 @@ def _draw_shelves(self, ax: plt.Axes, shelves: chex.Array) -> None: y, x = shelf.position.x, shelf.position.y y = self.rows - y - 1 # pyglet rendering is reversed shelf_color = ( - constants._SHELF_REQ_COLOR - if shelf.is_requested - else constants._SHELF_COLOR + constants._SHELF_REQ_COLOR if shelf.is_requested else constants._SHELF_COLOR ) shelf_padding = constants._SHELF_PADDING @@ -276,9 +274,7 @@ def _draw_agents(self, ax: plt.Axes, agents: chex.Array) -> None: y = y_radius + y_center verts += [[x, y]] facecolor = ( - constants._AGENT_LOADED_COLOR - if agent.is_carrying - else constants._AGENT_COLOR + constants._AGENT_LOADED_COLOR if agent.is_carrying else constants._AGENT_COLOR ) circle = plt.Polygon( verts, diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py index 0a1d0451c..6394003c3 100644 --- a/jumanji/environments/routing/snake/env.py +++ b/jumanji/environments/routing/snake/env.py @@ -165,9 +165,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=self._state_to_observation(state)) return state, timestep - def step( - self, state: State, action: chex.Numeric - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Numeric) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -253,9 +251,7 @@ def observation_spec(self) -> specs.Spec[Observation]: dtype=float, name="grid", ) - step_count = specs.DiscreteArray( - self.time_limit, dtype=jnp.int32, name="step_count" - ) + step_count = specs.DiscreteArray(self.time_limit, dtype=jnp.int32, name="step_count") action_mask = specs.BoundedArray( shape=(4,), dtype=bool, @@ -360,9 +356,7 @@ def is_valid(move: chex.Array) -> chex.Array: action_mask = jax.vmap(is_valid)(self.MOVES) return action_mask - def _update_head_position( - self, head_position: Position, action: chex.Numeric - ) -> Position: + def _update_head_position(self, head_position: Position, action: chex.Numeric) -> Position: """Give the new head position after taking an action. Args: diff --git a/jumanji/environments/routing/snake/viewer.py b/jumanji/environments/routing/snake/viewer.py index 60bf3ae9c..f05797158 100644 --- a/jumanji/environments/routing/snake/viewer.py +++ b/jumanji/environments/routing/snake/viewer.py @@ -72,9 +72,7 @@ def animate( """ if not states: raise ValueError(f"The states argument has to be non-empty, got {states}.") - fig, ax = plt.subplots( - num=f"{self._figure_name}Anim", figsize=self._figure_size - ) + fig, ax = plt.subplots(num=f"{self._figure_name}Anim", figsize=self._figure_size) self._draw_board(ax, states[0]) plt.close(fig) @@ -144,7 +142,7 @@ def _create_entities(self, state: State) -> List[matplotlib.patches.Patch]: patches = [] linewidth = ( - min(n * size for n, size in zip((num_rows, num_cols), self._figure_size)) + min(n * size for n, size in zip((num_rows, num_cols), self._figure_size, strict=False)) / 44.0 ) cmap = matplotlib.colors.LinearSegmentedColormap.from_list( diff --git a/jumanji/environments/routing/sokoban/env.py b/jumanji/environments/routing/sokoban/env.py index 2433df322..a7f5f5fd6 100644 --- a/jumanji/environments/routing/sokoban/env.py +++ b/jumanji/environments/routing/sokoban/env.py @@ -193,9 +193,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """ Executes one timestep of the environment's dynamics. @@ -331,9 +329,7 @@ def _get_extras(self, state: State) -> Dict: } return extras - def grid_combine( - self, variable_grid: chex.Array, fixed_grid: chex.Array - ) -> chex.Array: + def grid_combine(self, variable_grid: chex.Array, fixed_grid: chex.Array) -> chex.Array: """ Combines the variable grid and fixed grid into one single grid representation of the current Sokoban state required for visual @@ -444,9 +440,9 @@ def detect_noop_action( new_location = agent_location + MOVES[action].squeeze() - valid_destination = self.check_space( - fixed_grid, new_location, WALL - ) | ~self.in_grid(new_location) + valid_destination = self.check_space(fixed_grid, new_location, WALL) | ~self.in_grid( + new_location + ) updated_action = jax.lax.select( valid_destination, @@ -538,10 +534,7 @@ def move_agent( next_variable_grid = jax.lax.select( self.check_space(variable_grid, next_location, BOX), - next_variable_grid.at[tuple(next_location)] - .set(AGENT) - .at[tuple(box_location)] - .set(BOX), + next_variable_grid.at[tuple(next_location)].set(AGENT).at[tuple(box_location)].set(BOX), next_variable_grid.at[tuple(next_location)].set(AGENT), ) diff --git a/jumanji/environments/routing/sokoban/generator.py b/jumanji/environments/routing/sokoban/generator.py index e3ace0a4a..eed6d3682 100644 --- a/jumanji/environments/routing/sokoban/generator.py +++ b/jumanji/environments/routing/sokoban/generator.py @@ -87,9 +87,7 @@ def __init__( self.proportion_of_files = proportion_of_files # Set the cache path to user's home directory's .cache sub-directory - self.cache_path = os.path.join( - os.path.expanduser("~"), ".cache", "sokoban_dataset" - ) + self.cache_path = os.path.join(os.path.expanduser("~"), ".cache", "sokoban_dataset") # Downloads data if not already downloaded self._download_data() @@ -100,9 +98,7 @@ def __init__( if self.difficulty in ["unfiltered", "medium"]: if self.difficulty == "medium" and split == "test": - raise Exception( - "not a valid Deepmind Boxoban difficulty split" "combination" - ) + raise Exception("not a valid Deepmind Boxoban difficulty split" "combination") self.train_data_dir = os.path.join( self.train_data_dir, split, @@ -125,9 +121,7 @@ def __call__(self, rng_key: chex.PRNGKey) -> State: """ key, idx_key = jax.random.split(rng_key) - idx = jax.random.randint( - idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0] - ) + idx = jax.random.randint(idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0]) fixed_grid = self._fixed_grids.take(idx, axis=0) variable_grid = self._variable_grids.take(idx, axis=0) @@ -157,9 +151,7 @@ def _generate_dataset( """ all_files = [ - f - for f in listdir(self.train_data_dir) - if isfile(join(self.train_data_dir, f)) + f for f in listdir(self.train_data_dir) if isfile(join(self.train_data_dir, f)) ] # Only keep a few files if specified all_files = all_files[: int(self.proportion_of_files * len(all_files))] @@ -168,7 +160,6 @@ def _generate_dataset( variable_grids_list: List[chex.Array] = [] for file in all_files: - source_file = join(self.train_data_dir, file) current_map: List[str] = [] # parses a game file containing multiple games @@ -178,9 +169,7 @@ def _generate_dataset( fixed_grid, variable_grid = convert_level_to_array(current_map) fixed_grids_list.append(jnp.array(fixed_grid, dtype=jnp.uint8)) - variable_grids_list.append( - jnp.array(variable_grid, dtype=jnp.uint8) - ) + variable_grids_list.append(jnp.array(variable_grid, dtype=jnp.uint8)) current_map = [] if "#" == line[0]: @@ -214,9 +203,7 @@ def _download_data(self) -> None: if response.status_code != 200: raise Exception("Could not download levels") - path_to_zip_file = os.path.join( - self.cache_path, "boxoban_levels-master.zip" - ) + path_to_zip_file = os.path.join(self.cache_path, "boxoban_levels-master.zip") with open(path_to_zip_file, "wb") as handle: for data in tqdm(response.iter_content()): handle.write(data) @@ -284,9 +271,7 @@ def __call__(self, rng_key: chex.PRNGKey) -> State: """ key, idx_key = jax.random.split(rng_key) - idx = jax.random.randint( - idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0] - ) + idx = jax.random.randint(idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0]) fixed_grid = self._fixed_grids.take(idx, axis=0) variable_grid = self._variable_grids.take(idx, axis=0) diff --git a/jumanji/environments/routing/sokoban/generator_test.py b/jumanji/environments/routing/sokoban/generator_test.py index 0c766b632..29e394bb8 100644 --- a/jumanji/environments/routing/sokoban/generator_test.py +++ b/jumanji/environments/routing/sokoban/generator_test.py @@ -41,7 +41,6 @@ def test_sokoban__hugging_generator_creation() -> None: ] for dataset in datasets: - chex.clear_trace_counter() env = Sokoban( @@ -156,7 +155,6 @@ def test_sokoban__deepmind_generator_creation() -> None: ] for dataset in valid_datasets: - chex.clear_trace_counter() env = Sokoban( @@ -183,7 +181,6 @@ def test_sokoban__deepmind_invalid_creation() -> None: ] for dataset in valid_datasets: - chex.clear_trace_counter() with pytest.raises(Exception): diff --git a/jumanji/environments/routing/sokoban/reward_test.py b/jumanji/environments/routing/sokoban/reward_test.py index c8ae5b08b..59c0803d8 100644 --- a/jumanji/environments/routing/sokoban/reward_test.py +++ b/jumanji/environments/routing/sokoban/reward_test.py @@ -41,7 +41,6 @@ def check_correct_reward( num_boxes_on_targets_new: chex.Array, num_boxes_on_targets: chex.Array, ) -> None: - if num_boxes_on_targets_new == jnp.array(4, jnp.int32): assert timestep.reward == jnp.array(10.9, jnp.float32) elif num_boxes_on_targets_new - num_boxes_on_targets > jnp.array(0, jnp.int32): @@ -67,8 +66,6 @@ def check_correct_reward( num_boxes_on_targets_new = sokoban_simple.reward_fn.count_targets(state) - check_correct_reward( - timestep, num_boxes_on_targets_new, num_boxes_on_targets - ) + check_correct_reward(timestep, num_boxes_on_targets_new, num_boxes_on_targets) num_boxes_on_targets = num_boxes_on_targets_new diff --git a/jumanji/environments/routing/sokoban/viewer.py b/jumanji/environments/routing/sokoban/viewer.py index db5716704..beed3e848 100644 --- a/jumanji/environments/routing/sokoban/viewer.py +++ b/jumanji/environments/routing/sokoban/viewer.py @@ -98,9 +98,7 @@ def animate( Returns: Animation that can be saved as a GIF, MP4, or rendered with HTML. """ - fig, ax = plt.subplots( - num=f"{self._name}Animation", figsize=BoxViewer.FIGURE_SIZE - ) + fig, ax = plt.subplots(num=f"{self._name}Animation", figsize=BoxViewer.FIGURE_SIZE) plt.close(fig) def make_frame(state_index: int) -> None: @@ -179,9 +177,7 @@ def _draw_grid(self, grid: chex.Array, ax: plt.Axes) -> None: for row in range(rows): self._draw_grid_cell(grid[row, col], 9 - row, col, ax) - def _draw_grid_cell( - self, cell_value: int, row: int, col: int, ax: plt.Axes - ) -> None: + def _draw_grid_cell(self, cell_value: int, row: int, col: int, ax: plt.Axes) -> None: """ Draw a single cell of the grid. diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py index f6d57bf93..3c09ace12 100644 --- a/jumanji/environments/routing/tsp/env.py +++ b/jumanji/environments/routing/tsp/env.py @@ -135,9 +135,7 @@ def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=self._state_to_observation(state)) return state, timestep - def step( - self, state: State, action: chex.Numeric - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Numeric) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -189,9 +187,7 @@ def observation_spec(self) -> specs.Spec[Observation]: dtype=float, name="coordinates", ) - position = specs.DiscreteArray( - self.num_cities, dtype=jnp.int32, name="position" - ) + position = specs.DiscreteArray(self.num_cities, dtype=jnp.int32, name="position") trajectory = specs.BoundedArray( shape=(self.num_cities,), dtype=jnp.int32, diff --git a/jumanji/environments/routing/tsp/generator.py b/jumanji/environments/routing/tsp/generator.py index 712c34939..99a492e29 100644 --- a/jumanji/environments/routing/tsp/generator.py +++ b/jumanji/environments/routing/tsp/generator.py @@ -59,9 +59,7 @@ def __call__(self, key: chex.PRNGKey) -> State: key, sample_key = jax.random.split(key) # Randomly sample the coordinates of the cities. - coordinates = jax.random.uniform( - sample_key, (self.num_cities, 2), minval=0, maxval=1 - ) + coordinates = jax.random.uniform(sample_key, (self.num_cities, 2), minval=0, maxval=1) # Initially, the position is set to -1, which means that the agent is not in any city. position = jnp.array(-1, jnp.int32) diff --git a/jumanji/environments/routing/tsp/generator_test.py b/jumanji/environments/routing/tsp/generator_test.py index f8837bb58..bd17b36dc 100644 --- a/jumanji/environments/routing/tsp/generator_test.py +++ b/jumanji/environments/routing/tsp/generator_test.py @@ -50,9 +50,7 @@ def uniform_generator(self) -> UniformGenerator: num_cities=50, ) - def test_uniform_generator__properties( - self, uniform_generator: UniformGenerator - ) -> None: + def test_uniform_generator__properties(self, uniform_generator: UniformGenerator) -> None: """Validate that the random instance generator has the correct properties.""" assert uniform_generator.num_cities == 50 diff --git a/jumanji/environments/routing/tsp/reward.py b/jumanji/environments/routing/tsp/reward.py index 37ec63b23..a85a4525d 100644 --- a/jumanji/environments/routing/tsp/reward.py +++ b/jumanji/environments/routing/tsp/reward.py @@ -100,9 +100,7 @@ def __call__( return reward -def compute_tour_length( - coordinates: chex.Array, trajectory: chex.Array -) -> chex.Numeric: +def compute_tour_length(coordinates: chex.Array, trajectory: chex.Array) -> chex.Numeric: sorted_coordinates = coordinates[trajectory] # Shift coordinates to compute the distance between neighboring cities. shifted_coordinates = jnp.roll(sorted_coordinates, -1, axis=0) diff --git a/jumanji/environments/routing/tsp/reward_test.py b/jumanji/environments/routing/tsp/reward_test.py index b4c632aaa..eec96151a 100644 --- a/jumanji/environments/routing/tsp/reward_test.py +++ b/jumanji/environments/routing/tsp/reward_test.py @@ -47,9 +47,7 @@ def test_dense_reward(tsp_dense_reward: TSP, dense_reward: DenseReward) -> None: assert reward == penalty -def test_sparse_reward( # noqa: CCR001 - tsp_sparse_reward: TSP, sparse_reward: SparseReward -) -> None: +def test_sparse_reward(tsp_sparse_reward: TSP, sparse_reward: SparseReward) -> None: sparse_reward = jax.jit(sparse_reward) step_fn = jax.jit(tsp_sparse_reward.step) state, timestep = tsp_sparse_reward.reset(jax.random.PRNGKey(0)) diff --git a/jumanji/registration_test.py b/jumanji/registration_test.py index a37e05935..2e954074d 100644 --- a/jumanji/registration_test.py +++ b/jumanji/registration_test.py @@ -37,9 +37,7 @@ def test_parser__name_version(self, env_id: str, expected: Tuple[str, int]) -> N class TestRegistrationRules: - def test_registration__next_version( - self, mocker: pytest_mock.MockerFixture - ) -> None: + def test_registration__next_version(self, mocker: pytest_mock.MockerFixture) -> None: mocker.patch("jumanji.registration._REGISTRY", {}) # Check that the next registrable version is v+1 @@ -51,9 +49,7 @@ def test_registration__next_version( env_spec = registration.EnvSpec(id="Env-v1", entry_point="") registration._check_registration_is_allowed(env_spec) - def test_registration__already_registered( - self, mocker: pytest_mock.MockerFixture - ) -> None: + def test_registration__already_registered(self, mocker: pytest_mock.MockerFixture) -> None: mocker.patch("jumanji.registration._REGISTRY", {}) env_spec = registration.EnvSpec(id="Env-v0", entry_point="") registration.register(env_spec.id, entry_point=env_spec.entry_point) diff --git a/jumanji/specs.py b/jumanji/specs.py index 6cfacd546..d4cd99e86 100644 --- a/jumanji/specs.py +++ b/jumanji/specs.py @@ -111,9 +111,7 @@ def validate(self, value: T) -> T: def generate_value(self) -> T: """Generate a value which conforms to this spec.""" - constructor_kwargs = jax.tree_util.tree_map( - lambda spec: spec.generate_value(), self._specs - ) + constructor_kwargs = jax.tree_util.tree_map(lambda spec: spec.generate_value(), self._specs) return self._constructor(**constructor_kwargs) def replace(self, **kwargs: Any) -> "Spec": @@ -159,7 +157,7 @@ def __init__(self, shape: Iterable, dtype: Union[jnp.dtype, type], name: str = " self._dtype = get_valid_dtype(dtype) def __repr__(self) -> str: - return f"Array(shape={repr(self.shape)}, dtype={repr(self.dtype)}, name={repr(self.name)})" + return f"Array(shape={self.shape!r}, dtype={self.dtype!r}, name={self.name!r})" def __reduce__(self) -> Any: """To allow pickle to serialize the spec.""" @@ -196,21 +194,15 @@ def validate(self, value: chex.Numeric) -> chex.Array: """ value = jnp.asarray(value) if value.shape != self.shape: - self._fail_validation( - f"Expected shape {self.shape} but found {value.shape}" - ) + self._fail_validation(f"Expected shape {self.shape} but found {value.shape}") if value.dtype != self.dtype: - self._fail_validation( - f"Expected dtype {self.dtype} but found {value.dtype}" - ) + self._fail_validation(f"Expected dtype {self.dtype} but found {value.dtype}") return value def _get_constructor_kwargs(self) -> Dict[str, Any]: """Returns constructor kwargs for instantiating a new copy of this spec.""" # Get the names and kinds of the constructor parameters. - params = inspect.signature( - functools.partial(type(self).__init__, self) - ).parameters + params = inspect.signature(functools.partial(type(self).__init__, self)).parameters # __init__ must not accept *args or **kwargs, since otherwise we won't be # able to infer what the corresponding attribute names are. kinds = {value.kind for value in params.values()} @@ -306,20 +298,16 @@ def __init__( try: bcast_minimum = jnp.broadcast_to(minimum, shape=shape) except ValueError as jnp_exception: - raise ValueError( - "`minimum` is incompatible with `shape`" - ) from jnp_exception + raise ValueError("`minimum` is incompatible with `shape`") from jnp_exception try: bcast_maximum = jnp.broadcast_to(maximum, shape=shape) except ValueError as jnp_exception: - raise ValueError( - "`maximum` is incompatible with `shape`" - ) from jnp_exception + raise ValueError("`maximum` is incompatible with `shape`") from jnp_exception if jnp.any(bcast_minimum > bcast_maximum): raise ValueError( f"All values in `minimum` must be less than or equal to their corresponding " - f"value in `maximum`, got: \n\tminimum={repr(minimum)}\n\tmaximum={repr(maximum)}" + f"value in `maximum`, got: \n\tminimum={minimum!r}\n\tmaximum={maximum!r}" ) self._constructor = lambda: jnp.full(shape, minimum, dtype) self._minimum = minimum @@ -327,8 +315,8 @@ def __init__( def __repr__(self) -> str: return ( - f"BoundedArray(shape={repr(self.shape)}, dtype={repr(self.dtype)}, " - f"name={repr(self.name)}, minimum={repr(self.minimum)}, maximum={repr(self.maximum)})" + f"BoundedArray(shape={self.shape!r}, dtype={self.dtype!r}, " + f"name={self.name!r}, minimum={self.minimum!r}, maximum={self.maximum!r})" ) def __reduce__(self) -> Any: @@ -356,7 +344,7 @@ def validate(self, value: chex.Numeric) -> chex.Array: if (value < self.minimum).any() or (value > self.maximum).any(): self._fail_validation( "Values were not all within bounds " - f"{repr(self.minimum)} <= {repr(value)} <= {repr(self.maximum)}" + f"{self.minimum!r} <= {value!r} <= {self.maximum!r}" ) return value @@ -384,9 +372,7 @@ class DiscreteArray(BoundedArray): that accepts discrete actions. """ - def __init__( - self, num_values: int, dtype: Union[jnp.dtype, type] = jnp.int32, name: str = "" - ): + def __init__(self, num_values: int, dtype: Union[jnp.dtype, type] = jnp.int32, name: str = ""): """Initializes a new `DiscreteArray` spec. Args: @@ -398,9 +384,7 @@ def __init__( ValueError: if `num_values` is not positive, if `dtype` is not integer. """ if num_values <= 0 or not jnp.issubdtype(type(num_values), jnp.integer): - raise ValueError( - f"`num_values` must be a positive integer, got {num_values}." - ) + raise ValueError(f"`num_values` must be a positive integer, got {num_values}.") if not jnp.issubdtype(dtype, jnp.integer): raise ValueError(f"`dtype` must be integer, got {dtype}.") @@ -412,9 +396,9 @@ def __init__( def __repr__(self) -> str: return ( - f"DiscreteArray(shape={repr(self.shape)}, dtype={repr(self.dtype)}, " - f"name={repr(self.name)}, minimum={repr(self.minimum)}, maximum={repr(self.maximum)}, " - f"num_values={repr(self.num_values)})" + f"DiscreteArray(shape={self.shape!r}, dtype={self.dtype!r}, " + f"name={self.name!r}, minimum={self.minimum!r}, maximum={self.maximum!r}, " + f"num_values={self.num_values!r})" ) def __reduce__(self) -> Any: @@ -479,9 +463,9 @@ def __init__( def __repr__(self) -> str: return ( - f"MultiDiscreteArray(shape={repr(self.shape)}, dtype={repr(self.dtype)}, " - f"name={repr(self.name)}, minimum={repr(self.minimum)}, maximum={repr(self.maximum)}, " - f"num_values={repr(self.num_values)})" + f"MultiDiscreteArray(shape={self.shape!r}, dtype={self.dtype!r}, " + f"name={self.name!r}, minimum={self.minimum!r}, maximum={self.maximum!r}, " + f"num_values={self.num_values!r})" ) def __reduce__(self) -> Any: diff --git a/jumanji/specs_test.py b/jumanji/specs_test.py index 09b9f48b1..1bb0c7a68 100644 --- a/jumanji/specs_test.py +++ b/jumanji/specs_test.py @@ -93,9 +93,7 @@ def not_jumanji_type_spec() -> specs.Spec: @pytest.fixture -def mixed_spec( - singly_nested_spec: specs.Spec, not_jumanji_type_spec: specs.Spec -) -> specs.Spec: +def mixed_spec(singly_nested_spec: specs.Spec, not_jumanji_type_spec: specs.Spec) -> specs.Spec: """An example of nested Spec whose leaves are a mix of Jumanji and non-Jumanji specs.""" return specs.Spec( namedtuple("mixed_type", ["singly_nested", "not_jumanji_type"]), @@ -113,18 +111,14 @@ def test_spec__type(self, triply_nested_spec: specs.Spec) -> None: def test_spec__generate_value(self, triply_nested_spec: specs.Spec) -> None: assert isinstance(triply_nested_spec.generate_value(), TriplyNested) - assert isinstance( - triply_nested_spec["doubly_nested"].generate_value(), DoublyNested - ) + assert isinstance(triply_nested_spec["doubly_nested"].generate_value(), DoublyNested) assert isinstance( triply_nested_spec["doubly_nested"]["singly_nested"].generate_value(), SinglyNested, ) def test_spec__validate(self, triply_nested_spec: specs.Spec) -> None: - singly_nested = triply_nested_spec["doubly_nested"][ - "singly_nested" - ].generate_value() + singly_nested = triply_nested_spec["doubly_nested"]["singly_nested"].generate_value() assert isinstance(singly_nested, SinglyNested) doubly_nested = DoublyNested( @@ -147,14 +141,12 @@ def test_spec__replace(self, triply_nested_spec: specs.Spec) -> None: modified_specs = [ triply_nested_spec["bounded_array"].replace(name="wrong_name"), triply_nested_spec["doubly_nested"].replace( - discrete_array=triply_nested_spec["doubly_nested"][ - "discrete_array" - ].replace(num_values=2) + discrete_array=triply_nested_spec["doubly_nested"]["discrete_array"].replace( + num_values=2 + ) ), triply_nested_spec["doubly_nested"].replace( - singly_nested=triply_nested_spec["doubly_nested"][ - "singly_nested" - ].replace( + singly_nested=triply_nested_spec["doubly_nested"]["singly_nested"].replace( bounded_array=triply_nested_spec["doubly_nested"]["singly_nested"][ "bounded_array" ].replace(shape=(33, 33)) @@ -162,15 +154,13 @@ def test_spec__replace(self, triply_nested_spec: specs.Spec) -> None: ), triply_nested_spec["discrete_array"].replace(num_values=27), ] - for arg, modified_spec in zip(arg_list, modified_specs): + for arg, modified_spec in zip(arg_list, modified_specs, strict=False): old_spec = triply_nested_spec new_spec = old_spec.replace(**{arg: modified_spec}) assert new_spec != old_spec chex.assert_equal(getattr(new_spec, arg), modified_spec) for attr_name in set(arg_list).difference([arg]): - chex.assert_equal( - getattr(new_spec, attr_name), getattr(old_spec, attr_name) - ) + chex.assert_equal(getattr(new_spec, attr_name), getattr(old_spec, attr_name)) class TestArray: @@ -304,34 +294,24 @@ def test_read_only(self) -> None: def test_equal_broadcasting_bounds(self) -> None: spec_1 = specs.BoundedArray((1, 2), jnp.float32, minimum=0.0, maximum=1.0) - spec_2 = specs.BoundedArray( - (1, 2), jnp.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0] - ) + spec_2 = specs.BoundedArray((1, 2), jnp.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0]) assert jnp.all(spec_1.minimum == spec_2.minimum) assert jnp.all(spec_1.maximum == spec_2.maximum) def test_not_equal_different_minimum(self) -> None: - spec_1 = specs.BoundedArray( - (1, 2), jnp.float32, minimum=[0.0, -0.6], maximum=[1.0, 1.0] - ) - spec_2 = specs.BoundedArray( - (1, 2), jnp.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0] - ) + spec_1 = specs.BoundedArray((1, 2), jnp.float32, minimum=[0.0, -0.6], maximum=[1.0, 1.0]) + spec_2 = specs.BoundedArray((1, 2), jnp.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0]) assert not jnp.all(spec_1.minimum == spec_2.minimum) assert jnp.all(spec_1.maximum == spec_2.maximum) def test_not_equal_different_maximum(self) -> None: spec_1 = specs.BoundedArray((1, 2), jnp.int32, minimum=0.0, maximum=2.0) - spec_2 = specs.BoundedArray( - (1, 2), jnp.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0] - ) + spec_2 = specs.BoundedArray((1, 2), jnp.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0]) assert not jnp.all(spec_1.maximum == spec_2.maximum) assert jnp.all(spec_1.minimum == spec_2.minimum) def test_repr(self) -> None: - as_string = repr( - specs.BoundedArray((1, 2), jnp.int32, minimum=73.0, maximum=101.0) - ) + as_string = repr(specs.BoundedArray((1, 2), jnp.int32, minimum=73.0, maximum=101.0)) assert "73" in as_string assert "101" in as_string @@ -436,9 +416,7 @@ def test_replace(self, arg_name: str, new_value: Any) -> None: new_spec = old_spec.replace(**{arg_name: new_value}) assert old_spec != new_spec assert getattr(new_spec, arg_name) == new_value - for attr_name in {"shape", "dtype", "name", "minimum", "maximum"}.difference( - [arg_name] - ): + for attr_name in {"shape", "dtype", "name", "minimum", "maximum"}.difference([arg_name]): assert getattr(new_spec, attr_name) == getattr(old_spec, attr_name) @pytest.mark.parametrize( @@ -531,9 +509,7 @@ def test_dtype_not_integer(self, dtype: Union[jnp.dtype, type]) -> None: specs.MultiDiscreteArray(num_values=jnp.array([5, 6], int), dtype=dtype) def test_repr(self) -> None: - as_string = repr( - specs.MultiDiscreteArray(num_values=jnp.array([5, 6], dtype=int)) - ) + as_string = repr(specs.MultiDiscreteArray(num_values=jnp.array([5, 6], dtype=int))) assert "5" in as_string def test_properties(self) -> None: @@ -545,9 +521,7 @@ def test_properties(self) -> None: assert (spec.num_values == num_values).all() def test_serialization(self) -> None: - spec = specs.MultiDiscreteArray( - jnp.array([5, 6], dtype=int), jnp.int32, "pickle_test" - ) + spec = specs.MultiDiscreteArray(jnp.array([5, 6], dtype=int), jnp.int32, "pickle_test") loaded_spec = pickle.loads(pickle.dumps(spec)) assert isinstance(loaded_spec, spec.__class__) assert loaded_spec.dtype == spec.dtype @@ -566,16 +540,12 @@ def test_serialization(self) -> None: ], ) def test_replace(self, arg_name: str, new_value: Any) -> None: - old_spec = specs.MultiDiscreteArray( - jnp.array([5, 6], dtype=int), jnp.int32, "test" - ) + old_spec = specs.MultiDiscreteArray(jnp.array([5, 6], dtype=int), jnp.int32, "test") new_spec = old_spec.replace(**{arg_name: new_value}) for attr_name in ["num_values", "dtype", "name"]: # Check that the attribute corresponding to arg_name has been set to new_value, while # the other attributes have remained the same. - target_value = ( - new_value if attr_name == arg_name else getattr(old_spec, attr_name) - ) + target_value = new_value if attr_name == arg_name else getattr(old_spec, attr_name) if attr_name == "num_values": assert (getattr(new_spec, attr_name) == target_value).all() else: @@ -586,9 +556,7 @@ class TestJumanjiSpecsToDmEnvSpecs: def test_array(self) -> None: jumanji_spec = specs.Array((1, 2), jnp.int32) dm_env_spec = dm_env.specs.Array((1, 2), jnp.int32) - converted_spec: dm_env.specs.Array = specs.jumanji_specs_to_dm_env_specs( - jumanji_spec - ) + converted_spec: dm_env.specs.Array = specs.jumanji_specs_to_dm_env_specs(jumanji_spec) assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype @@ -596,9 +564,7 @@ def test_array(self) -> None: def test_bounded_array(self) -> None: jumanji_spec = specs.BoundedArray((1, 2), jnp.float32, minimum=0.0, maximum=1.0) - dm_env_spec = dm_env.specs.BoundedArray( - (1, 2), jnp.float32, minimum=0.0, maximum=1.0 - ) + dm_env_spec = dm_env.specs.BoundedArray((1, 2), jnp.float32, minimum=0.0, maximum=1.0) converted_spec: dm_env.specs.BoundedArray = specs.jumanji_specs_to_dm_env_specs( jumanji_spec ) @@ -612,8 +578,8 @@ def test_bounded_array(self) -> None: def test_discrete_array(self) -> None: jumanji_spec = specs.DiscreteArray(num_values=5, dtype=jnp.int32) dm_env_spec = dm_env.specs.DiscreteArray(num_values=5, dtype=jnp.int32) - converted_spec: dm_env.specs.DiscreteArray = ( - specs.jumanji_specs_to_dm_env_specs(jumanji_spec) + converted_spec: dm_env.specs.DiscreteArray = specs.jumanji_specs_to_dm_env_specs( + jumanji_spec ) assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape @@ -657,9 +623,7 @@ def test_mixed_spec(self, mixed_spec: specs.Spec) -> None: converted_spec = specs.jumanji_specs_to_dm_env_specs(mixed_spec) assert isinstance(converted_spec, dict) assert isinstance(converted_spec["singly_nested"], dict) - assert_tree_with_leaves_of_type( - converted_spec["singly_nested"], dm_env.specs.Array - ) + assert_tree_with_leaves_of_type(converted_spec["singly_nested"], dm_env.specs.Array) assert not converted_spec["not_jumanji_type"] assert mixed_spec["not_jumanji_type"] @@ -682,9 +646,7 @@ def test_array(self) -> None: assert converted_spec.dtype == gym_space.dtype def test_bounded_array(self) -> None: - jumanji_spec = specs.BoundedArray( - shape=(1, 2), dtype=jnp.float32, minimum=0.0, maximum=1.0 - ) + jumanji_spec = specs.BoundedArray(shape=(1, 2), dtype=jnp.float32, minimum=0.0, maximum=1.0) gym_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, 2), dtype=jnp.float32) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) assert type(converted_spec) is type(gym_space) @@ -703,9 +665,7 @@ def test_discrete_array(self) -> None: assert converted_spec.n == gym_space.n def test_multi_discrete_array(self) -> None: - jumanji_spec = specs.MultiDiscreteArray( - num_values=jnp.array([5, 6], dtype=jnp.int32) - ) + jumanji_spec = specs.MultiDiscreteArray(num_values=jnp.array([5, 6], dtype=jnp.int32)) gym_space = gym.spaces.MultiDiscrete(nvec=[5, 6]) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) assert type(converted_spec) is type(gym_space) @@ -720,9 +680,7 @@ def test_triply_nested_spec(self, triply_nested_spec: specs.Spec) -> None: converted_spec = specs.jumanji_specs_to_gym_spaces(triply_nested_spec) assert isinstance(converted_spec, gym.spaces.Dict) assert isinstance(converted_spec["doubly_nested"], gym.spaces.Dict) - assert isinstance( - converted_spec["doubly_nested"]["singly_nested"], gym.spaces.Dict - ) + assert isinstance(converted_spec["doubly_nested"]["singly_nested"], gym.spaces.Dict) assert isinstance( converted_spec["doubly_nested"]["singly_nested"]["array"], gym.spaces.Box, @@ -735,9 +693,7 @@ def test_triply_nested_spec(self, triply_nested_spec: specs.Spec) -> None: converted_spec["doubly_nested"]["singly_nested"]["multi_discrete_array"], gym.spaces.MultiDiscrete, ) - assert isinstance( - converted_spec["doubly_nested"]["discrete_array"], gym.spaces.Discrete - ) + assert isinstance(converted_spec["doubly_nested"]["discrete_array"], gym.spaces.Discrete) assert isinstance(converted_spec["bounded_array"], gym.spaces.Box) assert isinstance(converted_spec["discrete_array"], gym.spaces.Discrete) @@ -748,9 +704,7 @@ def test_mixed_spec(self, mixed_spec: specs.Spec) -> None: converted_spec = specs.jumanji_specs_to_gym_spaces(mixed_spec) assert isinstance(converted_spec, gym.spaces.Dict) assert isinstance(converted_spec["singly_nested"], gym.spaces.Dict) - assert_tree_with_leaves_of_type( - converted_spec["singly_nested"].spaces, gym.spaces.Space - ) + assert_tree_with_leaves_of_type(converted_spec["singly_nested"].spaces, gym.spaces.Space) assert not converted_spec["not_jumanji_type"] assert mixed_spec["not_jumanji_type"] diff --git a/jumanji/testing/env_not_smoke.py b/jumanji/testing/env_not_smoke.py index 4a9603c9d..1791617b8 100644 --- a/jumanji/testing/env_not_smoke.py +++ b/jumanji/testing/env_not_smoke.py @@ -27,9 +27,7 @@ def make_random_select_action_fn( - action_spec: Union[ - specs.BoundedArray, specs.DiscreteArray, specs.MultiDiscreteArray - ], + action_spec: Union[specs.BoundedArray, specs.DiscreteArray, specs.MultiDiscreteArray], ) -> SelectActionFn: """Create select action function that chooses random actions.""" @@ -100,10 +98,10 @@ def check_env_does_not_smoke( def access_specs(env: Environment) -> None: """Access specs of the environment.""" - env.observation_spec - env.action_spec - env.reward_spec - env.discount_spec + env.observation_spec # noqa: B018 + env.action_spec # noqa: B018 + env.reward_spec # noqa: B018 + env.discount_spec # noqa: B018 def check_env_specs_does_not_smoke(env: Environment) -> None: diff --git a/jumanji/testing/env_not_smoke_test.py b/jumanji/testing/env_not_smoke_test.py index 8900aaaeb..9e2fa3ec4 100644 --- a/jumanji/testing/env_not_smoke_test.py +++ b/jumanji/testing/env_not_smoke_test.py @@ -41,9 +41,7 @@ def select_action(key: chex.PRNGKey, state: chex.ArrayTree) -> chex.ArrayTree: return select_action -def test_env_not_smoke( - fake_env: FakeEnvironment, invalid_select_action_fn: SelectActionFn -) -> None: +def test_env_not_smoke(fake_env: FakeEnvironment, invalid_select_action_fn: SelectActionFn) -> None: """Test that the test_env_not_smoke function raises not errors for a valid environment, and raises errors for an invalid environment.""" check_env_does_not_smoke(fake_env) diff --git a/jumanji/testing/fakes.py b/jumanji/testing/fakes.py index a41246e1d..8208a43d9 100644 --- a/jumanji/testing/fakes.py +++ b/jumanji/testing/fakes.py @@ -68,9 +68,7 @@ def observation_spec(self) -> specs.Array: observation_spec: a `specs.Array` spec. """ - return specs.Array( - shape=self.observation_shape, dtype=float, name="observation" - ) + return specs.Array(shape=self.observation_shape, dtype=float, name="observation") @cached_property def action_spec(self) -> specs.BoundedArray: @@ -188,9 +186,7 @@ def observation_spec(self) -> specs.Array: observation_spec: a `specs.Array` spec. """ - return specs.Array( - shape=self.observation_shape, dtype=float, name="observation" - ) + return specs.Array(shape=self.observation_shape, dtype=float, name="observation") @cached_property def action_spec(self) -> specs.BoundedArray: @@ -200,9 +196,7 @@ def action_spec(self) -> specs.BoundedArray: action_spec: a `specs.Array` spec. """ - return specs.BoundedArray( - (self.num_agents,), int, 0, self.num_action_values - 1 - ) + return specs.BoundedArray((self.num_agents,), int, 0, self.num_action_values - 1) @cached_property def reward_spec(self) -> specs.Array: diff --git a/jumanji/testing/pytrees.py b/jumanji/testing/pytrees.py index a60dee52e..277b9c188 100644 --- a/jumanji/testing/pytrees.py +++ b/jumanji/testing/pytrees.py @@ -36,12 +36,8 @@ def is_equal_pytree(tree1: MixedTypeTree, tree2: MixedTypeTree) -> bool: Note that this function will block gradients between the input and output, and is created for use in the context of testing rather than for direct use inside RL algorithms.""" - is_equal_func = lambda leaf1, leaf2: np.array_equal( - np.asarray(leaf1), np.asarray(leaf2) - ) - is_equal_leaves = tree_lib.flatten( - tree_lib.map_structure(is_equal_func, tree1, tree2) - ) + is_equal_func = lambda leaf1, leaf2: np.array_equal(np.asarray(leaf1), np.asarray(leaf2)) + is_equal_leaves = tree_lib.flatten(tree_lib.map_structure(is_equal_func, tree1, tree2)) is_equal = np.all(is_equal_leaves) return bool(is_equal) @@ -53,9 +49,7 @@ def assert_trees_are_different(tree1: MixedTypeTree, tree2: MixedTypeTree) -> No This is useful for basic sanity checks, for example checking whether parameters are being updated.""" - assert not is_equal_pytree( - tree1, tree2 - ), "The trees have the same value(s) for all leaves." + assert not is_equal_pytree(tree1, tree2), "The trees have the same value(s) for all leaves." def assert_trees_are_equal(tree1: MixedTypeTree, tree2: MixedTypeTree) -> None: @@ -65,17 +59,13 @@ def assert_trees_are_equal(tree1: MixedTypeTree, tree2: MixedTypeTree) -> None: This is useful for basic sanity checks, for example checking if a checkpoint correctly restores a Learner's state.""" - assert is_equal_pytree( - tree1, tree2 - ), "The trees differ in at least one leaf's value(s)." + assert is_equal_pytree(tree1, tree2), "The trees differ in at least one leaf's value(s)." def is_tree_with_leaves_of_type(input_tree: Any, *leaf_type: Type) -> bool: """Returns true if all leaves in the `input_tree` are of the specified `leaf_type`.""" leaf_is_type_func = lambda leaf: isinstance(leaf, leaf_type) - is_type_leaves = tree_lib.flatten( - tree_lib.map_structure(leaf_is_type_func, input_tree) - ) + is_type_leaves = tree_lib.flatten(tree_lib.map_structure(leaf_is_type_func, input_tree)) tree_leaves_are_all_of_type = np.all(is_type_leaves) return bool(tree_leaves_are_all_of_type) diff --git a/jumanji/testing/pytrees_test.py b/jumanji/testing/pytrees_test.py index cef9e0d02..6f36395d1 100644 --- a/jumanji/testing/pytrees_test.py +++ b/jumanji/testing/pytrees_test.py @@ -109,12 +109,8 @@ def test_is_tree_with_leaves_of_type( """ assert pytree_test_utils.is_tree_with_leaves_of_type(jax_tree, jnp.ndarray) assert pytree_test_utils.is_tree_with_leaves_of_type(np_tree, np.ndarray) - assert not pytree_test_utils.is_tree_with_leaves_of_type( - jax_and_numpy_tree, jnp.ndarray - ) - assert not pytree_test_utils.is_tree_with_leaves_of_type( - jax_and_numpy_tree, np.ndarray - ) + assert not pytree_test_utils.is_tree_with_leaves_of_type(jax_and_numpy_tree, jnp.ndarray) + assert not pytree_test_utils.is_tree_with_leaves_of_type(jax_and_numpy_tree, np.ndarray) def test_assert_tree_with_leaves_of_type( @@ -133,16 +129,12 @@ def test_assert_tree_with_leaves_of_type( AssertionError, match=f"The tree has at least one leaf that is not of type {jnp.ndarray}.", ): - pytree_test_utils.assert_tree_with_leaves_of_type( - jax_and_numpy_tree, jnp.ndarray - ) + pytree_test_utils.assert_tree_with_leaves_of_type(jax_and_numpy_tree, jnp.ndarray) with pytest.raises( AssertionError, match=f"The tree has at least one leaf that is not of type {np.ndarray}.", ): - pytree_test_utils.assert_tree_with_leaves_of_type( - jax_and_numpy_tree, np.ndarray - ) + pytree_test_utils.assert_tree_with_leaves_of_type(jax_and_numpy_tree, np.ndarray) def test_assert_is_jax_array_tree( diff --git a/jumanji/training/agents/a2c/a2c_agent.py b/jumanji/training/agents/a2c/a2c_agent.py index 2392ab05c..7b705e914 100644 --- a/jumanji/training/agents/a2c/a2c_agent.py +++ b/jumanji/training/agents/a2c/a2c_agent.py @@ -90,9 +90,7 @@ def run_epoch(self, training_state: TrainingState) -> Tuple[TrainingState, Dict] training_state.acting_state, ) grad, metrics = jax.lax.pmean((grad, metrics), "devices") - updates, opt_state = self.optimizer.update( - grad, training_state.params_state.opt_state - ) + updates, opt_state = self.optimizer.update(grad, training_state.params_state.opt_state) params = optax.apply_updates(training_state.params_state.params, updates) training_state = TrainingState( params_state=ParamsState( @@ -109,18 +107,14 @@ def a2c_loss( params: ActorCriticParams, acting_state: ActingState, ) -> Tuple[float, Tuple[ActingState, Dict]]: - parametric_action_distribution = ( - self.actor_critic_networks.parametric_action_distribution - ) + parametric_action_distribution = self.actor_critic_networks.parametric_action_distribution value_apply = self.actor_critic_networks.value_network.apply acting_state, data = self.rollout( policy_params=params.actor, acting_state=acting_state, ) # data.shape == (T, B, ...) - last_observation = jax.tree_util.tree_map( - lambda x: x[-1], data.next_observation - ) + last_observation = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation) observation = jax.tree_util.tree_map( lambda obs_0_tm1, obs_t: jnp.concatenate([obs_0_tm1, obs_t[None]], axis=0), data.observation, @@ -157,14 +151,10 @@ def a2c_loss( policy_loss = -jnp.mean(jax.lax.stop_gradient(advantage) * data.log_prob) # Compute the entropy loss, i.e. negative of the entropy. - entropy = jnp.mean( - parametric_action_distribution.entropy(data.logits, acting_state.key) - ) + entropy = jnp.mean(parametric_action_distribution.entropy(data.logits, acting_state.key)) entropy_loss = -entropy - total_loss = ( - self.l_pg * policy_loss + self.l_td * critic_loss + self.l_en * entropy_loss - ) + total_loss = self.l_pg * policy_loss + self.l_td * critic_loss + self.l_en * entropy_loss metrics.update( total_loss=total_loss, policy_loss=policy_loss, @@ -182,28 +172,20 @@ def make_policy( self, policy_params: hk.Params, stochastic: bool = True, - ) -> Callable[ - [Any, chex.PRNGKey], Tuple[chex.Array, Tuple[chex.Array, chex.Array]] - ]: + ) -> Callable[[Any, chex.PRNGKey], Tuple[chex.Array, Tuple[chex.Array, chex.Array]]]: policy_network = self.actor_critic_networks.policy_network - parametric_action_distribution = ( - self.actor_critic_networks.parametric_action_distribution - ) + parametric_action_distribution = self.actor_critic_networks.parametric_action_distribution def policy( observation: Any, key: chex.PRNGKey ) -> Tuple[chex.Array, Tuple[chex.Array, chex.Array]]: logits = policy_network.apply(policy_params, observation) if stochastic: - raw_action = parametric_action_distribution.sample_no_postprocessing( - logits, key - ) + raw_action = parametric_action_distribution.sample_no_postprocessing(logits, key) log_prob = parametric_action_distribution.log_prob(logits, raw_action) else: del key - raw_action = parametric_action_distribution.mode_no_postprocessing( - logits - ) + raw_action = parametric_action_distribution.mode_no_postprocessing(logits) # log_prob is log(1), i.e. 0, for a greedy policy (deterministic distribution). log_prob = jnp.zeros_like( parametric_action_distribution.log_prob(logits, raw_action) @@ -254,8 +236,6 @@ def run_one_step( return acting_state, transition - acting_keys = jax.random.split(acting_state.key, self.n_steps).reshape( - (self.n_steps, -1) - ) + acting_keys = jax.random.split(acting_state.key, self.n_steps).reshape((self.n_steps, -1)) acting_state, data = jax.lax.scan(run_one_step, acting_state, acting_keys) return acting_state, data diff --git a/jumanji/training/agents/random/random_agent.py b/jumanji/training/agents/random/random_agent.py index 904fbe98f..fdf36374c 100644 --- a/jumanji/training/agents/random/random_agent.py +++ b/jumanji/training/agents/random/random_agent.py @@ -89,8 +89,6 @@ def run_one_step( extras = next_timestep.extras return acting_state, extras - acting_keys = jax.random.split(acting_state.key, self.n_steps).reshape( - (self.n_steps, -1) - ) + acting_keys = jax.random.split(acting_state.key, self.n_steps).reshape((self.n_steps, -1)) acting_state, extras = jax.lax.scan(run_one_step, acting_state, acting_keys) return acting_state, extras diff --git a/jumanji/training/evaluator.py b/jumanji/training/evaluator.py index d4e97d8ae..4dbda3278 100644 --- a/jumanji/training/evaluator.py +++ b/jumanji/training/evaluator.py @@ -62,9 +62,7 @@ def _eval_one_episode( policy_params: Optional[hk.Params], key: chex.PRNGKey, ) -> Dict: - policy = self.agent.make_policy( - policy_params=policy_params, stochastic=self.stochastic - ) + policy = self.agent.make_policy(policy_params=policy_params, stochastic=self.stochastic) if isinstance(self.agent, A2CAgent): def acting_policy(observation: Any, key: chex.PRNGKey) -> chex.Array: @@ -87,9 +85,7 @@ def body_fun( lambda x: x[None], acting_state.timestep.observation ) action = acting_policy(observation, action_key) - state, timestep = self.eval_env.step( - acting_state.state, jnp.squeeze(action, axis=0) - ) + state, timestep = self.eval_env.step(acting_state.state, jnp.squeeze(action, axis=0)) return_ += timestep.reward acting_state = ActingState( state=state, @@ -148,9 +144,7 @@ def _generate_evaluations( return eval_metrics - def run_evaluation( - self, params_state: Optional[ParamsState], eval_key: chex.PRNGKey - ) -> Dict: + def run_evaluation(self, params_state: Optional[ParamsState], eval_key: chex.PRNGKey) -> Dict: """Run one batch of evaluations.""" eval_keys = jax.random.split(eval_key, self.num_global_devices).reshape( self.num_workers, self.num_local_devices, -1 diff --git a/jumanji/training/loggers.py b/jumanji/training/loggers.py index a0ad74384..0ad2df8be 100644 --- a/jumanji/training/loggers.py +++ b/jumanji/training/loggers.py @@ -30,9 +30,7 @@ class Logger(AbstractContextManager): - def __init__( - self, save_checkpoint: bool, checkpoint_file_name: str = "training_state" - ): + def __init__(self, save_checkpoint: bool, checkpoint_file_name: str = "training_state"): self.save_checkpoint = save_checkpoint self.checkpoint_file_name = checkpoint_file_name @@ -88,9 +86,7 @@ def _save_and_upload_checkpoint(self) -> None: `training_state`. """ logging.info("Saving checkpoint...") - in_context_variables = dict( - set(self._variables_exit).difference(self._variables_enter) - ) + in_context_variables = dict(set(self._variables_exit).difference(self._variables_enter)) variable_id = in_context_variables.get("training_state", None) if variable_id is not None: training_state = self._variables_exit[("training_state", variable_id)] @@ -132,9 +128,7 @@ def write( class TerminalLogger(Logger): """Logs to terminal.""" - def __init__( - self, name: Optional[str] = None, save_checkpoint: bool = False - ) -> None: + def __init__(self, name: Optional[str] = None, save_checkpoint: bool = False) -> None: super().__init__(save_checkpoint=save_checkpoint) if name: logging.info(f"Experiment: {name}.") @@ -250,6 +244,4 @@ def close(self) -> None: self.run.stop() def upload_checkpoint(self) -> None: - self.run[f"checkpoint/{self.checkpoint_file_name}"].upload( - self.checkpoint_file_name - ) + self.run[f"checkpoint/{self.checkpoint_file_name}"].upload(self.checkpoint_file_name) diff --git a/jumanji/training/networks/bin_pack/actor_critic.py b/jumanji/training/networks/bin_pack/actor_critic.py index c3de93f8f..0190071da 100644 --- a/jumanji/training/networks/bin_pack/actor_critic.py +++ b/jumanji/training/networks/bin_pack/actor_critic.py @@ -178,12 +178,8 @@ def network_fn(observation: Observation) -> chex.Array: ems_embeddings, items_embeddings = torso(observation) # Process EMSs differently from items. - ems_embeddings = hk.Linear(torso.model_size, name="policy_ems_head")( - ems_embeddings - ) - items_embeddings = hk.Linear(torso.model_size, name="policy_items_head")( - items_embeddings - ) + ems_embeddings = hk.Linear(torso.model_size, name="policy_ems_head")(ems_embeddings) + items_embeddings = hk.Linear(torso.model_size, name="policy_items_head")(items_embeddings) # Outer-product between the embeddings to obtain logits. logits = jnp.einsum("...ek,...ik->...ei", ems_embeddings, items_embeddings) @@ -214,9 +210,7 @@ def network_fn(observation: Observation) -> chex.Array: ems_mask = observation.ems_mask ems_embedding = jnp.sum(ems_embeddings, axis=-2, where=ems_mask[..., None]) items_mask = observation.items_mask & ~observation.items_placed - items_embedding = jnp.sum( - items_embeddings, axis=-2, where=items_mask[..., None] - ) + items_embedding = jnp.sum(items_embeddings, axis=-2, where=items_mask[..., None]) joint_embedding = jnp.concatenate([ems_embedding, items_embedding], axis=-1) value = hk.nets.MLP((torso.model_size, 1), name="critic_head")(joint_embedding) diff --git a/jumanji/training/networks/bin_pack/random.py b/jumanji/training/networks/bin_pack/random.py index add32d75a..cf020d28b 100644 --- a/jumanji/training/networks/bin_pack/random.py +++ b/jumanji/training/networks/bin_pack/random.py @@ -22,6 +22,4 @@ def make_random_policy_bin_pack(bin_pack: BinPack) -> RandomPolicy: """Make random policy for BinPack.""" action_spec_num_values = bin_pack.action_spec.num_values - return make_masked_categorical_random_ndim( - action_spec_num_values=action_spec_num_values - ) + return make_masked_categorical_random_ndim(action_spec_num_values=action_spec_num_values) diff --git a/jumanji/training/networks/cleaner/actor_critic.py b/jumanji/training/networks/cleaner/actor_critic.py index 2fedfe289..4d60f0f6b 100644 --- a/jumanji/training/networks/cleaner/actor_critic.py +++ b/jumanji/training/networks/cleaner/actor_critic.py @@ -40,9 +40,7 @@ def make_actor_critic_networks_cleaner( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Cleaner` environment.""" num_values = np.asarray(cleaner.action_spec.num_values) - parametric_action_distribution = MultiCategoricalParametricDistribution( - num_values=num_values - ) + parametric_action_distribution = MultiCategoricalParametricDistribution(num_values=num_values) policy_network = make_actor_network( num_conv_channels=num_conv_channels, mlp_units=policy_layers, @@ -93,9 +91,7 @@ def process_obs_for_critic(observation: Observation) -> chex.Array: wall_channel = jnp.where(grid == WALL, 1, 0) agents_channel = all_agents_channel(observation.agents_locations, grid) - return jnp.stack( - [dirty_channel, wall_channel, agents_channel], axis=-1, dtype=float - ) + return jnp.stack([dirty_channel, wall_channel, agents_channel], axis=-1, dtype=float) def make_critic_network( @@ -121,9 +117,7 @@ def network_fn(observation: Observation) -> chex.Array: normalised_step_count = ( jnp.expand_dims(observation.step_count, axis=-1) / time_limit ) # (B, 1) - output = jnp.concatenate( - [embedding, normalised_step_count], axis=-1 - ) # (B, W*H+1) + output = jnp.concatenate([embedding, normalised_step_count], axis=-1) # (B, W*H+1) values = hk.nets.MLP((*mlp_units, 1), activate_final=False)(output) # (B, 1) return jnp.squeeze(values, axis=-1) # (B,) @@ -150,9 +144,7 @@ def process_obs_for_actor(observation: Observation) -> chex.Array: def create_channels_for_one_agent(agent_location: chex.Array) -> chex.Array: dirty_channel = jnp.where(grid == DIRTY, 1, 0) wall_channel = jnp.where(grid == WALL, 1, 0) - agent_channel = ( - jnp.zeros_like(grid).at[agent_location[0], agent_location[1]].set(1) - ) + agent_channel = jnp.zeros_like(grid).at[agent_location[0], agent_location[1]].set(1) agents_channel = all_agents_channel(agents_locations, grid) return jnp.stack( [dirty_channel, wall_channel, agent_channel, agents_channel], @@ -189,9 +181,7 @@ def network_fn(observation: Observation) -> chex.Array: num_agents, axis=1, ) # (B, N, 1) - output = jnp.concatenate( - [embedding, normalised_step_count], axis=-1 - ) # (B, N, W*H+1) + output = jnp.concatenate([embedding, normalised_step_count], axis=-1) # (B, N, W*H+1) head = hk.nets.MLP((*mlp_units, 4), activate_final=False) logits = head(output) # (B, N, 4) return jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) diff --git a/jumanji/training/networks/connector/actor_critic.py b/jumanji/training/networks/connector/actor_critic.py index edcc4946f..19f82998c 100644 --- a/jumanji/training/networks/connector/actor_critic.py +++ b/jumanji/training/networks/connector/actor_critic.py @@ -47,9 +47,7 @@ def make_actor_critic_networks_connector( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Connector` environment.""" num_values = np.asarray(connector.action_spec.num_values) - parametric_action_distribution = MultiCategoricalParametricDistribution( - num_values=num_values - ) + parametric_action_distribution = MultiCategoricalParametricDistribution(num_values=num_values) # num_values is of shape (num_agents,) and contains num_actions everywhere. num_agents = num_values.shape[0] num_actions = num_values[0] @@ -89,9 +87,7 @@ def channel_per_agent(agent_grid: chex.Array, agent_id: jnp.int32) -> chex.Array agent_pos = get_position(agent_id) agent_grid = jnp.expand_dims(agent_grid, -1) agent_mask = ( - (agent_grid == agent_path) - | (agent_grid == agent_target) - | (agent_grid == agent_pos) + (agent_grid == agent_path) | (agent_grid == agent_target) | (agent_grid == agent_pos) ) # Only current agent's info as values: 1, 2 or 3 # [G, G, 1] @@ -217,9 +213,7 @@ def network_fn(observation: Observation) -> chex.Array: num_agents=num_agents, ) embeddings = torso(observation) - logits = hk.nets.MLP((*transformer_mlp_units, num_actions), name="policy_head")( - embeddings - ) + logits = hk.nets.MLP((*transformer_mlp_units, num_actions), name="policy_head")(embeddings) logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) return logits diff --git a/jumanji/training/networks/cvrp/actor_critic.py b/jumanji/training/networks/cvrp/actor_critic.py index e5b498a2a..8b61d7a28 100644 --- a/jumanji/training/networks/cvrp/actor_critic.py +++ b/jumanji/training/networks/cvrp/actor_critic.py @@ -39,9 +39,7 @@ def make_actor_critic_networks_cvrp( ) -> ActorCriticNetworks: """Make actor-critic networks for the `CVRP` environment.""" num_actions = cvrp.action_spec.num_values - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_actor_network_cvrp( transformer_num_blocks=transformer_num_blocks, transformer_num_heads=transformer_num_heads, @@ -233,9 +231,7 @@ def make_cvrp_query( if mean_nodes_in_query: mean_nodes = jnp.mean(embeddings, axis=-2) - query = jnp.concatenate( - [current_position, current_capacity, mean_nodes], axis=-1 - ) + query = jnp.concatenate([current_position, current_capacity, mean_nodes], axis=-1) else: query = jnp.concatenate([current_position, current_capacity], axis=-1) return jnp.expand_dims(query, axis=-2) diff --git a/jumanji/training/networks/distribution.py b/jumanji/training/networks/distribution.py index 481790d76..03262136d 100644 --- a/jumanji/training/networks/distribution.py +++ b/jumanji/training/networks/distribution.py @@ -13,6 +13,7 @@ # limitations under the License. """Adapted from Brax.""" + from __future__ import annotations import abc @@ -83,6 +84,4 @@ def kl_divergence( # type: ignore[override] log_probs = jax.nn.log_softmax(self.logits) probs = jax.nn.softmax(self.logits) log_probs_other = jax.nn.log_softmax(other.logits) - return jnp.sum( - jnp.where(probs == 0, 0.0, probs * (log_probs - log_probs_other)), axis=-1 - ) + return jnp.sum(jnp.where(probs == 0, 0.0, probs * (log_probs - log_probs_other)), axis=-1) diff --git a/jumanji/training/networks/flat_pack/actor_critic.py b/jumanji/training/networks/flat_pack/actor_critic.py index 40aeaa037..345717d52 100644 --- a/jumanji/training/networks/flat_pack/actor_critic.py +++ b/jumanji/training/networks/flat_pack/actor_critic.py @@ -91,9 +91,7 @@ def __call__(self, grid_observation: chex.Array) -> chex.Array: grid_observation = grid_observation[..., jnp.newaxis].astype(float) # Down colvolve with strided convolutions - down_1 = hk.Conv2D(32, kernel_shape=3, stride=2, padding="SAME")( - grid_observation - ) + down_1 = hk.Conv2D(32, kernel_shape=3, stride=2, padding="SAME")(grid_observation) down_1 = jax.nn.relu(down_1) # (B, 6, 6, 32) down_2 = hk.Conv2D(32, kernel_shape=3, stride=2, padding="SAME")(down_1) down_2 = jax.nn.relu(down_2) # (B, 3, 3, 32) @@ -105,13 +103,9 @@ def __call__(self, grid_observation: chex.Array) -> chex.Array: up_2 = hk.Conv2DTranspose(32, kernel_shape=3, stride=2, padding="SAME")(up_1) up_2 = jax.nn.relu(up_2) # (B, 12, 12, 32) up_2 = up_2[:, :-1, :-1] - up_2 = jnp.concatenate( - [up_2, grid_observation], axis=-1 - ) # (B, num_rows, num_cols, 33) + up_2 = jnp.concatenate([up_2, grid_observation], axis=-1) # (B, num_rows, num_cols, 33) - output = hk.Conv2D(self.hidden_size, kernel_shape=1, stride=1, padding="SAME")( - up_2 - ) + output = hk.Conv2D(self.hidden_size, kernel_shape=1, stride=1, padding="SAME")(up_2) # Crop the upconvolved output to be the same size as the action mask. output = output[:, 1:-1, 1:-1] # (B, num_rows-2, num_cols-2, hidden_size) @@ -123,9 +117,7 @@ def __call__(self, grid_observation: chex.Array) -> chex.Array: ) # Linear mapping to transformer model size. - grid_conv_encoding = hk.Linear(self.model_size)( - grid_conv_encoding - ) # (B, model_size) + grid_conv_encoding = hk.Linear(self.model_size)(grid_conv_encoding) # (B, model_size) return grid_conv_encoding, output @@ -156,15 +148,11 @@ def __call__(self, observation: Observation) -> Tuple[chex.Array, chex.Array]: # Flatten the blocks # (B, num_blocks, 9) - flattened_blocks = jnp.reshape( - observation.blocks, (-1, self.num_blocks, 9) - ).astype(float) + flattened_blocks = jnp.reshape(observation.blocks, (-1, self.num_blocks, 9)).astype(float) # Encode the blocks with an MLP block_encoder = hk.nets.MLP(output_sizes=[self.model_size]) - blocks_embedding = jax.vmap(block_encoder)( - flattened_blocks - ) # (B, num_blocks, model_size) + blocks_embedding = jax.vmap(block_encoder)(flattened_blocks) # (B, num_blocks, model_size) unet = UNet(hidden_size=self.hidden_size, model_size=self.model_size) grid_conv_encoding, grid_encoding = unet( @@ -210,9 +198,7 @@ def __call__(self, observation: Observation) -> Tuple[chex.Array, chex.Array]: # Map blocks embedding from (num_blocks, 128) to (num_blocks, num_rotations, hidden_size) blocks_head = hk.nets.MLP(output_sizes=[4 * self.hidden_size]) blocks_embedding = jax.vmap(blocks_head)(blocks_embedding) - blocks_embedding = jnp.reshape( - blocks_embedding, (-1, self.num_blocks, 4, self.hidden_size) - ) + blocks_embedding = jnp.reshape(blocks_embedding, (-1, self.num_blocks, 4, self.hidden_size)) return blocks_embedding, grid_encoding @@ -236,13 +222,9 @@ def network_fn(observation: Observation) -> chex.Array: name="policy_torso", ) blocks_embedding, grid_embedding = torso(observation) - outer_product = jnp.einsum( - "...ijh,...klh->...ijkl", blocks_embedding, grid_embedding - ) + outer_product = jnp.einsum("...ijh,...klh->...ijkl", blocks_embedding, grid_embedding) - logits = jnp.where( - observation.action_mask, outer_product, jnp.finfo(jnp.float32).min - ) + logits = jnp.where(observation.action_mask, outer_product, jnp.finfo(jnp.float32).min) logits = logits.reshape(*logits.shape[:-4], -1) return logits diff --git a/jumanji/training/networks/flat_pack/random.py b/jumanji/training/networks/flat_pack/random.py index 7c8c09463..007ea3c31 100644 --- a/jumanji/training/networks/flat_pack/random.py +++ b/jumanji/training/networks/flat_pack/random.py @@ -23,6 +23,4 @@ def make_random_policy_flat_pack(flat_pack: FlatPack) -> RandomPolicy: """Make random policy for FlatPack.""" action_spec_num_values = flat_pack.action_spec.num_values - return make_masked_categorical_random_ndim( - action_spec_num_values=action_spec_num_values - ) + return make_masked_categorical_random_ndim(action_spec_num_values=action_spec_num_values) diff --git a/jumanji/training/networks/game_2048/actor_critic.py b/jumanji/training/networks/game_2048/actor_critic.py index caa4351dd..8e0ddbf6a 100644 --- a/jumanji/training/networks/game_2048/actor_critic.py +++ b/jumanji/training/networks/game_2048/actor_critic.py @@ -38,9 +38,7 @@ def make_actor_critic_networks_game_2048( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Game2048` environment.""" num_actions = game_2048.action_spec.num_values - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_network_cnn( num_outputs=num_actions, mlp_units=policy_layers, @@ -82,9 +80,7 @@ def network_fn(observation: Observation) -> chex.Array: return jnp.squeeze(head(embedding), axis=-1) else: logits = head(embedding) - masked_logits = jnp.where( - observation.action_mask, logits, jnp.finfo(jnp.float32).min - ) + masked_logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) return masked_logits init, apply = hk.without_apply_rng(hk.transform(network_fn)) diff --git a/jumanji/training/networks/graph_coloring/actor_critic.py b/jumanji/training/networks/graph_coloring/actor_critic.py index 62187d709..ce846bc6c 100644 --- a/jumanji/training/networks/graph_coloring/actor_critic.py +++ b/jumanji/training/networks/graph_coloring/actor_critic.py @@ -39,9 +39,7 @@ def make_actor_critic_networks_graph_coloring( ) -> ActorCriticNetworks: """Make actor-critic networks for the `GraphColoring` environment.""" num_actions = graph_coloring.action_spec.num_values - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_actor_network_graph_coloring( num_actions=num_actions, num_transformer_layers=num_transformer_layers, @@ -210,9 +208,7 @@ def __call__(self, observation: Observation) -> chex.Array: node_embeddings = new_node_embeddings - current_node_embeddings = jnp.take( - node_embeddings, observation.current_node_index, axis=1 - ) + current_node_embeddings = jnp.take(node_embeddings, observation.current_node_index, axis=1) new_embedding = TransformerBlock( num_heads=self.transformer_num_heads, key_size=self.transformer_key_size, @@ -241,9 +237,7 @@ def network_fn(observation: Observation) -> chex.Array: name="policy_torso", ) embeddings = torso(observation) # (B, N, H) - logits = hk.nets.MLP((torso.model_size, 1), name="policy_head")( - embeddings - ) # (B, N, 1) + logits = hk.nets.MLP((torso.model_size, 1), name="policy_head")(embeddings) # (B, N, 1) logits = jnp.squeeze(logits, axis=-1) # (B, N) logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) return logits diff --git a/jumanji/training/networks/job_shop/actor_critic.py b/jumanji/training/networks/job_shop/actor_critic.py index 77d070a17..fde796381 100644 --- a/jumanji/training/networks/job_shop/actor_critic.py +++ b/jumanji/training/networks/job_shop/actor_critic.py @@ -43,9 +43,7 @@ def make_actor_critic_networks_job_shop( ) -> ActorCriticNetworks: """Create an actor-critic network for the `JobShop` environment.""" num_values = np.asarray(job_shop.action_spec.num_values) - parametric_action_distribution = MultiCategoricalParametricDistribution( - num_values=num_values - ) + parametric_action_distribution = MultiCategoricalParametricDistribution(num_values=num_values) policy_network = make_actor_network_job_shop( num_layers_machines=num_layers_machines, num_layers_operations=num_layers_operations, @@ -94,34 +92,24 @@ def __call__(self, observation: Observation) -> chex.Array: m_remaining_times = observation.machines_remaining_times.astype(float)[ ..., None ] # (B, M, 1) - machine_embeddings = self.self_attention_machines( - m_remaining_times - ) # (B, M, D) + machine_embeddings = self.self_attention_machines(m_remaining_times) # (B, M, D) # Job encoder o_machine_ids = observation.ops_machine_ids # (B, J, O) o_durations = observation.ops_durations.astype(float) # (B, J, O) o_mask = observation.ops_mask # (B, J, O) - job_embeddings = jax.vmap( - self.job_encoder, in_axes=(-2, -2, -2, None), out_axes=-2 - )( + job_embeddings = jax.vmap(self.job_encoder, in_axes=(-2, -2, -2, None), out_axes=-2)( o_durations, o_machine_ids, o_mask, machine_embeddings, ) # (B, J, D) # Add embedding for no-op - no_op_emb = hk.Linear(self.model_size)( - jnp.ones((o_mask.shape[0], 1, 1)) - ) # (B, 1, D) - job_embeddings = jnp.concatenate( - [job_embeddings, no_op_emb], axis=-2 - ) # (B, J+1, D) + no_op_emb = hk.Linear(self.model_size)(jnp.ones((o_mask.shape[0], 1, 1))) # (B, 1, D) + job_embeddings = jnp.concatenate([job_embeddings, no_op_emb], axis=-2) # (B, J+1, D) # Joint (machines & jobs) self-attention - embeddings = jnp.concatenate( - [machine_embeddings, job_embeddings], axis=-2 - ) # (M+J+1, D) + embeddings = jnp.concatenate([machine_embeddings, job_embeddings], axis=-2) # (M+J+1, D) embeddings = self.self_attention_joint_machines_ops(embeddings) return embeddings @@ -263,16 +251,10 @@ def network_fn(observation: Observation) -> chex.Array: ) embeddings = torso(observation) # (B, M+J+1, D) num_machines = observation.machines_remaining_times.shape[-1] - machine_embeddings, job_embeddings = jnp.split( - embeddings, (num_machines,), axis=-2 - ) - machine_embeddings = hk.Linear(32, name="policy_head_machines")( - machine_embeddings - ) + machine_embeddings, job_embeddings = jnp.split(embeddings, (num_machines,), axis=-2) + machine_embeddings = hk.Linear(32, name="policy_head_machines")(machine_embeddings) job_embeddings = hk.Linear(32, name="policy_head_jobs")(job_embeddings) - logits = jnp.einsum( - "...mk,...jk->...mj", machine_embeddings, job_embeddings - ) # (B, M, J+1) + logits = jnp.einsum("...mk,...jk->...mj", machine_embeddings, job_embeddings) # (B, M, J+1) logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) return logits @@ -318,9 +300,7 @@ def __init__(self, d_model: int, max_len: int = 5000, name: Optional[str] = None # for an input sequence of length max_len pos_enc = jnp.zeros((self.max_len, self.d_model)) position = jnp.arange(0, self.max_len, dtype=np.float32)[:, None] - div_term = jnp.exp( - jnp.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model) - ) + div_term = jnp.exp(jnp.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model)) pos_enc = pos_enc.at[:, 0::2].set(jnp.sin(position * div_term)) pos_enc = pos_enc.at[:, 1::2].set(jnp.cos(position * div_term)) pos_enc = pos_enc[None] # (1, max_len, d_model) diff --git a/jumanji/training/networks/knapsack/actor_critic.py b/jumanji/training/networks/knapsack/actor_critic.py index b8a676e23..3a1eae4fd 100644 --- a/jumanji/training/networks/knapsack/actor_critic.py +++ b/jumanji/training/networks/knapsack/actor_critic.py @@ -37,9 +37,7 @@ def make_actor_critic_networks_knapsack( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Knapsack` environment.""" num_actions = knapsack.action_spec.num_values - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_actor_network_knapsack( transformer_num_blocks=transformer_num_blocks, transformer_num_heads=transformer_num_heads, diff --git a/jumanji/training/networks/lbf/actor_critic.py b/jumanji/training/networks/lbf/actor_critic.py index bea7a75dd..650f1c217 100644 --- a/jumanji/training/networks/lbf/actor_critic.py +++ b/jumanji/training/networks/lbf/actor_critic.py @@ -40,9 +40,7 @@ def make_actor_critic_networks_lbf( ) -> ActorCriticNetworks: """Make actor-critic networks for the `LevelBasedForaging` environment.""" num_values = np.asarray(lbf_env.action_spec.num_values) - parametric_action_distribution = MultiCategoricalParametricDistribution( - num_values=num_values - ) + parametric_action_distribution = MultiCategoricalParametricDistribution(num_values=num_values) policy_network = make_actor_network( time_limit=lbf_env.time_limit, transformer_num_blocks=transformer_num_blocks, @@ -111,9 +109,7 @@ def __call__(self, observation: Observation) -> chex.Array: model_size=self.model_size, name=f"self_attention_block_{block_id}", ) - embeddings = transformer_block( - query=embeddings, key=embeddings, value=embeddings - ) + embeddings = transformer_block(query=embeddings, key=embeddings, value=embeddings) return embeddings # (B, N, H) diff --git a/jumanji/training/networks/masked_categorical_random.py b/jumanji/training/networks/masked_categorical_random.py index 23ff5eb19..072efea64 100644 --- a/jumanji/training/networks/masked_categorical_random.py +++ b/jumanji/training/networks/masked_categorical_random.py @@ -47,9 +47,7 @@ def make_masked_categorical_random_ndim( def policy(observation: ObservationWithActionMask, key: chex.PRNGKey) -> chex.Array: """Sample uniformly at random from a joint distribution with masking""" n = action_spec_num_values.shape[0] - action_mask = observation.action_mask.reshape( - (observation.action_mask.shape[0], -1) - ) + action_mask = observation.action_mask.reshape((observation.action_mask.shape[0], -1)) flatten_logits = jnp.where( action_mask, jnp.zeros_like(action_mask), diff --git a/jumanji/training/networks/maze/actor_critic.py b/jumanji/training/networks/maze/actor_critic.py index 8d236ef5a..068e7b148 100644 --- a/jumanji/training/networks/maze/actor_critic.py +++ b/jumanji/training/networks/maze/actor_critic.py @@ -38,9 +38,7 @@ def make_actor_critic_networks_maze( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Maze` environment.""" num_actions = np.asarray(maze.action_spec.num_values) - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_network_maze( maze=maze, critic=False, @@ -99,9 +97,7 @@ def network_fn(observation: Observation) -> chex.Array: normalised_step_count = ( jnp.expand_dims(observation.step_count, axis=-1) / maze.time_limit ) # (B, 1) - output = jnp.concatenate( - [embedding, normalised_step_count], axis=-1 - ) # (B, H+1) + output = jnp.concatenate([embedding, normalised_step_count], axis=-1) # (B, H+1) if critic: head = hk.nets.MLP((*mlp_units, 1), activate_final=False) @@ -109,9 +105,7 @@ def network_fn(observation: Observation) -> chex.Array: else: head = hk.nets.MLP((*mlp_units, num_actions), activate_final=False) logits = head(output) - return jnp.where( - observation.action_mask, logits, jnp.finfo(jnp.float32).min - ) + return jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) init, apply = hk.without_apply_rng(hk.transform(network_fn)) return FeedForwardNetwork(init=init, apply=apply) diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index 593673f9f..8dea55279 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -90,9 +90,7 @@ def make_network_cnn( def network_fn(observation: Observation) -> chex.Array: conv_layers = [ [ - hk.Conv2D( - output_channels=output_channels, kernel_shape=board_kernel_shape - ), + hk.Conv2D(output_channels=output_channels, kernel_shape=board_kernel_shape), jax.nn.relu, ] for output_channels in board_conv_channels @@ -105,9 +103,9 @@ def network_fn(observation: Observation) -> chex.Array: ) x = board_embedder(observation.board + 1) num_mines_embedder = hk.Linear(num_mines_embed_dim) - y = num_mines_embedder( - observation.num_mines[:, None] / (board_num_rows * board_num_cols) - )[:, None, None, :] + y = num_mines_embedder(observation.num_mines[:, None] / (board_num_rows * board_num_cols))[ + :, None, None, : + ] y = jnp.tile(y, [1, board_num_rows, board_num_cols, 1]) output = jnp.concatenate([x, y], axis=-1) final_layers = hk.nets.MLP((*final_layer_dims, 1)) diff --git a/jumanji/training/networks/minesweeper/random.py b/jumanji/training/networks/minesweeper/random.py index c7194091f..ebdf33433 100644 --- a/jumanji/training/networks/minesweeper/random.py +++ b/jumanji/training/networks/minesweeper/random.py @@ -24,6 +24,4 @@ def make_random_policy_minesweeper(minesweeper: Minesweeper) -> RandomPolicy: """Make random policy for Minesweeper.""" action_spec_num_values = minesweeper.action_spec.num_values - return make_masked_categorical_random_ndim( - action_spec_num_values=action_spec_num_values - ) + return make_masked_categorical_random_ndim(action_spec_num_values=action_spec_num_values) diff --git a/jumanji/training/networks/mmst/actor_critic.py b/jumanji/training/networks/mmst/actor_critic.py index 18b21dec3..f3edc5082 100644 --- a/jumanji/training/networks/mmst/actor_critic.py +++ b/jumanji/training/networks/mmst/actor_critic.py @@ -39,9 +39,7 @@ def make_actor_critic_networks_mmst( ) -> ActorCriticNetworks: """Make actor-critic networks for the `MMST` environment.""" num_values = mmst.action_spec.num_values - parametric_action_distribution = MultiCategoricalParametricDistribution( - num_values=num_values - ) + parametric_action_distribution = MultiCategoricalParametricDistribution(num_values=num_values) policy_network = make_actor_network_mmst( num_transformer_layers=num_transformer_layers, transformer_num_heads=transformer_num_heads, diff --git a/jumanji/training/networks/multi_cvrp/actor_critic.py b/jumanji/training/networks/multi_cvrp/actor_critic.py index 3300b7835..c54e904af 100644 --- a/jumanji/training/networks/multi_cvrp/actor_critic.py +++ b/jumanji/training/networks/multi_cvrp/actor_critic.py @@ -271,13 +271,9 @@ def network_fn(observation: Observation) -> chex.Array: ) vehicle_embeddings, customer_embeddings = torso(observation) # (B, V+C+1, D) - vehicle_embeddings = hk.Linear(32, name="policy_head_vehicles")( - vehicle_embeddings - ) + vehicle_embeddings = hk.Linear(32, name="policy_head_vehicles")(vehicle_embeddings) - customer_embeddings = hk.Linear(32, name="policy_head_customers")( - customer_embeddings - ) + customer_embeddings = hk.Linear(32, name="policy_head_customers")(customer_embeddings) logits = jnp.einsum( "...vk,...ck->...vc", vehicle_embeddings, customer_embeddings diff --git a/jumanji/training/networks/pac_man/actor_critic.py b/jumanji/training/networks/pac_man/actor_critic.py index 59ca353ad..2a8325def 100644 --- a/jumanji/training/networks/pac_man/actor_critic.py +++ b/jumanji/training/networks/pac_man/actor_critic.py @@ -38,9 +38,7 @@ def make_actor_critic_networks_pacman( ) -> ActorCriticNetworks: """Make actor-critic networks for the `PacMan` environment.""" num_actions = np.asarray(pac_man.action_spec.num_values) - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_network_pac_man( pac_man=pac_man, critic=False, @@ -157,9 +155,7 @@ def network_fn(observation: Observation) -> chex.Array: obs = rgb_observation.astype(float) # Get player position, scatter_time and ghost locations - player_pos = jnp.array( - [observation.player_locations.x, observation.player_locations.y] - ) + player_pos = jnp.array([observation.player_locations.x, observation.player_locations.y]) player_pos = jnp.stack(player_pos, axis=-1) scatter_time = observation.frightened_state_time / 60 scatter_time = jnp.expand_dims(scatter_time, axis=-1) @@ -181,9 +177,7 @@ def network_fn(observation: Observation) -> chex.Array: else: head = hk.nets.MLP((*mlp_units, num_actions), activate_final=False) logits = head(output) - return jnp.where( - observation.action_mask, logits, jnp.finfo(jnp.float32).min - ) + return jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) init, apply = hk.without_apply_rng(hk.transform(network_fn)) return FeedForwardNetwork(init=init, apply=apply) diff --git a/jumanji/training/networks/parametric_distribution.py b/jumanji/training/networks/parametric_distribution.py index 108825104..325c17e9b 100644 --- a/jumanji/training/networks/parametric_distribution.py +++ b/jumanji/training/networks/parametric_distribution.py @@ -69,9 +69,7 @@ def postprocess(self, event: chex.Array) -> chex.Array: def inverse_postprocess(self, event: chex.Array) -> chex.Array: return self._postprocessor.inverse(event) - def sample_no_postprocessing( - self, parameters: chex.Array, seed: chex.PRNGKey - ) -> Any: + def sample_no_postprocessing(self, parameters: chex.Array, seed: chex.PRNGKey) -> Any: """Returns a sample of the distribution before postprocessing it.""" return self.create_dist(parameters).sample(seed=seed) @@ -105,9 +103,7 @@ def entropy(self, parameters: chex.Array, seed: chex.PRNGKey) -> chex.Array: entropy = jnp.sum(entropy, axis=-1) return entropy - def kl_divergence( - self, parameters: chex.Array, other_parameters: chex.Array - ) -> chex.Array: + def kl_divergence(self, parameters: chex.Array, other_parameters: chex.Array) -> chex.Array: """KL divergence is invariant with respect to transformation by the same bijector.""" if not isinstance(self._postprocessor, IdentityBijector): raise ValueError( @@ -173,9 +169,7 @@ def __init__(self, action_spec_num_values: chex.ArrayNumpy): posprocessor = FactorisedActionSpaceReshapeBijector( action_spec_num_values=action_spec_num_values ) - super().__init__( - param_size=num_actions, postprocessor=posprocessor, event_ndims=0 - ) + super().__init__(param_size=num_actions, postprocessor=posprocessor, event_ndims=0) def create_dist(self, parameters: chex.Array) -> CategoricalDistribution: return CategoricalDistribution(logits=parameters) diff --git a/jumanji/training/networks/postprocessor.py b/jumanji/training/networks/postprocessor.py index 319036269..e97fe3a19 100644 --- a/jumanji/training/networks/postprocessor.py +++ b/jumanji/training/networks/postprocessor.py @@ -59,9 +59,7 @@ def forward(self, x: chex.Array) -> chex.Array: flat_action = x n = self.action_spec_num_values.shape[0] for i in range(n - 1, 0, -1): - flat_action, remainder = jnp.divmod( - flat_action, self.action_spec_num_values[i] - ) + flat_action, remainder = jnp.divmod(flat_action, self.action_spec_num_values[i]) action_components.append(remainder) action_components.append(flat_action) action = jnp.stack( @@ -76,9 +74,7 @@ def inverse(self, y: chex.Array) -> chex.Array: action_components = jnp.split(y, n, axis=-1) flat_action = action_components[0] for i in range(1, n): - flat_action = ( - self.action_spec_num_values[i] * flat_action + action_components[i] - ) + flat_action = self.action_spec_num_values[i] * flat_action + action_components[i] return flat_action def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array: diff --git a/jumanji/training/networks/robot_warehouse/actor_critic.py b/jumanji/training/networks/robot_warehouse/actor_critic.py index 965caf397..9045a86e4 100644 --- a/jumanji/training/networks/robot_warehouse/actor_critic.py +++ b/jumanji/training/networks/robot_warehouse/actor_critic.py @@ -40,9 +40,7 @@ def make_actor_critic_networks_robot_warehouse( ) -> ActorCriticNetworks: """Make actor-critic networks for the `RobotWarehouse` environment.""" num_values = np.asarray(robot_warehouse.action_spec.num_values) - parametric_action_distribution = MultiCategoricalParametricDistribution( - num_values=num_values - ) + parametric_action_distribution = MultiCategoricalParametricDistribution(num_values=num_values) policy_network = make_actor_network( time_limit=robot_warehouse.time_limit, transformer_num_blocks=transformer_num_blocks, @@ -111,9 +109,7 @@ def __call__(self, observation: Observation) -> chex.Array: model_size=self.model_size, name=f"self_attention_block_{block_id}", ) - embeddings = transformer_block( - query=embeddings, key=embeddings, value=embeddings - ) + embeddings = transformer_block(query=embeddings, key=embeddings, value=embeddings) return embeddings # (B, N, H) diff --git a/jumanji/training/networks/rubiks_cube/actor_critic.py b/jumanji/training/networks/rubiks_cube/actor_critic.py index 53a2643ac..281ca671a 100644 --- a/jumanji/training/networks/rubiks_cube/actor_critic.py +++ b/jumanji/training/networks/rubiks_cube/actor_critic.py @@ -71,15 +71,11 @@ def make_torso_network_fn( def torso_network_fn(observation: Observation) -> chex.Array: # Cube embedding cube_embedder = hk.Embed(vocab_size=len(Face), embed_dim=cube_embed_dim) - cube_embedding = cube_embedder(observation.cube).reshape( - *observation.cube.shape[:-3], -1 - ) + cube_embedding = cube_embedder(observation.cube).reshape(*observation.cube.shape[:-3], -1) # Step count embedding step_count_embedder = hk.Linear(step_count_embed_dim) - step_count_embedding = step_count_embedder( - observation.step_count[:, None] / time_limit - ) + step_count_embedding = step_count_embedder(observation.step_count[:, None] / time_limit) embedding = jnp.concatenate([cube_embedding, step_count_embedding], axis=-1) return embedding diff --git a/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py b/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py index 5c4a6752c..6187a2fce 100644 --- a/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py +++ b/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py @@ -41,9 +41,7 @@ def make_actor_critic_networks_sliding_tile_puzzle( ) -> ActorCriticNetworks: """Make actor-critic networks for the `SlidingTilePuzzle` environment.""" num_actions = sliding_tile_puzzle.action_spec.num_values - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_mlp_network( num_outputs=num_actions, mlp_units=policy_layers, @@ -86,9 +84,7 @@ def network_fn(observation: Observation) -> chex.Array: return jnp.squeeze(head(embedding), axis=-1) else: logits = head(embedding) - masked_logits = jnp.where( - observation.action_mask, logits, jnp.finfo(jnp.float32).min - ) + masked_logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) return masked_logits init, apply = hk.without_apply_rng(hk.transform(network_fn)) diff --git a/jumanji/training/networks/snake/actor_critic.py b/jumanji/training/networks/snake/actor_critic.py index 0be42e223..6cc26ee4e 100644 --- a/jumanji/training/networks/snake/actor_critic.py +++ b/jumanji/training/networks/snake/actor_critic.py @@ -37,9 +37,7 @@ def make_actor_critic_networks_snake( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Snake` environment.""" num_actions = snake.action_spec.num_values - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_snake_cnn( num_outputs=num_actions, mlp_units=policy_layers, @@ -84,9 +82,7 @@ def network_fn(observation: Observation) -> chex.Array: return value else: logits = head(embedding) - logits = jnp.where( - observation.action_mask, logits, jnp.finfo(jnp.float32).min - ) + logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) return logits init, apply = hk.without_apply_rng(hk.transform(network_fn)) diff --git a/jumanji/training/networks/sokoban/actor_critic.py b/jumanji/training/networks/sokoban/actor_critic.py index 37942b7dd..5d32f874a 100644 --- a/jumanji/training/networks/sokoban/actor_critic.py +++ b/jumanji/training/networks/sokoban/actor_critic.py @@ -37,9 +37,7 @@ def make_actor_critic_networks_sokoban( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Sokoban` environment.""" num_actions = sokoban.action_spec.num_values - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_sokoban_cnn( num_outputs=num_actions, @@ -68,7 +66,6 @@ def make_sokoban_cnn( time_limit: int, ) -> FeedForwardNetwork: def network_fn(observation: Observation) -> chex.Array: - # Iterate over the channels sequence to create convolutional layers layers = [] for i, conv_n_channels in enumerate(channels): @@ -101,14 +98,9 @@ def network_fn(observation: Observation) -> chex.Array: def preprocess_input( input_array: chex.Array, ) -> chex.Array: + one_hot_array_fixed = jnp.equal(input_array[..., 0:1], jnp.array([3, 4])).astype(jnp.float32) - one_hot_array_fixed = jnp.equal(input_array[..., 0:1], jnp.array([3, 4])).astype( - jnp.float32 - ) - - one_hot_array_variable = jnp.equal(input_array[..., 1:2], jnp.array([1, 2])).astype( - jnp.float32 - ) + one_hot_array_variable = jnp.equal(input_array[..., 1:2], jnp.array([1, 2])).astype(jnp.float32) total = jnp.concatenate((one_hot_array_fixed, one_hot_array_variable), axis=-1) diff --git a/jumanji/training/networks/sudoku/actor_critic.py b/jumanji/training/networks/sudoku/actor_critic.py index 8ba664e65..97c7737b8 100644 --- a/jumanji/training/networks/sudoku/actor_critic.py +++ b/jumanji/training/networks/sudoku/actor_critic.py @@ -120,9 +120,7 @@ def network_fn(observation: Observation) -> chex.Array: logits = head(embedding) logits = logits.reshape(-1, BOARD_WIDTH, BOARD_WIDTH, BOARD_WIDTH) - logits = jnp.where( - observation.action_mask, logits, jnp.finfo(jnp.float32).min - ) + logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) return logits.reshape(observation.action_mask.shape[0], -1) @@ -162,13 +160,9 @@ def network_fn(observation: Observation) -> chex.Array: logits = jnp.transpose(logits, (0, 2, 1)) - logits = logits.reshape( - board.shape[0], BOARD_WIDTH, BOARD_WIDTH, BOARD_WIDTH - ) + logits = logits.reshape(board.shape[0], BOARD_WIDTH, BOARD_WIDTH, BOARD_WIDTH) - logits = jnp.where( - observation.action_mask, logits, jnp.finfo(jnp.float32).min - ) + logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) return logits.reshape(observation.action_mask.shape[0], -1) diff --git a/jumanji/training/networks/sudoku/random.py b/jumanji/training/networks/sudoku/random.py index e2cce1fde..b0452708e 100644 --- a/jumanji/training/networks/sudoku/random.py +++ b/jumanji/training/networks/sudoku/random.py @@ -25,6 +25,4 @@ def make_random_policy_sudoku(sudoku: Sudoku) -> RandomPolicy: action_spec_num_values = sudoku.action_spec.num_values - return make_masked_categorical_random_ndim( - action_spec_num_values=action_spec_num_values - ) + return make_masked_categorical_random_ndim(action_spec_num_values=action_spec_num_values) diff --git a/jumanji/training/networks/tetris/actor_critic.py b/jumanji/training/networks/tetris/actor_critic.py index 4e37052fd..44cd5087f 100644 --- a/jumanji/training/networks/tetris/actor_critic.py +++ b/jumanji/training/networks/tetris/actor_critic.py @@ -82,9 +82,7 @@ def network_fn(observation: Observation) -> chex.Array: jax.nn.relu, ] ) - grid_embeddings = grid_net( - observation.grid.astype(float)[..., None] - ) # [B, 2, 10, 64] + grid_embeddings = grid_net(observation.grid.astype(float)[..., None]) # [B, 2, 10, 64] grid_embeddings = jnp.transpose(grid_embeddings, [0, 2, 1, 3]) # [B, 10, 2, 64] grid_embeddings = jnp.reshape( grid_embeddings, [*grid_embeddings.shape[:2], -1] @@ -101,9 +99,7 @@ def network_fn(observation: Observation) -> chex.Array: tetromino_embeddings[:, None], (grid_embeddings.shape[1], 1) ) norm_step_count = observation.step_count / time_limit - norm_step_count = jnp.tile( - norm_step_count[:, None, None], (grid_embeddings.shape[1], 1) - ) + norm_step_count = jnp.tile(norm_step_count[:, None, None], (grid_embeddings.shape[1], 1)) embedding = jnp.concatenate( [grid_embeddings, tetromino_embeddings, norm_step_count], axis=-1 diff --git a/jumanji/training/networks/tetris/random.py b/jumanji/training/networks/tetris/random.py index eef995792..3236fed42 100644 --- a/jumanji/training/networks/tetris/random.py +++ b/jumanji/training/networks/tetris/random.py @@ -22,6 +22,4 @@ def make_random_policy_tetris(tetris: Tetris) -> RandomPolicy: """Make random policy for `Tetris`.""" action_spec_num_values = tetris.action_spec.num_values - return make_masked_categorical_random_ndim( - action_spec_num_values=action_spec_num_values - ) + return make_masked_categorical_random_ndim(action_spec_num_values=action_spec_num_values) diff --git a/jumanji/training/networks/tsp/actor_critic.py b/jumanji/training/networks/tsp/actor_critic.py index 1d720743b..97263ac5b 100644 --- a/jumanji/training/networks/tsp/actor_critic.py +++ b/jumanji/training/networks/tsp/actor_critic.py @@ -39,9 +39,7 @@ def make_actor_critic_networks_tsp( ) -> ActorCriticNetworks: """Make actor-critic networks for the `TSP` environment.""" num_actions = tsp.action_spec.num_values - parametric_action_distribution = CategoricalParametricDistribution( - num_actions=num_actions - ) + parametric_action_distribution = CategoricalParametricDistribution(num_actions=num_actions) policy_network = make_actor_network_tsp( transformer_num_blocks=transformer_num_blocks, transformer_num_heads=transformer_num_heads, @@ -83,9 +81,7 @@ def __init__( self.model_size = transformer_num_heads * transformer_key_size def __call__(self, coordinates: chex.Array, mask: chex.Array) -> chex.Array: - embeddings = hk.Linear(self.model_size, name="coordinates_projection")( - coordinates - ) + embeddings = hk.Linear(self.model_size, name="coordinates_projection")(coordinates) for block_id in range(self.transformer_num_blocks): transformer_block = TransformerBlock( num_heads=self.transformer_num_heads, diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index 4c743497f..d8612bed9 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -71,9 +71,7 @@ def setup_logger(cfg: DictConfig) -> Logger: if jax.process_index() != 0: return NoOpLogger() if cfg.logger.type == "tensorboard": - logger = TensorboardLogger( - name=cfg.logger.name, save_checkpoint=cfg.logger.save_checkpoint - ) + logger = TensorboardLogger(name=cfg.logger.name, save_checkpoint=cfg.logger.save_checkpoint) elif cfg.logger.type == "neptune": logger = NeptuneLogger( name=cfg.logger.name, @@ -82,9 +80,7 @@ def setup_logger(cfg: DictConfig) -> Logger: save_checkpoint=cfg.logger.save_checkpoint, ) elif cfg.logger.type == "terminal": - logger = TerminalLogger( - name=cfg.logger.name, save_checkpoint=cfg.logger.save_checkpoint - ) + logger = TerminalLogger(name=cfg.logger.name, save_checkpoint=cfg.logger.save_checkpoint) else: raise ValueError( f"logger expected in ['neptune', 'tensorboard', 'terminal'], got {cfg.logger}." @@ -133,15 +129,11 @@ def setup_agent(cfg: DictConfig, env: Environment) -> Agent: l_en=cfg.env.a2c.l_en, ) else: - raise ValueError( - f"Expected agent name to be in ['random', 'a2c'], got {cfg.agent}." - ) + raise ValueError(f"Expected agent name to be in ['random', 'a2c'], got {cfg.agent}.") return agent -def _setup_random_policy( # noqa: CCR001 - cfg: DictConfig, env: Environment -) -> RandomPolicy: +def _setup_random_policy(cfg: DictConfig, env: Environment) -> RandomPolicy: assert cfg.agent == "random" if cfg.env.name == "bin_pack": assert isinstance(env.unwrapped, BinPack) @@ -169,14 +161,10 @@ def _setup_random_policy( # noqa: CCR001 random_policy = networks.make_random_policy_multicvrp() elif cfg.env.name == "rubiks_cube": assert isinstance(env.unwrapped, RubiksCube) - random_policy = networks.make_random_policy_rubiks_cube( - rubiks_cube=env.unwrapped - ) + random_policy = networks.make_random_policy_rubiks_cube(rubiks_cube=env.unwrapped) elif cfg.env.name == "minesweeper": assert isinstance(env.unwrapped, Minesweeper) - random_policy = networks.make_random_policy_minesweeper( - minesweeper=env.unwrapped - ) + random_policy = networks.make_random_policy_minesweeper(minesweeper=env.unwrapped) elif cfg.env.name == "game_2048": assert isinstance(env.unwrapped, Game2048) random_policy = networks.make_random_policy_game_2048() @@ -223,9 +211,7 @@ def _setup_random_policy( # noqa: CCR001 return random_policy -def _setup_actor_critic_neworks( # noqa: CCR001 - cfg: DictConfig, env: Environment -) -> ActorCriticNetworks: +def _setup_actor_critic_neworks(cfg: DictConfig, env: Environment) -> ActorCriticNetworks: assert cfg.agent == "a2c" if cfg.env.name == "bin_pack": assert isinstance(env.unwrapped, BinPack) @@ -457,9 +443,7 @@ def setup_evaluators(cfg: DictConfig, agent: Agent) -> Tuple[Evaluator, Evaluato return stochastic_eval, greedy_eval -def setup_training_state( - env: Environment, agent: Agent, key: chex.PRNGKey -) -> TrainingState: +def setup_training_state(env: Environment, agent: Agent, key: chex.PRNGKey) -> TrainingState: params_key, reset_key, acting_key = jax.random.split(key, 3) # Initialize params. @@ -479,9 +463,7 @@ def setup_training_state( ) ) reset_keys_per_worker = reset_keys[jax.process_index()] - env_state, timestep = jax.pmap(env.reset, axis_name="devices")( - reset_keys_per_worker - ) + env_state, timestep = jax.pmap(env.reset, axis_name="devices")(reset_keys_per_worker) # Initialize acting states. acting_key_per_device = jax.random.split(acting_key, num_global_devices).reshape( diff --git a/jumanji/training/timer.py b/jumanji/training/timer.py index 3d03d55d3..4d650ee5b 100644 --- a/jumanji/training/timer.py +++ b/jumanji/training/timer.py @@ -13,7 +13,7 @@ # limitations under the License. # Inspired from https://stackoverflow.com/questions/51849395/how-can-we-associate-a-python-context-m -# anager-to-the-variables-appearing-in-it#:~:text=also%20inspect%20the-,stack,-for%20locals()%20variables +# anager-to-the-variables-appearing-in-it#:~:text=also%20inspect%20the-,stack,-for%20locals()%20variables # noqa: E501 from __future__ import annotations import inspect @@ -56,16 +56,12 @@ def __exit__(self, *exc: Any) -> Literal[False]: self._variables_exit = self._get_variables() self.data = {"time": elapsed_time} if self.num_steps_per_timing is not None: - self.data.update( - steps_per_second=int(self.num_steps_per_timing / elapsed_time) - ) + self.data.update(steps_per_second=int(self.num_steps_per_timing / elapsed_time)) self._write_in_variable(self.data) return False def _write_in_variable(self, data: Dict[str, float]) -> None: - in_context_variables = dict( - set(self._variables_exit).difference(self._variables_enter) - ) + in_context_variables = dict(set(self._variables_exit).difference(self._variables_enter)) metrics_id = in_context_variables.get(self.out_var_name, None) if metrics_id is not None: self._variables_exit[("metrics", metrics_id)].update(data) diff --git a/jumanji/training/train.py b/jumanji/training/train.py index 4d01d7785..ed14fbeab 100644 --- a/jumanji/training/train.py +++ b/jumanji/training/train.py @@ -54,9 +54,7 @@ def train(cfg: omegaconf.DictConfig, log_compiles: bool = False) -> None: * cfg.env.training.num_learner_steps_per_epoch ) eval_timer = Timer(out_var_name="metrics") - train_timer = Timer( - out_var_name="metrics", num_steps_per_timing=num_steps_per_epoch - ) + train_timer = Timer(out_var_name="metrics", num_steps_per_timing=num_steps_per_epoch) @functools.partial(jax.pmap, axis_name="devices") def epoch_fn(training_state: TrainingState) -> Tuple[TrainingState, Dict]: diff --git a/jumanji/tree_utils.py b/jumanji/tree_utils.py index d7a05300a..9d5e46496 100644 --- a/jumanji/tree_utils.py +++ b/jumanji/tree_utils.py @@ -64,7 +64,5 @@ def tree_add_element(tree: T, i: chex.Numeric, element: T) -> T: tree whose elements are the same as before but with the ith value being set to that of the given element. """ - new_tree: T = jax.tree_util.tree_map( - lambda array, value: array.at[i].set(value), tree, element - ) + new_tree: T = jax.tree_util.tree_map(lambda array, value: array.at[i].set(value), tree, element) return new_tree diff --git a/jumanji/tree_utils_test.py b/jumanji/tree_utils_test.py index 758b2e548..ce7110193 100644 --- a/jumanji/tree_utils_test.py +++ b/jumanji/tree_utils_test.py @@ -55,7 +55,5 @@ def test_tree_slice() -> None: ), ], ) -def test_tree_add_element( - tree: T, i: chex.Numeric, element: T, expected_tree: T -) -> None: +def test_tree_add_element(tree: T, i: chex.Numeric, element: T, expected_tree: T) -> None: assert_trees_are_equal(tree_add_element(tree, i, element), expected_tree) diff --git a/jumanji/types_test.py b/jumanji/types_test.py index 3c9d3c92f..82d264e49 100644 --- a/jumanji/types_test.py +++ b/jumanji/types_test.py @@ -67,12 +67,8 @@ def get_termination_transition() -> TimeStep: """Returns either a termination or transition TimeStep.""" timestep_termination_transition: TimeStep = lax.cond( done, - lambda _: termination( - reward=jnp.zeros((), float), observation=jnp.zeros((), float) - ), - lambda _: transition( - reward=jnp.zeros((), float), observation=jnp.zeros((), float) - ), + lambda _: termination(reward=jnp.zeros((), float), observation=jnp.zeros((), float)), + lambda _: transition(reward=jnp.zeros((), float), observation=jnp.zeros((), float)), None, ) return timestep_termination_transition @@ -86,9 +82,7 @@ def get_restart_truncation() -> TimeStep: timestep_restart_truncation: TimeStep = lax.cond( done, lambda _: restart(observation=jnp.zeros((), float)), - lambda _: truncation( - reward=jnp.zeros((), float), observation=jnp.zeros((), float) - ), + lambda _: truncation(reward=jnp.zeros((), float), observation=jnp.zeros((), float)), None, ) return timestep_restart_truncation diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 04a87db45..3a2c6880f 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -31,9 +31,7 @@ GymObservation = Any -class Wrapper( - Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] -): +class Wrapper(Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]): """Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72 """ @@ -43,7 +41,7 @@ def __init__(self, env: Environment[State, ActionSpec, Observation]): super().__init__() def __repr__(self) -> str: - return f"{self.__class__.__name__}({repr(self._env)})" + return f"{self.__class__.__name__}({self._env!r})" def __getattr__(self, name: str) -> Any: if name == "__setstate__": @@ -67,9 +65,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """ return self._env.reset(key) - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -125,9 +121,7 @@ def __exit__(self, *args: Any) -> None: self.close() -class JumanjiToDMEnvWrapper( - dm_env.Environment, Generic[State, ActionSpec, Observation] -): +class JumanjiToDMEnvWrapper(dm_env.Environment, Generic[State, ActionSpec, Observation]): """A wrapper that converts Environment to dm_env.Environment.""" def __init__( @@ -150,9 +144,9 @@ def __init__( self._jitted_reset: Callable[[chex.PRNGKey], Tuple[State, TimeStep]] = jax.jit( self._env.reset ) - self._jitted_step: Callable[ - [State, chex.Array], Tuple[State, TimeStep] - ] = jax.jit(self._env.step) + self._jitted_step: Callable[[State, chex.Array], Tuple[State, TimeStep]] = jax.jit( + self._env.step + ) def __repr__(self) -> str: return str(self._env.__repr__()) @@ -248,9 +242,7 @@ def __init__( self._reward_aggregator = reward_aggregator self._discount_aggregator = discount_aggregator - def _aggregate_timestep( - self, timestep: TimeStep[Observation] - ) -> TimeStep[Observation]: + def _aggregate_timestep(self, timestep: TimeStep[Observation]) -> TimeStep[Observation]: """Apply the reward and discount aggregator to a multi-agent timestep object to create a new timestep object that consists of a scalar reward and discount value. @@ -283,9 +275,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = self._aggregate_timestep(timestep) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. The rewards are aggregated into a single value based on the given reward aggregator. @@ -305,9 +295,7 @@ def step( return state, timestep -class VmapWrapper( - Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] -): +class VmapWrapper(Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]): """Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: @@ -337,9 +325,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: state, timestep = jax.vmap(self._env.reset)(key) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. The first dimension of the state will dictate the number of concurrent environments. @@ -450,9 +436,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = self._maybe_add_obs_to_extras(timestep) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Step the environment, with automatic resetting if the episode terminates.""" state, timestep = self._env.step(state, action) @@ -522,9 +506,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = self._maybe_add_obs_to_extras(timestep) return state, timestep - def step( - self, state: State, action: chex.Array - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of all environments' dynamics. It automatically resets environment(s) in which episodes have terminated. @@ -544,9 +526,7 @@ def step( # Vmap homogeneous computation (parallelizable). state, timestep = jax.vmap(self._env.step)(state, action) # Map heterogeneous computation (non-parallelizable). - state, timestep = jax.lax.map( - lambda args: self._maybe_reset(*args), (state, timestep) - ) + state, timestep = jax.lax.map(lambda args: self._maybe_reset(*args), (state, timestep)) return state, timestep def _auto_reset( @@ -626,9 +606,7 @@ def __init__( self._key = jax.random.PRNGKey(seed) self.backend = backend self._state = None - self.observation_space = specs.jumanji_specs_to_gym_spaces( - self._env.observation_spec - ) + self.observation_space = specs.jumanji_specs_to_gym_spaces(self._env.observation_spec) self.action_space = specs.jumanji_specs_to_gym_spaces(self._env.action_spec) def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]: @@ -676,9 +654,7 @@ def reset( else: return obs # type: ignore - def step( - self, action: chex.ArrayNumpy - ) -> Tuple[GymObservation, float, bool, Optional[Any]]: + def step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Optional[Any]]: """Updates the environment according to the action and returns an `Observation`. Args: @@ -744,9 +720,7 @@ def jumanji_to_gym_obs(observation: Observation) -> GymObservation: return np.asarray(observation) elif hasattr(observation, "__dict__"): # Applies to various containers including `chex.dataclass` - return { - key: jumanji_to_gym_obs(value) for key, value in vars(observation).items() - } + return {key: jumanji_to_gym_obs(value) for key, value in vars(observation).items()} elif hasattr(observation, "_asdict"): # Applies to `NamedTuple` container. return { diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index 53d914f3a..375085d05 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -123,7 +123,7 @@ def test_wrapper__observation_spec( wrapped_fake_environment = mock_wrapper_class(fake_environment) mock_obs_spec.assert_called_once() - wrapped_fake_environment.observation_spec + wrapped_fake_environment.observation_spec # noqa: B018 mock_obs_spec.assert_called_once() def test_wrapper__action_spec( @@ -140,7 +140,7 @@ def test_wrapper__action_spec( wrapped_fake_environment = mock_wrapper_class(fake_environment) mock_action_spec.assert_called_once() - wrapped_fake_environment.action_spec + wrapped_fake_environment.action_spec # noqa: B018 mock_action_spec.assert_called_once() def test_wrapper__repr(self, wrapped_fake_environment: FakeWrapper) -> None: @@ -156,9 +156,7 @@ def test_wrapper__render( ) -> None: """Checks `Wrapper.render` calls the render method of the underlying env.""" - mock_action_spec = mocker.patch.object( - fake_environment, "render", autospec=True - ) + mock_action_spec = mocker.patch.object(fake_environment, "render", autospec=True) mock_state = mocker.MagicMock() wrapped_fake_environment.render(mock_state) @@ -187,9 +185,7 @@ def test_wrapper__getattr( assert wrapped_fake_environment.time_limit == fake_environment.time_limit -FakeJumanjiToDMEnvWrapper = JumanjiToDMEnvWrapper[ - FakeState, specs.BoundedArray, chex.Array -] +FakeJumanjiToDMEnvWrapper = JumanjiToDMEnvWrapper[FakeState, specs.BoundedArray, chex.Array] class TestJumanjiEnvironmentToDeepMindEnv: @@ -258,9 +254,7 @@ def fake_gym_env(self, time_limit: int = 10) -> FakeJumanjiToGymWrapper: """Creates a fake environment wrapped as a gym.Env.""" return JumanjiToGymWrapper(FakeEnvironment(time_limit=time_limit)) - def test_jumanji_environment_to_gym_env__init( - self, fake_environment: FakeEnvironment - ) -> None: + def test_jumanji_environment_to_gym_env__init(self, fake_environment: FakeEnvironment) -> None: """Validates initialization of the gym wrapper.""" gym_environment = JumanjiToGymWrapper(fake_environment) assert isinstance(gym_environment, gym.Env) @@ -312,9 +306,7 @@ def test_jumanji_environment_to_gym_env__render( mocker: pytest_mock.MockerFixture, fake_gym_env: FakeJumanjiToGymWrapper, ) -> None: - mock_render = mocker.patch.object( - fake_gym_env.unwrapped, "render", autospec=True - ) + mock_render = mocker.patch.object(fake_gym_env.unwrapped, "render", autospec=True) mock_state = mocker.MagicMock() with pytest.raises(ValueError): @@ -342,9 +334,7 @@ def test_jumanji_environment_to_gym_env__unwrapped( assert isinstance(fake_gym_env.unwrapped, Environment) -FakeMultiToSingleWrapper = MultiToSingleWrapper[ - FakeState, specs.BoundedArray, chex.Array -] +FakeMultiToSingleWrapper = MultiToSingleWrapper[FakeState, specs.BoundedArray, chex.Array] class TestMultiToSingleEnvironment: @@ -397,8 +387,7 @@ def test_multi_env__step( assert next_timestep.reward.shape == () assert ( next_timestep.reward - == fake_multi_environment.reward_per_step - * fake_multi_environment.num_agents + == fake_multi_environment.reward_per_step * fake_multi_environment.num_agents ) assert next_timestep.discount.shape == () assert next_timestep.observation.shape[0] == fake_multi_environment.num_agents @@ -464,9 +453,7 @@ def test_multi_env__unwrapped( class TestVmapWrapper: @pytest.fixture - def fake_vmap_environment( - self, fake_environment: FakeEnvironment - ) -> FakeVmapWrapper: + def fake_vmap_environment(self, fake_environment: FakeEnvironment) -> FakeVmapWrapper: return VmapWrapper(fake_environment) def test_vmap_wrapper__init(self, fake_environment: FakeEnvironment) -> None: @@ -491,9 +478,7 @@ def test_vmap_env__step( ) -> None: """Validates step function of the vmap environment.""" state, timestep = fake_vmap_environment.reset(keys) - action = jax.vmap(lambda _: fake_vmap_environment.action_spec.generate_value())( - keys - ) + action = jax.vmap(lambda _: fake_vmap_environment.action_spec.generate_value())(keys) state, next_timestep = jax.jit(fake_vmap_environment.step)(state, action) @@ -548,9 +533,7 @@ def test_auto_reset_wrapper__auto_reset( ) -> None: """Validates the auto_reset function of the AutoResetWrapper.""" state, timestep = fake_state_and_timestep - _, reset_timestep = jax.jit(fake_auto_reset_environment._auto_reset)( - state, timestep - ) + _, reset_timestep = jax.jit(fake_auto_reset_environment._auto_reset)(state, timestep) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) # Expect that non-reset timestep obs and extras are the same. assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) @@ -566,9 +549,7 @@ def test_auto_reset_wrapper__step_no_reset( # Generate an action action = fake_auto_reset_environment.action_spec.generate_value() - state, timestep = jax.jit(fake_auto_reset_environment.step)( - state, action - ) # type: Tuple[FakeState, TimeStep[chex.Array]] + state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action) # type: Tuple[FakeState, TimeStep[chex.Array]] assert timestep.step_type == StepType.MID assert_trees_are_different(timestep, first_timestep) @@ -593,27 +574,19 @@ def test_auto_reset_wrapper__step_reset( for _ in range(fake_environment.time_limit - 1): action = fake_auto_reset_environment.action_spec.generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action) - assert jnp.all( - timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] - ) + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) state, final_timestep = jax.jit(fake_auto_reset_environment.step)(state, action) assert final_timestep.step_type == StepType.LAST - chex.assert_trees_all_equal( - final_timestep.observation, first_timestep.observation - ) + chex.assert_trees_all_equal(final_timestep.observation, first_timestep.observation) assert not jnp.all( final_timestep.observation == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] ) - assert jnp.all( - (timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] - ) + assert jnp.all((timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) -FakeVmapAutoResetWrapper = VmapAutoResetWrapper[ - FakeState, specs.BoundedArray, chex.Array -] +FakeVmapAutoResetWrapper = VmapAutoResetWrapper[FakeState, specs.BoundedArray, chex.Array] class TestVmapAutoResetWrapper: @@ -629,14 +602,10 @@ def action( fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, ) -> chex.Array: - generate_action_fn = ( - lambda _: fake_vmap_auto_reset_environment.action_spec.generate_value() - ) + generate_action_fn = lambda _: fake_vmap_auto_reset_environment.action_spec.generate_value() return jax.vmap(generate_action_fn)(keys) - def test_vmap_auto_reset_wrapper__init( - self, fake_environment: FakeEnvironment - ) -> None: + def test_vmap_auto_reset_wrapper__init(self, fake_environment: FakeEnvironment) -> None: """Validates initialization of the wrapper.""" vmap_auto_reset_env = VmapWrapper(fake_environment) assert isinstance(vmap_auto_reset_env, Environment) @@ -670,9 +639,7 @@ def test_vmap_auto_reset_wrapper__auto_reset( ) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) # expect rest timestep.extras to have the same obs as the original timestep - assert jnp.all( - timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] - ) + assert jnp.all(timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_vmap_auto_reset_wrapper__maybe_reset( self, @@ -687,9 +654,7 @@ def test_vmap_auto_reset_wrapper__maybe_reset( ) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) # expect rest timestep.extras to have the same obs as the original timestep - assert jnp.all( - timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] - ) + assert jnp.all(timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_vmap_auto_reset_wrapper__step_no_reset( self, @@ -709,9 +674,7 @@ def test_vmap_auto_reset_wrapper__step_no_reset( # no reset so expect extras and obs to be the same. # and the first timestep should have different obs in extras. - assert not jnp.all( - first_timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] - ) + assert not jnp.all(first_timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_vmap_auto_reset_wrapper__step_reset( @@ -729,26 +692,16 @@ def test_vmap_auto_reset_wrapper__step_reset( # Loop across time_limit so auto-reset occurs for _ in range(fake_vmap_auto_reset_environment.time_limit - 1): - state, timestep = jax.jit(fake_vmap_auto_reset_environment.step)( - state, action - ) - assert jnp.all( - timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] - ) - - state, final_timestep = jax.jit(fake_vmap_auto_reset_environment.step)( - state, action - ) + state, timestep = jax.jit(fake_vmap_auto_reset_environment.step)(state, action) + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) + + state, final_timestep = jax.jit(fake_vmap_auto_reset_environment.step)(state, action) assert jnp.all(final_timestep.step_type == StepType.LAST) - chex.assert_trees_all_equal( - final_timestep.observation, first_timestep.observation - ) + chex.assert_trees_all_equal(final_timestep.observation, first_timestep.observation) assert not jnp.all( final_timestep.observation == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] ) - assert jnp.all( - (timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] - ) + assert jnp.all((timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_vmap_auto_reset_wrapper__step( self, @@ -758,9 +711,7 @@ def test_vmap_auto_reset_wrapper__step( ) -> None: """Validates step function of the vmap environment.""" state, timestep = fake_vmap_auto_reset_environment.reset(keys) - state, next_timestep = jax.jit(fake_vmap_auto_reset_environment.step)( - state, action - ) + state, next_timestep = jax.jit(fake_vmap_auto_reset_environment.step)(state, action) assert_trees_are_different(next_timestep, timestep) chex.assert_trees_all_equal(next_timestep.reward, 0) @@ -768,9 +719,7 @@ def test_vmap_auto_reset_wrapper__step( assert next_timestep.discount.shape == (keys.shape[0],) assert next_timestep.observation.shape[0] == keys.shape[0] # expect observation and extras to be the same, since no reset - assert jnp.all( - next_timestep.observation == next_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] - ) + assert jnp.all(next_timestep.observation == next_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_vmap_auto_reset_wrapper__render( self, @@ -803,9 +752,7 @@ def test_jumanji_to_gym_obs__correct_observation(self) -> None: """Check that a NamedTuple containing a JAX array and a chex dataclass of two JAX arrays is converted correctly into a nested dictionary of numpy arrays. """ - NestedObservation = namedtuple( - "NestedObservation", ["jax_array", "chex_dataclass"] - ) + NestedObservation = namedtuple("NestedObservation", ["jax_array", "chex_dataclass"]) array = jnp.zeros((2, 2)) data_class = self.DummyChexDataclass(x=array, y=array) # type: ignore nested_obs = NestedObservation(array, data_class) @@ -827,9 +774,7 @@ def test_jumanji_to_gym_obs__wrong_observation(self) -> None: """Check that a NotImplementedError is raised when the wrong datatype is passed to one of the two attributes of the chex dataclass. """ - NestedObservation = namedtuple( - "NestedObservation", ["jax_array", "chex_dataclass"] - ) + NestedObservation = namedtuple("NestedObservation", ["jax_array", "chex_dataclass"]) array = jnp.zeros((10, 10)) # Pass in the wrong datatype diff --git a/pyproject.toml b/pyproject.toml index 636dcae26..ea4fabe5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,5 +91,22 @@ module = [ ] ignore_missing_imports = true -[tool.isort] -profile = "black" +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = ["A", "B", "E", "F", "I", "N", "W", "RUF", "ANN"] +ignore = [ + "E731", # Allow lambdas to be assigned to variables. + "ANN101", # no need to type self + "ANN102", # no need to type cls + "ANN204", # no need for return type for special methods + "ANN401", # can use Any type + "A002", # Argument shadowing a Python builtin. + "A003", # Class attribute shadowing a Python builtin. + "A005", # Module shadowing a Python builtin. + "B017", # assertRaises(Exception): or pytest.raises(Exception) should be considered evil. +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] # Ignore `F401` (import violations) in all `__init__.py` files. diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 03d70be87..a9dbbcfa6 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,7 +1,4 @@ -black==22.3.0 coverage -flake8 -isort==5.11.5 livereload mkdocs==1.2.3 mkdocs-git-revision-date-plugin==0.3.2 @@ -11,9 +8,8 @@ mkdocs-mermaid2-plugin==0.6.0 mkdocs_autorefs<1.0 mkdocstrings==0.18.0 mknotebooks==0.7.1 -mypy==0.991 -nbmake -pre-commit==2.17.0 +mypy +pre-commit promise pymdown-extensions pytest==7.0.1 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 47a19d216..000000000 --- a/setup.cfg +++ /dev/null @@ -1,40 +0,0 @@ -[flake8] -select = A,B,C,D,E,F,G,I,N,T,W # Specify list of error codes to report. -exclude = - .tox, - .git, - __pycache__, - build, - dist, - proto/*, - *.pyc, - *.egg-info, - .cache, - .eggs -max-line-length=100 -max-cognitive-complexity=14 -import-order-style = google -application-import-names = jumanji -doctests = True -docstring-convention = google -per-file-ignores = __init__.py:F401 - -ignore = -# Argument shadowing a Python builtin. - A002 -# Class attribute shadowing a Python builtin. - A003 -# Module shadowing a Python builtin. - A005 -# Do not require docstrings for __init__. - D107 -# Do not require block comments to only have a single leading #. - E266 -# Do not assign a lambda expression, use a def. - E731 -# Line break before binary operator (not compatible with black). - W503 -# assertRaises(Exception): or pytest.raises(Exception) should be considered evil. - B017 -# black and flake8 disagree on whitespace before ':'. - E203