diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 1227b03dc..54e5fd1ea 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -21,9 +21,9 @@ name: build on: # Triggers the workflow on push or pull request events but only for the master branch push: - branches: [ master ] + branches: [ "master" ] pull_request: - branches: [ master ] + branches: [ "master" ] # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -37,55 +37,103 @@ jobs: strategy: matrix: - python-version: [3.7] - # tf-nightly has some pip version conflicts, so can't be installed. - # Use only numbered TF as of now. - # tf-version: ["2.4.*", "tf-nightly"] - tf-version: ["2.4.*"] + python-version: [ '3.10' ] + # Which tf-version run. + tf-version: [ '2.13.0' ] # Which set of tests to run. - trax-test: ["lib", "research"] + trax-test: [ 'lib','research' ] # Steps represent a sequence of tasks that will be executed as part of the job steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install -q -U setuptools numpy - python -m pip install flake8 pytest - if [[ ${{matrix.tf-version}} == "tf-nightly" ]]; then python -m pip install tf-nightly; else python -m pip install -q "tensorflow=="${{matrix.tf-version}}; fi - pip install -e .[tests,t5] - # # Lint with flake8 - # - name: Lint with flake8 - # run: | - # # stop the build if there are Python syntax errors or undefined names - # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - # Test out right now with only testing one directory. - - name: Test with pytest - run: | - TRAX_TEST=" ${{matrix.trax-test}}" ./oss_scripts/oss_tests.sh - # The below step just reports the success or failure of tests as a "commit status". - # This is needed for copybara integration. - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v3 + - name: Set up Python ${{matrix.python-version}} + uses: actions/setup-python@v3 + with: + python-version: ${{matrix.python-version}} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 + python -m pip install setuptools==69.0.2 + python -m pip install numpy==1.23.5 + python -m pip install pytest==7.4.2 + python -m pip install tensorflow_datasets==4.2.0 + python -m pip install tensorflow_metadata==1.6.0 + python -m pip install tensorflow-text==2.13.0 + python -m pip install pandas==2.1.1 + python -m pip install matplotlib==3.8.0 + python -m pip install multimethod==1.10 + python -m pip install natsort==8.4.0 + python -m pip install omegaconf==2.3.0 + python -m pip install nltk==3.8.1 + python -m pip install pytest==7.4.2 + python -m pip install hydra-core==1.3.2 + python -m pip install scikit-learn==1.3.1 + python -m pip install pympler==1.0.1 + python -m pip install IPython==8.16.1 + python -m pip install mypy==1.5.1 + python -m pip install pylint==2.17.7 + python -m pip install black==23.9.1 + python -m pip install openpyxl==3.1.2 + python -m pip install numba==0.58.0 + python -m pip install parameterized==0.9.0 + python -m pip install mock==5.1.0 + python -m pip install tfds-nightly==4.6.0.dev202210050045 + python -m pip install editdistance==0.6.2 + python -m pip install pyglove==0.4.3 + python -m pip install sentencepiece==0.1.99 + python -m pip install babel==2.13.0 + python -m pip install rouge-score==0.1.2 + python -m pip install sacrebleu==2.3.1 + python -m pip install transformers==4.33.3 + python -m pip install mesh-tensorflow==0.1.21 + python -m pip install torch==2.0.1+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html + python -m pip install protobuf==3.20.3 + python -m pip install t5==0.9.2 --no-dependencies tensorflow + python -m pip install seqio==0.0.18 --no-deps tensorflow + python -m pip install funcsigs==1.0.2 + python -m pip install absl-py==1.4.0 + python -m pip install gym==0.26.2 + python -m pip install gin-config==0.5.0 + python -m pip install jax==0.4.20 + python -m pip install jaxlib==0.4.20 + python -m pip install psutil==5.9.5 + python -m pip install scipy==1.11.3 + python -m pip install six==1.14.0 + python -m pip install attrs==23.1.0 + python -m pip install mock==5.1.0 + python -m pip install parameterized==0.9.0 + python -m pip install pylint==2.17.7 + python -m pip install pytest==7.4.2 + python -m pip install wrapt==1.15.0 + python -m pip install tensor2tensor==1.15.7 + python -m pip install orbax-checkpoint==0.4.4 + python -m pip install clu==0.0.10 + python -m pip install flax==0.7.5 + # Test out right now with only testing one directory. + - name: Install trax package + run: | + python -m pip install -e . + - name: Test with pytest + working-directory: . + run: | + TRAX_TEST="${{matrix.trax-test}}" ./oss_scripts/oss_tests.sh + # The below step just reports the success or failure of tests as a "commit status". + # This is needed for copy bara integration. + - name: Report success or failure as github status + if: always() + shell: bash + run: | + status="${{ job.status }}" + lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + curl -sS --request POST \ + --url https://api.github.com/repos/${{github.repository}}/statuses/${{github.sha}} \ + --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + --header 'content-type: application/json' \ + --data '{ + "state": "'$lowercase_status'", + "target_url": "https://github.com/${{github.repository}}/actions/runs/${{github.run_id}}", + "description": "'$status'", + "context": "github-actions/build" + }' diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..5bc0ab766 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,82 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ "1.5.1" ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ "1.5.1" ] + schedule: + - cron: '31 4 * * 1' + +jobs: + analyze: + name: Analyze + # Runner size impacts CodeQL analysis time. To learn more, please see: + # - https://gh.io/recommended-hardware-resources-for-running-codeql + # - https://gh.io/supported-runners-and-hardware-resources + # - https://gh.io/using-larger-runners + # Consider using larger runners for possible analysis time improvements. + runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} + timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' ] + # Use only 'java-kotlin' to analyze code written in Java, Kotlin or both + # Use only 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both + # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + + # Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + # ℹ️ Command-line programs to run using the OS shell. + # πŸ“š See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + + # If the Autobuild fails above, remove it and uncomment the following three lines. + # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. + + # - run: | + # echo "Run, Build Application using script" + # ./location_of_script_within_repo/buildscript.sh + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{matrix.language}}" diff --git a/.readthedocs.yaml b/.readthedocs.yaml index d9de2c1be..f6b5ff3c5 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -28,6 +28,6 @@ formats: all # Optionally set the version of Python and requirements required to build your docs python: - version: 3.7 + version: 3.10 install: - requirements: docs/requirements.txt diff --git a/.travis.yml b/.travis.yml index 0251cb069..50cfa6499 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,10 +5,10 @@ git: depth: 3 quiet: true python: - - "3.6" + - "3.10" env: global: - - TF_VERSION="2.4.*" + - TF_VERSION="2.11.0" matrix: - TRAX_TEST="lib" - TRAX_TEST="research" diff --git a/README.md b/README.md index 33884979a..0716b7269 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ version](https://badge.fury.io/py/trax.svg)](https://badge.fury.io/py/trax) [![GitHub Issues](https://img.shields.io/github/issues/google/trax.svg)](https://github.com/google/trax/issues) -![GitHub Build](https://github.com/google/trax/actions/workflows/build.yaml/badge.svg) +![GitHub Build](https://github.com/mmarcinmichal/trax/actions/workflows/build.yaml/badge.svg) [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) diff --git a/docs/.readthedocs.yaml b/docs/.readthedocs.yaml index d9de2c1be..f6b5ff3c5 100644 --- a/docs/.readthedocs.yaml +++ b/docs/.readthedocs.yaml @@ -28,6 +28,6 @@ formats: all # Optionally set the version of Python and requirements required to build your docs python: - version: 3.7 + version: 3.10 install: - requirements: docs/requirements.txt diff --git a/oss_scripts/oss_pip_install.sh b/oss_scripts/oss_pip_install.sh index 839d3d9fc..652fa447d 100755 --- a/oss_scripts/oss_pip_install.sh +++ b/oss_scripts/oss_pip_install.sh @@ -15,7 +15,7 @@ #!/bin/bash set -v # print commands as they're executed -set -e # fail and exit on any command erroring +set -e # fail and exit on any command error : "${TF_VERSION:?}" diff --git a/oss_scripts/oss_release.sh b/oss_scripts/oss_release.sh index 9d913ba8f..7b2944fb9 100755 --- a/oss_scripts/oss_release.sh +++ b/oss_scripts/oss_release.sh @@ -15,18 +15,18 @@ #!/bin/bash set -v # print commands as they're executed -set -e # fail and exit on any command erroring +set -e # fail and exit on any command error GIT_COMMIT_ID=${1:-""} [[ -z $GIT_COMMIT_ID ]] && echo "Must provide a commit" && exit 1 TMP_DIR=$(mktemp -d) -pushd $TMP_DIR +pushd "$TMP_DIR" echo "Cloning trax and checking out commit $GIT_COMMIT_ID" git clone https://github.com/google/trax.git cd trax -git checkout $GIT_COMMIT_ID +git checkout "$GIT_COMMIT_ID" python3 -m pip install wheel twine pyopenssl @@ -42,4 +42,4 @@ python3 -m twine upload dist/* # Cleanup rm -rf build/ dist/ trax.egg-info/ popd -rm -rf $TMP_DIR +rm -rf "$TMP_DIR" diff --git a/oss_scripts/oss_tests.sh b/oss_scripts/oss_tests.sh index ee3bf428f..c359ed573 100755 --- a/oss_scripts/oss_tests.sh +++ b/oss_scripts/oss_tests.sh @@ -1,3 +1,5 @@ +#!/bin/bash + # Copyright 2022 The Trax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -#!/bin/bash - set -v # print commands as they're executed # aliases aren't expanded in non-interactive shells by default. @@ -46,98 +46,62 @@ set_status # # Run pytest with coverage. # alias pytest='coverage run -m pytest' -# Check tests, separate out directories for easy triage. - +# Check tests, check each directory of tests separately. if [[ "${TRAX_TEST}" == "lib" ]] then + echo "Testing all framework packages..." + ## Core Trax and Supervised Learning + pytest tests/data + set_status - # Disabled the decoding test for now, since it OOMs. - # TODO(afrozm): Add the decoding_test.py back again. - - # training_test and trainer_lib_test parse flags, so can't use with --ignore - pytest \ - --ignore=trax/supervised/callbacks_test.py \ - --ignore=trax/supervised/decoding_test.py \ - --ignore=trax/supervised/decoding_timing_test.py \ - --ignore=trax/supervised/trainer_lib_test.py \ - --ignore=trax/supervised/training_test.py \ - trax/supervised + pytest tests/fastmath set_status - # Testing these separately here. - pytest \ - trax/supervised/callbacks_test.py \ - trax/supervised/trainer_lib_test.py \ - trax/supervised/training_test.py + pytest tests/layers set_status - pytest trax/data + pytest tests/models set_status - # Ignoring acceleration_test's test_chunk_grad_memory since it is taking a - # lot of time on OSS. - pytest \ - --deselect=trax/layers/acceleration_test.py::AccelerationTest::test_chunk_grad_memory \ - --deselect=trax/layers/acceleration_test.py::AccelerationTest::test_chunk_memory \ - --ignore=trax/layers/initializers_test.py \ - --ignore=trax/layers/test_utils.py \ - trax/layers + pytest tests/optimizers set_status - pytest trax/layers/initializers_test.py + pytest tests/supervised set_status - pytest trax/fastmath + pytest tests/tf_numpy/extensions set_status - pytest trax/optimizers + pytest tests/tf_numpy/jax set_status - # Catch-all for futureproofing. - pytest \ - --ignore=trax/trax2keras_test.py \ - --ignore=trax/data \ - --ignore=trax/fastmath \ - --ignore=trax/layers \ - --ignore=trax/models \ - --ignore=trax/optimizers \ - --ignore=trax/rl \ - --ignore=trax/supervised \ - --ignore=trax/tf_numpy + pytest tests/tf_numpy/numpy_impl set_status -else - # Models, RL and misc right now. - ## Models - # Disabled tests are quasi integration tests. - pytest \ - --ignore=trax/models/reformer/reformer_e2e_test.py \ - --ignore=trax/models/reformer/reformer_memory_test.py \ - --ignore=trax/models/research/terraformer_e2e_test.py \ - --ignore=trax/models/research/terraformer_memory_test.py \ - --ignore=trax/models/research/terraformer_oom_test.py \ - trax/models + pytest tests/tf_numpy/public_symbol_test.py set_status - ## RL Trax - pytest trax/rl + pytest tests/import_test.py set_status - ## Trax2Keras - # TODO(afrozm): Make public again after TF 2.5 releases. - # pytest trax/trax2keras_test.py - # set_status + pytest tests/shapes_test.py + set_status + + pytest tests/trax2keras_test.py + set_status + +else + echo "No testing ..." + # Models, RL and misc right now. # Check notebooks. # TODO(afrozm): Add more. - jupyter nbconvert --ExecutePreprocessor.kernel_name=python3 \ - --ExecutePreprocessor.timeout=600 --to notebook --execute \ - trax/intro.ipynb; - set_status + # jupyter nbconvert --ExecutePreprocessor.kernel_name=python3 \ + # --ExecutePreprocessor.timeout=600 --to notebook --execute \ + # trax/intro.ipynb; + # set_status fi -# TODO(traxers): Test tf-numpy separately. - exit $STATUS diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..f284563f6 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = . trax \ No newline at end of file diff --git a/trax/data/testdata/bert_uncased_vocab.txt b/resources/data/testdata/bert_uncased_vocab.txt similarity index 100% rename from trax/data/testdata/bert_uncased_vocab.txt rename to resources/data/testdata/bert_uncased_vocab.txt diff --git a/trax/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 b/resources/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 rename to resources/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 b/resources/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 rename to resources/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/c4/en/2.3.0/dataset_info.json b/resources/data/testdata/c4/en/2.3.0/dataset_info.json similarity index 100% rename from trax/data/testdata/c4/en/2.3.0/dataset_info.json rename to resources/data/testdata/c4/en/2.3.0/dataset_info.json diff --git a/trax/data/testdata/corpus-1.txt b/resources/data/testdata/corpus-1.txt similarity index 100% rename from trax/data/testdata/corpus-1.txt rename to resources/data/testdata/corpus-1.txt diff --git a/trax/data/testdata/corpus-2.txt b/resources/data/testdata/corpus-2.txt similarity index 100% rename from trax/data/testdata/corpus-2.txt rename to resources/data/testdata/corpus-2.txt diff --git a/trax/data/testdata/en_8k.subword b/resources/data/testdata/en_8k.subword similarity index 100% rename from trax/data/testdata/en_8k.subword rename to resources/data/testdata/en_8k.subword diff --git a/trax/data/testdata/para_crawl/ende/1.2.0/dataset_info.json b/resources/data/testdata/para_crawl/ende/1.2.0/dataset_info.json similarity index 100% rename from trax/data/testdata/para_crawl/ende/1.2.0/dataset_info.json rename to resources/data/testdata/para_crawl/ende/1.2.0/dataset_info.json diff --git a/trax/data/testdata/para_crawl/ende/1.2.0/features.json b/resources/data/testdata/para_crawl/ende/1.2.0/features.json similarity index 100% rename from trax/data/testdata/para_crawl/ende/1.2.0/features.json rename to resources/data/testdata/para_crawl/ende/1.2.0/features.json diff --git a/trax/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 b/resources/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 rename to resources/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/sentencepiece.model b/resources/data/testdata/sentencepiece.model similarity index 100% rename from trax/data/testdata/sentencepiece.model rename to resources/data/testdata/sentencepiece.model diff --git a/resources/data/testdata/squad/v1.1/3.0.0/dataset_info.json b/resources/data/testdata/squad/v1.1/3.0.0/dataset_info.json new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 b/resources/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 rename to resources/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 b/resources/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 rename to resources/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/vocab-1.txt b/resources/data/testdata/vocab-1.txt similarity index 100% rename from trax/data/testdata/vocab-1.txt rename to resources/data/testdata/vocab-1.txt diff --git a/trax/data/testdata/vocab-2.txt b/resources/data/testdata/vocab-2.txt similarity index 100% rename from trax/data/testdata/vocab-2.txt rename to resources/data/testdata/vocab-2.txt diff --git a/resources/examples/ipynb/Attention_Visualization_in_Trax.ipynb For more information see the [tenso2tensor](https://trax-ml.readthedocs.io/en/latest/) visualization colab. All js tools are taken from the tensor2tensor version along with attention processing methods. The "viz" mode for a Trax model used in this colab [was added to Trax](https://github.com/google/trax/commit/e9a171379ef206a3e351b67cef91fe40bf37589c) with the attention visualization in mind. The colab re-uses some parts of the [Intro to Trax](https://github.com/google/trax/blob/master/trax/intro.ipynb) colab. All js tools are taken from the tensor2tensor version along with attention processing methods. The "viz" mode for a Trax model used in this colab [was added to Trax](https://github.com/google/trax/commit/e9a171379ef206a3e351b67cef91fe40bf37589c) with the attention visualization in mind. The colab re-uses some parts of the [Intro to Trax](https://github.com/google/trax/blob/master/trax/intro.ipynb) colab.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BIl27504La0G" + }, + "source": [ + "**General Setup**\n", + "\n", + "Execute the following few cells (once) before running of visualization codes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\n", + "import json\n", + "import numpy as np\n", + "import os\n", + "import IPython.display as display\n", + "import gin" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 466 + }, + "colab_type": "code", + "id": "vlGjGoGMTt-D", + "outputId": "28f4556b-caef-47a1-bddd-7f51ecc064d8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 368kB 2.8MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.5MB 13.0MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2.6MB 20.1MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 163kB 33.1MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 194kB 19.4MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 983kB 30.6MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 655kB 56.6MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 81kB 11.7MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5.3MB 45.0MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 368kB 57.1MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 307kB 55.8MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 358kB 58.6MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.1MB 59.0MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3.5MB 58.4MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 778kB 59.4MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 51kB 8.7MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 51kB 8.6MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 235kB 54.2MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3.0MB 62.4MB/s \n", + "\u001B[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 890kB 58.2MB/s \n", + "\u001B[?25h Building wheel for bz2file (setup.py) ... \u001B[?25l\u001B[?25hdone\n", + " Building wheel for pypng (setup.py) ... \u001B[?25l\u001B[?25hdone\n", + " Building wheel for sacremoses (setup.py) ... \u001B[?25l\u001B[?25hdone\n", + "\u001B[31mERROR: kfac 0.2.2 has requirement tensorflow-probability==0.8, but you'll have tensorflow-probability 0.7.0 which is incompatible.\u001B[0m\n", + "INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 \n" + ] + } + ], + "source": [ + "#@title\n", + "# Import Trax\n", + "\n", + "!pip install -q -U trax\n", + "import trax" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "VCBjVMrZRS6q" + }, + "outputs": [], + "source": [ + "#@title Some cool tooling for attention (make sure that you run the cell)\n", + "def resize(att_mat, max_length=None):\n", + " \"\"\"Normalize attention matrices and reshape as necessary.\"\"\"\n", + " for i, att in enumerate(att_mat):\n", + " # Add extra batch dim for viz code to work.\n", + " if att.ndim == 3:\n", + " att = np.expand_dims(att, axis=0)\n", + " if max_length is not None:\n", + " # Sum across different attention values for each token.\n", + " att = att[:, :, :max_length, :max_length]\n", + " row_sums = np.sum(att, axis=2)\n", + " # Normalize\n", + " att /= row_sums[:, :, np.newaxis]\n", + " att_mat[i] = att\n", + " return att_mat\n", + "\n", + "\n", + "def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts):\n", + " \"\"\"Compute representation of the attention ready for the d3 visualization.\n", + "\n", + " Args:\n", + " inp_text: list of strings, words to be displayed on the left of the vis\n", + " out_text: list of strings, words to be displayed on the right of the vis\n", + " enc_atts: numpy array, encoder self-attentions\n", + " [num_layers, batch_size, num_heads, enc_length, enc_length]\n", + " dec_atts: numpy array, decoder self-attentions\n", + " [num_layers, batch_size, num_heads, dec_length, dec_length]\n", + " encdec_atts: numpy array, encoder-decoder attentions\n", + " [num_layers, batch_size, num_heads, dec_length, enc_length]\n", + "\n", + " Returns:\n", + " Dictionary of attention representations with the structure:\n", + " {\n", + " 'all': Representations for showing all attentions at the same time.\n", + " 'inp_inp': Representations for showing encoder self-attentions\n", + " 'inp_out': Representations for showing encoder-decoder attentions\n", + " 'out_out': Representations for showing decoder self-attentions\n", + " }\n", + " and each sub-dictionary has structure:\n", + " {\n", + " 'att': list of inter attentions matrices, one for each attention head\n", + " 'top_text': list of strings, words to be displayed on the left of the vis\n", + " 'bot_text': list of strings, words to be displayed on the right of the vis\n", + " }\n", + " \"\"\"\n", + " def get_full_attention(layer):\n", + " \"\"\"Get the full input+output - input+output attentions.\"\"\"\n", + " enc_att = enc_atts[layer][0]\n", + " dec_att = dec_atts[layer][0]\n", + " encdec_att = encdec_atts[layer][0]\n", + " enc_att = np.transpose(enc_att, [0, 2, 1])\n", + " dec_att = np.transpose(dec_att, [0, 2, 1])\n", + " encdec_att = np.transpose(encdec_att, [0, 2, 1])\n", + " # [heads, query_length, memory_length]\n", + " enc_length = enc_att.shape[1]\n", + " dec_length = dec_att.shape[1]\n", + " num_heads = enc_att.shape[0]\n", + " first = np.concatenate([enc_att, encdec_att], axis=2)\n", + " second = np.concatenate(\n", + " [np.zeros((num_heads, dec_length, enc_length)), dec_att], axis=2)\n", + " full_att = np.concatenate([first, second], axis=1)\n", + " return [ha.T.tolist() for ha in full_att]\n", + "\n", + " def get_inp_inp_attention(layer):\n", + " att = np.transpose(enc_atts[layer][0], (0, 2, 1))\n", + " return [ha.T.tolist() for ha in att]\n", + "\n", + " def get_out_inp_attention(layer):\n", + " att = np.transpose(encdec_atts[layer][0], (0, 2, 1))\n", + " return [ha.T.tolist() for ha in att]\n", + "\n", + " def get_out_out_attention(layer):\n", + " att = np.transpose(dec_atts[layer][0], (0, 2, 1))\n", + " return [ha.T.tolist() for ha in att]\n", + "\n", + " def get_attentions(get_attention_fn):\n", + " num_layers = len(enc_atts)\n", + " return [get_attention_fn(i) for i in range(num_layers)]\n", + "\n", + " attentions = {\n", + " 'all': {\n", + " 'att': get_attentions(get_full_attention),\n", + " 'top_text': inp_text + out_text,\n", + " 'bot_text': inp_text + out_text,\n", + " },\n", + " 'inp_inp': {\n", + " 'att': get_attentions(get_inp_inp_attention),\n", + " 'top_text': inp_text,\n", + " 'bot_text': inp_text,\n", + " },\n", + " 'inp_out': {\n", + " 'att': get_attentions(get_out_inp_attention),\n", + " 'top_text': inp_text,\n", + " 'bot_text': out_text,\n", + " },\n", + " 'out_out': {\n", + " 'att': get_attentions(get_out_out_attention),\n", + " 'top_text': out_text,\n", + " 'bot_text': out_text,\n", + " },\n", + " }\n", + "\n", + " return attentions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "47lzWIH5THcw" + }, + "outputs": [], + "source": [ + "#@title Some cool HTML and js stuff (make sure that you run the cell)\n", + "vis_html = \"\"\"\n", + " \n", + " Layer: \n", + " Attention: \n", + " \n", + "
\n", + "\"\"\"\n", + "def call_html():\n", + " import IPython\n", + " display.display(display.HTML('''\n", + " \n", + " \n", + " '''))\n", + "vis_js = \"\"\"\n", + "/**\n", + " * @fileoverview Transformer Visualization D3 javascript code.\n", + " */\n", + "\n", + "requirejs(['jquery', 'd3'],\n", + "function($, d3) {\n", + "\n", + "var attention = window.attention;\n", + "\n", + "const TEXT_SIZE = 15;\n", + "const BOXWIDTH = TEXT_SIZE * 8;\n", + "const BOXHEIGHT = TEXT_SIZE * 1.5;\n", + "const WIDTH = 2000;\n", + "const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100;\n", + "const MATRIX_WIDTH = 150;\n", + "const head_colours = d3.scale.category10();\n", + "const CHECKBOX_SIZE = 20;\n", + "\n", + "function lighten(colour) {\n", + " var c = d3.hsl(colour);\n", + " var increment = (1 - c.l) * 0.6;\n", + " c.l += increment;\n", + " c.s -= increment;\n", + " return c;\n", + "}\n", + "\n", + "function transpose(mat) {\n", + " return mat[0].map(function(col, i) {\n", + " return mat.map(function(row) {\n", + " return row[i];\n", + " });\n", + " });\n", + "}\n", + "\n", + "function zip(a, b) {\n", + " return a.map(function (e, i) {\n", + " return [e, b[i]];\n", + " });\n", + "}\n", + "\n", + "\n", + "function renderVis(id, top_text, bot_text, attention_heads, config) {\n", + " $(id).empty();\n", + " var svg = d3.select(id)\n", + " .append('svg')\n", + " .attr(\"width\", WIDTH)\n", + " .attr(\"height\", HEIGHT);\n", + "\n", + " var att_data = [];\n", + " for (var i=0; i < attention_heads.length; i++) {\n", + " var att_trans = transpose(attention_heads[i]);\n", + " att_data.push(zip(attention_heads[i], att_trans));\n", + " }\n", + "\n", + " renderText(svg, top_text, true, att_data, 0);\n", + " renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH);\n", + "\n", + " renderAttentionHighlights(svg, att_data);\n", + "\n", + " svg.append(\"g\").classed(\"attention_heads\", true);\n", + "\n", + " renderAttention(svg, attention_heads);\n", + "\n", + " draw_checkboxes(config, 0, svg, attention_heads);\n", + "}\n", + "\n", + "\n", + "function renderText(svg, text, is_top, att_data, left_pos) {\n", + " var id = is_top ? \"top\" : \"bottom\";\n", + " var textContainer = svg.append(\"svg:g\")\n", + " .attr(\"id\", id);\n", + "\n", + " textContainer.append(\"g\").classed(\"attention_boxes\", true)\n", + " .selectAll(\"g\")\n", + " .data(att_data)\n", + " .enter()\n", + " .append(\"g\")\n", + " .selectAll(\"rect\")\n", + " .data(function(d) {return d;})\n", + " .enter()\n", + " .append(\"rect\")\n", + " .attr(\"x\", function(d, i, j) {\n", + " return left_pos + box_offset(j);\n", + " })\n", + " .attr(\"y\", function(d, i) {\n", + " return (+1) * BOXHEIGHT;\n", + " })\n", + " .attr(\"width\", BOXWIDTH/active_heads())\n", + " .attr(\"height\", function() { return BOXHEIGHT; })\n", + " .attr(\"fill\", function(d, i, j) {\n", + " return head_colours(j);\n", + " })\n", + " .style(\"opacity\", 0.0);\n", + "\n", + "\n", + " var tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n", + " .data(text)\n", + " .enter()\n", + " .append(\"g\");\n", + "\n", + " tokenContainer.append(\"rect\")\n", + " .classed(\"background\", true)\n", + " .style(\"opacity\", 0.0)\n", + " .attr(\"fill\", \"lightgray\")\n", + " .attr(\"x\", left_pos)\n", + " .attr(\"y\", function(d, i) {\n", + " return (i+1) * BOXHEIGHT;\n", + " })\n", + " .attr(\"width\", BOXWIDTH)\n", + " .attr(\"height\", BOXHEIGHT);\n", + "\n", + " var theText = tokenContainer.append(\"text\")\n", + " .text(function(d) { return d; })\n", + " .attr(\"font-size\", TEXT_SIZE + \"px\")\n", + " .style(\"cursor\", \"default\")\n", + " .style(\"-webkit-user-select\", \"none\")\n", + " .attr(\"x\", left_pos)\n", + " .attr(\"y\", function(d, i) {\n", + " return (i+1) * BOXHEIGHT;\n", + " });\n", + "\n", + " if (is_top) {\n", + " theText.style(\"text-anchor\", \"end\")\n", + " .attr(\"dx\", BOXWIDTH - TEXT_SIZE)\n", + " .attr(\"dy\", TEXT_SIZE);\n", + " } else {\n", + " theText.style(\"text-anchor\", \"start\")\n", + " .attr(\"dx\", + TEXT_SIZE)\n", + " .attr(\"dy\", TEXT_SIZE);\n", + " }\n", + "\n", + " tokenContainer.on(\"mouseover\", function(d, index) {\n", + " textContainer.selectAll(\".background\")\n", + " .style(\"opacity\", function(d, i) {\n", + " return i == index ? 1.0 : 0.0;\n", + " });\n", + "\n", + " svg.selectAll(\".attention_heads\").style(\"display\", \"none\");\n", + "\n", + " svg.selectAll(\".line_heads\") // To get the nesting to work.\n", + " .selectAll(\".att_lines\")\n", + " .attr(\"stroke-opacity\", function(d) {\n", + " return 1.0;\n", + " })\n", + " .attr(\"y1\", function(d, i) {\n", + " if (is_top) {\n", + " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " } else {\n", + " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " }\n", + " })\n", + " .attr(\"x1\", BOXWIDTH)\n", + " .attr(\"y2\", function(d, i) {\n", + " if (is_top) {\n", + " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " } else {\n", + " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " }\n", + " })\n", + " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", + " .attr(\"stroke-width\", 2)\n", + " .attr(\"stroke\", function(d, i, j) {\n", + " return head_colours(j);\n", + " })\n", + " .attr(\"stroke-opacity\", function(d, i, j) {\n", + " if (is_top) {d = d[0];} else {d = d[1];}\n", + " if (config.head_vis[j]) {\n", + " if (d) {\n", + " return d[index];\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " });\n", + "\n", + "\n", + " function updateAttentionBoxes() {\n", + " var id = is_top ? \"bottom\" : \"top\";\n", + " var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0;\n", + " svg.select(\"#\" + id)\n", + " .selectAll(\".attention_boxes\")\n", + " .selectAll(\"g\")\n", + " .selectAll(\"rect\")\n", + " .attr(\"x\", function(d, i, j) { return the_left_pos + box_offset(j); })\n", + " .attr(\"y\", function(d, i) { return (i+1) * BOXHEIGHT; })\n", + " .attr(\"width\", BOXWIDTH/active_heads())\n", + " .attr(\"height\", function() { return BOXHEIGHT; })\n", + " .style(\"opacity\", function(d, i, j) {\n", + " if (is_top) {d = d[0];} else {d = d[1];}\n", + " if (config.head_vis[j])\n", + " if (d) {\n", + " return d[index];\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " else\n", + " return 0.0;\n", + "\n", + " });\n", + " }\n", + "\n", + " updateAttentionBoxes();\n", + " });\n", + "\n", + " textContainer.on(\"mouseleave\", function() {\n", + " d3.select(this).selectAll(\".background\")\n", + " .style(\"opacity\", 0.0);\n", + "\n", + " svg.selectAll(\".att_lines\").attr(\"stroke-opacity\", 0.0);\n", + " svg.selectAll(\".attention_heads\").style(\"display\", \"inline\");\n", + " svg.selectAll(\".attention_boxes\")\n", + " .selectAll(\"g\")\n", + " .selectAll(\"rect\")\n", + " .style(\"opacity\", 0.0);\n", + " });\n", + "}\n", + "\n", + "function renderAttentionHighlights(svg, attention) {\n", + " var line_container = svg.append(\"g\");\n", + " line_container.selectAll(\"g\")\n", + " .data(attention)\n", + " .enter()\n", + " .append(\"g\")\n", + " .classed(\"line_heads\", true)\n", + " .selectAll(\"line\")\n", + " .data(function(d){return d;})\n", + " .enter()\n", + " .append(\"line\").classed(\"att_lines\", true);\n", + "}\n", + "\n", + "function renderAttention(svg, attention_heads) {\n", + " var line_container = svg.selectAll(\".attention_heads\");\n", + " line_container.html(null);\n", + " for(var h=0; h\").val(i).text(i));\n", + "}\n", + "\n", + "$(\"#layer\").on('change', function(e) {\n", + " config.layer = +e.currentTarget.value;\n", + " render();\n", + "});\n", + "\n", + "$(\"#att_type\").on('change', function(e) {\n", + " config.att_type = e.currentTarget.value;\n", + " render();\n", + "});\n", + "\n", + "$(\"button\").on('click', visualize);\n", + "\n", + "visualize();\n", + "\n", + "});\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-LQ89rFFsEdk" + }, + "source": [ + "## 1. Run a pre-trained Transformer\n", + "\n", + "* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)\n", + "* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)\n", + "* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)\n", + "* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)\n", + "* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "djTiSLcaNFGa", + "outputId": "b5ad2955-5e1d-47aa-97bb-5d72a25ed76d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Es ist schΓΆn, heute neue Dinge zu lernen!\n" + ] + } + ], + "source": [ + "# Create a Transformer model.\n", + "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", + "model = trax.models.Transformer(\n", + " input_vocab_size=33300,\n", + " d_model=512, d_ff=2048,\n", + " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", + " max_len=2048, mode='predict')\n", + "\n", + "# Initialize using pre-trained weights.\n", + "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", + " weights_only=True)\n", + "\n", + "# Tokenize a sentence.\n", + "sentence = 'It is nice to learn new things today!'\n", + "tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.\n", + " vocab_dir='gs://trax-ml/vocabs/',\n", + " vocab_file='ende_32k.subword'))[0]\n", + "\n", + "# Decode from the Transformer.\n", + "tokenized = tokenized[None, :] # Add batch dimension.\n", + "tokenized_translation = trax.supervised.decoding.autoregressive_sample(\n", + " model, tokenized, temperature=0.0) # Higher temperature: more diverse results.\n", + "\n", + "# De-tokenize,\n", + "tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.\n", + "translation = trax.data.detokenize(tokenized_translation,\n", + " vocab_dir='gs://trax-ml/vocabs/',\n", + " vocab_file='ende_32k.subword')\n", + "print(translation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "pWDPwZfSJeD3", + "outputId": "050d40bf-f28d-49ea-b69a-af2886cf92a4" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[ 118, 16, 1902, 9, 3197, 141, 1059, 420, 207]]),\n", + " array([ 168, 24, 9358, 2, 352, 367, 2427, 18, 3580, 207]))" + ] + }, + "execution_count": 6, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "tokenized, tokenized_translation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Lu6URNjbXIHv" + }, + "source": [ + "## 2. Prepare the tokens for visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kqNWMpNdMg9z" + }, + "outputs": [], + "source": [ + "def decode(single_token):\n", + " return trax.data.detokenize(single_token,\n", + " vocab_dir='gs://trax-ml/vocabs/',\n", + " vocab_file='ende_32k.subword')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "H2fbJB_BMeRw" + }, + "outputs": [], + "source": [ + "def get_tokens_str(integers):\n", + " token_strs = []\n", + " for i in range(integers.shape[1]):\n", + " token_strs.append(decode(integers[:,i]))\n", + " return token_strs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "YkNT8rbgKM5-" + }, + "outputs": [], + "source": [ + "tokenized_translation_with_start = np.array([0]+list(tokenized_translation), dtype=np.int64)\n", + "tokenized_translation_with_start = tokenized_translation_with_start[np.newaxis, ...]\n", + "tokenized_translation = np.array(tokenized_translation, dtype=np.int64)\n", + "tokenized_translation = tokenized_translation[np.newaxis, ...]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "r-FVdSZPKQhs" + }, + "outputs": [], + "source": [ + "tokenized_str = get_tokens_str(tokenized)\n", + "tokenized_translation_str = get_tokens_str(tokenized_translation_with_start)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 223 + }, + "colab_type": "code", + "id": "Cy7edKBuKash", + "outputId": "c1e00dbe-f467-48df-eaaf-579f68ef788f" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(['It', 'is', 'nice', 'to', 'learn', 'new', 'things', 'today', '!'],\n", + " ['',\n", + " 'Es',\n", + " 'ist',\n", + " 'schΓΆn',\n", + " ', ',\n", + " 'heute',\n", + " 'neue',\n", + " 'Dinge',\n", + " 'zu',\n", + " 'lernen',\n", + " '!'])" + ] + }, + "execution_count": 11, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "tokenized_str, tokenized_translation_str" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1XxJSqAsOTBe" + }, + "outputs": [], + "source": [ + "max_len = max(tokenized.shape[1], tokenized_translation.shape[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Qju-9pPHOV6G" + }, + "outputs": [], + "source": [ + "tokenized_translation_pad = np.zeros((1,max_len), dtype=np.int64)\n", + "tokenized_translation_pad[:,:tokenized_translation.shape[1]] = tokenized_translation\n", + "\n", + "tokenized_pad = np.zeros((1,max_len), dtype=np.int64)\n", + "tokenized_pad[:,:tokenized.shape[1]] = tokenized" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "zGxBSk0gOfYi", + "outputId": "d83328fa-eec8-4631-d2b6-4fffc3f0b933" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((1, 10), (1, 10))" + ] + }, + "execution_count": 14, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "tokenized_translation_pad.shape, tokenized_pad.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "WqvjmRaCXign" + }, + "source": [ + "## 3. Create the same pre-trained model in the \"viz\" mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Qb2F4Pj_OLMZ" + }, + "outputs": [], + "source": [ + "# Create a Transformer model in the \"viz\" mode\n", + "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", + "model_viz = trax.models.Transformer(\n", + " input_vocab_size=33300,\n", + " d_model=512, d_ff=2048,\n", + " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", + " max_len=2048, mode='viz')\n", + "\n", + "# Initialize using pre-trained weights.\n", + "model_viz.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", + " weights_only=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "AxcrAfprO0rD" + }, + "outputs": [], + "source": [ + "# We run the viz model because later we want to inspect its state\n", + "_ = model_viz((tokenized_pad, tokenized_translation_pad))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "lVCYSQSuXw6f" + }, + "source": [ + "## 4. Find the attention weights (aka dots)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "dsGuqdgnO2Lf" + }, + "outputs": [], + "source": [ + "attention_weights = []\n", + "def attention_sublayers(layer):\n", + " if 'Attention' in layer.name:\n", + " print(\"Found layer {}\".format(layer.name))\n", + " attention_weights.append(layer.state)\n", + " if layer.sublayers:\n", + " for sublayer in layer.sublayers:\n", + " attention_sublayers(sublayer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 326 + }, + "colab_type": "code", + "id": "FA3ba2-DO5l4", + "outputId": "f66756b1-fa86-4582-bd04-9b464ae132eb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found layer PureAttention\n", + "Found layer PureAttention\n", + "Found layer PureAttention\n", + "Found layer PureAttention\n", + "Found layer PureAttention\n", + "Found layer PureAttention\n", + "Found layer DotProductCausalAttention\n", + "Found layer PureAttention\n", + "Found layer DotProductCausalAttention\n", + "Found layer PureAttention\n", + "Found layer DotProductCausalAttention\n", + "Found layer PureAttention\n", + "Found layer DotProductCausalAttention\n", + "Found layer PureAttention\n", + "Found layer DotProductCausalAttention\n", + "Found layer PureAttention\n", + "Found layer DotProductCausalAttention\n", + "Found layer PureAttention\n" + ] + } + ], + "source": [ + "attention_sublayers(model_viz)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "q36-o98QO7HC", + "outputId": "445fe1ce-f1fa-484a-9db4-b37f56915d7c" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "18" + ] + }, + "execution_count": 19, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "len(attention_weights)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LahOE6q6PB1B" + }, + "outputs": [], + "source": [ + "# Manually identification of layers would be difficult, hence we rely on attention_sublayers function\n", + "enc_atts = attention_weights[:6]\n", + "dec_atts = attention_weights[6::2] # these are the DotProductCausalAttention layers\n", + "encdec_atts = attention_weights[7::2] # these are the PureAttention layers starting from the 6th layer on\n", + "\n", + "# Here we use a number of python utils inherited from tensor2tensor\n", + "enc_atts_res = resize(enc_atts)\n", + "dec_atts_res = resize(dec_atts)\n", + "encdec_atts_res = resize(encdec_atts)\n", + "attention_dict = _get_attention(tokenized_str, tokenized_translation_str, enc_atts_res, dec_atts_res, encdec_atts_res)\n", + "attention_json = json.dumps(attention_dict)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1DgBBfg-X6-d" + }, + "source": [ + "## 5. Display attention" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "resources": { + "http://localhost:8080/static/components/requirejs/require.js": { + "data": 0.07958999276161194, 0.34062448143959045, 0.0068156360648572445, 0.000798556546214968, 0.0009541919571347535, 0.00023223790049087256, 0.05767740309238434]], [[0.007143852766603231, 0.26111796498298645, 0.053768061101436615, 0.022731401026248932, 0.014146089553833008, 0.012985849753022194, 0.007359612733125687, 0.0043042986653745174, 0.6164429783821106, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6188079118728638, 0.11004422605037689, 0.07541824132204056, 0.010463211685419083, 0.003863272722810507, 0.016659650951623917, 0.028880171477794647, 0.010046081617474556, 0.1258174628019333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1673126220703125, 0.6515482068061829, 0.016748156398534775, 0.042502570897340775, 0.016912223771214485, 0.011716129258275032, 0.04548521339893341, 0.0008787817787379026, 0.04689598083496094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6625117659568787, 0.049922335892915726, 0.2738172709941864, 0.004228150937706232, 0.0033112652599811554, 0.001177642960101366, 0.0005330604617483914, 0.00011132613872177899, 0.0043872324749827385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0040039438754320145, 0.00112480903044343, 0.04353015124797821, 0.9313303232192993, 0.010056668892502785, 0.0007567661814391613, 0.0006773694767616689, 0.00016374654660467058, 0.008356312289834023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0015825676964595914, 0.001574154943227768, 0.001225732616148889, 0.27774307131767273, 0.47191065549850464, 0.041899941861629486, 0.10331469774246216, 0.0047262245789170265, 0.0960230901837349, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0012044805334880948, 0.001744594657793641, 0.010911357589066029, 0.035235531628131866, 0.12406003475189209, 0.49639585614204407, 0.02129644714295864, 0.07618547230958939, 0.23296628892421722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006665610242635012, 0.008957373909652233, 0.0028928713873028755, 0.7268922924995422, 0.10707614570856094, 0.01201178040355444, 0.013845101930201054, 0.022992080077528954, 0.09866661578416824, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04134861007332802, 0.0526767373085022, 0.04131396487355232, 0.023087071254849434, 0.04077164828777313, 0.027765633538365364, 0.05679082125425339, 0.025407245382666588, 0.6908383369445801, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0021411015186458826, 0.012145284563302994, 0.008635377511382103, 0.004571457393467426, 0.009789393283426762, 0.022923681885004044, 0.019266795367002487, 0.15913596749305725, 0.7613908052444458, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9029706120491028, 0.09702935069799423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.909331738948822, 0.09066825360059738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0243820920586586, 0.026593990623950958, 0.9490237236022949, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3865880072116852, 0.017979737371206284, 0.5954321622848511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002445725491270423, 0.01137782447040081, 0.2685152590274811, 0.7176609635353088, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5815560817718506, 0.15706834197044373, 0.052335821092128754, 0.2090395838022232, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013372139073908329, 0.017163371667265892, 0.023703746497631073, 0.029362313449382782, 0.916398286819458, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7268465161323547, 0.0363004133105278, 0.07873083651065826, 0.06576839834451675, 0.09235385805368423, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009483089670538902, 0.0015323974657803774, 0.016186771914362907, 0.02369842305779457, 0.15252061188220978, 0.7965786457061768, 0.0, 0.0, 0.0, 0.0, 0.6550694704055786, 0.019533857703208923, 0.042362816631793976, 0.07321250438690186, 0.06519921869039536, 0.14462217688560486, 0.0, 0.0, 0.0, 0.0], [0.9718639850616455, 0.0004310244112275541, 0.00011954548244830221, 0.007853196933865547, 0.005200029816478491, 0.0034086843952536583, 0.011123435571789742, 0.0, 0.0, 0.0, 0.48597243428230286, 0.05253118649125099, 0.06572883576154709, 0.06831242144107819, 0.06681334227323532, 0.09225586801767349, 0.168385848402977, 0.0, 0.0, 0.0], [0.24338993430137634, 0.0009381886338815093, 0.001691973302513361, 0.004991883412003517, 0.06480661034584045, 0.02667633630335331, 0.5911706686019897, 0.06633439660072327, 0.0, 0.0, 0.2283225953578949, 0.01085133571177721, 0.0076954541727900505, 0.03403906524181366, 0.05505141243338585, 0.11318682134151459, 0.23008716106414795, 0.3207661509513855, 0.0, 0.0], [0.9644694328308105, 0.00020931981271132827, 0.00022034233552403748, 0.001116775325499475, 0.0005140798166394234, 0.011200232431292534, 0.006607241928577423, 0.015303434804081917, 0.00035933865001425147, 0.0, 0.31019407510757446, 0.01576145552098751, 0.006604246329516172, 0.1025082990527153, 0.11805430799722672, 0.0999068170785904, 0.17944715917110443, 0.09494999051094055, 0.07257375121116638, 0.0], [6.446504994528368e-05, 5.223282641964033e-05, 4.761212403536774e-05, 0.0026887860149145126, 0.9879595041275024, 8.169181819539517e-05, 3.4432316169841215e-05, 0.00022215544595383108, 0.008540215902030468, 0.0003085293574258685, 0.028495613485574722, 0.00728303287178278, 0.028978589922189713, 0.21746259927749634, 0.0312367994338274, 0.01134485937654972, 0.002138715935871005, 0.0005697175511159003, 0.00012198994954815134, 0.6723678112030029]], [[0.1481553614139557, 0.14691436290740967, 0.5575758218765259, 0.02441403828561306, 0.058879025280475616, 0.011832842603325844, 0.01016098354011774, 0.015505112707614899, 0.026562504470348358, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20737145841121674, 0.2658809721469879, 0.4251604974269867, 0.03998560830950737, 0.012661930173635483, 0.003662273520603776, 0.0006891252705827355, 0.004390099551528692, 0.040197838097810745, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09591562300920486, 0.13717111945152283, 0.23219715058803558, 0.020156029611825943, 0.031411558389663696, 0.04842779412865639, 0.003137993859127164, 0.03623202070593834, 0.3953508734703064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.039530280977487564, 0.0770052894949913, 0.4637334942817688, 0.3284752666950226, 0.018390586599707603, 0.021701356396079063, 0.0038800504989922047, 0.01712900958955288, 0.03015456721186638, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11945435404777527, 0.064958855509758, 0.07506249845027924, 0.03312050923705101, 0.045947931706905365, 0.21168209612369537, 0.1585550606250763, 0.15941208600997925, 0.13180643320083618, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.027811916545033455, 0.005367752630263567, 0.022701909765601158, 0.02928026206791401, 0.042085714638233185, 0.23124846816062927, 0.3448639512062073, 0.17164644598960876, 0.12499356269836426, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04207267239689827, 0.0036673955619335175, 0.013221505098044872, 0.04823020473122597, 0.018784234300255775, 0.15617071092128754, 0.5762590169906616, 0.08817830681800842, 0.053416069597005844, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.016374358907341957, 0.01844160258769989, 0.0564517118036747, 0.008724928833544254, 0.031119121238589287, 0.08068697899580002, 0.028958600014448166, 0.08762799203395844, 0.6716147661209106, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01953587494790554, 0.025442129001021385, 0.033712055534124374, 0.054231878370046616, 0.046861547976732254, 0.038379911333322525, 0.03105914779007435, 0.027592265978455544, 0.723185122013092, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13706038892269135, 0.05296454578638077, 0.06056801974773407, 0.10271193832159042, 0.10989244282245636, 0.11971112340688705, 0.10623226314783096, 0.11503037810325623, 0.19582894444465637, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9361864924430847, 0.06381344050168991, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9973775148391724, 0.0026223897002637386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8988155722618103, 0.08033642917871475, 0.020848000422120094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8086934685707092, 0.08078567683696747, 0.11052089184522629, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9107179045677185, 0.0414138063788414, 0.01669401116669178, 0.031174303963780403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14967022836208344, 0.05171789228916168, 0.3914002478122711, 0.40721163153648376, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7451518774032593, 0.09319417923688889, 0.038068220019340515, 0.07867664098739624, 0.04490913450717926, 0.0, 0.0, 0.0, 0.0, 0.0, 0.40780311822891235, 0.04434635117650032, 0.05232110992074013, 0.3448564112186432, 0.15067294239997864, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7448095083236694, 0.053781699389219284, 0.02206255868077278, 0.051568105816841125, 0.050503022968769073, 0.07727508991956711, 0.0, 0.0, 0.0, 0.0, 0.3417491614818573, 0.023165758699178696, 0.008621969260275364, 0.03819064050912857, 0.566249430179596, 0.022023199126124382, 0.0, 0.0, 0.0, 0.0], [0.49530187249183655, 0.08432565629482269, 0.024537190794944763, 0.03536847233772278, 0.04101351276040077, 0.2942921817302704, 0.025161173194646835, 0.0, 0.0, 0.0, 0.13283461332321167, 0.0027981544844806194, 0.001892031985335052, 0.057958006858825684, 0.4807162284851074, 0.22431829571723938, 0.09948258846998215, 0.0, 0.0, 0.0], [0.33749768137931824, 0.0470278225839138, 0.025994539260864258, 0.11184448003768921, 0.035708073526620865, 0.3288814127445221, 0.052594561129808426, 0.06045151129364967, 0.0, 0.0, 0.5702553391456604, 0.005225116387009621, 0.0014312443090602756, 0.028526127338409424, 0.15899939835071564, 0.05284468084573746, 0.022491520270705223, 0.16022635996341705, 0.0, 0.0], [0.5827996730804443, 0.037185750901699066, 0.025691334158182144, 0.040444474667310715, 0.032313525676727295, 0.15237869322299957, 0.02532070316374302, 0.06300554424524307, 0.040860243141651154, 0.0, 0.005209033377468586, 6.901475717313588e-05, 5.760595013271086e-05, 0.006149875931441784, 0.006613760255277157, 0.010193211026489735, 0.013639912940561771, 0.9578513503074646, 0.00021631908020935953, 0.0], [0.17712005972862244, 0.06858222186565399, 0.023361776024103165, 0.06553570926189423, 0.015878353267908096, 0.40178799629211426, 0.03335757926106453, 0.09457883983850479, 0.06679747253656387, 0.05300001800060272, 0.5839820504188538, 0.007275882177054882, 0.03890826180577278, 0.2169828861951828, 0.02285575494170189, 0.0033320344518870115, 0.0027764069382101297, 0.032896872609853745, 0.038299210369586945, 0.05269058048725128]], [[0.1995955854654312, 0.1365700364112854, 0.06844333559274673, 0.10430964082479477, 0.06450515240430832, 0.046256136149168015, 0.1181989535689354, 0.11867640167474747, 0.14344464242458344, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11168574541807175, 0.19221879541873932, 0.09424428641796112, 0.10450402647256851, 0.06917304545640945, 0.0600862056016922, 0.14199501276016235, 0.09375911951065063, 0.1323338896036148, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1609925925731659, 0.1396498680114746, 0.177944153547287, 0.03334498405456543, 0.016808874905109406, 0.10536731034517288, 0.1187783032655716, 0.03365077078342438, 0.2134632021188736, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1381153017282486, 0.04422784969210625, 0.04791303351521492, 0.16848880052566528, 0.14531251788139343, 0.08485772460699081, 0.03650972992181778, 0.08906612545251846, 0.24550898373126984, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05420248210430145, 0.02658572979271412, 0.05446610227227211, 0.10749125480651855, 0.22097598016262054, 0.16638338565826416, 0.0331658273935318, 0.035041358321905136, 0.30168798565864563, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13837963342666626, 0.02727973647415638, 0.1299334168434143, 0.0896206796169281, 0.11551950126886368, 0.13963927328586578, 0.07841819524765015, 0.02172034978866577, 0.25948914885520935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08648376911878586, 0.022083481773734093, 0.023758457973599434, 0.19388236105442047, 0.1724909394979477, 0.02776450663805008, 0.04756799340248108, 0.03839871287345886, 0.38756975531578064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06476524472236633, 0.0030945511534810066, 0.016289785504341125, 0.013512126170098782, 0.007217712234705687, 0.047962453216314316, 0.03755675256252289, 0.02134101279079914, 0.788260281085968, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06007339805364609, 0.038077425211668015, 0.01732070930302143, 0.04335314407944679, 0.09849875420331955, 0.07123422622680664, 0.12536978721618652, 0.17402620613574982, 0.37204641103744507, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.31619662046432495, 0.10629935562610626, 0.051193755120038986, 0.08206456899642944, 0.08056272566318512, 0.05727463215589523, 0.13476009666919708, 0.03839832916855812, 0.13325001299381256, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5274816155433655, 0.47251835465431213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.050016310065984726, 0.9499835968017578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0672292709350586, 0.8769893646240234, 0.05578138306736946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1501082479953766, 0.3363426625728607, 0.5135491490364075, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13304242491722107, 0.3016340434551239, 0.1132093071937561, 0.45211419463157654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3278971016407013, 0.3615517318248749, 0.08450257778167725, 0.22604861855506897, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3050541281700134, 0.12257003039121628, 0.15977424383163452, 0.12758392095565796, 0.2850176990032196, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25667694211006165, 0.40780505537986755, 0.15422186255455017, 0.10097997635602951, 0.08031607419252396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22356735169887543, 0.03928647190332413, 0.007754397578537464, 0.009327426552772522, 0.12143179029226303, 0.5986325740814209, 0.0, 0.0, 0.0, 0.0, 0.06381947547197342, 0.026328660547733307, 0.009785240516066551, 0.001955200219526887, 0.6504424810409546, 0.24766895174980164, 0.0, 0.0, 0.0, 0.0], [0.21462960541248322, 0.06222677230834961, 0.03770677372813225, 0.020617984235286713, 0.1298619657754898, 0.16450734436511993, 0.3704494833946228, 0.0, 0.0, 0.0, 0.20312389731407166, 0.04875075817108154, 0.01619674637913704, 0.028123438358306885, 0.08579143136739731, 0.20417368412017822, 0.41383999586105347, 0.0, 0.0, 0.0], [0.03640613704919815, 0.010346643626689911, 0.00673291739076376, 0.007102567236870527, 0.047351155430078506, 0.07502260059118271, 0.1735789030790329, 0.6434589624404907, 0.0, 0.0, 0.0028631098102778196, 0.00043423930765129626, 0.00021322182146832347, 0.0004598038794938475, 0.009248310700058937, 0.010348621755838394, 0.14874719083309174, 0.8276853561401367, 0.0, 0.0], [0.10655857622623444, 0.005401281639933586, 0.008467022329568863, 0.004935698118060827, 0.02920999936759472, 0.06761414557695389, 0.11367721855640411, 0.6410130262374878, 0.02312297187745571, 0.0, 0.0018436884274706244, 0.00014077842934057117, 0.00019285999587737024, 0.0006260477821342647, 0.010675687342882156, 0.007219970691949129, 0.05410425364971161, 0.9211149215698242, 0.0040815602988004684, 0.0], [0.022663118317723274, 0.01638328656554222, 0.016234278678894043, 0.013438239693641663, 0.13762539625167847, 0.1316443532705307, 0.10126813501119614, 0.1815005987882614, 0.10885223746299744, 0.27039045095443726, 0.0996592566370964, 0.00012745348794851452, 0.0005518114776350558, 0.00026576913660392165, 0.00016320105351042002, 9.30886235437356e-05, 0.00024295282491948456, 0.000212193961488083, 0.0026789114344865084, 0.8960053324699402]]], [[[0.0077307759784162045, 0.013184988871216774, 0.016869038343429565, 0.013336911797523499, 0.01304439827799797, 0.013718237169086933, 0.0296618789434433, 0.02448520064353943, 0.8679684996604919, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1374177634716034, 0.018056754022836685, 0.029763542115688324, 0.004862301517277956, 0.00231130956672132, 0.006278112530708313, 0.012106452137231827, 0.033879589289426804, 0.755324125289917, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1539316475391388, 0.23461903631687164, 0.033691998571157455, 0.026462335139513016, 0.0030949951615184546, 0.0038835303857922554, 0.009438932873308659, 0.0025479686446487904, 0.5323294997215271, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.29215019941329956, 0.05790534242987633, 0.18934868276119232, 0.018473153933882713, 0.002999690594151616, 0.004652327857911587, 0.010374259203672409, 0.0072145056910812855, 0.4168816804885864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16539201140403748, 0.05819307267665863, 0.12084146589040756, 0.1738077849149704, 0.004504370968788862, 0.006831282749772072, 0.02180996537208557, 0.012287246994674206, 0.4363327622413635, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017602156847715378, 0.026805447414517403, 0.07671570032835007, 0.5152483582496643, 0.21202509105205536, 0.041201505810022354, 0.02207496576011181, 0.00952092744410038, 0.07880578190088272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017750855535268784, 0.01654047518968582, 0.07482129335403442, 0.23223723471164703, 0.3542158007621765, 0.16141267120838165, 0.02249749004840851, 0.044686269015073776, 0.07583795487880707, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01755565032362938, 0.008823209442198277, 0.013083218596875668, 0.5061533451080322, 0.02344801276922226, 0.0075739468447864056, 0.07187878340482712, 0.029167035594582558, 0.3223167061805725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07037408649921417, 0.05453738570213318, 0.05508268624544144, 0.02769530564546585, 0.038050826638936996, 0.20446287095546722, 0.19980187714099884, 0.19835616648197174, 0.15163862705230713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003375247586518526, 0.008793321438133717, 0.001001630094833672, 0.002094975672662258, 0.0032946632709354162, 0.01792662777006626, 0.10988471657037735, 0.21093924343585968, 0.6426896452903748, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5484673976898193, 0.11615661531686783, 0.018765881657600403, 0.05580288916826248, 0.007166780531406403, 0.0032917018979787827, 0.014802427962422371, 0.009165346622467041, 0.2263808399438858, 0.0, 0.11433855444192886, 0.04686390608549118, 0.05050662159919739, 0.01692698895931244, 0.014815520495176315, 0.0768260732293129, 0.017432983964681625, 0.6101956367492676, 0.052093639969825745, 0.0], [0.05529153719544411, 0.9114729762077332, 0.02160518430173397, 0.004567866213619709, 0.0020856212358921766, 0.002110318513587117, 7.9427394666709e-05, 0.0003330546023789793, 0.002454147208482027, 0.0, 0.38972416520118713, 0.22944767773151398, 0.07904881238937378, 0.052096493542194366, 0.007011168636381626, 0.006784902419894934, 0.0014207185013219714, 0.13426420092582703, 0.10020176321268082, 0.0], [0.05249117314815521, 0.7437799572944641, 0.20206855237483978, 0.000259216787526384, 5.300459815771319e-05, 0.0010736108524724841, 1.3093422239762731e-05, 5.472580232890323e-05, 0.00020689083612523973, 0.0, 0.11186019331216812, 0.017331605777144432, 0.04338392615318298, 0.012996343895792961, 0.0037596735637634993, 0.04186789691448212, 0.010188347660005093, 0.40130615234375, 0.35730576515197754, 0.0], [0.00011124516458949074, 0.00014391898002941161, 0.002859711181372404, 0.9545366168022156, 0.03663090988993645, 0.0007743826135993004, 9.755761857377365e-05, 0.004461521748453379, 0.00038416660390794277, 0.0, 0.10591340810060501, 0.20223784446716309, 0.2886512279510498, 0.06966178864240646, 0.009836114011704922, 0.0229511596262455, 0.008547846227884293, 0.14636646211147308, 0.14583423733711243, 0.0], [0.00743946572765708, 0.0006693374016322196, 0.030706975609064102, 0.7693524360656738, 0.13636630773544312, 0.010656076483428478, 0.00020547783060465008, 0.04430907592177391, 0.0002948205510620028, 0.0, 0.07172418385744095, 0.2936038672924042, 0.03526498004794121, 0.13891401886940002, 0.06139945611357689, 0.03925776481628418, 0.05349786579608917, 0.1478489190340042, 0.15848881006240845, 0.0], [0.34355977177619934, 0.05352164804935455, 0.28276264667510986, 0.0013746530748903751, 0.09131773561239243, 0.14369648694992065, 0.012657100334763527, 0.0018211203860118985, 0.06928855180740356, 0.0, 0.06896000355482101, 0.07670743763446808, 0.04746192321181297, 0.0077508557587862015, 0.0064402432180941105, 0.0690828412771225, 0.07239065319299698, 0.3534849286079407, 0.2977212965488434, 0.0], [0.00012065855844411999, 1.9836104911519215e-05, 1.09456195787061e-05, 1.777518991730176e-05, 5.799566861242056e-05, 0.00024955562548711896, 0.9993764758110046, 2.622428155518719e-06, 0.00014374956663232297, 0.0, 0.07854402810335159, 0.05315924435853958, 0.006970829796046019, 0.01197806466370821, 0.15678070485591888, 0.059328123927116394, 0.0844779834151268, 0.058106619864702225, 0.4906543493270874, 0.0], [0.07937772572040558, 0.06896578520536423, 0.046331409364938736, 0.0006753505440428853, 0.10991709679365158, 0.07862550020217896, 0.015291865915060043, 0.024935398250818253, 0.575879693031311, 0.0, 0.03681005910038948, 0.0140389958396554, 0.007565703243017197, 0.004180380143225193, 0.03550711274147034, 0.08768045157194138, 0.04672156274318695, 0.05331620201468468, 0.7141793966293335, 0.0], [0.0013021372724324465, 0.0014324801741167903, 0.001721968175843358, 0.0011953436769545078, 0.025524066761136055, 0.0017154657980427146, 0.004751213360577822, 0.06856247782707214, 0.8937948942184448, 0.0, 0.011276278644800186, 0.0043571279384195805, 0.0015699869254603982, 0.009309794753789902, 0.5466312766075134, 0.06633520126342773, 0.012565904296934605, 0.036605555564165115, 0.3113488256931305, 0.0], [0.19237911701202393, 0.08248087018728256, 0.005975060164928436, 0.0005637910799123347, 0.0032617340330034494, 0.08159960061311722, 0.10672623664140701, 0.06415636837482452, 0.4628572165966034, 0.0, 0.06600549072027206, 0.025541657581925392, 0.15132947266101837, 0.052793603390455246, 0.07684693485498428, 0.05682613328099251, 0.01590665802359581, 0.05306769907474518, 0.501682460308075, 0.0]], [[0.07041527330875397, 0.15367093682289124, 0.3963199257850647, 0.03077671490609646, 0.0928598940372467, 0.04086732864379883, 0.018142100423574448, 0.012120239436626434, 0.18482762575149536, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0367966964840889, 0.022482391446828842, 0.35830163955688477, 0.02875097654759884, 0.03547174483537674, 0.026731541380286217, 0.005365677177906036, 0.038472291082143784, 0.4476269483566284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.035722482949495316, 0.013633755035698414, 0.09877835214138031, 0.0896211713552475, 0.16777706146240234, 0.10725134611129761, 0.05053357034921646, 0.10712091624736786, 0.32956135272979736, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026592494919896126, 0.002020884770900011, 0.010739283636212349, 0.015951883047819138, 0.18538028001785278, 0.16766443848609924, 0.03731367364525795, 0.3853055238723755, 0.16903170943260193, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0038862666115164757, 0.0004870722477789968, 0.0013956124894320965, 0.001421120367012918, 0.013834443874657154, 0.18579153716564178, 0.18213258683681488, 0.5258954763412476, 0.08515587449073792, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002724156714975834, 0.004415807779878378, 0.003638928523287177, 0.00862019695341587, 0.010569852776825428, 0.18068262934684753, 0.2256886065006256, 0.3616458773612976, 0.20201392471790314, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004798010922968388, 0.006960091646760702, 0.005558326840400696, 0.015252271667122841, 0.010058294981718063, 0.16163121163845062, 0.19844789803028107, 0.089196115732193, 0.508097767829895, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006148195825517178, 0.009931370615959167, 0.0022139709908515215, 0.003481896361336112, 0.00199966412037611, 0.011451391503214836, 0.018514955416321754, 0.10389390587806702, 0.8423647284507751, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1474374532699585, 0.1116698756814003, 0.10040155798196793, 0.07832593470811844, 0.051569730043411255, 0.103182852268219, 0.0987909808754921, 0.06264416873455048, 0.24597744643688202, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0029639359563589096, 0.0008504446013830602, 0.002215511864051223, 0.0016108372947201133, 0.001786046545021236, 0.003435377962887287, 0.000923731888178736, 0.009203500114381313, 0.9770104885101318, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02611132338643074, 0.03929625079035759, 0.13154497742652893, 0.0362381711602211, 0.41634607315063477, 0.056842975318431854, 0.07852181792259216, 0.04581860452890396, 0.16927982866764069, 0.0, 0.02032626047730446, 0.020203230902552605, 0.6020291447639465, 0.030009541660547256, 0.018654465675354004, 0.18802858889102936, 0.05143868178129196, 0.02303317002952099, 0.04627683013677597, 0.0], [0.02255222387611866, 0.10361888259649277, 0.8158259987831116, 0.004235901869833469, 0.006766538135707378, 0.000986771541647613, 0.0032120062969624996, 0.001885716337710619, 0.04091576859354973, 0.0, 0.0019176346249878407, 0.016540443524718285, 0.9535443782806396, 0.023291945457458496, 0.00010982116509694606, 0.0009689715225249529, 0.00013850984396412969, 0.00020587266772054136, 0.0032822038047015667, 0.0], [0.017373425886034966, 0.13537107408046722, 0.8077713847160339, 0.0038886782713234425, 0.000276644917903468, 0.0006381930434145033, 0.00045153097016736865, 0.00025059780455194414, 0.0339784249663353, 0.0, 0.0021921033039689064, 0.01942657120525837, 0.9639623165130615, 0.007353355176746845, 5.936318120802753e-05, 0.0048356493934988976, 6.20722203166224e-05, 0.00038220148417167366, 0.0017266407376155257, 0.0], [0.005007221829146147, 0.01780957728624344, 0.01267488207668066, 0.04065018519759178, 0.30516332387924194, 0.0026367236860096455, 0.0019572318997234106, 0.004150545224547386, 0.6099502444267273, 0.0, 0.009460263885557652, 0.051493529230356216, 0.3948231041431427, 0.009338779374957085, 0.006950095761567354, 0.48816555738449097, 0.01867900975048542, 0.0028053568676114082, 0.018284155055880547, 0.0], [0.00225767120718956, 0.0031865497585386038, 0.001125291339121759, 0.016497144475579262, 0.8971690535545349, 0.00343481102026999, 0.003961168695241213, 0.006191920954734087, 0.0661763921380043, 0.0, 0.003424277761951089, 0.009652957320213318, 0.10005363076925278, 0.026136387139558792, 0.09182560443878174, 0.6324647068977356, 0.07658552378416061, 0.005459657870233059, 0.05439731851220131, 0.0], [0.009628614410758018, 0.010054895654320717, 0.001336919842287898, 0.0704738199710846, 0.6877674460411072, 0.03301373869180679, 0.05187760666012764, 0.005273953080177307, 0.1305730938911438, 0.0, 0.0026005429681390524, 0.0029943317640572786, 0.26352012157440186, 0.02426978573203087, 0.05801504850387573, 0.49633610248565674, 0.11849544942378998, 0.009708826430141926, 0.02405967190861702, 0.0], [0.004103087354451418, 0.0010862533235922456, 0.0006940921302884817, 0.005870609078556299, 0.43826234340667725, 0.030803751200437546, 0.2956492602825165, 0.002342070685699582, 0.2211885154247284, 0.0, 0.0027453943621367216, 0.0021119171287864447, 0.0030521987937390804, 0.09308812767267227, 0.28145554661750793, 0.015254919417202473, 0.530491828918457, 0.007408978417515755, 0.06439103186130524, 0.0], [0.0007035931921564043, 0.0015657383482903242, 0.0003329406026750803, 0.025085464119911194, 0.8715798258781433, 0.006046876311302185, 0.002586639951914549, 0.00011169366916874424, 0.09198720753192902, 0.0, 0.0012973180273547769, 0.002199073787778616, 0.0031004296615719795, 0.024488963186740875, 0.8535729050636292, 0.016068320721387863, 0.029179612174630165, 0.012250186875462532, 0.05784311145544052, 0.0], [0.0021814475767314434, 0.0018482444575056434, 0.02461252734065056, 0.02290530502796173, 0.17733190953731537, 0.007551506161689758, 0.026218494400382042, 0.1859409213066101, 0.5514096021652222, 0.0, 0.0010010729311034083, 0.0008253253763541579, 0.0028483583591878414, 0.028342707082629204, 0.860925018787384, 0.0038871155120432377, 0.006998666562139988, 0.01413769368082285, 0.08103384077548981, 0.0], [0.002560347318649292, 0.0069580040872097015, 0.0021583843044936657, 0.002428637584671378, 0.010794135741889477, 0.002866419730708003, 0.010929176583886147, 0.004671781323850155, 0.9566330909729004, 0.0, 0.007383578456938267, 0.056256651878356934, 0.5807297825813293, 0.01667044125497341, 0.03810223564505577, 0.07880110293626785, 0.009197888895869255, 0.12926581501960754, 0.08359251171350479, 0.0]], [[0.022387586534023285, 0.045972827821969986, 0.05835629999637604, 0.22869053483009338, 0.010770916007459164, 0.006216464098542929, 0.018148910254240036, 0.006308646872639656, 0.6031478047370911, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0035997454542666674, 0.002674269489943981, 0.016009783372282982, 0.05554450675845146, 0.0013587778666988015, 0.0032801039051264524, 0.00560772093012929, 0.00799081102013588, 0.9039342403411865, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0018151046242564917, 0.001049908110871911, 0.0005912692868150771, 0.005136367864906788, 0.0005621784366667271, 0.00844560656696558, 0.017937110736966133, 0.008342047221958637, 0.9561205506324768, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004429707303643227, 0.005516116041690111, 0.003033371875062585, 0.012963998131453991, 0.0034379358403384686, 0.003276604227721691, 0.0140963364392519, 0.005416945554316044, 0.9478288888931274, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013230006210505962, 0.011804360896348953, 0.009972047992050648, 0.004975683055818081, 0.008386109955608845, 0.18977868556976318, 0.1806434541940689, 0.03204761818051338, 0.549161970615387, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006766628473997116, 0.008349079638719559, 0.003925195895135403, 0.0006033667596057057, 0.006175691727548838, 0.2236345112323761, 0.03405819088220596, 0.07976362109184265, 0.6367236971855164, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03860252723097801, 0.01646261475980282, 0.02104821614921093, 0.0021387943997979164, 0.005319601856172085, 0.2400989532470703, 0.03188503161072731, 0.005558657925575972, 0.6388856768608093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0738457515835762, 0.018826894462108612, 0.0069308048114180565, 0.0074225678108632565, 0.004789229016751051, 0.046955253928899765, 0.11907684803009033, 0.18744726479053497, 0.5347052812576294, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09479796141386032, 0.11939200013875961, 0.0752992108464241, 0.061374519020318985, 0.08638977259397507, 0.12459041178226471, 0.16023214161396027, 0.0879756435751915, 0.1899482160806656, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.027537798509001732, 0.06296242028474808, 0.014751194976270199, 0.0011882808757945895, 0.016387099400162697, 0.15830224752426147, 0.03707461059093475, 0.028470970690250397, 0.6533253788948059, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10015174746513367, 0.09576243162155151, 0.03824060782790184, 0.020538825541734695, 0.027732321992516518, 0.017240623012185097, 0.011470633558928967, 0.4632270634174347, 0.22563566267490387, 0.0, 0.023356424644589424, 0.012650059536099434, 0.05017145350575447, 0.05590398982167244, 0.05159280076622963, 0.01602507382631302, 0.014807065948843956, 0.654244601726532, 0.12124844640493393, 0.0], [0.018642960116267204, 0.04822106286883354, 0.0140958521515131, 0.13092586398124695, 0.031955283135175705, 0.05195324495434761, 0.024238048121333122, 0.35591286420822144, 0.32405486702919006, 0.0, 0.030835414305329323, 0.04180247709155083, 0.029645785689353943, 0.20071062445640564, 0.010328685864806175, 0.03208288922905922, 0.026780622079968452, 0.457701712846756, 0.17011170089244843, 0.0], [0.04842180013656616, 0.06019889563322067, 0.016854524612426758, 0.166769877076149, 0.016003064811229706, 0.024013301357626915, 0.03686072677373886, 0.44016721844673157, 0.1907106339931488, 0.0, 0.02856343612074852, 0.03459611535072327, 0.15441730618476868, 0.04662194848060608, 0.0013040672056376934, 0.017847269773483276, 0.02464178577065468, 0.5969575643539429, 0.09505032747983932, 0.0], [0.0036558371502906084, 0.0033082098234444857, 0.0009605223312973976, 0.017093589529395103, 0.019939379766583443, 0.08280718326568604, 0.031923551112413406, 0.703184187412262, 0.1371275782585144, 0.0, 0.03350958973169327, 0.02514287829399109, 0.027676144614815712, 0.11052078753709793, 0.15496152639389038, 0.08862635493278503, 0.027723105624318123, 0.24766287207603455, 0.28417670726776123, 0.0], [0.0018930931109935045, 0.002881440566852689, 0.00019882648484781384, 0.00406575808301568, 0.0021070034708827734, 0.011610278859734535, 0.0074381148442626, 0.9341073036193848, 0.03569793701171875, 0.0, 0.014355039224028587, 0.005383878946304321, 0.002517768880352378, 0.09422861039638519, 0.06622537225484848, 0.046315327286720276, 0.08473969250917435, 0.4999735355377197, 0.18626095354557037, 0.0], [0.00691588269546628, 0.02838265150785446, 0.015397720038890839, 0.031874921172857285, 0.04765379801392555, 0.22744230926036835, 0.06624653190374374, 0.10724947601556778, 0.4688366651535034, 0.0, 0.002460801973938942, 0.0016284199664369226, 0.005857668351382017, 0.006880565080791712, 0.7626023292541504, 0.025456121191382408, 0.021016357466578484, 0.06090177595615387, 0.11319592595100403, 0.0], [0.009144916199147701, 0.012914983555674553, 0.0114166010171175, 0.010616455227136612, 0.03852293640375137, 0.11398687958717346, 0.23996756970882416, 0.03855413943529129, 0.5248754620552063, 0.0, 0.007633878383785486, 0.002682786202058196, 0.0008938225219026208, 0.006808742880821228, 0.17231638729572296, 0.049100711941719055, 0.32851701974868774, 0.0061601921916007996, 0.4258863925933838, 0.0], [0.0165528766810894, 0.08396174013614655, 0.03695421293377876, 0.012792840600013733, 0.05054211989045143, 0.004681664984673262, 0.006349458359181881, 0.0059485542587935925, 0.7822163701057434, 0.0, 0.003303236560896039, 0.0015338786179199815, 0.0017581325955688953, 0.0052335225045681, 0.24177710711956024, 0.09136255830526352, 0.06603478640317917, 0.0047843558713793755, 0.5842124223709106, 0.0], [0.01941962167620659, 0.0720844566822052, 0.06703408807516098, 0.0024893011432141066, 0.09017500281333923, 0.01547347940504551, 0.011082785204052925, 0.036743972450494766, 0.6854971051216125, 0.0, 0.011694137938320637, 0.0015430846251547337, 0.00043408613419160247, 0.005433904007077217, 0.03723231703042984, 0.1666216105222702, 0.04878358170390129, 0.024785596877336502, 0.7034717798233032, 0.0], [0.010559813119471073, 0.7681021094322205, 0.01782229356467724, 0.0007385257631540298, 0.000383153063012287, 0.00014055910287424922, 0.00037340103881433606, 0.000453647633548826, 0.2014264017343521, 0.0, 0.028515880927443504, 0.0183264147490263, 0.011487613432109356, 0.03205259144306183, 0.06179385632276535, 0.041277043521404266, 0.014015565626323223, 0.06198226660490036, 0.7305486798286438, 0.0]], [[0.044569190591573715, 0.00917287077754736, 0.004391324240714312, 0.8386606574058533, 0.06130588799715042, 0.003870139131322503, 0.007488539442420006, 0.028126200661063194, 0.002415221417322755, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08009635657072067, 0.016815535724163055, 0.012093844823539257, 0.1592065542936325, 0.5643750429153442, 0.02920410968363285, 0.0919446051120758, 0.036902546882629395, 0.009361499920487404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08603464066982269, 0.01933746039867401, 0.05900268629193306, 0.2806539237499237, 0.22094620764255524, 0.08643656224012375, 0.026435989886522293, 0.1974046230316162, 0.023747902363538742, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011109462939202785, 0.008883769623935223, 0.006091873627156019, 0.9400036931037903, 0.020445559173822403, 0.0056496066972613335, 0.0019461432239040732, 0.005268549080938101, 0.000601345207542181, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12766019999980927, 0.021774157881736755, 0.08726640790700912, 0.0718328207731247, 0.053083695471286774, 0.3031027019023895, 0.06321869790554047, 0.2611844837665558, 0.010876962915062904, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06407223641872406, 0.11627303808927536, 0.1807759404182434, 0.0054795523174107075, 0.026687098667025566, 0.09637009352445602, 0.052303463220596313, 0.4456423819065094, 0.012396130710840225, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09725438803434372, 0.15948796272277832, 0.05173082649707794, 0.01153761800378561, 0.0721999853849411, 0.059252724051475525, 0.11923323571681976, 0.05380275845527649, 0.3755004107952118, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0974307581782341, 0.23218762874603271, 0.06967660784721375, 0.031012043356895447, 0.04906507954001427, 0.31767621636390686, 0.08231117576360703, 0.07159094512462616, 0.04904941841959953, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16388627886772156, 0.17526376247406006, 0.07081529498100281, 0.17886894941329956, 0.07944575697183609, 0.07640470564365387, 0.0757102444767952, 0.04333823174238205, 0.13626690208911896, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09265941381454468, 0.11633585393428802, 0.04908691346645355, 0.0062498715706169605, 0.07016508281230927, 0.012818480841815472, 0.0484321266412735, 0.015437646768987179, 0.5888146162033081, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15394526720046997, 0.03155883774161339, 0.013263775035738945, 0.6156436800956726, 0.07578981667757034, 0.016770629212260246, 0.041555847972631454, 0.008898256346583366, 0.042573899030685425, 0.0, 0.10071786493062973, 0.017111245542764664, 0.07246935367584229, 0.01480931881815195, 0.14864948391914368, 0.20273517072200775, 0.054981958121061325, 0.25890761613845825, 0.12961813807487488, 0.0], [0.0076131573878228664, 0.0034139170311391354, 0.013692762702703476, 0.9489110708236694, 0.0023491643369197845, 0.004504398908466101, 0.018228650093078613, 6.88576401444152e-05, 0.0012180121848359704, 0.0, 0.049787674099206924, 0.02108882926404476, 0.20989678800106049, 0.006962155923247337, 0.21569682657718658, 0.1622857302427292, 0.016771212220191956, 0.2403237521648407, 0.07718709856271744, 0.0], [0.002593724289909005, 0.0010450187837705016, 0.004842442460358143, 0.9895163178443909, 0.0010734394891187549, 6.863682210678235e-05, 0.00040535334846936166, 0.00029785462538711727, 0.00015763564442750067, 0.0, 0.024618864059448242, 0.010488898493349552, 0.4834355115890503, 0.015693388879299164, 0.07393413037061691, 0.055557433515787125, 0.007495412603020668, 0.27800077199935913, 0.05077548325061798, 0.0], [0.007949098013341427, 0.0007930149440653622, 0.0010613474296405911, 0.913150429725647, 0.017281265929341316, 0.01414033118635416, 0.03622613847255707, 0.0036093932576477528, 0.0057889497838914394, 0.0, 0.006324393209069967, 0.0006586945382878184, 0.02188086323440075, 0.003439908614382148, 0.055277179926633835, 0.5423230528831482, 0.1656835526227951, 0.12264314293861389, 0.08176910877227783, 0.0], [0.002646723063662648, 0.0005160199943929911, 0.0002488561731297523, 0.0025351925287395716, 0.0016247049206867814, 0.09429789334535599, 0.8856176137924194, 0.005861275363713503, 0.006651633884757757, 0.0, 0.008246216922998428, 0.000647101376671344, 0.018551276996731758, 0.0031310885678976774, 0.04379039630293846, 0.34376823902130127, 0.2999532222747803, 0.13205647468566895, 0.1498558670282364, 0.0], [0.010305220261216164, 0.0041244118474423885, 0.0009454450337216258, 0.011387528851628304, 0.006450551562011242, 0.09920497238636017, 0.7582080960273743, 0.0005519646219909191, 0.1088215708732605, 0.0, 0.001800144906155765, 0.00032634654780849814, 0.02560480497777462, 0.0014933178899809718, 0.04328969866037369, 0.48067817091941833, 0.22867664694786072, 0.008819987997412682, 0.20931090414524078, 0.0], [0.02700764685869217, 0.004230276681482792, 0.0004602614790201187, 0.0022337904665619135, 0.001628970610909164, 0.01760227419435978, 0.5739604234695435, 0.0034094173461198807, 0.3694668710231781, 0.0, 0.0023069612216204405, 0.0018136217258870602, 0.006447605788707733, 0.005140945315361023, 0.046570103615522385, 0.045606330037117004, 0.3236173987388611, 0.014286459423601627, 0.5542104840278625, 0.0], [0.008519366383552551, 0.005846879445016384, 0.00031929058604873717, 0.00022687541786581278, 0.0001488836423959583, 0.0012441301951184869, 0.007195098325610161, 0.000364138453733176, 0.9761351943016052, 0.0, 0.0025225167628377676, 0.000774701707996428, 0.0168449804186821, 0.0014132045907899737, 0.1692919135093689, 0.21547472476959229, 0.19468647241592407, 0.00621472392231226, 0.392776757478714, 0.0], [0.10429845005273819, 0.062129467725753784, 0.0009245545952580869, 0.00015166438242886215, 0.00031537580071017146, 0.00040291156619787216, 0.006900690030306578, 0.009933815337717533, 0.8149431943893433, 0.0, 0.006893941201269627, 0.0026040272787213326, 0.036687299609184265, 0.0016275923699140549, 0.13132861256599426, 0.15552441775798798, 0.23651301860809326, 0.023025648668408394, 0.4057953953742981, 0.0], [0.021318454295396805, 0.024646490812301636, 0.0006273255567066371, 3.4892458643298596e-05, 0.0002248335222247988, 0.001184952794574201, 0.005942351184785366, 0.01648845337331295, 0.9295321702957153, 0.0, 0.004257934633642435, 0.008543262258172035, 0.05716743320226669, 0.0024442216381430626, 0.027526315301656723, 0.08828678727149963, 0.025276461616158485, 0.2843557894229889, 0.5021417737007141, 0.0]], [[0.005850312765687704, 0.017421673983335495, 0.004798548296093941, 0.008814580738544464, 0.00403921864926815, 0.015260725282132626, 0.03377071022987366, 0.009620469063520432, 0.9004237651824951, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013436811044812202, 0.03272867575287819, 0.005969575606286526, 0.02213078737258911, 0.008325905539095402, 0.015314633026719093, 0.027177294716238976, 0.017041552811861038, 0.8578747510910034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012305106967687607, 0.03383316844701767, 0.010593314655125141, 0.027156231924891472, 0.00306991720572114, 0.004844812210649252, 0.018964877352118492, 0.05307865887880325, 0.8361539244651794, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01255274098366499, 0.03600494936108589, 0.010369472205638885, 0.019021298736333847, 0.0032906190026551485, 0.0037067385856062174, 0.017627976834774017, 0.01037716306746006, 0.8870489597320557, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0020879805088043213, 0.013872731477022171, 0.0018324662232771516, 0.006437606178224087, 0.013170951046049595, 0.011930068954825401, 0.0030771365854889154, 0.018353432416915894, 0.9292376041412354, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022865016013383865, 0.0674654096364975, 0.00996339786797762, 0.01914660632610321, 0.014956261031329632, 0.026097828522324562, 0.018910687416791916, 0.06562207639217377, 0.7549726963043213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019515078514814377, 0.07521340996026993, 0.03206341341137886, 0.0070005785673856735, 0.0066195218823850155, 0.03877842426300049, 0.01228683814406395, 0.032381508499383926, 0.7761411666870117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0022571254521608353, 0.00383052253164351, 0.0012509305961430073, 0.005982697941362858, 0.001252268673852086, 0.0028570422437042, 0.00556317949667573, 0.7337145805358887, 0.24329175055027008, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.132036030292511, 0.1683972030878067, 0.13758207857608795, 0.14189518988132477, 0.03147142380475998, 0.047566916793584824, 0.07834812998771667, 0.12177446484565735, 0.1409287303686142, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.038695693016052246, 0.06063547730445862, 0.020152689889073372, 0.1101006418466568, 0.021127784624695778, 0.02848564088344574, 0.03705665469169617, 0.108894944190979, 0.5748504996299744, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.18210574984550476, 0.33179324865341187, 0.17356830835342407, 0.09634877741336823, 0.13200198113918304, 0.013823754154145718, 0.003925282042473555, 0.0030049949418753386, 0.06342781335115433, 0.0, 0.2613511085510254, 0.024888625368475914, 0.11462423205375671, 0.021279124543070793, 0.1065509021282196, 0.23139707744121552, 0.07117345929145813, 0.09822205454111099, 0.07051338255405426, 0.0], [0.0020119324326515198, 0.018535200506448746, 0.9658629894256592, 0.007450602483004332, 0.004517300054430962, 0.0010996937053278089, 0.00011890578753082082, 8.778373739914969e-05, 0.0003156243183184415, 0.0, 0.08973123878240585, 0.07256940752267838, 0.3644520342350006, 0.09313907474279404, 0.10501276701688766, 0.026235496625304222, 0.035534195601940155, 0.05646198242902756, 0.15686386823654175, 0.0], [0.00025491390260867774, 0.010184208862483501, 0.989091694355011, 0.0003856455150526017, 4.125349732930772e-05, 1.630889528314583e-05, 4.754766450787429e-06, 9.06635705177905e-06, 1.2339231943769846e-05, 0.0, 0.12105944007635117, 0.03531542792916298, 0.18099160492420197, 0.04576702043414116, 0.03264385089278221, 0.04934798926115036, 0.0072426870465278625, 0.2739674150943756, 0.25366437435150146, 0.0], [0.0005263744969852269, 0.0021082928869873285, 0.03339724615216255, 0.9491472840309143, 0.010056117549538612, 0.0008323733345605433, 0.00041247360059060156, 0.003171485150232911, 0.0003482940956018865, 0.0, 0.049101557582616806, 0.02436317875981331, 0.1119280532002449, 0.019082490354776382, 0.23333144187927246, 0.12024182081222534, 0.09606382250785828, 0.03866123780608177, 0.3072265088558197, 0.0], [0.004013302735984325, 0.003607808379456401, 0.10117900371551514, 0.3848154544830322, 0.4549750089645386, 0.02184862084686756, 0.015023248270154, 0.013938427902758121, 0.0005991325015202165, 0.0, 0.011761177331209183, 0.004259902983903885, 0.019396282732486725, 0.010304590687155724, 0.5410462021827698, 0.1548439860343933, 0.1577453315258026, 0.022628072649240494, 0.07801424711942673, 0.0], [0.006630904506891966, 0.001555793103761971, 0.01566290855407715, 0.005377574823796749, 0.0545264296233654, 0.7578195929527283, 0.15542279183864594, 0.00011175184772582725, 0.002892365213483572, 0.0, 0.011012338101863861, 0.006456742994487286, 0.03514476120471954, 0.01111147552728653, 0.3646441400051117, 0.06045660004019737, 0.22725869715213776, 0.030072104185819626, 0.2538430392742157, 0.0], [0.0013706677127629519, 0.0003565592342056334, 0.0006504033226519823, 0.0008717189775779843, 0.023110924288630486, 0.16852477192878723, 0.8020843863487244, 0.0004564319388009608, 0.0025740356650203466, 0.0, 0.0009554855059832335, 0.0010365764610469341, 0.000539954868145287, 0.013481645844876766, 0.6702913641929626, 0.013201623223721981, 0.06565960496664047, 0.008186675608158112, 0.2266470193862915, 0.0], [0.0005740747437812388, 0.0018384596332907677, 0.015691960230469704, 0.0004515495093073696, 0.04004881531000137, 0.8668573498725891, 0.03566786274313927, 0.01278533972799778, 0.0260846596211195, 0.0, 0.007978711277246475, 0.0019918852485716343, 0.0007363414042629302, 0.010062554851174355, 0.10717969387769699, 0.01258536335080862, 0.08278501033782959, 0.02946571074426174, 0.7472147941589355, 0.0], [0.0008355869795195758, 0.003608973463997245, 0.04490630701184273, 0.009341607801616192, 0.007649072911590338, 0.10034366697072983, 0.06446904689073563, 0.7009655237197876, 0.06788014620542526, 0.0, 0.014369996264576912, 0.00412968173623085, 0.002898097038269043, 0.0381503589451313, 0.28382056951522827, 0.03412872180342674, 0.2624143660068512, 0.04523473232984543, 0.3148534893989563, 0.0], [0.00805425550788641, 0.039243537932634354, 0.05003930628299713, 0.0007152591715566814, 0.00863983016461134, 0.4756345748901367, 0.24407540261745453, 0.1204291433095932, 0.053168926388025284, 0.0, 0.0025127469561994076, 0.0030011499766260386, 0.0036209137178957462, 0.0006047216593287885, 0.01094596553593874, 0.0023283734917640686, 0.003409643191844225, 0.009625249542295933, 0.9639512896537781, 0.0]], [[0.0842718631029129, 0.33930304646492004, 0.1421334594488144, 0.18528752028942108, 0.05815916135907173, 0.022830937057733536, 0.01860896497964859, 0.009871570393443108, 0.13953347504138947, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1295168399810791, 0.08884051442146301, 0.06592670828104019, 0.2686370015144348, 0.02522267960011959, 0.03633918985724449, 0.021549394354224205, 0.051057688891887665, 0.31290990114212036, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20177747309207916, 0.039902716875076294, 0.053595658391714096, 0.09988140314817429, 0.01657777465879917, 0.07154539972543716, 0.024320384487509727, 0.12353017926216125, 0.36886897683143616, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20979411900043488, 0.30265647172927856, 0.20804975926876068, 0.007371237967163324, 0.0033807901199907064, 0.02442527562379837, 0.017248263582587242, 0.022337088361382484, 0.2047368586063385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20739784836769104, 0.05559583380818367, 0.12535981833934784, 0.009768493473529816, 0.015522800385951996, 0.024528708308935165, 0.03864477947354317, 0.08712086826562881, 0.4360608160495758, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.054182324558496475, 0.0038395673036575317, 0.019914912059903145, 0.014234894886612892, 0.012772555463016033, 0.019022708758711815, 0.04023807495832443, 0.36886361241340637, 0.46693122386932373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10943998396396637, 0.007069440558552742, 0.019821595400571823, 0.012627309188246727, 0.016869045794010162, 0.05302179232239723, 0.05124732851982117, 0.14304620027542114, 0.5868573188781738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23190192878246307, 0.060554418712854385, 0.05880776792764664, 0.00438718032091856, 0.00454165181145072, 0.15464532375335693, 0.09585105627775192, 0.02281157113611698, 0.3664989471435547, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13208958506584167, 0.15704363584518433, 0.07176639884710312, 0.08554346114397049, 0.0733223557472229, 0.0956358015537262, 0.07472448796033859, 0.09573546797037125, 0.21413877606391907, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.28021925687789917, 0.045415956526994705, 0.04812552034854889, 0.00880114920437336, 0.012029618956148624, 0.04001859948039055, 0.0577121265232563, 0.02487611398100853, 0.4828015863895416, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7289455533027649, 0.07586073875427246, 0.09869885444641113, 0.029881592839956284, 0.013988524675369263, 0.006547953933477402, 0.011115974746644497, 0.01961168274283409, 0.01534893549978733, 0.0, 0.023785226047039032, 0.024275904521346092, 0.6168470978736877, 0.01581703871488571, 0.026939542964100838, 0.1783975064754486, 0.04853774979710579, 0.02762567065656185, 0.03777410835027695, 0.0], [0.0068419137969613075, 0.9909167289733887, 0.0013876587618142366, 0.00011136479588458315, 6.257302447920665e-05, 6.40047510387376e-05, 1.7212767488672398e-05, 0.00025805848417803645, 0.0003405954339541495, 0.0, 0.006065902300179005, 0.04932599142193794, 0.8176359534263611, 0.018976736813783646, 0.008159944787621498, 0.011068272404372692, 0.010428683832287788, 0.014124251902103424, 0.06421414017677307, 0.0], [0.0458948090672493, 0.09657034277915955, 0.8496794104576111, 0.005955036263912916, 0.00011324687511660159, 0.000537818530574441, 0.00024974963162094355, 6.34682146483101e-05, 0.0009358474053442478, 0.0, 0.0037236923817545176, 0.007844064384698868, 0.9502744674682617, 0.0048003061674535275, 0.00022506865207105875, 0.004834793973714113, 0.0015490480000153184, 0.0026021157391369343, 0.024146683514118195, 0.0], [0.0026473035104572773, 0.007075308356434107, 0.0509142205119133, 0.9043685793876648, 0.01687121018767357, 0.0027590824756771326, 0.0040096985176205635, 0.004973408300429583, 0.006381142418831587, 0.0, 0.0007564057596027851, 0.0017460802337154746, 0.15768493711948395, 0.004074132069945335, 0.015430302359163761, 0.7368869781494141, 0.028010869398713112, 0.013945921324193478, 0.04146439954638481, 0.0], [0.00491793267428875, 0.0006714555202051997, 0.0047717769630253315, 0.09624139964580536, 0.8609471917152405, 0.020327016711235046, 0.008984015323221684, 0.0013932499568909407, 0.0017457373905926943, 0.0, 0.0008445779676549137, 0.0015138997696340084, 0.17073306441307068, 0.0074179465882480145, 0.08121992647647858, 0.5853323936462402, 0.09402737021446228, 0.024092217907309532, 0.034818582236766815, 0.0], [0.003878936870023608, 0.005480882711708546, 0.00011314810399198905, 0.0003485401102807373, 0.006120527163147926, 0.0029893070459365845, 0.0006264422554522753, 0.004414959345012903, 0.9760271906852722, 0.0, 0.0006014688406139612, 0.0016882645431905985, 0.16094569861888885, 0.003698966233059764, 0.034668561071157455, 0.5876308679580688, 0.09562253206968307, 0.05209798738360405, 0.06304588913917542, 0.0], [2.4277944248751737e-05, 4.029595402244013e-06, 3.533453991622082e-07, 0.0002488488098606467, 2.0782925275852904e-05, 0.0004858894390054047, 0.9990906119346619, 4.08113919547759e-05, 8.422240352956578e-05, 0.0, 0.00015283364336937666, 0.0004516944463830441, 0.003205003682523966, 0.0049727726727724075, 0.10853080451488495, 0.03262018784880638, 0.6125266551971436, 0.005719948559999466, 0.2318202555179596, 0.0], [0.0008877408690750599, 0.00558891985565424, 8.855570922605693e-05, 8.779557902016677e-06, 0.00010457105963723734, 0.00017662049503996968, 0.002778601599857211, 0.09916532039642334, 0.8912010192871094, 0.0, 9.933842375176027e-05, 0.00020211786613799632, 0.0037883264012634754, 0.0051808832213282585, 0.6936825513839722, 0.10089477151632309, 0.023457802832126617, 0.011726793833076954, 0.16096755862236023, 0.0], [0.0001188771057059057, 0.0008492054184898734, 9.383026917930692e-05, 5.974515715934103e-06, 0.002050562761723995, 8.90250812517479e-05, 8.933644130593166e-05, 0.002921103034168482, 0.9937818646430969, 0.0, 0.0002526468597352505, 0.0010056017199531198, 0.003837066935375333, 0.034950658679008484, 0.5882559418678284, 0.029549231752753258, 0.030938459560275078, 0.01461110170930624, 0.2965993583202362, 0.0], [0.013618292286992073, 0.005976812914013863, 3.6079491110285744e-05, 4.805085336556658e-05, 0.00010178168304264545, 0.03545643016695976, 0.04239484667778015, 0.35667654871940613, 0.5456912517547607, 0.0, 0.0029390468262135983, 0.005815382581204176, 0.06488344818353653, 0.008705642074346542, 0.010130577720701694, 0.012970774434506893, 0.019612692296504974, 0.007819950580596924, 0.8671225309371948, 0.0]], [[0.057871319353580475, 0.09845025092363358, 0.03600643575191498, 0.06401734054088593, 0.07263048738241196, 0.014885936863720417, 0.07473781704902649, 0.1193607747554779, 0.4620397090911865, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04736582189798355, 0.06951819360256195, 0.039210546761751175, 0.040616557002067566, 0.05645532160997391, 0.01900673843920231, 0.063181072473526, 0.23291724920272827, 0.43172842264175415, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05375257506966591, 0.04091374948620796, 0.01263821218162775, 0.04125160351395607, 0.014244006015360355, 0.012229752726852894, 0.029117466881871223, 0.07314542680978775, 0.7227071523666382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04062311723828316, 0.2312910407781601, 0.20085060596466064, 0.03848586603999138, 0.04763459786772728, 0.013425372540950775, 0.027237186208367348, 0.03882591798901558, 0.3616262376308441, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04558584839105606, 0.04037311673164368, 0.043737076222896576, 0.02740027941763401, 0.005366531666368246, 0.014126299880445004, 0.07268305867910385, 0.014923120848834515, 0.7358046770095825, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04308110475540161, 0.037618398666381836, 0.054927192628383636, 0.045146394520998, 0.02157701551914215, 0.014024189673364162, 0.03546718508005142, 0.04130468890070915, 0.7068538665771484, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.049693867564201355, 0.040712278336286545, 0.011129319667816162, 0.08677691221237183, 0.24132831394672394, 0.028864668682217598, 0.04710082337260246, 0.028962818905711174, 0.46543097496032715, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0256647989153862, 0.026453843340277672, 0.1064542606472969, 0.07867259532213211, 0.03285365179181099, 0.056291256099939346, 0.026517342776060104, 0.014768523164093494, 0.6323237419128418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07950135320425034, 0.04906205087900162, 0.09099037200212479, 0.10450085997581482, 0.06846266984939575, 0.21755923330783844, 0.14818403124809265, 0.14456483721733093, 0.0971745029091835, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0029191188514232635, 0.0026039828080683947, 0.000987313687801361, 0.027727283537387848, 0.007311245426535606, 0.0033244043588638306, 0.023969389498233795, 0.00596341909840703, 0.9251939058303833, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5123088955879211, 0.04897892847657204, 0.0054785809479653835, 0.037157464772462845, 0.033040400594472885, 0.0287709329277277, 0.020658228546380997, 0.005767439026385546, 0.3078390061855316, 0.0, 0.027314670383930206, 0.02143898233771324, 0.1116434708237648, 0.006578116212040186, 0.20446842908859253, 0.3867157995700836, 0.054494183510541916, 0.09778231382369995, 0.08956411480903625, 0.0], [0.02116605080664158, 0.45384567975997925, 0.5185156464576721, 0.002682638820260763, 0.000782136688940227, 0.0011598592391237617, 9.265153494197875e-05, 0.0013774167746305466, 0.00037764458102174103, 0.0, 0.09418193250894547, 0.7071846127510071, 0.05323847755789757, 0.0077135805040597916, 0.01789833791553974, 0.010848474688827991, 0.0020562252029776573, 0.01705808937549591, 0.08982021361589432, 0.0], [0.06292606890201569, 0.18043853342533112, 0.7547035217285156, 0.0011577572440728545, 6.746899998688605e-06, 0.00012623713701032102, 8.539375994587317e-05, 0.0005039210664108396, 5.15973697474692e-05, 0.0, 0.005751691292971373, 0.01031999196857214, 0.8884198069572449, 0.00210022390820086, 0.0066058398224413395, 0.019834432750940323, 0.002143828198313713, 0.02793465554714203, 0.036889322102069855, 0.0], [0.0030907110776752234, 0.0035500964149832726, 0.4530283808708191, 0.5221376419067383, 0.009557867422699928, 0.008033420890569687, 0.0001996932114707306, 0.0003745300055015832, 2.7415442673373036e-05, 0.0, 0.027659546583890915, 0.03931494802236557, 0.10616040229797363, 0.011142275296151638, 0.1017894372344017, 0.30847156047821045, 0.12201698124408722, 0.05519269034266472, 0.22825226187705994, 0.0], [0.00426989421248436, 0.0004077394842170179, 0.010691642761230469, 0.016011668369174004, 0.5530933737754822, 0.12423280626535416, 0.0053755901753902435, 0.28551235795021057, 0.00040478314622305334, 0.0, 0.037639543414115906, 0.062483835965394974, 0.050776157528162, 0.012697378173470497, 0.27911704778671265, 0.19993652403354645, 0.14870049059391022, 0.10304640233516693, 0.10560261458158493, 0.0], [0.00216855201870203, 0.009595326147973537, 0.007803121581673622, 0.04625817388296127, 0.24702508747577667, 0.2669595181941986, 0.024053409695625305, 0.24639348685741425, 0.14974308013916016, 0.0, 0.007995839230716228, 0.008397839032113552, 0.03270075097680092, 0.004312656354159117, 0.03775893524289131, 0.3733556568622589, 0.3424486219882965, 0.012857009656727314, 0.18017242848873138, 0.0], [1.953603486981592e-06, 5.726202516598278e-07, 5.0084551617146644e-08, 6.896097602293594e-06, 0.0001788837107596919, 0.0027895711828023195, 0.9969833493232727, 1.1296948287053965e-05, 2.7187752493773587e-05, 0.0, 0.012142053805291653, 0.007298614829778671, 0.016982076689600945, 0.02473442070186138, 0.08738671243190765, 0.033574704080820084, 0.27830857038497925, 0.033199213445186615, 0.5063735842704773, 0.0], [0.00041725789196789265, 0.0019265476148575544, 6.523763295263052e-05, 0.00018337361689191312, 0.011946662329137325, 0.04555974155664444, 0.15744170546531677, 0.025624049827456474, 0.7568355202674866, 0.0, 0.02729739435017109, 0.05135440081357956, 0.03332214429974556, 0.02499799057841301, 0.11955489963293076, 0.020848069339990616, 0.017926985397934914, 0.01858661323785782, 0.6861116290092468, 0.0], [0.0004706757317762822, 0.0027324894908815622, 0.0007427418022416532, 0.00934627279639244, 0.17134670913219452, 0.030644211918115616, 0.08413954824209213, 0.2513456642627716, 0.4492316246032715, 0.0, 0.018110578879714012, 0.011406980454921722, 0.0018257799092680216, 0.025524618104100227, 0.3885835111141205, 0.010744227096438408, 0.008441396057605743, 0.003679890651255846, 0.5316829681396484, 0.0], [0.0002758087939582765, 0.0016877831658348441, 2.4452297111565713e-06, 0.0004533462051767856, 0.001545731327496469, 0.008134560659527779, 0.010873721912503242, 0.026235109195113182, 0.950791597366333, 0.0, 0.02325628325343132, 0.013795747421681881, 0.0823512151837349, 0.0021813653875142336, 0.03511650115251541, 0.0814405307173729, 0.02589382231235504, 0.14330172538757324, 0.5926627516746521, 0.0]], [[0.3116825819015503, 0.20666195452213287, 0.1363646388053894, 0.07141851633787155, 0.029045483097434044, 0.04730900749564171, 0.037391580641269684, 0.10128220915794373, 0.05884409323334694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16883184015750885, 0.1785622239112854, 0.23318006098270416, 0.05258537083864212, 0.10740725696086884, 0.09185276180505753, 0.022670285776257515, 0.08943870663642883, 0.05547139048576355, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13844002783298492, 0.030681077390909195, 0.12107612937688828, 0.007168712094426155, 0.05214103311300278, 0.10045275092124939, 0.006991118658334017, 0.2955043315887451, 0.2475447952747345, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.17306901514530182, 0.19963578879833221, 0.148520827293396, 0.2046978771686554, 0.06720028817653656, 0.019652126356959343, 0.05792365223169327, 0.07665500044822693, 0.05264541134238243, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15115797519683838, 0.06555280834436417, 0.02683498151600361, 0.20794028043746948, 0.17434173822402954, 0.11980342864990234, 0.04239796847105026, 0.03961418569087982, 0.1723567098379135, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022970011457800865, 0.003322059055790305, 0.03333919122815132, 0.03161727264523506, 0.3007935583591461, 0.20675496757030487, 0.037206847220659256, 0.30415406823158264, 0.059842076152563095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029802093282341957, 0.006723356898874044, 0.00844306219369173, 0.09911961853504181, 0.2257867157459259, 0.22737178206443787, 0.28318148851394653, 0.05687837302684784, 0.0626935064792633, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008192392997443676, 0.002076511736959219, 0.010627061128616333, 0.01573592983186245, 0.01893553137779236, 0.042316026985645294, 0.02403445728123188, 0.868257999420166, 0.009823988191783428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08033955097198486, 0.18780502676963806, 0.02817567251622677, 0.041370414197444916, 0.02824225462973118, 0.038844767957925797, 0.11732563376426697, 0.06162348762154579, 0.4162730574607849, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0563269779086113, 0.007248359732329845, 0.008029782213270664, 0.003040844574570656, 0.007221699226647615, 0.01730128563940525, 0.050128430128097534, 0.3587413430213928, 0.491961270570755, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23423711955547333, 0.3770143389701843, 0.036408282816410065, 0.05249679461121559, 0.12246986478567123, 0.07398416101932526, 0.040104977786540985, 0.05500312149524689, 0.008280999027192593, 0.0, 0.011912941001355648, 0.006341524887830019, 0.1334817260503769, 0.017931688576936722, 0.005569889210164547, 0.7441595792770386, 0.0258712787181139, 0.034265827387571335, 0.020465616136789322, 0.0], [0.032638370990753174, 0.0975130945444107, 0.07180608063936234, 0.1075734794139862, 0.025216424837708473, 0.0218534916639328, 0.0376754030585289, 0.19710108637809753, 0.40862247347831726, 0.0, 0.03743305802345276, 0.05229289084672928, 0.3549361228942871, 0.028500670567154884, 0.01974724419414997, 0.29288655519485474, 0.08050932735204697, 0.06582070142030716, 0.06787342578172684, 0.0], [0.06665100902318954, 0.01976630836725235, 0.13041609525680542, 0.1772802174091339, 0.07561768591403961, 0.0061133550480008125, 0.022857116535305977, 0.3516288995742798, 0.14966930449008942, 0.0, 0.025082573294639587, 0.10057684034109116, 0.7856844663619995, 0.01178921852260828, 0.0010154875926673412, 0.02595749869942665, 0.008632739074528217, 0.006036050152033567, 0.0352250337600708, 0.0], [0.007724895142018795, 0.007534464355558157, 0.020593322813510895, 0.32147932052612305, 0.059838466346263885, 0.017819387838244438, 0.13181470334529877, 0.15524767339229584, 0.27794769406318665, 0.0, 0.010372502729296684, 0.023954369127750397, 0.18692812323570251, 0.03930393233895302, 0.004741673823446035, 0.46527597308158875, 0.1267295777797699, 0.048278260976076126, 0.09441567957401276, 0.0], [0.0028419073205441236, 0.0006648481939919293, 0.0018234169110655785, 0.01609039306640625, 0.0009005893371067941, 0.09726841002702713, 0.11035522073507309, 0.6978457570075989, 0.07220931351184845, 0.0, 0.0031391805969178677, 0.006868112366646528, 0.1369999200105667, 0.013019833713769913, 0.008593270555138588, 0.6626507639884949, 0.07777946442365646, 0.052107103168964386, 0.03884238004684448, 0.0], [0.0027223427314311266, 0.011157790198922157, 0.0052745467983186245, 0.00438346853479743, 0.010011360049247742, 0.38358205556869507, 0.2294219732284546, 0.057655833661556244, 0.29579076170921326, 0.0, 0.005276903510093689, 0.026354510337114334, 0.056238383054733276, 0.03191604092717171, 0.025259410962462425, 0.3898610472679138, 0.2790180444717407, 0.028249284252524376, 0.1578262895345688, 0.0], [0.0007371494430117309, 0.00023907626746222377, 8.450529276160523e-05, 0.002850313438102603, 0.002168564358726144, 0.02191324159502983, 0.9514637589454651, 0.0002505093871150166, 0.02029278129339218, 0.0, 0.001124523114413023, 0.0035109275486320257, 0.0021898215636610985, 0.043262600898742676, 0.0267842635512352, 0.03029855713248253, 0.5000982284545898, 0.0056237452663481236, 0.38710734248161316, 0.0], [0.004062887746840715, 0.007024500984698534, 0.007399421185255051, 0.011675640009343624, 0.0680280476808548, 0.02557964250445366, 0.07043837755918503, 0.01946563646197319, 0.7863259315490723, 0.0, 0.00020160828717052937, 0.0004381221951916814, 0.010444902814924717, 0.010398894548416138, 0.10143585503101349, 0.23107938468456268, 0.09920267760753632, 0.019364865496754646, 0.5274338126182556, 0.0], [0.0007200208492577076, 0.0024253681767731905, 0.029840486124157906, 0.00014908696175552905, 0.11268167197704315, 0.03336171433329582, 0.007834793999791145, 0.08127990365028381, 0.7317068576812744, 0.0, 0.00019610628078226, 0.00041535915806889534, 0.009462974965572357, 0.005905983969569206, 0.29870083928108215, 0.24092966318130493, 0.11132201552391052, 0.05168075114488602, 0.2813863754272461, 0.0], [0.008707311004400253, 0.08642490208148956, 0.11372587829828262, 0.004973042756319046, 0.13256150484085083, 0.16557356715202332, 0.0817113071680069, 0.006879410706460476, 0.3994430899620056, 0.0, 0.004135515075176954, 0.014277483336627483, 0.15738952159881592, 0.003396045882254839, 0.01641785353422165, 0.07296160608530045, 0.02388397790491581, 0.024013018235564232, 0.683525025844574, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\", \"\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"], \"bot_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\", \"\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}, \"inp_inp\": {\"att\": [[[[0.05334341153502464, 0.025828205049037933, 0.062369391322135925, 0.043252814561128616, 0.4045393764972687, 0.06697215139865875, 0.09001608937978745, 0.14983074367046356, 0.10384786874055862, 0.0], [0.11816457659006119, 0.03106253407895565, 0.01979171112179756, 0.16624291241168976, 0.3321376442909241, 0.020051123574376106, 0.08730963617563248, 0.18211135268211365, 0.04312858730554581, 0.0], [0.05936884880065918, 0.02174757793545723, 0.016160180792212486, 0.010601435787975788, 0.43925121426582336, 0.03876951336860657, 0.19815810024738312, 0.07065817713737488, 0.14528508484363556, 0.0], [0.15478025376796722, 0.16446512937545776, 0.0578744001686573, 0.21637752652168274, 0.03835854306817055, 0.09130414575338364, 0.11191156506538391, 0.08360221982002258, 0.08132638782262802, 0.0], [0.2183060646057129, 0.1704275906085968, 0.0827711746096611, 0.1202380359172821, 0.05203341320157051, 0.05958092212677002, 0.12280035018920898, 0.09366822242736816, 0.08017415553331375, 0.0], [0.05084313824772835, 0.026207493618130684, 0.13631564378738403, 0.012270472943782806, 0.16236551105976105, 0.02548854425549507, 0.03909383341670036, 0.03172134608030319, 0.5156941413879395, 0.0], [0.03615221381187439, 0.04799472168087959, 0.04255519434809685, 0.04762651398777962, 0.5117892622947693, 0.016304347664117813, 0.005770198069512844, 0.10897397249937057, 0.18283340334892273, 0.0], [0.03243544325232506, 0.025252558290958405, 0.11733424663543701, 0.0250592939555645, 0.20289097726345062, 0.08240236341953278, 0.18285907804965973, 0.011341268196702003, 0.3204246759414673, 0.0], [0.22355543076992035, 0.1260528564453125, 0.03741241991519928, 0.16813479363918304, 0.09858733415603638, 0.035831648856401443, 0.16361697018146515, 0.07236126810312271, 0.07444748282432556, 0.0], [0.08996112644672394, 0.0921943336725235, 0.22672457993030548, 0.12702998518943787, 0.05907799303531647, 0.10712798684835434, 0.16789256036281586, 0.055181413888931274, 0.07481010258197784, 0.0]], [[0.040477100759744644, 0.20988762378692627, 0.4869004786014557, 0.03505674749612808, 0.0558856800198555, 0.025423096492886543, 0.12231241166591644, 0.007062799762934446, 0.016993943601846695, 0.0], [0.8996549844741821, 0.02599872276186943, 0.049097247421741486, 0.0040262676775455475, 0.0039152717217803, 0.0049644638784229755, 0.010553319938480854, 0.001352570834569633, 0.0004369009402580559, 0.0], [0.33065715432167053, 0.2687782049179077, 0.03312753140926361, 0.22958999872207642, 0.01851547136902809, 0.046473052352666855, 0.053183481097221375, 0.007113412953913212, 0.012561764568090439, 0.0], [0.1589452475309372, 0.47470128536224365, 0.12878550589084625, 0.14158962666988373, 0.04442765936255455, 0.022274963557720184, 0.013780632056295872, 0.0024951419327408075, 0.012999956496059895, 0.0], [0.2559169828891754, 0.033451542258262634, 0.15095548331737518, 0.024318046867847443, 0.10824166238307953, 0.03234097361564636, 0.36475417017936707, 0.012823408469557762, 0.017197895795106888, 0.0], [0.021462664008140564, 0.010474847629666328, 0.007213775999844074, 0.02227940410375595, 0.21737068891525269, 0.4960675537586212, 0.014628118835389614, 0.20502059161663055, 0.005482145119458437, 0.0], [0.06734316051006317, 0.09532227367162704, 0.1127309575676918, 0.009542002342641354, 0.0678786113858223, 0.12933993339538574, 0.03809814900159836, 0.44453269243240356, 0.035212237387895584, 0.0], [0.10458365827798843, 0.02846018597483635, 0.029760979115962982, 0.014774680137634277, 0.022077379748225212, 0.1553817093372345, 0.3539015054702759, 0.19523507356643677, 0.09582491964101791, 0.0], [0.021077070385217667, 0.010932122357189655, 0.05088815093040466, 0.028641115874052048, 0.0881260335445404, 0.12014731019735336, 0.3900885581970215, 0.09544514119625092, 0.1946544349193573, 0.0], [0.02552945166826248, 0.05594164505600929, 0.045791901648044586, 0.093170166015625, 0.03584437444806099, 0.0969511866569519, 0.18585819005966187, 0.17433671653270721, 0.28657644987106323, 0.0]], [[0.18220090866088867, 0.25508272647857666, 0.2721964120864868, 0.04886331781744957, 0.010257811285555363, 0.07344724237918854, 0.08866558223962784, 0.037977367639541626, 0.0313086174428463, 0.0], [0.5722172260284424, 0.09567929804325104, 0.1448327898979187, 0.033306267112493515, 0.0031244128476828337, 0.020944159477949142, 0.012691132724285126, 0.061001092195510864, 0.05620381608605385, 0.0], [0.049244701862335205, 0.5266616344451904, 0.27518483996391296, 0.09334208071231842, 0.005858665332198143, 0.005467486567795277, 0.02565312758088112, 0.005746132228523493, 0.012841282412409782, 0.0], [0.13445906341075897, 0.13356590270996094, 0.6041688919067383, 0.01878039538860321, 0.06342840194702148, 0.03677675500512123, 0.008389262482523918, 0.0002739423362072557, 0.00015757972141727805, 0.0], [0.03273050859570503, 0.0697193592786789, 0.19719526171684265, 0.41500693559646606, 0.13721567392349243, 0.05743291601538658, 0.06517775356769562, 0.010865128599107265, 0.014656689018011093, 0.0], [0.031571000814437866, 0.014337136410176754, 0.06860436499118805, 0.09357307106256485, 0.10011686384677887, 0.07827721536159515, 0.5866308212280273, 0.011440092697739601, 0.015449290163815022, 0.0], [0.006158333271741867, 0.001533387927338481, 0.05427416041493416, 0.005477452650666237, 0.02694696933031082, 0.8134917616844177, 0.02643686905503273, 0.050265438854694366, 0.015415593050420284, 0.0], [0.008847472257912159, 0.0066053420305252075, 0.036443497985601425, 0.021455924957990646, 0.019254589453339577, 0.11543811857700348, 0.1138116791844368, 0.20307059586048126, 0.4750728905200958, 0.0], [0.017603449523448944, 0.008448019623756409, 0.004260394722223282, 0.006066101603209972, 0.013470137491822243, 0.01876576989889145, 0.16350960731506348, 0.1980665624141693, 0.5698099732398987, 0.0], [0.10490093380212784, 0.014168650843203068, 0.0247807614505291, 0.018330294638872147, 0.009348674677312374, 0.02287398651242256, 0.032268356531858444, 0.10571902245283127, 0.6676092147827148, 0.0]], [[0.2071455419063568, 0.637531578540802, 0.06835082173347473, 0.011966697871685028, 0.0017193991225212812, 0.04911382868885994, 0.009478496387600899, 0.008040529675781727, 0.00665308628231287, 0.0], [0.07411027699708939, 0.15093472599983215, 0.2656005620956421, 0.05758262053132057, 0.05194409564137459, 0.23625947535037994, 0.019166678190231323, 0.04010465368628502, 0.10429693013429642, 0.0], [0.1540999412536621, 0.10598444193601608, 0.22474077343940735, 0.32441702485084534, 0.1116243302822113, 0.054135363548994064, 0.008848286233842373, 0.004088098648935556, 0.012061581946909428, 0.0], [0.019440434873104095, 0.00560638727620244, 0.0035774046555161476, 0.0888679027557373, 0.7120485901832581, 0.14891275763511658, 0.011600993573665619, 0.008666431531310081, 0.0012791723711416125, 0.0], [0.08580154180526733, 0.02444172091782093, 0.08060747385025024, 0.05198557302355766, 0.2700504660606384, 0.34216371178627014, 0.11280739307403564, 0.006445358972996473, 0.02569655328989029, 0.0], [0.0424385629594326, 0.029667967930436134, 0.006252861116081476, 0.020168066024780273, 0.03000665083527565, 0.2812231779098511, 0.49279165267944336, 0.09351769089698792, 0.003933228086680174, 0.0], [0.006467411294579506, 0.0076894015073776245, 0.008325580507516861, 0.0010907554533332586, 0.01040297094732523, 0.19462232291698456, 0.013263629749417305, 0.24681615829467773, 0.5113216042518616, 0.0], [0.028696376830339432, 0.014982450753450394, 0.011884906329214573, 0.0011242942418903112, 0.01692844182252884, 0.12885364890098572, 0.028225399553775787, 0.6451764106750488, 0.12412811070680618, 0.0], [0.16117365658283234, 0.06794824451208115, 0.06173194944858551, 0.00451233983039856, 0.05306624248623848, 0.0510348416864872, 0.04402391240000725, 0.12432018667459488, 0.4321887195110321, 0.0], [0.1690559983253479, 0.043453093618154526, 0.036818861961364746, 0.017293656244874, 0.11775903403759003, 0.07970321178436279, 0.043801818042993546, 0.06849095970392227, 0.4236232340335846, 0.0]], [[0.03085354156792164, 0.12322185933589935, 0.13651973009109497, 0.050716523081064224, 0.2999139726161957, 0.09802427887916565, 0.06620478630065918, 0.0782310962677002, 0.11631430685520172, 0.0], [0.06789751350879669, 0.058182138949632645, 0.3129631578922272, 0.04353875666856766, 0.09142065048217773, 0.10271093249320984, 0.026392055675387383, 0.09630800783634186, 0.2005866914987564, 0.0], [0.07152411341667175, 0.3454192876815796, 0.11299439519643784, 0.18012462556362152, 0.07151429355144501, 0.052652161568403244, 0.0567985400557518, 0.09459780901670456, 0.014374655671417713, 0.0], [0.10420235246419907, 0.21845531463623047, 0.19832336902618408, 0.022119704633951187, 0.13572701811790466, 0.07722532749176025, 0.0508468933403492, 0.045597679913043976, 0.14750221371650696, 0.0], [0.07030870020389557, 0.10706955939531326, 0.02791348285973072, 0.02260597050189972, 0.12725059688091278, 0.07336997240781784, 0.26662203669548035, 0.16957008838653564, 0.13528966903686523, 0.0], [0.05156806856393814, 0.04327721148729324, 0.07664787024259567, 0.06931594759225845, 0.1889398992061615, 0.09515503793954849, 0.07227510958909988, 0.2641449272632599, 0.13867592811584473, 0.0], [0.02184019424021244, 0.11184182018041611, 0.36672860383987427, 0.013787303119897842, 0.07600502669811249, 0.0389828234910965, 0.040494974702596664, 0.12485849112272263, 0.20546066761016846, 0.0], [0.013738485053181648, 0.05187288299202919, 0.03463537245988846, 0.03627979755401611, 0.048659998923540115, 0.02440205216407776, 0.07256433367729187, 0.024731382727622986, 0.6931155323982239, 0.0], [0.02671198360621929, 0.4013687074184418, 0.01132842618972063, 0.14022575318813324, 0.026275552809238434, 0.08107840269804001, 0.04189194366335869, 0.25432130694389343, 0.0167979933321476, 0.0], [0.14228780567646027, 0.07866450399160385, 0.08390624076128006, 0.09396661072969437, 0.087954580783844, 0.14498625695705414, 0.13517630100250244, 0.1169552430510521, 0.11610251665115356, 0.0]], [[0.02165721170604229, 0.018354326486587524, 0.6383510828018188, 0.042513273656368256, 0.10956817120313644, 0.10717540234327316, 0.030344119295477867, 0.015826348215341568, 0.01621006615459919, 0.0], [0.4647374749183655, 0.07284841686487198, 0.28081396222114563, 0.014013433828949928, 0.03169411048293114, 0.02214456908404827, 0.058711059391498566, 0.036629818379879, 0.01840737834572792, 0.0], [0.07372704148292542, 0.12858736515045166, 0.4501189887523651, 0.054217785596847534, 0.07096204906702042, 0.05748127028346062, 0.06541819125413895, 0.04703349620103836, 0.05245373025536537, 0.0], [0.04684445261955261, 0.019098779186606407, 0.008431704714894295, 0.0010175607167184353, 0.9129327535629272, 0.004866998642683029, 0.006678053177893162, 8.096762758214027e-05, 4.903498847852461e-05, 0.0], [0.08239725232124329, 0.02813413366675377, 0.16611848771572113, 0.1532817929983139, 0.07408729940652847, 0.10856874287128448, 0.047752734273672104, 0.02563621662557125, 0.31402355432510376, 0.0], [0.17959792912006378, 0.02262653037905693, 0.10724494606256485, 0.022216446697711945, 0.1862414926290512, 0.14705143868923187, 0.15912717580795288, 0.15293282270431519, 0.02296125516295433, 0.0], [0.038375359028577805, 0.0038853511214256287, 0.06201936677098274, 0.005828780122101307, 0.22059503197669983, 0.36631014943122864, 0.020396992564201355, 0.20976856350898743, 0.07282061129808426, 0.0], [0.014258276671171188, 0.005652762018144131, 0.025611618533730507, 0.15294744074344635, 0.06760217249393463, 0.2498260736465454, 0.1669282466173172, 0.2265811711549759, 0.09059228003025055, 0.0], [0.15833799540996552, 0.1228356659412384, 0.10147804021835327, 0.0284584891051054, 0.27955442667007446, 0.06763719022274017, 0.08874277770519257, 0.1152903363108635, 0.037665050476789474, 0.0], [0.09844867885112762, 0.0919492095708847, 0.028445947915315628, 0.03726689890027046, 0.035665158182382584, 0.06817072629928589, 0.29930955171585083, 0.09819743037223816, 0.2425464242696762, 0.0]], [[0.02519470639526844, 0.006357265170663595, 0.14269335567951202, 0.023629529401659966, 0.3124701976776123, 0.13565225899219513, 0.2595662772655487, 0.07959114015102386, 0.014845297671854496, 0.0], [0.04550129547715187, 0.011541971005499363, 0.1165909469127655, 0.02512240968644619, 0.01843150518834591, 0.05711649730801582, 0.44489097595214844, 0.033205363899469376, 0.24759893119335175, 0.0], [0.13528011739253998, 0.06777236610651016, 0.14429129660129547, 0.04697401076555252, 0.1738385707139969, 0.014099549502134323, 0.38417065143585205, 0.01158357597887516, 0.02199004776775837, 0.0], [0.21356959640979767, 0.1638900637626648, 0.10595463216304779, 0.06925727427005768, 0.167257159948349, 0.04259340837597847, 0.10967854410409927, 0.03570139408111572, 0.09209771454334259, 0.0], [0.20140984654426575, 0.04755665361881256, 0.15174560248851776, 0.11619894206523895, 0.21928974986076355, 0.07600340992212296, 0.05828682705760002, 0.10010629147291183, 0.029402663931250572, 0.0], [0.024259669706225395, 0.02116699516773224, 0.21201731264591217, 0.019622934982180595, 0.4893963038921356, 0.021304504945874214, 0.16948339343070984, 0.022949064150452614, 0.01979990489780903, 0.0], [0.022248759865760803, 0.01183647196739912, 0.0633181631565094, 0.029095010831952095, 0.07090882211923599, 0.4614315629005432, 0.020150773227214813, 0.18720205128192902, 0.1338084638118744, 0.0], [0.003461656626313925, 0.01603432185947895, 0.009874427691102028, 0.014947548508644104, 0.2953553795814514, 0.3502987027168274, 0.08878874033689499, 0.036094941198825836, 0.18514421582221985, 0.0], [0.005101516842842102, 0.022985950112342834, 0.007523353211581707, 0.026773063465952873, 0.01009095273911953, 0.014858697541058064, 0.15149906277656555, 0.028601571917533875, 0.7325656414031982, 0.0], [0.12995873391628265, 0.07769863307476044, 0.02032659947872162, 0.13720010221004486, 0.011713794432580471, 0.054615918546915054, 0.23920413851737976, 0.13190706074237823, 0.19737498462200165, 0.0]], [[0.21207179129123688, 0.11920439451932907, 0.4251355528831482, 0.014464439824223518, 0.20776884257793427, 0.01428140513598919, 0.0027938869316130877, 0.001743048895150423, 0.002536489861086011, 0.0], [0.046175818890333176, 0.026793524622917175, 0.8552185297012329, 0.04517081379890442, 0.010388500988483429, 0.004191457759588957, 0.0036751439329236746, 0.0013485046802088618, 0.007037981878966093, 0.0], [0.013186579570174217, 0.020899420604109764, 0.6900137662887573, 0.0480119027197361, 0.15360434353351593, 0.02344118244946003, 0.03952033817768097, 0.0038994532078504562, 0.007422822527587414, 0.0], [0.006273405160754919, 0.00015674144378863275, 0.000751359446439892, 0.00447711581364274, 0.9859057664871216, 0.002212332095950842, 0.00014360185014083982, 4.957199053023942e-05, 2.9913859179941937e-05, 0.0], [0.001047183177433908, 0.0003636489564087242, 0.009283728897571564, 0.016805388033390045, 0.42387446761131287, 0.4776095747947693, 0.06253702938556671, 0.005590841174125671, 0.002888289513066411, 0.0], [0.0018647151300683618, 0.0002549054042901844, 2.6050107408082113e-05, 2.586200753285084e-05, 0.0024472770746797323, 0.006814199965447187, 0.9776560664176941, 0.010138182900846004, 0.000773087958805263, 0.0], [0.047241877764463425, 0.006076885852962732, 0.04534892365336418, 0.00081661093281582, 0.087706059217453, 0.41394293308258057, 0.21876952052116394, 0.17005810141563416, 0.0100388890132308, 0.0], [0.0019138919888064265, 0.006189406383782625, 0.010115097276866436, 8.508542669005692e-05, 0.008424345403909683, 0.003492203773930669, 0.13495568931102753, 0.4890870749950409, 0.34573695063591003, 0.0], [0.016032341867685318, 0.005025702994316816, 0.009520799852907658, 0.0008855267078615725, 0.026489384472370148, 0.0020503124687820673, 0.032939448952674866, 0.09461060166358948, 0.8124459385871887, 0.0], [0.25683313608169556, 0.02960006147623062, 0.11211041361093521, 0.09736908972263336, 0.17546677589416504, 0.032068025320768356, 0.017857572063803673, 0.025635067373514175, 0.25305992364883423, 0.0]]], [[[0.10487863421440125, 0.7106320858001709, 0.1635318249464035, 0.011256101541221142, 0.0012767312582582235, 0.00310636218637228, 0.0013001860352233052, 0.0012553841806948185, 0.002762428717687726, 0.0], [0.021650908514857292, 0.0030605364590883255, 0.6595932245254517, 0.2987315356731415, 0.012945608235895634, 0.0028472936246544123, 7.557096250820905e-05, 0.00029089683084748685, 0.0008047237643040717, 0.0], [0.014272261410951614, 0.040512338280677795, 0.8595607280731201, 0.038314104080200195, 0.037397123873233795, 0.006795509252697229, 0.001303989440202713, 0.001011757180094719, 0.0008321924251504242, 0.0], [0.031783342361450195, 0.007319662719964981, 0.7663278579711914, 0.0010118860518559813, 0.1672297865152359, 0.02513650804758072, 0.000853335193824023, 0.0002817189379129559, 5.600590884569101e-05, 0.0], [0.002136597875505686, 0.00037253598566167057, 0.07588302344083786, 0.2252500057220459, 0.33551687002182007, 0.35751965641975403, 0.0027331046294420958, 0.00018122239271178842, 0.0004068210837431252, 0.0], [0.0004353485128376633, 0.0003557991876732558, 0.0003262429090682417, 0.003819868667051196, 0.33603885769844055, 0.2681770920753479, 0.3838857412338257, 0.0068349516950547695, 0.00012614508159458637, 0.0], [6.71677480568178e-05, 3.9912600186653435e-05, 0.00047830803669057786, 5.937727837590501e-05, 0.0014537296956405044, 0.6413838863372803, 0.29047340154647827, 0.06565171480178833, 0.0003929881495423615, 0.0], [0.00047039391938596964, 0.0007891620043665171, 0.0007817292353138328, 0.0010076714679598808, 0.00965806283056736, 0.003733346238732338, 0.35330116748809814, 0.5722718238830566, 0.05798657611012459, 0.0], [0.006178696174174547, 0.009340841323137283, 0.0005589249776676297, 0.005146770738065243, 0.0033258567564189434, 0.0016933922888711095, 0.06414961069822311, 0.3291752338409424, 0.5804308652877808, 0.0], [0.006624103523790836, 0.001978900283575058, 0.0081730792298913, 0.0030846702866256237, 0.0018904987955465913, 0.0014340116176754236, 0.005187559872865677, 0.029854312539100647, 0.9417726993560791, 0.0]], [[0.17277710139751434, 0.13871003687381744, 0.020699918270111084, 0.04190761595964432, 0.17760643362998962, 0.1702892780303955, 0.16168300807476044, 0.10000763088464737, 0.01631900854408741, 0.0], [0.9987638592720032, 0.0011447033612057567, 1.5495901607209817e-05, 2.3805538096333123e-10, 1.1166920899086108e-07, 4.81009180930414e-07, 2.3257289285538718e-05, 3.4320622944505885e-05, 1.812833215808496e-05, 0.0], [0.029870687052607536, 0.9668734669685364, 0.0031853404361754656, 3.7420595617732033e-06, 1.0481591772304455e-07, 4.711453893690987e-09, 4.051101996083162e-07, 1.359390239485947e-06, 6.518688314827159e-05, 0.0], [2.9839180569979362e-05, 0.0008244949858635664, 0.9990562796592712, 6.778111855965108e-05, 2.14482715819031e-05, 5.3428358959273226e-11, 7.202954205309808e-11, 7.697720239008277e-11, 1.422941551254553e-07, 0.0], [9.680035873316228e-05, 4.205659934086725e-05, 0.0021876851096749306, 0.9926192164421082, 0.0050464412197470665, 7.330636890401365e-06, 4.7689670878980905e-08, 8.238330573284713e-10, 9.979119397485192e-08, 0.0], [5.136659183335723e-06, 6.750806136324172e-08, 8.17252839624416e-06, 0.008817464113235474, 0.9640147089958191, 0.027066770941019058, 8.771067950874567e-05, 3.571775764044105e-09, 3.5257423647294672e-09, 0.0], [5.115869043947896e-07, 1.0059281407848175e-08, 1.3136859422502312e-07, 9.641905052149013e-08, 0.001335342414677143, 0.9957214593887329, 0.0029362423811107874, 7.136273325158982e-06, 1.1521567699901425e-08, 0.0], [3.561131961760111e-06, 2.727877870256634e-07, 8.369554507225985e-07, 1.214864764342849e-09, 4.873449597653234e-06, 0.024909861385822296, 0.9680997133255005, 0.006879042834043503, 0.00010210835171164945, 0.0], [0.00021467455371748656, 9.040503209689632e-05, 3.369562909938395e-05, 1.9265097961351785e-08, 9.727973520057276e-07, 2.4095537810353562e-05, 0.0040859803557395935, 0.8618475794792175, 0.1337023377418518, 0.0], [2.289768872287823e-06, 6.284429400693625e-05, 0.0001214230724144727, 2.809870807141124e-07, 1.092972157223926e-09, 1.0671180605825725e-09, 1.2438744079190656e-06, 0.024907555431127548, 0.9749038219451904, 0.0]], [[0.058097392320632935, 0.00935883168131113, 0.04822169989347458, 0.0048278868198394775, 0.191309854388237, 0.28154584765434265, 0.09391050785779953, 0.24126385152339935, 0.07146408408880234, 0.0], [0.10414423793554306, 0.027566324919462204, 0.021727869287133217, 0.033647697418928146, 0.026882247999310493, 0.17782779037952423, 0.05685214698314667, 0.45095938444137573, 0.10039239376783371, 0.0], [0.44215551018714905, 0.049670565873384476, 0.014098896645009518, 0.029011834412813187, 0.01834075152873993, 0.1358453929424286, 0.04072042554616928, 0.2330295443534851, 0.03712712228298187, 0.0], [0.10425814986228943, 0.06979154050350189, 0.036334071308374405, 0.028995294123888016, 0.015532439574599266, 0.1330128014087677, 0.063407763838768, 0.23157192766666412, 0.3170958459377289, 0.0], [0.3384562134742737, 0.055937401950359344, 0.038792647421360016, 0.00819220207631588, 0.03063569962978363, 0.09386011958122253, 0.07227522879838943, 0.30926018953323364, 0.05259038880467415, 0.0], [0.3519401550292969, 0.1823827177286148, 0.06509842723608017, 0.030452275648713112, 0.08377533406019211, 0.09469012171030045, 0.04247477278113365, 0.11751312017440796, 0.03167306259274483, 0.0], [0.3634622097015381, 0.14048337936401367, 0.08374395966529846, 0.038946691900491714, 0.03473563492298126, 0.06442954391241074, 0.019375532865524292, 0.22685663402080536, 0.027966352179646492, 0.0], [0.18070067465305328, 0.04645215719938278, 0.0992647334933281, 0.005799622740596533, 0.47514480352401733, 0.12094692885875702, 0.030788421630859375, 0.025236092507839203, 0.015666494145989418, 0.0], [0.5453059673309326, 0.10054859519004822, 0.01722547970712185, 0.06704734265804291, 0.007780902087688446, 0.07263857871294022, 0.022086072713136673, 0.1394840031862259, 0.027883058413863182, 0.0], [0.15028028190135956, 0.17163224518299103, 0.06043723225593567, 0.10140684247016907, 0.10512865334749222, 0.06778015196323395, 0.06512691080570221, 0.23085294663906097, 0.04735487326979637, 0.0]], [[0.11086989939212799, 0.14517885446548462, 0.17419463396072388, 0.060936953872442245, 0.08783368766307831, 0.11005676537752151, 0.03251044824719429, 0.07983692735433578, 0.19858187437057495, 0.0], [0.16660544276237488, 0.29352903366088867, 0.1008867621421814, 0.023942291736602783, 0.15022507309913635, 0.06581585109233856, 0.02344084158539772, 0.05208655819296837, 0.12346797436475754, 0.0], [0.1683349758386612, 0.22478938102722168, 0.06976605206727982, 0.1032773107290268, 0.16255290806293488, 0.08890064060688019, 0.03925151377916336, 0.023706944659352303, 0.11942004412412643, 0.0], [0.19914905726909637, 0.1368866264820099, 0.178489089012146, 0.11241752654314041, 0.06187256798148155, 0.0768556222319603, 0.01627686619758606, 0.07274915277957916, 0.14530348777770996, 0.0], [0.08000901341438293, 0.20181676745414734, 0.21235129237174988, 0.05340588092803955, 0.12758778035640717, 0.11278047412633896, 0.06906574964523315, 0.08596791326999664, 0.05701539292931557, 0.0], [0.14153669774532318, 0.10432923585176468, 0.09881750494241714, 0.08603313565254211, 0.10391980409622192, 0.06189347058534622, 0.06772381067276001, 0.08503933250904083, 0.25070688128471375, 0.0], [0.06525713205337524, 0.07869093865156174, 0.11366366595029831, 0.044226594269275665, 0.05455174669623375, 0.23646420240402222, 0.09933798015117645, 0.1198185384273529, 0.1879890412092209, 0.0], [0.09450254589319229, 0.027017319574952126, 0.06480545550584793, 0.10929621011018753, 0.11382008343935013, 0.17441418766975403, 0.11898359656333923, 0.06495486199855804, 0.23220552504062653, 0.0], [0.07681684195995331, 0.0671391412615776, 0.0905177965760231, 0.06064317002892494, 0.06652072072029114, 0.09855856746435165, 0.07360702753067017, 0.13956283032894135, 0.3266339898109436, 0.0], [0.12179998308420181, 0.07977079600095749, 0.08405954390764236, 0.1456507444381714, 0.14551174640655518, 0.07862778753042221, 0.09882251918315887, 0.14300917088985443, 0.1027478501200676, 0.0]], [[0.0261031873524189, 0.9575563073158264, 0.006272038444876671, 0.0037288309540599585, 0.0038619006518274546, 0.0007324732141569257, 0.0005133527447469532, 0.0003637235495261848, 0.0008679544553160667, 0.0], [0.02134888991713524, 0.08473973721265793, 0.6753177642822266, 0.028721673414111137, 0.14432094991207123, 0.027568204328417778, 0.0057298606261610985, 0.004451636224985123, 0.007801060564815998, 0.0], [0.03883299231529236, 0.030284319072961807, 0.5620493292808533, 0.09062989801168442, 0.17362907528877258, 0.08253934979438782, 0.010801085270941257, 0.00978847872465849, 0.0014453904004767537, 0.0], [0.002180949319154024, 0.003013473702594638, 0.16569769382476807, 0.008050205186009407, 0.7580646276473999, 0.061441101133823395, 0.001020166208036244, 0.0001067533012246713, 0.0004249440098647028, 0.0], [0.004150479566305876, 0.00034606645931489766, 0.3802972435951233, 0.06855826079845428, 0.29045602679252625, 0.1767650991678238, 0.06603583693504333, 0.0014808314153924584, 0.011909942142665386, 0.0], [0.006170187145471573, 0.0012396957026794553, 0.0354800671339035, 0.0032299698796123266, 0.03240001201629639, 0.5543311238288879, 0.30418315529823303, 0.051339369267225266, 0.01162647269666195, 0.0], [0.0035115755163133144, 0.0011483307462185621, 0.017956364899873734, 0.003783614607527852, 0.030611976981163025, 0.3673596978187561, 0.20627115666866302, 0.3506667912006378, 0.01869054324924946, 0.0], [0.0021685126703232527, 0.0006909942603670061, 0.010240452364087105, 0.01958688348531723, 0.004634156823158264, 0.11485372483730316, 0.04815557599067688, 0.7050773501396179, 0.0945921242237091, 0.0], [0.049201104789972305, 0.02397306263446808, 0.02337191067636013, 0.31066185235977173, 0.06433572620153427, 0.12544430792331696, 0.0786852017045021, 0.25179895758628845, 0.07252778857946396, 0.0], [0.010841209441423416, 0.0041772774420678616, 0.01548130251467228, 0.036074474453926086, 0.033387064933776855, 0.08192819356918335, 0.04784044623374939, 0.10195028781890869, 0.668319821357727, 0.0]], [[0.005738695617765188, 0.0068999892100691795, 0.4274883270263672, 0.08288666605949402, 0.1445126235485077, 0.04382907599210739, 0.10957401990890503, 0.05347184091806412, 0.1255987584590912, 0.0], [0.0025263649877160788, 0.00471830926835537, 0.13454590737819672, 0.4177793860435486, 0.28839975595474243, 0.029358303174376488, 0.017654288560152054, 0.0047735795378685, 0.10024390369653702, 0.0], [0.009192855097353458, 0.007133236154913902, 0.03149157017469406, 0.1856081485748291, 0.5691666603088379, 0.07386670261621475, 0.029819192364811897, 0.03683711960911751, 0.05688462406396866, 0.0], [0.00297820963896811, 0.0015070328954607248, 0.0025649494491517544, 0.0011051844339817762, 0.04088710993528366, 0.1953955888748169, 0.34000417590141296, 0.3367410898208618, 0.07881659269332886, 0.0], [0.003951869439333677, 0.009354526177048683, 0.007010620087385178, 0.0025927696842700243, 0.09962604194879532, 0.10909298062324524, 0.4455967843532562, 0.15358439087867737, 0.16918975114822388, 0.0], [0.0038829154800623655, 0.0036434896755963564, 0.006399825215339661, 0.000760377966798842, 0.010139851830899715, 0.038725122809410095, 0.10014155507087708, 0.48370444774627686, 0.35260239243507385, 0.0], [0.001297087874263525, 0.0014563009608536959, 0.013839880004525185, 0.0004286184557713568, 0.012207024730741978, 0.028704902157187462, 0.046600911766290665, 0.26406532526016235, 0.6313998103141785, 0.0], [0.0033481158316135406, 0.0038099782541394234, 0.0031049775425344706, 0.00033546099439263344, 0.0031272985506802797, 0.008788534440100193, 0.021183660253882408, 0.12157405912876129, 0.8347280025482178, 0.0], [0.3364367187023163, 0.17456969618797302, 0.051038213074207306, 0.006790165323764086, 0.024106895551085472, 0.0694134384393692, 0.02184627763926983, 0.061508405953645706, 0.25429028272628784, 0.0], [0.10536088049411774, 0.07750789821147919, 0.0850178673863411, 0.08725376427173615, 0.2586125433444977, 0.16756391525268555, 0.054291605949401855, 0.030132828280329704, 0.13425879180431366, 0.0]], [[0.034539882093667984, 0.0018589550163596869, 0.9604092836380005, 1.3120608855388127e-05, 2.1815638319822028e-05, 0.00012517283903434873, 8.019943197723478e-05, 0.0021589084062725306, 0.0007928607519716024, 0.0], [7.048832912914804e-07, 1.7815009414334781e-06, 0.9998455047607422, 0.0001518452918389812, 4.1070780554264275e-08, 2.7954746156799715e-11, 9.231376947582692e-12, 9.901777175969073e-09, 2.5545642756696907e-07, 0.0], [6.695767496012195e-08, 2.089915795977504e-07, 0.005368041805922985, 0.9945066571235657, 0.0001248170156031847, 2.304766155702964e-09, 2.762512718579302e-10, 3.973758211373024e-09, 9.372820954922645e-07, 0.0], [5.018761014413675e-13, 1.4841802622529476e-16, 4.663825770023777e-09, 3.820862737313746e-09, 0.9999942183494568, 4.988648925063899e-06, 4.967477167452938e-13, 1.416252587396787e-16, 2.1775358895380023e-16, 0.0], [4.666895758731471e-09, 7.292542437975502e-12, 2.898993545219497e-11, 4.2817244194637283e-10, 0.00027504604076966643, 0.9995728731155396, 0.00015239788626786321, 1.9082661839586734e-10, 2.232514032581706e-13, 0.0], [1.7137297136926577e-10, 5.3312285142048665e-12, 2.2368220760327594e-14, 4.904942142678549e-17, 8.726878775178193e-09, 0.004644036293029785, 0.9953435659408569, 1.324965796811739e-05, 6.982896899598856e-12, 0.0], [4.877224735189145e-10, 1.5497924055196677e-09, 6.021576987036426e-11, 8.955144165463396e-19, 1.7180077889825118e-13, 6.163505759104737e-07, 0.001256544259376824, 0.9987285733222961, 1.4209075743565336e-05, 0.0], [3.25698863434809e-08, 7.313030323530256e-07, 1.412931510458293e-06, 1.1662047555981733e-16, 8.495708612521816e-14, 1.1933978653379251e-13, 1.3303619539328793e-07, 0.01294001005589962, 0.9870572686195374, 0.0], [1.6884889646462398e-06, 2.6281904865754768e-05, 0.001122217159718275, 6.101166945882142e-06, 4.424501298672112e-08, 5.172042264953158e-13, 5.508820136168602e-11, 5.942968346062116e-05, 0.9987838268280029, 0.0], [4.288114359951578e-05, 6.015944563841913e-06, 0.004432132933288813, 0.025997335091233253, 0.000731422973331064, 6.87844434188456e-11, 8.199346692057408e-13, 7.098316245901515e-08, 0.9687905311584473, 0.0]], [[0.02526121959090233, 0.9527671933174133, 0.014345486648380756, 0.0014051493490114808, 0.003839265089482069, 0.00014350644778460264, 0.0006356940139085054, 0.00025237957015633583, 0.0013501241337507963, 0.0], [0.004122408106923103, 0.023777475580573082, 0.9002965688705444, 0.0682864859700203, 0.0017659803852438927, 0.0001271881628781557, 0.00011044178245356306, 0.0001890352723421529, 0.0013242338318377733, 0.0], [8.841444650897756e-05, 0.0002895947836805135, 0.06307922303676605, 0.9069769978523254, 0.028407124802470207, 0.000558151863515377, 0.00022284295118879527, 0.00018588549573905766, 0.00019132612214889377, 0.0], [1.889026179924258e-06, 3.9712713260087185e-06, 0.001210480579175055, 0.003201226470991969, 0.8290116786956787, 0.16640713810920715, 0.00015829727635718882, 4.0429063119518105e-06, 9.256136763724498e-07, 0.0], [0.000399262469727546, 5.1438626542221755e-05, 0.0001944842515513301, 0.0007700449787080288, 0.4879837930202484, 0.4847603738307953, 0.025640420615673065, 0.00018376839580014348, 1.6383723050239496e-05, 0.0], [4.30414620495867e-05, 1.017293288896326e-05, 8.407413588429336e-06, 5.451946094581217e-07, 0.000544070964679122, 0.021075371652841568, 0.9573339819908142, 0.0208626389503479, 0.00012169074034318328, 0.0], [0.00043880229350179434, 0.0004488519043661654, 0.000600603292696178, 1.4583132212919736e-07, 3.6701523640658706e-05, 0.010162030346691608, 0.37363454699516296, 0.559087336063385, 0.0555914081633091, 0.0], [0.0010709260823205113, 0.0006920771556906402, 0.0016655249055474997, 0.00010216240480076522, 1.0821948308148421e-05, 2.6151516067329794e-05, 0.01446994487196207, 0.2987785339355469, 0.6831837296485901, 0.0], [0.0002485924051143229, 0.00016839140153024346, 0.019545644521713257, 0.016785046085715294, 0.005671702325344086, 0.00014030851889401674, 0.001185068627819419, 0.04272715002298355, 0.9135279655456543, 0.0], [0.0039028520695865154, 0.0008621322922408581, 0.02400260791182518, 0.35541704297065735, 0.048350416123867035, 0.00013779231812804937, 0.00015075977717060596, 0.0015127401566132903, 0.5656636953353882, 0.0]]], [[[0.09929531812667847, 0.3125585615634918, 0.26699960231781006, 0.036189958453178406, 0.01689508929848671, 0.05626463145017624, 0.014853590168058872, 0.021625356748700142, 0.17531771957874298, 0.0], [0.6598999500274658, 0.04883529245853424, 0.24573534727096558, 0.008949915878474712, 0.008034803904592991, 0.0058951652608811855, 0.001835338887758553, 0.0024289200082421303, 0.018385181203484535, 0.0], [0.28377673029899597, 0.4307016134262085, 0.19275489449501038, 0.05968217924237251, 0.007509235758334398, 0.00627214927226305, 0.0010254314402118325, 0.0010938378982245922, 0.017183959484100342, 0.0], [0.00751571636646986, 0.01881357654929161, 0.9318985342979431, 0.014481762424111366, 0.02105659246444702, 0.0032304797787219286, 0.00013498679618351161, 2.4857494281604886e-05, 0.0028432777617126703, 0.0], [0.08691340684890747, 0.01259385235607624, 0.21131311357021332, 0.15839329361915588, 0.3931293189525604, 0.10845079272985458, 0.004768806044012308, 0.0032348930835723877, 0.021202562376856804, 0.0], [0.029192518442869186, 0.06438057869672775, 0.033022571355104446, 0.04279496520757675, 0.6011855006217957, 0.17385539412498474, 0.03754284232854843, 0.006468524225056171, 0.011557108722627163, 0.0], [0.006125382613390684, 0.006982659921050072, 0.004575703293085098, 0.0037440320011228323, 0.36007580161094666, 0.5409486889839172, 0.0626324936747551, 0.00843171589076519, 0.006483553443104029, 0.0], [0.0017123871948570013, 0.017555760219693184, 0.012620777823030949, 0.00947127677500248, 0.08178496360778809, 0.2538650631904602, 0.19189175963401794, 0.255443274974823, 0.17565478384494781, 0.0], [0.02615528553724289, 0.002552631078287959, 0.01957615464925766, 0.021708596497774124, 0.008856788277626038, 0.021813882514834404, 0.052812058478593826, 0.19690369069576263, 0.6496209502220154, 0.0], [0.004899451043456793, 0.005663626827299595, 0.012920243665575981, 0.007757777348160744, 0.014441648498177528, 0.021742597222328186, 0.05050418898463249, 0.35952994227409363, 0.5225404500961304, 0.0]], [[0.8470081686973572, 0.043761640787124634, 0.000660977209918201, 0.00018918802379630506, 0.01478277612477541, 0.00942840613424778, 0.06798462569713593, 0.011217072606086731, 0.004967056680470705, 0.0], [0.9998846054077148, 9.298400982515886e-05, 7.557733283647394e-08, 4.2952964861113496e-13, 4.9295836510032665e-12, 3.2098330660090824e-09, 5.042555585532682e-06, 1.7450745872338302e-05, 2.33268380611662e-07, 0.0], [2.118646625604015e-05, 0.9999122619628906, 6.629392737522721e-05, 1.312590147684034e-09, 2.7011800782239526e-11, 6.488713510726871e-14, 1.250517189799183e-10, 3.650779589747799e-08, 2.9122876554765753e-08, 0.0], [1.1949000816580124e-11, 3.2456850362905243e-07, 1.0, 3.0732459777027543e-07, 4.943382370115046e-10, 1.2582140899967535e-17, 7.485076299292317e-18, 2.998638596002183e-14, 1.3861908843004755e-10, 0.0], [5.382360668271247e-10, 8.056646905174603e-09, 0.00035429277340881526, 0.9995232820510864, 0.00012279135989956558, 1.6631793720023325e-09, 1.8857353897253244e-14, 9.284229879032505e-15, 1.8321206097376974e-12, 0.0], [8.614902194392648e-12, 3.5818106835540375e-13, 4.029543365646759e-09, 3.1193526410788763e-06, 0.9959417581558228, 0.004055640660226345, 2.0883923923520342e-08, 1.5150488692381933e-14, 1.8145465705242968e-17, 0.0], [2.3006167283734502e-12, 4.150501252094593e-15, 2.9068709245239077e-12, 2.726213081238188e-13, 1.0724114645199734e-06, 0.9999104142189026, 8.954491204349324e-05, 3.77386955019432e-10, 8.537545242676776e-16, 0.0], [8.656632632941808e-10, 2.8593680201360883e-10, 4.910126749635424e-10, 3.37084723469553e-15, 1.3075121541028523e-10, 0.0003027402563020587, 0.999218225479126, 0.00047932929010130465, 1.4258912273135138e-08, 0.0], [1.0133464911632473e-07, 1.7307414168499236e-07, 2.3342326471720298e-07, 4.688030020606748e-13, 1.5028331227032177e-12, 5.3876938466146385e-09, 0.00158107269089669, 0.994592010974884, 0.0038271904923021793, 0.0], [2.33300490037891e-10, 1.2628836998374027e-07, 1.2948551102454076e-06, 3.169647599943204e-10, 1.5141217069741288e-14, 8.21656009561151e-15, 2.347289251858342e-09, 0.0025180077645927668, 0.9974797964096069, 0.0]], [[0.011770328506827354, 0.014021093025803566, 0.10656744986772537, 0.04667313024401665, 0.13704808056354523, 0.04681243374943733, 0.08347266167402267, 0.3310377299785614, 0.22259721159934998, 0.0], [0.009583584032952785, 0.010384900495409966, 0.09424954652786255, 0.09874095767736435, 0.2214881330728531, 0.08727390319108963, 0.09998933970928192, 0.16299772262573242, 0.21529172360897064, 0.0], [0.040493443608284, 0.05296378955245018, 0.12471148371696472, 0.04822944849729538, 0.2201310694217682, 0.13458549976348877, 0.16853223741054535, 0.12866733968257904, 0.08168572932481766, 0.0], [0.014574799686670303, 0.015747353434562683, 0.011357909068465233, 0.008449763990938663, 0.024292636662721634, 0.06141809746623039, 0.10683716088533401, 0.6414783596992493, 0.1158437430858612, 0.0], [0.0041047134436666965, 0.010159346275031567, 0.006441198755055666, 0.009530052542686462, 0.061682768166065216, 0.07391326874494553, 0.3019707202911377, 0.45178085565567017, 0.08041701465845108, 0.0], [0.013634801842272282, 0.03774101287126541, 0.015713637694716454, 0.01436087116599083, 0.06650711596012115, 0.06899012625217438, 0.1819150745868683, 0.376579225063324, 0.2245580554008484, 0.0], [0.03166442736983299, 0.07015468180179596, 0.1104653850197792, 0.016236137598752975, 0.18190902471542358, 0.08141329884529114, 0.15690769255161285, 0.22899281978607178, 0.12225660681724548, 0.0], [0.10994787514209747, 0.08447018265724182, 0.05270976573228836, 0.013435273431241512, 0.06919412314891815, 0.04981343820691109, 0.24833135306835175, 0.2721446752548218, 0.09995320439338684, 0.0], [0.39435869455337524, 0.21061576902866364, 0.1085209921002388, 0.004411425907164812, 0.06908565759658813, 0.04562678933143616, 0.02559957653284073, 0.06842028349637985, 0.0733608528971672, 0.0], [0.2682938873767853, 0.18270419538021088, 0.12741044163703918, 0.03156330808997154, 0.10574271529912949, 0.0955348014831543, 0.052997197955846786, 0.0821281224489212, 0.05362524837255478, 0.0]], [[8.027511648833752e-05, 0.0010475717717781663, 0.9977908730506897, 0.0002747455728240311, 0.000536168459802866, 9.231048170477152e-05, 0.00010586588905425742, 1.1979215742030647e-05, 5.969347330392338e-05, 0.0], [0.00012679747305810452, 5.715776205761358e-05, 0.922791600227356, 0.07177212089300156, 0.002934361109510064, 0.0005548547487705946, 0.001313770073466003, 2.2278460164670832e-05, 0.0004267726035322994, 0.0], [0.0063565499149262905, 0.0009426671313121915, 0.23976103961467743, 0.6402719020843506, 0.019077658653259277, 0.04590805247426033, 0.0423574335873127, 0.00055616011377424, 0.0047685266472399235, 0.0], [0.00012164804502390325, 1.1780298336816486e-05, 0.0001827587402658537, 0.00020120454428251833, 0.9978508353233337, 0.0014421044616028666, 6.411068170564249e-05, 4.628768147085793e-05, 7.896547322161496e-05, 0.0], [0.03763079643249512, 0.00208932813256979, 0.0006042887107469141, 0.5138440728187561, 0.19755180180072784, 0.029773280024528503, 0.15554653108119965, 0.015671545639634132, 0.0472884401679039, 0.0], [3.8805592339485884e-05, 1.2464041901694145e-05, 9.030352521222085e-05, 1.7544094589538872e-05, 0.0006991567788645625, 0.039246365427970886, 0.9305517077445984, 0.02403487078845501, 0.005308609921485186, 0.0], [0.003011370776221156, 0.005974559113383293, 0.003425326431170106, 0.001937237335368991, 0.01794668287038803, 0.06517820060253143, 0.25853174924850464, 0.28359606862068176, 0.3603990077972412, 0.0], [0.0019687232561409473, 0.0019828693475574255, 0.0009621239732950926, 0.0017320939805358648, 0.008526722900569439, 0.012685983441770077, 0.060781437903642654, 0.38653799891471863, 0.524821937084198, 0.0], [0.06319467723369598, 0.3812802731990814, 0.07775641977787018, 0.0546053946018219, 0.0410320870578289, 0.010218034498393536, 0.022281788289546967, 0.04868403077125549, 0.30094724893569946, 0.0], [0.06465335935354233, 0.0841824859380722, 0.028003698214888573, 0.01470992248505354, 0.013160775415599346, 0.006258893292397261, 0.003528257366269827, 0.022525515407323837, 0.7629771828651428, 0.0]], [[0.00496841873973608, 0.010829150676727295, 0.03283568099141121, 0.009884797036647797, 0.047239795327186584, 0.06476759165525436, 0.11417313665151596, 0.6207002401351929, 0.09460126608610153, 0.0], [0.014457895420491695, 0.06253711134195328, 0.10527490824460983, 0.051058270037174225, 0.04873393103480339, 0.058862265199422836, 0.13390113413333893, 0.44425415992736816, 0.0809202790260315, 0.0], [0.09337731450796127, 0.22848238050937653, 0.11594945937395096, 0.04185759648680687, 0.012283656746149063, 0.1264774352312088, 0.19395124912261963, 0.16978387534618378, 0.017837027087807655, 0.0], [0.7125841975212097, 0.21987739205360413, 0.020619483664631844, 0.02881826087832451, 0.009833384305238724, 0.004124533850699663, 0.0008098671096377075, 0.0004809961246792227, 0.0028517041355371475, 0.0], [0.029080189764499664, 0.33611080050468445, 0.12628716230392456, 0.0817737877368927, 0.1908877044916153, 0.0943109318614006, 0.05712011829018593, 0.06781000643968582, 0.016619542613625526, 0.0], [0.07309448719024658, 0.07739713788032532, 0.0567743182182312, 0.03291132301092148, 0.16455504298210144, 0.1779973953962326, 0.2714528441429138, 0.13868720829486847, 0.007130389101803303, 0.0], [0.2111189365386963, 0.06559138745069504, 0.041267942637205124, 0.009358389303088188, 0.20342323184013367, 0.1869427114725113, 0.19775718450546265, 0.07797932624816895, 0.006560905836522579, 0.0], [0.08770362287759781, 0.12808790802955627, 0.023038268089294434, 0.17453545331954956, 0.09798892587423325, 0.11677049100399017, 0.09396524727344513, 0.26174578070640564, 0.01616443321108818, 0.0], [0.35409674048423767, 0.0420590415596962, 0.00930203776806593, 0.3349112272262573, 0.03967892378568649, 0.15319538116455078, 0.022175630554556847, 0.0432865284383297, 0.0012946304632350802, 0.0], [0.10030248761177063, 0.08145220577716827, 0.053510215133428574, 0.08076464384794235, 0.07446140050888062, 0.13495147228240967, 0.2503055930137634, 0.17467214167118073, 0.04957977309823036, 0.0]], [[0.140123188495636, 0.010056160390377045, 0.0845566838979721, 0.03108036518096924, 0.16015855967998505, 0.30321791768074036, 0.04101235046982765, 0.0719088688492775, 0.1578858345746994, 0.0], [0.6134085655212402, 0.1547522246837616, 0.03818102553486824, 0.001013039844110608, 0.013297338038682938, 0.008754062466323376, 0.005134810693562031, 0.0324203222990036, 0.13303862512111664, 0.0], [0.6891250014305115, 0.17779399454593658, 0.09809523820877075, 0.006996517535299063, 0.007719202898442745, 0.0016296659596264362, 0.010662317276000977, 0.004304768517613411, 0.0036729834973812103, 0.0], [0.04376668110489845, 0.09640005975961685, 0.8100467324256897, 0.018579678609967232, 0.017539000138640404, 0.0008903089328669012, 0.0009985471842810512, 0.003613307373598218, 0.008165487088263035, 0.0], [0.03085213713347912, 0.025543441995978355, 0.6937543153762817, 0.17392684519290924, 0.03124413825571537, 0.02177071012556553, 0.007475809659808874, 0.003389933379366994, 0.012042560614645481, 0.0], [0.020024498924613, 0.002941351616755128, 0.05481509119272232, 0.183584526181221, 0.4182366132736206, 0.25923243165016174, 0.05362166836857796, 0.0045484029687941074, 0.002995501272380352, 0.0], [0.006091661751270294, 0.0012010806240141392, 0.008193010464310646, 0.009258490055799484, 0.15450483560562134, 0.7388086915016174, 0.06675267219543457, 0.01373466569930315, 0.0014547830214723945, 0.0], [0.0014694302808493376, 0.0017220929730683565, 0.005703628528863192, 0.0032696493435651064, 0.01713697426021099, 0.49356934428215027, 0.3729664385318756, 0.05505490303039551, 0.04910748079419136, 0.0], [0.0052343131974339485, 0.004969605710357428, 0.005609327927231789, 0.0007064095698297024, 0.005421568639576435, 0.045942794531583786, 0.22256441414356232, 0.43683722615242004, 0.27271413803100586, 0.0], [0.011939328163862228, 0.019054703414440155, 0.010745645500719547, 0.006908759940415621, 0.009522099047899246, 0.006889646407216787, 0.12289831787347794, 0.2292226105928421, 0.5828191637992859, 0.0]], [[0.0014003654941916466, 0.00935011450201273, 0.8996742963790894, 0.029868578538298607, 0.05752851441502571, 0.0008847691351547837, 0.0005429417942650616, 0.0004143548430874944, 0.00033632174017839134, 0.0], [0.0005502321291714907, 0.003854800947010517, 0.8475468754768372, 0.06876953691244125, 0.07909266650676727, 5.498397149494849e-05, 2.1647396351909265e-05, 6.648269391007489e-06, 0.00010276718239765614, 0.0], [0.0025599629152566195, 0.010113149881362915, 0.21385346353054047, 0.26065483689308167, 0.44287386536598206, 0.0458405464887619, 0.013329384848475456, 0.0076821851544082165, 0.0030928871128708124, 0.0], [0.0002600199659354985, 3.3608048397582024e-05, 0.0020931970793753862, 0.007768034934997559, 0.9780486822128296, 0.011327453888952732, 0.00041993538616225123, 4.125805935473181e-05, 8.07127889856929e-06, 0.0], [0.0010751935187727213, 0.00017567894246894866, 0.004301255568861961, 0.0010412797564640641, 0.012584774754941463, 0.5903621912002563, 0.36841556429862976, 0.021853862330317497, 0.00019013854034710675, 0.0], [0.00036065353197045624, 0.00041391997365280986, 0.00018344201089348644, 1.21664334074012e-05, 0.0008204621262848377, 0.02300320193171501, 0.7380199432373047, 0.23411831259727478, 0.0030676021706312895, 0.0], [0.0007766868220642209, 0.00179819215554744, 0.0031821478623896837, 1.569229607412126e-05, 0.001023828866891563, 0.004582487046718597, 0.04412461444735527, 0.8326310515403748, 0.11186514794826508, 0.0], [0.002560202032327652, 0.0021961459424346685, 0.0012966376962140203, 3.874531466863118e-05, 0.00012789985339622945, 0.00017348439723718911, 0.06046983227133751, 0.07663179188966751, 0.856505274772644, 0.0], [0.05078713223338127, 0.09524610638618469, 0.03648101165890694, 0.050540339201688766, 0.009611092507839203, 0.0027538249269127846, 0.009690326638519764, 0.015156174078583717, 0.7297340035438538, 0.0], [0.017420543357729912, 0.009016300551593304, 0.008660875260829926, 0.04713813588023186, 0.042011067271232605, 0.003162879729643464, 0.00040178498602472246, 0.005153133533895016, 0.8670352697372437, 0.0]], [[0.22553573548793793, 0.2680850327014923, 0.019470686092972755, 0.14175784587860107, 0.053468361496925354, 0.02777918614447117, 0.05628729239106178, 0.04874898120760918, 0.15886712074279785, 0.0], [0.28905513882637024, 0.12247822433710098, 0.046002231538295746, 0.1958596557378769, 0.10771062225103378, 0.06661061197519302, 0.07628067582845688, 0.02713944762945175, 0.06886337697505951, 0.0], [0.04905243590474129, 0.05268532782793045, 0.11285670101642609, 0.09091109782457352, 0.24185867607593536, 0.20752739906311035, 0.04222555831074715, 0.05885446071624756, 0.14402832090854645, 0.0], [0.06971512734889984, 0.14066818356513977, 0.05942149832844734, 0.21028849482536316, 0.10966084897518158, 0.08002462983131409, 0.10722756385803223, 0.1377343237400055, 0.08525940030813217, 0.0], [0.1429702192544937, 0.26978883147239685, 0.12360350787639618, 0.05825580656528473, 0.022957824170589447, 0.2193503975868225, 0.0713224932551384, 0.06461618840694427, 0.02713468112051487, 0.0], [0.07554306834936142, 0.051579318940639496, 0.2103901356458664, 0.03246254473924637, 0.12347473949193954, 0.20594589412212372, 0.10415074229240417, 0.14436782896518707, 0.05208563804626465, 0.0], [0.10752540081739426, 0.08459899574518204, 0.07340764254331589, 0.019914846867322922, 0.048802055418491364, 0.2628321945667267, 0.23049965500831604, 0.11754198372364044, 0.05487721040844917, 0.0], [0.054300110787153244, 0.03522595763206482, 0.19028180837631226, 0.11526520550251007, 0.043804410845041275, 0.1941872388124466, 0.12765192985534668, 0.19942660629749298, 0.03985673561692238, 0.0], [0.13462598621845245, 0.09648311138153076, 0.08205218613147736, 0.241444393992424, 0.024601474404335022, 0.03336581960320473, 0.09252338856458664, 0.0673752948641777, 0.22752824425697327, 0.0], [0.1438782811164856, 0.15257491171360016, 0.11015111207962036, 0.2259429395198822, 0.11582648009061813, 0.06522659957408905, 0.06865230947732925, 0.07465960830450058, 0.04308782145380974, 0.0]]], [[[0.008583037182688713, 0.007665919605642557, 0.023932937532663345, 0.013663848862051964, 0.00724611384794116, 0.01780843734741211, 0.04220886155962944, 0.035630952566862106, 0.8432599306106567, 0.0], [0.005249040201306343, 0.006725347600877285, 0.022601336240768433, 0.004061485640704632, 0.003380684182047844, 0.05792760103940964, 0.08571713417768478, 0.017759306356310844, 0.796578049659729, 0.0], [0.014741344377398491, 0.08626628667116165, 0.11416944116353989, 0.06755448132753372, 0.010767532512545586, 0.037519536912441254, 0.13943251967430115, 0.03284287825226784, 0.4967060387134552, 0.0], [0.8946033120155334, 0.07520093768835068, 0.007621173746883869, 0.004705401603132486, 0.005715447012335062, 0.0016736779361963272, 0.0011882666731253266, 0.0005322583019733429, 0.008759708143770695, 0.0], [0.17331360280513763, 0.32618802785873413, 0.1865183413028717, 0.12219864875078201, 0.08427056670188904, 0.017049826681613922, 0.027256622910499573, 0.011689829640090466, 0.05151442065834999, 0.0], [0.024287043139338493, 0.22289688885211945, 0.2742122411727905, 0.1883603185415268, 0.1339159905910492, 0.04209006950259209, 0.04496186599135399, 0.03600992262363434, 0.033265650272369385, 0.0], [0.01142946071922779, 0.05564042925834656, 0.055694323033094406, 0.5140662789344788, 0.1435396671295166, 0.038738954812288284, 0.06230159476399422, 0.07060025632381439, 0.047988954931497574, 0.0], [0.03956271708011627, 0.0978141501545906, 0.053332336246967316, 0.4993227422237396, 0.15091775357723236, 0.05724353715777397, 0.05616844817996025, 0.014285729266703129, 0.03135249391198158, 0.0], [0.04081583395600319, 0.017569201067090034, 0.031049959361553192, 0.07860688865184784, 0.1978374421596527, 0.3013133406639099, 0.2561938464641571, 0.010236106812953949, 0.06637723743915558, 0.0], [0.005346705671399832, 0.017637349665164948, 0.01670711860060692, 0.027819450944662094, 0.014111858792603016, 0.15744496881961823, 0.29349666833877563, 0.10989060997962952, 0.357545405626297, 0.0]], [[0.14326919615268707, 0.06937730312347412, 0.4621289074420929, 0.06899607926607132, 0.20691490173339844, 0.03204977884888649, 0.010433961637318134, 0.001572124194353819, 0.005257652141153812, 0.0], [0.7372201681137085, 0.03819188475608826, 0.19263039529323578, 0.00509582320228219, 0.014029700309038162, 0.004338367842137814, 0.0016640998655930161, 0.0023727945517748594, 0.004456941969692707, 0.0], [0.6392468810081482, 0.09436309337615967, 0.23124097287654877, 0.009032140485942364, 0.016629014164209366, 0.004053707234561443, 0.0011662752367556095, 0.0013368013314902782, 0.0029307324439287186, 0.0], [0.15959776937961578, 0.060010410845279694, 0.6323540210723877, 0.04208587482571602, 0.09941276162862778, 0.001314919558353722, 0.0003186642425134778, 0.00045829309965483844, 0.004447522107511759, 0.0], [0.06331828236579895, 0.03697410970926285, 0.6882537603378296, 0.04094800353050232, 0.1500014215707779, 0.014815385453402996, 0.0006663103122264147, 0.0014023728435859084, 0.0036205528303980827, 0.0], [0.02740752510726452, 0.007235638331621885, 0.2575177550315857, 0.2825733423233032, 0.26921361684799194, 0.13694509863853455, 0.012512636370956898, 0.00419765617698431, 0.0023968773894011974, 0.0], [0.026527998968958855, 0.0014296816661953926, 0.0034867397043854, 0.11850380897521973, 0.15826237201690674, 0.4342584013938904, 0.21162042021751404, 0.04376554489135742, 0.0021449460182338953, 0.0], [0.0008783259545452893, 0.0010965524706989527, 0.006981557235121727, 0.007060014642775059, 0.27200379967689514, 0.45634904503822327, 0.1935150921344757, 0.03130912408232689, 0.030806703492999077, 0.0], [0.012816469185054302, 0.004784241784363985, 0.007290879264473915, 0.0027244724333286285, 0.0388973169028759, 0.12052476406097412, 0.3920805752277374, 0.10759556293487549, 0.3132855296134949, 0.0], [0.0021361028775572777, 0.003133963793516159, 0.003311034757643938, 0.0013810866512358189, 0.004479007329791784, 0.007041627541184425, 0.09507600963115692, 0.5596640706062317, 0.32377713918685913, 0.0]], [[0.001748488168232143, 0.011698327027261257, 0.047558922320604324, 0.7770814299583435, 0.15215088427066803, 0.0056790816597640514, 0.0010312696686014533, 0.0011229184456169605, 0.0019287114264443517, 0.0], [0.000820137036498636, 0.0007328591891564429, 0.012266330420970917, 0.94822758436203, 0.02221596986055374, 0.006038068328052759, 0.0018012026557698846, 0.002194090047851205, 0.0057037402875721455, 0.0], [0.0017187671037390828, 0.0012595502194017172, 0.00971528235822916, 0.8996129631996155, 0.03184645250439644, 0.026646586135029793, 0.01671759784221649, 0.005960865877568722, 0.006522092968225479, 0.0], [0.010048117488622665, 0.003920346032828093, 0.01464000903069973, 0.028398782014846802, 0.047600653022527695, 0.6803404688835144, 0.07394693046808243, 0.046145662665367126, 0.09495888650417328, 0.0], [0.0020061242394149303, 0.0010488562984392047, 0.0021137045696377754, 0.03403143212199211, 0.040159616619348526, 0.4656003415584564, 0.16990402340888977, 0.16164875030517578, 0.12348736822605133, 0.0], [0.0023888982832431793, 0.0010238748509436846, 0.0031129145063459873, 0.00400560162961483, 0.005227341782301664, 0.050918273627758026, 0.28773385286331177, 0.5181463956832886, 0.12744267284870148, 0.0], [0.0057381619699299335, 0.0037375285755842924, 0.006655727047473192, 0.0010085925459861755, 0.005980721674859524, 0.02943945676088333, 0.05893365666270256, 0.6100658774375916, 0.2784405052661896, 0.0], [0.003593636676669121, 0.0024473541416227818, 0.002264569513499737, 0.00914584007114172, 0.0013253247598186135, 0.010908454656600952, 0.07958614826202393, 0.12585432827472687, 0.7648744583129883, 0.0], [0.031058229506015778, 0.02174283377826214, 0.012145284563302994, 0.010826506651937962, 0.01352943666279316, 0.021966811269521713, 0.055832888931035995, 0.11603516340255737, 0.7168627977371216, 0.0], [0.20383700728416443, 0.06762446463108063, 0.042199794203042984, 0.021983252838253975, 0.11625738441944122, 0.013579235412180424, 0.025292381644248962, 0.08914806693792343, 0.4200783669948578, 0.0]], [[0.022736268118023872, 0.02286626398563385, 0.14116300642490387, 0.13108347356319427, 0.23994718492031097, 0.1924150437116623, 0.01816762052476406, 0.04976898059248924, 0.18185211718082428, 0.0], [0.05882957577705383, 0.028569074347615242, 0.23305171728134155, 0.053790394216775894, 0.18451730906963348, 0.2002667486667633, 0.015585620887577534, 0.052768219262361526, 0.17262138426303864, 0.0], [0.09136874228715897, 0.08459936082363129, 0.05023255571722984, 0.21660202741622925, 0.1335863471031189, 0.10654665529727936, 0.02717875875532627, 0.06888726353645325, 0.22099831700325012, 0.0], [0.04131297022104263, 0.05848437175154686, 0.3077566921710968, 0.040097035467624664, 0.16343727707862854, 0.11984208226203918, 0.06441103667020798, 0.0850440189242363, 0.11961443722248077, 0.0], [0.06447532773017883, 0.05503746494650841, 0.11529060453176498, 0.13719302415847778, 0.0843825414776802, 0.22279226779937744, 0.11870565265417099, 0.05292103812098503, 0.14920207858085632, 0.0], [0.061820220202207565, 0.03663187846541405, 0.08412205427885056, 0.386857271194458, 0.1083698719739914, 0.1462787538766861, 0.03903358429670334, 0.026668915525078773, 0.11021733283996582, 0.0], [0.08746915310621262, 0.025642354041337967, 0.16437062621116638, 0.19346435368061066, 0.10867251455783844, 0.12237238138914108, 0.06722743809223175, 0.0922309011220932, 0.13855047523975372, 0.0], [0.10294228792190552, 0.07313423603773117, 0.18607352674007416, 0.09769721329212189, 0.1089077964425087, 0.26933327317237854, 0.06555335968732834, 0.061070602387189865, 0.03528755530714989, 0.0], [0.12094805389642715, 0.14730192720890045, 0.09877816587686539, 0.21085986495018005, 0.06241541728377342, 0.22994481027126312, 0.04595630243420601, 0.04531335458159447, 0.0384821854531765, 0.0], [0.11032164841890335, 0.07897982746362686, 0.08231978863477707, 0.2677886188030243, 0.1231643408536911, 0.0929633229970932, 0.08270144462585449, 0.06097007542848587, 0.10079105943441391, 0.0]], [[0.008687321096658707, 0.012162125669419765, 0.02774685248732567, 0.0013578477082774043, 0.052177976816892624, 0.027187975123524666, 0.05590689554810524, 0.020962538197636604, 0.7938104867935181, 0.0], [0.005042325239628553, 0.015503124333918095, 0.010042164474725723, 0.0008876739302650094, 0.011308688670396805, 0.010491759516298771, 0.03130592033267021, 0.04934320226311684, 0.8660751581192017, 0.0], [0.013016406446695328, 0.03886239603161812, 0.027493299916386604, 0.029101338237524033, 0.009947741404175758, 0.00769558921456337, 0.035501737147569656, 0.023772817105054855, 0.8146085143089294, 0.0], [0.018851714208722115, 0.05105733126401901, 0.8005384206771851, 0.01116525661200285, 0.09583853930234909, 0.0015093896072357893, 0.005055624525994062, 0.0006665397086180747, 0.015317671000957489, 0.0], [0.01609102450311184, 0.023716216906905174, 0.5135837197303772, 0.10603100061416626, 0.26668840646743774, 0.019648341462016106, 0.01755940169095993, 0.01368130836635828, 0.023000601679086685, 0.0], [0.01718730293214321, 0.02692273259162903, 0.05480796471238136, 0.010818017646670341, 0.7150712013244629, 0.0585104264318943, 0.04717297852039337, 0.030360547825694084, 0.039148781448602676, 0.0], [0.006439396180212498, 0.012697076424956322, 0.014188298024237156, 0.000897688849363476, 0.7481768727302551, 0.15047557651996613, 0.03333613649010658, 0.01207506563514471, 0.021714046597480774, 0.0], [0.009459104388952255, 0.022298788651823997, 0.013802104629576206, 0.011955137364566326, 0.03879927098751068, 0.1585427075624466, 0.07075291126966476, 0.329448938369751, 0.3449409306049347, 0.0], [0.04810584336519241, 0.017975708469748497, 0.025123968720436096, 0.023182567209005356, 0.020010611042380333, 0.04571577161550522, 0.1801854819059372, 0.06764508783817291, 0.5720548629760742, 0.0], [0.026153914630413055, 0.0356404148042202, 0.10573611408472061, 0.06201518699526787, 0.06006328761577606, 0.09286139905452728, 0.2927103638648987, 0.20419549942016602, 0.12062377482652664, 0.0]], [[0.02415475994348526, 0.0027711745351552963, 0.003856832394376397, 0.0957413911819458, 0.02159286104142666, 0.03336814045906067, 0.009564127773046494, 0.03954486921429634, 0.7694058418273926, 0.0], [0.9052021503448486, 0.02053658291697502, 0.0014916026266291738, 0.00022646080469712615, 4.7710393118904904e-05, 0.000383042759494856, 0.014123834669589996, 0.0205638837069273, 0.03742456063628197, 0.0], [0.37607336044311523, 0.6030705571174622, 0.0068079219199717045, 0.0036466827150434256, 9.876023250399157e-05, 2.0246809071977623e-05, 0.0007042856304906309, 0.002560489112511277, 0.007017510011792183, 0.0], [5.0091031880583614e-05, 0.00024915943504311144, 0.9895205497741699, 0.006273698527365923, 0.0016484790248796344, 4.1711446101544425e-05, 7.522702958340233e-07, 1.2660359971050639e-05, 0.002202932955697179, 0.0], [8.009441080503166e-05, 9.311464236816391e-05, 0.006593613885343075, 0.9913647770881653, 0.0018261962104588747, 1.6436462829005904e-05, 8.038865075832291e-07, 1.0318336762793479e-06, 2.3524326024926268e-05, 0.0], [3.1561212381348014e-05, 1.8178753862230224e-06, 0.00011904581333510578, 0.027105441316962242, 0.8800897598266602, 0.09253741800785065, 0.00010895416926359758, 5.953493655397324e-06, 1.9602707368449046e-07, 0.0], [1.7160528553716858e-09, 1.4191656530493368e-11, 3.274841375855431e-08, 2.1219284462858923e-07, 1.9925082597183064e-05, 0.9999751448631287, 3.130498271275428e-06, 1.9788064946624218e-06, 3.1215499074477293e-09, 0.0], [1.2861962204624433e-05, 5.737682045037218e-07, 2.0471109110076213e-06, 1.0477544492459856e-05, 6.581651632586727e-06, 0.02534269355237484, 0.16125597059726715, 0.5878354907035828, 0.22553342580795288, 0.0], [0.0009172551217488945, 7.270056084962562e-05, 2.2026280930731446e-05, 4.6261970965133514e-06, 4.921669642499182e-06, 4.060195351485163e-05, 0.027831047773361206, 0.33271971344947815, 0.6383873224258423, 0.0], [1.3075091374048498e-05, 6.147480598883703e-05, 4.768987855641171e-05, 2.045959490715177e-06, 1.1152823553572944e-08, 3.07468525306831e-07, 0.0007055726600810885, 0.02803119830787182, 0.9711382985115051, 0.0]], [[0.060361556708812714, 0.015829458832740784, 0.05784451961517334, 0.3351474404335022, 0.06477320939302444, 0.04427827522158623, 0.09356044977903366, 0.03362266346812248, 0.2945823669433594, 0.0], [0.051239900290966034, 0.0459107868373394, 0.10656695812940598, 0.4080160856246948, 0.16381530463695526, 0.044977184385061264, 0.05972094088792801, 0.009804679080843925, 0.10994797199964523, 0.0], [0.019088272005319595, 0.05349855497479439, 0.4389742910861969, 0.022328443825244904, 0.03395729511976242, 0.20592069625854492, 0.007582489866763353, 0.08437496423721313, 0.13427504897117615, 0.0], [0.03275543451309204, 0.01311502419412136, 0.038520246744155884, 0.47789818048477173, 0.04586595296859741, 0.01380465179681778, 0.03337283805012703, 0.07212045043706894, 0.27254730463027954, 0.0], [0.04071904346346855, 0.043366871774196625, 0.1190471276640892, 0.18268215656280518, 0.2763146162033081, 0.029253922402858734, 0.017268449068069458, 0.0670313611626625, 0.22431644797325134, 0.0], [0.04853136092424393, 0.0034203159157186747, 0.17822766304016113, 0.005087696481496096, 0.02670232392847538, 0.5734196305274963, 0.06478680670261383, 0.04684215411543846, 0.05298209935426712, 0.0], [0.016102498397231102, 0.0006646174006164074, 0.00315408268943429, 0.003398373955860734, 0.01210782676935196, 0.07864897698163986, 0.743419349193573, 0.023116787895560265, 0.11938738822937012, 0.0], [0.0031801864970475435, 0.0032259617000818253, 0.027063841000199318, 0.0018325509736314416, 0.006064774002879858, 0.017839375883340836, 0.05006564408540726, 0.8002738952636719, 0.0904538482427597, 0.0], [0.02500138245522976, 0.016465606167912483, 0.02692888118326664, 0.01824249140918255, 0.047875918447971344, 0.06556686758995056, 0.15585453808307648, 0.21941381692886353, 0.42465049028396606, 0.0], [0.07641319185495377, 0.017753547057509422, 0.039497166872024536, 0.014236720278859138, 0.03872253745794296, 0.1210501492023468, 0.17305448651313782, 0.2333979308605194, 0.28587427735328674, 0.0]], [[0.15564993023872375, 0.3264511823654175, 0.08247561007738113, 0.04047680273652077, 0.04636594280600548, 0.03705644607543945, 0.05653020739555359, 0.08808662742376328, 0.16690711677074432, 0.0], [0.6047166585922241, 0.08402378112077713, 0.11650887131690979, 0.004807815421372652, 0.02726476825773716, 0.0609126091003418, 0.02905944734811783, 0.012920884415507317, 0.059785205870866776, 0.0], [0.5938906669616699, 0.07300958037376404, 0.08890929818153381, 0.008111076429486275, 0.04038470610976219, 0.07353192567825317, 0.03085281327366829, 0.08706387132406235, 0.004246041644364595, 0.0], [0.2591831088066101, 0.17658700048923492, 0.44177621603012085, 0.01689036749303341, 0.0653892457485199, 0.01502177957445383, 0.02055797167122364, 0.0024378441739827394, 0.0021566858049482107, 0.0], [0.33400091528892517, 0.03927909955382347, 0.27614372968673706, 0.009977479465305805, 0.12025652825832367, 0.1713484674692154, 0.04292818158864975, 0.004225345328450203, 0.00184013566467911, 0.0], [0.06147114187479019, 0.019044799730181694, 0.059415291994810104, 0.05198045074939728, 0.12181691080331802, 0.419679194688797, 0.1140735000371933, 0.14551687240600586, 0.00700181070715189, 0.0], [0.006845483556389809, 0.002091927919536829, 0.01196279563009739, 0.014390786178410053, 0.02692629024386406, 0.8455513715744019, 0.07174734026193619, 0.017689114436507225, 0.0027949714567512274, 0.0], [0.00039940490387380123, 0.00013551976007875055, 0.020663700997829437, 0.008696838282048702, 0.021915050223469734, 0.1381293535232544, 0.0347108468413353, 0.7650054097175598, 0.010343861766159534, 0.0], [0.02615724503993988, 0.0051858089864254, 0.038734134286642075, 0.021585455164313316, 0.19684533774852753, 0.17548950016498566, 0.1665634661912918, 0.2796759307384491, 0.08976294845342636, 0.0], [0.043001022189855576, 0.016749290749430656, 0.04958483204245567, 0.06659381091594696, 0.0702962800860405, 0.27735820412635803, 0.14212922751903534, 0.20686522126197815, 0.12742231786251068, 0.0]]], [[[0.13086311519145966, 0.049477167427539825, 0.10100015252828598, 0.03843620419502258, 0.27287009358406067, 0.20078831911087036, 0.16546384990215302, 0.03368193656206131, 0.007419050205498934, 0.0], [0.1137659102678299, 0.11250672489404678, 0.21935509145259857, 0.09974226355552673, 0.22245454788208008, 0.11022598296403885, 0.0977952778339386, 0.010162456892430782, 0.013991687446832657, 0.0], [0.09118296205997467, 0.0991944894194603, 0.31555840373039246, 0.16625922918319702, 0.1399575173854828, 0.0926588773727417, 0.021735703572630882, 0.056496523320674896, 0.016956249251961708, 0.0], [0.35773080587387085, 0.19870112836360931, 0.026073846966028214, 0.07347559928894043, 0.09251826256513596, 0.0859094187617302, 0.06421677768230438, 0.06334269791841507, 0.0380314365029335, 0.0], [0.02230222336947918, 0.0210218857973814, 0.024334343150258064, 0.36442241072654724, 0.2750929892063141, 0.13295342028141022, 0.06824173033237457, 0.0036951478105038404, 0.0879359245300293, 0.0], [0.018942566588521004, 0.011805560439825058, 0.04696377366781235, 0.09440026432275772, 0.39890599250793457, 0.17608429491519928, 0.10613365471363068, 0.10454639047384262, 0.04221746698021889, 0.0], [0.0475851334631443, 0.008668179623782635, 0.011950161308050156, 0.0786907747387886, 0.09432563930749893, 0.07653870433568954, 0.4287588894367218, 0.13403372466564178, 0.1194487139582634, 0.0], [0.008243327029049397, 0.006908380892127752, 0.04044030234217644, 0.08380357921123505, 0.1593569815158844, 0.1858288198709488, 0.0890916958451271, 0.40247857570648193, 0.02384827472269535, 0.0], [0.09753390401601791, 0.04787491634488106, 0.10570236295461655, 0.09989321976900101, 0.07242950052022934, 0.16000299155712128, 0.13195638358592987, 0.12870465219020844, 0.15590202808380127, 0.0], [0.3338638246059418, 0.05386793985962868, 0.15485166013240814, 0.05483235418796539, 0.052468191832304, 0.12754301726818085, 0.13515245914459229, 0.06475869566202164, 0.022661946713924408, 0.0]], [[0.011833908967673779, 0.03545977920293808, 0.03510122373700142, 0.06200635805726051, 0.09438431262969971, 0.06055876612663269, 0.053256530314683914, 0.30701303482055664, 0.3403860926628113, 0.0], [0.03663749620318413, 0.06511621922254562, 0.05716057866811752, 0.07533077895641327, 0.10846659541130066, 0.037432827055454254, 0.04480022192001343, 0.18166707456111908, 0.39338818192481995, 0.0], [0.06557667255401611, 0.03966936469078064, 0.008358842693269253, 0.06794404983520508, 0.05668830871582031, 0.02720261737704277, 0.07913517951965332, 0.20437636971473694, 0.45104852318763733, 0.0], [0.044038429856300354, 0.07477934658527374, 0.10143070667982101, 0.16204005479812622, 0.06265459954738617, 0.10170722752809525, 0.08676454424858093, 0.0699862688779831, 0.2965989410877228, 0.0], [0.06005045771598816, 0.046840403228998184, 0.06629239022731781, 0.04125581681728363, 0.007815167307853699, 0.20412082970142365, 0.1083299070596695, 0.04942404478788376, 0.41587093472480774, 0.0], [0.03666035085916519, 0.028792625293135643, 0.06887229532003403, 0.18481910228729248, 0.15058831870555878, 0.048441674560308456, 0.0780390277504921, 0.13469383120536804, 0.26909276843070984, 0.0], [0.03408746421337128, 0.026394939050078392, 0.05409233644604683, 0.06951043754816055, 0.1446777582168579, 0.09970070421695709, 0.05472328141331673, 0.16119606792926788, 0.35561704635620117, 0.0], [0.12936006486415863, 0.04621516913175583, 0.10149524360895157, 0.14774896204471588, 0.45855623483657837, 0.033130910247564316, 0.031401973217725754, 0.02012830227613449, 0.031963150948286057, 0.0], [0.1214270144701004, 0.04088712856173515, 0.05250505730509758, 0.07924661785364151, 0.05337269604206085, 0.10527284443378448, 0.08820997178554535, 0.17732012271881104, 0.28175854682922363, 0.0], [0.13074854016304016, 0.06475767493247986, 0.07325490564107895, 0.0625966489315033, 0.14061231911182404, 0.07830052822828293, 0.12438739091157913, 0.21453101933002472, 0.11081094294786453, 0.0]], [[0.0022766904439777136, 0.00227623013779521, 0.027263110503554344, 0.7988243699073792, 0.12335250526666641, 0.012830986641347408, 0.008179515600204468, 0.004631126299500465, 0.020365260541439056, 0.0], [0.022365765646100044, 0.0197063609957695, 0.08540411293506622, 0.7100865840911865, 0.10288897156715393, 0.023861246183514595, 0.009303209371864796, 0.012690575793385506, 0.013693095184862614, 0.0], [0.023093748837709427, 0.013999207876622677, 0.09048538655042648, 0.10519850999116898, 0.12126202881336212, 0.34847554564476013, 0.057331401854753494, 0.0919070839881897, 0.14824725687503815, 0.0], [0.03627682104706764, 0.0323517769575119, 0.06003699079155922, 0.04609783738851547, 0.3189731240272522, 0.3202785551548004, 0.06900984793901443, 0.021341597661376, 0.0956336110830307, 0.0], [0.026664189994335175, 0.018690558150410652, 0.01473171729594469, 0.003785684471949935, 0.012891196645796299, 0.6301508545875549, 0.1024516150355339, 0.10377107560634613, 0.08686315268278122, 0.0], [0.010066811926662922, 0.005272349342703819, 0.019913937896490097, 0.005584465805441141, 0.0479762889444828, 0.06466472148895264, 0.2978198528289795, 0.22872935235500336, 0.31997203826904297, 0.0], [0.054553788155317307, 0.011876759119331837, 0.005296430550515652, 0.008171333000063896, 0.17499762773513794, 0.29638832807540894, 0.22286026179790497, 0.017016055062413216, 0.20883934199810028, 0.0], [0.03061697818338871, 0.020777547731995583, 0.27117541432380676, 0.010558649897575378, 0.16651615500450134, 0.3011224865913391, 0.026109976693987846, 0.048922766000032425, 0.12420005351305008, 0.0], [0.16545239090919495, 0.03877135366201401, 0.007565324194729328, 0.015141250565648079, 0.03747279569506645, 0.3241279125213623, 0.26990416646003723, 0.043362975120544434, 0.09820175170898438, 0.0], [0.22949647903442383, 0.0972394198179245, 0.02905140444636345, 0.03182214871048927, 0.025490015745162964, 0.08278947323560715, 0.15009135007858276, 0.031098822131752968, 0.3229208290576935, 0.0]], [[0.023217031732201576, 0.015444980934262276, 0.33269768953323364, 0.4809305965900421, 0.08491171896457672, 0.027504485100507736, 0.007655052933841944, 0.015150148421525955, 0.012488299049437046, 0.0], [0.003814368275925517, 0.0054845609702169895, 0.005400203168392181, 0.34217125177383423, 0.010647634975612164, 0.00044525362318381667, 0.00011972449283348396, 0.00042839962407015264, 0.6314883828163147, 0.0], [0.013448912650346756, 0.01028169970959425, 0.4982297718524933, 0.3182436525821686, 0.01780710555613041, 0.024587348103523254, 0.0009282209794037044, 0.11607228964567184, 0.0004009671974927187, 0.0], [0.0027270291466265917, 0.01338754128664732, 0.019254636019468307, 0.11856623739004135, 0.0025901400949805975, 0.0012062221067026258, 0.0006161375786177814, 0.0012282256502658129, 0.8404240608215332, 0.0], [1.802536098693963e-05, 0.0005015733768232167, 2.3977232558536343e-05, 0.00012258262722752988, 0.00013862864580005407, 1.9367420463822782e-05, 1.2695372788584791e-05, 2.8395381377777085e-05, 0.9991349577903748, 0.0], [0.045823611319065094, 0.0060311248525977135, 0.11489683389663696, 0.011397628113627434, 0.14236140251159668, 0.31853923201560974, 0.18707275390625, 0.16781283915042877, 0.006064609158784151, 0.0], [0.031908370554447174, 0.0013231962220743299, 0.03774190694093704, 0.014869065955281258, 0.08836144208908081, 0.662682056427002, 0.1095389723777771, 0.05017231032252312, 0.0034025281202048063, 0.0], [0.0061959377489984035, 0.012075785547494888, 0.28881579637527466, 0.0719127431511879, 0.08756363391876221, 0.0848873034119606, 0.027471251785755157, 0.404219388961792, 0.016858302056789398, 0.0], [0.0946543961763382, 0.0623893216252327, 0.18748056888580322, 0.1788652539253235, 0.03208017721772194, 0.1587594598531723, 0.05469479411840439, 0.17047303915023804, 0.06060296297073364, 0.0], [0.019481608644127846, 0.068674735724926, 0.13537795841693878, 0.2137300968170166, 0.031131863594055176, 0.02376358024775982, 0.030956387519836426, 0.04989796131849289, 0.4269856810569763, 0.0]], [[0.00896595511585474, 0.001820763573050499, 0.0036846648436039686, 0.8942996859550476, 0.002699120668694377, 0.0018430916825309396, 0.00023619653075002134, 0.0008667120710015297, 0.08558366447687149, 0.0], [0.011139868758618832, 0.00517098605632782, 0.03486357256770134, 0.92783522605896, 0.010794212110340595, 0.0029791113920509815, 0.0008399260113947093, 0.0003134821599815041, 0.006063643377274275, 0.0], [0.07888396829366684, 0.0272236131131649, 0.0322146937251091, 0.791079044342041, 0.03133838623762131, 0.009372375905513763, 0.002263500588014722, 0.0005359782953746617, 0.02708848938345909, 0.0], [0.008838528767228127, 0.0009813528740778565, 0.014693140052258968, 0.00012726498243864626, 0.013269715011119843, 0.06431703269481659, 0.0039668334648013115, 0.8607616424560547, 0.0330444760620594, 0.0], [0.028727378696203232, 0.001701394678093493, 0.0009593431605026126, 0.0036824517883360386, 0.009683175943791866, 0.2589351236820221, 0.040837112814188004, 0.01649528741836548, 0.6389787197113037, 0.0], [0.009239337407052517, 0.0011580593418329954, 0.0009623299702070653, 0.000996780814602971, 0.00493139773607254, 0.04319336265325546, 0.859686553478241, 0.012395362369716167, 0.06743697822093964, 0.0], [0.024199873208999634, 0.007249501068145037, 0.02041051909327507, 0.008800184354186058, 0.02760438062250614, 0.1116553395986557, 0.030366744846105576, 0.03851965814828873, 0.7311937808990479, 0.0], [0.06881897896528244, 0.21671976149082184, 0.02303808182477951, 0.0017656114650890231, 0.09897635877132416, 0.04207116737961769, 0.012660021893680096, 0.25307658314704895, 0.2828734517097473, 0.0], [0.09324429929256439, 0.059572815895080566, 0.021969754248857498, 0.008625463582575321, 0.022502752020955086, 0.07016356289386749, 0.033860694617033005, 0.03514377400279045, 0.6549169421195984, 0.0], [0.04541633278131485, 0.01696496643126011, 0.003866765182465315, 0.00941139180213213, 0.006640681531280279, 0.024550199508666992, 0.009012367576360703, 0.009869653731584549, 0.8742677569389343, 0.0]], [[0.007143852766603231, 0.26111796498298645, 0.053768061101436615, 0.022731401026248932, 0.014146089553833008, 0.012985849753022194, 0.007359612733125687, 0.0043042986653745174, 0.6164429783821106, 0.0], [0.6188079118728638, 0.11004422605037689, 0.07541824132204056, 0.010463211685419083, 0.003863272722810507, 0.016659650951623917, 0.028880171477794647, 0.010046081617474556, 0.1258174628019333, 0.0], [0.1673126220703125, 0.6515482068061829, 0.016748156398534775, 0.042502570897340775, 0.016912223771214485, 0.011716129258275032, 0.04548521339893341, 0.0008787817787379026, 0.04689598083496094, 0.0], [0.6625117659568787, 0.049922335892915726, 0.2738172709941864, 0.004228150937706232, 0.0033112652599811554, 0.001177642960101366, 0.0005330604617483914, 0.00011132613872177899, 0.0043872324749827385, 0.0], [0.0040039438754320145, 0.00112480903044343, 0.04353015124797821, 0.9313303232192993, 0.010056668892502785, 0.0007567661814391613, 0.0006773694767616689, 0.00016374654660467058, 0.008356312289834023, 0.0], [0.0015825676964595914, 0.001574154943227768, 0.001225732616148889, 0.27774307131767273, 0.47191065549850464, 0.041899941861629486, 0.10331469774246216, 0.0047262245789170265, 0.0960230901837349, 0.0], [0.0012044805334880948, 0.001744594657793641, 0.010911357589066029, 0.035235531628131866, 0.12406003475189209, 0.49639585614204407, 0.02129644714295864, 0.07618547230958939, 0.23296628892421722, 0.0], [0.006665610242635012, 0.008957373909652233, 0.0028928713873028755, 0.7268922924995422, 0.10707614570856094, 0.01201178040355444, 0.013845101930201054, 0.022992080077528954, 0.09866661578416824, 0.0], [0.04134861007332802, 0.0526767373085022, 0.04131396487355232, 0.023087071254849434, 0.04077164828777313, 0.027765633538365364, 0.05679082125425339, 0.025407245382666588, 0.6908383369445801, 0.0], [0.0021411015186458826, 0.012145284563302994, 0.008635377511382103, 0.004571457393467426, 0.009789393283426762, 0.022923681885004044, 0.019266795367002487, 0.15913596749305725, 0.7613908052444458, 0.0]], [[0.1481553614139557, 0.14691436290740967, 0.5575758218765259, 0.02441403828561306, 0.058879025280475616, 0.011832842603325844, 0.01016098354011774, 0.015505112707614899, 0.026562504470348358, 0.0], [0.20737145841121674, 0.2658809721469879, 0.4251604974269867, 0.03998560830950737, 0.012661930173635483, 0.003662273520603776, 0.0006891252705827355, 0.004390099551528692, 0.040197838097810745, 0.0], [0.09591562300920486, 0.13717111945152283, 0.23219715058803558, 0.020156029611825943, 0.031411558389663696, 0.04842779412865639, 0.003137993859127164, 0.03623202070593834, 0.3953508734703064, 0.0], [0.039530280977487564, 0.0770052894949913, 0.4637334942817688, 0.3284752666950226, 0.018390586599707603, 0.021701356396079063, 0.0038800504989922047, 0.01712900958955288, 0.03015456721186638, 0.0], [0.11945435404777527, 0.064958855509758, 0.07506249845027924, 0.03312050923705101, 0.045947931706905365, 0.21168209612369537, 0.1585550606250763, 0.15941208600997925, 0.13180643320083618, 0.0], [0.027811916545033455, 0.005367752630263567, 0.022701909765601158, 0.02928026206791401, 0.042085714638233185, 0.23124846816062927, 0.3448639512062073, 0.17164644598960876, 0.12499356269836426, 0.0], [0.04207267239689827, 0.0036673955619335175, 0.013221505098044872, 0.04823020473122597, 0.018784234300255775, 0.15617071092128754, 0.5762590169906616, 0.08817830681800842, 0.053416069597005844, 0.0], [0.016374358907341957, 0.01844160258769989, 0.0564517118036747, 0.008724928833544254, 0.031119121238589287, 0.08068697899580002, 0.028958600014448166, 0.08762799203395844, 0.6716147661209106, 0.0], [0.01953587494790554, 0.025442129001021385, 0.033712055534124374, 0.054231878370046616, 0.046861547976732254, 0.038379911333322525, 0.03105914779007435, 0.027592265978455544, 0.723185122013092, 0.0], [0.13706038892269135, 0.05296454578638077, 0.06056801974773407, 0.10271193832159042, 0.10989244282245636, 0.11971112340688705, 0.10623226314783096, 0.11503037810325623, 0.19582894444465637, 0.0]], [[0.1995955854654312, 0.1365700364112854, 0.06844333559274673, 0.10430964082479477, 0.06450515240430832, 0.046256136149168015, 0.1181989535689354, 0.11867640167474747, 0.14344464242458344, 0.0], [0.11168574541807175, 0.19221879541873932, 0.09424428641796112, 0.10450402647256851, 0.06917304545640945, 0.0600862056016922, 0.14199501276016235, 0.09375911951065063, 0.1323338896036148, 0.0], [0.1609925925731659, 0.1396498680114746, 0.177944153547287, 0.03334498405456543, 0.016808874905109406, 0.10536731034517288, 0.1187783032655716, 0.03365077078342438, 0.2134632021188736, 0.0], [0.1381153017282486, 0.04422784969210625, 0.04791303351521492, 0.16848880052566528, 0.14531251788139343, 0.08485772460699081, 0.03650972992181778, 0.08906612545251846, 0.24550898373126984, 0.0], [0.05420248210430145, 0.02658572979271412, 0.05446610227227211, 0.10749125480651855, 0.22097598016262054, 0.16638338565826416, 0.0331658273935318, 0.035041358321905136, 0.30168798565864563, 0.0], [0.13837963342666626, 0.02727973647415638, 0.1299334168434143, 0.0896206796169281, 0.11551950126886368, 0.13963927328586578, 0.07841819524765015, 0.02172034978866577, 0.25948914885520935, 0.0], [0.08648376911878586, 0.022083481773734093, 0.023758457973599434, 0.19388236105442047, 0.1724909394979477, 0.02776450663805008, 0.04756799340248108, 0.03839871287345886, 0.38756975531578064, 0.0], [0.06476524472236633, 0.0030945511534810066, 0.016289785504341125, 0.013512126170098782, 0.007217712234705687, 0.047962453216314316, 0.03755675256252289, 0.02134101279079914, 0.788260281085968, 0.0], [0.06007339805364609, 0.038077425211668015, 0.01732070930302143, 0.04335314407944679, 0.09849875420331955, 0.07123422622680664, 0.12536978721618652, 0.17402620613574982, 0.37204641103744507, 0.0], [0.31619662046432495, 0.10629935562610626, 0.051193755120038986, 0.08206456899642944, 0.08056272566318512, 0.05727463215589523, 0.13476009666919708, 0.03839832916855812, 0.13325001299381256, 0.0]]], [[[0.0077307759784162045, 0.013184988871216774, 0.016869038343429565, 0.013336911797523499, 0.01304439827799797, 0.013718237169086933, 0.0296618789434433, 0.02448520064353943, 0.8679684996604919, 0.0], [0.1374177634716034, 0.018056754022836685, 0.029763542115688324, 0.004862301517277956, 0.00231130956672132, 0.006278112530708313, 0.012106452137231827, 0.033879589289426804, 0.755324125289917, 0.0], [0.1539316475391388, 0.23461903631687164, 0.033691998571157455, 0.026462335139513016, 0.0030949951615184546, 0.0038835303857922554, 0.009438932873308659, 0.0025479686446487904, 0.5323294997215271, 0.0], [0.29215019941329956, 0.05790534242987633, 0.18934868276119232, 0.018473153933882713, 0.002999690594151616, 0.004652327857911587, 0.010374259203672409, 0.0072145056910812855, 0.4168816804885864, 0.0], [0.16539201140403748, 0.05819307267665863, 0.12084146589040756, 0.1738077849149704, 0.004504370968788862, 0.006831282749772072, 0.02180996537208557, 0.012287246994674206, 0.4363327622413635, 0.0], [0.017602156847715378, 0.026805447414517403, 0.07671570032835007, 0.5152483582496643, 0.21202509105205536, 0.041201505810022354, 0.02207496576011181, 0.00952092744410038, 0.07880578190088272, 0.0], [0.017750855535268784, 0.01654047518968582, 0.07482129335403442, 0.23223723471164703, 0.3542158007621765, 0.16141267120838165, 0.02249749004840851, 0.044686269015073776, 0.07583795487880707, 0.0], [0.01755565032362938, 0.008823209442198277, 0.013083218596875668, 0.5061533451080322, 0.02344801276922226, 0.0075739468447864056, 0.07187878340482712, 0.029167035594582558, 0.3223167061805725, 0.0], [0.07037408649921417, 0.05453738570213318, 0.05508268624544144, 0.02769530564546585, 0.038050826638936996, 0.20446287095546722, 0.19980187714099884, 0.19835616648197174, 0.15163862705230713, 0.0], [0.003375247586518526, 0.008793321438133717, 0.001001630094833672, 0.002094975672662258, 0.0032946632709354162, 0.01792662777006626, 0.10988471657037735, 0.21093924343585968, 0.6426896452903748, 0.0]], [[0.07041527330875397, 0.15367093682289124, 0.3963199257850647, 0.03077671490609646, 0.0928598940372467, 0.04086732864379883, 0.018142100423574448, 0.012120239436626434, 0.18482762575149536, 0.0], [0.0367966964840889, 0.022482391446828842, 0.35830163955688477, 0.02875097654759884, 0.03547174483537674, 0.026731541380286217, 0.005365677177906036, 0.038472291082143784, 0.4476269483566284, 0.0], [0.035722482949495316, 0.013633755035698414, 0.09877835214138031, 0.0896211713552475, 0.16777706146240234, 0.10725134611129761, 0.05053357034921646, 0.10712091624736786, 0.32956135272979736, 0.0], [0.026592494919896126, 0.002020884770900011, 0.010739283636212349, 0.015951883047819138, 0.18538028001785278, 0.16766443848609924, 0.03731367364525795, 0.3853055238723755, 0.16903170943260193, 0.0], [0.0038862666115164757, 0.0004870722477789968, 0.0013956124894320965, 0.001421120367012918, 0.013834443874657154, 0.18579153716564178, 0.18213258683681488, 0.5258954763412476, 0.08515587449073792, 0.0], [0.002724156714975834, 0.004415807779878378, 0.003638928523287177, 0.00862019695341587, 0.010569852776825428, 0.18068262934684753, 0.2256886065006256, 0.3616458773612976, 0.20201392471790314, 0.0], [0.004798010922968388, 0.006960091646760702, 0.005558326840400696, 0.015252271667122841, 0.010058294981718063, 0.16163121163845062, 0.19844789803028107, 0.089196115732193, 0.508097767829895, 0.0], [0.006148195825517178, 0.009931370615959167, 0.0022139709908515215, 0.003481896361336112, 0.00199966412037611, 0.011451391503214836, 0.018514955416321754, 0.10389390587806702, 0.8423647284507751, 0.0], [0.1474374532699585, 0.1116698756814003, 0.10040155798196793, 0.07832593470811844, 0.051569730043411255, 0.103182852268219, 0.0987909808754921, 0.06264416873455048, 0.24597744643688202, 0.0], [0.0029639359563589096, 0.0008504446013830602, 0.002215511864051223, 0.0016108372947201133, 0.001786046545021236, 0.003435377962887287, 0.000923731888178736, 0.009203500114381313, 0.9770104885101318, 0.0]], [[0.022387586534023285, 0.045972827821969986, 0.05835629999637604, 0.22869053483009338, 0.010770916007459164, 0.006216464098542929, 0.018148910254240036, 0.006308646872639656, 0.6031478047370911, 0.0], [0.0035997454542666674, 0.002674269489943981, 0.016009783372282982, 0.05554450675845146, 0.0013587778666988015, 0.0032801039051264524, 0.00560772093012929, 0.00799081102013588, 0.9039342403411865, 0.0], [0.0018151046242564917, 0.001049908110871911, 0.0005912692868150771, 0.005136367864906788, 0.0005621784366667271, 0.00844560656696558, 0.017937110736966133, 0.008342047221958637, 0.9561205506324768, 0.0], [0.004429707303643227, 0.005516116041690111, 0.003033371875062585, 0.012963998131453991, 0.0034379358403384686, 0.003276604227721691, 0.0140963364392519, 0.005416945554316044, 0.9478288888931274, 0.0], [0.013230006210505962, 0.011804360896348953, 0.009972047992050648, 0.004975683055818081, 0.008386109955608845, 0.18977868556976318, 0.1806434541940689, 0.03204761818051338, 0.549161970615387, 0.0], [0.006766628473997116, 0.008349079638719559, 0.003925195895135403, 0.0006033667596057057, 0.006175691727548838, 0.2236345112323761, 0.03405819088220596, 0.07976362109184265, 0.6367236971855164, 0.0], [0.03860252723097801, 0.01646261475980282, 0.02104821614921093, 0.0021387943997979164, 0.005319601856172085, 0.2400989532470703, 0.03188503161072731, 0.005558657925575972, 0.6388856768608093, 0.0], [0.0738457515835762, 0.018826894462108612, 0.0069308048114180565, 0.0074225678108632565, 0.004789229016751051, 0.046955253928899765, 0.11907684803009033, 0.18744726479053497, 0.5347052812576294, 0.0], [0.09479796141386032, 0.11939200013875961, 0.0752992108464241, 0.061374519020318985, 0.08638977259397507, 0.12459041178226471, 0.16023214161396027, 0.0879756435751915, 0.1899482160806656, 0.0], [0.027537798509001732, 0.06296242028474808, 0.014751194976270199, 0.0011882808757945895, 0.016387099400162697, 0.15830224752426147, 0.03707461059093475, 0.028470970690250397, 0.6533253788948059, 0.0]], [[0.044569190591573715, 0.00917287077754736, 0.004391324240714312, 0.8386606574058533, 0.06130588799715042, 0.003870139131322503, 0.007488539442420006, 0.028126200661063194, 0.002415221417322755, 0.0], [0.08009635657072067, 0.016815535724163055, 0.012093844823539257, 0.1592065542936325, 0.5643750429153442, 0.02920410968363285, 0.0919446051120758, 0.036902546882629395, 0.009361499920487404, 0.0], [0.08603464066982269, 0.01933746039867401, 0.05900268629193306, 0.2806539237499237, 0.22094620764255524, 0.08643656224012375, 0.026435989886522293, 0.1974046230316162, 0.023747902363538742, 0.0], [0.011109462939202785, 0.008883769623935223, 0.006091873627156019, 0.9400036931037903, 0.020445559173822403, 0.0056496066972613335, 0.0019461432239040732, 0.005268549080938101, 0.000601345207542181, 0.0], [0.12766019999980927, 0.021774157881736755, 0.08726640790700912, 0.0718328207731247, 0.053083695471286774, 0.3031027019023895, 0.06321869790554047, 0.2611844837665558, 0.010876962915062904, 0.0], [0.06407223641872406, 0.11627303808927536, 0.1807759404182434, 0.0054795523174107075, 0.026687098667025566, 0.09637009352445602, 0.052303463220596313, 0.4456423819065094, 0.012396130710840225, 0.0], [0.09725438803434372, 0.15948796272277832, 0.05173082649707794, 0.01153761800378561, 0.0721999853849411, 0.059252724051475525, 0.11923323571681976, 0.05380275845527649, 0.3755004107952118, 0.0], [0.0974307581782341, 0.23218762874603271, 0.06967660784721375, 0.031012043356895447, 0.04906507954001427, 0.31767621636390686, 0.08231117576360703, 0.07159094512462616, 0.04904941841959953, 0.0], [0.16388627886772156, 0.17526376247406006, 0.07081529498100281, 0.17886894941329956, 0.07944575697183609, 0.07640470564365387, 0.0757102444767952, 0.04333823174238205, 0.13626690208911896, 0.0], [0.09265941381454468, 0.11633585393428802, 0.04908691346645355, 0.0062498715706169605, 0.07016508281230927, 0.012818480841815472, 0.0484321266412735, 0.015437646768987179, 0.5888146162033081, 0.0]], [[0.005850312765687704, 0.017421673983335495, 0.004798548296093941, 0.008814580738544464, 0.00403921864926815, 0.015260725282132626, 0.03377071022987366, 0.009620469063520432, 0.9004237651824951, 0.0], [0.013436811044812202, 0.03272867575287819, 0.005969575606286526, 0.02213078737258911, 0.008325905539095402, 0.015314633026719093, 0.027177294716238976, 0.017041552811861038, 0.8578747510910034, 0.0], [0.012305106967687607, 0.03383316844701767, 0.010593314655125141, 0.027156231924891472, 0.00306991720572114, 0.004844812210649252, 0.018964877352118492, 0.05307865887880325, 0.8361539244651794, 0.0], [0.01255274098366499, 0.03600494936108589, 0.010369472205638885, 0.019021298736333847, 0.0032906190026551485, 0.0037067385856062174, 0.017627976834774017, 0.01037716306746006, 0.8870489597320557, 0.0], [0.0020879805088043213, 0.013872731477022171, 0.0018324662232771516, 0.006437606178224087, 0.013170951046049595, 0.011930068954825401, 0.0030771365854889154, 0.018353432416915894, 0.9292376041412354, 0.0], [0.022865016013383865, 0.0674654096364975, 0.00996339786797762, 0.01914660632610321, 0.014956261031329632, 0.026097828522324562, 0.018910687416791916, 0.06562207639217377, 0.7549726963043213, 0.0], [0.019515078514814377, 0.07521340996026993, 0.03206341341137886, 0.0070005785673856735, 0.0066195218823850155, 0.03877842426300049, 0.01228683814406395, 0.032381508499383926, 0.7761411666870117, 0.0], [0.0022571254521608353, 0.00383052253164351, 0.0012509305961430073, 0.005982697941362858, 0.001252268673852086, 0.0028570422437042, 0.00556317949667573, 0.7337145805358887, 0.24329175055027008, 0.0], [0.132036030292511, 0.1683972030878067, 0.13758207857608795, 0.14189518988132477, 0.03147142380475998, 0.047566916793584824, 0.07834812998771667, 0.12177446484565735, 0.1409287303686142, 0.0], [0.038695693016052246, 0.06063547730445862, 0.020152689889073372, 0.1101006418466568, 0.021127784624695778, 0.02848564088344574, 0.03705665469169617, 0.108894944190979, 0.5748504996299744, 0.0]], [[0.0842718631029129, 0.33930304646492004, 0.1421334594488144, 0.18528752028942108, 0.05815916135907173, 0.022830937057733536, 0.01860896497964859, 0.009871570393443108, 0.13953347504138947, 0.0], [0.1295168399810791, 0.08884051442146301, 0.06592670828104019, 0.2686370015144348, 0.02522267960011959, 0.03633918985724449, 0.021549394354224205, 0.051057688891887665, 0.31290990114212036, 0.0], [0.20177747309207916, 0.039902716875076294, 0.053595658391714096, 0.09988140314817429, 0.01657777465879917, 0.07154539972543716, 0.024320384487509727, 0.12353017926216125, 0.36886897683143616, 0.0], [0.20979411900043488, 0.30265647172927856, 0.20804975926876068, 0.007371237967163324, 0.0033807901199907064, 0.02442527562379837, 0.017248263582587242, 0.022337088361382484, 0.2047368586063385, 0.0], [0.20739784836769104, 0.05559583380818367, 0.12535981833934784, 0.009768493473529816, 0.015522800385951996, 0.024528708308935165, 0.03864477947354317, 0.08712086826562881, 0.4360608160495758, 0.0], [0.054182324558496475, 0.0038395673036575317, 0.019914912059903145, 0.014234894886612892, 0.012772555463016033, 0.019022708758711815, 0.04023807495832443, 0.36886361241340637, 0.46693122386932373, 0.0], [0.10943998396396637, 0.007069440558552742, 0.019821595400571823, 0.012627309188246727, 0.016869045794010162, 0.05302179232239723, 0.05124732851982117, 0.14304620027542114, 0.5868573188781738, 0.0], [0.23190192878246307, 0.060554418712854385, 0.05880776792764664, 0.00438718032091856, 0.00454165181145072, 0.15464532375335693, 0.09585105627775192, 0.02281157113611698, 0.3664989471435547, 0.0], [0.13208958506584167, 0.15704363584518433, 0.07176639884710312, 0.08554346114397049, 0.0733223557472229, 0.0956358015537262, 0.07472448796033859, 0.09573546797037125, 0.21413877606391907, 0.0], [0.28021925687789917, 0.045415956526994705, 0.04812552034854889, 0.00880114920437336, 0.012029618956148624, 0.04001859948039055, 0.0577121265232563, 0.02487611398100853, 0.4828015863895416, 0.0]], [[0.057871319353580475, 0.09845025092363358, 0.03600643575191498, 0.06401734054088593, 0.07263048738241196, 0.014885936863720417, 0.07473781704902649, 0.1193607747554779, 0.4620397090911865, 0.0], [0.04736582189798355, 0.06951819360256195, 0.039210546761751175, 0.040616557002067566, 0.05645532160997391, 0.01900673843920231, 0.063181072473526, 0.23291724920272827, 0.43172842264175415, 0.0], [0.05375257506966591, 0.04091374948620796, 0.01263821218162775, 0.04125160351395607, 0.014244006015360355, 0.012229752726852894, 0.029117466881871223, 0.07314542680978775, 0.7227071523666382, 0.0], [0.04062311723828316, 0.2312910407781601, 0.20085060596466064, 0.03848586603999138, 0.04763459786772728, 0.013425372540950775, 0.027237186208367348, 0.03882591798901558, 0.3616262376308441, 0.0], [0.04558584839105606, 0.04037311673164368, 0.043737076222896576, 0.02740027941763401, 0.005366531666368246, 0.014126299880445004, 0.07268305867910385, 0.014923120848834515, 0.7358046770095825, 0.0], [0.04308110475540161, 0.037618398666381836, 0.054927192628383636, 0.045146394520998, 0.02157701551914215, 0.014024189673364162, 0.03546718508005142, 0.04130468890070915, 0.7068538665771484, 0.0], [0.049693867564201355, 0.040712278336286545, 0.011129319667816162, 0.08677691221237183, 0.24132831394672394, 0.028864668682217598, 0.04710082337260246, 0.028962818905711174, 0.46543097496032715, 0.0], [0.0256647989153862, 0.026453843340277672, 0.1064542606472969, 0.07867259532213211, 0.03285365179181099, 0.056291256099939346, 0.026517342776060104, 0.014768523164093494, 0.6323237419128418, 0.0], [0.07950135320425034, 0.04906205087900162, 0.09099037200212479, 0.10450085997581482, 0.06846266984939575, 0.21755923330783844, 0.14818403124809265, 0.14456483721733093, 0.0971745029091835, 0.0], [0.0029191188514232635, 0.0026039828080683947, 0.000987313687801361, 0.027727283537387848, 0.007311245426535606, 0.0033244043588638306, 0.023969389498233795, 0.00596341909840703, 0.9251939058303833, 0.0]], [[0.3116825819015503, 0.20666195452213287, 0.1363646388053894, 0.07141851633787155, 0.029045483097434044, 0.04730900749564171, 0.037391580641269684, 0.10128220915794373, 0.05884409323334694, 0.0], [0.16883184015750885, 0.1785622239112854, 0.23318006098270416, 0.05258537083864212, 0.10740725696086884, 0.09185276180505753, 0.022670285776257515, 0.08943870663642883, 0.05547139048576355, 0.0], [0.13844002783298492, 0.030681077390909195, 0.12107612937688828, 0.007168712094426155, 0.05214103311300278, 0.10045275092124939, 0.006991118658334017, 0.2955043315887451, 0.2475447952747345, 0.0], [0.17306901514530182, 0.19963578879833221, 0.148520827293396, 0.2046978771686554, 0.06720028817653656, 0.019652126356959343, 0.05792365223169327, 0.07665500044822693, 0.05264541134238243, 0.0], [0.15115797519683838, 0.06555280834436417, 0.02683498151600361, 0.20794028043746948, 0.17434173822402954, 0.11980342864990234, 0.04239796847105026, 0.03961418569087982, 0.1723567098379135, 0.0], [0.022970011457800865, 0.003322059055790305, 0.03333919122815132, 0.03161727264523506, 0.3007935583591461, 0.20675496757030487, 0.037206847220659256, 0.30415406823158264, 0.059842076152563095, 0.0], [0.029802093282341957, 0.006723356898874044, 0.00844306219369173, 0.09911961853504181, 0.2257867157459259, 0.22737178206443787, 0.28318148851394653, 0.05687837302684784, 0.0626935064792633, 0.0], [0.008192392997443676, 0.002076511736959219, 0.010627061128616333, 0.01573592983186245, 0.01893553137779236, 0.042316026985645294, 0.02403445728123188, 0.868257999420166, 0.009823988191783428, 0.0], [0.08033955097198486, 0.18780502676963806, 0.02817567251622677, 0.041370414197444916, 0.02824225462973118, 0.038844767957925797, 0.11732563376426697, 0.06162348762154579, 0.4162730574607849, 0.0], [0.0563269779086113, 0.007248359732329845, 0.008029782213270664, 0.003040844574570656, 0.007221699226647615, 0.01730128563940525, 0.050128430128097534, 0.3587413430213928, 0.491961270570755, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"], \"bot_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"]}, \"inp_out\": {\"att\": [[[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9198169708251953, 0.0801829993724823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8846490979194641, 0.10308036208152771, 0.012270578183233738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9307316541671753, 0.03309628367424011, 0.027538668364286423, 0.008633385412395, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9335180521011353, 0.020782457664608955, 0.008113296702504158, 0.029529055580496788, 0.008057110011577606, 0.0, 0.0, 0.0, 0.0, 0.0], [0.923790454864502, 0.01269624661654234, 0.004588128533214331, 0.020286502316594124, 0.018672045320272446, 0.019966628402471542, 0.0, 0.0, 0.0, 0.0], [0.5214514136314392, 0.051599469035863876, 0.007387364283204079, 0.04305899888277054, 0.0632161945104599, 0.07775087654590607, 0.2355356514453888, 0.0, 0.0, 0.0], [0.9122877717018127, 0.007671441417187452, 0.0012418286642059684, 0.005250561982393265, 0.001960531808435917, 0.032091617584228516, 0.03012256510555744, 0.009373520500957966, 0.0, 0.0], [0.012450892478227615, 0.0001350480888504535, 0.0001820741599658504, 0.0018266986589878798, 0.00022605709091294557, 0.0032795630395412445, 0.005876350682228804, 0.012136856094002724, 0.9638864398002625, 0.0], [0.907938539981842, 0.003707215888425708, 0.003004483412951231, 0.0008324749651364982, 0.0015859504928812385, 0.008079104125499725, 0.010460118763148785, 0.005838368553668261, 0.038938846439123154, 0.019614921882748604]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4050312936306, 0.5949686765670776, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2333158701658249, 0.39531010389328003, 0.37137407064437866, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.52278733253479, 0.11893566697835922, 0.28584957122802734, 0.07242746651172638, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23179638385772705, 0.09258762001991272, 0.103512242436409, 0.19472002983093262, 0.37738385796546936, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3839746117591858, 0.05338669568300247, 0.09416119009256363, 0.09689370542764664, 0.24871283769607544, 0.12287086993455887, 0.0, 0.0, 0.0, 0.0], [0.5838866233825684, 0.02439245954155922, 0.042716383934020996, 0.03342103213071823, 0.08018141984939575, 0.15234005451202393, 0.08306187391281128, 0.0, 0.0, 0.0], [0.639571487903595, 0.016348807141184807, 0.038869310170412064, 0.02800355665385723, 0.0377902127802372, 0.0529697984457016, 0.07620508968830109, 0.11024164408445358, 0.0, 0.0], [0.5836893320083618, 0.011862898245453835, 0.02550557814538479, 0.009363977238535881, 0.0196645837277174, 0.018125057220458984, 0.07040998339653015, 0.2077602595090866, 0.053618304431438446, 0.0], [0.49946048855781555, 0.04904361814260483, 0.04135226085782051, 0.015084759332239628, 0.018269173800945282, 0.020069265738129616, 0.05080949887633324, 0.09452320635318756, 0.06869905441999435, 0.14268863201141357]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9956012964248657, 0.00439875153824687, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8920916318893433, 0.017498359084129333, 0.09041006118059158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8103601336479187, 0.011479738168418407, 0.14884205162525177, 0.029318034648895264, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9073429107666016, 0.017702236771583557, 0.0008831396116875112, 0.017153160646557808, 0.05691858008503914, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7007134556770325, 0.00013011474220547825, 0.0017889889422804117, 0.00429273396730423, 0.20973503589630127, 0.08333952724933624, 0.0, 0.0, 0.0, 0.0], [0.8020992279052734, 0.0005838978104293346, 0.0002877263759728521, 0.000665249943267554, 0.00924165453761816, 0.10947777330875397, 0.07764454185962677, 0.0, 0.0, 0.0], [0.936653733253479, 0.00026242269086651504, 0.0004762547614518553, 0.000683068297803402, 0.0005867508007213473, 0.008624686859548092, 0.044821251183748245, 0.00789186917245388, 0.0, 0.0], [0.638530433177948, 0.00012756754586007446, 2.6267471184837632e-05, 0.035790614783763885, 0.00038457714254036546, 0.0026843701489269733, 0.0740678533911705, 0.21536435186862946, 0.03302408382296562, 0.0], [0.9069857597351074, 0.0010905838571488857, 0.0003166680980939418, 0.0021527763456106186, 0.00019805191550403833, 0.0004849489778280258, 0.025774035602808, 0.02642407827079296, 0.01662513054907322, 0.01994791068136692]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9964158535003662, 0.0035840808413922787, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.603236198425293, 0.29069802165031433, 0.10606581717729568, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7401933073997498, 0.005742713809013367, 0.18690980970859528, 0.06715414673089981, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9087624549865723, 0.0078224902972579, 0.003505129599943757, 0.0673881471157074, 0.012521738186478615, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7394620180130005, 0.0234938096255064, 0.009907918982207775, 0.01616108976304531, 0.1237591803073883, 0.08721596747636795, 0.0, 0.0, 0.0, 0.0], [0.9526587724685669, 0.007287254091352224, 0.0013716809917241335, 0.0023222684394568205, 0.007607423700392246, 0.009167732670903206, 0.01958492584526539, 0.0, 0.0, 0.0], [0.9270981550216675, 0.004809631034731865, 0.0030887839384377003, 0.005205564666539431, 0.018441975116729736, 0.006030889227986336, 0.03003735840320587, 0.0052877976559102535, 0.0, 0.0], [0.603268563747406, 0.009098237380385399, 0.00021995518181938678, 0.07179546356201172, 0.0017328117974102497, 0.01055157370865345, 0.020978767424821854, 0.2736198902130127, 0.008734744042158127, 0.0], [0.6497007608413696, 0.0906025841832161, 0.0100435521453619, 0.007925360463559628, 0.013416239991784096, 0.0018666544929146767, 0.02140365168452263, 0.08128199726343155, 0.04188578948378563, 0.08187359571456909]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9857779741287231, 0.014221975579857826, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9197340607643127, 0.07413885742425919, 0.0061270855367183685, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8673564195632935, 0.016403868794441223, 0.1017053872346878, 0.014534366317093372, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.044595908373594284, 0.010755550116300583, 0.002565854461863637, 0.9345642328262329, 0.007518457714468241, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4605148434638977, 0.007289387751370668, 0.009601963683962822, 0.08598940074443817, 0.4091304838657379, 0.027473902329802513, 0.0, 0.0, 0.0, 0.0], [0.8714936971664429, 0.002528996206820011, 0.0021269593853503466, 0.0052809687331318855, 0.02593054249882698, 0.07010670751333237, 0.022532090544700623, 0.0, 0.0, 0.0], [0.507957398891449, 0.003823956474661827, 0.004157013725489378, 0.018131878226995468, 0.06916838884353638, 0.047881923615932465, 0.2798653542995453, 0.06901402771472931, 0.0, 0.0], [0.4575899839401245, 0.005646431352943182, 0.0004441867640707642, 0.03129462152719498, 0.014414624311029911, 0.0058625745587050915, 0.09207130968570709, 0.34311652183532715, 0.04955975338816643, 0.0], [0.8105311393737793, 0.0010255038505420089, 0.0001402802881784737, 0.0005781117943115532, 0.00122542935423553, 0.000594198820181191, 0.02804729714989662, 0.01081023644655943, 0.13665232062339783, 0.010395429097115993]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8512031435966492, 0.14879685640335083, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10041537135839462, 0.8953256011009216, 0.0042589944787323475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6295948624610901, 0.2121732085943222, 0.10306572169065475, 0.055166181176900864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9503376483917236, 0.007425909396260977, 0.0019253676291555166, 0.025024304166436195, 0.015286784619092941, 0.0, 0.0, 0.0, 0.0, 0.0], [0.24298420548439026, 0.06981680542230606, 0.030552756041288376, 0.020666545256972313, 0.46177101135253906, 0.1742086559534073, 0.0, 0.0, 0.0, 0.0], [0.8132306933403015, 0.003601218806579709, 0.01019350253045559, 0.009439423680305481, 0.040081463754177094, 0.07570415735244751, 0.04774952307343483, 0.0, 0.0, 0.0], [0.6454712152481079, 0.006356438156217337, 0.006696825381368399, 0.0020169378258287907, 0.11416922509670258, 0.11139311641454697, 0.07912010699510574, 0.03477614000439644, 0.0, 0.0], [0.22032444179058075, 0.0006508066435344517, 0.006827942095696926, 0.028858821839094162, 0.0022757677361369133, 0.006474251858890057, 0.09447979182004929, 0.6212162375450134, 0.018891895189881325, 0.0], [0.03250038996338844, 0.0005526043241843581, 2.807211239996832e-05, 0.00014761221245862544, 0.00482193985953927, 7.781770545989275e-05, 0.00014718669990543276, 0.0008632297394797206, 0.959712028503418, 0.0011490467004477978]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9700191020965576, 0.029980869963765144, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7072298526763916, 0.2173422873020172, 0.07542789727449417, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5017270445823669, 0.10517530888319016, 0.32087045907974243, 0.07222715020179749, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39005738496780396, 0.2261916995048523, 0.1838584840297699, 0.10916081070899963, 0.09073163568973541, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11122927069664001, 0.04386316239833832, 0.023478534072637558, 0.07375308126211166, 0.5692906379699707, 0.17838534712791443, 0.0, 0.0, 0.0, 0.0], [0.16762810945510864, 0.030268238857388496, 0.015392551198601723, 0.05242612585425377, 0.21519990265369415, 0.34948840737342834, 0.16959665715694427, 0.0, 0.0, 0.0], [0.15348000824451447, 0.03554287180304527, 0.008979924954473972, 0.07115276902914047, 0.08698276430368423, 0.24143245816230774, 0.28553345799446106, 0.11689584702253342, 0.0, 0.0], [0.09456975758075714, 0.010759694501757622, 0.0067994119599461555, 0.01042863354086876, 0.05627141892910004, 0.11228546500205994, 0.14361944794654846, 0.3204572796821594, 0.2448090761899948, 0.0], [0.057867951691150665, 0.02229062095284462, 0.016399098560214043, 0.02521427348256111, 0.047808028757572174, 0.03428687900304794, 0.05170976370573044, 0.19979508221149445, 0.41991233825683594, 0.12471600621938705]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9535994529724121, 0.04640045389533043, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8665578961372375, 0.09402694553136826, 0.03941517323255539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8201385140419006, 0.07587680220603943, 0.05075912922620773, 0.053225547075271606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6245242953300476, 0.093341164290905, 0.11281723529100418, 0.1092497780919075, 0.06006752699613571, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5755861401557922, 0.0864969864487648, 0.10001320391893387, 0.12654373049736023, 0.06871193647384644, 0.04264802858233452, 0.0, 0.0, 0.0, 0.0], [0.6500274538993835, 0.06470640748739243, 0.047299426048994064, 0.08855419605970383, 0.06197808310389519, 0.04487667977809906, 0.04255769029259682, 0.0, 0.0, 0.0], [0.5771223902702332, 0.0491044707596302, 0.09411156177520752, 0.06903567165136337, 0.04109871760010719, 0.06523709744215012, 0.06637011468410492, 0.03792000934481621, 0.0, 0.0], [0.4695849120616913, 0.017787985503673553, 0.06290572881698608, 0.06516575813293457, 0.09894091635942459, 0.03647425398230553, 0.051347069442272186, 0.08907806128263474, 0.10871540009975433, 0.0], [0.18501408398151398, 0.040740884840488434, 0.10466982424259186, 0.07660976052284241, 0.17033715546131134, 0.05819392204284668, 0.0898737907409668, 0.09184892475605011, 0.10470453649759293, 0.0780070349574089]]], [[[0.10875418037176132, 0.15107707679271698, 0.07560893893241882, 0.11182637512683868, 0.051575273275375366, 0.1800614595413208, 0.13901139795780182, 0.11257244646549225, 0.06951297074556351, 0.0], [0.04530828073620796, 0.11530135571956635, 0.03132164478302002, 0.12301183491945267, 0.01339547149837017, 0.009322633035480976, 0.0069213854148983955, 0.181557297706604, 0.47386014461517334, 0.0], [0.08671615272760391, 0.21926835179328918, 0.11249969899654388, 0.05250205472111702, 0.044286634773015976, 0.006910341326147318, 0.004434189759194851, 0.00961831770837307, 0.4637643098831177, 0.0], [0.016148164868354797, 0.08668603748083115, 0.1414848268032074, 0.024200299754738808, 0.018711188808083534, 0.02537006139755249, 0.017450006678700447, 0.039331331849098206, 0.6306182146072388, 0.0], [0.024489276111125946, 0.03301851078867912, 0.03003605268895626, 0.03562680631875992, 0.06981870532035828, 0.022592445835471153, 0.025447512045502663, 0.03545365110039711, 0.7235170006752014, 0.0], [0.05760658532381058, 0.08793947100639343, 0.053903114050626755, 0.0679689273238182, 0.007038408424705267, 0.007889931090176105, 0.010035911574959755, 0.019540006294846535, 0.6880777478218079, 0.0], [0.045610494911670685, 0.042210742831230164, 0.14248158037662506, 0.03233090415596962, 0.03048519603908062, 0.011738738045096397, 0.014284060336649418, 0.006383211817592382, 0.6744750738143921, 0.0], [0.096277616918087, 0.030696624889969826, 0.10220203548669815, 0.04915016517043114, 0.047845132648944855, 0.05814794450998306, 0.06954183429479599, 0.028650736436247826, 0.5174878835678101, 0.0], [0.009306053631007671, 0.02153283730149269, 0.009718294255435467, 0.005953253246843815, 0.011703923344612122, 0.017902903258800507, 0.011090915650129318, 0.01645584963262081, 0.8963360786437988, 0.0], [0.009895006194710732, 0.026821313425898552, 0.16079027950763702, 0.01761648990213871, 0.01726638339459896, 0.08361288905143738, 0.039622098207473755, 0.14411716163158417, 0.5002583861351013, 0.0]], [[0.0543275885283947, 0.01742306910455227, 0.05347726121544838, 0.18824619054794312, 0.09003543108701706, 0.08433128148317337, 0.1953076422214508, 0.206686869263649, 0.11016455292701721, 0.0], [0.00859006680548191, 0.02184058353304863, 0.02418440766632557, 0.03131486475467682, 0.03273439407348633, 0.06774082779884338, 0.1731010377407074, 0.09275981038808823, 0.5477339029312134, 0.0], [0.02145911566913128, 0.046526145190000534, 0.014734850265085697, 0.026213468983769417, 0.04904777929186821, 0.08567024767398834, 0.13810616731643677, 0.03392839804291725, 0.5843138694763184, 0.0], [0.019245177507400513, 0.01515401341021061, 0.027409562841057777, 0.0068243746645748615, 0.07997982203960419, 0.0921224057674408, 0.04510754346847534, 0.04373685643076897, 0.670420229434967, 0.0], [0.04381020739674568, 0.06711422652006149, 0.07609888166189194, 0.021496189758181572, 0.05042967572808266, 0.15614424645900726, 0.11071597784757614, 0.14296749234199524, 0.3312230408191681, 0.0], [0.04100082442164421, 0.030313873663544655, 0.032653506845235825, 0.0695231482386589, 0.12672685086727142, 0.12515434622764587, 0.08855390548706055, 0.05835743993520737, 0.4277162253856659, 0.0], [0.14112897217273712, 0.06592341512441635, 0.06986766308546066, 0.06311382353305817, 0.12678426504135132, 0.04950721934437752, 0.08025017380714417, 0.03467738255858421, 0.36874714493751526, 0.0], [0.02841436117887497, 0.022568009793758392, 0.014519155025482178, 0.019271234050393105, 0.018120555207133293, 0.036434635519981384, 0.014109926298260689, 0.24622198939323425, 0.6003400683403015, 0.0], [0.05730762332677841, 0.07724729180335999, 0.030861826613545418, 0.04063780978322029, 0.08539344370365143, 0.029541905969381332, 0.02964094467461109, 0.028206804767251015, 0.6211622953414917, 0.0], [0.20915710926055908, 0.193747878074646, 0.11181499063968658, 0.07680925726890564, 0.04479793831706047, 0.03787367418408394, 0.04819086939096451, 0.11330965161323547, 0.1642986238002777, 0.0]], [[0.038908280432224274, 0.07760688662528992, 0.062413811683654785, 0.0023113787174224854, 0.0021746077109128237, 0.015095214359462261, 0.003646473865956068, 0.038165315985679626, 0.759678065776825, 0.0], [0.015742339193820953, 0.029524141922593117, 0.0550379604101181, 0.16926467418670654, 0.035933610051870346, 0.03279981389641762, 0.03188418969511986, 0.5383173227310181, 0.09149592369794846, 0.0], [0.022741766646504402, 0.013864121399819851, 0.06161126494407654, 0.06985131651163101, 0.03954875469207764, 0.02864447981119156, 0.036658816039562225, 0.05774570629000664, 0.6693336963653564, 0.0], [0.06077639013528824, 0.053226571530103683, 0.05544588342308998, 0.08368532359600067, 0.04779139161109924, 0.028960514813661575, 0.03463221713900566, 0.42419588565826416, 0.21128588914871216, 0.0], [0.03320460394024849, 0.07872876524925232, 0.0791814923286438, 0.008506255224347115, 0.010383618995547295, 0.021636927500367165, 0.009444555267691612, 0.026183925569057465, 0.7327298521995544, 0.0], [0.14095324277877808, 0.17195045948028564, 0.04960065335035324, 0.02801741287112236, 0.02789357118308544, 0.0246508177369833, 0.027228642255067825, 0.008449538610875607, 0.521255612373352, 0.0], [0.01678302139043808, 0.02193976752460003, 0.13912786543369293, 0.05168221518397331, 0.06239692494273186, 0.008615943603217602, 0.037501659244298935, 0.02482585795223713, 0.6371266841888428, 0.0], [0.03396642208099365, 0.07778684049844742, 0.18657010793685913, 0.11281172931194305, 0.019890569150447845, 0.012303605675697327, 0.0494060292840004, 0.11448060721158981, 0.39278414845466614, 0.0], [0.02684134803712368, 0.03310805931687355, 0.163743257522583, 0.014529252424836159, 0.10077258199453354, 0.044357266277074814, 0.04152251034975052, 0.10173188894987106, 0.4733937382698059, 0.0], [0.01862592063844204, 0.022009190171957016, 0.028925148770213127, 0.006837732624262571, 0.006956242956221104, 0.010202805511653423, 0.015325144864618778, 0.11640346795320511, 0.7747144103050232, 0.0]], [[0.0830092504620552, 0.0839436799287796, 0.10106679797172546, 0.11154499650001526, 0.045070260763168335, 0.1284436285495758, 0.1161414161324501, 0.19574469327926636, 0.1350351870059967, 0.0], [0.0006529411766678095, 0.0018492193194106221, 0.018439743667840958, 0.004895282443612814, 0.0036929987836629152, 0.05041775107383728, 0.03271673619747162, 0.4425412714481354, 0.4447941780090332, 0.0], [0.015919672325253487, 0.02172437310218811, 0.013682822696864605, 0.028371846303343773, 0.017258556559681892, 0.014516759663820267, 0.033475372940301895, 0.45419326424598694, 0.40085726976394653, 0.0], [0.006064589135348797, 0.006147248670458794, 0.06902536749839783, 0.011021673679351807, 0.0062199062667787075, 0.17622654139995575, 0.00982236210256815, 0.46262383460998535, 0.25284844636917114, 0.0], [0.018328940495848656, 0.034908927977085114, 0.027539005503058434, 0.04494883120059967, 0.03695090860128403, 0.18224696815013885, 0.04204700142145157, 0.09570277482271194, 0.5173265337944031, 0.0], [0.06838149577379227, 0.025893883779644966, 0.06412170827388763, 0.11039282381534576, 0.12848982214927673, 0.09953469038009644, 0.09056522697210312, 0.12723064422607422, 0.28538966178894043, 0.0], [0.07893572002649307, 0.0734885111451149, 0.06503137946128845, 0.04291535168886185, 0.08502060174942017, 0.04846649244427681, 0.07035838067531586, 0.14812934398651123, 0.38765427470207214, 0.0], [0.007445929106324911, 0.004103729501366615, 0.05411284416913986, 0.006074799690395594, 0.07146289199590683, 0.5494692921638489, 0.05009504780173302, 0.058794084936380386, 0.1984413117170334, 0.0], [0.0037151367869228125, 0.005083263851702213, 0.02171880006790161, 0.01245985459536314, 0.012914983555674553, 0.14437292516231537, 0.026943473145365715, 0.17420484125614166, 0.5985866785049438, 0.0], [0.02579679898917675, 0.0645768865942955, 0.03225725144147873, 0.044467855244874954, 0.04297630116343498, 0.06060377135872841, 0.030930038541555405, 0.03278812766075134, 0.6656030416488647, 0.0]], [[0.13460709154605865, 0.15298102796077728, 0.06546170264482498, 0.14220191538333893, 0.11837887763977051, 0.09888823330402374, 0.10630416870117188, 0.08867054432630539, 0.09250646829605103, 0.0], [0.9316296577453613, 0.016095036640763283, 0.0020372711587697268, 0.0019596514757722616, 2.8437656510504894e-05, 6.708989531034604e-05, 0.0004955903859809041, 3.0113247703411616e-05, 0.047657083719968796, 0.0], [0.043201129883527756, 0.9419298768043518, 0.0003410913050174713, 0.003313146298751235, 7.506452675443143e-06, 1.9570916265365668e-05, 2.5470235414104536e-05, 2.1080213628010824e-05, 0.011141069233417511, 0.0], [3.7581870856229216e-05, 0.00022979748609941453, 0.9982534646987915, 8.70372386998497e-05, 5.87535805607331e-06, 2.5239218302886002e-05, 6.597588708245894e-06, 2.193619138779468e-06, 0.001352491439320147, 0.0], [0.0019612079486250877, 0.011641290038824081, 0.010358362458646297, 0.8346317410469055, 0.00641160923987627, 0.0007435380248352885, 0.0018172020791098475, 7.255822129081935e-05, 0.1323624849319458, 0.0], [4.077299308846705e-05, 0.00016088274423964322, 3.1180113637674367e-06, 5.9685276937671006e-05, 6.661444786004722e-06, 0.0006764131248928607, 5.4107837058836594e-05, 0.9797272086143494, 0.01927126571536064, 0.0], [2.7792530090664513e-06, 1.1777839063142892e-05, 1.0386434951215051e-05, 0.0006807934259995818, 0.00028749846387654543, 0.9563493728637695, 2.4335316993528977e-05, 0.001297356327995658, 0.041335828602313995, 0.0], [0.00033864984288811684, 0.00016234541544690728, 0.00011107163300039247, 7.639558316441253e-05, 9.851753566181287e-05, 0.00046863980242051184, 0.9855522513389587, 0.00012009339843643829, 0.013071970082819462, 0.0], [0.001446103909984231, 0.0026176422834396362, 0.0005430445889942348, 0.5833504796028137, 0.08298782259225845, 0.01277364045381546, 0.008405186235904694, 0.028461067005991936, 0.2794148921966553, 0.0], [8.301706202473724e-07, 1.612889263924444e-06, 3.859615389956161e-06, 0.0015496612759307027, 0.9884966611862183, 0.0003321043332107365, 1.1829011782538146e-05, 3.7258676002238644e-06, 0.00959983840584755, 0.0]], [[0.03624086081981659, 0.008591840974986553, 0.01890810765326023, 0.010947922244668007, 0.5211313366889954, 0.04890615865588188, 0.13394898176193237, 0.08554741740226746, 0.13577744364738464, 0.0], [0.09101090580224991, 0.15663929283618927, 0.2008313536643982, 0.13744188845157623, 0.16349081695079803, 0.01479706447571516, 0.04576689749956131, 0.05515507981181145, 0.1348666250705719, 0.0], [0.10898119956254959, 0.19741322100162506, 0.12774543464183807, 0.07097428292036057, 0.033309608697891235, 0.016726871952414513, 0.019306309521198273, 0.09155051410198212, 0.3339925706386566, 0.0], [0.051247891038656235, 0.06952031701803207, 0.3243081271648407, 0.04820195212960243, 0.05462171137332916, 0.04280935227870941, 0.03801479935646057, 0.07710513472557068, 0.2941707372665405, 0.0], [0.22540897130966187, 0.04426601901650429, 0.13483746349811554, 0.09052211791276932, 0.036632657051086426, 0.06078784167766571, 0.09962243586778641, 0.04597063735127449, 0.2619517743587494, 0.0], [0.08315062522888184, 0.10649015009403229, 0.15254046022891998, 0.0728936716914177, 0.10388997197151184, 0.04998103529214859, 0.0675109326839447, 0.17524446547031403, 0.18829864263534546, 0.0], [0.09407053142786026, 0.04335644096136093, 0.04757237061858177, 0.023308007046580315, 0.14141318202018738, 0.017728488892316818, 0.02331509254872799, 0.07266414165496826, 0.5365718007087708, 0.0], [0.08477651327848434, 0.026448125019669533, 0.013684368692338467, 0.1331702470779419, 0.16824185848236084, 0.007634431589394808, 0.025501158088445663, 0.035930439829826355, 0.5046128630638123, 0.0], [0.03296202793717384, 0.01823815330862999, 0.025750160217285156, 0.08325016498565674, 0.1596710979938507, 0.010502922348678112, 0.01792057603597641, 0.05097610503435135, 0.6007286906242371, 0.0], [0.04370357468724251, 0.02250431850552559, 0.016271278262138367, 0.019842427223920822, 0.12028838694095612, 0.03933797404170036, 0.043740611523389816, 0.08045370131731033, 0.6138576865196228, 0.0]], [[0.1783323585987091, 0.3813028037548065, 0.2072289139032364, 0.06766574084758759, 0.053963109850883484, 0.030795719474554062, 0.023536406457424164, 0.03921645134687424, 0.01795845478773117, 0.0], [0.8837893009185791, 0.07202983647584915, 0.03646722435951233, 0.0004511935112532228, 0.0007272462244145572, 0.0008432198665104806, 0.0031319037079811096, 0.0004143840924371034, 0.0021455709356814623, 0.0], [0.3973897695541382, 0.14911939203739166, 0.3486334979534149, 0.012645252980291843, 0.00675938231870532, 0.00483374297618866, 0.010028100572526455, 0.012036854401230812, 0.058554183691740036, 0.0], [0.005409032106399536, 0.005906772334128618, 0.13379110395908356, 0.15247586369514465, 0.06559418141841888, 0.15356750786304474, 0.04085409641265869, 0.029147597029805183, 0.41325387358665466, 0.0], [0.0013326199259608984, 0.0014979635598137975, 0.011986319907009602, 0.7730216383934021, 0.06901827454566956, 0.05895080044865608, 0.016383536159992218, 0.015771687030792236, 0.052037257701158524, 0.0], [0.0012038598069921136, 0.0033955213148146868, 0.025528373196721077, 0.03136582672595978, 0.10901585966348648, 0.3851255178451538, 0.0182026457041502, 0.13982580602169037, 0.2863365411758423, 0.0], [0.008065885864198208, 0.004362722393125296, 0.06363680213689804, 0.023311397060751915, 0.06106392294168472, 0.1357712298631668, 0.03965916484594345, 0.06073852628469467, 0.6033903956413269, 0.0], [0.0003142715140711516, 0.0005578870768658817, 0.0015481057344004512, 0.0887022390961647, 0.06383900344371796, 0.2639910578727722, 0.049384135752916336, 0.12241825461387634, 0.40924492478370667, 0.0], [0.0003916181158274412, 0.0003099135938100517, 0.0024421222042292356, 0.016801349818706512, 0.18835966289043427, 0.025843605399131775, 0.08458039909601212, 0.20884136855602264, 0.4724300503730774, 0.0], [5.865378989255987e-05, 7.253760122694075e-05, 0.0007906460668891668, 0.025103986263275146, 0.0753612071275711, 0.04038592055439949, 0.011871143244206905, 0.05808362737298012, 0.7882723212242126, 0.0]], [[0.01597539149224758, 0.027860743924975395, 0.08824922889471054, 0.011547067202627659, 0.02896539680659771, 0.03845160827040672, 0.011409634724259377, 0.043791815638542175, 0.7337491512298584, 0.0], [0.0371943861246109, 0.014876782894134521, 0.02253115549683571, 0.10164438933134079, 0.029471710324287415, 0.040005166083574295, 0.020577073097229004, 0.07326765358448029, 0.6604316830635071, 0.0], [0.06676606088876724, 0.1320837438106537, 0.02368331328034401, 0.09289334714412689, 0.06407851725816727, 0.007657648529857397, 0.014540987089276314, 0.018603011965751648, 0.5796933174133301, 0.0], [0.029496638104319572, 0.013616771437227726, 0.030488401651382446, 0.021259615197777748, 0.13049498200416565, 0.06418323516845703, 0.050123173743486404, 0.1609034240245819, 0.4994336664676666, 0.0], [0.010230573825538158, 0.015954630449414253, 0.007779641076922417, 0.018425902351737022, 0.021085364744067192, 0.0588817335665226, 0.013979516923427582, 0.0252523310482502, 0.828410267829895, 0.0], [0.02648993395268917, 0.0214377511292696, 0.03494586795568466, 0.05471349507570267, 0.09140968322753906, 0.04952282831072807, 0.05564551055431366, 0.11169540882110596, 0.5541394948959351, 0.0], [0.03231878578662872, 0.018621357157826424, 0.05183127149939537, 0.03979233279824257, 0.13804322481155396, 0.03567919135093689, 0.047386858612298965, 0.13114488124847412, 0.505182147026062, 0.0], [0.04592716693878174, 0.010993612930178642, 0.01772226020693779, 0.05332585424184799, 0.15264220535755157, 0.22139224410057068, 0.048004403710365295, 0.12396018952131271, 0.3260320723056793, 0.0], [0.03168570622801781, 0.026294516399502754, 0.025469979271292686, 0.03026771917939186, 0.058515094220638275, 0.13361068069934845, 0.026259208098053932, 0.0612059161067009, 0.6066910624504089, 0.0], [0.07492455840110779, 0.06428299844264984, 0.07022737711668015, 0.0507473424077034, 0.0447908453643322, 0.060839906334877014, 0.14463475346565247, 0.054812539368867874, 0.4347396492958069, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9642227292060852, 0.035777393728494644, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9523521065711975, 0.027811188250780106, 0.019836684688925743, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.849480152130127, 0.03536543622612953, 0.019422976300120354, 0.09573143720626831, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.741925060749054, 0.05566684901714325, 0.024736514315009117, 0.08595114946365356, 0.09172046929597855, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6503966450691223, 0.0582728385925293, 0.0236701387912035, 0.0691222995519638, 0.0758395791053772, 0.12269847840070724, 0.0, 0.0, 0.0, 0.0], [0.4914315342903137, 0.11739180237054825, 0.02309434488415718, 0.07889512181282043, 0.05101678892970085, 0.12367808818817139, 0.11449223756790161, 0.0, 0.0, 0.0], [0.4262734055519104, 0.07066749036312103, 0.024391667917370796, 0.04879573732614517, 0.051445234566926956, 0.1276569813489914, 0.11843930184841156, 0.13233007490634918, 0.0, 0.0], [0.589878499507904, 0.026613032445311546, 0.020459800958633423, 0.028271155431866646, 0.03679497539997101, 0.07860217243432999, 0.08500825613737106, 0.09285575151443481, 0.04151623696088791, 0.0], [0.2743179202079773, 0.06089583784341812, 0.03565794974565506, 0.044920988380908966, 0.03933599591255188, 0.18495218455791473, 0.09192009270191193, 0.13160176575183868, 0.04121606424450874, 0.09518115967512131]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9842625260353088, 0.015737490728497505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8382691144943237, 0.11647694557905197, 0.04525385797023773, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4638526439666748, 0.1585947573184967, 0.3189436197280884, 0.0586090050637722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2375488132238388, 0.07284080982208252, 0.20766110718250275, 0.3110494017601013, 0.1708998829126358, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20615516602993011, 0.03705071657896042, 0.05929475650191307, 0.08692343533039093, 0.5564662218093872, 0.05410974845290184, 0.0, 0.0, 0.0, 0.0], [0.31913095712661743, 0.011343744583427906, 0.01675090566277504, 0.013238506391644478, 0.06746862828731537, 0.3789318799972534, 0.19313538074493408, 0.0, 0.0, 0.0], [0.4113273322582245, 0.003934106323868036, 0.003564919577911496, 0.005882325116544962, 0.018547017127275467, 0.18534934520721436, 0.3216978907585144, 0.04969710111618042, 0.0, 0.0], [0.07648876309394836, 0.0013769177021458745, 0.001890459912829101, 0.006597061175853014, 0.007926206104457378, 0.013261871412396431, 0.15683594346046448, 0.7190074324607849, 0.016615279018878937, 0.0], [0.08104224503040314, 0.00045554721145890653, 0.00038501128437928855, 0.0009405335295014083, 0.005597654264420271, 0.0034990713465958834, 0.009850292466580868, 0.0463707260787487, 0.7366765141487122, 0.11518235504627228]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9800853133201599, 0.019914645701646805, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9159882068634033, 0.02969631738960743, 0.05431551858782768, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6467475295066833, 0.08892705291509628, 0.19796258211135864, 0.06636285036802292, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9833061099052429, 0.004010406322777271, 0.004914217162877321, 0.0015858567785471678, 0.006183335091918707, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9524497389793396, 0.0022862900514155626, 0.000848656112793833, 0.00408557103946805, 0.028177350759506226, 0.012152665294706821, 0.0, 0.0, 0.0, 0.0], [0.1907505989074707, 0.026542214676737785, 0.01945381611585617, 0.029287727549672127, 0.057166602462530136, 0.11766232550144196, 0.5591367483139038, 0.0, 0.0, 0.0], [0.4022328555583954, 0.017193131148815155, 0.01565318927168846, 0.01915702596306801, 0.01739031821489334, 0.16459040343761444, 0.18205313384532928, 0.18172988295555115, 0.0, 0.0], [0.9652498960494995, 0.0010482663055881858, 0.0012260396033525467, 0.0009098293376155198, 0.0013901795027777553, 0.0028189055155962706, 0.007343438919633627, 0.018731823191046715, 0.0012814495712518692, 0.0], [0.18471455574035645, 0.018054824322462082, 0.08812589198350906, 0.00762907462194562, 0.018057269975543022, 0.05247756093740463, 0.03497685119509697, 0.5025416612625122, 0.052323222160339355, 0.04109897091984749]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9911633133888245, 0.008836665190756321, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9641951322555542, 0.023474374786019325, 0.012330451980233192, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6152319312095642, 0.28041696548461914, 0.04906271770596504, 0.05528838559985161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6057276725769043, 0.1235719844698906, 0.06170117110013962, 0.11151555925607681, 0.0974835753440857, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6386814713478088, 0.07927443087100983, 0.06004401296377182, 0.06398510187864304, 0.06341437995433807, 0.09460049122571945, 0.0, 0.0, 0.0, 0.0], [0.13321073353290558, 0.0565485954284668, 0.20425985753536224, 0.10307760536670685, 0.17957380414009094, 0.26328328251838684, 0.06004612147808075, 0.0, 0.0, 0.0], [0.19694660604000092, 0.027736904099583626, 0.05790374055504799, 0.10621010512113571, 0.15510229766368866, 0.2214440256357193, 0.18680275976657867, 0.04785352945327759, 0.0, 0.0], [0.08537944406270981, 0.033881768584251404, 0.03968465328216553, 0.08240006119012833, 0.15350975096225739, 0.23219235241413116, 0.22240297496318817, 0.11620921641588211, 0.034339725971221924, 0.0], [0.06051333248615265, 0.012086840346455574, 0.028373999521136284, 0.07542525231838226, 0.10199770331382751, 0.15039192140102386, 0.20426926016807556, 0.16016273200511932, 0.06537677347660065, 0.14140206575393677]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5400503277778625, 0.4599496126174927, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04321815073490143, 0.9357689023017883, 0.02101275697350502, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48035699129104614, 0.12913382053375244, 0.27151036262512207, 0.11899882555007935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6920371055603027, 0.019891848787665367, 0.1885785609483719, 0.06273186951875687, 0.036760613322257996, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8527964949607849, 0.08059625327587128, 0.0037265238352119923, 0.008582950569689274, 0.042790722101926804, 0.01150701567530632, 0.0, 0.0, 0.0, 0.0], [0.900881826877594, 0.012710069306194782, 0.000794807099737227, 0.00424413476139307, 0.02110898308455944, 0.01962616853415966, 0.04063420742750168, 0.0, 0.0, 0.0], [0.713775098323822, 0.003081131726503372, 0.000918463512789458, 0.009338468313217163, 0.013423318043351173, 0.019161174073815346, 0.10174864530563354, 0.13855360448360443, 0.0, 0.0], [0.4800099730491638, 0.0009553784620948136, 0.00013007478264626116, 0.020002998411655426, 0.0032414987217634916, 0.002101779682561755, 0.028948260471224785, 0.46123453974723816, 0.0033754503820091486, 0.0], [0.7501513361930847, 0.019767694175243378, 0.0020619838032871485, 0.0038300605956465006, 0.0023455689661204815, 0.023803891614079475, 0.011456847190856934, 0.045016106218099594, 0.08813992142677307, 0.05342674255371094]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03494315221905708, 0.965056836605072, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.020348060876131058, 0.8944171071052551, 0.08523476868867874, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0015979396412149072, 0.6347042918205261, 0.09008561074733734, 0.27361196279525757, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01025437843054533, 0.17247439920902252, 0.3664330542087555, 0.4087805449962616, 0.04205762594938278, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012186901643872261, 0.3028968572616577, 0.12117700278759003, 0.3522109389305115, 0.06255244463682175, 0.14897578954696655, 0.0, 0.0, 0.0, 0.0], [0.010822800919413567, 0.2333739995956421, 0.11113002151250839, 0.15861180424690247, 0.11286703497171402, 0.2766783833503723, 0.0965159684419632, 0.0, 0.0, 0.0], [0.00965114776045084, 0.19982098042964935, 0.054301097989082336, 0.13056904077529907, 0.03828747197985649, 0.4827912747859955, 0.05511533096432686, 0.029463520273566246, 0.0, 0.0], [0.014548483304679394, 0.07520423084497452, 0.1090526208281517, 0.14237697422504425, 0.030428709462285042, 0.5021095275878906, 0.026151562109589577, 0.04390878602862358, 0.05621904134750366, 0.0], [0.000422637298470363, 0.17123113572597504, 0.04347287863492966, 0.10408183932304382, 0.013075248338282108, 0.5476951003074646, 0.020964276045560837, 0.019243689253926277, 0.0612923838198185, 0.018520813435316086]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9947329163551331, 0.005267037078738213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7284466028213501, 0.21829284727573395, 0.05326057970523834, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7024527192115784, 0.0454108789563179, 0.10381712764501572, 0.14831924438476562, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2374107390642166, 0.04589728266000748, 0.2683154046535492, 0.3902822434902191, 0.0580943301320076, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7228419780731201, 0.007619804237037897, 0.013993922621011734, 0.04429992660880089, 0.020430808886885643, 0.19081364572048187, 0.0, 0.0, 0.0, 0.0], [0.4783930778503418, 0.005506142508238554, 0.008406496606767178, 0.012424511834979057, 0.04335693642497063, 0.17542317509651184, 0.27648961544036865, 0.0, 0.0, 0.0], [0.056768160313367844, 0.001066300319507718, 0.0015203694347292185, 0.004650356248021126, 0.004999558907002211, 0.17368057370185852, 0.7387632131576538, 0.018551528453826904, 0.0, 0.0], [0.14709600806236267, 0.007261540275067091, 0.001291902968659997, 0.012605146504938602, 0.005232691299170256, 0.08098926395177841, 0.5304067134857178, 0.207069993019104, 0.00804678164422512, 0.0], [0.15080930292606354, 0.014301316812634468, 0.002821019385010004, 0.02008463814854622, 0.004475536290556192, 0.05297520384192467, 0.27036672830581665, 0.407105028629303, 0.007729486562311649, 0.06933178007602692]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9945669174194336, 0.005433134268969297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9554939270019531, 0.02177131362259388, 0.0227347444742918, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.19059398770332336, 0.7459079623222351, 0.05105874687433243, 0.012439398095011711, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.062025006860494614, 0.7277394533157349, 0.13110491633415222, 0.028790757060050964, 0.050339892506599426, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7678350806236267, 0.007377212401479483, 0.020054306834936142, 0.11815592646598816, 0.07254840433597565, 0.014029012061655521, 0.0, 0.0, 0.0, 0.0], [0.8187481760978699, 0.009394909255206585, 0.015446240082383156, 0.012167787179350853, 0.10175905376672745, 0.02721206098794937, 0.01527167297899723, 0.0, 0.0, 0.0], [0.7012083530426025, 0.12151088565587997, 0.03808446228504181, 0.01883355714380741, 0.0837249755859375, 0.006598148960620165, 0.006499246694147587, 0.023540453985333443, 0.0, 0.0], [0.5152325630187988, 0.054241329431533813, 0.17093418538570404, 0.020541386678814888, 0.17657014727592468, 0.012641755864024162, 0.01802964322268963, 0.023539982736110687, 0.008269038051366806, 0.0], [0.9131196141242981, 0.0010915634920820594, 0.006193474866449833, 0.006082434672862291, 0.03542511910200119, 0.006826554890722036, 0.0028478680178523064, 0.004068343434482813, 0.014553201384842396, 0.009791722521185875]]], [[[0.16448259353637695, 0.17219680547714233, 0.09987642616033554, 0.09012344479560852, 0.06534503400325775, 0.08456553518772125, 0.06690192222595215, 0.08019057661294937, 0.17631761729717255, 0.0], [0.49537378549575806, 0.03979916125535965, 0.09498286247253418, 0.0017974335933104157, 0.028368383646011353, 0.0015277893980965018, 0.014851069077849388, 0.0003722719266079366, 0.3229270279407501, 0.0], [0.0031106590759009123, 0.8318147659301758, 0.0329316072165966, 0.00014872441533952951, 0.000739947019610554, 0.0009879706194624305, 0.0012947155628353357, 0.00040531408740207553, 0.128566175699234, 0.0], [3.727031798916869e-05, 0.00033458907273598015, 0.9051278829574585, 0.014809494838118553, 0.0013665216974914074, 0.0009820980485528708, 0.0004274636448826641, 0.0006300737150013447, 0.07628484070301056, 0.0], [2.789895370369777e-05, 7.413508137688041e-05, 0.00011113573418697342, 0.9593441486358643, 0.023210706189274788, 0.00043970797560177743, 0.00011651179374894127, 0.0001221746060764417, 0.016553271561861038, 0.0], [5.518151283467887e-06, 4.040239218738861e-06, 4.706911568064243e-06, 0.0001475349417887628, 0.0011833186727017164, 0.007331210654228926, 0.0003812467912212014, 0.7072276473045349, 0.28371480107307434, 0.0], [2.1062598989374237e-06, 1.0153020184588968e-06, 9.153064297606761e-07, 2.3557351596537046e-05, 0.0019158869981765747, 0.9726926684379578, 0.0003360892878845334, 0.008161749690771103, 0.01686590164899826, 0.0], [1.876308124337811e-05, 3.1762643629917875e-05, 7.612020908709383e-06, 4.369785983726615e-06, 0.00035698129795491695, 0.006292039528489113, 0.9372867941856384, 0.0028216273058205843, 0.0531802624464035, 0.0], [0.00017082327394746244, 0.0008267413941211998, 0.0010992212919518352, 0.016357675194740295, 0.03317699581384659, 0.013446258381009102, 0.022417983040213585, 0.0993492603302002, 0.813154935836792, 0.0], [2.095436911986326e-06, 1.0510404990782263e-06, 8.745904779061675e-06, 9.465758921578526e-05, 0.9096792936325073, 0.004888555034995079, 0.00019891942793037742, 0.00012723646068479866, 0.08499950170516968, 0.0]], [[0.09510962665081024, 0.13984361290931702, 0.01835908181965351, 0.05623754486441612, 0.05484192445874214, 0.02751241996884346, 0.023350151255726814, 0.02046714909374714, 0.5642784833908081, 0.0], [0.32246580719947815, 0.12212380021810532, 0.0033711090218275785, 0.41883695125579834, 0.0010050723794847727, 0.00026374190929345787, 0.00840060692280531, 0.0003199145139660686, 0.12321317940950394, 0.0], [0.1343918889760971, 0.42756012082099915, 0.03016146458685398, 0.27197346091270447, 0.0008738918695598841, 0.00041738885920494795, 0.0011337834876030684, 0.0017680631717666984, 0.13172008097171783, 0.0], [4.970032023265958e-05, 0.0002945268643088639, 0.9929893612861633, 0.006102537736296654, 1.304412307945313e-06, 7.552243459940655e-06, 2.0433815279830014e-06, 1.4308750905911438e-05, 0.0005390164442360401, 0.0], [0.0006735534407198429, 0.0037932321429252625, 0.014864870347082615, 0.9520841240882874, 0.0031083461362868547, 0.0014454165939241648, 0.000881638377904892, 0.00042032121564261615, 0.02272843010723591, 0.0], [1.054488166118972e-06, 5.819076250190847e-06, 3.686256491164386e-07, 5.7184315664926544e-05, 1.600286668690387e-05, 0.0002979082928504795, 5.8259040088159963e-05, 0.997514009475708, 0.0020495890639722347, 0.0], [1.2081607110303594e-06, 1.8248301785206422e-06, 3.5412674037615943e-07, 0.00017610432405490428, 0.0004308871575631201, 0.9919483065605164, 0.001251595327630639, 0.004008213523775339, 0.002181792864575982, 0.0], [1.3394396773946937e-06, 1.858925656961219e-06, 8.99223309147601e-08, 5.498410246218555e-06, 4.1167979361489415e-05, 0.003499603597447276, 0.9961592555046082, 8.322765097545926e-06, 0.0002831367892213166, 0.0], [0.0011697824811562896, 0.00207342766225338, 0.0001985222043003887, 0.24218614399433136, 0.2580603361129761, 0.03422079235315323, 0.3017951250076294, 0.0700761154294014, 0.09021952003240585, 0.0], [4.897859540164973e-08, 1.9182496657776937e-07, 1.6890984966266842e-07, 0.00012898082786705345, 0.9986647963523865, 0.0003688811557367444, 8.465539576718584e-05, 1.2611121746886056e-05, 0.0007397857843898237, 0.0]], [[0.008738831616938114, 0.010689073242247105, 0.010104849003255367, 0.025418052449822426, 0.008787600323557854, 0.018541773781180382, 0.01414045225828886, 0.009587875567376614, 0.8939914107322693, 0.0], [0.050771377980709076, 0.08173098415136337, 0.03076810948550701, 0.6816214919090271, 0.04326915368437767, 0.0030209666583687067, 0.006032166071236134, 0.007633579429239035, 0.09515213221311569, 0.0], [0.04749365150928497, 0.07148067653179169, 0.018722670152783394, 0.5845115184783936, 0.03816590458154678, 0.003933309111744165, 0.006466464139521122, 0.021205652505159378, 0.20802012085914612, 0.0], [0.021572547033429146, 0.11727327853441238, 0.03622674569487572, 0.4274545907974243, 0.05620160698890686, 0.01161592174321413, 0.010393376462161541, 0.014363090507686138, 0.30489882826805115, 0.0], [0.015270093455910683, 0.10013995319604874, 0.006727923639118671, 0.19538360834121704, 0.1119888573884964, 0.027630485594272614, 0.0700199231505394, 0.01868581771850586, 0.4541531801223755, 0.0], [0.00540963327512145, 0.07916348427534103, 0.01957465149462223, 0.49324244260787964, 0.10871188342571259, 0.02422497235238552, 0.008650544099509716, 0.16292543709278107, 0.0980970561504364, 0.0], [0.027941647917032242, 0.005471521522849798, 0.006384703796356916, 0.03924928605556488, 0.22657036781311035, 0.21837352216243744, 0.3372570872306824, 0.05897291377186775, 0.07977905124425888, 0.0], [0.009049936197698116, 0.005020579323172569, 0.014692768454551697, 0.15799382328987122, 0.4401932656764984, 0.1766415536403656, 0.03136269003152847, 0.12063619494438171, 0.044409021735191345, 0.0], [0.0007816475699655712, 0.0003147682291455567, 0.0032215022947639227, 0.4467180669307709, 0.3918246924877167, 0.00227341428399086, 0.004370422102510929, 0.14414219558238983, 0.006353371310979128, 0.0], [0.0005489268223755062, 0.016601460054516792, 0.01341363787651062, 0.2753817141056061, 0.13981539011001587, 0.04711242765188217, 0.08167178928852081, 0.11951272189617157, 0.30594193935394287, 0.0]], [[0.11438923329114914, 0.12380287796258926, 0.23573537170886993, 0.19010169804096222, 0.15611350536346436, 0.031749427318573, 0.02482231892645359, 0.05017237365245819, 0.07311322540044785, 0.0], [0.002549531403928995, 0.03178577870130539, 0.17347589135169983, 0.2232668697834015, 0.49775105714797974, 0.018238944932818413, 0.005651220679283142, 0.03368452191352844, 0.013595964759588242, 0.0], [0.0032994491048157215, 0.026504727080464363, 0.41210347414016724, 0.24245016276836395, 0.18897436559200287, 0.012874660082161427, 0.006452939473092556, 0.10089367628097534, 0.00644671730697155, 0.0], [0.002998506650328636, 0.048583757132291794, 0.28224417567253113, 0.0846971943974495, 0.013445784337818623, 0.02188579924404621, 0.017656570300459862, 0.5155076384544373, 0.012980557046830654, 0.0], [0.004188622813671827, 0.028234833851456642, 0.022820167243480682, 0.058492597192525864, 0.19205521047115326, 0.08343320339918137, 0.07119973003864288, 0.4843534827232361, 0.0552222914993763, 0.0], [0.0038351663388311863, 0.015353971160948277, 0.01755588687956333, 0.06245748698711395, 0.1218588799238205, 0.07207991182804108, 0.02867230959236622, 0.5455195903778076, 0.13266700506210327, 0.0], [0.004144841339439154, 0.0048835063353180885, 0.0035110898315906525, 0.06276324391365051, 0.04069552943110466, 0.3603023290634155, 0.1472603678703308, 0.2116946280002594, 0.16474448144435883, 0.0], [0.024624889716506004, 0.016127971932291985, 0.0073340879753232, 0.023849278688430786, 0.042295511811971664, 0.5078635215759277, 0.2884303331375122, 0.011452756822109222, 0.07802165299654007, 0.0], [0.00880166981369257, 0.002673782641068101, 0.001370548619888723, 0.0061265453696250916, 0.02490534819662571, 0.2073771357536316, 0.3818575143814087, 0.1663341522216797, 0.20055335760116577, 0.0], [0.012253189459443092, 0.02221212349832058, 0.002282155444845557, 0.10455729067325592, 0.4111727774143219, 0.08308815956115723, 0.045707643032073975, 0.03711223974823952, 0.2816142141819, 0.0]], [[0.5821239352226257, 0.14550858736038208, 0.031251534819602966, 0.030760297551751137, 0.02147754468023777, 0.013665237464010715, 0.009087015874683857, 0.01557532325387001, 0.15055041015148163, 0.0], [0.12817564606666565, 0.33913177251815796, 0.07241326570510864, 0.41213902831077576, 0.0326012559235096, 0.0031606394331902266, 0.0006341012776829302, 0.007317711599171162, 0.0044263736344873905, 0.0], [0.08047150820493698, 0.06199575960636139, 0.5555182099342346, 0.2858560383319855, 0.008700164034962654, 0.003758196486160159, 0.001155794132500887, 0.0007424709619954228, 0.0018020549323409796, 0.0], [0.010044030845165253, 0.018482256680727005, 0.6269924640655518, 0.32439544796943665, 0.01023165788501501, 0.007641270756721497, 0.0008933563949540257, 0.0010311403311789036, 0.00028844154439866543, 0.0], [0.0007911038701422513, 0.0008549468475393951, 0.015090622939169407, 0.8270009160041809, 0.11969847232103348, 0.032614268362522125, 0.0024233118165284395, 0.0011481117689982057, 0.0003779604157898575, 0.0], [0.017773190513253212, 0.008623103611171246, 0.0020072387997061014, 0.08177924901247025, 0.13816505670547485, 0.6801413297653198, 0.02186667174100876, 0.024107687175273895, 0.025536518543958664, 0.0], [0.000318053673254326, 5.6540200603194535e-05, 1.071194674295839e-05, 0.0009494975674897432, 0.0034297029487788677, 0.032661326229572296, 0.9588278532028198, 0.003185966284945607, 0.0005602877936325967, 0.0], [0.0017862697131931782, 0.0002347631088923663, 2.1297884813975543e-05, 0.0004797980946023017, 0.0018031852087005973, 0.024247879162430763, 0.45456385612487793, 0.5099425911903381, 0.006920217536389828, 0.0], [0.0006541880429722369, 0.0009561541373841465, 7.73017163737677e-05, 0.00942671112716198, 0.04198922589421272, 0.04971348121762276, 0.32961171865463257, 0.4513629972934723, 0.11620841920375824, 0.0], [0.017209511250257492, 0.004475452937185764, 3.128392927465029e-05, 0.00047953161993063986, 0.00448839133605361, 0.03360708802938461, 0.11509764194488525, 0.5398797988891602, 0.2847314178943634, 0.0]], [[0.20143046975135803, 0.41116827726364136, 0.09215858578681946, 0.10672477632761002, 0.06125285103917122, 0.017610367387533188, 0.01457523088902235, 0.02514597773551941, 0.06993352621793747, 0.0], [0.026864346116781235, 0.037146128714084625, 0.08411292731761932, 0.02904331497848034, 0.0955604761838913, 0.05886658653616905, 0.08584483712911606, 0.4076027572154999, 0.17495866119861603, 0.0], [0.073190838098526, 0.07998740673065186, 0.05594569817185402, 0.03243006020784378, 0.10037493705749512, 0.13878461718559265, 0.15250830352306366, 0.25721096992492676, 0.10956726223230362, 0.0], [0.0438627265393734, 0.04628896340727806, 0.4038660526275635, 0.005475929472595453, 0.03436022624373436, 0.11165640503168106, 0.02260321006178856, 0.28233063220977783, 0.04955587536096573, 0.0], [0.2377929538488388, 0.08882997930049896, 0.12371516227722168, 0.08651548624038696, 0.015416872687637806, 0.04211122542619705, 0.16403844952583313, 0.11833071708679199, 0.12324906885623932, 0.0], [0.023254310712218285, 0.0034057339653372765, 0.036038532853126526, 0.009054891765117645, 0.0329253226518631, 0.05284882336854935, 0.15671837329864502, 0.6067742109298706, 0.07897992432117462, 0.0], [0.015282228589057922, 0.008608018048107624, 0.08339564502239227, 0.032651614397764206, 0.21303850412368774, 0.22661514580249786, 0.21832069754600525, 0.1323210895061493, 0.06976725161075592, 0.0], [0.019424932077527046, 0.008587736636400223, 0.014951083809137344, 0.01159222237765789, 0.2890152633190155, 0.2543036639690399, 0.2561561167240143, 0.0882645845413208, 0.05770434811711311, 0.0], [0.020595766603946686, 0.015824340283870697, 0.008689227513968945, 0.03796549141407013, 0.3004503846168518, 0.16956602036952972, 0.10506420582532883, 0.05004280060529709, 0.2918018400669098, 0.0], [0.18154361844062805, 0.0977708026766777, 0.20556335151195526, 0.05251142755150795, 0.13640889525413513, 0.06629360467195511, 0.06030320003628731, 0.08172836154699326, 0.11787670105695724, 0.0]], [[0.07673492282629013, 0.03585591912269592, 0.0804624855518341, 0.05707075819373131, 0.16190174221992493, 0.1288135051727295, 0.1235240250825882, 0.06807681918144226, 0.2675597667694092, 0.0], [0.005086997989565134, 0.014635499566793442, 0.013461720198392868, 0.6349815726280212, 0.14714521169662476, 0.015218403190374374, 0.01605474203824997, 0.018318237736821175, 0.1350976973772049, 0.0], [0.03515003249049187, 0.049813926219940186, 0.04029693454504013, 0.4151618778705597, 0.24873343110084534, 0.009437951259315014, 0.008381601423025131, 0.020832136273384094, 0.17219208180904388, 0.0], [0.06722414493560791, 0.13528113067150116, 0.06224377825856209, 0.18915168941020966, 0.17580503225326538, 0.07229694724082947, 0.012536793015897274, 0.09137610346078873, 0.19408434629440308, 0.0], [0.09099949151277542, 0.09548961371183395, 0.04829362779855728, 0.1739831268787384, 0.06667517125606537, 0.05157051607966423, 0.05465595796704292, 0.06177656352519989, 0.3565560579299927, 0.0], [0.09822985529899597, 0.05441536381840706, 0.039150238037109375, 0.06369251012802124, 0.05292840674519539, 0.050128646194934845, 0.044398434460163116, 0.04042055085301399, 0.5566359758377075, 0.0], [0.012019939720630646, 0.0076602306216955185, 0.02716030552983284, 0.03984800726175308, 0.09776019304990768, 0.05175628885626793, 0.08536165207624435, 0.0944109782576561, 0.5840223431587219, 0.0], [0.036716632544994354, 0.021969007328152657, 0.010507079772651196, 0.012404722161591053, 0.040125522762537, 0.010736462660133839, 0.018730206415057182, 0.030387653037905693, 0.8184227347373962, 0.0], [0.04769879952073097, 0.19333122670650482, 0.02803504839539528, 0.016029207035899162, 0.11119306832551956, 0.03845509514212608, 0.011404097080230713, 0.0836206004023552, 0.4702327847480774, 0.0], [0.05245642364025116, 0.013315027579665184, 0.012056763283908367, 0.004825723823159933, 0.015483945608139038, 0.032884638756513596, 0.027794960886240005, 0.07057305425405502, 0.7706093788146973, 0.0]], [[0.05745904520153999, 0.06613133102655411, 0.11319872736930847, 0.031750500202178955, 0.0641264021396637, 0.07090476900339127, 0.053613319993019104, 0.1108509749174118, 0.4319649040699005, 0.0], [0.12783250212669373, 0.16847258806228638, 0.08126984536647797, 0.10575822740793228, 0.03301985561847687, 0.2111520618200302, 0.10687874257564545, 0.06316707283258438, 0.10244929045438766, 0.0], [0.1413263976573944, 0.38601601123809814, 0.16798537969589233, 0.14611834287643433, 0.015951359644532204, 0.042198505252599716, 0.016183707863092422, 0.06246974319219589, 0.021750787273049355, 0.0], [0.020376645028591156, 0.008152640424668789, 0.04579228535294533, 0.022974595427513123, 0.007921000011265278, 0.11700868606567383, 0.010826223529875278, 0.7216546535491943, 0.04529344290494919, 0.0], [0.04728184640407562, 0.041129130870103836, 0.12847241759300232, 0.038289085030555725, 0.07389654964208603, 0.11478690057992935, 0.04442784935235977, 0.41169247031211853, 0.1000237911939621, 0.0], [0.016180921345949173, 0.005130380857735872, 0.21081623435020447, 0.00797765702009201, 0.04691680520772934, 0.052309177815914154, 0.2947923243045807, 0.34133997559547424, 0.02453651838004589, 0.0], [0.006579844746738672, 0.001606129459105432, 0.206822007894516, 0.017204096540808678, 0.13898226618766785, 0.09910376369953156, 0.4235020577907562, 0.05497713387012482, 0.051222700625658035, 0.0], [0.00896216370165348, 0.0023249718360602856, 0.0226416178047657, 0.05458173528313637, 0.07694459706544876, 0.29436299204826355, 0.36870595812797546, 0.12525610625743866, 0.046219732612371445, 0.0], [0.027829669415950775, 0.014619122259318829, 0.014550572261214256, 0.048137370496988297, 0.15001901984214783, 0.11716196686029434, 0.34159788489341736, 0.1513865739107132, 0.13469791412353516, 0.0], [0.0014273751294240355, 0.003807784290984273, 0.3760293126106262, 0.002253596903756261, 0.11343870311975479, 0.12883712351322174, 0.04242479428648949, 0.28902071714401245, 0.042760640382766724, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9917634725570679, 0.008236419409513474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.711856484413147, 0.20838035643100739, 0.07976315170526505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6327172517776489, 0.1227935329079628, 0.21565596759319305, 0.028833283111453056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3586137592792511, 0.038762304931879044, 0.08015953004360199, 0.4233120083808899, 0.09915236383676529, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7095601558685303, 0.03453405201435089, 0.02220289036631584, 0.009008818306028843, 0.201883926987648, 0.022810086607933044, 0.0, 0.0, 0.0, 0.0], [0.5828825831413269, 0.02795644849538803, 0.054448600858449936, 0.01975347101688385, 0.11504233628511429, 0.08908692002296448, 0.11082970350980759, 0.0, 0.0, 0.0], [0.4315364956855774, 0.020537925884127617, 0.01659376546740532, 0.014654956758022308, 0.13063199818134308, 0.27319464087486267, 0.08869150280952454, 0.024158723652362823, 0.0, 0.0], [0.26020547747612, 0.014821716584265232, 0.01224969606846571, 0.0724530965089798, 0.10939211398363113, 0.19152909517288208, 0.10495918244123459, 0.1680101454257965, 0.06637949496507645, 0.0], [0.6687084436416626, 0.04345089942216873, 0.009689688682556152, 0.0018685735994949937, 0.0738394483923912, 0.12735962867736816, 0.025320274755358696, 0.026545442640781403, 0.020931225270032883, 0.0022863498888909817]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9482711553573608, 0.051728855818510056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8711318373680115, 0.04994085431098938, 0.07892734557390213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7221198678016663, 0.040686361491680145, 0.06532222777605057, 0.17187155783176422, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5948007702827454, 0.036634139716625214, 0.02264709398150444, 0.035541336983442307, 0.3103766441345215, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6650473475456238, 0.01644211634993553, 0.019737746566534042, 0.0375308021903038, 0.10231779515743256, 0.15892422199249268, 0.0, 0.0, 0.0, 0.0], [0.36675524711608887, 0.04118315875530243, 0.02765432558953762, 0.03228116035461426, 0.11875578761100769, 0.12892943620681763, 0.2844408452510834, 0.0, 0.0, 0.0], [0.19659309089183807, 0.015950728207826614, 0.02453998662531376, 0.039237309247255325, 0.037656329572200775, 0.34599894285202026, 0.23759640753269196, 0.10242718458175659, 0.0, 0.0], [0.3881740868091583, 0.012267092242836952, 0.01897304505109787, 0.013982790522277355, 0.030991200357675552, 0.10819684714078903, 0.20157809555530548, 0.14642520248889923, 0.07941170781850815, 0.0], [0.11410266160964966, 0.03479800745844841, 0.043540675193071365, 0.021180409938097, 0.03197954222559929, 0.2248576581478119, 0.12852585315704346, 0.2089216560125351, 0.039846520870923996, 0.1522471308708191]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.993086576461792, 0.0069133141078054905, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9852874875068665, 0.011381878517568111, 0.0033306065015494823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4834398031234741, 0.011301998049020767, 0.48758530616760254, 0.017672834917902946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9851425886154175, 0.0010397545993328094, 0.00470126885920763, 0.0012236799811944366, 0.007892588153481483, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6588926315307617, 0.005506628658622503, 0.021607331931591034, 0.010738613083958626, 0.07747143507003784, 0.2257833182811737, 0.0, 0.0, 0.0, 0.0], [0.13557791709899902, 0.018924091011285782, 0.02187344618141651, 0.015362304635345936, 0.11512601375579834, 0.14739760756492615, 0.5457385182380676, 0.0, 0.0, 0.0], [0.38992705941200256, 0.021535715088248253, 0.005403842777013779, 0.0032997699454426765, 0.4358868896961212, 0.06306594610214233, 0.03204012289643288, 0.04884066432714462, 0.0, 0.0], [0.81478351354599, 0.022238636389374733, 0.0008386021945625544, 0.01924033649265766, 0.06109088659286499, 0.020853841677308083, 0.014834966510534286, 0.028932424262166023, 0.017186695709824562, 0.0], [0.011323019862174988, 0.004743177909404039, 0.004908193834125996, 0.04389021545648575, 0.9175272583961487, 0.008399821817874908, 0.00010120288789039478, 0.0007724545430392027, 0.001946530188433826, 0.006388010922819376]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9621535539627075, 0.037846412509679794, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5398231148719788, 0.4385344386100769, 0.021642372012138367, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6502059698104858, 0.16868625581264496, 0.04876677691936493, 0.13234086334705353, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5965072512626648, 0.06637387722730637, 0.1054789125919342, 0.1866345852613449, 0.04500538855791092, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3253602683544159, 0.03396952152252197, 0.02178867906332016, 0.07780158519744873, 0.04822142422199249, 0.49285849928855896, 0.0, 0.0, 0.0, 0.0], [0.2524598240852356, 0.04065639525651932, 0.06012948602437973, 0.022925280034542084, 0.0371418297290802, 0.17370767891407013, 0.41297948360443115, 0.0, 0.0, 0.0], [0.03411499038338661, 0.003937003668397665, 0.005961195565760136, 0.01710909977555275, 0.011033114977180958, 0.7081340551376343, 0.13750500977039337, 0.08220544457435608, 0.0, 0.0], [0.42400264739990234, 0.02131979539990425, 0.017963027581572533, 0.01083337515592575, 0.019156770780682564, 0.14712399244308472, 0.1343262642621994, 0.19853995740413666, 0.02673417516052723, 0.0], [0.010900852270424366, 0.01643177680671215, 0.007438827771693468, 0.037741534411907196, 0.0038807683158665895, 0.513563871383667, 0.17121337354183197, 0.14364023506641388, 0.04466766491532326, 0.050521109253168106]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4730486273765564, 0.5269513726234436, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39858773350715637, 0.07930062711238861, 0.5221116542816162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5825604200363159, 0.08404675871133804, 0.15067298710346222, 0.182719886302948, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.29498350620269775, 0.03899451717734337, 0.00506106112152338, 0.006130008026957512, 0.6548308730125427, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13055028021335602, 0.007264712825417519, 0.014658198691904545, 0.03852052241563797, 0.6908979415893555, 0.11810839176177979, 0.0, 0.0, 0.0, 0.0], [0.6701509952545166, 0.016114505007863045, 0.009837295860052109, 0.013812566176056862, 0.10121432691812515, 0.04637172445654869, 0.14249859750270844, 0.0, 0.0, 0.0], [0.15980258584022522, 0.02680308185517788, 0.03885137289762497, 0.01341771800071001, 0.16442187130451202, 0.12716332077980042, 0.3698134124279022, 0.09972671419382095, 0.0, 0.0], [0.5671898722648621, 0.0029452391900122166, 0.0006932761170901358, 0.0009682640084065497, 0.008882325142621994, 0.018135691061615944, 0.19489231705665588, 0.1878870278596878, 0.01840599626302719, 0.0], [0.10793960839509964, 0.02733222208917141, 0.05983218923211098, 0.007959540002048016, 0.012123869732022285, 0.0992540642619133, 0.031409986317157745, 0.1074245497584343, 0.5389924645423889, 0.007731476798653603]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9029706120491028, 0.09702935069799423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0243820920586586, 0.026593990623950958, 0.9490237236022949, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002445725491270423, 0.01137782447040081, 0.2685152590274811, 0.7176609635353088, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013372139073908329, 0.017163371667265892, 0.023703746497631073, 0.029362313449382782, 0.916398286819458, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009483089670538902, 0.0015323974657803774, 0.016186771914362907, 0.02369842305779457, 0.15252061188220978, 0.7965786457061768, 0.0, 0.0, 0.0, 0.0], [0.9718639850616455, 0.0004310244112275541, 0.00011954548244830221, 0.007853196933865547, 0.005200029816478491, 0.0034086843952536583, 0.011123435571789742, 0.0, 0.0, 0.0], [0.24338993430137634, 0.0009381886338815093, 0.001691973302513361, 0.004991883412003517, 0.06480661034584045, 0.02667633630335331, 0.5911706686019897, 0.06633439660072327, 0.0, 0.0], [0.9644694328308105, 0.00020931981271132827, 0.00022034233552403748, 0.001116775325499475, 0.0005140798166394234, 0.011200232431292534, 0.006607241928577423, 0.015303434804081917, 0.00035933865001425147, 0.0], [6.446504994528368e-05, 5.223282641964033e-05, 4.761212403536774e-05, 0.0026887860149145126, 0.9879595041275024, 8.169181819539517e-05, 3.4432316169841215e-05, 0.00022215544595383108, 0.008540215902030468, 0.0003085293574258685]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9361864924430847, 0.06381344050168991, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8988155722618103, 0.08033642917871475, 0.020848000422120094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9107179045677185, 0.0414138063788414, 0.01669401116669178, 0.031174303963780403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7451518774032593, 0.09319417923688889, 0.038068220019340515, 0.07867664098739624, 0.04490913450717926, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7448095083236694, 0.053781699389219284, 0.02206255868077278, 0.051568105816841125, 0.050503022968769073, 0.07727508991956711, 0.0, 0.0, 0.0, 0.0], [0.49530187249183655, 0.08432565629482269, 0.024537190794944763, 0.03536847233772278, 0.04101351276040077, 0.2942921817302704, 0.025161173194646835, 0.0, 0.0, 0.0], [0.33749768137931824, 0.0470278225839138, 0.025994539260864258, 0.11184448003768921, 0.035708073526620865, 0.3288814127445221, 0.052594561129808426, 0.06045151129364967, 0.0, 0.0], [0.5827996730804443, 0.037185750901699066, 0.025691334158182144, 0.040444474667310715, 0.032313525676727295, 0.15237869322299957, 0.02532070316374302, 0.06300554424524307, 0.040860243141651154, 0.0], [0.17712005972862244, 0.06858222186565399, 0.023361776024103165, 0.06553570926189423, 0.015878353267908096, 0.40178799629211426, 0.03335757926106453, 0.09457883983850479, 0.06679747253656387, 0.05300001800060272]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5274816155433655, 0.47251835465431213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0672292709350586, 0.8769893646240234, 0.05578138306736946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13304242491722107, 0.3016340434551239, 0.1132093071937561, 0.45211419463157654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3050541281700134, 0.12257003039121628, 0.15977424383163452, 0.12758392095565796, 0.2850176990032196, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22356735169887543, 0.03928647190332413, 0.007754397578537464, 0.009327426552772522, 0.12143179029226303, 0.5986325740814209, 0.0, 0.0, 0.0, 0.0], [0.21462960541248322, 0.06222677230834961, 0.03770677372813225, 0.020617984235286713, 0.1298619657754898, 0.16450734436511993, 0.3704494833946228, 0.0, 0.0, 0.0], [0.03640613704919815, 0.010346643626689911, 0.00673291739076376, 0.007102567236870527, 0.047351155430078506, 0.07502260059118271, 0.1735789030790329, 0.6434589624404907, 0.0, 0.0], [0.10655857622623444, 0.005401281639933586, 0.008467022329568863, 0.004935698118060827, 0.02920999936759472, 0.06761414557695389, 0.11367721855640411, 0.6410130262374878, 0.02312297187745571, 0.0], [0.022663118317723274, 0.01638328656554222, 0.016234278678894043, 0.013438239693641663, 0.13762539625167847, 0.1316443532705307, 0.10126813501119614, 0.1815005987882614, 0.10885223746299744, 0.27039045095443726]]], [[[0.5484673976898193, 0.11615661531686783, 0.018765881657600403, 0.05580288916826248, 0.007166780531406403, 0.0032917018979787827, 0.014802427962422371, 0.009165346622467041, 0.2263808399438858, 0.0], [0.05529153719544411, 0.9114729762077332, 0.02160518430173397, 0.004567866213619709, 0.0020856212358921766, 0.002110318513587117, 7.9427394666709e-05, 0.0003330546023789793, 0.002454147208482027, 0.0], [0.05249117314815521, 0.7437799572944641, 0.20206855237483978, 0.000259216787526384, 5.300459815771319e-05, 0.0010736108524724841, 1.3093422239762731e-05, 5.472580232890323e-05, 0.00020689083612523973, 0.0], [0.00011124516458949074, 0.00014391898002941161, 0.002859711181372404, 0.9545366168022156, 0.03663090988993645, 0.0007743826135993004, 9.755761857377365e-05, 0.004461521748453379, 0.00038416660390794277, 0.0], [0.00743946572765708, 0.0006693374016322196, 0.030706975609064102, 0.7693524360656738, 0.13636630773544312, 0.010656076483428478, 0.00020547783060465008, 0.04430907592177391, 0.0002948205510620028, 0.0], [0.34355977177619934, 0.05352164804935455, 0.28276264667510986, 0.0013746530748903751, 0.09131773561239243, 0.14369648694992065, 0.012657100334763527, 0.0018211203860118985, 0.06928855180740356, 0.0], [0.00012065855844411999, 1.9836104911519215e-05, 1.09456195787061e-05, 1.777518991730176e-05, 5.799566861242056e-05, 0.00024955562548711896, 0.9993764758110046, 2.622428155518719e-06, 0.00014374956663232297, 0.0], [0.07937772572040558, 0.06896578520536423, 0.046331409364938736, 0.0006753505440428853, 0.10991709679365158, 0.07862550020217896, 0.015291865915060043, 0.024935398250818253, 0.575879693031311, 0.0], [0.0013021372724324465, 0.0014324801741167903, 0.001721968175843358, 0.0011953436769545078, 0.025524066761136055, 0.0017154657980427146, 0.004751213360577822, 0.06856247782707214, 0.8937948942184448, 0.0], [0.19237911701202393, 0.08248087018728256, 0.005975060164928436, 0.0005637910799123347, 0.0032617340330034494, 0.08159960061311722, 0.10672623664140701, 0.06415636837482452, 0.4628572165966034, 0.0]], [[0.02611132338643074, 0.03929625079035759, 0.13154497742652893, 0.0362381711602211, 0.41634607315063477, 0.056842975318431854, 0.07852181792259216, 0.04581860452890396, 0.16927982866764069, 0.0], [0.02255222387611866, 0.10361888259649277, 0.8158259987831116, 0.004235901869833469, 0.006766538135707378, 0.000986771541647613, 0.0032120062969624996, 0.001885716337710619, 0.04091576859354973, 0.0], [0.017373425886034966, 0.13537107408046722, 0.8077713847160339, 0.0038886782713234425, 0.000276644917903468, 0.0006381930434145033, 0.00045153097016736865, 0.00025059780455194414, 0.0339784249663353, 0.0], [0.005007221829146147, 0.01780957728624344, 0.01267488207668066, 0.04065018519759178, 0.30516332387924194, 0.0026367236860096455, 0.0019572318997234106, 0.004150545224547386, 0.6099502444267273, 0.0], [0.00225767120718956, 0.0031865497585386038, 0.001125291339121759, 0.016497144475579262, 0.8971690535545349, 0.00343481102026999, 0.003961168695241213, 0.006191920954734087, 0.0661763921380043, 0.0], [0.009628614410758018, 0.010054895654320717, 0.001336919842287898, 0.0704738199710846, 0.6877674460411072, 0.03301373869180679, 0.05187760666012764, 0.005273953080177307, 0.1305730938911438, 0.0], [0.004103087354451418, 0.0010862533235922456, 0.0006940921302884817, 0.005870609078556299, 0.43826234340667725, 0.030803751200437546, 0.2956492602825165, 0.002342070685699582, 0.2211885154247284, 0.0], [0.0007035931921564043, 0.0015657383482903242, 0.0003329406026750803, 0.025085464119911194, 0.8715798258781433, 0.006046876311302185, 0.002586639951914549, 0.00011169366916874424, 0.09198720753192902, 0.0], [0.0021814475767314434, 0.0018482444575056434, 0.02461252734065056, 0.02290530502796173, 0.17733190953731537, 0.007551506161689758, 0.026218494400382042, 0.1859409213066101, 0.5514096021652222, 0.0], [0.002560347318649292, 0.0069580040872097015, 0.0021583843044936657, 0.002428637584671378, 0.010794135741889477, 0.002866419730708003, 0.010929176583886147, 0.004671781323850155, 0.9566330909729004, 0.0]], [[0.10015174746513367, 0.09576243162155151, 0.03824060782790184, 0.020538825541734695, 0.027732321992516518, 0.017240623012185097, 0.011470633558928967, 0.4632270634174347, 0.22563566267490387, 0.0], [0.018642960116267204, 0.04822106286883354, 0.0140958521515131, 0.13092586398124695, 0.031955283135175705, 0.05195324495434761, 0.024238048121333122, 0.35591286420822144, 0.32405486702919006, 0.0], [0.04842180013656616, 0.06019889563322067, 0.016854524612426758, 0.166769877076149, 0.016003064811229706, 0.024013301357626915, 0.03686072677373886, 0.44016721844673157, 0.1907106339931488, 0.0], [0.0036558371502906084, 0.0033082098234444857, 0.0009605223312973976, 0.017093589529395103, 0.019939379766583443, 0.08280718326568604, 0.031923551112413406, 0.703184187412262, 0.1371275782585144, 0.0], [0.0018930931109935045, 0.002881440566852689, 0.00019882648484781384, 0.00406575808301568, 0.0021070034708827734, 0.011610278859734535, 0.0074381148442626, 0.9341073036193848, 0.03569793701171875, 0.0], [0.00691588269546628, 0.02838265150785446, 0.015397720038890839, 0.031874921172857285, 0.04765379801392555, 0.22744230926036835, 0.06624653190374374, 0.10724947601556778, 0.4688366651535034, 0.0], [0.009144916199147701, 0.012914983555674553, 0.0114166010171175, 0.010616455227136612, 0.03852293640375137, 0.11398687958717346, 0.23996756970882416, 0.03855413943529129, 0.5248754620552063, 0.0], [0.0165528766810894, 0.08396174013614655, 0.03695421293377876, 0.012792840600013733, 0.05054211989045143, 0.004681664984673262, 0.006349458359181881, 0.0059485542587935925, 0.7822163701057434, 0.0], [0.01941962167620659, 0.0720844566822052, 0.06703408807516098, 0.0024893011432141066, 0.09017500281333923, 0.01547347940504551, 0.011082785204052925, 0.036743972450494766, 0.6854971051216125, 0.0], [0.010559813119471073, 0.7681021094322205, 0.01782229356467724, 0.0007385257631540298, 0.000383153063012287, 0.00014055910287424922, 0.00037340103881433606, 0.000453647633548826, 0.2014264017343521, 0.0]], [[0.15394526720046997, 0.03155883774161339, 0.013263775035738945, 0.6156436800956726, 0.07578981667757034, 0.016770629212260246, 0.041555847972631454, 0.008898256346583366, 0.042573899030685425, 0.0], [0.0076131573878228664, 0.0034139170311391354, 0.013692762702703476, 0.9489110708236694, 0.0023491643369197845, 0.004504398908466101, 0.018228650093078613, 6.88576401444152e-05, 0.0012180121848359704, 0.0], [0.002593724289909005, 0.0010450187837705016, 0.004842442460358143, 0.9895163178443909, 0.0010734394891187549, 6.863682210678235e-05, 0.00040535334846936166, 0.00029785462538711727, 0.00015763564442750067, 0.0], [0.007949098013341427, 0.0007930149440653622, 0.0010613474296405911, 0.913150429725647, 0.017281265929341316, 0.01414033118635416, 0.03622613847255707, 0.0036093932576477528, 0.0057889497838914394, 0.0], [0.002646723063662648, 0.0005160199943929911, 0.0002488561731297523, 0.0025351925287395716, 0.0016247049206867814, 0.09429789334535599, 0.8856176137924194, 0.005861275363713503, 0.006651633884757757, 0.0], [0.010305220261216164, 0.0041244118474423885, 0.0009454450337216258, 0.011387528851628304, 0.006450551562011242, 0.09920497238636017, 0.7582080960273743, 0.0005519646219909191, 0.1088215708732605, 0.0], [0.02700764685869217, 0.004230276681482792, 0.0004602614790201187, 0.0022337904665619135, 0.001628970610909164, 0.01760227419435978, 0.5739604234695435, 0.0034094173461198807, 0.3694668710231781, 0.0], [0.008519366383552551, 0.005846879445016384, 0.00031929058604873717, 0.00022687541786581278, 0.0001488836423959583, 0.0012441301951184869, 0.007195098325610161, 0.000364138453733176, 0.9761351943016052, 0.0], [0.10429845005273819, 0.062129467725753784, 0.0009245545952580869, 0.00015166438242886215, 0.00031537580071017146, 0.00040291156619787216, 0.006900690030306578, 0.009933815337717533, 0.8149431943893433, 0.0], [0.021318454295396805, 0.024646490812301636, 0.0006273255567066371, 3.4892458643298596e-05, 0.0002248335222247988, 0.001184952794574201, 0.005942351184785366, 0.01648845337331295, 0.9295321702957153, 0.0]], [[0.18210574984550476, 0.33179324865341187, 0.17356830835342407, 0.09634877741336823, 0.13200198113918304, 0.013823754154145718, 0.003925282042473555, 0.0030049949418753386, 0.06342781335115433, 0.0], [0.0020119324326515198, 0.018535200506448746, 0.9658629894256592, 0.007450602483004332, 0.004517300054430962, 0.0010996937053278089, 0.00011890578753082082, 8.778373739914969e-05, 0.0003156243183184415, 0.0], [0.00025491390260867774, 0.010184208862483501, 0.989091694355011, 0.0003856455150526017, 4.125349732930772e-05, 1.630889528314583e-05, 4.754766450787429e-06, 9.06635705177905e-06, 1.2339231943769846e-05, 0.0], [0.0005263744969852269, 0.0021082928869873285, 0.03339724615216255, 0.9491472840309143, 0.010056117549538612, 0.0008323733345605433, 0.00041247360059060156, 0.003171485150232911, 0.0003482940956018865, 0.0], [0.004013302735984325, 0.003607808379456401, 0.10117900371551514, 0.3848154544830322, 0.4549750089645386, 0.02184862084686756, 0.015023248270154, 0.013938427902758121, 0.0005991325015202165, 0.0], [0.006630904506891966, 0.001555793103761971, 0.01566290855407715, 0.005377574823796749, 0.0545264296233654, 0.7578195929527283, 0.15542279183864594, 0.00011175184772582725, 0.002892365213483572, 0.0], [0.0013706677127629519, 0.0003565592342056334, 0.0006504033226519823, 0.0008717189775779843, 0.023110924288630486, 0.16852477192878723, 0.8020843863487244, 0.0004564319388009608, 0.0025740356650203466, 0.0], [0.0005740747437812388, 0.0018384596332907677, 0.015691960230469704, 0.0004515495093073696, 0.04004881531000137, 0.8668573498725891, 0.03566786274313927, 0.01278533972799778, 0.0260846596211195, 0.0], [0.0008355869795195758, 0.003608973463997245, 0.04490630701184273, 0.009341607801616192, 0.007649072911590338, 0.10034366697072983, 0.06446904689073563, 0.7009655237197876, 0.06788014620542526, 0.0], [0.00805425550788641, 0.039243537932634354, 0.05003930628299713, 0.0007152591715566814, 0.00863983016461134, 0.4756345748901367, 0.24407540261745453, 0.1204291433095932, 0.053168926388025284, 0.0]], [[0.7289455533027649, 0.07586073875427246, 0.09869885444641113, 0.029881592839956284, 0.013988524675369263, 0.006547953933477402, 0.011115974746644497, 0.01961168274283409, 0.01534893549978733, 0.0], [0.0068419137969613075, 0.9909167289733887, 0.0013876587618142366, 0.00011136479588458315, 6.257302447920665e-05, 6.40047510387376e-05, 1.7212767488672398e-05, 0.00025805848417803645, 0.0003405954339541495, 0.0], [0.0458948090672493, 0.09657034277915955, 0.8496794104576111, 0.005955036263912916, 0.00011324687511660159, 0.000537818530574441, 0.00024974963162094355, 6.34682146483101e-05, 0.0009358474053442478, 0.0], [0.0026473035104572773, 0.007075308356434107, 0.0509142205119133, 0.9043685793876648, 0.01687121018767357, 0.0027590824756771326, 0.0040096985176205635, 0.004973408300429583, 0.006381142418831587, 0.0], [0.00491793267428875, 0.0006714555202051997, 0.0047717769630253315, 0.09624139964580536, 0.8609471917152405, 0.020327016711235046, 0.008984015323221684, 0.0013932499568909407, 0.0017457373905926943, 0.0], [0.003878936870023608, 0.005480882711708546, 0.00011314810399198905, 0.0003485401102807373, 0.006120527163147926, 0.0029893070459365845, 0.0006264422554522753, 0.004414959345012903, 0.9760271906852722, 0.0], [2.4277944248751737e-05, 4.029595402244013e-06, 3.533453991622082e-07, 0.0002488488098606467, 2.0782925275852904e-05, 0.0004858894390054047, 0.9990906119346619, 4.08113919547759e-05, 8.422240352956578e-05, 0.0], [0.0008877408690750599, 0.00558891985565424, 8.855570922605693e-05, 8.779557902016677e-06, 0.00010457105963723734, 0.00017662049503996968, 0.002778601599857211, 0.09916532039642334, 0.8912010192871094, 0.0], [0.0001188771057059057, 0.0008492054184898734, 9.383026917930692e-05, 5.974515715934103e-06, 0.002050562761723995, 8.90250812517479e-05, 8.933644130593166e-05, 0.002921103034168482, 0.9937818646430969, 0.0], [0.013618292286992073, 0.005976812914013863, 3.6079491110285744e-05, 4.805085336556658e-05, 0.00010178168304264545, 0.03545643016695976, 0.04239484667778015, 0.35667654871940613, 0.5456912517547607, 0.0]], [[0.5123088955879211, 0.04897892847657204, 0.0054785809479653835, 0.037157464772462845, 0.033040400594472885, 0.0287709329277277, 0.020658228546380997, 0.005767439026385546, 0.3078390061855316, 0.0], [0.02116605080664158, 0.45384567975997925, 0.5185156464576721, 0.002682638820260763, 0.000782136688940227, 0.0011598592391237617, 9.265153494197875e-05, 0.0013774167746305466, 0.00037764458102174103, 0.0], [0.06292606890201569, 0.18043853342533112, 0.7547035217285156, 0.0011577572440728545, 6.746899998688605e-06, 0.00012623713701032102, 8.539375994587317e-05, 0.0005039210664108396, 5.15973697474692e-05, 0.0], [0.0030907110776752234, 0.0035500964149832726, 0.4530283808708191, 0.5221376419067383, 0.009557867422699928, 0.008033420890569687, 0.0001996932114707306, 0.0003745300055015832, 2.7415442673373036e-05, 0.0], [0.00426989421248436, 0.0004077394842170179, 0.010691642761230469, 0.016011668369174004, 0.5530933737754822, 0.12423280626535416, 0.0053755901753902435, 0.28551235795021057, 0.00040478314622305334, 0.0], [0.00216855201870203, 0.009595326147973537, 0.007803121581673622, 0.04625817388296127, 0.24702508747577667, 0.2669595181941986, 0.024053409695625305, 0.24639348685741425, 0.14974308013916016, 0.0], [1.953603486981592e-06, 5.726202516598278e-07, 5.0084551617146644e-08, 6.896097602293594e-06, 0.0001788837107596919, 0.0027895711828023195, 0.9969833493232727, 1.1296948287053965e-05, 2.7187752493773587e-05, 0.0], [0.00041725789196789265, 0.0019265476148575544, 6.523763295263052e-05, 0.00018337361689191312, 0.011946662329137325, 0.04555974155664444, 0.15744170546531677, 0.025624049827456474, 0.7568355202674866, 0.0], [0.0004706757317762822, 0.0027324894908815622, 0.0007427418022416532, 0.00934627279639244, 0.17134670913219452, 0.030644211918115616, 0.08413954824209213, 0.2513456642627716, 0.4492316246032715, 0.0], [0.0002758087939582765, 0.0016877831658348441, 2.4452297111565713e-06, 0.0004533462051767856, 0.001545731327496469, 0.008134560659527779, 0.010873721912503242, 0.026235109195113182, 0.950791597366333, 0.0]], [[0.23423711955547333, 0.3770143389701843, 0.036408282816410065, 0.05249679461121559, 0.12246986478567123, 0.07398416101932526, 0.040104977786540985, 0.05500312149524689, 0.008280999027192593, 0.0], [0.032638370990753174, 0.0975130945444107, 0.07180608063936234, 0.1075734794139862, 0.025216424837708473, 0.0218534916639328, 0.0376754030585289, 0.19710108637809753, 0.40862247347831726, 0.0], [0.06665100902318954, 0.01976630836725235, 0.13041609525680542, 0.1772802174091339, 0.07561768591403961, 0.0061133550480008125, 0.022857116535305977, 0.3516288995742798, 0.14966930449008942, 0.0], [0.007724895142018795, 0.007534464355558157, 0.020593322813510895, 0.32147932052612305, 0.059838466346263885, 0.017819387838244438, 0.13181470334529877, 0.15524767339229584, 0.27794769406318665, 0.0], [0.0028419073205441236, 0.0006648481939919293, 0.0018234169110655785, 0.01609039306640625, 0.0009005893371067941, 0.09726841002702713, 0.11035522073507309, 0.6978457570075989, 0.07220931351184845, 0.0], [0.0027223427314311266, 0.011157790198922157, 0.0052745467983186245, 0.00438346853479743, 0.010011360049247742, 0.38358205556869507, 0.2294219732284546, 0.057655833661556244, 0.29579076170921326, 0.0], [0.0007371494430117309, 0.00023907626746222377, 8.450529276160523e-05, 0.002850313438102603, 0.002168564358726144, 0.02191324159502983, 0.9514637589454651, 0.0002505093871150166, 0.02029278129339218, 0.0], [0.004062887746840715, 0.007024500984698534, 0.007399421185255051, 0.011675640009343624, 0.0680280476808548, 0.02557964250445366, 0.07043837755918503, 0.01946563646197319, 0.7863259315490723, 0.0], [0.0007200208492577076, 0.0024253681767731905, 0.029840486124157906, 0.00014908696175552905, 0.11268167197704315, 0.03336171433329582, 0.007834793999791145, 0.08127990365028381, 0.7317068576812744, 0.0], [0.008707311004400253, 0.08642490208148956, 0.11372587829828262, 0.004973042756319046, 0.13256150484085083, 0.16557356715202332, 0.0817113071680069, 0.006879410706460476, 0.3994430899620056, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"], \"bot_text\": [\"\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}, \"out_out\": {\"att\": [[[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9412446618080139, 0.05875528231263161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7461972832679749, 0.18569768965244293, 0.06810508668422699, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4299372434616089, 0.16845084726810455, 0.2029547393321991, 0.19865721464157104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5215166807174683, 0.16121163964271545, 0.19463112950325012, 0.09347883611917496, 0.029161658138036728, 0.0, 0.0, 0.0, 0.0, 0.0], [0.26405569911003113, 0.04358615726232529, 0.10687251389026642, 0.1710020899772644, 0.4105237126350403, 0.0039598336443305016, 0.0, 0.0, 0.0, 0.0], [0.29189321398735046, 0.19170531630516052, 0.11295431852340698, 0.08274418860673904, 0.12850242853164673, 0.09739833325147629, 0.09480219334363937, 0.0, 0.0, 0.0], [0.3496137857437134, 0.03085259348154068, 0.0195528082549572, 0.45414459705352783, 0.09152030944824219, 0.008845902979373932, 0.02992299199104309, 0.01554702315479517, 0.0, 0.0], [0.4675538241863251, 0.03941410034894943, 0.05400091037154198, 0.17985978722572327, 0.20104949176311493, 0.030323797836899757, 0.010615098290145397, 0.015154700726270676, 0.002028239192441106, 0.0], [0.053565241396427155, 0.029699191451072693, 0.0156599972397089, 0.016939852386713028, 0.04015244543552399, 0.21933501958847046, 0.1449035257101059, 0.4037321209907532, 0.019583676010370255, 0.056428998708724976]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5249735116958618, 0.4750264883041382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3563348054885864, 0.5701623558998108, 0.07350286096334457, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3398579955101013, 0.23167477548122406, 0.1957632154226303, 0.23270410299301147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4351256191730499, 0.09737284481525421, 0.08845506608486176, 0.06574707478284836, 0.31329941749572754, 0.0, 0.0, 0.0, 0.0, 0.0], [0.360861599445343, 0.02136792428791523, 0.005633710417896509, 0.009215844795107841, 0.15762653946876526, 0.4452943205833435, 0.0, 0.0, 0.0, 0.0], [0.009015758521854877, 0.0013937305193394423, 0.00017763266805559397, 0.00016997012426145375, 0.010879353620111942, 0.0024589570239186287, 0.9759047627449036, 0.0, 0.0, 0.0], [0.014776602387428284, 0.0001805058855097741, 1.6896785382414237e-05, 0.0003442507586441934, 0.006220621056854725, 0.0012393802171573043, 0.9433164596557617, 0.033905431628227234, 0.0, 0.0], [0.005810329224914312, 0.002043980173766613, 0.0003433740057516843, 0.001522325212135911, 0.0030212807469069958, 0.00817712489515543, 0.5456522107124329, 0.10564129799604416, 0.32778817415237427, 0.0], [0.3754594326019287, 0.030579065904021263, 0.028458155691623688, 0.035943739116191864, 0.28040432929992676, 0.0202159583568573, 0.0396210215985775, 0.05075624957680702, 0.13473623991012573, 0.0038258912973105907]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9630448818206787, 0.036955028772354126, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8940342664718628, 0.015322646126151085, 0.09064316004514694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4866876006126404, 0.028273453935980797, 0.4569007158279419, 0.028138065710663795, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7252220511436462, 0.10817205905914307, 0.07890959084033966, 0.017715180292725563, 0.06998112797737122, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8598019480705261, 0.012843498960137367, 0.014502299018204212, 0.004056263715028763, 0.10580158233642578, 0.0029942472465336323, 0.0, 0.0, 0.0, 0.0], [0.8686293363571167, 0.024889284744858742, 0.013860221020877361, 0.00703870365396142, 0.07120370119810104, 0.003939351066946983, 0.010439489968121052, 0.0, 0.0, 0.0], [0.8572709560394287, 0.018014011904597282, 0.008267350494861603, 0.0022140766959637403, 0.1038530021905899, 0.004275611136108637, 0.0009780752006918192, 0.005126776173710823, 0.0, 0.0], [0.35013046860694885, 0.0037752145435661077, 0.0071558705531060696, 0.01608894392848015, 0.6097922325134277, 0.002463925164192915, 0.0005387101555243134, 0.005540961865335703, 0.004513624589890242, 0.0], [0.1888049989938736, 0.12293454259634018, 0.5947631597518921, 0.009457849897444248, 0.07291270792484283, 0.008950368501245975, 0.0004109511792194098, 0.000914009811822325, 0.0006959570455364883, 0.00015547229850199074]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.91131192445755, 0.08868805319070816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.786292314529419, 0.09286607056856155, 0.1208416074514389, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1722075194120407, 0.10747934877872467, 0.1462225317955017, 0.5740904808044434, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1893281787633896, 0.1733204573392868, 0.06838839501142502, 0.47577211260795593, 0.09319086372852325, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08935888856649399, 0.012517428956925869, 0.017112966626882553, 0.08479276299476624, 0.7640082240104675, 0.03220977261662483, 0.0, 0.0, 0.0, 0.0], [0.824190616607666, 0.008810147643089294, 0.002143737394362688, 0.002297793049365282, 0.11996792256832123, 0.005709697026759386, 0.036880046129226685, 0.0, 0.0, 0.0], [0.1513449102640152, 0.015725232660770416, 0.02784004621207714, 0.01800909824669361, 0.6534391641616821, 0.016422629356384277, 0.09054289758205414, 0.026676079258322716, 0.0, 0.0], [0.1625923067331314, 0.016224535182118416, 0.06514906883239746, 0.003223034320399165, 0.6737184524536133, 0.014129054732620716, 0.036937959492206573, 0.023035621270537376, 0.004990031942725182, 0.0], [0.06836045533418655, 0.01236770860850811, 0.008784784935414791, 0.014186863787472248, 0.09790214896202087, 0.046204064041376114, 0.1703491061925888, 0.1878211945295334, 0.0703599750995636, 0.32366377115249634]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.961704432964325, 0.038295578211545944, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.37462106347084045, 0.2157517969608307, 0.40962719917297363, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48521965742111206, 0.031020229682326317, 0.3760664165019989, 0.10769358277320862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.914044201374054, 0.004715718794614077, 0.006151301320642233, 0.005079128313809633, 0.07000966370105743, 0.0, 0.0, 0.0, 0.0, 0.0], [0.060511741787195206, 0.006127620115876198, 0.00728148128837347, 0.013585635460913181, 0.9084653854370117, 0.004028240218758583, 0.0, 0.0, 0.0, 0.0], [0.23348243534564972, 0.03748093172907829, 0.055222347378730774, 0.014132470823824406, 0.27614685893058777, 0.017582375556230545, 0.3659524619579315, 0.0, 0.0, 0.0], [0.06461911648511887, 0.003781915409490466, 0.002705940278246999, 0.016099220141768456, 0.8774597644805908, 0.012668337672948837, 0.0088069261983037, 0.013858767226338387, 0.0, 0.0], [0.05451222136616707, 0.014412143267691135, 0.00208102585747838, 0.011283651925623417, 0.02552390843629837, 0.02239326573908329, 0.031104939058423042, 0.20777365565299988, 0.630915105342865, 0.0], [0.5451503992080688, 0.014764615334570408, 0.2503703534603119, 0.037022024393081665, 0.0935375839471817, 0.022694993764162064, 0.0037449353840202093, 0.0053339023143053055, 0.007315538357943296, 0.020065704360604286]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9904667735099792, 0.009533224627375603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9818503260612488, 0.007338901981711388, 0.010810752399265766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9738979935646057, 0.007647394668310881, 0.015154722146689892, 0.0032999368850141764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6611008644104004, 0.04138284549117088, 0.1119912639260292, 0.0262944046407938, 0.15923058986663818, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9380988478660583, 0.005562208592891693, 0.01078465860337019, 0.004562946502119303, 0.033130958676338196, 0.007860423997044563, 0.0, 0.0, 0.0, 0.0], [0.9377894997596741, 0.003691342193633318, 0.002771170577034354, 0.0017416415503248572, 0.04246653988957405, 0.002464305842295289, 0.009075501933693886, 0.0, 0.0, 0.0], [0.9083399176597595, 0.005597027484327555, 0.02609928511083126, 0.005710097029805183, 0.017865832895040512, 0.0029857312329113483, 0.002900469582527876, 0.030501706525683403, 0.0, 0.0], [0.8338009119033813, 0.00436164066195488, 0.006190306507050991, 0.0008050849428400397, 0.015337309800088406, 0.00863864365965128, 0.010715007781982422, 0.1143304780125618, 0.005820483900606632, 0.0], [0.9085996747016907, 0.00676243519410491, 0.02013525180518627, 0.009278967045247555, 0.02104269526898861, 0.009343095123767853, 0.0009470531367696822, 0.0018253516172990203, 0.003784958738833666, 0.018280424177646637]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.972051739692688, 0.027948210015892982, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7552067041397095, 0.17251533269882202, 0.0722779706120491, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6455309987068176, 0.23265127837657928, 0.10187581926584244, 0.01994187943637371, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.470674991607666, 0.26442891359329224, 0.14268451929092407, 0.03363766148686409, 0.08857394009828568, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6457618474960327, 0.011289404705166817, 0.008832731284201145, 0.01570025272667408, 0.2588561475276947, 0.059559762477874756, 0.0, 0.0, 0.0, 0.0], [0.4916176497936249, 0.07200384140014648, 0.0701020285487175, 0.019148536026477814, 0.0833231583237648, 0.12199999392032623, 0.14180481433868408, 0.0, 0.0, 0.0], [0.11119699478149414, 0.002801541704684496, 0.0021932011004537344, 0.0016493132570758462, 0.06827285885810852, 0.22499483823776245, 0.5049597024917603, 0.08393163233995438, 0.0, 0.0], [0.13208742439746857, 0.0035411729477345943, 0.0015305017586797476, 0.002489483682438731, 0.06612236052751541, 0.213859423995018, 0.5324232578277588, 0.03503565117716789, 0.012910734862089157, 0.0], [0.20209012925624847, 0.05223073810338974, 0.03088257648050785, 0.036374326795339584, 0.014660456217825413, 0.03045688569545746, 0.03597142919898033, 0.16862399876117706, 0.022359324619174004, 0.40635016560554504]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9218347668647766, 0.0781652107834816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4189925193786621, 0.4865715503692627, 0.09443587809801102, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48251789808273315, 0.34758540987968445, 0.13321316242218018, 0.036683470010757446, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8504839539527893, 0.033341050148010254, 0.053517427295446396, 0.012789242900907993, 0.049868300557136536, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4515743553638458, 0.03267433121800423, 0.019386781379580498, 0.024256065487861633, 0.17900733649730682, 0.29310107231140137, 0.0, 0.0, 0.0, 0.0], [0.5910289883613586, 0.0027754076290875673, 0.004533650353550911, 0.0023315453436225653, 0.08002334088087082, 0.06913208961486816, 0.2501751184463501, 0.0, 0.0, 0.0], [0.1626552939414978, 0.0011573631782084703, 0.00017211545491591096, 0.0007665579323656857, 0.03241841867566109, 0.34369325637817383, 0.2890424132347107, 0.17009468376636505, 0.0, 0.0], [0.10835989564657211, 0.0007107920246198773, 0.00030798258376307786, 0.005807099863886833, 0.04662986099720001, 0.1659584492444992, 0.3522194027900696, 0.30094781517982483, 0.019058646634221077, 0.0], [0.5449283123016357, 0.01310307253152132, 0.008020865730941296, 0.006764447782188654, 0.16009773313999176, 0.06950337439775467, 0.0024397175293415785, 0.014089844189584255, 0.013654321432113647, 0.1673980951309204]]], [[[0.03246883675456047, 0.020431363955140114, 0.06294436007738113, 0.08282972872257233, 0.047490958124399185, 0.03976213559508324, 0.01868664100766182, 0.5054241418838501, 0.18996170163154602, 0.0], [0.0334412157535553, 0.45350977778434753, 0.23828978836536407, 0.07703227549791336, 0.02545342594385147, 0.019935714080929756, 0.007961008697748184, 0.08864670246839523, 0.05572996661067009, 0.0], [0.008816813118755817, 0.009350132197141647, 0.09488566964864731, 0.022458655759692192, 0.001578008639626205, 0.01768183708190918, 0.0012928039068356156, 0.7889453768730164, 0.05499071627855301, 0.0], [0.0037117439787834883, 0.00603569345548749, 0.019362367689609528, 0.06632085889577866, 0.02251342497766018, 0.048607613891363144, 0.00711278198286891, 0.7890322804450989, 0.03730323165655136, 0.0], [0.0017165049212053418, 0.0031809706706553698, 0.00569736585021019, 0.027958940714597702, 0.001130971242673695, 0.006313299294561148, 0.004051794297993183, 0.9312260150909424, 0.018723946064710617, 0.0], [0.0028915719594806433, 0.007050157990306616, 0.004614752251654863, 0.0017270235111936927, 0.0016248916508629918, 0.06901240348815918, 0.005150379613041878, 0.13293159008026123, 0.7749972939491272, 0.0], [0.005032604560256004, 0.005055313929915428, 0.0030569147784262896, 0.0010687477188184857, 0.012304573319852352, 0.013984610326588154, 0.3489484190940857, 0.012370014563202858, 0.5981789827346802, 0.0], [0.0019784842152148485, 0.009333183988928795, 0.005381024908274412, 0.0002465381403453648, 0.0013898308388888836, 0.005461550783365965, 0.0012134313583374023, 0.001065099611878395, 0.9739308953285217, 0.0], [0.005657540168613195, 0.006781480740755796, 0.00696007814258337, 0.0009338636882603168, 0.02429838851094246, 0.03842600807547569, 0.00286328443326056, 0.03579647094011307, 0.8782829642295837, 0.0], [0.007395321968942881, 0.012293249368667603, 0.006963892374187708, 0.00022730379714630544, 0.0005401583621278405, 0.005707587581127882, 0.0028992195148020983, 0.0027063635643571615, 0.9612669944763184, 0.0]], [[0.02470340207219124, 0.02512546442449093, 0.11353036016225815, 0.35132649540901184, 0.20412008464336395, 0.027150044217705727, 0.015305055305361748, 0.05760098248720169, 0.1811380535364151, 0.0], [0.009894105605781078, 0.02192404493689537, 0.3007009029388428, 0.13983333110809326, 0.03682582825422287, 0.08908118307590485, 0.27657952904701233, 0.026430398225784302, 0.09873086214065552, 0.0], [0.011459765024483204, 0.044317521154880524, 0.5289616584777832, 0.19549138844013214, 0.03426412120461464, 0.017797794193029404, 0.030613277107477188, 0.0163635965436697, 0.12073105573654175, 0.0], [0.011578483507037163, 0.0029169816989451647, 0.00455811433494091, 0.01625976897776127, 0.018393559381365776, 0.11749742925167084, 0.32938554883003235, 0.41049671173095703, 0.08891336619853973, 0.0], [0.0033444140572100878, 0.0011373214656487107, 0.0019445078214630485, 0.02781311236321926, 0.0049105980433523655, 0.05221953243017197, 0.09222303330898285, 0.3644186854362488, 0.45198866724967957, 0.0], [0.002199131529778242, 0.0006913270917721093, 0.002652444876730442, 0.017487458884716034, 0.18746966123580933, 0.39171290397644043, 0.26989367604255676, 0.017002178356051445, 0.11089123785495758, 0.0], [0.01051913108676672, 0.003755246289074421, 0.0008555634994991124, 0.002675057854503393, 0.0025919810868799686, 0.02418649010360241, 0.018060903996229172, 0.003447937313467264, 0.9339075684547424, 0.0], [0.029951948672533035, 0.006547479424625635, 0.030934682115912437, 0.0036260345950722694, 0.1420958936214447, 0.19529034197330475, 0.1491098254919052, 0.009723717346787453, 0.43272000551223755, 0.0], [0.017757408320903778, 0.006832967512309551, 0.028906390070915222, 0.00921954121440649, 0.054915353655815125, 0.028632348403334618, 0.03646676614880562, 0.01978384144604206, 0.7974854707717896, 0.0], [0.06588920205831528, 0.05552517622709274, 0.18546447157859802, 0.007839588448405266, 0.020484987646341324, 0.01699826307594776, 0.01947665773332119, 0.017759086564183235, 0.6105626821517944, 0.0]], [[0.14391662180423737, 0.11156481504440308, 0.4162432849407196, 0.07845085859298706, 0.04067624360322952, 0.016916701570153236, 0.012291320599615574, 0.10670017451047897, 0.07323983311653137, 0.0], [0.0171683169901371, 0.03512553498148918, 0.4936983287334442, 0.18945446610450745, 0.020571058616042137, 0.011469473131000996, 0.04002959281206131, 0.08968089520931244, 0.10280223935842514, 0.0], [0.2093620002269745, 0.11281707882881165, 0.25891542434692383, 0.14515942335128784, 0.0042000748217105865, 0.006485591176897287, 0.005525505635887384, 0.14364667236804962, 0.11388827115297318, 0.0], [0.0109701631590724, 0.0007525839027948678, 0.011503712274134159, 0.03920656442642212, 0.2449047565460205, 0.048431187868118286, 0.12996943295001984, 0.4081973731517792, 0.10606419295072556, 0.0], [0.004995591007173061, 0.0001893905719043687, 0.0009439413552172482, 0.03207648918032646, 0.08267047256231308, 0.015983520075678825, 0.02033340558409691, 0.8191123604774475, 0.023694908246397972, 0.0], [0.0022357299458235502, 0.000793653482105583, 0.0010144039988517761, 0.2958794832229614, 0.3394852876663208, 0.07495945692062378, 0.06856833398342133, 0.06118563562631607, 0.15587811172008514, 0.0], [0.0020441634114831686, 0.00032311712857335806, 0.0006899640429764986, 0.03996479511260986, 0.38782593607902527, 0.05503879860043526, 0.24750953912734985, 0.004524962045252323, 0.26207876205444336, 0.0], [0.0012333561899140477, 0.0002747838443610817, 0.0023864947725087404, 0.10253860056400299, 0.4721597135066986, 0.04103615880012512, 0.03782818093895912, 0.026908699423074722, 0.31563398241996765, 0.0], [0.004791810177266598, 0.0015037101693451405, 0.004669447895139456, 0.38809871673583984, 0.13379721343517303, 0.024320820346474648, 0.03647102415561676, 0.013309511356055737, 0.3930378258228302, 0.0], [0.00849083997309208, 0.003579143201932311, 0.0033037925604730844, 0.006032468285411596, 0.017621049657464027, 0.0234503336250782, 0.018282314762473106, 0.02657976746559143, 0.8926602602005005, 0.0]], [[0.8417463898658752, 0.05951714888215065, 0.012198105454444885, 0.03180553764104843, 0.02919766865670681, 0.0096508814021945, 0.003031272441148758, 0.0009100366733036935, 0.011942943558096886, 0.0], [0.00569154741242528, 0.979739785194397, 0.012030904181301594, 0.0001143000990850851, 9.368032624479383e-05, 0.0008171445806510746, 0.00012590458209160715, 0.0005024938145652413, 0.0008843241375871003, 0.0], [0.005223963409662247, 0.005622355733066797, 0.9848889708518982, 0.002582893241196871, 0.0003334738139528781, 0.0005618981667794287, 3.256636409787461e-05, 0.00024550766102038324, 0.0005086653982289135, 0.0], [0.0032260464504361153, 0.007557107135653496, 0.0651315227150917, 0.6094849109649658, 0.008782745338976383, 0.2748804986476898, 0.015592943876981735, 0.008143502287566662, 0.007200630847364664, 0.0], [0.01683628372848034, 0.0020552987698465586, 0.00783018209040165, 0.008005303330719471, 0.0011927365558221936, 0.9284406900405884, 0.03478293865919113, 0.00030738895293325186, 0.0005490221083164215, 0.0], [0.0004254023951943964, 7.111614831956103e-05, 0.0008891545585356653, 1.880968193290755e-05, 6.570573896169662e-05, 0.9941434860229492, 0.0025632327888160944, 9.733852493809536e-06, 0.0018130606040358543, 0.0], [7.936867405078374e-06, 1.8136512153432705e-05, 4.5569290705316234e-06, 1.071940641850233e-05, 3.808495648627286e-06, 0.0008168917265720665, 0.9974388480186462, 1.4373016711033415e-05, 0.0016848900122568011, 0.0], [0.0014213839313015342, 0.003971228376030922, 0.008488249033689499, 2.0282970581320114e-05, 8.774230809649453e-05, 0.030342059209942818, 0.010436602868139744, 0.013138609007000923, 0.9320940375328064, 0.0], [9.058997966349125e-05, 0.0009022729936987162, 0.0017266678623855114, 1.3629892237077001e-05, 0.000727150880265981, 0.002379553159698844, 0.0010508937994018197, 0.012508089654147625, 0.9806011319160461, 0.0], [0.0003429521748330444, 0.001905322540551424, 0.0005013775080442429, 1.1471392099338118e-05, 0.00017356597527395934, 0.0029742273036390543, 0.003938945475965738, 0.028075864538550377, 0.9620763063430786, 0.0]], [[0.23634016513824463, 0.09021607041358948, 0.12040459364652634, 0.01354933436959982, 0.0019137230701744556, 0.009001325815916061, 0.028688833117485046, 0.2612648904323578, 0.23862121999263763, 0.0], [0.2307557761669159, 0.2812652289867401, 0.30346915125846863, 0.05031246319413185, 0.006193350534886122, 0.01668362505733967, 0.012607063166797161, 0.07951408624649048, 0.019199388101696968, 0.0], [0.29960742592811584, 0.20819564163684845, 0.27825382351875305, 0.007396433036774397, 0.0007608149899169803, 0.0260151494294405, 0.012685009278357029, 0.12934625148773193, 0.03773954138159752, 0.0], [0.035675279796123505, 0.035874202847480774, 0.007117687724530697, 0.018771182745695114, 0.010206644423305988, 0.06527784466743469, 0.03775254264473915, 0.7770709991455078, 0.012253628112375736, 0.0], [0.012017791159451008, 0.0028583300299942493, 0.0024127706419676542, 0.002610970288515091, 0.001820205245167017, 0.04092223569750786, 0.016621166840195656, 0.9115477800369263, 0.009188669733703136, 0.0], [0.03447290509939194, 0.013388306833803654, 0.08488336205482483, 0.015237652696669102, 0.19176845252513885, 0.3472833037376404, 0.10885429382324219, 0.192628413438797, 0.011483324691653252, 0.0], [0.0005363536183722317, 0.0001964608090929687, 0.0017719777533784509, 0.003164003835991025, 0.27662715315818787, 0.05286016687750816, 0.648875892162323, 0.007890382781624794, 0.00807751715183258, 0.0], [0.001257028547115624, 0.00020761204359587282, 0.0024441492278128862, 0.003374723019078374, 0.9062062501907349, 0.0712839737534523, 0.0032159662805497646, 0.009974849410355091, 0.0020355340093374252, 0.0], [0.0008205634076148272, 0.00019305139721836895, 0.002098840195685625, 0.004588909447193146, 0.9688709378242493, 0.01628950424492359, 0.0038415545132011175, 0.0016231476329267025, 0.0016735766548663378, 0.0], [0.03610469028353691, 0.046298399567604065, 0.04650943726301193, 0.02111651562154293, 0.06683006882667542, 0.37146270275115967, 0.174205482006073, 0.15773150324821472, 0.07974111288785934, 0.0]], [[0.03425053879618645, 0.026130978018045425, 0.3080751299858093, 0.027706336230039597, 0.12989944219589233, 0.29902005195617676, 0.0305496696382761, 0.03879137709736824, 0.1055762991309166, 0.0], [0.004509713500738144, 0.02305547706782818, 0.939035952091217, 0.006188178434967995, 0.020785806700587273, 0.00040150884888134897, 0.00018676061881706119, 0.00013036451127845794, 0.005706076975911856, 0.0], [0.0005241778562776744, 0.009561678394675255, 0.988527774810791, 2.2495760276797228e-05, 4.7274414100684226e-05, 0.00013538387429434806, 4.543165232462343e-06, 6.27172994427383e-05, 0.001113483915105462, 0.0], [0.06551901996135712, 0.0800878182053566, 0.06342226266860962, 0.00974376779049635, 0.5160938501358032, 0.02204274758696556, 0.004013149533420801, 0.0735243633389473, 0.1655530482530594, 0.0], [0.0013552415184676647, 0.0004213388019707054, 0.002606122987344861, 0.0010090378345921636, 0.24638326466083527, 0.6568374633789062, 0.01604411192238331, 0.04806208983063698, 0.027281243354082108, 0.0], [0.0002145337639376521, 0.00018796027870848775, 0.0008407118148170412, 0.0029629908967763186, 0.28427600860595703, 0.6725634336471558, 0.023870857432484627, 0.00339014851488173, 0.011693413369357586, 0.0], [0.0009873382514342666, 0.0005485343281179667, 6.628077971981838e-05, 0.0029302756302058697, 0.23183174431324005, 0.05256076529622078, 0.5701138377189636, 0.005792138632386923, 0.13516920804977417, 0.0], [2.471696279826574e-05, 2.0868348656222224e-05, 4.437468305695802e-05, 0.002024284563958645, 0.9655042886734009, 0.024176988750696182, 0.001284845289774239, 0.00018083618488162756, 0.006738840136677027, 0.0], [0.0007289832574315369, 7.746354822302237e-05, 0.00018428664770908654, 0.014176051132380962, 0.9112405180931091, 0.013280178420245647, 0.003417921019718051, 0.02014165185391903, 0.03675319626927376, 0.0], [0.00874137319624424, 0.03438721224665642, 0.17507928609848022, 0.007159235887229443, 0.0029199302662163973, 0.023628318682312965, 0.007933209650218487, 0.004559694789350033, 0.7355918884277344, 0.0]], [[0.01947755739092827, 0.007096209097653627, 0.03225293010473251, 0.0123430285602808, 0.10373923927545547, 0.44083938002586365, 0.04899014160037041, 0.25500863790512085, 0.08025286346673965, 0.0], [0.018974049016833305, 0.05092930048704147, 0.38670486211776733, 0.05532746762037277, 0.02096201851963997, 0.23439037799835205, 0.029592081904411316, 0.06233520433306694, 0.1407845914363861, 0.0], [0.009641589596867561, 0.009545106440782547, 0.19981582462787628, 0.009672220796346664, 0.003704657079651952, 0.04582780599594116, 0.006998295895755291, 0.5789687037467957, 0.13582585752010345, 0.0], [0.00450306897982955, 0.0034239809028804302, 0.012258612550795078, 0.005700208712369204, 0.04511384665966034, 0.4419432282447815, 0.12840862572193146, 0.13075105845928192, 0.22789721190929413, 0.0], [0.00048664878704585135, 0.00010348611976951361, 0.0010980216320604086, 0.0006185582024045289, 0.028226494789123535, 0.37447214126586914, 0.09456676244735718, 0.48241522908210754, 0.018012629821896553, 0.0], [8.0467427324038e-05, 3.9275117160286754e-05, 0.00016763176245149225, 0.00013412459520623088, 0.009092556312680244, 0.7851189374923706, 0.16675172746181488, 0.0029041438829153776, 0.03571125119924545, 0.0], [0.0007275060634128749, 0.00015159584290813655, 0.00037383963353931904, 0.0005468691233545542, 0.01837681420147419, 0.03491391986608505, 0.7517433166503906, 0.00028147027478553355, 0.19288486242294312, 0.0], [0.0005560970166698098, 0.0002987806510645896, 0.0021934551186859608, 0.00023410467838402838, 0.023030919954180717, 0.05263887345790863, 0.01838914304971695, 0.0007265828317031264, 0.9019319415092468, 0.0], [0.007445591501891613, 0.0020796440076082945, 0.012208829633891582, 0.001590645289979875, 0.09274771064519882, 0.017371611669659615, 0.04761578515172005, 0.004260089714080095, 0.8146799802780151, 0.0], [0.014990360476076603, 0.004210897721350193, 0.002848376054316759, 0.0006518716691061854, 0.0007818753365427256, 0.0019951288122683764, 0.0036728696431964636, 0.0004030312702525407, 0.9704453349113464, 0.0]], [[0.21779413521289825, 0.08220235258340836, 0.04201545566320419, 0.07069981843233109, 0.041075702756643295, 0.13784317672252655, 0.1975526064634323, 0.04344295710325241, 0.16737376153469086, 0.0], [0.23605762422084808, 0.07441659271717072, 0.04143041744828224, 0.05435749515891075, 0.0077708023600280285, 0.0960790365934372, 0.4399828016757965, 0.006641789805144072, 0.04326343908905983, 0.0], [0.06337786465883255, 0.03357791155576706, 0.03929098695516586, 0.5017232298851013, 0.0066258725710213184, 0.009236367419362068, 0.1690734624862671, 0.0422079935669899, 0.13488635420799255, 0.0], [0.006272959988564253, 0.0007428607787005603, 0.0011506476439535618, 0.007357995491474867, 0.0006080326274968684, 0.05679970234632492, 0.8685706257820129, 0.03271445259451866, 0.025782890617847443, 0.0], [0.041861388832330704, 0.004794578067958355, 0.0024879220873117447, 0.015253551304340363, 0.0005973980878479779, 0.08281483501195908, 0.814189076423645, 0.006639576051384211, 0.03136153519153595, 0.0], [0.010862020775675774, 0.0008270516409538686, 0.00023008826246950775, 0.006298262160271406, 0.0022151959128677845, 0.09469958394765854, 0.8416994214057922, 0.0006256845663301647, 0.04254243150353432, 0.0], [0.00024508681963197887, 3.835038296529092e-05, 2.0304802092141472e-05, 0.00012946058996021748, 0.0003255259362049401, 0.0026247953064739704, 0.9805192947387695, 0.00014136231038719416, 0.01595580205321312, 0.0], [0.001919803791679442, 0.0005674636922776699, 0.0002780239738058299, 0.0008655164856463671, 0.0013816945720463991, 0.010561172850430012, 0.05357982590794563, 0.0009362901910208166, 0.9299100637435913, 0.0], [0.00319756381213665, 0.0005108749028295279, 0.00043022894533351064, 0.005312783177942038, 0.005197612568736076, 0.008492776192724705, 0.05858352780342102, 0.01401757076382637, 0.9042569398880005, 0.0], [0.00021474930690601468, 0.0004951281007379293, 0.00032367443782277405, 0.0001866286911536008, 6.129321263870224e-05, 0.00016246296581812203, 0.0016925180098041892, 0.000427676277467981, 0.996435821056366, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9262088537216187, 0.07379112392663956, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2983383536338806, 0.576672375202179, 0.12498921155929565, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3100782334804535, 0.1274886280298233, 0.5286650061607361, 0.033768050372600555, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3118414282798767, 0.11087317764759064, 0.12077098339796066, 0.10916762799024582, 0.34734681248664856, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1361667662858963, 0.0034004957415163517, 0.00320720998570323, 0.0056303562596440315, 0.013746269047260284, 0.8378488421440125, 0.0, 0.0, 0.0, 0.0], [0.9168469905853271, 0.009582683444023132, 0.002923850901424885, 0.009140468202531338, 0.0233402531594038, 0.01968987099826336, 0.01847577467560768, 0.0, 0.0, 0.0], [0.4528708755970001, 0.012551077641546726, 0.013286955654621124, 0.003301329677924514, 0.024005549028515816, 0.0439622700214386, 0.03865182027220726, 0.41137006878852844, 0.0, 0.0], [0.06380993872880936, 0.0008893097401596606, 0.0011801879154518247, 0.0013187900185585022, 0.0034512828569859266, 0.0014297974994406104, 0.0023058890365064144, 0.041651248931884766, 0.8839635848999023, 0.0], [0.5330018997192383, 0.012773798778653145, 0.01854255609214306, 0.022641947492957115, 0.1288023591041565, 0.01178218238055706, 0.020595960319042206, 0.08756020665168762, 0.09921147674322128, 0.06508753448724747]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9422653913497925, 0.057734500616788864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.37070432305336, 0.2449311465024948, 0.3843645751476288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5423898100852966, 0.11884469538927078, 0.1850128471851349, 0.15375272929668427, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7452426552772522, 0.024770371615886688, 0.025099167600274086, 0.014617366716265678, 0.19027042388916016, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4940005838871002, 0.026306116953492165, 0.014163044281303883, 0.022562485188245773, 0.43185216188430786, 0.011115492321550846, 0.0, 0.0, 0.0, 0.0], [0.8323472142219543, 0.005361876450479031, 0.001218354911543429, 0.0017811520956456661, 0.06672050058841705, 0.0179598405957222, 0.07461105287075043, 0.0, 0.0, 0.0], [0.5900163650512695, 0.0016051119891926646, 0.00041884748497977853, 0.002425695303827524, 0.09076588600873947, 0.005809221416711807, 0.03928956016898155, 0.2696692943572998, 0.0, 0.0], [0.14191001653671265, 0.0026981914415955544, 0.000433926354162395, 0.0025318085681647062, 0.0752185806632042, 0.041030533611774445, 0.10226735472679138, 0.6134982705116272, 0.020411266013979912, 0.0], [0.9951959252357483, 0.000172812317032367, 0.0011272057890892029, 0.0002565488684922457, 0.001650187186896801, 0.0010172545444220304, 3.585639569791965e-05, 0.00030177918961271644, 2.7251116989646107e-05, 0.00021514984837267548]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9959792494773865, 0.004020644351840019, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8763805031776428, 0.06819441169500351, 0.05542506277561188, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6675543785095215, 0.035431310534477234, 0.2554236948490143, 0.04159051924943924, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8250302076339722, 0.013232334516942501, 0.10887149721384048, 0.016031241044402122, 0.03683457896113396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14042839407920837, 0.005938003305345774, 0.04128086566925049, 0.01834655925631523, 0.7866368293762207, 0.007369248662143946, 0.0, 0.0, 0.0, 0.0], [0.3567042350769043, 0.0165000781416893, 0.015264611691236496, 0.010309864766895771, 0.38396307826042175, 0.025359012186527252, 0.1918991357088089, 0.0, 0.0, 0.0], [0.03735272213816643, 0.0005555232055485249, 0.0009066119673661888, 0.003488750196993351, 0.4253699481487274, 0.039391178637742996, 0.3313658535480499, 0.1615692675113678, 0.0, 0.0], [0.0020103107672184706, 0.0002689870889298618, 0.0004340466111898422, 0.0009705349220894277, 0.03535917028784752, 0.014057940803468227, 0.07802704721689224, 0.8683921694755554, 0.0004796571738552302, 0.0], [0.21001528203487396, 0.008917403407394886, 0.08127831667661667, 0.6020672917366028, 0.0504239983856678, 0.01106872595846653, 0.002271559089422226, 0.009885885752737522, 0.013363776728510857, 0.010707534849643707]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8274853825569153, 0.1725146621465683, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39722761511802673, 0.5465205311775208, 0.05625181272625923, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7089572548866272, 0.12511004507541656, 0.08669630438089371, 0.0792364850640297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9339975714683533, 0.013466393575072289, 0.00928713008761406, 0.00507207540795207, 0.03817704692482948, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7470325231552124, 0.0030789184384047985, 0.0006101431790739298, 0.009402818977832794, 0.23476918041706085, 0.005106179974973202, 0.0, 0.0, 0.0, 0.0], [0.21711143851280212, 0.003716376842930913, 0.00037448908551596105, 0.0019620254170149565, 0.018900232389569283, 0.009617134928703308, 0.7483181953430176, 0.0, 0.0, 0.0], [0.010075456462800503, 5.468959716381505e-05, 5.17756825502147e-06, 5.762913860962726e-05, 0.0005752856959588826, 0.0004235330270603299, 0.004707484506070614, 0.9841007590293884, 0.0, 0.0], [0.0014721885090693831, 9.766960283741355e-05, 9.390318155055866e-06, 9.01468301890418e-05, 0.00026504675042815506, 0.0001477079640608281, 0.0007441531051881611, 0.9970147013664246, 0.00015886487381067127, 0.0], [0.9506397247314453, 0.010028047487139702, 0.0004243685398250818, 0.012790095992386341, 0.006212451495230198, 0.0008045415161177516, 0.0008908100426197052, 0.0004145564162172377, 0.0002187698701163754, 0.01757662557065487]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9158000946044922, 0.0841999277472496, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9424960017204285, 0.02535107545554638, 0.032153017818927765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22060541808605194, 0.18997374176979065, 0.08500542491674423, 0.5044154524803162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7531844973564148, 0.02070058509707451, 0.008920542895793915, 0.016695866361260414, 0.20049844682216644, 0.0, 0.0, 0.0, 0.0, 0.0], [0.759453296661377, 0.0056156679056584835, 0.008695651777088642, 0.014426307752728462, 0.16163751482963562, 0.05017174035310745, 0.0, 0.0, 0.0, 0.0], [0.2527230679988861, 0.0006535803549923003, 0.00037003192119300365, 0.00041730765951797366, 0.057080648839473724, 0.06757333129644394, 0.6211821436882019, 0.0, 0.0, 0.0], [0.6996693015098572, 0.00526623846963048, 0.003115275641903281, 0.001864676014520228, 0.019210346043109894, 0.022201303392648697, 0.16487717628479004, 0.08379579335451126, 0.0, 0.0], [0.01643717661499977, 0.001304203411564231, 0.00015219511988107115, 8.364384120795876e-05, 0.0027460975106805563, 0.005807426758110523, 0.02910688892006874, 0.054244525730609894, 0.8901176452636719, 0.0], [0.03737838938832283, 0.0008823095704428852, 0.00013810240488965064, 0.0003819032572209835, 0.0009168537217192352, 0.017434338107705116, 0.0524771511554718, 0.5634113550186157, 0.05003770440816879, 0.27694204449653625]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9822245836257935, 0.017775410786271095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9783667922019958, 0.004186260513961315, 0.01744689606130123, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8277915120124817, 0.0035995396319776773, 0.1268300712108612, 0.04177885130047798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9593387246131897, 0.001320014358498156, 0.002763292985036969, 0.002305841539055109, 0.03427214175462723, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5380056500434875, 0.00011044789425795898, 0.001150083844549954, 0.002725756261497736, 0.45681822299957275, 0.0011898496886715293, 0.0, 0.0, 0.0, 0.0], [0.16147758066654205, 0.001678255619481206, 0.004225697834044695, 0.012547606602311134, 0.4120558202266693, 0.030565770342946053, 0.37744930386543274, 0.0, 0.0, 0.0], [0.07655133306980133, 0.00011485892173368484, 0.0004792730906046927, 0.0037317569367587566, 0.9091346859931946, 0.005207230802625418, 0.003226343309506774, 0.0015543886693194509, 0.0, 0.0], [0.0006837816908955574, 6.692374881822616e-05, 3.2170661143027246e-05, 0.017242103815078735, 0.9703013896942139, 0.0009919245494529605, 0.00010187587758991867, 0.00012404048175085336, 0.01045528706163168, 0.0], [0.8681296706199646, 0.004244405776262283, 0.0034055972937494516, 0.0032342004124075174, 0.11890427023172379, 0.00032322408515028656, 1.7166490579256788e-05, 8.356601756531745e-05, 0.00016651467012707144, 0.0014914675848558545]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9673911333084106, 0.032608743757009506, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8945506811141968, 0.048047225922346115, 0.05740200728178024, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8226539492607117, 0.025171183049678802, 0.033602889627218246, 0.1185719221830368, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7488189339637756, 0.022310951724648476, 0.03220387548208237, 0.05049983412027359, 0.14616648852825165, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5947939157485962, 0.009725339710712433, 0.01194794476032257, 0.06678443402051926, 0.22137242555618286, 0.09537594765424728, 0.0, 0.0, 0.0, 0.0], [0.5493549704551697, 0.010730843059718609, 0.013811847195029259, 0.01375968661159277, 0.13386781513690948, 0.031593821942806244, 0.2468811273574829, 0.0, 0.0, 0.0], [0.44999176263809204, 0.0022518665064126253, 0.007128801662474871, 0.06941325962543488, 0.11436374485492706, 0.06527625769376755, 0.25339174270629883, 0.038182370364665985, 0.0, 0.0], [0.6273319125175476, 0.0019851899705827236, 0.014608433470129967, 0.053566914051771164, 0.10037831962108612, 0.05395424738526344, 0.09709113836288452, 0.020020073279738426, 0.031063806265592575, 0.0], [0.13732852041721344, 0.005784862674772739, 0.011142567731440067, 0.3659982979297638, 0.03412118926644325, 0.191008523106575, 0.02493627928197384, 0.01782877929508686, 0.005097466055303812, 0.2067534178495407]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9590145349502563, 0.0409853532910347, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13186156749725342, 0.7104970812797546, 0.15764127671718597, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1307007521390915, 0.4791290760040283, 0.2198515087366104, 0.1703186184167862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.25735223293304443, 0.03605807572603226, 0.08834479749202728, 0.21978884935379028, 0.398455947637558, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014754761941730976, 0.016280202195048332, 0.010505245067179203, 0.26496851444244385, 0.6780229210853577, 0.015468388795852661, 0.0, 0.0, 0.0, 0.0], [0.0561433881521225, 0.00821017101407051, 0.013592599891126156, 0.04250938817858696, 0.20505541563034058, 0.637790322303772, 0.03669866546988487, 0.0, 0.0, 0.0], [0.02288638986647129, 0.0031705975998193026, 0.0010986417764797807, 0.1258203089237213, 0.13997967541217804, 0.6275703310966492, 0.004779829643666744, 0.07469423860311508, 0.0, 0.0], [0.04480466619133949, 0.007826470769941807, 0.0012622721260413527, 0.18829701840877533, 0.1579897105693817, 0.4087865948677063, 0.0030938636045902967, 0.17715193331241608, 0.010787548497319221, 0.0], [0.2647387683391571, 0.0023117128293961287, 0.5836825370788574, 0.022214042022824287, 0.05302866920828819, 0.05609899014234543, 0.0002153095556423068, 0.0012429821072146297, 0.012765316292643547, 0.0037017168942838907]]], [[[0.18620921671390533, 0.0449230894446373, 0.15743261575698853, 0.0027164025232195854, 0.000954743183683604, 0.10880818217992783, 0.004260051064193249, 0.4840531051158905, 0.010642877779901028, 0.0], [0.10068266838788986, 0.8361198902130127, 0.05278307944536209, 0.003077939385548234, 0.0006954235723242164, 0.001363753923214972, 0.00026539582177065313, 0.004202431067824364, 0.0008096573874354362, 0.0], [0.012129311449825764, 0.01155073568224907, 0.9600933194160461, 8.282387716462836e-05, 1.0725593710958492e-05, 0.0005505315493792295, 8.825069380691275e-05, 0.015057343989610672, 0.00043726651347242296, 0.0], [8.100323611870408e-05, 0.0004598332743626088, 0.004657193087041378, 0.000634010590147227, 0.00027469659107737243, 0.005632649641484022, 0.000647437758743763, 0.9867796301841736, 0.0008332319557666779, 0.0], [0.00010327257041353732, 8.895192149793729e-05, 0.0004001102061010897, 3.5898548958357424e-05, 8.903054549591616e-06, 0.002168947132304311, 0.0003314291825518012, 0.9968016743659973, 6.082480831537396e-05, 0.0], [0.0006819640402682126, 0.0025551444850862026, 0.029635878279805183, 0.0007182788685895503, 0.0009121407056227326, 0.9391846656799316, 0.0023257755674421787, 0.020892569795250893, 0.0030933902598917484, 0.0], [0.0006610184791497886, 0.004029686562716961, 0.03350083529949188, 0.0028945906087756157, 0.06891647726297379, 0.0361749529838562, 0.6805889010429382, 0.0015104033518582582, 0.17172299325466156, 0.0], [0.00011510718468343839, 0.00041600633994676173, 0.007651225198060274, 0.0003919293521903455, 0.048794399946928024, 0.12390702962875366, 0.005600529722869396, 0.0008058404200710356, 0.8123176097869873, 0.0], [0.0003188557457178831, 0.0017433647299185395, 0.0013032852439209819, 0.008202485740184784, 0.26753997802734375, 0.1699969321489334, 0.02015369012951851, 0.026912324130535126, 0.5038290619850159, 0.0], [0.020566454157233238, 0.12752646207809448, 0.13235142827033997, 8.515831723343581e-05, 0.0007726486655883491, 0.005525102838873863, 0.002064254367724061, 0.0015006973408162594, 0.7096077799797058, 0.0]], [[0.08830718696117401, 0.003260435536503792, 0.007942354306578636, 0.007197668310254812, 0.023230358958244324, 0.6884769797325134, 0.13524922728538513, 0.013760159723460674, 0.03257569298148155, 0.0], [0.01410764642059803, 0.011476421728730202, 0.655226469039917, 0.029443562030792236, 0.17404575645923615, 0.04738258570432663, 0.035108331590890884, 0.004049936309456825, 0.02915901131927967, 0.0], [0.006112441886216402, 0.010383019223809242, 0.9739192724227905, 0.0017695348942652345, 0.0007649966282770038, 0.001380802714265883, 0.0003705607377924025, 0.00034036929719150066, 0.004958811681717634, 0.0], [0.025388794019818306, 0.006199578754603863, 0.10192698240280151, 0.0023500584065914154, 0.009979050606489182, 0.5388055443763733, 0.29305511713027954, 0.002850176068022847, 0.0194447822868824, 0.0], [0.0011180925648659468, 3.349311737110838e-05, 0.00020844468963332474, 0.00016400347521994263, 0.001158660277724266, 0.5398337244987488, 0.4514371454715729, 0.00012239665375091136, 0.005924074444919825, 0.0], [4.934398384648375e-05, 6.905893883413228e-07, 5.809057256556116e-06, 1.44853029269143e-05, 0.0013859024038538337, 0.62599116563797, 0.3719564974308014, 0.0002632574178278446, 0.00033293903106823564, 0.0], [1.8935834305011667e-05, 5.593590231001144e-06, 9.02482042874908e-06, 4.666295353672467e-05, 0.00140501803252846, 0.0024830379988998175, 0.9939435124397278, 0.00030495785176754, 0.0017833412857726216, 0.0], [0.00015082204481586814, 9.979225069400854e-06, 0.00013493606820702553, 0.0006857623811811209, 0.9507938623428345, 0.013522839173674583, 0.004887807182967663, 0.001293701701797545, 0.028520429506897926, 0.0], [0.00021830093464814126, 1.1190621080459096e-05, 0.0010014179861173034, 0.0016852812841534615, 0.9693949818611145, 0.003066261066123843, 0.002616706071421504, 0.006246546749025583, 0.015759343281388283, 0.0], [0.033513687551021576, 0.047761499881744385, 0.1371326446533203, 0.027179328724741936, 0.07905351370573044, 0.04665757715702057, 0.017991477623581886, 0.0258343443274498, 0.5848759412765503, 0.0]], [[0.3675236701965332, 0.22013956308364868, 0.3048599064350128, 0.045011524111032486, 0.013697491027414799, 0.012050136923789978, 0.009531261399388313, 0.0020223394967615604, 0.025163909420371056, 0.0], [0.013416368514299393, 0.7244334816932678, 0.22923606634140015, 0.004823721945285797, 0.0007022434147074819, 0.0012150612892583013, 0.001360778696835041, 0.00021415007358882576, 0.024598030373454094, 0.0], [0.03640636429190636, 0.024720389395952225, 0.8944843411445618, 0.0018058173591271043, 0.00014742508938070387, 0.002046161564067006, 0.0012721297098323703, 0.0010774562833830714, 0.0380399152636528, 0.0], [0.032080236822366714, 0.02157183177769184, 0.017530914396047592, 0.21374234557151794, 0.5176447033882141, 0.021586988121271133, 0.06124785542488098, 0.004810539539903402, 0.10978466272354126, 0.0], [0.16469916701316833, 0.0144515885040164, 0.007452514488250017, 0.029052020981907845, 0.2643658220767975, 0.1970161497592926, 0.2818319797515869, 0.016781603917479515, 0.024349281564354897, 0.0], [0.025996195152401924, 0.005627068690955639, 0.007119623012840748, 0.004898787476122379, 0.5349600911140442, 0.05678911507129669, 0.3094601333141327, 0.008422048762440681, 0.04672713205218315, 0.0], [0.004280757624655962, 0.0006373892538249493, 9.946383943315595e-05, 0.00030879577388986945, 0.02805289998650551, 0.008433223702013493, 0.9252934455871582, 0.001439885818399489, 0.03145414590835571, 0.0], [0.04426492750644684, 0.0032368048559874296, 0.0014763016952201724, 0.0021763627883046865, 0.5636131763458252, 0.010265699587762356, 0.08146306872367859, 0.003517861943691969, 0.289985716342926, 0.0], [0.012160537764430046, 0.00020874926121905446, 0.0005602578166872263, 0.0007960868533700705, 0.9389106035232544, 0.005963308271020651, 0.005384649150073528, 0.0009963578777387738, 0.035019390285015106, 0.0], [0.006462599150836468, 0.006167746149003506, 0.00141435069963336, 0.00035615835804492235, 0.0002947094908449799, 0.002378113567829132, 0.011835698038339615, 0.0024426754098385572, 0.968647837638855, 0.0]], [[0.013161101378500462, 0.01350532379001379, 0.39494189620018005, 0.007352527230978012, 0.12711142003536224, 0.14605116844177246, 0.03487401455640793, 0.15623201429843903, 0.10677067190408707, 0.0], [0.021876059472560883, 0.4906902313232422, 0.4596463143825531, 0.004091671667993069, 0.004464378114789724, 0.001156727666966617, 0.000353646173607558, 0.000146497564855963, 0.017574656754732132, 0.0], [0.005734701175242662, 0.026843877509236336, 0.9321272969245911, 0.00021884289162699133, 0.00045866103027947247, 0.0010309598874300718, 0.00017261962057091296, 0.003054215107113123, 0.030358724296092987, 0.0], [0.0482722632586956, 0.14050070941448212, 0.4546079635620117, 0.0072937230579555035, 0.023873258382081985, 0.09857403486967087, 0.0516686774790287, 0.11766187101602554, 0.05754747614264488, 0.0], [0.0020078516099601984, 0.002228439087048173, 0.111594557762146, 0.0033910104539245367, 0.08423032611608505, 0.17691271007061005, 0.14758752286434174, 0.4346924424171448, 0.037355244159698486, 0.0], [0.0008274781284853816, 0.0016531302826479077, 0.047970183193683624, 0.0006053023971617222, 0.22220103442668915, 0.6234129071235657, 0.05364101752638817, 0.012585645541548729, 0.03710317984223366, 0.0], [2.7583497285377234e-05, 1.1631378583842888e-05, 4.4259006244828925e-05, 0.0006730516324751079, 0.599366307258606, 0.006597205530852079, 0.3886081576347351, 0.0003169252013321966, 0.004354946780949831, 0.0], [2.752073669398669e-06, 2.0648456029448425e-06, 8.536147106497083e-06, 6.34281532256864e-05, 0.9992840886116028, 0.00028667543665505946, 7.951273437356576e-05, 3.5721727726922836e-06, 0.00026920961681753397, 0.0], [3.3996084312093444e-06, 2.1497796751646092e-06, 7.304265182028757e-06, 0.00018760550301522017, 0.99969482421875, 2.4790026145637967e-05, 3.4293629141757265e-05, 6.942725121916737e-06, 3.892222957802005e-05, 0.0], [0.0005689842510037124, 0.002939490834251046, 0.019829533994197845, 0.0003717679646797478, 0.01646142266690731, 0.011912180110812187, 0.001234701368957758, 0.0013870754046365619, 0.945294976234436, 0.0]], [[0.00632825493812561, 0.011520092375576496, 0.08263711631298065, 0.006356080062687397, 0.022936103865504265, 0.03108564019203186, 0.013897407799959183, 0.697504997253418, 0.12773430347442627, 0.0], [0.008715116418898106, 0.015272715128958225, 0.10463730990886688, 0.08011683076620102, 0.13045108318328857, 0.05373600497841835, 0.015578814782202244, 0.4212273955345154, 0.1702648103237152, 0.0], [0.004959889687597752, 0.007777809165418148, 0.14492008090019226, 0.02459821291267872, 0.014704479835927486, 0.016136664897203445, 0.008129375986754894, 0.7319321036338806, 0.0468413271009922, 0.0], [0.005315575283020735, 0.0021190166007727385, 0.007080279756337404, 0.006970370654016733, 0.010002117604017258, 0.007610250264406204, 0.004703941754996777, 0.8570073246955872, 0.09919113665819168, 0.0], [0.0016317280242219567, 0.0005414763581939042, 0.004523266106843948, 0.0019645043648779392, 0.010821727104485035, 0.008883371017873287, 0.00927714817225933, 0.920802652835846, 0.041554201394319534, 0.0], [0.002020488725975156, 0.0007793906843289733, 0.022791940718889236, 0.005821499973535538, 0.1932065784931183, 0.30031588673591614, 0.08197023719549179, 0.12508654594421387, 0.2680076062679291, 0.0], [0.007396090775728226, 0.0032474161125719547, 0.00692824088037014, 0.007240207865834236, 0.42384257912635803, 0.04473983123898506, 0.013007782399654388, 0.007779541425406933, 0.4858182966709137, 0.0], [0.0026900237426161766, 0.0007204422145150602, 0.005861051380634308, 0.003422616282477975, 0.46744993329048157, 0.10402297228574753, 0.05837857723236084, 0.0177029799669981, 0.3397515118122101, 0.0], [0.005906206555664539, 0.002057044068351388, 0.0031123505905270576, 0.008901549503207207, 0.43650564551353455, 0.08504725992679596, 0.0923796221613884, 0.009556618519127369, 0.3565336763858795, 0.0], [0.013360978104174137, 0.04520300775766373, 0.09048072248697281, 0.012179902754724026, 0.030064363032579422, 0.023480970412492752, 0.008669134229421616, 0.03746046498417854, 0.7391002178192139, 0.0]], [[0.023652182891964912, 0.008639940991997719, 0.08203616738319397, 0.035750582814216614, 0.050224509090185165, 0.3533262312412262, 0.03081362321972847, 0.28302860260009766, 0.1325281411409378, 0.0], [0.016670020297169685, 0.1283574253320694, 0.836423397064209, 0.0042742472141981125, 0.0022883012425154448, 0.00297459471039474, 0.00022807312780059874, 0.0012588471872732043, 0.007524838205426931, 0.0], [0.031559381633996964, 0.02045642025768757, 0.8176267743110657, 0.006169404834508896, 0.0014412011951208115, 0.0069603933952748775, 0.0010916722239926457, 0.011522608809173107, 0.10317197442054749, 0.0], [0.004598122555762529, 0.004610949195921421, 0.01865001954138279, 0.020574036985635757, 0.0137012405321002, 0.7973257303237915, 0.01646837778389454, 0.023596635088324547, 0.1004747673869133, 0.0], [0.0005213705007918179, 0.00018707667186390609, 0.0016978917410597205, 0.019619440659880638, 0.009308884851634502, 0.8590161800384521, 0.024511896073818207, 0.06970686465501785, 0.015430280938744545, 0.0], [0.0001481063081882894, 2.072651477647014e-05, 0.00035672096419148147, 0.00033358228392899036, 0.00040588833508081734, 0.9861487746238708, 0.00651955883949995, 0.00443643843755126, 0.0016300288261845708, 0.0], [0.0010996124474331737, 0.0011850595474243164, 0.0075045316480100155, 0.004539311397820711, 0.05570072680711746, 0.18870605528354645, 0.23963898420333862, 0.013960372656583786, 0.487665593624115, 0.0], [0.0003884119214490056, 0.0004658032557927072, 0.028157439082860947, 0.0002352961164433509, 0.1278570294380188, 0.08260466903448105, 0.02582997828722, 0.022790132090449333, 0.7116712927818298, 0.0], [0.0015414542285725474, 0.0007310948567464948, 0.010464987717568874, 0.0012846259633079171, 0.45206302404403687, 0.029316790401935577, 0.04706822335720062, 0.018986493349075317, 0.4385431706905365, 0.0], [0.0005072542116977274, 0.0011837932979688048, 0.01220926083624363, 8.532252832083032e-05, 0.0018606879748404026, 0.010199862532317638, 0.0016309961210936308, 0.010775143280625343, 0.9615475535392761, 0.0]], [[0.29744189977645874, 0.04770943149924278, 0.09888078272342682, 0.19768767058849335, 0.048243775963783264, 0.12058595567941666, 0.05976371467113495, 0.03847452625632286, 0.09121233224868774, 0.0], [0.04126456007361412, 0.6604095697402954, 0.028894882649183273, 0.20104490220546722, 0.0014044500421732664, 0.0009343607816845179, 0.00244489056058228, 0.007453228812664747, 0.05614929273724556, 0.0], [0.008357543498277664, 0.0022072584833949804, 0.9876156449317932, 8.841200906317681e-05, 1.4883004041621462e-05, 0.00011741811613319442, 2.7020510970032774e-05, 0.00016062626673374325, 0.001411277218721807, 0.0], [0.06216944754123688, 0.48559242486953735, 0.042546145617961884, 0.034007471054792404, 0.047574639320373535, 0.12490913271903992, 0.07922931015491486, 0.013364763930439949, 0.11060672253370285, 0.0], [0.05222959443926811, 0.025416702032089233, 0.02865077182650566, 0.17457211017608643, 0.03144511207938194, 0.3907364010810852, 0.19607771933078766, 0.05274118855595589, 0.04813018813729286, 0.0], [0.0037726862356066704, 0.0031579534988850355, 0.0029440780635923147, 0.0017320584738627076, 0.060473062098026276, 0.761774480342865, 0.1523173600435257, 0.0058823637664318085, 0.007945872843265533, 0.0], [0.0020738786552101374, 0.0012752892216667533, 0.0004058163322042674, 0.020963717252016068, 0.39340031147003174, 0.012434415519237518, 0.4783190190792084, 0.011497312225401402, 0.0796302929520607, 0.0], [5.31752230017446e-05, 1.4492364243778866e-05, 7.312332309084013e-05, 0.0023682843893766403, 0.9866323471069336, 0.0009243910317309201, 0.0011850211303681135, 0.0017622504383325577, 0.0069872229360044, 0.0], [4.074166645295918e-05, 1.823456841520965e-05, 0.0001418270985595882, 0.007263784296810627, 0.9604514241218567, 0.0001852070417953655, 0.00034164052340202034, 0.0018497714772820473, 0.029707150533795357, 0.0], [0.0133396340534091, 0.03136875480413437, 0.6319980621337891, 0.0033722908701747656, 0.04728742688894272, 0.03541773557662964, 0.009523973800241947, 0.03100484237074852, 0.1966874897480011, 0.0]], [[0.03367111459374428, 0.018932543694972992, 0.09506545215845108, 0.04718795791268349, 0.028798582032322884, 0.33658939599990845, 0.02586139366030693, 0.29842811822891235, 0.11546547710895538, 0.0], [0.006203038617968559, 0.0906001627445221, 0.6977949738502502, 0.018352899700403214, 0.06787873804569244, 0.04403599724173546, 0.001631368650123477, 0.024296771734952927, 0.049206044524908066, 0.0], [0.006243667099624872, 0.010453532449901104, 0.7879610657691956, 0.004093538969755173, 0.0008473669877275825, 0.027760563418269157, 0.0003080451278947294, 0.14831961691379547, 0.014012438245117664, 0.0], [0.004387176129966974, 0.023410169407725334, 0.17247918248176575, 0.03958609700202942, 0.023799436166882515, 0.43659475445747375, 0.014754846692085266, 0.2318120151758194, 0.05317622795701027, 0.0], [0.0020952164195477962, 0.0024118656292557716, 0.028229335322976112, 0.007075420115143061, 0.019164882600307465, 0.5397294163703918, 0.034580815583467484, 0.3465326428413391, 0.020180128514766693, 0.0], [0.00020744462381117046, 0.00036016973899677396, 0.004934145137667656, 0.0004664760490413755, 0.008187839761376381, 0.9661812782287598, 0.009987047873437405, 0.003882928751409054, 0.005792597308754921, 0.0], [3.4081476769642904e-05, 1.7181657312903553e-05, 5.4824478866066784e-05, 0.00045897584641352296, 0.0043338024988770485, 0.001544477418065071, 0.9909620881080627, 2.356152981519699e-05, 0.0025708049070090055, 0.0], [0.0001047314508468844, 0.0001599654060555622, 0.001310097286477685, 0.001540280063636601, 0.833267331123352, 0.044754061847925186, 0.0028599577490240335, 0.0006454077665694058, 0.11535807698965073, 0.0], [8.819431968731806e-05, 6.364465662045404e-05, 0.00022057128080632538, 0.001112746773287654, 0.9560981392860413, 0.003599100047722459, 0.0002217600413132459, 0.0006697923527099192, 0.03792598471045494, 0.0], [0.0018130787648260593, 0.022020958364009857, 0.12822051346302032, 0.0005810249131172895, 0.03168048337101936, 0.014293116517364979, 0.002500524278730154, 0.0212943647056818, 0.7775959372520447, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6252409815788269, 0.3747589886188507, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8520486354827881, 0.010580658912658691, 0.13737063109874725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05910082906484604, 0.011589597910642624, 0.877491295337677, 0.051818281412124634, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3626183867454529, 0.026959313079714775, 0.07612177729606628, 0.13077552616596222, 0.4035249352455139, 0.0, 0.0, 0.0, 0.0, 0.0], [0.21979263424873352, 0.001410112832672894, 0.007092535495758057, 0.13166557252407074, 0.626970648765564, 0.013068560510873795, 0.0, 0.0, 0.0, 0.0], [0.08148042857646942, 0.001490423921495676, 0.004908325150609016, 0.01383854728192091, 0.7959722876548767, 0.05201547220349312, 0.05029459297657013, 0.0, 0.0, 0.0], [0.03934427723288536, 5.908778257435188e-05, 0.00014962907880544662, 0.005592166446149349, 0.7025003433227539, 0.1675100177526474, 0.03920353576540947, 0.04564077779650688, 0.0, 0.0], [0.4660189151763916, 0.00034756408422254026, 9.701005183160305e-05, 0.008154522627592087, 0.08121690154075623, 0.15592943131923676, 0.11426379531621933, 0.17044323682785034, 0.0035288764629513025, 0.0], [0.3707294762134552, 0.0020887483842670918, 0.23984688520431519, 0.07748916745185852, 0.18109895288944244, 0.03584783151745796, 0.005205830093473196, 0.005058187525719404, 0.0050886403769254684, 0.0775463655591011]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15256483852863312, 0.8474349975585938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08618302643299103, 0.30268052220344543, 0.6111364364624023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6251113414764404, 0.14608541131019592, 0.21724094450473785, 0.011562197469174862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.31851068139076233, 0.11805614084005356, 0.02926168404519558, 0.0854775682091713, 0.44869405031204224, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23099647462368011, 0.015003926120698452, 0.0028121687937527895, 0.025386620312929153, 0.5829272270202637, 0.14287345111370087, 0.0, 0.0, 0.0, 0.0], [0.2648485600948334, 0.01456066407263279, 0.008421574719250202, 0.01653379574418068, 0.25845009088516235, 0.35933130979537964, 0.07785411924123764, 0.0, 0.0, 0.0], [0.21031156182289124, 0.00652333116158843, 0.005756322760134935, 0.019128819927573204, 0.2526819407939911, 0.49096593260765076, 0.008809886872768402, 0.00582215515896678, 0.0, 0.0], [0.11555754393339157, 0.00475481478497386, 0.0013921409845352173, 0.045808907598257065, 0.29882168769836426, 0.3024459183216095, 0.0483231395483017, 0.18265680968761444, 0.0002390409354120493, 0.0], [0.8451279401779175, 0.021679740399122238, 0.035543736070394516, 0.005811640061438084, 0.04445958510041237, 0.018052000552415848, 0.0015424924204126, 0.013668404892086983, 0.012673787772655487, 0.0014405279653146863]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9927853345870972, 0.007214863318949938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011021426878869534, 0.007158290129154921, 0.9818204641342163, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.007071706000715494, 0.026167649775743484, 0.19316613674163818, 0.773594319820404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.320003479719162, 0.03976304829120636, 0.22334550321102142, 0.24320250749588013, 0.17368540167808533, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10932182520627975, 0.001151762087829411, 0.007792286574840546, 0.18981949985027313, 0.6517421007156372, 0.04017229378223419, 0.0, 0.0, 0.0, 0.0], [0.02538878843188286, 0.005211540497839451, 0.03069700486958027, 0.13252338767051697, 0.4279623329639435, 0.0899164006114006, 0.28830063343048096, 0.0, 0.0, 0.0], [0.010537173599004745, 0.0007831656257621944, 0.0007035965682007372, 0.015162549912929535, 0.9050821661949158, 0.05248205363750458, 0.01132790744304657, 0.00392116466537118, 0.0, 0.0], [0.005222301464527845, 0.003575690556317568, 0.0029950442258268595, 0.00018454395467415452, 0.0012630765559151769, 0.01364975143224001, 0.09376595914363861, 0.853415846824646, 0.02592780999839306, 0.0], [0.14979584515094757, 0.0004723063320852816, 0.4970340430736542, 0.03214645013213158, 0.022075939923524857, 0.006538126152008772, 0.0013381451135501266, 0.0030305178370326757, 0.0008045822032727301, 0.28676414489746094]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9691458940505981, 0.03085414692759514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9338735938072205, 0.02144204080104828, 0.04468445107340813, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4091326594352722, 0.1788463294506073, 0.3530478775501251, 0.058973249047994614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8083640336990356, 0.0245783980935812, 0.02959858626127243, 0.02002020739018917, 0.11743883788585663, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6256738901138306, 0.03313886746764183, 0.03255102410912514, 0.015011090785264969, 0.27659764885902405, 0.017027597874403, 0.0, 0.0, 0.0, 0.0], [0.2970131039619446, 0.01776941865682602, 0.015323061496019363, 0.014444534666836262, 0.2387886643409729, 0.36828577518463135, 0.048375438898801804, 0.0, 0.0, 0.0], [0.16347570717334747, 0.01386126596480608, 0.012116431258618832, 0.006670618429780006, 0.5951986312866211, 0.1577492356300354, 0.024585027247667313, 0.02634291537106037, 0.0, 0.0], [0.1568753868341446, 0.002166055142879486, 0.0014692704426124692, 0.009539359249174595, 0.7249224781990051, 0.0696585550904274, 0.02269914373755455, 0.010646837763488293, 0.0020231890957802534, 0.0], [0.6687246561050415, 0.003988182172179222, 0.00992897991091013, 0.00877397134900093, 0.07160260528326035, 0.14080072939395905, 0.01739262230694294, 0.04941429942846298, 0.01782085746526718, 0.011553076095879078]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.497504860162735, 0.502495288848877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.028444888070225716, 0.01678420603275299, 0.9547709822654724, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02853180095553398, 0.022399114444851875, 0.7835201025009155, 0.1655489057302475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.023048963397741318, 0.055082567036151886, 0.3371332883834839, 0.25099456310272217, 0.33374062180519104, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013693265616893768, 0.057373203337192535, 0.02566814236342907, 0.11711565405130386, 0.13761301338672638, 0.6485366225242615, 0.0, 0.0, 0.0, 0.0], [0.5831283926963806, 0.0857725590467453, 0.06227085366845131, 0.03169894590973854, 0.06183577701449394, 0.01752074435353279, 0.15777261555194855, 0.0, 0.0, 0.0], [0.0033312023151665926, 0.003545752028003335, 0.0018331086030229926, 0.05265560373663902, 0.047756411135196686, 0.045255228877067566, 0.20667387545108795, 0.6389486193656921, 0.0, 0.0], [0.02047032117843628, 0.03542931377887726, 0.01270933635532856, 0.46998995542526245, 0.035482652485370636, 0.015606570988893509, 0.1128709465265274, 0.03180817514657974, 0.26563259959220886, 0.0], [0.027955254539847374, 0.024354776367545128, 0.4609973132610321, 0.07958999276161194, 0.34062448143959045, 0.0068156360648572445, 0.000798556546214968, 0.0009541919571347535, 0.00023223790049087256, 0.05767740309238434]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.909331738948822, 0.09066825360059738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3865880072116852, 0.017979737371206284, 0.5954321622848511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5815560817718506, 0.15706834197044373, 0.052335821092128754, 0.2090395838022232, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7268465161323547, 0.0363004133105278, 0.07873083651065826, 0.06576839834451675, 0.09235385805368423, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6550694704055786, 0.019533857703208923, 0.042362816631793976, 0.07321250438690186, 0.06519921869039536, 0.14462217688560486, 0.0, 0.0, 0.0, 0.0], [0.48597243428230286, 0.05253118649125099, 0.06572883576154709, 0.06831242144107819, 0.06681334227323532, 0.09225586801767349, 0.168385848402977, 0.0, 0.0, 0.0], [0.2283225953578949, 0.01085133571177721, 0.0076954541727900505, 0.03403906524181366, 0.05505141243338585, 0.11318682134151459, 0.23008716106414795, 0.3207661509513855, 0.0, 0.0], [0.31019407510757446, 0.01576145552098751, 0.006604246329516172, 0.1025082990527153, 0.11805430799722672, 0.0999068170785904, 0.17944715917110443, 0.09494999051094055, 0.07257375121116638, 0.0], [0.028495613485574722, 0.00728303287178278, 0.028978589922189713, 0.21746259927749634, 0.0312367994338274, 0.01134485937654972, 0.002138715935871005, 0.0005697175511159003, 0.00012198994954815134, 0.6723678112030029]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9973775148391724, 0.0026223897002637386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8086934685707092, 0.08078567683696747, 0.11052089184522629, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14967022836208344, 0.05171789228916168, 0.3914002478122711, 0.40721163153648376, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.40780311822891235, 0.04434635117650032, 0.05232110992074013, 0.3448564112186432, 0.15067294239997864, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3417491614818573, 0.023165758699178696, 0.008621969260275364, 0.03819064050912857, 0.566249430179596, 0.022023199126124382, 0.0, 0.0, 0.0, 0.0], [0.13283461332321167, 0.0027981544844806194, 0.001892031985335052, 0.057958006858825684, 0.4807162284851074, 0.22431829571723938, 0.09948258846998215, 0.0, 0.0, 0.0], [0.5702553391456604, 0.005225116387009621, 0.0014312443090602756, 0.028526127338409424, 0.15899939835071564, 0.05284468084573746, 0.022491520270705223, 0.16022635996341705, 0.0, 0.0], [0.005209033377468586, 6.901475717313588e-05, 5.760595013271086e-05, 0.006149875931441784, 0.006613760255277157, 0.010193211026489735, 0.013639912940561771, 0.9578513503074646, 0.00021631908020935953, 0.0], [0.5839820504188538, 0.007275882177054882, 0.03890826180577278, 0.2169828861951828, 0.02285575494170189, 0.0033320344518870115, 0.0027764069382101297, 0.032896872609853745, 0.038299210369586945, 0.05269058048725128]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.050016310065984726, 0.9499835968017578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1501082479953766, 0.3363426625728607, 0.5135491490364075, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3278971016407013, 0.3615517318248749, 0.08450257778167725, 0.22604861855506897, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.25667694211006165, 0.40780505537986755, 0.15422186255455017, 0.10097997635602951, 0.08031607419252396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06381947547197342, 0.026328660547733307, 0.009785240516066551, 0.001955200219526887, 0.6504424810409546, 0.24766895174980164, 0.0, 0.0, 0.0, 0.0], [0.20312389731407166, 0.04875075817108154, 0.01619674637913704, 0.028123438358306885, 0.08579143136739731, 0.20417368412017822, 0.41383999586105347, 0.0, 0.0, 0.0], [0.0028631098102778196, 0.00043423930765129626, 0.00021322182146832347, 0.0004598038794938475, 0.009248310700058937, 0.010348621755838394, 0.14874719083309174, 0.8276853561401367, 0.0, 0.0], [0.0018436884274706244, 0.00014077842934057117, 0.00019285999587737024, 0.0006260477821342647, 0.010675687342882156, 0.007219970691949129, 0.05410425364971161, 0.9211149215698242, 0.0040815602988004684, 0.0], [0.0996592566370964, 0.00012745348794851452, 0.0005518114776350558, 0.00026576913660392165, 0.00016320105351042002, 9.30886235437356e-05, 0.00024295282491948456, 0.000212193961488083, 0.0026789114344865084, 0.8960053324699402]]], [[[0.11433855444192886, 0.04686390608549118, 0.05050662159919739, 0.01692698895931244, 0.014815520495176315, 0.0768260732293129, 0.017432983964681625, 0.6101956367492676, 0.052093639969825745, 0.0], [0.38972416520118713, 0.22944767773151398, 0.07904881238937378, 0.052096493542194366, 0.007011168636381626, 0.006784902419894934, 0.0014207185013219714, 0.13426420092582703, 0.10020176321268082, 0.0], [0.11186019331216812, 0.017331605777144432, 0.04338392615318298, 0.012996343895792961, 0.0037596735637634993, 0.04186789691448212, 0.010188347660005093, 0.40130615234375, 0.35730576515197754, 0.0], [0.10591340810060501, 0.20223784446716309, 0.2886512279510498, 0.06966178864240646, 0.009836114011704922, 0.0229511596262455, 0.008547846227884293, 0.14636646211147308, 0.14583423733711243, 0.0], [0.07172418385744095, 0.2936038672924042, 0.03526498004794121, 0.13891401886940002, 0.06139945611357689, 0.03925776481628418, 0.05349786579608917, 0.1478489190340042, 0.15848881006240845, 0.0], [0.06896000355482101, 0.07670743763446808, 0.04746192321181297, 0.0077508557587862015, 0.0064402432180941105, 0.0690828412771225, 0.07239065319299698, 0.3534849286079407, 0.2977212965488434, 0.0], [0.07854402810335159, 0.05315924435853958, 0.006970829796046019, 0.01197806466370821, 0.15678070485591888, 0.059328123927116394, 0.0844779834151268, 0.058106619864702225, 0.4906543493270874, 0.0], [0.03681005910038948, 0.0140389958396554, 0.007565703243017197, 0.004180380143225193, 0.03550711274147034, 0.08768045157194138, 0.04672156274318695, 0.05331620201468468, 0.7141793966293335, 0.0], [0.011276278644800186, 0.0043571279384195805, 0.0015699869254603982, 0.009309794753789902, 0.5466312766075134, 0.06633520126342773, 0.012565904296934605, 0.036605555564165115, 0.3113488256931305, 0.0], [0.06600549072027206, 0.025541657581925392, 0.15132947266101837, 0.052793603390455246, 0.07684693485498428, 0.05682613328099251, 0.01590665802359581, 0.05306769907474518, 0.501682460308075, 0.0]], [[0.02032626047730446, 0.020203230902552605, 0.6020291447639465, 0.030009541660547256, 0.018654465675354004, 0.18802858889102936, 0.05143868178129196, 0.02303317002952099, 0.04627683013677597, 0.0], [0.0019176346249878407, 0.016540443524718285, 0.9535443782806396, 0.023291945457458496, 0.00010982116509694606, 0.0009689715225249529, 0.00013850984396412969, 0.00020587266772054136, 0.0032822038047015667, 0.0], [0.0021921033039689064, 0.01942657120525837, 0.9639623165130615, 0.007353355176746845, 5.936318120802753e-05, 0.0048356493934988976, 6.20722203166224e-05, 0.00038220148417167366, 0.0017266407376155257, 0.0], [0.009460263885557652, 0.051493529230356216, 0.3948231041431427, 0.009338779374957085, 0.006950095761567354, 0.48816555738449097, 0.01867900975048542, 0.0028053568676114082, 0.018284155055880547, 0.0], [0.003424277761951089, 0.009652957320213318, 0.10005363076925278, 0.026136387139558792, 0.09182560443878174, 0.6324647068977356, 0.07658552378416061, 0.005459657870233059, 0.05439731851220131, 0.0], [0.0026005429681390524, 0.0029943317640572786, 0.26352012157440186, 0.02426978573203087, 0.05801504850387573, 0.49633610248565674, 0.11849544942378998, 0.009708826430141926, 0.02405967190861702, 0.0], [0.0027453943621367216, 0.0021119171287864447, 0.0030521987937390804, 0.09308812767267227, 0.28145554661750793, 0.015254919417202473, 0.530491828918457, 0.007408978417515755, 0.06439103186130524, 0.0], [0.0012973180273547769, 0.002199073787778616, 0.0031004296615719795, 0.024488963186740875, 0.8535729050636292, 0.016068320721387863, 0.029179612174630165, 0.012250186875462532, 0.05784311145544052, 0.0], [0.0010010729311034083, 0.0008253253763541579, 0.0028483583591878414, 0.028342707082629204, 0.860925018787384, 0.0038871155120432377, 0.006998666562139988, 0.01413769368082285, 0.08103384077548981, 0.0], [0.007383578456938267, 0.056256651878356934, 0.5807297825813293, 0.01667044125497341, 0.03810223564505577, 0.07880110293626785, 0.009197888895869255, 0.12926581501960754, 0.08359251171350479, 0.0]], [[0.023356424644589424, 0.012650059536099434, 0.05017145350575447, 0.05590398982167244, 0.05159280076622963, 0.01602507382631302, 0.014807065948843956, 0.654244601726532, 0.12124844640493393, 0.0], [0.030835414305329323, 0.04180247709155083, 0.029645785689353943, 0.20071062445640564, 0.010328685864806175, 0.03208288922905922, 0.026780622079968452, 0.457701712846756, 0.17011170089244843, 0.0], [0.02856343612074852, 0.03459611535072327, 0.15441730618476868, 0.04662194848060608, 0.0013040672056376934, 0.017847269773483276, 0.02464178577065468, 0.5969575643539429, 0.09505032747983932, 0.0], [0.03350958973169327, 0.02514287829399109, 0.027676144614815712, 0.11052078753709793, 0.15496152639389038, 0.08862635493278503, 0.027723105624318123, 0.24766287207603455, 0.28417670726776123, 0.0], [0.014355039224028587, 0.005383878946304321, 0.002517768880352378, 0.09422861039638519, 0.06622537225484848, 0.046315327286720276, 0.08473969250917435, 0.4999735355377197, 0.18626095354557037, 0.0], [0.002460801973938942, 0.0016284199664369226, 0.005857668351382017, 0.006880565080791712, 0.7626023292541504, 0.025456121191382408, 0.021016357466578484, 0.06090177595615387, 0.11319592595100403, 0.0], [0.007633878383785486, 0.002682786202058196, 0.0008938225219026208, 0.006808742880821228, 0.17231638729572296, 0.049100711941719055, 0.32851701974868774, 0.0061601921916007996, 0.4258863925933838, 0.0], [0.003303236560896039, 0.0015338786179199815, 0.0017581325955688953, 0.0052335225045681, 0.24177710711956024, 0.09136255830526352, 0.06603478640317917, 0.0047843558713793755, 0.5842124223709106, 0.0], [0.011694137938320637, 0.0015430846251547337, 0.00043408613419160247, 0.005433904007077217, 0.03723231703042984, 0.1666216105222702, 0.04878358170390129, 0.024785596877336502, 0.7034717798233032, 0.0], [0.028515880927443504, 0.0183264147490263, 0.011487613432109356, 0.03205259144306183, 0.06179385632276535, 0.041277043521404266, 0.014015565626323223, 0.06198226660490036, 0.7305486798286438, 0.0]], [[0.10071786493062973, 0.017111245542764664, 0.07246935367584229, 0.01480931881815195, 0.14864948391914368, 0.20273517072200775, 0.054981958121061325, 0.25890761613845825, 0.12961813807487488, 0.0], [0.049787674099206924, 0.02108882926404476, 0.20989678800106049, 0.006962155923247337, 0.21569682657718658, 0.1622857302427292, 0.016771212220191956, 0.2403237521648407, 0.07718709856271744, 0.0], [0.024618864059448242, 0.010488898493349552, 0.4834355115890503, 0.015693388879299164, 0.07393413037061691, 0.055557433515787125, 0.007495412603020668, 0.27800077199935913, 0.05077548325061798, 0.0], [0.006324393209069967, 0.0006586945382878184, 0.02188086323440075, 0.003439908614382148, 0.055277179926633835, 0.5423230528831482, 0.1656835526227951, 0.12264314293861389, 0.08176910877227783, 0.0], [0.008246216922998428, 0.000647101376671344, 0.018551276996731758, 0.0031310885678976774, 0.04379039630293846, 0.34376823902130127, 0.2999532222747803, 0.13205647468566895, 0.1498558670282364, 0.0], [0.001800144906155765, 0.00032634654780849814, 0.02560480497777462, 0.0014933178899809718, 0.04328969866037369, 0.48067817091941833, 0.22867664694786072, 0.008819987997412682, 0.20931090414524078, 0.0], [0.0023069612216204405, 0.0018136217258870602, 0.006447605788707733, 0.005140945315361023, 0.046570103615522385, 0.045606330037117004, 0.3236173987388611, 0.014286459423601627, 0.5542104840278625, 0.0], [0.0025225167628377676, 0.000774701707996428, 0.0168449804186821, 0.0014132045907899737, 0.1692919135093689, 0.21547472476959229, 0.19468647241592407, 0.00621472392231226, 0.392776757478714, 0.0], [0.006893941201269627, 0.0026040272787213326, 0.036687299609184265, 0.0016275923699140549, 0.13132861256599426, 0.15552441775798798, 0.23651301860809326, 0.023025648668408394, 0.4057953953742981, 0.0], [0.004257934633642435, 0.008543262258172035, 0.05716743320226669, 0.0024442216381430626, 0.027526315301656723, 0.08828678727149963, 0.025276461616158485, 0.2843557894229889, 0.5021417737007141, 0.0]], [[0.2613511085510254, 0.024888625368475914, 0.11462423205375671, 0.021279124543070793, 0.1065509021282196, 0.23139707744121552, 0.07117345929145813, 0.09822205454111099, 0.07051338255405426, 0.0], [0.08973123878240585, 0.07256940752267838, 0.3644520342350006, 0.09313907474279404, 0.10501276701688766, 0.026235496625304222, 0.035534195601940155, 0.05646198242902756, 0.15686386823654175, 0.0], [0.12105944007635117, 0.03531542792916298, 0.18099160492420197, 0.04576702043414116, 0.03264385089278221, 0.04934798926115036, 0.0072426870465278625, 0.2739674150943756, 0.25366437435150146, 0.0], [0.049101557582616806, 0.02436317875981331, 0.1119280532002449, 0.019082490354776382, 0.23333144187927246, 0.12024182081222534, 0.09606382250785828, 0.03866123780608177, 0.3072265088558197, 0.0], [0.011761177331209183, 0.004259902983903885, 0.019396282732486725, 0.010304590687155724, 0.5410462021827698, 0.1548439860343933, 0.1577453315258026, 0.022628072649240494, 0.07801424711942673, 0.0], [0.011012338101863861, 0.006456742994487286, 0.03514476120471954, 0.01111147552728653, 0.3646441400051117, 0.06045660004019737, 0.22725869715213776, 0.030072104185819626, 0.2538430392742157, 0.0], [0.0009554855059832335, 0.0010365764610469341, 0.000539954868145287, 0.013481645844876766, 0.6702913641929626, 0.013201623223721981, 0.06565960496664047, 0.008186675608158112, 0.2266470193862915, 0.0], [0.007978711277246475, 0.0019918852485716343, 0.0007363414042629302, 0.010062554851174355, 0.10717969387769699, 0.01258536335080862, 0.08278501033782959, 0.02946571074426174, 0.7472147941589355, 0.0], [0.014369996264576912, 0.00412968173623085, 0.002898097038269043, 0.0381503589451313, 0.28382056951522827, 0.03412872180342674, 0.2624143660068512, 0.04523473232984543, 0.3148534893989563, 0.0], [0.0025127469561994076, 0.0030011499766260386, 0.0036209137178957462, 0.0006047216593287885, 0.01094596553593874, 0.0023283734917640686, 0.003409643191844225, 0.009625249542295933, 0.9639512896537781, 0.0]], [[0.023785226047039032, 0.024275904521346092, 0.6168470978736877, 0.01581703871488571, 0.026939542964100838, 0.1783975064754486, 0.04853774979710579, 0.02762567065656185, 0.03777410835027695, 0.0], [0.006065902300179005, 0.04932599142193794, 0.8176359534263611, 0.018976736813783646, 0.008159944787621498, 0.011068272404372692, 0.010428683832287788, 0.014124251902103424, 0.06421414017677307, 0.0], [0.0037236923817545176, 0.007844064384698868, 0.9502744674682617, 0.0048003061674535275, 0.00022506865207105875, 0.004834793973714113, 0.0015490480000153184, 0.0026021157391369343, 0.024146683514118195, 0.0], [0.0007564057596027851, 0.0017460802337154746, 0.15768493711948395, 0.004074132069945335, 0.015430302359163761, 0.7368869781494141, 0.028010869398713112, 0.013945921324193478, 0.04146439954638481, 0.0], [0.0008445779676549137, 0.0015138997696340084, 0.17073306441307068, 0.0074179465882480145, 0.08121992647647858, 0.5853323936462402, 0.09402737021446228, 0.024092217907309532, 0.034818582236766815, 0.0], [0.0006014688406139612, 0.0016882645431905985, 0.16094569861888885, 0.003698966233059764, 0.034668561071157455, 0.5876308679580688, 0.09562253206968307, 0.05209798738360405, 0.06304588913917542, 0.0], [0.00015283364336937666, 0.0004516944463830441, 0.003205003682523966, 0.0049727726727724075, 0.10853080451488495, 0.03262018784880638, 0.6125266551971436, 0.005719948559999466, 0.2318202555179596, 0.0], [9.933842375176027e-05, 0.00020211786613799632, 0.0037883264012634754, 0.0051808832213282585, 0.6936825513839722, 0.10089477151632309, 0.023457802832126617, 0.011726793833076954, 0.16096755862236023, 0.0], [0.0002526468597352505, 0.0010056017199531198, 0.003837066935375333, 0.034950658679008484, 0.5882559418678284, 0.029549231752753258, 0.030938459560275078, 0.01461110170930624, 0.2965993583202362, 0.0], [0.0029390468262135983, 0.005815382581204176, 0.06488344818353653, 0.008705642074346542, 0.010130577720701694, 0.012970774434506893, 0.019612692296504974, 0.007819950580596924, 0.8671225309371948, 0.0]], [[0.027314670383930206, 0.02143898233771324, 0.1116434708237648, 0.006578116212040186, 0.20446842908859253, 0.3867157995700836, 0.054494183510541916, 0.09778231382369995, 0.08956411480903625, 0.0], [0.09418193250894547, 0.7071846127510071, 0.05323847755789757, 0.0077135805040597916, 0.01789833791553974, 0.010848474688827991, 0.0020562252029776573, 0.01705808937549591, 0.08982021361589432, 0.0], [0.005751691292971373, 0.01031999196857214, 0.8884198069572449, 0.00210022390820086, 0.0066058398224413395, 0.019834432750940323, 0.002143828198313713, 0.02793465554714203, 0.036889322102069855, 0.0], [0.027659546583890915, 0.03931494802236557, 0.10616040229797363, 0.011142275296151638, 0.1017894372344017, 0.30847156047821045, 0.12201698124408722, 0.05519269034266472, 0.22825226187705994, 0.0], [0.037639543414115906, 0.062483835965394974, 0.050776157528162, 0.012697378173470497, 0.27911704778671265, 0.19993652403354645, 0.14870049059391022, 0.10304640233516693, 0.10560261458158493, 0.0], [0.007995839230716228, 0.008397839032113552, 0.03270075097680092, 0.004312656354159117, 0.03775893524289131, 0.3733556568622589, 0.3424486219882965, 0.012857009656727314, 0.18017242848873138, 0.0], [0.012142053805291653, 0.007298614829778671, 0.016982076689600945, 0.02473442070186138, 0.08738671243190765, 0.033574704080820084, 0.27830857038497925, 0.033199213445186615, 0.5063735842704773, 0.0], [0.02729739435017109, 0.05135440081357956, 0.03332214429974556, 0.02499799057841301, 0.11955489963293076, 0.020848069339990616, 0.017926985397934914, 0.01858661323785782, 0.6861116290092468, 0.0], [0.018110578879714012, 0.011406980454921722, 0.0018257799092680216, 0.025524618104100227, 0.3885835111141205, 0.010744227096438408, 0.008441396057605743, 0.003679890651255846, 0.5316829681396484, 0.0], [0.02325628325343132, 0.013795747421681881, 0.0823512151837349, 0.0021813653875142336, 0.03511650115251541, 0.0814405307173729, 0.02589382231235504, 0.14330172538757324, 0.5926627516746521, 0.0]], [[0.011912941001355648, 0.006341524887830019, 0.1334817260503769, 0.017931688576936722, 0.005569889210164547, 0.7441595792770386, 0.0258712787181139, 0.034265827387571335, 0.020465616136789322, 0.0], [0.03743305802345276, 0.05229289084672928, 0.3549361228942871, 0.028500670567154884, 0.01974724419414997, 0.29288655519485474, 0.08050932735204697, 0.06582070142030716, 0.06787342578172684, 0.0], [0.025082573294639587, 0.10057684034109116, 0.7856844663619995, 0.01178921852260828, 0.0010154875926673412, 0.02595749869942665, 0.008632739074528217, 0.006036050152033567, 0.0352250337600708, 0.0], [0.010372502729296684, 0.023954369127750397, 0.18692812323570251, 0.03930393233895302, 0.004741673823446035, 0.46527597308158875, 0.1267295777797699, 0.048278260976076126, 0.09441567957401276, 0.0], [0.0031391805969178677, 0.006868112366646528, 0.1369999200105667, 0.013019833713769913, 0.008593270555138588, 0.6626507639884949, 0.07777946442365646, 0.052107103168964386, 0.03884238004684448, 0.0], [0.005276903510093689, 0.026354510337114334, 0.056238383054733276, 0.03191604092717171, 0.025259410962462425, 0.3898610472679138, 0.2790180444717407, 0.028249284252524376, 0.1578262895345688, 0.0], [0.001124523114413023, 0.0035109275486320257, 0.0021898215636610985, 0.043262600898742676, 0.0267842635512352, 0.03029855713248253, 0.5000982284545898, 0.0056237452663481236, 0.38710734248161316, 0.0], [0.00020160828717052937, 0.0004381221951916814, 0.010444902814924717, 0.010398894548416138, 0.10143585503101349, 0.23107938468456268, 0.09920267760753632, 0.019364865496754646, 0.5274338126182556, 0.0], [0.00019610628078226, 0.00041535915806889534, 0.009462974965572357, 0.005905983969569206, 0.29870083928108215, 0.24092966318130493, 0.11132201552391052, 0.05168075114488602, 0.2813863754272461, 0.0], [0.004135515075176954, 0.014277483336627483, 0.15738952159881592, 0.003396045882254839, 0.01641785353422165, 0.07296160608530045, 0.02388397790491581, 0.024013018235564232, 0.683525025844574, 0.0]]]], \"top_text\": [\"\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"], \"bot_text\": [\"\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}}" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "\n", + "/**\n", + " * @fileoverview Transformer Visualization D3 javascript code.\n", + " */\n", + "\n", + "requirejs(['jquery', 'd3'],\n", + "function($, d3) {\n", + "\n", + "var attention = window.attention;\n", + "\n", + "const TEXT_SIZE = 15;\n", + "const BOXWIDTH = TEXT_SIZE * 8;\n", + "const BOXHEIGHT = TEXT_SIZE * 1.5;\n", + "const WIDTH = 2000;\n", + "const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100;\n", + "const MATRIX_WIDTH = 150;\n", + "const head_colours = d3.scale.category10();\n", + "const CHECKBOX_SIZE = 20;\n", + "\n", + "function lighten(colour) {\n", + " var c = d3.hsl(colour);\n", + " var increment = (1 - c.l) * 0.6;\n", + " c.l += increment;\n", + " c.s -= increment;\n", + " return c;\n", + "}\n", + "\n", + "function transpose(mat) {\n", + " return mat[0].map(function(col, i) {\n", + " return mat.map(function(row) {\n", + " return row[i];\n", + " });\n", + " });\n", + "}\n", + "\n", + "function zip(a, b) {\n", + " return a.map(function (e, i) {\n", + " return [e, b[i]];\n", + " });\n", + "}\n", + "\n", + "\n", + "function renderVis(id, top_text, bot_text, attention_heads, config) {\n", + " $(id).empty();\n", + " var svg = d3.select(id)\n", + " .append('svg')\n", + " .attr(\"width\", WIDTH)\n", + " .attr(\"height\", HEIGHT);\n", + "\n", + " var att_data = [];\n", + " for (var i=0; i < attention_heads.length; i++) {\n", + " var att_trans = transpose(attention_heads[i]);\n", + " att_data.push(zip(attention_heads[i], att_trans));\n", + " }\n", + "\n", + " renderText(svg, top_text, true, att_data, 0);\n", + " renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH);\n", + "\n", + " renderAttentionHighlights(svg, att_data);\n", + "\n", + " svg.append(\"g\").classed(\"attention_heads\", true);\n", + "\n", + " renderAttention(svg, attention_heads);\n", + "\n", + " draw_checkboxes(config, 0, svg, attention_heads);\n", + "}\n", + "\n", + "\n", + "function renderText(svg, text, is_top, att_data, left_pos) {\n", + " var id = is_top ? \"top\" : \"bottom\";\n", + " var textContainer = svg.append(\"svg:g\")\n", + " .attr(\"id\", id);\n", + "\n", + " textContainer.append(\"g\").classed(\"attention_boxes\", true)\n", + " .selectAll(\"g\")\n", + " .data(att_data)\n", + " .enter()\n", + " .append(\"g\")\n", + " .selectAll(\"rect\")\n", + " .data(function(d) {return d;})\n", + " .enter()\n", + " .append(\"rect\")\n", + " .attr(\"x\", function(d, i, j) {\n", + " return left_pos + box_offset(j);\n", + " })\n", + " .attr(\"y\", function(d, i) {\n", + " return (+1) * BOXHEIGHT;\n", + " })\n", + " .attr(\"width\", BOXWIDTH/active_heads())\n", + " .attr(\"height\", function() { return BOXHEIGHT; })\n", + " .attr(\"fill\", function(d, i, j) {\n", + " return head_colours(j);\n", + " })\n", + " .style(\"opacity\", 0.0);\n", + "\n", + "\n", + " var tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n", + " .data(text)\n", + " .enter()\n", + " .append(\"g\");\n", + "\n", + " tokenContainer.append(\"rect\")\n", + " .classed(\"background\", true)\n", + " .style(\"opacity\", 0.0)\n", + " .attr(\"fill\", \"lightgray\")\n", + " .attr(\"x\", left_pos)\n", + " .attr(\"y\", function(d, i) {\n", + " return (i+1) * BOXHEIGHT;\n", + " })\n", + " .attr(\"width\", BOXWIDTH)\n", + " .attr(\"height\", BOXHEIGHT);\n", + "\n", + " var theText = tokenContainer.append(\"text\")\n", + " .text(function(d) { return d; })\n", + " .attr(\"font-size\", TEXT_SIZE + \"px\")\n", + " .style(\"cursor\", \"default\")\n", + " .style(\"-webkit-user-select\", \"none\")\n", + " .attr(\"x\", left_pos)\n", + " .attr(\"y\", function(d, i) {\n", + " return (i+1) * BOXHEIGHT;\n", + " });\n", + "\n", + " if (is_top) {\n", + " theText.style(\"text-anchor\", \"end\")\n", + " .attr(\"dx\", BOXWIDTH - TEXT_SIZE)\n", + " .attr(\"dy\", TEXT_SIZE);\n", + " } else {\n", + " theText.style(\"text-anchor\", \"start\")\n", + " .attr(\"dx\", + TEXT_SIZE)\n", + " .attr(\"dy\", TEXT_SIZE);\n", + " }\n", + "\n", + " tokenContainer.on(\"mouseover\", function(d, index) {\n", + " textContainer.selectAll(\".background\")\n", + " .style(\"opacity\", function(d, i) {\n", + " return i == index ? 1.0 : 0.0;\n", + " });\n", + "\n", + " svg.selectAll(\".attention_heads\").style(\"display\", \"none\");\n", + "\n", + " svg.selectAll(\".line_heads\") // To get the nesting to work.\n", + " .selectAll(\".att_lines\")\n", + " .attr(\"stroke-opacity\", function(d) {\n", + " return 1.0;\n", + " })\n", + " .attr(\"y1\", function(d, i) {\n", + " if (is_top) {\n", + " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " } else {\n", + " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " }\n", + " })\n", + " .attr(\"x1\", BOXWIDTH)\n", + " .attr(\"y2\", function(d, i) {\n", + " if (is_top) {\n", + " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " } else {\n", + " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " }\n", + " })\n", + " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", + " .attr(\"stroke-width\", 2)\n", + " .attr(\"stroke\", function(d, i, j) {\n", + " return head_colours(j);\n", + " })\n", + " .attr(\"stroke-opacity\", function(d, i, j) {\n", + " if (is_top) {d = d[0];} else {d = d[1];}\n", + " if (config.head_vis[j]) {\n", + " if (d) {\n", + " return d[index];\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " });\n", + "\n", + "\n", + " function updateAttentionBoxes() {\n", + " var id = is_top ? \"bottom\" : \"top\";\n", + " var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0;\n", + " svg.select(\"#\" + id)\n", + " .selectAll(\".attention_boxes\")\n", + " .selectAll(\"g\")\n", + " .selectAll(\"rect\")\n", + " .attr(\"x\", function(d, i, j) { return the_left_pos + box_offset(j); })\n", + " .attr(\"y\", function(d, i) { return (i+1) * BOXHEIGHT; })\n", + " .attr(\"width\", BOXWIDTH/active_heads())\n", + " .attr(\"height\", function() { return BOXHEIGHT; })\n", + " .style(\"opacity\", function(d, i, j) {\n", + " if (is_top) {d = d[0];} else {d = d[1];}\n", + " if (config.head_vis[j])\n", + " if (d) {\n", + " return d[index];\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " else\n", + " return 0.0;\n", + "\n", + " });\n", + " }\n", + "\n", + " updateAttentionBoxes();\n", + " });\n", + "\n", + " textContainer.on(\"mouseleave\", function() {\n", + " d3.select(this).selectAll(\".background\")\n", + " .style(\"opacity\", 0.0);\n", + "\n", + " svg.selectAll(\".att_lines\").attr(\"stroke-opacity\", 0.0);\n", + " svg.selectAll(\".attention_heads\").style(\"display\", \"inline\");\n", + " svg.selectAll(\".attention_boxes\")\n", + " .selectAll(\"g\")\n", + " .selectAll(\"rect\")\n", + " .style(\"opacity\", 0.0);\n", + " });\n", + "}\n", + "\n", + "function renderAttentionHighlights(svg, attention) {\n", + " var line_container = svg.append(\"g\");\n", + " line_container.selectAll(\"g\")\n", + " .data(attention)\n", + " .enter()\n", + " .append(\"g\")\n", + " .classed(\"line_heads\", true)\n", + " .selectAll(\"line\")\n", + " .data(function(d){return d;})\n", + " .enter()\n", + " .append(\"line\").classed(\"att_lines\", true);\n", + "}\n", + "\n", + "function renderAttention(svg, attention_heads) {\n", + " var line_container = svg.selectAll(\".attention_heads\");\n", + " line_container.html(null);\n", + " for(var h=0; h\").val(i).text(i));\n", + "}\n", + "\n", + "$(\"#layer\").on('change', function(e) {\n", + " config.layer = +e.currentTarget.value;\n", + " render();\n", + "});\n", + "\n", + "$(\"#att_type\").on('change', function(e) {\n", + " config.att_type = e.currentTarget.value;\n", + " render();\n", + "});\n", + "\n", + "$(\"button\").on('click', visualize);\n", + "\n", + "visualize();\n", + "\n", + "});\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "call_html()\n", + "display.display(display.HTML(vis_html))\n", + "display.display(display.Javascript('window.attention = %s' % attention_json))\n", + "display.display(display.Javascript(vis_js))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lydjSs3hgDVF" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Attention_Visualization_in_Trax.ipynb", + "provenance": [ + { + "file_id": "1bJu3Qx37FY9UpHqVMyXCTNb64v4Iw_v7", + "timestamp": 1598692842045 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/trax/examples/Deep_N_Gram_Models.ipynb b/resources/examples/ipynb/Deep_N_Gram_Models.ipynb similarity index 100% rename from trax/examples/Deep_N_Gram_Models.ipynb rename to resources/examples/ipynb/Deep_N_Gram_Models.ipynb diff --git a/trax/examples/Fashion_MNIST_with_Trax.ipynb b/resources/examples/ipynb/Fashion_MNIST_with_Trax.ipynb similarity index 100% rename from trax/examples/Fashion_MNIST_with_Trax.ipynb rename to resources/examples/ipynb/Fashion_MNIST_with_Trax.ipynb diff --git a/trax/examples/Knowledge_Tracing_Transformer.ipynb b/resources/examples/ipynb/Knowledge_Tracing_Transformer.ipynb similarity index 100% rename from trax/examples/Knowledge_Tracing_Transformer.ipynb rename to resources/examples/ipynb/Knowledge_Tracing_Transformer.ipynb diff --git a/trax/examples/MathQA_Python_generation_notebook.ipynb b/resources/examples/ipynb/MathQA_Python_generation_notebook.ipynb similarity index 100% rename from trax/examples/MathQA_Python_generation_notebook.ipynb rename to resources/examples/ipynb/MathQA_Python_generation_notebook.ipynb diff --git a/trax/examples/NER_using_Reformer.ipynb b/resources/examples/ipynb/NER_using_Reformer.ipynb similarity index 100% rename from trax/examples/NER_using_Reformer.ipynb rename to resources/examples/ipynb/NER_using_Reformer.ipynb diff --git a/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb b/resources/examples/ipynb/NMT_with_Transformers_Reformers_using_Trax.ipynb similarity index 100% rename from trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb rename to resources/examples/ipynb/NMT_with_Transformers_Reformers_using_Trax.ipynb diff --git a/trax/examples/README.md b/resources/examples/ipynb/README.md similarity index 100% rename from trax/examples/README.md rename to resources/examples/ipynb/README.md diff --git a/trax/examples/Terraformer_from_scratch.ipynb b/resources/examples/ipynb/Terraformer_from_scratch.ipynb similarity index 100% rename from trax/examples/Terraformer_from_scratch.ipynb rename to resources/examples/ipynb/Terraformer_from_scratch.ipynb diff --git a/trax/examples/earlystopping.ipynb b/resources/examples/ipynb/earlystopping.ipynb similarity index 100% rename from trax/examples/earlystopping.ipynb rename to resources/examples/ipynb/earlystopping.ipynb diff --git a/trax/examples/illustrated_wideresnet.ipynb b/resources/examples/ipynb/illustrated_wideresnet.ipynb similarity index 100% rename from trax/examples/illustrated_wideresnet.ipynb rename to resources/examples/ipynb/illustrated_wideresnet.ipynb diff --git a/trax/intro.ipynb b/resources/examples/ipynb/intro.ipynb similarity index 100% rename from trax/intro.ipynb rename to resources/examples/ipynb/intro.ipynb diff --git a/trax/layers/intro.ipynb b/resources/examples/ipynb/layers_intro.ipynb similarity index 100% rename from trax/layers/intro.ipynb rename to resources/examples/ipynb/layers_intro.ipynb diff --git a/resources/examples/ipynb/models/reformer/image_generation.ipynb b/resources/examples/ipynb/models/reformer/image_generation.ipynb new file mode 100644 index 000000000..28d2487cb --- /dev/null +++ b/resources/examples/ipynb/models/reformer/image_generation.ipynb @@ -0,0 +1,412 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Reformer: Image Generation", + "provenance": [], + "collapsed_sections": [ + "udDs_biH0n5U" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "udDs_biH0n5U", + "colab_type": "text" + }, + "source": [ + "#### Copyright 2020 Google LLC." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WPY-OyyM0pSs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "psnUF-8c02o_", + "colab_type": "text" + }, + "source": [ + "# Reformer: Image Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/image_generation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1lnRd_IoERdk", + "colab_type": "text" + }, + "source": [ + "This notebook was designed to run on TPU.\n", + "\n", + "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8PluCmWbZIpJ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Install JAX. This custom build raises the TPU timeout threshold, because the\n", + "# default limit of 2 minutes is too short for sampling very long sequences.\n", + "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n", + "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n", + "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n", + "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n", + "\n", + "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", + "import requests\n", + "import os\n", + "if 'TPU_DRIVER_MODE' not in globals():\n", + " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", + " resp = requests.post(url)\n", + " TPU_DRIVER_MODE = 1\n", + "\n", + "# The following is required to use TPU Driver as JAX's backend.\n", + "from jax.config import config\n", + "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", + "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", + "print(config.FLAGS.jax_backend_target)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yiPdBenoZwH6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n", + "\n", + "from tensorflow.compat.v1.io.gfile import GFile\n", + "import gin\n", + "import os\n", + "import jax\n", + "import trax\n", + "from trax.models.beam_search import Search\n", + "from trax.supervised import inputs\n", + "\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "\n", + "from scipy.special import softmax" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yyxRk75iaAap", + "colab_type": "code", + "colab": {} + }, + "source": [ + "%matplotlib inline\n", + "from matplotlib import pyplot as plt" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FQ89jHCYfhpg" + }, + "source": [ + "## Load example data and model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qBvuw2h85WXE", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Normally we train on the full imagenet64 training set, which is quite large so\n", + "# we won't be loading it from this notebook. Instead, let's just load a few PNG\n", + "# images to use in our data pipeline.\n", + "DATA = []\n", + "for i in range(8):\n", + " img = plt.imread(GFile('gs://trax-ml/reformer/img{}.png'.format(i), 'rb'))\n", + " # Convert from RGBA floating-point to RGB integer representation.\n", + " img = np.asarray(img[:, :, :3] * 255, dtype=np.int32)\n", + " DATA.append(img)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "oBZh0Q2UEiaB", + "colab_type": "code", + "outputId": "d5adcac0-6f76-4c56-e6ef-74becaca87be", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 130 + } + }, + "source": [ + "# We can examine one of the images to make sure we've loaded it correctly.\n", + "plt.figure(figsize=(1.5, 1.5))\n", + "plt.axis('off')\n", + "plt.imshow(DATA[0])" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 5 + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAF8AAABfCAYAAACOTBv1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO29eaxk2X3f9znn3KXq1v72pdfp6Vk4\nC3tmyBluoklKpixSlk0pcbzJQhIkkaXAiWUpSOTEie0IiOEkBmwkCBAHhg1FSGzSlkjKDCVRpLgP\nZzjDmeF0z0z36/3129+rV3vde885+eN3q17T0jQ9DQQtIH2ABrrq3brL7/zOb/n+vud3lfee++Pe\nDH2vb+D/z+O+8O/huC/8ezjuC/8ejvvCv4cjuNMf/4+/9Re9wwHgnEZpj3MK4yRCyrVHe4V3coz3\n4HAYZ7HFvGqlOH/+Im988xWSigFAeYdTClNcxwIKj0eR2+JLL/+0kfMoLV8EPiBTclA7s6RDeODX\nPwtAZTlm41aXkw8sYm9clntkQDL4PsumDUCvvU8zewNIcPkYgDTz9IYZaVexsyXPEgaas6djMiu/\nG3RLHB7mHF6LKPcrcq5RjjKKB463ADA5ZB6+8p03+NDTpwCoJyF/+/98Xr1j4Vtl8E5+pxTgFRpP\n8RXay39cEa56QKGwKkBNJkQ5bJ4TBIDyhbAVyikwvjgPODTGO1Ivwg60xzqPLc4TBeCsJjOWsJih\nuoXmv/gXtMMhAMN2Sm25wnD9IsuJnLvX26MZ3CLubcr95BuMDkZcvbHF7uZAvhs/wGzyDGdOP0Zr\nVp5tc32DnUspYeMlAPYP1mlv1bi53uZULRKZaMVBd8wfvHQLgDj0rMyXeO7xFZK4UDQTvq1875ud\nezjuqPmB9+SFlis8hRIWy0A033uHLj5775jkbNN1pjRZnlFYjeKiigyPniZ4Bo9j7DRm8kvlMcqQ\nOdFyoxUOR+QduSg6x/7Zb3HDtinbKgBltUV9Y4t6Y4Z6+n0A0rZmPILdXgpAKQh48flNnnv4V/n5\nH38GgPm5OplJuHZzm+u3RIvPnDnF5vYup1Y/AcC1Nz7Dh37iU1y9dInf/eK/AmD7cBtGKXYogun0\nNWk6IptJ6I7lu3oluzvh57cJ0XtV2OWj75xy8mFih9ByjPfTCVIefGpxSmw7gHeTvxezqRzeycFa\nTa4HWnkCM/lCodCk1x0P/eN/BMCVwFH3muXahkzQ/qs0yOm+uYaviK1uGkdS2sLq7wIQph/nv/6r\nf4/V+ZNcu/o1AP7mf/s5fvEX/iK7Gylzhd2pNpt465hZmAHgodM/z/rFN3juxz7CtWuvifDOH5IR\nk/VFwN1xSr/f59Z6B63E7Ayz/t0JXwFq4nDRKOVxqImpBpQIuvhsvMK6XFbE5BCtyNOcQIN1E8Eq\ncuuJYrl87hxGabz15MVq0AZy5wkKr+xx0LHM/Dd/i/yxh+SYzZs08gvosTx8KT9AVc8ys7JP3O/I\nHY5eJ7F/gp/46G/Ic+SQlGP2Ll/BeLn+uSdO4dIRTz55ju+/dQmA2TDi+LFVgqgIHDwsnz5OHkac\nfebDAFx68wXqcRVdEQWZVR5nZxinOWkmqpal47sTfiFSeQjlUV6jsfjCKYoThomtcYDXCuudOFTA\nKMhthlIQFEqcA9oolJ9ESWKTPBAUBzmnQHnc5Fo+wycJJz/0LJ2OCLYZ3KLVvUqpslJM2AkWVYc0\nfYXlkWj+2XzAuHeF73z12wAsLZ7kyuWL1GoVlhdOA/CpP/soPqnzjedfZHHhGACDfofXXnyen/0P\n/wMAtq9fZn17l+6Vq+hcTFh/aOiPB0zwsdAYSqEmCkNKsTha6ytvK9n7Dvcejh+q+ZMwUikxMUop\nXBEyKjTgsEc+Eu8UGo2f2m5PlmcozXSqIw8aP/UBaIVRkkMExbnHTkyP1sXquAFP/M+/xltlg+ls\nAdDK9ghqyygt2rU8/h3ePdjg0VKXqHB0eZrz5sENLp8Xmx+VyqycOkZmHSqR+HxtZ5evffUL/Kmf\n+tPkA/HmX/nil/jYJ3+My1evAPC3/+avknvNw+96FxUt97jQ8OggYmJ3rfUMh552b0xWRCeTYOSu\nhK8mkQy+iNPVbcvF45VDTeJ9BQaFxaPcZIIUfpjKxEyOm7jtyXJF4RUY5cTcMIluQNtcHuJnPsXN\nH3uOK9fXWRrK71qt4/TJWfaHADxCxLl6H22H5L3i/sc5c5Ut3rv0IAAf/uk/iwlCxsMeX/29LwPQ\nnKvziU/8BLtbGyS1BID/9Jd+gd/67Of5L//aXwNg9vgCS4sn2b5ykbnSjtxTnmOUIVAiRm0slYqh\nisarifmchIjvUPhOia0H8A6MEoFMbJxxFtC44kLGWazyGCt2XSYPxi4j9RBPbLxSZNZL4oWsLu81\nTmsmWuScRQUB2TU55uFf+HmujFIeW1omUiKg9OB1FnWVObsOwJOsobMe6cCQ9cUut/cq/N5bi1wu\nvQLA73znl1mYn2fY7tJNRwA8+4EP8MwHP8gHVj9CurcNwE/+1CepJiXOPfu03EA+5NRixHh3DWdl\nzaqghPMON9FGb0B5PBr9b4Tlf9S4b/Pv4bhzkoWaLhsBFo7sP4AzQaEFogkOCU8z5THFzDtvSccD\nkuAozs8dxFozwY0AXK5QgcYVq0Mph9rJqPwXvwJA7dy7SK68SKOySHtX7PIxdYYkfY1nh2sAlPw2\ndqQhs1NT0EsjvnztOH/nf5HznDpzCus8SoEJi4jEOrTN6e9u8iu/8ssAtJIaOSFnapJ0NQPFW5df\npJwYFOJjPB7t7TRqQ2kU4J0jKBIW696+UnhH4XtAKTP97HyRYN3mTJUWkyQCU1jnUMpNbZ73Dj8C\nHwVoJk4IlJGUDMQHBKFnaC2mCC197jBPvpvV/+wXAfj25i3y9pDWeMQwbopgwwEPre9yxv42IFBR\nThN8D1uEg7+/NsfH//zPcuLBUwCM05RARzjvyYbDqRC8tfzK3/hVPvcb/zcASw+e4sMf/hgXXrwO\nQDI/wOsyJmCaG3qX4ZxCqeK5AgUeBqOcJBLDO1t5e+Ny5yRLqamgrRMYQcx9YfPw4BR+ik/maAzW\nH4FvzgssYa2/zchJLjDJZjWKwEseYIoHSbdg4X/7O7StaM78wT7naHJzdZ6VfXF4ycEOy/oaRYBC\nIzLYPCNwKU6LVn/7fIc/96eb06hDEeByh/MWClQzKpf4n/7+3+dLX/gsKw8/DEC5UuPmzavMnX4c\ngK3rXyQbDrAGOkXeVE7m0cGRWQ8kFaRSDklzkcm1g7uNdiZJFBAojXUOoxSTONLi8UodZcFOS/Tj\nQE2QTu/RFlykpriN8x6fa8KgMEReMfZOVs4N+e7c//D36K4ssteTsKWysUtt6QS0D9nfFKd4OtxD\nlxVqAp8oSznqieMO5cuFxohKkpBnoolZmmO0IR2PSALRhm989Wv8r3/312iePIErnHCne0iiF3ng\nKUngWmf/CuW4TKNeZ2dTENIL3/wNRuMIncu5U6VRoSYOA6KSCD+I3l689x3uPRw/BNtR0zVlXSr4\nCx6vJniHwgo6Lz/w4ApzMvEzucvILSTakbmjwkiExvsj6MKiCDOL/VMfA2D9Q8+Rr7RYvXJVTt1s\nsjcYkAURp5uCwy/tvsAp/30CNS/3Y/bETzlHruoAvLg5h/30Z3j47LsAaM7P0e+PMHbMMJOb/K9+\n+ZdoHFshIMBq0WITKcaDIV//g28A8N6PfoSPPvEMLz7/LQ5vXQCgXK8RmQRd2E+nFc46RsNsaq6d\nv0uzI1anALp0UAhUkEs5QEzFtJKFBa9wePQkAsodGrFUxSon9QqtISvu0GuBivub8PD/+Gvy3fiA\nE99fp312DoDLVy8SLZ0mHd4kOhQj3/RDgrCO82IabP51jDboCqhDuadGvcJL332Jv/HXxXF/8Ec+\njDIljh9b5dqVNwC4dfMyQZwQGD8NFBSek48eZ3FOJrZVCvjcP/1HmPwaSdgQ4UVV0LpIGgEnGW2S\nmCnSO5HfOxa+vw32VV6jdA5YnJvYc4EcpuGod3ibg/PkxepIxzmDQzAlRxSa6aRmgpPKhWyAu5HS\n+Gf/hH0lWr39ygXU4x9gqyP2NQkStjsd7CjmrVQmZHZ8nGbnJrOVr8v9lMv4NMX2mab3T54J+dZa\nxN6+AG2//dnPkeUjtnd2prHv7EyTU0shEUp8GlApa5ZbfdKewNWvXtjCK8PqwhLuKNyTRL2IfjRS\n3XKFXORh7zLa8d7jVZGruhDwAjNPYnjtMCYiKpfk4oFBB2WiJKFaEwEd7G/y8f+4ypULF7jxhsC1\ncQjKmClG5Dsp0S/95zy6tMjogmSrs9XjKDKWOhK1DCslgjwlz9rEsTjTNg121AlS+6aclxyvPWSg\njTzaRx8f8fyVEO8LiFcHuCigXmmRu8JT+5hTczmRzbGTe/I53b1rAjAh8XoQG7x3KHWE6lrv0BPX\nqTzDUcpSyzBM5bthOsn1//C473Dv4fghoabBhDEAppRgggAdRuigLH/WHoXBFlplrccEHu8DTCin\n9qbM6cfOcfrRJ+jvS7b45X/1WQYH+5hQ5r63D3/moz/C9nyJN8Xq8J4PvZ/zl7/NfGkRgD0/oDzy\n+Laj48TmX2s7Hqts0S5LLO663yMuGyIDg26h6eOI5QXH9r6UGo23jK3HGTUNP8slQ57nGHWbffaA\nNtMCvs0dldji3BiFxI/OaLTzR6bAK3QQcH1zTFIU8FfmkrsTfn35QUlFEbvmvJXkpBjWOtBqavMM\nKdlAk9shm/2u3KDNKUUVUueYhOO9zj6BiXAF+LX4V/99vvPUA3QurlFZEef52sYGcS/moC02N7CW\nuBJjqxnNkRRTZs8sUt02dGJxiuw1aLBLWFbUyjJBy6MxT52q86/b8jkKK+h+Sqg12YTeUqqg1Bjl\nBZEF0GgCHGmBlztrGWdD9vdyjBYNiZKYUmDQgZhGrUNCownrCXlRtruyeZcO11qLK8IojcdZi7Vj\nUitCw2lwYHRRDvQ5SivKcYh1ReLhIXdjXArpKCvOC0GQkkqiiv+Zv8S181doLc4xzGVVNXb36Xzn\n2/T3DwA4e+YsG7ZOTERpRlbDYTbgm7P/Hu8ZSUFbVR9nvrKD1Y5SLOHgI6e2+f52ix/9UA2A9f2U\nV75rCSOYK4sGzwYdYlJAHUHovqg754XwvCcMAsplBb7AhHLoeVAF7B0HniiIQXnCovZsjtCZPzTu\n2/x7OO6o+YP+cPp/rQ2B1linCEPRTo0jzVLGmRQzlI8wYUQ6Gk39QKgTgrhEmvemKyQuVbAbfar/\n/X8nJz//XcLHniJVM5T6Eu3s7dwiWqmjHjgJwG5jCZPtk4YV9iNRp9JoDCj20+MAPDvfQ6ffQY0H\nZKFEYGF5j0899TJr7WcBeP1YTKmmOdgfsxLKCjZX98gxKO3QhQ11WmF0wMRY6rBENYnQpQCjZMUY\nZcBP0Ray3NEdDX/gu+AO+MKdIeXAoApq3sHWJuVGgyCI8Lnc4Mh60CFjLxfQTpH2utTiEma63sbY\nPEahyDPBTfprfUo/9xfQD0kB+8yZIZeCOq3OW9hMHjZspozcLP0dEdDmy6/TeO8DVDo3sPtC59Bl\nRWheYa6yC0C+8wUy06SaDHCDIskzitJsyEOzki+cGW3xZPkMaTrim1fkvm84h4kMJkiICmGpwBDq\ngL6Tey77EXG5KJYUyKvFSwI5wb+MolqO8TiUE9FOWAzvWPhprhn3xebur62x+vRzpGnKqLCLoXZo\nH1IORMs8llK5iVEBFBpjswxPiveW7p4IqfTX/yNqn/hxGscEF99MQwYuxdiArMDY48iQbbfxuayq\n5LGTxMaS12bo2uK7/pCazqkENwE46DdYXszZS+eZj8ShdHstSjom7Ug221iBBxe/x1B/mFdfvQqA\nQWMAYxS6iNJ0WBLuqBfhG6Uko0dPgUSLx2s1FT5eSbkUmCSnd2AL3ln4h70e/bY86DhqEEYxzmpU\n4d3DwACOvED1tMtxNifNh2QFFBxHISooQ9Zm8MRzctyf+AjxMGSzJFFKMNilsn4d62Hu5Am5+MZr\n+GRMqkXLoyTkYGdMFEK9qNENjeayOsFuXyb1TLfGI0GfWrZBQ2rjxIsOZxyRFlOZ5ylYg+q+QJhL\nIuiM4EsedVTM8RLJ5dkEXgGtFdbpKRSuvEPfFszkCJzgUUd81jvI977DvYfjjpqflOt853vfA6Bq\nSjyMRRkwrqBWjyF3ClOogjcx1mY4DFHhFLPcYvtDvq5CnvtLPwfAIWOO54pLh4LVl3yf8twA26lS\nLRhie84wCldoHJewMk23mKtCevU1RolUslr+IpkrMYjkmFcrs1w+zHhP7GmNBNVcMtch9ISzslrT\n60OM9lQjw1ZbTIpyIcobKNgXIPUI4/QUgggMBWXRC90RpujtBLdU0zKfnxCypxDKOxb+Gzc3WN8v\n6Nglx2CsiEsRuRInqHOI9JHDGac5oAmjEJfJBFnn+M7738/euaf59YN9AD4wv8C4v89MqUhOQo0e\ntFg/UKg14eQEgwrdEDQyQXNhD/SIuBFQNSKQzD9AMF7HjiQLP6Y9zXLE74zeR70rcf5C5xIqSEHm\nAm0g23WMzYATDcl6d7oxQRgThjHKFA5XCaVxkuGWjEZ5gybFTUqrhamawO7eiRly6ojRqt1dAms3\ndnrs9OTBNtZ3+dSPOlym8boAloISqIy8EHSII3Cefn/MtRlJq7daq7jHz+HeWuPRZKIOnkvZkOG2\nZK8zpsTC7CKPrGxz+Xe/CUDr7DniGpTCglfvcvygj9kbE8fih3quykJ+QHcok/Hs6SF5ZZlu+xhf\nqT0BwPdKfxkXlliKRWHOlL/GA+PfRKURplhlXml0EGB8gFbhVLCpzadFf6U93jtQ+kibvSroR0fE\nAqc8yukpNWDiQ/6ocd/m38PxQzT/kPmmLMPz+7DXSWk1ckxRMNfjDpn1pCXR8n6rwjYD2t1Dbg0F\n/1h99t1cOf8qS9mABx//KADf2LrFUgBLi6sArAcxmUs5mRp+9Kf/PABfu/ga49xSGYldHt+6SPvC\nTZ4LDlheEU3vJE0eXC2xPnpAzpNU2LgJ26niMJXVMVpoYULLsPIIAK+WHyVY+nf5udf/AZ2XvyPP\noRO8AmscalIQ9prM5tNsySiH0x68meJoaIVyfppQOS+lE40XwA2mhLJ3LvxbXRYKvvpBJ+PSlZu8\n772PkhbbctaXZthujjksWADd7as4X6WjqyycOiO/23yLeVMi+cCH+Obr50WQzRbxKGR/LODb8Xib\nNKpQXl3lextX5bnoEW91p0I8Vq8QzVT5Sn6c+UVRiLnqPDeXHqJXOEU7PGR8RhMrSxKLQnQGPWpp\nzvUNKbrXW7OUlOI3z/4VVr56Ua7V0Wh0kbGqiVyxmZ9WqZQuODKe6fYm7cTm3wZqisA52p/A28v+\nh6CaIdMMN6qGnFhokKcpX3zsLADGDYn31okLwr6ZO015XCNqzTGel/i8PNylfPxdjLfXqV+Vhw1V\nyPGzdfq6gIbdAn1dQu9eZy4pqv7xPLulPmpWoIM03aR0pkL1Sk7FiELoqIkfHOIKfr4qNYj8EGMM\nfiAKUUnm2d3dZaEqSKhPIUtTuuUmN899XM7z+S+hUDivMbcZ4pG1mCnKqdAukOJVcYxTRVI1JQ4f\nlQ2nROE7yPfO8EKs0VqWfaV3QFKqMAxiakUE0N7eJjCLjPblwQbff5m9Rx+m2tHoumSvo8VTDLIR\npfXrtA5Ei+vPvYduNKS7K5pfCQ7IMweNJd54Qbbz+JMtllaWCdMCvlU1Bm7A4hPzuIHs9qiYLn7Y\nx0VSUx2T4rTBWodJJOsedIfM4qmMpYy4HyU0SyGmv44992MA2M98Flepo7yfbgD0GrLMipMFlDbk\nHiEFTzZwFBK30xRXTfclTN3sHXpb3He493DcUfOz4YBwVRKYxf/kk3z+yWPsZT2a1wQnaV68yTAs\nkRYAWfz4Y8zbnHQ5Io1F87i5waOrc6SDEVfPyHaeua0N0uEGw/ISAPrkAstpTvviGtFpccKDSoVR\n74BRYTQjbYjLIVWV4aoSn6u4ycAobGUBAJulKJtRDkP2C/JTvrtFuLLC7ljw/Pl8n3Sk6FeWWSo2\njeyXa8KYw0GxVcjhSa0jLBJI5RQGNwXVQLYqOXdkYhRCOvC3cVBvP/4dCb9WqxKcEPsaxHXmtnss\nvnWd/p6Yi3h2lujMaezrwmccPHaMwbFjqDyiWxIsJZqJWV/fZen0aT5s3wJg1OmyveXQPUmEfLXO\ndmrpzSzTjMXHJDnkPqBZFaGFY41xJTZtwFIsDndvOECX68ThxAhDbiO8tjQKa5vPNaA7JE0kX9mt\nL2Nu7VHVkFkBDX1UQpOjVHxkv60iTy1xgc5aJRs/UFN/i/KglMYUpsUqfjDlhaMS4zsV/vb6Fvam\noIPReJs/87Of4iqazceEz+jrdRpXXuHwOSE6NecVw+0+1g9Z6Iitbr+5R/PkAg+GPXqHYndt1sCV\n21NK3cXc0lxappZ5jBcnPDPq09El9ouKWDkqUXcZy9mQdsEIyOeWcSNHfyQPfywO0GmfnazEwUgm\nZDaMsblnqTjvMW154UsvE/30jxMuS/hZXj6DX7+Ei/wUKslzR+Yd5cIDey9O1/rbmDheFZPlp8eg\nJp8mvJ37SdYfy3HnGq4xjArayYPvepxxFBGfmefcoIiArr7Etfd9gsaizO7++VvYSo3l1XkSL3H1\nMOgzaB1jbf8mVSNw8aX9DnpmBhufA6C1sMxcbxNVbdHrTgrWcn1yCSMrYZnD1FIKI8bFuk96bfoq\nIayIadpQHpdoWtmQaFAU8GsxvdzQmRPf1c0Ve/GIoHOJWlf8i4/LgEQ6qgibc+dRqKM9YVMO3m1w\ngpbI53a6pBeC/pQuqe4Q6N9R+ItPP4hqSxi5WK8TDm4xPKjxWksepPTsaWZblv3LYpqau9vw2GlU\n3dO3kuRUT1fJ1tbo1xN6+wKS7Y0dw3SGxpL4k4Vxn3G5SXXQ5cGK2PN+rvEuJI4L5ttgwIEN0eWY\ngn1NyQSUAArbHVFFdTP6lIiKPKNuLIclTb8rZsfbLnNnKxzOPETckGMqvS0iZ3HkWCcnH+cWr/yU\nWm4oGnS4I2FOMJ7ppkEv/1dKTTF/9/aFrDsL/+mVKsvPSlXic40neKU2Q0yflYJ/k7814HquODUW\nLT/2yZ/gYKaFjmKGm1cBWM1GZLVZRkNH51BWzCOVCv0TZ2kUe2WbaYdtX8VVEyb7fkvVFicV1Atm\nwM64z/FQkZYr6KEoxLLR7I/HDJ2ELW1lOWlGbGSarNihuAssNg3HuwLQrQUNFpIHGISO+UC0unv5\nJr6mcC6fopp5Jpu5j1obUBjzfMpY05ZiP9YPQglCp7xDavtvI/z0mSf43KKEh/F+m+buLZLEc1jM\n+G6Ws6hTlv4d2SisGoaNcYljt26QZ/IQ6+Um+QjKgz5JS3D4w/kVtjv7JDW5vFERD+Qjrm4dcLMI\nIxtJhUGqOCzg47g0Q9V3yYY9fFUqUONsQKDH1PsCVadJnQNVIlcWWyCvjWxEd2i4ZZbloWoxm1rR\nCebYX5MVWxqvQfVh2TlZmAtJnCy6CD29mtACdUGaBK8Ft7fTXhTgnCALfspvut915I/luKPmv9TO\nCC7+gRxYKTN6eIWg22WwXYRfKwtUdjpcjiWmDza20Yc77Ngyh3VJsip7HWpJBe89dlaSKhM6ymPL\nWipaMWMCwhBqWtEcil8IUk1l1MfGshI65RLlYJmIjL2BHLNYqdAfj8niyQY1RVaq0KLHQSoqvNSs\nkhvFYVOuHd66TN5cJsoGuOfl2XQ4J30enEcVxOAsywt8nuLcmsBDpuQTFHvSUNO4XzaJC/Lp3aTg\ncpfsBWZalMsioFKtQd4/JLpwjfTUowCcuXmdlZOrRF5ArKBaY7NSpdvfJXhdGMm1dz9Ca5xzo7HI\nXiSJzrn2derecGMkmMxNDccTQ2YzKBflv1LMqFRlPpGJ3tzdpDcYMAyr1Iud471uDxUHLJcKjlA/\n48rBmNmlEp3vS1H90tIT7O3sYiUgotJqsJXH1OszJBvCEbImIPQe78AW6eooy2TTw1SyhcNV4Ow0\nlBGaSGFAPEo2ens//d1kA8g7Fv5MluOrRdTS30Ot3eDazpCTg9cBmH/30yx+4GFu9iSSaA/GMOxQ\nXtuk8ZyEkWzvsD5S7C1XeeSa/K5SC8mTGk9bcYIXXYMbOxYbxcRGNG9kDLrvuVrEuk3nqJfFHYeN\nwsGu71M9fpor+1IRO39rB3vmNGFnh9GsrMberVvUg5yDDcnCcxtRmlOUD/sMnv+/AEhmTh7R1QvB\nWusJOQoVZUfOJNmSMdmT42/ffaKcAHQTX3E/yfrjOe4MrM3N0EolWTGf/l0WZyCbOUd1VTD2+Jl3\n8aby5J2CAXywi9rvkz7+EFcvijb2nzzD2a1rvG/vdbpV4emocsB+rgiLDRTLeweYHPZ0SLNolVLL\nU7TqkRoBzWqNBczhHldR1NdkM0QyO8+t3S3WhwIxN2cquJ0txlkP3RIb3wgNajzgsSVJ8C6+fINy\ndIB5/XWuS3mByqkBS7M1yvWjBnu580SRm2LDmiMG8+3FcadzVNHswLkAlMNhyJW77ci7EL4dD6i8\nJQjmzKkGW2ffT32oOFiQZb83PGC2nVK5LvZ1YDT9uUUaL3yP8H3CjWy4MeWyJjZldvpy82utBcrj\nAde6EvcvRRW8GRPYHkMnIWoQheg44eDKFQA6K/P0rCHt7DGzIILdc2P2PdSLsDIoRSSHB7BynO2i\nGliqBNSbi1y7KPc4zCEeG07/P/+cF8vFc3ZHDHo51caYWk2e7fqNXeZn65Tn40Ia0uDD3VY8QaUo\nb44g+4LNYL0/2vJ0tzvQZ9sDDgNxgPaBj9IbDll76glmFuSuq7t7tF9ap3FCNHpnt81iZOGRM+wX\n3fXy3OBKLcL2iNlCsP7KNVZnI0aZMAp6yQJV3ed4SdMbykrbzeukozYLzSLD1AOSxhytSs4rI3Hw\nw7hE4By1omVkK8+4UK0y4z2HRTFlM8uor/fpdguWm+vxjJ6nVeqyelLue2OzTeYSBt2MdlsmqTMY\ns9HeZe2m3PPqXMLyXImkHFuFgA8AAA7TSURBVEwZCd4Lb2ZScHFMGkIddT67a82PUvAFNXDv5DE2\nkjorowGlL79aCKhMOj+L78vFK6dXKe9epufniS6Kg3v6RImNNKZ3fZ/qKbncvB2QVGepdUTQpaRC\nrBS622Zc3O3YdcjzEenyKRGGtZyNyjA4pN8X4TeCEuXITCHmvOdolsoc1qq0ilU22uwx7GbTaCe3\nLVZvdVg89gB5KsvjPU8+wc5hj/Nv3GR/Q4TvrCMOK7iiicT6wYD17UMq1YCVlijkTLOECW4jTVmF\n1x7rHUdNiN7erd53uPdw3FHzt06dpjMr4Fd964DTVrH1jecZTeiC73k/s+Mcl8jqCK5dovLQInZk\nCbV4rsbcPOUbWxyszOJ7ggkFMytcdGVqWjRvr9NmPlKouMlhYYoqWUqjkjAeyOeRj9hKUuilHBYF\njvfPNOmMumwVzex0r0s8u8RONIP69vMArJaXua4zcoHgyMclzvTf4PqNm5giOTRhxKljxzl94ji7\nbVmN59+6ySuvXiCw8rtyHJOakNHAcaknmFR+0zLbqLJcEMSatRBw0qiy0Pj8DqSpO2M7RtPcl2Wo\nN9Y5vHyZaK6CP/6YTIhSDPM+1V1BFZsPHGMjConckBOrckPXtlLo7DPyJepFUT10OcHhLvWCRt6P\nG2weDrF+SK1eUP+WZti7eovasizfza0N6uESb7S7nC0il26e09VQLwouB60mb3YtJ8cHXK+K0sz0\n+5iyolN0Vjq5vofb/z2GvQOa85OMO8R78M4zW5f7/sizD/O+9zzIjetCX3z1/BX2d7ZJoioTvGZk\nPQeHfXb3Cs5poJmdr7HQLFOOi/Lj3XI14zTHviaJ0ZUAji20qOsydlYclb25RjPukS5KyJj2PKXZ\nEnkfhiW5aG3jRcauykFnxPqqUPjifEQjH+IK+nWSHeArdUoaRonACbFO6JiY8Vg0/+FGwog+fZdC\nkQjtpCmnyxV2vEy+6Q3YHVXIzl8iOSHZc7djyUoRUUlAvTPBW2xcuYTWRzvntfJFkyg13QCYeWk3\n+eBp2aD30OkT7B4ccP7SDd54/c1CeIYoiBnriZbDre091jc09UT80Px87e6Ev/3iC1On8ODWPkmp\nTudjT+NelSU9v9xkuLTMY1uyVBs31vjm4Bid62ucft+TAPRHDlsKaR6bRRU7C7U2uCjnVtGua5xb\nas2I0miIHohgX965TmlukUFB+WiEit3RCB0GbGSiaatpnW3XoVGYoes2ZvbaHj5JaBdtWXo2JSrP\n0hnLMadKu+x2tonD2tTjOTSRNnjnpECC7LJRRYEcwKmcVqvOh599kg++R/o4nF9b5/z5y2xvi3WI\n4wRtyozHnvZIlGb3xt7byve+w72H446aby5coBXJ8jl8+Al6T72b0td/n+GymJ2tlSVObHW5PF/s\n8Pjffx375z5CZWWFziQPacyztdnl9NIyUUe27wz1LPNBjC9KlINqndF4TGxzMi3EqjTQDJWjVmDl\nN7yj7y0HI095UZbyMNZ093foFTa4eyNjMa2xX4OsKyfv1iMCp3lsIBo4XruIVgFRqUpcoKFhEEm3\nFK0whYO0SssWoCnXUqopzjtU0RLgiUdO8tSjJ6Y2/9W3rvD6y28SmpCFgr0xuluHm7z4KtmflEbO\nwbvfzfDLn4dHTlOfFbvcczmDvUOyVJZd/4PnqM3NUentcLMjNjfvj6lGMUp7xrkIKauUGQWKdlFw\n2Oz2KUeGUQZJSRZjpV6hZ9y0++Z+WGJ3b8Cwt8d4QxRif2mVVe/ovS6VtNJggf5ii8P2LUZVEVC/\nnBAN4diBZOrp61/GRHXKlTqlAjQMAiM9/b3CTfr1e4/CQrHZz3uL1j9oKJzzZECriPs/9v6n+ZFn\nH+etq+u89LKgup1bu3cn/N7cKvFHPgDA4NZ5Sk8+RrWqOVgSbEf969+h/+gjZBuCTkbvfZK4f5Vm\nM6K9d00e3kPQWmYvzdGJTIga9Hk5mMUMRWiVTHPy5CJ+0GW/4OC4KCYa9YgKcKVaSvjulVcIZ+v4\nolYQH94imatx4YpEJKWlId3DPmHF4iIRbITmPQsBM9+9CoDFEEclojhChyJY6x2hNqA0pphsoQ1G\nxUsTZEe69zlOKbw96mxulZqSpqzNMDrgsQeP8eRDIqNbe4dvK9/7Nv8ejjtqfvVD76OrBTH0SYm8\nHjHeOkTPS3Sjv/ES2bNP0esU6KRJiZKInp4jjCXxMaZOHA/ZPwhZrYnGxPN1rrx5DWxBZHr8XWzd\nuk6yuMqw6FlQH2UMcku1qPu+9dobuASyZoVSkdCX6yX6FzbIC7TxgCHdSkhSbhIV52lozanf/Me4\n81+Qh4pqhKUKQTkhKDZle2+lV8JtVSqtZWvrtFOUUgTa4JzHmMnvJlx8OcRrieudYxolLTdbdyf8\nwXILWyzDWqgY5DmkKZXrYlIWZ2e42j8kiIpKUjqgmcxwbXtIXLAAqM1QiXdx5YgqEv9pEzEkl8YZ\nQCOuc33s0HFMrYCH8+4eG8kMvatS5O5lfcZJDZMsYBE2XLNcZudqG1/8ZhDUqA0d1u/hTkg38Pde\n+zbx938LXxIFCeOYWr1JqVw+4lhajTEBjnQqEu89yvjpq0nwORaDVky3Remiu+Lt6NmkH8/ET9s7\nsJTvLPzFhKzYVZhf3ufhfI/umUfZuSJRy+Hx4/ir2wwWC45MnNDOHd3UkxY9iW/s7dKYy0kqMcMJ\njz+HQTdlYUG04nBtnXB2Dpdm6MKetqqKfH+DjZHYzEE5glqdmc+/wtonJc6ONofYPYUKJRGq7Wvy\nlSHJ8gLvtzJpK1/4JxDPYoouV0mjRVxOUJhpn0+MJFzGG9l9AtOXNUxJTzoi955A+yPat5cVMtk0\nh5IcRpLAH8T+37HwjQ7onRct18tNtj/9DdJnnqVVvOhlpdNh71vfYvTTPwPA1tIit8YDZlSX/YI/\nGdkRo7RGkOSoWELErd6ITm9IswgZ3/znn2b+F38OM3SMJhyYpMXG1ov0Z0Rje5WEmS9c52p4ktk1\nyWjDzNLWEd2KhIy2PqDZmqHpco594Z/KM9gBPkxozgo7rbawgjEB1uXT9pNaadBidoy/DY38gWK5\nI/TgpGR+2yFHxCqlFBZHYJh2a8G9vfDvO9x7OO4MrO3s00qLziMvb7Nz7ifxlw9ZWpPexbvnX2f0\nF/4y5aI43X7oFOVuD6/BFVxJpXNcskJSrmCKrGpjc4M532dcnLt2fIa9L1+k9MgcutD0vYtXGdZK\njCO5xfqr22xvz+CfydDLckz2rbc4CA15Sa7VOtGiVlJ8/NXfJ7z2glw/apA050iK85pAg/PShbDQ\n8kwpAqewKPSkkK7k9SR2ykyQ1sUa92/YeI26zeOqoh29spPtoXcp/OSwDd+SZOH6yU/wyImA65tt\n+q8XLwL42CdYtwNay2Lz+84y53N6xDAsYm+lsSomMgZfgGRB1iOvlqkVgu0MctbrA07EY0Y3hLM/\n3yrT9iWSwnF3X9AMziaUV0vEa5If7Kyvc3h6lYViMpqjER9Z+yrBN/8lLhZ/UirXmJlfJi4yTuXA\nGVPsIhTJBFqyW+kTOpGqRRk9BfGMDvAuxSk9LREqr7Aqm86FUcUuROWZNmW4W+rI+FvrdE7+lJwj\nsbwxSjm2cwFTNJoKEk9crzGeExg4thloj8tTTLExelyts3vYx820MDviqHPtaAzG7F0UfL9barGd\nGObDhJlZ+d2eN6hmg8OvSOq+MVOhfgaiUcDWV0SrD04vs1KPOBaJ5p9746s0X/pdXNwgjCQRay2u\nEJWr05BRa1A4jIG8YLEqL90HtTIE0x7/Hu3dtEu097k4U9RUm713hATYIkiQFzxQvGWj+OEdSFP3\nbf49HHcONY/9CK2qRDb9JCeIUpL1AYunBMvYaMyRd9r4R4XBFo4HWG1I+23yYtkN21B2W2AzGgWc\n0GnNEf3275EuCNV8o7XISj2SuLsA2w4rCfkNw+4NOU/lqRwqi7Sfv0S0UtATFxJWF+DMC/K+q/k3\nv4gLqygdMrMoBZdya17YZkVbGqc9RkeyTd9MXl1RRI+G6VshvDcoE0x9AF7hfC7U79sp4dofYT7K\n4b3CFf35OTr0nQu/p0ccluUsse8yHrVgY4vqqaKvQi9lbrlEu3iIik1RNidPx6RFA6T9gePxsEsc\nKIbzMmmbL1zjzPY2W3PCHC6vVKklmlI2YLMuoN1hO6bzpTb6QTEpdu4YnZduUF+pMLNV9OB5+Sbl\ntuLExS+LDKMKQRAxt3qSWuFghTF8FB56K2ZHaY0rNl54kCTLMaX+TfdDT6gf3smRzsuGaJjSBaeH\nFK2Qpq+wgj8Exv1bC9+WR2g/4a3UYbdP2e3ST+TBgryLXl3BOXECNks5GFsIPIOunDqql7BVTe4D\ntrfEns9/7it0HjnG4YoIv5aU8NWQYalMuyMONvuXe3ByDKdE0OZCm0HT4bpQuSDR1rlFzfLVF8Qx\nAjouM3/8DM3GHOnEUXqN0RpfcPFDpfCqeKFMsVNQGgG6H7TVTppc+CJeN55pows38RXFb6cvp9Hg\nbdGmeNIA/G57KatyCiNZ4q40ItZdopoiVZLU6MVZulGAKXrcZ5llPExRukS/yPoWKwE2H5P7KvzD\nz8gNrZ7k+uoS0bysoFYZwlqF1w7LLH1GSnTDuRi3HFO6KTD0brbLM43jLK2/wjCQ5Mxc+jRmBlwR\nEbUWVqg3i92Tk67ZKJz3hMXqdBZQHusySa4AcsiNx0xeLTiRrM+ZbNawThpyO+cmc4Z30tQ1nxQm\n8sLZHvX6mvbm+aPGfYd7D8ed4YVI44puf9m4RxQPqZZDLgvLj3huARuG6K7Y5V5uMQwZdDVUC6qG\nSSF37D7/Bk8/9ScBiFZ3eGNukbAkS7NWjtjqRAS/v84QMTNu8RDXn8VclwTu/Y0RC1/6Kje7mudi\nKYysPvWTjHYusCRlAmqzyyiMNNi+zQ57xbRjlDGa3DpCZaYvk7HFNk/nme6lst5hNLjb3o6cadAF\nlXwyrM+mWi70cYdz7o62fjKUv5M7vj/+Px33zc49HPeFfw/HfeHfw3Ff+Pdw3Bf+PRz3hX8Px/8L\nmha/p4Qii9cAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "VXjtCPxl3I82", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# We'll be using a pre-trained 12-layer Reformer model.\n", + "# First, load the config (which sets all needed hyperparameters).\n", + "!gsutil cp gs://trax-ml/reformer/imgnet64/config.gin ./config.gin\n", + "gin.parse_config_file('./config.gin')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "NhiTshPPbvLY", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Now we construct a ReformerLM instance and load the pre-trained weights.\n", + "# The 'predict' mode configures the model to accept single tokens at a time,\n", + "# instead of feeding in a complete image all at once.\n", + "model_infer = trax.models.ReformerLM(mode='predict')\n", + "model_infer.init_from_file(\n", + " 'gs://trax-ml/reformer/imgnet64/model.pkl', weights_only=True)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zY3hpgnI5Rgn", + "colab_type": "text" + }, + "source": [ + "## Sample from the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PnzRPCzFqIVi", + "colab_type": "text" + }, + "source": [ + "Now we're ready to sample from the pre-trained Reformer model. Unlike during training, sampling processes the images one pixel and channel value at a time. The TPU colab runtime has 8 cores so we can sample 8 images in parallel." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "W9ZetV91PujO", + "colab_type": "code", + "colab": {} + }, + "source": [ + "sampling_decoder = Search(\n", + " trax.models.ReformerLM,\n", + " model_infer.weights,\n", + " temperature=1.0,\n", + " max_decode_len=32*64*3,\n", + " )" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HOLawc5dB7QV", + "colab_type": "text" + }, + "source": [ + "Sampling is an inherently serial process and will take up to 9 minutes to run. A good chunk of that time will be spent on JIT-compiling the code, though, so the code cell below will finish faster when re-run for a second time." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "We9Jj9Rap3cB", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 214 + }, + "outputId": "10b6142b-11f1-414d-9b63-353f721a6a82" + }, + "source": [ + "flat_prompt = []\n", + "for i, img in enumerate(DATA[:trax.fastmath.device_count()]):\n", + " img = img.reshape((-1, 64, 3))[:32, :, :]\n", + " flat_prompt.append(img.reshape((-1,)))\n", + "prompt = np.stack(flat_prompt, 0)\n", + "\n", + "print(\"Prompt:\")\n", + "plt.figure(figsize=(10, 10*8))\n", + "for i in range(prompt.shape[0]):\n", + " plt.subplot(1, 8, i+1)\n", + " plt.axis('off')\n", + " plt.imshow(prompt[i].reshape((-1, 64, 3)), aspect='equal')\n", + "plt.show()\n", + "\n", + "seqs, scores = sampling_decoder.decode(targets_prefix=prompt, batch_size=8)\n", + "\n", + "print(\"Sampled completions:\")\n", + "plt.figure(figsize=(10, 10*8))\n", + "for i in range(prompt.shape[0]):\n", + " plt.subplot(1, 8, i+1)\n", + " plt.axis('off')\n", + " plt.imshow(seqs[i, -1].reshape((-1, 64, 3)), aspect='equal')\n", + "\n", + "plt.figure(figsize=(10, 10*8))\n", + "for i in range(prompt.shape[0]):\n", + " plt.subplot(1, 8, i+1)\n", + " plt.axis('off')\n", + " img = jnp.concatenate([prompt[i], seqs[i, -1]], -1)\n", + " plt.imshow(img.reshape((-1, 64, 3)), aspect='equal')" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Prompt:\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAsCAYAAABhRmIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8eZRlV3Xm+dvn3OGNMUdkRM6DlKkh\nUxISCBCDEGCEwWAwduPC5bmqGozbhbtNuQu77Ha5DMvluapcVe62jSdsbFw2BmMMhUACoQFSQkNK\nqVTOGZkZkTG+ePMdztn9x30RKcal8Fq9ei1WfGtpKfIN97633z7nfOfb3z6iqmxhC1vYwha2sIUt\nfDvD/P/9AbawhS1sYQtb2MIW/r/GFuHZwha2sIUtbGEL3/bYIjxb2MIWtrCFLWzh2x5bhGcLW9jC\nFrawhS1822OL8GxhC1vYwha2sIVve2wRni1sYQtb2MIWtvBtj+BbPfkHv/AO9Xi8N4hRvBesV3Kj\niFNUweOx3uEwGBGefvokzzzwOJWqRdTjRbCAAwRFEfIcUDC2uC4ogQZk4mhkjrQH+//sY1RnYuYu\nt9izfxtu9gw5XQ4k99NurDCSPQNU8HlCmintXkbaEhaveMLAcO2+mMw16LZKrK3lrJ2PcMsxYoX9\nu0axOWQK937pGV5+616GKiG/9KGH5fkG7s4T5zRAMCJ4MTgRqkbIEJwJcMYxjKVhhNBYBEHE0MUR\nilAFVIQcsGIQASNCLEIPQECKMGEwZCgWg4jHqgcxRCheDF6VFDAIXkFVEVFyBVElFRDvEQUPJN6D\ngFclQvBAnueUMOTeY0VQVYwqHgjUc8/ebc87NgAPzGVqg5A/+/Vf4/Dhmzl89+tI+w7nhMwLKDgP\n6hWv4L3gvKIevC+uYQffHyAQAVM8r8og98BpkVVGipgZo4gIYoq8M0YRAyIQWUGMYkSxBsTIIO4e\nwRfX0Kv3lEEceM7jKlLksQ5uCIDyih3h847PK1/zcm2srHDo4AEaK8sM1Wp0W022TY5w+4tupNts\nUqtU2b1zL6+5624ee+RRTKycnr3IleUVQjUcO/0sf/OxewjDEcJSwr5DY4yMVhFRvDPkuafZalIp\n1VhZbdHr9jES4r0DySmVQ+JSSL1WoRxFbNteI6yME9gqWepBHMZk9JOEytDt7Nl1F8uNZzl79l4C\nGmT9RSqVnUTxfrq+yo1HXk2302Zp7gFWlo8xs+s13HnX23njCyfQIljPCz/7m2/T1cUuzliWly9y\n5tnTRBVL6pW5C02WOvCHH/hPXLN/ih95z/dDUieoZAR5lUo9pi3zlMeUzqqgudLtClEAzimBEWrb\nIEuUsckIu3SISydXEZ/SaLaZmBzmrte+iIPXHGL3rkNIXOIrxz/N7/7WnxD1lNf90Cv57u/6fpaW\nT/Inf/aXHL3/MkFJCMqKZDBhyrTbKaUZJfWe0SnD6iKM7ILOGrg2jE1BJOD6wlpLOX3UPe/Y/MgP\n7tMrjRSLYsXQbHbo93p4p0QlIYosy3mXP7r9Vdxyx6043UGWXKY7VOFif436noMENka9J0vXSPoN\n2mtLfPz+x1hb6/HI4xcgVH7+p9+JsauMDteIKxWsGhBDril/9/cP8ME/vQ8JTDGn4QjjiDe8+S7G\nhsuExlMKDHgHRlBvMXiyLEMUFEHFk+eet976aiIT8lu/90GG+6dp2WEe65R49omzmAguLj7/vAH4\nnff/hCZJk9Ba4iiimCGELPVcWrhML03odLu0kg69tAveERhBtUkt7FKNAypRndhGxVyAFvODxqS5\n0s8cqYIJY0rlCuVomDiuUYpiRALyzJHnOVnuaLd7NNbWmL18kU4/oZ8N5jh1pE5p9RMEC2LBWKJA\nCEyItUJgLNYKobUIivcg3oAWc5IVQ2CF+x85+bzjs7BwVkUVsASBLeYBoFQdx4iQ9BuIhARhBICx\nFhEBVYwJsdZSzHdXbyny1bf/2n8D9JMOF44/SHN+juHhIXYduYNyffKrrgOg6nFPfhgrLdj1chSD\nX75I84mHWF1p0bI1/Pg0k3sPEg2Pgjii0ghhqY6YEFDCqII1IemTH6F8y9u/YWy+JeFxYlFfLAqo\nYFC8gFHBqUehSHkJEO9R8bg8JwgAURyCeAGrGAWPwarHm2Jxc94TBeCdIbOOMHcMORj5yEdohD16\njZT6TJXepZPMVJR2e5m4cwbN5+iv9jk3e4Wl+S6a7Ge8chsH9t3I6DjMX5pj8VRKOPwoK6uXaFyp\nc/FSgz2Vbay2Eu579DJxqGyfLPHiw9upxBax4bcKxddDlVwFa8AKhN6DsYQUREO9YdUW5CN3DmMs\nRjJiLLkqXYEICpIxWFQjDLl6YmMQ9WQYUgrSGKqAOEIvZCLFxKHQQwm0mERVlQwIEDJVrAoBQugc\noYGeGJx3uAEJNUDbe8oiqFj66hiSggjmKGZAUS1+c7GhIBHqMn7wZ9/LJ/7v3+PMb/9H7n7Xewkj\nQTKHV4M1oCp4D06VwIP6gsRIkVwF11AAxYjgBt+TAQ+x6jdeWJCXgjyKASMFEbIDYmOtYkxxnXWi\ngxQEqBgdxfeFb0x2RAoCWbyieM/6KNgMZrZNUauUCAPDgf17aawuE8UBQWjodruEYYXdu/bz2te+\nmiceO0ovaWF8xNNPPcULX3oHcxcuEYchB6/Zx6lT85SqFovFuZRKtYzLDc73qFYN3ncJwh71YUu3\nreRZhjHQ76UEYYDgqIZC1usRhE2iMCSKI5J+Rp7lBMYSRiGVkRGCkQOk+UXaSydo5ys0W1dwzYyw\nPMajj36YWsWztjJPEG9j14FDjEyObjpvvvPF/5JXvez1fPSe/4d/+6v/irEZeOub3sM9X/gfXFpo\n8iv/6heY2FbmX//s9/PyW76X1HdoN5dZWVkgCVeoVz3Li9BdFeJIKJcgS2FoSrCBkuVCdVjodhNq\nQZPxaUskNZqdFlNTYwRRCZtcIbhyAr/jjST9nNwJsVW6rQWePvEZnjo2S7NZ/OpRCHsOxAgJrbMp\n7bZjeqJMP+2Rpp5e2zAFpA6cQq8BrTa0VhWXbi42KoYoFEIxtHsOYy3eeZxTDBHOQQgcW1rmlgtr\ndCZCTkmbZt5mrdMmfOZJGs0mIzXDQ48dx6qn0/d0Us/liy1cpvSSJqn2uXCmx+//8Qd55UsOs7jU\nZdtMjbtediPX7BmnXotIk5Q4hJmd0xx5wXVsmyozXI6AjMwZRAXvDVme4xHCUoRzHnWeUikiTRxP\nLz6LX/LYzjxjwxHPXF6iPLKf0bGA5Ua+6dwBS5p4JBKseESKzY/3xaYwd0ruldwZvA9JcynmCC1h\n1BGYiDiI8RKBCM47AgyIRTQgMAYjJeJSjXK5TlwqE8cxcRRhjMU5T5Z5kn6KupgsC4jiHr3eGiGO\nVD24AK8O5xy5E1QUsZDmBms81lishcga4qiYkQTBeIfFbsxtTjdXnOl3VpGsRxCVkVKVMKqAGFRd\n8f3EYIxBZHBdLe6rBeVCB/crOI0MXqLfkORcfVwwYrGBJev3cKUIzZPBvFoIHcUqNBADZu5g/qkv\nMD0REtXHycZqLO3NmcuPkTnDiAgu62KlTpbngw1ugLUW9Q5BwRh6k7dR/iZx+JaEJ1AlFzZY5vo3\nNlqwdFXP+rmFRRwMWZ4xEG2KhZdCKQCL4km8wRjFiiXzDmsEjydST96DnX/yd8y6BmVXoyxXGJq7\nwtDwGEPpMdKGYWkppRQEHH14nhcfeh/vvPs2JieGyGyF8xcXuHD5MgcO7GV+YYm9O97A+Wf+By//\nzrdy7tQpPvTnvwf9FNfzNDuGNO2TjVVoJZ6havY8U6eAxaCmIIGhV7AGdQ4VgxhHIAa8p2/sgNR4\nYlUSFDWQekENZKIEKKEKGY4RI6Qe0sHSG4vBKwieQBWHEnpIjMOpYHBYAlItJtQizRSvHotskKmy\nQoAnFiFTUDyOgiAEIoDDqZJq8VkThFAVCySb22gVUEA8Pk148zv/Vz7xB3/M7777+/nn/9d/YWRy\nrCDIWmSOGrAq5AP1xvgB6diQWgbXG8REtFClioFjNl4iAipXiY4xYEWK/xvF2mIgGqMY3EAVKvap\nxQGc+tybbXyGDUqjxfWLnVbxaURls3yHMDSMjw7RbjZpra5yzbW7GanvoVIp022vMT5UY2pqis9/\n/l4uzZ7l5iNHCOIhvudtb2P/wYOcePIZtu/ezuT0Tn7nmf+KS2tkSaFiuFzp9xN6HSXPoiIPZBhj\nhHpdCUNLEASICLnL6fUSdu/dQz9PWFlrUomHMFjwDpelSFyl12+z1rxCaUiZnBwhzIYIzTCNZhMC\nx/btB1lenePcyUeIokle/Yo3cvsdd1Apb75iPjE5CcDdr/hnfOJzf0Z1bIq/+uhvs7IEK5fg4aNf\nYeYa2HH99Rx54S3cfOBGvuPOtzC/dILdd1zHCw6NMhY42r0WYUVwfkBKRVi6INQnIYhgbUlYaZ4j\nTcD3wZZKjI4PEWVL7F66h+0pPH3yfh454YhKCh1otFY5fuox7rvvDHNXILBCXFb27t1LJcp4fP4i\nYeToNxK+70ffyT987I9oxQm9ppDlIB66C4oGggQwdc3mYhObgH6vQ8crViDtZTjnAIP3inolzWF6\nz3Wc3zPOQytPMb84z/U7DvGr/+1DnL8CIxW4sgAzY/Dmu1/C9FSZrzx5gTRLSbKMeq1Gc+UKu3dt\n54ljOU8ce4xaDaIIPvH3jzI9BjfsHWNqxw7GxuqEwyVqlRHCIKKXZsUGAkcUGILQUy6FqBp6/QQj\nFmMDjDGUSxHzrRWaF+cYkmVCKREFyrZaj5tfvZ0vHG1sOndUIXOOIA/wAVgtlHdFcD4A8VhTIrBC\nLx8oJ1Ko4i73OBfiXIxKiA7UDYfgvEUkwpoIE8QEpkwcVqlX6tg4xiJ4VQILgQFrIrwL6KdCHFQJ\nbILmCcZCF4+mildL7hWPwXiwBnJjgOI6mYVMPZEtFHYrxZxtjUG0EBA2A5d06J07hs8SSqMzDO++\nntLQJLgUNWXwKd5niDEIHh2sISBYG6JhQbaKeVI2iI4qz/n76mZx8IuACFFUIjDKMD3i3iwaOLzL\nwSXgE8T3IBqhs7ZKOD7D0MxBAIK0RCwpExXLmgtRn6IuxQQRtFdJxKLqKVlL6JNCtY8qlLX9TePw\nLQlPzvoOWzb2vgJ48QOqN9AAim9dLA6pw0tRwlI/WBnwIJ4ifsUKYkQJbLFCCYb0gufg7/9nzgae\nITXM1OewK08wTE7rxGm02mDEeuKhRwjT1/Hz7/pVdkzu4fy5L/Bzv/hx3v0T72BpLmVidJzayAjq\nPGNTYxzc904unXyGF7/2VXzuk39ERkzWyWglKZ1Oh8uXmhix9LLO5jLIKwOVEW8c1hfBTEULyRlH\ngiVWV6glIqgaAlG8g9RCmYJ8RFLEKhKlj0V9MQIsgOYYAlI8BiFW6BplGKFnoeYNoq7QG0RRNWSq\nhGoQih1KIkpjQCYsHqNCn6IkaRBy78kHyoUTqCMY9agY1HtiNjm6YJ0BA0reT/muf/HDNAn4m3//\nL/nJP/goSd8Xk/TgvkYHMQLUy4DfFO9fV3kKYUdxqli9en0FzAbZWf9PBmWrAfkRxRgw4gpFR77u\nYw7IT/HA+uB97jfX5z6gz6FGmzytPEv6lKKQ0AovfNGtOE0olQz1epWpsTFGh8pkeZulpQUmJ6dZ\nWmpSHYLFtRYnTp9m/sIc23ZO8rKXvojHH38lTzz5OK3VDplaWs2U+bllmmtd6vUhxBQEq16vEsWW\nUjnCmJDARHgjBD6lakrUaqOsrC7SS0NEPL2sS+b6DMczTG0bpSRtTHuNwK8xMVGhVhonigKaPYjK\nNcbtLi6eOcvkyLUcvu42to2Vn8NYnz8azRUAem6Nmw69ifFto1yav0KQLjMfPMPnH/o4/doz7Jm6\nnsnhYYKgBMD52ePsGq9wzYE7eOALn2BifIhm2sQjRCUIQs/k3kIGXDyvNOYUa4SwBJ1MKfmQKFni\nVv8su0eh5wzX1Jb5jgguTBuSC0qvmXC606G1BqEVAhHmF5RLZ9vcfffr+PI/fJgoiuiser7y5b8n\nyfqUx0ACiJ1lbcnRawn1Mcv+w4oZdpuLTTshywtFMs0dYSz0+x4xhszlhEFISWDkroMslC3TtX2I\nVx768hmWlqEMuD7cdM0kb3nDzVxeWeXhJy9z+tQynTRHA0Op1+PYo49w5xsmeNubb+b++x5n+zRM\nj8fM7Jxi7/7d1IbriLVAQGN1lU7epFItI2IJpFBROg6cyykFnnolZKgOvZ6jn3isqZJnSikMeel1\nN/DA7CJRxXPy1FnS8xfYdV3Inl3jm84dURA/UIc9hXqCIKoIIZEYiCy5F5IsLxQNA4EarAiBDYEY\nlQART2gKwmxtDARgS1gbEcUVKpUKpbiENSEeQXSgMCBYG4CHNMkohQFl6/HpMjYeIsibeNcprAeA\nMSFhVMcaAwYUQ2AAo0VZ3hYleQuE1hCJxYgF8y2X7q+DFUMcVsgTjzaXSRbOE1WHEZ8gYQmftsnX\nVrBxCc0SXNLEtxcLhXdiL/HMTdgwHgS2UGXWoSobis6GMk5BfATI8WT9FuOjq5h+A6ovx+LQs3+M\nzJ2G2WMsm1tZu/77GL/2JrzrIjbE+x4SVxg68ALKTlmbPUOeZqRpl6nsDG5lhVwisuoOOvEEoSi1\nbI5S/yTsvOUbxuFbRq34Cr5YakXxyNWFZlAycT4vFB8AI+RpTmAG/gwVcqdEcTDwhhjUKWqKQRHY\ngrXSdIz9u18gv/EgZv4iw/lxTJJRyleR2rWMbV8h7jSR/lO87VUn8DlUyjHLZ85iNeCWI3vxaZ+b\nbrqFY8+eYjyM2LVzB0FkMAoz+3aRhxHeeobiGqYqjIvi3RhJmpNmjixNNpVAqkroDVlRsSMb1HtF\nhY56Im8R4xGESDyBt/TFgzEYhMgrLYFQPE6FUChYv3pUit0JeIxCb1DKKgd5wfTV0336ONnKCjI9\nAwcP4t2gRIjBU3h4rBaqWnnwm63gCWTglRGw6wv7QOXwWpTnmqqEYojUFGTtn8R3riY+oqS9lB/4\nFz/Ax+pD/OEHfocf+OmfIgjB5W6QT+uqCYVXhwGP0Ks7CVXBIAW50XWFZTBxFMEvatxGB3K2bpCf\novadb3h9njswN5Jdr5anNnYtzzH1bPAinkOE5Op7ni/q9SGeevJxDt9wPb1+j+XVOcbG6gzV69xz\nz2d4/d13USmX8F5ZXV1jrdFkaHSSRrvFth3bictlvHM89OB93H77YTK6PH3iGSZH9uHzPlOTo2zf\nPkUUW1RT+v0uxuSEEoMEDNfHmJyYZmJigjKelfPPsrh8npGDhzBxjU7SLsqexhT+kDCm05onbZ0l\n6V+kVq/hNcJITESXIExotRuMT+/jhiOvZGpmqvBN6SZrNsCH/+I/c/+zf8H9932QyanX8Nhf38Py\nEtx551u5fOE0nSShvdxlz80Hef13vIOHj36MhfYsf/PpD/NDP/Ie/u6vf59zC7B/f5vhcUOnBd0V\nz75rx1hbXuPMcce2qYgX3LaDY4+dBQM37NlGvpTyum3Pcv12sIEwFEHXGO44DH95ytATw003v5Iv\nPnQfXiGOBJ9CyQjHj11idfGjJPmg7J97cmkws2uEfteycqFHc7ZLlgr1aagOZ7TWIMg3p4AFxlCK\nQ/IsQY3QaafkDmqRkKmQZV2u2b+L3KZkPaVSG2Fk2zif+9PPsLIExsLth7fx+te+gNnFBdY6Kc7b\nwnqAILmSmBLnn13g/pGHeMNrDjJcabJn7w7GxkcwQbHtjYISWR4iGEpBk3arg08dRjxhYMnynNx7\nwijCRAFLSYIVpVYTdkzX6XegkaSMlGfoN1a47cUv4+BNh5HD9/HiFx1k1+4xqpV407mTu3wwHh3q\ni9g6zfHYYtwbQ4QhDi1RWPhjLCCDclFgI2wwKGeJIQgsOCUwESqGIAzBRsSliFKpRGgDrISFCqRF\necggGGvwVUNZQ7a9YD8vuPY2hoeGERE67SbHnzlGrAFJ2uXy2ae5cOnLXG6epxTUsSYq5nijlIzF\nDDYNoREqAYTGFD5G2RxZro/vQSpD+PYa2ZXHyVZncTuuRVwHWxqHXgN99mO4zKFZH8lTAgNBXEE1\nw49dg9hgUOrSQgGDDaJztczFxnoCppg/JaCfZ6S+Rrw0i0anwIaQT9BniBWzl/bYfkZ23EBgLUm/\nRRiE+CwhKpcxQZnI56yhuDyhtTDLtnP/BVP1hHu+h3J1CPIu2liE2YfpzS9SeeE3jsPzoIm2WDzU\nYHComoGnpzC0qin8POIFK5C7DBEIpFCIjBVE/UbJoChJGBDFq0E0QysV9rz8dprNJiPBZUZb5yhV\nt2PsbrZJkzR9nJl+g2vzLl/6/ENMb9vD2TMnqderzEzt461vuR6tDPHFh4+ybWon3U6TJ48+zA/+\n+I+xcOEMlxaWaJ09R6dn6STdgqxYSyk0RGFIKQ5xWt1UAhk8ESAY8gGhSMmJgSE19HDEanE4RAyZ\n97jAIt5jxKAihWFYBY+wosqwKVQcBHJ1GDVk4hlWoSfQVYNbXCD/4IdI8hx1jr6D0vV7Cd/+vUW9\nWXMsQqqQSk6IxXhwKA4IVXFGMN7jKRSeSAvZX6VQl0pi6Kpi1JGr4/nbca9C1yUQ1geBkvQS3vz2\nN/H60Yinjj7Ie/7jb7L7wHbSJB3kBaiaQtPRor6rqhtkR9UTitkQVHSwpzIwIDtFOa8wLlPUeKXw\nPhmKEtfX15x1oDIONEy5SnCuDl5Fv/odX32FTSo8Z86eZ3hsnAPXXsNjXznKjYcPMj05Tr/X48A1\n17DW6GNo0u526bba3HnXnRx/5gJHjtzEDTcfpt1pk/dzLv7t3xLamJHhUXbu3FPsBAPHxHiFzBdK\nbBSUqFbqGK2g/QBUGK/NsLbQpeRWKFcTaqNVxqerXGo26WlGHhpCE2B6XTrLlzgvI1RrQzQXTqL9\nBp3WCJQquFwxzrFy+SStbo/h8euwtSkoGdIMOs3Ne78+9rGPI4/CUABh5QJhJaRSdXzxwb9lrQlZ\nBlme8aVjH+UrP3sf1x7Zw2/8wTs5PZewd+dOLjYWqATC3HllZBhsKDiFcT0MXCbPTxHGnubaMkEF\nDkzv5kC9xu27n+ZFhwRjCzobVyHJCo/Oj96g/PFRw/33f5YL59uQmw0/kA0hLBnmllaZimJSVfpN\nT70yzdhMyqVjU5y9dJTMCcN7hMqkZ6Q8wsRUicRvTlWOrdDredaz1TtPLQ7JfeFXwcA/+7434H0J\nl7fITEoUl/mNX3w3n7//KY4fn+XIjbtpt/usLvdYaXXp9yFJc3RQYs69cK5jWXvyAlF1nCMvvAEr\nFqeF0dhIQDftEscVJBiiUo2olaC3do6+nSbNAnKXEZiQfjeh3+kX446A1aWMuaDJTYe2U45KdPtL\nJGvwkpftZmTvAq8bqXDq7FFOPLFCp5fyI+9776bi00syHBnOF17FXME7h8ejxhfzgxOcdZiwULvx\nHjHFAi6RwQRm3cWHwxbPBRawSBiAGGwYIkYwYjBBUCgc3qJkBPEQmasT9RJ21CpMTV5LYJWZoTrT\n06MElRqvetGL6Pf6zJ48h9x4G0tnbubUlcs8PPsQp1efpmojxIRExhdz3KCaIKGgpIXvcZM2g6zX\nQOceQa+cQNOUcOIwvttA0gZ+9BpMYCkFOUYTlB4SerABGii+t4DvrSCl6iAy6zo4Az4gz1F5WK9z\nbdSExBjyqMwpP0Z5OWaqto2oNoy5/mWsjj7F8aVPMT4xw/ahOr2F45QroxiJi6YmIGsuQX+NvN/G\npS2Glx7D7/8Z9Oy9SDdERncj9WFUFmlluzjHCt9Y33kehMcPDEjrRiQvimBwMlgbvGAw6KC+l+VZ\nYaswEGnB8RyAEawUnV7eK8aCMR6dhSO/+Ss8W7bY5hVGs2WC+gxiqswkn+bm7hzXl1pE1Yw8zfnH\nxx8hKpXZvncnmfNIZZTTi0t84fOf5PVvfhN5t8e9n7qHV7/xtZw5d5Zf+rn3kavh0A03MDWsRf0P\nxTml19NCJva+WBg3AdFBbIwnY6AiYIouKBQD9FyONwF171FriNRvlFf6XgmM0lchMxAZIfPFjtEP\nSoh9yTc6qzKU2DuW/vBPcb0uGlhCl6CBof/UCfS3f5fof/sJMjGoFSK1pCgl5+hLYV6uqOBUyRgo\nTYOF2g9q1Tk58WByC1CSgWel+09QePS5g0LWhRAl7SX827/9Ij/9lttZecdJfuD//DW+83tfTdLP\nBsTGb0w4Xt2A6FwlPmw8v3719T/9hhmZAfEpLDqFV0fka8nOern1OR6dr7ruIO+/QTlvXeFZn3M2\nmTq43FEervGlL32Zaw/soVyq0G536LQ6jI2M0en2WFpa4Jr9e4hsgPeOiW2jfOGLn+P85dPcetut\ndBo5tXiE2dnL7Jvex/6d1/Kl4w+Re0O3m+ICxUYBuQ8pmTqVaJhyyVApl8lzR95f48C+a0m7c0zW\nh1hZWWH/vp08/MQTRLUyEnoqJQNJjysXHsXaGMlWca5L1OtRG5skCMvktgauTdbvE8Vr+KTF0twa\n5VKVC+cWNxcYoBzXSPpt1rxw7NhJKpUqoQ1ZazQIawbj4PzpBc6eWQAHDz72ZbIMwlB45thFjAjT\n+4S4orRXBd+BZsMQVSpErQgxcHk2p1xrcsu1B9gzPMYL7CNcPw3HTxuOHCp+zwuzymhNEAO7JuDH\nXuT489N9WghBDmkOJiw6DHvdQj3BC1Fg6TRzuk3hlptv5eP3fRQfCNG4MHEttNcgI6WTZTTbmyM8\nq82EgS8Ar0I1DnF5gnphaCjgv37gl1nud1lpLhKFAaJQr40zFo9wcNciQ7Ual+cXSZOMIIpoNBqs\nrLQL9d7nlOpjDNUCwmwJ7cVonhKKpRRHhe8LoVoZpp8WHY82HGIlXyXPLWFlO4EpkeYeY1zR4Wn9\nc0zVBs0M3W7OvfefZu/OYcLIcfedt3Bl+fP84x+u8v7/9iRvvHOIiYk6eWlk07nTyZuopoRI4a3z\nWjRpkOFJQJTceDJSnEnAesTmg5gqmfFkJkfFAA5jDQaLNQKDTSpGCwIFRSkLQARvhdVGTOv4LPHi\nWWKXQXWCZXeUtLVM8oLbyTyJux8AACAASURBVA4dZMdttxFFEZU4IZU1Tn3p7xlrX2afH+KWQy9j\ntnOEv5v9IivZJQiiwjSihauGEEQC0KLEtRm45gqSlfHhTqRSxffbZE/dh9GMYMftXFW2czAgPgWf\nI07Q7gq+eQWpTyPGDjpVZcNKgCrWJ4V4oZ6iG8yjPsPkCVWrDFVqrDU7rLbn2f2y7ybLu5QqNUra\nZZttsTR7nIWhKsNj04iYjVnXBhFOuviki01biBHq6UNcuhRzubudPY0y5RqYWo6WZljzSxw/+g/c\n8tZ3f8M4fEvC44vfGe+LNl7PwDDlc4wYUAcC1rHhU+l3OqQeYik6ujIHQVCUfTwGsYqIhyDAnc84\n9LlPcbweM2ICovo+0tUusamyw53gbb1ZquVl0q6l30xpLFdpNDs8+MX7mZqcpNdo0Ur73H7HHbzz\np9/Dzh0zpMsLvP+X38cTx75Iq51A3mPvzBhrxz7FUC3EIAS2mPgqNaVCUCggsrndaCkvGsXFGCoG\ncuMJJKA7KFEFGIwRSi5HrVDTnMAZEhFScRgxWCdYHJEvSoZOhASLNVB2GW52Frk4S2thkaTX5ekv\nH4PhiL/+Nz9ItSqQ9sgTg0Zlfu/TX+SpRx5n+EW3kmaORAqSpAPVJMHjFZKBh8eipDhQoasQeI8b\nkIyuKqFCSR2JGhLd/E7duYHQuU76pVDDFM9LXnyER1ccT504zw++9NV8+NerVKZu4z3v/w9cd3gn\nSVIMnI2Sm/qNPjG/7hreoCjF0FgvDqx7edbv/XVlKF0nOrJxBRlc5mpn1oaEtDGhrctK8tx6ln7N\nv58nev0enY5F85RWq02r3ebMyZPMTG/jRbfeSm91lb1791Kt1QnEsHBlkWp9Gwf2HSQqlzn6pSfo\npp4Lq4tM7dhOlqa4LGVyKKSVCDkG5zwhIbsnd3H93v20Vhc5d+YSy4tnuf7GQxzYvoPG4mni6ihz\nC02yNKCM4/vvfjtO4DMP/E8uLM1SiiJEHZ1mk4nJfQT9jHf987dz/0Of5czCHFe6XerRHly7z2rz\nGGe6TVoXHyZzhosLFzYXGCAvtQm7kA9J4XVzPVprPXodg3aETtszXDGEAvE49JeFLPPEk4LrQ3NV\nSdaE0XGL9Y5oTNk7Bq3lPgE1ggC+743fxU37d3Hl2Ge5e/rLdHuGl/9CMMiiQdcfAYVGXSyG33Mz\nfOC7PZ8/JnzgIciN4DLFBoKNio1cMdcZ4qrhoU+f5PhXTuIzqO9RahNC1ofKiEFdn7mLRYlqM1hu\n9LGhZWK4TmOtwVriEC1y9v3/7idp9ZZYvbTCqTNPYwKQ3HHp3CJDY7sZHa+j3jO9bYR7v/Q0H/rI\nKSKB0Ql41csPkC2dZnpnmd2H9hCVbqTTyymHIc12n9wZSpU6qpa1lS55lmCtYE2bnBgTlokmxsBn\nRJphdFBuUcWrKwh2ltFrKbMXGqSp47aJ3Ry5ZogrrQan5keZ3nct3/n6PmtLF+ilOZeba5vOnYa/\nQGgElTKRBIg1eONwPicNu3j1pC4liTpkkmJLDMrbFu+hazJSYwiMEJqAvvSohEU3pWpClzYQ4Oij\nPqFEldZ8iUtPzFLRNttCxx5SRkpKPQ7AXGa1K+R1xc1+kXi6jj9xP9WDN3Hpyw+wduJJJpqPEpse\nh6ZiOs1HucYIR0bKLMtr+E/p3yHY4ngRlL61RAPCs1noZ36VvJdB5rCFExJbqSGjE7jlC5goBNeF\nrIPkffBpsbnUDtge+dkHkPoMZmSmaDgp5HPCc78L4/uRkVsK9//8n0BrDlmeRc7cC9ktBNk09WQE\nt/OF7LrlDlRTrFHU9ajMHGHnnUNMd9eQzgru9JO05h2loEt1YoZxey1ahV51hnzmBuLWIn1/mF3l\nM+z2BpdEpKc+giYp3nnq849y+w2v/KZx+NZdWgje+8FCddWI5G0ALtuYGjJRrAevjjTpUgkK03Lu\nITaG9eXK51Kc3yCKLGZU/817qd9yA5WzRxmubqOx1GOnHKCSPsntvdOUdAHXN5A5AglopxE//lPv\nZe+BvThflC1sGOKcx7icztI8733vzzBaqZMTcqB+mZFAePbMUcoVizKGUYcMzrERQL0nGLTJbwbO\nQ2gU8QMPigckJxZwgUVxqC8Yuc09BAYjEKnDiNCRHKuGmkAsDgc4DPrscdLHHmeumREd2I/dt5fy\n6Cjn/uaj/Pr7XsG+6e3Ebhkny9htbcJ2Fdcf4V1vO8zPHa8zjGFNlEwUgyUUJcGQ+ZxU1826Bbkx\naogGpcl8IPd6CrITeA9iQR32n9Cl5ZFCZh+YBov+Ax20PXrSNOGGg3v4yFce4Kfe8mZ04RF+8Z0/\nzo/9zC9x91teWpS51snJwCc1KHIBXM3HIilZV2QK21/x4qKIWviqdL0N8mvPjuCrrDtf13HwNY6d\njeuu31a02BhsBp12l5HhOnEcc/SRR6jXa1x36DqSpM/c8gLnnpnlhutu4sSzpxidGOLkxUtcf0ON\ndgKXzpwmimJOnLuIBCGEder1GlF5hD0Hr+fpZ76M5AlhEBNplRPHZ1mba/DiW68hrgpRFrHSaCAN\nz/DIEJHJqYxUuHRllSwLePTxx9m5axevf9lryZIO9x59kDPnjlGvTbJ9ch8uafL3n/mfHLnuANt2\nHOBjn7qXtWyV8alJgjDg0uwFOs02Ya1Onvc2FxjgJ3/85/iDP/wVMjy5M6Qr0FqBfg/qkcEbBa+U\nx4WoDv0GBJGQOsi6yraRmNVewtKi5+ANu1AajFSmSJIO1juMQnt1jsuzGeeXhUcTeOkh5XM/78mc\nx3nIcgiilHYP+m3FqVArw2pLeHq+MORGAxM8gDqhVldK4pB2hAi4VGhdFib31zj8kiGayUVWlwTv\nheFJ8KmwcmVzG4lqNaJeiVluNMnynDgKyPIc8TBe3cZf/e3HGRsdpz6yg9FqlcWFecamhTRt4PMO\nYXWE8/MJ23fv4Vd++SChVYLCSEcv2UcQRajLWGv2CIOAJE+JwxjnApJejtcUAUIRkl4OEZTE00kd\nfaXYyDrHwmKLdidhba1DnilXFlaZX1xgYmKCH/ruF3LzdeM8e/4yc1dgredotHMudU6wY+848bVT\nHH/yBOKbm84dDRt4E5Ebh7FRodp7wfkUifuI5ohLCW0forRYP8SCFJYLrw4VwRVyME4go0MUCNZG\nRRec9sjo4JyjsTDM5ccvUcmbDNcCJvMlKr6HqMXYiNRbqtUaUaVKEOeMjhny3hrp6hwxPXYPreJ2\nhpDmmDBnanuJvJ9i3RVK3R43l1/MUf9wcVwHkOERm21a3QEIywEmL9qQrCiYgniT9snnnoHaMEE7\nwWZtEAd5BurwPiUHXHIB05jHDG1DTbF2ikJnbYxa60Ek/St8OFWUA+xhkt4+svh21vpXWKVKoz6K\nSbuEjfP4bTvIvaNUGiEIQmz7HP7kZ2HlNOXlL4HmRCMH4MAoGrZg4mVUKjvZ4+rkEpG3r8Az/0iy\nNsRCA8rXHUIqU1CdIt99J6W4/k3j8C0JT2FhKMLrN3a7VxcDEcF5j4hHpfBYaB80CjCDg9zEDg7E\nQwlCpeccJvPYm25mx79+Nw/NXyZv9BhN+vTiEdphl4OXljjgPoFVyBkBbePylM+enuBd1+wlSdOi\ny0S1OD8EUOd47//xPj7+53/J9DV7eeUrX83xoxeoTHZRU8YGoD7De0HEYwIBhW4/pxLljFc3t9sy\neNQ5gkF8FENOjqKEqnhjSKyhKxkGQ8UbclOcLeERylgCgU6vS2AtLjBUTj7NylyXymvfQDhSpkxA\neuo4Nzz4ed7y5sM02wmPPTBHffkKLuzxwncM8+QDC7QuLWMnxnjrzX0+eTyjYg29VhvmF2m22oxO\njJG/4hVYG2IVnFe6CKH4DVOy8UpPlKr3WKAjQh3FqmI3W7MBnDOD2n1hkPYKVrXoPpCi5JVmKft2\nTfL7n/o0P/HW7yFbeoY//62f4YHPvon3/of/nUrNkmXuOQyDDcKyrsqs/yHr9Gbgu9H1s3X0ardG\n8UPpoBT1nHKWDMjR13pxnqP0XH3sOQ8Pyo2bRafbJwhiev02d9xxByh88pOf4uDBA/STPjt3T7G8\nusKFS3Ps2LuP7sUlHnn6aZaXlqlVh1hud6nUaxgb0Fhr4LyjVCpjzAhZf5TcJWR5n8uNRegFLFxZ\nIvMJteowmDKLyytUywH9fod2pUOaZ3QzhXyeLAfNuuzf9hImZyZ5yY//L1y48nJq47v40Ic/w/zi\nGhfPneIz9zzC9h27GZnczcXZM7AstLodFpcXeNn1+9Ewprf6zdtDvxne9a5foJcu8/7f+e+ghToc\n1wqC22znjIwI07vK9LMeQVUwcfEj5GueajDEW77vbv7qwx9hYQV8NsubX/8uJkd28plP/QNR7Bmb\nggcef4TkwCG0b/jo5RmMznHXrYINIEvgLz4Bb3od7K8qrZaQO1hZho88oNx7wTBiHWkg+BxqdQUj\njM9AvjJo46UocflMIerw6EM9RnYJ4oROE4bGYGib0lzb3LiqhpbZSw0kKM6PsWKwJiBzGYmBN735\nO9i74wCffeheHjtxhivzy1ycnefUqXniGF54807GxiepVYdYutJhdHyUWhhgSyVqZYN6Jc/61KwF\nB1nu6fcs6ruI8UQo3VRRJ2QuIcky1la7NJp9et2UTqdPu5NgAktgDEFgyDPPxOQQ73jzazmwu8Yz\nZ+f463vOk2WGaqXBgV0ziE/JndBPPUnmuOXGAzz2xMObzp0g7oCkeJPgpFR0mSI4n4HpF2e1+Jwo\nTLDqEDxiAkRzvOY4Xyh6InYwPYQ4EVYiB1QYs+OU0iG6aU63W2HlsVMESYt6JWIqaxD7DoFJ6SYh\n+IxqPWZy5za8c5CuUa948rhCb+4MfvEk/bP3EEUlslaDuGYwPqY+PEllaDvl+ct8x/IOqL6YU/Zh\n3OAcJ2sKr+xm4bME8rQ4hiR0hCEFiU3apCfvJyuNEl2+ggZ9rCnOUFMjQEyuPbxeQi4+hR3fhalP\nUoxIRW/4bq6s3kze7jGx7whhuUK/12D26S/Q4ArB5LVcevooQWyoJSlrj34Ggjp2ejfUU9Sn2KyJ\n685jQ0cc5RgFpyV86wCm7vB6E6STuLU5TBhRSto05Q7avslaLWZo7hPYsT3kzRJ+LUFWFuHWt3/j\nHPlWQZLB7toNPC7FerDefmeBHIPFaVG+8lqcbeCcrjtJceu+FYRACzOzuwJT//3f03DK5OoKtzDC\nxR2TbF9ZpLK6yIw5T7cHw5HF5RmBT/Em5KGnm7xbBCHA5x6vDvKEqFziN37t17jnkx9j+6FDlKt1\nLl48x8S+w1y58CmyXhdnIbOThdQrEFAoG9VySJpbzq9u0gTmcqyxiM9I1OA1IRu04ds8pxSEoJa+\nFRIrdNXjcyUTKFkh8IKSI3lCqdvlykc/xdqBg4zffB3Gd0gfP0k2MUrj0UepTu4lKXfp9Rw325DH\nGj32v7LC4pMTyK7LXHi6y0zJ0D03z3LrBWT3foGs3aZ0x21sv+kmzn/pUQKKLoKKConxZFq0qZat\noZU5nEDklUwhFqGsHq+FPynwm+sIAMhzIQgKRUcGhnVvC7M0dr0CXJzwPDVe44P/+A/8zA//KAun\nH8E88Tf88Bsf45d+5wMcuW0fSb8/IBqF2rJOUgZJunE2BBsHBepXqS9mvTQmMugu+JrfevD2r/L4\n6HPcO+vP6XPI/nOe26Q4iJiQZ06cJrBKq9lieKjGXXe9muXlRcIwwpuUo08eZXR0mi8+9BgmCCHv\ncHb2IgcPXk/fOcqlKnv27OHBBx8sTPhhxMLCFbZNbefipbOosSROWF5e5saD13PszEUunfoKo2MR\n+/eMUysZfO4Io2UmJicYHhtltFymUq4TxhUW584xEmxDG/b/5ey9oyy77jrfz94n31w5dFV1Tmp1\nS2pFS7JlgbOxMDAegzFjDx7SwGOB4b034DfADDCD15Ae89YwYwawCTYDtmVjjCzZVrCCZcVudc6p\ncrh184l77/fHudWSGdtQ7LW6+96qWn1v7bvPPr/9/X0DbhLzpc9/jiQWuL5FZXQIY8CpBjx35Bkm\nRye5fPUKC4sL3Hnv/QwMb2e52cTxRzY3McCP/utbefnocUZquVzcDQRWYOg2BFoZpAupFWMywfoy\nue2AgW4PFC1u2vc6ih+weMNd72L7zA6Myvh/P/bLHLihypkzq8gs37mOnb/E1uFJIuXwl8cG8Kx1\n7rlF4npw2374vU8J7tojuOcmQ7sFXzyqeXgBfGkjLVhTijQVOfnVzlutymgsacj6NycUdOLceLR9\nCooVKJUhiqDbAGlvbuFcXmwwWK0RRRFJnJCJnBtjWS6s1tk6NsbymZd46IkXqflT7Jrcx90HRxkY\nGKBSDZCWxPYEWHDm7BXOn7/MK0dP0+y2mJoYZWZqEM+1SWJoNFq5SaUlSJKYNDFESW72lqaKNO3L\npS0Hv2Dh+h6VWok0VSQqI0kyhioB99++hbHA4rnzczx/UhEEDu1Oyu03jDG3mlFvhFgWOBYEjqDZ\ny9Aln9/8N+/a9NqR7jpCuGjhkGIhRb7La61BhhiVIIzGAdz+IUlriRYJ0uTimo3WtSVdMh2DSUjM\nDg7KOxEL6zzx9DMUtuzFt5oM0kbIDBGF9HRC0TEUXcFASYEf0Isyav46zYVZmusd5MJV/OFJotnz\nrDz/BcT6CkFNYjuS9UuGsbEm4wVFefIgSTTIRHeN+9xp4tKNzMXH8fSriQX/cAv7x4aOkrwbYbKc\nLuAJHEuRRG1M/QpZeBpaa1B0cYpF5PBeilt3k6zNk107DVFIcvEZ5NAk7vbbEX4FhMEi5Yk/+RXG\nBiapDxZxtx6gvOUmqsMzeN1F0qOfYWk2xd5VxAqGsB2XzsJV0gwGRmaQTglZHseyNcL2uV5a+Cnr\n0++gsPNNFGpjqM4i1s5BVp78c0i3sTQ5jZmA8oUHEW2XxtHzJK0eNb+N4dsLkL4zabmP5ttCovqR\nAxiBwvSJbhIjNLlhbe6pIhVoV1w3YzKZxLFzrkhsNEIIbv6tj9KeHGOt06G4sEp5fAYaTeqLy2x3\n1pCBQKSAUARuBy0lOCmj1YgszUiTvNhI4oiCLXn6a0/y3379N6ltnUEnEa12k4IcY8ctkwzs/lcE\nXkC1UuFrn/1dothFZhmJkAhH4jk2rm9hu5tbQKlRpJkhFgKZxsRKIywLSxpiCYExiMzBKziI1NCy\nNyhuuWwVwPciOHmECy9eZurdb8snu91Btzus/dnfUDl0gIEbdjO6ZYr6/Es4pTnOrmsO76xSGyug\n0wMMJiGHDoesL8L05F4G/uZlBraOMOdOo184zjVjU/ngDxLHKY6CTFqgcyfnBJErbcjbBJJc2ZNo\nQaoVgZA4/XbUZoeAPppm+kqp/IaFJfr+QnlFYgClNOWCxe9/8s/5Pz/0E5x+9kvMTGs+/IEf4f3/\n9hd4349/H2maIoXCthy0NjiuTZoY0gwKBYFKM5TW3+yO3C96jDCv8nCu83pe3THMN8NF/V/gVY6P\n6ZsAXUeDNtpar616NjHeeP+beeyxr+TITFAiSTVXr1zj0qVLTM9M4DoT7D9wkIXlOo3uOqVShayT\nsHPHLhrrTcqlCkEQ0Gq1GBkZoVQqYVmSTq+BUlneWnUcysUiUanH0eOvcOvNdzAzvYNW/RJGdNDK\nRRgbzwuo1+skImGwuI04TZGOZnB4iFK1xqW5y8yv97gyt04kfBrtBoduvplemrIwd4WDB2e4eH6d\n1XqdHbv3sG//PnqdBEv4OHLzbrlrcpaZGwZZ6NVxQ0Hcy5We5SGwOwKvCFdPaAYLgso02MOCXh1i\nJYhTzece/yi1kSX+9itNPvAvfo5jp75Eu/0CfrAHSxaw3VyJZLkJS+FlHDVMhs9/e6aITY+7bxUc\n2A9DLxl+5VOCv98Pj57WLBUmObTH5ciRywSuS9FoMmGIeoKgYEibMDomWI5DtACTgbYgSfPuAEIQ\n9cArQGdF0KhrvGBz62Z0aDBXeCKRlo3KFGncodOGbSNb8cduwBm+i9//d+9FZBFJ2EJFLVTUQaUh\nQtpI28eybW687SDOvXdhFcqcml3k0a89wy//yu/S6sHP//hhCoUqvdjQabcBCKMYneZ7nrQlhcDO\nuYFxkneKpSZKFc1OyOSgz523TiEswRNHLrDDH8HziyiadLop46MFxoerzC6v0+6ElAKbRjvFCIUl\nDZ2uxtsyuOm1I7xuzjmBvnK2X7/YeVyNkBu+vv0hc3NYaTbuszagUcYgSbGMQcpR7k7eQnzlJKcv\nL1D2QWZt9hYF3WaDRmJRCBxMpujpmHXVo1yUDPsGWSpw+dhLrC9kGNfHLM9jeT6NMy8xf26OSuDg\n2xpv0MIfgTiUdNbqeNUVytWA7Ow8Y6ng7cEBni3VOJ0+hWPMN+1d/9ShuiF2/1CoyVCuyT2LYoXv\nDTNw6EbsqYMkta0Iu8CKM4jQLSrUKd1pMO3TdF58hujkVxB+GWfqRoQd0Fo+y65dhxkvV7jy8MeJ\nmj0OH3gLfqnCwtFPUeudZfeBtzDf7YAcxqqMIoanSY1GxXWwHfBrCK+I6K7ka8mA2zxL49Hf5Suf\n+iPuOLwX2VnC9Dp0FpfAL7HcTji77nDnjVtx2h7h3EK+T/s2VtH/tvPwj/jw9HuZOsk9dDAYIZFG\n5D69BnQfwdEGMp2SKShITdrP33KRGJNL2xUCJ1XM3Xsn2eQAWy5dxtRqrPV6pLbL9lqP8dXn2WaO\nY4sRpLWW32S0JhMVXlgcZunaCrWRYbrdCEvFhKnhl37xw1SnJrGxUTLDcgVxL+SpJ57m9vvfyP0H\nb+WFb3ydoFLGtQpIKdBSoJUmCtP+KX1zi8hLc+WV0imFTGMbTYShoEAHHj2jiU2KneayPJSFkjmy\n40pDz1Gox5+ktZIy9cB3kfV6SCkoKoMsFNjyI+/l1Cf+EnPkRe564C3Mzy1zy8ECO+8ex2o2WXsu\nZdU7R2st4sa9Du6tVX7v9x+hZxw6toNz33ayn/hRSsMDmF6EEZJM5kRQQ4o0eVKVFJpBrZg1OYLn\nmtwd2zWSjtEUtMayNs+S0zpHcfKsMfqyWXG9SEaIPh8n/6OUwLcFf/AXf8J/+Ll/xxOf/iOmt2/l\nbz/+UV545ll+7bc/QmmowlNPvIAXlHj2S5/hlSMnsFHsO3Qfb/q+tzIxOcLwRIUszlBCI64bX268\nTn9N58QeXsWZ8u/8w/EPfXo2ECZe819qxHXb9X/q+Jfvex+f+cLnmR4dIk4UUdhjdbXO6Ng4WZag\ntMsrx89SqlTZMjXJ6lqdanmAWm2AbdMl2q0OR08co1KpMDo6ilKKTqdDuThKpVJG6WkuXDoJKqFU\n8il4RU6dPEmxVGD39mFs45P2eti4dHs9lEmpWlUSragUPRRwdX6RXq9Lmikyu0Y7Br/iMz4yRdTW\nZMZCRZIkTqjX17jr7ruYmtlCt7NKvdWkm2bIf8a6oRKRNiWmC5YLMoHuOiSRwWSCbtsQ90AX8riG\nzipUhmGoCKorMd4SozMwIq4QnniaYaNxgnHi2KVSrTDWBS18VldiFpY1BVap+MP0lMtvPZzy/9gJ\ntx+S/PDb4P6b8hbXi13wSzaNuINXBhXmPLfhKvgBIHIfrTCE5mpuaJghcMowtgWWZzW9UFAuCjwB\nqTJYDhQqm5uaNDMImRJGMUnSZX4u4/bbbubnf+YXYHg3q/OzxJ010qiDVimoNG9NaA06xXJ8LNvv\nt3JyUYVjOUwXivzkO9/AT7//PTx99By//Z/+C0dfeIztN04yMFhEWBau5+AXHQpoXNelFyekaUaW\nmnz9hTEjNZd3fNdOVtbbPPLSFbR2saVkVnTYPlDFCiWVQYtb9o7y9LHLWKKCo23SLMWVklDl6iff\nd/nkw0/zln+7ufkR9gYKyzfT7jYQXItvPsRgkHbfMNcYDBnGiPwwnWUYA4OlG3HPXaDVW8cSMZPT\nM7zhhnGm1AUeXMk9cTrtDoHdxpZtWjYEDiw3JKmVsjaXcXXVY9SNGNURZAn1xUW6IYxXcqWwE2r8\nwDA6rHFLNmmvTTcx9DprlAcd9tgNJoMD/F1gcz56ikxluX3JJkbWSZEWCFvm5O2Wzt3Uw4TRH31P\nLr1XMS6LoDPsrmCVKSLLpVJ/ieLap3FqNxKev0B66lGEX8Ye3YVXnKAwNkGaxpRueRve+Haay2fp\nNRy8HffD6I0UgiHG68uo1SWGx8dxp4cQRpOlMXG3gddaxs402do1TApoiI3Dte4kd73p7UwWVzCX\nzpB11hibGSbrrOCHBr19O421BrXWGp5tMK6XAzDRP9NpeaOBIKXdh+1lH07T/QWSIzea3E5cZzon\np4rcYjsxua1/isDIPD6iuwhTQY+Z43M0dg9z8fI53PHtJOEsbjOkZkJsp4I2k6jsKSxpIYsgmppq\npcgv/PxPc8/r34CwfKantnDl0mnmZy9iewVsy1wPd9y6f5qx4REGfJsvfOK/YmVXKBbHQObGfOhc\nSl4oWKC/tfz4Ow3fKJJE4WeGyOTWXW5+9oIoRdl59ISlDLaUCK2xTd5SqWlNcOEi9YWImXsPY4Ux\nQa2A0gK13sFxJE7ZJ884sXjslXPs0gkPf3KZn/reEpcaHaarRcJTZ3B9h+MvQmp1WVprUfED1sOE\nQuE2pseGaUUxoQAPg0kNtp1/Hpkx+edmoNkPhRUijwxJ9cbnn5tEtjff0ULpfLXQ5/HkeTC5e+iG\nLcF1TEXkCKE2BplF/Ic/+C3+6/AAl57/Mu1eSP3q03zgDW/g7R/8Sb7+0F+QdNeYmCoRrXYpFhVf\n+cIzPP/VX6U6NMhNh9/Dv/ypn2ZgbIJMZRj1D1g2140W/8En/i0iIv53c8IcNtrgJRnAaLHpltbd\n97+ehx/5Eu9+1zswqkSmBUFpkChW7N4zTZIo6o02rU4MWGzftZsoiplfWqHbvUyhUODee+/hscce\no1qt0Gq2KBQKZCpDhKWqqQAAIABJREFUSM3pE+eYmpokzSKE5yKUhSccGu0Ox45d49ZbDqLsNXpR\nGxUnWJ6m3ogoBC0Gh4ZxHY8kgXonYc+2LTzx3DFKlRpeUCSKQ4xIyMKI+mqd2dl5ur0WtmNYb8yT\nZjHtRos0TQiCwuYmBlicjUkjg12UtGZzwnJtFIoDgm4j71HWJgzRiqHXNThFSDqSmf0GXTY01uDF\n5+DAnpM4ye+wd/oD6I6HPV7DEgEyCGjWNa02BIFAZLDWXaVWqNIIC/z3F0oEXp1tM4LxUc3XzwpU\nMUeGAtdmZHCQpfkWjhaoLvRSg2ML7Cr0mnnkjJACrRReSTBaEkzs9ZlbTiiPGGRP0OxAGguaq5tb\nOLnDcka90WSkVuNzf/3H3H7PWwivvMz6tdMYnZFH+rlIKUHaaCEwKkNYNpYb4NgOeZme+4MJSxK1\nV4h6dazZoxyuDPDg3/x3nnrpNL/x678AysV2c/GD61goDa12jyzTdKKEVitm/7Yqt+4ao9EO+cyT\n53EsB60dbAuGBooEBYtCUGbHRIDthywu9+h2JdWSTZJpqkWPRtwmNRmu7aC15Kkvn9/02nGc/BJV\n5jXdbXiNS/urIoXrKInJkZ8NYY7EQRiP6eBWqnKMkhikJ0KMEujiIO+69xA31y4TdT3a9Q6JsPGt\nKM+RlEUaWUa3rtk2nLEW5mrMdmzTWoq5Ierhiohr5y8zNVKikYLnKtxYMbso6PQUW2Y8LBnTXu/S\n7sZUBiIqhYxqOePd8m4e9gc50/1bhNoceqrjFGVz3Yk6DmOkUDiDO9HNP4fgrpwwp7qgG7iiyFTz\nzyFU6DWbbDkjPPP36KGdMP8K2cguZGUcrzhALXBwB4coz+xEjuxCSpssS1F+BacyQLF7hS0HdiPk\nVozngj2OjiM61hS2NUdQmkSPfwT32J+RPPmHGBtMllJMVyiVJMm55+HyMUQwgtE9rOoIDjAwuYtg\n7jjBwjKWV0VbNiopoJNvH0vynUnLQgAaYSRCZoDqxwH01VtGY1QG2pAJSRJn+UXv676TZc4sl4g8\nNO1aQvXP/pQzR08hbrybpdYiBbvAcquFijzOJsMMxdPUWrMMFZ/CBAEmSVBdSLXm0E6Hb1xo8MW/\n/QJpFrG8sgIKhgZrbBt3cBFYQlAMJBMDXZLOAq+cWsIIiy2j4/mNSYDRfWiz73nz2kymf+rIlMbJ\nDCpTWJaFFuBoQywMrlbYyhAZQ+zYCJXiSotYaDwEa7ZAHL1AcOdN+FGGX/BodnuozJBlCYVMkkYJ\nez74Q5QCn+6jXyGaXSbsdfnGpQWmtu/lyU7K8M33cuH8JdTyJQq+RCYJsZDgOax+8itYL13Cfuub\nCccmKSpFz8qjIkKtrwe5KmNIEGTCkKFxtek7Z+cZaJHJ40M3O7TKJ1sYcst0ARtQoOlbJYsNwnHf\nqVsYk/s8JSE/+6v/N7/+cy1Ky2fIsJkYS3jswY8xVJL4JR9PFBgbtBidsLhth6DjFGnHGa3uc/zO\nR57lwIH7uOvN72Rix0Ecz8qzW+gXMfnLXmcfi421fv0o2F//bKBAr70xvfpM9zfOzar2syTjpkM3\n8/jjT/Hhn/0ZTp04RtF1KFcqzM43KQY227Zvo9FqMzo6jmPbtJI2ruMyOjrK6uoqH//4J3jHO97B\nxQsXOHTTTX1PlTUajSb79u+jXl+kXCmRdRqUC0Ookk3QbRP2DM88d56De3fi2QHrnYvUSgVGh4cJ\n3CE6LY0xLTzPZXF+lrWVJVq9DNcLcG2LKALPDWist2g2WriOh1OrEicd4gxcx8FxXWzH+RYmj//4\n6KwKem2BtCEMDTqB+rxAWgbHE5Rr4HYgcwUH7hxnrR0yv9DA8uBNB29m0m8icMGTuEM1vvTkQxx/\ntsG+uyaYHA9YvBISJfm1P71NkjUtFhczItFkbHyGnlvir85V+f7oEg3b4X+eSjmwYxw7LlEREOsW\ntpOhsh6uEGRdQ2Z5RInG9lMsB1SiQYCnYHXJUAo0diborYATgBvAcMmg083NT6OVsLi0zI+9/4P8\n+1/6CMn6GuvHH0VrhbTsvrAky1WDKleROU6G5WhEvEqhMEPcOIURAVpnSLuCMS2EXUHKAhpBr71G\nfORL3LllN3/914/w0z/zk7Ray6jApxvmflhpmpJkGq0V3//GbVyer/P8xTppZnAshyzVeL5NqWDl\nxY7rsNxpYGOzd7TMkXOr+K6NJSRxlmB5Dq0oxHMcfMvh5aOzXNi8wA/PMf34nA3UJv+6Ma/y7K6b\nv7NRAOWPFKbP+Uuw9CghC5R1Fa1XifytNGrw5u0DOM2z+COGuLVCxesxt6SxfEFB2KjSIAXZYWzU\nJU0bVEs2bgEaHc2pDkg3IM00rguBm1JvGXZO24Q9TZJlXF0tIoIaI6UZEl9geJrF+SWmbjyEcCyG\nVZvXewdY1SeIkjObmhuTphgtURqEyXDGdxHsex3Zxa8gD3wOTQ8ZXcE0T8D6LIy/l8xZZP3FL+Je\nO0L32BGujDzAuLVKMWyhls6Rbb0NHJve818kvPA8Ztth7Ps+iDMwCtIi7nWoNU9Q3Gahk0UwLUxc\nQDiTdOIZEq9Iof4NdPJV6Cwg0imEO4aJl/AE7Hnjd9M0NnLfAzC+D91aIVpdxrgDsGOc8V13EBQz\nrOazyNokxvJxfEFw6/d823n4zgWPMRiRgc7j1wUgNGip8UsDSNvKbZ8LBUrlYdbri7zlx0tcOnWK\na6fP4zkgLAstDKaV4H7459g/PkajUUSQMt5yCIs+dpaQpQ08L6VBlRUxQ6LO4JFhpIEUpGVz/40R\nz150SKWNdm0qxQEynYLx2Dac4aoMJQzGZLTXroC0UNpge3nImBQyJ1GTOz2HUcL4gEWYSMJkcxVz\nlihkppFSIHQ/yoCcD5M5FiiFbcAyCqnz26QlLBIyCklGmGWUmy3U6AArUYjKDEblKFCcZaQSTLvN\n6c9+Frm4TMPz8ITkq6fXGNXXqJKRzF5m5417+dTLLTqL1+hkGVUswjglkxaXj5/FfuEEI7/yf9Eb\nGsakikxu3MI1icnzgIzO878E0DMKY2y00fhmg56+eS2STnM5urbyOAghN2DjDdtxcrSk3yq9zqsx\nAqUFWS/lJz7yH/nUv/8xPCvlwnLCzbcc5sKZo+hMsbLeplD0uHSuzfFuyt6tKcMDKWnLp1S0iVYe\nJ158kRcv7WJk5gPsOnAApVOuR0W8tkWFyBf2BoJ5/W/Rr39eG5OR/17KACb379Cb1KVblkOqNFMz\nW/nTv/wk/+tTf8Vvf/SjCGFjpKAYVFhYWCTJMmZnF5icnqFUKtNqt+gu9ti5cwff8z0PcPXKNW6/\n43U0m00s6TA0NIRlSxYX5/EDj1phJM/jscs41Qlumj7M6PgkX33kQS6ffIQdU0M4fgUhBUXHZd+e\nvZQrFRYWFsiyBMd1SXGIsxQ7cOh02mzbup1Gq82pk2e5dOkqxWKZ2+/aS5J2sB0faQUUCjbdbrcf\nbLm50esBmaEXQq8rUImhMgRxS5A2DCIFrSCzNUtLLVY6IT0FY4Vb+eAPf5z04sMk9QsgLK5dnWPf\nAKztMATCxnE8wgikC2FTgFKkSDxPstbUTE141Ea2Uxv0eV7sY2zLNsKv/CFZO6Adxaw3I1bW29TX\nQ2yhqBTyZOt2JyROwfGh6OWOs0rlN9ChLZDGGZOD0OmB4+avjxSb5sadOL/M//jPv8L7P/ABuvOX\nSZMULAejDVIKLCvD98ro5gI6i7Ecl3D+AlnrLGm0xrV5C8tqULzhR1CXH8fyixSqGbLgI6SFXdmJ\n9EsYb4jeyhW8Xos//ZNP8GM/9sM02x2SLEEbQZgm7Joqs2XI49Nfv8KA7yGkRmW5atFxbDzPwvUk\nUdgjCm186bB/psCLZ+aJE5sw7OKOV0iVoRfFBAWHLFFkqeHRR84wvcl2H9APoqbPydkg2tEvgnLa\nQp4a8M2HGAv6Bqe5kMWRHTy1m0XaoGfo1HvcMjXChYsXuanSobWm8DyL+dU1WnGR8S2TmFKZkekJ\n9uwZYTw6S9yAbn2ZXkGgB1POXYbM8VhaXqVlHLraMDpYIE57LLUFEzvGmNw2hDO0n/LkboYcm0sn\nztJqrmBSje61aDQSyr7ihvIBjpnNFTxCgJYuwmRYts/gu38JESRcPfkc2W+9iYHdu1h8+Rjulmmi\nSx28LV2SHePUVx1WO1to7PsQY3KV6toVUqcM7SVUr4Uol0mlj+cOIObOorttTG0YCztvnyqNaT8P\ng/swwetBljFLn6AcNKgsfwp6VyG4m/RCl97poySxTa3QL1CvvER5h4Oz9jLi0qOojoVILTICNJI0\nbaOuHcdvNSDqIKwM8a/+I1cmPsj2bzMP/whp2cJyPCy/gGXbSMdF2gFC5g6jShks22CMjeXYGCtg\n+4Gb2b7/IN36PI89+Lf01utYjqRTh++9//Usj/hMHz7IyYvPMuKPsWZ6BJHBNDQtHXKloTlQXKIR\n3IhuH8ELLFwLem1FGrvYTgHLKGKV30yzNCPwLbIsy/0FIF/o0kJpjco0RU+hdQyWnydxCw1GIG2b\nq4sxhYJhcnhz8LubxigjSRQ4dr/UEQK0QlsCpOgTuaFvuYhRGsdkpEISHLqR9ZePsPjYCo4lUb5D\n4LikWYatNYnRhLPzFKKQUEA3jFjPUt542ySPPPQoE9UiExMTXFtcJol6hG1DsWgRGYPMUoTlEmUZ\nluWjjr2CfOP9GJEnZFv9WIu4vyn4wiIxitTkQXoCjRaGngEHk6sXNjmyLOfsCJkjghs2OAj6LdAN\np1LTd+vcaG/lhWmqDGEieeA976N1+WG8Y9e4sNzED4pcW2+yZXwU23ZZaK4yv7DKfD1iZHiQg1sz\nZsZtOoOav394mYIzxz2lGja7URv8HfpFTL9F9c005tfG5L7mUui//dxMDYyRKNO3fd8k/0v0XzPV\nKUGpzPt+4ieZOXQrP/+hDzFZc4njiGqtComkXM7RNaUUnuexuLjI4tIi27buZO/e/cRRyuzVOdqd\nFnfceZh2p8ny8iJ79h7I1UNphgxG2XXgfu69+y5qI2N0ki5Xz3yd5aUmxpIUiy5Ga+YXL7C3to/D\ntx7khedeZs/eAzx35BhK2JBl+L5HsVhkvdnh9OkzedSAjICMJO1iux5aGYzQeJ636cgNyDk7rVZO\nLBaJwRcQLhmENCglWZ0z2EXDjpESh2rDnIln6QjNA2/4XtbWlzl1ZZnVM2fpdFJGPMPA2AC7vSEK\nlSFaPUmxCt0OlMqaoVKVlbRHXMxImyBsi8rAAOPjNXZunQYtKXhw5Pw52msJWQyuDSUBy12BY2kK\ngUXg5ER/nRmMk2/WxYJAJjmIWAjytHFPwEDZptPNaLfACjY3N7/+iz/D+z/4o/TmrmBMntmkTN6e\ncizonHiS6Mpj6Oo9MHIDlb2HaJw4weUXzpJhs/tdv0znyjHqYpyqW6O+tsz6YoilVwiqRaR7lOL4\nHopTt2CVpsmiFva1l/it3/kYh2+/i127J5mUDveObmNCFKi0KuzfNcYXFs5wrWvwbPAch8CXuI6k\nGyY4QuKg6YqI9U4TnQjedhOsegf4+nNXqPgucZzvCOWSx3PPXKYOlP4Z6CD0jyobSD79h330diOX\n6nrkQf8n9PUWtwHjk5pJMrmfJN1D2PQYKZ/HdWzKQZF2awkha8w3uiQioDI4QmXHreybqrFrqspA\nQeG2G1w6fgURd/CFy85Jn2vzLTKVMTpawxocZD5RXF7rsnUQDt0gGdwyQFTcSm1kkOpQSq+5zq0/\n9OM8/+n/idKKpLVGMRjhyLNPcOit7+K8M7CpebEDH+GWcT2J0AY5PEV07BN4l57j4vv+EEe2Wb31\njVSGJtj/vv1cfPIhku4K8wtz1LbNMPfI5xi67200mwOU4nWyXhMTtpFpAqUK7vgYOoowhTKyb/ro\n+D5JO6R7soXv/BrWfR+Dmbchpj4C3Xmyyy/S+epDiOghlt0HiEcOw4RmYOGTICE79TTyypNQqiBF\nDaNjrF4PVIiQCeELX8AxIU4RjM4Q4zfg1J9k6/xvwlujbz0P32mSKhO7QFiYvqmg7jtoKpUvUIuE\ntCfJVMhit41WGb5bJNG5nLnTqmNbLrqbMPZT/5rnbtlB69wFXHsBr+Ox3ljAVgqv6KFKKbWoxdDO\nMUrLFi1vBNaqVFnFCQTlIGQiilEyxHWKyG6CIyUpEukXESJG9BVkEomNJlF5Lz1OQ+prGUFJ4dsW\n0nbylFpL4lQKZEpzaXFzm3OQZmhhIWwboTSpyFtDSIGnAWPIXAsXQ6IV0lg5kTZLMGGGZ6UEtx4k\n0wYrVqRJiIpSbAkiU0THjhMurdFKYvYd3M/C/DWSMOXLX/4GtjAkMsNPe6hQYZa7VPxRGtE1Bioe\nvTTFMhk6yxGNxQf/nvKRk5Te/4PooWGUys2rJILYCKTIUNoQidx3JyAPXbVN3vPdvNYG0BqjrW86\nZV13Q95AVsyrmxEbHSUEQmpsAQMFQ8NxmZ4aZWU9Znb5DK6dMTk+glIJa/U2a6urbN15C2nSIola\nXFwpsdBssW3XIIM7PbxSjVb4DU6+9PvsvPnD3+y5c12m3i92+jla4lsVPf2fM7rvvWvyYE2tc5+W\nzY18o7UtC2NAZRl33naYF4+8xL13HETiMb+wQG1wiFKpiCPzvJowjNi7dy+tZotut0u1UmNubo79\nN+zHsiRz81eo1Wq8+c3v4OSps1RKZaqFLQyNTUK4yje+/nne8tYf4ND+XXyxWKXb6tBrNhkbnmZ1\naZY4aTAzUaVbrTC9dydLVy8zu7jG9NbdtDttxsaqtMMm6/U6cSfDKfgYWxD2UoJSCQsFogcCfK+E\n42xS+kjO/SqVBWli0AVBFoPRkrGRMrOLTZCGqgP33n4nd04Ncqixj7X1Ot0XnuDvjjzOqvIQXU06\ne5kdd+/lsVN13GKFpHmR6Yl9FKTPehJRMD4q9rB1hmVpZKxIe7nAIE0iVJZSKlWxHJskTBiSIMuC\nNCjgRT3WIggTQ6eXYTtQCERe5GORaY0rBUEA0ZIBFyInJz7HoUPcy/A82Kwv44d/8ReJlhcQ0uqj\nyhIp81if5ac+TyyHuTY7gli4SO9yTONv/j+C6gjISSZHB2md/iwXT5xB2F9CpYZCIMAqMThwN62V\nqwSsosI50ladytab8IZ2E61dZmzLjfz2r/4CZ/76c8xUx9m2dZyBQpXBbXsZtlu8d8+9/PHnv8B/\neegse8aHMDolS8EWgjQzZCbl5t1DHLu8jhGS9dlldt93gBPHfIxKAQehYGmpwb4Du7Go8cgLm0Mw\nALLrBqcCrV/l7+SXrugrLTeUSvSx5j4HtX/tp+zGNe9F9aZpr67QXl/lzl2DFOMVQq1RtqRnyjx/\nsQFD27nntjvZsaXC5KCN62ouXjjHi6+cZ4ctGa2OIrMGjhezZwxGRstMbN/F7TftobmygLe6yr5b\nt3Ph0hxfvKi4Uj9CVrnGdx8+yA+8YRIrs6lsPQw0yLSPl6wRWCH2+jK3jHz7ts23GtJ1cQYncEsO\navYqILCHdyNdwZnH/oLWhXPcNewigjJnP9fD1hnutoOM7Zim6GruPriTVZVyZiXlNi/C9FrosIVU\nKaLTQPR6GJUinZwjZtIUnaasVbaRTRQZDA9T+tM/hC0PsRoLaDeovP6nkB/8UZbTkJcef5BpN+b1\n6Sfh8Nsxpx6C0EWoDLs0ggjKqGtXMUmYG0YGLrayMLIMdPN9Pe5iwh1Y06//tvPwHQsepRRa9gMB\nlEKpmEQloCUWNpnJEFIQeE5u320g0zE6gSRKUQpsOyFZAfMDP8yVk5cYGBumsFqn9dyzdOvr7N65\nmwVVwcPFHxyjmfZ4Zui93BY9iCjdyEhxBSU1vneKfduW+e57p5irJxx9UeG4MBy4DNktPJJ8CQuB\nMQotQGV5I9exbYJAoDLoGBAqw7MNru2BMDiWwNokTSVWCmVrrFQTk7srKyHR2oDOsHwXKSDQGgtD\npFMSAYctiw+M+fz2pTbNOMZRCmUErnTJAolsNpn/6uOEjSYYxT33vY4Xv/EcrTjBM5qh0m66WRs3\n0SytJSSZTVnU6BkbO07RcUSv18VyXSKjsQ2EwiI8f4mlX/vPbHvPA8jX30s3M0hL4BhDqKGHRmuB\nNKBFnmYv0bm8fvMH9TyHxUjQfT5wv/axBH05url+CruuCt/4VwiUlkS9DF8kCG+cm28uMlXTfPnZ\na7xy8hJYFq1mg2K5hK0XCQIb7dgYZeFbis5Kl93jRYLpkCiUKHU0L3CMfs0maF5D2elTq/tffzVH\nK/+u6bMeNWB07j1ljCDTgkz974jQdx79k2a/uLIFYDSW7fCeD/0sD//lH2ESQ2O9zfhYAc/1cH2f\nK1euAAbP84jCiMmJIkNDQywuLrKyuoJScd/JWjI2MkxjbY1tO3ax3qiTNVdZadVZuLyM0BZbp0d5\n5cg5MDbdTsrM1lFGywHx3DyyNsFXH/kar7vlBh649zAnTl1lS2mQwwdv5cS5YzTWlhFRD+E7KGXw\nfQ/HBs8T4Ci0yhEZKTe/cIQAy4NeAmEjt7kQSpOmhjTJ5znRhsdf+CqZ+S5Ov3CaW6eH6KQpwwdf\nR8ktU47mCMYzXphtc34xIosbjI4NsGPGB+2RNiLwq5w7k9BOuygj8YqABMsSmH7r27IlYxM1ulfr\njBegg6AtJZYUuMJcp3xpRa46ycDyIcugFymyroB1gSjAUqSZGfeIPEOWGkRd4Jc3OT/dEKNFzoHT\n+cWjwxXoKepmkt7sRVy5TjcxxN119t1yCE9kNKIS1y6fZKQaMrrjAGcuXSVsdDm05y6ihVO02rOE\nvQQTlxmv7ECeO4FIPo1M78UZvYd09hXuvfvNtD72O2w7NIFU52m3qyRnFykevAe9IvjQu+5n19A4\nv/HIESpuQBInKC1IEs34kM9yo4tB4NqS09cKbFtaJU1iSgWPTOXWJqVije9+99088NaE4+/bfMGj\ndR4UuzE2AEa9AfVsXMv9FnQeGtJvpwtQYgyZfT/N+Rpr155FpyG2MQh2o6MO0sRMbt9NdVCy3lih\nvd5gbXGOPeM2vUgyMrqV85bP+eV1vrbW49CAYJ/UbIszdr/hEKPjW7DKgxy8700sHX2Uhc41MmHz\n7Nw0jy/ELNj7uGt6io/88cPo5G7uPbSFsLGMt3uClZMnkYNVRmuGUmAzxcSm5kZIB9VbJV5eRWRA\ndwV77CCVN76XnVEJSyaUrhwjTSwaa22StEtpp4teWCTZthVnbAoZa2bKhqSZkHWbyO46RmXYaUjW\naaMsDx2FKGFwvICwsUDQvMzCy49xvDPAzjf/LDN79hItr+LKCCHaLL3wRZJzL7N6qou1Y5roru+i\n+IYPYyrjtP/mT6kNekjHwmQppt1GxAk4BpEaHBGRaa9PT4C2v435pUHs41e54cC3nofvWPD0uvkR\nRMrcOVNpgeMESDTdXgNhXCzHJYkilFE4soDt+SRZB0vauffCQpfSb/wanHwR58AtJGKQtZXncScr\niB1bWa2OY6V1EqdI3bXwoxgQ1JNp7hjpIJPnEHGP1PFxgjXeuSPhxJSHX5as12MmnQTr8hoZFkJq\npNFoKbCkjSZFOj6lgov0bYoyyIEFkbuItqPw+nN7k0Y8wrIwlotIY2whMdImUSmlJEPVyigBxUwR\nofEMOMKilcS8PV3jufmIt4yOoaKUTy22KBtNKjW2Fpz9zOewpMBWCa7j8uxjT+FJKGUZQkpa4RVs\newilHAIryHOOhEBni2gjaLa67N6xlfVml2StTiYTHMfGMYKu6zP76b+jfOIE1R95PzgBYT8IzzJ5\nMSK1JhV5H9xGkKLAbJ60bNIYbBtpSay++YV4DZryavxcPv/SbJCHTZ9IDpWKTVoawFo/wfriOTAR\n9+4S/ND3vZcrCw0++9BLPHvkFKrZpug5jNTKVALNYr3LpdkWS6seb1wdZOedkqFwDiEcjInzvC24\nzsfZcH4W11Ef+o9f+wvl6g+trbylpQWZyr2FMrXJltY3xVfk0n1p5Zv197zn/ezeMsbv/adfx7UF\na2t1CuUySml2bN/BkaNHuOGGGxgaHqDba/H8C99gdHSE6aktHDv+Cnv37qVUKqIyh6TTYG1pjpGx\nMVpRj+bSKrVgmFKpTNlzqPglYtdjdbVDc3CAsZEC0oOlpTmanRZHT55ifHSUgSEXXItzFy+zNL/I\n2+6/G99oHvrSlynVRgjDLkZK/GIBy5Z0wi62FeC6zqbXjS0NcSjoroOt82Kwi2H2WgNcgfRBuDA8\nOUqhGrB/tMSWisPUjgHC5AoDIqDipxyb17xwdpVuq4OlYtyaRWpydcpIFdrNdQYqI2jTxfETYkUu\n01U6j5kxeZxA4FQpVlpU0MSJodNNKbh5JLLoL5FUgaMFgWvIVJb7/AiBbYFwDLEGT0LWTug2XEzR\nUCkIkk12bdK4h7Ry3piwDapxkdbsIisnvk6WhKwuzTNx8Hu58NSDlIqScK3OWlbALktGR8oIfwer\niwssLzQ4sLvG6nKdxkKDGw/fTBB2abZjenaN6vBhrl48zoy8xlBlgaw9xPjewwRbYLn5LNLycAKf\nSrHAhTNP0WltZdu+Q9x38w5OX1ziwfPzOW9Pg7Thhu1VXr7QxnMspGWzrgJWzy3g+xGDA1UanYQ4\n7VEtVNg6HJBUA5557KObXjtK8Wo7q1/kbBxFrj/vPzH9NrQRG+IJUNk76V2waDeOUQwkzW4IwjAg\nu7R6TXZOegwP+5hyQJjEaL/M0OAgp6+u8+4ffBu//Is/wyePD/Mb/8c7+cZfPcza+lVGthbYM1ll\n+o7XMXjTzTTbGrdYpjp7mnjLFo7Wx3HGNeeOnSWY7jA6qLih1KI1vp2lhXlkb56CXaXTu8bgcAmv\nFODRY+fI5kw9tbSgF5KuNbGziObf/QH+rgOYOMY0OtSGy7RPRZjuAuWBIVrrXZzlOUbveSelvbeC\nSKiefhwxZ5EmM8m1AAAgAElEQVQlCh32MN0mMkvyRIVM0TMJ+uRnsFnDymLsTgtPOSTpAI25qyR/\n9WtY0yFbChXQEqd5mR3+FhKzTmfHLuzRCs8/fZ77y/8DvXiFRsMlbMdMiUV0nJKuxxApTLWAiTNE\n1kTFDmYgP1BX559grdPDf//vfdt5+M5ZWraFEIr1pUWCahXbdjGZJlKG2LhILUg6bcqej2VZQIzK\nPASCLI3oXujif+CHkHu2s3NnyHm7wkDrLLKWEOkhuisJiy+foHr7Doqta6j6IDIQONZRhourZCsP\nkVo1SoUeupdzY/bsXGRntMShYCdJEvHMJZdrWmO5FpZdwLVdhG3hSJuujghMhBdYIEy/3ZUvftsS\nlAIPg0ZomyTdZF9CJQRCENkCLSRKGrwwRagMtxeC7RD3ughpkdXKpDrGXW+yUNQMl32KYZP5C3OU\nijWCLKTjVrHCNiXLJk6iPIW528EohUCSKY0vLDITkcXXkF6RtUSiMk1HdZFCosm4+dBBlldX2Ld7\nG2EUU19fJzOQWLkTNp6PuLRA62Mfx/+x9+M7BZI+3BtJhUee75Uog9VPvvU2GayaD43Q+YxvtIsc\nwbdIpc9xlA1VuDZ5XpZtaYigXT8Piy/jF4coVxxWOqtkCyeZDAb4yR+8g/vu2Mpn/+5Z5tc7LLR7\ndFPJntECwoZYZ3RMTK8+TOop4m4Xr+Bi+pyk/H1tqPQ2vmBek70lrr9HTe4hpA19N/E8fy1Rfb7S\nJhCe16qXrrtEAxjN2ECZbQ+8m6999XGOfv0ZZoanSLKMgcEC9Xqdm266iSzLuHbtKnEccc89r+OF\n519gy5YJ7rjjDlqtFuPj45w7O4slbNq9GL8TMrNzB5XqGLYbsLK8yOpqHc8tksYZ0nVZaba49tI8\nh7ZWUWKWhhGcOX6OCxcfZmpqK29569vpLs2Rdtt4MuF73nYvt9y0jz/7X5+j3QqpVEevI4QITRh1\n8b3Nt7R6sSBuQVAAoww2AtnNzUyrFWi2DHt3TfADb/p+1s++zMzUMCNTVUzJJewlRHg88/IlHvn8\nizRiOHCgmitKOxkiU1TKNqkF44UiW3YNcuWiYnTQZjFdJ+z0GVxGoozCKI20bISt8YTEVRrXcxBS\nI2TOm3Hy+whpCtWCRAtJN8soOrkhodaAnZsnKmmoOA6ptskihbXJ6ckVsnnUy/rxr3Dt0b+Abe+g\npwZpr5yhMlDh9Nc+y8j4JMKGuUaCKAZYq8t0u5qyfInMGuemQzfSSiWe6jF24z3MXTiGxyVq276L\nTKUc///be9MgS6+zzvN3zrve/ebNfa2sylpUKklV2mXZ8m7wAtjGTTdg0zTGHQ66GWCiGSCIYWa6\niQbCDZhmbGgGQ8PYdNtsxrslS7YWa9+lUlWpstas3DPvfu973+2cMx/eLMnD2B6nvzWR/4iMyg8Z\nlZknz/ue5zzPf/nmVznxw++n13iMYWGh0hQpDI0BOM0p/OIwtfIY/bU2l6MrbHZeZLVXR4sh3jKf\n4zMnTZY/ZcOxQ0OcutJiuFwkShS24+JVbfJmFNXfoCHOUq2MEIiIr3zjEre+boWxqTJ1Zdi3y72j\nDaBfjX+F7BNtzCshq2bnZSP4llEXhkH8DqKzw6j+GXzfp765iYgl+bEJtpttJobLFH1FeXSS9mCb\nwlAZlta48tLj/Mjbb+U//eeP8WDrAExYTIWr3PeZ3+Y1b34/W/2A0uR15GcOIfffSqm+TWt9m9zk\nQQr1LbwrHV48s87RuRqdrVP85Z8/yY/987fRbnRZPPdN0o01ykM3ot0hlJPHuCG1MgSD3cW26P4A\nWShBbQzVqtP75hfpPfgZbL/I7E/+NvaBg5RvehtpfR3bshgXKbrTRC09zub2JSxP4538PLEq4SQp\nKuhDp45MEmRnCT24SHlqlMLm3Vj2EMLNY/kunWZAKyhz5503UK2fpxg8jJVuZpy52IIZH2kkw9MT\nrG6uM5E2SB/8W5SAkm3h+CXCrR4Ig/Y8hOcjbBDVGiIIobuNUa8eHu3I4qu//AF+8e8vftt1+K4F\nT5xKon6TxvnzTN90O3EcEwqBIzU528eg8HNVLGEDCSpJMMQYo+jWt/H/539N6Z0/SGWmwHrsEOgY\nS9nI3DDJZguTtskf24dnKdJSja5qk+8PKMmUgr1Ms19hcjylHo8y6m7R7Q0he2eoTMHB8ecYyNfz\nwguXsJBYZO1o6dhIx8cgSE2YuUOLLAAVNEaKVze9uTpc0Fi7vIyaRCFMhGtZBLZFrDSOLbETQ9rp\n4wCRJbBERNqG94+XuZRYnA0SfiANKVmCK2Gfg0OjnNEWOdew9vmvE0chaZpgCXClJBWZ7N82EqEU\ntjAUXIcwDnAtm16aZpwqW6CV4PSpM4TdgHOnLyJzHtIYvEGE9jWRSolNyk+8841sFIr8zV9+mpkP\nfRBXKTQCXwsSkaJU1nvRZK3z4Ptg8UiRjRev8nYssfOCkZkl/VVJOlc7KwACLCF2ih5BqANUf5VU\nVqi6FsnGCqWyhxkM0L0eoRZMOy6//ct/SL0R84nPPs3JS+ucNobhsuTokEvaDmmlfZRbwlXPA7fv\nDK8y8rS4Ogbl6s/Dq1fCq50YTWaYaMQrxU6qyD5SQZx8H6Obne9/9btqBBaaip3SjzX//F/8BE/c\n9yDtdpdG0CEKI5rNJlubmxw/cQLHsYmjiM2NdfbNz2HbNoPBgJGREer1On6+RLPRot/qceHKGo0o\nJu9XcHMJjp9D2x5GOGjdR+kUxysxNXENOatPPwiQqcXKUgfXOsyJY7cStQbEnYCFuUmaW2sM14ao\nlnPcdOJ6nnv5AkkiGM4NY+yIVtLcIa1+e+Lgd0OqAQt6q+BUsiJhdFQwGIBKIZ8acirk0//3HzE3\ncoB6I6S4VOIH3/MOBqbP2OwCc4HFG+cWeawZcf0Nh6jXA/JKIyxJTdVYXelRqLkYy6FQKlG0JFVt\n0U9ShNFE4YCgNyDIBfiyiKtzIAYU8x6TRpAMDCUbUm0yW4y8oBEZEi0wWpEkgGUI+2BssCqCgqsp\nGAfX82gHBaIwYKSyu0vWVSPYuHWFxfvuoRfWkIuPg51nfXUVP2czun+BxOTpbS3hVSaZGonZ3hLU\n1+pURkcQGOqXnqE6fpgwCigFl4gq06StBLF2N5E8ysF3fZDWxgXy0zdjVIyVq7JxeYVHvgFH7nCp\n9WMWSi32LSxQGn4jF88/wcW4x4MPPsGNx6a5oZrj2UaX2aE8/X5CFAMixJYurpWwvyaYOPBuPqQ1\n975wH/c+/Q8kHZ/nXg750Q98gk/9l5/m2qO793CSAvRO4P0rPD2TPb+v+BNf5RSSFT+egIG8i/jC\ntcTdVaQNjlGQCkq1Ca7ZP8eJQz7DyWUaWxv41TxXlpvMTFcIHjjPwnUT/OWn/poHkmPcYJ9ieRka\n2wEVO6U2P83cTVOM3/Im8kfuBO0jvBGq88OYiUn6jTa1+jOMTMyyvrHNVtfiXT/5HkaHi8xd/jy5\n1Kbfh9baCsLNIb08Q47HoLPOYODtam10vw+OjfBymHwB02ogUx+VJOjnH0CGF0lPPoSDJFg8CX6F\nROdJCiXi2xfID1ZwREIUbmPiBvmZm7D3HaBcjnDe/+/QeOjkCk6aolaeRp99FtKUIdHjYGUM0biI\nTJvEgcEvgSz4GDePGnRQXpVCocL0O3+E+Yd/B70NRgrKeYUs2zjDI1jVoSwaQ0p0p4NqNHBVjF/d\n+QVTw3YXngkdfuD2a7/jOnzXgidRkgefPU3RyjHv2hlR2WhSc1UI5aKShETFWJZFmoLU8M005fbf\n+SOmiJhNBefaffykzv74AmpQJFIdKFUZ3neION7AaifEl17EzVcxZpGz2mfJfQ1G30rxQsIt3vPc\nUqwwMbeEkysTLnWwpKYkH6OxXcEI5xWVlDAGtEIKizCKcR2y7ocRWEZmxly8MgLnqgR5t0QVE4WE\nqUVi2YSOpmpZhEKgCz5OlIBOkdIgLYtxV/O1bpfy+gYfPjaHrWIa568wnfepBS3eduNx/v1vfZRg\nbQsv1rgqQduZedPsvjk2t7Yp5iU6CAnDAaFS5D0bkhhShS8E6SAkjUJsUSJNY3zPoRf0yFsWxrah\nn2K7NoQxn/rC1/iZn/hRDpdrXP7Tv2DoX70f20CsDakx9CVIBUiBpaH/fcjSbSmwZfav3CnAs8nW\nzu2Kq/wZnRUg/6hBIqTGtotstWosP/13mG5AbbTK0kqbkg9uoYSRgs0AWr2Y3nabD7x+hvSN8/S7\nA+4/vcnlnqKbas7dN+D0mRy/tPAEdvk1GKMwr9wD/7FVu3hFMWYgC1Q0WYGjNahUEClDmkIUZx9h\ntDuBsTLgpRA7kma/y1iuhJCQ2hb1pMNQrsyhE9czefwQ9ZWzjPkloqCDSQfMz02SRn2aWwHDwyPo\nRFEq5sk5Hto3tNttkiSl0WiQhCEH5heoVUcYHhrBKRbo9ftcvnyZ17zubSTJgAfu+VvefNtRThyZ\nYW39Eq2NlM4gxiv53PnaOzE7YxllOXSTlHOLp7ju4BTjY4e5tN6iF4Q0NhoM+gnNeo/b77iV4sx+\nUsvLRkN8eVdrM2TD5DWC1rjhwvMQ6qywEEKQbGgmFxzOdZr8+Ps+wMzEGCfv/Tz19YRKYQJDj6Cn\nOHLkBAdufobN57eZHB/Gtnyiep+861CoTbBgVxgeLtIOFZb2uP+5S5y/1OP4TfsZxDEFz8HxLFwv\nTzE/zEZXs2kVOTDiMFMrsr1VZ63T4Q13nmBkpMY9X/06FV/QCwy+a5FzFZV9Bm90Z+8XoRNKNpMB\nedEnHABSsL62q6XBLfmEGyssfvaPaTS2GTSXqczsp728ROzPYeiQ3z6Dqt3E/iPjbC2eZr3Vp7jv\nVibKb8b2LTprG5QPHWK9GVJLXyZe+DFKzYuMXVOjZd6HiAasLS8z4vTpnv1rvNd9lLQ0w1uvfRf+\nUJnlc8tMFAucWs8z7m3x+kPnGB0b57pClc6Ra5kfjnnn/HmCmmBjq0+oDbbjUHQcXnvUxRf7mUkO\nc+6xh5gvVTmeFHnvW34VYzs88JY6v/Z//TG/97uf4qUVRRT85q7W52qUDDvFjNHZBYWMypmNAtl5\nrkUmUQ+Tt6DPHoR4FSmKlMujlHM+agCuFZOTmpe//MeMH7qB0WoZW3Q4POuw/tBDXO67/OWXzvK2\n41V+qpzy8mqZ181YfOHZgE+/7xfIlUq84YP/CyulfTiDMvlmD4sUJ5dDFsfY/zO/wdiLD1O4527y\negnfHUetfpVcQRKM3kj7pYfZrFncf+9JOlpxZP8Ydm6E08sR/a2n+Q40lW+PNEa3mgjbQbgOlGvo\nIMYqlQlOPoa7dZKzz10i7EpGRgsIEyBJyR0twcVTVN98C+V3fwi/t4FrF3GiM6TuAolqI8rDJGkX\nW2sGWiALo3jH309yehnnof/K9CAm2P8WVD8mnrwT6idx4g62E0JZ4LiT+ME3WfnTv6IlYGxYolPI\nl8CyW4h2E1oXszuoAUuBWwIxdRtieBZZzBG4R6jk9vFOr0axfuk7LsN3LXjOLK+x0lCUfU0QCTzf\nJRUxMs0OryhOAYnjOugkQmnNE695DfUTN/GpZoM7R8eI+g1qvoN0JDIYYqUpcHoFug5Ieow4PZAh\nXsWmaCUk5gB2tIIKPWakoZpzuSe8g3L3NGOdc2DHSAuSbU1kBcxVimx1PWzHw3E8hOUiREYkVVrj\nWxJhLCTxK+m5CIHZyVDRO6MLqXd3aBmjiJXARmCJlCYKmYLvSvpGoS2JTZZM3u72yXs2p/B4bqvD\njYT0Uk1sSa69boEvfeFzbF5eo2JZJDuOWUalpCqlP+gzUS2ztL6BrVOCNCEnLAZpSm5Heq9llt5e\nKRRIohjfkkijMa5LmCbIJMG3HERqSI2m3mrzax//c8YqNYqzU0QXXiaYPUReCpKsIUaAIaezl4bk\n++lgZG8arbN2smRnXHWVSMhV/sw/+r9NlgRtuw5P3XMf//X3f5Z88QAnDo7y2btPc+O+Ks+swdn6\nJpHJukGDCF5/qMrbKzkGVo5mvcMd1+/nwZMr0GgyVHEplUK2t15g+mBmsojJUof1Tnvb7BCVd36E\nbPhpTCY93/HbUUqQaINKIU6znKRECc7e98fwhp/7ntdG6xSwUUA/AVXUGJ2glEfULdMrwgMPPM/L\np5aYGR2lP0golwzT09P0ej2KxQI5v4jv57h8+TLz8/O0u23CdEC5VGF2dh+piqjMTOBaNmE0oB/0\nGS1XQBlmp6bxPJ+zi6vYMmVra43CjQe47YbDLJ5e4tRaEyuXY6veZGp6Et/3CQcDVtfWcOMGMu1T\nHqpge1UsRBZmqTRh0geTUixWmdp/DcVScdf7xnMEUShB+tiVPvGmQA0gVwR3RGCXUsIYwqBBvRXi\nFyyGKwpXajbWzhEKh/NJSqup6ZWLPPj8Mq12wHF9EefUBsrMceK2m1g4ssDTzz/N9UfHWW0HLG/2\nXvHUwnGIowilUqTMJLYryy3inoW10WXEUyQh3PPoc8wetChfB811kEFm7CYlkBf0bIh6BrsPUShI\nY+jbmbDBpALx3U1B/j9ofOP3KN/xYaLmWXJCERbHGEQ5xo7fRP+pb2CP3sx6VMXrdFhZewk5cRtD\nfht78gauPPs0YzfcTnV7kWT1JFV3H0Gax5z6IpE9zpWz2xR4lGoux9iJ96HXDPnpHyV36C7+3Q/d\nwjIw1e0gc5KGdOhowUBYHExybC9tMTczSrezwqmNFB/N1kaIEiCU4boDFe66YxrVy1PeOkTa2ma4\nOkKaDChVqjR6A6o1l7smJnj0Y7/Pv/j13yY2W7veO5hvFRpctQR5lbwsZHaBwQjQWeC1ak+ikoC0\nG2DlHKyoy+VmkygWXH/oAG86NsRKc5LhoSrlikvS2CZOEkrT+2gEl6nhsbY5YF8+5NDsPEOrL9CW\no6ixOfYfPsrLwRAzJUmqEuI46xp32hvYpJTHJ3HnjjJ/0wb9c5Lxa/ejencwWHqebstDKouxEjx8\nxTA2nKPRSnHsPnGjgeft7rwSRmGCIFsjvwCOjTtzAH/fUZL6Iqqzhp0v4Pf6WYGYJMRhhDh7lqn3\nHGJy5F709uPk3EOYdAgsgaufzS6u0UWceAWwMNYkJArd3KLTrZGzhjFhG178OuVcE9fTyNo8Ymgc\nEoXpt9GNLkk3wrVkNlVI9CthzVlnJfNxE8ZAAPKud2G97YPI/gW29Bw96wB4Y1x1edPzU99xHb7r\nI3dlq8dWz2NtZZv3vkWjE4mREmH7pGkPB42tDf1+xOVano2hafR1J9Bnz3M0b0AbziUDBptr1Cyf\nseFxrpnapHkJvBL4zjpGp5igj1WP8Lw2PV1kLG3SHSTctn9AWpik25rh/tL1POd/gDFPspB7iAPR\nPyBiF8uVWb6XbWMZGykyf/FYpVmyrDQZZ0NkvikZoS073LQwCC3RZJlgu0GapLg2CClxkqzAshKN\niCWp46IlGGnwpCHRknYvJI/h7s0Ox6crTI4ZLNfHtgzVQo6ZoSJho0diVKaIixVawr6ZKosnVxDR\nAK0lvpTINCXVCca2iRKNY0EiHZJeF69SQESCVBt0kmIrg29LemkKQmPZLmWl8KQgGrQIzoZ4h2ax\npw7QJZNZy6vDHiMYGDK58S6hjcoyfHYIgtoIpDGZb83VQ0Fk/kVCvKqXECLriq0s9/jNn/8ZJsfm\n2AhivvnSFTSSL59r48ss6TtvZ3+7Sglebsac/foFBv0B0xWfsYt1brgmzze3DP2OYCKR5KoRcS/A\nLtpozI7zdlaQ/eMWkzEZeTobZYFWAqVAp4JEQRxnvjBLp5/kwoN/AHzvBQ8q+32VEQSxIpaSxHg8\n8fhTXFo8T7lU494v/z0VN6JYGKbVDSnk88RxFtRYr2+jUkmhUMR1XRr1BoevOUy9vcH6+jpKK4Zr\nY/iOS7fT4cTxBVzb5vCRBZ58+in6vQB0QhSHhKki0YKtVpswVDilGtotkQD5HfdnCYSDgPX1TfaN\nFIiVphe0KFg5bCEplUoYadHptwh6HQbdDpFKmZqa2fW+SWKF5+XwPIcrsg95wAEVglMyDA17zEzP\ncHC0xMblp3ndNQeJgpAnv/4Z7nniDGOTM/R6EanRBL2AMErpR4q75iFut1naavL1R77Ihz7wdrY7\nCQ0RUq+3iQIQRqJUgtaaKLaIopQ01UxMjdJpdYn6Csc2WHmL6hSYcUGAImpIoqbBc7KRujGK7qZB\nbwmMBZbe4Y8kgjQF42YBurtV7X/uk3/Oe5wK02/+t5z+0n8jabUxdsqFJ+5Gjt2BpZpsrHQYyvXI\nHX4N3eWzlPcf4MIjXydUksZ9/wmvdoh0EFGdP0ihd540MlRnh7myepraXT9Jb/UCrS98hGvf+n6O\n/uS/4Td+4g08/BIMAVpItJYM0gQ3EqzFLb5xOuLo1BgbGwMO75tj8dSTiAIoIxGW5g3H53nTm/fR\nafZpni8xNuqQ5PIIpSFfJUoiiiOjxNEA33Vw2wF/+7/+Mu/8zY/teu+onUwJbdjh7PBKhyfr7rxa\nDkmZkZxNZ0ASe6SWi0bSbG7hS8MtN9zEwbEyfnARJ19C6i469gjrCaOvew3ysecoc5kYQ6oSNhpN\n5mYLXCoeodkc8Ka3vpX9199IbPm0m3W++dwicn2TtcUlnti6hO16/OzP/hiv++F3U7nt7VQPH6Px\n0tOUJm5AmyGs3hNsrGsOTgmc1FAtKgaRJkxiiBTV0bFdrY0IA0g0JtVopRFejrS+SHh+Dfe6SYxt\nUau6BM0eBVJ6xRyJl2JGKlQne+jqD2HiKPNS0C3QIagWBgfcBczw25GNryGXH0eceYYrZ29lo2/w\nZm5mNrkHVwzQAxfLCbHnr8W67V9hWltEX/0D9NIiQkmkY6FTAylQyBSzWfj0zi8RgfX+D8PCKNtd\nxXL8OsomRlbHyaUDnPBFhj0Btg/c+G3X4f+n4GkzWnU51YB6J2aokmJhIaMOAydHf6jAJgGtbpvV\nQcD0bce5eOoFJpKAg9e9iYc3VpmwYWJ8mhXbI9Ex+2KLvp8QpYpCGBKtLtI6vcztdpPJqYROvsrB\naZ+V8AAr+QJry7AZC9pxm3BsiH7hBl7IHcWe+DF++qWP0nn2CaTMZ7JCSyNEAkaSqHTnFq/R0oCx\nECK7wZkdIpshI8hKnZlP7QqpBjSRVKQ6q0hLRjDQKRYSV8lM2aMNOVKIIrRwaHkez6yscXyogrZS\nLpw6R6ffp7XdIictkkGMQGAdmGNeDPiV/3Ajd38q4uOfaJD3IG/biFSw78gBls6v8SNvu4mtZg/L\ntnjwqRewjSDRGpkmpFGCkIJ+rIgB47j4RpOEMcp2sI0msQ3x0hIpKV5q4YksNE9j00fjf59p6Vor\nlNZIrTFaZpESgBLgGPHKtSsbcclMLUHmGFtvRPzrH7yLoZrFaN7i5e2EsgcTNZfGekitUqSEYqAV\n22GCJSwcqXAcQdyVBEKythWSyxtkztATKeW8prfu4Ezb2N/K1cGwY7ywsx/EjrFgJo03JuPqJCp7\noUYqI6imSlJfX6fz0G+i7V0mOzuaTmJoDwx/9PGP8bP/8ucII83f/M2f0ms+iyfHCLYblCoxkdsl\nND3qTY9iySOIBhQLZYTILP6PHz/OCy+8mAXhBgPm52ZZXDyPjU1uYpKR4SGWr1zk9ttu4cxLz+EJ\nxVpzkzBR+J7L2PAoqeVxaWmdg7Pj4JZYrbfIlUdIkxhJSnOzQ6c/YNBoMXtsnnp9idX1DsUgT7k6\nxL5DR1jfWKYTJAThgInxMbZXN2lu1He9b3IeWCLH1PQUcWR45vEWri2J+oa8C2kSMTkxQ9+a5Eoz\nz3p9k35vwOKFdcIg5NKZc0yNF7NRbtolUTAuNc0eLLgDtDFsrNbpY1MbG2P/sZt56ZHHuXCujtaK\noNVCjo5hCUMcB5RLBYwReK5LHMbk8jZewaabJnQagkEvS0P3EvBklk3nuhadvkIpUFJgGYNtC0hA\nkTn7RglYu/Th+dqiS+FPPsrtP/ReTvzMr/D0332WzoWvkqvdQL3bx4pbzFU6aGVTiC9i77+esHOS\n4ZmbaK0tYo2/g81Tj1KpVolXHmQrPsD8TIW1M+eI7BFO//ffpVyDO3/+E4yM2/zqu17LUjvPagK2\nI3FtCdIiNTYqVGgPLm006EYhE4mkXV8hTCFqtrkch5iggbrlGK1ul6WXA7aXBPkISmqAlStDv0dx\nfJzm2hqzo1Wifg9KRTzH5c9+4YO73jv6qibd7MS+GF7x1Mo4rTty9J0vSxMfTBFFgvGqCAGd+joj\nM/PcfGSKo5OC9Re7/MPnvsIv/uJP092+gj06zeDCJfqlA3R5ngNlUEKy3e6QWA36kWTmujfS9qfp\nlfdxcTuh8bkv8Z61bRYcifBs3j12E18tbvB//O4n+dztN1MYnsDMXstIbYTNL/8FqRboqM/immQg\nBE+tSA7MxBTsmJdfrnP89huJvlV//70gSrLLpjLoKEZaDsH6FsGLlxixQ3K37MNdXWUwP0J/CHQl\nR0n2cU2HzskGo+VnkPtfC7kFdBgi/DzGriIL4xjpIRrnSHrXU//sx1k5C8Htd3KmOOAu7zxGg+XJ\n7OVvQJ35MvadH0YcfDOW/+dkd/HMC2mzLZgcA+FKSFQmMtIgjUFMjcBkm03zK3j1T+NM3Iqqn2Sh\n8WdQXgDPxqg2pE0E7/y2y/DdC57VLmNDwzQ7CecuLnPHrUeJU8XKRI3FYofu5iW0KdKRRcbmF2iu\nn2XU8snf+ToeeekUUXUIL3RoRF1mvU1it0BuehrZuIK30aUdt5kpF3BrRe5PZxkddxkpjrI8cZie\nTlCDNtGCxBOKvJenE/Rorm1SHhrGF4J/OPQvmXpwEdmRSCSWsHZGIqCS7AgTUl5lrWVjnh255FVz\nXM0OOXWX9Y6TJiQobGPwhGFgBJ4E26SkQuKZENspoJAII7FVyqBQwFjwpPKpbG7R6XXoRCF3P/IC\nqcm8KM0FG44AABPMSURBVGw7o6+Ovf1tNP74Uyze20DEilo+h2NnstqC7TAxVGH29VNYfhGjesSu\nxVCxSHu7jmsJEgO+bRGhkdqQA+I0JTUKIS3yQjBQYJHQO7PIcJJgSOlZgqK2cE2MBjwtvq+0dEeC\nMjobaSmzo8LKRluplNgiGykakR0KAp3J4i2H//KR38FN10jdMssDxcGZPBVpqGAoeBbNfp/UsgmV\nJhqklEdhEGryfp6D15WQwrC0btH3bCQOo0nMl7/Z5cRtCb6dZLcSXi249I7nM4AiuxkqnWVkKZ2R\nwdPUEKdZsZMkknY7YOvz/xP1Tpvdmjj1U4f1Zp/zF86zeulF/vITf8iRhVs4MjdDODZJqiVPXfwK\na/UWtUKOfM5hY6vO8nqP6687BsYlGvRxXZcXXniBUqlEs9nEEhbbG1uMj46wb98MQT9gZGIKWwou\nnj/PeG2G6YX9bKxeIUwVaRjjqATPcXEsl62NDVZ6K+SLeTzbpjg6Qtjv4GLRaDa5684T+HaK4zr0\nI8kzjz7NyPgEI2MTdMM2U/kjTEwusLJ8ERNHDILdKUkA3CIEgwGdbh2SjCyolQHbULEcupcMp8Ul\n7r3wAO/4oXdREgl//6V7GOQcFuZHCBqao/MjbDQDPMvGSQTlqM9qdZKV6RvJNS4A8NADT3Lw0D5y\n5RW6rU0KRUhVil8ooI1GyATfy1MbGeX8mfOkIdgCVKpRSpO0oZVknEUwWDLbK9I2eJZNEibEcdaK\nTwCZGEyaxUroQCBSg9ylUOKei5rrKx7ii59l8um7ueMDH2Gt+WbOfPmvKEhBp/kihbnbKIiEeP5t\nbH/jExgRUx1+iJx3gCAIqE5M009dLCvH2NQQ3d4WcaipTU5x3Qf+lP0Hp7j7kx/n8//9yxSmDpKr\nwfalcwxbGq0zr64oifFcjzDR5IShGQ+wty/RtTXb9TZBbYIfee8bqVU9Xjx5gf1zmpce38YxEe2i\nR9Dv45Q0s0NFoq1lKpbF9uIKhZwkWDG4E4ep1HYnu4bMh+cVns7VV9ZVry2xc3EzIosfSi0kknay\niNbjhN2QYt6nWBvjwPgI46WU7uWzeKLPD7zzbZh0wEtnznPXyCi52ijnzj/GwWGHVGk2+ploRtsd\n/JFDtLsBuptQbMWcf/Zlrl3Zor12lpX9k9Sm5hi7eYKfLp3g3LlP8NRjT/KGd7072xiWTengDaw/\n9HmkFIzkFf+w6DJWzsb2g8I083Mt/FKJQbu9q7UxqcZo80rRYxxN6nms2wG1zS2swQRxwaFQaDEk\nGoiBjesU8BzF08tlXvi9e5m5+Tyj0+OUDh0j1gpfpWxdXCdfzHPvpz8Ga3DgPb/K+msOs3zucQ6v\nP8a19gu0XBfbM5hB1q1R2xD96c8gp29H1TfABpEIXG3oDAxKC0Sssk5P5iSZndfl/Vjri0T69/Gq\nNzK3/FuU821Mcg66xwANzgykne+4Dt/dadkBIRRu0WFurEIax9x97BCWHlBuLWON7CcXlXCHRohG\na+QG2+RmryXaXKF8aRFHOMweKtOXAy7rMfrSR24vUfFG2fb7iOFZ4ngdf6FA8WJKwRpGulVM0EZH\nCcKv4JoBlmVhgohCfhQ7bGNiSOKYbq7K8okfQH7xvuzgMjueL0CoFNaOKFpqO+tayB1llria3bSj\nknmFwPy9wzKKRNtkqcMGT0NsGWg30ZM5/HaArnhoKYgNuGmKMSlOanFmfQvvyjlGfI8LW1tsrG0j\nMSS2jQkVtbe8nmhphbt+ZJZ9N8Tc93CHfqoYsi1EqlBC8NBjJ7nm2PV0o3WGKxUunV7Hsh1QWdq5\nsEAmKfZO98SyJcJxEEmMIQsKFcrgCYuw3sseBGHIxRrfKAbC4CJJjUKa3fd4HGmQZL4mSMPVkD5p\nRHbFRe4Y02ksbZDSAguee3GTZ7/0F1SHKiAlUz4cr/k8fLnDhjYcHnNYbCe0+hHzwyWK01USA11v\nwOyMxC5pOudi3jTnkpMwd42HZ3zoWTz8wAbH39qkzTivyjV2eDvwiv9PuhOyp3eKHaUzrk6SGtJE\n0O6EXP7cL1PMSUQgcWWyq7XpRRYPP/g4Lz13PyUn4vzpe6n5EMcNrrvpJ9GWy1MPvIAwqwzlXHK+\nzeqldWZnphgdnUBgsRrFSCEJwxDHcajVali2j1KKXC6H7/nk/Cwd+5FvPsi73vVOfNfjG9+4jzte\ncxt9JXn0kUepC3BJCQd9WnGL3PA0hw4e4MyZJWpDw3ieS6fdwRAwPFYkbreZ3TfN2tYGlaEi9eY2\nlm+hkew7eIwDB+/iltdaPPiNTxLHXRYXd8fMrW9JXKtPu90DA56URF0DMczfci333/c8b3jLtahG\nm5uPv4ZqET73R/cQkXDtkSnOv7TN5MQQzSDBtSRxamOrAdr1sfJjFHOXmByGpx8/xYXHT/HFv/sK\nt94wjGNDmiToJCTodWk2NBurj/Hyiy+ju2A7WUyK1IZgoBEpuBJild2XjAItBIosYkIrIM1UWkKD\nic2OVHHnc3YURbuAtgxfu6AoHPKJlGb9Iz/P5MHD3PG+f8PmWouVh/usrb5IcXaB9tOfozJ3iOZ2\nnV6UUHNbjI7naW7lcPOzJIMO2yvrzO2v8Iaf+0XioMujn/skH/uFL0IR8mNzREGTM40+A8Aam6W5\ndoV2lDAKFHIGY0uMVFQdh0a3halM8csf/zC337FAqiNSLH5c3sX25hZ/9Tt/wMKBPN3BANlsUg17\ndOMew0UwsSZOFIWxGYJ2ndWTz1CaXeDg7pYnU1Mq+H9fZjLl56tfI3ZGXAqjDaWqha3zNJOQ+ZlZ\niq7gxusWEJ0rrC8vUSumXLlwliNzN/Hgc5e55ZZjmDTk4kaMXyiyttQi52Zd9/FRD2lrakM+uWCJ\nUN3M2nNPc52VMPmmt2A7CXHOpbUW4qUD3nn8BL4AJQSWUqgIHN+nPH+QpSceYSBtHluXvPeaEOWN\nUUYgx8axPAdp7U7FliaZzcLO3T8jIPoOlguxaaMHdcZHXXy/hp0fxVgOwraQpIxPVXjQuo2ZcItv\n/sZfc8sb4cXNEmvWKPc+doEffuM8J37hYTatDi9deJHS8klYXqJdm6HfeQFLJuDYJH2JTkAzhBlE\n9P/+K1gjFmJ0HJGEWI02YzmZuXSH2Tmtkmz8KAQEpkJue4T5IxpT05jacYx3A0TboLsgdm4Q8jur\nir+7D48nkTKk0GuS9wsMbI+S1rQ2Nwl7eYKTz1I/eoRiRyLLBcLxeYIkxF9ZYqjZpnz7LXTdAd3t\nLgW7SZpoqEyw1l9hYmoSJw5wRYlAB4xfP4oO+hSsLmbQR7sVImK0tFBKY+V9gu6ACd2i4eap+g5W\nfwV14q2ov/s8ulDOsqu0wEhIEoUxOvPB2Znpyp1qX+14n4hvcdx89UrwvSFRCkskaOmQhhFKSmSs\nEFrhx4okjUmUztJxpSA1JrNg1QZdr/PYs4tUXYvtdh+hFBpDmiqEBeO33cilP/kES/tcvvH5iPNn\n++TtLE/LE2CkpGAbTp85y6/9/E9xcb3JS089RRSn2I4kNlkxE+kUS9oIJCqFtokoXR0txQNSHJBg\nWxZCKXLaEEqBnSryFoRGI4Dw+1BpSa0xtkLplDQVYNyMuGwbLCV2Zu074aHSYKGxhctn/vA/UPIT\njJFMiZRHL/ZZ7kXcPp/HuJJuKCmGCQU35S3H5+imkshyefT0MhubitsnLDZnDEHPUKmlfPWBdRJl\nSLG487YRnrr/bo6962eJo+hb/vYCvRMIqo3cuSmaTH6eGpSCNIY4EcRKcPKL/5FicJmOk6MfhFTK\nu0s6vPsrn+XUc0/Sqy9TKRWpb6+wfOVFcr4HJsHLV5nfvw/bu0Kh4GCFhutvWODsmcvUt5tIS5Cm\nKUZDLpcnCAasr6+TL7jEccShQ4dYWV3hwP4F2p0OBxYWCPoBKlQsLMzzpS9/kYVrb8B2PYRbZGV1\nhUG5hPQdiLo88cxpDh69Bte3iOOIfr9PGA3w8z7byz1i4VFvdoiVhV+osNFcR3pyJzSyzOT0HG94\n04+yuXmFR7761K7WptsylIYFnhT4ecPELFw+mz2b+/dNcj/PM+IX6I+OsPXy8+A5FIZBFzx69QGN\nVsD2SpvqUImw02eqWqTd3ebRe0+RL9ToNTtUgbojcIvQbBqSZo+yAwjJ8HAVbUks22VmZpLNlU1C\nso6x5wusSGMShVQCP68ZpAIiiD3I+wZHClKdHSyYnUJfQnaygY4z/qB4RVG0CxjFswMYX5FUXcNU\nJUfEEhc++ksced07OP6+D3KdU0AnmoESmDRgbblJHEQMun3SnMORY0P45RFwLHKeJOj3+ZOffyeb\nm5C4oMsOQvqsNerYVkrt8Ov43z/wWkqlCkUH1q6c55Fnn+Vrjz7K4QJYvsXSep/5W67hj/7Pd5DP\nQat9EcsuISwX2fDxPIfYHpAMeqytrjDmOERRSqsxIOgqbFcSSUHr4oDtdkSn06cQGO7c5fIkO/YQ\nxoC4mmVI9mxfdXhHGCzhYrCQTg7FOs0tm/kDd6Kkw4lDY1w/YTi/GCOsPJWxCo2Vz6D0zQQdQGk2\nVmICcpSKRTbtJrEwaJ3i54q4hTJTE6O858ffzYV4kmfIEVTKbKeSm47O0+r0SYwmjmKGpsZZunKZ\n9splisNjCMvDGj+Mu3Ka3uKj/NkZi1JO0ejDzZUC1ZFhet0BrWabaq2yq7VRsXolQV5pIIwxjoPt\nW+BIfHWapG9h/ApEVrZ3tYI05GDyEnFyiK/Yt+H+0q/QyklWtjtshpqf+uBRXj77Er3Ok0ysv8jE\npRdpRy6XBoZJbfO5C/O8b/YS2ovRqUW4BXHaxPEsmPUQjsDoLtZolThJ8aMIIxVktFtQO7xl4EI4\nw/SlmNOf/DRv+N+mMGkLZB3hDmXPZHEcY3x0+zR8B4rTd5elDwKc6XHGP/wuvnjDDPWkR/XyGaqL\nyzSPHcW77hijKiWedIk9H5bXODo9QhyEXFo4zMjGGvFgjUFuArlvjMk4pbV4HjUzS9hrEiJwpYWX\ncyiKBF0sIrwqgSVQhTFUEiNUQs5xaMQh6fYG28MTjKYN4lDQL0wyUYBGroRtwEaDyQipsdI4Mitq\nLDTGZDwNfTXqgEyaaHYOc7PLLoYgxU5gAPhGoFSmFEljhTRpNvcPu3jSxSCxbRu706NvewjLod0L\naIssHLGXRlRzHkSa4h23svqpv0J3Oyw+q3npyXVynoN2JR6ZRFpqjbRt8rrH7//hn+P6HiXfYZAk\n6DiBVGG5DjYCV0BqVOaZpC08x6MXhUhj4ckE21j0jQJliNIEXwoSY3C0IW8sukZjyd13eLIcLysr\naHTWfVNYWAhSmXV6rG/psliWyyP3P4WJEgKj+aHpHEXLpudIbjxUJohSigJOrbUp5QvcPD/BTE3y\n9IUul1shQXvAQEoeeKhPvmCYqjn0tcfMfJ4Xrwja9Q7ttYC//fR/Zt8Nb8Id34c2eif8Myt0jLna\n2QGtJKkyJCqrU6MEkliwdvob+J172R7koN+mVnSZGNpdCuQ9X/kojnZQA4EZmeWa6+9g9eIZOp0N\nvviFjzE0doROa5Hh8Rz9NEIZC2lp/JzP1ladYsnHsi08x6XdblOr1VBKUSjkCfp9tra26PUGXLmy\nhBSGQwsHeOmlkxy/7jj5fJ75+TkajTrVkTFSJ08UwOmLq1y6skW3FXL98RvQqQKR0mw26fUDktjQ\nrA8IQ0GjV6dcGSVMJKOT07RWLhAnbeKwydbmBabnphmtHkNHhV3vG2TGlUotQRIKvDxUyoLYGC4/\n/SgAW4/fT+ponn3oy+yPLQ7ddoKT6+cJ15q0t1tsr0uGJ4ZptTv4UQCJoTYCGo1XyJEvSmSkSSwL\ncor2DoFdpSmDfhevNIYUCs93qVRKzM14NIII6Wf7wVhQdh0m8zatQYAxAt/OChnflkTKoEJ2pKxk\nNuY7nYVX+PHiajfie8dbpx1e7sP9rZgfrNmsdRLakcWR0RwvP/IVzj78FYQE14NyZZbc8CTV8gjS\nzWPyEA0GrD+5Tmv7Chvrm3Q6ECkIgcC16Qw0lbxmqxux1IxhaIa3X3OM9vIFClMTDO+fY+jIYU4c\nPcJ73/wmfuk//hZxX3HTjQf4yL+/haC5iFFDKJFjoLs4TgGZuJx/4utYuSEubbYISwkhglreIYl6\npOkA23cZqhap1xv0RI52J8FrNHe9dVIld9h4ZAcmAil3GDwW2NImRaOVAqMQRjOIUo4ffg/7a+Pc\ncf1+Vl94gs/f3WRfvsv+2SEKQ1X6Goo5m9m5YVR5P8+fbpIYg+VoHBfaQUpxfJJer8NQJeDydsRT\nV+DM9hX+2fwUYb9NkR52muAV8rhpSsSAQ/v20Q0u0106hxSG0tgsJlfFnbmeoaNv4tTf3cMPTGRF\n5fbmRVYWR5k6PE398jrVodKu1iaUDl4Uv2q1EaWABTmJTDOer8wp0A1ktjxZN9IC40A/mWCudY5i\nOWbf2hL54RLGLyAvNrnLbGHu/RxWQWNkREU7iJnb8GoFXvjCc/yzWRAqRY+OY+k+xUIpc2ken0W6\nFlpJjEpxaxInXEHmQEciC3PcgRbgNs5y+bU/zbHbP8RXPvLrvP1Dt7P2zKOU5g6QqxbZuudvUSZP\nX0kOHf/1b7sO4vtJNN7DHvawhz3sYQ97+B8J348AZw972MMe9rCHPezhfyjsFTx72MMe9rCHPezh\nnzz2Cp497GEPe9jDHvbwTx57Bc8e9rCHPexhD3v4J4+9gmcPe9jDHvawhz38k8dewbOHPexhD3vY\nwx7+yeP/AaEKzhrzt/cfAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Sampled completions:\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAsCAYAAABhRmIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8eYyl2Xne9zvnfPvdqu6tvaur1+np\n6ekZkkMOyeHIpESKpmRttmQkRrRAlgIoieMEiuVYjuNEDgwhAWTYiZ0/EgOKbUTRFkmwSVGkqXCT\nOBLF2dee7um9a7/79m1nyR9fzUh/iAOWgCAAUQ9QqOq6BfR33/ue8z7v8z7nCOccJzjBCU5wghOc\n4ATfzpD/fz/ACU5wghOc4AQnOMH/1zghPCc4wQlOcIITnODbHieE5wQnOMEJTnCCE3zb44TwnOAE\nJzjBCU5wgm97nBCeE5zgBCc4wQlO8G2PE8JzghOc4AQnOMEJvu3hvduLP/sHv+3K1gIdL2T3+ptM\nhwPipSYzL2ZfKNovvURSq+MuXMCLWyTNOo/FQ76YNlB5xEX/FodFSRCtQCEo9oeI9SYyihh3ezSb\nCotPljriwEfNumzIEdolZPE6o8GIWjRExg2GNsHdegP9yOO423tsPfows/4tYqMZf/5zNF5/jbBW\nRymwzuP6vT6hZ2gmCmHBOoETFoHCCMCCEQ6cBQcOyb/43WviWw3cz/6zX3SlUHRWVtntDphaiJ1g\nNpmRr6wgR2NMLaKe1DClZlxqikyz0Aq489JLZM9eI0sntKKQUApKB7k2bP3wX2X26d9lMB6hpMA5\ngS9hqKEuJdo5AiWx1mKVj5MegQJjNGmhmZc5i3HMLM2Q1mCdQAOxFMwsFIDyfTypKJ1ECij9mOX/\n8Rew1uFhcdainCTEUeCYCsm17//4txwbgM/+wW0nRYCWERqPUnjgh0gl8UKBUOBL8DxBEMKd195g\nvn+DL/7el7jzJ7/Gpy4uc2ua01mug81pL9bYP5yxPTTMrGUw1mTWkijFSqy43S8hrhGFivWlJmWR\nc+vWIWdPBbS2Ev74jyb8lfcs4gLYevQH+b6/9Q8oyhKHwNrqy1je+bnUDmuh0FAUkKaSLLc89798\nksxO8KOIjdVl0nTMfh7zq7/y77713PnH/6UrZ5JOc53dwztM84w4cNx843mEbNJeO0M6OKC9lGGk\nYjzIKbKShcY6d27v0Om0aC906B30cM7SWlig02kjlePatTe4fPkRrBVkaUoc+Tz1oSeZjsZ09w/Z\nOLVOs9Xg2ZdfZ/P8Zb70lS9QTnd48fnbaCfYWNlkpbOIF0LYsOjSMOumzA3U5YzlZo1xlqNlQpIs\nUhhLsrzEQfcOzfoSteYFzl+4ikkNaTrgf/6f/h7OuW85Np0t5QLPIRwoITAGZhOBlzjOJpbnb8CP\nXQa3GrK2GDLv13jsB/8aX3z+ee78yasc9KY8/eRpOg2fN1+/S6uV0OsX3N3NOXP1PIPxkGy/z7Up\ndGqCpUByruPzxmGGizt8/498gloSE4c+y0trvPKNr6OvvYDzBC/embHy6EW8JCDUGXce3Of5myMi\nI2n4oJQiDn3KsuTBnkYocD5IHAiBtSAcVZvpHE4LnLbfet68r+ZKLK8PDTfGhr+0IIl9R+grrp7y\n8D2J50vqkY8QJX5QQyiPIpvjXE5WeJSuQbT0ELZImVOj6Xb56lsx1+69wWAyI/NqXOvOAHh4fZlz\nnVWW2m1On7vIYj3hcDhhc2ONtaaPCCI+98bX+C9+ep10vku93mBuHLV4jVRrEApPBcymQ/73f/SH\n3L05IlaQJDFFmpMwpxkHGCHJ8hQ/iuilGaawRGHA79zoHWvP+al/rJx0DgcoJZAShBAIHJ60OEf1\nJUQVf1vjseBTPNb6ILvb93jpwYxb/ZLvvxRzcdWjLsY89JEP8cv/9P/kOz94mms3e7znqaf4/Fdf\n5plX98mmA8amxtKlpzh9+T2M/t1/Tm31w2x+73/M2uWrPD24S30v485en+79OxxMd/mV117jf/ux\nH6eIIAlq2IuOvrMsXnqMlbMPI1SImm+z8+u/xM/+D79GMxEYpfmOyw02l9c5+8hZhv0ptU6bR//u\nMfacjzXcjxpNiMM4gfUCtOeRDnM2WzNOfw+YGshZSLX4a5CVoA1eNuB6+WFU5zzPXtvDFimnQkOc\nzbj48CP0tg+pj79Os5khY5hOY/ZaHyDwSn7zV/6Yn/8R2P7u/5WVx5/E3/kitrTIm19CXvlO5Pnv\nROx9Edd/Bfugy/6vfpHmmkDkILy3P0OH5xxv1j9E8fOfYTXy+b2//d1sbD/Lpb/yg2ycW6L/5j6z\nnQe0ojr5/Qds/cqdPzc270p4dkrwXnmFrL1GY6mDvvgw01qN+WjG6d/7Hbh0EW91ncOlVR6ZD7mz\ntM5z9TXkXo/B4S0O2mdYi/YY+DFReZ+eLCmmddaDPqdWWvRNm0k+I7Y99GSEkiV9tYj06nh5j3AR\n7FyQqhqB1YSPv4/uNCNTPqNySM2WTMMYc2oT8/LzmNIH4WOMwViNJyVgEXhoz6GswDmHdnD0EhYQ\nVLznONiZCTwcw9ku1g8RWcYUg9CW3FhCY/A0jDKHCgLSWYkMPPanOe210+zoF6kJhTSWkXEIa9EO\nls19rvzlLfrbGZPJnOHhmMm8xA2mWBlilSS1FicqcoLNSa0kweJ0SegkeZZjrSO31QqXEnKnsA6s\n1SjPw1oDOGQQYJyhnM8RnkcuqvclhGXiwDgInTlecACsoxQOi8EicEKidYlHABoCIXEe9A4Oef6z\nv86zX/63nH/4Efbv3aW90uTAwsV2jYG1RPjcuDWg8DzONh1JEtBPJBNjGY4t49LiPAPaMpxlHO6P\nePRKk+aKT2d1Ba9ueP+ThkmrQEn4zd/6VX7gP/05rPNwDsyfITzGCIxxGC0ojaPUUJSCorTsvfUc\nSTPHFRFeVEM4RxgkJPP58XJnO+dn/uZ/S6Pe4BvP/hZf+Pf/NyZs0WisY7Ipdjaj152hfB+VlKT5\nDGk99vd3abdb1Gs1+r0BSnlk2Rzf9+l2e6xvLPPkkx+k3+9XuTMYcPr0Br1eH19Izp+/yOvXXuH8\n+XNYq8FoonqLUvcJ4zrbN+/Q6SzTXKpjTMbG2ir3HtxHKcmp1U36+29Q8VaftbVTzLKSzZVl3rq7\nB6VkPh0h5H32d+ssLy4RBsfPG3lUjCIpWF2BQQpe7FAW3jgQ1BcE11PHyl7BRl1x7VaX4kufZX9/\nRnulxZnzC1x87AqD7Xvo0uAHBonFiwSHOw8YFyUCSS1xqAicdVitwYHWmhuvXOPUmdNMJ2M87xaT\nbpdEW5TyyQ/mNJ+q01lbwrMFOgq43XseNQVFVVytE1Q7CiAcnhIIWeWWOCq4wgGuIv3HyptCsCAl\n5xNHagzPjx2P1sEvLS8/MJxZtPgyI15dIUzajGZjgppCFxLtb1BrzVlYvMJ0+zXqa1exd79KufAE\nH7qyyNWzS7xw4x6TUnNmOeWFN+7jEMy1ZWc4ovfaG2idc2p5ndt3b+E5w+b6En//v/oOutNXCeOY\n3EIUJuTGsNxqMy8KkAs06xtc/cg2r3z189iVFlmWYpDYWkRaCNLpBJQilJq7I0jHOaut4ti5Y43D\niYooKwnKk1TdrMRYBzikEAQKPJlwlR+g7T3OZz7/DC/sTGnUY65sdRAixsar1JMednqXqF7DOcuV\nh1d54Y++QX1hncDcwPkxFy8/QXzuPRA0udH4UR7veLzvY3+Jdjlj5e4IS8HZ5Qg7XmRtIeGTH/kI\nWTpDGoXVBQvJKqmdUc5nSClBGPBjChew1obxzFGLIM9zXFxn0C+RQUw5TY+XOwPHOJJ0pMGK6u49\n58A6h7UCjEMu/xh84uO4a59GZLvIT/0iYu93cb/9T+i+9YDTjRZi3KeeDrnw8aeYDubM9w5xwx5z\n12JRZCglUNLDTgeUzZjpHNDghjv4X/pH2OEYpn2Ev4/Yew79hX+ACGsQBAgX4VLACKyoyImrPjYs\nQABy2qcgQS6fZ7OVsvPqTaSxXL+zxxNrHTwtERff803j8K6Ex2vFJLKNnox5KwNSgevts9Q/wDxy\nAd1ZJ22tMB9MGI/GnFt33OuO2Wg0mbZWkLN73G+t4q+sU7w1oLFZ4yD3GFNnWLRIbMpyeYByM9Jo\nCW/WJUzqWGexnodLLdNwmViERA2PaZriZYbIU+xlPssrl+gM7rFf9yjzjDBpghPkxlA6iy8lOIVF\n4FuHRWGdrd70USAR4JxAHKuXAM8YhHX4zpAVJc4YjFJ46QxtLZ6xeLZkbnJqJqBZTpkSEjqL12qR\nTjIC30N6EqUdxloWI583P/cKX7rXJYp9lHMIUTWEYeBjrEWWBmT1XpAWg8ETioEzSGdASqwBJRxS\nQOYcpRDgDM4apFVIY3BKYZ3BGUvgCfSgD+1lfGGQQiIExIA2gD0mG6RaSDiDddUzGuFw1lEiEH7E\n/t4Dvvabv8y9l79ELUhpRiGzwW1GszENzyKNz/1hip8oCimZpo5unnL24ZhYggwUk14OXkiZQ+gL\naouWNJOM+xF37qR8/HLCS3cnrOqE5vmYnbcM585LHv+o4htf/jqPf+yj5IXFmCOVx4C2R2THOooS\nSi0oNRhr0Nd/C3yBzQSNRoulZsib93eIk+h4uSN9wqROnkNRFmg7oMwdSyun2HnrFVoNhy8l87mg\nFiuajQbTQUpZFDTqkjwv3tkJGs0m0+kUpRTj0ZgoipBSMplMAYFUksXFBW5cu87mxiYXL5wjKwqm\nszlFkZPOUzJTcPr0Eh/96AfJdMpsNkQKD1NaLm1dYEfucfdgl2YtZmllkYP+mCyf4Xke48kBYZyg\nXUoQl+wdvsFg1OOR7/tJnGkdO2+EhMiXtFsgfEctBD2CYiJIRw4ix58MAQTPvjnnzBbMend4/+Yq\n1+8NWdp4mEYtYq+wGAdOwFwqljo1droj+imcqUl2tKGpYNyDsdEUU9BNQ6kdYb3FQqfF5Yce5Stf\n+AI7t97g3NlFZBxw59YdXnrzVdpJg9kwpxhBcCQUIzW+UlhrAYc8IjqU4CRIV205iKOtxx7v0tfn\nM8FFpVnzYcmTFL7jK334QNPhK81bA8lqLSDd6bMQHKD8mASJ1ZJGzTIYrxCnt1Fqgf6D10kal+g+\nuEYWnsF5Po88cobbb1zjY+9/gu/72NM06g1u3N+j25/woYfWWG6vsrSwyHw65a37d/HrdbTXR1iD\nI8balFbjEsL3GI53iPyYVqPDPCv4nh/4GP/yFz+Pl89QcRPp+wwmc5w1CN/D5iW9WUZROtLCkhbH\nJ8tlUeUP0iEAaw1CgpQSQaWKCwlZ6bBmzmuzGfdeeoa1jTN43ZssNWs0Qh+nZ1x+7BEWliXXP/cr\nvO9Ch4P9Lu1mjSCpsdc7QHoeH//xn+PGNKG7d8iHPvgQT33s/SwFmnmvx/msixMOnY4JQ5+rVxZA\nxOyMuwjP4QmJsILxzSmL711koDUiUFBa8CTh+kPMclhbEGy0fRqtFvPhHs1WE5cXKHU8N4o3E+ig\nWhNGCIRzYG3VDBsJU4PwCkQzwkbL8Ce/g+8+Xu2JRcRwbFmeFzx9WtPsbLLwvd/B9c99EfPaLrVm\nwEjXsWYfz0EgS+qRpcgmjKbQK1dpvfZvwLuPzMCU4DzIRqB8ULUZBBY8UwkQDpzinb4BK6rP1BQY\nISnwiX2L0zEv/vGz3C0afOLDF9GDEflMkyrLyjeLw7sFKRARZbCEOCtZtB7m7g1WkhB97jSeiNhx\nJd3ZhHP7t1lb2mBsBCqo03vzdTaXT1FEgqC1gLm/Q2oiTNnEL4boaR3mz6LOnoVeTiYVyegW0itJ\npyX1UOL8VdyCYp7DkmcwWUo3BTtLiRIf8pR4lHHgtVg+fZ6D0idxVVEtykq+FFJWvEa8vXiqDuCI\n4GJwOCv4i1w2HWRTKqXe4juHsaIiL3mO0I5Ya+xoRGvBx+gU3e1SeD61zhLkGdZTRIGPsBZjNdIJ\nsjSjJiEMHZ40KCvRHniiKsCZtSSeJDWCSFmsUkgtKI0mkBLpBNoYAk9WBVuIo/FX9b4NEiUs1hqk\nscRRRKZL8DzcwSFhu43SjlIZpKnyrXQw4/iERwqJEwLloHDV5mOFxQGvvvIi//LvfpKHz29ybmOF\nL7y8S+B7nFpus7GYsCljClPw8Pk2X7nWIzWafm/OqCi4dl/QqimmxsMpR5RojBG0lxzWOJYuJtSa\nUE4tr90q+P73tHl1W7P3jZILTwUEZQRFj9t3+jz6UYE2AmuqwqRNNcbKtUNr0CVoXS3Qspzh8tfx\nvCZrLY9Op8b9B9u0mnUGw+N1WwEln//sr9KIz3M47BJGglI6SiyDYcHiYo4fGAb9MeOsYLWzSJGX\nLLQWwcHy8jIHe4cYbRkNh9QbDaIoQmvDzs4u0+kEZ+07v2s2GmxsrJOmKZ7nc/f6DZyFPCuIw4jx\nzHDh/FmC0IfSEEdLxFHCUtIgFD7zVotukRF6isj3iaMQv17DCNBOE9QCMuswYsj6aZhODrl/7wVO\nb106dt5QVhLIrBTsHzi8QDAfwcGhBR/W12D3DoDDAz7yXR9gdWMZr5wx1w8oyqJSD5w4UlQqIusL\nTaig9MCzlo6BdV/QuGx54sm/xNLdAc+8cJP2+lmc57HUWSBJEmpxxGBQ8sgVD08XfNcHLjANYpaW\nFnn55Ru8st1FSAc5RHHVOMW+B2isoepWBAjnsAaEEjhTjbiEOl6XNcoLrgmFNiV1HIlwnArhpbFj\nMxecb9p3lOqthYBsmrOYd/GUR6lnRP4hon4VVU/wJhHz0T1qrQUmD95kafMM+e6bPHG6xfz2v2Xt\nwodRRZsPX9qkFj3MdNAlu30LubFJLH0uLS4SNRLeeO4uZy5ptPBpJB0sKbZQtFunUQHsHtwiipq0\n1pf5D/72p/i1f/55lsKC6SRHOIvyQE8ztHBIYxEypBbkDPN3LU1/LrQDoatiajQoK1DSIoRFqSoP\njAFtBDp3zPSnUa2/QW+yx+bmJsXsECMWUXENaXKyIRTGJwpKJqWmGbapJRp7MOG7f+q/R61eYOOw\ny9XLZykNZGmBjB2XdBd/0KW0GqksQjq8qEF/NMCWBbYAGUU4Z1EzQ2Yknq8QKJwnwSREZy4BAqct\nN7vw9HqIsCXp3LBxqo2eT44VmyAvsEZhVNWMKjgi5lAWAjcD/Y3fwLv3NaTeRrbBDAADCI/Uxhz2\nC7bQSKfJvvw5tva/zL37jmTTg1qHLIegDn5UUC/GzLIhY8/nCzdzfqy+j8kr1u954LRD1kCUAqdB\nihThHGVZrQn39ujlqEbjAFMghEAZzRoPWFur89Cj51i5tIKYDylEwnxyH93ofNM4vGtW7QmPc1vr\nDLsj6uMhrG9SKMGgO2ceZaRZxoKFpSc+zIGWzOox+VtvcbETMwg0QbRBfXmN3s5nmQVNlucTimyA\n8PaJF9p416+RyAdYU+C1HsWPIozXwMUxByhqFLTsISJeoSwNVvgUCwm+nSKlwpqUIIuYJ010vYFz\nDmU9LHOwDk9YnJMIIZBHC0ICpbB4TuCUxGlbkaBjKjx7uzt4nk8UxSwGAbk1BDIi15a4KLA6R2Oh\nTMFYRJkTpxmHMoQ8RZUFJvLxHFghwDpyY6mVDmkMnpRMTUHgFBoolcBHYLQlFA7rPDAFUnmkxqKM\nxRmDJ2FSGgJVCTPWWRCSwhoCJIFSZM5S4jBFiRKS0mqYDBGFoRSW0AqsAqstoYOYYwaHP+2unHU4\nazDOoQEI+bV/+AkefughDiczXv76NYIwYDH2WN1Y4fa1Nzn0FB9Yi+mOCmazgnbk0bWOehQyKufM\nJx5XLzRJxAJW+Xh1x/2Ro1abM91P2VhfxJQa1xZ84eY2W2tNzgWKojdnr6tY22hz7dnf5BN/469h\n9dvenaog5RpKDVpXSpnWYHRJNuojXclBL+PK1jL37r5FHK9z/c59WguLx8ude28SBw3SJGKUpczn\nJatnN8lnmnHe42AEh70eDp/INemmc06vLnP7wQ5In9ObZxBCMptPaLcXKYqCOI44ODhgaanDma1z\nJEkNaw3bO/d56ZWXMYVmc2Mdaz3CMKTlJQhXKUlxkFAUJXsHu6xtnKLb7VILc5IFwU5/iA0jeuMJ\nj37gPfTvv4kR0Ns/QIYRXhySiQHj7JCtMzWybEhrUfLya7/H7dt3jp03pXFkJfi5IPKrfc4TAnCg\nIEiq7qQdC37gPZDeeJFnxw9x8ZHHEI0ZgRH4UmHw0Q6cFExmJe2kUlhyAzMB7RC0hDSDpNVmqkd4\nnmQyPsQJjZn3yeYFO9v75D7cGzkmBq69+TK1lVWSSNEbTMCANUdWQCdwzpKLP9NBvf2jEAjlcM5V\nREeAE8frtN4fw9dnhrul4pQoCQTEQtDxYKzhlYHj4QbUAsHB2OCQFKUj9DWjWQlC0xp8nU67AX6L\nnJh5d5d6u045vIGTAePZhNbGkzTOPoGxmjI3TAb7WCs4ffkKvhM0FWgZ4JzlzV+9zuiHGlx+r0E1\nOpS6IEnWKIoZ8+mQKOnQH/Xpp47/8Ee/g3/9zz/PbJ5V4xTlo3KDcQYcyCCgTEtQqlIxjwmrqxop\nJSCPVGZbkVBj3jbwVD49J0Aqj3k25+LmBQb9HoEf0+31+KFPPMF0XvDmH/wBl8+uY4qU/P6MuNak\n3tDcydZobJ7HWLh66SzDucZqw/pqAz+dMP2jZ8jOnSNI6sz2BoTWUboxs9kEa330kcJS83yENvgm\npH35PK7q0kEImlsXWA5KhC84GEnevH2fpz/4PpTUGA3J8jfTMP58hEd5qgEhHc4eeZqOJgLOgVBg\nt3cQAVjfQ1mDKx3kCpcERK2Y+e0MPylp2B7T+AmS915DDAb4RYOhaNBUE6QFX6ZE5ZS9VNJIAAPO\nOETNIc89gvG2CL/rp3F6ivmtn4HDEiczdHb0PByZ3Y78tkhQLifCYYRg0+1hpx3OX1gjKrfJZzXm\nGKxzJFtr3zQO76qLNVROdmebZnZAUuzjBY6uisgCB5M5siyxrToHu0N2NIwmU9RoTGOhzeLCArLZ\nYnawy3KrSWv1FK7pky6dI1rdJLYRtTNX2GtdZrD4KKmAnlpk6vnMywyRTZjPC2r1BvN5wTRZoxkK\nammJXjmD8GoYr0naCAgbq/ibW5hSgzJkReVBQSgQAke1ESEExjlUZSOsJGVn35mtHweNsM5y3MC3\nDqM1aMM8S8lMCrpkmuVYY5C5RszmKCFY1I5iMmfuwE9ifAHWWAIL0lqUktSXa3zPDz9B6iRPf+Qq\n61srLJ/q4FmH7yzWWAoNnrOQO4wxhAL8o/FceaTMOOOqPJEC4zTCWFJnGFtDqTVGG5wHKRbyAntw\ngLEGZao2SBXVuKswoP8CIy1jNMZYjDE4ZzHWUWiDlRKZLHPrcMRsntOqJazFitI5lO9R6JL5vEQC\ns1lJPRD0pilLiUJFAY+eTXjvxRYLYcDeIOX5nZLhqOTymWVQFr/tOLw3w2sGrJx2LK41mdShecZj\n+1XNC8/uIkpwvT9iOsgrcmMc5ZFB2WgwWlAUjqKw6LKkLKcwO+D63X2SWkhaZuQ6ZDRPQTm2dx4c\nL3cWamTzEuHqJLUmToTMs4KsLHnk6iZREjLLhggyrM4ZDecoFVPkBolCIphOp9TrdYQQzOdz0jTl\nzJktlpeXWVhosbf3gOef+zr7O3tsbZ0hiWK6B4cg4ODggFocMZ+MmQz7GAPb2/dZX9tEIHHO0UwS\nnn35ZQ6zKff2t2ksLJBpyWgwYTrOmExSjJFY6yNUxtraKkm4yHQYMBvFuNJwuHPj2HkjBCAFTjoM\njtnUcdA/WpwlOA1nT9f5gY+uEcZ1dp/VqL0ujcYiixunkL6HEg5jDcKCwENJR70WISIf3zpyHOsN\nWK7B3MIzf3KNfD4iTQvOX7jEysoilx66zKVLl7l0+QIbMawDRQjOSpzyefbFa/zBH9+ATGBL8H1B\nEAjCwH9HZXHWvTPCQgKeqDpmqIjQMZfVeiB4r2+5lxnulpKZgUA6IgGTUjDT8OIAdmeOnYljd2LZ\nmTr6M0tpHd2Z4u7A8frdCbdu32XUu8usSGnECXLlaRY2rnL+8U9g+rfovfYV9m7cw4wHDLdvQTqn\n39tlPLzH3b3bjG68xGT7De5sv8x/83e+xuvPjwFBHNRwpkDIAGMkRelhtKMoUmZK81M/90nSsaa0\nljwvybVBKYV2itk8Bxypc6TH33Kq8b86MiQfpYxzVaG3BkojeNsj7knJ6DBB+QGNJGRx7SzSD0g6\na7z2ym1ckRK3mgjpoV3AQj0m1YrSenzfT/w07eUlgiRBeAFJ6LO50WahFhHUY+4ub3DwzPPIWYYX\nBWht6QpNZh02jAgbDYRUCCEIlI83mBGuXKpmcghcMSMfz2h6cL0nSXxDEARk8znNRh1dZkcezG8d\nYdVTU9rqezUQrxoJqwXOgK/qqPoa1q5QjhOK2xKzDZQCfIUnLCx3cLUY01pBDHbRLkEqCERGqRJc\nCVhB6GlM7pGXsBkbpF/VIrHchKd+CPnIFYwNcMUYKcrqeY68O1a/3TxUD+pcpbxrGTE00Jul7F/4\nSaRnufDDT7P1C79O8P7vJE81ur7IxkrwTePwroQn6R1Q69/FHfYxnUtMMkPZ61FkOYP/5/fZDD1q\n05ICTT7sE+/vs7p1ltvE3Hch6yrD4nij68OdXcjGFM4h0hKnJsyzQ7TWOAL6QYuJGVPD4kufhVrE\nYmwJshzhCaaTDCMUpefQE00966K1JtEK8hxqDaSqEjwvDL6oTkeIyiGIQyCtrUyRAK6SmaUTCCeQ\nx2Q8RpeM8wzlLOO0xBaaUmuKvEQXOTadM3npVYb3b5OPRpRBzJ3JFD2fI2YTPOHIS402lsKWKCnB\nOq5d32atnaDHObosEQJOnVoiKzWZsVgsvnAo5ZOZkqK0mFyTWY0sNViNZy2lFUgHWIcpNQpBTcqq\nGACRcOg0JTz6+8mDXZQxSG3RWpMX+p15jirLY8UGwDiDtpbCGkptMaXGWMe9Gy+zslDD6JISSJSl\nFviks4w7D+6hneSJrQbrrRDrCZTvc/5Ui0Ip2oHH51/IePnalG9cG6C9AjEfcKqhIR3z1NmzPLG1\nxaX1JRaTgmTTZ2MdWqOcO5TSVi8AACAASURBVLdnpFmJcwX1SHF2I2Ha28MejTxMWZHFvBQUBWCg\nKA1FMaEopph0By8AV2TsHIxZ9AX7BwcESrDYbBwvNqqk1m4TNRTNhQ5Qo9SWIp8zGN9j8/Qpzp29\nCsKnlkSk84w7O9toC5cuPcz6+jqe56GUotlsUhQFBwcHCFFtEqPRGK2PcmdzkxvXbzKbp9y4fZev\nP/sci51lgtBHSIPEEYd10lnG7vY29XqDs+fOsru/R4HhYDrG+ArtCnZ3brG4ukijtUhRlAwGXW7d\nvE4nruEZj927QxK/QzNaphEsEVv/2HlTWIEXgC1gOBHMcoHWR4bTmsBoOHd1k2L5Ib6x73ixgDv7\nGbdvbnP/YMb6xQvgDIPRsCL/ErCW8STDc45SwMyBkIJTNUFDSW7cukfkCTzP47mvfplXv/ECd29v\no3WJ0Ya0hF3pqHUUZ69c4aELD1FoAylHZudKSTUGCmvR+qgYvV10jxRkAchY4N4+rWWOOdLS0A4k\nD0WOoRE8sJKxgUDBSugwtipo18fQzaGfQ28O2xPH/aFgmlcK5jy3DDPFKPdo1D10mRG5A6LmOvv3\n7pCsXGEyBpGOmRweoPMZe/s3Odh5nb3D++hScmM052sv3OWlssku8PJrJaQeg8kOxmp0kVMPlkmz\nnCCqs1BrUotbfPJ7P8A+oLVFWINxglQLSm2wuErpzkqWT586du4YK7C46hQWR6qtqwq7fdssfvTB\nOGeJl2D1dAMcxJ5lvb1Aux5z7sIqN/ZTLl55hOFoxMHeAc3IIYo+fZ3QtXVu3uuR1GPAsbhYJw49\nAl/hKcl7P/HdfGXUIx0PkKZgWCuZqxykV40z/YAwjLBS4XkSlR4deghjUA6nDSoI6CyHeM7gnODR\nC2cJdEFeljQaTfJcHys2wlYWjrdVLmdd9SzOoU01/pU1gdp6CP/9HyD81Pfi/fWfgI98Py6fcv1w\niPIEdnWd0cZH0Ebi0jH+d/40niuIVYH1E8qS6qBMNaAjErCUaIw8mjakDvvl38L8/i/j/s1fRXz6\n50FValV1NubIoez+1MLD0XoRRlZWCecI1x6jP+whpg/43f/uH/JPP/MqeXuJRt1jcLf3TePwriMt\nP2qiT50m2x1QZI7Y5chmiJooGv/JTzKzIeNhnyJKEN1DlLXslSWeJ7mwtcFh95C8n7OxtoXwh7j+\ngDOj1yjb5yjGKdHiGmvlmDxIkEDdSWZxC2VKUpez5AJ2jMHu9Wk0WxTZlKXOGomcMvJblLnCTjSR\nf4A+s4l761UK7ZhlBfVAgZQ4V42sXKVxYmUVUInF8mdmWe54JjAfh8MhtSGRkhCBtI7SCYbdA+x8\nRlJPKPYPsFKRtRewozFu1Idai0YtJp+lSClAhcxnGdLB3/yxD3HlQ1f4pX91Ceeg3x/yG//6mUrg\nM64yGnuSeZ4RSonCopUAXTIXVEZnq6vOBktxdCJYOktWGjxPoaQgdQJPWKQFbTPcXpegyPDLEi0U\nTio8HFMH87+Ah2eWFhgshbNoq9AICgcPXvkK97oTlushYVhJ10MPzm018Vsxo70h55uC3jQnzw2X\nliJmpWOtFnJ6JSaIchCgAstkrlhfS9gbpmSHU7KDCBlFLK4J8kyTBBEMAtqJR52SSx8K+Z7mEjcP\nuzRW2uTzlCyvTmFpDWUJpnSUZUFRjtDFlLLIMCZntv8WG4sd+qOMcWYolWBro83t7R7t2vHk9/Za\nwivXPkMSfA1rfNIsx4kuGxstGuEmWiUki8tsBDFCp3zH049zMOpz4dIZbt+6xu3b91hZ7lCWJXEc\nkyQJnU6HNM24efMmQRCglGF9Y429/QM6nSVWlxbxA8lbb92g1Wyxs3Mfa+HM6S3ubO+ytrJBaUv6\nvR6NWszpzVNIP0AEIfN0RqcWoJxAq5i9/UPanU20mHHx4U263T7bD0b4KmFjfZ3nXnyJjspY7hxP\ndgco5w4tBYV1DLqg/swOZXHMZpDUE7r9MUMv4m48Ix9MuJDlXHjoAouLbUzvFm+8uUNSQm4Eh3NJ\nmpcsJ5LvuxihZEhLRuxd38evh3zkiYcqJbS8w+Pvv0qy0OHq5UtVZ1nmxCF88PFHKYIHGCXoj8bc\nuX5YzQfU20fOBaWxkObMCgdBVUQqtZV3ZHl3pMTao+nFcfDixBBKS2LhlDCkDvapxmmnJSwHVWuX\nGrg/h9IKFgNHKGGQW3COQAo8CaHnCJTj1sASqwNqUZfo+psIHApDqxnS3QOb5WydWibWc4bGIy1m\nPJgqnnmQ82IP/uu/9UH+xY+sEQYNrJcSeAnOzYjjFtoG1Bw0mw9hLWTZgHpnkd//6t/hwx/9J5wN\nASkptCUQIIVinmqe+sufoh3Ex84dnEMgEOKI6DiB4E/JJkffrRNHR9en7Jn/g66IudD5Z6yc8ol7\n13n+5bs8/f6LBJ5gNhjhshl70Xn0lZ/AhTWuvb7HYO+Qb/zRc9iiZOviOdbX13jysTUatUrh+qFf\n+iU++xu/zZPLBjspyYoQz6uUB5fnWGOphQEqCoiaMRYD2QybTpFRggwNl68+SX/yh1zbdxxs3+P0\nuXNEgUdpSvJZdqzQJFRXjRjkO2THVvwSbR3OKTATGH8VBkcemgLEPGZwqDicl1gn+Pozb/HcquUX\nL2Ts5DHLb3waHS3gFzlSCNIZyLrDi1MWOh6rS46thZTbh0u8tfAU39X9NGI4QQYSV1Z+ZBcUoKu1\nbo54nDl6DQHSCZxwCJeB0QgkxDG7H/z73PiNf8XtXo/3Lizy2gtdnn5ojT/+zB/yH/3cnx+HdyU8\nunQUqcFRIsZDBsvLDNIxC50mtvToDw9ZKKuFVydnob2IDhSdmgMzIW+uIRghBndIXIkfQB+PNTVm\nXG+QzjLqUQtPgZERraRBNx2hw0XkaAfRWafTWmSmdyjjBkmjSek36alV6tMHTL2cTstH2w0WtoaV\nxG40tjR4UXUkUSIQFhwGIxXKvC2Xy2qUdSTumWNuPtZVxmTnwJOWHIGSglIIojBgPjZ4rho5+b5H\nMZ/jrKNmDSqbM5nMCZUkL0tkAb6sxjnlbMbowW3iRkwUBZxeT+h0PO5dc2jfoiwYU90fg6s64sCT\nRzK6wViLLxVaOJwFicIJjRWu6nhsdUIg8T1AkOsS4Rwi8rFaExiDwVJ6GoHAtxL/uDszYHVJ7hSF\nq8YLGjAostEe7zvfoT+Y4ZCUiURPCwokDz+yTv/2Hp9/a8x71xqkFlYaC+xPJjy03mCnP+FD5xcY\n93MGJkNLzXQ+p9NU9EvLg1lOMC/ZGfosNUL6gSHyPYw09A4c+Z0xq+dTvEZIVDgMIeUR0anELEep\nC4piRFlMKMuMMp8dHeGe4HzBre6cK5fOotIx03SCJxWD7Hidup63aTVSth/cY6HZYtgf4wenyeaQ\npwW7e3dYaLZITUo99sjyEUtLMXkxRJuMdKrZPLXBfD5nNBqRpinT6RTfVyRJjWazyXQ6ZDrN6LSX\naDVbNBoNOp0WS51Fbt68RSdqM5tl3L23zWg8JR/3WF1fqY6/mpLpaIDwInw/JNKamh8i8RlNp2Rl\nynZ3SFyTbJ1ZIlnw2fTWGPQ0K6fOUbt7G5OXPOgfHjtvPFV5XZSiuiurgHekkqP7kbwoIh1P0Lq6\nV6JAMOj2eWUyZmVpka2OwvMlOrfgLG0PhKdYWYpoJqCNZDrN8CKIWzEPDscEQhD4AcPBFC0SprOU\nTjshqdcpDPQebHPn+i4mHVMIj52dyqgubHWM0lPQCiQEkqnWR3ftVD46AeDzzghLUB0f+9ZvJ6rQ\n1ZaakqTW4VlHjeq/yaxjiiAW1ZCi7oE4Uo8GhSOQUPMECkmoHOFRkSsU+AZyzzEtDYlfqWmZkUST\nEicUTgumdw+JPMnu1HJtCK8P/lRd+J5PnmE6nRG0HUakKCtJU8N4NqReO4UxAdlsn1RXqtp83mN9\nfYG/9599F7/+y18lLwzaOjyv8gaZCxdZatXoHXzzLv2b4W2y886/RfW7txV+QXXq1UF1WMWCkgIh\nUl65/zOsrP4CW63H8Lweo7njlLXM0zHTGYw++OPEXkTkSU5vdDjc7tJqtVleb9NZXmDz1DIWSxT4\nZDNH5gyf/Os/yPQz/xezucVJgSkNKgwwZUkAqFoIvoVhisCC0bhSQyRQUUT70qPsf+YPuTUIuTwX\n1Ho9WotL4DcIwuOdDHVH5Oao9T9i3Eem/rdDJj0wBhtJCCoPpydT9rJFclvxgf5ck9QbKDdmPi9o\n1VqUxRApQpQtMdWQAWksyvOohSmy0eKc12VRP0PXnWbF3kccLQynDU7Zqkbb6mQjR55Wd1Sdofqs\n/LyPcoJ67ONqEYcPXqOQCfPte3g+PLbeIssF7eVv7ql8V8Izmw+pr9ZxrQ5Zy8ft7pH7glrnNAMf\nhPIR9Ro2y6lduspwNuWUnBL4IeloSiAK9NxQtNcgnSHzObQ2GThFLZ5jsh6ZW8SFEsZDekZTqyVk\nkz1UrUaZ5lDkTIVPYCR5mZJLHzubUfohXlxjpFNiQuYrp4niuBpv4fCVOLpjpwqZO9qBjKjGWMoe\nmYWrke+Rq+cYMIYSkBiyHKJAkmpJ4AzTLEWUmshKMinQRUnkhcxCiT+2dEddgkKTxAGhlBSi8org\nHP/+c7f57GfewqaOXEFZWuJEonxx5N+ySCfwqS4RdM6RGVuNxIxjLgTG2Io5G4uU1dhMW4snvcpT\nIx2mrI6tW2PxfR8yTTabITyvujAssxgFwsl3VKJjhUdbtHOUxlC66j4fqQR+GDNFksqAjZqPnDpq\ngY8fSFpBjBAB3VlBUAuop45cwlyX9IYls9zh9aasLERcrS3yYFjghIcpFC/3xrQWNMK35POc0HN0\nx5a7U0tRahaaCZcuR4yBXtfSeRisk0dkx1GWlrLI0OWUPJ9QFnN0kR0pPJpsPmcyz1hq1JjPJyxE\nHosu5sKy5ebu6FixmXUHzCZDOott5tMxtVpE/v9y9ibBliX3ed8vM8945zcPNXX1iO4GQACEg5RM\nB2UNDMumltLKG9vhrVb2xhtFaGHvHN544/DWEbLDCzFkmwxZIimTJgmCAwg0Gg10dVXX9F698c5n\nyNGLPPdVk3R38OFEVNd79Ya+N0+ezC+///d9/8qRyhErLalXM7YGCu/nQEmRl9TNirqxFHnO4rrB\nWsN0OuXw8BAp5U0pa7GYo5Rie3uHsixxFvqDIZPxmCRRfPbZCz799BFvPHyT1WpNmqW0TUN/0EcI\nwdXlJYN+yfHREWcXU9q6RSnFyfMXTCaHTFdLimFCYTzW1kjlSXOYTht2946ompYkTXEWVtXtTqEA\nKhGUvUic8Fefya5EIVSCDx7rJMMhpErw+7/1pyyBr7094td/7QNGvQLdaedQ0C9SBB5jPcYa1p2w\n/+lFxfj+GJGAlIG8lzMc5PR7JUY3tE1D28DJySlNs+CnjxacrwAbc3c2V79UGCkopML7jvHrSmfQ\nnZgdf5lquOWSE4Jg7QItkIdoaOsJGCC49mCEYNw5W0YqkAhBZSNbNmsDiYTCCwoZSCXk0QlM5iCR\ngmUTUAicCCxbcMHTGkHlBSdLODeScb/Ht9/eR3vLZK9H055TlCVJ3idR0HpLlvQo831kMib3axCO\nVGY07ZThYMLp9Jr/9D//D/hf/qffwQcZbeMhcJ1k/OK77/P24QQRbhlSRMxw6vTKQHh9fwI3YOcm\ni2TzdRFdunkaWCz/ex7bD3hw8E+Z3Btx/fT3WNeBk2/91wyEJE0TpFJkueLDX3yfLJPMF52sYWU4\n7OdoY+kP+7RVhQ2OK5mTBkviYxZc0AZpAxQ5vq4hpNirCnF9SpomUcjeVAirmaQtx/dKntaGQZHQ\nGIH3BiladHM71j2ETt4hvqCjDwJJPBgHK8BaxN4AuZVDliKkgsJxVnmyoodQijYIcAJFy7oGefwu\n7gf/GqGiucCLCCZjlIlnoEAIh0MyVleo4GltTpG0iOMDyAp49jQeDqwALwgdAgthA1hjWbhOBmgJ\nhVSoJOHyL/6CK7NkUC/YbST9411cI9l/ePSl4/DVtvQ7D2iCJDl9SZjsIoaK+3pBpme0155JmZKG\nnONBn6xasVVfwe4RYvuA0n/C2UKS7WbUs4qenpEGSbZ6jt17E7925K1lYB/TZN+kLjxpb4KqZ4gs\nwyQDWmPIyjGpmpNf/ozVes1s/+tsT59TH72DCbAlIUklKj+kGh0Snn5CIiBT6gbMROAT/6s69Og7\n54cghobdspxO5hxeCaSxOJXhrEd6SyslxXqNUhInLMfjLV5eXyGbml5esG6usIsF4yLduO5QQmBF\nXGBEqqh0S5kLshCLmd5bPNFGaEO3KMn4uiWB1rqoE0gSUmdwUtI55uNER+AINN6RJwrv40naeg8q\nPgFCeMTlJcnhYXQ4bcbO25/DowWN1ZiQYJxDO4kDlJdkecHjV9coIVimfY4nBS9rx9wGfu/ffY+D\ncZ/z2ZKV9WR5hvExZ+jyesrB7oDFSuCswXrD1jDl0YuK66XHt4aTp2C9I8sVVJ7do4REOgY9z86g\nprI9SikY7ikWly1BSLQOWBMwpka3S7RZYZoVWte4tsUaHYXUi1OkC1iRkPiU3ckus8U15yvYG2/f\nbu4oR9A1QYwo8x3miyUP3zgmUxlFvk1TfcZykXJ8dMhiOqNpNCYEpFRYZyl7JScnJxweHrJarUiS\nhDRNSRLFgwdvcHV1xXqdUtcNvd4Ao1t+/NGPGI16vDo9ZXtrQtO2TGczWu3Z298lV5GJHI6HGK05\nP7tge2ePk4sZB3fvsDMecX5+zmjc5+TVI7a2jiD0efHsFTZtOD58i8urhlfnJ7TakogCqW6fpeI7\nnr0xm1n3BVTgQdfgjEO3Dm0twzGoWpLuK5bnlnI4YrHQLOsofLfW4oRiXObYUBMQpJnCB4P14POU\nv/MPfoXRoMfTk/+V7cNjsjQjUZI8H1DkOW99+D7buwmjq5+SFimLumKVbDQQMdrgYu7Ym8iYI5Qo\nmtp3vEL0TgTXuWBs3GhFAsltWWUfEAScFKyFoCVgApQExggqDxfAXhq1LIWKDHftQHtwITDXgbWE\nXAnySEhReKJ9G5AiCn5dEMy05+Va0FhBlkGeQVWvODlfsVrAP3j7ASoIiiKlbedYN4zaMEoSSpp2\nTggW51NMEGyPJ9SNYWd0F2ta5hbKqDSgNo69d97izTu7nK8a9sZfLjz9ssv5QLCRLUJJ5E2sLEix\nId1kpxsPsUwiI+BRqcDZBXX1Rzz6/I/4H5/+XX6lPyTzBclgTK/Xw0M8PIaU1niKImU0KPDGkkgw\nrSEfDDDGdnloLjplfRJFKNaQZj1EIiF4nHe4VUvzyROGIolltt4QVEqorjn5/DF3797FfvIps7XE\nhRaVp9imohzeTjdINxKbw70KEPBIKbEenIewVARdwfPV68fOwdNmh51JQQhgnEU7Dd5TVUCz7IJ7\nYwnXdCGuXntUkVDKFuzqhgkdlVNMk8fnwUsIZdTC+S++yPh52NBxImqQEhGihMN7qus5yf13ufj9\nf8kv7xSUWU5bRbdf2f/yufOVgMdVNbme4e5sk0jJWGzh/Ijq4pLDvUMuL1dM8hY1SMlbw3D/kDKp\nuLo4Y+V2GY80Zr1g1C4xyRhlpoitO+QukPYLZsO36FUz5npOnibkrmKVlbTWUsqU0fJz2vaSbHxI\nur2FGO3gFxV2vMNWWXJ18RwjMnQ/wdZr5Ptvsvjoe+RFL05kIopVCAIWGeKU76qAUQV+M7C329Zd\n8OQIXKJIgkMKQZAC5T0i+FhaaxoutGZ7e4vp1TVpr0DagLCWoldEIOIdMkBBQiUMynuc1TiVYAQU\nQlJ5Q9I9sEIQ8x1QBBFhXColIThWTkda12hQyU25LoSADKCkRCBIpMAaS8CiUJAqggdXLbH+ACfj\nQpkEiZbh5woedNZinMZ6ifUSFwIuSFTRIxEpk0Ly+HyKfHCf7V6FWLb4Xo/aet45GPPTkzVlafnk\n3COTwLLMCTUc7ARqArVTXJxafvLCUbWeug0MS4VDYLTHJZJ1JdgbSIyXCFWyuLLIPGeUJaQubkfO\nOoxpMabC2Rara6xp0U2NMxprDdZYnFmSyQmuvSTbznnx6oy7ByOm84bFLfVfLmhkgMvLKf3BDqC4\nvj5htSqYjLfYGm8xvVogUQx6JVfn13gpSdKcshgQjKYse8znc6SUlGXJfD7n4GD/xq21Wq346KMf\n8c7b77FcrijznF7ZQ0nFaDTger5iuVxgnGQwnrBYLJBSMBqPSaVkd2eHsj+gXLZIIblaTEFYtnfG\nrOshw16GFAn90ZDr9QWLZYP38OaDuzx//DN8EBTF7RdlqWKZ1ru/SoIIktDZagEfHASHUqAkpIMU\nzh0/e/qKrVFOpaFwIKXicun5hQcKLwoEmjRVXFk4SGDdWn7jN34XZ+HxowXf+90/ZLA1JkMzHA75\n9LPPaVYzRvuHZJOSb779BmfXH7O6Mp0GJ77Cfi5IVTxQqeQLDEJnx4UO9GxYh9svOd3WLTrtRUBI\nWAHWw7YM9AKYIHjVwF4OiqjXGSCoBWjfBfIFWNlAIyCVgr6Phy7RhZUKoLKB0wpM8AQFqza+BiWh\n0oI1AaGg198hSzXGgleaQuQkacl8fY4VgiIZsTY1xlYEPyBVLdYOmC+vWRIrfdZaevsH3L9/n8mw\nz3TuWITbC96de+22FT7EsEfJTXgrIgKf1EfjEUHg4CbnJQhwNr6vuftt/pUHVx3wHwtP3TakSU6a\np9SNZTjIAAmuJUsVqYrrf9vaKO4l4OdTamNQSYo0AZGkeDyZkmChzApUdY3b3UFt3cHPniNVSgge\nkffYv3PA1XSOc/Bi5vj2lsc0GpUXrKufw8a2YU39jS44Otp8wLWe0EKYA04g0w4Ilh4nJHnicc5j\ntYluN6dQ0rJ68jGq61og8QQUxjjSNI6v31BuKpZ/gwOl2jjmL06B08jqdILqjTsrdG7RzXwMHrwR\ngGTVGjJTMdk/5tXjK+QHB6hU0SwbirJgffXlyfdfuVIHIZjkS9JqzrBZI2YXtEmfUBaYtCDvO+ph\nStbf5s5uhq+X6AX0lSJtLsiaJaZeI1XCTtpgsj4qz8jQNHIASE4MbPX7IAV52sPowK7UeFtR10vS\ntsKHnEruIHtHhMufMtAWMV8w7o8ZiTW+LFCFx3/zA66nS4pMAQERQidY9vguhDACoI1zYoMibl3Q\nIrgoeo5pzpKl952lF4Ru8QbKBHqJoqljqSjxHkcgSyW60RjraYwF49DORooxkQyShJZ4gqxDGyeC\n8zEGnIASMkK44PB4GmcJPran8C6eWggBGWJPIikFUkBwFms0tdUEEeiVA1SWdjVci18usd7gnScx\nBmNrgulCRm551a2l1pq61bRtg9Ytuq0pxwf0tyTXreNga8TbOzLGCGQZj5+dsyawbi2zeUNpDdPZ\nmucnC3RjuVi2nJwbPnvZ8i9+d8m/+3jNqmlBeD54d0i/p/AWVCpx1tHzgY+errlaaD56vKJpNcNB\nQnVZ41YOpTKc11hbY/Ua3S6xek1brzG6xdjouNNGU+uG0bhgNMqY9DOUCiAMeb9g75b6yhfnZxiv\nmE3XEAJGtySJZ70+xxjF9niPLO2RJyOadUCQ8fjJC9Zrw7qqKMuSsiyZTqdsbW3hu5DBk5MTvPc8\nfvyEpmn5+te/gbGWPM9J0wzvPe+99y77+3tMJhOc91xdX/HRRz+iyHOGwyHVehUddNbig+fg4AAf\nPMtqDsKxWk+ZbA9ZracxNM5YVmuL1prRcMhHH/0pu7sj0jSh17u98DQVnXAxhL/yTHafN/GUHUIE\ntst5TAPfzNDtcUHZyxgPs5jo6qPbrpdv5nkEJmnockiEoj8s6Q8LigLuP7zH2++9y4MHb/Lwza/x\n3rtv0eiKpRXoVeDPf/KC6+tOwxLipiAiodCdCzpK+QtHySBFZFKlYKNkDi7mPt3q6s4e7sZdE8FL\nJaDqmOKeCAyTwFUjaDpGSEooFRRSUG4AYiecrm1gZgJTE1homOvo7jprBEEqiizDWCjyPt/+4EMO\ntoZsjXaAIe+9MYmljGSEZ02RDZm37ub+9ESKcwbdLhHK0Zg52uUs1+fsjCb83b99n2vvUbngrXfe\n48HuBIQg70/YK/PbTp14P72IQas+RAdm1GrjQszfsd3YbZQ9oRtHF6LLyzlF00jWRtBWUOydoY3H\n+0Cvl6MU5FmCEhKCi+G2wSOVREhFkiqkSkjzgubJ5wihEEHirYkOIx8IrYZOmiDyHulej2A1Mklj\nmK3VWONx9KMVvS+Y5ApnLNfXs2jGcbd0aX0BXAc6PU8HnEOIImEpAnIQkPtlbGPUxDLgLChyJQne\n4xBcXFyzXDfIvI+cv0SIJM7Jjo6xrhv31kTgI6Ij2nvwbSxbdXREVFZ1z6XvdFWv72F8na5DZ8Fb\nnPfs5Ak7x3dpFnMqAUJ62spyfTGHTBK+omfLVzI82+s5vpeSDnfplSn1fEaoK8a5ZtbWZIMx272U\nfn1FI+NJuswzqvUpAxNwIcVnI7RoWbYVaZqROU2QBWNzSuZ6WJaYfItJmTGTfZQ9IzRLfL6Hv/N1\nivEIu6gwwbEKkuLf+/ssm5pRO6VgwEKnzNsWuc7ZPXoTefQGmTBd4GBcX4KINVwrYn1adguTDZ1W\nQHRhVbe4QvBRRIlDADke6Sw6KKRQZMFQJSnWeFTwjNKU6+Ua3TTYRmOSCFAyAa33N/Sd1iZOBOdR\nwZGrmLLsfWwR4XxACo8LggKBEZC4gMwVofXI4NDOoxTYTrmUipQklTTG4EMgEwopJU5rXAggFQHJ\nlmmx2iCCwBA1ATmaW/Q3vLn2BoLpsmKuA41TMc5cOpJygrCBDx4eU60q+mXC3Hj2Mvjue4e4IOkP\nclpt2e4pJqOCi0VD3VoGw4SgYG09w96SXMJwOMT4QOYSBpnDFy1FkpKVkidXNe/fy5ivLFILRJ7z\n4qxmMko43hL8xR//bazHFwAAIABJREFU78zPTvnm3/svMLrGmLoDZw1Wa4zROGOw1vLOW/dYvjzl\n3TfuoGuDlJrpyjCUt5882mWY9ZrDw2MW8znDfkZoLMpGC2qW9vBO4hzMF0v29rb57sF3cVbSNoZ+\nNmQ2m3N4eMh6vSbPc6y1DAbDmKw8GsVTUoDj42Nmsxn3DncYDPsYrVmvVhRFDgSyLEPWDUIIBv0B\nx8fHKBFYzKacn58j8wFSeIaDIUWhqM2aPE/oDfq4EEhVwv7WFnu7uzSN42h/RLVumIyO0XV763nT\ntDGYT4buMHKDtUO3qNI11Q0Y45hVIFOL6+bo9cWKz7KXNJUmEWDxiASqpsFah8oEMlHY1mNzaKqK\nhw/fZDLu8+MffMb9N+6ztTOh3++jpIxlZusYDHqgYLmq0KaTwH4h7CV0S3jsUMdrqUg8e930BBI+\n/lvwX/j5v+HlO6Fp3BzitNvc5ysRtTrbKjJhozxQ28iWDVTEWmUSreuJjPZ1E+J6omMlgLYDQcZF\nR5wQlmkF3/2FX6BICzJamizhfp5yfJxTMmc43MOLFYkqma2fU+TbOK8p8h5KBJKkoHEBY1O2xgfM\nlpdsj+8jZeCXf/UDfvsPnjHav8vRnXsc7Y6YriqGvV5sXnnLy28scC4KwpN4wr2hCf2G894Me3dP\nVJQ/dmMZf4cPUdw+vfD837//n/HBO/+cb3/wDcp+jzxTGOsRQtA6TS+T9DJJbRyNCxQEVCJoXz5h\nKDK8N92G3TFsRoMqsM6Qb02QZSDYGpRCJJJQB4RKGB8fk0nBu4egHRFwJRkqFZh6fbvBCdzolXyA\npKO0RPdv9TKQLWNIo0gUYdmgVJxbRqQoXPzYB9ZVS23i5ipWJwRShJIdIM9xrLBWkPgY/EsHgNGB\nkIiYf9dJKW4eIfeFrgdfuD9RxxMz5YTTyM5lvDae33n0gncPBzf3zdmYZWe+4oD+lYCnaZ8zOv4l\nDoc9ripNPT5EXl3iR8eEZUXpA4PRDn7Qwy9eYfoTliHQ2zmifvkZl65kq5TMQ0kvLfHNFK0tl16y\nn24zHPe5vhT0tEMl4O0FvTKnrj2lX5KaBCUd5wQGaQ9bVTTVNT3tCNs72JCyGjjyVlMkNZXNSN56\nH3nyAwQKEUS0oSMQXpJ0iD50fUykCPggEUFsJOx/88v5ri+XxIZ4AnBBUihoQ6B1gTQY0u7hWUtF\nv0gpRXRUJHmCCpGOTl1sHJorifIBLyTGa+7d2eXR8zOyLn1TESiUoDWeVAlM8GgTA9aSEAt3KSCV\nxIdOx9QlHdfaRIpRCGwwYBOEgCyJJ9/gHOuqIqsbGilwIdLNrVSkP4eI59GjH9FUCwbDCY3pY9U2\nQuYkQrNYW8rjwP0HIzyK/UHJujUom/JqtqRMBFu5ZLLVZ+4bPnx7wr/9k1OqtuXD93MSAou54o3j\nAeXdPWRZcf50QaMsJo3Wy9XCRLF0I/ng4ZCs5/jDH62YrxKSN1Ns22f14vs8vf4+78x/HS8Vpq0x\nbY01Da1uI9hpW5z1TFfXLKqGiWmx2nN5teR4Z8RZfcnDw71bjY3M+qyvZhg9RQFF0kNXNQ8f3mVu\nLWU5QgjJarVi/3CH9XqJcgWvTq8piwFJP2O1WvHw4UO01jx+/Jg0Tfnwww+5vr7GWstqtUJrTZpm\nPLh3j8ViCcKyXq9pmobecNI5CGPpoGnamIkVAtYZVssVdd2Ahb29bRbLC1bLhryXMJr0MUays7PP\n6emUUsD16QVae3qjXaSH5Tr+vtte1sT+Ov7G6f+F5zIRYAMegQ8B0/VGyvOcg8M9Ts9eIIJgOBrg\ntUYHg/SCXEqKtKAya6wP8fnoMlush7OXr/j4oylNA7/7W7/DeHePr3/4kP29PT7//AVXFyse/x+/\nz+dnmxcSefbNqTlRsUyMlAjAWA+24+I3LE7noNq0mmDDMN/i2mwGoSuXed+tbV3vqAUxqXovmmwY\npnGDmprI7PRU1OxI4t+hK4EtiUDIdcxHogQ2wKKF73z4Ndp1xaNHf8EbWxn/1d87YtLvo/oFv/W8\nQsgabxuEKnCuxnmBdoZhknM6uyZVQ7a23mC2eIn0liwb0tanCJHxn/zaN/lv/7vf4vD4Dr1M8eNP\nn3G8U3I2u8Do27VrgRguKLqEZYmITTK7ZrSbWoYg2p29j99PBwDw3UzbyEZEzDFzTmLsT/iTj/8x\n947/iLxq2d2dkKcpxhrSJEWHOI96uULKBONjWxFnHYIcb+J7US7KFwgBKWQEaBLE/WNU0SOsmwg4\nVIISUO7cRaUZwQVSqej1B9jWUDWBejG71dhsJByv8cTruedCFAzbijhPwxIlwBhB0gbqAMNuD1pV\nmt5bE4y+wnROZ+8FwcW2JsaKyOo4CDesZ1diJO6zQcQIlyDETWRD8LHFxM1hoYsOEB68iAd35y2J\nh8Z5Vuua65ef8639Mc5b2tYhhSeEwPr6y0taX+3S2nqP5mxN1h/TsqRYOabLBaK/x9o6jkoBraNa\nLuh//TvY1YrxckpzvaZMTQxpu/RkZYmxNZoE0Xju+mdI28fqFpFper6l9ttMlqfMekOGSWBJghcV\nue5TZgl6ren1C4Ztg88GBG0w6wsmMmo41NYBfbtkuT1CvYi31IvwWpMfy+mobtDtTVG383LdUoeB\nCHhrYjR3gMQLrHSIoHABslTRekXdtvQSQVFXhDyDJCXBk0iFs5ZMglYCbWI/LSGjRkc4z08evaQs\n0o5+9ogQ6+5BdLEem/ekwLk4czwy1u4TsNZFrZyU0c1EQAZH8CICzBBorEEh8N7RrCpSY8B3zUdl\ntHTfmv4CHv3Jv0ZJuLsz4vR6TbL9HWT/Dpcn3+fBwTbzlWWvL/nR00vOrtbc3xsxyAyvtMYLyaVO\nOVk0ZCl89njG3qhAKcn0zPB87tguBOsAW7uBxWXGwds7+HbNbNbi1rCYOUYEViLhyWnAmoaMAuUa\nvvPOMWYxZSIu8fNAszrHiQLnWrRu0E2L0S1WdwyP0eA016s15SynyDKClGRFn1oVnFe3y+G5f+9r\nPDx6wN2DHT776ccsZufs7Ix5/OinFHsPCblk0B+ilIXgKfKcV+dTIOCd5fLykrfffpv5fE6SJOzt\n7aGUYjqdorXGe8/Ozi7eO05PT7l355j79++zXs+RsuHOnbtczhYopUhC1DEdHh6ytTUhBI+xlslk\nDCrBkjIajVkuL8mLjLxMMdZzeTkjzXLKPGOSDfn82ROslVxPz9jeHrKYVlxd3d5arDIR++w48de1\nY90a3TQt61WDsQaBoNUt1/Oa/hCKMuHtt+7yuK15edZgutgG7UEbj5SBtvVkWULdxpP3eGtMUAJ4\nzvvfeIdiOOYXvvVNtrf3yZXn8z//HipkQB11Ox0bsNkhkzxuqmkSn31tOyZGBEgEwhHTo0PoDlfx\n1HrbMjqvyYrYCLPT3dxE8csYqmi9YFcEBiIm7OZpzOa50hH0FB3jEwgxgTeJombTsWTaQ9MEvvH2\nm4DkR48+AeC//NX77I9KZGIYb6W88azh7OwlB7vbGKHwvmBQ3keKlFpb9idvkhYlrVkjRU7T1iya\ncwbFiOASdnfG/MP/6F2sHTMqMmrpyIoRtYKJvPXoEEKXpC86b27oSiRdIGcXhdSxLdyAXlwHBG7y\nIrtdoyNvrYU8hX/z+/8Nf/sX/xk7O+OooRISmQDe44PEWocPDumj/KBtPalwBCUQxsWmzkmCsCZq\nJFOJKEvEG/eB6FRDBIJMECHyUTt7R4wGnzLs9Th5ds4v/tIvYZsVvdFtG/PGktIGbEftavc+O1bP\nmmgLj76eWPRrKtAyapSa2tDagHexBrjp1sEm70gpWgO95DWIVCIyPJtQHWE9QYnY6y6J94JNTlVX\nfgw3qKe7V93T4r1FSIH2kmZ2xfUP/4zsb73FslpQDgXj/W2sdnj/5czyVwKeUE4Y7k1weUkpSox+\nRZoXzC+ec3D3AfXsKVMUxc4Rl9MF6eoSVfZI21c0o/fZbRc8tTMKuYIkwS/P6DlN/tb7XBhBcfmc\nkZXoYkQVHGuZs93bR+WW7ZCyqBfMxC6ZfoVU0T4c8n28NMyXDYM8w6kxSlyw1pp20VA0K7xUSNEZ\n2kIAIZFhgxIlUgQIEkHU3VgCStzW5kfnAJMk0uGcJ0VhhCBYg5eKHI+XgdZ6ijyjtZYkSUgQaG3R\n3tLL0ljbJWBswKQB73yMw09ltMv7eFIJCHLi4dGEqCEaoFg6fZPcqkVMmdYmkISAlLJLMHWIEHN6\nkKprYCijWl9JVBBIbZC2RlpBKpPYc0xIWnl7iif4wGzVUirJ5bMTJlenvJqukDLwYuX4+oMJ52c1\nD7f7bE9SDvKCy+s1eV9xeVHx4bsjXJ5ileJlvebb7+/y2aNr3jkc8eCNPkVjeOks26M+qmoYbEmq\nVUKWNFyKaz64O4TKILUjAK8uClRWczAZMKs8zUyw0j+jv0h59K/+Odvf+SfY5gqb7iFkjnOCtm2w\nRmONZrFecDDssT/oYYTku197wGfPX0EiyEa3o99DCPzxH/8hH2WKrWGfyXDAYn7NYFDwyaMnfP39\nbyCkJM97nQNGY82UPMvJs4LJzjbOOebzOU3T8O1vf5vPPvsM5xzvv/8+T548QQg6IJRwenpKu15w\ndLTP0dERWmtm0yl5XoD03Lt3l+vra87OTjk4OKCp16wWU1aVYbx7yHw+o8hzrIs7gtWGoiyYzRbs\n7dzj1atL5vOKujbs7N8FlaKSFP1zdLxWEkaFoN2IFW++Im4qXLHzddyWvI+J/EWScHiwx4cf3sPp\nhpdnS0wAKwPSe/JUIaXCuQbtFKvG0U+gaWEw6JEVUW8UgMnWsHO3abwNKCk5mAh0UvL8UYxXFkKg\nVMytCZ2GyBpNIlKsj44sIQVBv65uiU2Nvdt1btu/z7MBWpsfDJ2GSNy4MgUCTeBMCFoX2JMCJaFP\noFSdRV0LchXzeZSAnhJkMqB9/LoXgopAnuf8yY8/BiQHpWJvpBC5JMmGUU8YDGlimdcLtvbuULuE\ndX1Gmgwp8y20WSETSSIzdrYGWO1JBCRFybyqqN2Sf/j33+M3/01CmQe+++GbfPb5K0iGZKPbu7R8\niCF1LgAuhsAmaXjdksnHeRS/3LH9fqPz2ST1xLU2rraiK4FJWh1ozG/zB3+aI+U/Y3tryGAQy1tS\ninjo9HFejXKBNw2+MbgkwXfGmVhLNBGoEkXnvH2IPNiL9kOVEIKLbIhzyGxEWhTsDPucXK1JESxW\nC44OdxHhduxpxy9G3BE2IuH4FR86wXfXSFncsJddynkISBHdwD4IqtWS3r5CiuiENETwKIHaSSbh\n9f6olL8pHYZOTxYssdRLBJk3zGV4/dAHH/PyQkfJxS85TPAkBLSzzC/AaEeOpF5rtg8ThAg3tvb/\nv+srAY/Pe+hByaJfYmbPwUicbhhmJVwvMIMxKu+T6ilH9ZLry2vYuofYOkKkgmubcbAtMStFtTT0\nkoLhYExVSwbzp/T2jzHrcyRRDDX2M3rXF5h7f4vpxXMmRZ/1/AUhz0iCJ5EB7ypma420FSuRkrg5\nSX+XZHlKkymy5Qr8ZrK6eOe8v5nISoAVArWp1QpP4sOts4R98KigaJ2JOQ4hZvwoF9tXaGvZEQqZ\npMyMpW1bhlnOMhUsuhqqUoqcCHATIQgitnXwIsQgJikIQSCFxwaPJD7MCiBElsqGGDQYOkidAhZJ\nKgKVdyTWgYwCbktE8HGee5SI2hofHMoZKmsp1wakohEOJSRWKZS6fSaG94J+pliuFuwfjPnmO/sM\nJts8fnnBb/7bP+Xz85pvvrXLvKqZLRqKHYVNISNB5Rl1Y3h5taLIUlIJP/t8yr39PmtjmD6b0u/n\nyJ2UJ58+x6uMoh1TrWts1fLW4YTKN1TaYrWi1wuMj3L2R2OG2oDRPHpVM+wrBgcS73JO/uw3CE6j\n8iHrumXn4bdZmX50r9mWrWEf5zxaBEzreDI/RwvFcNAnSW8XAuZtTaYUEmLQX78kS3us5heYNlBV\nNffvv0FVrZhdXyElvPvO1zg5OSXLcuq65uXLl2xtxbLUz372M+7du0fbtjx//pzxeIwQ8OTJE7wP\n3Dk+5s7hLk275gc/eMRka4KQEu8dWhvG4y1yFdjd2QIBRVGQyG0QK7xzNHVNr19gq4oXL16wczBi\nPJ5weTEnz0sW66eUgz4ycxhv2dk74POnp8wXt8snAqibwGIl8HZj6v5iQb9jIEyDbrt2JxbI4OBw\nm7PLOZ988oxhEZ8bgKbxrK1HEMizFOdapIRUKVrnURL+/Ps/IcnjUlgva5787Bn7kx36vYLHjx7H\nhoSDnHcORyxX10irub5s8d2eU6SQpcSQUxnLKqEDQqJjbUOIIIiNNunnuIJ/rXmIJa3XFZvIPIWb\nwEPrAlMhaB3shkCpujVHQC4DrRfUNq6JAwVp19yUIGg6Zk3ebByef/TdNzEu6Q5WLYI+3/yFXYq8\nJB8e0jQtk3KbJC0h6WFMi1SSJB1Rtw3GrCnSPo1cYkxLIjPqVvGtb93lf/uXP0XLBzx5foIWPYaD\nkqQ3uPX4eN8543y4ie3wTtyUAENn+tiIzb2LDJAnxJZDN1Ot2xvobOsibsAyBOaL3+T//N1P+fe/\n+z/wjfe/jreBNJMMctkdVEVcV6slSkqMjS16pAgEZ5GtQWYpMnTRKHvbiDwn1FWcK0HQ2fnI+kOC\nzBiOt9nt9GveaB598pSvffjGzzmJeB1ZcoObBdZwoyn7wjRC3di5wGsPIrBqzE33edH9UikkLgiM\nVx17BsF31QHrI3vWAUxBHE827kW/AVzx4xsG1cdviSaczeuQJCEgRMI1MFtqjkpIEglK8uj7P+Jr\nf+eXvvTtf7VLqwL/4nOaJyfYa01YvaKXZbjRiHYwJMxr8tWKEsmqv8twqGgnAyotqJ+ecmg1sveA\ndPE544Fja+uIfDigNZoLsUNYVWi1Qyszes6QDnZY9e+xOn2EkDlLmdKb7ECzhiyWXCrrKPWMO9u7\n7G7vE3p76HSAUSN2dye0rmNDhO9uRqw9+iCQHoyIziVPIHamChASNgFHf+N5EwLGGRIvyIVAJQrV\nCaekCKRKssRFC6mLjoHWWyb9Iv4sEYnWNhaSlfcoAso58hDPIcIHBB7T0YZJiHofQ2ReVCculAGE\n2li/QXpHY7tOmEoSiH1qEhFTk1V3mtHORBstEuc9rtYIpxHdz3prUdYS9O17aS0WU4JwzJYrrs9m\n/NkPn/LxRz8j9Uu+/uYR33jriDwXPLs0eCd4Y3vIpFdyd6fkeJyRZoGLpaWvACmQSH74dIrTnt1S\nMEqgJ/skeY/jnQLjaoKqKMeBpxc1L55ahMooR5DniiLXhFBzJWvOg+LgjSGWlKoCL3NIFdoLrLME\nAs8++j2WL79Ps/gcW1+RqRSVZtSrlvVyTZqUeATKW67Wt9MbPHv0MaNeyf7ePv3BEJWkIFICJWUx\n4OLikjwvefbsOd5HKvnk5Iy7dx9gjSPNUqSSFB0rIaXk5cuXlGXJ7u4uWmt6vR5JknB8fESaJKgk\nIYTAnTvHLJcrnjx5Qln2uHv3Hp8/fUYIgaura+7cucvu7i4+wO7eHt4HrqdXtG1L07S8+867VKsa\n76BposDe+IplPYPEUzUVP/74x1hneeutN289b1wd+z017i/zO/BaZudNwAV3s3hJCdVqyfnZCbqd\ncbG8pm0sxgNE0YpDxEyezrljbKBMYuPPd792j6998DYAw8mEO/fvcPfeXQ6P7/P2u29RrywhSZCJ\nYDZfMppEu/3Glp4k8ZUmSpInCdZ0WTsbNAJx5daho4K6P7fVDQa6DfH15rT5eNOt/aZzu4+5NMsQ\neBng0gqMj6UsJaFUgX4ChQzMbWDtopBZSEiVRAIfPXp0879+ezcjeEvwPrILIlD0DpAqQ0lJ1b7C\nhiXeLxFuQZEP8N7jzBpnZgz6e6RlSZJsQQhoM6UoJty/MwKvWa9XpPmYfpmivOdqdnuwTKftsp0m\nxLm49gYXx8KHeBCLVvSurNXpljYMBHyhrCK4ccYGIXAheqyFe8Tv/cGv8+zZc/Iio9GOurYMhzmp\nUljvaK/OSZSKgNQ6WDegLd5ZZJIipSA9mMRIEOOiw1d0zEta4IVECMH24RHbo5IyVzRaMt7e5s13\nHrJarW49Nq+3uE1z1ddlVWs3LrWobXLd+BgbS4LWe+ym86izgEQbhyPpuq1HQGm9imxa9ztEJyCO\n4YbEykKkKm/m7WbObkqzIUTw4zf3oGNznbUIAjp4mtYyArJg4r1xgfVc8+Z3vslq9uWC7q8EPJl+\nhU4EbvWCNmlpJmOMUPSLLYIINFqTppZK5WQiYT18E6VbQlsjegV+ssvs9Dm93XdZhRQpA/P5mtyc\ncTdfkuyOsUWBG2+TNa/AJvT9iqAKiiTFt4aqdfQGfWamT1ZKEmFI+rtMRR/tNam5ppg/QWYpl1aR\n9wrCTd5mJMGF39SsJTLE2uWNhCdIEP72gMc7UiFAOFrvaLRGG4sIniRI0hDQITIwibQQAsJ5GmtZ\ne0cKKDyFFFjvMCEQRDxBOSlIhMBIMCEunpbY+DDWowXGRYFpQWdk8Y5URCu6J24CqYxiMhHi6VPF\nEjFKyk47KQjeda/ZI5ynXS7BW5SLGiVlLdLfHvBcmhVpIrmqK/qjgixPOF2uWM0FgwyadsUPny44\n3Ck4OpiwNoaPn81IgV98cw/XeL715oBKBqarlp1Rj1/7zj36vZS7R2Men9esZyt2J0OmVbRFl3mf\ntglIJXjjwZisdMgkofUK1c8ZZ57rc49wKUk54PBgi8EwYV1dc3D3Q6wQrNYxcHBQCKSeU198zsWz\nj5g1KwqV0ViJUZL+IKUs+lBresVXEqV/7bImakGurq7o9wbUjcYYh/cZw/4QIQQ//OEPOTw8Yr2u\naFoLQXF5ec14soNznqPDI2azGcfHx+zsbOO9Z7FYMJ1OWSwWrNdrdnd3mU6nbJoJHR0fRlHfekWW\n5VxfX/P8+XP6/T7jrS3eePgw6oCMwTjL9fWUdbXm4PAgltPyDOc8TWPI0oKD/UPWq4q8t01vOOK9\nDz5ApYKj40OyXNA0t3SSAIRusQ3AX30mN3u8CNAtzpt1fLFY0StGvPfOAf/hL78TSzMBtPPUOgr5\nlZSxi7oQGBcFviRQ9odsbcfwyLKfc3Cwi1LRDQICbeJi3C8yEHD26hqIDieAxsYDBEpivCNTrzeX\nzYYiNkFaX3xbt0w7jSfnL8TtbzbmjVYldsp87XzpdEbawZWHV16wdvH7BVEDnkliGi7QeFjaQGsg\nTwSViRTWYb+kUAJJA97hQhLLHNJTDg5jxEU+Jk16nbZFQdAcHX4XVI/tyZsYfcZqcUmrK7K0T54f\n0Mv7aHHIew8NptH0x0NsSEFJesXtbenWhhshsvUB5zsxtt+wOVHjIzZjttGzhE3pamOC22hGvjj4\nAReihd06iZTwf/3OP+XsbIrWhsaD1hbvHUmSsf6LP8HJTnDp4s4dulBNCch+huz3EPu70C6I0d0S\nZBobNmcZIQR0W2FFj1nleDWdMZ0uEViEul1OkejcfBsrfvw7HiqCDzgLQQe8DTcg0bvIvAgRDwu1\nsQREFzsQ4yFwG3AZEEhMNwd9iG9JSl5bF70gOHEDNm/m7V9idjZMZrixqEN8HQSHCIFcSer1kh6R\nh2tah8pSVKIQSvDq0cmXjsNXAh7bHzCoZsjJDrLM2PNrXJLTXM0xTUXrFZUfkVQXSL9g4M7R/TFS\nKQrzAy5PnnFYtqSzM+5kmsG4j0wzDvfukfZ2kNkA5mvSqoZigiwz5skIFzzBGco+BD3l0hTsDgKN\nGOFlhuj1WYcU7TJEOYJ8zEh5RldXmLNLhIoqqC/E7OCUwKkNsOlq6BtqE9GZy//mlwoBEzwy+JtA\nwFQKrAgY72i9J8OTSk8uEgrho6Auwi6cC2gfotCto6ad0zgCZQjUwZF5GV1ezsXFy3mMd6Q3AWIO\nI6KYy/n4INsQOwGnIS6AInhcAOM9XgRciIK6ICRJl02kiAhetga9XFI4B9ogvac1FmFu77bZ2uoz\n2BozHI0oxwV7O2O++/4DnlxcYVygrT39NIlx7SLw6fka4wLf+3zKTy+WrFXGJ4/nnJ03fOuNfU6X\nK2QSeO/uLmmesTfM2BaOV1cLVBop1YuziuGwZHs7xYYW4VOUEJQjhQ2KpetxfL/HpKxIK4NAUqox\nGZIffv//Qa8arNN4o2l0XDCraonwNUlW8PnLUzIMpZJcXS0YpYHhKIvHoFuNTQQow0HsdC6FQIgE\nJWM68t07d+j3e6RpRn84ZP/wEGMDz56d4JxnvV5jjCHLMtI05fT0FePxmCzLODk5YW9vj6urK169\neoVSCU+ePOHq6prZdM719ZTt7R2EFCRJitaaEALLxYJXr14xn885efESgaAsS7I8o6oqjI2hlufn\n57z77gfM5yuW8zWrVcP5VcXu/h2sN1hfMd7qs1hekme3L4WCoOhBkf/153HTJsnZqEVTgmjxJgIa\nKSXrynJ+folz8SQvBdQaqnUbGZAQN8E0lViAABfnF3z608hm/PDPfswf/b/f4+Mf/5hPf/oT/vzP\nfkRtIVjPso6Na+vVDb8ORNNA/FXx35zgJroqtpbo9Bqd1iZsdtbbDs+m7Xp4vUH4zSax0T188XNE\n1zE8RLeWDbwKgrmPTjchIuhJRMwMGyrBKBFk6nXgG8C9UUqOIRUpQUR2QvsUXbfd4VLgXEWiBhAU\nQeQ0RrNaPUEJyenFY0CQl9uMBnt4kbIzPiBLeyzWFb/yq2/x7OVnXJ0+Y5RahtKCuX1bEtEx+f5m\nHOLnzkWg6B03QDmE2K9Nbu5Jp9mJPx+63/F6PP0GJBHXVeclzvw5//O/+Mc8fX5KpaNLdzwo8PNL\n9OkpiBTXGrwx0bHlPd7FTVumCrHVjwdOY0FG/YlvKpx1YA0hJAwnE0bDHkc7Q4a9HuNBgXUB095u\nfMINmiDqlIMHfqZZAAAgAElEQVS4eZ4iewLOCrwVOBM/9y72sgvEJa7RFilidIq1vmPDo7Vqk33U\n+AgUA5uDS/hL8xU27rjXZarwBcC5GXPvYmnSuw6QOfDWYkPgsjEs5xW7dLqsEFiuWozRWOOYXn85\n+/WVR9PEe4zPEa1kK9XM7IT+AFaLFf9fe2cOY9mRpecvIu767tsy82Vm7WSR7G72Oj0aoTE9gkae\nLNkCZAiQMDJly5YlT7YMWXLkyhEgDATMoCFDmhn1MtMryWazKllVub3Mt7+7xCYj7n1ZZJOcTnpq\n5AESWZXLy3fjxo3445z//0+W9ujdn6BVRloMSPMcnVfE2RC1+Bir/pBjfcH+cMIsH6J9wqpeMTjc\nY2YLfLQim1+xQTLwFT5OcCIi0xafKMT2OZV+kyQXCHNJWR+hBhnm5Bl1b5+0OSPKPVodIOKMjavQ\n5QIhQmpQtMSsgOI9kRXomyJXuNGdQ6S/PYfHOE9sg5cPOCIZTnxKBBKX8FBah3A2HLBbByzhHMoZ\nrI8DilXB/8JYSyYlwjtCwlAS4dDeE/lQioukak8sDiOCwMwS5JOO0HYC4ZEu9NuRsltXg6LDuKBq\nkSiMCP23HJIah7Qaaw1isaEebHEqIbKB7xTsMm8XRwdH/OY3J0SpYDpbMpttGPVS8lzx/HLNHz4Y\nUxjN/b2U80XJ6WzLvKwZZBF1aYkiwfFeAd7ys+dTjvZ7/PmPT/mz779BXRuuNg2DImF7uWIw7DMs\nQNiEsoF+njM+yrm+XlMZSyE8qIr1NsaZCF/U9I4kH/38kiLqE6uYPNU03pNJIFWkiefjizWpinES\ndOlQUYxwgahXa4tznvmiRg7HtxscL+n3C4SXaNOEViHekecFqdtydnbG/v6Ys7Mz6rpmsVjibPAc\nquuGKI53Ke3Oh0dKSVVVHB0dcXV1xdOnT/nVr37FYDBkOBxSbreMhj0eP37Mh7/5Ddvtltl8jooz\nmrrm+PgQYwxXV1N6RUGepVgL66rBOUuaFvTygrSX8+GHH4GPePLmO5y9uiAfHqKyAYgSFW05+fhv\nOTocslndTr3WxWoeTm5d7Mik7TSUsjVw84AFqyBKIlRVEcWSvBeRZYpmHUpYAEophLVEKpBInbE4\nGQDR4dERaZYA/5c//v4fkfR7fPVrX6NXDBDW8MO/+F9467mcLVuZ+SfZRbJda6z1of+iaxMHHRH2\ntY0F1/HzOmLo7x6esEG717gWu9d9/Yd8i42g/SMdQJOstaeSMJKhLU+xcyK+aayZScEoDgeotYEn\nkz6xCHwnSQkuxdsty3XNPVNjRUUapxi7wgqJb5bEcRZUjmKGiipEPMGYBTio6iW4hihOiKKC737r\nmH/7g//KX7aXEEcJeVHwH//Tf77dAH2qZOOdD/xI2WZX2i7qOwJ5u/kHvkoYMNEByTZjJgShy31b\nXwm4R+AaT6RA+J/zg7/+DxD/e/aHbzAyjuriFUhFow1GV0jnAvjUGunasY4jeOcN/HwGaRbeum5C\n2wknsLoiGubko32yNMOYmvGwoDGaB+MRLz5e3W5odvNCtNLw8EWPbDk07f7FzXUigiGhpfW9cqGF\nkRXB+LDRpi0XOpwL1ZPG3GRtnISuR5F3hPYSrS9S4LfRytJ96xXQgsuOQL1TRIb7IVpK0ONezkc4\nFMF2QdcWOYxoSk1vLIm/IOP+hYCnt7mgzB9QvXpGEjUkB2+x1bB/fJ9q/orh4QGrRpF6iynn5PGA\nl1fn7NUrXNPA0VM2xuOufsEiOmYgPGex5qj5CN/fw0rJIE+5rqb0kERx4FMoVyPSx8R7PcxyRS+O\nMXbJ9tlzhpMjit4B9eIZZZSQrF+gYk+9jrhal6RnJ8j9NNw26RFOYIVESIcQEmGDjZGTYaEKGTPR\n9l353cM5h5Wmtdh2NEaQqojSWJwNJawMSSNEILMhcFKQpBmT0ZB6u8XJkC73KnBrgqOkIzCKAkx2\ngJMSrKHxoU+WEMFUy7YlLIMLaquW1ehdIC93c8kKggpLhKnsRSv0cwInDEooMBbblJSbNVxP6fWH\nmKAtCE3kbhl+vmG51TzOhrz5tXu899ErTl7NSLxgr8gY9hJELUilIEcwHuQMUsV8U3FdNmRxzCBW\nnC5LDoY9/vidY64WS/7bj0/ZG0cMJilnqy3f+eZX6EnN2XJK1EvJnGC2XrNYzYAIZz2lSoijGCEa\nZCRwNubsrKR/HNOsNqyuFXhDESVczrYYC3vDgsN+TpZHlKVmvt1wOBiRRBpHjPKSJI1ZG8k7e7cj\nWB72Mz786CMev/lVlssFcZqyXC5pnGdvb8TZ+SlFkZEkCbPZjPv372GNYD6fs1yuONwfB8fssiSO\nY4bDIZvNBmMMe3t7DAZ9Xrx40RJbLU1dMxgc4b1nvV7xne98m0obLi6vmM8XHB4eEUUxh5MJ+3tj\nTk9PKXoFF9NZ2LSlZLPdgFdEaY80HZGmBVUJ+5NHfPA3f81+P6ZJNoyGKVlqKTdrvoS4D4Ej60uc\nFtStff6OadC+njbBzbfW7BbPzXzJRsOLV3AxC9kda9mZFfbSmMVmG4zhXGiP0E8lSSI5e/kKorAU\nPn/+krRXcDAYEcUxJ8+et414PV//2kOmrzynH84/ATJctyCLAKCi1siu61jeNRAVkiBPt6/tzbcI\n59yOjN3F6+8jgKyb/3Qwp/WzDXxFwt5yiedSQBLBN9KWACBAekiVp4gEeSxg4bjeWkqjcInBGpCq\ngUbz8sWSR9trNME7ylHTmC3Hk29R1hLjIvrFkNVqynp+ikwKirggiwpqVyHFMUl8xf6jJ7z7Rwf8\n8IdXIATGala39JkJF9Z6uxDWOteVSVrXZdEq5KQMjTSVCuoq15Fsvce27Q92GR0PkQw/61SwS9BN\n+Du1FqHEdfnn/Pf/8ef8xQ++xz/7k3/H93/5l+joXjCdrSym0SRxDJstqj8gynKKf/o9GBaB0JwV\neCK88ki9DW86ycFDXVeM7t8n/iDD+hohI6qyYr2+HeDpUo67DuTt9dEeHLR1mErio/BN5UTgG/uw\nj2hr2XiNEpKrdYWpPVaLXSbIWYnbCGrrsHXoV2YjiQqE2TCm5kYJJwjPlNOBd+bMDbBxbWmySwo5\n2u8nAZDW1lLXdRD8GIsR4DSIWKErzfjg81vafOHRPRscYOorxuMerijYOEOeZERGMMj32DQ19fKU\nsi7ZxGNeXF6iBBwPYlSWwmZJNMjY7n+N/t6EOK4Ybl7hlCTeLIL5XjMlywpsNMaVK7xe0rdL8s05\n5WlNGqfkyRjnItK9Q8q6YbGaoZ/8Ab2DN0gO3qR4+A9ZOU/y7ARl2g1dgHCKTmLoRTBJEiJMVrhp\nO4F3re/A7x7WWxQShcO70KOqthbrHco6GgSRCEaAuQBhPU4bBIZG12xaYGKcCwThcAYIKNlbEgJ/\nRwqPsSGVGON3lviyPa147wLfB79zZA5lqu6E191kt1v0jDVBlhmwO03rQSq1xtcNotLUdYU2DVrX\neHP7k/r5csv+0ZDptuRnP/mQPo48jXk8Loil5Ocv59TOURrNlbPkWcLDyYC3jwc02tLgWVUNX3uw\nT5Ipfvnskrp2PBjF5ElC5Wr0KCZLPWeLFddTzXpVYmwJtiFPU7x1DPsxHk1VO4SPcdZRl4ZYSFIk\ncZTQm6TEqeR6WTHa67E3HhApKBuNd5ZGa9I4Zb4quaoiNqahJx2rTYUSnnJ9O4Ll+cUV+0f3mF5d\nIpVivliwLUuu5teslzMePnrI4yeP2Ww2fO2rXyHv9Vitl/R6GdvthnK7YbVacXx8zGq14vr6mtFo\nxIMHD8jznKqq8Z5QEmvdkwN/Nnjs/OhHPwYheeedt0mShHJbBpn7conWmgcPHlBWVQA7IoB0FUU0\nRmOsZ7GseX5yznyx5b33PsBVS/qppZ9H7I/3cUYSqxz3JTKDRIKmhKaBT8OCHam0baHyejbDWMNw\nkPDGkwO+8nRCEon2+Qi/U9U6lDcIi7sk8Ia194wP9nn06AEADx8f8/Sdxzx68oivfvWrfP1bXwcB\nWlvqxqI+wx9GyGBwsUvZdGen7kH07R9rU/ti93DeLjpiaPfSvq23vP5599GWGHYn5LaEILrsRfte\nGi04bRu1ClrQIyCWUCjYL+BHJzMqF8bKE9YfJeCNp5PQMy0pEHiSKEZFQzbbBVV1FqwJ6jW9Yh/t\nFXl2gIzH9Hr76EaH8ojesKkF//pf/Gm4/RGoqCOE3y5cR1Jur60rR4W2ES1nxbcTprsF3e7fKog6\n8ncYs278ROt7JogVxInosALCw6aWVCVsFn/Nrz/8AddXK2rjgjrJmKCAM5ZExERRHNbmcR9fN8g4\nbRuLa/AGohRMg3cWJyRxOqReLRjvH1CkSXAZl57jR/duOTqvI2Cxu2xoidk+0Cxcy8nxLkjUnQ1z\npzaOsi3dW2txjaU9n+NaLo61Htv+23Vy847L5ggcHk/bu+zmo5vXHdHZOdG+Tlue7F7LhV1MO89g\nb0xFsF/p5r8SHhVLyurzOadfOK1kVfNANcQ+QcqcB0WEd7ClptQ15caS1zVxr2C9WsCwx0A6PuaI\n2bZi7VIq1Sf3Jb5aY4tH1HtfIUkTmsGY5XSOTA/wzRa/eUkvUuSiB/sPaUZDjvg1pGMsMdZbyAYM\nc0UalejnH7JuUiqZMRU56tFj4hcvUIlAydBYE0G7ELVgBx/cN9s7Ebx5ZJue/u2F7AunjxNUxgRk\n7CzGORwW7wwuiii8ZWks2jqstzjpkVHwY+gVPXzLOA/ydkHSnkQkDuk81jki4Ul9yMYkxlGKUPpQ\nwmK9JcUSdyuD8GgczodSS+U8ot0UhGPn6iy9bxuR+iDbdJ7Eh5Odcx5XN0G91ZjwYVp6/S2jX6R8\n/fiIP/3OY7733cc82O8xGWVstWd/lFKMM96/WPBXH1wym254/+SCHz9fYKzk7YMel+cznjyY8ORg\nwCSJEJGisZ4oVVgL9yZjnuwPWJWOJRtc3uDWjuuZpawiXr5oyNOENEkQTmG0pXGeyjRUtUVIwXxR\nUaQFOZLZwpKnMbX2JCrYAGjrWS8bhoOMnnMcHQ4ZxZrYKl4t1tRVw6QYcXZ1u9Nov+iFurh3zJYL\naqMRkUJGEfPlHO8d77//ASpSzOZz3n/vPcpyy+HhIb084969ezx58oTVasVkMmEymSCEYDabYW0w\nEVwul/R6PZbLJefn53jvGY2G5FnG/fv32Gy3vHz5ktFwiJQCbTTb7ZbNdst6vW4dmUuUksRRRN1o\nmkZzcvIx8/kcrTUff/yC4WDI8fGIk+e/Zr1acX52RZr2SNM+St1Org+A86iI0PH6U9FxDkS7YXVn\nFGchTSSHBwPuHY3oJVkoG3kwVjDMIUpT4kiFBobhxIDWnohgFpckgQS6WFRs1iXWGrQO2THnoKoM\nVRUyhO27uHlj3oUyhQoL+o6L3BkH4UP5vP36bjO/Nejx7Sm4XcPaDWy3pH3WR0uU7T524Kgt2Xjn\nuTTwIgjudu14IGwgvVix0A2zVYM1Fd5YgoQCvvH1+2yrC5J8xLq6wtgmlPbchkH/MXGckiQD9gb3\nSOIBziqECmtULx1T19dkSZ+yrPiT778d7mVLzOaW6zHctIdwzt2AunaDdU60/ned1w67Pk3Oth2+\nO7VQe/2+LUcGQNKOswygLEk8KiYoSAVsS0ndwIuf/hfWySQcspt6VyITTY1QMgDK+0PEYAhZD5Ks\nBTseGYd2Lw4LaR+BQ6QSW9comTKfT0myHuvVmji6nVACbsajUwd2VDLvOxVbB3Jk4PPYwOWRUtAY\nS6WDglU72wJLj+kyPG2fMk/gjHkbvI66FI1v+2t1nBzXtpu4Aedt/yz8LhtkO+J5m3ELwMfSU5LG\ne06AygafNWODItpq+4WtH78YR6cp2+yArashyWhMhHcVtvbM1ltiVSLvP2a9bcgzQTp9xUrHTBJN\nlMaMfv4/mb44QZZzuHqGbirs1Xvo/IAMh7QzembDSMYUaY+Lyduo4/uopkblD5APv0MkNY0y5MrT\nMyt0sYdXA6KsoTn/CV4vUKc/Q5YLxE//DhlFrTnZjbNmR0gLh7MAcQSiBUVtku+W+XfTUs3jzu/B\nW4wFhSR2Gm09qXdYZzBe4DWtE6dHZAm4Vj3mw+ZaiVA/td4FPx8haQjKiQiPVhB5R0zosSSFxIrg\nYGRbcpdsT3fmNRWady74HHlPQgA+ljARjfekIii2jLNBrm41Zl2y2WwRukLbBqFvy3CC1abhVy9f\n4pIe/YMR/fuHXGy3vPdqSrWu2S5K3j3a47tvHHK432OYJ0hXc7IqGQwK/tE3H3N+teTkes3RaIi3\njtJY1pVlqjWnesOmJ9HJkr1Bwl6/z+i4TzGISWTKYKRIC0WjG+IkwhqJ14EnNRpmVBbyPGZdr5hu\ntvT7Cuk9faXQ2uC8Ylyk5IOUdQX7x4eMe5KPztf0Rwn9oo+1FR++OiNYht5mbEqkVIzHe8RJwmg8\nZnI4Ybw35pvf/g6Xl5f0ej3KssQ5x3A4ZH9/n5OPT3j8+DEgOD8/YzAYYFtCuzGGPM+ZTqecnp7y\n9OlTkiRhb2+PouhhTFBdxXHgjvWynEcPH7Fer4njmCSOqaqK87NzmqYhSRIePX7M9OqK9WpFvyiw\n1qK1ppdljEcjrNFMpxccjMcMih5NbTg6vI8nRhBTlbcHyqiwMRnz23Oua8XgCMT7DgA5H1pSXE7X\nPHt2wfnZjG2lcQQJdpYEg7KyttSdR5f1LTcDZtNrPvrwGQCzq2tmV0ENc352ysmzk1CecjDqZ+Rp\nd6/9628MpYIIQMpuEb/JMkDnxeNveAuvnX5/5+gyF7+laLnFS+3AUJct8xjtmRp41YjdVQkgEdBT\njqMEfvqqRGuNcR4nIqyF5mTBd0cDrK4ZZhMGvQmJKhAMUQp0U3J6/WvOph8yGh4hlWCzuqaslwjh\nSKIcIRUqGTE+GPHwUbHL3Hn/Oe//CyIcXEPWoCPR0krPOyAVOCcCZ9qms+6mJGm9D9fXZsREV3rx\nfocdRTg/E8eCOAUZudb5RLBYQ09+hTgbgANb1eE2W4NwDiEF0lnibzwCGSOcw2/XgTIQp2A0ZnWN\nUK1CzWqkSFBpjzxLONgfc3H2ijhWlOvPb5/wufe9A3rc4GUIGSzrOkDSgc72a11pyXq07UYZjLbB\n4sEEno/TnbIrjLUNhYuwt7Z2Cfhwdg6E5OCo/gmSvQ8jbl13v3gt6wPtxo7FE7eCG63dbr1YXq1b\nQcDnPw1fCHhWS1g0EhdnZEkoFSW6wixfcv/eAWNpiZWh3hg2l2eIZsmmWXJux4wPjjD/5J+jzBRl\nNc34IbGpSNMj4nqG1hVFcZ9KeLyPsVIzuvgVs+WKRiq82bLxfayPSZsNMs7QIsVePqeqNenem0wG\nI+JygRk9ZiAjVtOr19QFqn0I/K5BqBXBk8OJm6fJO4GVis84UH7x/GlXHd0aFjggtgbhLHUTFFHa\nOhKhsMYipUM6ENogjaYRAWR4OkJzm1Vomfzah0yPFC50UW9JX85bEu8xzoXNrpXtKedDH67u6fU+\nSOIJpxcA4TVVy/dRbauKxvtgh46jdjpI9psG2YRMT2YczZfI8Lx5b0JiPKJp0JXjarbhw5dz3nx0\nSB1JvvnWfYZpRG08S+M4ngzJs5RUJMzKip++mmFc6D2jpCXPUpSSTLcV+4c5w9EA7zXWGeIoZIDW\nqxJvPP39iMlen/OLNVmRYrUnU5ZBP8ZLh7GOSGk2VYNTkCtBHKWIKEZFggcH/eCuXGrWlcYIzS9+\n84ym8oz7Oe+9OAdvGOURvVQyu17eamzW2y11XWOtZW9vj6Zpdv+vq4r1es3JyQlJHKOiiOPj41Cz\nThKqquSnP/07jDFIKanrmul0ilKK6XTK/v7+DgQVRUEcx4xGI46ODsmyjOvrK1arFSfPTzg/vyDL\nMpRSXFxMkVLy7rvv4r2nqkp+8fOfs91sQQimF1M2my0PHjwKi8tyhXOOKIqx2pLlfZrGsy0dL04u\nmF7O6eW9W88bZDhRquh1WnCI7uTmndtt9kC7SDqcs6w2Fca6mzkfBZl1FqvAw9CByeLaRdUamBxO\nODgM/dAmh3s8fuMBx/cOefj4CV//5rth87Jh3vhPLZmCNjPbWvIjQumjbT20Mwa8kea2QEN+qQTP\n7nc+TVjuNo1dZgM+tZm8dsJ/7fUgvCer4bTxvGhEew3hW5GAgww+vCrRLqY2ZrcTzZZLDocZDyWk\n+UOm84+xdsOmWaCtQKohaZQjRUJZrUBAv39EGhfk6aQ1lJMoGUN8zL/5V8EwznF7QvfNmPiWLP7J\n0t3OVbnNpGtL27E+ZBA6ECm5ATqdfH83lu1nKUO/tzj2xIlARR4Rhfl0sliG/k9Vg9QGYTSqapC2\nJUQ7T/b4Md668CJR6yhtGrxzRHGOSHPCg6CI8xwlFcvliqZ2DEcTVJKQF7d7ttrOVLtMoG8vZqcs\ndK35X5tpcdYHMNNmapwPZrpBrm9xrWLYm/ZZsiJYAXCTpbMu9DQLcvQwx5xtQU8LfnwLfujKYC70\n49oBL9uSmYN/IwIYxBHLzTY4O/tAmm6MxWiDaTSm+fzWEl8IeJrmCmE1cdqDeJ/YWXySocbHlJdT\nqmiAc5ZYXVGMxkRPvknBhnp2wXI2Q16+JE4G9Mf3mMQW4dYkqUQm++hSU8cFfnCEzCUqHqPimHj5\niroyRFlOZM8ReokcHTJfNyh9TdI/RtqS+fUaV0yoZUE2e8a6rsCGUkVYeIIyaZfCDMWt1pMnZHYC\nx0UQAf6WfAPtglLLOYe2FuUctXdoG8wMpQ5ZFWM0CZ5GOGIZZOJFJIlVoElr64KiopWaRwKylmwa\ni+B0WRH4RlqEJqNWQGeVWDuL96GRqceTEtpFGEJX+NpbhA+GYY0TCBdKKdpZmlD03p1gIkLbAGMM\n2mhc1WCdubXsGuDjkxfEzvLyfMHZ1YJYRDw9ntAYzziPqRrNyXTD//7glMTAtx8ccjRMOB4K+nHC\nu/f3eDTO2QrPj0+v8QJW64ZvPJwQ5Z5NWVFtt1Sl5Wph2G7KQLKznmGSYMWWg72M9cIgnKQ/zEky\nSSQiitQirSCyBfE2QwlBGsHeIKPoRaysIU9ihq3k3FaGWEbMVhVFntCLUiovqFqFwDq+HYdntLfH\ner1hdn3Nq5cv+Qd/+N0d4Lm4uODdd9/l6dOnrNZr+kVB0zRorTm+d4+6rjk8OmJ/f5/tdstyueTR\no0esViu+8Y2vE0URSZJQVRVXV1dcXV3xy1/8gtVqxXq9DgDFWg4ODnDOkmUZy+WCwWDAeDRmPl/s\nMkPj8R5HR4cMBwO01hRFn9Vy1fbsquj3eyyXoadWudWoqEev2OPp02/gnGI2vx0QBPB16Omjm89I\ngbT/tdZ++ogKhFLY4aRgfy8hSlqDOR88U9I0IlKKPFOBryYgVtDLJFIJ+v2ifW2D0cE93VqLbjTW\nC2wD68U2kDBff7+wU4tDeJ7dblMJGQYhAAVCifCAd9ya244Nr/1O+49PAIMOBP09FaHPyp6EsgJM\ntees4/S0wCeL4MXlgum6Ai9p6uBSGA9zOLvmXm9FT09Jkj693pC90RtIqbi4/Dv62QH90RFlvUZ4\nzeXVS2SUMl+dI71DmyqASa350z/59u46vkyGx3c3ozWXdG0WG8ROsm9tS0xuS1y7zAa729LWA0J2\nUMiwae+AAiHDIxVEbZYnigVKCqIEzuwH1NUGtlXgbDa6Fe6HDLzcK5C9HkgZJOlKgnN4GYxBfRzt\nuHN4UGmBNQ1Zr49KU1abFeWm5oOPTm83NogbAChaANjyZ/CB7tDokPXyrYzfdmDQB0NW01ocWOva\nA0BLLjYiGBcaF36/BSq2bXTQtazYlcss0AKaTxoPBrC1K5MZ2r/XyuRNuMnzKuxRMVC2hpll4/FK\nEqUJsfr8vfwLd3mVpozHB8RCBMXO3oD1/ALlIyJvkLZCN4rh+B7IHKoVvcnbDPoJPkkxQjC0axbW\nEQ3fwMoxa7nHxXxOfviQ9bpBWE3lYpSoaRqJGgzwMiIzG66qGJMOOb+cMUxLbDHBqZh4f0LPXXCy\nsLi8h+ntIV98RNQuRh6QrlNggRCSzne560nlnQw4vi1peWFvNYGwEDmLcIZYeIzxRM5DWxpqhCDC\ntUDDIozBGxcclq3b+TIIJW68GVzYsGvnSRCB/4MPUjof5OkJoJ3FWIPz3fdcK8Vn57ysXCAldxbm\ngrb1BUEaH3lBZj3O6dAjqc0MeatxRiMaA42lKSv8FyDmz4u3H004nAxZrbdcTa8RZYl2no3WLLaW\neakRieKNeyPiLGK2XeC9ZL7VPDjqcbFYsawN0jRsF5rpuuT+pCBKPWnhOdjzCGXYbkuqckmSe5IE\n9lPJq+mSQb6PlCll3bDVTXCS1R7dwKvLksnxiPEoQg08w35EZRyrTYnxsFpWZP2UTdmQFxl5HFP0\nEi4WW3oKJqMeWRRxMbPsTybs9W8nS5+t1hjnODg4oCgKfvKTvyWJE5pGc3x8zOXFBZv1mjeePOHk\n5ITT01MePLjPfD4LHAljMMYipWS73TKdTrl//z4/+tGPWa1WHBwc7Jx++/0+WZaFBSKOybKM7bZE\nKkW/P2A+n7O/f4CUiuVqRa+XM5td45yjqkqGwwHb7YZiMMAYy8XlFBVF5HmPPE+RUlD0+xTDISpJ\nmBw+5tXpFU/efIdHj5/eet6IlusiPqPEfINvbvgVYRMSiChms9ah2asLJeGwiYXTYiRDixbXgnsc\nRFJSNo7ZdMbpyzMAnj9/xdnpJR8/P+HlyUf88me/wrbPTlB7/vaSKQinTEQg88o2FSOEuFEPuJBR\n8DbsOB2341Zj033uwJL/5MeuHPXpofOf/NzxNj4dvt3AXtWeuQmbviT4i0368PEcrGvaEoOh1vCT\nH7yC5yc8UhXWbJhe/pI0KhBScbD/JnEyYLM1KCmYraYIpanWJ0SRwDqDVD2sL2mc4lvf+irDvfT2\nSLAN9xhZ1UwAAAUjSURBVPrmGW5HS9K+KTHuwJAL/K4AgF4rdXU/R8c9F+29bJtluq60JZAqNGGO\nU1CxJ4qhVPBqeYXTJqyjpl3bfVh3kzcPYdjHV+XNjRDBjiQIbDpAbEHFCDz9yX3wJUWeYustkfL0\ne7fr3xeux7ccp84rpy03dSW9djyMuyEPGxtSBdb5nZmjdxZtQwVDtyDJttkgZz3WSLQWWCODt5Dx\nbTanuz8tX8qIXcPWbtz9a2A1lB5boLQrkbUtyRoT7EFM+JvbxqLiCNM4+oPPz359MeBxjrWOkEmK\n2syppxt8ccB2dk0zOqCO++jVlNx71nVJ1TK9dXGMJ2IbFWwl9MyWbXnN1q/J9BV721Ocrcn7MbJa\ncTAaEg/v0RsPyOOMYeEox28y0guEsBR+Rd1/BKIhtQ3GpER7b3BfnJG3KLJ6dhomoWBn8hW7YKrn\nCQuR2HlSgFBBoRTtOD236+wsZEC52oTSUmQtjQ1GhJU1OFOFmmzbOVdYT+kM3jqSKKKDV8I4VDv5\nKhd6ETkXskNRJNsJ5oJplXfUzoaMFCFbhAugKN4pCsLq571DtmRt57ufCSee8PuBK4QP91k4G1xo\nnSNyGudCc1O0xX4Gn+Lvi2a+Yr0oKRRU1nFytWQ2nzFAoVTDs8s1e2nE9x7ts9pqau05GKRYLzm7\n3pD0UoaZJJaKw37GW/sFk37GS2+JWq6SijNEbNk7GGJshiVlriXjfsJ6s6ZpSrJMYLXA1LBee+JI\nMBnn2KZCiwisZLBXMM5TojwjSyXaCUzVkAtJ1Ri086QqZdSLWJQWKRQphidHMdeLJWwvbzU2UX/C\nVguWywVNU3NwcIC1ltFoyPn5eehiHimMMRRFwf7+/o6QnCQJSZIwHA5ZLBYcHh7y1ltv8ezZMx4+\nfMi9e/faklSFEAJrLYNBH601m82WsqwYjUY0TcX19TVlWXF1dY2UAucc17MZh4dHWGsZj8cM+n3K\nbUnRy0OPrSjm4YMH5FnGcrlkMOiT5ikOx3A45q/+6v9QFEPef/8jfv7TX9x63uyIlJ8x5bp93BnH\nJzgwHqw29LKIzUZzPatuXHJdMJgzhMOENiHz21WYIgmTowlvvv0EgDefPuLBowc8efqEN956hz/4\no2/v/pYQnjj+DDJ1q4AJqqj237bNmgQP1PBzUgQg1yGJLzM2HWj5FNj5RPrHfcbXXvv8Otj5dJmr\nM3l7WQnq9n1HEgYxvH8NtXFY3YRNTmtq2efX769wJ88ZX00ZDd9gU53jnKbcznk5/QDvSmqzppf1\nSbIDBqN3SJIU6zOiqCBSY4yp8DLmz/7lP27H9EuMT6D5YttSlNspfTpD1puN1bSIMWz2olVc3Qxa\n1wncA1K2QmrflW/C60gliOOQ2YnjUKFKYvjh+Qd4bfGN2WVqpFJIKcnefhiaN+sKbw2uafAqDkpY\n3UCU4p0BROiY3tToasto/4gIwfG9h8Rpn6Ojv7nt4LzOc3+t5BkeONe6JduWsGw7kLHLfgUhjfOh\nW7ox4VkyBpyRGAPGORrrW26UwLTs/a6U5doS4utOzrtMTpfV6cCXuWlv4TtOT1toSKTE6AYJNCZk\neLT11JXGaAM8/9xhEP6zoP5d3MVd3MVd3MVd3MXvUXzJvr13cRd3cRd3cRd3cRf//8Qd4LmLu7iL\nu7iLu7iL3/u4Azx3cRd3cRd3cRd38Xsfd4DnLu7iLu7iLu7iLn7v4w7w3MVd3MVd3MVd3MXvfdwB\nnru4i7u4i7u4i7v4vY//B8JgNN16YgdXAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAABJCAYAAAA5f/zBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8ebBlV3Xm+Vt7n+GObx5zHqRMDakB\nCQSIQQgwwmAwGLtx4fJcVQ3G7cLdptyFXXa7XIZwea4qV5W7beMJGxuXzWCMoRBIIDRASmhIZSqV\nc77MfJlvvO/OZ9h79R/nvpdiDD1HdHSE430RCr28wzn3rrv23t/+1re2qCpb2MIWtrCFLWxhC/+U\nYf7//gBb2MIWtrCFLWxhC/9fY4vwbGELW9jCFrawhX/y2CI8W9jCFrawhS1s4Z88tgjPFrawhS1s\nYQtb+CePLcKzhS1sYQtb2MIW/skj+HZP/sEvvEM9Hu8NYhTvBeuV3CjiFFXweKx3OAxGhKNHT/DM\ng09QqVpEPV4ECzhAUBQhzwEFY4vrghJoQCaORuZIe7Dvzz5OdTZm/lKL3fumcXOnyemyP3mAdmOF\nkewZoILPE9JMafcy0paweMUTBoZr98ZkrkG3VWJtLWftXIRbjhEr7Ns5is0hU7jvy8/w8tv2MFQJ\n+aUPPSLPN3B3HT+rAYIRwYvBiVA1QobgTIAzjmEsDSOExiIIIoYujlCEKqAi5IAVgwgYEWIRegAC\nUoQJgyFDsRhEPFY9iCFC8WLwqqSAQfAKqoqIkiuIKqmAeI8oeCDxHgS8KhGCB/I8p4Qh9x4rgqpi\nVPFAoJ5790w/79gAPDifqQ1C/uzXf41Dh27h0D2vI+07nBMyL6DgPKhXvIL3gvOKevC+uIYdfH+A\nQARM8bwqg9wDp0VWGSliZowiIogp8s4YRQyIQGQFMYoRxRoQI4O4ewRfXEOv3lMGceA5j6tIkcc6\nuCEAyiu2h887Pq98zcu1sbLCwQP7aawsM1Sr0W01mZ4c4Y4X3Ui32aRWqbJrxx5ec/c9PP7oY5hY\nOTV3gSvLK4RqOHLqWf7m4/cShiOEpYS9B8cYGa0ionhnyHNPs9WkUqqxstqi1+1jJMR7B5JTKofE\npZB6rUI5ipjeViOsjBPYKlnqQRzGZPSThMrQHezeeTfLjWc5c+Y+Ahpk/UUqlR1E8T66vsqNN72a\nbqfN0vyDrCwfYXbna7jr7rfzxhdOoEWwnhd+9jffpquLXZyxLC9f4PSzp4gqltQr8+ebLHXgDz/w\nn7hm3xQ/8p7vh6ROUMkI8iqVekxbLlMeUzqrguZKtytEATinBEaoTUOWKGOTEXbpIBdPrCI+pdFs\nMzE5zN2vfREHrjnIrp0HkbjEV499ht/9rT8h6imv+6FX8t3f9f0sLZ/gT/7sLzn8wCWCkhCUFclg\nwpRpt1NKs0rqPaNThtVFGNkJnTVwbRibgkjA9YW1lnLqsHvesfmRH9yrVxopFsWKodns0O/18E6J\nSkIUWZbzLn90x6u49c7bcLqdLLlEd6jChf4a9d0HCGyMek+WrpH0G7TXlvjEA4+zttbj0SfOQ6j8\n/E+/E2NXGR2uEVcqWDUghlxTPvZ3D/LBP70fCUwxp+EI44g3vPluxobLhMZTCgx4B0ZQbzF4sixD\nFBRBxZPnnrfe9moiE/Jbv/dBhvunaNlhHu+UePbJM5gILiw+/7wB+J33/4QmSZPQWuIoopghhCz1\nXFy4RC9N6HS7tJIOvbQL3hEYQbVJLexSjQMqUZ3YRsVcgBbzg8akudLPHKmCCWNK5QrlaJg4rlGK\nYkQC8syR5zlZ7mi3ezTW1pi7dIFOP6GfDeY4daROafUTBAtiwViiQAhMiLVCYCzWCqG1CIr3IN6A\nFnOSFUNghQcePfG847OwcEZFFbAEgS3mAaBUHceIkPQbiIQEYQSAsRYRAVWMCbHWUsx3V28p8rW3\n//p/A/STDuePPUTz8jzDw0PsvOlOyvXJr7kOgKrHPfVhrLRg58tRDH75As0nH2Z1pUXL1vDjM0zu\nOUA0PAriiEojhKU6YkJACaMK1oSkT32E8q1v/6ax+baEx4lFfbEooIJB8QJGBacehSLlJUC8R8Xj\n8pwgAERxCOIFrGIUPAarHm+Kxc15TxSAd4bMOsLcMeRg5CMfoRH26DVS6rNVehdPMFtR2u1l4s5p\nNJ+nv9rn7NwVli530WQf45Xb2b/3RkbH4fLFeRZPpoTDj7GyepHGlToXLjbYXZlmtZVw/2OXiENl\n22SJFx/aRiW2iA2/XSi+EarkKlgDViD0HowlpCAa6g2rtiAfuXMYYzGSEWPJVekKRFCQjMGiGmHI\n1RMbg6gnw5BSkMZQBcQReiETKSYOhR5KoMUkqqpkQICQqWJVCBBC5wgN9MTgvMMNSKgB2t5TFkHF\n0lfHkBREMEcxA4pq8ZuLDQWJUJfxgz/7Xj75f/8ep3/7P3LPu95LGAmSObwarAFVwXtwqgQe1Bck\nRorkKriGAihGBDf4ngx4iFW/8cKCvBTkUQwYKYiQHRAbaxVjiuusEx2kIEDF6Ci+L3xzsiNSEMji\nFcV71kfBZjA7PUWtUiIMDPv37aGxukwUBwShodvtEoYVdu3cx2tf+2qefPwwvaSF8RFHn36aF770\nTubPXyQOQw5cs5eTJy9TqlosFudSKtUyLjc436NaNXjfJQh71Ict3baSZxnGQL+XEoQBgqMaClmv\nRxA2icKQKI5I+hl5lhMYSxiFVEZGCEb2k+YXaC8dp52v0GxdwTUzwvIYjz32YWoVz9rKZYJ4mp37\nDzIyObrpvPnOF/9LXvWy1/PRe/8f/u2v/ivGZuGtb3oP937xf3Bxocmv/KtfYGK6zL/+2e/n5bd+\nL6nv0G4us7KyQBKuUK96lhehuyrEkVAuQZbC0JRgAyXLheqw0O0m1IIm4zOWSGo0Oy2mpsYIohI2\nuUJw5Th++xtJ+jm5E2KrdFsLHD3+WZ4+MkezWfzqUQi798cICa0zKe22Y2aiTD/tkaaeXtswBaQO\nnEKvAa02tFYVl24uNiqGKBRCMbR7DmMt3nmcUwwRzkEIHFla5tbza3QmQk5Km2beZq3TJnzmKRrN\nJiM1w8OPH8Oqp9P3dFLPpQstXKb0kiap9jl/usfv//EHeeVLDrG41GV6tsbdL7uRa3aPU69FpElK\nHMLsjhluesF1TE+VGS5HQEbmDKKC94Ysz/EIYSnCOY86T6kUkSaOo4vP4pc8tnOZseGIZy4tUR7Z\nx+hYwHIj33TugCVNPBIJVjwixebH+2JTmDsl90ruDN6HpLkUc4SWMOoITEQcxHiJQATnHQEGxCIa\nEBiDkRJxqUa5XCculYnjmDiKMMbinCfLPEk/RV1MlgVEcY9eb40QR6oeXIBXh3OO3AkqilhIc4M1\nHmss1kJkDXFUzEiCYLzDYjfmNqebK870O6tI1iOIykipShhVQAyqrvh+YjDGIDK4rhb31YJyoYP7\nFZxGBi/Rb0pyrj4uGLHYwJL1e7hShObJYF4thI5iFRqIAbN3cvnpLzIzERLVx8nGaiztyZnPj5A5\nw4gILutipU6W54MNboC1FvUOQcEYepO3U/4Wcfi2hCdQJRc2WOb6NzZasHRVz/oxPkUcDFmeMRBt\nioWXQikAi+JJvMEYxYol8w5rBI8nUk/egx1/8jHmXIOyq1GWKwzNX2FoeIyh9Ahpw7C0lFIKAg4/\ncpkXH3wf77zndiYnhshshXMXFjh/6RL79+/h8sISe7a/gXPP/A9e/p1v5ezJk3zoz38P+imu52l2\nDGnaJxur0Eo8Q9XseaZOAYtBTUECQ69gDeocKgYxjkAMeE/f2AGp8cSqJChqIPWCGshECVBCFTIc\nI0ZIPaSDpTcWg1cQPIEqDiX0kBiHU8HgsASkWkyoRZopXj0W2SBTZYUATyxCpqB4HAVBCEQAh1Ml\n1eKzJgihKhZINrfRKqCAeHya8OZ3/q988g/+mN999/fzz/+v/8LI5FhBkLXIHDVgVcgH6o3xA9Kx\nIbUMrjeIiWihShUDx2y8RARUrhIdY8CKFP83irXFQDRGMbiBKlTsU4vzqPS5N9v4DBuURovrFzut\n4tOIymb5DmFoGB8dot1s0lpd5ZprdzFS302lUqbbXmN8qMbU1BRf+MJ9XJw7wy033UQQD/E9b3sb\n+w4c4PhTz7Bt1zYmZ3bwO8/8V1xaI0sKFcPlSr+f0OsoeRYVeSDDGCPU60oYWoIgQETIXU6vl7Br\nz276ecLKWpNKPITBgne4LEXiKr1+m7XmFUpDyuTkCGE2RGiGaTSbEDi2bTvA8uo8Z088ShRN8upX\nvJE77ryTSnnzFfOJyUkA7nnFP+OTn/8zqmNT/NVHf5uVJVi5CI8c/iqz18D266/nphfeyi37b+Q7\n7noLl5eOs+vO63jBwVHGAke71yKsCM4PSKkIS+eF+iQEEawtCSvNs6QJ+D7YUonR8SGibIldS/ey\nLYWjJx7g0eOOqKTQgUZrlWMnH+f++08zfwUCK8RlZc+ePVSijCcuXyCMHP1Gwvf96Dv5+4//Ea04\nodcUshzEQ3dB0UCQAKau2VxsYhPQ73XoeMUKpL0M5xxg8F5Rr6Q5zOy+jnO7x3l45WkuL17m+u0H\n+dX/9iHOXYGRClxZgNkxePM9L2FmqsxXnzpPmqUkWUa9VqO5coVdO7fx5JGcJ488Tq0GUQSf/LvH\nmBmDG/aMMbV9O2NjdcLhErXKCGEQ0UuzYgOBIwoMQegpl0JUDb1+ghGLsQHGGMqliMutFZoX5hmS\nZUIpEQXKdK3HLa/exhcPNzadO6qQOUeQB/gArBbKuyI4H4B4rCkRWKGXD5QTKVRxl3ucC3EuRiVE\nB+qGQ3DeIhJhTYQJYgJTJg6r1Ct1bBxjEbwqgYXAgDUR3gX0UyEOqgQ2QfMEY6GLR1PFqyX3isdg\nPFgDuTFAcZ3MQqaeyBYKu5VizrbGIFoICJuBSzr0zh7BZwml0VmGd11PaWgSXIqaMvgU7zPEGASP\nDtYQEKwN0bAgW8U8KRtER5Xn/H11szj4RUCEKCoRGGWYHnFvDg0c3uXgEvAJ4nsQjdBZWyUcn2Vo\n9gAAQVoilpSJimXNhahPUZdiggjaqyRiUfWUrCX0SaHaRxXK2v6Wcfi2hCdnfYctG3tfAbz4AdUb\naADFty4Wh9ThpShhqR+sDHgQTxG/YgUxogS2WKEEQ3rec+D3/zNnAs+QGmbr89iVJxkmp3X8FFpt\nMGI98dCjhOnr+Pl3/SrbJ3dz7uwX+blf/ATv/ol3sDSfMjE6Tm1kBHWesakxDux9JxdPPMOLX/sq\nPv+pPyIjJutktJKUTqfDpYtNjFh6WWdzGeSVgcqINw7ri2CmooXkjCPBEqsr1BIRVA2BKN5BaqFM\nQT4iKWIVidLHor4YARZAcwwBKR6DECt0jTKM0LNQ8wZRV+gNoqgaMlVCNQjFDiURpTEgExaPUaFP\nUZI0CLn35APlwgnUEYx6VAzqPTGbHF2wzoABJe+nfNe/+GGaBPzNv/+X/OQffJSk74tJenBfo4MY\nAeplwG+K96+rPIWwozhVrF69vgJmg+ys/yeDstWA/IhiDBhxhaIj3/AxB+SneGB98D73m+tzH9Dn\nUKNNHt6ZJX1KUUhohRe+6DacJpRKhnq9ytTYGKNDZbK8zdLSApOTMywtNakOweJai+OnTnH5/DzT\nOyZ52UtfxBNPvJInn3qC1mqHTC2tZsrl+WWaa13q9SHEFASrXq8SxZZSOcKYkMBEeCMEPqVqStRq\no6ysLtJLQ0Q8vaxL5voMx7NMTY9SkjamvUbg15iYqFArjRNFAc0eROUa43YnF06fYXLkWg5ddzvT\nY+XnMNbnj0ZzBYCeW+Pmg29ifHqUi5evEKTLXA6e4QsPf4J+7Rl2T13P5PAwQVAC4NzcMXaOV7hm\n/508+MVPMjE+RDNt4hGiEgShZ3JPIQMunlMa84o1QliCTqaUfEiULHGbf5Zdo9Bzhmtqy3xHBOdn\nDMl5pddMONXp0FqD0AqBCJcXlItn2txzz+v4yt9/mCiK6Kx6vvqVvyPJ+pTHQAKInWVtydFrCfUx\ny75Dihl2m4tNOyHLC0UyzR1hLPT7HjGGzOWEQUhJYOTuAyyULTO1vYhXHv7KaZaWoQy4Ptx8zSRv\necMtXFpZ5ZGnLnHq5DKdNEcDQ6nX48hjj3LXGyZ425tv4YH7n2DbDMyMx8zumGLPvl3UhuuItUBA\nY3WVTt6kUi0jYgmkUFE6DpzLKQWeeiVkqA69nqOfeKypkmdKKQx56XU38ODcIlHFc+LkGdJz59l5\nXcjuneObzh1RED9Qhz2FeoIgqgghkRiILLkXkiwvFA0DgRqsCIENgRiVABFPaArCbG0MBGBLWBsR\nxRUqlQqluIQ1IR5BdKAwIFgbgIc0ySiFAWXr8ekyNh4iyJt41ymsB4AxIWFUxxoDBhRDYACjRVne\nFiV5C4TWEInFiAXzbZfub4AVQxxWyBOPNpdJFs4RVYcRnyBhCZ+2yddWsHEJzRJc0sS3FwuFd2IP\n8ezN2DAeBLZQZdahKhuKzoYyTkF8BMjxZP0W46OrmH4Dqi/H4tAzf4zMn4K5Iyyb21i7/vsYv/Zm\nvOsiNsT7HhJXGNr/AspOWZs7TZ5mpGmXqew0bmWFXCKy6nY68QShKLVsnlL/BOy49ZvG4dtGrfgK\nvlhqRfHI1YVmUDJxPi8UHwAj5GlOYAb+DBVyp0RxMPCGGNQpaopBEdiCtdJ0jP27XyC/8QDm8gWG\n82OYJKOUryK1axnbtkLcaSL9p3nbq47jc6iUY5ZPn8FqwK037cGnfW6++VaOPHuS8TBi547tBJHB\nKMzu3UkeRnjrGYprmKowLop3YyRpTpo5sjTZVAKpKqE3ZEXFjmxQ7xUVOuqJvEWMRxAi8QTe0hcP\nxmAQIq+0BELxOBVCoWD96lEpdifgMQq9QSmrHOQF01dP9+gxspUVZGYWDhzAu0GJEIOn8PBYLVS1\n8uA3W8ETyMArI2DXF/aByuG1KM81VQnFEKkpyNo/iu9cTXxESXspP/AvfoCP14f4ww/8Dj/w0z9F\nEILL3SCf1lUTCq8OAx6hV3cSqoJBCnKj6wrLYOIogl/UuI0O5GzdID9F7Tvf8Po8d2BuJLteLU9t\n7FqeY+rZ4EU8hwjJ1fc8X9TrQzz91BMcuuF6ev0ey6vzjI3VGarXuffez/L6e+6mUi7hvbK6usZa\no8nQ6CSNdovp7duIy2W8czz80P3cccchMrocPf4MkyN78XmfqclRtm2bIootqin9fhdjckKJQQKG\n62NMTswwMTFBGc/KuWdZXD7HyIGDmLhGJ2kXZU9jCn9IGNNpXSZtnSHpX6BWr+E1wkhMRJcgTGi1\nG4zP7OWGm17J1OxU4ZvSTdZsgA//xX/mgWf/ggfu/yCTU6/h8b++l+UluOuut3Lp/Ck6SUJ7ucvu\nWw7w+u94B48c/jgL7Tn+5jMf5od+5D187K9/n7MLsG9fm+FxQ6cF3RXP3mvHWFte4/Qxx/RUxAtu\n386Rx8+AgRt2T5Mvpbxu+lmu3wY2EIYi6BrDnYfgL08aemK4+ZZX8qWH78crxJHgUygZ4diRi6wu\nfpQkH5T9c08uDWZ3jtDvWlbO92jOdclSoT4D1eGM1hoE+eYUsMAYSnFIniWoETrtlNxBLRIyFbKs\nyzX7dpLblKynVGojjEyP8/k//SwrS2As3HFomte/9gXMLS6w1klx3hbWAwTJlcSUOPfsAg+MPMwb\nXnOA4UqT3Xu2MzY+ggmKbW8UlMjyEMFQCpq0Wx186jDiCQNLlufk3hNGESYKWEoSrCi1mrB9pk6/\nA40kZaQ8S7+xwu0vfhkHbj6EHLqfF7/oADt3jVGtxJvOndzlg/HoUF/E1mmOxxbj3hgiDHFoicLC\nH2MBGZSLAhthg0E5SwxBYMEpgYlQMQRhCDYiLkWUSiVCG2AlLFQgLcpDBsFYg68ayhoy/YJ9vODa\n2xkeGkZE6LSbHHvmCLEGJGmXS2eOcv7iV7jUPEcpqGNNVMzxRikZixlsGkIjVAIIjSl8jLI5slwf\n341UhvDtNbIrT5CtzuG2X4u4DrY0Dr0G+uzHcZlDsz6SpwQGgriCaoYfuwaxwaDUpYUCBhtE52qZ\ni431BEwxf0pAP89IfY14aQ6NToINIZ+gzxArZg/tsX2MbL+BwFqSfoswCPFZQlQuY4Iykc9ZQ3F5\nQmthjumz/wVT9YS7v4dydQjyLtpYhLlH6F1epPLCbx6H50ETbbF4qMHgUDUDT09haFVT+HnEC1Yg\ndxkiEEihEBkriPqNkkFRkjAgileDaIZWKux++R00m01GgkuMts5Sqm7D2F1MS5M0fYLZfoNr8y5f\n/sLDzEzv5szpE9TrVWan9vLWt1yPVob40iOHmZ7aQbfT5KnDj/CDP/5jLJw/zcWFJVpnztLpWTpJ\ntyAr1lIKDVEYUopDnFY3lUAGTwQIhnxAKFJyYmBIDT0csVocDhFD5j0usIj3GDGoSGEYVsEjrKgy\nbAoVB4FcHUYNmXiGVegJdNXgFhfIP/ghkjxHnaPvoHT9HsK3f29Rb9Yci5AqpJITYjEeHIoDQlWc\nEYz3eAqFJ9JC9lcp1KWSGLqqGHXk6nj+dtyr0HUJhPVBoCS9hDe//U28fjTi6cMP8Z7/+Jvs2r+N\nNEkHeQGqptB0tKjvquoG2VH1hGI2BBUd7KkMDMhOUc4rjMsUNV4pvE+GosT1jTVnHaiMAw1TrhKc\nq4NX0a99x9deYZMKz+kz5xgeG2f/tdfw+FcPc+OhA8xMjtPv9dh/zTWsNfoYmrS7XbqtNnfdfRfH\nnjnPTTfdzA23HKLdaZP3cy787d8S2piR4VF27Nhd7AQDx8R4hcwXSmwUlKhW6hitoP0AVBivzbK2\n0KXkVihXE2qjVcZnqlxsNulpRh4aQhNgel06yxc5JyNUa0M0F06g/Qad1giUKrhcMc6xcukErW6P\n4fHrsLUpKBnSDDrNzXu/Pv7xTyCPwVAAYeU8YSWkUnV86aG/Za0JWQZZnvHlIx/lqz97P9fetJvf\n+IN3cmo+Yc+OHVxoLFAJhPlzysgw2FBwCuN6CLhEnp8kjD3NtWWCCuyf2cX+eo07dh3lRQcFYws6\nG1chyQqPzo/eoPzxYcMDD3yO8+fakJsNP5ANISwZ5pdWmYpiUlX6TU+9MsPYbMrFI1OcuXiYzAnD\nu4XKpGekPMLEVInEb05Vjq3Q63nWs9U7Ty0OyX3hV8HAP/u+N+B9CZe3yExKFJf5jV98N1944GmO\nHZvjpht30W73WV3usdLq0u9DkubooMSce+Fsx7L21Hmi6jg3vfAGrFicFkZjIwHdtEscV5BgiEo1\nolaC3tpZ+naGNAvIXUZgQvrdhH6nX4w7AlaXMuaDJjcf3EY5KtHtL5GswUtetouRPQu8bqTCyTOH\nOf7kCp1eyo+8772bik8vyXBkOF94FXMF7xwejxpfzA9OcNZhwkLtxnvEFAu4RAYTmHUXHw5bPBdY\nwCJhAGKwYYgYwYjBBEGhcHiLkhHEQ2SuTtRL2F6rMDV5LYFVZofqzMyMElRqvOpFL6Lf6zN34ixy\n4+0snb6Fk1cu8cjcw5xaPUrVRogJiYwv5rhBNUFCQUkL3+MmbQZZr4HOP4peOY6mKeHEIXy3gaQN\n/Og1mMBSCnKMJig9JPRgAzRQfG8B31tBStVBZNZ1cAZ8QJ6j8rBe59qoCYkx5FGZk36M8nLMVG2a\nqDaMuf5lrI4+zbGlTzM+Mcu2oTq9hWOUK6MYiYumJiBrLkF/jbzfxqUthpcex+/7GfTMfUg3REZ3\nIfVhVBZpZTs5ywrfXN95HoTHDwxI60YkL4pgcDJYG7xgMOigvpflWWGrMBBpwfEcgBGsFJ1e3ivG\ngjEenYObfvNXeLZssc0rjGbLBPVZxFSZTT7DLd15ri+1iKoZeZrzD088SlQqs23PDjLnkcoopxaX\n+OIXPsXr3/wm8m6P+z59L69+42s5ffYMv/Rz7yNXw8EbbmBqWIv6H4pzSq+nhUzsfbEwbgKig9gY\nT8ZARcAUXVAoBui5HG8C6t6j1hCp3yiv9L0SGKWvQmYgMkLmix2jH5QQ+5JvdFZlKLF3LP3hn+J6\nXTSwhC5BA0P/6ePob/8u0f/2E2RiUCtEaklRSs7Rl8K8XFHBqZIxUJoGC7Uf1KpzcuLB5BagJAPP\nSvcfofDocweFrAshStpL+Ld/+yV++i13sPKOE/zA//lrfOf3vpqknw2Ijd+YcLy6AdG5SnzYeH79\n6ut/+g0zMgPiU1h0Cq+OyNeTnfVy63M8Ol9z3UHef5Ny3rrCsz7nbDJ1cLmjPFzjy1/+Ctfu3025\nVKHd7tBpdRgbGaPT7bG0tMA1+3YT2QDvHRPTo3zxS5/n3KVT3Hb7bXQaObV4hLm5S+yd2cu+Hdfy\n5WMPk3tDt5viAsVGAbkPKZk6lWiYcslQKZfJc0feX2P/3mtJu/NM1odYWVlh394dPPLkk0S1MhJ6\nKiUDSY8r5x/D2hjJVnGuS9TrURubJAjL5LYGrk3W7xPFa/ikxdL8GuVSlfNnFzcXGKAc10j6bda8\ncOTICSqVKqENWWs0CGsG4+DcqQXOnF4ABw89/hWyDMJQeObIBYwIM3uFuKK0VwXfgWbDEFUqRK0I\nMXBpLqdca3LrtfvZPTzGC+yjXD8Dx04ZbjpY/J7n55TRmiAGdk7Aj73I8een+rQQghzSHExYdBj2\nuoV6gheiwNJp5nSbwq233MYn7v8oPhCicWHiWmivQUZKJ8totjdHeFabCQNfAF6Fahzi8gT1wtBQ\nwH/9wC+z3O+y0lwkCgNEoV4bZywe4cDORYZqNS5dXiRNMoIootFosLLSLtR7n1OqjzFUCwizJbQX\no3lKKJZSHBW+L4RqZZh+WnQ82nCIlXyVPLeElW0EpkSae4xxRYen9c8xVRs0M3S7Ofc9cIo9O4YJ\nI8c9d93KleUv8A9/uMr7/9tTvPGuISYm6uSlkU3nTidvopoSIoW3zmvRpEGGJwFRcuPJSHEmAesR\nmw9iqmTGk5kcFQM4jDUYLNYIDDapGC0IFBSlLAARvBVWGzGtY3PEi2eIXQbVCZbdYdLWMskL7iA7\neIDtt99OFEVU4oRU1jj55VQFs2QAACAASURBVL9jrH2JvX6IWw++jLnOTXxs7kusZBchiArTiBau\nGkIQCUCLEtdm4JorSFbGhzuQShXfb5M9fT9GM4Ltd3BV2c7BgPgUfI44Qbsr+OYVpD6DGDvoVJUN\nKwGqWJ8U4oV6im4wj/oMkydUrTJUqbHW7LDavsyul303Wd6lVKlR0i7TtsXS3DEWhqoMj80gYjZm\nXRtEOOniky42bSFGqKcPc/FizKXuNnY3ypRrYGo5WpplzS9x7PDfc+tb3/1N4/BtCY8vfme8L9p4\nPQPDlM8xYkAdCFjHhk+l3+mQeoil6OjKHARBUfbxGMQqIh6CAHcu4+DnP82xesyICYjqe0lXu8Sm\nynZ3nLf15qiWl0m7ln4zpbFcpdHs8NCXHmBqcpJeo0Ur7XPHnXfyzp9+Dzu2z5IuL/D+X34fTx75\nEq12AnmPPbNjrB35NEO1EIMQ2GLiq9SUCkGhgMjmdqOlvGgUF2OoGMiNJ5CA7qBEFWAwRii5HLVC\nTXMCZ0hESMVhxGCdYHFEvigZOhESLNZA2WW4uTnkwhythUWSXpejXzkCwxF//W9+kGpVIO2RJwaN\nyvzeZ77E048+wfCLbiPNHIkUJEkHqkmCxyskAw+PRUlxoEJXIfAeNyAZXVVChZI6EjUkuvmdunMD\noXOd9Euhhimel7z4Jh5bcTx9/Bw/+NJX8+Ffr1KZup33vP8/cN2hHSRJMXA2Sm7qN/rE/LpreIOi\nFENjvTiw7uVZv/c3lKF0nejIxhVkcJmrnVkbEtLGhLYuK8lz61n6df9+nuj1e3Q6Fs1TWq02rXab\n0ydOMDszzYtuu43e6ip79uyhWqsTiGHhyiLV+jT79x4gKpc5/OUn6aae86uLTG3fRpamuCxlciik\nlQg5Buc8ISG7Jndy/Z59tFYXOXv6IsuLZ7j+xoPs37adxuIp4uoo8wtNsjSgjOP773k7TuCzD/5P\nzi/NUYoiRB2dZpOJyb0E/Yx3/fO388DDn+P0wjxXul3q0W5cu89q8winu01aFx4hc4YLC+c3Fxgg\nL7UJu5APSeF1cz1aaz16HYN2hE7bM1wxhALxOPSXhSzzxJOC60NzVUnWhNFxi/WOaEzZMwat5T4B\nNYIAvu+N38XN+3Zy5cjnuGfmK3R7hpf/QjDIokHXHwGFRl0sht9zC3zguz1fOCJ84GHIjeAyxQaC\njYqNXDHXGeKq4eHPnODYV0/gM6jvVmoTQtaHyohBXZ/5C0WJajNYbvSxoWViuE5jrcFa4hAtcvb9\n/+4nafWWWL24wsnTRzEBSO64eHaRobFdjI7XUe+ZmR7hvi8f5UMfOUkkMDoBr3r5frKlU8zsKLPr\n4G6i0o10ejnlMKTZ7pM7Q6lSR9WyttIlzxKsFaxpkxNjwjLRxBj4jEgzjA7KLap4dQXBzjJ6LWXu\nfIM0ddw+sYubrhniSqvBycujzOy9lu98fZ+1pfP00pxLzbVN507Dnyc0gkqZSALEGrxxOJ+Thl28\nelKXkkQdMkmxJQblbYv30DUZqTEERghNQF96VMKim1I1oUsbCHD0UZ9QokrrcomLT85R0TbToWM3\nKSMlpR4HYC6x2hXyuuLmvkQ8U8cff4DqgZu5+JUHWTv+FBPNx4hNj4NTMZ3mY1xjhJtGyizLa/hP\n6ccQbHG8CErfWqIB4dks9LO/St7LIHPYwgmJrdSQ0Qnc8nlMFILrQtZB8j74tNhcagdsj/zMg0h9\nFjMyWzScFPI54dnfhfF9yMithfv/8p9Aax5ZnkNO3wfZrQTZDPVkBLfjhey89U5UU6xR1PWozN7E\njruGmOmuIZ0V3KmnaF12lIIu1YlZxu21aBV61Vny2RuIW4v0/SF2lk+zyxtcEpGe/AiapHjnqV9+\njDtueOW3jMO379JC8N4PFqqrRiRvA3DZxtSQiWI9eHWkSZdKUJiWcw+xMawvVz6X4vwGUWQxo/pv\n3kv91huonDnMcHWaxlKPHbKfSvoUd/ROUdIFXN9A5ggkoJ1G/PhPvZc9+/fgfFG2sGGIcx7jcjpL\nl3nve3+G0UqdnJD99UuMBMKzpw9TrliUMYw6ZHCOjQDqPcGgTX4zcB5Co4gfeFA8IDmxgAssikN9\nwcht7iEwGIFIHUaEjuRYNdQEYnE4wGHQZ4+RPv4E882MaP8+7N49lEdHOfs3H+XX3/cK9s5sI3bL\nOFnGTrcJ21Vcf4R3ve0QP3eszjCGNVEyUQyWUJQEQ+ZzUl036xbkxqghGpQm84Hc6ynITuA9iAV1\n2H9El5ZHCpl9YBos+g900PboSdOEGw7s5iNffZCfesub0YVH+cV3/jg/9jO/xD1veWlR5lonJwOf\n1KDIBXA1H4ukZF2RKWx/xYuLImrhq9L1NsivPzuCr7HufEPHwdc5djauu35b0WJjsBl02l1GhuvE\ncczhRx+lXq9x3cHrSJI+88sLnH1mjhuuu5njz55kdGKIExcucv0NNdoJXDx9iiiKOX72AhKEENap\n12tE5RF2H7ieo898BckTwiAm0irHj82xNt/gxbddQ1wVoixipdFAGp7hkSEik1MZqXDxyipZFvDY\nE0+wY+dOXv+y15IlHe47/BCnzx6hXptk2+ReXNLk7z77P7npuv1Mb9/Pxz99H2vZKuNTkwRhwMW5\n83SabcJanTzvbS4wwE/++M/xB3/4K2R4cmdIV6C1Av0e1CODNwpeKY8LUR36DQgiIXWQdZXpkZjV\nXsLSoufADTtRGoxUpkiSDtY7jEJ7dZ5LcxnnloXHEnjpQeXzP+/JnMd5yHIIopR2D/ptxalQK8Nq\nSzh6uTDkRgMTPIA6oVZXSuKQdoQIuFRoXRIm99U49JIhmskFVpcE74XhSfCpsHJlcxuJajWiXolZ\nbjTJ8pw4CsjyHPEwXp3mr/72E4yNjlMf2c5otcriwmXGZoQ0beDzDmF1hHOXE7bt2s2v/PIBQqsE\nhZGOXrKXIIpQl7HW7BEGAUmeEocxzgUkvRyvKQKEIiS9HCIoiaeTOvpKsZF1joXFFu1OwtpahzxT\nriyscnlxgYmJCX7ou1/ILdeN8+y5S8xfgbWeo9HOudg5zvY948TXTnHsqeOIb246dzRs4E1EbhzG\nRoVq7wXnUyTuI5ojLiW0fYjSYv0QC1JYLrw6VARXyME4gYwOUSBYGxVdcNojo4NzjsbCMJeeuEgl\nbzJcC5jMl6j4HqIWYyNSb6lWa0SVKkGcMzpmyHtrpKvzxPTYNbSK2xFCmmPCnKltJfJ+inVXKHV7\n3FJ+MYf9I8VxHUCGR2y2aXUHICwHmLxoQ7KiYAriTdonn38GasME7QSbtUEc5Bmow/uUHHDJeUzj\nMmZoGjXF2ikKnbUxaq2HkPSv8OFUUQ6wh0h6e8niO1jrX2GVKo36KCbtEjbO4ae3k3tHqTRCEITY\n9ln8ic/ByinKy18GzYlG9sP+UTRswcTLqFR2sNvVySUib1+BZ/6BZG2IhQaUrzuIVKagOkW+6y5K\ncf1bxuHbEp7CwlCE12/sdq8uBiKC8x4Rj0rhsdA+aBRgBge5iR0ciIcShErPOUzmsTffwvZ//W4e\nvnyJvNFjNOnTi0doh10OXFxiv/skViFnBLSNy1M+d2qCd12zhyRNiy4T1eL8EECd473/x/v4xJ//\nJTPX7OGVr3w1xw6fpzLZRU0ZG4D6DO8FEY8JBBS6/ZxKlDNe3dxuy+BR5wgG8VEMOTmKEqrijSGx\nhq5kGAwVb8hNcbaERyhjCQQ6vS6BtbjAUDlxlJX5LpXXvoFwpEyZgPTkMW546Au85c2HaLYTHn9w\nnvryFVzY44XvGOapBxdoXVzGTozx1lv6fOpYRsUaeq02XF6k2WozOjFG/opXYG2IVXBe6SKE4jdM\nycYrPVGq3mOBjgh1FKuK3WzNBnDODGr3hUHaK1jVovtAipJXmqXs3TnJ73/6M/zEW7+HbOkZ/vy3\nfoYHP/cm3vsf/ncqNUuWuecwDDYIy7oqs/6HrNObge9G18/W0avdGsUPpYNS1HPKWTIgR1/vxXmO\n0nP1sec8PCg3bhadbp8giOn129x5552g8KlPfZoDB/bTT/rs2DXF8uoK5y/Os33PXroXlnj06FGW\nl5apVYdYbnep1GsYG9BYa+C8o1QqY8wIWX+U3CVkeZ9LjUXoBSxcWSLzCbXqMJgyi8srVMsB/X6H\ndqVDmmd0M4X8MlkOmnXZN/0SJmcnecmP/y+cv/JyauM7+dCHP8vlxTUunD3JZ+99lG3bdzEyuYsL\nc6dhWWh1OywuL/Cy6/ehYUxv9Vu3h34rvOtdv0AvXeb9v/PfQQt1OK4VBLfZzhkZEWZ2lulnPYKq\nYOLiR8jXPNVgiLd83z381Yc/wsIK+GyON7/+XUyO7OCzn/57otgzNgUPPvEoyf6DaN/w0UuzGJ3n\n7tsEG0CWwF98Et70OthXVVotIXewsgwfeVC577xhxDrSQPA51OoKRhifhXxl0MZLUeLymULU4bGH\ne4zsFMQJnSYMjcHQtNJc29y4qoaWuYsNJCjOj7FisCYgcxmJgTe9+TvYs30/n3v4Ph4/fporl5e5\nMHeZkycvE8fwwlt2MDY+Sa06xNKVDqPjo9TCAFsqUSsb1Ct51qdmLTjIck+/Z1HfRYwnQummijoh\ncwlJlrG22qXR7NPrpnQ6fdqdBBNYAmMIAkOeeSYmh3jHm1/L/l01njkzz1/fe44sM1QrDfbvnEV8\nSu6EfupJMsetN+7n8Scf2XTuBHEHJMWbBCelossUwfkMTL84q8XnRGGCVYfgERMgmuM1x/lC0ROx\ng+khxImwEjmgwpgdp5QO0U1zut0KK4+fJEha1CsRU1mD2HcITEo3CcFnVOsxkzum8c5Buka94snj\nCr350/jFE/TP3EsUlchaDeKawfiY+vAklaFtlC9f4juWt0P1xZy0j+AG5zhZU3hlNwufJZCnxTEk\noSMMKUhs0iY98QBZaZTo0hU06GNNcYaaGgFicu3h9SJy4Wns+E5MfZJiRCp6w3dzZfUW8naPib03\nEZYr9HsN5o5+kQZXCCav5eLRwwSxoZakrD32WQjq2JldUE9Rn2KzJq57GRs64ijHKDgt4Vv7MXWH\n15shncStzWPCiFLSpil30vZN1moxQ/OfxI7tJm+W8GsJsrIIt739m+fItwuSDHbXbuBxKdaD9fY7\nC+QYLE6L8pXX4mwD53TdSYpb960gBFqYmd0VmPrv/56GUyZXV7iVES5sn2TbyiKV1UVmzTm6PRiO\nLC7PCHyKNyEPH23ybhGEAJ97vDrIE6Jyid/4tV/j3k99nG0HD1Ku1rlw4SwTew9x5fynyXpdnIXM\nThZSr0BAoWxUyyFpbjm3ukkTmMuxxiI+I1GD14Rs0IZv85xSEIJa+lZIrNBVj8+VTKBkhcALSo7k\nCaVulysf/TRr+w8wfst1GN8hfeIE2cQojcceozq5h6Tcpddz3GJDHm/02PfKCotPTSA7L3H+aJfZ\nkqF79jLLrReQ3fdFsnab0p23s+3mmzn35ccIKLoIKiokxpNp0aZatoZW5nACkVcyhViEsnq8Fv6k\nwG+uIwAgz4UgKBQdGRjWvS3M0tj1CnBxwvPUeI0P/sPf8zM//KMsnHoU8+Tf8MNvfJxf+p0PcNPt\ne0n6/QHRKNSWdZIySNKNsyHYOChQv0Z9MeulMZFBd8HX/daDt3+Nx0ef495Zf06fQ/af89wmxUHE\nhDxz/BSBVVrNFsNDNe6++9UsLy8ShhHepBx+6jCjozN86eHHMUEIeYczcxc4cOB6+s5RLlXZvXs3\nDz30UGHCDyMWFq4wPbWNCxfPoMaSOGF5eZkbD1zPkdMXuHjyq4yORezbPU6tZPC5I4yWmZicYHhs\nlNFymUq5ThhXWJw/y0gwjTYsUZrwDx/7KP8vZ+8dZdl11/l+9j755sqhq6pzUqtbUssKlmTLAmdj\nYeB5DMaMPfiRBh4LDO+9Ab9hmAFm8BrSY94aZswANsFmwLYMxsiSbQUrWJZaUrc651Q53Lr5xL33\n++PcasmMbSj2Wt19b1Wtvrf23Wef3/7+viGJBa5vURkdwhhwqgHPH32WydFJrly7ysLiAnfd9wAD\nw9tZbjZx/JHNTQzwo//qdl4+doKRWi4XdwOBFRi6DYFWBulCasWYTLC+TG47YKDbA0WLW/a9nuIH\nLd5497vZPrMDozL+34//MgduqnL27Coyy3eu4xcus3V4kkg5/MXxATxrnXtvk7gevG4//O6nBXfv\nEdx7i6Hdgi8e0zyyAL60kRasKUWaipz8auetVmU0ljRk/ZsTCjpxbjzaPg3FCpTKEEXQbYC0N7dw\nriw2GKzWiKKIJE7IRM6NsSwXVutsHRtj+exLPPzki9T8KXZN7uOeg6MMDAxQqQZIS2J7Aiw4e+4q\nFy5c4ZVjZ2h2W0xNjDIzNYjn2iQxNBqt3KTSEiRJTJoYoiQ3e0tTRZr25dKWg1+wcH2PSq1EmioS\nlZEkGUOVgAfu2MJYYPH8hTleOKUIAod2J+WOm8aYW82oN0IsCxwLAkfQ7GXoks9v/O/v3vTake46\nQrho4ZBiIUW+y2utQYYYlSCMxgHc/iFJa4kWCdLk4pqN1rUlXTIdg0lIzA4OyrsQC+s8+cyzFLbs\nxbeaDNJGyAwRhfR0QtExFF3BQEmBH9CLMmr+Os2FWZrrHeTCNfzhSaLZC6y88AXE+gpBTWI7kvXL\nhrGxJuMFRXnyIEk0yER3jfvdaeLSzczFJ/D0q4kF/3AL+8eGjpK8G2GynC7gCRxLkURtTP0qWXgG\nWmtQdHGKReTwXopbd5OszZNdPwNRSHLpWeTQJO72OxB+BYTBIuXJP/4VxgYmqQ8WcbceoLzlFqrD\nM3jdRdJjn2VpNsXeVcQKhrAdl87CNdIMBkZmkE4JWR7HsjXC9rlRWvgp69PvpLDzzRRqY6jOItbO\nQVae+jNIt7E0OY2ZgPLFhxBtl8axCyStHjW/jeHbC5C+M2m5j+bbQqL6kQMYgcL0iW4SIzS5YW3u\nqSIVaFfcMGMymcSxc65IbDRCCG79zY/RnhxjrdOhuLBKeXwGGk3qi8tsd9aQgUCkgFAEbgctJTgp\no9WILM1Ik7zYSOKIgi155mtP8V9/7TeobZ1BJxGtdpOCHGPHbZMM7P6XBF5AtVLha5/7HaLYRWYZ\niZAIR+I5Nq5vYbubW0CpUaSZIRYCmcbESiMsC0saYgmBMYjMwSs4iNTQsjcobrlsFcD3Ijh1lIsv\nXmHqPW/PJ7vdQbc7rP3pX1M5dICBm3YzumWK+vxLOKU5zq1rDu+sUhsroNMDDCYhhw6HrC/C9ORe\nBv76ZQa2jjDnTqOPnOC6sal86AeJ4xRHQSYt0LmTc4LIlTbkbQJJruxJtCDVikBInH47arNDQB9N\nM32lVH7DwhJ9f6G8IjGAUppyweL3PvVn/J8f/gnOPPclZqY1H/ngj/CBf/0LvP/Hv480TZFCYVsO\nWhsc1yZNDGkGhYJApRlK6292R+4XPUaYV3k4N3g9r+4Y5pvhov4v8CrHx/RNgG6gQRttrddWPZsY\nb3rgLTz++FdyZCYokaSaa1evc/nyZaZnJnCdCfYfOMjCcp1Gd51SqULWSdi5YxeN9SblUoUgCGi1\nWoyMjFAqlbAsSafXQKksb606DuVikajU49iJV7j91juZmd5Bq34ZIzpo5SKMjecF1Ot1EpEwWNxG\nnKZIRzM4PESpWuPy3BXm13tcnVsnEj6NdoNDt95KL01ZmLvKwYMzXLqwzmq9zo7de9i3fx+9ToIl\nfBy5ebfcNTnLzE2DLPTquKEg7uVKz/IQ2B2BV4RrJzWDBUFlGuxhQa8OsRLEqebzT3yM2sgSf/uV\nJh/8336O46e/RLt9BD/YgyUL2G6uRLLchKXwCo4aJsPnvz5bxKbHPbcLDuyHoZcMv/Jpwd/vh8fO\naJYKkxza43L06BUC16VoNJkwRD1BUDCkTRgdEyzHIVqAyUBbkKR5dwAhiHrgFaCzImjUNV6wuXUz\nOjSYKzyRSMtGZYo07tBpw7aRrfhjN+EM383v/Zv3IbKIJGyhohYq6qDSECFtpO1j2TY3v+4gzn13\nYxXKnJ5d5LGvPcsv/8rv0OrBz//4YQqFKr3Y0Gm3AQijGJ3me560JYXAzrmBcZJ3iqUmShXNTsjk\noM9dt08hLMGTRy+ywx/B84somnS6KeOjBcaHq8wur9PuhJQCm0Y7xQiFJQ2drsbbMrjptSO8bs45\ngb5ytl+/2HlcjZAbvr79IXNzWGk27rM2oFHGIEmxjEHKUe5J3kp89RRnrixQ9kFmbfYWBd1mg0Zi\nUQgcTKbo6Zh11aNclAz7BlkqcOX4S6wvZBjXxyzPY3k+jbMvMX9+jkrg4Nsab9DCH4E4lHTW6njV\nFcrVgOzcPGOp4B3BAZ4r1TiTPo1jzDftXf/Uobohdv9QqMlQrsk9i2KF7w0zcOhm7KmDJLWtCLvA\nijOI0C0q1CndZTDtM3RefJbo1FcQfhln6maEHdBaPseuXYcZL1e4+sgniJo9Dh94K36pwsKxT1Pr\nnWP3gbcy3+2AHMaqjCKGp0mNRsV1sB3wawiviOiu5GvJgNs8R+Ox3+Ern/5D7jy8F9lZwvQ6dBaX\nwC+x3E44t+5w181bcdoe4dxCvk/7NlbR/7bz8I/48PR7mTrJPXQwGCGRRuQ+vQZ0H8HRBjKdkiko\nSE3az99ykRiTS9sVAidVzN13F9nkAFsuX8HUaqz1eqS2y/Zaj/HVF9hmTmCLEaS1lt9ktCYTFY4s\nDrN0fYXayDDdboSlYsLU8Eu/+BGqU5PY2CiZYbmCuBfy9JPPcMcDb+KBg7dz5BtfJ6iUca0CUgq0\nFGilicK0f0rf3CLy0lx5pXRKIdPYRhNhKCjQgUfPaGKTYqe5LA9loWSO7LjS0HMU6omnaK2kTD34\nXWS9HlIKisogCwW2/Mj7OP3Jv8AcfZG7H3wr83PL3HawwM57xrGaTdaeT1n1ztNai7h5r4N7e5Xf\n/b1H6RmHju3g3L+d7Cd+lNLwAKYXYYQkkzkR1JAiTZ5UJYVmUCtmTY7guSZ3x3aNpGM0Ba2xrM2z\n5LTOUZw8a4y+bFbcKJIRos/Hyf8oJfBtwe//+R/z73/u3/DkZ/6Q6e1b+dtPfIwjzz7Hr/7WRykN\nVXj6ySN4QYnnvvRZXjl6EhvFvkP38+bvexsTkyMMT1TI4gwlNOKG8eXG6/TXdE7s4VWcKf/OPxz/\n0KdnA2HiNf+lRtywXf+njn/x/vfz2S/8DdOjQ8SJIgp7rK7WGR0bJ8sSlHZ55cQ5SpUqW6YmWV2r\nUy0PUKsNsG26RLvV4djJ41QqFUZHR1FK0el0KBdHqVTKKD3NxcunQCWUSj4Fr8jpU6colgrs3j6M\nbXzSXg8bl26vhzIpVatKohWVoocCrs0v0ut1STNFZtdox+BXfMZHpojamsxYqEiSxAn1+hp333M3\nUzNb6HZWqbeadNMM+c9YN1Qi0qbEdMFyQSbQXYckMphM0G0b4h7oQh7X0FmFyjAMFUF1JcZbYnQG\nRsRVwpPPMGw0TjBOHLtUqhXGuqCFz+pKzMKypsAqFX+YnnL5zUdS/h874Y5Dkh9+OzxwS97ierEL\nfsmmEXfwyqDCnOc2XAU/AETuoxWG0FzNDQ0zBE4ZxrbA8qymFwrKRYEnIFUGy4FCZXNTk2YGIVPC\nKCZJuszPZdzxulv5+Z/5BRjezer8LHFnjTTqoFUKKs1bE1qDTrEcH8v2+62cXFThWA7ThSI/+a43\n8tMfeC/PHDvPb/3H/8yxI4+z/eZJBgaLCMvC9Rz8okMBjeu69OKENM3IUpOvvzBmpObyzu/aycp6\nm0dfuorWLraUzIoO2weqWKGkMmhx295Rnjl+BUtUcLRNmqW4UhKqXP3k+y6feuQZ3vqvNzc/wt5A\nYflm2t0GgmvxzYcYDNLuG+YagyHDGJEfprMMY2CwdDPu+Yu0eutYImZyeoY33jTOlLrIQyu5J06n\n3SGw29iyTcuGwIHlhiS1UtbmMq6teoy6EaM6giyhvrhIN4TxSq4UdkKNHxhGhzVuySbttekmhl5n\njfKgwx67wWRwgL8LbC5ET5OpLLcv2cTIOinSAmHLnLzd0rmbepgw+qPvzaX3KsZlEXSG3RWsMkVk\nuVTqL1Fc+wxO7WbCCxdJTz+G8MvYo7vwihMUxiZI05jSbW/HG99Oc/kcvYaDt+MBGL2ZQjDEeH0Z\ntbrE8Pg47vQQwmiyNCbuNvBay9iZJlu7jkkBDbFxuN6d5O43v4PJ4grm8lmyzhpjM8NknRX80KC3\nb6ex1qDWWsOzDcb1cgAm+mc6LW80EKS0+7C97MNpur9AcuRGk9uJ60zn5FSRW2wnJrf1TxEYmcdH\ndBdhKugxc2KOxu5hLl05jzu+nSScxW2G1EyI7VTQZhKVPY0lLWQRRFNTrRT5hZ//ae59wxsRls/0\n1BauXj7D/OwlbK+AbZkb4Y5b908zNjzCgG/zhU/+F6zsKsXiGMjcmA+dS8kLBQv0t5Yff6fhG0WS\nKPzMEJncusvNz14QpSg7j56wlMGWEqE1tslbKjWtCS5eor4QMXPfYawwJqgVUFqg1js4jsQp++QZ\nJxaPv3KeXTrhkU8t81PfW+Jyo8N0tUh4+iyu73DiRUitLktrLSp+wHqYUCi8jumxYVpRTCjAw2BS\ng23nn0dmTP65GWj2Q2GFyCNDUr3x+ecmke3Nd7RQOl8t9Hk8eR5M7h66YUtwA1MROUKojUFmEf/+\n93+T/zI8wOUXvky7F1K/9gwffOMbeceHfpKvP/znJN01JqZKRKtdikXFV77wLC989d9RHRrklsPv\n5V/81E8zMDZBpjKM+gcsmxtGi//gE/8WERH/qzlhDhtt8JIMYLTYdEvrngfewCOPfon3vPudGFUi\n04KgNEgUK3bvmSZJFPVGm1YnBiy279pNFMXML63Q7V6hUChw33338vjjj1OtVmg1WxQKBTKVIaTm\nzMnzTE1NkmYRwnMRSKngdQAAIABJREFUysITDo12h+PHr3P7bQdR9hq9qI2KEyxPU29EFIIWg0PD\nuI5HkkC9k7Bn2xaefP44pUoNLygSxSFGJGRhRH21zuzsPN1eC9sxrDfmSbOYdqNFmiYEQWFzEwMs\nzsakkcEuSlqzOWG5NgrFAUG3kfcoaxOGaMXQ6xqcIiQdycx+gy4bGmvw4vNwYM8pnOS32Tv9QXTH\nwx6vYYkAGQQ065pWG4JAIDJY665SK1RphAX+25ESgVdn24xgfFTz9XMCVcyRocC1GRkcZGm+haMF\nqgu91ODYArsKvWYeOSOkQCuFVxKMlgQTe33mlhPKIwbZEzQ7kMaC5urmFk7usJxRbzQZqdX4/F/9\nEXfc+1bCqy+zfv0MRmfkkX4uUkqQNloIjMoQlo3lBji2Q16m5/5gwpJE7RWiXh1r9hiHKwM89Nf/\njadfOsOv/9ovgHKx3Vz84DoWSkOr3SPLNJ0oodWK2b+tyu27xmi0Qz771AUcy0FrB9uCoYEiQcGi\nEJTZMRFg+yGLyz26XUm1ZJNkmmrRoxG3SU2GaztoLXn6yxc2vXYcJ79ElXlNdxte49L+qkjhBkpi\ncuRnQ5gjcRDGYzq4naocoyQG6YkQowS6OMi77zvErbUrRF2Pdr1DImx8K8pzJGWRRpbRrWu2DWes\nhbkasx3btJZibop6uCLi+oUrTI2UaKTguQo3VswuCjo9xZYZD0vGtNe7tLsxlYGISiGjWs54j7yH\nR/xBznb/FqE2h57qOEXZ3HCijsMYKRTO4E50888guDsnzKku6AauKDLV/DMIFXrNJlvOCM/+PXpo\nJ8y/QjayC1kZxysOUAsc3MEhyjM7kSO7kNImy1KUX8GpDFDsXmXLgd0IuRXjuWCPo+OIjjWFbc0R\nlCbR4x/FPf6nJE/9AcYGk6UU0xVKJUly/gW4chwRjGB0D6s6ggMMTO4imDtBsLCM5VXRlo1KCujk\n28eSfGfSshCARhiJkBmg+nEAffWW0RiVgTZkQpLEWX7R+7rvZJkzyyUiD027nlD90z/h7LHTiJvv\nYam1SMEusNxqoSKPc8kwQ/E0tdYsQ8WnMUGASRJUF1KtObTT4RsXG3zxb79AmkUsr6yAgqHBGtvG\nHVwElhAUA8nEQJeks8Arp5cwwmLL6Hh+YxJgdB/a7HvevDaT6Z86MqVxMoPKFJZloQU42hALg6sV\ntjJExhA7NkKluNIiFhoPwZotEMcuEtx1C36U4Rc8mt0eKjNkWUIhk6RRwp4P/RClwKf72FeIZpcJ\ne12+cXmBqe17eaqTMnzrfVy8cBm1fJmCL5FJQiwkeA6rn/oK1kuXsd/2FsKxSYpK0bPyqIhQ6xtB\nrsoYEgSZMGRoXG36ztl5Blpk8vjQzQ6t8skWhtwyXcAGFGj6Vslig3Dcd+oWxuQ+T0nIz/67/5tf\n+7kWpeWzZNhMjCU8/tDHGSpJ/JKPJwqMDVqMTli8boeg4xRpxxmt7vP89kef48CB+7n7Le9iYsdB\nHM/Ks1voFzH5y95gH4uNtX7jKNhf/2ygQK+9Mb36TPc3zs2q9rMk45ZDt/LEE0/zkZ/9GU6fPE7R\ndShXKszONykGNtu2b6PRajM6Oo5j27SSNq7jMjo6yurqKp/4xCd55zvfyaWLFzl0yy19T5U1Go0m\n+/bvo15fpFwpkXUalAtDqJJN0G0T9gzPPn+Bg3t34tkB651L1EoFRoeHCdwhOi2NMS08z2Vxfpa1\nlSVavQzXC3BtiygCzw1orLdoNlq4jodTqxInHeIMXMfBcV1sx/kWJo//+OisCnptgbQhDA06gfq8\nQFoGxxOUa+B2IHMFB+4aZ60dMr/QwPLgzQdvZdJvInDBk7hDNb701MOceK7BvrsnmBwPWLwaEiX5\ntT+9TZI1LRYXMyLRZGx8hp5b4i/PV/n+6DIN2+F/nE45sGMcOy5RERDrFraTobIerhBkXUNmeUSJ\nxvZTLAdUokGAp2B1yVAKNHYm6K2AE4AbwHDJoNPNzU+jlbC4tMyPfeBD/Ntf+ijJ+hrrJx5Da4W0\n7L6wJMtVgypXkTlOhuVoRLxKoTBD3DiNEQFaZ0i7gjEthF1BygIaQa+9Rnz0S9y1ZTd/9VeP8tM/\n85O0WsuowKcb5n5YaZqSZBqtFd//pm1cma/zwqU6aWZwLIcs1Xi+Talg5cWO67DcaWBjs3e0zNHz\nq/iujSUkcZZgeQ6tKMRzHHzL4eVjs1zcvMAPzzH9+JwN1Cb/ujGv8uxumL+zUQDljxSmz/lLsPQo\nIQuUdRWtV4n8rTRq8JbtAzjNc/gjhri1QsXrMbeksXxBQdio0iAF2WFs1CVNG1RLNm4BGh3N6Q5I\nNyDNNK4LgZtSbxl2TtuEPU2SZVxbLSKCGiOlGRJfYHiGxfklpm4+hHAshlWbN3gHWNUniZKzm5ob\nk6YYLVEahMlwxncR7Hs92aWvIA98Hk0PGV3FNE/C+iyMv4/MWWT9xS/iXj9K9/hRro48yLi1SjFs\noZbOk219HTg2vRe+SHjxBcy2w9j3fwhnYBSkRdzrUGuepLjNQieLYFqYuIBwJunEMyRekUL9G+jk\nq9BZQKRTCHcMEy/hCdjzpu+maWzkvgdhfB+6tUK0uoxxB2DHOOO77iQoZljN55C1SYzl4/iC4Pbv\n+bbz8J0LHmMwIgOdx68LQGjQUuOXBpC2lds+FwqUysOs1xd564+XuHz6NNfPXMBzQFgWWhhMK8H9\nyM+xf3yMRqOIIGW85RAWfewsIUsbeF5KgyorYoZEncUjw0gDKUjL5oGbI5675JBKG+3aVIoDZDoF\n47FtOMNVGUoYjMlor10FaaG0wfbykDEpZE6iJnd6DqOE8QGLMJGEyeYq5ixRyEwjpUDofpQBOR8m\ncyxQCtuAZRRS57dJS1gkZBSSjDDLKDdbqNEBVqIQlRmMylGgOMtIJZh2mzOf+xxycZmG5+EJyVfP\nrDGqr1MlI5m9ws6b9/Lpl1t0Fq/TyTKqWIRxSiYtrpw4h33kJCO/8n/RGxrGpIpMbtzCNYnJ84CM\nzvO/BNAzCmNstNH4ZoOevnktkk5zObq28jgIITdg4w3bcXK0pN8qvcGrMQKlBVkv5Sc++h/49L/9\nMTwr5eJywq23Hebi2WPoTLGy3qZQ9Lh8vs2JbsrerSnDAylpy6dUtIlWniBefJEXL+9iZOaD7Dpw\nAKVTbkRFvLZFhcgX9gaCeeNv0a9/XhuTkf9eygAm9+/Qm9SlW5ZDqjRTM1v5k7/4FP/z03/Jb33s\nYwhhY6SgGFRYWFgkyTJmZxeYnJ6hVCrTarfoLvbYuXMH3/M9D3Lt6nXuuPP1NJtNLOkwNDSEZUsW\nF+fxA49aYSTP47HLONUJbpk+zOj4JF999CGunHqUHVNDOH4FIQVFx2Xfnr2UKxUWFhbIsgTHdUlx\niLMUO3DodNps27qdRqvN6VPnuHz5GsVimTvu3kuSdrAdH2kFFAo23W63H2y5udHrAZmhF0KvK1CJ\noTIEcUuQNgwiBa0gszVLSy1WOiE9BWOF2/nQD3+C9NIjJPWLICyuX5tj3wCs7TAEwsZxPMIIpAth\nU4BSpEg8T7LW1ExNeNRGtlMb9HlB7GNsyzbCr/wBWTugHcWsNyNW1tvU10NsoagU8mTrdickTsHx\noejljrNK5TfQoS2QxhmTg9DpgePmr48Um+bGnbywzH//T7/CBz74QbrzV0iTFCwHow1SCiwrw/fK\n6OYCOouxHJdw/iJZ6xxptMb1eQvLalC86UdQV57A8osUqhmy4COkhV3ZifRLGG+I3spVvF6LP/nj\nT/JjP/bDNNsdkixBG0GYJuyaKrNlyOMzX7/KgO8hpEZluWrRcWw8z8L1JFHYIwptfOmwf6bAi2fn\niRObMOzijldIlaEXxQQFhyxRZKnhsUfPMr3Jdh/QD6Kmz8nZINrRL4Jy2kKeGvDNhxgL+ganuZDF\nkR08tZtF2qBn6NR73DY1wsVLl7il0qG1pvA8i/nVNVpxkfEtk5hSmZHpCfbsGWE8OkfcgG59mV5B\noAdTzl+BzPFYWl6lZRy62jA6WCBOeyy1BRM7xpjcNoQztJ/y5G6GHJvLJ8/Raq5gUo3utWg0Esq+\n4qbyAY6bzRU8QoCWLsJkWLbP4Ht+CREkXDv1PNlvvpmB3btYfPk47pZpossdvC1dkh3j1FcdVjtb\naOz7MGNyleraVVKnDO0lVK+FKJdJpY/nDiDmzqG7bUxtGAs7b58qjWm/AIP7MMEbQJYxS5+kHDSo\nLH8aetcguIf0YpfemWMksU2t0C9Qr75EeYeDs/Yy4vJjqI6FSC0yAjSSNG2jrp/AbzUg6iCsDPEv\n/wNXJz7E9m8zD/8IadnCcjwsv4Bl20jHRdoBQuYOo0oZLNtgjI3l2BgrYPuBW9m+/yDd+jyPP/S3\n9NbrWI6kU4fvfeANLI/4TB8+yKlLzzHij7FmegSRwTQ0LR1ytaE5UFyiEdyMbh/FCyxcC3ptRRq7\n2E4Byyhild9MszQj8C2yLMv9BSBf6NJCaY3KNEVPoXUMlp8ncQsNRiBtm2uLMYWCYXJ4c/C7m8Yo\nI0kUOHa/1BECtEJbAqToE7mhb7mIURrHZKRCEhy6mfWXj7L4+AqOJVG+Q+C4pFmGrTWJ0YSz8xSi\nkFBAN4xYz1Le9LpJHn34MSaqRSYmJri+uEwS9QjbhmLRIjIGmaUIyyXKMizLRx1/BfmmBzAiT8i2\n+rEWcX9T8IVFYhSpyYP0BBotDD0DDiZXL2xyZFnO2REyRwQ3bHAQ9FugG06lpu/WudHeygvTVBnC\nRPLge99P68ojeMevc3G5iR8Uub7eZMv4KLbtstBcZX5hlfl6xMjwIAe3ZsyM23QGNX//yDIFZ457\nSzVsdqM2+Dv0i5h+i+qbacyvjcl9zaXQf/u5mRoYI1Gmb/u+Sf6X6L9mqlOCUpn3/8RPMnPodn7+\nwx9msuYSxxHVWhUSSbmco2tKKTzPY3FxkcWlRbZt3cnevfuJo5TZa3O0Oy3uvOsw7U6T5eVF9uw9\nkKuH0gwZjLLrwAPcd8/d1EbG6CRdrp39OstLTYwlKRZdjNbML15kb20fh28/yJHnX2bP3gM8f/Q4\nStiQZfi+R7FYZL3Z4cyZs3nUgIyAjCTtYrseWhmM0Hiet+nIDcg5O61WTiwWicEXEC4ZhDQoJVmd\nM9hFw46REodqw5yNZ+kIzYNv/F7W1pc5fXWZ1bPn6HRSRjzDwNgAu70hCpUhWj1JsQrdDpTKmqFS\nlZW0R1zMSJsgbIvKwADj4zV2bp0GLSl4cPTCedprCVkMrg0lActdgWNpCoFF4OREf50ZjJNv1sWC\nQCY5iFgI8rRxT8BA2abTzWi3wAo2Nze/9os/wwc+9KP05q5iTJ7ZpEzennIs6Jx8iujq4+jqvTBy\nE5W9h2icPMmVI+fIsNn97l+mc/U4dTFO1a1RX1tmfTHE0isE1SLSPUZxfA/FqduwStNkUQv7+kv8\n5m9/nMN33M2u3ZNMSof7RrcxIQpUWhX27xrjCwtnud41eDZ4jkPgS1xH0g0THCFx0HRFxHqniU4E\nb78FVr0DfP35q1R8lzjOd4RyyeP5Z69QB0r/DHQQ+keVDSSf/sM+eruRS3Uj8qD/E/pGi9uA8UnN\nJJncT5LuIWx6jJQv4Do25aBIu7WEkDXmG10SEVAZHKGy43b2TdXYNVVloKBw2w0un7iKiDv4wmXn\npM/1+RaZyhgdrWENDjKfKK6sddk6CIdukgxuGSAqbqU2Mkh1KKXXXOf2H/pxXvjM/0BpRdJaoxiM\ncPS5Jzn0tndzwRnY1LzYgY9wy7ieRGiDHJ4iOv5JvMvPc+n9f4Aj26ze/iYqQxPsf/9+Lj31MEl3\nhfmFOWrbZph79PMM3f92ms0BSvE6Wa+JCdvINIFSBXd8DB1FmEIZ2Td9dHyfpB3SPdXCd34V6/6P\nw8zbEVMfhe482ZUX6Xz1YUT0MMvug8Qjh2FCM7DwKZCQnX4GefUpKFWQoobRMVavBypEyITwyBdw\nTIhTBKMzxPhNOPWn2Dr/G/C26FvPw3eapMrELhAWpm8qqPsOmkrlC9QiIe1JMhWy2G2jVYbvFkl0\nLmfutOrYlovuJoz91L/i+dt20Dp/EddewOt4rDcWsJXCK3qoUkotajG0c4zSskXLG4G1KlVWcQJB\nOQiZiGKUDHGdIrKb4EhJikT6RYSIEX0FmURio0lU3kuP05D6WkZQUvi2hbSdPKXWkjiVApnSXF7c\n3OYcpBlaWAjbRihNKvLWEFLgacAYMtfCxZBohTRWTqTNEkyY4Vkpwe0HybTBihVpEqKiFFuCyBTR\n8ROES2u0kph9B/ezMH+dJEz58pe/gS0Miczw0x4qVJjlLhV/lEZ0nYGKRy9NsUyGznJEY/Ghv6d8\n9BSlD/wgemgYpXLzKokgNgIpMpQ2RCL33QnIQ1dtk/d8N6+1AbTGaOubTlk33JA3kBXz6mbERkcJ\ngZAaW8BAwdBwXKanRllZj5ldPotrZ0yOj6BUwlq9zdrqKlt33kaatEiiFpdWSiw0W2zbNcjgTg+v\nVKMVfoNTL/0eO2/9yDd77tyQqfeLnX6OlvhWRU//54zue++aPFhT69ynZXMj32hty8IYUFnGXa87\nzItHX+K+Ow8i8ZhfWKA2OESpVMSReV5NGEbs3buXVrNFt9ulWqkxNzfH/pv2Y1mSufmr1Go13vKW\nd3Lq9DkqpTLVwhaGxiYhXOUbX/8b3vq2H+DQ/l18sVil2+rQazYZG55mdWmWOGkwM1GlW60wvXcn\nS9euMLu4xvTW3bQ7bcbGqrTDJuv1OnEnwyn4GFsQ9lKCUgkLBaIHAnyvhONsUvpIzv0qlQVpYtAF\nQRaD0ZKxkTKzi02QhqoD991xF3dNDXKosY+19TrdI0/yd0efYFV5iK4mnb3Cjnv28vjpOm6xQtK8\nxPTEPgrSZz2JKBgfFXvYOsOyNDJWpL1cYJAmESpLKZWqWI5NEiYMSZBlQRoU8KIeaxGEiaHTy7Ad\nKAQiL/KxyLTGlYIggGjJgAuRkxOf49Ah7mV4HmzWl/Ejv/iLRMsLCGn1UWWJlHmsz/LTf0Msh7k+\nO4JYuETvSkzjr/8/guoIyEkmRwdpnfkcl06eRdhfQqWGQiDAKjE4cA+tlWsErKLCOdJWncrWW/CG\ndhOtXWFsy8381r/7Bc7+1eeZqY6zbes4A4Uqg9v2Mmy3eN+e+/ijv/kC//nhc+wZH8LolCwFWwjS\nzJCZlFt3D3H8yjpGSNZnl9l9/wFOHvcxKgUchIKlpQb7DuzGosajRzaHYABkNwxOBVq/yt/JL13R\nV1puKJXoY819Dmr/2k/ZjWveh+pN015dob2+yl27BinGK4Rao2xJz5R54VIDhrZz7+vuYseWCpOD\nNq6ruXTxPC++coEdtmS0OorMGjhezJ4xGBktM7F9F3fcsofmygLe6ir7bt/OxctzfPGS4mr9KFnl\nOt99+CA/8MZJrMymsvUw0CDTPl6yRmCF2OvL3Dby7ds232pI18UZnMAtOajZa4DAHt6NdAVnH/9z\nWhfPc/ewiwjKnPt8D1tnuNsOMrZjmqKruefgTlZVytmVlNd5EabXQoctpEoRnQai18OoFOnkHDGT\npug0Za2yjWyiyGB4mNKf/AFseZjVWEC7QeUNP4X80I+ynIa89MRDTLsxb0g/BYffgTn9MIQuQmXY\npRFEUEZdv4ZJwtwwMnCxlYWRZaCb7+txFxPuwJp+w7edh+9Y8Cil0LIfCKAUSsUkKgEtsbDJTIaQ\ngsBzcvtuA5mO0QkkUYpSYNsJyQqYH/hhrp66zMDYMIXVOq3nn6NbX2f3zt0sqAoeLv7gGM20x7ND\n7+N10UOI0s2MFFdQUuN7p9m3bZnvvm+KuXrCsRcVjgvDgcuQ3cIjyZewEBij0AJUljdyHdsmCAQq\ng44BoTI82+DaHgiDYwmsTdJUYqVQtsZKNTG5u7ISEq0N6AzLd5ECAq2xMEQ6JRFw2LL44JjPb11u\n04xjHKVQRuBKlyyQyGaT+a8+QdhoglHce//refEbz9OKEzyjGSrtppu1cRPN0lpCktmURY2esbHj\nFB1H9HpdLNclMhrbQCgswguXWfrV/8S29z6IfMN9dDODtASOMYQaemi0FkgDWuRp9hKdy+s3f1DP\nc1iMBN3nA/drH0vQl6ObG6ewG6rwjX+FQGlJ1MvwRYLwxrn11iJTNc2Xn7vOK6cug2XRajYolkvY\nepEgsNGOjVEWvqXorHTZPV4kmA6JQolSx/ICx+jXbILmNZSdPrW6//VXc7Ty75o+61EDRufeU8YI\nMi3I1P+KCH3n0T9p9osrWwBGY9kO7/3wz/LIX/whJjE01tuMjxXwXA/X97l69Spg8DyPKIyYnCgy\nNDTE4uIiK6srKBX3nawlYyPDNNbW2LZjF+uNOllzlZVWnYUrywhtsXV6lFeOngdj0+2kzGwdZbQc\nEM/NI2sTfPXRr/H6227iwfsOc/L0NbaUBjl88HZOnj9OY20ZEfUQvoNSBt/3cGzwPAGOQqsckZFy\n8wtHCLA86CUQNnKbC6E0aWpIk3yeE2144shXycx3cebIGW6fHqKTpgwffD0lt0w5miMYzzgy2+bC\nYkQWNxgdG2DHjA/aI21E4Fc5fzahnXZRRuIVAQmWJTD91rdlS8YmanSv1RkvQAdBW0osKXCFuUH5\n0opcdZKB5UOWQS9SZF0B6wJRgKVIMzPuEXmGLDWIusAvb3J+uiFGi5wDp/OLR4cr0FPUzSS92Uu4\ncp1uYoi76+y77RCeyGhEJa5fOcVINWR0xwHOXr5G2OhyaM/dRAunabVnCXsJJi4zXtmBPH8SkXwG\nmd6HM3ov6ewr3HfPW2h9/LfZdmgCqS7QbldJzi1SPHgvekXw4Xc/wK6hcX790aNU3IAkTlBakCSa\n8SGf5UYXg8C1JWeuF9i2tEqaxJQKHpnKrU1KxRrf/Z57ePBtCSfev/mCR+s8KHZjbACMegPq2biW\n+y3oPDSk304XoMQYMvt+mvM11q4/h05DbGMQ7EZHHaSJmdy+m+qgZL2xQnu9wdriHHvGbXqRZGR0\nKxcsnwvL63xtrcehAcE+qdkWZ+x+4yFGx7dglQc5eP+bWTr2GAud62TC5rm5aZ5YiFmw93H39BQf\n/aNH0Mk93HdoC2FjGW/3BCunTiEHq4zWDKXAZoqJTc2NkA6qt0q8vIrIgO4K9thBKm96HzujEpZM\nKF09TppYNNbaJGmX0k4XvbBIsm0rztgUMtbMlA1JMyHrNpHddYzKsNOQrNNGWR46ClHC4HgBYWOB\noHmFhZcf50RngJ1v+Vlm9uwlWl7FlRFCtFk68kWS8y+zerqLtWOa6O7vovjGj2Aq47T/+k+oDXpI\nx8JkKabdRsQJOAaRGhwRkWmvT0+Atr+N+aVB7BPXuOnAt56H71jw9Lr5EUTK3DlTaYHjBEg03V4D\nYVwsxyWJIpRROLKA7fkkWQdL2rn3wkKX0q//Kpx6EefAbSRikLWVF3AnK4gdW1mtjmOldRKnSN21\n8KMYENSTae4c6SCT5xFxj9TxcYI13rUj4eSUh1+WrNdjJp0E68oaGRZCaqTRaCmwpI0mRTo+pYKL\n9G2KMsiBBZG7iLaj8MZze5NGPMKyMJaLSGNsITHSJlEppSRD1cooAcVMEaHxDDjCopXEvCNd4/n5\niLeOjqGilE8vtigbTSo1thac++znsaTAVgmu4/Lc40/jSShlGUJKWuFVbHsIpRwCK8hzjoRAZ4to\nI2i2uuzesZX1ZpdkrU4mExzHxjGCrusz+5m/o3zyJNUf+QA4AWE/CM8yeTEitSYVeR/cRpCiwGye\ntGzSGGwbaUmsvvmFeA2a8mr8XD7/0myQh02fSA6Vik1aGsBaP8n64nkwEfftEvzQ972PqwsNPvfw\nSzx39DSq2aboOYzUylQCzWK9y+XZFkurHm9aHWTnXZKhcA4hHIyJ87wtuMHH2XB+FjdQH/qPX/sL\n5eoPra28paUFmcq9hTK1yZbWN8VX5NJ9aeWb9fe89wPs3jLG7/7HX8O1BWtrdQrlMkppdmzfwdFj\nR7npppsYGh6g22vxwpFvMDo6wvTUFo6feIW9e/dSKhVRmUPSabC2NMfI2BitqEdzaZVaMEypVKbs\nOVT8ErHrsbraoTk4wNhIAenB0tIczU6LY6dOMz46ysCQC67F+UtXWJpf5O0P3INvNA9/6cuUaiOE\nYRcjJX6xgGVLOmEX2wpwXWfT68aWhjgUdNfB1nkx2MUwe70BrkD6IFwYnhylUA3YP1piS8VhascA\nYXKVARFQ8VOOz2uOnFul2+pgqRi3ZpGaXJ0yUoV2c52BygjadHH8hFiRy3SVzmNmTB4nEDhVipUW\nFTRxYuh0UwpuHoks+kskVeBoQeAaMpXlPj9CYFsgHEOswZOQtRO6DRdTNFQKgmSTXZs07iGtnDcm\nbINqXKI1u8jKya+TJSGrS/NMHPxeLj79EKWiJFyrs5YVsMuS0ZEywt/B6uICywsNDuyusbpcp7HQ\n4ObDtxKEXZrtmJ5dozp8mGuXTjAjrzNUWSBrDzG+9zDBFlhuPoe0PJzAp1IscPHs03RaW9m27xD3\n37qDM5eWeOjCfM7b0yBtuGl7lZcvtvEcC2nZrKuA1fML+H7E4ECVRichTntUCxW2Dgck1YBnH//Y\npteOUrzazuoXORtHkRvP+09Mvw1txIZ4AlT2LnoXLdqN4xQDSbMbgjAMyC6tXpOdkx7Dwz6mHBAm\nMdovMzQ4yJlr67znB9/OL//iz/CpE8P8+v/xLr7xl4+wtn6Nka0F9kxWmb7z9QzecivNtsYtlqnO\nniHesoVj9XGccc354+cIpjuMDipuKrVojW9naWEe2ZunYFfp9K4zOFzCKwV49Ng5sjlTTy0t6IWk\na03sLKL5d7+Pv+sAJo4xjQ614TLt0xGmu0B5YIjWehdneY7Re99Fae/tIBKqZ55AzFlkiUKHPUy3\nicySPFEhU/SEJAiJAAAgAElEQVRMgj71WWzWsLIYu9PCUw5JOkBj7hrJX/4q1nTIlkIFtMRpXmGH\nv4XErNPZsQt7tMILz1zggfJ/Ry9epdFwCdsxU2IRHaek6zFEClMtYOIMkTVRsYMZyA/U1fknWev0\n8D/wu992Hr5zlpZtIYRifWmRoFrFtl1MpomUITYuUguSTpuy52NZFhCjMg+BIEsjuhe7+B/8IeSe\n7ezcGXLBrjDQOoesJUR6iO5KwuLLJ6nesYNi6zqqPogMBI51jOHiKtnKw6RWjVKhh+7l3Jg9OxfZ\nGS1xKNhJkkQ8e9nlutZYroVlF3BtF2FbONKmqyMCE+EFFgjTb3fli9+2BKXAw6AR2iZJN9mXUAmB\nEES2QAuJkgYvTBEqw+2FYDvEvS5CWmS1MqmOcdebLBQ1w2WfYthk/uIcpWKNIAvpuFWssE3JsomT\nKE9h7nYwSiGQZErjC4vMRGTxdaRXZC2RqEzTUV2kkGgybj10kOXVFfbt3kYYxdTX18kMJFbuhI3n\nIy4v0Pr4J/B/7AP4ToGkD/dGUuGR53slymD1k2+9TQar5kMjdD7jG+0iR/AtUulzHGVDFa5Nnpdl\nWxoiaNcvwOLL+MUhyhWHlc4q2cIpJoMBfvIH7+T+O7fyub97jvn1DgvtHt1Usme0gLAh1hkdE9Or\nD5N6irjbxSu4mD4nKX9fGyq9jS+Y12RviRvvUZN7CGlD3008z19LVJ+vtAmE57XqpRsu0QBGMzZQ\nZtuD7+FrX32CY19/lpnhKZIsY2CwQL1e55ZbbiHLMq5fv0YcR9x77+s58sIRtmyZ4M4776TVajE+\nPs75c7NYwqbdi/E7ITM7d1CpjmG7ASvLi6yu1vHcImmcIV2XlWaL6y/Nc2hrFSVmaRjB2RPnuXjp\nEaamtvLWt72D7tIcabeNJxO+5+33cdst+/jT//l52q2QSnX0BkKI0IRRF9/bfEurFwviFgQFMMpg\nI5Dd3My0WoFmy7B31wQ/8ObvZ/3cy8xMDTMyVcWUXMJeQoTHsy9f5tG/eZFGDAcOVHNFaSdDZIpK\n2Sa1YLxQZMuuQa5eUowO2iym64SdPoPLSJRRGKWRlo2wNZ6QuErjeg5CaoTMeTNOfh8hTaFakGgh\n6WYZRSc3JNQasHPzRCUNFcch1TZZpLA2OT25QjaPelk/8RWuP/bnsO2d9NQg7ZWzVAYqnPna5xgZ\nn0TYMNdIEMUAa3WZbldTli+RWePccuhmWqnEUz3Gbr6XuYvH8bhMbdt3kamUE09/iVvf/cN06v8/\nZ28eJdl113l+7n1r7JGR+1pZlbWoVJKqtC+2vBtsyY1t3HQDNk1j3MeHbgaY0wxwmOmZ6WYaOG7A\nNGObZjA0jE23zWbkRbZkydZi7btUmyprzco9M/blrffe+eNlSR7G1jh9z4mT8UecjIhf3Pfub/ku\nTzIsLFSaIoWhEYDTnMIvDlMrj9Ffa3Mpusxm5xVWe3W0GOKd8zm+eNxk/lM2HDkwxMnLLYbLRaJE\nYTsuXtUmb0ZR/Q0a4gzVyggDEfH1b1/k5jevMDZVpq4Me3a5d7QB9Ov2r5A90ca8ZrJqdm42gu8a\ndWEI4vcSnRlG9U/j+z71zU1ELMmPTbDdbDMxXKboK8qjk7SDbQpDZVha4/KJp/ix99zMf/rPn+KR\n1j6YsJgKV3nwi7/D7e/4MFv9AaXJa8jPHEDuvZlSfZvW+ja5yf0U6lt4lzu8cnqdw3M1Olsn+cs/\nf4af+Gfvpt3osnj2O6Qba5SHrke7Qygnj3FDamUYBLuzbdH9AFkoQW0M1arT+85X6T3yRWy/yOxP\n/w72vv2Ub3g3aX0d27IYFym600QtPcXm9kUsT+Md/zKxKuEkKWrQh04dmSTIzhI6uEB5apTC5n1Y\n9hDCzWP5Lp3mgNagzB13XEe1fo7i4DGsdDPDzMUWzPhIIxmenmB1c52JtEH6yN+iBJRsC8cvEW71\nQBi05yE8H2GDqNYQgxC62xj1+uHRjiy+8asf4Zf//sL3jMMbJjxxKon6TRrnzjF9w63EcUwoBI7U\n5Gwfg8LPVbGEDSSoJMEQY4yiW9/G/x//FaW7fpTKTIH12GGgYyxlI3PDJJstTNomf2QPnqVISzW6\nqk2+H1CSKQV7mWa/wuR4Sj0eZdTdotsbQvZOU5mC/eMvEsi38PLLF7GQWGTtaOnYSMfHIEhNmKlD\ni8wAFTRGitc3vbkyXNBYuyxGTaIQJsK1LAa2Raw0ji2xE0Pa6eMAkSWwRETahg+Pl7mYWJwZJPxI\nGlKyBJfDPvuHRjmtLXKuYe3L3yKOQtI0wRLgSkkqMtq/bSRCKWxhKLgOYTzAtWx6aZphqmyBVoJT\nJ08TdgecPXUBmfOQxuAFEdrXRColNik/ddfb2CgU+Zu//AIzH/sorlJoBL4WJCJFqaz3osla54Mf\nAsUjRTZevILbscTODUZmkvRXKOlc6awACLCE2El6BKEeoPqrpLJC1bVINlYolT1MEKB7PUItmHZc\nfudXf5R6I+azX3qO4xfXOWUMw2XJ4SGXtB3SSvsot4SrXgJu3RleZeBpcWUMypXPw+sl4ZVOjCYT\nTDTitWQnVWSPVBAnP8ToZuf9r7yrRmChqdgp/Vjzz/75T/H0g4/QbndpDDpEYUSz2WRrc5Ojx47h\nODZxFLG5sc6e+Tls2yYIAkZGRqjX6/j5Es1Gi36rx/nLazSimLxfwc0lOH4ObXsY4aB1H6VTHK/E\n1MRV5Kw+/cEAmVqsLHVwrYMcO3IzUSsg7gxYmJukubXGcG2IajnHDceu5cVXz5MkguHcMMaOaCXN\nHdDq9wYOvtFKNWBBbxWcSpYkjI4KggBUCvnUkFMhX/i/P8PcyD7qjZDiUokf/cB7CUyfsdkF5gYW\nb5tb5MlmxLXXHaBeH5BXGmFJaqrG6kqPQs3FWA6FUomiJalqi36SIowmCgMGvYBBboAvi7g6ByKg\nmPeYNIIkMJRsSLXJZDHygkZkSLTAaEWSAJYh7IOxwaoICq6mYBxcz6M9KBCFA0YquyuyrgjBxq3L\nLD54P72whlx8Cuw866ur+Dmb0b0LJCZPb2sJrzLJ1EjM9pagvlanMjqCwFC/+DzV8YOE0YDS4CJR\nZZq0lSDW7iOSh9l/90dpbZwnP30jRsVYuSobl1Z4/Ntw6DaXWj9modRiz8ICpeG3ceHc01yIezzy\nyNNcf2Sa66o5Xmh0mR3K0+8nRDEgQmzp4loJe2uCiX3v52Na88DLD/LAc/9A0vF58dWQH//IZ/n8\nf/lZrj68ew0nKUDvGN6/htMz2fX7mj7xFUwhWfLjCQjkncTnryburiJtcIyCVFCqTXDV3jmOHfAZ\nTi7R2NrAr+a5vNxkZrrC4OFzLFwzwV9+/q95ODnCdfZJlpehsT2gYqfU5qeZu2GK8ZveTv7QHaB9\nhDdCdX4YMzFJv9GmVn+ekYlZ1je22epa3P3TH2B0uMjcpS+TS236fWitrSDcHNLLM+R4BJ11gsDb\nVWx0vw+OjfBymHwB02ogUx+VJOiXHkaGF0iPP4qDZLB4HPwKic6TFErEty6QD1ZwREIUbmPiBvmZ\nG7D37KNcjnA+/G/ReOjkMk6aolaeQ595AdKUIdFjf2UM0biATJvEA4NfAlnwMW4eFXRQXpVCocL0\nXT/G/GO/i94GIwXlvEKWbZzhEazqUGaNISW600E1Grgqxq/ufMHUsN2F50OHH7n16u8bhzdMeBIl\neeSFUxStHPOunQGVjSY1V4hQLipJSFSMZVmkKUgN30lTbv3dzzBFxGwqONvu4yd19sbnUUGRSHWg\nVGV4zwHieAOrnRBffAU3X8WYRc5onyX3doy+meL5hJu8l7ipWGFibgknVyZc6mBJTUk+SWO7ghHO\naywpYQxohRQWYRTjOmTdDyOwjMyEuXhtBM4VCvJugSomCglTi8SyCR1N1bIIhUAXfJwoAZ0ipUFa\nFuOu5pvdLuX1DT5+ZA5bxTTOXWY671MbtHj39Uf597/9SQZrW3ixxlUJ2s7Em2b3zLG5tU0xL9GD\nkDAMCJUi79mQxJAqfCFIg5A0CrFFiTSN8T2H3qBH3rIwtg39FNu1IYz5/Fe+yc/91I9zsFzj0p/+\nBUP/8sPYBmJtSI2hL0EqQAosDf0fgpZuS4Ets79yJwHPJls71RVX8DM6S0D+UYNESI1tF9lq1Vh+\n7u8w3QG10SpLK21KPriFEkYKNgfQ6sX0ttt85C0zpG+bp98NeOjUJpd6im6qOftgwKnTOX5l4Wns\n8u0YozCv1YH/WKpdvMYYM5AZKposwdEaVCqIlCFNIYqzRxjtjmCsDHgpxI6k2e8ylishJKS2RT3p\nMJQrc+DYtUwePUB95Qxjfolo0MGkAfNzk6RRn+bWgOHhEXSiKBXz5BwP7Rva7TZJktJoNEjCkH3z\nC9SqIwwPjeAUC/T6fS5dusTtb343SRLw8P1/yztuOcyxQzOsrV+ktZHSCWK8ks8db7oDszOWUZZD\nN0k5u3iSa/ZPMT52kIvrLXqDkMZGg6Cf0Kz3uPW2mynO7CW1vGw0xL27is2QDZNXCVrjhvMvQaiz\nxEIIQbKhmVxwONtp8pMf+ggzE2Mcf+DL1NcTKoUJDD0GPcWhQ8fYd+PzbL60zeT4MLblE9X75F2H\nQm2CBbvC8HCRdqiwtMdDL17k3MUeR2/YSxDHFDwHx7NwvTzF/DAbXc2mVWTfiMNMrcj2Vp21Toe3\n3nGMkZEa93/jW1R8QW9g8F2LnKuo7DF4ozt7vwidULKZBORFnzAApGB9bVehwS35hBsrLH7pj2k0\ntgmay1Rm9tJeXiL25zB0yG+fRtVuYO+hcbYWT7He6lPcczMT5Xdg+xadtQ3KBw6w3gyppa8SL/wE\npeYFxq6q0TIfQkQBa8vLjDh9umf+Gu/NnyQtzfCuq+/GHyqzfHaZiWKBk+t5xr0t3nLgLKNj41xT\nqNI5dDXzwzF3zZ9jUBNsbPUJtcF2HIqOw5sOu/hiLzPJQc4++SjzpSpHkyIffOevY2yHh99Z5zf+\nrz/m93/v85xYUUSD39pVfK5YybCTzBidFShkUM5sFMjOdS0yinqYvBN9Zj/Eq0hRpFwepZzzUQG4\nVkxOal69948ZP3Ado9UytuhwcNZh/dFHudR3+cuvneHdR6v8TDnl1dUyb56x+MoLA77woV8iVyrx\n1o/+T6yU9uAEZfLNHhYpTi6HLI6x9+f+HWOvPEbh/vvI6yV8dxy1+g1yBclg9HraJx5js2bx0APH\n6WjFob1j2LkRTi1H9Lee4/vAVL73SmN0q4mwHYTrQLmGHsRYpTKD40/ibh3nzIsXCbuSkdECwgyQ\npOQOl+DCSarvuIny+z+G39vAtYs40WlSd4FEtRHlYZK0i601gRbIwije0Q+TnFrGefS/Mh3EDPa+\nE9WPiSfvgPpxnLiD7YRQFjjuJP7gO6z86V/REjA2LNEp5Etg2S1EuwmtC1kNasBS4JZATN2CGJ5F\nFnMM3ENUcnu4y6tRrF/8vmF4w4Tn9PIaKw1F2dcMIoHnu6QiRqbZ4RXFKSBxXAedRCitefr226kf\nu4HPNxvcMTpG1G9Q8x2kI5GDIVaaAqdXoOuApMeI0wMZ4lVsilZCYvZhRyuo0GNGGqo5l/vD2yh3\nTzHWOQt2jLQg2dZE1oC5SpGtrofteDiOh7BchMiApEprfEsijIUkfs09FyEwOx4qemd0IfXuDi1j\nFLES2AgskdJEIVPwXUnfKLQlscmcydvdPnnP5iQeL251uJ6QXqqJLcnV1yzwta/cw+alNSqWRbKj\nmGVUSqpS+kGfiWqZpfUNbJ0ySBNywiJIU3I71HstM/f2SqFAEsX4lkQajXFdwjRBJgm+5SBSQ2o0\n9Vab3/j0nzNWqVGcnSI6/yqD2QPkpSDJGmIMMOR0dtOQ/DAdjOxOo3XWTpbsjKuuAAm5gp/5R//b\nZE7Qtuvw7P0P8l//4OfJF/dxbP8oX7rvFNfvqfL8GpypbxKZrBsURPCWA1XeU8kRWDma9Q63XbuX\nR46vQKPJUMWlVArZ3nqZ6f2ZyCImcx3WO+1tswNU3vkI2fDTmIx6vqO3o5Qg0QaVQpxmPkmJEpx5\n8I/hrb/wA8dG6xSwUUA/AVXUGJ2glEfULdMrwsMPv8SrJ5eYGR2lHySUS4bp6Wl6vR7FYoGcX8T3\nc1y6dIn5+Xna3TZhGlAuVZid3UOqIiozE7iWTRgF9Ad9RssVUIbZqWk8z+fM4iq2TNnaWqNw/T5u\nue4gi6eWOLnWxMrl2Ko3mZqexPd9wiBgdW0NN24g0z7loQq2V8VCZGaWShMmfTApxWKVqb1XUSwV\nd71vPEcQhRKkj13pE28KVAC5IrgjAruUEsYQDhrUWyF+wWK4onClZmPtLKFwOJektJqaXrnIIy8t\n02oPOKov4JzcQJk5jt1yAwuHFnjupee49vA4q+0By5u91zS1cBziKEKpFCkziu3Kcou4Z2FtdBnx\nFEkI9z/xIrP7LcrXQHMd5CATdpMSyAt6NkQ9g92HKBSkMfTtjNhgUoF4Y1GQ/89qfPv3Kd/2caLm\nGXJCERbHCKIcY0dvoP/st7FHb2Q9quJ1OqysnUBO3MKQ38aevI7LLzzH2HW3Ut1eJFk9TtXdwyDN\nY05+lcge5/KZbQo8QTWXY+zYh9Brhvz0j5M7cCf/9n03sQxMdTvInKQhHTpaEAiL/UmO7aUt5mZG\n6XZWOLmR4qPZ2ghRAoQyXLOvwp23TaN6ecpbB0hb2wxXR0iTgFKlSqMXUK253DkxwROf+gP++W/+\nDrHZ2vXewXw30eCKJMjr4GUhswIGI0BnhteqPYlKBqTdAVbOwYq6XGo2iWLBtQf28fYjQ6w0Jxke\nqlKuuCSNbeIkoTS9h8bgEjU81jYD9uRDDszOM7T6Mm05ihqbY+/Bw7w6GGKmJElVQhxnXeNOewOb\nlPL4JO7cYeZv2KB/VjJ+9V5U7zaCpZfotjykshgrwWOXDWPDORqtFMfuEzcaeN7uzithFGYwyGLk\nF8CxcWf24e85TFJfRHXWsPMF/F4/SxCThDiMEGfOMPWBA0yOPIDefoqcewCTDoElcPULWeEaXcCJ\nVwALY01CotDNLTrdGjlrGBO24ZVvUc41cT2NrM0jhsYhUZh+G93oknQjXEtmU4VEv2bWnHVWMh03\nYQwMQN55N9a7P4rsn2dLz9Gz9oE3xhWVNz0/9X3j8IaX3OWtHls9j7WVbT74To1OJEZKhO2Tpj0c\nNLY29PsRl2p5Noam0dccQ585x+G8AW04mwQEm2vULJ+x4XGumtqkeRG8EvjOOkanmEEfqx7heW16\nushY2qQbJNyyNyAtTNJtzfBQ6Vpe9D/CmCdZyD3KvugfELGL5crM38u2sYyNFJm+eKzSzFlWmgyz\nITLdlAzQlh1uWhiElmgyT7DdrDRJcW0QUuIkWYJlJRoRS1LHRUsw0uBJQ6Il7V5IHsN9mx2OTleY\nHDNYro9tGaqFHDNDRcJGj8SojBEXK7SEPTNVFo+vIKIArSW+lMg0JdUJxraJEo1jQSIdkl4Xr1JA\nRIJUG3SSYiuDb0t6aQpCY9kuZaXwpCAKWgzOhHgHZrGn9tElo1nLK8MeIwgMGd14l0sblXn47AAE\ntRFIYzLdmiuHgsj0i4R4nS8hRNYVW1nu8Vu/+HNMjs2xMYj5zonLaCT3nm3jy8zpO29nv12lBK82\nY8586zxBP2C64jN2oc51V+X5zpah3xFMJJJcNSLuDbCLNhqzo7ydJWT/uMVkTAaezkZZoJVAKdCp\nIFEQx5kuzNKpZzj/yB8CP3jCg8q+rzKCQayIpSQxHk8/9SwXF89RLtV44N6/p+JGFAvDtLohhXye\nOM6MGuv1bVQqKRSKuK5Lo97g4FUHqbc3WF9fR2nFcG0M33HpdjocO7qAa9scPLTAM889S783AJ0Q\nxSFhqki0YKvVJgwVTqmGdkskQH5H/VkCYTBgfX2TPSMFYqXpDVoUrBy2kJRKJYy06PRbDHodgm6H\nSKVMTc3set8kscLzcniew2XZhzzggArBKRmGhj1mpmfYP1pi49JzvPmq/USDkGe+9UXuf/o0Y5Mz\n9HoRqdEMegPCKKUfKe6ch7jdZmmrybce/yof+8h72O4kNERIvd4mGoAwEqUStNZEsUUUpaSpZmJq\nlE6rS9RXOLbByltUp8CMCwYoooYkaho8JxupG6Pobhr0lsBYYOkd/EgiSFMwbmagu1vW/j2f+3M+\n4FSYfse/4dTX/htJq42xU84/fR9y7DYs1WRjpcNQrkfu4O10l89Q3ruP849/i1BJGg/+J7zaAdIg\nojq/n0LvHGlkqM4Oc3n1FLU7f5re6nlaX/kEV7/rwxz+6X/Nv/upt/LYCRgCtJBoLQnSBDcSrMUt\nvn0q4vDUGBsbAQf3zLF48hlEAZSRCEvz1qPzvP0de+g0+zTPlRgbdUhyeYTSkK8SJRHFkVHiKMB3\nHdz2gL/9X36Vu37rU7veO2rHU0IbdjA7vNbhybo7r6dDUmYgZ9MJSGKP1HLRSJrNLXxpuOm6G9g/\nVsYfXMDJl5C6i449wnrC6JtvRz75ImUuEWNIVcJGo8ncbIGLxUM0mwFvf9e72Hvt9cSWT7tZ5zsv\nLiLXN1lbXOLprYvYrsfP//xP8OZ/8n4qt7yH6sEjNE48R2niOrQZwuo9zca6Zv+UwEkN1aIiiDRh\nEkOkqI6O7So2IhxAojGpRiuN8HKk9UXCc2u410xibIta1WXQ7FEgpVfMkXgpZqRCdbKHrr4PE0eZ\nloJugQ5BtTA44C5ght+DbHwTufwU4vTzXD5zMxt9gzdzI7PJ/bgiQAculhNiz1+Ndcu/xLS2iL7x\nh+ilRYSSSMdCpwZSoJAxZjPz6Z0vEYH14Y/DwijbXcVy/GbKJkZWx8mlAU74CsOeANsHrv+ecfj/\nSXjajFZdTjag3okZqqRYWMioQ+Dk6A8V2GRAq9tmNRgwfctRLpx8mYlkwP5r3s5jG6tM2DAxPs2K\n7ZHomD2xRd9PiFJFIQyJVhdpnVrmVrvJ5FRCJ19l/7TPSriPlXyBtWXYjAXtuE04NkS/cB0v5w5j\nT/wEP3vik3ReeBop8xmt0NIIkYCRJCrdqeI1WhowFkJkFZzZAbIZMoCs1Jn41K5WqgFNJBWpzjLS\nkhEEOsVC4iqZMXu0IUcKUYQWDi3P4/mVNY4OVdBWyvmTZ+n0+7S2W+SkRRLECATWvjnmRcCv/Yfr\nue/zEZ/+bIO8B3nbRqSCPYf2sXRujR979w1sNXtYtsUjz76MbQSJ1sg0IY0ShBT0Y0UMGMfFN5ok\njFG2g200iW2Il5ZISfFSC09kpnkamz4a/4d0S9daobRGao3RMrOUAJQAx4jXyq5sxCUztgSZYmy9\nEfGvfvROhmoWo3mLV7cTyh5M1Fwa6yG1SpESikArtsMES1g4UuE4grgrGQjJ2lZILm+QOUNPpJTz\nmt66gzNtY383VgfDjvDCzn4QO8KCGTXemAyrk6jshhqpDKCaKkl9fZ3Oo7+Ftnfp7OxoOomhHRg+\n8+lP8fP/4hcII83f/M2f0mu+gCfHGGw3KFViIrdLaHrUmx7FkscgCigWygiRSfwfPXqUl19+JTPC\nHQTMz82yuHgOG5vcxCQjw0MsX77ArbfcxOkTL+IJxVpzkzBR+J7L2PAoqeVxcWmd/bPj4JZYrbfI\nlUdIkxhJSnOzQ6cfEDRazB6Zp15fYnW9Q3GQp1wdYs+BQ6xvLNMZJAzCgInxMbZXN2lu1He9b3Ie\nWCLH1PQUcWR4/qkWri2J+oa8C2kSMTkxQ9+a5HIzz3p9k34vYPH8OuEg5OLps0yNF7NRbtolUTAu\nNc0eLLgB2hg2Vuv0samNjbH3yI2cePwpzp+to7Vi0GohR8ewhCGOB5RLBYwReK5LHMbk8jZewaab\nJnQagqCXuaF7CXgy86ZzXYtOX6EUKCmwjMG2BSSgyJR9owSsXerwfHPRpfAnn+TW932QYz/3azz3\nd1+ic/4b5GrXUe/2seIWc5UOWtkU4gvYe68l7BxneOYGWmuLWOPvZfPkE1SqVeKVR9iK9zE/U2Ht\n9Fkie4RT//33KNfgjl/8LCPjNr9+95tYaudZTcB2JK4tQVqkxkaFCu3BxY0G3ShkIpG06yuEKUTN\nNpfiEDNooG46QqvbZenVAdtLgnwEJRVg5crQ71EcH6e5tsbsaJWo34NSEc9x+bNf+uiu946+wkk3\nO7Yvhtc0tTJM6w4dfedlaeKDKaJIMF4VIaBTX2dkZp4bD01xeFKw/kqXf7jn6/zyL/8s3e3L2KPT\nBOcv0i/to8tL7CuDEpLtdofEatCPJDPXvI22P02vvIcL2wmNe77GB9a2WXAkwrN5/9gNfKO4wf/+\ne5/jnltvpDA8gZm9mpHaCJv3/gWpFuioz+KaJBCCZ1ck+2ZiCnbMq6/WOXrr9UTfzb//QVaUZMWm\nMugoRloOg/UtBq9cZMQOyd20B3d1lWB+hP4Q6EqOkuzjmg6d4w1Gy88j974JcgvoMET4eYxdRRbG\nMdJDNM6S9K6l/qVPs3IGBrfeweliwJ3eOYwGy5PZzd+AOn0v9h0fR+x/B5b/52S1eKaFtNkWTI6B\ncCUkKiMZaZDGIKZGYLLNpvk1vPoXcCZuRtWPs9D4MygvgGdjVBvSJoK7vmcY3jjhWe0yNjRMs5Nw\n9sIyt918mDhVrEzUWCx26G5eRJsiHVlkbH6B5voZRi2f/B1v5vETJ4mqQ3ihQyPqMuttErsFctPT\nyMZlvI0u7bjNTLmAWyvyUDrL6LjLSHGU5YmD9HSCCtpECxJPKPJens6gR3Ntk/LQML4Q/MOBf8HU\nI4vIjkQisYS1MxIBlWRHmJDyCmotG/Ps0CWviONqdsCpu8x3nDQhQWEbgycMgRF4EmyTkgqJZ0Js\np4BCIozEVilBoYCx4BnlU9ncotPr0IlC7nv8ZVKTaVHYdgZfHXvPu2n88edZfKCBiBW1fA7Hzmi1\nBdthYjIt6GYAACAASURBVKjC7FumsPwiRvWIXYuhYpH2dh3XEiQGfNsiQiO1IQfEaUpqFEJa5IUg\nUGCR0Du9yHCSYEjpWYKitnBNjAY8LX4ot3RHgjI6G2kps8PCykZbqZTYIhspGpEdCgKd0eIth//y\nid/FTddI3TLLgWL/TJ6KNFQwFDyLZr9PatmEShMFKeVRCEJN3s+z/5oSUhiW1i36no3EYTSJufc7\nXY7dkuDbSVaV8HrCpXc0nwEUWWWodOaRpXQGBk9TQ5xmyU6SSNrtAVtf/h+od9rsVsSpnzqsN/uc\nO3+O1Yuv8Jef/SMOLdzEobkZwrFJUi159sLXWau3qBVy5HMOG1t1ltd7XHvNETAuUdDHdV1efvll\nSqUSzWYTS1hsb2wxPjrCnj0zDPoDRiamsKXgwrlzjNdmmF7Yy8bqZcJUkYYxjkrwHBfHctna2GCl\nt0K+mMezbYqjI4T9Di4WjWaTO+84hm+nOK5DP5I8/8RzjIxPMDI2QTdsM5U/xMTkAivLFzBxRDDY\nHZMEwC3CIAjodOuQZGBBrQzYhorl0L1oOCUu8sD5h3nv++6mJBL+/mv3E+QcFuZHGDQ0h+dH2GgO\n8CwbJxGUoz6r1UlWpq8n1zgPwKMPP8P+A3vIlVfotjYpFCFVKX6hgDYaIRN8L09tZJRzp8+RhmAL\nUKlGKU3ShlaSYRbBYMlsr0jb4Fk2SZgQx1krPgFkYjBpZiuhBwKRGuQuiRL3X9BcW/EQX/0Sk8/d\nx20f+QRrzXdw+t6/oiAFneYrFOZuoSAS4vl3s/3tz2JETHX4UXLePgaDAdWJafqpi2XlGJsaotvb\nIg41tckprvnIn7J3/xT3fe7TfPm/30thaj+5GmxfPMuwpdE60+qKkhjP9QgTTU4YmnGAvX2Rrq3Z\nrrcZ1Cb4sQ++jVrV45Xj59k7pznx1DaOiWgXPQb9Pk5JMztUJNpapmJZbC+uUMhJBisGd+Igldru\naNeQ6fC8htO5csu6orUldgo3IzL7odRCImkni2g9TtgNKeZ9irUx9o2PMF5K6V46gyf6/Mhd78ak\nASdOn+POkVFytVHOnnuS/cMOqdJs9DPSjLY7+CMHaHcH6G5CsRVz7oVXuXpli/baGVb2TlKbmmPs\nxgl+tnSMs2c/y7NPPsNb735/tjEsm9L+61h/9MtIKRjJK/5h0WWsnI3tg8I083Mt/FKJoN3eVWxM\nqjHavJb0GEeTeh7r9oDa5hZWMEFccCgUWgyJBiKwcZ0CnqN4brnMy7//ADM3nmN0epzSgSPEWuGr\nlK0L6+SLeR74wqdgDfZ94NdZv/0gy2ef4uD6k1xtv0zLdbE9gwmybo3ahuhPfw45fSuqvgE2iETg\nakMnMCgtELHKOj2ZkmR2Xpf3Yq0vEuk/wKtez9zyb1POtzHJWegeATQ4M5B2vm8c3lhp2QEhFG7R\nYW6sQhrH3HfkAJYOKLeWsUb2kotKuEMjRKM1csE2udmriTZXKF9cxBEOswfK9GXAJT1GX/rI7SUq\n3ijbfh8xPEscr+MvFCheSClYw0i3ihm00VGC8Cu4JsCyLMwgopAfxQ7bmBiSOKabq7J87EeQX30w\nO7jMjuYLECqFtUOKltrOuhZyh5klrng37bBkXgMw/+DLMopE22SuwwZPQ2wZaDfRkzn89gBd8dBS\nEBtw0xRjUpzU4vT6Ft7ls4z4Hue3tthY20ZiSGwbEypq73wL0dIKd/7YLHuui3nwsQ79VDFkW4hU\noYTg0SePc9WRa+lG6wxXKlw8tY5lO6Ayt3NhgUxS7J3uiWVLhOMgkhhDZhQqlMETFmG9l10IwpCL\nNb5RBMLgIkmNQprd93gcaZBkuiZIwxWTPmlEVuIid4TpNJY2SGmBBS++sskLX/sLqkMVkJIpH47W\nfB671GFDGw6OOSy2E1r9iPnhEsXpKomBrhcwOyOxS5rO2Zi3z7nkJMxd5eEZH3oWjz28wdF3NWkz\nzut0jR3cDrym/5PumOzpnWRH6Qyrk6SGNBG0OyGX7vlVijmJGEhcmewqNr3I4rFHnuLEiw9RciLO\nnXqAmg9x3OCaG34abbk8+/DLCLPKUM4l59usXlxndmaK0dEJBBarUYwUkjAMcRyHWq2GZfsopcjl\ncvieT87P3LEf/84j3H33Xfiux7e//SC33X4LfSV54vEnqAtwSQmDPq24RW54mgP793H69BK1oWE8\nz6XT7mAYMDxWJG63md0zzdrWBpWhIvXmNpZvoZHs2X+Effvv5KY3WTzy7c8Rx10WF3eHzK1vSVyr\nT7vdAwOelERdAzHM33Q1Dz34Em9959WoRpsbj95OtQj3fOZ+IhKuPjTFuRPbTE4M0RwkuJYkTm1s\nFaBdHys/RjF3kclheO6pk5x/6iRf/buvc/N1wzg2pEmCTkIGvS7NhmZj9UlefeVVdBdsJ7NJkdow\nCDQiBVdCrLJ6ySjQQqDILCa0AtKMpSU0mNjsUBV3nrPDKNrF0pbhm+cVhQM+kdKsf+IXmdx/kNs+\n9K/ZXGux8liftdVXKM4u0H7uHipzB2hu1+lFCTW3xeh4nuZWDjc/SxJ02F5ZZ25vhbf+wi8TD7o8\ncc/n+NQvfRWKkB+bIxo0Od3oEwDW2CzNtcu0o4RRoJAzGFtipKLqODS6LUxlil/99Me59bYFUh2R\nYvGT8k62N7f4q9/9Qxb25ekGAbLZpBr26MY9hotgYk2cKApjMwzadVaPP09pdoH9uwtPxqZU8P8u\nZjLm5+uvETsjLoXRhlLVwtZ5mknI/MwsRVdw/TULiM5l1peXqBVTLp8/w6G5G3jkxUvcdNMRTBpy\nYSPGLxRZW2qRc7Ou+/ioh7Q1tSGf3GCJUN3I2ovPcY2VMPn2d2I7CXHOpbUW4qUBdx09hi9ACYGl\nFCoCx/cpz+9n6enHCaTNk+uSD14Vorwxygjk2DiW5yCt3bHY0iSTWdip/TMAou9guRCbNjqoMz7q\n4vs17PwoxnIQtoUkZXyqwiPWLcyEW3zn3/01N70NXtkssWaN8sCT5/knb5vn2C89xqbV4cT5Vygt\nH4flJdq1Gfqdl7FkAo5N0pfoBDRDmCCi//dfxxqxEKPjiCTEarQZy8lMpTvMzmmVZONHIWBgKuS2\nR5g/pDE1jakdxXjXQbQNugtip4KQ359V/MY6PJ5EypBCr0neLxDYHiWtaW1uEvbyDI6/QP3wIYod\niSwXCMfnGSQh/soSQ8025VtvousGdLe7FOwmaaKhMsFaf4WJqUmceIArSgz0gPFrR9GDPgWriwn6\naLdCRIyWFkpprLzPoBswoVs03DxV38Hqr6COvQv1d19GF8qZd5UWGAlJojBGZzo4OzNduZPtqx3t\nE/FdipuvlwQ/2EqUwhIJWjqkYYSSEhkrhFb4sSJJYxKlM3dcKUiNySRYtUHX6zz5wiJV12K73Uco\nhcaQpgphwfgt13PxTz7L0h6Xb3854tyZPnk789PyBBgpKdiGU6fP8Bu/+DNcWG9y4tlnieIU25HE\nJktmIp1iSRuBRKXQNhGlK6OlOCDFAQm2ZSGUIqcNoRTYqSJvQWg0Agh/CJaW1BpjK5ROSVMBxs2A\ny7bBUmJn1r5jHioNFhpbuHzxj/4DJT/BGMmUSHniQp/lXsSt83mMK+mGkmKYUHBT3nl0jm4qiSyX\nJ04ts7GpuHXCYnPGMOgZKrWUbzy8TqIMKRZ33DLCsw/dx5G7f544ir7rtxfoHUNQbeROpWgy+nlq\nUArSGOJEECvB8a/+R4qDS3ScHP1BSKW8O6fD+77+JU6++Ay9+jKVUpH69grLl18h53tgErx8lfm9\ne7C9yxQKDlZouPa6Bc6cvkR9u4m0BGmaYjTkcnkGg4D19XXyBZc4jjhw4AArqyvs27tAu9Nh38IC\ng/4AFSoWFub52r1fZeHq67BdD+EWWVldISiXkL4DUZennz/F/sNX4foWcRzR7/cJowA/77O93CMW\nHvVmh1hZ+IUKG811pCd3TCPLTE7P8da3/zibm5d5/BvP7io23ZahNCzwpMDPGyZm4dKZ7Nrcu2eS\nh3iJEb9Af3SErVdfAs+hMAy64NGrBzRaA7ZX2lSHSoSdPlPVIu3uNk88cJJ8oUav2aEK1B2BW4Rm\n05A0e5QdQEiGh6toS2LZLjMzk2yubBKSdYw9X2BFGpMopBL4eU2QCogg9iDvGxwpSHV2sGB2En0J\n2ckGOs7wg+I1RtEullG8EMD4iqTqGqYqOSKWOP/JX+HQm9/L0Q99lGucAjrRBEpg0gFry03iQUTQ\n7ZPmHA4dGcIvj4BjkfMkg36fP/nFu9jchMQFXXYQ0metUce2UmoH38z/9pE3USpVKDqwdvkcj7/w\nAt984gkOFsDyLZbW+8zfdBWf+T/fSz4HrfYFLLuEsFxkw8fzHGI7IAl6rK2uMOY4RFFKqxEw6Cps\nVxJJQetCwHY7otPpUxgY7thleJIdeQhjQFzxMiS7tq8ovCMMlnAxWEgnh2Kd5pbN/L47UNLh2IEx\nrp0wnFuMEVaeyliFxsoXUfpGBh1AaTZWYgbkKBWLbNpNYmHQOsXPFXELZaYmRvnAT76f8/Ekz5Nj\nUCmznUpuODxPq9MnMZo4ihmaGmfp8iXaK5coDo8hLA9r/CDuyil6i0/wZ6ctSjlFow83VgpUR4bp\ndQNazTbVWmVXsVGxes1BXmkgjDGOg+1b4Eh8dYqkb2H8CkRWtne1gjRkf3KCODnA1+1bcH/l12jl\nJCvbHTZDzc989DCvnjlBr/MME+uvMHHxFdqRy8XAMKlt7jk/z4dmL6K9GJ1ahFsQp00cz4JZD+EI\njO5ijVaJkxQ/ijBSQQa7BbWDWwbOhzNMX4w59bkv8Nb/dQqTtkDWEe5Qdk0WxzHGR7dPwfeBOL0x\nLT0Y4EyPM/7xu/nqdTPUkx7VS6epLi7TPHIY75ojjKqUeNIl9nxYXuPw9AjxIOTiwkFGNtaIgzWC\n3ARyzxiTcUpr8RxqZpaw1yRE4EoLL+dQFAm6WER4VQaWQBXGUEmMUAk5x6ERh6TbG2wPTzCaNohD\nQb8wyUQBGrkStgEbDSYDpMZK48gsqbHQGJPhNPQVqwMyaqLZOczNLrsYghQ7gQDwjUCpjCmSxgpp\n0mzuH3bxpItBYts2dqdH3/YQlkO7N6AtMnPEXhpRzXkQaYq33czq5/8K3e2w+ILmxDPr5DwH7Uo8\nMoq01Bpp2+R1jz/4oz/H9T1KvkOQJOg4gVRhuQ42AldAalSmmaQtPMejF4VIY+HJBNtY9I0CZYjS\nBF8KEmNwtCFvLLpGY8ndd3gyHy8rS2h01n1TWFgIUpl1eqzv6rJYlsvjDz2LiRIGRvO+6RxFy6bn\nSK4/UGYQpRQFnFxrU8oXuHF+gpma5LnzXS61QgbtgEBKHn60T75gmKo59LXHzHyeVy4L2vUO7bUB\nf/uF/8ye696OO74HbfSO+WeW6BhzpbMDWklSZUhUlqdGCSSxYO3Ut/E7D7Ad5KDfplZ0mRjanQvk\n/V//JI52UIHAjMxy1bW3sXrhNJ3OBl/9yqcYGjtEp7XI8HiOfhqhjIW0NH7OZ2urTrHkY9kWnuPS\nbrep1WoopSgU8gz6fba2tuj1Ai5fXkIKw4GFfZw4cZyj1xwln88zPz9Ho1GnOjJG6uSJBnDqwioX\nL2/RbYVce/Q6dKpApDSbTXr9AUlsaNYDwlDQ6NUpV0YJE8no5DStlfPESZs4bLK1eZ7puWlGq0fQ\nUWHX+waZYaVSS5CEAi8PlbIgNoZLzz0BwNZTD5E6mhcevZe9scWBW45xfP0c4VqT9naL7XXJ8MQw\nrXYHPxpAYqiNgEbjFXLkixIZaRLLgpyivQNgV2lK0O/ilcaQQuH5LpVKibkZj8YgQvrZfjAWlF2H\nybxNKxhgjMC3s0TGtyWRMqiQHSormYz5TmfhNXy8uNKN+MHXu6YdXu3DQ62YH63ZrHUS2pHFodEc\nrz7+dc489nWEBNeDcmWW3PAk1fII0s1j8hAFAevPrNPavszG+iadDkQKQmDg2nQCTSWv2epGLDVj\nGJrhPVcdob18nsLUBMN75xg6dJBjhw/xwXe8nV/5j79N3FfccP0+PvHvb2LQXMSoIZTIEegujlNA\nJi7nnv4WVm6Ii5stwlJCiKCWd0iiHmkaYPsuQ9Ui9XqDnsjR7iR4jeaut06q5A4aj+zARCDlDoLH\nAlvapGi0UmAUwmiCKOXowQ+wtzbObdfuZfXlp/nyfU325LvsnR2iMFSlr6GYs5mdG0aV9/LSqSaJ\nMViOxnGhPUgpjk/S63UYqgy4tB3x7GU4vX2Zfzo/RdhvU6SHnSZ4hTxumhIRcGDPHrqDS3SXziKF\noTQ2i8lVcWeuZejw2zn5d/fzIxNZUrm9eYGVxVGmDk5Tv7ROdai0q9iE0sGL4telNqIUsCAnkWmG\n85U5BbqBzMKTdSMtMA70kwnmWmcplmP2rC2RHy5h/ALyQpM7zRbmgXuwChojIyraQczcglcr8PJX\nXuSfzoJQKXp0HEv3KRZKmUrz+CzStdBKYlSKW5M44QoyBzoSmZnjztIC3MYZLr3pZzly68f4+id+\nk/d87FbWnn+C0tw+ctUiW/f/Lcrk6SvJgaO/+T3j8IYJz/s+82sklSrDtsfamVfphYrc/Dj9/dfQ\nExbFl16CQhHHd8kNBuRrVfawzbempnAjnzGnyVZcxveL0OzS22jhTA4zgqYTppQrBSKqtANDznWw\n+ttM6TZ5kSc0gkE3pOC3kFYJR+SxzYBktMjKhR5zRw6hG+cJeinJNdfjnzyBcVyQKVrZxJHCszPF\nXKGzUYURIIWVTTH0DiX9CvVtlxieBfok0mJ4eJy17SY9DTktMs2aNEbGMcqRFPMuKonpJClxnDKS\nE1zsbOP5OcKgi+t7THqlTPbeVVTm5+m/coIoSTKLCUcAKcEALCnRJmuX636CtpyM552GhErjSJuB\nihgqFegHIY7M1IAVkJOCVGl6KsByHGwpSIxERTGWk8NWaQbmU4pEa7SReCT4GHq77b0DCRqpUrS0\ngRSVAHhZYikFwsrAy7YtkQ6cO34KN1ilVC5TtQUF1+ZcL+LY3homjpiuFdnY6jNcLdLXmm+eXOcr\nxzV5y2IsZ6FUAm6ByCtRq5bZCCOeeGyL+WmX6T15Lq9axNqhlnN48O//G3f/m/85a/MidthYWayy\n7k42vtIakhTiGNJUEieaS/f+H4Ra4fgpU+PjBEGHc53ddcAO7b+epC8ZLk+ytnWRXrvP+MQM5049\njwjWcFyXKA4pCE3R8+kMIuIw4sD+A1y8sIptQ606TH2zjjGaKIoYHq6RJCnrG+v/D2dvFqRJdp7n\nPeec3P+t/tqrunrv6enpaQyBGSwEQAIkQXAxSdGmFbbCXIIiL2jLti4syrZsy5YdDkY4QpLtkH1h\nK0xTCtMUSVMKmRAEEBSxkAAJYgYzmLWn9+ru2v99yfUsvsi/GwiLgFTMi96qIjr/rJOZ3/m+931e\nlrpdnBM8fPiIOPK5dOEc62vrvPH662yf2aLT6XB79212GkuM85yqsrxx6xjtBNs7V3CE9I7HTDOL\nrgx5kRP6be7dfcxau0luCyZzTZJ0OTocsLN5hePeA6w2zCbH7O++i8kMWXb6l5aUAps7isyhRe2M\nc0oSJ45RWusWRNqjvRGy2Q2ZDhp83yc+gv16xIM/fZNKQdBtEASG2XCA6iTkZQkFpMeHDCcj8rkl\n19BwlmeXFDuBZKqgKA2d7jKNxK9dnNZQFXM+fr6F89q89mDO+nuv4CUBoc558PgRDojCGk9hraMy\nFlNp0PUad/4iNNdbOIWeOEUX1vTTHGcCn/XA8raAf3Fi+N4lSVFpXi0UN84kNfhUSQLPZ5qOybVm\n1DugzFOcK8hLj8q1iFZvsHY2o0GDtjvgS3diHp28w3A6J582uNmb1+u0WfDml/+A1eVl4iCAPOdk\nNGVne5PNts//9T/893zmnS/zV39xiyw9oNlsMUlzGvESaE1lUjylWbt2jaVOj93emKIYcZLElP0x\nCSntOMDkBfnRCD+K6Gd9TGn/XJTuJ5lZjgWEVoIQsuZ8OYvRumZsSYEnHM5GfGTlB3lPZ4WDvZv8\nH6+8wr1BxY9fjVnrRjA/Il66wLnL72c2r3jphecJojYP9l8nTwsCArYu3eA9Vz/M2Wvfxfj//Y9I\nvJCd932QMGnyV1q7NN97lgeHDXqPHvAHX/sav/7WW/xvP/Oz5ALkoxGXr7QZDI/x1zZoYhEyJNo4\nx8rmMj+8Du1I1HqZZpvpdEyRdgmVZD4dnera/Fqu+WljCRfuUCtAY/CiACHnyBa4Bsh5CF6ICBuQ\n15tnmQ9ZomLl2mVevnlIr5ScmU6I8wOuPPsce/uG5mRM2+TIGEQWEwQz/P6QZDYGBQcf+V9Zf+ED\n+Pt/gK0s8u7nkde/D3np+xCHf4AbvEH4uMfR248JonoCIJ5oTcwizzEzeO//KcrIZ19oPvPf/c9c\n/Tf+Al2Rc/Snu8z3Z3QiCB89/rbX4TsWPPsVeG+8Qb68SWt1BX3lWWaNBul4ztl//k/g6hW8jS1O\nVjd4Lh3xYHWLV5qbyMM+w5N7HC+fZzM6ZOjHRNUj+rKinDXZCgacWe8wMMtMizmx7aOnY5SsGKgu\n0mviFX3Cbi3wy1SDwGrCF95Hb5aTK59xNaJhK2ZhjDmzg3n965jKB+FjTD1K8WSN3RR4aM89tYdq\nt7BFfwuG/JSudPbnAg/HaH6A9UNEnjPDILSlMJbQGDwN49yhgoBsXiEDj6NZwfLmWfb1azSEQhrL\n2DiErYGOa+YR13/oHIO9nOk0ZXQyYZpWuOEMK0OskmTW4oSrO1q2ILOSBIvTFaGTFHmBtY5iYVWQ\nEgqnFrEIGuV5WFsHz8ggwDhDlaYIz6MQ9ecSwjJ1C0ieO70tHeuohMNSU5ydkGhd4RGAhkBInAf9\n4xO+/unf5OUv/FMuPfscRw93WV5vc2zhynKDobVE+Ny+N6T0PC60HUkSMEgkU2MZTSyTqh6foS2j\nec7J0Zjnr7dpr/usbKzjNQ0vfcAw7ZQoCb/9O7/BT/wHv4x13lOBcl3wLCIjjMNoQWVcXfBUgrKy\nHN55haRd4MoIL2rUsRtBQpKmp1s7ewW/9Jf/K1rNFl97+Xf43O/9P5iwQ6u1hcln2Pmcfm+O8n1U\nUpEVc6T1ODo6YHm5Q7PRYNAfopRHnqf4vk+v12dre40PfOCDDAaDeu0Mh5w9u02/P8AXkkuXrvD2\nzTe4dOlizQIymqjZodIDwrjJ3t0HrKys0V5tYkzO9uYGDx8/QinJmY0dBkfvIDzwQ5/NzTPM84qd\n9TXu7B5CJUlnY4R8xNFBk7XuKmFw+nUjRf2wi6RgYx2GGXixQ1l451jQXBLcyhzrhyXbTcXNez3K\nz3+ao6M5y+sdzl9a4sp7rjPce4iuDH5gkFi8SHCy/5hJWSGQNBKHiurCwy5ehFprbr9xkzPnzzKb\nTvC8e0x7PRJtUcqnOE5pf7jJyuYqni3RUcD9/tdRs1qWJoRYOIEWhYxweEogZL22nowUhANcXRCd\nat2UgiUpuZQ4MmP4+sTxfBP8yvL6Y8P5rsWXOfHGOmGyzHg+IWgodCnR/jaNTspS9zqzvbdobt7A\n7n6JaulFPnS9y40Lq7x6+yHTSnN+LePVdx7hEKTasj8a03/rHbQuOLO2xf3de3jOsLO1yt/4T76H\n3uxNwjimsBCFCYUxrHWWScsS5BLt5jY3PrLHG1/6LHa9Q55nGCS2EZGVgmw2BaUIpWZ3DNmkYKNT\nnnrt2IU5QgmBkqC8J4pXuYBg1rl1gQJPJtzgJ1j2XuBTn/0Kr+7PaDVjrp9bQYgYG2/QTPrY2S5R\ns4FzluvPbvDqH3+N5tIWgbmN82OuXHuR+OJ3QdDmduuneWHF430f/16Wqznru2MsJRfWIuyky+ZS\nwic/8hHybI40CqtLlpINMjunSudIKUEY8GNKF7C5DJO5oxFBURS4uMlwUCGDmGp2Oovf/tAxiSQr\n0izYc/VatK7e5GEccu1n4BM/gLv5u4j8APnDv4I4/Ge4f/x36N15zNlWBzEZ0MxGXP6BDzMbpqSH\nJ7hRn9R16IocpRY5lrMhVTtmllJr2Ub7+J//b7GjCcwGCP8IcfgK+nP/ZV1cBQHCRbgMMHVB5i3O\nEbeQHgQgZwNKEuTaJXY6Gftv3kUay60Hh7y4uYKnJeLKd33b6/CdNTydmEQuo6cT7uRAJnD9I1YH\nx5jnLqNXtsg666TDKZPxhItbjoe9CdutNrPOOnL+kEedDfz1Lco7Q1o7DY4LjwlNRmWHxGasVcco\nNyeLVvHmPcKkiXUW63m4zDIL14hFSNTymGUZXm6IPMVh7rO2fpWV4UOOmh5VkRMmbXCCwhgqZ/Gl\nBKewCHzrsCiss/WH/pbOjnPiXyL9/qsOzxiEdfjOkJcVzhiMUnjZHG0tnrF4tiI1BQ0T0K5mzAgJ\nncXrdMimOYHvIT2J0jUksRv5vPuZN/j8wx5R7KMWwZYSCAO/tnlXBmT9WZAWg8ETiqEzSLdAbxtQ\nC35B7hzVQlXprEFahTQGp1TNyjGWwBPo4QCW1/CFQQqJEBBTW7JrmMXpDuscOIN19TkaUTsEKgTC\njzg6fMyXf/tXefj652kEGe0oZD68z3g+oeVZpPF5NMrwE0UpJbPM0SsyLjwbE0uQgWLaL8ALqYoa\nWNfoWrJcMhlEPHiQ8QPXEr6xO2VDJ7QvxezfMVy8JHnhY4qvfeGrvPDxj1GUFmPEQqBcU32NrgGD\nZQWVFlQajDXoW78DvsDmglarw2o75N1H+8RJdLq1I33CpElRQFmVaDukKhyr62fYv/MGnZbDl5I0\nFTRiRbvVYjbMqMqSVlNSFOXTJ0Gr3WY2m6GUYjKeEEURUkqm0xkgkErS7S5x++YtdrZ3uHL5InlZ\nMpunlGVBlmbkpuTs2VU+9rEPkuuM+XyEFB6mslw9d5l9ecju8QHtRszqepfjwYS8mON5HpPpMWGc\n2wNtaQAAIABJREFUoF1GEFccnrzDcNznuR/7eZw5nc4A6rl95EuWOzXhuRGCHkM5FWRjB5HjT0cA\ngpffTTl/Dub9B7y0s8GthyNWt5+l1Yg4LC1m0dVNpWJ1pcF+b8wgg/MNyb42tBVM+jAxmnIGum2o\ntCNsdlha6XDtmef54uc+x/69d7h4oYuMAx7ce8A33n2T5aTFfFRQjiFwiw2T1PhKYRfwF7kodKhJ\nGbUrsT71+tHzxB/9r3l8PRdcUZpNH1Y9Sek7vjiA97cdvtLcGUo2GgHZ/oCl4BjlxyRIrJa0Gpbh\nZJ04u49SSwwev03Sukrv8U3y8DzO83nuufPcf+cmH3/pRX7s4x+l1Wxx+9EhvcGUDz2zydryBqtL\nXdLZjDuPdvGbTbQ3WGTmxVib0WldRfgeo8k+kR/Taa2Q5iU/8hMf5+//ymfxijkqbiN9n+E0xVmD\n8D1sUdGf55SVIystWXn6YrkqF7oPubChW4OQIKVEIOtMNQl55bAm5a35nIff+Aqb2+fxendZbTdo\nhT5Oz7n2nudYWpPc+syv877LKxwf9VhuNwiSBof9Y6Tn8QM/+8vcniX0Dk/40Aef4cMff4nVQJP2\n+1zKezjh0NmEMPS5cX0JRMz+pIfwHJ6QCCuY3J3RfW+XodaIQEFlwZOEW88wL2BzSbC97NPqdEhH\nh7Q7bVxRotTpuu7eXKCD+p4wTxMJbL0ZNhJmBuGViHaEjdbgT/8JvvuB+plYRowmlrW05KNnNe2V\nHZZ+9Hu49Zk/wLx1QKMdMNZNrDnCcxDIimZkKfMp4xn0qw06b/1D8B4hczBVLebPx6B8UI05BBY8\nUzcgaorMN6cuVtQ/U1NihKTEJ/YtTse89icvs1u2+MR3X0EPxxRzTabst5PwfOeCJxARVbCKuCDp\nWg+ze5v1JERfPIsnIvZdRW8+5eLRfTZXt5kYgQqa9N99m521M5SRIOgsYR7tk5kIU7XxyxF61oT0\nZdSFC9AvyKUiGd9DehXZrKIZSpy/gVtSpAWsegaTZ/SyOgQtSnwoMuJxzrHXYe3sJY4rn2QBtisr\nuxCuyQVG/MnN474FPw6Ghcj59K5rgny2SNy1+M4tLKkSVRQI7Yi1xo7HdJZ8jM7QvR6l59NYWYUi\nx3qKKPAR1mKsRjpBnuU0JIShw5MGZSXaA0/UL+DcWhJPkhlBpCxWKaQWVEYTSIl0Am0MgSfrF7YQ\nNb1ywR4yyJpLZA3SWOIoItcVeB7u+IRweRmlHZUySFOvt8r9+aIlpKijPpSDcuGKsKJm3bz5xmv8\n/b/+SZ69tMPF7XU+9/oBge9xZm2Z7W7CjowpTcmzl5b54s0+mdEM+injsuTmI0GnoZgZD6ccUVJn\nfy2v1nbL1SsJjTZUM8tb90p+/LuWeXNPc/i1issfDgiqCMo+9x8MeP5jTxLP6xeTNvUYq9B1dISu\nav2OqaCq5rjibTyvzWbHY2WlwaPHe3TaTYaj0+22Aio+++nfoBVf4mTUI4wElXRUWIajkm63wA8M\nw8GESV6ysdKlLCqWOl1wsLa2xvHhCUZbxqMRzVaLKIrQ2rC/f8BsNsVZ+/Tf2q0W29tbZFmG5/ns\n3rqNs1DkJXEYMZkbLl+6QBD6UBniaJU4SlhNWoTCJ+106JU5oaeIfJ84CvGbDYwA7TRBIyC3DiNG\nbJ2F2fSERw9f5ey5q6deN1R1C2ReCY6OHV4gSMdwfGLBh61NOHgA4PCAj3z/+9nYXsOr5qT6MWVV\n1t2DJ1l5ri5kfaEJFVQeeNayYmDLF7SuWV78wPeyujvkK6/eZXnrAs7zWF1ZIkkSGnHEcFjx3HUP\nT5d8//svMwtiVle7vP76bd7Y6yGkgwKiuN44xX49xrWGJ3kqiMUoWiiBMzUKQ6jT7bLGRclNodCm\nookjEY4zIXxj4tgpBJfa9mmn+txSQD4r6BY9POVR6TmRf4Jo3kA1E7xpRDp+SKOzxPTxu6zunKc4\neJcXz3ZI7/9TNi9/N6pc5ruv7tCInmU27JHfv4fc3iGWPle7XaJWwjuv7HL+qkYLn1aygiXDlorl\nzllUAAfH94iiNp2tNf6d//iH+Ud/77OshiWzaYFwFuWBnuVo4ZCmHuk0goJRcUoMNdTmlIUzzuga\n+KikRQiLUvU6MAa0EejCMde/i+r8JfrTQ3Z2dijnJxjRRcUNpCnIR1AanyiomFaadrhMI9HY4yk/\n+Av/DWrjMtsnPW5cu0BlIM9KZOy4qnv4wx6V1UhlEdLhRS0G4yG2KrElyCjCOYuaG3Ij8XyFQOE8\nCSYhOn8VEDhtuduDj26FCFuRpYbtM8vodHqqaxMUJdYozCJEWsGiMK+1iW4O+mu/hffwy0i9h1wG\nM6QGRwmPzMacDErOoZFOk3/hM5w7+gIPHzmSHQ8aK+RFjZXwo5JmOWGej5h4Pp+7W/AzzSNMUVf9\nngdOO2SjtqM7DVJkCOeoqoVr9mkCrHvqrMOUCCFQRrPJYzY3mzzz/EXWr64j0hGlSEinj9CtlW97\nHb7jqjoUHhfPbTHqjWlORrC1Q6kEw15KGuVkec6ShdUXv5tjLZk3Y4o7d7iyEjMMNEG0TXNtk/7+\np5kHbdbSKWU+RHhHxEvLeLduksjHWFPidZ7HjyKM18LFMccoGpR07AkiXqeqDFb4lEsJvp0hpcKa\njCCPSJM2utmqk4SthyWtAzdFLVZ+EmKpXf38qYTFcwKnZM0n+JdBu//K4/BgH8/ziaKYbhBQWEMg\nIwpticsSqws0FqoMjEVUBXGWcyJDKDJUVWIiH8/VdlasozCWRuWQxuBJycyUBE6hgUoJfARGW0Lh\nsM4DUyKVR2YsylicMXgSppUhUAvx7YIyXVpDgCRQitxZKhymrFBCUlkN0xGiNFTCElqBVWC1JXQQ\nn1bgxDd3VzX7wWCcW0SQhvyjv/kJnn3mGU6mc17/6k2CMKAbe2xsr3P/5ruceIr3b8b0xiXzecly\n5NGzjmYUMq5S0qnHjcttErGEVT5e0/Fo7Gg0UmZHGdtbXUylccuCz93d49xmm4uBouynHPYUm9vL\n3Hz5t/nEX/q3sFoshMr1C6nQtW5H67pTpjUYXZGPB0hXcdzPuX5ujYe7d4jjLW49eERnqXu6tfPw\nXeKgRZZEjPOMNK3YuLBDMddMij7HYzjp93H4RK5NL0s5u7HG/cf7IH3O7pxHCMk8nbK83KUsS+I4\n4vj4mNXVFc6fu0iSNLDWsLf/iG+88Tqm1Oxsb2GtRxiGdLwE4epOUhwklGXF4fEBm9tn6PV6NMKC\nZEmwPxhhw4j+ZMrz7/8uBo/exQjoHx0jwwgvDsnFkEl+wrnzDfJ8RKcref2tf879+w9OvW4q48gr\n8AtB5NfPOa+OtQYFQVLvTpZjwU98F2S3X+PlyTNcee49iNacwAh8qTD4deafFEznFctJ3WEpDMwF\nLIegJWQ5JJ1lZnqM50mmkxOc0Jh0QJ6W7O8dUfjwcOyYGrj57us01jdIIkV/OK2dV2aB+nKi1lR9\nq/3qyR+FQKg67Vw8saef0qb1UgxfnRt2K8UZUREIiIVgxYOJhjeGjmdb0AgExxODQ1JWjtDXjOcV\nCE1n+FVWllvgdyiISXsHNJebVKPbOBkwmU/pbH+A1oUXMVZTFYbp8AhrBWevXcd3grYCLQOcs7z7\nG7cY/2SLa+81qNYKlS5Jkk3Kck46GxElKwzGAwaZ49/96e/hH/y9zzJP83qconxUYTCutuLIIKDK\nKlCq7mKe8rD6mxRl5KLLbOsi9AmFubaki1rPqTzSPOXKzmWGgz6BH9Pr9/nJT7zILC159w//kGsX\ntjBlRvFoTtxo02xpHuSbtHYuYSzcuHqBUaqx2rC10cLPpsz++CvkFy8SJE3mh0NC66jchPl8irU+\netFhaXg+Qht8E7J87VIddrpIVm6fu8xaUCF8wfFY8u79R3z0g+9DSV2H6J6StBwu1qlmkT5g66aA\nW0wEnKtRJnZvHxGA9T2UNbjKQaFwSUDUiUnv5/hJRcv2mcUvkrz3JmI4xC9bjESLtpoiLfgyI6pm\nHGaSVkLttjIO0XDIi89hvHOE3/+LOD3D/M4vwUmFkzk6X5zPE2/WQm+LBOUKIhxGCHbcIXa2wqXL\nm0TVHsW8QYrBOkdybvPbXofv2BdrqYL8wR7t/JikPMILHD0VkQcOpimyqrCdJscHI/Y1jKcz1HhC\na2mZ7tISst1hfnzAWqdNZ+MMru2TrV4k2tghthGN89c57Fxj2H2eTEBfdZl5PmmVI/IpaVrSaLZI\n05JZskk7FDSyCr1+HuE1MF6brBUQtjbwd87VYkFlyMtag4JQi/gCsYiXEBjnUIt0KGHrXvST2fpp\njlbYZC1u4VtXi+G0Ic0zcpOBrpjlBdYYZKER8xQlBF3tKKcpqQM/ifEFWGMJbG3jVkrSXGvwIz/1\nIpmTfPQjN9g6t87amRU86/BdjQUvNXjOQuEwxhAK8BfjuWrRmXGmditIKTBOI4wlc4aJNVRaY7TB\neZBhoSixx8cYa1Cm3gapsh53lQb0n2OkZUyd4m5MjQcw1lFqg5USmaxx72TMPC3oNBI2Y7VwPXiU\nuiJNKyQwn1c0A0F/lrGaKFQU8PyFhPde6bAUBhwOM76+XzEaV1w7vwbK4i87Th7O8doB62cd3c02\n0ya0z3vsval59eUDRAWu/8fMhkVd3BhHZet8LKPrgNCydJSlRVcVVTWD+TG3do9IGiFZlVPokHGa\ngXLs7X97kdyfuXaWGuRphXBNkkYbJ0LSvCSvKp67sUOUhMzzEYIcqwvGoxSlYsrCIFFIxCJTq4kQ\ngjRNybKM8+fPsba2xtJSh8PDx3z9la9ytH/IuXPnSaKY3vEJCDg+PqYRR6TTCdPRAGNgb+8RW5s7\nNXXaOdpJwsuvv85JPuPh0R6tpSVyLRkPp8wmOdNphjESa32Eytnc3CAJu8xGAfNxjKsMJ/u3T71u\nhACkwEmHwTGfOY4Hi5uzAqfhwtkmP/GxTcK4ycHLGnXYo9Xq0t0+g/Q9lHAYaxAWBB5KOpqNCBH5\n+NZR4NhqwVoDUgtf+dObFOmYLCu5dPkq6+tdrj5zjatXr3H12mW2Y9iitp47K3HK5+XXbvKHf3Ib\ncoGtwPcFQSAIA/9pl8UtEBD1W5gFuXDxQWsE1amOrUDwXt/yMDfsVpK5gUDWTs9pJZhreG0IB3PH\n/tRxMLXszxyDuaWyjt5csTt0vL075d79Xcb9XeZlRitOkOsfZWn7Bpde+ARmcI/+W1/k8PZDzGTI\naO8eZCmD/gGT0UN2D+8zvv0Npnvv8GDvdf6Lv/Zl3v76BBDEQQNnSoQMMEZSVh5GO8oyY640v/DL\nnySbaCprKYqauK+UQjvFPC0AR+Yc2ekfOfX4X9UasG/N0bKLl31lBHqBIvGkZHySoPyAVhLS3byA\n9AOSlU3eeuM+rsyIO22E9NAuYKkZk2lFZT1+7Od+keW1VYIkQXgBSeizs73MUiMiaMbsrm1z/JWv\nI+c5XhSgtaUnNLl12DAibLUQsib/B8rHG84J16/WMzkErpxTTOa0PbjVlyS+IQgC8jSl3Wqiq3yh\nwfzXP8J6T01lFxRq4AmNzmqBM+CrJqq5ibXrVJOE8r7E7AGVAF/hCQtrK7hGjOmsI4YHaJcgFQQi\np1IJrgKsIPQ0pvAoKtiJDdKv30VirQ0f/knkc9cxNsCVE6So6vNZaHesfrJ5qE/UubrzrmXEyEB/\nnnF0+eeRnuXyT32Uc3/rNwle+j6KTKObXbbXv73g/TsWPEn/mMZgF3cywKxcZZobqn6fMi8Y/ovf\nZyf0aMwqSjTFaEB8dMTGuQvcJ+aRC9lSORbHOz0fHhxAPqF0DpFVODUlzU9qnggBg6DD1ExoYPGl\nz1IjohtbgrxAeILZNMcIReU59FTTzHtorUm0gqKARgtZ4wMoSoMv6tiIOpyytiZKa2tRJNRqfVfb\no58yek5xGF0xKXKUs0yyCltqKq0piwpdFtgsZfqNNxk9uk8xHlMFMQ+mM3SaIuZTPOEoKo02ltJW\ntfXbOm7e2mNzOUFPCnRVIQScObNKXmlyY7FYfOFQyic3FWVlMYUmtxpZabAaz1oqK+qgTuswlUYh\naEhZvwyASDh0lhEuvn/6+ABlDFJbtNZ1MOxinqOq04H1AIyrIzdKa6h07Vwx1vHw9uusLzUwuqrz\nmpSlEfhk85wHjx+ineTFcy22OiHWEyjf59KZDqVSLAcen3015/WbM752c4j2SkQ65ExLQzbhwxcu\n8OK5c1zdWqWblCQ7Pttb0BkXPLg/J8srnCtpRooL2wmz/mHt0DKLsZWBohKUJWCgrAxlOaUsZ5hs\nHy8AV+bsH0/o+oKj42MCJei2T2cRNaqisbxM1FK0l1aABpW2lEXKcPKQnbNnuHjhBgifRhKRpTkP\n9vfQFq5efZatrS08z0MpRbvdpixLjo+PEaJ+SIzHE7RerJ2dHW7fuss8zbh9f5evvvwK3ZU1gtBH\nSIPEEYdNsnnOwd4ezWaLCxcvcHB0SInheDbB+ArtSg7279Hd6NLqdCnLiuGwx727t1iJG3jG42B3\nROKv0I7WaAWrxPaUKGGgtAIvAFvCaCqYFzULCRyyITAaLt7YoVx7hq8dOV4r4cFRzv27ezw6nrN1\n5TI4w3A8qot/CVjLZJrjOUclYL5w6pxpCFpKcvveQyJP4Hker3zpC7z5tVfZvb+H1hVGG7IKDqSj\nsaK4cP06z1x+hlKbmknhnuC1LMZAaS1aL15G36QuPIW5y7gmeAtHjdo/xTHWsBxInokcIyN4bCUT\nA4GC9dDVgEwLtybQK2BQQD+Fvanj0UgwK+oOZlpYRrliXHi0mh66yoncMVF7i6OHD0jWrzOdgMgm\nTE+O0cWcw6O7HO+/zeHJI3QluT1O+fKru3yjanMAvP5WBZnHcLpfs7fKgmawRpYXBFGTpUabRtzh\nkz/6fo4ArS3CGowTZFpQ6ZpFhnUUecXa2TOnXjvGCiwOt3jG2wVywi3iYr7ZUKvzFeNV2DjbAgex\nZ9laXmK5GXPx8ga3jzKuXH+O0XjM8eEx7cghygEDndCzTe4+7JM0Y8DR7TaJQ4/AV3hK8t5P/CBf\nHPfJJkOkKRk1KlJVgPTqcaYfEIYRVio8T6KyhekhjEE5nDaoIGBlLcRzBucEz1++QKBLiqqi1WpT\nFN8ervdnHcLWEo4nXS5nXX0uzqFNPf6VDYE69wz+S+8n/OEfxfuLPwcf+XFcMePWyQjlCezGFuPt\nj6CNxGUT/O/7RTxXEqsS6ydUFbVRph7QEQlYTTRGLqYNmcN+4Xcwv/+ruH/4byJ+9z8HVXer3BNo\n5CL/7OndsbhfhJG1VMI5ws33MBj1EbPH/LP/+m/yP37qTYrlVVpNj+Hut4+0+Y4jLT9qo8+cJT8Y\nUuaO2BXIdoiaKlr//s8ztyGT0YAyShC9E5S1HFYVnie5fG6bk94JxaBge/Mcwh/hBkPOj9+iWr5I\nOcmIuptsVhOKIEECTSeZxx2UqchcwaoL2DcGezig1e5Q5jNWVzZJ5Iyx36EqFHaqifxj9Pkd3J03\nKbVjnpc0AwVS4pz9Zvq1dXWulq3TxS3fMss6JYfHx+FwSG1IpCREIK2jcoJR7xibzkmaCeXRMVYq\n8uUl7HiCGw+g0aHViCnmGVIKUCHpPEc6+Ms/8yGuf+g6f/vXruIcDAYjfusffKVu8BlXC409SVrk\nhFKisGglQFekglrobJ/kYVnKBUlYOkteGTxPoaQgc6KOd7CgbY477BGUOX5VoYXCSYWHY+Yg/XNo\neOZZicFSOou2Ck1NnH78xhd52Juy1gwJw7p1PfLg4rk2fidmfDjiUlvQnxUUheHqasS8cmw2Qs6u\nxwRRAQJUYJmmiq3NhMNRRn4yIz+OkFFEd1NQ5JokiGAYsJx4NKm4+qGQH2mvcvekR2t9mSLNyIva\nhaUXsRGmclRVSVmN0eWMqswxpmB+dIft7gqDcc4kN1RKcG57mft7fZYbp2u/L28mvHHzUyTBl7HG\nJ8sLnOixvd2hFe6gVULSXWM7iBE643s++gLH4wGXr57n/r2b3L//kPW1FaqqIo5jkiRhZWWFLMu5\ne/cuQRCglGFre5PDo2NWVlbZWO3iB5I7d27TaXfY33+EtXD+7Dke7B2wub5NZSsG/T6tRszZnTNI\nP0AEIWk2Z6URoJxAq5jDoxOWV3bQYs6VZ3fo9QbsPR7jq4TtrS1eee0brKictZXTtd0BqtShpaC0\njmEP1Lc8oSyO+RySZkJvMGHkRezGc4rhlMt5weVnLtPtLmP693jn3X2SCgojOEklWVGxlkh+7EqE\nkiEdGXF46wi/GfKRF5+pO6HVA1546QbJ0go3rl2td5ZVQRzCB194njJ4jFGCwXjCg1sn9XxA1Xoc\nawWVsZAVzEsHQf0SqbutPG3LL5in9Qv4lLSH16aGUFoSC2eEIXNwRD1OOythLai3dpmBRylUVtAN\nHKGEYVEH+QZS4EkIPUegHPeGllgd04h6RLfeReBQGDrtkN4h2Lzg3Jk1Yp0yMh5ZOefxTPGVxwWv\n9eE//Q8/yP/yb28SBi2slxF4Cc7NieMO2gY0HLTbz2At5PmQ5kqX3//SX+O7P/Z3uBACUlJqS7BA\nhqSZ5sM/9MMsB6djWwE8AT0K4Z7m4Qm+JTlo8bt1YmFdn3Fo/k96Iubyyv/E+hmfuH+Lr7++y0df\nukLgCebDMS6fcxhdQl//OVzY4ObbhwwPT/jaH7+CLSvOXbnI1tYmH3jPJq1G3eH6yb/9t/n0b/1j\nPrBmsNOKvAzxvLrz4IoCayyNMEBFAVE7xmIgn2OzGTJKkKHh2o0PMJj+ETePHMd7Dzl78SJR4FGZ\nimKen+rSJFiUkxjk02Kn5o/VRg3nFJgpTL4Ew4WGpgSRxgxPFCdphXWCr37lDq9sWH7lcs5+EbP2\nzu+ioyX8skAKQTYH2XR4ccbSisfGquPcUsb9k1XuLH2Y7+/9LmI0RQYSVy3wVEEJur7XzaKOM4uv\nIUC6BV7G5WB0HVAUxxx88G9w+7d+jfv9Pu9d6vLWqz0++swmf/KpP+Lf++U/+zp8x4JHV44yMzgq\nxGTEcG2NYTZhaaWNrTwGoxOWqvrGa1KwtNxFB4qVhgMzpWhvIhgjhg9IXIUfwACPTTVh0myRzXOa\nUQdPgZERnaRFLxujwy5yvI9Y2WKl02Wu96niFkmrTeW36asNmrPHzLyClY6PttssnRvVLXajsZXB\ni2pLonySnYXBSIUyT9rlsh5lLZp75rSYd1cLk50DT1oKBEoKKiGIwoB0UudsGWPwfY8yTXHW0bAG\nladMpymhkhRVhSzBl/U4p5rPGT++T9yKiaKAs1sJKyseD286tG9RFoyp6nakq3fEgScXbfQ6sNOX\nCi1cHbqGwgmNFa7e8djaIZD4HiAodIVwDhH5WK0JjMFgqTyNQOBbiX/aJzM1eLBwitLV4wUNGBT5\n+JD3XVphMJzjkFSJRM9KSiTPPrfF4P4hn70z4b2bLTIL660ljqZTntlqsT+Y8qFLS0wGBUOTo6Vm\nlqastBWDyvJ4XhCkFfsjn9VWyCAwRL6HkYb+saN4MGHjUobXColKhyGkWhQ6dTPLUemSshxTlVOq\nKqcq5gsL9xTnC+71Uq5fvYDKJsyyKZ5UDPPT7dR1ukynlbH3+CFL7Q6jwQQ/OEueQpGVHBw+YKnd\nITMZzdgjL8asrsYU5QhtcrKZZufMNmmaMh6PybKM2WyG7yuSpEG73WY2GzGb5awsr9Jpd2i1Wqys\ndFhd6XL37j1WomXm85zdh3uMJzOKSZ+NrfXa/moqZuMhwovw/ZBIaxp+iMRnPJuRVxl7vRFxQ3Lu\n/CrJks+Ot8mwr1k/c5HG7n1MUfF4cHLqdeOpWuuiFGDBlPC0VbLgI3lRRDaZonXNlSgRDHsD3phO\nWF/tcm5F4fkSXVhwlmUPhKdYX41oJzWgbjbL8SKIOzGPTyYEQhD4AaPhDC0SZvOMleWEpNmkNNB/\nvMeDWweYbEIpPPb3a6G6sGJBK4dOICGQzLR+mgH0JNsJn6cjLLEgELrTLRt62tJQksw6POtoUP83\nuXXMEMSLyJymB2LRPRqWjkBCwxMoJKFyhIuXXKnAN1B4jlllSPy6m5YbSTStcELhtGC2e0LkSQ5m\nlpsjeHv4ze7Cj3zyPLPZnGDZYUSGspIsM0zmI5qNMxgTkM+PyHTdVUvTPltbS/xnf+X7+c1f/RJF\nadC2DldtKzCXr7DaadA/Pn3w7JNi5+nfBU8jbepSqHa9OqjNKhaUFAiR8cajX2J9429xrvMePK/P\nOHWcsZY0mzCbw/iDP0vsRUSe5Oz2Cid7PTqdZda2lllZW2LnzBoWSxT45HNH7gyf/It/gdmn/m/m\nqcVJgakMKgwwVUUAqEYIvoVRVodgGI2rNEQCFUUsX32eo0/9EfeGIddSQaPfp9NdBb9FEJ7OGerc\nk9DURaaYdU8dyk/NgtIDY7CRhKDWcHoy4zDvUti6HhikmqTZQrkJaVrSaXSoyhFShChbYeohA9JY\nlOfRCDNkq8NFr0dXf4WeO8u6fYRY3BhOG5yqcySdrZ2NLDStjifhT/XPyi8GKCdoxj6uEXHy+C1K\nmZDuPcTz4T1bHfJCsLz27TWV37HgmacjmhtNXGeFvOPjDg4pfEFj5SxDH4TyEc0GNi9oXL3BaD7j\njJwR+CHZeEYgSnRqKJc3IZsjixQ6OwydohGnmLxP7rq4UMJkRN9oGo2EfHqIajSosgLKgpnwCYyk\nqDIK6WPncyo/xIsbjHVGTEi6fpYojuvxFg5fiQVjp75kbvEEMmJB+bULsXA98l2oek5xGFOHAmLI\nC4gCSaYlgTPM8gxRaSIryaVAlxWRFzIPJf7E0hv3CEpNEgeEUlKKWiuCc/zeZ+7z6U/dwWaOQkFV\nWeJEonyx0G9ZpBP4UCegO0duFjRk40iFwBhbV87GImU9NtPW4kmv1tRIh6lq27o1Ft/3IdcOkMF7\nAAAgAElEQVTk8znC80AovNxiFAgnn3aJTnV5tEU7R2UMlat5PlIJ/DBmhiSTAdsNHzlzNAIfP5B0\nghghAnrzkqAR0MwchYRUV/RHFfPC4fVnrC9F3Gh0eTwqccLDlIrX+xM6SxrhW4q0IPQcvYlld2Yp\nK81SO+HqtYgJ0O9ZVp6t6cp1seOoKktV5uhqRlFMqcoUXeaLDo8mT1Omac5qq0GaTlmKPLou5vKa\n5e7B6YL85r0h8+mIle4y6WxCoxFRpAZftpmVkmw2ottUWDsGYqIwJstnZLkmCkMmgxytK4bDIZub\nm0gpn46yJpMxSimWl1eI4xijodFssdTp4HmKu3cfc/v2HS5cvMRsNscPfIo8p9FsIISg3+vRbMRs\nb21xdDKkyAqUUuw/eszS0ibD2ZSo5RFVFq0zpLL4IQyHOatrW6R5gef7GA2z9HS7UADlCeKkbpzw\n/78nFyMKoTyss2gjabXAV4I/+swrTIFrV9r8+A9dp51ElAvtHAoakY+gjnqpdMV8IezfPUnpnOsg\nPJDSESYhrWZII4mpypwizyly2N8/IM8nvHtnwvEM0DV358nRiBWVFERSYe2i47cYncFix7yg1z5t\nNZzykeOcYG4cBRC62tCWCGgiGFiohKCzcLa0lcMTglTX3bJR4fAkRFYQSYcvIaydwAQGPCmY5g5F\nHa45Lerw36ISpFawP4XjStJpJLzvyjql1SytJeTFMVEc44UNPAWF1QReQhyuI70OoZ2DMPgyIC+G\ntJpLHAwH/MwvfC+//r9/HutkbRt3joEX8NLV57iyuYRwp4QUUTOcFnplwH3z5+N4Wuw8ZZE8+bqo\nXbqh75hM/y739HXOb/xVls62Gez+IfPMsf/ev05T1NEpUimCUPH8S88RBJLxZCFrmFVsNkLKStNo\nNSjSFO0MfRniO41naxacKyukdhCF2CwD56P7KWJwgO97tZA9TxG6ZMkv2D4bs5tVNCOPvBJYWyFF\nQZmfruvu3ELeIb5FR+8Eknpj7LQArRFrTWQ3hMBHSAWR4Si1BFGCUIrCCTACRcE8A7l9FfPa7yFU\nbS6woi4ma5SJpalACINB0lF9lLMUOiTyCsT2BgQRPNytNwdagBW4RQX2JO0e6rFw5jUpJURSoTyP\n3je+Qb+a0swmrOaSxvYqJpesX9z6ttfhO9vSz5wndxLvYA+3tIpoKc6VE4JyRDGwLMU+vgvZbjYI\n0hndrA+rW4jlDWJ7k6OJJFgNyEYpSTnCd5Jg9gi9dgk7N4SFpqnvkQcvkEUWP1lCZSNEEFB5TYqq\nIog7+GpM2LvFbD5ntH6D5eEjsq1nqBx0JXi+RIWbpO1N3O5NPAGBUk+LmbrwqX9Vi+rRLpwfghoa\ndspxOoExWCWQlcaoAKMt0moKKYnmc5SSGKHZ7nTZG/SReUYSRszzPnoyoRP5T1x3KCHQon7ACF+R\nlgVxKAjqQB6s1VhqG6F2i4eSrM9b4ii0qXUCnodvKoyUT2nSbtHmNThyawg9hbX1TlpbWwcaOhDC\nIno9vM3N2uH05NpZ/efwaEGuSyrnURlDaSQGUFYShBH3DgcoIZj6DbaXIvYyw1g7/vCLX2Wj0+B4\nNGWmLUEYUNmaM9QbDNlYbTKZCYyu0Lai2/K58zhlMLXYomJ/F7Q1BKGC1LK65eFJQzOxrDQzUp0Q\nS0FrTTHpFTghKUuHrhxVlVEWU8pqRpXPKMsMUxToqqyF1JMDpHFo4eFZn9WlVUaTAcczWOssn27t\nKIMrM5xoE4crjCdTLl7YJlABUbhMnt5lOvHZ3tpkMhyR5yWVqwNWtdHEScz+/j6bm5vMZjM8z8P3\nfTxPcf78Bfr9PvO5T5blJEmTqix46803aLcTDg8OWO4ukRcFw9GIorSsra8SqroT2eq0qMqS46MT\nllfW2D8ZsbFzhpVOm+PjY9qdBvuHd+h2t8A1ePzwEO3nbG9eptfPOTzerzPdRIRUp2ep2EWfPa+e\nrLpvqQoslBmYylAWhlJrWp3/j7M3ibUsOe/8fhFxxju/+eXLOWtiVZEUSdGQ1Bagdk/ottvL9t5o\neGd45ZU3bfSivTLgjeGFvTbahgG30LYl290aLFkTKVEii8UimZVz5pvfnc8Qoxdx7suSZBb0eIBE\n5hvz3jhxIr74f/8BVC1J9xXLM0s5HLFYaJZ1JL5ba3FCMS5zbKgJCNJM4YPBevB5yt/++7/KaNDj\n+Zv/ie3DI7I0I1GSPB9Q5DnvfPwh27sJo8sfkxYpi7pilWw4ENHa4Hzu2JvI6COUKJrad7hC1E4E\n16lgbJellUByU1TZBwQBJwVrIWgJmAAlgTGCysM5sJdGLkuhIsJdO9A+ZgjOdWAtIVeCPAJSFJ4o\n3wakiIRfFwQz7Xm9FjRWkGWQZ1DVK96crVgt4O+/ex8VBEWR0rZzrBtGbhglCSVNOycEi/MpJgi2\nxxPqxrAzuoM1LXMLZWQaUBvH3nvv8Oj2Lmerhr3xzZ2WnQ8EG9EilERe28qCFBvQTXa88RDbJDIW\nPCoVOLugrv6Ix8/+iP/2+d/hV/tDMl+QDMb0ej08xMNjSGmNpyhSRoMCbyyJBNMa8sEAY2znh+ai\nUtYnkYRiDWnWQyQSgsd5h1u1NJ89ZSiS2GbrDUGlhOqKN8+ecOfOHexnP2W2lrjQovIU21SUw5vx\nBulGYnO4VwECHikl1sd8rbBUBF3By9Xbx87B82aHnUlBCGCcRTsN3lNVQLPsjHtjC9d0Jq5ee1SR\nUMoW7OoaCR2VU0yTx+fBSwhl5ML5L77I+HHYwHEicpASETMXrfdUV3OSe+9z/vv/il/eKSiznLaK\nar+y/7PnzpcWPK6qyfUMd3ubRErGYgvnR1TnFxzuHXJxsWKSt6hBSt4ahvuHlEnF5fkpK7fLeKQx\n6wWjdolJxigzRWzdJneBtF8wG75Dr5ox13PyNCF3FauspLWWUqaMls9o2wuy8SHp9hZitINfVNjx\nDltlyeX5S4zI0P0EW6+RHz5i8ckfkxe9OJGJVaxCELBd6neg6wJGFvj1wN5sW3fBkyNwiSIJDikE\nQQqU94jgY2utaTjXmu3tLaaXV6S9AmkDwlqKXhELEe+QAQoSKmFQ3uOsxqkEI6AQksobku6BFYLo\n74AiiFjGpVISgmPldIR1jQaVXLfrQgjIEKMpBIJECqyxxEhNBakieHDVEusPcDIulEmQaBl+LuNB\nZy3GaayXMX08BFyQqKJHIlImheTJ2RR5/x7bvQqxbPG9HrX1vHcw5sdv1pSl5bMzj0wCyzIn1HCw\nE6gJ1E5xfmz50StH1XrqNjAsFQ6B0R6XSNaVYG8gMV4iVMni0iLznFGWkLq4HTnrMKbFmApnW6yu\nsaZFNzXOaKw1WGNxZkkmJ7j2gmw759XJKXcORkznDYsb8r9c0MgAFxdT+oMdQHF19YbVqmAy3mJr\nvMX0coFEMeiVXJ5d4aUkSXPKYkAwmrLsMZ/PkVJSliXz+ZyDg/1rtdZqteKTT37Ae+9+wHK5osxz\nemUPJRWj0YCr+YrlcoFxksF4wmKxQErBaDwmlZLdnR3K/oBy2SKF5HIxBWHZ3hmzrocMexlSJPRH\nQ67W5yyWDd7Do/t3ePnkJ/ggKIqbL8pSxTatd38VBBEkoZPVAj7EwB+lQElIBymcOX7y/IStUU6l\noXAgpeJi6fmF+wovCgSaNFVcWjhIYN1afv3Xfwdn4cnjBX/8O3/IYGtMhmY4HPLTz5/RrGaM9g/J\nJiVff/cBp1efsro0HQcnvsJ+LkhVPFCp5AsIQifHha7o2aAON19yuq1bdNyLgJCwAqyHbRnoBTBB\ncNLAXg6KyNcZIKgFaN8Z8gVY2UAjIJWCvo+HLtGZlQqgsoHjCkzwBAWrmLWLklBpwZqAUNDr75Cl\nGmPBK00hcpK0ZL4+wwpBkYxYmxpjK4IfkKoWawfMl1csiZ0+ay29/QPu3bvHZNhnOncsws0J7869\nVdsKH6LZY5e0LbsBlCJ2kYwAgsDBtc9LEOBsfF9z91v8aw+uOuDfF566bUiTnDRPqRvLcJARg99a\nslSRqrj+t62N5F4Cfj6lNgaVpEgTEEmKx5MpCRbKrEBVV7jdHdTWbfzsJVKlMfQ677F/+4DL6Rzn\n4NXM8c0tj2k0Ki9YVz+HjG2DmvprXnBUtPmAaz2hhTAHnECmXSFYepyQ5InHOY/VJqrdnEJJy+rp\np6gutUDiCSiMcaRpHF+/gdxUbP8GB0q1ccxfHQPHEdXpCNUbdVbo1KKb+Rg8eCMAyao1ZKZisn/E\nyZNL5EcHqFTRLBuKsmB9+bOd7790pQ5CMMmXpNWcYbNGzM5pkz6hLDBpQd531MOUrL/N7d0MXy/R\nC+grRdqckzVLTL1GqoSdtMFkfVSekaFp5ACQvDGw1e+DFORpD6MDu1LjbUVdL0nbCh9yKrmD7N0i\nXPyYgbaI+YJxf8xIrPFlgSo8/usfcTVdUmQKCDEtV8RK1ncmhLEA2ignNlXEjRtaBBdJz9HNWbL0\nvpP0gtAt3kCZQC9RNHVsFSXe4whkqUQ3GmM9jbFgHNrZCDEmkkGS0BJPkHVo40RwPtqAE1BdPowN\nDo+ncZbgYzyFd/HUQgjIEDNIpBRIAcFZrNHUVhNEoFcOUFna9XAtfrnEeoN3nsQYjK0JpjMZueFV\nt5Zaa+pW07YNWrfotqYcH9Dfkly1joOtEe/uyGgjkGU8eXHGmsC6tczmDaU1TGdrXr5ZoBvL+bLl\nzZnh89ct//J3lvzup2tWTQvC89H7Q/o9hbegUomzjp4PfPJ8zeVC88mTFU2rGQ4Sqosat3IoleG8\nxtoaq9fodonVa9p6jdEtxkbFnTaaWjeMxgWjUcakn6FUAGHI+wV7N+RXvjo7xXjFbLqGEDC6JUk8\n6/UZxii2x3tkaY88GdGsA4KMJ09fsV4b1lVFWZaUZcl0OmVrawvfmQy+efMG7z1PnjylaVq++tWv\nYawlz3PSNMN7zwcfvM/+/h6TyQTnPZdXl3zyyQ8o8pzhcEi1XkUFnbX44Dk4OMAHz7Kag3Cs1lMm\n20NW62k0jTOW1dqitWY0HPLJJ3/K7u6INE3o9W5OPE1FR1wM4a88k93HTTxlhxAL2+U8uoFvZuj2\nuKDsZYyHWXR09VFt18s38zwWJmnofEiEoj8s6Q8LigLuPbzLux+8z/37j3j46Ct88P47NLpiaQV6\nFfjej15xddVxWELcFEQEFLpzQQcpf+EoGaSISKoUbJjMwYUbh4duOEHuWl0Ti5dKQNUhxT0RGCaB\ny0bQdIiQlFAqKKSg3BSIHXG6toGZCUxNYKFhrqO667QRBKkosgxjocj7fPOjjznYGrI12gGGfPBg\nElsZyQjPmiIbMm/d9f3piRTnDLpdIpSjMXO0y1muz9gZTfg7f+seV96jcsE7733A/d0JCEHen7BX\n5jedOvF++phpFkOBRQyL7lrq3gtsN3YbZk/oxtGFqPJyTtE0krURtBUUe6do4/E+0OvlKAV5lqCE\nhOCiuW3wSCURUpGkCqkS0rygefoMIRQiSLw1UWHkA6HV0FETRN4j3esRrEYmaTSztRprPI5+lKL3\nBZNc4Yzl6moWxTjuhiqtLxTXgbe5Y76T8DsT0T05CMj9MsYYNbENOAuKXEmC9zgE5+dXLNcNMu8j\n568RIolzsoNjrOvGvTWx8BFREe09MRTUi45qIiKzqnsufcerensP4+t0XXUWvMV5z06esHN0h2Yx\npxIgpKetLFfnc8gk4UsyW74U4dlez/G9lHS4S69MqeczQl0xzjWztiYbjNnupfTrSxoZT9JlnlGt\njxmYgAspPhuhRcuyrUjTjMxpgiwYm2My18OyxORbTMqMmeyj7CmhWeLzPfztr1KMR9hFhQmOVZAU\n/87fY9nUjNopBQMWOmXetsh1zu6tR8hbD8iE6QwH4/oS8/q6MEIRkN3CZEPHFRCdWdUNrhB8JFHi\nEEDehWXqoJBCkQVDlaRY41HBM0pTrpZrdNNgG41JYoGSCWi9v4bvtDZxIjiPCo5cRZdl72NEhPMB\nKTwuCAoERkDiAjJXhNYjg0M7j1JgO+ZSKlKSVNIYgw+BTCiklDitcSGAVAQkW6bFaoMIAkPkBOTo\na++Km1x7A8F0WTHXgcapaGcuHUk5QdjARw+PqFYV/TJhbjx7GXz7g0NckPQHOa22bPcUk1HB+aKh\nbi2DYUJQsLaeYW9JLmE4HGJ8IHMJg8zhi5YiSclKydPLmg/vZsxXFqkFIs95dVozGSUcbQn+4k/+\nF+anx3z97/5TjK4xpu6KswarNcZonDFYa3nvnbssXx/z/oPb6NogpWa6MgzlzSePdhlmvebw8IjF\nfM6wnxEai7JRgpqlPbyTOAfzxZK9vW2+ffBtnJW0jaGfDZnN5hweHrJer8nzHGstg8EwOiuPRvGU\nFODo6IjZbMbdwx0Gwz5Ga9arFUWRA4Esy5B1gxCCQX/A0dERSgQWsylnZ2fIfIAUnuFgSFEoarMm\nzxN6gz4uBFKVsL+1xd7uLk3juLU/olo3TEZH6Lq98bxp2mjMJ0N3GLmutUO3qEaiaQgBYxyzCmRq\ncd0cvTpf8Xn2mqbSJAIsHpFA1TRY61CZQCYK23psDk1V8fDhIybjPj/888+59+AeWzsT+v0+SsrY\nZraOwaAHCparCm06CuwXzF5Ct4THhDreUkXi2es6E0j4+Lngv/Dzf8PLd0TTuDnEabe5z5cicnW2\nVUTCRnmgthEtG6hYa5VJlK4nMsrXTYjriY6dANquCDIuKuKEsEwr+PYv/AJFWpDR0mQJ9/KUo6Oc\nkjnD4R5erEhUyWz9kiLfxnlNkfdQIpAkBY0LGJuyNT5gtrxge3wPKQO//Gsf8Vt/8ILR/h1u3b7L\nrd0R01XFsNeDvH/jueM3EjgXCeFJPOFew4R+g3lvhr27JyrSH7uxjL/Dh0hun557/u/f/4/56L1/\nzjc/+hplv0eeKYz1CCFonaaXSXqZpDaOxgUKAioRtK+fMhQZ3ptuw+4QNqNBFVhnyLcmyDIQbA1K\nIRJJqANCJYyPjsik4P1D0I5YcCUZKhWYen2zwQlc85V8gKSDtET3uXoZyJbRpFEkirBsUCrOLSNS\nFC7+2wfWVUtt4uYqVm8IpAglu4I8x7HCWkHio/EvXQGMDoRERP+7jkpx/Qi5L6QefOH+RB5P9JQT\nTiM7lfHaeH778SvePxxc3zdno5ed+ZID+pcWPE37ktHRL3E47HFZaerxIfLyAj86IiwrSh8YjHbw\ngx5+cYLpT1iGQG/nFvXrz7lwJVulZB5KemmJb6Zobbnwkv10m+G4z9WFoKcdKgFvz+mVOXXtKf2S\n1CQo6TgjMEh72Kqiqa7oaUfY3sGGlNXAkbeaIqmpbEbyzofIN3+OQCGCiDJ0BMJLkq6iD12OiRQB\nH2RMZb1hrg3Od7lcEhviCcAFSaGgDYHWBdJgSLuHZy0V/SKlFFFRkeQJKkQ4OnUxODRXEuUDXkiM\n19y9vcvjl6dknfumIlAoQWs8qRKY4NEmGqwlITbuUkAqiQ8dj6lzOq61iRCjENhgwCYIAVkST77B\nOdZVRVY3NFLgQoSbW6lIfw4Sz+PHP6CpFgyGExrTx6pthMxJhGaxtpRHgXv3R3gU+4OSdWtQNuVk\ntqRMBFu5ZLLVZ+4bPn53wr/97jFV2/LxhzkJgcVc8eBoQHlnD1lWnD1f0CiLSaP0crUwkSzdSD56\nOCTrOf7wByvmq4TkUYpt+6xefYfnV9/hvfk/xkuFaWtMW2NNQ6vbWOy0Lc56pqsrFlXDxLRY7bm4\nXHK0M+K0vuDh4d6NxkZmfdaXM4yeooAi6aGrmocP7zC3lrIcIYRktVqxf7jDer1EuYKT4yvKYkDS\nz1itVjx8+BCtNU+ePCFNUz7++GOurq6w1rJardBak6YZ9+/eZbFYgrCs12uapqE3nHQKwtg6aJo2\nemKFgHWG1XJFXTdgYW9vm8XynNWyIe8ljCZ9jJHs7OxzfDylFHB1fI7Wnt5oF+lhuY6/76aXNTFf\nx18r/b/wXCYCbMAj8CFgumykPM85ONzj+PQVIgiGowFea3QwSC/IpaRICyqzxvoQn4/Os8V6OH19\nwqefTGka+J3f/G3Gu3t89eOH7O/t8ezZKy7PVzz5336fZ6ebFxJx9s2pOVGxTYyUCMBYD7bD4jco\nTqeg2kRNsEGYb3BtNoPQtcu879a2LjtqQXSq3osiG4Zp3KCmJiI7PRU5O5L4d+haYEtiIeQ65CNR\nAhtg0cK3Pv4K7bri8eO/4MFWxn/+d28x6fdR/YLffFkhZI23DUIVOFfjvEA7wzDJOZ5dkaohW1sP\nmC1eI70ly4a09TFCZPwH/+Dr/Iv/6jc5PLpNL1P88KcvONopOZ2dY/TN4logmguKzmFZImJIZhdG\nu+llCKLc2fv4/XQFAL6baRvaiIg+Zs5JjP0R3/30n3D36I/Iq5bd3Ql5mmKsIU1SdIjzqJcrpEww\nPsaKOOsQ5HgT34tykb5ACEghY4EmQdw7QhU9wrqJBYdKUALKnTuoNCO4QCoVvf4A2xqqJlAvbpaW\nvqFwvK0n3s49FyJh2FbEeRqWKAHGCJI2UAcYdnvQqtL03plg9CWmUzp7LwguxpoYKyKq4yBco55d\ni5G4zwYRLVyCENeWDcHHiInrw0JnHSA8eBEP7s5bEg+N86zWNVevn/GN/THOW9rWIYUnhMD66me3\ntL5cpbX1Ac3pmqw/pmVJsXJMlwtEf4+1ddwqBbSOarmg/9VvYVcrxsspzdWaMjXRpO3Ck5UlxtZo\nEkTjueNfIG0fq1tEpun5ltpvM1keM+sNGSaBJQleVOS6T5kl6LWm1y8Ytg0+GxC0wazPmcjI4VBb\nB/TtkuX2CPUq3lIvwltOfmyno7pBt9dN3U7LdUMeBiLgrYnW3AESL7DSIYLCBchSResVddvSSwRF\nXRHyDJKUBE8iFc5aMglaCbSJeVpCRo6OcJ4fPX5NWaQd/OwRIfbdg+hsPTbvSYFzceZ4ZOzdJ2Ct\ni1w5KaOaiYAMjuBFLDBDoLEGhcB7R7OqSI0B34WPyijpvjH8BTz+7v+FknBnZ8Tx1Zpk+1vI/m0u\n3nyH+wfbzFeWvb7kB88vOL1cc29vxCAznGiNF5ILnfJm0ZCl8PmTGXujAqUk01PDy7ljuxCsA2zt\nBhYXGQfv7uDbNbNZi1vDYuYYEViJhKfHAWsaMgqUa/jWe0eYxZSJuMDPA83qDCcKnGvRukE3LUa3\nWN0hPEaD01yt1pSznCLLCFKSFX1qVXBW3cyH597dr/Dw1n3uHOzw+Y8/ZTE7Y2dnzJPHP6bYe0jI\nJYP+EKUsBE+R55ycTYGAd5aLiwveffdd5vM5SZKwt7eHUorpdIrWGu89Ozu7eO84Pj7m7u0j7t27\nx3o9R8qG27fvcDFboJQiCZHHdHh4yNbWhBA8xlomkzGoBEvKaDRmubwgLzLyMsVYz8XFjDTLKfOM\nSTbk2YunWCu5mp6yvT1kMa24vLy5tFhlIubsOPHXuWPdGt00LetVg7EGgaDVLVfzmv4QijLh3Xfu\n8KSteX3aYDrbBu1BG4+Ugbb1ZFlC3caT93hrTFACeMmHX3uPYjjmF77xdba398mV59n3/hgVMqCO\nvJ0ODdjskEkeN9U0ic++th0SIwIkAuGI7tEhdIereGq9aRudt2BFDMLseDfXVvwymipaL9gVgYGI\nDrt5Gr15LnUseooO8QmE6MCbRFKz6VAy7aFpAl979xEg+cHjzwD4T37tHvujEpkYxlspD140nJ6+\n5mB3GyMU3hcMyntIkVJry/7kEWlR0po1UuQ0bc2iOWNQjAguYXdnzD/6h+9j7ZhRkVFLR1aMqBVM\n5I1HhxA6J33RaXND1yLpDDk7K6QObeG66MV1hcC1X2S3a3TgrbWQp/Bvfv+/4G/94j9jZ2ccOVRC\nIhPAe3yQWOvwwSF9pB+0rScVjqAEwrgY6pwkCGsiRzKViLJEPLgHRKUaIhBkgggRj9rZu8Vo8FOG\nvR5vXpzxi7/0S9hmRW9002De2FLaFNuRu9q9zw7VsybKwqOuJzb9mgq0jBylpja0NuBd7AFu0jrY\n+B0pRWugl7wtIpWICM/GVEdYT1AiZt0l8V6w8anq2o/huurp7lX3tHhvEVKgvaSZXXL1/T8j+5V3\nWFYLyqFgvL+N1Q7vfzay/KUFTygnDPcmuLykFCVGn5DmBfPzlxzcuU89e84URbFzi4vpgnR1gSp7\npO0JzehDdtsFz+2MQq4gSfDLU3pOk7/zIedGUFy8ZGQluhhRBcda5mz39lG5ZTukLOoFM7FLpk+Q\nKsqHQ76Pl4b5smGQZzg1Rolz1lrTLhqKZoWXCik6QVsIICQybKpEiRQBgkQQeTeWgBI3lfnRKcAk\niXQ450lRGCEI1uClIsfjZaC1niLPaK0lSRISBFpbtLf0sjT2dgkYGzBpwDsf7fBTGeXyPp5UAoKc\neHg0IXKIBiiWTl87t2oRXaa1CSQhIKXsHEwdIkSfHqTqAgxlZOsriQoCqQ3S1kgrSGUSM8eEpJU3\nh3iCD8xWLaWSXLx4w+TymJPpCikDr1aOr96fcHZa83C7z/Yk5SAvuLhak/cVF+cVH78/wuUpVile\n12u++eEunz++4r3DEfcf9Ckaw2tn2R71UVXDYEtSrRKypOFCXPHRnSFUBqkdATg5L1BZzcFkwKzy\nNDPBSv+E/iLl8b/+52x/6z/CNpfYdA8hc5wTtG2DNRprNIv1goNhj/1BDyMk3/7KfT5/eQKJIBvd\nDH4PIfAnf/KHfJIptoZ9JsMBi/kVg0HBZ4+f8tUPv4aQkjzvdQoYjTVT8iwnzwomO9s455jP5zRN\nwze/+U0+//xznHN8+OGHPH36FCHoCqGE4+Nj2vWCW7f2uXXrFlprZtMpeV6A9Ny9e4erqytOT485\nODigqdesFlNWlWG8e8h8PqPIc6yLO4LVhqIsmM0W7O3c5eTkgvm8oq4NO/t3QKWoJF3xNygAACAA\nSURBVEX/HInXSsKoELQbsuL1V8R1hysmX8dtyfvoyF8kCYcHe3z88V2cbnh9usQEsDIgvSdPFVIq\nnGvQTrFqHP0EmhYGgx5ZEflGAZhsDTt1m8bbgJKSg4lAJyUvH0d7ZSEESkXfmtBxiKzRJCLF+qjI\nElIQ9Nvultj02Ltd56b5fZ5NobX5wdBxiMS1KlMg0AROhaB1gT0pUBL6BErVSdS1IFfRn0cJ6ClB\nJgPax697IagI5HnOd3/4KSA5KBV7I4XIJUk2jHzCYEgTy7xesLV3m9olrOtT0mRImW+hzQqZSBKZ\nsbM1wGpPIiApSuZVRe2W/KO/9wG/8W8Syjzw7Y8f8fmzE0iGZKObq7R8iCZ1LgAumsAmaXgbyeTj\nPIpf7tB+v+H5bJx64lobV1vRtcAkrQ405rf4gz/NkfKfsb01ZDCI7S0pRTx0+jivRrnAmwbfGFyS\n4DvhTOwlmlioEknnvHuIPNiL8kOVEIKLaIhzyGxEWhTsDPu8uVyTIlisFtw63EWEm6GnHb4Y646w\nIQnHr/jQEb67IGVxjV52LuchIEVUA/sgqFZLevsKKaIS0hCLRwnUTjIJb/dHpfx16zB0fLJgia1e\nYpF5jVyGtw998NEvL3SQXPySwwRPQkA7y/wcjHbkSOq1ZvswQYhwLWv//7u+tODxeQ89KFn0S8zs\nJRiJ0w3DrISrBWYwRuV9Uj3lVr3k6uIKtu4itm4hUsGVzTjYlpiVoloaeknBcDCmqiWD+XN6+0eY\n9RmSSIYa+xm9q3PM3V9hev6SSdFnPX9FyDOS4ElkwLuK2VojbcVKpCRuTtLfJVke02SKbLkCv5ms\nLt45768nshJghUBterXCk/hwYy9hHzwqKFpnoo9DiB4/ysX4Cm0tO0Ihk5SZsbRtyzDLWaaCRddD\nVUqREwvcRAiCiLEOXoRoxCQFIQik8NjgkcSHWQGEiFLZEI0GQ1dSp4BFkopA5R2JdSAjgdsSK/g4\nzz1KRG6NDw7lDJW1lGsDUtEIhxISqxRK3dwTw3tBP1MsVwv2D8Z8/b19BpNtnrw+5zf+7Z/y7Kzm\n6+/sMq9qZouGYkdhU8hIUHlG3RheX64ospRUwk+eTbm732dtDNMXU/r9HLmT8vSnL/Eqo2jHVOsa\nW7W8czih8g2Vtlit6PUC41s5+6MxQ23AaB6f1Az7isGBxLucN3/26wSnUfmQdd2y8/CbrEw/qtds\ny9awj3MeLQKmdTydn6GFYjjok6Q3MwHztiZTCgnR6K9fkqU9VvNzTBuoqpp79x5QVStmV5dICe+/\n9xXevDkmy3Lquub169dsbcW21E9+8hPu3r1L27a8fPmS8XiMEPD06VO8D9w+OuL24S5Nu+bP//wx\nk60JQkq8d2htGI+3yFVgd2cLBBRFQSK3QazwztHUNb1+ga0qXr16xc7BiPF4wsX5nDwvWayfUw76\nyMxhvGVn74Bnz4+ZL27mTwRQN4HFSuDtRtT9xYZ+h0CYBt12cScWyODgcJvTizmfffaCYRGfG4Cm\n8aytRxDIsxTnWqSEVCla51ESvvedH5HkcSmslzVPf/KC/ckO/V7Bk8dPYiDhIOe9wxHL1RXSaq4u\nWny35xQpZCnR5FTGtkroCiHRobYhxCKIDTfp57iCf8t5iC2ttx2biDyFa8ND6wJTIWgd7IZAqbo1\nR0AuA60X1DauiQMFaRduShA0HbImrzcOz3/47UcYl3QHqxZBn6//wi5FXpIPD2malkm5TZKWkPQw\npkUqSZKOqNsGY9YUaZ9GLjGmJZEZdav4xjfu8D//qx+j5X2evnyDFj2Gg5KkN7jx+HjfKeN8uLbt\n8E5ctwBDJ/rYkM29iwiQJ8TIoeup1u0NdLJ1ETdgGQLzxW/wv//OT/l3v/3f8LUPv4q3gTSTDHLZ\nHVRFXFerJUpKjI0RPVIEgrPI1iCzFBk6a5S9bUSeE+oqzpUg6OR8ZP0hQWYMx9vsdvw1bzSPP3vO\nVz5+8HNOIt5allzXzQJruOaUfWEaoa7lXOC1BxFYNeY6fV50v1QKiQsC41WHnkHwXXfA+oiedQWm\nII4nG/Wi3xRc8d/XCKqP3xJFOJvXIUlCQIiEK2C21NwqIUkkKMnj7/yAr/ztX/qZb//LVVoV+FfP\naJ6+wV5pwuqEXpbhRiPawZAwr8lXK0okq/4uw6GinQyotKB+fsyh1cjefdLFM8YDx9bWLfLhgNZo\nzsUOYVWh1Q6tzOg5QzrYYdW/y+r4MULmLGVKb7IDzRqy2HKprKPUM25v77K7vU/o7aHTAUaN2N2d\n0LoODRG+uxmx9+iDQHowIiqXPIGYTBUgJGwMjv7G8yYEjDMkXpALgUoUqiNOSRFIlWSJixJSFxUD\nrbdM+kX8WWIlWtvYSFbeowgo58hDPIcIHxB4TAcbJiHyfQwReVEduVAGEGoj/QbpHY3tkjCVJBBz\nahIRXZNVd5rRzkQZLRLnPa7WCKcR3c96a1HWEvTNs7QWiylBOGbLFVenM/7s+8/59JOfkPolX310\ni6+9c4s8F7y4MHgneLA9ZNIrubNTcjTOSLPA+dLSV4AUSCTffz7Fac9uKRgl0JN9krzH0U6BcTVB\nVZTjwPPzmlfPLUJllCPIc0WRa0KouZQ1Z0Fx8GCIJaWqwMscUoX2AussgcCLT36P5evv0CyeYetL\nMpWi0ox61bJerkmTEo9Aecvl+mZ8gxePP2XUK9nf26c/GKKSFERKoKQsBpyfX5DnJS9evMT7CCW/\neXPKnTv3scaRZilSSYoOlZBS8vr1a8qyZHd3F601vV6PJEk4OrpFmiSoJCGEwO3bRyyXK54+fUpZ\n9rhz5y7Pnr8ghMDl5RW3b99hd3cXH2B3bw/vA1fTS9q2pWla3n/vfapVjXfQNJFgb3zFsp5B4qma\nih9++kOss7zzzqMbzxtXx7ynxv1lfAfe0uy8CbjgrhcvKaFaLTk7fYNuZ5wvr2gbi/EAkbTiENGT\np1PuGBsokxj8+f5X7vKVj94FYDiZcPvebe7cvcPh0T3eff8d6pUlJAkyEczmS0aTKLffyNKTJL7S\nREnyJMGazmtnU41AXLl16KCg7s9NeYOBbkN8uzlt/r1Ja79ObvfRl2YZAq8DXFiB8bGVpSSUKtBP\noJCBuQ2sXSQyCwmpkkjgk8ePr//rd3czgrcE7yO6IAJF7wCpMpSUVO0JNizxfolwC4p8gPceZ9Y4\nM2PQ3yMtS5JkC0JAmylFMeHe7RF4zXq9Is3H9MsU5T2Xs5sXy3TcLttxQpyLa29wcSx8iAexKEXv\n2lodb2mDQMAX2iqCa2VsEAIXosZauMf83h/8Y168eEleZDTaUdeW4TAnVQrrHe3lGYlSsSC1DtYN\naIt3FpmkSClIDybREsS4qPAVHfKSFnghEUKwfXiL7VFJmSsaLRlvb/PovYesVqsbj83bLW4Trvq2\nrWrtRqUWuU2uGx9jY0vQeo/dJI86C0i0cTiSLm09FpTWq4imdb9DdATiaG5I7CxEqPJ63m7m7KY1\nG0IsfvzmHnRorrMWQUAHT9NaRkAWTLw3LrCeax596+usZj+b0P2lBU+mT9CJwK1e0SYtzWSMEYp+\nsUUQgUZr0tRSqZxMJKyHj1C6JbQ1olfgJ7vMjl/S232fVUiRMjCfr8nNKXfyJcnuGFsUuPE2WXMC\nNqHvVwRVUCQpvjVUraM36DMzfbJSkghD0t9lKvpor0nNFcX8KTJLubCKvFcQrv02Iwgu/KZnLZEh\n9i6vKTxBgvA3L3i8IxUChKP1jkZrtLGI4EmCJA0BHSICk0gLISCcp7GWtXekgMJTSIH1DhMCQcQT\nlJOCRAiMBBPi4mmJwYexHy0wLhJMCzohi3ekIkrRPXETSGUkk4kQT58qtohRUnbcSUHwrnvNHuE8\n7XIJ3qJc5Cgpa5H+5gXPhVmRJpLLuqI/KsjyhOPlitVcMMigaVd8//mCw52CWwcT1sbw6YsZKfCL\nj/ZwjecbjwZUMjBdteyMevyDb92l30u5c2vMk7Oa9WzF7mTItIqy6DLv0zYBqQQP7o/JSodMElqv\nUP2ccea5OvMIl5KUAw4PthgME9bVFQd3PsYKwWodDQcHhUDqOfX5M85ffMKsWVGojMZKjJL0Byll\n0Yda0yu+FCj9a5c1kQtyeXlJvzegbjTGOLzPGPaHCCH4/ve/z+HhLdbriqa1EBQXF1eMJzs457l1\neIvZbMbR0RE7O9t471ksFkynUxaLBev1mt3dXabTKZswoVtHh5HUt16RZTlXV1e8fPmSfr/PeGuL\nBw8fRh6QMRhnubqasq7WHBwexHZanuGcp2kMWVpwsH/IelWR97bpDUd88NFHqFRw6+iQLBc0zQ2V\nJAChW2wD8Fefyc0eLwJ0i/NmHV8sVvSKER+8d8C/98vvxdZMAO08tY5EfiVlTFEXAuMiwZcEyv6Q\nre1oHln2cw4OdlEqqkFAoE1cjPtFBgJOT66AqHACaGw8QKAkxjsy9XZz2WwoYmOk9cW3dUO303hy\n/oLd/mZj3nBVYlLmW+VLxzPSDi49nHjB2sXvF0QOeCaJbrhA42FpA62BPBFUJkJYh/2SQgkkDXiH\nC0lsc0hPOTiMFhf5mDTpddwWBUFz6/DboHpsTx5h9CmrxQWtrsjSPnl+QC/vo8UhHzw0mEbTHw+x\nIQUl6RU3l6VbG66JyNYHnO/I2H6D5kSOj9iM2YbPEjatq40IbsMZ+eLgB1yIEnbrJFLC//Hb/xmn\np1O0NjQetLZ470iSjPVffBcnO8Klizt36Ew1JSD7GbLfQ+zvQrsgWndLkGkMbM4yQgjotsKKHrPK\ncTKdMZ0uEViEuplPkejUfBspfvw7HiqCDzgLQQe8DddFoncReREiHhZqYwmIznYg2kPgNsVlQCAx\n3Rz0Ib4lKXkrXfSC4MR1sXk9b/8SsrNBMsO1RB3i6yA4RAjkSlKvl/SIOFzTOlSWohKFUIKTx29+\n5jh8acFj+wMG1Qw52UGWGXt+jUtymss5pqlovaLyI5LqHOkXDNwZuj9GKkVh/pyLNy84LFvS2Sm3\nM81g3EemGYd7d0l7O8hsAPM1aVVDMUGWGfNkhAue4AxlH4KecmEKdgeBRozwMkP0+qxDinYZohxB\nPmakPKPLS8zpBUJFFtQXbHZwSuDUprDpeugbaBPRicv/5pcKARM8MvhrQ8BUCqwIGO9ovSfDk0pP\nLhIK4SOhLpZdOBfQPkSiWwdNO6dxBMoQqIMj8zKqvJyLi5fzGO9Irw3EHEZEMpfz8UG2ISYBpyEu\ngCJ4XADjPV4EXIiEuiAkSedNpIgVvGwNermkcA60QXpPayzC3Fxts7XVZ7A1ZjgaUY4L9nbGfPvD\n+zw9v8S4QFt7+mkS7dpF4Kdna4wL/PGzKT8+X7JWGZ89mXN61vCNB/scL1fIJPDBnV3SPGNvmLEt\nHCeXC1QaIdXz04rhsGR7O8WGFuFTlBCUI4UNiqXrcXSvx6SsSCuDQFKqMRmS73/n/0GvGqzTeKNp\ndFwwq2qJ8DVJVvDs9TEZhlJJLi8XjNLAcJTFY9CNxiYWKMNBTDqXQiBEgpLRHfnO7dv0+z3SNKM/\nHLJ/eIixgRcv3uCcZ71eY4whyzLSNOX4+ITxeEyWZbx584a9vT0uLy85OTlBqYSnT59yeXnFbDrn\n6mrK9vYOQgqSJEVrTQiB5WLByckJ8/mcN69eIxCUZUmWZ1RVhbHR1PLs7Iz33/+I+XzFcr5mtWo4\nu6zY3b+N9QbrK8ZbfRbLC/Ls5q1QEBQ9KPK//jxuYpKcjVw0JYgSb2JBI6VkXVnOzi5wLp7kpYBa\nQ7VuIwIS4iaYphILEOD87Jyf/jiiGd//sx/yR//vH/PpD3/IT3/8I773Zz+gthCsZ1nH4Np6dY2v\nA1E0EH9V/JwTXFtXxWiJjq/RcW3CZme96fBsYtfD2w3CbzaJDe/hix8jusTwENVaNnASBHMflW5C\nxKInEdEzbKgEo0SQqbeGbwB3Ryk5hlSkBBHRCe1TdN12h0uBcxWJGkBQBJHTGM1q9RQlJMfnTwBB\nXm4zGuzhRcrO+IAs7bFYV/zqr73Di9efc3n8glFqGUoL5uaxJKJD8v31OMSPnYuFondcF8ohxLw2\nubknHWcn/nzofsfb8fSbIom4rjovceZ7/A//8p/w/OUxlY4q3fGgwM8v0MfHIFJca/DGRMWW93gX\nN22ZKsRWPx44jQUZ+Se+qXDWgTWEkDCcTBgNe9zaGTLs9RgPCqwLmPZm4xOuqwkiTzmI6+cpoifg\nrMBbgTPxY+9ill0gLnGNtkgRrVOs9R0aHqVVG++jxsdCMbA5uIS/NF9ho45726YKXyg4N2PuXWxN\netcVZA68tdgQuGgMy3nFLh0vKwSWqxZjNNY4plc/G/360qNp4j3G54hWspVqZnZCfwCrxYoi79G7\ntYtRBXl/SF6WmLIhLUao+Uuc+iYH5ozt0S7TcoQJGct2yXBvi6nrE5IlxeySNZJhaAhphhcJhXGE\nTCGq5zTmAVkpEPacut1HDQvsi2e0vW1yfUJSBozaQaQFa99g6jlCRGhQdMSsWMUHEicwb5tc8UZv\nHCLDzTk81gdSF718wJPIeOJTIpK4RIDaeYR38YDdOWAJ71He4kIaq1gV/S+scxRSIoInAoaSBI8J\ngSTEVlwiVXdi8VgRBWaOKJ/0xNgJRED6mLcj5WZdjYoO66OqRaKwIuZveSQtHukMzlnEfE07rPAq\nI3GR7xTtMm927e/s8+TJC5JccDFdMJ2uGfdyylLx/HzFN48m9K3h1lbO6bzmeFoxq1uGRUJbO5JE\ncLDVh+D45PkF+9s9/s/vHfNPf+U+bWu5XGuG/YzqfMlwNGDUB+Eyag2DsmSyX3J1taKxjr4IoBpW\nVYq3CaHf0tuXPP3hOf1kQKpSytygQ6CQQK7Is8DLsxW5SvESTO1RSYrwkajXGof3gdm8RY4mNxuc\nIBkM+oggMVbHqJDgKcs+ua84OTlhe3vCyckJbdsyny/wLnoOta0mSdNrSHvjwyOlpGka9vf3uby8\n5OHDh3z22WcMhyNGoxF1VTEe9bh79y6fP3lCVVVMZzNUWqDbloODPay1XF5e0Ov3KYsc52DVaLx3\n5HmfXtkn75V8/vlTCAn3HrzLyZszytEeqhiCqFFJxYuXf8H+3oj18mbqtc21nMWT2+a6JpN201DK\nzsAtAA6cgiRLUE1DkkrKXkJRKPQqtrAAlFII50hUJJF66/AyFkR7+/vkRQZ8l1/+lV8kG/R4/4MP\n6PWHCGf509/6PYILnE8Xncz8L7OLZLfWOBdi/qLvgIMNEfYLGwt+w8/bEEP/5lcgbtD+C1yL69/7\nxW8KXW0E3X+yKdAkKxNoJIxljOXpXzsRvw3WLKRgnMYD1MrCvd0BqYh8J0kNPie4isWq5dC2ONGQ\npznWLXFCEvSCNC2iylFMUUmDSHexdg4emnYBXpOkGUnS5xtfPeA//d3/kd/u3kKaZJT9Pv/1f/ff\n32yA/krLJvgQ+ZGyQ1e6FPVrAnm3+Ue+ShwwsSkkO8RMCGLKfddfiXWPwOtAokCEH/K7f/IvIP0v\n2R7dZ2w9zdkbkAptLNY0SO9j8WkM0ndjnSbw7n3CbAp5EV+60TF2wgucaUhGJeV4myIvsLZlMuqj\nreFoMubVy+XNhuZ6XohOGh4/GZAdh6bbv3j7PhHRkNDR+V75GGHkRDQ+1MZ27UKP97F7ou1b1MZL\n2GQUBU+Ml+h8kSK/jU6WHjqvgK643BCorxWR8X6IjhJ0t1fyFI8i2i6Y1iFHCbo29CaS9EsQ9y8t\neHrrM+ryiObNM7JEk+08ojKwfXCLZvaG0d4OS63Ig8PWM8p0yOvLU7baJV5r2H/I2gb85afMkwOG\nInCSGvb1U8JgCyclwzLnqrmghyRJI59C+RaR3yXd6mEXS3ppinULqmfPGe3u0+/t0M6fUScZ2eoV\nKg20q4TLVU1+8gK5ncfbJgPCC5yQCOkRQiJctDHyMi5UETETXe7K3/zy3uOk7Sy2PdoKcpVQW4d3\nsYVVINFCRDIbAi8FWV6wOx7RVhVeRrg8qMitiY6SnsgoimWyB7yU4Cw6xJwsIaKplutaWBYf1VYd\nqzH4SF7ezCUniCosEadyEJ3Qzwu8sCihwDqcrqnXK7i6oDcYYaO2IIbI3fAKszWLynC3GPHgg0N+\n/PQNL95MyYJgq18w6mWIVpBLQYlgMiwZ5orZuuGq1hRpyjBVHC9qdkY9fvndAy7nC/7X7x2zNUkY\n7uacLCu+/vF79KThZHFB0sspvGC6WjFfToEE7wK1ykiTFCE0MhF4l3JyUjM4SNHLNcsrBcHSTzLO\npxXWwdaoz96gpCgT6towq9bsDcdkicGTooIky1NWVvLu1s0IlnuDgs+fPuXug/dZLOakec7i/2vv\nzHojO9L0/EScNfPkSiaXYrFKqlYvknqZtsdoYMbw+M5XvjbgCwM27N/g/+P/YMAYGLAxV/bMYLrb\no+nukdRyV1FVxSKZZO55lth8EXGSJbWkGerOjfyABPdkZpw4EW983/u+33JJYx3j8ZA3V5cURU6a\npsxmMx49OsVowXw+Z7lccXQw8o7ZZUmSJAwGAzabDVprxuMx/X6Ply9fBmKroalr+v1jnHOs1yt+\n8pMfUynN9c0t8/mCo6Nj4jjhaDLhYDzi8vKSoltwPZ35TVtKNtsNuIg465JlQ7KsoCrhYHLOp3/9\nVxz0Epp0w3CQkWeGcrPmW4j7EFjynsQqQR3s83dMg/B8Sns331qxWzw38yUbBS9fw/XMZ3eMYWdW\n2M0SFputN4azvj1CL5OkqeTNq9cQ+6XwxYtXZN2Cw/6QOEm4eP4iNOJ1fPCDx0xfOy4/m38BZNh2\nQRYeQMXByK7tWN42EBUSL083b+3NDwhr7Y6M3cbbr8ODrPsvWpgT/Gw9XxG/t9zguBGQxvBhFggA\nAqSDLHIUsaCTCFhY7raGUkfYVGM0yKiBRvHq5ZLz7R0K7x1lqWn0lpPJjyhribYxvWLAajVlPb9E\npgVFUpDHBbWtkOKENLnl4Pwp7//xIX/zN7cgBNooVg/0mfFvLHi74Nc625ZJguuyCAo5KX0jzSjy\n6irbkmydw4T2B7uMjoNY+t+1kbdLUI3/P7USvsR18+f81//25/yPv/gZ//pP/zN/8pv/iYpPvels\nZdCNIk0S2GyJen3ivEPxr34Gg8ITmvMCR4yLHFJt/YtOO+CgriuGjx6RfJpjXI2QMVVZsV4/DPC0\nKcddB/Lw/ggHB2UsupK42P8wssLzjZ3fR5QxbJwiEpLbdYWuHUaJXSbIGondCGpjMbXvV2ZiSeQJ\ns35M9b0STuDvKas878zqe2BjQ2myTQpZws9TD0hrY6jr2gt+tEELsApEEqEqxejw61vafOPRPe8f\noutbRqMutijYWE0nzYm1oN8Zs2lq6uUlZV2ySUa8vLkhEnDST4jyDDZL4n7O9uAH9MYTkqRisHmN\njSTJZuHN95opeV5g4hG2XOHUkp5Z0tlcUV7WZElGJx1hbUw2PqKsGxarGerpH9E9fIf08F2Kx/+M\nlXWkzy+IdNjQBQgb0UoMnfAmSUL4yQr3bSdwNvgO/OPDOEOEJMLirO9RVRuDcZbIWBoEsfBGgB0B\nwjis0gg0jarZBGCirfUEYX8G8CjZGVI8f0cKhzY+lZjgdpb4MpxWnLOe74PbOTL7MlV7wmsvst0t\netpoL8v02J0meJBKpXB1g6gUdV2hdINSNU4//KR+tdxycDxgui35u19+Rg9LJ0t4MipIpORXr+bU\n1lJqxa01dPKUx5M+7530aZShwbGqGn5wdkCaR/zm+Q11bTkbJnTSlMrWqGFCnjneLFbcTRXrVYk2\nJZiGTpbhjGXQS3AoqtoiXII1lrrUJEKSIUnilO4kI8kkd8uK4bjLeNQnjqBsFM4aGqXIkoz5quS2\nitnohq60rDYVkXCU64cRLK+ubzk4PmV6e4OMIuaLBduy5HZ+x3o54/H5Y548fcJms+EH3/8enW6X\n1XpJt5uz3W4otxtWqxUnJyesVivu7u4YDoecnZ3R6XSoqhrn8CWx4J7s+bPeY+fnP/8FCMl3v/se\naZpSbksvc18uUUpxdnZGWVUe7AgP0qM4ptEKbRyLZc2Liyvmiy0ff/wptlrSywy9TszB6ACrJUnU\nwX6LzCCxoCmhaeDLsGBHKg0tVN7OZmijGfRT3nl6yPeeTUhjEe4P/zdVrXx5A7+4SzxvWDnH6PCA\n8/MzAB4/OeHZd59w/vSc73//+3zwow9AgFKGujFEX+EPI6Q3uNilbNqzU3sjuvDPQmpf7G7Oh0VL\nDG2f2oV6y9sfd49QYtidkEMJQbTZi/BaGiW4DI1aBQH0CEgkFBEcFPDzixmV9WPl8OtPJOCdZxPf\nMy0tEDjSOCGKB2y2C6rqjbcmqNd0iwOUi+jkh8hkRLd7gGqUL4+oDZta8B/+7Z/5yx9DFLeE8IeF\nbUnK4b215SjfNiJwVlyYMO0laHf/oCBqyd9+zNrxE8H3TJBEkKSixQoIB5taUpWwWfwVv/3sL7i7\nXVFr69VJWnsFnDakIiGOE782j3q4ukEmWWgsrsBpiDPQDc4arJAk2YB6tWB0cEiRpd5lXDpOzk8f\nODpvI2Cxe9sQiNnO0yxs4OQ46yXq1vi5U2tLGUr3xhhsYwjnc2zg4hjjMOFz28rNWy6bxXN4HKF3\n2f2jndct0dlaEZ4nlCfb57J+F1PW0R+PqPD2K+38j4QjSiRl9fWc02+cVrKqOYsaEpciZYezIsZZ\n2FJTqppyY+jUNUm3YL1awKBLX1o+55jZtmJtM6qoR8eVuGqNKc6px98jzVKa/ojldI7MDnHNFrd5\nRTeO6IguHDymGQ445reQjTAkGGcg7zPoRGRxiXrxGesmo5I5U9EhOn9C8vIlUSqIpG+siSAsRAHs\n4Lz7ZrgS3ptHhvT07y9k3zh9rKDS2iNja9DWYjE4q7FxTOEMS21QxmKcwUqHA/CXNwAAF01JREFU\njL0fQ7fo4gLj3MvbBWk4iUgs0jqMtcTCkTmfjUm1pRS+9BEJg3GGDEPSrgzCobBY50stlXWIsCkI\ny87VWToXGpE6L9u0jtT5k521Dls3Xr3VaP/QgV7/wOgVGR+cHPNnP3nCz376hLODLpNhzlY5DoYZ\nxSjnk+sFf/npDbPphk8urvnFiwXaSN477HJzNePp2YSnh30maYyIIxrjiLMIY+B0MuLpQZ9VaVmy\nwXYa7NpyNzOUVcyrlw2dLCVLU4SN0MrQWEelG6raIKRgvqgosoIOktnC0MkSauVII28DoIxjvWwY\n9HO61nJ8NGCYKBIT8Xqxpq4aJsWQN7cPO432iq6vizvLbLmg1goRR8g4Zr6c45zlk08+JYojZvM5\nn3z8MWW55ejoiG4n5/T0lKdPn7JarZhMJkwmE4QQzGYzjPEmgsvlkm63y3K55OrqCuccw+GATp7z\n6NEpm+2WV69eMRwMkFKgtGK73bLZblmv18GRuSSKJEkcUzeKplFcXHzOfD5HKcXnn79k0B9wcjLk\n4sVvWa9WXL25Jcu6ZFmPKHqYXB8A64hifMfrL0XLORBhw2rPKNZAlkqODvucHg/pprkvGznQRjDo\nQJxlJHHkGxj6EwNKOWK8WVyaehLoYlGxWZcYo1HKZ8esharSVJXPEIZXcf/CnPVlisgv6Dsucmsc\nhPPl8/D93Wb+YNDjwik4rGFhA9staV/1CETZ9rEDR6Fk46zjRsNLL7jbteMBv4F0k4iFapitGoyu\ncNrgJRTw4QeP2FbXpJ0h6+oWbRpf2rMb+r0nJElGmvYZ909Jkz7WRIjIr1HdbERd35GnPcqy4k//\n5D1/LQMxmweux3DfHsJaew/qwgZrrQj+d63XDrs+TdaEDt+tWii8fxfKkR6QhHGWHpSlqSNK8ApS\nAdtSUjfw8qP/wjqd+EN2U+9KZKKpEZH0gPLRANEfQN6FNA9gxyET3+7FYiDrIbCITGLqmkhmzOdT\n0rzLerUmiR8mlID78WjVgS2VzLlWxdaCHOn5PMZzeaQUNNpQKa9gVdYEYOnQbYYn9ClzeM6YM97r\nqE3RuNBfq+Xk2NBu4h6ch/5ZuF02yLTE85Bx88DH0I0kjXNcAJXxPmvaeEW0UeYbWz9+M47OMrb5\nIVtbQ5rT6BhnK0ztmK23JFGJfPSE9bahkwuy6WtWKmGSKuIsYfir/8705QWynMPtc1RTYW4/RnUO\nybFIM6OrNwxlQpF1uZ68R3TyiKipiTpnyMc/IZaKJtJ0IkdXr1DFGBf1ifOG5uqXOLUguvw7ZLlA\nfPS3yDgO5mT3zpotIc0fzjzEEYgAikKS74H5dx2o5knr9+AM2kCEJLEKZRyZsxir0U7gFMGJ0yHy\nFGxQjzm/uVbC10+Ns97PR0gavHIixqEiiJ0lwfdYkkJihHcwMoHcJcPpTr+lQnPWep8j50jxwMfg\nJ6J2jkx4xZa2xsvVjUKvSzabLUJVKNMg1EMZTrDaNPz9q1fYtEvvcEjv0RHX2y0fv55SrWu2i5L3\nj8f89J0jjg66DDop0tZcrEr6/YJ//sMnXN0uubhbczwc4Iyl1IZ1ZZgqxaXasOlKVLpk3E8Z93oM\nT3oU/YRUZvSHEVkR0aiGJI0xWuKU50kNBzmVgU4nYV2vmG629HoR0jl6UYRSGusiRkVGp5+xruDg\n5IhRV/K7qzW9YUqv6GFMxWev3+AtQx8yNiVSRoxGY5I0ZTgaMTmaMBqP+OGPf8LNzQ3dbpeyLLHW\nMhgMODg44OLzC548eQIIrq7e0O/3MYHQrrWm0+kwnU65vLzk2bNnpGnKeDymKLpo7VVXSeK5Y928\nw/njc9brNUmSkCYJVVVx9eaKpmlI05TzJ0+Y3t6yXq3oFQXGGJRSdPOc0XCI0Yrp9JrD0Yh+0aWp\nNcdHj3AkCBKq8uFAmchvTFr//pxrWzFYPPG+BUDW+ZYUN9M1z59fc/VmxrZSWLwEO0+9QVlZG+rW\no8u4wM2A2fSO3332HIDZ7R2zW6+GuXpzycXzC1+esjDs5XSy9lq7t18YUeRFAFK2i/h9lgFaLx53\nz1t46/T7j442c/F7ipYHPNUODLXZModWjqmG143YvSsBpAK6keU4hY9elyil0NZhRYwx0Fws+Omw\nj1E1g3xCvzshjQoEA6IIVFNyefdb3kw/Yzg4RkaCzeqOsl4ihCWNOwgZEaVDRodDHp8Xu8ydc1/z\n+r8h/MHVZw1aEi1Bet4CKc85EVgdms7a+5Kkcc6/v5ARE23pxbkddhT+/EySCJIMZGyD84lgsYau\n/B5J3gcLpqr9ZTYaYS1CCqQ1JB+eg0wQ1uK2a08ZSDLQCr26Q0RBoWYUUqREWZdOnnJ4MOL6zWuS\nJKJcf337hK+97i3Q4x4vg89gGdsCkhZ0hu+1pSXjUKYdZdDKeIsH7Xk+VrXKLj/Wxhcu/N4a7BJw\n/uzsCcneUf0LJHvnR9zY9nrxVtYHwsaOwZEEwY1SdrdeLG/XQRDw9XfDNwKe1RIWjcQmOXnqS0Wp\nqtDLVzw6PWQkDUmkqTeazc0bRLNk0yy5MiNGh8fof/lviPSUyCia0WMSXZFlxyT1DKUqiuIRlXA4\nl2CkYnj998yWKxoZ4fSWjethXELWbJBJjhIZ5uYFVa3Ixu8y6Q9JygV6+IS+jFlNb99SF0ThJnC7\nBqFGeE8OK+7vJmcFRkZ8xYHym+dPWHVUMCywQGI0whrqxiuilLGkIsJog5QWaUEojdSKRniQ4WgJ\nzSGrEJj8yvlMjxTWd1EPpC/rDKlzaGv9Zhdke5F1vg9Xe/c65yXx+NMLgHCKKvB9otCqonHO26Fj\nqa3ykv2mQTY+05NrS/MtMjzvnk5ItUM0Daqy3M42fPZqzrvnR9Sx5IffecQgi6m1Y6ktJ5MBnTwj\nEymzsuKj1zO09b1nImno5BlRJJluKw6OOgyGfZxTGKtJYp8BWq9KnHb0DmIm4x5X12vyIsMoRx4Z\n+r0EJy3aWOJIsakabASdSJDEGSJOiGLB2WHPuyuXinWl0ELx6//7nKZyjHodPn55BU4z7MR0M8ns\nbvmgsVlvt9R1jTGG8XhM0zS7r+uqYr1ec3FxQZokRHHMycmJr1mnKVVV8tFHf4vWGikldV0znU6J\noojpdMrBwcEOBBVFQZIkDIdDjo+PyPOcu7tbVqsVFy8uuLq6Js9zoiji+nqKlJL3338f5xxVVfLr\nX/2K7WYLQjC9nrLZbDk7O/eLy3KFtZY4TjDKkHd6NI1jW1peXlwzvZnT7XQfPG+Q/kQZxW/Tgn20\nJzdn7W6zB8IiabHWsNpUaGPv53zsZdZ5EnkehvJMFhsWVaNhcjTh8Mj3Q5scjXnyzhknp0c8fvKU\nD374vt+8jJ837ktLpiBkZoMlP8KXPkLroZ0x4L00NwAN+a0SPLu/+TJhud00dpkN+NJm8tYJ/63n\nA/+ajILLxvGyEeE9+B/FAg5z+Oy2RNmEWuvdTjRbLjka5DyWkHUeM51/jjEbNs0CZQQyGpDFHaRI\nKasVCOj1jsmSgk42CYZykkgmkJzwn/69N4yzPJzQfT8mLpDFv1i627kqh0y6MoSO9T6D0IJIyT3Q\naeX7u7EMH6X0/d6SxJGkgih2iNjPp4vF0vd/qhqk0gitiKoGaQIh2jryJ09wxvoniYOjtG5w1hIn\nHUTWwd8IEUmnQyQjlssVTW0ZDCdEaUqneNi9FTpT7TKBLryZnbLQBvO/kGmxxnkwEzI11nkzXS/X\nN9igGHY63EtGeCsA7rN0xvqeZl6O7ueYNQH0BPDjAvihLYNZ349rB7xMIDN7/0YE0E9ilputd3Z2\nnjTdaINWGt0odPP1rSW+EfA0zS3CKJKsC8kBiTW4NCcanVDeTKniPtYakuiWYjgifvpDCjbUs2uW\nsxny5hVJ2qc3OmWSGIRdk2YSmR6gSkWdFLj+MbIjiZIRUZKQLF9TV5o47xCbK4RaIodHzNcNkboj\n7Z0gTcn8bo0tJtSyIJ89Z11XYHypwi88Xpm0S2H64lbw5PGZHc9xEcSAeyDfQFmv1LLWoowhspba\nWZTxZoZS+ayK1ooURyMsifQy8SKWJJGnSStjvaIiSM1jAXkgmybCO11WeL6REr7JqBHQWiXW1uCc\nb2TqcGT4dhEa3xW+dgbhvGFYYwXC+lKKsobGF713J5gY3zZAa43SCls1GKsfLLsG+PziJYk1vLpa\n8OZ2QSJinp1MaLRj1EmoGsXFdMP/+vSSVMOPz444HqScDAS9JOX9R2PORx22wvGLyzucgNW64cPH\nE+KOY1NWVNstVWm4XWi2m9KT7IxjkKYYseVwnLNeaISV9AYd0lwSi5giM0gjiE1Bss2JhCCLYdzP\nKboxK6PppAmDIDk3lSaRMbNVRdFJ6cYZlRNUQSGwTh7G4RmOx6zXG2Z3d7x+9Yp/+k9+ugM819fX\nvP/++zx79ozVek2vKGiaBqUUJ6en1HXN0fExBwcHbLdblssl5+fnrFYrPvzwA+I4Jk1Tqqri9vaW\n29tbfvPrX7NarViv1x6gGMPh4SHWGvI8Z7lc0O/3GQ1HzOeLXWZoNBpzfHzEoN9HKUVR9FgtV6Fn\nV0Wv12W59D21yq0iirt0izHPnn2ItRGz+cOAIICrfU8f1XxFCiR8aYz58hEV8KWwo0nBwTglToPB\nnPOeKVkWE0cRnTzyfDUBSQTdXCIjQa9XhOfWaOXd040xqEZhnMA0sF5sPQnz7dcLO7U4+PvZ7jYV\nn2EQAohARMLf4C235qFjw1t/Ez75AjBoQdA/UBH6quyJLyvAVDnetJyeAHzyGF7eLJiuK3CSpvYu\nhcmgA2/uOO2u6Kopadqj2x0wHr6DlBHXN39LLz+kNzymrNcIp7i5fYWMM+arK6SzKF15MKkUf/an\nP969j2+T4XHtxQjmkjZksUHsJPvGBGJyKHHtMhvsLkuoB/jsoJB+094BBXyGR0YQhyxPnAgiKYhT\neGM+pa42sK08Z7NRQbjvM/ByXCC7XZDSS9IjCdbipDcGdUm8487hIMoKjG7Iuz2iLGO1WVFuaj79\n3eXDxgZxDwBFAICBP4PzdIdG+ayXCzJ+04JB5w1ZdbA4MMaGA0AgF2vhjQu19X8fgIoJjQ7alhW7\ncpkBAqD5ovGgB1u7Mpkm/L8gk9f+Is8rv0clQBkMM8vG4SJJnKUk0dfv5d+4y0dZxmh0SCKEV+yM\n+6zn10QuJnYaaSpUEzEYnYLsQLWiO3mPfi/FpRlaCAZmzcJY4sE7GDliLcdcz+d0jh6zXjcIo6hs\nQiRqmkYS9fs4GZPrDbdVgs4GXN3MGGQlpphgo4TkYELXXnOxMNhOF90dI1/+jjgsRg6QtlVggRCS\n1ne57UnlrPQ4PpS0nDAPmkAYiK1BWE0iHFo7YusglIYaIYixAWgYhNY4bb3DsrE7XwYRiXtvBus3\n7No6UoTn/+C8lM55eXoKKGvQRmNd+zMbpPjsnJcj60nJrYW5ILS+wEvjYyfIjcNa5XskhcyQMwqr\nFaLR0BiassJ9A2L+unjvfMLRZMBqveV2eocoS5R1bJRisTXMS4VII945HZLkMbPtAuck863i7LjL\n9WLFstZI3bBdKKbrkkeTgjhzZIXjcOwQkWa7LanKJWnHkaZwkEleT5f0OwdImVHWDVvVeCdZ5VAN\nvL4pmZwMGQ1jor5j0IuptGW1KdEOVsuKvJexKRs6RU4nSSi6KdeLLd0IJsMueRxzPTMcTCaMew+T\npc9Wa7S1HB4eUhQFv/zl/yFNUppGcXJyws31NZv1mneePuXi4oLLy0vOzh4xn888R0JrtDZIKdlu\nt0ynUx49esTPf/4LVqsVh4eHO6ffXq9Hnud+gUgS8jxnuy2RUUSv12c+n3NwcIiUEcvVim63w2x2\nh7WWqioZDPpstxuKfh+tDdc3U6I4ptPp0ulkSCkoej2KwYAoTZkcPeH15S1P3/0u50+ePXjeiMB1\nEV9RYr7HN/f8Cr8JCUScsFkr3+zV+pKw38T8aTGWvkWLDeAeC7GUlI1lNp1x+eoNAC9evObN5Q2f\nv7jg1cXv+M3f/T0m3Dte7fn7S6bAnzIRnswrQypGCHGvHrA+o+CM33FabseDxqb92IIl98XHrhz1\n5aFzX/zY8ja+HC5sYK9rx1z7TV/i/cUmPfh8DsY2ocSgqRX88i9ew4sLzqMKozdMb35DFhcIGXF4\n8C5J2mez1URSMFtNEZGiWl8QxwJjNTLqYlxJYyN+9KPvMxhnD0eCIezbm6e/HIGkfV9i3IEh6/ld\nHgC9Vepqf4+Wey7CtQzNMm1b2hLIyDdhTjKIEkecQBnB6+UtVmm/juqwtju/7qbvHsGgh6vK+wsh\nvB2JF9i0gNhAlCBw9CaPwJUUnQxTb4kjR6/7sP59/v24wHFqvXJCuakt6YXx0PaePKyNTxUY63Zm\njs4alPEVDBVAkgnZIGscRkuUEhgtvbeQdiGb016fwJfSYtewtR139xZY9aXHAJR2JbLQkqzR3h5E\n+/+5bQxREqMbS6//9dmvbwY81rJWMTLNiDZz6ukGVxyynd3RDA+pkx5qNaXjHOu6pApMb1Wc4IjZ\nxgVbCV29ZVvesXVrcnXLeHuJNTWdXoKsVhwOBySDU7qjPp0kZ1BYytG7DNUCIQyFW1H3zkE0ZKZB\n64x4/A6PxBs6AUVWzy/9JBTsTL4S6031HH4hEjtPChCRVyjFO07Pwzo7C+lRrtK+tBQbQ2O8EWFl\nNFZXviYbOucK4yitxhlLGse08EpoSxQmX2V9LyJrfXYojmWYYNabVjlLbY3PSOGzRVgPipKdosCv\nfs5ZZCBrW9f+jj/x+L/3XCGcv87CGu9Cay2xVVjrm5uiDOYr+BT/UDTzFetFSRFBZSwXt0tm8xl9\nIqKo4fnNmnEW87PzA1ZbRa0ch/0M4yRv7jak3YxBLklkxFEv5zsHBZNezitniANXKUpyRGIYHw7Q\nJseQMVeSUS9lvVnTNCV5LjBKoGtYrx1JLJiMOpimQokYjKQ/Lhh1MuJOTp5JlBXoqqEjJFWjUdaR\nRRnDbsyiNEgRkaF5epxwt1jC9uZBYxP3JmyVYLlc0DQ1h4eHGGMYDgdcXV35LuZxhNaaoig4ODjY\nEZLTNCVNUwaDAYvFgqOjI77zne/w/PlzHj9+zOnpaShJVQghMMbQ7/dQSrHZbCnLiuFwSNNU3N3d\nUZYVt7d3SCmw1nI3m3F0dIwxhtFoRL/Xo9yWFN2O77EVJzw+O6OT5yyXS/r9Hlknw2IZDEb85V/+\nb4piwCef/I5fffTrB8+bHZHyK6Zcu49bbfkCB8aBUZpuHrPZKO5m1b1LrvUGcxp/mFDaZ37bClMs\nYXI84d33ngLw7rNzzs7PePrsKe9857v80R//ePe/hHAkyVeQqYMCxquiwucmZE28B6r/PSk8kGuR\nxLcZmxa0fAnsfCH9Y7/ie299fBvsfLnM1Zq8vaoEdXjdsYR+Ap/cQa0tRjV+k1OKWvb47Scr7MUL\nRrdThoN32FRXWKsot3NeTT/F2ZJar+nmPdL8kP7wu6RphnE5cVwQRyO0rnAy4T/+u38RxvRbjI+n\n+WJCKcrulD6tIev9xqoDYvSbvQiKq/tBazuBO0DKIKR2bfnGP4+MBEniMztJ4itUaQJ/c/UpThlc\no3eZGhlFSCnJ33vsmzerCmc0tmlwUeKVsKqBOMNZDQjfMb2pUdWW4cExMYKT08ckWY/j479+6OC8\nzXN/q+Tpbzgb3JJNICybFmTssl9eSGOd75autb+XtAarJVqDtpbGuMCNEujA3m9LWTaUEN92ct5l\nctqsTgu+9H17C9dyekKhIZUSrRok0Gif4VHGUVcKrTTw4muHQbivgvr72Mc+9rGPfexjH39A8S37\n9u5jH/vYxz72sY99/P8Te8Czj33sYx/72Mc+/uBjD3j2sY997GMf+9jHH3zsAc8+9rGPfexjH/v4\ng4894NnHPvaxj33sYx9/8LEHPPvYxz72sY997OMPPv4fdIUDdUTUECYAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "olF4PpORpCTK", + "colab_type": "code", + "colab": {} + }, + "source": [], + "execution_count": 0, + "outputs": [] + } + ] +} diff --git a/resources/examples/ipynb/models/reformer/machine_translation.ipynb b/resources/examples/ipynb/models/reformer/machine_translation.ipynb new file mode 100644 index 000000000..8e4745f5a --- /dev/null +++ b/resources/examples/ipynb/models/reformer/machine_translation.ipynb @@ -0,0 +1,380 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Reformer: Machine Translation", + "provenance": [], + "collapsed_sections": [ + "udDs_biH0n5U" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "udDs_biH0n5U", + "colab_type": "text" + }, + "source": [ + "#### Copyright 2020 Google LLC." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WPY-OyyM0pSs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "psnUF-8c02o_", + "colab_type": "text" + }, + "source": [ + "# Reformer: Machine Translation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1lnRd_IoERdk", + "colab_type": "text" + }, + "source": [ + "This notebook was designed to run on TPU.\n", + "\n", + "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8PluCmWbZIpJ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Install JAX.\n", + "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n", + "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n", + "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n", + "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n", + "\n", + "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", + "import requests\n", + "import os\n", + "if 'TPU_DRIVER_MODE' not in globals():\n", + " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", + " resp = requests.post(url)\n", + " TPU_DRIVER_MODE = 1\n", + "\n", + "# The following is required to use TPU Driver as JAX's backend.\n", + "from jax.config import config\n", + "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", + "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", + "print(config.FLAGS.jax_backend_target)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yiPdBenoZwH6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n", + "\n", + "from tensorflow.compat.v1.io.gfile import GFile\n", + "import gin\n", + "import os\n", + "import pickle\n", + "import jax\n", + "import trax\n", + "from trax.models.beam_search import Search\n", + "from trax.supervised import inputs\n", + "\n", + "from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder\n", + "\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "\n", + "from scipy.special import softmax" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "uCX88z9iXB7s", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Install sacreBLEU\n", + "!pip install sacrebleu\n", + "import sacrebleu" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FQ89jHCYfhpg" + }, + "source": [ + "## Load WMT14 data" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8S3h28Q9b_9B", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Download the newstest2014 English-to-German translation pairs\n", + "!sacrebleu -t wmt14/full -l en-de --echo src > wmt14-en-de.src\n", + "!sacrebleu -t wmt14/full -l en-de --echo ref > wmt14-en-de.ref" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "CBv2SDnWZEI7", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Load the source text and reference translations into Python\n", + "refs = []\n", + "for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1):\n", + " if line.endswith('\\n'):\n", + " line = line[:-1]\n", + " refs.append(line)\n", + "srcs = []\n", + "for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1):\n", + " if line.endswith('\\n'):\n", + " line = line[:-1]\n", + " srcs.append(line)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "CbYw4eMXZGKa", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Set up our sub-word tokenizer\n", + "tokenizer = SubwordTextEncoder(\n", + " 'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2NbOslppZGZ0", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Encode source sentences using the tokenizer\n", + "input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64)\n", + "for i, x in enumerate(srcs):\n", + " x = tokenizer.encode(x)\n", + " assert len(x) <= 127\n", + " input_ids[i, :len(x)] = x\n", + " input_ids[i, len(x)] = 1" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YwzU64GmZTb2", + "colab_type": "text" + }, + "source": [ + "## Load the pre-trained model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "VXjtCPxl3I82", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# We'll be using a pre-trained reversible transformer-base model.\n", + "# First, load the config (which sets all needed hyperparameters).\n", + "!gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin\n", + "gin.parse_config_file('./config.gin')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "IediBe8MXyLf", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Now we load the pre-trained model weights.\n", + "with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f:\n", + " model_weights = pickle.load(f)['weights']" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zY3hpgnI5Rgn", + "colab_type": "text" + }, + "source": [ + "## Beam search decoding" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fc_VlhrBYW0u", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Set up beam search.\n", + "beam_decoder = Search(\n", + " trax.models.Reformer, model_weights,\n", + " beam_size=4,\n", + " alpha=0.6, # For length normalization, set to 0.6 following Vaswani et al.\n", + " eos_id=1, # The stop token has id 1 in the vocabulary we use.\n", + " max_decode_len=146,\n", + " )" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "bynTpreMYXPs", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 71 + }, + "outputId": "cfd24e01-617b-4beb-a5f2-98a7ce2e1449" + }, + "source": [ + "pred_ids = []\n", + "preds = []\n", + "BATCH_SIZE = 1024\n", + "for start in range(0, input_ids.shape[0], BATCH_SIZE):\n", + " print(start, '/', input_ids.shape[0], flush=True)\n", + " batch = input_ids[start:start+BATCH_SIZE]\n", + " seqs, scores = beam_decoder.decode(batch, batch_size=BATCH_SIZE)\n", + " # Select highest scoring output.\n", + " batch_pred_ids = seqs[:, -1]\n", + " pred_ids.append(batch_pred_ids)\n", + " preds.extend([\n", + " tokenizer.decode(pred.tolist(), strip_extraneous=True)\n", + " for pred in batch_pred_ids\n", + " ])" + ], + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "text": [ + "0 / 3003\n", + "1024 / 3003\n", + "2048 / 3003\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "c5Gq4qF_YY2i", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "37a5e24f-9264-4d7a-dd74-065758c9a7e4" + }, + "source": [ + "bleu = sacrebleu.corpus_bleu(preds, [refs], lowercase=True, tokenize='intl')\n", + "print(bleu)" + ], + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "text": [ + "BLEU = 27.86 59.5/33.5/21.3/14.2 (BP = 1.000 ratio = 1.020 hyp_len = 65943 ref_len = 64676)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "olF4PpORpCTK", + "colab_type": "code", + "colab": {} + }, + "source": [], + "execution_count": 0, + "outputs": [] + } + ] +} diff --git a/resources/examples/ipynb/models/reformer/text_generation.ipynb b/resources/examples/ipynb/models/reformer/text_generation.ipynb new file mode 100644 index 000000000..a28b37838 --- /dev/null +++ b/resources/examples/ipynb/models/reformer/text_generation.ipynb @@ -0,0 +1,546 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Reformer: Text Generation", + "provenance": [], + "collapsed_sections": [ + "udDs_biH0n5U" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "udDs_biH0n5U", + "colab_type": "text" + }, + "source": [ + "#### Copyright 2020 Google LLC." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WPY-OyyM0pSs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "psnUF-8c02o_", + "colab_type": "text" + }, + "source": [ + "# Reformer: Text Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1lnRd_IoERdk", + "colab_type": "text" + }, + "source": [ + "This notebook was designed to run on TPU.\n", + "\n", + "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8PluCmWbZIpJ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Install JAX.\n", + "!pip install --upgrade jax\n", + "!pip install --upgrade jaxlib\n", + "!pip install --upgrade trax\n", + "\n", + "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", + "import requests\n", + "import os\n", + "if 'TPU_DRIVER_MODE' not in globals():\n", + " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", + " resp = requests.post(url)\n", + " TPU_DRIVER_MODE = 1\n", + "\n", + "# The following is required to use TPU Driver as JAX's backend.\n", + "from jax.config import config\n", + "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", + "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", + "print(config.FLAGS.jax_backend_target)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yiPdBenoZwH6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!pip install --upgrade -q sentencepiece\n", + "!pip install --upgrade -q gin \n", + "\n", + "from tensorflow.compat.v1.io.gfile import GFile\n", + "import gin\n", + "import os\n", + "import jax\n", + "import trax\n", + "from trax.data import inputs\n", + "\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "\n", + "from scipy.special import softmax\n", + "\n", + "from sentencepiece import SentencePieceProcessor" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FQ89jHCYfhpg" + }, + "source": [ + "## Setting up data and model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9_OCIqghSyfs", + "colab_type": "text" + }, + "source": [ + "In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of \"Crime and Punishment\" on *each* of the 8 TPU cores (over 500,000 tokens per 8GB of memory)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "tYSOVGR47LVL", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Import a copy of \"Crime and Punishment\", by Fyodor Dostoevsky\n", + "with GFile('gs://trax-ml/reformer/crime-and-punishment-2554.txt') as f:\n", + " text = f.read()\n", + "\n", + "# The file read above includes metadata and licensing information.\n", + "# For training our language model, we will only use the actual novel text.\n", + "start = text.find('CRIME AND PUNISHMENT') # skip header\n", + "start = text.find('CRIME AND PUNISHMENT', start + 1) # skip header\n", + "start = text.find('CRIME AND PUNISHMENT', start + 1) # skip translator preface\n", + "end = text.rfind('End of Project') # skip extra text at the end\n", + "text = text[start:end].strip()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "mMntV3H-6OR0", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "outputId": "c8d4386c-cf5d-4dc4-92d9-24391fa2f30e" + }, + "source": [ + "# Load a BPE vocabulaary with 320 types. This mostly consists of single letters\n", + "# and pairs of letters, but it has some common words and word pieces, too.\n", + "!gsutil cp gs://trax-ml/reformer/cp.320.* .\n", + "\n", + "TOKENIZER = SentencePieceProcessor()\n", + "TOKENIZER.load('cp.320.model')" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Copying gs://trax-ml/reformer/cp.320.model...\n", + "Copying gs://trax-ml/reformer/cp.320.vocab...\n", + "/ [2 files][239.0 KiB/239.0 KiB] \n", + "Operation completed over 2 objects/239.0 KiB. \n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 4 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "HnJzxSi_77zP", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "f8b2050b-0233-40e4-88f1-e546a1541b31" + }, + "source": [ + "# Tokenize\n", + "IDS = TOKENIZER.EncodeAsIds(text)\n", + "IDS = np.asarray(IDS, dtype=np.int32)\n", + "PAD_AMOUNT = 512 * 1024 - len(IDS)\n", + "print(\"Number of tokens:\", IDS.shape[0])" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Number of tokens: 513812\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bzQ7G9uGSga5", + "colab_type": "text" + }, + "source": [ + "As we see above, \"Crime and Punishment\" has just over half a million tokens with the BPE vocabulary we have selected.\n", + "\n", + "Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.\n", + "\n", + "We have 8 TPU cores, so we will separately randomize the amount of padding for each core." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PdAwmpS220ub", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "c0919b3d-4c63-4d2f-db44-3aeccaf4d966" + }, + "source": [ + "# Set up the data pipeline.\n", + "def my_inputs(n_devices):\n", + " while True:\n", + " inputs = []\n", + " mask = []\n", + " pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)\n", + " for i in range(n_devices):\n", + " inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n", + " mode='constant'))\n", + " mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),\n", + " (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n", + " mode='constant'))\n", + " inputs = np.stack(inputs)\n", + " mask = np.stack(mask)\n", + " yield (inputs, inputs, mask)\n", + "\n", + "print(\"(device count, tokens per device) = \",\n", + " next(my_inputs(trax.fastmath.device_count()))[0].shape)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "(device count, tokens per device) = (8, 524288)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Ei90LdK024r_", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Configure hyperparameters.\n", + "gin.parse_config(\"\"\"\n", + "import trax.layers\n", + "import trax.models\n", + "import trax.optimizers\n", + "import trax.data.inputs\n", + "import trax.supervised.trainer_lib\n", + "\n", + "# Parameters that will vary between experiments:\n", + "# ==============================================================================\n", + "train.model = @trax.models.ReformerLM\n", + "# Our model will have 6 layers, alternating between the LSH attention proposed\n", + "# in the Reformer paper and local attention within a certain context window.\n", + "n_layers = 6\n", + "attn_type = [\n", + " @trax.layers.SelfAttention,\n", + " @LSHSelfAttention, \n", + " @trax.layers.SelfAttention,\n", + " @LSHSelfAttention,\n", + " @trax.layers.SelfAttention,\n", + " @LSHSelfAttention,\n", + " ]\n", + "share_qk = False # LSH attention ignores this flag and always shares q & k\n", + "n_heads = 2\n", + "attn_kv = 64\n", + "dropout = 0.05\n", + "n_tokens = 524288\n", + "\n", + "# Parameters for multifactor:\n", + "# ==============================================================================\n", + "multifactor.constant = 0.01\n", + "multifactor.factors = 'constant * linear_warmup * cosine_decay'\n", + "multifactor.warmup_steps = 100\n", + "multifactor.steps_per_cycle = 900\n", + "\n", + "# Parameters for Adam:\n", + "# ==============================================================================\n", + "Adam.weight_decay_rate=0.0\n", + "Adam.b1 = 0.86\n", + "Adam.b2 = 0.92\n", + "Adam.eps = 1e-9\n", + "\n", + "# Parameters for SelfAttention:\n", + "# ==============================================================================\n", + "trax.layers.SelfAttention.attention_dropout = 0.05\n", + "trax.layers.SelfAttention.chunk_len = 64\n", + "trax.layers.SelfAttention.n_chunks_before = 1\n", + "trax.layers.SelfAttention.n_parallel_heads = 1\n", + "\n", + "# Parameters for LSHSelfAttention:\n", + "# ==============================================================================\n", + "LSHSelfAttention.attention_dropout = 0.0\n", + "LSHSelfAttention.chunk_len = 64\n", + "LSHSelfAttention.n_buckets = [64, 128]\n", + "LSHSelfAttention.n_chunks_after = 0\n", + "LSHSelfAttention.n_chunks_before = 1\n", + "LSHSelfAttention.n_hashes = 1\n", + "LSHSelfAttention.n_parallel_heads = 1\n", + "LSHSelfAttention.predict_drop_len = 128\n", + "LSHSelfAttention.predict_mem_len = 1024\n", + "\n", + "# Parameters for ReformerLM:\n", + "# ==============================================================================\n", + "ReformerLM.attention_type = %attn_type\n", + "ReformerLM.d_attention_key = %attn_kv\n", + "ReformerLM.d_attention_value = %attn_kv\n", + "ReformerLM.d_model = 256\n", + "ReformerLM.d_ff = 512\n", + "ReformerLM.dropout = %dropout\n", + "ReformerLM.ff_activation = @trax.layers.Relu\n", + "ReformerLM.max_len = %n_tokens\n", + "ReformerLM.mode = 'train'\n", + "ReformerLM.n_heads = %n_heads\n", + "ReformerLM.n_layers = %n_layers\n", + "ReformerLM.vocab_size = 320\n", + "ReformerLM.axial_pos_shape = (512, 1024)\n", + "ReformerLM.d_axial_pos_embs= (64, 192)\n", + "\"\"\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "RGGt0WaT3a-h", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Set up a Trainer.\n", + "output_dir = os.path.expanduser('~/train_dir/')\n", + "!rm -f ~/train_dir/model.pkl.gz # Remove old model\n", + "\n", + "trainer = trax.supervised.Trainer(\n", + " model=trax.models.ReformerLM,\n", + " loss_fn=trax.layers.CrossEntropyLoss(),\n", + " optimizer=trax.optimizers.Adam,\n", + " lr_schedule=trax.lr.multifactor(),\n", + " inputs=trax.data.inputs.Inputs(my_inputs),\n", + " output_dir=output_dir)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "y6VQkmKO3a1L", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 255 + }, + "outputId": "3c933bab-b49d-4e18-caf6-3dfc3e220938" + }, + "source": [ + "# Run one training step, to make sure the model fits in memory.\n", + "# The first time trainer.train_epoch is called, it will JIT the entire network\n", + "# architecture, which takes around 2 minutes. The JIT-compiled model is saved\n", + "# so subsequent runs will be much faster than the first.\n", + "trainer.train_epoch(n_steps=1, n_eval_steps=1)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "Step 1: Ran 1 train steps in 155.17 secs\n", + "Step 1: Evaluation\n", + "Step 1: train accuracy | 0.00343633\n", + "Step 1: train loss | 6.36618853\n", + "Step 1: train neg_log_perplexity | -6.36618853\n", + "Step 1: train sequence_accuracy | 0.00000000\n", + "Step 1: train weights_per_batch_per_core | 513812.00000000\n", + "Step 1: eval accuracy | 0.00340154\n", + "Step 1: eval loss | 6.36649418\n", + "Step 1: eval neg_log_perplexity | -6.36649418\n", + "Step 1: eval sequence_accuracy | 0.00000000\n", + "Step 1: eval weights_per_batch_per_core | 513812.00000000\n", + "Step 1: Finished evaluation\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "EFnX4G6z3asD", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Train for 600 steps total\n", + "# The first ~20 steps are slow to run, but after that it reaches steady-state\n", + "# speed. This will take at least 30 minutes to run to completion, but can safely\n", + "# be interrupted by selecting \"Runtime > Interrupt Execution\" from the menu.\n", + "# The language model won't be exceptionally good when trained for just a few\n", + "# steps and with minimal regularization. However, we can still sample from it to\n", + "# see what it learns.\n", + "trainer.train_epoch(n_steps=9, n_eval_steps=1)\n", + "for _ in range(59):\n", + " trainer.train_epoch(n_steps=10, n_eval_steps=1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zY3hpgnI5Rgn", + "colab_type": "text" + }, + "source": [ + "## Sample from the model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ffeLSbJk35pv", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# As we report in the Reformer paper, increasing the number of hashing rounds\n", + "# helps with quality. We can even increase the number of hashing rounds at\n", + "# evaluation time only.\n", + "\n", + "gin.parse_config(\"\"\"LSHSelfAttention.n_hashes = 4\"\"\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "-BwIjdl6_2tX", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Load the trained Reformer in 'predict' mode\n", + "model = trax.models.ReformerLM(mode='predict')\n", + "model.init_from_file(os.path.join(output_dir,'model.pkl.gz'),\n", + " weights_only=True)\n", + "\n", + "# Sample from ReformerLM\n", + "output_token_ids = trax.supervised.decoding.autoregressive_sample(\n", + " model, temperature=0.0)\n", + "\n", + "# Decode token IDs\n", + "# Reformer outputed a batch with one item, we access it using [0]\n", + "# tolist() converts from int64 to int, the type SentencePiece expects\n", + "TOKENIZER.DecodeIds(output_token_ids[0].tolist()) \n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "s5f5QAmZBgPj", + "colab_type": "code", + "colab": {} + }, + "source": [], + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/trax/models/research/examples/hourglass_downsampled_imagenet.ipynb b/resources/examples/ipynb/models/research/hourglass_downsampled_imagenet.ipynb similarity index 100% rename from trax/models/research/examples/hourglass_downsampled_imagenet.ipynb rename to resources/examples/ipynb/models/research/hourglass_downsampled_imagenet.ipynb diff --git a/trax/models/research/examples/hourglass_enwik8.ipynb b/resources/examples/ipynb/models/research/hourglass_enwik8.ipynb similarity index 100% rename from trax/models/research/examples/hourglass_enwik8.ipynb rename to resources/examples/ipynb/models/research/hourglass_enwik8.ipynb diff --git a/trax/examples/semantic_segmentation.ipynb b/resources/examples/ipynb/semantic_segmentation.ipynb similarity index 100% rename from trax/examples/semantic_segmentation.ipynb rename to resources/examples/ipynb/semantic_segmentation.ipynb diff --git a/trax/tf_numpy_and_keras.ipynb b/resources/examples/ipynb/tf_numpy_and_keras.ipynb similarity index 100% rename from trax/tf_numpy_and_keras.ipynb rename to resources/examples/ipynb/tf_numpy_and_keras.ipynb diff --git a/trax/examples/trax_data_Explained.ipynb b/resources/examples/ipynb/trax_data_Explained.ipynb similarity index 100% rename from trax/examples/trax_data_Explained.ipynb rename to resources/examples/ipynb/trax_data_Explained.ipynb diff --git a/resources/examples/python/mnist/dataset.py b/resources/examples/python/mnist/dataset.py new file mode 100644 index 000000000..257f93286 --- /dev/null +++ b/resources/examples/python/mnist/dataset.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Load pickled MNIST data.""" +import gzip +import os +import pickle +import random +import urllib +import numpy as np + + +def load(): + """Loads the dataset. + + Looks for the dataset at /tmp/mnist.pkl.gz and downloads it if it is not there + already. + + Note: The training data is shuffled. + + Returns: + ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)). + Shapes: + train_x: num_training_examples x image_size + train_y: num_training_examples x num_classes + valid_x: num_validation_examples x image_size + valid_y: num_validation_examples x num_classes + test_x: num_test_examples x image_size + test_y: num_test_examples x num_classes + """ + filepath = _maybe_download() + with gzip.open(os.path.join(filepath), "rb") as f: + training_data, validation_data, test_data = pickle.load(f) + training_data = (training_data[0], [to_one_hot(x) for x in training_data[1]]) + validation_data = (validation_data[0], [to_one_hot(x) for x in validation_data[1]]) + test_data = (test_data[0], [to_one_hot(x) for x in test_data[1]]) + + def shuffle(data): + zipped = zip(*data) + random.shuffle(zipped) + return zip(*zipped) + + return (shuffle(training_data), validation_data, test_data) + + +def to_one_hot(label, num_classes=10): + vec = np.zeros(num_classes, dtype=np.float32) + vec[label] = 1.0 + return vec + + +def _maybe_download(): + """Downloads the MNIST dataset if it is not there already.""" + data_url = "http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz" + filename = data_url.split("/")[-1] + filepath = os.path.join(_get_data_dir(), filename) + if not os.path.exists(filepath): + + def _progress(count, block_size, total_size): + print( + "\r>> Downloading %s %.1f%%" + % (filename, float(count * block_size) / float(total_size) * 100.0) + ) + + filepath, _ = urllib.urlretrieve(data_url, filepath, _progress) + statinfo = os.stat(filepath) + print("Successfully downloaded %s %d bytes." % (filename, statinfo.st_size)) + else: + print("Data already present on disk.") + return filepath + + +def _get_data_dir(): + return "/tmp" diff --git a/resources/examples/python/mnist/model.py b/resources/examples/python/mnist/model.py new file mode 100644 index 000000000..346f99568 --- /dev/null +++ b/resources/examples/python/mnist/model.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model for training on MNIST data.""" +from numpy import float32 +from numpy import int32 + +import tensorflow.compat.v2 as tf + +from trax.tf_numpy import numpy as np + + +class Model(object): + """A simple neural network with dense layers and sigmoid non-linearity. + + The network consists of `len(hidden_layers) + 1` dense layers. The sizes of + the hidden layers are specified by the user in `hidden_layers` and the + network takes care of adding layers to match the input and output size. + + Attributes: + weights: A list of 2-d float32 arrays containing the layer weights. + biases: A list of 2-d float32 arrays containing the layer biases. + + Methods: + forward: Can be used to perform a forward pass on a batch of + flattened images. Output is returned as a batch of one-hot vectors of the + classes. + train: method performs a forward and backward pass and updates the + weights and biases. + evaluate: method can be used to evaluate the network on a batch of + examples. + """ + + def __init__(self, hidden_layers, input_size=784, num_classes=10): + """Initializes the neural network. + + Args: + hidden_layers: List of ints specifying the sizes of hidden layers. Could + be empty. + input_size: Length of the input array. The network receives the input + image as a flattened 1-d array. Defaults to 784(28*28), the default + image size for MNIST. + num_classes: The number of output classes. Defaults to 10. + """ + hidden_layers = [input_size] + hidden_layers + [num_classes] + self.weights = [] + self.biases = [] + for i in range(len(hidden_layers) - 1): + # TODO(srbs): This is manually cast to float32 to avoid the cast in + # np.dot since backprop fails for tf.cast op. + self.weights.append( + np.array( + np.random.randn(hidden_layers[i + 1], hidden_layers[i]), + copy=False, + dtype=float32, + ) + ) + self.biases.append( + np.array( + np.random.randn(hidden_layers[i + 1]), copy=False, dtype=float32 + ) + ) + + def forward(self, x): + """Performs the forward pass. + + Args: + x: 2-d array of size batch_size x image_size. + + Returns: + A 2-d array of size batch_size x num_classes. + """ + + def sigmoid(x): + return 1.0 / (1.0 + np.exp(-x)) + + for w, b in zip(self.weights, self.biases): + x = sigmoid(np.dot(w, x.T).T + b) + return x + + def train(self, x, y, learning_rate=0.01): + """Runs a single training pass. + + Args: + x: 2-d array of size batch_size x image_size. + y: 2-d array of size batch_size x num_classes in one-hot notation. + learning_rate: The learning rate. + """ + x = np.array(x, copy=False) + y = np.array(y, copy=False) + + def mean_squared_error(x, y): + diff = x - y + return np.sum(diff * diff) / len(x) + + wb_tensors = self.weights + self.biases + with tf.GradientTape() as g: + g.watch(wb_tensors) + loss = mean_squared_error(self.forward(x), y) + gradients = g.gradient(loss, wb_tensors) + gradients = [np.asarray(grad) for grad in gradients] + + new_weights_and_biases = [] + for v, dv in zip(self.weights + self.biases, gradients): + new_weights_and_biases.append(v - learning_rate * dv) + + total_len = len(new_weights_and_biases) + self.weights = new_weights_and_biases[: total_len // 2] + self.biases = new_weights_and_biases[total_len // 2 :] + + def evaluate(self, x, y): + """Returns the number of correct predictions. + + Args: + x: 2-d array of size batch_size x image_size. + y: 2-d array of size batch_size x num_classes. + + Returns: + A scalar, the number of correct predictions. + """ + y_actual = np.argmax(y, axis=1) + y_predicted = np.argmax(self.forward(x), axis=1) + return int(np.sum(np.array(y_actual == y_predicted, copy=False, dtype=int32))) diff --git a/resources/examples/python/mnist/train.py b/resources/examples/python/mnist/train.py new file mode 100644 index 000000000..246a23462 --- /dev/null +++ b/resources/examples/python/mnist/train.py @@ -0,0 +1,93 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Perform training.""" +from absl import app +from absl import flags + +from six.moves import range +import tensorflow.compat.v2 as tf + +from resources.examples.python.mnist import model as model_lib, dataset + +FLAGS = flags.FLAGS + +flags.DEFINE_integer("batch_size", 50, "Batch size.") +flags.DEFINE_integer("num_training_iters", 10000, "Number of iterations to train for.") +flags.DEFINE_integer( + "validation_steps", 100, "Validation is performed every these many training steps." +) +flags.DEFINE_float("learning_rate", 5.0, "Learning rate.") + + +def train(batch_size, learning_rate, num_training_iters, validation_steps): + """Runs the training.""" + print("Loading data") + training_data, validation_data, test_data = dataset.load() + print( + "Loaded dataset with {} training, {} validation and {} test examples.".format( + len(training_data[0]), len(validation_data[0]), len(test_data[0]) + ) + ) + + assert len(training_data[0]) % batch_size == 0 + assert len(validation_data[0]) % batch_size == 0 + assert len(test_data[0]) % batch_size == 0 + + def build_iterator(data, infinite=True): + """Build the iterator for inputs.""" + index = 0 + size = len(data[0]) + while True: + if index + batch_size > size: + if infinite: + index = 0 + else: + return + yield data[0][index : index + batch_size], data[1][ + index : index + batch_size + ] + index += batch_size + + train_iter = build_iterator(training_data) + model = model_lib.Model([30]) + + for i in range(num_training_iters): + train_x, train_y = next(train_iter) + model.train(train_x, train_y, learning_rate) + if (i + 1) % validation_steps == 0: + validation_iter = build_iterator(validation_data, infinite=False) + correct_predictions = 0 + for valid_x, valid_y in validation_iter: + correct_predictions += model.evaluate(valid_x, valid_y) + print( + "{}/{} correct validation predictions.".format( + correct_predictions, len(validation_data[0]) + ) + ) + + +def main(unused_argv): + train( + FLAGS.batch_size, + FLAGS.learning_rate, + FLAGS.num_training_iters, + FLAGS.validation_steps, + ) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + app.run(main) diff --git a/resources/examples/python/mnist/train_test.py b/resources/examples/python/mnist/train_test.py new file mode 100644 index 000000000..a4760b498 --- /dev/null +++ b/resources/examples/python/mnist/train_test.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that the example training script works on fake data.""" +import mock +import numpy as np +import tensorflow.compat.v2 as tf + +from resources.examples.python.mnist import train, dataset + + +class TFNumpyMnistExampleTest(tf.test.TestCase): + def testRuns(self): + with mock.patch.object(dataset, "load", new=fake_mnist_data): + train.train( + batch_size=1, + learning_rate=0.1, + num_training_iters=10, + validation_steps=5, + ) + train.train( + batch_size=2, + learning_rate=0.1, + num_training_iters=5, + validation_steps=2, + ) + train.train( + batch_size=10, + learning_rate=0.1, + num_training_iters=1, + validation_steps=1, + ) + + +def fake_mnist_data(): + def gen_examples(num_examples): + x = np.array(np.random.randn(num_examples, 784), copy=False, dtype=np.float32) + y = np.zeros((num_examples, 10), dtype=np.float32) + y[:][0] = 1.0 + return (x, y) + + return (gen_examples(100), gen_examples(10), gen_examples(10)) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001 b/resources/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001 similarity index 100% rename from trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001 rename to resources/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001 diff --git a/trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001 b/resources/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001 similarity index 100% rename from trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001 rename to resources/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001 diff --git a/trax/models/reformer/testdata/vocab.translate_ende_wmt32k.32768.subwords b/resources/models/reformer/testdata/vocab.translate_ende_wmt32k.32768.subwords similarity index 100% rename from trax/models/reformer/testdata/vocab.translate_ende_wmt32k.32768.subwords rename to resources/models/reformer/testdata/vocab.translate_ende_wmt32k.32768.subwords diff --git a/trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001 b/resources/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001 similarity index 100% rename from trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001 rename to resources/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001 diff --git a/trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001 b/resources/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001 similarity index 100% rename from trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001 rename to resources/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001 diff --git a/trax/models/research/testdata/vocab.translate_ende_wmt32k.32768.subwords b/resources/models/research/testdata/vocab.translate_ende_wmt32k.32768.subwords similarity index 100% rename from trax/models/research/testdata/vocab.translate_ende_wmt32k.32768.subwords rename to resources/models/research/testdata/vocab.translate_ende_wmt32k.32768.subwords diff --git a/trax/supervised/configs/bert.gin b/resources/supervised/configs/bert.gin similarity index 100% rename from trax/supervised/configs/bert.gin rename to resources/supervised/configs/bert.gin diff --git a/trax/supervised/configs/bert_glue_classification.gin b/resources/supervised/configs/bert_glue_classification.gin similarity index 100% rename from trax/supervised/configs/bert_glue_classification.gin rename to resources/supervised/configs/bert_glue_classification.gin diff --git a/trax/supervised/configs/bert_glue_regression.gin b/resources/supervised/configs/bert_glue_regression.gin similarity index 100% rename from trax/supervised/configs/bert_glue_regression.gin rename to resources/supervised/configs/bert_glue_regression.gin diff --git a/trax/supervised/configs/bert_glue_sweep_regression_task.yaml b/resources/supervised/configs/bert_glue_sweep_regression_task.yaml similarity index 100% rename from trax/supervised/configs/bert_glue_sweep_regression_task.yaml rename to resources/supervised/configs/bert_glue_sweep_regression_task.yaml diff --git a/trax/supervised/configs/bert_glue_sweep_single_sentence.yaml b/resources/supervised/configs/bert_glue_sweep_single_sentence.yaml similarity index 100% rename from trax/supervised/configs/bert_glue_sweep_single_sentence.yaml rename to resources/supervised/configs/bert_glue_sweep_single_sentence.yaml diff --git a/trax/supervised/configs/bert_glue_sweep_two_sentences.yaml b/resources/supervised/configs/bert_glue_sweep_two_sentences.yaml similarity index 100% rename from trax/supervised/configs/bert_glue_sweep_two_sentences.yaml rename to resources/supervised/configs/bert_glue_sweep_two_sentences.yaml diff --git a/trax/supervised/configs/bert_pretraining.gin b/resources/supervised/configs/bert_pretraining.gin similarity index 100% rename from trax/supervised/configs/bert_pretraining.gin rename to resources/supervised/configs/bert_pretraining.gin diff --git a/trax/supervised/configs/bert_pretraining_onlymlm.gin b/resources/supervised/configs/bert_pretraining_onlymlm.gin similarity index 100% rename from trax/supervised/configs/bert_pretraining_onlymlm.gin rename to resources/supervised/configs/bert_pretraining_onlymlm.gin diff --git a/trax/supervised/configs/bert_pretraining_onlynsp.gin b/resources/supervised/configs/bert_pretraining_onlynsp.gin similarity index 100% rename from trax/supervised/configs/bert_pretraining_onlynsp.gin rename to resources/supervised/configs/bert_pretraining_onlynsp.gin diff --git a/trax/supervised/configs/c4.gin b/resources/supervised/configs/c4.gin similarity index 100% rename from trax/supervised/configs/c4.gin rename to resources/supervised/configs/c4.gin diff --git a/trax/supervised/configs/c4_pretrain_16gb_adafactor.gin b/resources/supervised/configs/c4_pretrain_16gb_adafactor.gin similarity index 100% rename from trax/supervised/configs/c4_pretrain_16gb_adafactor.gin rename to resources/supervised/configs/c4_pretrain_16gb_adafactor.gin diff --git a/trax/supervised/configs/c4_trax_data.gin b/resources/supervised/configs/c4_trax_data.gin similarity index 100% rename from trax/supervised/configs/c4_trax_data.gin rename to resources/supervised/configs/c4_trax_data.gin diff --git a/trax/supervised/configs/cond_skipping_transformer_lm1b.gin b/resources/supervised/configs/cond_skipping_transformer_lm1b.gin similarity index 100% rename from trax/supervised/configs/cond_skipping_transformer_lm1b.gin rename to resources/supervised/configs/cond_skipping_transformer_lm1b.gin diff --git a/trax/supervised/configs/gru_copy.gin b/resources/supervised/configs/gru_copy.gin similarity index 100% rename from trax/supervised/configs/gru_copy.gin rename to resources/supervised/configs/gru_copy.gin diff --git a/trax/supervised/configs/hourglass_cifar10.gin b/resources/supervised/configs/hourglass_cifar10.gin similarity index 100% rename from trax/supervised/configs/hourglass_cifar10.gin rename to resources/supervised/configs/hourglass_cifar10.gin diff --git a/trax/supervised/configs/hourglass_enwik8.gin b/resources/supervised/configs/hourglass_enwik8.gin similarity index 100% rename from trax/supervised/configs/hourglass_enwik8.gin rename to resources/supervised/configs/hourglass_enwik8.gin diff --git a/trax/supervised/configs/hourglass_imagenet32.gin b/resources/supervised/configs/hourglass_imagenet32.gin similarity index 100% rename from trax/supervised/configs/hourglass_imagenet32.gin rename to resources/supervised/configs/hourglass_imagenet32.gin diff --git a/trax/supervised/configs/hourglass_imagenet64.gin b/resources/supervised/configs/hourglass_imagenet64.gin similarity index 100% rename from trax/supervised/configs/hourglass_imagenet64.gin rename to resources/supervised/configs/hourglass_imagenet64.gin diff --git a/trax/supervised/configs/hourglass_wiki40b.gin b/resources/supervised/configs/hourglass_wiki40b.gin similarity index 100% rename from trax/supervised/configs/hourglass_wiki40b.gin rename to resources/supervised/configs/hourglass_wiki40b.gin diff --git a/trax/supervised/configs/layerdrop_every_transformer_lm1b.gin b/resources/supervised/configs/layerdrop_every_transformer_lm1b.gin similarity index 100% rename from trax/supervised/configs/layerdrop_every_transformer_lm1b.gin rename to resources/supervised/configs/layerdrop_every_transformer_lm1b.gin diff --git a/trax/supervised/configs/layerdrop_transformer_lm1b.gin b/resources/supervised/configs/layerdrop_transformer_lm1b.gin similarity index 100% rename from trax/supervised/configs/layerdrop_transformer_lm1b.gin rename to resources/supervised/configs/layerdrop_transformer_lm1b.gin diff --git a/trax/supervised/configs/layerdrop_ushape_transformer_lm1b.gin b/resources/supervised/configs/layerdrop_ushape_transformer_lm1b.gin similarity index 100% rename from trax/supervised/configs/layerdrop_ushape_transformer_lm1b.gin rename to resources/supervised/configs/layerdrop_ushape_transformer_lm1b.gin diff --git a/trax/supervised/configs/lstm_lm1b.gin b/resources/supervised/configs/lstm_lm1b.gin similarity index 100% rename from trax/supervised/configs/lstm_lm1b.gin rename to resources/supervised/configs/lstm_lm1b.gin diff --git a/trax/supervised/configs/lstm_seq2seq_wmt_ende.gin b/resources/supervised/configs/lstm_seq2seq_wmt_ende.gin similarity index 100% rename from trax/supervised/configs/lstm_seq2seq_wmt_ende.gin rename to resources/supervised/configs/lstm_seq2seq_wmt_ende.gin diff --git a/trax/supervised/configs/mlp_mnist.gin b/resources/supervised/configs/mlp_mnist.gin similarity index 100% rename from trax/supervised/configs/mlp_mnist.gin rename to resources/supervised/configs/mlp_mnist.gin diff --git a/trax/supervised/configs/reformer_addition.gin b/resources/supervised/configs/reformer_addition.gin similarity index 100% rename from trax/supervised/configs/reformer_addition.gin rename to resources/supervised/configs/reformer_addition.gin diff --git a/trax/supervised/configs/reformer_bair_robot_pushing.gin b/resources/supervised/configs/reformer_bair_robot_pushing.gin similarity index 100% rename from trax/supervised/configs/reformer_bair_robot_pushing.gin rename to resources/supervised/configs/reformer_bair_robot_pushing.gin diff --git a/trax/supervised/configs/reformer_c4.gin b/resources/supervised/configs/reformer_c4.gin similarity index 100% rename from trax/supervised/configs/reformer_c4.gin rename to resources/supervised/configs/reformer_c4.gin diff --git a/trax/supervised/configs/reformer_cifar10.gin b/resources/supervised/configs/reformer_cifar10.gin similarity index 100% rename from trax/supervised/configs/reformer_cifar10.gin rename to resources/supervised/configs/reformer_cifar10.gin diff --git a/trax/supervised/configs/reformer_copy.gin b/resources/supervised/configs/reformer_copy.gin similarity index 100% rename from trax/supervised/configs/reformer_copy.gin rename to resources/supervised/configs/reformer_copy.gin diff --git a/trax/supervised/configs/reformer_enwik8.gin b/resources/supervised/configs/reformer_enwik8.gin similarity index 100% rename from trax/supervised/configs/reformer_enwik8.gin rename to resources/supervised/configs/reformer_enwik8.gin diff --git a/trax/supervised/configs/reformer_imagenet64.gin b/resources/supervised/configs/reformer_imagenet64.gin similarity index 100% rename from trax/supervised/configs/reformer_imagenet64.gin rename to resources/supervised/configs/reformer_imagenet64.gin diff --git a/trax/supervised/configs/reformer_imagenet64_testing.gin b/resources/supervised/configs/reformer_imagenet64_testing.gin similarity index 100% rename from trax/supervised/configs/reformer_imagenet64_testing.gin rename to resources/supervised/configs/reformer_imagenet64_testing.gin diff --git a/trax/supervised/configs/reformer_pc_enpl.gin b/resources/supervised/configs/reformer_pc_enpl.gin similarity index 100% rename from trax/supervised/configs/reformer_pc_enpl.gin rename to resources/supervised/configs/reformer_pc_enpl.gin diff --git a/trax/supervised/configs/reformer_wmt_ende.gin b/resources/supervised/configs/reformer_wmt_ende.gin similarity index 100% rename from trax/supervised/configs/reformer_wmt_ende.gin rename to resources/supervised/configs/reformer_wmt_ende.gin diff --git a/trax/supervised/configs/reformer_wmt_ende_big.gin b/resources/supervised/configs/reformer_wmt_ende_big.gin similarity index 100% rename from trax/supervised/configs/reformer_wmt_ende_big.gin rename to resources/supervised/configs/reformer_wmt_ende_big.gin diff --git a/trax/supervised/configs/resnet50_frn_imagenet_8gb.gin b/resources/supervised/configs/resnet50_frn_imagenet_8gb.gin similarity index 100% rename from trax/supervised/configs/resnet50_frn_imagenet_8gb.gin rename to resources/supervised/configs/resnet50_frn_imagenet_8gb.gin diff --git a/trax/supervised/configs/resnet50_imagenet_8gb_testing.gin b/resources/supervised/configs/resnet50_imagenet_8gb_testing.gin similarity index 100% rename from trax/supervised/configs/resnet50_imagenet_8gb_testing.gin rename to resources/supervised/configs/resnet50_imagenet_8gb_testing.gin diff --git a/trax/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin b/resources/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin similarity index 100% rename from trax/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin rename to resources/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin diff --git a/trax/supervised/configs/rse_addition.gin b/resources/supervised/configs/rse_addition.gin similarity index 100% rename from trax/supervised/configs/rse_addition.gin rename to resources/supervised/configs/rse_addition.gin diff --git a/trax/supervised/configs/rse_addition_sweep.yaml b/resources/supervised/configs/rse_addition_sweep.yaml similarity index 100% rename from trax/supervised/configs/rse_addition_sweep.yaml rename to resources/supervised/configs/rse_addition_sweep.yaml diff --git a/trax/supervised/configs/scientific_papers_reformer_lm.gin b/resources/supervised/configs/scientific_papers_reformer_lm.gin similarity index 100% rename from trax/supervised/configs/scientific_papers_reformer_lm.gin rename to resources/supervised/configs/scientific_papers_reformer_lm.gin diff --git a/trax/supervised/configs/scientific_papers_terraformer.gin b/resources/supervised/configs/scientific_papers_terraformer.gin similarity index 100% rename from trax/supervised/configs/scientific_papers_terraformer.gin rename to resources/supervised/configs/scientific_papers_terraformer.gin diff --git a/trax/supervised/configs/scientific_papers_terraformer_favor.gin b/resources/supervised/configs/scientific_papers_terraformer_favor.gin similarity index 100% rename from trax/supervised/configs/scientific_papers_terraformer_favor.gin rename to resources/supervised/configs/scientific_papers_terraformer_favor.gin diff --git a/trax/supervised/configs/scientific_papers_terraformer_pretrained.gin b/resources/supervised/configs/scientific_papers_terraformer_pretrained.gin similarity index 100% rename from trax/supervised/configs/scientific_papers_terraformer_pretrained.gin rename to resources/supervised/configs/scientific_papers_terraformer_pretrained.gin diff --git a/trax/supervised/configs/skipping_transformer_lm1b.gin b/resources/supervised/configs/skipping_transformer_lm1b.gin similarity index 100% rename from trax/supervised/configs/skipping_transformer_lm1b.gin rename to resources/supervised/configs/skipping_transformer_lm1b.gin diff --git a/trax/supervised/configs/sp_sweep.yaml b/resources/supervised/configs/sp_sweep.yaml similarity index 100% rename from trax/supervised/configs/sp_sweep.yaml rename to resources/supervised/configs/sp_sweep.yaml diff --git a/trax/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin b/resources/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin similarity index 100% rename from trax/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin rename to resources/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin diff --git a/trax/supervised/configs/sparse_lm1b_pretrain_16gb.gin b/resources/supervised/configs/sparse_lm1b_pretrain_16gb.gin similarity index 100% rename from trax/supervised/configs/sparse_lm1b_pretrain_16gb.gin rename to resources/supervised/configs/sparse_lm1b_pretrain_16gb.gin diff --git a/trax/supervised/configs/t5_aqua_parallel.gin b/resources/supervised/configs/t5_aqua_parallel.gin similarity index 100% rename from trax/supervised/configs/t5_aqua_parallel.gin rename to resources/supervised/configs/t5_aqua_parallel.gin diff --git a/trax/supervised/configs/t5_drop.gin b/resources/supervised/configs/t5_drop.gin similarity index 100% rename from trax/supervised/configs/t5_drop.gin rename to resources/supervised/configs/t5_drop.gin diff --git a/trax/supervised/configs/t5_glue_classification.gin b/resources/supervised/configs/t5_glue_classification.gin similarity index 100% rename from trax/supervised/configs/t5_glue_classification.gin rename to resources/supervised/configs/t5_glue_classification.gin diff --git a/trax/supervised/configs/t5_glue_classification_mnli.gin b/resources/supervised/configs/t5_glue_classification_mnli.gin similarity index 100% rename from trax/supervised/configs/t5_glue_classification_mnli.gin rename to resources/supervised/configs/t5_glue_classification_mnli.gin diff --git a/trax/supervised/configs/t5_glue_classification_parallel.gin b/resources/supervised/configs/t5_glue_classification_parallel.gin similarity index 100% rename from trax/supervised/configs/t5_glue_classification_parallel.gin rename to resources/supervised/configs/t5_glue_classification_parallel.gin diff --git a/trax/supervised/configs/t5_glue_classification_two_constants.gin b/resources/supervised/configs/t5_glue_classification_two_constants.gin similarity index 100% rename from trax/supervised/configs/t5_glue_classification_two_constants.gin rename to resources/supervised/configs/t5_glue_classification_two_constants.gin diff --git a/trax/supervised/configs/t5_mathqa.gin b/resources/supervised/configs/t5_mathqa.gin similarity index 100% rename from trax/supervised/configs/t5_mathqa.gin rename to resources/supervised/configs/t5_mathqa.gin diff --git a/trax/supervised/configs/t5_mathqa_drop_loop.gin b/resources/supervised/configs/t5_mathqa_drop_loop.gin similarity index 100% rename from trax/supervised/configs/t5_mathqa_drop_loop.gin rename to resources/supervised/configs/t5_mathqa_drop_loop.gin diff --git a/trax/supervised/configs/t5_mathqa_drop_sweep.yaml b/resources/supervised/configs/t5_mathqa_drop_sweep.yaml similarity index 100% rename from trax/supervised/configs/t5_mathqa_drop_sweep.yaml rename to resources/supervised/configs/t5_mathqa_drop_sweep.yaml diff --git a/trax/supervised/configs/t5_mathqa_multi.gin b/resources/supervised/configs/t5_mathqa_multi.gin similarity index 100% rename from trax/supervised/configs/t5_mathqa_multi.gin rename to resources/supervised/configs/t5_mathqa_multi.gin diff --git a/trax/supervised/configs/t5_mathqa_parallel.gin b/resources/supervised/configs/t5_mathqa_parallel.gin similarity index 100% rename from trax/supervised/configs/t5_mathqa_parallel.gin rename to resources/supervised/configs/t5_mathqa_parallel.gin diff --git a/trax/supervised/configs/t5_mathqa_parallel_full.gin b/resources/supervised/configs/t5_mathqa_parallel_full.gin similarity index 100% rename from trax/supervised/configs/t5_mathqa_parallel_full.gin rename to resources/supervised/configs/t5_mathqa_parallel_full.gin diff --git a/trax/supervised/configs/t5_mathqa_parallel_full_correct_order.gin b/resources/supervised/configs/t5_mathqa_parallel_full_correct_order.gin similarity index 100% rename from trax/supervised/configs/t5_mathqa_parallel_full_correct_order.gin rename to resources/supervised/configs/t5_mathqa_parallel_full_correct_order.gin diff --git a/trax/supervised/configs/t5_mathqa_parallel_full_order.gin b/resources/supervised/configs/t5_mathqa_parallel_full_order.gin similarity index 100% rename from trax/supervised/configs/t5_mathqa_parallel_full_order.gin rename to resources/supervised/configs/t5_mathqa_parallel_full_order.gin diff --git a/trax/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin b/resources/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin similarity index 100% rename from trax/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin rename to resources/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin diff --git a/trax/supervised/configs/t5_sweep.yaml b/resources/supervised/configs/t5_sweep.yaml similarity index 100% rename from trax/supervised/configs/t5_sweep.yaml rename to resources/supervised/configs/t5_sweep.yaml diff --git a/trax/supervised/configs/t5_sweep_temperature.yaml b/resources/supervised/configs/t5_sweep_temperature.yaml similarity index 100% rename from trax/supervised/configs/t5_sweep_temperature.yaml rename to resources/supervised/configs/t5_sweep_temperature.yaml diff --git a/trax/supervised/configs/terraformer_c4_medium.gin b/resources/supervised/configs/terraformer_c4_medium.gin similarity index 100% rename from trax/supervised/configs/terraformer_c4_medium.gin rename to resources/supervised/configs/terraformer_c4_medium.gin diff --git a/trax/supervised/configs/terraformer_copy.gin b/resources/supervised/configs/terraformer_copy.gin similarity index 100% rename from trax/supervised/configs/terraformer_copy.gin rename to resources/supervised/configs/terraformer_copy.gin diff --git a/trax/supervised/configs/terraformer_copy_self_attn.gin b/resources/supervised/configs/terraformer_copy_self_attn.gin similarity index 100% rename from trax/supervised/configs/terraformer_copy_self_attn.gin rename to resources/supervised/configs/terraformer_copy_self_attn.gin diff --git a/trax/supervised/configs/terraformer_purelsh_copy.gin b/resources/supervised/configs/terraformer_purelsh_copy.gin similarity index 100% rename from trax/supervised/configs/terraformer_purelsh_copy.gin rename to resources/supervised/configs/terraformer_purelsh_copy.gin diff --git a/trax/supervised/configs/terraformer_wmt_ende.gin b/resources/supervised/configs/terraformer_wmt_ende.gin similarity index 100% rename from trax/supervised/configs/terraformer_wmt_ende.gin rename to resources/supervised/configs/terraformer_wmt_ende.gin diff --git a/trax/supervised/configs/transformer_big_lm1b_8gb.gin b/resources/supervised/configs/transformer_big_lm1b_8gb.gin similarity index 100% rename from trax/supervised/configs/transformer_big_lm1b_8gb.gin rename to resources/supervised/configs/transformer_big_lm1b_8gb.gin diff --git a/trax/supervised/configs/transformer_finetune_squad_16gb.gin b/resources/supervised/configs/transformer_finetune_squad_16gb.gin similarity index 100% rename from trax/supervised/configs/transformer_finetune_squad_16gb.gin rename to resources/supervised/configs/transformer_finetune_squad_16gb.gin diff --git a/trax/supervised/configs/transformer_imdb_8gb.gin b/resources/supervised/configs/transformer_imdb_8gb.gin similarity index 100% rename from trax/supervised/configs/transformer_imdb_8gb.gin rename to resources/supervised/configs/transformer_imdb_8gb.gin diff --git a/trax/supervised/configs/transformer_imdb_tfds.gin b/resources/supervised/configs/transformer_imdb_tfds.gin similarity index 100% rename from trax/supervised/configs/transformer_imdb_tfds.gin rename to resources/supervised/configs/transformer_imdb_tfds.gin diff --git a/trax/supervised/configs/transformer_lm1b_8gb_testing.gin b/resources/supervised/configs/transformer_lm1b_8gb_testing.gin similarity index 100% rename from trax/supervised/configs/transformer_lm1b_8gb_testing.gin rename to resources/supervised/configs/transformer_lm1b_8gb_testing.gin diff --git a/trax/supervised/configs/transformer_lm_cnndailymail.gin b/resources/supervised/configs/transformer_lm_cnndailymail.gin similarity index 100% rename from trax/supervised/configs/transformer_lm_cnndailymail.gin rename to resources/supervised/configs/transformer_lm_cnndailymail.gin diff --git a/trax/supervised/configs/transformer_lm_wmt_ende_16gb.gin b/resources/supervised/configs/transformer_lm_wmt_ende_16gb.gin similarity index 100% rename from trax/supervised/configs/transformer_lm_wmt_ende_16gb.gin rename to resources/supervised/configs/transformer_lm_wmt_ende_16gb.gin diff --git a/trax/supervised/configs/transformer_lm_wmt_ende_8gb.gin b/resources/supervised/configs/transformer_lm_wmt_ende_8gb.gin similarity index 100% rename from trax/supervised/configs/transformer_lm_wmt_ende_8gb.gin rename to resources/supervised/configs/transformer_lm_wmt_ende_8gb.gin diff --git a/trax/supervised/configs/transformer_ptb_16gb.gin b/resources/supervised/configs/transformer_ptb_16gb.gin similarity index 100% rename from trax/supervised/configs/transformer_ptb_16gb.gin rename to resources/supervised/configs/transformer_ptb_16gb.gin diff --git a/trax/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin b/resources/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin similarity index 100% rename from trax/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin rename to resources/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin diff --git a/trax/supervised/configs/transformer_wmt_ende_8gb.gin b/resources/supervised/configs/transformer_wmt_ende_8gb.gin similarity index 100% rename from trax/supervised/configs/transformer_wmt_ende_8gb.gin rename to resources/supervised/configs/transformer_wmt_ende_8gb.gin diff --git a/trax/supervised/configs/wide_resnet_cifar10_8gb.gin b/resources/supervised/configs/wide_resnet_cifar10_8gb.gin similarity index 100% rename from trax/supervised/configs/wide_resnet_cifar10_8gb.gin rename to resources/supervised/configs/wide_resnet_cifar10_8gb.gin diff --git a/trax/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz b/resources/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz similarity index 100% rename from trax/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz rename to resources/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz diff --git a/trax/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz b/resources/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz similarity index 100% rename from trax/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz rename to resources/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz diff --git a/trax/supervised/testdata/terraformer_copy_self_attn.pkl.gz b/resources/supervised/testdata/terraformer_copy_self_attn.pkl.gz similarity index 100% rename from trax/supervised/testdata/terraformer_copy_self_attn.pkl.gz rename to resources/supervised/testdata/terraformer_copy_self_attn.pkl.gz diff --git a/trax/supervised/testdata/terraformer_purelsh_copy.pkl.gz b/resources/supervised/testdata/terraformer_purelsh_copy.pkl.gz similarity index 100% rename from trax/supervised/testdata/terraformer_purelsh_copy.pkl.gz rename to resources/supervised/testdata/terraformer_purelsh_copy.pkl.gz diff --git a/trax/supervised/testdata/transformer_copy.pkl.gz b/resources/supervised/testdata/transformer_copy.pkl.gz similarity index 100% rename from trax/supervised/testdata/transformer_copy.pkl.gz rename to resources/supervised/testdata/transformer_copy.pkl.gz diff --git a/trax/supervised/testdata/transformerlm_copy.pkl.gz b/resources/supervised/testdata/transformerlm_copy.pkl.gz similarity index 100% rename from trax/supervised/testdata/transformerlm_copy.pkl.gz rename to resources/supervised/testdata/transformerlm_copy.pkl.gz diff --git a/setup.py b/setup.py index 0f1f0ec57..c0e268de2 100644 --- a/setup.py +++ b/setup.py @@ -20,56 +20,58 @@ from setuptools import setup setup( - name='trax', - version='1.4.1', - description='Trax', + name="trax", + version="1.5.1", + description="Trax", long_description=( - 'Trax helps you understand deep learning. We start with basic maths and' - ' go through layers, models, supervised and reinforcement learning. We ' - 'get to advanced deep learning results, including recent papers and ' - 'state-of-the-art models.' + "Trax helps you understand deep learning. We start with basic maths and" + " go through layers, models, supervised and reinforcement learning. We " + "get to advanced deep learning results, including recent papers and " + "state-of-the-art models." ), - author='Google Inc.', - author_email='no-reply@google.com', - url='http://github.com/google/trax', - license='Apache 2.0', + author="Google Inc.", + author_email="no-reply@google.com", + url="http://github.com/google/trax", + license="Apache 2.0", packages=find_packages(), install_requires=[ - 'absl-py', - 'funcsigs', - 'gin-config', - 'gym', - 'jax', - 'jaxlib', - 'matplotlib', - 'numpy', - 'psutil', - 'scipy', - 'six', - 'tensorflow-datasets', - 'tensorflow-text', + "absl-py==1.4.0", + "funcsigs==1.0.2", + "gin-config==0.5.0", + "gym==0.26.2", + "jax==0.4.20", + "jaxlib==0.4.20", + "matplotlib==3.8.0", + "numpy==1.23.5", + "psutil==5.9.5", + "scipy==1.11.3", + "six==1.14.0", + "tensorflow-datasets==4.2.0", + "tensorflow-text==2.13.0", ], extras_require={ - 'tensorflow': ['tensorflow>=1.15.0'], - 'tensorflow_gpu': ['tensorflow-gpu>=1.15.0'], - 't5': ['t5>=0.4.0'], - 'tests': [ - 'attrs', - 'jupyter', - 'mock', - 'parameterized', - 'pylint', - 'pytest', - 'wrapt==1.11.*', + "tensorflow": ["tensorflow==2.13.0"], + "tensorflow_gpu": ["tensorflow-gpu>=2.13.0"], + "t5": ["t5==0.9.2"], + "tests": [ + "attrs==23.1.0", + "jupyter", + "mock==5.1.0", + "parameterized==0.9.0", + "pylint==2.17.7", + "pytest==7.4.2", + "wrapt==1.15.0", + ], + "t2t": [ + "tensor2tensor==1.15.7", ], - 't2t': ['tensor2tensor',], }, classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ], - keywords='tensorflow machine learning jax', + keywords="tensorflow machine learning jax", ) diff --git a/tests/data/inputs_test.py b/tests/data/inputs_test.py new file mode 100644 index 000000000..7c71a31a3 --- /dev/null +++ b/tests/data/inputs_test.py @@ -0,0 +1,1100 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.inputs.""" + +import itertools +import os + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from trax import data + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath(os.path.join(pkg_dir, "../../resources/data/testdata")) + + +def _spm_path(): + return os.path.join(_TESTDATA, "sentencepiece.model") + + +class InputsTest(parameterized.TestCase): + @parameterized.named_parameters( + ("zero", 0), + ("negative", -5), + ) + def test_shuffle_data_raises_error_queue_size(self, queue_size): + samples = iter(range(10)) + with self.assertRaises(ValueError): + _ = list(data.shuffle(samples, queue_size)) + + @parameterized.named_parameters( + ("one", 1), + ("two", 2), + ("twenty", 20), + ) + def test_shuffle_data_queue_size(self, queue_size): + samples = iter(range(100, 200)) + shuffled_stream = data.shuffle(samples, queue_size) + first_ten = [next(shuffled_stream) for _ in range(10)] + + # Queue size limits how far ahead/upstream the current sample can reach. + self.assertLess(first_ten[0], 100 + queue_size) + self.assertLess(first_ten[3], 103 + queue_size) + self.assertLess(first_ten[9], 109 + queue_size) + + unshuffled_first_ten = list(range(100, 110)) + if queue_size == 1: # Degenerate case: no shuffling can happen. + self.assertEqual(first_ten, unshuffled_first_ten) + if queue_size > 1: + self.assertNotEqual(first_ten, unshuffled_first_ten) + + @parameterized.named_parameters( + ("qsize_100_n_001", 100, 1), + ("qsize_100_n_099", 100, 99), + ("qsize_100_n_100", 100, 100), + ("qsize_100_n_101", 100, 101), + ("qsize_100_n_199", 100, 199), + ) + def test_shuffle_data_yields_all_samples(self, queue_size, n_samples): + samples = iter(range(n_samples)) + shuffled_stream = data.shuffle(samples, queue_size) + self.assertLen(list(shuffled_stream), n_samples) + + def test_batch_data(self): + dataset = ((i, i + 1) for i in range(10)) + batches = data.batch(dataset, 10) + batch = next(batches) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (10,)) + + def test_batch_data_padding(self): + dataset = (([1] * (10 - i), i + 1) for i in range(10)) + batches = data.batch(dataset, 10) + batch = next(batches) + self.assertEqual(batch[0].shape, (10, 10)) + self.assertTrue(np.array_equal(batch[0][-1], np.asarray([1] + 9 * [0]))) + + def test_batch_exception_size(self): + dataset = ((i, i + 1) for i in range(10)) + with self.assertRaises(ValueError): + batches = data.batch(dataset, 0) + next(batches) + + def test_serial(self): + dataset = lambda _: ((i, i + 1) for i in range(10)) + batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10)) + batch = next(batches()) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (10,)) + + def test_serial_composes(self): + """Check that data.Serial works inside another data.Serial.""" + dataset = lambda _: ((i, i + 1) for i in range(10)) + serial1 = data.Serial(dataset, data.Shuffle(3)) + batches = data.Serial(serial1, data.Batch(10)) + batch = next(batches()) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (10,)) + + def test_count_and_skip(self): + dataset = lambda _: ((i, i + 1) for i in range(10)) + examples = data.Serial(dataset, data.CountAndSkip("toy_data")) + ex_generator = examples() + ex1 = next(ex_generator) + self.assertEqual(ex1, (0, 1)) + self.assertEqual(data.inputs.data_counters["toy_data"], 1) + ex2 = next(ex_generator) + self.assertEqual(ex2, (1, 2)) + self.assertEqual(data.inputs.data_counters["toy_data"], 2) + ex3 = next(examples()) # new generator, will skip + self.assertEqual(ex3, (2, 3)) + self.assertEqual(data.inputs.data_counters["toy_data"], 3) + data.inputs.data_counters["toy_data"] = 0 # reset + ex4 = next(examples()) # new generator, was reset + self.assertEqual(ex4, (0, 1)) + self.assertEqual(data.inputs.data_counters["toy_data"], 1) + + def test_parallel(self): + """Basic test of the parallel ccmbinator.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + parallel = data.Parallel([dataset1, dataset2]) + generator = parallel() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 12) + + def test_parallel_with_gen_not_none(self): + """Test of the parallel ccmbinator with a not none generator.""" + dataset1 = lambda _: (i for i in range(10)) + dataset2 = lambda _: (i for i in range(10, 20)) + parallel = data.Parallel([dataset1, dataset2]) + + def test_generator(): + yield 0 + + generator = parallel(gen=test_generator) + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 12) + + def test_parallel_with_weights(self): + """Test of the parallel ccmbinator with weights.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + parallel = data.Parallel([dataset1, dataset2], counters=(2, 1)) + generator = parallel() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 3) + self.assertEqual(next(generator), 12) + self.assertEqual(next(generator), 4) + self.assertEqual(next(generator), 5) + self.assertEqual(next(generator), 13) + + def test_parallel_with_weights_and_minimum(self): + """Test of the parallel ccmbinator with weights and minimum.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 110)) + parallel = data.Parallel( + [dataset1, dataset2], counters=(10, 100), reweight_by_minimum=True + ) + generator = parallel() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 12) + self.assertEqual(next(generator), 13) + self.assertEqual(next(generator), 14) + self.assertEqual(next(generator), 15) + self.assertEqual(next(generator), 16) + self.assertEqual(next(generator), 17) + self.assertEqual(next(generator), 18) + self.assertEqual(next(generator), 19) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 20) + self.assertEqual(next(generator), 21) + self.assertEqual(next(generator), 22) + self.assertEqual(next(generator), 23) + self.assertEqual(next(generator), 24) + self.assertEqual(next(generator), 25) + self.assertEqual(next(generator), 26) + self.assertEqual(next(generator), 27) + self.assertEqual(next(generator), 28) + self.assertEqual(next(generator), 29) + self.assertEqual(next(generator), 2) + + def test_parallel_with_gradual_reweighting(self): + """Test of the parallel ccmbinator with weights and minimum.""" + dataset1 = lambda: (i for i in itertools.cycle(range(1))) + dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) + dataset3 = lambda: (i for i in itertools.cycle(range(30, 70))) + parallel = data.Parallel( + [dataset2, dataset1, dataset3], + counters=(20, 1, 40), + gradually_reweight=True, + ) + generator = parallel() + + for _ in range(3): + self.assertEqual(next(generator), 0) + for i in range(20): + self.assertEqual(next(generator), 10 + i) + self.assertEqual(next(generator), 30 + 2 * i) + self.assertEqual(next(generator), 30 + 2 * i + 1) + + def test_parallel_with_gradual_reweighting_remainders(self): + """Test of the parallel ccmbinator with weights and minimum.""" + dataset1 = lambda: (i for i in itertools.cycle(range(1))) + dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) + dataset3 = lambda: (i for i in itertools.cycle(range(30, 80))) + parallel = data.Parallel( + [dataset2, dataset1, dataset3], + counters=(20, 1, 50), + gradually_reweight=True, + use_remainders=True, + ) + generator = parallel() + + for _ in range(3): + self.assertEqual(next(generator), 0) + for i in range(20): + self.assertEqual(next(generator), 10 + i) + self.assertEqual(next(generator), 30 + 2 * i) + self.assertEqual(next(generator), 30 + 2 * i + 1) + # Here we process the remainder from dataset 3: + for i in range(10): + self.assertEqual(next(generator), 70 + i) + + def test_parallel_with_gradual_reweighting_remainders_big(self): + """Test of the parallel ccmbinator with weights and minimum.""" + dataset1 = lambda: (i for i in itertools.cycle(range(1))) + dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) + dataset3 = lambda: (i for i in itertools.cycle(range(30, 80))) + dataset4 = lambda: (i for i in itertools.cycle(range(100, 220))) + parallel = data.Parallel( + [dataset2, dataset1, dataset4, dataset3], + counters=(20, 1, 120, 50), + gradually_reweight=True, + use_remainders=True, + ) + generator = parallel() + + for _ in range(3): + self.assertEqual(next(generator), 0) + for i in range(20): + self.assertEqual(next(generator), 10 + i) + for j in range(2): + self.assertEqual(next(generator), 30 + 2 * i + j) + for k in range(2): + self.assertEqual(next(generator), 100 + 2 * 2 * i + 2 * j + k) + # Here we process the remainder from datasets 3 and 4: + for i in range(10): + self.assertEqual(next(generator), 70 + i) + for i in range(40): + self.assertEqual(next(generator), 180 + i) + + def test_parallel_with_weights_three_datasets(self): + """Check that data.Serial works inside another data.Serial.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + dataset3 = lambda: (i for i in range(20, 30)) + parallel = data.Parallel([dataset1, dataset2, dataset3], counters=(2, 1, 3)) + generator = parallel() + + self.assertEqual(next(generator), 0) # (1,0,0) + self.assertEqual(next(generator), 10) # (1,1,0) + self.assertEqual(next(generator), 20) # (1,1,1) + self.assertEqual(next(generator), 1) # (2,1,1) + self.assertEqual(next(generator), 21) # (2,1,2) + self.assertEqual(next(generator), 22) # (2,1,3) + self.assertEqual(next(generator), 2) # (1,0,0) + self.assertEqual(next(generator), 11) # (1,1,0) + self.assertEqual(next(generator), 23) # (1,1,1) + self.assertEqual(next(generator), 3) # (2,1,1) + self.assertEqual(next(generator), 24) # (2,1,2) + self.assertEqual(next(generator), 25) # (2,1,3) + self.assertEqual(next(generator), 4) # (1,0,0) + + def test_stack_parallel(self): + """Test of stacked parallel ccmbinators.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + dataset3 = lambda: (i for i in range(20, 30)) + parallel_lev0 = data.Parallel([dataset1, dataset2]) + parallel_lev1 = data.Parallel([parallel_lev0, dataset3]) + generator = parallel_lev1() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 20) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 21) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 22) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 23) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 24) + self.assertEqual(next(generator), 12) + + def test_parallel_with_zero_counters(self): + """Test of stacked parallel ccmbinators.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + dataset3 = lambda: (i for i in range(20, 30)) + parallel = data.Parallel([dataset1, dataset2, dataset3], counters=[1, 0, 1]) + generator = parallel() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 20) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 21) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 22) + self.assertEqual(next(generator), 3) + self.assertEqual(next(generator), 23) + + def test_serial_with_python(self): + dataset = lambda _: ((i, i + 1) for i in range(10)) + batches = data.Serial( + dataset, + lambda g: map(lambda x: (x[0], x[1] + 1), g), + lambda g: filter(lambda x: x[0] % 2 == 1, g), + data.Batch(2), + ) + batch = next(batches()) + self.assertLen(batch, 2) + (xs, ys) = batch + # First tuple after filtering is (1, 3) = (1, 2+1). + self.assertEqual(xs[0], 1) + self.assertEqual(ys[0], 3) + # Second tuple after filtering is (3, 5). + self.assertEqual(xs[1], 3) + self.assertEqual(ys[1], 5) + + def test_pad_to_max_dims(self): + tensors1 = [np.zeros((3, 10)), np.ones((3, 10))] + padded1 = data.inputs.pad_to_max_dims(tensors1) + self.assertEqual(padded1.shape, (2, 3, 10)) + tensors2 = [np.zeros((2, 10)), np.ones((3, 9))] + padded2 = data.inputs.pad_to_max_dims(tensors2) + self.assertEqual(padded2.shape, (2, 3, 10)) + tensors3 = [np.zeros((8, 10)), np.ones((8, 9))] + padded3 = data.inputs.pad_to_max_dims(tensors3, 12) + self.assertEqual(padded3.shape, (2, 12, 12)) + tensors4 = [np.zeros((2, 10)), np.ones((3, 9))] + padded4 = data.inputs.pad_to_max_dims(tensors4, 12) + self.assertEqual(padded4.shape, (2, 4, 12)) + + def test_pad_to_length(self): + tensors1 = [(np.zeros((5)), np.ones((3)))] + pad_to_length_function1 = data.inputs.PadToLength( + len_map={0: 10, 1: 11}, pad_value={0: 0, 1: 1} + ) + padded1 = next(pad_to_length_function1(tensors1)) + self.assertEqual(padded1[0].shape, (10,)) + self.assertEqual(padded1[1].shape, (11,)) + + tensors2 = [(np.zeros((15)), np.ones((20)))] + pad_to_length_function2 = data.inputs.PadToLength( + len_map={0: 10, 1: 10}, pad_value={0: 0, 1: 1}, multiple=True + ) + padded2 = next(pad_to_length_function2(tensors2)) + self.assertEqual(padded2[0].shape, (20,)) + self.assertEqual(padded2[1].shape, (20,)) + + def test_concatenate_lm_input(self): + tensors1 = [(np.zeros((5)), np.ones((3)))] + + lm_input_function1 = data.inputs.ConcatenateToLMInput(pad_to_length=10) + lm_input_1 = next(lm_input_function1(tensors1)) + self.assertEqual(lm_input_1[0].shape, (10,)) + self.assertEqual(lm_input_1[1].shape, (10,)) + self.assertEqual(lm_input_1[2].shape, (10,)) + self.assertEqual( + lm_input_1[2].all(), + np.array([[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0]]).all(), + ) + + tensors2 = [(np.zeros((5)), np.ones((3)))] + lm_input_function2 = data.inputs.ConcatenateToLMInput() + lm_input_2 = next(lm_input_function2(tensors2)) + self.assertEqual(lm_input_2[0].shape, (8,)) + self.assertEqual(lm_input_2[1].shape, (8,)) + self.assertEqual(lm_input_2[2].shape, (8,)) + self.assertEqual( + lm_input_2[2].all(), + np.array([[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]]).all(), + ) + + def test_truncate_to_length_no_arg(self): + """Tests that a no-arg call leaves shapes unchanged.""" + + def data_stream(): + while True: + yield (np.zeros((1, 5)), np.ones((1, 5))) + + stream_fn = data.inputs.TruncateToLength() + y0, y1 = next(stream_fn(data_stream())) + self.assertEqual(y0.shape, (1, 5)) + self.assertEqual(y1.shape, (1, 5)) + + @parameterized.named_parameters( + ("none", None, ((1, 5), (1, 5))), + ("large_values", {0: (1, 77), 1: (1, 88)}, ((1, 5), (1, 5))), + ("small_values", {0: (1, 3), 1: (1, 2)}, ((1, 3), (1, 2))), + ) + def test_truncate_to_length_len_map(self, len_map, out_shapes): + """Tests that truncation occurs when len_map values are small enough.""" + + def data_stream(): + while True: + yield (np.zeros((1, 5)), np.ones((1, 5))) + + stream_fn = data.inputs.TruncateToLength(len_map=len_map) + y0, y1 = next(stream_fn(data_stream())) + self.assertEqual(y0.shape, out_shapes[0]) + self.assertEqual(y1.shape, out_shapes[1]) + + def test_truncate_to_length_questionable_behavior(self): + # Use of np.reshape in TruncateToLength allows non-truncation results + # without warning. As long as the target shape (len_map value) is + # lexicographically prior to the data shape, then np.reshape can happen, + # even if it results in *adding* values to the overall array. + # + # This test passes as a marker of the questionable behavior, and should + # *fail* -- and then be removed -- when the function is + # clarified/re-implemented. + # + # TODO(jonni): Determine desired behavior, and fit implementation to it. + x = np.arange(21).reshape((1, 21, 1)) + + def data_stream(): + while True: + yield x + + stream_fn = data.inputs.TruncateToLength(len_map={0: (1, 4, 6)}) + (y,) = next(stream_fn(data_stream())) + self.assertEqual(y.shape, (1, 4, 6)) + self.assertEqual(y[0, 3, 1], 19) + self.assertEqual(y[0, 3, 2], 20) # end of original values [0..20] + self.assertEqual(y[0, 3, 3], 0) # added value + self.assertEqual(y[0, 3, 4], 1) # added value + self.assertEqual(y[0, 3, 5], 2) # added value + + def test_filter_empty_examples(self): + tensors1 = [ + (np.zeros((0,)), np.ones((1, 5))), + (np.zeros((1, 5)), np.ones((1, 5))), + ] + + filter_empty_examples_function1 = data.inputs.FilterEmptyExamples() + filtered1 = next(filter_empty_examples_function1(tensors1)) + self.assertEqual(filtered1[0].shape, (1, 5)) + self.assertEqual(filtered1[1].shape, (1, 5)) + + filter_empty_examples_function2 = data.inputs.FilterEmptyExamples(axes=[1]) + filtered2 = next(filter_empty_examples_function2(tensors1)) + self.assertEqual(filtered2[0].shape, (0,)) + self.assertEqual(filtered2[1].shape, (1, 5)) + + def test_append_value(self): + tensors1 = [(np.zeros((1, 5)), np.ones((1, 5)))] + + append_value_function1 = data.inputs.AppendValue() + unmodified = next(append_value_function1(tensors1)) + self.assertEqual(unmodified[0].shape, (1, 5)) + self.assertEqual(unmodified[1].shape, (1, 5)) + + append_value_function2 = data.inputs.AppendValue({0: [[5]], 1: [[4]]}) + appended = next(append_value_function2(tensors1)) + self.assertEqual(appended[0].shape, (1, 6)) + self.assertEqual( + appended[0].all(), np.array([[0.0, 0.0, 0.0, 0.0, 0.0, 5.0]]).all() + ) + self.assertEqual(appended[1].shape, (1, 6)) + self.assertEqual( + appended[1].all(), np.array([[1.0, 1.0, 1.0, 1.0, 1.0, 4.0]]).all() + ) + + def test_pad_to_max_dims_boundary_list(self): + tensors = [np.zeros((1, 15, 31)), np.ones((2, 10, 35)), np.ones((4, 2, 3))] + padded_tensors = data.inputs.pad_to_max_dims(tensors, boundary=(None, 15, 20)) + # no boundary, only max in the first dim, 15 is already the max len in + # second dim, last dim padded to multiple of 20. + # The outer dim is the batch here. + self.assertEqual(padded_tensors.shape, (3, 4, 15, 40)) + + def test_pad_to_max_dims_strict_pad_on_len(self): + tensors = [np.ones((15,)), np.ones((12,)), np.ones((14,))] + padded_tensors = data.inputs.pad_to_max_dims( + tensors, boundary=10, strict_pad_on_len=True + ) + self.assertEqual(padded_tensors.shape, (3, 20)) + + def test_bucket_by_length(self): + def fake_generator(length, num_examples=1): + for _ in range(num_examples): + yield (np.ones((length,)), np.ones((length,))) + + def length_function(example): + return max(example[0].shape[0], example[1].shape[0]) + + batches = list( + data.bucket_by_length( + fake_generator(5, 6), length_function, [20], [2], strict_pad_on_len=True + ) + ) + + # We'll get three batches of 2 examples each. + self.assertLen(batches, 3) + self.assertIsInstance(batches[0], tuple) + self.assertLen(batches[0], 2) + self.assertEqual((2, 20), batches[0][0].shape) + self.assertEqual((2, 20), batches[0][1].shape) + + @parameterized.named_parameters( + ("encdec_on", True), + ("encdec_off", False), + ) + def test_addition_inputs_exceptions(self, encdec): + vocab_size = 5 + batch_size = 256 + seq_length = 64 + # Check if max/min lengths are validated for train stream + with self.assertRaises(ValueError): + inputs = data.inputs.addition_inputs( + vocab_size=vocab_size, + batch_size=batch_size, + train_length=2, + eval_min_length=1, + eval_max_length=seq_length, + pad_to_multiple=seq_length, + encdec=encdec, + ) + train_stream = inputs.train_stream(n_devices=1) + for _ in range(10): + next(train_stream) + + # Check if max/min lengths are validated for eval stream + with self.assertRaises(ValueError): + inputs = data.inputs.addition_inputs( + vocab_size=vocab_size, + batch_size=batch_size, + train_length=seq_length, + eval_min_length=1, + eval_max_length=seq_length, + pad_to_multiple=seq_length, + encdec=True, + ) + eval_stream = inputs.eval_stream(n_devices=1) + for _ in range(10): + next(eval_stream) + + def test_addition_inputs_constraints(self): + vocab_size = 5 + batch_size = 256 + seq_length = 64 + inputs = data.inputs.addition_inputs( + vocab_size=vocab_size, + batch_size=batch_size, + train_length=seq_length, + eval_min_length=seq_length, + eval_max_length=seq_length, + pad_to_multiple=seq_length, + encdec=True, + ) + + # Check if max length is respected for train stream + train_stream = inputs.train_stream(n_devices=1) + for _ in range(10): + x, y, weights = next(train_stream) + self.assertEqual(x.shape[1], seq_length) + self.assertEqual(y.shape[1], seq_length) + self.assertEqual(weights.shape[1], seq_length) + + # Check if max length is respected for eval stream + eval_stream = inputs.eval_stream(n_devices=1) + for _ in range(10): + x, y, weights = next(eval_stream) + self.assertEqual(x.shape[1], seq_length) + self.assertEqual(y.shape[1], seq_length) + self.assertEqual(weights.shape[1], seq_length) + + def _get_span_lengths(self, x): + span_lengths = [] + curr_len = 0 + for i in range(1, len(x)): + # 1 -> 0 + if x[i] == 0 and x[i - 1] == 1: + span_lengths.append(curr_len) + curr_len = 0 + # 1 -> 1 or 0 -> 1 + elif (x[i] == 1 and x[i - 1] == 1) or (x[i] == 1 and x[i - 1] == 0): + curr_len += 1 + if curr_len != 0: + span_lengths.append(curr_len) + return span_lengths + + def test_random_spans_noise_mask(self): + length = 100 + noise_density = 0.15 + mean_noise_span_length = 3.0 + + # Take 5 random seed1, seed2 values. + for seed in np.random.randint(0, 100, (5, 2)): + is_noise = data.random_spans_noise_mask( + length, + noise_density, + mean_noise_span_length, + seed1=seed[0], + seed2=seed[1], + ) + is_noise = is_noise.astype(np.int32) + # noise_density fraction of tokens are produced + self.assertEqual(np.sum(is_noise), noise_density * length) + # Get span lengths and make sure the average is what we expect. + actual_span_lengths = self._get_span_lengths(is_noise) + average_span_length = sum(actual_span_lengths) / len(actual_span_lengths) + self.assertEqual(mean_noise_span_length, average_span_length) + + @absltest.skip("The version of the dataset you are trying is to old") + def test_process_c4_with_span_corruption(self): + def process_c4_with_span_corruption( + spm_path=None, + extra_ids=0, + train=False, + max_length=100, + noise_density=0.15, + mean_noise_span_length=3.0, + seed1=None, + seed2=None, + ): + return data.Serial( + data.TFDS( + "c4/en:2.3.0", data_dir=_TESTDATA, keys=("text",), train=train + ), + data.SentencePieceTokenize(spm_path=spm_path, extra_ids=extra_ids), + data.generate_sequential_chunks(max_length=max_length), + data.generate_random_noise_mask( + noise_density=noise_density, + mean_noise_span_length=mean_noise_span_length, + seed1=seed1, + seed2=seed2, + ), + data.consume_noise_mask(vocab_size=32000 + extra_ids), + data.FilterEmptyExamples(), + data.AppendValue(val={0: [1], 1: [1]}), + data.PadToLength(len_map={0: 100, 1: 30}, pad_value={0: 0, 1: 0}), + data.AddLossWeights(id_to_mask=0), + data.Batch(batch_size=2), + ) + + gen = process_c4_with_span_corruption(spm_path=_spm_path(), seed1=0, seed2=1) + + examples = [] + for i, ex in enumerate(gen()): + if i == 100: + break + examples.append(ex) + + self.assertLen(examples, 100) + example = examples[0] + + batched_input, batched_output, batched_loss_weights = example + + self.assertSequenceEqual( + batched_input.tolist(), + # pylint: disable=bad-continuation,bad-whitespace + [ + [ + 37, + 2335, + 113, + 3977, + 227, + 7306, + 45, + 3, + 9, + 4716, + 147, + 8, + 71, + 2658, + 65, + 118, + 4313, + 38, + 3, + 9, + 13065, + 32, + 31999, + 9, + 5704, + 26, + 109, + 6, + 6862, + 6, + 4728, + 45, + 8, + 3796, + 24093, + 11834, + 4716, + 30, + 8, + 1379, + 13, + 31998, + 130, + 718, + 12, + 8, + 24124, + 1343, + 300, + 4357, + 1714, + 31997, + 1373, + 47, + 16487, + 3168, + 16, + 321, + 7943, + 5, + 3, + 4868, + 3856, + 5700, + 75, + 7, + 200, + 2231, + 6, + 11163, + 9, + 6, + 113, + 47, + 5330, + 45, + 14354, + 6, + 47, + 31996, + 20721, + 3654, + 44, + 8, + 3112, + 5, + 14599, + 11, + 8067, + 31995, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 277, + 828, + 43, + 5899, + 46, + 16, + 10952, + 139, + 160, + 1687, + 56, + 539, + 30, + 2875, + 41, + 31122, + 2307, + 137, + 2702, + 2780, + 15, + 7, + 31999, + 44, + 8, + 3112, + 11, + 30, + 569, + 783, + 5, + 3, + 17701, + 6, + 2194, + 26, + 23, + 1336, + 6321, + 1694, + 30, + 31998, + 196, + 56, + 1852, + 1423, + 25, + 5, + 27, + 183, + 8032, + 31997, + 217, + 149, + 1513, + 11, + 2238, + 25, + 1800, + 5, + 96, + 2703, + 44, + 3065, + 12537, + 11163, + 9, + 535, + 71, + 9363, + 14886, + 646, + 44, + 8, + 3112, + 243, + 23281, + 12, + 8, + 31996, + 346, + 402, + 17, + 99, + 83, + 11, + 773, + 3668, + 1280, + 31995, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ] + # pylint: enable=bad-continuation,bad-whitespace + ) + + self.assertSequenceEqual( + batched_output.tolist(), + # pylint: disable=bad-continuation,bad-whitespace + [ + [ + 31999, + 1639, + 7, + 15480, + 5, + 11163, + 31998, + 2083, + 9997, + 5076, + 31997, + 265, + 11, + 8, + 31996, + 3, + 31995, + 1343, + 2487, + 106, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 31999, + 12, + 8, + 15480, + 130, + 646, + 31998, + 1376, + 10, + 96, + 31997, + 62, + 410, + 59, + 31996, + 96, + 31995, + 94, + 608, + 10, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ] + # pylint: enable=bad-continuation,bad-whitespace + ) + + self.assertSequenceEqual( + batched_loss_weights.tolist(), + # pylint: disable=bad-continuation,bad-whitespace + [ + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + ] + # pylint: enable=bad-continuation,bad-whitespace + ) + + def test_prefix_lm_last_output_batch_is_short(self): + prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) + examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7, 8]])) + self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) + self.assertSequenceEqual(([6, 7], [8]), examples[1]) + self.assertLen(examples, 2) + + def test_prefix_lm_last_input_batch_is_short(self): + prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) + examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6]])) + self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) + self.assertLen(examples, 1) + + def test_prefix_lm_last_input_batch_exists_but_no_output(self): + prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) + examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7]])) + self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) + self.assertLen(examples, 1) + + def test_unbatch(self): + unbatch_fn = data.UnBatch() + batched_inputs = [ + # First batch - 3 examples + ( + np.arange(3 * 2).reshape(3, -1), + np.arange(3 * 3).reshape(3, -1), + np.arange(3 * 4).reshape(3, -1), + ), + # Second batch - 4 examples + ( + np.arange(4 * 2).reshape(4, -1), + np.arange(4 * 3).reshape(4, -1), + np.arange(4 * 4).reshape(4, -1), + ), + ] + examples = list(unbatch_fn(batched_inputs)) + self.assertLen(examples, 3 + 4) + + def test_sine_shape(self): + inputs = data.sine_inputs(batch_size=3, length=5) + train_batch = next(inputs.train_stream(n_devices=1)) + eval_batch = next(inputs.eval_stream(n_devices=1)) + # (observations, actions, observations, mask) + self.assertLen(train_batch, 4) + self.assertLen(eval_batch, 4) + for (x, y) in zip(train_batch, eval_batch): + self.assertEqual(x.shape, (3, 5)) + self.assertEqual(y.shape, (3, 5)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/data/text_encoder_test.py b/tests/data/text_encoder_test.py new file mode 100644 index 000000000..2f45e65a9 --- /dev/null +++ b/tests/data/text_encoder_test.py @@ -0,0 +1,386 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.text_encoder.""" + +import collections +import io +import os +import random +import shutil +import string + +import mock +from six.moves import range # pylint: disable=redefined-builtin +import tensorflow.compat.v1 as tf +from trax.data import text_encoder + + +class NativeToUnicodeTest(tf.test.TestCase): + def test_native_to_unicode(self): + s = r"foo bar" + s_unicode = text_encoder.native_to_unicode(s) + self.assertEqual(s_unicode, "foo bar") + + +class EscapeUnescapeTokenTest(tf.test.TestCase): + def test_escape_token(self): + escaped = text_encoder._escape_token( + "Foo! Bar.\nunder_score back\\slash", + set("abcdefghijklmnopqrstuvwxyz .\n") | text_encoder._ESCAPE_CHARS, + ) + + self.assertEqual( + "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_", escaped + ) + + def test_unescape_token(self): + unescaped = text_encoder._unescape_token( + "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_" + ) + + self.assertEqual("Foo! Bar.\nunder_score back\\slash", unescaped) + + +class TokenTextEncoderTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + """Make sure the test dir exists and is empty.""" + cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") + shutil.rmtree(cls.test_temp_dir, ignore_errors=True) + tf.gfile.MakeDirs(cls.test_temp_dir) + + def test_save_and_reload(self): + """Test that saving and reloading doesn't change the vocab. + + Note that this test reads and writes to the filesystem, which necessitates + that this test size be "large". + """ + + corpus = "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z" + vocab_filename = os.path.join(self.test_temp_dir, "abc.vocab") + + # Make text encoder from a list and store vocab to fake filesystem. + encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) + encoder.store_to_file(vocab_filename) + + # Load back the saved vocab file from the fake_filesystem. + new_encoder = text_encoder.TokenTextEncoder(vocab_filename) + + self.assertEqual(encoder._id_to_token, new_encoder._id_to_token) + self.assertEqual(encoder._token_to_id, new_encoder._token_to_id) + + def test_reserved_tokens_in_corpus(self): + """Test that we handle reserved tokens appearing in the corpus.""" + corpus = "A B {} D E F {} G {}".format( + text_encoder.EOS, text_encoder.EOS, text_encoder.PAD + ) + + encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) + + all_tokens = encoder._id_to_token.values() + + # If reserved tokens are removed correctly, then the set of tokens will + # be unique. + self.assertEqual(len(all_tokens), len(set(all_tokens))) + + +class SubwordTextEncoderTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + """Make sure the test dir exists and is empty.""" + cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") + shutil.rmtree(cls.test_temp_dir, ignore_errors=True) + tf.gfile.MakeDirs(cls.test_temp_dir) + + def test_encode_decode(self): + corpus = ( + "This is a corpus of text that provides a bunch of tokens from which " + "to build a vocabulary. It will be used when strings are encoded " + "with a TextEncoder subclass. The encoder was coded by a coder." + ) + token_counts = collections.Counter(corpus.split(" ")) + alphabet = set(corpus) - {" "} + + original = "This is a coded sentence encoded by the SubwordTextEncoder." + token_counts.update(original.split(" ")) + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + + # Encoding should be reversible. + encoded = encoder.encode(original) + decoded = encoder.decode(encoded) + self.assertEqual(original, decoded) + + # The substrings coded and coder are frequent enough in the corpus that + # they should appear in the vocabulary even though they are substrings + # of other included strings. + subtoken_strings = {encoder.all_subtoken_strings[i] for i in encoded} + self.assertIn("encoded_", subtoken_strings) + self.assertIn("coded_", subtoken_strings) + self.assertIn("TextEncoder", encoder.all_subtoken_strings) + self.assertIn("coder", encoder.all_subtoken_strings) + + # Every character in the corpus should be in the encoders alphabet and + # its subtoken vocabulary. + self.assertTrue(alphabet.issubset(encoder._alphabet)) + for a in alphabet: + self.assertIn(a, encoder.all_subtoken_strings) + + def test_unicode(self): + corpus = "Cat emoticons. \U0001F638 \U0001F639 \U0001F63A \U0001F63B" + token_counts = collections.Counter(corpus.split(" ")) + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + + self.assertIn("\U0001F638", encoder._alphabet) + self.assertIn("\U0001F63B", encoder.all_subtoken_strings) + + def test_small_vocab(self): + corpus = "The quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + alphabet = set(corpus) - {" "} + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 10, token_counts, 2, 10 + ) + + # All vocabulary elements are in the alphabet and subtoken strings even + # if we requested a smaller vocabulary to assure all expected strings + # are encodable. + self.assertTrue(alphabet.issubset(encoder._alphabet)) + for a in alphabet: + self.assertIn(a, encoder.all_subtoken_strings) + + def test_long_tokens(self): + """Subword tokenization should still run efficiently with long tokens. + + To make it run efficiently, we need to use the `max_subtoken_length` + argument when calling SubwordTextEncoder.build_to_target_size. + """ + token_length = 4000 + num_tokens = 50 + target_vocab_size = 600 + max_subtoken_length = 10 # Set this to `None` to get problems. + max_count = 500 + + # Generate some long random strings. + random.seed(0) + long_tokens = [] + for _ in range(num_tokens): + long_token = "".join( + [random.choice(string.ascii_uppercase) for _ in range(token_length)] + ) + long_tokens.append(long_token) + + corpus = " ".join(long_tokens) + token_counts = collections.Counter(corpus.split(" ")) + alphabet = set(corpus) - {" "} + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + target_vocab_size, + token_counts, + 1, + max_count, + num_iterations=1, + max_subtoken_length=max_subtoken_length, + ) + + # All vocabulary elements are in the alphabet and subtoken strings even + # if we requested a smaller vocabulary to assure all expected strings + # are encodable. + self.assertTrue(alphabet.issubset(encoder._alphabet)) + for a in alphabet: + self.assertIn(a, encoder.all_subtoken_strings) + + def test_custom_reserved_tokens(self): + """Test that we can pass custom reserved tokens to SubwordTextEncoder.""" + corpus = "The quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + start_symbol = "" + end_symbol = "" + reserved_tokens = text_encoder.RESERVED_TOKENS + [start_symbol, end_symbol] + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 10, token_counts, 2, 10, reserved_tokens=reserved_tokens + ) + + # Make sure that reserved tokens appear in the right places. + self.assertEqual(encoder.decode([2]), start_symbol) + self.assertEqual(encoder.decode([3]), end_symbol) + + # Make sure that we haven't messed up the ability to reconstruct. + reconstructed_corpus = encoder.decode(encoder.encode(corpus)) + self.assertEqual(corpus, reconstructed_corpus) + + def test_encodable_when_not_in_alphabet(self): + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + original = "This has UPPER CASE letters that are out of alphabet" + + # Early versions could have an infinite loop when breaking into subtokens + # if there was any out-of-alphabet characters in the encoded string. + encoded = encoder.encode(original) + decoded = encoder.decode(encoded) + + self.assertEqual(original, decoded) + encoded_str = "".join(encoder.all_subtoken_strings[i] for i in encoded) + self.assertIn("\\84;", encoded_str) + + @mock.patch.object(text_encoder, "_ESCAPE_CHARS", new=set("\\_;13579")) + def test_raises_exception_when_not_encodable(self): + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + # Deliberately exclude some required encoding chars from the alphabet + # and token list, making some strings unencodable. + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + original = "This has UPPER CASE letters that are out of alphabet" + + # Previously there was a bug which produced an infinite loop in this case. + with self.assertRaises(AssertionError): + encoder.encode(original) + + def test_load_from_file(self): + # Test a vocab file with words not wrapped with single quotes + encoder = text_encoder.SubwordTextEncoder() + correct_vocab = ["the", "and", "of"] + vocab = io.StringIO("the\n" "and\n" "of\n") + encoder._load_from_file_object(vocab) + self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) + + # Test a vocab file with words wrapped in single quotes + encoder = text_encoder.SubwordTextEncoder() + vocab = io.StringIO('"the"\n' '"and"\n' '"of"\n') + encoder._load_from_file_object(vocab) + self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) + + def test_reserved_token_chars_not_in_alphabet(self): + corpus = "dog" + token_counts = collections.Counter(corpus.split(" ")) + encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 100 + ) + filename = os.path.join(self.test_temp_dir, "out.voc") + encoder1.store_to_file(filename) + encoder2 = text_encoder.SubwordTextEncoder(filename=filename) + + self.assertEqual(encoder1._alphabet, encoder2._alphabet) + + for t in text_encoder.RESERVED_TOKENS: + for c in t: + # Verify that encoders can encode all reserved token chars. + encoder1.encode(c) + encoder2.encode(c) + + def test_save_and_reload(self): + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + # Deliberately exclude some required encoding chars from the alphabet + # and token list, making some strings unencodable. + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + + filename = os.path.join(self.test_temp_dir, "out.voc") + encoder.store_to_file(filename) + new_encoder = text_encoder.SubwordTextEncoder(filename) + + self.assertEqual(encoder._alphabet, new_encoder._alphabet) + self.assertEqual(encoder.all_subtoken_strings, new_encoder.all_subtoken_strings) + self.assertEqual( + encoder._subtoken_string_to_id, new_encoder._subtoken_string_to_id + ) + self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len) + + def test_save_and_reload_no_single_quotes(self): + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + # Deliberately exclude some required encoding chars from the alphabet + # and token list, making some strings unencodable. + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + + filename = os.path.join(self.test_temp_dir, "out.voc") + encoder.store_to_file(filename, add_single_quotes=False) + new_encoder = text_encoder.SubwordTextEncoder(filename) + + self.assertEqual(encoder._alphabet, new_encoder._alphabet) + self.assertEqual(encoder.all_subtoken_strings, new_encoder.all_subtoken_strings) + self.assertEqual( + encoder._subtoken_string_to_id, new_encoder._subtoken_string_to_id + ) + self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len) + + def test_build_from_generator(self): + corpus = "The quick brown fox jumps over the lazy dog" + + def gen(): + for _ in range(3): + yield corpus + + start_symbol = "" + end_symbol = "" + reserved_tokens = text_encoder.RESERVED_TOKENS + [start_symbol, end_symbol] + encoder = text_encoder.SubwordTextEncoder.build_from_generator( + gen(), 10, reserved_tokens=reserved_tokens + ) + + # Make sure that reserved tokens appear in the right places. + self.assertEqual(encoder.decode([2]), start_symbol) + self.assertEqual(encoder.decode([3]), end_symbol) + + self.assertEqual( + "hi%s" % start_symbol, encoder.decode(encoder.encode("hi") + [2]) + ) + + # Make sure that we haven't messed up the ability to reconstruct. + reconstructed_corpus = encoder.decode(encoder.encode(corpus)) + self.assertEqual(corpus, reconstructed_corpus) + + +class OneHotClassLabelEncoderTest(tf.test.TestCase): + def test_one_hot_encode(self): + encoder = text_encoder.OneHotClassLabelEncoder( + class_labels=["zero", "one", "two"] + ) + self.assertEqual(encoder.encode("zero"), [1, 0, 0]) + self.assertEqual(encoder.encode("one"), [0, 1, 0]) + self.assertEqual(encoder.encode("two"), [0, 0, 1]) + + def test_one_hot_decode(self): + encoder = text_encoder.OneHotClassLabelEncoder( + class_labels=["zero", "one", "two"] + ) + self.assertEqual(encoder.decode([1, 0, 0]), "zero") + self.assertEqual(encoder.decode([0, 1, 0]), "one") + self.assertEqual(encoder.decode([0, 0, 1]), "two") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/data/tf_inputs_test.py b/tests/data/tf_inputs_test.py new file mode 100644 index 000000000..41d1cdfa7 --- /dev/null +++ b/tests/data/tf_inputs_test.py @@ -0,0 +1,899 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.tf_inputs.""" + +import collections +import os +from unittest import mock + +import gin +import numpy as np +from t5.data import assert_dataset +from t5.data import preprocessors as t5_processors +import tensorflow as tf +import tensorflow_datasets as tfds +from trax.data import inputs # pylint: disable=unused-import +from trax.data import tf_inputs + + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath(os.path.join(pkg_dir, "../../resources/data/testdata")) + + +def _test_dataset_ints(inp_lengths, tgt_lengths): + """Create a test dataset of int64 tensors of given shapes.""" + + def generator(): + for inp_len, tgt_len in zip(inp_lengths, tgt_lengths): + inp = np.ones([inp_len], dtype=np.int64) + tgt = np.ones([tgt_len], dtype=np.int64) + yield {"inputs": inp, "targets": tgt} + + types = {"inputs": tf.int64, "targets": tf.int64} + shapes = {"inputs": tf.TensorShape([None]), "targets": tf.TensorShape([None])} + return tf.data.Dataset.from_generator( + generator, output_types=types, output_shapes=shapes + ) + + +def _load_dataset(name, split="train"): + return tfds.load(name=name, split=split, data_dir=_TESTDATA, shuffle_files=False) + + +def _c4_dataset(split="train"): + return _load_dataset("c4:2.3.0", split=split) + + +def _spm_path(): + return os.path.join(_TESTDATA, "sentencepiece.model") + + +def _t5_gin_config(): + # The following pages worth of gin configuration are required because a lot + # of T5 functions have `gin.REQUIRED` in code, i.e. you cannot use these + # functions at all without having configured gin. + + noise_density = 0.15 + max_input_length = 50 + + # What preprocessors to apply - we select a random chunk of the document if + # it exceeds a certain lengths (`select_random_chunk`), then split up long + # examples (`split_tokens`) and finally the denoising objective (`denoise`). + # + # In addition to this T5 concates multiple documents together to reduce + # padding (`reduce_concat_tokens`) after `select_random_chunk`, but we skip + # that since we don't do sequence packing. + gin.bind_parameter( + "unsupervised.preprocessors", + [ + t5_processors.select_random_chunk, + t5_processors.split_tokens, + t5_processors.denoise, + ], + ) + + # select_random_chunk + gin.bind_parameter("select_random_chunk.feature_key", "targets") + gin.bind_parameter("select_random_chunk.max_length", max_input_length) + + # reduce_concat_tokens + gin.bind_parameter("random_spans_helper.extra_tokens_per_span_inputs", 1) + gin.bind_parameter("random_spans_helper.extra_tokens_per_span_targets", 1) + gin.bind_parameter("random_spans_helper.inputs_length", max_input_length) + gin.bind_parameter("random_spans_helper.mean_noise_span_length", 3.0) + gin.bind_parameter("random_spans_helper.noise_density", noise_density) + + # split_tokens + gin.bind_parameter( + "split_tokens.max_tokens_per_segment", + t5_processors.random_spans_tokens_length(), + ) + + # denoise + gin.bind_parameter("denoise.inputs_fn", t5_processors.noise_span_to_unique_sentinel) + gin.bind_parameter("denoise.noise_density", noise_density) + gin.bind_parameter("denoise.noise_mask_fn", t5_processors.random_spans_noise_mask) + gin.bind_parameter( + "denoise.targets_fn", t5_processors.nonnoise_span_to_unique_sentinel + ) + + +class TFInputsTest(tf.test.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def test_TFDS_single_host_with_eval_holdout(self): + self.skipTest("The version of the dataset you are trying is to old") + train_ds_gen = tf_inputs.TFDS( + "c4/en:2.3.0", + data_dir=_TESTDATA, + train=True, + host_id=0, + keys=("text",), + n_hosts=1, + eval_holdout_size=0.1, + ) + + # Just ensure that this doesn't crash. + for d in train_ds_gen(): + print(f"Train: {d}") + break + + valid_ds_gen = tf_inputs.TFDS( + "c4/en:2.3.0", + data_dir=_TESTDATA, + train=False, + host_id=0, + keys=("text",), + n_hosts=1, + eval_holdout_size=0.1, + ) + + # Just ensure that this doesn't crash. + for d in valid_ds_gen(): + print(f"Eval: {d}") + break + + def test_TFDS_single_host_with_eval_holdout_no_valid_split(self): + train_ds_gen = tf_inputs.TFDS( + "para_crawl/ende", + data_dir=_TESTDATA, + train=True, + host_id=0, + keys=("en", "de"), + n_hosts=1, + eval_holdout_size=0.1, + ) + + # Just ensure that this doesn't crash. + for d in train_ds_gen(): + print(f"Train: {d}") + break + + # para_crawl doesn't have a validation set, see that this still doesn't + # crash because of eval_holdout_set. + valid_ds_gen = tf_inputs.TFDS( + "para_crawl/ende", + data_dir=_TESTDATA, + train=False, + host_id=0, + keys=("en", "de"), + n_hosts=1, + eval_holdout_size=0.1, + ) + + # Just ensure that this doesn't crash. + for d in valid_ds_gen(): + print(f"Eval: {d}") + break + + def test_TFDS_mnli_split_is_eval(self): + with mock.patch("tensorflow_datasets.load") as tfds_load: + with mock.patch( + "trax.data.tf_inputs.download_and_prepare", lambda _, data_dir: data_dir + ): + _ = tf_inputs.TFDS( + "glue/mnli", keys=("premise", "hypothesis"), train=False + ) + call_kwargs = tfds_load.call_args[1] + self.assertEqual(call_kwargs["split"], "validation_matched") + + def test_TFDS_mnli_split_is_alt_eval(self): + with mock.patch("tensorflow_datasets.load") as tfds_load: + with mock.patch( + "trax.data.tf_inputs.download_and_prepare", lambda _, data_dir: data_dir + ): + _ = tf_inputs.TFDS( + "glue/mnli", + keys=("premise", "hypothesis"), + train=False, + use_alt_eval=True, + ) + call_kwargs = tfds_load.call_args[1] + self.assertEqual(call_kwargs["split"], "validation_mismatched") + + def test_convert_to_unicode(self): + def dataset1(): + yield (b"Audentes fortuna iuvat.", b"Fortune favors the bold.") + + def dataset2(): + yield (b"\x81aabb", b"Value") + + convert_function1 = tf_inputs.ConvertToUnicode(keys=[0]) + convert_output1 = next(convert_function1(dataset1())) + self.assertEqual(convert_output1[0], "Audentes fortuna iuvat.") + self.assertEqual(convert_output1[1], b"Fortune favors the bold.") + self.assertIsInstance(convert_output1[0], str) + self.assertIsInstance(convert_output1[1], bytes) + + # Contains an invalid bytes array from the point of view of UTF-8. + try: + convert_function2 = tf_inputs.ConvertToUnicode(keys=[0]) + convert_output2 = next(convert_function2(dataset2())) + except UnicodeDecodeError: + self.fail("ConvertToUnicode threw UnicodeDecodeError.") + self.assertEqual(convert_output2[0], "aabb") + self.assertIsInstance(convert_output2[0], str) + + def test_tokenize_detokenize(self): + def dataset(): + yield "I have a cat." + + # Character-level. + tok_char = list(tf_inputs.tokenize(dataset(), vocab_type="char")) + self.assertAllEqual(tok_char[0], np.array([ord(c) for c in "I have a cat."])) + detok = tf_inputs.detokenize(tok_char[0], vocab_type="char") + self.assertEqual(detok, "I have a cat.") + + # Sentencepiece. + tok_spc = list( + tf_inputs.tokenize( + dataset(), + vocab_type="sentencepiece", + vocab_dir=_TESTDATA, + vocab_file="sentencepiece.model", + ) + ) + self.assertAllEqual(tok_spc[0], np.array([27, 43, 3, 9, 1712, 5])) + detok = tf_inputs.detokenize( + list(tok_spc[0]), + vocab_type="sentencepiece", + vocab_dir=_TESTDATA, + vocab_file="sentencepiece.model", + ) + self.assertEqual(detok, "I have a cat.") + + # Subword. + tok_sbw = list( + tf_inputs.tokenize( + dataset(), + vocab_type="subword", + vocab_dir=_TESTDATA, + vocab_file="en_8k.subword", + ) + ) + self.assertAllEqual(tok_sbw[0], np.array([139, 96, 12, 2217, 2, 21])) + detok = tf_inputs.detokenize( + tok_sbw[0], + vocab_type="subword", + vocab_dir=_TESTDATA, + vocab_file="en_8k.subword", + ) + self.assertEqual(detok, "I have a cat.") + + # bert-lowercase + tok_sbw = list( + tf_inputs.tokenize( + dataset(), + vocab_type="bert-lowercase", + vocab_dir=_TESTDATA, + vocab_file="bert_uncased_vocab.txt", + ) + ) + self.assertAllEqual(tok_sbw[0], np.array([1045, 2031, 1037, 4937, 1012])) + detok = tf_inputs.detokenize( + tok_sbw[0], + vocab_type="bert-lowercase", + vocab_dir=_TESTDATA, + vocab_file="bert_uncased_vocab.txt", + ) + self.assertEqual(detok, "i have a cat .") + # note: BERT tokenizer is not reversible, therefore + # difference between original input + + def test_tokenize_keys_reservedids(self): + def dataset(): + yield ("Cat.", "Dog.") + + tok_char1 = list( + tf_inputs.tokenize(dataset(), vocab_type="char", n_reserved_ids=5) + ) + self.assertAllEqual(tok_char1[0][0], np.array([ord(c) + 5 for c in "Cat."])) + self.assertAllEqual(tok_char1[0][1], np.array([ord(c) + 5 for c in "Dog."])) + + tok_char2 = list( + tf_inputs.tokenize(dataset(), keys=[0], vocab_type="char", n_reserved_ids=2) + ) + self.assertAllEqual(tok_char2[0][0], np.array([ord(c) + 2 for c in "Cat."])) + self.assertEqual(tok_char2[0][1], "Dog.") + + def test_tokenize_dict(self): + def dataset(): + yield {"a": "Cat.", "b": "Dog."} + + tok_char1 = list(tf_inputs.tokenize(dataset(), vocab_type="char")) + self.assertAllEqual(tok_char1[0]["a"], np.array([ord(c) for c in "Cat."])) + self.assertAllEqual(tok_char1[0]["b"], np.array([ord(c) for c in "Dog."])) + + tok_char2 = list(tf_inputs.tokenize(dataset(), keys=["a"], vocab_type="char")) + self.assertAllEqual(tok_char2[0]["a"], np.array([ord(c) for c in "Cat."])) + self.assertEqual(tok_char2[0]["b"], "Dog.") + + def test_vocab_size(self): + # Character-level. + char_size = tf_inputs.vocab_size(vocab_type="char", n_reserved_ids=11) + self.assertEqual(char_size, 256 + 11) + # Sentencepiece. + spc_size = tf_inputs.vocab_size( + vocab_type="sentencepiece", + vocab_dir=_TESTDATA, + vocab_file="sentencepiece.model", + ) + self.assertEqual(spc_size, 32000) + # Subword. + sbw_size = tf_inputs.vocab_size( + vocab_type="subword", + vocab_dir=_TESTDATA, + vocab_file="en_8k.subword", + ) + self.assertEqual(sbw_size, 8183) + # Bert_uncased. + sbw_size = tf_inputs.vocab_size( + vocab_type="bert-lowercase", + vocab_dir=_TESTDATA, + vocab_file="bert_uncased_vocab.txt", + ) + self.assertEqual(sbw_size, 30522) + + def test_c4_bare_preprocess_fn(self): + dataset = _c4_dataset() + + example = list(tfds.as_numpy(dataset.take(1)))[0] + + # Targets are NOT in the example. + self.assertNotIn("targets", example) + self.assertIn("text", example) + text = example["text"] + + # This should convert the dataset to an inputs/targets that are tokenized. + dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) + + example = list(tfds.as_numpy(dataset.take(1)))[0] + + # Earlier text is now stored in targets_pretokenized + self.assertIn("targets_pretokenized", example) + self.assertEqual(example["targets_pretokenized"], text) + + # Targets are now tokenized. + self.assertIn("targets", example) + self.assertIsInstance(example["targets"], np.ndarray) + self.assertEqual(example["targets"].dtype, np.int64) + self.assertGreater(len(example["targets"]), 0) + self.assertEqual(example["targets"][-1], 1) # we add EOS at the end. + + # Inputs exist but is empty because t5 preprocessors' unsupervised wasn't + # gin configured with any. + self.assertIn("inputs", example) + self.assertEqual(len(example["inputs"]), 0) + + def test_c4_preprocess(self): + def load_c4_dataset(split="train"): + dataset = _c4_dataset(split=split) + return dataset.map(lambda example: (example, example["text"])) + + def examine_processed_dataset(proc_dataset): + count = 0 + lengths = [] + for example in tfds.as_numpy(proc_dataset): + count += 1 + ex = example[0] + # Targets are in the example. + self.assertIn("targets", ex) + self.assertEqual(ex["targets"].dtype, np.int64) + lengths.append(len(ex["targets"])) + return count, lengths + + unfiltered_count = 0 + for example in tfds.as_numpy(load_c4_dataset()): + unfiltered_count += 1 + # Targets are NOT in the example. + self.assertNotIn("targets", example[0]) + + proc_dataset = tf_inputs.c4_preprocess(load_c4_dataset(), False, 2048) + + # `examine_processed_dataset` has some asserts in it. + proc_count, char_lengths = examine_processed_dataset(proc_dataset) + + # Both the original and filtered datasets have examples. + self.assertGreater(unfiltered_count, 0) + self.assertGreater(proc_count, 0) + + # Because we filter out some entries on length. + self.assertLess(proc_count, unfiltered_count) + + # Preprocess using the sentencepiece model in testdata. + spc_proc_dataset = tf_inputs.c4_preprocess( + load_c4_dataset(), False, 2048, tokenization="spc", spm_path=_spm_path() + ) + + spc_proc_count, spc_lengths = examine_processed_dataset(spc_proc_dataset) + + # spc shortens the target sequence a lot, should be almost equal to + # unfiltered + self.assertLessEqual(proc_count, spc_proc_count) + self.assertEqual(unfiltered_count, spc_proc_count) + + # Assert all spc_lengths are lesser than their char counterparts. + for spc_len, char_len in zip(spc_lengths, char_lengths): + self.assertLessEqual(spc_len, char_len) + + def test_c4(self): + gin.bind_parameter("c4_preprocess.max_target_length", 2048) + gin.bind_parameter("c4_preprocess.tokenization", "spc") + gin.bind_parameter("c4_preprocess.spm_path", _spm_path()) + + # Just make sure this doesn't throw. + _ = tf_inputs.data_streams( + "c4", + data_dir=_TESTDATA, + input_name="targets", + target_name="text", + preprocess_fn=tf_inputs.c4_preprocess, + ) + + def test_c4_bare_preprocess_fn_denoising_objective(self): + _t5_gin_config() + + dataset = _c4_dataset() + dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) + + example = list(tfds.as_numpy(dataset.take(1)))[0] + + # Assertions now. + + self.assertIn("targets", example) + targets = example["targets"] + self.assertIsInstance(targets, np.ndarray) + self.assertEqual(targets.dtype, np.int64) + self.assertGreater(len(targets), 0) + + self.assertIn("inputs", example) + _inputs = example["inputs"] # pylint: disable=invalid-name + self.assertIsInstance(_inputs, np.ndarray) + self.assertEqual(_inputs.dtype, np.int64) + self.assertGreater(len(_inputs), 0) + + # WHP inputs will have the bulk of the text. + self.assertGreater(len(_inputs), len(targets)) + + # WHP there will be one sentinel token in the inputs and targets. + inputs_counter = collections.Counter(_inputs.tolist()) + targets_counter = collections.Counter(targets.tolist()) + self.assertEqual(1, inputs_counter[31999]) + self.assertEqual(1, targets_counter[31999]) + + def test_c4_pretrain(self): + _t5_gin_config() + + gin.bind_parameter("c4_bare_preprocess_fn.spm_path", _spm_path()) + + gin.bind_parameter("batcher.batch_size_per_device", 8) + gin.bind_parameter("batcher.eval_batch_size", 8) + gin.bind_parameter("batcher.max_eval_length", 50) + gin.bind_parameter("batcher.buckets", ([51], [8, 1])) + + # Just make sure this doesn't throw. + _ = tf_inputs.data_streams( + "c4", + data_dir=_TESTDATA, + input_name="inputs", + target_name="targets", + bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn, + ) + + def test_generic_text_dataset_preprocess_fn(self): + self.skipTest("google.protobuf.json_format.ParseError ...") + dataset = _load_dataset("squad/v1.1:3.0.0") + + (example,) = tfds.as_numpy(dataset.take(1)) + + self.assertNotIn("inputs", example) + self.assertNotIn("targets", example) + + proc_dataset = tf_inputs.generic_text_dataset_preprocess_fn( + dataset, + spm_path=_spm_path(), + text_preprocess_fns=[lambda ds, training: t5_processors.squad(ds)], + copy_pretokenized=True, + debug_print_examples=True, + debug_print_examples_rate=1.0, + ) + + (proc_example,) = tfds.as_numpy(proc_dataset.take(1)) + + self.assertIn("inputs", proc_example) + self.assertIn("targets", proc_example) + + self.assertEqual(proc_example["inputs"].dtype, np.int32) + self.assertEqual(proc_example["targets"].dtype, np.int32) + + # TODO(afrozm): Why does this test take so much time? + def test_inputs_using_generic_text_dataset_preprocess_fn(self): + self.skipTest("google.protobuf.json_format.ParseError ...") + gin.bind_parameter("generic_text_dataset_preprocess_fn.spm_path", _spm_path()) + gin.bind_parameter( + "generic_text_dataset_preprocess_fn.text_preprocess_fns", + [lambda ds, training: t5_processors.squad(ds)], + ) + + # Just make sure this doesn't throw. + def data_streams(): + return tf_inputs.data_streams( + "squad", + data_dir=_TESTDATA, + input_name="inputs", + target_name="targets", + bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, + shuffle_buffer_size=1, + ) + + n_devices = 3 + + squad_inputs = inputs.batcher( + data_streams=data_streams, + max_eval_length=512, + buckets=( + [ + 513, + ], + [n_devices, n_devices], + ), + ) + + eval_stream = squad_inputs.eval_stream(n_devices) + inps, tgts, _ = next(eval_stream) + + # We can only assert that the batch dim gets divided by n_devices. + self.assertEqual(inps.shape[0] % n_devices, 0) + self.assertEqual(tgts.shape[0] % n_devices, 0) + + def test_filter_dataset_on_len(self): + # {1, 2}, {2, 4}, {3, 6} ... {10, 20} + ds = _test_dataset_ints(range(1, 11), range(2, 21, 2)) + + ds1 = tf_inputs.filter_dataset_on_len( + ds, True, {"inputs": [4, 8], "targets": [14, 20]} + ) + # Only {7, 14} and {8, 16} satisfy this. + self.assertLen(list(ds1.as_numpy_iterator()), 2) + + ds2 = tf_inputs.filter_dataset_on_len( + ds, + False, + len_map={"inputs": [4, 8], "targets": [14, 20]}, + filter_on_eval=False, + ) + # This is eval and we aren't supposed to filter it. + self.assertLen(list(ds2.as_numpy_iterator()), 10) + + ds3 = tf_inputs.filter_dataset_on_len( + ds, + False, + len_map={"inputs": [4, 8], "targets": [14, 20]}, + filter_on_eval=True, + ) + # This is eval and we are asked to filter it. + self.assertLen(list(ds3.as_numpy_iterator()), 2) + + def test_truncate_dataset_on_len(self): + ds = _test_dataset_ints([5, 6, 7], [8, 9, 10]) + ds1 = tf_inputs.truncate_dataset_on_len( + ds, True, len_map={"inputs": 6, "targets": 4} + ) + expected_ds = _test_dataset_ints([5, 6, 6], [4, 4, 4]) + + # training, should filter. + assert_dataset(ds1, list(expected_ds.as_numpy_iterator())) + + # not Training, shouldn't filter. + ds2 = tf_inputs.truncate_dataset_on_len( + ds, False, len_map={"inputs": 6, "targets": 4} + ) + assert_dataset(ds2, list(ds.as_numpy_iterator())) + + # not Training, but asked to filter, should filter. + ds3 = tf_inputs.truncate_dataset_on_len( + ds, False, len_map={"inputs": 6, "targets": 4}, truncate_on_eval=True + ) + assert_dataset(ds3, list(expected_ds.as_numpy_iterator())) + + def test_get_t5_preprocessor_by_name(self): + gin.clear_config() + + gin.parse_config( + """ + get_t5_preprocessor_by_name.name = 'rekey' + get_t5_preprocessor_by_name.fn_kwargs = {'key_map': {'inputs': 'other', 'targets': 'text'}} + """ + ) + prep_rekey = tf_inputs.get_t5_preprocessor_by_name() + og_dataset = tf.data.Dataset.from_tensors( + {"text": "That is good.", "other": "That is bad."} + ) + training = True + dataset = prep_rekey(og_dataset, training) + assert_dataset(dataset, {"inputs": "That is bad.", "targets": "That is good."}) + + def test_pad_dataset_to_length(self): + ds = _test_dataset_ints([5, 6, 7], [6, 7, 8]) + ds1 = tf_inputs.pad_dataset_to_length( + ds, True, len_map={"inputs": 7, "targets": 10} + ) + + expected_ds = [ + { + "inputs": np.array([1, 1, 1, 1, 1, 0, 0], dtype=np.int64), + "targets": np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0], dtype=np.int64), + }, + { + "inputs": np.array([1, 1, 1, 1, 1, 1, 0], dtype=np.int64), + "targets": np.array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64), + }, + { + "inputs": np.array([1, 1, 1, 1, 1, 1, 1], dtype=np.int64), + "targets": np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0], dtype=np.int64), + }, + ] + + assert_dataset(ds1, expected_ds) + + def test_lm_token_preprocessing(self): + ds = _test_dataset_ints([1, 2, 3], [3, 2, 1]) + ds1 = tf_inputs.lm_token_preprocessing(ds, True) + + # pylint: disable=bad-whitespace + expected_ds = [ + { + "inputs": np.array([1, 0, 1, 1, 1], dtype=np.int64), + "targets": np.array([1, 0, 1, 1, 1], dtype=np.int64), + "mask": np.array([0, 0, 1, 1, 1], dtype=np.int64), + }, + { + "inputs": np.array([1, 1, 0, 1, 1], dtype=np.int64), + "targets": np.array([1, 1, 0, 1, 1], dtype=np.int64), + "mask": np.array([0, 0, 0, 1, 1], dtype=np.int64), + }, + { + "inputs": np.array([1, 1, 1, 0, 1], dtype=np.int64), + "targets": np.array([1, 1, 1, 0, 1], dtype=np.int64), + "mask": np.array([0, 0, 0, 0, 1], dtype=np.int64), + }, + ] + # pylint: enable=bad-whitespace + + assert_dataset(ds1, expected_ds) + + def test_create_bert_inputs(self): + inputs_sentences_1 = [np.array([100, 150, 200])] + inputs_sentences_2 = [np.array([300, 500])] + labels = [np.array(1)] + + create_inputs_1 = tf_inputs.CreateBertInputs(False) + create_inputs_2 = tf_inputs.CreateBertInputs(True) + for res in create_inputs_1(zip(inputs_sentences_1, labels)): + values, segment_embs, _, label, weight = res + self.assertAllEqual(values, np.array([101, 100, 150, 200, 102])) + self.assertAllEqual(segment_embs, np.zeros(5)) + self.assertEqual(label, np.int64(1)) + self.assertEqual(weight, np.int64(1)) + + for res in create_inputs_2(zip(inputs_sentences_1, inputs_sentences_2, labels)): + values, segment_embs, _, label, weight = res + self.assertAllEqual( + values, np.array([101, 100, 150, 200, 102, 300, 500, 102]) + ) + exp_segment = np.concatenate((np.zeros(5), np.ones(3))) + self.assertAllEqual(segment_embs, exp_segment) + self.assertEqual(label, np.int64(1)) + self.assertEqual(weight, np.int64(1)) + + def test_mask_random_tokens(self): + """Test only standard tokens. + + This test deals with sentences composed of two parts: [100 CLS tokens, 100 + chosen standard tokens]. CLS is the token that is added at the beginning of + the sentence and there is only one token in standard scenario. It is never + masked because it is not a part of the sentence. + This tests whether mask_random_tokens will: + - mask only standard tokens + - mask expected number of tokens (15 percent candidates for masking) + """ + cls_token = 101 + mask_token = 103 + example_standard_token = 1001 + test_case_row = np.array([cls_token] * 100 + [example_standard_token] * 100) + test_case = [(test_case_row.copy(),)] + + out, original_tokens, token_weights = next( + tf_inputs.mask_random_tokens(test_case) + ) + # test whether original tokens are unchanged + self.assertAllEqual(test_case_row, original_tokens) + + self.assertEqual(1, token_weights.sum()) + self.assertEqual( + 15, (token_weights > 0).sum() + ) # we should have 15 candidates for masking + + # 101 is a special token, so only 1001 should be masked + self.assertAllEqual(out[:100], test_case_row[:100]) + + # Each candidate has 0.8 probability to be masked while others have 0, so + # no more than 15 tokens with MASK + self.assertLessEqual((out == mask_token).sum(), 15) + + def test_bert_next_sentence_prediction_inputs(self): + stream = tf_inputs.BertNextSentencePredictionInputs( + "c4/en:2.3.0", data_dir=_TESTDATA, train=False, shuffle_size=1 + ) + exp_sent1 = "Police were called to the carriageway around 6." + exp_sent2 = "I am sorry we did not see how lost and alone you felt." + sent1, sent2, label = next(stream()) + self.assertEqual(exp_sent1, sent1) + self.assertEqual(exp_sent2, sent2) + self.assertFalse(label) + + def test_process_single_mathqa_example_0(self): + # This is the first problem in the MathQA dataset. + example = { + "Problem": "the banker ' s gain of a certain sum due 3 years hence at 10 % " + "per annum is rs . 36 . what is the present worth ?", + "Rationale": '"explanation : t = 3 years r = 10 % td = ( bg Γ— 100 ) / tr = ( ' + "36 Γ— 100 ) / ( 3 Γ— 10 ) = 12 Γ— 10 = rs . 120 td = ( pw Γ— tr )" + " / 100 β‡’ 120 = ( pw Γ— 3 Γ— 10 ) / 100 β‡’ 1200 = pw Γ— 3 pw = " + '1200 / 3 = rs . 400 answer : option a"', + "options": "a ) rs . 400 , b ) rs . 300 , c ) rs . 500 , d ) rs . 350 , e ) " + "none of these", + "correct": "a", + "annotated_formula": "divide(multiply(const_100, divide(multiply(36, const_100), " + "multiply(3, 10))), multiply(3, 10))", + "linear_formula": "multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)|multiply(#2,const_100)|divide(#3,#1)|", + "category": "gain", + } + + ( + answer_num, + python_result, + python_program, + list_op, + list_num, + ) = tf_inputs.process_single_mathqa_example(example) + self.assertEqual(answer_num, 400) # we know it, because correct answer is a) + self.assertEqual(python_result, [3600.0, 30.0, 120.0, 12000.0, 400.0]) + + self.assertEqual( + python_program, + [ + "t0 = n2 * 100.0", + "t1 = n0 * n1", + "t2 = t0 / t1", + "t3 = t2 * 100.0", + "t4 = t3 / t1", + ], + ) + self.assertEqual( + list_op, + [ + "multiply(n2,const_100)", + "multiply(n0,n1)", + "divide(#0,#1)", + "multiply(#2,const_100)", + "divide(#3,#1)", + ], + ) + self.assertEqual(list_num, [3.0, 10.0, 36.0]) + + def test_process_single_mathqa_example_1(self): + # This is the third problem in the MathQA dataset. + example = { + "Problem": "sophia finished 2 / 3 of a book . she calculated that she " + "finished 90 more pages than she has yet to read . how long is her" + " book ?", + "Rationale": "let xx be the total number of pages in the book , then she " + "finished 23 β‹… x 23 β‹… x pages . then she has x βˆ’ 23 β‹… x = " + "13 β‹… xx βˆ’ 23 β‹… x = 13 β‹… x pages left . 23 β‹… x βˆ’ 13 " + "β‹… x = 9023 β‹… x βˆ’ 13 β‹… x = 90 13 β‹… x = 9013 β‹… x = 90 x" + " = 270 x = 270 so the book is 270 pages long . answer : b", + "options": "a ) 229 , b ) 270 , c ) 877 , d ) 266 , e ) 281", + "correct": "b", + "annotated_formula": "divide(90, subtract(const_1, divide(2, 3)))", + "linear_formula": "divide(n0,n1)|subtract(const_1,#0)|divide(n2,#1)", + "category": "general", + } + + ( + answer_num, + python_result, + python_program, + list_op, + list_num, + ) = tf_inputs.process_single_mathqa_example(example) + self.assertEqual(answer_num, 270) # we know it, because correct answer is b) + self.assertAllClose( + python_result, [0.6666666666666666, 0.33333333333333337, 269.99999999999994] + ) + self.assertEqual( + python_program, ["t0 = n0 / n1", "t1 = 1.0 - t0", "t2 = n2 / t1"] + ) + self.assertEqual( + list_op, ["divide(n0,n1)", "subtract(const_1,#0)", "divide(n2,#1)"] + ) + self.assertEqual(list_num, [2.0, 3.0, 90.0]) + + def test_process_single_mathqa_example_with_import(self): + # This is a training MathQA problem which involve an import. + example = { + "Problem": "the length of a rectangular garden is three times its width . if " + "the area of the rectangular garden is 588 square meters , then " + "what is the width of the rectangular garden ?", + "Rationale": '"let x be the width of the garden . 3 x ^ 2 = 588 x ^ 2 = 196 x ' + '= 14 the answer is c ."', + "options": "a ) 12 , b ) 13 , c ) 14 , d ) 15 , e ) 16", + "correct": "c", + "annotated_formula": "sqrt(divide(588, const_3))", + "linear_formula": "divide(n0,const_3)|sqrt(#0)|", + "category": "geometry", + } + + ( + answer_num, + python_result, + python_program, + list_op, + list_num, + ) = tf_inputs.process_single_mathqa_example(example) + self.assertEqual(answer_num, 14) # we know it, because correct answer is c) + self.assertAllClose(python_result, [196, 14]) + self.assertEqual( + python_program, ["t0 = n0 / 3.0", "t1 = math.sqrt(max(0, t0))"] + ) + self.assertEqual(list_op, ["divide(n0,const_3)", "sqrt(#0)"]) + self.assertEqual(list_num, [588]) + + # Below we execute twice the Python program and once the DSL program. + target_values = "import math\n" + problem = example["Problem"] + for i in range(len(list_num)): + target_values += "n{} = {}\n".format(i, list_num[i]) + problem += " n{} = {}".format(i, list_num[i]) + target_values += "\n".join(python_program[:-1]) + final_line = python_program[-1].split("=")[1] + target_values += "\nanswer ={}".format(final_line) + var_dict = {} + exec(target_values, globals(), var_dict) # pylint: disable=exec-used + self.assertAllClose(var_dict["answer"], 14) + self.assertAllClose( + tf_inputs.execute_mathqa_program(problem, target_values.split("\n")), 14 + ) + self.assertAllClose( + tf_inputs.execute_mathqa_dsl_program(problem, [example["linear_formula"]]), + 14, + ) + + def test_sentencepiece_tokenize(self): + def dataset(): + yield "I have a cat." + + examples = [] + for example in tf_inputs.sentencepiece_tokenize(dataset(), _spm_path()): + examples.append(example) + toks = list(examples[0]) + self.assertSequenceEqual([27, 43, 3, 9, 1712, 5], toks) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/data/tokenizer_test.py b/tests/data/tokenizer_test.py new file mode 100644 index 000000000..5684324e2 --- /dev/null +++ b/tests/data/tokenizer_test.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data..tokenizer.""" +import os +import random + +import six +from six.moves import range # pylint: disable=redefined-builtin +import tensorflow.compat.v1 as tf +from trax.data import tokenizer + + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath(os.path.join(pkg_dir, "../../resources/data/testdata")) + + +class TokenizerTest(tf.test.TestCase): + def test_encode(self): + self.assertListEqual( + ["Dude", " - ", "that", "'", "s", "so", "cool", "."], + tokenizer.encode("Dude - that's so cool."), + ) + self.assertListEqual( + ["Łukasz", "est", "nΓ©", "en", "1981", "."], + tokenizer.encode("Łukasz est nΓ© en 1981."), + ) + self.assertListEqual( + [" ", "Spaces", "at", "the", "ends", " "], + tokenizer.encode(" Spaces at the ends "), + ) + self.assertListEqual(["802", ".", "11b"], tokenizer.encode("802.11b")) + self.assertListEqual(["two", ". \n", "lines"], tokenizer.encode("two. \nlines")) + + def test_decode(self): + self.assertEqual( + "Dude - that's so cool.", + tokenizer.decode(["Dude", " - ", "that", "'", "s", "so", "cool", "."]), + ) + + def test_invertibility_on_random_strings(self): + for _ in range(1000): + s = "".join(six.unichr(random.randint(0, 65535)) for _ in range(10)) + self.assertEqual(s, tokenizer.decode(tokenizer.encode(s))) + + +class TestTokenCounts(tf.test.TestCase): + def setUp(self): + super(TestTokenCounts, self).setUp() + self.corpus_path = os.path.join(_TESTDATA, "corpus-*.txt") + self.vocab_path = os.path.join(_TESTDATA, "vocab-*.txt") + + def test_corpus_token_counts_split_on_newlines(self): + token_counts = tokenizer.corpus_token_counts( + self.corpus_path, corpus_max_lines=0, split_on_newlines=True + ) + + expected = { + "'": 2, + ".": 2, + ". ": 1, + "... ": 1, + "Groucho": 1, + "Marx": 1, + "Mitch": 1, + "Hedberg": 1, + "I": 3, + "in": 2, + "my": 2, + "pajamas": 2, + } + self.assertDictContainsSubset(expected, token_counts) + self.assertNotIn(".\n\n", token_counts) + self.assertNotIn("\n", token_counts) + + def test_corpus_token_counts_no_split_on_newlines(self): + token_counts = tokenizer.corpus_token_counts( + self.corpus_path, corpus_max_lines=0, split_on_newlines=False + ) + + if ".\r\n\r\n" in token_counts.keys(): + token_counts.update({"\n\n": token_counts.pop(".\r\n\r\n")}) + + if "\r\n" in token_counts.keys(): + token_counts.update({"\n": token_counts.pop("\r\n")}) + + if ".\n\n" in token_counts.keys(): + token_counts.update({"\n\n": token_counts.pop(".\n\n")}) + + self.assertDictContainsSubset({"\n\n": 2, "\n": 3}, token_counts) + + def test_corpus_token_counts_split_with_max_lines(self): + token_counts = tokenizer.corpus_token_counts( + self.corpus_path, corpus_max_lines=5, split_on_newlines=True + ) + + self.assertIn("slept", token_counts) + self.assertNotIn("Mitch", token_counts) + + def test_corpus_token_counts_no_split_with_max_lines(self): + token_counts = tokenizer.corpus_token_counts( + self.corpus_path, corpus_max_lines=5, split_on_newlines=False + ) + + self.assertIn("slept", token_counts) + self.assertNotIn("Mitch", token_counts) + self.assertDictContainsSubset({".\n\n": 1, "\n": 2, ".\n": 1}, token_counts) + + def test_vocab_token_counts(self): + token_counts = tokenizer.vocab_token_counts(self.vocab_path, 0) + + expected = { + "lollipop": 8, + "reverberated": 12, + "kattywampus": 11, + "balderdash": 10, + "jiggery-pokery": 14, + } + self.assertDictEqual(expected, token_counts) + + def test_vocab_token_counts_with_max_lines(self): + # vocab-1 has 2 lines, vocab-2 has 3 + token_counts = tokenizer.vocab_token_counts(self.vocab_path, 5) + + expected = { + "lollipop": 8, + "reverberated": 12, + "kattywampus": 11, + "balderdash": 10, + } + self.assertDictEqual(expected, token_counts) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/fastmath/ops_test.py b/tests/fastmath/ops_test.py new file mode 100644 index 000000000..3f889028d --- /dev/null +++ b/tests/fastmath/ops_test.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.fastmath.ops.""" + +import collections +from absl.testing import parameterized + +import gin +import jax.numpy as jnp +import numpy as onp +from tensorflow import test +from trax import fastmath + + +_TestNamedtuple = collections.namedtuple("_TestNamedtuple", ["x"]) + + +class BackendTest(test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def override_gin(self, bindings): + gin.parse_config_files_and_bindings(None, bindings) + + def test_backend_imports_correctly(self): + backend = fastmath.backend() + self.assertEqual(jnp, backend["np"]) + self.assertNotEqual(onp, backend["np"]) + + self.override_gin("backend.name = 'numpy'") + + backend = fastmath.backend() + self.assertNotEqual(jnp, backend["np"]) + self.assertEqual(onp, backend["np"]) + + def test_backend_can_be_set(self): + self.assertEqual(fastmath.backend_name(), "jax") + fastmath.set_backend("tensorflow-numpy") + self.assertEqual(fastmath.backend_name(), "tensorflow-numpy") + fastmath.set_backend(None) + self.assertEqual(fastmath.backend_name(), "jax") + + def test_numpy_backend_delegation(self): + # Assert that we are getting JAX's numpy backend. + backend = fastmath.backend() + numpy = fastmath.numpy + self.assertEqual(jnp, backend["np"]) + + # Assert that `numpy` calls the appropriate gin configured functions and + # properties. + self.assertTrue(numpy.isinf(numpy.inf)) + self.assertEqual(jnp.isinf, numpy.isinf) + self.assertEqual(jnp.inf, numpy.inf) + + # Assert that we will now get the pure numpy backend. + + self.override_gin("backend.name = 'numpy'") + + backend = fastmath.backend() + numpy = fastmath.numpy + self.assertEqual(onp, backend["np"]) + + # Assert that `numpy` calls the appropriate gin configured functions and + # properties. + self.assertTrue(numpy.isinf(numpy.inf)) + self.assertEqual(onp.isinf, numpy.isinf) + self.assertEqual(onp.inf, numpy.inf) + + @parameterized.named_parameters( + ("_" + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP) + ) + def test_fori_loop(self, backend): + with fastmath.use_backend(backend): + res = fastmath.fori_loop(2, 5, lambda i, x: x + i, 1) + self.assertEqual(res, 1 + 2 + 3 + 4) + + def test_nested_map(self): + inp = {"a": ([0, 1], 2), "b": _TestNamedtuple(3)} + out = {"a": ([1, 2], 3), "b": _TestNamedtuple(4)} + self.assertEqual(fastmath.nested_map(lambda x: x + 1, inp), out) + + def test_nested_stack(self): + inp = [ + {"a": ([0, 1], 2), "b": _TestNamedtuple(3)}, + {"a": ([1, 2], 3), "b": _TestNamedtuple(4)}, + ] + out = {"a": ([[0, 1], [1, 2]], [2, 3]), "b": _TestNamedtuple([3, 4])} + onp.testing.assert_equal(fastmath.nested_stack(inp), out) + + def test_names_match(self): + # Names match up. + for backend_enum, backend_obj in fastmath.ops._backend_dict.items(): + self.assertEqual(backend_enum.value, backend_obj["name"]) + + # Every backend appears in the dictionary. + for backend_enum in fastmath.ops.Backend: + self.assertIn(backend_enum, fastmath.ops._backend_dict) + + def test_use_backend_str(self): + with fastmath.use_backend("tensorflow-numpy"): + self.assertEqual(fastmath.backend_name(), "tensorflow-numpy") + + def test_use_backend_enum(self): + with fastmath.use_backend(fastmath.Backend.NUMPY): + self.assertEqual(fastmath.backend_name(), "numpy") + + +if __name__ == "__main__": + test.main() diff --git a/trax/import_test.py b/tests/import_test.py similarity index 69% rename from trax/import_test.py rename to tests/import_test.py index 00051b8fb..07920e91d 100644 --- a/trax/import_test.py +++ b/tests/import_test.py @@ -19,18 +19,15 @@ class ImportTest(absltest.TestCase): + def test_import_trax(self): + try: + # Import trax + import trax # pylint: disable=g-import-not-at-top - def test_import_trax(self): - try: - # Import trax - import trax # pylint: disable=g-import-not-at-top - # Access a few symbols. - dir(trax.fastmath) - dir(trax.layers) - dir(trax.models) - except ImportError as e: - raise e + # Access a few symbols. + except ImportError as e: + raise e -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/acceleration_test.py b/tests/layers/acceleration_test.py new file mode 100644 index 000000000..6897937b0 --- /dev/null +++ b/tests/layers/acceleration_test.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acceleration.""" + +from absl.testing import absltest + +from jax.config import config +import numpy as np + +from trax import fastmath +from trax import layers as tl +from trax import shapes + + +class AccelerationTest(absltest.TestCase): + def test_accelerated_same_result(self): + layer = tl.Dense(2) + x = np.random.uniform(size=(8, 7)) + layer.init(shapes.signature(x)) + y = layer(x) + z = tl.Accelerate(layer)(x) + for i in range(8): + self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4) + self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4) + + def test_accelerated_pad(self): + layer = tl.Dense(2) + x = np.random.uniform(size=(3, 7)) + layer.init(shapes.signature(x)) + y = layer(x) + z = tl.Accelerate(layer)(x) + self.assertEqual(z.shape, y.shape) + for i in range(3): + self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4) + self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4) + + def test_accelerated_weighted_category_accuracy(self): + """Test multi-device aggregation of weights.""" + layer = tl.Accelerate(tl.WeightedCategoryAccuracy()) + weights = np.array([1.0, 1.0, 1.0, 0.0]) + targets = np.array([0, 1, 2, 3]) + + model_outputs = np.array( + [ + [0.2, 0.1, 0.7, 0.0], + [0.2, 0.1, 0.7, 0.0], + [0.2, 0.1, 0.7, 0.0], + [0.2, 0.1, 0.7, 0.0], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(np.mean(accuracy), 1 / 3) + + def test_chunk_memory(self): + """Test chunking here to exercise accelerator memory usage.""" + layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(128)) + chunked = tl.Chunk(layer, 256) + x = np.random.uniform(size=(16 * 1024, 16)) + chunked.init(shapes.signature(x)) + y = chunked(x) + z = tl.Accelerate(chunked)(x) + self.assertEqual(y.shape, (16 * 1024, 128)) + self.assertEqual(z.shape, (16 * 1024, 128)) + + def test_chunk_grad_memory(self): + """Test chunking gradient here to exercise accelerator memory usage.""" + layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(24)) + chunked = tl.Chunk(layer, 256) + + @fastmath.jit + def mock_training_step(x, weights, state, rng): + def compute_mock_loss(weights): + logits, new_state = chunked.pure_fn(x, weights, state, rng) + loss = fastmath.numpy.mean(logits) + return loss, (new_state, logits) + + gradients, (new_state, logits) = fastmath.grad( + compute_mock_loss, has_aux=True + )(weights) + new_weights = fastmath.nested_map_multiarg( + lambda w, g: w - 1e-4 * g, weights, gradients + ) + return new_weights, new_state, logits + + x = np.random.uniform(size=(32 * 1024, 16)) + chunked.init(shapes.signature(x)) + weights, _, logits = mock_training_step( + x, chunked.weights, chunked.state, fastmath.random.get_prng(0) + ) + self.assertEqual(logits.shape, (32 * 1024, 24)) + self.assertEqual(weights[1][0][0][0].shape, (16, 1024 * 1024)) + + +if __name__ == "__main__": + config.config_with_absl() + absltest.main() diff --git a/tests/layers/activation_fns_test.py b/tests/layers/activation_fns_test.py new file mode 100644 index 000000000..1afd16818 --- /dev/null +++ b/tests/layers/activation_fns_test.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for activation function layers.""" + +from absl.testing import absltest +import numpy as np + +import trax.layers as tl + + +class ActivationFnsTest(absltest.TestCase): + def test_relu(self): + layer = tl.Relu() + x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) + y = layer(x) + self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, 2.0, 3.0, 5.0]) + + def test_parametric_relu(self): + layer = tl.ParametricRelu(a=0.25) + x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) + y = layer(x) + self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, 0.5, 0.75, 1.25]) + + def test_leaky_relu(self): + layer = tl.LeakyRelu(a=0.125) + x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) + y = layer(x) + self.assertEqual(tl.to_list(y), [-0.25, -0.125, 0.0, 2.0, 3.0, 5.0]) + + def test_hard_sigmoid(self): + layer = tl.HardSigmoid() + x = np.array([-1.5, -0.5, -0.25, 0.0, 0.25, 0.5, 1.5]) + y = layer(x) + self.assertEqual(tl.to_list(y), [0.0, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0]) + + def test_hard_tanh(self): + layer = tl.HardTanh() + x = np.array([-1.5, -0.5, -0.25, 0.0, 0.25, 0.5, 1.5]) + y = layer(x) + self.assertEqual(tl.to_list(y), [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/assert_shape_test.py b/tests/layers/assert_shape_test.py new file mode 100644 index 000000000..7e8c53436 --- /dev/null +++ b/tests/layers/assert_shape_test.py @@ -0,0 +1,280 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for assert shape layers.""" + +from absl.testing import absltest +import numpy as np + +import trax.layers as tl + + +class AssertFunctionTest(absltest.TestCase): + """Test AssertFunction layer.""" + + def test_simple_pass(self): + layer = tl.AssertFunction("abc->abc", tl.Dropout(rate=0.1)) + x = np.ones((2, 5, 20)) + layer(x) + + def test_simple_fail(self): + layer = tl.AssertFunction("abc->cba", tl.Dropout(rate=0.1)) + x = np.ones((2, 5, 20)) + with self.assertRaises(tl.LayerError): + layer(x) + + def test_reduce_rank_ellipsis_pass(self): + layer = tl.AssertFunction("...ab->...c", tl.Flatten(n_axes_to_keep=3)) + x = np.ones((1, 2, 3, 4, 5)) + layer(x) + + def test_reduce_rank_explicit_pass(self): + layer = tl.AssertFunction("xyzab->xyzc", tl.Flatten(n_axes_to_keep=3)) + x = np.ones((1, 2, 3, 4, 5)) + layer(x) + + def test_reduce_rank_to_one_pass(self): + layer = tl.AssertFunction("abcde->x", tl.Flatten(n_axes_to_keep=0)) + x = np.ones((1, 2, 3, 4, 5)) + layer(x) + + def test_reduce_rank_explicit_fail1(self): + layer = tl.AssertFunction("abcde->abcde", tl.Flatten(n_axes_to_keep=3)) + x = np.ones((1, 2, 3, 4, 5)) + with self.assertRaises(tl.LayerError): + layer(x) + + def test_reduce_rank_explicit_fail2(self): + layer = tl.AssertFunction("abcde->abcd", tl.Flatten(n_axes_to_keep=3)) + x = np.ones((1, 2, 3, 4, 5)) + with self.assertRaises(tl.LayerError): + layer(x) + + def test_two_outputs_pass(self): + layer = tl.AssertFunction( + "...cd->...x,...cd", + tl.Branch( + tl.Flatten(n_axes_to_keep=2), + tl.Dropout(rate=0.1), + ), + ) + x = np.ones((1, 2, 3, 4)) + layer(x) + + def test_numeric_dimensions_pass(self): + layer = tl.AssertFunction( + "...34->1234,...34", + tl.Branch( + tl.Dropout(rate=0.1), + tl.Select([0]), + ), + ) + x = np.ones((1, 2, 3, 4)) + layer(x) + + def test_too_many_outputs_fail(self): + layer = tl.AssertFunction( + "...cd->...x,...cd,...cd,...cd", + tl.Branch( + tl.Flatten(n_axes_to_keep=2), + tl.Dropout(rate=0.1), + tl.Serial(), + ), + ) + x = np.ones((1, 2, 3, 4)) + with self.assertRaises(tl.LayerError): + layer(x) + + def test_multi_output_rank_fail(self): + layer = tl.AssertFunction( + "...34->...x,...y", + tl.Branch( + tl.Flatten(n_axes_to_keep=3), + tl.Serial(), + ), + ) + x = np.ones((1, 2, 3, 4)) + with self.assertRaises(tl.LayerError): + layer(x) + + +class AssertShapeTest(absltest.TestCase): + """Test AssertShape layer.""" + + def test_simple_pass(self): + layer = tl.AssertShape("aba,ba") + x = [np.ones((10, 5, 10)), np.zeros((5, 10))] + y = layer(x) + self.assertEqual(y, x) + + def test_same_shapes_pass(self): + layer = tl.AssertShape("aba,ba") + x = [np.ones((5, 5, 5)), np.zeros((5, 5))] + y = layer(x) + self.assertEqual(y, x) + + def test_single_arg_pass(self): + layer = tl.AssertShape("a") + x = np.ones((5,)) + y = layer(x) + self.assertEqual(y.tolist(), x.tolist()) + + def test_scalar_pass(self): + layer = tl.AssertShape("") + x = np.ones(()) + y = layer(x) + self.assertEqual(y.tolist(), x.tolist()) + + def test_square_matrix_pass(self): + layer = tl.AssertShape("aa") + x = np.ones((3, 3)) + y = layer(x) + self.assertEqual(y.tolist(), x.tolist()) + + def test_vector_scalar_pass(self): + layer = tl.AssertShape("a,") + x = [np.ones((5,)), np.zeros(())] + y = layer(x) + self.assertEqual(y, x) + + def test_three_args_pass(self): + layer = tl.AssertShape("a,b,a") + x = [np.ones((5,)), np.zeros((2)), np.zeros((5))] + y = layer(x) + self.assertEqual(y, x) + + def test_multiple_matching_dims_pass(self): + layer = tl.AssertShape("a,b,a,ab") + x = [np.ones((5,)), np.zeros((2)), np.zeros((5)), np.zeros((5, 2))] + y = layer(x) + self.assertEqual(y, x) + + def test_numeric_dims_pass(self): + layer = tl.AssertShape("23,1,93") + x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_numeric_dims_fail(self): + layer = tl.AssertShape("24,1,93") + x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipsis_middle_pass(self): + layer = tl.AssertShape("a...bc,abc") + x = [np.ones((1, 5, 5, 2, 3)), np.zeros((1, 2, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_ellipsis_prefix_pass(self): + layer = tl.AssertShape("...bc,abc") + x = [np.ones((5, 5, 2, 3)), np.zeros((1, 2, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_ellipsis_matching_zero_dims_pass(self): + layer = tl.AssertShape("...bc,abc") + x = [np.ones((2, 3)), np.zeros((1, 2, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_ellipsis_matching_ellipsis_pass(self): + layer = tl.AssertShape("...bc,...bc") + x = [np.ones((1, 2, 3)), np.zeros((1, 2, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_prefix_ellipsis_matching_sufix_ellipsis_pass(self): + layer = tl.AssertShape("bb...,...bb") + x = [np.ones((2, 2, 5, 6)), np.zeros((5, 6, 2, 2))] + y = layer(x) + self.assertEqual(y, x) + + def test_middle_ellipsis_fail(self): + layer = tl.AssertShape("ab...cde,2") + x = [np.ones((2, 3, 4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_short_middle_ellipsis_fail(self): + layer = tl.AssertShape("b...c,2") + x = [np.ones((2)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_double_ellipsis_fail(self): + layer = tl.AssertShape("b......c,2") + x = [np.ones((2, 3, 4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_typo_ellipsis_fail(self): + layer = tl.AssertShape("b..c,2") + x = [np.ones((2, 3, 4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipsis_matching_ellipsis_fail(self): + layer = tl.AssertShape("...a,...b") + x = [np.ones((1, 2, 3, 7)), np.zeros((1, 2, 8))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipsis_numeric_pass(self): + layer = tl.AssertShape("...22,...3") + x = [np.ones((1, 2, 3, 2, 2)), np.zeros((1, 2, 3, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_prefix_and_sufix_ellipsis_fail(self): + layer = tl.AssertShape("...c...,2") + x = [np.ones((2, 3, 4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipsis_too_few_dims_fail(self): + layer = tl.AssertShape("...abc,2") + x = [np.ones((4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipses_matching_dims_fail(self): + layer = tl.AssertShape("...2,...8") + x = [np.ones((1, 2, 3, 9)), np.zeros((1, 3, 3, 8))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_dims_matching_fail(self): + layer = tl.AssertShape("aba,ab") + x = [np.ones((10, 5, 10)), np.ones((5, 8))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_rank_fail(self): + layer = tl.AssertShape("aba,ab") + x = [np.ones((10, 5, 10)), np.ones((5, 10, 4))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_square_matrix_fail(self): + layer = tl.AssertShape("aa") + x = np.ones((10, 5)) + with self.assertRaises(tl.LayerError): + layer(x) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/attention_test.py b/tests/layers/attention_test.py new file mode 100644 index 000000000..09ffaa713 --- /dev/null +++ b/tests/layers/attention_test.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.attention.""" + +import functools +from absl.testing import absltest +import numpy as np + +from trax import shapes +import trax.layers as tl +from tests.layers import test_utils + + +class AttentionTest(absltest.TestCase): + def test_simple_call(self): + layer = tl.CausalAttention(d_feature=4, n_heads=2) + x = [ + np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ), + np.array([[[[1, 0, 1]]]]), + ] + _, _ = layer.init(shapes.signature(x)) + + y, mask = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + self.assertEqual(mask.shape, (1, 1, 1, 3)) + + def test_shift_right(self): + # Test shifts right on axis=1 + layer = tl.ShiftRight() + x = np.array( + [ + [[9, 9, 9], [8, 8, 8], [7, 7, 7], [6, 6, 6]], + [[99, 98, 97], [96, 95, 94], [93, 92, 91], [90, 89, 88]], + ] + ) + y = layer(x) + self.assertEqual(x.shape, y.shape) + self.assertEqual( + tl.to_list(y), + [ + [[0, 0, 0], [9, 9, 9], [8, 8, 8], [7, 7, 7]], + [[0, 0, 0], [99, 98, 97], [96, 95, 94], [93, 92, 91]], + ], + ) + + def test_shift_right_float(self): + layer = tl.ShiftRight() + x = np.array( + [ + [[9, 9, 9], [8, 8, 8], [7, 7, 7], [6, 6, 6]], + [[99, 98, 97], [96, 95, 94], [93, 92, 91], [90, 89, 88]], + ] + ).astype(np.float32) + x /= 2.0 + self.assertEqual(x.dtype, np.float32) + + y = layer(x) + self.assertEqual(y.dtype, np.float32) + self.assertEqual( + tl.to_list(y), + [ + [[0.0, 0.0, 0.0], [4.5, 4.5, 4.5], [4.0, 4.0, 4.0], [3.5, 3.5, 3.5]], + [ + [0.0, 0.0, 0.0], + [49.5, 49.0, 48.5], + [48.0, 47.5, 47.0], + [46.5, 46.0, 45.5], + ], + ], + ) + + def test_padding_mask(self): + layer = tl.PaddingMask() + x = np.array( + [ + [1.0, 2.0, 3.0, 4.0, 0.0], + [1.0, 2.0, 3.0, 0.0, 0.0], + [1.0, 2.0, 0.0, 0.0, 0.0], + ] + ) + y = layer(x) + self.assertEqual(x.shape, (3, 5)) + self.assertEqual(y.shape, (3, 1, 1, 5)) + np.testing.assert_equal( + y, + [ + [[[True, True, True, True, False]]], + [[[True, True, True, False, False]]], + [[[True, True, False, False, False]]], + ], + ) + + +class CausalAttentionTest(absltest.TestCase): + def test_simple_call(self): + layer = tl.CausalAttention(d_feature=4, n_heads=2) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + def test_deterministic_eval(self): + d_model = 32 + seq_len = 3 + x_shape = (1, seq_len, d_model) + inp = np.ones(x_shape).astype(np.float32) + + model_fn = functools.partial( + tl.CausalAttention, + d_feature=d_model, + n_heads=4, + ) + + test_utils.test_eval_is_deterministic(inp, model_fn) + + def test_predict_equals_eval(self): + d_model = 32 + seq_len = 10 + x_shape = (1, seq_len, d_model) + inp = np.ones(x_shape).astype(np.float32) + + model_fn = functools.partial( + tl.CausalAttention, + d_feature=d_model, + n_heads=4, + ) + + test_utils.test_eval_equals_predict(inp, model_fn) + + +class PositionalEncodingTest(absltest.TestCase): + def test_simple_call(self): + layer = tl.PositionalEncoding(max_len=8) + x = np.array([[[2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0]]]) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (1, 2, 4)) + + def test_predict(self): + layer = tl.PositionalEncoding(max_len=8) + x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) + self.assertEqual(x.shape, (1, 4, 2)) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (1, 4, 2)) + layer = tl.PositionalEncoding(max_len=8, mode="predict") + layer.init(shapes.signature(x[:, :1, :])) + y0 = layer(x[:, :1, :]) # just the first token + self.assertEqual(y0.shape, (1, 1, 2)) + self.assertTrue(np.array_equal(y0, y[:, :1, :])) + y1 = layer(x[:, 1:3, :]) # now the next 2 tokens + self.assertEqual(y1.shape, (1, 2, 2)) + self.assertTrue(np.array_equal(y1, y[:, 1:3, :])) + y2 = layer(x[:, 3:4, :]) # final one token + self.assertEqual(y2.shape, (1, 1, 2)) + self.assertTrue(np.array_equal(y2, y[:, 3:4, :])) + + def test_predict_equals_eval(self): + x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) + self.assertEqual(x.shape, (1, 4, 2)) + + layer_eval = tl.PositionalEncoding(max_len=8, d_feature=4, mode="eval") + layer_eval.init(shapes.signature(x)) + + output_eval = layer_eval(x) + + layer_predict = tl.PositionalEncoding(max_len=8, d_feature=4, mode="predict") + layer_predict.init(shapes.signature(x)) + layer_predict.weights = layer_eval.weights + + output_predict = layer_predict(x) + self.assertTrue(np.array_equal(output_eval, output_predict)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/base_test.py b/tests/layers/base_test.py new file mode 100644 index 000000000..a102709b0 --- /dev/null +++ b/tests/layers/base_test.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Trax base layer classes and generic layer-creating functions.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np + +from trax import fastmath +from trax import shapes +from trax.fastmath import numpy as jnp +import trax.layers as tl + +BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP] +CUSTOM_GRAD_BACKENDS = [fastmath.Backend.JAX] # TODO(afrozm): del after TF 2.3 + + +class BaseLayerTest(parameterized.TestCase): + def test_call_raises_error(self): + layer = tl.Layer() + x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) + with self.assertRaisesRegex(tl.LayerError, "NotImplementedError"): + _ = layer(x) + + def test_set_weighs_raises_error(self): + layer = tl.Layer() + layer.weights = 1.0 # can assign weights + with self.assertRaisesRegex(ValueError, "weighs"): + layer.weighs = 1.0 # cannot assign weighs + + def test_forward_raises_error(self): + layer = tl.Layer() + x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) + with self.assertRaises(NotImplementedError): + _ = layer.forward(x) + + def test_init_returns_empty_weights_and_state(self): + layer = tl.Layer() + input_signature = shapes.ShapeDtype((2, 5)) + weights, state = layer.init(input_signature) + self.assertEmpty(weights) + self.assertEmpty(state) + + def test_output_signature_no_weights(self): + shape_2_3_5 = shapes.ShapeDtype((2, 3, 5)) + input_signature = (shape_2_3_5, shape_2_3_5) + layer = tl.Fn("2in1out", lambda x, y: x + y) + output_signature = layer.output_signature(input_signature) + self.assertEqual(output_signature, shape_2_3_5) + + shape_5_7 = shapes.ShapeDtype((5, 7)) + input_signature = shape_5_7 + layer = tl.Fn("1in3out", lambda x: (x, 2 * x, 3 * x), n_out=3) + output_signature = layer.output_signature(input_signature) + self.assertEqual(output_signature, (shape_5_7, shape_5_7, shape_5_7)) + + # TODO(jonni): Define/test behavior of output signature for layers w/weights. + + @parameterized.named_parameters([("_" + b.value, b) for b in CUSTOM_GRAD_BACKENDS]) + def test_custom_zero_grad(self, backend): + class IdWithZeroGrad(tl.Layer): + def forward(self, x): + return x + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, grad, weights, state, new_state, rng): + return (jnp.zeros_like(grad), ()) + + with fastmath.use_backend(backend): + layer = IdWithZeroGrad() + rng = fastmath.random.get_prng(0) + input_signature = shapes.ShapeDtype((9, 17)) + random_input = fastmath.random.uniform( + rng, input_signature.shape, minval=-1.0, maxval=1.0 + ) + layer.init(input_signature) + f = lambda x: jnp.mean(layer(x)) + grad = fastmath.grad(f)(random_input) + self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. + self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0. + + @parameterized.named_parameters([("_" + b.value, b) for b in CUSTOM_GRAD_BACKENDS]) + def test_custom_id_grad(self, backend): + class IdWithIdGrad(tl.Layer): + def forward(self, x): + return x + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, grad, weights, state, new_state, rng): + return (inputs, ()) + + with fastmath.use_backend(backend): + layer = IdWithIdGrad() + rng = fastmath.random.get_prng(0) + input_signature = shapes.ShapeDtype((9, 17)) + random_input = fastmath.random.uniform( + rng, input_signature.shape, minval=-1.0, maxval=1.0 + ) + layer.init(input_signature) + f = lambda x: jnp.mean(layer(x)) + grad = fastmath.grad(f)(random_input) + self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. + self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input. + + def test_weights_and_state_signature(self): + class MyLayer(tl.Layer): + def init_weights_and_state(self, input_signature): + self.weights = jnp.zeros((2, 3)) + self.state = jnp.ones(input_signature.shape) + + def forward(self, inputs): + return self.weights + self.state + + layer = MyLayer() + w, s = layer.weights_and_state_signature(jnp.zeros((3, 4))) + self.assertEqual(w.shape, (2, 3)) + self.assertEqual(s.shape, (3, 4)) + + def test_custom_name(self): + layer = tl.Layer() + self.assertIn("Layer", str(layer)) + self.assertNotIn("CustomLayer", str(layer)) + + layer = tl.Layer(name="CustomLayer") + self.assertIn("CustomLayer", str(layer)) + + +class PureLayerTest(absltest.TestCase): + def test_forward(self): + layer = tl.PureLayer( + lambda x: 2 * x[0] + ) # Pure layer cast input to tuple (input,) so x is a tuple + + # Use Layer.__call__. + in_0 = np.array([1, 2]) + out_0 = layer(in_0, weights=jnp.zeros((2, 3))) + self.assertEqual(out_0.tolist(), [2, 4]) + self.assertEmpty(layer.weights) + + # Use PureLayer.forward. + in_1 = np.array([3, 4]) + out_1 = layer.forward(in_1) + self.assertEqual(out_1.tolist(), [6, 8]) + + # Use Layer.pure_fn + in_2 = np.array([5, 6]) + out_2, _ = layer.pure_fn(in_2, tl.EMPTY_WEIGHTS, tl.EMPTY_WEIGHTS, None) + self.assertEqual(out_2.tolist(), [10, 12]) + + +class FnTest(absltest.TestCase): + def test_bad_f_has_default_arg(self): + with self.assertRaisesRegex(ValueError, "default arg"): + _ = tl.Fn("", lambda x, sth=None: x) + + def test_bad_f_has_keyword_arg(self): + with self.assertRaisesRegex(ValueError, "keyword arg"): + _ = tl.Fn("", lambda x, **kwargs: x) + + def test_bad_f_has_variable_arg(self): + with self.assertRaisesRegex(ValueError, "variable arg"): + _ = tl.Fn("", lambda *args: args[0]) + + def test_forward(self): + layer = tl.Fn( + "SumAndMax", lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2 + ) + + x0 = np.array([1, 2, 3, 4, 5]) + x1 = np.array([10, 20, 30, 40, 50]) + + y0, y1 = layer((x0, x1)) + self.assertEqual(y0.tolist(), [11, 22, 33, 44, 55]) + self.assertEqual(y1.tolist(), [10, 20, 30, 40, 50]) + + y2, y3 = layer.forward((x0, x1)) + self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55]) + self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50]) + + (y4, y5), state = layer.pure_fn( + (x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE, None + ) + self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55]) + self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50]) + self.assertEqual(state, tl.EMPTY_STATE) + + def test_weights_state(self): + layer = tl.Fn( + "2in2out", lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), n_out=2 + ) + layer.init_weights_and_state(None) + self.assertEmpty(layer.weights) + self.assertEmpty(layer.state) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/combinators_test.py b/tests/layers/combinators_test.py new file mode 100644 index 000000000..2fce7ec27 --- /dev/null +++ b/tests/layers/combinators_test.py @@ -0,0 +1,748 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for combinator layers.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np + +from trax import fastmath +from trax import shapes +import trax.layers as tl + + +def DivideBy(val): # pylint: disable=invalid-name + """Returns a simple division layer with n_in == 1 and n_out == 1.""" + return tl.Fn("DivideBy", lambda x: x / val) + + +def ReturnConst(val): # pylint: disable=invalid-name + """Returns a simple const layer with n_in == 0 and n_out == 1.""" + return tl.Fn("ReturnConst", lambda: val) + + +def SmallerThan(val): # pylint: disable=invalid-name + """Checks if the input is smaller than certain value.""" + return tl.Fn("SmallerThan", lambda x: x < val) + + +# TODO(jonni): Consider a more generic home for this utiliity function. +def as_list(outputs): + """Converts layer outputs to a nested list, for easier equality testing. + + Args: + outputs: A tensor or tuple/list of tensors coming from the forward + application of a layer. Each tensor is NumPy ndarray-like, which + complicates simple equality testing (e.g., via `assertEquals`): + such tensors require equality testing to use either `all` (all + elements match) or `any` (at least one element matches), which is not + directly supported in absltest. + + Returns: + A nested list structure containing all the output values, but now directly + testable using `assertEquals`. + """ + if isinstance(outputs, (list, tuple)): + return [as_list(y) for y in outputs] + else: + return outputs.tolist() + + +class SerialTest(absltest.TestCase): + def test_none_is_no_op(self): + layer = tl.Serial(None) + xs = [np.array([1, 2, 3, 4]), np.array([10, 20, 30])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3, 4], [10, 20, 30]]) + + def test_empty_list_is_no_op(self): + layer = tl.Serial([]) + xs = [np.array([1, 2, 3, 4]), np.array([10, 20, 30])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3, 4], [10, 20, 30]]) + + def test_one_in_one_out(self): + layer = tl.Serial(DivideBy(3)) + x = np.array([3, 6, 9, 12]) + y = layer(x) + self.assertEqual(as_list(y), [1, 2, 3, 4]) + + def test_zero_in_one_out(self): + layer = tl.Serial(ReturnConst(np.array([3, 4, 5, 6]))) + y = layer(()) + self.assertEqual(as_list(y), [3, 4, 5, 6]) + + def test_one_in_two_out(self): + layer = tl.Serial(DivideBy(3), ReturnConst(np.array([3, 4, 5, 6]))) + x = np.array([3, 6, 9, 12]) + y = layer(x) + self.assertEqual(as_list(y), [[3, 4, 5, 6], [1, 2, 3, 4]]) + + def test_const_div(self): + layer = tl.Serial(ReturnConst(np.array([3, 6, 9, 12])), DivideBy(3)) + y = layer(()) + self.assertEqual(as_list(y), [1, 2, 3, 4]) + + def test_div_div(self): + layer = tl.Serial(DivideBy(2.0), DivideBy(5.0)) + x = np.array([10, 20, 30]) + y = layer(x) + self.assertEqual(as_list(y), [1, 2, 3]) + + def test_dup_dup(self): + layer = tl.Serial(tl.Dup(), tl.Dup()) + x = np.array([1, 2, 3]) + ys = layer(x) + self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + + def test_default_name(self): + layer = tl.Serial(tl.Dup(), tl.Dup()) + self.assertIn("Serial", str(layer)) + + def test_custom_name(self): + layer = tl.Serial(tl.Dup(), tl.Dup(), name="Branch") + self.assertIn("Branch", str(layer)) + + def test_weights(self): + model = tl.Serial(tl.Dense(4), tl.Dense(5), tl.Dense(7)) + self.assertIsInstance(model.weights, tuple) + self.assertLen(model.weights, 3) + + def test_flat_weights_and_state(self): + model = tl.Serial(tl.Dup(), tl.Dense(5), tl.Serial(tl.Dense(7), tl.Dup())) + sample_input_signature = shapes.signature(np.zeros((2, 3))) + model.init(sample_input_signature) + flat_weights, flat_state = tl.flatten_weights_and_state( + model.weights, model.state + ) + # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. + # So after making them flat, there are 4 trainable weights. + self.assertLen(flat_weights, 4) + self.assertEmpty(flat_state) + model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(7)) + sig = model2.weights_and_state_signature(sample_input_signature) + weights2, state2 = tl.unflatten_weights_and_state(flat_weights, flat_state, sig) + model2.weights = weights2 + model2.state = state2 + self.assertLen(model2.weights, 3) + self.assertEqual(model.weights[1], model2.weights[0]) + self.assertEqual(model.weights[2][0], model2.weights[2]) + + def test_flat_weights_and_state_shared(self): + shared = tl.Dense(5) + model = tl.Serial(tl.Dense(5), shared, tl.Serial(shared, tl.Dup())) + sample_input_signature = shapes.signature(np.zeros((2, 3))) + model.init(sample_input_signature) + flat_weights, flat_state = tl.flatten_weights_and_state( + model.weights, model.state + ) + # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. + # So after making them flat, there are 4 trainable weights. + self.assertLen(flat_weights, 4) + self.assertEmpty(flat_state) + model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(5)) + sig = model2.weights_and_state_signature(sample_input_signature) + weights2, state2 = tl.unflatten_weights_and_state(flat_weights, flat_state, sig) + model2.weights = weights2 + model2.state = state2 + self.assertLen(model2.weights, 3) + self.assertEqual(model.weights[0], model2.weights[0]) + self.assertEqual(model.weights[1], model2.weights[2]) + + def test_assign_sublayer_weights(self): + layer = tl.Dense(5, use_bias=False) + model = tl.Serial(tl.Serial(layer, tl.Dense(6)), tl.Dense(7)) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + new_layer_weights = np.random.uniform(weights[0][0].shape) + layer.weights = new_layer_weights + self.assertIs(model.weights[0][0], new_layer_weights) + + def test_shared_weights(self): + layer = tl.Dense(5) + model = tl.Serial(layer, layer) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_nested(self): + layer = tl.Dense(5) + model = tl.Serial(layer, tl.Serial(layer)) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_double_nested(self): + layer = tl.Dense(5) + model = tl.Serial(tl.Serial(layer), tl.Serial(layer)) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_for_shared_serial(self): + layer = tl.Serial(tl.Dense(5), tl.Dense(5)) + model = tl.Serial(layer, layer) + sample_input = np.array([1, 2, 3, 4, 5]) + # Init gives weights reflecting weight sharing. + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIsNot(weights[0], tl.GET_WEIGHTS_FROM_CACHE) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + # Forward pass runs successfully. + y = model(sample_input) + self.assertEqual(y.shape, (5,)) + + def test_state(self): + model = tl.Serial(tl.Dense(4), tl.Dense(5), tl.Dense(7)) + self.assertIsInstance(model.state, tuple) + self.assertLen(model.state, 3) + + def test_set_rng_recurse_two_levels(self): + dense_00 = tl.Dense(2) + dense_01 = tl.Dense(2) + dense_10 = tl.Dense(2) + dense_11 = tl.Dense(2) + layer = tl.Serial( + tl.Serial(dense_00, dense_01), + tl.Serial(dense_10, dense_11), + ) + input_signature = shapes.ShapeDtype((1, 2)) + + _, _ = layer.init(input_signature) + weights = layer.weights + dense_00_w, dense_00_b = weights[0][0] + dense_01_w, dense_01_b = weights[0][1] + dense_10_w, dense_10_b = weights[1][0] + dense_11_w, dense_11_b = weights[1][1] + + # Setting rng's recursively during init should yield differing weights. + self.assertFalse(np.array_equal(dense_00_w, dense_01_w)) + self.assertFalse(np.array_equal(dense_00_b, dense_01_b)) + self.assertFalse(np.array_equal(dense_10_w, dense_11_w)) + self.assertFalse(np.array_equal(dense_10_b, dense_11_b)) + + +class ParallelTest(absltest.TestCase): + def test_dup_dup(self): + layer = tl.Parallel(tl.Dup(), tl.Dup()) + xs = [np.array([1, 2, 3]), np.array([10, 20])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [10, 20], [10, 20]]) + + def test_div_div(self): + layer = tl.Parallel(DivideBy(0.5), DivideBy(3.0)) + xs = [np.array([1, 2, 3]), np.array([30, 60])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[2, 4, 6], [10, 20]]) + + def test_two_no_ops(self): + layer = tl.Parallel(tl.Select([0]), tl.Select([0])) + xs = (np.array([1, 2, 3]), np.array([10, 20])) + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3], [10, 20]]) + + def test_default_name(self): + layer = tl.Parallel(tl.Dup(), tl.Dup()) + self.assertIn("Parallel", str(layer)) + + def test_custom_name(self): + layer = tl.Parallel(tl.Dup(), tl.Dup(), name="DupDup") + self.assertIn("DupDup", str(layer)) + + def test_weights(self): + model = tl.Parallel(tl.Dense(3), tl.Dense(5)) + self.assertIsInstance(model.weights, tuple) + self.assertLen(model.weights, 2) + + def test_shared_weights(self): + layer = tl.Dense(5) + model = tl.Parallel(layer, layer) + sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_nested(self): + layer = tl.Dense(5) + model = tl.Parallel([layer, tl.Dense(2)], [layer, tl.Dense(2)]) + sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_for_shared_parallel(self): + layer = tl.Parallel(tl.Dense(5), tl.Dense(7)) + model = tl.Parallel(layer, layer) + sample_input = [ + np.array([1, 2, 3]), + np.array([10, 20, 30]), + np.array([100, 200, 300]), + np.array([1000, 2000, 3000]), + ] + # Init gives weights reflecting weight sharing. + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIsNot(weights[0], tl.GET_WEIGHTS_FROM_CACHE) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + # Forward pass runs successfully. + y0, y1, y2, y3 = model(sample_input) + self.assertEqual(y0.shape, (5,)) + self.assertEqual(y1.shape, (7,)) + self.assertEqual(y2.shape, (5,)) + self.assertEqual(y3.shape, (7,)) + + def test_state(self): + model = tl.Parallel(tl.Dense(3), tl.Dense(5)) + self.assertIsInstance(model.state, tuple) + self.assertLen(model.state, 2) + + +class ConcatenateTest(absltest.TestCase): + def test_n_in_n_out(self): + layer = tl.Concatenate() + self.assertEqual(layer.n_in, 2) + self.assertEqual(layer.n_out, 1) + + def test_with_defaults(self): + layer = tl.Concatenate() # Default n_items=2, axis=-1 + xs = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[10, 20, 30], [40, 50, 60]])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3, 10, 20, 30], [4, 5, 6, 40, 50, 60]]) + + def test_axis_0(self): + layer = tl.Concatenate(axis=0) + xs = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[10, 20, 30], [40, 50, 60]])] + y = layer(xs) + self.assertEqual(as_list(y), [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]]) + + def test_axis_1(self): + layer = tl.Concatenate(axis=1) + xs = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[10, 20, 30], [40, 50, 60]])] + y = layer(xs) + self.assertEqual(as_list(y), [[1, 2, 3, 10, 20, 30], [4, 5, 6, 40, 50, 60]]) + + def test_n_items_is_not_default(self): + layer = tl.Concatenate(n_items=3) + xs = [ + np.array([[1, 2, 3], [4, 5, 6]]), + np.array([[10, 20, 30], [40, 50, 60]]), + np.array([[100, 200, 300], [400, 500, 600]]), + ] + y = layer(xs) + self.assertEqual(y.shape, (2, 9)) + self.assertEqual( + as_list(y), + [ + [1, 2, 3, 10, 20, 30, 100, 200, 300], + [4, 5, 6, 40, 50, 60, 400, 500, 600], + ], + ) + + def test_repr(self): + layer = tl.Concatenate() + self.assertEqual(repr(layer), "Concatenate_in2") + + layer = tl.Concatenate(axis=0) + self.assertEqual(repr(layer), "Concatenate_axis0_in2") + + layer = tl.Concatenate(axis=1) + self.assertEqual(repr(layer), "Concatenate_axis1_in2") + + layer = tl.Concatenate(n_items=3) + self.assertEqual(repr(layer), "Concatenate_in3") + + +class BranchTest(absltest.TestCase): + def test_noop_dup(self): + layer = tl.Branch(tl.Select([0]), tl.Dup()) + x = np.array([1, 2, 3]) + ys = layer(x) + self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + + def test_add_div(self): + layer = tl.Branch(tl.Add(), DivideBy(0.5)) + xs = [np.array([1, 2, 3]), np.array([10, 20, 30])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[11, 22, 33], [2, 4, 6]]) + + def test_one_sublayer(self): + layer = tl.Branch(DivideBy(0.5)) + x = np.array([1, 2, 3]) + ys = layer(x) + self.assertEqual(as_list(ys), [2, 4, 6]) + + def test_default_name(self): + layer = tl.Branch(tl.Add(), DivideBy(0.5)) + self.assertIn("Branch", str(layer)) + + def test_printing_sublayers(self): + layer = tl.Branch(tl.Add(), tl.Add()) + expected_result = "Branch_in2_out2[\n Add_in2\n Add_in2\n]" + self.assertEqual(expected_result, str(layer)) + + +class SelectTest(absltest.TestCase): + def test_computes_n_in(self): + layer = tl.Select([0, 0]) + self.assertEqual(layer.n_in, 1) + + layer = tl.Select([1, 0]) + self.assertEqual(layer.n_in, 2) + + layer = tl.Select([2]) + self.assertEqual(layer.n_in, 3) + + def test_given_n_in(self): + layer = tl.Select([0], n_in=2) + self.assertEqual(layer.n_in, 2) + + layer = tl.Select([0], n_in=3) + self.assertEqual(layer.n_in, 3) + + def test_first_of_3(self): + layer = tl.Select([0], n_in=3) + xs = [np.array([1, 2, 3]), np.array([10, 20]), np.array([100])] + y = layer(xs) + self.assertEqual(as_list(y), [1, 2, 3]) + + def test_second_of_3(self): + layer = tl.Select([1], n_in=3) + xs = [np.array([1, 2, 3]), np.array([10, 20]), np.array([100])] + y = layer(xs) + self.assertEqual(as_list(y), [10, 20]) + + +class DropTest(absltest.TestCase): + def test_drop(self): + layer = tl.Drop() + x = np.array([1, 2, 3]) + y = layer(x) + self.assertEqual(as_list(y), []) + + +class SwapTest(absltest.TestCase): + def test_swap(self): + layer = tl.Swap() + xs = [np.array([1, 2, 3]), np.array([10, 20, 30])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[10, 20, 30], [1, 2, 3]]) + + +class ChunkTest(absltest.TestCase): + def test_chunk(self): + layer = tl.Dense(4) + x = np.array([[1, 2, 3], [4, 5, 6]]) + layer.init(x) + y = layer(x) + z = tl.Chunk(layer, 1)(x) + self.assertLess(np.sum((y - z) ** 2), 1e-5) # y == z upto numerics + + def test_chunk_uneven_numbers(self): + layer = tl.Dense(4) + x = np.array([[1, 2, 3], [4, 5, 6]]) + layer.init(x) + y = layer(x) + z = tl.Chunk(layer, 3)(x) # By default it should just pass + self.assertLess(np.sum((y - z) ** 2), 1e-5) # y == z upto numerics + chunk_with_test = tl.Chunk(layer, 3, pass_unchunkable=False) + self.assertRaises(tl.LayerError, lambda: chunk_with_test(x)) + + +class SerialWithSideOutputsTest(absltest.TestCase): + def test_serial_with_side_outputs_div_div(self): + def some_layer(): + return tl.Parallel(DivideBy(2.0), DivideBy(5.0)) + + layer = tl.SerialWithSideOutputs([some_layer(), some_layer()]) + xs = (np.array([1, 2, 3]), np.array([10, 20, 30, 40, 50]), np.array([100, 200])) + ys = layer(xs) + output_shapes = [y.shape for y in ys] + self.assertEqual(output_shapes, [(3,), (5,), (2,)]) + + +BACKENDS = [fastmath.Backend.JAX] + + +@parameterized.named_parameters(("_" + b.value, b) for b in BACKENDS) +class ScanTest(parameterized.TestCase): + def _AddWithCarry(self): # pylint: disable=invalid-name + del self + + def f(x, carry): + res = x + carry + return res, res # output and carry are the same + + return tl.Fn("AddWithCarry", f, n_out=2) + + def test_default_axis(self, backend): + with fastmath.use_backend(backend): + layer = tl.Scan(self._AddWithCarry()) + xs = [ + np.array([[0, 1, 2, 3], [0, 10, 20, 30], [0, 100, 200, 300]]), + np.array([9000, 8000, 7000, 6000]), + ] + ys = layer(xs) + self.assertEqual( + as_list(ys), + [ + [ + [9000, 8001, 7002, 6003], + [9000, 8011, 7022, 6033], + [9000, 8111, 7222, 6333], + ], + [9000, 8111, 7222, 6333], + ], + ) + + def test_axis_1(self, backend): + with fastmath.use_backend(backend): + layer = tl.Scan(self._AddWithCarry(), axis=1) + xs = [ + np.array([[0, 1, 2, 3], [0, 10, 20, 30], [0, 100, 200, 300]]), + np.array([9000, 8000, 7000]), + ] + ys = layer(xs) + self.assertEqual( + as_list(ys), + [ + [ + [9000, 9001, 9003, 9006], + [8000, 8010, 8030, 8060], + [7000, 7100, 7300, 7600], + ], + [9006, 8060, 7600], + ], + ) + + def test_predict(self, backend): + with fastmath.use_backend(backend): + layer = tl.Scan(self._AddWithCarry(), axis=1, mode="predict") + xs = [np.array([[0, 1, 2]]), np.array([90])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[[90, 91, 93]], [93]]) + xs = [np.array([[3, 4]]), np.array([90])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[[96, 100]], [100]]) + + def test_multi_input(self, backend): + def _MultiInputFn(): # pylint: disable=invalid-name + def f(a, b, carry): + return a + b, b, carry + 1 + + return tl.Fn("MultiInputFn", f, n_out=2) + + with fastmath.use_backend(backend): + layer = tl.Scan(_MultiInputFn(), axis=1) + xs = [ + np.array([[0, 1, 2], [0, 10, 20]]), + np.array([[4, 5, 6], [40, 50, 60]]), + np.array([9000, 8000]), + ] + ys = layer(xs) + self.assertEqual( + as_list(ys), + [[[4, 6, 8], [40, 60, 80]], [[4, 5, 6], [40, 50, 60]], [9003, 8003]], + ) + + def test_no_carry(self, backend): + def _AddOne(): # pylint: disable=invalid-name + return tl.Fn("AddOne", lambda x: x + 1) + + with fastmath.use_backend(backend): + layer = tl.Scan(_AddOne(), n_carry=0) + x = np.array([[1, 3, 7], [10, 30, 70]]) + y = layer(x) + self.assertEqual(as_list(y), [[2, 4, 8], [11, 31, 71]]) + + +class CondTest(absltest.TestCase): + def test_basic_true(self): + cond = ReturnConst(True) + true = ReturnConst([2]) + false = ReturnConst([5]) + layer = tl.Cond(cond, true, false) + layer.init(()) + xs = tuple() + ys = layer(xs) + self.assertEqual(as_list(ys), 2) + + def test_basic_false(self): + cond = ReturnConst(False) + true = ReturnConst([2]) + false = ReturnConst([5]) + layer = tl.Cond(cond, true, false) + layer.init(()) + xs = tuple() + ys = layer(xs) + self.assertEqual(as_list(ys), 5) + + def test_complex_blocks(self): + cond = ReturnConst(True) + true = DivideBy(2.0) + false = DivideBy(4.0) + layer = tl.Cond(cond, true, false) + xs = [np.arange(5).astype(np.float32)] + layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual(as_list(ys), [0.0, 0.5, 1.0, 1.5, 2.0]) + + def test_condition_func_true(self): + cond = SmallerThan(3.0) + true = DivideBy(2.0) + false = DivideBy(4.0) + layer = tl.Cond(cond, true, false) + xs = (np.array(2.0), np.array([4.0, 12.0])) + layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual(as_list(ys), [2.0, 6.0]) + + def test_condition_func_false(self): + cond = SmallerThan(3.0) + true = DivideBy(2.0) + false = DivideBy(4.0) + layer = tl.Cond(cond, true, false) + xs = (np.array(4.0), np.array([4.0, 12.0])) + layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual(as_list(ys), [1.0, 3.0]) + + def test_condition_func_default_false(self): + cond = SmallerThan(3.0) + true = DivideBy(2.0) + layer = tl.Cond(cond, true) + xs = (np.array(4.0), np.array([4.0, 12.0])) + layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual(as_list(ys), [4.0, 12.0]) + + def test_exception_n_out(self): + cond = SmallerThan(3.0) + true = DivideBy(2.0) + false = tl.Dup() + self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false)) + + def test_exception_n_in(self): + cond = SmallerThan(3.0) + true = ReturnConst(2.0) + false = DivideBy(2.0) + self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false)) + + def test_exception_run1(self): + # We expect exactly one input. + cond = SmallerThan(3.0) + true = ReturnConst(2.0) + false = ReturnConst(5.0) + + def init_and_run(layer, xs): + layer.init(shapes.signature(xs)) + layer(xs) + + # It will pass with one input. + xs = np.array(4.0) + layer = tl.Cond(cond, true, false) + init_and_run(layer, xs) + # It will fail with zero or two inputs. + for xs in ((), (np.array(4.0), np.array([4.0, 12.0]))): + layer = tl.Cond(cond, true, false) + # pylint: disable=cell-var-from-loop + self.assertRaises(Exception, lambda: init_and_run(layer, xs)) + + def test_exception_run2(self): + # We expect exactly two inputs. + cond = SmallerThan(3.0) + true = DivideBy(2.0) + false = DivideBy(5.0) + + def init_and_run(layer, xs): + layer.init(shapes.signature(xs)) + layer(xs) + + # It will pass with two inputs. + xs = (np.array(4.0), np.array([4.0, 12.0])) + layer = tl.Cond(cond, true, false) + init_and_run(layer, xs) + # It will fail with zero or one input. + for xs in ((), (np.array(4.0))): + # pylint: disable=cell-var-from-loop + self.assertRaises(Exception, lambda: init_and_run(layer, xs)) + + def test_weights_and_state(self): + cond = SmallerThan(3.0) + true = tl.Dense(5) + false = tl.Dense(5) + different = tl.Dense(5) + layer = tl.Cond(cond, true, false) + xs = (np.array(2.0), np.array([0.0, 1.0, 2.0])) + layer.init(shapes.signature(xs)) + + # weights + self.assertEqual( + as_list(layer.weights), as_list((cond.weights, true.weights, false.weights)) + ) + self.assertNotEqual(as_list(true.weights), as_list(false.weights)) + self.assertNotEqual(as_list(true.weights), as_list(different.weights)) + + false.weights = true.weights + self.assertEqual( + as_list(layer.weights), as_list((cond.weights, true.weights, true.weights)) + ) + + layer.weights = (cond.weights, true.weights, different.weights) + self.assertEqual( + as_list(layer.weights), + as_list((cond.weights, true.weights, different.weights)), + ) + # state + self.assertEqual( + as_list(layer.state), as_list((cond.state, true.state, false.state)) + ) + # just check if simple assignments (setter from base.Layer) work correctly + # with Cond.init_weights_and_state ; all states are empty so there is no + # point in checking equality + false.state = true.state + layer.state = (cond.state, true.state, different.state) + + +class BatchLeadingAxesTest(absltest.TestCase): + def _Id3Dim(self): # pylint: disable=invalid-name + del self + + def f(x): + assert len(x.shape) == 3 + return x + + return tl.Fn("Id3Dim", f, n_out=1) + + def test_2axes(self): + layer = tl.BatchLeadingAxes(self._Id3Dim(), n_last_axes_to_keep=2) + ys = layer(np.zeros((3, 4, 5))) + self.assertEqual(ys.shape, (3, 4, 5)) + ys = layer(np.zeros((2, 3, 4, 5))) + self.assertEqual(ys.shape, (2, 3, 4, 5)) + ys = layer(np.zeros((1, 2, 3, 4, 5))) + self.assertEqual(ys.shape, (1, 2, 3, 4, 5)) + + +class BidirectionalTest(absltest.TestCase): + def test_dimensionality(self): + x = np.ones((2, 3, 8)) + layer = tl.Bidirectional(tl.GRU(n_units=8)) + input_signature = shapes.signature(x) + _, _ = layer.init(input_signature) + yhat = layer(x) + + self.assertEqual(yhat.shape, (2, 3, 8 + 8)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/convolution_test.py b/tests/layers/convolution_test.py new file mode 100644 index 000000000..42b9dae85 --- /dev/null +++ b/tests/layers/convolution_test.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for convolution layers.""" + +from absl.testing import absltest +import numpy as np + +from trax import shapes +import trax.layers as tl + + +class ConvolutionTest(absltest.TestCase): + def test_call(self): + layer = tl.Conv(30, (3, 3)) + x = np.ones((9, 5, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 3, 3, 30)) + + def test_use_bias_true(self): + layer = tl.Conv(30, (3, 3), use_bias=True) + x = np.ones((9, 5, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 3, 3, 30)) + + self.assertIsInstance(layer.weights, tuple) + self.assertLen(layer.weights, 2) + self.assertEqual(layer.weights[0].shape, (3, 3, 20, 30)) + self.assertEqual(layer.weights[1].shape, (30,)) + + def test_use_bias_false(self): + layer = tl.Conv(30, (3, 3), use_bias=False) + x = np.ones((9, 5, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 3, 3, 30)) + # With use_bias=False, layer.weights is just 'w' and there is no 'b'. + self.assertEqual(layer.weights.shape, (3, 3, 20, 30)) + + def test_call_rebatch(self): + layer = tl.Conv(30, (3, 3)) + x = np.ones((2, 9, 5, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (2, 9, 3, 3, 30)) + + +class CausalConvolutionTest(absltest.TestCase): + def test_causal_conv(self): + layer = tl.CausalConv(filters=30, kernel_width=3) + x = np.ones((9, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 5, 30)) + + # TODO(ddohan): How to test for causality? Gradient check between positions? + + def test_causal_conv_use_bias_false(self): + layer = tl.CausalConv(filters=30, kernel_width=3, use_bias=False) + x = np.ones((9, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 5, 30)) + + self.assertEqual(layer.weights.shape, (3, 20, 30)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/core_test.py b/tests/layers/core_test.py new file mode 100644 index 000000000..2e3e0c23a --- /dev/null +++ b/tests/layers/core_test.py @@ -0,0 +1,476 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for core layers.""" + +from absl.testing import absltest +import numpy as np + +from trax import shapes +from trax.fastmath import numpy as jnp +import trax.layers as tl +import trax.layers.initializers as init + + +class DenseTest(absltest.TestCase): + """Test Dense layer per se and as a key example of trainable layers.""" + + def test_call_before_init_raises_error(self): + layer = tl.Dense(5) + x = np.array([1, 2, 3]) + + # Without init, layer lacks the weights it needs for forward computation. + with self.assertRaises(tl.LayerError): + _ = layer(x) + + def test_call_uses_and_caches_supplied_weights(self): + layer = tl.Dense(4) + x = np.array([2, 3]) + + # Weights from random initialization are cached in the layer. + _, _ = layer.init(shapes.signature(x)) + w_init, b_init = layer.weights + + # Call the layer with externally specified weights. + w = np.array([[10000, 20000, 30000, 40000], [100, 200, 100, 200]]) + b = np.array([9, 8, 7, 6]) + y = layer(x, weights=(w, b)) + + # Using weights keyword arg overrides any previous cached weights ... + self.assertEqual(y.tolist(), [20309, 40608, 60307, 80606]) + self.assertNotEqual(w.tolist(), w_init.tolist()) + self.assertNotEqual(b.tolist(), b_init.tolist()) + + # ... and do not over-write the old weights. + w_cached, b_cached = layer.weights + self.assertNotEqual(w.tolist(), w_cached.tolist()) + self.assertNotEqual(b.tolist(), b_cached.tolist()) + + def test_separate_instances_have_separate_weights(self): + # Two dense layer instances: each will get its own initial weights (w, b). + model = tl.Serial(tl.Dense(5), tl.Dense(5)) + + sample_input = np.array([1, 2, 3, 4, 5]) + _, _ = model.init(shapes.signature(sample_input)) + weights_0 = model.sublayers[0].weights + weights_1 = model.sublayers[1].weights + + w0, b0 = weights_0 + w1, b1 = weights_1 + self.assertNotEqual(w0.tolist(), w1.tolist()) + self.assertNotEqual(b0.tolist(), b1.tolist()) + + def test_shared_instance_means_shared_weights(self): + # Same dense layer instance in two places --> shared weights. + layer = tl.Dense(5) + model = tl.Serial(layer, layer) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + + def test_call_no_bias(self): + layer = tl.Dense(4, use_bias=False) + x = np.array([2, 5, 3]) + _, _ = layer.init(shapes.signature(x)) + + w = np.array([[100, 200, 300, 400], [10, 10, 10, 10], [1, 2, 1, 2]]) + y = layer(x, weights=w) + self.assertEqual(y.tolist(), [253, 456, 653, 856]) + + def test_new_weights_use_bias(self): + layer = tl.Dense(4) + x = np.array([1, 2]) + _, _ = layer.init(shapes.signature(x)) + self.assertLen(layer.weights, 2) + self.assertEqual(layer.weights[0].shape, (2, 4)) + self.assertEqual(layer.weights[1].shape, (4,)) + + def test_new_weights_no_bias(self): + layer = tl.Dense(4, use_bias=False) + x = np.array([1, 2]) + _, _ = layer.init(shapes.signature(x)) + self.assertEqual(layer.weights.shape, (2, 4)) + + def test_init_twice_weights_same_shape(self): + layer = tl.Dense(4, use_bias=False) + x = np.array([1, 2]) + w1, _ = layer.init(shapes.signature(x)) + w2, _ = layer.init(shapes.signature(x)) + self.assertEqual(w1.shape, (2, 4)) + self.assertEqual(w2.shape, (2, 4)) + + def test_save_to_file_and_init_to_file(self): + layer1 = tl.Dense(4, use_bias=False) + layer2 = tl.Dense(4, use_bias=False) + x = np.array([1, 2]) + w1, _ = layer1.init(shapes.signature(x)) + layer1.save_to_file("/tmp/dense_weights", input_signature=shapes.signature(x)) + w2, _ = layer2.init_from_file("/tmp/dense_weights") + self.assertEqual(w1.shape, (2, 4)) + self.assertEqual(w2.shape, (2, 4)) + self.assertEqual(w1.tolist(), w2.tolist()) + + +class EmbeddingTest(absltest.TestCase): + def test_forward(self): + layer = tl.Embedding(10, 3) # vocab_size=10, d_feature=3 + _, _ = layer.init(None) # Embedding init doesn't use input signature. + x = np.array([2, 3, 5, 3, 2]) + y = layer(x) + self.assertEqual(y.shape, (5, 3)) + + # For distinct in-domain token IDs, resulting vectors should be distinct. + self.assertNotEqual(y[0].tolist(), y[1].tolist()) + self.assertNotEqual(y[0].tolist(), y[2].tolist()) + self.assertNotEqual(y[1].tolist(), y[2].tolist()) + + # For repeats of a token id, resulting vectors should match. + self.assertEqual(y[0].tolist(), y[4].tolist()) + self.assertEqual(y[1].tolist(), y[3].tolist()) + + def test_negative_inputs_clip_to_zero(self): + layer = tl.Embedding(10, 3) + _, _ = layer.init(None) + x = np.array([0, 2, 3, -2, -3]) + y = layer(x) + self.assertNotEqual(y[0].tolist(), y[1].tolist()) + self.assertNotEqual(y[0].tolist(), y[2].tolist()) + self.assertEqual(y[0].tolist(), y[3].tolist()) + self.assertEqual(y[0].tolist(), y[4].tolist()) + + def test_large_inputs_clip_to_upper_bound(self): + layer = tl.Embedding(10, 3) + _, _ = layer.init(None) + x = np.array([2, 3, 9, 10, 20]) + y = layer(x) + + # vocab_size of 10 means max valid token id is 9. + self.assertNotEqual(y[2].tolist(), y[0].tolist()) + self.assertNotEqual(y[2].tolist(), y[1].tolist()) + self.assertEqual(y[2].tolist(), y[3].tolist()) + self.assertEqual(y[2].tolist(), y[4].tolist()) + + def test_new_weights(self): + layer = tl.Embedding(20, 5) + _, _ = layer.init(None) + + # Default weights sampled from Gaussian, mu = 0, sigma = 1. + w = layer.weights + self.assertEqual(w.shape, (20, 5)) + self.assertLess(np.abs(np.mean(w)), 0.4) # .4 is 4 sigma deviation + + def test_explicit_kernel_initializer(self): + def f(shape, rng): + del rng + n_elements = np.prod(shape) + return np.arange(n_elements).reshape(shape) + + layer = tl.Embedding(5, 2, kernel_initializer=f) + _, _ = layer.init(None) + x = np.array([0, 1, 2, 3, 4]) + y = layer(x) + self.assertEqual(y.tolist(), [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) + + +class DropoutTest(absltest.TestCase): + def test_call_in_train_mode(self): + layer = tl.Dropout(rate=0.1, mode="train") + x = np.ones((2, 5, 1000)) # 10,000 values + y = layer(x) + self.assertEqual(y.shape, (2, 5, 1000)) + + # Dropout is stochastic; test it nonflakily at 4 sigmas (.99994). + n_remaining = np.count_nonzero(y) + mu_of_remaining = 9000 # N * q: 10000 * .9 + sigma_of_remaining = 30 # sqrt(N * p * q): sqrt(10000 * .1 * .9) + self.assertLess(np.abs(n_remaining - mu_of_remaining), 4 * sigma_of_remaining) + + def test_call_in_eval_mode_does_no_dropout(self): + layer = tl.Dropout(rate=0.1, mode="eval") + x = np.ones((2, 5, 1000)) + y = layer(x) + self.assertEqual(np.count_nonzero(y), 10_000) + + def test_new_weights(self): + layer = tl.Dropout(rate=0.1, mode="train") + layer.init(None) + self.assertEmpty(layer.weights) + + +class WeightsTest(absltest.TestCase): + """Test Weights layer.""" + + def test_simple(self): + layer = tl.Weights(lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)) + layer.init(()) + y = layer(()) + self.assertEqual(y.tolist(), 0.0) + + def test_shape(self): + layer = tl.Weights(init.RandomNormalInitializer(), (5, 10, 3)) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, (5, 10, 3)) + + def test_simple_custom_initializer(self): + layer = tl.Weights(init.RandomNormalInitializer()) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, ()) + self.assertNotEqual(y.tolist(), 0.0) + + def test_custom_initializer_shape(self): + layer = tl.Weights( + lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32), (2, 2) + ) + layer.init(()) + y = layer(()) + self.assertEqual(y.tolist(), [[0.0, 0.0], [0.0, 0.0]]) + + layer = tl.Weights(init.RandomNormalInitializer(), (2, 2)) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, (2, 2)) + self.assertNotEqual(y.tolist(), [[0.0, 0.0], [0.0, 0.0]]) + + +class SummaryScalarTest(absltest.TestCase): + def test_passes(self): + layer = tl.SummaryScalar("test") + x = np.array([[3.0, 5.0], [2.0, 6.0]]) # 10,000 values + y = layer(x) + self.assertEqual(y.tolist(), [[3.0, 5.0], [2.0, 6.0]]) + self.assertEqual(layer.state["summary_test"].tolist(), 4.0) + + +class RandomUniformTest(absltest.TestCase): + """Test Weights layer.""" + + def test_simple(self): + layer = tl.RandomUniform() + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, ()) + self.assertBetween(y, 0.0, 1.0) + + def test_shape(self): + layer = tl.RandomUniform(shape=(5, 10, 3)) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, (5, 10, 3)) + + def test_simple_range(self): + layer = tl.RandomUniform(1.0, 2.0, shape=(1000,)) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, (1000,)) + self.assertBetween(min(y.tolist()), 1.0, 2.0) + self.assertBetween(max(y.tolist()), 1.0, 2.0) + self.assertBetween(1.5, min(y.tolist()), max(y.tolist())) + + +class LocallyConnected1dTest(absltest.TestCase): + def test_shape_kernel1(self): + for padding in ["WRAP", "SAME", "VALID"]: + layer = tl.LocallyConnected1d(6, 1, padding=padding) + x = np.array([[0, 1], [2, 3], [4, 5]]) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (3, 6)) + + def test_shape_kernel3(self): + for padding in ["WRAP", "SAME"]: + layer = tl.LocallyConnected1d(6, 3, padding=padding) + x = np.array([[0, 1], [2, 3], [4, 5]]) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (3, 6)) + + for padding in ["VALID"]: + layer = tl.LocallyConnected1d(6, 3, padding=padding) + x = np.array([[0, 1], [2, 3], [4, 5]]) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (1, 6)) + + +class FlattenTest(absltest.TestCase): + def test_keep_default(self): + layer = tl.Flatten() + x = np.ones((1, 2, 3, 4, 5)) + y = layer(x) + # Default is leave first axis untouched, flatten the rest. + self.assertEqual(y.shape, (1, 2 * 3 * 4 * 5)) + + def test_keep_3(self): + layer = tl.Flatten(n_axes_to_keep=3) + x = np.ones((1, 2, 3, 4, 5)) + y = layer(x) + self.assertEqual(y.shape, (1, 2, 3, 4 * 5)) + + def test_keep_max_number(self): + layer = tl.Flatten(n_axes_to_keep=4) + x = np.ones((1, 2, 3, 4, 5)) + y = layer(x) + self.assertEqual(y.shape, (1, 2, 3, 4, 5)) + + def test_keep_too_many_raises_error(self): + layer = tl.Flatten(n_axes_to_keep=5) + with self.assertRaises(tl.LayerError): + x = np.ones((1, 2, 3, 4, 5)) + _ = layer(x) + + +class LogSoftmaxTest(absltest.TestCase): + def test_call(self): + layer = tl.LogSoftmax() + x = np.array([[2.0, 1.0, -10.0], [1.0, 1.0, -10.0]]) + y = layer(x) + np.testing.assert_allclose( + y, [[-0.313, -1.313, -12.313], [-0.693, -0.693, -11.693]], atol=0.001 + ) + + +class SoftmaxTest(absltest.TestCase): + def test_call(self): + layer = tl.Softmax() + x = np.array([[2.0, 1.0, -10.0], [1.0, 1.0, -10.0]]) + y = layer(x) + np.testing.assert_allclose( + y, [[0.731, 0.269, 0.00000449], [0.500, 0.500, 0.00000835]], atol=0.001 + ) + + +class CoreFunctionsTest(absltest.TestCase): + def test_one_hot(self): + targets = np.array([2, 0, 1]) + n_categories = 5 + target_distributions = tl.one_hot(targets, n_categories) + self.assertEqual( + tl.to_list(target_distributions), + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ], + ) + + def test_log_softmax(self): + activations = np.array([[2.0, 1.0, -10.0], [1.0, 1.0, -10.0]]) + log_probabilities = tl.log_softmax(activations) + np.testing.assert_allclose( + log_probabilities, + [[-0.313, -1.313, -12.313], [-0.693, -0.693, -11.693]], + atol=0.001, + ) + + def test_log_gaussian_pdf(self): + x = np.zeros((2, 5), dtype=np.float32) + mu = x + dsigma = np.eye(5)[None, :, :] + sigma = np.concatenate([dsigma, 2 * dsigma], axis=0) + prob = tl.log_gaussian_pdf(x, mu, sigma) + self.assertEqual(prob.shape, (2,)) + self.assertEqual(int(prob[0]), -4) + self.assertEqual(int(prob[1]), -6) + + def test_log_gaussian_diag_pdf(self): + x = np.zeros((2, 5), dtype=np.float32) + mu = x + sigma = np.ones((5,))[None, :] + sigma = np.concatenate([sigma, 2 * sigma], axis=0) + prob = tl.log_gaussian_diag_pdf(x, mu, sigma) + self.assertEqual(prob.shape, (2,)) + self.assertEqual(int(prob[0]), -4) + self.assertEqual(int(prob[1]), -6) + + +class StopGradientTest(absltest.TestCase): + def test_passes(self): + layer = tl.StopGradient() + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2, 2)) + self.assertEqual(y.tolist(), [[3.0, 5.0], [2.0, 6.0]]) + + +class MinMaxTest(absltest.TestCase): + def test_min(self): + layer = tl.Min() + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2,)) + self.assertEqual(y.tolist(), [3.0, 2.0]) + + layer = tl.Min(axis=0) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2,)) + self.assertEqual(y.tolist(), [2.0, 5.0]) + + layer = tl.Min(axis=None) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, ()) + self.assertEqual(y.tolist(), 2.0) + + layer = tl.Min(keepdims=True) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2, 1)) + self.assertEqual(y.tolist(), [[3.0], [2.0]]) + + def test_max(self): + layer = tl.Max() + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2,)) + self.assertEqual(y.tolist(), [5.0, 6.0]) + + layer = tl.Max(axis=0) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2,)) + self.assertEqual(y.tolist(), [3.0, 6.0]) + + layer = tl.Max(axis=None) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, ()) + self.assertEqual(y.tolist(), 6.0) + + layer = tl.Max(axis=0, keepdims=True) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (1, 2)) + self.assertEqual(y.tolist(), [[3.0, 6.0]]) + + +class ClassifierLayersTest(absltest.TestCase): + def test_threshold_to_binary(self): + layer = tl.ThresholdToBinary() + x = np.array([0.30, 0.49, 0.50, 0.51, 0.70]) + y = layer(x) + self.assertEqual(y.tolist(), [0, 0, 0, 1, 1]) + + def test_arg_max(self): + layer = tl.ArgMax() + x = np.array([[0.10, 0.90, 0.20, 0.80], [0.22, 0.88, 0.11, 0.99]]) + y = layer(x) + self.assertEqual(y.tolist(), [1, 3]) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/deconvolution_test.py b/tests/layers/deconvolution_test.py similarity index 75% rename from trax/layers/deconvolution_test.py rename to tests/layers/deconvolution_test.py index f1111f21e..759349c4b 100644 --- a/trax/layers/deconvolution_test.py +++ b/tests/layers/deconvolution_test.py @@ -23,15 +23,14 @@ class ConvTransposeTest(absltest.TestCase): + def test_call(self): + layer = tl.ConvTranspose(30, (3, 3)) + x = np.ones((9, 5, 5, 20)) + layer.init(shapes.signature(x)) - def test_call(self): - layer = tl.ConvTranspose(30, (3, 3)) - x = np.ones((9, 5, 5, 20)) - layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (9, 7, 7, 30)) - y = layer(x) - self.assertEqual(y.shape, (9, 7, 7, 30)) - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/initializers_test.py b/tests/layers/initializers_test.py new file mode 100644 index 000000000..3d248ad80 --- /dev/null +++ b/tests/layers/initializers_test.py @@ -0,0 +1,96 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for initializers.""" + +from absl.testing import absltest +import numpy as np + +from trax import fastmath +from trax import test_utils +import trax.layers as tl + + +INPUT_SHAPE = (5, 7, 20) + + +def rng(): # Can't be a constant, because JAX has to init itself in main first. + return fastmath.random.get_prng(0) + + +class InitializersTest(absltest.TestCase): + def test_random_normal(self): + f = tl.RandomNormalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_lecun_uniform(self): + f = tl.LeCunUniformInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_random_uniform(self): + f = tl.RandomUniformInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_glorot_normal(self): + f = tl.GlorotNormalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_glorot_uniform(self): + f = tl.GlorotUniformInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_lecun_normal(self): + f = tl.LeCunNormalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_kaiming_normal(self): + f = tl.KaimingNormalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_kaiming_uniform(self): + f = tl.KaimingUniformInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_orthogonal(self): + f = tl.OrthogonalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_from_file(self): + params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) + # `create_tempfile` needs access to --test_tmpdir, however in the OSS world + # pytest doesn't run `absltest.main`, so we need to manually parse the flags + test_utils.ensure_flag("test_tmpdir") + filename = self.create_tempfile("params.npy").full_path + with open(filename, "wb") as f: + np.save(f, params) + f = tl.InitializerFromFile(filename) + init_value = f(params.shape, rng()) + np.testing.assert_almost_equal( + tl.to_list(init_value), tl.to_list(params), decimal=4 + ) + # self.assertEqual('%s' % init_value, '%s' % params) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/metrics_test.py b/tests/layers/metrics_test.py new file mode 100644 index 000000000..1964ddbb3 --- /dev/null +++ b/tests/layers/metrics_test.py @@ -0,0 +1,441 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for metrics layers.""" + +from absl.testing import absltest +import numpy as np +import trax.layers as tl + + +class MetricsTest(absltest.TestCase): + def test_category_accuracy(self): + layer = tl.CategoryAccuracy() + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets]) + self.assertEqual(accuracy, 1 / 3) + + def test_weighted_category_accuracy_even_weights(self): + layer = tl.WeightedCategoryAccuracy() + weights = np.array([1.0, 1.0, 1.0]) + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1 / 3) + + def test_weighted_category_accuracy_uneven_weights(self): + layer = tl.WeightedCategoryAccuracy() + weights = np.array([1.0, 5.0, 2.0]) + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.7, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.625) + + def test_category_cross_entropy(self): + layer = tl.CategoryCrossEntropy() + targets = np.array([0, 1]) + + # Near-perfect prediction (for both items in batch). + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 9.0, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.001, places=3) + + # More right than wrong (for both items in batch). + model_outputs = np.array([[2.2, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.665, places=3) + + # First item near perfect, second item more right than wrong. + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.333, places=3) + + def test_category_cross_entropy_with_label_smoothing(self): + epsilon = 0.01 + layer = tl.CategoryCrossEntropy(label_smoothing=epsilon) + targets = np.array([0, 1]) + + # Near-perfect prediction (for both items in batch). + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 9.0, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.069, places=3) + + # More right than wrong (for both items in batch). + model_outputs = np.array([[2.2, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.682, places=3) + + # First item near perfect, second item more right than wrong. + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.375, places=3) + + def test_weighted_category_cross_entropy(self): + layer = tl.WeightedCategoryCrossEntropy() + targets = np.array([0, 1]) + weights = np.array([30, 10]) + + # Near-perfect prediction (for both items in batch). + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 9.0, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.001, places=3) + + # More right than wrong (for both items in batch). + model_outputs = np.array([[2.2, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.665, places=3) + + # First item (with 75% weight) near perfect, second more right than wrong. + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.167, places=3) + + def test_weighted_category_cross_entropy_with_label_smoothing(self): + epsilon = 0.01 + layer = tl.WeightedCategoryCrossEntropy(label_smoothing=epsilon) + targets = np.array([0, 1]) + weights = np.array([30, 10]) + + # Near-perfect prediction (for both items in batch). + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 9.0, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.069, places=3) + + # More right than wrong (for both items in batch). + model_outputs = np.array([[2.2, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.682, places=3) + + # First item (with 75% weight) near perfect, second more right than wrong. + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.222, places=3) + + def test_masked_sequence_accuracy(self): + layer = tl.MaskedSequenceAccuracy() + targets = np.array([[0, 1, 0, 0], [1, 0, 1, 0]]) + weights = np.array([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0]]) + + # Model gets both sequences right; output in final position would give + # wrong category but is ignored. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.35, 0.65]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.35, 0.65]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + # Model gets the first element of the first sequence barely wrong. + model_outputs = np.array( + [ + [[0.45, 0.55], [0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.6, 0.4]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.5) + + # Model gets second-to-last element of each sequence barely wrong. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.48, 0.52], [0.6, 0.4]], + [[0.3, 0.7], [0.8, 0.2], [0.51, 0.49], [0.6, 0.4]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.0) + + def test_binary_cross_entropy(self): + layer = tl.BinaryCrossEntropy() + targets = np.array([1, 1, 0, 0, 0]) + + # Near-perfect prediction for all five items in batch. + model_outputs = np.array([9.0, 9.0, -9.0, -9.0, -9.0]) + metric_output = layer([model_outputs, targets]) + self.assertAlmostEqual(metric_output, 0.000123, places=6) + + # More right than wrong for all five items in batch. + model_outputs = np.array([1.0, 1.0, -1.0, -1.0, -1.0]) + metric_output = layer([model_outputs, targets]) + self.assertAlmostEqual(metric_output, 0.313, places=3) + + # Near-perfect for 2, more right than wrong for 3. + model_outputs = np.array([9.0, 1.0, -1.0, -1.0, -9.0]) + metric_output = layer([model_outputs, targets]) + self.assertAlmostEqual(metric_output, 0.188, places=3) + + # More wrong than right for all five. + model_outputs = np.array([-1.0, -1.0, 1.0, 1.0, 1.0]) + metric_output = layer([model_outputs, targets]) + self.assertAlmostEqual(metric_output, 1.313, places=3) + + def test_accuracy_even_weights(self): + layer = tl.Accuracy() + weights = np.array([1.0, 1.0, 1.0]) + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1 / 3) + + def test_accuracy_uneven_weights(self): + layer = tl.Accuracy() + weights = np.array([1.0, 5.0, 2.0]) + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.7, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.625) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.7, 0.2, 0.1, 0.0], [0.7, 0.2, 0.1, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.125) + + def test_accuracy_binary_classifier(self): + layer = tl.Accuracy(classifier=tl.ThresholdToBinary()) + targets = np.array([[0, 0, 1, 1], [1, 1, 1, 0]]) + weights = np.ones_like(targets) + + model_outputs = np.array( + [[0.499, 0.500, 0.501, 0.502], [0.503, 0.502, 0.501, 0.500]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.498, 0.499, 0.500, 0.501], [0.502, 0.501, 0.500, 0.499]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.75) + + def test_sequence_accuracy_weights_all_ones(self): + layer = tl.SequenceAccuracy() + targets = np.array([[0, 1, 0, 1], [1, 0, 1, 1]]) + weights = np.ones_like(targets) + + # Model gets both sequences right; for each position in each sequence, the + # category (integer ID) selected by argmax matches the target category. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.4, 0.6]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.4, 0.6]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + # Model gets the first element of the first sequence barely wrong. + model_outputs = np.array( + [ + [[0.45, 0.55], [0.2, 0.8], [0.7, 0.3], [0.4, 0.6]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.4, 0.6]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.5) + + # Model gets the last element of each sequence barely wrong. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.55, 0.45]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.52, 0.48]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.0) + + def test_sequence_accuracy_last_position_zero_weight(self): + layer = tl.SequenceAccuracy() + targets = np.array([[0, 1, 0, 0], [1, 0, 1, 0]]) + weights = np.array([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0]]) + + # Model gets both sequences right; output in final position would give + # wrong category but is ignored. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.35, 0.65]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.35, 0.65]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + # Model gets the first element of the first sequence barely wrong. + model_outputs = np.array( + [ + [[0.45, 0.55], [0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.6, 0.4]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.5) + + # Model gets second-to-last element of each sequence barely wrong. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.48, 0.52], [0.6, 0.4]], + [[0.3, 0.7], [0.8, 0.2], [0.51, 0.49], [0.6, 0.4]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.0) + + def test_binary_cross_entropy_loss(self): + # TODO(jonni): Clarify desired semantics/naming, then test it. + layer = tl.BinaryCrossEntropyLoss() + xs = [np.ones((9, 1)), np.ones((9, 1)), np.ones((9, 1))] + y = layer(xs) + self.assertEqual(y.shape, ()) + + def test_cross_entropy_loss(self): + # TODO(jonni): Clarify desired semantics/naming, then test it. + layer = tl.CrossEntropyLoss() + xs = [np.ones((9, 4, 4, 20)), np.ones((9, 4, 4)), np.ones((9, 4, 4))] + y = layer(xs) + self.assertEqual(y.shape, ()) + + def test_l2_loss(self): + layer = tl.L2Loss() + + model_outputs = np.array([[1.0, 1.0], [1.0, 1.0]]) + targets = np.array([[1.0, 1.0], [1.0, 0.0]]) + weights = np.array([[1.0, 1.0], [1.0, 0.0]]) + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, 0.0) + + weights = np.array([[1.0, 0.0], [0.0, 1.0]]) + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, 0.5) + + def test_smooth_l1_loss(self): + layer = tl.SmoothL1Loss() + + model_outputs = np.array([[1.0, 1.0], [1.0, 2.0]]) + targets = np.array([[1.0, 1.0], [1.0, 0.0]]) + l1_dist = 2 + + weights = np.array([[1.0, 1.0], [1.0, 0.0]]) + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, 0.0) + + weights = np.array([[1.0, 0.0], [0.0, 1.0]]) + sum_weights = 2 + + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, (l1_dist - 0.5) / sum_weights) + + model_outputs = np.array([[1.0, 1.0], [1.0, 1.5]]) + targets = np.array([[1.0, 1.0], [1.0, 1.0]]) + l1_dist = 0.5 + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, 0.5 * l1_dist**2 / sum_weights) + + def test_macro_averaged_f_score(self): + # predictions = [1, 1, 2, 1, 1]. + model_outputs = np.array( + [[0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]] + ) + targets = np.array([1, 2, 2, 3, 1]) + # Category indices starting with `0`. + layer = tl.MacroAveragedFScore() + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.333, places=3) + # Excluding the padding index `0`. + layer = tl.MacroAveragedFScore(initial_category_index=1) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.444, places=3) + + def test_weighted_f_score(self): + # predictions = [1, 1, 2, 1, 1]. + model_outputs = np.array( + [[0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]] + ) + targets = np.array([1, 2, 2, 3, 1]) + # Category indices starting with `0`. + layer = tl.WeightedFScore() + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.533, places=3) + # Excluding the padding index `0`. + layer = tl.WeightedFScore(initial_category_index=1) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.533, places=3) + + def test_names(self): + layer = tl.L2Loss() + self.assertEqual("L2Loss_in3", str(layer)) + layer = tl.Accuracy() + self.assertEqual("Accuracy_in3", str(layer)) + layer = tl.SequenceAccuracy() + self.assertEqual("SequenceAccuracy_in3", str(layer)) + layer = tl.BinaryCrossEntropyLoss() + self.assertEqual("BinaryCrossEntropyLoss_in3", str(layer)) + layer = tl.CrossEntropyLoss() + self.assertEqual("CrossEntropyLoss_in3", str(layer)) + layer = tl.BinaryCrossEntropySum() + self.assertEqual("BinaryCrossEntropySum_in3", str(layer)) + layer = tl.CrossEntropySum() + self.assertEqual("CrossEntropySum_in3", str(layer)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/normalization_test.py b/tests/layers/normalization_test.py new file mode 100644 index 000000000..f5f519634 --- /dev/null +++ b/tests/layers/normalization_test.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for normalization layers.""" + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +from trax import fastmath +from trax import shapes +import trax.layers as tl + + +class BatchNormTest(parameterized.TestCase): + def test_forward_shape(self): + layer = tl.BatchNorm() + x = np.ones((30, 20, 70)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + @parameterized.named_parameters( + ("jax32", fastmath.Backend.JAX, np.float32), + ("tf32", fastmath.Backend.TFNP, np.float32), + ("tf64", fastmath.Backend.TFNP, np.float64), + ) + def test_forward_dtype(self, backend, dtype): + with fastmath.use_backend(backend): + layer = tl.BatchNorm() + x = np.ones((3, 2, 7)).astype(dtype) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.dtype, dtype) + + @parameterized.named_parameters( + ("momentum_999", 0.999), + ("momentum_900", 0.900), + ("momentum_800", 0.800), + ) + def test_forward(self, momentum): + layer = tl.BatchNorm(momentum=momentum) + x = np.array( + [ + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], + ] + ).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + running_mean, running_var, n_batches = layer.state + + fraction_old = momentum + fraction_new = 1.0 - momentum + mean_of_x = 11.5 # mean of range(24) + var_of_x = 47.9167 # variance of range(24) + np.testing.assert_allclose( + running_mean, 0.0 * fraction_old + mean_of_x * fraction_new + ) + np.testing.assert_allclose( + running_var, 1.0 * fraction_old + var_of_x * fraction_new, rtol=1e-6 + ) + self.assertEqual(n_batches, 1) + eps = 1e-5 + np.testing.assert_allclose( + y, (x - mean_of_x) / np.sqrt(var_of_x + eps), rtol=1e-6 + ) + + def test_new_weights_and_state(self): + layer = tl.BatchNorm() + x = np.ones((3, 2, 7)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + + running_mean, running_var, n_batches = layer.state + np.testing.assert_allclose(running_mean, 0.0) + np.testing.assert_allclose(running_var, 1.0) + self.assertEqual(n_batches, 0) + + +class LayerNormTest(parameterized.TestCase): + def test_forward_shape(self): + layer = tl.LayerNorm() + x = np.ones((3, 2, 7)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + @parameterized.named_parameters( + ("jax32", fastmath.Backend.JAX, np.float32), + ("tf32", fastmath.Backend.TFNP, np.float32), + ("tf64", fastmath.Backend.TFNP, np.float64), + ) + def test_forward_dtype(self, backend, dtype): + with fastmath.use_backend(backend): + layer = tl.LayerNorm() + x = np.ones((3, 2, 7)).astype(dtype) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.dtype, dtype) + + +class FilterResponseNormTest(parameterized.TestCase): + @parameterized.named_parameters( + ("learn_epsilon_false", False), + ("learn_epsilon_true", True), + ) + def test_forward_shape(self, learn_epsilon): + layer = tl.FilterResponseNorm(learn_epsilon=learn_epsilon) + + B, H, W, C = 64, 5, 7, 3 # pylint: disable=invalid-name + x = np.ones((B, H, W, C)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/pooling_test.py b/tests/layers/pooling_test.py new file mode 100644 index 000000000..a41651031 --- /dev/null +++ b/tests/layers/pooling_test.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for conv layers.""" + +from absl.testing import absltest +import numpy as np + +import trax.layers as tl + + +class MaxPoolTest(absltest.TestCase): + def test_forward_shape(self): + layer = tl.MaxPool(pool_size=(2, 2), strides=(1, 2)) + x = np.ones((11, 6, 4, 17)) + y = layer(x) + self.assertEqual(y.shape, (11, 5, 2, 17)) + + def test_forward(self): + layer = tl.MaxPool(pool_size=(2, 2), strides=(2, 2)) + x = np.array( + [ + [ + [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], + [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], + ] + ] + ) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[[7, 5, 6], [70, 50, 60]]]]) + + def test_padding_default(self): + layer = tl.MaxPool(pool_size=(3,), strides=(3,)) + + # Discard incomplete window at end: [[3, 6], [4, 5]]. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6], [4, 5]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[2, 9]]]) + + def test_padding_same(self): + layer = tl.MaxPool(pool_size=(3,), strides=(3,), padding="SAME") + + # One padding position needed; add at end. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6], [4, 5]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[2, 9], [4, 6]]]) + + # Two padding positions needed; add one at end and one at start. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[1, 9], [3, 7]]]) + + +class SumPoolTest(absltest.TestCase): + def test_forward_shape(self): + layer = tl.SumPool(pool_size=(2, 2), strides=(1, 2)) + x = np.ones((11, 6, 4, 17)) + y = layer(x) + self.assertEqual(y.shape, (11, 5, 2, 17)) + + def test_forward(self): + layer = tl.SumPool(pool_size=(2, 2), strides=(2, 2)) + x = np.array( + [ + [ + [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], + [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], + ] + ] + ) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[[16, 10, 14], [160, 100, 140]]]]) + + def test_padding_same(self): + layer = tl.SumPool(pool_size=(3,), strides=(3,), padding="SAME") + + # One padding position needed; add at end. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6], [4, 5]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[3, 24], [7, 11]]]) + + # Two padding positions needed; add one at end and one at start. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[1, 17], [5, 13]]]) + + +class AvgPoolTest(absltest.TestCase): + def test_forward_shape(self): + layer = tl.AvgPool(pool_size=(2, 2), strides=(1, 2)) + x = np.ones((11, 6, 4, 17)) + y = layer(x) + self.assertEqual(y.shape, (11, 5, 2, 17)) + + def test_forward(self): + layer = tl.AvgPool(pool_size=(2, 2), strides=(2, 2)) + x = np.array( + [ + [ + [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], + [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], + ] + ] + ) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[[4.0, 2.5, 3.5], [40, 25, 35]]]]) + + def test_padding_same(self): + layer = tl.AvgPool(pool_size=(3,), strides=(3,), padding="SAME") + + # One padding position needed; add at end. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6], [4, 5]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[1, 8], [3.5, 5.5]]]) + + # Two padding positions needed; add one at end and one at start. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[0.5, 8.5], [2.5, 6.5]]]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/research/efficient_attention_test.py b/tests/layers/research/efficient_attention_test.py new file mode 100644 index 000000000..ae98c4428 --- /dev/null +++ b/tests/layers/research/efficient_attention_test.py @@ -0,0 +1,585 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.research.efficient_attention.""" + +from absl.testing import parameterized +import jax +import numpy as np +from tensorflow import test + +from trax import fastmath +from trax import shapes +from trax.fastmath import numpy as jnp +from trax.layers.research import efficient_attention + + +class EfficientAttentionTest(test.TestCase, parameterized.TestCase): + def test_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.SelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + share_qk=False, + causal=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + x = np.ones((3, 32, 8)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_lsh_ff(self): + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.LSHFF(d_ff=1024 * 8, n_buckets=[16, 8]) + x = np.ones((3, 7, 1024)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_self_attention_tf(self): + with fastmath.use_backend(fastmath.Backend.TFNP): + layer = efficient_attention.SelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + share_qk=False, + causal=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + x = np.ones((3, 32, 8)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_lsh_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.LSHSelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + causal=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + x = np.ones((3, 32, 8)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def _run_forward_and_backward(self, model, inp, weights, state): + def forward(inp, weights): + return model.pure_fn(inp, weights, state, rng=jax.random.PRNGKey(0)) + + out, vjpfun, new_state = jax.vjp(forward, inp, weights, has_aux=True) + inp_grad, weights_grad = vjpfun(fastmath.numpy.ones_like(inp)) + return out, new_state, inp_grad, weights_grad + + def _test_equivalence_to_reference_code( + self, model_cls, inp, input_signature, common_kwargs, *test_kwargs + ): + ref_model = model_cls(use_reference_code=True, **common_kwargs) + rng = fastmath.random.get_prng(123) + weights, state = ref_model.init(input_signature, rng) + + ref_all = self._run_forward_and_backward(ref_model, inp, weights, state) + ref_out, ref_state, ref_inp_grad, ref_weights_grad = ref_all + + for kwargs in test_kwargs: + test_model = model_cls(**common_kwargs, **kwargs) + state = test_model.init(input_signature, rng)[1] + test_all = self._run_forward_and_backward(test_model, inp, weights, state) + test_out, test_state, test_inp_grad, test_weights_grad = test_all + + self.assertEqual(jax.tree_structure(ref_out), jax.tree_structure(test_out)) + self.assertEqual( + jax.tree_structure(ref_state), jax.tree_structure(test_state) + ) + self.assertEqual( + jax.tree_structure(ref_inp_grad), jax.tree_structure(test_inp_grad) + ) + self.assertEqual( + jax.tree_structure(ref_weights_grad), + jax.tree_structure(test_weights_grad), + ) + + check_close = lambda x, y: self.assertAllClose(x, y, rtol=2e-3, atol=2e-3) + fastmath.nested_map_multiarg(check_close, ref_out, test_out) + fastmath.nested_map_multiarg(check_close, ref_state, test_state) + fastmath.nested_map_multiarg(check_close, ref_inp_grad, test_inp_grad) + fastmath.nested_map_multiarg( + check_close, ref_weights_grad, test_weights_grad + ) + + def test_batching_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + common_kwargs = dict( + n_heads=6, + d_qk=7, + d_v=17, + share_qk=False, + causal=True, + chunk_len=5, + n_chunks_before=1, + n_chunks_after=0, + attention_dropout=0.2, + output_dropout=0.1, + mode="train", + ) + test_kwargs = [] + for n_parallel_heads in [1, 3, 6, 12]: + for use_python_loop in [True, False]: + test_kwargs.append( + dict( + n_parallel_heads=n_parallel_heads, + use_python_loop=use_python_loop, + ) + ) + + x = jax.random.uniform( + jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32 + ) + input_signature = shapes.signature(x) + self._test_equivalence_to_reference_code( + efficient_attention.SelfAttention, + x, + input_signature, + common_kwargs, + *test_kwargs + ) + + def test_batching_lsh_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + common_kwargs = dict( + n_heads=6, + d_qk=7, + d_v=17, + causal=True, + chunk_len=5, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + attention_dropout=0.2, + output_dropout=0.1, + mode="train", + ) + test_kwargs = [] + for n_parallel_heads in [1, 3, 6, 12]: + for use_python_loop in [True, False]: + test_kwargs.append( + dict( + n_parallel_heads=n_parallel_heads, + use_python_loop=use_python_loop, + ) + ) + + x = jax.random.uniform( + jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32 + ) + input_signature = shapes.signature(x) + self._test_equivalence_to_reference_code( + efficient_attention.LSHSelfAttention, + x, + input_signature, + common_kwargs, + *test_kwargs + ) + + def _test_fast_inference( + self, model_cls, x, input_signature, common_kwargs, *test_kwargs + ): + ref_model = model_cls(use_reference_code=True, mode="eval", **common_kwargs) + weights, state = ref_model.init(input_signature) + + ref_out, _ = ref_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + + def get_slice(pytree, i): + def get_slice_for_val(x): + if isinstance(x, shapes.ShapeDtype): + return shapes.ShapeDtype( + shape=x.shape[:1] + (1,) + x.shape[2:], dtype=x.dtype + ) + else: + return x[:, i : i + 1] + + return jax.tree_map(get_slice_for_val, pytree) + + seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1] + + for kwargs in test_kwargs: + test_model = model_cls(mode="predict", **common_kwargs, **kwargs) + cur_state = test_model.init(get_slice(input_signature, 0))[1] + out = [] + for i in range(seqlen): + cur_out, cur_state = test_model.pure_fn( + get_slice(x, i), weights, cur_state, jax.random.PRNGKey(0) + ) + out.append(cur_out) + out = jnp.concatenate(out, axis=1) + + self.assertAllClose(out, ref_out, rtol=1e-3, atol=1e-3) + + def test_fast_inference_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + common_kwargs = dict( + n_heads=6, + d_qk=7, + d_v=17, + share_qk=False, + causal=True, + chunk_len=5, + n_chunks_before=1, + n_chunks_after=0, + attention_dropout=0.0, + output_dropout=0.0, + ) + test_kwargs = [] + for n_parallel_heads in [1, 3, 6, 12]: + for use_python_loop in [True, False]: + test_kwargs.append( + dict( + n_parallel_heads=n_parallel_heads, + use_python_loop=use_python_loop, + ) + ) + + x = jax.random.uniform( + jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32 + ) + input_signature = shapes.signature(x) + self._test_fast_inference( + efficient_attention.SelfAttention, + x, + input_signature, + common_kwargs, + *test_kwargs + ) + + def _test_lsh_self_attention_deterministic_given_seed(self, causal=False): + # Once the initialization and the call seeds are pinned down we have + # deterministic output. + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.LSHSelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + causal=causal, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + x = np.ones((3, 32, 8)).astype(np.float32) + + def get_output(): + _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0)) + return layer(x, rng=jax.random.PRNGKey(1)) + + ys = [get_output() for _ in range(10)] + + self.assertEqual(ys[0].shape, x.shape) + + for y in ys[1:]: + np.testing.assert_array_almost_equal(ys[0], y, decimal=6) + + def test_lsh_determinism_causal(self): + self._test_lsh_self_attention_deterministic_given_seed(causal=True) + + def test_lsh_determinism_non_causal(self): + self._test_lsh_self_attention_deterministic_given_seed(causal=False) + + def test_lsh_self_attention_masked_non_causal(self): + # Test that when the input that is in the masked area changes the attention + # for the un-masked outputs doesn't change, but the masked region does + # change. + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.LSHSelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + causal=False, + masked=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + + batch = 5 + max_len = 32 + hidden = 8 + + x = np.random.uniform(size=(batch, max_len, hidden)) + mask = np.ones((batch, max_len)).astype(bool) + rngs = jax.random.randint( + jax.random.PRNGKey(0), (batch,), minval=1, maxval=max_len - 1 + ) + + # Set some suffix of each mask[b] to 0. + for i in range(batch): + mask[i, rngs[i] :] = 0 + + # Fix rngs and get the output for the LSH layer. + def get_output(x, mask): + xs = [x, mask] + _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0)) + return layer(xs, rng=jax.random.PRNGKey(1)) + + # Get the attention output for masked x. + y = get_output(x, mask) + + # Change x, but only in the masked regions. + for i in range(batch): + x[i, rngs[i] :] = np.random.uniform(size=(max_len - rngs[i], hidden)) + + y2 = get_output(x, mask) + + for i in range(batch): + # y and y2 should be identical in the non-masked part. + np.testing.assert_array_almost_equal( + y[i, : rngs[i]], y2[i, : rngs[i]], decimal=6 + ) + + # In the masked out part, they should be different. + self.assertGreater( + np.mean(np.abs(y[i, rngs[i] :] - y2[i, rngs[i] :])), 1e-5 + ) + + @parameterized.named_parameters(("_weights_2", 2), ("_weights_3", 3)) + def test_pure_lsh_wrapper_causal_non_masked(self, num_weights): + with fastmath.use_backend(fastmath.Backend.JAX): + n_heads = 5 + batch, seqlen, d_head = 3, 32, 8 + n_hashes = 2 + d_model = n_heads * d_head + layer = efficient_attention.PureLSHSelfAttentionWrapper( + n_heads=n_heads, + d_qk=d_head, + d_v=d_head, + causal=True, + masked=False, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=n_hashes, + n_buckets=4, + bias=False, + pure_lsh_implementation=efficient_attention.PureLSHSelfAttention, + mode="train", + num_weights=num_weights, + ) + + rng = jax.random.PRNGKey(0) + rng, x_rng = jax.random.split(rng) + + input_shape = (batch, seqlen, d_model) + x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) + + inp = x + w, s = layer.init(shapes.signature(inp)) + o = layer(inp) + + # Get the actual weights. + weights = fastmath.tree_leaves(w) + # Assert number of weights is as expected, the extra 1 is for output. + self.assertLen(weights, num_weights + 1) + + # Assert each weight is of the expected shape. + for i in range(num_weights + 1): + self.assertEqual(weights[i].shape, (d_model, d_model)) + + # Test that the output and the input shape match. + self.assertEqual(inp.shape, o.shape) + + # Assert state is the shape expected. + state = fastmath.tree_leaves(s) + self.assertLen(state, 2) + # buckets + self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) + # rngs + self.assertEqual(state[1].shape, (batch * n_heads, 2)) + + @parameterized.named_parameters(("_weights_2", 2), ("_weights_3", 3)) + def test_pure_lsh_wrapper_non_causal_masked(self, num_weights): + with fastmath.use_backend(fastmath.Backend.JAX): + n_heads = 5 + batch, seqlen, d_head = 3, 32, 8 + num_weights = 2 + n_hashes = 2 + d_model = n_heads * d_head + layer = efficient_attention.PureLSHSelfAttentionWrapper( + n_heads=n_heads, + d_qk=d_head, + d_v=d_head, + causal=False, + masked=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=n_hashes, + n_buckets=4, + bias=False, + pure_lsh_implementation=efficient_attention.PureLSHSelfAttention, + mode="train", + num_weights=num_weights, + ) + + rng = jax.random.PRNGKey(0) + rng, x_rng = jax.random.split(rng) + + input_shape = (batch, seqlen, d_model) + x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) + mask = jnp.ones((batch, seqlen), dtype=jnp.int32) + + inp = (x, mask) + w, s = layer.init(shapes.signature(inp)) + o = layer(inp) + + # Get the actual weights. + weights = fastmath.tree_leaves(w) + # Assert number of weights is as expected, the extra 1 is for output. + self.assertLen(weights, num_weights + 1) + + # Assert each weight is of the expected shape. + for i in range(num_weights + 1): + self.assertEqual(weights[i].shape, (d_model, d_model)) + + # Test that the output and the x's shape match. + self.assertEqual(x.shape, o.shape) + + # Assert state is the shape expected. + state = fastmath.tree_leaves(s) + self.assertLen(state, 2) + # buckets + self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) + # rngs + self.assertEqual(state[1].shape, (batch * n_heads, 2)) + + def test_lsh_and_pure_lsh_self_attention_equivalence(self): + # Given the same weight matrices and random numbers, do these produce the + # same output. + with fastmath.use_backend(fastmath.Backend.JAX): + n_heads = 4 + d_head = 4 + d_model = n_heads * d_head + pure_lsh_layer = efficient_attention.PureLSHSelfAttention( + n_heads=n_heads, + d_qk=d_head, + d_v=d_head, + causal=True, + masked=False, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=4, + n_buckets=8, + use_reference_code=False, + attention_dropout=0.0, + use_python_loop=True, + bias=False, + mode="train", + ) + lsh_layer = efficient_attention.LSHSelfAttention( + n_heads=n_heads, + d_qk=d_head, + d_v=d_head, + causal=True, + masked=False, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=4, + n_buckets=8, + use_reference_code=False, + attention_dropout=0.0, + use_python_loop=True, + mode="train", + ) + + batch, seqlen = 3, 32 + input_shape = (batch, seqlen, d_model) + + x = jax.random.uniform( + jax.random.PRNGKey(0), input_shape, dtype=jnp.float32 + ) + lsh_layer_input = x + + call_rng = jax.random.PRNGKey(42) + + lsh_layer_weights, lsh_layer_state = lsh_layer.init( + shapes.signature(lsh_layer_input) + ) + lsh_layer.rng = call_rng + lsh_layer_output = lsh_layer(lsh_layer_input) + + # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head), + # (n_heads, d_head, d_model) + # Abbreviated as - hmn, hmn, hnm + w_qk, w_v, w_o = lsh_layer_weights + + qk = jnp.einsum("blm,hmn->bhln", x, w_qk) + qk = qk.reshape((-1, qk.shape[2], qk.shape[3])) + + v = jnp.einsum("blm,hmn->bhln", x, w_v) + v = v.reshape((-1, v.shape[2], v.shape[3])) + + pure_lsh_layer_input = (qk, v) + _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input)) + pure_lsh_layer.rng = call_rng + pure_lsh_layer.state = lsh_layer_state + pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input) + + # b*h,l,n + pure_lsh_layer_output = pure_lsh_layer_output.reshape( + (batch, -1) + pure_lsh_layer_output.shape[1:] + ) + pure_lsh_layer_output_projected = jnp.einsum( + "bhld,hdm->blm", pure_lsh_layer_output, w_o + ) + + diff = pure_lsh_layer_output_projected - lsh_layer_output + avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff)) + + self.assertLess(avg_diff, 1e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tests/layers/research/position_encodings_test.py b/tests/layers/research/position_encodings_test.py new file mode 100644 index 000000000..715c1cbe2 --- /dev/null +++ b/tests/layers/research/position_encodings_test.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.research.position_encodings.""" + +import functools +import absl.testing.absltest as unittest +import numpy as np +import parameterized + +from trax import fastmath +import trax.layers.research.position_encodings as pe + + +@parameterized.parameterized_class( + [ + # {'Encoding': pe.FixedBasePositionalEncoding}, + {"Encoding": pe.InfinitePositionalEncoding}, + {"Encoding": functools.partial(pe.InfinitePositionalEncoding, affine=False)}, + { + "Encoding": functools.partial( + pe.TimeBinPositionalEncoding, time_bin_length=5 + ) + }, + ] +) +class PositionEncodingsTest(unittest.TestCase): + """Position encodings conform to the position encodings protocol.""" + + @parameterized.parameterized.expand( + [ + (1, 100, 8), # typical + (1, 1, 8), # short + (1, 100, 1), # narrow + (2, 100, 8), # batched + ] + ) + def test_training(self, n, t, c): + encoding = self.Encoding() + input_ntc = np.random.randn(n, t, c) + encoding.init(input_ntc) + output_ntc = encoding(input_ntc) + self.assertEqual(output_ntc.shape, input_ntc.shape) + self.assertTrue(np.not_equal(output_ntc, input_ntc).any()) + + @parameterized.parameterized.expand( + [ + (1, 100, 8), # typical + (1, 100, 1), # narrow + (2, 100, 8), # batched + ] + ) + def test_inference(self, n, t, c): + # Get the eval mode outputs: + encoding = self.Encoding(mode="eval") + input_ntc = np.random.randn(n, t, c) + rng = fastmath.random.get_prng(1234) + encoding.init(input_ntc, rng=rng) + output_ntc = encoding(input_ntc) + + is_random = self.Encoding == pe.InfinitePositionalEncoding + + # Get the predict mode outputs: + encoding_pred = self.Encoding(mode="predict") + encoding_pred.init(input_ntc[:, 0:1, :], rng=rng) + output_ntc0 = encoding_pred(input_ntc[:, 0:1, :]) + if not is_random: + np.testing.assert_allclose(output_ntc0, output_ntc[:, 0:1, :], atol=1e-4) + + output_ntc1 = encoding_pred(input_ntc[:, 1:2, :]) + if not is_random: + np.testing.assert_allclose(output_ntc1, output_ntc[:, 1:2, :], atol=1e-4) + + output_ntc2 = encoding_pred(input_ntc[:, 2:3, :]) + if not is_random: + np.testing.assert_allclose(output_ntc2, output_ntc[:, 2:3, :], atol=1e-4) + + +class SinCosEncodingsTest(unittest.TestCase): + """Position encodings conform to the position encodings protocol.""" + + @parameterized.parameterized.expand( + [ + (1, 100, 8), # typical + (1, 1, 8), # short + (2, 100, 8), # batched + ] + ) + def test_training(self, n, t, c): + encoding = pe.SinCosPositionalEncoding() + input_ntc = np.random.randn(n, t, c) + encoding.init(input_ntc) + output_ntc = encoding(input_ntc) + self.assertEqual(output_ntc.shape, input_ntc.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/trax/layers/research/rel_attention_test.py b/tests/layers/research/rel_attention_test.py similarity index 60% rename from trax/layers/research/rel_attention_test.py rename to tests/layers/research/rel_attention_test.py index 50918ff78..cfb5a891b 100644 --- a/trax/layers/research/rel_attention_test.py +++ b/tests/layers/research/rel_attention_test.py @@ -37,18 +37,37 @@ class RelAttentionTest(absltest.TestCase): + def test_fast_shift_matrix(self): + layer = ra._fast_matrix_shift + x = np.array( + [ + [ + [ + [-3.0, -2.0, -1.0, 0.0], + [-3.0, -2.0, -1.0, 0.0], + [-3.0, -2.0, -1.0, 0.0], + [-3.0, -2.0, -1.0, 0.0], + ] + ] + ] + ).astype(np.float32) - def test_fast_shift_matrix(self): - layer = ra._fast_matrix_shift - x = np.array([[[[-3., -2., -1., 0.], [-3., -2., -1., - 0.], [-3., -2., -1., 0.], - [-3., -2., -1., 0.]]]]).astype(np.float32) + y = layer(x) + self.assertEqual(y.dtype, np.float32) + self.assertEqual( + tl.to_list(y), + [ + [ + [ + [0.0, 0.0, -3.0, -2.0], + [-1.0, 0.0, 0.0, -3.0], + [-2.0, -1.0, 0.0, 0.0], + [-3.0, -2.0, -1.0, 0.0], + ] + ] + ], + ) - y = layer(x) - self.assertEqual(y.dtype, np.float32) - self.assertEqual( - tl.to_list(y), [[[[0., 0., -3., -2.], [-1., 0., 0., -3.], - [-2., -1., 0., 0.], [-3., -2., -1., 0.]]]]) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/research/rotary_positional_embedding_test.py b/tests/layers/research/rotary_positional_embedding_test.py similarity index 52% rename from trax/layers/research/rotary_positional_embedding_test.py rename to tests/layers/research/rotary_positional_embedding_test.py index 8e049d11e..8dadf63c6 100644 --- a/trax/layers/research/rotary_positional_embedding_test.py +++ b/tests/layers/research/rotary_positional_embedding_test.py @@ -21,26 +21,25 @@ class RelAttentionTest(absltest.TestCase): + def test_rotary_monotonicity(self): + layer = rotary_pe.Rotate() + batch_size = 1 + seq_len = 32 + d_model = 512 + shape = (batch_size, seq_len, d_model) + q, k = np.ones(shape).astype(np.float32), np.ones(shape).astype(np.float32) + q, k = layer(q), layer(k) - def test_rotary_monotonicity(self): - layer = rotary_pe.Rotate() - batch_size = 1 - seq_len = 32 - d_model = 512 - shape = (batch_size, seq_len, d_model) - q, k = np.ones(shape).astype(np.float32), np.ones(shape).astype(np.float32) - q, k = layer(q), layer(k) + self.assertEqual(q.dtype, np.float32) + self.assertEqual(q.shape, shape) - self.assertEqual(q.dtype, np.float32) - self.assertEqual(q.shape, shape) + # Test monotonicity of the resulting dot_product for the two first tokens + # in close proximity + dot_product = np.einsum("bnd, bmd -> bnm", q, k) - # Test monotonicity of the resulting dot_product for the two first tokens - # in close proximity - dot_product = np.einsum('bnd, bmd -> bnm', q, k) + self.assertTrue((dot_product[0, 0, :9] > dot_product[0, 0, 1:10]).all()) + self.assertTrue((dot_product[0, 1, 1:10] > dot_product[0, 1, 2:11]).all()) - self.assertTrue((dot_product[0, 0, :9] > dot_product[0, 0, 1:10]).all()) - self.assertTrue((dot_product[0, 1, 1:10] > dot_product[0, 1, 2:11]).all()) - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/research/sparsity_test.py b/tests/layers/research/sparsity_test.py new file mode 100644 index 000000000..772095953 --- /dev/null +++ b/tests/layers/research/sparsity_test.py @@ -0,0 +1,513 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.research.efficient_attention.""" + +import functools +from absl.testing import parameterized +import jax +import numpy as np +from tensorflow import test + +from trax import fastmath +from trax import shapes +import trax.layers as tl +from tests.layers import test_utils +from trax.layers.research import sparsity + + +class EfficientFeedForwardTest(test.TestCase, parameterized.TestCase): + def test_blocksparse_ff_train(self): + d_model = 1024 + n_experts = 64 + d_ff = d_model * 8 + x_shape = (3, 7, d_model) + with fastmath.use_backend(fastmath.Backend.JAX): + layer = sparsity.BlockSparseFF( + d_ff=d_ff, n_experts=n_experts, temperature=0.7, mode="train" + ) + x = np.ones(x_shape).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_blocksparse_ff_predict_equals_eval(self): + d_model = 1024 + n_experts = 64 + d_ff = d_model * 8 + x_shape = (1, 1, d_model) + temperature = 0.7 + with fastmath.use_backend(fastmath.Backend.JAX): + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + common_kwargs = dict( + d_ff=d_ff, + n_experts=n_experts, + temperature=temperature, + ) + eval_model = sparsity.BlockSparseFF(mode="eval", **common_kwargs) + weights, state = eval_model.init(input_signature) + eval_out, _ = eval_model.pure_fn( + x, weights, state, rng=jax.random.PRNGKey(0) + ) + pred_model = sparsity.BlockSparseFF(mode="predict", **common_kwargs) + _, _ = pred_model.init(input_signature) + pred_out, _ = pred_model.pure_fn( + x, weights, state, rng=jax.random.PRNGKey(0) + ) + self.assertEqual(eval_out.shape, x.shape) + # eval_out and pred_out should be identical. + np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :]) + + def test_sparse_ff_predict_equals_eval(self): + with fastmath.use_backend(fastmath.Backend.JAX): + d_model = 64 + seq_len = 6 + x_shape = (1, seq_len, d_model) + inp = np.ones(x_shape).astype(np.float32) + + model_fn = functools.partial( + sparsity.SparseFF, + d_ff=256, + temperature=0.7, + n_elements_in_block=8, + ) + + configs = [ + {"multiply_by_controller_output": True}, + {"multiply_by_controller_output": False}, + {"ff_chunk_size": 2}, + ] + + test_utils.test_eval_equals_predict_configs(inp, model_fn, configs) + + @parameterized.named_parameters( + ("_mode_train", "train"), ("_mode_eval", "eval"), ("_mode_predict", "predict") + ) + def test_sparse_ff_with_chunking(self, mode): + d_model = 8 + n_elements_in_block = 2 + d_ff = 16 + x_shape = (2, 8, d_model) + temperature = 0.7 + with fastmath.use_backend(fastmath.Backend.JAX): + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + model = sparsity.SparseFF( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + temperature=temperature, + ff_chunk_size=4, + mode=mode, + ) + weights, state = model.init(input_signature) + out, _ = model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + self.assertEqual(out.shape, x.shape) + + @parameterized.named_parameters( + ("_mode_train", "train"), ("_mode_eval", "eval"), ("_mode_predict", "predict") + ) + def test_sparse_ff_multiply(self, mode): + d_model = 8 + n_elements_in_block = 2 + d_ff = 16 + x_shape = (2, 8, d_model) + temperature = 0.7 + with fastmath.use_backend(fastmath.Backend.JAX): + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + model = sparsity.SparseFF( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + temperature=temperature, + ff_chunk_size=4, + mode=mode, + multiply_by_controller_output=True, + ) + weights, state = model.init(input_signature) + out, _ = model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + self.assertEqual(out.shape, x.shape) + + def test_sparse_ff_kernel_scaling(self): + d_model = 8 + n_elements_in_block = 2 + d_ff = 16 + x_shape = (2, 8, d_model) + temperature = 0.7 + with fastmath.use_backend(fastmath.Backend.JAX): + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + model = sparsity.SparseFF( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + temperature=temperature, + ff_chunk_size=4, + mode="train", + kernel_scaling=True, + ) + weights, state = model.init(input_signature) + out, _ = model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + self.assertEqual(out.shape, x.shape) + + def test_switchsparse_ff_train(self): + d_model = 1024 + n_experts = 64 + d_ff = d_model * 8 + x_shape = (3, 7, d_model) + layer = sparsity.SwitchSparseFF(d_ff=d_ff, n_experts=n_experts, mode="train") + x = np.ones(x_shape).astype(np.float32) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_switchsparse_ff_predict_equals_eval(self): + d_model = 1024 + n_experts = 64 + d_ff = d_model * 8 + x_shape = (1, 1, d_model) + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + eval_model = sparsity.SwitchSparseFF( + mode="eval", d_ff=d_ff, n_experts=n_experts + ) + weights, state = eval_model.init(input_signature) + eval_out, _ = eval_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + pred_model = sparsity.SwitchSparseFF( + mode="predict", d_ff=d_ff, n_experts=n_experts + ) + pred_model.init(input_signature) + pred_out, _ = pred_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + self.assertEqual(eval_out.shape, x.shape) + # eval_out and pred_out should be identical. + np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :]) + + +class ReversibleReshapePermuteTest(test.TestCase): + def test_reversible_permute(self): + layer = sparsity.ReversibleReshapePermute() + x = np.array([[1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7]]) + layer.init(shapes.signature(x)) + ys = layer(x) + self.assertEqual( + tl.to_list(ys), [[1, 3, 5, 7, 2, 4, 6, 8], [0, 2, 4, 6, 1, 3, 5, 7]] + ) + rev_x = layer.reverse(ys, weights=layer.weights) + self.assertEqual(tl.to_list(x), tl.to_list(rev_x)) + + +class ReversibleRandomPermuteTest(test.TestCase): + def test_reversible_permute(self): + layer = sparsity.ReversibleRandomPermute() + x = np.array( + [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], + [0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 11, 12, 13], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], + ] + ) + layer.init(shapes.signature(x)) + ys = layer(x) + # this assert will fail once per ~87B runs, but it's okay + self.assertNotEqual(tl.to_list(ys), tl.to_list(x)) + + self.assertEqual(tl.to_list(ys[0]), tl.to_list(ys[2])) + self.assertNotEqual(tl.to_list(ys[0]), tl.to_list(ys[1])) + rev_x = layer.reverse(ys, weights=layer.weights) + self.assertEqual(tl.to_list(x), tl.to_list(rev_x)) + + +class LocallyConnectedDenseTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.LocallyConnectedDense(2, 8) + x = np.array([[2, 5, 3, 4], [0, 1, 2, 3]]) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (2, 16)) + + +class SparseDenseWithOptionsTest(test.TestCase): + def test_simple_call(self): + d_input, d_output = 16, 32 + settings = [ + (None, 0, 0, False), + (None, 0, 0, True), + ("einsum", 0, 0, False), + ("lowrank", 0, 8, False), + ("mult", 2, 0, False), + ("mult", 2, 0, True), + ("local", 2, 0, False), + ("local3", 2, 0, False), + ] + for stype, sparsity_level, d_lowrank, use_bfloat16 in settings: + layer = sparsity.SparseDenseWithOptions( + d_output, + d_input=d_input, + sparsity_type=stype, + sparsity=sparsity_level, + d_lowrank=d_lowrank, + use_bfloat16=use_bfloat16, + ) + x = np.ones((1, 1, d_input)) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual( + y.shape, + (1, 1, d_output), + msg="[{}->{}] {} - {} - {} - {}".format( + d_input, d_output, stype, sparsity_level, d_lowrank, use_bfloat16 + ), + ) + + +class ModularCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.ModularCausalAttention(d_feature=4, n_heads=2, sparsity=2) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + +class LowRankCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.LowRankCausalAttention(d_feature=4, n_heads=2, lowrank=2) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + +class MultiplicativeCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.MultiplicativeCausalAttention( + d_feature=4, n_heads=2, sparsity=2 + ) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + +class MultiplicativeModularCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.MultiplicativeModularCausalAttention( + d_feature=4, n_heads=2, sparsity=2 + ) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + +class MultiplicativeConvCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.MultiplicativeConvCausalAttention( + d_feature=4, n_heads=2, sparsity=2 + ) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + def test_various_calls(self): + list_kwargs = [] + for share_qk in [True, False]: + for output in ["none", "mult", "conv", "multconv"]: + for concat in ["original", "fixed", "none"]: + kwargs = { + "share_qk": share_qk, + "output_layer_type": output, + "v_concat_type": concat, + } + list_kwargs.append(kwargs) + for kwargs in list_kwargs: + layer = sparsity.MultiplicativeConvCausalAttention( + d_feature=4, n_heads=2, sparsity=2, **kwargs + ) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + def test_predict_equals_eval(self): + with fastmath.use_backend(fastmath.Backend.JAX): + d_model = 32 + seq_len = 5 + x_shape = (1, seq_len, d_model) + inp = np.ones(x_shape).astype(np.float32) + + model_fn = functools.partial( + sparsity.MultiplicativeConvCausalAttention, + d_feature=d_model, + n_heads=4, + sparsity=4, + ) + + list_kwargs = [] + for share_qk in [True, False]: + for output in ["none", "mult", "conv", "multconv"]: + for concat in ["original", "fixed", "none"]: + kwargs = { + "share_qk": share_qk, + "output_layer_type": output, + "v_concat_type": concat, + } + list_kwargs.append(kwargs) + + test_utils.test_eval_equals_predict_configs(inp, model_fn, list_kwargs) + + +class FavorTest(test.TestCase): + def test_call_and_grad(self): + layer_partial = tl.Serial( + tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), + sparsity.Favor(d_feature=4, n_heads=2), + tl.Select([0], n_in=2), + ) + layer = tl.Serial( + tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), + sparsity.Favor(d_feature=4, n_heads=2), + tl.Select([0], n_in=2), + tl.WeightedCategoryCrossEntropy(), + ) + x = np.ones((1, 2), dtype=np.int32) + w = np.ones_like(x).astype(np.float32) + x_sig = shapes.signature(x) + w_sig = shapes.signature(w) + layer_partial.init(x_sig) + y = layer_partial(x) + self.assertEqual(y.shape, (1, 2, 4)) + layer.init((x_sig, x_sig, w_sig)) + y = layer((x, x, w)) + self.assertEqual(y.shape, ()) + state = layer.state + rng = fastmath.random.get_prng(0) + fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] + g = fastmath.grad(fwd)(layer.weights, (x, x, w)) + self.assertEqual(g[0][1][0].shape, (3, 4)) + + def test_call_and_grad_approximate_softmax(self): + layer_partial = tl.Serial( + tl.Branch(tl.Embedding(11, 12), tl.PaddingMask()), + sparsity.Favor( + d_feature=12, + n_heads=3, + n_random_features=128, + use_approximate_softmax=True, + ), + tl.Select([0], n_in=2), + ) + layer = tl.Serial( + tl.Branch(tl.Embedding(11, 12), tl.PaddingMask()), + sparsity.Favor( + d_feature=12, + n_heads=3, + n_random_features=128, + use_approximate_softmax=True, + ), + tl.Select([0], n_in=2), + tl.WeightedCategoryCrossEntropy(), + ) + x = np.ones((3, 5), dtype=np.int32) + w = np.ones_like(x).astype(np.float32) + x_sig = shapes.signature(x) + w_sig = shapes.signature(w) + layer_partial.init(x_sig) + y = layer_partial(x) + self.assertEqual(y.shape, (3, 5, 12)) + layer.init((x_sig, x_sig, w_sig)) + y = layer((x, x, w)) + self.assertEqual(y.shape, ()) + state = layer.state + rng = fastmath.random.get_prng(0) + fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] + g = fastmath.grad(fwd)(layer.weights, (x, x, w)) + self.assertEqual(g[0][1][0].shape, (11, 12)) + + def test_causal_call_and_grad(self): + layer = tl.Serial( + tl.Dense(4), sparsity.CausalFavor(d_feature=4, n_heads=2), tl.L2Loss() + ) + x = np.random.uniform(size=(1, 2, 4)).astype(np.float32) + w = np.ones_like(x) + x_sig = shapes.signature(x) + w_sig = shapes.signature(w) + layer.init((x_sig, x_sig, w_sig)) + y = layer((x, x, w)) + self.assertEqual(y.shape, ()) + state = layer.state + rng = fastmath.random.get_prng(0) + fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] + g = fastmath.grad(fwd)(layer.weights, (x, x, w)) + self.assertEqual(g[0][0].shape, (4, 4)) + + +if __name__ == "__main__": + test.main() diff --git a/trax/layers/reversible_test.py b/tests/layers/reversible_test.py similarity index 68% rename from trax/layers/reversible_test.py rename to tests/layers/reversible_test.py index 14fb67eaf..8ca45727b 100644 --- a/trax/layers/reversible_test.py +++ b/tests/layers/reversible_test.py @@ -27,15 +27,14 @@ class ReversibleLayerTest(parameterized.TestCase): + @parameterized.named_parameters([("_" + b.value, b) for b in BACKENDS]) + def test_reversible_swap(self, backend): + with fastmath.use_backend(backend): + layer = tl.ReversibleSwap() + xs = [np.array([1, 2]), np.array([10, 20])] + ys = layer(xs) + self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]]) - @parameterized.named_parameters([('_' + b.value, b) for b in BACKENDS]) - def test_reversible_swap(self, backend): - with fastmath.use_backend(backend): - layer = tl.ReversibleSwap() - xs = [np.array([1, 2]), np.array([10, 20])] - ys = layer(xs) - self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]]) - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/rnn_test.py b/tests/layers/rnn_test.py new file mode 100644 index 000000000..9b8bc8a6c --- /dev/null +++ b/tests/layers/rnn_test.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for rnn layers.""" + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import trax.layers as tl +from trax import fastmath +from trax import shapes + +BACKENDS = [fastmath.Backend.JAX] + + +@parameterized.named_parameters(("_" + b.value, b) for b in BACKENDS) +class RnnTest(parameterized.TestCase): + def test_conv_gru_cell(self, backend): + with fastmath.use_backend(backend): + layer = tl.ConvGRUCell(9, kernel_size=(3, 3)) + x = np.ones((8, 1, 7, 9)) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_gru_cell(self, backend): + with fastmath.use_backend(backend): + layer = tl.GRUCell(9) + xs = [np.ones((8, 7, 9)), np.ones((8, 7, 9))] + _, _ = layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual([y.shape for y in ys], [(8, 7, 9), (8, 7, 9)]) + + def test_lstm_cell(self, backend): + with fastmath.use_backend(backend): + layer = tl.LSTMCell(9) + xs = [np.ones((8, 9)), np.ones((8, 18))] + _, _ = layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual([y.shape for y in ys], [(8, 9), (8, 18)]) + + def test_sru(self, backend): + with fastmath.use_backend(backend): + layer = tl.SRU(7) + x = np.ones((8, 9, 7), np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_names(self, backend): + with fastmath.use_backend(backend): + layer = tl.LSTM(3) + self.assertEqual("LSTM_3", str(layer)) + layer = tl.GRU(5) + self.assertEqual("GRU_5", str(layer)) + layer = tl.SRU(7) + self.assertEqual("SRU_7", str(layer)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/test_utils.py b/tests/layers/test_utils.py new file mode 100644 index 000000000..2b93f4431 --- /dev/null +++ b/tests/layers/test_utils.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for testing.""" + +import copy +import functools +from absl.testing import absltest + +import numpy as np + +from trax import fastmath +from trax import layers as tl +from trax import shapes + +import pytest + + +@absltest.skip +@pytest.mark.skip(reason="This is helper method not direct test") +def test_eval_is_deterministic(inp, model_fn, message=""): + """Utility method for testing if eval mode is deterministic. + + Args: + inp: input fed to the model. It can be a tensor, or a tuple of tensors. + model_fn: function creating a model after calling with `mode` argument. + message: Optional message to show when outputs of eval/predict mode don't + match. + """ + with fastmath.use_backend(fastmath.Backend.JAX): + model_eval1 = model_fn(mode="eval") + model_eval2 = model_fn(mode="eval") + + input_signature = shapes.signature(inp) + model_eval1.init(input_signature) + model_eval2.init(input_signature) + model_eval1.save_to_file("/tmp/unique_weights") + model_eval2.init_from_file( + "/tmp/unique_weights", weights_only=True, input_signature=input_signature + ) + + rng = fastmath.random.get_prng(0) + output_eval1 = model_eval1(inp, rng=rng) + if not isinstance(output_eval1, (tuple, list)): + # We will automatically check each and every tensor returned. + output_eval1 = [output_eval1] + + output_eval2 = model_eval2(inp, rng=rng) + if not isinstance(output_eval2, (tuple, list)): + # We will automatically check each and every tensor returned. + output_eval2 = [output_eval2] + + np.testing.assert_equal(len(output_eval1), len(output_eval2)) + for out1, out2 in zip(output_eval1, output_eval2): + np.testing.assert_array_almost_equal( + out1, out2, decimal=5, err_msg="Non-deterministic.{}".format(message) + ) + + +@absltest.skip +@pytest.mark.skip(reason="This is helper method not direct test") +def test_eval_equals_predict( + inp, model_fn, seq_axis=1, seq_tensor=None, init_tokens=3, message="" +): + """Utility method for testing equivalence of predict and eval modes. + + Args: + inp: input fed to the model. It can be a tensor, or a tuple of tensors. + model_fn: function creating a model after calling with `mode` argument. + seq_axis: axis of sequence_length. In predict mode we iterate over this + axis. By default `1`, which is 2nd dimension. + seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor + in this tuple on which we iterate the sequence. + init_tokens: how many tokens should be passed to the first `predict` call. + message: Optional message to show when outputs of eval/predict mode don't + match. + """ + with fastmath.use_backend(fastmath.Backend.JAX): + model_eval = model_fn(mode="eval") + model_predict = model_fn(mode="predict") + + input_signature = shapes.signature(inp) + model_eval.init(input_signature) + model_predict.init(input_signature) + model_eval.save_to_file("/tmp/unique_weights") + model_predict.init_from_file( + "/tmp/unique_weights", weights_only=True, input_signature=input_signature + ) + + rng = fastmath.random.get_prng(0) + output_eval = model_eval(inp, rng=rng) + if not isinstance(output_eval, (tuple, list)): + # We will automatically check each and every tensor returned. + output_eval = [output_eval] + + if seq_tensor is None: + length = inp.shape[seq_axis] + else: + length = inp[seq_tensor].shape[seq_axis] + + assert length >= init_tokens + 2 # Required to properly test predict mode. + indices_list = [(0, init_tokens)] + [ + (i, i + 1) for i in range(init_tokens, length) + ] + + for indices in indices_list: + start, end = indices + if seq_tensor is None: + new_inp = inp.take(indices=np.arange(start, end), axis=seq_axis) + else: + new_inp = list(inp) + new_inp[seq_tensor] = new_inp[seq_tensor].take( + indices=np.arange(start, end), axis=seq_axis + ) + + output_predict = model_predict(new_inp, rng=rng) + if not isinstance(output_predict, (tuple, list)): + # We will automatically check each and every tensor returned. + output_predict = [output_predict] + + np.testing.assert_equal(len(output_predict), len(output_eval)) + for outp, oute in zip(output_predict, output_eval): + np.testing.assert_array_almost_equal( + oute.take(indices=np.arange(start, end), axis=seq_axis), + outp.take(indices=np.arange(0, end - start), axis=seq_axis), + decimal=5, + err_msg="Error on element {} out of {}.{}".format( + indices, length, message + ), + ) + + +@absltest.skip +@pytest.mark.skip(reason="This is helper method not direct test") +def test_eval_equals_predict_configs( + inp, model_fn, configs, seq_axis=1, seq_tensor=None, message="" +): + """Utility method for testing equivalence of predict and eval modes. + + This function iterates over a list of dictionaries `confis`, and runs the test + on models with each configuration. + + Args: + inp: input fed to the model. It can be a tensor, or a tuple of tensors. + model_fn: function creating a model after calling with `mode` argument. + configs: List of dictionaries, which contain configs to be fed into + `model_fn`. + seq_axis: axis of sequence_length. In predict mode we iterate over this + axis. By default `1`, which is 2nd dimension. + seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor + in this tuple on which we iterate the sequence. + message: Optional message to show when outputs of eval/predict mode don't + match. + """ + for config in configs: + model_fn_configured = functools.partial(model_fn, **config) + test_eval_equals_predict( + inp, + model_fn_configured, + seq_axis=seq_axis, + seq_tensor=seq_tensor, + message=" Config: {}.{}".format(config, message), + ) + + +@absltest.skip +@pytest.mark.skip(reason="This is helper method not direct test") +def test_eval_equals_predict_discrete(model_fn, vocab_size=10, length=5, batch_size=3): + """Tests the equivalence of eval and predict modes for discrete models.""" + with fastmath.use_backend(fastmath.Backend.JAX): + model_slow = model_fn(mode="eval", vocab_size=vocab_size) + model_fast = model_fn(mode="predict", vocab_size=vocab_size) + rng = fastmath.random.get_prng(0) + input_signature = shapes.ShapeDtype((batch_size, 1), np.int32) + # Given the same rng, both models initialize with the same parameters. + model_slow.init(input_signature, rng) + model_fast.init(input_signature, rng) + + buf = np.zeros((batch_size, length), dtype=np.int32) + next_sym = np.zeros((batch_size, 1), dtype=np.int32) + + for index in range(length): + logits_slow = model_slow(buf, rng=rng) + logits_fast = model_fast(next_sym, rng=rng) + np.testing.assert_array_almost_equal( + logits_slow[:, index, :], + logits_fast[:, 0, :], + decimal=5, + ) + next_sym = np.random.randint(vocab_size, size=(batch_size, 1)) + buf[:, index] = next_sym[:, 0] + + +class MockTransformerLM(tl.Layer): + r"""Mock TransformerLM for testing autoregressive sampling routines. + + Mimics the behavior of a perfectly-trained, deterministic TransformerLM. + Allows to specify the \sigma^* -> \sigma function implemented by the model + and to make assertions about the input sequence passed to the model. + + Supports two modes: stateful "predict" for fast inference, and stateless + non-"predict" ("train", "eval" etc). + + Useful for testing any logic that relies on autoregressive sampling, as it + removes the additional layer of complexity related to training a model or + maintaining a pretrained one. Makes the tests run MUCH faster. + + Does not support acceleration. Do not wrap in tl.Accelerate(). + """ + + def __init__(self, sequence_fn, mode, vocab_size): + super().__init__() + + self._sequence_fn = sequence_fn + self._mode = mode + self._vocab_size = vocab_size + + self._prediction_buffers = None + + @property + def state(self): + return copy.deepcopy(self._prediction_buffers) + + @state.setter + def state(self, state): + self._prediction_buffers = copy.deepcopy(state) + + def _output_symbol_predict(self, input_symbols, prediction_buffer): + prediction_buffer.extend(input_symbols) + output_symbol = self._sequence_fn(np.array(prediction_buffer)) + return np.array([output_symbol]) + + def _output_symbols_eval(self, input_symbols, prediction_buffer): + del prediction_buffer + + # Add a leading 0 token to imitate ShiftRight. + input_symbols = np.concatenate(([0], input_symbols)) + + # Call sequence_fn repeatedly along the input sequence. + return np.array( + [ + self._sequence_fn(input_symbols[:end]) + for end in range(1, len(input_symbols)) + ] + ) + + def _symbols_to_logits(self, symbols): + # Assert that symbols are discrete. + assert np.issubdtype(symbols.dtype, np.integer) + # Assert that 0 <= symbols < vocab_size. + np.testing.assert_array_less(-1, symbols) + np.testing.assert_array_less(symbols, self._vocab_size) + + # Return almost-determinisitc logits: + # e^1000 / (e^1000 + vocab_size) ~= 1 + return tl.one_hot(symbols, n_categories=self._vocab_size) * 1000.0 + + def __call__(self, inputs, rng=None): + del rng + + assert inputs.ndim == 2, "The input sequences should have exactly two axes." + + if self._prediction_buffers is None: + # Initialize the buffer. + batch_size = inputs.shape[0] + # [[]] * batch_size would create multiple references to the same + # list, and we want separate lists. + self._prediction_buffers = [[] for _ in range(batch_size)] + + if self._mode == "predict": + output_fn = self._output_symbol_predict + else: + output_fn = self._output_symbols_eval + + # Calculate the output separately for each sequence in the batch. + output_symbols = np.array( + [ + output_fn(input_seq, pred_buffer) + for (input_seq, pred_buffer) in zip(inputs, self._prediction_buffers) + ] + ) + return self._symbols_to_logits(output_symbols) + + def assert_prediction_buffers_equal(self, expected_buffers): + if self._prediction_buffers is None: + batch_size = expected_buffers.shape[0] + actual_buffers = np.empty((batch_size, 0)) + else: + actual_buffers = np.array(self._prediction_buffers) + + np.testing.assert_array_equal(actual_buffers, expected_buffers) diff --git a/tests/layers/test_utils_test.py b/tests/layers/test_utils_test.py new file mode 100644 index 000000000..8d8c07299 --- /dev/null +++ b/tests/layers/test_utils_test.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.test_utils.""" + +import functools + +from absl.testing import absltest +import numpy as np + +from tests.layers import test_utils +from trax.supervised import decoding + + +def arithmetic_sequence(input_seq, limit=10): + # Increment the last symbol. Wrap to [0, 10). + return (input_seq[-1] + 1) % limit + + +class TestUtilsTest(absltest.TestCase): + def test_mock_transformer_lm_eval_equals_predict(self): + model_fn = functools.partial( + test_utils.MockTransformerLM, + sequence_fn=arithmetic_sequence, + vocab_size=10, + ) + test_utils.test_eval_equals_predict_discrete(model_fn, vocab_size=10) + + def test_mock_transformer_lm_decodes_arithmetic_sequence(self): + model = test_utils.MockTransformerLM( + sequence_fn=arithmetic_sequence, + vocab_size=10, + mode="predict", + ) + output = decoding.autoregressive_sample( + model, max_length=5, start_id=0, eos_id=-1, accelerate=False + ) + + # Sequence including the leading 0 and the last predicted symbol. + full_seq = list(range(6)) + # decoding.autoregressive_sample doesn't return the leading 0. + np.testing.assert_array_equal(output, [full_seq[1:]]) + # The prediction buffers don't include the last predicted symbol. + model.assert_prediction_buffers_equal([full_seq[:-1]]) + + def test_mock_transformer_lm_rewinds(self): + model = test_utils.MockTransformerLM( + sequence_fn=arithmetic_sequence, + vocab_size=10, + mode="predict", + ) + sample_3 = functools.partial( + decoding.autoregressive_sample, + max_length=3, + eos_id=-1, + accelerate=False, + ) + + # Generate the 3 initial symbols. + init_output = sample_3(model, start_id=0) + np.testing.assert_array_equal(init_output, [[1, 2, 3]]) + state = model.state + + # Generate the next 3 symbols. + next_output = sample_3(model, start_id=init_output[0, -1]) + np.testing.assert_array_equal(next_output, [[4, 5, 6]]) + + # Rewind and generate the last 3 symbols again. + model.state = state + next_output = sample_3(model, start_id=init_output[0, -1]) + np.testing.assert_array_equal(next_output, [[4, 5, 6]]) + + # Check the buffers. + model.assert_prediction_buffers_equal([[0, 1, 2, 3, 4, 5]]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/atari_cnn_test.py b/tests/models/atari_cnn_test.py new file mode 100644 index 000000000..95ce8b034 --- /dev/null +++ b/tests/models/atari_cnn_test.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.models.atari_cnn.""" + +import functools +import operator as op + +import numpy as np +from tensorflow import test + +from trax.models import atari_cnn +from trax.shapes import ShapeDtype + + +class AtariCnnTest(test.TestCase): + def test_computes(self): + hidden_size = (4, 4) + output_size = 6 + + model = atari_cnn.AtariCnn(hidden_sizes=hidden_size, output_size=output_size) + + B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name + input_signature = ShapeDtype((1, 1) + OBS) + + _, _ = model.init(input_signature) + x = np.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( + B, T + 1, *OBS + ) + y = model(x) + self.assertEqual((B, T + 1, output_size), y.shape) + + +class FrameStackMLPTest(test.TestCase): + def test_computes(self): + hidden_size = (4, 4) + output_size = 6 + model = atari_cnn.FrameStackMLP( + hidden_sizes=hidden_size, output_size=output_size + ) + B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name + input_signature = ShapeDtype((1, 1, OBS)) + + _, _ = model.init(input_signature) + x = np.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) + y = model(x) + + self.assertEqual((B, T + 1, output_size), y.shape) + + +if __name__ == "__main__": + test.main() diff --git a/trax/models/mlp_test.py b/tests/models/mlp_test.py similarity index 71% rename from trax/models/mlp_test.py rename to tests/models/mlp_test.py index 40d335610..befa9371f 100644 --- a/trax/models/mlp_test.py +++ b/tests/models/mlp_test.py @@ -15,24 +15,21 @@ """Tests for MLP.""" -from absl.testing import absltest import numpy as np +from absl.testing import absltest -from trax import fastmath from trax import shapes from trax.models import mlp class MLPTest(absltest.TestCase): - - def test_mlp_forward_shape(self): - model = mlp.MLP(layer_widths=(32, 16, 8)) - x = np.ones((7, 28, 28, 3)).astype(np.float32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (7, 8)) - + def test_mlp_forward_shape(self): + model = mlp.MLP(layer_widths=(32, 16, 8)) + x = np.ones((7, 28, 28, 3)).astype(np.float32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (7, 8)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/trax/models/neural_gpu_test.py b/tests/models/neural_gpu_test.py similarity index 71% rename from trax/models/neural_gpu_test.py rename to tests/models/neural_gpu_test.py index 0eaa77dbf..f2919413d 100644 --- a/trax/models/neural_gpu_test.py +++ b/tests/models/neural_gpu_test.py @@ -23,14 +23,13 @@ class NeuralGPUTest(absltest.TestCase): + def test_ngpu(self): + model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=22) + x = np.ones((3, 5, 7)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 5, 7, 22)) - def test_ngpu(self): - model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=22) - x = np.ones((3, 5, 7)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 5, 7, 22)) - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/reformer/reformer_e2e_test.py b/tests/models/reformer/reformer_e2e_test.py new file mode 100644 index 000000000..a6fe371c1 --- /dev/null +++ b/tests/models/reformer/reformer_e2e_test.py @@ -0,0 +1,81 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End to end test for Reformer.""" + +import os + +import gin +from absl.testing import absltest + +from trax import test_utils +from trax.supervised import trainer_lib + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath( + os.path.join(pkg_dir, "../../../resources/models/reformer/testdata") +) +_CONFIG_DIR = os.path.normpath( + os.path.join(pkg_dir, "../../../resources/supervised/configs/") +) + + +class ReformerE2ETest(absltest.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + gin.add_config_file_search_path(_CONFIG_DIR) + test_utils.ensure_flag("test_tmpdir") + + def test_reformer_wmt_ende(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + + gin.parse_config_file("reformer_wmt_ende.gin") + + gin.bind_parameter("data_streams.data_dir", _TESTDATA) + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("Reformer.n_encoder_layers", n_layers) + gin.bind_parameter("Reformer.n_decoder_layers", n_layers) + gin.bind_parameter("Reformer.d_ff", d_ff) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + def test_reformer_copy(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + d_model = 32 + + gin.parse_config_file("reformer_copy.gin") + + gin.bind_parameter("data_streams.data_dir", _TESTDATA) + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("ReformerLM.n_layers", n_layers) + gin.bind_parameter("ReformerLM.d_ff", d_ff) + gin.bind_parameter("ReformerLM.d_model", d_model) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/reformer/reformer_test.py b/tests/models/reformer/reformer_test.py new file mode 100644 index 000000000..485cbb853 --- /dev/null +++ b/tests/models/reformer/reformer_test.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Reformer models.""" + +import functools + +from absl.testing import absltest +from absl.testing import parameterized +import gin +import numpy as np + +from trax import fastmath +from trax import layers as tl +from trax import shapes +from trax.models.reformer import reformer + + +BACKENDS = [fastmath.Backend.JAX] + + +def short_name(b): + if b == fastmath.Backend.JAX: + return "jax" + else: + return "tf" + + +class ReformerTest(parameterized.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def _lsh_self_attention_fn(self): + return functools.partial( + tl.LSHSelfAttention, + attention_dropout=0.0, + chunk_len=64, + n_buckets=[32, 32], + n_chunks_after=0, + n_chunks_before=1, + n_hashes=1, + n_parallel_heads=1, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def _timebin_self_attention_fn(self, use_reference_code=False): + return functools.partial( + tl.SelfAttention, + attention_dropout=0.05, + chunk_len=64, + n_chunks_before=1, + n_parallel_heads=1, + use_reference_code=use_reference_code, + ) + + def test_reformer_lm_forward_shape(self): + vocab_size = 16 + model = reformer.ReformerLM( + vocab_size, + d_model=32, + d_ff=64, + d_attention_key=16, + d_attention_value=16, + n_layers=1, + n_heads=2, + max_len=16, + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + @absltest.skip + def test_reformer_lm_lsh(self): + """ + Problems with: + - res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)), + probably dropout_shared_axes should be [] + - Scan in Chunk res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size)) shape assertion is wrong + """ + lsh_self_attention = self._lsh_self_attention_fn() + timebin_self_attention = self._timebin_self_attention_fn() + + model = reformer.ReformerLM( + vocab_size=256, + d_model=256, + d_ff=512, + d_attention_key=64, + d_attention_value=64, + n_layers=2, + n_heads=2, + dropout=0.05, + max_len=65536, + attention_type=[timebin_self_attention, lsh_self_attention], + pos_axial_shape=(256, 256), + pos_d_axial_embs=(64, 192), + ff_activation=tl.Relu, + ff_use_sru=0, + ff_chunk_size=8192, + mode="train", + ) + x = (np.ones((1, 65536)).astype(np.int32), np.ones((1, 65536)).astype(np.int32)) + weights, state = model.init(shapes.signature(x)) + + @fastmath.jit + def mock_training_step(x, weights, state, rng): + def compute_mock_loss(weights): + logits, new_state = model.pure_fn(x, weights, state, rng) + loss = fastmath.numpy.mean(logits[..., 0]) + return loss, (new_state, logits) + + gradients, (new_state, logits) = fastmath.grad( + compute_mock_loss, has_aux=True + )(weights) + new_weights = fastmath.nested_map_multiarg( + lambda w, g: w - 1e-4 * g, weights, gradients + ) + return new_weights, new_state, logits + + weights, state, logits = mock_training_step( + x, weights, state, fastmath.random.get_prng(0) + ) + self.assertEqual(logits.shape, (1, 65536, 256)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/configurable_transformer_test.py b/tests/models/research/configurable_transformer_test.py new file mode 100644 index 000000000..ca57761d8 --- /dev/null +++ b/tests/models/research/configurable_transformer_test.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer models.""" + +import functools + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +from trax import fastmath +from trax import layers as tl +from trax import shapes +from tests.layers import test_utils +from trax.models.research import configurable_transformer as ct + + +class ConfigurableTransformerTest(parameterized.TestCase): + def test_transformer_lm_forward_shape(self): + vocab_size = 16 + model = ct.ConfigurableTransformerLM( + vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2 + ) + x = np.ones((3, 5)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 5, vocab_size)) + + def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + model = ct.ConfigurableTransformer( + input_vocab_size, + output_vocab_size, + d_model=32, + d_ff=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + ) + xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + y, _ = model(xs) + + vocab_size = output_vocab_size or input_vocab_size + self.assertEqual(y.shape, (3, 5, vocab_size)) + + @parameterized.named_parameters( + ("same_vocab", 16, None), ("same_size", 16, 16), ("different_size", 16, 50) + ) + def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + """Run the Transformer forward and check output shape.""" + self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) + + def test_dot_product_causal_attention_fast_inference(self): + self._test_fast_inference(length=5) + + def _test_fast_inference(self, length): + with fastmath.use_backend(fastmath.Backend.JAX): + model_fn = functools.partial( + ct.ConfigurableTransformerLM, + vocab_size=16, + d_model=4, + d_ff=8, + n_layers=2, + n_heads=2, + ) + batch_size = 2 + inp = np.zeros((batch_size, length), dtype=np.int32) + + test_utils.test_eval_equals_predict(inp, model_fn) + + def test_sparse_configurable_transformer_fast_inference(self): + self._test_sparse_fast_inference(length=5) + + def _test_sparse_fast_inference(self, length): + with fastmath.use_backend(fastmath.Backend.JAX): + vocab_size = 16 + d_model = 4 + batch_size = 2 + + encoder_decoder_attention_type = functools.partial( + tl.MultiplicativeConvCausalAttention, + sparsity=2, + length_kernel_size=1, + ) + + model_fn = functools.partial( + ct.ConfigurableTransformer, + input_vocab_size=vocab_size, + d_model=d_model, + d_ff=8, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + loss_sparsity=2, + ff_sparsity=2, + encoder_decoder_attention_type=encoder_decoder_attention_type, + ff_use_sru=(1, 4), + ) + + inp = np.random.randint(vocab_size, size=(batch_size, length)) + out = np.zeros((batch_size, length), dtype=np.int32) + + test_utils.test_eval_equals_predict((inp, out), model_fn, seq_tensor=1) + + @parameterized.named_parameters( + ("positional_encoding", None), + ("fixed_base_positional_encoding", "fixed-base"), + ("infinite_positional_encoding", "infinite"), + ("infinite_affine_positional_encoding", "infinite-affine"), + ("axial_positional_encoding", (2, 16)), + ) + def test_positional_encoder(self, pos_axial_shape): + # dim should divide FixedBasePositionalEncoding.n_digits + batch, length, dim = 2, 32, 8 + input_shape = (batch, length, dim) + vocab_size = 32 + x = np.random.randint(0, vocab_size - 1, input_shape) + # should sum to dim + pos_d_axial_embs = (4, 4) + + positional_encoding = ct.PositionalEncoder( + "train", + dropout=0.1, + max_len=length, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + ) + _, _ = positional_encoding.init(shapes.signature(x)) + y = positional_encoding(x) + self.assertEqual(y.shape, input_shape) + + @parameterized.named_parameters( + ("input_vocab_size_only", 32, None), + ("output_vocab_size_only", None, 32), + ("same_input_output_vocab_size", 32, 32), + ("different_input_output_vocab_size", 32, 16), + ) + def test_embedding_and_positional_encodings( + self, input_vocab_size, output_vocab_size + ): + d_model = 16 + max_len = 32 + batch = 2 + input_shape = (batch, max_len) + output_vocab_size_expected = output_vocab_size or input_vocab_size + x_out = np.random.randint(0, output_vocab_size_expected - 1, input_shape) + if input_vocab_size is None: + x_in = np.random.uniform(size=list(input_shape) + [2]) + else: + x_in = np.random.randint(0, input_vocab_size - 1, input_shape) + + ( + in_encoder, + out_encoder, + output_vocab_size_result, + ) = ct.EmbeddingAndPositionalEncodings( + input_vocab_size, + d_model, + "train", + 0.1, + [-2], + max_len, + output_vocab_size=output_vocab_size, + pos_axial_shape=None, + pos_d_axial_embs=None, + ) + + self.assertEqual(output_vocab_size_result, output_vocab_size_expected) + + model_in = tl.Serial(in_encoder) + model_out = tl.Serial(out_encoder) + + model_in.init(shapes.signature(x_in)) + model_out.init(shapes.signature(x_out)) + + y = model_in(x_in) + self.assertEqual(y.shape, input_shape + (d_model,)) + + y = model_out(x_out) + self.assertEqual(y.shape, input_shape + (d_model,)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/hourglass_test.py b/tests/models/research/hourglass_test.py new file mode 100644 index 000000000..09e2381c7 --- /dev/null +++ b/tests/models/research/hourglass_test.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Hourglass model.""" + +from absl.testing import absltest +from absl.testing import parameterized +import gin +import jax +import numpy as np +from trax import fastmath +from trax import layers as tl +from trax import shapes +import trax.layers.research.resampling as resampling +import trax.models.research.hourglass as hourglass + + +class HourglassTest(parameterized.TestCase): + def _check_forward_shape(self, model, input_shape, output_vocab_size): + x = np.ones(input_shape).astype(np.int32) + model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (*input_shape, output_vocab_size)) + + def test_hourglass_lm_forward_shape(self): + d_model = 16 + vocab_size = 7 + model = hourglass.HourglassLM( + vocab_size, + hierarchy="2@3 2@6 2@3", + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + batch_size, seq_len = 3, 24 + self._check_forward_shape( + model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size + ) + + def test_lsh_attention_in_vanilla(self): + d_model = 16 + vocab_size = 7 + + gin.bind_parameter( + "PureLSHSelfAttentionWrapper.pure_lsh_implementation", + tl.PureLSHSelfAttention, + ) + gin.bind_parameter("PureLSHSelfAttention.chunk_len", 2) + + model = hourglass.HourglassLM( + vocab_size, + hierarchy="2@3", + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + vanilla_attn_type=tl.PureLSHSelfAttentionWrapper, + downsampling_fn=resampling.LinearPooling, + upsampling_fn=resampling.LinearUpsampling, + ) + + batch_size, seq_len = 3, 12 + self._check_forward_shape( + model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size + ) + + def _test_autoregressive_property(self, model, input_shape, output_vocab_size): + rng_1 = jax.random.PRNGKey(0) + rng_2 = jax.random.PRNGKey(1) + + def _get_output_logits(unitialized_eval_model: tl.Layer, x): + input_signature = shapes.signature(x) + unitialized_eval_model.init(input_signature, rng=rng_1, use_cache=False) + + output_logits, *_ = unitialized_eval_model(x, rng=rng_1) + return output_logits + + def check_autoregressive_property(model): + with fastmath.use_backend(fastmath.Backend.JAX): + x_1 = jax.random.randint(rng_1, input_shape, 0, output_vocab_size) + y_1 = _get_output_logits(model, x_1) + + x_2 = jax.random.randint(rng_2, input_shape, 0, output_vocab_size) + + for i in range(input_shape[1]): + masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]), axis=1) + + y_2 = _get_output_logits(model, masked_x_2) + self.assertEqual(y_2.shape[0], input_shape[1]) + np.testing.assert_array_almost_equal(y_1[: i + 1], y_2[: i + 1]) + + check_autoregressive_property(model) + + def test_hourglass_lm_autoregressive_property(self): + d_model = 8 + vocab_size = 26 + + model_single_stage = hourglass.HourglassLM( + vocab_size, + hierarchy="2@4", + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + model_multi_stage = hourglass.HourglassLM( + vocab_size, + hierarchy="2@3 2@6 2@3", + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + input_shape = (1, 12) + self._test_autoregressive_property( + model_single_stage, input_shape, output_vocab_size=vocab_size + ) + self._test_autoregressive_property( + model_multi_stage, input_shape, output_vocab_size=vocab_size + ) + + def test_parse_hourglass_hierarchy(self): + self.assertEqual(hourglass._parse_hierarchy("6@3"), ([6], [3])) + self.assertEqual( + hourglass._parse_hierarchy("3@2 2@6 5@24 2@6 3@2"), ([3, 2, 5], [2, 3, 4]) + ) + self.assertRaises(ValueError, hourglass._parse_hierarchy, "1@2 2@3 1@2") + self.assertRaises(ValueError, hourglass._parse_hierarchy, "1@2 2@3") + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/layerdrop_transformer_test.py b/tests/models/research/layerdrop_transformer_test.py new file mode 100644 index 000000000..af34d22c4 --- /dev/null +++ b/tests/models/research/layerdrop_transformer_test.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Reformer models.""" + +from absl.testing import absltest +import numpy as np + +from trax import shapes +from trax.models.research import layerdrop_transformer + + +class SkippingTransformerTest(absltest.TestCase): + def test_skipping_transformer_forward_shape(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = layerdrop_transformer.SkippingTransformerLM( + vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16 + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +class LayerDropTransformerTest(absltest.TestCase): + def test_layerdrop_transformer_forward_shape(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = layerdrop_transformer.LayerDropTransformerLM( + vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16 + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + def test_layerdrop_layerwise_skip_fraction(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = layerdrop_transformer.LayerDropTransformerLM( + vocab_size, + d_model=16, + d_ff=32, + n_layers=2, + n_heads=2, + max_len=16, + skip_fraction=[0.2, 0.8], + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +class EveryOtherLayerDropTransformerTest(absltest.TestCase): + def test_everyother_layerdrop_transformer_forward(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = layerdrop_transformer.EveryOtherLayerDropTransformerLM( + vocab_size, + d_model=16, + d_ff=32, + n_layers=2, + n_heads=2, + max_len=16, + skip_mode="1half", + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/rezero_test.py b/tests/models/research/rezero_test.py new file mode 100644 index 000000000..f935d9eb8 --- /dev/null +++ b/tests/models/research/rezero_test.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ReZero models.""" + +from absl.testing import absltest +import numpy as np + +from trax import layers as tl +from trax import shapes +from trax.models.research import rezero + + +class ResidualZeroTest(absltest.TestCase): + def test_residual_layer_forward(self): + """Tests that the forward pass runs and returns the expected shape.""" + model = rezero.ResidualZero(tl.Dense(5)) + x = [np.arange(5).astype(np.float32)] + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0, 3.0, 4.0]) + + +class ReZeroTransformerLMTest(absltest.TestCase): + def test_rezero_lm_forward_shape(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = rezero.ReZeroTransformerLM( + vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2, max_len=16 + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +class ReZeroTransformerTest(absltest.TestCase): + def test_rezero_forward_shape(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = rezero.ReZeroTransformer( + vocab_size, + d_model=32, + d_ff=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + max_len=16, + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/rse_test.py b/tests/models/research/rse_test.py new file mode 100644 index 000000000..4e79deb1e --- /dev/null +++ b/tests/models/research/rse_test.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Residual Shuffle-Exchange Networks.""" + +from absl.testing import absltest +import numpy as np + +from trax import shapes +from trax.models.research import rse + + +class RSETest(absltest.TestCase): + def test_rsu_forward_shape(self): + batch_size = 3 + seq_len = 32 + d_model = 17 + model = rse.ResidualSwitchUnit(d_model=d_model, dropout=0.1, mode="train") + x = np.ones((batch_size, seq_len, d_model)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (batch_size, seq_len, d_model)) + + def test_shuffle_layer(self): + shuffle_layer = rse.ShuffleLayer() + x = np.array([[[0], [1], [2], [3], [4], [5], [6], [7]]]) + print(x.shape) + _, _ = shuffle_layer.init(shapes.signature(x)) + y = shuffle_layer(x) + expected_output = np.array([[[0], [2], [4], [6], [1], [3], [5], [7]]]) + self._assert_equal_tensors(y, expected_output) + + def test_shuffle_layer_log_times_is_identity(self): + seq_len = 8 + d_model = 17 + shuffle_layer = rse.ShuffleLayer() + x = _input_with_indice_as_values(seq_len, d_model) + _, _ = shuffle_layer.init(shapes.signature(x)) + y = x + for _ in range(int(np.log2(seq_len))): + y = shuffle_layer(y) + self._assert_equal_tensors(x, y) + + def test_reverse_shuffle_layer(self): + reverse_shuffle_layer = rse.ReverseShuffleLayer() + x = np.array([[[0], [1], [2], [3], [4], [5], [6], [7]]]) + print(x.shape) + _, _ = reverse_shuffle_layer.init(shapes.signature(x)) + y = reverse_shuffle_layer(x) + expected_output = np.array([[[0], [4], [1], [5], [2], [6], [3], [7]]]) + self._assert_equal_tensors(y, expected_output) + + def test_reverse_shuffle_layer_log_times_is_identity(self): + seq_len = 8 + d_model = 17 + reverse_shuffle_layer = rse.ReverseShuffleLayer() + x = _input_with_indice_as_values(seq_len, d_model) + _, _ = reverse_shuffle_layer.init(shapes.signature(x)) + y = x + for _ in range(int(np.log2(seq_len))): + y = reverse_shuffle_layer(y) + self._assert_equal_tensors(x, y) + + def test_rse_forward_shape(self): + vocab_size = 12 + seq_len = 32 + model = rse.ResidualShuffleExchange( + vocab_size=vocab_size, + d_model=17, + dropout=0.1, + input_dropout=0.05, + mode="train", + ) + x = np.ones((3, seq_len)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, seq_len, vocab_size)) + + def _assert_equal_tensors(self, x, y): + self.assertEqual(y.shape, x.shape) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + for k in range(x.shape[2]): + self.assertEqual( + x[i][j][k], + y[i][j][k], + f"Tensors differ on index [{i}][{j}][{k}].", + ) + + +def _input_with_indice_as_values(length, dim): + """Retuns np.array of size (1, length, dim) where x[0, a, b] = a.""" + positions = [] + for i in range(length): + positions.append([i] * dim) + positions_input = np.array(positions) + positions_input = np.expand_dims(positions_input, axis=0) + return positions_input + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/terraformer_e2e_test.py b/tests/models/research/terraformer_e2e_test.py new file mode 100644 index 000000000..70417fa0f --- /dev/null +++ b/tests/models/research/terraformer_e2e_test.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End to end test for Reformer.""" + +import os + +import gin +from absl.testing import absltest + + +from trax import test_utils +from trax.supervised import trainer_lib + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath( + os.path.join(pkg_dir, "../../../resources/models/reformer/testdata") +) +_CONFIG_DIR = os.path.normpath( + os.path.join(pkg_dir, "../../../resources/supervised/configs") +) + + +class TerraformerE2ETest(absltest.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + gin.clear_config() + gin.add_config_file_search_path(_CONFIG_DIR) + + def test_terraformer_wmt_ende(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + + gin.parse_config_file("terraformer_wmt_ende.gin") + + gin.bind_parameter("data_streams.data_dir", _TESTDATA) + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter( + "batcher.buckets", ([512], [batch_size_per_device, batch_size_per_device]) + ) + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("ConfigurableTerraformer.n_encoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.n_decoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.d_ff", d_ff) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + def test_terraformer_copy(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + + gin.parse_config_file("terraformer_copy.gin") + + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter("batcher.buckets", ([64], [1, 1])) # batch size 1. + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("ConfigurableTerraformer.n_encoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.n_decoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.d_ff", d_ff) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + def test_terraformer_purelsh_copy(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + + gin.parse_config_file("terraformer_purelsh_copy.gin") + + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter("batcher.buckets", ([64], [1, 1])) # batch size 1. + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("ConfigurableTerraformer.n_encoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.n_decoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.d_ff", d_ff) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/models/research/terraformer_memory_test.py b/tests/models/research/terraformer_memory_test.py similarity index 83% rename from trax/models/research/terraformer_memory_test.py rename to tests/models/research/terraformer_memory_test.py index 8c4a78601..cf7750c2b 100644 --- a/trax/models/research/terraformer_memory_test.py +++ b/tests/models/research/terraformer_memory_test.py @@ -23,14 +23,10 @@ from absl.testing import absltest - class TerraformerMemoryTest(absltest.TestCase): + def test_terraformer_memory(self): + pass # TODO(jonni): Figure out an OSS-compatible memory test. - def test_terraformer_memory(self): - pass # TODO(jonni): Figure out an OSS-compatible memory test. - - -if __name__ == '__main__': - config.config_with_absl() - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/terraformer_oom_test.py b/tests/models/research/terraformer_oom_test.py new file mode 100644 index 000000000..3434d2710 --- /dev/null +++ b/tests/models/research/terraformer_oom_test.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OOM for Terraformer .""" + +import functools +import operator + +from absl.testing import absltest +import gin +import numpy as np + +from trax import fastmath +from trax import layers as tl +from trax import shapes +from trax.models.research import terraformer + + +class TerraformerOOMTest(absltest.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def _lsh_self_attention_fn(self): + return functools.partial( + tl.LSHSelfAttention, + attention_dropout=0.0, + chunk_len=64, + n_buckets=[32, 32], + n_chunks_after=0, + n_chunks_before=1, + n_hashes=1, + n_parallel_heads=1, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def test_terraformer_one_step(self): + d_model = 1024 + vocab_size = 14041 + max_len = 16384 + pos_axial = (128, 128) # should multiply to max_len + pos_d_axial_embs = (512, 512) # sum to d model + + assert operator.mul(*pos_axial) == max_len + assert sum(pos_d_axial_embs) == d_model + + d_ff = 4096 + n_heads = 8 + d_attn = d_model // n_heads + + n_buckets = 128 + encoder_chunk_len = (2 * max_len) // n_buckets # 256 + decoder_chunk_len = 2 * encoder_chunk_len # 512 + encoder_n_chunks_after = 1 # since its not causal. + + lsh_self_attention = functools.partial( + self._lsh_self_attention_fn(), n_buckets=n_buckets + ) + + encoder_lsh_self_attention = functools.partial( + lsh_self_attention, + n_chunks_after=encoder_n_chunks_after, + chunk_len=encoder_chunk_len, + ) + + decoder_lsh_self_attention = functools.partial( + lsh_self_attention, n_chunks_after=0, chunk_len=decoder_chunk_len + ) + + model = terraformer.ConfigurableTerraformer( + vocab_size, + d_model=d_model, + d_ff=d_ff, + d_attention_key=d_attn, + d_attention_value=d_attn, + n_encoder_layers=1, + n_decoder_layers=1, + n_heads=n_heads, + dropout=0.05, + max_len=max_len, + encoder_attention_type=encoder_lsh_self_attention, + encoder_decoder_attention_type=decoder_lsh_self_attention, + pos_axial_shape=pos_axial, + pos_d_axial_embs=pos_d_axial_embs, + ff_activation=tl.Relu, + ff_use_sru=0, + mode="train", + ) + + def random_sentence(): + return np.random.randint( + low=1, high=vocab_size - 1, size=(1, max_len), dtype=np.int32 + ) + + x = [random_sentence(), random_sentence()] + weights, state = model.init(shapes.signature(x)) + + @fastmath.jit + def mock_training_step(x, weights, state, rng): + def compute_mock_loss(weights): + logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) + # This returns [logits, decoder tokens] + logits = logits_and_dec_toks[0] + loss = fastmath.numpy.mean(logits[..., 0]) + return loss, (new_state, logits) + + gradients, (new_state, logits) = fastmath.grad( + compute_mock_loss, has_aux=True + )(weights) + new_weights = fastmath.nested_map_multiarg( + lambda w, g: w - 1e-4 * g, weights, gradients + ) + return new_weights, new_state, logits + + weights, state, logits = mock_training_step( + x, weights, state, fastmath.random.get_prng(0) + ) + + self.assertEqual(logits.shape, (1, max_len, vocab_size)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/terraformer_test.py b/tests/models/research/terraformer_test.py new file mode 100644 index 000000000..081ea7e36 --- /dev/null +++ b/tests/models/research/terraformer_test.py @@ -0,0 +1,290 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Terraformer models.""" + +import functools + +from absl.testing import absltest +from absl.testing import parameterized +import gin +import numpy as np + +from trax import fastmath +from trax import layers as tl +from trax import shapes +from tests.layers import test_utils +from trax.models.research import terraformer + + +BACKENDS = [fastmath.Backend.JAX] + + +def short_name(b): + if b == fastmath.Backend.JAX: + return "jax" + else: + return "tf" + + +class TerraformerTest(parameterized.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def _lsh_self_attention_fn(self): + return functools.partial( + tl.LSHSelfAttention, + attention_dropout=0.0, + chunk_len=64, + n_buckets=[32, 32], + n_chunks_after=0, + n_chunks_before=1, + n_hashes=1, + n_parallel_heads=1, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def _timebin_self_attention_fn(self, use_reference_code=False): + return functools.partial( + tl.SelfAttention, + attention_dropout=0.05, + chunk_len=64, + n_chunks_before=1, + n_parallel_heads=1, + use_reference_code=use_reference_code, + ) + + @parameterized.named_parameters( + [ + ("_%s_efficient" % short_name(backend), backend, tl.SelfAttention, False) + for backend in BACKENDS + ] + + [ + ("_%s_causal" % short_name(backend), backend, tl.CausalAttention, False) + for backend in BACKENDS + ] + + + # NOTE: tl.SelfAttention is not currently working for this case. + [ + ("_%s_preembed" % short_name(backend), backend, tl.CausalAttention, True) + for backend in BACKENDS + ] + ) + def test_terraformer_quick(self, backend, encoder_attention_type, preembed): + with fastmath.use_backend(backend): + vocab_size = 2 + input_vocab_size = None if preembed else vocab_size + output_vocab_size = vocab_size if preembed else None + max_len = 2 + + model = terraformer.ConfigurableTerraformer( + input_vocab_size, + d_model=4, + d_ff=4, + n_encoder_layers=1, + n_decoder_layers=1, + n_heads=2, + dropout=0.05, + max_len=max_len, + pos_type=None, + ff_activation=tl.Relu, + ff_use_sru=0, + ff_chunk_size=2, + mode="train", + output_vocab_size=output_vocab_size, + encoder_attention_type=encoder_attention_type, + ) + + if preembed: + model_inputs = [ + np.ones((1, max_len, 3)).astype(np.float32), + np.ones((1, max_len)).astype(bool), + ] + else: + model_inputs = [np.ones((1, max_len)).astype(np.int32)] + x = model_inputs + [np.ones((1, max_len)).astype(np.int32)] + model.init(shapes.signature(x)) + + logits, dec_toks = model(x) + del dec_toks + + self.assertEqual(logits.shape, (1, max_len, vocab_size)) + + def test_terraformer_deterministic_eval(self): + with fastmath.use_backend(fastmath.Backend.JAX): + vocab_size = 16 + d_model = 4 + batch_size = 2 + length = 5 + + model_fn = functools.partial( + terraformer.ConfigurableTerraformer, + vocab_size, + d_model=d_model, + d_ff=16, + n_encoder_layers=0, + n_decoder_layers=1, + n_heads=2, + dropout=0.0, + max_len=length * 2, + pos_type=None, + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + ) + + inp = np.random.randint(vocab_size, size=(batch_size, length)) + out = np.zeros((batch_size, length), dtype=np.int32) + + test_utils.test_eval_is_deterministic((inp, out), model_fn) + + def test_terraformer_predict_equals_eval(self): + with fastmath.use_backend(fastmath.Backend.JAX): + vocab_size = 16 + d_model = 8 + batch_size = 1 + length = 5 + + model_fn = functools.partial( + terraformer.ConfigurableTerraformer, + vocab_size, + d_model=d_model, + d_ff=16, + n_encoder_layers=1, + n_decoder_layers=1, + n_heads=2, + ff_use_sru=(1, 8), # ? is SRU working? + dropout=0.0, + max_len=(length + 7) * 2, + pos_type=None, + reversible_encoder=True, + n_decoder_attention_layers=1, + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + ) + + # Token id of 0 indicates padding; and predict mode doesn't support it. + inp = np.random.randint(1, vocab_size, size=(batch_size, length)) + inp[:, -2:] = 0 + out = np.zeros((batch_size, length), dtype=np.int32) + + test_utils.test_eval_equals_predict( + (inp, out), model_fn, seq_axis=1, seq_tensor=-1, init_tokens=1 + ) + + def test_terraformer_doubling(self): + vocab_size = 2 + max_len = 2 + + model = terraformer.ConfigurableTerraformer( + vocab_size, + d_model=8, + d_ff=16, + n_encoder_layers=1, + n_decoder_layers=6, + n_heads=2, + dropout=0.05, + max_len=max_len, + pos_type=None, + half_before_layer=2, + double_after_layer=2, + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + mode="train", + ) + + x = [ + np.ones((1, max_len)).astype(np.int32), + np.ones((1, max_len)).astype(np.int32), + ] + model.init(shapes.signature(x)) + + logits, dec_toks = model(x) + del dec_toks + + self.assertEqual(logits.shape, (1, max_len, vocab_size)) + + def test_terraformer_one_step(self): + vocab_size = 32 + max_len = 256 + pos_axial = 16 + assert pos_axial * pos_axial == max_len + + chunk_len = 32 + + # Since 2 * chunk_len * n_buckets should be max_len. + n_buckets = max_len // (2 * chunk_len) + + lsh_self_attention = functools.partial( + self._lsh_self_attention_fn(), chunk_len=chunk_len, n_buckets=n_buckets + ) + + timebin_self_attention = self._timebin_self_attention_fn() + + model = terraformer.ConfigurableTerraformer( + vocab_size, + d_model=32, + d_ff=64, + d_attention_key=64, + d_attention_value=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + dropout=0.05, + max_len=max_len, + encoder_attention_type=lsh_self_attention, + encoder_decoder_attention_type=[timebin_self_attention, lsh_self_attention], + pos_axial_shape=(pos_axial, pos_axial), + pos_d_axial_embs=(64, 192), + ff_activation=tl.Relu, + ff_use_sru=0, + ff_chunk_size=64, + ff_sparsity=8, + mode="train", + ) + + x = [ + np.ones((1, max_len)).astype(np.int32), + np.ones((1, max_len)).astype(np.int32), + ] + weights, state = model.init(shapes.signature(x)) + + @fastmath.jit + def mock_training_step(x, weights, state, rng): + def compute_mock_loss(weights): + logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) + # This returns [logits, decoder tokens] + logits = logits_and_dec_toks[0] + loss = fastmath.numpy.mean(logits[..., 0]) + return loss, (new_state, logits) + + gradients, (new_state, logits) = fastmath.grad( + compute_mock_loss, has_aux=True + )(weights) + new_weights = fastmath.nested_map_multiarg( + lambda w, g: w - 1e-4 * g, weights, gradients + ) + return new_weights, new_state, logits + + weights, state, logits = mock_training_step( + x, weights, state, fastmath.random.get_prng(0) + ) + + self.assertEqual(logits.shape, (1, max_len, vocab_size)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/transformer2_test.py b/tests/models/research/transformer2_test.py new file mode 100644 index 000000000..326ada7da --- /dev/null +++ b/tests/models/research/transformer2_test.py @@ -0,0 +1,467 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer models.""" + +from absl.testing import absltest +import numpy as np + +from trax import shapes +from trax.models.research import transformer2 + + +class Transformer2Test(absltest.TestCase): + def test_concat_with_padding(self): + vec_e = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + # vec_e[:,:,0] != 0 + mask_e = np.array( + [ + [True, True, False, False, False, False], + [True, True, True, False, False, False], + ] + ) + + vec_d = np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + layer = transformer2.ConcatWithPadding(mode="train") + inp = (vec_e, vec_d, mask_e, vec_e, vec_d) # tok_e = vec_e, tok_d = vec_d + layer.init(shapes.signature(inp)) + y, _, _ = layer(inp) + + np.testing.assert_equal( + y, + np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + def test_concat_with_padding_predict(self): + vec_e = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + # vec_e[:,:,0] != 0 + mask_e = np.array( + [ + [True, True, False, False, False, False], + [True, True, True, False, False, False], + ] + ) + + vec_d = np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + layer = transformer2.ConcatWithPadding(mode="predict") + inp = (vec_e, vec_d, mask_e, vec_e, vec_d) # tok_e = vec_e, tok_d = vec_d + _, _ = layer.init(shapes.signature(inp)) + y, _, _ = layer(inp) + + np.testing.assert_equal( + y, + np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + # On subsequent runs however, we should get vec_d only. + for _ in range(2): + y, _, _ = layer(inp) + np.testing.assert_equal(y, vec_d) + + def test_concat_with_padding2(self): + vec_e = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + # vec_e[:,:,0] != 0 + mask_e = np.array( + [ + [True, True, False, False, False, False], + [True, True, True, False, False, False], + ] + ) + + vec_d = np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + layer = transformer2.ConcatWithPadding2(mode="train") + inp = (vec_e, vec_e, vec_d, mask_e, vec_e, vec_d) + layer.init(shapes.signature(inp)) + y1, y2, _, _ = layer(inp) + + np.testing.assert_equal( + y1, + np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + np.testing.assert_equal( + y2, + np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + def test_strip_from_concatenate_with_padding(self): + enc_dec = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + tok_e = np.array([[7, 8, 0, 0, 0, 0], [4, 6, 3, 0, 0, 0]]) + tok_d = np.array([[4, 6, 0, 0], [3, 4, 1, 0]]) + + layer = transformer2.StripFromConcatenateWithPadding(mode="train") + inp = (enc_dec, tok_e, tok_d) + _, _ = layer.init(shapes.signature(inp)) + y = layer(inp) + + np.testing.assert_equal( + y, + np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + def test_strip_from_concatenate_with_padding_predict(self): + enc_dec = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + tok_e = np.array([[7, 8, 0, 0, 0, 0], [4, 6, 3, 0, 0, 0]]) + tok_d = np.array([[4, 6, 0, 0], [3, 4, 1, 0]]) + + layer = transformer2.StripFromConcatenateWithPadding(mode="predict") + inp = (enc_dec, tok_e, tok_d) + _, _ = layer.init(shapes.signature(inp)) + y = layer(inp) + + np.testing.assert_equal( + y, + np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + # On subsequent runs however, we should get enc_dec only. + for _ in range(2): + y = layer(inp) + np.testing.assert_equal(y, enc_dec) + + def test_transformer_noencdec_forward_shape(self): + input_vocab_size = 16 + output_vocab_size = 16 + + model = transformer2.Transformer2( + input_vocab_size, + output_vocab_size, + d_model=32, + d_ff=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + ) + + enc_toks = np.array([[6, 2, 0, 0, 0, 0], [6, 3, 7, 0, 0, 0]]) + dec_toks = np.array([[4, 2, 0, 0], [8, 5, 0, 0]]) + + xs = [enc_toks, dec_toks] + _, _ = model.init(shapes.signature(xs)) + + # decoder output, decoder mask + ys = model(xs) + + # (B, L2, H) + self.assertEqual( + ys[0].shape, (dec_toks.shape[0], dec_toks.shape[1], output_vocab_size) + ) + + self.assertEqual(ys[1].shape, dec_toks.shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/models/resnet_test.py b/tests/models/resnet_test.py similarity index 57% rename from trax/models/resnet_test.py rename to tests/models/resnet_test.py index 3742d67ae..820b1f1cf 100644 --- a/trax/models/resnet_test.py +++ b/tests/models/resnet_test.py @@ -15,31 +15,28 @@ """Tests for Resnet models.""" -from absl.testing import absltest import numpy as np +from absl.testing import absltest -from trax import fastmath from trax import shapes from trax.models import resnet class ResnetTest(absltest.TestCase): - - def test_resnet(self): - model = resnet.Resnet50(d_hidden=8, n_output_classes=10) - x = np.ones((3, 256, 256, 3)).astype(np.float32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 10)) - - def test_wide_resnet(self): - model = resnet.WideResnet(n_blocks=1, n_output_classes=10) - x = np.ones((3, 32, 32, 3)).astype(np.float32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 10)) - - - -if __name__ == '__main__': - absltest.main() + def test_resnet(self): + model = resnet.Resnet50(d_hidden=8, n_output_classes=10) + x = np.ones((3, 256, 256, 3)).astype(np.float32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 10)) + + def test_wide_resnet(self): + model = resnet.WideResnet(n_blocks=1, n_output_classes=10) + x = np.ones((3, 32, 32, 3)).astype(np.float32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 10)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/rl_test.py b/tests/models/rl_test.py new file mode 100644 index 000000000..bf8398378 --- /dev/null +++ b/tests/models/rl_test.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RL.""" + +from unittest import mock +from absl.testing import absltest +import numpy as np + +from trax import shapes +from trax.models import rl + + +class RLTest(absltest.TestCase): + def test_policy_forward_shape(self): + mock_dist = mock.MagicMock() + mock_dist.n_inputs = 4 + model = rl.Policy(policy_distribution=mock_dist) + x = np.ones((2, 3)) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (2, 4)) + + def test_value_forward_shape(self): + model = rl.Value() + x = np.ones((2, 3)) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (2, 1)) + + def test_policy_and_value_forward_shape(self): + mock_dist = mock.MagicMock() + mock_dist.n_inputs = 4 + model = rl.PolicyAndValue(policy_distribution=mock_dist) + x = np.ones((2, 3)) + _, _ = model.init(shapes.signature(x)) + ys = model(x) + self.assertEqual([y.shape for y in ys], [(2, 4), (2, 1)]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/rnn_test.py b/tests/models/rnn_test.py new file mode 100644 index 000000000..2a2167b8f --- /dev/null +++ b/tests/models/rnn_test.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RNNs.""" + +from absl.testing import absltest +from absl.testing import parameterized + +from trax import fastmath +from trax.fastmath import numpy as jnp +from trax import shapes +from trax.models import rnn + +BACKENDS = [fastmath.Backend.JAX] + + +@parameterized.named_parameters(("_" + b.value, b) for b in BACKENDS) +class RNNTest(parameterized.TestCase): + def test_rnnlm_forward_shape(self, backend): + with fastmath.use_backend(backend): + model = rnn.RNNLM(vocab_size=20, d_model=16) + x = (jnp.ones((3, 28)).astype(jnp.int32),) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 28, 20)) + + def test_grulm_forward_shape(self, backend): + with fastmath.use_backend(backend): + model = rnn.GRULM(vocab_size=20, d_model=16) + x = jnp.ones((3, 28)).astype(jnp.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 28, 20)) + + def test_lstmseq2seqattn_forward_shape(self, backend): + with fastmath.use_backend(backend): + model = rnn.LSTMSeq2SeqAttn( + input_vocab_size=20, target_vocab_size=20, d_model=16 + ) + x = jnp.ones((3, 28)).astype(jnp.int32) + _, _ = model.init([shapes.signature(x), shapes.signature(x)]) + ys = model([x, x]) + self.assertEqual([y.shape for y in ys], [(3, 28, 20), (3, 28)]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/transformer_test.py b/tests/models/transformer_test.py new file mode 100644 index 000000000..b694cb89f --- /dev/null +++ b/tests/models/transformer_test.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer models.""" + +import functools + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +from trax import shapes +from tests.layers import test_utils +from trax.models import transformer + + +class TransformerTest(parameterized.TestCase): + def test_transformer_lm_forward_shape(self): + vocab_size = 16 + model = transformer.TransformerLM( + vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2 + ) + x = np.ones((3, 5)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 5, vocab_size)) + + def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + model = transformer.Transformer( + input_vocab_size, + output_vocab_size, + d_model=32, + d_ff=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + ) + xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + y, _ = model(xs) + + vocab_size = output_vocab_size or input_vocab_size + self.assertEqual(y.shape, (3, 5, vocab_size)) + + @parameterized.named_parameters( + ("same_vocab", 16, None), ("same_size", 16, 16), ("different_size", 16, 50) + ) + def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + """Run the Transformer forward and check output shape.""" + self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) + + def test_dot_product_causal_attention_fast_inference(self): + model_fn = functools.partial( + transformer.TransformerLM, d_model=4, d_ff=8, n_layers=2, n_heads=2 + ) + test_utils.test_eval_equals_predict_discrete(model_fn) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/optimizers/optimizers_test.py b/tests/optimizers/optimizers_test.py similarity index 51% rename from trax/optimizers/optimizers_test.py rename to tests/optimizers/optimizers_test.py index 583f73655..fc949b9eb 100644 --- a/trax/optimizers/optimizers_test.py +++ b/tests/optimizers/optimizers_test.py @@ -24,27 +24,26 @@ class OptimizersTest(absltest.TestCase): - - def test_slots(self): - weights_shape = (3, 5) - weight_tree = np.arange(15).reshape(weights_shape) - - # SGD - an optimizer that doesn't use slots. - opt_1 = optimizers.SGD(.01) - self.assertIsNone(opt_1.slots) - opt_1.tree_init(weight_tree) - self.assertIsInstance(opt_1.slots, tuple) - self.assertLen(opt_1.slots, 1) - self.assertIsNone(opt_1.slots[0]) - - # Momentum - an optimizer with slots - opt_2 = momentum.Momentum(.01) - self.assertIsNone(opt_2.slots) - opt_2.tree_init(weight_tree) - self.assertIsInstance(opt_2.slots, tuple) - self.assertLen(opt_2.slots, 1) - self.assertEqual(weights_shape, opt_2.slots[0].shape) - - -if __name__ == '__main__': - absltest.main() + def test_slots(self): + weights_shape = (3, 5) + weight_tree = np.arange(15).reshape(weights_shape) + + # SGD - an optimizer that doesn't use slots. + opt_1 = optimizers.SGD(0.01) + self.assertIsNone(opt_1.slots) + opt_1.tree_init(weight_tree) + self.assertIsInstance(opt_1.slots, tuple) + self.assertLen(opt_1.slots, 1) + self.assertIsNone(opt_1.slots[0]) + + # Momentum - an optimizer with slots + opt_2 = momentum.Momentum(0.01) + self.assertIsNone(opt_2.slots) + opt_2.tree_init(weight_tree) + self.assertIsInstance(opt_2.slots, tuple) + self.assertLen(opt_2.slots, 1) + self.assertEqual(weights_shape, opt_2.slots[0].shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/optimizers/trainer_test.py b/tests/optimizers/trainer_test.py new file mode 100644 index 000000000..8a96e18c4 --- /dev/null +++ b/tests/optimizers/trainer_test.py @@ -0,0 +1,384 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for accelerated optimization of loss layers.""" + +import time +from absl.testing import absltest + +from jax.config import config +import numpy as np + +from trax import fastmath +from trax import layers as tl +from trax import optimizers +from trax import shapes +from trax.layers import base +from trax.models.research import terraformer + + +class TrainerTest(absltest.TestCase): + def _assert_all_equal(self, t1, t2, tol=1e-5): + def eq(x1, x2): + diff = np.maximum(np.abs(x1 - x2) - tol, 0.0) + self.assertLessEqual( + np.sum(diff), 0.0, msg=f"\n{x1}\n !=\n{x2}\n diff:\n{x1-x2}" + ) + + fastmath.nested_map_multiarg(eq, t1, t2) + + def test_run_simple_task(self): + """Runs an accelerated optimizer on a simple task.""" + inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch + targets_batch = np.pi * np.ones_like(inputs_batch) + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss()) + loss_layer.init(labeled_batch) + optimizer = optimizers.SGD(0.01) + optimizer.tree_init(loss_layer.weights) + trainer = optimizers.Trainer(loss_layer, optimizer) + rng = fastmath.random.get_prng(0) + trainer.one_step(labeled_batch, rng) + + def test_run_sharded_terraformer(self): + """Runs Terraformer with sharded weights (only on 2+-device systems).""" + if fastmath.local_device_count() == 1: + return + base.N_WEIGHTS_SHARDS = fastmath.local_device_count() + inputs_batch = np.arange(8).reshape((2, 4)) + 1 + targets_batch = 2 * inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) + input_sig = (int_sig, int_sig, int_sig) + # We want to test rng propagation too, so adding some dropout layers. + model = terraformer.ConfigurableTerraformer( + 20, + d_model=8, + d_ff=32, + n_heads=1, + dropout=0.0, + n_encoder_layers=2, + n_decoder_layers=2, + ff_sparsity=(4, 8, 0.0, 1.0), + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + pos_type=None, + reversible_encoder=True, + ) + loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) + model_with_loss = tl.Serial(model, loss) + rng_init = fastmath.random.get_prng(12) + model_with_loss.init(input_sig, rng=rng_init) + + # Make a step with the trainer. + optimizer = optimizers.Adafactor(0.01) + split_w = fastmath.nested_map( + lambda x: x[0], tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS) + ) + optimizer.tree_init(split_w) + trainer = optimizers.Trainer(model_with_loss, optimizer) + rng_step1 = fastmath.random.get_prng(7) + trainer.one_step(labeled_batch, rng_step1) + # Reset shards back to default. + base.N_WEIGHTS_SHARDS = 1 + + def test_run_reversible_slots(self): + """Tests that slots can be read and assigned in reversible trainer.""" + layers = [tl.Dense(4), tl.Dup()] + rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)), tl.ReversibleSwap()] + loss_layer = tl.Serial( + tl.Concatenate(), tl.Dense(4), tl.LogSoftmax(), tl.CrossEntropyLoss() + ) + trainer = optimizers.ReversibleSerialTrainer( + [(layers, rev_layers)], loss_layer, optimizers.Adam + ) + slots = trainer.slots + trainer.slots = slots + self.assertEqual(slots, trainer.slots) + + def test_run_reversible_same_as_default_basic(self): + """Runs the reversible trainer, check results are the same as default.""" + inputs_batch = np.arange(8).reshape((2, 4)) + targets_batch = 2 * inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + # We want to test rng propagation too, so adding some dropout layers. + first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) + rev_layers = [ + tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), + tl.ReversibleSwap(), + tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), + tl.ReversibleSwap(), + ] + loss_layer = tl.Serial( + tl.Concatenate(), + tl.Dense(19), + tl.Dropout(0.3), + tl.LogSoftmax(), + tl.CrossEntropyLoss(), + ) + model = tl.Serial([first_layer] + rev_layers + [loss_layer]) + rng_init = fastmath.random.get_prng(12) + model.init(labeled_batch, rng=rng_init) + optimizer_fn = optimizers.Adam # to test slots + + # Make 2 steps with the original trainer. + optimizer = optimizer_fn() + optimizer.tree_init(model.weights) + trainer = optimizers.Trainer(model, optimizer) + rng_step1 = fastmath.random.get_prng(7) + rng_step2 = fastmath.random.get_prng(8) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + first_layer_weights1 = first_layer.weights + rev_layer0_weights1 = rev_layers[0].weights + rev_layer2_weights1 = rev_layers[2].weights + loss_layer_weights1 = loss_layer.weights + + # Now make 2 steps with reversible trainer. + model.init(labeled_batch, rng=rng_init) + trainer = optimizers.ReversibleSerialTrainer( + [(first_layer.sublayers, rev_layers)], loss_layer, optimizer_fn + ) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + + # Check that weights end up the same. + self._assert_all_equal(loss_layer_weights1, loss_layer.weights) + self._assert_all_equal(rev_layer2_weights1, rev_layers[2].weights) + self._assert_all_equal(rev_layer0_weights1, rev_layers[0].weights) + self._assert_all_equal(first_layer_weights1, first_layer.weights) + + def test_run_reversible_same_as_default_extended(self): + """Runs the reversible trainer, check results are the same as default.""" + inputs_batch = np.arange(8).reshape((2, 4)) + targets_batch = 2 * inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + # We want to test rng propagation too, so adding some dropout layers. + first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) + rev_layers1 = [ + tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), + tl.ReversibleSwap(), + tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), + tl.ReversibleSwap(), + ] + mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) + rev_layers2 = [ + tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), + tl.ReversibleSwap(), + ] + loss_layer = tl.Serial( + tl.Concatenate(), + tl.Dense(19), + tl.Dropout(0.3), + tl.LogSoftmax(), + tl.CrossEntropyLoss(), + ) + model = tl.Serial( + [first_layer] + rev_layers1 + [mid_layer] + rev_layers2 + [loss_layer] + ) + rng_init = fastmath.random.get_prng(12) + model.init(labeled_batch, rng=rng_init) + optimizer_fn = optimizers.Adam # to test slots + + # Make 3 steps with the original trainer. + optimizer = optimizer_fn() + optimizer.tree_init(model.weights) + trainer = optimizers.Trainer(model, optimizer) + rng_step1 = fastmath.random.get_prng(7) + rng_step2 = fastmath.random.get_prng(8) + rng_step3 = fastmath.random.get_prng(9) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) + first_layer_weights1 = first_layer.weights + rev_layer12_weights1 = rev_layers1[2].weights + mid_layer_weights1 = mid_layer.weights + rev_layer20_weights1 = rev_layers2[0].weights + loss_layer_weights1 = loss_layer.weights + + # Now make 3 steps with reversible trainer. + model.init(labeled_batch, rng=rng_init) + # TODO(lukaszkaiser): this test seems to fail with memoize_jit, why? + trainer = optimizers.ReversibleSerialTrainer( + [(first_layer.sublayers, rev_layers1), (mid_layer.sublayers, rev_layers2)], + loss_layer, + optimizer_fn, + memoize_jit=False, + ) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) + + # Check that weights end up the same. + self._assert_all_equal(loss_layer_weights1, loss_layer.weights) + self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) + self._assert_all_equal(mid_layer_weights1, mid_layer.weights) + self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) + self._assert_all_equal(first_layer_weights1, first_layer.weights) + + def test_run_reversible_same_as_default_terraformer(self): + """Runs the reversible trainer, check results are the same as default.""" + inputs_batch = np.arange(8).reshape((2, 4)) + 1 + targets_batch = 2 * inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) + input_sig = (int_sig, int_sig, int_sig) + # We want to test rng propagation too, so adding some dropout layers. + model = terraformer.ConfigurableTerraformer( + 20, + d_model=8, + d_ff=32, + n_heads=1, + dropout=0.0, + n_encoder_layers=2, + n_decoder_layers=2, + ff_sparsity=(4, 8, 0.0, 1.0), + pos_type=None, + reversible_encoder=True, + ) + loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) + optimizer_fn = optimizers.Adafactor + blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( + [model, loss], loss_chunk_size=4 + ) + blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks] + model_with_loss = tl.Serial(model, loss) + rng_init = fastmath.random.get_prng(12) + model_with_loss.init(input_sig, rng=rng_init) + + # Make 3 steps with the original trainer. + optimizer = optimizer_fn() + optimizer.tree_init(model_with_loss.weights) + trainer = optimizers.Trainer(model_with_loss, optimizer) + rng_step1 = fastmath.random.get_prng(7) + rng_step2 = fastmath.random.get_prng(8) + rng_step3 = fastmath.random.get_prng(9) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) + first_weights = blocks_serial[0][0].weights + first_rev_weights = blocks[0][1][0].weights + loss_weights = loss_layer.weights + + # Now make 3 steps with reversible trainer. + model_with_loss.init(input_sig, rng=rng_init) + trainer = optimizers.ReversibleSerialTrainer(blocks, loss_layer, optimizer_fn) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) + + # Check that weights end up the same. + self._assert_all_equal(loss_weights, loss_layer.weights) + self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights) + self._assert_all_equal(first_weights, blocks_serial[0][0].weights) + + def test_run_reversible_large_weights(self): + """Runs the reversible trainer with a lot of weights to test memory use.""" + # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU + # and CPU when you run it locally, but it's too big for unit-testing. + ram_limited = True # Set to False to run this test locally. + if fastmath.global_device_count() == 1 and ram_limited: + return + + # Create inputs and rngs. + inputs_batch = np.arange(8).reshape((2, 4)) + targets_batch = inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup()) + rng_init = fastmath.random.get_prng(12) + rng_step = fastmath.random.get_prng(13) + + # Initialize layers. + first_layer.init(labeled_batch, rng=rng_init) + n_layers = 18 # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram + rev_layers = [] + int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) + shape = shapes.ShapeDtype((2, 4, 16 * 1024)) + sig = (shape, shape) + for _ in range(n_layers): + layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024)) + layer.init(sig, rng=rng_init) + layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory + rev_layers.append(layer) + rev_layers.append(tl.ReversibleSwap()) + loss_layer = tl.Serial( + tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss() + ) + loss_layer.init((shape, shape, int_shape, int_shape)) + optimizer_fn = optimizers.Adafactor + + # Make a step with reversible trainer. + trainer = optimizers.ReversibleSerialTrainer( + [(first_layer, rev_layers)], loss_layer, optimizer_fn + ) + loss, _ = trainer.one_step(labeled_batch, rng_step) + self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. + # Set to true to run again, e.g., for profiling. + run_twice = False + if run_twice: + t = time.time() + loss, _ = trainer.one_step(labeled_batch, rng_step) + self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. + print("Took %.3f seconds to run, loss %s" % (time.time() - t, loss)) + + def test_run_reversible_weights_trainsfer_xprof(self): + """Runs the reversible trainer and profiles weight transfer stats.""" + run_this_test = False # We only run this test manually. + if not run_this_test or fastmath.global_device_count() == 1: # TPU only + return + + # Create inputs and rngs. + inputs_batch = np.ones((1024, 128), dtype=np.int32) + targets_batch = inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup()) + rng_init = fastmath.random.get_prng(12) + rng_step = fastmath.random.get_prng(13) + + # Initialize layers. + first_layer.init(labeled_batch, rng=rng_init) + n_layers = 6 + rev_layers = [] + int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32) + shape = shapes.ShapeDtype((1024, 128, 1024)) + sig = (shape, shape) + for _ in range(n_layers): + layer = tl.ReversibleHalfResidual(tl.Dense(1024)) + layer.init(sig, rng=rng_init) + layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory + rev_layers.append(layer) + rev_layers.append(tl.ReversibleSwap()) + loss_layer = tl.Serial( + tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss() + ) + loss_layer.init((shape, shape, int_shape, int_shape)) + optimizer_fn = optimizers.SGD + + # Make a step with reversible trainer. + trainer = optimizers.ReversibleSerialTrainer( + [(first_layer, rev_layers)], loss_layer, optimizer_fn + ) + loss, _ = trainer.one_step(labeled_batch, rng_step) + self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. + # We profile here. + t = time.time() + loss, _ = trainer.one_step(labeled_batch, rng_step) + self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. + print("Took %.3f seconds to run, loss %s" % (time.time() - t, loss)) + + +if __name__ == "__main__": + config.config_with_absl() + absltest.main() diff --git a/tests/shapes_test.py b/tests/shapes_test.py new file mode 100644 index 000000000..c7e511707 --- /dev/null +++ b/tests/shapes_test.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.shapes.""" +from absl.testing import absltest +import numpy as np + +from trax import shapes +from trax.shapes import ShapeDtype + + +class ShapesTest(absltest.TestCase): + + def test_constructor_and_read_properties(self): + sd = ShapeDtype((2, 3), np.int32) + self.assertEqual(sd.shape, (2, 3)) + self.assertEqual(sd.dtype, np.int32) + + def test_default_dtype_is_float32(self): + sd = ShapeDtype((2, 3)) + self.assertEqual(sd.shape, (2, 3)) + self.assertEqual(sd.dtype, np.float32) + + def test_signature_on_ndarray(self): + array = np.array([[2, 3, 5, 7], + [11, 13, 17, 19]], + dtype=np.int16) + sd = shapes.signature(array) + self.assertEqual(sd.shape, (2, 4)) + self.assertEqual(sd.dtype, np.int16) + + def test_shape_dtype_repr(self): + sd = ShapeDtype((2, 3)) + repr_string = '{}'.format(sd) + self.assertEqual(repr_string, + "ShapeDtype{shape:(2, 3), dtype:}") + + def test_splice_signatures(self): + sd1 = ShapeDtype((1,)) + sd2 = ShapeDtype((2,)) + sd3 = ShapeDtype((3,)) + sd4 = ShapeDtype((4,)) + sd5 = ShapeDtype((5,)) + + # Signatures can be ShapeDtype instances, tuples of 2+ ShapeDtype instances, + # or empty tuples. + sig1 = sd1 + sig2 = (sd2, sd3, sd4) + sig3 = () + sig4 = sd5 + spliced = shapes.splice_signatures(sig1, sig2, sig3, sig4) + self.assertEqual(spliced, (sd1, sd2, sd3, sd4, sd5)) + + def test_len_signature(self): + """Signatures of all sizes should give correct length when asked.""" + x1 = np.array([1, 2, 3]) + x2 = np.array([10, 20, 30]) + inputs0 = () + inputs1 = x1 # NOT in a tuple + inputs2 = (x1, x2) + + sig0 = shapes.signature(inputs0) + sig1 = shapes.signature(inputs1) + sig2 = shapes.signature(inputs2) + + # pylint: disable=g-generic-assert + self.assertEqual(len(sig0), 0) + self.assertEqual(len(sig1), 1) + self.assertEqual(len(sig2), 2) + # pylint: enable=g-generic-assert + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/supervised/callbacks_test.py b/tests/supervised/callbacks_test.py new file mode 100644 index 000000000..0c64cf3c4 --- /dev/null +++ b/tests/supervised/callbacks_test.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.supervised.callbacks.""" + +import functools +import io +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import gym +import numpy as np + +from trax import models +from trax import test_utils +from trax.data import inputs +from tests.layers import test_utils as tl_test_utils +from trax.rl import serialization_utils +from trax.rl import space_serializer +from trax.supervised import callbacks +from trax.supervised import lr_schedules +from trax.supervised import trainer_lib +from trax.supervised import training + + +def random_inputs(seq_len, batch_size): + def stream_fn(num_devices): + del num_devices + while True: + x = np.random.uniform(size=(batch_size, seq_len)) + y = np.random.uniform(size=(batch_size, seq_len)) + mask = np.ones_like(x).astype(np.float32) + yield (x, y, x, mask) + + return inputs.Inputs( + train_stream=stream_fn, + eval_stream=stream_fn, + ) + + +def make_multibonacci_modulo(history_length, limit): + """Creates a function that generates the Multibonacci sequence modulo n.""" + + def sequence_fn(seq): + return np.sum(seq[-history_length:]) % limit + + return sequence_fn + + +def generate_trajectory(sequence_fn, space, n_steps): + """Generates random actions and observations that follow sequence_fn.""" + act = [space.sample() for _ in range(n_steps)] + obs = [space.sample()] + + for (o, a) in zip( + obs, + act[:-1], # Don't generate the last observation. + ): + context = list(np.array([o, a]).flatten()) + symbols = [] + for _ in range(np.array(o).size): + symbol = sequence_fn(context + symbols) + symbols.append(symbol) + obs.append(np.reshape(symbols, space.shape)) + + obs = np.array([obs]) + act = np.array([act]) + return (obs, act) + + +def make_singleton_eval_task(observations, actions): + """Creates an EvalTask with just one example.""" + mask = np.ones(observations.shape[:2]) + + def data(): + while True: + yield (observations, actions, observations, mask) + + return training.EvalTask( + labeled_data=data(), + metrics=[], + ) + + +def make_serialized_model(seq_model, space, vocab_size): + srl = space_serializer.create(space, vocab_size) + return serialization_utils.SerializedModel( + functools.partial(seq_model, vocab_size=vocab_size), + observation_serializer=srl, + action_serializer=srl, + significance_decay=0.7, + ) + + +class CallbacksTest(parameterized.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + + @mock.patch("sys.stdout", new_callable=io.StringIO) + def test_serialized_model_evaluation(self, mock_stdout): + precision = 1 + vocab_size = 2 + srl = space_serializer.BoxSpaceSerializer( + space=gym.spaces.Box(shape=(), low=0.0, high=1.0), + vocab_size=vocab_size, + precision=precision, + ) + + def inner_model(mode): + return models.TransformerLM( + mode=mode, + vocab_size=vocab_size, + d_model=2, + d_ff=4, + n_layers=1, + n_heads=1, + ) + + serialized_model_fn = functools.partial( + serialization_utils.SerializedModel, + inner_model, + observation_serializer=srl, + action_serializer=srl, + significance_decay=0.7, + ) + eval_callback = functools.partial( + callbacks.SerializedModelEvaluation, eval_at=5 + ) + + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir=output_dir, + model=serialized_model_fn, + inputs=functools.partial(random_inputs, seq_len=4, batch_size=64), + lr_schedule_fn=functools.partial(lr_schedules.constant, 0.01), + callbacks=[eval_callback], + steps=10, + ) + self.assertTrue(_has_metric("pred_error", mock_stdout)) + + @parameterized.product( + context_lengths=((2,), (1, 3)), + horizon_lengths=((1,), (1, 2)), + ) + def test_srl_eval_feeds_correct_sequence(self, context_lengths, horizon_lengths): + vocab_size = 10 + n_steps = 5 + + multibonacci_modulo = make_multibonacci_modulo(2, vocab_size) + space = gym.spaces.Discrete(n=vocab_size) + (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps) + eval_task = make_singleton_eval_task(obs, act) + seq_model = functools.partial( + tl_test_utils.MockTransformerLM, + sequence_fn=multibonacci_modulo, + ) + serialized_model = make_serialized_model(seq_model, space, vocab_size) + callback = callbacks.SerializedModelEvaluation( + loop=None, + eval_task=eval_task, + model=serialized_model, + context_lengths=context_lengths, + horizon_lengths=horizon_lengths, + accelerate_model=False, + ) + callback.evaluate(weights=None) + + expected_seq = np.zeros(2 * n_steps + 1) + expected_seq[1::2] = obs + expected_seq[2::2] = act + seen_len = (context_lengths[-1] + horizon_lengths[-1]) * 2 + callback.predict_model.assert_prediction_buffers_equal( + [expected_seq[:seen_len]] + ) + + @parameterized.named_parameters(("one_symbol", 1), ("two_symbols", 2)) + def test_srl_eval_reports_zero_error_for_perfect_model(self, precision): + vocab_size = 100 + n_steps = 5 + + multibonacci_modulo = make_multibonacci_modulo(2 * precision, vocab_size) + space = gym.spaces.MultiDiscrete(nvec=([vocab_size] * precision)) + (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps) + eval_task = make_singleton_eval_task(obs, act) + seq_model = functools.partial( + tl_test_utils.MockTransformerLM, + sequence_fn=multibonacci_modulo, + ) + serialized_model = make_serialized_model(seq_model, space, vocab_size) + callback = callbacks.SerializedModelEvaluation( + loop=None, + eval_task=eval_task, + model=serialized_model, + context_lengths=(1,), + horizon_lengths=(4,), + accelerate_model=False, + ) + metrics = callback.evaluate(weights=None) + error = next(value for (name, value) in metrics.items() if "pred_error" in name) + assert error == 0 + + +def _has_metric(metric_name, stdout): + log = stdout.getvalue() + metric_logs = [line for line in log.split("\n") if metric_name in line] + return bool(metric_logs) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/supervised/decoding_test.py b/tests/supervised/decoding_test.py new file mode 100644 index 000000000..26eae2960 --- /dev/null +++ b/tests/supervised/decoding_test.py @@ -0,0 +1,536 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for decoding.""" + +import functools +import os + +import gin +from jax.config import config +import numpy as np +from tensorflow.compat.v2 import test + +from trax import fastmath +from trax import layers as tl +from trax import models +from trax import shapes +from trax.supervised import decoding + + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.join(pkg_dir, "../../resources/supervised/testdata") +_CONFIG_DIR = os.path.join(pkg_dir, "../../resources/supervised/configs/") + + +class DecodingTest(test.TestCase): + def test_autoregressive_sample_transformerlm(self): + model = models.TransformerLM( + 10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode="predict" + ) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + s1 = decoding.autoregressive_sample( + model, batch_size=1, eos_id=-1, max_length=10 + ) + self.assertEqual(s1.shape[0], 1) + self.assertEqual(s1.shape[1], 10) + batch_per_device = 2 // fastmath.local_device_count() + model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) + s2 = decoding.autoregressive_sample(model, batch_size=2, max_length=10) + self.assertEqual(s2.shape[0], 2) + self.assertLess(s2.shape[1], 11) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + prefix = np.array([[1, 2, 3]]) + s3 = decoding.autoregressive_sample( + model, prefix, eos_id=-1, max_length=10, batch_size=1 + ) + self.assertEqual(s3.shape[0], 1) + self.assertEqual(s3.shape[1], 10) + + def test_autoregressive_sample_transformerlm_tfnp(self): + with fastmath.use_backend(fastmath.Backend.TFNP): + model = models.TransformerLM( + 10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode="predict" + ) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + s1 = decoding.autoregressive_sample( + model, batch_size=1, eos_id=-1, max_length=10 + ) + self.assertEqual(s1.shape[0], 1) + self.assertEqual(s1.shape[1], 10) + batch_per_device = 2 // fastmath.local_device_count() + model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) + s2 = decoding.autoregressive_sample(model, batch_size=2, max_length=10) + self.assertEqual(s2.shape[0], 2) + self.assertLess(s2.shape[1], 11) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + prefix = np.array([[1, 2, 3]]) + s3 = decoding.autoregressive_sample( + model, prefix, eos_id=-1, max_length=10, batch_size=1 + ) + self.assertEqual(s3.shape[0], 1) + self.assertEqual(s3.shape[1], 10) + + def _lsh_self_attention_fn(self): + return functools.partial( + tl.LSHSelfAttention, + attention_dropout=0.0, + chunk_len=64, + n_buckets=[32, 32], + n_chunks_after=0, + n_chunks_before=1, + n_hashes=1, + n_parallel_heads=1, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def _pure_lsh_self_attention_fn(self, n_chunks_after=0): + return functools.partial( + tl.PureLSHSelfAttentionWrapper, + attention_dropout=0.0, + chunk_len=16, + n_buckets=[32, 32], + n_chunks_after=n_chunks_after, + n_chunks_before=1, + n_hashes=2, + n_parallel_heads=1, + max_length_for_buckets=1024, + predict_drop_len=128, + predict_mem_len=1024, + num_weights=2, + bias=False, + pure_lsh_implementation=tl.PureLSHSelfAttention, + ) + + def _timebin_self_attention_fn(self, use_reference_code=False, chunk_len=64): + return functools.partial( + tl.SelfAttention, + attention_dropout=0.05, + chunk_len=chunk_len, + n_chunks_before=1, + n_parallel_heads=1, + use_reference_code=use_reference_code, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def test_autoregressive_sample_reformerlm(self): + lsh_self_attention = self._lsh_self_attention_fn() + timebin_self_attention = self._timebin_self_attention_fn() + + model = models.ReformerLM( + vocab_size=256, + d_model=256, + d_ff=512, + d_attention_key=128, + d_attention_value=128, + n_layers=2, + n_heads=2, + dropout=0.05, + max_len=65536, + attention_type=[timebin_self_attention, lsh_self_attention], + pos_axial_shape=(256, 256), + pos_d_axial_embs=(128, 128), + ff_activation=tl.Relu, + ff_use_sru=0, + mode="predict", + ) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + s1 = decoding.autoregressive_sample( + model, batch_size=1, eos_id=-1, max_length=10 + ) + self.assertEqual(s1.shape[0], 1) + self.assertEqual(s1.shape[1], 10) + + def test_autoregressive_sample_transformer(self): + model = models.Transformer( + 10, + d_model=32, + d_ff=64, + n_encoder_layers=1, + n_decoder_layers=1, + n_heads=2, + mode="predict", + ) + inputs = np.ones((1, 3), dtype=np.int32) + model.init( + (shapes.signature(inputs), shapes.ShapeDtype((1, 1), dtype=np.int32)) + ) + s = decoding.autoregressive_sample( + model, inputs=inputs, eos_id=-1, max_length=10 + ) + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 10) + + def test_autoregressive_sample_transformerlm_quality(self): + pred_model = models.TransformerLM( + d_model=64, + d_ff=128, + dropout=0.05, + max_len=256, + n_heads=2, + n_layers=2, + vocab_size=13, + mode="predict", + ) + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + model_path = os.path.join(_TESTDATA, "transformerlm_copy.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape11, shape11) + ) + inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) + s = decoding.autoregressive_sample( + pred_model, inputs, max_length=6, temperature=0.0 + ) + self.assertEqual(str(s[0]), "[3 7 5 3 2 4]") + + def test_autoregressive_sample_transformerlm_quality_eval(self): + eval_model = models.TransformerLM( + d_model=64, + d_ff=128, + dropout=0.05, + max_len=256, + n_heads=2, + n_layers=2, + vocab_size=13, + mode="eval", + ) + model_path = os.path.join(_TESTDATA, "transformerlm_copy.pkl.gz") + eval_model.init_from_file(model_path) + inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) + s = decoding.autoregressive_sample( + eval_model, inputs, eval_mode=True, max_length=6, temperature=0.0 + ) + self.assertEqual(str(s[0]), "[3 7 5 3 2 4]") + + def test_autoregressive_sample_transformerlm_quality_beam(self): + pred_model = models.TransformerLM( + d_model=64, + d_ff=128, + dropout=0.05, + max_len=256, + n_heads=2, + n_layers=2, + vocab_size=13, + mode="predict", + ) + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + model_path = os.path.join(_TESTDATA, "transformerlm_copy.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape11, shape11) + ) + inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) + s = decoding.beam_search(pred_model, inputs, n_beams=3, max_length=6) + self.assertEqual(len(s), 3) # 3 beams + self.assertEqual(str(s[0][0][0]), "[3 7 5 3 2 4]") + self.assertEqual(str(s[1][0][0]), "[3 7 5 3 2 2]") # different from above + self.assertEqual(str(s[2][0][0]), "[3 7 5 3 3 2]") # different from above + + def test_autoregressive_sample_transformer_quality(self): + pred_model = models.Transformer( + d_model=64, + d_ff=128, + dropout=0.05, + max_len=256, + n_heads=2, + n_encoder_layers=2, + n_decoder_layers=2, + input_vocab_size=13, + mode="predict", + ) + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + model_path = os.path.join(_TESTDATA, "transformer_copy.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape11, shape11) + ) + inputs = np.array([[3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32) + s = decoding.autoregressive_sample( + pred_model, inputs=inputs, eos_id=1, max_length=10, temperature=0.0 + ) + self.assertEqual(str(s[0]), "[3 7 5 3 2 4 1]") + + def test_autoregressive_sample_terraformer_lsh(self): + max_len = 128 + + pred_model = models.ConfigurableTerraformer( + mode="predict", + d_model=256, + d_ff=512, + dropout=0.05, + max_len=max_len, + n_heads=4, + n_encoder_layers=1, + n_decoder_layers=1, + ff_use_sru=1, + d_attention_key=64, + d_attention_value=64, + encoder_attention_type=self._lsh_self_attention_fn(), + encoder_decoder_attention_type=self._lsh_self_attention_fn(), + input_vocab_size=256, + pos_axial_shape=None, + ) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + pred_model.init(input_signature=(shape1l, shape11)) + + # 0w0w + inputs = np.array( + [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32 + ) + inputs = np.pad( + inputs, + [(0, 0), (0, max_len - inputs.shape[1])], + mode="constant", + constant_values=0, + ) + s = decoding.autoregressive_sample( + pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0 + ) + + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 10) + + def test_autoregressive_sample_terraformer_lsh_attn_quality(self): + gin.add_config_file_search_path(_CONFIG_DIR) + max_len = 32 # 32 is the max length we trained the checkpoint for. + test_lengths = [8, 16, 32] + vocab_size = 13 + # The checkpoint is correct on ~90% sequences, set random seed to deflake. + np.random.seed(0) + for test_len in test_lengths: + gin.clear_config() + gin.parse_config_file("terraformer_copy.gin") + gin.bind_parameter("LSHSelfAttention.predict_mem_len", 2 * max_len) + gin.bind_parameter("LSHSelfAttention.predict_drop_len", 2 * max_len) + + pred_model = models.ConfigurableTerraformer(mode="predict") + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + + model_path = os.path.join(_TESTDATA, "terraformer_copy_lsh_attn.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape1l, shape11) + ) + initial_state = pred_model.state + + for _ in range(2): # Set low to make the test run reasonably fast. + # Pick a length in [1, test_len] at random. + inp_len = np.random.randint(low=1, high=test_len + 1) + inputs = np.random.randint( + low=1, high=vocab_size - 1, size=(1, max_len) + ) + # TODO(jaszczur): properly fix padding in terraformer predict mode, + # and add a test here. + s = decoding.autoregressive_sample( + pred_model, + inputs=inputs, + eos_id=-1, + max_length=inp_len, + temperature=0.0, + ) + np.testing.assert_equal(s[0], inputs[0, :inp_len]) + pred_model.state = initial_state + gin.clear_config() # Make sure to not affect other tests. + + def test_autoregressive_sample_reformerlm_lsh(self): + max_len = 32 + + pred_model = models.ReformerLM( + mode="predict", + d_model=256, + d_ff=512, + dropout=0.05, + max_len=2 * max_len, + n_heads=4, + n_layers=3, + ff_use_sru=0, + d_attention_key=64, + d_attention_value=64, + attention_type=functools.partial( + tl.LSHSelfAttention, + chunk_len=16, + n_hashes=2, + n_buckets=[32, 32], + predict_drop_len=max_len, + predict_mem_len=max_len, + max_length_for_buckets=1024, + ), + vocab_size=13, + pos_type="fixed-base", + pos_d_axial_embs=None, + ) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + pred_model.init(shape11) + + # 0w0 + inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) + inputs = np.pad( + inputs, + [(0, 0), (0, max_len - inputs.shape[1])], + mode="constant", + constant_values=0, + ) + s = decoding.autoregressive_sample( + pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0 + ) + + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 10) + + def test_autoregressive_sample_reformerlm_lsh_quality(self): + max_len = 32 + + pred_model = models.ReformerLM( + mode="predict", + d_model=256, + d_ff=512, + dropout=0.05, + max_len=2 * max_len, + n_heads=4, + n_layers=3, + ff_use_sru=0, + d_attention_key=64, + d_attention_value=64, + attention_type=functools.partial( + tl.LSHSelfAttention, + chunk_len=16, + n_hashes=2, + n_buckets=[32, 32], + predict_drop_len=max_len, + predict_mem_len=max_len, + max_length_for_buckets=1024, + ), + vocab_size=13, + pos_type="fixed-base", + pos_d_axial_embs=None, + ) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + + model_path = os.path.join(_TESTDATA, "reformerlm_copy_lsh_attn.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=shape11 + ) + + # 0w0 + inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) + inp_len = inputs.shape[1] + s = decoding.autoregressive_sample( + pred_model, + inputs=inputs, + eos_id=-1, + max_length=inp_len - 2, + temperature=0.0, + ) + + np.testing.assert_equal(s[0], inputs[0, 1 : inp_len - 1]) + # pylint: enable=unreachable + + def test_autoregressive_sample_terraformer_pure_lsh(self): + max_len = 128 + + pred_model = models.ConfigurableTerraformer( + mode="predict", + d_model=256, + d_ff=512, + dropout=0.05, + max_len=max_len, + n_heads=4, + n_encoder_layers=1, + n_decoder_layers=1, + ff_use_sru=1, + d_attention_key=64, + d_attention_value=64, + encoder_attention_type=self._pure_lsh_self_attention_fn(n_chunks_after=1), + encoder_decoder_attention_type=self._pure_lsh_self_attention_fn(), + input_vocab_size=256, + pos_axial_shape=None, + ) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + pred_model.init(input_signature=(shape1l, shape11)) + + # 0w0w + inputs = np.array( + [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32 + ) + inputs = np.pad( + inputs, + [(0, 0), (0, max_len - inputs.shape[1])], + mode="constant", + constant_values=0, + ) + s = decoding.autoregressive_sample( + pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0 + ) + + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 10) + + def test_autoregressive_sample_terraformer_pure_lsh_attn_quality(self): + gin.add_config_file_search_path(_CONFIG_DIR) + max_len = 32 # 32 is the max length we trained the checkpoint for. + test_lengths = [8, 16, 32] + vocab_size = 13 + # The checkpoint is correct on ~90% sequences, set random seed to deflake. + np.random.seed(0) + for test_len in test_lengths: + gin.clear_config() + gin.parse_config_file("terraformer_purelsh_copy.gin") + gin.bind_parameter("PureLSHSelfAttention.predict_mem_len", 2 * max_len) + gin.bind_parameter("PureLSHSelfAttention.predict_drop_len", 2 * max_len) + gin.bind_parameter("PureLSHSelfAttentionWrapper.bias", False) + gin.bind_parameter("PureLSHSelfAttentionWrapper.num_weights", 2) + + pred_model = models.ConfigurableTerraformer(mode="predict") + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + + model_path = os.path.join(_TESTDATA, "terraformer_purelsh_copy.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape1l, shape11) + ) + initial_state = pred_model.state + + for _ in range(2): # Set low to make the test run reasonably fast. + # Pick a length in [1, test_len] at random. + inp_len = np.random.randint(low=1, high=test_len + 1) + inputs = np.random.randint( + low=1, high=vocab_size - 1, size=(1, max_len) + ) + # TODO(jaszczur): properly fix padding in terraformer predict mode, + # and add a test here. + s = decoding.autoregressive_sample( + pred_model, + inputs=inputs, + eos_id=-1, + max_length=inp_len, + temperature=0.0, + ) + + np.testing.assert_equal(s[0], inputs[0, :inp_len]) + pred_model.state = initial_state + gin.clear_config() # Make sure to not affect other tests. + + +if __name__ == "__main__": + config.config_with_absl() + test.main() diff --git a/tests/supervised/decoding_timing_test.py b/tests/supervised/decoding_timing_test.py new file mode 100644 index 000000000..791e0d9a6 --- /dev/null +++ b/tests/supervised/decoding_timing_test.py @@ -0,0 +1,500 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Timing tests for decoding.""" + +import copy +import functools +import gc +import os +import time +from jax.config import config +import numpy as np +import psutil +from tensorflow.compat.v2 import test + +from trax import fastmath +from trax import layers as tl +from trax import models +from trax import shapes +from trax.supervised import decoding + + +def _size_of_model(model): + def _size(x): + try: + return x.size + except Exception: # pylint: disable=broad-except + return 0 + + sizes = fastmath.nested_map(_size, model.weights) + total_size = sum(fastmath.tree_flatten(sizes)) + return total_size + + +def _recurrent_delete(w): + if "delete" in dir(w): + # Object has a 'delete' method, so it is a DeviceArray or something similar, + # so we want to delete it. + w.delete() + elif isinstance(w, (list, tuple)): + for x in w: + _recurrent_delete(x) + elif isinstance(w, dict): + for x in w.values(): + _recurrent_delete(x) + else: + raise ValueError("Unknown type encountered in weights: {}".format(type(w))) + + +def _memory_usage(): + gc.collect() + return psutil.Process(os.getpid()).memory_info().rss + + +class DecodingTimingTest(test.TestCase): + def _terraformer_decoding_time(self, settings): + # Garbage collection influences the timing, so we turn it off. + gc.disable() + max_len = 16 + + def _self_attention_fn(): + return functools.partial( + tl.SelfAttention, + predict_drop_len=2 * max_len, + predict_mem_len=2 * max_len, + ) + + def _causal_attention_fn(): + attn_layer, attn_kwargs = settings["attn"] + return functools.partial( + attn_layer, max_inference_length=2 * max_len, **attn_kwargs + ) + + if settings["model"] == "terraformer": + pred_model = models.ConfigurableTerraformer( + mode="predict", + d_model=settings["d_model"], + d_ff=settings["d_ff"], + dropout=0.1, + max_len=max_len, + n_heads=settings["n_heads"], + n_encoder_layers=settings["encoder_layers"], + n_decoder_layers=settings["decoder_layers"], + encoder_attention_type=_self_attention_fn(), + encoder_decoder_attention_type=_causal_attention_fn(), + input_vocab_size=settings["vocab"], + ff_sparsity=settings["ff_sparsity"], + ff_use_sru=settings["ff_use_sru"], + ff_dropout=0.1, + # ff_chunk_size=1024, + # attention_chunk_size=1, + n_decoder_attention_layers=settings["attention_layers"], + loss_sparsity=settings["loss_sparsity"], + pos_axial_shape=None, + use_bfloat16=True, + ) + elif settings["model"] == "transformer": + pred_model = models.ConfigurableTransformer( + mode="predict", + d_model=settings["d_model"], + d_ff=settings["d_ff"], + dropout=0.1, + max_len=max_len, + n_heads=settings["n_heads"], + n_encoder_layers=settings["encoder_layers"], + n_decoder_layers=settings["decoder_layers"], + # encoder_attention_type=_self_attention_fn(), + encoder_decoder_attention_type=_causal_attention_fn(), + input_vocab_size=settings["vocab"], + ff_sparsity=settings["ff_sparsity"], + ff_use_sru=settings["ff_use_sru"], + # ff_dropout=0.1, + # ff_chunk_size=1024, + # attention_chunk_size=1, + # n_decoder_attention_layers=settings['attention_layers'], + loss_sparsity=settings["loss_sparsity"], + pos_axial_shape=None, + # enc_dec_attention_sparsity=settings['enc_dec_sparsity'], + # use_bfloat16=True, + ) + else: + assert False + # We put acceleration outside of autoregressive_sample_stream, because + # we want to have a separate run (separate input) for model compilation. + pred_model = tl.Accelerate(pred_model) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + pred_model.init(input_signature=(shape1l, shape11)) + original_state = copy.deepcopy(pred_model.state) + + inputs_warmup = np.zeros((1, max_len), dtype=np.int32) + inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len) + + # This is a warm-up run, for compilation. + result, current_time = [], time.time() + elapsed_warmup_times = [] + for index, sample in zip( + range(0, 4), + decoding.autoregressive_sample_stream( + pred_model, inputs_warmup, temperature=0.0, accelerate=False + ), + ): + del index # unused + result.append(sample[:, None]) # to be sure that the result is computed + + current_time, start_time = time.time(), current_time + elapsed_warmup_times.append(current_time - start_time) + + # This is a real decoding timing run that we measure. + pred_model.state = original_state + result, current_time = [], time.time() + elapsed_times = [] + for index, sample in zip( + range(12), + decoding.autoregressive_sample_stream( + pred_model, inputs, temperature=0.0, accelerate=False + ), + ): + del index # unused + result.append(sample[:, None]) # to be sure that the result is computed + + current_time, start_time = time.time(), current_time + elapsed_times.append(current_time - start_time) + peak_memory = _memory_usage() + + if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]): + print( + "WARNING! High variance found in elapsed times! Settings: {} ; " + "elapsed times: {} ; Probably more warm-up steps should be used, " + "or model size should be increased.".format(settings, elapsed_times) + ) + # Check resulting shapes. + s = np.concatenate(result, axis=1) + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 12) + model_size = int(_size_of_model(pred_model)) + + # We delete the model weights, because in some situations they won't be + # deleted automatically. + _recurrent_delete(pred_model.weights) + gc.enable() + return model_size, elapsed_times, peak_memory + + def test_autoregressive_sample_terraformer_timing(self): + template_to_use = "medium_transformer" + + settings_templates = { + # full model + # # 54B params + # 'full_model': { + # 'encoder_layers': 6, 'decoder_layers': 36, 'vocab': 32000, + # 'attention_layers': 2, + # 'd_ff': 64*1024, 'd_model': 96*96, 'n_heads': 96, + # 'ff_use_sru': (1, 64), 'ff_sparsity': (256, 32), + # 'loss_sparsity': 8, + # 'attn': (tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 64})}, + # 1/18 of model (1/6 of encoder, 1/18 of decoder, full vocab) + # 4B params + # 'big_terraformer': { + # 'model': 'terraformer', + # 'encoder_layers': 1, 'decoder_layers': 2, 'vocab': 32000, + # 'attention_layers': 2, + # 'd_ff': int(5/8 * 64*1024), 'd_model': 96*96, 'n_heads': 96, + # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, + # 'attn': (tl.CausalAttention, {})}, + # 'big_transformer': { + # 'model': 'transformer', + # 'encoder_layers': 1, 'decoder_layers': 2, 'vocab': 32000, + # 'attention_layers': 2, + # 'd_ff': int(5/8 * 64*1024), 'd_model': 96*96, 'n_heads': 96, + # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, + # 'attn': (tl.CausalAttention, {})}, + # medium model + # 275M params (only decoder) + "medium_transformer": { + "model": "transformer", + "encoder_layers": 2, + "decoder_layers": 24, + "vocab": 32000, + "attention_layers": 2, + "d_ff": 4 * 1024, + "d_model": 1024, + "n_heads": 16, + "ff_use_sru": 0, + "ff_sparsity": 0, + "loss_sparsity": 0, + "attn": (tl.CausalAttention, {}), + }, + # 'medium_terraformer': { + # 'model': 'terraformer', + # 'encoder_layers': 2, 'decoder_layers': 24, 'vocab': 32000, + # 'attention_layers': 2, + # 'd_ff': 4*1024, 'd_model': 1024, 'n_heads': 16, + # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, + # 'attn': (tl.CausalAttention, {})}, + } + + sweep_settings = { + # 'big_transformer': [ # for big + # dict(), # baseline + # {'ff_sparsity': (256, 32)}, # + Sparse FF + # {'attn': ( # + Sparse QKV + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 64}), + # 'd_ff': 64*1024, + # }, + # {'ff_sparsity': (256, 32), + # 'attn': ( # + Sparse FF+QKV + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 64}), + # 'd_ff': 64*1024, + # }, + # ], + "medium_transformer": [ # for medium + dict(), # baseline + { + "ff_sparsity": 64, + "attn": ( # Sparse FF+QKV + tl.MultiplicativeConvCausalAttention, + {"length_kernel_size": 3, "sparsity": 16}, + ), + "d_ff": 6 * 1024, + }, + # {'ff_sparsity': 64, # Sparse FF+QKV + Loss + # 'attn': ( + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # 'loss_sparsity': 4, + # }, + # {'attn': ( # Sparse QKV + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # }, + # {'loss_sparsity': 4}, # Sparse Loss + # {'ff_sparsity': 64}, # Sparse FF + # {'ff_sparsity': 128}, # + Sparse FF 128 + # APPENDIX below + # different loss layers + # {'loss_sparsity': 8}, + # {'loss_sparsity': 2}, + # {'loss_sparsity': 0}, + ], + # 'big_terraformer': [ # for big terraformer + # dict(), # baseline + # {'ff_sparsity': 64}, # + Sparse FF / Sparse FF 64 + # {'ff_sparsity': 64, + # 'attn': ( # + Sparse FF+QKV + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # }, + # {'ff_sparsity': 64, # + Sparse FF+QKV+Loss + # 'attn': ( + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # 'loss_sparsity': 4, + # }, + # ], + # 'medium_terraformer': [ # for medium terraformer + # {'ff_sparsity': 64, # + Sparse FF+QKV+Loss + # 'attn': ( + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # 'loss_sparsity': 4, + # }, + # ], + } + + encoding_times = [] + decoding_times = [] + sizes = [] + memories = [] + messages = [] + for override_settings in sweep_settings[template_to_use]: + settings = copy.deepcopy(settings_templates[template_to_use]) + settings.update(override_settings) + + init_memory = _memory_usage() + size, elapsed_times, peak_memory = self._terraformer_decoding_time(settings) + + # TODO(jaszczur): Why is elapsed_times[0] always small? + encoding_time = elapsed_times[1] + decoding_time_10 = sum(elapsed_times[2:]) + + after_memory = _memory_usage() + model_memory_gigabytes = (peak_memory - init_memory) / 1024**3 + decoding_time_diff = (max(elapsed_times[2:]) - min(elapsed_times[2:])) / 2 + decoding_time_diff_percent = int( + decoding_time_diff / np.mean(elapsed_times) * 100 + ) + message = ( + "\n\n" + "Params: {}\n" + "Settings: {}\n" + "Override: {}\n" + "Init memory: {:.1f} GiB\n" + "Peak memory: {:.1f} GiB\n" + "After memory: {:.1f} GiB\n" + "Estimated model memory: {:.1f} GiB\n" + "Times for each step: {}\n" + "Time for encoding: {:.4f} s\n" + "Time for decoding 10 tokens: {:.4f} s +/- {} %\n" + "\n\n".format( + size, + settings, + override_settings, + init_memory / 1024**3, + peak_memory / 1024**3, + after_memory / 1024**3, + model_memory_gigabytes, + elapsed_times, + encoding_time, + decoding_time_10, + decoding_time_diff_percent, + ) + ) + print(message) + messages.append(message) + encoding_times.append(encoding_time) + decoding_times.append(decoding_time_10) + sizes.append(size) + memories.append(model_memory_gigabytes) + + print("Final results (recap):") + for message in messages: + print(message) + + # This is useful for copying results into a spreadsheet etc. + # for i in range(len(sweep_settings)): + # print('{}\t{}\t{}\t{:.1f}'.format( + # sizes[i], encoding_times[i], decoding_times[i], memories[i])) + + def test_loss_layer_timing(self): + all_settings = [ + # The first run is sometimes slower, less reliable. + { + "output": 32000, + "input": 2048, + "prob": None, + "type": None, + "sparsity": 0, + "lowrank": 0, + "use_bias": False, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": None, + "sparsity": 0, + "lowrank": 0, + "use_bias": False, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": "einsum", + "sparsity": 0, + "lowrank": 0, + "use_bias": False, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": "mult", + "sparsity": 2, + "lowrank": 0, + "use_bias": False, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": None, + "sparsity": 0, + "lowrank": 0, + "use_bias": True, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": "einsum", + "sparsity": 0, + "lowrank": 0, + "use_bias": True, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": "mult", + "sparsity": 2, + "lowrank": 0, + "use_bias": True, + }, + ] + + messages = [] + for settings in all_settings: + pred_model = tl.SparseDenseWithOptions( + n_units=settings["output"], + d_input=settings["input"], + sparsity_type=settings["type"], + sparsity=settings["sparsity"], + d_lowrank=settings["lowrank"], + prob_sparse=settings["prob"], + use_bias=settings["use_bias"], + mode="predict", + ) + pred_model = tl.Accelerate(pred_model) + + shape1l = shapes.ShapeDtype((1, settings["input"])) + pred_model.init(input_signature=shape1l) + inputs = np.ones((1, settings["input"])) + + total_time = 0.0 + for counter in range(-50, 100): + start_time = time.time() + y = pred_model(inputs) + self.assertEqual(y.shape, (1, settings["output"])) + elapsed_time = time.time() - start_time + if counter >= 0: + total_time += elapsed_time + + message = ( + "\n\nParams: %d Settings: %s\nTime for 100 tokens: %.4f s\n\n\n" + % (_size_of_model(pred_model), settings, total_time) + ) + messages.append(message) + print(message) + + print("Final results (recap):") + for message in messages: + print(message) + + +if __name__ == "__main__": + config.config_with_absl() + test.main() diff --git a/tests/supervised/history_test.py b/tests/supervised/history_test.py new file mode 100644 index 000000000..49e25d9e3 --- /dev/null +++ b/tests/supervised/history_test.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.supervised.history.""" + +from absl.testing import absltest + +from trax.supervised import history as trax_history + + +class HistoryTest(absltest.TestCase): + def test_unknown_mode(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + self.assertEqual(history.get("unknown_mode", "metric1"), []) + + def test_unknown_metric(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + self.assertEqual(history.get("train", "unknown_metric"), []) + + def test_serializer_and_deserializer(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + json_object = history.to_dict() + history2 = trax_history.History.from_dict(json_object) + self.assertEqual(history2.get("train", "metric1"), [(1, 0.1)]) + + def test_modes(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + history.append("test", "metric2", 2, 0.2) + self.assertEqual(history.modes, ["test", "train"]) + + def test_metrics_for_mode(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + history.append("train", "metric2", 2, 0.2) + self.assertEqual(history.metrics_for_mode("train"), ["metric1", "metric2"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/supervised/lr_schedules_test.py b/tests/supervised/lr_schedules_test.py new file mode 100644 index 000000000..b76b5de1b --- /dev/null +++ b/tests/supervised/lr_schedules_test.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests of learning rate schedules.""" + +import math + +from absl.testing import absltest + +from trax.supervised import lr_schedules + + +class LRFunctionsTest(absltest.TestCase): + def test_warmup(self): + lr_fn = lr_schedules.warmup(9, 0.01) + + # Linear warm-up. + self.assertAlmostEqual(0.001, lr_fn(1)) + self.assertAlmostEqual(0.002, lr_fn(2)) + self.assertAlmostEqual(0.005, lr_fn(5)) + self.assertAlmostEqual(0.009, lr_fn(9)) + + # Constant thereafter. + self.assertAlmostEqual(0.01, lr_fn(10)) + self.assertAlmostEqual(0.01, lr_fn(11)) + self.assertAlmostEqual(0.01, lr_fn(20)) + self.assertAlmostEqual(0.01, lr_fn(300)) + self.assertAlmostEqual(0.01, lr_fn(4000)) + + def test_constant(self): + lr_fn = lr_schedules.constant(0.02) + self.assertEqual(0.02, lr_fn(1)) + self.assertEqual(0.02, lr_fn(20)) + self.assertEqual(0.02, lr_fn(300)) + self.assertEqual(0.02, lr_fn(4000)) + self.assertEqual(0.02, lr_fn(50000)) + self.assertEqual(0.02, lr_fn(600000)) + self.assertEqual(0.02, lr_fn(7000000)) + self.assertEqual(0.02, lr_fn(80000000)) + self.assertEqual(0.02, lr_fn(900000000)) + + def test_warmup_and_rsqrt_decay(self): + lr_fn = lr_schedules.warmup_and_rsqrt_decay(24, 0.25) + + # Warm-up. + self.assertAlmostEqual(0.01, lr_fn(1)) + self.assertAlmostEqual(0.02, lr_fn(2)) + self.assertAlmostEqual(0.23, lr_fn(23)) + self.assertAlmostEqual(0.24, lr_fn(24)) + + # Reciprocal square-root decay. + self.assertAlmostEqual(0.25 * (5 / math.sqrt(25)), lr_fn(25)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(26)), lr_fn(26)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(27)), lr_fn(27)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(300)), lr_fn(300)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(4000)), lr_fn(4000)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(50000)), lr_fn(50000)) + + def test_cosine_sawtooth(self): + tail_fn = lr_schedules._CosineSawtoothTail(180, min_value=0.1) + lr_fn = lr_schedules._BodyAndTail(0.3, tail_start=0, tail_fn=tail_fn) + + # First cycle + self.assertAlmostEqual(0.29998477, lr_fn(1)) + self.assertAlmostEqual(0.28660254, lr_fn(30)) + self.assertAlmostEqual(0.25, lr_fn(60)) + self.assertAlmostEqual(0.20, lr_fn(90)) + self.assertAlmostEqual(0.15, lr_fn(120)) + self.assertAlmostEqual(0.10001523, lr_fn(179)) + + # Second cycle + self.assertEqual(0.3, lr_fn(180)) + self.assertAlmostEqual(0.29998477, lr_fn(181)) + self.assertAlmostEqual(0.28660254, lr_fn(210)) + self.assertAlmostEqual(0.25, lr_fn(240)) + self.assertAlmostEqual(0.20, lr_fn(270)) + self.assertAlmostEqual(0.15, lr_fn(300)) + self.assertAlmostEqual(0.10001523, lr_fn(359)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/supervised/trainer_lib_test.py b/tests/supervised/trainer_lib_test.py new file mode 100644 index 000000000..02e82756f --- /dev/null +++ b/tests/supervised/trainer_lib_test.py @@ -0,0 +1,580 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.supervised.trainer_lib.""" + +import functools +import os + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax.config import config +import tensorflow.compat.v2 as tf +from trax import fastmath +from trax import layers as tl +from trax import models +from trax import optimizers as trax_opt +from trax import shapes as trax_shapes +from trax import test_utils +from trax.data import inputs as inputs_lib +from trax.fastmath import numpy as jnp +from trax.supervised import lr_schedules as lr +from trax.supervised import trainer_lib +from trax.tf_numpy import extensions as npe +from trax.tf_numpy import numpy as tf_np + + +def _test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)): + """Make trainer_lib.inputs.Inputs.""" + batch_size = 2 * jax.device_count() + + def input_stream(n_devices): + del n_devices + key = fastmath.random.get_prng(0) + while True: + keys = fastmath.random.split(key, 4) + key = keys[0] + inputs = fastmath.random.uniform(keys[1], [batch_size] + list(input_shape)) + targets = fastmath.random.randint( + keys[2], [batch_size], dtype=jnp.int32, minval=0, maxval=n_classes + ) + weights = fastmath.random.uniform(keys[3], [batch_size]) + if with_weights: + yield inputs, targets, weights + else: + yield inputs, targets + + def input_stream_masked(n_devices): + return inputs_lib.add_loss_weights(input_stream(n_devices)) + + return inputs_lib.Inputs(input_stream_masked) + + +def _test_inputs_lm(vocab_size, seq_len, per_device_batch_size=2): + """Make trainer_lib.inputs.Inputs for language model.""" + batch_size = per_device_batch_size * jax.device_count() + + def input_stream(_): + def make_batch(key): + return fastmath.random.randint( + key, [batch_size, seq_len], dtype=jnp.int32, minval=0, maxval=vocab_size + ) + + key = fastmath.random.get_prng(0) + while True: + keys = fastmath.random.split(key, 3) + key = keys[0] + inputs = make_batch(keys[1]) + targets = make_batch(keys[2]) + yield inputs, targets + + def input_stream_masked(n_devices): + return inputs_lib.add_loss_weights(input_stream(n_devices)) + + return inputs_lib.Inputs(input_stream_masked) + + +BACKENDS = [fastmath.Backend.JAX] +BACKENDS_AND_CONFIGS = [(fastmath.Backend.JAX, [("Simple", None)])] + + +def short_name(b): + if b == fastmath.Backend.JAX: + return "jax" + else: + return "tf" + + +def opt_name(opt): + if opt is None: + return "None" + return opt.__name__ + + +def _pure_lsh_self_attention_fn(n_chunks_after=0): + return functools.partial( + tl.PureLSHSelfAttentionWrapper, + attention_dropout=0.1, + chunk_len=16, + n_buckets=[32, 32], + n_chunks_after=n_chunks_after, + n_chunks_before=1, + n_hashes=2, + n_parallel_heads=1, + max_length_for_buckets=1024, + predict_drop_len=128, + predict_mem_len=1024, + num_weights=2, + bias=False, + pure_lsh_implementation=tl.PureLSHSelfAttention, + ) + + +def _mixed_lsh_self_attention_fn(n_chunks_after=0): + return functools.partial( + tl.PureLSHSelfAttentionWrapper, + attention_dropout=0.1, + chunk_len=16, + n_buckets=[32, 32], + n_chunks_after=n_chunks_after, + n_chunks_before=1, + n_hashes=2, + n_parallel_heads=1, + max_length_for_buckets=1024, + predict_drop_len=128, + predict_mem_len=1024, + num_weights=2, + bias=False, + pure_lsh_implementation=tl.MixedLSHSelfAttention, + ) + + +class TraxTest(parameterized.TestCase): + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super().__init__(methodName) + if npe.tpu_devices(): + # Initialize TPU for TF + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") + tf.tpu.experimental.initialize_tpu_system(resolver) + + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + self._old_is_allow_float64 = tf_np.is_allow_float64() + tf_np.set_allow_float64(False) + + def tearDown(self): + tf_np.set_allow_float64(self._old_is_allow_float64) + super().tearDown() + + def _test_train_eval_predict(self, backend, model_name="Simple", optimizer=None): + with fastmath.use_backend(backend): + # Prepare model and inputs + steps = 2 + eval_steps = 2 + + if model_name == "Simple": + n_classes = 4 + # Adds Dropout and BatchNorm to test state handling. + def model_fn(mode="train"): + return tl.Serial( + tl.Dropout(mode=mode, rate=0.1), + tl.BatchNorm(mode=mode), + models.MLP(layer_widths=(16, 16, n_classes), mode=mode), + ) + + inputs = _test_inputs(n_classes) + n_in = 1 + elif model_name == "Resnet50": + n_classes = 4 + model_fn = models.Resnet50 + inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) + n_in = 1 + elif model_name == "Transformer": + vocab_size = 32 + seq_len = 16 + inputs = _test_inputs_lm(vocab_size, seq_len) + model_fn = functools.partial( + models.Transformer, input_vocab_size=vocab_size + ) + n_in = 2 + else: + raise ValueError("Unrecognized model name: " + model_name) + + kwargs = {} + if optimizer is not None: + kwargs["optimizer"] = optimizer + + # Train and evaluate + output_dir = self.create_tempdir().full_path + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, # eval at every step. + **kwargs + ) + + # Assert total train steps + self.assertEqual(steps, loop.step) + + inputs = inputs.train_stream(1) + + # Predict with final weights + model = model_fn() + weights = loop.model.weights + state = loop.model.state + model(next(inputs)[:n_in], weights=weights, state=state) + + # Predict with weights loaded from file. + model = model_fn() + model.init_from_file(os.path.join(output_dir, "model.pkl.gz")) + model(next(inputs)[:n_in]) + + @parameterized.named_parameters( + ( + "_%s_%s_%s" + % ( + short_name(backend), + model_name, + opt_name(opt), + ), # pylint: disable=g-complex-comprehension + backend, + model_name, + opt, + ) + for backend, configs in BACKENDS_AND_CONFIGS + for model_name, opt in configs + ) + def test_train_eval_predict(self, backend, model_name, opt): + self._test_train_eval_predict(backend, model_name, opt) + + @parameterized.parameters(BACKENDS) + def test_train_eval_predict_sm3(self, backend): + self._test_train_eval_predict(backend, "Simple", trax_opt.SM3) + + @parameterized.parameters(BACKENDS) + def test_train_restart(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 2 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Restart training + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=(2 * steps), + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Assert total train steps + self.assertEqual(loop.step, 2 * steps) + + @parameterized.parameters(BACKENDS) + def test_train_permanent_checkpoints(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 5 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + + # Steps 1 -> 5 + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + permanent_checkpoint_frequency=2, + ) + + # Steps 6 -> 10 + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=(2 * steps), + eval_steps=eval_steps, + eval_frequency=1, + permanent_checkpoints_at=[7, 8, 10], + ) + + path = os.path.join(output_dir, "model.pkl.gz") + self.assertTrue(tf.io.gfile.exists(path)) + + for step in range(11): + filename = "model_{}.pkl.gz".format(step) + path = os.path.join(output_dir, filename) + if step in [1, 2, 4, 7, 8, 10]: + self.assertTrue( + tf.io.gfile.exists(path), + msg="No model for step: {} in dir {}.".format( + step, tf.io.gfile.listdir(output_dir) + ), + ) + else: + self.assertFalse( + tf.io.gfile.exists(path), + msg="Model for step: {} in dir {}.".format( + step, tf.io.gfile.listdir(output_dir) + ), + ) + + # Assert total train steps + self.assertEqual(loop.step, 10) + + @parameterized.parameters(BACKENDS) + def test_train_restart_with_same_steps(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 2 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Restart training + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Assert total train steps + self.assertEqual(loop.step, steps) + + def test_train_with_pure_lsh_attention(self, backend=fastmath.Backend.JAX): + with fastmath.use_backend(backend): + # Prepare model and inputs + def model(mode="train"): + return models.ConfigurableTerraformer( + mode=mode, + d_model=16, + d_ff=16, + n_heads=2, + dropout=0.05, + n_decoder_layers=1, + n_encoder_layers=1, + input_vocab_size=256, + encoder_attention_type=_pure_lsh_self_attention_fn(), + encoder_decoder_attention_type=_pure_lsh_self_attention_fn(), + ) + + max_len = 128 + inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) + + steps = 1 + eval_steps = 1 + + # Train and evaluate + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir, + model=model, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Read checkpoint + model_file = os.path.join(output_dir, "model.pkl.gz") + + shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) + shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) + + model_predict = model(mode="predict") + model_predict.init_from_file( + model_file, weights_only=True, input_signature=(shape1l, shape11) + ) + + def test_train_with_mixed_lsh_attention(self, backend=fastmath.Backend.JAX): + with fastmath.use_backend(backend): + # Prepare model and inputs + + def model(mode="train"): + return models.ConfigurableTerraformer( + mode=mode, + d_model=16, + d_ff=16, + n_heads=2, + dropout=0.05, + n_decoder_layers=1, + n_encoder_layers=1, + input_vocab_size=256, + encoder_attention_type=_mixed_lsh_self_attention_fn(), + encoder_decoder_attention_type=_mixed_lsh_self_attention_fn(), + ) + + max_len = 128 + inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) + + steps = 1 + eval_steps = 1 + + # Train and evaluate + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir, + model=model, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Read checkpoint + model_file = os.path.join(output_dir, "model.pkl.gz") + + shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) + shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) + + model_predict = model(mode="predict") + model_predict.init_from_file( + model_file, weights_only=True, input_signature=(shape1l, shape11) + ) + + @parameterized.parameters(BACKENDS) + def test_train_fills_in_missing_eval_metrics(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 2 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + additional_eval_stream = trainer_lib.NamedStream( + # deliberately duplicating eval data + stream=inputs.eval_stream(1), + name="additional_eval_task", + ) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + additional_eval_streams=[additional_eval_stream], + ) + + self.assertLen(loop.eval_tasks, 2) + eval_task_1, eval_task_2 = loop.eval_tasks + self.assertCountEqual(eval_task_1.metrics, eval_task_2.metrics) + self.assertCountEqual(eval_task_1.metric_names, eval_task_2.metric_names) + + @parameterized.named_parameters( + ("_%s" % short_name(backend), backend) for backend in BACKENDS + ) + def test_train_with_weights(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 2 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes, with_weights=True) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + state = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + ) + + # Assert total train steps + self.assertEqual(state.step, steps) + + @parameterized.parameters(BACKENDS) + def test_reset_twice(self, backend): + with fastmath.use_backend(backend): + n_classes = 4 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + + trainer = trainer_lib.Trainer( + model=model_fn, + loss_fn=tl.WeightedCategoryCrossEntropy(), + optimizer=trax_opt.SM3, + lr_schedule=lr.multifactor(), + inputs=inputs, + ) + + output_dir1 = self.create_tempdir(name="output_dir1").full_path + trainer.reset(output_dir1) + trainer.evaluate(1) + output_dir2 = self.create_tempdir(name="output_dir2").full_path + trainer.reset(output_dir2) + trainer.evaluate(1) + + def test_tf_xla_forced_compile(self): + # TODO(wangpeng): re-enable this test + self.skipTest("Needs --config=cuda to pass this test") + old_flag = fastmath.tf.tf_xla_forced_compile_enabled() + fastmath.tf.set_tf_xla_forced_compile(True) + self._test_train_eval_predict("tf") + fastmath.tf.set_tf_xla_forced_compile(old_flag) + + +class EpochsTest(absltest.TestCase): + def test_cuts_epoch_when_total_steps_reached(self): + epoch_steps = trainer_lib.epochs( + total_steps=5, steps_to_skip=0, epoch_steps=[1, 2, 3] + ) + self.assertEqual(list(epoch_steps), [1, 2, 2]) + + def test_skips_full_epoch(self): + epoch_steps = trainer_lib.epochs( + total_steps=4, steps_to_skip=2, epoch_steps=[2, 2] + ) + self.assertEqual(list(epoch_steps), [2]) + + def test_skips_part_of_epoch(self): + epoch_steps = trainer_lib.epochs( + total_steps=4, steps_to_skip=1, epoch_steps=[2, 2] + ) + self.assertEqual(list(epoch_steps), [1, 2]) + + +if __name__ == "__main__": + config.config_with_absl() + tf.compat.v1.enable_eager_execution() + absltest.main() diff --git a/tests/supervised/training_test.py b/tests/supervised/training_test.py new file mode 100644 index 000000000..f44dc3cf3 --- /dev/null +++ b/tests/supervised/training_test.py @@ -0,0 +1,761 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for supervised training: core classes and flows.""" + +import collections +import os +import time + +from absl.testing import absltest +from jax.config import config +import numpy as np + +from trax import data +from trax import fastmath +from trax import layers as tl +from trax import optimizers +from trax import shapes +from trax import test_utils +from trax.layers import base +from trax.models import transformer +from trax.supervised import callbacks +from trax.supervised import training + + +class TrainingTest(absltest.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + + def test_loop_no_eval_task(self): + """Runs a training loop with no eval task(s).""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + training_session = training.Loop(model, [task]) + # Loop should initialize and run successfully, even with no eval task. + training_session.run(n_steps=5) + + def test_loop_checkpoint_low_metric(self): + """Runs a training loop that saves checkpoints for low metric values.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_metric = tl.L2Loss() + eval_task = training.EvalTask( + _very_simple_data(), [eval_metric], metric_names=["l2_loss"] + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + output_dir=tmp_dir, + eval_at=lambda step_n: step_n % 2 == 0, + checkpoint_at=lambda step_n: step_n % 2 == 0, + checkpoint_low_metric="l2_loss", + ) + call_counter = collections.Counter() + loop.save_checkpoint = lambda name: call_counter.update([name]) + loop.run(n_steps=10) + + # Eval metric steadily descends, so low checkpoint triggered all 5 times. + # High checkpoint not defined, so never triggered. + self.assertEqual(call_counter["model"], 5) + self.assertEqual(call_counter["lowest_l2_loss"], 5) + self.assertEqual(call_counter["highest_l2_loss"], 0) + + def test_loop_checkpoint_high_metric(self): + """Runs a training loop that saves checkpoints for high metric values.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_metric = tl.L2Loss() + eval_task = training.EvalTask( + _very_simple_data(), [eval_metric], metric_names=["l2_loss"] + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + output_dir=tmp_dir, + eval_at=lambda step_n: step_n % 2 == 0, + checkpoint_at=lambda step_n: step_n % 2 == 0, + checkpoint_high_metric="l2_loss", + ) + call_counter = collections.Counter() + loop.save_checkpoint = lambda name: call_counter.update([name]) + loop.run(n_steps=10) + + # Eval metric steadily descends, so high checkpoint triggered only once. + # Low checkpoint not defined, so never triggered. + self.assertEqual(call_counter["model"], 5) + self.assertEqual(call_counter["lowest_l2_loss"], 0) + self.assertEqual(call_counter["highest_l2_loss"], 1) + + def test_train_dense_layer(self): + """Trains a very simple network on a very simple task.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + training_session = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=15) + self.assertEqual(15, training_session.step) + training_session.run(n_steps=5) + self.assertEqual(20, training_session.step) + + def test_loop_with_initialized_model(self): + """Check that loop does not re-initialize an already initialized model.""" + model = tl.Serial(tl.Dense(1)) + example_data = next(_very_simple_data()) + model.init(example_data) + w = model.weights[0][0] + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + ) + self.assertEqual(0, loop.step) + self.assertEqual(loop.model.weights[0][0], w) + + def test_train_save_restore_dense(self): + """Saves and restores a checkpoint to check for equivalence.""" + self.skipTest("Broken by https://github.com/google/jax/pull/11234") + train_data = data.Serial( + lambda _: _very_simple_data(), data.CountAndSkip("simple_data") + ) + task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(0.0001)) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + tmp_dir = self.create_tempdir().full_path + + def _make_model_and_session(): + m = tl.Serial(tl.Dense(1)) + ts = training.Loop( + m, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + return m, ts + + model, training_session = _make_model_and_session() + self.assertEqual(0, training_session.step) + training_session.run(n_steps=1) + training_session.save_checkpoint("model") + self.assertEqual(data.inputs.data_counters["simple_data"], 2) + data.inputs.data_counters["simple_data"] = 0 # reset manually + self.assertEqual(data.inputs.data_counters["simple_data"], 0) # check + model2, training_session2 = _make_model_and_session() + self.assertEqual(data.inputs.data_counters["simple_data"], 2) # restored + + x = np.ones((8, 1)) + y1 = model(x, rng=fastmath.random.get_prng(0)) + y2 = model2(x, rng=fastmath.random.get_prng(0)) + self.assertEqual(str(y1), str(y2)) + + training_session2.run(n_steps=1) + y1 = model(x, rng=fastmath.random.get_prng(0)) + y2 = model2(x, rng=fastmath.random.get_prng(0)) + self.assertNotEqual(str(y1), str(y2)) + + slots1 = training_session._trainer_per_task[0].slots + slots2 = training_session2._trainer_per_task[0].slots + np.testing.assert_array_equal(slots1, slots2) + + def test_train_save_restore_sharded(self): + """Saves and restores a sharded checkpoint to check for equivalence.""" + if fastmath.local_device_count() < 2: + return # multi-accelerator only + base.N_WEIGHTS_SHARDS = fastmath.local_device_count() + train_data = data.Serial( + lambda _: _very_simple_data(2, 2), data.CountAndSkip("simple_data") + ) + task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(0.0001)) + eval_task = training.EvalTask( + _very_simple_data(2, 2), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + tmp_dir = self.create_tempdir().full_path + + def _make_model_and_session(): + m = tl.Serial(tl.Dense(2)) + ts = training.Loop( + m, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + return m, ts + + _, training_session = _make_model_and_session() + self.assertEqual(0, training_session.step) + training_session.run(n_steps=1) + training_session.save_checkpoint("model") + _, training_session2 = _make_model_and_session() + training_session2.run(n_steps=1) + base.N_WEIGHTS_SHARDS = 1 + + def test_train_save_restore_transformer(self): + """Saves and restores a checkpoint to check for equivalence.""" + vocab_size = 8 + task = training.TrainTask( + _very_simple_transformer_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_transformer_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + tmp_dir = self.create_tempdir().full_path + + def _make_model_and_session(): + m = transformer.TransformerLM( + vocab_size, d_model=4, d_ff=4, n_layers=1, n_heads=2, dropout=0.0 + ) + ts = training.Loop( + m, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + return m, ts + + model, training_session = _make_model_and_session() + self.assertEqual(0, training_session.step) + training_session.run(n_steps=1) + training_session.save_checkpoint("model") + model2, training_session2 = _make_model_and_session() + + x = np.ones((2, 2)).astype(np.int32) + y1 = model(x, rng=fastmath.random.get_prng(0)) + y2 = model2(x, rng=fastmath.random.get_prng(0)) + self.assertEqual(str(y1), str(y2)) + + training_session2.run(n_steps=1) + y1 = model(x, rng=fastmath.random.get_prng(0)) + y2 = model2(x, rng=fastmath.random.get_prng(0)) + self.assertNotEqual(str(y1), str(y2)) + + def test_train_dense_layer_with_momentum(self): + """Trains with an optimizer that has slots / requires initialization.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.Momentum(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["Momentum.L2Loss"], + ) + training_session = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=20) + self.assertEqual(20, training_session.step) + + def test_train_dense_layer_evals(self): + """Trains a very simple network on a very simple task, 2 epochs.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), [tl.L2Loss()] # deliberately re-using training data + ) + training_session = training.Loop( + model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: False + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=10) + self.assertEqual(10, training_session.step) + training_session.run_evals() + self.assertEqual(10, training_session.step) # Unchanged + + def test_summaries_are_written(self): + """Training writes down metrics when writing is turned on.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + tmp_dir = self.create_tempdir().full_path + training_session = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + expected_train_metric_dir = os.path.join(tmp_dir, "train") + expected_eval_metric_dir = os.path.join(tmp_dir, "eval") + for directory in [expected_train_metric_dir, expected_eval_metric_dir]: + self.assertFalse( + os.path.isdir(directory), "Failed for directory %s." % directory + ) + training_session.run(n_steps=15) + time.sleep(1) # wait for the files to be closed + for directory in [expected_train_metric_dir, expected_eval_metric_dir]: + self.assertTrue( + os.path.isdir(directory), "Failed for directory %s." % directory + ) + self.assertEqual( + 1, _count_files(directory), "Failed for directory %s." % directory + ) + training_session.run(n_steps=5) + time.sleep(1) # wait for the files to be closed + for directory in [expected_train_metric_dir, expected_eval_metric_dir]: + self.assertEqual( + 2, _count_files(directory), "Failed for directory %s." % directory + ) + + def test_restores_step(self): + """Training restores step from directory where it saved it.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(4) + loop2 = training.Loop(model, [task], output_dir=tmp_dir) + self.assertEqual(4, loop2.step) + + def test_restores_memory_efficient_from_standard(self): + """Training restores step from directory where it saved it.""" + self.skipTest("Broken by https://github.com/google/jax/pull/11234") + model = tl.Serial(tl.Dense(4), tl.Dense(1)) + task_std = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.Adam(0.0001) + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task_std], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(4) + task_memeff = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.Adam + ) + loop2 = training.Loop( + model, [task_memeff], output_dir=tmp_dir, use_memory_efficient_trainer=True + ) + loop2.run(2) + self.assertEqual(6, loop2.step) + + def test_restores_from_smaller_model(self): + """Training restores from a checkpoint created with smaller model.""" + self.skipTest("Broken by https://github.com/google/jax/pull/11234") + model1 = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.Adam(0.01) + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model1, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(2) + model2 = tl.Serial(tl.Dense(1), tl.Dense(1)) + loop2 = training.Loop(model2, [task], output_dir=tmp_dir) + self.assertEqual(2, loop2.step) + + def test_restore_fails_different_model(self): + """Training restores from a checkpoint created with smaller model.""" + model1 = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model1, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(2) + model2 = tl.Serial(tl.Dense(2)) + with self.assertRaises(IndexError): + training.Loop(model2, [task], output_dir=tmp_dir) + + def test_restores_step_bfloat16(self): + """Training restores step from directory where it saved it, w/ bfloat16.""" + self.skipTest("Broken by https://github.com/google/jax/pull/11234") + model = tl.Serial(tl.Dense(1, use_bfloat16=True)) + # We'll also use Adafactor with bfloat16 to check restoring bfloat slots. + opt = optimizers.Adafactor(0.01, do_momentum=True, momentum_in_bfloat16=True) + task = training.TrainTask(_very_simple_data(), tl.L2Loss(), opt) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(4) + loop2 = training.Loop(model, [task], output_dir=tmp_dir) + self.assertEqual(4, loop2.step) + loop2.run(2) # check that continued training works + self.assertEqual(6, loop2.step) + + def test_restores_step_sharded(self): + """Training restores step from directory where it saved it, sharded.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + use_memory_efficient_trainer=True, + ) + loop.run(4) + loop2 = training.Loop( + model, [task], output_dir=tmp_dir, use_memory_efficient_trainer=True + ) + self.assertEqual(4, loop2.step) + + def test_restores_step_sharded_bfloat16(self): + """Training restores step from where it saved it, sharded and bfloat16.""" + model = tl.Serial(tl.Dense(1, use_bfloat16=True)) + task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + use_memory_efficient_trainer=True, + ) + loop.run(4) + loop2 = training.Loop( + model, [task], output_dir=tmp_dir, use_memory_efficient_trainer=True + ) + self.assertEqual(4, loop2.step) + loop2.run(2) # check that continued training works + self.assertEqual(6, loop2.step) + + def test_restores_history(self): + """Training restores history from directory where it saved it.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), [tl.L2Loss()] # deliberately re-using training data + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(4) + loop2 = training.Loop(model, [task], output_dir=tmp_dir) + self.assertLen(loop2.history.modes, 2) + self.assertLen(loop2.history.metrics_for_mode("train"), 6) + self.assertLen(loop2.history.metrics_for_mode("eval"), 1) + for mode, metric in [ + ("train", "metrics/L2Loss"), + ("train", "training/learning_rate"), + ("train", "training/steps per second"), + ("train", "training/gradients_l2"), + ("train", "training/loss"), + ("train", "training/weights_l2"), + ("eval", "metrics/L2Loss"), + ]: + self.assertLen(loop2.history.get(mode, metric), 1) + self.assertEqual(2, loop2.history.get(mode, metric)[0][0]) + + def test_trains_on_two_tasks(self): + """Trains a very simple network on two very simple tasks.""" + model = tl.Serial(tl.Dense(3), tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + ) + training_session = training.Loop( + model, + tasks=(task, task), + eval_tasks=(eval_task, eval_task), + which_task=lambda step_n: step_n % 2, + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=15) + self.assertEqual(15, training_session.step) + training_session.run(n_steps=5) + self.assertEqual(20, training_session.step) + + def test_train_one_task_eval_two_tasks(self): + """Trains a very simple network on one task and evaluates on two tasks.""" + model = tl.Serial(tl.Dense(3), tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + export_prefix_1 = "eval_1" + eval_task_1 = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + export_prefix=export_prefix_1, + ) + export_prefix_2 = "eval_2" + eval_task_2 = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + export_prefix=export_prefix_2, + ) + training_session = training.Loop( + model, + tasks=(task,), + eval_tasks=(eval_task_1, eval_task_2), + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=5) + self.assertEqual(5, training_session.step) + export_prefixes = [task.export_prefix for task in training_session.eval_tasks] + self.assertCountEqual([export_prefix_1, export_prefix_2], export_prefixes) + + def test_can_predict_with_trained_model(self): + model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) + train_tasks, eval_tasks = [], [] + for output_dim in [1, 2]: + # The head we select from the model: 0 for output_dim 1 and 1 for 2. + head_index = output_dim - 1 + train_tasks.append( + training.TrainTask( + _very_simple_data(output_dim), + tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss()), + optimizers.SGD(0.01), + ) + ) + eval_tasks.append( + training.EvalTask( + _very_simple_data(output_dim), # deliberately re-use training data + [tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss())], + ) + ) + tmp_dir = self.create_tempdir().full_path + training_session = training.Loop( + model, + tasks=train_tasks, + eval_tasks=eval_tasks, + checkpoint_at=lambda step_n: step_n == 1, + output_dir=tmp_dir, + which_task=lambda step_n: step_n % 2, + ) + training_session.run(n_steps=2) + + trained_model = training_session.eval_model + inp = next(_very_simple_data())[0] + out = trained_model(inp) + self.assertEqual( + shapes.signature(out), + (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), + ) + + def test_train_memory_efficient(self): + """Trains a large network in a memory-efficient way.""" + # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU + # and CPU when you run it locally, but it's too big for unit-testing. + ram_limited = True # Set to False to run this test locally. + if fastmath.global_device_count() == 1 and ram_limited: + return + + # Create the model. + n_layers = 16 # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram + model = tl.Serial( + tl.Embedding(9, 16 * 1024), + tl.Dup(), + [ + [tl.ReversibleHalfResidual(tl.Dense(16 * 1024)), tl.ReversibleSwap()] + for _ in range(n_layers) + ], + tl.Concatenate(), + tl.Dense(9), + ) + + # Create inputs. + inputs_batch = np.arange(8).reshape((2, 4)) + targets_batch = inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + + def _data_gen(): + while True: + yield labeled_batch + + # Run training. + loss_layer = tl.WeightedCategoryCrossEntropy() + task = training.TrainTask(_data_gen(), loss_layer, optimizers.Adafactor) + eval_task = training.EvalTask(_data_gen(), [tl.WeightedCategoryCrossEntropy()]) + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n == 2, + use_memory_efficient_trainer=True, + ) + self.assertEqual(0, loop.step) + loop.run(n_steps=2) + self.assertEqual(2, loop.step) + + def test_initializes_step_callbacks_with_loop_instance(self): + """Runs a training loop, asserting that callbacks are initialized.""" + + class ActualLoop: + # Wrapper object to make the Loop reference mutable. + loop = None + + class TestCallback(callbacks.TrainingStepCallback): + def __init__(self, loop): + super().__init__(loop) + ActualLoop.loop = loop + + def call_at(self, step): + return False + + def on_step_begin(self, step): + del step + + def on_step_end(self, step): + del step + + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + expected_loop = training.Loop(model, [task], callbacks=[TestCallback]) + self.assertIs(ActualLoop.loop, expected_loop) + + def test_calls_step_callbacks(self): + """Runs a training loop, asserting that callbacks are called.""" + call_at_steps = [1, 3, 4] + begin_steps = [] + end_steps = [] + test_case = self + + class TestCallback(callbacks.TrainingStepCallback): + def call_at(self, step): + return step in call_at_steps + + def on_step_begin(self, step): + begin_steps.append(step) + + def on_step_end(self, step): + # Assert that on_step_begin() was called before. + test_case.assertIn(step, begin_steps) + end_steps.append(step) + + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + loop = training.Loop(model, [task], callbacks=[TestCallback]) + loop.run(n_steps=5) + + # Assert that the callback has been called at the appropriate steps. + self.assertEqual(begin_steps, call_at_steps) + self.assertEqual(end_steps, call_at_steps) + + +def _very_simple_data(output_dim=1, input_dim=1): + """Returns stream of labeled data that maps small integers to constant pi.""" + inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch + inputs_batch = np.concatenate([inputs_batch] * input_dim, axis=1) + targets_batch = np.pi * np.ones((8, output_dim)) + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + while True: + yield labeled_batch + + +def _very_simple_transformer_data(): + """ "Returns stream of labeled data that maps small integers to constant pi.""" + inputs_batch = np.ones((2, 2)).astype(np.int32) + targets_batch = np.ones((2, 2, 8)).astype(np.int32) + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + while True: + yield labeled_batch + + +def _count_files(path): + """Returns number of files in a given directory.""" + return len( + [ + filename + for filename in os.listdir(path) + if os.path.isfile(os.path.join(path, filename)) + ] + ) + + +if __name__ == "__main__": + config.config_with_absl() + absltest.main() diff --git a/tests/tf_numpy/extensions/extensions_test.py b/tests/tf_numpy/extensions/extensions_test.py new file mode 100644 index 000000000..f49dc1151 --- /dev/null +++ b/tests/tf_numpy/extensions/extensions_test.py @@ -0,0 +1,1170 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tf numpy mathematical methods.""" +import functools +import itertools + +from tests.tf_numpy.jax.config import flags + +import jax +import numpy as np +import tensorflow.compat.v2 as tf +from absl.testing import parameterized + +import trax.tf_numpy.numpy as tf_np +from trax.tf_numpy import extensions + +FLAGS = flags.FLAGS + +flags.DEFINE_bool("requires_tpu", False, "Requires TPU.") + + +def generate_params_inputs_targets(num_examples=1000): + params = (tf_np.asarray(tf.constant(5.0)), tf_np.asarray(tf.constant(0.0))) + + params_true = (tf_np.asarray(tf.constant(3.0)), tf_np.asarray(tf.constant(2.0))) + + inputs = tf_np.asarray(tf.random.normal(shape=[num_examples])) + noise = tf_np.asarray(tf.random.normal(shape=[num_examples])) + targets = inputs * params_true[0] + params_true[1] + noise + + return params, params_true, inputs, targets + + +def loss_fn(params, inputs, targets): + predicted = params[0] * inputs + params[1] + loss = tf.reduce_mean(input_tensor=tf.square(predicted - targets)) + return tf_np.asarray(loss) + + +def train_step(params, inputs, targets, learning_rate=0.1): + grad_fn = extensions.grad(loss_fn) + grads = grad_fn(params, inputs, targets) + new_w = params[0] - (grads[0] * learning_rate) + new_b = params[1] - (grads[1] * learning_rate) + + return new_w, new_b + + +def uniform(rng, shape, dtype): + if np.issubdtype(dtype, np.integer): + minval = None + else: + minval = 0 + return tf_np.asarray(rng.uniform(shape=shape, dtype=dtype, minval=minval)) + + +def to_np(a): + return tf.nest.map_structure(tf_np.asarray, a) + + +def to_tf_fn(f): + return lambda *args: f(*to_np(args)) + + +def scan_reference(f, init, xs): + carry = init + ys = [] + for x in xs: + (carry, y) = f(carry, x) + ys.append(tf_np.reshape(y, (1,) + y.shape)) + ys = tf_np.concatenate(ys, 0) + return carry, ys + + +def spec(*args): + return tf.TensorSpec(args, tf.float32) + + +class ExtensionsTest(tf.test.TestCase, parameterized.TestCase): + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super().__init__(methodName) + physical_devices = tf.config.experimental.list_physical_devices("CPU") + tf.config.experimental.set_virtual_device_configuration( + physical_devices[0], + [ + tf.config.experimental.VirtualDeviceConfiguration(), + tf.config.experimental.VirtualDeviceConfiguration(), + ], + ) + if extensions.tpu_devices(): + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") + tf.tpu.experimental.initialize_tpu_system(resolver) + + def _hasGPU(self): + physical_devices = tf.config.experimental.list_physical_devices("GPU") + return physical_devices + + def testCustomGrad(self): + """Test for custom_grad.""" + x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) + y_shape = tf.TensorShape([]) + dtype = np.float32 + scale1 = 5.0 + scale2 = 6.0 + + def fwd(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + @extensions.custom_grad + def f(a, b): + y = fwd(a, b) + + def vjp(dy): + return dy * scale1 * a, dy * scale2 * b + + return y, vjp + + rng = tf.random.Generator.from_seed(1234) + x, dy = tf.nest.map_structure( + lambda shape: uniform(rng, shape, dtype), [x_shape, y_shape] + ) + expected_y = fwd(*x) + expected_dx = (dy * scale1 * x[0], dy * scale2 * x[1]) + y, vjp = extensions.vjp(f, *x) + dx = vjp(dy) + self.assertAllClose(expected_y, y) + self.assertAllClose(expected_dx, dx) + + @parameterized.named_parameters( + [ + ( # pylint: disable=g-complex-comprehension + ("_%s_%s_%s" % (decorator_id, x_struct, y_struct)) + .replace(" ", "") + .replace("None", ""), + decorator, + x_struct, + y_struct, + ) + for y_struct in [[None, ()], (None, (), [], (None, ((), None)))] + for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))] + for decorator_id, decorator in enumerate([lambda f: f, extensions.jit]) + ] + ) + def testCustomGradStructure(self, decorator, x_struct, y_struct): + """Tests that custom_grad can handle structured inputs/outputs.""" + + def zeros(x): + return tf.nest.map_structure(lambda _: tf_np.zeros([], np.float32), x) + + def get_struct(x): + return tf.nest.map_structure(lambda _: None, x) + + @extensions.custom_grad + def f(*x): + del x + + def vjp(dy): + self.assertEqual(y_struct, get_struct(dy)) + return zeros(x_struct) + + return zeros(y_struct), vjp + + x, dy = zeros([x_struct, y_struct]) + + @decorator + def run(x, dy): + y, vjp = extensions.vjp(f, *x) + dx = vjp(dy) + return dx, y + + dx, y = run(x, dy) + self.assertEqual(x_struct, get_struct(dx)) + self.assertEqual(y_struct, get_struct(y)) + + @parameterized.named_parameters( + [("_%s" % has_aux, has_aux) for has_aux in [True, False]] + ) + def testVjp(self, has_aux): + x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) + y_shape = tf.TensorShape([]) + dtype = np.float32 + + def f(a, b): + y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + if has_aux: + return y, tf_np.asarray(1) + else: + return y + + rng = tf.random.Generator.from_seed(1234) + x, dy_list = tf.nest.map_structure( + lambda shape: uniform(rng, shape, dtype), [x_shape, [y_shape] * 2] + ) + tf_x = x + outputs = extensions.vjp(f, *x, has_aux=has_aux) + if has_aux: + y, vjp, aux = outputs + else: + y, vjp = outputs + with tf.GradientTape(persistent=True) as tape: + tape.watch(tf_x) + outputs = f(*x) + if has_aux: + expected_y, expected_aux = outputs + self.assertAllClose(expected_aux, aux) + else: + expected_y = outputs + self.assertAllClose(expected_y, y) + for dy in dy_list: + expected_dx = tape.gradient(expected_y, tf_x, output_gradients=dy) + self.assertAllClose(expected_dx, vjp(dy)) + + def testGrad(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + g = extensions.grad(f) + + def compare(a, b): + with tf.GradientTape() as tape: + tape.watch(a) + r = f(a, b) + expected = tape.gradient(r, a) + self.assertAllEqual(expected, g(a, b)) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + compare(a, b) + + def testGradNonArrayOutput(self): + def f(_): + return 1.0 + + g = extensions.grad(f) + with self.assertRaisesWithPredicateMatch( + ValueError, r"result .* must be an ndarray" + ): + g(tf_np.asarray(1.0)) + + def testGradNonScalarOutput(self): + def f(a): + return a + + g = extensions.grad(f) + with self.assertRaisesWithPredicateMatch( + ValueError, r"result .* must be a scalar" + ): + g(tf_np.asarray([1.0, 2.0])) + + @extensions.jit + def g_jitted(a): + return extensions.grad(f)(a) + + g_jitted(tf_np.asarray(1.0)) + with self.assertRaisesWithPredicateMatch( + ValueError, r"result .* must be a scalar" + ): + g_jitted(tf_np.asarray([1.0, 2.0])) + + def testJit(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + f_jitted = extensions.jit(f) + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + self.assertAllClose(f(a, b), f_jitted(a, b)) + # Call again since the code path is different on second call + self.assertAllClose(f(a, b), f_jitted(a, b)) + + def testJitNoUnnecessaryTracing(self): + def num_traces(f): + return len(f.tf_function._list_all_concrete_functions_for_serialization()) + + def check_trace_only_once(arg1, arg2): + @extensions.jit + def f(a): + return a + 1 + + self.assertAllEqual(0, num_traces(f)) + f(arg1) + self.assertAllEqual(1, num_traces(f)) + f(arg2) + self.assertAllEqual(1, num_traces(f)) + + check_trace_only_once(1, 2) + check_trace_only_once(1.1, 2.1) + check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2)) + check_trace_only_once( + tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2) + ) + + def _testEvalOnShapes(self, transformer, allow_static_outputs): + + # A class that's not convertable to tensor + class Thing: + def __init__(self, value): + self.value = value + + def f(a, b, reverse=False): + res = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + res = (res, 10) + if allow_static_outputs: + res = res + (Thing(20),) + if reverse: + res = tuple(reversed(res)) + return res + + f_prime = transformer( + f, static_argnums=(2,), allow_static_outputs=allow_static_outputs + ) + shape = [10] + dtype = np.float16 + a = tf_np.zeros(shape=shape, dtype=dtype) + b = tf_np.zeros(shape=shape, dtype=dtype) + expected, *_ = f(a, b) + got = f_prime(a, b) + + def check(got): + self.assertIsInstance(got[0], (tf.TensorSpec, tf_np.ndarray)) + self.assertAllEqual(expected.shape, got[0].shape) + self.assertAllEqual(expected.dtype, got[0].dtype) + if allow_static_outputs: + self.assertIsInstance(got[1], int) + self.assertEqual(10, got[1]) + self.assertIsInstance(got[2], Thing) + self.assertEqual(20, got[2].value) + else: + self.assertIsInstance(got[1], (tf.TensorSpec, tf_np.ndarray)) + self.assertAllEqual((), got[1].shape) + + check(got) + # Call again since the code path is different on second call + got = f_prime(a, b) + check(got) + # Retrace and check again + got = f_prime(a, b, True) + check(tuple(reversed(got))) + got = f_prime(a, b, True) + check(tuple(reversed(got))) + + @parameterized.named_parameters(("_%s" % b, b) for b in [False, True]) + def testEvalOnShapes(self, allow_static_outputs): + self._testEvalOnShapes(extensions.eval_on_shapes, allow_static_outputs) + + def testEvalOnShapesNested(self): + transformer = functools.partial( + extensions.eval_on_shapes, allow_static_outputs=True + ) + + @transformer + def outer(): + @transformer + def inner(): + return 1 + + return inner() + 2 + + r = outer() + self.assertIsInstance(r, int) + self.assertEqual(3, r) + + def testJitOfEvalOnShapes(self): + """Tests that eval_on_shapes can be called within jit.""" + + def transformer(f, **kwargs): + def f_prime(*args): + res = extensions.eval_on_shapes(f, **kwargs)(*args) + return tf.nest.map_structure( + lambda x: tf_np.zeros(x.shape, x.dtype), res + ) + + return extensions.jit(f_prime, kwargs.get("static_argnums", ())) + + self._testEvalOnShapes(transformer, False) + + def testEvalOnShapesNoUnnecessaryTracing(self): + def num_traces(f): + return len(f._tf_function._list_all_concrete_functions_for_serialization()) + + def check_trace_only_once(arg1, arg2): + @extensions.eval_on_shapes + def f(a): + return a + 1 + + self.assertAllEqual(0, num_traces(f)) + f(arg1) + self.assertAllEqual(1, num_traces(f)) + f(arg2) + self.assertAllEqual(1, num_traces(f)) + + check_trace_only_once(1, 2) + check_trace_only_once(1.1, 2.1) + check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2)) + check_trace_only_once( + tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2) + ) + + @parameterized.parameters( + { + "lhs_np": np.ones((5, 3)), + "rhs_np": np.ones((3, 2)), + "dims": (((1,), (0,)), ((), ())), + }, + { + "lhs_np": np.ones((5, 3)), + "rhs_np": np.ones((5, 3)), + "dims": (((0, 1), (0, 1)), ((), ())), + }, + { + "lhs_np": np.ones((5, 3, 2)), + "rhs_np": np.ones((2, 3, 2)), + "dims": (((1, 2), (1, 0)), ((), ())), + }, + { + "lhs_np": np.ones((6, 5, 3)), + "rhs_np": np.ones((6, 3, 2)), + "dims": (((2,), (1,)), ((0,), (0,))), + }, + { + "lhs_np": np.ones((6, 3, 5)), + "rhs_np": np.ones((6, 3, 2)), + "dims": (((1,), (1,)), ((0,), (0,))), + }, + { + "lhs_np": np.ones((5, 3, 2, 2)), + "rhs_np": np.ones((5, 2, 2, 6)), + "dims": (((2, 3), (1, 2)), ((0,), (0,))), + }, + { + "lhs_np": np.ones((2, 2, 5, 3)), + "rhs_np": np.ones((2, 2, 3, 2)), + "dims": (((3,), (2,)), ((0, 1), (0, 1))), + }, + { + "lhs_np": np.ones((2, 2, 5, 2)), + "rhs_np": np.ones((2, 2, 3, 2)), + "dims": (((3,), (1,)), ((0,), (0,))), + }, + { + "lhs_np": np.ones((2, 2, 5, 3, 3)), + "rhs_np": np.ones((2, 3, 2, 3, 2)), + "dims": (((4,), (1,)), ((0,), (0,))), + }, + ) + def test_tf_dot_general(self, lhs_np, rhs_np, dims): + ans = jax.lax.dot_general(lhs_np, rhs_np, dims) + result = extensions.tf_dot_general(lhs_np, rhs_np, dims) + self.assertAllClose(result, np.array(ans)) + + @parameterized.named_parameters( + [ + ( + "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" # pylint: disable=g-complex-comprehension + "_lhs_dilation={}_rhs_dilation={}" + "_feature_group_count={}_batch_group_count={}_dims={}" + "_perms={}".format( + lhs_shape, + rhs_shape, + strides, + padding, + lhs_dilation, + rhs_dilation, + feature_group_count, + batch_group_count, + ",".join(dimension_numbers), + perms, + ), + lhs_shape, + rhs_shape, + strides, + padding, + lhs_dilation, + rhs_dilation, + feature_group_count, + batch_group_count, + dimension_numbers, + perms, + ) + for batch_group_count, feature_group_count in [(1, 1)] + for lhs_shape, rhs_shape in [ + ( + (b * batch_group_count, i * feature_group_count, 9, w), + (j * feature_group_count * batch_group_count, i, 4, 5), + ) + for w in [0, 10] + for b, i, j in itertools.product([2, 3], repeat=3) + ] + for strides in [(1, 1), (2, 1)] + for padding in ["SAME"] + for lhs_dilation, rhs_dilation in [(None, (1, 1))] + for dimension_numbers, perms in [ + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) + ] + ] + ) + def testConvGeneralDilated( + self, + lhs_shape, + rhs_shape, + strides, + padding, + lhs_dilation, + rhs_dilation, + feature_group_count, + batch_group_count, + dimension_numbers, + perms, + ): + lhs_perm, rhs_perm = perms # permute to compatible shapes + + lhs = np.transpose(np.ones(lhs_shape), lhs_perm) + rhs = np.transpose(np.ones(rhs_shape), rhs_perm) + + jax_conv = jax.lax.conv_general_dilated( + lhs, + rhs, + strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count, + ) + + tf_conv = extensions.tf_conv_general_dilated( + lhs, + rhs, + strides, + padding, + None, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count, + ) + + self.assertAllClose(tf_conv, tf_np.asarray(jax_conv)) + + def testConv(self): + y = extensions.conv( + np.ones([5, 320, 480, 3], dtype=np.float32), + np.ones([3, 4, 3, 11], dtype=np.float32), + [1, 1], + "SAME", + ("NHWC", "HWIO", "NHWC"), + ) + self.assertAllClose(y.shape, [5, 320, 480, 11]) + self.assertAllClose( + y, + tf.nn.conv2d( + input=tf.ones([5, 320, 480, 3], dtype=tf.float32), + filters=tf.ones([3, 4, 3, 11], dtype=tf.float32), + strides=1, + padding="SAME", + ), + ) + + def testAvgPool(self): + y = extensions.avg_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID") + self.assertAllEqual( + y, + tf.nn.pool( + input=tf.ones([5, 320, 480, 3]), + window_shape=[3, 5], + pooling_type="AVG", + padding="VALID", + strides=[2, 3], + ), + ) + + def testMaxPool(self): + y = extensions.max_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID") + self.assertAllEqual( + y, + tf.nn.pool( + input=tf.ones([5, 320, 480, 3]), + window_shape=[3, 5], + pooling_type="MAX", + padding="VALID", + strides=[2, 3], + ), + ) + + def assertDTypesEqual(self, a, b): + get_dtype = lambda t: t.dtype + self.assertEqual( + tf.nest.map_structure(get_dtype, a), tf.nest.map_structure(get_dtype, b) + ) + + @parameterized.named_parameters( + ( + f"_{jit_scan}_{jit_f}", + jit_scan, + jit_f, + ) # pylint: disable=g-complex-comprehension + for jit_f in [False, True] + for jit_scan in ["no", "no_xla", "xla_forced_compile"] + ) + def testScanImpl(self, jit_scan, jit_f): + rng = np.random.RandomState(0) + + d = rng.randn(2) + + def f(c, a): + assert a.shape == (3,) + assert c.shape == (4,) + b = tf_np.cos( + tf_np.sum(tf_np.sin(a)) + + tf_np.sum(tf_np.cos(c)) + + tf_np.sum(tf_np.tan(d)) + ) + c = tf_np.sin(c * b) + assert b.shape == () # pylint: disable=g-explicit-bool-comparison + return c, b + + if jit_f: + f = extensions.jit(f) + + if jit_scan == "no_xla": + scan = extensions.jit(extensions.scan, static_argnums=(0,)) + elif jit_scan == "xla_forced_compile": + scan = extensions.jit( + extensions.scan, static_argnums=(0,), xla_forced_compile=True + ) + else: + scan = extensions.scan + + xs = rng.randn(5, 3) + c = rng.randn(4) + + ans = scan(f, c, xs) + expected = scan_reference(f, c, xs) + if jit_scan == "xla_forced_compile": + # xla.compile doesn't preserve list-vs-tuple properly for the outputs, so + # we canonicalize them to lists here. + expected = list(expected) + ans = list(ans) + self.assertDTypesEqual(expected, ans) + self.assertAllClose(expected, ans) + + def testScanStruct(self): + rng = np.random.RandomState(0) + + d = rng.randn(2) + + def f(c_g_i, a_e_h): + c_g, i = c_g_i + c, g = c_g + a, e_h = a_e_h + e, h = e_h + assert a.shape == (3,) + assert e.shape == () # pylint: disable=g-explicit-bool-comparison + assert c.shape == (4,) + assert g.shape == (2,) + assert i is None + assert h is None + b = tf_np.cos( + tf_np.sum(tf_np.sin(a)) + + tf_np.sum(tf_np.cos(c)) + + tf_np.sum(tf_np.tan(d)) + ) + f = tf_np.cos(a) + c = tf_np.sin(c * b) + g = tf_np.sin(g * b) + assert b.shape == () # pylint: disable=g-explicit-bool-comparison + assert f.shape == (3,) + return [(c, g), i], (b, [f, h]) + + xs = (rng.randn(5, 3), [rng.randn(5), None]) + init = [(rng.randn(4), rng.randn(2)), None] + + c_g_i, b_f_h = extensions.scan(f, init, xs) + self.assertIsInstance(c_g_i, list) + self.assertIsInstance(b_f_h, tuple) + c_g, i = c_g_i + c, g = c_g + self.assertIsInstance(c_g, tuple) + self.assertEqual((4,), c.shape) + self.assertEqual((2,), g.shape) + self.assertIsNone(i) + b, f_h = b_f_h + f, h = f_h + self.assertIsInstance(f_h, list) + self.assertEqual((5,), b.shape) + self.assertEqual((5, 3), f.shape) + self.assertIsNone(h) + + @parameterized.named_parameters( + ( + f"_{jit_grad}_{jit_scan}_{jit_f}", + jit_grad, + jit_scan, + jit_f, + ) # pylint: disable=g-complex-comprehension + for jit_f in [False, True] + for jit_scan in ["no", "no_xla", "xla_forced_compile"] + for jit_grad in ["no", "no_xla", "xla_forced_compile"] + ) + def testScanGrad(self, jit_grad, jit_scan, jit_f): + rng = np.random.RandomState(0) + + d = rng.randn(2) + + def f(c, a): + assert a.shape == (3,) + assert c.shape == (4,) + b = ( + tf_np.sum(tf_np.sin(a)) + + tf_np.sum(tf_np.sin(c)) + + tf_np.sum(tf_np.sin(d)) + ) + c = tf_np.sin(c * b) + assert b.shape == () # pylint: disable=g-explicit-bool-comparison + return c, b + + if jit_f: + f = extensions.jit(f) + + if jit_scan == "no_xla": + scan = extensions.jit(extensions.scan, static_argnums=(0,)) + elif jit_scan == "xla_forced_compile": + # TODO(b/187107596): Remove `skipTest` + self.skipTest( + "Taking gradients of `jit(scan, experimental_compile=True)` triggers " + "'Support for TensorList crossing the XLA/TF boundary is not " + "implemented' error" + ) + # `xla_forced_compile=True` doesn't support gradients, so we use + # `experimental_compile=True`. + scan = extensions.jit( + extensions.scan, static_argnums=(0,), experimental_compile=True + ) + else: + scan = extensions.scan + + xs = tf_np.asarray(rng.randn(5, 3)) + c = tf_np.asarray(rng.randn(4)) + + def losses(scan, c, xs): + c, ys = scan(f, c, xs) + return tf_np.concatenate( + tf.nest.flatten( + tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]), (c, ys)) + ) + ) + + def loss(scan, c, xs): + return tf_np.sum(losses(scan, c, xs)) + + def grad_origin(c, xs): + return extensions.grad(functools.partial(loss, scan))(c, xs) + + if jit_grad == "no_xla": + grad_jit = extensions.jit(grad_origin) + elif jit_grad == "xla_forced_compile": + grad_jit = extensions.jit(grad_origin, xla_forced_compile=True) + else: + grad_jit = grad_origin + + ans = grad_jit(c, xs) + expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs) + self.assertDTypesEqual(expected, ans) + self.assertAllClose(expected, ans) + + theoretical, numerical = tf.test.compute_gradient( + to_tf_fn(functools.partial(losses, scan)), (c, xs) + ) + self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4) + + @parameterized.named_parameters( + (f"_{i}", *args) # pylint: disable=g-complex-comprehension + for i, args in enumerate( + [ + ( + lambda c, x: (c + 1, tf_np.sum(c + x, 0)), + [spec(2), spec(4, 3, 2)], + [spec(2), spec(4, 2)], + ), + ( + lambda c, x: (c + 1, tf_np.sum(c + x, 0)), + [spec(2), spec(0, 3, 2), 0], + [spec(2), spec(0, 2)], + ), + ] + ) + ) + def testScanShape(self, f, inputs, expected_outputs): + outputs = extensions.eval_on_shapes( + functools.partial(extensions.scan, f), static_argnums=(2,) + )(*inputs) + self.assertAllEqual(expected_outputs, outputs) + + def testMap(self): + shape = [2, 3] + dtype = tf_np.int32 + xs1 = tf_np.zeros(shape, dtype) + xs2 = tf_np.ones(shape, dtype) + ys_expected = [xs2 + 10, xs1 + 20] + + def f(x): + self.assertIsInstance(x, tuple) + for a in x: + self.assertEqual(a.shape, shape[1:]) + x1, x2 = x + return [x2 + 10, x1 + 20] + + ys = extensions.tf_map(f, (xs1, xs2)) + self.assertIsInstance(ys, list) + self.assertAllClose(ys, ys_expected) + + def testPrng(self): + self.assertAllEqual(tf_np.asarray(123, np.int64), extensions.prng(123)) + + def testUniform(self): + minval = 0.43 + maxval = 3.10 + shape = [13, 34, 29] + atol = 0.1 + outputs = extensions.uniform(123, shape, minval=minval, maxval=maxval) + self.assertAllClose((minval + maxval) / 2.0, np.mean(outputs), atol=atol) + + def testNormal(self): + shape = [13, 34, 29] + atol = 0.1 + outputs = extensions.normal(123, shape) + self.assertAllClose(0, np.mean(outputs), atol=atol) + self.assertAllClose(1, np.std(outputs), atol=atol) + + def testBernoulli(self): + mean = 0.23 + shape = [13, 34, 29] + atol = 0.1 + outputs = extensions.bernoulli(123, mean, shape) + self.assertAllClose(mean, np.mean(outputs), atol=atol) + + def testBernoulliWrongShape(self): + mean = [0.1, 0.2] + shape = [3] + with self.assertRaisesIncompatibleShapesError(): + extensions.bernoulli(123, mean, shape) + + def testDatasetAsNumpy(self): + arrs = extensions.dataset_as_numpy([tf.constant([1, 2]), tf.constant([3, 4])]) + for a in arrs: + self.assertIsInstance(a, tf_np.ndarray) + with self.assertRaisesWithPredicateMatch( + ValueError, + r"dataset_as_numpy must be run in eager mode outside tf.function", + ): + + @tf.function + def f(): + return extensions.dataset_as_numpy([tf.constant([1, 2])]) + + f() + + def _get_two_devices(self, require_same_type=False): + tpus = extensions.tpu_devices() + if FLAGS.requires_tpu: + if len(tpus) == 2: + res = tpus + else: + raise ValueError( + "This test requires 2 TPU cores but %s are found" % len(tpus) + ) + else: + if len(tpus) == 2: + res = tpus + elif self._hasGPU() and not require_same_type: + res = ("CPU:0", "GPU:0") + else: + res = ("CPU:0", "CPU:1") + return res + + def testPmap(self): + devices = self._get_two_devices() + + @functools.partial(extensions.pmap, devices=devices) + def return_three(f): + return f, f + 1.0, f + 2.0 + + result = return_three(tf.ones((2, 20))) + # The function returned 3 items, so we got 3 items back. + self.assertLen(result, 3) + + # Each of the items should be a ShardedNdarray that when converted to tensor + # should produce a tensor of shape (2, 20) + converted = tf.nest.map_structure(tf.convert_to_tensor, result) + + self.assertLen(result, 3) + + self.assertAllEqual(converted[0].shape, converted[1].shape) + self.assertAllEqual(converted[0].shape, converted[2].shape) + + self.assertAllEqual(converted[0], tf.ones((2, 20))) + self.assertAllEqual(converted[1], 1 + tf.ones((2, 20))) + self.assertAllEqual(converted[2], 2 + tf.ones((2, 20))) + + @functools.partial(extensions.pmap, devices=devices) + def return_one(f): + return f + 2.0 + + result = return_one(tf.ones((2, 20))) + + # Only a single item is returned, so we can convert it directly. + converted = tf.convert_to_tensor(value=result) + self.assertAllEqual(converted, 2 + tf.ones((2, 20))) + + @functools.partial(extensions.pmap, devices=devices) + def return_list(f): + return [f + 2.0] + + result = return_list(tf.ones((2, 20))) + + # A singleton list is returned. + self.assertLen(result, 1) + converted = tf.convert_to_tensor(value=result[0]) + self.assertAllEqual(converted, 2 + tf.ones((2, 20))) + + def testGradSimpleModel(self): + params, params_true, inputs, targets = generate_params_inputs_targets() + + for _ in range(50): + params = train_step(params, inputs, targets) + + # This is not trained super well, but it usually gets "close". + self.assertAllClose(params[0], params_true[0], atol=1e-1) + self.assertAllClose(params[1], params_true[1], atol=1e-1) + + # NOTE: Compare to testGradSimpleModel to see the differences when pmapping. + def testPmapSimpleModel(self): + devices = self._get_two_devices(require_same_type=True) + n_devices = len(devices) + + params, params_true, inputs, targets = generate_params_inputs_targets() + + def _train_and_reduce(params, inputs, targets, learning_rate=0.1): + new_w, new_b = train_step(params, inputs, targets, learning_rate) + + return ( + extensions.psum(new_w) / n_devices, + extensions.psum(new_b) / n_devices, + ) + + train_step_pmapped = extensions.pmap(_train_and_reduce, devices=devices) + + def replicate(x, num_devices=2): + return tf_np.broadcast_to(x, (num_devices,) + x.shape) + + params = tf.nest.map_structure(replicate, params) + + def reshape(x, num_devices=2): + x_shape = list(x.shape) + batch_size = x_shape[0] + batch_size_per_device = batch_size // num_devices + + # New shape. + new_shape_prefix = [num_devices, batch_size_per_device] + return tf_np.reshape(x, new_shape_prefix + x_shape[1:]) + + inputs = tf.nest.map_structure(reshape, inputs) + targets = tf.nest.map_structure(reshape, targets) + + for _ in range(50): + params = train_step_pmapped(params, inputs, targets) + + # PMAP returns sharded tensors. + + # Since the inputs are identical, the returned tensors should be identical + self.assertAllClose(params[0][0], params[0][1]) + self.assertAllClose(params[1][0], params[1][1]) + + # This is not trained super well, but it usually gets "close". + self.assertAllClose(params[0][0], params_true[0], atol=1e-1) + self.assertAllClose(params[1][0], params_true[1], atol=1e-1) + + def testPsum(self): + devices = self._get_two_devices(require_same_type=True) + + def reduce_sum(f): + return extensions.psum(f) + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + pmapped = extensions.pmap(reduce_sum, devices=devices) + result = pmapped(data) + + self.assertAllClose(result[0], 4) + self.assertAllClose(result[1], 4) + + def testPsumStruct(self): + devices = self._get_two_devices(require_same_type=True) + + def reduce_sum(a): + a = extensions.psum(a) + tf.nest.map_structure(lambda x: self.assertIsInstance(x, tf_np.ndarray), a) + return a + + data = [tf_np.asarray([1, 3]), tf_np.asarray([2, 4], np.int64)] + pmapped = extensions.pmap(reduce_sum, devices=devices) + result = pmapped(data) + + self.assertIsInstance(result[0][0], tf_np.ndarray) + self.assertIsInstance(result[0][1], tf_np.ndarray) + self.assertIsInstance(result[1][0], tf_np.ndarray) + self.assertIsInstance(result[1][1], tf_np.ndarray) + self.assertAllClose(result[0][0], 4) + self.assertAllClose(result[0][1], 4) + self.assertAllClose(result[1][0], 6) + self.assertAllClose(result[1][1], 6) + + def testPmean(self): + if extensions.tpu_devices(): + self.skipTest("pmean for TPU is not supported yet") + devices = self._get_two_devices(require_same_type=True) + + def reduce_mean(f): + return extensions.pmean(f) + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + pmapped = extensions.pmap(reduce_mean, devices=devices) + result = pmapped(data) + + self.assertAllClose(result[0], 2) + self.assertAllClose(result[1], 2) + + def testAxisName(self): + devices = self._get_two_devices(require_same_type=True) + + def reduce_sum(f): + return extensions.psum(f, axis_name="foo") + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) + pmapped(data) + + def testWrongAxisName(self): + devices = self._get_two_devices(require_same_type=True) + + def reduce_sum(f): + return extensions.psum(f, axis_name="bar") + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + with self.assertRaisesWithPredicateMatch( + ValueError, r"axis_name (.*) is not equal to that of the surrounding" + ): + pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) + pmapped(data) + + def testNoNestedPmap(self): + devices = self._get_two_devices(require_same_type=True) + + def f(x): + return x + 1.0 + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + with self.assertRaisesWithPredicateMatch( + ValueError, r"Nested pmap is not supported" + ): + f = extensions.pmap(f, devices=devices) + f = extensions.pmap(f, devices=devices) + f(data) + + def testVmap(self): + fn1 = extensions.vmap(lambda z: z * z) + + x = tf_np.arange(10) + self.assertAllClose(x * x, fn1(x)) + + y = tf.range(10) + np_y = tf_np.asarray(y) + output = fn1(y) + self.assertIsInstance(output, tf_np.ndarray) + self.assertAllClose(np_y * np_y, output) + + fn2 = extensions.vmap(lambda x, y: x + y) + x = tf_np.random.randn(10, 3) + y = tf_np.random.randn(10, 2, 3) + self.assertAllClose(tf_np.expand_dims(x, 1) + y, fn2(x, y)) + + def testRemat(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + f_remat = extensions.remat(f) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + + actual = extensions.grad(f_remat)(a, b) + expected = extensions.grad(f)(a, b) + self.assertAllClose(actual, expected) + + def testRematLambdaFunction(self): + f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + f_remat = extensions.remat(f) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + + actual = extensions.grad(f_remat)(a, b) + expected = extensions.grad(f)(a, b) + self.assertAllClose(actual, expected) + + def testRematJit(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + f_remat = extensions.remat(f) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + + actual = extensions.jit(extensions.grad(f_remat))(a, b) + expected = extensions.jit(extensions.grad(f))(a, b) + self.assertAllClose(actual, expected) + + def testRematJitXla(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + f_remat = extensions.remat(f) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + + actual = extensions.jit(extensions.grad(f_remat), xla_forced_compile=True)(a, b) + expected = extensions.jit(extensions.grad(f), xla_forced_compile=True)(a, b) + self.assertAllClose(actual, expected) + + actual = extensions.jit(extensions.grad(f_remat), experimental_compile=True)( + a, b + ) + expected = extensions.jit(extensions.grad(f), experimental_compile=True)(a, b) + self.assertAllClose(actual, expected) + + def testStaticStopGradient(self): + self.assertEqual(extensions.stop_gradient(5.0), 5.0) + self.assertEqual(type(extensions.stop_gradient(5.0)), type(5.0)) + + self.assertEqual(extensions.stop_gradient(tf_np.asarray(5.0)), 5.0) + self.assertNotEqual( + type(extensions.stop_gradient(tf_np.asarray(5.0))), type(5.0) + ) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/tf_numpy/jax/config.py b/tests/tf_numpy/jax/config.py new file mode 100644 index 000000000..4c68441a5 --- /dev/null +++ b/tests/tf_numpy/jax/config.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + + +def bool_env(varname: str, default: bool) -> bool: + """Read an environment variable and interpret it as a boolean. + + True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; + false values are 'n', 'no', 'f', 'false', 'off', and '0'. + + Args: + varname: the name of the variable + default: the default boolean value + Raises: ValueError if the environment variable is anything else. + """ + val = os.getenv(varname, str(default)) + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError("invalid truth value %r for environment %r" % (val, varname)) + + +class Config(object): + def __init__(self): + self.values = {} + self.meta = {} + self.FLAGS = NameSpace(self.read) + self.use_absl = False + + def update(self, name, val): + if self.use_absl: + setattr(self.absl_flags.FLAGS, name, val) + else: + self.check_exists(name) + if name not in self.values: + raise Exception("Unrecognized config option: {}".format(name)) + self.values[name] = val + + def read(self, name): + if self.use_absl: + return getattr(self.absl_flags.FLAGS, name) + else: + self.check_exists(name) + return self.values[name] + + def add_option(self, name, default, opt_type, meta_args, meta_kwargs): + if name in self.values: + raise Exception("Config option {} already defined".format(name)) + self.values[name] = default + self.meta[name] = (opt_type, meta_args, meta_kwargs) + + def check_exists(self, name): + if name not in self.values: + raise Exception("Unrecognized config option: {}".format(name)) + + def DEFINE_bool(self, name, default, *args, **kwargs): + self.add_option(name, default, bool, args, kwargs) + + def DEFINE_integer(self, name, default, *args, **kwargs): + self.add_option(name, default, int, args, kwargs) + + def DEFINE_string(self, name, default, *args, **kwargs): + self.add_option(name, default, str, args, kwargs) + + def DEFINE_enum(self, name, default, *args, **kwargs): + self.add_option(name, default, "enum", args, kwargs) + + def config_with_absl(self): + # Run this before calling `app.run(main)` etc + from absl import app, flags as absl_flags + + self.use_absl = True + self.absl_flags = absl_flags + absl_defs = { + bool: absl_flags.DEFINE_bool, + int: absl_flags.DEFINE_integer, + str: absl_flags.DEFINE_string, + "enum": absl_flags.DEFINE_enum, + } + + for name, val in self.values.items(): + flag_type, meta_args, meta_kwargs = self.meta[name] + absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) + + app.call_after_init(lambda: self.complete_absl_config(absl_flags)) + + def complete_absl_config(self, absl_flags): + for name, _ in self.values.items(): + self.update(name, getattr(absl_flags.FLAGS, name)) + + def parse_flags_with_absl(self): + global already_configured_with_absl + if not already_configured_with_absl: + import absl.flags + + self.config_with_absl() + absl.flags.FLAGS(sys.argv, known_only=True) + self.complete_absl_config(absl.flags) + already_configured_with_absl = True + + +class NameSpace(object): + def __init__(self, getter): + self._getter = getter + + def __getattr__(self, name): + return self._getter(name) + + +config = Config() +flags = config +FLAGS = flags.FLAGS + +already_configured_with_absl = False + +flags.DEFINE_bool( + "jax_enable_checks", + bool_env("JAX_ENABLE_CHECKS", False), + help="Turn on invariant checking (core.skip_checks = False)", +) + +flags.DEFINE_bool( + "tf_numpy_additional_tests", True, "Run tests added specifically for TF numpy" +) diff --git a/tests/tf_numpy/jax/lax_numpy_einsum_test.py b/tests/tf_numpy/jax/lax_numpy_einsum_test.py new file mode 100644 index 000000000..f578604c3 --- /dev/null +++ b/tests/tf_numpy/jax/lax_numpy_einsum_test.py @@ -0,0 +1,372 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict # pylint: disable=g-importing-member +import itertools + +from absl.testing import absltest +from absl.testing import parameterized + +from tests.tf_numpy.jax.config import config +import tests.tf_numpy.jax.utils as jtu + +import numpy as np +import tensorflow.compat.v2 as tf +import trax.tf_numpy.numpy as jnp + + +config.parse_flags_with_absl() + + +class EinsumTest(jtu.TestCase): + def _check(self, s, *ops): + a = np.einsum(s, *ops) + b = jnp.einsum(s, *ops) + self.assertAllClose(a, b, check_dtypes=True, atol=1e-4, rtol=1e-4) + + def test_three_operands_1(self): + r = self.rng() + x = r.randn(3) + y = r.randn(4) + z = r.randn(5) + s = "i,j,k->ijk" + self._check(s, x, y, z) + + def test_three_operands_2(self): + r = self.rng() + x = r.randn(3) + y = r.randn(4) + z = r.randn(5) + s = "i,j,k->ijk" + self._check(s, x, y, z) + + def test_two_operands_1(self): + r = self.rng() + x = r.randn(3, 4) + y = r.randn(4) + s = "ij,j->i" + self._check(s, x, y) + + def test_two_operands_2(self): + r = self.rng() + x = r.randn(3, 4, 5) + y = r.randn(4) + s = "ijk,j->i" + self._check(s, x, y) + + def test_two_operands_3(self): + r = self.rng() + x = r.randn(3, 4, 3) + y = r.randn(3) + s = "iji,i->j" + self._check(s, x, y) + + def test_two_operands_4(self): + r = self.rng() + x = r.randn(3, 4) + y = r.randn(3, 4) + s = "ij,ij->" + self._check(s, x, y) + + def test_two_operands_5(self): + r = self.rng() + x = r.randn(10, 2, 3) + y = r.randn(3, 4) + s = "nij,jk->nik" + self._check(s, x, y) + + def test_two_operands_6(self): + # based on https://github.com/google/jax/issues/37#issuecomment-448572187 + r = self.rng() + x = r.randn(2, 1) + y = r.randn(2, 3, 4) + s = "sa,shb->shab" + self._check(s, x, y) + + def test_one_operand_1(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "ijk->j" + self._check(s, x) + + def test_one_operand_2(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "ijk->kij" + self._check(s, x) + + def test_one_operand_3(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "ijk->ki" + self._check(s, x) + + def test_one_operand_4(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "ijk->ki" + self._check(s, x) + + def test_one_operand_5(self): + r = self.rng() + x = r.randn(2, 3, 4, 5) + s = "...ijk->...ki" + self._check(s, x) + + def test_one_operand_6(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "...ijk->ki" + self._check(s, x) + + def test_one_operand_7(self): + r = self.rng() + x = r.randn(3, 3) + s = "ii->" + self._check(s, x) + + def test_one_operand_8(self): + r = self.rng() + x = r.randn(3, 3) + s = "ij->" + self._check(s, x) + + def test_one_operand_9(self): + r = self.rng() + x = r.randn(3, 3, 3) + s = "iii->" + self._check(s, x) + + def test_one_operand_10(self): + r = self.rng() + x = r.randn(3, 3) + s = "ii->i" + self._check(s, x) + + def test_one_operand_11(self): + r = self.rng() + x = r.randn(3, 3, 4) + s = "iij->i" + self._check(s, x) + + def test_one_operand_12(self): + r = self.rng() + x = r.randn(3, 3, 3) + s = "iii->i" + self._check(s, x) + + def test_one_operand_13(self): + r = self.rng() + x = r.randn(3, 3, 5, 4, 4) + s = "iijkk->i" + self._check(s, x) + + def test_one_operand_14(self): + r = self.rng() + x = r.randn(3, 3, 5, 4, 4) + s = "iijkk->ik" + self._check(s, x) + + def test_one_operand_15(self): + r = self.rng() + x = r.randn(3, 3, 5, 4, 4) + s = "iijkl->il" + self._check(s, x) + + def test_one_operand_16(self): + r = self.rng() + x = r.randn(3, 3) + s = "ij->ij" + self._check(s, x) + + def test_tf_unsupported_1(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = self.rng() + x = r.randn(2, 3, 5, 1) + y = r.randn(3, 4, 5, 1) + s = "ij...,jk...->ik..." + self._check(s, x, y) + + def test_tf_unsupported_2(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = self.rng() + x = r.randn(2, 3, 3) + y = r.randn(4) + s = "ijj,k->ik" + self._check(s, x, y) + + def test_tf_unsupported_3(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = self.rng() + x = r.randn(2, 3) + y = r.randn(2, 3) + z = r.randn(3, 4) + s = "ij,ij,jk->ik" + self._check(s, x, y, z) + + # these tests are based on https://github.com/dask/dask/pull/3412/files + @parameterized.named_parameters( + { + "testcase_name": "_{}_dtype={}".format( + einstr, dtype.__name__ + ), # pylint: disable=g-complex-comprehension + "einstr": einstr, + "dtype": dtype, + } + for einstr in [ + "abc,bad->abcd", + "abcdef,bcdfg->abcdeg", + "ea,fb,abcd,gc,hd->efgh", + "ab,b", + "aa", + "a,a->", + "a,a->a", + "a,a", + "a,b", + "a,b,c", + "a", + "ba,b", + "ba,b->", + "defab,fedbc->defac", + "ab...,bc...->ac...", + "a...a", + "abc...->cba...", + "...ab->...a", + "a...a->a...", + # Following 2 from # https://stackoverflow.com/a/19203475/1611416 + "...abc,...abcd->...d", + "ab...,b->ab...", + # https://github.com/dask/dask/pull/3412#discussion_r182413444 + "aa->a", + "ab,ab,c->c", + "aab,bc->ac", + "aab,bcc->ac", + "fdf,cdd,ccd,afe->ae", + "fff,fae,bef,def->abd", + ] + # TODO(wangpeng): Add jnp.bool_ to dtype list + for dtype in [jnp.float32, jnp.int32, jnp.complex64] + ) + def test_from_dask(self, einstr, dtype): + r = jtu.rand_default() + if "->" in einstr: + input_str, _ = einstr.split("->") + else: + input_str = einstr + input_names = input_str.split(",") + + dims = itertools.cycle([2, 3, 4]) + shapes = defaultdict(lambda: next(dims)) + input_shapes = [ + tuple(shapes[c] for c in names.replace("...", "01")) + for names in input_names + ] + operands = [r(shape, dtype) for shape in input_shapes] + + self._check(einstr, *operands) + + def test_ordered_front_batch_dim_case(self): + x = np.ones((1, 8, 20, 4)) + y = np.ones((1, 8, 20, 4)) + s = "ijkl,ijml->ijkm" + self._check(s, x, y) + + # pylint: disable=invalid-name + def test_einsum_path(self): + # just check examples from np.einsum_path docstring + a = self.rng().rand(2, 2) + b = self.rng().rand(2, 5) + c = self.rng().rand(5, 2) + + path_info = np.einsum_path("ij,jk,kl->il", a, b, c, optimize="greedy") + self.assertEqual(str(path_info[0]), "['einsum_path', (1, 2), (0, 1)]") + self.assertEqual( + path_info[1].split("\n")[0], " Complete contraction: ij,jk,kl->il" + ) + + # check this doesn't crash + I = self.rng().rand(10, 10, 10, 10) + C = self.rng().rand(10, 10) + np.einsum_path("ea,fb,abcd,gc,hd->efgh", C, C, I, C, C, optimize="greedy") + + @jtu.disable + def test_einsum_kpmurphy_example(self): + # code from an email with @murphyk + N = 2 + C = 3 + D = 4 + K = 5 + T = 6 + r = self.rng() + S = r.randn(N, T, K) + W = r.randn(K, D) + V = r.randn(D, C) + L = np.zeros((N, C)) + for n in range(N): + for c in range(C): + s = 0 + for d in range(D): + for k in range(K): + for t in range(T): + s += S[n, t, k] * W[k, d] * V[d, c] + L[n, c] = s + + path = jnp.einsum_path("ntk,kd,dc->nc", S, W, V, optimize="optimal")[0] + rtol = 1e-2 if jtu.device_under_test() == "tpu" else None + self.assertAllClose( + L, + jnp.einsum("ntk,kd,dc->nc", S, W, V, optimize=path), + check_dtypes=False, + rtol=rtol, + ) + + # pylint: enable=invalid-name + + @jtu.disable + def test_contraction_broadcasting(self): + r = self.rng() + x = r.randn(3, 4, 5) + y = r.randn(3, 1, 6) + s = "cij,cjk->cik" + self._check(s, x, y) + + @jtu.disable + def test_batch_broadcasting(self): + r = self.rng() + x = r.randn(1, 4, 5) + y = r.randn(3, 5, 6) + s = "cij,cjk->cik" + self._check(s, x, y) + + @jtu.disable + def test_batch_and_contraction_broadcasting(self): + r = self.rng() + x = r.randn(1, 4, 5) + y = r.randn(3, 1, 6) + s = "cij,cjk->cik" + self._check(s, x, y) + + @jtu.disable + def test_broadcasting_issue_2189(self): + r = self.rng() + x = r.randn(2, 1, 3, 3) + y = r.randn(2, 4, 3) + s = "...ij,...j" + self._check(s, x, y) + + +if __name__ == "__main__": + tf.enable_v2_behavior() + absltest.main() diff --git a/tests/tf_numpy/jax/lax_numpy_indexing_test.py b/tests/tf_numpy/jax/lax_numpy_indexing_test.py new file mode 100644 index 000000000..07f532758 --- /dev/null +++ b/tests/tf_numpy/jax/lax_numpy_indexing_test.py @@ -0,0 +1,1332 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import enum +import itertools +from functools import partial + +import numpy as onp +import tensorflow.compat.v2 as tf + +from absl.testing import absltest +from absl.testing import parameterized + +from tests.tf_numpy.jax.config import config +import tests.tf_numpy.jax.utils as jtu + +import trax.tf_numpy.extensions as npe +import trax.tf_numpy.numpy as jnp + + +config.parse_flags_with_absl() + + +# We disable the whitespace continuation check in this file because otherwise it +# makes the test name formatting unwieldy. +# pylint: disable=bad-continuation +# We also disable undefined-variable till we start enabling tests. +# pylint: disable=undefined-variable + + +def subvals(lst, replace): + lst = list(lst) + for i, v in replace: + lst[i] = v + return tuple(lst) + + +float_dtypes = [onp.float32, onp.float64] +int_dtypes = [onp.int32, onp.int64] +bool_types = [onp.bool_] +default_dtypes = float_dtypes + int_dtypes +all_dtypes = float_dtypes + int_dtypes + bool_types + +IndexSpec = collections.namedtuple("IndexTest", ["shape", "indexer"]) + + +suppress_deprecated_indexing_warnings = partial( + jtu.ignore_warning, category=FutureWarning, message="Using a non-tuple sequence.*" +) + + +STATIC_INDEXING_TESTS = [ + ( + "OneIntIndex", + [ + IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ], + ), + ( + "TwoIntIndices", + [ + IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ], + ), + ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ( + "OneSliceIndex", + [ + IndexSpec(shape=(10,), indexer=slice(1, 3)), + IndexSpec(shape=(10,), indexer=slice(1, -1)), + IndexSpec(shape=(10,), indexer=slice(None, -1)), + IndexSpec(shape=(10,), indexer=slice(None, None, None)), + IndexSpec(shape=(10, 8), indexer=slice(1, 3)), + IndexSpec(shape=(10, 8), indexer=slice(1, None)), + IndexSpec(shape=(10, 8), indexer=slice(None, 3)), + IndexSpec(shape=(10, 8), indexer=slice(-3, None)), + ], + ), + ( + "OneSliceIndexNegativeStride", + [ + IndexSpec(shape=(10,), indexer=slice(3, 1, -1)), + IndexSpec(shape=(10,), indexer=slice(1, 8, -1)), # empty result + IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10,), indexer=slice(None, None, -1)), + IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1)), + IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1)), # empty result + IndexSpec(shape=(10, 8), indexer=slice(None, None, -1)), + ], + ), + ( + "OneSliceIndexNonUnitStride", + [ + IndexSpec(shape=(10,), indexer=slice(0, 8, 2)), + IndexSpec(shape=(10,), indexer=slice(0, 8, 3)), + IndexSpec(shape=(10,), indexer=slice(1, 3, 2)), + IndexSpec(shape=(10,), indexer=slice(1, None, 2)), + IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, 2)), + IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, -2)), + ], + ), + ( + "TwoSliceIndices", + [ + IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2))), + IndexSpec(shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2))), + ], + ), + ( + "OneColonIndex", + [ + IndexSpec(shape=(3,), indexer=slice(None)), + IndexSpec(shape=(3, 4), indexer=slice(None)), + ], + ), + ( + "MultipleColonIndices", + [ + IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), + ], + ), + ( + "MixedSliceIndices", + [ + IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2))), + IndexSpec(shape=(10, 4), indexer=(1, slice(None))), + ], + ), + ( + "EllipsisIndex", + [ + IndexSpec(shape=(3,), indexer=Ellipsis), + IndexSpec(shape=(3, 4), indexer=Ellipsis), + IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), + ], + ), + ( + "NoneIndex", + [ + IndexSpec(shape=(), indexer=None), + IndexSpec(shape=(), indexer=(None, None)), + IndexSpec(shape=(), indexer=(Ellipsis, None)), + IndexSpec(shape=(3,), indexer=None), + IndexSpec(shape=(3, 4), indexer=None), + IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), + IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), + ], + ), + ( + "EmptyIndex", + [ + IndexSpec(shape=(), indexer=()), + IndexSpec(shape=(3,), indexer=()), + IndexSpec(shape=(3, 4), indexer=()), + ], + ), +] + +STATIC_INDEXING_GRAD_TESTS = [ + ( + "OneIntIndex", + [ + IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ], + ), + ( + "TwoIntIndices", + [ + IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ], + ), + ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ( + "OneSliceIndex", + [ + IndexSpec(shape=(5,), indexer=slice(1, 3)), + IndexSpec(shape=(5,), indexer=slice(1, -1)), + IndexSpec(shape=(5,), indexer=slice(None, -1)), + IndexSpec(shape=(5,), indexer=slice(None, None, None)), + IndexSpec(shape=(5, 4), indexer=slice(1, 3)), + IndexSpec(shape=(5, 4), indexer=slice(1, None)), + IndexSpec(shape=(5, 4), indexer=slice(None, 3)), + IndexSpec(shape=(5, 4), indexer=slice(-3, None)), + ], + ), + ( + "TwoSliceIndices", + [ + IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4), indexer=(slice(1, None), slice(None, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, None), slice(0, 2))), + ], + ), + ( + "OneColonIndex", + [ + IndexSpec(shape=(3,), indexer=slice(None)), + IndexSpec(shape=(3, 4), indexer=slice(None)), + ], + ), + ( + "MultipleColonIndices", + [ + IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), + ], + ), + ( + "MixedSliceIndices", + [ + IndexSpec(shape=(5, 4), indexer=(slice(None), slice(0, 2))), + IndexSpec(shape=(5, 4), indexer=(1, slice(None))), + ], + ), + ( + "EllipsisIndex", + [ + IndexSpec(shape=(3,), indexer=Ellipsis), + IndexSpec(shape=(3, 4), indexer=Ellipsis), + IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), + ], + ), + ( + "NoneIndex", + [ + IndexSpec(shape=(), indexer=None), + IndexSpec(shape=(), indexer=(None, None)), + IndexSpec(shape=(), indexer=(Ellipsis, None)), + IndexSpec(shape=(3,), indexer=None), + IndexSpec(shape=(3, 4), indexer=None), + IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), + IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), + ], + ), + # TODO(mattjj): these fail for uninteresting dtype reasons + # ("EmptyIndex", + # [IndexSpec(shape=(), indexer=()), + # IndexSpec(shape=(3,), indexer=()), + # IndexSpec(shape=(3, 4), indexer=()), + # ]), +] + +ADVANCED_INDEXING_TESTS = [ + ( + "One1DIntArrayIndex", + [ + IndexSpec(shape=(3,), indexer=onp.array([0, 1])), + IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), + IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), + IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), + IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), + IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)), + ], + ), + ( + "One2DIntArrayIndex", + [ + IndexSpec(shape=(3,), indexer=onp.array([[0, 0]])), + IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])), + IndexSpec( + shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]]) + ), + ], + ), + ( + "Two1DIntArrayIndicesNoBroadcasting", + [ + IndexSpec(shape=(3, 3), indexer=(onp.array([0, 1]), onp.array([1, 2]))), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2, 0, 1]), onp.array([-1, 0, -1, 2])), + ), + ], + ), + ( + "Two1DIntArrayIndicesWithBroadcasting", + [ + IndexSpec(shape=(3, 3), indexer=(onp.array([[0, 1]]), onp.array([1, 2]))), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([[0, 2, 0, 1]]), onp.array([-1, 0, -1, 2])), + ), + ], + ), + ( + "TupleOfListsOfPythonInts", + [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1])), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]])), + ], + ), + ( + "TupleOfPythonIntsAndIntArrays", + [ + IndexSpec(shape=(3, 4, 5), indexer=(0, onp.array([0, 1]))), + IndexSpec(shape=(3, 4, 5), indexer=(0, 1, onp.array([[2, 3, 0, 3]]))), + ], + ), + ( + "TupleOfListsOfPythonIntsAndIntArrays", + [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1], onp.array([0]))), + IndexSpec( + shape=(3, 4, 5), indexer=([[0], [-1]], onp.array([[2, 3, 0, 3]])) + ), + ], + ), +] + +ADVANCED_INDEXING_TESTS_NO_REPEATS = [ + ( + "One1DIntArrayIndex", + [ + IndexSpec(shape=(3,), indexer=onp.array([0, 1])), + IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 0])), + IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 1])), + IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), + IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), + # Fails with a TF/XLA error. + # IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)), + ], + ), + ( + "One2DIntArrayIndex", + [ + IndexSpec(shape=(3,), indexer=onp.array([[0, 1]])), + IndexSpec(shape=(6, 6), indexer=onp.array([[1, 2, 0], [3, 4, -1]])), + ], + ), + ( + "Two1DIntArrayIndicesNoBroadcasting", + [ + IndexSpec(shape=(3, 3), indexer=(onp.array([0, 1]), onp.array([1, 2]))), + IndexSpec( + shape=(4, 5, 6), + indexer=(onp.array([0, 2, 1, 3]), onp.array([-1, 0, -2, 1])), + ), + ], + ), + ( + "Two1DIntArrayIndicesWithBroadcasting", + [ + IndexSpec(shape=(3, 3), indexer=(onp.array([[0, 1]]), onp.array([1, 2]))), + IndexSpec( + shape=(4, 5, 6), + indexer=(onp.array([[0, 2, -1, 1]]), onp.array([-1, 0, -2, 2])), + ), + ], + ), + ( + "TupleOfListsOfPythonInts", + [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1])), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]])), + ], + ), + ( + "TupleOfPythonIntsAndIntArrays", + [ + IndexSpec(shape=(3, 4, 5), indexer=(0, onp.array([0, 1]))), + IndexSpec(shape=(3, 4, 5), indexer=(0, 1, onp.array([[2, 3, 0]]))), + ], + ), + ( + "TupleOfListsOfPythonIntsAndIntArrays", + [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1], onp.array([0]))), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], onp.array([[2, 3, 0]]))), + ], + ), +] + +MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [ + ( + "SlicesAndOneIntArrayIndex", + [ + IndexSpec(shape=(2, 3), indexer=(onp.array([0, 1]), slice(1, 2))), + IndexSpec(shape=(2, 3), indexer=(slice(0, 2), onp.array([0, 2]))), + IndexSpec( + shape=(3, 4, 5), indexer=(Ellipsis, onp.array([0, 2]), slice(None)) + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(Ellipsis, onp.array([[0, 2], [1, 3]]), slice(None)), + ), + ], + ), + ( + "SlicesAndTwoIntArrayIndices", + [ + IndexSpec( + shape=(3, 4, 5), + indexer=(Ellipsis, onp.array([0, 2]), onp.array([-1, 2])), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), Ellipsis, onp.array([-1, 2])), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), onp.array([-1, 2]), Ellipsis), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), onp.array([-1, 2]), slice(1, 3)), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), slice(1, 3), onp.array([-1, 2])), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=( + onp.array([0, 2, -2]), + slice(None, None, 2), + onp.array([-1, 2, 1]), + ), + ), + ], + ), + ( + "NonesAndIntArrayIndices", + [ + IndexSpec( + shape=(3, 4, 5), indexer=(onp.array([0, 2]), None, onp.array([-1, 2])) + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), None, None, onp.array([-1, 2])), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(Ellipsis, onp.array([0, 2]), None, None, onp.array([-1, 2])), + ), + ], + ), + ( + "IntArrayWithInt32Type", + [IndexSpec(shape=(3, 4), indexer=(Ellipsis, onp.array(1, dtype=onp.int32)))], + ), +] + +MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [ + ( + "SlicesAndOneIntArrayIndex", + [ + IndexSpec( + shape=(3, 4, 5), + indexer=(Ellipsis, onp.array([[0, 2], [1, 1]]), slice(None)), + ), + ], + ), + ( + "SlicesAndTwoIntArrayIndices", + [ + IndexSpec( + shape=(3, 4, 5), + indexer=( + onp.array([0, 2, -2]), + slice(None, None, 2), + onp.array([-1, 2, -1]), + ), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=( + onp.array([[0, 2], [2, 0]]), + Ellipsis, + onp.array([[1, 0], [1, 0]]), + ), + ), + ], + ), +] + + +def dynamic_slice_reference(operand, start_indices, slice_sizes): + out = onp.zeros(slice_sizes, dtype=operand.dtype) + idx = tuple( + slice(start, start + size) for start, size in zip(start_indices, slice_sizes) + ) + section = operand[idx] + out[tuple(slice(None, stop) for stop in section.shape)] = section + return out + + +def dynamic_update_slice_reference(operand, update, start_indices): + slices = tuple(map(slice, start_indices, onp.add(start_indices, update.shape))) + updated_operand = onp.copy(operand) + updated_operand[slices] = update + return updated_operand + + +class IndexingTest(jtu.TestCase): + """Tests for Numpy indexing translation rules.""" + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "{}_inshape={}_indexer={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testStaticIndexing(self, shape, dtype, rng_factory, indexer): + # TODO(rohanj): Revisit passing in self.rng() to this to customize further. + # This would need updating lax_numpy_test as well. + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype)] + onp_fun = lambda x: x[indexer] + jnp_fun = lambda x: jnp.asarray(x)[indexer] + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + jnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + def _ReplaceSlicesWithTuples(self, idx): + """Helper method to replace slices with tuples for dynamic indexing args.""" + if isinstance(idx, slice): + triple = idx.start, idx.stop, idx.step + isnone = [i for i, elt in enumerate(triple) if elt is None] + zeros = itertools.repeat(0) + nones = itertools.repeat(None) + out = subvals(triple, zip(isnone, zeros)) + return out, lambda out: slice(*subvals(out, zip(isnone, nones))) + elif isinstance(idx, (tuple, list)) and idx: + t = type(idx) + elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) + return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) + else: + return idx, lambda x: x + + @parameterized.named_parameters( + { + "testcase_name": "{}_inshape={}_indexer={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in [ + ( + "OneSliceIndex", + [ + IndexSpec(shape=(5,), indexer=slice(1, 3)), + IndexSpec(shape=(5, 4), indexer=slice(1, 3)), + ], + ), + ( + "TwoSliceIndices", + [ + IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))), + ], + ), + ( + "NonUnitStrides", + [ + IndexSpec(shape=(3,), indexer=slice(None, None, -1)), + IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), + IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)), + ], + ), + ( + "OnlyStartOrStopDynamic", + [ + IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))), + ], + ), + ] + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + def testDynamicIndexingWithSlices(self, shape, dtype, rng_factory, indexer): + rng = rng_factory() + unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) + + def onp_fun(x, unpacked_indexer): + indexer = pack_indexer(unpacked_indexer) + return x[indexer] + + jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) + + args_maker = lambda: [rng(shape, dtype), unpacked_indexer] + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + # TODO(wangpeng): check_xla_forced_compile is turned off because some + # compile-time-constant requirements are violated. Investigate and turn it + # on. + self._CompileAndCheck( + jnp_fun, + args_maker, + check_dtypes=True, + check_eval_on_shapes=False, + check_incomplete_shape=True, + check_xla_forced_compile=False, + ) + + @parameterized.named_parameters( + { + "testcase_name": "{}_inshape={}_indexer={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in [ + ( + "OneIntIndex", + [ + IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ], + ), + ( + "TwoIntIndices", + [ + IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ], + ), + ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ] + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer): + # TODO(rohanj): Revisit passing in self.rng() to this to customize further. + # This would need updating lax_numpy_test as well. + rng = rng_factory() + unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) + + def onp_fun(x, unpacked_indexer): + indexer = pack_indexer(unpacked_indexer) + return x[indexer] + + jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) + + args_maker = lambda: [rng(shape, dtype), unpacked_indexer] + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + jnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @parameterized.named_parameters( + { + "testcase_name": "_{}_inshape={}_indexer={}".format( # pylint: disable=g-complex-comprehension + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "name": name, + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in ADVANCED_INDEXING_TESTS + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + def testAdvancedIntegerIndexing(self, name, shape, dtype, rng_factory, indexer): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), indexer] + onp_fun = lambda x, idx: x[idx] + jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) + + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + # TODO(wangpeng): check_xla_forced_compile is turned off for + # ListOfPythonIntsAndIntArrays because it throws "The number of output + # elements has to equal to number of input elements that are sliced when + # input indices are not constant". Investigate and turn it on. + check_xla = name != "ListOfPythonIntsAndIntArrays" + self._CompileAndCheck( + jnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + check_xla_forced_compile=check_xla, + ) + + @parameterized.named_parameters( + { + "testcase_name": "_{}_inshape={}_indexer={}".format( # pylint: disable=g-complex-comprehension + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "name": name, + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + def testMixedAdvancedIntegerIndexing( + self, name, shape, dtype, rng_factory, indexer + ): + rng = rng_factory() + indexer_with_dummies = [ + e if isinstance(e, onp.ndarray) else () for e in indexer + ] + substitutes = [ + (i, e) for i, e in enumerate(indexer) if not isinstance(e, onp.ndarray) + ] + args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] + + def np_fun(x, indexer_with_dummies): + idx = type(indexer)(subvals(indexer_with_dummies, substitutes)) + return x[idx] + + jnp_fun = lambda x, idx: np_fun(jnp.asarray(x), idx) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + # TODO(wangpeng): check_xla_forced_compile is turned off for + # IntArrayWithInt32Type because it throws "The number of output elements has + # to equal to number of input elements that are sliced when input indices + # are not constant". Investigate and turn it on. + check_xla = name != "IntArrayWithInt32Type" + self._CompileAndCheck( + jnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + check_xla_forced_compile=check_xla, + ) + + def testAdvancedIndexingManually(self): + x = onp.random.RandomState(0).randn(3, 4, 5) + index_array = onp.array([0, 2, -1, 0]) + + op = lambda x, index_array: x[..., index_array, :] + cop = npe.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2, check_dtypes=True) + + op = lambda x, index_array: x[..., index_array, :, index_array, None] + cop = npe.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2, check_dtypes=True) + + op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] + cop = npe.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2, check_dtypes=True) + + # Note that we don't currently allow __iter__ in graph mode. So this test only + # iterates over eager tensor. + def testUnpacking(self): + def foo(x): + a, b, c = x + return a + b + c + + a1 = foo(onp.arange(3)) + a2 = foo(jnp.arange(3)) + + self.assertAllClose(a1, a2, check_dtypes=True) + + def testBooleanIndexingArray1D(self): + idx = onp.array([True, True, False]) + x = jnp.asarray(onp.arange(3)) + ans = x[idx] + expected = onp.arange(3)[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingList1D(self): + idx = [True, True, False] + x = jnp.asarray(onp.arange(3)) + ans = x[idx] + expected = onp.arange(3)[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingArray2DBroadcast(self): + idx = onp.array([True, True, False, True]) + x = onp.arange(8).reshape(4, 2) + ans = jnp.asarray(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingList2DBroadcast(self): + idx = [True, True, False, True] + x = onp.arange(8).reshape(4, 2) + ans = jnp.asarray(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingArray2D(self): + idx = onp.array([[True, False], [False, True], [False, False], [True, True]]) + x = onp.arange(8).reshape(4, 2) + ans = jnp.asarray(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingDynamicShape(self): + x = onp.zeros(3) + i = onp.array([True, True, False]) + ans = x[i] + expected = jnp.asarray(x)[i] + self.assertAllClose(ans, expected, check_dtypes=True) + + def testIssue187(self): + x = jnp.ones((5, 5)) + x[[0, 2, 4], [0, 2, 4]] # doesn't crash + + x = onp.arange(25).reshape((5, 5)) + ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) + expected = x[[0, 2, 4], [0, 2, 4]] + self.assertAllClose(ans, expected, check_dtypes=False) + + # TODO(agarwal): Fix this use case. + @jtu.disable + def testIndexingEmptyDimension(self): + # Issue 2671: XLA error when indexing into dimension of size 0 + x = jnp.ones((2, 0)) + # The following work, even on axis 1 of size 0 + _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] + + with self.assertRaisesRegex( + IndexError, "index .* is out of bounds for axis .* with size 0" + ): + _ = onp.ones((2, 0))[0, 0] # The numpy error + with self.assertRaisesRegex( + IndexError, "index is out of bounds for axis .* with size 0" + ): + _ = x[0, 0] # JAX indexing + with self.assertRaisesRegex( + IndexError, "index is out of bounds for axis .* with size 0" + ): + npe.jit(lambda i: x[0, i])(0) # JAX indexing under jit + + def testBooleanIndexingWithEmptyResult(self): + # based on a TensorFlow Probability test that started failing after #1623 + x = jnp.array([-1]) + mask = jnp.array([False]) + ans = x[mask] # doesn't crash + + expected = onp.array([-1])[onp.array([False])] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testFloatIndexingError(self): + error_regex = "only integers, slices.*are valid indices" + # Verify onp behavior + with self.assertRaisesRegex(IndexError, error_regex): + _ = onp.zeros((2, 2))[(0, 0.0)] + # Test jnp + with self.assertRaisesRegex(IndexError, error_regex): + jnp.zeros(2)[0.0] + with self.assertRaisesRegex(IndexError, error_regex): + jnp.zeros((2, 2))[(0, 0.0)] + # Test with jit + with self.assertRaisesRegex(IndexError, error_regex): + npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.0)) + + def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 + array = jnp.ones(5) + self.assertAllClose(array, array[:10], check_dtypes=True) + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_start_indices={}_size_indices={}".format( # pylint: disable=g-complex-comprehension + jtu.format_shape_dtype_string(shape, dtype), + start_indices, + size_indices, + ), + "shape": shape, + "dtype": dtype, + "start_indices": start_indices, + "size_indices": size_indices, + "rng_factory": rng_factory, + } + for shape, start_indices, size_indices in [ + [(3,), onp.array((1,)), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(5, 3), (1, -2), (3, 1)], + [(5, 3), onp.array((1, 1)), (3, 1)], + [(7, 5, 3), onp.array((4, 1, 0)), (2, 0, 1)], + [(), (), ()], + ] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testDynamicSlice(self, shape, dtype, start_indices, size_indices, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] + op = lambda x, starts: npe.dynamic_slice(x, starts, size_indices) + self._CompileAndCheck(op, args_maker) + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_start_indices={}_size_indices={}".format( # pylint: disable=g-complex-comprehension + jtu.format_shape_dtype_string(shape, dtype), + start_indices, + size_indices, + ), + "shape": shape, + "dtype": dtype, + "start_indices": start_indices, + "size_indices": size_indices, + "rng_factory": rng_factory, + } + for shape, start_indices, size_indices in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(5, 3), (1, -2), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 0, 1)], + [(), (), ()], + ] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testDynamicSliceAgainstNumpy( + self, shape, dtype, start_indices, size_indices, rng_factory + ): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] + op = lambda x, s: npe.dynamic_slice(x, s, size_indices) + numpy_op = lambda x, s: dynamic_slice_reference(x, s, size_indices) + self._CheckAgainstNumpy(numpy_op, op, args_maker) + + def testDynamicSliceInDim(self): + rng = jtu.rand_default() + x = rng((6, 7), onp.int32) + self.assertAllClose( + npe.dynamic_slice_in_dim(x, 2, 3), x[2:5], check_dtypes=True + ) + + +def _broadcastable_shapes(shape): + """Returns all shapes that broadcast to `shape`.""" + + def f(rshape): + yield [] + if rshape: + for s in f(rshape[1:]): + yield rshape[0:1] + s + if rshape[0] != 1: + for s in f(rshape[1:]): + yield [1] + s + + for x in f(list(reversed(shape))): + yield list(reversed(x)) + + +def _update_shape(shape, indexer): + return onp.zeros(shape)[indexer].shape + + +class UpdateOps(enum.Enum): + UPDATE = 0 + ADD = 1 + # MUL = 2 + MIN = 3 + MAX = 4 + + def np_fn(op, indexer, x, y): # pylint: disable=no-self-argument + x = x.copy() + x[indexer] = { + UpdateOps.UPDATE: lambda: y, + UpdateOps.ADD: lambda: x[indexer] + y, + # UpdateOps.MUL: lambda: x[indexer] * y, + UpdateOps.MIN: lambda: onp.minimum(x[indexer], y), + UpdateOps.MAX: lambda: onp.maximum(x[indexer], y), + }[op]() + return x + + def tfnp_fn(op, indexer, x, y): # pylint: disable=no-self-argument + return { + UpdateOps.UPDATE: npe.index_update, + UpdateOps.ADD: npe.index_add, + # UpdateOps.MUL: npe.index_mul, + UpdateOps.MIN: npe.index_min, + UpdateOps.MAX: npe.index_max, + }[op](x, indexer, y) + + +# a test to workaround b/123559667 +def has_non_trivial_stride(indexer): + def has(idx): + return isinstance(idx, slice) and idx.step not in (1, -1, None) + + return any(has(idx) for idx in tf.nest.flatten(indexer)) + + +class IndexedUpdateTest(jtu.TestCase): + @parameterized.named_parameters( + jtu.cases_from_list( + { # pylint: disable=g-complex-comprehension + "testcase_name": "_{}_{}_{}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), + indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), + op.name, + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op, + } + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for op in UpdateOps + for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testStaticIndexing( + self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op + ): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) + self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) + # TODO(wangpeng): When indexer is slice(_, 8, -1), XLA throws error "Missing + # xla_context 0-th output from". Investigate. + check_xla = not has_non_trivial_stride(indexer) and not ( # b/123559667 + isinstance(indexer, slice) and indexer.stop == 8 and indexer.step == -1 + ) + self._CompileAndCheck( + tfnp_fn, + args_maker, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @parameterized.named_parameters( + jtu.cases_from_list( + { # pylint: disable=g-complex-comprehension + "testcase_name": "_{}_{}_{}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), + indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), + op.name, + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op, + } + for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS + for shape, indexer in index_specs + for op in UpdateOps + for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testAdvancedIndexing( + self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op + ): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) + self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) + self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) + + @parameterized.named_parameters( + jtu.cases_from_list( + { # pylint: disable=g-complex-comprehension + "testcase_name": "_{}_{}_{}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), + indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), + op.name, + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op, + } + for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + for shape, indexer in index_specs + for op in UpdateOps + for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testMixedAdvancedIndexing( + self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op + ): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) + self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) + check_xla = not has_non_trivial_stride(indexer) # b/123559667 + self._CompileAndCheck( + tfnp_fn, + args_maker, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @parameterized.named_parameters( + jtu.cases_from_list( + { # pylint: disable=g-complex-comprehension + "testcase_name": "_{}_{}_{}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), + indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), + op.name, + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op, + } + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for op in [UpdateOps.ADD, UpdateOps.UPDATE] + for dtype in float_dtypes + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in float_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testStaticIndexingGrads( + self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op + ): + rng = rng_factory() + tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) + x = rng(shape, dtype) + y = rng(update_shape, update_dtype) + self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.0) + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension + jtu.format_shape_dtype_string(shape, dtype), + start_indices, + update_shape, + ), + "shape": shape, + "dtype": dtype, + "start_indices": start_indices, + "update_shape": update_shape, + "rng_factory": rng_factory, + } + for shape, start_indices, update_shape in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(5, 3), (1, -2), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 0, 1)], + [(), (), ()], + ] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testDynamicUpdateSlice( + self, shape, dtype, start_indices, update_shape, rng_factory + ): + rng = rng_factory() + + def args_maker(): + return [ + rng(shape, dtype), + rng(update_shape, dtype), + onp.array(start_indices), + ] + + # update's shape must be fully known. + # TODO(wangpeng): Support turning off check_incomplete_shape for individual + # arguments. + self._CompileAndCheck( + npe.dynamic_update_slice, args_maker, check_incomplete_shape=False + ) + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension + jtu.format_shape_dtype_string(shape, dtype), + start_indices, + update_shape, + ), + "shape": shape, + "dtype": dtype, + "start_indices": start_indices, + "update_shape": update_shape, + "rng_factory": rng_factory, + } + for shape, start_indices, update_shape in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(5, 3), (1, -2), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 0, 1)], + [(), (), ()], + ] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testDynamicUpdateSliceAgainstNumpy( + self, shape, dtype, start_indices, update_shape, rng_factory + ): + rng = rng_factory() + + def args_maker(): + return [ + rng(shape, dtype), + rng(update_shape, dtype), + onp.array(start_indices), + ] + + self._CheckAgainstNumpy( + dynamic_update_slice_reference, npe.dynamic_update_slice, args_maker + ) + + def testDynamicUpdateSliceInDim(self): + rng = jtu.rand_default() + x = rng((6, 7), onp.int32) + y = rng((3, 7), onp.int32) + z = x.copy() + z[2:5] = y + self.assertAllClose( + npe.dynamic_update_slice_in_dim(x, y, 2, 0), z, check_dtypes=True + ) + + +if __name__ == "__main__": + tf.config.set_soft_device_placement(False) + jnp.enable_numpy_behavior() + absltest.main() diff --git a/tests/tf_numpy/jax/lax_numpy_test.py b/tests/tf_numpy/jax/lax_numpy_test.py new file mode 100644 index 000000000..576a770a6 --- /dev/null +++ b/tests/tf_numpy/jax/lax_numpy_test.py @@ -0,0 +1,4758 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import functools +import itertools +import operator +import unittest +import warnings +from functools import partial +from unittest import SkipTest + +import numpy as onp +import six +import tensorflow.compat.v2 as tf +from absl.testing import absltest +from absl.testing import parameterized + +import tests.tf_numpy.jax.utils as jtu +from tests.tf_numpy.jax.config import config, FLAGS + +import trax.tf_numpy.extensions as npe +import trax.tf_numpy.numpy as lnp + + +config.parse_flags_with_absl() + + +nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] +nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes +empty_array_shapes = [ + (0,), + (0, 4), + (3, 0), +] + +scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] +array_shapes = nonempty_array_shapes + empty_array_shapes +nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes +nonempty_shapes = scalar_shapes + nonempty_array_shapes +all_shapes = scalar_shapes + array_shapes + +# TODO(wangpeng): float_dtypes = [lnp.bfloat16, onp.float16, onp.float32, +# onp.float64] +float_dtypes = [onp.float16, onp.float32, onp.float64] +complex_dtypes = [onp.complex64, onp.complex128] +int_dtypes = [onp.int32, onp.int64] +unsigned_dtypes = [onp.uint32, onp.uint64] +bool_dtypes = [onp.bool_] +default_dtypes = float_dtypes + int_dtypes +inexact_dtypes = float_dtypes + complex_dtypes +number_dtypes = float_dtypes + complex_dtypes + int_dtypes +all_dtypes = number_dtypes + bool_dtypes + + +python_scalar_dtypes = [lnp.bool_, lnp.int_, lnp.float_, lnp.complex_] + + +def _valid_dtypes_for_shape(shape, dtypes): + # Not all (shape, dtype) pairs are valid. In particular, Python scalars only + # have one type in each category (float, bool, etc.) + if shape is jtu.PYTHON_SCALAR_SHAPE: + return [t for t in dtypes if t in python_scalar_dtypes] + return dtypes + + +def _shape_and_dtypes(shapes, dtypes): + for shape in shapes: + for dtype in _valid_dtypes_for_shape(shape, dtypes): + yield (shape, dtype) + + +OpRecord = collections.namedtuple( + "OpRecord", + [ + "name", + "nargs", + "dtypes", + "shapes", + "rng_factory", + "diff_modes", + "test_name", + "check_dtypes", + "tolerance", + "inexact", + "check_incomplete_shape", + ], +) + + +def op_record( + name, + nargs, + dtypes, + shapes, + rng_factory, + diff_modes, + test_name=None, + check_dtypes=True, + tolerance=None, + inexact=False, + check_incomplete_shape=True, +): + test_name = test_name or name + return OpRecord( + name, + nargs, + dtypes, + shapes, + rng_factory, + diff_modes, + test_name, + check_dtypes, + tolerance, + inexact, + check_incomplete_shape, + ) + + +def minus(a, b): + return [x for x in a if x not in b] + + +JAX_ONE_TO_ONE_OP_RECORDS = [ + op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), + op_record( + "exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], inexact=True + ), + op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "float_power", + 2, + inexact_dtypes, + all_shapes, + partial(jtu.rand_default, scale=1), + ["rev"], + tolerance={ + # TODO(wangpeng): lnp.bfloat16: 1e-2, + onp.float32: 1e-3, + onp.float64: 1e-12, + onp.complex64: 2e-4, + onp.complex128: 1e-12, + }, + check_dtypes=False, + ), + op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "greater", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "greater_equal", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "less", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "less_equal", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], inexact=True + ), + op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record( + "maximum", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf, + [], + ), + op_record( + "minimum", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf, + [], + ), + op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "nextafter", + 2, + [f for f in float_dtypes if f not in (lnp.bfloat16, onp.float16)], + all_shapes, + jtu.rand_default, + ["rev"], + inexact=True, + tolerance=0, + ), + op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), + op_record( + "array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"] + ), + op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "signbit", + 1, + default_dtypes + bool_dtypes, + all_shapes, + jtu.rand_some_inf_and_nan, + ["rev"], + ), + op_record( + "sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], inexact=True + ), + op_record( + "cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], inexact=True + ), + op_record( + "tan", + 1, + number_dtypes, + all_shapes, + partial(jtu.rand_uniform, -1.5, 1.5), + ["rev"], + tolerance={onp.complex64: 3e-5, onp.complex128: 4e-14}, + inexact=True, + ), + # TODO(wangpeng): Add float16 support + op_record( + "sinh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_default, + ["rev"], + inexact=True, + ), + op_record( + "cosh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_default, + ["rev"], + inexact=True, + ), + # TODO(b/142975473): on CPU, tanh for complex128 is only accurate to + # ~float32 precision. + # TODO(b/143135720): on GPU, tanh has only ~float32 precision. + op_record( + "tanh", + 1, + number_dtypes, + all_shapes, + jtu.rand_default, + ["rev"], + tolerance={onp.float64: 1e-7, onp.complex128: 1e-7}, + inexact=True, + ), + op_record( + "arcsin", + 1, + minus(float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), + op_record( + "arccos", + 1, + minus(float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), + op_record( + "arctan", + 1, + minus(float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), + op_record( + "arctan2", + 2, + minus(float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), + op_record( + "arcsinh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_positive, + ["rev"], + inexact=True, + ), + op_record( + "arccosh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_positive, + ["rev"], + inexact=True, + ), + op_record( + "arctanh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), +] + +JAX_COMPOUND_OP_RECORDS = [ + # angle has inconsistent 32/64-bit return types across numpy versions. + op_record( + "angle", + 1, + number_dtypes, + all_shapes, + jtu.rand_default, + [], + check_dtypes=False, + inexact=True, + ), + op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "cbrt", 1, default_dtypes, all_shapes, jtu.rand_default, ["rev"], inexact=True + ), + op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "divide", + 2, + number_dtypes, + all_shapes, + jtu.rand_nonzero, + ["rev"], + inexact=six.PY3, + ), + op_record( + "divmod", + 2, + minus(int_dtypes + float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_nonzero, + [], + ), + op_record( + "exp2", + 1, + number_dtypes, + all_shapes, + jtu.rand_default, + ["rev"], + tolerance={ + # TODO(wangpeng): lnp.bfloat16: 2e-2, + onp.float16: 1e-2 + }, + inexact=True, + ), + # TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32 + # precision. + op_record( + "expm1", + 1, + number_dtypes, + all_shapes, + jtu.rand_positive, + [], + test_name="expm1_large", + tolerance={onp.float64: 1e-8}, + inexact=True, + ), + op_record( + "expm1", + 1, + number_dtypes, + all_shapes, + jtu.rand_small_positive, + [], + tolerance={onp.float64: 1e-8}, + inexact=True, + ), + op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "floor_divide", + 2, + minus(number_dtypes, complex_dtypes), + all_shapes, + jtu.rand_nonzero, + ["rev"], + ), + op_record( + "heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], inexact=True + ), + op_record( + "hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], inexact=True + ), + op_record( + "kron", + 2, + number_dtypes, + nonempty_shapes, + jtu.rand_default, + [], + check_incomplete_shape=False, + ), + op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record( + "isfinite", + 1, + minus(inexact_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf_and_nan, + [], + ), + op_record( + "isinf", + 1, + minus(inexact_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf_and_nan, + [], + ), + op_record( + "isnan", + 1, + minus(inexact_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf_and_nan, + [], + ), + op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record( + "log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], inexact=True + ), + op_record( + "log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], inexact=True + ), + op_record( + "log1p", + 1, + number_dtypes, + all_shapes, + jtu.rand_positive, + [], + test_name="log1p_large", + tolerance={onp.float64: 1e-12}, + inexact=True, + ), + op_record( + "log1p", + 1, + number_dtypes, + all_shapes, + jtu.rand_small_positive, + [], + tolerance={onp.float64: 1e-12}, + inexact=True, + ), + op_record( + "logaddexp", + 2, + float_dtypes, + all_shapes, + jtu.rand_some_inf_and_nan, + ["rev"], + tolerance={onp.float64: 1e-12}, + inexact=True, + ), + op_record( + "logaddexp2", + 2, + float_dtypes, + all_shapes, + jtu.rand_some_inf_and_nan, + ["rev"], + tolerance={onp.float16: 1e-2}, + inexact=True, + ), + op_record( + "polyval", + 2, + number_dtypes, + nonempty_nonscalar_array_shapes, + jtu.rand_default, + [], + check_dtypes=False, + tolerance={onp.float16: 1e-2, onp.float64: 1e-12}, + check_incomplete_shape=False, + ), + op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "power", + 2, + number_dtypes, + all_shapes, + jtu.rand_positive, + ["rev"], + tolerance={onp.complex128: 1e-14}, + ), + op_record( + "rad2deg", + 1, + float_dtypes, + all_shapes, + jtu.rand_default, + [], + tolerance={onp.float64: 5e-6}, + ), + op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record( + "remainder", + 2, + minus(default_dtypes, [onp.float16]), + all_shapes, + jtu.rand_nonzero, + [], + tolerance={onp.float16: 1e-2}, + ), + op_record( + "mod", 2, minus(default_dtypes, [onp.float16]), all_shapes, jtu.rand_nonzero, [] + ), + op_record( + "sinc", + 1, + [t for t in number_dtypes if t != lnp.bfloat16], + all_shapes, + jtu.rand_default, + ["rev"], + tolerance={onp.complex64: 1e-5}, + inexact=True, + check_dtypes=False, + ), + op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], inexact=True + ), + op_record( + "transpose", + 1, + all_dtypes, + all_shapes, + jtu.rand_default, + ["rev"], + check_dtypes=False, + ), + op_record( + "true_divide", + 2, + all_dtypes, + all_shapes, + jtu.rand_nonzero, + ["rev"], + inexact=True, + ), + op_record( + "diff", + 1, + number_dtypes, + nonzerodim_shapes, + jtu.rand_default, + ["rev"], + check_incomplete_shape=False, + ), +] + +JAX_BITWISE_OP_RECORDS = [ + op_record( + "bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [] + ), + op_record( + "bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [] + ), + op_record( + "bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [] + ), + op_record( + "bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [] + ), +] + +JAX_REDUCER_RECORDS = [ + op_record( + "mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True + ), + op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "nanmean", + 1, + minus(inexact_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_some_nan, + [], + inexact=True, + ), + op_record( + "nanprod", + 1, + minus(inexact_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_nan, + [], + ), + op_record( + "nansum", + 1, + minus(number_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_nan, + [], + ), +] + +JAX_REDUCER_NO_DTYPE_RECORDS = [ + op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), + op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), + op_record( + "max", + 1, + minus(all_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_default, + [], + ), + op_record( + "min", + 1, + minus(all_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_default, + [], + ), + op_record( + "var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True + ), + op_record( + "std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True + ), +] + +JAX_ARGMINMAX_RECORDS = [ + op_record( + "argmin", + 1, + minus(all_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "argmax", + 1, + minus(all_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_some_equal, + [], + ), +] + +JAX_OPERATOR_OVERLOADS = [ + op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "__pow__", + 2, + inexact_dtypes, + all_shapes, + jtu.rand_positive, + [], + tolerance={onp.float32: 2e-4, onp.complex64: 2e-4, onp.complex128: 1e-14}, + ), + op_record( + "__mod__", + 2, + minus(default_dtypes, [onp.float16]), + all_shapes, + jtu.rand_nonzero, + [], + tolerance={onp.float16: 1e-1}, + ), + op_record("__floordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), + op_record( + "__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], inexact=True + ), + op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + # TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2 + op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []), + # TODO(mattjj): investigate these failures + # op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + # op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), + # TODO(mattjj): lshift, rshift +] + +JAX_RIGHT_OPERATOR_OVERLOADS = [ + op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "__rpow__", + 2, + inexact_dtypes, + all_shapes, + jtu.rand_positive, + [], + tolerance={onp.float32: 2e-4, onp.complex64: 1e-3}, + ), + op_record( + "__rmod__", + 2, + minus(default_dtypes, [onp.float16]), + all_shapes, + jtu.rand_nonzero, + [], + tolerance={onp.float16: 1e-1}, + ), + op_record("__rfloordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), + op_record( + "__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], inexact=True + ), + # op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + # op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), +] + +numpy_version = tuple(map(int, onp.version.version.split("."))) +if numpy_version >= (1, 15): + JAX_COMPOUND_OP_RECORDS += [ + op_record( + "isclose", + 2, + [t for t in all_dtypes if t != lnp.bfloat16], + all_shapes, + jtu.rand_small_positive, + [], + ), + op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default, []), + op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default, []), + ] + JAX_REDUCER_NO_DTYPE_RECORDS += [ + op_record( + "ptp", + 1, + minus(number_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_default, + [], + ), + ] + +if six.PY2: + JAX_OPERATOR_OVERLOADS += [ + op_record("__div__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), + ] + JAX_RIGHT_OPERATOR_OVERLOADS += [ + op_record("__rdiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), + ] + + +CombosWithReplacement = itertools.combinations_with_replacement + + +def _dtypes_are_compatible_for_bitwise_ops(args): + if len(args) <= 1: + return True + is_signed = lambda dtype: lnp.issubdtype(dtype, onp.signedinteger) + width = lambda dtype: lnp.iinfo(dtype).bits + x, y = args + # `lnp.iinfo(dtype).bits` can't be called on bools, so we convert bools to + # ints. + if x == lnp.bool_: + x = lnp.int32 + if y == lnp.bool_: + y = lnp.int32 + if width(x) > width(y): + x, y = y, x + if x == lnp.uint32 and y == lnp.uint64: + return False + # The following condition seems a little ad hoc, but seems to capture what + # numpy actually implements. + return ( + is_signed(x) == is_signed(y) + or (width(x) == 32 and width(y) == 32) + or (width(x) == 32 and width(y) == 64 and is_signed(y)) + ) + + +def _shapes_are_broadcast_compatible(shapes): + accumulator = onp.zeros([]) + for shape in shapes: + try: + accumulator = accumulator + onp.zeros(shape) + except ValueError: + return False + return True + + +def _shapes_are_equal_length(shapes): + return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) + + +def _promote_like_lnp(fun, inexact=False): + """Decorator that promotes the arguments of `fun` to `lnp.result_type(*args)`. + + lnp and onp have different type promotion semantics; this decorator allows + tests make an onp reference implementation act more like an lnp + implementation. + """ + + def wrapper(*args, **kw): + flat_args = tf.nest.flatten(args) + if inexact and not any( + lnp.issubdtype(lnp.result_type(x).as_numpy_dtype, lnp.inexact) + for x in flat_args + ): + dtype = lnp.result_type(lnp.float_, *flat_args) + else: + dtype = lnp.result_type(*flat_args) + dtype = dtype.as_numpy_dtype + args = tf.nest.map_structure(lambda a: onp.asarray(a, dtype), args) + return fun(*args, **kw) + + return wrapper + + +def new_test(f): + def wrapper(self, *args, **kwargs): + if not FLAGS.tf_numpy_additional_tests: + self.skipTest("Newly added test is disabled, since flag is False.") + else: + f(self, *args, **kwargs) + + return wrapper + + +def named_parameters(ls): + """A version that allows an empty param list.""" + + def noop(_): + def wrapper(self, *args, **kwargs): + self.skipTest("Empty parameter list") + + return wrapper + + if isinstance(ls, (list, tuple)) and not ls: + return noop + if isinstance(ls, itertools.chain): + try: + first = next(ls) + except StopIteration: + return noop + else: + ls = itertools.chain([first], ls) + return parameterized.named_parameters(ls) + + +# TODO(wangpeng): Enable all disabled tests in this class +class LaxBackedNumpyTests(jtu.TestCase): + """Tests for LAX-backed Numpy implementation.""" + + def _GetArgsMaker(self, rng, shapes, dtypes, onp_arrays=True): + def f(): + out = [ + rng(shape, dtype or lnp.float_) for shape, dtype in zip(shapes, dtypes) + ] + return out if onp_arrays else [lnp.asarray(a) for a in out] + + return f + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes + ), + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtypes": dtypes, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + "check_dtypes": rec.check_dtypes, + "tolerance": rec.tolerance, + "inexact": rec.inexact, + "check_incomplete_shape": rec.check_incomplete_shape, + } + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(rec.shapes, rec.nargs), + ) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes) + ) + ) + for rec in itertools.chain( + JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS + ) + ) + ) + def testOp( + self, + onp_op, + lnp_op, + rng_factory, + shapes, + dtypes, + check_dtypes, + tolerance, + inexact, + check_incomplete_shape, + ): + # TODO(b/147769803): Remove this skipping + if lnp_op.__name__ == "kron" and shapes == ((2, 3, 4), (2, 3, 4)): + self.skipTest("Case disabled because of b/147769803") + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) + tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) + tol = functools.reduce( + jtu.join_tolerance, [tolerance, tol, jtu.default_tolerance()] + ) + self._CheckAgainstNumpy( + _promote_like_lnp(onp_op, inexact), + lnp_op, + args_maker, + check_dtypes=check_dtypes, + tol=tol, + ) + # tf.math.pow doesn't support int32/int64 on XLA (b/169191476). + check_xla = not ( + lnp_op.__name__ == "power" + and set(dtypes).intersection((onp.int32, onp.int64)) + ) + self._CompileAndCheck( + lnp_op, + args_maker, + check_dtypes=check_dtypes, + atol=tol, + rtol=tol, + check_incomplete_shape=check_incomplete_shape, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes + ), + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtypes": dtypes, + "name": rec.name, + "tol": rec.tolerance, + } + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(rec.shapes, rec.nargs), + ) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes) + ) + ) + for rec in JAX_OPERATOR_OVERLOADS + ) + ) + def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): + rng = rng_factory() + # onp and lnp arrays have different type promotion rules; force the use of + # lnp arrays. + args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) + fun = lambda *xs: getattr(operator, name.strip("_"))(*xs) + scalar_arg = ( + jtu.PYTHON_SCALAR_SHAPE in shapes + or jtu.NUMPY_SCALAR_SHAPE in shapes + or () in shapes + ) + empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) + self._CompileAndCheck( + fun, + args_maker, + check_dtypes=True, # not scalar_arg and not empty_shape, + atol=tol, + rtol=tol, + ) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes + ), + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtypes": dtypes, + "name": rec.name, + "op_tolerance": rec.tolerance, + } + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(rec.shapes, rec.nargs), + ) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes) + ) + ) + for rec in JAX_RIGHT_OPERATOR_OVERLOADS + ) + ) + def testRightOperatorOverload( + self, name, rng_factory, shapes, dtypes, op_tolerance + ): + if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: + raise SkipTest() # TODO(mattjj): clean up + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) + fun = lambda fst, snd: getattr(snd, name)(fst) + tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) + scalar_arg = ( + jtu.PYTHON_SCALAR_SHAPE in shapes + or jtu.NUMPY_SCALAR_SHAPE in shapes + or () in shapes + ) + empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) + self._CompileAndCheck( + fun, + args_maker, + check_dtypes=True, # not scalar_arg and not empty_shape, + atol=tol, + rtol=tol, + ) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes + ), + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtypes": dtypes, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + } + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(rec.shapes, rec.nargs), + ) + for dtypes in filter( + _dtypes_are_compatible_for_bitwise_ops, + CombosWithReplacement(rec.dtypes, rec.nargs), + ) + ) + for rec in JAX_BITWISE_OP_RECORDS + ) + ) + def testBitwiseOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, shapes, dtypes) + has_python_scalar = jtu.PYTHON_SCALAR_SHAPE in shapes + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + if onp_op == onp.bitwise_not and has_python_scalar: + # For bitwise_not with a Python `int`, npe.jit may choose a different + # dtype for the `int` from onp's choice, which may result in a different + # result value, so we skip _CompileAndCheck. + return + # Numpy does value-dependent dtype promotion on Python/numpy/array scalars + # which `jit` can't do (when np.result_type is called inside `jit`, tensor + # values are not available), so we skip dtype check in this case. + check_dtypes = not ( + set(shapes) & set([jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE, ()]) + ) + self._CompileAndCheck(lnp_op, args_maker, check_dtypes=check_dtypes) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), + axis, + "None" if out_dtype is None else onp.dtype(out_dtype).name, + keepdims, + ), + "rng_factory": rec.rng_factory, + "shape": shape, + "dtype": dtype, + "out_dtype": out_dtype, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + "axis": axis, + "keepdims": keepdims, + "inexact": rec.inexact, + } + for shape in rec.shapes + for dtype in rec.dtypes + for out_dtype in [None] + rec.dtypes + for axis in set(range(-len(shape), len(shape))) | set([None]) + for keepdims in [False, True] + ) + for rec in JAX_REDUCER_RECORDS + ) + ) + def testReducer( + self, + onp_op, + lnp_op, + rng_factory, + shape, + dtype, + out_dtype, + axis, + keepdims, + inexact, + ): + rng = rng_factory() + + def onp_fun(x): + x_cast = x if dtype != lnp.bfloat16 else x.astype(onp.float32) + t = out_dtype if out_dtype != lnp.bfloat16 else onp.float32 + return onp_op(x_cast, axis, dtype=t, keepdims=keepdims) + + onp_fun = _promote_like_lnp(onp_fun, inexact) + lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) + args_maker = lambda: [rng(shape, dtype)] + tol_spec = { + onp.float16: 1e-2, + onp.float32: 1e-3, + onp.complex64: 1e-3, + onp.float64: 1e-5, + onp.complex128: 1e-5, + } + tol = jtu.tolerance(dtype, tol_spec) + tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol + self._CheckAgainstNumpy( + onp_fun, + lnp_fun, + args_maker, + check_dtypes=lnp.bfloat16 not in (dtype, out_dtype), + tol=tol, + ) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), + axis, + keepdims, + ), + "rng_factory": rec.rng_factory, + "shape": shape, + "dtype": dtype, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + "axis": axis, + "keepdims": keepdims, + "inexact": rec.inexact, + } + for rec in JAX_REDUCER_NO_DTYPE_RECORDS + for shape in rec.shapes + for dtype in rec.dtypes + for axis in set(range(-len(shape), len(shape))) | set([None]) + for keepdims in [False, True] + ) + ) + def testReducerNoDtype( + self, onp_op, lnp_op, rng_factory, shape, dtype, axis, keepdims, inexact + ): + rng = rng_factory() + onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) + onp_fun = _promote_like_lnp(onp_fun, inexact) + lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis + ), + "shape": shape, + "dtype": dtype, + "axis": axis, + } + for shape in all_shapes + for dtype in all_dtypes + for axis in set(range(-len(shape), len(shape))) | set([None]) + ) + ) + def testCountNonzero(self, shape, dtype, axis): + rng = jtu.rand_some_zero() + onp_fun = lambda x: onp.count_nonzero(x, axis) + lnp_fun = lambda x: lnp.count_nonzero(x, axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "shape": shape, + "dtype": dtype, + } + for shape in all_shapes + for dtype in all_dtypes + ) + ) + def testNonzero(self, shape, dtype): + rng = jtu.rand_some_zero() + onp_fun = lambda x: onp.nonzero(x) + lnp_fun = lambda x: lnp.nonzero(x) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) + # The shapes of `nonzero`'s results are value-dependent, so `eval_on_shapes` + # won't return concrete shapes. + # Also, `nonzero` requires a known rank. + # Turns off XLA check because there are no XLA kernels for `Where`, which + # XLA can't support because it's output shape is dynamic. + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + check_eval_on_shapes=False, + check_incomplete_shape=True, + check_unknown_rank=False, + check_experimental_compile=False, + check_xla_forced_compile=False, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "{}_inshape={}_axis={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), + axis, + ), + "rng_factory": rec.rng_factory, + "shape": shape, + "dtype": dtype, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + "axis": axis, + } + for rec in JAX_ARGMINMAX_RECORDS + for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) + for axis in range(-len(shape), len(shape)) + ) + ) + def testArgMinMax(self, onp_op, lnp_op, rng_factory, shape, dtype, axis): + rng = rng_factory() + if dtype == onp.complex128 and jtu.device_under_test() == "gpu": + raise unittest.SkipTest("complex128 reductions not supported on GPU") + + def onp_fun(array_to_reduce): + return onp_op(array_to_reduce, axis).astype(lnp.int_) + + def lnp_fun(array_to_reduce): + return lnp_op(array_to_reduce, axis) + + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + axes, + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "axes": axes, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for lhs_shape, rhs_shape, axes in [ + [(2,), (2,), (-1, -1, -1, None)], # scalar output + [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors + [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors + [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting + [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes + [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting + [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors + [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting + [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing + [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)], # same as before + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement( + minus(number_dtypes, complex_dtypes), 2 + ) + ) + ) + def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + axisa, axisb, axisc, axis = axes + lnp_fun = lambda a, b: lnp.cross(a, b, axisa, axisb, axisc, axis) + + def onp_fun(a, b): + a = a.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else a + b = b.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else b + out = onp.cross(a, b, axisa, axisb, axisc, axis) + return out.astype(lnp.promote_types(lhs_dtype, rhs_dtype)) + + tol_spec = { + # TODO(wangpeng): dtypes.bfloat16: 3e-1, + onp.float16: 0.15 + } + tol = max( + jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec) + ) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + name, + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for name, lhs_shape, rhs_shape in [ + ("matrix-scalar", (3, 3), ()), + ("scalar-matrix", (), (3, 3)), + ("matrix-vector", (4, 5), (5,)), + ("vector-matrix", (6,), (6, 4)), + ("matrix-matrix", (3, 4), (4, 5)), + ("tensor-vector", (4, 3, 2), (2,)), + ("vector-tensor", (2,), (3, 2, 4)), + ("tensor-matrix", (4, 3, 2), (2, 5)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-tensor", (2, 3, 4), (5, 4, 1)), + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2) + ) + ) + def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = { + onp.float16: 1e-2, + onp.float32: 1e-5, + onp.float64: 1e-14, + onp.complex128: 1e-14, + } + if jtu.device_under_test() == "tpu": + tol[onp.float32] = tol[onp.complex64] = 2e-1 + + def onp_dot(x, y): + x = x.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else x + y = y.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else y + # `onp.dot(x, y).dtype` sometimes differs from `onp.result_type(x, y)` + # (e.g. when x is float64[] and y is complex64[3,3], or when x is + # float16[3,3] and y is int64[]). We ignore this corner case and pretend + # that they agree. + return onp.dot(x, y).astype(onp.result_type(x, y)) + + self._CheckAgainstNumpy( + onp_dot, lnp.dot, args_maker, check_dtypes=True, tol=tol + ) + # We disable dtype check in the following cases because `np.dot` does + # value-dependent type promotion in those cases. + check_dtypes = () not in (lhs_shape, rhs_shape) + # XLA lacks int32/int64 MatMul kernels (b/168657656). + check_xla = not set((lhs_dtype, rhs_dtype)).intersection((onp.int32, onp.int64)) + self._CompileAndCheck( + lnp.dot, + args_maker, + check_dtypes=check_dtypes, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + name, + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for name, lhs_shape, rhs_shape in [ + ("vector-vector", (3,), (3,)), + ("matrix-vector", (3, 3), (3,)), + ("vector-matrix", (3,), (3, 3)), + ("matrix-matrix", (3, 3), (3, 3)), + ("vector-tensor", (3,), (5, 3, 2)), + ("tensor-vector", (5, 3, 2), (2,)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-matrix", (5, 2, 3), (3, 2)), + ("tensor-tensor", (5, 3, 4), (5, 4, 1)), + ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1)), + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2) + ) + ) + def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): + rng = rng_factory() + + def onp_fun(x, y): + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return ( + onp.matmul(x, y).astype(dtype), + onp.array(x).__matmul__(y).astype(dtype), + onp.array(y).__rmatmul__(x).astype(dtype), + ) + + def lnp_fun(x, y): + return ( + lnp.matmul(x, y), + lnp.array(x).__matmul__(y), + lnp.array(y).__rmatmul__(x), + ) + + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = { + onp.float16: 1e-2, + onp.float32: 2e-2, + onp.float64: 1e-12, + onp.complex128: 1e-12, + } + if jtu.device_under_test() == "tpu": + tol[onp.float32] = tol[onp.complex64] = 4e-2 + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + # XLA lacks int32/int64 MatMul kernels (b/168657656). + check_xla = not set((lhs_dtype, rhs_dtype)).intersection((onp.int32, onp.int64)) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + name, + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for name, lhs_shape, rhs_shape in [ + ("vector-vector", (3,), (3,)), + ("vector-matrix", (9,), (3, 3)), + ("matrix-matrix", (3, 3), (3, 3)), + ("tensor-vector", (5, 3, 2), (30,)), + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2) + ) + ) + @new_test + def testVDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = { + onp.float16: 1e-2, + onp.float32: 2e-2, + onp.float64: 1e-12, + onp.complex128: 1e-12, + } + self._CheckAgainstNumpy( + onp.vdot, lnp.vdot, args_maker, check_dtypes=True, tol=tol + ) + # XLA lacks int32/int64 MatMul kernels (b/168657656). + check_xla = not set((lhs_dtype, rhs_dtype)).intersection((onp.int32, onp.int64)) + self._CompileAndCheck( + lnp.vdot, + args_maker, + check_dtypes=True, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + axes, + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "axes": axes, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for lhs_shape, rhs_shape, axes in [ + [(2, 3, 4), (5, 6, 7), 0], # from issue #740 + [(2, 3, 4), (3, 4, 5, 6), 2], + [(2, 3, 4), (5, 4, 3, 6), [1, 2]], + [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], + [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2) + ) + ) + def testTensordot( + self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory + ): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + lnp_fun = lambda a, b: lnp.tensordot(a, b, axes) + + def onp_fun(a, b): + a = a if lhs_dtype != lnp.bfloat16 else a.astype(onp.float32) + b = b if rhs_dtype != lnp.bfloat16 else b.astype(onp.float32) + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return onp.tensordot(a, b, axes).astype(dtype) + + tol = { + onp.float16: 1e-1, + onp.float32: 1e-3, + onp.float64: 1e-12, + onp.complex64: 1e-3, + onp.complex128: 1e-12, + } + if jtu.device_under_test() == "tpu": + tol[onp.float32] = tol[onp.complex64] = 2e-1 + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + # XLA lacks int32/int64 MatMul kernels (b/168657656). + check_xla = not set((lhs_dtype, rhs_dtype)).intersection((onp.int32, onp.int64)) + + tol = {onp.float64: 1e-14, onp.float16: 0.04, onp.complex128: 6e-15} + tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol)) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + atol=tol, + rtol=tol, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "rng_factory": jtu.rand_default, + } + # TODO(phawkins): support integer dtypes too. + for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + if len(jtu._dims_of_shape(lhs_shape)) == 0 + or len(jtu._dims_of_shape(rhs_shape)) == 0 + or lhs_shape[-1] == rhs_shape[-1] + ) + ) + def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + + def onp_fun(lhs, rhs): + lhs = lhs if lhs_dtype != lnp.bfloat16 else lhs.astype(onp.float32) + rhs = rhs if rhs_dtype != lnp.bfloat16 else rhs.astype(onp.float32) + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return onp.inner(lhs, rhs).astype(dtype) + + lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs) + tol_spec = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 2e-6} + if jtu.device_under_test() == "tpu": + tol_spec[onp.float32] = tol_spec[onp.complex64] = 2e-1 + tol = max( + jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec) + ) + # TODO(phawkins): there are float32/float64 disagreements for some inputs. + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=False, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=False, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_amin={}_amax={}".format( + jtu.format_shape_dtype_string(shape, dtype), a_min, a_max + ), + "shape": shape, + "dtype": dtype, + "a_min": a_min, + "a_max": a_max, + "rng_factory": jtu.rand_default, + } + for shape in all_shapes + for dtype in minus(number_dtypes, complex_dtypes) + for a_min, a_max in [ + (-1, None), + (None, 1), + (-1, 1), + (-onp.ones(1), None), + (None, onp.ones(1)), + (-onp.ones(1), onp.ones(1)), + ] + ) + ) + def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) + lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) + args_maker = lambda: [rng(shape, dtype)] + tol_spec = {onp.float64: 2e-7} + tol = jtu.tolerance(dtype, tol_spec) + is_x32_scalar = dtype in [onp.int32, onp.float32] and shape in [ + jtu.NUMPY_SCALAR_SHAPE, + (), + ] + # Turns check_dtypes off if is_x32_scalar is True because there is + # a weird promotion inconsistency in numpy: + # ``` + # print(np.result_type(np.ones([], np.int32), 1)) + # print(np.result_type(np.ones([1], np.int32), 1)) + # print(np.result_type(np.int32(1), 1)) + # print(np.result_type(np.int32, 1)) + # print(np.result_type(np.ones([], np.float32), 1)) + # print(np.result_type(np.ones([1], np.float32), 1)) + # print(np.result_type(np.float32(1), 1)) + # print(np.result_type(np.float32, 1)) + # ``` + # >>> + # int64 + # int32 + # int64 + # int32 + # float64 + # float32 + # float64 + # float32 + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=not is_x32_scalar, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=not is_x32_scalar, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_amin={}_amax={}".format( + jtu.format_shape_dtype_string(shape, dtype), a_min, a_max + ), + "shape": shape, + "dtype": dtype, + "a_min": a_min, + "a_max": a_max, + "rng_factory": jtu.rand_default, + } + for shape in array_shapes + [jtu.NUMPY_SCALAR_SHAPE] + for dtype in minus(number_dtypes, complex_dtypes) + for a_min, a_max in [ + (-1, None), + (None, 1), + (-1, 1), + (-onp.ones(1), None), + (None, onp.ones(1)), + (-onp.ones(1), onp.ones(1)), + ] + ) + ) + @new_test + def testClipAsMethodStaticBounds(self, shape, dtype, a_min, a_max, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) + lnp_fun = lambda x: lnp.asarray(x).clip(a_min=a_min, a_max=a_max) + args_maker = lambda: [rng(shape, dtype)] + tol_spec = {onp.float64: 2e-7} + tol = jtu.tolerance(dtype, tol_spec) + is_x32_scalar = dtype in [onp.int32, onp.float32] and shape in [ + jtu.NUMPY_SCALAR_SHAPE, + (), + ] + # Turns check_dtypes off if is_x32_scalar is True because there is + # a weird promotion inconsistency in numpy: + # ``` + # print(np.result_type(np.ones([], np.int32), 1)) + # print(np.result_type(np.ones([1], np.int32), 1)) + # print(np.result_type(np.int32(1), 1)) + # print(np.result_type(np.int32, 1)) + # print(np.result_type(np.ones([], np.float32), 1)) + # print(np.result_type(np.ones([1], np.float32), 1)) + # print(np.result_type(np.float32(1), 1)) + # print(np.result_type(np.float32, 1)) + # ``` + # >>> + # int64 + # int32 + # int64 + # int32 + # float64 + # float32 + # float64 + # float32 + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=not is_x32_scalar, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=not is_x32_scalar, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_decimals={}".format( + jtu.format_shape_dtype_string(shape, dtype), decimals + ), + "shape": shape, + "dtype": dtype, + "decimals": decimals, + "rng_factory": jtu.rand_default, + } + for shape, dtype in _shape_and_dtypes( + all_shapes, minus(number_dtypes, complex_dtypes) + ) + for decimals in [0, 1, -2] + ) + ) + def testRoundStaticDecimals(self, shape, dtype, decimals, rng_factory): + rng = rng_factory() + if lnp.issubdtype(dtype, onp.integer) and decimals < 0: + self.skipTest("Integer rounding with decimals < 0 not implemented") + onp_fun = lambda x: onp.round(x, decimals=decimals) + lnp_fun = lambda x: lnp.round(x, decimals=decimals) + args_maker = lambda: [rng(shape, dtype)] + tol = { + # TODO(b/154768983): lnp.bfloat16: 5e-2, + onp.float16: 1e-2 + } + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=check_dtypes, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=check_dtypes, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + def testOperatorRound(self): + self.assertAllClose( + round(onp.float32(7.532), 1), round(lnp.float32(7.5), 1), check_dtypes=True + ) + self.assertAllClose( + round(onp.float32(1.234), 2), + round(lnp.float32(1.234), 2), + check_dtypes=True, + ) + self.assertAllClose( + round(onp.float32(1.234)), round(lnp.float32(1.234)), check_dtypes=False + ) + self.assertAllClose( + round(onp.float32(7.532), 1), + round(lnp.array(7.5, lnp.float32), 1), + check_dtypes=True, + ) + self.assertAllClose( + round(onp.float32(1.234), 2), + round(lnp.array(1.234, lnp.float32), 2), + check_dtypes=True, + ) + self.assertAllClose( + round(onp.float32(1.234)), + round(lnp.array(1.234, lnp.float32)), + check_dtypes=False, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_mode={}_rpadwidth={}_rconstantvalues={}".format( + jtu.format_shape_dtype_string(shape, dtype), + mode, + pad_width_rank, + constant_values_rank, + ), + "shape": shape, + "dtype": dtype, + "mode": mode, + "pad_width_rank": pad_width_rank, + "constant_values_rank": constant_values_rank, + "rng_factory": jtu.rand_default, + "irng_factory": partial(jtu.rand_int, 3), + } + for mode, constant_values_rank, shapes in [ + ("constant", 0, all_shapes), + ("constant", 1, all_shapes), + ("constant", 2, all_shapes), + ("symmetric", None, nonempty_shapes), + ("reflect", None, nonempty_shapes), + ("wrap", None, nonempty_shapes), + ] + for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) + for pad_width_rank in range(3) + ) + ) + @jtu.disable + def testPad( + self, + shape, + dtype, + mode, + pad_width_rank, + constant_values_rank, + rng_factory, + irng_factory, + ): + rng = rng_factory() + irng = irng_factory() + pad_width = irng([len(shape), 2][2 - pad_width_rank :], onp.int32) + + def onp_fun(x, kwargs): + if pad_width.size == 0: + return x + return onp.pad(x, pad_width, mode=mode, **kwargs) + + def lnp_fun(x, kwargs): + return lnp.pad(x, pad_width, mode=mode, **kwargs) + + def args_maker(): + kwargs = {} + if constant_values_rank: + kwargs["constant_values"] = rng( + [len(shape), 2][2 - constant_values_rank :], dtype + ) + return rng(shape, dtype), kwargs + + self._CheckAgainstNumpy( + onp_fun, + lnp_fun, + args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, + ) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape=[{}]_reps={}".format( + jtu.format_shape_dtype_string(shape, dtype), reps + ), + "shape": shape, + "dtype": dtype, + "reps": reps, + "rng_factory": jtu.rand_default, + } + for reps in [(), (2,), (3, 4), (2, 3, 4)] + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) + ) + ) + def testTile(self, shape, dtype, reps, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: onp.tile(arg, reps) + lnp_fun = lambda arg: lnp.tile(arg, reps) + args_maker = lambda: [rng(shape, dtype)] + tol_spec = {onp.float64: 2e-7} + tol = jtu.tolerance(dtype, tol_spec) + self._CheckAgainstNumpy( + onp_fun, + lnp_fun, + args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, + tol=tol, + ) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( + axis, + ",".join(str(d) for d in base_shape), + ",".join(onp.dtype(dtype).name for dtype in arg_dtypes), + ), + "axis": axis, + "base_shape": base_shape, + "arg_dtypes": arg_dtypes, + "rng_factory": jtu.rand_default, + } + for num_arrs in [3] + for arg_dtypes in CombosWithReplacement(default_dtypes, num_arrs) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(-len(base_shape) + 1, len(base_shape)) + ) + ) + def testConcatenate(self, axis, base_shape, arg_dtypes, rng_factory): + rng = rng_factory() + wrapped_axis = axis % len(base_shape) + shapes = [ + base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis + 1 :] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes) + ] + + def onp_fun(*args): + # TODO(nareshmodi): enable once bfloat16 has better support + # args = [x if x.dtype != bfloat16 else x.astype(onp.float32) + # for x in args] + dtype = functools.reduce(lnp.promote_types, arg_dtypes) + return onp.concatenate(args, axis=axis).astype(dtype) + + lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( + axis, + ",".join(str(d) for d in base_shape), + ",".join(onp.dtype(dtype).name for dtype in arg_dtypes), + ), + "axis": axis, + "base_shape": base_shape, + "arg_dtypes": arg_dtypes, + "rng_factory": jtu.rand_default, + } + for arg_dtypes in CombosWithReplacement(default_dtypes, 2) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(-len(base_shape) + 1, len(base_shape)) + ) + ) + def testAppend(self, axis, base_shape, arg_dtypes, rng_factory): + rng = rng_factory() + wrapped_axis = axis % len(base_shape) + shapes = [ + base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis + 1 :] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes) + ] + + def onp_fun(arr, values): + arr = arr.astype(onp.float32) if lnp.bfloat16 == arr.dtype else arr + values = ( + values.astype(onp.float32) if lnp.bfloat16 == values.dtype else values + ) + out = onp.append(arr, values, axis=axis) + return out.astype(lnp.promote_types(*arg_dtypes)) + + lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape=[{}]_axis={}_repeats={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, repeats + ), + "axis": axis, + "shape": shape, + "dtype": dtype, + "repeats": repeats, + "rng_factory": jtu.rand_default, + } + for repeats in [0, 1, 2] + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) + for axis in [None] + list(range(-len(shape), len(shape))) + ) + ) + def testRepeat(self, axis, shape, dtype, repeats, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis) + onp_fun = _promote_like_lnp(onp_fun) + lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) + + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False + ) + + def testIssue1233(self): + """ + Following numpy test suite from `test_repeat` at https://github.com/numpy/numpy/blob/master/numpy/core/tests/test_multiarray.py + """ + tol = 1e-5 + + def test_single(m, args_maker, repeats, axis): + lax_ans = lnp.repeat(m, repeats, axis) + numpy_ans = onp.repeat(m, repeats, axis) + + self.assertAllClose( + lax_ans, numpy_ans, check_dtypes=True, rtol=tol, atol=tol + ) + + lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) + # Turns off XLA check because there are no XLA kernels for `Where` used by + # tf.repeat (b/169192730). + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=False, + check_experimental_compile=False, + check_xla_forced_compile=False, + ) + + m = lnp.array([1, 2, 3, 4, 5, 6]) + args_maker = lambda: [m] + + for repeats in [ + 2, + [1, 3, 2, 1, 1, 2], + [1, 3, 0, 1, 1, 2], + [2], + lnp.array([1, 3, 2, 1, 1, 2]), + lnp.array([2]), + ]: + test_single(m, args_maker, repeats, None) + + m_rect = m.reshape((2, 3)) + args_maker = lambda: [m_rect] + + for repeats in [2, [2, 1], [2], lnp.array([2, 1]), lnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=0) + + for repeats in [2, [1, 3, 2], [2], lnp.array([1, 3, 2]), lnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=1) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( + op, jtu.format_shape_dtype_string(shape, dtype), axis, out_dtype + ), + "axis": axis, + "shape": shape, + "dtype": dtype, + "out_dtype": out_dtype, + "rng_factory": jtu.rand_default, + "lnp_op": getattr(lnp, op), + "onp_op": getattr(onp, op), + } + for op in ["cumsum", "cumprod"] + for dtype in default_dtypes + for out_dtype in default_dtypes + for shape in all_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ) + ) + def testCumSumProd( + self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_factory + ): + rng = rng_factory() + onp_fun = lambda arg: onp_op(arg, axis=axis, dtype=out_dtype) + lnp_fun = lambda arg: lnp_op(arg, axis=axis, dtype=out_dtype) + + args_maker = lambda: [rng(shape, dtype)] + + tol = max(jtu.tolerance(dtype), jtu.tolerance(out_dtype)) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + # XLA lacks int64 Cumsum/Cumprod kernels (b/168841378). + check_xla = out_dtype != onp.int64 + rtol = None + if out_dtype == onp.float16: + rtol = 2e-3 + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + rtol=rtol, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_dtype={}_m={}_n={}_k={}".format( + onp.dtype(dtype).name, m, n, k + ), + "m": m, + "n": n, + "k": k, + "dtype": dtype, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for n in [0, 4] + for m in [None, 0, 1, 3, 4] + for k in list(range(-4, 4)) + ) + ) + def testTri(self, m, n, k, dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype) + lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_op={}_shape={}_k={}".format( + op, jtu.format_shape_dtype_string(shape, dtype), k + ), + "dtype": dtype, + "shape": shape, + "op": op, + "k": k, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for op in ["tril", "triu"] + for k in list(range(-3, 3)) + ) + ) + def testTriLU(self, dtype, shape, op, k, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: getattr(onp, op)(arg, k=k) + lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + # Incomplete shape support is not implemented at the moment. + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False + ) + + @named_parameters( + jtu.cases_from_list( + {"testcase_name": "_ndim={}_n={}".format(ndim, n), "ndim": ndim, "n": n} + for ndim in [0, 1, 4] + for n in [0, 1, 7] + ) + ) + def testDiagIndices(self, ndim, n): + onp.testing.assert_equal(onp.diag_indices(n, ndim), lnp.diag_indices(n, ndim)) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k + ), + "dtype": dtype, + "shape": shape, + "k": k, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] + for k in list(range(-4, 4)) + ) + ) + def testDiag(self, shape, dtype, k, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: onp.diag(arg, k) + lnp_fun = lambda arg: lnp.diag(arg, k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( + jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2 + ), + "dtype": dtype, + "shape": shape, + "offset": offset, + "axis1": axis1, + "axis2": axis2, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in [ + a + for a in range(-len(shape), len(shape)) + if a % len(shape) != axis1 % len(shape) + ] + for offset in list(range(-4, 4)) + ) + ) + def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2) + lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n), + "dtype": dtype, + "n": n, + } + for dtype in default_dtypes + for n in list(range(4)) + ) + ) + def testIdentity(self, n, dtype): + onp_fun = lambda: onp.identity(n, dtype) + lnp_fun = lambda: lnp.identity(n, dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format( + jtu.format_shape_dtype_string(shape, dtype), + out_dtype, + offset, + axis1, + axis2, + ), + "dtype": dtype, + "out_dtype": out_dtype, + "shape": shape, + "offset": offset, + "axis1": axis1, + "axis2": axis2, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for out_dtype in [None] + number_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in range(-len(shape), len(shape)) + if (axis1 % len(shape)) != (axis2 % len(shape)) + for offset in list(range(-4, 4)) + ) + ) + def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng_factory): + rng = rng_factory() + + def onp_fun(arg): + if out_dtype == lnp.bfloat16: + return onp.trace(arg, offset, axis1, axis2, onp.float32).astype( + lnp.bfloat16 + ) + else: + return onp.trace(arg, offset, axis1, axis2, out_dtype) + + lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_axis={}".format( + jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis + ), + "shape": shape, + "axis": axis, + "dtypes": dtypes, + "rng_factory": rng_factory, + } + for dtypes in [ + [onp.float32], + [onp.float32, onp.float32], + [onp.float32, onp.int32, onp.float32], + [onp.float32, onp.int64, onp.float32], + [onp.float32, onp.int32, onp.float64], + ] + for shape in [(), (2,), (3, 4), (1, 100)] + for axis in range(-len(shape), len(shape) + 1) + for rng_factory in [jtu.rand_default] + ) + ) + def testStack(self, shape, axis, dtypes, rng_factory): + rng = rng_factory() + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + onp_fun = _promote_like_lnp(partial(onp.stack, axis=axis)) + lnp_fun = partial(lnp.stack, axis=axis) + self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lnp_fun, args_maker, True, check_incomplete_shape=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_op={}_{}".format( + op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes) + ), + "shape": shape, + "op": op, + "dtypes": dtypes, + "rng_factory": rng_factory, + } + for op in ["hstack", "vstack", "dstack"] + for dtypes in [ + [onp.float32], + [onp.float32, onp.float32], + [onp.float32, onp.int32, onp.float32], + [onp.float32, onp.int64, onp.float32], + [onp.float32, onp.int32, onp.float64], + ] + for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)] + for rng_factory in [jtu.rand_default] + ) + ) + def testHVDStack(self, shape, op, dtypes, rng_factory): + rng = rng_factory() + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + onp_fun = _promote_like_lnp(getattr(onp, op)) + lnp_fun = getattr(lnp, op) + self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lnp_fun, args_maker, True, check_incomplete_shape=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_outdtype={}".format( + jtu.format_shape_dtype_string(shape, fill_value_dtype), + onp.dtype(out_dtype).name if out_dtype else "None", + ), + "shape": shape, + "fill_value_dtype": fill_value_dtype, + "out_dtype": out_dtype, + "rng_factory": jtu.rand_default, + } + for shape in array_shapes + [3, onp.array(7, dtype=onp.int32)] + for fill_value_dtype in default_dtypes + for out_dtype in [None] + default_dtypes + ) + ) + def testFull(self, shape, fill_value_dtype, out_dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype) + lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype) + args_maker = lambda: [rng((), fill_value_dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype), + "onp_op": getattr(onp, op), + "lnp_op": getattr(lnp, op), + "shape": shape, + "dtype": dtype, + } + for op in ["zeros", "ones"] + for shape in [ + 2, + (), + (2,), + (3, 0), + onp.array((4, 5, 6), dtype=onp.int32), + onp.array(4, dtype=onp.int32), + ] + for dtype in all_dtypes + ) + ) + def testZerosOnes(self, onp_op, lnp_op, shape, dtype): + rng = jtu.rand_default() + + def args_maker(): + return [] + + onp_op = partial(onp_op, shape, dtype) + lnp_op = partial(lnp_op, shape, dtype) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_filldtype={}_outdtype={}".format( + jtu.format_shape_dtype_string(shape, in_dtype), + onp.dtype(fill_value_dtype).name, + onp.dtype(out_dtype).name, + ), + "shape": shape, + "in_dtype": in_dtype, + "fill_value_dtype": fill_value_dtype, + "out_dtype": out_dtype, + "rng_factory": jtu.rand_default, + } + for shape in array_shapes + for in_dtype in default_dtypes + for fill_value_dtype in default_dtypes + for out_dtype in default_dtypes + ) + ) + def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype) + lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype) + args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_axis={}_{}sections".format( + jtu.format_shape_dtype_string(shape, dtype), axis, num_sections + ), + "shape": shape, + "num_sections": num_sections, + "axis": axis, + "dtype": dtype, + "rng_factory": jtu.rand_default, + } + for shape, axis, num_sections in [ + ((3,), 0, 3), + ((12,), 0, 3), + ((12, 4), 0, 4), + ((12, 4), 1, 2), + ((2, 3, 4), -1, 2), + ((2, 3, 4), -2, 3), + ] + for dtype in default_dtypes + ) + ) + def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.split(x, num_sections, axis=axis) + lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_axis={}_{}sections".format( + jtu.format_shape_dtype_string(shape, dtype), axis, num_sections + ), + "shape": shape, + "num_sections": num_sections, + "axis": axis, + "dtype": dtype, + "rng_factory": jtu.rand_default, + } + for shape, axis, num_sections in [ + ((12, 4), 0, 4), + ((12, 4), 1, 2), + ((2, 3, 4), 2, 2), + ((4, 3, 4), 0, 2), + ] + for dtype in default_dtypes + ) + ) + def testHVDSplit(self, shape, num_sections, axis, dtype, rng_factory): + rng = rng_factory() + + def fn(module, axis): + if axis == 0: + return module.vsplit + elif axis == 1: + return module.hsplit + else: + assert axis == 2 + return module.dsplit + + onp_fun = lambda x: fn(onp, axis)(x, num_sections) + lnp_fun = lambda x: fn(lnp, axis)(x, num_sections) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_outshape={}_order={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype), + order, + ), + "arg_shape": arg_shape, + "out_shape": out_shape, + "dtype": dtype, + "order": order, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for order in ["C", "F"] + for arg_shape, out_shape in [ + (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), + ((), (1, 1, 1)), + ((7, 0), (0, 42, 101)), + ((3, 4), 12), + ((3, 4), (12,)), + ((3, 4), -1), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)), + ] + ) + ) + def testReshape(self, arg_shape, out_shape, dtype, order, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.reshape(x, out_shape, order=order) + lnp_fun = lambda x: lnp.reshape(x, out_shape, order=order) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_outshape={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype), + ), + "arg_shape": arg_shape, + "out_shape": out_shape, + "dtype": dtype, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for arg_shape, out_shape in [ + ((7, 0), (0, 42, 101)), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)), + ] + ) + ) + def testReshapeMethod(self, arg_shape, out_shape, dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.reshape(x, out_shape) + lnp_fun = lambda x: x.reshape(*out_shape) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_expanddim={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), dim + ), + "arg_shape": arg_shape, + "dtype": dtype, + "dim": dim, + "rng_factory": jtu.rand_default, + } + for arg_shape in [(), (3,), (3, 4)] + for dtype in default_dtypes + for dim in range(-len(arg_shape) + 1, len(arg_shape)) + ) + ) + def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.expand_dims(x, dim) + lnp_fun = lambda x: lnp.expand_dims(x, dim) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_axes=({},{})".format( + jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2 + ), + "arg_shape": arg_shape, + "dtype": dtype, + "ax1": ax1, + "ax2": ax2, + "rng_factory": jtu.rand_default, + } + for arg_shape, ax1, ax2 in [ + ((3, 4), 0, 1), + ((3, 4), 1, 0), + ((3, 4, 5), 1, 2), + ((3, 4, 5), -1, -2), + ((3, 4, 5), 0, 1), + ] + for dtype in default_dtypes + ) + ) + def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) + lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_axes=({},{})".format( + jtu.format_shape_dtype_string(arg_shape, dtype), source, destination + ), + "arg_shape": arg_shape, + "dtype": dtype, + "source": source, + "destination": destination, + "rng_factory": jtu.rand_default, + } + for arg_shape, source, destination in [ + (tuple(range(6)), (0, 2), (3, 5)), + (tuple(range(6)), (0, 2), (-1, -3)), + (tuple(range(6)), (-6, -4), (3, 5)), + (tuple(range(6)), (-6, -4), (-1, -3)), + (tuple(range(6)), 0, 4), + (tuple(range(6)), -6, -2), + (tuple(range(6)), tuple(range(6)), tuple(range(6))), + (tuple(range(6)), tuple(range(6)), tuple(reversed(range(6)))), + (tuple(range(6)), (), ()), + ] + for dtype in default_dtypes + ) + ) + @new_test + def testMoveaxisStaticAxes( + self, arg_shape, dtype, source, destination, rng_factory + ): + rng = rng_factory() + onp_fun = lambda x: onp.moveaxis(x, source, destination) + lnp_fun = lambda x: lnp.moveaxis(x, source, destination) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_axis={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), ax + ), + "arg_shape": arg_shape, + "dtype": dtype, + "ax": ax, + "rng_factory": jtu.rand_default, + } + for arg_shape, ax in [ + ((3, 1), None), + ((3, 1), 1), + ((1, 3, 1), (0, 2)), + ((1, 4, 1), (0,)), + ] + for dtype in default_dtypes + ) + ) + def testSqueeze(self, arg_shape, dtype, ax, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.squeeze(x, ax) + lnp_fun = lambda x: lnp.squeeze(x, ax) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_axis={}_weights={}_returned={}".format( + jtu.format_shape_dtype_string(shape, dtype), + axis, + ( + None + if weights_shape is None + else jtu.format_shape_dtype_string(weights_shape, dtype) + ), + returned, + ), + "rng_factory": jtu.rand_default, + "shape": shape, + "dtype": dtype, + "axis": axis, + "weights_shape": weights_shape, + "returned": returned, + } + for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes) + for axis in set(range(-len(shape), len(shape))) | set([None]) + # `weights_shape` is either `None`, same as the averaged axis, or same as + # that of the input + for weights_shape in ( + [None, shape] + if axis is None or len(shape) == 1 + else [None, (shape[axis],), shape] + ) + for returned in [False, True] + ) + ) + def testAverage(self, shape, dtype, axis, weights_shape, returned, rng_factory): + rng = rng_factory() + if weights_shape is None: + onp_fun = lambda x: onp.average(x, axis, returned=returned) + lnp_fun = lambda x: lnp.average(x, axis, returned=returned) + args_maker = lambda: [rng(shape, dtype)] + else: + onp_fun = lambda x, weights: onp.average(x, axis, weights, returned) + lnp_fun = lambda x, weights: lnp.average(x, axis, weights, returned) + args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)] + onp_fun = _promote_like_lnp(onp_fun, inexact=True) + tol = { + # TODO(b/154768983): lnp.bfloat16: 1e-1, + onp.float16: 1e-1, + onp.float32: 1e-3, + onp.float64: 2e-7, + onp.complex64: 1e-3, + onp.complex128: 1e-10, + } + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + try: + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=check_dtypes, tol=tol + ) + except ZeroDivisionError: + self.skipTest("don't support checking for ZeroDivisionError") + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=check_dtypes, + rtol=tol, + atol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_arg{}_ndmin={}".format(i, ndmin), + "arg": arg, + "ndmin": ndmin, + "dtype": dtype, + } + for i, (arg, dtype) in enumerate( + [ + ([True, False, True], lnp.bool_), + (3.0, lnp.float_), + ([1, 2, 3], lnp.int_), + ([1.0, 2.0, 3.0], lnp.float_), + ([[1, 2], [3, 4], [5, 6]], lnp.int_), + ([[1, 2.0], [3, 4], [5, 6]], lnp.float_), + ([[1.0, 2j], [3.0, 4.0], [5.0, 6.0]], lnp.complex_), + ( + [ + [3, onp.array(2, dtype=lnp.float_), 1], + onp.arange(3.0, dtype=lnp.float_), + ], + lnp.float_, + ), + ] + ) + for ndmin in [None, onp.ndim(arg), onp.ndim(arg) + 1, onp.ndim(arg) + 2] + ) + ) + def testArray(self, arg, ndmin, dtype): + args_maker = lambda: [arg] + dtype = lnp.canonicalize_dtype(dtype) + if ndmin is not None: + onp_fun = partial(onp.array, ndmin=ndmin, dtype=dtype) + lnp_fun = partial(lnp.array, ndmin=ndmin) + else: + onp_fun = partial(onp.array, dtype=dtype) + lnp_fun = lnp.array + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + static_argnums=[0], + ) + + def testIssue121(self): + assert not onp.isscalar(lnp.array(3)) + + @jtu.disable + def testArrayMethod(self): + class arraylike(object): + dtype = onp.float32 + + def __array__(self, dtype=None): + return 3.0 + + a = arraylike() + ans = lnp.array(a) + assert ans == 3.0 + + @jtu.skip_on_devices("tpu") # TODO(b/32368900): TPUs don't support uint8 yet. + @jtu.disable + def testMemoryView(self): + ans = lnp.array(bytearray(b"\x2a")) + self.assertAllClose(ans, onp.array([0x2A], dtype=onp.uint8), check_dtypes=True) + + def testAllClose(self): + rng = onp.random.RandomState(0) + x = rng.randn(2, 2) + y = rng.randn(2) + + def same(list1, list2): + allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) + elements_close = list(map(allclose, list1, list2)) + return lnp.all(lnp.array(elements_close)) + + csame = npe.jit(same) + + a1 = same((x, y), (x, y)) + a2 = csame((x, y), (x, y)) + a3 = csame((x, y), (x, 2 * y)) + + self.assertTrue(a1) + self.assertTrue(a2) + self.assertFalse(a3) + + @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure + @jtu.disable + def testOnesBroadcastingConstantHandler(self): + # TODO(mattjj): update this test for jax3 + self.skipTest("test needs jax3 update") + + def fun(x): + ones = lnp.ones((3, 4)) + assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) + + # To check that the constant handler generates a Broadcast for stride-zero + # arrays, we monkey-patch the client instance. + # TODO(mattjj): once we have better HLO dumping and inspecting facilities, + # we can check the HLO more directly. + c = x._node.c + Broadcast = c.Broadcast # pylint: disable=invalid-name + was_called = [] + c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) + out = x + ones # the ndarray constant handler should call Broadcast here + assert was_called, "Broadcast was not called." + + return out + + fun = api.jit(fun) + out_val = fun(lnp.ones(4)) + self.assertAllClose(out_val, onp.full((3, 4), 2.0), check_dtypes=False) + + def testZeroStridesConstantHandler(self): + raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) + const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) + + def fun(x): + return x * const + + fun = npe.jit(fun) + out_val = fun(3.0) + self.assertAllClose(out_val, 3.0 * const, check_dtypes=False) + + def testIsInstanceNdarrayDuringTracing(self): + arr = onp.ones(3) + + @npe.jit + def f(x): + self.assertIsInstance(x, lnp.ndarray) + return lnp.sum(x) + + f(arr) + + @jtu.disable + def testNonArrayErrorMessage(self): + x = [1.0, 2.0] + y = onp.array([3.0, 4.0]) + + def g(x, y): + return lnp.add(x, y) + + def f(x, y): + return lnp.dot(x, y) + + self.assertRaises(TypeError, lambda: g(x, y)) + self.assertRaises(TypeError, lambda: f(x, y)) + self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) + self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) + + @jtu.disable + def testAbstractionErrorMessage(self): + @api.jit + def f(x, n): + for _ in range(n): + x = x * x + return x + + self.assertRaises(TypeError, lambda: f(3.0, 3)) + + @api.jit + def g(x): + if x > 0.0: + return x * 2 + else: + return x + 2 + + self.assertRaises(TypeError, lambda: g(3.0)) + + @jtu.disable + def testTracingPrimitiveWithNoTranslationErrorMessage(self): + # TODO(mattjj): update this for jax3 + self.skipTest("test needs jax3 update") + foo = lnp._not_implemented(lambda x: x) + + # No error if there's no tracing. + foo(onp.arange(3)) + + cfoo = api.jit(foo) + self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + "axis": axis, + } + for shape in [(3,), (2, 3)] + for dtype in default_dtypes + for axis in list(range(-len(shape), len(shape))) + + [None] # Test negative axes + for rng_factory in [jtu.rand_default] + ) + ) + def testFlip(self, shape, dtype, axis, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.flip(x, axis) + onp_op = lambda x: onp.flip(x, axis) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + } + for shape in [(3,), (2, 3), (3, 2, 4)] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testFlipud(self, shape, dtype, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.flipud(x) + onp_op = lambda x: onp.flipud(x) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + } + for shape in [(3, 2), (2, 3), (3, 2, 4)] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testFliplr(self, shape, dtype, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.fliplr(x) + onp_op = lambda x: onp.fliplr(x) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_k={}_axes={}".format( + jtu.format_shape_dtype_string(shape, dtype), k, axes + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + "k": k, + "axes": axes, + } + for shape, axes in [ + [(2, 3), (0, 1)], + [(2, 3), (1, 0)], + [(4, 3, 2), (0, 2)], + [(4, 3, 2), (2, 1)], + ] + for k in range(-3, 4) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testRot90(self, shape, dtype, k, axes, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.rot90(x, k, axes) + onp_op = lambda x: onp.rot90(x, k, axes) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_k={}_axes={}".format( + jtu.format_shape_dtype_string(shape, dtype), k, axes + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + "k": k, + "axes": axes, + } + for shape, axes in [ + [(2, 3), (-2, -1)], + [(2, 3), (-2, 1)], + [(4, 3, 2), (-1, -2)], + [(4, 3, 2), (2, -2)], + ] + for k in range(-3, 4) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + @new_test + # These tests are only added as a separate test from testRot90 since we would + # like to measure coverage directly against the existing baseline. Once we + # stop measuring that, we can combine this test with the above. + def testRot90Additional(self, shape, dtype, k, axes, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.rot90(x, k, axes) + onp_op = lambda x: onp.rot90(x, k, axes) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + # TODO(mattjj): test infix operator overrides + + def testRavel(self): + rng = onp.random.RandomState(0) + args_maker = lambda: [rng.randn(3, 4).astype("float32")] + self._CompileAndCheck( + lambda x: x.ravel(), + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + ) + + def testAstype(self): + rng = onp.random.RandomState(0) + args_maker = lambda: [rng.randn(3, 4).astype("float32")] + op = lambda x: x.astype(lnp.int32) + self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True) + self._CompileAndCheck( + op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + # TODO(mattjj): test other ndarray-like method overrides + + def testOnpMean(self): + # from https://github.com/google/jax/issues/125 + x = lnp.add(lnp.eye(3, dtype=lnp.float_), 0.0) + ans = onp.mean(x) + self.assertAllClose(ans, onp.array(1.0 / 3), check_dtypes=False) + + @jtu.disable + def testArangeOnFloats(self): + # from https://github.com/google/jax/issues/145 + expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_) + ans = lnp.arange(0.0, 1.0, 0.1) + self.assertAllClose(expected, ans, check_dtypes=True) + + def testSortManually(self): + def _test(*args, **kwargs): + + raw_ans = lnp.sort(*args, **kwargs) + fn_ans = npe.jit(lnp.sort, static_argnums=(1,))(*args, **kwargs) + expected = onp.sort(*args, **kwargs) + + self.assertAllClose(expected, raw_ans, check_dtypes=True) + self.assertAllClose(expected, fn_ans, check_dtypes=True) + + # manual tests for sort are nice because we don't have to worry about ties. + # lax.sort is tested combinatorially. + _test(onp.array([16, 15, 23, 42, 8, 4])) + _test(onp.array([[1, 4], [3, 1]]), None) + _test(onp.array([[1, 4], [3, 1]])) + _test(onp.array([[1, 4], [3, 1]]), 0) + + def testArgsortManually(self): + def _test(*args, **kwargs): + + raw_ans = lnp.argsort(*args, **kwargs) + fn_ans = npe.jit(lnp.argsort, static_argnums=(1,))(*args, **kwargs) + expected = onp.argsort(*args, **kwargs) + + self.assertAllClose(expected, raw_ans, check_dtypes=True) + self.assertAllClose(expected, fn_ans, check_dtypes=True) + + _test(onp.array([16, 15, 23, 42, 8, 4])) + _test(onp.array([[16, 15, 23], [42, 8, 4]]), 0) + _test(onp.array([[16, 15, 23], [42, 8, 4]]), 1) + _test(onp.array([[16, 15, 23], [42, 8, 4]]), None) + _test(onp.array([[16, 15, 23], [42, 8, 4]])) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_shifts={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), shifts, axis + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + "shifts": shifts, + "axis": axis, + } + for dtype in all_dtypes + for shape in [(3, 4), (3, 4, 5), (7, 4, 0)] + for shifts, axis in [ + (3, None), + (1, 1), + ((3,), (0,)), + ((-2,), (-2,)), + ((1, 2), (0, -1)), + ] + for rng_factory in [jtu.rand_default] + ) + ) + def testRoll(self, shape, dtype, shifts, axis, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), onp.array(shifts)] + lnp_op = partial(lnp.roll, axis=axis) + onp_op = partial(onp.roll, axis=axis) + self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_index={}_axis={}_mode={}".format( + jtu.format_shape_dtype_string(shape, dtype), + jtu.format_shape_dtype_string(index_shape, index_dtype), + axis, + mode, + ), + "rng_factory": rng_factory, + "rng_indices_factory": rng_indices_factory, + "shape": shape, + "index_shape": index_shape, + "dtype": dtype, + "index_dtype": index_dtype, + "axis": axis, + "mode": mode, + } + for shape in [(3,), (3, 4), (3, 4, 5)] + for index_shape in scalar_shapes + [(3,), (2, 1, 3)] + for axis in itertools.chain(range(-len(shape), len(shape)), [None]) + for dtype in all_dtypes + for index_dtype in int_dtypes + for mode in ["wrap", "clip"] + for rng_factory in [jtu.rand_default] + for rng_indices_factory in [partial(jtu.rand_int, -5, 5)] + ) + ) + def testTake( + self, + shape, + dtype, + index_shape, + index_dtype, + axis, + mode, + rng_factory, + rng_indices_factory, + ): + def args_maker(): + x = rng(shape, dtype) + i = rng_indices(index_shape, index_dtype) + return x, i + + rng = rng_factory() + rng_indices = rng_indices_factory() + lnp_op = lambda x, i: lnp.take(x, i, axis=axis, mode=mode) + onp_op = lambda x, i: onp.take(x, i, axis=axis, mode=mode) + self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_ishape={}_axis={}".format( + jtu.format_shape_dtype_string(x_shape, dtype), i_shape, axis + ), + "rng_factory": rng_factory, + "x_shape": x_shape, + "i_shape": i_shape, + "dtype": dtype, + "axis": axis, + } + for x_shape, i_shape in filter( + _shapes_are_equal_length, + filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(nonempty_nonscalar_array_shapes, 2), + ), + ) + for axis in itertools.chain(range(len(x_shape)), [-1], [None]) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testTakeAlongAxis(self, x_shape, i_shape, dtype, axis, rng_factory): + rng = rng_factory() + i_shape = onp.array(i_shape) + if axis is None: + i_shape = [onp.prod(i_shape, dtype=onp.int64)] + else: + # Test the case where the size of the axis doesn't necessarily broadcast. + i_shape[axis] *= 3 + i_shape = list(i_shape) + + def args_maker(): + x = rng(x_shape, dtype) + n = onp.prod(x_shape, dtype=onp.int32) if axis is None else x_shape[axis] + i = rng(i_shape, onp.int32) % (2 * n - 1) - (n - 1) + return x, i + + lnp_op = lambda x, i: lnp.take_along_axis(x, i, axis=axis) + + if hasattr(onp, "take_along_axis"): + onp_op = lambda x, i: onp.take_along_axis(x, i, axis=axis) + self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_n={}_increasing={}".format( + jtu.format_shape_dtype_string([shape], dtype), n, increasing + ), + "dtype": dtype, + "shape": shape, + "n": n, + "increasing": increasing, + "rng_factory": jtu.rand_default, + } + for dtype in inexact_dtypes + for shape in [0, 5] + for n in [2, 4] + for increasing in [False, True] + ) + ) + def testVander(self, shape, dtype, n, increasing, rng_factory): + rng = rng_factory() + + def onp_fun(arg): + arg = arg.astype(onp.float32) if dtype == lnp.bfloat16 else arg + return onp.vander(arg, N=n, increasing=increasing) + + lnp_fun = lambda arg: lnp.vander(arg, N=n, increasing=increasing) + args_maker = lambda: [rng([shape], dtype)] + # np.vander seems to return float64 for all floating types. We could obey + # those semantics, but they seem like a bug. + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=False, tol={onp.float32: 1e-3} + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=False, + check_incomplete_shape=True, + rtol={onp.complex128: 2e-15}, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + "nan_to_num", [shape], [dtype] + ), + "rng_factory": jtu.rand_some_inf_and_nan, + "shape": shape, + "dtype": dtype, + } + for shape in all_shapes + for dtype in inexact_dtypes + ) + ) + @jtu.disable + def testNanToNum(self, rng_factory, shape, dtype): + rng = rng_factory() + dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type + + def onp_fun(x): + if dtype == lnp.bfloat16: + x = onp.where(onp.isnan(x), dtype(0), x) + x = onp.where(onp.isposinf(x), lnp.finfo(dtype).max, x) + x = onp.where(onp.isneginf(x), lnp.finfo(dtype).min, x) + return x + else: + return onp.nan_to_num(x).astype(dtype) + + args_maker = lambda: [rng(shape, dtype)] + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy( + onp_fun, lnp.nan_to_num, args_maker, check_dtypes=check_dtypes + ) + self._CompileAndCheck(lnp.nan_to_num, args_maker, check_dtypes=check_dtypes) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes), + "rng_factory": jtu.rand_default, + "shapes": shapes, + "dtypes": dtypes, + } + for shapes, dtypes in ( + ((), ()), + (((7,),), (onp.int32,)), + (((3,), (4,)), (onp.int32, onp.int32)), + (((3,), (1,), (4,)), (onp.int32, onp.int32, onp.int32)), + ) + ) + ) + def testIx_(self, rng_factory, shapes, dtypes): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] + self._CheckAgainstNumpy(onp.ix_, lnp.ix_, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp.ix_, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}".format( + op, + jtu.format_shape_dtype_string(a_shape, a_dtype), + jtu.format_shape_dtype_string(q_shape, q_dtype), + axis, + keepdims, + ), + "a_rng": jtu.rand_default(), + "q_rng": q_rng, + "op": op, + "a_shape": a_shape, + "a_dtype": a_dtype, + "q_shape": q_shape, + "q_dtype": q_dtype, + "axis": axis, + "keepdims": keepdims, + } + for (op, q_rng) in ( + ("percentile", jtu.rand_uniform(low=0.0, high=100.0)), + ("quantile", jtu.rand_uniform(low=0.0, high=1.0)), + ("median", jtu.rand_uniform(low=0.0, high=1.0)), + ) + for a_dtype in float_dtypes + for a_shape, axis in ( + ((7,), None), + ((47, 7), 0), + ((4, 101), 1), + ) + for q_dtype in [onp.float32] + for q_shape in scalar_shapes + [(4,)] + for keepdims in [False, True] + ) + ) + @jtu.disable + def testQuantile( + self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype, axis, keepdims + ): + if op == "quantile" and numpy_version < (1, 15): + raise SkipTest("Numpy < 1.15 does not have np.quantile") + if op == "median": + args_maker = lambda: [a_rng(a_shape, a_dtype)] + else: + args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] + + def onp_fun(*args): + args = [ + x if lnp.result_type(x) != lnp.bfloat16 else onp.asarray(x, onp.float32) + for x in args + ] + return getattr(onp, op)(*args, axis=axis, keepdims=keepdims) + + lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims) + # TODO(phawkins): we currently set dtype=False because we aren't as + # aggressive about promoting to float64. It's not clear we want to mimic + # Numpy here. + tol_spec = {onp.float32: 2e-4, onp.float64: 5e-6} + tol = max(jtu.tolerance(a_dtype, tol_spec), jtu.tolerance(q_dtype, tol_spec)) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=False, tol=tol + ) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "shape": shape, + "dtype": dtype, + } + for shape in all_shapes + for dtype in all_dtypes + ) + ) + def testWhereOneArgument(self, shape, dtype): + rng = jtu.rand_some_zero() + onp_fun = lambda x: onp.where(x) + lnp_fun = lambda x: lnp.where(x) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) + # Turns off XLA check because there are no XLA kernels for `Where`, which + # XLA can't support because it's output shape is dynamic. + self._CompileAndCheck( + lnp.where, + args_maker, + check_dtypes=True, + check_eval_on_shapes=False, + check_incomplete_shape=True, + check_unknown_rank=False, + check_experimental_compile=False, + check_xla_forced_compile=False, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}".format( + "_".join( + jtu.format_shape_dtype_string(shape, dtype) + for shape, dtype in zip(shapes, dtypes) + ) + ), + "rng_factory": jtu.rand_default, + "shapes": shapes, + "dtypes": dtypes, + } + for shapes in filter( + _shapes_are_broadcast_compatible, CombosWithReplacement(all_shapes, 3) + ) + for dtypes in CombosWithReplacement(all_dtypes, 3) + ) + ) + def testWhereThreeArgument(self, rng_factory, shapes, dtypes): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng_factory(), shapes, dtypes) + + def onp_fun(cond, x, y): + return _promote_like_lnp(partial(onp.where, cond))(x, y) + + self._CheckAgainstNumpy(onp_fun, lnp.where, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp.where, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + def testWhereScalarPromotion(self): + x = lnp.where(lnp.array([True, False]), 3, lnp.ones((2,), dtype=lnp.float32)) + self.assertEqual(x.dtype, onp.dtype(onp.float32)) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + "", shapes, (onp.bool_,) * n + dtypes + ), + "rng_factory": jtu.rand_default, + "shapes": shapes, + "dtypes": dtypes, + } + for n in range(0, 3) + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(all_shapes, 2 * n + 1), + ) + for dtypes in CombosWithReplacement(all_dtypes, n + 1) + ) + ) + def testSelect(self, rng_factory, shapes, dtypes): + rng = rng_factory() + n = len(dtypes) - 1 + + def args_maker(): + condlist = [rng(shape, onp.bool_) for shape in shapes[:n]] + choicelist = [ + rng(shape, dtype) for shape, dtype in zip(shapes[n:-1], dtypes[:n]) + ] + default = rng(shapes[-1], dtypes[-1]) + return condlist, choicelist, default + + # TODO(phawkins): float32/float64 type mismatches + def onp_fun(condlist, choicelist, default): + choicelist = [ + x if lnp.bfloat16 != lnp.result_type(x) else x.astype(onp.float32) + for x in choicelist + ] + dtype = lnp.result_type(default, *choicelist).as_numpy_dtype + return onp.select( + condlist, + [onp.asarray(x, dtype=dtype) for x in choicelist], + onp.asarray(default, dtype=dtype), + ) + + self._CheckAgainstNumpy(onp_fun, lnp.select, args_maker, check_dtypes=False) + self._CompileAndCheck( + lnp.select, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + rtol={onp.float64: 1e-7, onp.complex128: 1e-7}, + ) + + @jtu.disable + def testIssue330(self): + x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash + self.assertEqual(x[0, 0], 1) + + @jtu.disable + def testScalarDtypePromotion(self): + orig_numpy_result = (1 + onp.eye(1, dtype=onp.float32)).dtype + jax_numpy_result = (1 + lnp.eye(1, dtype=lnp.float32)).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + @jtu.disable + def testSymmetrizeDtypePromotion(self): + x = onp.eye(3, dtype=onp.float32) + orig_numpy_result = ((x + x.T) / 2).dtype + + x = lnp.eye(3, dtype=lnp.float32) + jax_numpy_result = ((x + x.T) / 2).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + @jtu.disable + def testIssue347(self): + # https://github.com/google/jax/issues/347 + def test_fail(x): + x = lnp.sqrt(lnp.sum(x**2, axis=1)) + ones = lnp.ones_like(x) + x = lnp.where(x > 0.5, x, ones) + return lnp.sum(x) + + x = lnp.array([[1, 2], [3, 4], [0, 0]], dtype=lnp.float64) + result = api.grad(test_fail)(x) + assert not onp.any(onp.isnan(result)) + + def testIssue453(self): + # https://github.com/google/jax/issues/453 + a = onp.arange(6) + 1 + ans = lnp.reshape(a, (3, 2), order="F") + expected = onp.reshape(a, (3, 2), order="F") + self.assertAllClose(ans, expected, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_op={}_dtype={}".format(op, pytype.__name__), + "pytype": pytype, + "dtype": dtype, + "op": op, + } + for pytype, dtype in [ + (int, lnp.int_), + (float, lnp.float_), + (bool, lnp.bool_), + (complex, lnp.complex_), + ] + for op in ["atleast_1d", "atleast_2d", "atleast_3d"] + ) + ) + def testAtLeastNdLiterals(self, pytype, dtype, op): + # Fixes: https://github.com/google/jax/issues/634 + onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype) + lnp_fun = lambda arg: getattr(lnp, op)(arg) + args_maker = lambda: [pytype(2)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + def testLongLong(self): + self.assertAllClose( + onp.int64(7), npe.jit(lambda x: x)(onp.longlong(7)), check_dtypes=True + ) + + def testArange(self): + # test cases inspired by dask tests at + # https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92 + self.assertAllClose( + lnp.arange(77), onp.arange(77, dtype=lnp.int_), check_dtypes=True + ) + self.assertAllClose( + lnp.arange(2, 13), onp.arange(2, 13, dtype=lnp.int_), check_dtypes=True + ) + self.assertAllClose( + lnp.arange(4, 21, 9), + onp.arange(4, 21, 9, dtype=lnp.int_), + check_dtypes=True, + ) + self.assertAllClose( + lnp.arange(53, 5, -3), + onp.arange(53, 5, -3, dtype=lnp.int_), + check_dtypes=True, + ) + # TODO(mattjj): make these tests work when enable_x64=True + self.assertAllClose( + lnp.arange(77, dtype=float), onp.arange(77, dtype=float), check_dtypes=True + ) + self.assertAllClose( + lnp.arange(2, 13, dtype=int), + onp.arange(2, 13, dtype=int), + check_dtypes=True, + ) + self.assertAllClose( + lnp.arange(0, 1, -0.5), + onp.arange(0, 1, -0.5, dtype=lnp.float_), + check_dtypes=True, + ) + + self.assertRaises(TypeError, lambda: lnp.arange()) + + # # The following have been disabled since they test JAX specific behavior + # # test that lnp.arange(N) doesn't instantiate an ndarray + # self.assertFalse(type(lnp.arange(77)) == type(onp.arange(77))) + # self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77))) + + # # test that lnp.arange(N, dtype=int32) doesn't instantiate an ndarray + # self.assertFalse(type(lnp.arange(77, dtype=lnp.int32)) == + # type(onp.arange(77, dtype=onp.int32))) + # self.assertTrue(type(lnp.arange(77, dtype=lnp.int32)) == + # type(lax.iota(onp.int32, 77))) + + def testIssue830(self): + a = lnp.arange(4, dtype=lnp.complex64) + self.assertEqual(a.dtype, lnp.complex64) + + def testIssue728(self): + assert lnp.allclose(lnp.eye(5000), onp.eye(5000)) + self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050))) + + def testIssue746(self): + lnp.arange(12).reshape(3, 4) # doesn't crash + + def testIssue764(self): + x = lnp.linspace(190, 200, 4) + f = npe.grad(lambda x: lnp.sum(lnp.tanh(x))) + # Expected values computed with autograd in float64 precision. + expected = onp.array( + [3.71669453e-165, 4.72999108e-168, 6.01954653e-171, 7.66067839e-174], + onp.float64, + ) + self.assertAllClose(f(x), expected, check_dtypes=False) + + @jtu.disable + def testIssue776(self): + """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" + + def f(u): + y = ( + onp.ones( + 10, + ) + .at[[2, 4, 5]] + .add(u) + ) + # The transpose rule for lax.tie_in returns a symbolic zero for its first + # argument. + return lax.tie_in(y, 7.0) + + self.assertAllClose( + onp.zeros( + 3, + ), + api.grad(f)( + onp.ones( + 3, + ) + ), + check_dtypes=True, + ) + + @jtu.disable + def testIssue777(self): + x = lnp.linspace(-200, 0, 4, dtype=onp.float32) + f = npe.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x)))) + self.assertAllClose( + f(x), onp.array([0.0, 0.0, 0.0, 0.25], dtype=onp.float32), check_dtypes=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]), + "dtype": dtype, + "op": op, + } + for dtype in float_dtypes + for op in ( + "sqrt", + "arccos", + "arcsin", + "arctan", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "tanh", + "arccosh", + "arcsinh", + "arctanh", + "exp", + "log", + "expm1", + "log1p", + ) + ) + ) + def testMathSpecialFloatValues(self, op, dtype): + onp_op = getattr(onp, op) + lnp_op = getattr(lnp, op) + dtype = onp.dtype(lnp.canonicalize_dtype(dtype)).type + for x in ( + onp.nan, + -onp.inf, + -100.0, + -2.0, + -1.0, + 0.0, + 1.0, + 2.0, + 100.0, + onp.inf, + lnp.finfo(dtype).max, + onp.sqrt(lnp.finfo(dtype).max), + onp.sqrt(lnp.finfo(dtype).max) * 2.0, + ): + if ( + op in ("sin", "cos", "tan", "arctan") + and jtu.device_under_test() == "tpu" + ): + continue # TODO(b/132196789, b/134175194): fix and reenable. + # TODO(b/158006398): fix and reenable. + if ( + op + in ( + "cosh", + "arccosh", + "arcsinh", + "arcsin", + "sinh", + "arccos", + "arctan", + "arctanh", + ) + and dtype == onp.float16 + ): + continue + x = dtype(x) + expected = onp_op(x) + actual = lnp_op(x) + tol = jtu.tolerance(dtype, {onp.float32: 1e-3, onp.float64: 1e-7}) + self.assertAllClose(expected, actual, check_dtypes=True, atol=tol, rtol=tol) + + def testIssue883(self): + # from https://github.com/google/jax/issues/883 + + @partial(npe.jit, static_argnums=(1,)) + def f(x, v): + return x + + x = lnp.ones((10, 10)) + v = lnp.array([1, 2, 3]) + first_call = f(x, v) + second_call = f(x, v) # doesn't crash + + def testReductionOfOutOfBoundsAxis(self): # Issue 888 + x = lnp.ones((3, 4)) + self.assertRaises(tf.errors.InvalidArgumentError, lambda: lnp.sum(x, axis=2)) + + @jtu.disable + def testIssue956(self): + self.assertRaises(TypeError, lambda: lnp.ndarray((1, 1))) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}".format( + shape, dtype, out_dtype, axis, ddof, keepdims + ), + "shape": shape, + "dtype": dtype, + "out_dtype": out_dtype, + "axis": axis, + "ddof": ddof, + "keepdims": keepdims, + "rng_factory": rng_factory, + } + for shape in [(5,), (10, 5)] + for dtype in all_dtypes + for out_dtype in inexact_dtypes + for axis in [None, 0, -1] + for ddof in [0, 1, 2] + for keepdims in [False, True] + for rng_factory in [jtu.rand_default] + ) + ) + def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + + def onp_fun(x): + out = onp.var( + x.astype(lnp.promote_types(onp.float32, dtype)), + axis=axis, + ddof=ddof, + keepdims=keepdims, + ) + return out.astype(out_dtype) + + lnp_fun = partial( + lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims + ) + tol = jtu.tolerance( + out_dtype, + { + onp.float16: 1e-1, + onp.float32: 1e-3, + onp.float64: 1e-3, + onp.complex128: 1e-6, + }, + ) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + rtol=tol, + atol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format( + shape, dtype, rowvar, ddof, bias + ), + "shape": shape, + "dtype": dtype, + "rowvar": rowvar, + "ddof": ddof, + "bias": bias, + "rng_factory": rng_factory, + } + for shape in [(5,), (10, 5), (5, 10)] + for dtype in all_dtypes + for rowvar in [True, False] + for bias in [True, False] + for ddof in [None, 2, 3] + for rng_factory in [jtu.rand_default] + ) + ) + @jtu.skip_on_devices("gpu") # TODO(b/138003641): test fails on GPU. + @jtu.disable + def testCov(self, shape, dtype, rowvar, ddof, bias, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + onp_fun = partial(onp.cov, rowvar=rowvar, ddof=ddof, bias=bias) + lnp_fun = partial(lnp.cov, rowvar=rowvar, ddof=ddof, bias=bias) + tol = {onp.float32: 1e-5, onp.float64: 1e-13, onp.complex128: 1e-13} + tol = 7e-2 if jtu.device_under_test() == "tpu" else tol + tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=False, tol=tol + ) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol + ) + + def testIssue967(self): + self.assertRaises(TypeError, lambda: lnp.zeros(1.5)) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format( + shape, dtype, rowvar, ddof, bias + ), + "shape": shape, + "dtype": dtype, + "rowvar": rowvar, + "ddof": ddof, + "bias": bias, + "rng_factory": rng_factory, + } + for shape in [(5,), (10, 5), (3, 10)] + for dtype in number_dtypes + for rowvar in [True, False] + for bias in [True, False] + for ddof in [None, 2, 3] + for rng_factory in [jtu.rand_default] + ) + ) + @jtu.disable + def testCorrCoef(self, shape, dtype, rowvar, ddof, bias, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + mat = onp.asarray([rng(shape, dtype)]) + onp_fun = partial(onp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias) + lnp_fun = partial(lnp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias) + if not onp.any(onp.isclose(onp.std(mat), 0.0)): + self._CheckAgainstNumpy( + onp_fun, + lnp_fun, + args_maker, + check_dtypes=False, + tol=1e-2 if jtu.device_under_test() == "tpu" else None, + ) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shapes={}_dtype={}_indexing={}_sparse={}".format( + shapes, jtu.dtype_str(dtype), indexing, sparse + ), + "shapes": shapes, + "dtype": dtype, + "indexing": indexing, + "sparse": sparse, + "rng_factory": rng_factory, + } + for shapes in [(), (5,), (5, 3)] + for dtype in number_dtypes + for indexing in ["xy", "ij"] + for sparse in [False] # TODO(nareshmodi): Make sparse work + for rng_factory in [jtu.rand_default] + ) + ) + def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker( + rng, [(x,) for x in shapes], [dtype] * len(shapes) + ) + onp_fun = partial(onp.meshgrid, indexing=indexing, sparse=sparse) + lnp_fun = partial(lnp.meshgrid, indexing=indexing, sparse=sparse) + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": ( + "_start_shape={}_stop_shape={}_num={}_endpoint={}" + "_retstep={}_dtype={}" + ).format(start_shape, stop_shape, num, endpoint, retstep, dtype), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, + "endpoint": endpoint, + "retstep": retstep, + "dtype": dtype, + "rng_factory": rng_factory, + } + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + for retstep in [True, False] + for dtype in number_dtypes + + [ + None, + ] + for rng_factory in [jtu.rand_default] + ) + ) + def testLinspace( + self, start_shape, stop_shape, num, endpoint, retstep, dtype, rng_factory + ): + if not endpoint and onp.issubdtype(dtype, onp.integer): + # TODO(b/157597565): Support all dtypes when the tf op supports endpoint + # Currently, subtracting the step early leads to rounding errors for + # integers. + return + rng = rng_factory() + # relax default tolerances slightly + tol = jtu.tolerance(dtype if dtype else onp.float32) * 10 + args_maker = self._GetArgsMaker(rng, [start_shape, stop_shape], [dtype, dtype]) + start, stop = args_maker() + ndim = len(onp.shape(start + stop)) + for axis in range(-ndim, ndim): + lnp_op = lambda start, stop: lnp.linspace( + start, + stop, + num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + onp_op = lambda start, stop: onp.linspace( + start, + stop, + num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + self._CheckAgainstNumpy( + onp_op, lnp_op, args_maker, check_dtypes=False, tol=tol + ) + # floating-point compute between jitted platforms and non-jit + rounding + # cause unavoidable variation in integer truncation for some inputs. + if dtype in ( + inexact_dtypes + + [ + None, + ] + ): + self._CompileAndCheck( + lnp_op, + args_maker, + check_dtypes=False, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": ( + "_start_shape={}_stop_shape={}_num={}_endpoint={}" + "_base={}_dtype={}" + ).format( + start_shape, + stop_shape, + num, + endpoint, + base, + dtype.__name__ if dtype else "None", + ), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, + "endpoint": endpoint, + "base": base, + "dtype": dtype, + "rng_factory": rng_factory, + } + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + for base in [10.0, 2, onp.e] + for dtype in inexact_dtypes + + [ + None, + ] + for rng_factory in [jtu.rand_default] + ) + ) + def testLogspace( + self, start_shape, stop_shape, num, endpoint, base, dtype, rng_factory + ): + if ( + dtype in int_dtypes + and jtu.device_under_test() in ("gpu", "tpu") + and not FLAGS.enable_x64 + ): + raise unittest.SkipTest( + "GPUx32 truncated exponentiation" + " doesn't exactly match other platforms." + ) + rng = rng_factory() + # relax default tolerances slightly + tol = { + onp.float16: 2e-2, + onp.float32: 1e-2, + onp.float64: 1e-6, + onp.complex64: 1e-3, + onp.complex128: 1e-6, + } + args_maker = self._GetArgsMaker(rng, [start_shape, stop_shape], [dtype, dtype]) + start, stop = args_maker() + ndim = len(onp.shape(start + stop)) + for axis in range(-ndim, ndim): + lnp_op = lambda start, stop: lnp.logspace( + start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis + ) + onp_op = lambda start, stop: onp.logspace( + start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis + ) + self._CheckAgainstNumpy( + onp_op, lnp_op, args_maker, check_dtypes=False, tol=tol + ) + if dtype in ( + inexact_dtypes + + [ + None, + ] + ): + # Why do compiled and op-by-op float16 np.power numbers differ + # slightly more than expected? + atol = {onp.float16: 1e-2} + self._CompileAndCheck( + lnp_op, + args_maker, + check_dtypes=False, + atol=atol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": ( + "_start_shape={}_stop_shape={}_num={}_endpoint={}" "_dtype={}" + ).format(start_shape, stop_shape, num, endpoint, dtype), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, + "endpoint": endpoint, + "dtype": dtype, + "rng_factory": rng_factory, + } + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + # NB: numpy's geomspace gives nonsense results on integer types + for dtype in inexact_dtypes + + [ + None, + ] + for rng_factory in [jtu.rand_default] + ) + ) + def testGeomspace(self, start_shape, stop_shape, num, endpoint, dtype, rng_factory): + rng = rng_factory() + # relax default tolerances slightly + tol = {onp.float16: 4e-3, onp.float32: 2e-3, onp.complex128: 1e-14} + + def args_maker(): + """Test the set of inputs onp.geomspace is well-defined on.""" + start, stop = self._GetArgsMaker( + rng, [start_shape, stop_shape], [dtype, dtype] + )() + # onp.geomspace can't handle differently ranked tensors + # w. negative numbers! + start, stop = lnp.broadcast_arrays(start, stop) + if dtype in complex_dtypes: + return start, stop + # to avoid NaNs, non-complex start and stop cannot + # differ in sign, elementwise + start = start * lnp.sign(start) * lnp.sign(stop) + return start, stop + + start, stop = args_maker() + ndim = len(onp.shape(start + stop)) + for axis in range(-ndim, ndim): + + def lnp_op(start, stop): + return lnp.geomspace( + start, stop, num, endpoint=endpoint, dtype=dtype, axis=axis + ) + + def onp_op(start, stop): + start = start.astype(onp.float32) if dtype == lnp.bfloat16 else start + stop = stop.astype(onp.float32) if dtype == lnp.bfloat16 else stop + return onp.geomspace( + start, + stop, + num, + endpoint=endpoint, + dtype=dtype if dtype != lnp.bfloat16 else onp.float32, + axis=axis, + ).astype(dtype) + + self._CheckAgainstNumpy( + onp_op, lnp_op, args_maker, check_dtypes=False, tol=tol + ) + if dtype in ( + inexact_dtypes + + [ + None, + ] + ): + self._CompileAndCheck( + lnp_op, + args_maker, + check_dtypes=False, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @jtu.disable + def testDisableNumpyRankPromotionBroadcasting(self): + try: + prev_flag = FLAGS.jax_numpy_rank_promotion + FLAGS.jax_numpy_rank_promotion = "allow" + lnp.ones(2) + lnp.ones((1, 2)) # works just fine + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + try: + prev_flag = FLAGS.jax_numpy_rank_promotion + FLAGS.jax_numpy_rank_promotion = "raise" + self.assertRaises(ValueError, lambda: lnp.ones(2) + lnp.ones((1, 2))) + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + try: + prev_flag = FLAGS.jax_numpy_rank_promotion + FLAGS.jax_numpy_rank_promotion = "warn" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + lnp.ones(2) + lnp.ones((1, 2)) + assert len(w) > 0 + msg = str(w[-1].message) + expected_msg = ( + "Following NumPy automatic rank promotion for add on " + "shapes (2,) (1, 2)." + ) + self.assertEqual(msg[: len(expected_msg)], expected_msg) + + prev_len = len(w) + lnp.ones(2) + 3 + self.assertEqual(len(w), prev_len) # don't want to warn for scalars + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + def testStackArrayArgument(self): + # tests https://github.com/google/jax/issues/1271 + @npe.jit + def foo(x): + return lnp.stack(x) + + foo(onp.zeros(2)) # doesn't crash + + @npe.jit + def foo(x): + return lnp.concatenate(x) + + foo(onp.zeros((2, 2))) # doesn't crash + + @jtu.disable + def testReluGradientConstants(self): + # This is a regression test that verifies that constants associated with the + # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the + # outermost jaxpr. This was producing some large materialized constants for + # every relu activation in a model. + def body(i, xy): + x, y = xy + y = y + jax.grad(lambda z: lnp.sum(lnp.maximum(z, 0.0)))(x) + return x, y + + f = lambda y: lax.fori_loop(0, 5, body, (y, y)) + wrapped = linear_util.wrap_init(f) + pv = partial_eval.PartialVal( + (jax.core.ShapedArray((3, 4), onp.float32), jax.core.unit) + ) + _, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv]) + self.assertFalse( + any( + onp.array_equal(x, onp.full((3, 4), 2.0, dtype=onp.float32)) + for x in consts + ) + ) + + @named_parameters( + { + "testcase_name": "_from={}_to={}".format(from_shape, to_shape), + "rng_factory": rng_factory, + "from_shape": from_shape, + "to_shape": to_shape, + } + for from_shape, to_shape in [ + [(1, 3), (4, 3)], + [(3,), (2, 1, 3)], + [(3,), (3, 3)], + [(1,), (3,)], + ] + for rng_factory in [jtu.rand_default] + ) + def testBroadcastTo(self, from_shape, to_shape, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [from_shape], [onp.float32]) + onp_op = lambda x: onp.broadcast_to(x, to_shape) + lnp_op = lambda x: lnp.broadcast_to(x, to_shape) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + def testBroadcastToIssue1522(self): + self.assertRaisesRegex( + Exception, + "Unable to broadcast", + lambda: lnp.broadcast_to(onp.ones((2, 3)), (1, 3)), + ) + + def testBroadcastToIntIssue1548(self): + self.assertAllClose( + lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)), check_dtypes=False + ) + + def testBroadcastToOnScalar(self): + self.assertIsInstance(lnp.broadcast_to(10.0, ()), lnp.ndarray) + self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray) + + @jtu.disable + def testPrecision(self): + + ones_1d = onp.ones((2,)) + ones_2d = onp.ones((2, 2)) + ones_3d = onp.ones((2, 2, 2)) + HIGHEST = lax.Precision.HIGHEST + + jtu.assert_dot_precision(None, lnp.dot, ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.dot, precision=HIGHEST), ones_1d, ones_1d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.dot, precision=HIGHEST), ones_3d, ones_3d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.matmul, precision=HIGHEST), ones_2d, ones_2d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.vdot, precision=HIGHEST), ones_1d, ones_1d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.tensordot, axes=2, precision=HIGHEST), ones_2d, ones_2d + ) + jtu.assert_dot_precision( + HIGHEST, + partial(lnp.tensordot, axes=(0, 0), precision=HIGHEST), + ones_1d, + ones_1d, + ) + jtu.assert_dot_precision( + HIGHEST, + partial(lnp.tensordot, axes=((0,), (0,)), precision=HIGHEST), + ones_1d, + ones_1d, + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.einsum, "i,i", precision=HIGHEST), ones_1d, ones_1d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.einsum, "ij,ij", precision=HIGHEST), ones_2d, ones_2d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.inner, precision=HIGHEST), ones_1d, ones_1d + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}_{}".format( + shape, + jtu.dtype_str(key_dtype), + jtu.dtype_str(value_dtype), + dimension, + ).replace(" ", ""), + "shape": shape, + "key_dtype": key_dtype, + "value_dtype": value_dtype, + "dimension": dimension, + "rng_factory": rng_factory, + } + for shape in all_shapes + for key_dtype in minus(number_dtypes, complex_dtypes) + for value_dtype in all_dtypes + for dimension in range(-len(shape), len(shape)) + for rng_factory in [jtu.rand_default] + ) + ) + @new_test + def testSortKeyValue(self, shape, key_dtype, value_dtype, dimension, rng_factory): + def onp_ref(keys, values): + idxs = list(onp.ix_(*[onp.arange(d) for d in keys.shape])) + idxs[dimension] = onp.argsort(keys, axis=dimension) + return keys[tuple(idxs)], values[tuple(idxs)] + + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape, shape], [key_dtype, value_dtype]) + op = partial(npe.sort_key_val, dimension=dimension) + self._CheckAgainstNumpy(onp_ref, op, args_maker, check_dtypes=True) + # sort_key_val requires known rank. + # XLA only has TopKV2 (used by tf.argsort) kernels on those dtypes + # (b/169194137). + check_xla = key_dtype in (onp.uint32, onp.int32, onp.float32, lnp.bfloat16) + self._CompileAndCheck( + op, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + check_unknown_rank=False, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + +# Most grad tests are at the lax level (see lax_test.py), but we add some here +# as needed for e.g. particular compound ops of interest. + +GradTestSpec = collections.namedtuple( + "GradTestSpec", ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"] +) + + +def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): + return GradTestSpec(op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) + + +GRAD_TEST_RECORDS = [ + grad_test_spec( + lnp.arcsinh, + nargs=1, + order=2, + rng_factory=jtu.rand_positive, + dtypes=[onp.float64, onp.complex64], + tol=1e-4, + ), + grad_test_spec( + lnp.arccosh, + nargs=1, + order=2, + rng_factory=jtu.rand_positive, + dtypes=[onp.float64, onp.complex64], + tol=1e-4, + ), + grad_test_spec( + lnp.arctanh, + nargs=1, + order=2, + rng_factory=partial(jtu.rand_uniform, -0.9, 0.9), + dtypes=[onp.float64, onp.complex64], + tol=1e-4, + ), +] + +GradSpecialValuesTestSpec = collections.namedtuple( + "GradSpecialValuesTestSpec", ["op", "values", "order"] +) + +GRAD_SPECIAL_VALUE_TEST_RECORDS = [ + GradSpecialValuesTestSpec(lnp.arcsinh, [0.0, 1000.0], 2), + GradSpecialValuesTestSpec(lnp.arccosh, [1000.0], 2), + GradSpecialValuesTestSpec(lnp.arctanh, [0.0], 2), + # TODO(wangpeng): Add `GradSpecialValuesTestSpec(lnp.sinc, [0.], 1)` +] + + +def num_float_bits(dtype): + return lnp.finfo(dtypes.canonicalize_dtype(dtype)).bits + + +class NumpyGradTests(jtu.TestCase): + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.name, shapes, itertools.repeat(dtype) + ), + "op": rec.op, + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtype": dtype, + "order": rec.order, + "tol": rec.tol, + } + for shapes in CombosWithReplacement(nonempty_shapes, rec.nargs) + for dtype in rec.dtypes + ) + for rec in GRAD_TEST_RECORDS + ) + ) + @jtu.disable + def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): + rng = rng_factory() + tol = {onp.float32: 1e-1, onp.complex64: 1e-1} + args = tuple(rng(shape, dtype) for shape in shapes) + check_grads(op, args, order, ["fwd", "rev"], tol, tol) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}".format(rec.op.__name__, special_value), + "op": rec.op, + "special_value": special_value, + "order": rec.order, + } + for special_value in rec.values + ) + for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS + ) + ) + @jtu.disable + def testOpGradSpecialValue(self, op, special_value, order): + check_grads( + op, (special_value,), order, ["fwd", "rev"], atol={onp.float32: 3e-3} + ) + + @jtu.disable + def testTakeAlongAxisIssue1521(self): + # https://github.com/google/jax/issues/1521 + idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1)) + + def f(x): + y = x * lnp.arange(3.0).reshape((1, 3)) + return lnp.take_along_axis(y, idx, -1).sum() + + check_grads(f, (1.0,), order=1) + + +if __name__ == "__main__": + tf.enable_v2_behavior() + lnp.enable_numpy_behavior() + absltest.main() diff --git a/tests/tf_numpy/jax/utils.py b/tests/tf_numpy/jax/utils.py new file mode 100644 index 000000000..71a900fdd --- /dev/null +++ b/tests/tf_numpy/jax/utils.py @@ -0,0 +1,992 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import re +import sys +import unittest +import warnings +import zlib +from contextlib import contextmanager +from distutils.util import strtobool +from functools import partial +from typing import Dict, Sequence, Union + +import numpy as onp +import numpy.random as npr +import scipy +import tensorflow.compat.v2 as tf +from absl.testing import parameterized + +import trax.tf_numpy.extensions as npe +import trax.tf_numpy.numpy as tf_np +from tests.tf_numpy.jax.config import flags + +tree_map = tf.nest.map_structure +tree_multimap = tf.nest.map_structure + + +FLAGS = flags.FLAGS + + +# TODO(wangpeng): Remove this flag after broken tests are fixed +flags.DEFINE_bool("enable_x64", strtobool("False"), "Enable 64-bit types to be used.") + + +flags.DEFINE_enum( + "test_dut", + "", + enum_values=["", "cpu", "gpu", "tpu"], + help="Describes the device under test in case special consideration is required.", +) + + +flags.DEFINE_integer( + "num_generated_cases", 10, help="Number of generated cases to test" +) + + +EPS = 1e-4 + + +# Default dtypes corresponding to Python scalars. +python_scalar_dtypes = { + bool: onp.dtype(onp.bool_), + int: onp.dtype(onp.int_), + float: onp.dtype(onp.float_), + complex: onp.dtype(onp.complex_), +} + + +def _dtype(x): + if isinstance(x, tf.Tensor): + return x.dtype.as_numpy_dtype + return ( + getattr(x, "dtype", None) + or onp.dtype(python_scalar_dtypes.get(type(x), None)) + or onp.asarray(x).dtype + ) + + +def is_sequence(x): + try: + iter(x) + except TypeError: + return False + else: + return True + + +_default_tolerance = { + onp.dtype(onp.bool_): 0, + onp.dtype(onp.int8): 0, + onp.dtype(onp.int16): 0, + onp.dtype(onp.int32): 0, + onp.dtype(onp.int64): 0, + onp.dtype(onp.uint8): 0, + onp.dtype(onp.uint16): 0, + onp.dtype(onp.uint32): 0, + onp.dtype(onp.uint64): 0, + # TODO(b/154768983): onp.dtype(dtypes.bfloat16): 1e-2, + onp.dtype(onp.float16): 1e-3, + onp.dtype(onp.float32): 1e-6, + onp.dtype(onp.float64): 1e-15, + onp.dtype(onp.complex64): 1e-6, + onp.dtype(onp.complex128): 1e-15, +} + + +def default_tolerance(): + return _default_tolerance + + +default_gradient_tolerance = { + # TODO(b/154768983): onp.dtype(dtypes.bfloat16): 1e-1, + onp.dtype(onp.float16): 1e-2, + onp.dtype(onp.float32): 2e-3, + onp.dtype(onp.float64): 1e-5, + onp.dtype(onp.complex64): 1e-3, + onp.dtype(onp.complex128): 1e-5, +} + + +def _assert_numpy_allclose(a, b, atol=None, rtol=None): + # TODO(b/154768983): + # a = a.astype(onp.float32) if a.dtype == dtypes.bfloat16 else a + # b = b.astype(onp.float32) if b.dtype == dtypes.bfloat16 else b + kw = {} + if atol: + kw["atol"] = atol + if rtol: + kw["rtol"] = rtol + onp.testing.assert_allclose(a, b, **kw) + + +def tolerance(dtype, tol=None): + tol = {} if tol is None else tol + if not isinstance(tol, dict): + return tol + tol = {onp.dtype(key): value for key, value in tol.items()} + dtype = onp.dtype(dtype) + return tol.get(dtype, default_tolerance()[dtype]) + + +def _normalize_tolerance(tol): + tol = tol or 0 + if isinstance(tol, dict): + return {onp.dtype(k): v for k, v in tol.items()} + else: + return {k: tol for k in _default_tolerance} + + +def join_tolerance(tol1, tol2): + tol1 = _normalize_tolerance(tol1) + tol2 = _normalize_tolerance(tol2) + out = tol1 + for k, v in tol2.items(): + out[k] = max(v, tol1.get(k, 0)) + return out + + +def _assert_numpy_close(a, b, atol=None, rtol=None): + assert a.shape == b.shape + atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol)) + rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol)) + _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size) + + +def check_eq(xs, ys): + tree_all(tree_multimap(_assert_numpy_allclose, xs, ys)) + + +def check_close(xs, ys, atol=None, rtol=None): + assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol) + tree_all(tree_multimap(assert_close, xs, ys)) + + +def inner_prod(xs, ys): + def contract(x, y): + return onp.real(onp.dot(onp.conj(x).reshape(-1), y.reshape(-1))) + + return tree_reduce(onp.add, tree_multimap(contract, xs, ys)) + + +add = partial(tree_multimap, lambda x, y: onp.add(x, y, dtype=_dtype(x))) +sub = partial(tree_multimap, lambda x, y: onp.subtract(x, y, dtype=_dtype(x))) +conj = partial(tree_map, lambda x: onp.conj(x, dtype=_dtype(x))) + + +def scalar_mul(xs, a): + return tree_map(lambda x: onp.multiply(x, a, dtype=_dtype(x)), xs) + + +def rand_like(rng, x): + shape = onp.shape(x) + dtype = _dtype(x) + randn = lambda: onp.asarray(rng.randn(*shape), dtype=dtype) + if onp.issubdtype(dtype, onp.complexfloating): + return randn() + dtype.type(1.0j) * randn() + else: + return randn() + + +def numerical_jvp(f, primals, tangents, eps=EPS): + delta = scalar_mul(tangents, eps) + f_pos = f(*add(primals, delta)) + f_neg = f(*sub(primals, delta)) + return scalar_mul(sub(f_pos, f_neg), 0.5 / eps) + + +def _merge_tolerance(tol, default): + if tol is None: + return default + if not isinstance(tol, dict): + return tol + out = default.copy() + for k, v in tol.items(): + out[onp.dtype(k)] = v + return out + + +def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS): + atol = _merge_tolerance(atol, default_gradient_tolerance) + rtol = _merge_tolerance(rtol, default_gradient_tolerance) + rng = onp.random.RandomState(0) + tangent = tree_map(partial(rand_like, rng), args) + v_out, t_out = f_jvp(args, tangent) + v_out_expected = f(*args) + t_out_expected = numerical_jvp(f, args, tangent, eps=eps) + # In principle we should expect exact equality of v_out and v_out_expected, + # but due to nondeterminism especially on GPU (e.g., due to convolution + # autotuning) we only require "close". + check_close(v_out, v_out_expected, atol=atol, rtol=rtol) + check_close(t_out, t_out_expected, atol=atol, rtol=rtol) + + +def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS): + atol = _merge_tolerance(atol, default_gradient_tolerance) + rtol = _merge_tolerance(rtol, default_gradient_tolerance) + _rand_like = partial(rand_like, onp.random.RandomState(0)) + v_out, vjpfun = f_vjp(*args) + v_out_expected = f(*args) + check_close(v_out, v_out_expected, atol=atol, rtol=rtol) + tangent = tree_map(_rand_like, args) + tangent_out = numerical_jvp(f, args, tangent, eps=eps) + cotangent = tree_map(_rand_like, v_out) + cotangent_out = conj(vjpfun(conj(cotangent))) + ip = inner_prod(tangent, cotangent_out) + ip_expected = inner_prod(tangent_out, cotangent) + check_close(ip, ip_expected, atol=atol, rtol=rtol) + + +def device_under_test(): + return FLAGS.test_dut + + +def if_device_under_test(device_type: Union[str, Sequence[str]], if_true, if_false): + """Chooses `if_true` of `if_false` based on device_under_test.""" + if device_under_test() in ( + [device_type] if isinstance(device_type, str) else device_type + ): + return if_true + else: + return if_false + + +def supported_dtypes(): + if device_under_test() == "tpu": + return { + onp.bool_, + onp.int32, + onp.uint32, + dtypes.bfloat16, + onp.float32, + onp.complex64, + } + else: + return { + onp.bool_, + onp.int8, + onp.int16, + onp.int32, + onp.int64, + onp.uint8, + onp.uint16, + onp.uint32, + onp.uint64, + dtypes.bfloat16, + onp.float16, + onp.float32, + onp.float64, + onp.complex64, + onp.complex128, + } + + +def skip_if_unsupported_type(dtype): + if dtype not in supported_dtypes(): + raise unittest.SkipTest(f"Type {dtype} not supported on {device_under_test()}") + + +def skip_on_devices(*disabled_devices): + """A decorator for test methods to skip the test on certain devices.""" + + def skip(test_method): + @functools.wraps(test_method) + def test_method_wrapper(self, *args, **kwargs): + device = device_under_test() + if device in disabled_devices: + test_name = getattr(test_method, "__name__", "[unknown test]") + raise unittest.SkipTest( + f"{test_name} not supported on {device.upper()}." + ) + return test_method(self, *args, **kwargs) + + return test_method_wrapper + + return skip + + +def skip_on_flag(flag_name, skip_value): + """A decorator for test methods to skip the test when flags are set.""" + + def skip(test_method): # pylint: disable=missing-docstring + @functools.wraps(test_method) + def test_method_wrapper(self, *args, **kwargs): + flag_value = getattr(FLAGS, flag_name) + if flag_value == skip_value: + test_name = getattr(test_method, "__name__", "[unknown test]") + raise unittest.SkipTest( + f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}" + ) + return test_method(self, *args, **kwargs) + + return test_method_wrapper + + return skip + + +# TODO(phawkins): workaround for bug https://github.com/google/jax/issues/432 +# Delete this code after the minimum jaxlib version is 0.1.46 or greater. +skip_on_mac_linalg_bug = partial( + unittest.skipIf, + ( + sys.platform == "darwin" + and scipy.version.version > "1.1.0" + and lib.version < (0, 1, 46) + ), + "Test fails on Mac with new scipy (issue #432)", +) + + +def format_test_name_suffix(opname, shapes, dtypes): + arg_descriptions = ( + format_shape_dtype_string(shape, dtype) for shape, dtype in zip(shapes, dtypes) + ) + return "{}_{}".format(opname.capitalize(), "_".join(arg_descriptions)) + + +# We use special symbols, represented as singleton objects, to distinguish +# between NumPy scalars, Python scalars, and 0-D arrays. +class ScalarShape: + def __len__(self): + return 0 + + def __getitem__(self, i): + raise IndexError(f"index {i} out of range.") + + +class _NumpyScalar(ScalarShape): + pass + + +class _PythonScalar(ScalarShape): + pass + + +NUMPY_SCALAR_SHAPE = _NumpyScalar() +PYTHON_SCALAR_SHAPE = _PythonScalar() + + +def _dims_of_shape(shape): + """Converts `shape` to a tuple of dimensions.""" + if type(shape) in (list, tuple): + return shape + elif isinstance(shape, ScalarShape): + return () + else: + raise TypeError(type(shape)) + + +def _cast_to_shape(value, shape, dtype): + """Casts `value` to the correct Python type for `shape` and `dtype`.""" + if shape is NUMPY_SCALAR_SHAPE: + # explicitly cast to NumPy scalar in case `value` is a Python scalar. + return onp.dtype(dtype).type(value) + elif shape is PYTHON_SCALAR_SHAPE: + # explicitly cast to Python scalar via https://stackoverflow.com/a/11389998 + return onp.asarray(value).item() + elif type(shape) in (list, tuple): + assert onp.shape(value) == tuple(shape) + return value + else: + raise TypeError(type(shape)) + + +def dtype_str(dtype): + return onp.dtype(dtype).name + + +def format_shape_dtype_string(shape, dtype): + if shape is NUMPY_SCALAR_SHAPE: + return dtype_str(dtype) + elif shape is PYTHON_SCALAR_SHAPE: + return "py" + dtype_str(dtype) + elif type(shape) in (list, tuple): + shapestr = ",".join(str(dim) for dim in shape) + return "{}[{}]".format(dtype_str(dtype), shapestr) + elif type(shape) is int: + return "{}[{},]".format(dtype_str(dtype), shape) + elif isinstance(shape, onp.ndarray): + return "{}[{}]".format(dtype_str(dtype), shape) + else: + raise TypeError(type(shape)) + + +def _rand_dtype(rand, shape, dtype, scale=1.0, post=lambda x: x): + """Produce random values given shape, dtype, scale, and post-processor. + + Args: + rand: a function for producing random values of a given shape, e.g. a + bound version of either onp.RandomState.randn or onp.RandomState.rand. + shape: a shape value as a tuple of positive integers. + dtype: a numpy dtype. + scale: optional, a multiplicative scale for the random values (default 1). + post: optional, a callable for post-processing the random values (default + identity). + + Returns: + An ndarray of the given shape and dtype using random values based on a call + to rand but scaled, converted to the appropriate dtype, and post-processed. + """ + r = lambda: onp.asarray(scale * rand(*_dims_of_shape(shape)), dtype) + if onp.issubdtype(dtype, onp.complexfloating): + vals = r() + 1.0j * r() + else: + vals = r() + return _cast_to_shape(onp.asarray(post(vals), dtype), shape, dtype) + + +def rand_default(scale=3): + randn = npr.RandomState(0).randn + return partial(_rand_dtype, randn, scale=scale) + + +def rand_nonzero(): + post = lambda x: onp.where(x == 0, onp.array(1, dtype=x.dtype), x) + randn = npr.RandomState(0).randn + return partial(_rand_dtype, randn, scale=3, post=post) + + +def rand_positive(): + post = lambda x: x + 1 + rand = npr.RandomState(0).rand + return partial(_rand_dtype, rand, scale=2, post=post) + + +def rand_small(): + randn = npr.RandomState(0).randn + return partial(_rand_dtype, randn, scale=1e-3) + + +def rand_not_small(offset=10.0): + post = lambda x: x + onp.where(x > 0, offset, -offset) + randn = npr.RandomState(0).randn + return partial(_rand_dtype, randn, scale=3.0, post=post) + + +def rand_small_positive(): + rand = npr.RandomState(0).rand + return partial(_rand_dtype, rand, scale=2e-5) + + +def rand_uniform(low=0.0, high=1.0): + assert low < high + rand = npr.RandomState(0).rand + post = lambda x: x * (high - low) + low + return partial(_rand_dtype, rand, post=post) + + +def rand_some_equal(): + randn = npr.RandomState(0).randn + rng = npr.RandomState(0) + + def post(x): + x_ravel = x.ravel() + if len(x_ravel) == 0: + return x + flips = rng.rand(*onp.shape(x)) < 0.5 + return onp.where(flips, x_ravel[0], x) + + return partial(_rand_dtype, randn, scale=100.0, post=post) + + +def rand_some_inf(): + """Return a random sampler that produces infinities in floating types.""" + rng = npr.RandomState(1) + base_rand = rand_default() + + """ + TODO: Complex numbers are not correctly tested + If blocks should be switched in order, and relevant tests should be fixed + """ + + def rand(shape, dtype): + """The random sampler function.""" + if not onp.issubdtype(dtype, onp.floating): + # only float types have inf + return base_rand(shape, dtype) + + if onp.issubdtype(dtype, onp.complexfloating): + base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype + out = rand(shape, base_dtype) + onp.array(1j, dtype) * rand( + shape, base_dtype + ) + return _cast_to_shape(out, shape, dtype) + + dims = _dims_of_shape(shape) + posinf_flips = rng.rand(*dims) < 0.1 + neginf_flips = rng.rand(*dims) < 0.1 + + vals = base_rand(shape, dtype) + vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) + vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) + + return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) + + return rand + + +def rand_some_nan(): + """Return a random sampler that produces nans in floating types.""" + rng = npr.RandomState(1) + base_rand = rand_default() + + def rand(shape, dtype): + """The random sampler function.""" + if onp.issubdtype(dtype, onp.complexfloating): + base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype + out = rand(shape, base_dtype) + onp.array(1j, dtype) * rand( + shape, base_dtype + ) + return _cast_to_shape(out, shape, dtype) + + if not onp.issubdtype(dtype, onp.floating): + # only float types have inf + return base_rand(shape, dtype) + + dims = _dims_of_shape(shape) + nan_flips = rng.rand(*dims) < 0.1 + + vals = base_rand(shape, dtype) + vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) + + return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) + + return rand + + +def rand_some_inf_and_nan(): + """Return a random sampler that produces infinities in floating types.""" + rng = npr.RandomState(1) + base_rand = rand_default() + + """ + TODO: Complex numbers are not correctly tested + If blocks should be switched in order, and relevant tests should be fixed + """ + + def rand(shape, dtype): + """The random sampler function.""" + if not onp.issubdtype(dtype, onp.floating): + # only float types have inf + return base_rand(shape, dtype) + + if onp.issubdtype(dtype, onp.complexfloating): + base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype + out = rand(shape, base_dtype) + onp.array(1j, dtype) * rand( + shape, base_dtype + ) + return _cast_to_shape(out, shape, dtype) + + dims = _dims_of_shape(shape) + posinf_flips = rng.rand(*dims) < 0.1 + neginf_flips = rng.rand(*dims) < 0.1 + nan_flips = rng.rand(*dims) < 0.1 + + vals = base_rand(shape, dtype) + vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) + vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) + vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) + + return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) + + return rand + + +# TODO(mattjj): doesn't handle complex types +def rand_some_zero(): + """Return a random sampler that produces some zeros.""" + rng = npr.RandomState(1) + base_rand = rand_default() + + def rand(shape, dtype): + """The random sampler function.""" + dims = _dims_of_shape(shape) + zeros = rng.rand(*dims) < 0.5 + + vals = base_rand(shape, dtype) + vals = onp.where(zeros, onp.array(0, dtype=dtype), vals) + + return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) + + return rand + + +def rand_int(low, high=None): + randint = npr.RandomState(0).randint + + def fn(shape, dtype): + return randint(low, high=high, size=shape, dtype=dtype) + + return fn + + +def rand_unique_int(): + randchoice = npr.RandomState(0).choice + + def fn(shape, dtype): + return randchoice( + onp.arange(onp.prod(shape), dtype=dtype), size=shape, replace=False + ) + + return fn + + +def rand_bool(): + rng = npr.RandomState(0) + + def generator(shape, dtype): + return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype) + + return generator + + +def check_raises(thunk, err_type, msg): + try: + thunk() + assert False + except err_type as e: + assert str(e).startswith(msg), "\n{}\n\n{}\n".format(e, msg) + + +def check_raises_regexp(thunk, err_type, pattern): + try: + thunk() + assert False + except err_type as e: + assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern) + + +def _iter_eqns(jaxpr): + # TODO(necula): why doesn't this search in params? + for eqn in jaxpr.eqns: + yield eqn + for subjaxpr in core.subjaxprs(jaxpr): + yield from _iter_eqns(subjaxpr) + + +def assert_dot_precision(expected_precision, fun, *args): + jaxpr = api.make_jaxpr(fun)(*args) + precisions = [ + eqn.params["precision"] + for eqn in _iter_eqns(jaxpr.jaxpr) + if eqn.primitive == lax.dot_general_p + ] + for precision in precisions: + msg = "Unexpected precision: {} != {}".format(expected_precision, precision) + assert precision == expected_precision, msg + + +_CACHED_INDICES: Dict[int, Sequence[int]] = {} + + +def cases_from_list(xs): + xs = list(xs) + n = len(xs) + k = min(n, FLAGS.num_generated_cases) + # Random sampling for every parameterized test is expensive. Do it once and + # cache the result. + indices = _CACHED_INDICES.get(n) + if indices is None: + rng = npr.RandomState(42) + _CACHED_INDICES[n] = indices = rng.permutation(n) + return [xs[i] for i in indices[:k]] + + +def cases_from_gens(*gens): + sizes = [1, 3, 10] + cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1 + for size in sizes: + for i in range(cases_per_size): + yield ("_{}_{}".format(size, i),) + tuple(gen(size) for gen in gens) + + +def to_np(a): + return tf.nest.map_structure(tf_np.asarray, a) + + +def to_tf_fn(f): + return lambda *args: f(*to_np(args)) + + +class TestCase(parameterized.TestCase): + """Base class for tests including numerical checks and boilerplate.""" + + # copied from jax.test_util + def setUp(self): + super().setUp() + self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) + + # copied from jax.test_util + def rng(self): + return self._rng + + # TODO(mattjj): this obscures the error messages from failures, figure out how + # to re-enable it + # def tearDown(self) -> None: + # assert core.reset_trace_state() + + def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None): + """Assert that x and y are close (up to numerical tolerances).""" + self.assertEqual(x.shape, y.shape) + atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) + rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) + + _assert_numpy_allclose(x, y, atol=atol, rtol=rtol) + + if check_dtypes: + self.assertDtypesMatch(x, y) + + def assertDtypesMatch(self, x, y): + if FLAGS.enable_x64: + self.assertEqual(_dtype(x), _dtype(y)) + + def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None): + """Assert that x and y, either arrays or nested tuples/lists, are close.""" + if isinstance(x, dict): + self.assertIsInstance(y, dict) + self.assertEqual(set(x.keys()), set(y.keys())) + for k in x: + self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol) + elif is_sequence(x) and not hasattr(x, "__array__"): + self.assertTrue(is_sequence(y) and not hasattr(y, "__array__")) + self.assertEqual(len(x), len(y)) + for x_elt, y_elt in zip(x, y): + self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol) + elif hasattr(x, "__array__") or onp.isscalar(x): + self.assertTrue(hasattr(y, "__array__") or onp.isscalar(y)) + if check_dtypes: + self.assertDtypesMatch(x, y) + x = onp.asarray(x) + y = onp.asarray(y) + self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol) + elif x == y: + return + else: + raise TypeError((type(x), type(y))) + + def assertMultiLineStrippedEqual(self, expected, what): + """Asserts two strings are equal, after stripping each line.""" + ignore_space_re = re.compile(r"\s*\n\s*") + expected_clean = re.sub(ignore_space_re, "\n", expected.strip()) + what_clean = re.sub(ignore_space_re, "\n", what.strip()) + self.assertMultiLineEqual( + expected_clean, + what_clean, + msg="Found\n{}\nExpecting\n{}".format(what, expected), + ) + + def _CheckAgainstNumpy( + self, numpy_reference_op, lax_op, args_maker, check_dtypes=True, tol=None + ): + args = args_maker() + lax_ans = lax_op(*args) + numpy_ans = numpy_reference_op(*args) + self.assertAllClose( + numpy_ans, lax_ans, check_dtypes=check_dtypes, atol=tol, rtol=tol + ) + + def _CompileAndCheck( + self, + fun, + args_maker, + check_dtypes=True, + rtol=None, + atol=None, + check_eval_on_shapes=True, + check_incomplete_shape=True, + check_unknown_rank=True, + static_argnums=(), + check_experimental_compile=True, + check_xla_forced_compile=True, + ): + """Compiles the function and checks the results. + + Args: + fun: the function to be checked. + args_maker: a callable that returns a tuple which will be used as the + positional arguments. + check_dtypes: whether to check that the result dtypes from non-compiled + and compiled runs agree. + rtol: relative tolerance for allclose assertions. + atol: absolute tolerance for allclose assertions. + check_eval_on_shapes: whether to run `eval_on_shapes` on the function and + check that the result shapes and dtypes are correct. + check_incomplete_shape: whether to check that the function can handle + incomplete shapes (including those with and without a known rank). + check_unknown_rank: (only has effect when check_incomplete_shape is True) + whether to check that the function can handle unknown ranks. + static_argnums: indices of arguments to be treated as static arguments for + `jit` and `eval_on_shapes`. + check_experimental_compile: whether to check compilation with + experimental_compile=True (in addition to compilation without the flag). + check_xla_forced_compile: whether to check compilation with + forced_compile=True (in addition to compilation without the flag). This + flag is different from experimental_compile because it enforces + whole-function compilation while the latter doesn't. TPU requires + whole-function compilation. + """ + args = args_maker() + + for x in args: + if not hasattr(x, "dtype"): + # If there is a input that doesn't have dtype info, jit and + # eval_on_shapes may pick a different dtype for it than numpy, so we + # skip the dtype check. + check_dtypes = False + + python_ans = fun(*args) + + python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans) + onp_shapes = tf.nest.map_structure( + lambda x: onp.shape(onp.asarray(x)), python_ans + ) + self.assertEqual(python_shapes, onp_shapes) + + def check_compile(**kwargs): + # `wrapped_fun` and `python_should_be_executing` are used to check that + # when the jitted function is called the second time, the original Python + # function won't be executed. + def wrapped_fun(*args): + self.assertTrue(python_should_be_executing) + return fun(*args) + + cfun = npe.jit(wrapped_fun, static_argnums=static_argnums, **kwargs) + python_should_be_executing = True + monitored_ans = cfun(*args) + + python_should_be_executing = False + compiled_ans = cfun(*args) + + self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol) + self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) + + # Run `cfun` with a different set of arguments to check that changing + # arguments won't cause recompilation. + + new_args = args_maker() + + skip_retracing_test = False + for old, new in zip(tf.nest.flatten(args), tf.nest.flatten(new_args)): + if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new): + # If the old and new arguments result in different dtypes (because + # they fall into different value ranges), tf-numpy will retrace, so we + # skip the no-retrace test. + skip_retracing_test = True + + if not skip_retracing_test: + python_should_be_executing = True + new_python_ans = fun(*new_args) + python_should_be_executing = False + compiled_ans = cfun(*new_args) + self.assertAllClose( + new_python_ans, compiled_ans, check_dtypes, atol, rtol + ) + + check_compile() + if check_experimental_compile: + check_compile(experimental_compile=True) + if check_xla_forced_compile: + check_compile(xla_forced_compile=True) + + if check_eval_on_shapes: + # Check that npe.eval_on_shapes can get complete output shapes given + # complete input shapes. + cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums) + compiled_ans = cfun(*args) + flat_python_ans = tf.nest.flatten(python_ans) + flat_compiled_ans = tf.nest.flatten(compiled_ans) + self.assertEqual(len(flat_python_ans), len(flat_compiled_ans)) + for a, b in zip(flat_python_ans, flat_compiled_ans): + if hasattr(a, "shape"): + self.assertEqual(a.shape, b.shape) + if check_dtypes and hasattr(a, "dtype"): + self.assertEqual(tf.as_dtype(a.dtype), b.dtype) + + # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we + # skip incomplete-shape checks, since shape specs need dtype. It's OK to + # skip since the same incomplete-shape checks will run for []-shaped arrays. + if check_incomplete_shape and all(hasattr(x, "dtype") for x in args): + # Check partial shapes with known ranks. + # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not + # `shape`. + if all(hasattr(x, "shape") for x in args): + specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args] + cfun = npe.jit( + fun, static_argnums=static_argnums, input_signature=specs + ) + compiled_ans = cfun(*args) + self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) + + if check_unknown_rank: + # Check unknown ranks. + specs = [tf.TensorSpec(None, x.dtype) for x in args] + cfun = npe.jit( + fun, static_argnums=static_argnums, input_signature=specs + ) + compiled_ans = cfun(*args) + self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) + + def check_grads(self, f, args, atol=None, rtol=None, delta=None): + """Check gradients against finite differences. + + Args: + f: function to check at ``f(*args)``. + args: a list or tuple of argument values. + atol: absolute tolerance for gradient equality. + rtol: relative tolerance for gradient equality. + delta: step size used for finite differences. + """ + if delta is None: + # Optimal stepsize for central difference is O(epsilon^{1/3}). + dtype = tf_np.result_type(*args) + epsilon = onp.finfo(dtype).eps + delta = epsilon ** (1.0 / 3.0) + theoretical, numerical = tf.test.compute_gradient( + to_tf_fn(f), args, delta=delta + ) + self.assertAllClose( + theoretical, numerical, check_dtypes=False, atol=atol, rtol=rtol + ) + + +@contextmanager +def ignore_warning(**kw): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", **kw) + yield + + +def disable(_): + def wrapper(self, *args, **kwargs): + self.skipTest("Test is disabled") + + return wrapper diff --git a/tests/tf_numpy/jax/vmap_test.py b/tests/tf_numpy/jax/vmap_test.py new file mode 100644 index 000000000..1acdd2242 --- /dev/null +++ b/tests/tf_numpy/jax/vmap_test.py @@ -0,0 +1,180 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from absl.testing import parameterized + +import numpy as np +import tensorflow.compat.v2 as tf + +from trax.tf_numpy import extensions +import trax.tf_numpy.numpy as tf_np + +from tensorflow.python.ops.numpy_ops import ( + np_math_ops, +) # pylint: disable=g-direct-tensorflow-import + + +class VmapTest(tf.test.TestCase, parameterized.TestCase): + def test_vmap_in_axes_list(self): + # https://github.com/google/jax/issues/2367 + dictionary = {"a": 5.0, "b": tf_np.ones(2)} + x = tf_np.zeros(3) + y = tf_np.arange(3.0) + + def f(dct, x, y): + return dct["a"] + dct["b"] + x + y + + out1 = extensions.vmap(f, (None, 0, 0))(dictionary, x, y) + out2 = extensions.vmap(f, [None, 0, 0])(dictionary, x, y) + self.assertAllClose(out1, out2) + + def test_vmap_in_axes_tree_prefix_error(self): + # https://github.com/google/jax/issues/795 + self.assertRaisesRegex( + ValueError, + "vmap in_axes specification must be a tree prefix of the corresponding " + r"value, got specification \(0, 0\) for value tree ", + lambda: extensions.vmap(lambda x: x, in_axes=(0, 0))(tf_np.ones(3)), + ) + + def test_vmap_in_axes_leaf_types(self): + with self.assertRaisesRegex( + TypeError, r"vmap in_axes must be an int, None, or .*" + ): + extensions.vmap(lambda x: x, in_axes=(tf_np.array([1.0, 2.0]),))( + tf_np.array([1.0, 2.0]) + ) + + def test_vmap_out_axes_leaf_types(self): + with self.assertRaisesRegex( + TypeError, r"vmap out_axes must be an int, None, or .*" + ): + extensions.vmap(lambda x: x, out_axes=(tf_np.array([1.0, 2.0]),))( + tf_np.array([1.0, 2.0]) + ) + + def test_vmap_unbatched_object_passthrough_issue_183(self): + # https://github.com/google/jax/issues/183 + fun = lambda f, x: f(x) + vfun = extensions.vmap(fun, (None, 0)) + ans = vfun(lambda x: x + 1, tf_np.arange(3)) + self.assertAllClose(ans, np.arange(1, 4)) + + def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): + # https://github.com/google/jax/issues/705 + with self.assertRaisesRegex( + ValueError, "vmap must have at least one non-None value in in_axes" + ): + # If the output is mapped, there must be a non-None in_axes + extensions.vmap(lambda x: x, in_axes=None)(tf_np.array([1.0, 2.0])) + + # Error is: TypeError: only integer scalar arrays can be converted to a + # scalar index + with self.assertRaisesRegex( + ValueError, + "vmap out_axes specification must be a tree prefix of the " + "corresponding value.*", + ): + extensions.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))( + tf_np.array([1.0, 2.0]) + ) + + def test_vmap_structured_in_axes(self): + a, b, c, d = 2, 3, 4, 5 + k = 6 # batch size + x = np.ones((k, a, b)) # batch axis in different locations + y = np.ones((b, k, c)) + z = np.ones((c, d, k)) + + def foo(tree_arg): + x, (y, z) = tree_arg + return tf_np.dot(x, tf_np.dot(y, z)) + + tree = (x, (y, z)) + vfoo = extensions.vmap(foo, in_axes=((0, (1, 2)),)) + self.assertEqual(vfoo(tree).shape, (6, 2, 5)) + + Point = collections.namedtuple("Point", ["x", "y"]) + tree = (x, Point(y, z)) + vfoo = extensions.vmap(foo, in_axes=((0, Point(1, 2)),)) + self.assertEqual(vfoo(tree).shape, (6, 2, 5)) + + def foo2(tree_arg): + x, dct = tree_arg + y, z = dct["a"], dct["b"] + return tf_np.dot(x, tf_np.dot(y, z)) + + tree = (x, {"a": y, "b": z}) + vfoo = extensions.vmap(foo2, in_axes=((0, {"a": 1, "b": 2}),)) + self.assertEqual(vfoo(tree).shape, (6, 2, 5)) + + tree = (x, collections.OrderedDict([("a", y), ("b", z)])) + vfoo = extensions.vmap( + foo2, in_axes=((0, collections.OrderedDict([("a", 1), ("b", 2)])),) + ) + self.assertEqual(vfoo(tree).shape, (6, 2, 5)) + + def test_vmap_out_axes(self): + f = extensions.vmap(lambda x: x, out_axes=0) + inp = tf_np.arange(6).reshape([2, 3]) + self.assertAllClose(inp, f(inp)) + self.assertAllClose([inp, inp], f((inp, inp))) + + f = extensions.vmap(lambda x: x, out_axes=-1) + self.assertAllClose(inp.T, f(inp)) + + f = extensions.vmap(lambda x: x, out_axes=None) + self.assertAllClose(inp[0], f(inp)) + + f = extensions.vmap(lambda x: x, out_axes=([0], (-1, None), {"a": 1})) + a, b, c = f(([inp], (inp, inp), {"a": inp})) + self.assertAllClose([inp], a) + self.assertAllClose((inp.T, inp[0]), b) + self.assertAllClose(inp.T, c["a"]) + + def test_negative_axes(self): + x = np.arange(3 * 4 * 5).reshape(3, 4, 5) + self.assertAllClose( + extensions.vmap(tf_np.sum, in_axes=-3)(x), tf_np.sum(x, axis=(1, 2)) + ) + self.assertAllClose( + extensions.vmap(tf_np.sum, in_axes=-2)(x), tf_np.sum(x, axis=(0, 2)) + ) + self.assertAllClose( + extensions.vmap(tf_np.sum, in_axes=-1)(x), tf_np.sum(x, axis=(0, 1)) + ) + + identity = lambda y: y + self.assertAllClose(x, extensions.vmap(identity, in_axes=0, out_axes=-3)(x)) + self.assertAllClose( + x.transpose(1, 0, 2), extensions.vmap(identity, in_axes=0, out_axes=-2)(x) + ) + self.assertAllClose( + x.transpose(1, 2, 0), extensions.vmap(identity, in_axes=0, out_axes=-1)(x) + ) + + self.assertAllClose( + np.full((5,), 7), + extensions.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -1))( + np.arange(5), 7 + )[1], + ) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + np_math_ops.enable_numpy_methods_on_tensor() + tf.test.main() diff --git a/tests/tf_numpy/numpy_impl/array_ops_test.py b/tests/tf_numpy/numpy_impl/array_ops_test.py new file mode 100644 index 000000000..cacaf3d23 --- /dev/null +++ b/tests/tf_numpy/numpy_impl/array_ops_test.py @@ -0,0 +1,1201 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tf numpy array methods.""" +import itertools +import sys +import numpy as np +from six.moves import range +from six.moves import zip +import tensorflow.compat.v2 as tf + +from trax.tf_numpy.numpy_impl import array_ops +from trax.tf_numpy.numpy_impl import arrays + + +class ArrayCreationTest(tf.test.TestCase): + def setUp(self): + super().setUp() + python_shapes = [0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3]] + self.shape_transforms = [ + lambda x: x, + lambda x: np.array(x, dtype=int), + lambda x: array_ops.array(x, dtype=int), + tf.TensorShape, + ] + + self.all_shapes = [] + for fn in self.shape_transforms: + self.all_shapes.extend([fn(s) for s in python_shapes]) + + if sys.version_info.major == 3: + # There is a bug of np.empty (and alike) in Python 3 causing a crash when + # the `shape` argument is an arrays.ndarray scalar (or tf.Tensor scalar). + def not_ndarray_scalar(s): + return not (isinstance(s, arrays.ndarray) and s.ndim == 0) + + self.all_shapes = list(filter(not_ndarray_scalar, self.all_shapes)) + + self.all_types = [ + int, + float, + np.int16, + np.int32, + np.int64, + np.float16, + np.float32, + np.float64, + ] + + source_array_data = [ + 1, + 5.5, + 7, + (), + (8, 10.0), + ((), ()), + ((1, 4), (2, 8)), + [], + [7], + [8, 10.0], + [[], []], + [[1, 4], [2, 8]], + ([], []), + ([1, 4], [2, 8]), + [(), ()], + [(1, 4), (2, 8)], + ] + + self.array_transforms = [ + lambda x: x, + tf.convert_to_tensor, + np.array, + array_ops.array, + ] + self.all_arrays = [] + for fn in self.array_transforms: + self.all_arrays.extend([fn(s) for s in source_array_data]) + + def testEmpty(self): + for s in self.all_shapes: + actual = array_ops.empty(s) + expected = np.empty(s) + msg = "shape: {}".format(s) + self.match_shape(actual, expected, msg) + self.match_dtype(actual, expected, msg) + + for s, t in itertools.product(self.all_shapes, self.all_types): + actual = array_ops.empty(s, t) + expected = np.empty(s, t) + msg = "shape: {}, dtype: {}".format(s, t) + self.match_shape(actual, expected, msg) + self.match_dtype(actual, expected, msg) + + def testEmptyLike(self): + for a in self.all_arrays: + actual = array_ops.empty_like(a) + expected = np.empty_like(a) + msg = "array: {}".format(a) + self.match_shape(actual, expected, msg) + self.match_dtype(actual, expected, msg) + + for a, t in itertools.product(self.all_arrays, self.all_types): + actual = array_ops.empty_like(a, t) + expected = np.empty_like(a, t) + msg = "array: {} type: {}".format(a, t) + self.match_shape(actual, expected, msg) + self.match_dtype(actual, expected, msg) + + def testZeros(self): + for s in self.all_shapes: + actual = array_ops.zeros(s) + expected = np.zeros(s) + msg = "shape: {}".format(s) + self.match(actual, expected, msg) + + for s, t in itertools.product(self.all_shapes, self.all_types): + actual = array_ops.zeros(s, t) + expected = np.zeros(s, t) + msg = "shape: {}, dtype: {}".format(s, t) + self.match(actual, expected, msg) + + def testZerosLike(self): + for a in self.all_arrays: + actual = array_ops.zeros_like(a) + expected = np.zeros_like(a) + msg = "array: {}".format(a) + self.match(actual, expected, msg) + + for a, t in itertools.product(self.all_arrays, self.all_types): + actual = array_ops.zeros_like(a, t) + expected = np.zeros_like(a, t) + msg = "array: {} type: {}".format(a, t) + self.match(actual, expected, msg) + + def testOnes(self): + for s in self.all_shapes: + actual = array_ops.ones(s) + expected = np.ones(s) + msg = "shape: {}".format(s) + self.match(actual, expected, msg) + + for s, t in itertools.product(self.all_shapes, self.all_types): + actual = array_ops.ones(s, t) + expected = np.ones(s, t) + msg = "shape: {}, dtype: {}".format(s, t) + self.match(actual, expected, msg) + + def testOnesLike(self): + for a in self.all_arrays: + actual = array_ops.ones_like(a) + expected = np.ones_like(a) + msg = "array: {}".format(a) + self.match(actual, expected, msg) + + for a, t in itertools.product(self.all_arrays, self.all_types): + actual = array_ops.ones_like(a, t) + expected = np.ones_like(a, t) + msg = "array: {} type: {}".format(a, t) + self.match(actual, expected, msg) + + def testEye(self): + n_max = 3 + m_max = 3 + + for n in range(1, n_max + 1): + self.match(array_ops.eye(n), np.eye(n)) + for k in range(-n, n + 1): + self.match(array_ops.eye(n, k=k), np.eye(n, k=k)) + for m in range(1, m_max + 1): + self.match(array_ops.eye(n, m), np.eye(n, m)) + for k in range(-n, m): + self.match(array_ops.eye(n, k=k), np.eye(n, k=k)) + self.match(array_ops.eye(n, m, k), np.eye(n, m, k)) + + for dtype in self.all_types: + for n in range(1, n_max + 1): + self.match(array_ops.eye(n, dtype=dtype), np.eye(n, dtype=dtype)) + for k in range(-n, n + 1): + self.match( + array_ops.eye(n, k=k, dtype=dtype), np.eye(n, k=k, dtype=dtype) + ) + for m in range(1, m_max + 1): + self.match( + array_ops.eye(n, m, dtype=dtype), np.eye(n, m, dtype=dtype) + ) + for k in range(-n, m): + self.match( + array_ops.eye(n, k=k, dtype=dtype), + np.eye(n, k=k, dtype=dtype), + ) + self.match( + array_ops.eye(n, m, k, dtype=dtype), + np.eye(n, m, k, dtype=dtype), + ) + + def testIdentity(self): + n_max = 3 + + for n in range(1, n_max + 1): + self.match(array_ops.identity(n), np.identity(n)) + + for dtype in self.all_types: + for n in range(1, n_max + 1): + self.match( + array_ops.identity(n, dtype=dtype), np.identity(n, dtype=dtype) + ) + + def testFull(self): + # List of 2-tuples of fill value and shape. + data = [ + (5, ()), + (5, (7,)), + (5.0, (7,)), + ([5, 8], (2,)), + ([5, 8], (3, 2)), + ([[5], [8]], (2, 3)), + ([[5], [8]], (3, 2, 5)), + ([[5.0], [8.0]], (3, 2, 5)), + ([[3, 4], [5, 6], [7, 8]], (3, 3, 2)), + ] + for f, s in data: + for fn1, fn2 in itertools.product( + self.array_transforms, self.shape_transforms + ): + fill_value = fn1(f) + shape = fn2(s) + self.match( + array_ops.full(shape, fill_value), np.full(shape, fill_value) + ) + for dtype in self.all_types: + self.match( + array_ops.full(shape, fill_value, dtype=dtype), + np.full(shape, fill_value, dtype=dtype), + ) + + def testFullLike(self): + # List of 2-tuples of fill value and shape. + data = [ + (5, ()), + (5, (7,)), + (5.0, (7,)), + ([5, 8], (2,)), + ([5, 8], (3, 2)), + ([[5], [8]], (2, 3)), + ([[5], [8]], (3, 2, 5)), + ([[5.0], [8.0]], (3, 2, 5)), + ] + zeros_builders = [array_ops.zeros, np.zeros] + for f, s in data: + for fn1, fn2, arr_dtype in itertools.product( + self.array_transforms, zeros_builders, self.all_types + ): + fill_value = fn1(f) + arr = fn2(s, arr_dtype) + self.match( + array_ops.full_like(arr, fill_value), np.full_like(arr, fill_value) + ) + for dtype in self.all_types: + self.match( + array_ops.full_like(arr, fill_value, dtype=dtype), + np.full_like(arr, fill_value, dtype=dtype), + ) + + def testArray(self): + ndmins = [0, 1, 2, 5] + for a, dtype, ndmin, copy in itertools.product( + self.all_arrays, self.all_types, ndmins, [True, False] + ): + self.match( + array_ops.array(a, dtype=dtype, ndmin=ndmin, copy=copy), + np.array(a, dtype=dtype, ndmin=ndmin, copy=copy), + ) + + zeros_list = array_ops.zeros(5) + + # TODO(srbs): Test that copy=True when context.device is different from + # tensor device copies the tensor. + + # Backing tensor is the same if copy=False, other attributes being None. + self.assertIs(array_ops.array(zeros_list, copy=False).data, zeros_list.data) + self.assertIs( + array_ops.array(zeros_list.data, copy=False).data, zeros_list.data + ) + + # Backing tensor is different if ndmin is not satisfied. + self.assertIsNot( + array_ops.array(zeros_list, copy=False, ndmin=2).data, zeros_list.data + ) + self.assertIsNot( + array_ops.array(zeros_list.data, copy=False, ndmin=2).data, zeros_list.data + ) + self.assertIs( + array_ops.array(zeros_list, copy=False, ndmin=1).data, zeros_list.data + ) + self.assertIs( + array_ops.array(zeros_list.data, copy=False, ndmin=1).data, zeros_list.data + ) + + # Backing tensor is different if dtype is not satisfied. + self.assertIsNot( + array_ops.array(zeros_list, copy=False, dtype=int).data, zeros_list.data + ) + self.assertIsNot( + array_ops.array(zeros_list.data, copy=False, dtype=int).data, + zeros_list.data, + ) + self.assertIs( + array_ops.array(zeros_list, copy=False, dtype=float).data, zeros_list.data + ) + self.assertIs( + array_ops.array(zeros_list.data, copy=False, dtype=float).data, + zeros_list.data, + ) + + def testAsArray(self): + for a, dtype in itertools.product(self.all_arrays, self.all_types): + self.match(array_ops.asarray(a, dtype=dtype), np.asarray(a, dtype=dtype)) + + zeros_list = array_ops.zeros(5) + # Same instance is returned if no dtype is specified and input is ndarray. + self.assertIs(array_ops.asarray(zeros_list), zeros_list) + # Different instance is returned if dtype is specified and input is ndarray. + self.assertIsNot(array_ops.asarray(zeros_list, dtype=int), zeros_list) + + def testAsAnyArray(self): + for a, dtype in itertools.product(self.all_arrays, self.all_types): + self.match( + array_ops.asanyarray(a, dtype=dtype), np.asanyarray(a, dtype=dtype) + ) + zeros_list = array_ops.zeros(5) + # Same instance is returned if no dtype is specified and input is ndarray. + self.assertIs(array_ops.asanyarray(zeros_list), zeros_list) + # Different instance is returned if dtype is specified and input is ndarray. + self.assertIsNot(array_ops.asanyarray(zeros_list, dtype=int), zeros_list) + + def testAsContiguousArray(self): + for a, dtype in itertools.product(self.all_arrays, self.all_types): + self.match( + array_ops.ascontiguousarray(a, dtype=dtype), + np.ascontiguousarray(a, dtype=dtype), + ) + + def testARange(self): + int_values = np.arange(-3, 3).tolist() + float_values = np.arange(-3.5, 3.5).tolist() + all_values = int_values + float_values + for dtype in self.all_types: + for start in all_values: + msg = "dtype:{} start:{}".format(dtype, start) + self.match(array_ops.arange(start), np.arange(start), msg=msg) + self.match( + array_ops.arange(start, dtype=dtype), + np.arange(start, dtype=dtype), + msg=msg, + ) + for stop in all_values: + msg = "dtype:{} start:{} stop:{}".format(dtype, start, stop) + self.match( + array_ops.arange(start, stop), np.arange(start, stop), msg=msg + ) + # TODO(srbs): Investigate and remove check. + # There are some bugs when start or stop is float and dtype is int. + if not isinstance(start, float) and not isinstance(stop, float): + self.match( + array_ops.arange(start, stop, dtype=dtype), + np.arange(start, stop, dtype=dtype), + msg=msg, + ) + # Note: We intentionally do not test with float values for step + # because numpy.arange itself returns inconsistent results. e.g. + # np.arange(0.5, 3, step=0.5, dtype=int) returns + # array([0, 1, 2, 3, 4]) + for step in int_values: + msg = "dtype:{} start:{} stop:{} step:{}".format( + dtype, start, stop, step + ) + if not step: + with self.assertRaises(ValueError): + self.match( + array_ops.arange(start, stop, step), + np.arange(start, stop, step), + msg=msg, + ) + if not isinstance(start, float) and not isinstance( + stop, float + ): + self.match( + array_ops.arange( + start, stop, step, dtype=dtype + ), + np.arange(start, stop, step, dtype=dtype), + msg=msg, + ) + else: + self.match( + array_ops.arange(start, stop, step), + np.arange(start, stop, step), + msg=msg, + ) + if not isinstance(start, float) and not isinstance( + stop, float + ): + self.match( + array_ops.arange(start, stop, step, dtype=dtype), + np.arange(start, stop, step, dtype=dtype), + msg=msg, + ) + + def testGeomSpace(self): + def run_test(start, stop, **kwargs): + arg1 = start + arg2 = stop + self.match( + array_ops.geomspace(arg1, arg2, **kwargs), + np.geomspace(arg1, arg2, **kwargs), + msg="geomspace({}, {})".format(arg1, arg2), + almost=True, + ) + + run_test(1, 1000, num=5) + run_test(1, 1000, num=5, endpoint=False) + run_test(-1, -1000, num=5) + run_test(-1, -1000, num=5, endpoint=False) + + def testDiag(self): + array_transforms = [ + lambda x: x, # Identity, + tf.convert_to_tensor, + np.array, + lambda x: np.array(x, dtype=np.float32), + lambda x: np.array(x, dtype=np.float64), + array_ops.array, + lambda x: array_ops.array(x, dtype=np.float32), + lambda x: array_ops.array(x, dtype=np.float64), + ] + + def run_test(arr): + for fn in array_transforms: + arr = fn(arr) + self.match( + array_ops.diag(arr), np.diag(arr), msg="diag({})".format(arr) + ) + for k in range(-3, 3): + self.match( + array_ops.diag(arr, k), + np.diag(arr, k), + msg="diag({}, k={})".format(arr, k), + ) + + # 2-d arrays. + run_test(np.arange(9).reshape((3, 3)).tolist()) + run_test(np.arange(6).reshape((2, 3)).tolist()) + run_test(np.arange(6).reshape((3, 2)).tolist()) + run_test(np.arange(3).reshape((1, 3)).tolist()) + run_test(np.arange(3).reshape((3, 1)).tolist()) + run_test([[5]]) + run_test([[]]) + run_test([[], []]) + + # 1-d arrays. + run_test([]) + run_test([1]) + run_test([1, 2]) + + def testDiagFlat(self): + array_transforms = [ + lambda x: x, # Identity, + tf.convert_to_tensor, + np.array, + lambda x: np.array(x, dtype=np.float32), + lambda x: np.array(x, dtype=np.float64), + array_ops.array, + lambda x: array_ops.array(x, dtype=np.float32), + lambda x: array_ops.array(x, dtype=np.float64), + ] + + def run_test(arr): + for fn in array_transforms: + arr = fn(arr) + self.match( + array_ops.diagflat(arr), + np.diagflat(arr), + msg="diagflat({})".format(arr), + ) + for k in range(-3, 3): + self.match( + array_ops.diagflat(arr, k), + np.diagflat(arr, k), + msg="diagflat({}, k={})".format(arr, k), + ) + + # 1-d arrays. + run_test([]) + run_test([1]) + run_test([1, 2]) + # 2-d arrays. + run_test([[]]) + run_test([[5]]) + run_test([[], []]) + run_test(np.arange(4).reshape((2, 2)).tolist()) + run_test(np.arange(2).reshape((2, 1)).tolist()) + run_test(np.arange(2).reshape((1, 2)).tolist()) + # 3-d arrays + run_test(np.arange(8).reshape((2, 2, 2)).tolist()) + + def match_shape(self, actual, expected, msg=None): + if msg: + msg = "Shape match failed for: {}. Expected: {} Actual: {}".format( + msg, expected.shape, actual.shape + ) + self.assertEqual(actual.shape, expected.shape, msg=msg) + if msg: + msg = "Shape: {} is not a tuple for {}".format(actual.shape, msg) + self.assertIsInstance(actual.shape, tuple, msg=msg) + + def match_dtype(self, actual, expected, msg=None): + if msg: + msg = "Dtype match failed for: {}. Expected: {} Actual: {}.".format( + msg, expected.dtype, actual.dtype + ) + self.assertEqual(actual.dtype, expected.dtype, msg=msg) + + def match(self, actual, expected, msg=None, almost=False): + msg_ = "Expected: {} Actual: {}".format(expected, actual) + if msg: + msg = "{} {}".format(msg_, msg) + else: + msg = msg_ + self.assertIsInstance(actual, arrays.ndarray) + self.match_dtype(actual, expected, msg) + self.match_shape(actual, expected, msg) + if not almost: + if not actual.shape: + self.assertEqual(actual.tolist(), expected.tolist()) + else: + self.assertSequenceEqual(actual.tolist(), expected.tolist()) + else: + self.assertAllClose(actual.tolist(), expected.tolist()) + + def testIndexedSlices(self): + dtype = tf.int64 + iss = tf.IndexedSlices( + values=tf.ones([2, 3], dtype=dtype), + indices=tf.constant([1, 9]), + dense_shape=[10, 3], + ) + a = array_ops.array(iss, copy=False) + expected = tf.scatter_nd([[1], [9]], tf.ones([2, 3], dtype=dtype), [10, 3]) + self.assertAllEqual(expected, a) + + +class ArrayMethodsTest(tf.test.TestCase): + def setUp(self): + super().setUp() + self.array_transforms = [ + lambda x: x, + tf.convert_to_tensor, + np.array, + array_ops.array, + ] + + def testAllAny(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arr = fn(arr) + self.match( + array_ops.all(arr, *args, **kwargs), np.all(arr, *args, **kwargs) + ) + self.match( + array_ops.any(arr, *args, **kwargs), np.any(arr, *args, **kwargs) + ) + + run_test(0) + run_test(1) + run_test([]) + run_test([[True, False], [True, True]]) + run_test([[True, False], [True, True]], axis=0) + run_test([[True, False], [True, True]], axis=0, keepdims=True) + run_test([[True, False], [True, True]], axis=1) + run_test([[True, False], [True, True]], axis=1, keepdims=True) + run_test([[True, False], [True, True]], axis=(0, 1)) + run_test([[True, False], [True, True]], axis=(0, 1), keepdims=True) + run_test([5.2, 3.5], axis=0) + run_test([1, 0], axis=0) + + def testCompress(self): + def run_test(condition, arr, *args, **kwargs): + for fn1 in self.array_transforms: + for fn2 in self.array_transforms: + arg1 = fn1(condition) + arg2 = fn2(arr) + self.match( + array_ops.compress(arg1, arg2, *args, **kwargs), + np.compress( + np.asarray(arg1).astype(bool), arg2, *args, **kwargs + ), + ) + + run_test([True], 5) + run_test([False], 5) + run_test([], 5) + run_test([True, False, True], [1, 2, 3]) + run_test([True, False], [1, 2, 3]) + run_test([False, True], [[1, 2], [3, 4]]) + run_test([1, 0, 1], [1, 2, 3]) + run_test([1, 0], [1, 2, 3]) + run_test([0, 1], [[1, 2], [3, 4]]) + run_test([True], [[1, 2], [3, 4]]) + run_test([False, True], [[1, 2], [3, 4]], axis=1) + run_test([False, True], [[1, 2], [3, 4]], axis=0) + run_test([False, True], [[1, 2], [3, 4]], axis=-1) + run_test([False, True], [[1, 2], [3, 4]], axis=-2) + + def testCopy(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + self.match( + array_ops.copy(arg, *args, **kwargs), np.copy(arg, *args, **kwargs) + ) + + run_test([]) + run_test([1, 2, 3]) + run_test([1.0, 2.0, 3.0]) + run_test([True]) + run_test(np.arange(9).reshape((3, 3)).tolist()) + + def testCumProdAndSum(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + self.match( + array_ops.cumprod(arg, *args, **kwargs), + np.cumprod(arg, *args, **kwargs), + ) + self.match( + array_ops.cumsum(arg, *args, **kwargs), + np.cumsum(arg, *args, **kwargs), + ) + + run_test([]) + run_test([1, 2, 3]) + run_test([1, 2, 3], dtype=float) + run_test([1, 2, 3], dtype=np.float32) + run_test([1, 2, 3], dtype=np.float64) + run_test([1.0, 2.0, 3.0]) + run_test([1.0, 2.0, 3.0], dtype=int) + run_test([1.0, 2.0, 3.0], dtype=np.int32) + run_test([1.0, 2.0, 3.0], dtype=np.int64) + run_test([[1, 2], [3, 4]], axis=1) + run_test([[1, 2], [3, 4]], axis=0) + run_test([[1, 2], [3, 4]], axis=-1) + run_test([[1, 2], [3, 4]], axis=-2) + + def testImag(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + self.match( + array_ops.imag(arg, *args, **kwargs), + # np.imag may return a scalar so we convert to a np.ndarray. + np.array(np.imag(arg, *args, **kwargs)), + ) + + run_test(1) + run_test(5.5) + run_test(5 + 3j) + run_test(3j) + run_test([]) + run_test([1, 2, 3]) + run_test([1 + 5j, 2 + 3j]) + run_test([[1 + 5j, 2 + 3j], [1 + 7j, 2 + 8j]]) + + def testAMaxAMin(self): + def run_test(arr, *args, **kwargs): + axis = kwargs.pop("axis", None) + for fn1 in self.array_transforms: + for fn2 in self.array_transforms: + arr_arg = fn1(arr) + axis_arg = fn2(axis) if axis is not None else None + self.match( + array_ops.amax(arr_arg, axis=axis_arg, *args, **kwargs), + np.amax(arr_arg, axis=axis, *args, **kwargs), + ) + self.match( + array_ops.amin(arr_arg, axis=axis_arg, *args, **kwargs), + np.amin(arr_arg, axis=axis, *args, **kwargs), + ) + + run_test([1, 2, 3]) + run_test([1.0, 2.0, 3.0]) + run_test([[1, 2], [3, 4]], axis=1) + run_test([[1, 2], [3, 4]], axis=0) + run_test([[1, 2], [3, 4]], axis=-1) + run_test([[1, 2], [3, 4]], axis=-2) + run_test([[1, 2], [3, 4]], axis=(0, 1)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True) + + def testMean(self): + def run_test(arr, *args, **kwargs): + axis = kwargs.pop("axis", None) + for fn1 in self.array_transforms: + for fn2 in self.array_transforms: + arr_arg = fn1(arr) + axis_arg = fn2(axis) if axis is not None else None + self.match( + array_ops.mean(arr_arg, axis=axis_arg, *args, **kwargs), + np.mean(arr_arg, axis=axis, *args, **kwargs), + ) + + run_test([1, 2, 1]) + run_test([1.0, 2.0, 1.0]) + run_test([1.0, 2.0, 1.0], dtype=int) + run_test([[1, 2], [3, 4]], axis=1) + run_test([[1, 2], [3, 4]], axis=0) + run_test([[1, 2], [3, 4]], axis=-1) + run_test([[1, 2], [3, 4]], axis=-2) + run_test([[1, 2], [3, 4]], axis=(0, 1)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True) + + def testProd(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + self.match( + array_ops.prod(arg, *args, **kwargs), np.prod(arg, *args, **kwargs) + ) + + run_test([1, 2, 3]) + run_test([1.0, 2.0, 3.0]) + run_test(np.array([1, 2, 3], dtype=np.int16)) + run_test([[1, 2], [3, 4]], axis=1) + run_test([[1, 2], [3, 4]], axis=0) + run_test([[1, 2], [3, 4]], axis=-1) + run_test([[1, 2], [3, 4]], axis=-2) + run_test([[1, 2], [3, 4]], axis=(0, 1)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0)) + run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True) + + def _testReduce(self, math_fun, np_fun, name): + axis_transforms = [ + lambda x: x, # Identity, + tf.convert_to_tensor, + np.array, + array_ops.array, + lambda x: array_ops.array(x, dtype=np.float32), + lambda x: array_ops.array(x, dtype=np.float64), + ] + + def run_test(a, **kwargs): + axis = kwargs.pop("axis", None) + for fn1 in self.array_transforms: + for fn2 in axis_transforms: + arg1 = fn1(a) + axis_arg = fn2(axis) if axis is not None else None + self.match( + math_fun(arg1, axis=axis_arg, **kwargs), + np_fun(arg1, axis=axis, **kwargs), + msg="{}({}, axis={}, keepdims={})".format( + name, arg1, axis, kwargs.get("keepdims") + ), + ) + + run_test(5) + run_test([2, 3]) + run_test([[2, -3], [-6, 7]]) + run_test([[2, -3], [-6, 7]], axis=0) + run_test([[2, -3], [-6, 7]], axis=0, keepdims=True) + run_test([[2, -3], [-6, 7]], axis=1) + run_test([[2, -3], [-6, 7]], axis=1, keepdims=True) + run_test([[2, -3], [-6, 7]], axis=(0, 1)) + run_test([[2, -3], [-6, 7]], axis=(1, 0)) + + def testSum(self): + self._testReduce(array_ops.sum, np.sum, "sum") + + def testAmax(self): + self._testReduce(array_ops.amax, np.amax, "amax") + + def testRavel(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + self.match( + array_ops.ravel(arg, *args, **kwargs), + np.ravel(arg, *args, **kwargs), + ) + + run_test(5) + run_test(5.0) + run_test([]) + run_test([[]]) + run_test([[], []]) + run_test([1, 2, 3]) + run_test([1.0, 2.0, 3.0]) + run_test([[1, 2], [3, 4]]) + run_test(np.arange(8).reshape((2, 2, 2)).tolist()) + + def testReal(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + self.match( + array_ops.real(arg, *args, **kwargs), + np.array(np.real(arg, *args, **kwargs)), + ) + + run_test(1) + run_test(5.5) + run_test(5 + 3j) + run_test(3j) + run_test([]) + run_test([1, 2, 3]) + run_test([1 + 5j, 2 + 3j]) + run_test([[1 + 5j, 2 + 3j], [1 + 7j, 2 + 8j]]) + + def testRepeat(self): + def run_test(arr, repeats, *args, **kwargs): + for fn1 in self.array_transforms: + for fn2 in self.array_transforms: + arr_arg = fn1(arr) + repeats_arg = fn2(repeats) + self.match( + array_ops.repeat(arr_arg, repeats_arg, *args, **kwargs), + np.repeat(arr_arg, repeats_arg, *args, **kwargs), + ) + + run_test(1, 2) + run_test([1, 2], 2) + run_test([1, 2], [2]) + run_test([1, 2], [1, 2]) + run_test([[1, 2], [3, 4]], 3, axis=0) + run_test([[1, 2], [3, 4]], 3, axis=1) + run_test([[1, 2], [3, 4]], [3], axis=0) + run_test([[1, 2], [3, 4]], [3], axis=1) + run_test([[1, 2], [3, 4]], [3, 2], axis=0) + run_test([[1, 2], [3, 4]], [3, 2], axis=1) + run_test([[1, 2], [3, 4]], [3, 2], axis=-1) + run_test([[1, 2], [3, 4]], [3, 2], axis=-2) + + def testAround(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + self.match( + array_ops.around(arg, *args, **kwargs), + np.around(arg, *args, **kwargs), + ) + + run_test(5.5) + run_test(5.567, decimals=2) + run_test([]) + run_test([1.27, 2.49, 2.75], decimals=1) + run_test([23.6, 45.1], decimals=-1) + + def testReshape(self): + def run_test(arr, newshape, *args, **kwargs): + for fn1 in self.array_transforms: + for fn2 in self.array_transforms: + arr_arg = fn1(arr) + newshape_arg = fn2(newshape) + # If reshape is called on a Tensor, it calls out to the Tensor.reshape + # method. + np_arr_arg = arr_arg + if isinstance(np_arr_arg, tf.Tensor): + np_arr_arg = np_arr_arg.numpy() + self.match( + array_ops.reshape(arr_arg, newshape_arg, *args, **kwargs), + np.reshape(np_arr_arg, newshape, *args, **kwargs), + ) + + run_test(5, [-1]) + run_test([], [-1]) + run_test([1, 2, 3], [1, 3]) + run_test([1, 2, 3], [3, 1]) + run_test([1, 2, 3, 4], [2, 2]) + run_test([1, 2, 3, 4], [2, 1, 2]) + + def testExpandDims(self): + def run_test(arr, axis): + self.match(array_ops.expand_dims(arr, axis), np.expand_dims(arr, axis)) + + run_test([1, 2, 3], 0) + run_test([1, 2, 3], 1) + + def testSqueeze(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + # Note: np.squeeze ignores the axis arg for non-ndarray objects. + # This looks like a bug: https://github.com/numpy/numpy/issues/8201 + # So we convert the arg to np.ndarray before passing to np.squeeze. + self.match( + array_ops.squeeze(arg, *args, **kwargs), + np.squeeze(np.array(arg), *args, **kwargs), + ) + + run_test(5) + run_test([]) + run_test([5]) + run_test([[1, 2, 3]]) + run_test([[[1], [2], [3]]]) + run_test([[[1], [2], [3]]], axis=0) + run_test([[[1], [2], [3]]], axis=2) + run_test([[[1], [2], [3]]], axis=(0, 2)) + run_test([[[1], [2], [3]]], axis=-1) + run_test([[[1], [2], [3]]], axis=-3) + + def testTranspose(self): + def run_test(arr, axes=None): + for fn1 in self.array_transforms: + for fn2 in self.array_transforms: + arr_arg = fn1(arr) + axes_arg = fn2(axes) if axes is not None else None + # If transpose is called on a Tensor, it calls out to the + # Tensor.transpose method. + np_arr_arg = arr_arg + if isinstance(np_arr_arg, tf.Tensor): + np_arr_arg = np_arr_arg.numpy() + self.match( + array_ops.transpose(arr_arg, axes_arg), + np.transpose(np_arr_arg, axes), + ) + + run_test(5) + run_test([]) + run_test([5]) + run_test([5, 6, 7]) + run_test(np.arange(30).reshape(2, 3, 5).tolist()) + run_test(np.arange(30).reshape(2, 3, 5).tolist(), [0, 1, 2]) + run_test(np.arange(30).reshape(2, 3, 5).tolist(), [0, 2, 1]) + run_test(np.arange(30).reshape(2, 3, 5).tolist(), [1, 0, 2]) + run_test(np.arange(30).reshape(2, 3, 5).tolist(), [1, 2, 0]) + run_test(np.arange(30).reshape(2, 3, 5).tolist(), [2, 0, 1]) + run_test(np.arange(30).reshape(2, 3, 5).tolist(), [2, 1, 0]) + + def testSetItem(self): + def run_test(arr, index, value): + for fn in self.array_transforms: + value_arg = fn(value) + tf_array = array_ops.array(arr) + np_array = np.array(arr) + tf_array[index] = value_arg + # TODO(srbs): "setting an array element with a sequence" is thrown + # if we do not wrap value_arg in a numpy array. Investigate how this can + # be avoided. + np_array[index] = np.array(value_arg) + self.match(tf_array, np_array) + + run_test([1, 2, 3], 1, 5) + run_test([[1, 2], [3, 4]], 0, [6, 7]) + run_test([[1, 2], [3, 4]], 1, [6, 7]) + run_test([[1, 2], [3, 4]], (0, 1), 6) + run_test([[1, 2], [3, 4]], 0, 6) # Value needs to broadcast. + + def match_shape(self, actual, expected, msg=None): + if msg: + msg = "Shape match failed for: {}. Expected: {} Actual: {}".format( + msg, expected.shape, actual.shape + ) + self.assertEqual(actual.shape, expected.shape, msg=msg) + if msg: + msg = "Shape: {} is not a tuple for {}".format(actual.shape, msg) + self.assertIsInstance(actual.shape, tuple, msg=msg) + + def match_dtype(self, actual, expected, msg=None): + if msg: + msg = "Dtype match failed for: {}. Expected: {} Actual: {}.".format( + msg, expected.dtype, actual.dtype + ) + self.assertEqual(actual.dtype, expected.dtype, msg=msg) + + def match(self, actual, expected, msg=None, check_dtype=True): + msg_ = "Expected: {} Actual: {}".format(expected, actual) + if msg: + msg = "{} {}".format(msg_, msg) + else: + msg = msg_ + self.assertIsInstance(actual, arrays.ndarray) + if check_dtype: + self.match_dtype(actual, expected, msg) + self.match_shape(actual, expected, msg) + if not actual.shape: + self.assertAllClose(actual.tolist(), expected.tolist()) + else: + self.assertAllClose(actual.tolist(), expected.tolist()) + + def testPad(self): + t = [[1, 2, 3], [4, 5, 6]] + paddings = [ + [ + 1, + 1, + ], + [2, 2], + ] + self.assertAllEqual( + array_ops.pad(t, paddings, "constant"), + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 3, 0, 0], + [0, 0, 4, 5, 6, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ], + ) + + self.assertAllEqual( + array_ops.pad(t, paddings, "reflect"), + [ + [6, 5, 4, 5, 6, 5, 4], + [3, 2, 1, 2, 3, 2, 1], + [6, 5, 4, 5, 6, 5, 4], + [3, 2, 1, 2, 3, 2, 1], + ], + ) + + self.assertAllEqual( + array_ops.pad(t, paddings, "symmetric"), + [ + [2, 1, 1, 2, 3, 3, 2], + [2, 1, 1, 2, 3, 3, 2], + [5, 4, 4, 5, 6, 6, 5], + [5, 4, 4, 5, 6, 6, 5], + ], + ) + + def testTake(self): + a = [4, 3, 5, 7, 6, 8] + indices = [0, 1, 4] + self.assertAllEqual([4, 3, 6], array_ops.take(a, indices)) + indices = [[0, 1], [2, 3]] + self.assertAllEqual([[4, 3], [5, 7]], array_ops.take(a, indices)) + a = [[4, 3, 5], [7, 6, 8]] + self.assertAllEqual([[4, 3], [5, 7]], array_ops.take(a, indices)) + a = np.random.rand(2, 16, 3) + axis = 1 + self.assertAllEqual( + np.take(a, indices, axis=axis), array_ops.take(a, indices, axis=axis) + ) + + def testWhere(self): + self.assertAllEqual( + [[1.0, 1.0], [1.0, 1.0]], + array_ops.where([True], [1.0, 1.0], [[0, 0], [0, 0]]), + ) + + def testShape(self): + self.assertAllEqual((1, 2), array_ops.shape([[0, 0]])) + + def testSwapaxes(self): + x = [[1, 2, 3]] + self.assertAllEqual([[1], [2], [3]], array_ops.swapaxes(x, 0, 1)) + self.assertAllEqual([[1], [2], [3]], array_ops.swapaxes(x, -2, -1)) + x = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] + self.assertAllEqual( + [[[0, 4], [2, 6]], [[1, 5], [3, 7]]], array_ops.swapaxes(x, 0, 2) + ) + self.assertAllEqual( + [[[0, 4], [2, 6]], [[1, 5], [3, 7]]], array_ops.swapaxes(x, -3, -1) + ) + + def testMoveaxis(self): + def _test(*args): + expected = np.moveaxis(*args) + raw_ans = array_ops.moveaxis(*args) + + self.assertAllEqual(expected, raw_ans) + + a = np.random.rand(1, 2, 3, 4, 5, 6) + + # Basic + _test(a, (0, 2), (3, 5)) + _test(a, (0, 2), (-1, -3)) + _test(a, (-6, -4), (3, 5)) + _test(a, (-6, -4), (-1, -3)) + _test(a, 0, 4) + _test(a, -6, -2) + _test(a, tuple(range(6)), tuple(range(6))) + _test(a, tuple(range(6)), tuple(reversed(range(6)))) + _test(a, (), ()) + + def testNdim(self): + self.assertAllEqual(0, array_ops.ndim(0.5)) + self.assertAllEqual(1, array_ops.ndim([1, 2])) + + def testIsscalar(self): + self.assertTrue(array_ops.isscalar(0.5)) + self.assertTrue(array_ops.isscalar(5)) + self.assertTrue(array_ops.isscalar(False)) + self.assertFalse(array_ops.isscalar([1, 2])) + + def assertListEqual(self, a, b): + self.assertAllEqual(len(a), len(b)) + for x, y in zip(a, b): + self.assertAllEqual(x, y) + + def testSplit(self): + x = array_ops.arange(9) + y = array_ops.split(x, 3) + self.assertListEqual([([0, 1, 2]), ([3, 4, 5]), ([6, 7, 8])], y) + + x = array_ops.arange(8) + y = array_ops.split(x, [3, 5, 6, 10]) + self.assertListEqual([([0, 1, 2]), ([3, 4]), ([5]), ([6, 7]), ([])], y) + + +class ArrayManipulationTest(tf.test.TestCase): + def setUp(self): + super().setUp() + self.array_transforms = [ + lambda x: x, + tf.convert_to_tensor, + np.array, + array_ops.array, + ] + + def testBroadcastTo(self): + def run_test(arr, shape): + for fn in self.array_transforms: + arg1 = fn(arr) + self.match( + array_ops.broadcast_to(arg1, shape), np.broadcast_to(arg1, shape) + ) + + run_test(1, 2) + run_test(1, (2, 2)) + run_test([1, 2], (2, 2)) + run_test([[1], [2]], (2, 2)) + run_test([[1, 2]], (3, 2)) + run_test([[[1, 2]], [[3, 4]], [[5, 6]]], (3, 4, 2)) + + def testIx_(self): + possible_arys = [ + [True, True], + [True, False], + [False, False], + list(range(5)), + array_ops.empty(0, dtype=np.int64), + ] + for r in range(len(possible_arys)): + for arys in itertools.combinations_with_replacement(possible_arys, r): + tnp_ans = array_ops.ix_(*arys) + onp_ans = np.ix_(*arys) + for t, o in zip(tnp_ans, onp_ans): + self.match(t, o) + + def match_shape(self, actual, expected, msg=None): + if msg: + msg = "Shape match failed for: {}. Expected: {} Actual: {}".format( + msg, expected.shape, actual.shape + ) + self.assertEqual(actual.shape, expected.shape, msg=msg) + if msg: + msg = "Shape: {} is not a tuple for {}".format(actual.shape, msg) + self.assertIsInstance(actual.shape, tuple, msg=msg) + + def match_dtype(self, actual, expected, msg=None): + if msg: + msg = "Dtype match failed for: {}. Expected: {} Actual: {}.".format( + msg, expected.dtype, actual.dtype + ) + self.assertEqual(actual.dtype, expected.dtype, msg=msg) + + def match(self, actual, expected, msg=None): + msg_ = "Expected: {} Actual: {}".format(expected, actual) + if msg: + msg = "{} {}".format(msg_, msg) + else: + msg = msg_ + self.assertIsInstance(actual, arrays.ndarray) + self.match_dtype(actual, expected, msg) + self.match_shape(actual, expected, msg) + if not actual.shape: + self.assertEqual(actual.tolist(), expected.tolist()) + else: + self.assertSequenceEqual(actual.tolist(), expected.tolist()) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/tf_numpy/numpy_impl/arrays_test.py b/tests/tf_numpy/numpy_impl/arrays_test.py new file mode 100644 index 000000000..f99d38f8d --- /dev/null +++ b/tests/tf_numpy/numpy_impl/arrays_test.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ndarray.""" +from collections import abc + +import numpy as np +import tensorflow.compat.v2 as tf + +from trax.tf_numpy.numpy_impl import arrays + +# Required for operator overloads +from trax.tf_numpy.numpy_impl import math_ops # pylint: disable=unused-import + + +t2a = arrays.tensor_to_ndarray + + +class ArrayTest(tf.test.TestCase): + def testDtype(self): + a = t2a(tf.zeros(shape=[1, 2], dtype=tf.int64)) + self.assertIs(a.dtype.type, np.int64) + self.assertAllEqual(0, a.dtype.type(0)) + + def testAstype(self): + a = t2a(tf.convert_to_tensor(value=1.1, dtype=tf.float32)).astype(np.int32) + self.assertIs(a.dtype.type, np.int32) + self.assertAllEqual(1, a) + a = t2a(tf.convert_to_tensor(value=[0.0, 1.1], dtype=tf.float32)).astype( + np.bool_ + ) + self.assertIs(a.dtype.type, np.bool_) + self.assertAllEqual([False, True], a) + + def testNeg(self): + a = t2a(tf.convert_to_tensor(value=[1.0, 2.0])) + self.assertAllEqual([-1.0, -2.0], -a) + + def _testBinOp(self, a, b, out, f, types=None): + a = t2a(tf.convert_to_tensor(value=a, dtype=np.int32)) + b = t2a(tf.convert_to_tensor(value=b, dtype=np.int32)) + if not isinstance(out, arrays.ndarray): + out = t2a(tf.convert_to_tensor(value=out, dtype=np.int32)) + if types is None: + types = [ + [np.int32, np.int32, np.int32], + [np.int64, np.int32, np.int64], + [np.int32, np.int64, np.int64], + [np.float32, np.int32, np.float64], + [np.int32, np.float32, np.float64], + [np.float32, np.float32, np.float32], + [np.float64, np.float32, np.float64], + [np.float32, np.float64, np.float64], + ] + for a_type, b_type, out_type in types: + o = f(a.astype(a_type), b.astype(b_type)) + self.assertIs(o.dtype.type, out_type) + self.assertAllClose(out.astype(out_type), o) + + def testAdd(self): + self._testBinOp([1, 2], [3, 4], [4, 6], lambda a, b: a.__add__(b)) + + def testRadd(self): + self._testBinOp([1, 2], [3, 4], [4, 6], lambda a, b: b.__radd__(a)) + + def testSub(self): + self._testBinOp([1, 2], [3, 5], [-2, -3], lambda a, b: a.__sub__(b)) + + def testRsub(self): + self._testBinOp([1, 2], [3, 5], [-2, -3], lambda a, b: b.__rsub__(a)) + + def testMul(self): + self._testBinOp([1, 2], [3, 4], [3, 8], lambda a, b: a.__mul__(b)) + + def testRmul(self): + self._testBinOp([1, 2], [3, 4], [3, 8], lambda a, b: b.__rmul__(a)) + + def testPow(self): + self._testBinOp([4, 5], [3, 2], [64, 25], lambda a, b: a.__pow__(b)) + + def testRpow(self): + self._testBinOp([4, 5], [3, 2], [64, 25], lambda a, b: b.__rpow__(a)) + + _truediv_types = [ + [np.int32, np.int32, np.float64], + [np.int64, np.int32, np.float64], + [np.int32, np.int64, np.float64], + [np.float32, np.int32, np.float64], + [np.int32, np.float32, np.float64], + [np.float32, np.float32, np.float32], + [np.float64, np.float32, np.float64], + [np.float32, np.float64, np.float64], + ] + + def testTruediv(self): + self._testBinOp( + [3, 5], + [2, 4], + t2a(tf.convert_to_tensor(value=[1.5, 1.25])), + lambda a, b: a.__truediv__(b), + types=self._truediv_types, + ) + + def testRtruediv(self): + self._testBinOp( + [3, 5], + [2, 4], + t2a(tf.convert_to_tensor(value=[1.5, 1.25])), + lambda a, b: b.__rtruediv__(a), + types=self._truediv_types, + ) + + def _testCmp(self, a, b, out, f): + a = t2a(tf.convert_to_tensor(value=a, dtype=np.int32)) + b = t2a(tf.convert_to_tensor(value=b, dtype=np.int32)) + types = [ + [np.int32, np.int32], + [np.int64, np.int32], + [np.int32, np.int64], + [np.float32, np.int32], + [np.int32, np.float32], + [np.float32, np.float32], + [np.float64, np.float32], + [np.float32, np.float64], + ] + for a_type, b_type in types: + o = f(a.astype(a_type), b.astype(b_type)) + self.assertAllEqual(out, o) + + def testLt(self): + self._testCmp( + [1, 2, 3], [3, 2, 1], [True, False, False], lambda a, b: a.__lt__(b) + ) + + def testLe(self): + self._testCmp( + [1, 2, 3], [3, 2, 1], [True, True, False], lambda a, b: a.__le__(b) + ) + + def testGt(self): + self._testCmp( + [1, 2, 3], [3, 2, 1], [False, False, True], lambda a, b: a.__gt__(b) + ) + + def testGe(self): + self._testCmp( + [1, 2, 3], [3, 2, 1], [False, True, True], lambda a, b: a.__ge__(b) + ) + + def testEq(self): + self._testCmp( + [1, 2, 3], [3, 2, 1], [False, True, False], lambda a, b: a.__eq__(b) + ) + + def testNe(self): + self._testCmp( + [1, 2, 3], [3, 2, 1], [True, False, True], lambda a, b: a.__ne__(b) + ) + + def testInt(self): + v = 10 + u = int(t2a(tf.convert_to_tensor(value=v))) + self.assertIsInstance(u, int) + self.assertAllEqual(v, u) + + def testFloat(self): + v = 21.32 + u = float(t2a(tf.convert_to_tensor(value=v))) + self.assertIsInstance(u, float) + self.assertAllClose(v, u) + + def testBool(self): + b = bool(t2a(tf.convert_to_tensor(value=10))) + self.assertIsInstance(b, bool) + self.assertTrue(b) + self.assertFalse(bool(t2a(tf.convert_to_tensor(value=0)))) + self.assertTrue(bool(t2a(tf.convert_to_tensor(value=0.1)))) + self.assertFalse(bool(t2a(tf.convert_to_tensor(value=0.0)))) + + def testHash(self): + a = t2a(tf.convert_to_tensor(value=10)) + self.assertNotIsInstance(a, abc.Hashable) + with self.assertRaisesWithPredicateMatch(TypeError, r"unhashable type"): + hash(a) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/tf_numpy/numpy_impl/backprop_test.py b/tests/tf_numpy/numpy_impl/backprop_test.py new file mode 100644 index 000000000..f75a8eacd --- /dev/null +++ b/tests/tf_numpy/numpy_impl/backprop_test.py @@ -0,0 +1,67 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for backpropgration on tf-numpy functions.""" +import tensorflow.compat.v2 as tf + +from trax.tf_numpy.numpy_impl import array_ops + +# Required for operator overloads +from trax.tf_numpy.numpy_impl import math_ops # pylint: disable=unused-import + + +class BackpropTest(tf.test.TestCase): + def test_setitem(self): + # Single integer index. + a = array_ops.array([1.0, 2.0, 3.0]) + b = array_ops.array(5.0) + c = array_ops.array(10.0) + + tensors = [arr.data for arr in [a, b, c]] + with tf.GradientTape() as g: + g.watch(tensors) + a[1] = b + c + loss = array_ops.sum(a) + + gradients = g.gradient(loss.data, tensors) + self.assertSequenceEqual( + array_ops.array(gradients[0]).tolist(), [1.0, 0.0, 1.0] + ) + self.assertEqual(array_ops.array(gradients[1]).tolist(), 1.0) + self.assertEqual(array_ops.array(gradients[2]).tolist(), 1.0) + + # Tuple index. + a = array_ops.array( + [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] + ) # 2x2x2 array. + b = array_ops.array([10.0, 11.0]) + + tensors = [arr.data for arr in [a, b]] + with tf.GradientTape() as g: + g.watch(tensors) + a[(1, 0)] = b + loss = array_ops.sum(a) + + gradients = g.gradient(loss.data, tensors) + self.assertSequenceEqual( + array_ops.array(gradients[0]).tolist(), + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [1.0, 1.0]]], + ) + self.assertEqual(array_ops.array(gradients[1]).tolist(), [1.0, 1.0]) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/tf_numpy/numpy_impl/logic_test.py b/tests/tf_numpy/numpy_impl/logic_test.py new file mode 100644 index 000000000..99c1a6732 --- /dev/null +++ b/tests/tf_numpy/numpy_impl/logic_test.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tf numpy random number methods.""" +import numpy as np +import tensorflow.compat.v2 as tf + +from trax.tf_numpy.numpy_impl import array_ops +from trax.tf_numpy.numpy_impl import arrays +from trax.tf_numpy.numpy_impl import math_ops + + +class LogicTest(tf.test.TestCase): + def setUp(self): + super().setUp() + self.array_transforms = [ + lambda x: x, # Identity, + tf.convert_to_tensor, + np.array, + lambda x: np.array(x, dtype=np.int32), + lambda x: np.array(x, dtype=np.int64), + lambda x: np.array(x, dtype=np.float32), + lambda x: np.array(x, dtype=np.float64), + array_ops.array, + lambda x: array_ops.array(x, dtype=tf.int32), + lambda x: array_ops.array(x, dtype=tf.int64), + lambda x: array_ops.array(x, dtype=tf.float32), + lambda x: array_ops.array(x, dtype=tf.float64), + ] + + def testEqual(self): + def run_test(x1, x2=None): + if x2 is None: + x2 = x1 + for fn1 in self.array_transforms: + for fn2 in self.array_transforms: + arg1 = fn1(x1) + arg2 = fn2(x2) + self.match( + math_ops.equal(arg1, arg2), + np.equal( + make_numpy_compatible(arg1), make_numpy_compatible(arg2) + ), + ) + + run_test(1) + run_test(1, 2) + run_test([1, 2]) + run_test([1, 2, 3], [2]) + run_test([[1, 2], [3, 4]], [1, 2]) + run_test([[1, 2], [1, 4]], [1, 2]) + run_test([1, 2], [[1, 2], [1, 4]]) + run_test([[1, 2], [3, 4]], [[1, 2], [3, 4]]) + run_test([[1, 2], [3, 4]], [[1, 3], [3, 4]]) + + def match_shape(self, actual, expected, msg=None): + if msg: + msg = "Shape match failed for: {}. Expected: {} Actual: {}".format( + msg, expected.shape, actual.shape + ) + self.assertEqual(actual.shape, expected.shape, msg=msg) + if msg: + msg = "Shape: {} is not a tuple for {}".format(actual.shape, msg) + self.assertIsInstance(actual.shape, tuple, msg=msg) + + def match_dtype(self, actual, expected, msg=None): + if msg: + msg = "Dtype match failed for: {}. Expected: {} Actual: {}.".format( + msg, expected.dtype, actual.dtype + ) + self.assertEqual(actual.dtype, expected.dtype, msg=msg) + + def match(self, actual, expected, msg=None): + msg_ = "Expected: {} Actual: {}".format(expected, actual) + if msg: + msg = "{} {}".format(msg_, msg) + else: + msg = msg_ + self.assertIsInstance(actual, arrays.ndarray) + self.match_dtype(actual, expected, msg) + self.match_shape(actual, expected, msg) + if not actual.shape: + self.assertEqual(actual.tolist(), expected.tolist()) + else: + self.assertSequenceEqual(actual.tolist(), expected.tolist()) + + +def make_numpy_compatible(s): + return s if not isinstance(s, arrays.ndarray) else s.data.numpy() + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/tf_numpy/numpy_impl/math_ops_test.py b/tests/tf_numpy/numpy_impl/math_ops_test.py new file mode 100644 index 000000000..2f4556863 --- /dev/null +++ b/tests/tf_numpy/numpy_impl/math_ops_test.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tf numpy mathematical methods.""" +import itertools +from absl.testing import parameterized +import numpy as np +from six.moves import range +import tensorflow.compat.v2 as tf + +from trax.tf_numpy.numpy_impl import array_ops +from trax.tf_numpy.numpy_impl import arrays +from trax.tf_numpy.numpy_impl import math_ops + + +class MathTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.array_transforms = [ + lambda x: x, # Identity, + tf.convert_to_tensor, + np.array, + lambda x: np.array(x, dtype=np.float32), + lambda x: np.array(x, dtype=np.float64), + array_ops.array, + lambda x: array_ops.array(x, dtype=np.float32), + lambda x: array_ops.array(x, dtype=np.float64), + ] + self.types = [np.int32, np.int64, np.float32, np.float64] + + def _testBinaryOp( + self, + math_fun, + np_fun, + name, + operands=None, + extra_operands=None, + check_promotion=True, + check_promotion_result_type=True, + ): + def run_test(a, b): + for fn in self.array_transforms: + arg1 = fn(a) + arg2 = fn(b) + self.match( + math_fun(arg1, arg2), + np_fun(arg1, arg2), + msg="{}({}, {})".format(name, arg1, arg2), + ) + # Tests type promotion + for type_a in self.types: + for type_b in self.types: + if not check_promotion and type_a != type_b: + continue + arg1 = array_ops.array(a, dtype=type_a) + arg2 = array_ops.array(b, dtype=type_b) + self.match( + math_fun(arg1, arg2), + np_fun(arg1, arg2), + msg="{}({}, {})".format(name, arg1, arg2), + check_dtype=check_promotion_result_type, + ) + + if operands is None: + operands = [ + (5, 2), + (5, [2, 3]), + (5, [[2, 3], [6, 7]]), + ([1, 2, 3], 7), + ([1, 2, 3], [5, 6, 7]), + ] + for operand1, operand2 in operands: + run_test(operand1, operand2) + if extra_operands is not None: + for operand1, operand2 in extra_operands: + run_test(operand1, operand2) + + def testDot(self): + extra_operands = [ + ([1, 2], [[5, 6, 7], [8, 9, 10]]), + ( + np.arange(2 * 3 * 5).reshape([2, 3, 5]).tolist(), + np.arange(5 * 7 * 11).reshape([7, 5, 11]).tolist(), + ), + ] + return self._testBinaryOp( + math_ops.dot, np.dot, "dot", extra_operands=extra_operands + ) + + def testMinimum(self): + # The numpy version has strange result type when promotion happens, + # so set check_promotion_result_type to False. + return self._testBinaryOp( + math_ops.minimum, np.minimum, "minimum", check_promotion_result_type=False + ) + + def testMaximum(self): + # The numpy version has strange result type when promotion happens, + # so set check_promotion_result_type to False. + return self._testBinaryOp( + math_ops.maximum, np.maximum, "maximum", check_promotion_result_type=False + ) + + def testMatmul(self): + operands = [([[1, 2]], [[3, 4, 5], [6, 7, 8]])] + return self._testBinaryOp( + math_ops.matmul, np.matmul, "matmul", operands=operands + ) + + def testMatmulError(self): + with self.assertRaisesRegex(ValueError, r""): + math_ops.matmul( + array_ops.ones([], np.int32), array_ops.ones([2, 3], np.int32) + ) + with self.assertRaisesRegex(ValueError, r""): + math_ops.matmul( + array_ops.ones([2, 3], np.int32), array_ops.ones([], np.int32) + ) + + def _testUnaryOp(self, math_fun, np_fun, name): + def run_test(a): + for fn in self.array_transforms: + arg1 = fn(a) + self.match( + math_fun(arg1), np_fun(arg1), msg="{}({})".format(name, arg1) + ) + + run_test(5) + run_test([2, 3]) + run_test([[2, -3], [-6, 7]]) + + def testLog(self): + self._testUnaryOp(math_ops.log, np.log, "log") + + def testExp(self): + self._testUnaryOp(math_ops.exp, np.exp, "exp") + + def testTanh(self): + self._testUnaryOp(math_ops.tanh, np.tanh, "tanh") + + def testSqrt(self): + self._testUnaryOp(math_ops.sqrt, np.sqrt, "sqrt") + + def match(self, actual, expected, msg="", check_dtype=True): + self.assertIsInstance(actual, arrays.ndarray) + if check_dtype: + self.assertEqual( + actual.dtype, + expected.dtype, + "Dtype mismatch.\nActual: {}\nExpected: {}\n{}".format( + actual.dtype, expected.dtype, msg + ), + ) + self.assertEqual( + actual.shape, + expected.shape, + "Shape mismatch.\nActual: {}\nExpected: {}\n{}".format( + actual.shape, expected.shape, msg + ), + ) + self.assertAllClose(actual.tolist(), expected.tolist()) + + def testArgsort(self): + self._testUnaryOp(math_ops.argsort, np.argsort, "argsort") + + # Test stability + r = np.arange(100) + a = np.zeros(100) + np.testing.assert_equal(math_ops.argsort(a, kind="stable"), r) + + def testArgMaxArgMin(self): + data = [ + 0, + 5, + [1], + [1, 2, 3], + [[1, 2, 3]], + [[4, 6], [7, 8]], + [[[4, 6], [9, 10]], [[7, 8], [12, 34]]], + ] + for fn, d in itertools.product(self.array_transforms, data): + arr = fn(d) + self.match(math_ops.argmax(arr), np.argmax(arr)) + self.match(math_ops.argmin(arr), np.argmin(arr)) + if hasattr(arr, "shape"): + ndims = len(arr.shape) + else: + ndims = array_ops.array(arr, copy=False).ndim + if ndims == 0: + # Numpy flattens the scalar ndarray and treats it as a 1-d array of + # size 1. + ndims = 1 + for axis in range(-ndims, ndims): + self.match(math_ops.argmax(arr, axis=axis), np.argmax(arr, axis=axis)) + self.match(math_ops.argmin(arr, axis=axis), np.argmin(arr, axis=axis)) + + @parameterized.parameters([False, True]) + def testIsCloseEqualNan(self, equal_nan): + a = np.asarray([1, 1, np.nan, 1, np.nan], np.float32) + b = np.asarray([1, 2, 1, np.nan, np.nan], np.float32) + self.match( + math_ops.isclose(a, b, equal_nan=equal_nan), + np.isclose(a, b, equal_nan=equal_nan), + ) + + def testAverageWrongShape(self): + with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError, r""): + math_ops.average(np.ones([2, 3]), weights=np.ones([2, 4])) + with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError, r""): + math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([2, 4])) + with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError, r""): + math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([])) + with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError, r""): + math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([5])) + + def testClip(self): + def run_test(arr, *args, **kwargs): + check_dtype = kwargs.pop("check_dtype", True) + for fn in self.array_transforms: + arr = fn(arr) + self.match( + math_ops.clip(arr, *args, **kwargs), + np.clip(arr, *args, **kwargs), + check_dtype=check_dtype, + ) + + # NumPy exhibits weird typing behavior when a/a_min/a_max are scalars v/s + # lists, e.g., + # + # np.clip(np.array(0, dtype=np.int32), -5, 5).dtype == np.int64 + # np.clip(np.array([0], dtype=np.int32), -5, 5).dtype == np.int32 + # np.clip(np.array([0], dtype=np.int32), [-5], [5]).dtype == np.int64 + # + # So we skip matching type. In tf-numpy the type of the output array is + # always the same as the input array. + run_test(0, -1, 5, check_dtype=False) + run_test(-1, -1, 5, check_dtype=False) + run_test(5, -1, 5, check_dtype=False) + run_test(-10, -1, 5, check_dtype=False) + run_test(10, -1, 5, check_dtype=False) + run_test(10, None, 5, check_dtype=False) + run_test(10, -1, None, check_dtype=False) + run_test([0, 20, -5, 4], -1, 5, check_dtype=False) + run_test([0, 20, -5, 4], None, 5, check_dtype=False) + run_test([0, 20, -5, 4], -1, None, check_dtype=False) + run_test([0.5, 20.2, -5.7, 4.4], -1.5, 5.1, check_dtype=False) + + run_test([0, 20, -5, 4], [-5, 0, -5, 0], [0, 5, 0, 5], check_dtype=False) + run_test([[1, 2, 3], [4, 5, 6]], [2, 0, 2], 5, check_dtype=False) + run_test([[1, 2, 3], [4, 5, 6]], 0, [5, 3, 1], check_dtype=False) + + def testPtp(self): + def run_test(arr, *args, **kwargs): + for fn in self.array_transforms: + arg = fn(arr) + self.match( + math_ops.ptp(arg, *args, **kwargs), np.ptp(arg, *args, **kwargs) + ) + + run_test([1, 2, 3]) + run_test([1.0, 2.0, 3.0]) + run_test([[1, 2], [3, 4]], axis=1) + run_test([[1, 2], [3, 4]], axis=0) + run_test([[1, 2], [3, 4]], axis=-1) + run_test([[1, 2], [3, 4]], axis=-2) + + def testLinSpace(self): + array_transforms = [ + lambda x: x, # Identity, + tf.convert_to_tensor, + np.array, + lambda x: np.array(x, dtype=np.float32), + lambda x: np.array(x, dtype=np.float64), + array_ops.array, + lambda x: array_ops.array(x, dtype=np.float32), + lambda x: array_ops.array(x, dtype=np.float64), + ] + + def run_test(start, stop, **kwargs): + for fn1 in array_transforms: + for fn2 in array_transforms: + arg1 = fn1(start) + arg2 = fn2(stop) + self.match( + math_ops.linspace(arg1, arg2, **kwargs), + np.linspace(arg1, arg2, **kwargs), + msg="linspace({}, {})".format(arg1, arg2), + ) + + run_test(0, 1) + run_test(0, 1, num=10) + run_test(0, 1, endpoint=False) + run_test(0, -1) + run_test(0, -1, num=10) + run_test(0, -1, endpoint=False) + + def testLogSpace(self): + array_transforms = [ + lambda x: x, # Identity, + tf.convert_to_tensor, + np.array, + lambda x: np.array(x, dtype=np.float32), + lambda x: np.array(x, dtype=np.float64), + array_ops.array, + lambda x: array_ops.array(x, dtype=np.float32), + lambda x: array_ops.array(x, dtype=np.float64), + ] + + def run_test(start, stop, **kwargs): + for fn1 in array_transforms: + for fn2 in array_transforms: + arg1 = fn1(start) + arg2 = fn2(stop) + self.match( + math_ops.logspace(arg1, arg2, **kwargs), + np.logspace(arg1, arg2, **kwargs), + msg="logspace({}, {})".format(arg1, arg2), + ) + + run_test(0, 5) + run_test(0, 5, num=10) + run_test(0, 5, endpoint=False) + run_test(0, 5, base=2.0) + run_test(0, -5) + run_test(0, -5, num=10) + run_test(0, -5, endpoint=False) + run_test(0, -5, base=2.0) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/tf_numpy/numpy_impl/random_test.py b/tests/tf_numpy/numpy_impl/random_test.py new file mode 100644 index 000000000..83c7c1a1d --- /dev/null +++ b/tests/tf_numpy/numpy_impl/random_test.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tf numpy random number methods.""" +import numpy as np +import tensorflow.compat.v2 as tf +from six.moves import range + +# Needed for ndarray.reshape. +from trax.tf_numpy.numpy_impl import random + + +class RandomTest(tf.test.TestCase): + def assertNotAllClose(self, a, b, **kwargs): + try: + self.assertAllClose(a, b, **kwargs) + except AssertionError: + return + raise AssertionError("The two values are close at all %d elements" % np.size(a)) + + def testRandN(self): + def run_test(*args): + num_samples = 1000 + tol = 0.1 # High tolerance to keep the # of samples low else the test + # takes a long time to run. + random.seed(10) + outputs = [random.randn(*args) for _ in range(num_samples)] + + # Test output shape. + for output in outputs: + self.assertEqual(output.shape, tuple(args)) + self.assertEqual(output.dtype.type, random.DEFAULT_RANDN_DTYPE) + + if np.prod(args): # Don't bother with empty arrays. + outputs = [output.tolist() for output in outputs] + + # Test that the properties of normal distribution are satisfied. + mean = np.mean(outputs, axis=0) + stddev = np.std(outputs, axis=0) + self.assertAllClose(mean, np.zeros(args), atol=tol) + self.assertAllClose(stddev, np.ones(args), atol=tol) + + # Test that outputs are different with different seeds. + random.seed(20) + diff_seed_outputs = [ + random.randn(*args).tolist() for _ in range(num_samples) + ] + self.assertNotAllClose(outputs, diff_seed_outputs) + + # Test that outputs are the same with the same seed. + random.seed(10) + same_seed_outputs = [ + random.randn(*args).tolist() for _ in range(num_samples) + ] + self.assertAllClose(outputs, same_seed_outputs) + + run_test() + run_test(0) + run_test(1) + run_test(5) + run_test(2, 3) + run_test(0, 2, 3) + run_test(2, 0, 3) + run_test(2, 3, 0) + run_test(2, 3, 5) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/trax/tf_numpy/numpy_impl/tests/utils_test.py b/tests/tf_numpy/numpy_impl/utils_test.py similarity index 66% rename from trax/tf_numpy/numpy_impl/tests/utils_test.py rename to tests/tf_numpy/numpy_impl/utils_test.py index ca27a9f21..cc32436dc 100644 --- a/trax/tf_numpy/numpy_impl/tests/utils_test.py +++ b/tests/tf_numpy/numpy_impl/utils_test.py @@ -21,16 +21,18 @@ class UtilsTest(tf.test.TestCase): - # pylint: disable=unused-argument - def testNpDoc(self): - def np_fun(x): - """np_fun docstring.""" - return - @utils.np_doc(np_fun) - def f(): - """f docstring.""" - return - expected = """TensorFlow variant of `numpy.np_fun`. + # pylint: disable=unused-argument + def testNpDoc(self): + def np_fun(x): + """np_fun docstring.""" + return + + @utils.np_doc(np_fun) + def f(): + """f docstring.""" + return + + expected = """TensorFlow variant of `numpy.np_fun`. Unsupported arguments: `x`. @@ -39,9 +41,9 @@ def f(): Documentation for `numpy.np_fun`: np_fun docstring.""" - self.assertEqual(f.__doc__, expected) + self.assertEqual(f.__doc__, expected) -if __name__ == '__main__': - tf.enable_v2_behavior() - tf.test.main() +if __name__ == "__main__": + tf.enable_v2_behavior() + tf.test.main() diff --git a/trax/tf_numpy/public_symbol_test.py b/tests/tf_numpy/public_symbol_test.py similarity index 80% rename from trax/tf_numpy/public_symbol_test.py rename to tests/tf_numpy/public_symbol_test.py index 23f3ebe4e..7724c1414 100644 --- a/trax/tf_numpy/public_symbol_test.py +++ b/tests/tf_numpy/public_symbol_test.py @@ -25,14 +25,13 @@ class PublicSymbolTest(tf.test.TestCase): - - def testSimple(self): - a = 0.1 - b = 0.2 - for op in [np1.add, np2.add, np3.add]: - self.assertAllClose(onp.add(a, b), op(a, b)) + def testSimple(self): + a = 0.1 + b = 0.2 + for op in [np1.add, np2.add, np3.add]: + self.assertAllClose(onp.add(a, b), op(a, b)) if __name__ == "__main__": - tf.compat.v1.enable_eager_execution() - tf.test.main() + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/trax2keras_test.py b/tests/trax2keras_test.py new file mode 100644 index 000000000..0d54c3493 --- /dev/null +++ b/tests/trax2keras_test.py @@ -0,0 +1,230 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax2keras.""" + +import os + + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as onp + +import tensorflow.compat.v2 as tf + + +import trax +from trax import fastmath as math_lib +from trax import layers +from trax import trax2keras +from trax.fastmath import numpy as jnp +from trax.models import mlp +from trax.models import transformer +from trax.trax2keras import read_values +from trax.trax2keras import to_arrays +from trax.trax2keras import to_tensors + +tf.enable_v2_behavior() + + +def has_gpu(): + return bool(tf.config.list_physical_devices("GPU")) + + +def dummy_inputs(rng, input_sig): + def f(sig): + shape = sig.shape + if shape and shape[0] is None: + shape = (2,) + tuple(shape[1:]) + if onp.issubdtype(sig.dtype, onp.integer): + minval = 1 + # Must specify maxval for integer dtype. + # TODO(afrozm): Revisit after TF 2.3 + maxval = 10000 + else: + minval = 0 + maxval = 1 + return rng.uniform(shape=shape, dtype=sig.dtype, minval=minval, maxval=maxval) + + return math_lib.nested_map(f, input_sig) + + +def Mod(n): # pylint: disable=invalid-name + return layers.Fn("Mod", lambda x: x % n) + + +# Format: +# (trax-layer maker, input shapes, input dtype, can handle None batch size?) +_LAYERS = [ + (lambda: layers.Dense(3), tf.TensorShape([4]), onp.float32, True), + (mlp.MLP, tf.TensorShape([4]), onp.float32, False), + ( + lambda: layers.Serial(Mod(8), transformer.TransformerLM(8)), + tf.TensorShape([4]), + onp.int32, + False, + ), +] + +_RNG_UPDATERS = [ + lambda x: x, + lambda rng: math_lib.random.split(rng, 1)[0], +] + + +# Needs tf.test.TestCase for `assertAllClose` and `get_temp_dir` +class Trax2KerasTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + [ + { + "testcase_name": "_%s_%s_%s_%s_%s_%s" + % ( # pylint: disable=g-complex-comprehension + layer_id, + rng_updater_id, + batch_size, + trax_has_weights, + explicit_build, + use_model, + ), + "layer_id": layer_id, + "rng_updater_id": rng_updater_id, + "batch_size": batch_size, + "trax_has_weights": trax_has_weights, + "explicit_build": explicit_build, + "use_model": use_model, + } + for use_model in [True, False] + for explicit_build in [True, False] + for trax_has_weights in [True, False] + for batch_size in [2, None] + for rng_updater_id in [1] + for layer_id in range(len(_LAYERS)) + ] + ) + def testTrain( + self, + layer_id, + rng_updater_id, + batch_size, + trax_has_weights, + explicit_build, + use_model, + ): + """Tests training (forward and backward pass) for AsKeras. + + Args: + layer_id: an integer, the index into `_LAYERS`. + rng_updater_id: an integer, the index into `_RNG_UPDATERS`. + batch_size: an integer or `None`, the value for the `batch_size` argument + in `AsKeras.__init__`. + trax_has_weights: bool, whether to make the trax layer contain weights at + the time when `AsKeras.build` is called. + explicit_build: bool, whether to explicitly call `AsKeras.build`. + use_model: bool, whether to build a `tf.keras.Model` out of the + `AsKeras` layer and use the model to do the training instead of + the bare layer. If `True`, we will also test checkpointing and restoring + using the model. + """ + with trax.fastmath.use_backend("tensorflow-numpy"): + make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = _LAYERS[ + layer_id + ] + # We make a fresh trax layer for each test case, so that different test + # cases won't interfere with each other. + trax_layer = make_trax_layer() + if not allow_none_batch and batch_size is None: + self.skipTest("This Trax layer can't handle None batch size.") + rng_updater = _RNG_UPDATERS[rng_updater_id] + input_shapes = math_lib.nested_map( + lambda s: [batch_size] + s, input_shapes_no_batch + ) + input_sig = trax2keras.tensor_shapes_to_shape_dtypes(input_shapes, dtype) + initializer_rng = math_lib.random.get_prng(765) + weights, state = trax_layer.init(input_sig, rng=initializer_rng) + generator = tf.random.Generator.from_seed(567) + + def get_inputs(): + return dummy_inputs(generator, input_sig) + + if trax_has_weights: + trax_layer(to_arrays(get_inputs()), weights=weights, state=state) + rng = math_lib.random.get_prng(1234) + keras_layer = trax2keras.AsKeras( + trax_layer, + batch_size=batch_size, + initializer_rng=initializer_rng, + rng=rng, + rng_updater=rng_updater, + ) + if explicit_build: + keras_layer.build(input_shapes) + if use_model: + x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype) + y = keras_layer(x) + keras_model = tf.keras.Model(inputs=x, outputs=y) + lr = 0.1 # learning rate + for _ in range(3): + inputs = get_inputs() + with tf.GradientTape() as trax_tape: + trax_tape.watch(tf.nest.flatten(weights)) + trax_outputs, state = trax_layer.pure_fn( + to_arrays(inputs), weights=weights, state=state, rng=rng + ) + trax_grads = trax_tape.gradient(*to_tensors([trax_outputs, weights])) + # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor` + # before multiplication. + weights = tf.nest.map_structure( + lambda w, g: w + jnp.asarray(lr * tf.convert_to_tensor(g), w.dtype), + weights, + trax_grads, + ) + rng = rng_updater(rng) + with tf.GradientTape() as keras_tape: + if use_model: + keras_outputs = keras_model(inputs) + else: + keras_outputs = keras_layer(inputs) + if isinstance(keras_outputs, tuple) and len(keras_outputs) == 1: + keras_outputs = keras_outputs[0] + self.assertAllClose(to_tensors(trax_outputs), keras_outputs, atol=1e-5) + keras_grads = keras_tape.gradient( + keras_outputs, keras_layer.trainable_variables + ) + tf.nest.map_structure( + lambda v, g: v.assign_add( # pylint: disable=g-long-lambda + tf.cast(lr * tf.convert_to_tensor(g), v.dtype) + ), + keras_layer.trainable_variables, + keras_grads, + ) + self.assertAllClose( + to_tensors(weights), + read_values(keras_layer._weights), + rtol=2e-6, + atol=4.5e-4 if has_gpu() else 1e-6, + ) + self.assertAllClose(to_tensors(state), read_values(keras_layer._state)) + self.assertAllClose(to_tensors(rng), read_values(keras_layer._rng)) + if use_model: + fname = os.path.join(self.get_temp_dir(), "checkpoint") + keras_model.save(fname) + loaded_model = tf.keras.models.load_model(fname) + for _ in range(2): + inputs = get_inputs() + self.assertAllClose(keras_model(inputs), loaded_model(inputs)) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/__init__.py b/trax/__init__.py index 8747ec520..48a3348d2 100644 --- a/trax/__init__.py +++ b/trax/__init__.py @@ -14,13 +14,3 @@ # limitations under the License. """Trax top level import.""" - -from trax import data -from trax import fastmath -from trax import layers -from trax import models -from trax import optimizers -from trax import shapes -from trax import supervised -from trax.supervised import lr_schedules as lr -from trax.trax2keras import AsKeras diff --git a/trax/data/__init__.py b/trax/data/__init__.py index 9f1ed919b..39e506746 100644 --- a/trax/data/__init__.py +++ b/trax/data/__init__.py @@ -101,4 +101,3 @@ from trax.data.tf_inputs import vocab_size from trax.data.tf_inputs import wmt_concat_preprocess from trax.data.tf_inputs import wmt_preprocess - diff --git a/trax/data/debug_data_pipeline.py b/trax/data/debug_data_pipeline.py index 7506149cc..411f4fad8 100644 --- a/trax/data/debug_data_pipeline.py +++ b/trax/data/debug_data_pipeline.py @@ -21,21 +21,22 @@ import gin -@gin.configurable(denylist=['f']) -def debug_pipeline(f, debug=False, method='pow', log_prefix=None): - """Decorator for input pipeline generators that logs examples at intervals.""" - if not debug: - return f - - assert method in ('pow', 'every') - @functools.wraps(f) - def wrapper(*args, **kwargs): - count = 0 - prefix = log_prefix or f.__name__ - for example in f(*args, **kwargs): - count += 1 - if method == 'every' or (method == 'pow' and (count & count - 1 == 0)): - logging.info('%s example[%d] = %r', prefix, count, example) - yield example - - return wrapper +@gin.configurable(denylist=["f"]) +def debug_pipeline(f, debug=False, method="pow", log_prefix=None): + """Decorator for input pipeline generators that logs examples at intervals.""" + if not debug: + return f + + assert method in ("pow", "every") + + @functools.wraps(f) + def wrapper(*args, **kwargs): + count = 0 + prefix = log_prefix or f.__name__ + for example in f(*args, **kwargs): + count += 1 + if method == "every" or (method == "pow" and (count & count - 1 == 0)): + logging.info("%s example[%d] = %r", prefix, count, example) + yield example + + return wrapper diff --git a/trax/data/inputs.py b/trax/data/inputs.py index de15497f4..2db869892 100644 --- a/trax/data/inputs.py +++ b/trax/data/inputs.py @@ -90,12 +90,14 @@ def Serial(*fns): # pylint: disable=invalid-name - """Combines generator functions into one that runs them serially.""" - def composed_fns(generator=None): - for f in fastmath.tree_flatten(fns): - generator = f(generator) - return generator - return composed_fns + """Combines generator functions into one that runs them serially.""" + + def composed_fns(generator=None): + for f in fastmath.tree_flatten(fns): + generator = f(generator) + return generator + + return composed_fns # TODO(jonni): Rename to Blend/Merge/Mix/Interleave/...? @@ -104,1080 +106,1205 @@ def Parallel( # pylint: disable=invalid-name counters=None, reweight_by_minimum=False, gradually_reweight=False, - use_remainders=False): - """Combines generator functions into one that runs them in parallel. - - Args: - fns: a sequence of datasets which are combined in parallel. - counters: a sequence of ints with same length as fns, please see comments on - its use below. - reweight_by_minimum: if set to True, then we re-weight every counter by the - minimal counter. E.g. counters (10000, 100000) are translated to (1, 10) - and hence for every 10 examples from the second dataset we are getting - 1 example from the first dataset. Without reweighting first we would see - 20 examples from the first and second dataset and then 90 thousand eamples - only from the first dataset. - gradually_reweight: if set to True, then we loop through the generators - using a recursive rule defined in emit_examples. First we sort generators - by the counters. If we have datasets with counters 1, 20, 40 - (after sorting) then we yield examples (a(b c^2)^20)^*, where examples of - type a come from the first dataset, of type b from the second and of type - c from the third. The exponents are obtained through divisions of - subsequent counters. - use_remainders: if set to True as weell as gradually_reweight is set to - True and counters are 1, 20, 45 then after dealing with all examples in - the format (a(b c^2)^20)^*, the generator yields the remaining 5 examples - from the dataset with counter 45. - Returns: - parallel_generator: the generator yields samples according to given; - if counters are not given then samples are genereted uniformly. - - Example 1: - - gen = data.Parallel([dataset1, dataset2, dataset3], counters=(2, 1, 3)) - - defines a generator that yields 33% examples from dataset1, 16% examples from - dataset2 and 50% examples from dataset3. - - Example 2: - - gen = data.Parallel([dataset1, dataset2, dataset3], counters=(20, 50, 30)) - - defines a generator that yields 20% examples from dataset1, 50% examples from - dataset2 and 30% examples from dataset3. - """ - - if counters: - assert len(counters) == len(fns) - # Remove generators with zero counters - counters = list(counters) - fns = list(fns) - non_zeros = [j for j in range(len(counters)) if counters[j] != 0] - counters = [counters[j] for j in non_zeros] - fns = [fns[j] for j in non_zeros] - else: - counters = [1] * len(fns) - - if reweight_by_minimum: - counters = [math.floor(counter / min(counters)) for counter in counters] - - def emit_examples(sorted_counters_with_gens, prev_counter): - if sorted_counters_with_gens: - _, counter, generator = sorted_counters_with_gens[0] - repeats = math.floor(counter / prev_counter) - for _ in range(repeats): - yield next(generator) - yield from emit_examples(sorted_counters_with_gens[1:], counter) - - def parallel_generator(gen=None): - # If gradually_reweight is set to False then - # current_counters are increased step by step; they are reset to 0s when - # current_counters[idx] == counters[idx] for all idx. See - # test_parallel_with_weights_three_datasets for an example of how - # current_counters are changed during computation. - # If gradually_reweight is set to False then we loop using a - # recursive rule defined in emit_examples. - - generators = [] - for f in fns: - if gen: - generators.append(f(gen)) - else: - # This handles the case when the function f cannot be - # called on None. - generators.append(f()) - - if gradually_reweight: - counters_with_gens = zip(range(len(generators)), counters, generators) - sorted_counters_with_gens = sorted(counters_with_gens, key=lambda x: x[1]) - while True: - yield from emit_examples(sorted_counters_with_gens, min(counters)) - if use_remainders: - # Below we are dealing with remainders. - fractions = [] - for i in range(len(sorted_counters_with_gens)): - _, counter, generator = sorted_counters_with_gens[i] - processed = 1 - for fraction in fractions: - processed *= fraction - remainder = counter - processed - for _ in range(remainder): - yield next(generator) - if i < len(sorted_counters_with_gens) - 1: - _, next_counter, _ = sorted_counters_with_gens[i + 1] - fractions.append(math.floor(next_counter / counter)) + use_remainders=False, +): + """Combines generator functions into one that runs them in parallel. + + Args: + fns: a sequence of datasets which are combined in parallel. + counters: a sequence of ints with same length as fns, please see comments on + its use below. + reweight_by_minimum: if set to True, then we re-weight every counter by the + minimal counter. E.g. counters (10000, 100000) are translated to (1, 10) + and hence for every 10 examples from the second dataset we are getting + 1 example from the first dataset. Without reweighting first we would see + 20 examples from the first and second dataset and then 90 thousand eamples + only from the first dataset. + gradually_reweight: if set to True, then we loop through the generators + using a recursive rule defined in emit_examples. First we sort generators + by the counters. If we have datasets with counters 1, 20, 40 + (after sorting) then we yield examples (a(b c^2)^20)^*, where examples of + type a come from the first dataset, of type b from the second and of type + c from the third. The exponents are obtained through divisions of + subsequent counters. + use_remainders: if set to True as weell as gradually_reweight is set to + True and counters are 1, 20, 45 then after dealing with all examples in + the format (a(b c^2)^20)^*, the generator yields the remaining 5 examples + from the dataset with counter 45. + Returns: + parallel_generator: the generator yields samples according to given; + if counters are not given then samples are genereted uniformly. + + Example 1: + + gen = data.Parallel([dataset1, dataset2, dataset3], counters=(2, 1, 3)) + + defines a generator that yields 33% examples from dataset1, 16% examples from + dataset2 and 50% examples from dataset3. + + Example 2: + + gen = data.Parallel([dataset1, dataset2, dataset3], counters=(20, 50, 30)) + + defines a generator that yields 20% examples from dataset1, 50% examples from + dataset2 and 30% examples from dataset3. + """ + + if counters: + assert len(counters) == len(fns) + # Remove generators with zero counters + counters = list(counters) + fns = list(fns) + non_zeros = [j for j in range(len(counters)) if counters[j] != 0] + counters = [counters[j] for j in non_zeros] + fns = [fns[j] for j in non_zeros] else: - current_counters = [0] * len(generators) - while True: - for idx, generator in enumerate(generators): - if current_counters[idx] < counters[idx]: - current_counters[idx] += 1 - # instead of checking current_counters[idx] == counters[idx] for - # all idx, we check the equivalent condition: - if sum(current_counters) == sum(counters): - current_counters = [0] * len(generators) - yield next(generator) + counters = [1] * len(fns) + + if reweight_by_minimum: + counters = [math.floor(counter / min(counters)) for counter in counters] + + def emit_examples(sorted_counters_with_gens, prev_counter): + if sorted_counters_with_gens: + _, counter, generator = sorted_counters_with_gens[0] + repeats = math.floor(counter / prev_counter) + for _ in range(repeats): + yield next(generator) + yield from emit_examples(sorted_counters_with_gens[1:], counter) + + def parallel_generator(gen=None): + # If gradually_reweight is set to False then + # current_counters are increased step by step; they are reset to 0s when + # current_counters[idx] == counters[idx] for all idx. See + # test_parallel_with_weights_three_datasets for an example of how + # current_counters are changed during computation. + # If gradually_reweight is set to False then we loop using a + # recursive rule defined in emit_examples. + + generators = [] + for f in fns: + if gen: + generators.append(f(gen)) + else: + # This handles the case when the function f cannot be + # called on None. + generators.append(f()) + + if gradually_reweight: + counters_with_gens = zip(range(len(generators)), counters, generators) + sorted_counters_with_gens = sorted(counters_with_gens, key=lambda x: x[1]) + while True: + yield from emit_examples(sorted_counters_with_gens, min(counters)) + if use_remainders: + # Below we are dealing with remainders. + fractions = [] + for i in range(len(sorted_counters_with_gens)): + _, counter, generator = sorted_counters_with_gens[i] + processed = 1 + for fraction in fractions: + processed *= fraction + remainder = counter - processed + for _ in range(remainder): + yield next(generator) + if i < len(sorted_counters_with_gens) - 1: + _, next_counter, _ = sorted_counters_with_gens[i + 1] + fractions.append(math.floor(next_counter / counter)) + else: + current_counters = [0] * len(generators) + while True: + for idx, generator in enumerate(generators): + if current_counters[idx] < counters[idx]: + current_counters[idx] += 1 + # instead of checking current_counters[idx] == counters[idx] for + # all idx, we check the equivalent condition: + if sum(current_counters) == sum(counters): + current_counters = [0] * len(generators) + yield next(generator) - return parallel_generator + return parallel_generator -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def Shuffle(queue_size=1024): # pylint: disable=invalid-name - """Returns a shuffle function with the given queue size.""" - return lambda g: shuffle(g, queue_size) + """Returns a shuffle function with the given queue size.""" + return lambda g: shuffle(g, queue_size) -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def Batch(batch_size): # pylint: disable=invalid-name - """Returns a batching function with given batch size.""" - return lambda g: batch(g, batch_size) + """Returns a batching function with given batch size.""" + return lambda g: batch(g, batch_size) -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def Dup(): # pylint: disable=invalid-name - """Duplicates (copies) the top element (inputs). + """Duplicates (copies) the top element (inputs). + + The generator stream is augmented in the following way: - The generator stream is augmented in the following way: + - If the stream consists of a single element `(inputs, )`, + the inputs simply get copied to `(inputs, inputs)`. + - If the stream consists of multiple elements, for example + `(inputs, weights)`, the rest of elements get moved toward + the right side `(inputs, inputs, weights)`. - - If the stream consists of a single element `(inputs, )`, - the inputs simply get copied to `(inputs, inputs)`. - - If the stream consists of multiple elements, for example - `(inputs, weights)`, the rest of elements get moved toward - the right side `(inputs, inputs, weights)`. + Returns: + the duplicating function. + """ + + def _copy(xs): + x, *rest = xs + return (x, x, *rest) - Returns: - the duplicating function. - """ - def _copy(xs): - x, *rest = xs - return (x, x, *rest) - return lambda g: map(lambda x: _copy(x), g) # pylint: disable=unnecessary-lambda + return lambda g: map(lambda x: _copy(x), g) # pylint: disable=unnecessary-lambda -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def FilterEmptyExamples(axes=None, debug=False): # pylint: disable=invalid-name - """Filters empty examples. - - Filters any example that has an array of size (0,) (if axes=None). - Alternatively, checks only axes provided in `axes' list. Contrary to - FilterByLength used with several elements with length_axis, here the example - would be filtered if ANY of the dimensions listed in `axes' contains an empty - array. - - Args: - axes: list of indices to check, if None, all of them. - debug: If true, emits a log everytime we filter out an empty example. - - Returns: - Function filtering empty examples. - """ - def _filter_examples(generator): - for example in generator: - correct = True - for i, unused_tuple_element in enumerate(example): - if axes is None or i in axes: - if example[i].shape == (0,): - correct = False - break - if correct: - yield example - elif debug: - logging.info('Filtered example: %r', example) - return _filter_examples - - -@gin.configurable(module='trax.data') -def FilterByLength(max_length, min_length=0, # pylint: disable=invalid-name - length_keys=None, length_axis=0): - """Returns a function that filters out examples by length. - - Args: - max_length: int. If not None, indicates maximum length. - min_length: int. If not None, indicates minimum length. - length_keys: (list) which example keys to take into account. - length_axis: which shape axis to take into account. - Returns: - a function that filters out examples by length. - """ - - assert max_length is not None or min_length is not None - length_keys = length_keys or [0, 1] - length_fn = lambda x: _length_fn(x, length_axis, length_keys) - def filtered(gen): - for example in gen: - example_len = length_fn(example) - - # Checking max length boundary. - if max_length is not None: - if example_len > max_length: - continue - # Checking min length boundary. - if min_length is not None: - if example_len < min_length: - continue - # Within bounds. - yield example - return filtered - - -@gin.configurable(module='trax.data') + """Filters empty examples. + + Filters any example that has an array of size (0,) (if axes=None). + Alternatively, checks only axes provided in `axes' list. Contrary to + FilterByLength used with several elements with length_axis, here the example + would be filtered if ANY of the dimensions listed in `axes' contains an empty + array. + + Args: + axes: list of indices to check, if None, all of them. + debug: If true, emits a log everytime we filter out an empty example. + + Returns: + Function filtering empty examples. + """ + + def _filter_examples(generator): + for example in generator: + correct = True + for i, unused_tuple_element in enumerate(example): + if axes is None or i in axes: + if example[i].shape == (0,): + correct = False + break + if correct: + yield example + elif debug: + logging.info("Filtered example: %r", example) + + return _filter_examples + + +@gin.configurable(module="trax.data") +def FilterByLength( + max_length, + min_length=0, # pylint: disable=invalid-name + length_keys=None, + length_axis=0, +): + """Returns a function that filters out examples by length. + + Args: + max_length: int. If not None, indicates maximum length. + min_length: int. If not None, indicates minimum length. + length_keys: (list) which example keys to take into account. + length_axis: which shape axis to take into account. + Returns: + a function that filters out examples by length. + """ + + assert max_length is not None or min_length is not None + length_keys = length_keys or [0, 1] + length_fn = lambda x: _length_fn(x, length_axis, length_keys) + + def filtered(gen): + for example in gen: + example_len = length_fn(example) + + # Checking max length boundary. + if max_length is not None: + if example_len > max_length: + continue + # Checking min length boundary. + if min_length is not None: + if example_len < min_length: + continue + # Within bounds. + yield example + + return filtered + + +@gin.configurable(module="trax.data") def TruncateToLength(len_map=None): # pylint: disable=invalid-name - """Returns a stream function that resizes items as specified by ``len_map``. - - Args: - len_map: Dictionary that specifies maximum shapes for potentially multiple - features per stream item. For example, given a stream of tokenized - string pairs, one could enforce a maximum length of 256 tokens for each - string by using ``len_map={0: (256,), 1: (256,)}``. - """ - @debug_data_pipeline.debug_pipeline - def _truncate_to_length(generator): - for example in generator: - if isinstance(example, np.ndarray): - example = (example,) - if isinstance(example, (list, tuple)): - example = list(example) - if len_map is not None: - for key, max_len in len_map.items(): - example_len = example[key].shape - if example_len > max_len: - example[key] = np.resize(example[key], max_len) - output = tuple(example) - else: - output = None - raise ValueError(f'Unknown example type: {example}') - yield output - - return _truncate_to_length - - -@gin.configurable(module='trax.data') + """Returns a stream function that resizes items as specified by ``len_map``. + + Args: + len_map: Dictionary that specifies maximum shapes for potentially multiple + features per stream item. For example, given a stream of tokenized + string pairs, one could enforce a maximum length of 256 tokens for each + string by using ``len_map={0: (256,), 1: (256,)}``. + """ + + @debug_data_pipeline.debug_pipeline + def _truncate_to_length(generator): + for example in generator: + if isinstance(example, np.ndarray): + example = (example,) + if isinstance(example, (list, tuple)): + example = list(example) + if len_map is not None: + for key, max_len in len_map.items(): + example_len = example[key].shape + if example_len > max_len: + example[key] = np.resize(example[key], max_len) + output = tuple(example) + else: + output = None + raise ValueError(f"Unknown example type: {example}") + yield output + + return _truncate_to_length + + +@gin.configurable(module="trax.data") def PadToLength( # pylint: disable=invalid-name - len_map=None, pad_value=0, multiple=False): - """Pads the values to lengths given in `len_map'. - - len_map contains a dictionary of example keys to dimension sizes. - - Args: - len_map: dict of int to int, we pad examples to lengths - given by the values of the dict. If multiple is True, the dimensions are - padded to multiple of this value. - pad_value: dict of int to int. The value gets applied to - constant_values on numpy.pad per given dimension. - multiple: boolean. If False, pads to the value of len_map. If True, pads to - closest multiple of value of len_map. - Returns: - Function to pad examples to given lengths. - """ - @debug_data_pipeline.debug_pipeline - def _pad_to_length(generator): - for example in generator: - if isinstance(example, (list, tuple)): - example = list(example) - for key, value in len_map.items(): - array_length = example[key].shape[0] - if multiple: - padding_len = array_length - ((array_length // value) * value) - else: - padding_len = max([0, value-example[key].shape[0]]) - example[key] = np.pad(example[key], - pad_width=(0, padding_len), - mode='constant', - constant_values=pad_value[key]) - output = tuple(example) - else: - if not isinstance(example, np.ndarray): - raise ValueError(f'example isn\'t nparray, but should be: {example}') - array_length = example.shape[0] - if multiple: - padding_len = ( - array_length - ((array_length // len_map[0]) * len_map[0])) - else: - padding_len = max(0, len_map[0] - array_length) - output = np.pad(example, + len_map=None, pad_value=0, multiple=False +): + """Pads the values to lengths given in `len_map'. + + len_map contains a dictionary of example keys to dimension sizes. + + Args: + len_map: dict of int to int, we pad examples to lengths + given by the values of the dict. If multiple is True, the dimensions are + padded to multiple of this value. + pad_value: dict of int to int. The value gets applied to + constant_values on numpy.pad per given dimension. + multiple: boolean. If False, pads to the value of len_map. If True, pads to + closest multiple of value of len_map. + Returns: + Function to pad examples to given lengths. + """ + + @debug_data_pipeline.debug_pipeline + def _pad_to_length(generator): + for example in generator: + if isinstance(example, (list, tuple)): + example = list(example) + for key, value in len_map.items(): + array_length = example[key].shape[0] + if multiple: + padding_len = array_length - ((array_length // value) * value) + else: + padding_len = max([0, value - example[key].shape[0]]) + example[key] = np.pad( + example[key], pad_width=(0, padding_len), - mode='constant', - constant_values=pad_value[0]) - yield output - if len_map is None: - raise ValueError('len_map parameter should be provided.') - return _pad_to_length - - -@gin.configurable(module='trax.data') -def BucketByLength(boundaries, batch_sizes, # pylint: disable=invalid-name - length_keys=None, length_axis=0, strict_pad_on_len=False): - """Returns a function for bucketing inputs, see `bucket_by_length`.""" - length_keys = length_keys or [0, 1] - # In all cases so far, we use a length function of the following form. - length_fn = lambda x: _length_fn(x, length_axis, length_keys) - return lambda g: bucket_by_length( # pylint: disable=g-long-lambda - g, length_fn, boundaries, batch_sizes, strict_pad_on_len) - - -@gin.configurable(module='trax.data') -def MLM(vocab_size=None, # pylint:disable=invalid-name - max_length=None, - noise_density=0.15, - mean_noise_span_length=3.0): - """Pipeline that just does MLM.""" - return Serial( - # Generate sequential chunks. - generate_sequential_chunks(max_length=max_length), - # Generate mask and chunk. - generate_random_noise_mask( - noise_density=noise_density, - mean_noise_span_length=mean_noise_span_length), - # Consume mask and chunk to give (input, targets). - consume_noise_mask(vocab_size=vocab_size), - ) + mode="constant", + constant_values=pad_value[key], + ) + output = tuple(example) + else: + if not isinstance(example, np.ndarray): + raise ValueError(f"example isn't nparray, but should be: {example}") + array_length = example.shape[0] + if multiple: + padding_len = array_length - ( + (array_length // len_map[0]) * len_map[0] + ) + else: + padding_len = max(0, len_map[0] - array_length) + output = np.pad( + example, + pad_width=(0, padding_len), + mode="constant", + constant_values=pad_value[0], + ) + yield output + + if len_map is None: + raise ValueError("len_map parameter should be provided.") + return _pad_to_length + + +@gin.configurable(module="trax.data") +def BucketByLength( + boundaries, + batch_sizes, # pylint: disable=invalid-name + length_keys=None, + length_axis=0, + strict_pad_on_len=False, +): + """Returns a function for bucketing inputs, see `bucket_by_length`.""" + length_keys = length_keys or [0, 1] + # In all cases so far, we use a length function of the following form. + length_fn = lambda x: _length_fn(x, length_axis, length_keys) + return lambda g: bucket_by_length( # pylint: disable=g-long-lambda + g, length_fn, boundaries, batch_sizes, strict_pad_on_len + ) -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") +def MLM( + vocab_size=None, # pylint:disable=invalid-name + max_length=None, + noise_density=0.15, + mean_noise_span_length=3.0, +): + """Pipeline that just does MLM.""" + return Serial( + # Generate sequential chunks. + generate_sequential_chunks(max_length=max_length), + # Generate mask and chunk. + generate_random_noise_mask( + noise_density=noise_density, mean_noise_span_length=mean_noise_span_length + ), + # Consume mask and chunk to give (input, targets). + consume_noise_mask(vocab_size=vocab_size), + ) + + +@gin.configurable(module="trax.data") def PrefixLM(input_length=128, output_length=512): # pylint:disable=invalid-name - """Chunks examples so as to make inputs/outputs of specified lenghts.""" - def _f(generator): - for example in generator: - n_tokens = len(example) - # Iterate: - # |--------|<---- input_length ---->|<- output_length ->|--------------| - # ^ ^ ^ ^ - # | | | | - # 0 input_begin_idx input_end_idx output_end_idx - input_begin_idx = 0 - # While you can make an input batch, keep going. - while input_begin_idx + input_length < n_tokens: - input_end_idx = input_begin_idx + input_length - output_end_idx = min(input_end_idx + output_length, n_tokens) - yield (example[input_begin_idx:input_end_idx], - example[input_end_idx:output_end_idx]) - # Update the indices. - input_begin_idx = output_end_idx - return _f - - -@gin.configurable(module='trax.data') + """Chunks examples so as to make inputs/outputs of specified lenghts.""" + + def _f(generator): + for example in generator: + n_tokens = len(example) + # Iterate: + # |--------|<---- input_length ---->|<- output_length ->|--------------| + # ^ ^ ^ ^ + # | | | | + # 0 input_begin_idx input_end_idx output_end_idx + input_begin_idx = 0 + # While you can make an input batch, keep going. + while input_begin_idx + input_length < n_tokens: + input_end_idx = input_begin_idx + input_length + output_end_idx = min(input_end_idx + output_length, n_tokens) + yield ( + example[input_begin_idx:input_end_idx], + example[input_end_idx:output_end_idx], + ) + # Update the indices. + input_begin_idx = output_end_idx + + return _f + + +@gin.configurable(module="trax.data") def ConcatenateToLMInput(pad_to_length=None): # pylint: disable=invalid-name - """Prepares the input needed for training of Language Models. - - Each example needs to contain two elements (input and target). - Input is concatenated to target and, if pad_to_length is given, padded to - length provided. - The loss_weights indicates only the target, without input nor padding. - - Args: - pad_to_length: int, total length of padding of input and target arrays. - Returns: - Function to return input for a LM. - """ - @debug_data_pipeline.debug_pipeline - def _concatenate_to_lm_input(generator): - for example in generator: - if isinstance(example, (list, tuple)) and (len(example) == 2): - concatenated = np.concatenate((example[0], example[1]), axis=-1) - loss_weights = np.concatenate((np.zeros_like(example[0]), - np.ones_like(example[1]))) - if pad_to_length is not None: - padding_len = pad_to_length - ( - example[0].shape[0] + example[1].shape[0]) - if padding_len < 0: - raise ValueError( - 'Example lengths ' - f'({example[0].shape[0]}, {example[1].shape[0]}) ' - f'longer than pad_to_length ({pad_to_length}).') - loss_weights = np.pad(loss_weights, (0, padding_len), 'constant') - concatenated = np.pad(concatenated, (0, padding_len), 'constant') - output = (concatenated, concatenated, loss_weights) - elif isinstance(example, (list, tuple)) and (len(example) == 1): - # Make x into (x, x) - output = (example[0], example[0]) - elif isinstance(example, np.ndarray): - # Make x into (x, x) - output = (example, example) - else: - output = None - raise ValueError(f'Unknown input to ConcatenateToLMInput: {example}') - yield output - - return _concatenate_to_lm_input - - -@gin.configurable(module='trax.data') -def CastTo(dtype=np.int32, indices=(0, 1,), debug=False): # pylint: disable=invalid-name - """Casts the given indices to the given dtype.""" - def _cast_fn(generator): - debug_count = 0 - for example in generator: - debug_count += 1 - assert isinstance(example, tuple) - example = list(example) - dtype_mismatch = False - original_index_and_dtype = [] - for i in range(len(example)): - if i not in indices: - continue - original_type = example[i].dtype - if original_type != dtype: - if not (original_type == np.int64 and dtype == np.int32): - # Downcasting from np.int64 to np.int32 is OK - original_index_and_dtype.append((i, original_type)) - example[i] = example[i].astype(dtype) - dtype_mismatch = True - if debug and dtype_mismatch and original_index_and_dtype: - logging.info('dtype mismatch in example[%d] = %r was earlier: %r', - debug_count, example, original_index_and_dtype) - yield tuple(example) - return _cast_fn - - -@gin.configurable(module='trax.data') + """Prepares the input needed for training of Language Models. + + Each example needs to contain two elements (input and target). + Input is concatenated to target and, if pad_to_length is given, padded to + length provided. + The loss_weights indicates only the target, without input nor padding. + + Args: + pad_to_length: int, total length of padding of input and target arrays. + Returns: + Function to return input for a LM. + """ + + @debug_data_pipeline.debug_pipeline + def _concatenate_to_lm_input(generator): + for example in generator: + if isinstance(example, (list, tuple)) and (len(example) == 2): + concatenated = np.concatenate((example[0], example[1]), axis=-1) + loss_weights = np.concatenate( + (np.zeros_like(example[0]), np.ones_like(example[1])) + ) + if pad_to_length is not None: + padding_len = pad_to_length - ( + example[0].shape[0] + example[1].shape[0] + ) + if padding_len < 0: + raise ValueError( + "Example lengths " + f"({example[0].shape[0]}, {example[1].shape[0]}) " + f"longer than pad_to_length ({pad_to_length})." + ) + loss_weights = np.pad(loss_weights, (0, padding_len), "constant") + concatenated = np.pad(concatenated, (0, padding_len), "constant") + output = (concatenated, concatenated, loss_weights) + elif isinstance(example, (list, tuple)) and (len(example) == 1): + # Make x into (x, x) + output = (example[0], example[0]) + elif isinstance(example, np.ndarray): + # Make x into (x, x) + output = (example, example) + else: + output = None + raise ValueError(f"Unknown input to ConcatenateToLMInput: {example}") + yield output + + return _concatenate_to_lm_input + + +@gin.configurable(module="trax.data") +def CastTo( + dtype=np.int32, + indices=( + 0, + 1, + ), + debug=False, +): # pylint: disable=invalid-name + """Casts the given indices to the given dtype.""" + + def _cast_fn(generator): + debug_count = 0 + for example in generator: + debug_count += 1 + assert isinstance(example, tuple) + example = list(example) + dtype_mismatch = False + original_index_and_dtype = [] + for i in range(len(example)): + if i not in indices: + continue + original_type = example[i].dtype + if original_type != dtype: + if not (original_type == np.int64 and dtype == np.int32): + # Downcasting from np.int64 to np.int32 is OK + original_index_and_dtype.append((i, original_type)) + example[i] = example[i].astype(dtype) + dtype_mismatch = True + if debug and dtype_mismatch and original_index_and_dtype: + logging.info( + "dtype mismatch in example[%d] = %r was earlier: %r", + debug_count, + example, + original_index_and_dtype, + ) + yield tuple(example) + + return _cast_fn + + +@gin.configurable(module="trax.data") def AppendValue(val=None): # pylint: disable=invalid-name - """Appends values provided in 'val` to inputs. - - val are keyed by example keys, its values contain appended tensors. - - Args: - val: dict of int to tensors. Specific keys get the tensors specified in - values appended. - Returns: - Funtion to append tensors to examples. - """ - @debug_data_pipeline.debug_pipeline - def _append_value(generator): - for example in generator: - if isinstance(example, tuple): - example = list(example) - if val is not None: - for key, value in val.items(): - example[key] = np.append(example[key], value, -1) - output = tuple(example) - else: - if not isinstance(example, np.ndarray): - raise ValueError(f'example isn\'t nparray, but should be: {example}') - output = np.append(example, val[0]) - yield output - - return _append_value - - -@gin.configurable(module='trax.data') + """Appends values provided in 'val` to inputs. + + val are keyed by example keys, its values contain appended tensors. + + Args: + val: dict of int to tensors. Specific keys get the tensors specified in + values appended. + Returns: + Funtion to append tensors to examples. + """ + + @debug_data_pipeline.debug_pipeline + def _append_value(generator): + for example in generator: + if isinstance(example, tuple): + example = list(example) + if val is not None: + for key, value in val.items(): + example[key] = np.append(example[key], value, -1) + output = tuple(example) + else: + if not isinstance(example, np.ndarray): + raise ValueError(f"example isn't nparray, but should be: {example}") + output = np.append(example, val[0]) + yield output + + return _append_value + + +@gin.configurable(module="trax.data") def AddLossWeights(id_to_mask=None): # pylint: disable=invalid-name - """Returns a function to add loss weights; see `add_loss_weights`.""" - return lambda g: add_loss_weights(g, id_to_mask=id_to_mask) + """Returns a function to add loss weights; see `add_loss_weights`.""" + return lambda g: add_loss_weights(g, id_to_mask=id_to_mask) -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def UnBatch(): # pylint: disable=invalid-name - """Returns a function which unbatches.""" - def _unbatch(generator): - for batched_example in generator: - # batched_example is usually like: - # (batched_inputs, batched_outputs) or - # (batched_inputs, batched_outputs, batched_weights) - assert isinstance(batched_example, tuple) - # assert all lengths are the same. - batch_sizes = list(set(map(lambda example: example.shape[0], - batched_example))) - assert len(batch_sizes) == 1 - # Now unbatch examples. - for example_idx in range(batch_sizes[0]): - yield tuple(map(lambda x: x[example_idx], batched_example)) # pylint: disable=cell-var-from-loop - return _unbatch - - -@gin.configurable(module='trax.data') + """Returns a function which unbatches.""" + + def _unbatch(generator): + for batched_example in generator: + # batched_example is usually like: + # (batched_inputs, batched_outputs) or + # (batched_inputs, batched_outputs, batched_weights) + assert isinstance(batched_example, tuple) + # assert all lengths are the same. + batch_sizes = list( + set(map(lambda example: example.shape[0], batched_example)) + ) + assert len(batch_sizes) == 1 + # Now unbatch examples. + for example_idx in range(batch_sizes[0]): + yield tuple( + map(lambda x: x[example_idx], batched_example) + ) # pylint: disable=cell-var-from-loop + + return _unbatch + + +@gin.configurable(module="trax.data") def Prefetch(n_prefetch=2): # pylint: disable=invalid-name - """Pre-fetches a number of examples from generator in a separate process.""" - def prefetch(generator): - in_q, out_q = mp.Queue(), mp.Queue() - p = mp.Process(target=_generator_process, args=(generator, in_q, out_q)) - for _ in range(n_prefetch): - in_q.put(None) - p.start() - while True: - yield out_q.get() - in_q.put(None) - return prefetch - - -@gin.configurable(module='trax.data') -def UniformlySeek(name=None, host_id=None, n_hosts=None, dataset_size=None): # pylint: disable=invalid-name - """Sets each host at (dataset_size/n_hosts)-th of the dataset.""" - if not dataset_size: - dataset_size = 2 ** 18 # 512 * 512 - logging.error( - 'No dataset size given to Uniformly seek, assuming: %d', dataset_size) - assert name - host_id = jax.process_index() if host_id is None else host_id - n_hosts = n_hosts or jax.host_count() - each_host = int(dataset_size / n_hosts) - def _f(generator): - # Each host seeks to the appropriate point in the dataset. - num_to_seek = int(host_id * each_host) - start_time = time.time() - logging.info('Dataset[%s] host_id[%d] is seeking to position[%d]', - name, host_id, num_to_seek) - for _ in range(num_to_seek): - next(generator) - logging.info('Dataset[%s] host_id[%d] reached position[%d]. ' - 'Time taken [%s] seconds', - name, host_id, num_to_seek, time.time() - start_time) - for example in generator: - yield example - return _f - - -@gin.configurable(module='trax.data') + """Pre-fetches a number of examples from generator in a separate process.""" + + def prefetch(generator): + in_q, out_q = mp.Queue(), mp.Queue() + p = mp.Process(target=_generator_process, args=(generator, in_q, out_q)) + for _ in range(n_prefetch): + in_q.put(None) + p.start() + while True: + yield out_q.get() + in_q.put(None) + + return prefetch + + +@gin.configurable(module="trax.data") +def UniformlySeek( + name=None, host_id=None, n_hosts=None, dataset_size=None +): # pylint: disable=invalid-name + """Sets each host at (dataset_size/n_hosts)-th of the dataset.""" + if not dataset_size: + dataset_size = 2**18 # 512 * 512 + logging.error( + "No dataset size given to Uniformly seek, assuming: %d", dataset_size + ) + assert name + host_id = jax.process_index() if host_id is None else host_id + n_hosts = n_hosts or jax.host_count() + each_host = int(dataset_size / n_hosts) + + def _f(generator): + # Each host seeks to the appropriate point in the dataset. + num_to_seek = int(host_id * each_host) + start_time = time.time() + logging.info( + "Dataset[%s] host_id[%d] is seeking to position[%d]", + name, + host_id, + num_to_seek, + ) + for _ in range(num_to_seek): + next(generator) + logging.info( + "Dataset[%s] host_id[%d] reached position[%d]. " "Time taken [%s] seconds", + name, + host_id, + num_to_seek, + time.time() - start_time, + ) + for example in generator: + yield example + + return _f + + +@gin.configurable(module="trax.data") def CountAndSkip(name): # pylint: disable=invalid-name - """Returns a function that counts and skips examples (see above).""" - return lambda g: count_and_skip(g, name) + """Returns a function that counts and skips examples (see above).""" + return lambda g: count_and_skip(g, name) -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def Log(n_steps_per_example=1, only_shapes=True): # pylint: disable=invalid-name - """Creates a logging component of the input pipeline.""" - def log(stream): - counter = 0 - for example in stream: - item_to_log = example - if only_shapes: - item_to_log = fastmath.nested_map(shapes.signature, example) - if counter % n_steps_per_example == 0: - logging.info(str(item_to_log)) - print(item_to_log) - counter += 1 - yield example - return log + """Creates a logging component of the input pipeline.""" + + def log(stream): + counter = 0 + for example in stream: + item_to_log = example + if only_shapes: + item_to_log = fastmath.nested_map(shapes.signature, example) + if counter % n_steps_per_example == 0: + logging.info(str(item_to_log)) + print(item_to_log) + counter += 1 + yield example + + return log def shuffle(samples, queue_size): - """Shuffles a sample stream using a random-out next-in queue of given size. - - Args: - samples: Stream of samples for eventual use as training data or eval data. - queue_size: Minimum number of samples within which the streamed shuffling - takes place. - - Yields: - Shuffled stream of samples, ready for further processing, e.g., grouping - into batches. - """ - if queue_size < 1: - raise ValueError(f'Arg queue_size ({queue_size}) is less than 1.') - if queue_size == 1: - logging.warning('Queue size of 1 results in no shuffling.') - queue = [] - try: - # Prep: fill the queue. - for _ in range(queue_size): - queue.append(next(samples)) - - # Core streaming shuffle: yield sample from random location in queue, then - # fill that location with new sample from input stream. - for sample in samples: - i = np.random.randint(queue_size) - yield queue[i] - queue[i] = sample - except StopIteration: - # Only get here if the initial queue fill fails. - logging.warning( - 'Not enough samples (%d) to fill initial queue (size %d).', - len(queue), queue_size) - - # No new samples coming in; shuffle and drain the queue. - np.random.shuffle(queue) - for sample in queue: - yield sample + """Shuffles a sample stream using a random-out next-in queue of given size. + + Args: + samples: Stream of samples for eventual use as training data or eval data. + queue_size: Minimum number of samples within which the streamed shuffling + takes place. + + Yields: + Shuffled stream of samples, ready for further processing, e.g., grouping + into batches. + """ + if queue_size < 1: + raise ValueError(f"Arg queue_size ({queue_size}) is less than 1.") + if queue_size == 1: + logging.warning("Queue size of 1 results in no shuffling.") + queue = [] + try: + # Prep: fill the queue. + for _ in range(queue_size): + queue.append(next(samples)) + + # Core streaming shuffle: yield sample from random location in queue, then + # fill that location with new sample from input stream. + for sample in samples: + i = np.random.randint(queue_size) + yield queue[i] + queue[i] = sample + except StopIteration: + # Only get here if the initial queue fill fails. + logging.warning( + "Not enough samples (%d) to fill initial queue (size %d).", + len(queue), + queue_size, + ) + + # No new samples coming in; shuffle and drain the queue. + np.random.shuffle(queue) + for sample in queue: + yield sample def batch(generator, batch_size): - """Batch and pad generator as in tf.data.Dataset.padded_batch.""" - if batch_size <= 0: - raise ValueError(f'Batch size must be positive, but is {batch_size}.') - buf = [] - i = 0 - for example in generator: - buf.append(example) # Examples are tuples of tensors. - if len(buf) == batch_size: - # buf is a list of tuples, e.g., [(in1, tgt1), (in2, tgt2), (in3, tgt3)] - # batch is a tuple of arrays: ([in1, in2, in3], [tgt1, tgt2, tgt3]) - try: - batched_example = tuple( - pad_to_max_dims([np.asarray(tensor) for tensor in x]) - for x in zip(*buf)) - except ValueError as e: - for j in range(len(buf)): - logging.error('Batch[%d][%d] input shape: %r output shape: %r', - i, j, buf[j][0].shape, buf[j][1].shape) - for j in range(len(buf)): - logging.error('Batch[%d][%d] input: %r', i, j, buf[j][0]) - logging.error('Batch[%d][%d] output: %r', i, j, buf[j][1]) - raise e - i += 1 - yield batched_example - buf = [] + """Batch and pad generator as in tf.data.Dataset.padded_batch.""" + if batch_size <= 0: + raise ValueError(f"Batch size must be positive, but is {batch_size}.") + buf = [] + i = 0 + for example in generator: + buf.append(example) # Examples are tuples of tensors. + if len(buf) == batch_size: + # buf is a list of tuples, e.g., [(in1, tgt1), (in2, tgt2), (in3, tgt3)] + # batch is a tuple of arrays: ([in1, in2, in3], [tgt1, tgt2, tgt3]) + try: + batched_example = tuple( + pad_to_max_dims([np.asarray(tensor) for tensor in x]) + for x in zip(*buf) + ) + except ValueError as e: + for j in range(len(buf)): + logging.error( + "Batch[%d][%d] input shape: %r output shape: %r", + i, + j, + buf[j][0].shape, + buf[j][1].shape, + ) + for j in range(len(buf)): + logging.error("Batch[%d][%d] input: %r", i, j, buf[j][0]) + logging.error("Batch[%d][%d] output: %r", i, j, buf[j][1]) + raise e + i += 1 + yield batched_example + buf = [] def pad_to_max_dims(tensors, boundary=None, strict_pad_on_len=False): - """Pad a tuple of tensors to a joint dimension and return their batch. - - For example, a pair of tensors of shape (2, 10) and (3, 9) will be padded - to (3, 10) both and the returned tensor will have shape (2, 3, 10). - - When boundary is specified, we try to pad all unknown dimensions to boundary - if possible, which can help reduce the number of different shapes occurring - in the tensors and speed up XLA compilation. So, for example, a pair of - tensors of shapes (8, 10), (8, 9) with boundary=12 will be padded to (8, 12). - - One special case occurs when boundary is much higher than the padding length - that we'd use without boundary. For example, tensors (2, 10) and (3, 9) with - boundary=12 could end up padded to (12, 12), but this is very wasteful in - the first dimension. In that case, we will use the closest power-of-2 instead - of the boundary, so the we will end up padding to (4, 12) instead of (12, 12). - - Args: - tensors: a tuple or list of tensors to pad - boundary: int or None; if given, expand the padded dimensions to this size - strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] - strictly as a multiple of boundary. - - Returns: - a tensor, the tensors padded together - """ - # TODO(afrozm): Unify this later. - if ((boundary is not None) and - (strict_pad_on_len or isinstance(boundary, (list, tuple)))): - ndim = tensors[0].ndim - if not isinstance(boundary, (list, tuple)): - boundary = [boundary] * ndim - - if ndim != len(boundary): - raise ValueError(f'ndim != len(boundary) - ' - f'ndim({ndim}) vs boundary({boundary}) ' - f'len(boundary) = {len(boundary)}.') - - max_len_per_dim = [0] * ndim - for tensor in tensors: - max_len_per_dim = [ - max(e, s) for e, s in zip(tensor.shape, max_len_per_dim)] - - # Round everything up to a multiple of boundary in the respective dimension. - len_per_dim = [ - max_len_per_dim[i] if not b else b * math.ceil(max_len_per_dim[i] / b) - for i, b in enumerate(boundary)] - - padded_tensors = [ - np.pad(t, [(0, len_per_dim[i] - t.shape[i]) for i in range(ndim)], - mode='constant', constant_values=t.dtype.type(0)) - for t in tensors] + """Pad a tuple of tensors to a joint dimension and return their batch. + + For example, a pair of tensors of shape (2, 10) and (3, 9) will be padded + to (3, 10) both and the returned tensor will have shape (2, 3, 10). + + When boundary is specified, we try to pad all unknown dimensions to boundary + if possible, which can help reduce the number of different shapes occurring + in the tensors and speed up XLA compilation. So, for example, a pair of + tensors of shapes (8, 10), (8, 9) with boundary=12 will be padded to (8, 12). + One special case occurs when boundary is much higher than the padding length + that we'd use without boundary. For example, tensors (2, 10) and (3, 9) with + boundary=12 could end up padded to (12, 12), but this is very wasteful in + the first dimension. In that case, we will use the closest power-of-2 instead + of the boundary, so the we will end up padding to (4, 12) instead of (12, 12). + + Args: + tensors: a tuple or list of tensors to pad + boundary: int or None; if given, expand the padded dimensions to this size + strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] + strictly as a multiple of boundary. + + Returns: + a tensor, the tensors padded together + """ + # TODO(afrozm): Unify this later. + if (boundary is not None) and ( + strict_pad_on_len or isinstance(boundary, (list, tuple)) + ): + ndim = tensors[0].ndim + if not isinstance(boundary, (list, tuple)): + boundary = [boundary] * ndim + + if ndim != len(boundary): + raise ValueError( + f"ndim != len(boundary) - " + f"ndim({ndim}) vs boundary({boundary}) " + f"len(boundary) = {len(boundary)}." + ) + + max_len_per_dim = [0] * ndim + for tensor in tensors: + max_len_per_dim = [max(e, s) for e, s in zip(tensor.shape, max_len_per_dim)] + + # Round everything up to a multiple of boundary in the respective dimension. + len_per_dim = [ + max_len_per_dim[i] if not b else b * math.ceil(max_len_per_dim[i] / b) + for i, b in enumerate(boundary) + ] + + padded_tensors = [ + np.pad( + t, + [(0, len_per_dim[i] - t.shape[i]) for i in range(ndim)], + mode="constant", + constant_values=t.dtype.type(0), + ) + for t in tensors + ] + + return np.stack(padded_tensors) + + max_len_to_pad = [] + padding_needed = False + dim = len(tensors[0].shape) + for i in range(dim): + max_len = max([t.shape[i] for t in tensors]) + min_len = min([t.shape[i] for t in tensors]) + if max_len == min_len and max_len == boundary: # No padding needed. + max_len_to_pad.append(max_len) + elif boundary is None: + max_len_to_pad.append(max_len) + padding_needed = True + else: + padding_needed = True + cur_boundary = max(max_len, boundary) + if 2 * max_len < cur_boundary: + cur_boundary = 2 ** int(np.ceil(np.log2(max_len))) + max_len_to_pad.append(cur_boundary) + if not padding_needed: + return np.stack(tensors) + padded_tensors = [] + for t in tensors: + pad_widths = [(0, max_len_to_pad[i] - t.shape[i]) for i in range(dim)] + padded_t = np.pad( + t, pad_widths, mode="constant", constant_values=t.dtype.type(0) + ) + padded_tensors.append(padded_t) return np.stack(padded_tensors) - max_len_to_pad = [] - padding_needed = False - dim = len(tensors[0].shape) - for i in range(dim): - max_len = max([t.shape[i] for t in tensors]) - min_len = min([t.shape[i] for t in tensors]) - if max_len == min_len and max_len == boundary: # No padding needed. - max_len_to_pad.append(max_len) - elif boundary is None: - max_len_to_pad.append(max_len) - padding_needed = True - else: - padding_needed = True - cur_boundary = max(max_len, boundary) - if 2 * max_len < cur_boundary: - cur_boundary = 2**int(np.ceil(np.log2(max_len))) - max_len_to_pad.append(cur_boundary) - if not padding_needed: - return np.stack(tensors) - padded_tensors = [] - for t in tensors: - pad_widths = [(0, max_len_to_pad[i] - t.shape[i]) for i in range(dim)] - padded_t = np.pad(t, pad_widths, mode='constant', - constant_values=t.dtype.type(0)) - padded_tensors.append(padded_t) - return np.stack(padded_tensors) - - -def bucket_by_length(generator, length_fn, boundaries, batch_sizes, - strict_pad_on_len=False): - """Bucket by length, like tf.data.experimental.bucket_by_sequence_length. - - This function draws examples from the provided `generator` and puts an - example into a bucket depending on `l = length_fn(example)`. Which bucket - is used depends on between which `boundaries` is l. When a bucket reaches - its batch size, as specified by `batch_sizes`, generates a batch of - padded examples from this bucket. - - Args: - generator: python generator to draw data from. - length_fn: a function taking the example and returning the length. - boundaries: a list of bucket boundaries. - batch_sizes: a list of batch sizes. - strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] - strictly as a multiple of boundary. - - Yields: - An input batch, which comes from one of the buckets. - """ - buckets = [[] for _ in range(len(batch_sizes))] - boundaries = boundaries + [math.inf] # Max boundary is unlimited. - for example in generator: - length = length_fn(example) - # `bucket_idx` will always be < len(boundaries), since boundaries is right - # padded by `math.inf`. - bucket_idx = min([i for i, b in enumerate(boundaries) if length <= b]) - buckets[bucket_idx].append(example) - if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]: - batched = zip(*buckets[bucket_idx]) - boundary = boundaries[bucket_idx] - boundary = None if boundary == math.inf else boundary - padded_batch = tuple( - pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batched) - yield padded_batch - buckets[bucket_idx] = [] + +def bucket_by_length( + generator, length_fn, boundaries, batch_sizes, strict_pad_on_len=False +): + """Bucket by length, like tf.data.experimental.bucket_by_sequence_length. + + This function draws examples from the provided `generator` and puts an + example into a bucket depending on `l = length_fn(example)`. Which bucket + is used depends on between which `boundaries` is l. When a bucket reaches + its batch size, as specified by `batch_sizes`, generates a batch of + padded examples from this bucket. + + Args: + generator: python generator to draw data from. + length_fn: a function taking the example and returning the length. + boundaries: a list of bucket boundaries. + batch_sizes: a list of batch sizes. + strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] + strictly as a multiple of boundary. + + Yields: + An input batch, which comes from one of the buckets. + """ + buckets = [[] for _ in range(len(batch_sizes))] + boundaries = boundaries + [math.inf] # Max boundary is unlimited. + for example in generator: + length = length_fn(example) + # `bucket_idx` will always be < len(boundaries), since boundaries is right + # padded by `math.inf`. + bucket_idx = min([i for i, b in enumerate(boundaries) if length <= b]) + buckets[bucket_idx].append(example) + if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]: + batched = zip(*buckets[bucket_idx]) + boundary = boundaries[bucket_idx] + boundary = None if boundary == math.inf else boundary + padded_batch = tuple( + pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batched + ) + yield padded_batch + buckets[bucket_idx] = [] @debug_data_pipeline.debug_pipeline def add_loss_weights(generator, id_to_mask=None): - """Add weights to inputs without weights and masks by id if requested. - - The generator stream is augmented in the following way: - - - If the stream consists of pairs `(inputs, targets)`, a loss mask is added - that is creates as a tensor of ones of the same shape as targets. - - If `id_to_mask` is not `None`, and the stream (after the previous point) - has triples `(inputs, targets, weights)`, the weights are multiplied by a - 0/1 mask that is 0 iff targets is equal to `id_to_mask` (1 otherwise). - - Args: - generator: Stream of tuples. - id_to_mask: If not None, int-valued id that represents padding, as opposed - to true target IDs. - - Yields: - Examples from the augmented stream. - """ - for example in generator: - if len(example) > 3 or len(example) < 2: - assert id_to_mask is None, 'Cannot automatically mask this stream.' - yield example - else: - if len(example) == 2: - weights = np.ones_like(example[1]).astype(np.float32) - else: - weights = example[2].astype(np.float32) - mask = 1.0 - np.equal(example[1], id_to_mask).astype(np.float32) - weights *= mask - output = (example[0], example[1], weights) - yield output - - -@gin.configurable(module='trax.data') -def generate_random_noise_mask(noise_density=0.15, - mean_noise_span_length=3.0, - seed1=None, - seed2=None): - """Returns a function that generates a random noise mask.""" - def _f(generator): + """Add weights to inputs without weights and masks by id if requested. + + The generator stream is augmented in the following way: + + - If the stream consists of pairs `(inputs, targets)`, a loss mask is added + that is creates as a tensor of ones of the same shape as targets. + - If `id_to_mask` is not `None`, and the stream (after the previous point) + has triples `(inputs, targets, weights)`, the weights are multiplied by a + 0/1 mask that is 0 iff targets is equal to `id_to_mask` (1 otherwise). + + Args: + generator: Stream of tuples. + id_to_mask: If not None, int-valued id that represents padding, as opposed + to true target IDs. + + Yields: + Examples from the augmented stream. + """ for example in generator: - length = len(example) - noise_mask = random_spans_noise_mask( - length, noise_density=noise_density, - mean_noise_span_length=mean_noise_span_length, - seed1=seed1, seed2=seed2, example=example) - yield (example, noise_mask) - return _f + if len(example) > 3 or len(example) < 2: + assert id_to_mask is None, "Cannot automatically mask this stream." + yield example + else: + if len(example) == 2: + weights = np.ones_like(example[1]).astype(np.float32) + else: + weights = example[2].astype(np.float32) + mask = 1.0 - np.equal(example[1], id_to_mask).astype(np.float32) + weights *= mask + output = (example[0], example[1], weights) + yield output + + +@gin.configurable(module="trax.data") +def generate_random_noise_mask( + noise_density=0.15, mean_noise_span_length=3.0, seed1=None, seed2=None +): + """Returns a function that generates a random noise mask.""" + + def _f(generator): + for example in generator: + length = len(example) + noise_mask = random_spans_noise_mask( + length, + noise_density=noise_density, + mean_noise_span_length=mean_noise_span_length, + seed1=seed1, + seed2=seed2, + example=example, + ) + yield (example, noise_mask) + + return _f -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def consume_noise_mask(vocab_size=32100): - """Consumes (tokens, noise mask) and returns (inputs, targets).""" - def _noise_span_to_unique_sentinel(tokens, noise_mask): - prev_token_is_noise = np.pad( - noise_mask[:-1], [1, 0], mode='constant', constant_values=False) - first_noise_tokens = np.logical_and(noise_mask, - np.logical_not(prev_token_is_noise)) - subsequent_noise_tokens = np.logical_and(noise_mask, prev_token_is_noise) - sentinel = vocab_size - np.cumsum(first_noise_tokens) - tokens = np.where(first_noise_tokens, sentinel, tokens) - return tokens[np.logical_not(subsequent_noise_tokens)] - - def _f(generator): - for tokens, noise_mask in generator: - # Returns inputs and targets. - yield (_noise_span_to_unique_sentinel(tokens, noise_mask), - _noise_span_to_unique_sentinel(tokens, np.logical_not(noise_mask))) - return _f - - -@gin.configurable(module='trax.data') + """Consumes (tokens, noise mask) and returns (inputs, targets).""" + + def _noise_span_to_unique_sentinel(tokens, noise_mask): + prev_token_is_noise = np.pad( + noise_mask[:-1], [1, 0], mode="constant", constant_values=False + ) + first_noise_tokens = np.logical_and( + noise_mask, np.logical_not(prev_token_is_noise) + ) + subsequent_noise_tokens = np.logical_and(noise_mask, prev_token_is_noise) + sentinel = vocab_size - np.cumsum(first_noise_tokens) + tokens = np.where(first_noise_tokens, sentinel, tokens) + return tokens[np.logical_not(subsequent_noise_tokens)] + + def _f(generator): + for tokens, noise_mask in generator: + # Returns inputs and targets. + yield ( + _noise_span_to_unique_sentinel(tokens, noise_mask), + _noise_span_to_unique_sentinel(tokens, np.logical_not(noise_mask)), + ) + + return _f + + +@gin.configurable(module="trax.data") def generate_sequential_chunks(max_length=None): - """Returns a function that generates chunks of atmost max_length length.""" - def _f(generator): - for example in generator: - n_tokens = len(example) - if n_tokens <= max_length: - yield example - else: - n_segments = int(math.ceil(float(n_tokens) / float(max_length))) - for i in range(n_segments): - start = max_length * i - end = min(start + max_length, n_tokens) - yield example[start:end] - return _f - - -@gin.configurable(module='trax.data') + """Returns a function that generates chunks of atmost max_length length.""" + + def _f(generator): + for example in generator: + n_tokens = len(example) + if n_tokens <= max_length: + yield example + else: + n_segments = int(math.ceil(float(n_tokens) / float(max_length))) + for i in range(n_segments): + start = max_length * i + end = min(start + max_length, n_tokens) + yield example[start:end] + + return _f + + +@gin.configurable(module="trax.data") def addition_input_stream( - vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, min_length=gin.REQUIRED, - max_length=gin.REQUIRED, pad_to_multiple=32, encdec=False): - """Data stream for the add problem: x+y(x+y). - - Args: - vocab_size: how many symbols to use. - batch_size: how large are the batches. - min_length: minimal length of w. - max_length: maximal length of w. - pad_to_multiple: int, pad length to be multiple of this number. - encdec: bool, if True return encoder-decoder style inputs (default: False) - - Returns: - python generator of tuples of data examples - """ - base = vocab_size - 3 # We use 0 to pad, base+1 as "+" and base+2 as "". - def single_example(max_length, min_length): - """Generate a stream of random mini-batches.""" - add_len = (min_length - 1) // 2 - l1 = np.random.randint((max_length - add_len + 1) // 2) + add_len - l2 = np.random.randint(max_length - l1 - 1) + 1 - n1 = random_number_lower_endian(l1, base) - n2 = random_number_lower_endian(l2, base) - result = lower_endian_to_number(n1, base) + lower_endian_to_number( - n2, base) - inp = n1 + [base] + n2 - tgt = number_to_lower_endian(result, base) - if encdec: - x = [i + 1 for i in inp] - y = [i + 1 for i in tgt] - weights = [1] * len(tgt) - candidate_example = (np.array(x), np.array(y), np.array(weights)) - if any(len(sample) > max_length for sample in candidate_example): - # sample too long, try again - return single_example(max_length, min_length) - return (np.array(x), np.array(y), np.array(weights)) - else: - x = [base+2] + [i+1 for i in inp] + [base+2] + [i+1 for i in tgt] - weights = ([0] * (len(inp) + 2)) + ([1] * len(tgt)) - return (np.array(x), np.array(x), np.array(weights)) + vocab_size=gin.REQUIRED, + batch_size=gin.REQUIRED, + min_length=gin.REQUIRED, + max_length=gin.REQUIRED, + pad_to_multiple=32, + encdec=False, +): + """Data stream for the add problem: x+y(x+y). - def batches(max_length, min_length): - """Batches of examples.""" - if max_length < 3 or min_length < 3: - raise ValueError('Maximum/minimum length must be at least 3.') - while True: - ex = [single_example(max_length, min_length) for _ in range(batch_size)] - padded_batch = [pad_to_max_dims(x, boundary=pad_to_multiple, - strict_pad_on_len=True) - for x in zip(*ex)] - yield tuple(padded_batch) + Args: + vocab_size: how many symbols to use. + batch_size: how large are the batches. + min_length: minimal length of w. + max_length: maximal length of w. + pad_to_multiple: int, pad length to be multiple of this number. + encdec: bool, if True return encoder-decoder style inputs (default: False) + + Returns: + python generator of tuples of data examples + """ + base = vocab_size - 3 # We use 0 to pad, base+1 as "+" and base+2 as "". + + def single_example(max_length, min_length): + """Generate a stream of random mini-batches.""" + add_len = (min_length - 1) // 2 + l1 = np.random.randint((max_length - add_len + 1) // 2) + add_len + l2 = np.random.randint(max_length - l1 - 1) + 1 + n1 = random_number_lower_endian(l1, base) + n2 = random_number_lower_endian(l2, base) + result = lower_endian_to_number(n1, base) + lower_endian_to_number(n2, base) + inp = n1 + [base] + n2 + tgt = number_to_lower_endian(result, base) + if encdec: + x = [i + 1 for i in inp] + y = [i + 1 for i in tgt] + weights = [1] * len(tgt) + candidate_example = (np.array(x), np.array(y), np.array(weights)) + if any(len(sample) > max_length for sample in candidate_example): + # sample too long, try again + return single_example(max_length, min_length) + return (np.array(x), np.array(y), np.array(weights)) + else: + x = [base + 2] + [i + 1 for i in inp] + [base + 2] + [i + 1 for i in tgt] + weights = ([0] * (len(inp) + 2)) + ([1] * len(tgt)) + return (np.array(x), np.array(x), np.array(weights)) + + def batches(max_length, min_length): + """Batches of examples.""" + if max_length < 3 or min_length < 3: + raise ValueError("Maximum/minimum length must be at least 3.") + while True: + ex = [single_example(max_length, min_length) for _ in range(batch_size)] + padded_batch = [ + pad_to_max_dims(x, boundary=pad_to_multiple, strict_pad_on_len=True) + for x in zip(*ex) + ] + yield tuple(padded_batch) - return batches(max_length, min_length) + return batches(max_length, min_length) # This is a straightforward translation of T5's random_spans_noise_mask. -def random_spans_noise_mask(length, - noise_density=0.15, - mean_noise_span_length=3.0, - seed1=None, - seed2=None, - example=None): - """Computes span corruption masks given input parameters.""" - # Passing this in case if we want to use for debugging/logging - del example - orig_length = length - # increase length to avoid degeneracy - length = max(length, 2) - num_noise_tokens = int(round(length * noise_density)) - # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. - num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) - num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) - # avoid degeneracy by ensuring positive number of noise spans - num_noise_spans = max(num_noise_spans, 1) - num_nonnoise_tokens = length - num_noise_tokens - - # Pick the lengths of the noise spans and the non-noise spans - def randomly_segment(num_items, num_segments, seed): - x = np.arange(num_items - 1) < num_segments - 1 - # Set random seed if passed (only in tests for now). - if seed is not None: - np.random.seed(seed) - np.random.shuffle(x) - first_in_segment = np.pad(x, (1, 0), mode='constant') - segment_id = np.cumsum(first_in_segment) - - y = np.roll(segment_id, 1) - y[0] = 0 - idxs = np.pad(np.squeeze(np.argwhere(segment_id - y), axis=1), - (1, 0), - mode='constant') - segment_lengths = np.add.reduceat(np.ones_like(segment_id), idxs, axis=0) - return segment_lengths - - noise_span_lengths = randomly_segment( - num_noise_tokens, num_noise_spans, seed1) - nonnoise_span_lengths = randomly_segment( - num_nonnoise_tokens, num_noise_spans, seed2) - interleaved_span_lengths = np.reshape( - np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), - [num_noise_spans * 2]) - span_starts = np.cumsum(interleaved_span_lengths)[:-1] - span_start_indicator = np.zeros(length) # all 0s to begin with - span_start_indicator[span_starts] = 1 - span_num = np.cumsum(span_start_indicator) - is_noise = np.equal(span_num % 2, 1) - return is_noise[:orig_length] +def random_spans_noise_mask( + length, + noise_density=0.15, + mean_noise_span_length=3.0, + seed1=None, + seed2=None, + example=None, +): + """Computes span corruption masks given input parameters.""" + # Passing this in case if we want to use for debugging/logging + del example + orig_length = length + # increase length to avoid degeneracy + length = max(length, 2) + num_noise_tokens = int(round(length * noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # Pick the lengths of the noise spans and the non-noise spans + def randomly_segment(num_items, num_segments, seed): + x = np.arange(num_items - 1) < num_segments - 1 + # Set random seed if passed (only in tests for now). + if seed is not None: + np.random.seed(seed) + np.random.shuffle(x) + first_in_segment = np.pad(x, (1, 0), mode="constant") + segment_id = np.cumsum(first_in_segment) + + y = np.roll(segment_id, 1) + y[0] = 0 + idxs = np.pad( + np.squeeze(np.argwhere(segment_id - y), axis=1), (1, 0), mode="constant" + ) + segment_lengths = np.add.reduceat(np.ones_like(segment_id), idxs, axis=0) + return segment_lengths + + noise_span_lengths = randomly_segment(num_noise_tokens, num_noise_spans, seed1) + nonnoise_span_lengths = randomly_segment( + num_nonnoise_tokens, num_noise_spans, seed2 + ) + interleaved_span_lengths = np.reshape( + np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), + [num_noise_spans * 2], + ) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros(length) # all 0s to begin with + span_start_indicator[span_starts] = 1 + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + return is_noise[:orig_length] def lower_endian_to_number(l, base): - """Helper function: convert a list of digits in the given base to a number.""" - return sum([d * (base**i) for i, d in enumerate(l)]) + """Helper function: convert a list of digits in the given base to a number.""" + return sum([d * (base**i) for i, d in enumerate(l)]) def number_to_lower_endian(n, base): - """Helper function: convert a number to a list of digits in the given base.""" - if n < base: - return [n] - return [n % base] + number_to_lower_endian(n // base, base) + """Helper function: convert a number to a list of digits in the given base.""" + if n < base: + return [n] + return [n % base] + number_to_lower_endian(n // base, base) def random_number_lower_endian(length, base): - """Helper function: generate a random number as a lower-endian digits list.""" - if length == 1: # Last digit can be 0 only if length is 1. - return [np.random.randint(base)] - prefix = [np.random.randint(base) for _ in range(length - 1)] - return prefix + [np.random.randint(base - 1) + 1] # Last digit is not 0. + """Helper function: generate a random number as a lower-endian digits list.""" + if length == 1: # Last digit can be 0 only if length is 1. + return [np.random.randint(base)] + prefix = [np.random.randint(base) for _ in range(length - 1)] + return prefix + [np.random.randint(base - 1) + 1] # Last digit is not 0. data_counters = {} # Used by {load,save}_data_counters and count_and_skip def count_and_skip(generator, name): - """Count the number of items in the generator, skip already counted ones. - - This function counts the number of processed examples and puts it into - the global variable `counters`. This variable can be saved and restored, - and if restored, this function will skip examples until the restored counter - is reached. When the data generator is deterministic, this allows to restore - the data reading process from a checkpoint. - - Args: - generator: generator for examples in the dataset. - name: string, a unique id that we use to count the examples - - Yields: - The examples from generator but first skip the number specified in the - global variable counters[name] and next increment this variable every - time a new example appears. - """ - global data_counters - local_counter = 0 - for example in generator: - local_counter += 1 - # This check must be inside the loop due to asynchronous initializations. - if name not in data_counters: - data_counters[name] = 0 - if local_counter > data_counters[name]: - data_counters[name] += 1 - yield example + """Count the number of items in the generator, skip already counted ones. + + This function counts the number of processed examples and puts it into + the global variable `counters`. This variable can be saved and restored, + and if restored, this function will skip examples until the restored counter + is reached. When the data generator is deterministic, this allows to restore + the data reading process from a checkpoint. + + Args: + generator: generator for examples in the dataset. + name: string, a unique id that we use to count the examples + + Yields: + The examples from generator but first skip the number specified in the + global variable counters[name] and next increment this variable every + time a new example appears. + """ + global data_counters + local_counter = 0 + for example in generator: + local_counter += 1 + # This check must be inside the loop due to asynchronous initializations. + if name not in data_counters: + data_counters[name] = 0 + if local_counter > data_counters[name]: + data_counters[name] += 1 + yield example def save_data_counters(output_dir, host_id=None): - """Checkpoint data counters.""" - global data_counters - host_id = jax.process_index() if host_id is None else host_id - fname = os.path.join(output_dir, 'data_counters%d.pkl' % host_id) - with tf.io.gfile.GFile(fname, 'wb') as f: - pickle.dump(data_counters, f) + """Checkpoint data counters.""" + global data_counters + host_id = jax.process_index() if host_id is None else host_id + fname = os.path.join(output_dir, "data_counters%d.pkl" % host_id) + with tf.io.gfile.GFile(fname, "wb") as f: + pickle.dump(data_counters, f) def load_data_counters(output_dir, host_id=None): - """Checkpoint data counters.""" - global data_counters - host_id = jax.process_index() if host_id is None else host_id - fname = os.path.join(output_dir, 'data_counters%d.pkl' % host_id) - if not tf.io.gfile.exists(fname): - logging.info('Did not load data counters as %s does not exist.', fname) - return - with tf.io.gfile.GFile(fname, 'rb') as f: - obj = pickle.load(f) - data_counters = obj + """Checkpoint data counters.""" + global data_counters + host_id = jax.process_index() if host_id is None else host_id + fname = os.path.join(output_dir, "data_counters%d.pkl" % host_id) + if not tf.io.gfile.exists(fname): + logging.info("Did not load data counters as %s does not exist.", fname) + return + with tf.io.gfile.GFile(fname, "rb") as f: + obj = pickle.load(f) + data_counters = obj def _generator_process(generator, in_q, out_q): - for example in generator: - in_q.get() - out_q.put(example) - - -def _buckets_for_length(bucket_length, batch_size, max_eval_length, n_devices, - training): - """Creates heuristically a set of bucket boundaries and sizes. - - The middle boundary is set to `bucket_length` and the corresponding batch - size is set to `batch_size`. We also create buckets of 1/2 and 1/4 length - with 2x and 4x batch size, and buckets of 2x and 4x and larger length with - 1/2 and 1/4 batch size respectively, and batch size 1 for the final one. - - Args: - bucket_length: the length of the middle bucket. - batch_size: the batch size for the middle bucket. - max_eval_length: the longest bucket length if training=False. - n_devices: number of devices, batch sizes are divisible by that. - training: bool, whether we are training or evaluating. - - Returns: - a pair of lists of integers, (bucket_boundaries, bucket_batch_sizes). - """ - bucket_boundaries = [bucket_length // 4, bucket_length // 2, - bucket_length, bucket_length * 2, - bucket_length * 4, bucket_length * 8, - bucket_length * 16] - if not training: - max_eval_length = max_eval_length or bucket_length * 32 - # Set last bucket boundary to be max_eval_length, cut off boundaries - # that are larger than this. - bucket_boundaries = ( - [b for b in bucket_boundaries if b < max_eval_length] + - [max_eval_length] - ) - bucket_boundaries.append(max_eval_length) - bucket_batch_sizes = [batch_size * 4, batch_size * 2, - batch_size, batch_size // 2, - batch_size // 4, batch_size // 8, - batch_size // 16, 1] - if not training: - # The last bucket batch size is always 1, but the one-but-last is - # sized to accommodate the final length = bucket_boundaries[-1], which - # we changed for eval above -- so adjusting here too. - - # Resize if needed, since bucket_batch_sizes may not be the same size - # anymore. - bucket_batch_sizes = bucket_batch_sizes[:len(bucket_boundaries)] + [1] - bucket_batch_sizes[-2] = batch_size // max_eval_length - # Make batch sizes divisible by n_devices. - bucket_batch_sizes = [max(b // n_devices, 1) * n_devices - for b in bucket_batch_sizes] - return (bucket_boundaries, bucket_batch_sizes) + for example in generator: + in_q.get() + out_q.put(example) + + +def _buckets_for_length( + bucket_length, batch_size, max_eval_length, n_devices, training +): + """Creates heuristically a set of bucket boundaries and sizes. + + The middle boundary is set to `bucket_length` and the corresponding batch + size is set to `batch_size`. We also create buckets of 1/2 and 1/4 length + with 2x and 4x batch size, and buckets of 2x and 4x and larger length with + 1/2 and 1/4 batch size respectively, and batch size 1 for the final one. + + Args: + bucket_length: the length of the middle bucket. + batch_size: the batch size for the middle bucket. + max_eval_length: the longest bucket length if training=False. + n_devices: number of devices, batch sizes are divisible by that. + training: bool, whether we are training or evaluating. + + Returns: + a pair of lists of integers, (bucket_boundaries, bucket_batch_sizes). + """ + bucket_boundaries = [ + bucket_length // 4, + bucket_length // 2, + bucket_length, + bucket_length * 2, + bucket_length * 4, + bucket_length * 8, + bucket_length * 16, + ] + if not training: + max_eval_length = max_eval_length or bucket_length * 32 + # Set last bucket boundary to be max_eval_length, cut off boundaries + # that are larger than this. + bucket_boundaries = [b for b in bucket_boundaries if b < max_eval_length] + [ + max_eval_length + ] + bucket_boundaries.append(max_eval_length) + bucket_batch_sizes = [ + batch_size * 4, + batch_size * 2, + batch_size, + batch_size // 2, + batch_size // 4, + batch_size // 8, + batch_size // 16, + 1, + ] + if not training: + # The last bucket batch size is always 1, but the one-but-last is + # sized to accommodate the final length = bucket_boundaries[-1], which + # we changed for eval above -- so adjusting here too. + + # Resize if needed, since bucket_batch_sizes may not be the same size + # anymore. + bucket_batch_sizes = bucket_batch_sizes[: len(bucket_boundaries)] + [1] + bucket_batch_sizes[-2] = batch_size // max_eval_length + # Make batch sizes divisible by n_devices. + bucket_batch_sizes = [ + max(b // n_devices, 1) * n_devices for b in bucket_batch_sizes + ] + return (bucket_boundaries, bucket_batch_sizes) def _length_fn(example, length_axis, length_keys): - """Length is the maximum of shape on length_axis over length_keys.""" - if isinstance(example, (list, tuple)): - return max([example[i].shape[length_axis] for i in length_keys]) - return example.shape[length_axis] + """Length is the maximum of shape on length_axis over length_keys.""" + if isinstance(example, (list, tuple)): + return max([example[i].shape[length_axis] for i in length_keys]) + return example.shape[length_axis] # ######################################################################## @@ -1188,359 +1315,449 @@ def _length_fn(example, length_axis, length_keys): class Inputs: - """Inputs bundle. - - Inputs bundle holds input streams and shapes for a training run. - It contains stream-creating functions that return python generators - of (input_batch, target_batch) tuples. - - * train_stream: training data that will be used for training - may include all the augmentation or selection the training wants - the shape of examples is [batch_fn.batch_size, ...] - * train_eval_stream: training data used for evaluation - examples from training data but usually without augmentation - the shape of examples is [batch_fn.eval_batch_size, ...] - * eval_stream: evaluation data stream - examples from evaluation data, usually without augmentation - the shape of examples is [batch_fn.eval_batch_size, ...] - * input_shape: the shape of inputs - the [...] above, without batch size - * input_dtype: the data type of inputs - * target_shape: the shape of targets - the [...] above, without batch size - * target_dtype: the data type of targets - """ - - def __init__(self, train_stream, eval_stream=None, train_eval_stream=None): - """Initialize a new set of inputs. - - Args: - train_stream: a function taking n_devices (an int) and returning - a python generator of training batches. - eval_stream: a function taking n_devices (an int) and returning - a python generator of validation batches; - if None, then the training generator will be used for evaluation. - train_eval_stream: a function taking n_devices (an int) and returning - a python generator of batches from - the training set used for evaluation (if None, use train_stream). + """Inputs bundle. + + Inputs bundle holds input streams and shapes for a training run. + It contains stream-creating functions that return python generators + of (input_batch, target_batch) tuples. + + * train_stream: training data that will be used for training + may include all the augmentation or selection the training wants + the shape of examples is [batch_fn.batch_size, ...] + * train_eval_stream: training data used for evaluation + examples from training data but usually without augmentation + the shape of examples is [batch_fn.eval_batch_size, ...] + * eval_stream: evaluation data stream + examples from evaluation data, usually without augmentation + the shape of examples is [batch_fn.eval_batch_size, ...] + * input_shape: the shape of inputs + the [...] above, without batch size + * input_dtype: the data type of inputs + * target_shape: the shape of targets + the [...] above, without batch size + * target_dtype: the data type of targets """ - if not callable(train_stream): - raise ValueError('Trax Inputs should be initialized with a function. ' - 'Did you forget the n_devices argument? If your inputs ' - 'do not use it, try lambda _: [your-inputs].') - self._train_stream = train_stream - self._eval_stream = eval_stream or self._train_stream + def __init__(self, train_stream, eval_stream=None, train_eval_stream=None): + """Initialize a new set of inputs. + + Args: + train_stream: a function taking n_devices (an int) and returning + a python generator of training batches. + eval_stream: a function taking n_devices (an int) and returning + a python generator of validation batches; + if None, then the training generator will be used for evaluation. + train_eval_stream: a function taking n_devices (an int) and returning + a python generator of batches from + the training set used for evaluation (if None, use train_stream). + """ + if not callable(train_stream): + raise ValueError( + "Trax Inputs should be initialized with a function. " + "Did you forget the n_devices argument? If your inputs " + "do not use it, try lambda _: [your-inputs]." + ) + + self._train_stream = train_stream + self._eval_stream = eval_stream or self._train_stream - # TODO(lukaszkaiser): should we get rid of this one day? - self._train_eval_stream = train_eval_stream or self._train_stream + # TODO(lukaszkaiser): should we get rid of this one day? + self._train_eval_stream = train_eval_stream or self._train_stream - # Peek into the train stream to get an example shape. - example_train_batch = next(train_stream(1)) - self._input_shape = tuple(example_train_batch[0].shape)[1:] - self._input_dtype = example_train_batch[0].dtype - self._target_shape = tuple(example_train_batch[-1].shape)[1:] - self._target_dtype = example_train_batch[-1].dtype - self._example_shape = [x.shape for x in example_train_batch] - self._example_dtype = [x.dtype for x in example_train_batch] + # Peek into the train stream to get an example shape. + example_train_batch = next(train_stream(1)) + self._input_shape = tuple(example_train_batch[0].shape)[1:] + self._input_dtype = example_train_batch[0].dtype + self._target_shape = tuple(example_train_batch[-1].shape)[1:] + self._target_dtype = example_train_batch[-1].dtype + self._example_shape = [x.shape for x in example_train_batch] + self._example_dtype = [x.dtype for x in example_train_batch] - def train_stream(self, n_devices): - return self._train_stream(n_devices) + def train_stream(self, n_devices): + return self._train_stream(n_devices) - def eval_stream(self, n_devices): - return self._eval_stream(n_devices) + def eval_stream(self, n_devices): + return self._eval_stream(n_devices) - def train_eval_stream(self, n_devices): - return self._train_stream(n_devices) + def train_eval_stream(self, n_devices): + return self._train_stream(n_devices) - @property - def input_shape(self): - """Example input shape, without batch dimension.""" - return self._input_shape + @property + def input_shape(self): + """Example input shape, without batch dimension.""" + return self._input_shape - @property - def target_shape(self): - """Example target shape, without batch dimension.""" - return self._target_shape + @property + def target_shape(self): + """Example target shape, without batch dimension.""" + return self._target_shape - @property - def input_dtype(self): - """Dtype of the input.""" - return self._input_dtype + @property + def input_dtype(self): + """Dtype of the input.""" + return self._input_dtype - @property - def target_dtype(self): - """Dtype of the target.""" - return self._target_dtype + @property + def target_dtype(self): + """Dtype of the target.""" + return self._target_dtype - @property - def example_shape_dtype(self): - """Shape and Dtype of an example batch.""" - return self._example_shape, self._example_dtype + @property + def example_shape_dtype(self): + """Shape and Dtype of an example batch.""" + return self._example_shape, self._example_dtype # Batching and Inputs creation helpers. -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def make_inputs(train_stream=gin.REQUIRED, eval_stream=None): - """Create Inputs from two streams; mostly for use in gin configs.""" - if isinstance(train_stream, (list, tuple)): - train_stream = Serial(train_stream)() - if isinstance(eval_stream, (list, tuple)): - eval_stream = Serial(eval_stream)() - eval_stream_fn = None if eval_stream is None else lambda _: eval_stream - return Inputs(train_stream=lambda _: train_stream, - eval_stream=eval_stream_fn) + """Create Inputs from two streams; mostly for use in gin configs.""" + if isinstance(train_stream, (list, tuple)): + train_stream = Serial(train_stream)() + if isinstance(eval_stream, (list, tuple)): + eval_stream = Serial(eval_stream)() + eval_stream_fn = None if eval_stream is None else lambda _: eval_stream + return Inputs(train_stream=lambda _: train_stream, eval_stream=eval_stream_fn) -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def make_additional_stream(stream=gin.REQUIRED): - """Create a stream mostly for use in gin configs for additional tasks.""" - return Serial(stream)() + """Create a stream mostly for use in gin configs for additional tasks.""" + return Serial(stream)() -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def make_parallel_stream(streams=gin.REQUIRED, counters=None): - """Create a parallel stream for use in gin configs for additional tasks.""" - return Parallel(streams, counters=counters)() - - -@gin.configurable(module='trax.data') -def batcher(data_streams=gin.REQUIRED, variable_shapes=True, - batch_size_per_device=32, batch_size=None, eval_batch_size=32, - bucket_length=32, buckets=None, - buckets_include_inputs_in_length=False, - batch_shuffle_size=None, max_eval_length=None, - # TODO(afrozm): Unify padding logic. - id_to_mask=None, strict_pad_on_len=False): - """Batcher: create trax Inputs from single-example data-streams.""" - # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. - # For now leaving the arguments as in batch_fn to reduce gin config changes. - if callable(data_streams): # If we pass a function, e.g., through gin, call. - train_stream, eval_stream = data_streams() - else: - train_stream, eval_stream = data_streams - # pylint: disable=g-long-lambda - batch_train_stream = lambda n_devices: batch_fn( - train_stream(), True, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - batch_eval_stream = lambda n_devices: batch_fn( - eval_stream(), False, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - batch_train_eval_stream = lambda n_devices: batch_fn( - train_stream(), False, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - # pylint: enable=g-long-lambda - return Inputs(train_stream=batch_train_stream, - eval_stream=batch_eval_stream, - train_eval_stream=batch_train_eval_stream) - - -def batch_fn(dataset, training, n_devices, variable_shapes, - batch_size_per_device=32, batch_size=None, eval_batch_size=32, - bucket_length=32, buckets=None, - buckets_include_inputs_in_length=False, - batch_shuffle_size=None, max_eval_length=None, - id_to_mask=None, strict_pad_on_len=False): - """Batching function.""" - # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. - # After that, create a proper doc-string; we may also not need to pass both - # training and eval arguments here, as batcher calls the function separately - # now and it's not under gin-config any more -- consider reducing args. - batch_size = batch_size or batch_size_per_device * n_devices - # If bucketing is not specified, check if target shapes are variable. - cur_batch_size = batch_size if training else eval_batch_size - # Make cur_batch_size divisible by n_devices. - cur_batch_size = max(cur_batch_size // n_devices, 1) * n_devices - # Create heuristic buckets if none are specified. - if buckets is None: - logging.info('Heuristically setting bucketing to %s based on shapes ' - 'of target tensors.', variable_shapes) - if variable_shapes: - buckets = _buckets_for_length( - bucket_length, cur_batch_size, max_eval_length, n_devices, training) - - if buckets: - logging.info('Bucketing with buckets %s.', str(buckets)) - def example_length(x): - """The length function used by bucket_by_sequence_length to bucket.""" - # The input x is a tuple to go on the stack, typically either - # (input, target) or (input, target, mask). - example_inputs, target = x[0], x[1] - # Length is the shape of axis 0 here (no batch yet). - other_length = 0 # We include input length only if asked. - if buckets_include_inputs_in_length: - other_length = example_inputs.shape[0] - return max(target.shape[0], other_length) - boundaries, batch_sizes = buckets - dataset = bucket_by_length( - dataset, example_length, boundaries, batch_sizes, strict_pad_on_len) - else: - logging.info('Not Bucketing cur_batch_size %d.', cur_batch_size) - dataset = batch(dataset, cur_batch_size) - if training and batch_shuffle_size is not None: - dataset = shuffle(dataset, batch_shuffle_size) - return add_loss_weights(dataset, id_to_mask) + """Create a parallel stream for use in gin configs for additional tasks.""" + return Parallel(streams, counters=counters)() + + +@gin.configurable(module="trax.data") +def batcher( + data_streams=gin.REQUIRED, + variable_shapes=True, + batch_size_per_device=32, + batch_size=None, + eval_batch_size=32, + bucket_length=32, + buckets=None, + buckets_include_inputs_in_length=False, + batch_shuffle_size=None, + max_eval_length=None, + # TODO(afrozm): Unify padding logic. + id_to_mask=None, + strict_pad_on_len=False, +): + """Batcher: create trax Inputs from single-example data-streams.""" + # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. + # For now leaving the arguments as in batch_fn to reduce gin config changes. + if callable(data_streams): # If we pass a function, e.g., through gin, call. + train_stream, eval_stream = data_streams() + else: + train_stream, eval_stream = data_streams + # pylint: disable=g-long-lambda + batch_train_stream = lambda n_devices: batch_fn( + train_stream(), + True, + n_devices, + variable_shapes, + batch_size_per_device, + batch_size, + eval_batch_size, + bucket_length, + buckets, + buckets_include_inputs_in_length, + batch_shuffle_size, + max_eval_length, + id_to_mask, + strict_pad_on_len, + ) + batch_eval_stream = lambda n_devices: batch_fn( + eval_stream(), + False, + n_devices, + variable_shapes, + batch_size_per_device, + batch_size, + eval_batch_size, + bucket_length, + buckets, + buckets_include_inputs_in_length, + batch_shuffle_size, + max_eval_length, + id_to_mask, + strict_pad_on_len, + ) + batch_train_eval_stream = lambda n_devices: batch_fn( + train_stream(), + False, + n_devices, + variable_shapes, + batch_size_per_device, + batch_size, + eval_batch_size, + bucket_length, + buckets, + buckets_include_inputs_in_length, + batch_shuffle_size, + max_eval_length, + id_to_mask, + strict_pad_on_len, + ) + # pylint: enable=g-long-lambda + return Inputs( + train_stream=batch_train_stream, + eval_stream=batch_eval_stream, + train_eval_stream=batch_train_eval_stream, + ) + + +def batch_fn( + dataset, + training, + n_devices, + variable_shapes, + batch_size_per_device=32, + batch_size=None, + eval_batch_size=32, + bucket_length=32, + buckets=None, + buckets_include_inputs_in_length=False, + batch_shuffle_size=None, + max_eval_length=None, + id_to_mask=None, + strict_pad_on_len=False, +): + """Batching function.""" + # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. + # After that, create a proper doc-string; we may also not need to pass both + # training and eval arguments here, as batcher calls the function separately + # now and it's not under gin-config any more -- consider reducing args. + batch_size = batch_size or batch_size_per_device * n_devices + # If bucketing is not specified, check if target shapes are variable. + cur_batch_size = batch_size if training else eval_batch_size + # Make cur_batch_size divisible by n_devices. + cur_batch_size = max(cur_batch_size // n_devices, 1) * n_devices + # Create heuristic buckets if none are specified. + if buckets is None: + logging.info( + "Heuristically setting bucketing to %s based on shapes " + "of target tensors.", + variable_shapes, + ) + if variable_shapes: + buckets = _buckets_for_length( + bucket_length, cur_batch_size, max_eval_length, n_devices, training + ) + + if buckets: + logging.info("Bucketing with buckets %s.", str(buckets)) + + def example_length(x): + """The length function used by bucket_by_sequence_length to bucket.""" + # The input x is a tuple to go on the stack, typically either + # (input, target) or (input, target, mask). + example_inputs, target = x[0], x[1] + # Length is the shape of axis 0 here (no batch yet). + other_length = 0 # We include input length only if asked. + if buckets_include_inputs_in_length: + other_length = example_inputs.shape[0] + return max(target.shape[0], other_length) + + boundaries, batch_sizes = buckets + dataset = bucket_by_length( + dataset, example_length, boundaries, batch_sizes, strict_pad_on_len + ) + else: + logging.info("Not Bucketing cur_batch_size %d.", cur_batch_size) + dataset = batch(dataset, cur_batch_size) + if training and batch_shuffle_size is not None: + dataset = shuffle(dataset, batch_shuffle_size) + return add_loss_weights(dataset, id_to_mask) # Example input functions. -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def random_inputs( - input_shape=gin.REQUIRED, input_dtype=jnp.int32, input_range=(0, 255), - output_shape=gin.REQUIRED, output_dtype=jnp.int32, output_range=(0, 9)): - """Make random Inputs for debugging. - - Args: - input_shape: the shape of inputs (including batch dimension). - input_dtype: the type of the inputs (int32 by default). - input_range: the range of inputs (defaults to (0, 255)). - output_shape: the shape of outputs (including batch dimension). - output_dtype: the type of the outputs (int32 by default). - output_range: the range of outputs (defaults to (0, 9)). - - Returns: - trax.inputs.Inputs - """ - def random_minibatches(n_devices): - """Generate a stream of random mini-batches.""" - assert input_range[0] % n_devices == 0 - if input_dtype in [jnp.float16, jnp.float32, jnp.float64]: - rand = np.random.uniform - else: - rand = np.random.random_integers - while True: - inp = rand(input_range[0], input_range[1], input_shape) - inp = inp.astype(input_dtype) - out = rand(output_range[0], output_range[1], output_shape) - out = out.astype(output_dtype) - yield inp, out + input_shape=gin.REQUIRED, + input_dtype=jnp.int32, + input_range=(0, 255), + output_shape=gin.REQUIRED, + output_dtype=jnp.int32, + output_range=(0, 9), +): + """Make random Inputs for debugging. - return Inputs(random_minibatches) + Args: + input_shape: the shape of inputs (including batch dimension). + input_dtype: the type of the inputs (int32 by default). + input_range: the range of inputs (defaults to (0, 255)). + output_shape: the shape of outputs (including batch dimension). + output_dtype: the type of the outputs (int32 by default). + output_range: the range of outputs (defaults to (0, 9)). + + Returns: + trax.inputs.Inputs + """ + def random_minibatches(n_devices): + """Generate a stream of random mini-batches.""" + assert input_range[0] % n_devices == 0 + if input_dtype in [jnp.float16, jnp.float32, jnp.float64]: + rand = np.random.uniform + else: + rand = np.random.random_integers + while True: + inp = rand(input_range[0], input_range[1], input_shape) + inp = inp.astype(input_dtype) + out = rand(output_range[0], output_range[1], output_shape) + out = out.astype(output_dtype) + yield inp, out -@gin.configurable(module='trax.data') + return Inputs(random_minibatches) + + +@gin.configurable(module="trax.data") def sequence_copy_inputs( - vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, - eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, reverse=False, - pad_to_multiple=32): - """Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*. - - Args: - vocab_size: how many symbols to use. - batch_size: how large are the batches. - train_length: maximum length of w for training. - eval_min_length: minimum length of w for eval. - eval_max_length : maximum length of w for eval. - reverse: bool (optional, false by default): reverse the second sequence. - pad_to_multiple: int, pad length to be multiple of this number. - - Returns: - trax.inputs.Inputs - """ - def random_minibatches(length_list): - """Generate a stream of random mini-batches.""" - while True: - length = random.choice(length_list) - assert length % 2 == 0 - w_length = (length // 2) - 1 - w = np.random.randint(low=1, high=vocab_size-1, - size=(batch_size, w_length)) - zero = np.zeros([batch_size, 1], np.int32) - loss_weights = np.concatenate([np.zeros((batch_size, w_length+2)), - np.ones((batch_size, w_length))], axis=1) - if reverse: - x = np.concatenate([zero, w, zero, jnp.flip(w, axis=1)], axis=1) - else: - x = np.concatenate([zero, w, zero, w], axis=1) - x = _pad_to_multiple_of(x, pad_to_multiple, 1) - loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) - yield (x, x, loss_weights) # Here inputs and targets are the same. - - train_lengths = [2*(i+2) for i in range(train_length - 1)] - eval_lengths = [2*(i+1) for i in range(eval_min_length, eval_max_length)] - return Inputs( - train_stream=lambda _: random_minibatches(train_lengths), - eval_stream=lambda _: random_minibatches(eval_lengths) - ) + vocab_size=gin.REQUIRED, + batch_size=gin.REQUIRED, + train_length=gin.REQUIRED, + eval_min_length=gin.REQUIRED, + eval_max_length=gin.REQUIRED, + reverse=False, + pad_to_multiple=32, +): + """Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*. + Args: + vocab_size: how many symbols to use. + batch_size: how large are the batches. + train_length: maximum length of w for training. + eval_min_length: minimum length of w for eval. + eval_max_length : maximum length of w for eval. + reverse: bool (optional, false by default): reverse the second sequence. + pad_to_multiple: int, pad length to be multiple of this number. + + Returns: + trax.inputs.Inputs + """ + + def random_minibatches(length_list): + """Generate a stream of random mini-batches.""" + while True: + length = random.choice(length_list) + assert length % 2 == 0 + w_length = (length // 2) - 1 + w = np.random.randint( + low=1, high=vocab_size - 1, size=(batch_size, w_length) + ) + zero = np.zeros([batch_size, 1], np.int32) + loss_weights = np.concatenate( + [np.zeros((batch_size, w_length + 2)), np.ones((batch_size, w_length))], + axis=1, + ) + if reverse: + x = np.concatenate([zero, w, zero, jnp.flip(w, axis=1)], axis=1) + else: + x = np.concatenate([zero, w, zero, w], axis=1) + x = _pad_to_multiple_of(x, pad_to_multiple, 1) + loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) + yield (x, x, loss_weights) # Here inputs and targets are the same. + + train_lengths = [2 * (i + 2) for i in range(train_length - 1)] + eval_lengths = [2 * (i + 1) for i in range(eval_min_length, eval_max_length)] + return Inputs( + train_stream=lambda _: random_minibatches(train_lengths), + eval_stream=lambda _: random_minibatches(eval_lengths), + ) -@gin.configurable(module='trax.data') + +@gin.configurable(module="trax.data") def simple_sequence_copy_inputs( - vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, - eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, - pad_to_multiple=32): - """Inputs for the sequence copy problem: w for w in [1..vocab_size-1]*. - - Args: - vocab_size: how many symbols to use. - batch_size: how large are the batches. - train_length: maximum length of w for training. - eval_min_length: minimum length of w for eval. - eval_max_length : maximum length of w for eval. - pad_to_multiple: int, pad length to be multiple of this number. - - Returns: - trax.inputs.Inputs - """ - def random_minibatches(length_list): - """Generate a stream of random mini-batches.""" - while True: - length = random.choice(length_list) - x = np.random.randint(low=1, high=vocab_size-1, - size=(batch_size, length)) - loss_weights = np.ones((batch_size, length)) - x = _pad_to_multiple_of(x, pad_to_multiple, 1) - loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) - yield (x, x, loss_weights) # Here inputs and targets are the same. - - train_lengths = list(range(1, train_length + 1)) - eval_lengths = list(range(eval_min_length, eval_max_length + 1)) - return Inputs( - train_stream=lambda _: random_minibatches(train_lengths), - eval_stream=lambda _: random_minibatches(eval_lengths) - ) + vocab_size=gin.REQUIRED, + batch_size=gin.REQUIRED, + train_length=gin.REQUIRED, + eval_min_length=gin.REQUIRED, + eval_max_length=gin.REQUIRED, + pad_to_multiple=32, +): + """Inputs for the sequence copy problem: w for w in [1..vocab_size-1]*. + + Args: + vocab_size: how many symbols to use. + batch_size: how large are the batches. + train_length: maximum length of w for training. + eval_min_length: minimum length of w for eval. + eval_max_length : maximum length of w for eval. + pad_to_multiple: int, pad length to be multiple of this number. + + Returns: + trax.inputs.Inputs + """ + def random_minibatches(length_list): + """Generate a stream of random mini-batches.""" + while True: + length = random.choice(length_list) + x = np.random.randint(low=1, high=vocab_size - 1, size=(batch_size, length)) + loss_weights = np.ones((batch_size, length)) + x = _pad_to_multiple_of(x, pad_to_multiple, 1) + loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) + yield (x, x, loss_weights) # Here inputs and targets are the same. + + train_lengths = list(range(1, train_length + 1)) + eval_lengths = list(range(eval_min_length, eval_max_length + 1)) + return Inputs( + train_stream=lambda _: random_minibatches(train_lengths), + eval_stream=lambda _: random_minibatches(eval_lengths), + ) -@gin.configurable(module='trax.data') + +@gin.configurable(module="trax.data") def addition_inputs( - vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, - eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, - pad_to_multiple=32, encdec=False): - """Inputs for the add problem: x+y(x+y). - - Args: - vocab_size: how many symbols to use. - batch_size: how large are the batches. - train_length: maximal length of w for training. - eval_min_length: minimal length of w for eval. - eval_max_length: maximal length of w for eval. - pad_to_multiple: int, pad length to be multiple of this number. - encdec: bool, if True return encoder-decoder style inputs (default: False) - - Returns: - trax.inputs.Inputs - """ - train_stream = addition_input_stream( - vocab_size, batch_size, 3, train_length, pad_to_multiple, encdec) - eval_stream = addition_input_stream( - vocab_size, batch_size, eval_min_length, eval_max_length, pad_to_multiple, - encdec) - return Inputs( - train_stream=lambda _: train_stream, - eval_stream=lambda _: eval_stream - ) + vocab_size=gin.REQUIRED, + batch_size=gin.REQUIRED, + train_length=gin.REQUIRED, + eval_min_length=gin.REQUIRED, + eval_max_length=gin.REQUIRED, + pad_to_multiple=32, + encdec=False, +): + """Inputs for the add problem: x+y(x+y). + + Args: + vocab_size: how many symbols to use. + batch_size: how large are the batches. + train_length: maximal length of w for training. + eval_min_length: minimal length of w for eval. + eval_max_length: maximal length of w for eval. + pad_to_multiple: int, pad length to be multiple of this number. + encdec: bool, if True return encoder-decoder style inputs (default: False) + + Returns: + trax.inputs.Inputs + """ + train_stream = addition_input_stream( + vocab_size, batch_size, 3, train_length, pad_to_multiple, encdec + ) + eval_stream = addition_input_stream( + vocab_size, + batch_size, + eval_min_length, + eval_max_length, + pad_to_multiple, + encdec, + ) + return Inputs( + train_stream=lambda _: train_stream, eval_stream=lambda _: eval_stream + ) -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def sine_inputs( batch_size=gin.REQUIRED, length=gin.REQUIRED, @@ -1548,43 +1765,43 @@ def sine_inputs( min_period=0.1, max_period=10.0, ): - """Sinusoids of random period and phase. - - Args: - batch_size (int): Number of examples in a batch. - length (int): Length of each sequence. - max_phase (float): Maximum phase of the sinusoids. - min_period (float): Minimum period of the sinusoids. - max_period (float): Maximum period of the sinusoids. - - Returns: - trax.inputs.Inputs - """ - def random_series(): - while True: - phase = np.random.uniform(0, max_phase) - period = np.exp(np.random.uniform(np.log(min_period), np.log(max_period))) - x = np.arange(length) - yield np.sin((x - phase) / period) - - def random_minibatches(_): - minibatch = [] - for series in random_series(): - minibatch.append(series) - if len(minibatch) == batch_size: - obs = np.stack(minibatch) - minibatch.clear() - act = np.zeros_like(obs, dtype=np.int32) - mask = np.ones_like(obs) - yield (obs, act, obs, mask) - - return Inputs(train_stream=random_minibatches, eval_stream=random_minibatches) + """Sinusoids of random period and phase. + + Args: + batch_size (int): Number of examples in a batch. + length (int): Length of each sequence. + max_phase (float): Maximum phase of the sinusoids. + min_period (float): Minimum period of the sinusoids. + max_period (float): Maximum period of the sinusoids. + + Returns: + trax.inputs.Inputs + """ + + def random_series(): + while True: + phase = np.random.uniform(0, max_phase) + period = np.exp(np.random.uniform(np.log(min_period), np.log(max_period))) + x = np.arange(length) + yield np.sin((x - phase) / period) + + def random_minibatches(_): + minibatch = [] + for series in random_series(): + minibatch.append(series) + if len(minibatch) == batch_size: + obs = np.stack(minibatch) + minibatch.clear() + act = np.zeros_like(obs, dtype=np.int32) + mask = np.ones_like(obs) + yield (obs, act, obs, mask) + + return Inputs(train_stream=random_minibatches, eval_stream=random_minibatches) def _pad_to_multiple_of(x, y, axis): - """Pads x to multiple of y on the given axis.""" - pad_len = np.ceil(x.shape[axis] / float(y)) * y - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = (0, int(pad_len - x.shape[axis])) - return np.pad(x, pad_widths, mode='constant', - constant_values=x.dtype.type(0)) + """Pads x to multiple of y on the given axis.""" + pad_len = np.ceil(x.shape[axis] / float(y)) * y + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (0, int(pad_len - x.shape[axis])) + return np.pad(x, pad_widths, mode="constant", constant_values=x.dtype.type(0)) diff --git a/trax/data/inputs_test.py b/trax/data/inputs_test.py deleted file mode 100644 index e4cf5c0bd..000000000 --- a/trax/data/inputs_test.py +++ /dev/null @@ -1,774 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.data.inputs.""" - -import itertools -import os - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -from trax import data - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') - - -def _spm_path(): - return os.path.join(_TESTDATA, 'sentencepiece.model') - - -class InputsTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('zero', 0), - ('negative', -5), - ) - def test_shuffle_data_raises_error_queue_size(self, queue_size): - samples = iter(range(10)) - with self.assertRaises(ValueError): - _ = list(data.shuffle(samples, queue_size)) - - @parameterized.named_parameters( - ('one', 1), - ('two', 2), - ('twenty', 20), - ) - def test_shuffle_data_queue_size(self, queue_size): - samples = iter(range(100, 200)) - shuffled_stream = data.shuffle(samples, queue_size) - first_ten = [next(shuffled_stream) for _ in range(10)] - - # Queue size limits how far ahead/upstream the current sample can reach. - self.assertLess(first_ten[0], 100 + queue_size) - self.assertLess(first_ten[3], 103 + queue_size) - self.assertLess(first_ten[9], 109 + queue_size) - - unshuffled_first_ten = list(range(100, 110)) - if queue_size == 1: # Degenerate case: no shuffling can happen. - self.assertEqual(first_ten, unshuffled_first_ten) - if queue_size > 1: - self.assertNotEqual(first_ten, unshuffled_first_ten) - - @parameterized.named_parameters( - ('qsize_100_n_001', 100, 1), - ('qsize_100_n_099', 100, 99), - ('qsize_100_n_100', 100, 100), - ('qsize_100_n_101', 100, 101), - ('qsize_100_n_199', 100, 199), - ) - def test_shuffle_data_yields_all_samples(self, queue_size, n_samples): - samples = iter(range(n_samples)) - shuffled_stream = data.shuffle(samples, queue_size) - self.assertLen(list(shuffled_stream), n_samples) - - def test_batch_data(self): - dataset = ((i, i+1) for i in range(10)) - batches = data.batch(dataset, 10) - batch = next(batches) - self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (10,)) - - def test_batch_data_padding(self): - dataset = (([1] * (10 - i), i+1) for i in range(10)) - batches = data.batch(dataset, 10) - batch = next(batches) - self.assertEqual(batch[0].shape, (10, 10)) - self.assertTrue(np.array_equal(batch[0][-1], np.asarray([1] + 9 * [0]))) - - def test_batch_exception_size(self): - dataset = ((i, i + 1) for i in range(10)) - with self.assertRaises(ValueError): - batches = data.batch(dataset, 0) - next(batches) - - def test_serial(self): - dataset = lambda _: ((i, i+1) for i in range(10)) - batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10)) - batch = next(batches()) - self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (10,)) - - def test_serial_composes(self): - """Check that data.Serial works inside another data.Serial.""" - dataset = lambda _: ((i, i+1) for i in range(10)) - serial1 = data.Serial(dataset, data.Shuffle(3)) - batches = data.Serial(serial1, data.Batch(10)) - batch = next(batches()) - self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (10,)) - - def test_count_and_skip(self): - dataset = lambda _: ((i, i+1) for i in range(10)) - examples = data.Serial(dataset, data.CountAndSkip('toy_data')) - ex_generator = examples() - ex1 = next(ex_generator) - self.assertEqual(ex1, (0, 1)) - self.assertEqual(data.inputs.data_counters['toy_data'], 1) - ex2 = next(ex_generator) - self.assertEqual(ex2, (1, 2)) - self.assertEqual(data.inputs.data_counters['toy_data'], 2) - ex3 = next(examples()) # new generator, will skip - self.assertEqual(ex3, (2, 3)) - self.assertEqual(data.inputs.data_counters['toy_data'], 3) - data.inputs.data_counters['toy_data'] = 0 # reset - ex4 = next(examples()) # new generator, was reset - self.assertEqual(ex4, (0, 1)) - self.assertEqual(data.inputs.data_counters['toy_data'], 1) - - def test_parallel(self): - """Basic test of the parallel ccmbinator.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - parallel = data.Parallel([dataset1, dataset2]) - generator = parallel() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 12) - - def test_parallel_with_gen_not_none(self): - """Test of the parallel ccmbinator with a not none generator.""" - dataset1 = lambda _: (i for i in range(10)) - dataset2 = lambda _: (i for i in range(10, 20)) - parallel = data.Parallel([dataset1, dataset2]) - - def test_generator(): - yield 0 - - generator = parallel(gen=test_generator) - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 12) - - def test_parallel_with_weights(self): - """Test of the parallel ccmbinator with weights.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - parallel = data.Parallel([dataset1, dataset2], counters=(2, 1)) - generator = parallel() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 3) - self.assertEqual(next(generator), 12) - self.assertEqual(next(generator), 4) - self.assertEqual(next(generator), 5) - self.assertEqual(next(generator), 13) - - def test_parallel_with_weights_and_minimum(self): - """Test of the parallel ccmbinator with weights and minimum.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 110)) - parallel = data.Parallel([dataset1, dataset2], - counters=(10, 100), - reweight_by_minimum=True) - generator = parallel() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 12) - self.assertEqual(next(generator), 13) - self.assertEqual(next(generator), 14) - self.assertEqual(next(generator), 15) - self.assertEqual(next(generator), 16) - self.assertEqual(next(generator), 17) - self.assertEqual(next(generator), 18) - self.assertEqual(next(generator), 19) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 20) - self.assertEqual(next(generator), 21) - self.assertEqual(next(generator), 22) - self.assertEqual(next(generator), 23) - self.assertEqual(next(generator), 24) - self.assertEqual(next(generator), 25) - self.assertEqual(next(generator), 26) - self.assertEqual(next(generator), 27) - self.assertEqual(next(generator), 28) - self.assertEqual(next(generator), 29) - self.assertEqual(next(generator), 2) - - def test_parallel_with_gradual_reweighting(self): - """Test of the parallel ccmbinator with weights and minimum.""" - dataset1 = lambda: (i for i in itertools.cycle(range(1))) - dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) - dataset3 = lambda: (i for i in itertools.cycle(range(30, 70))) - parallel = data.Parallel([dataset2, dataset1, dataset3], - counters=(20, 1, 40), - gradually_reweight=True) - generator = parallel() - - for _ in range(3): - self.assertEqual(next(generator), 0) - for i in range(20): - self.assertEqual(next(generator), 10 + i) - self.assertEqual(next(generator), 30 + 2 * i) - self.assertEqual(next(generator), 30 + 2 * i + 1) - - def test_parallel_with_gradual_reweighting_remainders(self): - """Test of the parallel ccmbinator with weights and minimum.""" - dataset1 = lambda: (i for i in itertools.cycle(range(1))) - dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) - dataset3 = lambda: (i for i in itertools.cycle(range(30, 80))) - parallel = data.Parallel([dataset2, dataset1, dataset3], - counters=(20, 1, 50), - gradually_reweight=True, - use_remainders=True) - generator = parallel() - - for _ in range(3): - self.assertEqual(next(generator), 0) - for i in range(20): - self.assertEqual(next(generator), 10 + i) - self.assertEqual(next(generator), 30 + 2 * i) - self.assertEqual(next(generator), 30 + 2 * i + 1) - # Here we process the remainder from dataset 3: - for i in range(10): - self.assertEqual(next(generator), 70 + i) - - def test_parallel_with_gradual_reweighting_remainders_big(self): - """Test of the parallel ccmbinator with weights and minimum.""" - dataset1 = lambda: (i for i in itertools.cycle(range(1))) - dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) - dataset3 = lambda: (i for i in itertools.cycle(range(30, 80))) - dataset4 = lambda: (i for i in itertools.cycle(range(100, 220))) - parallel = data.Parallel([dataset2, dataset1, dataset4, dataset3], - counters=(20, 1, 120, 50), - gradually_reweight=True, - use_remainders=True) - generator = parallel() - - for _ in range(3): - self.assertEqual(next(generator), 0) - for i in range(20): - self.assertEqual(next(generator), 10 + i) - for j in range(2): - self.assertEqual(next(generator), 30 + 2 * i + j) - for k in range(2): - self.assertEqual(next(generator), 100 + 2 * 2 * i + 2 * j + k) - # Here we process the remainder from datasets 3 and 4: - for i in range(10): - self.assertEqual(next(generator), 70 + i) - for i in range(40): - self.assertEqual(next(generator), 180 + i) - - def test_parallel_with_weights_three_datasets(self): - """Check that data.Serial works inside another data.Serial.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - dataset3 = lambda: (i for i in range(20, 30)) - parallel = data.Parallel( - [dataset1, dataset2, dataset3], counters=(2, 1, 3)) - generator = parallel() - - self.assertEqual(next(generator), 0) # (1,0,0) - self.assertEqual(next(generator), 10) # (1,1,0) - self.assertEqual(next(generator), 20) # (1,1,1) - self.assertEqual(next(generator), 1) # (2,1,1) - self.assertEqual(next(generator), 21) # (2,1,2) - self.assertEqual(next(generator), 22) # (2,1,3) - self.assertEqual(next(generator), 2) # (1,0,0) - self.assertEqual(next(generator), 11) # (1,1,0) - self.assertEqual(next(generator), 23) # (1,1,1) - self.assertEqual(next(generator), 3) # (2,1,1) - self.assertEqual(next(generator), 24) # (2,1,2) - self.assertEqual(next(generator), 25) # (2,1,3) - self.assertEqual(next(generator), 4) # (1,0,0) - - def test_stack_parallel(self): - """Test of stacked parallel ccmbinators.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - dataset3 = lambda: (i for i in range(20, 30)) - parallel_lev0 = data.Parallel([dataset1, dataset2]) - parallel_lev1 = data.Parallel([parallel_lev0, dataset3]) - generator = parallel_lev1() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 20) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 21) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 22) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 23) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 24) - self.assertEqual(next(generator), 12) - - def test_parallel_with_zero_counters(self): - """Test of stacked parallel ccmbinators.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - dataset3 = lambda: (i for i in range(20, 30)) - parallel = data.Parallel([dataset1, dataset2, dataset3], counters=[1, 0, 1]) - generator = parallel() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 20) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 21) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 22) - self.assertEqual(next(generator), 3) - self.assertEqual(next(generator), 23) - - def test_serial_with_python(self): - dataset = lambda _: ((i, i+1) for i in range(10)) - batches = data.Serial( - dataset, - lambda g: map(lambda x: (x[0], x[1] + 1), g), - lambda g: filter(lambda x: x[0] % 2 == 1, g), - data.Batch(2) - ) - batch = next(batches()) - self.assertLen(batch, 2) - (xs, ys) = batch - # First tuple after filtering is (1, 3) = (1, 2+1). - self.assertEqual(xs[0], 1) - self.assertEqual(ys[0], 3) - # Second tuple after filtering is (3, 5). - self.assertEqual(xs[1], 3) - self.assertEqual(ys[1], 5) - - def test_pad_to_max_dims(self): - tensors1 = [np.zeros((3, 10)), np.ones((3, 10))] - padded1 = data.inputs.pad_to_max_dims(tensors1) - self.assertEqual(padded1.shape, (2, 3, 10)) - tensors2 = [np.zeros((2, 10)), np.ones((3, 9))] - padded2 = data.inputs.pad_to_max_dims(tensors2) - self.assertEqual(padded2.shape, (2, 3, 10)) - tensors3 = [np.zeros((8, 10)), np.ones((8, 9))] - padded3 = data.inputs.pad_to_max_dims(tensors3, 12) - self.assertEqual(padded3.shape, (2, 12, 12)) - tensors4 = [np.zeros((2, 10)), np.ones((3, 9))] - padded4 = data.inputs.pad_to_max_dims(tensors4, 12) - self.assertEqual(padded4.shape, (2, 4, 12)) - - def test_pad_to_length(self): - tensors1 = [(np.zeros((5)), np.ones((3)))] - pad_to_length_function1 = data.inputs.PadToLength(len_map={0: 10, - 1: 11}, - pad_value={0: 0, - 1: 1}) - padded1 = next(pad_to_length_function1(tensors1)) - self.assertEqual(padded1[0].shape, (10,)) - self.assertEqual(padded1[1].shape, (11,)) - - tensors2 = [(np.zeros((15)), np.ones((20)))] - pad_to_length_function2 = data.inputs.PadToLength(len_map={0: 10, - 1: 10}, - pad_value={0: 0, - 1: 1}, - multiple=True) - padded2 = next(pad_to_length_function2(tensors2)) - self.assertEqual(padded2[0].shape, (20,)) - self.assertEqual(padded2[1].shape, (20,)) - - def test_concatenate_lm_input(self): - tensors1 = [(np.zeros((5)), np.ones((3)))] - - lm_input_function1 = data.inputs.ConcatenateToLMInput(pad_to_length=10) - lm_input_1 = next(lm_input_function1(tensors1)) - self.assertEqual(lm_input_1[0].shape, (10,)) - self.assertEqual(lm_input_1[1].shape, (10,)) - self.assertEqual(lm_input_1[2].shape, (10,)) - self.assertEqual(lm_input_1[2].all(), - np.array([[0., 0., 0., 0., 0., - 1., 1., 1., 0., 0.]]).all()) - - tensors2 = [(np.zeros((5)), np.ones((3)))] - lm_input_function2 = data.inputs.ConcatenateToLMInput() - lm_input_2 = next(lm_input_function2(tensors2)) - self.assertEqual(lm_input_2[0].shape, (8,)) - self.assertEqual(lm_input_2[1].shape, (8,)) - self.assertEqual(lm_input_2[2].shape, (8,)) - self.assertEqual(lm_input_2[2].all(), - np.array([[0., 0., 0., 0., 0., - 1., 1., 1.]]).all()) - - def test_truncate_to_length_no_arg(self): - """Tests that a no-arg call leaves shapes unchanged.""" - def data_stream(): - while True: - yield (np.zeros((1, 5)), np.ones((1, 5))) - stream_fn = data.inputs.TruncateToLength() - y0, y1 = next(stream_fn(data_stream())) - self.assertEqual(y0.shape, (1, 5)) - self.assertEqual(y1.shape, (1, 5)) - - @parameterized.named_parameters( - ('none', None, ((1, 5), (1, 5))), - ('large_values', {0: (1, 77), 1: (1, 88)}, ((1, 5), (1, 5))), - ('small_values', {0: (1, 3), 1: (1, 2)}, ((1, 3), (1, 2))), - ) - def test_truncate_to_length_len_map(self, len_map, out_shapes): - """Tests that truncation occurs when len_map values are small enough.""" - def data_stream(): - while True: - yield (np.zeros((1, 5)), np.ones((1, 5))) - stream_fn = data.inputs.TruncateToLength(len_map=len_map) - y0, y1 = next(stream_fn(data_stream())) - self.assertEqual(y0.shape, out_shapes[0]) - self.assertEqual(y1.shape, out_shapes[1]) - - def test_truncate_to_length_questionable_behavior(self): - # Use of np.reshape in TruncateToLength allows non-truncation results - # without warning. As long as the target shape (len_map value) is - # lexicographically prior to the data shape, then np.reshape can happen, - # even if it results in *adding* values to the overall array. - # - # This test passes as a marker of the questionable behavior, and should - # *fail* -- and then be removed -- when the function is - # clarified/re-implemented. - # - # TODO(jonni): Determine desired behavior, and fit implementation to it. - x = np.arange(21).reshape((1, 21, 1)) - def data_stream(): - while True: - yield x - stream_fn = data.inputs.TruncateToLength(len_map={0: (1, 4, 6)}) - (y,) = next(stream_fn(data_stream())) - self.assertEqual(y.shape, (1, 4, 6)) - self.assertEqual(y[0, 3, 1], 19) - self.assertEqual(y[0, 3, 2], 20) # end of original values [0..20] - self.assertEqual(y[0, 3, 3], 0) # added value - self.assertEqual(y[0, 3, 4], 1) # added value - self.assertEqual(y[0, 3, 5], 2) # added value - - def test_filter_empty_examples(self): - tensors1 = [(np.zeros((0,)), np.ones((1, 5))), - (np.zeros((1, 5)), np.ones((1, 5)))] - - filter_empty_examples_function1 = data.inputs.FilterEmptyExamples() - filtered1 = next(filter_empty_examples_function1(tensors1)) - self.assertEqual(filtered1[0].shape, (1, 5)) - self.assertEqual(filtered1[1].shape, (1, 5)) - - filter_empty_examples_function2 = data.inputs.FilterEmptyExamples(axes=[1]) - filtered2 = next(filter_empty_examples_function2(tensors1)) - self.assertEqual(filtered2[0].shape, (0,)) - self.assertEqual(filtered2[1].shape, (1, 5)) - - def test_append_value(self): - tensors1 = [(np.zeros((1, 5)), np.ones((1, 5)))] - - append_value_function1 = data.inputs.AppendValue() - unmodified = next(append_value_function1(tensors1)) - self.assertEqual(unmodified[0].shape, (1, 5)) - self.assertEqual(unmodified[1].shape, (1, 5)) - - append_value_function2 = data.inputs.AppendValue({0: [[5]], - 1: [[4]]}) - appended = next(append_value_function2(tensors1)) - self.assertEqual(appended[0].shape, (1, 6)) - self.assertEqual(appended[0].all(), - np.array([[0., 0., 0., 0., 0., 5.]]).all()) - self.assertEqual(appended[1].shape, (1, 6)) - self.assertEqual(appended[1].all(), - np.array([[1., 1., 1., 1., 1., 4.]]).all()) - - def test_pad_to_max_dims_boundary_list(self): - tensors = [np.zeros((1, 15, 31)), np.ones((2, 10, 35)), np.ones((4, 2, 3))] - padded_tensors = data.inputs.pad_to_max_dims( - tensors, boundary=(None, 15, 20)) - # no boundary, only max in the first dim, 15 is already the max len in - # second dim, last dim padded to multiple of 20. - # The outer dim is the batch here. - self.assertEqual(padded_tensors.shape, (3, 4, 15, 40)) - - def test_pad_to_max_dims_strict_pad_on_len(self): - tensors = [np.ones((15,)), np.ones((12,)), np.ones((14,))] - padded_tensors = data.inputs.pad_to_max_dims( - tensors, boundary=10, strict_pad_on_len=True) - self.assertEqual(padded_tensors.shape, (3, 20)) - - def test_bucket_by_length(self): - def fake_generator(length, num_examples=1): - for _ in range(num_examples): - yield (np.ones((length,)), np.ones((length,))) - - def length_function(example): - return max(example[0].shape[0], example[1].shape[0]) - - batches = list(data.bucket_by_length(fake_generator(5, 6), - length_function, - [20], - [2], - strict_pad_on_len=True)) - - # We'll get three batches of 2 examples each. - self.assertLen(batches, 3) - self.assertIsInstance(batches[0], tuple) - self.assertLen(batches[0], 2) - self.assertEqual((2, 20), batches[0][0].shape) - self.assertEqual((2, 20), batches[0][1].shape) - - @parameterized.named_parameters( - ('encdec_on', True), - ('encdec_off', False), - ) - def test_addition_inputs_exceptions(self, encdec): - vocab_size = 5 - batch_size = 256 - seq_length = 64 - # Check if max/min lengths are validated for train stream - with self.assertRaises(ValueError): - inputs = data.inputs.addition_inputs( - vocab_size=vocab_size, - batch_size=batch_size, - train_length=2, - eval_min_length=1, - eval_max_length=seq_length, - pad_to_multiple=seq_length, - encdec=encdec) - train_stream = inputs.train_stream(n_devices=1) - for _ in range(10): - next(train_stream) - - # Check if max/min lengths are validated for eval stream - with self.assertRaises(ValueError): - inputs = data.inputs.addition_inputs( - vocab_size=vocab_size, - batch_size=batch_size, - train_length=seq_length, - eval_min_length=1, - eval_max_length=seq_length, - pad_to_multiple=seq_length, - encdec=True) - eval_stream = inputs.eval_stream(n_devices=1) - for _ in range(10): - next(eval_stream) - - def test_addition_inputs_constraints(self): - vocab_size = 5 - batch_size = 256 - seq_length = 64 - inputs = data.inputs.addition_inputs( - vocab_size=vocab_size, - batch_size=batch_size, - train_length=seq_length, - eval_min_length=seq_length, - eval_max_length=seq_length, - pad_to_multiple=seq_length, - encdec=True) - - # Check if max length is respected for train stream - train_stream = inputs.train_stream(n_devices=1) - for _ in range(10): - x, y, weights = next(train_stream) - self.assertEqual(x.shape[1], seq_length) - self.assertEqual(y.shape[1], seq_length) - self.assertEqual(weights.shape[1], seq_length) - - # Check if max length is respected for eval stream - eval_stream = inputs.eval_stream(n_devices=1) - for _ in range(10): - x, y, weights = next(eval_stream) - self.assertEqual(x.shape[1], seq_length) - self.assertEqual(y.shape[1], seq_length) - self.assertEqual(weights.shape[1], seq_length) - - def _get_span_lengths(self, x): - span_lengths = [] - curr_len = 0 - for i in range(1, len(x)): - # 1 -> 0 - if x[i] == 0 and x[i - 1] == 1: - span_lengths.append(curr_len) - curr_len = 0 - # 1 -> 1 or 0 -> 1 - elif ((x[i] == 1 and x[i - 1] == 1) or - (x[i] == 1 and x[i - 1] == 0)): - curr_len += 1 - if curr_len != 0: - span_lengths.append(curr_len) - return span_lengths - - def test_random_spans_noise_mask(self): - length = 100 - noise_density = 0.15 - mean_noise_span_length = 3.0 - - # Take 5 random seed1, seed2 values. - for seed in np.random.randint(0, 100, (5, 2)): - is_noise = data.random_spans_noise_mask(length, - noise_density, - mean_noise_span_length, - seed1=seed[0], - seed2=seed[1]) - is_noise = is_noise.astype(np.int32) - # noise_density fraction of tokens are produced - self.assertEqual(np.sum(is_noise), noise_density * length) - # Get span lengths and make sure the average is what we expect. - actual_span_lengths = self._get_span_lengths(is_noise) - average_span_length = ( - sum(actual_span_lengths) / len(actual_span_lengths)) - self.assertEqual(mean_noise_span_length, average_span_length) - - def test_process_c4_with_span_corruption(self): - def process_c4_with_span_corruption(spm_path=None, - extra_ids=0, - train=False, - max_length=100, - noise_density=0.15, - mean_noise_span_length=3.0, - seed1=None, - seed2=None): - return data.Serial( - data.TFDS( - 'c4/en:2.3.0', data_dir=_TESTDATA, keys=('text',), train=train), - data.SentencePieceTokenize(spm_path=spm_path, extra_ids=extra_ids), - data.generate_sequential_chunks(max_length=max_length), - data.generate_random_noise_mask( - noise_density=noise_density, - mean_noise_span_length=mean_noise_span_length, - seed1=seed1, seed2=seed2), - data.consume_noise_mask(vocab_size=32000 + extra_ids), - data.FilterEmptyExamples(), - data.AppendValue(val={0: [1], 1: [1]}), - data.PadToLength(len_map={0: 100, 1: 30}, pad_value={0: 0, 1: 0}), - data.AddLossWeights(id_to_mask=0), - data.Batch(batch_size=2) - ) - - gen = process_c4_with_span_corruption( - spm_path=_spm_path(), seed1=0, seed2=1) - - examples = [] - for i, ex in enumerate(gen()): - if i == 100: - break - examples.append(ex) - - self.assertLen(examples, 100) - example = examples[0] - - batched_input, batched_output, batched_loss_weights = example - - self.assertSequenceEqual( - batched_input.tolist(), - # pylint: disable=bad-continuation,bad-whitespace - [[ 37, 2335, 113, 3977, 227, 7306, 45, 3, 9, - 4716, 147, 8, 71, 2658, 65, 118, 4313, 38, - 3, 9, 13065, 32, 31999, 9, 5704, 26, 109, - 6, 6862, 6, 4728, 45, 8, 3796, 24093, 11834, - 4716, 30, 8, 1379, 13, 31998, 130, 718, 12, - 8, 24124, 1343, 300, 4357, 1714, 31997, 1373, 47, - 16487, 3168, 16, 321, 7943, 5, 3, 4868, 3856, - 5700, 75, 7, 200, 2231, 6, 11163, 9, 6, - 113, 47, 5330, 45, 14354, 6, 47, 31996, 20721, - 3654, 44, 8, 3112, 5, 14599, 11, 8067, 31995, - 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0], - [ 277, 828, 43, 5899, 46, 16, 10952, 139, 160, - 1687, 56, 539, 30, 2875, 41, 31122, 2307, 137, - 2702, 2780, 15, 7, 31999, 44, 8, 3112, 11, - 30, 569, 783, 5, 3, 17701, 6, 2194, 26, - 23, 1336, 6321, 1694, 30, 31998, 196, 56, 1852, - 1423, 25, 5, 27, 183, 8032, 31997, 217, 149, - 1513, 11, 2238, 25, 1800, 5, 96, 2703, 44, - 3065, 12537, 11163, 9, 535, 71, 9363, 14886, 646, - 44, 8, 3112, 243, 23281, 12, 8, 31996, 346, - 402, 17, 99, 83, 11, 773, 3668, 1280, 31995, - 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0]] - # pylint: enable=bad-continuation,bad-whitespace - ) - - self.assertSequenceEqual( - batched_output.tolist(), - # pylint: disable=bad-continuation,bad-whitespace - [[31999, 1639, 7, 15480, 5, 11163, 31998, 2083, 9997, - 5076, 31997, 265, 11, 8, 31996, 3, 31995, 1343, - 2487, 106, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 0], - [31999, 12, 8, 15480, 130, 646, 31998, 1376, 10, - 96, 31997, 62, 410, 59, 31996, 96, 31995, 94, - 608, 10, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 0]] - # pylint: enable=bad-continuation,bad-whitespace - ) - - self.assertSequenceEqual( - batched_loss_weights.tolist(), - # pylint: disable=bad-continuation,bad-whitespace - [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], - [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]] - # pylint: enable=bad-continuation,bad-whitespace - ) - - def test_prefix_lm_last_output_batch_is_short(self): - prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) - examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7, 8]])) - self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) - self.assertSequenceEqual(([6, 7], [8]), examples[1]) - self.assertLen(examples, 2) - - def test_prefix_lm_last_input_batch_is_short(self): - prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) - examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6]])) - self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) - self.assertLen(examples, 1) - - def test_prefix_lm_last_input_batch_exists_but_no_output(self): - prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) - examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7]])) - self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) - self.assertLen(examples, 1) - - def test_unbatch(self): - unbatch_fn = data.UnBatch() - batched_inputs = [ - # First batch - 3 examples - (np.arange(3*2).reshape(3, -1), - np.arange(3*3).reshape(3, -1), - np.arange(3*4).reshape(3, -1)), - # Second batch - 4 examples - (np.arange(4*2).reshape(4, -1), - np.arange(4*3).reshape(4, -1), - np.arange(4*4).reshape(4, -1)), - ] - examples = list(unbatch_fn(batched_inputs)) - self.assertLen(examples, 3 + 4) - - def test_sine_shape(self): - inputs = data.sine_inputs(batch_size=3, length=5) - train_batch = next(inputs.train_stream(n_devices=1)) - eval_batch = next(inputs.eval_stream(n_devices=1)) - # (observations, actions, observations, mask) - self.assertLen(train_batch, 4) - self.assertLen(eval_batch, 4) - for (x, y) in zip(train_batch, eval_batch): - self.assertEqual(x.shape, (3, 5)) - self.assertEqual(y.shape, (3, 5)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/data/testdata/squad/v1.1/3.0.0/dataset_info.json b/trax/data/testdata/squad/v1.1/3.0.0/dataset_info.json deleted file mode 100644 index 3298dab0a..000000000 --- a/trax/data/testdata/squad/v1.1/3.0.0/dataset_info.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n", - "description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", - "location": { - "urls": [ - "https://rajpurkar.github.io/SQuAD-explorer/" - ] - }, - "name": "squad", - "schema": { - "feature": [ - { - "name": "answers" - }, - { - "name": "context", - "type": "BYTES" - }, - { - "name": "id", - "type": "BYTES" - }, - { - "name": "question", - "type": "BYTES" - }, - { - "name": "title", - "type": "BYTES" - } - ] - }, - "sizeInBytes": "35142551", - "splits": [ - { - "name": "train", - "numShards": "1", - "shardLengths": [ - "10" - ] - }, - { - "name": "validation", - "numShards": "1", - "shardLengths": [ - "10" - ] - } - ], - "version": "3.0.0" -} diff --git a/trax/data/text_encoder.py b/trax/data/text_encoder.py index 245d9f312..afa610439 100644 --- a/trax/data/text_encoder.py +++ b/trax/data/text_encoder.py @@ -21,7 +21,6 @@ * SubwordTextEncoder: invertible * BertEncoder: for compatible tokenizers with original bert """ - import collections import itertools import math @@ -50,1289 +49,1325 @@ # '\\' is converted to '\' # '\213;' is converted to unichr(213) _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") -_ESCAPE_CHARS = set(u"\\_u;0123456789") +_ESCAPE_CHARS = set("\\_u;0123456789") # Unicode utility functions that work with Python 2 and 3 def native_to_unicode(s): - if is_unicode(s): - return s - try: - return to_unicode(s) - except UnicodeDecodeError: - res = to_unicode(s, ignore_errors=True) - logging.info("Ignoring Unicode error, outputting: %s", res) - return res + if is_unicode(s): + return s + try: + return to_unicode(s) + except UnicodeDecodeError: + res = to_unicode(s, ignore_errors=True) + logging.info("Ignoring Unicode error, outputting: %s", res) + return res def is_unicode(s): - return isinstance(s, six.text_type) + return isinstance(s, six.text_type) def to_unicode(s, ignore_errors=False): - if is_unicode(s): - return s - error_mode = "ignore" if ignore_errors else "strict" - return s.decode("utf-8", errors=error_mode) + if is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) def to_unicode_ignore_errors(s): - return to_unicode(s, ignore_errors=True) + return to_unicode(s, ignore_errors=True) def to_unicode_utf8(s): - return s.decode("utf-8") + return s.decode("utf-8") def strip_ids(ids, ids_to_strip): - """Strip ids_to_strip from the end IDs.""" - ids = list(ids) - while ids and ids[-1] in ids_to_strip: - ids.pop() - return ids + """Strip ids_to_strip from the end IDs.""" + ids = list(ids) + while ids and ids[-1] in ids_to_strip: + ids.pop() + return ids class TextEncoder: - """Base class for converting from ints to/from human readable strings.""" + """Base class for converting from ints to/from human readable strings.""" - def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): - self._num_reserved_ids = num_reserved_ids + def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): + self._num_reserved_ids = num_reserved_ids - @property - def num_reserved_ids(self): - return self._num_reserved_ids + @property + def num_reserved_ids(self): + return self._num_reserved_ids - def encode(self, s): - """Transform a human-readable string into a sequence of int IDs. + def encode(self, s): + """Transform a human-readable string into a sequence of int IDs. - The IDs should be in the range [num_reserved_ids, vocab_size). IDs [0, - num_reserved_ids) are reserved. + The IDs should be in the range [num_reserved_ids, vocab_size). IDs [0, + num_reserved_ids) are reserved. - EOS is not appended. + EOS is not appended. - Args: - s: human-readable string to be converted. + Args: + s: human-readable string to be converted. - Returns: - ids: list of integers - """ - return [int(w) + self._num_reserved_ids for w in s.split()] + Returns: + ids: list of integers + """ + return [int(w) + self._num_reserved_ids for w in s.split()] - def decode(self, ids, strip_extraneous=False): - """Transform a sequence of int IDs into a human-readable string. + def decode(self, ids, strip_extraneous=False): + """Transform a sequence of int IDs into a human-readable string. - EOS is not expected in IDs. + EOS is not expected in IDs. - Args: - ids: list of integers to be converted. - strip_extraneous: bool, whether to strip off extraneous tokens (EOS and - PAD). + Args: + ids: list of integers to be converted. + strip_extraneous: bool, whether to strip off extraneous tokens (EOS and + PAD). - Returns: - s: human-readable string. - """ - if strip_extraneous: - ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) - return " ".join(self.decode_list(ids)) + Returns: + s: human-readable string. + """ + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + return " ".join(self.decode_list(ids)) - def decode_list(self, ids): - """Transform a sequence of int IDs into a their string versions. + def decode_list(self, ids): + """Transform a sequence of int IDs into a their string versions. - This method supports transforming individual input/output IDs to their - string versions so that sequence to/from text conversions can be visualized - in a human readable format. + This method supports transforming individual input/output IDs to their + string versions so that sequence to/from text conversions can be visualized + in a human readable format. - Args: - ids: list of integers to be converted. + Args: + ids: list of integers to be converted. - Returns: - strs: list of human-readable string. - """ - decoded_ids = [] - for id_ in ids: - if 0 <= id_ < self._num_reserved_ids: - decoded_ids.append(RESERVED_TOKENS[int(id_)]) - else: - decoded_ids.append(id_ - self._num_reserved_ids) - return [str(d) for d in decoded_ids] + Returns: + strs: list of human-readable string. + """ + decoded_ids = [] + for id_ in ids: + if 0 <= id_ < self._num_reserved_ids: + decoded_ids.append(RESERVED_TOKENS[int(id_)]) + else: + decoded_ids.append(id_ - self._num_reserved_ids) + return [str(d) for d in decoded_ids] - @property - def vocab_size(self): - raise NotImplementedError() + @property + def vocab_size(self): + raise NotImplementedError() class ByteTextEncoder(TextEncoder): - """Encodes each byte to an id. For 8-bit strings only.""" - - def encode(self, s): - numres = self._num_reserved_ids - # Python3: explicitly convert to UTF-8 - return [c + numres for c in s.encode("utf-8")] - - def decode(self, ids, strip_extraneous=False): - if strip_extraneous: - ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) - numres = self._num_reserved_ids - decoded_ids = [] - int2byte = six.int2byte - for id_ in ids: - if 0 <= id_ < numres: - decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) - else: - decoded_ids.append(int2byte(id_ - numres)) - # Python3: join byte arrays and then decode string - return b"".join(decoded_ids).decode("utf-8", "replace") - - def decode_list(self, ids): - numres = self._num_reserved_ids - decoded_ids = [] - int2byte = six.int2byte - for id_ in ids: - if 0 <= id_ < numres: - decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) - else: - decoded_ids.append(int2byte(id_ - numres)) - # Python3: join byte arrays and then decode string - return decoded_ids - - @property - def vocab_size(self): - return 2**8 + self._num_reserved_ids + """Encodes each byte to an id. For 8-bit strings only.""" + + def encode(self, s): + numres = self._num_reserved_ids + # Python3: explicitly convert to UTF-8 + return [c + numres for c in s.encode("utf-8")] + + def decode(self, ids, strip_extraneous=False): + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + numres = self._num_reserved_ids + decoded_ids = [] + int2byte = six.int2byte + for id_ in ids: + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + else: + decoded_ids.append(int2byte(id_ - numres)) + # Python3: join byte arrays and then decode string + return b"".join(decoded_ids).decode("utf-8", "replace") + + def decode_list(self, ids): + numres = self._num_reserved_ids + decoded_ids = [] + int2byte = six.int2byte + for id_ in ids: + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + else: + decoded_ids.append(int2byte(id_ - numres)) + # Python3: join byte arrays and then decode string + return decoded_ids + + @property + def vocab_size(self): + return 2**8 + self._num_reserved_ids class ClassLabelEncoder(TextEncoder): - """Encoder for class labels.""" + """Encoder for class labels.""" - def __init__(self, class_labels=None, class_labels_fname=None): - super(ClassLabelEncoder, self).__init__(num_reserved_ids=0) + def __init__(self, class_labels=None, class_labels_fname=None): + super(ClassLabelEncoder, self).__init__(num_reserved_ids=0) - if class_labels_fname: - with tf.io.gfile.GFile(class_labels_fname) as f: - class_labels = [label.strip() for label in f.readlines()] + if class_labels_fname: + with tf.io.gfile.GFile(class_labels_fname) as f: + class_labels = [label.strip() for label in f.readlines()] - assert class_labels - self._class_labels = class_labels + assert class_labels + self._class_labels = class_labels - def encode(self, s): - label_str = s - return self._class_labels.index(label_str) + def encode(self, s): + label_str = s + return self._class_labels.index(label_str) - def decode(self, ids, strip_extraneous=False): - del strip_extraneous - label_id = ids - if isinstance(label_id, list): - assert len(label_id) == 1 - label_id, = label_id - if isinstance(label_id, np.ndarray): - label_id = np.squeeze(label_id) - return self._class_labels[label_id] + def decode(self, ids, strip_extraneous=False): + del strip_extraneous + label_id = ids + if isinstance(label_id, list): + assert len(label_id) == 1 + (label_id,) = label_id + if isinstance(label_id, np.ndarray): + label_id = np.squeeze(label_id) + return self._class_labels[label_id] - def decode_list(self, ids): - return [self._class_labels[i] for i in ids] + def decode_list(self, ids): + return [self._class_labels[i] for i in ids] - @property - def vocab_size(self): - return len(self._class_labels) + @property + def vocab_size(self): + return len(self._class_labels) class OneHotClassLabelEncoder(ClassLabelEncoder): - """One-hot encoder for class labels.""" + """One-hot encoder for class labels.""" - def encode(self, label_str, on_value=1, off_value=0): # pylint: disable=arguments-differ - e = np.full(self.vocab_size, off_value, dtype=np.int32) - e[self._class_labels.index(label_str)] = on_value - return e.tolist() + def encode( + self, label_str, on_value=1, off_value=0 + ): # pylint: disable=arguments-differ + e = np.full(self.vocab_size, off_value, dtype=np.int32) + e[self._class_labels.index(label_str)] = on_value + return e.tolist() - def decode(self, ids, strip_extraneous=False): - del strip_extraneous - label_id = ids - if isinstance(label_id, np.ndarray): - label_id = np.squeeze(label_id).astype(np.int8).tolist() - assert isinstance(label_id, list) - assert len(label_id) == self.vocab_size - return self._class_labels[label_id.index(1)] + def decode(self, ids, strip_extraneous=False): + del strip_extraneous + label_id = ids + if isinstance(label_id, np.ndarray): + label_id = np.squeeze(label_id).astype(np.int8).tolist() + assert isinstance(label_id, list) + assert len(label_id) == self.vocab_size + return self._class_labels[label_id.index(1)] - @property - def vocab_size(self): - return len(self._class_labels) + @property + def vocab_size(self): + return len(self._class_labels) class TokenTextEncoder(TextEncoder): - """Encoder based on a user-supplied vocabulary (file or list).""" - - def __init__(self, - vocab_filename, - reverse=False, - vocab_list=None, - replace_oov=None, - num_reserved_ids=NUM_RESERVED_TOKENS): - """Initialize from a file or list, one token per line. - - Handling of reserved tokens works as follows: - - When initializing from a list, we add reserved tokens to the vocab. - - When initializing from a file, we do not add reserved tokens to the vocab. - - When saving vocab files, we save reserved tokens to the file. - - Args: - vocab_filename: If not None, the full filename to read vocab from. If this - is not None, then vocab_list should be None. - reverse: Boolean indicating if tokens should be reversed during encoding - and decoding. - vocab_list: If not None, a list of elements of the vocabulary. If this is - not None, then vocab_filename should be None. - replace_oov: If not None, every out-of-vocabulary token seen when encoding - will be replaced by this string (which must be in vocab). - num_reserved_ids: Number of IDs to save for reserved tokens like . - """ - super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) - self._reverse = reverse - self._replace_oov = replace_oov - if vocab_filename: - self._init_vocab_from_file(vocab_filename) - else: - assert vocab_list is not None - self._init_vocab_from_list(vocab_list) - - def encode(self, s): - """Converts a space-separated string of tokens to a list of ids.""" - sentence = s - tokens = sentence.strip().split() - if self._replace_oov is not None: - tokens = [ - t if t in self._token_to_id else self._replace_oov for t in tokens - ] - ret = [self._token_to_id[tok] for tok in tokens] - return ret[::-1] if self._reverse else ret - - def decode(self, ids, strip_extraneous=False): - return " ".join(self.decode_list(ids)) - - def decode_list(self, ids): - seq = reversed(ids) if self._reverse else ids - return [self._safe_id_to_token(i) for i in seq] - - @property - def vocab_size(self): - return len(self._id_to_token) - - def _safe_id_to_token(self, idx): - return self._id_to_token.get(idx, "ID_%d" % idx) - - def _init_vocab_from_file(self, filename): - """Load vocab from a file. - - Args: - filename: The file to load vocabulary from. - """ - with tf.io.gfile.GFile(filename) as f: - tokens = [token.strip() for token in f.readlines()] - - def token_gen(): - for token in tokens: - yield token - - self._init_vocab(token_gen(), add_reserved_tokens=False) - - def _init_vocab_from_list(self, vocab_list): - """Initialize tokens from a list of tokens. - - It is ok if reserved tokens appear in the vocab list. They will be - removed. The set of tokens in vocab_list should be unique. - - Args: - vocab_list: A list of tokens. - """ - - def token_gen(): - for token in vocab_list: - if token not in RESERVED_TOKENS: - yield token - - self._init_vocab(token_gen()) - - def _init_vocab(self, token_generator, add_reserved_tokens=True): - """Initialize vocabulary with tokens from token_generator.""" - - self._id_to_token = {} - non_reserved_start_index = 0 - - if add_reserved_tokens: - self._id_to_token.update(enumerate(RESERVED_TOKENS)) - non_reserved_start_index = len(RESERVED_TOKENS) - - self._id_to_token.update( - enumerate(token_generator, start=non_reserved_start_index)) - - # _token_to_id is the reverse of _id_to_token - self._token_to_id = dict( - (v, k) for k, v in six.iteritems(self._id_to_token)) - - def store_to_file(self, filename): - """Write vocab file to disk. - - Vocab files have one token per line. The file ends in a newline. Reserved - tokens are written to the vocab file as well. - - Args: - filename: Full path of the file to store the vocab to. - """ - with tf.io.gfile.GFile(filename, "w") as f: - for i in range(len(self._id_to_token)): - f.write(self._id_to_token[i] + "\n") - - -def _escape_token(token, alphabet): - """Escape away underscores and OOV characters and append '_'. - - This allows the token to be expressed as the concatenation of a list - of subtokens from the vocabulary. The underscore acts as a sentinel - which allows us to invertibly concatenate multiple such lists. + """Encoder based on a user-supplied vocabulary (file or list).""" + + def __init__( + self, + vocab_filename, + reverse=False, + vocab_list=None, + replace_oov=None, + num_reserved_ids=NUM_RESERVED_TOKENS, + ): + """Initialize from a file or list, one token per line. + + Handling of reserved tokens works as follows: + - When initializing from a list, we add reserved tokens to the vocab. + - When initializing from a file, we do not add reserved tokens to the vocab. + - When saving vocab files, we save reserved tokens to the file. + + Args: + vocab_filename: If not None, the full filename to read vocab from. If this + is not None, then vocab_list should be None. + reverse: Boolean indicating if tokens should be reversed during encoding + and decoding. + vocab_list: If not None, a list of elements of the vocabulary. If this is + not None, then vocab_filename should be None. + replace_oov: If not None, every out-of-vocabulary token seen when encoding + will be replaced by this string (which must be in vocab). + num_reserved_ids: Number of IDs to save for reserved tokens like . + """ + super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) + self._reverse = reverse + self._replace_oov = replace_oov + if vocab_filename: + self._init_vocab_from_file(vocab_filename) + else: + assert vocab_list is not None + self._init_vocab_from_list(vocab_list) - Args: - token: A unicode string to be escaped. - alphabet: A set of all characters in the vocabulary's alphabet. + def encode(self, s): + """Converts a space-separated string of tokens to a list of ids.""" + sentence = s + tokens = sentence.strip().split() + if self._replace_oov is not None: + tokens = [ + t if t in self._token_to_id else self._replace_oov for t in tokens + ] + ret = [self._token_to_id[tok] for tok in tokens] + return ret[::-1] if self._reverse else ret - Returns: - escaped_token: An escaped unicode string. + def decode(self, ids, strip_extraneous=False): + return " ".join(self.decode_list(ids)) - Raises: - ValueError: If the provided token is not unicode. - """ - if not isinstance(token, six.text_type): - raise ValueError("Expected string type for token, got %s" % type(token)) + def decode_list(self, ids): + seq = reversed(ids) if self._reverse else ids + return [self._safe_id_to_token(i) for i in seq] - token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") - ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] - return u"".join(ret) + "_" + @property + def vocab_size(self): + return len(self._id_to_token) + def _safe_id_to_token(self, idx): + return self._id_to_token.get(idx, "ID_%d" % idx) -def _unescape_token(escaped_token): - """Inverse of _escape_token(). + def _init_vocab_from_file(self, filename): + """Load vocab from a file. - Args: - escaped_token: a unicode string + Args: + filename: The file to load vocabulary from. + """ + with tf.io.gfile.GFile(filename) as f: + tokens = [token.strip() for token in f.readlines()] - Returns: - token: a unicode string - """ + def token_gen(): + for token in tokens: + yield token - def match(m): - if m.group(1) is None: - return u"_" if m.group(0) == u"\\u" else u"\\" + self._init_vocab(token_gen(), add_reserved_tokens=False) - try: - return six.unichr(int(m.group(1))) - except (ValueError, OverflowError) as _: - return u"\u3013" # Unicode for undefined character. + def _init_vocab_from_list(self, vocab_list): + """Initialize tokens from a list of tokens. - trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token - return _UNESCAPE_REGEX.sub(match, trimmed) + It is ok if reserved tokens appear in the vocab list. They will be + removed. The set of tokens in vocab_list should be unique. + Args: + vocab_list: A list of tokens. + """ -class SubwordTextEncoder(TextEncoder): - """Class for invertibly encoding text using a limited vocabulary. + def token_gen(): + for token in vocab_list: + if token not in RESERVED_TOKENS: + yield token - Invertibly encodes a native string as a sequence of subtokens from a limited - vocabulary. + self._init_vocab(token_gen()) - A SubwordTextEncoder is built from a corpus (so it is tailored to the text in - the corpus), and stored to a file. See text_encoder_build_subword.py. + def _init_vocab(self, token_generator, add_reserved_tokens=True): + """Initialize vocabulary with tokens from token_generator.""" - It can then be loaded and used to encode/decode any text. + self._id_to_token = {} + non_reserved_start_index = 0 - Encoding has four phases: + if add_reserved_tokens: + self._id_to_token.update(enumerate(RESERVED_TOKENS)) + non_reserved_start_index = len(RESERVED_TOKENS) - 1. Tokenize into a list of tokens. Each token is a unicode string of either - all alphanumeric characters or all non-alphanumeric characters. We drop - tokens consisting of a single space that are between two alphanumeric - tokens. + self._id_to_token.update( + enumerate(token_generator, start=non_reserved_start_index) + ) - 2. Escape each token. This escapes away special and out-of-vocabulary - characters, and makes sure that each token ends with an underscore, and - has no other underscores. + # _token_to_id is the reverse of _id_to_token + self._token_to_id = dict((v, k) for k, v in six.iteritems(self._id_to_token)) - 3. Represent each escaped token as a the concatenation of a list of subtokens - from the limited vocabulary. Subtoken selection is done greedily from - beginning to end. That is, we construct the list in order, always picking - the longest subtoken in our vocabulary that matches a prefix of the - remaining portion of the encoded token. + def store_to_file(self, filename): + """Write vocab file to disk. - 4. Concatenate these lists. This concatenation is invertible due to the - fact that the trailing underscores indicate when one list is finished. + Vocab files have one token per line. The file ends in a newline. Reserved + tokens are written to the vocab file as well. - """ + Args: + filename: Full path of the file to store the vocab to. + """ + with tf.io.gfile.GFile(filename, "w") as f: + for i in range(len(self._id_to_token)): + f.write(self._id_to_token[i] + "\n") - def __init__(self, filename=None): - """Initialize and read from a file, if provided. - Args: - filename: filename from which to read vocab. If None, do not load a vocab - """ - self._alphabet = set() - self.filename = filename - if filename is not None: - self._load_from_file(filename) - super(SubwordTextEncoder, self).__init__() - - def encode(self, s): - """Converts a native string to a list of subtoken IDs. - - Args: - s: a native string. - - Returns: - a list of integers in the range [0, vocab_size) - """ - return self._tokens_to_subtoken_ids(tokenizer.encode(native_to_unicode(s))) - - def encode_without_tokenizing(self, token_text): - """Converts string to list of subtoken IDs without calling tokenizer. +def _escape_token(token, alphabet): + """Escape away underscores and OOV characters and append '_'. - This treats `token_text` as a single token and directly converts it - to subtoken IDs. This may be useful when the default tokenizer doesn't - do what we want (e.g., when encoding text with tokens composed of lots of - nonalphanumeric characters). It is then up to the caller to make sure that - raw text is consistently converted into tokens. Only use this if you are - sure that `encode` doesn't suit your needs. + This allows the token to be expressed as the concatenation of a list + of subtokens from the vocabulary. The underscore acts as a sentinel + which allows us to invertibly concatenate multiple such lists. Args: - token_text: A native string representation of a single token. + token: A unicode string to be escaped. + alphabet: A set of all characters in the vocabulary's alphabet. Returns: - A list of subword token IDs; i.e., integers in the range [0, vocab_size). - """ - return self._tokens_to_subtoken_ids([native_to_unicode(token_text)]) + escaped_token: An escaped unicode string. - def decode(self, ids, strip_extraneous=False): - """Converts a sequence of subtoken IDs to a native string. - - Args: - ids: a list of integers in the range [0, vocab_size) - strip_extraneous: bool, whether to strip off extraneous tokens (EOS and - PAD). - - Returns: - a native string + Raises: + ValueError: If the provided token is not unicode. """ - if strip_extraneous: - ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) - return tokenizer.decode(self._subtoken_ids_to_tokens(ids)) + if not isinstance(token, six.text_type): + raise ValueError("Expected string type for token, got %s" % type(token)) - def decode_list(self, ids): - return [self._subtoken_id_to_subtoken_string(s) for s in ids] + token = token.replace("\\", "\\\\").replace("_", "\\u") + ret = [c if c in alphabet and c != "\n" else r"\%d;" % ord(c) for c in token] + return "".join(ret) + "_" - @property - def vocab_size(self): - """The subtoken vocabulary size.""" - return len(self._all_subtoken_strings) - def _tokens_to_subtoken_ids(self, tokens): - """Converts a list of tokens to a list of subtoken IDs. +def _unescape_token(escaped_token): + """Inverse of _escape_token(). Args: - tokens: a list of strings. + escaped_token: a unicode string Returns: - a list of integers in the range [0, vocab_size) + token: a unicode string """ - ret = [] - for token in tokens: - ret.extend(self._token_to_subtoken_ids(token)) - return ret - def _token_to_subtoken_ids(self, token): - """Converts token to a list of subtoken IDs. + def match(m): + if m.group(1) is None: + return "_" if m.group(0) == "\\u" else "\\" - Args: - token: a string. + try: + return six.unichr(int(m.group(1))) + except (ValueError, OverflowError) as _: + return "\u3013" # Unicode for undefined character. - Returns: - a list of integers in the range [0, vocab_size) - """ - cache_location = hash(token) % self._cache_size - cache_key, cache_value = self._cache[cache_location] - if cache_key == token: - return cache_value - ret = self._escaped_token_to_subtoken_ids( - _escape_token(token, self._alphabet)) - self._cache[cache_location] = (token, ret) - return ret - - def _subtoken_ids_to_tokens(self, subtokens): - """Converts a list of subtoken IDs to a list of tokens. + trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token + return _UNESCAPE_REGEX.sub(match, trimmed) - Args: - subtokens: a list of integers in the range [0, vocab_size) - Returns: - a list of strings. - """ - concatenated = "".join( - [self._subtoken_id_to_subtoken_string(s) for s in subtokens]) - split = concatenated.split("_") - ret = [] - for t in split: - if t: - unescaped = _unescape_token(t + "_") - if unescaped: - ret.append(unescaped) - return ret - - def _subtoken_id_to_subtoken_string(self, subtoken): - """Converts a subtoken integer ID to a subtoken string.""" - if 0 <= subtoken < self.vocab_size: - return self._all_subtoken_strings[subtoken] - return u"" - - def _escaped_token_to_subtoken_strings(self, escaped_token): - """Converts an escaped token string to a list of subtoken strings. +class SubwordTextEncoder(TextEncoder): + """Class for invertibly encoding text using a limited vocabulary. - Args: - escaped_token: An escaped token as a unicode string. + Invertibly encodes a native string as a sequence of subtokens from a limited + vocabulary. - Returns: - A list of subtokens as unicode strings. - """ - # NOTE: This algorithm is greedy; it won't necessarily produce the "best" - # list of subtokens. - ret = [] - start = 0 - token_len = len(escaped_token) - while start < token_len: - for end in range( - min(token_len, start + self._max_subtoken_len), start, -1): - subtoken = escaped_token[start:end] - if subtoken in self._subtoken_string_to_id: - ret.append(subtoken) - start = end - break - - else: # Did not break - # If there is no possible encoding of the escaped token then one of the - # characters in the token is not in the alphabet. This should be - # impossible and would be indicative of a bug. - assert False, "Token substring not found in subtoken vocabulary." - - return ret - - def _escaped_token_to_subtoken_ids(self, escaped_token): - """Converts an escaped token string to a list of subtoken IDs. + A SubwordTextEncoder is built from a corpus (so it is tailored to the text in + the corpus), and stored to a file. See text_encoder_build_subword.py. - Args: - escaped_token: An escaped token as a unicode string. + It can then be loaded and used to encode/decode any text. - Returns: - A list of subtoken IDs as integers. - """ - return [ - self._subtoken_string_to_id[subtoken] - for subtoken in self._escaped_token_to_subtoken_strings(escaped_token) - ] - - @classmethod - def build_from_generator(cls, - generator, - target_size, - max_subtoken_length=None, - reserved_tokens=None): - """Builds a SubwordTextEncoder from the generated text. + Encoding has four phases: - Args: - generator: yields text. - target_size: int, approximate vocabulary size to create. - max_subtoken_length: Maximum length of a subtoken. If this is not set, - then the runtime and memory use of creating the vocab is quadratic in - the length of the longest token. If this is set, then it is instead - O(max_subtoken_length * length of longest token). - reserved_tokens: List of reserved tokens. The global variable - `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this - argument is `None`, it will use `RESERVED_TOKENS`. + 1. Tokenize into a list of tokens. Each token is a unicode string of either + all alphanumeric characters or all non-alphanumeric characters. We drop + tokens consisting of a single space that are between two alphanumeric + tokens. - Returns: - SubwordTextEncoder with `vocab_size` approximately `target_size`. - """ - token_counts = collections.defaultdict(int) - for item in generator: - for tok in tokenizer.encode(native_to_unicode(item)): - token_counts[tok] += 1 - encoder = cls.build_to_target_size( - target_size, - token_counts, - 1, - 1e3, - max_subtoken_length=max_subtoken_length, - reserved_tokens=reserved_tokens) - return encoder - - @classmethod - def build_to_target_size(cls, - target_size, - token_counts, - min_val, - max_val, - max_subtoken_length=None, - reserved_tokens=None, - num_iterations=4): - """Builds a SubwordTextEncoder that has `vocab_size` near `target_size`. - - Uses simple recursive binary search to find a minimum token count that most - closely matches the `target_size`. + 2. Escape each token. This escapes away special and out-of-vocabulary + characters, and makes sure that each token ends with an underscore, and + has no other underscores. - Args: - target_size: Desired vocab_size to approximate. - token_counts: A dictionary of token counts, mapping string to int. - min_val: An integer; lower bound for the minimum token count. - max_val: An integer; upper bound for the minimum token count. - max_subtoken_length: Maximum length of a subtoken. If this is not set, - then the runtime and memory use of creating the vocab is quadratic in - the length of the longest token. If this is set, then it is instead - O(max_subtoken_length * length of longest token). - reserved_tokens: List of reserved tokens. The global variable - `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this - argument is `None`, it will use `RESERVED_TOKENS`. - num_iterations: An integer; how many iterations of refinement. + 3. Represent each escaped token as a the concatenation of a list of subtokens + from the limited vocabulary. Subtoken selection is done greedily from + beginning to end. That is, we construct the list in order, always picking + the longest subtoken in our vocabulary that matches a prefix of the + remaining portion of the encoded token. - Returns: - A SubwordTextEncoder instance. + 4. Concatenate these lists. This concatenation is invertible due to the + fact that the trailing underscores indicate when one list is finished. - Raises: - ValueError: If `min_val` is greater than `max_val`. """ - if min_val > max_val: - raise ValueError("Lower bound for the minimum token count " - "is greater than the upper bound.") - if target_size < 1: - raise ValueError("Target size must be positive.") - - if reserved_tokens is None: - reserved_tokens = RESERVED_TOKENS - - def bisect(min_val, max_val): - """Bisection to find the right size.""" - present_count = (max_val + min_val) // 2 - logging.info("Trying min_count %d", present_count) - subtokenizer = cls() - subtokenizer.build_from_token_counts( - token_counts, - present_count, - num_iterations, - max_subtoken_length=max_subtoken_length, - reserved_tokens=reserved_tokens) - - # Being within 1% of the target size is ok. - is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size - # If min_val == max_val, we can't do any better than this. - if is_ok or min_val >= max_val or present_count < 2: - return subtokenizer - - if subtokenizer.vocab_size > target_size: - other_subtokenizer = bisect(present_count + 1, max_val) - else: - other_subtokenizer = bisect(min_val, present_count - 1) - - if other_subtokenizer is None: - return subtokenizer - - if (abs(other_subtokenizer.vocab_size - target_size) < - abs(subtokenizer.vocab_size - target_size)): - return other_subtokenizer - return subtokenizer - - return bisect(min_val, max_val) - - def build_from_token_counts(self, - token_counts, - min_count, - num_iterations=4, - reserved_tokens=None, - max_subtoken_length=None): - """Train a SubwordTextEncoder based on a dictionary of word counts. - - Args: - token_counts: a dictionary of Unicode strings to int. - min_count: an integer - discard subtokens with lower counts. - num_iterations: an integer. how many iterations of refinement. - reserved_tokens: List of reserved tokens. The global variable - `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this - argument is `None`, it will use `RESERVED_TOKENS`. - max_subtoken_length: Maximum length of a subtoken. If this is not set, - then the runtime and memory use of creating the vocab is quadratic in - the length of the longest token. If this is set, then it is instead - O(max_subtoken_length * length of longest token). - Raises: - ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it - is not clear what the space is being reserved for, or when it will be - filled in. - """ - if reserved_tokens is None: - reserved_tokens = RESERVED_TOKENS - else: - # There is not complete freedom in replacing RESERVED_TOKENS. - for default, proposed in zip(RESERVED_TOKENS, reserved_tokens): - if default != proposed: - raise ValueError("RESERVED_TOKENS must be a prefix of " - "reserved_tokens.") - - # Initialize the alphabet. Note, this must include reserved tokens or it can - # result in encoding failures. - alphabet_tokens = itertools.chain( - six.iterkeys(token_counts), - [native_to_unicode(t) for t in reserved_tokens]) - - self._init_alphabet_from_tokens(alphabet_tokens) - - # Bootstrap the initial list of subtokens with the characters from the - # alphabet plus the escaping characters. - self._init_subtokens_from_list( - list(self._alphabet), reserved_tokens=reserved_tokens) - - # We build iteratively. On each iteration, we segment all the words, - # then count the resulting potential subtokens, keeping the ones - # with high enough counts for our new vocabulary. - if min_count < 1: - min_count = 1 - for i in range(num_iterations): - logging.info("Iteration %d", i) - - # Collect all substrings of the encoded token that break along current - # subtoken boundaries. - subtoken_counts = collections.defaultdict(int) - for token, count in six.iteritems(token_counts): - iter_start_time = time.time() - escaped_token = _escape_token(token, self._alphabet) - subtokens = self._escaped_token_to_subtoken_strings(escaped_token) + def __init__(self, filename=None): + """Initialize and read from a file, if provided. + + Args: + filename: filename from which to read vocab. If None, do not load a vocab + """ + self._alphabet = set() + self.filename = filename + if filename is not None: + self._load_from_file(filename) + super(SubwordTextEncoder, self).__init__() + + def encode(self, s): + """Converts a native string to a list of subtoken IDs. + + Args: + s: a native string. + + Returns: + a list of integers in the range [0, vocab_size) + """ + return self._tokens_to_subtoken_ids(tokenizer.encode(native_to_unicode(s))) + + def encode_without_tokenizing(self, token_text): + """Converts string to list of subtoken IDs without calling tokenizer. + + This treats `token_text` as a single token and directly converts it + to subtoken IDs. This may be useful when the default tokenizer doesn't + do what we want (e.g., when encoding text with tokens composed of lots of + nonalphanumeric characters). It is then up to the caller to make sure that + raw text is consistently converted into tokens. Only use this if you are + sure that `encode` doesn't suit your needs. + + Args: + token_text: A native string representation of a single token. + + Returns: + A list of subword token IDs; i.e., integers in the range [0, vocab_size). + """ + return self._tokens_to_subtoken_ids([native_to_unicode(token_text)]) + + def decode(self, ids, strip_extraneous=False): + """Converts a sequence of subtoken IDs to a native string. + + Args: + ids: a list of integers in the range [0, vocab_size) + strip_extraneous: bool, whether to strip off extraneous tokens (EOS and + PAD). + + Returns: + a native string + """ + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + return tokenizer.decode(self._subtoken_ids_to_tokens(ids)) + + def decode_list(self, ids): + return [self._subtoken_id_to_subtoken_string(s) for s in ids] + + @property + def vocab_size(self): + """The subtoken vocabulary size.""" + return len(self._all_subtoken_strings) + + def _tokens_to_subtoken_ids(self, tokens): + """Converts a list of tokens to a list of subtoken IDs. + + Args: + tokens: a list of strings. + + Returns: + a list of integers in the range [0, vocab_size) + """ + ret = [] + for token in tokens: + ret.extend(self._token_to_subtoken_ids(token)) + return ret + + def _token_to_subtoken_ids(self, token): + """Converts token to a list of subtoken IDs. + + Args: + token: a string. + + Returns: + a list of integers in the range [0, vocab_size) + """ + cache_location = hash(token) % self._cache_size + cache_key, cache_value = self._cache[cache_location] + if cache_key == token: + return cache_value + ret = self._escaped_token_to_subtoken_ids(_escape_token(token, self._alphabet)) + self._cache[cache_location] = (token, ret) + return ret + + def _subtoken_ids_to_tokens(self, subtokens): + """Converts a list of subtoken IDs to a list of tokens. + + Args: + subtokens: a list of integers in the range [0, vocab_size) + + Returns: + a list of strings. + """ + concatenated = "".join( + [self._subtoken_id_to_subtoken_string(s) for s in subtokens] + ) + split = concatenated.split("_") + ret = [] + for t in split: + if t: + unescaped = _unescape_token(t + "_") + if unescaped: + ret.append(unescaped) + return ret + + def _subtoken_id_to_subtoken_string(self, subtoken): + """Converts a subtoken integer ID to a subtoken string.""" + if 0 <= subtoken < self.vocab_size: + return self._all_subtoken_strings[subtoken] + return "" + + def _escaped_token_to_subtoken_strings(self, escaped_token): + """Converts an escaped token string to a list of subtoken strings. + + Args: + escaped_token: An escaped token as a unicode string. + + Returns: + A list of subtokens as unicode strings. + """ + # NOTE: This algorithm is greedy; it won't necessarily produce the "best" + # list of subtokens. + ret = [] start = 0 - for subtoken in subtokens: - last_position = len(escaped_token) + 1 - if max_subtoken_length is not None: - last_position = min(last_position, start + max_subtoken_length) - - for end in range(start + 1, last_position): - new_subtoken = escaped_token[start:end] - subtoken_counts[new_subtoken] += count - start += len(subtoken) - iter_time_secs = time.time() - iter_start_time - if iter_time_secs > 0.1: - logging.info( - "Processing token [%s] took {%d} seconds, consider " - "setting Text2TextProblem.max_subtoken_length to a " - "smaller value.", token, iter_time_secs) - - # Array of sets of candidate subtoken strings, by length. - len_to_subtoken_strings = [] - for subtoken_string, count in six.iteritems(subtoken_counts): - lsub = len(subtoken_string) - if count >= min_count: - while len(len_to_subtoken_strings) <= lsub: - len_to_subtoken_strings.append(set()) - len_to_subtoken_strings[lsub].add(subtoken_string) - - # Consider the candidates longest to shortest, so that if we accept - # a longer subtoken string, we can decrement the counts of its prefixes. - new_subtoken_strings = [] - for lsub in range(len(len_to_subtoken_strings) - 1, 0, -1): - subtoken_strings = len_to_subtoken_strings[lsub] - for subtoken_string in subtoken_strings: - count = subtoken_counts[subtoken_string] - if count >= min_count: - # Exclude alphabet tokens here, as they must be included later, - # explicitly, regardless of count. - if subtoken_string not in self._alphabet: - new_subtoken_strings.append((count, subtoken_string)) - for l in range(1, lsub): - subtoken_counts[subtoken_string[:l]] -= count - - # Include the alphabet explicitly to guarantee all strings are encodable. - new_subtoken_strings.extend( - (subtoken_counts.get(a, 0), a) for a in self._alphabet) - new_subtoken_strings.sort(reverse=True) - - # Reinitialize to the candidate vocabulary. - new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings] - if reserved_tokens: - escaped_reserved_tokens = [ - _escape_token(native_to_unicode(t), self._alphabet) - for t in reserved_tokens + token_len = len(escaped_token) + while start < token_len: + for end in range(min(token_len, start + self._max_subtoken_len), start, -1): + subtoken = escaped_token[start:end] + if subtoken in self._subtoken_string_to_id: + ret.append(subtoken) + start = end + break + + else: # Did not break + # If there is no possible encoding of the escaped token then one of the + # characters in the token is not in the alphabet. This should be + # impossible and would be indicative of a bug. + assert False, "Token substring not found in subtoken vocabulary." + + return ret + + def _escaped_token_to_subtoken_ids(self, escaped_token): + """Converts an escaped token string to a list of subtoken IDs. + + Args: + escaped_token: An escaped token as a unicode string. + + Returns: + A list of subtoken IDs as integers. + """ + return [ + self._subtoken_string_to_id[subtoken] + for subtoken in self._escaped_token_to_subtoken_strings(escaped_token) ] - new_subtoken_strings = escaped_reserved_tokens + new_subtoken_strings - - self._init_subtokens_from_list(new_subtoken_strings) - logging.info("vocab_size = %d", self.vocab_size) - - @property - def all_subtoken_strings(self): - return tuple(self._all_subtoken_strings) - - def dump(self): - """Debugging dump of the current subtoken vocabulary.""" - subtoken_strings = [ - (i, s) for s, i in six.iteritems(self._subtoken_string_to_id) - ] - print(u", ".join( - u"{0} : '{1}'".format(i, s) for i, s in sorted(subtoken_strings))) - - def _init_subtokens_from_list(self, subtoken_strings, reserved_tokens=None): - """Initialize token information from a list of subtoken strings. - - Args: - subtoken_strings: a list of subtokens - reserved_tokens: List of reserved tokens. We must have `reserved_tokens` - as None or the empty list, or else the global variable `RESERVED_TOKENS` - must be a prefix of `reserved_tokens`. - - Raises: - ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it - is not clear what the space is being reserved for, or when it will be - filled in. - """ - if reserved_tokens is None: - reserved_tokens = [] - - if reserved_tokens: - self._all_subtoken_strings = reserved_tokens + subtoken_strings - else: - self._all_subtoken_strings = subtoken_strings - - # we remember the maximum length of any subtoken to avoid having to - # check arbitrarily long strings. - self._max_subtoken_len = max([len(s) for s in subtoken_strings]) - self._subtoken_string_to_id = { - s: i + len(reserved_tokens) for i, s in enumerate(subtoken_strings) if s - } - # Initialize the cache to empty. - self._cache_size = 2**20 - self._cache = [(None, None)] * self._cache_size - - def _init_alphabet_from_tokens(self, tokens): - """Initialize alphabet from an iterable of token or subtoken strings.""" - # Include all characters from all tokens in the alphabet to guarantee that - # any token can be encoded. Additionally, include all escaping characters. - self._alphabet = {c for token in tokens for c in token} # pylint: disable=g-complex-comprehension - self._alphabet |= _ESCAPE_CHARS - - def _load_from_file_object(self, f): - """Load from a file object. - Args: - f: File object to load vocabulary from - """ - subtoken_strings = [] - for line in f: - s = line.rstrip() - # Some vocab files wrap words in single quotes, but others don't - if ((s.startswith("'") and s.endswith("'")) or - (s.startswith("\"") and s.endswith("\""))): - s = s[1:-1] - subtoken_strings.append(native_to_unicode(s)) - self._init_subtokens_from_list(subtoken_strings) - self._init_alphabet_from_tokens(subtoken_strings) - - def _load_from_file(self, filename): - """Load from a vocab file.""" - if not tf.io.gfile.exists(filename): - raise ValueError("File %s not found" % filename) - with tf.io.gfile.GFile(filename) as f: - self._load_from_file_object(f) - - def store_to_file(self, filename, add_single_quotes=True): - with tf.io.gfile.GFile(filename, "w") as f: - for subtoken_string in self._all_subtoken_strings: - if add_single_quotes: - f.write("'" + subtoken_string + "'\n") + @classmethod + def build_from_generator( + cls, generator, target_size, max_subtoken_length=None, reserved_tokens=None + ): + """Builds a SubwordTextEncoder from the generated text. + + Args: + generator: yields text. + target_size: int, approximate vocabulary size to create. + max_subtoken_length: Maximum length of a subtoken. If this is not set, + then the runtime and memory use of creating the vocab is quadratic in + the length of the longest token. If this is set, then it is instead + O(max_subtoken_length * length of longest token). + reserved_tokens: List of reserved tokens. The global variable + `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this + argument is `None`, it will use `RESERVED_TOKENS`. + + Returns: + SubwordTextEncoder with `vocab_size` approximately `target_size`. + """ + token_counts = collections.defaultdict(int) + for item in generator: + for tok in tokenizer.encode(native_to_unicode(item)): + token_counts[tok] += 1 + encoder = cls.build_to_target_size( + target_size, + token_counts, + 1, + 1e3, + max_subtoken_length=max_subtoken_length, + reserved_tokens=reserved_tokens, + ) + return encoder + + @classmethod + def build_to_target_size( + cls, + target_size, + token_counts, + min_val, + max_val, + max_subtoken_length=None, + reserved_tokens=None, + num_iterations=4, + ): + """Builds a SubwordTextEncoder that has `vocab_size` near `target_size`. + + Uses simple recursive binary search to find a minimum token count that most + closely matches the `target_size`. + + Args: + target_size: Desired vocab_size to approximate. + token_counts: A dictionary of token counts, mapping string to int. + min_val: An integer; lower bound for the minimum token count. + max_val: An integer; upper bound for the minimum token count. + max_subtoken_length: Maximum length of a subtoken. If this is not set, + then the runtime and memory use of creating the vocab is quadratic in + the length of the longest token. If this is set, then it is instead + O(max_subtoken_length * length of longest token). + reserved_tokens: List of reserved tokens. The global variable + `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this + argument is `None`, it will use `RESERVED_TOKENS`. + num_iterations: An integer; how many iterations of refinement. + + Returns: + A SubwordTextEncoder instance. + + Raises: + ValueError: If `min_val` is greater than `max_val`. + """ + if min_val > max_val: + raise ValueError( + "Lower bound for the minimum token count " + "is greater than the upper bound." + ) + if target_size < 1: + raise ValueError("Target size must be positive.") + + if reserved_tokens is None: + reserved_tokens = RESERVED_TOKENS + + def bisect(min_val, max_val): + """Bisection to find the right size.""" + present_count = (max_val + min_val) // 2 + logging.info("Trying min_count %d", present_count) + subtokenizer = cls() + subtokenizer.build_from_token_counts( + token_counts, + present_count, + num_iterations, + max_subtoken_length=max_subtoken_length, + reserved_tokens=reserved_tokens, + ) + + # Being within 1% of the target size is ok. + is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size + # If min_val == max_val, we can't do any better than this. + if is_ok or min_val >= max_val or present_count < 2: + return subtokenizer + + if subtokenizer.vocab_size > target_size: + other_subtokenizer = bisect(present_count + 1, max_val) + else: + other_subtokenizer = bisect(min_val, present_count - 1) + + if other_subtokenizer is None: + return subtokenizer + + if abs(other_subtokenizer.vocab_size - target_size) < abs( + subtokenizer.vocab_size - target_size + ): + return other_subtokenizer + return subtokenizer + + return bisect(min_val, max_val) + + def build_from_token_counts( + self, + token_counts, + min_count, + num_iterations=4, + reserved_tokens=None, + max_subtoken_length=None, + ): + """Train a SubwordTextEncoder based on a dictionary of word counts. + + Args: + token_counts: a dictionary of Unicode strings to int. + min_count: an integer - discard subtokens with lower counts. + num_iterations: an integer. how many iterations of refinement. + reserved_tokens: List of reserved tokens. The global variable + `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this + argument is `None`, it will use `RESERVED_TOKENS`. + max_subtoken_length: Maximum length of a subtoken. If this is not set, + then the runtime and memory use of creating the vocab is quadratic in + the length of the longest token. If this is set, then it is instead + O(max_subtoken_length * length of longest token). + + Raises: + ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it + is not clear what the space is being reserved for, or when it will be + filled in. + """ + if reserved_tokens is None: + reserved_tokens = RESERVED_TOKENS else: - f.write(subtoken_string + "\n") + # There is not complete freedom in replacing RESERVED_TOKENS. + for default, proposed in zip(RESERVED_TOKENS, reserved_tokens): + if default != proposed: + raise ValueError( + "RESERVED_TOKENS must be a prefix of " "reserved_tokens." + ) + + # Initialize the alphabet. Note, this must include reserved tokens or it can + # result in encoding failures. + alphabet_tokens = itertools.chain( + six.iterkeys(token_counts), [native_to_unicode(t) for t in reserved_tokens] + ) + + self._init_alphabet_from_tokens(alphabet_tokens) + + # Bootstrap the initial list of subtokens with the characters from the + # alphabet plus the escaping characters. + self._init_subtokens_from_list( + list(self._alphabet), reserved_tokens=reserved_tokens + ) + + # We build iteratively. On each iteration, we segment all the words, + # then count the resulting potential subtokens, keeping the ones + # with high enough counts for our new vocabulary. + if min_count < 1: + min_count = 1 + for i in range(num_iterations): + logging.info("Iteration %d", i) + + # Collect all substrings of the encoded token that break along current + # subtoken boundaries. + subtoken_counts = collections.defaultdict(int) + for token, count in six.iteritems(token_counts): + iter_start_time = time.time() + escaped_token = _escape_token(token, self._alphabet) + subtokens = self._escaped_token_to_subtoken_strings(escaped_token) + start = 0 + for subtoken in subtokens: + last_position = len(escaped_token) + 1 + if max_subtoken_length is not None: + last_position = min(last_position, start + max_subtoken_length) + + for end in range(start + 1, last_position): + new_subtoken = escaped_token[start:end] + subtoken_counts[new_subtoken] += count + start += len(subtoken) + iter_time_secs = time.time() - iter_start_time + if iter_time_secs > 0.1: + logging.info( + "Processing token [%s] took {%d} seconds, consider " + "setting Text2TextProblem.max_subtoken_length to a " + "smaller value.", + token, + iter_time_secs, + ) + + # Array of sets of candidate subtoken strings, by length. + len_to_subtoken_strings = [] + for subtoken_string, count in six.iteritems(subtoken_counts): + lsub = len(subtoken_string) + if count >= min_count: + while len(len_to_subtoken_strings) <= lsub: + len_to_subtoken_strings.append(set()) + len_to_subtoken_strings[lsub].add(subtoken_string) + + # Consider the candidates longest to shortest, so that if we accept + # a longer subtoken string, we can decrement the counts of its prefixes. + new_subtoken_strings = [] + for lsub in range(len(len_to_subtoken_strings) - 1, 0, -1): + subtoken_strings = len_to_subtoken_strings[lsub] + for subtoken_string in subtoken_strings: + count = subtoken_counts[subtoken_string] + if count >= min_count: + # Exclude alphabet tokens here, as they must be included later, + # explicitly, regardless of count. + if subtoken_string not in self._alphabet: + new_subtoken_strings.append((count, subtoken_string)) + for l in range(1, lsub): + subtoken_counts[subtoken_string[:l]] -= count + + # Include the alphabet explicitly to guarantee all strings are encodable. + new_subtoken_strings.extend( + (subtoken_counts.get(a, 0), a) for a in self._alphabet + ) + new_subtoken_strings.sort(reverse=True) + + # Reinitialize to the candidate vocabulary. + new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings] + if reserved_tokens: + escaped_reserved_tokens = [ + _escape_token(native_to_unicode(t), self._alphabet) + for t in reserved_tokens + ] + new_subtoken_strings = escaped_reserved_tokens + new_subtoken_strings + + self._init_subtokens_from_list(new_subtoken_strings) + logging.info("vocab_size = %d", self.vocab_size) + + @property + def all_subtoken_strings(self): + return tuple(self._all_subtoken_strings) + + def dump(self): + """Debugging dump of the current subtoken vocabulary.""" + subtoken_strings = [ + (i, s) for s, i in six.iteritems(self._subtoken_string_to_id) + ] + print( + ", ".join("{0} : '{1}'".format(i, s) for i, s in sorted(subtoken_strings)) + ) + + def _init_subtokens_from_list(self, subtoken_strings, reserved_tokens=None): + """Initialize token information from a list of subtoken strings. + + Args: + subtoken_strings: a list of subtokens + reserved_tokens: List of reserved tokens. We must have `reserved_tokens` + as None or the empty list, or else the global variable `RESERVED_TOKENS` + must be a prefix of `reserved_tokens`. + + Raises: + ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it + is not clear what the space is being reserved for, or when it will be + filled in. + """ + if reserved_tokens is None: + reserved_tokens = [] + + if reserved_tokens: + self._all_subtoken_strings = reserved_tokens + subtoken_strings + else: + self._all_subtoken_strings = subtoken_strings + + # we remember the maximum length of any subtoken to avoid having to + # check arbitrarily long strings. + self._max_subtoken_len = max([len(s) for s in subtoken_strings]) + self._subtoken_string_to_id = { + s: i + len(reserved_tokens) for i, s in enumerate(subtoken_strings) if s + } + # Initialize the cache to empty. + self._cache_size = 2**20 + self._cache = [(None, None)] * self._cache_size + + def _init_alphabet_from_tokens(self, tokens): + """Initialize alphabet from an iterable of token or subtoken strings.""" + # Include all characters from all tokens in the alphabet to guarantee that + # any token can be encoded. Additionally, include all escaping characters. + self._alphabet = { + c for token in tokens for c in token + } # pylint: disable=g-complex-comprehension + self._alphabet |= _ESCAPE_CHARS + + def _load_from_file_object(self, f): + """Load from a file object. + + Args: + f: File object to load vocabulary from + """ + subtoken_strings = [] + for line in f: + s = line.rstrip() + # Some vocab files wrap words in single quotes, but others don't + if (s.startswith("'") and s.endswith("'")) or ( + s.startswith('"') and s.endswith('"') + ): + s = s[1:-1] + subtoken_strings.append(native_to_unicode(s)) + self._init_subtokens_from_list(subtoken_strings) + self._init_alphabet_from_tokens(subtoken_strings) + + def _load_from_file(self, filename): + """Load from a vocab file.""" + if not tf.io.gfile.exists(filename): + raise ValueError("File %s not found" % filename) + with tf.io.gfile.GFile(filename) as f: + self._load_from_file_object(f) + + def store_to_file(self, filename, add_single_quotes=True): + with tf.io.gfile.GFile(filename, "w") as f: + for subtoken_string in self._all_subtoken_strings: + if add_single_quotes: + f.write("'" + subtoken_string + "'\n") + else: + f.write(subtoken_string + "\n") class ImageEncoder: - """Encoder class for saving and loading images.""" - - def __init__(self, num_reserved_ids=0, height=None, width=None, channels=3): - assert num_reserved_ids == 0 - self._height = height - self._width = width - self._channels = channels - - @property - def num_reserved_ids(self): - return 0 - - def encode(self, s): - """Transform a string with a filename into a list of RGB integers. - - Args: - s: path to the file with an image. - - Returns: - ids: list of integers - """ - try: - import matplotlib.image as im # pylint: disable=g-import-not-at-top - except ImportError as e: - logging.warning( - "Reading an image requires matplotlib to be installed: %s", e) - raise NotImplementedError("Image reading not implemented.") - return im.imread(s) - - def decode(self, ids, strip_extraneous=False): - """Transform a sequence of int IDs into an image file. - - Args: - ids: list of integers to be converted. - strip_extraneous: unused - - Returns: - Path to the temporary file where the image was saved. - - Raises: - ValueError: if the IDs are not of the appropriate size. - """ - del strip_extraneous - _, tmp_file_path = tempfile.mkstemp("_decode.png") - if self._height is None or self._width is None: - size = int(math.sqrt(len(ids) / self._channels)) - length = size * size * self._channels - else: - size = None - length = self._height * self._width * self._channels - if len(ids) != length: - raise ValueError("Length of ids (%d) must be height (%d) x width (%d) x " - "channels (%d); %d != %d.\n Ids: %s" % - (len(ids), self._height, self._width, self._channels, - len(ids), length, " ".join([str(i) for i in ids]))) - with tf.Graph().as_default(): - raw = tf.constant(ids, dtype=tf.uint8) - if size is None: - img = tf.reshape(raw, [self._height, self._width, self._channels]) - else: - img = tf.reshape(raw, [size, size, self._channels]) - png = tf.image.encode_png(img) - op = tf.write_file(tmp_file_path, png) - with tf.Session() as sess: - sess.run(op) - return tmp_file_path - - def decode_list(self, ids): - """Transform a sequence of int IDs into an image file. - - Args: - ids: list of integers to be converted. - - Returns: - Singleton list: path to the temporary file where the image was saved. - """ - return [self.decode(ids)] - - @property - def vocab_size(self): - return 256 + """Encoder class for saving and loading images.""" + + def __init__(self, num_reserved_ids=0, height=None, width=None, channels=3): + assert num_reserved_ids == 0 + self._height = height + self._width = width + self._channels = channels + + @property + def num_reserved_ids(self): + return 0 + + def encode(self, s): + """Transform a string with a filename into a list of RGB integers. + + Args: + s: path to the file with an image. + + Returns: + ids: list of integers + """ + try: + import matplotlib.image as im # pylint: disable=g-import-not-at-top + except ImportError as e: + logging.warning( + "Reading an image requires matplotlib to be installed: %s", e + ) + raise NotImplementedError("Image reading not implemented.") + return im.imread(s) + + def decode(self, ids, strip_extraneous=False): + """Transform a sequence of int IDs into an image file. + + Args: + ids: list of integers to be converted. + strip_extraneous: unused + + Returns: + Path to the temporary file where the image was saved. + + Raises: + ValueError: if the IDs are not of the appropriate size. + """ + del strip_extraneous + _, tmp_file_path = tempfile.mkstemp("_decode.png") + if self._height is None or self._width is None: + size = int(math.sqrt(len(ids) / self._channels)) + length = size * size * self._channels + else: + size = None + length = self._height * self._width * self._channels + if len(ids) != length: + raise ValueError( + "Length of ids (%d) must be height (%d) x width (%d) x " + "channels (%d); %d != %d.\n Ids: %s" + % ( + len(ids), + self._height, + self._width, + self._channels, + len(ids), + length, + " ".join([str(i) for i in ids]), + ) + ) + with tf.Graph().as_default(): + raw = tf.constant(ids, dtype=tf.uint8) + if size is None: + img = tf.reshape(raw, [self._height, self._width, self._channels]) + else: + img = tf.reshape(raw, [size, size, self._channels]) + png = tf.image.encode_png(img) + op = tf.write_file(tmp_file_path, png) + with tf.Session() as sess: + sess.run(op) + return tmp_file_path + + def decode_list(self, ids): + """Transform a sequence of int IDs into an image file. + + Args: + ids: list of integers to be converted. + + Returns: + Singleton list: path to the temporary file where the image was saved. + """ + return [self.decode(ids)] + + @property + def vocab_size(self): + return 256 class RealEncoder: - """Encoder class for saving and loading float values.""" + """Encoder class for saving and loading float values.""" - def encode(self, s): - """Transform a string (space separated float values) into a float array. + def encode(self, s): + """Transform a string (space separated float values) into a float array. - Args: - s: space separated float values. + Args: + s: space separated float values. - Returns: - Array of float values. - """ - return [float(w) for w in s.split()] + Returns: + Array of float values. + """ + return [float(w) for w in s.split()] - def decode(self, ids, strip_extraneous=False): - """Transform sequence of float values into string (float values). + def decode(self, ids, strip_extraneous=False): + """Transform sequence of float values into string (float values). - Args: - ids: array of floats to be converted. - strip_extraneous: unused + Args: + ids: array of floats to be converted. + strip_extraneous: unused - Returns: - String having space separated float values. + Returns: + String having space separated float values. - Raises: - ValueError: if the IDs are not of the appropriate size. - """ - del strip_extraneous - return " ".join([str(i) for i in ids]) + Raises: + ValueError: if the IDs are not of the appropriate size. + """ + del strip_extraneous + return " ".join([str(i) for i in ids]) class BertEncoder: - """Encoder Class that is compatible with models trained in original BERT library.""" - - def __init__(self, vocab_file, do_lower_case=True): - self._vocab = self.load_vocab(vocab_file) - self._inv_vocab = {v: k for k, v in self._vocab.items()} - self._basic_tokenizer = BertBasicEncoder(do_lower_case=do_lower_case) - self._wordpiece_tokenizer = BertWordpieceTokenizer(vocab=self._vocab) - - def load_vocab(self, vocab_file): - """Loads a vocabulary file into a dictionary.""" - vocab = collections.OrderedDict() - index = 0 - with tf.io.gfile.GFile(vocab_file, "r") as reader: - while True: - token = native_to_unicode(reader.readline()) - if not token: - break - token = token.strip() - vocab[token] = index - index += 1 - return vocab - - def encode(self, text): - return self._convert_tokens_to_ids(self.tokenize(text)) - - # Note: Because encoding by BertEncoder is not unique text decoded - # from token ids is not unique. - def decode(self, ids): - """Returns a text that encoded would yield provided ids.""" - tokens = self._convert_ids_to_tokens(ids) - if not tokens: - return "" - retarr = [tokens[0]] - for token in tokens[1:]: - if token.startswith("##"): - retarr.append(token.lstrip("#")) - else: - retarr.append(" ") - retarr.append(token) - return "".join(retarr) - - @property - def vocab_size(self): - return len(self._vocab) - - def tokenize(self, text): - split_tokens = [] - for token in self._basic_tokenizer.tokenize(text): - for sub_token in self._wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) - - return split_tokens - - def _convert_tokens_to_ids(self, tokens): - return [self._vocab[token] for token in tokens] - - def _convert_ids_to_tokens(self, ids): - return [self._inv_vocab[token_id] for token_id in ids] + """Encoder Class that is compatible with models trained in original BERT library.""" + + def __init__(self, vocab_file, do_lower_case=True): + self._vocab = self.load_vocab(vocab_file) + self._inv_vocab = {v: k for k, v in self._vocab.items()} + self._basic_tokenizer = BertBasicEncoder(do_lower_case=do_lower_case) + self._wordpiece_tokenizer = BertWordpieceTokenizer(vocab=self._vocab) + + def load_vocab(self, vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with tf.io.gfile.GFile(vocab_file, "r") as reader: + while True: + token = native_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + def encode(self, text): + return self._convert_tokens_to_ids(self.tokenize(text)) + + # Note: Because encoding by BertEncoder is not unique text decoded + # from token ids is not unique. + def decode(self, ids): + """Returns a text that encoded would yield provided ids.""" + tokens = self._convert_ids_to_tokens(ids) + if not tokens: + return "" + retarr = [tokens[0]] + for token in tokens[1:]: + if token.startswith("##"): + retarr.append(token.lstrip("#")) + else: + retarr.append(" ") + retarr.append(token) + return "".join(retarr) + + @property + def vocab_size(self): + return len(self._vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self._basic_tokenizer.tokenize(text): + for sub_token in self._wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def _convert_tokens_to_ids(self, tokens): + return [self._vocab[token] for token in tokens] + + def _convert_ids_to_tokens(self, ids): + return [self._inv_vocab[token_id] for token_id in ids] class BertBasicEncoder: - """Part of BertEncoder; tokenization (punctuation splitting, lower casing).""" - - def __init__(self, do_lower_case=True): - """Constructs a BasicTokenizer. - - Args: - do_lower_case: Whether to lower case the input. - """ - self.do_lower_case = do_lower_case - - def tokenize(self, text): - """Tokenizes a piece of text.""" - text = native_to_unicode(text) - text = self._clean_text(text) - - text = self._tokenize_chinese_chars(text) - - orig_tokens = whitespace_tokenize(text) - split_tokens = [] - for token in orig_tokens: - if self.do_lower_case: - token = token.lower() - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) - return output_tokens - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _bert_is_punctuation(char): - output.append([char]) + """Part of BertEncoder; tokenization (punctuation splitting, lower casing).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = native_to_unicode(text) + text = self._clean_text(text) + + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _tokenize_chinese_chars(self, text): - """Adds whitespace around any CJK character.""" - output = [] - for char in text: - cp = ord(char) - if self._is_chinese_char(cp): - output.append(" ") - output.append(char) - output.append(" ") - else: - output.append(char) - return "".join(output) - - def _is_chinese_char(self, cp): - """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. - if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # - return True - - return False - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xfffd or _bert_is_control(char): - continue - if _bert_is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) + output = [] + while i < len(chars): + char = chars[i] + if _bert_is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _bert_is_control(char): + continue + if _bert_is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) class BertWordpieceTokenizer: - """Runs WordPiece tokenziation.""" - - def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): - self.vocab = vocab - self.unk_token = unk_token - self.max_input_chars_per_word = max_input_chars_per_word - - def tokenize(self, text): - """Tokenizes a piece of text into its word pieces. - - This uses a greedy longest-match-first algorithm to perform tokenization - using the given vocabulary. - For example: - input = "unaffable" - output = ["un", "##aff", "##able"] - Args: - text: A single token or whitespace separated tokens. This should have - already been passed through `BasicTokenizer. - - Returns: - A list of wordpiece tokens. - """ - - text = native_to_unicode(text) - - output_tokens = [] - for token in whitespace_tokenize(text): - chars = list(token) - if len(chars) > self.max_input_chars_per_word: - output_tokens.append(self.unk_token) - continue - - is_bad = False - start = 0 - sub_tokens = [] - while start < len(chars): - end = len(chars) - cur_substr = None - while start < end: - substr = "".join(chars[start:end]) - if start > 0: - substr = "##" + substr - if substr in self.vocab: - cur_substr = substr - break - end -= 1 - if cur_substr is None: - is_bad = True - break - sub_tokens.append(cur_substr) - start = end - - if is_bad: - output_tokens.append(self.unk_token) - else: - output_tokens.extend(sub_tokens) - return output_tokens + """Runs WordPiece tokenziation.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = native_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens def _bert_is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False def _bert_is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True return False - cat = unicodedata.category(char) - if cat in ("Cc", "Cf"): - return True - return False def _bert_is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ( + (cp >= 33 and cp <= 47) + or (cp >= 58 and cp <= 64) + or (cp >= 91 and cp <= 96) + or (cp >= 123 and cp <= 126) + ): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False def whitespace_tokenize(text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - tokens = text.split() - return tokens + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens diff --git a/trax/data/text_encoder_build_subword.py b/trax/data/text_encoder_build_subword.py index 1df9d85cc..9ca0de783 100644 --- a/trax/data/text_encoder_build_subword.py +++ b/trax/data/text_encoder_build_subword.py @@ -35,46 +35,53 @@ from trax.data import text_encoder from trax.data import tokenizer -flags.DEFINE_string('output_filename', '/tmp/my.subword_text_encoder', - 'where to store the SubwordTextEncoder') -flags.DEFINE_string('corpus_filepattern', '', - 'Corpus of one or more text files') + +flags.DEFINE_string( + "output_filename", + "/tmp/my.subword_text_encoder", + "where to store the SubwordTextEncoder", +) +flags.DEFINE_string("corpus_filepattern", "", "Corpus of one or more text files") flags.DEFINE_string( - 'vocab_filepattern', '', 'One or more vocabulary files ' - '(one word per line as "word,count")') -flags.DEFINE_integer('min_count', 5, 'Minimum subtoken count in corpus') -flags.DEFINE_integer('corpus_max_lines', 10000, - 'How many lines of corpus to read') -flags.DEFINE_integer('num_iterations', 4, 'Number of iterations') -flags.DEFINE_bool('split_on_newlines', True, 'Break corpus into lines.') + "vocab_filepattern", + "", + "One or more vocabulary files " '(one word per line as "word,count")', +) +flags.DEFINE_integer("min_count", 5, "Minimum subtoken count in corpus") +flags.DEFINE_integer("corpus_max_lines", 10000, "How many lines of corpus to read") +flags.DEFINE_integer("num_iterations", 4, "Number of iterations") +flags.DEFINE_bool("split_on_newlines", True, "Break corpus into lines.") FLAGS = flags.FLAGS def main(unused_argv): - if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern: - raise ValueError( - 'Must only provide one of --corpus_filepattern or --vocab_filepattern') - - elif FLAGS.corpus_filepattern: - token_counts = tokenizer.corpus_token_counts( - FLAGS.corpus_filepattern, - FLAGS.corpus_max_lines, - split_on_newlines=FLAGS.split_on_newlines) - - elif FLAGS.vocab_filepattern: - token_counts = tokenizer.vocab_token_counts(FLAGS.vocab_filepattern, - FLAGS.corpus_max_lines) - - else: - raise ValueError( - 'Must provide one of --corpus_filepattern or --vocab_filepattern') - - encoder = text_encoder.SubwordTextEncoder() - encoder.build_from_token_counts(token_counts, FLAGS.min_count, - FLAGS.num_iterations) - encoder.store_to_file(FLAGS.output_filename) - - -if __name__ == '__main__': - app.run(main) + if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern: + raise ValueError( + "Must only provide one of --corpus_filepattern or --vocab_filepattern" + ) + + elif FLAGS.corpus_filepattern: + token_counts = tokenizer.corpus_token_counts( + FLAGS.corpus_filepattern, + FLAGS.corpus_max_lines, + split_on_newlines=FLAGS.split_on_newlines, + ) + + elif FLAGS.vocab_filepattern: + token_counts = tokenizer.vocab_token_counts( + FLAGS.vocab_filepattern, FLAGS.corpus_max_lines + ) + + else: + raise ValueError( + "Must provide one of --corpus_filepattern or --vocab_filepattern" + ) + + encoder = text_encoder.SubwordTextEncoder() + encoder.build_from_token_counts(token_counts, FLAGS.min_count, FLAGS.num_iterations) + encoder.store_to_file(FLAGS.output_filename) + + +if __name__ == "__main__": + app.run(main) diff --git a/trax/data/text_encoder_test.py b/trax/data/text_encoder_test.py deleted file mode 100644 index 791f13e9b..000000000 --- a/trax/data/text_encoder_test.py +++ /dev/null @@ -1,376 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.data.text_encoder.""" - -import collections -import io -import os -import random -import shutil -import string - -import mock -from six.moves import range # pylint: disable=redefined-builtin -import tensorflow.compat.v1 as tf -from trax.data import text_encoder - - -class NativeToUnicodeTest(tf.test.TestCase): - - def test_native_to_unicode(self): - s = r"foo bar" - s_unicode = text_encoder.native_to_unicode(s) - self.assertEqual(s_unicode, u"foo bar") - - -class EscapeUnescapeTokenTest(tf.test.TestCase): - - def test_escape_token(self): - escaped = text_encoder._escape_token( - "Foo! Bar.\nunder_score back\\slash", - set("abcdefghijklmnopqrstuvwxyz .\n") | text_encoder._ESCAPE_CHARS) - - self.assertEqual( - "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_", escaped) - - def test_unescape_token(self): - unescaped = text_encoder._unescape_token( - "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_") - - self.assertEqual( - "Foo! Bar.\nunder_score back\\slash", unescaped) - - -class TokenTextEncoderTest(tf.test.TestCase): - - @classmethod - def setUpClass(cls): - """Make sure the test dir exists and is empty.""" - cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") - shutil.rmtree(cls.test_temp_dir, ignore_errors=True) - tf.gfile.MakeDirs(cls.test_temp_dir) - - def test_save_and_reload(self): - """Test that saving and reloading doesn't change the vocab. - - Note that this test reads and writes to the filesystem, which necessitates - that this test size be "large". - """ - - corpus = "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z" - vocab_filename = os.path.join(self.test_temp_dir, "abc.vocab") - - # Make text encoder from a list and store vocab to fake filesystem. - encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) - encoder.store_to_file(vocab_filename) - - # Load back the saved vocab file from the fake_filesystem. - new_encoder = text_encoder.TokenTextEncoder(vocab_filename) - - self.assertEqual(encoder._id_to_token, new_encoder._id_to_token) - self.assertEqual(encoder._token_to_id, new_encoder._token_to_id) - - def test_reserved_tokens_in_corpus(self): - """Test that we handle reserved tokens appearing in the corpus.""" - corpus = "A B {} D E F {} G {}".format(text_encoder.EOS, - text_encoder.EOS, - text_encoder.PAD) - - encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) - - all_tokens = encoder._id_to_token.values() - - # If reserved tokens are removed correctly, then the set of tokens will - # be unique. - self.assertEqual(len(all_tokens), len(set(all_tokens))) - - -class SubwordTextEncoderTest(tf.test.TestCase): - - @classmethod - def setUpClass(cls): - """Make sure the test dir exists and is empty.""" - cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") - shutil.rmtree(cls.test_temp_dir, ignore_errors=True) - tf.gfile.MakeDirs(cls.test_temp_dir) - - def test_encode_decode(self): - corpus = ( - "This is a corpus of text that provides a bunch of tokens from which " - "to build a vocabulary. It will be used when strings are encoded " - "with a TextEncoder subclass. The encoder was coded by a coder.") - token_counts = collections.Counter(corpus.split(" ")) - alphabet = set(corpus) - {" "} - - original = "This is a coded sentence encoded by the SubwordTextEncoder." - token_counts.update(original.split(" ")) - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - - # Encoding should be reversible. - encoded = encoder.encode(original) - decoded = encoder.decode(encoded) - self.assertEqual(original, decoded) - - # The substrings coded and coder are frequent enough in the corpus that - # they should appear in the vocabulary even though they are substrings - # of other included strings. - subtoken_strings = {encoder.all_subtoken_strings[i] for i in encoded} - self.assertIn("encoded_", subtoken_strings) - self.assertIn("coded_", subtoken_strings) - self.assertIn("TextEncoder", encoder.all_subtoken_strings) - self.assertIn("coder", encoder.all_subtoken_strings) - - # Every character in the corpus should be in the encoders alphabet and - # its subtoken vocabulary. - self.assertTrue(alphabet.issubset(encoder._alphabet)) - for a in alphabet: - self.assertIn(a, encoder.all_subtoken_strings) - - def test_unicode(self): - corpus = "Cat emoticons. \U0001F638 \U0001F639 \U0001F63A \U0001F63B" - token_counts = collections.Counter(corpus.split(" ")) - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - - self.assertIn("\U0001F638", encoder._alphabet) - self.assertIn("\U0001F63B", encoder.all_subtoken_strings) - - def test_small_vocab(self): - corpus = "The quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - alphabet = set(corpus) - {" "} - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 10, token_counts, 2, 10) - - # All vocabulary elements are in the alphabet and subtoken strings even - # if we requested a smaller vocabulary to assure all expected strings - # are encodable. - self.assertTrue(alphabet.issubset(encoder._alphabet)) - for a in alphabet: - self.assertIn(a, encoder.all_subtoken_strings) - - def test_long_tokens(self): - """Subword tokenization should still run efficiently with long tokens. - - To make it run efficiently, we need to use the `max_subtoken_length` - argument when calling SubwordTextEncoder.build_to_target_size. - """ - token_length = 4000 - num_tokens = 50 - target_vocab_size = 600 - max_subtoken_length = 10 # Set this to `None` to get problems. - max_count = 500 - - # Generate some long random strings. - random.seed(0) - long_tokens = [] - for _ in range(num_tokens): - long_token = "".join([random.choice(string.ascii_uppercase) - for _ in range(token_length)]) - long_tokens.append(long_token) - - corpus = " ".join(long_tokens) - token_counts = collections.Counter(corpus.split(" ")) - alphabet = set(corpus) - {" "} - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - target_vocab_size, token_counts, 1, max_count, num_iterations=1, - max_subtoken_length=max_subtoken_length) - - # All vocabulary elements are in the alphabet and subtoken strings even - # if we requested a smaller vocabulary to assure all expected strings - # are encodable. - self.assertTrue(alphabet.issubset(encoder._alphabet)) - for a in alphabet: - self.assertIn(a, encoder.all_subtoken_strings) - - def test_custom_reserved_tokens(self): - """Test that we can pass custom reserved tokens to SubwordTextEncoder.""" - corpus = "The quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - start_symbol = "" - end_symbol = "" - reserved_tokens = text_encoder.RESERVED_TOKENS + [start_symbol, - end_symbol] - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 10, token_counts, 2, 10, reserved_tokens=reserved_tokens) - - # Make sure that reserved tokens appear in the right places. - self.assertEqual(encoder.decode([2]), start_symbol) - self.assertEqual(encoder.decode([3]), end_symbol) - - # Make sure that we haven't messed up the ability to reconstruct. - reconstructed_corpus = encoder.decode(encoder.encode(corpus)) - self.assertEqual(corpus, reconstructed_corpus) - - def test_encodable_when_not_in_alphabet(self): - corpus = "the quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - original = "This has UPPER CASE letters that are out of alphabet" - - # Early versions could have an infinite loop when breaking into subtokens - # if there was any out-of-alphabet characters in the encoded string. - encoded = encoder.encode(original) - decoded = encoder.decode(encoded) - - self.assertEqual(original, decoded) - encoded_str = "".join(encoder.all_subtoken_strings[i] for i in encoded) - self.assertIn("\\84;", encoded_str) - - @mock.patch.object(text_encoder, "_ESCAPE_CHARS", new=set("\\_;13579")) - def test_raises_exception_when_not_encodable(self): - corpus = "the quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - # Deliberately exclude some required encoding chars from the alphabet - # and token list, making some strings unencodable. - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - original = "This has UPPER CASE letters that are out of alphabet" - - # Previously there was a bug which produced an infinite loop in this case. - with self.assertRaises(AssertionError): - encoder.encode(original) - - def test_load_from_file(self): - # Test a vocab file with words not wrapped with single quotes - encoder = text_encoder.SubwordTextEncoder() - correct_vocab = ["the", "and", "of"] - vocab = io.StringIO("the\n" - "and\n" - "of\n") - encoder._load_from_file_object(vocab) - self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) - - # Test a vocab file with words wrapped in single quotes - encoder = text_encoder.SubwordTextEncoder() - vocab = io.StringIO("\"the\"\n" - "\"and\"\n" - "\"of\"\n") - encoder._load_from_file_object(vocab) - self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) - - def test_reserved_token_chars_not_in_alphabet(self): - corpus = "dog" - token_counts = collections.Counter(corpus.split(" ")) - encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 100) - filename = os.path.join(self.test_temp_dir, "out.voc") - encoder1.store_to_file(filename) - encoder2 = text_encoder.SubwordTextEncoder(filename=filename) - - self.assertEqual(encoder1._alphabet, encoder2._alphabet) - - for t in text_encoder.RESERVED_TOKENS: - for c in t: - # Verify that encoders can encode all reserved token chars. - encoder1.encode(c) - encoder2.encode(c) - - def test_save_and_reload(self): - corpus = "the quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - # Deliberately exclude some required encoding chars from the alphabet - # and token list, making some strings unencodable. - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - - filename = os.path.join(self.test_temp_dir, "out.voc") - encoder.store_to_file(filename) - new_encoder = text_encoder.SubwordTextEncoder(filename) - - self.assertEqual(encoder._alphabet, new_encoder._alphabet) - self.assertEqual(encoder.all_subtoken_strings, - new_encoder.all_subtoken_strings) - self.assertEqual(encoder._subtoken_string_to_id, - new_encoder._subtoken_string_to_id) - self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len) - - def test_save_and_reload_no_single_quotes(self): - corpus = "the quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - # Deliberately exclude some required encoding chars from the alphabet - # and token list, making some strings unencodable. - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - - filename = os.path.join(self.test_temp_dir, "out.voc") - encoder.store_to_file(filename, add_single_quotes=False) - new_encoder = text_encoder.SubwordTextEncoder(filename) - - self.assertEqual(encoder._alphabet, new_encoder._alphabet) - self.assertEqual(encoder.all_subtoken_strings, - new_encoder.all_subtoken_strings) - self.assertEqual(encoder._subtoken_string_to_id, - new_encoder._subtoken_string_to_id) - self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len) - - def test_build_from_generator(self): - - corpus = "The quick brown fox jumps over the lazy dog" - - def gen(): - for _ in range(3): - yield corpus - - start_symbol = "" - end_symbol = "" - reserved_tokens = text_encoder.RESERVED_TOKENS + [start_symbol, - end_symbol] - encoder = text_encoder.SubwordTextEncoder.build_from_generator( - gen(), 10, reserved_tokens=reserved_tokens) - - # Make sure that reserved tokens appear in the right places. - self.assertEqual(encoder.decode([2]), start_symbol) - self.assertEqual(encoder.decode([3]), end_symbol) - - self.assertEqual("hi%s" % start_symbol, - encoder.decode(encoder.encode("hi") + [2])) - - # Make sure that we haven't messed up the ability to reconstruct. - reconstructed_corpus = encoder.decode(encoder.encode(corpus)) - self.assertEqual(corpus, reconstructed_corpus) - - -class OneHotClassLabelEncoderTest(tf.test.TestCase): - - def test_one_hot_encode(self): - encoder = text_encoder.OneHotClassLabelEncoder( - class_labels=["zero", "one", "two"]) - self.assertEqual(encoder.encode("zero"), [1, 0, 0]) - self.assertEqual(encoder.encode("one"), [0, 1, 0]) - self.assertEqual(encoder.encode("two"), [0, 0, 1]) - - def test_one_hot_decode(self): - encoder = text_encoder.OneHotClassLabelEncoder( - class_labels=["zero", "one", "two"]) - self.assertEqual(encoder.decode([1, 0, 0]), "zero") - self.assertEqual(encoder.decode([0, 1, 0]), "one") - self.assertEqual(encoder.decode([0, 0, 1]), "two") - - -if __name__ == "__main__": - tf.test.main() diff --git a/trax/data/tf_inputs.py b/trax/data/tf_inputs.py index 239b4d3a1..aaca6f9bb 100644 --- a/trax/data/tf_inputs.py +++ b/trax/data/tf_inputs.py @@ -48,244 +48,315 @@ def t5_data(): - """Get the T5 data module if available.""" - module = None - try: - import t5.data # pylint: disable=g-import-not-at-top - module = t5.data - except AttributeError as e: - logging.error('pip install t5') - raise e - return module + """Get the T5 data module if available.""" + module = None + try: + import t5.data # pylint: disable=g-import-not-at-top + + module = t5.data + except AttributeError as e: + logging.error("pip install t5") + raise e + return module def no_preprocess(dataset, training): - del training - return dataset + del training + return dataset + + +class t2t_problems: + """Tensor2Tensor has been discontinued. Furthermore, version 1.15.7 caused many problems during running, + such as problems with dynamically loading all datasets and creating module registries. Therefore, we simplified + it by creating our own simple registry of problems.""" + + import tensor2tensor.data_generators.translate_ende + _problems: dict = { + "translate_ende_wmt32k": tensor2tensor.data_generators.translate_ende.TranslateEndeWmt32k(), + } -def t2t_problems(): - # Load t2t problems on request only, this should save some import time. - from tensor2tensor import problems_colab as t2tp # pylint: disable=g-import-not-at-top - return t2tp + def __init__(self): + pass + + @staticmethod + def problem(name): + if name in t2t_problems._problems: + return t2t_problems._problems[name] + + raise Exception(f"There is no registered problem {name}") + + @staticmethod + def available(): + return t2t_problems._problems # TODO(jonni): Rename function to better match its return values. -@gin.configurable(module='trax.data') -def data_streams(dataset_name, - data_dir=None, - preprocess_fn=no_preprocess, - bare_preprocess_fn=None, - shuffle_buffer_size=1024, - eval_holdout_size=0, - input_name=None, - target_name=None): - """Creates `(train, eval)` data sources from ``dataset_name``. - - Args: - dataset_name: Name of dataset belonging to TFDS or T2T. T2T dataset names - must start with ``'t2t_'``. - data_dir: Directory where the data is located. - preprocess_fn: Function to use for pre-processing after appending targets to - inputs. - bare_preprocess_fn: Function to use for pre-processing before appending - targets to inputs. - shuffle_buffer_size: Size of the shuffle buffer. - eval_holdout_size: If greater than 0, specifies a fraction of training data - to siphon off and use as eval data, in place of an separate eval split. - input_name: Name of the inputs from the dictionary. - target_name: Name of the outputs either from the dictionary or as a result - of post-processing. - - Returns: - A pair of functions, `(f, g)` for use as data sources; call `f()` to get an - iterator of training data samples, and call `g()` to get an iterator of eval - data samples. - """ - data_dir = download_and_prepare(dataset_name, data_dir) - - cache = [] - - def stream(which): - """Create the stream, cache TF streams if needed.""" - if not cache: - cache.append( - _train_and_eval_streams(dataset_name, data_dir, preprocess_fn, - bare_preprocess_fn, shuffle_buffer_size, - eval_holdout_size, input_name, target_name)) - - (train_ds, eval_ds, input_name_c) = cache[0] - dataset = eval_ds if which == 'eval' else train_ds - return dataset_to_stream(dataset, input_name_c) - - train_stream = lambda: stream('train') - eval_stream = lambda: stream('eval') - return train_stream, eval_stream +@gin.configurable(module="trax.data") +def data_streams( + dataset_name, + data_dir=None, + preprocess_fn=no_preprocess, + bare_preprocess_fn=None, + shuffle_buffer_size=1024, + eval_holdout_size=0, + input_name=None, + target_name=None, +): + """Creates `(train, eval)` data sources from ``dataset_name``. + + Args: + dataset_name: Name of dataset belonging to TFDS or T2T. T2T dataset names + must start with ``'t2t_'``. + data_dir: Directory where the data is located. + preprocess_fn: Function to use for pre-processing after appending targets to + inputs. + bare_preprocess_fn: Function to use for pre-processing before appending + targets to inputs. + shuffle_buffer_size: Size of the shuffle buffer. + eval_holdout_size: If greater than 0, specifies a fraction of training data + to siphon off and use as eval data, in place of an separate eval split. + input_name: Name of the inputs from the dictionary. + target_name: Name of the outputs either from the dictionary or as a result + of post-processing. + + Returns: + A pair of functions, `(f, g)` for use as data sources; call `f()` to get an + iterator of training data samples, and call `g()` to get an iterator of eval + data samples. + """ + data_dir = download_and_prepare(dataset_name, data_dir) + + cache = [] + + def stream(which): + """Create the stream, cache TF streams if needed.""" + if not cache: + cache.append( + _train_and_eval_streams( + dataset_name, + data_dir, + preprocess_fn, + bare_preprocess_fn, + shuffle_buffer_size, + eval_holdout_size, + input_name, + target_name, + ) + ) + + (train_ds, eval_ds, input_name_c) = cache[0] + dataset = eval_ds if which == "eval" else train_ds + return dataset_to_stream(dataset, input_name_c) + + train_stream = lambda: stream("train") + eval_stream = lambda: stream("eval") + return train_stream, eval_stream def dataset_to_stream(dataset, input_name): - """Takes a tf.Dataset and creates a numpy stream of ready batches.""" - # All input-pipeline processing should be on CPU. - for example in fastmath.dataset_as_numpy(dataset): - features = example[0] - inp, out = features[input_name], example[1] - mask = features['mask'] if 'mask' in features else None - # Some accelerators don't handle uint8 well, cast to int. - if isinstance(inp, np.uint8): - inp = inp.astype(np.int32) - if isinstance(out, np.uint8): - out = out.astype(np.int32) - yield (inp, out) if mask is None else (inp, out, mask) - - -def _train_and_eval_streams(dataset, data_dir, preprocess_fn, - bare_preprocess_fn, shuffle_buffer_size, - eval_holdout_size, input_name, target_name): - """Return train and eval batches with input name and shape.""" - (train_data, eval_data, - keys) = _train_and_eval_dataset(dataset, data_dir, eval_holdout_size) - # If provided select input_name/target_name else fall back to keys if that is - # available, else [None]. - input_names = ([input_name] if input_name is not None else - keys[0] if keys is not None else [None]) - target_names = ([target_name] if target_name is not None else - keys[1] if keys is not None else [None]) - - train_batches = _shuffle_data(train_data, target_names, True, - shuffle_buffer_size, preprocess_fn, - bare_preprocess_fn) - eval_batches = _shuffle_data(eval_data, target_names, False, - shuffle_buffer_size, preprocess_fn, - bare_preprocess_fn) - return (train_batches, eval_batches, input_names[0]) - - -def _shuffle_data(dataset, target_names, training, shuffle_buffer_size, - preprocess_fn, bare_preprocess_fn): - """Shuffle the given dataset and run pre-processing.""" - - def append_targets(example): - """Append targets to the example dictionary. Needed for Keras.""" - if len(target_names) == 1: - return (example, example[target_names[0]]) - targets = {} - for name in target_names: - targets[name] = example[name] - return (example, targets) - - # `bare_preprocess_fn` is called before appending targets etc. - if bare_preprocess_fn is not None: - dataset = bare_preprocess_fn(dataset, training) - dataset = dataset.map(append_targets) - # TODO(pkozakowski): Repeat both the training and evaluation set, so we don't - # have incomplete batches during evaluation. This will be a problem when we - # add an option to evaluate on the whole dataset, then we'll need to think of - # a different solution. - dataset = dataset.repeat() - if training: - # Skip a random fraction at the beginning of the stream. The skip is - # essential for synchronous highly-parallel training to avoid multiple - # replicas reading the same data in lock-step. - dataset = dataset.skip(random.randint(0, _MAX_SKIP_EXAMPLES)) - dataset = preprocess_fn(dataset, training) - dataset = dataset.shuffle(shuffle_buffer_size) - return dataset.prefetch(8) - - -def _train_and_eval_dataset(dataset_name, - data_dir, - eval_holdout_size, - train_shuffle_files=True, - eval_shuffle_files=False, - use_alt_eval=False, - subsplit=None): - """Return train and evaluation datasets, feature info and supervised keys. - - Args: - dataset_name: a string, the name of the dataset; if it starts with 't2t_' - then we'll search T2T Problem registry for it, otherwise we assume it is a - dataset from TFDS and load it from there. - data_dir: directory where the data is located. - eval_holdout_size: float from 0 to <1; if >0 use this much of training data - for evaluation (instead of looking for a pre-specified VALIDATION split). - train_shuffle_files: Boolean determining whether or not to shuffle the train - files at startup. Set to False if you want data determinism. - eval_shuffle_files: Boolean determining whether or not to shuffle the test - files at startup. Set to False if you want data determinism. - use_alt_eval: If True, use the dataset's alternate/secondary eval split; - else use the dataset's default/only eval split. Currently, only the - `glue/mnli` dataset provides an alternate eval split, and this arg is - ignored for other datasets. - subsplit: a pair of floats (x, y), both in [0, 1], saying which part of the - full training dataset we should return (default: all of it, [0, 1]). - - Returns: - a 4-tuple consisting of: - * the train tf.Dataset - * the eval tf.Dataset - * information about features: a python dictionary with feature names - as keys and an object as value that provides .shape and .n_classes. - * supervised_keys: information what's the input and what's the target, - ie., a pair of lists with input and target feature names. - """ - logging.info('Building TF data pipeline for %s', dataset_name) - if dataset_name.startswith('t2t_'): - return _train_and_eval_dataset_v1(dataset_name[4:], data_dir, - train_shuffle_files, eval_shuffle_files) - dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) - info = dataset_builder.info - splits = dataset_builder.info.splits - if dataset_name != 'c4/multilingual' and tfds.Split.TRAIN not in splits: - raise ValueError('To train we require a train split in the dataset.') - train_split = tfds.Split.TRAIN if dataset_name != 'c4/multilingual' else 'en' - eval_split = None - train_examples = info.splits[train_split].num_examples - eval_holdout_examples = int(train_examples * eval_holdout_size) - if eval_holdout_examples > 0 or subsplit is not None: - if subsplit is None: - subsplit = (0, 1) - n_train = train_examples - eval_holdout_examples - train_start = int(n_train * subsplit[0]) - train_end = int(n_train * subsplit[1]) - if train_end - train_start < 1: - raise ValueError('Requested train subsplit has no examples: ' - 'n_train %d subsplit %s' % (n_train, subsplit)) - # Eval holdout examples from the end of the training set. - if eval_holdout_examples > 0: - eval_split = f'{train_split}[-{eval_holdout_examples}:]' - # Shard the training set for this host. - train_split = f'{train_split}[{train_start}:{train_end}]' - - if dataset_name == 'glue/mnli': - eval_split = ( - 'validation_mismatched' if use_alt_eval else 'validation_matched') - elif dataset_name == 'c4/multilingual': - eval_split = 'en-validation' - elif eval_split is None: - if tfds.Split.VALIDATION not in splits and 'test' not in splits: - raise ValueError('We require a validation or test split in the dataset.') - eval_split = tfds.Split.VALIDATION - if tfds.Split.VALIDATION not in splits: - eval_split = tfds.Split.TEST - - train = tfds.load( - name=dataset_name, - split=train_split, - data_dir=data_dir, - shuffle_files=train_shuffle_files) - valid = tfds.load( - name=dataset_name, - split=eval_split, - data_dir=data_dir, - shuffle_files=eval_shuffle_files) - keys = None - if info.supervised_keys: - keys = ([info.supervised_keys[0]], [info.supervised_keys[1]]) - return train, valid, keys + """Takes a tf.Dataset and creates a numpy stream of ready batches.""" + # All input-pipeline processing should be on CPU. + for example in fastmath.dataset_as_numpy(dataset): + features = example[0] + inp, out = features[input_name], example[1] + mask = features["mask"] if "mask" in features else None + # Some accelerators don't handle uint8 well, cast to int. + if isinstance(inp, np.uint8): + inp = inp.astype(np.int32) + if isinstance(out, np.uint8): + out = out.astype(np.int32) + yield (inp, out) if mask is None else (inp, out, mask) + + +def _train_and_eval_streams( + dataset, + data_dir, + preprocess_fn, + bare_preprocess_fn, + shuffle_buffer_size, + eval_holdout_size, + input_name, + target_name, +): + """Return train and eval batches with input name and shape.""" + (train_data, eval_data, keys) = _train_and_eval_dataset( + dataset, data_dir, eval_holdout_size + ) + # If provided select input_name/target_name else fall back to keys if that is + # available, else [None]. + input_names = ( + [input_name] + if input_name is not None + else keys[0] + if keys is not None + else [None] + ) + target_names = ( + [target_name] + if target_name is not None + else keys[1] + if keys is not None + else [None] + ) + + train_batches = _shuffle_data( + train_data, + target_names, + True, + shuffle_buffer_size, + preprocess_fn, + bare_preprocess_fn, + ) + eval_batches = _shuffle_data( + eval_data, + target_names, + False, + shuffle_buffer_size, + preprocess_fn, + bare_preprocess_fn, + ) + return (train_batches, eval_batches, input_names[0]) + + +def _shuffle_data( + dataset, + target_names, + training, + shuffle_buffer_size, + preprocess_fn, + bare_preprocess_fn, +): + """Shuffle the given dataset and run pre-processing.""" + + def append_targets(example): + """Append targets to the example dictionary. Needed for Keras.""" + if len(target_names) == 1: + return (example, example[target_names[0]]) + targets = {} + for name in target_names: + targets[name] = example[name] + return (example, targets) + + # `bare_preprocess_fn` is called before appending targets etc. + if bare_preprocess_fn is not None: + dataset = bare_preprocess_fn(dataset, training) + dataset = dataset.map(append_targets) + # TODO(pkozakowski): Repeat both the training and evaluation set, so we don't + # have incomplete batches during evaluation. This will be a problem when we + # add an option to evaluate on the whole dataset, then we'll need to think of + # a different solution. + dataset = dataset.repeat() + if training: + # Skip a random fraction at the beginning of the stream. The skip is + # essential for synchronous highly-parallel training to avoid multiple + # replicas reading the same data in lock-step. + dataset = dataset.skip(random.randint(0, _MAX_SKIP_EXAMPLES)) + dataset = preprocess_fn(dataset, training) + dataset = dataset.shuffle(shuffle_buffer_size) + return dataset.prefetch(8) + + +def _train_and_eval_dataset( + dataset_name, + data_dir, + eval_holdout_size, + train_shuffle_files=True, + eval_shuffle_files=False, + use_alt_eval=False, + subsplit=None, +): + """Return train and evaluation datasets, feature info and supervised keys. + + Args: + dataset_name: a string, the name of the dataset; if it starts with 't2t_' + then we'll search T2T Problem registry for it, otherwise we assume it is a + dataset from TFDS and load it from there. + data_dir: directory where the data is located. + eval_holdout_size: float from 0 to <1; if >0 use this much of training data + for evaluation (instead of looking for a pre-specified VALIDATION split). + train_shuffle_files: Boolean determining whether or not to shuffle the train + files at startup. Set to False if you want data determinism. + eval_shuffle_files: Boolean determining whether or not to shuffle the test + files at startup. Set to False if you want data determinism. + use_alt_eval: If True, use the dataset's alternate/secondary eval split; + else use the dataset's default/only eval split. Currently, only the + `glue/mnli` dataset provides an alternate eval split, and this arg is + ignored for other datasets. + subsplit: a pair of floats (x, y), both in [0, 1], saying which part of the + full training dataset we should return (default: all of it, [0, 1]). + + Returns: + a 4-tuple consisting of: + * the train tf.Dataset + * the eval tf.Dataset + * information about features: a python dictionary with feature names + as keys and an object as value that provides .shape and .n_classes. + * supervised_keys: information what's the input and what's the target, + ie., a pair of lists with input and target feature names. + """ + logging.info("Building TF data pipeline for %s", dataset_name) + if dataset_name.startswith("t2t_"): + return _train_and_eval_dataset_v1( + dataset_name[4:], data_dir, train_shuffle_files, eval_shuffle_files + ) + dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) + info = dataset_builder.info + splits = dataset_builder.info.splits + if dataset_name != "c4/multilingual" and tfds.Split.TRAIN not in splits: + raise ValueError("To train we require a train split in the dataset.") + train_split = tfds.Split.TRAIN if dataset_name != "c4/multilingual" else "en" + eval_split = None + train_examples = info.splits[train_split].num_examples + eval_holdout_examples = int(train_examples * eval_holdout_size) + if eval_holdout_examples > 0 or subsplit is not None: + if subsplit is None: + subsplit = (0, 1) + n_train = train_examples - eval_holdout_examples + train_start = int(n_train * subsplit[0]) + train_end = int(n_train * subsplit[1]) + if train_end - train_start < 1: + raise ValueError( + "Requested train subsplit has no examples: " + "n_train %d subsplit %s" % (n_train, subsplit) + ) + # Eval holdout examples from the end of the training set. + if eval_holdout_examples > 0: + eval_split = f"{train_split}[-{eval_holdout_examples}:]" + # Shard the training set for this host. + train_split = f"{train_split}[{train_start}:{train_end}]" + + if dataset_name == "glue/mnli": + eval_split = "validation_mismatched" if use_alt_eval else "validation_matched" + elif dataset_name == "c4/multilingual": + eval_split = "en-validation" + elif eval_split is None: + if tfds.Split.VALIDATION not in splits and "test" not in splits: + raise ValueError("We require a validation or test split in the dataset.") + eval_split = tfds.Split.VALIDATION + if tfds.Split.VALIDATION not in splits: + eval_split = tfds.Split.TEST + + train = tfds.load( + name=dataset_name, + split=train_split, + data_dir=data_dir, + shuffle_files=train_shuffle_files, + ) + valid = tfds.load( + name=dataset_name, + split=eval_split, + data_dir=data_dir, + shuffle_files=eval_shuffle_files, + ) + keys = None + if info.supervised_keys: + keys = ([info.supervised_keys[0]], [info.supervised_keys[1]]) + return train, valid, keys # TODO(jonni): Consider renaming this function. -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def TFDS( # pylint: disable=invalid-name dataset_name, data_dir=None, @@ -296,1161 +367,1163 @@ def TFDS( # pylint: disable=invalid-name shuffle_train=True, host_id=None, n_hosts=None, - eval_holdout_size=0): - """Creates a data source from TensorFlow dataset ``dataset_name``. - - Args: - dataset_name: Name of the dataset, as registered in TensorFlow datasets - (e.g., ``'glue/mnli'``). - data_dir: Directory where the data is located. - tfds_preprocess_fn: If specified, function that applies to items in raw - dataset (before selecting specific features). - keys: Tuple of dataset-specific strings that select features from the - dataset. - train: If True, select the training split from the dataset; else select an - eval split. - use_alt_eval: If True, and if ``train`` is False, select the dataset's - alternate eval split if it has one (or fall back to the dataset's only - eval split). This currently affects only the `glue/mnli` dataset. - shuffle_train: If True, have TensorFlow pre-shuffle the training data; else - receive training data in deterministic sequence. - host_id: Integer id used for tracking data subsplits, in cases where - ``n_hosts`` > 1. - n_hosts: If greater than 1, prepare data subsplits for the given number of - hosts. - eval_holdout_size: If greater than 0, specifies a fraction of training data - to siphon off and use as eval data, in place of an separate eval split. - - Returns: - A function `f` for use as a training or eval data source; call `f()` to get - an iterator of data samples. - """ - data_dir = download_and_prepare(dataset_name, data_dir) - - host_id = jax.process_index() if host_id is None else host_id - n_hosts = n_hosts or jax.host_count() - if n_hosts > 1: - subsplit = (host_id / n_hosts, (host_id + 1) / n_hosts) - else: - subsplit = None - train_data, eval_data, _ = ( - _train_and_eval_dataset(dataset_name, - data_dir, - eval_holdout_size, - train_shuffle_files=shuffle_train, - use_alt_eval=use_alt_eval, - subsplit=subsplit)) - dataset = train_data if train else eval_data - dataset = dataset if tfds_preprocess_fn is None else tfds_preprocess_fn( - dataset) - - def select_from(example): - return tuple(example[k] for k in keys) - - dataset = dataset.map(select_from) - dataset = dataset.repeat() - - def gen(generator=None): - del generator - for example in fastmath.dataset_as_numpy(dataset): - yield example + eval_holdout_size=0, +): + """Creates a data source from TensorFlow dataset ``dataset_name``. + + Args: + dataset_name: Name of the dataset, as registered in TensorFlow datasets + (e.g., ``'glue/mnli'``). + data_dir: Directory where the data is located. + tfds_preprocess_fn: If specified, function that applies to items in raw + dataset (before selecting specific features). + keys: Tuple of dataset-specific strings that select features from the + dataset. + train: If True, select the training split from the dataset; else select an + eval split. + use_alt_eval: If True, and if ``train`` is False, select the dataset's + alternate eval split if it has one (or fall back to the dataset's only + eval split). This currently affects only the `glue/mnli` dataset. + shuffle_train: If True, have TensorFlow pre-shuffle the training data; else + receive training data in deterministic sequence. + host_id: Integer id used for tracking data subsplits, in cases where + ``n_hosts`` > 1. + n_hosts: If greater than 1, prepare data subsplits for the given number of + hosts. + eval_holdout_size: If greater than 0, specifies a fraction of training data + to siphon off and use as eval data, in place of an separate eval split. + + Returns: + A function `f` for use as a training or eval data source; call `f()` to get + an iterator of data samples. + """ + data_dir = download_and_prepare(dataset_name, data_dir) + + host_id = jax.process_index() if host_id is None else host_id + n_hosts = n_hosts or jax.host_count() + if n_hosts > 1: + subsplit = (host_id / n_hosts, (host_id + 1) / n_hosts) + else: + subsplit = None + train_data, eval_data, _ = _train_and_eval_dataset( + dataset_name, + data_dir, + eval_holdout_size, + train_shuffle_files=shuffle_train, + use_alt_eval=use_alt_eval, + subsplit=subsplit, + ) + dataset = train_data if train else eval_data + dataset = dataset if tfds_preprocess_fn is None else tfds_preprocess_fn(dataset) + + def select_from(example): + return tuple(example[k] for k in keys) + + dataset = dataset.map(select_from) + dataset = dataset.repeat() - return gen + def gen(generator=None): + del generator + for example in fastmath.dataset_as_numpy(dataset): + yield example + + return gen def _select_features(example, feature_list=None): - """Select a subset of features from the example dict.""" - feature_list = feature_list or ['inputs', 'targets'] - return {f: example[f] for f in feature_list if f in example} + """Select a subset of features from the example dict.""" + feature_list = feature_list or ["inputs", "targets"] + return {f: example[f] for f in feature_list if f in example} def _eager_dataset_iterator(dataset): - for item in dataset: - flat = tf.nest.flatten(item) - flat = [el.numpy() for el in flat] - yield tf.nest.pack_sequence_as(item, flat) - - -def _train_and_eval_dataset_v1(problem_name, data_dir, train_shuffle_files, - eval_shuffle_files): - """Return train and evaluation datasets, feature info and supervised keys.""" - with tf.device('cpu:0'): - problem = t2t_problems().problem(problem_name) - hparams = None - if problem_name == 'video_bair_robot_pushing': - hparams = problem.get_hparams() - bair_robot_pushing_hparams(hparams) - train_dataset = problem.dataset( - tf_estimator.ModeKeys.TRAIN, - data_dir, - shuffle_files=train_shuffle_files, - hparams=hparams) - train_dataset = train_dataset.map(_select_features) - eval_dataset = problem.dataset( - tf_estimator.ModeKeys.EVAL, - data_dir, - shuffle_files=eval_shuffle_files, - hparams=hparams) - eval_dataset = eval_dataset.map(_select_features) - # TODO(lukaszkaiser): remove this need for one example, just input_key. - examples = list(tfds.as_numpy(train_dataset.take(1))) - # We use 'inputs' as input except for purely auto-regressive tasks like - # language models where 'targets' are used as input_key. - input_key = 'inputs' if 'inputs' in examples[0] else 'targets' - supervised_keys = ([input_key], ['targets']) - return train_dataset, eval_dataset, supervised_keys + for item in dataset: + flat = tf.nest.flatten(item) + flat = [el.numpy() for el in flat] + yield tf.nest.pack_sequence_as(item, flat) + + +def _train_and_eval_dataset_v1( + problem_name, data_dir, train_shuffle_files, eval_shuffle_files +): + """Return train and evaluation datasets, feature info and supervised keys.""" + with tf.device("cpu:0"): + problem = t2t_problems().problem(problem_name) + hparams = None + if problem_name == "video_bair_robot_pushing": + hparams = problem.get_hparams() + bair_robot_pushing_hparams(hparams) + + train_dataset = problem.dataset( + tf_estimator.ModeKeys.TRAIN, + data_dir, + shuffle_files=train_shuffle_files, + hparams=hparams, + ) + train_dataset = train_dataset.map(_select_features) + eval_dataset = problem.dataset( + tf_estimator.ModeKeys.EVAL, + data_dir, + shuffle_files=eval_shuffle_files, + hparams=hparams, + ) + eval_dataset = eval_dataset.map(_select_features) + # TODO(lukaszkaiser): remove this need for one example, just input_key. + examples = list(tfds.as_numpy(train_dataset.take(1))) + # We use 'inputs' as input except for purely auto-regressive tasks like + # language models where 'targets' are used as input_key. + input_key = "inputs" if "inputs" in examples[0] else "targets" + supervised_keys = ([input_key], ["targets"]) + return train_dataset, eval_dataset, supervised_keys # Tokenization. @debug_data_pipeline.debug_pipeline -def tokenize(stream, - keys=None, - vocab_type='subword', - vocab_file=None, - vocab_dir=None, - n_reserved_ids=0): - """Tokenize examples from the stream. - - This function assumes that `stream` generates either strings or tuples/dicts - containing strings at some `keys`. This function maps these strings to - numpy arrays of integers -- the tokenized version of each string. - - Args: - stream: A python generator yielding strings, tuples or dicts. - keys: which keys of the tuple/dict to tokenize (by default: all) - vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. - vocab_file: Name of the vocabulary file. - vocab_dir: Directory which contains the vocabulary file. - n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused; - This is common for example when reserving the 0 for padding and 1 for EOS, - but it's only needed if these symbols are not already included (and thus - reserved) in the vocab_file. - - Yields: - Examples from stream with strings at `keys` replaced by np.arrays of - integers -- the tokenized version of these strings. - """ - vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) - for example in stream: - if isinstance(example, (list, tuple)): - new_example = [] - for i, x in enumerate(example): - if keys is None or i in keys: - new_example.append(np.array(vocab.encode(x)) + n_reserved_ids) - else: - new_example.append(x) - output = tuple(new_example) - yield output - elif isinstance(example, dict): - new_example = {} - for k in example: - if keys is None or k in keys: - new_example[k] = np.array(vocab.encode(example[k])) + n_reserved_ids +def tokenize( + stream, + keys=None, + vocab_type="subword", + vocab_file=None, + vocab_dir=None, + n_reserved_ids=0, +): + """Tokenize examples from the stream. + + This function assumes that `stream` generates either strings or tuples/dicts + containing strings at some `keys`. This function maps these strings to + numpy arrays of integers -- the tokenized version of each string. + + Args: + stream: A python generator yielding strings, tuples or dicts. + keys: which keys of the tuple/dict to tokenize (by default: all) + vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. + vocab_file: Name of the vocabulary file. + vocab_dir: Directory which contains the vocabulary file. + n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused; + This is common for example when reserving the 0 for padding and 1 for EOS, + but it's only needed if these symbols are not already included (and thus + reserved) in the vocab_file. + + Yields: + Examples from stream with strings at `keys` replaced by np.arrays of + integers -- the tokenized version of these strings. + """ + vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) + for example in stream: + if isinstance(example, (list, tuple)): + new_example = [] + for i, x in enumerate(example): + if keys is None or i in keys: + new_example.append(np.array(vocab.encode(x)) + n_reserved_ids) + else: + new_example.append(x) + output = tuple(new_example) + yield output + elif isinstance(example, dict): + new_example = {} + for k in example: + if keys is None or k in keys: + new_example[k] = np.array(vocab.encode(example[k])) + n_reserved_ids + else: + new_example[k] = example[k] + yield new_example else: - new_example[k] = example[k] - yield new_example - else: - output = np.array(vocab.encode(example)) + n_reserved_ids - yield output + output = np.array(vocab.encode(example)) + n_reserved_ids + yield output -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def Tokenize( # pylint: disable=invalid-name keys=None, - vocab_type='subword', # pylint: disable=invalid-name + vocab_type="subword", # pylint: disable=invalid-name vocab_file=None, vocab_dir=None, - n_reserved_ids=0): - """Returns a function that maps text to integer arrays; see `tokenize`.""" - return lambda g: tokenize( # pylint: disable=g-long-lambda - g, - keys=keys, - vocab_type=vocab_type, - vocab_file=vocab_file, - vocab_dir=vocab_dir, - n_reserved_ids=n_reserved_ids) - - -def detokenize(x, - vocab_type='subword', - vocab_file=None, - vocab_dir=None, - n_reserved_ids=0): - """Maps integer arrays to text; the opposite of `tokenize`. - - In many cases (all char- and subword-type vocabularies and most sentencepiece - ones) the tokenization is invertible, so detokenize(tokenize(x)) = x. In some - more rare cases this can remove some spacing, but it is still often useful - to run detokenize to get a readable version for a tokenized string. - - Args: - x: a list or numpy array of integers. - vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. - vocab_file: Name of the vocabulary file. - vocab_dir: Directory which contains the vocabulary file. - n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused; - This is common for example when reserving the 0 for padding and 1 for EOS, - but it's only needed if these symbols are not already included (and thus - reserved) in the vocab_file. - - Returns: - A string corresponding to the de-tokenized version of x. - """ - vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) - x_unreserved = np.array(x) - n_reserved_ids - return str(vocab.decode(x_unreserved.tolist())) + n_reserved_ids=0, +): + """Returns a function that maps text to integer arrays; see `tokenize`.""" + return lambda g: tokenize( # pylint: disable=g-long-lambda + g, + keys=keys, + vocab_type=vocab_type, + vocab_file=vocab_file, + vocab_dir=vocab_dir, + n_reserved_ids=n_reserved_ids, + ) + + +def detokenize( + x, vocab_type="subword", vocab_file=None, vocab_dir=None, n_reserved_ids=0 +): + """Maps integer arrays to text; the opposite of `tokenize`. + + In many cases (all char- and subword-type vocabularies and most sentencepiece + ones) the tokenization is invertible, so detokenize(tokenize(x)) = x. In some + more rare cases this can remove some spacing, but it is still often useful + to run detokenize to get a readable version for a tokenized string. + + Args: + x: a list or numpy array of integers. + vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. + vocab_file: Name of the vocabulary file. + vocab_dir: Directory which contains the vocabulary file. + n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused; + This is common for example when reserving the 0 for padding and 1 for EOS, + but it's only needed if these symbols are not already included (and thus + reserved) in the vocab_file. + + Returns: + A string corresponding to the de-tokenized version of x. + """ + vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) + x_unreserved = np.array(x) - n_reserved_ids + return str(vocab.decode(x_unreserved.tolist())) def _to_unicode(s): - # Errors of the casting are ignored (e.g. sequences not allowed by UTF-8), - # in order not to stay with incomplete examples (with empty values). - return str(s, encoding='utf-8', errors='ignore') + # Errors of the casting are ignored (e.g. sequences not allowed by UTF-8), + # in order not to stay with incomplete examples (with empty values). + return str(s, encoding="utf-8", errors="ignore") -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def ConvertToUnicode(keys=None): # pylint: disable=invalid-name - """Converts to Unicode UTF-8 elements of an example. + """Converts to Unicode UTF-8 elements of an example. + + Useful for when TFDS outputs byte arrays. All of the errors of the conversion + are ignored. + + Args: + keys: tuple/list of example dimensions to convert. + + Returns: + Function converting chosen elements of an example to UTF-8. + """ - Useful for when TFDS outputs byte arrays. All of the errors of the conversion - are ignored. + @debug_data_pipeline.debug_pipeline + def _convert_to_unicode_str(stream): + for example in stream: + if isinstance(example, (list, tuple)): + new_example = [] + for i, x in enumerate(example): + if keys is None or i in keys: + new_example.append(_to_unicode(x)) + else: + new_example.append(x) + output = tuple(new_example) + yield output + elif isinstance(example, dict): + new_example = {} + for k in example: + if keys is None or k in keys: + new_example[k] = _to_unicode(example[k]) + else: + new_example[k] = example[k] + yield new_example + else: + output = _to_unicode(example) + yield output - Args: - keys: tuple/list of example dimensions to convert. + return _convert_to_unicode_str - Returns: - Function converting chosen elements of an example to UTF-8. - """ - @debug_data_pipeline.debug_pipeline - def _convert_to_unicode_str(stream): - for example in stream: - if isinstance(example, (list, tuple)): - new_example = [] - for i, x in enumerate(example): - if keys is None or i in keys: - new_example.append(_to_unicode(x)) - else: - new_example.append(x) - output = tuple(new_example) - yield output - elif isinstance(example, dict): - new_example = {} - for k in example: - if keys is None or k in keys: - new_example[k] = _to_unicode(example[k]) - else: - new_example[k] = example[k] - yield new_example - else: - output = _to_unicode(example) - yield output - - return _convert_to_unicode_str - - -def vocab_size(vocab_type='subword', - vocab_file=None, - vocab_dir=None, - n_reserved_ids=0): - """Returns the size of the vocabulary (number of symbols used). - - This function can be used to set the size of the final layers of a model that - needs to predict symbols from a given vocabulary. More precisely, if this - function returns N then the last layer size should be set to at least N (it - can be more). Note that this function does take reserved IDs into account. - - Args: - vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. - vocab_file: Name of the vocabulary file. - vocab_dir: Directory which contains the vocabulary file. - n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused. - - Returns: - An integer, the number of symbols used (including reserved IDs). - """ - vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) - return vocab.vocab_size + n_reserved_ids - - -def _get_vocab(vocab_type='subword', vocab_file=None, vocab_dir=None, - extra_ids=0): - """Gets the vocabulary object for tokenization; see tokenize for details.""" - if vocab_type not in [ - 'char', 'subword', 'sentencepiece', 'bert', 'bert-lowercase' - ]: - raise ValueError( - 'vocab_type must be "subword", "char", "sentencepiece", "bert" or "bert-lowercase" ' - f'but got {vocab_type}') - - if vocab_type == 'char': - # Note that we set num_reserved_ids=0 below. We could instead pass - # the value n_reserved_ids from tokenize here -- ByteTextEncoder does - # exactly the same thing as tokenize above, ie., adds num_reserved_ids. - return text_encoder.ByteTextEncoder(num_reserved_ids=0) - - vocab_dir = vocab_dir or 'gs://trax-ml/vocabs/' - path = os.path.join(vocab_dir, vocab_file) - - if vocab_type == 'subword': - return text_encoder.SubwordTextEncoder(path) - - if vocab_type == 'bert': - return text_encoder.BertEncoder(path, do_lower_case=False) - - if vocab_type == 'bert-lowercase': - return text_encoder.BertEncoder(path, do_lower_case=True) - - assert vocab_type == 'sentencepiece' - return t5_data().SentencePieceVocabulary(sentencepiece_model_file=path, - extra_ids=extra_ids) +def vocab_size(vocab_type="subword", vocab_file=None, vocab_dir=None, n_reserved_ids=0): + """Returns the size of the vocabulary (number of symbols used). + + This function can be used to set the size of the final layers of a model that + needs to predict symbols from a given vocabulary. More precisely, if this + function returns N then the last layer size should be set to at least N (it + can be more). Note that this function does take reserved IDs into account. + + Args: + vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. + vocab_file: Name of the vocabulary file. + vocab_dir: Directory which contains the vocabulary file. + n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused. + + Returns: + An integer, the number of symbols used (including reserved IDs). + """ + vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) + return vocab.vocab_size + n_reserved_ids + + +def _get_vocab(vocab_type="subword", vocab_file=None, vocab_dir=None, extra_ids=0): + """Gets the vocabulary object for tokenization; see tokenize for details.""" + if vocab_type not in ["char", "subword", "sentencepiece", "bert", "bert-lowercase"]: + raise ValueError( + 'vocab_type must be "subword", "char", "sentencepiece", "bert" or "bert-lowercase" ' + f"but got {vocab_type}" + ) + + if vocab_type == "char": + # Note that we set num_reserved_ids=0 below. We could instead pass + # the value n_reserved_ids from tokenize here -- ByteTextEncoder does + # exactly the same thing as tokenize above, ie., adds num_reserved_ids. + return text_encoder.ByteTextEncoder(num_reserved_ids=0) + + vocab_dir = vocab_dir or "gs://trax-ml/vocabs/" + path = os.path.join(vocab_dir, vocab_file) + + if vocab_type == "subword": + return text_encoder.SubwordTextEncoder(path) + + if vocab_type == "bert": + return text_encoder.BertEncoder(path, do_lower_case=False) + + if vocab_type == "bert-lowercase": + return text_encoder.BertEncoder(path, do_lower_case=True) + + assert vocab_type == "sentencepiece" + return t5_data().SentencePieceVocabulary( + sentencepiece_model_file=path, extra_ids=extra_ids + ) # Makes the function accessible in gin configs, even with all args denylisted. -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) def cifar10_no_augmentation_preprocess(dataset, training): - del training + del training - def cast_image(features, targets): - features['image'] = tf.cast(features['image'], tf.float32) / 255.0 - return features, targets + def cast_image(features, targets): + features["image"] = tf.cast(features["image"], tf.float32) / 255.0 + return features, targets - dataset = dataset.map(cast_image) - return dataset + dataset = dataset.map(cast_image) + return dataset def _cifar_augment_image(image): - """Image augmentation suitable for CIFAR-10/100. + """Image augmentation suitable for CIFAR-10/100. - As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5). + As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5). - Args: - image: a Tensor. + Args: + image: a Tensor. - Returns: - Tensor of the same shape as image. - """ - image = tf.image.resize_with_crop_or_pad(image, 40, 40) - image = tf.image.random_crop(image, [32, 32, 3]) - image = tf.image.random_flip_left_right(image) - return image + Returns: + Tensor of the same shape as image. + """ + image = tf.image.resize_with_crop_or_pad(image, 40, 40) + image = tf.image.random_crop(image, [32, 32, 3]) + image = tf.image.random_flip_left_right(image) + return image # Makes the function accessible in gin configs, even with all args denylisted. -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) def cifar10_augmentation_preprocess(dataset, training): - """Preprocessing for cifar10 with augmentation (see below).""" - - def augment(features, targets): - features['image'] = _cifar_augment_image(features['image']) - return features, targets - - def cast_image(features, targets): - features['image'] = tf.cast(features['image'], tf.float32) / 255.0 - return features, targets - - if training: - dataset = dataset.map(augment) - dataset = dataset.map(cast_image) - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def cifar10_augmentation_flatten_preprocess(dataset, - training, - predict_image_train_weight=0.01): - """Preprocessing for cifar10 that flattens it and appends targets.""" - - def augment(features, targets): - features['image'] = _cifar_augment_image(features['image']) - return features, targets - - def flatten_image(features, targets): - """Flatten the image.""" - img = features['image'] - flat = tf.cast(tf.reshape(img, [-1]), tf.int64) - tgt = tf.expand_dims(targets, axis=0) - flat_with_target = tf.concat([flat, tgt], axis=0) - new_features = {} - new_features['image'] = flat_with_target - predict_image_weight = predict_image_train_weight if training else 0.0 - mask_begin = tf.ones_like(flat) - mask_begin = tf.cast(mask_begin, tf.float32) * predict_image_weight - mask_end = tf.cast(tf.ones_like(tgt), tf.float32) - new_features['mask'] = tf.concat([mask_begin, mask_end], axis=0) - return new_features, flat_with_target - - if training: - dataset = dataset.map(augment) - dataset = dataset.map(flatten_image) - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def downsampled_imagenet_flatten_bare_preprocess(dataset, training): - """Preprocessing for downsampled_imagenet. + """Preprocessing for cifar10 with augmentation (see below).""" - Args: - dataset: the dataset. - training: unused option. + def augment(features, targets): + features["image"] = _cifar_augment_image(features["image"]) + return features, targets - Returns: - Flattened dataset. + def cast_image(features, targets): + features["image"] = tf.cast(features["image"], tf.float32) / 255.0 + return features, targets - Preprocessing for downsampled_imagenet 32x32 and 64x64 generation from - http://arxiv.org/abs/1601.06759 (page 8). - """ - del training + if training: + dataset = dataset.map(augment) + dataset = dataset.map(cast_image) + return dataset - def flatten_image(features): - img = features['image'] - flat = tf.cast(tf.reshape(img, [-1]), tf.int64) - new_features = {'image': flat} - return new_features +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def cifar10_augmentation_flatten_preprocess( + dataset, training, predict_image_train_weight=0.01 +): + """Preprocessing for cifar10 that flattens it and appends targets.""" + + def augment(features, targets): + features["image"] = _cifar_augment_image(features["image"]) + return features, targets + + def flatten_image(features, targets): + """Flatten the image.""" + img = features["image"] + flat = tf.cast(tf.reshape(img, [-1]), tf.int64) + tgt = tf.expand_dims(targets, axis=0) + flat_with_target = tf.concat([flat, tgt], axis=0) + new_features = {} + new_features["image"] = flat_with_target + predict_image_weight = predict_image_train_weight if training else 0.0 + mask_begin = tf.ones_like(flat) + mask_begin = tf.cast(mask_begin, tf.float32) * predict_image_weight + mask_end = tf.cast(tf.ones_like(tgt), tf.float32) + new_features["mask"] = tf.concat([mask_begin, mask_end], axis=0) + return new_features, flat_with_target + + if training: + dataset = dataset.map(augment) + dataset = dataset.map(flatten_image) - return dataset.map(flatten_image) + return dataset -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def concat_preprocess(dataset, training, pad_symbol=0): - """Pre-processing function that concatenates input and target for LM.""" - del training +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def downsampled_imagenet_flatten_bare_preprocess(dataset, training): + """Preprocessing for downsampled_imagenet. + + Args: + dataset: the dataset. + training: unused option. + + Returns: + Flattened dataset. - def concat(features, targets): - inp = features['inputs'] - pad = tf.expand_dims(tf.zeros_like(inp[0]) + pad_symbol, axis=0) - concat = tf.concat([pad, inp, pad, targets], axis=0) - # Note: we're updating existing features dictionary here, so make sure - # it is not re-used in some other ways outside of this function. - features['inputs'] = concat - return features, concat + Preprocessing for downsampled_imagenet 32x32 and 64x64 generation from + http://arxiv.org/abs/1601.06759 (page 8). + """ + del training + + def flatten_image(features): + img = features["image"] + flat = tf.cast(tf.reshape(img, [-1]), tf.int64) + + new_features = {"image": flat} + return new_features + + return dataset.map(flatten_image) - dataset = dataset.map(concat) - return dataset +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def concat_preprocess(dataset, training, pad_symbol=0): + """Pre-processing function that concatenates input and target for LM.""" + del training + + def concat(features, targets): + inp = features["inputs"] + pad = tf.expand_dims(tf.zeros_like(inp[0]) + pad_symbol, axis=0) + concat = tf.concat([pad, inp, pad, targets], axis=0) + # Note: we're updating existing features dictionary here, so make sure + # it is not re-used in some other ways outside of this function. + features["inputs"] = concat + return features, concat + + dataset = dataset.map(concat) + return dataset -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) def squeeze_targets_preprocess(dataset, training): - """Pre-processing function that squeezes last axis of targets.""" - del training + """Pre-processing function that squeezes last axis of targets.""" + del training - def squeeze(features, targets): - if targets.shape[-1] == 1: - targets = tf.squeeze(targets, axis=-1) - return features, targets + def squeeze(features, targets): + if targets.shape[-1] == 1: + targets = tf.squeeze(targets, axis=-1) + return features, targets - dataset = dataset.map(squeeze) - return dataset + dataset = dataset.map(squeeze) + return dataset -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def lm1b_preprocess(dataset, - training, - max_target_length=-1, - max_eval_target_length=-1): - """Preprocessing for LM1B: filter out targets exceeding maximum length.""" +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def lm1b_preprocess(dataset, training, max_target_length=-1, max_eval_target_length=-1): + """Preprocessing for LM1B: filter out targets exceeding maximum length.""" - def target_right_length(_, target): - return tf.less(tf.shape(target)[0], max_target_length + 1) + def target_right_length(_, target): + return tf.less(tf.shape(target)[0], max_target_length + 1) - def eval_target_right_length(_, target): - return tf.less(tf.shape(target)[0], max_eval_target_length + 1) + def eval_target_right_length(_, target): + return tf.less(tf.shape(target)[0], max_eval_target_length + 1) - if max_target_length > 0 and training: - dataset = dataset.filter(target_right_length) + if max_target_length > 0 and training: + dataset = dataset.filter(target_right_length) - if max_eval_target_length > 0 and not training: - dataset = dataset.filter(eval_target_right_length) + if max_eval_target_length > 0 and not training: + dataset = dataset.filter(eval_target_right_length) - return dataset + return dataset # TODO(lukaszkaiser): find a single more abstract way of text pre-processing. -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) def wmt_preprocess(dataset, training, max_length=-1, max_eval_length=-1): - """Preprocessing for LM1B: filter out targets exceeding maximum length.""" + """Preprocessing for LM1B: filter out targets exceeding maximum length.""" - def train_right_length(example, target): - l = tf.maximum(tf.shape(example['inputs'])[0], tf.shape(target)[0]) - return tf.less(l, max_length + 1) + def train_right_length(example, target): + l = tf.maximum(tf.shape(example["inputs"])[0], tf.shape(target)[0]) + return tf.less(l, max_length + 1) - def eval_right_length(example, target): - l = tf.maximum(tf.shape(example['inputs'])[0], tf.shape(target)[0]) - return tf.less(l, max_eval_length + 1) + def eval_right_length(example, target): + l = tf.maximum(tf.shape(example["inputs"])[0], tf.shape(target)[0]) + return tf.less(l, max_eval_length + 1) - if max_length > 0 and training: - dataset = dataset.filter(train_right_length) + if max_length > 0 and training: + dataset = dataset.filter(train_right_length) - if max_eval_length > 0 and not training: - dataset = dataset.filter(eval_right_length) + if max_eval_length > 0 and not training: + dataset = dataset.filter(eval_right_length) - return dataset + return dataset -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) def wmt_concat_preprocess(dataset, training, max_length=-1, max_eval_length=-1): - """Preprocessing for WMT: filter exceeding maximum length and concatenate.""" - dataset = wmt_preprocess(dataset, training, max_length, max_eval_length) + """Preprocessing for WMT: filter exceeding maximum length and concatenate.""" + dataset = wmt_preprocess(dataset, training, max_length, max_eval_length) + + def concat_and_add_mask(features, targets): + inp = features["inputs"] + pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) + concat = tf.concat([inp, pad, targets], axis=0) + mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) + features["inputs"] = concat + features["mask"] = mask + return features, concat + + dataset = dataset.map(concat_and_add_mask) + return dataset + - def concat_and_add_mask(features, targets): - inp = features['inputs'] - pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) - concat = tf.concat([inp, pad, targets], axis=0) - mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) - features['inputs'] = concat - features['mask'] = mask - return features, concat +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def lm_token_preprocessing(dataset, training): + """Concatenates inputs, 0, targets, with masking only for targets.""" + del training + + def concat_and_add_mask(x): + inp = x["inputs"] + targets = x["targets"] + pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) + concat = tf.concat([inp, pad, targets], axis=0) + mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) + x["inputs"] = concat + x["targets"] = concat + x["mask"] = mask + return x + + dataset = dataset.map(concat_and_add_mask) + return dataset - dataset = dataset.map(concat_and_add_mask) - return dataset +@gin.configurable(module="trax.data", denylist=["hparams"]) +def bair_robot_pushing_hparams( + hparams=None, video_num_input_frames=1, video_num_target_frames=15 +): + if hparams is not None: + hparams.video_num_input_frames = video_num_input_frames + hparams.video_num_target_frames = video_num_target_frames + else: + return video_num_input_frames, video_num_target_frames -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def lm_token_preprocessing(dataset, training): - """Concatenates inputs, 0, targets, with masking only for targets.""" - del training - - def concat_and_add_mask(x): - inp = x['inputs'] - targets = x['targets'] - pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) - concat = tf.concat([inp, pad, targets], axis=0) - mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) - x['inputs'] = concat - x['targets'] = concat - x['mask'] = mask - return x - - dataset = dataset.map(concat_and_add_mask) - return dataset - - -@gin.configurable(module='trax.data', denylist=['hparams']) -def bair_robot_pushing_hparams(hparams=None, - video_num_input_frames=1, - video_num_target_frames=15): - if hparams is not None: - hparams.video_num_input_frames = video_num_input_frames - hparams.video_num_target_frames = video_num_target_frames - else: - return video_num_input_frames, video_num_target_frames - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) def bair_robot_pushing_preprocess(dataset, training): - """Pre-processing function that concatenates input and target frames.""" - del training - - def concat_and_add_mask(features, targets): - """Concatenate input and output frames to form a language modeling setup.""" - inp = features['inputs'] - concat = tf.concat([inp, targets], axis=0) - mask = tf.concat([tf.zeros_like(inp), tf.ones_like(targets)], axis=0) - concat = tf.reshape(concat, (-1,)) - mask = tf.reshape(mask, (-1,)) - concat = tf.cast(concat, tf.int32) - mask = tf.cast(mask, tf.float32) - features['inputs'] = features['targets'] = concat - features['mask'] = mask - return features, concat - - dataset = dataset.map(concat_and_add_mask) - return dataset + """Pre-processing function that concatenates input and target frames.""" + del training + + def concat_and_add_mask(features, targets): + """Concatenate input and output frames to form a language modeling setup.""" + inp = features["inputs"] + concat = tf.concat([inp, targets], axis=0) + mask = tf.concat([tf.zeros_like(inp), tf.ones_like(targets)], axis=0) + concat = tf.reshape(concat, (-1,)) + mask = tf.reshape(mask, (-1,)) + concat = tf.cast(concat, tf.int32) + mask = tf.cast(mask, tf.float32) + features["inputs"] = features["targets"] = concat + features["mask"] = mask + return features, concat + + dataset = dataset.map(concat_and_add_mask) + return dataset def sentencepiece_tokenize(stream, spm_path=None, extra_ids=0): - """Sentencepiece tokenization.""" - spm_path = spm_path or t5_data().DEFAULT_SPM_PATH - vocab_file = os.path.basename(spm_path) - vocab_dir = os.path.dirname(spm_path) - vocab = _get_vocab(vocab_type='sentencepiece', - vocab_file=vocab_file, - vocab_dir=vocab_dir, - extra_ids=extra_ids) - for example in stream: - # example could either be str or (str,) - if isinstance(example, tuple): - example = example[0] - yield np.array(vocab.encode(example)) - - -@gin.configurable(module='trax.data') -def SentencePieceTokenize( # pylint: disable=invalid-name - spm_path=None, - extra_ids=0): - """Returns a function that maps text to integer arrays.""" - return lambda g: sentencepiece_tokenize( # pylint: disable=g-long-lambda - g, - spm_path=spm_path, - extra_ids=extra_ids) - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def c4_preprocess(dataset, - training, - max_target_length=-1, - tokenization=None, - spm_path=None): - """Pre-processing function for C4 dataset.""" - del training - - def unicode_decode_chars(features, targets): - targets = tf.strings.unicode_decode(features['text'], 'UTF-8') - targets = tf.cast(targets, tf.int64) - features['targets'] = targets - features['inputs'] = targets - return (features, targets) - - def spc_tokenize(tokenizer, features, targets): - del targets - tokenized_text = tokenizer.tokenize(features['text']) - features['targets'] = tf.cast(tokenized_text, tf.int64) - features['inputs'] = features['targets'] - return features, features['targets'] - - if tokenization == 'spc': + """Sentencepiece tokenization.""" spm_path = spm_path or t5_data().DEFAULT_SPM_PATH - with tf.compat.v1.gfile.GFile(spm_path, 'rb') as f: - spc_model = f.read() - tokenizer = tf_text.SentencepieceTokenizer(model=spc_model) - dataset = dataset.map(functools.partial(spc_tokenize, tokenizer)) - else: - dataset = dataset.map(unicode_decode_chars) - - def target_right_length(_, target): - return tf.less(tf.shape(target)[0], max_target_length + 1) - - if max_target_length > 0: - dataset = dataset.filter(target_right_length) - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def c4_bare_preprocess_fn(dataset, - training=True, - spm_path=None, - copy_pretokenized=True, - sequence_length=None): - """Returns a dataset that contains 'inputs' and 'targets' from C4.""" - # Set target key to be equal to the text content. - dataset = t5_data().preprocessors.rekey( - dataset, key_map={ - 'targets': 'text', - 'inputs': None - }) - - # Vocabulary for tokenization. - extra_ids = 0 - vocab = t5_data().SentencePieceVocabulary( - sentencepiece_model_file=spm_path or t5_data().DEFAULT_SPM_PATH, - extra_ids=extra_ids) - feature = t5_data().Feature(vocab) - output_features = {'targets': feature, 'inputs': feature} - - # Tokenize the targets. - keys = output_features - - def encode_string_features_fn(features): - """Encode all specified feature that are strings and return a dictionary. + vocab_file = os.path.basename(spm_path) + vocab_dir = os.path.dirname(spm_path) + vocab = _get_vocab( + vocab_type="sentencepiece", + vocab_file=vocab_file, + vocab_dir=vocab_dir, + extra_ids=extra_ids, + ) + for example in stream: + # example could either be str or (str,) + if isinstance(example, tuple): + example = example[0] + yield np.array(vocab.encode(example)) + + +@gin.configurable(module="trax.data") +def SentencePieceTokenize(spm_path=None, extra_ids=0): # pylint: disable=invalid-name + """Returns a function that maps text to integer arrays.""" + return lambda g: sentencepiece_tokenize( # pylint: disable=g-long-lambda + g, spm_path=spm_path, extra_ids=extra_ids + ) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def c4_preprocess( + dataset, training, max_target_length=-1, tokenization=None, spm_path=None +): + """Pre-processing function for C4 dataset.""" + del training + + def unicode_decode_chars(features, targets): + targets = tf.strings.unicode_decode(features["text"], "UTF-8") + targets = tf.cast(targets, tf.int64) + features["targets"] = targets + features["inputs"] = targets + return (features, targets) + + def spc_tokenize(tokenizer, features, targets): + del targets + tokenized_text = tokenizer.tokenize(features["text"]) + features["targets"] = tf.cast(tokenized_text, tf.int64) + features["inputs"] = features["targets"] + return features, features["targets"] + + if tokenization == "spc": + spm_path = spm_path or t5_data().DEFAULT_SPM_PATH + with tf.compat.v1.gfile.GFile(spm_path, "rb") as f: + spc_model = f.read() + tokenizer = tf_text.SentencepieceTokenizer(model=spc_model) + dataset = dataset.map(functools.partial(spc_tokenize, tokenizer)) + else: + dataset = dataset.map(unicode_decode_chars) + + def target_right_length(_, target): + return tf.less(tf.shape(target)[0], max_target_length + 1) + + if max_target_length > 0: + dataset = dataset.filter(target_right_length) + + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def c4_bare_preprocess_fn( + dataset, training=True, spm_path=None, copy_pretokenized=True, sequence_length=None +): + """Returns a dataset that contains 'inputs' and 'targets' from C4.""" + # Set target key to be equal to the text content. + dataset = t5_data().preprocessors.rekey( + dataset, key_map={"targets": "text", "inputs": None} + ) + + # Vocabulary for tokenization. + extra_ids = 0 + vocab = t5_data().SentencePieceVocabulary( + sentencepiece_model_file=spm_path or t5_data().DEFAULT_SPM_PATH, + extra_ids=extra_ids, + ) + feature = t5_data().Feature(vocab) + output_features = {"targets": feature, "inputs": feature} + + # Tokenize the targets. + keys = output_features + + def encode_string_features_fn(features): + """Encode all specified feature that are strings and return a dictionary. + + Args: + features: a dictionary + + Returns: + a dictionary + """ + ret = {} + for k, v in features.items(): + if k in keys and v.dtype == tf.string: + if copy_pretokenized: + ret["%s_pretokenized" % k] = v + v = tf.cast(output_features[k].vocabulary.encode_tf(v), tf.int64) + ret[k] = v + return ret + + dataset = dataset.map( + encode_string_features_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + + # Preprocess the tokens - the exact preprocessors are set via gin. + dataset = t5_data().preprocessors.unsupervised( + dataset, sequence_length=sequence_length, output_features=output_features + ) + + # Add EOS. + dataset = add_eos_to_output_features(dataset, training) + + # Truncate and then pad the examples -- all examples have the same shape. + dataset = truncate_dataset_on_len(dataset, training, sequence_length, True) + dataset = pad_dataset_to_length(dataset, training, sequence_length) + + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def filter_dataset_on_len(dataset, training, len_map=None, filter_on_eval=False): + """Filters a dataset of lengths given in `len_map`. Args: - features: a dictionary + dataset: `tf.data.Dataset` the dataset to filter. + training: bool, true if we are in training mode. + len_map: optional dict of str to (int, int). We filter examples where a + feature's size is beyond the specified bounds. Ex: + {'inputs': (1, 512), 'targets': (64, 128)} will keep only those examples + where 1 <= len(inputs) <= 512 and 64 <= len(targets) <= 128. + filter_on_eval: bool if true, we will filter in eval mode also. Returns: - a dictionary + a filtered `tf.data.Dataset`. """ - ret = {} - for k, v in features.items(): - if k in keys and v.dtype == tf.string: - if copy_pretokenized: - ret['%s_pretokenized' % k] = v - v = tf.cast(output_features[k].vocabulary.encode_tf(v), tf.int64) - ret[k] = v - return ret - - dataset = dataset.map( - encode_string_features_fn, - num_parallel_calls=tf.data.experimental.AUTOTUNE) - - # Preprocess the tokens - the exact preprocessors are set via gin. - dataset = t5_data().preprocessors.unsupervised( - dataset, sequence_length=sequence_length, output_features=output_features) - - # Add EOS. - dataset = add_eos_to_output_features(dataset, training) - - # Truncate and then pad the examples -- all examples have the same shape. - dataset = truncate_dataset_on_len(dataset, training, sequence_length, True) - dataset = pad_dataset_to_length(dataset, training, sequence_length) - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def filter_dataset_on_len(dataset, - training, - len_map=None, - filter_on_eval=False): - """Filters a dataset of lengths given in `len_map`. - - Args: - dataset: `tf.data.Dataset` the dataset to filter. - training: bool, true if we are in training mode. - len_map: optional dict of str to (int, int). We filter examples where a - feature's size is beyond the specified bounds. Ex: - {'inputs': (1, 512), 'targets': (64, 128)} will keep only those examples - where 1 <= len(inputs) <= 512 and 64 <= len(targets) <= 128. - filter_on_eval: bool if true, we will filter in eval mode also. - - Returns: - a filtered `tf.data.Dataset`. - """ - if (len_map is None) or (not training and not filter_on_eval): - return dataset + if (len_map is None) or (not training and not filter_on_eval): + return dataset + + assert isinstance(len_map, dict) + for k, bounds in len_map.items(): + # pylint: disable=cell-var-from-loop + # TODO(afrozm): Investigate `cell-var-from-loop` - since this is WAI and + # there is a test too. + def within_bounds(x, key, len_bounds): + size = tf.shape(x[key])[0] + min_len, max_len = len_bounds + return (min_len <= size) and (size <= max_len) + + dataset = dataset.filter(lambda x: within_bounds(x, k, bounds)) + # pylint: enable=cell-var-from-loop - assert isinstance(len_map, dict) - for k, bounds in len_map.items(): - # pylint: disable=cell-var-from-loop - # TODO(afrozm): Investigate `cell-var-from-loop` - since this is WAI and - # there is a test too. - def within_bounds(x, key, len_bounds): - size = tf.shape(x[key])[0] - min_len, max_len = len_bounds - return (min_len <= size) and (size <= max_len) - - dataset = dataset.filter(lambda x: within_bounds(x, k, bounds)) - # pylint: enable=cell-var-from-loop - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def truncate_dataset_on_len(dataset, - training, - len_map=None, - truncate_on_eval=False): - """Truncates features in an example to lengths given in `len_map`. - - Args: - dataset: `tf.data.Dataset` the dataset to filter. - training: bool, true if we are in training mode. - len_map: optional dict of str to int, we truncate examples where a feature's - size is beyond the max. Ex: {'inputs': 512, 'targets': 64} will truncate - examples to be within those bounds. - truncate_on_eval: bool if true, we will truncate in eval mode also. - - Returns: - a filtered `tf.data.Dataset`. - """ - if (len_map is None) or (not training and not truncate_on_eval): return dataset - assert isinstance(len_map, dict) - def truncate_example(x): - for key, max_len in len_map.items(): - x_len = tf.shape(x[key])[0] - if x_len > max_len: - x[key] = x[key][:max_len, ...] - return x +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def truncate_dataset_on_len(dataset, training, len_map=None, truncate_on_eval=False): + """Truncates features in an example to lengths given in `len_map`. - return dataset.map(truncate_example) + Args: + dataset: `tf.data.Dataset` the dataset to filter. + training: bool, true if we are in training mode. + len_map: optional dict of str to int, we truncate examples where a feature's + size is beyond the max. Ex: {'inputs': 512, 'targets': 64} will truncate + examples to be within those bounds. + truncate_on_eval: bool if true, we will truncate in eval mode also. + + Returns: + a filtered `tf.data.Dataset`. + """ + if (len_map is None) or (not training and not truncate_on_eval): + return dataset + assert isinstance(len_map, dict) -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) + def truncate_example(x): + for key, max_len in len_map.items(): + x_len = tf.shape(x[key])[0] + if x_len > max_len: + x[key] = x[key][:max_len, ...] + return x + + return dataset.map(truncate_example) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) def pad_dataset_to_length(dataset, training, len_map=None): - """Pad features less than specified length to specified length.""" - del training - if len_map is None: + """Pad features less than specified length to specified length.""" + del training + if len_map is None: + return dataset + + def pad_to_len(x): + for key, max_len in len_map.items(): + x_shape = tf.shape(x[key]) + x_len = x_shape[0] + if x_len < max_len: + pad_shape = [ + max_len - x_len, + ] + zeros = tf.zeros(pad_shape, dtype=x[key].dtype) + x[key] = tf.concat([x[key], zeros], 0) + return x + + return dataset.map(pad_to_len) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def add_eos_to_output_features(dataset, training, output_features="targets", eos=1): + """Adds `EOS` to all features in `output_features`.""" + del training + if not isinstance(output_features, (list, tuple)): + output_features = [output_features] + + def add_eos(x): + for output_feature in output_features: + x[output_feature] = tf.concat([x[output_feature], [eos]], axis=0) + return x + + return dataset.map(add_eos) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def generic_text_dataset_preprocess_fn( + dataset, + training=True, + text_preprocess_fns=None, + token_preprocess_fns=None, + spm_path=None, + copy_pretokenized=False, + debug_print_examples=False, + debug_print_examples_rate=0.01, +): + """Pre-processes, tokenizes and post-processes a `tf.data.Dataset`. + + Args: + dataset: `tf.data.Dataset` to process. + training: boolean, set to True if training, False otherwise. + text_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool -> + `tf.data.Dataset` this operates before tokenization. Typically used to + select which fields we want to learn over or change something into "text + to text" form. + token_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool -> + `tf.data.Dataset`, this operates after tokenization. Since this can view + the tokenized fields, this can be used to filter on length etc. + spm_path: None or str, path to a sentencepiece model to use for tokenization + by default uses the 32k vocabulary from T5. + copy_pretokenized: bool, if True retains the original fields after + tokenization. + debug_print_examples: bool, if True this prints examples to the logging + stream for inspection, both before and after tokenization. + debug_print_examples_rate: float, [0, 1.0], on average this fraction of + dataset examples will be printed out in each phase i.e. pre and post + tokenization. + + Returns: + a `tf.data.Dataset` with all the preprocessing and tokenization performed. + """ + + # The assumption is that `text_preprocess_fns` finally gives us a dataset + # which has `inputs` and `targets`. + if text_preprocess_fns is not None: + for text_preprocess_fn in text_preprocess_fns: + dataset = text_preprocess_fn(dataset, training) + + # Print debugging examples if needed before tokenization. + if debug_print_examples: + + def print_examples(x): + if np.random.uniform() < debug_print_examples_rate: + tf.print(x, output_stream=logging.info) + return x + + dataset = dataset.map(print_examples) + + # Vocabulary for tokenization. + extra_ids = 0 + vocab = t5_data().SentencePieceVocabulary( + sentencepiece_model_file=spm_path or t5_data().DEFAULT_SPM_PATH, + extra_ids=extra_ids, + ) + feature = t5_data().Feature(vocab) + output_features = {"targets": feature, "inputs": feature} + + # Tokenize the inputs and targets. + dataset = t5_data().preprocessors.tokenize( + dataset, output_features, copy_pretokenized=copy_pretokenized + ) + + # Apply the token-preprocessors. + if token_preprocess_fns is not None: + for token_preprocess_fn in token_preprocess_fns: + dataset = token_preprocess_fn(dataset, training) + + if debug_print_examples: + + def print_examples_and_shapes(x): + if np.random.uniform() < debug_print_examples_rate: + tf.print( + { + "inputs_shape": tf.size(x["inputs"]), + "targets_shape": tf.size(x["targets"]), + "inputs": x["inputs"], + "targets": x["targets"], + }, + output_stream=logging.info, + ) + return x + + dataset = dataset.map(print_examples_and_shapes) + return dataset - def pad_to_len(x): - for key, max_len in len_map.items(): - x_shape = tf.shape(x[key]) - x_len = x_shape[0] - if x_len < max_len: - pad_shape = [ - max_len - x_len, - ] - zeros = tf.zeros(pad_shape, dtype=x[key].dtype) - x[key] = tf.concat([x[key], zeros], 0) - return x - - return dataset.map(pad_to_len) - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def add_eos_to_output_features(dataset, - training, - output_features='targets', - eos=1): - """Adds `EOS` to all features in `output_features`.""" - del training - if not isinstance(output_features, (list, tuple)): - output_features = [output_features] - - def add_eos(x): - for output_feature in output_features: - x[output_feature] = tf.concat([x[output_feature], [eos]], axis=0) - return x - - return dataset.map(add_eos) - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def generic_text_dataset_preprocess_fn(dataset, - training=True, - text_preprocess_fns=None, - token_preprocess_fns=None, - spm_path=None, - copy_pretokenized=False, - debug_print_examples=False, - debug_print_examples_rate=0.01): - """Pre-processes, tokenizes and post-processes a `tf.data.Dataset`. - - Args: - dataset: `tf.data.Dataset` to process. - training: boolean, set to True if training, False otherwise. - text_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool -> - `tf.data.Dataset` this operates before tokenization. Typically used to - select which fields we want to learn over or change something into "text - to text" form. - token_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool -> - `tf.data.Dataset`, this operates after tokenization. Since this can view - the tokenized fields, this can be used to filter on length etc. - spm_path: None or str, path to a sentencepiece model to use for tokenization - by default uses the 32k vocabulary from T5. - copy_pretokenized: bool, if True retains the original fields after - tokenization. - debug_print_examples: bool, if True this prints examples to the logging - stream for inspection, both before and after tokenization. - debug_print_examples_rate: float, [0, 1.0], on average this fraction of - dataset examples will be printed out in each phase i.e. pre and post - tokenization. - - Returns: - a `tf.data.Dataset` with all the preprocessing and tokenization performed. - """ - - # The assumption is that `text_preprocess_fns` finally gives us a dataset - # which has `inputs` and `targets`. - if text_preprocess_fns is not None: - for text_preprocess_fn in text_preprocess_fns: - dataset = text_preprocess_fn(dataset, training) - - # Print debugging examples if needed before tokenization. - if debug_print_examples: - - def print_examples(x): - if np.random.uniform() < debug_print_examples_rate: - tf.print(x, output_stream=logging.info) - return x - - dataset = dataset.map(print_examples) - - # Vocabulary for tokenization. - extra_ids = 0 - vocab = t5_data().SentencePieceVocabulary( - sentencepiece_model_file=spm_path or t5_data().DEFAULT_SPM_PATH, - extra_ids=extra_ids) - feature = t5_data().Feature(vocab) - output_features = {'targets': feature, 'inputs': feature} - - # Tokenize the inputs and targets. - dataset = t5_data().preprocessors.tokenize( - dataset, output_features, copy_pretokenized=copy_pretokenized) - - # Apply the token-preprocessors. - if token_preprocess_fns is not None: - for token_preprocess_fn in token_preprocess_fns: - dataset = token_preprocess_fn(dataset, training) - - if debug_print_examples: - - def print_examples_and_shapes(x): - if np.random.uniform() < debug_print_examples_rate: - tf.print( - { - 'inputs_shape': tf.size(x['inputs']), - 'targets_shape': tf.size(x['targets']), - 'inputs': x['inputs'], - 'targets': x['targets'], - }, - output_stream=logging.info) - return x - - dataset = dataset.map(print_examples_and_shapes) - - return dataset - - -@gin.configurable(module='trax.data') + +@gin.configurable(module="trax.data") def get_t5_preprocessor_by_name(name=None, fn_kwargs=None): - """Returns a closure of any T5 preprocessor function with its arguments. + """Returns a closure of any T5 preprocessor function with its arguments. - The main use-case is to use this (with gin scopes) to make any preprocessor - function available in a gin file to configure and use. + The main use-case is to use this (with gin scopes) to make any preprocessor + function available in a gin file to configure and use. - See: `TFInputs.test_gin_configurable_preprocessors` + See: `TFInputs.test_gin_configurable_preprocessors` - Args: - name: str, name of the preprocessor function to configure. - fn_kwargs: optional dictionary, the arguments to configure, these will be - partially applied to the function given by `name`. + Args: + name: str, name of the preprocessor function to configure. + fn_kwargs: optional dictionary, the arguments to configure, these will be + partially applied to the function given by `name`. - Returns: - a closure of the preprocessor function along with its arguments, this - function takes two arguments only, dataset and boolean training and ignores - the training and calls the t5 processor with the dataset (and closed over - arguments only). - """ + Returns: + a closure of the preprocessor function along with its arguments, this + function takes two arguments only, dataset and boolean training and ignores + the training and calls the t5 processor with the dataset (and closed over + arguments only). + """ - assert name is not None - f = getattr(t5_data().preprocessors, name) - if fn_kwargs is not None: - f = functools.partial(f, **fn_kwargs) - return lambda ds, unused_training: f(ds) + assert name is not None + f = getattr(t5_data().preprocessors, name) + if fn_kwargs is not None: + f = functools.partial(f, **fn_kwargs) + return lambda ds, unused_training: f(ds) def download_and_prepare(dataset_name, data_dir): - """Downloads and prepares T2T or TFDS dataset. - - Args: - dataset_name: tfds dataset or t2t problem name prefixed by 't2t_'. - data_dir: location of existing dataset or None. - - Returns: - data_dir: path string of downloaded data. - """ - if not data_dir: - data_dir = os.path.expanduser('~/tensorflow_datasets/') - dl_dir = os.path.join(data_dir, 'download') - logging.info( - 'No dataset directory provided. ' - 'Downloading and generating dataset for %s inside data directory %s ' - 'For large datasets it is better to prepare datasets manually!', - dataset_name, data_dir) - if dataset_name.startswith('t2t_'): - # Download and run dataset generator for T2T problem. - data_dir = os.path.join(data_dir, dataset_name) - tf.io.gfile.makedirs(data_dir) - tf.io.gfile.makedirs(dl_dir) - t2t_problems().problem(dataset_name[len('t2t_'):]).generate_data( - data_dir, dl_dir) + """Downloads and prepares T2T or TFDS dataset. + + Args: + dataset_name: tfds dataset or t2t problem name prefixed by 't2t_'. + data_dir: location of existing dataset or None. + + Returns: + data_dir: path string of downloaded data. + """ + if not data_dir: + data_dir = os.path.expanduser("~/tensorflow_datasets/") + dl_dir = os.path.join(data_dir, "download") + logging.info( + "No dataset directory provided. " + "Downloading and generating dataset for %s inside data directory %s " + "For large datasets it is better to prepare datasets manually!", + dataset_name, + data_dir, + ) + if dataset_name.startswith("t2t_"): + # Download and run dataset generator for T2T problem. + data_dir = os.path.join(data_dir, dataset_name) + tf.io.gfile.makedirs(data_dir) + tf.io.gfile.makedirs(dl_dir) + t2t_problems().problem(dataset_name[len("t2t_") :]).generate_data( + data_dir, dl_dir + ) + else: + # Download and prepare TFDS dataset. + tfds_builder = tfds.builder(dataset_name) + tfds_builder.download_and_prepare(download_dir=dl_dir) else: - # Download and prepare TFDS dataset. - tfds_builder = tfds.builder(dataset_name) - tfds_builder.download_and_prepare(download_dir=dl_dir) - else: - data_dir = os.path.expanduser(data_dir) - return data_dir - - -def BertSingleSentenceInputs(batch, # pylint: disable=invalid-name - labeled=True, - cls_id=101, - sep_id=102): - """Prepares inputs for BERT: add [SEP], [CLS] and create embeddings.""" - if labeled: - for sent1, label in batch: - value_vector = np.concatenate(([cls_id], sent1, [sep_id])) - segment_embs = np.zeros(sent1.shape[0] + 2, dtype=np.int32) - yield value_vector, segment_embs, segment_embs, label, np.int32(1) - else: - for (sent1,) in batch: # row is a tuple with 1 element - value_vector = np.concatenate(([cls_id], sent1, [sep_id])) - segment_embs = np.zeros(sent1.shape[0] + 2, dtype=np.int32) - yield value_vector, segment_embs, segment_embs - - -def BertDoubleSentenceInputs(batch, # pylint: disable=invalid-name - labeled=True, - cls_id=101, - sep_id=102): - """Prepares inputs for BERT models by adding [SEP] and [CLS] tokens and creating segment embeddings.""" - if labeled: - for sent1, sent2, label in batch: - value_vector = np.concatenate( - ([cls_id], sent1, [sep_id], sent2, [sep_id])) - - segment_embs = np.zeros( - sent1.shape[0] + sent2.shape[0] + 3, dtype=np.int32) - second_sent_start = sent1.shape[0] + 2 - segment_embs[second_sent_start:] = 1 - yield value_vector, segment_embs, segment_embs, label, np.int32(1) - else: - for sent1, sent2 in batch: - value_vector = np.concatenate( - ([cls_id], sent1, [sep_id], sent2, [sep_id])) - - segment_embs = np.zeros( - sent1.shape[0] + sent2.shape[0] + 3, dtype=np.int32) - second_sent_start = sent1.shape[0] + 2 - segment_embs[second_sent_start:] = 1 - yield value_vector, segment_embs, segment_embs - - -@gin.configurable(module='trax.data') -def CreateBertInputs(double_sentence=True, # pylint: disable=invalid-name - labeled=True, - cls_id=101, - sep_id=102): - bert_inputs_fn = BertDoubleSentenceInputs if double_sentence else BertSingleSentenceInputs - return functools.partial( - bert_inputs_fn, labeled=labeled, cls_id=cls_id, sep_id=sep_id) - - -@gin.configurable(module='trax.data') -def mask_random_tokens(batch, - explicit_vocab_size=30522, - masking_prob=0.15, - cls_id=101, - sep_id=102, - mask_id=103, - vocab_start_id=999): - """Prepares input for the masking task. - - Preparation consist in masking masking_prob percentage of non-special tokens - at each input row; round(masking_prob * num_nonspecial_tokens) random tokens - are selected out of which each token is either - - replaced with [MASK] token with 80% probability, - - replaced with random token with 10% probability, - - or unchanged with 10%. - The implentation is based on - https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L342 - - Examples: - - batch is a stream with each row having tuple (token_ids,). Function yields - rows of form (modified_token_ids, original_tokens, token_weights), where - modified_token_ids have [MASK] tokens or random tokens according to the - procedure described above. - - batch is a stream with each row having tuple (token_ids, segment_embeddings, - nsp_label, nsp_weight).Function yields rows of form (modified_token_ids, - segment_embeddings, nsp_label, nsp_weight, original_tokens, token_weights). - - Args: - batch: stream of inputs. Each row in the stream is a tuple which first - element is an array of tokens - explicit_vocab_size: the total size of the vocabulary. - masking_prob: Determines percent of non-special tokens to be selected for - masking. - cls_id: id of the special CLS token. - sep_id: id of the special SEP token. - mask_id: id of the special MASK token. - vocab_start_id: id of first non-special token in the vocabulary. - - Yields: - a stream with tokens masked for MLM training and 2 appended arrays: - - original tokens: a copy of original tokens used as a label for mlm - training - - token_weights: weights distributed uniformly over selected tokens (sum - is 1). Other tokens have 0 weight. - """ - for token_ids, *row_rest in batch: - original_tokens = token_ids.copy() - - # choose tokens for prediction. Chooses 0.15 of - # all non-special tokens - is_special_token = np.logical_or(token_ids == cls_id, - token_ids == sep_id) # CLS and SEP tokens - is_special_token = np.logical_or(is_special_token, - token_ids == 0) # padding - viable_ids = np.arange(token_ids.shape[0])[~is_special_token] - num_to_sample = round(masking_prob * viable_ids.shape[0]) - if num_to_sample == 0: - # sentence is too short to select given percentage of tokens to mask - continue - candidate_ids = np.random.choice(viable_ids, num_to_sample, replace=False) - - # create weights - token_weights = np.zeros(token_ids.shape) - token_weights[candidate_ids] = 1 / candidate_ids.shape[0] - - prob_scores = np.random.random(candidate_ids.shape) - - # change 80 % of tokens to [MASK] - mask_token_ids = candidate_ids[prob_scores < 0.8] - token_ids[mask_token_ids] = mask_id - - # change 10% of tokens to random token - random_token_ids = candidate_ids[(0.8 <= prob_scores) & (prob_scores < 0.9)] - token_ids[random_token_ids] = np.random.randint(vocab_start_id, - explicit_vocab_size, - random_token_ids.shape[0]) - - # rest (10%) is left unchaged - yield (token_ids, *row_rest, original_tokens, token_weights) - - -@gin.configurable(module='trax.data') -def BertNextSentencePredictionInputs(dataset_name, # pylint: disable=invalid-name - data_dir=None, - text_key='text', - train=True, - shuffle_size=50000): - """Defines a stream for the next sentence prediction task.""" - stream = TFDS( - dataset_name, - data_dir=data_dir, - tfds_preprocess_fn=functools.partial( - t5_data().preprocessors.next_sentence_prediction, - text_key=text_key, - label_sentences=True, - buffer_size=shuffle_size), - keys=['inputs', 'targets'], - train=train) - - def split_stream(generator=None): - # split string with 'sentence1:' and 'sentence2:' into two separate strings - for text, target in stream(generator): - text_str = str(text)[:-1] # removes last '"' which is always at the end - sentences = text_str.split('sentence1: ')[1].split(' sentence2: ') - if len(sentences) != 2: - # 'sentence2:' appeared in the text and got mixed up with the label - continue - sent1, sent2 = sentences - yield sent1, sent2, target == 'next' - - return split_stream - - -@gin.configurable(module='trax.data') -def CorpusToRandomChunks(dataset_name, num_tokens=512, train=True): # pylint: disable=invalid-name - return TFDS( - dataset_name, - tfds_preprocess_fn=functools.partial( - t5_data().preprocessors.random_split_text, - max_words_per_segment=num_tokens), - train=train, - keys=['text']) + data_dir = os.path.expanduser(data_dir) + return data_dir + + +def BertSingleSentenceInputs( + batch, labeled=True, cls_id=101, sep_id=102 # pylint: disable=invalid-name +): + """Prepares inputs for BERT: add [SEP], [CLS] and create embeddings.""" + if labeled: + for sent1, label in batch: + value_vector = np.concatenate(([cls_id], sent1, [sep_id])) + segment_embs = np.zeros(sent1.shape[0] + 2, dtype=np.int32) + yield value_vector, segment_embs, segment_embs, label, np.int32(1) + else: + for (sent1,) in batch: # row is a tuple with 1 element + value_vector = np.concatenate(([cls_id], sent1, [sep_id])) + segment_embs = np.zeros(sent1.shape[0] + 2, dtype=np.int32) + yield value_vector, segment_embs, segment_embs + + +def BertDoubleSentenceInputs( + batch, labeled=True, cls_id=101, sep_id=102 # pylint: disable=invalid-name +): + """Prepares inputs for BERT models by adding [SEP] and [CLS] tokens and creating segment embeddings.""" + if labeled: + for sent1, sent2, label in batch: + value_vector = np.concatenate(([cls_id], sent1, [sep_id], sent2, [sep_id])) + + segment_embs = np.zeros(sent1.shape[0] + sent2.shape[0] + 3, dtype=np.int32) + second_sent_start = sent1.shape[0] + 2 + segment_embs[second_sent_start:] = 1 + yield value_vector, segment_embs, segment_embs, label, np.int32(1) + else: + for sent1, sent2 in batch: + value_vector = np.concatenate(([cls_id], sent1, [sep_id], sent2, [sep_id])) + + segment_embs = np.zeros(sent1.shape[0] + sent2.shape[0] + 3, dtype=np.int32) + second_sent_start = sent1.shape[0] + 2 + segment_embs[second_sent_start:] = 1 + yield value_vector, segment_embs, segment_embs + + +@gin.configurable(module="trax.data") +def CreateBertInputs( + double_sentence=True, # pylint: disable=invalid-name + labeled=True, + cls_id=101, + sep_id=102, +): + bert_inputs_fn = ( + BertDoubleSentenceInputs if double_sentence else BertSingleSentenceInputs + ) + return functools.partial( + bert_inputs_fn, labeled=labeled, cls_id=cls_id, sep_id=sep_id + ) + + +@gin.configurable(module="trax.data") +def mask_random_tokens( + batch, + explicit_vocab_size=30522, + masking_prob=0.15, + cls_id=101, + sep_id=102, + mask_id=103, + vocab_start_id=999, +): + """Prepares input for the masking task. + + Preparation consist in masking masking_prob percentage of non-special tokens + at each input row; round(masking_prob * num_nonspecial_tokens) random tokens + are selected out of which each token is either + - replaced with [MASK] token with 80% probability, + - replaced with random token with 10% probability, + - or unchanged with 10%. + The implentation is based on + https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L342 + + Examples: + - batch is a stream with each row having tuple (token_ids,). Function yields + rows of form (modified_token_ids, original_tokens, token_weights), where + modified_token_ids have [MASK] tokens or random tokens according to the + procedure described above. + - batch is a stream with each row having tuple (token_ids, segment_embeddings, + nsp_label, nsp_weight).Function yields rows of form (modified_token_ids, + segment_embeddings, nsp_label, nsp_weight, original_tokens, token_weights). + + Args: + batch: stream of inputs. Each row in the stream is a tuple which first + element is an array of tokens + explicit_vocab_size: the total size of the vocabulary. + masking_prob: Determines percent of non-special tokens to be selected for + masking. + cls_id: id of the special CLS token. + sep_id: id of the special SEP token. + mask_id: id of the special MASK token. + vocab_start_id: id of first non-special token in the vocabulary. + + Yields: + a stream with tokens masked for MLM training and 2 appended arrays: + - original tokens: a copy of original tokens used as a label for mlm + training + - token_weights: weights distributed uniformly over selected tokens (sum + is 1). Other tokens have 0 weight. + """ + for token_ids, *row_rest in batch: + original_tokens = token_ids.copy() + + # choose tokens for prediction. Chooses 0.15 of + # all non-special tokens + is_special_token = np.logical_or( + token_ids == cls_id, token_ids == sep_id + ) # CLS and SEP tokens + is_special_token = np.logical_or(is_special_token, token_ids == 0) # padding + viable_ids = np.arange(token_ids.shape[0])[~is_special_token] + num_to_sample = round(masking_prob * viable_ids.shape[0]) + if num_to_sample == 0: + # sentence is too short to select given percentage of tokens to mask + continue + candidate_ids = np.random.choice(viable_ids, num_to_sample, replace=False) + + # create weights + token_weights = np.zeros(token_ids.shape) + token_weights[candidate_ids] = 1 / candidate_ids.shape[0] + + prob_scores = np.random.random(candidate_ids.shape) + + # change 80 % of tokens to [MASK] + mask_token_ids = candidate_ids[prob_scores < 0.8] + token_ids[mask_token_ids] = mask_id + + # change 10% of tokens to random token + random_token_ids = candidate_ids[(0.8 <= prob_scores) & (prob_scores < 0.9)] + token_ids[random_token_ids] = np.random.randint( + vocab_start_id, explicit_vocab_size, random_token_ids.shape[0] + ) + + # rest (10%) is left unchaged + yield (token_ids, *row_rest, original_tokens, token_weights) + + +@gin.configurable(module="trax.data") +def BertNextSentencePredictionInputs( + dataset_name, # pylint: disable=invalid-name + data_dir=None, + text_key="text", + train=True, + shuffle_size=50000, +): + """Defines a stream for the next sentence prediction task.""" + stream = TFDS( + dataset_name, + data_dir=data_dir, + tfds_preprocess_fn=functools.partial( + t5_data().preprocessors.next_sentence_prediction, + text_key=text_key, + label_sentences=True, + buffer_size=shuffle_size, + ), + keys=["inputs", "targets"], + train=train, + ) + + def split_stream(generator=None): + # split string with 'sentence1:' and 'sentence2:' into two separate strings + for text, target in stream(generator): + text_str = str(text)[:-1] # removes last '"' which is always at the end + sentences = text_str.split("sentence1: ")[1].split(" sentence2: ") + if len(sentences) != 2: + # 'sentence2:' appeared in the text and got mixed up with the label + continue + sent1, sent2 = sentences + yield sent1, sent2, target == "next" + + return split_stream + + +@gin.configurable(module="trax.data") +def CorpusToRandomChunks( + dataset_name, num_tokens=512, train=True +): # pylint: disable=invalid-name + return TFDS( + dataset_name, + tfds_preprocess_fn=functools.partial( + t5_data().preprocessors.random_split_text, max_words_per_segment=num_tokens + ), + train=train, + keys=["text"], + ) _GLUE_KEYS = { - 'cola': ('sentence',), - 'sst2': ('sentence',), - 'mrpc': ('sentence1', 'sentence2'), - 'qqp': ('question1', 'question2'), - 'stsb': ('sentence1', 'sentence2'), - 'mnli': ('premise', 'hypothesis'), - 'qnli': ('question', 'sentence'), - 'rte': ('sentence1', 'sentence2'), - 'wnli': ('sentence1', 'sentence2'), + "cola": ("sentence",), + "sst2": ("sentence",), + "mrpc": ("sentence1", "sentence2"), + "qqp": ("question1", "question2"), + "stsb": ("sentence1", "sentence2"), + "mnli": ("premise", "hypothesis"), + "qnli": ("question", "sentence"), + "rte": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), } # Labels inferred from the T5 paper: https://arxiv.org/pdf/1910.10683.pdf _GLUE_LABELS = { - 'cola': ('unacceptable', 'acceptable'), - 'sst2': ('negative', 'positive'), - 'mrpc': ('not_equivalent', 'equivalent'), - 'qqp': ('not_duplicate', 'duplicate'), - 'stsb': ('sentence1', 'sentence2'), - 'mnli': ('entailment', 'neutral', 'contradiction'), - 'qnli': ('entailment', 'not_entailment'), - 'rte': ('entailment', 'not_entailment'), - 'wnli': ('sentence1', 'sentence2'), + "cola": ("unacceptable", "acceptable"), + "sst2": ("negative", "positive"), + "mrpc": ("not_equivalent", "equivalent"), + "qqp": ("not_duplicate", "duplicate"), + "stsb": ("sentence1", "sentence2"), + "mnli": ("entailment", "neutral", "contradiction"), + "qnli": ("entailment", "not_entailment"), + "rte": ("entailment", "not_entailment"), + "wnli": ("sentence1", "sentence2"), } # Defining separate TrainStream and EvalStream functions (below) @@ -1462,15 +1535,15 @@ def CorpusToRandomChunks(dataset_name, num_tokens=512, train=True): # pylint: d # pylint: disable=invalid-name -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def BertGlueTrainStream(benchmark=gin.REQUIRED): - """Returns a Bert-preprocessed training stream for ``benchmark``. + """Returns a Bert-preprocessed training stream for ``benchmark``. - Args: - benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, - ``'mnli'``, ``'rte'``. - """ - return _BertGlueDataStream(benchmark + '_t') + Args: + benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, + ``'mnli'``, ``'rte'``. + """ + return _BertGlueDataStream(benchmark + "_t") # GLUE evals need special handling because one eval in particular, MNLI, has @@ -1478,797 +1551,843 @@ def BertGlueTrainStream(benchmark=gin.REQUIRED): # distinguishes between the two using the suffixes '_e' versus '_e2', # respectively. def _ensure_eval_suffix(benchmark): - """Returns a string ending in an eval suffix; adds ``'_e'`` suffix if needed. + """Returns a string ending in an eval suffix; adds ``'_e'`` suffix if needed. - Args: - benchmark: Name of a benchmark or task, that might already include an - eval-indicating suffix (``'_e'`` or ``'_e2'``). - """ - if benchmark.endswith('_e') or benchmark.endswith('_e2'): - return benchmark - else: - return benchmark + '_e' + Args: + benchmark: Name of a benchmark or task, that might already include an + eval-indicating suffix (``'_e'`` or ``'_e2'``). + """ + if benchmark.endswith("_e") or benchmark.endswith("_e2"): + return benchmark + else: + return benchmark + "_e" -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def BertGlueEvalStream(benchmark=gin.REQUIRED): - """Returns a Bert-preprocessed eval data stream for ``benchmark``. + """Returns a Bert-preprocessed eval data stream for ``benchmark``. - Args: - benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, - ``'mnli'``, ``'rte'``. If the benchmark includes an alternate - eval (e.g., MNLI's "mismatched" eval/validation split), you can - specify it with an ``'_e2'`` suffix, e.g., ``'mnli_e2'``. - """ - return _BertGlueDataStream(_ensure_eval_suffix(benchmark)) + Args: + benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, + ``'mnli'``, ``'rte'``. If the benchmark includes an alternate + eval (e.g., MNLI's "mismatched" eval/validation split), you can + specify it with an ``'_e2'`` suffix, e.g., ``'mnli_e2'``. + """ + return _BertGlueDataStream(_ensure_eval_suffix(benchmark)) def _BertGlueDataStream(benchmark_id): - """Returns a Bert-preprocessed data stream for ``benchmark_id``. - - Args: - benchmark_id: String that indicates the name and data split of a GLUE - benchmark. Data splits are indicated as underscore suffixes, e.g., - ``'cola_t'`` (Cola benchmark, training split), ``'rte_e'`` (RTE - benchmark, eval/validation split), and ``'mnli_e2'`` (MNLI benchmark, - alternate "mismatched" eval/validation split). - """ - benchmark_id = _ensure_eval_suffix(benchmark_id) - benchmark, split = benchmark_id.rsplit('_', 1) - glue_data = TFDS(f'glue/{benchmark}', - keys=_GLUE_KEYS[benchmark], - train=(split == 't'), - use_alt_eval=(split == 'e2')) - return data.Serial( - glue_data, - data.Tokenize(), - data.CreateBertInputs(), - data.Shuffle(), - data.PadToLength(), - data.TruncateToLength(), - data.Batch(), - ) - - -@gin.configurable(module='trax.data') + """Returns a Bert-preprocessed data stream for ``benchmark_id``. + + Args: + benchmark_id: String that indicates the name and data split of a GLUE + benchmark. Data splits are indicated as underscore suffixes, e.g., + ``'cola_t'`` (Cola benchmark, training split), ``'rte_e'`` (RTE + benchmark, eval/validation split), and ``'mnli_e2'`` (MNLI benchmark, + alternate "mismatched" eval/validation split). + """ + benchmark_id = _ensure_eval_suffix(benchmark_id) + benchmark, split = benchmark_id.rsplit("_", 1) + glue_data = TFDS( + f"glue/{benchmark}", + keys=_GLUE_KEYS[benchmark], + train=(split == "t"), + use_alt_eval=(split == "e2"), + ) + return data.Serial( + glue_data, + data.Tokenize(), + data.CreateBertInputs(), + data.Shuffle(), + data.PadToLength(), + data.TruncateToLength(), + data.Batch(), + ) + + +@gin.configurable(module="trax.data") def T5GlueTrainStream(benchmark=gin.REQUIRED): - """Returns a T5-preprocessed training data stream for ``benchmark``. - - Args: - benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, - ``'mnli'``, ``'rte'``. - """ - return _T5GlueDataStream(benchmark + '_t') - - -@gin.configurable(module='trax.data') -def T5GlueTrainStreamsParallel(benchmark_list=gin.REQUIRED, - counters=None, - reweight_by_minimum=False, - gradually_reweight=False): - """Returns a parallel set of training streams, based on ``benchmark_list``. - - Args: - benchmark_list: List of simple lower-case names of GLUE benchmarks, e.g., - ``'cola'``, ``'mnli'``, ``'rte'``. - counters: a list of counters to be passed to data.Parallel, e.g., - [8551, 392702, 2490] would be a reasonable counterpart to - benchmark_list = ["cola", "mnli", "rte"], see - https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/glue_utils.py#L42 - for more details on counters. - reweight_by_minimum: divide by the minimal counter. - gradually_reweight: a more refined reweighting policy, see inputs.py - for more details. - """ - stream_list = list(map(T5GlueTrainStream, benchmark_list)) - return data.Parallel( - stream_list, - counters=counters, - reweight_by_minimum=reweight_by_minimum, - gradually_reweight=gradually_reweight)() - - -@gin.configurable(module='trax.data') + """Returns a T5-preprocessed training data stream for ``benchmark``. + + Args: + benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, + ``'mnli'``, ``'rte'``. + """ + return _T5GlueDataStream(benchmark + "_t") + + +@gin.configurable(module="trax.data") +def T5GlueTrainStreamsParallel( + benchmark_list=gin.REQUIRED, + counters=None, + reweight_by_minimum=False, + gradually_reweight=False, +): + """Returns a parallel set of training streams, based on ``benchmark_list``. + + Args: + benchmark_list: List of simple lower-case names of GLUE benchmarks, e.g., + ``'cola'``, ``'mnli'``, ``'rte'``. + counters: a list of counters to be passed to data.Parallel, e.g., + [8551, 392702, 2490] would be a reasonable counterpart to + benchmark_list = ["cola", "mnli", "rte"], see + https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/glue_utils.py#L42 + for more details on counters. + reweight_by_minimum: divide by the minimal counter. + gradually_reweight: a more refined reweighting policy, see inputs.py + for more details. + """ + stream_list = list(map(T5GlueTrainStream, benchmark_list)) + return data.Parallel( + stream_list, + counters=counters, + reweight_by_minimum=reweight_by_minimum, + gradually_reweight=gradually_reweight, + )() + + +@gin.configurable(module="trax.data") def T5GlueEvalStream(benchmark=gin.REQUIRED): - """Returns a T5-preprocessed eval data stream for ``benchmark``. + """Returns a T5-preprocessed eval data stream for ``benchmark``. - Args: - benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, - ``'mnli'``, ``'rte'``. If the benchmark includes an alternate - eval (e.g., MNLI's "mismatched" eval/validation split), you can - specify it with an ``'_e2'`` suffix, e.g., ``'mnli_e2'``. - """ - return _T5GlueDataStream(_ensure_eval_suffix(benchmark)) + Args: + benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, + ``'mnli'``, ``'rte'``. If the benchmark includes an alternate + eval (e.g., MNLI's "mismatched" eval/validation split), you can + specify it with an ``'_e2'`` suffix, e.g., ``'mnli_e2'``. + """ + return _T5GlueDataStream(_ensure_eval_suffix(benchmark)) -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def T5GlueEvalStreamsParallel(benchmark_list=gin.REQUIRED): - """Returns a parallel set of T5 eval streams, based on ``benchmark_list``. + """Returns a parallel set of T5 eval streams, based on ``benchmark_list``. - Args: - benchmark_list: List of strings, each of which is a simple lower-case name - of a GLUE benchmark, e.g., ``'cola'``, ``'mnli'``, ``'rte'``. If a - benchmark includes an alternate eval (e.g., MNLI's "mismatched" - eval/validation split), you can specify it with an ``'_e2'`` suffix, - e.g., ``'mnli_e2'``. - """ - stream_list = list(map(T5GlueEvalStream, benchmark_list)) - return data.Parallel(stream_list)() + Args: + benchmark_list: List of strings, each of which is a simple lower-case name + of a GLUE benchmark, e.g., ``'cola'``, ``'mnli'``, ``'rte'``. If a + benchmark includes an alternate eval (e.g., MNLI's "mismatched" + eval/validation split), you can specify it with an ``'_e2'`` suffix, + e.g., ``'mnli_e2'``. + """ + stream_list = list(map(T5GlueEvalStream, benchmark_list)) + return data.Parallel(stream_list)() def _T5GlueDataStream(benchmark_id, t5_tokenization=False): - """Returns a T5-preprocessed data stream for ``benchmark_id``. - - Args: - benchmark_id: String that indicates the name and data split of a GLUE - benchmark. Data splits are indicated as underscore suffixes, e.g., - ``'cola_t'`` (Cola benchmark, training split), ``'rte_e'`` (RTE - benchmark, eval/validation split), and ``'mnli_e2'`` (MNLI benchmark, - alternate "mismatched" eval/validation split). - t5_tokenization: if true, then use t5_tokenization. - """ - return data.Serial( - _t5_glue_data_split(benchmark_id) - if t5_tokenization else _t5_glue_data_split_no_token(benchmark_id), - data.Tokenize(), - data.Shuffle(), - data.PadToLength(), - data.TruncateToLength(), - data.Batch(), - ) - - -@gin.configurable(module='trax.data') + """Returns a T5-preprocessed data stream for ``benchmark_id``. + + Args: + benchmark_id: String that indicates the name and data split of a GLUE + benchmark. Data splits are indicated as underscore suffixes, e.g., + ``'cola_t'`` (Cola benchmark, training split), ``'rte_e'`` (RTE + benchmark, eval/validation split), and ``'mnli_e2'`` (MNLI benchmark, + alternate "mismatched" eval/validation split). + t5_tokenization: if true, then use t5_tokenization. + """ + return data.Serial( + _t5_glue_data_split(benchmark_id) + if t5_tokenization + else _t5_glue_data_split_no_token(benchmark_id), + data.Tokenize(), + data.Shuffle(), + data.PadToLength(), + data.TruncateToLength(), + data.Batch(), + ) + + +@gin.configurable(module="trax.data") def T5GlueEvalTasks(benchmark_list=gin.REQUIRED): - """Returns a list of T5 GLUE eval tasks, based on ``benchmark_list``. + """Returns a list of T5 GLUE eval tasks, based on ``benchmark_list``. - Args: - benchmark_list: List of strings, each of which indicates the name and - data split of a GLUE benchmark. Data splits are indicated as underscore - suffixes, e.g., ``'cola_t'`` (Cola benchmark, training split), - ``'rte_e'`` (RTE benchmark, eval/validation split), and ``'mnli_e2'`` - (MNLI alternate "mismatched" eval/validation split). - """ - task_list = list(map(_T5GlueEvalTask, benchmark_list)) - return task_list + Args: + benchmark_list: List of strings, each of which indicates the name and + data split of a GLUE benchmark. Data splits are indicated as underscore + suffixes, e.g., ``'cola_t'`` (Cola benchmark, training split), + ``'rte_e'`` (RTE benchmark, eval/validation split), and ``'mnli_e2'`` + (MNLI alternate "mismatched" eval/validation split). + """ + task_list = list(map(_T5GlueEvalTask, benchmark_list)) + return task_list def _T5GlueEvalTask(benchmark_id): - """Returns a T5 GLUE eval task, based on ``benchmark_id``.""" - eval_data = T5GlueEvalStream(benchmark_id) - benchmark_id = _ensure_eval_suffix(benchmark_id) - metrics = [tl.WeightedCategoryAccuracy(), tl.SequenceAccuracy()] - benchmark, split = benchmark_id.rsplit('_', 1) - if benchmark == 'cola': - name_upper = 'Cola' - elif benchmark == 'mnli': - name_upper = 'MNLI_matched' if split == 'e' else 'MNLI_mismatched' - else: - name_upper = benchmark.upper() - return supervised.training.EvalTask( - eval_data(), - metrics, - metric_names=[f'{name_upper} accuracy', - f'{name_upper} sequence accuracy']) + """Returns a T5 GLUE eval task, based on ``benchmark_id``.""" + eval_data = T5GlueEvalStream(benchmark_id) + benchmark_id = _ensure_eval_suffix(benchmark_id) + metrics = [tl.WeightedCategoryAccuracy(), tl.SequenceAccuracy()] + benchmark, split = benchmark_id.rsplit("_", 1) + if benchmark == "cola": + name_upper = "Cola" + elif benchmark == "mnli": + name_upper = "MNLI_matched" if split == "e" else "MNLI_mismatched" + else: + name_upper = benchmark.upper() + return supervised.training.EvalTask( + eval_data(), + metrics, + metric_names=[f"{name_upper} accuracy", f"{name_upper} sequence accuracy"], + ) def _t5_glue_data_split_no_token(benchmark_id): - """Returns a GLUE data split prepared with the standard T5 preprocessor.""" - benchmark, split = _t5_glue_benchmark_and_split(benchmark_id) - dataset = tfds.load(name=f'glue/{benchmark}', split=split) - processed_dataset = t5_data().preprocessors.glue( # pylint: disable=g-long-lambda - dataset, - benchmark_name=benchmark, - label_names=_GLUE_LABELS[benchmark]) - - def stream_of_inputs_targets_weights(generator=None): - del generator - while True: - for example in processed_dataset: - input_values = example['inputs'].numpy() - target_values = example['targets'].numpy() - yield (input_values, - target_values, - jnp.array([1] * len(target_values))) - - return stream_of_inputs_targets_weights + """Returns a GLUE data split prepared with the standard T5 preprocessor.""" + benchmark, split = _t5_glue_benchmark_and_split(benchmark_id) + dataset = tfds.load(name=f"glue/{benchmark}", split=split) + processed_dataset = t5_data().preprocessors.glue( # pylint: disable=g-long-lambda + dataset, benchmark_name=benchmark, label_names=_GLUE_LABELS[benchmark] + ) + + def stream_of_inputs_targets_weights(generator=None): + del generator + while True: + for example in processed_dataset: + input_values = example["inputs"].numpy() + target_values = example["targets"].numpy() + yield (input_values, target_values, jnp.array([1] * len(target_values))) + + return stream_of_inputs_targets_weights def _t5_glue_data_split(benchmark_id): - """Returns a GLUE data split prepared with the standard T5 preprocessor.""" - benchmark, split = _t5_glue_benchmark_and_split(benchmark_id) - dataset = tfds.load(name=f'glue/{benchmark}', split=split) - processed_dataset = generic_text_dataset_preprocess_fn( - dataset, - spm_path=t5_data().DEFAULT_SPM_PATH, - text_preprocess_fns=[ - lambda ds, training: t5_data().preprocessors.glue( # pylint: disable=g-long-lambda - ds, - benchmark_name=benchmark, - label_names=_GLUE_LABELS[benchmark]) - ], - copy_pretokenized=True, - debug_print_examples=True, - debug_print_examples_rate=0.05) - dataset_as_numpy = tfds.as_numpy(processed_dataset) - - def stream_of_inputs_targets_weights(generator=None): - del generator - while True: - for example in dataset_as_numpy: - input_values = example['inputs'] - target_values = example['targets'] - yield (jnp.array(input_values), - jnp.array(target_values), - jnp.array([1] * len(target_values))) - - return stream_of_inputs_targets_weights + """Returns a GLUE data split prepared with the standard T5 preprocessor.""" + benchmark, split = _t5_glue_benchmark_and_split(benchmark_id) + dataset = tfds.load(name=f"glue/{benchmark}", split=split) + processed_dataset = generic_text_dataset_preprocess_fn( + dataset, + spm_path=t5_data().DEFAULT_SPM_PATH, + text_preprocess_fns=[ + lambda ds, training: t5_data().preprocessors.glue( # pylint: disable=g-long-lambda + ds, benchmark_name=benchmark, label_names=_GLUE_LABELS[benchmark] + ) + ], + copy_pretokenized=True, + debug_print_examples=True, + debug_print_examples_rate=0.05, + ) + dataset_as_numpy = tfds.as_numpy(processed_dataset) + + def stream_of_inputs_targets_weights(generator=None): + del generator + while True: + for example in dataset_as_numpy: + input_values = example["inputs"] + target_values = example["targets"] + yield ( + jnp.array(input_values), + jnp.array(target_values), + jnp.array([1] * len(target_values)), + ) + + return stream_of_inputs_targets_weights def _t5_glue_benchmark_and_split(benchmark_id): - benchmark, mode = benchmark_id.rsplit('_', 1) - if mode == 't': - split = 'train' - elif benchmark == 'mnli': - split = 'validation_mismatched' if mode == 'e2' else 'validation_matched' - else: - split = 'validation' - return benchmark, split + benchmark, mode = benchmark_id.rsplit("_", 1) + if mode == "t": + split = "train" + elif benchmark == "mnli": + split = "validation_mismatched" if mode == "e2" else "validation_matched" + else: + split = "validation" + return benchmark, split + + # pylint: enable=invalid-name def compute_single_result(op_name, num_args): - """An implementation of the most popular ops from the MathQA dataset.""" - # See https://gitlab.cs.washington.edu/amini91/mathqa-categorization/ - # and specfically line 142 and following in new_DataStructure.py - # for an implementation which covers more details. - if op_name == 'add': - return num_args[0] + num_args[1] - elif op_name == 'circle_arc': - return num_args[0] / 360 * math.pi * 2 * num_args[1] - elif op_name == 'circle_area': - return math.pi * num_args[0]**2 - elif op_name == 'circle_sector_area': - return num_args[1] / 360 * math.pi * (num_args[0]**2) - elif op_name == 'circumface': - return 2 * math.pi * num_args[0] - elif op_name == 'choose': - return scipy.special.comb(num_args[0], num_args[1]) - elif op_name == 'cosine': - return math.cos(num_args[0]) - elif op_name == 'cube_edge_by_volume': - return num_args[0]**(1 / 3) - elif op_name == 'combined_work': - return 1 / ( - min(num_args[0], 1 / num_args[0]) + min(num_args[1], 1 / num_args[1])) - elif op_name == 'count_interval': - return num_args[0] - num_args[1] + 1 - elif op_name == 'diagonal': - return math.sqrt(num_args[0]**2 + num_args[1]**2) - elif op_name == 'divide' or op_name == 'speed': - if num_args[1] != 0: - return num_args[0] / num_args[1] - else: - return 0 - elif op_name == 'factorial': - return math.factorial(min(15, int(num_args[0]))) - elif op_name == 'floor': - return math.floor(num_args[0]) - elif op_name == 'find_work': - return 1 / ( - max( - min(num_args[0], 1 / num_args[0]), min( - num_args[1], 1 / num_args[1])) - min( - min(num_args[0], 1 / num_args[0]), - min(num_args[1], 1 / num_args[1]))) - elif op_name == 'from_percent': - return num_args[0] / 100 - elif op_name == 'gain_percent': - return 100 + num_args[0] - elif op_name == 'gcd': - return scipy.gcd(int(num_args[0]), int(num_args[1])) - elif op_name == 'inverse': - if num_args[0] != 0: - return 1 / num_args[0] - else: - return 0 - elif op_name == 'lcm': - return scipy.lcm(int(num_args[0]), int(num_args[1])) - elif op_name == 'log': - return math.log(max(1e-5, num_args[0]), 2) - elif op_name == 'loss_percent': - return 100 - num_args[0] - elif op_name == 'max': - return max(num_args[0], num_args[1]) - elif op_name == 'multiply': - return num_args[0] * num_args[1] - elif op_name == 'negate_percent': - return 100 - num_args[0] - elif op_name == 'negate': - return -num_args[0] - elif op_name == 'original_price_before_loss': - return num_args[1] * 100 / (100 + 1e-5 - num_args[0]) - elif op_name == 'original_price_before_gain': - return num_args[1] * 100 / (100 + num_args[0]) - elif op_name == 'permutation': - n, m = min(num_args[0], num_args[1]), max(num_args[0], num_args[1]) - return math.factorial(int(m)) / math.factorial(int(m - n)) - elif op_name == 'power': - return num_args[0]**min(num_args[1], 5) - elif op_name == 'percent': - return num_args[0] / 100 * num_args[1] - elif op_name == 'price_after_gain' or op_name == 'p_after_gain': - return (1 + num_args[0] / 100) * num_args[1] - elif op_name == 'price_after_loss' or op_name == 'price_after_loss': - return (1 - num_args[0] / 100) * num_args[1] - elif op_name == 'quadrilateral_area': - return num_args[0] * (num_args[1] + num_args[2]) / 2 - elif op_name == 'reminder': - return num_args[0] % num_args[1] - elif op_name == 'rectangle_area': - return num_args[0] * num_args[1] - elif op_name == 'rectangle_perimeter': - return 2 * (num_args[0] + num_args[1]) - elif op_name == 'rhombus_area': - return num_args[0] * num_args[1] / 2 - elif op_name == 'sine': - return math.sin(num_args[0]) - elif op_name == 'sqrt': - return math.sqrt(max(0, num_args[0])) - elif op_name == 'subtract': - return num_args[0] - num_args[1] - elif op_name == 'square_edge_by_perimeter': - return num_args[0] / 4 - elif op_name == 'square_edge_by_area': - return math.sqrt(num_args[0]) - elif op_name == 'square_area': - return num_args[0]**2 - elif op_name == 'surface_cube': - return 6 * num_args[0]**2 - elif op_name == 'surface_rectangular_prism': - return 2 * ( - num_args[0] * num_args[1] + num_args[0] * num_args[2] + - num_args[1] * num_args[2]) - elif op_name == 'semi_circle_perimiter': - return math.pi * num_args[0] + 2 * num_args[0] - elif op_name == 'square_perimeter' or op_name == 'rhombus_perimeter': - return 4 * num_args[0] - elif op_name == 'surface_sphere': - return 4 * math.pi * num_args[0]**2 - elif op_name == 'speed_ratio_steel_to_stream': - return (num_args[0] + num_args[1]) / (num_args[0] - num_args[1]) - elif op_name == 'speed_in_still_water': - return (num_args[0] + num_args[1]) / 2 - elif op_name == 'stream_speed': - return (num_args[0] - num_args[1]) / 2 - elif op_name == 'trapezium_area': - return num_args[0] * (num_args[1] + num_args[2]) / 2 - elif op_name == 'triangle_area': - return num_args[0] * num_args[1] / 2 - elif op_name == 'triangle_perimeter': - return num_args[0] + num_args[1] + num_args[2] - elif op_name == 'triangle_area_three_edges': - # Heron's formula - s = (num_args[0] + num_args[1] + num_args[2]) / 2 - return math.sqrt( - max(0, - s * (s - num_args[0]) * (s - num_args[1]) * (s - num_args[2]))) - elif op_name == 'union_prob': - return num_args[0] + num_args[1] - num_args[2] - elif op_name == 'negate_prob': - return 1 - num_args[0] - elif op_name == 'volume_cube': - return num_args[0]**3 - elif op_name == 'volume_cone': - return math.pi * num_args[0]**2 * num_args[1] / 3 - elif op_name == 'volume_cylinder': - return math.pi * num_args[0]**2 * num_args[1] - elif op_name == 'volume_rectangular_prism': - return num_args[0] * num_args[1] * num_args[2] - elif op_name == 'volume_sphere': - return 4 / 3 * math.pi * num_args[0]**3 + """An implementation of the most popular ops from the MathQA dataset.""" + # See https://gitlab.cs.washington.edu/amini91/mathqa-categorization/ + # and specfically line 142 and following in new_DataStructure.py + # for an implementation which covers more details. + if op_name == "add": + return num_args[0] + num_args[1] + elif op_name == "circle_arc": + return num_args[0] / 360 * math.pi * 2 * num_args[1] + elif op_name == "circle_area": + return math.pi * num_args[0] ** 2 + elif op_name == "circle_sector_area": + return num_args[1] / 360 * math.pi * (num_args[0] ** 2) + elif op_name == "circumface": + return 2 * math.pi * num_args[0] + elif op_name == "choose": + return scipy.special.comb(num_args[0], num_args[1]) + elif op_name == "cosine": + return math.cos(num_args[0]) + elif op_name == "cube_edge_by_volume": + return num_args[0] ** (1 / 3) + elif op_name == "combined_work": + return 1 / ( + min(num_args[0], 1 / num_args[0]) + min(num_args[1], 1 / num_args[1]) + ) + elif op_name == "count_interval": + return num_args[0] - num_args[1] + 1 + elif op_name == "diagonal": + return math.sqrt(num_args[0] ** 2 + num_args[1] ** 2) + elif op_name == "divide" or op_name == "speed": + if num_args[1] != 0: + return num_args[0] / num_args[1] + else: + return 0 + elif op_name == "factorial": + return math.factorial(min(15, int(num_args[0]))) + elif op_name == "floor": + return math.floor(num_args[0]) + elif op_name == "find_work": + return 1 / ( + max(min(num_args[0], 1 / num_args[0]), min(num_args[1], 1 / num_args[1])) + - min(min(num_args[0], 1 / num_args[0]), min(num_args[1], 1 / num_args[1])) + ) + elif op_name == "from_percent": + return num_args[0] / 100 + elif op_name == "gain_percent": + return 100 + num_args[0] + elif op_name == "gcd": + return scipy.gcd(int(num_args[0]), int(num_args[1])) + elif op_name == "inverse": + if num_args[0] != 0: + return 1 / num_args[0] + else: + return 0 + elif op_name == "lcm": + return scipy.lcm(int(num_args[0]), int(num_args[1])) + elif op_name == "log": + return math.log(max(1e-5, num_args[0]), 2) + elif op_name == "loss_percent": + return 100 - num_args[0] + elif op_name == "max": + return max(num_args[0], num_args[1]) + elif op_name == "multiply": + return num_args[0] * num_args[1] + elif op_name == "negate_percent": + return 100 - num_args[0] + elif op_name == "negate": + return -num_args[0] + elif op_name == "original_price_before_loss": + return num_args[1] * 100 / (100 + 1e-5 - num_args[0]) + elif op_name == "original_price_before_gain": + return num_args[1] * 100 / (100 + num_args[0]) + elif op_name == "permutation": + n, m = min(num_args[0], num_args[1]), max(num_args[0], num_args[1]) + return math.factorial(int(m)) / math.factorial(int(m - n)) + elif op_name == "power": + return num_args[0] ** min(num_args[1], 5) + elif op_name == "percent": + return num_args[0] / 100 * num_args[1] + elif op_name == "price_after_gain" or op_name == "p_after_gain": + return (1 + num_args[0] / 100) * num_args[1] + elif op_name == "price_after_loss" or op_name == "price_after_loss": + return (1 - num_args[0] / 100) * num_args[1] + elif op_name == "quadrilateral_area": + return num_args[0] * (num_args[1] + num_args[2]) / 2 + elif op_name == "reminder": + return num_args[0] % num_args[1] + elif op_name == "rectangle_area": + return num_args[0] * num_args[1] + elif op_name == "rectangle_perimeter": + return 2 * (num_args[0] + num_args[1]) + elif op_name == "rhombus_area": + return num_args[0] * num_args[1] / 2 + elif op_name == "sine": + return math.sin(num_args[0]) + elif op_name == "sqrt": + return math.sqrt(max(0, num_args[0])) + elif op_name == "subtract": + return num_args[0] - num_args[1] + elif op_name == "square_edge_by_perimeter": + return num_args[0] / 4 + elif op_name == "square_edge_by_area": + return math.sqrt(num_args[0]) + elif op_name == "square_area": + return num_args[0] ** 2 + elif op_name == "surface_cube": + return 6 * num_args[0] ** 2 + elif op_name == "surface_rectangular_prism": + return 2 * ( + num_args[0] * num_args[1] + + num_args[0] * num_args[2] + + num_args[1] * num_args[2] + ) + elif op_name == "semi_circle_perimiter": + return math.pi * num_args[0] + 2 * num_args[0] + elif op_name == "square_perimeter" or op_name == "rhombus_perimeter": + return 4 * num_args[0] + elif op_name == "surface_sphere": + return 4 * math.pi * num_args[0] ** 2 + elif op_name == "speed_ratio_steel_to_stream": + return (num_args[0] + num_args[1]) / (num_args[0] - num_args[1]) + elif op_name == "speed_in_still_water": + return (num_args[0] + num_args[1]) / 2 + elif op_name == "stream_speed": + return (num_args[0] - num_args[1]) / 2 + elif op_name == "trapezium_area": + return num_args[0] * (num_args[1] + num_args[2]) / 2 + elif op_name == "triangle_area": + return num_args[0] * num_args[1] / 2 + elif op_name == "triangle_perimeter": + return num_args[0] + num_args[1] + num_args[2] + elif op_name == "triangle_area_three_edges": + # Heron's formula + s = (num_args[0] + num_args[1] + num_args[2]) / 2 + return math.sqrt( + max(0, s * (s - num_args[0]) * (s - num_args[1]) * (s - num_args[2])) + ) + elif op_name == "union_prob": + return num_args[0] + num_args[1] - num_args[2] + elif op_name == "negate_prob": + return 1 - num_args[0] + elif op_name == "volume_cube": + return num_args[0] ** 3 + elif op_name == "volume_cone": + return math.pi * num_args[0] ** 2 * num_args[1] / 3 + elif op_name == "volume_cylinder": + return math.pi * num_args[0] ** 2 * num_args[1] + elif op_name == "volume_rectangular_prism": + return num_args[0] * num_args[1] * num_args[2] + elif op_name == "volume_sphere": + return 4 / 3 * math.pi * num_args[0] ** 3 def compute_result(list_op, list_num): - """Python execution of MathQA ops.""" - # The last of temporary results is the final answer. - temporary_results = [] - for op in list_op: - op_name = op.split('(')[0] - start_bracket = op.find('(') - end_bracket = op.find(')') - op_args = op[start_bracket + 1:end_bracket].split(',') - num_args = [] - for arg in op_args: - # The hash stands for a number stored in temporary_results. - # For example #2 refers to the third temporary result. - if arg[0] == '#': - temp_index = int( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - arg)[0]) - num_args.append(temporary_results[temp_index]) - # The n prefix stands for numbers which listed in list_num - - # originally they were contained in the text. - elif arg[0] == 'n': - n_index = int( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - arg)[0]) - num_args.append(list_num[n_index]) - elif arg[0] == 'c': - if arg == 'const_pi': - constant = math.pi - elif arg == 'const_deg_to_rad': - constant = math.pi / 180 - else: - consts = re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', arg) - if len(consts) == 1: - constant = float(consts[0]) - else: - constant1 = float(consts[0]) - constant2 = float('0.' + consts[1]) - constant = constant1 + constant2 - num_args.append(constant) - temporary_results.append(compute_single_result(op_name, num_args)) - return temporary_results + """Python execution of MathQA ops.""" + # The last of temporary results is the final answer. + temporary_results = [] + for op in list_op: + op_name = op.split("(")[0] + start_bracket = op.find("(") + end_bracket = op.find(")") + op_args = op[start_bracket + 1 : end_bracket].split(",") + num_args = [] + for arg in op_args: + # The hash stands for a number stored in temporary_results. + # For example #2 refers to the third temporary result. + if arg[0] == "#": + temp_index = int( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + )[0] + ) + num_args.append(temporary_results[temp_index]) + # The n prefix stands for numbers which listed in list_num - + # originally they were contained in the text. + elif arg[0] == "n": + n_index = int( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + )[0] + ) + num_args.append(list_num[n_index]) + elif arg[0] == "c": + if arg == "const_pi": + constant = math.pi + elif arg == "const_deg_to_rad": + constant = math.pi / 180 + else: + consts = re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + ) + if len(consts) == 1: + constant = float(consts[0]) + else: + constant1 = float(consts[0]) + constant2 = float("0." + consts[1]) + constant = constant1 + constant2 + num_args.append(constant) + temporary_results.append(compute_single_result(op_name, num_args)) + return temporary_results def single_op_to_python_command(op_name, num_args): - """An implementation of the most popular ops from the MathQA dataset.""" - # See https://gitlab.cs.washington.edu/amini91/mathqa-categorization/ - # and specfically line 142 and following in new_DataStructure.py - # for an implementation which covers more details. - if op_name == 'add': - return '{} + {}'.format(num_args[0], num_args[1]) - elif op_name == 'circle_arc': - return '{} / 360 * math.pi * 2 * {}'.format(num_args[0], num_args[1]) - elif op_name == 'circle_area': - return 'math.pi * {}**2'.format(num_args[0]) - elif op_name == 'circle_sector_area': - return '{} / 360 * math.pi * ({}**2)'.format(num_args[1], num_args[0]) - elif op_name == 'circumface': - return '2 * math.pi * {}'.format(num_args[0]) - elif op_name == 'choose': - return 'scipy.special.comb({}, {})'.format(num_args[0], num_args[1]) - elif op_name == 'cosine': - return 'math.cos({})'.format(num_args[0]) - elif op_name == 'cube_edge_by_volume': - return '{}**(1 / 3)'.format(num_args[0]) - elif op_name == 'combined_work': - return '1 / (min({}, 1 / {}) + min({}, 1 / {}))'.format( - num_args[0], num_args[0], num_args[1], num_args[1]) - elif op_name == 'count_interval': - return '{} - {} + 1'.format(num_args[0], num_args[1]) - elif op_name == 'diagonal': - return 'math.sqrt({}**2 + {}**2)'.format(num_args[0], num_args[1]) - elif op_name == 'divide' or op_name == 'speed': - # safe divide - if num_args[1] != 0: - return '{} / {}'.format(num_args[0], num_args[1]) - else: - return '0' - elif op_name == 'factorial': - return 'math.factorial(min(15, int({})))'.format(num_args[0]) - elif op_name == 'floor': - return 'math.floor({})'.format(num_args[0]) - elif op_name == 'find_work': - return ('1 / (max(min({}, 1 / {}), min({}, 1 / {})) - min(min({}, 1 / {}), ' - 'min({}, 1 / {})))').format(num_args[0], num_args[0], num_args[1], - num_args[1], num_args[0], num_args[0], - num_args[1], num_args[1]) - elif op_name == 'from_percent': - return '{} / 100'.format(num_args[0]) - elif op_name == 'gain_percent': - return '100 + {}'.format(num_args[0]) - elif op_name == 'gcd': - return 'scipy.gcd(int({}), int({}))'.format(num_args[0], num_args[1]) - elif op_name == 'inverse': - # safe inverse - if num_args[0] != 0: - return '1 / {}'.format(num_args[0]) - else: - return '0' - elif op_name == 'lcm': - return 'scipy.lcm(int({}), int({}))'.format(num_args[0], num_args[1]) - elif op_name == 'log': - return 'math.log(max(1e-5, {}), 2)'.format(num_args[0]) - elif op_name == 'loss_percent': - return '100 - {}'.format(num_args[0]) - elif op_name == 'max': - return 'max({},{})'.format(num_args[0], num_args[1]) - elif op_name == 'multiply': - return '{} * {}'.format(num_args[0], num_args[1]) - elif op_name == 'negate_percent': - return '100 - {}'.format(num_args[0]) - elif op_name == 'negate': - return '-{}'.format(num_args[0]) - elif op_name == 'original_price_before_loss': - return '{} * 100 / (100 + 1e-5 - {}) # original price before loss'.format( - num_args[1], num_args[0]) - elif op_name == 'original_price_before_gain': - return '{} * 100 / (100 + {}) # original_price_before gain'.format( - num_args[1], num_args[0]) - elif op_name == 'permutation': - return ('math.factorial(int(max({}, {}))) / math.factorial(int(max({}, {}) ' - '- min({}, {}))) # find all permutations').format( - num_args[0], num_args[1], num_args[0], num_args[1], num_args[0], - num_args[1]) - elif op_name == 'power': - return '{}**min({}, 5)'.format(num_args[0], num_args[1]) - elif op_name == 'percent': - return '{} / 100 * {}'.format(num_args[0], num_args[1]) - elif op_name == 'price_after_gain' or op_name == 'p_after_gain': - return '(1 + {} / 100) * {}'.format(num_args[0], num_args[1]) - elif op_name == 'price_after_loss' or op_name == 'price_after_loss': - return '(1 - {} / 100) * {}'.format(num_args[0], num_args[1]) - elif op_name == 'quadrilateral_area': - return '{} * ({} + {}) / 2 # quadrilateral area'.format( - num_args[0], num_args[1], num_args[2]) - elif op_name == 'reminder': - return '{} % {}'.format(num_args[0], num_args[1]) - elif op_name == 'rectangle_area': - return '{} * {} # area of rectangle'.format(num_args[0], num_args[1]) - elif op_name == 'rectangle_perimeter': - return '2 * ({} + {}) # perimetere of rectangle'.format( - num_args[0], num_args[1]) - elif op_name == 'rhombus_area': - return '{} * {} / 2'.format(num_args[0], num_args[1]) - elif op_name == 'sine': - return 'math.sin({})'.format(num_args[0]) - elif op_name == 'sqrt': - return 'math.sqrt(max(0, {}))'.format(num_args[0]) - elif op_name == 'subtract': - return '{} - {}'.format(num_args[0], num_args[1]) - elif op_name == 'square_edge_by_perimeter': - return '{} / 4. # square edge given perimeter'.format(num_args[0]) - elif op_name == 'square_edge_by_area': - return 'math.sqrt({}) # square edge given area'.format(num_args[0]) - elif op_name == 'square_area': - return '{}**2'.format(num_args[0]) - elif op_name == 'surface_cube': - return '6 * {}**2 # surface of a cube'.format(num_args[0]) - elif op_name == 'surface_rectangular_prism': - return '2 * ({} * {} + {} * {} + {} * {}) # surface of a rectangular prism'.format( - num_args[0], num_args[1], num_args[0], num_args[2], num_args[1], - num_args[2]) - elif op_name == 'semi_circle_perimiter': - return 'math.pi * {} + 2 * {} # perimeter of a semi-circle'.format( - num_args[0], num_args[0]) - elif op_name == 'square_perimeter' or op_name == 'rhombus_perimeter': - return '4 * {}'.format(num_args[0]) - elif op_name == 'surface_sphere': - return '4 * math.pi * {}**2'.format(num_args[0]) - elif op_name == 'speed_ratio_steel_to_stream': - return '({} + {}) / ({} - {})'.format(num_args[0], num_args[1], num_args[0], - num_args[1]) - elif op_name == 'speed_in_still_water': - return '{} + {} / 2'.format(num_args[0], num_args[1]) - elif op_name == 'stream_speed': - return '{} - {} / 2'.format(num_args[0], num_args[1]) - elif op_name == 'trapezium_area': - return '{} * ({} + {}) / 2'.format(num_args[0], num_args[1], num_args[2]) - elif op_name == 'triangle_area': - return '{} * {} / 2'.format(num_args[0], num_args[1]) - elif op_name == 'triangle_perimeter': - return '{} + {} + {} # perimeter of a triangle'.format( - num_args[0], num_args[1], num_args[2]) - elif op_name == 'triangle_area_three_edges': - return ("(lambda s, a, b, c: math.sqrt(max(0, s * (s - a) * (s - b) * (s - " - "c))))(({} + {} + {}) / 2, {}, {}, {}) # Heron's formula").format( - num_args[0], num_args[1], num_args[2], num_args[0], num_args[1], - num_args[2]) - elif op_name == 'union_prob': - return '{} + {} - {}'.format(num_args[0], num_args[1], num_args[2]) - elif op_name == 'negate_prob': - return '1 - {}'.format(num_args[0]) - elif op_name == 'volume_cube': - return '{}**3'.format(num_args[0]) - elif op_name == 'volume_cone': - return 'math.pi * {}**2 * {} / 3'.format(num_args[0], num_args[1]) - elif op_name == 'volume_cylinder': - return 'math.pi * {}**2 * {}'.format(num_args[0], num_args[1]) - elif op_name == 'volume_rectangular_prism': - return '{} * {} * {}'.format(num_args[0], num_args[1], num_args[2]) - elif op_name == 'volume_sphere': - return '4 / 3 * math.pi * {}**3'.format(num_args[0]) + """An implementation of the most popular ops from the MathQA dataset.""" + # See https://gitlab.cs.washington.edu/amini91/mathqa-categorization/ + # and specfically line 142 and following in new_DataStructure.py + # for an implementation which covers more details. + if op_name == "add": + return "{} + {}".format(num_args[0], num_args[1]) + elif op_name == "circle_arc": + return "{} / 360 * math.pi * 2 * {}".format(num_args[0], num_args[1]) + elif op_name == "circle_area": + return "math.pi * {}**2".format(num_args[0]) + elif op_name == "circle_sector_area": + return "{} / 360 * math.pi * ({}**2)".format(num_args[1], num_args[0]) + elif op_name == "circumface": + return "2 * math.pi * {}".format(num_args[0]) + elif op_name == "choose": + return "scipy.special.comb({}, {})".format(num_args[0], num_args[1]) + elif op_name == "cosine": + return "math.cos({})".format(num_args[0]) + elif op_name == "cube_edge_by_volume": + return "{}**(1 / 3)".format(num_args[0]) + elif op_name == "combined_work": + return "1 / (min({}, 1 / {}) + min({}, 1 / {}))".format( + num_args[0], num_args[0], num_args[1], num_args[1] + ) + elif op_name == "count_interval": + return "{} - {} + 1".format(num_args[0], num_args[1]) + elif op_name == "diagonal": + return "math.sqrt({}**2 + {}**2)".format(num_args[0], num_args[1]) + elif op_name == "divide" or op_name == "speed": + # safe divide + if num_args[1] != 0: + return "{} / {}".format(num_args[0], num_args[1]) + else: + return "0" + elif op_name == "factorial": + return "math.factorial(min(15, int({})))".format(num_args[0]) + elif op_name == "floor": + return "math.floor({})".format(num_args[0]) + elif op_name == "find_work": + return ( + "1 / (max(min({}, 1 / {}), min({}, 1 / {})) - min(min({}, 1 / {}), " + "min({}, 1 / {})))" + ).format( + num_args[0], + num_args[0], + num_args[1], + num_args[1], + num_args[0], + num_args[0], + num_args[1], + num_args[1], + ) + elif op_name == "from_percent": + return "{} / 100".format(num_args[0]) + elif op_name == "gain_percent": + return "100 + {}".format(num_args[0]) + elif op_name == "gcd": + return "scipy.gcd(int({}), int({}))".format(num_args[0], num_args[1]) + elif op_name == "inverse": + # safe inverse + if num_args[0] != 0: + return "1 / {}".format(num_args[0]) + else: + return "0" + elif op_name == "lcm": + return "scipy.lcm(int({}), int({}))".format(num_args[0], num_args[1]) + elif op_name == "log": + return "math.log(max(1e-5, {}), 2)".format(num_args[0]) + elif op_name == "loss_percent": + return "100 - {}".format(num_args[0]) + elif op_name == "max": + return "max({},{})".format(num_args[0], num_args[1]) + elif op_name == "multiply": + return "{} * {}".format(num_args[0], num_args[1]) + elif op_name == "negate_percent": + return "100 - {}".format(num_args[0]) + elif op_name == "negate": + return "-{}".format(num_args[0]) + elif op_name == "original_price_before_loss": + return "{} * 100 / (100 + 1e-5 - {}) # original price before loss".format( + num_args[1], num_args[0] + ) + elif op_name == "original_price_before_gain": + return "{} * 100 / (100 + {}) # original_price_before gain".format( + num_args[1], num_args[0] + ) + elif op_name == "permutation": + return ( + "math.factorial(int(max({}, {}))) / math.factorial(int(max({}, {}) " + "- min({}, {}))) # find all permutations" + ).format( + num_args[0], num_args[1], num_args[0], num_args[1], num_args[0], num_args[1] + ) + elif op_name == "power": + return "{}**min({}, 5)".format(num_args[0], num_args[1]) + elif op_name == "percent": + return "{} / 100 * {}".format(num_args[0], num_args[1]) + elif op_name == "price_after_gain" or op_name == "p_after_gain": + return "(1 + {} / 100) * {}".format(num_args[0], num_args[1]) + elif op_name == "price_after_loss" or op_name == "price_after_loss": + return "(1 - {} / 100) * {}".format(num_args[0], num_args[1]) + elif op_name == "quadrilateral_area": + return "{} * ({} + {}) / 2 # quadrilateral area".format( + num_args[0], num_args[1], num_args[2] + ) + elif op_name == "reminder": + return "{} % {}".format(num_args[0], num_args[1]) + elif op_name == "rectangle_area": + return "{} * {} # area of rectangle".format(num_args[0], num_args[1]) + elif op_name == "rectangle_perimeter": + return "2 * ({} + {}) # perimetere of rectangle".format( + num_args[0], num_args[1] + ) + elif op_name == "rhombus_area": + return "{} * {} / 2".format(num_args[0], num_args[1]) + elif op_name == "sine": + return "math.sin({})".format(num_args[0]) + elif op_name == "sqrt": + return "math.sqrt(max(0, {}))".format(num_args[0]) + elif op_name == "subtract": + return "{} - {}".format(num_args[0], num_args[1]) + elif op_name == "square_edge_by_perimeter": + return "{} / 4. # square edge given perimeter".format(num_args[0]) + elif op_name == "square_edge_by_area": + return "math.sqrt({}) # square edge given area".format(num_args[0]) + elif op_name == "square_area": + return "{}**2".format(num_args[0]) + elif op_name == "surface_cube": + return "6 * {}**2 # surface of a cube".format(num_args[0]) + elif op_name == "surface_rectangular_prism": + return "2 * ({} * {} + {} * {} + {} * {}) # surface of a rectangular prism".format( + num_args[0], num_args[1], num_args[0], num_args[2], num_args[1], num_args[2] + ) + elif op_name == "semi_circle_perimiter": + return "math.pi * {} + 2 * {} # perimeter of a semi-circle".format( + num_args[0], num_args[0] + ) + elif op_name == "square_perimeter" or op_name == "rhombus_perimeter": + return "4 * {}".format(num_args[0]) + elif op_name == "surface_sphere": + return "4 * math.pi * {}**2".format(num_args[0]) + elif op_name == "speed_ratio_steel_to_stream": + return "({} + {}) / ({} - {})".format( + num_args[0], num_args[1], num_args[0], num_args[1] + ) + elif op_name == "speed_in_still_water": + return "{} + {} / 2".format(num_args[0], num_args[1]) + elif op_name == "stream_speed": + return "{} - {} / 2".format(num_args[0], num_args[1]) + elif op_name == "trapezium_area": + return "{} * ({} + {}) / 2".format(num_args[0], num_args[1], num_args[2]) + elif op_name == "triangle_area": + return "{} * {} / 2".format(num_args[0], num_args[1]) + elif op_name == "triangle_perimeter": + return "{} + {} + {} # perimeter of a triangle".format( + num_args[0], num_args[1], num_args[2] + ) + elif op_name == "triangle_area_three_edges": + return ( + "(lambda s, a, b, c: math.sqrt(max(0, s * (s - a) * (s - b) * (s - " + "c))))(({} + {} + {}) / 2, {}, {}, {}) # Heron's formula" + ).format( + num_args[0], num_args[1], num_args[2], num_args[0], num_args[1], num_args[2] + ) + elif op_name == "union_prob": + return "{} + {} - {}".format(num_args[0], num_args[1], num_args[2]) + elif op_name == "negate_prob": + return "1 - {}".format(num_args[0]) + elif op_name == "volume_cube": + return "{}**3".format(num_args[0]) + elif op_name == "volume_cone": + return "math.pi * {}**2 * {} / 3".format(num_args[0], num_args[1]) + elif op_name == "volume_cylinder": + return "math.pi * {}**2 * {}".format(num_args[0], num_args[1]) + elif op_name == "volume_rectangular_prism": + return "{} * {} * {}".format(num_args[0], num_args[1], num_args[2]) + elif op_name == "volume_sphere": + return "4 / 3 * math.pi * {}**3".format(num_args[0]) def compute_program(list_op): - """Python execution of MathQA ops.""" - # The last of temporary results is the final answer. - temporary_results = [] - num_op = 0 - for op in list_op: - op_name = op.split('(')[0] - start_bracket = op.find('(') - end_bracket = op.find(')') - op_args = op[start_bracket + 1:end_bracket].split(',') - num_args = [] - for arg in op_args: - # The hash stands for a number stored in temporary_results. - # For example #2 refers to the third temporary result. - if arg[0] == '#': - temp_index = int( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - arg)[0]) - num_args.append('t{}'.format(temp_index)) - # The n prefix stands for numbers which listed in list_num - - # originally they were contained in the text. - elif arg[0] == 'n': - # n_index = int( - # re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - # arg)[0]) - num_args.append(arg) - elif arg[0] == 'c': - if arg == 'const_pi': - constant = math.pi - elif arg == 'const_deg_to_rad': - constant = math.pi / 180 - else: - consts = re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', arg) - if len(consts) == 1: - constant = float(consts[0]) - else: - constant1 = float(consts[0]) - constant2 = float('0.' + consts[1]) - constant = constant1 + constant2 - num_args.append(str(constant)) - temporary_result = 't{} = {}'.format( - num_op, single_op_to_python_command(op_name, num_args)) - temporary_results.append(temporary_result) - num_op += 1 - return temporary_results + """Python execution of MathQA ops.""" + # The last of temporary results is the final answer. + temporary_results = [] + num_op = 0 + for op in list_op: + op_name = op.split("(")[0] + start_bracket = op.find("(") + end_bracket = op.find(")") + op_args = op[start_bracket + 1 : end_bracket].split(",") + num_args = [] + for arg in op_args: + # The hash stands for a number stored in temporary_results. + # For example #2 refers to the third temporary result. + if arg[0] == "#": + temp_index = int( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + )[0] + ) + num_args.append("t{}".format(temp_index)) + # The n prefix stands for numbers which listed in list_num - + # originally they were contained in the text. + elif arg[0] == "n": + # n_index = int( + # re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', + # arg)[0]) + num_args.append(arg) + elif arg[0] == "c": + if arg == "const_pi": + constant = math.pi + elif arg == "const_deg_to_rad": + constant = math.pi / 180 + else: + consts = re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + ) + if len(consts) == 1: + constant = float(consts[0]) + else: + constant1 = float(consts[0]) + constant2 = float("0." + consts[1]) + constant = constant1 + constant2 + num_args.append(str(constant)) + temporary_result = "t{} = {}".format( + num_op, single_op_to_python_command(op_name, num_args) + ) + temporary_results.append(temporary_result) + num_op += 1 + return temporary_results def compute_nums(question): - """Finds numbers in a string and convert them to floats.""" - # The funny looking replace is needed to deal with numbers such as 4,000 - # TODO(henrykm) deal with numbers written as words "one", "two", ... - return [ - float(num.replace(',', '')) for num in re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', question) - ] + """Finds numbers in a string and convert them to floats.""" + # The funny looking replace is needed to deal with numbers such as 4,000 + # TODO(henrykm) deal with numbers written as words "one", "two", ... + return [ + float(num.replace(",", "")) + for num in re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", question + ) + ] def compute_ops(linear_formula): - list_op = linear_formula.split('|') - # In some cases the list of operations contains a superflous last element, - # namely an empty string. - if not list_op[-1]: - list_op = list_op[:-1] - return list_op + list_op = linear_formula.split("|") + # In some cases the list of operations contains a superflous last element, + # namely an empty string. + if not list_op[-1]: + list_op = list_op[:-1] + return list_op def process_single_mathqa_example(example): - """Execute a single example and verify coherence of a MathQA problem. - - Args: - example: a dictionary with the following fields: Problem - a natural - language formulation of the problem Rationale - a natural language - solution of the problem options - five possible answers ( a) b) c) d) and - e) ) correct - the letter representing the correct answer - annotated_formula - formula representing the full solution linear_formula - - a string of operations separated by the | character, e.g. - multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)| - multiply(#2,const_100)|divide(#3,#1)| category - a natural language - description of the category to which a given problem belongs. - - Returns: - answer_num: numerical answer contained in the example - python_result: numerical answers computed in Python, including intermediate - results. The answer_num should be close python_result[-1] - list_op: list of arithmetic operations - list_num: list of identified numbers in the text - """ - question = example['Problem'] - list_num = compute_nums(question) - list_op = compute_ops(example['linear_formula']) - answers = example['options'] - correct_answer = example['correct'] - index = answers.find('{} )'.format(correct_answer)) - answer_string = re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', answers[index:]) - # The if statement deals with empty lists - they are needed to treat - # a correct non-numerical answer e) None of the above. Here we do not want - # non-numerical answers, hence we return None. - if answer_string: - answer_num = float( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - answers[index:])[0].replace(',', '')) - else: - return None - # The if statements below deals with answers written as fractions e.g. - # a ) 1 / 2 , b ) 1 / 3 , c ) 1 / 5 , d ) 10 / 30 , e ) 2 / 5 ? - index_end_of_answer = index + len(str(answer_num)) + 3 - if index_end_of_answer < len(answers) and answers[index_end_of_answer] == '/': - answer_denom = float( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - answers[index_end_of_answer:])[0].replace(',', '')) - answer_num /= answer_denom - python_result = compute_result(list_op, list_num) - python_program = compute_program(list_op) - return answer_num, python_result, python_program, list_op, list_num + """Execute a single example and verify coherence of a MathQA problem. + + Args: + example: a dictionary with the following fields: Problem - a natural + language formulation of the problem Rationale - a natural language + solution of the problem options - five possible answers ( a) b) c) d) and + e) ) correct - the letter representing the correct answer + annotated_formula - formula representing the full solution linear_formula + - a string of operations separated by the | character, e.g. + multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)| + multiply(#2,const_100)|divide(#3,#1)| category - a natural language + description of the category to which a given problem belongs. + + Returns: + answer_num: numerical answer contained in the example + python_result: numerical answers computed in Python, including intermediate + results. The answer_num should be close python_result[-1] + list_op: list of arithmetic operations + list_num: list of identified numbers in the text + """ + question = example["Problem"] + list_num = compute_nums(question) + list_op = compute_ops(example["linear_formula"]) + answers = example["options"] + correct_answer = example["correct"] + index = answers.find("{} )".format(correct_answer)) + answer_string = re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", answers[index:] + ) + # The if statement deals with empty lists - they are needed to treat + # a correct non-numerical answer e) None of the above. Here we do not want + # non-numerical answers, hence we return None. + if answer_string: + answer_num = float( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", answers[index:] + )[0].replace(",", "") + ) + else: + return None + # The if statements below deals with answers written as fractions e.g. + # a ) 1 / 2 , b ) 1 / 3 , c ) 1 / 5 , d ) 10 / 30 , e ) 2 / 5 ? + index_end_of_answer = index + len(str(answer_num)) + 3 + if index_end_of_answer < len(answers) and answers[index_end_of_answer] == "/": + answer_denom = float( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", + answers[index_end_of_answer:], + )[0].replace(",", "") + ) + answer_num /= answer_denom + python_result = compute_result(list_op, list_num) + python_program = compute_program(list_op) + return answer_num, python_result, python_program, list_op, list_num def convert_float_to_mathqa(number): - floor = int(float(number)) - if floor == number: - return 'const_' + str(floor) - else: - return 'const_' + str(floor) + '_' + str(number)[len(str(floor)) + 1:] + floor = int(float(number)) + if floor == number: + return "const_" + str(floor) + else: + return "const_" + str(floor) + "_" + str(number)[len(str(floor)) + 1 :] def convert_to_subtract(const_string): - return 'subtract({},const_0)'.format(const_string) + return "subtract({},const_0)".format(const_string) def execute_mathqa_dsl_program(problem, dsl_code): - """Executes the DSL code for a given problem. - - Args: - problem: problem formulation (needed to get parameters). - dsl_code: DSL code. - - Returns: - the result of executing of the DSL code. - """ - n0_loc = problem.find('n0') - list_num = compute_nums(problem[n0_loc:]) - # The list contains _all_ numbers in the string, hence in particular - # for n0 = 2.0 n1 = 3.0 we are getting list_num = [0.0, 2.0, 1.0, 3.0], - # so that below we are filtering the odd occurrences. - assert len(list_num) % 2 == 0 - list_num = [list_num[2 * i + 1] for i in range(int(len(list_num) / 2))] - - # dsl_code is a list of strings; since all DSL programs are single liners, - # we need to guess the correct line. For now we use the same location as in - # in the ground truth examples, that is the first line. - list_op = compute_ops(dsl_code[0]) - - try: - results = compute_result(list_op, list_num)[-1] - except: # pylint: disable=bare-except - results = None - return results + """Executes the DSL code for a given problem. + + Args: + problem: problem formulation (needed to get parameters). + dsl_code: DSL code. + + Returns: + the result of executing of the DSL code. + """ + n0_loc = problem.find("n0") + list_num = compute_nums(problem[n0_loc:]) + # The list contains _all_ numbers in the string, hence in particular + # for n0 = 2.0 n1 = 3.0 we are getting list_num = [0.0, 2.0, 1.0, 3.0], + # so that below we are filtering the odd occurrences. + assert len(list_num) % 2 == 0 + list_num = [list_num[2 * i + 1] for i in range(int(len(list_num) / 2))] + + # dsl_code is a list of strings; since all DSL programs are single liners, + # we need to guess the correct line. For now we use the same location as in + # in the ground truth examples, that is the first line. + list_op = compute_ops(dsl_code[0]) + + try: + results = compute_result(list_op, list_num)[-1] + except: # pylint: disable=bare-except + results = None + return results def is_number(s): - try: - float(s) - return True - except: # pylint: disable=bare-except - return False + try: + float(s) + return True + except: # pylint: disable=bare-except + return False def execute_mathqa_program(problem, program): - """Executes the DSL code for a given problem. - - Args: - problem: problem formulation (not needed, but we want the same API as - in the DSL case). - program: Python code. - - Returns: - the result of executing of the Python code. - """ - del problem # problem only needed in the DSL version. - # Programs are lists of strings. We need to concatenate them in order to exec. - program = '\n'.join(program) - var_dict = {} - try: - # The logic of this is the following: if exec with timeout is working - # without exceptions, then we can call exec again and gather the variables. - exec(program, globals(), var_dict) # pylint: disable=exec-used - if 'answer' in var_dict and is_number(var_dict['answer']): - return float(var_dict['answer']) - else: - return None - except: # pylint: disable=bare-except - return None + """Executes the DSL code for a given problem. + + Args: + problem: problem formulation (not needed, but we want the same API as + in the DSL case). + program: Python code. + + Returns: + the result of executing of the Python code. + """ + del problem # problem only needed in the DSL version. + # Programs are lists of strings. We need to concatenate them in order to exec. + program = "\n".join(program) + var_dict = {} + try: + # The logic of this is the following: if exec with timeout is working + # without exceptions, then we can call exec again and gather the variables. + exec(program, globals(), var_dict) # pylint: disable=exec-used + if "answer" in var_dict and is_number(var_dict["answer"]): + return float(var_dict["answer"]) + else: + return None + except: # pylint: disable=bare-except + return None -@gin.configurable(module='trax.data') +@gin.configurable(module="trax.data") def CreateMathQAInputs( # pylint: disable=invalid-name dataset_path=None, train=True, @@ -2286,185 +2405,208 @@ def CreateMathQAInputs( # pylint: disable=invalid-name category=False, order_prediction=False, reduced_operation_name=True, - qed=False): - """Prepares MathQA inputs. - - The generation procedure leaves a lot parameters to be set by the user. - Currently we support only correct examples in the following sense: - python execution agrees with the declared answer up to 1%. - - According to this criterion wrong examples such as - problem: calculate 85184 Γ· ? = 352 - operations ['multiply(n0,n1)'] - are ignored (this should be divide(n0,n1) in this case). - - Args: - dataset_path: a path with the MathQA dataset. - train: if True, then generate training examples; if train, test and - challenge are set to False generate validation examples. - test: if train is set to False and test is set to True, - then generate test examples. - challenge: if train and test are set to False and challenge is set to True, - then generate challenge examples. - tolerance: if for a given example relative difference between Python result - and the result declared in the dataset exceeds the level, then the example - is dropped; tolerances ranging from 0.1 to 0.001 yield from 18K to 21K - examples. - cumulative: if set to True, then generate examples in the format input - - problem + numbers + op1 + op2 + op3 target - op4 If set to False, then - examples are in the format input - problem + numbers target - all - operations. - python_code: if set to True, then generates python code instead of - MathQA commands. - full_dict: if set to True, then Python examples are returned together with - the DSL code and the NLP rationale. - partial_results: if set to True, then partial results will be reported as - part of the input, e.g. input - problem + numbers + op1 + #1 + op2 + #2 + - op3 + #3, target - op4, where #k is the partial results from operation - opk. Activated only in cumulative set to True. - nlp_rationale: if set to True, then input is the problem and the target is - the nlp rationale. - correct_answer: if set to True, then input is the problem plus all possible - answers and the target is the correct answer. - answer_in_mathqa_format: if set to True, then convert numerical answer to - the MathQA format and wrap it in the subtract operation. - E.g. "3.13" is converted to "subtract(const_3_13,const_0)". - correct_answer_given_reasoning: if set to True, then input is the problem - plus linear formula plus all possible answers and the target is the - correct answer. - category: if set to True, then input is the problem and the target is its - category. - order_prediction: if set to True, then input is the problem and a list of - all operations; with probability 0.5 two operations are swapped; the task - consists in detecting whether the operations were swapped. See the - order prediction task in CreateAquaInputs in this file. - reduced_operation_name: If set to True, then in order prediction consider - only the operation token without parameterers. - qed: if set to True, then the reasoning is finished with an additional - operation qed. - - Returns: - mathqa_yield_examples: a generator of MathQA examples; the generator yields - non-tokenized examples - they can be further processed using for example - the tokenize function from this module - """ - if train: - dataset_path = os.path.join(dataset_path, 'train.json') - elif test: - dataset_path = os.path.join(dataset_path, 'test.json') - elif challenge: - dataset_path = os.path.join(dataset_path, 'challenge_test.json') - else: - dataset_path = os.path.join(dataset_path, 'dev.json') - # Opening with GFile allows to use remotely stored files, e.g. - # in a gs bucket. - dataset_handle = tf.io.gfile.GFile(dataset_path, 'r') - dataset = json.load(dataset_handle) - - def mathqa_yield_examples(generator=None): - del generator - while True: - for example in itertools.cycle(dataset): - result = process_single_mathqa_example(example) - # TODO(henrykm): Remove the first two ifs. - if not result: - continue - answer_num, python_result, python_program, list_op, list_num = result - if not answer_num or not python_result[-1]: - continue - if qed: - list_op.append('qed') - if math.isclose(answer_num, python_result[-1], rel_tol=tolerance): - input_prefix = example['Problem'] - for i in range(len(list_num)): - input_prefix += ' n{} = {}'.format(i, list_num[i]) - if cumulative: - for i in range(len(list_op)): - input_values = input_prefix - target_values = list_op[i] - input_prefix += ' ' + list_op[i] - if partial_results: - input_prefix += ' #{} = {}'.format(i, answer_num) - yield input_values, target_values, np.array([1] * - len(target_values)) - elif python_code: - input_values = '# ' + input_prefix - target_values = '' - for command in python_program: - if 'math' in command: - target_values += 'import math\n' - break - for command in python_program: - if 'scipy' in command: - target_values += 'import scipy\n' - break - for i in range(len(list_num)): - target_values += 'n{} = {}\n'.format(i, list_num[i]) - target_values += '\n'.join(python_program[:-1]) - final_line = python_program[-1].split('=')[1] - target_values += '\nanswer ={}'.format(final_line) - var_dict = {} - # We generate a python code and want to check whether the answer - # is coorect. - exec(target_values, globals(), var_dict) # pylint: disable=exec-used - if math.isclose(answer_num, var_dict['answer'], rel_tol=tolerance): - if full_dict: - yield input_values, target_values, example[ - 'linear_formula'], example['Rationale'] - else: - yield input_values, target_values, np.array([1] * - len(target_values)) - elif nlp_rationale: - input_values = 'infer full rationale: ' + input_prefix - target_values = example['Rationale'] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif correct_answer: - input_values = 'infer correct answer: ' + input_prefix - input_values += ' ' + example['options'] - if answer_in_mathqa_format: - target_values = str(answer_num) - target_values = convert_to_subtract( - convert_float_to_mathqa(target_values)) - else: - target_values = example['correct'] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif correct_answer_given_reasoning: - input_values = 'infer correct answer given reasoning: ' + input_prefix - input_values += ' ' + ' '.join(list_op) + ' ' + example['options'] - target_values = example['correct'] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif category: - input_values = 'infer category: ' + input_prefix - target_values = example['category'] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif order_prediction: - if np.random.uniform() < 0.5 and len(list_op) >= 2: - idx = range(len(list_op)) - i1, i2 = random.sample(idx, 2) - list_op[i1], list_op[i2] = list_op[i2], list_op[i1] - target_values = 'not_ordered' - else: - target_values = 'ordered' - if reduced_operation_name: - list_op = [op.split('(')[0] for op in list_op] - input_values = 'order prediction: ' + input_prefix + ' ' + ' '.join( - list_op) - yield input_values, target_values, np.array([1] * - len(target_values)) - else: - input_values = 'infer full calculation: ' + input_prefix - target_values = example['linear_formula'] - yield input_values, target_values, np.array([1] * - len(target_values)) - - return mathqa_yield_examples - - -@gin.configurable(module='trax.data') + qed=False, +): + """Prepares MathQA inputs. + + The generation procedure leaves a lot parameters to be set by the user. + Currently we support only correct examples in the following sense: + python execution agrees with the declared answer up to 1%. + + According to this criterion wrong examples such as + problem: calculate 85184 Γ· ? = 352 + operations ['multiply(n0,n1)'] + are ignored (this should be divide(n0,n1) in this case). + + Args: + dataset_path: a path with the MathQA dataset. + train: if True, then generate training examples; if train, test and + challenge are set to False generate validation examples. + test: if train is set to False and test is set to True, + then generate test examples. + challenge: if train and test are set to False and challenge is set to True, + then generate challenge examples. + tolerance: if for a given example relative difference between Python result + and the result declared in the dataset exceeds the level, then the example + is dropped; tolerances ranging from 0.1 to 0.001 yield from 18K to 21K + examples. + cumulative: if set to True, then generate examples in the format input - + problem + numbers + op1 + op2 + op3 target - op4 If set to False, then + examples are in the format input - problem + numbers target - all + operations. + python_code: if set to True, then generates python code instead of + MathQA commands. + full_dict: if set to True, then Python examples are returned together with + the DSL code and the NLP rationale. + partial_results: if set to True, then partial results will be reported as + part of the input, e.g. input - problem + numbers + op1 + #1 + op2 + #2 + + op3 + #3, target - op4, where #k is the partial results from operation + opk. Activated only in cumulative set to True. + nlp_rationale: if set to True, then input is the problem and the target is + the nlp rationale. + correct_answer: if set to True, then input is the problem plus all possible + answers and the target is the correct answer. + answer_in_mathqa_format: if set to True, then convert numerical answer to + the MathQA format and wrap it in the subtract operation. + E.g. "3.13" is converted to "subtract(const_3_13,const_0)". + correct_answer_given_reasoning: if set to True, then input is the problem + plus linear formula plus all possible answers and the target is the + correct answer. + category: if set to True, then input is the problem and the target is its + category. + order_prediction: if set to True, then input is the problem and a list of + all operations; with probability 0.5 two operations are swapped; the task + consists in detecting whether the operations were swapped. See the + order prediction task in CreateAquaInputs in this file. + reduced_operation_name: If set to True, then in order prediction consider + only the operation token without parameterers. + qed: if set to True, then the reasoning is finished with an additional + operation qed. + + Returns: + mathqa_yield_examples: a generator of MathQA examples; the generator yields + non-tokenized examples - they can be further processed using for example + the tokenize function from this module + """ + if train: + dataset_path = os.path.join(dataset_path, "train.json") + elif test: + dataset_path = os.path.join(dataset_path, "test.json") + elif challenge: + dataset_path = os.path.join(dataset_path, "challenge_test.json") + else: + dataset_path = os.path.join(dataset_path, "dev.json") + # Opening with GFile allows to use remotely stored files, e.g. + # in a gs bucket. + dataset_handle = tf.io.gfile.GFile(dataset_path, "r") + dataset = json.load(dataset_handle) + + def mathqa_yield_examples(generator=None): + del generator + while True: + for example in itertools.cycle(dataset): + result = process_single_mathqa_example(example) + # TODO(henrykm): Remove the first two ifs. + if not result: + continue + answer_num, python_result, python_program, list_op, list_num = result + if not answer_num or not python_result[-1]: + continue + if qed: + list_op.append("qed") + if math.isclose(answer_num, python_result[-1], rel_tol=tolerance): + input_prefix = example["Problem"] + for i in range(len(list_num)): + input_prefix += " n{} = {}".format(i, list_num[i]) + if cumulative: + for i in range(len(list_op)): + input_values = input_prefix + target_values = list_op[i] + input_prefix += " " + list_op[i] + if partial_results: + input_prefix += " #{} = {}".format(i, answer_num) + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif python_code: + input_values = "# " + input_prefix + target_values = "" + for command in python_program: + if "math" in command: + target_values += "import math\n" + break + for command in python_program: + if "scipy" in command: + target_values += "import scipy\n" + break + for i in range(len(list_num)): + target_values += "n{} = {}\n".format(i, list_num[i]) + target_values += "\n".join(python_program[:-1]) + final_line = python_program[-1].split("=")[1] + target_values += "\nanswer ={}".format(final_line) + var_dict = {} + # We generate a python code and want to check whether the answer + # is coorect. + exec( + target_values, globals(), var_dict + ) # pylint: disable=exec-used + if math.isclose( + answer_num, var_dict["answer"], rel_tol=tolerance + ): + if full_dict: + yield input_values, target_values, example[ + "linear_formula" + ], example["Rationale"] + else: + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif nlp_rationale: + input_values = "infer full rationale: " + input_prefix + target_values = example["Rationale"] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif correct_answer: + input_values = "infer correct answer: " + input_prefix + input_values += " " + example["options"] + if answer_in_mathqa_format: + target_values = str(answer_num) + target_values = convert_to_subtract( + convert_float_to_mathqa(target_values) + ) + else: + target_values = example["correct"] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif correct_answer_given_reasoning: + input_values = ( + "infer correct answer given reasoning: " + input_prefix + ) + input_values += ( + " " + " ".join(list_op) + " " + example["options"] + ) + target_values = example["correct"] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif category: + input_values = "infer category: " + input_prefix + target_values = example["category"] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif order_prediction: + if np.random.uniform() < 0.5 and len(list_op) >= 2: + idx = range(len(list_op)) + i1, i2 = random.sample(idx, 2) + list_op[i1], list_op[i2] = list_op[i2], list_op[i1] + target_values = "not_ordered" + else: + target_values = "ordered" + if reduced_operation_name: + list_op = [op.split("(")[0] for op in list_op] + input_values = ( + "order prediction: " + + input_prefix + + " " + + " ".join(list_op) + ) + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + else: + input_values = "infer full calculation: " + input_prefix + target_values = example["linear_formula"] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + + return mathqa_yield_examples + + +@gin.configurable(module="trax.data") def CreateAquaInputs( # pylint: disable=invalid-name dataset_path=None, train=True, @@ -2473,283 +2615,308 @@ def CreateAquaInputs( # pylint: disable=invalid-name correct_answer=False, correct_answer_given_reasoning=False, partial_reasoning=True, - order_prediction=False): - """Prepares Aqua inputs. - - Args: - dataset_path: a path with the Aqua dataset. - train: if True, then generate training examples, otherwhise generate - validation examples (the dataset has also a test set). - cumulative: if set to True, then generate examples in the format input - - problem + step1 + step3 + step3 target - step4 If set to False, then - examples are in the format input - problem, target - all operations. - rationale: if set to True, then input is the problem and the target is the - rationale. - correct_answer: if set to True, then input is the problem plus all possible - answers and the target is the correct answer. - correct_answer_given_reasoning: if set to True, then input is the problem - plus reasoning (aka rationale) plus all possible answers and the target is - the correct answer. - partial_reasoning: an additional option related to - correct_answer_given_reasoning; if set to True, then we take a random - prefix of the reasoning. - order_prediction: if set to True, then input is the problem and a list of - all operations; with probability 0.5 two operations are swapped; the task - consists in detecting whether the operations were swapped. A similar - additional task was considered in https://arxiv.org/pdf/1909.11942.pdf and - in a recent work of Piotr PiΔ™kos, henrykm@ and mateuszm@. - - Returns: - aqua_yield_examples: a generator of Aqua examples; the generator yields - non-tokenized examples - they can be further processed using for example - the tokenize function from this module - """ - if train: - dataset_path = os.path.join(dataset_path, 'train.json') - else: - dataset_path = os.path.join(dataset_path, 'dev.json') - # Opening with GFile allows to use remotely stored files, e.g. - # in a gs bucket. - dataset_handle = tf.io.gfile.GFile(dataset_path, 'r') - dataset = [] - for line in dataset_handle: - dataset.append(json.loads(line)) - - def aqua_yield_examples(generator=None): - del generator - while True: - for example in itertools.cycle(dataset): - input_prefix = example['question'] - steps = example['rationale'].split('\n') - if cumulative: - for i in range(len(steps)): - input_values = 'infer cumulative rationale: ' + input_prefix - target_values = steps[i] - input_prefix += ' ' + steps[i] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif rationale: - input_values = 'infer full rationale: ' + input_prefix - target_values = example['rationale'] - yield input_values, target_values, np.array([1] * len(target_values)) - elif correct_answer: - input_values = 'infer correct answer: ' + input_prefix - input_values += ' ' + ' '.join(example['options']) - target_values = example['correct'] - yield input_values, target_values, np.array([1] * len(target_values)) - elif correct_answer_given_reasoning: - input_values = 'infer correct answer given reasoning: ' + input_prefix - if partial_reasoning: - reasoning_list = example['rationale'].split('\n') - reasoning_list = reasoning_list[0:np.random - .randint(0, len(reasoning_list))] - reasoning = '\n'.join(reasoning_list) - else: - reasoning = example['rationale'] - input_values += ' ' + example['rationale'] + ' ' + ' '.join( - example['options']) - target_values = example['correct'] - yield input_values, target_values, np.array([1] * len(target_values)) - elif order_prediction: - if np.random.uniform() < 0.5 and len(steps) >= 2: - idx = range(len(steps)) - i1, i2 = random.sample(idx, 2) - steps[i1], steps[i2] = steps[i2], steps[i1] - target_values = 'not_ordered' - else: - target_values = 'ordered' - input_values = 'order prediction: ' + input_prefix + ' ' + '\n'.join( - steps) - yield input_values, target_values, np.array([1] * len(target_values)) - else: - raise ValueError( - 'One of the boolean parameters of the Aqua generator must be set to True.' - ) - - return aqua_yield_examples - - -@gin.configurable(module='trax.data') -def CreateDropInputs( # pylint: disable=invalid-name - train=True, mathqa_format=False): - """Prepares Drop inputs. - - Args: - train: if True, then generate training examples, otherwhise generate - validation examples (the dataset has also a test set). - mathqa_format: if True, then floats in targets are converted to the - the MathQA convention and wrapped in the subtract operation. - E.g. "3.13" is converted to "subtract(const_3_13,const_0)". - - Returns: - drop_yield_examples: a generator of Drop examples; the generator yields - non-tokenized examples - they can be further processed using for example - the tokenize function from this module - """ - if train: - dataset = tfds.load(name='drop', split='train') - else: - dataset = tfds.load(name='drop', split='dev') - dataset = tfds.as_numpy(dataset) - - def drop_yield_examples(generator=None): - del generator - while True: - for example in itertools.cycle(dataset): - input_values = 'drop question: ' + example['passage'].decode( - 'utf-8') + ' ' + example['question'].decode('utf-8') - target_values = example['answer'].decode('utf-8') - # Apparently the dataset has some empty "target values" - - # when such a value is encountered, the Tokenizer decides to assign - # to it a float32 tensor and the training fails. - if not target_values: - continue - if mathqa_format: - if target_values.replace('.', '', 1).isdigit(): - target_values = convert_to_subtract( - convert_float_to_mathqa(target_values)) - yield input_values, target_values, np.array( - [1] * len(target_values), dtype=np.int32) - - return drop_yield_examples - - -@gin.configurable(module='trax.data') + order_prediction=False, +): + """Prepares Aqua inputs. + + Args: + dataset_path: a path with the Aqua dataset. + train: if True, then generate training examples, otherwhise generate + validation examples (the dataset has also a test set). + cumulative: if set to True, then generate examples in the format input - + problem + step1 + step3 + step3 target - step4 If set to False, then + examples are in the format input - problem, target - all operations. + rationale: if set to True, then input is the problem and the target is the + rationale. + correct_answer: if set to True, then input is the problem plus all possible + answers and the target is the correct answer. + correct_answer_given_reasoning: if set to True, then input is the problem + plus reasoning (aka rationale) plus all possible answers and the target is + the correct answer. + partial_reasoning: an additional option related to + correct_answer_given_reasoning; if set to True, then we take a random + prefix of the reasoning. + order_prediction: if set to True, then input is the problem and a list of + all operations; with probability 0.5 two operations are swapped; the task + consists in detecting whether the operations were swapped. A similar + additional task was considered in https://arxiv.org/pdf/1909.11942.pdf and + in a recent work of Piotr PiΔ™kos, henrykm@ and mateuszm@. + + Returns: + aqua_yield_examples: a generator of Aqua examples; the generator yields + non-tokenized examples - they can be further processed using for example + the tokenize function from this module + """ + if train: + dataset_path = os.path.join(dataset_path, "train.json") + else: + dataset_path = os.path.join(dataset_path, "dev.json") + # Opening with GFile allows to use remotely stored files, e.g. + # in a gs bucket. + dataset_handle = tf.io.gfile.GFile(dataset_path, "r") + dataset = [] + for line in dataset_handle: + dataset.append(json.loads(line)) + + def aqua_yield_examples(generator=None): + del generator + while True: + for example in itertools.cycle(dataset): + input_prefix = example["question"] + steps = example["rationale"].split("\n") + if cumulative: + for i in range(len(steps)): + input_values = "infer cumulative rationale: " + input_prefix + target_values = steps[i] + input_prefix += " " + steps[i] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif rationale: + input_values = "infer full rationale: " + input_prefix + target_values = example["rationale"] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif correct_answer: + input_values = "infer correct answer: " + input_prefix + input_values += " " + " ".join(example["options"]) + target_values = example["correct"] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif correct_answer_given_reasoning: + input_values = ( + "infer correct answer given reasoning: " + input_prefix + ) + if partial_reasoning: + reasoning_list = example["rationale"].split("\n") + reasoning_list = reasoning_list[ + 0 : np.random.randint(0, len(reasoning_list)) + ] + reasoning = "\n".join(reasoning_list) + else: + reasoning = example["rationale"] + input_values += ( + " " + example["rationale"] + " " + " ".join(example["options"]) + ) + target_values = example["correct"] + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + elif order_prediction: + if np.random.uniform() < 0.5 and len(steps) >= 2: + idx = range(len(steps)) + i1, i2 = random.sample(idx, 2) + steps[i1], steps[i2] = steps[i2], steps[i1] + target_values = "not_ordered" + else: + target_values = "ordered" + input_values = ( + "order prediction: " + input_prefix + " " + "\n".join(steps) + ) + yield input_values, target_values, np.array( + [1] * len(target_values) + ) + else: + raise ValueError( + "One of the boolean parameters of the Aqua generator must be set to True." + ) + + return aqua_yield_examples + + +@gin.configurable(module="trax.data") +def CreateDropInputs(train=True, mathqa_format=False): # pylint: disable=invalid-name + """Prepares Drop inputs. + + Args: + train: if True, then generate training examples, otherwhise generate + validation examples (the dataset has also a test set). + mathqa_format: if True, then floats in targets are converted to the + the MathQA convention and wrapped in the subtract operation. + E.g. "3.13" is converted to "subtract(const_3_13,const_0)". + + Returns: + drop_yield_examples: a generator of Drop examples; the generator yields + non-tokenized examples - they can be further processed using for example + the tokenize function from this module + """ + if train: + dataset = tfds.load(name="drop", split="train") + else: + dataset = tfds.load(name="drop", split="dev") + dataset = tfds.as_numpy(dataset) + + def drop_yield_examples(generator=None): + del generator + while True: + for example in itertools.cycle(dataset): + input_values = ( + "drop question: " + + example["passage"].decode("utf-8") + + " " + + example["question"].decode("utf-8") + ) + target_values = example["answer"].decode("utf-8") + # Apparently the dataset has some empty "target values" - + # when such a value is encountered, the Tokenizer decides to assign + # to it a float32 tensor and the training fails. + if not target_values: + continue + if mathqa_format: + if target_values.replace(".", "", 1).isdigit(): + target_values = convert_to_subtract( + convert_float_to_mathqa(target_values) + ) + yield input_values, target_values, np.array( + [1] * len(target_values), dtype=np.int32 + ) + + return drop_yield_examples + + +@gin.configurable(module="trax.data") def CreateAnnotatedDropInputs( # pylint: disable=invalid-name dataset_path=None, train=True, single_file=True, unique=False, total_number_of_samples=None, - percentile=1.): - r"""Prepares annotated Drop inputs. - - Example of an annotated input which can be used with this interface: - - { - 'passage': 'The Armenian Prelature of Cyprus was established in 973 by - Catholicos Khatchig I. Historically, the Prelature has been under the - jurisdiction of the Catholicosate of the Great House of Cilicia, while today - it is the oldest theme that falls under its jurisdiction. Since 2014 the - Prelate, a Catholicosal Vicar General, has been Archbishop Nareg Alemezian. - The parish priest in Nicosia is Fr. Momik Habeshian, while the parish priest - in Larnaca and Limassol is Fr. Mashdots Ashkarian. For centuries, the - Prelature building was located within the Armenian compound in Victoria - street in walled Nicosia; when that area was taken over by Turkish-Cypriot - extremists in 1963-1964, the Prelature was temporarily housed in Aram - Ouzounian street and, later on, in Kyriakos Matsis street in Ayios - Dhometios. Thanks to the efforts of Bishop Zareh Aznavorian and with - financial aid from the Evangelical Church of Westphalia, the new Prelature - building was erected in 1983, next to the Virgin Mary church and the Nareg - school in Nicosia, by architects Athos Dikaios & Alkis Dikaios; it was - officially inaugurated on 4 March 1984, during the pastoral visit of - Catholicos Karekin II. By initiative of Archbishop Varoujan Hergelian, in - 1998 the basement of the building was renovated and the "Vahram Utidjian" - Hall was formed; previously a store room, it became a reality from the - proceeds of the auction in 1994 of the art collection that Vahram Utidjian - had donated to the Prelature in 1954. It was inaugurated on 3 February 1999 - by Catholicos Aram I; numerous charity, communal and cultural events take - place there. The Prelature\'s consistory houses a collection of - ecclesiastical relics, some of which were previously in the old Virgin Mary - church or the Magaravank.', - 'question': 'How many years after the Vahram Utidjian was donated to the - Prelature was it sold at an auction?', - 'answer': 40, - 'calculation': 'subtract(n8,n9)' - } - - In this example the calculation is formulated using the notation from the - MathQA dataset, but this is not required. subtract(n8,n9) means that the - answer 40 can be obtained through the substraction of the 9th and and the 10th - number in the input. The input consists of the passage concatened with the - question. The annotations can be generated using, for example, a method - from the paper https://arxiv.org/abs/1909.00109. - - Args: - dataset_path: a path with the Aqua dataset. - train: if True, then generate training examples, otherwhise generate - validation examples (the dataset has also a test set). - single_file: if True, then look just for one file. If False, read all - json files in a given directory and assume that each file contains one - example. Applied only to training data. - unique: if set to True, then the generator will provide at most one question - per passage. - total_number_of_samples: if set to a positive integer, then the total number - of unique samples will be bounded total_number_of_samples. - percentile: the percentile of the train dataset used for training; default - set to 1., though setting to a lower value can be interesting when - combined train is combined with another source of data. - - Returns: - drop_annotated_yield_examples: a generator of annotated Drop examples; - the generator yields non-tokenized examples - they can be further processed - using for example the tokenize function from this module. - """ - if train: - if single_file: - dataset_path = os.path.join(dataset_path, 'train_annotated.json') - else: - dataset_path = os.path.join(dataset_path, 'dev_annotated.json') - - def load_dataset(): - dataset = [] - if single_file: - # Opening with GFile allows to use remotely stored files, e.g. - # in a gs bucket. - dataset_handle = tf.io.gfile.GFile(dataset_path, 'r') - for line in dataset_handle: - dataset.append(json.loads(line)) + percentile=1.0, +): + r"""Prepares annotated Drop inputs. + + Example of an annotated input which can be used with this interface: + + { + 'passage': 'The Armenian Prelature of Cyprus was established in 973 by + Catholicos Khatchig I. Historically, the Prelature has been under the + jurisdiction of the Catholicosate of the Great House of Cilicia, while today + it is the oldest theme that falls under its jurisdiction. Since 2014 the + Prelate, a Catholicosal Vicar General, has been Archbishop Nareg Alemezian. + The parish priest in Nicosia is Fr. Momik Habeshian, while the parish priest + in Larnaca and Limassol is Fr. Mashdots Ashkarian. For centuries, the + Prelature building was located within the Armenian compound in Victoria + street in walled Nicosia; when that area was taken over by Turkish-Cypriot + extremists in 1963-1964, the Prelature was temporarily housed in Aram + Ouzounian street and, later on, in Kyriakos Matsis street in Ayios + Dhometios. Thanks to the efforts of Bishop Zareh Aznavorian and with + financial aid from the Evangelical Church of Westphalia, the new Prelature + building was erected in 1983, next to the Virgin Mary church and the Nareg + school in Nicosia, by architects Athos Dikaios & Alkis Dikaios; it was + officially inaugurated on 4 March 1984, during the pastoral visit of + Catholicos Karekin II. By initiative of Archbishop Varoujan Hergelian, in + 1998 the basement of the building was renovated and the "Vahram Utidjian" + Hall was formed; previously a store room, it became a reality from the + proceeds of the auction in 1994 of the art collection that Vahram Utidjian + had donated to the Prelature in 1954. It was inaugurated on 3 February 1999 + by Catholicos Aram I; numerous charity, communal and cultural events take + place there. The Prelature\'s consistory houses a collection of + ecclesiastical relics, some of which were previously in the old Virgin Mary + church or the Magaravank.', + 'question': 'How many years after the Vahram Utidjian was donated to the + Prelature was it sold at an auction?', + 'answer': 40, + 'calculation': 'subtract(n8,n9)' + } + + In this example the calculation is formulated using the notation from the + MathQA dataset, but this is not required. subtract(n8,n9) means that the + answer 40 can be obtained through the substraction of the 9th and and the 10th + number in the input. The input consists of the passage concatened with the + question. The annotations can be generated using, for example, a method + from the paper https://arxiv.org/abs/1909.00109. + + Args: + dataset_path: a path with the Aqua dataset. + train: if True, then generate training examples, otherwhise generate + validation examples (the dataset has also a test set). + single_file: if True, then look just for one file. If False, read all + json files in a given directory and assume that each file contains one + example. Applied only to training data. + unique: if set to True, then the generator will provide at most one question + per passage. + total_number_of_samples: if set to a positive integer, then the total number + of unique samples will be bounded total_number_of_samples. + percentile: the percentile of the train dataset used for training; default + set to 1., though setting to a lower value can be interesting when + combined train is combined with another source of data. + + Returns: + drop_annotated_yield_examples: a generator of annotated Drop examples; + the generator yields non-tokenized examples - they can be further processed + using for example the tokenize function from this module. + """ + if train: + if single_file: + dataset_path = os.path.join(dataset_path, "train_annotated.json") else: - all_files = tf.io.gfile.listdir(dataset_path) - for filename in all_files: - if 'json' in filename: - print('Loading data from file {}'.format(filename)) - with tf.io.gfile.GFile(os.path.join(dataset_path, filename)) as f: - for line in f: - dataset.append(json.loads(line)) - print('The total size of the dataset {}'.format(len(dataset))) - return dataset[:int(len(dataset) * percentile)] - - def drop_annotated_yield_examples(generator=None): - del generator - while True: - passages = set() - unique_examples = set() - # Notice that below we enable a poor man RL loop - # aka the DAgger algorithm: https://arxiv.org/pdf/1011.0686.pdf - # tl;dr: after parsing all examples we re-load the dataset - this - # may become handy if a prediction service generates new examples. - dataset = load_dataset() - for example in dataset: - # If total_number_of_samples is not None and we have reached this - # number of samples, then we re-load the dataset. - if total_number_of_samples: - if len(unique_examples) >= total_number_of_samples: - break - # Do we have a pre-calculated input in the example? - if 'input' in example.keys(): - question = example['input'] - # Remove the old prompt - question = question[question.find(':') + 2:] + dataset_path = os.path.join(dataset_path, "dev_annotated.json") + + def load_dataset(): + dataset = [] + if single_file: + # Opening with GFile allows to use remotely stored files, e.g. + # in a gs bucket. + dataset_handle = tf.io.gfile.GFile(dataset_path, "r") + for line in dataset_handle: + dataset.append(json.loads(line)) else: - # If input is not present, then we expect that this is an - # original drop example. - if unique and example['passage'] in passages: - continue - passages.add(example['passage']) - question = example['passage'] + ' ' + example['question'] - list_num = [ - float(num.replace(',', '').rstrip('.').lstrip('.')) # pylint: disable=g-complex-comprehension - for num in re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - question) - ] - for i in range(len(list_num)): - question += ' n{} = {}'.format(i, list_num[i]) - input_values = 'drop annotated question: ' + question - target_values = example['calculation'] - unique_examples.add((input_values, target_values)) - yield input_values, target_values, np.array( - [1] * len(target_values), dtype=np.int32) - - return drop_annotated_yield_examples + all_files = tf.io.gfile.listdir(dataset_path) + for filename in all_files: + if "json" in filename: + print("Loading data from file {}".format(filename)) + with tf.io.gfile.GFile(os.path.join(dataset_path, filename)) as f: + for line in f: + dataset.append(json.loads(line)) + print("The total size of the dataset {}".format(len(dataset))) + return dataset[: int(len(dataset) * percentile)] + + def drop_annotated_yield_examples(generator=None): + del generator + while True: + passages = set() + unique_examples = set() + # Notice that below we enable a poor man RL loop + # aka the DAgger algorithm: https://arxiv.org/pdf/1011.0686.pdf + # tl;dr: after parsing all examples we re-load the dataset - this + # may become handy if a prediction service generates new examples. + dataset = load_dataset() + for example in dataset: + # If total_number_of_samples is not None and we have reached this + # number of samples, then we re-load the dataset. + if total_number_of_samples: + if len(unique_examples) >= total_number_of_samples: + break + # Do we have a pre-calculated input in the example? + if "input" in example.keys(): + question = example["input"] + # Remove the old prompt + question = question[question.find(":") + 2 :] + else: + # If input is not present, then we expect that this is an + # original drop example. + if unique and example["passage"] in passages: + continue + passages.add(example["passage"]) + question = example["passage"] + " " + example["question"] + list_num = [ + float( + num.replace(",", "").rstrip(".").lstrip(".") + ) # pylint: disable=g-complex-comprehension + for num in re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", + question, + ) + ] + for i in range(len(list_num)): + question += " n{} = {}".format(i, list_num[i]) + input_values = "drop annotated question: " + question + target_values = example["calculation"] + unique_examples.add((input_values, target_values)) + yield input_values, target_values, np.array( + [1] * len(target_values), dtype=np.int32 + ) + + return drop_annotated_yield_examples diff --git a/trax/data/tf_inputs_test.py b/trax/data/tf_inputs_test.py deleted file mode 100644 index 376f59c2b..000000000 --- a/trax/data/tf_inputs_test.py +++ /dev/null @@ -1,873 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.data.tf_inputs.""" - -import collections -import os -from unittest import mock - -import gin -import numpy as np -from t5.data import assert_dataset -from t5.data import preprocessors as t5_processors -import tensorflow as tf -import tensorflow_datasets as tfds -from trax.data import inputs # pylint: disable=unused-import -from trax.data import tf_inputs - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') - - -def _test_dataset_ints(inp_lengths, tgt_lengths): - """Create a test dataset of int64 tensors of given shapes.""" - - def generator(): - for inp_len, tgt_len in zip(inp_lengths, tgt_lengths): - inp = np.ones([inp_len], dtype=np.int64) - tgt = np.ones([tgt_len], dtype=np.int64) - yield {'inputs': inp, 'targets': tgt} - - types = {'inputs': tf.int64, 'targets': tf.int64} - shapes = {'inputs': tf.TensorShape([None]), 'targets': tf.TensorShape([None])} - return tf.data.Dataset.from_generator( - generator, output_types=types, output_shapes=shapes) - - -def _load_dataset(name, split='train'): - return tfds.load( - name=name, split=split, data_dir=_TESTDATA, shuffle_files=False) - - -def _c4_dataset(split='train'): - return _load_dataset('c4:2.3.0', split=split) - - -def _spm_path(): - return os.path.join(_TESTDATA, 'sentencepiece.model') - - -def _t5_gin_config(): - # The following pages worth of gin configuration are required because a lot - # of T5 functions have `gin.REQUIRED` in code, i.e. you cannot use these - # functions at all without having configured gin. - - noise_density = 0.15 - max_input_length = 50 - - # What preprocessors to apply - we select a random chunk of the document if - # it exceeds a certain lengths (`select_random_chunk`), then split up long - # examples (`split_tokens`) and finally the denoising objective (`denoise`). - # - # In addition to this T5 concates multiple documents together to reduce - # padding (`reduce_concat_tokens`) after `select_random_chunk`, but we skip - # that since we don't do sequence packing. - gin.bind_parameter('unsupervised.preprocessors', [ - t5_processors.select_random_chunk, - t5_processors.split_tokens, - t5_processors.denoise, - ]) - - # select_random_chunk - gin.bind_parameter('select_random_chunk.feature_key', 'targets') - gin.bind_parameter('select_random_chunk.max_length', max_input_length) - - # reduce_concat_tokens - gin.bind_parameter('random_spans_helper.extra_tokens_per_span_inputs', 1) - gin.bind_parameter('random_spans_helper.extra_tokens_per_span_targets', 1) - gin.bind_parameter('random_spans_helper.inputs_length', max_input_length) - gin.bind_parameter('random_spans_helper.mean_noise_span_length', 3.0) - gin.bind_parameter('random_spans_helper.noise_density', noise_density) - - # split_tokens - gin.bind_parameter('split_tokens.max_tokens_per_segment', - t5_processors.random_spans_tokens_length()) - - # denoise - gin.bind_parameter('denoise.inputs_fn', - t5_processors.noise_span_to_unique_sentinel) - gin.bind_parameter('denoise.noise_density', noise_density) - gin.bind_parameter('denoise.noise_mask_fn', - t5_processors.random_spans_noise_mask) - gin.bind_parameter('denoise.targets_fn', - t5_processors.nonnoise_span_to_unique_sentinel) - - -class TFInputsTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - - def test_TFDS_single_host_with_eval_holdout(self): - train_ds_gen = tf_inputs.TFDS( - 'c4/en:2.3.0', - data_dir=_TESTDATA, - train=True, - host_id=0, - keys=('text',), - n_hosts=1, - eval_holdout_size=0.1) - - # Just ensure that this doesn't crash. - for d in train_ds_gen(): - print(f'Train: {d}') - break - - valid_ds_gen = tf_inputs.TFDS( - 'c4/en:2.3.0', - data_dir=_TESTDATA, - train=False, - host_id=0, - keys=('text',), - n_hosts=1, - eval_holdout_size=0.1) - - # Just ensure that this doesn't crash. - for d in valid_ds_gen(): - print(f'Eval: {d}') - break - - def test_TFDS_single_host_with_eval_holdout_no_valid_split(self): - train_ds_gen = tf_inputs.TFDS( - 'para_crawl/ende', - data_dir=_TESTDATA, - train=True, - host_id=0, - keys=('en', 'de'), - n_hosts=1, - eval_holdout_size=0.1) - - # Just ensure that this doesn't crash. - for d in train_ds_gen(): - print(f'Train: {d}') - break - - # para_crawl doesn't have a validation set, see that this still doesn't - # crash because of eval_holdout_set. - valid_ds_gen = tf_inputs.TFDS( - 'para_crawl/ende', - data_dir=_TESTDATA, - train=False, - host_id=0, - keys=('en', 'de'), - n_hosts=1, - eval_holdout_size=0.1) - - # Just ensure that this doesn't crash. - for d in valid_ds_gen(): - print(f'Eval: {d}') - break - - def test_TFDS_mnli_split_is_eval(self): - with mock.patch('tensorflow_datasets.load') as tfds_load: - with mock.patch('trax.data.tf_inputs.download_and_prepare', - lambda _, data_dir: data_dir): - _ = tf_inputs.TFDS('glue/mnli', - keys=('premise', 'hypothesis'), - train=False) - call_kwargs = tfds_load.call_args[1] - self.assertEqual(call_kwargs['split'], 'validation_matched') - - def test_TFDS_mnli_split_is_alt_eval(self): - with mock.patch('tensorflow_datasets.load') as tfds_load: - with mock.patch('trax.data.tf_inputs.download_and_prepare', - lambda _, data_dir: data_dir): - _ = tf_inputs.TFDS('glue/mnli', - keys=('premise', 'hypothesis'), - train=False, - use_alt_eval=True) - call_kwargs = tfds_load.call_args[1] - self.assertEqual(call_kwargs['split'], 'validation_mismatched') - - def test_convert_to_unicode(self): - - def dataset1(): - yield (b'Audentes fortuna iuvat.', b'Fortune favors the bold.') - - def dataset2(): - yield (b'\x81aabb', b'Value') - - convert_function1 = tf_inputs.ConvertToUnicode(keys=[0]) - convert_output1 = next(convert_function1(dataset1())) - self.assertEqual(convert_output1[0], 'Audentes fortuna iuvat.') - self.assertEqual(convert_output1[1], b'Fortune favors the bold.') - self.assertIsInstance(convert_output1[0], str) - self.assertIsInstance(convert_output1[1], bytes) - - # Contains an invalid bytes array from the point of view of UTF-8. - try: - convert_function2 = tf_inputs.ConvertToUnicode(keys=[0]) - convert_output2 = next(convert_function2(dataset2())) - except UnicodeDecodeError: - self.fail('ConvertToUnicode threw UnicodeDecodeError.') - self.assertEqual(convert_output2[0], 'aabb') - self.assertIsInstance(convert_output2[0], str) - - def test_tokenize_detokenize(self): - - def dataset(): - yield 'I have a cat.' - - # Character-level. - tok_char = list(tf_inputs.tokenize(dataset(), vocab_type='char')) - self.assertAllEqual(tok_char[0], - np.array([ord(c) for c in 'I have a cat.'])) - detok = tf_inputs.detokenize(tok_char[0], vocab_type='char') - self.assertEqual(detok, 'I have a cat.') - - # Sentencepiece. - tok_spc = list( - tf_inputs.tokenize( - dataset(), - vocab_type='sentencepiece', - vocab_dir=_TESTDATA, - vocab_file='sentencepiece.model')) - self.assertAllEqual(tok_spc[0], np.array([27, 43, 3, 9, 1712, 5])) - detok = tf_inputs.detokenize( - list(tok_spc[0]), - vocab_type='sentencepiece', - vocab_dir=_TESTDATA, - vocab_file='sentencepiece.model') - self.assertEqual(detok, 'I have a cat.') - - # Subword. - tok_sbw = list( - tf_inputs.tokenize( - dataset(), - vocab_type='subword', - vocab_dir=_TESTDATA, - vocab_file='en_8k.subword')) - self.assertAllEqual(tok_sbw[0], np.array([139, 96, 12, 2217, 2, 21])) - detok = tf_inputs.detokenize( - tok_sbw[0], - vocab_type='subword', - vocab_dir=_TESTDATA, - vocab_file='en_8k.subword') - self.assertEqual(detok, 'I have a cat.') - - # bert-lowercase - tok_sbw = list( - tf_inputs.tokenize( - dataset(), - vocab_type='bert-lowercase', - vocab_dir=_TESTDATA, - vocab_file='bert_uncased_vocab.txt')) - self.assertAllEqual(tok_sbw[0], np.array([1045, 2031, 1037, 4937, 1012])) - detok = tf_inputs.detokenize( - tok_sbw[0], - vocab_type='bert-lowercase', - vocab_dir=_TESTDATA, - vocab_file='bert_uncased_vocab.txt') - self.assertEqual(detok, 'i have a cat .') - # note: BERT tokenizer is not reversible, therefore - # difference between original input - - def test_tokenize_keys_reservedids(self): - - def dataset(): - yield ('Cat.', 'Dog.') - - tok_char1 = list( - tf_inputs.tokenize(dataset(), vocab_type='char', n_reserved_ids=5)) - self.assertAllEqual(tok_char1[0][0], np.array([ord(c) + 5 for c in 'Cat.'])) - self.assertAllEqual(tok_char1[0][1], np.array([ord(c) + 5 for c in 'Dog.'])) - - tok_char2 = list( - tf_inputs.tokenize( - dataset(), keys=[0], vocab_type='char', n_reserved_ids=2)) - self.assertAllEqual(tok_char2[0][0], np.array([ord(c) + 2 for c in 'Cat.'])) - self.assertEqual(tok_char2[0][1], 'Dog.') - - def test_tokenize_dict(self): - - def dataset(): - yield {'a': 'Cat.', 'b': 'Dog.'} - - tok_char1 = list(tf_inputs.tokenize(dataset(), vocab_type='char')) - self.assertAllEqual(tok_char1[0]['a'], np.array([ord(c) for c in 'Cat.'])) - self.assertAllEqual(tok_char1[0]['b'], np.array([ord(c) for c in 'Dog.'])) - - tok_char2 = list( - tf_inputs.tokenize(dataset(), keys=['a'], vocab_type='char')) - self.assertAllEqual(tok_char2[0]['a'], np.array([ord(c) for c in 'Cat.'])) - self.assertEqual(tok_char2[0]['b'], 'Dog.') - - def test_vocab_size(self): - # Character-level. - char_size = tf_inputs.vocab_size(vocab_type='char', n_reserved_ids=11) - self.assertEqual(char_size, 256 + 11) - # Sentencepiece. - spc_size = tf_inputs.vocab_size( - vocab_type='sentencepiece', - vocab_dir=_TESTDATA, - vocab_file='sentencepiece.model') - self.assertEqual(spc_size, 32000) - # Subword. - sbw_size = tf_inputs.vocab_size( - vocab_type='subword', vocab_dir=_TESTDATA, vocab_file='en_8k.subword') - self.assertEqual(sbw_size, 8183) - # Bert_uncased. - sbw_size = tf_inputs.vocab_size( - vocab_type='bert-lowercase', - vocab_dir=_TESTDATA, - vocab_file='bert_uncased_vocab.txt') - self.assertEqual(sbw_size, 30522) - - def test_c4_bare_preprocess_fn(self): - dataset = _c4_dataset() - - example = list(tfds.as_numpy(dataset.take(1)))[0] - - # Targets are NOT in the example. - self.assertNotIn('targets', example) - self.assertIn('text', example) - text = example['text'] - - # This should convert the dataset to an inputs/targets that are tokenized. - dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) - - example = list(tfds.as_numpy(dataset.take(1)))[0] - - # Earlier text is now stored in targets_pretokenized - self.assertIn('targets_pretokenized', example) - self.assertEqual(example['targets_pretokenized'], text) - - # Targets are now tokenized. - self.assertIn('targets', example) - self.assertIsInstance(example['targets'], np.ndarray) - self.assertEqual(example['targets'].dtype, np.int64) - self.assertGreater(len(example['targets']), 0) - self.assertEqual(example['targets'][-1], 1) # we add EOS at the end. - - # Inputs exist but is empty because t5 preprocessors' unsupervised wasn't - # gin configured with any. - self.assertIn('inputs', example) - self.assertEqual(len(example['inputs']), 0) - - def test_c4_preprocess(self): - - def load_c4_dataset(split='train'): - dataset = _c4_dataset(split=split) - return dataset.map(lambda example: (example, example['text'])) - - def examine_processed_dataset(proc_dataset): - count = 0 - lengths = [] - for example in tfds.as_numpy(proc_dataset): - count += 1 - ex = example[0] - # Targets are in the example. - self.assertIn('targets', ex) - self.assertEqual(ex['targets'].dtype, np.int64) - lengths.append(len(ex['targets'])) - return count, lengths - - unfiltered_count = 0 - for example in tfds.as_numpy(load_c4_dataset()): - unfiltered_count += 1 - # Targets are NOT in the example. - self.assertNotIn('targets', example[0]) - - proc_dataset = tf_inputs.c4_preprocess(load_c4_dataset(), False, 2048) - - # `examine_processed_dataset` has some asserts in it. - proc_count, char_lengths = examine_processed_dataset(proc_dataset) - - # Both the original and filtered datasets have examples. - self.assertGreater(unfiltered_count, 0) - self.assertGreater(proc_count, 0) - - # Because we filter out some entries on length. - self.assertLess(proc_count, unfiltered_count) - - # Preprocess using the sentencepiece model in testdata. - spc_proc_dataset = tf_inputs.c4_preprocess( - load_c4_dataset(), - False, - 2048, - tokenization='spc', - spm_path=_spm_path()) - - spc_proc_count, spc_lengths = examine_processed_dataset(spc_proc_dataset) - - # spc shortens the target sequence a lot, should be almost equal to - # unfiltered - self.assertLessEqual(proc_count, spc_proc_count) - self.assertEqual(unfiltered_count, spc_proc_count) - - # Assert all spc_lengths are lesser than their char counterparts. - for spc_len, char_len in zip(spc_lengths, char_lengths): - self.assertLessEqual(spc_len, char_len) - - def test_c4(self): - gin.bind_parameter('c4_preprocess.max_target_length', 2048) - gin.bind_parameter('c4_preprocess.tokenization', 'spc') - gin.bind_parameter('c4_preprocess.spm_path', _spm_path()) - - # Just make sure this doesn't throw. - _ = tf_inputs.data_streams( - 'c4', - data_dir=_TESTDATA, - input_name='targets', - target_name='text', - preprocess_fn=tf_inputs.c4_preprocess) - - def test_c4_bare_preprocess_fn_denoising_objective(self): - _t5_gin_config() - - dataset = _c4_dataset() - dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) - - example = list(tfds.as_numpy(dataset.take(1)))[0] - - # Assertions now. - - self.assertIn('targets', example) - targets = example['targets'] - self.assertIsInstance(targets, np.ndarray) - self.assertEqual(targets.dtype, np.int64) - self.assertGreater(len(targets), 0) - - self.assertIn('inputs', example) - _inputs = example['inputs'] # pylint: disable=invalid-name - self.assertIsInstance(_inputs, np.ndarray) - self.assertEqual(_inputs.dtype, np.int64) - self.assertGreater(len(_inputs), 0) - - # WHP inputs will have the bulk of the text. - self.assertGreater(len(_inputs), len(targets)) - - # WHP there will be one sentinel token in the inputs and targets. - inputs_counter = collections.Counter(_inputs.tolist()) - targets_counter = collections.Counter(targets.tolist()) - self.assertEqual(1, inputs_counter[31999]) - self.assertEqual(1, targets_counter[31999]) - - def test_c4_pretrain(self): - _t5_gin_config() - - gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path()) - - gin.bind_parameter('batcher.batch_size_per_device', 8) - gin.bind_parameter('batcher.eval_batch_size', 8) - gin.bind_parameter('batcher.max_eval_length', 50) - gin.bind_parameter('batcher.buckets', ([51], [8, 1])) - - # Just make sure this doesn't throw. - _ = tf_inputs.data_streams( - 'c4', - data_dir=_TESTDATA, - input_name='inputs', - target_name='targets', - bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn) - - def test_generic_text_dataset_preprocess_fn(self): - dataset = _load_dataset('squad/v1.1:3.0.0') - - example, = tfds.as_numpy(dataset.take(1)) - - self.assertNotIn('inputs', example) - self.assertNotIn('targets', example) - - proc_dataset = tf_inputs.generic_text_dataset_preprocess_fn( - dataset, - spm_path=_spm_path(), - text_preprocess_fns=[lambda ds, training: t5_processors.squad(ds)], - copy_pretokenized=True, - debug_print_examples=True, - debug_print_examples_rate=1.0) - - proc_example, = tfds.as_numpy(proc_dataset.take(1)) - - self.assertIn('inputs', proc_example) - self.assertIn('targets', proc_example) - - self.assertEqual(proc_example['inputs'].dtype, np.int32) - self.assertEqual(proc_example['targets'].dtype, np.int32) - - # TODO(afrozm): Why does this test take so much time? - def test_inputs_using_generic_text_dataset_preprocess_fn(self): - gin.bind_parameter('generic_text_dataset_preprocess_fn.spm_path', - _spm_path()) - gin.bind_parameter('generic_text_dataset_preprocess_fn.text_preprocess_fns', - [lambda ds, training: t5_processors.squad(ds)]) - - # Just make sure this doesn't throw. - def data_streams(): - return tf_inputs.data_streams( - 'squad', - data_dir=_TESTDATA, - input_name='inputs', - target_name='targets', - bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, - shuffle_buffer_size=1) - - n_devices = 3 - - squad_inputs = inputs.batcher( - data_streams=data_streams, - max_eval_length=512, - buckets=([ - 513, - ], [n_devices, n_devices])) - - eval_stream = squad_inputs.eval_stream(n_devices) - inps, tgts, _ = next(eval_stream) - - # We can only assert that the batch dim gets divided by n_devices. - self.assertEqual(inps.shape[0] % n_devices, 0) - self.assertEqual(tgts.shape[0] % n_devices, 0) - - def test_filter_dataset_on_len(self): - # {1, 2}, {2, 4}, {3, 6} ... {10, 20} - ds = _test_dataset_ints(range(1, 11), range(2, 21, 2)) - - ds1 = tf_inputs.filter_dataset_on_len(ds, True, { - 'inputs': [4, 8], - 'targets': [14, 20] - }) - # Only {7, 14} and {8, 16} satisfy this. - self.assertLen(list(ds1.as_numpy_iterator()), 2) - - ds2 = tf_inputs.filter_dataset_on_len( - ds, - False, - len_map={ - 'inputs': [4, 8], - 'targets': [14, 20] - }, - filter_on_eval=False) - # This is eval and we aren't supposed to filter it. - self.assertLen(list(ds2.as_numpy_iterator()), 10) - - ds3 = tf_inputs.filter_dataset_on_len( - ds, - False, - len_map={ - 'inputs': [4, 8], - 'targets': [14, 20] - }, - filter_on_eval=True) - # This is eval and we are asked to filter it. - self.assertLen(list(ds3.as_numpy_iterator()), 2) - - def test_truncate_dataset_on_len(self): - ds = _test_dataset_ints([5, 6, 7], [8, 9, 10]) - ds1 = tf_inputs.truncate_dataset_on_len( - ds, True, len_map={ - 'inputs': 6, - 'targets': 4 - }) - expected_ds = _test_dataset_ints([5, 6, 6], [4, 4, 4]) - - # training, should filter. - assert_dataset(ds1, list(expected_ds.as_numpy_iterator())) - - # not Training, shouldn't filter. - ds2 = tf_inputs.truncate_dataset_on_len( - ds, False, len_map={ - 'inputs': 6, - 'targets': 4 - }) - assert_dataset(ds2, list(ds.as_numpy_iterator())) - - # not Training, but asked to filter, should filter. - ds3 = tf_inputs.truncate_dataset_on_len( - ds, False, len_map={ - 'inputs': 6, - 'targets': 4 - }, truncate_on_eval=True) - assert_dataset(ds3, list(expected_ds.as_numpy_iterator())) - - def test_get_t5_preprocessor_by_name(self): - gin.clear_config() - - gin.parse_config(""" - get_t5_preprocessor_by_name.name = 'rekey' - get_t5_preprocessor_by_name.fn_kwargs = {'key_map': {'inputs': 'other', 'targets': 'text'}} - """) - prep_rekey = tf_inputs.get_t5_preprocessor_by_name() - og_dataset = tf.data.Dataset.from_tensors({ - 'text': 'That is good.', - 'other': 'That is bad.' - }) - training = True - dataset = prep_rekey(og_dataset, training) - assert_dataset(dataset, { - 'inputs': 'That is bad.', - 'targets': 'That is good.' - }) - - def test_pad_dataset_to_length(self): - ds = _test_dataset_ints([5, 6, 7], [6, 7, 8]) - ds1 = tf_inputs.pad_dataset_to_length( - ds, True, len_map={ - 'inputs': 7, - 'targets': 10 - }) - - expected_ds = [ - { - 'inputs': np.array([1, 1, 1, 1, 1, 0, 0], dtype=np.int64), - 'targets': np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0], dtype=np.int64), - }, - { - 'inputs': np.array([1, 1, 1, 1, 1, 1, 0], dtype=np.int64), - 'targets': np.array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64), - }, - { - 'inputs': np.array([1, 1, 1, 1, 1, 1, 1], dtype=np.int64), - 'targets': np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0], dtype=np.int64), - }, - ] - - assert_dataset(ds1, expected_ds) - - def test_lm_token_preprocessing(self): - ds = _test_dataset_ints([1, 2, 3], [3, 2, 1]) - ds1 = tf_inputs.lm_token_preprocessing(ds, True) - - # pylint: disable=bad-whitespace - expected_ds = [ - { - 'inputs': np.array([1, 0, 1, 1, 1], dtype=np.int64), - 'targets': np.array([1, 0, 1, 1, 1], dtype=np.int64), - 'mask': np.array([0, 0, 1, 1, 1], dtype=np.int64), - }, - { - 'inputs': np.array([1, 1, 0, 1, 1], dtype=np.int64), - 'targets': np.array([1, 1, 0, 1, 1], dtype=np.int64), - 'mask': np.array([0, 0, 0, 1, 1], dtype=np.int64), - }, - { - 'inputs': np.array([1, 1, 1, 0, 1], dtype=np.int64), - 'targets': np.array([1, 1, 1, 0, 1], dtype=np.int64), - 'mask': np.array([0, 0, 0, 0, 1], dtype=np.int64), - }, - ] - # pylint: enable=bad-whitespace - - assert_dataset(ds1, expected_ds) - - def test_create_bert_inputs(self): - inputs_sentences_1 = [np.array([100, 150, 200])] - inputs_sentences_2 = [np.array([300, 500])] - labels = [np.array(1)] - - create_inputs_1 = tf_inputs.CreateBertInputs(False) - create_inputs_2 = tf_inputs.CreateBertInputs(True) - for res in create_inputs_1(zip(inputs_sentences_1, labels)): - values, segment_embs, _, label, weight = res - self.assertAllEqual(values, np.array([101, 100, 150, 200, 102])) - self.assertAllEqual(segment_embs, np.zeros(5)) - self.assertEqual(label, np.int64(1)) - self.assertEqual(weight, np.int64(1)) - - for res in create_inputs_2( - zip(inputs_sentences_1, inputs_sentences_2, labels)): - values, segment_embs, _, label, weight = res - self.assertAllEqual(values, - np.array([101, 100, 150, 200, 102, 300, 500, 102])) - exp_segment = np.concatenate((np.zeros(5), np.ones(3))) - self.assertAllEqual(segment_embs, exp_segment) - self.assertEqual(label, np.int64(1)) - self.assertEqual(weight, np.int64(1)) - - def test_mask_random_tokens(self): - """Test only standard tokens. - - This test deals with sentences composed of two parts: [100 CLS tokens, 100 - chosen standard tokens]. CLS is the token that is added at the beginning of - the sentence and there is only one token in standard scenario. It is never - masked because it is not a part of the sentence. - This tests whether mask_random_tokens will: - - mask only standard tokens - - mask expected number of tokens (15 percent candidates for masking) - """ - cls_token = 101 - mask_token = 103 - example_standard_token = 1001 - test_case_row = np.array([cls_token] * 100 + [example_standard_token] * 100) - test_case = [(test_case_row.copy(),)] - - out, original_tokens, token_weights = next( - tf_inputs.mask_random_tokens(test_case)) - # test whether original tokens are unchanged - self.assertAllEqual(test_case_row, original_tokens) - - self.assertEqual(1, token_weights.sum()) - self.assertEqual( - 15, - (token_weights > 0).sum()) # we should have 15 candidates for masking - - # 101 is a special token, so only 1001 should be masked - self.assertAllEqual(out[:100], test_case_row[:100]) - - # Each candidate has 0.8 probability to be masked while others have 0, so - # no more than 15 tokens with MASK - self.assertLessEqual((out == mask_token).sum(), 15) - - def test_bert_next_sentence_prediction_inputs(self): - stream = tf_inputs.BertNextSentencePredictionInputs( - 'c4/en:2.3.0', data_dir=_TESTDATA, train=False, shuffle_size=1) - exp_sent1 = 'Police were called to the carriageway around 6.' - exp_sent2 = 'I am sorry we did not see how lost and alone you felt.' - sent1, sent2, label = next(stream()) - self.assertEqual(exp_sent1, sent1) - self.assertEqual(exp_sent2, sent2) - self.assertFalse(label) - - def test_process_single_mathqa_example_0(self): - # This is the first problem in the MathQA dataset. - example = { - 'Problem': - "the banker ' s gain of a certain sum due 3 years hence at 10 % " - 'per annum is rs . 36 . what is the present worth ?', - 'Rationale': - '"explanation : t = 3 years r = 10 % td = ( bg Γ— 100 ) / tr = ( ' - '36 Γ— 100 ) / ( 3 Γ— 10 ) = 12 Γ— 10 = rs . 120 td = ( pw Γ— tr )' - ' / 100 β‡’ 120 = ( pw Γ— 3 Γ— 10 ) / 100 β‡’ 1200 = pw Γ— 3 pw = ' - '1200 / 3 = rs . 400 answer : option a"', - 'options': - 'a ) rs . 400 , b ) rs . 300 , c ) rs . 500 , d ) rs . 350 , e ) ' - 'none of these', - 'correct': - 'a', - 'annotated_formula': - 'divide(multiply(const_100, divide(multiply(36, const_100), ' - 'multiply(3, 10))), multiply(3, 10))', - 'linear_formula': - 'multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)|multiply(#2,const_100)|divide(#3,#1)|', - 'category': - 'gain' - } - - answer_num, python_result, python_program, list_op, list_num = tf_inputs.process_single_mathqa_example( - example) - self.assertEqual(answer_num, - 400) # we know it, because correct answer is a) - self.assertEqual(python_result, [3600.0, 30.0, 120.0, 12000.0, 400.0]) - - self.assertEqual(python_program, [ - 't0 = n2 * 100.0', 't1 = n0 * n1', 't2 = t0 / t1', 't3 = t2 * 100.0', - 't4 = t3 / t1' - ]) - self.assertEqual(list_op, [ - 'multiply(n2,const_100)', 'multiply(n0,n1)', 'divide(#0,#1)', - 'multiply(#2,const_100)', 'divide(#3,#1)' - ]) - self.assertEqual(list_num, [3.0, 10.0, 36.0]) - - def test_process_single_mathqa_example_1(self): - # This is the third problem in the MathQA dataset. - example = { - 'Problem': - 'sophia finished 2 / 3 of a book . she calculated that she ' - 'finished 90 more pages than she has yet to read . how long is her' - ' book ?', - 'Rationale': - 'let xx be the total number of pages in the book , then she ' - 'finished 23 β‹… x 23 β‹… x pages . then she has x βˆ’ 23 β‹… x = ' - '13 β‹… xx βˆ’ 23 β‹… x = 13 β‹… x pages left . 23 β‹… x βˆ’ 13 ' - 'β‹… x = 9023 β‹… x βˆ’ 13 β‹… x = 90 13 β‹… x = 9013 β‹… x = 90 x' - ' = 270 x = 270 so the book is 270 pages long . answer : b', - 'options': 'a ) 229 , b ) 270 , c ) 877 , d ) 266 , e ) 281', - 'correct': 'b', - 'annotated_formula': 'divide(90, subtract(const_1, divide(2, 3)))', - 'linear_formula': 'divide(n0,n1)|subtract(const_1,#0)|divide(n2,#1)', - 'category': 'general' - } - - answer_num, python_result, python_program, list_op, list_num = tf_inputs.process_single_mathqa_example( - example) - self.assertEqual(answer_num, - 270) # we know it, because correct answer is b) - self.assertAllClose( - python_result, - [0.6666666666666666, 0.33333333333333337, 269.99999999999994]) - self.assertEqual(python_program, - ['t0 = n0 / n1', 't1 = 1.0 - t0', 't2 = n2 / t1']) - self.assertEqual(list_op, - ['divide(n0,n1)', 'subtract(const_1,#0)', 'divide(n2,#1)']) - self.assertEqual(list_num, [2.0, 3.0, 90.0]) - - def test_process_single_mathqa_example_with_import(self): - # This is a training MathQA problem which involve an import. - example = { - 'Problem': - 'the length of a rectangular garden is three times its width . if ' - 'the area of the rectangular garden is 588 square meters , then ' - 'what is the width of the rectangular garden ?', - 'Rationale': - '\"let x be the width of the garden . 3 x ^ 2 = 588 x ^ 2 = 196 x ' - '= 14 the answer is c .\"', - 'options': - 'a ) 12 , b ) 13 , c ) 14 , d ) 15 , e ) 16', - 'correct': - 'c', - 'annotated_formula': - 'sqrt(divide(588, const_3))', - 'linear_formula': - 'divide(n0,const_3)|sqrt(#0)|', - 'category': - 'geometry' - } - - answer_num, python_result, python_program, list_op, list_num = tf_inputs.process_single_mathqa_example( - example) - self.assertEqual(answer_num, 14) # we know it, because correct answer is c) - self.assertAllClose(python_result, [196, 14]) - self.assertEqual( - python_program, - ['t0 = n0 / 3.0', 't1 = math.sqrt(max(0, t0))']) - self.assertEqual(list_op, ['divide(n0,const_3)', 'sqrt(#0)']) - self.assertEqual(list_num, [588]) - - # Below we execute twice the Python program and once the DSL program. - target_values = 'import math\n' - problem = example['Problem'] - for i in range(len(list_num)): - target_values += 'n{} = {}\n'.format(i, list_num[i]) - problem += ' n{} = {}'.format(i, list_num[i]) - target_values += '\n'.join(python_program[:-1]) - final_line = python_program[-1].split('=')[1] - target_values += '\nanswer ={}'.format(final_line) - var_dict = {} - exec(target_values, globals(), var_dict) # pylint: disable=exec-used - self.assertAllClose(var_dict['answer'], 14) - self.assertAllClose( - tf_inputs.execute_mathqa_program(problem, target_values.split('\n')), - 14) - self.assertAllClose( - tf_inputs.execute_mathqa_dsl_program(problem, - [example['linear_formula']]), 14) - - - def test_sentencepiece_tokenize(self): - def dataset(): - yield 'I have a cat.' - - examples = [] - for example in tf_inputs.sentencepiece_tokenize(dataset(), _spm_path()): - examples.append(example) - toks = list(examples[0]) - self.assertSequenceEqual([27, 43, 3, 9, 1712, 5], toks) - - -if __name__ == '__main__': - tf.test.main() diff --git a/trax/data/tokenizer.py b/trax/data/tokenizer.py index 64081f4da..669758439 100644 --- a/trax/data/tokenizer.py +++ b/trax/data/tokenizer.py @@ -51,138 +51,141 @@ # This set contains all letter and number characters. _ALPHANUMERIC_CHAR_SET = set( - six.unichr(i) for i in range(sys.maxunicode) - if (unicodedata.category(six.unichr(i)).startswith("L") or - unicodedata.category(six.unichr(i)).startswith("N"))) + six.unichr(i) + for i in range(sys.maxunicode) + if ( + unicodedata.category(six.unichr(i)).startswith("L") + or unicodedata.category(six.unichr(i)).startswith("N") + ) +) def encode(text): - """Encode a unicode string as a list of tokens. - - Args: - text: a unicode string - Returns: - a list of tokens as Unicode strings - """ - if not text: - return [] - ret = [] - token_start = 0 - # Classify each character in the input string - is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] - for pos in range(1, len(text)): - if is_alnum[pos] != is_alnum[pos - 1]: - token = text[token_start:pos] - if token != u" " or token_start == 0: - ret.append(token) - token_start = pos - final_token = text[token_start:] - ret.append(final_token) - return ret + """Encode a unicode string as a list of tokens. + + Args: + text: a unicode string + Returns: + a list of tokens as Unicode strings + """ + if not text: + return [] + ret = [] + token_start = 0 + # Classify each character in the input string + is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] + for pos in range(1, len(text)): + if is_alnum[pos] != is_alnum[pos - 1]: + token = text[token_start:pos] + if token != " " or token_start == 0: + ret.append(token) + token_start = pos + final_token = text[token_start:] + ret.append(final_token) + return ret def decode(tokens): - """Decode a list of tokens to a unicode string. - - Args: - tokens: a list of Unicode strings - Returns: - a unicode string - """ - token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] - ret = [] - for i, token in enumerate(tokens): - if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: - ret.append(u" ") - ret.append(token) - return "".join(ret) + """Decode a list of tokens to a unicode string. + + Args: + tokens: a list of Unicode strings + Returns: + a unicode string + """ + token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] + ret = [] + for i, token in enumerate(tokens): + if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: + ret.append(" ") + ret.append(token) + return "".join(ret) def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True): - """Reads files matching a wildcard pattern, yielding the contents. - - Args: - filepattern: A wildcard pattern matching one or more files. - max_lines: If set, stop reading after reading this many lines. - split_on_newlines: A boolean. If true, then split files by lines and strip - leading and trailing whitespace from each line. Otherwise, treat each - file as a single string. - - Yields: - The contents of the files as lines, if split_on_newlines is True, or - the entire contents of each file if False. - """ - filenames = sorted(tf.io.gfile.glob(filepattern)) - lines_read = 0 - for filename in filenames: - with tf.io.gfile.GFile(filename) as f: - if split_on_newlines: - for line in f: - yield line.strip() - lines_read += 1 - if max_lines and lines_read >= max_lines: - return - - else: - if max_lines: - doc = [] - for line in f: - doc.append(line) - lines_read += 1 - if max_lines and lines_read >= max_lines: - yield "".join(doc) - return - yield "".join(doc) - - else: - yield f.read() - - -def corpus_token_counts( - text_filepattern, corpus_max_lines, split_on_newlines=True): - """Read the corpus and compute a dictionary of token counts. - - Args: - text_filepattern: A pattern matching one or more files. - corpus_max_lines: An integer; maximum total lines to read. - split_on_newlines: A boolean. If true, then split files by lines and strip - leading and trailing whitespace from each line. Otherwise, treat each - file as a single string. - - Returns: - a dictionary mapping token to count. - """ - counts = collections.Counter() - for doc in _read_filepattern( - text_filepattern, - max_lines=corpus_max_lines, - split_on_newlines=split_on_newlines): - counts.update(encode(doc)) - - return counts + """Reads files matching a wildcard pattern, yielding the contents. + + Args: + filepattern: A wildcard pattern matching one or more files. + max_lines: If set, stop reading after reading this many lines. + split_on_newlines: A boolean. If true, then split files by lines and strip + leading and trailing whitespace from each line. Otherwise, treat each + file as a single string. + + Yields: + The contents of the files as lines, if split_on_newlines is True, or + the entire contents of each file if False. + """ + filenames = sorted(tf.io.gfile.glob(filepattern)) + lines_read = 0 + for filename in filenames: + with tf.io.gfile.GFile(filename) as f: + if split_on_newlines: + for line in f: + yield line.strip() + lines_read += 1 + if max_lines and lines_read >= max_lines: + return + + else: + if max_lines: + doc = [] + for line in f: + doc.append(line) + lines_read += 1 + if max_lines and lines_read >= max_lines: + yield "".join(doc) + return + yield "".join(doc) + + else: + yield f.read() + + +def corpus_token_counts(text_filepattern, corpus_max_lines, split_on_newlines=True): + """Read the corpus and compute a dictionary of token counts. + + Args: + text_filepattern: A pattern matching one or more files. + corpus_max_lines: An integer; maximum total lines to read. + split_on_newlines: A boolean. If true, then split files by lines and strip + leading and trailing whitespace from each line. Otherwise, treat each + file as a single string. + + Returns: + a dictionary mapping token to count. + """ + counts = collections.Counter() + for doc in _read_filepattern( + text_filepattern, + max_lines=corpus_max_lines, + split_on_newlines=split_on_newlines, + ): + counts.update(encode(doc)) + + return counts def vocab_token_counts(text_filepattern, max_lines): - """Read a vocab file and return a dictionary of token counts. + """Read a vocab file and return a dictionary of token counts. - Reads a two-column CSV file of tokens and their frequency in a dataset. The - tokens are presumed to be generated by encode() or the equivalent. + Reads a two-column CSV file of tokens and their frequency in a dataset. The + tokens are presumed to be generated by encode() or the equivalent. - Args: - text_filepattern: A pattern matching one or more files. - max_lines: An integer; maximum total lines to read. + Args: + text_filepattern: A pattern matching one or more files. + max_lines: An integer; maximum total lines to read. - Returns: - a dictionary mapping token to count. - """ - ret = {} - for i, line in enumerate( - _read_filepattern(text_filepattern, max_lines=max_lines)): - if "," not in line: - logging.warning("Malformed vocab line #%d '%s'", i, line) - continue + Returns: + a dictionary mapping token to count. + """ + ret = {} + for i, line in enumerate(_read_filepattern(text_filepattern, max_lines=max_lines)): + if "," not in line: + logging.warning("Malformed vocab line #%d '%s'", i, line) + continue - token, count = line.rsplit(",", 1) - ret[token] = int(count) + token, count = line.rsplit(",", 1) + ret[token] = int(count) - return ret + return ret diff --git a/trax/data/tokenizer_test.py b/trax/data/tokenizer_test.py deleted file mode 100644 index 593ebe83d..000000000 --- a/trax/data/tokenizer_test.py +++ /dev/null @@ -1,136 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.data..tokenizer.""" -import os -import random - -import six -from six.moves import range # pylint: disable=redefined-builtin -import tensorflow.compat.v1 as tf -from trax.data import tokenizer - - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, "testdata") - - -class TokenizerTest(tf.test.TestCase): - - def test_encode(self): - self.assertListEqual( - [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."], - tokenizer.encode(u"Dude - that's so cool.")) - self.assertListEqual([u"Łukasz", u"est", u"nΓ©", u"en", u"1981", u"."], - tokenizer.encode(u"Łukasz est nΓ© en 1981.")) - self.assertListEqual([u" ", u"Spaces", u"at", u"the", u"ends", u" "], - tokenizer.encode(u" Spaces at the ends ")) - self.assertListEqual([u"802", u".", u"11b"], tokenizer.encode(u"802.11b")) - self.assertListEqual([u"two", u". \n", u"lines"], - tokenizer.encode(u"two. \nlines")) - - def test_decode(self): - self.assertEqual( - u"Dude - that's so cool.", - tokenizer.decode( - [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."])) - - def test_invertibility_on_random_strings(self): - for _ in range(1000): - s = u"".join(six.unichr(random.randint(0, 65535)) for _ in range(10)) - self.assertEqual(s, tokenizer.decode(tokenizer.encode(s))) - - -class TestTokenCounts(tf.test.TestCase): - - def setUp(self): - super(TestTokenCounts, self).setUp() - self.corpus_path = os.path.join(_TESTDATA, "corpus-*.txt") - self.vocab_path = os.path.join(_TESTDATA, "vocab-*.txt") - - def test_corpus_token_counts_split_on_newlines(self): - token_counts = tokenizer.corpus_token_counts( - self.corpus_path, corpus_max_lines=0, split_on_newlines=True) - - expected = { - u"'": 2, - u".": 2, - u". ": 1, - u"... ": 1, - u"Groucho": 1, - u"Marx": 1, - u"Mitch": 1, - u"Hedberg": 1, - u"I": 3, - u"in": 2, - u"my": 2, - u"pajamas": 2, - } - self.assertDictContainsSubset(expected, token_counts) - self.assertNotIn(u".\n\n", token_counts) - self.assertNotIn(u"\n", token_counts) - - def test_corpus_token_counts_no_split_on_newlines(self): - token_counts = tokenizer.corpus_token_counts( - self.corpus_path, corpus_max_lines=0, split_on_newlines=False) - - self.assertDictContainsSubset({u".\n\n": 2, u"\n": 3}, token_counts) - - def test_corpus_token_counts_split_with_max_lines(self): - token_counts = tokenizer.corpus_token_counts( - self.corpus_path, corpus_max_lines=5, split_on_newlines=True) - - self.assertIn(u"slept", token_counts) - self.assertNotIn(u"Mitch", token_counts) - - def test_corpus_token_counts_no_split_with_max_lines(self): - token_counts = tokenizer.corpus_token_counts( - self.corpus_path, corpus_max_lines=5, split_on_newlines=False) - - self.assertIn(u"slept", token_counts) - self.assertNotIn(u"Mitch", token_counts) - self.assertDictContainsSubset({ - u".\n\n": 1, - u"\n": 2, - u".\n": 1 - }, token_counts) - - def test_vocab_token_counts(self): - token_counts = tokenizer.vocab_token_counts(self.vocab_path, 0) - - expected = { - u"lollipop": 8, - u"reverberated": 12, - u"kattywampus": 11, - u"balderdash": 10, - u"jiggery-pokery": 14, - } - self.assertDictEqual(expected, token_counts) - - def test_vocab_token_counts_with_max_lines(self): - # vocab-1 has 2 lines, vocab-2 has 3 - token_counts = tokenizer.vocab_token_counts(self.vocab_path, 5) - - expected = { - u"lollipop": 8, - u"reverberated": 12, - u"kattywampus": 11, - u"balderdash": 10, - } - self.assertDictEqual(expected, token_counts) - - -if __name__ == "__main__": - tf.test.main() diff --git a/trax/fastmath/jax.py b/trax/fastmath/jax.py index df838c708..a5a6ea259 100644 --- a/trax/fastmath/jax.py +++ b/trax/fastmath/jax.py @@ -29,191 +29,220 @@ from trax.shapes import signature -def jax_conv(inp, fltr, window_strides, padding, dimension_numbers, - filter_dilation=None): - """A wrapper around `lax.conv_general_dilated`. - - It requires `dimension_numbers` and disallows `inp_dilation`. - - Args: - inp: an (N+2)-D array. The input of the convolution. - fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. - window_strides: the strides for moving the convolution window. - padding: a string, either 'VALID' or 'SAME'. The padding algorithm. - dimension_numbers: a tuple of three strings encoding the data format of - input, filter and output. 'I' means input; 'O' means output; 'C' means - channel; other characters such as 'W', 'H' and 'D' means spatial - dimensions. - filter_dilation: the dilation rates for the filter. Dilating the filter - means adding "holes" to the filter. - - Returns: - An (N+2)-D array. The convolution result. - """ - return lax.conv_general_dilated(inp, fltr, window_strides, padding, - lhs_dilation=None, - rhs_dilation=filter_dilation, - dimension_numbers=dimension_numbers) - - -def _pooling_general(inputs, reducer, init_val, rescaler=None, - pool_size=(2, 2), strides=None, padding='VALID'): - """Helper: general pooling computation used in pooling layers later.""" - spatial_strides = strides or (1,) * len(pool_size) - rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None - dims = (1,) + pool_size + (1,) # NHWC - strides = (1,) + spatial_strides + (1,) - out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding) - return rescale(out, inputs) if rescale else out # pylint: disable=not-callable +def jax_conv( + inp, fltr, window_strides, padding, dimension_numbers, filter_dilation=None +): + """A wrapper around `lax.conv_general_dilated`. + + It requires `dimension_numbers` and disallows `inp_dilation`. + + Args: + inp: an (N+2)-D array. The input of the convolution. + fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. + window_strides: the strides for moving the convolution window. + padding: a string, either 'VALID' or 'SAME'. The padding algorithm. + dimension_numbers: a tuple of three strings encoding the data format of + input, filter and output. 'I' means input; 'O' means output; 'C' means + channel; other characters such as 'W', 'H' and 'D' means spatial + dimensions. + filter_dilation: the dilation rates for the filter. Dilating the filter + means adding "holes" to the filter. + + Returns: + An (N+2)-D array. The convolution result. + """ + return lax.conv_general_dilated( + inp, + fltr, + window_strides, + padding, + lhs_dilation=None, + rhs_dilation=filter_dilation, + dimension_numbers=dimension_numbers, + ) + + +def _pooling_general( + inputs, + reducer, + init_val, + rescaler=None, + pool_size=(2, 2), + strides=None, + padding="VALID", +): + """Helper: general pooling computation used in pooling layers later.""" + spatial_strides = strides or (1,) * len(pool_size) + rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None + dims = (1,) + pool_size + (1,) # NHWC + strides = (1,) + spatial_strides + (1,) + out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding) + return rescale(out, inputs) if rescale else out # pylint: disable=not-callable def jax_max_pool(x, pool_size, strides, padding): - return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size, - strides=strides, padding=padding) + return _pooling_general( + x, lax.max, -jnp.inf, pool_size=pool_size, strides=strides, padding=padding + ) def jax_sum_pool(x, pool_size, strides, padding): - return _pooling_general(x, lax.add, 0., pool_size=pool_size, - strides=strides, padding=padding) + return _pooling_general( + x, lax.add, 0.0, pool_size=pool_size, strides=strides, padding=padding + ) -def _normalize_by_window_size(dims, spatial_strides, padding): # pylint: disable=invalid-name - def rescale(outputs, inputs): - one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype) - window_sizes = lax.reduce_window( - one, 0., lax.add, dims, spatial_strides, padding) - return outputs / window_sizes[..., jnp.newaxis] - return rescale +def _normalize_by_window_size( + dims, spatial_strides, padding +): # pylint: disable=invalid-name + def rescale(outputs, inputs): + one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype) + window_sizes = lax.reduce_window( + one, 0.0, lax.add, dims, spatial_strides, padding + ) + return outputs / window_sizes[..., jnp.newaxis] + + return rescale def jax_avg_pool(x, pool_size, strides, padding): - return _pooling_general(x, lax.add, 0., _normalize_by_window_size, - pool_size, strides=strides, padding=padding) + return _pooling_general( + x, + lax.add, + 0.0, + _normalize_by_window_size, + pool_size, + strides=strides, + padding=padding, + ) def jax_abstract_eval(f): - """Returns a function that evaluates `f` given input shapes and dtypes. + """Returns a function that evaluates `f` given input shapes and dtypes. + + It transforms function `f` to a function that performs the same computation as + `f` but only on shapes and dtypes (a.k.a. shape inference). - It transforms function `f` to a function that performs the same computation as - `f` but only on shapes and dtypes (a.k.a. shape inference). + Args: + f: the function to be transformed. - Args: - f: the function to be transformed. + Returns: + A function whose input arguments can be either the same as `f`'s or only + their shapes/dtypes represented by `ShapeDtype`, and whose return values are + `ShapeDtype`s with the same nested structure as `f`'s return values. + """ - Returns: - A function whose input arguments can be either the same as `f`'s or only - their shapes/dtypes represented by `ShapeDtype`, and whose return values are - `ShapeDtype`s with the same nested structure as `f`'s return values. - """ - def shape_fun(*args, **kwargs): - jax_shapes = jax.eval_shape(f, *args, **kwargs) - return tnp.nested_map(signature, jax_shapes) - return shape_fun + def shape_fun(*args, **kwargs): + jax_shapes = jax.eval_shape(f, *args, **kwargs) + return tnp.nested_map(signature, jax_shapes) + + return shape_fun # The default value of dtype is different from jax_random.randint def jax_randint(key, shape, minval, maxval, dtype=np.int32): - """Sample uniform random values in [minval, maxval) with given shape/dtype. + """Sample uniform random values in [minval, maxval) with given shape/dtype. - Args: - key: a PRNGKey used as the random key. - shape: a tuple of nonnegative integers representing the shape. - minval: int or array of ints broadcast-compatible with ``shape``, a minimum - (inclusive) value for the range. - maxval: int or array of ints broadcast-compatible with ``shape``, a maximum - (exclusive) value for the range. - dtype: optional, an int dtype for the returned values (default int32). + Args: + key: a PRNGKey used as the random key. + shape: a tuple of nonnegative integers representing the shape. + minval: int or array of ints broadcast-compatible with ``shape``, a minimum + (inclusive) value for the range. + maxval: int or array of ints broadcast-compatible with ``shape``, a maximum + (exclusive) value for the range. + dtype: optional, an int dtype for the returned values (default int32). - Returns: - A random array with the specified shape and dtype. - """ - return jax_random.randint(key, shape, minval=minval, maxval=maxval, - dtype=dtype) + Returns: + A random array with the specified shape and dtype. + """ + return jax_random.randint(key, shape, minval=minval, maxval=maxval, dtype=dtype) def _to_numpy(x): - """Converts non-NumPy tensors to NumPy arrays.""" - return x if isinstance(x, np.ndarray) else x.numpy() + """Converts non-NumPy tensors to NumPy arrays.""" + return x if isinstance(x, np.ndarray) else x.numpy() def _dataset_as_numpy(ds, batch_size=None): - """Speed up tfds.as_numpy by batching and then iterating over the batches.""" - batch_size = batch_size or 1 - try: # Check that dense_to_ragged_batch exists. - if batch_size < 2: # Fall back to default if no batching requested. - raise AttributeError - ds_batch = ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size)) - for example in tfds.as_numpy(ds_batch): - flat_example = tnp.tree_flatten(example) - np_flat_example = [_to_numpy(x) for x in flat_example] - for single_example_flat in zip(*np_flat_example): - single_example, _ = tnp.tree_unflatten(single_example_flat, example) - yield single_example - except AttributeError: - # In TF 1.X there is not dense_to_ragged_batch: fallback. - for example in tfds.as_numpy(ds): - yield example + """Speed up tfds.as_numpy by batching and then iterating over the batches.""" + batch_size = batch_size or 1 + try: # Check that dense_to_ragged_batch exists. + if batch_size < 2: # Fall back to default if no batching requested. + raise AttributeError + ds_batch = ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size)) + for example in tfds.as_numpy(ds_batch): + flat_example = tnp.tree_flatten(example) + np_flat_example = [_to_numpy(x) for x in flat_example] + for single_example_flat in zip(*np_flat_example): + single_example, _ = tnp.tree_unflatten(single_example_flat, example) + yield single_example + except AttributeError: + # In TF 1.X there is not dense_to_ragged_batch: fallback. + for example in tfds.as_numpy(ds): + yield example def _custom_grad(f_vjp, f_original): - f_ = jax.custom_transforms(f_original) - jax.defvjp_all(f_, f_vjp) - return f_ + f_ = jax.custom_transforms(f_original) + jax.defvjp_all(f_, f_vjp) + return f_ def _custom_vjp(f, f_fwd, f_bwd, nondiff_argnums=()): - @functools.partial(jax.custom_vjp, nondiff_argnums=nondiff_argnums) - def _f(*args, **kwargs): - return f(*args, **kwargs) - _f.defvjp(f_fwd, f_bwd) - return _f + @functools.partial(jax.custom_vjp, nondiff_argnums=nondiff_argnums) + def _f(*args, **kwargs): + return f(*args, **kwargs) + + _f.defvjp(f_fwd, f_bwd) + return _f JAX_BACKEND = { - 'name': 'jax', - 'np': jnp, - 'abstract_eval': jax_abstract_eval, - 'avg_pool': jax_avg_pool, - 'cond': lax.cond, - 'conv': jax_conv, - 'custom_vjp': _custom_vjp, - 'custom_grad': _custom_grad, - 'dataset_as_numpy': _dataset_as_numpy, - 'dynamic_slice': jax.lax.dynamic_slice, - 'dynamic_slice_in_dim': jax.lax.dynamic_slice_in_dim, - 'dynamic_update_slice': jax.lax.dynamic_update_slice, - 'dynamic_update_slice_in_dim': jax.lax.dynamic_update_slice_in_dim, - 'erf': jax_special.erf, - 'expit': jax_special.expit, - 'fori_loop': lax.fori_loop, - 'global_device_count': jax.device_count, - 'grad': jax.grad, - 'value_and_grad': jax.value_and_grad, - 'index_add': lambda x, idx, y: jnp.asarray(x).at[idx].add(y), - 'index_max': lambda x, idx, y: jnp.asarray(x).at[idx].max(y), - 'index_min': lambda x, idx, y: jnp.asarray(x).at[idx].min(y), - 'index_update': lambda x, idx, y: jnp.asarray(x).at[idx].set(y), - 'jit': jax.jit, - 'local_device_count': jax.local_device_count, - 'logsumexp': jax_special.logsumexp, - 'lt': lax.lt, - 'map': lax.map, - 'max_pool': jax_max_pool, - 'pmap': jax.pmap, - 'psum': lax.psum, - 'random_bernoulli': jax_random.bernoulli, - 'random_get_prng': jax.jit(jax_random.PRNGKey), - 'random_normal': jax_random.normal, - 'random_randint': jax_randint, - 'random_split': jax_random.split, - 'random_fold_in': jax_random.fold_in, - 'random_uniform': jax_random.uniform, - 'remat': jax.remat, - 'scan': lax.scan, - 'sort_key_val': jax.lax.sort_key_val, - 'stop_gradient': lax.stop_gradient, - 'sum_pool': jax_sum_pool, - 'top_k': lax.top_k, - 'vjp': jax.vjp, - 'vmap': jax.vmap, + "name": "jax", + "np": jnp, + "abstract_eval": jax_abstract_eval, + "avg_pool": jax_avg_pool, + "cond": lax.cond, + "cond": lax.cond, + "conv": jax_conv, + "custom_vjp": _custom_vjp, + "custom_grad": _custom_grad, + "dataset_as_numpy": _dataset_as_numpy, + "dynamic_slice": jax.lax.dynamic_slice, + "dynamic_slice_in_dim": jax.lax.dynamic_slice_in_dim, + "dynamic_update_slice": jax.lax.dynamic_update_slice, + "dynamic_update_slice_in_dim": jax.lax.dynamic_update_slice_in_dim, + "erf": jax_special.erf, + "expit": jax_special.expit, + "fori_loop": lax.fori_loop, + "global_device_count": jax.device_count, + "grad": jax.grad, + "value_and_grad": jax.value_and_grad, + "index_add": lambda x, idx, y: jnp.asarray(x).at[idx].add(y), + "index_max": lambda x, idx, y: jnp.asarray(x).at[idx].max(y), + "index_min": lambda x, idx, y: jnp.asarray(x).at[idx].min(y), + "index_update": lambda x, idx, y: jnp.asarray(x).at[idx].set(y), + "jit": jax.jit, + "local_device_count": jax.local_device_count, + "logsumexp": jax_special.logsumexp, + "lt": lax.lt, + "map": lax.map, + "max_pool": jax_max_pool, + "pmap": jax.pmap, + "psum": lax.psum, + "random_bernoulli": jax_random.bernoulli, + "random_get_prng": jax.jit(jax_random.PRNGKey), + "random_normal": jax_random.normal, + "random_randint": jax_randint, + "random_split": jax_random.split, + "random_fold_in": jax_random.fold_in, + "random_uniform": jax_random.uniform, + "remat": jax.remat, + "scan": lax.scan, + "sort_key_val": jax.lax.sort_key_val, + "stop_gradient": lax.stop_gradient, + "sum_pool": jax_sum_pool, + "top_k": lax.top_k, + "vjp": jax.vjp, + "vmap": jax.vmap, } diff --git a/trax/fastmath/numpy.py b/trax/fastmath/numpy.py index 0826fa416..9861a8fa2 100644 --- a/trax/fastmath/numpy.py +++ b/trax/fastmath/numpy.py @@ -21,269 +21,271 @@ def get_prng(seed): - """JAX-compatible way of getting PRNG seeds.""" - if np.shape(seed): - raise TypeError('PRNGKey seed must be a scalar.') - convert = lambda k: np.reshape(np.asarray(k, np.uint32), [1]) - k1 = convert(np.bitwise_and(np.right_shift(seed, 32), 0xFFFFFFFF)) - k2 = convert(np.bitwise_and(seed, 0xFFFFFFFF)) - return np.concatenate([k1, k2], 0) + """JAX-compatible way of getting PRNG seeds.""" + if np.shape(seed): + raise TypeError("PRNGKey seed must be a scalar.") + convert = lambda k: np.reshape(np.asarray(k, np.uint32), [1]) + k1 = convert(np.bitwise_and(np.right_shift(seed, 32), 0xFFFFFFFF)) + k2 = convert(np.bitwise_and(seed, 0xFFFFFFFF)) + return np.concatenate([k1, k2], 0) def random_uniform(rng, shape=(), dtype=np.float64, minval=0.0, maxval=1.0): - del rng - return np.random.uniform(minval, maxval, size=shape).astype(dtype) + del rng + return np.random.uniform(minval, maxval, size=shape).astype(dtype) def random_normal(rng, shape=(), dtype=np.float64): - del rng - return np.random.normal(size=shape).astype(dtype) + del rng + return np.random.normal(size=shape).astype(dtype) def random_randint(rng, shape, minval, maxval, dtype=np.int64): - del rng - return np.random.randint(minval, maxval, size=shape).astype(dtype) + del rng + return np.random.randint(minval, maxval, size=shape).astype(dtype) def random_bernoulli(rng, p=0.5, shape=()): - del rng - return np.random.binomial(1, p, size=shape) + del rng + return np.random.binomial(1, p, size=shape) def np_abstract_eval(f): - """Abstract evaluation in numpy by running the real function on 0s.""" - def abstract_f(*args, **kwargs): - real_args = [nested_map(lambda x: np.zeros(x.shape, x.dtype), a) - for a in args] - real_res = f(*real_args, **kwargs) - return signature(real_res) - return abstract_f + """Abstract evaluation in numpy by running the real function on 0s.""" + + def abstract_f(*args, **kwargs): + real_args = [nested_map(lambda x: np.zeros(x.shape, x.dtype), a) for a in args] + real_res = f(*real_args, **kwargs) + return signature(real_res) + + return abstract_f NUMPY_BACKEND = { - 'abstract_eval': np_abstract_eval, - 'local_device_count': lambda: 1, - 'global_device_count': lambda: 1, - 'jit': lambda f: f, - 'logsumexp': logsumexp, - 'name': 'numpy', - 'np': np, - 'random_bernoulli': random_bernoulli, - 'random_get_prng': get_prng, - 'random_normal': random_normal, - 'random_randint': random_randint, - 'random_split': lambda prng, num=2: (None,) * num, - 'random_uniform': random_uniform, - 'expit': lambda x: 1. / (1. + np.exp(-x)), + "abstract_eval": np_abstract_eval, + "local_device_count": lambda: 1, + "global_device_count": lambda: 1, + "jit": lambda f: f, + "logsumexp": logsumexp, + "name": "numpy", + "np": np, + "random_bernoulli": random_bernoulli, + "random_get_prng": get_prng, + "random_normal": random_normal, + "random_randint": random_randint, + "random_split": lambda prng, num=2: (None,) * num, + "random_uniform": random_uniform, + "expit": lambda x: 1.0 / (1.0 + np.exp(-x)), } def nested_map(f, obj, level=0, ignore_nones=True): - """Maps `f` recursively inside any dicts/lists/tuples in `obj`. - - Args: - f: A function taking a single object as input. f's input must NOT be a - dict, list, or tuple, or any subclass of those. - obj: Either an input object to f or some nested structure of collections - of (collections of ...) input objects to f. - level: Level in the nested structure to stop at, counted from the leaves - - so level 0 is the leaf, level 1 is such that all of its children are at - level 0 etc. - ignore_nones: Whether to ignore Nones in the structure, i.e. return None - without calling `f`. - - Returns: - An object with the same nested structure as `obj`, but with each input - object `x` replaced by `f(x)`. - """ - if _is_at_level(obj, level): - if ignore_nones and _is_made_of_nones(obj): - return None - else: - return f(obj) - - if _is_namedtuple_instance(obj): - return type(obj)(*nested_map(f, list(obj), level=level)) - if isinstance(obj, list): - return [nested_map(f, y, level=level) for y in obj] - if isinstance(obj, tuple): - return tuple([nested_map(f, y, level=level) for y in obj]) - if isinstance(obj, dict): - return {k: nested_map(f, v, level=level) for (k, v) in obj.items()} - - raise ValueError('Non-exhaustive pattern match for {}.'.format(obj)) + """Maps `f` recursively inside any dicts/lists/tuples in `obj`. + + Args: + f: A function taking a single object as input. f's input must NOT be a + dict, list, or tuple, or any subclass of those. + obj: Either an input object to f or some nested structure of collections + of (collections of ...) input objects to f. + level: Level in the nested structure to stop at, counted from the leaves - + so level 0 is the leaf, level 1 is such that all of its children are at + level 0 etc. + ignore_nones: Whether to ignore Nones in the structure, i.e. return None + without calling `f`. + + Returns: + An object with the same nested structure as `obj`, but with each input + object `x` replaced by `f(x)`. + """ + if _is_at_level(obj, level): + if ignore_nones and _is_made_of_nones(obj): + return None + else: + return f(obj) + + if _is_namedtuple_instance(obj): + return type(obj)(*nested_map(f, list(obj), level=level)) + if isinstance(obj, list): + return [nested_map(f, y, level=level) for y in obj] + if isinstance(obj, tuple): + return tuple([nested_map(f, y, level=level) for y in obj]) + if isinstance(obj, dict): + return {k: nested_map(f, v, level=level) for (k, v) in obj.items()} + + raise ValueError("Non-exhaustive pattern match for {}.".format(obj)) def nested_map_multiarg(f, *objs, ignore_nones=True): - """Maps multi-arg `f` recursively inside any dicts/lists/tuples in `objs`. - - Args: - f: A function taking len(objs) inputs. f's input must NOT be a - dict, list, or tuple, or any subclass of those. - *objs: Either input objects to f or some nested structure of collections - of (collections of ...) input objects to f. - ignore_nones: Whether to ignore Nones in the structure, i.e. return None - without calling `f`. - - Returns: - An object with the same nested structure as `objs[0]`, but with each input - object `x` replaced by `f(*xs)`. - """ - if isinstance(objs[0], list): - return [nested_map_multiarg(f, *[o[i] for o in objs]) - for i in range(len(objs[0]))] - if isinstance(objs[0], tuple): - return tuple([nested_map_multiarg(f, *[o[i] for o in objs]) - for i in range(len(objs[0]))]) - if isinstance(objs[0], dict): - return {k: nested_map_multiarg(f, *[o[k] for o in objs]) - for k in objs[0]} - if ignore_nones and _is_made_of_nones(objs): - return None - return f(*objs) + """Maps multi-arg `f` recursively inside any dicts/lists/tuples in `objs`. + + Args: + f: A function taking len(objs) inputs. f's input must NOT be a + dict, list, or tuple, or any subclass of those. + *objs: Either input objects to f or some nested structure of collections + of (collections of ...) input objects to f. + ignore_nones: Whether to ignore Nones in the structure, i.e. return None + without calling `f`. + + Returns: + An object with the same nested structure as `objs[0]`, but with each input + object `x` replaced by `f(*xs)`. + """ + if isinstance(objs[0], list): + return [ + nested_map_multiarg(f, *[o[i] for o in objs]) for i in range(len(objs[0])) + ] + if isinstance(objs[0], tuple): + return tuple( + [nested_map_multiarg(f, *[o[i] for o in objs]) for i in range(len(objs[0]))] + ) + if isinstance(objs[0], dict): + return {k: nested_map_multiarg(f, *[o[k] for o in objs]) for k in objs[0]} + if ignore_nones and _is_made_of_nones(objs): + return None + return f(*objs) def nested_zip(objs): - """Zips the leaves of each nested structure in `objs`. + """Zips the leaves of each nested structure in `objs`. - Args: - objs: List of nested structures to zip. + Args: + objs: List of nested structures to zip. - Returns: - An object with the same nested structure as each element of `objs`, with - leaves zipped together into tuples. - """ - assert isinstance(objs, (list, tuple)) - assert objs, 'Cannot zip an empty sequence.' + Returns: + An object with the same nested structure as each element of `objs`, with + leaves zipped together into tuples. + """ + assert isinstance(objs, (list, tuple)) + assert objs, "Cannot zip an empty sequence." - if _is_at_level(objs, 1): - return tuple(objs) + if _is_at_level(objs, 1): + return tuple(objs) - if _is_namedtuple_instance(objs[0]): - return type(objs[0])(*nested_zip(list(map(list, objs)))) - if isinstance(objs[0], list): - return [nested_zip([obj[i] for obj in objs]) for i in range(len(objs[0]))] - if isinstance(objs[0], tuple): - return nested_zip(list(map(list, objs))) - if isinstance(objs[0], dict): - return {k: nested_zip([obj[k] for obj in objs]) for k in objs[0]} + if _is_namedtuple_instance(objs[0]): + return type(objs[0])(*nested_zip(list(map(list, objs)))) + if isinstance(objs[0], list): + return [nested_zip([obj[i] for obj in objs]) for i in range(len(objs[0]))] + if isinstance(objs[0], tuple): + return nested_zip(list(map(list, objs))) + if isinstance(objs[0], dict): + return {k: nested_zip([obj[k] for obj in objs]) for k in objs[0]} - raise ValueError('Non-exhaustive pattern match for {}.'.format(objs[0])) + raise ValueError("Non-exhaustive pattern match for {}.".format(objs[0])) def nested_stack(objs, axis=0, np_module=np): - """Stacks the numpy arrays inside any dicts/lists/tuples in `objs`. - - Args: - objs: List of nested structures to stack. - axis: Axis to stack along. - np_module: numpy module to use - typically numpy or jax.numpy. - - Returns: - An object with the same nested structure as each element of `objs`, with - leaves stacked together into numpy arrays. Nones are propagated, i.e. if - each element of the stacked sequence is None, the output will be None. - """ - # nested_map the stacking operation, but stopping at level 1 so at tuples of - # numpy arrays. - return nested_map( - lambda x: np_module.stack(x, axis=axis), - nested_zip(objs), - level=1, - ) + """Stacks the numpy arrays inside any dicts/lists/tuples in `objs`. + + Args: + objs: List of nested structures to stack. + axis: Axis to stack along. + np_module: numpy module to use - typically numpy or jax.numpy. + + Returns: + An object with the same nested structure as each element of `objs`, with + leaves stacked together into numpy arrays. Nones are propagated, i.e. if + each element of the stacked sequence is None, the output will be None. + """ + # nested_map the stacking operation, but stopping at level 1 so at tuples of + # numpy arrays. + return nested_map( + lambda x: np_module.stack(x, axis=axis), + nested_zip(objs), + level=1, + ) def tree_flatten(tree): - """Flatten a tree into a list.""" - if isinstance(tree, (list, tuple)): - # In python, sum of lists starting from [] is the concatenation. - return sum([tree_flatten(t) for t in tree], []) - if isinstance(tree, dict): - # Only use the values in case of a dictionary node. - return sum([tree_flatten(v) for v in tree.values()], []) - return [tree] + """Flatten a tree into a list.""" + if isinstance(tree, (list, tuple)): + # In python, sum of lists starting from [] is the concatenation. + return sum([tree_flatten(t) for t in tree], []) + if isinstance(tree, dict): + # Only use the values in case of a dictionary node. + return sum([tree_flatten(v) for v in tree.values()], []) + return [tree] def tree_leaves(tree, ignore_nones=True): - """Gets the leaves of a tree.""" + """Gets the leaves of a tree.""" - # Right now this is just `tree_flatten`, but we keep this separate since - # JAX's tree_flatten returns the structure of the tree as well. - flattened = tree_flatten(tree) - return [flat for flat in flattened if (not ignore_nones) or flat is not None] + # Right now this is just `tree_flatten`, but we keep this separate since + # JAX's tree_flatten returns the structure of the tree as well. + flattened = tree_flatten(tree) + return [flat for flat in flattened if (not ignore_nones) or flat is not None] def tree_unflatten(flat, tree, copy_from_tree=None): - """Unflatten a list into a tree given the tree shape as second argument. - - Args: - flat: a flat list of elements to be assembled into a tree. - tree: a tree with the structure we want to have in the new tree. - copy_from_tree: optional list of elements that we just copy from tree. - This argument is used when the flat version does not contain all elements - of the expected tree but just a subset, while the rest are filled from - the tree itself. It allows to omit "unnecessary" elements. For example, - consider trees (A, (B, X), X) and (X, (A, X), B) where X is some element - we do not care about. Flattening the first tree and removing X will yield - a flat list [A, B] and the second tree can then be reconstructed from this - list and the tree (X, (E, X), E) with copy_from_tree=[X]. One example - where this is used is the weights-tree of a model, where layers with no - weights have () in the tree and we use copy_from_tree=[()] to restore - a model from a file that only has a list of trainable weights. - - Returns: - A pair (new_tree, rest_of_flat) where the new tree that has the structure - of tree but with leaves from flat, and the remaining elements of flat if - more were provided than the number of leaves of tree (useful for recursion). - """ - if copy_from_tree is not None: - for el in copy_from_tree: - # Equality checks comparing a DeviceArray with other Python objects - # may legitimately raise a TypeError. - try: - if tree == el: - return tree, flat - except TypeError: - continue - - if isinstance(tree, (list, tuple)): - new_tree, rest = [], flat - for t in tree: - new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) - new_tree.append(new_t) - new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree - return new_tree, rest - if isinstance(tree, dict): - new_tree, rest = {}, flat - for k in tree: - new_v, rest = tree_unflatten(rest, tree[k], copy_from_tree=copy_from_tree) - new_tree[k] = new_v - return new_tree, rest - return flat[0], flat[1:] + """Unflatten a list into a tree given the tree shape as second argument. + + Args: + flat: a flat list of elements to be assembled into a tree. + tree: a tree with the structure we want to have in the new tree. + copy_from_tree: optional list of elements that we just copy from tree. + This argument is used when the flat version does not contain all elements + of the expected tree but just a subset, while the rest are filled from + the tree itself. It allows to omit "unnecessary" elements. For example, + consider trees (A, (B, X), X) and (X, (A, X), B) where X is some element + we do not care about. Flattening the first tree and removing X will yield + a flat list [A, B] and the second tree can then be reconstructed from this + list and the tree (X, (E, X), E) with copy_from_tree=[X]. One example + where this is used is the weights-tree of a model, where layers with no + weights have () in the tree and we use copy_from_tree=[()] to restore + a model from a file that only has a list of trainable weights. + + Returns: + A pair (new_tree, rest_of_flat) where the new tree that has the structure + of tree but with leaves from flat, and the remaining elements of flat if + more were provided than the number of leaves of tree (useful for recursion). + """ + if copy_from_tree is not None: + for el in copy_from_tree: + # Equality checks comparing a DeviceArray with other Python objects + # may legitimately raise a TypeError. + try: + if tree == el: + return tree, flat + except TypeError: + continue + + if isinstance(tree, (list, tuple)): + new_tree, rest = [], flat + for t in tree: + new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) + new_tree.append(new_t) + new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree + return new_tree, rest + if isinstance(tree, dict): + new_tree, rest = {}, flat + for k in tree: + new_v, rest = tree_unflatten(rest, tree[k], copy_from_tree=copy_from_tree) + new_tree[k] = new_v + return new_tree, rest + return flat[0], flat[1:] def _is_namedtuple_instance(x): - """Checks if `x` is an instance of a `namedtuple` type.""" - if not isinstance(x, tuple): - return False - return hasattr(x, '_fields') + """Checks if `x` is an instance of a `namedtuple` type.""" + if not isinstance(x, tuple): + return False + return hasattr(x, "_fields") def _is_at_level(obj, level): - """Checks if `obj` is an at level `level`.""" - is_leaf = not isinstance(obj, (list, tuple, dict)) - if level == 0 or is_leaf: - return (level == 0) == is_leaf + """Checks if `obj` is an at level `level`.""" + is_leaf = not isinstance(obj, (list, tuple, dict)) + if level == 0 or is_leaf: + return (level == 0) == is_leaf - if isinstance(obj, dict): - elems = obj.values() - else: - elems = obj - return elems and all(_is_at_level(x, level - 1) for x in elems) + if isinstance(obj, dict): + elems = obj.values() + else: + elems = obj + return elems and all(_is_at_level(x, level - 1) for x in elems) def _is_made_of_nones(obj): - """Checks if `obj` is a nested structure of `None`s.""" - elems = tree_flatten(obj) - # Returning False for an empty list, because it doesn't have any Nones inside. - return elems and all(x is None for x in elems) + """Checks if `obj` is a nested structure of `None`s.""" + elems = tree_flatten(obj) + # Returning False for an empty list, because it doesn't have any Nones inside. + return elems and all(x is None for x in elems) diff --git a/trax/fastmath/ops.py b/trax/fastmath/ops.py index dbd6dfb83..3962a72b9 100644 --- a/trax/fastmath/ops.py +++ b/trax/fastmath/ops.py @@ -39,9 +39,9 @@ @enum.unique class Backend(enum.Enum): - JAX = 'jax' - TFNP = 'tensorflow-numpy' - NUMPY = 'numpy' + JAX = "jax" + TFNP = "tensorflow-numpy" + NUMPY = "numpy" # For numpy and random modules, we need to call "backend()" lazily, only when @@ -52,320 +52,338 @@ class Backend(enum.Enum): # A class that just forwards attribute accesses to backend's numpy object. class NumpyBackend: - """Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.""" + """Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.""" + + def __getattr__(self, attr): + return getattr(backend()["np"], attr) - def __getattr__(self, attr): - return getattr(backend()['np'], attr) numpy = NumpyBackend() class RandomBackend: - """Backend providing random functions.""" + """Backend providing random functions.""" - def get_prng(self, seed): - return backend()['random_get_prng'](seed) + def get_prng(self, seed): + return backend()["random_get_prng"](seed) - def split(self, prng, num=2): - return backend()['random_split'](prng, num) + def split(self, prng, num=2): + return backend()["random_split"](prng, num) - def fold_in(self, rng, data): - return backend()['random_fold_in'](rng, data) + def fold_in(self, rng, data): + return backend()["random_fold_in"](rng, data) - def uniform(self, *args, **kwargs): - return backend()['random_uniform'](*args, **kwargs) + def uniform(self, *args, **kwargs): + return backend()["random_uniform"](*args, **kwargs) - def randint(self, *args, **kwargs): - return backend()['random_randint'](*args, **kwargs) + def randint(self, *args, **kwargs): + return backend()["random_randint"](*args, **kwargs) - def normal(self, *args, **kwargs): - return backend()['random_normal'](*args, **kwargs) + def normal(self, *args, **kwargs): + return backend()["random_normal"](*args, **kwargs) - def bernoulli(self, *args, **kwargs): - return backend()['random_bernoulli'](*args, **kwargs) + def bernoulli(self, *args, **kwargs): + return backend()["random_bernoulli"](*args, **kwargs) random = RandomBackend() def logsumexp(*args, **kwargs): - """Computes the log of the sum of exponentials of input elements.""" - return backend()['logsumexp'](*args, **kwargs) + """Computes the log of the sum of exponentials of input elements.""" + return backend()["logsumexp"](*args, **kwargs) def expit(*args, **kwargs): - """Computes the expit (sigmoid) function.""" - return backend()['expit'](*args, **kwargs) + """Computes the expit (sigmoid) function.""" + return backend()["expit"](*args, **kwargs) def sigmoid(*args, **kwargs): - """Computes the sigmoid (expit) function.""" - return backend()['expit'](*args, **kwargs) + """Computes the sigmoid (expit) function.""" + return backend()["expit"](*args, **kwargs) def erf(*args, **kwargs): - """Computes the erf function.""" - return backend()['erf'](*args, **kwargs) + """Computes the erf function.""" + return backend()["erf"](*args, **kwargs) def conv(*args, **kwargs): - """Computes a generalized convolution.""" - return backend()['conv'](*args, **kwargs) + """Computes a generalized convolution.""" + return backend()["conv"](*args, **kwargs) def avg_pool(*args, **kwargs): - """Average pooling.""" - return backend()['avg_pool'](*args, **kwargs) + """Average pooling.""" + return backend()["avg_pool"](*args, **kwargs) def max_pool(*args, **kwargs): - """Max pooling.""" - return backend()['max_pool'](*args, **kwargs) + """Max pooling.""" + return backend()["max_pool"](*args, **kwargs) def sum_pool(*args, **kwargs): - """Sum pooling.""" - return backend()['sum_pool'](*args, **kwargs) + """Sum pooling.""" + return backend()["sum_pool"](*args, **kwargs) def top_k(*args, **kwargs): - """Top k.""" - return backend()['top_k'](*args, **kwargs) + """Top k.""" + return backend()["top_k"](*args, **kwargs) def sort_key_val(*args, **kwargs): - """Sorts keys along dimension and applies same permutation to values.""" - return backend()['sort_key_val'](*args, **kwargs) + """Sorts keys along dimension and applies same permutation to values.""" + return backend()["sort_key_val"](*args, **kwargs) def scan(*args, **kwargs): - """Scan to make recurrent functions run faster on accelerators.""" - return backend()['scan'](*args, **kwargs) + """Scan to make recurrent functions run faster on accelerators.""" + return backend()["scan"](*args, **kwargs) def map(*args, **kwargs): # pylint: disable=redefined-builtin - """Map a function over leading array axes.""" - return backend()['map'](*args, **kwargs) + """Map a function over leading array axes.""" + return backend()["map"](*args, **kwargs) def fori_loop(lower, upper, body_fn, init_val): - """Loop from `lower` to `upper` running `body_fn` starting from `init_val`. - - The semantics of `fori_loop` is as follows:: - - def fori_loop(lower, upper, body_fn, init_val): - val = init_val - for i in range(lower, upper): - val = body_fn(i, val) - return val - - Args: - lower: an integer representing the loop index lower bound (inclusive) - upper: an integer representing the loop index upper bound (exclusive) - body_fn: function of type `(int, a) -> a`. - init_val: initial loop carry value of type `a`. - - Returns: - Loop value from the final iteration. - """ - if 'fori_loop' in backend(): - return backend()['fori_loop'](lower, upper, body_fn, init_val) - # Use scan otherwise. - def scanned_fn(loop_carry, _): - i, x = loop_carry - return (i + 1, body_fn(i, x)), None - (_, result), _ = scan( - scanned_fn, (lower, init_val), None, length=upper - lower) - return result + """Loop from `lower` to `upper` running `body_fn` starting from `init_val`. + + The semantics of `fori_loop` is as follows:: + + def fori_loop(lower, upper, body_fn, init_val): + val = init_val + for i in range(lower, upper): + val = body_fn(i, val) + return val + + Args: + lower: an integer representing the loop index lower bound (inclusive) + upper: an integer representing the loop index upper bound (exclusive) + body_fn: function of type `(int, a) -> a`. + init_val: initial loop carry value of type `a`. + + Returns: + Loop value from the final iteration. + """ + if "fori_loop" in backend(): + return backend()["fori_loop"](lower, upper, body_fn, init_val) + # Use scan otherwise. + def scanned_fn(loop_carry, _): + i, x = loop_carry + return (i + 1, body_fn(i, x)), None + + (_, result), _ = scan(scanned_fn, (lower, init_val), None, length=upper - lower) + return result def remat(*args, **kwargs): - """Recompute everything in the backward pass to same memory.""" - return backend()['remat'](*args, **kwargs) + """Recompute everything in the backward pass to same memory.""" + return backend()["remat"](*args, **kwargs) def cond(*args, **kwargs): - """Conditional computation to run on accelerators.""" - return backend()['cond'](*args, **kwargs) + """Conditional computation to run on accelerators.""" + return backend()["cond"](*args, **kwargs) def lt(*args, **kwargs): - """Less-than function for backends that do not override <.""" - return backend()['lt'](*args, **kwargs) + """Less-than function for backends that do not override <.""" + return backend()["lt"](*args, **kwargs) def index_update(*args, **kwargs): - return backend()['index_update'](*args, **kwargs) + return backend()["index_update"](*args, **kwargs) def index_add(*args, **kwargs): - return backend()['index_add'](*args, **kwargs) + return backend()["index_add"](*args, **kwargs) def index_min(*args, **kwargs): - return backend()['index_min'](*args, **kwargs) + return backend()["index_min"](*args, **kwargs) def index_max(*args, **kwargs): - return backend()['index_max'](*args, **kwargs) + return backend()["index_max"](*args, **kwargs) def dynamic_slice(*args, **kwargs): - return backend()['dynamic_slice'](*args, **kwargs) + return backend()["dynamic_slice"](*args, **kwargs) def dynamic_slice_in_dim(*args, **kwargs): - return backend()['dynamic_slice_in_dim'](*args, **kwargs) + return backend()["dynamic_slice_in_dim"](*args, **kwargs) def dynamic_update_slice(*args, **kwargs): - return backend()['dynamic_update_slice'](*args, **kwargs) + return backend()["dynamic_update_slice"](*args, **kwargs) def dynamic_update_slice_in_dim(*args, **kwargs): - return backend()['dynamic_update_slice_in_dim'](*args, **kwargs) + return backend()["dynamic_update_slice_in_dim"](*args, **kwargs) def stop_gradient(*args, **kwargs): - """Identity on the forward pass but 0 (no gradient) on the backward pass.""" - return backend()['stop_gradient'](*args, **kwargs) + """Identity on the forward pass but 0 (no gradient) on the backward pass.""" + return backend()["stop_gradient"](*args, **kwargs) _disable_jit = False def jit(*args, **kwargs): - """Just-In-Time compiles the given function for use on accelerators.""" - global _disable_jit - if _disable_jit: - return args[0] # jit(f, **unused_now_jit_kwargs) = f - return backend()['jit'](*args, **kwargs) + """Just-In-Time compiles the given function for use on accelerators.""" + global _disable_jit + if _disable_jit: + return args[0] # jit(f, **unused_now_jit_kwargs) = f + return backend()["jit"](*args, **kwargs) def disable_jit(): - """Disables JIT-compilation; helpful for debugging.""" - global _disable_jit - _disable_jit = True + """Disables JIT-compilation; helpful for debugging.""" + global _disable_jit + _disable_jit = True def vmap(*args, **kwargs): - """Vectorizes the specified function (returns a function).""" - return backend()['vmap'](*args, **kwargs) + """Vectorizes the specified function (returns a function).""" + return backend()["vmap"](*args, **kwargs) def grad(*args, **kwargs): - """Computes the gradient of the specified function (returns a function).""" - return backend()['grad'](*args, **kwargs) + """Computes the gradient of the specified function (returns a function).""" + return backend()["grad"](*args, **kwargs) def value_and_grad(*args, **kwargs): - """Computes the gradient of the specified function together with the value.""" - if 'value_and_grad' in backend(): - return backend()['value_and_grad'](*args, **kwargs) - grad_fn = grad(*args, **kwargs) - fn = args[0] - has_aux = False - if has_aux in kwargs: - has_aux = kwargs['has_aux'] - if not has_aux: - def val_and_grad(*fn_args, **fn_kwargs): - return fn(*fn_args, **fn_kwargs), grad_fn(*fn_args, **fn_kwargs) - return val_and_grad - def val_and_grad_aux(*fn_args, **fn_kwargs): - g, aux = grad_fn(*fn_args, **fn_kwargs) - res, _ = fn(*fn_args, **fn_kwargs) - return (res, aux), g - return val_and_grad_aux + """Computes the gradient of the specified function together with the value.""" + if "value_and_grad" in backend(): + return backend()["value_and_grad"](*args, **kwargs) + + grad_fn = grad(*args, **kwargs) + fn = args[0] + has_aux = False + if has_aux in kwargs: + has_aux = kwargs["has_aux"] + if not has_aux: + + def val_and_grad(*fn_args, **fn_kwargs): + return fn(*fn_args, **fn_kwargs), grad_fn(*fn_args, **fn_kwargs) + + return val_and_grad + + def val_and_grad_aux(*fn_args, **fn_kwargs): + g, aux = grad_fn(*fn_args, **fn_kwargs) + res, _ = fn(*fn_args, **fn_kwargs) + return (res, aux), g + + return val_and_grad_aux def vjp(*args, **kwargs): - """Computes the vector-Jacobian product for the specified function.""" - return backend()['vjp'](*args, **kwargs) + """Computes the vector-Jacobian product for the specified function.""" + return backend()["vjp"](*args, **kwargs) def custom_grad(*args, **kwargs): - """Set a custom gradient computation (override the default) for a function.""" - return backend()['custom_grad'](*args, **kwargs) + """Set a custom gradient computation (override the default) for a function.""" + return backend()["custom_grad"](*args, **kwargs) def custom_vjp(f, f_fwd, f_bwd, nondiff_argnums=()): - """Set a custom vjp computation (override the default) for a function.""" - # Call backend custom_vjp if it exists. - # TODO(lukaszkaiser): unify the APIs and remove nondiff_argnums altogether. - if 'custom_vjp' in backend(): - return backend()['custom_vjp'](f, f_fwd, f_bwd) - - # Check that nondiff_argnums is (0, 1, ..., N) for some N. - # Currently we only support nondiff_argnums at the front. - counter = -1 - for i in nondiff_argnums: - counter += 1 - if i != counter: - raise ValueError('Currently we only support custom_vjps with all nondiff' - '_argnums up front, like (0,) or (0, 1) but not (1,) or' - ' (1, 2). Found: %s' % str(nondiff_argnums)) - - # Use custom_grad. - if counter == -1: # no non-diff args - def f_vjp(*args): - out, residual = f_fwd(*args) - def vjpfn(g): - return f_bwd(residual, g) - return out, vjpfn - return backend()['custom_grad'](f_vjp, f) - - # Handle non-diff args by closure. - def f_joint(*args): - """This function takes all args, first counter+1 are non-diff ones.""" - nondiff_args = list(args[:counter+1]) - def f_diff(*diff_args): # Takes only diff args, will define custom grad. - args = nondiff_args + list(diff_args) - return f(*args) - def f_vjp(*diff_args): # Custom VJP for diff args. - args = nondiff_args + list(diff_args) - out, residual = f_fwd(*args) - def vjpfn(g): - bwd_args = [residual, g] - res = f_bwd(*bwd_args) - return res[counter+1:] - return out, vjpfn - # This is the function taking only diff args with custom vjp. - f_diff_vjp = backend()['custom_grad'](f_vjp, f_diff) - # Call it on the diff args. - return f_diff_vjp(*args[counter+1:]) - return f_joint + """Set a custom vjp computation (override the default) for a function.""" + # Call backend custom_vjp if it exists. + # TODO(lukaszkaiser): unify the APIs and remove nondiff_argnums altogether. + if "custom_vjp" in backend(): + return backend()["custom_vjp"](f, f_fwd, f_bwd) + + # Check that nondiff_argnums is (0, 1, ..., N) for some N. + # Currently we only support nondiff_argnums at the front. + counter = -1 + for i in nondiff_argnums: + counter += 1 + if i != counter: + raise ValueError( + "Currently we only support custom_vjps with all nondiff" + "_argnums up front, like (0,) or (0, 1) but not (1,) or" + " (1, 2). Found: %s" % str(nondiff_argnums) + ) + + # Use custom_grad. + if counter == -1: # no non-diff args + + def f_vjp(*args): + out, residual = f_fwd(*args) + + def vjpfn(g): + return f_bwd(residual, g) + + return out, vjpfn + + return backend()["custom_grad"](f_vjp, f) + + # Handle non-diff args by closure. + def f_joint(*args): + """This function takes all args, first counter+1 are non-diff ones.""" + nondiff_args = list(args[: counter + 1]) + + def f_diff(*diff_args): # Takes only diff args, will define custom grad. + args = nondiff_args + list(diff_args) + return f(*args) + + def f_vjp(*diff_args): # Custom VJP for diff args. + args = nondiff_args + list(diff_args) + out, residual = f_fwd(*args) + + def vjpfn(g): + bwd_args = [residual, g] + res = f_bwd(*bwd_args) + return res[counter + 1 :] + + return out, vjpfn + + # This is the function taking only diff args with custom vjp. + f_diff_vjp = backend()["custom_grad"](f_vjp, f_diff) + # Call it on the diff args. + return f_diff_vjp(*args[counter + 1 :]) + + return f_joint def pmap(*args, **kwargs): - """Parallel-map to apply a function on multiple accelerators in parallel.""" - return backend()['pmap'](*args, **kwargs) + """Parallel-map to apply a function on multiple accelerators in parallel.""" + return backend()["pmap"](*args, **kwargs) def psum(*args, **kwargs): - """Parallel-sum to use within a pmap'd function for aggregation.""" - return backend()['psum'](*args, **kwargs) + """Parallel-sum to use within a pmap'd function for aggregation.""" + return backend()["psum"](*args, **kwargs) def abstract_eval(*args, **kwargs): - """Evaluates function just on signatures of parameters, return signatures.""" - return backend()['abstract_eval'](*args, **kwargs) + """Evaluates function just on signatures of parameters, return signatures.""" + return backend()["abstract_eval"](*args, **kwargs) def dataset_as_numpy(*args, **kwargs): - """Convert a tf.data.Dataset to a stream of numpy arrays.""" - if 'dataset_as_numpy' in backend(): - return backend()['dataset_as_numpy'](*args, **kwargs) - return JAX_BACKEND['dataset_as_numpy'](*args, **kwargs) + """Convert a tf.data.Dataset to a stream of numpy arrays.""" + if "dataset_as_numpy" in backend(): + return backend()["dataset_as_numpy"](*args, **kwargs) + return JAX_BACKEND["dataset_as_numpy"](*args, **kwargs) def global_device_count(*args, **kwargs): - """Return the number of accelerators (GPUs or TPUs) in all hosts.""" - return backend()['global_device_count'](*args, **kwargs) + """Return the number of accelerators (GPUs or TPUs) in all hosts.""" + return backend()["global_device_count"](*args, **kwargs) def local_device_count(*args, **kwargs): - """Return the number of accelerators (GPUs or TPUs) available on this host.""" - return backend()['local_device_count'](*args, **kwargs) + """Return the number of accelerators (GPUs or TPUs) available on this host.""" + return backend()["local_device_count"](*args, **kwargs) # Backend selection functions. @@ -380,65 +398,65 @@ def local_device_count(*args, **kwargs): def _assert_valid_backend_name(name): - for backend_ in Backend: - if backend_.value == name: - return - raise ValueError(f'No backend with name {name}') + for backend_ in Backend: + if backend_.value == name: + return + raise ValueError(f"No backend with name {name}") def set_backend(name): - """Sets the default backend to use in Trax.""" - if name: - _assert_valid_backend_name(name) - global default_backend - default_backend = name + """Sets the default backend to use in Trax.""" + if name: + _assert_valid_backend_name(name) + global default_backend + default_backend = name def _get_backend_from_string(name_str): - # name is a string. - for backend_ in Backend: - if backend_.value == name_str: - return _backend_dict[backend_] - return JAX_BACKEND + # name is a string. + for backend_ in Backend: + if backend_.value == name_str: + return _backend_dict[backend_] + return JAX_BACKEND @gin.configurable -def backend(name='jax'): - """Returns the backend used to provide fastmath ops ('tf' or 'jax').""" - if override_backend: - return _get_backend_from_string(override_backend) +def backend(name="jax"): + """Returns the backend used to provide fastmath ops ('tf' or 'jax').""" + if override_backend: + return _get_backend_from_string(override_backend) - if default_backend: - return _get_backend_from_string(default_backend) + if default_backend: + return _get_backend_from_string(default_backend) - if isinstance(name, Backend): - return _backend_dict[name] + if isinstance(name, Backend): + return _backend_dict[name] - # name is a string. - return _get_backend_from_string(name) + # name is a string. + return _get_backend_from_string(name) @contextlib.contextmanager def use_backend(name): - """Call fastmath functions with a specified backend.""" - if isinstance(name, Backend): - name = name.value + """Call fastmath functions with a specified backend.""" + if isinstance(name, Backend): + name = name.value - _assert_valid_backend_name(name) - global override_backend - prev_name_or_backend = override_backend - override_backend = name - # Run the decorated function in try-finally in case it throws, e.g. for tests. - try: - yield - finally: - override_backend = prev_name_or_backend + _assert_valid_backend_name(name) + global override_backend + prev_name_or_backend = override_backend + override_backend = name + # Run the decorated function in try-finally in case it throws, e.g. for tests. + try: + yield + finally: + override_backend = prev_name_or_backend def backend_name(): - """Returns the name of the backend currently in use ('tf' or 'jax').""" - return backend()['name'] + """Returns the name of the backend currently in use ('tf' or 'jax').""" + return backend()["name"] def is_backend(backend_): - return backend()['name'] == backend_.value + return backend()["name"] == backend_.value diff --git a/trax/fastmath/ops_test.py b/trax/fastmath/ops_test.py deleted file mode 100644 index 2e22b91b6..000000000 --- a/trax/fastmath/ops_test.py +++ /dev/null @@ -1,123 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.fastmath.ops.""" - -import collections -from absl.testing import parameterized - -import gin -import jax.numpy as jnp -import numpy as onp -from tensorflow import test -from trax import fastmath - - -_TestNamedtuple = collections.namedtuple('_TestNamedtuple', ['x']) - - -class BackendTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - def override_gin(self, bindings): - gin.parse_config_files_and_bindings(None, bindings) - - def test_backend_imports_correctly(self): - backend = fastmath.backend() - self.assertEqual(jnp, backend['np']) - self.assertNotEqual(onp, backend['np']) - - self.override_gin("backend.name = 'numpy'") - - backend = fastmath.backend() - self.assertNotEqual(jnp, backend['np']) - self.assertEqual(onp, backend['np']) - - def test_backend_can_be_set(self): - self.assertEqual(fastmath.backend_name(), 'jax') - fastmath.set_backend('tensorflow-numpy') - self.assertEqual(fastmath.backend_name(), 'tensorflow-numpy') - fastmath.set_backend(None) - self.assertEqual(fastmath.backend_name(), 'jax') - - def test_numpy_backend_delegation(self): - # Assert that we are getting JAX's numpy backend. - backend = fastmath.backend() - numpy = fastmath.numpy - self.assertEqual(jnp, backend['np']) - - # Assert that `numpy` calls the appropriate gin configured functions and - # properties. - self.assertTrue(numpy.isinf(numpy.inf)) - self.assertEqual(jnp.isinf, numpy.isinf) - self.assertEqual(jnp.inf, numpy.inf) - - # Assert that we will now get the pure numpy backend. - - self.override_gin("backend.name = 'numpy'") - - backend = fastmath.backend() - numpy = fastmath.numpy - self.assertEqual(onp, backend['np']) - - # Assert that `numpy` calls the appropriate gin configured functions and - # properties. - self.assertTrue(numpy.isinf(numpy.inf)) - self.assertEqual(onp.isinf, numpy.isinf) - self.assertEqual(onp.inf, numpy.inf) - - @parameterized.named_parameters( - ('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP)) - def test_fori_loop(self, backend): - with fastmath.use_backend(backend): - res = fastmath.fori_loop(2, 5, lambda i, x: x + i, 1) - self.assertEqual(res, 1 + 2 + 3 + 4) - - def test_nested_map(self): - inp = {'a': ([0, 1], 2), 'b': _TestNamedtuple(3)} - out = {'a': ([1, 2], 3), 'b': _TestNamedtuple(4)} - self.assertEqual(fastmath.nested_map(lambda x: x + 1, inp), out) - - def test_nested_stack(self): - inp = [ - {'a': ([0, 1], 2), 'b': _TestNamedtuple(3)}, - {'a': ([1, 2], 3), 'b': _TestNamedtuple(4)}, - ] - out = {'a': ([[0, 1], [1, 2]], [2, 3]), 'b': _TestNamedtuple([3, 4])} - onp.testing.assert_equal(fastmath.nested_stack(inp), out) - - def test_names_match(self): - # Names match up. - for backend_enum, backend_obj in fastmath.ops._backend_dict.items(): - self.assertEqual(backend_enum.value, backend_obj['name']) - - # Every backend appears in the dictionary. - for backend_enum in fastmath.ops.Backend: - self.assertIn(backend_enum, fastmath.ops._backend_dict) - - def test_use_backend_str(self): - with fastmath.use_backend('tensorflow-numpy'): - self.assertEqual(fastmath.backend_name(), 'tensorflow-numpy') - - def test_use_backend_enum(self): - with fastmath.use_backend(fastmath.Backend.NUMPY): - self.assertEqual(fastmath.backend_name(), 'numpy') - - -if __name__ == '__main__': - test.main() diff --git a/trax/fastmath/tf.py b/trax/fastmath/tf.py index e02ba40a7..4d4a322c4 100644 --- a/trax/fastmath/tf.py +++ b/trax/fastmath/tf.py @@ -24,158 +24,171 @@ def tf_abstract_eval(f): - """Returns a function that evaluates `f` given input shapes and dtypes. - - It transforms function `f` to a function that performs the same computation as - `f` but only on shapes and dtypes (a.k.a. shape inference). - - Args: - f: the function to be transformed. - - Returns: - A function whose input arguments can be either the same as `f`'s or only - their shapes/dtypes represented by `ShapeDtype`, and whose return values are - `ShapeDtype`s with the same nested structure as `f`'s return values. - """ - f_shape = tf_np_extensions.eval_on_shapes(f) - def from_shape_type(x): - if isinstance(x, ShapeDtype): - return tf.TensorSpec(x.shape, x.dtype) - else: - return x - def to_shape_type(x): # pylint: disable=missing-docstring - # TODO(wangpeng): handle partial output shapes using `tf.shape`. - def to_numpy_shape(s): - if s.is_fully_defined(): - return tuple(s.as_list()) - else: - raise ValueError("The output shapes (%s) of the dry-run'ed function are" - ' not fully defined.' % s) - def to_numpy_dtype(t): - return np.dtype(t.as_numpy_dtype) - if isinstance(x, tf.TensorSpec): - return ShapeDtype(to_numpy_shape(x.shape), to_numpy_dtype(x.dtype)) - else: - return x - def f_return(*args): - args = tf.nest.map_structure(from_shape_type, args) - res = f_shape(*args) - return tf.nest.map_structure(to_shape_type, res) - return f_return + """Returns a function that evaluates `f` given input shapes and dtypes. + + It transforms function `f` to a function that performs the same computation as + `f` but only on shapes and dtypes (a.k.a. shape inference). + + Args: + f: the function to be transformed. + + Returns: + A function whose input arguments can be either the same as `f`'s or only + their shapes/dtypes represented by `ShapeDtype`, and whose return values are + `ShapeDtype`s with the same nested structure as `f`'s return values. + """ + f_shape = tf_np_extensions.eval_on_shapes(f) + + def from_shape_type(x): + if isinstance(x, ShapeDtype): + return tf.TensorSpec(x.shape, x.dtype) + else: + return x + + def to_shape_type(x): # pylint: disable=missing-docstring + # TODO(wangpeng): handle partial output shapes using `tf.shape`. + def to_numpy_shape(s): + if s.is_fully_defined(): + return tuple(s.as_list()) + else: + raise ValueError( + "The output shapes (%s) of the dry-run'ed function are" + " not fully defined." % s + ) + + def to_numpy_dtype(t): + return np.dtype(t.as_numpy_dtype) + + if isinstance(x, tf.TensorSpec): + return ShapeDtype(to_numpy_shape(x.shape), to_numpy_dtype(x.dtype)) + else: + return x + + def f_return(*args): + args = tf.nest.map_structure(from_shape_type, args) + res = f_shape(*args) + return tf.nest.map_structure(to_shape_type, res) + + return f_return # The arguments order is different from tf_np_extensions.uniform def tf_randint(key, shape, minval, maxval, dtype=np.int32): - """Sample uniform random values in [minval, maxval) with given shape/dtype. + """Sample uniform random values in [minval, maxval) with given shape/dtype. - Args: - key: a PRNGKey used as the random key. - shape: a tuple of nonnegative integers representing the shape. - minval: int or array of ints broadcast-compatible with ``shape``, a minimum - (inclusive) value for the range. - maxval: int or array of ints broadcast-compatible with ``shape``, a maximum - (exclusive) value for the range. - dtype: optional, an int dtype for the returned values (default int32). + Args: + key: a PRNGKey used as the random key. + shape: a tuple of nonnegative integers representing the shape. + minval: int or array of ints broadcast-compatible with ``shape``, a minimum + (inclusive) value for the range. + maxval: int or array of ints broadcast-compatible with ``shape``, a maximum + (exclusive) value for the range. + dtype: optional, an int dtype for the returned values (default int32). - Returns: - A random array with the specified shape and dtype. - """ - return tf_np_extensions.uniform(key, shape, minval=minval, maxval=maxval, - dtype=dtype) + Returns: + A random array with the specified shape and dtype. + """ + return tf_np_extensions.uniform( + key, shape, minval=minval, maxval=maxval, dtype=dtype + ) _tf_xla_forced_compile_enabled = False def tf_xla_forced_compile_enabled(): - return _tf_xla_forced_compile_enabled + return _tf_xla_forced_compile_enabled def set_tf_xla_forced_compile(b): - global _tf_xla_forced_compile_enabled - _tf_xla_forced_compile_enabled = b + global _tf_xla_forced_compile_enabled + _tf_xla_forced_compile_enabled = b def _tf_jit(*args, **kwargs): - kwargs['xla_forced_compile'] = tf_xla_forced_compile_enabled() - kwargs.pop('donate_argnums', None) # donate_argnums not used in TF - return tf_np_extensions.jit(*args, **kwargs) + kwargs["xla_forced_compile"] = tf_xla_forced_compile_enabled() + kwargs.pop("donate_argnums", None) # donate_argnums not used in TF + return tf_np_extensions.jit(*args, **kwargs) def _tf_pmap(*args, **kwargs): - kwargs.pop('donate_argnums', None) # donate_argnums not used in TF - return tf_np_extensions.pmap(*args, **kwargs) + kwargs.pop("donate_argnums", None) # donate_argnums not used in TF + return tf_np_extensions.pmap(*args, **kwargs) def _tf_grad(f, **kwargs): - """Grad with support for argnums.""" - argnums = kwargs.pop('argnums', 0) - if argnums != 0: - def g(*args, **kwargs): - args = list(args) - args[0], args[argnums] = args[argnums], args[0] - return f(*args, **kwargs) - else: - g = f - grad_g = tf_np_extensions.grad(g, **kwargs) - if argnums == 0: - return grad_g - def grad_f(*args, **kwargs): - args = list(args) - args[0], args[argnums] = args[argnums], args[0] - return grad_g(*args, **kwargs) - return grad_f + """Grad with support for argnums.""" + argnums = kwargs.pop("argnums", 0) + if argnums != 0: + + def g(*args, **kwargs): + args = list(args) + args[0], args[argnums] = args[argnums], args[0] + return f(*args, **kwargs) + + else: + g = f + grad_g = tf_np_extensions.grad(g, **kwargs) + if argnums == 0: + return grad_g + + def grad_f(*args, **kwargs): + args = list(args) + args[0], args[argnums] = args[argnums], args[0] + return grad_g(*args, **kwargs) + + return grad_f def _fold_in(rng, d): - """Equivalent of jax.random.fold_in.""" - # TODO(lukaszkaiser): verify that this function has good randomness - # properties or switch to an implementation equivalent to JAX. - _, rng = tf_np_extensions.split(rng + tf_np.sum(d).astype(tf_np.int64), 2) - return rng + """Equivalent of jax.random.fold_in.""" + # TODO(lukaszkaiser): verify that this function has good randomness + # properties or switch to an implementation equivalent to JAX. + _, rng = tf_np_extensions.split(rng + tf_np.sum(d).astype(tf_np.int64), 2) + return rng TF_BACKEND = { - 'name': 'tensorflow-numpy', - 'np': tf_np, - 'jit': _tf_jit, - 'stop_gradient': tf_np_extensions.stop_gradient, - 'grad': _tf_grad, - 'vjp': tf_np_extensions.vjp, - 'custom_grad': tf_np_extensions.custom_grad, - 'abstract_eval': tf_abstract_eval, - 'expit': tf_np_extensions.expit, - 'erf': tf_np_extensions.erf, - 'index_update': tf_np_extensions.index_update, - 'index_add': tf_np_extensions.index_add, - 'index_min': tf_np_extensions.index_min, - 'index_max': tf_np_extensions.index_max, - 'dynamic_slice': tf_np_extensions.dynamic_slice, - 'dynamic_slice_in_dim': tf_np_extensions.dynamic_slice_in_dim, - 'dynamic_update_slice': tf_np_extensions.dynamic_update_slice, - 'dynamic_update_slice_in_dim': tf_np_extensions.dynamic_update_slice_in_dim, - 'logsumexp': tf_np_extensions.logsumexp, - 'conv': tf_np_extensions.conv, - 'lt': lambda x, y: x < y, - 'avg_pool': tf_np_extensions.avg_pool, - 'max_pool': tf_np_extensions.max_pool, - 'sort_key_val': tf_np_extensions.sort_key_val, - 'random_uniform': tf_np_extensions.uniform, - 'random_randint': tf_randint, - 'random_normal': tf_np_extensions.normal, - 'random_bernoulli': tf_np_extensions.bernoulli, - 'random_get_prng': tf_np_extensions.prng, - 'random_split': tf_np_extensions.split, - 'random_fold_in': _fold_in, + "name": "tensorflow-numpy", + "np": tf_np, + "jit": _tf_jit, + "stop_gradient": tf_np_extensions.stop_gradient, + "grad": _tf_grad, + "vjp": tf_np_extensions.vjp, + "custom_grad": tf_np_extensions.custom_grad, + "abstract_eval": tf_abstract_eval, + "expit": tf_np_extensions.expit, + "erf": tf_np_extensions.erf, + "index_update": tf_np_extensions.index_update, + "index_add": tf_np_extensions.index_add, + "index_min": tf_np_extensions.index_min, + "index_max": tf_np_extensions.index_max, + "dynamic_slice": tf_np_extensions.dynamic_slice, + "dynamic_slice_in_dim": tf_np_extensions.dynamic_slice_in_dim, + "dynamic_update_slice": tf_np_extensions.dynamic_update_slice, + "dynamic_update_slice_in_dim": tf_np_extensions.dynamic_update_slice_in_dim, + "logsumexp": tf_np_extensions.logsumexp, + "conv": tf_np_extensions.conv, + "lt": lambda x, y: x < y, + "avg_pool": tf_np_extensions.avg_pool, + "max_pool": tf_np_extensions.max_pool, + "sort_key_val": tf_np_extensions.sort_key_val, + "random_uniform": tf_np_extensions.uniform, + "random_randint": tf_randint, + "random_normal": tf_np_extensions.normal, + "random_bernoulli": tf_np_extensions.bernoulli, + "random_get_prng": tf_np_extensions.prng, + "random_split": tf_np_extensions.split, + "random_fold_in": _fold_in, # TODO(wangpeng): See whether and how to support `remat` - 'remat': lambda f: f, - 'scan': tf_np_extensions.scan, - 'map': tf_np_extensions.tf_map, + "remat": lambda f: f, + "scan": tf_np_extensions.scan, + "map": tf_np_extensions.tf_map, # TODO(wangpeng): can we make extensions ds_as_numpy compatible with data? # 'dataset_as_numpy': tf_np_extensions.dataset_as_numpy, - 'global_device_count': lambda: max(len(tf_np_extensions.accelerators()), 1), - 'local_device_count': lambda: max(len(tf_np_extensions.accelerators()), 1), - 'pmap': _tf_pmap, - 'psum': tf_np_extensions.psum, - 'vmap': tf_np_extensions.vmap, + "global_device_count": lambda: max(len(tf_np_extensions.accelerators()), 1), + "local_device_count": lambda: max(len(tf_np_extensions.accelerators()), 1), + "pmap": _tf_pmap, + "psum": tf_np_extensions.psum, + "vmap": tf_np_extensions.vmap, } diff --git a/trax/jaxboard.py b/trax/jaxboard.py index c160c63fa..7be429199 100644 --- a/trax/jaxboard.py +++ b/trax/jaxboard.py @@ -23,10 +23,11 @@ import warnings import wave import matplotlib as mpl + # Necessary to prevent attempted Tk import: with warnings.catch_warnings(): - warnings.simplefilter('ignore') - mpl.use('Agg') + warnings.simplefilter("ignore") + mpl.use("Agg") # pylint: disable=g-import-not-at-top import matplotlib.pyplot as plt import numpy as np @@ -35,326 +36,341 @@ # pylint: disable=g-direct-tensorflow-import from tensorflow.core.util import event_pb2 from tensorflow.python.summary.writer.event_file_writer import EventFileWriter + # pylint: enable=g-direct-tensorflow-import def _pack_images(images, rows, cols): - """Helper utility to make a tiled field of images from numpy arrays. - - Args: - images: Image tensor in shape [N, W, H, C]. - rows: Number of images per row in tiled image. - cols: Number of images per column in tiled image. - - Returns: - A tiled image of shape [W * rows, H * cols, C]. - Truncates incomplete rows. - """ - shape = np.shape(images) - width, height, depth = shape[-3:] - images = np.reshape(images, (-1, width, height, depth)) - batch = np.shape(images)[0] - rows = np.minimum(rows, batch) - cols = np.minimum(batch // rows, cols) - images = images[:rows * cols] - images = np.reshape(images, (rows, cols, width, height, depth)) - images = np.transpose(images, [0, 2, 1, 3, 4]) - images = np.reshape(images, [rows * width, cols * height, depth]) - return images - - -class SummaryWriter: - """Saves data in event and summary protos for tensorboard.""" - - def __init__(self, log_dir, enable=True): - """Create a new SummaryWriter. + """Helper utility to make a tiled field of images from numpy arrays. Args: - log_dir: path to record tfevents files in. - enable: bool: if False don't actually write or flush data. Used in - multihost training. - """ - # If needed, create log_dir directory as well as missing parent directories. - if not tf.io.gfile.isdir(log_dir): - tf.io.gfile.makedirs(log_dir) - - self._event_writer = EventFileWriter(log_dir, 10, 120, None) - self._step = 0 - self._closed = False - self._enabled = enable - - def add_summary(self, summary, step): - if not self._enabled: - return - event = event_pb2.Event(summary=summary) - event.wall_time = time.time() - if step is not None: - event.step = int(step) - self._event_writer.add_event(event) - - def close(self): - """Close SummaryWriter. Final!""" - if not self._closed: - self._event_writer.close() - self._closed = True - del self._event_writer - - def __del__(self): # safe? - # TODO(afrozm): Sometimes this complains with - # `TypeError: 'NoneType' object is not callable` - try: - self.close() - except Exception: # pylint: disable=broad-except - pass - - def flush(self): - if not self._enabled: - return - self._event_writer.flush() - - def scalar(self, tag, value, step=None): - """Saves scalar value. + images: Image tensor in shape [N, W, H, C]. + rows: Number of images per row in tiled image. + cols: Number of images per column in tiled image. - Args: - tag: str: label for this data - value: int/float: number to log - step: int: training step - """ - value = float(np.array(value)) - if step is None: - step = self._step - else: - self._step = step - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) - self.add_summary(summary, step) - - def image(self, tag, image, step=None): - """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3]. - - Args: - tag: str: label for this data - image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ - step: int: training step + Returns: + A tiled image of shape [W * rows, H * cols, C]. + Truncates incomplete rows. """ - image = np.array(image) - if step is None: - step = self._step - else: - self._step = step - if len(np.shape(image)) == 2: - image = image[:, :, np.newaxis] - if np.shape(image)[-1] == 1: - image = np.repeat(image, 3, axis=-1) - image_strio = io.BytesIO() - plt.imsave(image_strio, image, format='png') - image_summary = tf.compat.v1.Summary.Image( - encoded_image_string=image_strio.getvalue(), - colorspace=3, - height=image.shape[0], - width=image.shape[1]) - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) - self.add_summary(summary, step) - - def images(self, tag, images, step=None, rows=None, cols=None): - """Saves (rows, cols) tiled images from np.ndarray. - - If either rows or cols aren't given, they are determined automatically - from the size of the image batch, if neither are given a long column - of images is produced. This truncates the image batch rather than padding - if it doesn't fill the final row. + shape = np.shape(images) + width, height, depth = shape[-3:] + images = np.reshape(images, (-1, width, height, depth)) + batch = np.shape(images)[0] + rows = np.minimum(rows, batch) + cols = np.minimum(batch // rows, cols) + images = images[: rows * cols] + images = np.reshape(images, (rows, cols, width, height, depth)) + images = np.transpose(images, [0, 2, 1, 3, 4]) + images = np.reshape(images, [rows * width, cols * height, depth]) + return images - Args: - tag: str: label for this data - images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d - step: int: training step - rows: int: number of rows in tile - cols: int: number of columns in tile - """ - images = np.array(images) - if step is None: - step = self._step - else: - self._step = step - n_images = np.shape(images)[0] - if rows is None and cols is None: - rows = 1 - cols = n_images - elif rows is None: - rows = n_images // cols - elif cols is None: - cols = n_images // rows - tiled_images = _pack_images(images, rows, cols) - self.image(tag, tiled_images, step=step) - - def plot(self, tag, mpl_plt, step=None, close_plot=True): - """Saves matplotlib plot output to summary image. - Args: - tag: str: label for this data - mpl_plt: matplotlib stateful pyplot object with prepared plotting state - step: int: training step - close_plot: bool: automatically closes plot - """ - if step is None: - step = self._step - else: - self._step = step - fig = mpl_plt.get_current_fig_manager() - img_w, img_h = fig.canvas.get_width_height() - image_buf = io.BytesIO() - mpl_plt.savefig(image_buf, format='png') - image_summary = tf.compat.v1.Summary.Image( - encoded_image_string=image_buf.getvalue(), - colorspace=4, # RGBA - height=img_h, - width=img_w) - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) - self.add_summary(summary, step) - if close_plot: - mpl_plt.close() - - def audio(self, tag, audiodata, step=None, sample_rate=44100): - """Saves audio. - - NB: single channel only right now. - - Args: - tag: str: label for this data - audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave - step: int: training step - sample_rate: sample rate of passed in audio buffer - """ - audiodata = np.array(audiodata) - if step is None: - step = self._step - else: - self._step = step - audiodata = np.clip(np.squeeze(audiodata), -1, 1) - if audiodata.ndim != 1: - raise ValueError('Audio data must be 1D.') - sample_list = (32767.0 * audiodata).astype(int).tolist() - wio = io.BytesIO() - wav_buf = wave.open(wio, 'wb') - wav_buf.setnchannels(1) - wav_buf.setsampwidth(2) - wav_buf.setframerate(sample_rate) - enc = b''.join([struct.pack(' 0 else np.concatenate([[0], counts[:end]])) - limits = limits[start:end + 1] - sum_sq = values.dot(values) - histo = tf.compat.v1.HistogramProto( - min=values.min(), - max=values.max(), - num=len(values), - sum=values.sum(), - sum_squares=sum_sq, - bucket_limit=limits.tolist(), - bucket=counts.tolist()) - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value(tag=tag, histo=histo)]) - self.add_summary(summary, step) - - def text(self, tag, textdata, step=None): - """Saves a text summary. - - Args: - tag: str: label for this data - textdata: string, or 1D/2D list/numpy array of strings - step: int: training step - Note: markdown formatting is rendered by tensorboard. - """ - if step is None: - step = self._step - else: - self._step = step - smd = tf.compat.v1.SummaryMetadata( - plugin_data=tf.compat.v1.SummaryMetadata.PluginData(plugin_name='text')) - if isinstance(textdata, (str, bytes)): - tensor = tf.make_tensor_proto( - values=[textdata.encode(encoding='utf_8')], shape=(1,)) - else: - textdata = np.array(textdata) # convert lists, jax arrays, etc. - datashape = np.shape(textdata) - if len(datashape) == 1: - tensor = tf.make_tensor_proto( - values=[td.encode(encoding='utf_8') for td in textdata], - shape=(datashape[0],)) - elif len(datashape) == 2: - tensor = tf.make_tensor_proto( - values=[ - td.encode(encoding='utf_8') for td in np.reshape(textdata, -1) - ], - shape=(datashape[0], datashape[1])) - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value( - tag=tag, metadata=smd, tensor=tensor)]) - self.add_summary(summary, step) +class SummaryWriter: + """Saves data in event and summary protos for tensorboard.""" + + def __init__(self, log_dir, enable=True): + """Create a new SummaryWriter. + + Args: + log_dir: path to record tfevents files in. + enable: bool: if False don't actually write or flush data. Used in + multihost training. + """ + # If needed, create log_dir directory as well as missing parent directories. + if not tf.io.gfile.isdir(log_dir): + tf.io.gfile.makedirs(log_dir) + + self._event_writer = EventFileWriter(log_dir, 10, 120, None) + self._step = 0 + self._closed = False + self._enabled = enable + + def add_summary(self, summary, step): + if not self._enabled: + return + event = event_pb2.Event(summary=summary) + event.wall_time = time.time() + if step is not None: + event.step = int(step) + self._event_writer.add_event(event) + + def close(self): + """Close SummaryWriter. Final!""" + if not self._closed: + self._event_writer.close() + self._closed = True + del self._event_writer + + def __del__(self): # safe? + # TODO(afrozm): Sometimes this complains with + # `TypeError: 'NoneType' object is not callable` + try: + self.close() + except Exception: # pylint: disable=broad-except + pass + + def flush(self): + if not self._enabled: + return + self._event_writer.flush() + + def scalar(self, tag, value, step=None): + """Saves scalar value. + + Args: + tag: str: label for this data + value: int/float: number to log + step: int: training step + """ + value = float(np.array(value)) + if step is None: + step = self._step + else: + self._step = step + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)] + ) + self.add_summary(summary, step) + + def image(self, tag, image, step=None): + """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3]. + + Args: + tag: str: label for this data + image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ + step: int: training step + """ + image = np.array(image) + if step is None: + step = self._step + else: + self._step = step + if len(np.shape(image)) == 2: + image = image[:, :, np.newaxis] + if np.shape(image)[-1] == 1: + image = np.repeat(image, 3, axis=-1) + image_strio = io.BytesIO() + plt.imsave(image_strio, image, format="png") + image_summary = tf.compat.v1.Summary.Image( + encoded_image_string=image_strio.getvalue(), + colorspace=3, + height=image.shape[0], + width=image.shape[1], + ) + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)] + ) + self.add_summary(summary, step) + + def images(self, tag, images, step=None, rows=None, cols=None): + """Saves (rows, cols) tiled images from np.ndarray. + + If either rows or cols aren't given, they are determined automatically + from the size of the image batch, if neither are given a long column + of images is produced. This truncates the image batch rather than padding + if it doesn't fill the final row. + + Args: + tag: str: label for this data + images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d + step: int: training step + rows: int: number of rows in tile + cols: int: number of columns in tile + """ + images = np.array(images) + if step is None: + step = self._step + else: + self._step = step + n_images = np.shape(images)[0] + if rows is None and cols is None: + rows = 1 + cols = n_images + elif rows is None: + rows = n_images // cols + elif cols is None: + cols = n_images // rows + tiled_images = _pack_images(images, rows, cols) + self.image(tag, tiled_images, step=step) + + def plot(self, tag, mpl_plt, step=None, close_plot=True): + """Saves matplotlib plot output to summary image. + + Args: + tag: str: label for this data + mpl_plt: matplotlib stateful pyplot object with prepared plotting state + step: int: training step + close_plot: bool: automatically closes plot + """ + if step is None: + step = self._step + else: + self._step = step + fig = mpl_plt.get_current_fig_manager() + img_w, img_h = fig.canvas.get_width_height() + image_buf = io.BytesIO() + mpl_plt.savefig(image_buf, format="png") + image_summary = tf.compat.v1.Summary.Image( + encoded_image_string=image_buf.getvalue(), + colorspace=4, # RGBA + height=img_h, + width=img_w, + ) + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)] + ) + self.add_summary(summary, step) + if close_plot: + mpl_plt.close() + + def audio(self, tag, audiodata, step=None, sample_rate=44100): + """Saves audio. + + NB: single channel only right now. + + Args: + tag: str: label for this data + audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave + step: int: training step + sample_rate: sample rate of passed in audio buffer + """ + audiodata = np.array(audiodata) + if step is None: + step = self._step + else: + self._step = step + audiodata = np.clip(np.squeeze(audiodata), -1, 1) + if audiodata.ndim != 1: + raise ValueError("Audio data must be 1D.") + sample_list = (32767.0 * audiodata).astype(int).tolist() + wio = io.BytesIO() + wav_buf = wave.open(wio, "wb") + wav_buf.setnchannels(1) + wav_buf.setsampwidth(2) + wav_buf.setframerate(sample_rate) + enc = b"".join([struct.pack(" 0 + else np.concatenate([[0], counts[:end]]) + ) + limits = limits[start : end + 1] + sum_sq = values.dot(values) + histo = tf.compat.v1.HistogramProto( + min=values.min(), + max=values.max(), + num=len(values), + sum=values.sum(), + sum_squares=sum_sq, + bucket_limit=limits.tolist(), + bucket=counts.tolist(), + ) + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, histo=histo)] + ) + self.add_summary(summary, step) + + def text(self, tag, textdata, step=None): + """Saves a text summary. + + Args: + tag: str: label for this data + textdata: string, or 1D/2D list/numpy array of strings + step: int: training step + Note: markdown formatting is rendered by tensorboard. + """ + if step is None: + step = self._step + else: + self._step = step + smd = tf.compat.v1.SummaryMetadata( + plugin_data=tf.compat.v1.SummaryMetadata.PluginData(plugin_name="text") + ) + if isinstance(textdata, (str, bytes)): + tensor = tf.make_tensor_proto( + values=[textdata.encode(encoding="utf_8")], shape=(1,) + ) + else: + textdata = np.array(textdata) # convert lists, jax arrays, etc. + datashape = np.shape(textdata) + if len(datashape) == 1: + tensor = tf.make_tensor_proto( + values=[td.encode(encoding="utf_8") for td in textdata], + shape=(datashape[0],), + ) + elif len(datashape) == 2: + tensor = tf.make_tensor_proto( + values=[ + td.encode(encoding="utf_8") for td in np.reshape(textdata, -1) + ], + shape=(datashape[0], datashape[1]), + ) + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, metadata=smd, tensor=tensor)] + ) + self.add_summary(summary, step) # Copied from gin/tf/utils.py:GinConfigSaverHook def markdownify_operative_config_str(string): - """Convert an operative config string to markdown format.""" - - # TODO(b/37527917): Total hack below. Implement more principled formatting. - def process(line): - """Convert a single line to markdown format.""" - if not line.startswith('#'): - return ' ' + line - - line = line[2:] - if line.startswith('===='): - return '' - if line.startswith('None'): - return ' # None.' - if line.endswith(':'): - return '#### ' + line - return line - - output_lines = [] - for line in string.splitlines(): - procd_line = process(line) - if procd_line is not None: - output_lines.append(procd_line) - - return '\n'.join(output_lines) + """Convert an operative config string to markdown format.""" + + # TODO(b/37527917): Total hack below. Implement more principled formatting. + def process(line): + """Convert a single line to markdown format.""" + if not line.startswith("#"): + return " " + line + + line = line[2:] + if line.startswith("===="): + return "" + if line.startswith("None"): + return " # None." + if line.endswith(":"): + return "#### " + line + return line + + output_lines = [] + for line in string.splitlines(): + procd_line = process(line) + if procd_line is not None: + output_lines.append(procd_line) + + return "\n".join(output_lines) diff --git a/trax/layers/__init__.py b/trax/layers/__init__.py index 3913fdafc..ce8f2f83e 100644 --- a/trax/layers/__init__.py +++ b/trax/layers/__init__.py @@ -16,6 +16,7 @@ """Layers: trainable functions as neural network building blocks.""" import gin + # We create a flat layers.* namespace for uniform calling conventions as we # upstream changes. # pylint: disable=wildcard-import @@ -44,8 +45,9 @@ # Ginify def layer_configure(*args, **kwargs): - kwargs['module'] = 'trax.layers' - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.layers" + return gin.external_configurable(*args, **kwargs) + # pylint: disable=used-before-assignment # pylint: disable=invalid-name @@ -69,41 +71,44 @@ def layer_configure(*args, **kwargs): FilterResponseNorm = layer_configure(FilterResponseNorm) ThresholdedLinearUnit = layer_configure(ThresholdedLinearUnit) -Attention = layer_configure(Attention, denylist=['mode']) -CausalAttention = layer_configure(CausalAttention, denylist=['mode']) -FavorAttention = layer_configure(FavorAttention, denylist=['mode']) -Favor = layer_configure(Favor, denylist=['mode']) -CausalFavor = layer_configure(CausalFavor, denylist=['mode']) -CausalFavorAttention = layer_configure(CausalFavorAttention, denylist=['mode']) +Attention = layer_configure(Attention, denylist=["mode"]) +CausalAttention = layer_configure(CausalAttention, denylist=["mode"]) +FavorAttention = layer_configure(FavorAttention, denylist=["mode"]) +Favor = layer_configure(Favor, denylist=["mode"]) +CausalFavor = layer_configure(CausalFavor, denylist=["mode"]) +CausalFavorAttention = layer_configure(CausalFavorAttention, denylist=["mode"]) DotProductCausalAttention = layer_configure( - DotProductCausalAttention, denylist=['mode']) -SelfAttention = layer_configure(SelfAttention, denylist=['mode']) -ModularCausalAttention = layer_configure(ModularCausalAttention, - denylist=['mode']) -LowRankCausalAttention = layer_configure(LowRankCausalAttention, - denylist=['mode']) -MultiplicativeCausalAttention = layer_configure(MultiplicativeCausalAttention, - denylist=['mode']) + DotProductCausalAttention, denylist=["mode"] +) +SelfAttention = layer_configure(SelfAttention, denylist=["mode"]) +ModularCausalAttention = layer_configure(ModularCausalAttention, denylist=["mode"]) +LowRankCausalAttention = layer_configure(LowRankCausalAttention, denylist=["mode"]) +MultiplicativeCausalAttention = layer_configure( + MultiplicativeCausalAttention, denylist=["mode"] +) MultiplicativeModularCausalAttention = layer_configure( - MultiplicativeModularCausalAttention, denylist=['mode']) -ConvCausalAttention = layer_configure(ConvCausalAttention, denylist=['mode']) + MultiplicativeModularCausalAttention, denylist=["mode"] +) +ConvCausalAttention = layer_configure(ConvCausalAttention, denylist=["mode"]) MultiplicativeConvCausalAttention = layer_configure( - MultiplicativeConvCausalAttention, denylist=['mode']) + MultiplicativeConvCausalAttention, denylist=["mode"] +) ConvTranspose = layer_configure(ConvTranspose) -LSHSelfAttention = layer_configure(LSHSelfAttention, denylist=['mode']) -PureLSHSelfAttention = layer_configure(PureLSHSelfAttention, denylist=['mode']) -MixedLSHSelfAttention = layer_configure( - MixedLSHSelfAttention, denylist=['mode']) +LSHSelfAttention = layer_configure(LSHSelfAttention, denylist=["mode"]) +PureLSHSelfAttention = layer_configure(PureLSHSelfAttention, denylist=["mode"]) +MixedLSHSelfAttention = layer_configure(MixedLSHSelfAttention, denylist=["mode"]) PureLSHSelfAttentionWrapper = layer_configure( - PureLSHSelfAttentionWrapper, denylist=['mode']) -EncDecAttention = layer_configure(EncDecAttention, denylist=['mode']) + PureLSHSelfAttentionWrapper, denylist=["mode"] +) +EncDecAttention = layer_configure(EncDecAttention, denylist=["mode"]) -PositionalEncoding = layer_configure( - PositionalEncoding, denylist=['mode']) +PositionalEncoding = layer_configure(PositionalEncoding, denylist=["mode"]) InfinitePositionalEncoding = layer_configure( - InfinitePositionalEncoding, denylist=['mode']) + InfinitePositionalEncoding, denylist=["mode"] +) TimeBinPositionalEncoding = layer_configure( - TimeBinPositionalEncoding, denylist=['mode']) + TimeBinPositionalEncoding, denylist=["mode"] +) AtariConvInit = layer_configure(AtariConvInit) CrossEntropyLossWithLogSoftmax = layer_configure(CrossEntropyLossWithLogSoftmax) diff --git a/trax/layers/acceleration.py b/trax/layers/acceleration.py index 57fd7ffe5..24642b2df 100644 --- a/trax/layers/acceleration.py +++ b/trax/layers/acceleration.py @@ -23,247 +23,261 @@ class Accelerate(base.Layer): - """Accelerates a layer, running in data-parallel way on multiple devices. - - By default it uses all available accelerators, splits the input on the - first (batch) axis, and runs each part on the corresponding accelerator. - If only one accelerator is available, this layer JIT-compiles the underlying - layer and in this way makes it run faster. - - The output is guaranteed to be the same as the output of the original layer - if the batch dimension is divisible by the number of devices. If it is not, - then 0-padding is added to make it divisible and the output may be affected - if it relies on layers like batch normalization. - - This layer does not require calling ``init`` if the underlying layer has - already been initialized, so it can be used as follows:: - - layer = tl.Serial(...) - layer.init(...) - fast_layer = tl.Accelerate(layer) - y = fast_layer(x) # Split x on batch and run data-parallel - - In case the weights of this layer need to be set using the weights of - the sublayer, use the ``replicate_weights`` function:: - - # Instead of layer.weights = new_weights: - fast_layer.replicate_weights(new_weights) - - """ - - def __init__(self, layer, n_devices=None): - super().__init__(n_in=layer.n_in, n_out=layer.n_out) - self._sublayers = [layer] - self._n_devices = n_devices or fastmath.local_device_count() - self._jit_pure_fn = jit_forward( - layer.pure_fn, self._n_devices, do_mean=False) - - @property - def sublayer(self): - """Returns the unique sublayer managed by this layer.""" - return self._sublayers[0] - - def pure_fn(self, x, weights, state, rng, use_cache=False): - """Calls ``self.sublayer.pure_fn`` in an accelerated way.""" - # Check if we can divide x evenly across devices. - # Note: x can be a list/tuple because the underlying layer may take - # its input as a list/tuple, ex: (inputs, targets, weight). - if isinstance(x, (list, tuple)): - remainder = x[0].shape[0] % self._n_devices - else: - remainder = x.shape[0] % self._n_devices - if remainder == 0: # If yes, run the accelerated sublayer.pure_fn. - return self._jit_pure_fn(x, weights, state, rng) - # If not, pad first. - def pad(z): - pad_widths = [(0, 0)] * len(z.shape) - pad_widths[0] = (0, self._n_devices - remainder) - return jnp.pad(z, pad_widths, mode='constant', - constant_values=z.dtype.type(0)) - padded_x = [pad(z) for z in x] if isinstance(x, (list, tuple)) else pad(x) - # Run and un-pad. - padded_y, state = self._jit_pure_fn(padded_x, weights, state, rng) - if isinstance(x, (list, tuple)): - y = tuple(padded_z[:z.shape[0]] for (padded_z, z) in zip(padded_y, x)) - y = list(y) if isinstance(x, list) else y - else: - y = padded_y[:x.shape[0]] - return y, state - - def _prepare_weights(self, weights): - """Replicate or shard weights for the number of devices requested.""" - if base.N_WEIGHTS_SHARDS > 1: - if base.N_WEIGHTS_SHARDS % self._n_devices != 0: - raise ValueError(f'Number of shards ({base.N_WEIGHTS_SHARDS}) must ' - f'be a multiple of n_devices ({self._n_devices}).') - return base.shard(weights, base.N_WEIGHTS_SHARDS) - else: - return for_n_devices(weights, self._n_devices) - - def init(self, input_signature): - """Calls ``self.sublayer.init`` and replicates its values onto devices.""" - weights, state = self.sublayer.init(input_signature, use_cache=True) - self._weights = self._prepare_weights(weights) - self._state = for_n_devices(state, self._n_devices) - return (self.weights, self.state) - - def replicate_weights(self, weights): - """Sets the weights of the sublayer and replicates them for this layer.""" - self.sublayer.weights = weights - self._weights = self._prepare_weights(weights) - - def replicate_state(self, state): - """Sets the state of the sublayer and replicates it for this layer.""" - self.sublayer.state = state - self._state = for_n_devices(state, self._n_devices) - - def _unreplicate(self, x): - """Returns a single-device version of ``x``.""" - if self._n_devices < 2: - return x - return fastmath.nested_map(lambda y: y[0], x) - - @property - def weights(self): - # Override the getter so it works even if only sublayer is initialized. - if self._weights is base.EMPTY_WEIGHTS: - self._weights = self._prepare_weights(self.sublayer.weights) - return self._weights - - @weights.setter - def weights(self, weights): - self._weights = weights - self.sublayer.weights = self._unreplicate(weights) - - @property - def state(self): - # Override the getter so it works even if only sublayer is initialized. - if self._state is base.EMPTY_STATE: - self._state = for_n_devices(self.sublayer.state, self._n_devices) - return self._state - - @state.setter - def state(self, state): - self._state = state - self.sublayer.state = self._unreplicate(state) + """Accelerates a layer, running in data-parallel way on multiple devices. + + By default it uses all available accelerators, splits the input on the + first (batch) axis, and runs each part on the corresponding accelerator. + If only one accelerator is available, this layer JIT-compiles the underlying + layer and in this way makes it run faster. + + The output is guaranteed to be the same as the output of the original layer + if the batch dimension is divisible by the number of devices. If it is not, + then 0-padding is added to make it divisible and the output may be affected + if it relies on layers like batch normalization. + + This layer does not require calling ``init`` if the underlying layer has + already been initialized, so it can be used as follows:: + + layer = tl.Serial(...) + layer.init(...) + fast_layer = tl.Accelerate(layer) + y = fast_layer(x) # Split x on batch and run data-parallel + + In case the weights of this layer need to be set using the weights of + the sublayer, use the ``replicate_weights`` function:: + + # Instead of layer.weights = new_weights: + fast_layer.replicate_weights(new_weights) + + """ + + def __init__(self, layer, n_devices=None): + super().__init__(n_in=layer.n_in, n_out=layer.n_out) + self._sublayers = [layer] + self._n_devices = n_devices or fastmath.local_device_count() + self._jit_pure_fn = jit_forward(layer.pure_fn, self._n_devices, do_mean=False) + + @property + def sublayer(self): + """Returns the unique sublayer managed by this layer.""" + return self._sublayers[0] + + def pure_fn(self, x, weights, state, rng, use_cache=False): + """Calls ``self.sublayer.pure_fn`` in an accelerated way.""" + # Check if we can divide x evenly across devices. + # Note: x can be a list/tuple because the underlying layer may take + # its input as a list/tuple, ex: (inputs, targets, weight). + if isinstance(x, (list, tuple)): + remainder = x[0].shape[0] % self._n_devices + else: + remainder = x.shape[0] % self._n_devices + if remainder == 0: # If yes, run the accelerated sublayer.pure_fn. + return self._jit_pure_fn(x, weights, state, rng) + # If not, pad first. + def pad(z): + pad_widths = [(0, 0)] * len(z.shape) + pad_widths[0] = (0, self._n_devices - remainder) + return jnp.pad( + z, pad_widths, mode="constant", constant_values=z.dtype.type(0) + ) + + padded_x = [pad(z) for z in x] if isinstance(x, (list, tuple)) else pad(x) + # Run and un-pad. + padded_y, state = self._jit_pure_fn(padded_x, weights, state, rng) + if isinstance(x, (list, tuple)): + y = tuple(padded_z[: z.shape[0]] for (padded_z, z) in zip(padded_y, x)) + y = list(y) if isinstance(x, list) else y + else: + y = padded_y[: x.shape[0]] + return y, state + + def _prepare_weights(self, weights): + """Replicate or shard weights for the number of devices requested.""" + if base.N_WEIGHTS_SHARDS > 1: + if base.N_WEIGHTS_SHARDS % self._n_devices != 0: + raise ValueError( + f"Number of shards ({base.N_WEIGHTS_SHARDS}) must " + f"be a multiple of n_devices ({self._n_devices})." + ) + return base.shard(weights, base.N_WEIGHTS_SHARDS) + else: + return for_n_devices(weights, self._n_devices) + + def init(self, input_signature): + """Calls ``self.sublayer.init`` and replicates its values onto devices.""" + weights, state = self.sublayer.init(input_signature, use_cache=True) + self._weights = self._prepare_weights(weights) + self._state = for_n_devices(state, self._n_devices) + return (self.weights, self.state) + + def replicate_weights(self, weights): + """Sets the weights of the sublayer and replicates them for this layer.""" + self.sublayer.weights = weights + self._weights = self._prepare_weights(weights) + + def replicate_state(self, state): + """Sets the state of the sublayer and replicates it for this layer.""" + self.sublayer.state = state + self._state = for_n_devices(state, self._n_devices) + + def _unreplicate(self, x): + """Returns a single-device version of ``x``.""" + if self._n_devices < 2: + return x + return fastmath.nested_map(lambda y: y[0], x) + + @property + def weights(self): + # Override the getter so it works even if only sublayer is initialized. + if self._weights is base.EMPTY_WEIGHTS: + self._weights = self._prepare_weights(self.sublayer.weights) + return self._weights + + @weights.setter + def weights(self, weights): + self._weights = weights + self.sublayer.weights = self._unreplicate(weights) + + @property + def state(self): + # Override the getter so it works even if only sublayer is initialized. + if self._state is base.EMPTY_STATE: + self._state = for_n_devices(self.sublayer.state, self._n_devices) + return self._state + + @state.setter + def state(self, state): + self._state = state + self.sublayer.state = self._unreplicate(state) # TODO(jonni): Rename, since implementation does not use pmean. def mean_or_pmean(n_devices, x, axis=None): - """Computes the mean of a distributed value ``x``. - - Args: - n_devices: Number of devices. - x: Distributed array. - axis: Axis along which to compute means; can only be ``0`` or ``None``. - - Returns: - A local array. - """ - if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1: - if axis not in (None, 0): - raise ValueError('axis can only be None or 0') - x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices - if axis is None: - x = jnp.mean(x) - return x - else: - return jnp.mean(x, axis=axis) + """Computes the mean of a distributed value ``x``. + + Args: + n_devices: Number of devices. + x: Distributed array. + axis: Axis along which to compute means; can only be ``0`` or ``None``. + + Returns: + A local array. + """ + if fastmath.backend_name() == "tensorflow-numpy" and n_devices > 1: + if axis not in (None, 0): + raise ValueError("axis can only be None or 0") + x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices + if axis is None: + x = jnp.mean(x) + return x + else: + return jnp.mean(x, axis=axis) def jit_forward(forward, n_devices, do_mean=True): - """Returns a JIT-compiled forward function running on ``n_devices``.""" - model_predict = _accelerate(forward, n_devices) - # n_devices == 0 => CPU - if n_devices < 2: - return model_predict - - def predict(x, weights, state, rng): - """Predict function JIT-compiled and parallelized as requested.""" - res, state = model_predict( - reshape_by_device(x, n_devices), - weights, - state, - jnp.stack(fastmath.random.split(rng, n_devices))) - res = _combine_devices(res) - if do_mean: - return fastmath.nested_map( - lambda y: mean_or_pmean(n_devices, y, axis=0), res), state - else: - return res, state - - return predict + """Returns a JIT-compiled forward function running on ``n_devices``.""" + model_predict = _accelerate(forward, n_devices) + # n_devices == 0 => CPU + if n_devices < 2: + return model_predict + + def predict(x, weights, state, rng): + """Predict function JIT-compiled and parallelized as requested.""" + res, state = model_predict( + reshape_by_device(x, n_devices), + weights, + state, + jnp.stack(fastmath.random.split(rng, n_devices)), + ) + res = _combine_devices(res) + if do_mean: + return ( + fastmath.nested_map(lambda y: mean_or_pmean(n_devices, y, axis=0), res), + state, + ) + else: + return res, state + + return predict def _combine_devices(x_tuple): - """Combines multi-device tensors into a single batch.""" - def f(x): - if len(x.shape) < 2: - return x # No extra batch dimension: use devices as batch, so return. - batch_size = x.shape[0] * x.shape[1] - return jnp.reshape(x, [batch_size] + list(x.shape[2:])) - return fastmath.nested_map(f, x_tuple) + """Combines multi-device tensors into a single batch.""" + + def f(x): + if len(x.shape) < 2: + return x # No extra batch dimension: use devices as batch, so return. + batch_size = x.shape[0] * x.shape[1] + return jnp.reshape(x, [batch_size] + list(x.shape[2:])) + + return fastmath.nested_map(f, x_tuple) def _accelerate(f, n_devices): - """Returns an accelerated version of ``f`` running on ``n_devices``.""" - if n_devices == 0: # no accelerators - run on CPU - return fastmath.jit(f, device=jax.devices('cpu')[0]) + """Returns an accelerated version of ``f`` running on ``n_devices``.""" + if n_devices == 0: # no accelerators - run on CPU + return fastmath.jit(f, device=jax.devices("cpu")[0]) - if n_devices == 1: - return fastmath.jit(f) + if n_devices == 1: + return fastmath.jit(f) - return fastmath.pmap(f, axis_name='batch') + return fastmath.pmap(f, axis_name="batch") def reshape_by_device(x, n_devices, pure_np=False): - """Reshapes possibly nested ``x`` into a shape ``(n_devices, ...)``.""" - def f(x): - x_shape = list(x.shape) - batch_size = x_shape[0] - batch_size_per_device = batch_size // n_devices - if batch_size_per_device * n_devices != batch_size: - raise ValueError(f'Number of devices ({n_devices}) does not evenly ' - f'divide batch size ({batch_size}).') - new_shape_prefix = [n_devices, batch_size_per_device] - if pure_np: - return np.reshape(x, new_shape_prefix + x_shape[1:]) - else: - return jnp.reshape(x, new_shape_prefix + x_shape[1:]) - return fastmath.nested_map(f, x) + """Reshapes possibly nested ``x`` into a shape ``(n_devices, ...)``.""" + + def f(x): + x_shape = list(x.shape) + batch_size = x_shape[0] + batch_size_per_device = batch_size // n_devices + if batch_size_per_device * n_devices != batch_size: + raise ValueError( + f"Number of devices ({n_devices}) does not evenly " + f"divide batch size ({batch_size})." + ) + new_shape_prefix = [n_devices, batch_size_per_device] + if pure_np: + return np.reshape(x, new_shape_prefix + x_shape[1:]) + else: + return jnp.reshape(x, new_shape_prefix + x_shape[1:]) + + return fastmath.nested_map(f, x) def for_n_devices(x, n_devices): - """Replicates/broadcasts ``x`` for ``n_devices``.""" - def f(x): - if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX): - return jax.device_put_replicated(x, jax.local_devices()) - elif n_devices > 1: - return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape) - else: - return x - return fastmath.nested_map(f, x) + """Replicates/broadcasts ``x`` for ``n_devices``.""" + + def f(x): + if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX): + return jax.device_put_replicated(x, jax.local_devices()) + elif n_devices > 1: + return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape) + else: + return x + + return fastmath.nested_map(f, x) def on_cpu(x): - """Puts ``x`` in CPU memory in JAX.""" - if fastmath.is_backend(fastmath.Backend.JAX): - return jax.device_put(x, jax.devices('cpu')[0]) - else: - return x + """Puts ``x`` in CPU memory in JAX.""" + if fastmath.is_backend(fastmath.Backend.JAX): + return jax.device_put(x, jax.devices("cpu")[0]) + else: + return x def on_accelerator(x): - """Puts ``x`` in (single) accelerator memory in JAX.""" - try: - accelerator_devices = jax.devices('gpu') - except RuntimeError: + """Puts ``x`` in (single) accelerator memory in JAX.""" try: - accelerator_devices = jax.devices('tpu') + accelerator_devices = jax.devices("gpu") except RuntimeError: - accelerator_devices = [] - if not accelerator_devices: - return x - if len(accelerator_devices) != 1: - return x - return jax.device_put(x, accelerator_devices[0]) + try: + accelerator_devices = jax.devices("tpu") + except RuntimeError: + accelerator_devices = [] + if not accelerator_devices: + return x + if len(accelerator_devices) != 1: + return x + return jax.device_put(x, accelerator_devices[0]) diff --git a/trax/layers/acceleration_test.py b/trax/layers/acceleration_test.py deleted file mode 100644 index 57002474d..000000000 --- a/trax/layers/acceleration_test.py +++ /dev/null @@ -1,102 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for acceleration.""" - -from absl.testing import absltest - -from jax.config import config -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes - - -class AccelerationTest(absltest.TestCase): - - def test_accelerated_same_result(self): - layer = tl.Dense(2) - x = np.random.uniform(size=(8, 7)) - layer.init(shapes.signature(x)) - y = layer(x) - z = tl.Accelerate(layer)(x) - for i in range(8): - self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4) - self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4) - - def test_accelerated_pad(self): - layer = tl.Dense(2) - x = np.random.uniform(size=(3, 7)) - layer.init(shapes.signature(x)) - y = layer(x) - z = tl.Accelerate(layer)(x) - self.assertEqual(z.shape, y.shape) - for i in range(3): - self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4) - self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4) - - def test_accelerated_weighted_category_accuracy(self): - """Test multi-device aggregation of weights.""" - layer = tl.Accelerate(tl.WeightedCategoryAccuracy()) - weights = np.array([1., 1., 1., 0.]) - targets = np.array([0, 1, 2, 3]) - - model_outputs = np.array([[.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(np.mean(accuracy), 1 / 3) - - def test_chunk_memory(self): - """Test chunking here to exercise accelerator memory usage.""" - layer = tl.Serial(tl.Dense(1024*1024), tl.Dense(128)) - chunked = tl.Chunk(layer, 256) - x = np.random.uniform(size=(16*1024, 16)) - chunked.init(shapes.signature(x)) - y = chunked(x) - z = tl.Accelerate(chunked)(x) - self.assertEqual(y.shape, (16*1024, 128)) - self.assertEqual(z.shape, (16*1024, 128)) - - def test_chunk_grad_memory(self): - """Test chunking gradient here to exercise accelerator memory usage.""" - layer = tl.Serial(tl.Dense(1024*1024), tl.Dense(24)) - chunked = tl.Chunk(layer, 256) - - @fastmath.jit - def mock_training_step(x, weights, state, rng): - def compute_mock_loss(weights): - logits, new_state = chunked.pure_fn(x, weights, state, rng) - loss = fastmath.numpy.mean(logits) - return loss, (new_state, logits) - gradients, (new_state, logits) = fastmath.grad( - compute_mock_loss, has_aux=True)(weights) - new_weights = fastmath.nested_map_multiarg( - lambda w, g: w - 1e-4 * g, weights, gradients) - return new_weights, new_state, logits - - x = np.random.uniform(size=(32*1024, 16)) - chunked.init(shapes.signature(x)) - weights, _, logits = mock_training_step( - x, chunked.weights, chunked.state, fastmath.random.get_prng(0)) - self.assertEqual(logits.shape, (32*1024, 24)) - self.assertEqual(weights[1][0][0][0].shape, (16, 1024*1024)) - - -if __name__ == '__main__': - config.config_with_absl() - absltest.main() diff --git a/trax/layers/activation_fns.py b/trax/layers/activation_fns.py index 625ff87ab..133273d2f 100644 --- a/trax/layers/activation_fns.py +++ b/trax/layers/activation_fns.py @@ -29,9 +29,9 @@ from trax.layers.base import Fn -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Relu(): - r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function. + r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -39,12 +39,12 @@ def Relu(): x & \text{otherwise}. \end{array} \right. """ - return Fn('Relu', lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x)) + return Fn("Relu", lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x)) -@assert_shape('...->...') # The output and input shapes are the same. -def ParametricRelu(a=1.): - r"""Returns a layer that computes a ReLU function with the given slope. +@assert_shape("...->...") # The output and input shapes are the same. +def ParametricRelu(a=1.0): + r"""Returns a layer that computes a ReLU function with the given slope. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -55,12 +55,12 @@ def ParametricRelu(a=1.): Args: a: Slope of line for positive inputs. """ - return Fn('ParametricRelu', lambda x: jnp.maximum(a * x, jnp.zeros_like(x))) + return Fn("ParametricRelu", lambda x: jnp.maximum(a * x, jnp.zeros_like(x))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def LeakyRelu(a=0.01): - r"""Returns a ReLU-like layer with linear nonzero outputs for negative inputs. + r"""Returns a ReLU-like layer with linear nonzero outputs for negative inputs. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -71,12 +71,12 @@ def LeakyRelu(a=0.01): Args: a: Slope of line for negative inputs. """ - return Fn('LeakyRelu', lambda x: jnp.where(x >= 0, x, a * x)) + return Fn("LeakyRelu", lambda x: jnp.where(x >= 0, x, a * x)) -@assert_shape('...->...') # The output and input shapes are the same. -def Elu(a=1.): - r"""Returns a ReLU-like layer with exponential outputs for negative inputs. +@assert_shape("...->...") # The output and input shapes are the same. +def Elu(a=1.0): + r"""Returns a ReLU-like layer with exponential outputs for negative inputs. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -89,13 +89,14 @@ def Elu(a=1.): Args: a: Coefficient multiplying the exponential, for negative inputs. """ - return Fn('Elu', lambda x: jnp.where(x > 0, x, a * jnp.expm1(x))) + return Fn("Elu", lambda x: jnp.where(x > 0, x, a * jnp.expm1(x))) -@assert_shape('...->...') # The output and input shapes are the same. -def Selu(alpha=1.6732632423543772848170429916717, - lmbda=1.0507009873554804934193349852946): - r"""Returns an `Elu`-like layer with an additional scaling/slope parameter. +@assert_shape("...->...") # The output and input shapes are the same. +def Selu( + alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946 +): + r"""Returns an `Elu`-like layer with an additional scaling/slope parameter. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -107,58 +108,62 @@ def Selu(alpha=1.6732632423543772848170429916717, alpha: Coefficient multiplying the exponential, for negative inputs. lmbda: Coefficient scaling the whole function. """ - return Fn('Selu', lambda x: lmbda * jnp.where(x > 0, x, alpha * jnp.expm1(x))) + return Fn("Selu", lambda x: lmbda * jnp.where(x > 0, x, alpha * jnp.expm1(x))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Gelu(): - r"""Returns a layer that computes the Gaussian Error Linear Unit function. + r"""Returns a layer that computes the Gaussian Error Linear Unit function. - .. math:: - f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}})) - """ - return Fn('Gelu', lambda x: x * 0.5 * (1.0 + fastmath.erf(x / jnp.sqrt(2.0)))) + .. math:: + f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}})) + """ + return Fn("Gelu", lambda x: x * 0.5 * (1.0 + fastmath.erf(x / jnp.sqrt(2.0)))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def FastGelu(): - r"""Returns a layer that computes a fast approximation to `Gelu`. + r"""Returns a layer that computes a fast approximation to `Gelu`. - .. math:: - f(x) = \frac{x}{2} \cdot (1 + \tanh(ax + abx^3)) + .. math:: + f(x) = \frac{x}{2} \cdot (1 + \tanh(ax + abx^3)) - where :math:`a = 0.7978845608` and :math:`b = 0.044715`. - """ - def f(x): # pylint: disable=invalid-name - return 0.5 * x * (1 + jnp.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) - return Fn('FastGelu', f) + where :math:`a = 0.7978845608` and :math:`b = 0.044715`. + """ + + def f(x): # pylint: disable=invalid-name + return 0.5 * x * (1 + jnp.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) + + return Fn("FastGelu", f) # pylint: disable=unnecessary-lambda -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Sigmoid(): - r"""Returns a layer that computes the sigmoid function. + r"""Returns a layer that computes the sigmoid function. - .. math:: - f(x) = \frac{1}{1 + e^{-x}} - """ - return Fn('Sigmoid', lambda x: fastmath.expit(x)) + .. math:: + f(x) = \frac{1}{1 + e^{-x}} + """ + return Fn("Sigmoid", lambda x: fastmath.expit(x)) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Tanh(): - r"""Returns a layer that computes the hyperbolic tangent function. + r"""Returns a layer that computes the hyperbolic tangent function. + + .. math:: + f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} + """ + return Fn("Tanh", lambda x: jnp.tanh(x)) + - .. math:: - f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} - """ - return Fn('Tanh', lambda x: jnp.tanh(x)) # pylint: enable=unnecessary-lambda -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def HardSigmoid(): - r"""Returns a layer that computes a linear approximation to `Sigmoid`. + r"""Returns a layer that computes a linear approximation to `Sigmoid`. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -167,12 +172,12 @@ def HardSigmoid(): 1 & \text{otherwise}. \end{array} \right. """ - return Fn('HardSigmoid', lambda x: jnp.maximum(0, jnp.minimum(1, (1 + x)))) + return Fn("HardSigmoid", lambda x: jnp.maximum(0, jnp.minimum(1, (1 + x)))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def HardTanh(): - r"""Returns a layer that computes a linear approximation to `Tanh`. + r"""Returns a layer that computes a linear approximation to `Tanh`. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -181,76 +186,76 @@ def HardTanh(): 1 & \text{otherwise}. \end{array} \right. """ - return Fn('HardTanh', lambda x: jnp.maximum(-1, jnp.minimum(1, x))) + return Fn("HardTanh", lambda x: jnp.maximum(-1, jnp.minimum(1, x))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Softplus(): - r"""Returns a layer that computes the softplus function. + r"""Returns a layer that computes the softplus function. - .. math:: - f(x) = \ln(e^x + 1) - """ - return Fn('Softplus', lambda x: jnp.logaddexp(x, 0.)) + .. math:: + f(x) = \ln(e^x + 1) + """ + return Fn("Softplus", lambda x: jnp.logaddexp(x, 0.0)) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Exp(): - """Returns a layer that computes the element-wise exponential of a tensor.""" - return Fn('Exp', lambda x: jnp.exp(x)) # pylint: disable=unnecessary-lambda + """Returns a layer that computes the element-wise exponential of a tensor.""" + return Fn("Exp", lambda x: jnp.exp(x)) # pylint: disable=unnecessary-lambda -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Log(): - """Returns a layer that computes the element-wise logarithm of a tensor.""" - return Fn('Log', lambda x: jnp.log(x)) # pylint: disable=unnecessary-lambda + """Returns a layer that computes the element-wise logarithm of a tensor.""" + return Fn("Log", lambda x: jnp.log(x)) # pylint: disable=unnecessary-lambda -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Swish(): - r"""Returns a layer that computes the Swish function. + r"""Returns a layer that computes the Swish function. - .. math:: - f(x) = x \cdot \text{sigmoid}(x) - """ - return Fn('Swish', lambda x: x * fastmath.expit(x)) + .. math:: + f(x) = x \cdot \text{sigmoid}(x) + """ + return Fn("Swish", lambda x: x * fastmath.expit(x)) -@assert_shape('...a->...b') # The output and input shapes are not the same. +@assert_shape("...a->...b") # The output and input shapes are not the same. def Glu(): - r"""Returns a layer that computes the Gated Linear Unit function. + r"""Returns a layer that computes the Gated Linear Unit function. - .. math:: - f(x) = a \cdot \text{sigmoid}(b) - where a and b are formed by splitting input in half along axis + .. math:: + f(x) = a \cdot \text{sigmoid}(b) + where a and b are formed by splitting input in half along axis - """ + """ - def _f(x, axis=-1): # pylint: disable=invalid-name - size = x.shape[axis] - assert size % 2 == 0, f'axis {axis} of size {size} is not be divisible by 2' - a, b = jnp.split(x, 2, axis) - return a * fastmath.expit(b) + def _f(x, axis=-1): # pylint: disable=invalid-name + size = x.shape[axis] + assert size % 2 == 0, f"axis {axis} of size {size} is not be divisible by 2" + a, b = jnp.split(x, 2, axis) + return a * fastmath.expit(b) - return Fn('Glu', _f) + return Fn("Glu", _f) class ThresholdedLinearUnit(base.Layer): - """Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf .""" + """Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf .""" - def init_weights_and_state(self, input_signature): - """Initializes this layer's single weight to zero.""" - del input_signature - self.weights = jnp.zeros((), dtype=jnp.float32) + def init_weights_and_state(self, input_signature): + """Initializes this layer's single weight to zero.""" + del input_signature + self.weights = jnp.zeros((), dtype=jnp.float32) - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model. + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model. - Args: - inputs: Tensor. + Args: + inputs: Tensor. - Returns: - Tensor of same shape and dtype as the input. - """ - threshold = self.weights - return jnp.maximum(inputs, threshold) + Returns: + Tensor of same shape and dtype as the input. + """ + threshold = self.weights + return jnp.maximum(inputs, threshold) diff --git a/trax/layers/activation_fns_test.py b/trax/layers/activation_fns_test.py deleted file mode 100644 index 2f128bd47..000000000 --- a/trax/layers/activation_fns_test.py +++ /dev/null @@ -1,58 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for activation function layers.""" - -from absl.testing import absltest -import numpy as np - -import trax.layers as tl - - -class ActivationFnsTest(absltest.TestCase): - - def test_relu(self): - layer = tl.Relu() - x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) - y = layer(x) - self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, 2.0, 3.0, 5.0]) - - def test_parametric_relu(self): - layer = tl.ParametricRelu(a=.25) - x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) - y = layer(x) - self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, .5, .75, 1.25]) - - def test_leaky_relu(self): - layer = tl.LeakyRelu(a=.125) - x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) - y = layer(x) - self.assertEqual(tl.to_list(y), [-.25, -.125, 0.0, 2.0, 3.0, 5.0]) - - def test_hard_sigmoid(self): - layer = tl.HardSigmoid() - x = np.array([-1.5, -.5, -.25, 0.0, .25, .5, 1.5]) - y = layer(x) - self.assertEqual(tl.to_list(y), [0.0, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0]) - - def test_hard_tanh(self): - layer = tl.HardTanh() - x = np.array([-1.5, -.5, -.25, 0.0, .25, .5, 1.5]) - y = layer(x) - self.assertEqual(tl.to_list(y), [-1.0, -.5, -.25, 0.0, .25, .5, 1.0]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/assert_shape.py b/trax/layers/assert_shape.py index dffa85392..0ae28e467 100644 --- a/trax/layers/assert_shape.py +++ b/trax/layers/assert_shape.py @@ -24,268 +24,276 @@ def assert_shape(specification): - """Decorator for checking the input and output shapes of Layer. - - Decorator can be applied on trax.base.Layer class, or a function returning - a trax.base.Layer class. It uses notation similar to einsum (Einstein - summation convention), achieving concise and simple representation of tensors. - For example 'ij,jh->ih' is a valid representation of a function taking two - 2D matrices as input, and returning a single output, also a 2D matrix. - - It improves readability and puts puts three levels of asserts on the function: - first level is the number of input tensors and output tensors; second level is - the rank of each tensor; third level is the size of each dimension of each - tensor. The decorator inserts those asserts right before and right after - 'forward' call. - - First level, assert on number of inputs and outputs. In the representation - input tensors are separated from output tensors by an arrow '->'. For layers - taking multiple input tensors or returning multiple output tensors, those - tensors will be separated by a comma ','. - For example, specification 'bsd,df->bsf' asserts that there will be two - input tensors, with shapes represented by 'bsd' and 'df' respectively; and - a single output tensor with shape represented by 'bsf'. - - Second level, asserts on possible rank of each tensor. Most commonly, - each letter represents a single dimension. For example,the tensor with shapes - represented by 'bsd' has rank three; with 'df' it has rank two. The special - case is an ellipsis ('...'), which expand to arbitrary number of dimensions, - including zero. For example, the tensor with specification '...sf' has at - least two dimensions. Each tensor may have in its representation one ellipsis. - - Third level, asserts the size of each dimension. If two dimensions in any - of input or output tensors have the same letter in the representation then - they must have the same size. For example, with a tensor A represented by 'df' - and a tensor B represented by 'bsf', the size of the second dimension of A - must equal the size of the third dimension of B. Another example: with a - tensor C represented by '...dv' and a tensor D represented by 'd', the size of - the first and only dimension of D must be equal to the size of the second to - last dimension of tensor C. - - If two distinct tensors have an ellipsis in their representation then all of - dimensions covered by those ellipses must match. For example, with a tensor E - represented by '...d' and tensor F represented by '...x' then E and F must - have the same rank, and the sizes of all but the last dimensions must match. - - Examples: - # In Dense layer there is a single input and single output; the last dimension - # may change in size, while the sizes of all previous dimensions, marked by - # an ellipsis, will stay the same. - @assert_shape('...a->...b') - class Dense(base.Layer): - (...) - - # DotProductCausalAttention takes three tensors as input: Queries, Keys, and - # Values, and outputs a single tensor. Sizes of the first two dimensions in - # all those tensors must match, while the last dimension must match only - # between Queries and Keys, and separately between Values and output tensor. - @assert_shape('blk,blk,bld->bld') - class DotProductCausalAttention(base.Layer): - (...) - - # assert_shape can also be placed before the function returning base.Layer. - @assert_shape('...d->...') - def ReduceSum(): - return Fn('ReduceSum', lambda x: jnp.sum(x, axis=-1, keepdims=False)) - - Args: - specification: A text specification for the input/output tensors. - - Returns: - The decorator changing the class or function. - """ - caller = inspect.getframeinfo(inspect.stack()[1][0]) - message = f'Defined at {caller.filename}:{caller.lineno}' - - def wrap_cls(cls): - forward = getattr(cls, 'forward') - init = getattr(cls, '__init__') - - before_spec, after_spec = specification.split('->') - - @functools.wraps(init) - def init_wrapper(self, *args, **kwargs): - before_assert = AssertShape(before_spec, - message=message + ' function input') - after_assert = AssertShape(after_spec, - message=message + ' function output') - after_assert._create_link(before_assert) # pylint: disable=protected-access - out = init(self, *args, **kwargs) - self._before_assert_fun = before_assert # pylint: disable=protected-access - self._after_assert_fun = after_assert # pylint: disable=protected-access - return out - - @functools.wraps(forward) - def forward_wrapper(self, x, *args, **kwargs): - x = self._before_assert_fun.forward(x) # pylint: disable=protected-access - y = forward(self, x, *args, **kwargs) - y = self._after_assert_fun.forward(y) # pylint: disable=protected-access - return y - - setattr(cls, 'forward', forward_wrapper) - setattr(cls, '__init__', init_wrapper) - return cls - - # TODO(jaszczur): replace this with forward/init override. - def wrap_fun(fun): - @functools.wraps(fun) - def fun_wrapper(*args, **kwargs): - layer = fun(*args, **kwargs) - return AssertFunction(specification, layer, message) - return fun_wrapper - - def wrap_fun_or_cls(fun_or_cls): - return (wrap_cls(fun_or_cls) if inspect.isclass(fun_or_cls) else - wrap_fun(fun_or_cls)) - - return wrap_fun_or_cls + """Decorator for checking the input and output shapes of Layer. + + Decorator can be applied on trax.base.Layer class, or a function returning + a trax.base.Layer class. It uses notation similar to einsum (Einstein + summation convention), achieving concise and simple representation of tensors. + For example 'ij,jh->ih' is a valid representation of a function taking two + 2D matrices as input, and returning a single output, also a 2D matrix. + + It improves readability and puts puts three levels of asserts on the function: + first level is the number of input tensors and output tensors; second level is + the rank of each tensor; third level is the size of each dimension of each + tensor. The decorator inserts those asserts right before and right after + 'forward' call. + + First level, assert on number of inputs and outputs. In the representation + input tensors are separated from output tensors by an arrow '->'. For layers + taking multiple input tensors or returning multiple output tensors, those + tensors will be separated by a comma ','. + For example, specification 'bsd,df->bsf' asserts that there will be two + input tensors, with shapes represented by 'bsd' and 'df' respectively; and + a single output tensor with shape represented by 'bsf'. + + Second level, asserts on possible rank of each tensor. Most commonly, + each letter represents a single dimension. For example,the tensor with shapes + represented by 'bsd' has rank three; with 'df' it has rank two. The special + case is an ellipsis ('...'), which expand to arbitrary number of dimensions, + including zero. For example, the tensor with specification '...sf' has at + least two dimensions. Each tensor may have in its representation one ellipsis. + + Third level, asserts the size of each dimension. If two dimensions in any + of input or output tensors have the same letter in the representation then + they must have the same size. For example, with a tensor A represented by 'df' + and a tensor B represented by 'bsf', the size of the second dimension of A + must equal the size of the third dimension of B. Another example: with a + tensor C represented by '...dv' and a tensor D represented by 'd', the size of + the first and only dimension of D must be equal to the size of the second to + last dimension of tensor C. + + If two distinct tensors have an ellipsis in their representation then all of + dimensions covered by those ellipses must match. For example, with a tensor E + represented by '...d' and tensor F represented by '...x' then E and F must + have the same rank, and the sizes of all but the last dimensions must match. + + Examples: + # In Dense layer there is a single input and single output; the last dimension + # may change in size, while the sizes of all previous dimensions, marked by + # an ellipsis, will stay the same. + @assert_shape('...a->...b') + class Dense(base.Layer): + (...) + + # DotProductCausalAttention takes three tensors as input: Queries, Keys, and + # Values, and outputs a single tensor. Sizes of the first two dimensions in + # all those tensors must match, while the last dimension must match only + # between Queries and Keys, and separately between Values and output tensor. + @assert_shape('blk,blk,bld->bld') + class DotProductCausalAttention(base.Layer): + (...) + + # assert_shape can also be placed before the function returning base.Layer. + @assert_shape('...d->...') + def ReduceSum(): + return Fn('ReduceSum', lambda x: jnp.sum(x, axis=-1, keepdims=False)) + Args: + specification: A text specification for the input/output tensors. -def AssertFunction(specification, layer, message=None): # pylint: disable=invalid-name - """AssertFunction asserts shapes on the input/output tensors of a layer. - - It passes all inputs to the layer, and returns all outputs of the layer - unchanged. - - Args: - specification: A specification. See assert_shape decorator for a full - documentation. - layer: A base.Layer to wrap around. - message: An optional message to print if an assert fails. By default it will - print the filename and the line number where AssertFunction was called. - - Returns: - The given layer wrapped in asserts on its inputs and outputs. - """ - if message is None: + Returns: + The decorator changing the class or function. + """ caller = inspect.getframeinfo(inspect.stack()[1][0]) - message = f'Defined at {caller.filename}:{caller.lineno}' - before_spec, after_spec = specification.split('->') - before_assert = AssertShape(before_spec, message=message + ' function input') - after_assert = AssertShape(after_spec, message=message + ' function output') - after_assert._create_link(before_assert) # pylint: disable=protected-access - return combinators.Serial( - before_assert, layer, after_assert) + message = f"Defined at {caller.filename}:{caller.lineno}" + + def wrap_cls(cls): + forward = getattr(cls, "forward") + init = getattr(cls, "__init__") + + before_spec, after_spec = specification.split("->") + + @functools.wraps(init) + def init_wrapper(self, *args, **kwargs): + before_assert = AssertShape( + before_spec, message=message + " function input" + ) + after_assert = AssertShape(after_spec, message=message + " function output") + after_assert._create_link(before_assert) # pylint: disable=protected-access + out = init(self, *args, **kwargs) + self._before_assert_fun = before_assert # pylint: disable=protected-access + self._after_assert_fun = after_assert # pylint: disable=protected-access + return out + + @functools.wraps(forward) + def forward_wrapper(self, x, *args, **kwargs): + x = self._before_assert_fun.forward(x) # pylint: disable=protected-access + y = forward(self, x, *args, **kwargs) + y = self._after_assert_fun.forward(y) # pylint: disable=protected-access + return y + + setattr(cls, "forward", forward_wrapper) + setattr(cls, "__init__", init_wrapper) + return cls + + # TODO(jaszczur): replace this with forward/init override. + def wrap_fun(fun): + @functools.wraps(fun) + def fun_wrapper(*args, **kwargs): + layer = fun(*args, **kwargs) + return AssertFunction(specification, layer, message) + + return fun_wrapper + + def wrap_fun_or_cls(fun_or_cls): + return ( + wrap_cls(fun_or_cls) + if inspect.isclass(fun_or_cls) + else wrap_fun(fun_or_cls) + ) + + return wrap_fun_or_cls -class AssertShape(base.Layer): - """Layer which put asserts on shapes of tensors, and returns them unchanged. - - It borrows the notation from assert_shape decorator, except it doesn't have - the arrow '->' special character, as the input tensors are the same as output. - """ +def AssertFunction(specification, layer, message=None): # pylint: disable=invalid-name + """AssertFunction asserts shapes on the input/output tensors of a layer. - def __init__(self, spec, message=None, visible_layer=False): - """Creates AssertShape layer. + It passes all inputs to the layer, and returns all outputs of the layer + unchanged. Args: - spec: Specification for input tensors. See assert_shape decorator for the - full documentation. - message: An optional message to include when assert fails. By default it - includes the filename and line number where this function was called. - visible_layer: If true, print this layer inside the model (default: False) + specification: A specification. See assert_shape decorator for a full + documentation. + layer: A base.Layer to wrap around. + message: An optional message to print if an assert fails. By default it will + print the filename and the line number where AssertFunction was called. + + Returns: + The given layer wrapped in asserts on its inputs and outputs. """ - name = 'AssertShape' if visible_layer else '' - super().__init__(name=name) - spec = spec.replace('...', '.') - for letter in spec: - assert letter in string.ascii_letters + string.digits + '.' + ',' - self._specs = spec.split(',') - self._n_in = self._n_out = len(self._specs) + if message is None: + caller = inspect.getframeinfo(inspect.stack()[1][0]) + message = f"Defined at {caller.filename}:{caller.lineno}" + before_spec, after_spec = specification.split("->") + before_assert = AssertShape(before_spec, message=message + " function input") + after_assert = AssertShape(after_spec, message=message + " function output") + after_assert._create_link(before_assert) # pylint: disable=protected-access + return combinators.Serial(before_assert, layer, after_assert) - self._defined_shapes = {str(i): i for i in range(10)} - self._linked = False - if message is None: - caller = inspect.getframeinfo(inspect.stack()[1][0]) - self._message = f'Defined at {caller.filename}:{caller.lineno}' - else: - self._message = message - - def forward(self, xs): - if not self._linked: - for k in list(self._defined_shapes.keys()): - if not k.isdigit(): - del self._defined_shapes[k] - - if not isinstance(xs, (list, tuple)): - xs = [xs] - - # Try-except below checks if something is wrong with shapes. It can happen - # e.g. when using trax2keras. If this is the case we cannot check if shapes - # are correct or not - try: - for x in xs: - for i in range(len(x.shape)): - if x.shape[i] != x.shape[i]: - raise TypeError() - except TypeError: - message = ('AssertShape cannot check shapes. This often happens when' - ' using trax2keras. Shape asserts are skipped.') - print(message) - logging.warning(message) - if len(xs) == 1: - return xs[0] - else: - return xs - - # helper functions - def assert_true(cond): - if not cond: - shapes = [x.shape for x in xs] - defined_shapes_dict_without_digits = { - k: v for k, v in self._defined_shapes.items() if not k.isdigit()} - raise ValueError( - f'AssertShape Error. Expected {self._specs}, got {shapes} with dict' - f' {defined_shapes_dict_without_digits}. {self._message}') - - def assert_equal(a, b): - assert_true(a == b) - return a - - def check_shape(shape, spec): - assert_equal(len(shape), len(spec)) - for shape_dim, letter in zip(shape, spec): - if letter in self._defined_shapes: - self._defined_shapes[letter] = assert_equal( - self._defined_shapes[letter], shape_dim) +class AssertShape(base.Layer): + """Layer which put asserts on shapes of tensors, and returns them unchanged. + + It borrows the notation from assert_shape decorator, except it doesn't have + the arrow '->' special character, as the input tensors are the same as output. + """ + + def __init__(self, spec, message=None, visible_layer=False): + """Creates AssertShape layer. + + Args: + spec: Specification for input tensors. See assert_shape decorator for the + full documentation. + message: An optional message to include when assert fails. By default it + includes the filename and line number where this function was called. + visible_layer: If true, print this layer inside the model (default: False) + """ + name = "AssertShape" if visible_layer else "" + super().__init__(name=name) + spec = spec.replace("...", ".") + for letter in spec: + assert letter in string.ascii_letters + string.digits + "." + "," + self._specs = spec.split(",") + self._n_in = self._n_out = len(self._specs) + + self._defined_shapes = {str(i): i for i in range(10)} + self._linked = False + + if message is None: + caller = inspect.getframeinfo(inspect.stack()[1][0]) + self._message = f"Defined at {caller.filename}:{caller.lineno}" else: - self._defined_shapes[letter] = shape_dim - - def check_ellipsys(shape): - if '.' not in self._defined_shapes: - self._defined_shapes['.'] = shape - else: - assert_equal(len(shape), len(self._defined_shapes['.'])) - for s1, s2 in zip(shape, self._defined_shapes['.']): - assert_equal(s1, s2) - - # actual asserts - assert_equal(len(xs), len(self._specs)) - - for x, spec in zip(xs, self._specs): - if '.' in spec: - assert_true(len(x.shape) >= (len(spec) - 1)) - - before, after = spec.split('.') - check_shape(x.shape[:len(before)], before) - if after: - check_shape(x.shape[-len(after):], after) - check_ellipsys(x.shape[len(before):-len(after)]) + self._message = message + + def forward(self, xs): + if not self._linked: + for k in list(self._defined_shapes.keys()): + if not k.isdigit(): + del self._defined_shapes[k] + + if not isinstance(xs, (list, tuple)): + xs = (xs,) + + # Try-except below checks if something is wrong with shapes. It can happen + # e.g. when using trax2keras. If this is the case we cannot check if shapes + # are correct or not + try: + for x in xs: + for i in range(len(x.shape)): + if x.shape[i] != x.shape[i]: + raise TypeError() + except TypeError: + message = ( + "AssertShape cannot check shapes. This often happens when" + " using trax2keras. Shape asserts are skipped." + ) + print(message) + logging.warning(message) + if len(xs) == 1: + return xs[0] + else: + return xs + + # helper functions + def assert_true(cond): + if not cond: + shapes = [x.shape for x in xs] + defined_shapes_dict_without_digits = { + k: v for k, v in self._defined_shapes.items() if not k.isdigit() + } + raise ValueError( + f"AssertShape Error. Expected {self._specs}, got {shapes} with dict" + f" {defined_shapes_dict_without_digits}. {self._message}" + ) + + def assert_equal(a, b): + assert_true(a == b) + return a + + def check_shape(shape, spec): + assert_equal(len(shape), len(spec)) + for shape_dim, letter in zip(shape, spec): + if letter in self._defined_shapes: + self._defined_shapes[letter] = assert_equal( + self._defined_shapes[letter], shape_dim + ) + else: + self._defined_shapes[letter] = shape_dim + + def check_ellipsys(shape): + if "." not in self._defined_shapes: + self._defined_shapes["."] = shape + else: + assert_equal(len(shape), len(self._defined_shapes["."])) + for s1, s2 in zip(shape, self._defined_shapes["."]): + assert_equal(s1, s2) + + # actual asserts + assert_equal(len(xs), len(self._specs)) + + for x, spec in zip(xs, self._specs): + if "." in spec: + assert_true(len(x.shape) >= (len(spec) - 1)) + + before, after = spec.split(".") + check_shape(x.shape[: len(before)], before) + if after: + check_shape(x.shape[-len(after) :], after) + check_ellipsys(x.shape[len(before) : -len(after)]) + else: + # if len(after) == 0 then -len(after) in indices evaluates badly. + check_ellipsys(x.shape[len(before) :]) + else: + check_shape(x.shape, spec) + + if len(xs) == 1: + return xs[0] else: - # if len(after) == 0 then -len(after) in indices evaluates badly. - check_ellipsys(x.shape[len(before):]) - else: - check_shape(x.shape, spec) - - if len(xs) == 1: - return xs[0] - else: - return xs - - def _create_link(self, other): - """Internal. Used to create a shared dictionary.""" - # This works well for assert_shape and AssertFunction; but it can break - # easily if the order of calls to forward() is not known in advance. - self._linked = True - self._defined_shapes = other._defined_shapes # pylint: disable=protected-access + return xs + + def _create_link(self, other): + """Internal. Used to create a shared dictionary.""" + # This works well for assert_shape and AssertFunction; but it can break + # easily if the order of calls to forward() is not known in advance. + self._linked = True + self._defined_shapes = other._defined_shapes # pylint: disable=protected-access diff --git a/trax/layers/assert_shape_test.py b/trax/layers/assert_shape_test.py deleted file mode 100644 index 5c3995ba6..000000000 --- a/trax/layers/assert_shape_test.py +++ /dev/null @@ -1,275 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for assert shape layers.""" - -from absl.testing import absltest -import numpy as np - -import trax.layers as tl - - -class AssertFunctionTest(absltest.TestCase): - """Test AssertFunction layer.""" - - def test_simple_pass(self): - layer = tl.AssertFunction('abc->abc', tl.Dropout(rate=0.1)) - x = np.ones((2, 5, 20)) - layer(x) - - def test_simple_fail(self): - layer = tl.AssertFunction('abc->cba', tl.Dropout(rate=0.1)) - x = np.ones((2, 5, 20)) - with self.assertRaises(tl.LayerError): - layer(x) - - def test_reduce_rank_ellipsis_pass(self): - layer = tl.AssertFunction('...ab->...c', tl.Flatten(n_axes_to_keep=3)) - x = np.ones((1, 2, 3, 4, 5)) - layer(x) - - def test_reduce_rank_explicit_pass(self): - layer = tl.AssertFunction('xyzab->xyzc', tl.Flatten(n_axes_to_keep=3)) - x = np.ones((1, 2, 3, 4, 5)) - layer(x) - - def test_reduce_rank_to_one_pass(self): - layer = tl.AssertFunction('abcde->x', tl.Flatten(n_axes_to_keep=0)) - x = np.ones((1, 2, 3, 4, 5)) - layer(x) - - def test_reduce_rank_explicit_fail1(self): - layer = tl.AssertFunction('abcde->abcde', tl.Flatten(n_axes_to_keep=3)) - x = np.ones((1, 2, 3, 4, 5)) - with self.assertRaises(tl.LayerError): - layer(x) - - def test_reduce_rank_explicit_fail2(self): - layer = tl.AssertFunction('abcde->abcd', tl.Flatten(n_axes_to_keep=3)) - x = np.ones((1, 2, 3, 4, 5)) - with self.assertRaises(tl.LayerError): - layer(x) - - def test_two_outputs_pass(self): - layer = tl.AssertFunction( - '...cd->...x,...cd', - tl.Branch( - tl.Flatten(n_axes_to_keep=2), - tl.Dropout(rate=0.1), - )) - x = np.ones((1, 2, 3, 4)) - layer(x) - - def test_numeric_dimensions_pass(self): - layer = tl.AssertFunction( - '...34->1234,...34', - tl.Branch( - tl.Dropout(rate=0.1), - tl.Serial(), - )) - x = np.ones((1, 2, 3, 4)) - layer(x) - - def test_too_many_outputs_fail(self): - layer = tl.AssertFunction( - '...cd->...x,...cd,...cd,...cd', - tl.Branch( - tl.Flatten(n_axes_to_keep=2), - tl.Dropout(rate=0.1), - tl.Serial(), - )) - x = np.ones((1, 2, 3, 4)) - with self.assertRaises(tl.LayerError): - layer(x) - - def test_multi_output_rank_fail(self): - layer = tl.AssertFunction( - '...34->...x,...y', - tl.Branch( - tl.Flatten(n_axes_to_keep=3), - tl.Serial(), - )) - x = np.ones((1, 2, 3, 4)) - with self.assertRaises(tl.LayerError): - layer(x) - - -class AssertShapeTest(absltest.TestCase): - """Test AssertShape layer.""" - - def test_simple_pass(self): - layer = tl.AssertShape('aba,ba') - x = [np.ones((10, 5, 10)), np.zeros((5, 10))] - y = layer(x) - self.assertEqual(y, x) - - def test_same_shapes_pass(self): - layer = tl.AssertShape('aba,ba') - x = [np.ones((5, 5, 5)), np.zeros((5, 5))] - y = layer(x) - self.assertEqual(y, x) - - def test_single_arg_pass(self): - layer = tl.AssertShape('a') - x = np.ones((5,)) - y = layer(x) - self.assertEqual(y.tolist(), x.tolist()) - - def test_scalar_pass(self): - layer = tl.AssertShape('') - x = np.ones(()) - y = layer(x) - self.assertEqual(y.tolist(), x.tolist()) - - def test_square_matrix_pass(self): - layer = tl.AssertShape('aa') - x = np.ones((3, 3)) - y = layer(x) - self.assertEqual(y.tolist(), x.tolist()) - - def test_vector_scalar_pass(self): - layer = tl.AssertShape('a,') - x = [np.ones((5,)), np.zeros(())] - y = layer(x) - self.assertEqual(y, x) - - def test_three_args_pass(self): - layer = tl.AssertShape('a,b,a') - x = [np.ones((5,)), np.zeros((2)), np.zeros((5))] - y = layer(x) - self.assertEqual(y, x) - - def test_multiple_matching_dims_pass(self): - layer = tl.AssertShape('a,b,a,ab') - x = [np.ones((5,)), np.zeros((2)), np.zeros((5)), np.zeros((5, 2))] - y = layer(x) - self.assertEqual(y, x) - - def test_numeric_dims_pass(self): - layer = tl.AssertShape('23,1,93') - x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_numeric_dims_fail(self): - layer = tl.AssertShape('24,1,93') - x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipsis_middle_pass(self): - layer = tl.AssertShape('a...bc,abc') - x = [np.ones((1, 5, 5, 2, 3)), np.zeros((1, 2, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_ellipsis_prefix_pass(self): - layer = tl.AssertShape('...bc,abc') - x = [np.ones((5, 5, 2, 3)), np.zeros((1, 2, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_ellipsis_matching_zero_dims_pass(self): - layer = tl.AssertShape('...bc,abc') - x = [np.ones((2, 3)), np.zeros((1, 2, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_ellipsis_matching_ellipsis_pass(self): - layer = tl.AssertShape('...bc,...bc') - x = [np.ones((1, 2, 3)), np.zeros((1, 2, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_prefix_ellipsis_matching_sufix_ellipsis_pass(self): - layer = tl.AssertShape('bb...,...bb') - x = [np.ones((2, 2, 5, 6)), np.zeros((5, 6, 2, 2))] - y = layer(x) - self.assertEqual(y, x) - - def test_middle_ellipsis_fail(self): - layer = tl.AssertShape('ab...cde,2') - x = [np.ones((2, 3, 4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_short_middle_ellipsis_fail(self): - layer = tl.AssertShape('b...c,2') - x = [np.ones((2)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_double_ellipsis_fail(self): - layer = tl.AssertShape('b......c,2') - x = [np.ones((2, 3, 4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_typo_ellipsis_fail(self): - layer = tl.AssertShape('b..c,2') - x = [np.ones((2, 3, 4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipsis_matching_ellipsis_fail(self): - layer = tl.AssertShape('...a,...b') - x = [np.ones((1, 2, 3, 7)), np.zeros((1, 2, 8))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipsis_numeric_pass(self): - layer = tl.AssertShape('...22,...3') - x = [np.ones((1, 2, 3, 2, 2)), np.zeros((1, 2, 3, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_prefix_and_sufix_ellipsis_fail(self): - layer = tl.AssertShape('...c...,2') - x = [np.ones((2, 3, 4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipsis_too_few_dims_fail(self): - layer = tl.AssertShape('...abc,2') - x = [np.ones((4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipses_matching_dims_fail(self): - layer = tl.AssertShape('...2,...8') - x = [np.ones((1, 2, 3, 9)), np.zeros((1, 3, 3, 8))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_dims_matching_fail(self): - layer = tl.AssertShape('aba,ab') - x = [np.ones((10, 5, 10)), np.ones((5, 8))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_rank_fail(self): - layer = tl.AssertShape('aba,ab') - x = [np.ones((10, 5, 10)), np.ones((5, 10, 4))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_square_matrix_fail(self): - layer = tl.AssertShape('aa') - x = np.ones((10, 5)) - with self.assertRaises(tl.LayerError): - layer(x) - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/attention.py b/trax/layers/attention.py index 231c1eb8f..3820d58de 100644 --- a/trax/layers/attention.py +++ b/trax/layers/attention.py @@ -63,138 +63,33 @@ # inputs are [batch, length, depth], [batch, 1, 1 length] -@assert_shape('bld,b11l->bld,b11l') -def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): - """Returns a layer that maps `(vectors, mask)` to `(new_vectors, mask)`. - - This layer type represents one pass of multi-head self-attention, from vector - set to vector set, using masks to represent out-of-bound (e.g., padding) - positions. It: - - - makes three copies of incoming activations and maps these to multi-head - query (Q) vectors, key (K) vectors, and value (V) vectors, respectively; - - for each head, computes the scaled dot product of each Q-K pair; - - applies mask to screen out positions that come from padding tokens - (indicated by 0 value); - - [in ``'train'`` mode] applies dropout to Q-K dot products; - - for each head, computes Q-K attention strengths using a per-query softmax - of the Q-K dot products; - - for each head, for each query position, combines V vectors according - to the Q-K attention strengths; and - - concatenates and fuses resulting per-head vectors into outgoing - activations matching original input activation shapes. - - Args: - d_feature: Last/innermost dimension of activations in the input to and - output from this layer. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size - ``d_feature / n_heads``. - dropout: Probababilistic rate for attention dropout, which overrides - (sets to zero) some attention strengths derived from query-key - matching. As a result, on a given forward pass, some value vectors - don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only if layer is - created in ``'train'`` mode. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - """ - return cb.Serial( - cb.Select([0, 0, 0]), - AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), - ) - - -@assert_shape('bSq,blk,blv,b1xl->bSd,b1xl') -def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', - cache_KV_in_predict=False, q_sparsity=None, - result_sparsity=None): - """Returns a layer that maps `(AQ, AK, AV, mask)` to `(new-A, mask)`. - - Unlike :py:class:`Attention` above, :py:class:`AttentionQKV` allows the - incoming activations (`AQ`, `AK`, and `AV`) to come from different sources. - This is used, for instance, in encoder-decoder attention (Q-related - activations `AQ` from the decoder, K- and V-related activations -- `AK` and - `AV` -- from the encoder). Otherwise, see the :py:class:`Attention` - description for further context/details. - - Args: - d_feature: Last/innermost dimension of activations in the input to and - output from this layer. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size - ``d_feature / n_heads``. - dropout: Probababilistic rate for attention dropout, which overrides - (sets to zero) some attention strengths derived from query-key - matching. As a result, on a given forward pass, some value vectors - don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only if layer is - created in ``'train'`` mode. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode. - q_sparsity: Sparsity with which to process queries. If ``None``, - :py:class:`Dense` is used; if ``'noop'``, no processing is used. - result_sparsity: Sparsity with which to process result of the attention. - If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is - used. - """ - def _SparsifiableDense(layer_sparsity): - if layer_sparsity is None: - return core.Dense(d_feature) - elif layer_sparsity == 'noop': - return cb.Serial() # No-op layer. - else: - d_module = d_feature // layer_sparsity - return cb.Serial( - sparsity.FactoredDense(layer_sparsity, d_feature, d_feature), - sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode, - kernel_size=3, length_kernel_size=3) - ) - - def _CacheableDense(): - if cache_KV_in_predict and mode == 'predict': - return cb.Cache(core.Dense(d_feature)) - else: - return core.Dense(d_feature) - - def _PureAttention(): - return PureAttention(n_heads=n_heads, dropout=dropout, mode=mode) - - return cb.Serial( - cb.Parallel(_SparsifiableDense(q_sparsity), - _CacheableDense(), - _CacheableDense()), - _PureAttention(), - _SparsifiableDense(result_sparsity), - ) - - -# 'k' is number of keys/values, while 'l' is number of queries. Typically they -# will be the same, but it is not necessary. -@assert_shape('blq,bkq,bkd,b1xk->bld,b1xk') -class PureAttention(base.Layer): - """Returns a layer that maps `(Q, K, V, mask)` to `(activations, mask)`. - - This layer type performs the inner workings of one pass of multi-head - self-attention. It: - - - subdivides incoming Q/K/V activations into multi-head versions; - - for each head, computes the scaled dot product of each Q-K pair; - - applies mask to screen out positions that come from padding tokens - (indicated by 0 value); - - [in ``'train'`` mode] applies dropout to Q-K dot products; - - for each head, computes Q-K attention strengths using a per-query softmax - of the Q-K dot products; - - for each head, for each query position, combines V vectors according - to the Q-K attention strengths; and - - concatenates and fuses resulting per-head vectors into outgoing - activations matching original input activation shapes. - """ - - def __init__(self, n_heads=1, dropout=0.0, mode='train'): - """Returns a new :py:class:`PureAttention` instance. +@assert_shape("bld,b11l->bld,b11l") +def Attention(d_feature, n_heads=1, dropout=0.0, mode="train"): + """Returns a layer that maps `(vectors, mask)` to `(new_vectors, mask)`. + + This layer type represents one pass of multi-head self-attention, from vector + set to vector set, using masks to represent out-of-bound (e.g., padding) + positions. It: + + - makes three copies of incoming activations and maps these to multi-head + query (Q) vectors, key (K) vectors, and value (V) vectors, respectively; + - for each head, computes the scaled dot product of each Q-K pair; + - applies mask to screen out positions that come from padding tokens + (indicated by 0 value); + - [in ``'train'`` mode] applies dropout to Q-K dot products; + - for each head, computes Q-K attention strengths using a per-query softmax + of the Q-K dot products; + - for each head, for each query position, combines V vectors according + to the Q-K attention strengths; and + - concatenates and fuses resulting per-head vectors into outgoing + activations matching original input activation shapes. Args: - n_heads: Number of attention heads. + d_feature: Last/innermost dimension of activations in the input to and + output from this layer. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size + ``d_feature / n_heads``. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors @@ -203,305 +98,384 @@ def __init__(self, n_heads=1, dropout=0.0, mode='train'): created in ``'train'`` mode. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ - super().__init__(n_in=4, n_out=2) - self._n_heads = n_heads - self._dropout = dropout - self._mode = mode - - def forward(self, inputs): - """Returns attention-computed activations and unmodified mask. + return cb.Serial( + cb.Select([0, 0, 0]), + AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), + ) + + +@assert_shape("bSq,blk,blv,b1xl->bSd,b1xl") +def AttentionQKV( + d_feature, + n_heads=1, + dropout=0.0, + mode="train", + cache_KV_in_predict=False, + q_sparsity=None, + result_sparsity=None, +): + """Returns a layer that maps `(AQ, AK, AV, mask)` to `(new-A, mask)`. + + Unlike :py:class:`Attention` above, :py:class:`AttentionQKV` allows the + incoming activations (`AQ`, `AK`, and `AV`) to come from different sources. + This is used, for instance, in encoder-decoder attention (Q-related + activations `AQ` from the decoder, K- and V-related activations -- `AK` and + `AV` -- from the encoder). Otherwise, see the :py:class:`Attention` + description for further context/details. Args: - inputs: A `(Q, K, V, mask)` tuple, whose query, key, and value - activations have not yet been subdivided into heads. + d_feature: Last/innermost dimension of activations in the input to and + output from this layer. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size + ``d_feature / n_heads``. + dropout: Probababilistic rate for attention dropout, which overrides + (sets to zero) some attention strengths derived from query-key + matching. As a result, on a given forward pass, some value vectors + don't contribute to the output, analogous to how regular dropout can + cause some node activations to be ignored. Applies only if layer is + created in ``'train'`` mode. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode. + q_sparsity: Sparsity with which to process queries. If ``None``, + :py:class:`Dense` is used; if ``'noop'``, no processing is used. + result_sparsity: Sparsity with which to process result of the attention. + If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is + used. """ - q, k, v, mask = inputs - d_feature = q.shape[-1] - n_heads = self._n_heads - if d_feature % n_heads != 0: - raise ValueError( - f'Dimensionality of feature embedding ({d_feature}) is not a ' - f'multiple of the requested number of attention heads ({n_heads}).') - - per_head_results, dots = _per_head_attention( - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(q), - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(k), - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(v), - mask, - dropout=self._dropout, - mode=self._mode, - rng=self.rng) - if self._mode == 'viz': - self.state = dots - merged_results = MergeHeads(n_heads, merged_batch_and_head=False).forward( - per_head_results) - return (merged_results, mask) + def _SparsifiableDense(layer_sparsity): + if layer_sparsity is None: + return core.Dense(d_feature) + elif layer_sparsity == "noop": + return cb.Serial() # No-op layer. + else: + d_module = d_feature // layer_sparsity + return cb.Serial( + sparsity.FactoredDense(layer_sparsity, d_feature, d_feature), + sparsity.LocallyConvDense( + layer_sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=3, + ), + ) + + def _CacheableDense(): + if cache_KV_in_predict and mode == "predict": + return cb.Cache(core.Dense(d_feature)) + else: + return core.Dense(d_feature) + + def _PureAttention(): + return PureAttention(n_heads=n_heads, dropout=dropout, mode=mode) + + return cb.Serial( + cb.Parallel( + _SparsifiableDense(q_sparsity), _CacheableDense(), _CacheableDense() + ), + _PureAttention(), + _SparsifiableDense(result_sparsity), + ) -def _per_head_attention(queries, keys, values, mask, dropout, mode, rng): - """Computes new per-head activations via scaled dot-product attention. - - This function is the core of the attention mechanism. Given per-head - ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it: - - - computes the scaled dot product of each Q-K pair; - - applies ``mask`` to screen out positions that come from padding tokens - (indicated by 0 value); - - [in ``'train'`` mode] applies dropout to Q-K dot products; - - computes Q-K attention strengths using a per-query softmax of the Q-K dot - products; and - - for each query position, combines V vectors according to the Q-K - attention strengths. - - Args: - queries: Per-head activations representing attention queries. - keys: Per-head activations representing attention keys. - values: Per-head activations to be combined by computed attention strengths. - mask: Mask that distinguishes positions with real content vs. padding. - dropout: Probababilistic rate for attention dropout, which overrides - (sets to zero) some attention strengths derived from query-key - matching. As a result, on a given forward pass, some value vectors - don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only in ``'train'`` - mode. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - rng: Single-use random number generator (JAX PRNG key). - - Returns: - Tuple of (activations, attn_strengths), where activations are new per-head - activation vectors and attn_strengths is a matrix of per-head attention - strengths. - """ - if dropout >= 1.0: - raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.') - - d_feature = queries.shape[-1] - - dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) - if mask is not None: - dots = jnp.where(mask, - dots, - jnp.full_like(dots, -1e9)) - attn_strengths = ( - jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))) - if dropout is not None and dropout > 0.0 and mode == 'train': - keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape) - attn_strengths = jnp.where(keep, - attn_strengths / (1.0 - dropout), - jnp.zeros_like(attn_strengths)) - activations = jnp.matmul(attn_strengths, values).astype(jnp.float32) - attn_strengths = attn_strengths.astype(jnp.float32) - return activations, attn_strengths +# 'k' is number of keys/values, while 'l' is number of queries. Typically they +# will be the same, but it is not necessary. +@assert_shape("blq,bkq,bkd,b1xk->bld,b1xk") +class PureAttention(base.Layer): + """Returns a layer that maps `(Q, K, V, mask)` to `(activations, mask)`. + + This layer type performs the inner workings of one pass of multi-head + self-attention. It: + + - subdivides incoming Q/K/V activations into multi-head versions; + - for each head, computes the scaled dot product of each Q-K pair; + - applies mask to screen out positions that come from padding tokens + (indicated by 0 value); + - [in ``'train'`` mode] applies dropout to Q-K dot products; + - for each head, computes Q-K attention strengths using a per-query softmax + of the Q-K dot products; + - for each head, for each query position, combines V vectors according + to the Q-K attention strengths; and + - concatenates and fuses resulting per-head vectors into outgoing + activations matching original input activation shapes. + """ + def __init__(self, n_heads=1, dropout=0.0, mode="train"): + """Returns a new :py:class:`PureAttention` instance. + + Args: + n_heads: Number of attention heads. + dropout: Probababilistic rate for attention dropout, which overrides + (sets to zero) some attention strengths derived from query-key + matching. As a result, on a given forward pass, some value vectors + don't contribute to the output, analogous to how regular dropout can + cause some node activations to be ignored. Applies only if layer is + created in ``'train'`` mode. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + """ + super().__init__(n_in=4, n_out=2) + self._n_heads = n_heads + self._dropout = dropout + self._mode = mode + + def forward(self, inputs): + """Returns attention-computed activations and unmodified mask. + + Args: + inputs: A `(Q, K, V, mask)` tuple, whose query, key, and value + activations have not yet been subdivided into heads. + """ + q, k, v, mask = inputs + + d_feature = q.shape[-1] + n_heads = self._n_heads + if d_feature % n_heads != 0: + raise ValueError( + f"Dimensionality of feature embedding ({d_feature}) is not a " + f"multiple of the requested number of attention heads ({n_heads})." + ) + + per_head_results, dots = _per_head_attention( + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(q), + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(k), + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(v), + mask, + dropout=self._dropout, + mode=self._mode, + rng=self.rng, + ) + if self._mode == "viz": + self.state = dots + merged_results = MergeHeads(n_heads, merged_batch_and_head=False).forward( + per_head_results + ) + return (merged_results, mask) -class DotProductAttention(base.Layer): - """Returns a layer that computes per-head attention (via scaled dot-product). - This layer computes the core of the attention mechanism. Given per-head - queries (Q), keys (K), values (V), and mask, it: +def _per_head_attention(queries, keys, values, mask, dropout, mode, rng): + """Computes new per-head activations via scaled dot-product attention. - - computes the scaled dot product of each Q-K pair; - - applies mask to screen out positions that come from padding tokens - (indicated by 0 value); - - [if created in ``'train'`` mode] applies dropout to Q-K dot products; - - computes Q-K attention strengths using a per-query softmax of the Q-K dot - products; and - - for each query position, combines V vectors according to the Q-K - attention strengths. - """ + This function is the core of the attention mechanism. Given per-head + ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it: - def __init__(self, dropout=0.0, mode='train'): - """Creates a :py:class:`DotProductAttention` instance in a specific mode. + - computes the scaled dot product of each Q-K pair; + - applies ``mask`` to screen out positions that come from padding tokens + (indicated by 0 value); + - [in ``'train'`` mode] applies dropout to Q-K dot products; + - computes Q-K attention strengths using a per-query softmax of the Q-K dot + products; and + - for each query position, combines V vectors according to the Q-K + attention strengths. Args: + queries: Per-head activations representing attention queries. + keys: Per-head activations representing attention keys. + values: Per-head activations to be combined by computed attention strengths. + mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only if layer is - created in ``'train'`` mode. - mode: One of ``'train'``, ``'eval'``, ``'predict'`` or ``'viz'``. - """ - super().__init__(n_in=4, n_out=1) - self._dropout = dropout - self._mode = mode - - def forward(self, inputs): - """Returns attention-computed per-head activations and unchanged mask. + cause some node activations to be ignored. Applies only in ``'train'`` + mode. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + rng: Single-use random number generator (JAX PRNG key). - Args: - inputs: A `(Q, K, V, mask)` tuple, whose query, key, and value - activations have been subdivided into heads. + Returns: + Tuple of (activations, attn_strengths), where activations are new per-head + activation vectors and attn_strengths is a matrix of per-head attention + strengths. """ - q, k, v, mask = inputs - activations, attn_strengths = _per_head_attention( - q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) - if self._mode == 'viz': - self.state = attn_strengths - return activations + if dropout >= 1.0: + raise ValueError(f"Dropout rate ({dropout}) must be lower than 1.") + d_feature = queries.shape[-1] -# (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) -@assert_shape('bld->...lh') -def SplitIntoHeads(n_heads, merged_batch_and_head=True): - """Returns a layer that reshapes an array for multi-head computation.""" - def f(x): - batch_size, seq_len, d_feature = x.shape - if d_feature % n_heads != 0: - raise ValueError( - f'Feature embedding dimensionality ({d_feature}) is not a multiple' - f' of the requested number of attention heads ({n_heads}).') + dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) + if mask is not None: + dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) + attn_strengths = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) + if dropout is not None and dropout > 0.0 and mode == "train": + keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape) + attn_strengths = jnp.where( + keep, attn_strengths / (1.0 - dropout), jnp.zeros_like(attn_strengths) + ) + activations = jnp.matmul(attn_strengths, values).astype(jnp.float32) + attn_strengths = attn_strengths.astype(jnp.float32) + return activations, attn_strengths - d_head = d_feature // n_heads - # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) - x = x.reshape((batch_size, seq_len, n_heads, d_head)) - x = x.transpose((0, 2, 1, 3)) - if merged_batch_and_head: - x = x.reshape((batch_size * n_heads, seq_len, d_head)) - return x - return Fn('SplitIntoHeads', f) +class DotProductAttention(base.Layer): + """Returns a layer that computes per-head attention (via scaled dot-product). + + This layer computes the core of the attention mechanism. Given per-head + queries (Q), keys (K), values (V), and mask, it: + + - computes the scaled dot product of each Q-K pair; + - applies mask to screen out positions that come from padding tokens + (indicated by 0 value); + - [if created in ``'train'`` mode] applies dropout to Q-K dot products; + - computes Q-K attention strengths using a per-query softmax of the Q-K dot + products; and + - for each query position, combines V vectors according to the Q-K + attention strengths. + """ + def __init__(self, dropout=0.0, mode="train"): + """Creates a :py:class:`DotProductAttention` instance in a specific mode. + + Args: + dropout: Probababilistic rate for attention dropout, which overrides + (sets to zero) some attention strengths derived from query-key + matching. As a result, on a given forward pass, some value vectors + don't contribute to the output, analogous to how regular dropout can + cause some node activations to be ignored. Applies only if layer is + created in ``'train'`` mode. + mode: One of ``'train'``, ``'eval'``, ``'predict'`` or ``'viz'``. + """ + super().__init__(n_in=4, n_out=1) + self._dropout = dropout + self._mode = mode + + def forward(self, inputs): + """Returns attention-computed per-head activations and unchanged mask. + + Args: + inputs: A `(Q, K, V, mask)` tuple, whose query, key, and value + activations have been subdivided into heads. + """ + q, k, v, mask = inputs + activations, attn_strengths = _per_head_attention( + q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng + ) + if self._mode == "viz": + self.state = attn_strengths + return activations -# (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) -@assert_shape('...lh->bld') -def MergeHeads(n_heads, merged_batch_and_head=True): - """Returns a layer that rejoins heads, after multi-head computation.""" - def f(x): - if merged_batch_and_head: - dim_0, seq_len, d_head = x.shape - if dim_0 % n_heads != 0: - raise ValueError( - f"Array's leading dimension ({dim_0}) is not a multiple of the" - f" number of attention heads ({n_heads}).") - batch_size = dim_0 // n_heads - x = x.reshape((batch_size, n_heads, seq_len, d_head)) - else: - batch_size, _, seq_len, d_head = x.shape - - # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) - x = x.transpose((0, 2, 1, 3)) - x = x.reshape((batch_size, seq_len, n_heads * d_head)) - return x - return Fn('MergeHeads', f) - - -@assert_shape('bld->bld') -def ConfigurableAttention(q_layer, k_layer, v_layer, final_layer, # pylint: disable=invalid-name - qkv_attention_layer, n_heads=1): - """Returns a configured multi-head self-attention layer. - - A :py:class:`ConfigurableAttention` layer acts similarly to - :py:class:`Attention` layers, but with configurable components. It - - - makes three copies of incoming activations and uses ``q_layer``, - ``k_layer``, and ``v_layer`` to map activations to multi-head query (Q) - vectors, key (K) vectors, and value (V) vectors, respectively; - - uses ``qkv_attention_layer`` to compute per-head attention, similar to - :py:class:`DotProductAttention` or :py:class:`DotProductCausalAttention`; - - concatenates and fuses resulting per-head vectors into activations - matching original input activation shapes; and - - applies a final layer, ``final_layer``, mapping activations to - activations (with shape matching the original input activations). - - Args: - q_layer: Layer that maps input activations to per-head query activations. - k_layer: Layer that maps input activations to per-head key activations. - v_layer: Layer that maps input activations to per-head value activations. - final_layer: After main multi-head computation and rejoining of heads, - layer that maps activations to activations (with shape matching the - original input activations). - qkv_attention_layer: Layer the does the core multi-head self-attention - computation. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size - ``d_feature / n_heads``. - """ - return cb.Serial( - cb.Branch( - [q_layer, SplitIntoHeads(n_heads)], - [k_layer, SplitIntoHeads(n_heads)], - [v_layer, SplitIntoHeads(n_heads)], - ), - qkv_attention_layer, - MergeHeads(n_heads), - final_layer - ) - - -@assert_shape('bld->bld') -def CausalAttention(d_feature, - n_heads=1, - dropout=0.0, - max_inference_length=2048, - use_dconv=False, - mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like :py:class:`Attention`, this layer type represents one pass of multi-head - self-attention, but with causal masking rather than padding-based masking. - - Args: - d_feature: Last/innermost dimension of activations in the input to and - output from this layer. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size - ``d_feature / n_heads``. - dropout: Probababilistic rate for attention dropout, which overrides - (sets to zero) some attention strengths derived from query-key - matching. As a result, on a given forward pass, some value vectors - don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only if layer is - created in ``'train'`` mode. - max_inference_length: Maximum sequence length allowed in non-training - modes. - use_dconv: if True, use depthwise convolutions on top of dense layers - for Q, K and V. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - """ - if d_feature % n_heads != 0: - raise ValueError( - f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' - f'of the requested number of attention heads ({n_heads}).') - - def QKVLayer(): - """Function returning the Q, K and V layer.""" - if use_dconv: - return cb.Serial(core.Dense(d_feature), convolution.CausalDepthwiseConv()) - else: - return core.Dense(d_feature) +# (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) +@assert_shape("bld->...lh") +def SplitIntoHeads(n_heads, merged_batch_and_head=True): + """Returns a layer that reshapes an array for multi-head computation.""" - return ConfigurableAttention( - QKVLayer(), - QKVLayer(), - QKVLayer(), - core.Dense(d_feature), - n_heads=n_heads, - qkv_attention_layer=DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) + def f(x): + batch_size, seq_len, d_feature = x.shape + if d_feature % n_heads != 0: + raise ValueError( + f"Feature embedding dimensionality ({d_feature}) is not a multiple" + f" of the requested number of attention heads ({n_heads})." + ) + d_head = d_feature // n_heads -@assert_shape('bld,bld,bld->bld') -class DotProductCausalAttention(base.Layer): - """Layer that computes attention strengths by masking out the "future". + # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) + x = jnp.reshape(x, (batch_size, seq_len, n_heads, d_head)) + x = x.transpose((0, 2, 1, 3)) + if merged_batch_and_head: + x = jnp.reshape(x, (batch_size * n_heads, seq_len, d_head)) + return x - Causal attention uses masking to prevent a given sequence position from - attending to positions greater than / following it. This is used, for - example, when training autoregressive sequence models, or when decoding a - sequence symbol by symbol. + return Fn("SplitIntoHeads", f) - This layer performs the core per-head attention calculation. The layer - assumes that any splitting into attention heads precedes it, and that any - merging of attention heads will follow it. - """ - def __init__(self, dropout=0.0, max_inference_length=2048, mode='train'): - """Creates a :py:class:`DotProductCausalAttention` instance. +# (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) +@assert_shape("...lh->bld") +def MergeHeads(n_heads, merged_batch_and_head=True): + """Returns a layer that rejoins heads, after multi-head computation.""" + + def f(x): + if merged_batch_and_head: + dim_0, seq_len, d_head = x.shape + if dim_0 % n_heads != 0: + raise ValueError( + f"Array's leading dimension ({dim_0}) is not a multiple of the" + f" number of attention heads ({n_heads})." + ) + + batch_size = dim_0 // n_heads + x = x.reshape((batch_size, n_heads, seq_len, d_head)) + else: + batch_size, _, seq_len, d_head = x.shape + + # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) + x = x.transpose((0, 2, 1, 3)) + x = x.reshape((batch_size, seq_len, n_heads * d_head)) + return x + + return Fn("MergeHeads", f) + + +@assert_shape("bld->bld") +def ConfigurableAttention( + q_layer, + k_layer, + v_layer, + final_layer, # pylint: disable=invalid-name + qkv_attention_layer, + n_heads=1, +): + """Returns a configured multi-head self-attention layer. + + A :py:class:`ConfigurableAttention` layer acts similarly to + :py:class:`Attention` layers, but with configurable components. It + + - makes three copies of incoming activations and uses ``q_layer``, + ``k_layer``, and ``v_layer`` to map activations to multi-head query (Q) + vectors, key (K) vectors, and value (V) vectors, respectively; + - uses ``qkv_attention_layer`` to compute per-head attention, similar to + :py:class:`DotProductAttention` or :py:class:`DotProductCausalAttention`; + - concatenates and fuses resulting per-head vectors into activations + matching original input activation shapes; and + - applies a final layer, ``final_layer``, mapping activations to + activations (with shape matching the original input activations). + + Args: + q_layer: Layer that maps input activations to per-head query activations. + k_layer: Layer that maps input activations to per-head key activations. + v_layer: Layer that maps input activations to per-head value activations. + final_layer: After main multi-head computation and rejoining of heads, + layer that maps activations to activations (with shape matching the + original input activations). + qkv_attention_layer: Layer the does the core multi-head self-attention + computation. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size + ``d_feature / n_heads``. + """ + return cb.Serial( + cb.Branch( + [q_layer, SplitIntoHeads(n_heads)], + [k_layer, SplitIntoHeads(n_heads)], + [v_layer, SplitIntoHeads(n_heads)], + ), + qkv_attention_layer, + MergeHeads(n_heads), + final_layer, + ) + + +@assert_shape("bld->bld") +def CausalAttention( + d_feature, + n_heads=1, + dropout=0.0, + max_inference_length=2048, + use_dconv=False, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like :py:class:`Attention`, this layer type represents one pass of multi-head + self-attention, but with causal masking rather than padding-based masking. Args: + d_feature: Last/innermost dimension of activations in the input to and + output from this layer. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size + ``d_feature / n_heads``. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors @@ -510,345 +484,435 @@ def __init__(self, dropout=0.0, max_inference_length=2048, mode='train'): created in ``'train'`` mode. max_inference_length: Maximum sequence length allowed in non-training modes. + use_dconv: if True, use depthwise convolutions on top of dense layers + for Q, K and V. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ - super().__init__(n_in=3, n_out=1) - self._dropout = dropout - self._mode = mode - self._max_len = max_inference_length - self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none - - def monkey_patched_mask(self): - # This is necessary for Terraformer model. See comments there. - # The mask will only be used in Terraformer in predict mode. - return None + if d_feature % n_heads != 0: + raise ValueError( + f"Dimensionality of feature embedding ({d_feature}) is not a multiple " + f"of the requested number of attention heads ({n_heads})." + ) + + def QKVLayer(): + """Function returning the Q, K and V layer.""" + if use_dconv: + return cb.Serial(core.Dense(d_feature), convolution.CausalDepthwiseConv()) + else: + return core.Dense(d_feature) + + return ConfigurableAttention( + QKVLayer(), + QKVLayer(), + QKVLayer(), + core.Dense(d_feature), + n_heads=n_heads, + qkv_attention_layer=DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("bld,bld,bld->bld") +class DotProductCausalAttention(base.Layer): + """Layer that computes attention strengths by masking out the "future". - def forward(self, inputs): - """Returns attention-computed activations. + Causal attention uses masking to prevent a given sequence position from + attending to positions greater than / following it. This is used, for + example, when training autoregressive sequence models, or when decoding a + sequence symbol by symbol. - Args: - inputs: A (queries, keys, values) tuple. + This layer performs the core per-head attention calculation. The layer + assumes that any splitting into attention heads precedes it, and that any + merging of attention heads will follow it. """ - q, k, v = inputs - if self._portal_mask is not None: - mask_for_predict = self._portal_mask.get_value() - else: - mask_for_predict = None - - if self._mode == 'predict': - self.state, mask = _fast_inference_update_state( - inputs, self.state, - mask_for_predict=mask_for_predict) - if self._portal_mask is not None: - (_, k, v, _) = self.state - else: - (k, v, _) = self.state + def __init__(self, dropout=0.0, max_inference_length=2048, mode="train"): + """Creates a :py:class:`DotProductCausalAttention` instance. + + Args: + dropout: Probababilistic rate for attention dropout, which overrides + (sets to zero) some attention strengths derived from query-key + matching. As a result, on a given forward pass, some value vectors + don't contribute to the output, analogous to how regular dropout can + cause some node activations to be ignored. Applies only if layer is + created in ``'train'`` mode. + max_inference_length: Maximum sequence length allowed in non-training + modes. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + """ + super().__init__(n_in=3, n_out=1) + self._dropout = dropout + self._mode = mode + self._max_len = max_inference_length + self._portal_mask = ( + self.monkey_patched_mask() + ) # pylint: disable=assignment-from-none + + def monkey_patched_mask(self): + # This is necessary for Terraformer model. See comments there. + # The mask will only be used in Terraformer in predict mode. + return None + + def forward(self, inputs): + """Returns attention-computed activations. + + Args: + inputs: A (queries, keys, values) tuple. + """ + q, k, v = inputs + + if self._portal_mask is not None: + mask_for_predict = self._portal_mask.get_value() + else: + mask_for_predict = None + + if self._mode == "predict": + self.state, mask = _fast_inference_update_state( + inputs, self.state, mask_for_predict=mask_for_predict + ) + if self._portal_mask is not None: + (_, k, v, _) = self.state + else: + (k, v, _) = self.state + else: + sequence_length = q.shape[-2] + mask = _causal_mask(sequence_length) + + activations, attn_strengths = _per_head_attention( + q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng + ) + if self._mode == "viz": + self.state = attn_strengths + return activations + + def init_weights_and_state(self, input_signature): + """Initializes this layer for fast inference, if in ``'predict'`` mode.""" + if self._mode == "predict": + self.state = _fast_inference_init_state( + input_signature, self._max_len, predict_mask=self._portal_mask + ) + + +def _causal_mask(length): + # Not all backends define jnp.tril. However, using np.tril is inefficient + # in that it creates a large global constant. TODO(kitaev): try to find an + # alternative that works across all backends. + if fastmath.is_backend(fastmath.Backend.JAX): + return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0) else: - sequence_length = q.shape[-2] - mask = _causal_mask(sequence_length) + return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0) - activations, attn_strengths = _per_head_attention( - q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) - if self._mode == 'viz': - self.state = attn_strengths - return activations - def init_weights_and_state(self, input_signature): - """Initializes this layer for fast inference, if in ``'predict'`` mode.""" - if self._mode == 'predict': - self.state = _fast_inference_init_state( - input_signature, self._max_len, - predict_mask=self._portal_mask) +@assert_shape("...d->...d") +def ShiftRight(n_positions=1, mode="train"): + """Returns a layer that can insert padding to shift the input sequence. + Args: + n_positions: Number of positions to shift the input sequence rightward; + initial positions freed by the shift get padded with zeros. Applies + only if layer is created in a non-``'eval'`` mode. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + """ + # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads? + def f(x): + if mode == "predict": + return x + padded = _zero_pad(x, (n_positions, 0), 1) + return padded[:, :-n_positions] -def _causal_mask(length): - # Not all backends define jnp.tril. However, using np.tril is inefficient - # in that it creates a large global constant. TODO(kitaev): try to find an - # alternative that works across all backends. - if fastmath.is_backend(fastmath.Backend.JAX): - return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0) - else: - return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0) - - -@assert_shape('...d->...d') -def ShiftRight(n_positions=1, mode='train'): - """Returns a layer that can insert padding to shift the input sequence. - - Args: - n_positions: Number of positions to shift the input sequence rightward; - initial positions freed by the shift get padded with zeros. Applies - only if layer is created in a non-``'eval'`` mode. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - """ - # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads? - def f(x): - if mode == 'predict': - return x - padded = _zero_pad(x, (n_positions, 0), 1) - return padded[:, :-n_positions] - return Fn(f'ShiftRight({n_positions})', f) - - -@assert_shape('bs->b11l') + return Fn(f"ShiftRight({n_positions})", f) + + +@assert_shape("bs->b11l") def PaddingMask(pad=0): - """Returns a layer that maps integer sequences to padding masks. - - The layer expects as input a batch of integer sequences. The layer output is - an N-D array that marks for each sequence position whether the integer (e.g., - a token ID) in that position represents padding -- value ``pad`` -- versus - text/content -- all other values. The padding mask shape is - (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast - to cover any number of attention heads and axis 2 will broadcast to cover - decoder sequence positions. - - Args: - pad: Integer that represents padding rather than a token/content ID. - """ - def f(x): - if len(x.shape) != 2: - raise ValueError( - f'Input to PaddingMask must be a 2-D array with shape ' - f'(batch_size, sequence_length); instead got shape {x.shape}.') - batch_size = x.shape[0] - sequence_length = x.shape[1] - content_positions = (x != pad) - return content_positions.reshape((batch_size, 1, 1, sequence_length)) - return Fn(f'PaddingMask({pad})', f) + """Returns a layer that maps integer sequences to padding masks. + + The layer expects as input a batch of integer sequences. The layer output is + an N-D array that marks for each sequence position whether the integer (e.g., + a token ID) in that position represents padding -- value ``pad`` -- versus + text/content -- all other values. The padding mask shape is + (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast + to cover any number of attention heads and axis 2 will broadcast to cover + decoder sequence positions. + + Args: + pad: Integer that represents padding rather than a token/content ID. + """ + + def f(x): + if len(x.shape) != 2: + raise ValueError( + f"Input to PaddingMask must be a 2-D array with shape " + f"(batch_size, sequence_length); instead got shape {x.shape}." + ) + batch_size = x.shape[0] + sequence_length = x.shape[1] + content_positions = x != pad + return content_positions.reshape((batch_size, 1, 1, sequence_length)) + + return Fn(f"PaddingMask({pad})", f) def EncoderDecoderMask(): - """Returns a layer that creates a mask for encoder-decoder cross attention. - - The layer expects two inputs: - - - decoder_input: batch of integer (e.g., token ID) sequences - - mask: padding mask from the encoder - - The layer output is a mask that marks for each sequence position (for both - encoder and decoder) whether that position can be attended to or not. The - encoder-decoder mask shape is (batch_size, 1, decoder_sequence_length, - encoder_sequence_length), such that axis 1 will automatically broadcast to - cover any number of attention heads. - """ - def f(decoder_input, mask): - if len(decoder_input.shape) != 3: - raise ValueError( - f'Decoder input to EncoderDecoderMask must be a 3-D array with ' - f'shape (batch_size, decoder_sequence_length, d_model); instead got ' - f'shape {decoder_input.shape}.') - batch_size = mask.shape[0] - encoder_sequence_length = mask.shape[-1] - decoder_sequence_length = decoder_input.shape[1] - mask = mask.reshape((batch_size, 1, 1, encoder_sequence_length)) - return mask + jnp.zeros((1, 1, decoder_sequence_length, 1)) - return Fn('EncoderDecoderMask', f) - - -@assert_shape('...d->...d') -class PositionalEncoding(base.Layer): - """Implements bare positional encoding. + """Returns a layer that creates a mask for encoder-decoder cross attention. - Positional encoding includes a kind of dropout, if the layer is created in - ``'train'`` mode with a nonzero ``dropout`` value. For such a layer, on each - forward pass a subset of sequence positions selected at random will *not* - receive positional marking. - """ + The layer expects two inputs: - def __init__(self, max_len=2048, dropout=0.0, dropout_broadcast_dims=(-2,), - use_bfloat16=False, start_from_zero_prob=1.0, - max_offset_to_add=0, d_feature=None, mode='train'): - """Creates a :py:class:`PositionalEncoding` instance in a given mode. + - decoder_input: batch of integer (e.g., token ID) sequences + - mask: padding mask from the encoder - Args: - max_len: Maximum input sequence length. - dropout: Probability of *not* adding positional encoding to a sequence - position. Applies only if layer is created in ``'train'`` mode. - dropout_broadcast_dims: Axes along which dropout mask values are - broadcast rather than individually set at random. - use_bfloat16: If ``True``, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - max_offset_to_add: maximum offset to add to the positions during training - when randomizing; this offset plus input length must still be less than - max_len for all training examples. - d_feature: int or None; have this dimension for embeddings + shared FF if - not None. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + The layer output is a mask that marks for each sequence position (for both + encoder and decoder) whether that position can be attended to or not. The + encoder-decoder mask shape is (batch_size, 1, decoder_sequence_length, + encoder_sequence_length), such that axis 1 will automatically broadcast to + cover any number of attention heads. """ - super().__init__() - self._max_len = max_len - if dropout >= 1.0: - raise ValueError('Dropout rates must be lower than 1.') - if mode == 'train': - self._dropout = dropout - else: - self._dropout = 0.0 - self._dropout_broadcast_dims = dropout_broadcast_dims - self._use_bfloat16 = use_bfloat16 - self._start_from_zero_prob = start_from_zero_prob - self._max_offset_to_add = max_offset_to_add - self._mode = mode - self._d_feature = d_feature - - def forward(self, inputs): - """Returns the input activations, with added positional information.""" - weights = self.weights - if self._d_feature is not None: - weights, ff = weights - weights = jnp.dot(weights[:inputs.shape[1], :], ff) - if len(weights.shape) < 3: # old checkpoints have 1 in first dim already - weights = weights[None, :, :] # [1, self._max_len, d_feature] - if self._mode != 'predict': - x = inputs - symbol_size = jnp.shape(x)[1] - if self._mode != 'train' or self._start_from_zero_prob >= 1.0: - px = weights[:, :symbol_size, :] - else: - rng1, rng2 = fastmath.random.split(self.rng, 2) - start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) - start_from_zero = fastmath.random.uniform(rng2, (), jnp.float32, 0, 1) - start = jnp.where(start_from_zero < self._start_from_zero_prob, - jnp.zeros((), dtype=jnp.int32), start) - px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size, - axis=1) - if self._dropout == 0: - return x + px - else: - noise_shape = list(px.shape) - for dim in self._dropout_broadcast_dims: - noise_shape[dim] = 1 - keep_prob = 1.0 - self._dropout - keep = fastmath.random.bernoulli(self.rng, keep_prob, - tuple(noise_shape)) - multiplier = keep.astype(x.dtype) / keep_prob - return x + px * multiplier - else: - if self._dropout != 0: - raise ValueError(f'In predict mode, but dropout rate ' - f'({self._dropout}) is not zero.') - - # State in this class is only used for fast inference. In that case, - # the model is called with consecutive elements position-by-position. - # This positional encoding layer stores the index of the current - # position and increments it on each call. - emb = fastmath.dynamic_slice_in_dim( - weights, self.state, inputs.shape[1], axis=1) - self.state += inputs.shape[1] - return inputs + emb - - def init_weights_and_state(self, input_signature): - """Randomly initializes the positional encoding vectors. - Args: - input_signature: :py:class:`ShapeDtype` instance characterizing the input - this layer should compute on. + def f(decoder_input, mask): + if len(decoder_input.shape) != 3: + raise ValueError( + f"Decoder input to EncoderDecoderMask must be a 3-D array with " + f"shape (batch_size, decoder_sequence_length, d_model); instead got " + f"shape {decoder_input.shape}." + ) + batch_size = mask.shape[0] + encoder_sequence_length = mask.shape[-1] + decoder_sequence_length = decoder_input.shape[1] + mask = mask.reshape((batch_size, 1, 1, encoder_sequence_length)) + return mask + jnp.zeros((1, 1, decoder_sequence_length, 1)) + + return Fn("EncoderDecoderMask", f) + + +@assert_shape("...d->...d") +class PositionalEncoding(base.Layer): + """Implements bare positional encoding. + + Positional encoding includes a kind of dropout, if the layer is created in + ``'train'`` mode with a nonzero ``dropout`` value. For such a layer, on each + forward pass a subset of sequence positions selected at random will *not* + receive positional marking. """ - d_feature = input_signature.shape[-1] - if self._d_feature is not None: - d_feature = self._d_feature - pe = np.zeros((self._max_len, d_feature), dtype=np.float32) - position = np.arange(0, self._max_len)[:, np.newaxis] - div_term = np.exp( - np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) - pe[:, 0::2] = np.sin(position * div_term) - pe[:, 1::2] = np.cos(position * div_term) # [self._max_len, d_feature] - if self._use_bfloat16: - pe = pe.astype(jnp.bfloat16) - w = jnp.array(pe) # Trainable parameters, initialized above. - if self._d_feature is not None: - ff = init.GlorotUniformInitializer()( - (d_feature, input_signature.shape[-1]), self.rng) - self.weights = w, ff - else: - self.weights = w - if self._mode == 'predict': - self.state = jnp.zeros((), dtype=jnp.int32) + + def __init__( + self, + max_len=2048, + dropout=0.0, + dropout_broadcast_dims=(-2,), + use_bfloat16=False, + start_from_zero_prob=1.0, + max_offset_to_add=0, + d_feature=None, + mode="train", + ): + """Creates a :py:class:`PositionalEncoding` instance in a given mode. + + Args: + max_len: Maximum input sequence length. + dropout: Probability of *not* adding positional encoding to a sequence + position. Applies only if layer is created in ``'train'`` mode. + dropout_broadcast_dims: Axes along which dropout mask values are + broadcast rather than individually set at random. + use_bfloat16: If ``True``, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + max_offset_to_add: maximum offset to add to the positions during training + when randomizing; this offset plus input length must still be less than + max_len for all training examples. + d_feature: int or None; have this dimension for embeddings + shared FF if + not None. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + """ + super().__init__() + self._max_len = max_len + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + if mode == "train": + self._dropout = dropout + else: + self._dropout = 0.0 + self._dropout_broadcast_dims = dropout_broadcast_dims + self._use_bfloat16 = use_bfloat16 + self._start_from_zero_prob = start_from_zero_prob + self._max_offset_to_add = max_offset_to_add + self._mode = mode + self._d_feature = d_feature + + def forward(self, inputs): + """Returns the input activations, with added positional information.""" + weights = self.weights + if self._d_feature is not None: + weights, ff = weights + weights = jnp.dot(weights[: inputs.shape[1], :], ff) + if len(weights.shape) < 3: # old checkpoints have 1 in first dim already + weights = weights[None, :, :] # [1, self._max_len, d_feature] + if self._mode != "predict": + x = inputs + symbol_size = jnp.shape(x)[1] + if self._mode != "train" or self._start_from_zero_prob >= 1.0: + px = weights[:, :symbol_size, :] + else: + rng1, rng2 = fastmath.random.split(self.rng, 2) + start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) + start_from_zero = fastmath.random.uniform(rng2, (), jnp.float32, 0, 1) + start = jnp.where( + start_from_zero < self._start_from_zero_prob, + jnp.zeros((), dtype=jnp.int32), + start, + ) + px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size, axis=1) + if self._dropout == 0: + return x + px + else: + noise_shape = list(px.shape) + for dim in self._dropout_broadcast_dims: + noise_shape[dim] = 1 + keep_prob = 1.0 - self._dropout + keep = fastmath.random.bernoulli( + self.rng, keep_prob, tuple(noise_shape) + ) + multiplier = keep.astype(x.dtype) / keep_prob + return x + px * multiplier + else: + if self._dropout != 0: + raise ValueError( + f"In predict mode, but dropout rate " + f"({self._dropout}) is not zero." + ) + + # State in this class is only used for fast inference. In that case, + # the model is called with consecutive elements position-by-position. + # This positional encoding layer stores the index of the current + # position and increments it on each call. + emb = fastmath.dynamic_slice_in_dim( + weights, self.state, inputs.shape[1], axis=1 + ) + self.state += inputs.shape[1] + return inputs + emb + + def init_weights_and_state(self, input_signature): + """Randomly initializes the positional encoding vectors. + + Args: + input_signature: :py:class:`ShapeDtype` instance characterizing the input + this layer should compute on. + """ + d_feature = input_signature.shape[-1] + if self._d_feature is not None: + d_feature = self._d_feature + pe = np.zeros((self._max_len, d_feature), dtype=np.float32) + position = np.arange(0, self._max_len)[:, np.newaxis] + div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + pe[:, 0::2] = np.sin(position * div_term) + pe[:, 1::2] = np.cos(position * div_term) # [self._max_len, d_feature] + if self._use_bfloat16: + pe = pe.astype(jnp.bfloat16) + w = jnp.array(pe) # Trainable parameters, initialized above. + if self._d_feature is not None: + ff = init.GlorotUniformInitializer()( + (d_feature, input_signature.shape[-1]), self.rng + ) + self.weights = w, ff + else: + self.weights = w + if self._mode == "predict": + self.state = jnp.zeros((), dtype=jnp.int32) def _zero_pad(x, pad, axis): - """Helper for jnp.pad with 0s for single-axis case.""" - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = pad # Padding on axis. - return jnp.pad(x, pad_widths, mode='constant') - - -def _fast_inference_init_state(input_signature, buffer_length, - predict_mask=None): - """Returns an initial state for causal attention layer fast inference.""" - def zeros_for(batch_size, shape_dtype): - shape, dtype = shape_dtype.as_tuple() - d_feature = shape[-1] - return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype) - - batch_size = input_signature[0].shape[0] - k = zeros_for(batch_size, input_signature[1]) - v = zeros_for(batch_size, input_signature[2]) - if predict_mask is not None: - mask_for_predict = jnp.zeros((buffer_length,)) != 0 - return (mask_for_predict, k, v, jnp.array(0)) - else: - return (k, v, jnp.array(0)) + """Helper for jnp.pad with 0s for single-axis case.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = pad # Padding on axis. + return jnp.pad(x, pad_widths, mode="constant") + + +def _fast_inference_init_state(input_signature, buffer_length, predict_mask=None): + """Returns an initial state for causal attention layer fast inference.""" + + def zeros_for(batch_size, shape_dtype): + shape, dtype = shape_dtype.as_tuple() + d_feature = shape[-1] + return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype) + + batch_size = input_signature[0].shape[0] + k = zeros_for(batch_size, input_signature[1]) + v = zeros_for(batch_size, input_signature[2]) + if predict_mask is not None: + mask_for_predict = jnp.zeros((buffer_length,)) != 0 + return (mask_for_predict, k, v, jnp.array(0)) + else: + return (k, v, jnp.array(0)) def _fast_inference_update_state(inputs, state, mask_for_predict=None): - """Updates state of a causal attention layer for fast inference. - - The layer state stores arrays with cached values of keys and values, - as well as an index. To make shapes static, keys and values in the state are - long, and the index indicates where the new keys and values from inputs need - to be appended. - - During update, we append new_keys and new_values to keys and values at - position given by index. And we increment index by length of new keys. - We also create a mask to be 1 at appropriate positions (causal mask). - - Args: - inputs: a triple (new_queries, new_keys, new_values) - state: layer state with (keys, values, index) - mask_for_predict: mask used for predict mode. This is used only in - Terraformer. - - Returns: - Updated state and mask to be used. - """ - # Fast inference: run step-by-step, storing the sequence - # of keys and values calculated so far in state. - (_, new_k, new_v) = inputs - if mask_for_predict is not None: - (state_mask_for_predict, ks, vs, idx) = state - else: - (ks, vs, idx) = state - length = new_k.shape[1] - # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path - # with index_update when length == 1 is worth it. - # Keys and values are of shape [batch_size, length, d_kv]. - ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) - vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) - k_length = ks.shape[1] - - # Mask is of shape [1, q_length, k_length]. - # Mask should be true for every pair of (query_token, key_token) such that - # index of query_token is equal or larger to index of key_token. - mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) - <= jnp.reshape(jnp.arange(length) + idx, (1, length, 1))) - if mask_for_predict is None: - return (ks, vs, idx + length), mask - else: - state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( - state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0, - axis=0) - - state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( - state_mask_for_predict != 0, jnp.ones((1,)) != 0, - jnp.sum(mask_for_predict, dtype=jnp.int32), axis=0) - - state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( - state_mask_for_predict != 0, jnp.ones((1,)) != 0, idx, axis=0) - placeholder = jnp.reshape(state_mask_for_predict != 0, - (1, 1, mask.shape[2],)) - mask = mask * placeholder - - return (state_mask_for_predict, ks, vs, idx + length), mask + """Updates state of a causal attention layer for fast inference. + + The layer state stores arrays with cached values of keys and values, + as well as an index. To make shapes static, keys and values in the state are + long, and the index indicates where the new keys and values from inputs need + to be appended. + + During update, we append new_keys and new_values to keys and values at + position given by index. And we increment index by length of new keys. + We also create a mask to be 1 at appropriate positions (causal mask). + + Args: + inputs: a triple (new_queries, new_keys, new_values) + state: layer state with (keys, values, index) + mask_for_predict: mask used for predict mode. This is used only in + Terraformer. + + Returns: + Updated state and mask to be used. + """ + # Fast inference: run step-by-step, storing the sequence + # of keys and values calculated so far in state. + (_, new_k, new_v) = inputs + if mask_for_predict is not None: + (state_mask_for_predict, ks, vs, idx) = state + else: + (ks, vs, idx) = state + length = new_k.shape[1] + # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path + # with index_update when length == 1 is worth it. + # Keys and values are of shape [batch_size, length, d_kv]. + ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) + vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) + k_length = ks.shape[1] + + # Mask is of shape [1, q_length, k_length]. + # Mask should be true for every pair of (query_token, key_token) such that + # index of query_token is equal or larger to index of key_token. + mask = jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) <= jnp.reshape( + jnp.arange(length) + idx, (1, length, 1) + ) + if mask_for_predict is None: + return (ks, vs, idx + length), mask + else: + state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( + state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0, axis=0 + ) + + state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( + state_mask_for_predict != 0, + jnp.ones((1,)) != 0, + jnp.sum(mask_for_predict, dtype=jnp.int32), + axis=0, + ) + + state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( + state_mask_for_predict != 0, jnp.ones((1,)) != 0, idx, axis=0 + ) + placeholder = jnp.reshape( + state_mask_for_predict != 0, + ( + 1, + 1, + mask.shape[2], + ), + ) + mask = mask * placeholder + + return (state_mask_for_predict, ks, vs, idx + length), mask diff --git a/trax/layers/attention_test.py b/trax/layers/attention_test.py deleted file mode 100644 index 165866d62..000000000 --- a/trax/layers/attention_test.py +++ /dev/null @@ -1,190 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.attention.""" - -import functools -from absl.testing import absltest -import numpy as np - -from trax import shapes -import trax.layers as tl -from trax.layers import test_utils - - -class AttentionTest(absltest.TestCase): - - def test_simple_call(self): - layer = tl.CausalAttention(d_feature=4, n_heads=2) - x = [np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]), - np.array([[[[1, 0, 1]]]])] - _, _ = layer.init(shapes.signature(x)) - - y, mask = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - self.assertEqual(mask.shape, (1, 1, 1, 3)) - - def test_shift_right(self): - # Test shifts right on axis=1 - layer = tl.ShiftRight() - x = np.array([[[9, 9, 9], - [8, 8, 8], - [7, 7, 7], - [6, 6, 6]], - [[99, 98, 97], - [96, 95, 94], - [93, 92, 91], - [90, 89, 88]]]) - y = layer(x) - self.assertEqual(x.shape, y.shape) - self.assertEqual(tl.to_list(y), [[[0, 0, 0], - [9, 9, 9], - [8, 8, 8], - [7, 7, 7]], - [[0, 0, 0], - [99, 98, 97], - [96, 95, 94], - [93, 92, 91]]]) - - def test_shift_right_float(self): - layer = tl.ShiftRight() - x = np.array([[[9, 9, 9], - [8, 8, 8], - [7, 7, 7], - [6, 6, 6]], - [[99, 98, 97], - [96, 95, 94], - [93, 92, 91], - [90, 89, 88]]]).astype(np.float32) - x /= 2.0 - self.assertEqual(x.dtype, np.float32) - - y = layer(x) - self.assertEqual(y.dtype, np.float32) - self.assertEqual(tl.to_list(y), [[[0.0, 0.0, 0.0], - [4.5, 4.5, 4.5], - [4.0, 4.0, 4.0], - [3.5, 3.5, 3.5]], - [[0.0, 0.0, 0.0], - [49.5, 49.0, 48.5], - [48.0, 47.5, 47.0], - [46.5, 46.0, 45.5]]]) - - def test_padding_mask(self): - layer = tl.PaddingMask() - x = np.array([ - [1., 2., 3., 4., 0.], - [1., 2., 3., 0., 0.], - [1., 2., 0., 0., 0.], - ]) - y = layer(x) - self.assertEqual(x.shape, (3, 5)) - self.assertEqual(y.shape, (3, 1, 1, 5)) - np.testing.assert_equal(y, [[[[True, True, True, True, False]]], - [[[True, True, True, False, False]]], - [[[True, True, False, False, False]]]]) - - -class CausalAttentionTest(absltest.TestCase): - - def test_simple_call(self): - layer = tl.CausalAttention(d_feature=4, n_heads=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - def test_deterministic_eval(self): - d_model = 32 - seq_len = 3 - x_shape = (1, seq_len, d_model) - inp = np.ones(x_shape).astype(np.float32) - - model_fn = functools.partial( - tl.CausalAttention, - d_feature=d_model, - n_heads=4, - ) - - test_utils.test_eval_is_deterministic(inp, model_fn) - - def test_predict_equals_eval(self): - d_model = 32 - seq_len = 10 - x_shape = (1, seq_len, d_model) - inp = np.ones(x_shape).astype(np.float32) - - model_fn = functools.partial( - tl.CausalAttention, - d_feature=d_model, - n_heads=4, - ) - - test_utils.test_eval_equals_predict(inp, model_fn) - - -class PositionalEncodingTest(absltest.TestCase): - - def test_simple_call(self): - layer = tl.PositionalEncoding(max_len=8) - x = np.array([[[2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0]]]) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (1, 2, 4)) - - def test_predict(self): - layer = tl.PositionalEncoding(max_len=8) - x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) - self.assertEqual(x.shape, (1, 4, 2)) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (1, 4, 2)) - layer = tl.PositionalEncoding(max_len=8, mode='predict') - layer.init(shapes.signature(x[:, :1, :])) - y0 = layer(x[:, :1, :]) # just the first token - self.assertEqual(y0.shape, (1, 1, 2)) - self.assertTrue(np.array_equal(y0, y[:, :1, :])) - y1 = layer(x[:, 1:3, :]) # now the next 2 tokens - self.assertEqual(y1.shape, (1, 2, 2)) - self.assertTrue(np.array_equal(y1, y[:, 1:3, :])) - y2 = layer(x[:, 3:4, :]) # final one token - self.assertEqual(y2.shape, (1, 1, 2)) - self.assertTrue(np.array_equal(y2, y[:, 3:4, :])) - - def test_predict_equals_eval(self): - x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) - self.assertEqual(x.shape, (1, 4, 2)) - - layer_eval = tl.PositionalEncoding(max_len=8, d_feature=4, mode='eval') - layer_eval.init(shapes.signature(x)) - - output_eval = layer_eval(x) - - layer_predict = tl.PositionalEncoding(max_len=8, d_feature=4, - mode='predict') - layer_predict.init(shapes.signature(x)) - layer_predict.weights = layer_eval.weights - - output_predict = layer_predict(x) - self.assertTrue(np.array_equal(output_eval, output_predict)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/base.py b/trax/layers/base.py index 5c80f1521..f9901eefd 100644 --- a/trax/layers/base.py +++ b/trax/layers/base.py @@ -35,1021 +35,1089 @@ # TODO(lukaszkaiser): should we use special objects for these for clarity? -EMPTY_WEIGHTS = () # Used for layers that have no trainable weights. -EMPTY_STATE = () # Used for layers that have no non-trainable state. -GET_WEIGHTS_FROM_CACHE = {'__marker_for_cached_weights_': ()} -GET_STATE_FROM_CACHE = {'__marker_for_cached_state_': ()} +EMPTY_WEIGHTS = () # Used for layers that have no trainable weights. +EMPTY_STATE = () # Used for layers that have no non-trainable state. +GET_WEIGHTS_FROM_CACHE = {"__marker_for_cached_weights_": ()} +GET_STATE_FROM_CACHE = {"__marker_for_cached_state_": ()} N_WEIGHTS_SHARDS = 1 # TODO(lukaszkaiser): make weight-sharding non-global class Layer: - """Base class for composable layers in a deep learning network. + """Base class for composable layers in a deep learning network. - Layers are the basic building blocks for deep learning models. A layer - computes a function from zero or more inputs to zero or more outputs, - optionally using trainable weights (common) and non-parameter state (not - common). + Layers are the basic building blocks for deep learning models. A layer + computes a function from zero or more inputs to zero or more outputs, + optionally using trainable weights (common) and non-parameter state (not + common). - Layer subclasses typically override at most two methods of the base `Layer` - class: + Layer subclasses typically override at most two methods of the base `Layer` + class: - `forward(inputs)`: - Computes the layer's output as part of a forward pass through the model. + `forward(inputs)`: + Computes the layer's output as part of a forward pass through the model. - `init_weights_and_state(self, input_signature)`: - Initializes the layer's weights and state to handle input with the given - signature (number, shapes and dtypes of input arguments). + `init_weights_and_state(self, input_signature)`: + Initializes the layer's weights and state to handle input with the given + signature (number, shapes and dtypes of input arguments). - A small number of layer types are combinators -- they organize the computation - of their sublayers, e.g., applying their sublayers in series or in parallel. + A small number of layer types are combinators -- they organize the computation + of their sublayers, e.g., applying their sublayers in series or in parallel. - All layers have the following properties, with default values implemented - in the base `Layer` class: + All layers have the following properties, with default values implemented + in the base `Layer` class: - - `n_in`: int (default 1) - - `n_out`: int (default 1) - - `weights`: tuple (default empty -- the layer has no weights) - - `state`: tuple (default empty -- the layer has no non-parameter state) - - `sublayers`: tuple (default empty -- the layer has no sublayers) + - `n_in`: int (default 1) + - `n_out`: int (default 1) + - `weights`: tuple (default empty -- the layer has no weights) + - `state`: tuple (default empty -- the layer has no non-parameter state) + - `sublayers`: tuple (default empty -- the layer has no sublayers) - The inputs to a layer are tensors, packaged according to how many there are: + The inputs to a layer are tensors, packaged according to how many there are: - - `n_in = 0`: an empty tuple - - `n_in = 1`: one tensor (NOT wrapped in a tuple) - - `n_in > 1`: a tuple of tensors + - `n_in = 0`: an empty tuple + - `n_in = 1`: one tensor (NOT wrapped in a tuple) + - `n_in > 1`: a tuple of tensors - (The special treatment of the single-input case is meant to simplify the - work of layer writers; this design choice may be revisited in the future.) + (The special treatment of the single-input case is meant to simplify the + work of layer writers; this design choice may be revisited in the future.) - The outputs from a layer are also tensors, packaged the same as layer inputs: + The outputs from a layer are also tensors, packaged the same as layer inputs: - - `n_out = 0`: an empty tuple - - `n_out = 1`: the tensor (NOT wrapped in a tuple) - - `n_out > 1`: a tuple of tensors + - `n_out = 0`: an empty tuple + - `n_out = 1`: the tensor (NOT wrapped in a tuple) + - `n_out > 1`: a tuple of tensors - The Trax runtime maintains a data stack with which layer calls are composed. - For more complex data network architectures, possibly involving multiple data - flows, one can view each layer as a function from stack state to stack state, - where the function's inputs are a slice from the stack, and the function's - outputs are spliced back into the stack. - """ - - def __init__(self, n_in=1, n_out=1, name=None, sublayers_to_print=None): - """Creates a partially initialized, unconnected layer instance. - - Args: - n_in: Number of inputs expected by this layer. - n_out: Number of outputs promised by this layer. - name: Class-like name for this layer; for use when printing this layer. - sublayers_to_print: Sublayers to display when printing out this layer; - if None (the default), display all sublayers. + The Trax runtime maintains a data stack with which layer calls are composed. + For more complex data network architectures, possibly involving multiple data + flows, one can view each layer as a function from stack state to stack state, + where the function's inputs are a slice from the stack, and the function's + outputs are spliced back into the stack. """ - self._n_in = n_in - self._n_out = n_out - self._name = self.__class__.__name__ if name is None else name - self._sublayers_to_print = sublayers_to_print - self._sublayers = () # Default is no sublayers. - - # The actual rng value/shape depends on the backend, which may not yet be - # initialized at the point this method is run. Hence, at first initialize - # only a seed random integer, in a backend-neutral way. - self._rng = None - self._rng_seed_int = random.randint(0, 2**31 - 1) - - # The private fields _weights and _state store the private part of - # layer weights and state. When a layer has no sublayers, these are - # the same as layer.weights and layer.state. For layers with sublayers - # (i.e., combinators), these just mark which weights are cached -- see - # the getter and setter for weights and state for details. - # There is no need to use these fields in most user-implemented classes. - self._weights = EMPTY_WEIGHTS # By default no trainable weights. - self._state = EMPTY_STATE # By default no non-trainable state. - - # Record layer creation site for use in LayerError messages. - # The frame can mutate, so copy relevant values out of it. - frame = _find_frame(inspect.currentframe()) - self._caller = {'filename': copy.copy(frame.f_code.co_filename), - 'lineno': int(frame.f_lineno)} - del frame # Just in case. - - self._init_cached = False - self._jit_cache = {} - - def __repr__(self): - """Renders this layer as a medium-detailed string, to help in debugging. - - Subclasses should aim for high-signal/low-noise when overriding this - method. - Returns: - A high signal-to-noise string representing this layer. - """ - def indent_string(x): - return ' ' + x.replace('\n', '\n ') + def __init__(self, n_in=1, n_out=1, name=None, sublayers_to_print=None): + """Creates a partially initialized, unconnected layer instance. + + Args: + n_in: Number of inputs expected by this layer. + n_out: Number of outputs promised by this layer. + name: Class-like name for this layer; for use when printing this layer. + sublayers_to_print: Sublayers to display when printing out this layer; + if None (the default), display all sublayers. + """ + self._n_in = n_in + self._n_out = n_out + self._name = self.__class__.__name__ if name is None else name + self._sublayers_to_print = sublayers_to_print + self._sublayers = () # Default is no sublayers. + + # The actual rng value/shape depends on the backend, which may not yet be + # initialized at the point this method is run. Hence, at first initialize + # only a seed random integer, in a backend-neutral way. + self._rng = None + self._rng_seed_int = random.randint(0, 2**31 - 1) + + # The private fields _weights and _state store the private part of + # layer weights and state. When a layer has no sublayers, these are + # the same as layer.weights and layer.state. For layers with sublayers + # (i.e., combinators), these just mark which weights are cached -- see + # the getter and setter for weights and state for details. + # There is no need to use these fields in most user-implemented classes. + self._weights = EMPTY_WEIGHTS # By default no trainable weights. + self._state = EMPTY_STATE # By default no non-trainable state. + + # Record layer creation site for use in LayerError messages. + # The frame can mutate, so copy relevant values out of it. + frame = _find_frame(inspect.currentframe()) + self._caller = { + "filename": copy.copy(frame.f_code.co_filename), + "lineno": int(frame.f_lineno), + } + del frame # Just in case. + + self._init_cached = False + self._jit_cache = {} + + def __repr__(self): + """Renders this layer as a medium-detailed string, to help in debugging. + + Subclasses should aim for high-signal/low-noise when overriding this + method. + + Returns: + A high signal-to-noise string representing this layer. + """ + + def indent_string(x): + return " " + x.replace("\n", "\n ") + + name_str = self._name + n_in, n_out = self.n_in, self.n_out + if n_in != 1: + name_str += f"_in{n_in}" + if n_out != 1: + name_str += f"_out{n_out}" + + if self._sublayers_to_print is not None: + substructure = self._sublayers_to_print + else: + substructure = self.sublayers + if substructure: + substructure_strs = [str(x) for x in substructure if str(x)] + substructure_str = "\n".join(indent_string(s) for s in substructure_strs) + return f"{name_str}[\n{substructure_str}\n]" + else: + return name_str + + def __call__(self, x, weights=None, state=None, rng=None): + """Makes layers callable; for use in tests or interactive settings. + + This convenience method helps library users play with, test, or otherwise + probe the behavior of layers outside of a full training environment. It + presents the layer as callable function from inputs to outputs, with the + option of manually specifying weights and non-parameter state per individual + call. For convenience, weights and non-parameter state are cached per layer + instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`, + and acquiring non-empty values either by initialization or from values + explicitly provided via the weights and state keyword arguments, in which + case the old weights will be preserved, and the state will be updated. + + Args: + x: Zero or more input tensors, packaged as described in the `Layer` class + docstring. + weights: Weights or `None`; if `None`, use self's cached weights value. + state: State or `None`; if `None`, use self's cached state value. + rng: Single-use random number generator (JAX PRNG key), or `None`; + if `None`, use a default computed from an integer 0 seed. + + Returns: + Zero or more output tensors, packaged as described in the `Layer` class + docstring. + """ + weights = self.weights if weights is None else weights + rng = self.rng if rng is None else rng + if state is not None: + self.state = state # Needed if the model wasn't fully initialized. + state = self.state + outputs, new_state = self.pure_fn(x, weights, state, rng) + self.state = new_state + return outputs - name_str = self._name - n_in, n_out = self.n_in, self.n_out - if n_in != 1: name_str += f'_in{n_in}' - if n_out != 1: name_str += f'_out{n_out}' + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. - if self._sublayers_to_print is not None: - substructure = self._sublayers_to_print - else: - substructure = self.sublayers - if substructure: - substructure_strs = [str(x) for x in substructure if str(x)] - substructure_str = '\n'.join(indent_string(s) for s in substructure_strs) - return f'{name_str}[\n{substructure_str}\n]' - else: - return name_str - - def __call__(self, x, weights=None, state=None, rng=None): - """Makes layers callable; for use in tests or interactive settings. + A layer subclass overrides this method to define how the layer computes + outputs from inputs. If the layer depends on weights, state, or randomness + as part of the computation, the needed information can be accessed as + properties of the layer object: `self.weights`, `self.state`, and + `self.rng`. (See numerous examples in `trax.layers.core`.) - This convenience method helps library users play with, test, or otherwise - probe the behavior of layers outside of a full training environment. It - presents the layer as callable function from inputs to outputs, with the - option of manually specifying weights and non-parameter state per individual - call. For convenience, weights and non-parameter state are cached per layer - instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`, - and acquiring non-empty values either by initialization or from values - explicitly provided via the weights and state keyword arguments, in which - case the old weights will be preserved, and the state will be updated. + Args: + inputs: Zero or more input tensors, packaged as described in the `Layer` + class docstring. - Args: - x: Zero or more input tensors, packaged as described in the `Layer` class + Returns: + Zero or more output tensors, packaged as described in the `Layer` class docstring. - weights: Weights or `None`; if `None`, use self's cached weights value. - state: State or `None`; if `None`, use self's cached state value. - rng: Single-use random number generator (JAX PRNG key), or `None`; - if `None`, use a default computed from an integer 0 seed. + """ + raise NotImplementedError + + def init_weights_and_state(self, input_signature): + """Initializes weights and state, to handle input with the given signature. + + A layer subclass must override this method if the layer uses weights or + state. To initialize weights, set `self.weights` to desired (typically + random) values. To initialize state (uncommon), set `self.state` to desired + starting values. + + Args: + input_signature: A `ShapeDtype` instance (if this layer takes one input) + or a list/tuple of `ShapeDtype` instances. + """ + del input_signature + + @property + def has_backward(self): + """Returns `True` if this layer provides its own custom backward pass code. + + A layer subclass that provides custom backward pass code (for custom + gradients) must override this method to return `True`. + """ + return False + + def backward(self, inputs, output, grad, weights, state, new_state, rng): + """Custom backward pass to propagate gradients in a custom way. + + Args: + inputs: Input tensors; can be a (possibly nested) tuple. + output: The result of running this layer on inputs. + grad: Gradient signal computed based on subsequent layers; its structure + and shape must match output. + weights: This layer's weights. + state: This layer's state prior to the current forward pass. + new_state: This layer's state after the current forward pass. + rng: Single-use random number generator (JAX PRNG key). + + Returns: + The custom gradient signal for the input. Note that we need to return + a gradient for each argument of forward, so it will usually be a tuple + of signals: the gradient for inputs and weights. + """ + raise NotImplementedError + + # End of public subclassing interface. + # Begin public callable interface. + + def init(self, input_signature, rng=None, use_cache=False): + """Initializes weights/state of this layer and its sublayers recursively. + + Initialization creates layer weights and state, for layers that use them. + It derives the necessary array shapes and data types from the layer's input + signature, which is itself just shape and data type information. + + For layers without weights or state, this method safely does nothing. + + This method is designed to create weights/state only once for each layer + instance, even if the same layer instance occurs in multiple places in the + network. This enables weight sharing to be implemented as layer sharing. + + Args: + input_signature: `ShapeDtype` instance (if this layer takes one input) + or list/tuple of `ShapeDtype` instances. + rng: Single-use random number generator (JAX PRNG key), or `None`; + if `None`, use a default computed from an integer 0 seed. + use_cache: If `True`, and if this layer instance has already been + initialized elsewhere in the network, then return special marker + values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`. + Else return this layer's newly initialized weights and state. + + Returns: + A `(weights, state)` tuple. + """ + try: + if self._init_cached and use_cache: + return (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE) + + if rng is not None: + self.rng = rng + self.init_weights_and_state(input_signature) + + if use_cache: + self._init_cached = True + else: + self._clear_init_cache() + + return (self.weights, self.state) + + except Exception: + # Skipping 3 lines as it's always the uninteresting internal call. + name, trace = self._name, _short_traceback(skip=3) + raise LayerError( + name, "init", self._caller, input_signature, trace + ) from None + + def init_from_file(self, file_name, weights_only=False, input_signature=None): + """Initializes this layer and its sublayers from a pickled checkpoint. + + In the common case (`weights_only=False`), the file must be a gziped pickled + dictionary containing items with keys `'flat_weights', `'flat_state'` and + `'input_signature'`, which are used to initialize this layer. + If `input_signature` is specified, it's used instead of the one in the file. + If `weights_only` is `True`, the dictionary does not need to have the + `'flat_state'` item and the state it not restored either. + + Args: + file_name: Name/path of the pickled weights/state file. + weights_only: If `True`, initialize only the layer's weights. Else + initialize both weights and state. + input_signature: Input signature to be used instead of the one from file. + + Returns: + A `(weights, state)` tuple. + """ + with tf.io.gfile.GFile(file_name, "rb") as f: + with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf: + dictionary = pickle.load(gzipf) + # In the current checkpoint format, we store weights in a separate + # non-pickled file with the same name but added ".npy". + if isinstance(dictionary["flat_weights"], int): + if file_name.endswith(".pkl.gz"): + weights_path = file_name[:-6] + "weights.npy.gz" + else: + weights_path = file_name + ".npy" + if not tf.io.gfile.exists(weights_path): # old format compatibility + weights_path = file_name + ".npy" + dictionary["flat_weights"] = np_from_file( + weights_path, compresslevel=dictionary["flat_weights"] + ) + if input_signature is None: + input_signature = dictionary["input_signature"] + if weights_only and input_signature is not None: + self.init(input_signature) + weights_and_state_sig = self.weights_and_state_signature(input_signature) + weights, state = unflatten_weights_and_state( + dictionary["flat_weights"], + dictionary["flat_state"], + weights_and_state_sig, + weights_only=weights_only, + ) + if not weights_only: + self.state = state + self.weights = weights + return (self.weights, self.state) + + def save_to_file(self, file_name, weights_only=False, input_signature=None): + """Saves this layer and its sublayers to a pickled checkpoint. + + Args: + file_name: Name/path of the pickled weights/state file. + weights_only: If `True`, save only the layer's weights. Else + save both weights and state. + input_signature: Input signature to be used. + """ + flat_weights, flat_state = flatten_weights_and_state(self.weights, self.state) + dictionary = { + "flat_weights": flat_weights, + } + if not weights_only: + dictionary["flat_state"] = flat_state + if input_signature is not None: + dictionary["input_signature"] = input_signature + + tmp_file_path = file_name + "._tmp_" + with tf.io.gfile.GFile(tmp_file_path, "wb") as f: + with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf: + pickle.dump(dictionary, gzipf, protocol=pickle.HIGHEST_PROTOCOL) + # Moving a file is much less error-prone than pickling large files. + tf.io.gfile.rename(tmp_file_path, file_name, overwrite=True) + + def flatten_tuple(self, inputs): + flat_tuple = () + for _input in inputs: + if isinstance(_input, tuple): + flat_tuple += self.flatten_tuple(_input) + else: + flat_tuple += (_input,) + return flat_tuple + + # End of public callable methods. + # Methods and properties below are reserved for internal use. + + @property + def name(self): + """Returns the name of this layer.""" + return self._name + + @property + def n_in(self): + """Returns how many tensors this layer expects as input.""" + return self._n_in + + @property + def n_out(self): + """Returns how many tensors this layer promises as output.""" + return self._n_out + + @property + def sublayers(self): + """Returns a tuple containing this layer's sublayers; may be empty.""" + return self._sublayers + + @property + def weights(self): + """Returns this layer's weights. + + Depending on the layer, the weights can be in the form of: + + - an empty tuple + - a tensor (ndarray) + - a nested structure of tuples and tensors + + If the layer has sublayers, the weights by convention will be + a tuple of length `len(sublayers)` containing the weights of sublayers. + Note that in this case self._weights only marks which ones are shared. + """ + if not self.sublayers: + return self._weights + else: + return tuple( + layer.weights if w is None else w + for (layer, w) in zip(self.sublayers, self._weights) + ) + + @weights.setter + def weights(self, weights): + """Sets the weights of this layer and its sublayers. + + Args: + weights: the weights to set; if layer has sublayers, weights should be + either a list or a tuple of the same length as `len(self.sublayers)` + and it will be used to set the weights of all sublayers. + """ + if isinstance(weights, dict) and weights == GET_WEIGHTS_FROM_CACHE: + return + if not self.sublayers: + self._weights = weights + else: + # When having sublayers, self._weights just marks which are cached, + # the actual weights are stored by sublayers. + self._weights = [] + for w in weights: + if isinstance(w, dict) and w == GET_WEIGHTS_FROM_CACHE: + self._weights.append(w) + else: + self._weights.append(None) + # Set sublayer weights. + n_layers = len(self.sublayers) + if len(weights) != n_layers: + raise ValueError( + f"Number of weight elements ({len(weights)}) does not equal the " + f"number of sublayers ({n_layers}) in: {str(self)}." + ) + for sublayer, sublayer_weights in zip(self.sublayers, weights): + sublayer.weights = sublayer_weights + + @property + def state(self): + """Returns a tuple containing this layer's state; may be empty. + + If the layer has sublayers, the state by convention will be + a tuple of length `len(sublayers)` containing sublayer states. + Note that in this case self._state only marks which ones are shared. + """ + if not self.sublayers: + return self._state + else: + return tuple( + layer.state if s is None else s + for (layer, s) in zip(self.sublayers, self._state) + ) + + @state.setter + def state(self, state): + """Sets the state of this layer and its sublayers. + + Args: + state: the state to set; if layer has sublayers, state should be + either a list or a tuple of the same length as `len(self.sublayers)` + and it will be used to set the state of all sublayers. + """ + if isinstance(state, dict) and state == GET_STATE_FROM_CACHE: + return + if not self._sublayers: + self._state = state + else: + # When having sublayers, self._state just marks which are cached, + # the actual weights are stored by sublayers. + self._state = [] + for s in state: + if isinstance(s, dict) and s == GET_STATE_FROM_CACHE: + self._state.append(s) + else: + self._state.append(None) + # Set sublayer states. + n_layers = len(self.sublayers) + if len(state) != n_layers: + raise ValueError( + f"Number of state elements ({len(state)}) does not equal the " + f"number of sublayers ({n_layers}) in: {str(self)}." + ) + for sublayer, sublayer_state in zip(self.sublayers, state): + sublayer.state = sublayer_state + + def weights_and_state_signature(self, input_signature, unsafe=False): + """Return a pair containing the signatures of weights and state.""" + rng, state, weights = self.rng, self.state, self.weights + abstract_init = fastmath.abstract_eval(self.init) + sig = abstract_init(input_signature) + self.rng = rng + if not unsafe: + self.state, self.weights = state, weights + return sig + + @property + def rng(self): + """Returns this layer's current single-use random number generator. + + Code that wants to base random samples on this generator must explicitly + split off new generators from it. (See, for example, the `rng` setter code + below.) + """ + if self._rng is None: + # One-time initialization from backend-neutral seed int. + self._rng = fastmath.random.get_prng(self._rng_seed_int) + return self._rng + + @rng.setter + def rng(self, rng): + """Sets the rng (JAX PRNG key) for this layer and sublayers, recursively.""" + self._rng = rng + sublayers = self.sublayers + if sublayers: + rngs = fastmath.random.split(rng, len(sublayers)) + for sublayer, rng in zip(sublayers, rngs): + sublayer.rng = rng + + def _clear_init_cache(self): + self._init_cached = False + for sublayer in self.sublayers: + sublayer._clear_init_cache() # pylint: disable=protected-access + + def pure_fn(self, x, weights, state, rng, use_cache=False): + """Applies this layer as a pure function with no optional args. + + This method exposes the layer's computation as a pure function. This is + especially useful for JIT compilation. Do not override, use `forward` + instead. + + Args: + x: Zero or more input tensors, packaged as described in the `Layer` class + docstring. + weights: A tuple or list of trainable weights, with one element for this + layer if this layer has no sublayers, or one for each sublayer if + this layer has sublayers. If a layer (or sublayer) has no trainable + weights, the corresponding weights element is an empty tuple. + state: Layer-specific non-parameter state that can update between batches. + rng: Single-use random number generator (JAX PRNG key). + use_cache: if `True`, cache weights and state in the layer object; used + to implement layer sharing in combinators. + + Returns: + A tuple of `(tensors, state)`. The tensors match the number (`n_out`) + promised by this layer, and are packaged as described in the `Layer` + class docstring. + """ + try: + old_weights, old_state, old_rng = self.weights, self.state, self.rng + self._rng = rng + # The isinstance check is only needed when == is overloaded, as in TF. + if ( + isinstance(weights, dict) + and isinstance(state, dict) + and weights == GET_WEIGHTS_FROM_CACHE + and state == GET_STATE_FROM_CACHE + ): + was_cached = True + weights = self.weights + state = self.state + else: + # In this case, we're called for the first time: cache weights. + was_cached = False + self.weights, self.state = weights, state + + # If weights are sharded across multiple devices, unshard before forward. + sharded_weights, weights_were_unsharded = weights, False + if N_WEIGHTS_SHARDS > 1 and not self.sublayers: + self.weights, weights_were_unsharded = unshard_in_pmap( + weights, N_WEIGHTS_SHARDS + ) + + if not self.has_backward: + outputs = self.forward(x) + s = self.state + else: + outputs, s = self._do_custom_gradients(x) + self.state = s + self._rng = old_rng + if weights_were_unsharded: # only store a shard of weights if sharded + self.weights = sharded_weights + + if not use_cache: + self.weights, self.state = old_weights, old_state + if was_cached: # If the layer was shared, return a state marking this. + s = GET_STATE_FROM_CACHE + return outputs, s + + except Exception: + # Skipping 3 lines as it's always the uninteresting internal call. + name, trace = self._name, _short_traceback(skip=3) + raise LayerError( + name, "pure_fn", self._caller, signature(x), trace + ) from None + + def output_signature(self, input_signature): + """Returns output signature this layer would give for `input_signature`.""" + return self._forward_abstract(input_signature)[0] # output only, not state + + def _forward_abstract(self, input_signature): + """Computes shapes and dtypes this layer would produce in a forward pass. + + Args: + input_signature: `ShapeDtype` instance (if this layer takes one input) + or list/tuple of `ShapeDtype` instances. + + Returns: + Tuple of (output, state). + + The output part of the tuple is a `ShapeDtype` instance representing the + shape and type of the output (if this layer has one output) or a tuple + of `ShapeDtype` instances (if this layer has more than one output). + """ + try: + # Note: By using rng_signature in place of an rng, we avoid computing and + # permanently storing in global memory a large number of dropout masks. + # TODO(jonni): Check if using an rng still carries this cost. + dummy_rng = fastmath.random.get_prng(0) + rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype) + weights_signature = nested_map(signature, self.weights) + state_signature = nested_map(signature, self.state) + forward_infer_shapes = fastmath.abstract_eval(self.pure_fn) + return forward_infer_shapes( + input_signature, weights_signature, state_signature, rng_signature + ) + except Exception: + # TODO(lukaszkaiser): the choice of 7 is a heuristic, can we automate it? + # Skipping 7 lines which are all JAX abstract'ifying wrappers. + name, trace = self._name, _short_traceback(skip=7) + raise LayerError( + name, "_forward_abstract", self._caller, input_signature, trace + ) from None + + # pylint: disable=protected-access + def _do_custom_gradients(self, x): + """Calls this layer for a forward pass, but with custom gradients.""" + + def _f(state, rng, y, weights): + old_weights, old_state, old_rng = self.weights, self.state, self._rng + self.weights, self.state, self._rng = weights, state, rng + res = self.forward(y) + s = self.state + self.weights, self.state, self._rng = old_weights, old_state, old_rng + return res, s + + def _f_fwd(state, rng, y, weights): + old_weights, old_state, old_rng = self.weights, self.state, self._rng + self.weights, self.state, self._rng = weights, state, rng + res = self.forward(y) + s = self.state + self.weights, self.state, self._rng = old_weights, old_state, old_rng + return (res, s), (state, rng, y, res, weights, s) + + def _f_bwd(residual, grad): + """Custom gradient function.""" + state, rng, y, output, weights, new_state = residual + grad = grad[0] # Ignore dummy gradient wrt state. + out = self.backward(y, output, grad, weights, state, new_state, rng) + return (None, None, *out) + + do_forward = fastmath.custom_vjp(_f, _f_fwd, _f_bwd, nondiff_argnums=(0, 1)) + + output, state = do_forward(self.state, self._rng, x, self.weights) + return output, state + + def _settable_attrs(self): + """We only allow to set these attributes in Trax layers to prevent typos.""" + return ("weights", "state", "rng") + + def __setattr__(self, attr, value): + """Sets class attributes and protects from typos. + + In Trax layers, we only allow to set the following public attributes:: + + - weights + - state + - rng + + This function prevents from setting other public attributes to avoid typos, + for example, this is not possible and would be without this function:: + + [typo] layer.weighs = some_tensor + + If you need to set other public attributes in a derived class (which we + do not recommend as in almost all cases it suffices to use a private + attribute), override self._settable_attrs to include the attribute name. + + Args: + attr: Name of the attribute to be set. + value: Value to be assigned to the attribute. + """ + if attr[0] != "_" and attr not in self._settable_attrs(): + raise ValueError( + f"Trax layers only allow to set {self._settable_attrs()} as public " + f"attribues, not {attr}." + ) + else: + super().__setattr__(attr, value) - Returns: - Zero or more output tensors, packaged as described in the `Layer` class - docstring. - """ - weights = self.weights if weights is None else weights - rng = self.rng if rng is None else rng - if state is not None: - self.state = state # Needed if the model wasn't fully initialized. - state = self.state - outputs, new_state = self.pure_fn(x, weights, state, rng) - self.state = new_state - return outputs - - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. - - A layer subclass overrides this method to define how the layer computes - outputs from inputs. If the layer depends on weights, state, or randomness - as part of the computation, the needed information can be accessed as - properties of the layer object: `self.weights`, `self.state`, and - `self.rng`. (See numerous examples in `trax.layers.core`.) - Args: - inputs: Zero or more input tensors, packaged as described in the `Layer` - class docstring. +class PureLayer(Layer): + """Pure function from inputs to outputs, packaged as neural network layer. - Returns: - Zero or more output tensors, packaged as described in the `Layer` class - docstring. + The `PureLayer` class represents the simplest kinds of layers: layers with + no trainable weights and no randomness, hence pure functions from inputs to + outputs. """ - raise NotImplementedError - def init_weights_and_state(self, input_signature): - """Initializes weights and state, to handle input with the given signature. + def __init__(self, forward_fn, n_in=1, n_out=1, name="PureLayer"): + """Creates an unconnected `PureLayer` instance. - A layer subclass must override this method if the layer uses weights or - state. To initialize weights, set `self.weights` to desired (typically - random) values. To initialize state (uncommon), set `self.state` to desired - starting values. + Args: + forward_fn: Pure function from input tensors to output tensors, where + inputs and outputs are packaged as specified for `forward`. + n_in: Number of inputs expected by this layer. + n_out: Number of outputs promised by this layer. + name: Class-like name for this layer; for use only in debugging. + """ + super().__init__(n_in, n_out, name) + self._forward_fn = forward_fn - Args: - input_signature: A `ShapeDtype` instance (if this layer takes one input) - or a list/tuple of `ShapeDtype` instances. - """ - del input_signature + def forward(self, inputs): + """Overrides `Layer.forward`. - @property - def has_backward(self): - """Returns `True` if this layer provides its own custom backward pass code. + Args: + inputs: Zero or more input tensors, packaged as described in the `Layer` + class docstring. - A layer subclass that provides custom backward pass code (for custom - gradients) must override this method to return `True`. - """ - return False + Returns: + Zero or more output tensors, packaged as described in the `Layer` class + docstring. + """ + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) - def backward(self, inputs, output, grad, weights, state, new_state, rng): - """Custom backward pass to propagate gradients in a custom way. + # The input should be a flat single tuple without nested tuples + inputs = self.flatten_tuple(inputs) - Args: - inputs: Input tensors; can be a (possibly nested) tuple. - output: The result of running this layer on inputs. - grad: Gradient signal computed based on subsequent layers; its structure - and shape must match output. - weights: This layer's weights. - state: This layer's state prior to the current forward pass. - new_state: This layer's state after the current forward pass. - rng: Single-use random number generator (JAX PRNG key). + _validate_forward_input(inputs, self.n_in) - Returns: - The custom gradient signal for the input. Note that we need to return - a gradient for each argument of forward, so it will usually be a tuple - of signals: the gradient for inputs and weights. - """ - raise NotImplementedError + raw_output = self._forward_fn(inputs) + output = () if _is_empty(raw_output) else raw_output + return output - # End of public subclassing interface. - # Begin public callable interface. - def init(self, input_signature, rng=None, use_cache=False): - """Initializes weights/state of this layer and its sublayers recursively. +def Fn(name, f, n_out=1): # pylint: disable=invalid-name + """Returns a layer with no weights that applies the function `f`. - Initialization creates layer weights and state, for layers that use them. - It derives the necessary array shapes and data types from the layer's input - signature, which is itself just shape and data type information. + `f` can take and return any number of arguments, and takes only positional + arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`). + The following, for example, would create a layer that takes two inputs and + returns two outputs -- element-wise sums and maxima: - For layers without weights or state, this method safely does nothing. + `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)` - This method is designed to create weights/state only once for each layer - instance, even if the same layer instance occurs in multiple places in the - network. This enables weight sharing to be implemented as layer sharing. + The layer's number of inputs (`n_in`) is automatically set to number of + positional arguments in `f`, but you must explicitly set the number of + outputs (`n_out`) whenever it's not the default value 1. Args: - input_signature: `ShapeDtype` instance (if this layer takes one input) - or list/tuple of `ShapeDtype` instances. - rng: Single-use random number generator (JAX PRNG key), or `None`; - if `None`, use a default computed from an integer 0 seed. - use_cache: If `True`, and if this layer instance has already been - initialized elsewhere in the network, then return special marker - values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`. - Else return this layer's newly initialized weights and state. + name: Class-like name for the resulting layer; for use in debugging. + f: Pure function from input tensors to output tensors, where each input + tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`. + Output tensors must be packaged as specified in the `Layer` class + docstring. + n_out: Number of outputs promised by the layer; default value 1. Returns: - A `(weights, state)` tuple. + Layer executing the function `f`. """ - try: - if self._init_cached and use_cache: - return (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE) - - if rng is not None: - self.rng = rng - self.init_weights_and_state(input_signature) + argspec = inspect.getfullargspec(f) + if argspec.defaults is not None: + raise ValueError("Function has default arguments (not allowed).") + if argspec.varkw is not None: + raise ValueError("Function has keyword arguments (not allowed).") + if argspec.varargs is not None: + raise ValueError("Function has variable args (not allowed).") - if use_cache: - self._init_cached = True - else: - self._clear_init_cache() + def _forward(xs): # pylint: disable=invalid-name + if not isinstance(xs, (tuple, list)): + xs = (xs,) + return f(*xs) - return (self.weights, self.state) + n_in = len(argspec.args) + name = name or "Fn" + return PureLayer(_forward, n_in=n_in, n_out=n_out, name=name) - except Exception: - # Skipping 3 lines as it's always the uninteresting internal call. - name, trace = self._name, _short_traceback(skip=3) - raise LayerError(name, 'init', self._caller, - input_signature, trace) from None - def init_from_file(self, file_name, weights_only=False, input_signature=None): - """Initializes this layer and its sublayers from a pickled checkpoint. - - In the common case (`weights_only=False`), the file must be a gziped pickled - dictionary containing items with keys `'flat_weights', `'flat_state'` and - `'input_signature'`, which are used to initialize this layer. - If `input_signature` is specified, it's used instead of the one in the file. - If `weights_only` is `True`, the dictionary does not need to have the - `'flat_state'` item and the state it not restored either. +class LayerError(Exception): + """Exception raised in the layer stack.""" + + def __init__( + self, layer_name, function_name, caller, input_signature, traceback_string + ): + self._layer_name = layer_name + self._function_name = function_name + self._caller = caller # Python inspect object with init caller info. + self._traceback = traceback_string + self._input_signature = input_signature + super().__init__(self.message) + + @property + def message(self): + """Assembles current layer context into an error message.""" + prefix = "Exception passing through layer " + prefix += "%s (in %s):\n" % (self._layer_name, self._function_name) + short_path = "[...]/" + "/".join(self._caller["filename"].split("/")[-3:]) + caller = " layer created in file %s, line %d\n" % ( + short_path, + self._caller["lineno"], + ) + shapes_str = " layer input shapes: %s\n\n" % str(self._input_signature) + return prefix + caller + shapes_str + self._traceback - Args: - file_name: Name/path of the pickled weights/state file. - weights_only: If `True`, initialize only the layer's weights. Else - initialize both weights and state. - input_signature: Input signature to be used instead of the one from file. - Returns: - A `(weights, state)` tuple. - """ - with tf.io.gfile.GFile(file_name, 'rb') as f: - with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf: - dictionary = pickle.load(gzipf) - # In the current checkpoint format, we store weights in a separate - # non-pickled file with the same name but added ".npy". - if isinstance(dictionary['flat_weights'], int): - if file_name.endswith('.pkl.gz'): - weights_path = file_name[:-6] + 'weights.npy.gz' - else: - weights_path = file_name + '.npy' - if not tf.io.gfile.exists(weights_path): # old format compatibility - weights_path = file_name + '.npy' - dictionary['flat_weights'] = np_from_file( - weights_path, compresslevel=dictionary['flat_weights']) - if input_signature is None: - input_signature = dictionary['input_signature'] - if weights_only and input_signature is not None: - self.init(input_signature) - weights_and_state_sig = self.weights_and_state_signature(input_signature) - weights, state = unflatten_weights_and_state( - dictionary['flat_weights'], dictionary['flat_state'], - weights_and_state_sig, weights_only=weights_only) - if not weights_only: - self.state = state - self.weights = weights - return (self.weights, self.state) +def flatten_weights_and_state(weights, state): + """Flatten weights and state into lists, excluding empty and cached ones.""" - def save_to_file(self, file_name, weights_only=False, input_signature=None): - """Saves this layer and its sublayers to a pickled checkpoint. + def _is_empty_weight(x): + return x is EMPTY_WEIGHTS or ( + isinstance(x, dict) and x == GET_WEIGHTS_FROM_CACHE + ) - Args: - file_name: Name/path of the pickled weights/state file. - weights_only: If `True`, save only the layer's weights. Else - save both weights and state. - input_signature: Input signature to be used. - """ - flat_weights, flat_state = flatten_weights_and_state( - self.weights, self.state) - dictionary = { - 'flat_weights': flat_weights, - } - if not weights_only: - dictionary['flat_state'] = flat_state - if input_signature is not None: - dictionary['input_signature'] = input_signature - - tmp_file_path = file_name + '._tmp_' - with tf.io.gfile.GFile(tmp_file_path, 'wb') as f: - with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf: - pickle.dump(dictionary, gzipf, protocol=pickle.HIGHEST_PROTOCOL) - # Moving a file is much less error-prone than pickling large files. - tf.io.gfile.rename(tmp_file_path, file_name, overwrite=True) + flat_weights = [ + w for w in fastmath.tree_flatten(weights) if not _is_empty_weight(w) + ] - # End of public callable methods. - # Methods and properties below are reserved for internal use. + def _is_empty_state(x): + return x is EMPTY_STATE or (isinstance(x, dict) and x == GET_STATE_FROM_CACHE) - @property - def name(self): - """Returns the name of this layer.""" - return self._name + flat_state = [s for s in fastmath.tree_flatten(state) if not _is_empty_state(s)] + return flat_weights, flat_state - @property - def n_in(self): - """Returns how many tensors this layer expects as input.""" - return self._n_in - @property - def n_out(self): - """Returns how many tensors this layer promises as output.""" - return self._n_out +def unflatten_weights_and_state( + flat_weights, flat_state, weights_and_state_signature, weights_only=False +): + """Unflatten weights and state given their signatures.""" + weights_tree, state_tree = weights_and_state_signature + weights_to_copy = [EMPTY_WEIGHTS, GET_WEIGHTS_FROM_CACHE] + weights, _ = fastmath.tree_unflatten( + flat_weights, weights_tree, copy_from_tree=weights_to_copy + ) + state = None + if not weights_only: + states_to_copy = [EMPTY_STATE, GET_STATE_FROM_CACHE] + state, _ = fastmath.tree_unflatten( + flat_state, state_tree, copy_from_tree=states_to_copy + ) + return weights, state - @property - def sublayers(self): - """Returns a tuple containing this layer's sublayers; may be empty.""" - return self._sublayers - @property - def weights(self): - """Returns this layer's weights. +def np_to_file(list_of_nparrays, file_path, compresslevel): + """Save numpy arrays to file_path with gzipping and failure protection.""" + # Pickle to tmp file and overwrite to prevent writing partial files. + tmp_file_path = file_path + "._tmp_" + with tf.io.gfile.GFile(tmp_file_path, "wb") as f: + with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf: + for x in list_of_nparrays: + np.save(gzipf, x, allow_pickle=False) + # Moving a file is much less error-prone than pickling large files. + tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True) - Depending on the layer, the weights can be in the form of: - - an empty tuple - - a tensor (ndarray) - - a nested structure of tuples and tensors +def np_from_file(file_path, compresslevel): + """Load numpy arrays from file_path with gzipping.""" + if not tf.io.gfile.exists(file_path): + raise FileNotFoundError(file_path) + res = [] + with tf.io.gfile.GFile(file_path, "rb") as f: + with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf: + while True: + try: + res.append(np.load(gzipf, allow_pickle=False)) + except Exception: # pylint: disable=broad-except + break + return res - If the layer has sublayers, the weights by convention will be - a tuple of length `len(sublayers)` containing the weights of sublayers. - Note that in this case self._weights only marks which ones are shared. - """ - if not self.sublayers: - return self._weights - else: - return tuple(layer.weights if w is None else w - for (layer, w) in zip(self.sublayers, self._weights)) - @weights.setter - def weights(self, weights): - """Sets the weights of this layer and its sublayers. +def to_list(outputs): + """Converts layer outputs to a nested list, for easier equality testing. Args: - weights: the weights to set; if layer has sublayers, weights should be - either a list or a tuple of the same length as `len(self.sublayers)` - and it will be used to set the weights of all sublayers. - """ - if isinstance(weights, dict) and weights == GET_WEIGHTS_FROM_CACHE: - return - if not self.sublayers: - self._weights = weights - else: - # When having sublayers, self._weights just marks which are cached, - # the actual weights are stored by sublayers. - self._weights = [] - for w in weights: - if isinstance(w, dict) and w == GET_WEIGHTS_FROM_CACHE: - self._weights.append(w) - else: - self._weights.append(None) - # Set sublayer weights. - n_layers = len(self.sublayers) - if len(weights) != n_layers: - raise ValueError( - f'Number of weight elements ({len(weights)}) does not equal the ' - f'number of sublayers ({n_layers}) in: {str(self)}.') - for sublayer, sublayer_weights in zip(self.sublayers, weights): - sublayer.weights = sublayer_weights - - @property - def state(self): - """Returns a tuple containing this layer's state; may be empty. - - If the layer has sublayers, the state by convention will be - a tuple of length `len(sublayers)` containing sublayer states. - Note that in this case self._state only marks which ones are shared. + outputs: A tensor or tuple/list of tensors coming from the forward + application of a layer. Each tensor is NumPy ndarray-like, which + complicates simple equality testing (e.g., via `assertEquals`): + such tensors require equality testing to use either `all` (all + elements match) or `any` (at least one element matches), which is not + directly supported in `absltest`. + + Returns: + A nested list structure containing all the output values, but now directly + testable using `assertEquals`. """ - if not self.sublayers: - return self._state + if isinstance(outputs, (list, tuple)): + return [y.tolist() for y in outputs] else: - return tuple(layer.state if s is None else s - for (layer, s) in zip(self.sublayers, self._state)) + return outputs.tolist() - @state.setter - def state(self, state): - """Sets the state of this layer and its sublayers. - Args: - state: the state to set; if layer has sublayers, state should be - either a list or a tuple of the same length as `len(self.sublayers)` - and it will be used to set the state of all sublayers. - """ - if isinstance(state, dict) and state == GET_STATE_FROM_CACHE: - return - if not self._sublayers: - self._state = state - else: - # When having sublayers, self._state just marks which are cached, - # the actual weights are stored by sublayers. - self._state = [] - for s in state: - if isinstance(s, dict) and s == GET_STATE_FROM_CACHE: - self._state.append(s) - else: - self._state.append(None) - # Set sublayer states. - n_layers = len(self.sublayers) - if len(state) != n_layers: +def _validate_forward_input(x, n_in): + if n_in != 1: + if not isinstance(x, (tuple, list)): + raise TypeError( + f"Expected input to be a tuple or list; instead got {type(x)}." + ) + + if len(x) != n_in: raise ValueError( - f'Number of state elements ({len(state)}) does not equal the ' - f'number of sublayers ({n_layers}) in: {str(self)}.') - for sublayer, sublayer_state in zip(self.sublayers, state): - sublayer.state = sublayer_state - - def weights_and_state_signature(self, input_signature, unsafe=False): - """Return a pair containing the signatures of weights and state.""" - rng, state, weights = self.rng, self.state, self.weights - abstract_init = fastmath.abstract_eval(self.init) - sig = abstract_init(input_signature) - self.rng = rng - if not unsafe: - self.state, self.weights = state, weights - return sig - - @property - def rng(self): - """Returns this layer's current single-use random number generator. - - Code that wants to base random samples on this generator must explicitly - split off new generators from it. (See, for example, the `rng` setter code - below.) - """ - if self._rng is None: - # One-time initialization from backend-neutral seed int. - self._rng = fastmath.random.get_prng(self._rng_seed_int) - return self._rng - - @rng.setter - def rng(self, rng): - """Sets the rng (JAX PRNG key) for this layer and sublayers, recursively.""" - self._rng = rng - sublayers = self.sublayers - if sublayers: - rngs = fastmath.random.split(rng, len(sublayers)) - for sublayer, rng in zip(sublayers, rngs): - sublayer.rng = rng - - def _clear_init_cache(self): - self._init_cached = False - for sublayer in self.sublayers: - sublayer._clear_init_cache() # pylint: disable=protected-access - - def pure_fn(self, x, weights, state, rng, use_cache=False): - """Applies this layer as a pure function with no optional args. - - This method exposes the layer's computation as a pure function. This is - especially useful for JIT compilation. Do not override, use `forward` - instead. + f"Input tuple length ({len(x)}) does not equal required " + f"number of inputs ({n_in})." + ) - Args: - x: Zero or more input tensors, packaged as described in the `Layer` class - docstring. - weights: A tuple or list of trainable weights, with one element for this - layer if this layer has no sublayers, or one for each sublayer if - this layer has sublayers. If a layer (or sublayer) has no trainable - weights, the corresponding weights element is an empty tuple. - state: Layer-specific non-parameter state that can update between batches. - rng: Single-use random number generator (JAX PRNG key). - use_cache: if `True`, cache weights and state in the layer object; used - to implement layer sharing in combinators. - Returns: - A tuple of `(tensors, state)`. The tensors match the number (`n_out`) - promised by this layer, and are packaged as described in the `Layer` - class docstring. - """ - try: - old_weights, old_state, old_rng = self.weights, self.state, self.rng - self._rng = rng - # The isinstance check is only needed when == is overloaded, as in TF. - if (isinstance(weights, dict) and isinstance(state, dict) and - weights == GET_WEIGHTS_FROM_CACHE and state == GET_STATE_FROM_CACHE): - was_cached = True - weights = self.weights - state = self.state - else: - # In this case, we're called for the first time: cache weights. - was_cached = False - self.weights, self.state = weights, state - - # If weights are sharded across multiple devices, unshard before forward. - sharded_weights, weights_were_unsharded = weights, False - if N_WEIGHTS_SHARDS > 1 and not self.sublayers: - self.weights, weights_were_unsharded = unshard_in_pmap( - weights, N_WEIGHTS_SHARDS) - - if not self.has_backward: - outputs = self.forward(x) - s = self.state - else: - outputs, s = self._do_custom_gradients(x) - self.state = s - self._rng = old_rng - if weights_were_unsharded: # only store a shard of weights if sharded - self.weights = sharded_weights - - if not use_cache: - self.weights, self.state = old_weights, old_state - if was_cached: # If the layer was shared, return a state marking this. - s = GET_STATE_FROM_CACHE - return outputs, s - - except Exception: - # Skipping 3 lines as it's always the uninteresting internal call. - name, trace = self._name, _short_traceback(skip=3) - raise LayerError(name, 'pure_fn', - self._caller, signature(x), trace) from None - - def output_signature(self, input_signature): - """Returns output signature this layer would give for `input_signature`.""" - return self._forward_abstract(input_signature)[0] # output only, not state - - def _forward_abstract(self, input_signature): - """Computes shapes and dtypes this layer would produce in a forward pass. +def _is_empty(container): + if container is None: + raise ValueError('Argument "container" is None.') + return ( + isinstance(container, (list, tuple)) and len(container) == 0 + ) # pylint: disable=g-explicit-length-test - Args: - input_signature: `ShapeDtype` instance (if this layer takes one input) - or list/tuple of `ShapeDtype` instances. - Returns: - Tuple of (output, state). +def _find_frame(frame): + """Find the frame with the caller on the stack.""" - The output part of the tuple is a `ShapeDtype` instance representing the - shape and type of the output (if this layer has one output) or a tuple - of `ShapeDtype` instances (if this layer has more than one output). - """ - try: - # Note: By using rng_signature in place of an rng, we avoid computing and - # permanently storing in global memory a large number of dropout masks. - # TODO(jonni): Check if using an rng still carries this cost. - dummy_rng = fastmath.random.get_prng(0) - rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype) - weights_signature = nested_map(signature, self.weights) - state_signature = nested_map(signature, self.state) - forward_infer_shapes = fastmath.abstract_eval(self.pure_fn) - return forward_infer_shapes( - input_signature, weights_signature, state_signature, rng_signature) - except Exception: - # TODO(lukaszkaiser): the choice of 7 is a heuristic, can we automate it? - # Skipping 7 lines which are all JAX abstract'ifying wrappers. - name, trace = self._name, _short_traceback(skip=7) - raise LayerError(name, '_forward_abstract', self._caller, input_signature, - trace) from None - - # pylint: disable=protected-access - def _do_custom_gradients(self, x): - """Calls this layer for a forward pass, but with custom gradients.""" - - def _f(state, rng, y, weights): - old_weights, old_state, old_rng = self.weights, self.state, self._rng - self.weights, self.state, self._rng = weights, state, rng - res = self.forward(y) - s = self.state - self.weights, self.state, self._rng = old_weights, old_state, old_rng - return res, s - - def _f_fwd(state, rng, y, weights): - old_weights, old_state, old_rng = self.weights, self.state, self._rng - self.weights, self.state, self._rng = weights, state, rng - res = self.forward(y) - s = self.state - self.weights, self.state, self._rng = old_weights, old_state, old_rng - return (res, s), (state, rng, y, res, weights, s) - - def _f_bwd(residual, grad): - """Custom gradient function.""" - state, rng, y, output, weights, new_state = residual - grad = grad[0] # Ignore dummy gradient wrt state. - out = self.backward(y, output, grad, weights, state, new_state, rng) - return (None, None, *out) - - do_forward = fastmath.custom_vjp(_f, _f_fwd, _f_bwd, nondiff_argnums=(0, 1)) - - output, state = do_forward(self.state, self._rng, x, self.weights) - return output, state - - def _settable_attrs(self): - """We only allow to set these attributes in Trax layers to prevent typos.""" - return ('weights', 'state', 'rng') - - def __setattr__(self, attr, value): - """Sets class attributes and protects from typos. - - In Trax layers, we only allow to set the following public attributes:: - - - weights - - state - - rng - - This function prevents from setting other public attributes to avoid typos, - for example, this is not possible and would be without this function:: - - [typo] layer.weighs = some_tensor - - If you need to set other public attributes in a derived class (which we - do not recommend as in almost all cases it suffices to use a private - attribute), override self._settable_attrs to include the attribute name. + def _dirname_is_trax_layers_or_gin(frame): + """Skip frames coming from trax/layers or .../gin.""" + try: + dirname1 = frame.f_code.co_filename.split("/")[-3] + dirname2 = frame.f_code.co_filename.split("/")[-2] + return (dirname1 == "trax" and dirname2 == "layers") or dirname2 == "gin" + except IndexError: + return False - Args: - attr: Name of the attribute to be set. - value: Value to be assigned to the attribute. - """ - if attr[0] != '_' and attr not in self._settable_attrs(): - raise ValueError( - f'Trax layers only allow to set {self._settable_attrs()} as public ' - f'attribues, not {attr}.') - else: - super().__setattr__(attr, value) + while _dirname_is_trax_layers_or_gin(frame): + frame = frame.f_back + return frame -class PureLayer(Layer): - """Pure function from inputs to outputs, packaged as neural network layer. +def _shorten_file_path(line): + """Shorten file path in error lines for more readable tracebacks.""" + start = line.lower().find("file") + if start < 0: + return line + first_quote = line.find('"', start) + if first_quote < 0: + return line + second_quote = line.find('"', first_quote + 1) + if second_quote < 0: + return line + path = line[first_quote + 1 : second_quote] + new_path = "/".join(path.split("/")[-3:]) + return line[:first_quote] + "[...]/" + new_path + line[second_quote + 1 :] - The `PureLayer` class represents the simplest kinds of layers: layers with - no trainable weights and no randomness, hence pure functions from inputs to - outputs. - """ - def __init__(self, forward_fn, n_in=1, n_out=1, name='PureLayer'): - """Creates an unconnected `PureLayer` instance. +def _short_traceback(skip=3): + """Cleaned-up form of traceback.""" + counter, res = 0, [] + # Skipping 3 lines by default: the top (useless) and self-call. + # In python 3, we need to set chain to False (it doesn't exist in python 2). + lines = traceback.format_exc(chain=False).splitlines()[ + skip: + ] # pylint: disable=unexpected-keyword-arg + for l in lines: + if l.startswith("trax.layers.base.LayerError"): + l = l[len("trax.layers.base.") :] # Remove the trax.layers.base prefix. + res.append(_shorten_file_path(l)) + if counter % 2 == 1: + res.append("") + counter += 1 + # If we see a LayerError, the traceback has already been processed. + if l.startswith("LayerError"): + # Skip 4 back except last as these are internal base-layer calls. + res = res[:-4] + [res[-1]] + res += lines[counter:] + break + return "\n".join(res) - Args: - forward_fn: Pure function from input tensors to output tensors, where - inputs and outputs are packaged as specified for `forward`. - n_in: Number of inputs expected by this layer. - n_out: Number of outputs promised by this layer. - name: Class-like name for this layer; for use only in debugging. - """ - super().__init__(n_in, n_out, name) - self._forward_fn = forward_fn - def forward(self, inputs): - """Overrides `Layer.forward`. +def _random_values(input_signature, rng): + """Creates random floats or ints of the given shape. Args: - inputs: Zero or more input tensors, packaged as described in the `Layer` - class docstring. + input_signature: A `ShapeDtype` instance (if `layer_obj` takes one input) + or a list/tuple of ShapeDtype instances. + rng: Single-use random number generator (JAX PRNG key). Returns: - Zero or more output tensors, packaged as described in the `Layer` class - docstring. + Random values with the shape and type specified. """ - _validate_forward_input(inputs, self.n_in) - raw_output = self._forward_fn(inputs) - output = () if _is_empty(raw_output) else raw_output - return output - - -def Fn(name, f, n_out=1): # pylint: disable=invalid-name - """Returns a layer with no weights that applies the function `f`. - - `f` can take and return any number of arguments, and takes only positional - arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`). - The following, for example, would create a layer that takes two inputs and - returns two outputs -- element-wise sums and maxima: - - `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)` - - The layer's number of inputs (`n_in`) is automatically set to number of - positional arguments in `f`, but you must explicitly set the number of - outputs (`n_out`) whenever it's not the default value 1. - - Args: - name: Class-like name for the resulting layer; for use in debugging. - f: Pure function from input tensors to output tensors, where each input - tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`. - Output tensors must be packaged as specified in the `Layer` class - docstring. - n_out: Number of outputs promised by the layer; default value 1. - - Returns: - Layer executing the function `f`. - """ - argspec = inspect.getfullargspec(f) - if argspec.defaults is not None: - raise ValueError('Function has default arguments (not allowed).') - if argspec.varkw is not None: - raise ValueError('Function has keyword arguments (not allowed).') - if argspec.varargs is not None: - raise ValueError('Function has variable args (not allowed).') - - def _forward(xs): # pylint: disable=invalid-name - if not isinstance(xs, (tuple, list)): - xs = (xs,) - return f(*xs) - - n_in = len(argspec.args) - name = name or 'Fn' - return PureLayer(_forward, n_in=n_in, n_out=n_out, name=name) - - -class LayerError(Exception): - """Exception raised in the layer stack.""" - - def __init__(self, layer_name, function_name, caller, - input_signature, traceback_string): - self._layer_name = layer_name - self._function_name = function_name - self._caller = caller # Python inspect object with init caller info. - self._traceback = traceback_string - self._input_signature = input_signature - super().__init__(self.message) - - @property - def message(self): - """Assembles current layer context into an error message.""" - prefix = 'Exception passing through layer ' - prefix += '%s (in %s):\n' % (self._layer_name, self._function_name) - short_path = '[...]/' + '/'.join( - self._caller['filename'].split('/')[-3:]) - caller = ' layer created in file %s, line %d\n' % (short_path, - self._caller['lineno']) - shapes_str = ' layer input shapes: %s\n\n' % str(self._input_signature) - return prefix + caller + shapes_str + self._traceback - - -def flatten_weights_and_state(weights, state): - """Flatten weights and state into lists, excluding empty and cached ones.""" - def _is_empty_weight(x): - return (x is EMPTY_WEIGHTS or - (isinstance(x, dict) and x == GET_WEIGHTS_FROM_CACHE)) - flat_weights = [w for w in fastmath.tree_flatten(weights) - if not _is_empty_weight(w)] - def _is_empty_state(x): - return (x is EMPTY_STATE or - (isinstance(x, dict) and x == GET_STATE_FROM_CACHE)) - flat_state = [s for s in fastmath.tree_flatten(state) - if not _is_empty_state(s)] - return flat_weights, flat_state - - -def unflatten_weights_and_state( - flat_weights, flat_state, weights_and_state_signature, weights_only=False): - """Unflatten weights and state given their signatures.""" - weights_tree, state_tree = weights_and_state_signature - weights_to_copy = [EMPTY_WEIGHTS, GET_WEIGHTS_FROM_CACHE] - weights, _ = fastmath.tree_unflatten(flat_weights, weights_tree, - copy_from_tree=weights_to_copy) - state = None - if not weights_only: - states_to_copy = [EMPTY_STATE, GET_STATE_FROM_CACHE] - state, _ = fastmath.tree_unflatten(flat_state, state_tree, - copy_from_tree=states_to_copy) - return weights, state - + if isinstance(input_signature, ShapeDtype): + shape, dtype = input_signature.shape, input_signature.dtype + if np.issubdtype(dtype, np.integer): + return fastmath.random.bernoulli(rng, 0.5, shape).astype(np.int32) + else: + return fastmath.random.uniform(rng, shape, minval=-1.0, maxval=1.0) + elif isinstance(input_signature, (list, tuple)): + return tuple(_random_values(x, rng) for x in input_signature) + else: + raise TypeError(type(input_signature)) -def np_to_file(list_of_nparrays, file_path, compresslevel): - """Save numpy arrays to file_path with gzipping and failure protection.""" - # Pickle to tmp file and overwrite to prevent writing partial files. - tmp_file_path = file_path + '._tmp_' - with tf.io.gfile.GFile(tmp_file_path, 'wb') as f: - with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf: - for x in list_of_nparrays: - np.save(gzipf, x, allow_pickle=False) - # Moving a file is much less error-prone than pickling large files. - tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True) +def _shapes(x): + """Gets a structure of shapes for a structure of nested arrays.""" -def np_from_file(file_path, compresslevel): - """Load numpy arrays from file_path with gzipping.""" - if not tf.io.gfile.exists(file_path): - raise FileNotFoundError(file_path) - res = [] - with tf.io.gfile.GFile(file_path, 'rb') as f: - with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf: - while True: + def shape(x): try: - res.append(np.load(gzipf, allow_pickle=False)) + return tuple([int(i) for i in x.shape]) except Exception: # pylint: disable=broad-except - break - return res + return () + return tuple(nested_map(shape, x)) -def to_list(outputs): - """Converts layer outputs to a nested list, for easier equality testing. - - Args: - outputs: A tensor or tuple/list of tensors coming from the forward - application of a layer. Each tensor is NumPy ndarray-like, which - complicates simple equality testing (e.g., via `assertEquals`): - such tensors require equality testing to use either `all` (all - elements match) or `any` (at least one element matches), which is not - directly supported in `absltest`. - - Returns: - A nested list structure containing all the output values, but now directly - testable using `assertEquals`. - """ - if isinstance(outputs, (list, tuple)): - return [y.tolist() for y in outputs] - else: - return outputs.tolist() +@functools.partial(fastmath.pmap, axis_name="batch") +def _axis_index(unused_x): + """Return the axis indices.""" + return jax.lax.axis_index("batch") -def _validate_forward_input(x, n_in): - if n_in != 1: - if not isinstance(x, (tuple, list)): - raise TypeError( - f'Expected input to be a tuple or list; instead got {type(x)}.') - if len(x) != n_in: - raise ValueError(f'Input tuple length ({len(x)}) does not equal required ' - f'number of inputs ({n_in}).') - - -def _is_empty(container): - if container is None: - raise ValueError('Argument "container" is None.') - return isinstance(container, (list, tuple)) and len(container) == 0 # pylint: disable=g-explicit-length-test - - -def _find_frame(frame): - """Find the frame with the caller on the stack.""" - def _dirname_is_trax_layers_or_gin(frame): - """Skip frames coming from trax/layers or .../gin.""" - try: - dirname1 = frame.f_code.co_filename.split('/')[-3] - dirname2 = frame.f_code.co_filename.split('/')[-2] - return (dirname1 == 'trax' and dirname2 == 'layers') or dirname2 == 'gin' - except IndexError: - return False - - while _dirname_is_trax_layers_or_gin(frame): - frame = frame.f_back - return frame +def _axis_to_shard_heuristic(shape): + """Chooses an axis to shard on - a simple heuristic to be revisited.""" + axis = 0 if len(shape) < 3 else -1 + return axis -def _shorten_file_path(line): - """Shorten file path in error lines for more readable tracebacks.""" - start = line.lower().find('file') - if start < 0: - return line - first_quote = line.find('"', start) - if first_quote < 0: - return line - second_quote = line.find('"', first_quote + 1) - if second_quote < 0: - return line - path = line[first_quote + 1:second_quote] - new_path = '/'.join(path.split('/')[-3:]) - return line[:first_quote] + '[...]/' + new_path + line[second_quote + 1:] +def shard(tensors, n_shards=None): + """Shard tensors across n_shards.""" + n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards + indices = _axis_index(np.zeros(fastmath.local_device_count())) -def _short_traceback(skip=3): - """Cleaned-up form of traceback.""" - counter, res = 0, [] - # Skipping 3 lines by default: the top (useless) and self-call. - # In python 3, we need to set chain to False (it doesn't exist in python 2). - lines = traceback.format_exc(chain=False).splitlines()[skip:] # pylint: disable=unexpected-keyword-arg - for l in lines: - if l.startswith('trax.layers.base.LayerError'): - l = l[len('trax.layers.base.'):] # Remove the trax.layers.base prefix. - res.append(_shorten_file_path(l)) - if counter % 2 == 1: - res.append('') - counter += 1 - # If we see a LayerError, the traceback has already been processed. - if l.startswith('LayerError'): - # Skip 4 back except last as these are internal base-layer calls. - res = res[:-4] + [res[-1]] - res += lines[counter:] - break - return '\n'.join(res) + def _shard_fn(x): + axis = _axis_to_shard_heuristic(x.shape) + if int(x.shape[axis]) % n_shards != 0: + raise ValueError(f"Cannot split x with shape {x.shape} into {n_shards}.") + split_x = jnp.split(x, n_shards, axis=axis) + split_x = [split_x[i % n_shards] for i in indices] + return np.stack(split_x, axis=0) + return fastmath.nested_map(_shard_fn, tensors) -def _random_values(input_signature, rng): - """Creates random floats or ints of the given shape. - - Args: - input_signature: A `ShapeDtype` instance (if `layer_obj` takes one input) - or a list/tuple of ShapeDtype instances. - rng: Single-use random number generator (JAX PRNG key). - - Returns: - Random values with the shape and type specified. - """ - if isinstance(input_signature, ShapeDtype): - shape, dtype = input_signature.shape, input_signature.dtype - if np.issubdtype(dtype, np.integer): - return fastmath.random.bernoulli(rng, 0.5, shape).astype(np.int32) - else: - return fastmath.random.uniform(rng, shape, minval=-1.0, maxval=1.0) - elif isinstance(input_signature, (list, tuple)): - return tuple(_random_values(x, rng) for x in input_signature) - else: - raise TypeError(type(input_signature)) +def unshard_in_pmap(tensors, n_shards): + """Unshard tensors that were sharded into n_shards (call inside pmap).""" + groups = [ + [n_shards * i + d for d in range(n_shards)] + for i in range(fastmath.global_device_count() // n_shards) + ] + + def _unshard_fn(x): + y = jax.lax.all_gather(x, "batch", axis_index_groups=groups) + split_y = jnp.split(y, n_shards, axis=0) + split_y = [jnp.squeeze(sy, axis=0) for sy in split_y] + axis = _axis_to_shard_heuristic(split_y[0].shape) + return jnp.concatenate(split_y, axis=axis) -def _shapes(x): - """Gets a structure of shapes for a structure of nested arrays.""" - def shape(x): try: - return tuple([int(i) for i in x.shape]) - except Exception: # pylint: disable=broad-except - return () - return tuple(nested_map(shape, x)) - - -@functools.partial(fastmath.pmap, axis_name='batch') -def _axis_index(unused_x): - """Return the axis indices.""" - return jax.lax.axis_index('batch') - - -def _axis_to_shard_heuristic(shape): - """Chooses an axis to shard on - a simple heuristic to be revisited.""" - axis = 0 if len(shape) < 3 else -1 - return axis + jax.lax.axis_index("batch") # will throw if not in pmap, e.g., on init + res = fastmath.nested_map(_unshard_fn, tensors) + return res, True + except NameError: # thrown from axis_index above + return tensors, False -def shard(tensors, n_shards=None): - """Shard tensors across n_shards.""" - n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards - indices = _axis_index(np.zeros(fastmath.local_device_count())) - def _shard_fn(x): - axis = _axis_to_shard_heuristic(x.shape) - if int(x.shape[axis]) % n_shards != 0: - raise ValueError(f'Cannot split x with shape {x.shape} into {n_shards}.') - split_x = jnp.split(x, n_shards, axis=axis) - split_x = [split_x[i % n_shards] for i in indices] - return np.stack(split_x, axis=0) - return fastmath.nested_map(_shard_fn, tensors) - - -def unshard_in_pmap(tensors, n_shards): - """Unshard tensors that were sharded into n_shards (call inside pmap).""" - groups = [[n_shards * i + d for d in range(n_shards)] - for i in range(fastmath.global_device_count() // n_shards)] - def _unshard_fn(x): - y = jax.lax.all_gather(x, 'batch', axis_index_groups=groups) - split_y = jnp.split(y, n_shards, axis=0) - split_y = [jnp.squeeze(sy, axis=0) for sy in split_y] - axis = _axis_to_shard_heuristic(split_y[0].shape) - return jnp.concatenate(split_y, axis=axis) - try: - jax.lax.axis_index('batch') # will throw if not in pmap, e.g., on init - res = fastmath.nested_map(_unshard_fn, tensors) - return res, True - except NameError: # thrown from axis_index above - return tensors, False - - -@functools.partial(fastmath.pmap, axis_name='batch') +@functools.partial(fastmath.pmap, axis_name="batch") def _all_gather(x, groups): - return jax.lax.all_gather(x, 'batch', axis_index_groups=groups) + return jax.lax.all_gather(x, "batch", axis_index_groups=groups) def unshard(tensors, n_shards=None): - """Unshard tensors that were sharded into n_shards (outside of pmap).""" - n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards - def _unshard_fn(x): - # We use numpy here to put the large un-sharded arrays in CPU memory. - # For unsharding on accelerators use ushard_in_pmap above and pmap it. - split_y = np.split(np.asarray(x), n_shards, axis=0) - split_y = [np.squeeze(sy, axis=0) for sy in split_y] - axis = _axis_to_shard_heuristic(split_y[0].shape) - return np.concatenate(split_y, axis=axis) - return fastmath.nested_map(_unshard_fn, tensors) + """Unshard tensors that were sharded into n_shards (outside of pmap).""" + n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards + + def _unshard_fn(x): + # We use numpy here to put the large un-sharded arrays in CPU memory. + # For unsharding on accelerators use ushard_in_pmap above and pmap it. + split_y = np.split(np.asarray(x), n_shards, axis=0) + split_y = [np.squeeze(sy, axis=0) for sy in split_y] + axis = _axis_to_shard_heuristic(split_y[0].shape) + return np.concatenate(split_y, axis=axis) + + return fastmath.nested_map(_unshard_fn, tensors) diff --git a/trax/layers/base_test.py b/trax/layers/base_test.py deleted file mode 100644 index b1abf2786..000000000 --- a/trax/layers/base_test.py +++ /dev/null @@ -1,223 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Trax base layer classes and generic layer-creating functions.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np - -from trax import fastmath -from trax import shapes -from trax.fastmath import numpy as jnp -import trax.layers as tl - -BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP] -CUSTOM_GRAD_BACKENDS = [fastmath.Backend.JAX] # TODO(afrozm): del after TF 2.3 - - -class BaseLayerTest(parameterized.TestCase): - - def test_call_raises_error(self): - layer = tl.Layer() - x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) - with self.assertRaisesRegex(tl.LayerError, 'NotImplementedError'): - _ = layer(x) - - def test_set_weighs_raises_error(self): - layer = tl.Layer() - layer.weights = 1.0 # can assign weights - with self.assertRaisesRegex(ValueError, 'weighs'): - layer.weighs = 1.0 # cannot assign weighs - - def test_forward_raises_error(self): - layer = tl.Layer() - x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) - with self.assertRaises(NotImplementedError): - _ = layer.forward(x) - - def test_init_returns_empty_weights_and_state(self): - layer = tl.Layer() - input_signature = shapes.ShapeDtype((2, 5)) - weights, state = layer.init(input_signature) - self.assertEmpty(weights) - self.assertEmpty(state) - - def test_output_signature_no_weights(self): - shape_2_3_5 = shapes.ShapeDtype((2, 3, 5)) - input_signature = (shape_2_3_5, shape_2_3_5) - layer = tl.Fn('2in1out', lambda x, y: x + y) - output_signature = layer.output_signature(input_signature) - self.assertEqual(output_signature, shape_2_3_5) - - shape_5_7 = shapes.ShapeDtype((5, 7)) - input_signature = shape_5_7 - layer = tl.Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3) - output_signature = layer.output_signature(input_signature) - self.assertEqual(output_signature, (shape_5_7, shape_5_7, shape_5_7)) - - # TODO(jonni): Define/test behavior of output signature for layers w/weights. - - @parameterized.named_parameters( - [('_' + b.value, b) for b in CUSTOM_GRAD_BACKENDS]) - def test_custom_zero_grad(self, backend): - - class IdWithZeroGrad(tl.Layer): - - def forward(self, x): - return x - - @property - def has_backward(self): - return True - - def backward(self, inputs, output, grad, weights, state, new_state, rng): - return (jnp.zeros_like(grad), ()) - - with fastmath.use_backend(backend): - layer = IdWithZeroGrad() - rng = fastmath.random.get_prng(0) - input_signature = shapes.ShapeDtype((9, 17)) - random_input = fastmath.random.uniform( - rng, input_signature.shape, minval=-1.0, maxval=1.0) - layer.init(input_signature) - f = lambda x: jnp.mean(layer(x)) - grad = fastmath.grad(f)(random_input) - self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. - self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0. - - @parameterized.named_parameters( - [('_' + b.value, b) for b in CUSTOM_GRAD_BACKENDS]) - def test_custom_id_grad(self, backend): - - class IdWithIdGrad(tl.Layer): - - def forward(self, x): - return x - - @property - def has_backward(self): - return True - - def backward(self, inputs, output, grad, weights, state, new_state, rng): - return (inputs, ()) - - with fastmath.use_backend(backend): - layer = IdWithIdGrad() - rng = fastmath.random.get_prng(0) - input_signature = shapes.ShapeDtype((9, 17)) - random_input = fastmath.random.uniform( - rng, input_signature.shape, minval=-1.0, maxval=1.0) - layer.init(input_signature) - f = lambda x: jnp.mean(layer(x)) - grad = fastmath.grad(f)(random_input) - self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. - self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input. - - def test_weights_and_state_signature(self): - - class MyLayer(tl.Layer): - - def init_weights_and_state(self, input_signature): - self.weights = jnp.zeros((2, 3)) - self.state = jnp.ones(input_signature.shape) - - def forward(self, inputs): - return self.weights + self.state - - layer = MyLayer() - w, s = layer.weights_and_state_signature(jnp.zeros((3, 4))) - self.assertEqual(w.shape, (2, 3)) - self.assertEqual(s.shape, (3, 4)) - - def test_custom_name(self): - layer = tl.Layer() - self.assertIn('Layer', str(layer)) - self.assertNotIn('CustomLayer', str(layer)) - - layer = tl.Layer(name='CustomLayer') - self.assertIn('CustomLayer', str(layer)) - - -class PureLayerTest(absltest.TestCase): - - def test_forward(self): - layer = tl.PureLayer(lambda x: 2 * x) - - # Use Layer.__call__. - in_0 = np.array([1, 2]) - out_0 = layer(in_0, weights=jnp.zeros((2, 3))) - self.assertEqual(out_0.tolist(), [2, 4]) - self.assertEmpty(layer.weights) - - # Use PureLayer.forward. - in_1 = np.array([3, 4]) - out_1 = layer.forward(in_1) - self.assertEqual(out_1.tolist(), [6, 8]) - - # Use Layer.pure_fn - in_2 = np.array([5, 6]) - out_2, _ = layer.pure_fn(in_2, tl.EMPTY_WEIGHTS, tl.EMPTY_WEIGHTS, None) - self.assertEqual(out_2.tolist(), [10, 12]) - - -class FnTest(absltest.TestCase): - - def test_bad_f_has_default_arg(self): - with self.assertRaisesRegex(ValueError, 'default arg'): - _ = tl.Fn('', lambda x, sth=None: x) - - def test_bad_f_has_keyword_arg(self): - with self.assertRaisesRegex(ValueError, 'keyword arg'): - _ = tl.Fn('', lambda x, **kwargs: x) - - def test_bad_f_has_variable_arg(self): - with self.assertRaisesRegex(ValueError, 'variable arg'): - _ = tl.Fn('', lambda *args: args[0]) - - def test_forward(self): - layer = tl.Fn( - 'SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2) - - x0 = np.array([1, 2, 3, 4, 5]) - x1 = np.array([10, 20, 30, 40, 50]) - - y0, y1 = layer((x0, x1)) - self.assertEqual(y0.tolist(), [11, 22, 33, 44, 55]) - self.assertEqual(y1.tolist(), [10, 20, 30, 40, 50]) - - y2, y3 = layer.forward((x0, x1)) - self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55]) - self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50]) - - (y4, y5), state = layer.pure_fn((x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE, - None) - self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55]) - self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50]) - self.assertEqual(state, tl.EMPTY_STATE) - - def test_weights_state(self): - layer = tl.Fn( - '2in2out', - lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), - n_out=2) - layer.init_weights_and_state(None) - self.assertEmpty(layer.weights) - self.assertEmpty(layer.state) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/combinators.py b/trax/layers/combinators.py index ddde61f8f..15c084711 100644 --- a/trax/layers/combinators.py +++ b/trax/layers/combinators.py @@ -25,987 +25,1079 @@ class Serial(base.Layer): - """Combinator that applies layers serially (by function composition). - - This combinator is commonly used to construct deep networks, e.g., like this:: - - mlp = tl.Serial( - tl.Dense(128), - tl.Relu(), - tl.Dense(10), - ) - - A Serial combinator uses stack semantics to manage data for its sublayers. - Each sublayer sees only the inputs it needs and returns only the outputs it - has generated. The sublayers interact via the data stack. For instance, a - sublayer k, following sublayer j, gets called with the data stack in the - state left after layer j has applied. The Serial combinator then: - - - takes n_in items off the top of the stack (n_in = k.n_in) and calls - layer k, passing those items as arguments; and - - - takes layer k's n_out return values (n_out = k.n_out) and pushes - them onto the data stack. - - A Serial instance with no sublayers acts as a special-case (but useful) - 1-input 1-output no-op. - """ - - def __init__(self, *sublayers, name=None, sublayers_to_print=None): - super().__init__( - name=name, sublayers_to_print=sublayers_to_print) - - sublayers = _ensure_flat(sublayers) - self._sublayers = sublayers - self._n_layers = len(sublayers) - - if sublayers: - self._n_in, self._n_out = self._n_inputs_n_outputs(sublayers) - self._weights = tuple(None for l in sublayers) - self._state = tuple(None for l in sublayers) - - def forward(self, xs): - """Executes this layer as part of a forward pass through the model.""" - self._validate_forward_inputs(xs) - if not self.sublayers: # No-op: outputs = inputs - return xs - - state, weights = self.state, self.weights - rngs = _split_rngs(self.rng, self._n_layers) - stack = xs - new_state = [] - n_layers = self._n_layers - if len(weights) != n_layers: - raise ValueError( - f'Number of weight elements ({len(weights)}) does not equal ' - f'number of sublayers ({n_layers}).') - if len(state) != n_layers: - raise ValueError( - f'Number of state elements ({len(state)}) does not equal ' - f'number of sublayers ({n_layers}).') - - for layer, w, s, rng in zip(self.sublayers, weights, state, rngs): - inputs = inputs_from_stack(stack, layer.n_in) - outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True) - stack = outputs_onto_stack(outputs, stack, layer.n_in) - new_state.append(s) - self.state = tuple(new_state) - return stack - - # pylint: disable=protected-access - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - weights = [] - states = [] - # In the code below, stack, inputs, and outputs are abstract (shapes and - # dtypes), but weights and states are non-abstract actual values. - stack = input_signature - for sublayer in self.sublayers: - inputs = inputs_from_stack(stack, sublayer.n_in) - weights_or_cache_marker, state_or_cache_marker = ( - sublayer.init(inputs, use_cache=True)) - outputs, _ = sublayer._forward_abstract(inputs) - stack = outputs_onto_stack(outputs, stack, sublayer.n_in) - - weights.append(weights_or_cache_marker) - states.append(state_or_cache_marker) - self.state = tuple(states) - self.weights = tuple(weights) - # pylint: enable=protected-access - - def _n_inputs_n_outputs(self, layers): - del self - running_max = 0 - running_total = 0 - for layer in layers: - running_total += layer.n_in - running_max = max(running_max, running_total) - running_total -= layer.n_out - return running_max, (running_max - running_total) - - def _validate_forward_inputs(self, xs): - if not isinstance(xs, (tuple, list)) and self._n_in != 1: - raise TypeError(f'Serial.forward input must be a tuple or list; ' - f'instead got {type(xs)}.') - # TODO(jonni): Include full xs (or shape) in error message? - len_xs = 1 if isinstance(xs, jnp.ndarray) else len(xs) - if len_xs < self.n_in: - raise ValueError( - f'Number of inputs ({len(xs)}) to Serial.forward less than n_in ' - f'({self.n_in}).') + """Combinator that applies layers serially (by function composition). + This combinator is commonly used to construct deep networks, e.g., like this:: -class Parallel(base.Layer): - """Combinator that applies a list of layers in parallel to its inputs. + mlp = tl.Serial( + tl.Dense(128), + tl.Relu(), + tl.Dense(10), + ) - Layers in the list apply to successive spans of inputs, where the spans are - determined how many inputs each layer takes. The resulting output is the - (flattened) concatenation of the respective layer outputs. + A Serial combinator uses stack semantics to manage data for its sublayers. + Each sublayer sees only the inputs it needs and returns only the outputs it + has generated. The sublayers interact via the data stack. For instance, a + sublayer k, following sublayer j, gets called with the data stack in the + state left after layer j has applied. The Serial combinator then: - For example, suppose one has three layers: + - takes n_in items off the top of the stack (n_in = k.n_in) and calls + layer k, passing those items as arguments; and - - F: 1 input, 1 output - - G: 3 inputs, 1 output - - H: 2 inputs, 2 outputs (h1, h2) + - takes layer k's n_out return values (n_out = k.n_out) and pushes + them onto the data stack. - Then Parallel(F, G, H) will take 6 inputs and give 4 outputs: + A Serial instance with no sublayers acts as a special-case (but useful) + 1-input 1-output no-op. + """ - - inputs: a, b, c, d, e, f - - outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f) + def __init__(self, *sublayers, name=None, sublayers_to_print=None): + super().__init__(name=name, sublayers_to_print=sublayers_to_print) + + sublayers = _ensure_flat(sublayers) + self._sublayers = sublayers + self._n_layers = len(sublayers) + + if sublayers: + self._n_in, self._n_out = self._n_inputs_n_outputs(sublayers) + self._weights = tuple(None for l in sublayers) + self._state = tuple(None for l in sublayers) + + def forward(self, xs): + """Executes this layer as part of a forward pass through the model.""" + if not isinstance(xs, (tuple, list)): + xs = (xs,) + + # The input should be a flat single tuple without nested tuples + xs = self.flatten_tuple(xs) + + self._validate_forward_inputs(xs) + + if not self.sublayers: # No-op: outputs = inputs + return xs + + state, weights = self.state, self.weights + rngs = _split_rngs(self.rng, self._n_layers) + stack = xs + new_state = [] + n_layers = self._n_layers + if len(weights) != n_layers: + raise ValueError( + f"Number of weight elements ({len(weights)}) does not equal " + f"number of sublayers ({n_layers})." + ) + if len(state) != n_layers: + raise ValueError( + f"Number of state elements ({len(state)}) does not equal " + f"number of sublayers ({n_layers})." + ) + + for layer, w, s, rng in zip(self.sublayers, weights, state, rngs): + inputs = inputs_from_stack(stack, layer.n_in) + outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True) + stack = outputs_onto_stack(outputs, stack, layer.n_in) + new_state.append(s) + self.state = tuple(new_state) + return stack + + # pylint: disable=protected-access + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + weights = [] + states = [] + # In the code below, stack, inputs, and outputs are abstract (shapes and + # dtypes), but weights and states are non-abstract actual values. + stack = input_signature + for sublayer in self.sublayers: + inputs = inputs_from_stack(stack, sublayer.n_in) + weights_or_cache_marker, state_or_cache_marker = sublayer.init( + inputs, use_cache=True + ) + outputs, _ = sublayer._forward_abstract(inputs) + stack = outputs_onto_stack(outputs, stack, sublayer.n_in) + + weights.append(weights_or_cache_marker) + states.append(state_or_cache_marker) + self.state = tuple(states) + self.weights = tuple(weights) - As an important special case, a None argument to Parallel acts as if it takes - one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For - example: + # pylint: enable=protected-access - Parallel(None, F) + def _n_inputs_n_outputs(self, layers): + del self + running_max = 0 + running_total = 0 + for layer in layers: + running_total += layer.n_in + running_max = max(running_max, running_total) + running_total -= layer.n_out + return running_max, (running_max - running_total) + + def _validate_forward_inputs(self, xs): + if not isinstance(xs, (tuple, list)) and self._n_in != 1: + raise TypeError( + f"Serial.forward input must be a tuple or list; " + f"instead got {type(xs)}." + ) + # TODO(jonni): Include full xs (or shape) in error message? + + len_xs = 1 if isinstance(xs, jnp.ndarray) else len(xs) + if len_xs < self.n_in: + raise ValueError( + f"Number of inputs ({len(xs)}) to Serial.forward less than n_in " + f"({self.n_in})." + ) - creates a layer that passes its first input unchanged and applies F to the - following input(s). - """ - def __init__(self, *sublayers, name=None): - """The constructor. +class Parallel(base.Layer): + """Combinator that applies a list of layers in parallel to its inputs. - Args: - *sublayers: A list of sublayers. - name: Descriptive name for this layer. + Layers in the list apply to successive spans of inputs, where the spans are + determined how many inputs each layer takes. The resulting output is the + (flattened) concatenation of the respective layer outputs. - Returns: - A new layer in which each of the given sublayers applies to its - corresponding span of elements in the dataflow stack. - """ - super().__init__(name=name) - sublayers = self._validate(sublayers) - self._n_layers = len(sublayers) - self._sublayers = sublayers - self._n_in = sum(l.n_in for l in sublayers) - self._n_out = sum(l.n_out for l in sublayers) - self._weights = tuple(None for l in sublayers) - self._state = tuple(None for l in sublayers) - - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model.""" - n_layers, layers = self._n_layers, self.sublayers - sublayer_inputs = self._allot_to_sublayers(inputs) - state, weights = self.state, self.weights - rngs = _split_rngs(self.rng, n_layers) - if len(sublayer_inputs) != n_layers: - raise ValueError( - f'Number of inputs for sublayers ({len(sublayer_inputs)}) does not equal ' - f'number of sublayers ({n_layers}).') - if len(weights) != n_layers: - raise ValueError( - f'Number of weight elements ({len(weights)}) does not equal ' - f'number of sublayers ({n_layers}).') - if len(state) != n_layers: - raise ValueError( - f'Number of state elements ({len(state)}) does not equal ' - f'number of sublayers ({n_layers}).') - if len(rngs) != n_layers: - raise ValueError( - f'Number of rngs ({len(rngs)}) does not equal ' - f'number of sublayers ({n_layers}).') - outputs = [] - new_state = [] - for layer, x, w, s, r in zip(layers, sublayer_inputs, weights, state, rngs): - # Note that zip silently truncates its result if lengths don't match. - sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True) - if layer.n_out == 1: - outputs.append(sub_outputs) - else: - outputs.extend(sub_outputs) - new_state.append(sub_state) - output = outputs[0] if self.n_out == 1 else tuple(outputs) - self.state = tuple(new_state) - return output - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - sublayer_signatures = self._allot_to_sublayers(input_signature) - inits = [layer.init(signature, use_cache=True) - for layer, signature - in zip(self.sublayers, sublayer_signatures)] - if inits: - weights, state = tuple(zip(*inits)) - self.state = state - self.weights = weights - - def _validate(self, layers): - if not layers or len(layers) < 2: - raise ValueError( - f'layers ({layers}) must be a list with at least two elements') - layers = list(layers) # Ensure we can modify layers. - for i, obj in enumerate(layers): - if obj is None or obj == []: # pylint: disable=g-explicit-bool-comparison - layers[i] = Serial(None) - elif isinstance(obj, (list, tuple)): - layers[i] = Serial(obj) - else: - if not isinstance(obj, base.Layer): - raise ValueError( - f'Found nonlayer object ({obj}) in layers list: [{layers}]') - if layers[i].n_in == 0: - raise ValueError( - f'Sublayer with n_in = 0 not allowed in Parallel: {layers[i]}') - return layers + For example, suppose one has three layers: - def _allot_to_sublayers(self, inputs): - """Divides Parallel's inputs for use by the sublayers. + - F: 1 input, 1 output + - G: 3 inputs, 1 output + - H: 2 inputs, 2 outputs (h1, h2) - Args: - inputs: Tuple of ndarrays or ShapeDtype instances. + Then Parallel(F, G, H) will take 6 inputs and give 4 outputs: - Returns: - A tuple that partitions this layer's inputs among its sublayers. - Sublayers that take one argument get that argument directly. All other - sublayers get a tuple of items. + - inputs: a, b, c, d, e, f + - outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f) + + As an important special case, a None argument to Parallel acts as if it takes + one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For + example: + + Parallel(None, F) + + creates a layer that passes its first input unchanged and applies F to the + following input(s). """ - start, end = 0, 0 - sub_inputs = [] - for layer in self.sublayers: - n_in = layer.n_in - end = start + n_in - if n_in == 1: - sub_inputs.append(inputs[start]) - else: - sub_inputs.append(inputs[start:end]) - start = end - return tuple(sub_inputs) + + def __init__(self, *sublayers, name=None): + """The constructor. + + Args: + *sublayers: A list of sublayers. + name: Descriptive name for this layer. + + Returns: + A new layer in which each of the given sublayers applies to its + corresponding span of elements in the dataflow stack. + """ + super().__init__(name=name) + sublayers = self._validate(sublayers) + self._n_layers = len(sublayers) + self._sublayers = sublayers + self._n_in = sum(l.n_in for l in sublayers) + self._n_out = sum(l.n_out for l in sublayers) + self._weights = tuple(None for l in sublayers) + self._state = tuple(None for l in sublayers) + + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model.""" + n_layers, layers = self._n_layers, self.sublayers + sublayer_inputs = self._allot_to_sublayers(inputs) + state, weights = self.state, self.weights + rngs = _split_rngs(self.rng, n_layers) + if len(sublayer_inputs) != n_layers: + raise ValueError( + f"Number of inputs for sublayers ({len(sublayer_inputs)}) does not equal " + f"number of sublayers ({n_layers})." + ) + if len(weights) != n_layers: + raise ValueError( + f"Number of weight elements ({len(weights)}) does not equal " + f"number of sublayers ({n_layers})." + ) + if len(state) != n_layers: + raise ValueError( + f"Number of state elements ({len(state)}) does not equal " + f"number of sublayers ({n_layers})." + ) + if len(rngs) != n_layers: + raise ValueError( + f"Number of rngs ({len(rngs)}) does not equal " + f"number of sublayers ({n_layers})." + ) + outputs = [] + new_state = [] + for layer, x, w, s, r in zip(layers, sublayer_inputs, weights, state, rngs): + # Note that zip silently truncates its result if lengths don't match. + sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True) + if layer.n_out == 1: + outputs.append(sub_outputs) + else: + outputs.extend(sub_outputs) + new_state.append(sub_state) + output = outputs[0] if self.n_out == 1 else tuple(outputs) + self.state = tuple(new_state) + + if not isinstance(output, (tuple, list)): + output = (output,) + + # The input should be a flat single tuple without nested tuples + output = self.flatten_tuple(output) + + return output + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + sublayer_signatures = self._allot_to_sublayers(input_signature) + inits = [ + layer.init(signature, use_cache=True) + for layer, signature in zip(self.sublayers, sublayer_signatures) + ] + if inits: + weights, state = tuple(zip(*inits)) + self.state = state + self.weights = weights + + def _validate(self, layers): + if not layers or len(layers) < 2: + raise ValueError( + f"layers ({layers}) must be a list with at least two elements" + ) + layers = list(layers) # Ensure we can modify layers. + for i, obj in enumerate(layers): + if obj is None or obj == []: # pylint: disable=g-explicit-bool-comparison + layers[i] = Serial(None) + elif isinstance(obj, (list, tuple)): + layers[i] = Serial(obj) + else: + if not isinstance(obj, base.Layer): + raise ValueError( + f"Found nonlayer object ({obj}) in layers list: [{layers}]" + ) + if layers[i].n_in == 0: + raise ValueError( + f"Sublayer with n_in = 0 not allowed in Parallel: {layers[i]}" + ) + return layers + + def _allot_to_sublayers(self, inputs): + """Divides Parallel's inputs for use by the sublayers. + + Args: + inputs: Tuple of ndarrays or ShapeDtype instances. + + Returns: + A tuple that partitions this layer's inputs among its sublayers. + Sublayers that take one argument get that argument directly. All other + sublayers get a tuple of items. + """ + start, end = 0, 0 + sub_inputs = [] + for layer in self.sublayers: + n_in = layer.n_in + end = start + n_in + if n_in == 1: + sub_inputs.append(inputs[start]) + else: + sub_inputs.append(inputs[start:end]) + start = end + return tuple(sub_inputs) class Concatenate(base.Layer): - """Concatenates a number of tensors into a single tensor. + """Concatenates a number of tensors into a single tensor. - For example:: + For example:: - x = np.array([1, 2]) - y = np.array([3, 4]) - z = np.array([5, 6]) - concat3 = tl.Concatenate(n_items=3) - z = concat3((x, y, z)) # z = [1, 2, 3, 4, 5, 6] + x = np.array([1, 2]) + y = np.array([3, 4]) + z = np.array([5, 6]) + concat3 = tl.Concatenate(n_items=3) + z = concat3((x, y, z)) # z = [1, 2, 3, 4, 5, 6] - Use the `axis` argument to specify on which axis to concatenate the tensors. - By default it's the last axis, `axis=-1`, and `n_items=2`. - """ + Use the `axis` argument to specify on which axis to concatenate the tensors. + By default it's the last axis, `axis=-1`, and `n_items=2`. + """ - def __init__(self, n_items=2, axis=-1): - name = 'Concatenate' if axis == -1 else f'Concatenate_axis{axis}' - super().__init__(n_in=n_items, name=name) - self._n_items = n_items - self._axis = axis + def __init__(self, n_items=2, axis=-1): + name = "Concatenate" if axis == -1 else f"Concatenate_axis{axis}" + super().__init__(n_in=n_items, name=name) + self._n_items = n_items + self._axis = axis - def forward(self, xs): - """Executes this layer as part of a forward pass through the model.""" - return jnp.concatenate(xs, self._axis) + def forward(self, xs): + """Executes this layer as part of a forward pass through the model.""" + return jnp.concatenate(xs, self._axis) class Split(base.Layer): - """Splits the input into n items along an axis.""" + """Splits the input into n items along an axis.""" - def __init__(self, n_items=2, axis=-1): - super().__init__(n_out=n_items) - self._n_items = n_items - self._axis = axis + def __init__(self, n_items=2, axis=-1): + super().__init__(n_out=n_items) + self._n_items = n_items + self._axis = axis - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model.""" - return tuple(jnp.split(inputs, self._n_items, self._axis)) + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model.""" + return tuple(jnp.split(inputs, self._n_items, self._axis)) def _scan(f, xs, init_value, axis=0, remat=False): - """Scans the f over the given axis of xs. - - In pseudo-python, the scan function would look as follows: - - def scan(f, xs, init_value, axis): - xs = [xs[..., i, ...] for i in range(xs.shape[axis])] - cur_value = init_value - ys = [] - for x in xs: - y, cur_value = f(x, cur_value) - ys.append(y) - return np.stack(ys, axis), cur_value - - Args: - f: function (x, carry) -> (y, new_carry) - xs: tensor, x will be xs slices on axis - init_value: tensor, initial value of the carry-over - axis: int, the axis on which to slice xs - remat: whether to re-materialize f - - Returns: - A pair (ys, last_value) as described above. - """ - def swapaxes(x): - transposed_axes = list(range(len(x.shape))) - transposed_axes[axis] = 0 - transposed_axes[0] = axis - return jnp.transpose(x, axes=transposed_axes) - if axis != 0: - xs = fastmath.nested_map(swapaxes, xs) - def transposed_f(c, x): - y, d = f(x, c) - return d, y - if remat: - transposed_f = fastmath.remat(transposed_f) - last_value, ys = fastmath.scan(transposed_f, init_value, xs) - if axis != 0: - ys = fastmath.nested_map(swapaxes, ys) - return ys, last_value + """Scans the f over the given axis of xs. + + In pseudo-python, the scan function would look as follows: + + def scan(f, xs, init_value, axis): + xs = [xs[..., i, ...] for i in range(xs.shape[axis])] + cur_value = init_value + ys = [] + for x in xs: + y, cur_value = f(x, cur_value) + ys.append(y) + return np.stack(ys, axis), cur_value + + Args: + f: function (x, carry) -> (y, new_carry) + xs: tensor, x will be xs slices on axis + init_value: tensor, initial value of the carry-over + axis: int, the axis on which to slice xs + remat: whether to re-materialize f + + Returns: + A pair (ys, last_value) as described above. + """ + + def swapaxes(x): + transposed_axes = list(range(len(x.shape))) + transposed_axes[axis] = 0 + transposed_axes[0] = axis + return jnp.transpose(x, axes=transposed_axes) + + if axis != 0: + xs = fastmath.nested_map(swapaxes, xs) + + def transposed_f(c, x): + y, d = f(x, c) + return d, y + + if remat: + transposed_f = fastmath.remat(transposed_f) + last_value, ys = fastmath.scan(transposed_f, init_value, xs) + if axis != 0: + ys = fastmath.nested_map(swapaxes, ys) + return ys, last_value class Scan(base.Layer): - """Applies a layer progressively/cumulatively to an axis-derived sequence. - - Conceptually, this is a function from a list to a same-length list of partial - (cumulative) results. For instance, a list of values (`[1, 2, 3, 4, 5]`) can - transform to a list of cumulative sums (`[1, 3, 6, 10, 15]`). Functions for - the same concept are called `scan` in Scala, `scanl` in Haskell, and - `accumulate*` in Factor. - - In more detail, we assume the layer takes a tuple of inputs of the following - form: - - (input1, ..., inputN, carry1, ..., carryM) - - and returns: - - (output1, ..., outputK, new_carry1, ..., new_carryM) - - The scanned version applies the layer iteratively to a tensor treating values - at the given axis as if they were a list. For example, to calculate all - sums of prefixes of a tensor, we can do this:: - - def add(x, carry): - def f(input, carry): - res = input + carry - return res, res # output and carry are the same - return tl.Fn('add', f, n_out=2) - - Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6 - """ - - def __init__(self, layer, axis=0, n_carry=1, remat=False, mode='train'): - super().__init__(n_in=layer.n_in, n_out=layer.n_out) - self._sublayers = [layer] - self._n_carry = n_carry - self._axis = axis - self._remat = remat - self._weights = (None,) - self._state = (None, ()) - self._mode = mode - - @property - def sublayer(self): - """Returns the unique sublayer managed by this layer.""" - return self._sublayers[0] - - @property - def state(self): - """Returns a tuple containing this layer's state.""" - return (self.sublayer.state, self._state[1]) - - @state.setter - def state(self, state): - """Recursively sets state on this layer the sublayer.""" - if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE: - return - self._state = (None, state[1]) - self.sublayer.state = state[0] - - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model.""" - weights = self.weights[0] - if isinstance(inputs, list): - inputs = tuple(inputs) # so that inputs structure matches outputs - n_carry = self._n_carry - def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name - carry, state, i = carry_and_state - x_and_carry = x + carry if n_carry > 0 else x - rng = fastmath.random.fold_in(self.rng, i) - res, new_state = self.sublayer.pure_fn( - x_and_carry, weights, state, rng, use_cache=True) - if n_carry > 0: - return (res[:-n_carry], (res[-n_carry:], new_state, i+1)) - else: - return (res, ([], new_state, i+1)) - - if n_carry > 0: - xs = inputs[:-n_carry] # Split input stack into inputs and carry. - xs_carry = inputs[-n_carry:] - if self._mode == 'predict' and self._state[1] is not (): # pylint: disable=literal-comparison - xs_carry = self._state[1] - init = (xs_carry, self.state[0], jnp.array(0, dtype=jnp.int32)) - else: - xs_carry = () - xs, init = inputs, ([], self.state[0], jnp.array(0, dtype=jnp.int32)) - ys, (carry, new_state, _) = _scan(scannable_fn, xs, init, - axis=self._axis, remat=self._remat) - res = ys + carry if n_carry > 0 else ys - state_carry = carry if self._mode == 'predict' and n_carry > 0 else () - self.state = (new_state, state_carry) - return res # Put outputs and carry back on stack. - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - n_carry = self._n_carry - if n_carry == 0: - if isinstance(input_signature, (list, tuple)): - layer_sig = [ShapeDtype(_shape_without_axis(x, self._axis), x.dtype) - for x in input_signature] - layer_sig = tuple(layer_sig) - else: - layer_sig = ShapeDtype(_shape_without_axis(input_signature, self._axis), - input_signature.dtype) - weights, state = self.sublayer.init(layer_sig) - self.state = (state, ()) - self.weights = (weights,) - else: - xs = input_signature[:-n_carry] - init = input_signature[-n_carry:] - xs_slices = [ShapeDtype(_shape_without_axis(x, self._axis), x.dtype) - for x in xs] - layer_signature = tuple(xs_slices + list(init)) - weights, state = self.sublayer.init(layer_signature, use_cache=True) - self.state = (state, ()) - self.weights = (weights,) + """Applies a layer progressively/cumulatively to an axis-derived sequence. + Conceptually, this is a function from a list to a same-length list of partial + (cumulative) results. For instance, a list of values (`[1, 2, 3, 4, 5]`) can + transform to a list of cumulative sums (`[1, 3, 6, 10, 15]`). Functions for + the same concept are called `scan` in Scala, `scanl` in Haskell, and + `accumulate*` in Factor. -class Cond(base.Layer): - """Applies layers conditionally. - - For parameters `cond`, `true`, and `false` runs the equivalent of `true(y) - if cond(x) else false(y)`, where `x` is `cond.n_in` elements from front of the - stack and `y` is the rest of the stack. - Exactly one of `true` and `false` functions is executed, so it can be used to - conditionally run long computations. The state of non-executed function is not - updated. Note that different branches may be executed on different devices - if `cond` returns different values on them. - By default 'false' function is an identity. - - `cond` must return exactly one element: a Boolean value. - `true` and `false` must have the same n_in, and the same n_out. - """ - - def __init__(self, cond, true, false=None, name=None): - super(Cond, self).__init__(name=name) - - if false is None: - self._identity_false_fun = True - # We don't need this function, but it will be useful for checking if - # 'true' has proper n_in/n_out. - false = Serial() - self._false = false - else: - self._identity_false_fun = False - self._false = false - - sublayers = [cond, true, false] - self._sublayers = sublayers - self._n_layers = len(sublayers) - self._cond = cond - self._true = true - - if cond.n_out != 1: - raise ValueError( - 'cond.n_out must be 1: cond:{}->{}'.format(cond.n_in, cond.n_out)) - if true.n_in != false.n_in: - raise ValueError( - 'true.n_in and false.n_in must be equal: true:{}->{} ; false:{}->{}' - .format(true.n_in, true.n_out, false.n_in, false.n_out)) - if true.n_out != false.n_out: - raise ValueError( - 'true.n_out and false.n_out must be equal: true:{}->{} ; false:{}->{}' - .format(true.n_in, true.n_out, false.n_in, false.n_out)) - - self._n_in = cond.n_in + true.n_in - self._n_out = true.n_out - self._weights = tuple(None for l in sublayers) - self._state = tuple(None for l in sublayers) - - # pylint: disable=protected-access - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - weights = [] - states = [] - # In the code below, stack, inputs, and outputs are abstract (shapes and - # dtypes), but weights and states are non-abstract actual values. - stack = _make_tuple(input_signature) - - # Inputs/outputs of `cond`. - inputs = inputs_from_stack(stack, self._cond.n_in) - weights_or_cache_marker, state_or_cache_marker = ( - self._cond.init(inputs, use_cache=True)) - weights.append(weights_or_cache_marker) - states.append(state_or_cache_marker) - self._cond._forward_abstract(inputs) - stack = _make_tuple(outputs_onto_stack([], stack, self._cond.n_in)) - - # Inputs/outputs of `true` and `false`. - for sublayer in [self._true, self._false]: - inputs = inputs_from_stack(stack, sublayer.n_in) - weights_or_cache_marker, state_or_cache_marker = ( - sublayer.init(inputs, use_cache=True)) - weights.append(weights_or_cache_marker) - states.append(state_or_cache_marker) - - self.state = states - self.weights = weights - # pylint: enable=protected-access + In more detail, we assume the layer takes a tuple of inputs of the following + form: - def _validate_forward_inputs(self, xs): - xs = _make_tuple(xs) - if len(xs) < self.n_in: - raise ValueError( - f'Number of inputs ({len(xs)}) to Cond.forward less than n_in ' - f'({self.n_in}).') + (input1, ..., inputN, carry1, ..., carryM) - def forward(self, xs): - """Executes this layer as part of a forward pass through the model. + and returns: - Args: - xs: Tensors of as required by the branches of this conditional. + (output1, ..., outputK, new_carry1, ..., new_carryM) - Returns: - Tensors resulting from running the chosen branch. + The scanned version applies the layer iteratively to a tensor treating values + at the given axis as if they were a list. For example, to calculate all + sums of prefixes of a tensor, we can do this:: + + def add(x, carry): + def f(input, carry): + res = input + carry + return res, res # output and carry are the same + return tl.Fn('add', f, n_out=2) + + Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6 """ - # TODO(jaszczur): modify; it's a copy from SkippingSerial - self._validate_forward_inputs(xs) - layers_state = self.state - # Get 3 rngs, one for each layer. - rngs = _split_rngs(self.rng, 3) - - # Prepare the stack and do some safety checks as in the parent class. - stack = _make_tuple(xs) - weights = self.weights - if len(weights) != 3: - raise ValueError('number of weights ({}) not equal to 3' - .format(len(weights))) - if len(layers_state) != 3: - raise ValueError('length of state ({}) not equal to 3' - .format(len(layers_state))) - - def true_func(t): - outputs, new_true_state = self._true.pure_fn( - t[0][0], t[1][0], t[2][0], t[3][0]) - # t[2][1] is old_false_state which is not changing if true is executed. - return outputs, (new_true_state, t[2][1]) - - def false_func(t): - if self._identity_false_fun: - # Memory optimization: we don't need pure_fn call. - return t[0][1], t[2] - outputs, new_false_state = self._false.pure_fn( - t[0][1], t[1][1], t[2][1], t[3][1]) - # t[2][1] is old_true_state, which is not changing if false is executed. - return outputs, (t[2][0], new_false_state) - - cond_inputs = inputs_from_stack(xs, self._cond.n_in) - cond_output, s = self._cond.pure_fn(cond_inputs, self.weights[0], - self.state[0], rngs[0], use_cache=True) - stack = outputs_onto_stack([], stack, self._cond.n_in) - self._cond.state = s - - outputs, both_states = fastmath.cond( - cond_output, - true_func, - false_func, - [(stack, stack), - (self.weights[1], self.weights[2]), - (self.state[1], self.state[2]), - (rngs[1], rngs[2])] - ) - stack = outputs_onto_stack([], stack, self._cond.n_in) - # We don't know which (`true` or `false`) branch was run, but both of them - # are adding (n_out) and removing (n_in) the same number of elements of the - # stack (this was checked in __init__). outputs_onto_stack just uses the - # layer's n_in, so we can pass either `true` or `false` to it. - # Note that `outputs` is the actual output of `true` or `false` branch, - # whichever was run, and we add it to the stack in any case. - stack = outputs_onto_stack(outputs, stack, self._true.n_in) - self._true.state = both_states[0] - self._false.state = both_states[1] - return _make_singleitem_or_original(stack) + def __init__(self, layer, axis=0, n_carry=1, remat=False, mode="train"): + super().__init__(n_in=layer.n_in, n_out=layer.n_out) + self._sublayers = [layer] + self._n_carry = n_carry + self._axis = axis + self._remat = remat + self._weights = (None,) + self._state = (None, ()) + self._mode = mode + + @property + def sublayer(self): + """Returns the unique sublayer managed by this layer.""" + return self._sublayers[0] + + @property + def state(self): + """Returns a tuple containing this layer's state.""" + return (self.sublayer.state, self._state[1]) + + @state.setter + def state(self, state): + """Recursively sets state on this layer the sublayer.""" + if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE: + return + self._state = (None, state[1]) + self.sublayer.state = state[0] + + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model.""" + weights = self.weights[0] + if isinstance(inputs, list): + inputs = tuple(inputs) # so that inputs structure matches outputs + n_carry = self._n_carry + + def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name + carry, state, i = carry_and_state + x_and_carry = x + carry if n_carry > 0 else x + rng = fastmath.random.fold_in(self.rng, i) + res, new_state = self.sublayer.pure_fn( + x_and_carry, weights, state, rng, use_cache=True + ) + if n_carry > 0: + return (res[:-n_carry], (res[-n_carry:], new_state, i + 1)) + else: + return (res, ([], new_state, i + 1)) + + if n_carry > 0: + xs = inputs[:-n_carry] # Split input stack into inputs and carry. + xs_carry = inputs[-n_carry:] + if ( + self._mode == "predict" and self._state[1] is not () + ): # pylint: disable=literal-comparison + xs_carry = self._state[1] + init = (xs_carry, self.state[0], jnp.array(0, dtype=jnp.int32)) + else: + xs_carry = () + xs, init = inputs, ([], self.state[0], jnp.array(0, dtype=jnp.int32)) + ys, (carry, new_state, _) = _scan( + scannable_fn, xs, init, axis=self._axis, remat=self._remat + ) + res = ys + carry if n_carry > 0 else ys + state_carry = carry if self._mode == "predict" and n_carry > 0 else () + self.state = (new_state, state_carry) + return res # Put outputs and carry back on stack. + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + n_carry = self._n_carry + if n_carry == 0: + if isinstance(input_signature, (list, tuple)): + layer_sig = [ + ShapeDtype(_shape_without_axis(x, self._axis), x.dtype) + for x in input_signature + ] + layer_sig = tuple(layer_sig) + else: + layer_sig = ShapeDtype( + _shape_without_axis(input_signature, self._axis), + input_signature.dtype, + ) + weights, state = self.sublayer.init(layer_sig) + self.state = (state, ()) + self.weights = (weights,) + else: + xs = input_signature[:-n_carry] + init = input_signature[-n_carry:] + xs_slices = [ + ShapeDtype(_shape_without_axis(x, self._axis), x.dtype) for x in xs + ] + layer_signature = tuple(xs_slices + list(init)) + weights, state = self.sublayer.init(layer_signature, use_cache=True) + self.state = (state, ()) + self.weights = (weights,) + + +class Cond(base.Layer): + """Applies layers conditionally. + + For parameters `cond`, `true`, and `false` runs the equivalent of `true(y) + if cond(x) else false(y)`, where `x` is `cond.n_in` elements from front of the + stack and `y` is the rest of the stack. + Exactly one of `true` and `false` functions is executed, so it can be used to + conditionally run long computations. The state of non-executed function is not + updated. Note that different branches may be executed on different devices + if `cond` returns different values on them. + By default 'false' function is an identity. + + `cond` must return exactly one element: a Boolean value. + `true` and `false` must have the same n_in, and the same n_out. + """ + + def __init__(self, cond, true, false=None, name=None): + super(Cond, self).__init__(name=name) + + if false is None: + self._identity_false_fun = True + # We don't need this function, but it will be useful for checking if + # 'true' has proper n_in/n_out. + false = Serial() + self._false = false + else: + self._identity_false_fun = False + self._false = false + + sublayers = [cond, true, false] + self._sublayers = sublayers + self._n_layers = len(sublayers) + self._cond = cond + self._true = true + + if cond.n_out != 1: + raise ValueError( + "cond.n_out must be 1: cond:{}->{}".format(cond.n_in, cond.n_out) + ) + if true.n_in != false.n_in: + raise ValueError( + "true.n_in and false.n_in must be equal: true:{}->{} ; false:{}->{}".format( + true.n_in, true.n_out, false.n_in, false.n_out + ) + ) + if true.n_out != false.n_out: + raise ValueError( + "true.n_out and false.n_out must be equal: true:{}->{} ; false:{}->{}".format( + true.n_in, true.n_out, false.n_in, false.n_out + ) + ) + + self._n_in = cond.n_in + true.n_in + self._n_out = true.n_out + self._weights = tuple(None for l in sublayers) + self._state = tuple(None for l in sublayers) + + # pylint: disable=protected-access + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + weights = [] + states = [] + # In the code below, stack, inputs, and outputs are abstract (shapes and + # dtypes), but weights and states are non-abstract actual values. + stack = _make_tuple(input_signature) + + # Inputs/outputs of `cond`. + inputs = inputs_from_stack(stack, self._cond.n_in) + weights_or_cache_marker, state_or_cache_marker = self._cond.init( + inputs, use_cache=True + ) + weights.append(weights_or_cache_marker) + states.append(state_or_cache_marker) + self._cond._forward_abstract(inputs) + stack = _make_tuple(outputs_onto_stack([], stack, self._cond.n_in)) + + # Inputs/outputs of `true` and `false`. + for sublayer in [self._true, self._false]: + inputs = inputs_from_stack(stack, sublayer.n_in) + weights_or_cache_marker, state_or_cache_marker = sublayer.init( + inputs, use_cache=True + ) + weights.append(weights_or_cache_marker) + states.append(state_or_cache_marker) + + self.state = states + self.weights = weights + # pylint: enable=protected-access + + def _validate_forward_inputs(self, xs): + xs = _make_tuple(xs) + if len(xs) < self.n_in: + raise ValueError( + f"Number of inputs ({len(xs)}) to Cond.forward less than n_in " + f"({self.n_in})." + ) + + def forward(self, xs): + """Executes this layer as part of a forward pass through the model. + + Args: + xs: Tensors of as required by the branches of this conditional. + + Returns: + Tensors resulting from running the chosen branch. + """ + + # TODO(jaszczur): modify; it's a copy from SkippingSerial + self._validate_forward_inputs(xs) + xs = _make_tuple(xs) + + # The input should be a flat single tuple without nested tuples + xs = self.flatten_tuple(xs) + + layers_state = self.state + # Get 3 rngs, one for each layer. + rngs = _split_rngs(self.rng, 3) + + # Prepare the stack and do some safety checks as in the parent class. + stack = _make_tuple(xs) + stack = self.flatten_tuple(stack) + weights = self.weights + if len(weights) != 3: + raise ValueError( + "number of weights ({}) not equal to 3".format(len(weights)) + ) + if len(layers_state) != 3: + raise ValueError( + "length of state ({}) not equal to 3".format(len(layers_state)) + ) + + def true_func(t): + outputs, new_true_state = self._true.pure_fn( + t[0][0], t[1][0], t[2][0], t[3][0] + ) + # t[2][1] is old_false_state which is not changing if true is executed. + outputs = _make_tuple(outputs) + return outputs, (new_true_state, t[2][1]) + + def false_func(t): + if self._identity_false_fun: + # Memory optimization: we don't need pure_fn call. + return t[0][1], t[2] + outputs, new_false_state = self._false.pure_fn( + t[0][1], t[1][1], t[2][1], t[3][1] + ) + # t[2][1] is old_true_state, which is not changing if false is executed. + outputs = _make_tuple(outputs) + return outputs, (t[2][0], new_false_state) + + cond_inputs = inputs_from_stack(xs, self._cond.n_in) + cond_inputs = _make_tuple(cond_inputs) + cond_output, s = self._cond.pure_fn( + cond_inputs, self.weights[0], self.state[0], rngs[0], use_cache=True + ) + stack = outputs_onto_stack([], stack, self._cond.n_in) + stack = _make_tuple(stack) + self._cond.state = s + + outputs, both_states = fastmath.cond( + cond_output, + true_func, + false_func, + [ + (stack, stack), + (self.weights[1], self.weights[2]), + (self.state[1], self.state[2]), + (rngs[1], rngs[2]), + ], + ) + stack = outputs_onto_stack([], stack, self._cond.n_in) + + # We don't know which (`true` or `false`) branch was run, but both of them + # are adding (n_out) and removing (n_in) the same number of elements of the + # stack (this was checked in __init__). outputs_onto_stack just uses the + # layer's n_in, so we can pass either `true` or `false` to it. + # Note that `outputs` is the actual output of `true` or `false` branch, + # whichever was run, and we add it to the stack in any case. + stack = outputs_onto_stack(outputs, stack, self._true.n_in) + self._true.state = both_states[0] + self._false.state = both_states[1] + return _make_singleitem_or_original(stack) # pylint: disable=invalid-name def Chunk(layer, chunk_size, pass_unchunkable=True): - """Executes `layer` using batch chunks of size `chunk_size` to save memory.""" - if chunk_size < 1: - return layer - def reshape_to_chunks(x): - chunk_batch = x.shape[0] - size = chunk_size - n_chunks = chunk_batch // size - if chunk_batch % size != 0: - if pass_unchunkable: - n_chunks = 1 - size = chunk_batch - else: - raise ValueError(f'Chunk size {size} must divide batch ' - f'size {chunk_batch}') - return jnp.reshape(x, [n_chunks, size] + list(x.shape[1:])) - reshape_to_chunks_layer = base.PureLayer( - lambda xs: fastmath.nested_map(reshape_to_chunks, xs), - n_in=layer.n_in, n_out=layer.n_in, name='ReshapeToChunks') - def reshape_from_chunks(x): - batch_size = x.shape[0] * x.shape[1] - return jnp.reshape(x, [batch_size] + list(x.shape[2:])) - reshape_from_chunks_layer = base.PureLayer( - lambda xs: fastmath.nested_map(reshape_from_chunks, xs), - n_in=layer.n_out, n_out=layer.n_out, name='ReshapeFromChunks') - return Serial( - reshape_to_chunks_layer, - Scan(layer, axis=0, n_carry=0, remat=True), - reshape_from_chunks_layer, - ) - - -def Branch(*layers, name='Branch'): - """Combinator that applies a list of layers in parallel to copies of inputs. - - Each layer in the input list is applied to as many inputs from the stack - as it needs, and their outputs are successively combined on stack. - - For example, suppose one has three layers: - - - F: 1 input, 1 output - - G: 3 inputs, 1 output - - H: 2 inputs, 2 outputs (h1, h2) - - Then Branch(F, G, H) will take 3 inputs and give 4 outputs: - - - inputs: a, b, c - - outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b) - - As an important special case, a None argument to Branch acts as if it takes - one argument, which it leaves unchanged. (It acts as a one-arg no-op.) - - Args: - *layers: List of layers. - name: Descriptive name for this layer. - - Returns: - A branch layer built from the given sublayers. - """ - if len(layers) == 1: - return layers[0] - parallel_layer = Parallel(*layers) - indices = [list(range(layer.n_in)) for layer in parallel_layer.sublayers] - return Serial(Select(_deep_flatten(indices)), parallel_layer, - name=name, sublayers_to_print=layers) + """Executes `layer` using batch chunks of size `chunk_size` to save memory.""" + if chunk_size < 1: + return layer + + def reshape_to_chunks(x): + chunk_batch = x.shape[0] + size = chunk_size + n_chunks = chunk_batch // size + if chunk_batch % size != 0: + if pass_unchunkable: + n_chunks = 1 + size = chunk_batch + else: + raise ValueError( + f"Chunk size {size} must divide batch " f"size {chunk_batch}" + ) + return jnp.reshape(x, [n_chunks, size] + list(x.shape[1:])) + + reshape_to_chunks_layer = base.PureLayer( + lambda xs: fastmath.nested_map(reshape_to_chunks, xs), + n_in=layer.n_in, + n_out=layer.n_in, + name="ReshapeToChunks", + ) + + def reshape_from_chunks(x): + batch_size = x.shape[0] * x.shape[1] + return jnp.reshape(x, [batch_size] + list(x.shape[2:])) + + reshape_from_chunks_layer = base.PureLayer( + lambda xs: fastmath.nested_map(reshape_from_chunks, xs), + n_in=layer.n_out, + n_out=layer.n_out, + name="ReshapeFromChunks", + ) + return Serial( + reshape_to_chunks_layer, + Scan(layer, axis=0, n_carry=0, remat=True), + reshape_from_chunks_layer, + ) + + +def Branch(*layers, name="Branch"): + """Combinator that applies a list of layers in parallel to copies of inputs. + + Each layer in the input list is applied to as many inputs from the stack + as it needs, and their outputs are successively combined on stack. + + For example, suppose one has three layers: + + - F: 1 input, 1 output + - G: 3 inputs, 1 output + - H: 2 inputs, 2 outputs (h1, h2) + + Then Branch(F, G, H) will take 3 inputs and give 4 outputs: + + - inputs: a, b, c + - outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b) + + As an important special case, a None argument to Branch acts as if it takes + one argument, which it leaves unchanged. (It acts as a one-arg no-op.) + + Args: + *layers: List of layers. + name: Descriptive name for this layer. + + Returns: + A branch layer built from the given sublayers. + """ + if len(layers) == 1: + return layers[0] + parallel_layer = Parallel(*layers) + indices = [list(range(layer.n_in)) for layer in parallel_layer.sublayers] + return Serial( + Select(_deep_flatten(indices)), + parallel_layer, + name=name, + sublayers_to_print=layers, + ) def Residual(*layers, shortcut=None): - """Wraps a series of layers with a residual connection. - - Args: - *layers: One or more layers, to be applied in series. - shortcut: If None (the usual case), the Residual layer computes the - element-wise sum of the stack-top input with the output of the layer - series. If specified, the `shortcut` layer applies to a copy of the - inputs and (elementwise) adds its output to the output from the main - layer series. - - Returns: - A layer representing a residual connection paired with a layer series. - """ - layers = _ensure_flat(layers) - layer = layers[0] if len(layers) == 1 else Serial(layers) - # TODO(jonni): Should we require layer.n_out = 1 and shortcut.n_out = 1? - return Serial( - Branch(shortcut, layer), - Add(), # pylint: disable=no-value-for-parameter - ) + """Wraps a series of layers with a residual connection. + + Args: + *layers: One or more layers, to be applied in series. + shortcut: If None (the usual case), the Residual layer computes the + element-wise sum of the stack-top input with the output of the layer + series. If specified, the `shortcut` layer applies to a copy of the + inputs and (elementwise) adds its output to the output from the main + layer series. + + Returns: + A layer representing a residual connection paired with a layer series. + """ + layers = _ensure_flat(layers) + layer = layers[0] if len(layers) == 1 else Serial(layers) + # TODO(jonni): Should we require layer.n_out = 1 and shortcut.n_out = 1? + return Serial( + Branch(shortcut, layer), + Add(), # pylint: disable=no-value-for-parameter + ) def Select(indices, n_in=None, name=None): - """Copies, reorders, or deletes stack elements according to `indices`. - - Args: - indices: A list or tuple of 0-based indices to select elements relative to - the top of the stack. - n_in: Number of input elements to pop from the stack, and replace with - those specified by `indices`. If not specified, its value will be - calculated as `max(indices) + 1`. - name: Descriptive name for this layer. - - Returns: - Tensors, matching the number selected (`n_out = len(indices)`). - Specifically: - - - n_out = 0: an empty tuple - - n_out = 1: one tensor (NOT wrapped in a tuple) - - n_out > 1: a tuple of tensors, with n_out items - """ - if n_in is None: - n_in = max(indices) + 1 - if name is None: - name = f'Select{indices}'.replace(' ', '') - - def select(xs): # pylint: disable=invalid-name - if not isinstance(xs, (tuple, list)): - xs = (xs,) - selected = tuple(xs[i] for i in indices) - return selected[0] if len(selected) == 1 else selected - - return base.PureLayer(select, n_in=n_in, n_out=len(indices), name=name) + """Copies, reorders, or deletes stack elements according to `indices`. + + Args: + indices: A list or tuple of 0-based indices to select elements relative to + the top of the stack. + n_in: Number of input elements to pop from the stack, and replace with + those specified by `indices`. If not specified, its value will be + calculated as `max(indices) + 1`. + name: Descriptive name for this layer. + + Returns: + Tensors, matching the number selected (`n_out = len(indices)`). + Specifically: + + - n_out = 0: an empty tuple + - n_out = 1: one tensor (NOT wrapped in a tuple) + - n_out > 1: a tuple of tensors, with n_out items + """ + if n_in is None: + n_in = max(indices) + 1 + if name is None: + name = f"Select{indices}".replace(" ", "") + + def select(xs): # pylint: disable=invalid-name + if not isinstance(xs, (tuple, list)): + xs = (xs,) + selected = tuple(xs[i] for i in indices) + return selected[0] if len(selected) == 1 else selected + + return base.PureLayer(select, n_in=n_in, n_out=len(indices), name=name) def Drop(): - """Drops the top stack element.""" - return Fn('Drop', lambda x: (), n_out=0) + """Drops the top stack element.""" + return Fn("Drop", lambda x: (), n_out=0) def Dup(): - """Duplicates (copies) the top element on the data stack.""" - return Fn('Dup', lambda x: (x, x), n_out=2) + """Duplicates (copies) the top element on the data stack.""" + return Fn("Dup", lambda x: (x, x), n_out=2) def Swap(): - """Swaps the top two stack elements.""" - return Fn('Swap', lambda x0, x1: (x1, x0), n_out=2) + """Swaps the top two stack elements.""" + return Fn("Swap", lambda x0, x1: (x1, x0), n_out=2) def SerialWithSideOutputs(layers, n_side_outputs=1): - """Serial layer with side outputs. - - This layer makes it easier to manage the stack when layers have side outputs. - - In the simplest case of layers with n_in=1, n_out=2 and with - n_side_outputs=1, this layer runs the following computation on x:: - - side_outputs = [] - for i in range(len(layers)): - x, side_output = layers[i](x) - side_outputs.append(side_output) - return [x] + side_outputs - - In the general case of layers with variable n_in and n_out and - n_side_outputs being a list of N integers, it does the following:: - - side_outputs = [] - for i in range(N): - res = layer[i](cur_stack) # remove n_in from stack - cur_stack.append(res[:n_side_outputs[i]]) # put back some on stack - side_outputs.extend(res[n_side_outputs:]) - return cur_stack + side_outputs - - Args: - layers: a list of layers to execute - n_side_outputs: an int or a list of ints, how many outputs of each layer - to put aside - - Returns: - A layer that performs the above computation. - """ - if isinstance(n_side_outputs, int): - n_side_outputs = [n_side_outputs] * len(layers) - - # Calculate the n_in for this layer. - running_max = 0 - running_total = 0 - for layer, n_side_output in zip(layers, n_side_outputs): - running_total += layer.n_in - running_max = max(running_max, running_total) - running_total -= layer.n_out - n_side_output - n_in = running_max - - # Create the list of layers to run serially. - cur_stack_size = n_in - serial_layers = [] - for layer, n_side_output in zip(layers, n_side_outputs): - serial_layers.append(layer) - cur_stack_size += layer.n_out - layer.n_in - # Indices to move n_side_outputs to the back of the stack. - # Don't touch first n_out - n_side_outputs. - move_back_indices = list(range(layer.n_out - n_side_output)) - # Then comes the rest of the stack that we're not moving. - move_back_indices += [i + layer.n_out - for i in range(cur_stack_size - layer.n_out)] - # Finally the indices we move. - move_back_indices += [i + layer.n_out - n_side_output - for i in range(n_side_output)] - # Swap them on stack. - serial_layers.append(Select(move_back_indices)) - - return Serial(serial_layers) + """Serial layer with side outputs. + + This layer makes it easier to manage the stack when layers have side outputs. + + In the simplest case of layers with n_in=1, n_out=2 and with + n_side_outputs=1, this layer runs the following computation on x:: + + side_outputs = [] + for i in range(len(layers)): + x, side_output = layers[i](x) + side_outputs.append(side_output) + return [x] + side_outputs + + In the general case of layers with variable n_in and n_out and + n_side_outputs being a list of N integers, it does the following:: + + side_outputs = [] + for i in range(N): + res = layer[i](cur_stack) # remove n_in from stack + cur_stack.append(res[:n_side_outputs[i]]) # put back some on stack + side_outputs.extend(res[n_side_outputs:]) + return cur_stack + side_outputs + + Args: + layers: a list of layers to execute + n_side_outputs: an int or a list of ints, how many outputs of each layer + to put aside + + Returns: + A layer that performs the above computation. + """ + if isinstance(n_side_outputs, int): + n_side_outputs = [n_side_outputs] * len(layers) + + # Calculate the n_in for this layer. + running_max = 0 + running_total = 0 + for layer, n_side_output in zip(layers, n_side_outputs): + running_total += layer.n_in + running_max = max(running_max, running_total) + running_total -= layer.n_out - n_side_output + n_in = running_max + + # Create the list of layers to run serially. + cur_stack_size = n_in + serial_layers = [] + for layer, n_side_output in zip(layers, n_side_outputs): + serial_layers.append(layer) + cur_stack_size += layer.n_out - layer.n_in + # Indices to move n_side_outputs to the back of the stack. + # Don't touch first n_out - n_side_outputs. + move_back_indices = list(range(layer.n_out - n_side_output)) + # Then comes the rest of the stack that we're not moving. + move_back_indices += [ + i + layer.n_out for i in range(cur_stack_size - layer.n_out) + ] + # Finally the indices we move. + move_back_indices += [ + i + layer.n_out - n_side_output for i in range(n_side_output) + ] + # Swap them on stack. + serial_layers.append(Select(move_back_indices)) + + return Serial(serial_layers) def FlattenList(): - """Flatten lists.""" - # TODO(jonni): Consider renaming layer to DeepFlatten. - return Fn('FlattenList', lambda x: tuple(_deep_flatten(x))) + """Flatten lists.""" + # TODO(jonni): Consider renaming layer to DeepFlatten. + return Fn("FlattenList", lambda x: tuple(_deep_flatten(x))) def Add(): - """Adds two tensors.""" - return Fn('Add', lambda x0, x1: x0 + x1) + """Adds two tensors.""" + return Fn("Add", lambda x0, x1: jnp.add(x0, x1)) def SubtractTop(): - """Subtracts the first tensor from the second.""" - return Fn('SubtractTop', lambda x0, x1: x1 - x0) + """Subtracts the first tensor from the second.""" + return Fn("SubtractTop", lambda x0, x1: jnp.subtract(x1, x0)) def Multiply(): - """Multiplies two tensors.""" - return Fn('Multiply', lambda x0, x1: x0 * x1) + """Multiplies two tensors.""" + return Fn("Multiply", lambda x0, x1: jnp.multiply(x0, x1)) def Gate(): - """Returns a gating layer on a (memory, gate, candidate) tuple. + """Returns a gating layer on a (memory, gate, candidate) tuple. - Final update is memory * gate + (1 - gate) * candidate + Final update is memory * gate + (1 - gate) * candidate - This gating equation may also be referred to as Highway Network. - Highway Networks: https://arxiv.org/abs/1505.00387 - """ - return Fn('Gate', lambda m, g, c: g * m + (1.0 - g) * c) + This gating equation may also be referred to as Highway Network. + Highway Networks: https://arxiv.org/abs/1505.00387 + """ + return Fn("Gate", lambda m, g, c: g * m + (1.0 - g) * c) class Cache(base.Layer): - """Applies a layer on the first run and returns the outputs on next calls.""" - - def __init__(self, layer): - super().__init__(n_in=layer.n_in, n_out=layer.n_out) - self._sublayers = [layer] - - @property - def sublayer(self): - """Returns the unique sublayer managed by this layer.""" - return self._sublayers[0] - - @property - def state(self): - """Returns a tuple containing this layer's state; may be empty.""" - return self._state - - @state.setter - def state(self, state): - """Recursively sets state on this layer and all sublayers.""" - if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE: - return - self._state = state - self.sublayer.state = state[1] - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - weights, layer_state = self.sublayer.init(input_signature, use_cache=True) - self.state = ((), layer_state) - self._weights = (weights,) - - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model. + """Applies a layer on the first run and returns the outputs on next calls.""" + + def __init__(self, layer): + super().__init__(n_in=layer.n_in, n_out=layer.n_out) + self._sublayers = [layer] + + @property + def sublayer(self): + """Returns the unique sublayer managed by this layer.""" + return self._sublayers[0] + + @property + def state(self): + """Returns a tuple containing this layer's state; may be empty.""" + return self._state + + @state.setter + def state(self, state): + """Recursively sets state on this layer and all sublayers.""" + if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE: + return + self._state = state + self.sublayer.state = state[1] + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + weights, layer_state = self.sublayer.init(input_signature, use_cache=True) + self.state = ((), layer_state) + self._weights = (weights,) + + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model. + + Args: + inputs: Tensors required by the sublayer. + + Returns: + Tensors resulting from running the sublayer the first time. + """ + state, weights = self.state, self.weights[0] + if state[0] is (): # pylint: disable=literal-comparison + res, layer_state = self.sublayer.pure_fn( + inputs, weights, state[1], self.rng + ) + self.state = (res, layer_state) + return res + else: + return state[0] - Args: - inputs: Tensors required by the sublayer. - Returns: - Tensors resulting from running the sublayer the first time. - """ - state, weights = self.state, self.weights[0] - if state[0] is (): # pylint: disable=literal-comparison - res, layer_state = self.sublayer.pure_fn( - inputs, weights, state[1], self.rng) - self.state = (res, layer_state) - return res - else: - return state[0] +class BatchLeadingAxes(base.Layer): + """Applies a layer after flattening all but n_last_axes_to_keep to batch. + This can be used to make layers accept an arbitrary number of leading + axes (dimensions) as batch. For example, a Convolution layer may normally + only operate on tensors of shape [B, W, H, C]. In this case, the layer -class BatchLeadingAxes(base.Layer): - """Applies a layer after flattening all but n_last_axes_to_keep to batch. - - This can be used to make layers accept an arbitrary number of leading - axes (dimensions) as batch. For example, a Convolution layer may normally - only operate on tensors of shape [B, W, H, C]. In this case, the layer - - BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3) - - will operate on any tensor [..., W, H, C] and treat the leading axes as batch. - """ - - def __init__(self, layer, n_last_axes_to_keep=1): - if layer.n_out != 1: - raise ValueError('BatchLeadingAxes currently only works for layers with ' - f'n_out = 1, got {layer.n_out}.') - super().__init__(n_in=layer.n_in, n_out=layer.n_out) - self._sublayers = [layer] - self._n_last_axes_to_keep = n_last_axes_to_keep - self._weights = (None,) - self._state = (None,) - - @property - def sublayer(self): - """Returns the unique sublayer managed by this layer.""" - return self._sublayers[0] - - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model.""" - if self._n_in == 1: - inputs = [inputs] - new_inputs = [] - for old_input in inputs: - batched_axes_shape = list(old_input.shape[:-self._n_last_axes_to_keep]) - batched_shape = [-1] + list(old_input.shape[-self._n_last_axes_to_keep:]) - new_inputs.append(jnp.reshape(old_input, batched_shape)) - new_inputs = tuple(new_inputs) - if self._n_in == 1: - new_inputs = new_inputs[0] - res, layer_state = self.sublayer.pure_fn( - new_inputs, self.weights[0], self.state[0], self.rng) - self.state = (layer_state,) - return jnp.reshape(res, batched_axes_shape + list(res.shape[1:])) - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - if self._n_in == 1 and not isinstance(input_signature, (list, tuple)): - input_signature = (input_signature,) - batched_signature = [] - for sub_input_signature in input_signature: - batched_size = 1 - for d in sub_input_signature.shape[:-self._n_last_axes_to_keep]: - batched_size *= d - batched_shape = [batched_size] + list( - sub_input_signature.shape[-self._n_last_axes_to_keep:]) - batched_signature.append(ShapeDtype(batched_shape, - sub_input_signature.dtype)) - if self._n_in == 1: - batched_signature = batched_signature[0] - weights, layer_state = self.sublayer.init(batched_signature, use_cache=True) - self.state = (layer_state,) - self.weights = (weights,) + BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3) + + will operate on any tensor [..., W, H, C] and treat the leading axes as batch. + """ + + def __init__(self, layer, n_last_axes_to_keep=1): + if layer.n_out != 1: + raise ValueError( + "BatchLeadingAxes currently only works for layers with " + f"n_out = 1, got {layer.n_out}." + ) + super().__init__(n_in=layer.n_in, n_out=layer.n_out) + self._sublayers = [layer] + self._n_last_axes_to_keep = n_last_axes_to_keep + self._weights = (None,) + self._state = (None,) + + @property + def sublayer(self): + """Returns the unique sublayer managed by this layer.""" + return self._sublayers[0] + + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model.""" + if self._n_in == 1: + inputs = [inputs] + new_inputs = [] + for old_input in inputs: + batched_axes_shape = list(old_input.shape[: -self._n_last_axes_to_keep]) + batched_shape = [-1] + list(old_input.shape[-self._n_last_axes_to_keep :]) + new_inputs.append(jnp.reshape(old_input, batched_shape)) + new_inputs = tuple(new_inputs) + if self._n_in == 1: + new_inputs = new_inputs[0] + res, layer_state = self.sublayer.pure_fn( + new_inputs, self.weights[0], self.state[0], self.rng + ) + self.state = (layer_state,) + return jnp.reshape(res, batched_axes_shape + list(res.shape[1:])) + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + if self._n_in == 1 and not isinstance(input_signature, (list, tuple)): + input_signature = (input_signature,) + batched_signature = [] + for sub_input_signature in input_signature: + batched_size = 1 + for d in sub_input_signature.shape[: -self._n_last_axes_to_keep]: + batched_size *= d + batched_shape = [batched_size] + list( + sub_input_signature.shape[-self._n_last_axes_to_keep :] + ) + batched_signature.append( + ShapeDtype(batched_shape, sub_input_signature.dtype) + ) + if self._n_in == 1: + batched_signature = batched_signature[0] + weights, layer_state = self.sublayer.init(batched_signature, use_cache=True) + self.state = (layer_state,) + self.weights = (weights,) def Bidirectional(forward_layer, axis=1, merge_layer=Concatenate()): - """Bidirectional combinator for RNNs. - - Args: - forward_layer: A layer, such as `trax.layers.LSTM` or `trax.layers.GRU`. - axis: a time axis of the inputs. Default value is `1`. - merge_layer: A combinator used to combine outputs of the forward - and backward RNNs. Default value is 'trax.layers.Concatenate'. - - Example: - Bidirectional(RNN(n_units=8)) - - Returns: - The Bidirectional combinator for RNNs. - """ - backward_layer = copy.deepcopy(forward_layer) - flip = base.Fn('_FlipAlongTimeAxis', lambda x: jnp.flip(x, axis=axis)) - backward = Serial( - flip, - backward_layer, - flip, - ) - - return Serial( - Branch(forward_layer, backward), - merge_layer, - ) + """Bidirectional combinator for RNNs. + + Args: + forward_layer: A layer, such as `trax.layers.LSTM` or `trax.layers.GRU`. + axis: a time axis of the inputs. Default value is `1`. + merge_layer: A combinator used to combine outputs of the forward + and backward RNNs. Default value is 'trax.layers.Concatenate'. + + Example: + Bidirectional(RNN(n_units=8)) + + Returns: + The Bidirectional combinator for RNNs. + """ + backward_layer = copy.deepcopy(forward_layer) + flip = base.Fn("_FlipAlongTimeAxis", lambda x: jnp.flip(x, axis=axis)) + backward = Serial( + flip, + backward_layer, + flip, + ) + + return Serial( + Branch(forward_layer, backward), + merge_layer, + ) # All module-private helper functions are below. @@ -1013,99 +1105,103 @@ def Bidirectional(forward_layer, axis=1, merge_layer=Concatenate()): def _deep_flatten(items): - """Returns a list of objects, flattening sublists/subtuples along the way. + """Returns a list of objects, flattening sublists/subtuples along the way. - Example: _deep_flatten([1, (2, 3, (4, 5), [6, 7]), [[[8]]]]) would return - the list [1, 2, 3, 4, 5, 6, 7, 8]. + Example: _deep_flatten([1, (2, 3, (4, 5), [6, 7]), [[[8]]]]) would return + the list [1, 2, 3, 4, 5, 6, 7, 8]. - Args: - items: An iterable. If elements of this iterable are lists or tuples, they - will be (recursively) flattened until non-list non-tuple objects are - reached. + Args: + items: An iterable. If elements of this iterable are lists or tuples, they + will be (recursively) flattened until non-list non-tuple objects are + reached. + + Returns: + A list of non-list, non-tuple objects. + """ - Returns: - A list of non-list, non-tuple objects. - """ - def _flat_gen(xs): - for x in xs: - if isinstance(x, (list, tuple)): - for y in _flat_gen(x): - yield y - else: - yield x - return list(_flat_gen(items)) + def _flat_gen(xs): + for x in xs: + if isinstance(x, (list, tuple)): + for y in _flat_gen(x): + yield y + else: + yield x + + return list(_flat_gen(items)) def _ensure_sublayers(layers): - """Ensures that elements in a layer list are layers. - - Args: - layers: A tuple or list whose elements can each be a layer, tuple, or list, - and so on recursively. - - Returns: - An analogous collection of layers in which embedded layer lists are - wrapped in Serial layer instances. - """ - if not layers: # None or an empty list can signal a no-op. - return Serial(None) # no-op, but still handles shapes and initialization - elif isinstance(layers, (list, tuple)): - sublayers_not_lists = [] - for layer in layers: - sublayers_not_lists.append( - Serial(layer) if isinstance(layer, (list, tuple)) else layer) - return sublayers_not_lists - else: - raise TypeError(type(layers)) + """Ensures that elements in a layer list are layers. + + Args: + layers: A tuple or list whose elements can each be a layer, tuple, or list, + and so on recursively. + + Returns: + An analogous collection of layers in which embedded layer lists are + wrapped in Serial layer instances. + """ + if not layers: # None or an empty list can signal a no-op. + return Serial(None) # no-op, but still handles shapes and initialization + elif isinstance(layers, (list, tuple)): + sublayers_not_lists = [] + for layer in layers: + sublayers_not_lists.append( + Serial(layer) if isinstance(layer, (list, tuple)) else layer + ) + return sublayers_not_lists + else: + raise TypeError(type(layers)) def _split_rngs(rng, n_copies): - if rng is None: - return (None,) * n_copies - return fastmath.random.split(rng, n_copies) + if rng is None: + return (None,) * n_copies + return fastmath.random.split(rng, n_copies) def inputs_from_stack(stack, n): - """Returns n inputs from stack.""" - stack = _make_tuple(stack) - return _make_singleitem_or_original(stack[:n]) + """Returns n inputs from stack.""" + stack = _make_tuple(stack) + return _make_singleitem_or_original(stack[:n]) def outputs_onto_stack(outputs, stack, n): - """"Returns the new stack after removing n items and pushing outputs there.""" - outputs = _make_tuple(outputs) - stack = _make_tuple(stack) - return _make_singleitem_or_original(outputs + stack[n:]) + """ "Returns the new stack after removing n items and pushing outputs there.""" + outputs = _make_tuple(outputs) + stack = _make_tuple(stack) + return _make_singleitem_or_original(outputs + stack[n:]) def _make_tuple(xs): - """Returns a tuple from a list, a tuple, or a single element.""" - if isinstance(xs, (list, tuple)): - return tuple(xs) - else: - return (xs,) + """Returns a tuple from a list, a tuple, or a single element.""" + if isinstance(xs, (list, tuple)): + return tuple(xs) + else: + return (xs,) def _make_singleitem_or_original(xs): - """Returns a single element if possible, or the original list/tuple if not.""" - if isinstance(xs, (list, tuple)) and len(xs) == 1: - return xs[0] - else: - return xs + """Returns a single element if possible, or the original list/tuple if not.""" + if isinstance(xs, (list, tuple)) and len(xs) == 1: + return xs[0] + else: + return xs def _shape_without_axis(x, axis): - return x.shape[:axis] + x.shape[axis + 1:] + return x.shape[:axis] + x.shape[axis + 1 :] def _ensure_flat(layers): - """Ensures that layers is a single flat list of Layer instances.""" - if len(layers) == 1 and layers[0] is None: - layers = () - else: - layers = _deep_flatten(layers) - for obj in layers: - if not isinstance(obj, base.Layer): - raise ValueError( - f'Found nonlayer object ({obj}) in layers: {layers}') - return layers + """Ensures that layers is a single flat list of Layer instances.""" + if len(layers) == 1 and layers[0] is None: + layers = () + else: + layers = _deep_flatten(layers) + for obj in layers: + if not isinstance(obj, base.Layer): + raise ValueError( + f"Found non-layer object ({obj}) type ({type(obj)}) ({type(obj)}) in layers: {layers}" + ) + return layers diff --git a/trax/layers/combinators_test.py b/trax/layers/combinators_test.py deleted file mode 100644 index 4f6ba40b8..000000000 --- a/trax/layers/combinators_test.py +++ /dev/null @@ -1,802 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for combinator layers.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np - -from trax import fastmath -from trax import shapes -import trax.layers as tl - - -def DivideBy(val): # pylint: disable=invalid-name - """Returns a simple division layer with n_in == 1 and n_out == 1.""" - return tl.Fn('DivideBy', lambda x: x / val) - - -def ReturnConst(val): # pylint: disable=invalid-name - """Returns a simple const layer with n_in == 0 and n_out == 1.""" - return tl.Fn('ReturnConst', lambda: val) - - -def SmallerThan(val): # pylint: disable=invalid-name - """Checks if the input is smaller than certain value.""" - return tl.Fn('SmallerThan', lambda x: x < val) - - -# TODO(jonni): Consider a more generic home for this utiliity function. -def as_list(outputs): - """Converts layer outputs to a nested list, for easier equality testing. - - Args: - outputs: A tensor or tuple/list of tensors coming from the forward - application of a layer. Each tensor is NumPy ndarray-like, which - complicates simple equality testing (e.g., via `assertEquals`): - such tensors require equality testing to use either `all` (all - elements match) or `any` (at least one element matches), which is not - directly supported in absltest. - - Returns: - A nested list structure containing all the output values, but now directly - testable using `assertEquals`. - """ - if isinstance(outputs, (list, tuple)): - return [as_list(y) for y in outputs] - else: - return outputs.tolist() - - -class SerialTest(absltest.TestCase): - - def test_none_is_no_op(self): - layer = tl.Serial(None) - xs = [np.array([1, 2, 3, 4]), - np.array([10, 20, 30])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3, 4], - [10, 20, 30]]) - - def test_empty_list_is_no_op(self): - layer = tl.Serial([]) - xs = [np.array([1, 2, 3, 4]), - np.array([10, 20, 30])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3, 4], - [10, 20, 30]]) - - def test_one_in_one_out(self): - layer = tl.Serial(DivideBy(3)) - x = np.array([3, 6, 9, 12]) - y = layer(x) - self.assertEqual(as_list(y), [1, 2, 3, 4]) - - def test_zero_in_one_out(self): - layer = tl.Serial(ReturnConst(np.array([3, 4, 5, 6]))) - y = layer(()) - self.assertEqual(as_list(y), [3, 4, 5, 6]) - - def test_one_in_two_out(self): - layer = tl.Serial(DivideBy(3), - ReturnConst(np.array([3, 4, 5, 6]))) - x = np.array([3, 6, 9, 12]) - y = layer(x) - self.assertEqual(as_list(y), [[3, 4, 5, 6], - [1, 2, 3, 4]]) - - def test_const_div(self): - layer = tl.Serial(ReturnConst(np.array([3, 6, 9, 12])), - DivideBy(3)) - y = layer(()) - self.assertEqual(as_list(y), [1, 2, 3, 4]) - - def test_div_div(self): - layer = tl.Serial(DivideBy(2.0), DivideBy(5.0)) - x = np.array([10, 20, 30]) - y = layer(x) - self.assertEqual(as_list(y), [1, 2, 3]) - - def test_dup_dup(self): - layer = tl.Serial(tl.Dup(), tl.Dup()) - x = np.array([1, 2, 3]) - ys = layer(x) - self.assertEqual(as_list(ys), [[1, 2, 3], - [1, 2, 3], - [1, 2, 3]]) - - def test_default_name(self): - layer = tl.Serial(tl.Dup(), tl.Dup()) - self.assertIn('Serial', str(layer)) - - def test_custom_name(self): - layer = tl.Serial(tl.Dup(), tl.Dup(), name='Branch') - self.assertIn('Branch', str(layer)) - - def test_weights(self): - model = tl.Serial(tl.Dense(4), tl.Dense(5), tl.Dense(7)) - self.assertIsInstance(model.weights, tuple) - self.assertLen(model.weights, 3) - - def test_flat_weights_and_state(self): - model = tl.Serial(tl.Dup(), tl.Dense(5), tl.Serial(tl.Dense(7), tl.Dup())) - sample_input_signature = shapes.signature(np.zeros((2, 3))) - model.init(sample_input_signature) - flat_weights, flat_state = tl.flatten_weights_and_state( - model.weights, model.state) - # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. - # So after making them flat, there are 4 trainable weights. - self.assertLen(flat_weights, 4) - self.assertEmpty(flat_state) - model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(7)) - sig = model2.weights_and_state_signature(sample_input_signature) - weights2, state2 = tl.unflatten_weights_and_state( - flat_weights, flat_state, sig) - model2.weights = weights2 - model2.state = state2 - self.assertLen(model2.weights, 3) - self.assertEqual(model.weights[1], model2.weights[0]) - self.assertEqual(model.weights[2][0], model2.weights[2]) - - def test_flat_weights_and_state_shared(self): - shared = tl.Dense(5) - model = tl.Serial(tl.Dense(5), shared, tl.Serial(shared, tl.Dup())) - sample_input_signature = shapes.signature(np.zeros((2, 3))) - model.init(sample_input_signature) - flat_weights, flat_state = tl.flatten_weights_and_state( - model.weights, model.state) - # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. - # So after making them flat, there are 4 trainable weights. - self.assertLen(flat_weights, 4) - self.assertEmpty(flat_state) - model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(5)) - sig = model2.weights_and_state_signature(sample_input_signature) - weights2, state2 = tl.unflatten_weights_and_state( - flat_weights, flat_state, sig) - model2.weights = weights2 - model2.state = state2 - self.assertLen(model2.weights, 3) - self.assertEqual(model.weights[0], model2.weights[0]) - self.assertEqual(model.weights[1], model2.weights[2]) - - def test_assign_sublayer_weights(self): - layer = tl.Dense(5, use_bias=False) - model = tl.Serial(tl.Serial(layer, tl.Dense(6)), tl.Dense(7)) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - new_layer_weights = np.random.uniform(weights[0][0].shape) - layer.weights = new_layer_weights - self.assertIs(model.weights[0][0], new_layer_weights) - - def test_shared_weights(self): - layer = tl.Dense(5) - model = tl.Serial(layer, layer) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_nested(self): - layer = tl.Dense(5) - model = tl.Serial(layer, tl.Serial(layer)) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_double_nested(self): - layer = tl.Dense(5) - model = tl.Serial(tl.Serial(layer), tl.Serial(layer)) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_for_shared_serial(self): - layer = tl.Serial(tl.Dense(5), tl.Dense(5)) - model = tl.Serial(layer, layer) - sample_input = np.array([1, 2, 3, 4, 5]) - # Init gives weights reflecting weight sharing. - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIsNot(weights[0], tl.GET_WEIGHTS_FROM_CACHE) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - # Forward pass runs successfully. - y = model(sample_input) - self.assertEqual(y.shape, (5,)) - - def test_state(self): - model = tl.Serial(tl.Dense(4), tl.Dense(5), tl.Dense(7)) - self.assertIsInstance(model.state, tuple) - self.assertLen(model.state, 3) - - def test_set_rng_recurse_two_levels(self): - dense_00 = tl.Dense(2) - dense_01 = tl.Dense(2) - dense_10 = tl.Dense(2) - dense_11 = tl.Dense(2) - layer = tl.Serial( - tl.Serial(dense_00, dense_01), - tl.Serial(dense_10, dense_11), - ) - input_signature = shapes.ShapeDtype((1, 2)) - - _, _ = layer.init(input_signature) - weights = layer.weights - dense_00_w, dense_00_b = weights[0][0] - dense_01_w, dense_01_b = weights[0][1] - dense_10_w, dense_10_b = weights[1][0] - dense_11_w, dense_11_b = weights[1][1] - - # Setting rng's recursively during init should yield differing weights. - self.assertFalse(np.array_equal(dense_00_w, dense_01_w)) - self.assertFalse(np.array_equal(dense_00_b, dense_01_b)) - self.assertFalse(np.array_equal(dense_10_w, dense_11_w)) - self.assertFalse(np.array_equal(dense_10_b, dense_11_b)) - - -class ParallelTest(absltest.TestCase): - - def test_dup_dup(self): - layer = tl.Parallel(tl.Dup(), tl.Dup()) - xs = [np.array([1, 2, 3]), - np.array([10, 20])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3], - [1, 2, 3], - [10, 20], - [10, 20]]) - - def test_div_div(self): - layer = tl.Parallel(DivideBy(0.5), DivideBy(3.0)) - xs = [np.array([1, 2, 3]), - np.array([30, 60])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[2, 4, 6], - [10, 20]]) - - def test_two_no_ops(self): - layer = tl.Parallel([], None) - xs = [np.array([1, 2, 3]), - np.array([10, 20])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3], - [10, 20]]) - - def test_default_name(self): - layer = tl.Parallel(tl.Dup(), tl.Dup()) - self.assertIn('Parallel', str(layer)) - - def test_custom_name(self): - layer = tl.Parallel(tl.Dup(), tl.Dup(), name='DupDup') - self.assertIn('DupDup', str(layer)) - - def test_weights(self): - model = tl.Parallel(tl.Dense(3), tl.Dense(5)) - self.assertIsInstance(model.weights, tuple) - self.assertLen(model.weights, 2) - - def test_shared_weights(self): - layer = tl.Dense(5) - model = tl.Parallel(layer, layer) - sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_nested(self): - layer = tl.Dense(5) - model = tl.Parallel([layer, tl.Dense(2)], - [layer, tl.Dense(2)]) - sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_for_shared_parallel(self): - layer = tl.Parallel(tl.Dense(5), tl.Dense(7)) - model = tl.Parallel(layer, layer) - sample_input = [ - np.array([1, 2, 3]), - np.array([10, 20, 30]), - np.array([100, 200, 300]), - np.array([1000, 2000, 3000]), - ] - # Init gives weights reflecting weight sharing. - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIsNot(weights[0], tl.GET_WEIGHTS_FROM_CACHE) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - # Forward pass runs successfully. - y0, y1, y2, y3 = model(sample_input) - self.assertEqual(y0.shape, (5,)) - self.assertEqual(y1.shape, (7,)) - self.assertEqual(y2.shape, (5,)) - self.assertEqual(y3.shape, (7,)) - - def test_state(self): - model = tl.Parallel(tl.Dense(3), tl.Dense(5)) - self.assertIsInstance(model.state, tuple) - self.assertLen(model.state, 2) - - -class ConcatenateTest(absltest.TestCase): - - def test_n_in_n_out(self): - layer = tl.Concatenate() - self.assertEqual(layer.n_in, 2) - self.assertEqual(layer.n_out, 1) - - def test_with_defaults(self): - layer = tl.Concatenate() # Default n_items=2, axis=-1 - xs = [np.array([[1, 2, 3], - [4, 5, 6]]), - np.array([[10, 20, 30], - [40, 50, 60]])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3, 10, 20, 30], - [4, 5, 6, 40, 50, 60]]) - - def test_axis_0(self): - layer = tl.Concatenate(axis=0) - xs = [np.array([[1, 2, 3], - [4, 5, 6]]), - np.array([[10, 20, 30], - [40, 50, 60]])] - y = layer(xs) - self.assertEqual(as_list(y), [[1, 2, 3], - [4, 5, 6], - [10, 20, 30], - [40, 50, 60]]) - - def test_axis_1(self): - layer = tl.Concatenate(axis=1) - xs = [np.array([[1, 2, 3], - [4, 5, 6]]), - np.array([[10, 20, 30], - [40, 50, 60]])] - y = layer(xs) - self.assertEqual(as_list(y), [[1, 2, 3, 10, 20, 30], - [4, 5, 6, 40, 50, 60]]) - - def test_n_items_is_not_default(self): - layer = tl.Concatenate(n_items=3) - xs = [np.array([[1, 2, 3], - [4, 5, 6]]), - np.array([[10, 20, 30], - [40, 50, 60]]), - np.array([[100, 200, 300], - [400, 500, 600]])] - y = layer(xs) - self.assertEqual(y.shape, (2, 9)) - self.assertEqual(as_list(y), [[1, 2, 3, 10, 20, 30, 100, 200, 300], - [4, 5, 6, 40, 50, 60, 400, 500, 600]]) - - def test_repr(self): - layer = tl.Concatenate() - self.assertEqual(repr(layer), 'Concatenate_in2') - - layer = tl.Concatenate(axis=0) - self.assertEqual(repr(layer), 'Concatenate_axis0_in2') - - layer = tl.Concatenate(axis=1) - self.assertEqual(repr(layer), 'Concatenate_axis1_in2') - - layer = tl.Concatenate(n_items=3) - self.assertEqual(repr(layer), 'Concatenate_in3') - - -class BranchTest(absltest.TestCase): - - def test_noop_dup(self): - layer = tl.Branch([], tl.Dup()) - x = np.array([1, 2, 3]) - ys = layer(x) - self.assertEqual(as_list(ys), [[1, 2, 3], - [1, 2, 3], - [1, 2, 3]]) - - def test_add_div(self): - layer = tl.Branch(tl.Add(), DivideBy(0.5)) - xs = [np.array([1, 2, 3]), - np.array([10, 20, 30])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[11, 22, 33], - [2, 4, 6]]) - - def test_one_sublayer(self): - layer = tl.Branch(DivideBy(0.5)) - x = np.array([1, 2, 3]) - ys = layer(x) - self.assertEqual(as_list(ys), [2, 4, 6]) - - def test_default_name(self): - layer = tl.Branch(tl.Add(), DivideBy(0.5)) - self.assertIn('Branch', str(layer)) - - def test_printing_sublayers(self): - layer = tl.Branch(tl.Add(), tl.Add()) - expected_result = 'Branch_in2_out2[\n Add_in2\n Add_in2\n]' - self.assertEqual(expected_result, str(layer)) - - -class SelectTest(absltest.TestCase): - - def test_computes_n_in(self): - layer = tl.Select([0, 0]) - self.assertEqual(layer.n_in, 1) - - layer = tl.Select([1, 0]) - self.assertEqual(layer.n_in, 2) - - layer = tl.Select([2]) - self.assertEqual(layer.n_in, 3) - - def test_given_n_in(self): - layer = tl.Select([0], n_in=2) - self.assertEqual(layer.n_in, 2) - - layer = tl.Select([0], n_in=3) - self.assertEqual(layer.n_in, 3) - - def test_first_of_3(self): - layer = tl.Select([0], n_in=3) - xs = [np.array([1, 2, 3]), - np.array([10, 20]), - np.array([100])] - y = layer(xs) - self.assertEqual(as_list(y), [1, 2, 3]) - - def test_second_of_3(self): - layer = tl.Select([1], n_in=3) - xs = [np.array([1, 2, 3]), - np.array([10, 20]), - np.array([100])] - y = layer(xs) - self.assertEqual(as_list(y), [10, 20]) - - -class DropTest(absltest.TestCase): - - def test_drop(self): - layer = tl.Drop() - x = np.array([1, 2, 3]) - y = layer(x) - self.assertEqual(as_list(y), []) - - -class SwapTest(absltest.TestCase): - - def test_swap(self): - layer = tl.Swap() - xs = [np.array([1, 2, 3]), - np.array([10, 20, 30])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[10, 20, 30], - [1, 2, 3]]) - - -class ChunkTest(absltest.TestCase): - - def test_chunk(self): - layer = tl.Dense(4) - x = np.array([[1, 2, 3], [4, 5, 6]]) - layer.init(x) - y = layer(x) - z = tl.Chunk(layer, 1)(x) - self.assertLess(np.sum((y - z)**2), 1e-5) # y == z upto numerics - - def test_chunk_uneven_numbers(self): - layer = tl.Dense(4) - x = np.array([[1, 2, 3], [4, 5, 6]]) - layer.init(x) - y = layer(x) - z = tl.Chunk(layer, 3)(x) # By default it should just pass - self.assertLess(np.sum((y - z)**2), 1e-5) # y == z upto numerics - chunk_with_test = tl.Chunk(layer, 3, pass_unchunkable=False) - self.assertRaises(tl.LayerError, lambda: chunk_with_test(x)) - - -class SerialWithSideOutputsTest(absltest.TestCase): - - def test_serial_with_side_outputs_div_div(self): - def some_layer(): - return tl.Parallel(DivideBy(2.0), DivideBy(5.0)) - layer = tl.SerialWithSideOutputs([some_layer(), some_layer()]) - xs = (np.array([1, 2, 3]), - np.array([10, 20, 30, 40, 50]), - np.array([100, 200])) - ys = layer(xs) - output_shapes = [y.shape for y in ys] - self.assertEqual(output_shapes, [(3,), (5,), (2,)]) - - -BACKENDS = [fastmath.Backend.JAX] - - -@parameterized.named_parameters( - ('_' + b.value, b) for b in BACKENDS) -class ScanTest(parameterized.TestCase): - - def _AddWithCarry(self): # pylint: disable=invalid-name - del self - def f(x, carry): - res = x + carry - return res, res # output and carry are the same - return tl.Fn('AddWithCarry', f, n_out=2) - - def test_default_axis(self, backend): - with fastmath.use_backend(backend): - layer = tl.Scan(self._AddWithCarry()) - xs = [ - np.array([[0, 1, 2, 3], - [0, 10, 20, 30], - [0, 100, 200, 300]]), - np.array([9000, 8000, 7000, 6000]) - ] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[9000, 8001, 7002, 6003], - [9000, 8011, 7022, 6033], - [9000, 8111, 7222, 6333] - ], - [9000, 8111, 7222, 6333] - ]) - - def test_axis_1(self, backend): - with fastmath.use_backend(backend): - layer = tl.Scan(self._AddWithCarry(), axis=1) - xs = [ - np.array([[0, 1, 2, 3], - [0, 10, 20, 30], - [0, 100, 200, 300]]), - np.array([9000, - 8000, - 7000]) - ] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[9000, 9001, 9003, 9006], - [8000, 8010, 8030, 8060], - [7000, 7100, 7300, 7600] - ], - [9006, - 8060, - 7600] - ]) - - def test_predict(self, backend): - with fastmath.use_backend(backend): - layer = tl.Scan(self._AddWithCarry(), axis=1, mode='predict') - xs = [np.array([[0, 1, 2]]), - np.array([90])] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[90, 91, 93]], - [93]]) - xs = [np.array([[3, 4]]), - np.array([90])] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[96, 100]], - [100]]) - - def test_multi_input(self, backend): - def _MultiInputFn(): # pylint: disable=invalid-name - def f(a, b, carry): - return a + b, b, carry + 1 - return tl.Fn('MultiInputFn', f, n_out=2) - - with fastmath.use_backend(backend): - layer = tl.Scan(_MultiInputFn(), axis=1) - xs = [ - np.array([[0, 1, 2], - [0, 10, 20]]), - np.array([[4, 5, 6], - [40, 50, 60]]), - np.array([9000, - 8000]) - ] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[4, 6, 8], - [40, 60, 80]], - [[4, 5, 6], - [40, 50, 60]], - [9003, - 8003] - ]) - - def test_no_carry(self, backend): - def _AddOne(): # pylint: disable=invalid-name - return tl.Fn('AddOne', lambda x: x + 1) - - with fastmath.use_backend(backend): - layer = tl.Scan(_AddOne(), n_carry=0) - x = np.array([[1, 3, 7], - [10, 30, 70]]) - y = layer(x) - self.assertEqual(as_list(y), [[2, 4, 8], - [11, 31, 71]]) - - -class CondTest(absltest.TestCase): - - def test_basic_true(self): - cond = ReturnConst(True) - true = ReturnConst([2]) - false = ReturnConst([5]) - layer = tl.Cond(cond, true, false) - layer.init(()) - xs = tuple() - ys = layer(xs) - self.assertEqual(as_list(ys), 2) - - def test_basic_false(self): - cond = ReturnConst(False) - true = ReturnConst([2]) - false = ReturnConst([5]) - layer = tl.Cond(cond, true, false) - layer.init(()) - xs = tuple() - ys = layer(xs) - self.assertEqual(as_list(ys), 5) - - def test_complex_blocks(self): - cond = ReturnConst(True) - true = DivideBy(2.) - false = DivideBy(4.) - layer = tl.Cond(cond, true, false) - xs = [np.arange(5).astype(np.float32)] - layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual(as_list(ys), [0., 0.5, 1.0, 1.5, 2.0]) - - def test_condition_func_true(self): - cond = SmallerThan(3.0) - true = DivideBy(2.) - false = DivideBy(4.) - layer = tl.Cond(cond, true, false) - xs = (np.array(2.), np.array([4., 12.])) - layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual(as_list(ys), [2., 6.]) - - def test_condition_func_false(self): - cond = SmallerThan(3.0) - true = DivideBy(2.) - false = DivideBy(4.) - layer = tl.Cond(cond, true, false) - xs = (np.array(4.), np.array([4., 12.])) - layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual(as_list(ys), [1., 3.]) - - def test_condition_func_default_false(self): - cond = SmallerThan(3.0) - true = DivideBy(2.) - layer = tl.Cond(cond, true) - xs = (np.array(4.), np.array([4., 12.])) - layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual(as_list(ys), [4., 12.]) - - def test_exception_n_out(self): - cond = SmallerThan(3.0) - true = DivideBy(2.) - false = tl.Dup() - self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false)) - - def test_exception_n_in(self): - cond = SmallerThan(3.0) - true = ReturnConst(2.) - false = DivideBy(2.) - self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false)) - - def test_exception_run1(self): - # We expect exactly one input. - cond = SmallerThan(3.0) - true = ReturnConst(2.) - false = ReturnConst(5.) - def init_and_run(layer, xs): - layer.init(shapes.signature(xs)) - layer(xs) - # It will pass with one input. - xs = np.array(4.) - layer = tl.Cond(cond, true, false) - init_and_run(layer, xs) - # It will fail with zero or two inputs. - for xs in ((), (np.array(4.), np.array([4., 12.]))): - layer = tl.Cond(cond, true, false) - # pylint: disable=cell-var-from-loop - self.assertRaises(Exception, lambda: init_and_run(layer, xs)) - - def test_exception_run2(self): - # We expect exactly two inputs. - cond = SmallerThan(3.0) - true = DivideBy(2.) - false = DivideBy(5.) - def init_and_run(layer, xs): - layer.init(shapes.signature(xs)) - layer(xs) - # It will pass with two inputs. - xs = (np.array(4.), np.array([4., 12.])) - layer = tl.Cond(cond, true, false) - init_and_run(layer, xs) - # It will fail with zero or one input. - for xs in ((), (np.array(4.))): - # pylint: disable=cell-var-from-loop - self.assertRaises(Exception, lambda: init_and_run(layer, xs)) - - def test_weights_and_state(self): - cond = SmallerThan(3.0) - true = tl.Dense(5) - false = tl.Dense(5) - different = tl.Dense(5) - layer = tl.Cond(cond, true, false) - xs = (np.array(2.), np.array([0., 1., 2.])) - layer.init(shapes.signature(xs)) - - # weights - self.assertEqual(as_list(layer.weights), - as_list((cond.weights, true.weights, false.weights))) - self.assertNotEqual(as_list(true.weights), as_list(false.weights)) - self.assertNotEqual(as_list(true.weights), as_list(different.weights)) - - false.weights = true.weights - self.assertEqual(as_list(layer.weights), - as_list((cond.weights, true.weights, true.weights))) - - layer.weights = (cond.weights, true.weights, different.weights) - self.assertEqual(as_list(layer.weights), - as_list((cond.weights, true.weights, different.weights))) - # state - self.assertEqual(as_list(layer.state), - as_list((cond.state, true.state, false.state))) - # just check if simple assignments (setter from base.Layer) work correctly - # with Cond.init_weights_and_state ; all states are empty so there is no - # point in checking equality - false.state = true.state - layer.state = (cond.state, true.state, different.state) - - -class BatchLeadingAxesTest(absltest.TestCase): - - def _Id3Dim(self): # pylint: disable=invalid-name - del self - def f(x): - assert len(x.shape) == 3 - return x - return tl.Fn('Id3Dim', f, n_out=1) - - def test_2axes(self): - layer = tl.BatchLeadingAxes(self._Id3Dim(), n_last_axes_to_keep=2) - ys = layer(np.zeros((3, 4, 5))) - self.assertEqual(ys.shape, (3, 4, 5)) - ys = layer(np.zeros((2, 3, 4, 5))) - self.assertEqual(ys.shape, (2, 3, 4, 5)) - ys = layer(np.zeros((1, 2, 3, 4, 5))) - self.assertEqual(ys.shape, (1, 2, 3, 4, 5)) - - -class BidirectionalTest(absltest.TestCase): - - def test_dimensionality(self): - x = np.ones((2, 3, 8)) - layer = tl.Bidirectional(tl.GRU(n_units=8)) - input_signature = shapes.signature(x) - _, _ = layer.init(input_signature) - yhat = layer(x) - - self.assertEqual(yhat.shape, (2, 3, 8 + 8)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/convolution.py b/trax/layers/convolution.py index d0658f679..54e44ff8b 100644 --- a/trax/layers/convolution.py +++ b/trax/layers/convolution.py @@ -26,167 +26,193 @@ class Conv(base.Layer): - """Layer constructor function for a general convolution layer.""" - - def __init__(self, filters, kernel_size, strides=None, padding='VALID', - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - kernel_initializer=None, - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True): - super().__init__() - self._filters = filters - self._kernel_size = kernel_size - self._padding = padding - self._dimension_numbers = dimension_numbers - self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers - self._one = (1,) * len(kernel_size) - self._strides = strides or self._one - self._bias_initializer = bias_initializer - self._use_bias = use_bias - rhs_spec = self._rhs_spec - self._kernel_initializer = kernel_initializer - if kernel_initializer is None: - self._kernel_initializer = init.GlorotNormalInitializer( - rhs_spec.index('O'), rhs_spec.index('I')) - - def _check_nhwc(self): - msg = 'Convolutions on more than 4 dimensions only supported in NHWC.' - assert self._lhs_spec == self._out_spec == 'NHWC', msg - - def forward(self, x): - if self._use_bias: - w, b = self.weights - else: - w = self.weights - x_shape = list(x.shape) - if len(x_shape) > 4: - self._check_nhwc() - new_batch_dim = functools.reduce(operator.mul, x_shape[:-3]) - x = jnp.reshape(x, [new_batch_dim] + x_shape[-3:]) - res = fastmath.conv( - x, w, self._strides, self._padding, self._dimension_numbers, - self._one) - if self._use_bias: - res = res + b - if len(x_shape) > 4: - res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) - return res - - def _kernel_shape(self, input_shape): - """Helper to calculate the kernel shape.""" - kernel_size_iter = iter(self._kernel_size) - return [self._filters if c == 'O' else - input_shape[self._lhs_spec.index('C')] if c == 'I' else - next(kernel_size_iter) for c in self._rhs_spec] - - def init_weights_and_state(self, input_signature): - input_shape = input_signature.shape - if len(input_shape) > 4: - self._check_nhwc() - new_batch_dim = functools.reduce(operator.mul, input_shape[:-3]) - input_shape = [new_batch_dim] + list(input_shape[-3:]) - kernel_shape = self._kernel_shape(input_shape) - rng1, rng2 = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(kernel_shape, rng1) - if self._use_bias: - bias_shape = [self._filters if c == 'C' else 1 for c in self._out_spec] - bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) - b = self._bias_initializer(bias_shape, rng2) - self.weights = (w, b) - else: - self.weights = w + """Layer constructor function for a general convolution layer.""" + + def __init__( + self, + filters, + kernel_size, + strides=None, + padding="VALID", + dimension_numbers=("NHWC", "HWIO", "NHWC"), + kernel_initializer=None, + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, + ): + super().__init__() + self._filters = filters + self._kernel_size = kernel_size + self._padding = padding + self._dimension_numbers = dimension_numbers + self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers + self._one = (1,) * len(kernel_size) + self._strides = strides or self._one + self._bias_initializer = bias_initializer + self._use_bias = use_bias + rhs_spec = self._rhs_spec + self._kernel_initializer = kernel_initializer + if kernel_initializer is None: + self._kernel_initializer = init.GlorotNormalInitializer( + rhs_spec.index("O"), rhs_spec.index("I") + ) + + def _check_nhwc(self): + msg = "Convolutions on more than 4 dimensions only supported in NHWC." + assert self._lhs_spec == self._out_spec == "NHWC", msg + + def forward(self, x): + if self._use_bias: + w, b = self.weights + else: + w = self.weights + x_shape = list(x.shape) + if len(x_shape) > 4: + self._check_nhwc() + new_batch_dim = functools.reduce(operator.mul, x_shape[:-3]) + x = jnp.reshape(x, [new_batch_dim] + x_shape[-3:]) + res = fastmath.conv( + x, w, self._strides, self._padding, self._dimension_numbers, self._one + ) + if self._use_bias: + res = res + b + if len(x_shape) > 4: + res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) + return res + + def _kernel_shape(self, input_shape): + """Helper to calculate the kernel shape.""" + kernel_size_iter = iter(self._kernel_size) + return [ + self._filters + if c == "O" + else input_shape[self._lhs_spec.index("C")] + if c == "I" + else next(kernel_size_iter) + for c in self._rhs_spec + ] + + def init_weights_and_state(self, input_signature): + input_shape = input_signature.shape + if len(input_shape) > 4: + self._check_nhwc() + new_batch_dim = functools.reduce(operator.mul, input_shape[:-3]) + input_shape = [new_batch_dim] + list(input_shape[-3:]) + kernel_shape = self._kernel_shape(input_shape) + rng1, rng2 = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(kernel_shape, rng1) + if self._use_bias: + bias_shape = [self._filters if c == "C" else 1 for c in self._out_spec] + bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) + b = self._bias_initializer(bias_shape, rng2) + self.weights = (w, b) + else: + self.weights = w class CausalConv(Conv): - """Causal (masked) convolution for [batch x time x depth] sequences. - - Maintains causality along time axis. Used in language modeling tasks. - """ - - def __init__(self, - filters, - kernel_width=3, - kernel_initializer=None, - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True): - super().__init__( - filters=filters, - kernel_size=(kernel_width,), - strides=None, - padding='VALID', - dimension_numbers=('NWC', 'WIO', 'NWC'), + """Causal (masked) convolution for [batch x time x depth] sequences. + + Maintains causality along time axis. Used in language modeling tasks. + """ + + def __init__( + self, + filters, + kernel_width=3, + kernel_initializer=None, + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, + ): + super().__init__( + filters=filters, + kernel_size=(kernel_width,), + strides=None, + padding="VALID", + dimension_numbers=("NWC", "WIO", "NWC"), + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + use_bias=use_bias, + ) + + def forward(self, x): + assert self._padding == "VALID" + # Left pad with 0s. Applying an unmasked valid convolution on top of this + # yields a causal convolution. + # TODO(ddohan): Support strided and dilated convolutions. + rate = 1 + effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1) + pad = effective_kernel_size - 1 + x_leftpad = jnp.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode="constant") + return super().forward(x_leftpad) + + +def Conv1d( + filters, + kernel_size, + stride=1, + padding="VALID", + kernel_initializer=None, + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, +): + return Conv( + filters, + (kernel_size,), + strides=(stride,), + padding=padding, + dimension_numbers=("NWC", "WIO", "NWC"), kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, - use_bias=use_bias) - - def forward(self, x): - assert self._padding == 'VALID' - # Left pad with 0s. Applying an unmasked valid convolution on top of this - # yields a causal convolution. - # TODO(ddohan): Support strided and dilated convolutions. - rate = 1 - effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1) - pad = effective_kernel_size - 1 - x_leftpad = ( - jnp.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode='constant')) - return super().forward(x_leftpad) - - -def Conv1d(filters, kernel_size, stride=1, padding='VALID', - kernel_initializer=None, - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True): - return Conv(filters, (kernel_size,), strides=(stride,), padding=padding, - dimension_numbers=('NWC', 'WIO', 'NWC'), - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - use_bias=use_bias) + use_bias=use_bias, + ) def _zero_pad(x, pad, axis): # pylint: disable = invalid-name - """Helper for jnp.pad with 0s for single-axis case.""" - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = pad # Padding on axis. - return jnp.pad(x, pad_widths, mode='constant') + """Helper for jnp.pad with 0s for single-axis case.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = pad # Padding on axis. + return jnp.pad(x, pad_widths, mode="constant") # @assert_shape('bld->bld') class CausalDepthwiseConv(base.Layer): - """A causal depthwise convolution layer.""" - - def __init__(self, - kernel_size=3, - kernel_initializer=init.GlorotUniformInitializer(), - use_bfloat16=False): - """Returns a causal depthwise convolution layer.""" - super().__init__(n_in=1, n_out=1) - self._kernel_size = kernel_size - self._kernel_initializer = kernel_initializer - self._use_bfloat16 = use_bfloat16 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input. - """ - w = self.weights - res = x * w[0, :][None, None, :] - for i in range(1, self._kernel_size): - x = _zero_pad(x, (1, 0), 1) - x = x[:, :-1, :] - res += x * w[i, :][None, None, :] - return res - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - shape_w = (self._kernel_size, input_signature.shape[-1]) - rng_w, _ = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(shape_w, rng_w) - if self._use_bfloat16: - w = w.astype(jnp.bfloat16) - self.weights = w + """A causal depthwise convolution layer.""" + + def __init__( + self, + kernel_size=3, + kernel_initializer=init.GlorotUniformInitializer(), + use_bfloat16=False, + ): + """Returns a causal depthwise convolution layer.""" + super().__init__(n_in=1, n_out=1) + self._kernel_size = kernel_size + self._kernel_initializer = kernel_initializer + self._use_bfloat16 = use_bfloat16 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + w = self.weights + res = x * w[0, :][None, None, :] + for i in range(1, self._kernel_size): + x = _zero_pad(x, (1, 0), 1) + x = x[:, :-1, :] + res += x * w[i, :][None, None, :] + return res + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + shape_w = (self._kernel_size, input_signature.shape[-1]) + rng_w, _ = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(shape_w, rng_w) + if self._use_bfloat16: + w = w.astype(jnp.bfloat16) + self.weights = w diff --git a/trax/layers/convolution_test.py b/trax/layers/convolution_test.py deleted file mode 100644 index 7d7c69d30..000000000 --- a/trax/layers/convolution_test.py +++ /dev/null @@ -1,91 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for convolution layers.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -import trax.layers as tl - - -class ConvolutionTest(absltest.TestCase): - - def test_call(self): - layer = tl.Conv(30, (3, 3)) - x = np.ones((9, 5, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 3, 3, 30)) - - def test_use_bias_true(self): - layer = tl.Conv(30, (3, 3), use_bias=True) - x = np.ones((9, 5, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 3, 3, 30)) - - self.assertIsInstance(layer.weights, tuple) - self.assertLen(layer.weights, 2) - self.assertEqual(layer.weights[0].shape, (3, 3, 20, 30)) - self.assertEqual(layer.weights[1].shape, (30,)) - - def test_use_bias_false(self): - layer = tl.Conv(30, (3, 3), use_bias=False) - x = np.ones((9, 5, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 3, 3, 30)) - # With use_bias=False, layer.weights is just 'w' and there is no 'b'. - self.assertEqual(layer.weights.shape, (3, 3, 20, 30)) - - def test_call_rebatch(self): - layer = tl.Conv(30, (3, 3)) - x = np.ones((2, 9, 5, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (2, 9, 3, 3, 30)) - - -class CausalConvolutionTest(absltest.TestCase): - - def test_causal_conv(self): - layer = tl.CausalConv(filters=30, kernel_width=3) - x = np.ones((9, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 5, 30)) - - # TODO(ddohan): How to test for causality? Gradient check between positions? - - def test_causal_conv_use_bias_false(self): - layer = tl.CausalConv(filters=30, kernel_width=3, use_bias=False) - x = np.ones((9, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 5, 30)) - - self.assertEqual(layer.weights.shape, (3, 20, 30)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/core.py b/trax/layers/core.py index b7fd6fc31..be9e0d089 100644 --- a/trax/layers/core.py +++ b/trax/layers/core.py @@ -29,841 +29,890 @@ # The output tensor has the same shape as the input tensor, except for the size # of the last dimension. -@assert_shape('...a->...b') +@assert_shape("...a->...b") class Dense(base.Layer): - """A dense (a.k.a. fully-connected, affine) layer. - - Dense layers are the prototypical example of a trainable layer, i.e., a layer - with trainable weights. Each node in a dense layer computes a weighted sum of - all node values from the preceding layer and adds to that sum a node-specific - bias term. The full layer computation is expressed compactly in linear - algebra as an affine map `y = Wx + b`, where `W` is a matrix and `y`, `x`, - and `b` are vectors. The layer is trained, or "learns", by updating the - values in `W` and `b`. - - Less commonly, a dense layer can omit the bias term and be a pure linear map: - `y = Wx`. - """ - - def __init__(self, - n_units, - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True, - use_bfloat16=False): - """Returns a dense (fully connected) layer of width `n_units`. - - A dense layer maps collections of `R^m` vectors to `R^n`, where `n` - (`= n_units`) is fixed at layer creation time, and `m` is set at layer - initialization time. - - Args: - n_units: Number of nodes in the layer, also known as the width of the - layer. - kernel_initializer: Function that creates a matrix of (random) initial - connection weights `W` for the layer. - bias_initializer: Function that creates a vector of (random) initial - bias weights `b` for the layer. - use_bias: If `True`, compute an affine map `y = Wx + b`; else compute - a linear map `y = Wx`. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - """ - super().__init__(name=f'Dense_{n_units}') - self._n_units = n_units - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - self._use_bias = use_bias - self._use_bfloat16 = use_bfloat16 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input, except the final dimension - is the layer's `n_units` value. + """A dense (a.k.a. fully-connected, affine) layer. + + Dense layers are the prototypical example of a trainable layer, i.e., a layer + with trainable weights. Each node in a dense layer computes a weighted sum of + all node values from the preceding layer and adds to that sum a node-specific + bias term. The full layer computation is expressed compactly in linear + algebra as an affine map `y = Wx + b`, where `W` is a matrix and `y`, `x`, + and `b` are vectors. The layer is trained, or "learns", by updating the + values in `W` and `b`. + + Less commonly, a dense layer can omit the bias term and be a pure linear map: + `y = Wx`. """ - if self._use_bias: - if not isinstance(self.weights, (tuple, list)): - raise ValueError(f'Weights should be a (w, b) tuple or list; ' - f'instead got: {self.weights}') - w, b = self.weights - return jnp.dot(x, w) + b # Affine map. - else: - w = self.weights - return jnp.dot(x, w) # Linear map. - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights. - - Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the - default case), or a `w` tensor for layers created with `use_bias=False`. - Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. - """ - shape_w = (input_signature.shape[-1], self._n_units) - shape_b = (self._n_units,) - rng_w, rng_b = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(shape_w, rng_w) - if self._use_bfloat16: - w = w.astype(jnp.bfloat16) - - if self._use_bias: - b = self._bias_initializer(shape_b, rng_b) - if self._use_bfloat16: - b = b.astype(jnp.bfloat16) - self.weights = (w, b) - else: - self.weights = w + def __init__( + self, + n_units, + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, + use_bfloat16=False, + ): + """Returns a dense (fully connected) layer of width `n_units`. + + A dense layer maps collections of `R^m` vectors to `R^n`, where `n` + (`= n_units`) is fixed at layer creation time, and `m` is set at layer + initialization time. + + Args: + n_units: Number of nodes in the layer, also known as the width of the + layer. + kernel_initializer: Function that creates a matrix of (random) initial + connection weights `W` for the layer. + bias_initializer: Function that creates a vector of (random) initial + bias weights `b` for the layer. + use_bias: If `True`, compute an affine map `y = Wx + b`; else compute + a linear map `y = Wx`. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + """ + super().__init__(name=f"Dense_{n_units}") + self._n_units = n_units + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + self._use_bias = use_bias + self._use_bfloat16 = use_bfloat16 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input, except the final dimension + is the layer's `n_units` value. + """ + if self._use_bias: + if not isinstance(self.weights, (tuple, list)): + raise ValueError( + f"Weights should be a (w, b) tuple or list; " + f"instead got: {self.weights}" + ) + w, b = self.weights + return jnp.dot(x, w) + b # Affine map. + else: + w = self.weights + return jnp.dot(x, w) # Linear map. + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights. + + Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the + default case), or a `w` tensor for layers created with `use_bias=False`. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. + """ + shape_w = (input_signature.shape[-1], self._n_units) + shape_b = (self._n_units,) + rng_w, rng_b = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(shape_w, rng_w) + if self._use_bfloat16: + w = w.astype(jnp.bfloat16) + + if self._use_bias: + b = self._bias_initializer(shape_b, rng_b) + if self._use_bfloat16: + b = b.astype(jnp.bfloat16) + self.weights = (w, b) + else: + self.weights = w # The output tensor has the same shape as the input tensor, but with added # dimension at the end. This dimension size corresponds to embedding depth. -@assert_shape('...->...d') +@assert_shape("...->...d") class Embedding(base.Layer): - """Trainable layer that maps discrete tokens/IDs to vectors. - - Embedding layers are commonly used to map discrete data, like words in NLP, - into vectors. Here is a canonical example:: - - vocab_size = 5 - word_ids = np.array([1, 2, 3, 4], dtype=np.int32) # word_ids < vocab_size - embedding_layer = tl.Embedding(vocab_size, 32) - embedding_layer.init(trax.shapes.signature(word_ids)) - embedded = embedding_layer(word_ids) # embedded.shape = (4, 32) - """ - - def __init__(self, - vocab_size, - d_feature, - use_bfloat16=False, - kernel_initializer=init.ScaledInitializer( - out_dim=-1, in_dim=-2, scale=1., mode='fan_out', - distribution='uniform')): - """Returns an embedding layer with given vocabulary size and vector size. - - The layer clips input values (token IDs) to the range `[0, vocab_size)`. - That is, negative token IDs all clip to `0` before being mapped to a - vector, and token IDs with value `vocab_size` or greater all clip to - `vocab_size - 1` before being mapped to a vector. - - Args: - vocab_size: Size of the input vocabulary. The layer will assign a unique - vector to each id in `range(vocab_size)`. - d_feature: Dimensionality/depth of the output vectors. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - kernel_initializer: Function that creates (random) initial vectors for - the embedding. - """ - # TODO(jonni): is the clipping behavior what we want going forward? - super().__init__(name=f'Embedding_{vocab_size}_{d_feature}') - self._d_feature = d_feature # feature dimensionality - self._vocab_size = vocab_size - self._use_bfloat16 = use_bfloat16 - self._kernel_initializer = kernel_initializer + """Trainable layer that maps discrete tokens/IDs to vectors. - def forward(self, x): - """Returns embedding vectors corresponding to input token IDs. + Embedding layers are commonly used to map discrete data, like words in NLP, + into vectors. Here is a canonical example:: - Args: - x: Tensor of token IDs. - - Returns: - Tensor of embedding vectors. + vocab_size = 5 + word_ids = np.array([1, 2, 3, 4], dtype=np.int32) # word_ids < vocab_size + embedding_layer = tl.Embedding(vocab_size, 32) + embedding_layer.init(trax.shapes.signature(word_ids)) + embedded = embedding_layer(word_ids) # embedded.shape = (4, 32) """ - embedded = jnp.take(self.weights, x, axis=0, mode='clip') - if self._use_bfloat16: # Return float32 activations w/ bfloat16 weights. - embedded = embedded.astype(jnp.float32) - return embedded - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - del input_signature - shape_w = (self._vocab_size, self._d_feature) - # TODO(lukaszkaiser): do we split self.rng for consistency? Add a method? - w = self._kernel_initializer(shape_w, self.rng) - if self._use_bfloat16: - w = w.astype(jnp.bfloat16) - self.weights = w - - -@assert_shape('...->...') # The output and input shapes are the same. -class Dropout(base.Layer): - """A layer that stochastically ignores a subset of inputs each training step. - In training, to compensate for the fraction of input values dropped (`rate`), - all surviving values are multiplied by `1 / (1 - rate)`. - - The parameter `shared_axes` allows to specify a list of axes on which - the mask will be shared: we will use size 1 on those axes for dropout mask - and broadcast it. Sharing reduces randomness, but can save memory. + def __init__( + self, + vocab_size, + d_feature, + use_bfloat16=False, + kernel_initializer=init.ScaledInitializer( + out_dim=-1, in_dim=-2, scale=1.0, mode="fan_out", distribution="uniform" + ), + ): + """Returns an embedding layer with given vocabulary size and vector size. + + The layer clips input values (token IDs) to the range `[0, vocab_size)`. + That is, negative token IDs all clip to `0` before being mapped to a + vector, and token IDs with value `vocab_size` or greater all clip to + `vocab_size - 1` before being mapped to a vector. + + Args: + vocab_size: Size of the input vocabulary. The layer will assign a unique + vector to each id in `range(vocab_size)`. + d_feature: Dimensionality/depth of the output vectors. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + kernel_initializer: Function that creates (random) initial vectors for + the embedding. + """ + # TODO(jonni): is the clipping behavior what we want going forward? + super().__init__(name=f"Embedding_{vocab_size}_{d_feature}") + + self._d_feature = d_feature # feature dimensionality + self._vocab_size = vocab_size + self._use_bfloat16 = use_bfloat16 + self._kernel_initializer = kernel_initializer + + def forward(self, x): + """Returns embedding vectors corresponding to input token IDs. + + Args: + x: Tensor of token IDs. + + Returns: + Tensor of embedding vectors. + """ + embedded = jnp.take(self.weights, x, axis=0, mode="clip") + if self._use_bfloat16: # Return float32 activations w/ bfloat16 weights. + embedded = embedded.astype(jnp.float32) + return embedded + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + del input_signature + shape_w = (self._vocab_size, self._d_feature) + # TODO(lukaszkaiser): do we split self.rng for consistency? Add a method? + w = self._kernel_initializer(shape_w, self.rng) + if self._use_bfloat16: + w = w.astype(jnp.bfloat16) + self.weights = w + + +@assert_shape("...->...") # The output and input shapes are the same. +class Dropout(base.Layer): + """A layer that stochastically ignores a subset of inputs each training step. - This layer is active only during training (`mode='train'`). In other - circumstances it is a no-op. + In training, to compensate for the fraction of input values dropped (`rate`), + all surviving values are multiplied by `1 / (1 - rate)`. - Originally introduced in the paper "Dropout: A Simple Way to Prevent Neural - Networks from Overfitting" available under the following link: - https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf - """ + The parameter `shared_axes` allows to specify a list of axes on which + the mask will be shared: we will use size 1 on those axes for dropout mask + and broadcast it. Sharing reduces randomness, but can save memory. - def __init__(self, rate=0.0, shared_axes=None, mode='train'): - """Creates a dropout layer with the given target drop rate. + This layer is active only during training (`mode='train'`). In other + circumstances it is a no-op. - Args: - rate: Stochastic rate (probability) for dropping an activation value - from the preceding layer (setting it to zero). - shared_axes: List of axes on which the mask is shared. - mode: If `'train'`, this layer will perform dropout; else, it will pass - all values through unaltered. + Originally introduced in the paper "Dropout: A Simple Way to Prevent Neural + Networks from Overfitting" available under the following link: + https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf """ - super().__init__() - self._initial_rate = rate - self._shared_axes = [] if shared_axes is None else shared_axes - self._mode = mode - def init_weights_and_state(self, input_signature): - """Sets layer-specific internal state.""" - del input_signature - self.state = jnp.array(self._initial_rate) - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of activations. - - Returns: - Tensor of same shape and dtype as the input. - """ - if self._mode != 'train': - return x - state, rng = self.state, self.rng - rate = self._initial_rate - if isinstance(state, dict) and self._name in state: - rate = state[self._name] - if rate == 0.0: - return x - mask_shape = list(x.shape) - for axis in self._shared_axes: - mask_shape[axis] = 1 - keep_prob = 1.0 - rate - keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape)) - mask = keep.astype(x.dtype) / keep_prob - return x * mask + def __init__(self, rate=0.0, shared_axes=None, mode="train"): + """Creates a dropout layer with the given target drop rate. + + Args: + rate: Stochastic rate (probability) for dropping an activation value + from the preceding layer (setting it to zero). + shared_axes: List of axes on which the mask is shared. + mode: If `'train'`, this layer will perform dropout; else, it will pass + all values through unaltered. + """ + super().__init__() + self._initial_rate = rate + self._shared_axes = [] if shared_axes is None else shared_axes + self._mode = mode + + def init_weights_and_state(self, input_signature): + """Sets layer-specific internal state.""" + del input_signature + self.state = jnp.array(self._initial_rate) + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of activations. + + Returns: + Tensor of same shape and dtype as the input. + """ + if self._mode != "train": + return x + state, rng = self.state, self.rng + rate = self._initial_rate + if isinstance(state, dict) and self._name in state: + rate = state[self._name] + if rate == 0.0: + return x + mask_shape = list(jnp.shape(x)) + + for axis in self._shared_axes: + mask_shape[axis] = 1 + keep_prob = 1.0 - rate + keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape)) + mask = keep.astype(x.dtype) / keep_prob + return x * mask class Weights(base.Layer): - """Learnable weights as a layer. - - It takes no input and returns a single tensor: weights. - """ - - def __init__(self, initializer, shape=tuple(), use_bfloat16=False): - """Returns a learnable tensor of shape `shape`. + """Learnable weights as a layer. - Args: - initializer: Function taking shape and rng as arguments. - shape: Shape of the learnable weights. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - """ - super().__init__(name=f'Weights_{shape}', n_in=0, n_out=1) - self._shape = shape - self._initializer = initializer - self._use_bfloat16 = use_bfloat16 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor with previously specified shape and dtype. + It takes no input and returns a single tensor: weights. """ - del x # Unused. There is no input to this layer. - return self.weights - - def init_weights_and_state(self, input_signature): - """Returns newly initialized weights for this layer. - Weights is a single `w` tensor with previously specified shape. + def __init__(self, initializer, shape=tuple(), use_bfloat16=False): + """Returns a learnable tensor of shape `shape`. + + Args: + initializer: Function taking shape and rng as arguments. + shape: Shape of the learnable weights. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + """ + super().__init__(name=f"Weights_{shape}", n_in=0, n_out=1) + self._shape = shape + self._initializer = initializer + self._use_bfloat16 = use_bfloat16 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor with previously specified shape and dtype. + """ + del x # Unused. There is no input to this layer. + return self.weights + + def init_weights_and_state(self, input_signature): + """Returns newly initialized weights for this layer. + + Weights is a single `w` tensor with previously specified shape. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. Unused. + """ + del input_signature # Unused. There is no input to this layer. + self.weights = self._initializer(self._shape, self.rng) + if self._use_bfloat16: + self.weights = self.weights.astype(jnp.bfloat16) + + +def PrintShape(n_in=1, msg=""): + """Prints the shapes of `n_in` inputs and returns then unchanged.""" + + def Fwd(xs): + def format_shape(x): # pylint: disable = invalid-name + return str(jnp.shape(x)) + f"[{x.dtype}]" + + if n_in > 1: + shapes_and_dtypes = ", ".join([format_shape(x) for x in xs]) + else: + shapes_and_dtypes = format_shape(xs) + info = f"PrintShape: {msg}: [{shapes_and_dtypes}]" + print(info) + logging.info(info) + return xs - Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. Unused. - """ - del input_signature # Unused. There is no input to this layer. - self.weights = self._initializer(self._shape, self.rng) - if self._use_bfloat16: - self.weights = self.weights.astype(jnp.bfloat16) - - -def PrintShape(n_in=1, msg=''): - """Prints the shapes of `n_in` inputs and returns then unchanged.""" - def Fwd(xs): - def format_shape(x): # pylint: disable = invalid-name - return str(x.shape) + f'[{x.dtype}]' - if n_in > 1: - shapes_and_dtypes = ', '.join([format_shape(x) for x in xs]) - else: - shapes_and_dtypes = format_shape(xs) - info = f'PrintShape: {msg}: [{shapes_and_dtypes}]' - print(info) - logging.info(info) - return xs - return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f'PrintShape_{n_in}') + return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f"PrintShape_{n_in}") class SummaryImage(base.Layer): - """A layer receiving a tensor, and adding it to TensorBoard as an image. - - It takes an input and returns it unchanged. It stores this input as a state to - be used as a metric in TensorBoard. - It converts a tensor to a scalar by running a given aggregation function (mean - by default). On TensorBoard, results for each device will be reported - separately. - """ + """A layer receiving a tensor, and adding it to TensorBoard as an image. - def __init__(self, name, n_in, num_summaries=5, - recover_fn=None): - """Takes a tensor and returns it. - - Args: - name: Name of the metric to be reported. - n_in: Number of inputs. - num_summaries: Number of images to show. - recover_fn: the function for converting a tensor to a dipslayable image. + It takes an input and returns it unchanged. It stores this input as a state to + be used as a metric in TensorBoard. + It converts a tensor to a scalar by running a given aggregation function (mean + by default). On TensorBoard, results for each device will be reported + separately. """ - super().__init__(name=f'Summary_{name}', n_in=n_in, n_out=n_in) - name = 'summary_' + name - self._name = name - self._num_summaries = num_summaries - self._recover_fn = recover_fn - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor with previously specified shape and dtype. - """ - self.state = {} - batch_size = x[0].shape[0] - num_images = min(self._num_summaries, batch_size) - for s in range(num_images): - images = [] - for i in range(self._n_in): - images.append( - self._recover_fn(x[i][s]) if self._recover_fn else x[i][s]) - self.state[self._name + str(s)] = jnp.concatenate(images, axis=0) - return x[:self._n_in] - - def init_weights_and_state(self, input_signature): - """Returns newly initialized weights for this layer. - - Weights is a single `w` tensor with previously specified shape. - - Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. Unused. - """ - del input_signature # Unused. - self.weights = () - self.state = {self._name: jnp.array(0.)} + def __init__(self, name, n_in, num_summaries=5, recover_fn=None): + """Takes a tensor and returns it. + + Args: + name: Name of the metric to be reported. + n_in: Number of inputs. + num_summaries: Number of images to show. + recover_fn: the function for converting a tensor to a dipslayable image. + """ + super().__init__(name=f"Summary_{name}", n_in=n_in, n_out=n_in) + name = "summary_" + name + self._name = name + self._num_summaries = num_summaries + self._recover_fn = recover_fn + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor with previously specified shape and dtype. + """ + self.state = {} + batch_size = x[0].shape[0] + num_images = min(self._num_summaries, batch_size) + for s in range(num_images): + images = [] + for i in range(self._n_in): + images.append( + self._recover_fn(x[i][s]) if self._recover_fn else x[i][s] + ) + self.state[self._name + str(s)] = jnp.concatenate(images, axis=0) + return x[: self._n_in] + + def init_weights_and_state(self, input_signature): + """Returns newly initialized weights for this layer. + + Weights is a single `w` tensor with previously specified shape. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. Unused. + """ + del input_signature # Unused. + self.weights = () + self.state = {self._name: jnp.array(0.0)} class SummaryScalar(base.Layer): - """A layer receiving a tensor, and adding it to TensorBoard as a scalar. + """A layer receiving a tensor, and adding it to TensorBoard as a scalar. - It takes an input and returns it unchanged. It stores this input as a state to - be used as a metric in TensorBoard. - It converts a tensor to a scalar by running a given aggregation function (mean - by default). On TensorBoard, results for each device will be reported - separately. - """ + It takes an input and returns it unchanged. It stores this input as a state to + be used as a metric in TensorBoard. + It converts a tensor to a scalar by running a given aggregation function (mean + by default). On TensorBoard, results for each device will be reported + separately. + """ - def __init__(self, name, aggregation_fun=jnp.mean): - """Takes a tensor and returns it. + def __init__(self, name, aggregation_fun=jnp.mean): + """Takes a tensor and returns it. - Args: - name: Name of the metric to be reported. - aggregation_fun: Aggregation function to be used. - """ - super().__init__(name=f'Summary_{name}', n_in=1, n_out=1) - name = 'summary_' + name - self._name = name - self._aggregation_fun = aggregation_fun + Args: + name: Name of the metric to be reported. + aggregation_fun: Aggregation function to be used. + """ + super().__init__(name=f"Summary_{name}", n_in=1, n_out=1) + name = "summary_" + name + self._name = name + self._aggregation_fun = aggregation_fun - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + def forward(self, x): + """Executes this layer as part of a forward pass through the model. - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. - Returns: - Tensor with previously specified shape and dtype. - """ - self.state = {self._name: self._aggregation_fun(x)} - return x + Returns: + Tensor with previously specified shape and dtype. + """ + self.state = {self._name: self._aggregation_fun(x)} + return x - def init_weights_and_state(self, input_signature): - """Returns newly initialized weights for this layer. + def init_weights_and_state(self, input_signature): + """Returns newly initialized weights for this layer. - Weights is a single `w` tensor with previously specified shape. + Weights is a single `w` tensor with previously specified shape. - Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. Unused. - """ - del input_signature # Unused. - self.weights = () - self.state = {self._name: jnp.array(0.)} + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. Unused. + """ + del input_signature # Unused. + self.weights = () + self.state = {self._name: jnp.array(0.0)} class RandomUniform(base.Layer): - """Layer returning a tensor with random values distributed uniformly.""" + """Layer returning a tensor with random values distributed uniformly.""" + + def __init__( + self, min_val=0.0, max_val=1.0, shape=(), dtype=jnp.float32, sync=False + ): + """Layer returning a tensor with random values distributed uniformly. + + Args: + min_val: Lower end of uniform distribution. + max_val: Upper end of uniform distribution. + shape: Shape of the tensor to return. Values are sampled independently. + dtype: Type of value to return. + sync: Whether to synchronise `rng` across devices. + """ + super().__init__(n_in=0, n_out=1) + self._min_val = min_val + self._max_val = max_val + self._shape = shape + self._dtype = dtype + self._sync = sync + + def forward(self, xs): + """Executes this layer as part of a forward pass through the model. + + Args: + xs: Unused tensors. + + Returns: + Random uniform tensor of the shape and type specified in constructor. + """ + rng = self._get_conditionally_synced_rng() + result = fastmath.random.uniform( + rng, self._shape, self._dtype, self._min_val, self._max_val + ) + return result + + def _get_conditionally_synced_rng(self): + if self._sync and fastmath.global_device_count() > 1: + return fastmath.psum(self.rng, "batch") + else: + return self.rng - def __init__(self, min_val=0.0, max_val=1.0, shape=(), dtype=jnp.float32, - sync=False): - """Layer returning a tensor with random values distributed uniformly. - Args: - min_val: Lower end of uniform distribution. - max_val: Upper end of uniform distribution. - shape: Shape of the tensor to return. Values are sampled independently. - dtype: Type of value to return. - sync: Whether to synchronise `rng` across devices. - """ - super().__init__(n_in=0, n_out=1) - self._min_val = min_val - self._max_val = max_val - self._shape = shape - self._dtype = dtype - self._sync = sync +class LocallyConnected1d(base.Layer): + """Locally-connected layer for 1D inputs. - def forward(self, xs): - """Executes this layer as part of a forward pass through the model. + The LocallyConnected1d layer applies a different set of filters to each patch + of the input. This is similar to applying a convolution layer, except that + locally-connected layer uses a different set of weights for each patch. - Args: - xs: Unused tensors. + The size of patch is determined by the kernel size. The stride is currently + not modifiable and set to one. This means for the input of shape (..., L, D) + the output shape for paddings 'SAME' and 'WRAP' will be (..., L, filters) and + for padding 'VALID' (..., L-kernel_size+1, filters); where L is the number of + "pixels" or "steps" in the input, D is the size of the embedding. - Returns: - Random uniform tensor of the shape and type specified in constructor. + Note that, since the weights for different patches are not shared, the number + of "pixels" or "steps" cannot change after calling init_weights_and_state. + This is because each "pixel" is assigned its own set of weights. """ - rng = self._get_conditionally_synced_rng() - result = fastmath.random.uniform( - rng, self._shape, self._dtype, self._min_val, self._max_val) - return result - - def _get_conditionally_synced_rng(self): - if self._sync and fastmath.global_device_count() > 1: - return fastmath.psum(self.rng, 'batch') - else: - return self.rng + def __init__( + self, + filters, + kernel_size, + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, + padding="VALID", + ): + """Returns a locally-connected conv-like layer. + + Args: + filters: Number of output filters in the convolution. + kernel_size: A length of the convolution window. Must be an odd number. + kernel_initializer: Function that creates a matrix of (random) initial + connection weights `W` for the layer. + bias_initializer: Function that creates a vector of (random) initial + bias weights `b` for the layer. + use_bias: If `True`, the layer uses a bias vector. + padding: The type of padding to use; must be 'VALID', 'SAME', or 'WRAP'. + """ + super().__init__(name=f"LocallyConnected1d_{filters}_{kernel_size}") + self._filters = filters + self._kernel_size = kernel_size + assert self._kernel_size % 2 == 1 # kernel size has to be odd + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + self._use_bias = use_bias + self._padding = padding + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input, except the final dimension + is the layer's `filters` value, and the second to last dimension is + shrinked if 'VALID' padding is used with kernel_size bigger than one. + """ + if self._use_bias: + if not isinstance(self.weights, (tuple, list)): + raise ValueError( + f"Weights should be a (w, b) tuple or list; " + f"instead got: {self.weights}" + ) + w, b = self.weights + else: + w = self.weights -class LocallyConnected1d(base.Layer): - """Locally-connected layer for 1D inputs. - - The LocallyConnected1d layer applies a different set of filters to each patch - of the input. This is similar to applying a convolution layer, except that - locally-connected layer uses a different set of weights for each patch. - - The size of patch is determined by the kernel size. The stride is currently - not modifiable and set to one. This means for the input of shape (..., L, D) - the output shape for paddings 'SAME' and 'WRAP' will be (..., L, filters) and - for padding 'VALID' (..., L-kernel_size+1, filters); where L is the number of - "pixels" or "steps" in the input, D is the size of the embedding. - - Note that, since the weights for different patches are not shared, the number - of "pixels" or "steps" cannot change after calling init_weights_and_state. - This is because each "pixel" is assigned its own set of weights. - """ - - def __init__(self, filters, kernel_size, - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True, padding='VALID'): - """Returns a locally-connected conv-like layer. + linear_results_before_shifting = jnp.einsum("...lp,lkpd->...lkd", x, w) + # TODO(jaszczur): this could be run after padding for better efficiency - Args: - filters: Number of output filters in the convolution. - kernel_size: A length of the convolution window. Must be an odd number. - kernel_initializer: Function that creates a matrix of (random) initial - connection weights `W` for the layer. - bias_initializer: Function that creates a vector of (random) initial - bias weights `b` for the layer. - use_bias: If `True`, the layer uses a bias vector. - padding: The type of padding to use; must be 'VALID', 'SAME', or 'WRAP'. - """ - super().__init__(name=f'LocallyConnected1d_{filters}_{kernel_size}') - self._filters = filters - self._kernel_size = kernel_size - assert self._kernel_size % 2 == 1 # kernel size has to be odd - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - self._use_bias = use_bias - self._padding = padding - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + if self._kernel_size == 1: + # With kernel size 1 we don't have to split or shift anything. + linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) + else: + # We computed a result for every "pixel", but each direction from the + # receptive field (there are 'self._kernel_size' such directions) must be + # shifted by a different amount. The easiest way to do it is to split + # the tensor to 'self._kernel_size' smaller tensors, shift each one + # appropriately, and then sum them together. + split_shifting_linear_results = jnp.split( + linear_results_before_shifting, self._kernel_size, axis=-2 + ) + + for i in range(self._kernel_size): + # Each tensor has to be shifted a different amount. + if self._padding == "WRAP": + # We can shift by padding and cutting. With 'wrap' padding we + # essentially have a torus. + padding = [(0, 0) for i in split_shifting_linear_results[i].shape] + padding[-3] = ((self._kernel_size - 1) - i, i) + split_shifting_linear_results[i] = jnp.pad( + split_shifting_linear_results[i], padding, mode="wrap" + ) + split_shifting_linear_results[i] = split_shifting_linear_results[i][ + ..., + (self._kernel_size - 1) // 2 : -(self._kernel_size - 1) // 2, + :, + :, + ] + elif self._padding == "SAME": + # We can shift by padding and cutting. + padding = [(0, 0) for i in split_shifting_linear_results[i].shape] + padding[-3] = ((self._kernel_size - 1) - i, i) + split_shifting_linear_results[i] = jnp.pad( + split_shifting_linear_results[i], padding + ) + split_shifting_linear_results[i] = split_shifting_linear_results[i][ + ..., + (self._kernel_size - 1) // 2 : -(self._kernel_size - 1) // 2, + :, + :, + ] + # TODO(jaszczur): improve efficiency by not padding things to cut + elif self._padding == "VALID": + # We don't need to shift - just cut the leftmost and rightmost values. + cut_left = (self._kernel_size - 1) - i + cut_right = split_shifting_linear_results[i].shape[-3] - i + split_shifting_linear_results[i] = split_shifting_linear_results[i][ + ..., cut_left:cut_right, :, : + ] + else: + raise ValueError(f"Invalid padding {self._padding}") + # After shifting. + shifted_linear_results = jnp.concatenate( + split_shifting_linear_results, axis=-2 + ) + linear_result = jnp.sum(shifted_linear_results, axis=-2) + + if self._use_bias: + return linear_result + b + else: + return linear_result + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights. + + Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the + default case), or a `w` tensor for layers created with `use_bias=False`. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. + """ + shape_w = ( + input_signature.shape[-2], + self._kernel_size, + input_signature.shape[-1], + self._filters, + ) + if self._padding == "VALID": + shape_b = ( + input_signature.shape[-2] - self._kernel_size + 1, + self._filters, + ) + else: + shape_b = ( + input_signature.shape[-2], + self._filters, + ) - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + rng_w, rng_b = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(shape_w, rng_w, nonreceptive_dims=[0]) - Returns: - Tensor of same shape and dtype as the input, except the final dimension - is the layer's `filters` value, and the second to last dimension is - shrinked if 'VALID' padding is used with kernel_size bigger than one. - """ - if self._use_bias: - if not isinstance(self.weights, (tuple, list)): - raise ValueError(f'Weights should be a (w, b) tuple or list; ' - f'instead got: {self.weights}') - w, b = self.weights - else: - w = self.weights - - linear_results_before_shifting = jnp.einsum( - '...lp,lkpd->...lkd', x, w) - # TODO(jaszczur): this could be run after padding for better efficiency - - if self._kernel_size == 1: - # With kernel size 1 we don't have to split or shift anything. - linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) - else: - # We computed a result for every "pixel", but each direction from the - # receptive field (there are 'self._kernel_size' such directions) must be - # shifted by a different amount. The easiest way to do it is to split - # the tensor to 'self._kernel_size' smaller tensors, shift each one - # appropriately, and then sum them together. - split_shifting_linear_results = jnp.split( - linear_results_before_shifting, self._kernel_size, axis=-2) - - for i in range(self._kernel_size): - # Each tensor has to be shifted a different amount. - if self._padding == 'WRAP': - # We can shift by padding and cutting. With 'wrap' padding we - # essentially have a torus. - padding = [(0, 0) for i in split_shifting_linear_results[i].shape] - padding[-3] = ((self._kernel_size - 1) - i, i) - split_shifting_linear_results[i] = jnp.pad( - split_shifting_linear_results[i], padding, mode='wrap') - split_shifting_linear_results[i] = split_shifting_linear_results[i][ - ..., (self._kernel_size-1)//2:-(self._kernel_size-1)//2, :, :] - elif self._padding == 'SAME': - # We can shift by padding and cutting. - padding = [(0, 0) for i in split_shifting_linear_results[i].shape] - padding[-3] = ((self._kernel_size - 1) - i, i) - split_shifting_linear_results[i] = jnp.pad( - split_shifting_linear_results[i], padding) - split_shifting_linear_results[i] = split_shifting_linear_results[i][ - ..., (self._kernel_size-1)//2:-(self._kernel_size-1)//2, :, :] - # TODO(jaszczur): improve efficiency by not padding things to cut - elif self._padding == 'VALID': - # We don't need to shift - just cut the leftmost and rightmost values. - cut_left = (self._kernel_size - 1) - i - cut_right = split_shifting_linear_results[i].shape[-3] - i - split_shifting_linear_results[i] = split_shifting_linear_results[i][ - ..., cut_left:cut_right, :, :] + if self._use_bias: + b = self._bias_initializer(shape_b, rng_b) + self.weights = (w, b) else: - raise ValueError(f'Invalid padding {self._padding}') - # After shifting. - shifted_linear_results = jnp.concatenate(split_shifting_linear_results, - axis=-2) - linear_result = jnp.sum(shifted_linear_results, axis=-2) + self.weights = w - if self._use_bias: - return linear_result + b - else: - return linear_result - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights. +def Flatten(n_axes_to_keep=1): + """Returns a layer that combines one or more trailing axes of a tensor. - Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the - default case), or a `w` tensor for layers created with `use_bias=False`. + Flattening keeps all the values of the input tensor, but reshapes it by + collapsing one or more trailing axes into a single axis. For example, a + `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape + `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`. Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. + n_axes_to_keep: Number of leading axes to leave unchanged when reshaping; + collapse only the axes after these. """ - shape_w = (input_signature.shape[-2], self._kernel_size, - input_signature.shape[-1], self._filters) - if self._padding == 'VALID': - shape_b = (input_signature.shape[-2] - self._kernel_size + 1, - self._filters,) - else: - shape_b = (input_signature.shape[-2], self._filters,) + layer_name = f"Flatten_keep{n_axes_to_keep}" - rng_w, rng_b = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(shape_w, rng_w, nonreceptive_dims=[0]) + def f(x): # pylint: disable=invalid-name + in_rank = len(jnp.shape(x)) + if in_rank <= n_axes_to_keep: + raise ValueError( + f"Input rank ({in_rank}) must exceed the number of " + f"axes to keep ({n_axes_to_keep}) after flattening." + ) + shape = jnp.shape(x) + if isinstance(shape, tf.TensorShape): + shape = tuple(shape.as_list()) + return jnp.reshape(x, (shape[:n_axes_to_keep] + (-1,))) - if self._use_bias: - b = self._bias_initializer(shape_b, rng_b) - self.weights = (w, b) - else: - self.weights = w - - -def Flatten(n_axes_to_keep=1): - """Returns a layer that combines one or more trailing axes of a tensor. - - Flattening keeps all the values of the input tensor, but reshapes it by - collapsing one or more trailing axes into a single axis. For example, a - `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape - `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`. - - Args: - n_axes_to_keep: Number of leading axes to leave unchanged when reshaping; - collapse only the axes after these. - """ - layer_name = f'Flatten_keep{n_axes_to_keep}' - def f(x): # pylint: disable=invalid-name - in_rank = len(x.shape) - if in_rank <= n_axes_to_keep: - raise ValueError(f'Input rank ({in_rank}) must exceed the number of ' - f'axes to keep ({n_axes_to_keep}) after flattening.') - shape = x.shape - if isinstance(shape, tf.TensorShape): - shape = tuple(shape.as_list()) - return jnp.reshape(x, (shape[:n_axes_to_keep] + (-1,))) - return Fn(layer_name, f) + return Fn(layer_name, f) def LogSoftmax(axis=-1): - """Returns a layer that applies log softmax along one tensor axis. + """Returns a layer that applies log softmax along one tensor axis. - Note that the implementation actually computes x - LogSumExp(x), - which is mathematically equal to LogSoftmax(x). + Note that the implementation actually computes x - LogSumExp(x), + which is mathematically equal to LogSoftmax(x). - `LogSoftmax` acts on a group of values and normalizes them to look like a set - of log probability values. (Probability values must be non-negative, and as - a set must sum to 1. A group of log probability values can be seen as the - natural logarithm function applied to a set of probability values.) + `LogSoftmax` acts on a group of values and normalizes them to look like a set + of log probability values. (Probability values must be non-negative, and as + a set must sum to 1. A group of log probability values can be seen as the + natural logarithm function applied to a set of probability values.) - Args: - axis: Axis along which values are grouped for computing log softmax. - """ - return Fn('LogSoftmax', lambda x: log_softmax(x, axis=axis)) + Args: + axis: Axis along which values are grouped for computing log softmax. + """ + return Fn("LogSoftmax", lambda x: log_softmax(x, axis=axis)) def LogSumExp(axis=-1): - """Returns a layer that computes log(sum(exp(x))) along one tensor axis. + """Returns a layer that computes log(sum(exp(x))) along one tensor axis. - Args: - axis: Axis along which values are grouped for computing log-sum-exp. - """ - return Fn('LogSumExp', - lambda x: fastmath.logsumexp(x, axis=axis, keepdims=True)) + Args: + axis: Axis along which values are grouped for computing log-sum-exp. + """ + return Fn("LogSumExp", lambda x: fastmath.logsumexp(x, axis=axis, keepdims=True)) def Softmax(axis=-1): - """Returns a layer that applies softmax along one tensor axis. + """Returns a layer that applies softmax along one tensor axis. - `Softmax` acts on a group of values and normalizes them to look like a set - of probability values. (Probability values must be non-negative, and as a - set must sum to 1.) + `Softmax` acts on a group of values and normalizes them to look like a set + of probability values. (Probability values must be non-negative, and as a + set must sum to 1.) - Args: - axis: Axis along which values are grouped for computing softmax. - """ - return Fn('Softmax', - lambda x: jnp.exp(log_softmax(x, axis=axis))) + Args: + axis: Axis along which values are grouped for computing softmax. + """ + return Fn("Softmax", lambda x: jnp.exp(log_softmax(x, axis=axis))) def ToFloat(): - """Returns a layer that changes the dtype of a tensor to `float32`.""" - return Fn('ToFloat', lambda x: x.astype(np.float32)) + """Returns a layer that changes the dtype of a tensor to `float32`.""" + return Fn("ToFloat", lambda x: x.astype(np.float32)) def Mean(axis=-1, keepdims=False): - """Returns a layer that computes mean values using one tensor axis. + """Returns a layer that computes mean values using one tensor axis. - `Mean` uses one tensor axis to form groups of values and replaces each group - with the mean value of that group. The resulting values can either remain - in their own size 1 axis (`keepdims=True`), or that axis can be removed from - the overall tensor (default `keepdims=False`), lowering the rank of the - tensor by one. + `Mean` uses one tensor axis to form groups of values and replaces each group + with the mean value of that group. The resulting values can either remain + in their own size 1 axis (`keepdims=True`), or that axis can be removed from + the overall tensor (default `keepdims=False`), lowering the rank of the + tensor by one. - Args: - axis: Axis along which values are grouped for computing a mean. - keepdims: If `True`, keep the resulting size 1 axis as a separate tensor - axis; else, remove that axis. - """ - return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims)) + Args: + axis: Axis along which values are grouped for computing a mean. + keepdims: If `True`, keep the resulting size 1 axis as a separate tensor + axis; else, remove that axis. + """ + return Fn("Mean", lambda x: jnp.mean(x, axis=axis, keepdims=keepdims)) def Min(axis=-1, keepdims=False): - """Returns a layer that applies min along one tensor axis. + """Returns a layer that applies min along one tensor axis. - Args: - axis: Axis along which values are grouped for computing minimum. - keepdims: If `True`, keep the resulting size 1 axis as a separate tensor - axis; else, remove that axis. - """ - return Fn('Min', lambda x: jnp.min(x, axis, keepdims=keepdims)) + Args: + axis: Axis along which values are grouped for computing minimum. + keepdims: If `True`, keep the resulting size 1 axis as a separate tensor + axis; else, remove that axis. + """ + return Fn("Min", lambda x: jnp.min(x, axis, keepdims=keepdims)) def Max(axis=-1, keepdims=False): - """Returns a layer that applies max along one tensor axis. + """Returns a layer that applies max along one tensor axis. - Args: - axis: Axis along which values are grouped for computing maximum. - keepdims: If `True`, keep the resulting size 1 axis as a separate tensor - axis; else, remove that axis. - """ - return Fn('Max', lambda x: jnp.max(x, axis, keepdims=keepdims)) + Args: + axis: Axis along which values are grouped for computing maximum. + keepdims: If `True`, keep the resulting size 1 axis as a separate tensor + axis; else, remove that axis. + """ + return Fn("Max", lambda x: jnp.max(x, axis, keepdims=keepdims)) def Sum(axis=None, keepdims=False): - """Returns a layer that computes sums using one tensor axis. + """Returns a layer that computes sums using one tensor axis. - `Sum` uses one tensor axis to form groups of values and replaces each group - with the sum of that group. The resulting sum values can either remain in - their own size 1 axis (`keepdims=True`), or that axis can be removed from the - overall tensor (default `keepdims=False`), lowering the rank of the tensor by - one. + `Sum` uses one tensor axis to form groups of values and replaces each group + with the sum of that group. The resulting sum values can either remain in + their own size 1 axis (`keepdims=True`), or that axis can be removed from the + overall tensor (default `keepdims=False`), lowering the rank of the tensor by + one. - Args: - axis: Axis along which values are grouped for computing a sum; if None, - compute sum over all elements in tensor. - keepdims: If `True`, keep the resulting size 1 axis as a separate tensor - axis; else, remove that axis. - """ - return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims)) + Args: + axis: Axis along which values are grouped for computing a sum; if None, + compute sum over all elements in tensor. + keepdims: If `True`, keep the resulting size 1 axis as a separate tensor + axis; else, remove that axis. + """ + return Fn("Sum", lambda x: jnp.sum(x, axis=axis, keepdims=keepdims)) + + +def ThresholdToBinary(threshold=0.5): + """Returns a layer that thresholds inputs to yield outputs in {0, 1}.""" + def f(model_output): # pylint: disable=invalid-name + return (model_output > threshold).astype(jnp.int32) -def ThresholdToBinary(threshold=.5): - """Returns a layer that thresholds inputs to yield outputs in {0, 1}.""" - def f(model_output): # pylint: disable=invalid-name - return (model_output > threshold).astype(jnp.int32) - return Fn('ThresholdToBinary', f) + return Fn("ThresholdToBinary", f) def ArgMax(axis=-1): - """Returns a layer that calculates argmax along the given axis.""" - def f(model_output): # pylint: disable=invalid-name - return jnp.argmax(model_output, axis=axis) - return Fn('ArgMax', f) + """Returns a layer that calculates argmax along the given axis.""" + def f(model_output): # pylint: disable=invalid-name + return jnp.argmax(model_output, axis=axis) -@assert_shape('...->...') # The output and input shapes are the same. + return Fn("ArgMax", f) + + +@assert_shape("...->...") # The output and input shapes are the same. def Negate(): - """Returns a layer that computes the element-wise negation of a tensor.""" - return Fn('Negate', lambda x: -x) + """Returns a layer that computes the element-wise negation of a tensor.""" + return Fn("Negate", lambda x: -x) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def StopGradient(): - """Returns an identity layer with a stop gradient.""" - return Fn('StopGradient', lambda x: fastmath.stop_gradient(x)) # pylint: disable=unnecessary-lambda + """Returns an identity layer with a stop gradient.""" + return Fn( + "StopGradient", lambda x: fastmath.stop_gradient(x) + ) # pylint: disable=unnecessary-lambda def one_hot(x, n_categories, dtype=jnp.float32): # pylint: disable=invalid-name - """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" - indices_less_than_n = jnp.arange(n_categories) - return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype) + """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" + indices_less_than_n = jnp.arange(n_categories) + return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype) def log_softmax(x, axis=-1): # pylint: disable=invalid-name - """Transforms activation vectors to log-probability vectors. + """Transforms activation vectors to log-probability vectors. - Log probability vectors are derived by, in effect, applying softmax to raw - activation vectors and then applying log element-wise. The actual - implementation uses a mathematically valid simplification of this. + Log probability vectors are derived by, in effect, applying softmax to raw + activation vectors and then applying log element-wise. The actual + implementation uses a mathematically valid simplification of this. - Args: - x: An ndarray with activation vectors along the given axis. - axis: Axis along which values are grouped for computing log softmax. + Args: + x: An ndarray with activation vectors along the given axis. + axis: Axis along which values are grouped for computing log softmax. - Returns: - An ndarray containing log-probability vectors derived from the raw - activation vectors in `x`. - """ - return x - fastmath.logsumexp(x, axis=axis, keepdims=True) + Returns: + An ndarray containing log-probability vectors derived from the raw + activation vectors in `x`. + """ + return x - fastmath.logsumexp(x, axis=axis, keepdims=True) def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name - """Returns `log N(x | mu, sigma)`. - - Args: - x: - mu: - sigma: - """ - a = mu.shape[-1] * jnp.log(2 * jnp.pi) - _, b = jnp.linalg.slogdet(sigma) - y = jnp.linalg.solve(sigma, x - mu) - y = jnp.expand_dims(y, axis=-1) - xm = jnp.expand_dims(x - mu, axis=-2) - c = jnp.matmul(xm, y) - c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) - return -0.5 * (a + b + c) + """Returns `log N(x | mu, sigma)`. + + Args: + x: + mu: + sigma: + """ + a = mu.shape[-1] * jnp.log(2 * jnp.pi) + _, b = jnp.linalg.slogdet(sigma) + y = jnp.linalg.solve(sigma, x - mu) + y = jnp.expand_dims(y, axis=-1) + xm = jnp.expand_dims(x - mu, axis=-2) + c = jnp.matmul(xm, y) + c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) + return -0.5 * (a + b + c) def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name - """Returns `log N(x | mu, eye(diag_sigma))`. - - Args: - x: - mu: - diag_sigma: - """ - a = mu.shape[-1] * jnp.log(2 * jnp.pi) - b = jnp.sum(jnp.log(diag_sigma), axis=-1) - y = x - mu / diag_sigma - y = jnp.expand_dims(y, axis=-1) - xm = jnp.expand_dims(x - mu, axis=-2) - c = jnp.matmul(xm, y) - c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) - return -0.5 * (a + b + c) + """Returns `log N(x | mu, eye(diag_sigma))`. + + Args: + x: + mu: + diag_sigma: + """ + a = mu.shape[-1] * jnp.log(2 * jnp.pi) + b = jnp.sum(jnp.log(diag_sigma), axis=-1) + y = x - mu / diag_sigma + y = jnp.expand_dims(y, axis=-1) + xm = jnp.expand_dims(x - mu, axis=-2) + c = jnp.matmul(xm, y) + c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) + return -0.5 * (a + b + c) def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-name - """Returns a mixture of gaussians loss. - - Args: - preds: - targets: - ngauss: - """ - ndims = targets.shape[-1] - logits = preds[:, :ngauss] - mus = preds[:, ngauss:ngauss*(ndims + 1)] - sigmas = preds[:, ngauss(ndims + 1):] - sigmas = sigmas * sigmas + 1e-6 # Make positive. - loglogits = logits - fastmath.logsumexp(logits, axis=-1, keepdims=True) - mus = jnp.reshape(mus, [-1, ngauss, ndims]) - sigmas = jnp.reshape(sigmas, [-1, ngauss, ndims]) - targets = jnp.reshape(targets, [-1, 1, ndims]) - glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) - return fastmath.logsumexp(loglogits + glogprobs, axis=-1) + """Returns a mixture of gaussians loss. + + Args: + preds: + targets: + ngauss: + """ + ndims = targets.shape[-1] + logits = preds[:, :ngauss] + mus = preds[:, ngauss : ngauss * (ndims + 1)] + sigmas = preds[:, ngauss(ndims + 1) :] + sigmas = sigmas * sigmas + 1e-6 # Make positive. + loglogits = logits - fastmath.logsumexp(logits, axis=-1, keepdims=True) + mus = jnp.reshape(mus, [-1, ngauss, ndims]) + sigmas = jnp.reshape(sigmas, [-1, ngauss, ndims]) + targets = jnp.reshape(targets, [-1, 1, ndims]) + glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) + return fastmath.logsumexp(loglogits + glogprobs, axis=-1) # TODO(jonni): Rename to log_softmax_sample. def logsoftmax_sample(log_probs, temperature=1.0): # pylint: disable=invalid-name - """Returns a sample from a log-softmax output, with temperature. - - Args: - log_probs: Logarithms of probabilities (often coming from LogSoftmax) - temperature: For scaling before sampling (1.0 = default, 0.0 = pick argmax) - """ - # This is equivalent to sampling from a softmax with temperature. - u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape) - g = -np.log(-np.log(u)) - return np.argmax(log_probs + g * temperature, axis=-1) + """Returns a sample from a log-softmax output, with temperature. + + Args: + log_probs: Logarithms of probabilities (often coming from LogSoftmax) + temperature: For scaling before sampling (1.0 = default, 0.0 = pick argmax) + """ + # This is equivalent to sampling from a softmax with temperature. + u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape) + g = -np.log(-np.log(u)) + return np.argmax(log_probs + g * temperature, axis=-1) diff --git a/trax/layers/core_test.py b/trax/layers/core_test.py deleted file mode 100644 index 85143f0ca..000000000 --- a/trax/layers/core_test.py +++ /dev/null @@ -1,492 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for core layers.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.fastmath import numpy as jnp -import trax.layers as tl -import trax.layers.initializers as init - - -class DenseTest(absltest.TestCase): - """Test Dense layer per se and as a key example of trainable layers.""" - - def test_call_before_init_raises_error(self): - layer = tl.Dense(5) - x = np.array([1, 2, 3]) - - # Without init, layer lacks the weights it needs for forward computation. - with self.assertRaises(tl.LayerError): - _ = layer(x) - - def test_call_uses_and_caches_supplied_weights(self): - layer = tl.Dense(4) - x = np.array([2, 3]) - - # Weights from random initialization are cached in the layer. - _, _ = layer.init(shapes.signature(x)) - w_init, b_init = layer.weights - - # Call the layer with externally specified weights. - w = np.array([[10000, 20000, 30000, 40000], [100, 200, 100, 200]]) - b = np.array([9, 8, 7, 6]) - y = layer(x, weights=(w, b)) - - # Using weights keyword arg overrides any previous cached weights ... - self.assertEqual(y.tolist(), [20309, 40608, 60307, 80606]) - self.assertNotEqual(w.tolist(), w_init.tolist()) - self.assertNotEqual(b.tolist(), b_init.tolist()) - - # ... and do not over-write the old weights. - w_cached, b_cached = layer.weights - self.assertNotEqual(w.tolist(), w_cached.tolist()) - self.assertNotEqual(b.tolist(), b_cached.tolist()) - - def test_separate_instances_have_separate_weights(self): - # Two dense layer instances: each will get its own initial weights (w, b). - model = tl.Serial(tl.Dense(5), tl.Dense(5)) - - sample_input = np.array([1, 2, 3, 4, 5]) - _, _ = model.init(shapes.signature(sample_input)) - weights_0 = model.sublayers[0].weights - weights_1 = model.sublayers[1].weights - - w0, b0 = weights_0 - w1, b1 = weights_1 - self.assertNotEqual(w0.tolist(), w1.tolist()) - self.assertNotEqual(b0.tolist(), b1.tolist()) - - def test_shared_instance_means_shared_weights(self): - # Same dense layer instance in two places --> shared weights. - layer = tl.Dense(5) - model = tl.Serial(layer, layer) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - - def test_call_no_bias(self): - layer = tl.Dense(4, use_bias=False) - x = np.array([2, 5, 3]) - _, _ = layer.init(shapes.signature(x)) - - w = np.array([[100, 200, 300, 400], [10, 10, 10, 10], [1, 2, 1, 2]]) - y = layer(x, weights=w) - self.assertEqual(y.tolist(), [253, 456, 653, 856]) - - def test_new_weights_use_bias(self): - layer = tl.Dense(4) - x = np.array([1, 2]) - _, _ = layer.init(shapes.signature(x)) - self.assertLen(layer.weights, 2) - self.assertEqual(layer.weights[0].shape, (2, 4)) - self.assertEqual(layer.weights[1].shape, (4,)) - - def test_new_weights_no_bias(self): - layer = tl.Dense(4, use_bias=False) - x = np.array([1, 2]) - _, _ = layer.init(shapes.signature(x)) - self.assertEqual(layer.weights.shape, (2, 4)) - - def test_init_twice_weights_same_shape(self): - layer = tl.Dense(4, use_bias=False) - x = np.array([1, 2]) - w1, _ = layer.init(shapes.signature(x)) - w2, _ = layer.init(shapes.signature(x)) - self.assertEqual(w1.shape, (2, 4)) - self.assertEqual(w2.shape, (2, 4)) - - def test_save_to_file_and_init_to_file(self): - layer1 = tl.Dense(4, use_bias=False) - layer2 = tl.Dense(4, use_bias=False) - x = np.array([1, 2]) - w1, _ = layer1.init(shapes.signature(x)) - layer1.save_to_file('/tmp/dense_weights', - input_signature=shapes.signature(x)) - w2, _ = layer2.init_from_file('/tmp/dense_weights') - self.assertEqual(w1.shape, (2, 4)) - self.assertEqual(w2.shape, (2, 4)) - self.assertEqual(w1.tolist(), w2.tolist()) - - -class EmbeddingTest(absltest.TestCase): - - def test_forward(self): - layer = tl.Embedding(10, 3) # vocab_size=10, d_feature=3 - _, _ = layer.init(None) # Embedding init doesn't use input signature. - x = np.array([2, 3, 5, 3, 2]) - y = layer(x) - self.assertEqual(y.shape, (5, 3)) - - # For distinct in-domain token IDs, resulting vectors should be distinct. - self.assertNotEqual(y[0].tolist(), y[1].tolist()) - self.assertNotEqual(y[0].tolist(), y[2].tolist()) - self.assertNotEqual(y[1].tolist(), y[2].tolist()) - - # For repeats of a token id, resulting vectors should match. - self.assertEqual(y[0].tolist(), y[4].tolist()) - self.assertEqual(y[1].tolist(), y[3].tolist()) - - def test_negative_inputs_clip_to_zero(self): - layer = tl.Embedding(10, 3) - _, _ = layer.init(None) - x = np.array([0, 2, 3, -2, -3]) - y = layer(x) - self.assertNotEqual(y[0].tolist(), y[1].tolist()) - self.assertNotEqual(y[0].tolist(), y[2].tolist()) - self.assertEqual(y[0].tolist(), y[3].tolist()) - self.assertEqual(y[0].tolist(), y[4].tolist()) - - def test_large_inputs_clip_to_upper_bound(self): - layer = tl.Embedding(10, 3) - _, _ = layer.init(None) - x = np.array([2, 3, 9, 10, 20]) - y = layer(x) - - # vocab_size of 10 means max valid token id is 9. - self.assertNotEqual(y[2].tolist(), y[0].tolist()) - self.assertNotEqual(y[2].tolist(), y[1].tolist()) - self.assertEqual(y[2].tolist(), y[3].tolist()) - self.assertEqual(y[2].tolist(), y[4].tolist()) - - def test_new_weights(self): - layer = tl.Embedding(20, 5) - _, _ = layer.init(None) - - # Default weights sampled from Gaussian, mu = 0, sigma = 1. - w = layer.weights - self.assertEqual(w.shape, (20, 5)) - self.assertLess(np.abs(np.mean(w)), .4) # .4 is 4 sigma deviation - - def test_explicit_kernel_initializer(self): - - def f(shape, rng): - del rng - n_elements = np.prod(shape) - return np.arange(n_elements).reshape(shape) - - layer = tl.Embedding(5, 2, kernel_initializer=f) - _, _ = layer.init(None) - x = np.array([0, 1, 2, 3, 4]) - y = layer(x) - self.assertEqual(y.tolist(), [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) - - -class DropoutTest(absltest.TestCase): - - def test_call_in_train_mode(self): - layer = tl.Dropout(rate=0.1, mode='train') - x = np.ones((2, 5, 1000)) # 10,000 values - y = layer(x) - self.assertEqual(y.shape, (2, 5, 1000)) - - # Dropout is stochastic; test it nonflakily at 4 sigmas (.99994). - n_remaining = np.count_nonzero(y) - mu_of_remaining = 9000 # N * q: 10000 * .9 - sigma_of_remaining = 30 # sqrt(N * p * q): sqrt(10000 * .1 * .9) - self.assertLess( - np.abs(n_remaining - mu_of_remaining), 4 * sigma_of_remaining) - - def test_call_in_eval_mode_does_no_dropout(self): - layer = tl.Dropout(rate=0.1, mode='eval') - x = np.ones((2, 5, 1000)) - y = layer(x) - self.assertEqual(np.count_nonzero(y), 10_000) - - def test_new_weights(self): - layer = tl.Dropout(rate=0.1, mode='train') - layer.init(None) - self.assertEmpty(layer.weights) - - -class WeightsTest(absltest.TestCase): - """Test Weights layer.""" - - def test_simple(self): - layer = tl.Weights(lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)) - layer.init(()) - y = layer(()) - self.assertEqual(y.tolist(), 0.) - - def test_shape(self): - layer = tl.Weights(init.RandomNormalInitializer(), (5, 10, 3)) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, (5, 10, 3)) - - def test_simple_custom_initializer(self): - layer = tl.Weights(init.RandomNormalInitializer()) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, ()) - self.assertNotEqual(y.tolist(), 0.) - - def test_custom_initializer_shape(self): - layer = tl.Weights(lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32), - (2, 2)) - layer.init(()) - y = layer(()) - self.assertEqual(y.tolist(), [[0., 0.], - [0., 0.]]) - - layer = tl.Weights(init.RandomNormalInitializer(), (2, 2)) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, (2, 2)) - self.assertNotEqual(y.tolist(), [[0., 0.], - [0., 0.]]) - - -class SummaryScalarTest(absltest.TestCase): - - def test_passes(self): - layer = tl.SummaryScalar('test') - x = np.array([[3., 5.], [2., 6.]]) # 10,000 values - y = layer(x) - self.assertEqual(y.tolist(), [[3., 5.], [2., 6.]]) - self.assertEqual(layer.state['summary_test'].tolist(), 4.0) - - -class RandomUniformTest(absltest.TestCase): - """Test Weights layer.""" - - def test_simple(self): - layer = tl.RandomUniform() - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, ()) - self.assertBetween(y, 0.0, 1.0) - - def test_shape(self): - layer = tl.RandomUniform(shape=(5, 10, 3)) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, (5, 10, 3)) - - def test_simple_range(self): - layer = tl.RandomUniform(1., 2., shape=(1000,)) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, (1000,)) - self.assertBetween(min(y.tolist()), 1., 2.) - self.assertBetween(max(y.tolist()), 1., 2.) - self.assertBetween(1.5, min(y.tolist()), max(y.tolist())) - - -class LocallyConnected1dTest(absltest.TestCase): - - def test_shape_kernel1(self): - for padding in ['WRAP', 'SAME', 'VALID']: - layer = tl.LocallyConnected1d(6, 1, padding=padding) - x = np.array([[0, 1], [2, 3], [4, 5]]) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (3, 6)) - - def test_shape_kernel3(self): - for padding in ['WRAP', 'SAME']: - layer = tl.LocallyConnected1d(6, 3, padding=padding) - x = np.array([[0, 1], [2, 3], [4, 5]]) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (3, 6)) - - for padding in ['VALID']: - layer = tl.LocallyConnected1d(6, 3, padding=padding) - x = np.array([[0, 1], [2, 3], [4, 5]]) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (1, 6)) - - -class FlattenTest(absltest.TestCase): - - def test_keep_default(self): - layer = tl.Flatten() - x = np.ones((1, 2, 3, 4, 5)) - y = layer(x) - # Default is leave first axis untouched, flatten the rest. - self.assertEqual(y.shape, (1, 2 * 3 * 4 * 5)) - - def test_keep_3(self): - layer = tl.Flatten(n_axes_to_keep=3) - x = np.ones((1, 2, 3, 4, 5)) - y = layer(x) - self.assertEqual(y.shape, (1, 2, 3, 4 * 5)) - - def test_keep_max_number(self): - layer = tl.Flatten(n_axes_to_keep=4) - x = np.ones((1, 2, 3, 4, 5)) - y = layer(x) - self.assertEqual(y.shape, (1, 2, 3, 4, 5)) - - def test_keep_too_many_raises_error(self): - layer = tl.Flatten(n_axes_to_keep=5) - with self.assertRaises(tl.LayerError): - x = np.ones((1, 2, 3, 4, 5)) - _ = layer(x) - - -class LogSoftmaxTest(absltest.TestCase): - - def test_call(self): - layer = tl.LogSoftmax() - x = np.array([[2., 1., -10.], - [1., 1., -10.]]) - y = layer(x) - np.testing.assert_allclose(y, - [[-0.313, -1.313, -12.313], - [-0.693, -0.693, -11.693]], - atol=.001) - - -class SoftmaxTest(absltest.TestCase): - - def test_call(self): - layer = tl.Softmax() - x = np.array([[2., 1., -10.], - [1., 1., -10.]]) - y = layer(x) - np.testing.assert_allclose(y, - [[.731, .269, .00000449], - [.500, .500, .00000835]], - atol=.001) - - -class CoreFunctionsTest(absltest.TestCase): - - def test_one_hot(self): - targets = np.array([2, 0, 1]) - n_categories = 5 - target_distributions = tl.one_hot(targets, n_categories) - self.assertEqual(tl.to_list(target_distributions), - [[0., 0., 1., 0., 0.], - [1., 0., 0., 0., 0.], - [0., 1., 0., 0., 0.]]) - - def test_log_softmax(self): - activations = np.array([[2., 1., -10.], - [1., 1., -10.]]) - log_probabilities = tl.log_softmax(activations) - np.testing.assert_allclose(log_probabilities, - [[-0.313, -1.313, -12.313], - [-0.693, -0.693, -11.693]], - atol=.001) - - def test_log_gaussian_pdf(self): - x = np.zeros((2, 5), dtype=np.float32) - mu = x - dsigma = np.eye(5)[None, :, :] - sigma = np.concatenate([dsigma, 2 * dsigma], axis=0) - prob = tl.log_gaussian_pdf(x, mu, sigma) - self.assertEqual(prob.shape, (2,)) - self.assertEqual(int(prob[0]), -4) - self.assertEqual(int(prob[1]), -6) - - def test_log_gaussian_diag_pdf(self): - x = np.zeros((2, 5), dtype=np.float32) - mu = x - sigma = np.ones((5,))[None, :] - sigma = np.concatenate([sigma, 2 * sigma], axis=0) - prob = tl.log_gaussian_diag_pdf(x, mu, sigma) - self.assertEqual(prob.shape, (2,)) - self.assertEqual(int(prob[0]), -4) - self.assertEqual(int(prob[1]), -6) - - -class StopGradientTest(absltest.TestCase): - - def test_passes(self): - layer = tl.StopGradient() - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2, 2)) - self.assertEqual(y.tolist(), [[3., 5.], [2., 6.]]) - - -class MinMaxTest(absltest.TestCase): - - def test_min(self): - layer = tl.Min() - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2,)) - self.assertEqual(y.tolist(), [3., 2.]) - - layer = tl.Min(axis=0) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2,)) - self.assertEqual(y.tolist(), [2., 5.]) - - layer = tl.Min(axis=None) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, ()) - self.assertEqual(y.tolist(), 2.) - - layer = tl.Min(keepdims=True) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2, 1)) - self.assertEqual(y.tolist(), [[3.], [2.]]) - - def test_max(self): - layer = tl.Max() - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2,)) - self.assertEqual(y.tolist(), [5., 6.]) - - layer = tl.Max(axis=0) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2,)) - self.assertEqual(y.tolist(), [3., 6.]) - - layer = tl.Max(axis=None) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, ()) - self.assertEqual(y.tolist(), 6.) - - layer = tl.Max(axis=0, keepdims=True) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (1, 2)) - self.assertEqual(y.tolist(), [[3., 6.]]) - - -class ClassifierLayersTest(absltest.TestCase): - - def test_threshold_to_binary(self): - layer = tl.ThresholdToBinary() - x = np.array([.30, .49, .50, .51, .70]) - y = layer(x) - self.assertEqual(y.tolist(), [0, 0, 0, 1, 1]) - - def test_arg_max(self): - layer = tl.ArgMax() - x = np.array([[.10, .90, .20, .80], - [.22, .88, .11, .99]]) - y = layer(x) - self.assertEqual(y.tolist(), [1, 3]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/deconvolution.py b/trax/layers/deconvolution.py index 05c5a26de..e6bf53e90 100644 --- a/trax/layers/deconvolution.py +++ b/trax/layers/deconvolution.py @@ -27,68 +27,84 @@ class ConvTranspose(base.Layer): - """Layer constructor function for a general Transpose Convolutional Layer.""" + """Layer constructor function for a general Transpose Convolutional Layer.""" - def __init__(self, - filters, - kernel_size, - strides=None, - padding='VALID', - rhs_dilation=None, - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - kernel_initialzer=None, - bias_initializer=init.RandomNormalInitializer(1e-6)): - super(ConvTranspose, self).__init__() - self._filters = filters - self._kernel_size = kernel_size - self._padding = padding - self._rhs_dilation = rhs_dilation - self._dimension_numbers = dimension_numbers - self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers - self._one = (1,) * len(kernel_size) - self._strides = strides or self._one - self._bias_initializer = bias_initializer - rhs_spec = self._rhs_spec - self._kernel_initializer = kernel_initialzer - if kernel_initialzer is None: - self._kernel_initializer = init.GlorotNormalInitializer( - rhs_spec.index('O'), rhs_spec.index('I')) + def __init__( + self, + filters, + kernel_size, + strides=None, + padding="VALID", + rhs_dilation=None, + dimension_numbers=("NHWC", "HWIO", "NHWC"), + kernel_initialzer=None, + bias_initializer=init.RandomNormalInitializer(1e-6), + ): + super(ConvTranspose, self).__init__() + self._filters = filters + self._kernel_size = kernel_size + self._padding = padding + self._rhs_dilation = rhs_dilation + self._dimension_numbers = dimension_numbers + self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers + self._one = (1,) * len(kernel_size) + self._strides = strides or self._one + self._bias_initializer = bias_initializer + rhs_spec = self._rhs_spec + self._kernel_initializer = kernel_initialzer + if kernel_initialzer is None: + self._kernel_initializer = init.GlorotNormalInitializer( + rhs_spec.index("O"), rhs_spec.index("I") + ) - def _check_nhwc(self): - msg = 'Deconvolutions on more than 4 dimensions only supported in NHWC.' - assert self._lhs_spec == self._out_spec == 'NHWC', msg + def _check_nhwc(self): + msg = "Deconvolutions on more than 4 dimensions only supported in NHWC." + assert self._lhs_spec == self._out_spec == "NHWC", msg - def forward(self, x): - w, b = self.weights - x_shape = list(x.shape) - if len(x_shape) > 4: - self._check_nhwc() - new_batch_dim = functools.reduce(operator.mul, x.shape[:-3]) - x = jnp.reshape(x, [new_batch_dim] + list(x.shape[-3:])) - res = lax.conv_transpose(x, w, self._strides, self._padding, - self._rhs_dilation, self._dimension_numbers) + b - if len(x_shape) > 4: - res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) - return res + def forward(self, x): + w, b = self.weights + x_shape = list(x.shape) + if len(x_shape) > 4: + self._check_nhwc() + new_batch_dim = functools.reduce(operator.mul, x.shape[:-3]) + x = jnp.reshape(x, [new_batch_dim] + list(x.shape[-3:])) + res = ( + lax.conv_transpose( + x, + w, + self._strides, + self._padding, + self._rhs_dilation, + self._dimension_numbers, + ) + + b + ) + if len(x_shape) > 4: + res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) + return res - def _kernel_shape(self, input_shape): - """Helper to calculate the kernel shape.""" - kernel_size_iter = iter(self._kernel_size) - return [ - self._filters if c == 'O' else input_shape[self._lhs_spec.index('C')] - if c == 'I' else next(kernel_size_iter) for c in self._rhs_spec - ] + def _kernel_shape(self, input_shape): + """Helper to calculate the kernel shape.""" + kernel_size_iter = iter(self._kernel_size) + return [ + self._filters + if c == "O" + else input_shape[self._lhs_spec.index("C")] + if c == "I" + else next(kernel_size_iter) + for c in self._rhs_spec + ] - def init_weights_and_state(self, input_signature): - input_shape = input_signature.shape - if len(input_shape) > 4: - self._check_nhwc() - new_batch_dim = functools.reduce(operator.mul, input_shape[:-3]) - input_shape = [new_batch_dim] + list(input_shape[-3:]) - kernel_shape = self._kernel_shape(input_shape) - bias_shape = [self._filters if c == 'C' else 1 for c in self._out_spec] - bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) - rng1, rng2 = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(kernel_shape, rng1) - b = self._bias_initializer(bias_shape, rng2) - self.weights = (w, b) + def init_weights_and_state(self, input_signature): + input_shape = input_signature.shape + if len(input_shape) > 4: + self._check_nhwc() + new_batch_dim = functools.reduce(operator.mul, input_shape[:-3]) + input_shape = [new_batch_dim] + list(input_shape[-3:]) + kernel_shape = self._kernel_shape(input_shape) + bias_shape = [self._filters if c == "C" else 1 for c in self._out_spec] + bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) + rng1, rng2 = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(kernel_shape, rng1) + b = self._bias_initializer(bias_shape, rng2) + self.weights = (w, b) diff --git a/trax/layers/initializers.py b/trax/layers/initializers.py index 42e9220d5..f089c3130 100644 --- a/trax/layers/initializers.py +++ b/trax/layers/initializers.py @@ -24,174 +24,185 @@ def _GetFans(shape, out_dim=-1, in_dim=-2, nonreceptive_dims=None): - """Get the fan-in and fan-out sizes for the given shape and dims.""" - # Temporary fix until numpy.delete supports negative indices. - if out_dim < 0: - out_dim += len(shape) - if in_dim < 0: - in_dim += len(shape) - - if nonreceptive_dims is None: - nonreceptive_dims = [] - if not isinstance(nonreceptive_dims, (list, tuple)): - nonreceptive_dims = [nonreceptive_dims] - - receptive_field = jnp.prod(np.delete(shape, [in_dim, out_dim, - *nonreceptive_dims])) - if len(shape) >= 2: - fan_in, fan_out = shape[in_dim], shape[out_dim] - elif len(shape) == 1: - fan_in = shape[0] - fan_out = shape[0] - else: - fan_in = 1. - fan_out = 1. - fan_in *= receptive_field - fan_out *= receptive_field - return fan_in, fan_out + """Get the fan-in and fan-out sizes for the given shape and dims.""" + # Temporary fix until numpy.delete supports negative indices. + if out_dim < 0: + out_dim += len(shape) + if in_dim < 0: + in_dim += len(shape) + + if nonreceptive_dims is None: + nonreceptive_dims = [] + if not isinstance(nonreceptive_dims, (list, tuple)): + nonreceptive_dims = [nonreceptive_dims] + + receptive_field = jnp.prod(np.delete(shape, [in_dim, out_dim, *nonreceptive_dims])) + if len(shape) >= 2: + fan_in, fan_out = shape[in_dim], shape[out_dim] + elif len(shape) == 1: + fan_in = shape[0] + fan_out = shape[0] + else: + fan_in = 1.0 + fan_out = 1.0 + fan_in *= receptive_field + fan_out *= receptive_field + return fan_in, fan_out def InitializerFromFile(path): - """Loads parameters from .npy file.""" + """Loads parameters from .npy file.""" - def Initializer(shape, rng): - del rng - logging.info('Loading pretrained embeddings from %s', path) - with tf.io.gfile.GFile(path, 'rb') as f: - parameters = jnp.load(f) - assert jnp.shape(parameters) == shape, ( - 'Expected shape %s, got %s' % (shape, jnp.shape(parameters))) - return parameters + def Initializer(shape, rng): + del rng + logging.info("Loading pretrained embeddings from %s", path) + with tf.io.gfile.GFile(path, "rb") as f: + parameters = jnp.load(f) + assert jnp.shape(parameters) == shape, "Expected shape %s, got %s" % ( + shape, + jnp.shape(parameters), + ) + return parameters - return Initializer + return Initializer def _PureShape(shape): - """Make sure shape does not contain int tensors by calling int().""" - return [int(x) for x in shape] + """Make sure shape does not contain int tensors by calling int().""" + return [int(x) for x in shape] def RandomNormalInitializer(stddev=1e-2): - """Returns an initializer for random normal coefficients.""" - return lambda shape, rng: (stddev * random.normal( # pylint: disable=g-long-lambda - rng, _PureShape(shape)).astype('float32')) + """Returns an initializer for random normal coefficients.""" + return lambda shape, rng: ( + stddev + * random.normal(rng, _PureShape(shape)).astype( # pylint: disable=g-long-lambda + "float32" + ) + ) def RandomUniformInitializer(lim=1.0): - """Returns an initializer for random uniform coefficients.""" - # Make sure shape does not contain int tensors by calling int() below. - return lambda shape, rng: random.uniform( # pylint: disable=g-long-lambda - rng, _PureShape(shape), jnp.float32, -lim, lim) + """Returns an initializer for random uniform coefficients.""" + # Make sure shape does not contain int tensors by calling int() below. + return lambda shape, rng: random.uniform( # pylint: disable=g-long-lambda + rng, _PureShape(shape), jnp.float32, -lim, lim + ) def ScaledInitializer(out_dim, in_dim, scale, mode, distribution): - """Returns an initializer that adjusts its scale based on weight shapes.""" - if scale <= 0.: - raise ValueError('scale must be positive float, {} given'.format(scale)) - if mode not in {'fan_in', 'fan_out', 'fan_avg'}: - raise ValueError( - 'Invalid mode argument:, {}, must be either fan_in, fan_out or fan_avg' - .format(mode)) - - def Init(shape, rng, nonreceptive_dims=None): - """Returns random values for initializing weights of the given `shape`.""" - shape = _PureShape(shape) - fan_in, fan_out = _GetFans(shape, out_dim, in_dim, nonreceptive_dims) - gain = scale - if mode == 'fan_in': - gain /= fan_in - elif mode == 'fan_out': - gain /= fan_out - elif mode == 'fan_avg': - gain /= (fan_in + fan_out) / 2 - if distribution == 'truncated_normal': - # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) - stddev = jnp.sqrt(gain) / .87962566103423978 - new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev - return new_weights.astype('float32') - elif distribution == 'normal': - new_weights = random.normal(rng, shape) * jnp.sqrt(gain) - return new_weights.astype('float32') - elif distribution == 'uniform': - lim = jnp.sqrt(3. * gain) - return random.uniform(rng, shape, jnp.float32, -lim, lim) - else: - raise ValueError('invalid distribution for ScaleInitializer') - - return Init - - -def GlorotNormalInitializer(out_dim=-1, in_dim=-2, scale=1.): - """Returns an initializer for random Glorot-scaled coefficients.""" - return ScaledInitializer(out_dim, in_dim, scale, 'fan_avg', 'normal') - - -def GlorotUniformInitializer(out_dim=-1, in_dim=-2, scale=1.): - """Returns an initializer for random uniform Glorot-scaled coefficients.""" - return ScaledInitializer(out_dim, in_dim, scale, 'fan_avg', 'uniform') - - -def LeCunNormalInitializer(out_dim=-1, in_dim=-2, scale=1.): - """Returns an initializer for random LeCun-scaled coefficients.""" - return ScaledInitializer(out_dim, in_dim, scale, 'fan_in', 'normal') - - -def LeCunUniformInitializer(out_dim=-1, in_dim=-2, scale=1.): - """Returns an initializer for random uniform LeCun-scaled coefficients.""" - return ScaledInitializer(out_dim, in_dim, scale, 'fan_in', 'uniform') - - -def KaimingNormalInitializer(out_dim=-1, in_dim=-2, param=0.): - """Returns an initializer for random Kaiming-scaled coefficients.""" - return ScaledInitializer( - out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), 'fan_in', 'normal') - - -def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.): - """Returns an initializer for random uniform Kaiming-scaled coefficients.""" - return ScaledInitializer( - out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), 'fan_in', 'uniform') + """Returns an initializer that adjusts its scale based on weight shapes.""" + if scale <= 0.0: + raise ValueError("scale must be positive float, {} given".format(scale)) + if mode not in {"fan_in", "fan_out", "fan_avg"}: + raise ValueError( + "Invalid mode argument:, {}, must be either fan_in, fan_out or fan_avg".format( + mode + ) + ) + + def Init(shape, rng, nonreceptive_dims=None): + """Returns random values for initializing weights of the given `shape`.""" + shape = _PureShape(shape) + fan_in, fan_out = _GetFans(shape, out_dim, in_dim, nonreceptive_dims) + gain = scale + if mode == "fan_in": + gain /= fan_in + elif mode == "fan_out": + gain /= fan_out + elif mode == "fan_avg": + gain /= (fan_in + fan_out) / 2 + if distribution == "truncated_normal": + # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + stddev = jnp.sqrt(gain) / 0.87962566103423978 + new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev + return new_weights.astype("float32") + elif distribution == "normal": + new_weights = random.normal(rng, shape) * jnp.sqrt(gain) + return new_weights.astype("float32") + elif distribution == "uniform": + lim = jnp.sqrt(3.0 * gain) + return random.uniform(rng, shape, jnp.float32, -lim, lim) + else: + raise ValueError("invalid distribution for ScaleInitializer") + + return Init + + +def GlorotNormalInitializer(out_dim=-1, in_dim=-2, scale=1.0): + """Returns an initializer for random Glorot-scaled coefficients.""" + return ScaledInitializer(out_dim, in_dim, scale, "fan_avg", "normal") + + +def GlorotUniformInitializer(out_dim=-1, in_dim=-2, scale=1.0): + """Returns an initializer for random uniform Glorot-scaled coefficients.""" + return ScaledInitializer(out_dim, in_dim, scale, "fan_avg", "uniform") + + +def LeCunNormalInitializer(out_dim=-1, in_dim=-2, scale=1.0): + """Returns an initializer for random LeCun-scaled coefficients.""" + return ScaledInitializer(out_dim, in_dim, scale, "fan_in", "normal") + + +def LeCunUniformInitializer(out_dim=-1, in_dim=-2, scale=1.0): + """Returns an initializer for random uniform LeCun-scaled coefficients.""" + return ScaledInitializer(out_dim, in_dim, scale, "fan_in", "uniform") + + +def KaimingNormalInitializer(out_dim=-1, in_dim=-2, param=0.0): + """Returns an initializer for random Kaiming-scaled coefficients.""" + return ScaledInitializer( + out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), "fan_in", "normal" + ) + + +def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.0): + """Returns an initializer for random uniform Kaiming-scaled coefficients.""" + return ScaledInitializer( + out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), "fan_in", "uniform" + ) def OrthogonalInitializer(stddev=1.0): - """Returns an orthogonal initializer.""" - def Init(shape, rng): - """Returns orthogonalized random normal values with the given `shape`.""" - # Have at least 2 elements in shape. - cur_shape = list(shape) - while len(cur_shape) < 2: - cur_shape = [1] + cur_shape + """Returns an orthogonal initializer.""" + + def Init(shape, rng): + """Returns orthogonalized random normal values with the given `shape`.""" + # Have at least 2 elements in shape. + cur_shape = list(shape) + while len(cur_shape) < 2: + cur_shape = [1] + cur_shape - # Flatten the input shape with the last dimension remaining. - n_rows = 1 - for dim in cur_shape[:-1]: - n_rows *= dim - n_cols = cur_shape[-1] - flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) + # Flatten the input shape with the last dimension remaining. + n_rows = 1 + for dim in cur_shape[:-1]: + n_rows *= dim + n_cols = cur_shape[-1] + flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) - # Generate a random matrix - a = random.normal(rng, flat_shape, dtype=jnp.float32) + # Generate a random matrix + a = random.normal(rng, flat_shape, dtype=jnp.float32) - # Compute the qr factorization - q, r = jnp.linalg.qr(a) + # Compute the qr factorization + q, r = jnp.linalg.qr(a) - # Make Q uniform - d = jnp.diag(r) - q *= jnp.sign(d) + # Make Q uniform + d = jnp.diag(r) + q *= jnp.sign(d) - # Transpose and reshape back q if needed. - if n_rows < n_cols: - q = jnp.transpose(q) - q = jnp.reshape(q, shape) + # Transpose and reshape back q if needed. + if n_rows < n_cols: + q = jnp.transpose(q) + q = jnp.reshape(q, shape) - # Return scaled as requested. - return stddev * q + # Return scaled as requested. + return stddev * q - return Init + return Init def AtariConvInit(kernel_shape, rng, dtype=jnp.float32): - """The standard init for Conv laters and Atari.""" - filter_height, filter_width, fan_in, _ = kernel_shape - std = 1 / jnp.sqrt(fan_in * filter_height * filter_width) - return random.uniform(rng, kernel_shape, dtype, minval=-std, maxval=std) + """The standard init for Conv laters and Atari.""" + filter_height, filter_width, fan_in, _ = kernel_shape + std = 1 / jnp.sqrt(fan_in * filter_height * filter_width) + return random.uniform(rng, kernel_shape, dtype, minval=-std, maxval=std) diff --git a/trax/layers/initializers_test.py b/trax/layers/initializers_test.py deleted file mode 100644 index 921452c58..000000000 --- a/trax/layers/initializers_test.py +++ /dev/null @@ -1,96 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for initializers.""" - -from absl.testing import absltest -import numpy as np - -from trax import fastmath -from trax import test_utils -import trax.layers as tl - - -INPUT_SHAPE = (5, 7, 20) - - -def rng(): # Can't be a constant, because JAX has to init itself in main first. - return fastmath.random.get_prng(0) - - -class InitializersTest(absltest.TestCase): - - def test_random_normal(self): - f = tl.RandomNormalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_lecun_uniform(self): - f = tl.LeCunUniformInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_random_uniform(self): - f = tl.RandomUniformInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_glorot_normal(self): - f = tl.GlorotNormalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_glorot_uniform(self): - f = tl.GlorotUniformInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_lecun_normal(self): - f = tl.LeCunNormalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_kaiming_normal(self): - f = tl.KaimingNormalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_kaiming_uniform(self): - f = tl.KaimingUniformInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_orthogonal(self): - f = tl.OrthogonalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_from_file(self): - params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) - # `create_tempfile` needs access to --test_tmpdir, however in the OSS world - # pytest doesn't run `absltest.main`, so we need to manually parse the flags - test_utils.ensure_flag('test_tmpdir') - filename = self.create_tempfile('params.npy').full_path - with open(filename, 'wb') as f: - np.save(f, params) - f = tl.InitializerFromFile(filename) - init_value = f(params.shape, rng()) - np.testing.assert_almost_equal( - tl.to_list(init_value), tl.to_list(params), decimal=4) - # self.assertEqual('%s' % init_value, '%s' % params) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/metrics.py b/trax/layers/metrics.py index 9cea73922..edc1fb33f 100644 --- a/trax/layers/metrics.py +++ b/trax/layers/metrics.py @@ -53,231 +53,236 @@ def CategoryAccuracy(): - r"""Returns a layer that computes category prediction accuracy. + r"""Returns a layer that computes category prediction accuracy. - The layer takes two inputs: + The layer takes two inputs: - - A batch of activation vectors. The components in a given vector should - be mappable to a probability distribution in the following loose sense: - within a vector, a higher component value corresponds to a higher - probability, such that argmax within a vector (``axis=-1``) picks the - index (category) having the highest probablity. + - A batch of activation vectors. The components in a given vector should + be mappable to a probability distribution in the following loose sense: + within a vector, a higher component value corresponds to a higher + probability, such that argmax within a vector (``axis=-1``) picks the + index (category) having the highest probablity. - - A batch of target categories; each target is an integer in - :math:`\{0, ..., N-1\}`. + - A batch of target categories; each target is an integer in + :math:`\{0, ..., N-1\}`. - The predicted category from each vector is the index of the highest-valued - vector component. The layer returns the accuracy of these predictions - averaged over the batch. - """ - def f(model_output, targets): # pylint: disable=invalid-name - predictions = jnp.argmax(model_output, axis=-1) - shapes.assert_same_shape(predictions, targets) - n_total = predictions.size - n_correct = jnp.sum(jnp.equal(predictions, targets)) - return n_correct / n_total + The predicted category from each vector is the index of the highest-valued + vector component. The layer returns the accuracy of these predictions + averaged over the batch. + """ + + def f(model_output, targets): # pylint: disable=invalid-name + predictions = jnp.argmax(model_output, axis=-1) + shapes.assert_same_shape(predictions, targets) + n_total = predictions.size + n_correct = jnp.sum(jnp.equal(predictions, targets)) + return n_correct / n_total - return base.Fn('CategoryAccuracy', f) + return base.Fn("CategoryAccuracy", f) def _n_weights_per_core(weights): # pylint: disable=invalid-name - """Calculates the number of weights per core. - - In multi-device settings, gradients and losses are averaged over all devices. - When loss is weighted and the number of weights can differ by device, e.g., - when the weights represent the number of tokens in a batch of sentences (which - can differ from device to device), we want to make sure each token on each - device is weighted in the same way. This function ensures that by reporting - the number of weights per core in multi-core settings (and simply - np.sum(weights) in a single-core setting). - - Args: - weights: tensor with arbitrary shape - - Returns: - a scalar equal to np.sum(weights) in 1-machine settings and to the sum - of weights over all cores divided by the number of cores otherwise - """ - weights_sum = jnp.sum(weights) - if fastmath.global_device_count() < 2: - return weights_sum - else: - try: - n_devices_total = fastmath.psum(1, 'batch') - return fastmath.psum(weights_sum, 'batch') / n_devices_total - except (NameError, ValueError): # running outside of pmap, e.g., on init - return weights_sum # fall back to the sum + """Calculates the number of weights per core. + + In multi-device settings, gradients and losses are averaged over all devices. + When loss is weighted and the number of weights can differ by device, e.g., + when the weights represent the number of tokens in a batch of sentences (which + can differ from device to device), we want to make sure each token on each + device is weighted in the same way. This function ensures that by reporting + the number of weights per core in multi-core settings (and simply + np.sum(weights) in a single-core setting). + + Args: + weights: tensor with arbitrary shape + + Returns: + a scalar equal to np.sum(weights) in 1-machine settings and to the sum + of weights over all cores divided by the number of cores otherwise + """ + weights_sum = jnp.sum(weights) + if fastmath.global_device_count() < 2: + return weights_sum + else: + try: + n_devices_total = fastmath.psum(1, "batch") + return fastmath.psum(weights_sum, "batch") / n_devices_total + except (NameError, ValueError): # running outside of pmap, e.g., on init + return weights_sum # fall back to the sum def _non_nan(x): # pylint: disable=invalid-name - """Replaces NaN values with zeros. + """Replaces NaN values with zeros. - A support function replaces NaN values with zeros to escape - the undefined behavior of the division by zero. + A support function replaces NaN values with zeros to escape + the undefined behavior of the division by zero. - Args: - x: tensor with arbitrary shape. + Args: + x: tensor with arbitrary shape. - Returns: - Array with NaNs replaced with 0. - """ - return jnp.where(jnp.isnan(x), 0., x) + Returns: + Array with NaNs replaced with 0. + """ + return jnp.where(jnp.isnan(x), 0.0, x) def _precision_recall(predictions, targets, k): # pylint: disable=invalid-name - """Returns precision, recall, and intermediate values for the category `k`. - - A support function for calculating precision, recall, - and intermediate values for the single category `k` - for future use in metric layers. - - Args: - predictions: predicted categories. - targets: target categories. - k: a category number. - - Returns a tuple: - n_correct: a number of correct (or true) examples. - n_k_predictions: a number of predictions of the `k` category. - n_k_targets: a number of targets for the `k` category. - precision: a precision score. - recall: a recall score. - """ - n_correct = sum((predictions == k) & (targets == k)) - n_k_predictions = sum(predictions == k) - precision = _non_nan(n_correct / n_k_predictions) - n_k_targets = sum(targets == k) - recall = _non_nan(n_correct / n_k_targets) - return (n_correct, n_k_predictions, n_k_targets, precision, recall) + """Returns precision, recall, and intermediate values for the category `k`. + + A support function for calculating precision, recall, + and intermediate values for the single category `k` + for future use in metric layers. + + Args: + predictions: predicted categories. + targets: target categories. + k: a category number. + + Returns a tuple: + n_correct: a number of correct (or true) examples. + n_k_predictions: a number of predictions of the `k` category. + n_k_targets: a number of targets for the `k` category. + precision: a precision score. + recall: a recall score. + """ + n_correct = sum((predictions == k) & (targets == k)) + n_k_predictions = sum(predictions == k) + precision = _non_nan(n_correct / n_k_predictions) + n_k_targets = sum(targets == k) + recall = _non_nan(n_correct / n_k_targets) + return (n_correct, n_k_predictions, n_k_targets, precision, recall) def _f_score(precision, recall, beta2): # pylint: disable=invalid-name - """Returns F-score. + """Returns F-score. - Args: - precision: a precision score. - recall: a recall score. - beta2: a square of the parameter that determines the weight of recall. + Args: + precision: a precision score. + recall: a recall score. + beta2: a square of the parameter that determines the weight of recall. - A support function to calculate F-score for the single category. - """ - return _non_nan( - (beta2 + 1) * (precision * recall) / ((beta2 * precision) + recall)) + A support function to calculate F-score for the single category. + """ + return _non_nan((beta2 + 1) * (precision * recall) / ((beta2 * precision) + recall)) def WeightedCategoryAccuracy(): - r"""Returns a layer that computes a weighted category prediction accuracy. + r"""Returns a layer that computes a weighted category prediction accuracy. - The layer takes three inputs: + The layer takes three inputs: - - A batch of activation vectors. The components in a given vector should - be mappable to a probability distribution in the following loose sense: - within a vector, a higher component value corresponds to a higher - probability, such that argmax within a vector (``axis=-1``) picks the - index (category) having the highest probablity. + - A batch of activation vectors. The components in a given vector should + be mappable to a probability distribution in the following loose sense: + within a vector, a higher component value corresponds to a higher + probability, such that argmax within a vector (``axis=-1``) picks the + index (category) having the highest probablity. - - A batch of target categories; each target is an integer in - :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector - depth/dimensionality. + - A batch of target categories; each target is an integer in + :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector + depth/dimensionality. - - A batch of weights, which matches or can be broadcast to match the shape - of the target ndarray. This arg can give uneven weighting to different - items in the batch (depending, for instance, on the item's target - category). + - A batch of weights, which matches or can be broadcast to match the shape + of the target ndarray. This arg can give uneven weighting to different + items in the batch (depending, for instance, on the item's target + category). - The predicted category from each vector is the index of the highest-valued - vector component. The layer returns a weighted average accuracy of these - predictions. - """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - predictions = jnp.argmax(model_output, axis=-1) - shapes.assert_same_shape(predictions, targets) - ones_and_zeros = jnp.equal(predictions, targets) - return jnp.sum(ones_and_zeros * weights) / _n_weights_per_core(weights) + The predicted category from each vector is the index of the highest-valued + vector component. The layer returns a weighted average accuracy of these + predictions. + """ - return base.Fn('WeightedCategoryAccuracy', f) + def f(model_output, targets, weights): # pylint: disable=invalid-name + predictions = jnp.argmax(model_output, axis=-1) + shapes.assert_same_shape(predictions, targets) + ones_and_zeros = jnp.equal(predictions, targets) + return jnp.sum(ones_and_zeros * weights) / _n_weights_per_core(weights) + + return base.Fn("WeightedCategoryAccuracy", f) def CategoryCrossEntropy(label_smoothing=None): - r"""Returns a layer that computes cross-entropy from activations and integers. + r"""Returns a layer that computes cross-entropy from activations and integers. - The layer takes two inputs: + The layer takes two inputs: - - A batch of activation vectors. The components in a given vector should - be pre-softmax activations (mappable to a probability distribution via - softmax). For performance reasons, the softmax and cross-entropy - computations are combined inside the layer. + - A batch of activation vectors. The components in a given vector should + be pre-softmax activations (mappable to a probability distribution via + softmax). For performance reasons, the softmax and cross-entropy + computations are combined inside the layer. - - A batch of target categories; each target is an integer in - :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector - depth/dimensionality. + - A batch of target categories; each target is an integer in + :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector + depth/dimensionality. - To compute cross-entropy per batch item, the layer derives probability - distributions: + To compute cross-entropy per batch item, the layer derives probability + distributions: - - from model output (vectors): :math:`\ q = \text{softmax}(v)` + - from model output (vectors): :math:`\ q = \text{softmax}(v)` - - from target categories (integers): :math:`\ p = \text{one_hot}(n)` or - :math:`p = (1-\varepsilon)\cdot\text{one_hot}(n) + \frac{\varepsilon}{N}`, - where :math:`\varepsilon` is the label smoothing factor. + - from target categories (integers): :math:`\ p = \text{one_hot}(n)` or + :math:`p = (1-\varepsilon)\cdot\text{one_hot}(n) + \frac{\varepsilon}{N}`, + where :math:`\varepsilon` is the label smoothing factor. - (The conversion of integer category targets to one-hot vectors amounts to - assigning all the probability mass to the target category.) Cross-entropy - per batch item is computed between the resulting distributions: + (The conversion of integer category targets to one-hot vectors amounts to + assigning all the probability mass to the target category.) Cross-entropy + per batch item is computed between the resulting distributions: - .. math:: - \text{cross_entropy} = - \sum_{i=0}^{N-1} p_i \log q_i + .. math:: + \text{cross_entropy} = - \sum_{i=0}^{N-1} p_i \log q_i - The layer returns the average of these cross-entropy values over all items in - the batch. + The layer returns the average of these cross-entropy values over all items in + the batch. - Args: - label_smoothing: Creates soft targets if provided. Must be between 0 and 1. - """ - def f(model_output, targets): # pylint: disable=invalid-name - cross_entropies = _category_cross_entropy( - model_output, targets, label_smoothing, 0.0) - return jnp.average(cross_entropies) + Args: + label_smoothing: Creates soft targets if provided. Must be between 0 and 1. + """ + + def f(model_output, targets): # pylint: disable=invalid-name + cross_entropies = _category_cross_entropy( + model_output, targets, label_smoothing, 0.0 + ) + return jnp.average(cross_entropies) - return base.Fn('CategoryCrossEntropy', f) + return base.Fn("CategoryCrossEntropy", f) def WeightedCategoryCrossEntropy(label_smoothing=None, cutoff=0.0): - r"""Returns a layer like ``CategoryCrossEntropy``, with weights as third input. + r"""Returns a layer like ``CategoryCrossEntropy``, with weights as third input. - The layer takes three inputs: + The layer takes three inputs: - - A batch of activation vectors. The components in a given vector should - be pre-softmax activations (mappable to a probability distribution via - softmax). For performance reasons, the softmax and cross-entropy - computations are combined inside the layer. + - A batch of activation vectors. The components in a given vector should + be pre-softmax activations (mappable to a probability distribution via + softmax). For performance reasons, the softmax and cross-entropy + computations are combined inside the layer. - - A batch of target categories; each target is an integer in - :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector - depth/dimensionality. + - A batch of target categories; each target is an integer in + :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector + depth/dimensionality. - - A batch of weights, which matches or can be broadcast to match the shape - of the target ndarray. This arg can give uneven weighting to different - items in the batch (depending, for instance, on the item's target - category). + - A batch of weights, which matches or can be broadcast to match the shape + of the target ndarray. This arg can give uneven weighting to different + items in the batch (depending, for instance, on the item's target + category). - The layer returns the weighted average of these cross-entropy values over all - items in the batch. + The layer returns the weighted average of these cross-entropy values over all + items in the batch. - Args: - label_smoothing: Creates soft targets if provided. Must be between 0 and 1. - cutoff: Prevent loss lower than this cutoff (0.0 meaning none by default). - """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - cross_entropies = _category_cross_entropy( - model_output, targets, label_smoothing, cutoff) - return jnp.sum(cross_entropies * weights) / _n_weights_per_core(weights) + Args: + label_smoothing: Creates soft targets if provided. Must be between 0 and 1. + cutoff: Prevent loss lower than this cutoff (0.0 meaning none by default). + """ - return base.Fn('WeightedCategoryCrossEntropy', f) + def f(model_output, targets, weights): # pylint: disable=invalid-name + cross_entropies = _category_cross_entropy( + model_output, targets, label_smoothing, cutoff + ) + return jnp.sum(cross_entropies * weights) / _n_weights_per_core(weights) + + return base.Fn("WeightedCategoryCrossEntropy", f) def BinaryCrossEntropy(): - r"""Returns a layer that computes cross-entropy for binary classification. + r"""Returns a layer that computes cross-entropy for binary classification. The layer takes two inputs: @@ -305,156 +310,168 @@ def BinaryCrossEntropy(): The layer returns the average of these cross-entropy values over all items in the batch. """ - def f(model_output, targets): # pylint: disable=invalid-name - probabilities = fastmath.expit(model_output) - binary_entropies = - (targets * jnp.log(probabilities) + - (1 - targets) * (jnp.log(1 - probabilities))) - return jnp.average(binary_entropies) - - return base.Fn('BinaryCrossEntropy', f) + def f(model_output, targets): # pylint: disable=invalid-name + probabilities = fastmath.expit(model_output) + binary_entropies = -( + targets * jnp.log(probabilities) + + (1 - targets) * (jnp.log(1 - probabilities)) + ) + return jnp.average(binary_entropies) -def MaskedSequenceAccuracy(): - r"""Returns a layer that computes sequence prediction accuracy with masking. + return base.Fn("BinaryCrossEntropy", f) - This layer type is intended for variable length sequences, especially text, - represented as a batch of fixed-length sequences via padding for unused - positions. - - The layer takes three inputs: - - A batch of sequences of activation vectors. The components in a given - vector should be mappable to a probability distribution in the following - loose sense: within a vector, a higher component value corresponds to a - higher probability, such that argmax within a vector (``axis=-1``) picks - the index having the highest probablity. In text modeling, the index - represents a token id from a predetermined token vocabulary (or padding). - - - A batch of target integer sequences, with values in - :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector - depth/dimensionality. In text modeling, these sequences typically - represent token ids from a predetermined token vocabulary (or padding). - - - A batch of weights/masks, which matches or can be broadcast to match the - shape of the target ndarray. This arg is used to give weight 0 to padding - positions, which masks those positions out of the calculation. Only the - zero/non-zero distinction matters; all non-zero values are treated alike - as signaling non-masked (i.e., valid/in-use) positions. - - The predicted integer value for each sequence position is the index of the - highest-valued component of the position's vector. A predicted integer - sequence is judged correct if it matches the target integer sequence in all - non-zero-weighted positions. The layer returns the accuracy of predicted - sequences averaged over the batch. - """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - predictions = jnp.argmax(model_output, axis=-1) - shapes.assert_same_shape(predictions, targets) - position_is_padding = jnp.equal(weights, 0) - position_is_accurate = jnp.logical_or(jnp.equal(predictions, targets), - position_is_padding) - sequence_is_accurate = jnp.all(position_is_accurate, axis=-1) - return jnp.average(sequence_is_accurate) - - return base.Fn('MaskedSequenceAccuracy', f) +def MaskedSequenceAccuracy(): + r"""Returns a layer that computes sequence prediction accuracy with masking. + + This layer type is intended for variable length sequences, especially text, + represented as a batch of fixed-length sequences via padding for unused + positions. + + The layer takes three inputs: + + - A batch of sequences of activation vectors. The components in a given + vector should be mappable to a probability distribution in the following + loose sense: within a vector, a higher component value corresponds to a + higher probability, such that argmax within a vector (``axis=-1``) picks + the index having the highest probablity. In text modeling, the index + represents a token id from a predetermined token vocabulary (or padding). + + - A batch of target integer sequences, with values in + :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector + depth/dimensionality. In text modeling, these sequences typically + represent token ids from a predetermined token vocabulary (or padding). + + - A batch of weights/masks, which matches or can be broadcast to match the + shape of the target ndarray. This arg is used to give weight 0 to padding + positions, which masks those positions out of the calculation. Only the + zero/non-zero distinction matters; all non-zero values are treated alike + as signaling non-masked (i.e., valid/in-use) positions. + + The predicted integer value for each sequence position is the index of the + highest-valued component of the position's vector. A predicted integer + sequence is judged correct if it matches the target integer sequence in all + non-zero-weighted positions. The layer returns the accuracy of predicted + sequences averaged over the batch. + """ + + def f(model_output, targets, weights): # pylint: disable=invalid-name + predictions = jnp.argmax(model_output, axis=-1) + shapes.assert_same_shape(predictions, targets) + position_is_padding = jnp.equal(weights, 0) + position_is_accurate = jnp.logical_or( + jnp.equal(predictions, targets), position_is_padding + ) + sequence_is_accurate = jnp.all(position_is_accurate, axis=-1) + return jnp.average(sequence_is_accurate) + + return base.Fn("MaskedSequenceAccuracy", f) def Accuracy(classifier=core.ArgMax()): - """Returns a layer that computes mean category prediction accuracy. + """Returns a layer that computes mean category prediction accuracy. - DEPRECATED; use ``WeightedCategoryAccuracy`` instead. + DEPRECATED; use ``WeightedCategoryAccuracy`` instead. - Args: - classifier: Layer that transforms activation vectors into category - predictions. - """ - return cb.Serial(classifier, - _Accuracy(), - _WeightedMean(), - name='Accuracy', - sublayers_to_print=[]) + Args: + classifier: Layer that transforms activation vectors into category + predictions. + """ + return cb.Serial( + classifier, _Accuracy(), _WeightedMean(), name="Accuracy", sublayers_to_print=[] + ) def SequenceAccuracy(classifier=core.ArgMax()): - """Returns a layer that computes mean sequence prediction accuracy. + """Returns a layer that computes mean sequence prediction accuracy. - DEPRECATED; use ``MaskedSequenceAccuracy`` instead. + DEPRECATED; use ``MaskedSequenceAccuracy`` instead. - Args: - classifier: Layer that transforms activation vectors into category - predictions. - """ - return cb.Serial(classifier, - _Accuracy(), - _WeightedSequenceMean(), - name='SequenceAccuracy', - sublayers_to_print=[]) + Args: + classifier: Layer that transforms activation vectors into category + predictions. + """ + return cb.Serial( + classifier, + _Accuracy(), + _WeightedSequenceMean(), + name="SequenceAccuracy", + sublayers_to_print=[], + ) def CrossEntropyLoss(): - """Returns a layer that outputs multiclass prediction-target cross-entropy. + """Returns a layer that outputs multiclass prediction-target cross-entropy. - DEPRECATED; refactor to use ``WeightedCategoryCrossEntropy`` or - ``CategoryCrossEntropy`` instead. + DEPRECATED; refactor to use ``WeightedCategoryCrossEntropy`` or + ``CategoryCrossEntropy`` instead. - (``CrossEntropyLoss`` by itself does not compute cross-entropy. In older - code, this layer had to be preceded by ``LogSoftmax``, and the two layers - together did the work of converting category information to probability - distributions and computing the cross-entropy between those distributions. - All this is now done by ``WeightedCategoryCrossEntropy``.) - """ - return cb.Serial(_CrossEntropy(), - _WeightedMean(), - name='CrossEntropyLoss', - sublayers_to_print=[]) + (``CrossEntropyLoss`` by itself does not compute cross-entropy. In older + code, this layer had to be preceded by ``LogSoftmax``, and the two layers + together did the work of converting category information to probability + distributions and computing the cross-entropy between those distributions. + All this is now done by ``WeightedCategoryCrossEntropy``.) + """ + return cb.Serial( + _CrossEntropy(), _WeightedMean(), name="CrossEntropyLoss", sublayers_to_print=[] + ) def CrossEntropyLossWithLogSoftmax(): - """Mean prediction-target cross-entropy for multiclass classification.""" - return cb.Serial(core.LogSoftmax(), _CrossEntropy(), _WeightedMean(), - name='CrossEntropyLossWithLogSoftmax', - sublayers_to_print=[]) + """Mean prediction-target cross-entropy for multiclass classification.""" + return cb.Serial( + core.LogSoftmax(), + _CrossEntropy(), + _WeightedMean(), + name="CrossEntropyLossWithLogSoftmax", + sublayers_to_print=[], + ) def BinaryCrossEntropyLoss(): - """Returns a layer that outputs binary prediction-target cross-entropy. + """Returns a layer that outputs binary prediction-target cross-entropy. - DEPRECATED; refactor to use ``BinaryCrossEntropy`` instead. (The newer - ``BinaryCrossEntropy`` does not use weights, so refactor accordingly. Unless - and until clear motivating use cases arise, the library will not include a - binary cross-entropy function with weights.) - """ - return cb.Serial(_BinaryCrossEntropy(), - _WeightedMean(), - name='BinaryCrossEntropyLoss', - sublayers_to_print=[]) + DEPRECATED; refactor to use ``BinaryCrossEntropy`` instead. (The newer + ``BinaryCrossEntropy`` does not use weights, so refactor accordingly. Unless + and until clear motivating use cases arise, the library will not include a + binary cross-entropy function with weights.) + """ + return cb.Serial( + _BinaryCrossEntropy(), + _WeightedMean(), + name="BinaryCrossEntropyLoss", + sublayers_to_print=[], + ) def L2Loss(): - r"""Returns a layer that computes an L2-like loss for one batch. + r"""Returns a layer that computes an L2-like loss for one batch. - The layer takes three inputs: + The layer takes three inputs: - - Model output from one batch, an ndarray of float-valued elements. + - Model output from one batch, an ndarray of float-valued elements. - - A batch of element-wise target values, which matches the shape of the - model output. + - A batch of element-wise target values, which matches the shape of the + model output. - - A batch of weights, which matches the shape of the model output. + - A batch of weights, which matches the shape of the model output. - The layer returns a weighted average of element-wise squared error terms - :math:`(y_i - t_i)^2`. - """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - shapes.assert_same_shape(model_output, targets) - shapes.assert_same_shape(model_output, weights) - weighted_sse = weights * (model_output - targets)**2 - return jnp.sum(weighted_sse) / jnp.sum(weights) - return base.Fn('L2Loss', f) + The layer returns a weighted average of element-wise squared error terms + :math:`(y_i - t_i)^2`. + """ + + def f(model_output, targets, weights): # pylint: disable=invalid-name + shapes.assert_same_shape(model_output, targets) + shapes.assert_same_shape(model_output, weights) + weighted_sse = weights * (model_output - targets) ** 2 + return jnp.sum(weighted_sse) / jnp.sum(weights) + + return base.Fn("L2Loss", f) def SmoothL1Loss(): - r"""Returns a layer that computes a weighted, smoothed L1 loss for one batch. + r"""Returns a layer that computes a weighted, smoothed L1 loss for one batch. The layer takes three inputs: @@ -476,178 +493,200 @@ def SmoothL1Loss(): The layer returns a weighted average of these element-wise values. """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - shapes.assert_same_shape(model_output, targets) - shapes.assert_same_shape(model_output, weights) - l1_dist = jnp.abs(model_output - targets) - smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5) - weighted_smooth_dist = weights * smooth_dist - return jnp.sum(weighted_smooth_dist) / jnp.sum(weights) - return base.Fn('SmoothL1Loss', f) + def f(model_output, targets, weights): # pylint: disable=invalid-name + shapes.assert_same_shape(model_output, targets) + shapes.assert_same_shape(model_output, weights) + l1_dist = jnp.abs(model_output - targets) + smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5) + weighted_smooth_dist = weights * smooth_dist + return jnp.sum(weighted_smooth_dist) / jnp.sum(weights) -def MacroAveragedFScore(beta=1., initial_category_index=0): - r"""Returns a layer that computes a macro-averaged F-score. + return base.Fn("SmoothL1Loss", f) - The macro-averaged F-score summarize how well the classifier's `k` predictions - align with the observed/gold instances of `k`. It additionally cares about - all the classes equally regardless of their size. - Args: - beta: a parameter that determines the weight of recall in the F-score. - initial_category_index: an index of the initial category. +def MacroAveragedFScore(beta=1.0, initial_category_index=0): + r"""Returns a layer that computes a macro-averaged F-score. - The layer takes two inputs: + The macro-averaged F-score summarize how well the classifier's `k` predictions + align with the observed/gold instances of `k`. It additionally cares about + all the classes equally regardless of their size. - - Model output from one batch, an ndarray of float-valued elements. + Args: + beta: a parameter that determines the weight of recall in the F-score. + initial_category_index: an index of the initial category. - - A batch of element-wise target values, which matches the shape of the - model output. + The layer takes two inputs: - The layer returns an macro-averaged F-score across all the classes. - """ - def f(model_output, targets): # pylint: disable=invalid-name - beta2 = beta ** 2 - predictions = jnp.argmax(model_output, axis=-1) - n_categories = model_output.shape[-1] - f_scores = jnp.empty(0) - for k in range(initial_category_index, n_categories): - _, _, _, precision, recall = _precision_recall(predictions, targets, k) - f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) - return jnp.mean(f_scores) + - Model output from one batch, an ndarray of float-valued elements. - return base.Fn('MacroAveragedFScore', f) + - A batch of element-wise target values, which matches the shape of the + model output. + The layer returns an macro-averaged F-score across all the classes. + """ -def WeightedFScore(beta=1., initial_category_index=0): - """Returns a layer that computes a weighted F-score. + def f(model_output, targets): # pylint: disable=invalid-name + beta2 = beta**2 + predictions = jnp.argmax(model_output, axis=-1) + n_categories = model_output.shape[-1] + f_scores = jnp.empty(0) + for k in range(initial_category_index, n_categories): + _, _, _, precision, recall = _precision_recall(predictions, targets, k) + f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) + return jnp.mean(f_scores) - The weighted F-score summarize how well the classifier's `k` predictions - align with the observed/gold instances of `k`. It additionally - weights the summary by the number of observed/gold and predicted examples - in each class. + return base.Fn("MacroAveragedFScore", f) - Args: - beta: a parameter that determines the weight of recall in the F-score. - initial_category_index: an index of the initial category. - The layer takes two inputs: +def WeightedFScore(beta=1.0, initial_category_index=0): + """Returns a layer that computes a weighted F-score. - - Model output from one batch, an ndarray of float-valued elements. + The weighted F-score summarize how well the classifier's `k` predictions + align with the observed/gold instances of `k`. It additionally + weights the summary by the number of observed/gold and predicted examples + in each class. - - A batch of element-wise target values, which matches the shape of the - model output. + Args: + beta: a parameter that determines the weight of recall in the F-score. + initial_category_index: an index of the initial category. - The layer returns a weighted F-score across all the classes. - """ - def f(model_output, targets): # pylint: disable=invalid-name - beta2 = beta ** 2 - predictions = jnp.argmax(model_output, axis=-1) - n_categories = model_output.shape[-1] - f_scores = jnp.empty(0) - weights = jnp.empty(0) - for k in range(initial_category_index, n_categories): - _, _, n_k_targets, precision, recall = _precision_recall( - predictions, targets, k) - f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) - weights = jnp.append(weights, n_k_targets) - return jnp.average(f_scores, weights=weights) + The layer takes two inputs: + + - Model output from one batch, an ndarray of float-valued elements. + + - A batch of element-wise target values, which matches the shape of the + model output. - return base.Fn('WeightedFScore', f) + The layer returns a weighted F-score across all the classes. + """ + + def f(model_output, targets): # pylint: disable=invalid-name + beta2 = beta**2 + predictions = jnp.argmax(model_output, axis=-1) + n_categories = model_output.shape[-1] + f_scores = jnp.empty(0) + weights = jnp.empty(0) + for k in range(initial_category_index, n_categories): + _, _, n_k_targets, precision, recall = _precision_recall( + predictions, targets, k + ) + f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) + weights = jnp.append(weights, n_k_targets) + return jnp.average(f_scores, weights=weights) + + return base.Fn("WeightedFScore", f) def WeightedSum(): - """Returns a layer that computes a weighted sum of the given values.""" - def f(values, weights): # pylint: disable=invalid-name - return jnp.sum(values * weights) - return base.Fn('WeightedSum', f) + """Returns a layer that computes a weighted sum of the given values.""" + + def f(values, weights): # pylint: disable=invalid-name + return jnp.sum(values * weights) + + return base.Fn("WeightedSum", f) def _Accuracy(): - """Returns a layer that scores predicted versus target category.""" - def f(predicted_category, target_category): # pylint: disable=invalid-name - # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. - # shapes.assert_same_shape(predicted_category, target_category) - return jnp.equal(predicted_category, target_category).astype(jnp.float32) - return base.Fn('_Accuracy', f) + """Returns a layer that scores predicted versus target category.""" + + def f(predicted_category, target_category): # pylint: disable=invalid-name + # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. + # shapes.assert_same_shape(predicted_category, target_category) + return jnp.equal(predicted_category, target_category).astype(jnp.float32) + + return base.Fn("_Accuracy", f) def _CrossEntropy(): - """Returns a layer that computes prediction-target cross entropies.""" - def f(model_output, target_category): # pylint: disable=invalid-name - # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. - # shapes.assert_shape_equals(target_category, model_output.shape[:-1]) - target_distribution = core.one_hot(target_category, model_output.shape[-1]) - return -1.0 * jnp.sum(model_output * target_distribution, axis=-1) - return base.Fn('_CrossEntropy', f) + """Returns a layer that computes prediction-target cross entropies.""" + + def f(model_output, target_category): # pylint: disable=invalid-name + # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. + # shapes.assert_shape_equals(target_category, model_output.shape[:-1]) + target_distribution = core.one_hot(target_category, model_output.shape[-1]) + return -1.0 * jnp.sum(model_output * target_distribution, axis=-1) + + return base.Fn("_CrossEntropy", f) def _BinaryCrossEntropy(): - """Returns a layer that computes prediction-target cross entropies.""" - def f(model_output, target_category): # pylint: disable=invalid-name - shapes.assert_same_shape(model_output, target_category) - batch_size = model_output.shape[0] - j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output)) - j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output)) - j = -1.0/batch_size * jnp.squeeze(j) - return j - return base.Fn('_BinaryCrossEntropy', f) + """Returns a layer that computes prediction-target cross entropies.""" + + def f(model_output, target_category): # pylint: disable=invalid-name + shapes.assert_same_shape(model_output, target_category) + batch_size = model_output.shape[0] + j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output)) + j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output)) + j = -1.0 / batch_size * jnp.squeeze(j) + return j + + return base.Fn("_BinaryCrossEntropy", f) def CrossEntropySum(): - """Sum of prediction-target cross entropies for multiclass classification.""" - return cb.Serial(_CrossEntropy(), - WeightedSum(), - name='CrossEntropySum', - sublayers_to_print=[]) + """Sum of prediction-target cross entropies for multiclass classification.""" + return cb.Serial( + _CrossEntropy(), WeightedSum(), name="CrossEntropySum", sublayers_to_print=[] + ) def BinaryCrossEntropySum(): - """Sum of prediction-target cross entropies for binary classification.""" - return cb.Serial(_BinaryCrossEntropy(), - WeightedSum(), - name='BinaryCrossEntropySum', - sublayers_to_print=[]) + """Sum of prediction-target cross entropies for binary classification.""" + return cb.Serial( + _BinaryCrossEntropy(), + WeightedSum(), + name="BinaryCrossEntropySum", + sublayers_to_print=[], + ) + + # pylint: enable=no-value-for-parameter def _WeightedMean(): - """Returns a layer that computes a weighted mean of the given values.""" - def f(values, weights): # pylint: disable=invalid-name - return jnp.sum(values * weights) / _n_weights_per_core(weights) - return base.Fn('_WeightedMean', f) + """Returns a layer that computes a weighted mean of the given values.""" + + def f(values, weights): # pylint: disable=invalid-name + return jnp.sum(values * weights) / _n_weights_per_core(weights) + + return base.Fn("_WeightedMean", f) def _WeightedSequenceMean(): - """Returns a layer that computes a weighted sequence accuracy mean.""" - def f(values, weights): # pylint: disable=invalid-name - # This function assumes weights are 0 or 1. - # Then compute 1: not-correct, 0: correct or masked - not_correct = (1.0 - values) * weights - axis_to_sum = list(range(1, len(not_correct.shape))) - # Summing not-correct on all axes but batch. We're summing 0s and 1s, - # so the sum is 0 if it's all 0 and >=1 in all other cases. - not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum) - # Sequence is correct if not_correct_seq is 0, reverting here. - correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq) - return jnp.mean(correct_seq) # Mean over batch. - return base.Fn('_WeightedSequenceMean', f) + """Returns a layer that computes a weighted sequence accuracy mean.""" + + def f(values, weights): # pylint: disable=invalid-name + # This function assumes weights are 0 or 1. + # Then compute 1: not-correct, 0: correct or masked + not_correct = (1.0 - values) * weights + axis_to_sum = list(range(1, len(not_correct.shape))) + # Summing not-correct on all axes but batch. We're summing 0s and 1s, + # so the sum is 0 if it's all 0 and >=1 in all other cases. + not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum) + # Sequence is correct if not_correct_seq is 0, reverting here. + correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq) + return jnp.mean(correct_seq) # Mean over batch. + + return base.Fn("_WeightedSequenceMean", f) def _category_cross_entropy( # pylint: disable=invalid-name - model_output, targets, label_smoothing, cutoff): - """Computes category cross entropy with label smoothing.""" - n_categories = model_output.shape[-1] - target_distributions = core.one_hot(targets, n_categories) - if label_smoothing: - if label_smoothing < 0. or label_smoothing > 1.: - raise ValueError( - f'Arg label_smoothing ({label_smoothing}) must be between 0 and 1.') - target_distributions *= (1. - label_smoothing) - target_distributions += label_smoothing / n_categories - model_log_distributions = core.log_softmax(model_output) - cross_ent = - jnp.sum(target_distributions * model_log_distributions, axis=-1) - if cutoff > 0.0: - return jnp.maximum(cross_ent, cutoff) - cutoff - else: - return cross_ent + model_output, targets, label_smoothing, cutoff +): + """Computes category cross entropy with label smoothing.""" + n_categories = model_output.shape[-1] + target_distributions = core.one_hot(targets, n_categories) + if label_smoothing: + if label_smoothing < 0.0 or label_smoothing > 1.0: + raise ValueError( + f"Arg label_smoothing ({label_smoothing}) must be between 0 and 1." + ) + target_distributions *= 1.0 - label_smoothing + target_distributions += label_smoothing / n_categories + model_log_distributions = core.log_softmax(model_output) + cross_ent = -jnp.sum(target_distributions * model_log_distributions, axis=-1) + if cutoff > 0.0: + return jnp.maximum(cross_ent, cutoff) - cutoff + else: + return cross_ent diff --git a/trax/layers/metrics_test.py b/trax/layers/metrics_test.py deleted file mode 100644 index 3c59da790..000000000 --- a/trax/layers/metrics_test.py +++ /dev/null @@ -1,430 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for metrics layers.""" - -from absl.testing import absltest -import numpy as np -import trax.layers as tl - - -class MetricsTest(absltest.TestCase): - - def test_category_accuracy(self): - layer = tl.CategoryAccuracy() - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets]) - self.assertEqual(accuracy, 1 / 3) - - def test_weighted_category_accuracy_even_weights(self): - layer = tl.WeightedCategoryAccuracy() - weights = np.array([1., 1., 1.]) - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1 / 3) - - def test_weighted_category_accuracy_uneven_weights(self): - layer = tl.WeightedCategoryAccuracy() - weights = np.array([1., 5., 2.]) - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .7, .1, 0.], - [.2, .7, .1, 0.], - [.2, .7, .1, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .625) - - def test_category_cross_entropy(self): - layer = tl.CategoryCrossEntropy() - targets = np.array([0, 1]) - - # Near-perfect prediction (for both items in batch). - model_outputs = np.array([[9., 2., 0., -2.], - [2., 9., 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .001, places=3) - - # More right than wrong (for both items in batch). - model_outputs = np.array([[2.2, 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .665, places=3) - - # First item near perfect, second item more right than wrong. - model_outputs = np.array([[9., 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .333, places=3) - - def test_category_cross_entropy_with_label_smoothing(self): - epsilon = 0.01 - layer = tl.CategoryCrossEntropy(label_smoothing=epsilon) - targets = np.array([0, 1]) - - # Near-perfect prediction (for both items in batch). - model_outputs = np.array([[9., 2., 0., -2.], - [2., 9., 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .069, places=3) - - # More right than wrong (for both items in batch). - model_outputs = np.array([[2.2, 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .682, places=3) - - # First item near perfect, second item more right than wrong. - model_outputs = np.array([[9., 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .375, places=3) - - def test_weighted_category_cross_entropy(self): - layer = tl.WeightedCategoryCrossEntropy() - targets = np.array([0, 1]) - weights = np.array([30, 10]) - - # Near-perfect prediction (for both items in batch). - model_outputs = np.array([[9., 2., 0., -2.], - [2., 9., 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .001, places=3) - - # More right than wrong (for both items in batch). - model_outputs = np.array([[2.2, 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .665, places=3) - - # First item (with 75% weight) near perfect, second more right than wrong. - model_outputs = np.array([[9., 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .167, places=3) - - def test_weighted_category_cross_entropy_with_label_smoothing(self): - epsilon = 0.01 - layer = tl.WeightedCategoryCrossEntropy(label_smoothing=epsilon) - targets = np.array([0, 1]) - weights = np.array([30, 10]) - - # Near-perfect prediction (for both items in batch). - model_outputs = np.array([[9., 2., 0., -2.], - [2., 9., 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .069, places=3) - - # More right than wrong (for both items in batch). - model_outputs = np.array([[2.2, 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .682, places=3) - - # First item (with 75% weight) near perfect, second more right than wrong. - model_outputs = np.array([[9., 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .222, places=3) - - def test_masked_sequence_accuracy(self): - layer = tl.MaskedSequenceAccuracy() - targets = np.array([[0, 1, 0, 0], - [1, 0, 1, 0]]) - weights = np.array([[1., 1., 1., 0.], - [1., 1., 1., 0.]]) - - # Model gets both sequences right; output in final position would give - # wrong category but is ignored. - model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.35, .65]], - [[.3, .7], [.8, .2], [.1, .9], [.35, .65]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.) - - # Model gets the first element of the first sequence barely wrong. - model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.6, .4]], - [[.3, .7], [.8, .2], [.1, .9], [.6, .4]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .5) - - # Model gets second-to-last element of each sequence barely wrong. - model_outputs = np.array([[[.9, .1], [.2, .8], [.48, .52], [.6, .4]], - [[.3, .7], [.8, .2], [.51, .49], [.6, .4]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 0.) - - def test_binary_cross_entropy(self): - layer = tl.BinaryCrossEntropy() - targets = np.array([1, 1, 0, 0, 0]) - - # Near-perfect prediction for all five items in batch. - model_outputs = np.array([9., 9., -9., -9., -9.]) - metric_output = layer([model_outputs, targets]) - self.assertAlmostEqual(metric_output, 0.000123, places=6) - - # More right than wrong for all five items in batch. - model_outputs = np.array([1., 1., -1., -1., -1.]) - metric_output = layer([model_outputs, targets]) - self.assertAlmostEqual(metric_output, 0.313, places=3) - - # Near-perfect for 2, more right than wrong for 3. - model_outputs = np.array([9., 1., -1., -1., -9.]) - metric_output = layer([model_outputs, targets]) - self.assertAlmostEqual(metric_output, 0.188, places=3) - - # More wrong than right for all five. - model_outputs = np.array([-1., -1., 1., 1., 1.]) - metric_output = layer([model_outputs, targets]) - self.assertAlmostEqual(metric_output, 1.313, places=3) - - def test_accuracy_even_weights(self): - layer = tl.Accuracy() - weights = np.array([1., 1., 1.]) - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1 / 3) - - def test_accuracy_uneven_weights(self): - layer = tl.Accuracy() - weights = np.array([1., 5., 2.]) - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .7, .1, 0.], - [.2, .7, .1, 0.], - [.2, .7, .1, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .625) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.7, .2, .1, 0.], - [.7, .2, .1, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .125) - - def test_accuracy_binary_classifier(self): - layer = tl.Accuracy(classifier=tl.ThresholdToBinary()) - targets = np.array([[0, 0, 1, 1], - [1, 1, 1, 0]]) - weights = np.ones_like(targets) - - model_outputs = np.array([[.499, .500, .501, .502], - [.503, .502, .501, .500]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.498, .499, .500, .501], - [.502, .501, .500, .499]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .75) - - def test_sequence_accuracy_weights_all_ones(self): - layer = tl.SequenceAccuracy() - targets = np.array([[0, 1, 0, 1], - [1, 0, 1, 1]]) - weights = np.ones_like(targets) - - # Model gets both sequences right; for each position in each sequence, the - # category (integer ID) selected by argmax matches the target category. - model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.4, .6]], - [[.3, .7], [.8, .2], [.1, .9], [.4, .6]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.) - - # Model gets the first element of the first sequence barely wrong. - model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.4, .6]], - [[.3, .7], [.8, .2], [.1, .9], [.4, .6]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .5) - - # Model gets the last element of each sequence barely wrong. - model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.55, .45]], - [[.3, .7], [.8, .2], [.1, .9], [.52, .48]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 0.) - - def test_sequence_accuracy_last_position_zero_weight(self): - layer = tl.SequenceAccuracy() - targets = np.array([[0, 1, 0, 0], - [1, 0, 1, 0]]) - weights = np.array([[1., 1., 1., 0.], - [1., 1., 1., 0.]]) - - # Model gets both sequences right; output in final position would give - # wrong category but is ignored. - model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.35, .65]], - [[.3, .7], [.8, .2], [.1, .9], [.35, .65]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.) - - # Model gets the first element of the first sequence barely wrong. - model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.6, .4]], - [[.3, .7], [.8, .2], [.1, .9], [.6, .4]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .5) - - # Model gets second-to-last element of each sequence barely wrong. - model_outputs = np.array([[[.9, .1], [.2, .8], [.48, .52], [.6, .4]], - [[.3, .7], [.8, .2], [.51, .49], [.6, .4]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 0.) - - def test_binary_cross_entropy_loss(self): - # TODO(jonni): Clarify desired semantics/naming, then test it. - layer = tl.BinaryCrossEntropyLoss() - xs = [np.ones((9, 1)), - np.ones((9, 1)), - np.ones((9, 1))] - y = layer(xs) - self.assertEqual(y.shape, ()) - - def test_cross_entropy_loss(self): - # TODO(jonni): Clarify desired semantics/naming, then test it. - layer = tl.CrossEntropyLoss() - xs = [np.ones((9, 4, 4, 20)), - np.ones((9, 4, 4)), - np.ones((9, 4, 4))] - y = layer(xs) - self.assertEqual(y.shape, ()) - - def test_l2_loss(self): - layer = tl.L2Loss() - - model_outputs = np.array([[1., 1.], [1., 1.]]) - targets = np.array([[1., 1.], [1., 0.]]) - weights = np.array([[1., 1.], [1., 0.]]) - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, 0.0) - - weights = np.array([[1., 0.], [0., 1.]]) - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, 0.5) - - def test_smooth_l1_loss(self): - layer = tl.SmoothL1Loss() - - model_outputs = np.array([[1., 1.], [1., 2.]]) - targets = np.array([[1., 1.], [1., 0.]]) - l1_dist = 2 - - weights = np.array([[1., 1.], [1., 0.]]) - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, 0.0) - - weights = np.array([[1., 0.], [0., 1.]]) - sum_weights = 2 - - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, (l1_dist-0.5) / sum_weights) - - model_outputs = np.array([[1., 1.], [1., 1.5]]) - targets = np.array([[1., 1.], [1., 1.]]) - l1_dist = 0.5 - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, 0.5 * l1_dist**2 / sum_weights) - - def test_macro_averaged_f_score(self): - # predictions = [1, 1, 2, 1, 1]. - model_outputs = np.array([[0, 1, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 1, 0, 0], - [0, 1, 0, 0]]) - targets = np.array([1, 2, 2, 3, 1]) - # Category indices starting with `0`. - layer = tl.MacroAveragedFScore() - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .333, places=3) - # Excluding the padding index `0`. - layer = tl.MacroAveragedFScore(initial_category_index=1) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .444, places=3) - - def test_weighted_f_score(self): - # predictions = [1, 1, 2, 1, 1]. - model_outputs = np.array([[0, 1, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 1, 0, 0], - [0, 1, 0, 0]]) - targets = np.array([1, 2, 2, 3, 1]) - # Category indices starting with `0`. - layer = tl.WeightedFScore() - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .533, places=3) - # Excluding the padding index `0`. - layer = tl.WeightedFScore(initial_category_index=1) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .533, places=3) - - def test_names(self): - layer = tl.L2Loss() - self.assertEqual('L2Loss_in3', str(layer)) - layer = tl.Accuracy() - self.assertEqual('Accuracy_in3', str(layer)) - layer = tl.SequenceAccuracy() - self.assertEqual('SequenceAccuracy_in3', str(layer)) - layer = tl.BinaryCrossEntropyLoss() - self.assertEqual('BinaryCrossEntropyLoss_in3', str(layer)) - layer = tl.CrossEntropyLoss() - self.assertEqual('CrossEntropyLoss_in3', str(layer)) - layer = tl.BinaryCrossEntropySum() - self.assertEqual('BinaryCrossEntropySum_in3', str(layer)) - layer = tl.CrossEntropySum() - self.assertEqual('CrossEntropySum_in3', str(layer)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/normalization.py b/trax/layers/normalization.py index 8c3e5c138..ad2d91407 100644 --- a/trax/layers/normalization.py +++ b/trax/layers/normalization.py @@ -20,189 +20,200 @@ class BatchNorm(base.Layer): - """Layer that performs batch normalization. - - In training, batch normalization keeps smoothed cumulative statistics across - batches of input data and modifies each new batch so that its components are - normally distributed. In eval or inference, a `BatchNorm` instance uses its - stored mean and variance to approximately normalize each new batch of data. - - See https://arxiv.org/abs/1502.03167 for original presentation and motivation - of batch normalization). - """ - - def __init__(self, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, - momentum=0.999, mode='train'): - super().__init__() - self._axis = axis - self._epsilon = epsilon - self._center = center - self._scale = scale - self._momentum = momentum - self._mode = mode - - def forward(self, x): - """Computes batch normalization as part of a forward pass in the model.""" - running_mean, running_var, n_batches = self.state - if self._mode == 'train': - n_batches += 1 - mean, var = self._fast_mean_and_variance(x) - # Gather smoothed input statistics for later use in evals or inference. - running_mean = _exponentially_smoothed(self._momentum, running_mean, mean) - running_var = _exponentially_smoothed(self._momentum, running_var, var) - self.state = (running_mean, running_var, n_batches) - else: - mean = running_mean - var = running_var - - z = self._z_score(x, mean, var) - beta, gamma = self._beta_gamma_with_correct_axes(x, self.weights) - - # Return the z rescaled by the parameters if requested. - if self._center and self._scale: - output = gamma * z + beta - elif self._center: - output = z + beta - elif self._scale: - output = gamma * z - else: - output = z - if output.dtype != x.dtype: - raise TypeError(f'The dtype of the output ({output.dtype}) of batch ' - f'norm is not the same as the input ({x.dtype}). ' - f'Batch norm should not change the dtype.') - return output - - def init_weights_and_state(self, input_signature): - """Helper to initialize batch norm weights and state.""" - axis = self._axis - axis = (axis,) if jnp.isscalar(axis) else axis - input_shape = input_signature.shape - shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) - # TODO(jonni): Should beta and gamma match the dtype in the input signature? - beta = jnp.zeros(shape, dtype='float32') if self._center else () - gamma = jnp.ones(shape, dtype='float32') if self._scale else () - def get_stats_axis(i, d): - if i in axis: - return 1 - else: - return d - stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) - running_mean = jnp.zeros(stats_shape, dtype=jnp.float32) - running_var = jnp.ones(stats_shape, dtype=jnp.float32) - n_batches = jnp.zeros((), dtype=jnp.int64) - self.weights = (beta, gamma) - self.state = (running_mean, running_var, n_batches) - - def _fast_mean_and_variance(self, x): - mean = jnp.mean(x, self._axis, keepdims=True) - # Fast but less numerically-stable variance calculation than jnp.var. - m1 = jnp.mean(x**2, self._axis, keepdims=True) - variance = m1 - mean**2 - return mean, variance - - def _z_score(self, x, mean, variance): - mu = mean.astype(x.dtype) - sigma = jnp.sqrt(variance + self._epsilon).astype(x.dtype) - return (x - mu) / sigma - - def _beta_gamma_with_correct_axes(self, x, weights): - # Expand the parameters to have the right axes. - beta, gamma = weights - # TODO(phawkins): jnp.expand_dims should accept an axis tuple. - # (https://github.com/numpy/numpy/issues/12290) - ed = tuple(None if i in self._axis else slice(None) - for i in range(jnp.ndim(x))) - beta = beta[ed] - gamma = gamma[ed] - return beta, gamma + """Layer that performs batch normalization. + + In training, batch normalization keeps smoothed cumulative statistics across + batches of input data and modifies each new batch so that its components are + normally distributed. In eval or inference, a `BatchNorm` instance uses its + stored mean and variance to approximately normalize each new batch of data. + + See https://arxiv.org/abs/1502.03167 for original presentation and motivation + of batch normalization). + """ + + def __init__( + self, + axis=(0, 1, 2), + epsilon=1e-5, + center=True, + scale=True, + momentum=0.999, + mode="train", + ): + super().__init__() + self._axis = axis + self._epsilon = epsilon + self._center = center + self._scale = scale + self._momentum = momentum + self._mode = mode + + def forward(self, x): + """Computes batch normalization as part of a forward pass in the model.""" + running_mean, running_var, n_batches = self.state + if self._mode == "train": + n_batches += 1 + mean, var = self._fast_mean_and_variance(x) + # Gather smoothed input statistics for later use in evals or inference. + running_mean = _exponentially_smoothed(self._momentum, running_mean, mean) + running_var = _exponentially_smoothed(self._momentum, running_var, var) + self.state = (running_mean, running_var, n_batches) + else: + mean = running_mean + var = running_var + + z = self._z_score(x, mean, var) + beta, gamma = self._beta_gamma_with_correct_axes(x, self.weights) + + # Return the z rescaled by the parameters if requested. + if self._center and self._scale: + output = gamma * z + beta + elif self._center: + output = z + beta + elif self._scale: + output = gamma * z + else: + output = z + if output.dtype != x.dtype: + raise TypeError( + f"The dtype of the output ({output.dtype}) of batch " + f"norm is not the same as the input ({x.dtype}). " + f"Batch norm should not change the dtype." + ) + return output + + def init_weights_and_state(self, input_signature): + """Helper to initialize batch norm weights and state.""" + axis = self._axis + axis = (axis,) if jnp.isscalar(axis) else axis + input_shape = input_signature.shape + shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) + # TODO(jonni): Should beta and gamma match the dtype in the input signature? + beta = jnp.zeros(shape, dtype="float32") if self._center else () + gamma = jnp.ones(shape, dtype="float32") if self._scale else () + + def get_stats_axis(i, d): + if i in axis: + return 1 + else: + return d + + stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) + running_mean = jnp.zeros(stats_shape, dtype=jnp.float32) + running_var = jnp.ones(stats_shape, dtype=jnp.float32) + n_batches = jnp.zeros((), dtype=jnp.int64) + self.weights = (beta, gamma) + self.state = (running_mean, running_var, n_batches) + + def _fast_mean_and_variance(self, x): + mean = jnp.mean(x, self._axis, keepdims=True) + # Fast but less numerically-stable variance calculation than jnp.var. + m1 = jnp.mean(x**2, self._axis, keepdims=True) + variance = m1 - mean**2 + return mean, variance + + def _z_score(self, x, mean, variance): + mu = mean.astype(x.dtype) + sigma = jnp.sqrt(variance + self._epsilon).astype(x.dtype) + return (x - mu) / sigma + + def _beta_gamma_with_correct_axes(self, x, weights): + # Expand the parameters to have the right axes. + beta, gamma = weights + # TODO(phawkins): jnp.expand_dims should accept an axis tuple. + # (https://github.com/numpy/numpy/issues/12290) + ed = tuple(None if i in self._axis else slice(None) for i in range(jnp.ndim(x))) + beta = beta[ed] + gamma = gamma[ed] + return beta, gamma class LayerNorm(base.Layer): - """Layer normalization.""" + """Layer normalization.""" - def __init__(self, center=True, epsilon=1e-6): - super().__init__() - self._epsilon = epsilon - self._center = center + def __init__(self, center=True, epsilon=1e-6): + super().__init__() + self._epsilon = epsilon + self._center = center - def forward(self, x): - scale, bias = self.weights - mean = jnp.mean(x, axis=-1, keepdims=True) - centered = x - mean if self._center else x - variance = jnp.mean(centered * centered, axis=-1, keepdims=True) - norm_inputs = centered / jnp.sqrt(variance + self._epsilon) - scaled = norm_inputs * scale - return scaled + bias if self._center else scaled + def forward(self, x): + scale, bias = self.weights + mean = jnp.mean(x, axis=-1, keepdims=True) + centered = x - mean if self._center else x + variance = jnp.mean(centered * centered, axis=-1, keepdims=True) + norm_inputs = centered / jnp.sqrt(variance + self._epsilon) + scaled = norm_inputs * scale + return scaled + bias if self._center else scaled - def init_weights_and_state(self, input_signature): - features = input_signature.shape[-1] - scale = jnp.ones(features, dtype=input_signature.dtype) - bias = jnp.zeros(features, dtype=input_signature.dtype) - self.weights = scale, bias + def init_weights_and_state(self, input_signature): + features = input_signature.shape[-1] + scale = jnp.ones(features, dtype=input_signature.dtype) + bias = jnp.zeros(features, dtype=input_signature.dtype) + self.weights = scale, bias class FilterResponseNorm(base.Layer): - """Filter Response Normalization layer without Threshold Linear Unit. + """Filter Response Normalization layer without Threshold Linear Unit. - c.f. https://arxiv.org/pdf/1911.09737.pdf - """ + c.f. https://arxiv.org/pdf/1911.09737.pdf + """ - def __init__(self, - mode=None, - learn_epsilon=False, - init_epsilon=1e-6, - init_learnt_epsilon=1e-4): - super().__init__() + def __init__( + self, + mode=None, + learn_epsilon=False, + init_epsilon=1e-6, + init_learnt_epsilon=1e-4, + ): + super().__init__() - del mode + del mode - # If we learn epsilon then epsilon = init_epsilon + |learnt_value| - # where learnt_value is initialized to init_learnt_epsilon. - # If learn_epsilon is false then epsilon is just init_epsilon. - # - # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`. - self._learn_epsilon = learn_epsilon + # If we learn epsilon then epsilon = init_epsilon + |learnt_value| + # where learnt_value is initialized to init_learnt_epsilon. + # If learn_epsilon is false then epsilon is just init_epsilon. + # + # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`. + self._learn_epsilon = learn_epsilon - # TODO(jonni): Replace asserts with ValueError. - assert init_epsilon > 0 - assert init_learnt_epsilon > 0 + # TODO(jonni): Replace asserts with ValueError. + assert init_epsilon > 0 + assert init_learnt_epsilon > 0 - self._init_epsilon = jnp.array(init_epsilon, dtype=jnp.float32) - self._init_learnt_epsilon = jnp.array(init_learnt_epsilon, - dtype=jnp.float32) + self._init_epsilon = jnp.array(init_epsilon, dtype=jnp.float32) + self._init_learnt_epsilon = jnp.array(init_learnt_epsilon, dtype=jnp.float32) - def forward(self, inputs): - gamma, beta, epsilon_l = self.weights + def forward(self, inputs): + gamma, beta, epsilon_l = self.weights - epsilon = self._init_epsilon - if epsilon_l is not base.EMPTY_WEIGHTS: - epsilon += jnp.abs(epsilon_l[0]) + epsilon = self._init_epsilon + if epsilon_l is not base.EMPTY_WEIGHTS: + epsilon += jnp.abs(epsilon_l[0]) - # Omit B and C - axis = tuple(range(1, len(jnp.shape(inputs)) - 1)) - # (B, 1, 1, C) - nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True) - # (B, W, H, C) - xhat = inputs / jnp.sqrt(nu2 + epsilon) + # Omit B and C + axis = tuple(range(1, len(jnp.shape(inputs)) - 1)) + # (B, 1, 1, C) + nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True) + # (B, W, H, C) + xhat = inputs / jnp.sqrt(nu2 + epsilon) - return gamma * xhat + beta + return gamma * xhat + beta - def init_weights_and_state(self, input_signature): - # Usually (B, W, H, C) - shape = input_signature.shape - num_channels = shape[-1] + def init_weights_and_state(self, input_signature): + # Usually (B, W, H, C) + shape = input_signature.shape + num_channels = shape[-1] - gamma = jnp.ones((num_channels,), dtype=jnp.float32) - beta = jnp.zeros((num_channels,), dtype=jnp.float32) + gamma = jnp.ones((num_channels,), dtype=jnp.float32) + beta = jnp.zeros((num_channels,), dtype=jnp.float32) - epsilon_l = base.EMPTY_WEIGHTS - if self._learn_epsilon: - epsilon_l = (self._init_learnt_epsilon,) + epsilon_l = base.EMPTY_WEIGHTS + if self._learn_epsilon: + epsilon_l = (self._init_learnt_epsilon,) - self.weights = gamma, beta, epsilon_l + self.weights = gamma, beta, epsilon_l def _exponentially_smoothed(momentum, old, new): - smoothed_value = momentum * old + (1 - momentum) * new - return smoothed_value.astype(old.dtype) + smoothed_value = momentum * old + (1 - momentum) * new + return smoothed_value.astype(old.dtype) diff --git a/trax/layers/normalization_test.py b/trax/layers/normalization_test.py deleted file mode 100644 index b844c0d1c..000000000 --- a/trax/layers/normalization_test.py +++ /dev/null @@ -1,130 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for normalization layers.""" - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import shapes -import trax.layers as tl - - -class BatchNormTest(parameterized.TestCase): - - def test_forward_shape(self): - layer = tl.BatchNorm() - x = np.ones((30, 20, 70)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - @parameterized.named_parameters( - ('jax32', fastmath.Backend.JAX, np.float32), - ('tf32', fastmath.Backend.TFNP, np.float32), - ('tf64', fastmath.Backend.TFNP, np.float64), - ) - def test_forward_dtype(self, backend, dtype): - with fastmath.use_backend(backend): - layer = tl.BatchNorm() - x = np.ones((3, 2, 7)).astype(dtype) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.dtype, dtype) - - @parameterized.named_parameters( - ('momentum_999', .999), - ('momentum_900', .900), - ('momentum_800', .800), - ) - def test_forward(self, momentum): - layer = tl.BatchNorm(momentum=momentum) - x = np.array([[[0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11]], - [[12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23]]]).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - running_mean, running_var, n_batches = layer.state - - fraction_old = momentum - fraction_new = 1.0 - momentum - mean_of_x = 11.5 # mean of range(24) - var_of_x = 47.9167 # variance of range(24) - np.testing.assert_allclose( - running_mean, 0.0 * fraction_old + mean_of_x * fraction_new) - np.testing.assert_allclose( - running_var, 1.0 * fraction_old + var_of_x * fraction_new, rtol=1e-6) - self.assertEqual(n_batches, 1) - eps = 1e-5 - np.testing.assert_allclose( - y, (x - mean_of_x) / np.sqrt(var_of_x + eps), rtol=1e-6) - - def test_new_weights_and_state(self): - layer = tl.BatchNorm() - x = np.ones((3, 2, 7)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - - running_mean, running_var, n_batches = layer.state - np.testing.assert_allclose(running_mean, 0.0) - np.testing.assert_allclose(running_var, 1.0) - self.assertEqual(n_batches, 0) - - -class LayerNormTest(parameterized.TestCase): - - def test_forward_shape(self): - layer = tl.LayerNorm() - x = np.ones((3, 2, 7)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - @parameterized.named_parameters( - ('jax32', fastmath.Backend.JAX, np.float32), - ('tf32', fastmath.Backend.TFNP, np.float32), - ('tf64', fastmath.Backend.TFNP, np.float64), - ) - def test_forward_dtype(self, backend, dtype): - with fastmath.use_backend(backend): - layer = tl.LayerNorm() - x = np.ones((3, 2, 7)).astype(dtype) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.dtype, dtype) - - -class FilterResponseNormTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('learn_epsilon_false', False), - ('learn_epsilon_true', True), - ) - def test_forward_shape(self, learn_epsilon): - layer = tl.FilterResponseNorm(learn_epsilon=learn_epsilon) - - B, H, W, C = 64, 5, 7, 3 # pylint: disable=invalid-name - x = np.ones((B, H, W, C)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/pooling.py b/trax/layers/pooling.py index f10241d3c..00e3d5abd 100644 --- a/trax/layers/pooling.py +++ b/trax/layers/pooling.py @@ -20,100 +20,109 @@ # pylint: disable=invalid-name -def MaxPool(pool_size=(2, 2), strides=None, padding='VALID'): - """Reduces each multi-dimensional window to the max of the window's values. - - Windows, as specified by `pool_size` and `strides`, involve all axes of an - n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` - from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. - - Args: - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only - full windows get reduced; partial windows are discarded. If 'SAME', - padding is added at array edges as needed to avoid partial windows - but does not otherwise affect the selection of max values. - - Returns: - N-dimensional array in which each valid (or padded-valid) window position - in the input is reduced to / replaced by the max value from that window. - An output array has the same number of dimensions as its input, but has - fewer elements. - """ - layer_name = f'MaxPool{pool_size}'.replace(' ', '') - def f(x): - return fastmath.max_pool( - x, pool_size=pool_size, strides=strides, padding=padding) - return Fn(layer_name, f) - - -def SumPool(pool_size=(2, 2), strides=None, padding='VALID'): - """Reduces each multi-dimensional window to the sum of the window's values. - - Windows, as specified by `pool_size` and `strides`, involve all axes of an - n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` - from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. - - Args: - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only - full windows get reduced; partial windows are discarded. If 'SAME', - padding is added at array edges as needed to avoid partial - windows but does not otherwise affect the computation of sums. - - Returns: - N-dimensional array in which each valid (or padded-valid) window position - in the input is reduced to / replaced by the sum of values in that window. - An output array has the same number of dimensions as its input, but has - fewer elements. - """ - layer_name = f'SumPool{pool_size}'.replace(' ', '') - def f(x): - return fastmath.sum_pool( - x, pool_size=pool_size, strides=strides, padding=padding) - return Fn(layer_name, f) - - -def AvgPool(pool_size=(2, 2), strides=None, padding='VALID'): - """Reduces each multi-dimensional window to the mean of the window's values. - - Windows, as specified by `pool_size` and `strides`, involve all axes of an - n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` - from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. - - Args: - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only - full windows get reduced; partial windows are discarded. If 'SAME', - padding is added at array edges as needed but is not counted in the - computation of averages. - - Returns: - N-dimensional array in which each valid (or padded-valid) window position - in the input is reduced to / replaced by the mean of values in that window. - An output array has the same number of dimensions as its input, but has - fewer elements. - """ - layer_name = f'AvgPool{pool_size}'.replace(' ', '') - def f(x): - return fastmath.avg_pool( - x, pool_size=pool_size, strides=strides, padding=padding) - return Fn(layer_name, f) +def MaxPool(pool_size=(2, 2), strides=None, padding="VALID"): + """Reduces each multi-dimensional window to the max of the window's values. + + Windows, as specified by `pool_size` and `strides`, involve all axes of an + n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` + from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. + + Args: + pool_size: Shape of window that gets reduced to a single vector value. + If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` + must be a tuple of length :math:`n-2`. + strides: Offsets from the location of one window to the locations of + neighboring windows along each axis. If specified, must be a tuple of + the same length as `pool_size`. If None, then offsets of 1 along each + window axis, :math:`(1, ..., 1)`, will be used. + padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only + full windows get reduced; partial windows are discarded. If 'SAME', + padding is added at array edges as needed to avoid partial windows + but does not otherwise affect the selection of max values. + + Returns: + N-dimensional array in which each valid (or padded-valid) window position + in the input is reduced to / replaced by the max value from that window. + An output array has the same number of dimensions as its input, but has + fewer elements. + """ + layer_name = f"MaxPool{pool_size}".replace(" ", "") + + def f(x): + return fastmath.max_pool( + x, pool_size=pool_size, strides=strides, padding=padding + ) + + return Fn(layer_name, f) + + +def SumPool(pool_size=(2, 2), strides=None, padding="VALID"): + """Reduces each multi-dimensional window to the sum of the window's values. + + Windows, as specified by `pool_size` and `strides`, involve all axes of an + n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` + from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. + + Args: + pool_size: Shape of window that gets reduced to a single vector value. + If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` + must be a tuple of length :math:`n-2`. + strides: Offsets from the location of one window to the locations of + neighboring windows along each axis. If specified, must be a tuple of + the same length as `pool_size`. If None, then offsets of 1 along each + window axis, :math:`(1, ..., 1)`, will be used. + padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only + full windows get reduced; partial windows are discarded. If 'SAME', + padding is added at array edges as needed to avoid partial + windows but does not otherwise affect the computation of sums. + + Returns: + N-dimensional array in which each valid (or padded-valid) window position + in the input is reduced to / replaced by the sum of values in that window. + An output array has the same number of dimensions as its input, but has + fewer elements. + """ + layer_name = f"SumPool{pool_size}".replace(" ", "") + + def f(x): + return fastmath.sum_pool( + x, pool_size=pool_size, strides=strides, padding=padding + ) + + return Fn(layer_name, f) + + +def AvgPool(pool_size=(2, 2), strides=None, padding="VALID"): + """Reduces each multi-dimensional window to the mean of the window's values. + + Windows, as specified by `pool_size` and `strides`, involve all axes of an + n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` + from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. + + Args: + pool_size: Shape of window that gets reduced to a single vector value. + If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` + must be a tuple of length :math:`n-2`. + strides: Offsets from the location of one window to the locations of + neighboring windows along each axis. If specified, must be a tuple of + the same length as `pool_size`. If None, then offsets of 1 along each + window axis, :math:`(1, ..., 1)`, will be used. + padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only + full windows get reduced; partial windows are discarded. If 'SAME', + padding is added at array edges as needed but is not counted in the + computation of averages. + + Returns: + N-dimensional array in which each valid (or padded-valid) window position + in the input is reduced to / replaced by the mean of values in that window. + An output array has the same number of dimensions as its input, but has + fewer elements. + """ + layer_name = f"AvgPool{pool_size}".replace(" ", "") + + def f(x): + return fastmath.avg_pool( + x, pool_size=pool_size, strides=strides, padding=padding + ) + + return Fn(layer_name, f) diff --git a/trax/layers/pooling_test.py b/trax/layers/pooling_test.py deleted file mode 100644 index 7c858cd8b..000000000 --- a/trax/layers/pooling_test.py +++ /dev/null @@ -1,140 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for conv layers.""" - -from absl.testing import absltest -import numpy as np - -import trax.layers as tl - - -class MaxPoolTest(absltest.TestCase): - - def test_forward_shape(self): - layer = tl.MaxPool(pool_size=(2, 2), strides=(1, 2)) - x = np.ones((11, 6, 4, 17)) - y = layer(x) - self.assertEqual(y.shape, (11, 5, 2, 17)) - - def test_forward(self): - layer = tl.MaxPool(pool_size=(2, 2), strides=(2, 2)) - x = np.array([[ - [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], - [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[[7, 5, 6], [70, 50, 60]]]]) - - def test_padding_default(self): - layer = tl.MaxPool(pool_size=(3,), strides=(3,)) - - # Discard incomplete window at end: [[3, 6], [4, 5]]. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6], [4, 5] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[2, 9]]]) - - def test_padding_same(self): - layer = tl.MaxPool(pool_size=(3,), strides=(3,), padding='SAME') - - # One padding position needed; add at end. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6], [4, 5] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[2, 9], [4, 6]]]) - - # Two padding positions needed; add one at end and one at start. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[1, 9], [3, 7]]]) - - -class SumPoolTest(absltest.TestCase): - - def test_forward_shape(self): - layer = tl.SumPool(pool_size=(2, 2), strides=(1, 2)) - x = np.ones((11, 6, 4, 17)) - y = layer(x) - self.assertEqual(y.shape, (11, 5, 2, 17)) - - def test_forward(self): - layer = tl.SumPool(pool_size=(2, 2), strides=(2, 2)) - x = np.array([[ - [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], - [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[[16, 10, 14], [160, 100, 140]]]]) - - def test_padding_same(self): - layer = tl.SumPool(pool_size=(3,), strides=(3,), padding='SAME') - - # One padding position needed; add at end. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6], [4, 5] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[3, 24], [7, 11]]]) - - # Two padding positions needed; add one at end and one at start. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[1, 17], [5, 13]]]) - - -class AvgPoolTest(absltest.TestCase): - - def test_forward_shape(self): - layer = tl.AvgPool(pool_size=(2, 2), strides=(1, 2)) - x = np.ones((11, 6, 4, 17)) - y = layer(x) - self.assertEqual(y.shape, (11, 5, 2, 17)) - - def test_forward(self): - layer = tl.AvgPool(pool_size=(2, 2), strides=(2, 2)) - x = np.array([[ - [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], - [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[[4.0, 2.5, 3.5], [40, 25, 35]]]]) - - def test_padding_same(self): - layer = tl.AvgPool(pool_size=(3,), strides=(3,), padding='SAME') - - # One padding position needed; add at end. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6], [4, 5] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[1, 8], [3.5, 5.5]]]) - - # Two padding positions needed; add one at end and one at start. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[.5, 8.5], [2.5, 6.5]]]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/research/efficient_attention.py b/trax/layers/research/efficient_attention.py index ecc52e523..f34a85d5a 100644 --- a/trax/layers/research/efficient_attention.py +++ b/trax/layers/research/efficient_attention.py @@ -52,3217 +52,3668 @@ def length_normalized(x, epsilon=1e-6): - variance = np.mean(x**2, axis=-1, keepdims=True) - norm_inputs = x / np.sqrt(variance + epsilon) - return norm_inputs + variance = np.mean(x**2, axis=-1, keepdims=True) + norm_inputs = x / np.sqrt(variance + epsilon) + return norm_inputs def hash_vecs(vecs, n_buckets_in, n_hashes, rng): - """Hash vectors into buckets. - - Args: - vecs: vectors to hash, a tensor of shape [batch_size, depth] - n_buckets_in: an int or a list of ints, number of hash buckets; - if it is a list, we do hierarchical hashing as specified by the list - n_hashes: number of hashes - rng: random generator to use for hashing - - Returns: - A pair (buckets, n_buckets) where buckets is a tensor of shape - [n_hashes, batch_size] of integers -- the hash bucket IDs, and - n_buckets is an int, the total number of hash buckets, equal to - the product of all items in n_buckets_in. - """ - # See https://arxiv.org/pdf/1509.02897.pdf - # We sample a different random rotation for each round of hashing to - # decrease the probability of hash misses. - if isinstance(n_buckets_in, int): - assert n_buckets_in % 2 == 0 - rot_size = n_buckets_in - n_buckets = n_buckets_in - else: - # Factorize the hash if n_buckets_in is a list or tuple - rot_size, n_buckets = 0, 1 - for factor in n_buckets_in: - assert factor % 2 == 0 - rot_size += factor - n_buckets *= factor - - rotations_shape = (vecs.shape[-1], n_hashes, rot_size // 2) - random_rotations = fastmath.random.normal(rng, rotations_shape).astype( - np.float32) - if fastmath.is_backend(fastmath.Backend.JAX): - rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations) - else: - random_rotations = np.reshape(random_rotations, - [-1, n_hashes * (rot_size // 2)]) - rotated_vecs = np.dot(vecs, random_rotations) - rotated_vecs = np.reshape(rotated_vecs, [-1, n_hashes, rot_size//2]) - rotated_vecs = np.transpose(rotated_vecs, (1, 0, 2)) - - if isinstance(n_buckets_in, int) or len(n_buckets_in) == 1: - rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) - buckets = np.argmax(rotated_vecs, axis=-1).astype(np.int32) - else: - # Get the buckets for them and combine. - buckets, cur_sum, cur_product = None, 0, 1 - for factor in n_buckets_in: - rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] - cur_sum += factor // 2 - rv = np.concatenate([rv, -rv], axis=-1) - if buckets is None: - buckets = np.argmax(rv, axis=-1).astype(np.int32) - else: - buckets += cur_product * np.argmax(rv, axis=-1).astype(np.int32) - cur_product *= factor - - return buckets, n_buckets # buckets is now (n_hashes, batch_size) + """Hash vectors into buckets. + Args: + vecs: vectors to hash, a tensor of shape [batch_size, depth] + n_buckets_in: an int or a list of ints, number of hash buckets; + if it is a list, we do hierarchical hashing as specified by the list + n_hashes: number of hashes + rng: random generator to use for hashing -def look_adjacent(x, n_chunks_before, n_chunks_after): - """Used to implement attention between consecutive chunks. - - Args: - x: array of shape [n_chunks, chunk_len, ...] - n_chunks_before: Number of previous chunks to attend to. - n_chunks_after: Number of subsequent chunks to attend to. - Returns: - array of shape [n_chunks, N * chunk_len, ...], where - N = (1 + n_chunks_before + n_chunks_after). - """ - if n_chunks_before == 0 and n_chunks_after == 0: - return x - - slices = [] - for i in range(-n_chunks_before, n_chunks_after + 1): - if i == 0: - slices.append(x) + Returns: + A pair (buckets, n_buckets) where buckets is a tensor of shape + [n_hashes, batch_size] of integers -- the hash bucket IDs, and + n_buckets is an int, the total number of hash buckets, equal to + the product of all items in n_buckets_in. + """ + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + if isinstance(n_buckets_in, int): + assert n_buckets_in % 2 == 0 + rot_size = n_buckets_in + n_buckets = n_buckets_in + else: + # Factorize the hash if n_buckets_in is a list or tuple + rot_size, n_buckets = 0, 1 + for factor in n_buckets_in: + assert factor % 2 == 0 + rot_size += factor + n_buckets *= factor + + rotations_shape = (vecs.shape[-1], n_hashes, rot_size // 2) + random_rotations = fastmath.random.normal(rng, rotations_shape).astype(np.float32) + if fastmath.is_backend(fastmath.Backend.JAX): + rotated_vecs = np.einsum("tf,fhb->htb", vecs, random_rotations) + else: + random_rotations = np.reshape( + random_rotations, [-1, n_hashes * (rot_size // 2)] + ) + rotated_vecs = np.dot(vecs, random_rotations) + rotated_vecs = np.reshape(rotated_vecs, [-1, n_hashes, rot_size // 2]) + rotated_vecs = np.transpose(rotated_vecs, (1, 0, 2)) + + if isinstance(n_buckets_in, int) or len(n_buckets_in) == 1: + rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) + buckets = np.argmax(rotated_vecs, axis=-1).astype(np.int32) else: - slices.append(np.concatenate([x[i:, ...], x[:i, ...]], axis=0)) - return np.concatenate(slices, axis=1) + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for factor in n_buckets_in: + rv = rotated_vecs[..., cur_sum : cur_sum + (factor // 2)] + cur_sum += factor // 2 + rv = np.concatenate([rv, -rv], axis=-1) + if buckets is None: + buckets = np.argmax(rv, axis=-1).astype(np.int32) + else: + buckets += cur_product * np.argmax(rv, axis=-1).astype(np.int32) + cur_product *= factor + + return buckets, n_buckets # buckets is now (n_hashes, batch_size) + + +def look_adjacent(x, n_chunks_before, n_chunks_after): + """Used to implement attention between consecutive chunks. + + Args: + x: array of shape [n_chunks, chunk_len, ...] + n_chunks_before: Number of previous chunks to attend to. + n_chunks_after: Number of subsequent chunks to attend to. + Returns: + array of shape [n_chunks, N * chunk_len, ...], where + N = (1 + n_chunks_before + n_chunks_after). + """ + if n_chunks_before == 0 and n_chunks_after == 0: + return x + + slices = [] + for i in range(-n_chunks_before, n_chunks_after + 1): + if i == 0: + slices.append(x) + else: + slices.append(np.concatenate([x[i:, ...], x[:i, ...]], axis=0)) + return np.concatenate(slices, axis=1) def mask_self_attention( - dots, q_info, kv_info, causal=True, exclude_self=True, masked=False): - """Performs masking for self-attention.""" - q_info = q_info.astype(np.float32) - kv_info = kv_info.astype(np.float32) - if causal: - mask = fastmath.lt(q_info, kv_info) - dots = dots - 1e9 * mask - if exclude_self: - mask = np.equal(q_info, kv_info) - dots = dots - 1e5 * mask - if masked: - zeros_like_kv_info = np.zeros_like(kv_info) - mask = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32) - dots = dots - 1e9 * mask - return dots + dots, q_info, kv_info, causal=True, exclude_self=True, masked=False +): + """Performs masking for self-attention.""" + q_info = q_info.astype(np.float32) + kv_info = kv_info.astype(np.float32) + if causal: + mask = fastmath.lt(q_info, kv_info) + dots = dots - 1e9 * mask + if exclude_self: + mask = np.equal(q_info, kv_info) + dots = dots - 1e5 * mask + if masked: + zeros_like_kv_info = np.zeros_like(kv_info) + mask = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32) + dots = dots - 1e9 * mask + return dots def attend( - q, k=None, v=None, - q_chunk_len=None, kv_chunk_len=None, - n_chunks_before=0, n_chunks_after=0, - mask_fn=None, q_info=None, kv_info=None, - dropout=0.0, rng=None, - ): - """Dot-product attention, with optional chunking and/or masking. - - Args: - q: Query vectors, shape [q_len, d_qk] - k: Key vectors, shape [kv_len, d_qk]; or None - v: Value vectors, shape [kv_len, d_v] - q_chunk_len: Set to non-zero to enable chunking for query vectors - kv_chunk_len: Set to non-zero to enable chunking for key/value vectors - n_chunks_before: Number of adjacent previous chunks to attend to - n_chunks_after: Number of adjacent subsequent chunks to attend to - mask_fn: TODO(kitaev) doc - q_info: Query-associated metadata for masking - kv_info: Key-associated metadata for masking - dropout: Dropout rate - rng: RNG for dropout - - Returns: - A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and - dots_logsumexp has shape [q_len]. The logsumexp of the attention - probabilities is useful for combining multiple rounds of attention (as in - LSH attention). - """ - assert v is not None - share_qk = (k is None) - - # `q_info` and `kv_info` if supplied are 0 indexed, we want them to be 1 - # indexed instead so that we can mask position 0 as well - see Github #820 - - if q_info is None: - q_info = np.arange(1, q.shape[-2] + 1, dtype=np.int32) - else: - q_info += 1 - - if kv_info is None and not share_qk: - kv_info = np.arange(1, v.shape[-2] + 1, dtype=np.int32) - elif kv_info is not None: - kv_info += 1 - - # Split q/k/v into chunks along the time axis, if desired. - if q_chunk_len is not None: - q = np.reshape(q, (-1, q_chunk_len, q.shape[-1])) - q_info = np.reshape(q_info, (-1, q_chunk_len)) - - if share_qk: - assert kv_chunk_len is None or kv_chunk_len == q_chunk_len - k = q - kv_chunk_len = q_chunk_len - if kv_info is None: - kv_info = q_info - elif kv_chunk_len is not None: - # kv_info is not None, but reshape as required. - kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) - elif kv_chunk_len is not None: - k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1])) - kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) - - if kv_chunk_len is not None: - v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1])) - - if share_qk: - k = length_normalized(k) - k = k / np.sqrt(k.shape[-1]) - - # Optionally include adjacent chunks. - if q_chunk_len is not None or kv_chunk_len is not None: - assert q_chunk_len is not None and kv_chunk_len is not None - else: - assert n_chunks_before == 0 and n_chunks_after == 0 - - k = look_adjacent(k, n_chunks_before, n_chunks_after) - v = look_adjacent(v, n_chunks_before, n_chunks_after) - kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after) - - # Dot-product attention. - dots = np.matmul(q, np.swapaxes(k, -1, -2)) - - # Masking - if mask_fn is not None: - dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :]) - - # Softmax. - dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True) - dots = np.exp(dots - dots_logsumexp) - - if dropout > 0.0: - assert rng is not None - # Dropout is broadcast across the bin dimension - dropout_shape = (dots.shape[-2], dots.shape[-1]) - # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix) - keep_prob = 1.0 - dropout - keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape) - multiplier = keep.astype(dots.dtype) / keep_prob - dots = dots * multiplier - - # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn. - out = np.matmul(dots, v) - out = np.reshape(out, (-1, out.shape[-1])) - dots_logsumexp = np.reshape(dots_logsumexp, (-1,)) - return out, dots_logsumexp + q, + k=None, + v=None, + q_chunk_len=None, + kv_chunk_len=None, + n_chunks_before=0, + n_chunks_after=0, + mask_fn=None, + q_info=None, + kv_info=None, + dropout=0.0, + rng=None, +): + """Dot-product attention, with optional chunking and/or masking. + Args: + q: Query vectors, shape [q_len, d_qk] + k: Key vectors, shape [kv_len, d_qk]; or None + v: Value vectors, shape [kv_len, d_v] + q_chunk_len: Set to non-zero to enable chunking for query vectors + kv_chunk_len: Set to non-zero to enable chunking for key/value vectors + n_chunks_before: Number of adjacent previous chunks to attend to + n_chunks_after: Number of adjacent subsequent chunks to attend to + mask_fn: TODO(kitaev) doc + q_info: Query-associated metadata for masking + kv_info: Key-associated metadata for masking + dropout: Dropout rate + rng: RNG for dropout -def apply_broadcasted_dropout(vecs, dropout_rate, rng): - """Apply dropout, broadcasted across all but the last dimension of `vecs`.""" - if dropout_rate > 0.0: - assert rng is not None - keep_prob = 1.0 - dropout_rate - keep = fastmath.random.bernoulli(rng, keep_prob, (vecs.shape[-1],)) - multiplier = keep.astype(vecs.dtype) / keep_prob - return vecs * multiplier - else: - return vecs + Returns: + A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and + dots_logsumexp has shape [q_len]. The logsumexp of the attention + probabilities is useful for combining multiple rounds of attention (as in + LSH attention). + """ + assert v is not None + share_qk = k is None + # `q_info` and `kv_info` if supplied are 0 indexed, we want them to be 1 + # indexed instead so that we can mask position 0 as well - see Github #820 -# The new implementations below don't use custom_transforms in JAX but -# do cause Tracer errors, so we don't use them for now. + if q_info is None: + q_info = np.arange(1, q.shape[-2] + 1, dtype=np.int32) + else: + q_info += 1 + + if kv_info is None and not share_qk: + kv_info = np.arange(1, v.shape[-2] + 1, dtype=np.int32) + elif kv_info is not None: + kv_info += 1 + + # Split q/k/v into chunks along the time axis, if desired. + if q_chunk_len is not None: + q = np.reshape(q, (-1, q_chunk_len, q.shape[-1])) + q_info = np.reshape(q_info, (-1, q_chunk_len)) + + if share_qk: + assert kv_chunk_len is None or kv_chunk_len == q_chunk_len + k = q + kv_chunk_len = q_chunk_len + if kv_info is None: + kv_info = q_info + elif kv_chunk_len is not None: + # kv_info is not None, but reshape as required. + kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) + elif kv_chunk_len is not None: + k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1])) + kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) + if kv_chunk_len is not None: + v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1])) -def permute_via_gather(val, permutation, inverse_permutation, axis=0): - """Permutation helper for LSH attention.""" - def permute_impl(p, unused_ip, val): - return np.take(val, p, axis=axis) - def permute_fwd(p, ip, val): - return np.take(val, p, axis=axis), ip - def permute_bwd(ip, permuted_grad): - # JAX autodiff would synthesize a scatter operation because it doesn't - # know that the indices are a permutation. However on TPU, gathers are - # faster than scatters (at least in the regime the LSH attention uses). - return (None, None, np.take(permuted_grad, ip, axis=axis)) - permute = fastmath.custom_vjp(permute_impl, permute_fwd, permute_bwd) - return permute(permutation, inverse_permutation, val) + if share_qk: + k = length_normalized(k) + k = k / np.sqrt(k.shape[-1]) + # Optionally include adjacent chunks. + if q_chunk_len is not None or kv_chunk_len is not None: + assert q_chunk_len is not None and kv_chunk_len is not None + else: + assert n_chunks_before == 0 and n_chunks_after == 0 -def permute_via_sort(val, keys, inverse_keys, axis=0): - """Permutation helper for LSH attention.""" - def permute_impl(k, unused_ik, val): - # On TPU, sorting scalars by key is faster than a gather. - _, permuted = fastmath.sort_key_val(k, val, dimension=axis) - return permuted - def permute_fwd(k, ik, val): - # On TPU, sorting scalars by key is faster than a gather. - _, permuted = fastmath.sort_key_val(k, val, dimension=axis) - return permuted, ik - def permute_bwd(ik, permuted_grad): - _, val_grad = fastmath.sort_key_val( - ik, permuted_grad, dimension=axis) - return (None, None, val_grad) - permute = fastmath.custom_vjp(permute_impl, permute_fwd, permute_bwd) - return permute(keys, inverse_keys, val) + k = look_adjacent(k, n_chunks_before, n_chunks_after) + v = look_adjacent(v, n_chunks_before, n_chunks_after) + kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after) + # Dot-product attention. + dots = np.matmul(q, np.swapaxes(k, -1, -2)) -####################################################### Classes + # Masking + if mask_fn is not None: + dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :]) + # Softmax. + dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True) + dots = np.exp(dots - dots_logsumexp) -class EfficientAttentionBase(base.Layer): - """Base class for efficient attention. + if dropout > 0.0: + assert rng is not None + # Dropout is broadcast across the bin dimension + dropout_shape = (dots.shape[-2], dots.shape[-1]) + # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix) + keep_prob = 1.0 - dropout + keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape) + multiplier = keep.astype(dots.dtype) / keep_prob + dots = dots * multiplier - This is a base class that implements memory-efficient batching for both the - forward and backward passes. Subclasses should override - `create_weights_unbatched`, `create_state_unbatched`, `forward_unbatched`, and - optionally `incremental_forward_unbatched` to define the actual attention - mechanism. - """ + # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn. + out = np.matmul(dots, v) + out = np.reshape(out, (-1, out.shape[-1])) + dots_logsumexp = np.reshape(dots_logsumexp, (-1,)) + return out, dots_logsumexp - def __init__(self, n_heads, n_in=1, n_parallel_heads=None, - incremental=False, predict_mem_len=None, predict_drop_len=None, - use_python_loop=False, use_reference_code=False): - """Constructs an EfficientAttentionBase instance. - Args: - n_heads: Number of attention heads. - n_in: Number of inputs to the layer (default 1). - n_parallel_heads: Number of attention heads to compute in parallel. - - - If `n_parallel_heads` is None (default), the entire layer is - computed with maximum parallelism. This mode is the fastest, but - also uses the most memory. Start with this mode, but switch to one - of the others if memory runs out. - - If `n_parallel_heads` is 1, attention is computed one head at a - time, and one example at a time. This mode uses the least memory - but is not as fast as batched attention. Use this mode when working - with very long sequences, such that any amount of parallelism won't - fit in memory. - - If `n_parallel_heads` is a multiple of `n_heads`, attention is - computed for sub-batches of (`n_parallel_heads // n_heads`) - examples at a time. - - If `1 < n_parallel_heads < n_heads`, attention is computed for - several heads at a time, but only within a single example. It must - be the case that `n_heads` is a multiple of `n_parallel_heads`. Use - this mode for long sequences, to strike a balance between - parallelism and memory usage. - incremental: If `True`, enable fast inference for self-attention types. - Note that this flag should *not* be set when doing encoder-decoder - attention, but only when doing self-attention. - predict_mem_len: Number of input positions to remember in a cache - when doing fast inference. Whenever the cache fills up, some input - elements will be forgotten. - predict_drop_len: Number of input elements to drop once the fast - inference input cache fills up. - use_python_loop: Set to True to use a Python loop when iterating over - sub-batches of examples/heads (as opposed to a JAX/XLA loop). - This option will increase compilation time and jitted code size, - potentially drastically. Using it is not recommended except for - testing/debugging. In particular, note that enabling this option on - TPU can decrease the maximum model size that will fit in memory. - use_reference_code: Set to True to fall back to the reference - implementation of batched attention. This option will increase - compilation time and jitted code size, potentially drastically. Using - it is not recommended except for testing/debugging. - """ - super().__init__(n_in=n_in, n_out=1) - self._n_heads = n_heads - self._incremental = incremental - if self._incremental: - if predict_mem_len is None or predict_drop_len is None: - raise ValueError('This configuration does not support fast inference.') - if not 0 < predict_drop_len <= predict_mem_len: - raise ValueError( - 'Bad parameter values: (predict_mem_len, predict_drop_len) = ', - predict_mem_len, predict_drop_len) - self._predict_mem_len = predict_mem_len - self._predict_drop_len = predict_drop_len - - if n_parallel_heads: - if ((n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) - or (n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0)): - raise ValueError( - 'n_parallel_heads must be a multiple or fraction of n_heads') - self._n_parallel_heads = n_parallel_heads +def apply_broadcasted_dropout(vecs, dropout_rate, rng): + """Apply dropout, broadcasted across all but the last dimension of `vecs`.""" + if dropout_rate > 0.0: + assert rng is not None + keep_prob = 1.0 - dropout_rate + keep = fastmath.random.bernoulli(rng, keep_prob, (vecs.shape[-1],)) + multiplier = keep.astype(vecs.dtype) / keep_prob + return vecs * multiplier else: - self._n_parallel_heads = None - self._use_python_loop = use_python_loop - self._use_reference_code = use_reference_code - - def init_weights_and_state(self, input_signature): - if not isinstance(input_signature, (tuple, list)): - input_signature = (input_signature,) - input_signature_unbatched = fastmath.nested_map( - lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), - input_signature) - batch_size = int(input_signature[0].shape[0]) - - weights = [] - weight_rngs = fastmath.random.split(self.rng, self._n_heads) - for i in range(self._n_heads): - weights.append(self.create_weights_unbatched(input_signature_unbatched, - weight_rngs[i])) - state = [] - state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) - for i in range(self._n_heads * batch_size): - state.append(self.create_state_unbatched(input_signature_unbatched, - state_rngs[i])) - - stack_along_axis_0 = lambda *x: np.stack(x, axis=0) - weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) - state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - - if self._incremental: - mem = fastmath.nested_map( - lambda x: np.zeros( # pylint: disable=g-long-lambda - x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], - dtype=x.dtype), - input_signature) - mem_end = np.zeros((), dtype=np.int32) - state = (mem_end, mem, state) - - self.state = tuple(state) - self.weights = tuple(weights) - - def create_weights_unbatched(self, input_signature, rng): - raise NotImplementedError( - 'Subclasses should override create_weights_unbatched') + return vecs - def create_state_unbatched(self, input_signature, rng): - return () - def forward_unbatched(self, *inputs, weights, state): - """Perform attention for a single batch element and head. +# The new implementations below don't use custom_transforms in JAX but +# do cause Tracer errors, so we don't use them for now. - Subclasses should override this method. - Args: - *inputs: Inputs for a single example (subclasses may use different inputs) - weights: Weights for a single attention head - state: State for a single example & attention head pair. +def permute_via_gather(val, permutation, inverse_permutation, axis=0): + """Permutation helper for LSH attention.""" - Returns: - A tuple (output, new_state) -- output and new state for a single example - and attention head. - """ - raise NotImplementedError('Subclasses should override forward_unbatched') + def permute_impl(p, unused_ip, val): + return np.take(val, p, axis=axis) - def _incremental_forward_unbatched(self, *inputs, q_start, q_len, - weights, state): - """Perform fast inference for a single batch element and head. + def permute_fwd(p, ip, val): + return np.take(val, p, axis=axis), ip - Subclasses should override this method. + def permute_bwd(ip, permuted_grad): + # JAX autodiff would synthesize a scatter operation because it doesn't + # know that the indices are a permutation. However on TPU, gathers are + # faster than scatters (at least in the regime the LSH attention uses). + return (None, None, np.take(permuted_grad, ip, axis=axis)) - Args: - *inputs: Inputs for a single example (subclasses may use different inputs) - q_start: Index along the sequence-length dimension that points to the - first input element that should be used as a query (and not just a key). - q_len: Number of new query elements in this call to the attention - mechanism. This is typically 1 for autoregressive decoding, but may be - longer if initializing a language model with a prefix. - weights: Weights for a single attention head - state: State for a single example & attention head pair. - Returns: - A tuple (output, new_state) -- output and new state for a single example - and attention head. - """ - raise NotImplementedError( - 'Fast inference is not implemented for this attention type.') + permute = fastmath.custom_vjp(permute_impl, permute_fwd, permute_bwd) + return permute(permutation, inverse_permutation, val) - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. - Args: - inputs: Layer inputs (subclasses may use different inputs) +def permute_via_sort(val, keys, inverse_keys, axis=0): + """Permutation helper for LSH attention.""" - Returns: - A tuple (output, new_state). + def permute_impl(k, unused_ik, val): + # On TPU, sorting scalars by key is faster than a gather. + _, permuted = fastmath.sort_key_val(k, val, dimension=axis) + return permuted + + def permute_fwd(k, ik, val): + # On TPU, sorting scalars by key is faster than a gather. + _, permuted = fastmath.sort_key_val(k, val, dimension=axis) + return permuted, ik + + def permute_bwd(ik, permuted_grad): + _, val_grad = fastmath.sort_key_val(ik, permuted_grad, dimension=axis) + return (None, None, val_grad) + + permute = fastmath.custom_vjp(permute_impl, permute_fwd, permute_bwd) + return permute(keys, inverse_keys, val) + + +####################################################### Classes + + +class EfficientAttentionBase(base.Layer): + """Base class for efficient attention. + + This is a base class that implements memory-efficient batching for both the + forward and backward passes. Subclasses should override + `create_weights_unbatched`, `create_state_unbatched`, `forward_unbatched`, and + optionally `incremental_forward_unbatched` to define the actual attention + mechanism. """ - weights, state, rng = self.weights, self.state, self.rng - if not self._use_reference_code: - # By default, an efficient, batched implementation is used. - output, new_state, _, _ = self.forward_and_or_backward( - inputs, weights, state, rng, compute_output=True, update_state=True) - self.state = new_state - return output - - # The reference implementation below provides a more readable overview of - # what this class does. It's not optimized, however, and should only be used - # when testing this class for correctness. - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - if self._incremental: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - - output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] - new_state = [] - for example_idx in range(batch_size): - for head_idx in range(self._n_heads): - # pylint: disable=cell-var-from-loop - single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) - single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) - single_state = fastmath.nested_map( - lambda s: s[example_idx * self._n_heads + head_idx], state) - # pylint: enable=cell-var-from-loop + + def __init__( + self, + n_heads, + n_in=1, + n_parallel_heads=None, + incremental=False, + predict_mem_len=None, + predict_drop_len=None, + use_python_loop=False, + use_reference_code=False, + ): + """Constructs an EfficientAttentionBase instance. + + Args: + n_heads: Number of attention heads. + n_in: Number of inputs to the layer (default 1). + n_parallel_heads: Number of attention heads to compute in parallel. + + - If `n_parallel_heads` is None (default), the entire layer is + computed with maximum parallelism. This mode is the fastest, but + also uses the most memory. Start with this mode, but switch to one + of the others if memory runs out. + - If `n_parallel_heads` is 1, attention is computed one head at a + time, and one example at a time. This mode uses the least memory + but is not as fast as batched attention. Use this mode when working + with very long sequences, such that any amount of parallelism won't + fit in memory. + - If `n_parallel_heads` is a multiple of `n_heads`, attention is + computed for sub-batches of (`n_parallel_heads // n_heads`) + examples at a time. + - If `1 < n_parallel_heads < n_heads`, attention is computed for + several heads at a time, but only within a single example. It must + be the case that `n_heads` is a multiple of `n_parallel_heads`. Use + this mode for long sequences, to strike a balance between + parallelism and memory usage. + incremental: If `True`, enable fast inference for self-attention types. + Note that this flag should *not* be set when doing encoder-decoder + attention, but only when doing self-attention. + predict_mem_len: Number of input positions to remember in a cache + when doing fast inference. Whenever the cache fills up, some input + elements will be forgotten. + predict_drop_len: Number of input elements to drop once the fast + inference input cache fills up. + use_python_loop: Set to True to use a Python loop when iterating over + sub-batches of examples/heads (as opposed to a JAX/XLA loop). + This option will increase compilation time and jitted code size, + potentially drastically. Using it is not recommended except for + testing/debugging. In particular, note that enabling this option on + TPU can decrease the maximum model size that will fit in memory. + use_reference_code: Set to True to fall back to the reference + implementation of batched attention. This option will increase + compilation time and jitted code size, potentially drastically. Using + it is not recommended except for testing/debugging. + """ + super().__init__(n_in=n_in, n_out=1) + self._n_heads = n_heads + self._incremental = incremental if self._incremental: - single_out, single_new_state = self._incremental_forward_unbatched( - *single_inputs, q_start=q_start, q_len=seqlen, - weights=single_weights, rng=rng, - state=single_state, update_state=True) - else: - single_out, single_new_state = self.forward_unbatched( - *single_inputs, weights=single_weights, rng=rng, - state=single_state, update_state=True) - new_state.append(single_new_state) - output_accum[example_idx] = output_accum[example_idx] + single_out - - output = np.stack(output_accum, 0) - if new_state and fastmath.tree_leaves(new_state[0]): - new_state = fastmath.nested_map_multiarg( - lambda *s: np.stack(s, 0), *new_state) - else: - new_state = state - if self._incremental: - new_state = (new_mem_end, new_mem, new_state) - self.state = tuple(new_state) - return output - - def _use_predict_mem(self, inputs, state): - """Update input cache for fast inference.""" - mem_end, mem, state = state - seqlen = inputs[0].shape[-2] - - if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: - # This branch is called when only a small number of tokens are appended to - # the sequence, e.g. when generating one token at a time. A fixed number - # of tokens (self._predict_drop_tokens) will be dropped from memory if - # needed, and then new values will be inserted into the memory. - def roll_mem(buf): - return np.concatenate( - [buf[:, self._predict_drop_len:], - np.zeros_like(buf[:, :self._predict_drop_len])], axis=1) - - do_roll_mem = (mem_end + seqlen > self._predict_mem_len) - mem = fastmath.cond( - pred=do_roll_mem, - true_operand=mem, - true_fun=lambda x: fastmath.nested_map(roll_mem, x), - false_operand=mem, - false_fun=lambda x: x, - ) - mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) - def update_mem(mem_element, new_vals): - assert new_vals.shape[1] == seqlen - if seqlen == 1: - return fastmath.index_update( - mem_element, jax.numpy.index_exp[:, mem_end], new_vals[:, 0, ...]) + if predict_mem_len is None or predict_drop_len is None: + raise ValueError("This configuration does not support fast inference.") + if not 0 < predict_drop_len <= predict_mem_len: + raise ValueError( + "Bad parameter values: (predict_mem_len, predict_drop_len) = ", + predict_mem_len, + predict_drop_len, + ) + self._predict_mem_len = predict_mem_len + self._predict_drop_len = predict_drop_len + + if n_parallel_heads: + if (n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) or ( + n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0 + ): + raise ValueError( + "n_parallel_heads must be a multiple or fraction of n_heads" + ) + self._n_parallel_heads = n_parallel_heads else: - return fastmath.dynamic_update_slice_in_dim( - mem_element, new_vals, mem_end, axis=1) - inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) - return inputs, state, mem_end, inputs, mem_end + seqlen - else: - assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len - # This branch handles the case where a large number of tokens are being - # introduced all at once. The code here assumes that we are at the start - # of the sequence, which matches the typical use case of decoding from a - # language model given a long prefix. Note that if we're not at the start - # of the sequence, the code here won't work. - new_flat_mem = [] - for inp in fastmath.tree_leaves(inputs): - assert inp.shape[1] == seqlen - if seqlen == self._predict_mem_len: - new_mem_val = inp - elif seqlen > self._predict_mem_len: - new_mem_val = inp[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - else: - new_mem_val = np.concatenate([ - inp, - np.zeros(inp.shape[:1] - + (self._predict_mem_len - inp.shape[1],) - + inp.shape[2:], - dtype=inp.dtype) - ], axis=1) - new_flat_mem.append(new_mem_val) - mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) - - # This code only works at the start of the sequence. There's no "assert" - # primitive we can use to signal an error, so we instead signal the error - # by introducing NaNs into the computation. - def replace_with_nan_if_not_seq_start(x): - if x.dtype != np.float32: - return x - return fastmath.cond( - pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), - true_operand=x, true_fun=lambda x: x, - false_operand=x, false_fun=lambda x: x * np.nan) - inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) - return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) - - @property - def has_backward(self): - # Use an efficient backward pass, unless we're running the reference code. - return not self._use_reference_code - - def backward(self, inputs, output, grad, weights, state, new_state, rng=None, - **kwargs): - """Custom backward pass, for efficiency (see forward_and_or_backward).""" - assert not self._use_reference_code - del output, state, kwargs - _, _, inputs_grad, weights_grad = self.forward_and_or_backward( - inputs, weights, new_state, rng, output_grad=grad, - compute_output=False, update_state=False) - return inputs_grad, weights_grad - - def forward_and_or_backward( - self, inputs, weights, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. - - See `forward` for a reference implementation of what this layer does. The - reference implementation is not very efficient, however, and this method - provides a more performant version. + self._n_parallel_heads = None + self._use_python_loop = use_python_loop + self._use_reference_code = use_reference_code + + def init_weights_and_state(self, input_signature): + if not isinstance(input_signature, (tuple, list)): + input_signature = (input_signature,) + input_signature_unbatched = fastmath.nested_map( + lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), input_signature + ) + batch_size = int(input_signature[0].shape[0]) + + weights = [] + weight_rngs = fastmath.random.split(self.rng, self._n_heads) + for i in range(self._n_heads): + weights.append( + self.create_weights_unbatched(input_signature_unbatched, weight_rngs[i]) + ) + state = [] + state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) + for i in range(self._n_heads * batch_size): + state.append( + self.create_state_unbatched(input_signature_unbatched, state_rngs[i]) + ) + + stack_along_axis_0 = lambda *x: np.stack(x, axis=0) + weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) + state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. + if self._incremental: + mem = fastmath.nested_map( + lambda x: np.zeros( # pylint: disable=g-long-lambda + x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], dtype=x.dtype + ), + input_signature, + ) + mem_end = np.zeros((), dtype=np.int32) + state = (mem_end, mem, state) + + self.state = tuple(state) + self.weights = tuple(weights) + + def create_weights_unbatched(self, input_signature, rng): + raise NotImplementedError("Subclasses should override create_weights_unbatched") + + def create_state_unbatched(self, input_signature, rng): + return () + + def forward_unbatched(self, *inputs, weights, state): + """Perform attention for a single batch element and head. + + Subclasses should override this method. + + Args: + *inputs: Inputs for a single example (subclasses may use different inputs) + weights: Weights for a single attention head + state: State for a single example & attention head pair. + + Returns: + A tuple (output, new_state) -- output and new state for a single example + and attention head. + """ + raise NotImplementedError("Subclasses should override forward_unbatched") + + def _incremental_forward_unbatched(self, *inputs, q_start, q_len, weights, state): + """Perform fast inference for a single batch element and head. + + Subclasses should override this method. + + Args: + *inputs: Inputs for a single example (subclasses may use different inputs) + q_start: Index along the sequence-length dimension that points to the + first input element that should be used as a query (and not just a key). + q_len: Number of new query elements in this call to the attention + mechanism. This is typically 1 for autoregressive decoding, but may be + longer if initializing a language model with a prefix. + weights: Weights for a single attention head + state: State for a single example & attention head pair. + Returns: + A tuple (output, new_state) -- output and new state for a single example + and attention head. + """ + raise NotImplementedError( + "Fast inference is not implemented for this attention type." + ) - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. + + Args: + inputs: Layer inputs (subclasses may use different inputs) + + Returns: + A tuple (output, new_state). + """ + weights, state, rng = self.weights, self.state, self.rng + if not self._use_reference_code: + # By default, an efficient, batched implementation is used. + output, new_state, _, _ = self.forward_and_or_backward( + inputs, weights, state, rng, compute_output=True, update_state=True + ) + self.state = new_state + return output + + # The reference implementation below provides a more readable overview of + # what this class does. It's not optimized, however, and should only be used + # when testing this class for correctness. + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # TODO(kitaev): profile ~4% speed drop compared to previous implementation - # in some conditions. Other conditions (e.g. the enwik8 model) appear - # to have the same overall training speed. - # TODO(b/148460708): reduce memory usage further - # TODO(kitaev): there should be a higher-level API (like vmap) that does - # batching, instead of needing 3 separate manual implementations here. - - # Notes regarding the implementation: - # (a) Multiple heads or examples are batched together. There are three - # different regimes possible: one head at a time (for long sequences and - # expensive attention types), several attention heads at a time (for - # long sequences but less-expensive attention types), and several - # examples at a time (for large batches of shorter sequences). For the - # time being, each of these regimes has its own code. - # (b) Python loops produce large computation graphs when jitted, so the - # default is to use a JAX loop instead. - # (c) No intermediate quantities are cached for the backward pass. Instead, - # the forward pass is re-computed when doing backprop. This approach is - # often called "checkpointing" or "rematerialization". When not all - # examples or heads fit in memory simultaneously, the implementation - # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse - # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] - # automatically, so the looping for the backward pass is done manually. - # - # [FW-BW-1] for example, head in zip(examples, heads): - # forward(example, head) - # backward(example, head) # uses intermediates from forward - # - # [FW-BW-2] for example, head in zip(examples, heads): - # forward(example, head) - # for example, head in zip(examples, heads): - # backward(example, head) - - have_single_input = not isinstance(inputs, (tuple, list)) - if have_single_input: - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - compute_grad = (output_grad is not None) - assert compute_output or compute_grad, 'No work to perform!' - - if not self._incremental: - forward_unbatched = functools.partial( - self.forward_unbatched, rng=rng, update_state=update_state) - else: - if update_state: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - else: - # This assumes that the memory stores all of the inputs, which would not - # be valid if doing backprop in mode 'predict' with long lengths. - new_mem_end, inputs, state = state - q_start = new_mem_end - seqlen - - forward_unbatched = functools.partial( - self._incremental_forward_unbatched, - q_start=fastmath.stop_gradient(q_start), - q_len=fastmath.stop_gradient(seqlen), - rng=rng, update_state=update_state) - - # Adjust degree of parallelism based on the batch size. - n_parallel_heads = batch_size * self._n_heads - if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: - n_parallel_heads = self._n_parallel_heads - - def tree_update(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], - y), - tree, new_values) - - def tree_add(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), - tree, new_values) - - if compute_grad: - inputs_is_differentiable = fastmath.nested_map( - lambda x: np.issubdtype(x.dtype, np.inexact), inputs) - def split_differentiable(xs): - differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: x if is_differentiable else None, - xs, inputs_is_differentiable) - non_differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: None if is_differentiable else x, - xs, inputs_is_differentiable) - return differentiable_xs, non_differentiable_xs - def join_differentiable(differentiable_xs, non_differentiable_xs): - """Reconstitute inputs pytree from differentiable/non-d. partitions.""" - differentiable_leaves = fastmath.tree_leaves(differentiable_xs) - non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) - leaves = [] - for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): - if is_differentiable: - leaves.append(differentiable_leaves.pop(0)) - else: - leaves.append(non_differentiable_leaves.pop(0)) - assert not differentiable_leaves - assert not non_differentiable_leaves - tree, _ = fastmath.tree_unflatten(leaves, inputs) - return tree - - def vjp(fn, inp, *args, has_aux=False): - d_inp, nd_inp = split_differentiable(inp) - def fn_closed_over_nd_inp(d_inp, *args): - inp = join_differentiable(d_inp, nd_inp) - return fn(inp, *args) - return fastmath.vjp(fn_closed_over_nd_inp, d_inp, *args, - has_aux=has_aux) - - if n_parallel_heads == 1: - def run_inner(idx, loop_val): - """Runs one slice of attention (for a single head).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - example_idx = idx // self._n_heads - head_idx = idx % self._n_heads - - i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_h = fastmath.nested_map(lambda w: w[head_idx], weights) - s_h = fastmath.nested_map(lambda s: s[idx], state) - - def forward_fn(i_h, w_h): - return forward_unbatched( - *i_h, weights=w_h, state=fastmath.stop_gradient(s_h)) + if self._incremental: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + + output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] + new_state = [] + for example_idx in range(batch_size): + for head_idx in range(self._n_heads): + # pylint: disable=cell-var-from-loop + single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) + single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) + single_state = fastmath.nested_map( + lambda s: s[example_idx * self._n_heads + head_idx], state + ) + # pylint: enable=cell-var-from-loop + if self._incremental: + single_out, single_new_state = self._incremental_forward_unbatched( + *single_inputs, + q_start=q_start, + q_len=seqlen, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + else: + single_out, single_new_state = self.forward_unbatched( + *single_inputs, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + new_state.append(single_new_state) + output_accum[example_idx] = output_accum[example_idx] + single_out + + output = np.stack(output_accum, 0) + if new_state and fastmath.tree_leaves(new_state[0]): + new_state = fastmath.nested_map_multiarg( + lambda *s: np.stack(s, 0), *new_state + ) + else: + new_state = state + if self._incremental: + new_state = (new_mem_end, new_mem, new_state) + self.state = tuple(new_state) + return output + + def _use_predict_mem(self, inputs, state): + """Update input cache for fast inference.""" + mem_end, mem, state = state + seqlen = inputs[0].shape[-2] + + if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: + # This branch is called when only a small number of tokens are appended to + # the sequence, e.g. when generating one token at a time. A fixed number + # of tokens (self._predict_drop_tokens) will be dropped from memory if + # needed, and then new values will be inserted into the memory. + def roll_mem(buf): + return np.concatenate( + [ + buf[:, self._predict_drop_len :], + np.zeros_like(buf[:, : self._predict_drop_len]), + ], + axis=1, + ) + + do_roll_mem = mem_end + seqlen > self._predict_mem_len + mem = fastmath.cond( + pred=do_roll_mem, + true_operand=mem, + true_fun=lambda x: fastmath.nested_map(roll_mem, x), + false_operand=mem, + false_fun=lambda x: x, + ) + mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) + + def update_mem(mem_element, new_vals): + assert new_vals.shape[1] == seqlen + if seqlen == 1: + return fastmath.index_update( + mem_element, + jax.numpy.index_exp[:, mem_end], + new_vals[:, 0, ...], + ) + else: + return fastmath.dynamic_update_slice_in_dim( + mem_element, new_vals, mem_end, axis=1 + ) + + inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) + return inputs, state, mem_end, inputs, mem_end + seqlen + else: + assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len + # This branch handles the case where a large number of tokens are being + # introduced all at once. The code here assumes that we are at the start + # of the sequence, which matches the typical use case of decoding from a + # language model given a long prefix. Note that if we're not at the start + # of the sequence, the code here won't work. + new_flat_mem = [] + for inp in fastmath.tree_leaves(inputs): + assert inp.shape[1] == seqlen + if seqlen == self._predict_mem_len: + new_mem_val = inp + elif seqlen > self._predict_mem_len: + new_mem_val = inp[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + else: + new_mem_val = np.concatenate( + [ + inp, + np.zeros( + inp.shape[:1] + + (self._predict_mem_len - inp.shape[1],) + + inp.shape[2:], + dtype=inp.dtype, + ), + ], + axis=1, + ) + new_flat_mem.append(new_mem_val) + mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) + + # This code only works at the start of the sequence. There's no "assert" + # primitive we can use to signal an error, so we instead signal the error + # by introducing NaNs into the computation. + def replace_with_nan_if_not_seq_start(x): + if x.dtype != np.float32: + return x + return fastmath.cond( + pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), + true_operand=x, + true_fun=lambda x: x, + false_operand=x, + false_fun=lambda x: x * np.nan, + ) + + inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) + return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) + + @property + def has_backward(self): + # Use an efficient backward pass, unless we're running the reference code. + return not self._use_reference_code + + def backward( + self, inputs, output, grad, weights, state, new_state, rng=None, **kwargs + ): + """Custom backward pass, for efficiency (see forward_and_or_backward).""" + assert not self._use_reference_code + del output, state, kwargs + _, _, inputs_grad, weights_grad = self.forward_and_or_backward( + inputs, + weights, + new_state, + rng, + output_grad=grad, + compute_output=False, + update_state=False, + ) + return inputs_grad, weights_grad - if compute_grad: - o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) - ct_h = output_grad[example_idx] - assert o_h.shape == ct_h.shape - i_ct_h, w_ct_h = backward_fn(ct_h) + def forward_and_or_backward( + self, + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + See `forward` for a reference implementation of what this layer does. The + reference implementation is not very efficient, however, and this method + provides a more performant version. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # TODO(kitaev): profile ~4% speed drop compared to previous implementation + # in some conditions. Other conditions (e.g. the enwik8 model) appear + # to have the same overall training speed. + # TODO(b/148460708): reduce memory usage further + # TODO(kitaev): there should be a higher-level API (like vmap) that does + # batching, instead of needing 3 separate manual implementations here. + + # Notes regarding the implementation: + # (a) Multiple heads or examples are batched together. There are three + # different regimes possible: one head at a time (for long sequences and + # expensive attention types), several attention heads at a time (for + # long sequences but less-expensive attention types), and several + # examples at a time (for large batches of shorter sequences). For the + # time being, each of these regimes has its own code. + # (b) Python loops produce large computation graphs when jitted, so the + # default is to use a JAX loop instead. + # (c) No intermediate quantities are cached for the backward pass. Instead, + # the forward pass is re-computed when doing backprop. This approach is + # often called "checkpointing" or "rematerialization". When not all + # examples or heads fit in memory simultaneously, the implementation + # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse + # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] + # automatically, so the looping for the backward pass is done manually. + # + # [FW-BW-1] for example, head in zip(examples, heads): + # forward(example, head) + # backward(example, head) # uses intermediates from forward + # + # [FW-BW-2] for example, head in zip(examples, heads): + # forward(example, head) + # for example, head in zip(examples, heads): + # backward(example, head) + + have_single_input = not isinstance(inputs, (tuple, list)) + if have_single_input: + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] + + compute_grad = output_grad is not None + assert compute_output or compute_grad, "No work to perform!" + + if not self._incremental: + forward_unbatched = functools.partial( + self.forward_unbatched, rng=rng, update_state=update_state + ) else: - o_h, s_h = forward_fn(i_h, w_h) + if update_state: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + else: + # This assumes that the memory stores all of the inputs, which would not + # be valid if doing backprop in mode 'predict' with long lengths. + new_mem_end, inputs, state = state + q_start = new_mem_end - seqlen + + forward_unbatched = functools.partial( + self._incremental_forward_unbatched, + q_start=fastmath.stop_gradient(q_start), + q_len=fastmath.stop_gradient(seqlen), + rng=rng, + update_state=update_state, + ) + + # Adjust degree of parallelism based on the batch size. + n_parallel_heads = batch_size * self._n_heads + if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: + n_parallel_heads = self._n_parallel_heads + + def tree_update(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) + + def tree_add(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) - if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_h) - if update_state: - s_all = tree_update(s_all, idx, s_h) if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) - w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) - return (o_all, s_all, i_ct_all, w_ct_all) - elif n_parallel_heads < self._n_heads: - assert self._n_heads % n_parallel_heads == 0 - def run_inner(idx, loop_val): - """Runs one slice of attention (multiple heads, but one example).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * self._n_parallel_heads - example_idx = idx // self._n_heads - head_idx_lo = idx % self._n_heads - head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_mh = fastmath.nested_map(lambda w: w[head_range], weights) - s_mh = fastmath.nested_map(lambda s: s[state_range], state) - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - def forward_fn(i_mh, w_mh): - o_mh, new_s_mh = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0)( - i_mh, w_mh, s_mh) - o_mh = np.sum(o_mh, axis=0) - return o_mh, new_s_mh + inputs_is_differentiable = fastmath.nested_map( + lambda x: np.issubdtype(x.dtype, np.inexact), inputs + ) + + def split_differentiable(xs): + differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: x if is_differentiable else None, + xs, + inputs_is_differentiable, + ) + non_differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: None if is_differentiable else x, + xs, + inputs_is_differentiable, + ) + return differentiable_xs, non_differentiable_xs + + def join_differentiable(differentiable_xs, non_differentiable_xs): + """Reconstitute inputs pytree from differentiable/non-d. partitions.""" + differentiable_leaves = fastmath.tree_leaves(differentiable_xs) + non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) + leaves = [] + for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): + if is_differentiable: + leaves.append(differentiable_leaves.pop(0)) + else: + leaves.append(non_differentiable_leaves.pop(0)) + assert not differentiable_leaves + assert not non_differentiable_leaves + tree, _ = fastmath.tree_unflatten(leaves, inputs) + return tree + + def vjp(fn, inp, *args, has_aux=False): + d_inp, nd_inp = split_differentiable(inp) + + def fn_closed_over_nd_inp(d_inp, *args): + inp = join_differentiable(d_inp, nd_inp) + return fn(inp, *args) + + return fastmath.vjp( + fn_closed_over_nd_inp, d_inp, *args, has_aux=has_aux + ) + + if n_parallel_heads == 1: + + def run_inner(idx, loop_val): + """Runs one slice of attention (for a single head).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + example_idx = idx // self._n_heads + head_idx = idx % self._n_heads + + i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_h = fastmath.nested_map(lambda w: w[head_idx], weights) + s_h = fastmath.nested_map(lambda s: s[idx], state) + + def forward_fn(i_h, w_h): + return forward_unbatched( + *i_h, weights=w_h, state=fastmath.stop_gradient(s_h) + ) + + if compute_grad: + o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) + ct_h = output_grad[example_idx] + assert o_h.shape == ct_h.shape + i_ct_h, w_ct_h = backward_fn(ct_h) + else: + o_h, s_h = forward_fn(i_h, w_h) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_h) + if update_state: + s_all = tree_update(s_all, idx, s_h) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) + w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) + return (o_all, s_all, i_ct_all, w_ct_all) + + elif n_parallel_heads < self._n_heads: + assert self._n_heads % n_parallel_heads == 0 + + def run_inner(idx, loop_val): + """Runs one slice of attention (multiple heads, but one example).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * self._n_parallel_heads + example_idx = idx // self._n_heads + head_idx_lo = idx % self._n_heads + head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_mh = fastmath.nested_map(lambda w: w[head_range], weights) + s_mh = fastmath.nested_map(lambda s: s[state_range], state) + + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + def forward_fn(i_mh, w_mh): + o_mh, new_s_mh = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0 + )(i_mh, w_mh, s_mh) + o_mh = np.sum(o_mh, axis=0) + return o_mh, new_s_mh + + if compute_grad: + o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) + ct_mh = output_grad[example_idx] + assert o_mh.shape == ct_mh.shape + i_ct_mh, w_ct_mh = backward_fn(ct_mh) + else: + o_mh, s_mh = forward_fn(i_mh, w_mh) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_mh) + if update_state: + s_all = tree_update(s_all, state_range, s_mh) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) + w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) + return (o_all, s_all, i_ct_all, w_ct_all) - if compute_grad: - o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) - ct_mh = output_grad[example_idx] - assert o_mh.shape == ct_mh.shape - i_ct_mh, w_ct_mh = backward_fn(ct_mh) else: - o_mh, s_mh = forward_fn(i_mh, w_mh) - + assert n_parallel_heads % self._n_heads == 0 + + def forward_single_example(i_x, w_all, s_x): + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + o_x, s_x = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0) + )(i_x, w_all, s_x) + o_x = np.sum(o_x, axis=0) + return o_x, s_x + + def run_inner(idx, loop_val): + """Runs one slice of attention (all heads for one or more examples).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * n_parallel_heads + example_idx_lo = idx // self._n_heads + example_range = example_idx_lo + np.arange( + n_parallel_heads // self._n_heads, dtype=np.int32 + ) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) + s_mex = fastmath.nested_map( + lambda s: np.reshape( + s[state_range], # pylint: disable=g-long-lambda + (-1, self._n_heads) + s.shape[1:], + ), + state, + ) + + def forward_fn(i_mex, w_all): + o_mex, new_s_mex = fastmath.vmap( + forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0) + )(i_mex, w_all, s_mex) + new_s_mex = fastmath.nested_map( + lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), + new_s_mex, + ) + return o_mex.astype(i_mex[0].dtype), new_s_mex + + if compute_grad: + o_mex, backward_fn, s_mex = vjp( + forward_fn, i_mex, weights, has_aux=True + ) + ct_mex = output_grad[example_range] + assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) + assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) + i_ct_mex, w_ct_mex = backward_fn(ct_mex) + else: + o_mex, s_mex = forward_fn(i_mex, weights) + + if compute_output: + o_all = fastmath.index_add( + o_all, jax.numpy.index_exp[example_range], o_mex + ) + if update_state: + s_all = tree_update(s_all, state_range, s_mex) + if compute_grad: + i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) + w_ct_all = fastmath.nested_map_multiarg( + lambda old_all, delta_all: old_all + delta_all, + w_ct_all, + w_ct_mex, + ) + return (o_all, s_all, i_ct_all, w_ct_all) + + o_all = s_all = i_ct_all = w_ct_all = None if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_mh) + o_all = np.zeros((batch_size, seqlen, d_model), dtype=inputs[0].dtype) if update_state: - s_all = tree_update(s_all, state_range, s_mh) + s_all = state if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) - w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) - return (o_all, s_all, i_ct_all, w_ct_all) - else: - assert n_parallel_heads % self._n_heads == 0 - def forward_single_example(i_x, w_all, s_x): - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - o_x, s_x = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0))( - i_x, w_all, s_x) - o_x = np.sum(o_x, axis=0) - return o_x, s_x - def run_inner(idx, loop_val): - """Runs one slice of attention (all heads for one or more examples).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * n_parallel_heads - example_idx_lo = idx // self._n_heads - example_range = example_idx_lo + np.arange( - n_parallel_heads // self._n_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) - s_mex = fastmath.nested_map( - lambda s: np.reshape(s[state_range], # pylint: disable=g-long-lambda - (-1, self._n_heads) + s.shape[1:]), - state) - def forward_fn(i_mex, w_all): - o_mex, new_s_mex = fastmath.vmap( - forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0))( - i_mex, w_all, s_mex) - new_s_mex = fastmath.nested_map( - lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), - new_s_mex) - return o_mex.astype(i_mex[0].dtype), new_s_mex + i_ct_all = fastmath.nested_map(np.zeros_like, inputs) + i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) + w_ct_all = fastmath.nested_map(np.zeros_like, weights) - if compute_grad: - o_mex, backward_fn, s_mex = vjp(forward_fn, i_mex, weights, - has_aux=True) - ct_mex = output_grad[example_range] - assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) - assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) - i_ct_mex, w_ct_mex = backward_fn(ct_mex) - else: - o_mex, s_mex = forward_fn(i_mex, weights) + loop_val = (o_all, s_all, i_ct_all, w_ct_all) - if compute_output: - o_all = fastmath.index_add(o_all, jax.numpy.index_exp[example_range], - o_mex) - if update_state: - s_all = tree_update(s_all, state_range, s_mex) - if compute_grad: - i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) - w_ct_all = fastmath.nested_map_multiarg( - lambda old_all, delta_all: old_all + delta_all, - w_ct_all, w_ct_mex) - return (o_all, s_all, i_ct_all, w_ct_all) - - o_all = s_all = i_ct_all = w_ct_all = None - if compute_output: - o_all = np.zeros( - (batch_size, seqlen, d_model), dtype=inputs[0].dtype) - if update_state: - s_all = state - if compute_grad: - i_ct_all = fastmath.nested_map(np.zeros_like, inputs) - i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - w_ct_all = fastmath.nested_map(np.zeros_like, weights) - - loop_val = (o_all, s_all, i_ct_all, w_ct_all) - - assert (batch_size * self._n_heads) % n_parallel_heads == 0 - loop_hi = (batch_size * self._n_heads) // n_parallel_heads - if self._use_python_loop or loop_hi == 1: - for idx in range(loop_hi): - loop_val = run_inner(idx, loop_val) - else: - loop_val = fastmath.fori_loop( - 0, loop_hi, run_inner, loop_val) + assert (batch_size * self._n_heads) % n_parallel_heads == 0 + loop_hi = (batch_size * self._n_heads) // n_parallel_heads + if self._use_python_loop or loop_hi == 1: + for idx in range(loop_hi): + loop_val = run_inner(idx, loop_val) + else: + loop_val = fastmath.fori_loop(0, loop_hi, run_inner, loop_val) - (o_all, s_all, i_ct_all, w_ct_all) = loop_val + (o_all, s_all, i_ct_all, w_ct_all) = loop_val - if compute_grad: - i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + if compute_grad: + i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) - if self._incremental and update_state: - s_all = (new_mem_end, new_mem, s_all) + if self._incremental and update_state: + s_all = (new_mem_end, new_mem, s_all) - if have_single_input and compute_grad: - assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 - return (o_all, s_all, i_ct_all[0], w_ct_all) - else: - return (o_all, s_all, i_ct_all, w_ct_all) + if have_single_input and compute_grad: + assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 + return (o_all, s_all, i_ct_all[0], w_ct_all) + else: + return (o_all, s_all, i_ct_all, w_ct_all) class SelfAttention(base.Layer): - """Memory-efficient self-attention (second attempt).""" - - def __init__(self, - n_heads=2, d_qk=64, d_v=64, share_qk=False, - causal=False, masked=False, - chunk_len=None, n_chunks_before=0, n_chunks_after=0, - bias=False, - mode='train', - predict_mem_len=None, predict_drop_len=None, - attention_dropout=0.0, - output_dropout=0.0, - n_parallel_heads=None, - use_python_loop=False, - use_reference_code=False, - ): - """Construct a self-attention layer. + """Memory-efficient self-attention (second attempt).""" + + def __init__( + self, + n_heads=2, + d_qk=64, + d_v=64, + share_qk=False, + causal=False, + masked=False, + chunk_len=None, + n_chunks_before=0, + n_chunks_after=0, + bias=False, + mode="train", + predict_mem_len=None, + predict_drop_len=None, + attention_dropout=0.0, + output_dropout=0.0, + n_parallel_heads=None, + use_python_loop=False, + use_reference_code=False, + ): + """Construct a self-attention layer. + + Args: + n_heads: int: Number of attention heads + d_qk: int: Depth of query ond key vectors + d_v: int: Depth of value vectors + share_qk: bool: Set to True to share query and key projection weights + causal: bool: Set to True to mask out attention to future items + masked: bool: Set to True to accept an additional mask argument, that + allows masking out attention to padding tokens. + chunk_len (optional): Number of tokens per chunk. Setting this option will + enable chunked attention. + n_chunks_before: Number of previous chunks to attend to, when using + chunked attention. + n_chunks_after: Number of subsequent chunks to attend to, when using + chunked attention. Don't use this option for causal attention, because + attention to future tokens will be masked out anyway. However, note that + cross-chunk attention "wraps around" in both directions, so this option + is never a strict no-op. + bias: bool: Set to True to add bias vectors when computing query/key/value + mode: 'train', 'eval', or 'predict' + predict_mem_len: int: Number of input positions to remember in a cache + when doing fast inference. Whenever the cache fills up, some input + elements will be forgotten. When chunking is enabled, the default is to + store chunk_len * (1 + n_chunks_before) elements. + predict_drop_len: int: Number of input elements to drop once the fast + inference input cache fills up. When chunking is enabled, the default is + to drop exactly chunk_len elements. + attention_dropout: Dropout probability for attention mask. + output_dropout: Dropout probability for the layer output. + n_parallel_heads: Number of attention heads to compute in parallel. + + - If `n_parallel_heads` is None (default), the entire layer is + computed with maximum parallelism. This mode is the fastest, but + also uses the most memory. Start with this mode, but switch to one + of the others if memory runs out. + - If `n_parallel_heads` is 1, attention is computed one head at a + time, and one example at a time. This mode uses the least memory + but is not as fast as batched attention. Use this mode when working + with very long sequences, such that any amount of parallelism won't + fit in memory. + - If `n_parallel_heads` is a multiple of `n_heads`, attention is + computed for sub-batches of (`n_parallel_heads // n_heads`) + examples at a time. + - If `1 < n_parallel_heads < n_heads`, attention is computed for + several heads at a time, but only within a single example. It must + be the case that `n_heads` is a multiple of `n_parallel_heads`. Use + this mode for long sequences, to strike a balance between + parallelism and memory usage. + use_python_loop: Set to True to use a Python loop when iterating over + sub-batches of examples/heads (as opposed to a JAX/XLA loop). + This option will increase compilation time and jitted code size, + potentially drastically. Using it is not recommended except for + testing/debugging. In particular, note that enabling this option on + TPU can decrease the maximum model size that will fit in memory. + use_reference_code: Set to True to fall back to the reference + implementation of batched attention. This option will increase + compilation time and jitted code size, potentially drastically. Using + it is not recommended except for testing/debugging. + + """ + super().__init__(n_in=(2 if masked else 1), n_out=1) + + self._incremental = mode == "predict" + if self._incremental: + assert causal, "Only causal attention supports fast inference" + assert chunk_len is not None or (predict_mem_len and predict_drop_len) + predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) + predict_drop_len = predict_drop_len or chunk_len + if predict_mem_len is None or predict_drop_len is None: + raise ValueError("This configuration does not support fast inference.") + if not 0 < predict_drop_len <= predict_mem_len: + raise ValueError( + "Bad parameter values: (predict_mem_len, predict_drop_len) = ", + predict_mem_len, + predict_drop_len, + ) + self._predict_mem_len = predict_mem_len + self._predict_drop_len = predict_drop_len + + self._n_heads = n_heads + if n_parallel_heads: + if (n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) or ( + n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0 + ): + raise ValueError( + "n_parallel_heads must be a multiple or fraction of n_heads" + ) + self._n_parallel_heads = n_parallel_heads + else: + self._n_parallel_heads = None + self._use_python_loop = use_python_loop + self._use_reference_code = use_reference_code + + self._d_qk = d_qk + self._d_v = d_v + self._share_qk = share_qk + self._causal = causal + self._masked = masked + self._chunk_len = chunk_len + self._n_chunks_before = n_chunks_before + self._n_chunks_after = n_chunks_after + self._bias = bias + self._mode = mode + if mode == "train": + self._attention_dropout = attention_dropout + self._output_dropout = output_dropout + else: + self._attention_dropout = 0.0 + self._output_dropout = 0.0 + + def _kernel_initializer(self, shape, rng): + # Attention uses Glorot uniform initalization with respect to the *total* + # dimension of queries/key/values across all heads. We initialize one head + # at a time in this class, so init.GlorotUniformInitializer won't work. + # This initialization type is for parity with previous Trax & tensor2tensor + # Transformers; it's not clear if it's strictly needed for model accuracy. + lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) + return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) + + def init_weights_and_state(self, input_signature): + if not isinstance(input_signature, (tuple, list)): + input_signature = (input_signature,) + else: + input_signature = (input_signature[0],) + input_signature_unbatched = fastmath.nested_map( + lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), input_signature + ) + batch_size = int(input_signature[0].shape[0]) + + weights = [] + weight_rngs = fastmath.random.split(self.rng, self._n_heads) + for i in range(self._n_heads): + weights.append( + self.create_weights_unbatched(input_signature_unbatched, weight_rngs[i]) + ) + state = [] + state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) + for i in range(self._n_heads * batch_size): + state.append( + self.create_state_unbatched(input_signature_unbatched, state_rngs[i]) + ) + + stack_along_axis_0 = lambda *x: np.stack(x, axis=0) + weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) + state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - Args: - n_heads: int: Number of attention heads - d_qk: int: Depth of query ond key vectors - d_v: int: Depth of value vectors - share_qk: bool: Set to True to share query and key projection weights - causal: bool: Set to True to mask out attention to future items - masked: bool: Set to True to accept an additional mask argument, that - allows masking out attention to padding tokens. - chunk_len (optional): Number of tokens per chunk. Setting this option will - enable chunked attention. - n_chunks_before: Number of previous chunks to attend to, when using - chunked attention. - n_chunks_after: Number of subsequent chunks to attend to, when using - chunked attention. Don't use this option for causal attention, because - attention to future tokens will be masked out anyway. However, note that - cross-chunk attention "wraps around" in both directions, so this option - is never a strict no-op. - bias: bool: Set to True to add bias vectors when computing query/key/value - mode: 'train', 'eval', or 'predict' - predict_mem_len: int: Number of input positions to remember in a cache - when doing fast inference. Whenever the cache fills up, some input - elements will be forgotten. When chunking is enabled, the default is to - store chunk_len * (1 + n_chunks_before) elements. - predict_drop_len: int: Number of input elements to drop once the fast - inference input cache fills up. When chunking is enabled, the default is - to drop exactly chunk_len elements. - attention_dropout: Dropout probability for attention mask. - output_dropout: Dropout probability for the layer output. - n_parallel_heads: Number of attention heads to compute in parallel. - - - If `n_parallel_heads` is None (default), the entire layer is - computed with maximum parallelism. This mode is the fastest, but - also uses the most memory. Start with this mode, but switch to one - of the others if memory runs out. - - If `n_parallel_heads` is 1, attention is computed one head at a - time, and one example at a time. This mode uses the least memory - but is not as fast as batched attention. Use this mode when working - with very long sequences, such that any amount of parallelism won't - fit in memory. - - If `n_parallel_heads` is a multiple of `n_heads`, attention is - computed for sub-batches of (`n_parallel_heads // n_heads`) - examples at a time. - - If `1 < n_parallel_heads < n_heads`, attention is computed for - several heads at a time, but only within a single example. It must - be the case that `n_heads` is a multiple of `n_parallel_heads`. Use - this mode for long sequences, to strike a balance between - parallelism and memory usage. - use_python_loop: Set to True to use a Python loop when iterating over - sub-batches of examples/heads (as opposed to a JAX/XLA loop). - This option will increase compilation time and jitted code size, - potentially drastically. Using it is not recommended except for - testing/debugging. In particular, note that enabling this option on - TPU can decrease the maximum model size that will fit in memory. - use_reference_code: Set to True to fall back to the reference - implementation of batched attention. This option will increase - compilation time and jitted code size, potentially drastically. Using - it is not recommended except for testing/debugging. + if self._incremental: + mem = fastmath.nested_map( + lambda x: np.zeros( # pylint: disable=g-long-lambda + x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], dtype=x.dtype + ), + input_signature, + ) + mem_end = np.zeros((), dtype=np.int32) + state = (mem_end, mem, state) + + self.state = tuple(state) + self.weights = tuple(weights) + + def create_weights_unbatched(self, input_signature, rng): + if isinstance(input_signature, (tuple, list)): + input_signature = input_signature[0] + d_model = input_signature.shape[-1] + rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) + w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) + if not self._share_qk: + w_k = self._kernel_initializer((d_model, self._d_qk), rng_k) + w_v = self._kernel_initializer((d_model, self._d_v), rng_v) + w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) + + if self._bias: + b_q = np.zeros(self._d_qk) + b_v = np.zeros(self._d_v) + if self._share_qk: + return (w_q, w_v, w_o, b_q, b_v) + else: + b_k = np.zeros(self._d_qk) + return (w_q, w_k, w_v, w_o, b_q, b_k, b_v) + + if self._share_qk: + return (w_q, w_v, w_o) + else: + return (w_q, w_k, w_v, w_o) + + def create_state_unbatched(self, input_signature, rng): + return () + + def forward_unbatched(self, x, mask=None, *, weights, state, rng, update_state): + """Perform attention for a single batch element and head. + + Args: + x: Inputs for a single example (subclasses may use different inputs) + mask: Mask for the inputs. + weights: Weights for a single attention head + state: State for a single example & attention head pair. + rng: PRNG key for the layer (shared across all examples and heads) + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state) -- output and new state for a single example + and attention head. + """ + + del update_state + attend_rng, output_rng = fastmath.random.split(rng) + if self._bias: + if self._share_qk: + w_q, w_v, w_o, b_q, b_v = weights + else: + w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights + else: + if self._share_qk: + w_q, w_v, w_o = weights + else: + w_q, w_k, w_v, w_o = weights + + q = np.matmul(x, w_q) + k = None + if not self._share_qk: + k = np.matmul(x, w_k) + v = np.matmul(x, w_v) + + if self._bias: + q = q + b_q + if not self._share_qk: + k = k + b_k + v = v + b_v + + mask_fn = functools.partial( + mask_self_attention, + causal=self._causal, + exclude_self=self._share_qk, + masked=self._masked, + ) + q_info = kv_info = np.arange(q.shape[-2], dtype=np.int32) + + assert (mask is not None) == self._masked + if self._masked: + # mask is a boolean array (True means "is valid token") + ones_like_mask = np.ones_like(mask, dtype=np.int32) + kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask) + + o, _ = attend( + q, + k, + v, + q_chunk_len=self._chunk_len, + kv_chunk_len=self._chunk_len, + n_chunks_before=self._n_chunks_before, + n_chunks_after=self._n_chunks_after, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) - """ - super().__init__(n_in=(2 if masked else 1), n_out=1) - - self._incremental = (mode == 'predict') - if self._incremental: - assert causal, 'Only causal attention supports fast inference' - assert chunk_len is not None or (predict_mem_len and predict_drop_len) - predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) - predict_drop_len = predict_drop_len or chunk_len - if predict_mem_len is None or predict_drop_len is None: - raise ValueError('This configuration does not support fast inference.') - if not 0 < predict_drop_len <= predict_mem_len: - raise ValueError( - 'Bad parameter values: (predict_mem_len, predict_drop_len) = ', - predict_mem_len, predict_drop_len) - self._predict_mem_len = predict_mem_len - self._predict_drop_len = predict_drop_len - - self._n_heads = n_heads - if n_parallel_heads: - if ((n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) - or (n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0)): - raise ValueError( - 'n_parallel_heads must be a multiple or fraction of n_heads') - self._n_parallel_heads = n_parallel_heads - else: - self._n_parallel_heads = None - self._use_python_loop = use_python_loop - self._use_reference_code = use_reference_code - - self._d_qk = d_qk - self._d_v = d_v - self._share_qk = share_qk - self._causal = causal - self._masked = masked - self._chunk_len = chunk_len - self._n_chunks_before = n_chunks_before - self._n_chunks_after = n_chunks_after - self._bias = bias - self._mode = mode - if mode == 'train': - self._attention_dropout = attention_dropout - self._output_dropout = output_dropout - else: - self._attention_dropout = 0.0 - self._output_dropout = 0.0 - - def _kernel_initializer(self, shape, rng): - # Attention uses Glorot uniform initalization with respect to the *total* - # dimension of queries/key/values across all heads. We initialize one head - # at a time in this class, so init.GlorotUniformInitializer won't work. - # This initialization type is for parity with previous Trax & tensor2tensor - # Transformers; it's not clear if it's strictly needed for model accuracy. - lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) - return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) - - def init_weights_and_state(self, input_signature): - if not isinstance(input_signature, (tuple, list)): - input_signature = (input_signature,) - else: - input_signature = (input_signature[0],) - input_signature_unbatched = fastmath.nested_map( - lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), - input_signature) - batch_size = int(input_signature[0].shape[0]) - - weights = [] - weight_rngs = fastmath.random.split(self.rng, self._n_heads) - for i in range(self._n_heads): - weights.append(self.create_weights_unbatched(input_signature_unbatched, - weight_rngs[i])) - state = [] - state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) - for i in range(self._n_heads * batch_size): - state.append(self.create_state_unbatched(input_signature_unbatched, - state_rngs[i])) - - stack_along_axis_0 = lambda *x: np.stack(x, axis=0) - weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) - state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - - if self._incremental: - mem = fastmath.nested_map( - lambda x: np.zeros( # pylint: disable=g-long-lambda - x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], - dtype=x.dtype), - input_signature) - mem_end = np.zeros((), dtype=np.int32) - state = (mem_end, mem, state) - - self.state = tuple(state) - self.weights = tuple(weights) - - def create_weights_unbatched(self, input_signature, rng): - if isinstance(input_signature, (tuple, list)): - input_signature = input_signature[0] - d_model = input_signature.shape[-1] - rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) - w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) - if not self._share_qk: - w_k = self._kernel_initializer((d_model, self._d_qk), rng_k) - w_v = self._kernel_initializer((d_model, self._d_v), rng_v) - w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) - - if self._bias: - b_q = np.zeros(self._d_qk) - b_v = np.zeros(self._d_v) - if self._share_qk: - return (w_q, w_v, w_o, b_q, b_v) - else: - b_k = np.zeros(self._d_qk) - return (w_q, w_k, w_v, w_o, b_q, b_k, b_v) - - if self._share_qk: - return (w_q, w_v, w_o) - else: - return (w_q, w_k, w_v, w_o) + out = np.matmul(o, w_o) + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + return out, state - def create_state_unbatched(self, input_signature, rng): - return () + def _incremental_forward_unbatched( + self, x, mask=None, *, q_start, q_len, weights, state, rng, update_state + ): + """Perform fast inference for a single batch element and head. + + Args: + x: Inputs for a single example (subclasses may use different inputs) + mask: inputs mask. + q_start: Index along the sequence-length dimension that points to the + first input element that should be used as a query (and not just a key). + q_len: Number of new query elements in this call to the attention + mechanism. This is typically 1 for autoregressive decoding, but may be + longer if initializing a language model with a prefix. + weights: Weights for a single attention head + state: State for a single example & attention head pair. + rng: PRNG key for the layer (shared across all examples and heads) + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state) -- output and new state for a single example + and attention head. + """ + del update_state + attend_rng, output_rng = fastmath.random.split(rng) + if self._share_qk: + w_q, w_v, w_o = weights + else: + w_q, w_k, w_v, w_o = weights + + q_range = q_start + np.arange(q_len, dtype=np.int32) + if q_len == 1: + # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not + # floating-point equivalent, at least in non-jitted code. We correct the + # discrepancy by duplicating the slice. Floating-point noise may not be + # an issue when using models, but it makes it harder to write tests that + # compare fast and slow inference code for equivalence. + q = np.matmul(np.concatenate([x[q_range]] * 2, 0), w_q) + else: + q = np.matmul(x[q_range], w_q) + if self._share_qk: + k = length_normalized(np.matmul(x, w_q)) + else: + k = np.matmul(x, w_k) + v = np.matmul(x, w_v) + + mask_fn = functools.partial( + mask_self_attention, + causal=self._causal, + exclude_self=self._share_qk, + masked=self._masked, + ) + q_info = q_range + kv_info = np.arange(k.shape[-2], dtype=np.int32) + + if self._chunk_len is not None and q_len > self._chunk_len: + assert q_start == 0 + assert q_len % self._chunk_len == 0 + o, _ = attend( + q, + k, + v, + q_chunk_len=self._chunk_len, + kv_chunk_len=self._chunk_len, + n_chunks_before=self._n_chunks_before, + n_chunks_after=self._n_chunks_after, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) + else: + o, _ = attend( + q, + k, + v, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) + + out = np.matmul(o, w_o) + if q_len == 1: + out = out[:1] + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + return out, state + + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. + + Args: + inputs: Layer inputs (subclasses may use different inputs) + + Returns: + A tuple (output, new_state). + """ + weights, state, rng = self.weights, self.state, self.rng + if not self._use_reference_code: + # By default, an efficient, batched implementation is used. + output, new_state, _, _ = self.forward_and_or_backward( + inputs, weights, state, rng, compute_output=True, update_state=True + ) + self.state = new_state + return output + + # The reference implementation below provides a more readable overview of + # what this class does. It's not optimized, however, and should only be used + # when testing this class for correctness. + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] - def forward_unbatched(self, x, mask=None, *, - weights, state, rng, update_state): - """Perform attention for a single batch element and head. + if self._incremental: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + + output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] + new_state = [] + for example_idx in range(batch_size): + for head_idx in range(self._n_heads): + # pylint: disable=cell-var-from-loop + single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) + single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) + single_state = fastmath.nested_map( + lambda s: s[example_idx * self._n_heads + head_idx], state + ) + # pylint: enable=cell-var-from-loop + if self._incremental: + single_out, single_new_state = self._incremental_forward_unbatched( + *single_inputs, + q_start=q_start, + q_len=seqlen, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + else: + single_out, single_new_state = self.forward_unbatched( + *single_inputs, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + new_state.append(single_new_state) + output_accum[example_idx] = output_accum[example_idx] + single_out + + output = np.stack(output_accum, 0) + if new_state and fastmath.tree_leaves(new_state[0]): + new_state = fastmath.nested_map_multiarg( + lambda *s: np.stack(s, 0), *new_state + ) + else: + new_state = state + if self._incremental: + new_state = (new_mem_end, new_mem, new_state) + self.state = tuple(new_state) + return output + + def _use_predict_mem(self, inputs, state): + """Update input cache for fast inference.""" + mem_end, mem, state = state + seqlen = inputs[0].shape[-2] + + if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: + # This branch is called when only a small number of tokens are appended to + # the sequence, e.g. when generating one token at a time. A fixed number + # of tokens (self._predict_drop_tokens) will be dropped from memory if + # needed, and then new values will be inserted into the memory. + def roll_mem(buf): + return np.concatenate( + [ + buf[:, self._predict_drop_len :], + np.zeros_like(buf[:, : self._predict_drop_len]), + ], + axis=1, + ) + + do_roll_mem = mem_end + seqlen > self._predict_mem_len + mem = fastmath.cond( + pred=do_roll_mem, + true_operand=mem, + true_fun=lambda x: fastmath.nested_map(roll_mem, x), + false_operand=mem, + false_fun=lambda x: x, + ) + mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) + + def update_mem(mem_element, new_vals): + assert new_vals.shape[1] == seqlen + if seqlen == 1: + return fastmath.index_update( + mem_element, + jax.numpy.index_exp[:, mem_end], + new_vals[:, 0, ...], + ) + else: + return fastmath.dynamic_update_slice_in_dim( + mem_element, new_vals, mem_end, axis=1 + ) + + inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) + return inputs, state, mem_end, inputs, mem_end + seqlen + else: + assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len + # This branch handles the case where a large number of tokens are being + # introduced all at once. The code here assumes that we are at the start + # of the sequence, which matches the typical use case of decoding from a + # language model given a long prefix. Note that if we're not at the start + # of the sequence, the code here won't work. + new_flat_mem = [] + for inp in fastmath.tree_leaves(inputs): + assert inp.shape[1] == seqlen + if seqlen == self._predict_mem_len: + new_mem_val = inp + elif seqlen > self._predict_mem_len: + new_mem_val = inp[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + else: + new_mem_val = np.concatenate( + [ + inp, + np.zeros( + inp.shape[:1] + + (self._predict_mem_len - inp.shape[1],) + + inp.shape[2:], + dtype=inp.dtype, + ), + ], + axis=1, + ) + new_flat_mem.append(new_mem_val) + mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) + + # This code only works at the start of the sequence. There's no "assert" + # primitive we can use to signal an error, so we instead signal the error + # by introducing NaNs into the computation. + def replace_with_nan_if_not_seq_start(x): + if x.dtype != np.float32: + return x + return fastmath.cond( + pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), + true_operand=x, + true_fun=lambda x: x, + false_operand=x, + false_fun=lambda x: x * np.nan, + ) + + inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) + return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) + + @property + def has_backward(self): + # Use an efficient backward pass, unless we're running the reference code. + return not self._use_reference_code + + def backward( + self, inputs, output, grad, weights, state, new_state, rng=None, **kwargs + ): + """Custom backward pass, for efficiency (see forward_and_or_backward).""" + assert not self._use_reference_code + del output, state, kwargs + _, _, inputs_grad, weights_grad = self.forward_and_or_backward( + inputs, + weights, + new_state, + rng, + output_grad=grad, + compute_output=False, + update_state=False, + ) + return inputs_grad, weights_grad - Args: - x: Inputs for a single example (subclasses may use different inputs) - mask: Mask for the inputs. - weights: Weights for a single attention head - state: State for a single example & attention head pair. - rng: PRNG key for the layer (shared across all examples and heads) - update_state: bool: whether to return an updated layer state. + def forward_and_or_backward( + self, + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + See `forward` for a reference implementation of what this layer does. The + reference implementation is not very efficient, however, and this method + provides a more performant version. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # TODO(kitaev): profile ~4% speed drop compared to previous implementation + # in some conditions. Other conditions (e.g. the enwik8 model) appear + # to have the same overall training speed. + # TODO(b/148460708): reduce memory usage further + # TODO(kitaev): there should be a higher-level API (like vmap) that does + # batching, instead of needing 3 separate manual implementations here. + + # Notes regarding the implementation: + # (a) Multiple heads or examples are batched together. There are three + # different regimes possible: one head at a time (for long sequences and + # expensive attention types), several attention heads at a time (for + # long sequences but less-expensive attention types), and several + # examples at a time (for large batches of shorter sequences). For the + # time being, each of these regimes has its own code. + # (b) Python loops produce large computation graphs when jitted, so the + # default is to use a JAX loop instead. + # (c) No intermediate quantities are cached for the backward pass. Instead, + # the forward pass is re-computed when doing backprop. This approach is + # often called "checkpointing" or "rematerialization". When not all + # examples or heads fit in memory simultaneously, the implementation + # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse + # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] + # automatically, so the looping for the backward pass is done manually. + # + # [FW-BW-1] for example, head in zip(examples, heads): + # forward(example, head) + # backward(example, head) # uses intermediates from forward + # + # [FW-BW-2] for example, head in zip(examples, heads): + # forward(example, head) + # for example, head in zip(examples, heads): + # backward(example, head) + + have_single_input = not isinstance(inputs, (tuple, list)) + if have_single_input: + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] + + compute_grad = output_grad is not None + assert compute_output or compute_grad, "No work to perform!" + + if not self._incremental: + forward_unbatched = functools.partial( + self.forward_unbatched, rng=rng, update_state=update_state + ) + else: + if update_state: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + else: + # This assumes that the memory stores all of the inputs, which would not + # be valid if doing backprop in mode 'predict' with long lengths. + new_mem_end, inputs, state = state + q_start = new_mem_end - seqlen + + forward_unbatched = functools.partial( + self._incremental_forward_unbatched, + q_start=fastmath.stop_gradient(q_start), + q_len=fastmath.stop_gradient(seqlen), + rng=rng, + update_state=update_state, + ) + + # Adjust degree of parallelism based on the batch size. + n_parallel_heads = batch_size * self._n_heads + if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: + n_parallel_heads = self._n_parallel_heads + + def tree_update(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) + + def tree_add(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) - Returns: - A tuple (output, new_state) -- output and new state for a single example - and attention head. - """ + if compute_grad: + inputs_is_differentiable = fastmath.nested_map( + lambda x: np.issubdtype(x.dtype, np.inexact), inputs + ) + + def split_differentiable(xs): + differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: x if is_differentiable else None, + xs, + inputs_is_differentiable, + ) + non_differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: None if is_differentiable else x, + xs, + inputs_is_differentiable, + ) + return differentiable_xs, non_differentiable_xs + + def join_differentiable(differentiable_xs, non_differentiable_xs): + """Reconstitute inputs pytree from differentiable/non-d. partitions.""" + differentiable_leaves = fastmath.tree_leaves(differentiable_xs) + non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) + leaves = [] + for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): + if is_differentiable: + leaves.append(differentiable_leaves.pop(0)) + else: + leaves.append(non_differentiable_leaves.pop(0)) + assert not differentiable_leaves + assert not non_differentiable_leaves + tree, _ = fastmath.tree_unflatten(leaves, inputs) + return tree + + def vjp(fn, inp, *args, has_aux=False): + d_inp, nd_inp = split_differentiable(inp) + + def fn_closed_over_nd_inp(d_inp, *args): + inp = join_differentiable(d_inp, nd_inp) + return fn(inp, *args) + + return fastmath.vjp( + fn_closed_over_nd_inp, d_inp, *args, has_aux=has_aux + ) + + if n_parallel_heads == 1: + + def run_inner(idx, loop_val): + """Runs one slice of attention (for a single head).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + example_idx = idx // self._n_heads + head_idx = idx % self._n_heads + + i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_h = fastmath.nested_map(lambda w: w[head_idx], weights) + s_h = fastmath.nested_map(lambda s: s[idx], state) + + def forward_fn(i_h, w_h): + return forward_unbatched( + *i_h, weights=w_h, state=fastmath.stop_gradient(s_h) + ) + + if compute_grad: + o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) + ct_h = output_grad[example_idx] + assert o_h.shape == ct_h.shape + i_ct_h, w_ct_h = backward_fn(ct_h) + else: + o_h, s_h = forward_fn(i_h, w_h) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_h) + if update_state: + s_all = tree_update(s_all, idx, s_h) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) + w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) + return (o_all, s_all, i_ct_all, w_ct_all) + + elif n_parallel_heads < self._n_heads: + assert self._n_heads % n_parallel_heads == 0 + + def run_inner(idx, loop_val): + """Runs one slice of attention (multiple heads, but one example).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * self._n_parallel_heads + example_idx = idx // self._n_heads + head_idx_lo = idx % self._n_heads + head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_mh = fastmath.nested_map(lambda w: w[head_range], weights) + s_mh = fastmath.nested_map(lambda s: s[state_range], state) + + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + def forward_fn(i_mh, w_mh): + o_mh, new_s_mh = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0 + )(i_mh, w_mh, s_mh) + o_mh = np.sum(o_mh, axis=0) + return o_mh, new_s_mh + + if compute_grad: + o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) + ct_mh = output_grad[example_idx] + assert o_mh.shape == ct_mh.shape + i_ct_mh, w_ct_mh = backward_fn(ct_mh) + else: + o_mh, s_mh = forward_fn(i_mh, w_mh) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_mh) + if update_state: + s_all = tree_update(s_all, state_range, s_mh) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) + w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) + return (o_all, s_all, i_ct_all, w_ct_all) - del update_state - attend_rng, output_rng = fastmath.random.split(rng) - if self._bias: - if self._share_qk: - w_q, w_v, w_o, b_q, b_v = weights - else: - w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights - else: - if self._share_qk: - w_q, w_v, w_o = weights - else: - w_q, w_k, w_v, w_o = weights + else: + assert n_parallel_heads % self._n_heads == 0 + + def forward_single_example(i_x, w_all, s_x): + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + o_x, s_x = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0) + )(i_x, w_all, s_x) + o_x = np.sum(o_x, axis=0) + return o_x, s_x + + def run_inner(idx, loop_val): + """Runs one slice of attention (all heads for one or more examples).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * n_parallel_heads + example_idx_lo = idx // self._n_heads + example_range = example_idx_lo + np.arange( + n_parallel_heads // self._n_heads, dtype=np.int32 + ) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) + s_mex = fastmath.nested_map( + lambda s: np.reshape( + s[state_range], # pylint: disable=g-long-lambda + (-1, self._n_heads) + s.shape[1:], + ), + state, + ) + + def forward_fn(i_mex, w_all): + o_mex, new_s_mex = fastmath.vmap( + forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0) + )(i_mex, w_all, s_mex) + new_s_mex = fastmath.nested_map( + lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), + new_s_mex, + ) + return o_mex.astype(i_mex[0].dtype), new_s_mex + + if compute_grad: + o_mex, backward_fn, s_mex = vjp( + forward_fn, i_mex, weights, has_aux=True + ) + ct_mex = output_grad[example_range] + assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) + assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) + i_ct_mex, w_ct_mex = backward_fn(ct_mex) + else: + o_mex, s_mex = forward_fn(i_mex, weights) + + if compute_output: + o_all = fastmath.index_add( + o_all, jax.numpy.index_exp[example_range], o_mex + ) + if update_state: + s_all = tree_update(s_all, state_range, s_mex) + if compute_grad: + i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) + w_ct_all = fastmath.nested_map_multiarg( + lambda old_all, delta_all: old_all + delta_all, + w_ct_all, + w_ct_mex, + ) + return (o_all, s_all, i_ct_all, w_ct_all) + + o_all = s_all = i_ct_all = w_ct_all = None + if compute_output: + o_all = np.zeros((batch_size, seqlen, d_model), dtype=inputs[0].dtype) + if update_state: + s_all = state + if compute_grad: + i_ct_all = fastmath.nested_map(np.zeros_like, inputs) + i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) + w_ct_all = fastmath.nested_map(np.zeros_like, weights) - q = np.matmul(x, w_q) - k = None - if not self._share_qk: - k = np.matmul(x, w_k) - v = np.matmul(x, w_v) - - if self._bias: - q = q + b_q - if not self._share_qk: - k = k + b_k - v = v + b_v - - mask_fn = functools.partial( - mask_self_attention, - causal=self._causal, exclude_self=self._share_qk, masked=self._masked) - q_info = kv_info = np.arange(q.shape[-2], dtype=np.int32) - - assert (mask is not None) == self._masked - if self._masked: - # mask is a boolean array (True means "is valid token") - ones_like_mask = np.ones_like(mask, dtype=np.int32) - kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask) - - o, _ = attend( - q, k, v, - q_chunk_len=self._chunk_len, - kv_chunk_len=self._chunk_len, - n_chunks_before=self._n_chunks_before, - n_chunks_after=self._n_chunks_after, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, - ) + loop_val = (o_all, s_all, i_ct_all, w_ct_all) - out = np.matmul(o, w_o) - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - return out, state + assert (batch_size * self._n_heads) % n_parallel_heads == 0 + loop_hi = (batch_size * self._n_heads) // n_parallel_heads + if self._use_python_loop or loop_hi == 1: + for idx in range(loop_hi): + loop_val = run_inner(idx, loop_val) + else: + loop_val = fastmath.fori_loop(0, loop_hi, run_inner, loop_val) - def _incremental_forward_unbatched(self, x, mask=None, *, - q_start, q_len, - weights, state, rng, update_state): - """Perform fast inference for a single batch element and head. + (o_all, s_all, i_ct_all, w_ct_all) = loop_val - Args: - x: Inputs for a single example (subclasses may use different inputs) - mask: inputs mask. - q_start: Index along the sequence-length dimension that points to the - first input element that should be used as a query (and not just a key). - q_len: Number of new query elements in this call to the attention - mechanism. This is typically 1 for autoregressive decoding, but may be - longer if initializing a language model with a prefix. - weights: Weights for a single attention head - state: State for a single example & attention head pair. - rng: PRNG key for the layer (shared across all examples and heads) - update_state: bool: whether to return an updated layer state. + if compute_grad: + i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) - Returns: - A tuple (output, new_state) -- output and new state for a single example - and attention head. - """ - del update_state - attend_rng, output_rng = fastmath.random.split(rng) - if self._share_qk: - w_q, w_v, w_o = weights - else: - w_q, w_k, w_v, w_o = weights - - q_range = q_start + np.arange(q_len, dtype=np.int32) - if q_len == 1: - # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not - # floating-point equivalent, at least in non-jitted code. We correct the - # discrepancy by duplicating the slice. Floating-point noise may not be - # an issue when using models, but it makes it harder to write tests that - # compare fast and slow inference code for equivalence. - q = np.matmul(np.concatenate([x[q_range]] * 2, 0), w_q) - else: - q = np.matmul(x[q_range], w_q) - if self._share_qk: - k = length_normalized(np.matmul(x, w_q)) - else: - k = np.matmul(x, w_k) - v = np.matmul(x, w_v) - - mask_fn = functools.partial( - mask_self_attention, - causal=self._causal, exclude_self=self._share_qk, masked=self._masked) - q_info = q_range - kv_info = np.arange(k.shape[-2], dtype=np.int32) - - if self._chunk_len is not None and q_len > self._chunk_len: - assert q_start == 0 - assert q_len % self._chunk_len == 0 - o, _ = attend( - q, k, v, - q_chunk_len=self._chunk_len, - kv_chunk_len=self._chunk_len, - n_chunks_before=self._n_chunks_before, - n_chunks_after=self._n_chunks_after, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, - ) - else: - o, _ = attend( - q, k, v, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, - ) + if self._incremental and update_state: + s_all = (new_mem_end, new_mem, s_all) - out = np.matmul(o, w_o) - if q_len == 1: - out = out[:1] - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - return out, state + if have_single_input and compute_grad: + assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 + return (o_all, s_all, i_ct_all[0], w_ct_all) + else: + return (o_all, s_all, i_ct_all, w_ct_all) - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. - Args: - inputs: Layer inputs (subclasses may use different inputs) +class LSHSelfAttention(base.Layer): + """LSH self-attention (second implementation).""" + + def __init__( + self, + n_heads=2, + d_qk=64, + d_v=64, + share_qk="unused", + causal=False, + masked=False, + chunk_len=128, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=1, + n_buckets=None, + mode="train", + predict_mem_len=2048, + predict_drop_len=256, + attention_dropout=0.0, + output_dropout=0.0, + max_length_for_buckets=None, + bias=False, + n_parallel_heads=1, + use_python_loop=False, + use_reference_code=False, + ): + """Construct an LSH self-attention layer.""" + super().__init__(n_in=(2 if masked else 1), n_out=1) + + self._n_heads = n_heads + if n_parallel_heads: + if (n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) or ( + n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0 + ): + raise ValueError( + "n_parallel_heads must be a multiple or fraction of n_heads" + ) + self._n_parallel_heads = n_parallel_heads + else: + self._n_parallel_heads = None - Returns: - A tuple (output, new_state). - """ - weights, state, rng = self.weights, self.state, self.rng - if not self._use_reference_code: - # By default, an efficient, batched implementation is used. - output, new_state, _, _ = self.forward_and_or_backward( - inputs, weights, state, rng, compute_output=True, update_state=True) - self.state = new_state - return output - - # The reference implementation below provides a more readable overview of - # what this class does. It's not optimized, however, and should only be used - # when testing this class for correctness. - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - if self._incremental: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - - output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] - new_state = [] - for example_idx in range(batch_size): - for head_idx in range(self._n_heads): - # pylint: disable=cell-var-from-loop - single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) - single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) - single_state = fastmath.nested_map( - lambda s: s[example_idx * self._n_heads + head_idx], state) - # pylint: enable=cell-var-from-loop + self._incremental = mode == "predict" if self._incremental: - single_out, single_new_state = self._incremental_forward_unbatched( - *single_inputs, q_start=q_start, q_len=seqlen, - weights=single_weights, rng=rng, - state=single_state, update_state=True) - else: - single_out, single_new_state = self.forward_unbatched( - *single_inputs, weights=single_weights, rng=rng, - state=single_state, update_state=True) - new_state.append(single_new_state) - output_accum[example_idx] = output_accum[example_idx] + single_out - - output = np.stack(output_accum, 0) - if new_state and fastmath.tree_leaves(new_state[0]): - new_state = fastmath.nested_map_multiarg( - lambda *s: np.stack(s, 0), *new_state) - else: - new_state = state - if self._incremental: - new_state = (new_mem_end, new_mem, new_state) - self.state = tuple(new_state) - return output - - def _use_predict_mem(self, inputs, state): - """Update input cache for fast inference.""" - mem_end, mem, state = state - seqlen = inputs[0].shape[-2] - - if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: - # This branch is called when only a small number of tokens are appended to - # the sequence, e.g. when generating one token at a time. A fixed number - # of tokens (self._predict_drop_tokens) will be dropped from memory if - # needed, and then new values will be inserted into the memory. - def roll_mem(buf): - return np.concatenate( - [buf[:, self._predict_drop_len:], - np.zeros_like(buf[:, :self._predict_drop_len])], axis=1) - - do_roll_mem = (mem_end + seqlen > self._predict_mem_len) - mem = fastmath.cond( - pred=do_roll_mem, - true_operand=mem, - true_fun=lambda x: fastmath.nested_map(roll_mem, x), - false_operand=mem, - false_fun=lambda x: x, - ) - mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) - def update_mem(mem_element, new_vals): - assert new_vals.shape[1] == seqlen - if seqlen == 1: - return fastmath.index_update( - mem_element, jax.numpy.index_exp[:, mem_end], new_vals[:, 0, ...]) - else: - return fastmath.dynamic_update_slice_in_dim( - mem_element, new_vals, mem_end, axis=1) - inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) - return inputs, state, mem_end, inputs, mem_end + seqlen - else: - assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len - # This branch handles the case where a large number of tokens are being - # introduced all at once. The code here assumes that we are at the start - # of the sequence, which matches the typical use case of decoding from a - # language model given a long prefix. Note that if we're not at the start - # of the sequence, the code here won't work. - new_flat_mem = [] - for inp in fastmath.tree_leaves(inputs): - assert inp.shape[1] == seqlen - if seqlen == self._predict_mem_len: - new_mem_val = inp - elif seqlen > self._predict_mem_len: - new_mem_val = inp[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type + assert causal, "Only causal attention supports fast inference" + assert chunk_len is not None or (predict_mem_len and predict_drop_len) + predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) + predict_drop_len = predict_drop_len or chunk_len + if predict_mem_len is None or predict_drop_len is None: + raise ValueError("This configuration does not support fast inference.") + if not 0 < predict_drop_len <= predict_mem_len: + raise ValueError( + "Bad parameter values: (predict_mem_len, predict_drop_len) = ", + predict_mem_len, + predict_drop_len, + ) + self._predict_mem_len = predict_mem_len + self._predict_drop_len = predict_drop_len + + self._use_python_loop = use_python_loop + self._use_reference_code = use_reference_code + + self._d_qk = d_qk + self._d_v = d_v + self._share_qk = True + self._causal = causal + self._masked = masked + self._chunk_len = chunk_len + self._n_chunks_before = n_chunks_before + self._n_chunks_after = n_chunks_after + self._bias = bias + self._mode = mode + if mode == "train": + self._attention_dropout = attention_dropout + self._output_dropout = output_dropout else: - new_mem_val = np.concatenate([ - inp, - np.zeros(inp.shape[:1] - + (self._predict_mem_len - inp.shape[1],) - + inp.shape[2:], - dtype=inp.dtype) - ], axis=1) - new_flat_mem.append(new_mem_val) - mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) - - # This code only works at the start of the sequence. There's no "assert" - # primitive we can use to signal an error, so we instead signal the error - # by introducing NaNs into the computation. - def replace_with_nan_if_not_seq_start(x): - if x.dtype != np.float32: - return x - return fastmath.cond( - pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), - true_operand=x, true_fun=lambda x: x, - false_operand=x, false_fun=lambda x: x * np.nan) - inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) - return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) - - @property - def has_backward(self): - # Use an efficient backward pass, unless we're running the reference code. - return not self._use_reference_code - - def backward(self, inputs, output, grad, weights, state, new_state, rng=None, - **kwargs): - """Custom backward pass, for efficiency (see forward_and_or_backward).""" - assert not self._use_reference_code - del output, state, kwargs - _, _, inputs_grad, weights_grad = self.forward_and_or_backward( - inputs, weights, new_state, rng, output_grad=grad, - compute_output=False, update_state=False) - return inputs_grad, weights_grad - - def forward_and_or_backward( - self, inputs, weights, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. - - See `forward` for a reference implementation of what this layer does. The - reference implementation is not very efficient, however, and this method - provides a more performant version. - - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. - - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). - - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # TODO(kitaev): profile ~4% speed drop compared to previous implementation - # in some conditions. Other conditions (e.g. the enwik8 model) appear - # to have the same overall training speed. - # TODO(b/148460708): reduce memory usage further - # TODO(kitaev): there should be a higher-level API (like vmap) that does - # batching, instead of needing 3 separate manual implementations here. - - # Notes regarding the implementation: - # (a) Multiple heads or examples are batched together. There are three - # different regimes possible: one head at a time (for long sequences and - # expensive attention types), several attention heads at a time (for - # long sequences but less-expensive attention types), and several - # examples at a time (for large batches of shorter sequences). For the - # time being, each of these regimes has its own code. - # (b) Python loops produce large computation graphs when jitted, so the - # default is to use a JAX loop instead. - # (c) No intermediate quantities are cached for the backward pass. Instead, - # the forward pass is re-computed when doing backprop. This approach is - # often called "checkpointing" or "rematerialization". When not all - # examples or heads fit in memory simultaneously, the implementation - # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse - # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] - # automatically, so the looping for the backward pass is done manually. - # - # [FW-BW-1] for example, head in zip(examples, heads): - # forward(example, head) - # backward(example, head) # uses intermediates from forward - # - # [FW-BW-2] for example, head in zip(examples, heads): - # forward(example, head) - # for example, head in zip(examples, heads): - # backward(example, head) - - have_single_input = not isinstance(inputs, (tuple, list)) - if have_single_input: - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - compute_grad = (output_grad is not None) - assert compute_output or compute_grad, 'No work to perform!' - - if not self._incremental: - forward_unbatched = functools.partial( - self.forward_unbatched, rng=rng, update_state=update_state) - else: - if update_state: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - else: - # This assumes that the memory stores all of the inputs, which would not - # be valid if doing backprop in mode 'predict' with long lengths. - new_mem_end, inputs, state = state - q_start = new_mem_end - seqlen - - forward_unbatched = functools.partial( - self._incremental_forward_unbatched, - q_start=fastmath.stop_gradient(q_start), - q_len=fastmath.stop_gradient(seqlen), - rng=rng, update_state=update_state) - - # Adjust degree of parallelism based on the batch size. - n_parallel_heads = batch_size * self._n_heads - if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: - n_parallel_heads = self._n_parallel_heads - - def tree_update(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], - y), - tree, new_values) - - def tree_add(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), - tree, new_values) - - if compute_grad: - inputs_is_differentiable = fastmath.nested_map( - lambda x: np.issubdtype(x.dtype, np.inexact), inputs) - def split_differentiable(xs): - differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: x if is_differentiable else None, - xs, inputs_is_differentiable) - non_differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: None if is_differentiable else x, - xs, inputs_is_differentiable) - return differentiable_xs, non_differentiable_xs - def join_differentiable(differentiable_xs, non_differentiable_xs): - """Reconstitute inputs pytree from differentiable/non-d. partitions.""" - differentiable_leaves = fastmath.tree_leaves(differentiable_xs) - non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) - leaves = [] - for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): - if is_differentiable: - leaves.append(differentiable_leaves.pop(0)) - else: - leaves.append(non_differentiable_leaves.pop(0)) - assert not differentiable_leaves - assert not non_differentiable_leaves - tree, _ = fastmath.tree_unflatten(leaves, inputs) - return tree - - def vjp(fn, inp, *args, has_aux=False): - d_inp, nd_inp = split_differentiable(inp) - def fn_closed_over_nd_inp(d_inp, *args): - inp = join_differentiable(d_inp, nd_inp) - return fn(inp, *args) - return fastmath.vjp(fn_closed_over_nd_inp, d_inp, *args, - has_aux=has_aux) - - if n_parallel_heads == 1: - def run_inner(idx, loop_val): - """Runs one slice of attention (for a single head).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - example_idx = idx // self._n_heads - head_idx = idx % self._n_heads - - i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_h = fastmath.nested_map(lambda w: w[head_idx], weights) - s_h = fastmath.nested_map(lambda s: s[idx], state) - - def forward_fn(i_h, w_h): - return forward_unbatched( - *i_h, weights=w_h, state=fastmath.stop_gradient(s_h)) + self._attention_dropout = 0.0 + self._output_dropout = 0.0 + + self._n_hashes = n_hashes + self._n_buckets = n_buckets + self._max_length_for_buckets = max_length_for_buckets + + def _kernel_initializer(self, shape, rng): + # Attention uses Glorot uniform initalization with respect to the *total* + # dimension of queries/key/values across all heads. We initialize one head + # at a time in this class, so init.GlorotUniformInitializer won't work. + # This initialization type is for parity with previous Trax & tensor2tensor + # Transformers; it's not clear if it's strictly needed for model accuracy. + lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) + return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) + + def init_weights_and_state(self, input_signature): + if not isinstance(input_signature, (tuple, list)): + input_signature = (input_signature,) + input_signature_unbatched = fastmath.nested_map( + lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), input_signature + ) + batch_size = int(input_signature[0].shape[0]) + + weights = [] + weight_rngs = fastmath.random.split(self.rng, self._n_heads) + for i in range(self._n_heads): + weights.append( + self.create_weights_unbatched(input_signature_unbatched, weight_rngs[i]) + ) + state = [] + state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) + for i in range(self._n_heads * batch_size): + state.append( + self.create_state_unbatched(input_signature_unbatched, state_rngs[i]) + ) + + stack_along_axis_0 = lambda *x: np.stack(x, axis=0) + weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) + state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - if compute_grad: - o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) - ct_h = output_grad[example_idx] - assert o_h.shape == ct_h.shape - i_ct_h, w_ct_h = backward_fn(ct_h) + if self._incremental: + mem = fastmath.nested_map( + lambda x: np.zeros( # pylint: disable=g-long-lambda + x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], dtype=x.dtype + ), + input_signature, + ) + mem_end = np.zeros((), dtype=np.int32) + state = (mem_end, mem, state) + + self.state = tuple(state) + self.weights = tuple(weights) + + def create_weights_unbatched(self, input_signature, rng): + if isinstance(input_signature, (tuple, list)): + input_signature = input_signature[0] + d_model = input_signature.shape[-1] + rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) + w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) + if not self._share_qk: + w_k = self._kernel_initializer((d_model, self._d_qk), rng_k) + w_v = self._kernel_initializer((d_model, self._d_v), rng_v) + w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) + + if self._bias: + b_q = np.zeros(self._d_qk) + b_v = np.zeros(self._d_v) + if self._share_qk: + return (w_q, w_v, w_o, b_q, b_v) + else: + b_k = np.zeros(self._d_qk) + return (w_q, w_k, w_v, w_o, b_q, b_k, b_v) + + if self._share_qk: + return (w_q, w_v, w_o) else: - o_h, s_h = forward_fn(i_h, w_h) - - if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_h) - if update_state: - s_all = tree_update(s_all, idx, s_h) - if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) - w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) - return (o_all, s_all, i_ct_all, w_ct_all) - elif n_parallel_heads < self._n_heads: - assert self._n_heads % n_parallel_heads == 0 - def run_inner(idx, loop_val): - """Runs one slice of attention (multiple heads, but one example).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * self._n_parallel_heads - example_idx = idx // self._n_heads - head_idx_lo = idx % self._n_heads - head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_mh = fastmath.nested_map(lambda w: w[head_range], weights) - s_mh = fastmath.nested_map(lambda s: s[state_range], state) - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - def forward_fn(i_mh, w_mh): - o_mh, new_s_mh = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0)( - i_mh, w_mh, s_mh) - o_mh = np.sum(o_mh, axis=0) - return o_mh, new_s_mh - - if compute_grad: - o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) - ct_mh = output_grad[example_idx] - assert o_mh.shape == ct_mh.shape - i_ct_mh, w_ct_mh = backward_fn(ct_mh) + return (w_q, w_k, w_v, w_o) + + def create_state_unbatched(self, input_signature, rng): + if isinstance(input_signature, (tuple, list)): + input_signature = input_signature[0] + # The `rng` argument passed to forward_unbatched is shared across all + # examples and heads. This facilitates using broadcasted dropout, which + # saves memory and hasn't been shown to hurt model quality. Even though the + # same sharing is likely to be safe when selecting random hash functions + # for LSH, we haven't run experiments to demonstrate this. To be on the safe + # side we include a per-head RNG in the state for the purpose of doing LSH. + if not self._incremental: + length = self._max_length_for_buckets or input_signature.shape[0] + buckets = np.zeros(self._n_hashes * length, dtype=np.int32) + return (buckets, rng) else: - o_mh, s_mh = forward_fn(i_mh, w_mh) + buckets = np.zeros(self._n_hashes * self._predict_mem_len, dtype=np.int32) + buckets_idx = np.zeros((), dtype=np.int32) + return (buckets, buckets_idx, rng) + + def hash_vectors(self, vecs, rng, mask=None): + n_buckets_list = self._n_buckets + + # Determine the number of buckets needed from input length if not set. + if n_buckets_list is None: + length = vecs.shape[0] + n_buckets = 2 * max(1, length // self._chunk_len) + if n_buckets <= 128: + n_buckets_list = n_buckets + else: # Factorize n_buckets. + n_buckets_div = 2 ** math.ceil(math.log2(math.sqrt(n_buckets))) + # Both factors must be even. + n_buckets_rest = 2 * (n_buckets // (2 * n_buckets_div)) + n_buckets_list = [n_buckets_div, n_buckets_rest] + + # Hash vectors. + buckets, n_buckets = hash_vecs(vecs, n_buckets_list, self._n_hashes, rng) + + if mask is not None: + n_buckets += 1 # Create an extra bucket for padding tokens only + buckets = np.where(mask[None, :], buckets, n_buckets - 1) + + # buckets is now (n_hashes, seqlen). Next we add offsets so that + # bucket numbers from different hashing rounds don't overlap. + offsets = np.arange(self._n_hashes, dtype=np.int32) + offsets = np.reshape(offsets * n_buckets, (-1, 1)) + buckets = np.reshape(buckets + offsets, (-1,)) + return buckets + + def forward_unbatched(self, x, mask=None, *, weights, state, rng, update_state): + attend_rng, output_rng = fastmath.random.split(rng) + w_q, w_v, w_o = weights - if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_mh) - if update_state: - s_all = tree_update(s_all, state_range, s_mh) - if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) - w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) - return (o_all, s_all, i_ct_all, w_ct_all) - else: - assert n_parallel_heads % self._n_heads == 0 - def forward_single_example(i_x, w_all, s_x): - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - o_x, s_x = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0))( - i_x, w_all, s_x) - o_x = np.sum(o_x, axis=0) - return o_x, s_x - def run_inner(idx, loop_val): - """Runs one slice of attention (all heads for one or more examples).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * n_parallel_heads - example_idx_lo = idx // self._n_heads - example_range = example_idx_lo + np.arange( - n_parallel_heads // self._n_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) - s_mex = fastmath.nested_map( - lambda s: np.reshape(s[state_range], # pylint: disable=g-long-lambda - (-1, self._n_heads) + s.shape[1:]), - state) - def forward_fn(i_mex, w_all): - o_mex, new_s_mex = fastmath.vmap( - forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0))( - i_mex, w_all, s_mex) - new_s_mex = fastmath.nested_map( - lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), - new_s_mex) - return o_mex.astype(i_mex[0].dtype), new_s_mex + q = np.matmul(x, w_q) + v = np.matmul(x, w_v) - if compute_grad: - o_mex, backward_fn, s_mex = vjp(forward_fn, i_mex, weights, - has_aux=True) - ct_mex = output_grad[example_range] - assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) - assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) - i_ct_mex, w_ct_mex = backward_fn(ct_mex) + if update_state: + _, old_hash_rng = state + hash_rng, hash_subrng = fastmath.random.split(old_hash_rng) + buckets = self.hash_vectors(q, hash_subrng, mask) + s_buckets = buckets + if self._max_length_for_buckets: + length = self._n_hashes * self._max_length_for_buckets + if buckets.shape[0] < length: + s_buckets = np.concatenate( + [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)], + axis=0, + ) + state = (s_buckets, hash_rng) else: - o_mex, s_mex = forward_fn(i_mex, weights) + buckets, _ = state + if self._max_length_for_buckets: + buckets = buckets[: self._n_hashes * x.shape[0]] - if compute_output: - o_all = fastmath.index_add(o_all, jax.numpy.index_exp[example_range], - o_mex) - if update_state: - s_all = tree_update(s_all, state_range, s_mex) - if compute_grad: - i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) - w_ct_all = fastmath.nested_map_multiarg( - lambda old_all, delta_all: old_all + delta_all, - w_ct_all, w_ct_mex) - return (o_all, s_all, i_ct_all, w_ct_all) - - o_all = s_all = i_ct_all = w_ct_all = None - if compute_output: - o_all = np.zeros( - (batch_size, seqlen, d_model), dtype=inputs[0].dtype) - if update_state: - s_all = state - if compute_grad: - i_ct_all = fastmath.nested_map(np.zeros_like, inputs) - i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - w_ct_all = fastmath.nested_map(np.zeros_like, weights) - - loop_val = (o_all, s_all, i_ct_all, w_ct_all) - - assert (batch_size * self._n_heads) % n_parallel_heads == 0 - loop_hi = (batch_size * self._n_heads) // n_parallel_heads - if self._use_python_loop or loop_hi == 1: - for idx in range(loop_hi): - loop_val = run_inner(idx, loop_val) - else: - loop_val = fastmath.fori_loop( - 0, loop_hi, run_inner, loop_val) + seqlen = x.shape[0] + assert int(buckets.shape[0]) == self._n_hashes * seqlen - (o_all, s_all, i_ct_all, w_ct_all) = loop_val + ticker = np.arange(self._n_hashes * seqlen, dtype=np.int32) + buckets_and_t = seqlen * buckets + (ticker % seqlen) + buckets_and_t = fastmath.stop_gradient(buckets_and_t) - if compute_grad: - i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + # Hash-based sort ("s" at the start of variable names means "sorted") + sbuckets_and_t, sticker = fastmath.sort_key_val( + buckets_and_t, ticker, dimension=-1 + ) + _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1) + sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t) + sticker = fastmath.stop_gradient(sticker) + undo_sort = fastmath.stop_gradient(undo_sort) + + st = sticker % seqlen + sq = np.take(q, st, axis=0) + sv = np.take(v, st, axis=0) + + mask_fn = functools.partial( + mask_self_attention, + causal=self._causal, + exclude_self=True, + masked=self._masked, + ) + q_info = st + + assert (mask is not None) == self._masked + kv_info = None + if self._masked: + # mask is a boolean array (True means "is valid token") + smask = np.take(mask, st, axis=0) + ones_like_mask = np.ones_like(smask, dtype=np.int32) + kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask) + + so, slogits = attend( + sq, + k=None, + v=sv, + q_chunk_len=self._chunk_len, + n_chunks_before=self._n_chunks_before, + n_chunks_after=self._n_chunks_after, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) - if self._incremental and update_state: - s_all = (new_mem_end, new_mem, s_all) + # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would + # also work, but these helpers include performance optimizations for TPU. + o = permute_via_gather(so, undo_sort, sticker, axis=0) + logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) - if have_single_input and compute_grad: - assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 - return (o_all, s_all, i_ct_all[0], w_ct_all) - else: - return (o_all, s_all, i_ct_all, w_ct_all) + if self._n_hashes > 1: + o = np.reshape(o, (self._n_hashes, seqlen, o.shape[-1])) + logits = np.reshape(logits, (self._n_hashes, seqlen, 1)) + probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True)) + o = np.sum(o * probs, axis=0) + assert o.shape == (seqlen, w_v.shape[-1]) + out = np.matmul(o, w_o) + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + return out, state -class LSHSelfAttention(base.Layer): - """LSH self-attention (second implementation).""" - - def __init__(self, - n_heads=2, d_qk=64, d_v=64, share_qk='unused', - causal=False, - masked=False, - chunk_len=128, n_chunks_before=1, n_chunks_after=0, - n_hashes=1, - n_buckets=None, - mode='train', - predict_mem_len=2048, predict_drop_len=256, - attention_dropout=0.0, - output_dropout=0.0, - max_length_for_buckets=None, - bias=False, - n_parallel_heads=1, - use_python_loop=False, - use_reference_code=False, - ): - """Construct an LSH self-attention layer.""" - super().__init__(n_in=(2 if masked else 1), n_out=1) - - self._n_heads = n_heads - if n_parallel_heads: - if ((n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) - or (n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0)): - raise ValueError( - 'n_parallel_heads must be a multiple or fraction of n_heads') - self._n_parallel_heads = n_parallel_heads - else: - self._n_parallel_heads = None - - self._incremental = (mode == 'predict') - if self._incremental: - assert causal, 'Only causal attention supports fast inference' - assert chunk_len is not None or (predict_mem_len and predict_drop_len) - predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) - predict_drop_len = predict_drop_len or chunk_len - if predict_mem_len is None or predict_drop_len is None: - raise ValueError('This configuration does not support fast inference.') - if not 0 < predict_drop_len <= predict_mem_len: - raise ValueError( - 'Bad parameter values: (predict_mem_len, predict_drop_len) = ', - predict_mem_len, predict_drop_len) - self._predict_mem_len = predict_mem_len - self._predict_drop_len = predict_drop_len - - self._use_python_loop = use_python_loop - self._use_reference_code = use_reference_code - - self._d_qk = d_qk - self._d_v = d_v - self._share_qk = True - self._causal = causal - self._masked = masked - self._chunk_len = chunk_len - self._n_chunks_before = n_chunks_before - self._n_chunks_after = n_chunks_after - self._bias = bias - self._mode = mode - if mode == 'train': - self._attention_dropout = attention_dropout - self._output_dropout = output_dropout - else: - self._attention_dropout = 0.0 - self._output_dropout = 0.0 - - self._n_hashes = n_hashes - self._n_buckets = n_buckets - self._max_length_for_buckets = max_length_for_buckets - - def _kernel_initializer(self, shape, rng): - # Attention uses Glorot uniform initalization with respect to the *total* - # dimension of queries/key/values across all heads. We initialize one head - # at a time in this class, so init.GlorotUniformInitializer won't work. - # This initialization type is for parity with previous Trax & tensor2tensor - # Transformers; it's not clear if it's strictly needed for model accuracy. - lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) - return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) - - def init_weights_and_state(self, input_signature): - if not isinstance(input_signature, (tuple, list)): - input_signature = (input_signature,) - input_signature_unbatched = fastmath.nested_map( - lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), - input_signature) - batch_size = int(input_signature[0].shape[0]) - - weights = [] - weight_rngs = fastmath.random.split(self.rng, self._n_heads) - for i in range(self._n_heads): - weights.append(self.create_weights_unbatched(input_signature_unbatched, - weight_rngs[i])) - state = [] - state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) - for i in range(self._n_heads * batch_size): - state.append(self.create_state_unbatched(input_signature_unbatched, - state_rngs[i])) - - stack_along_axis_0 = lambda *x: np.stack(x, axis=0) - weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) - state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - - if self._incremental: - mem = fastmath.nested_map( - lambda x: np.zeros( # pylint: disable=g-long-lambda - x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], - dtype=x.dtype), - input_signature) - mem_end = np.zeros((), dtype=np.int32) - state = (mem_end, mem, state) - - self.state = tuple(state) - self.weights = tuple(weights) - - def create_weights_unbatched(self, input_signature, rng): - if isinstance(input_signature, (tuple, list)): - input_signature = input_signature[0] - d_model = input_signature.shape[-1] - rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) - w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) - if not self._share_qk: - w_k = self._kernel_initializer((d_model, self._d_qk), rng_k) - w_v = self._kernel_initializer((d_model, self._d_v), rng_v) - w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) - - if self._bias: - b_q = np.zeros(self._d_qk) - b_v = np.zeros(self._d_v) - if self._share_qk: - return (w_q, w_v, w_o, b_q, b_v) - else: - b_k = np.zeros(self._d_qk) - return (w_q, w_k, w_v, w_o, b_q, b_k, b_v) - - if self._share_qk: - return (w_q, w_v, w_o) - else: - return (w_q, w_k, w_v, w_o) - - def create_state_unbatched(self, input_signature, rng): - if isinstance(input_signature, (tuple, list)): - input_signature = input_signature[0] - # The `rng` argument passed to forward_unbatched is shared across all - # examples and heads. This facilitates using broadcasted dropout, which - # saves memory and hasn't been shown to hurt model quality. Even though the - # same sharing is likely to be safe when selecting random hash functions - # for LSH, we haven't run experiments to demonstrate this. To be on the safe - # side we include a per-head RNG in the state for the purpose of doing LSH. - if not self._incremental: - length = self._max_length_for_buckets or input_signature.shape[0] - buckets = np.zeros(self._n_hashes * length, dtype=np.int32) - return (buckets, rng) - else: - buckets = np.zeros( - self._n_hashes * self._predict_mem_len, dtype=np.int32) - buckets_idx = np.zeros((), dtype=np.int32) - return (buckets, buckets_idx, rng) - - def hash_vectors(self, vecs, rng, mask=None): - n_buckets_list = self._n_buckets - - # Determine the number of buckets needed from input length if not set. - if n_buckets_list is None: - length = vecs.shape[0] - n_buckets = 2 * max(1, length // self._chunk_len) - if n_buckets <= 128: - n_buckets_list = n_buckets - else: # Factorize n_buckets. - n_buckets_div = 2**math.ceil(math.log2(math.sqrt(n_buckets))) - # Both factors must be even. - n_buckets_rest = 2 * (n_buckets // (2 * n_buckets_div)) - n_buckets_list = [n_buckets_div, n_buckets_rest] - - # Hash vectors. - buckets, n_buckets = hash_vecs(vecs, n_buckets_list, self._n_hashes, rng) - - if mask is not None: - n_buckets += 1 # Create an extra bucket for padding tokens only - buckets = np.where(mask[None, :], buckets, n_buckets - 1) - - # buckets is now (n_hashes, seqlen). Next we add offsets so that - # bucket numbers from different hashing rounds don't overlap. - offsets = np.arange(self._n_hashes, dtype=np.int32) - offsets = np.reshape(offsets * n_buckets, (-1, 1)) - buckets = np.reshape(buckets + offsets, (-1,)) - return buckets - - def forward_unbatched(self, x, mask=None, *, weights, state, rng, - update_state): - attend_rng, output_rng = fastmath.random.split(rng) - w_q, w_v, w_o = weights - - q = np.matmul(x, w_q) - v = np.matmul(x, w_v) - - if update_state: - _, old_hash_rng = state - hash_rng, hash_subrng = fastmath.random.split(old_hash_rng) - buckets = self.hash_vectors(q, hash_subrng, mask) - s_buckets = buckets - if self._max_length_for_buckets: - length = self._n_hashes * self._max_length_for_buckets - if buckets.shape[0] < length: - s_buckets = np.concatenate( - [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)], - axis=0) - state = (s_buckets, hash_rng) - else: - buckets, _ = state - if self._max_length_for_buckets: - buckets = buckets[:self._n_hashes * x.shape[0]] - - seqlen = x.shape[0] - assert int(buckets.shape[0]) == self._n_hashes * seqlen - - ticker = np.arange(self._n_hashes * seqlen, dtype=np.int32) - buckets_and_t = seqlen * buckets + (ticker % seqlen) - buckets_and_t = fastmath.stop_gradient(buckets_and_t) - - # Hash-based sort ("s" at the start of variable names means "sorted") - sbuckets_and_t, sticker = fastmath.sort_key_val( - buckets_and_t, ticker, dimension=-1) - _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1) - sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t) - sticker = fastmath.stop_gradient(sticker) - undo_sort = fastmath.stop_gradient(undo_sort) - - st = (sticker % seqlen) - sq = np.take(q, st, axis=0) - sv = np.take(v, st, axis=0) - - mask_fn = functools.partial(mask_self_attention, causal=self._causal, - exclude_self=True, masked=self._masked) - q_info = st - - assert (mask is not None) == self._masked - kv_info = None - if self._masked: - # mask is a boolean array (True means "is valid token") - smask = np.take(mask, st, axis=0) - ones_like_mask = np.ones_like(smask, dtype=np.int32) - kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask) - - so, slogits = attend( - sq, k=None, v=sv, - q_chunk_len=self._chunk_len, - n_chunks_before=self._n_chunks_before, - n_chunks_after=self._n_chunks_after, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, + def _incremental_forward_unbatched( + self, x, *, q_start, q_len, weights, state, rng, update_state + ): + assert ( + update_state + ), "This setting not supported (e.g. no backprop for fast inference)" + if q_len > 1: + if isinstance(q_start, int): + assert q_start == 0, "Chunks larger than 1 only work at start for now." + if x.shape[0] % self._chunk_len == 0: + x_padded = x + else: + pad_amount = self._chunk_len - (x.shape[0] % self._chunk_len) + x_padded = np.pad(x, ((0, pad_amount), (0, 0)), mode="constant") + buckets, buckets_idx, hash_rng = state + q = np.matmul(x_padded, weights[0]) + buckets_update = self.hash_vectors(q, hash_rng) + + out, _ = self.forward_unbatched( + x_padded, + weights=weights, + state=(buckets_update, hash_rng), + rng=rng, + update_state=False, + ) + + out = out[:q_len] + buckets = np.reshape(buckets, (self._n_hashes, -1)) + buckets_update = np.reshape(buckets_update, (self._n_hashes, -1))[:, :q_len] + if q_len > self._predict_mem_len: + buckets_update = buckets_update[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + buckets = fastmath.dynamic_update_slice_in_dim( + buckets, buckets_update, q_start, axis=1 + ) + buckets = np.reshape(buckets, (-1,)) + + return out, (buckets, buckets_idx + q_len, hash_rng) + + # This codepath is for handling one token at a time. + assert q_len == 1 + buckets, buckets_idx, hash_rng = state + + def roll_buckets(buckets): + buckets = np.reshape(buckets, (self._n_hashes, -1)) + new_buckets = np.concatenate( + [ + buckets, + np.zeros( + (self._n_hashes, self._predict_drop_len), dtype=buckets.dtype + ), + ], + axis=1, + ) + new_buckets = fastmath.dynamic_slice_in_dim( + new_buckets, buckets_idx - q_start, buckets.shape[-1], axis=1 + ) + new_buckets = np.reshape(new_buckets, (-1,)) + return new_buckets + + buckets = fastmath.cond( + pred=buckets_idx > q_start, + true_operand=buckets, + true_fun=roll_buckets, + false_operand=buckets, + false_fun=lambda x: x, ) - # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would - # also work, but these helpers include performance optimizations for TPU. - o = permute_via_gather(so, undo_sort, sticker, axis=0) - logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) - - if self._n_hashes > 1: - o = np.reshape(o, (self._n_hashes, seqlen, o.shape[-1])) - logits = np.reshape(logits, (self._n_hashes, seqlen, 1)) - probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True)) - o = np.sum(o * probs, axis=0) - - assert o.shape == (seqlen, w_v.shape[-1]) - out = np.matmul(o, w_o) - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - return out, state - - def _incremental_forward_unbatched(self, x, *, - q_start, q_len, - weights, state, rng, update_state): - assert update_state, ( - 'This setting not supported (e.g. no backprop for fast inference)') - if q_len > 1: - if isinstance(q_start, int): - assert q_start == 0, 'Chunks larger than 1 only work at start for now.' - if x.shape[0] % self._chunk_len == 0: - x_padded = x - else: - pad_amount = self._chunk_len - (x.shape[0] % self._chunk_len) - x_padded = np.pad(x, ((0, pad_amount), (0, 0)), mode='constant') - buckets, buckets_idx, hash_rng = state - q = np.matmul(x_padded, weights[0]) - buckets_update = self.hash_vectors(q, hash_rng) - - out, _ = self.forward_unbatched( - x_padded, weights=weights, state=(buckets_update, hash_rng), - rng=rng, update_state=False) - - out = out[:q_len] - buckets = np.reshape(buckets, (self._n_hashes, -1)) - buckets_update = np.reshape( - buckets_update, (self._n_hashes, -1))[:, :q_len] - if q_len > self._predict_mem_len: - buckets_update = buckets_update[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - buckets = fastmath.dynamic_update_slice_in_dim( - buckets, buckets_update, q_start, axis=1) - buckets = np.reshape(buckets, (-1,)) - - return out, (buckets, buckets_idx + q_len, hash_rng) - - # This codepath is for handling one token at a time. - assert q_len == 1 - buckets, buckets_idx, hash_rng = state - - def roll_buckets(buckets): - buckets = np.reshape(buckets, (self._n_hashes, -1)) - new_buckets = np.concatenate( - [buckets, np.zeros((self._n_hashes, self._predict_drop_len), - dtype=buckets.dtype) - ], axis=1) - new_buckets = fastmath.dynamic_slice_in_dim( - new_buckets, buckets_idx - q_start, buckets.shape[-1], axis=1) - new_buckets = np.reshape(new_buckets, (-1,)) - return new_buckets - - buckets = fastmath.cond( - pred=buckets_idx > q_start, - true_operand=buckets, - true_fun=roll_buckets, - false_operand=buckets, - false_fun=lambda x: x, - ) + attend_rng, output_rng = fastmath.random.split(rng) + w_q, w_v, w_o = weights - attend_rng, output_rng = fastmath.random.split(rng) - w_q, w_v, w_o = weights - - q_range = q_start + np.arange(q_len, dtype=np.int32) - # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not - # floating-point equivalent, at least in non-jitted code. We correct the - # discrepancy by duplicating the slice. Floating-point noise may not be - # an issue when using models, but it makes it harder to write tests that - # compare fast and slow inference code for equivalence. - q = np.matmul(np.concatenate([x[q_range]] * 2, 0), w_q) - - q_buckets = self.hash_vectors(q, hash_rng) - q_buckets = np.reshape(q_buckets, (self._n_hashes, 2))[:, :q_len] - - unflattened_buckets = fastmath.dynamic_update_slice_in_dim( - np.reshape(buckets, (self._n_hashes, -1)), - q_buckets, q_start, axis=1) - buckets = np.reshape(unflattened_buckets, (-1,)) - is_valid_target = np.any(unflattened_buckets == q_buckets, axis=0) - - assert q_buckets.shape[-1] == 1 # Is true when q_len == 1 - seqlen = x.shape[0] - arange_seqlen = np.arange(seqlen, dtype=np.int32) - kv_priorities = np.where( - arange_seqlen > (q_start + q_len), - -(seqlen + arange_seqlen), arange_seqlen) - kv_priorities = kv_priorities + seqlen * is_valid_target.astype(np.int32) - _, kv_indices = fastmath.sort_key_val(kv_priorities, arange_seqlen) - kv_indices = kv_indices[ - -self._n_hashes * self._chunk_len * (1 + self._n_chunks_before):] - assert self._n_chunks_after == 0 - - x_attend_to = x[kv_indices] - k = length_normalized(np.matmul(x_attend_to, w_q)) - v = np.matmul(x_attend_to, w_v) - - mask_fn = functools.partial( - mask_self_attention, causal=True, masked=True, exclude_self=True) - q_info = q_start + np.arange(q_len, dtype=np.int32) - kv_info = kv_indices.astype(np.int32) - q_info = q_info.astype(np.int32) - # TODO(kitaev): is it better to mask out attention across buckets? - # kv_info = np.where(is_valid_target[kv_indices], kv_indices, -kv_indices) - o, _ = attend( - q, k, v, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, - ) + q_range = q_start + np.arange(q_len, dtype=np.int32) + # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not + # floating-point equivalent, at least in non-jitted code. We correct the + # discrepancy by duplicating the slice. Floating-point noise may not be + # an issue when using models, but it makes it harder to write tests that + # compare fast and slow inference code for equivalence. + q = np.matmul(np.concatenate([x[q_range]] * 2, 0), w_q) - out = np.matmul(o, w_o) - if q_len == 1: - out = out[:1] - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - buckets_idx = np.array(q_start + q_len, dtype=buckets_idx.dtype) - return out, (buckets, buckets_idx, hash_rng) + q_buckets = self.hash_vectors(q, hash_rng) + q_buckets = np.reshape(q_buckets, (self._n_hashes, 2))[:, :q_len] - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. + unflattened_buckets = fastmath.dynamic_update_slice_in_dim( + np.reshape(buckets, (self._n_hashes, -1)), q_buckets, q_start, axis=1 + ) + buckets = np.reshape(unflattened_buckets, (-1,)) + is_valid_target = np.any(unflattened_buckets == q_buckets, axis=0) + + assert q_buckets.shape[-1] == 1 # Is true when q_len == 1 + seqlen = x.shape[0] + arange_seqlen = np.arange(seqlen, dtype=np.int32) + kv_priorities = np.where( + arange_seqlen > (q_start + q_len), -(seqlen + arange_seqlen), arange_seqlen + ) + kv_priorities = kv_priorities + seqlen * is_valid_target.astype(np.int32) + _, kv_indices = fastmath.sort_key_val(kv_priorities, arange_seqlen) + kv_indices = kv_indices[ + -self._n_hashes * self._chunk_len * (1 + self._n_chunks_before) : + ] + assert self._n_chunks_after == 0 + + x_attend_to = x[kv_indices] + k = length_normalized(np.matmul(x_attend_to, w_q)) + v = np.matmul(x_attend_to, w_v) + + mask_fn = functools.partial( + mask_self_attention, causal=True, masked=True, exclude_self=True + ) + q_info = q_start + np.arange(q_len, dtype=np.int32) + kv_info = kv_indices.astype(np.int32) + q_info = q_info.astype(np.int32) + # TODO(kitaev): is it better to mask out attention across buckets? + # kv_info = np.where(is_valid_target[kv_indices], kv_indices, -kv_indices) + o, _ = attend( + q, + k, + v, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) - Args: - inputs: Layer inputs (subclasses may use different inputs) + out = np.matmul(o, w_o) + if q_len == 1: + out = out[:1] + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + buckets_idx = np.array(q_start + q_len, dtype=buckets_idx.dtype) + return out, (buckets, buckets_idx, hash_rng) + + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. + + Args: + inputs: Layer inputs (subclasses may use different inputs) + + Returns: + A tuple (output, new_state). + """ + weights, state, rng = self.weights, self.state, self.rng + if not self._use_reference_code: + # By default, an efficient, batched implementation is used. + output, new_state, _, _ = self.forward_and_or_backward( + inputs, weights, state, rng, compute_output=True, update_state=True + ) + self.state = new_state + return output + + # The reference implementation below provides a more readable overview of + # what this class does. It's not optimized, however, and should only be used + # when testing this class for correctness. + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] - Returns: - A tuple (output, new_state). - """ - weights, state, rng = self.weights, self.state, self.rng - if not self._use_reference_code: - # By default, an efficient, batched implementation is used. - output, new_state, _, _ = self.forward_and_or_backward( - inputs, weights, state, rng, compute_output=True, update_state=True) - self.state = new_state - return output - - # The reference implementation below provides a more readable overview of - # what this class does. It's not optimized, however, and should only be used - # when testing this class for correctness. - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - if self._incremental: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - - output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] - new_state = [] - for example_idx in range(batch_size): - for head_idx in range(self._n_heads): - # pylint: disable=cell-var-from-loop - single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) - single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) - single_state = fastmath.nested_map( - lambda s: s[example_idx * self._n_heads + head_idx], state) - # pylint: enable=cell-var-from-loop if self._incremental: - single_out, single_new_state = self._incremental_forward_unbatched( - *single_inputs, q_start=q_start, q_len=seqlen, - weights=single_weights, rng=rng, - state=single_state, update_state=True) - else: - single_out, single_new_state = self.forward_unbatched( - *single_inputs, weights=single_weights, rng=rng, - state=single_state, update_state=True) - new_state.append(single_new_state) - output_accum[example_idx] = output_accum[example_idx] + single_out - - output = np.stack(output_accum, 0) - if new_state and fastmath.tree_leaves(new_state[0]): - new_state = fastmath.nested_map_multiarg( - lambda *s: np.stack(s, 0), *new_state) - else: - new_state = state - if self._incremental: - new_state = (new_mem_end, new_mem, new_state) - self.state = tuple(new_state) - return output - - def _use_predict_mem(self, inputs, state): - """Update input cache for fast inference.""" - mem_end, mem, state = state - seqlen = inputs[0].shape[-2] - - if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: - # This branch is called when only a small number of tokens are appended to - # the sequence, e.g. when generating one token at a time. A fixed number - # of tokens (self._predict_drop_tokens) will be dropped from memory if - # needed, and then new values will be inserted into the memory. - def roll_mem(buf): - return np.concatenate( - [buf[:, self._predict_drop_len:], - np.zeros_like(buf[:, :self._predict_drop_len])], axis=1) - - do_roll_mem = (mem_end + seqlen > self._predict_mem_len) - mem = fastmath.cond( - pred=do_roll_mem, - true_operand=mem, - true_fun=lambda x: fastmath.nested_map(roll_mem, x), - false_operand=mem, - false_fun=lambda x: x, - ) - mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) - def update_mem(mem_element, new_vals): - assert new_vals.shape[1] == seqlen - if seqlen == 1: - return fastmath.index_update( - mem_element, jax.numpy.index_exp[:, mem_end], new_vals[:, 0, ...]) + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + + output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] + new_state = [] + for example_idx in range(batch_size): + for head_idx in range(self._n_heads): + # pylint: disable=cell-var-from-loop + single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) + single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) + single_state = fastmath.nested_map( + lambda s: s[example_idx * self._n_heads + head_idx], state + ) + # pylint: enable=cell-var-from-loop + if self._incremental: + single_out, single_new_state = self._incremental_forward_unbatched( + *single_inputs, + q_start=q_start, + q_len=seqlen, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + else: + single_out, single_new_state = self.forward_unbatched( + *single_inputs, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + new_state.append(single_new_state) + output_accum[example_idx] = output_accum[example_idx] + single_out + + output = np.stack(output_accum, 0) + if new_state and fastmath.tree_leaves(new_state[0]): + new_state = fastmath.nested_map_multiarg( + lambda *s: np.stack(s, 0), *new_state + ) else: - return fastmath.dynamic_update_slice_in_dim( - mem_element, new_vals, mem_end, axis=1) - inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) - return inputs, state, mem_end, inputs, mem_end + seqlen - else: - assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len - # This branch handles the case where a large number of tokens are being - # introduced all at once. The code here assumes that we are at the start - # of the sequence, which matches the typical use case of decoding from a - # language model given a long prefix. Note that if we're not at the start - # of the sequence, the code here won't work. - new_flat_mem = [] - for inp in fastmath.tree_leaves(inputs): - assert inp.shape[1] == seqlen - if seqlen == self._predict_mem_len: - new_mem_val = inp - elif seqlen > self._predict_mem_len: - new_mem_val = inp[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type + new_state = state + if self._incremental: + new_state = (new_mem_end, new_mem, new_state) + self.state = tuple(new_state) + return output + + def _use_predict_mem(self, inputs, state): + """Update input cache for fast inference.""" + mem_end, mem, state = state + seqlen = inputs[0].shape[-2] + + if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: + # This branch is called when only a small number of tokens are appended to + # the sequence, e.g. when generating one token at a time. A fixed number + # of tokens (self._predict_drop_tokens) will be dropped from memory if + # needed, and then new values will be inserted into the memory. + def roll_mem(buf): + return np.concatenate( + [ + buf[:, self._predict_drop_len :], + np.zeros_like(buf[:, : self._predict_drop_len]), + ], + axis=1, + ) + + do_roll_mem = mem_end + seqlen > self._predict_mem_len + mem = fastmath.cond( + pred=do_roll_mem, + true_operand=mem, + true_fun=lambda x: fastmath.nested_map(roll_mem, x), + false_operand=mem, + false_fun=lambda x: x, + ) + mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) + + def update_mem(mem_element, new_vals): + assert new_vals.shape[1] == seqlen + if seqlen == 1: + return fastmath.index_update( + mem_element, + jax.numpy.index_exp[:, mem_end], + new_vals[:, 0, ...], + ) + else: + return fastmath.dynamic_update_slice_in_dim( + mem_element, new_vals, mem_end, axis=1 + ) + + inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) + return inputs, state, mem_end, inputs, mem_end + seqlen else: - new_mem_val = np.concatenate([ - inp, - np.zeros(inp.shape[:1] - + (self._predict_mem_len - inp.shape[1],) - + inp.shape[2:], - dtype=inp.dtype) - ], axis=1) - new_flat_mem.append(new_mem_val) - mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) - - # This code only works at the start of the sequence. There's no "assert" - # primitive we can use to signal an error, so we instead signal the error - # by introducing NaNs into the computation. - def replace_with_nan_if_not_seq_start(x): - if x.dtype != np.float32: - return x - return fastmath.cond( - pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), - true_operand=x, true_fun=lambda x: x, - false_operand=x, false_fun=lambda x: x * np.nan) - inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) - return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) - - @property - def has_backward(self): - # Use an efficient backward pass, unless we're running the reference code. - return not self._use_reference_code - - def backward(self, inputs, output, grad, weights, state, new_state, rng=None, - **kwargs): - """Custom backward pass, for efficiency (see forward_and_or_backward).""" - assert not self._use_reference_code - del output, state, kwargs - _, _, inputs_grad, weights_grad = self.forward_and_or_backward( - inputs, weights, new_state, rng, output_grad=grad, - compute_output=False, update_state=False) - return inputs_grad, weights_grad - - def forward_and_or_backward( - self, inputs, weights, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. - - See `forward` for a reference implementation of what this layer does. The - reference implementation is not very efficient, however, and this method - provides a more performant version. - - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. - - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). + assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len + # This branch handles the case where a large number of tokens are being + # introduced all at once. The code here assumes that we are at the start + # of the sequence, which matches the typical use case of decoding from a + # language model given a long prefix. Note that if we're not at the start + # of the sequence, the code here won't work. + new_flat_mem = [] + for inp in fastmath.tree_leaves(inputs): + assert inp.shape[1] == seqlen + if seqlen == self._predict_mem_len: + new_mem_val = inp + elif seqlen > self._predict_mem_len: + new_mem_val = inp[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + else: + new_mem_val = np.concatenate( + [ + inp, + np.zeros( + inp.shape[:1] + + (self._predict_mem_len - inp.shape[1],) + + inp.shape[2:], + dtype=inp.dtype, + ), + ], + axis=1, + ) + new_flat_mem.append(new_mem_val) + mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) + + # This code only works at the start of the sequence. There's no "assert" + # primitive we can use to signal an error, so we instead signal the error + # by introducing NaNs into the computation. + def replace_with_nan_if_not_seq_start(x): + if x.dtype != np.float32: + return x + return fastmath.cond( + pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), + true_operand=x, + true_fun=lambda x: x, + false_operand=x, + false_fun=lambda x: x * np.nan, + ) + + inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) + return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) + + @property + def has_backward(self): + # Use an efficient backward pass, unless we're running the reference code. + return not self._use_reference_code + + def backward( + self, inputs, output, grad, weights, state, new_state, rng=None, **kwargs + ): + """Custom backward pass, for efficiency (see forward_and_or_backward).""" + assert not self._use_reference_code + del output, state, kwargs + _, _, inputs_grad, weights_grad = self.forward_and_or_backward( + inputs, + weights, + new_state, + rng, + output_grad=grad, + compute_output=False, + update_state=False, + ) + return inputs_grad, weights_grad - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # TODO(kitaev): profile ~4% speed drop compared to previous implementation - # in some conditions. Other conditions (e.g. the enwik8 model) appear - # to have the same overall training speed. - # TODO(b/148460708): reduce memory usage further - # TODO(kitaev): there should be a higher-level API (like vmap) that does - # batching, instead of needing 3 separate manual implementations here. - - # Notes regarding the implementation: - # (a) Multiple heads or examples are batched together. There are three - # different regimes possible: one head at a time (for long sequences and - # expensive attention types), several attention heads at a time (for - # long sequences but less-expensive attention types), and several - # examples at a time (for large batches of shorter sequences). For the - # time being, each of these regimes has its own code. - # (b) Python loops produce large computation graphs when jitted, so the - # default is to use a JAX loop instead. - # (c) No intermediate quantities are cached for the backward pass. Instead, - # the forward pass is re-computed when doing backprop. This approach is - # often called "checkpointing" or "rematerialization". When not all - # examples or heads fit in memory simultaneously, the implementation - # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse - # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] - # automatically, so the looping for the backward pass is done manually. - # - # [FW-BW-1] for example, head in zip(examples, heads): - # forward(example, head) - # backward(example, head) # uses intermediates from forward - # - # [FW-BW-2] for example, head in zip(examples, heads): - # forward(example, head) - # for example, head in zip(examples, heads): - # backward(example, head) - - have_single_input = not isinstance(inputs, (tuple, list)) - if have_single_input: - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - compute_grad = (output_grad is not None) - assert compute_output or compute_grad, 'No work to perform!' - - if not self._incremental: - forward_unbatched = functools.partial( - self.forward_unbatched, rng=rng, update_state=update_state) - else: - if update_state: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - else: - # This assumes that the memory stores all of the inputs, which would not - # be valid if doing backprop in mode 'predict' with long lengths. - new_mem_end, inputs, state = state - q_start = new_mem_end - seqlen - - forward_unbatched = functools.partial( - self._incremental_forward_unbatched, - q_start=fastmath.stop_gradient(q_start), - q_len=fastmath.stop_gradient(seqlen), - rng=rng, update_state=update_state) - - # Adjust degree of parallelism based on the batch size. - n_parallel_heads = batch_size * self._n_heads - if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: - n_parallel_heads = self._n_parallel_heads - - def tree_update(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], - y), - tree, new_values) - - def tree_add(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), - tree, new_values) - - if compute_grad: - inputs_is_differentiable = fastmath.nested_map( - lambda x: np.issubdtype(x.dtype, np.inexact), inputs) - def split_differentiable(xs): - differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: x if is_differentiable else None, - xs, inputs_is_differentiable) - non_differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: None if is_differentiable else x, - xs, inputs_is_differentiable) - return differentiable_xs, non_differentiable_xs - def join_differentiable(differentiable_xs, non_differentiable_xs): - """Reconstitute inputs pytree from differentiable/non-d. partitions.""" - differentiable_leaves = fastmath.tree_leaves(differentiable_xs) - non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) - leaves = [] - for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): - if is_differentiable: - leaves.append(differentiable_leaves.pop(0)) - else: - leaves.append(non_differentiable_leaves.pop(0)) - assert not differentiable_leaves - assert not non_differentiable_leaves - tree, _ = fastmath.tree_unflatten(leaves, inputs) - return tree - - def vjp(fn, inp, *args, has_aux=False): - d_inp, nd_inp = split_differentiable(inp) - def fn_closed_over_nd_inp(d_inp, *args): - inp = join_differentiable(d_inp, nd_inp) - return fn(inp, *args) - return fastmath.vjp(fn_closed_over_nd_inp, d_inp, *args, - has_aux=has_aux) - - if n_parallel_heads == 1: - def run_inner(idx, loop_val): - """Runs one slice of attention (for a single head).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - example_idx = idx // self._n_heads - head_idx = idx % self._n_heads - - i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_h = fastmath.nested_map(lambda w: w[head_idx], weights) - s_h = fastmath.nested_map(lambda s: s[idx], state) - - def forward_fn(i_h, w_h): - return forward_unbatched( - *i_h, weights=w_h, state=fastmath.stop_gradient(s_h)) + def forward_and_or_backward( + self, + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + See `forward` for a reference implementation of what this layer does. The + reference implementation is not very efficient, however, and this method + provides a more performant version. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # TODO(kitaev): profile ~4% speed drop compared to previous implementation + # in some conditions. Other conditions (e.g. the enwik8 model) appear + # to have the same overall training speed. + # TODO(b/148460708): reduce memory usage further + # TODO(kitaev): there should be a higher-level API (like vmap) that does + # batching, instead of needing 3 separate manual implementations here. + + # Notes regarding the implementation: + # (a) Multiple heads or examples are batched together. There are three + # different regimes possible: one head at a time (for long sequences and + # expensive attention types), several attention heads at a time (for + # long sequences but less-expensive attention types), and several + # examples at a time (for large batches of shorter sequences). For the + # time being, each of these regimes has its own code. + # (b) Python loops produce large computation graphs when jitted, so the + # default is to use a JAX loop instead. + # (c) No intermediate quantities are cached for the backward pass. Instead, + # the forward pass is re-computed when doing backprop. This approach is + # often called "checkpointing" or "rematerialization". When not all + # examples or heads fit in memory simultaneously, the implementation + # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse + # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] + # automatically, so the looping for the backward pass is done manually. + # + # [FW-BW-1] for example, head in zip(examples, heads): + # forward(example, head) + # backward(example, head) # uses intermediates from forward + # + # [FW-BW-2] for example, head in zip(examples, heads): + # forward(example, head) + # for example, head in zip(examples, heads): + # backward(example, head) + + have_single_input = not isinstance(inputs, (tuple, list)) + if have_single_input: + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] + + compute_grad = output_grad is not None + assert compute_output or compute_grad, "No work to perform!" + + if not self._incremental: + forward_unbatched = functools.partial( + self.forward_unbatched, rng=rng, update_state=update_state + ) + else: + if update_state: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + else: + # This assumes that the memory stores all of the inputs, which would not + # be valid if doing backprop in mode 'predict' with long lengths. + new_mem_end, inputs, state = state + q_start = new_mem_end - seqlen + + forward_unbatched = functools.partial( + self._incremental_forward_unbatched, + q_start=fastmath.stop_gradient(q_start), + q_len=fastmath.stop_gradient(seqlen), + rng=rng, + update_state=update_state, + ) + + # Adjust degree of parallelism based on the batch size. + n_parallel_heads = batch_size * self._n_heads + if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: + n_parallel_heads = self._n_parallel_heads + + def tree_update(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) + + def tree_add(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) if compute_grad: - o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) - ct_h = output_grad[example_idx] - assert o_h.shape == ct_h.shape - i_ct_h, w_ct_h = backward_fn(ct_h) - else: - o_h, s_h = forward_fn(i_h, w_h) + inputs_is_differentiable = fastmath.nested_map( + lambda x: np.issubdtype(x.dtype, np.inexact), inputs + ) + + def split_differentiable(xs): + differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: x if is_differentiable else None, + xs, + inputs_is_differentiable, + ) + non_differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: None if is_differentiable else x, + xs, + inputs_is_differentiable, + ) + return differentiable_xs, non_differentiable_xs + + def join_differentiable(differentiable_xs, non_differentiable_xs): + """Reconstitute inputs pytree from differentiable/non-d. partitions.""" + differentiable_leaves = fastmath.tree_leaves(differentiable_xs) + non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) + leaves = [] + for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): + if is_differentiable: + leaves.append(differentiable_leaves.pop(0)) + else: + leaves.append(non_differentiable_leaves.pop(0)) + assert not differentiable_leaves + assert not non_differentiable_leaves + tree, _ = fastmath.tree_unflatten(leaves, inputs) + return tree + + def vjp(fn, inp, *args, has_aux=False): + d_inp, nd_inp = split_differentiable(inp) + + def fn_closed_over_nd_inp(d_inp, *args): + inp = join_differentiable(d_inp, nd_inp) + return fn(inp, *args) + + return fastmath.vjp( + fn_closed_over_nd_inp, d_inp, *args, has_aux=has_aux + ) + + if n_parallel_heads == 1: + + def run_inner(idx, loop_val): + """Runs one slice of attention (for a single head).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + example_idx = idx // self._n_heads + head_idx = idx % self._n_heads + + i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_h = fastmath.nested_map(lambda w: w[head_idx], weights) + s_h = fastmath.nested_map(lambda s: s[idx], state) + + def forward_fn(i_h, w_h): + return forward_unbatched( + *i_h, weights=w_h, state=fastmath.stop_gradient(s_h) + ) + + if compute_grad: + o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) + ct_h = output_grad[example_idx] + assert o_h.shape == ct_h.shape + i_ct_h, w_ct_h = backward_fn(ct_h) + else: + o_h, s_h = forward_fn(i_h, w_h) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_h) + if update_state: + s_all = tree_update(s_all, idx, s_h) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) + w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) + return (o_all, s_all, i_ct_all, w_ct_all) + + elif n_parallel_heads < self._n_heads: + assert self._n_heads % n_parallel_heads == 0 + + def run_inner(idx, loop_val): + """Runs one slice of attention (multiple heads, but one example).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * self._n_parallel_heads + example_idx = idx // self._n_heads + head_idx_lo = idx % self._n_heads + head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_mh = fastmath.nested_map(lambda w: w[head_range], weights) + s_mh = fastmath.nested_map(lambda s: s[state_range], state) + + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + def forward_fn(i_mh, w_mh): + o_mh, new_s_mh = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0 + )(i_mh, w_mh, s_mh) + o_mh = np.sum(o_mh, axis=0) + return o_mh, new_s_mh + + if compute_grad: + o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) + ct_mh = output_grad[example_idx] + assert o_mh.shape == ct_mh.shape + i_ct_mh, w_ct_mh = backward_fn(ct_mh) + else: + o_mh, s_mh = forward_fn(i_mh, w_mh) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_mh) + if update_state: + s_all = tree_update(s_all, state_range, s_mh) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) + w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) + return (o_all, s_all, i_ct_all, w_ct_all) + else: + assert n_parallel_heads % self._n_heads == 0 + + def forward_single_example(i_x, w_all, s_x): + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + o_x, s_x = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0) + )(i_x, w_all, s_x) + o_x = np.sum(o_x, axis=0) + return o_x, s_x + + def run_inner(idx, loop_val): + """Runs one slice of attention (all heads for one or more examples).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * n_parallel_heads + example_idx_lo = idx // self._n_heads + example_range = example_idx_lo + np.arange( + n_parallel_heads // self._n_heads, dtype=np.int32 + ) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) + s_mex = fastmath.nested_map( + lambda s: np.reshape( + s[state_range], # pylint: disable=g-long-lambda + (-1, self._n_heads) + s.shape[1:], + ), + state, + ) + + def forward_fn(i_mex, w_all): + o_mex, new_s_mex = fastmath.vmap( + forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0) + )(i_mex, w_all, s_mex) + new_s_mex = fastmath.nested_map( + lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), + new_s_mex, + ) + return o_mex.astype(i_mex[0].dtype), new_s_mex + + if compute_grad: + o_mex, backward_fn, s_mex = vjp( + forward_fn, i_mex, weights, has_aux=True + ) + ct_mex = output_grad[example_range] + assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) + assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) + i_ct_mex, w_ct_mex = backward_fn(ct_mex) + else: + o_mex, s_mex = forward_fn(i_mex, weights) + + if compute_output: + o_all = fastmath.index_add( + o_all, jax.numpy.index_exp[example_range], o_mex + ) + if update_state: + s_all = tree_update(s_all, state_range, s_mex) + if compute_grad: + i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) + w_ct_all = fastmath.nested_map_multiarg( + lambda old_all, delta_all: old_all + delta_all, + w_ct_all, + w_ct_mex, + ) + return (o_all, s_all, i_ct_all, w_ct_all) + + o_all = s_all = i_ct_all = w_ct_all = None if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_h) + o_all = np.zeros((batch_size, seqlen, d_model), dtype=inputs[0].dtype) if update_state: - s_all = tree_update(s_all, idx, s_h) + s_all = state if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) - w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) - return (o_all, s_all, i_ct_all, w_ct_all) - elif n_parallel_heads < self._n_heads: - assert self._n_heads % n_parallel_heads == 0 - def run_inner(idx, loop_val): - """Runs one slice of attention (multiple heads, but one example).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * self._n_parallel_heads - example_idx = idx // self._n_heads - head_idx_lo = idx % self._n_heads - head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_mh = fastmath.nested_map(lambda w: w[head_range], weights) - s_mh = fastmath.nested_map(lambda s: s[state_range], state) - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - def forward_fn(i_mh, w_mh): - o_mh, new_s_mh = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0)( - i_mh, w_mh, s_mh) - o_mh = np.sum(o_mh, axis=0) - return o_mh, new_s_mh + i_ct_all = fastmath.nested_map(np.zeros_like, inputs) + i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) + w_ct_all = fastmath.nested_map(np.zeros_like, weights) - if compute_grad: - o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) - ct_mh = output_grad[example_idx] - assert o_mh.shape == ct_mh.shape - i_ct_mh, w_ct_mh = backward_fn(ct_mh) + loop_val = (o_all, s_all, i_ct_all, w_ct_all) + + assert (batch_size * self._n_heads) % n_parallel_heads == 0 + loop_hi = (batch_size * self._n_heads) // n_parallel_heads + if self._use_python_loop or loop_hi == 1: + for idx in range(loop_hi): + loop_val = run_inner(idx, loop_val) else: - o_mh, s_mh = forward_fn(i_mh, w_mh) + loop_val = fastmath.fori_loop(0, loop_hi, run_inner, loop_val) - if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_mh) - if update_state: - s_all = tree_update(s_all, state_range, s_mh) - if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) - w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) - return (o_all, s_all, i_ct_all, w_ct_all) - else: - assert n_parallel_heads % self._n_heads == 0 - def forward_single_example(i_x, w_all, s_x): - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - o_x, s_x = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0))( - i_x, w_all, s_x) - o_x = np.sum(o_x, axis=0) - return o_x, s_x - def run_inner(idx, loop_val): - """Runs one slice of attention (all heads for one or more examples).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * n_parallel_heads - example_idx_lo = idx // self._n_heads - example_range = example_idx_lo + np.arange( - n_parallel_heads // self._n_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) - s_mex = fastmath.nested_map( - lambda s: np.reshape(s[state_range], # pylint: disable=g-long-lambda - (-1, self._n_heads) + s.shape[1:]), - state) - def forward_fn(i_mex, w_all): - o_mex, new_s_mex = fastmath.vmap( - forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0))( - i_mex, w_all, s_mex) - new_s_mex = fastmath.nested_map( - lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), - new_s_mex) - return o_mex.astype(i_mex[0].dtype), new_s_mex + (o_all, s_all, i_ct_all, w_ct_all) = loop_val if compute_grad: - o_mex, backward_fn, s_mex = vjp(forward_fn, i_mex, weights, - has_aux=True) - ct_mex = output_grad[example_range] - assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) - assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) - i_ct_mex, w_ct_mex = backward_fn(ct_mex) + i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + + if self._incremental and update_state: + s_all = (new_mem_end, new_mem, s_all) + + if have_single_input and compute_grad: + assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 + return (o_all, s_all, i_ct_all[0], w_ct_all) else: - o_mex, s_mex = forward_fn(i_mex, weights) + return (o_all, s_all, i_ct_all, w_ct_all) - if compute_output: - o_all = fastmath.index_add(o_all, jax.numpy.index_exp[example_range], - o_mex) - if update_state: - s_all = tree_update(s_all, state_range, s_mex) - if compute_grad: - i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) - w_ct_all = fastmath.nested_map_multiarg( - lambda old_all, delta_all: old_all + delta_all, - w_ct_all, w_ct_mex) - return (o_all, s_all, i_ct_all, w_ct_all) - - o_all = s_all = i_ct_all = w_ct_all = None - if compute_output: - o_all = np.zeros( - (batch_size, seqlen, d_model), dtype=inputs[0].dtype) - if update_state: - s_all = state - if compute_grad: - i_ct_all = fastmath.nested_map(np.zeros_like, inputs) - i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - w_ct_all = fastmath.nested_map(np.zeros_like, weights) - - loop_val = (o_all, s_all, i_ct_all, w_ct_all) - - assert (batch_size * self._n_heads) % n_parallel_heads == 0 - loop_hi = (batch_size * self._n_heads) // n_parallel_heads - if self._use_python_loop or loop_hi == 1: - for idx in range(loop_hi): - loop_val = run_inner(idx, loop_val) - else: - loop_val = fastmath.fori_loop( - 0, loop_hi, run_inner, loop_val) - (o_all, s_all, i_ct_all, w_ct_all) = loop_val +class PureLSHSelfAttention(base.Layer): + """LSH self-attention without weights.""" + + def __init__( + self, + n_heads=2, + d_qk=64, + d_v=64, + share_qk="unused", + causal=False, + masked=False, + chunk_len=128, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=1, + n_buckets=None, + mode="train", + predict_mem_len=2048, + predict_drop_len=256, + attention_dropout=0.0, + output_dropout=0.0, + max_length_for_buckets=None, + bias=False, + n_parallel_heads=1, + use_python_loop=False, + use_reference_code=False, + ): + """Construct an LSH self-attention layer.""" + # (qk, v, mask) -> (o) if masked + # (qk, v) -> (o) otherwise + super().__init__(n_in=(3 if masked else 2), n_out=1) + + self._n_heads = n_heads + if n_parallel_heads: + if (n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) or ( + n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0 + ): + raise ValueError( + "n_parallel_heads must be a multiple or fraction of n_heads" + ) + self._n_parallel_heads = n_parallel_heads + else: + self._n_parallel_heads = None - if compute_grad: - i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + self._incremental = mode == "predict" + if self._incremental: + assert causal, "Only causal attention supports fast inference" + assert chunk_len is not None or (predict_mem_len and predict_drop_len) + predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) + predict_drop_len = predict_drop_len or chunk_len + if predict_mem_len is None or predict_drop_len is None: + raise ValueError("This configuration does not support fast inference.") + if not 0 < predict_drop_len <= predict_mem_len: + raise ValueError( + "Bad parameter values: (predict_mem_len, predict_drop_len) = ", + predict_mem_len, + predict_drop_len, + ) + self._predict_mem_len = predict_mem_len + self._predict_drop_len = predict_drop_len + + self._use_python_loop = use_python_loop + self._use_reference_code = use_reference_code + + self._d_qk = d_qk + self._d_v = d_v + self._share_qk = True + self._causal = causal + self._masked = masked + self._chunk_len = chunk_len + self._n_chunks_before = n_chunks_before + self._n_chunks_after = n_chunks_after + self._bias = bias + self._mode = mode + if mode == "train": + self._attention_dropout = attention_dropout + self._output_dropout = output_dropout + else: + self._attention_dropout = 0.0 + self._output_dropout = 0.0 + + self._n_hashes = n_hashes + self._n_buckets = n_buckets + self._max_length_for_buckets = max_length_for_buckets + + def _kernel_initializer(self, shape, rng): + # Attention uses Glorot uniform initalization with respect to the *total* + # dimension of queries/key/values across all heads. We initialize one head + # at a time in this class, so init.GlorotUniformInitializer won't work. + # This initialization type is for parity with previous Trax & tensor2tensor + # Transformers; it's not clear if it's strictly needed for model accuracy. + lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) + return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) + + def init_weights_and_state(self, input_signature): + # input_signature should be the type signature of (qk, v, mask) or (qk, v) + expected_inputs = 3 if self._masked else 2 + if not ( + isinstance(input_signature, (tuple, list)) + and len(input_signature) == expected_inputs + ): + raise ValueError( + f"input_signature should be {expected_inputs}-tuple, " + f"but is: {input_signature}" + ) + + # Each of qk, v are shaped - (batch * heads, length, d_head) + # mask is shaped: (batch, length) + qk_signature = input_signature[0] + v_signature = input_signature[1] + # mask_signature = input_signature[2] + # batch = mask_signature.shape[0] + batch_x_heads = qk_signature.shape[0] + + assert batch_x_heads % self._n_heads == 0 + batch = batch_x_heads // self._n_heads + + query_signature_unbatched = fastmath.nested_map( + lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), qk_signature + ) - if self._incremental and update_state: - s_all = (new_mem_end, new_mem, s_all) + state_rngs = fastmath.random.split(self.rng, batch_x_heads) + state = [ + self.create_state_unbatched(query_signature_unbatched, rng) + for rng in state_rngs + ] - if have_single_input and compute_grad: - assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 - return (o_all, s_all, i_ct_all[0], w_ct_all) - else: - return (o_all, s_all, i_ct_all, w_ct_all) + stack_along_axis_0 = lambda *x: np.stack(x, axis=0) + state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) + if self._incremental: + mem = fastmath.nested_map( + lambda x: np.zeros( # pylint: disable=g-long-lambda + x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], dtype=x.dtype + ), + (qk_signature, v_signature), + ) + mem_end = np.zeros((), dtype=np.int32) + state = (mem_end, mem, state) + + self.state = tuple(state) + self.weights = () + + def create_state_unbatched(self, input_signature, rng): + if isinstance(input_signature, (tuple, list)): + input_signature = input_signature[0] + # The `rng` argument passed to forward_unbatched is shared across all + # examples and heads. This facilitates using broadcasted dropout, which + # saves memory and hasn't been shown to hurt model quality. Even though the + # same sharing is likely to be safe when selecting random hash functions + # for LSH, we haven't run experiments to demonstrate this. To be on the safe + # side we include a per-head RNG in the state for the purpose of doing LSH. + if not self._incremental: + length = self._max_length_for_buckets or input_signature.shape[0] + buckets = np.zeros(self._n_hashes * length, dtype=np.int32) + return (buckets, rng) + else: + buckets = np.zeros(self._n_hashes * self._predict_mem_len, dtype=np.int32) + buckets_idx = np.zeros((), dtype=np.int32) + return (buckets, buckets_idx, rng) + + def hash_vectors(self, vecs, rng, mask=None): + n_buckets_list = self._n_buckets + + # Determine the number of buckets needed from input length if not set. + if n_buckets_list is None: + length = vecs.shape[0] + n_buckets = 2 * max(1, length // self._chunk_len) + if n_buckets <= 128: + n_buckets_list = n_buckets + else: # Factorize n_buckets. + n_buckets_div = 2 ** math.ceil(math.log2(math.sqrt(n_buckets))) + # Both factors must be even. + n_buckets_rest = 2 * (n_buckets // (2 * n_buckets_div)) + n_buckets_list = [n_buckets_div, n_buckets_rest] + + # Hash vectors. + buckets, n_buckets = hash_vecs(vecs, n_buckets_list, self._n_hashes, rng) + + if mask is not None: + n_buckets += 1 # Create an extra bucket for padding tokens only + buckets = np.where(mask[None, :], buckets, n_buckets - 1) + + # buckets is now (n_hashes, seqlen). Next we add offsets so that + # bucket numbers from different hashing rounds don't overlap. + offsets = np.arange(self._n_hashes, dtype=np.int32) + offsets = np.reshape(offsets * n_buckets, (-1, 1)) + buckets = np.reshape(buckets + offsets, (-1,)) + return buckets + + def forward_unbatched(self, qk, v, mask=None, *, state, rng, update_state): + attend_rng, output_rng = fastmath.random.split( + rng + ) # pylint: disable=unused-variable + + # Since these are unbatched: + # q, v are shaped (seqlen, d_head) + # mask is shaped (seqlen,) + q = qk + seqlen = q.shape[0] -class PureLSHSelfAttention(base.Layer): - """LSH self-attention without weights.""" - - def __init__(self, - n_heads=2, d_qk=64, d_v=64, share_qk='unused', - causal=False, - masked=False, - chunk_len=128, - n_chunks_before=1, n_chunks_after=0, - n_hashes=1, - n_buckets=None, - mode='train', - predict_mem_len=2048, predict_drop_len=256, - attention_dropout=0.0, - output_dropout=0.0, - max_length_for_buckets=None, - bias=False, - n_parallel_heads=1, - use_python_loop=False, - use_reference_code=False, - ): - """Construct an LSH self-attention layer.""" - # (qk, v, mask) -> (o) if masked - # (qk, v) -> (o) otherwise - super().__init__(n_in=(3 if masked else 2), n_out=1) - - self._n_heads = n_heads - if n_parallel_heads: - if ((n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) - or (n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0)): - raise ValueError( - 'n_parallel_heads must be a multiple or fraction of n_heads') - self._n_parallel_heads = n_parallel_heads - else: - self._n_parallel_heads = None - - self._incremental = (mode == 'predict') - if self._incremental: - assert causal, 'Only causal attention supports fast inference' - assert chunk_len is not None or (predict_mem_len and predict_drop_len) - predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) - predict_drop_len = predict_drop_len or chunk_len - if predict_mem_len is None or predict_drop_len is None: - raise ValueError('This configuration does not support fast inference.') - if not 0 < predict_drop_len <= predict_mem_len: - raise ValueError( - 'Bad parameter values: (predict_mem_len, predict_drop_len) = ', - predict_mem_len, predict_drop_len) - self._predict_mem_len = predict_mem_len - self._predict_drop_len = predict_drop_len - - self._use_python_loop = use_python_loop - self._use_reference_code = use_reference_code - - self._d_qk = d_qk - self._d_v = d_v - self._share_qk = True - self._causal = causal - self._masked = masked - self._chunk_len = chunk_len - self._n_chunks_before = n_chunks_before - self._n_chunks_after = n_chunks_after - self._bias = bias - self._mode = mode - if mode == 'train': - self._attention_dropout = attention_dropout - self._output_dropout = output_dropout - else: - self._attention_dropout = 0.0 - self._output_dropout = 0.0 - - self._n_hashes = n_hashes - self._n_buckets = n_buckets - self._max_length_for_buckets = max_length_for_buckets - - def _kernel_initializer(self, shape, rng): - # Attention uses Glorot uniform initalization with respect to the *total* - # dimension of queries/key/values across all heads. We initialize one head - # at a time in this class, so init.GlorotUniformInitializer won't work. - # This initialization type is for parity with previous Trax & tensor2tensor - # Transformers; it's not clear if it's strictly needed for model accuracy. - lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) - return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) - - def init_weights_and_state(self, input_signature): - # input_signature should be the type signature of (qk, v, mask) or (qk, v) - expected_inputs = 3 if self._masked else 2 - if not (isinstance(input_signature, (tuple, list)) and - len(input_signature) == expected_inputs): - raise ValueError( - f'input_signature should be {expected_inputs}-tuple, ' - f'but is: {input_signature}') - - # Each of qk, v are shaped - (batch * heads, length, d_head) - # mask is shaped: (batch, length) - qk_signature = input_signature[0] - v_signature = input_signature[1] - # mask_signature = input_signature[2] - # batch = mask_signature.shape[0] - batch_x_heads = qk_signature.shape[0] - - assert batch_x_heads % self._n_heads == 0 - batch = batch_x_heads // self._n_heads - - query_signature_unbatched = fastmath.nested_map( - lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), - qk_signature) - - state_rngs = fastmath.random.split(self.rng, batch_x_heads) - state = [self.create_state_unbatched(query_signature_unbatched, rng) - for rng in state_rngs] - - stack_along_axis_0 = lambda *x: np.stack(x, axis=0) - state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - - if self._incremental: - mem = fastmath.nested_map( - lambda x: np.zeros( # pylint: disable=g-long-lambda - x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], - dtype=x.dtype), - (qk_signature, v_signature)) - mem_end = np.zeros((), dtype=np.int32) - state = (mem_end, mem, state) - - self.state = tuple(state) - self.weights = () - - def create_state_unbatched(self, input_signature, rng): - if isinstance(input_signature, (tuple, list)): - input_signature = input_signature[0] - # The `rng` argument passed to forward_unbatched is shared across all - # examples and heads. This facilitates using broadcasted dropout, which - # saves memory and hasn't been shown to hurt model quality. Even though the - # same sharing is likely to be safe when selecting random hash functions - # for LSH, we haven't run experiments to demonstrate this. To be on the safe - # side we include a per-head RNG in the state for the purpose of doing LSH. - if not self._incremental: - length = self._max_length_for_buckets or input_signature.shape[0] - buckets = np.zeros(self._n_hashes * length, dtype=np.int32) - return (buckets, rng) - else: - buckets = np.zeros( - self._n_hashes * self._predict_mem_len, dtype=np.int32) - buckets_idx = np.zeros((), dtype=np.int32) - return (buckets, buckets_idx, rng) - - def hash_vectors(self, vecs, rng, mask=None): - n_buckets_list = self._n_buckets - - # Determine the number of buckets needed from input length if not set. - if n_buckets_list is None: - length = vecs.shape[0] - n_buckets = 2 * max(1, length // self._chunk_len) - if n_buckets <= 128: - n_buckets_list = n_buckets - else: # Factorize n_buckets. - n_buckets_div = 2**math.ceil(math.log2(math.sqrt(n_buckets))) - # Both factors must be even. - n_buckets_rest = 2 * (n_buckets // (2 * n_buckets_div)) - n_buckets_list = [n_buckets_div, n_buckets_rest] - - # Hash vectors. - buckets, n_buckets = hash_vecs(vecs, n_buckets_list, self._n_hashes, rng) - - if mask is not None: - n_buckets += 1 # Create an extra bucket for padding tokens only - buckets = np.where(mask[None, :], buckets, n_buckets - 1) - - # buckets is now (n_hashes, seqlen). Next we add offsets so that - # bucket numbers from different hashing rounds don't overlap. - offsets = np.arange(self._n_hashes, dtype=np.int32) - offsets = np.reshape(offsets * n_buckets, (-1, 1)) - buckets = np.reshape(buckets + offsets, (-1,)) - return buckets - - def forward_unbatched(self, qk, v, mask=None, *, state, rng, - update_state): - attend_rng, output_rng = fastmath.random.split(rng) # pylint: disable=unused-variable - - # Since these are unbatched: - # q, v are shaped (seqlen, d_head) - # mask is shaped (seqlen,) - q = qk - seqlen = q.shape[0] - - if update_state: - _, old_hash_rng = state - hash_rng, hash_subrng = fastmath.random.split(old_hash_rng) - buckets = self.hash_vectors(q, hash_subrng, mask) - s_buckets = buckets - if self._max_length_for_buckets: - length = self._n_hashes * self._max_length_for_buckets - if buckets.shape[0] < length: - s_buckets = np.concatenate( - [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)], - axis=0) - state = (s_buckets, hash_rng) - else: - buckets, _ = state - if self._max_length_for_buckets: - buckets = buckets[:self._n_hashes * seqlen] - - assert int(buckets.shape[0]) == self._n_hashes * seqlen - - ticker = np.arange(self._n_hashes * seqlen, dtype=np.int32) - buckets_and_t = seqlen * buckets + (ticker % seqlen) - buckets_and_t = fastmath.stop_gradient(buckets_and_t) - - # Hash-based sort ("s" at the start of variable names means "sorted") - sbuckets_and_t, sticker = fastmath.sort_key_val( - buckets_and_t, ticker, dimension=-1) - _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1) - sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t) - sticker = fastmath.stop_gradient(sticker) - undo_sort = fastmath.stop_gradient(undo_sort) - - st = (sticker % seqlen) - sq = np.take(q, st, axis=0) - sv = np.take(v, st, axis=0) - - mask_fn = functools.partial(mask_self_attention, causal=self._causal, - exclude_self=True, masked=self._masked) - q_info = st - - assert (mask is not None) == self._masked - kv_info = None - if self._masked: - # mask is a boolean array (True means "is valid token") - smask = np.take(mask, st, axis=0) - ones_like_mask = np.ones_like(smask, dtype=np.int32) - kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask) - - so, slogits = attend( - sq, k=None, v=sv, - q_chunk_len=self._chunk_len, - n_chunks_before=self._n_chunks_before, - n_chunks_after=self._n_chunks_after, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, + if update_state: + _, old_hash_rng = state + hash_rng, hash_subrng = fastmath.random.split(old_hash_rng) + buckets = self.hash_vectors(q, hash_subrng, mask) + s_buckets = buckets + if self._max_length_for_buckets: + length = self._n_hashes * self._max_length_for_buckets + if buckets.shape[0] < length: + s_buckets = np.concatenate( + [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)], + axis=0, + ) + state = (s_buckets, hash_rng) + else: + buckets, _ = state + if self._max_length_for_buckets: + buckets = buckets[: self._n_hashes * seqlen] + + assert int(buckets.shape[0]) == self._n_hashes * seqlen + + ticker = np.arange(self._n_hashes * seqlen, dtype=np.int32) + buckets_and_t = seqlen * buckets + (ticker % seqlen) + buckets_and_t = fastmath.stop_gradient(buckets_and_t) + + # Hash-based sort ("s" at the start of variable names means "sorted") + sbuckets_and_t, sticker = fastmath.sort_key_val( + buckets_and_t, ticker, dimension=-1 + ) + _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1) + sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t) + sticker = fastmath.stop_gradient(sticker) + undo_sort = fastmath.stop_gradient(undo_sort) + + st = sticker % seqlen + sq = np.take(q, st, axis=0) + sv = np.take(v, st, axis=0) + + mask_fn = functools.partial( + mask_self_attention, + causal=self._causal, + exclude_self=True, + masked=self._masked, + ) + q_info = st + + assert (mask is not None) == self._masked + kv_info = None + if self._masked: + # mask is a boolean array (True means "is valid token") + smask = np.take(mask, st, axis=0) + ones_like_mask = np.ones_like(smask, dtype=np.int32) + kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask) + + so, slogits = attend( + sq, + k=None, + v=sv, + q_chunk_len=self._chunk_len, + n_chunks_before=self._n_chunks_before, + n_chunks_after=self._n_chunks_after, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, ) - # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would - # also work, but these helpers include performance optimizations for TPU. - o = permute_via_gather(so, undo_sort, sticker, axis=0) - logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) - - if self._n_hashes > 1: - o = np.reshape(o, (self._n_hashes, seqlen, o.shape[-1])) - logits = np.reshape(logits, (self._n_hashes, seqlen, 1)) - probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True)) - o = np.sum(o * probs, axis=0) - - # assert o.shape == (seqlen, w_v.shape[-1]) - assert o.shape == v.shape - - # TODO(afrozm): Unlike LSHSelfAttention we don't apply output dropout here. - out = o - return out, state - - def _incremental_forward_unbatched(self, qk, v, mask=None, *, - q_start, q_len, - state, rng, update_state): - x = (qk, v) - length = x[0].shape[0] - assert update_state, ( - 'This setting not supported (e.g. no backprop for fast inference)') - if q_len > 1: - if isinstance(q_start, int): - assert q_start == 0, 'Chunks larger than 1 only work at start for now.' - if length % self._chunk_len == 0: - x_padded = x - else: - pad_amount = self._chunk_len - (length % self._chunk_len) - x_padded = fastmath.nested_map( - lambda x: np.pad(x, ((0, pad_amount), (0, 0)), mode='constant'), x) - buckets, buckets_idx, hash_rng = state - qk, v = x_padded - buckets_update = self.hash_vectors(qk, hash_rng) - - out, _ = self.forward_unbatched( - qk, v, mask=mask, state=(buckets_update, hash_rng), rng=rng, - update_state=False) - - out = out[:q_len] - buckets = np.reshape(buckets, (self._n_hashes, -1)) - buckets_update = np.reshape( - buckets_update, (self._n_hashes, -1))[:, :q_len] - if q_len > self._predict_mem_len: - buckets_update = buckets_update[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - buckets = fastmath.dynamic_update_slice_in_dim( - buckets, buckets_update, q_start, axis=1) - buckets = np.reshape(buckets, (-1,)) - - return out, (buckets, buckets_idx + q_len, hash_rng) - - # This codepath is for handling one token at a time. - assert q_len == 1 - buckets, buckets_idx, hash_rng = state - - def roll_buckets(buckets): - buckets = np.reshape(buckets, (self._n_hashes, -1)) - new_buckets = np.concatenate( - [buckets, np.zeros((self._n_hashes, self._predict_drop_len), - dtype=buckets.dtype) - ], axis=1) - new_buckets = fastmath.dynamic_slice_in_dim( - new_buckets, buckets_idx - q_start, buckets.shape[-1], axis=1) - new_buckets = np.reshape(new_buckets, (-1,)) - return new_buckets - - buckets = fastmath.cond( - pred=buckets_idx > q_start, - true_operand=buckets, - true_fun=roll_buckets, - false_operand=buckets, - false_fun=lambda x: x, - ) + # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would + # also work, but these helpers include performance optimizations for TPU. + o = permute_via_gather(so, undo_sort, sticker, axis=0) + logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) + + if self._n_hashes > 1: + o = np.reshape(o, (self._n_hashes, seqlen, o.shape[-1])) + logits = np.reshape(logits, (self._n_hashes, seqlen, 1)) + probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True)) + o = np.sum(o * probs, axis=0) + + # assert o.shape == (seqlen, w_v.shape[-1]) + assert o.shape == v.shape - attend_rng, unused_output_rng = fastmath.random.split(rng) - - q_range = q_start + np.arange(q_len, dtype=np.int32) - # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not - # floating-point equivalent, at least in non-jitted code. We correct the - # discrepancy by duplicating the slice. Floating-point noise may not be - # an issue when using models, but it makes it harder to write tests that - # compare fast and slow inference code for equivalence. - q = np.concatenate([qk[q_range]] * 2, 0) - - q_buckets = self.hash_vectors(q, hash_rng) - q_buckets = np.reshape(q_buckets, (self._n_hashes, 2))[:, :q_len] - - unflattened_buckets = fastmath.dynamic_update_slice_in_dim( - np.reshape(buckets, (self._n_hashes, -1)), - q_buckets, q_start, axis=1) - buckets = np.reshape(unflattened_buckets, (-1,)) - is_valid_target = np.any(unflattened_buckets == q_buckets, axis=0) - - assert q_buckets.shape[-1] == 1 # Is true when q_len == 1 - length = qk.shape[0] - arange_seqlen = np.arange(length, dtype=np.int32) - kv_priorities = np.where( - arange_seqlen > (q_start + q_len), - -(length + arange_seqlen), arange_seqlen) - kv_priorities = kv_priorities + length * is_valid_target.astype(np.int32) - _, kv_indices = fastmath.sort_key_val(kv_priorities, arange_seqlen) - kv_indices = kv_indices[ - -self._n_hashes * self._chunk_len * (1 + self._n_chunks_before):] - assert self._n_chunks_after == 0 - - k = length_normalized(qk[kv_indices]) - v = v[kv_indices] - - mask_fn = functools.partial( - mask_self_attention, causal=True, masked=True, exclude_self=True) - q_info = q_start + np.arange(q_len, dtype=np.int32) - kv_info = kv_indices.astype(np.int32) - q_info = q_info.astype(np.int32) - # TODO(kitaev): is it better to mask out attention across buckets? - # kv_info = np.where(is_valid_target[kv_indices], kv_indices, -kv_indices) - o, _ = attend( - q, k, v, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, + # TODO(afrozm): Unlike LSHSelfAttention we don't apply output dropout here. + out = o + return out, state + + def _incremental_forward_unbatched( + self, qk, v, mask=None, *, q_start, q_len, state, rng, update_state + ): + x = (qk, v) + length = x[0].shape[0] + assert ( + update_state + ), "This setting not supported (e.g. no backprop for fast inference)" + if q_len > 1: + if isinstance(q_start, int): + assert q_start == 0, "Chunks larger than 1 only work at start for now." + if length % self._chunk_len == 0: + x_padded = x + else: + pad_amount = self._chunk_len - (length % self._chunk_len) + x_padded = fastmath.nested_map( + lambda x: np.pad(x, ((0, pad_amount), (0, 0)), mode="constant"), x + ) + buckets, buckets_idx, hash_rng = state + qk, v = x_padded + buckets_update = self.hash_vectors(qk, hash_rng) + + out, _ = self.forward_unbatched( + qk, + v, + mask=mask, + state=(buckets_update, hash_rng), + rng=rng, + update_state=False, + ) + + out = out[:q_len] + buckets = np.reshape(buckets, (self._n_hashes, -1)) + buckets_update = np.reshape(buckets_update, (self._n_hashes, -1))[:, :q_len] + if q_len > self._predict_mem_len: + buckets_update = buckets_update[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + buckets = fastmath.dynamic_update_slice_in_dim( + buckets, buckets_update, q_start, axis=1 + ) + buckets = np.reshape(buckets, (-1,)) + + return out, (buckets, buckets_idx + q_len, hash_rng) + + # This codepath is for handling one token at a time. + assert q_len == 1 + buckets, buckets_idx, hash_rng = state + + def roll_buckets(buckets): + buckets = np.reshape(buckets, (self._n_hashes, -1)) + new_buckets = np.concatenate( + [ + buckets, + np.zeros( + (self._n_hashes, self._predict_drop_len), dtype=buckets.dtype + ), + ], + axis=1, + ) + new_buckets = fastmath.dynamic_slice_in_dim( + new_buckets, buckets_idx - q_start, buckets.shape[-1], axis=1 + ) + new_buckets = np.reshape(new_buckets, (-1,)) + return new_buckets + + buckets = fastmath.cond( + pred=buckets_idx > q_start, + true_operand=buckets, + true_fun=roll_buckets, + false_operand=buckets, + false_fun=lambda x: x, ) - out = o - if q_len == 1: - out = out[:1] - buckets_idx = np.array(q_start + q_len, dtype=buckets_idx.dtype) - return out, (buckets, buckets_idx, hash_rng) + attend_rng, unused_output_rng = fastmath.random.split(rng) - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. + q_range = q_start + np.arange(q_len, dtype=np.int32) + # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not + # floating-point equivalent, at least in non-jitted code. We correct the + # discrepancy by duplicating the slice. Floating-point noise may not be + # an issue when using models, but it makes it harder to write tests that + # compare fast and slow inference code for equivalence. + q = np.concatenate([qk[q_range]] * 2, 0) - Args: - inputs: Layer inputs (subclasses may use different inputs) + q_buckets = self.hash_vectors(q, hash_rng) + q_buckets = np.reshape(q_buckets, (self._n_hashes, 2))[:, :q_len] - Returns: - A tuple (output, new_state). - """ - state, rng = self.state, self.rng - - if self._use_reference_code: - raise NotImplementedError( - 'Reference code not implemented for PureLSHSelfAttention') - - output, new_state, unused_input_cotangents = self.forward_and_or_backward( - inputs, state, rng, compute_output=True, update_state=True) - self.state = new_state - return output - - def _use_predict_mem(self, inputs, state): - """Update input cache for fast inference.""" - - # inputs is (qk, v). mask isn't passed in. - # where qk/v are shaped - (batch * n_heads, seq_len, d_head) - - mem_end, mem, state = state - seqlen = inputs[0].shape[-2] - - if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: - # This branch is called when only a small number of tokens are appended to - # the sequence, e.g. when generating one token at a time. A fixed number - # of tokens (self._predict_drop_tokens) will be dropped from memory if - # needed, and then new values will be inserted into the memory. - def roll_mem(buf): - return np.concatenate( - [buf[:, self._predict_drop_len:], - np.zeros_like(buf[:, :self._predict_drop_len])], axis=1) - - do_roll_mem = (mem_end + seqlen > self._predict_mem_len) - mem = fastmath.cond( - pred=do_roll_mem, - true_operand=mem, - true_fun=lambda x: fastmath.nested_map(roll_mem, x), - false_operand=mem, - false_fun=lambda x: x, - ) - mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) - def update_mem(mem_element, new_vals): - assert new_vals.shape[1] == seqlen - if seqlen == 1: - return fastmath.index_update( - mem_element, jax.numpy.index_exp[:, mem_end], new_vals[:, 0, ...]) - else: - return fastmath.dynamic_update_slice_in_dim( - mem_element, new_vals, mem_end, axis=1) - inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) - return inputs, state, mem_end, inputs, mem_end + seqlen - else: - assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len - # This branch handles the case where a large number of tokens are being - # introduced all at once. The code here assumes that we are at the start - # of the sequence, which matches the typical use case of decoding from a - # language model given a long prefix. Note that if we're not at the start - # of the sequence, the code here won't work. - new_flat_mem = [] - for inp in fastmath.tree_leaves(inputs): - assert inp.shape[1] == seqlen - if seqlen == self._predict_mem_len: - new_mem_val = inp - elif seqlen > self._predict_mem_len: - new_mem_val = inp[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - else: - new_mem_val = np.concatenate([ - inp, - np.zeros(inp.shape[:1] - + (self._predict_mem_len - inp.shape[1],) - + inp.shape[2:], - dtype=inp.dtype) - ], axis=1) - new_flat_mem.append(new_mem_val) - mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) - - # This code only works at the start of the sequence. There's no "assert" - # primitive we can use to signal an error, so we instead signal the error - # by introducing NaNs into the computation. - def replace_with_nan_if_not_seq_start(x): - if x.dtype != np.float32: - return x - return fastmath.cond( - pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), - true_operand=x, true_fun=lambda x: x, - false_operand=x, false_fun=lambda x: x * np.nan) - inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) - return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) - - @property - def has_backward(self): - return True - - def backward(self, inputs, output, grad, weights, state, new_state, rng=None, - **kwargs): - """Custom backward pass, for efficiency (see forward_and_or_backward).""" - del output - del state - del kwargs - unused_output, unused_new_state, inputs_grad = self.forward_and_or_backward( - inputs, - new_state, - rng, - output_grad=grad, - compute_output=False, - update_state=False) + unflattened_buckets = fastmath.dynamic_update_slice_in_dim( + np.reshape(buckets, (self._n_hashes, -1)), q_buckets, q_start, axis=1 + ) + buckets = np.reshape(unflattened_buckets, (-1,)) + is_valid_target = np.any(unflattened_buckets == q_buckets, axis=0) + + assert q_buckets.shape[-1] == 1 # Is true when q_len == 1 + length = qk.shape[0] + arange_seqlen = np.arange(length, dtype=np.int32) + kv_priorities = np.where( + arange_seqlen > (q_start + q_len), -(length + arange_seqlen), arange_seqlen + ) + kv_priorities = kv_priorities + length * is_valid_target.astype(np.int32) + _, kv_indices = fastmath.sort_key_val(kv_priorities, arange_seqlen) + kv_indices = kv_indices[ + -self._n_hashes * self._chunk_len * (1 + self._n_chunks_before) : + ] + assert self._n_chunks_after == 0 + + k = length_normalized(qk[kv_indices]) + v = v[kv_indices] + + mask_fn = functools.partial( + mask_self_attention, causal=True, masked=True, exclude_self=True + ) + q_info = q_start + np.arange(q_len, dtype=np.int32) + kv_info = kv_indices.astype(np.int32) + q_info = q_info.astype(np.int32) + # TODO(kitaev): is it better to mask out attention across buckets? + # kv_info = np.where(is_valid_target[kv_indices], kv_indices, -kv_indices) + o, _ = attend( + q, + k, + v, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) - weights_grad = fastmath.nested_map(np.zeros_like, weights) - return inputs_grad, weights_grad + out = o + if q_len == 1: + out = out[:1] + buckets_idx = np.array(q_start + q_len, dtype=buckets_idx.dtype) + return out, (buckets, buckets_idx, hash_rng) - def forward_and_or_backward( - self, inputs, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. - See `forward` for a reference implementation of what this layer does. The - reference implementation is not very efficient, however, and this method - provides a more performant version. + Args: + inputs: Layer inputs (subclasses may use different inputs) - Args: - inputs: inputs to the attention layer tuple (qk, v, mask) - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. + Returns: + A tuple (output, new_state). + """ + state, rng = self.state, self.rng - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). + if self._use_reference_code: + raise NotImplementedError( + "Reference code not implemented for PureLSHSelfAttention" + ) - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # TODO(b/148460708): reduce memory usage further - # TODO(kitaev): there should be a higher-level API (like vmap) that does - # batching, instead of needing 3 separate manual implementations here. - - # Notes regarding the implementation: - # (a) Multiple heads or examples are batched together. There are three - # different regimes possible: one head at a time (for long sequences and - # expensive attention types), several attention heads at a time (for - # long sequences but less-expensive attention types), and several - # examples at a time (for large batches of shorter sequences). For the - # time being, each of these regimes has its own code. - # (b) Python loops produce large computation graphs when jitted, so the - # default is to use a JAX loop instead. - # (c) No intermediate quantities are cached for the backward pass. Instead, - # the forward pass is re-computed when doing backprop. This approach is - # often called "checkpointing" or "rematerialization". When not all - # examples or heads fit in memory simultaneously, the implementation - # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse - # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] - # automatically, so the looping for the backward pass is done manually. - # - # [FW-BW-1] for example, head in zip(examples, heads): - # forward(example, head) - # backward(example, head) # uses intermediates from forward - # - # [FW-BW-2] for example, head in zip(examples, heads): - # forward(example, head) - # for example, head in zip(examples, heads): - # backward(example, head) - - if self._masked: - qk, v, mask = inputs - batch_size = mask.shape[0] - else: - qk, v = inputs - mask = None - batch_size = qk.shape[0] // self._n_heads - batch_x_heads, seqlen, d_model = qk.shape + output, new_state, unused_input_cotangents = self.forward_and_or_backward( + inputs, state, rng, compute_output=True, update_state=True + ) + self.state = new_state + return output + + def _use_predict_mem(self, inputs, state): + """Update input cache for fast inference.""" + + # inputs is (qk, v). mask isn't passed in. + # where qk/v are shaped - (batch * n_heads, seq_len, d_head) + + mem_end, mem, state = state + seqlen = inputs[0].shape[-2] + + if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: + # This branch is called when only a small number of tokens are appended to + # the sequence, e.g. when generating one token at a time. A fixed number + # of tokens (self._predict_drop_tokens) will be dropped from memory if + # needed, and then new values will be inserted into the memory. + def roll_mem(buf): + return np.concatenate( + [ + buf[:, self._predict_drop_len :], + np.zeros_like(buf[:, : self._predict_drop_len]), + ], + axis=1, + ) + + do_roll_mem = mem_end + seqlen > self._predict_mem_len + mem = fastmath.cond( + pred=do_roll_mem, + true_operand=mem, + true_fun=lambda x: fastmath.nested_map(roll_mem, x), + false_operand=mem, + false_fun=lambda x: x, + ) + mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) + + def update_mem(mem_element, new_vals): + assert new_vals.shape[1] == seqlen + if seqlen == 1: + return fastmath.index_update( + mem_element, + jax.numpy.index_exp[:, mem_end], + new_vals[:, 0, ...], + ) + else: + return fastmath.dynamic_update_slice_in_dim( + mem_element, new_vals, mem_end, axis=1 + ) + + inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) + return inputs, state, mem_end, inputs, mem_end + seqlen + else: + assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len + # This branch handles the case where a large number of tokens are being + # introduced all at once. The code here assumes that we are at the start + # of the sequence, which matches the typical use case of decoding from a + # language model given a long prefix. Note that if we're not at the start + # of the sequence, the code here won't work. + new_flat_mem = [] + for inp in fastmath.tree_leaves(inputs): + assert inp.shape[1] == seqlen + if seqlen == self._predict_mem_len: + new_mem_val = inp + elif seqlen > self._predict_mem_len: + new_mem_val = inp[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + else: + new_mem_val = np.concatenate( + [ + inp, + np.zeros( + inp.shape[:1] + + (self._predict_mem_len - inp.shape[1],) + + inp.shape[2:], + dtype=inp.dtype, + ), + ], + axis=1, + ) + new_flat_mem.append(new_mem_val) + mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) + + # This code only works at the start of the sequence. There's no "assert" + # primitive we can use to signal an error, so we instead signal the error + # by introducing NaNs into the computation. + def replace_with_nan_if_not_seq_start(x): + if x.dtype != np.float32: + return x + return fastmath.cond( + pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), + true_operand=x, + true_fun=lambda x: x, + false_operand=x, + false_fun=lambda x: x * np.nan, + ) + + inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) + return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) + + @property + def has_backward(self): + return True + + def backward( + self, inputs, output, grad, weights, state, new_state, rng=None, **kwargs + ): + """Custom backward pass, for efficiency (see forward_and_or_backward).""" + del output + del state + del kwargs + unused_output, unused_new_state, inputs_grad = self.forward_and_or_backward( + inputs, + new_state, + rng, + output_grad=grad, + compute_output=False, + update_state=False, + ) - compute_grad = (output_grad is not None) - assert compute_output or compute_grad, 'No work to perform!' + weights_grad = fastmath.nested_map(np.zeros_like, weights) + return inputs_grad, weights_grad - if not self._incremental: - forward_unbatched = functools.partial( - self.forward_unbatched, rng=rng, update_state=update_state) - else: - assert not compute_grad - - # The input to use_predict_mem is (qk, v) - inputs = (qk, v) - - if update_state: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - else: - # This assumes that the memory stores all of the inputs, which would not - # be valid if doing backprop in mode 'predict' with long lengths. - new_mem_end, inputs, state = state - q_start = new_mem_end - seqlen - - # Reset qk and v to what use_predict_mem/state gave us. - qk, v = inputs - - forward_unbatched = functools.partial( - self._incremental_forward_unbatched, - q_start=fastmath.stop_gradient(q_start), - q_len=fastmath.stop_gradient(seqlen), - rng=rng, update_state=update_state) - - # Adjust degree of parallelism based on the batch size. - n_parallel_heads = batch_size * self._n_heads - if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: - n_parallel_heads = self._n_parallel_heads - - def tree_update(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], - y), - tree, new_values) - - def tree_add(tree, indices, addends): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), - tree, addends) - - if compute_grad: - inputs_is_differentiable = fastmath.nested_map( - lambda x: np.issubdtype(x.dtype, np.inexact), inputs) - def split_differentiable(xs): - differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: x if is_differentiable else None, - xs, inputs_is_differentiable) - non_differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: None if is_differentiable else x, - xs, inputs_is_differentiable) - return differentiable_xs, non_differentiable_xs - def join_differentiable(differentiable_xs, non_differentiable_xs): - """Reconstitute inputs pytree from differentiable/non-d. partitions.""" - differentiable_leaves = fastmath.tree_leaves(differentiable_xs) - non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) - leaves = [] - for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): - if is_differentiable: - leaves.append(differentiable_leaves.pop(0)) - else: - leaves.append(non_differentiable_leaves.pop(0)) - assert not differentiable_leaves - assert not non_differentiable_leaves - tree, _ = fastmath.tree_unflatten(leaves, inputs) - return tree - - def vjp(fn, inp, *args, has_aux=False): - d_inp, nd_inp = split_differentiable(inp) - - def fn_closed_over_nd_inp(d_inp, *args): - inp = join_differentiable(d_inp, nd_inp) - return fn(inp, *args) - return fastmath.vjp(fn_closed_over_nd_inp, d_inp, *args, - has_aux=has_aux) - - if n_parallel_heads != 1: - raise NotImplementedError( - 'PureLSHSelfAttention is not implemented for n_parallel_heads != 1.') - - def run_inner(idx, loop_val): - """Runs one slice of attention (for a single head).""" - o_all, s_all, i_ct_all = loop_val - example_idx = idx // self._n_heads - unused_head_idx = idx % self._n_heads - - s_h = fastmath.nested_map(lambda s: s[idx], state) - - if self._masked: - i_h = (qk[idx], v[idx], mask[example_idx]) - else: - i_h = (qk[idx], v[idx]) - - def forward_fn(i_h): - return forward_unbatched( - *i_h, state=fastmath.stop_gradient(s_h)) - - if compute_grad: - o_h, backward_fn, s_h = vjp(forward_fn, i_h, has_aux=True) - ct_h = output_grad[idx] - assert o_h.shape == ct_h.shape - i_ct_h, = backward_fn(ct_h) - else: - o_h, s_h = forward_fn(i_h) - - if compute_output: - o_all = tree_update(o_all, idx, o_h) - if update_state: - s_all = tree_update(s_all, idx, s_h) - if compute_grad: - i_ct_all = tree_add(i_ct_all, idx, i_ct_h) - return (o_all, s_all, i_ct_all) - - o_all = s_all = i_ct_all = None - if compute_output: - o_all = np.zeros((batch_x_heads, seqlen, d_model), dtype=v.dtype) - if update_state: - s_all = state - if compute_grad: - i_ct_all = fastmath.nested_map(np.zeros_like, inputs) - i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - - loop_val = (o_all, s_all, i_ct_all) - - assert (batch_size * self._n_heads) % n_parallel_heads == 0 - loop_hi = (batch_size * self._n_heads) // n_parallel_heads - if self._use_python_loop or loop_hi == 1: - for idx in range(loop_hi): - loop_val = run_inner(idx, loop_val) - else: - loop_val = fastmath.fori_loop( - 0, loop_hi, run_inner, loop_val) + def forward_and_or_backward( + self, + inputs, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + See `forward` for a reference implementation of what this layer does. The + reference implementation is not very efficient, however, and this method + provides a more performant version. + + Args: + inputs: inputs to the attention layer tuple (qk, v, mask) + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # TODO(b/148460708): reduce memory usage further + # TODO(kitaev): there should be a higher-level API (like vmap) that does + # batching, instead of needing 3 separate manual implementations here. + + # Notes regarding the implementation: + # (a) Multiple heads or examples are batched together. There are three + # different regimes possible: one head at a time (for long sequences and + # expensive attention types), several attention heads at a time (for + # long sequences but less-expensive attention types), and several + # examples at a time (for large batches of shorter sequences). For the + # time being, each of these regimes has its own code. + # (b) Python loops produce large computation graphs when jitted, so the + # default is to use a JAX loop instead. + # (c) No intermediate quantities are cached for the backward pass. Instead, + # the forward pass is re-computed when doing backprop. This approach is + # often called "checkpointing" or "rematerialization". When not all + # examples or heads fit in memory simultaneously, the implementation + # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse + # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] + # automatically, so the looping for the backward pass is done manually. + # + # [FW-BW-1] for example, head in zip(examples, heads): + # forward(example, head) + # backward(example, head) # uses intermediates from forward + # + # [FW-BW-2] for example, head in zip(examples, heads): + # forward(example, head) + # for example, head in zip(examples, heads): + # backward(example, head) + + if self._masked: + qk, v, mask = inputs + batch_size = mask.shape[0] + else: + qk, v = inputs + mask = None + batch_size = qk.shape[0] // self._n_heads + batch_x_heads, seqlen, d_model = qk.shape + + compute_grad = output_grad is not None + assert compute_output or compute_grad, "No work to perform!" + + if not self._incremental: + forward_unbatched = functools.partial( + self.forward_unbatched, rng=rng, update_state=update_state + ) + else: + assert not compute_grad + + # The input to use_predict_mem is (qk, v) + inputs = (qk, v) + + if update_state: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + else: + # This assumes that the memory stores all of the inputs, which would not + # be valid if doing backprop in mode 'predict' with long lengths. + new_mem_end, inputs, state = state + q_start = new_mem_end - seqlen + + # Reset qk and v to what use_predict_mem/state gave us. + qk, v = inputs + + forward_unbatched = functools.partial( + self._incremental_forward_unbatched, + q_start=fastmath.stop_gradient(q_start), + q_len=fastmath.stop_gradient(seqlen), + rng=rng, + update_state=update_state, + ) + + # Adjust degree of parallelism based on the batch size. + n_parallel_heads = batch_size * self._n_heads + if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: + n_parallel_heads = self._n_parallel_heads + + def tree_update(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) + + def tree_add(tree, indices, addends): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), + tree, + addends, + ) - (o_all, s_all, i_ct_all) = loop_val + if compute_grad: + inputs_is_differentiable = fastmath.nested_map( + lambda x: np.issubdtype(x.dtype, np.inexact), inputs + ) + + def split_differentiable(xs): + differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: x if is_differentiable else None, + xs, + inputs_is_differentiable, + ) + non_differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: None if is_differentiable else x, + xs, + inputs_is_differentiable, + ) + return differentiable_xs, non_differentiable_xs + + def join_differentiable(differentiable_xs, non_differentiable_xs): + """Reconstitute inputs pytree from differentiable/non-d. partitions.""" + differentiable_leaves = fastmath.tree_leaves(differentiable_xs) + non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) + leaves = [] + for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): + if is_differentiable: + leaves.append(differentiable_leaves.pop(0)) + else: + leaves.append(non_differentiable_leaves.pop(0)) + assert not differentiable_leaves + assert not non_differentiable_leaves + tree, _ = fastmath.tree_unflatten(leaves, inputs) + return tree + + def vjp(fn, inp, *args, has_aux=False): + d_inp, nd_inp = split_differentiable(inp) + + def fn_closed_over_nd_inp(d_inp, *args): + inp = join_differentiable(d_inp, nd_inp) + return fn(inp, *args) + + return fastmath.vjp( + fn_closed_over_nd_inp, d_inp, *args, has_aux=has_aux + ) + + if n_parallel_heads != 1: + raise NotImplementedError( + "PureLSHSelfAttention is not implemented for n_parallel_heads != 1." + ) + + def run_inner(idx, loop_val): + """Runs one slice of attention (for a single head).""" + o_all, s_all, i_ct_all = loop_val + example_idx = idx // self._n_heads + unused_head_idx = idx % self._n_heads + + s_h = fastmath.nested_map(lambda s: s[idx], state) + + if self._masked: + i_h = (qk[idx], v[idx], mask[example_idx]) + else: + i_h = (qk[idx], v[idx]) + + def forward_fn(i_h): + return forward_unbatched(*i_h, state=fastmath.stop_gradient(s_h)) + + if compute_grad: + o_h, backward_fn, s_h = vjp(forward_fn, i_h, has_aux=True) + ct_h = output_grad[idx] + assert o_h.shape == ct_h.shape + (i_ct_h,) = backward_fn(ct_h) + else: + o_h, s_h = forward_fn(i_h) + + if compute_output: + o_all = tree_update(o_all, idx, o_h) + if update_state: + s_all = tree_update(s_all, idx, s_h) + if compute_grad: + i_ct_all = tree_add(i_ct_all, idx, i_ct_h) + return (o_all, s_all, i_ct_all) + + o_all = s_all = i_ct_all = None + if compute_output: + o_all = np.zeros((batch_x_heads, seqlen, d_model), dtype=v.dtype) + if update_state: + s_all = state + if compute_grad: + i_ct_all = fastmath.nested_map(np.zeros_like, inputs) + i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - if compute_grad: - i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + loop_val = (o_all, s_all, i_ct_all) - if self._incremental and update_state: - s_all = (new_mem_end, new_mem, s_all) + assert (batch_size * self._n_heads) % n_parallel_heads == 0 + loop_hi = (batch_size * self._n_heads) // n_parallel_heads + if self._use_python_loop or loop_hi == 1: + for idx in range(loop_hi): + loop_val = run_inner(idx, loop_val) + else: + loop_val = fastmath.fori_loop(0, loop_hi, run_inner, loop_val) - return (o_all, s_all, i_ct_all) + (o_all, s_all, i_ct_all) = loop_val + + if compute_grad: + i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + + if self._incremental and update_state: + s_all = (new_mem_end, new_mem, s_all) + + return (o_all, s_all, i_ct_all) def _ProjectAndSplitHeads( # pylint: disable=invalid-name @@ -3272,522 +3723,632 @@ def _ProjectAndSplitHeads( # pylint: disable=invalid-name num_weights=2, sparsity=16, length_kernel_size=3, - weights_format='sparse', + weights_format="sparse", rotary_position_emb=False, - mode='train'): - """Creates the QK and V activations from input.""" - # There can be either two or three weights: - # two - qk and v or three - q, k, v - # If there are three, we want to average q and k and use that. - - # Weights can also be in 'heads' major format - (n_heads, d_model, d_head) - # this is used by efficient_attention.LSHSelfAttention and - # efficient_attention.SelfAttention - - # Or they can be in 'model' major format - (d_model, d_model), which is what - # tl._attention/CausalAttention etc use -- so use this format if we pretrain a - # model trained with those and finetuning with PureLSHSelfAttention. - - assert weights_format in ('heads', 'model', 'sparse') - - # When an earlier model was trained with 3 separate weights for Q, K, V - # projections with tl._attention/tl._causalAttention etc. - if weights_format == 'model' and num_weights == 3: - return cb.Serial( - # Create the raw Q, K, V projections. - cb.Branch( - core.Dense(d_model, use_bias=use_bias), - core.Dense(d_model, use_bias=use_bias), - core.Dense(d_model, use_bias=use_bias)), # q, k, v - # Optionally, rotate Q and K vectors if rotary embeddings are used. - cb.Parallel(rotary_pe.Rotate(), rotary_pe.Rotate(), None) - if rotary_position_emb else [], - # Average Q and K into one single QK tensor. - core.Fn('QKAvg', lambda x, y: (x + y) / 2.0, n_out=1), # qk, v - # Split heads and combine with batch dimension to get two tensors of - # (batch * n_heads, seq_len, d_head) shape. - cb.Parallel( - attention.SplitIntoHeads(n_heads), - attention.SplitIntoHeads(n_heads)) # qk, v - ) + mode="train", +): + """Creates the QK and V activations from input.""" + # There can be either two or three weights: + # two - qk and v or three - q, k, v + # If there are three, we want to average q and k and use that. + + # Weights can also be in 'heads' major format - (n_heads, d_model, d_head) + # this is used by efficient_attention.LSHSelfAttention and + # efficient_attention.SelfAttention + + # Or they can be in 'model' major format - (d_model, d_model), which is what + # tl._attention/CausalAttention etc use -- so use this format if we pretrain a + # model trained with those and finetuning with PureLSHSelfAttention. + + assert weights_format in ("heads", "model", "sparse") + + # When an earlier model was trained with 3 separate weights for Q, K, V + # projections with tl._attention/tl._causalAttention etc. + if weights_format == "model" and num_weights == 3: + return cb.Serial( + # Create the raw Q, K, V projections. + cb.Branch( + core.Dense(d_model, use_bias=use_bias), + core.Dense(d_model, use_bias=use_bias), + core.Dense(d_model, use_bias=use_bias), + ), # q, k, v + # Optionally, rotate Q and K vectors if rotary embeddings are used. + cb.Parallel(rotary_pe.Rotate(), rotary_pe.Rotate(), None) + if rotary_position_emb + else [], + # Average Q and K into one single QK tensor. + core.Fn("QKAvg", lambda x, y: (x + y) / 2.0, n_out=1), # qk, v + # Split heads and combine with batch dimension to get two tensors of + # (batch * n_heads, seq_len, d_head) shape. + cb.Parallel( + attention.SplitIntoHeads(n_heads), attention.SplitIntoHeads(n_heads) + ), # qk, v + ) - if weights_format == 'sparse' and num_weights == 3: - d_module = d_model // sparsity - # This layer matches sparsity.MultiplicativeConvCausalAttention, - # see there for more explanation. - # TODO(lukaszkaiser): unify code so that we don't duplicate so much. - return cb.Serial( - cb.Select([0, 0]), # duplicate activations - sp.FactoredDense(sparsity, d_model, d_model), - cb.Select([0, 0, 0]), # use for q, k, v - cb.Parallel( - [sp.LocallyConvDense(sparsity, d_module, mode=mode, - kernel_size=3, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads)], - [sp.LocallyConvDense(sparsity, d_module, mode=mode, - kernel_size=3, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads)], - [cb.Select([0], n_in=2), - sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads)], - ), - core.Fn('QKAvg', lambda x, y: (x + y) / 2.0, n_out=1), - ) + if weights_format == "sparse" and num_weights == 3: + d_module = d_model // sparsity + # This layer matches sparsity.MultiplicativeConvCausalAttention, + # see there for more explanation. + # TODO(lukaszkaiser): unify code so that we don't duplicate so much. + return cb.Serial( + cb.Select([0, 0]), # duplicate activations + sp.FactoredDense(sparsity, d_model, d_model), + cb.Select([0, 0, 0]), # use for q, k, v + cb.Parallel( + [ + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + ], + [ + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + ], + [ + cb.Select([0], n_in=2), + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=1, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + ], + ), + core.Fn("QKAvg", lambda x, y: (x + y) / 2.0, n_out=1), + ) - if weights_format == 'sparse' and num_weights == 2: - d_module = d_model // sparsity - # This layer matches sparsity.MultiplicativeConvCausalAttention, - # see there for more explanation. - # TODO(lukaszkaiser): unify code so that we don't duplicate so much. - return cb.Serial( - cb.Select([0, 0]), # pre-qkv, pre-v-for-concat - sp.FactoredDense(sparsity, d_model, d_model), # shared q k - cb.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat - sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads), - cb.Parallel( - [], - [cb.Select([0], n_in=2), - sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads)], + if weights_format == "sparse" and num_weights == 2: + d_module = d_model // sparsity + # This layer matches sparsity.MultiplicativeConvCausalAttention, + # see there for more explanation. + # TODO(lukaszkaiser): unify code so that we don't duplicate so much. + return cb.Serial( + cb.Select([0, 0]), # pre-qkv, pre-v-for-concat + sp.FactoredDense(sparsity, d_model, d_model), # shared q k + cb.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + cb.Parallel( + [], + [ + cb.Select([0], n_in=2), + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=1, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + ], + ), ) - ) - # We want to train from scratch and have only two weights, w_qk and w_v. - if weights_format == 'model' and num_weights == 2: - return cb.Branch( - [ - core.Dense(d_model, use_bias=use_bias), - rotary_pe.Rotate() if rotary_position_emb else [], - attention.SplitIntoHeads(n_heads) - ], - [ - core.Dense(d_model, use_bias=use_bias), - attention.SplitIntoHeads(n_heads) - ], - ) + # We want to train from scratch and have only two weights, w_qk and w_v. + if weights_format == "model" and num_weights == 2: + return cb.Branch( + [ + core.Dense(d_model, use_bias=use_bias), + rotary_pe.Rotate() if rotary_position_emb else [], + attention.SplitIntoHeads(n_heads), + ], + [core.Dense(d_model, use_bias=use_bias), attention.SplitIntoHeads(n_heads)], + ) - assert weights_format == 'head' + assert weights_format == "head" - raise NotImplementedError('TODO(afrozm): Implement this when we want to use ' - 'checkpoints trained with LSHSelfAttention or ' - 'SelfAttention') + raise NotImplementedError( + "TODO(afrozm): Implement this when we want to use " + "checkpoints trained with LSHSelfAttention or " + "SelfAttention" + ) class MixedLSHSelfAttention(base.Layer): - """LSH attention mixed with standard attention used until std_length.""" - - def __init__(self, - n_heads=1, - d_qk=64, - d_v=64, - causal=False, - masked=False, - std_length=None, - mode='train', - output_dropout=0.0, - attention_dropout=0.0, - force_no_dropout=False, - **pure_lsh_implementation_kwargs): - # This class could be replaced with a Branch and tl.Fn(..) selecting - # one of the arguments based on the class. But, similarly to the Wrapper - # below, we need forward_and_backward currently to pass remembered state - # back to the PureLSH layer. We should switch that to the other Remember - # mechanism used for the SparseFF layer (and clarify and document that too). - # Once this is done, we can remove this and the Wrapper class. - attention_dropout = 0.0 if force_no_dropout else attention_dropout - output_dropout = 0.0 if force_no_dropout else output_dropout - self._lsha = PureLSHSelfAttention(n_heads=n_heads, - d_qk=d_qk, - d_v=d_v, - causal=causal, - masked=masked, - mode=mode, - output_dropout=output_dropout, - attention_dropout=attention_dropout, - **pure_lsh_implementation_kwargs) - if causal: - pure_attn = attention.DotProductCausalAttention - preprocess = core.Fn('dup_shared_qk', lambda q, v: (q, q, v), n_out=3) - else: - pure_attn = attention.DotProductAttention - def _add_heads_to_mask(m): - m_with_heads = np.reshape(m, (m.shape[0], 1, m.shape[1])) - m_with_heads = np.broadcast_to(m_with_heads, - (m.shape[0], n_heads, m.shape[1])) - return np.reshape(m_with_heads, (-1, 1, m.shape[1])) - preprocess = core.Fn('dup_shared_qk_and_make_mask', - lambda q, v, m: (q, q, v, _add_heads_to_mask(m)), - n_out=4) - self._stda = cb.Serial( - preprocess, - pure_attn(dropout=attention_dropout, mode=mode) - ) - self._std_length = std_length - self._sublayers = [self._lsha, self._stda] - - if self._stda.n_in != self._lsha.n_in: - raise ValueError(f'n_in diff: {self._stda.n_in} != {self._lsha.n_in}') - if self._stda.n_out != self._lsha.n_out: - raise ValueError(f'n_out diff: {self._stda.n_out} != {self._lsha.n_out}') - super().__init__(n_in=self._stda.n_in, n_out=self._stda.n_out) - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - states = [] - for sublayer in [self._lsha, self._stda]: - unused_weights_or_cache_marker, state_or_cache_marker = sublayer.init( - input_signature, use_cache=False) - states.append(state_or_cache_marker) - self.state = tuple(states) - self.weights = () # Wrapper forward_and_backward assumes this is () - - def forward(self, xs): - """Executes this layer as part of a forward pass through the model.""" - rng1, rng2 = fastmath.random.split(self.rng, 2) - l = xs[0].shape[1] if isinstance(xs, tuple) else xs.shape[1] - if self._std_length is None or l > self._std_length: - s = self.state[0] - outputs, s = self._lsha.pure_fn(xs, (), s, rng1, use_cache=True) - self.state = (s, self.state[1]) - else: - s = self.state[1] - w = ((), ()) # std attention is a Serial(Dup, DotProduct), needs 2 () - outputs, s = self._stda.pure_fn(xs, w, s, rng2, use_cache=True) - self.state = (self.state[0], s) - return outputs - - def forward_and_or_backward(self, inputs, state, rng, - output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes.""" - assert compute_output - assert not update_state - - l = inputs[0].shape[1] if isinstance(inputs, tuple) else inputs.shape[1] - rng1, rng2 = fastmath.random.split(rng, 2) - if self._std_length is None or l > self._std_length: - # Run the LSH layer - s = state[0] - (out, unused_new_s, grads_inputs) = self._lsha.forward_and_or_backward( - inputs, s, rng1, output_grad=output_grad, - compute_output=True, update_state=False) - else: - # Run the standard layer - s, w = state[1], ((), ()) - out, std_vjp_fn, unused_new_s = fastmath.vjp( - self._stda.pure_fn, inputs, w, s, rng2, has_aux=True) - if output_grad is not None: - grads_inputs, _, _, _ = std_vjp_fn(output_grad) - else: - grads_inputs = None + """LSH attention mixed with standard attention used until std_length.""" + + def __init__( + self, + n_heads=1, + d_qk=64, + d_v=64, + causal=False, + masked=False, + std_length=None, + mode="train", + output_dropout=0.0, + attention_dropout=0.0, + force_no_dropout=False, + **pure_lsh_implementation_kwargs, + ): + # This class could be replaced with a Branch and tl.Fn(..) selecting + # one of the arguments based on the class. But, similarly to the Wrapper + # below, we need forward_and_backward currently to pass remembered state + # back to the PureLSH layer. We should switch that to the other Remember + # mechanism used for the SparseFF layer (and clarify and document that too). + # Once this is done, we can remove this and the Wrapper class. + attention_dropout = 0.0 if force_no_dropout else attention_dropout + output_dropout = 0.0 if force_no_dropout else output_dropout + self._lsha = PureLSHSelfAttention( + n_heads=n_heads, + d_qk=d_qk, + d_v=d_v, + causal=causal, + masked=masked, + mode=mode, + output_dropout=output_dropout, + attention_dropout=attention_dropout, + **pure_lsh_implementation_kwargs, + ) + if causal: + pure_attn = attention.DotProductCausalAttention + preprocess = core.Fn("dup_shared_qk", lambda q, v: (q, q, v), n_out=3) + else: + pure_attn = attention.DotProductAttention + + def _add_heads_to_mask(m): + m_with_heads = np.reshape(m, (m.shape[0], 1, m.shape[1])) + m_with_heads = np.broadcast_to( + m_with_heads, (m.shape[0], n_heads, m.shape[1]) + ) + return np.reshape(m_with_heads, (-1, 1, m.shape[1])) + + preprocess = core.Fn( + "dup_shared_qk_and_make_mask", + lambda q, v, m: (q, q, v, _add_heads_to_mask(m)), + n_out=4, + ) + self._stda = cb.Serial( + preprocess, pure_attn(dropout=attention_dropout, mode=mode) + ) + self._std_length = std_length + self._sublayers = [self._lsha, self._stda] + + if self._stda.n_in != self._lsha.n_in: + raise ValueError(f"n_in diff: {self._stda.n_in} != {self._lsha.n_in}") + if self._stda.n_out != self._lsha.n_out: + raise ValueError(f"n_out diff: {self._stda.n_out} != {self._lsha.n_out}") + super().__init__(n_in=self._stda.n_in, n_out=self._stda.n_out) + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + states = [] + for sublayer in [self._lsha, self._stda]: + unused_weights_or_cache_marker, state_or_cache_marker = sublayer.init( + input_signature, use_cache=False + ) + states.append(state_or_cache_marker) + self.state = tuple(states) + self.weights = () # Wrapper forward_and_backward assumes this is () + + def forward(self, xs): + """Executes this layer as part of a forward pass through the model.""" + rng1, rng2 = fastmath.random.split(self.rng, 2) + l = xs[0].shape[1] if isinstance(xs, tuple) else xs.shape[1] + if self._std_length is None or l > self._std_length: + s = self.state[0] + outputs, s = self._lsha.pure_fn(xs, (), s, rng1, use_cache=True) + self.state = (s, self.state[1]) + else: + s = self.state[1] + w = ((), ()) # std attention is a Serial(Dup, DotProduct), needs 2 () + outputs, s = self._stda.pure_fn(xs, w, s, rng2, use_cache=True) + self.state = (self.state[0], s) + return outputs + + def forward_and_or_backward( + self, + inputs, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes.""" + assert compute_output + assert not update_state + + l = inputs[0].shape[1] if isinstance(inputs, tuple) else inputs.shape[1] + rng1, rng2 = fastmath.random.split(rng, 2) + if self._std_length is None or l > self._std_length: + # Run the LSH layer + s = state[0] + (out, unused_new_s, grads_inputs) = self._lsha.forward_and_or_backward( + inputs, + s, + rng1, + output_grad=output_grad, + compute_output=True, + update_state=False, + ) + else: + # Run the standard layer + s, w = state[1], ((), ()) + out, std_vjp_fn, unused_new_s = fastmath.vjp( + self._stda.pure_fn, inputs, w, s, rng2, has_aux=True + ) + if output_grad is not None: + grads_inputs, _, _, _ = std_vjp_fn(output_grad) + else: + grads_inputs = None - return (out, None, grads_inputs) + return (out, None, grads_inputs) class PureLSHSelfAttentionWrapper(cb.Serial): - """Pure LSH serial.""" - - def __init__(self, - n_heads=1, - d_qk=64, - d_v=64, - causal=False, - masked=False, - output_dropout=0.0, - attention_dropout=0.0, - pure_lsh_implementation=None, - bias=True, - mode='train', - num_weights=3, - sparsity=16, - weights_format='model', - rotary_position_emb=False, - **pure_lsh_implementation_kwargs): - d_model = d_qk * n_heads - self._qkv = _ProjectAndSplitHeads( - d_model, - n_heads, - bias, - num_weights=num_weights, - sparsity=sparsity, - weights_format=weights_format, - rotary_position_emb=rotary_position_emb, - mode=mode) - self._attn = pure_lsh_implementation(n_heads=n_heads, - d_qk=d_qk, - d_v=d_v, - causal=causal, - masked=masked, - mode=mode, - output_dropout=output_dropout, - attention_dropout=attention_dropout, - **pure_lsh_implementation_kwargs) - self._merge = attention.MergeHeads(n_heads) - if weights_format != 'sparse': - self._dense = core.Dense(d_model, use_bias=bias) - super().__init__(self._qkv, self._attn, self._merge, self._dense) - else: - self._dense = None - super().__init__(self._qkv, self._attn, self._merge) + """Pure LSH serial.""" + + def __init__( + self, + n_heads=1, + d_qk=64, + d_v=64, + causal=False, + masked=False, + output_dropout=0.0, + attention_dropout=0.0, + pure_lsh_implementation=None, + bias=True, + mode="train", + num_weights=3, + sparsity=16, + weights_format="model", + rotary_position_emb=False, + **pure_lsh_implementation_kwargs, + ): + d_model = d_qk * n_heads + self._qkv = _ProjectAndSplitHeads( + d_model, + n_heads, + bias, + num_weights=num_weights, + sparsity=sparsity, + weights_format=weights_format, + rotary_position_emb=rotary_position_emb, + mode=mode, + ) + self._attn = pure_lsh_implementation( + n_heads=n_heads, + d_qk=d_qk, + d_v=d_v, + causal=causal, + masked=masked, + mode=mode, + output_dropout=output_dropout, + attention_dropout=attention_dropout, + **pure_lsh_implementation_kwargs, + ) + self._merge = attention.MergeHeads(n_heads) + if weights_format != "sparse": + self._dense = core.Dense(d_model, use_bias=bias) + super().__init__(self._qkv, self._attn, self._merge, self._dense) + else: + self._dense = None + super().__init__(self._qkv, self._attn, self._merge) - def forward_and_or_backward(self, inputs, weights, state, rng, - output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. + def forward_and_or_backward( + self, + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + assert compute_output + assert not update_state + assert output_grad is not None + + rngs = fastmath.random.split(rng, 4) + # Layer order forward: self._qkv, self._attn, self._merge, self._dense + # Use forward_and_or_backward for attn. + + qkv_output, qkv_vjp_fn, unused_qkv_new_state = fastmath.vjp( + self._qkv.pure_fn, inputs, weights[0], state[0], rngs[0], has_aux=True + ) - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. + attn_output, _, _ = self._attn.forward_and_or_backward( + qkv_output, + state[1], + rngs[1], + output_grad=None, + compute_output=True, + update_state=False, + ) - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - assert compute_output - assert not update_state - assert output_grad is not None - - rngs = fastmath.random.split(rng, 4) - # Layer order forward: self._qkv, self._attn, self._merge, self._dense - # Use forward_and_or_backward for attn. - - qkv_output, qkv_vjp_fn, unused_qkv_new_state = fastmath.vjp( - self._qkv.pure_fn, inputs, weights[0], state[0], - rngs[0], has_aux=True) - - attn_output, _, _ = self._attn.forward_and_or_backward( - qkv_output, state[1], rngs[1], output_grad=None, - compute_output=True, update_state=False) - - merge_output, merge_vjp_fn, unused_merge_new_state = fastmath.vjp( - self._merge.pure_fn, attn_output, weights[2], state[2], rngs[2], - has_aux=True) - - if self._dense is not None: - dense_output, dense_vjp_fn, unused_dense_new_state = fastmath.vjp( - self._dense.pure_fn, merge_output, weights[3], state[3], - rngs[3], has_aux=True) - - # Now backward. - if self._dense is not None: - dense_grads_inputs, dense_grads_weights, _, _ = dense_vjp_fn( - output_grad) - else: - dense_grads_inputs = output_grad - merge_grads_inputs, merge_grads_weights, _, _ = merge_vjp_fn( - dense_grads_inputs) - - # Use forward_and_or_backward for attn. - (attn_output, _, attn_grads_inputs) = self._attn.forward_and_or_backward( - qkv_output, state[1], rngs[1], output_grad=merge_grads_inputs, - compute_output=True, update_state=False) - - # Backward for qkv layer. - qkv_grad_inputs, qkv_grads_weights, _, _ = qkv_vjp_fn(attn_grads_inputs) - - if self._dense is None: - grads_weights = (qkv_grads_weights, - (), - merge_grads_weights) - else: - grads_weights = (qkv_grads_weights, - (), - merge_grads_weights, - dense_grads_weights) - - # Output is (output, new_state, inputs_grad, weights_grad). - # new_state is None because update_state is False. - if self._dense is None: - return (merge_output, None, qkv_grad_inputs, grads_weights) - else: - return (dense_output, None, qkv_grad_inputs, grads_weights) + merge_output, merge_vjp_fn, unused_merge_new_state = fastmath.vjp( + self._merge.pure_fn, + attn_output, + weights[2], + state[2], + rngs[2], + has_aux=True, + ) + + if self._dense is not None: + dense_output, dense_vjp_fn, unused_dense_new_state = fastmath.vjp( + self._dense.pure_fn, + merge_output, + weights[3], + state[3], + rngs[3], + has_aux=True, + ) + + # Now backward. + if self._dense is not None: + dense_grads_inputs, dense_grads_weights, _, _ = dense_vjp_fn(output_grad) + else: + dense_grads_inputs = output_grad + merge_grads_inputs, merge_grads_weights, _, _ = merge_vjp_fn(dense_grads_inputs) + + # Use forward_and_or_backward for attn. + (attn_output, _, attn_grads_inputs) = self._attn.forward_and_or_backward( + qkv_output, + state[1], + rngs[1], + output_grad=merge_grads_inputs, + compute_output=True, + update_state=False, + ) + + # Backward for qkv layer. + qkv_grad_inputs, qkv_grads_weights, _, _ = qkv_vjp_fn(attn_grads_inputs) + + if self._dense is None: + grads_weights = (qkv_grads_weights, (), merge_grads_weights) + else: + grads_weights = ( + qkv_grads_weights, + (), + merge_grads_weights, + dense_grads_weights, + ) + + # Output is (output, new_state, inputs_grad, weights_grad). + # new_state is None because update_state is False. + if self._dense is None: + return (merge_output, None, qkv_grad_inputs, grads_weights) + else: + return (dense_output, None, qkv_grad_inputs, grads_weights) class EncDecAttention(EfficientAttentionBase): - """Memory-efficient encoder-decoder attention.""" - - def __init__(self, - n_heads=2, d_qk=64, d_v=64, - masked=True, - mode='train', - attention_dropout=0.0, - output_dropout=0.0, - n_parallel_heads=None, - use_python_loop=False, - use_reference_code=False, - ): - super().__init__( - n_heads=n_heads, - n_in=(3 if masked else 2), - n_parallel_heads=n_parallel_heads, - use_python_loop=use_python_loop, - use_reference_code=use_reference_code, + """Memory-efficient encoder-decoder attention.""" + + def __init__( + self, + n_heads=2, + d_qk=64, + d_v=64, + masked=True, + mode="train", + attention_dropout=0.0, + output_dropout=0.0, + n_parallel_heads=None, + use_python_loop=False, + use_reference_code=False, + ): + super().__init__( + n_heads=n_heads, + n_in=(3 if masked else 2), + n_parallel_heads=n_parallel_heads, + use_python_loop=use_python_loop, + use_reference_code=use_reference_code, ) - self._d_qk = d_qk - self._d_v = d_v - self._masked = masked - self._mode = mode - if mode == 'train': - self._attention_dropout = attention_dropout - self._output_dropout = output_dropout - else: - self._attention_dropout = 0.0 - self._output_dropout = 0.0 - - def _kernel_initializer(self, shape, rng): - # Attention uses Glorot uniform initalization with respect to the *total* - # dimension of queries/key/values across all heads. We initialize one head - # at a time in this class, so init.GlorotUniformInitializer won't work. - # This initialization type is for parity with previous Trax & tensor2tensor - # Transformers; it's not clear if it's strictly needed for model accuracy. - lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) - return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) - - def create_weights_unbatched(self, input_signature, rng): - d_model = input_signature[0].shape[-1] - d_kv_antecedent = input_signature[1].shape[-1] - rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) - w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) - w_k = self._kernel_initializer((d_kv_antecedent, self._d_qk), rng_k) - w_v = self._kernel_initializer((d_kv_antecedent, self._d_v), rng_v) - w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) - return (w_q, w_k, w_v, w_o) - - def forward_unbatched(self, q_antecedent, kv_antecedent, mask=None, *, - weights, state, rng, update_state): - del update_state - attend_rng, output_rng = fastmath.random.split(rng) - w_q, w_k, w_v, w_o = weights - - q = np.matmul(q_antecedent, w_q) - k = np.matmul(kv_antecedent, w_k) - v = np.matmul(kv_antecedent, w_v) - - if not self._masked: - assert mask is None - q_info = kv_info = mask_fn = None - else: - # mask is a boolean array (True means "is valid token") - assert mask is not None - q_info = None - kv_info = (~mask).astype(np.int32) # pylint: disable=invalid-unary-operand-type - def mask_fn(dots, q_info, kv_info): - del q_info - mask = kv_info.astype(np.float32) - dots = dots - 1e9 * mask - return dots + self._d_qk = d_qk + self._d_v = d_v + self._masked = masked + self._mode = mode + if mode == "train": + self._attention_dropout = attention_dropout + self._output_dropout = output_dropout + else: + self._attention_dropout = 0.0 + self._output_dropout = 0.0 + + def _kernel_initializer(self, shape, rng): + # Attention uses Glorot uniform initalization with respect to the *total* + # dimension of queries/key/values across all heads. We initialize one head + # at a time in this class, so init.GlorotUniformInitializer won't work. + # This initialization type is for parity with previous Trax & tensor2tensor + # Transformers; it's not clear if it's strictly needed for model accuracy. + lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) + return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) + + def create_weights_unbatched(self, input_signature, rng): + d_model = input_signature[0].shape[-1] + d_kv_antecedent = input_signature[1].shape[-1] + rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) + w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) + w_k = self._kernel_initializer((d_kv_antecedent, self._d_qk), rng_k) + w_v = self._kernel_initializer((d_kv_antecedent, self._d_v), rng_v) + w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) + return (w_q, w_k, w_v, w_o) + + def forward_unbatched( + self, + q_antecedent, + kv_antecedent, + mask=None, + *, + weights, + state, + rng, + update_state, + ): + del update_state + attend_rng, output_rng = fastmath.random.split(rng) + w_q, w_k, w_v, w_o = weights - o, _ = attend( - q, k, v, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, + q = np.matmul(q_antecedent, w_q) + k = np.matmul(kv_antecedent, w_k) + v = np.matmul(kv_antecedent, w_v) + + if not self._masked: + assert mask is None + q_info = kv_info = mask_fn = None + else: + # mask is a boolean array (True means "is valid token") + assert mask is not None + q_info = None + kv_info = (~mask).astype( + np.int32 + ) # pylint: disable=invalid-unary-operand-type + + def mask_fn(dots, q_info, kv_info): + del q_info + mask = kv_info.astype(np.float32) + dots = dots - 1e9 * mask + return dots + + o, _ = attend( + q, + k, + v, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, ) - out = np.matmul(o, w_o) - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - return out, state + out = np.matmul(o, w_o) + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + return out, state class LSHFF(base.Layer): - """Feed-forward block with LSH. - - The original (non-LSH) feed-forward block is a triple Dense(d_ff)-Relu-Dense - that takes an input, makes it of size d_ff (usually larger than it was) and - then brings it back to the original size after Relu. It is commonly used in - Transformer models where it often accounts for most of the trainable weights. - - The original block can be slow in decoding due to the need to fetch a lot of - weights from memory. The LSH block aims to exploit this sparsity. So in the - first Dense(d_ff) layer, instead of making a full matrix multiplication, - this block only multiplies by the parts of the weights matrix that have - the highest chance to give non-0 after Relu. This is determined by taking - a number of locality-sensitive hashes and masking to only include weights - that have one hash identical to the multiplied element. - """ - - def __init__(self, d_ff, n_buckets, n_hashes=4, mode='train', - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6)): - """Returns a LSH feed-forward block.""" - super().__init__(name=f'LSHFF_{d_ff}') - self._mode = mode - self._d_ff = d_ff - self._n_buckets = n_buckets - self._n_hashes = n_hashes - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input. + """Feed-forward block with LSH. + + The original (non-LSH) feed-forward block is a triple Dense(d_ff)-Relu-Dense + that takes an input, makes it of size d_ff (usually larger than it was) and + then brings it back to the original size after Relu. It is commonly used in + Transformer models where it often accounts for most of the trainable weights. + + The original block can be slow in decoding due to the need to fetch a lot of + weights from memory. The LSH block aims to exploit this sparsity. So in the + first Dense(d_ff) layer, instead of making a full matrix multiplication, + this block only multiplies by the parts of the weights matrix that have + the highest chance to give non-0 after Relu. This is determined by taking + a number of locality-sensitive hashes and masking to only include weights + that have one hash identical to the multiplied element. """ - w1, w2, b2 = self.weights - x_shape = x.shape - x = np.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - # Hash x into hash buckets; x_buckets is [n_hashes, joint_batch]. - x_buckets, _ = hash_vecs(x, self._n_buckets, self._n_hashes, self.rng) - - # Hash w1 into hash buckets; w1_buckets is [n_hashes, d_ff]. - # Note that we use the same self.rng - so the same hash vectors as for x. - w1_buckets, _ = hash_vecs(w1, self._n_buckets, self._n_hashes, self.rng) - - # Create a mask to determine which x's have the same hash as which w1's. - # First: just subtract the hashes and make them non-negative. - hash_mask = (x_buckets[:, :, None] - w1_buckets[:, None, :])**2 - hash_mask = fastmath.stop_gradient(hash_mask) # make sure no gradients here - # hash_mask is [n_hashes, joint_batch, d_ff], 0 iff hashes were equal - hash_mask = 1 - np.minimum(hash_mask, 1) # now 1 if equal, 0 otherwise - # we now sum over n_hashes and use min, it's 1 iff any of n_hashes was equal - hash_mask = np.minimum(np.sum(hash_mask, axis=0), 1) - hash_mask = hash_mask.astype(np.float32) # convert to float to use mask - - # First dense layer of the block, with hash masking. - mid = np.dot(x, w1.T) * hash_mask # [joint_batch, d_ff] - - # Relu and the second dense layer, as in a standard feed-forward block. - # Note: we merge the second block into this layer because of future plans, - # not anything implemented yet. The potential gain would be as follows: - # in predict mode, we would pre-hash (once) both w1 and w2 and only do - # matmuls (and memory copies) for the parts that correspond to the hash - # of the input. The hash of w1 determines which parts of Relu are 0, so - # it also determines which parts of w2 can be skipped. - relu = np.where(mid <= 0, np.zeros_like(mid), mid) - res = np.dot(relu, w2) + b2 - return np.reshape(res, x_shape) # un-flatten if needed - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - d_model = input_signature.shape[-1] - shape_w1 = (self._d_ff, d_model) - shape_w2 = (self._d_ff, d_model) - shape_b2 = (d_model,) - - rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3) - w1 = self._kernel_initializer(shape_w1, rng_w1) - w2 = self._kernel_initializer(shape_w2, rng_w2) - b2 = self._bias_initializer(shape_b2, rng_b2) - self.weights = (w1, w2, b2) + + def __init__( + self, + d_ff, + n_buckets, + n_hashes=4, + mode="train", + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + ): + """Returns a LSH feed-forward block.""" + super().__init__(name=f"LSHFF_{d_ff}") + self._mode = mode + self._d_ff = d_ff + self._n_buckets = n_buckets + self._n_hashes = n_hashes + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + w1, w2, b2 = self.weights + x_shape = x.shape + x = np.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + # Hash x into hash buckets; x_buckets is [n_hashes, joint_batch]. + x_buckets, _ = hash_vecs(x, self._n_buckets, self._n_hashes, self.rng) + + # Hash w1 into hash buckets; w1_buckets is [n_hashes, d_ff]. + # Note that we use the same self.rng - so the same hash vectors as for x. + w1_buckets, _ = hash_vecs(w1, self._n_buckets, self._n_hashes, self.rng) + + # Create a mask to determine which x's have the same hash as which w1's. + # First: just subtract the hashes and make them non-negative. + hash_mask = (x_buckets[:, :, None] - w1_buckets[:, None, :]) ** 2 + hash_mask = fastmath.stop_gradient(hash_mask) # make sure no gradients here + # hash_mask is [n_hashes, joint_batch, d_ff], 0 iff hashes were equal + hash_mask = 1 - np.minimum(hash_mask, 1) # now 1 if equal, 0 otherwise + # we now sum over n_hashes and use min, it's 1 iff any of n_hashes was equal + hash_mask = np.minimum(np.sum(hash_mask, axis=0), 1) + hash_mask = hash_mask.astype(np.float32) # convert to float to use mask + + # First dense layer of the block, with hash masking. + mid = np.dot(x, w1.T) * hash_mask # [joint_batch, d_ff] + + # Relu and the second dense layer, as in a standard feed-forward block. + # Note: we merge the second block into this layer because of future plans, + # not anything implemented yet. The potential gain would be as follows: + # in predict mode, we would pre-hash (once) both w1 and w2 and only do + # matmuls (and memory copies) for the parts that correspond to the hash + # of the input. The hash of w1 determines which parts of Relu are 0, so + # it also determines which parts of w2 can be skipped. + relu = np.where(mid <= 0, np.zeros_like(mid), mid) + res = np.dot(relu, w2) + b2 + return np.reshape(res, x_shape) # un-flatten if needed + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + d_model = input_signature.shape[-1] + shape_w1 = (self._d_ff, d_model) + shape_w2 = (self._d_ff, d_model) + shape_b2 = (d_model,) + + rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3) + w1 = self._kernel_initializer(shape_w1, rng_w1) + w2 = self._kernel_initializer(shape_w2, rng_w2) + b2 = self._bias_initializer(shape_b2, rng_b2) + self.weights = (w1, w2, b2) diff --git a/trax/layers/research/efficient_attention_test.py b/trax/layers/research/efficient_attention_test.py deleted file mode 100644 index fd4cbc9e3..000000000 --- a/trax/layers/research/efficient_attention_test.py +++ /dev/null @@ -1,441 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.research.efficient_attention.""" - -from absl.testing import parameterized -import jax -import numpy as np -from tensorflow import test - -from trax import fastmath -from trax import shapes -from trax.fastmath import numpy as jnp -from trax.layers.research import efficient_attention - - -class EfficientAttentionTest(test.TestCase, parameterized.TestCase): - - def test_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.SelfAttention( - n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_lsh_ff(self): - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.LSHFF(d_ff=1024*8, n_buckets=[16, 8]) - x = np.ones((3, 7, 1024)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_self_attention_tf(self): - with fastmath.use_backend(fastmath.Backend.TFNP): - layer = efficient_attention.SelfAttention( - n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_lsh_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.LSHSelfAttention( - n_heads=5, d_qk=7, d_v=17, causal=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - - def _run_forward_and_backward(self, model, inp, weights, state): - def forward(inp, weights): - return model.pure_fn( - inp, weights, state, rng=jax.random.PRNGKey(0)) - out, vjpfun, new_state = jax.vjp(forward, inp, weights, has_aux=True) - inp_grad, weights_grad = vjpfun(fastmath.numpy.ones_like(inp)) - return out, new_state, inp_grad, weights_grad - - def _test_equivalence_to_reference_code( - self, model_cls, inp, input_signature, common_kwargs, *test_kwargs): - ref_model = model_cls(use_reference_code=True, **common_kwargs) - rng = fastmath.random.get_prng(123) - weights, state = ref_model.init(input_signature, rng) - - ref_all = self._run_forward_and_backward(ref_model, inp, weights, state) - ref_out, ref_state, ref_inp_grad, ref_weights_grad = ref_all - - for kwargs in test_kwargs: - test_model = model_cls(**common_kwargs, **kwargs) - state = test_model.init(input_signature, rng)[1] - test_all = self._run_forward_and_backward(test_model, inp, weights, state) - test_out, test_state, test_inp_grad, test_weights_grad = test_all - - self.assertEqual(jax.tree_structure(ref_out), - jax.tree_structure(test_out)) - self.assertEqual(jax.tree_structure(ref_state), - jax.tree_structure(test_state)) - self.assertEqual(jax.tree_structure(ref_inp_grad), - jax.tree_structure(test_inp_grad)) - self.assertEqual(jax.tree_structure(ref_weights_grad), - jax.tree_structure(test_weights_grad)) - - check_close = lambda x, y: self.assertAllClose(x, y, rtol=2e-3, atol=2e-3) - fastmath.nested_map_multiarg(check_close, ref_out, test_out) - fastmath.nested_map_multiarg(check_close, ref_state, test_state) - fastmath.nested_map_multiarg(check_close, ref_inp_grad, test_inp_grad) - fastmath.nested_map_multiarg(check_close, ref_weights_grad, - test_weights_grad) - - def test_batching_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - common_kwargs = dict( - n_heads=6, d_qk=7, d_v=17, share_qk=False, causal=True, - chunk_len=5, n_chunks_before=1, n_chunks_after=0, - attention_dropout=0.2, output_dropout=0.1, mode='train', - ) - test_kwargs = [] - for n_parallel_heads in [1, 3, 6, 12]: - for use_python_loop in [True, False]: - test_kwargs.append(dict(n_parallel_heads=n_parallel_heads, - use_python_loop=use_python_loop)) - - x = jax.random.uniform( - jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) - input_signature = shapes.signature(x) - self._test_equivalence_to_reference_code( - efficient_attention.SelfAttention, - x, input_signature, - common_kwargs, *test_kwargs) - - def test_batching_lsh_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - common_kwargs = dict( - n_heads=6, d_qk=7, d_v=17, causal=True, - chunk_len=5, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - attention_dropout=0.2, output_dropout=0.1, mode='train', - ) - test_kwargs = [] - for n_parallel_heads in [1, 3, 6, 12]: - for use_python_loop in [True, False]: - test_kwargs.append(dict(n_parallel_heads=n_parallel_heads, - use_python_loop=use_python_loop)) - - x = jax.random.uniform( - jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) - input_signature = shapes.signature(x) - self._test_equivalence_to_reference_code( - efficient_attention.LSHSelfAttention, - x, input_signature, - common_kwargs, *test_kwargs) - - def _test_fast_inference( - self, model_cls, x, input_signature, common_kwargs, *test_kwargs): - ref_model = model_cls(use_reference_code=True, mode='eval', **common_kwargs) - weights, state = ref_model.init(input_signature) - - ref_out, _ = ref_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - - def get_slice(pytree, i): - def get_slice_for_val(x): - if isinstance(x, shapes.ShapeDtype): - return shapes.ShapeDtype(shape=x.shape[:1] + (1,) + x.shape[2:], - dtype=x.dtype) - else: - return x[:, i:i+1] - return jax.tree_map(get_slice_for_val, pytree) - - seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1] - - for kwargs in test_kwargs: - test_model = model_cls(mode='predict', **common_kwargs, **kwargs) - cur_state = test_model.init(get_slice(input_signature, 0))[1] - out = [] - for i in range(seqlen): - cur_out, cur_state = test_model.pure_fn( - get_slice(x, i), weights, cur_state, jax.random.PRNGKey(0)) - out.append(cur_out) - out = jnp.concatenate(out, axis=1) - - self.assertAllClose(out, ref_out, rtol=1e-3, atol=1e-3) - - def test_fast_inference_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - common_kwargs = dict( - n_heads=6, d_qk=7, d_v=17, share_qk=False, causal=True, - chunk_len=5, n_chunks_before=1, n_chunks_after=0, - attention_dropout=0.0, output_dropout=0.0, - ) - test_kwargs = [] - for n_parallel_heads in [1, 3, 6, 12]: - for use_python_loop in [True, False]: - test_kwargs.append(dict(n_parallel_heads=n_parallel_heads, - use_python_loop=use_python_loop)) - - x = jax.random.uniform( - jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) - input_signature = shapes.signature(x) - self._test_fast_inference( - efficient_attention.SelfAttention, - x, input_signature, - common_kwargs, *test_kwargs) - - def _test_lsh_self_attention_deterministic_given_seed(self, causal=False): - # Once the initialization and the call seeds are pinned down we have - # deterministic output. - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.LSHSelfAttention( - n_heads=5, d_qk=7, d_v=17, causal=causal, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - - def get_output(): - _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0)) - return layer(x, rng=jax.random.PRNGKey(1)) - - ys = [get_output() for _ in range(10)] - - self.assertEqual(ys[0].shape, x.shape) - - for y in ys[1:]: - np.testing.assert_array_almost_equal(ys[0], y, decimal=6) - - def test_lsh_determinism_causal(self): - self._test_lsh_self_attention_deterministic_given_seed(causal=True) - - def test_lsh_determinism_non_causal(self): - self._test_lsh_self_attention_deterministic_given_seed(causal=False) - - def test_lsh_self_attention_masked_non_causal(self): - # Test that when the input that is in the masked area changes the attention - # for the un-masked outputs doesn't change, but the masked region does - # change. - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.LSHSelfAttention( - n_heads=5, d_qk=7, d_v=17, causal=False, masked=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - use_reference_code=True, attention_dropout=0.0, mode='train') - - batch = 5 - max_len = 32 - hidden = 8 - - x = np.random.uniform(size=(batch, max_len, hidden)) - mask = np.ones((batch, max_len)).astype(bool) - rngs = jax.random.randint( - jax.random.PRNGKey(0), (batch,), minval=1, maxval=max_len - 1) - - # Set some suffix of each mask[b] to 0. - for i in range(batch): - mask[i, rngs[i]:] = 0 - - # Fix rngs and get the output for the LSH layer. - def get_output(x, mask): - xs = [x, mask] - _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0)) - return layer(xs, rng=jax.random.PRNGKey(1)) - - # Get the attention output for masked x. - y = get_output(x, mask) - - # Change x, but only in the masked regions. - for i in range(batch): - x[i, rngs[i]:] = np.random.uniform(size=(max_len - rngs[i], hidden)) - - y2 = get_output(x, mask) - - for i in range(batch): - # y and y2 should be identical in the non-masked part. - np.testing.assert_array_almost_equal(y[i, :rngs[i]], y2[i, :rngs[i]], - decimal=6) - - # In the masked out part, they should be different. - self.assertGreater( - np.mean(np.abs(y[i, rngs[i]:] - y2[i, rngs[i]:])), 1e-5) - - @parameterized.named_parameters(('_weights_2', 2), ('_weights_3', 3)) - def test_pure_lsh_wrapper_causal_non_masked(self, num_weights): - with fastmath.use_backend(fastmath.Backend.JAX): - n_heads = 5 - batch, seqlen, d_head = 3, 32, 8 - n_hashes = 2 - d_model = n_heads * d_head - layer = efficient_attention.PureLSHSelfAttentionWrapper( - n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=n_hashes, n_buckets=4, bias=False, - pure_lsh_implementation=efficient_attention.PureLSHSelfAttention, - mode='train', num_weights=num_weights) - - rng = jax.random.PRNGKey(0) - rng, x_rng = jax.random.split(rng) - - input_shape = (batch, seqlen, d_model) - x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) - - inp = x - w, s = layer.init(shapes.signature(inp)) - o = layer(inp) - - # Get the actual weights. - weights = fastmath.tree_leaves(w) - # Assert number of weights is as expected, the extra 1 is for output. - self.assertLen(weights, num_weights + 1) - - # Assert each weight is of the expected shape. - for i in range(num_weights + 1): - self.assertEqual(weights[i].shape, (d_model, d_model)) - - # Test that the output and the input shape match. - self.assertEqual(inp.shape, o.shape) - - # Assert state is the shape expected. - state = fastmath.tree_leaves(s) - self.assertLen(state, 2) - # buckets - self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) - # rngs - self.assertEqual(state[1].shape, (batch * n_heads, 2)) - - @parameterized.named_parameters(('_weights_2', 2), ('_weights_3', 3)) - def test_pure_lsh_wrapper_non_causal_masked(self, num_weights): - with fastmath.use_backend(fastmath.Backend.JAX): - n_heads = 5 - batch, seqlen, d_head = 3, 32, 8 - num_weights = 2 - n_hashes = 2 - d_model = n_heads * d_head - layer = efficient_attention.PureLSHSelfAttentionWrapper( - n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=False, masked=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=n_hashes, n_buckets=4, bias=False, - pure_lsh_implementation=efficient_attention.PureLSHSelfAttention, - mode='train', num_weights=num_weights) - - rng = jax.random.PRNGKey(0) - rng, x_rng = jax.random.split(rng) - - input_shape = (batch, seqlen, d_model) - x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) - mask = jnp.ones((batch, seqlen), dtype=jnp.int32) - - inp = (x, mask) - w, s = layer.init(shapes.signature(inp)) - o = layer(inp) - - # Get the actual weights. - weights = fastmath.tree_leaves(w) - # Assert number of weights is as expected, the extra 1 is for output. - self.assertLen(weights, num_weights + 1) - - # Assert each weight is of the expected shape. - for i in range(num_weights + 1): - self.assertEqual(weights[i].shape, (d_model, d_model)) - - # Test that the output and the x's shape match. - self.assertEqual(x.shape, o.shape) - - # Assert state is the shape expected. - state = fastmath.tree_leaves(s) - self.assertLen(state, 2) - # buckets - self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) - # rngs - self.assertEqual(state[1].shape, (batch * n_heads, 2)) - - def test_lsh_and_pure_lsh_self_attention_equivalence(self): - # Given the same weight matrices and random numbers, do these produce the - # same output. - with fastmath.use_backend(fastmath.Backend.JAX): - n_heads = 4 - d_head = 4 - d_model = n_heads * d_head - pure_lsh_layer = efficient_attention.PureLSHSelfAttention( - n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=4, n_buckets=8, - use_reference_code=False, - attention_dropout=0.0, - use_python_loop=True, - bias=False, mode='train') - lsh_layer = efficient_attention.LSHSelfAttention( - n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=4, n_buckets=8, - use_reference_code=False, - attention_dropout=0.0, - use_python_loop=True, - mode='train') - - batch, seqlen = 3, 32 - input_shape = (batch, seqlen, d_model) - - x = jax.random.uniform(jax.random.PRNGKey(0), input_shape, - dtype=jnp.float32) - lsh_layer_input = x - - call_rng = jax.random.PRNGKey(42) - - lsh_layer_weights, lsh_layer_state = lsh_layer.init( - shapes.signature(lsh_layer_input)) - lsh_layer.rng = call_rng - lsh_layer_output = lsh_layer(lsh_layer_input) - - # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head), - # (n_heads, d_head, d_model) - # Abbreviated as - hmn, hmn, hnm - w_qk, w_v, w_o = lsh_layer_weights - - qk = jnp.einsum('blm,hmn->bhln', x, w_qk) - qk = qk.reshape((-1, qk.shape[2], qk.shape[3])) - - v = jnp.einsum('blm,hmn->bhln', x, w_v) - v = v.reshape((-1, v.shape[2], v.shape[3])) - - pure_lsh_layer_input = (qk, v) - _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input)) - pure_lsh_layer.rng = call_rng - pure_lsh_layer.state = lsh_layer_state - pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input) - - # b*h,l,n - pure_lsh_layer_output = pure_lsh_layer_output.reshape( - (batch, -1) + pure_lsh_layer_output.shape[1:]) - pure_lsh_layer_output_projected = ( - jnp.einsum('bhld,hdm->blm', pure_lsh_layer_output, w_o)) - - diff = pure_lsh_layer_output_projected - lsh_layer_output - avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff)) - - self.assertLess(avg_diff, 1e-5) - -if __name__ == '__main__': - test.main() diff --git a/trax/layers/research/position_encodings.py b/trax/layers/research/position_encodings.py index b483a8650..6c9abdd73 100644 --- a/trax/layers/research/position_encodings.py +++ b/trax/layers/research/position_encodings.py @@ -17,6 +17,7 @@ import logging import jax +import jax.extend as jex import numpy as np import trax from trax import fastmath @@ -26,528 +27,582 @@ class AxialPositionalEncoding(layer_base.Layer): - """Axial positional encoding.""" - # TODO(kitaev): support variable-length sequences. - - def __init__(self, shape=(64, 64, 3), d_embs=(384, 384, 256), - kernel_initializer=init.RandomNormalInitializer(1.0), - dropout=0.0, dropout_broadcast_dims=(), mode='train'): - super().__init__() - self._kernel_initializer = kernel_initializer - assert len(shape) == len(d_embs) - self._shape = shape - self._d_embs = d_embs - - if dropout >= 1.0: - raise ValueError('Dropout rates must be lower than 1.') - if mode == 'train': - self._dropout = dropout - else: - self._dropout = 0.0 - self._dropout_broadcast_dims = dropout_broadcast_dims - self._mode = mode - - def forward(self, inputs): - rng, state = self.rng, self.state - embs = [] - for ax_emb in self.weights: - ax_emb = jnp.broadcast_to( - ax_emb, (inputs.shape[0],) + self._shape + (ax_emb.shape[-1],)) - embs.append(ax_emb) - - if self._mode == 'predict': - assert self._dropout == 0.0 - emb = jnp.concatenate(embs, -1) - emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) - emb = fastmath.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) - self.state = state + inputs.shape[1] - return inputs + emb - elif self._dropout == 0: - # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) - # leads to memory blow-up on TPU. - # emb = jnp.concatenate(embs, -1) - # return inputs + jnp.reshape(emb, inputs.shape), state - return inputs + jnp.concatenate( - [jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1],)) - for emb in embs - ], -1) - else: - emb = jnp.concatenate(embs, -1) - noise_shape = list(emb.shape) - for dim in self._dropout_broadcast_dims: - noise_shape[dim] = 1 - keep_prob = 1.0 - self._dropout - keep = fastmath.random.bernoulli(rng, keep_prob, tuple(noise_shape)) - multiplier = keep.astype(inputs.dtype) / keep_prob - return inputs + jnp.reshape(emb * multiplier, inputs.shape) - - def init_weights_and_state(self, input_signature): - d_feature = input_signature.shape[-1] - if sum(self._d_embs) != d_feature: - raise ValueError( - f'sum(self._d_embs) != d_feature: ' - f'sum({self._d_embs}) vs d_feature: {d_feature}') - - rngs = fastmath.random.split(self.rng, len(self._d_embs)) - weights = [] - for ax, (ax_rng, d_emb) in enumerate(zip(rngs, self._d_embs)): - ax_shape = [1] * len(self._shape) - ax_shape[ax] = self._shape[ax] - ax_shape = (1,) + tuple(ax_shape) + (d_emb,) - ax_emb = self._kernel_initializer(ax_shape, ax_rng) - weights.append(ax_emb) - - # State is EMPTY_STATE by default, stays so except for predict mode. - if self._mode == 'predict': - self.state = np.array(0, dtype=np.int32) - self.weights = tuple(weights) + """Axial positional encoding.""" + + # TODO(kitaev): support variable-length sequences. + + def __init__( + self, + shape=(64, 64, 3), + d_embs=(384, 384, 256), + kernel_initializer=init.RandomNormalInitializer(1.0), + dropout=0.0, + dropout_broadcast_dims=(), + mode="train", + ): + super().__init__() + self._kernel_initializer = kernel_initializer + assert len(shape) == len(d_embs) + self._shape = shape + self._d_embs = d_embs + + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + if mode == "train": + self._dropout = dropout + else: + self._dropout = 0.0 + self._dropout_broadcast_dims = dropout_broadcast_dims + self._mode = mode + + def forward(self, inputs): + rng, state = self.rng, self.state + embs = [] + for ax_emb in self.weights: + ax_emb = jnp.broadcast_to( + ax_emb, (inputs.shape[0],) + self._shape + (ax_emb.shape[-1],) + ) + embs.append(ax_emb) + + if self._mode == "predict": + assert self._dropout == 0.0 + emb = jnp.concatenate(embs, -1) + emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) + emb = fastmath.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) + self.state = state + inputs.shape[1] + return inputs + emb + elif self._dropout == 0: + # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) + # leads to memory blow-up on TPU. + # emb = jnp.concatenate(embs, -1) + # return inputs + jnp.reshape(emb, inputs.shape), state + return inputs + jnp.concatenate( + [ + jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1],)) + for emb in embs + ], + -1, + ) + else: + emb = jnp.concatenate(embs, -1) + noise_shape = list(emb.shape) + for dim in self._dropout_broadcast_dims: + noise_shape[dim] = 1 + keep_prob = 1.0 - self._dropout + keep = fastmath.random.bernoulli(rng, keep_prob, tuple(noise_shape)) + multiplier = keep.astype(inputs.dtype) / keep_prob + return inputs + jnp.reshape(emb * multiplier, inputs.shape) + + def init_weights_and_state(self, input_signature): + d_feature = input_signature.shape[-1] + if sum(self._d_embs) != d_feature: + raise ValueError( + f"sum(self._d_embs) != d_feature: " + f"sum({self._d_embs}) vs d_feature: {d_feature}" + ) + + rngs = fastmath.random.split(self.rng, len(self._d_embs)) + weights = [] + for ax, (ax_rng, d_emb) in enumerate(zip(rngs, self._d_embs)): + ax_shape = [1] * len(self._shape) + ax_shape[ax] = self._shape[ax] + ax_shape = (1,) + tuple(ax_shape) + (d_emb,) + ax_emb = self._kernel_initializer(ax_shape, ax_rng) + weights.append(ax_emb) + + # State is EMPTY_STATE by default, stays so except for predict mode. + if self._mode == "predict": + self.state = np.array(0, dtype=np.int32) + self.weights = tuple(weights) class SinCosPositionalEncoding(layer_base.Layer): - """Implements the sin-cos positional encoding.""" - - def __init__(self, add_offset=2048, dropout=0.0, dropout_broadcast_dims=(-2,), - start_from_zero_one_in=2, mode='train'): - """Creates a SinCosPositionalEncoding instance. - - Args: - add_offset: Maximumnumber to add to positions during training. - dropout: Probability of *not* adding positional encoding to a sequence - position. - dropout_broadcast_dims: Axes along which dropout mask values are - broadcast rather than individually set at random. - start_from_zero_one_in: how often to start from 0 during training, - every one in that many times (e.g., if 4, then it's 25% of the time). - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - super().__init__() - self._add_offset = add_offset - if dropout >= 1.0: - raise ValueError('Dropout rates must be lower than 1.') - if mode == 'train': - self._dropout = dropout - else: - self._dropout = 0.0 - self._dropout_broadcast_dims = dropout_broadcast_dims - self._start_from_zero_one_in = start_from_zero_one_in - self._mode = mode - - def _sincos(self, start, length, d_feature): - """Create the sin-cos tensor of shape [1, length, d_feature].""" - position = jnp.arange(0, length)[:, None] + start - div_term = jnp.exp( - jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature)) - sin = jnp.sin(position * div_term) - cos = jnp.cos(position * div_term) - pe = jnp.concatenate([sin, cos], axis=1) - return pe[None, :, :] # [1, length, d_feature] - - def forward(self, inputs): - """Returns the input activations, with added positional information.""" - if self._mode != 'predict': - x = inputs - length = jnp.shape(x)[1] - if self._mode != 'train': - start = 0 - else: - rng1, rng2 = fastmath.random.split(self.rng, 2) - start = fastmath.random.randint(rng1, (), 0, self._add_offset) - start_from_nonzero = fastmath.random.randint( - rng2, (), 0, self._start_from_zero_one_in) - start_from_nonzero = jnp.minimum(1, start_from_nonzero) - start *= start_from_nonzero - px = self._sincos(start, length, inputs.shape[2]) - if self._dropout == 0: - return x + px - else: - noise_shape = list(px.shape) - for dim in self._dropout_broadcast_dims: - noise_shape[dim] = 1 - keep_prob = 1.0 - self._dropout - keep = fastmath.random.bernoulli(self.rng, keep_prob, - tuple(noise_shape)) - multiplier = keep.astype(x.dtype) / keep_prob - return x + px * multiplier - else: - if self._dropout != 0: - raise ValueError(f'In predict mode, but dropout rate ' - f'({self._dropout}) is not zero.') - - # State in this class is only used for fast inference. In that case, - # the model is called with consecutive elements position-by-position. - # This positional encoding layer needs to store the index of the current - # position then and increment it on each call -- that's how state is used - # and updated below. - pe = self._sincos(self.state, inputs.shape[1], inputs.shape[2]) - self.state += inputs.shape[1] - return inputs + pe - - def init_weights_and_state(self, input_signature): - """Randomly initializes the positional encoding vectors. - - Args: - input_signature: `ShapeDtype` instance characterizing the input this - layer should compute on. - """ - if self._mode == 'predict': - self.state = jnp.zeros((), dtype=jnp.int32) + """Implements the sin-cos positional encoding.""" + + def __init__( + self, + add_offset=2048, + dropout=0.0, + dropout_broadcast_dims=(-2,), + start_from_zero_one_in=2, + mode="train", + ): + """Creates a SinCosPositionalEncoding instance. + + Args: + add_offset: Maximumnumber to add to positions during training. + dropout: Probability of *not* adding positional encoding to a sequence + position. + dropout_broadcast_dims: Axes along which dropout mask values are + broadcast rather than individually set at random. + start_from_zero_one_in: how often to start from 0 during training, + every one in that many times (e.g., if 4, then it's 25% of the time). + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + super().__init__() + self._add_offset = add_offset + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + if mode == "train": + self._dropout = dropout + else: + self._dropout = 0.0 + self._dropout_broadcast_dims = dropout_broadcast_dims + self._start_from_zero_one_in = start_from_zero_one_in + self._mode = mode + + def _sincos(self, start, length, d_feature): + """Create the sin-cos tensor of shape [1, length, d_feature].""" + position = jnp.arange(0, length)[:, None] + start + div_term = jnp.exp( + jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature) + ) + sin = jnp.sin(position * div_term) + cos = jnp.cos(position * div_term) + pe = jnp.concatenate([sin, cos], axis=1) + return pe[None, :, :] # [1, length, d_feature] + + def forward(self, inputs): + """Returns the input activations, with added positional information.""" + if self._mode != "predict": + x = inputs + length = jnp.shape(x)[1] + if self._mode != "train": + start = 0 + else: + rng1, rng2 = fastmath.random.split(self.rng, 2) + start = fastmath.random.randint(rng1, (), 0, self._add_offset) + start_from_nonzero = fastmath.random.randint( + rng2, (), 0, self._start_from_zero_one_in + ) + start_from_nonzero = jnp.minimum(1, start_from_nonzero) + start *= start_from_nonzero + px = self._sincos(start, length, inputs.shape[2]) + if self._dropout == 0: + return x + px + else: + noise_shape = list(px.shape) + for dim in self._dropout_broadcast_dims: + noise_shape[dim] = 1 + keep_prob = 1.0 - self._dropout + keep = fastmath.random.bernoulli( + self.rng, keep_prob, tuple(noise_shape) + ) + multiplier = keep.astype(x.dtype) / keep_prob + return x + px * multiplier + else: + if self._dropout != 0: + raise ValueError( + f"In predict mode, but dropout rate " + f"({self._dropout}) is not zero." + ) + + # State in this class is only used for fast inference. In that case, + # the model is called with consecutive elements position-by-position. + # This positional encoding layer needs to store the index of the current + # position then and increment it on each call -- that's how state is used + # and updated below. + pe = self._sincos(self.state, inputs.shape[1], inputs.shape[2]) + self.state += inputs.shape[1] + return inputs + pe + + def init_weights_and_state(self, input_signature): + """Randomly initializes the positional encoding vectors. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this + layer should compute on. + """ + if self._mode == "predict": + self.state = jnp.zeros((), dtype=jnp.int32) class FixedBasePositionalEncoding(layer_base.Layer): - """Implements fixed-base positional encoding.""" - - def __init__(self, bases=[11, 13, 14, 15], n_digits=8, # pylint: disable=dangerous-default-value - start_from_zero_one_in=2, base_dropout_one_in=100, - mode='train', initializer=init.RandomUniformInitializer(1e-4)): - super().__init__() - self._bases = bases - self._n_digits = n_digits - self._mode = mode - self._initializer = initializer - self._start_from_zero_one_in = start_from_zero_one_in - self._base_dropout_one_in = base_dropout_one_in - - def forward(self, x): - rng = self.rng - base_weights, start_vec = self.weights - batch_size, length = x.shape[0], x.shape[1] - max_pos = min(self._bases)**self._n_digits - rng1, rng2, rng3 = fastmath.random.split(rng, 3) - assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) - positions = jnp.arange(0, length)[None, :] - # In training we'll randomize starts for better generalization. - # We use the trainable start_vec to compensate and give model a way - # to learn what is the starting position in a sequence. - if self._mode == 'train': - # In 1% of training cases still start from 0 to be exactly as in eval. - start_from_nonzero = fastmath.random.randint( - rng1, (batch_size,), 0, self._start_from_zero_one_in) - start_from_nonzero = jnp.minimum(1, start_from_nonzero) - random_start = fastmath.random.randint( - rng2, (batch_size,), 0, max_pos-length) - random_start *= start_from_nonzero - positions += random_start[:, None] - if self._mode == 'predict': - positions += self.state - res = [] - for bn, base in enumerate(self._bases): - pos_embeddings = [] - cur_positions = positions - for i in range(self._n_digits): - cur_indices = jnp.mod(cur_positions, base) - cur_positions = cur_positions // base - s = base_weights[bn][i] - pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s) - embeddings = jnp.concatenate(pos_embeddings, axis=-1) - if self._mode == 'train': - base_dropout = fastmath.random.randint( - rng3, (batch_size,), 0, self._base_dropout_one_in) - base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) - embeddings *= base_dropout[:, None, None] - res.append(embeddings) - res = sum(res) # Sum embeddings from all bases. - # Add start_vec to the first position only to mark it as starting. - res0 = res[:, 0, :][:, None, :] - start_pos = res0 + start_vec - if self._mode == 'predict': - start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0) - self.state += length # Add input length to state. - res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1) - return x + res - - def init_weights_and_state(self, input_signature): - d_feature = input_signature.shape[-1] - if d_feature % self._n_digits != 0: - raise ValueError( - f'd_feature({d_feature}) % self._n_digits({self._n_digits}) != 0') - d_weight = d_feature // self._n_digits - rng1, rng2 = fastmath.random.split(self.rng, 2) - base_weights = [[self._initializer((1, d_weight), rng) - for rng in fastmath.random.split(rng1, self._n_digits)] - for _ in self._bases] - # Special vector to mark the starting position. - start_vec = self._initializer((1, 1, d_feature), rng2) - self.weights = (base_weights, start_vec) - if self._mode == 'predict': - self.state = jnp.zeros((), dtype=jnp.int32) + """Implements fixed-base positional encoding.""" + + def __init__( + self, + bases=[11, 13, 14, 15], + n_digits=8, # pylint: disable=dangerous-default-value + start_from_zero_one_in=2, + base_dropout_one_in=100, + mode="train", + initializer=init.RandomUniformInitializer(1e-4), + ): + super().__init__() + self._bases = bases + self._n_digits = n_digits + self._mode = mode + self._initializer = initializer + self._start_from_zero_one_in = start_from_zero_one_in + self._base_dropout_one_in = base_dropout_one_in + + def forward(self, x): + rng = self.rng + base_weights, start_vec = self.weights + batch_size, length = x.shape[0], x.shape[1] + max_pos = min(self._bases) ** self._n_digits + rng1, rng2, rng3 = fastmath.random.split(rng, 3) + assert length < max_pos, "length (%d) >= max_pos (%d)" % (length, max_pos) + positions = jnp.arange(0, length)[None, :] + # In training we'll randomize starts for better generalization. + # We use the trainable start_vec to compensate and give model a way + # to learn what is the starting position in a sequence. + if self._mode == "train": + # In 1% of training cases still start from 0 to be exactly as in eval. + start_from_nonzero = fastmath.random.randint( + rng1, (batch_size,), 0, self._start_from_zero_one_in + ) + start_from_nonzero = jnp.minimum(1, start_from_nonzero) + random_start = fastmath.random.randint( + rng2, (batch_size,), 0, max_pos - length + ) + random_start *= start_from_nonzero + positions += random_start[:, None] + if self._mode == "predict": + positions += self.state + res = [] + for bn, base in enumerate(self._bases): + pos_embeddings = [] + cur_positions = positions + for i in range(self._n_digits): + cur_indices = jnp.mod(cur_positions, base) + cur_positions = cur_positions // base + s = base_weights[bn][i] + pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s) + embeddings = jnp.concatenate(pos_embeddings, axis=-1) + if self._mode == "train": + base_dropout = fastmath.random.randint( + rng3, (batch_size,), 0, self._base_dropout_one_in + ) + base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) + embeddings *= base_dropout[:, None, None] + res.append(embeddings) + res = sum(res) # Sum embeddings from all bases. + # Add start_vec to the first position only to mark it as starting. + res0 = res[:, 0, :][:, None, :] + start_pos = res0 + start_vec + if self._mode == "predict": + start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0) + self.state += length # Add input length to state. + res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1) + return x + res + + def init_weights_and_state(self, input_signature): + d_feature = input_signature.shape[-1] + if d_feature % self._n_digits != 0: + raise ValueError( + f"d_feature({d_feature}) % self._n_digits({self._n_digits}) != 0" + ) + d_weight = d_feature // self._n_digits + rng1, rng2 = fastmath.random.split(self.rng, 2) + base_weights = [ + [ + self._initializer((1, d_weight), rng) + for rng in fastmath.random.split(rng1, self._n_digits) + ] + for _ in self._bases + ] + # Special vector to mark the starting position. + start_vec = self._initializer((1, 1, d_feature), rng2) + self.weights = (base_weights, start_vec) + if self._mode == "predict": + self.state = jnp.zeros((), dtype=jnp.int32) def threefry_2x32_prf(key, x: jnp.ndarray) -> jnp.ndarray: - """Apply the threefry PRF to an array of inputs. - - This function is vectorized over x. - For threefry_2x32: K = X = uint32[2] - - Args: - key: uint32[2] the key of the PRF - x: uint32[..., 2] the inputs - - Returns: - y: uint32[..., 2] the outputs - """ - if not (key.shape == (2,) and key.dtype == jnp.uint32): - raise TypeError('key must be uint32[2]', key) - if not (x.shape[-1:] == (2,) and x.dtype == jnp.uint32): - raise TypeError('x must be uint32[..., 2]', x) - # Threefry-2x32 expects this weird format: - x_3f = jnp.moveaxis(x, source=-1, destination=0).flatten() - y_3f = jax.random.threefry_2x32(key, x_3f) - y = jnp.moveaxis( - jnp.reshape(y_3f, (2,) + x.shape[:-1]), source=0, destination=-1) - return y + """Apply the threefry PRF to an array of inputs. + This function is vectorized over x. + For threefry_2x32: K = X = uint32[2] -def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): - """Splits a key into a stream of random keys. - - This uses the little-endian counter mode. - - Args: - key: uint32[2] the key to split - lo: the range to start extracting from - hi: the range to stop extracting from - - Returns: - keys: uint32[hi - lo, 2] the split keys - """ - if not (key.shape == (2,) and key.dtype == jnp.uint32): - raise ValueError('key must be uint32[2]') - if not hi < 2**32: - # You shouldn't really be using more than half the key size anyways. - raise NotImplementedError('only 32-bit sizes are supported') - # Create a 64-bit counter: - i_lo = jnp.arange(lo, hi, dtype=jnp.uint32) - i_hi = jnp.zeros_like(i_lo) - i = jnp.stack([i_lo, i_hi], axis=-1) - return threefry_2x32_prf(key, i) - + Args: + key: uint32[2] the key of the PRF + x: uint32[..., 2] the inputs -class InfinitePositionalEncoding(layer_base.Layer): - """Infinite positional encoding.""" + Returns: + y: uint32[..., 2] the outputs + """ + if not (key.shape == (2,) and key.dtype == jnp.uint32): + raise TypeError("key must be uint32[2]", key) + if not (x.shape[-1:] == (2,) and x.dtype == jnp.uint32): + raise TypeError("x must be uint32[..., 2]", x) + # Threefry-2x32 expects this weird format: + x_3f = jnp.moveaxis(x, source=-1, destination=0).flatten() + y_3f = jex.random.threefry_2x32(key, x_3f) + y = jnp.moveaxis(jnp.reshape(y_3f, (2,) + x.shape[:-1]), source=0, destination=-1) + return y - def __init__( - self, drift=.03, affine=True, transform='any', - time_bin_length=None, - mode='train'): - """Initializes the encoding. - The encoding tries to roughly evenly traverse the latent space. - The recurrence time is dependent on how many bits per dimension you use. +def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): + """Splits a key into a stream of random keys. - There are two parameters to control randomization: - - randomizing the origin every 1/drift steps by letting it drift - - randomizing the origin per call + This uses the little-endian counter mode. Args: - drift: variance in position difference per unit of difference - affine: whether to randomize the origin every call - transform: learnable transform after encoding (any/diag/none) - time_bin_length: Add features AxialPositionalEncoding learns if - TimeBinCausalAttention is the first layer. - bin_length should match TBCA.bin_length - If you set transform='diag', this flag increases your model capacity to - close to transform='any', though it will still train slower. - mode: if 'predict', allow evaluating one token at a time - """ - super().__init__() - if transform not in ('any', 'diag', 'none'): - raise ValueError(transform) - self._noise_rng = jax.random.split(jax.random.PRNGKey(234234535))[0] - assert self._noise_rng is not None - self._noise = None - self._drift = drift - self._affine = affine - self._transform = transform - self._time_bin_length = time_bin_length - self._mode = mode - - def _get_noise(self, lo: int, hi: int, depth: int): - """Return pseudorandom noise with shape float[length, depth]. - - Args: - lo: where to start sampling - hi: where to stop sampling - depth: noise depth + key: uint32[2] the key to split + lo: the range to start extracting from + hi: the range to stop extracting from Returns: - noise[lo:hi, :]: the noise, where noise.diff(axis=0) is i.i.d. U(-1,1) + keys: uint32[hi - lo, 2] the split keys """ - if self._noise is None or self._noise.shape[0] < hi: - # Resize the noise: - new_length = 1 - while new_length < hi: - new_length *= 2 - noise = threefry_2x32_prange(self._noise_rng, 0, new_length * depth) - noise = noise.reshape((new_length, depth, 2))[:, :, 0] - # Normalize to [-sqrt(3), sqrt(3)]: - noise = noise.astype(jnp.float32) / np.float32(2**31 - 1) - noise = noise * 3**.5 - # TODO(tying): use multiscale noise for memory-efficient sampling - noise = noise.cumsum(axis=0) - self._noise = noise - assert self._noise.shape[0] >= hi - assert self._noise.shape[1] == depth - return self._noise[lo:hi, :] - - def _get_embeddings(self, lo: int, hi: int, depth, rng=None): - """Get embeddings float[length, depth]. + if not (key.shape == (2,) and key.dtype == jnp.uint32): + raise ValueError("key must be uint32[2]") + if not hi < 2**32: + # You shouldn't really be using more than half the key size anyways. + raise NotImplementedError("only 32-bit sizes are supported") + # Create a 64-bit counter: + i_lo = jnp.arange(lo, hi, dtype=jnp.uint32) + i_hi = jnp.zeros_like(i_lo) + i = jnp.stack([i_lo, i_hi], axis=-1) + return threefry_2x32_prf(key, i) - Args: - lo: where to start sampling - hi: where to stop sampling - depth: embedding depth - rng: rng for random phase - Returns: - embeddings: float[length, depth] - """ - noise = self._get_noise(lo, hi, (depth + 1) // 2) - # Make the stddev around 1 after 1/drift. - noise = noise * self._drift**.5 - - t, c = np.mgrid[lo:hi, :depth] - # Make even channels cos, odd channels sin: - c_div_2, c_mod_2 = divmod(c, 2) - # Off-by-one correction for odd depth: - drift = self._drift - if depth > 2: - drift = drift**(((depth+1)//2)/(depth//2)) - # Spend roughly half the frequencies on noise: - freq = jnp.geomspace(.5, .5 * drift**2, num=(depth + 1) // 2)[c_div_2] - cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4 - assert cycles.shape == (hi - lo, depth), cycles.shape - - # Get random phases: - if self._affine: - assert rng is not None - cycles = cycles + trax.fastmath.random.uniform( - rng, (1, depth,), minval=0, maxval=1) - - # Convert from cycles to radians: - embeddings = jnp.cos(jnp.pi * 2 * cycles) - - # Set the last channels to the time bin features: - if self._time_bin_length is not None: - inter_bin_idx, intra_bin_idx = divmod(t[:, -1:], self._time_bin_length) - bin_parity = inter_bin_idx % 2 - bin_fraction = intra_bin_idx / self._time_bin_length - embeddings = jnp.concatenate( - [ - embeddings[:, :-3], - 1 / (1 + inter_bin_idx), - bin_fraction, - bin_parity.astype(jnp.float32), - ], -1) - - assert embeddings.shape == (hi - lo, depth), embeddings.shape - return embeddings - - def forward(self, inputs): - rng, state = self.rng, self.state - d_feature = inputs.shape[-1] - input_len = inputs.shape[-2] - - if self._mode == 'predict': - # Assume all the positions are pretty close to each other. - index, predict_rng = state - lo = index.min() - hi = index.max() + 1 - emb = self._get_embeddings(lo=lo, hi=hi, depth=d_feature, rng=predict_rng) - emb = emb[index - lo, jnp.newaxis, :] - index = index + 1 - state = index, predict_rng - else: - emb = self._get_embeddings(lo=0, hi=input_len, depth=d_feature, rng=rng) - emb = emb[jnp.newaxis, :input_len, :] - # TODO(tying): check that XLA swaps matmul(slice(x)) -> slice(matmul(x)), - # or inline this code into get_embeddings/get_noise - if self._transform == 'diag': - emb = emb * jax.nn.softplus(self.weights) - elif self._transform == 'any': - emb = emb @ self.weights - self.state = state - return inputs + emb - - def init_weights_and_state(self, input_signature): - d_feature = input_signature.shape[-1] - if self._transform == 'diag': - # Initialize it to a small value because JAX has a bug in softplus. - scale_isoftplus = jnp.zeros((d_feature,), dtype=jnp.float32) + 1e-4 - weights = scale_isoftplus - elif self._transform == 'any': - ortho = trax.layers.initializers.OrthogonalInitializer() - weights = ortho((d_feature, d_feature), self.rng) - else: - weights = layer_base.EMPTY_WEIGHTS - if self._mode == 'predict': - batch_size = input_signature.shape[0] - self.state = jnp.zeros((batch_size,), dtype=jnp.int32), self.rng - self.weights = weights +class InfinitePositionalEncoding(layer_base.Layer): + """Infinite positional encoding.""" + + def __init__( + self, + drift=0.03, + affine=True, + transform="any", + time_bin_length=None, + mode="train", + ): + """Initializes the encoding. + + The encoding tries to roughly evenly traverse the latent space. + The recurrence time is dependent on how many bits per dimension you use. + + There are two parameters to control randomization: + - randomizing the origin every 1/drift steps by letting it drift + - randomizing the origin per call + + Args: + drift: variance in position difference per unit of difference + affine: whether to randomize the origin every call + transform: learnable transform after encoding (any/diag/none) + time_bin_length: Add features AxialPositionalEncoding learns if + TimeBinCausalAttention is the first layer. + bin_length should match TBCA.bin_length + If you set transform='diag', this flag increases your model capacity to + close to transform='any', though it will still train slower. + mode: if 'predict', allow evaluating one token at a time + """ + super().__init__() + if transform not in ("any", "diag", "none"): + raise ValueError(transform) + self._noise_rng = jax.random.split(jax.random.PRNGKey(234234535))[0] + assert self._noise_rng is not None + self._noise = None + self._drift = drift + self._affine = affine + self._transform = transform + self._time_bin_length = time_bin_length + self._mode = mode + + def _get_noise(self, lo: int, hi: int, depth: int): + """Return pseudorandom noise with shape float[length, depth]. + + Args: + lo: where to start sampling + hi: where to stop sampling + depth: noise depth + + Returns: + noise[lo:hi, :]: the noise, where noise.diff(axis=0) is i.i.d. U(-1,1) + """ + if self._noise is None or self._noise.shape[0] < hi: + # Resize the noise: + new_length = 1 + while new_length < hi: + new_length *= 2 + noise = threefry_2x32_prange(self._noise_rng, 0, new_length * depth) + noise = noise.reshape((new_length, depth, 2))[:, :, 0] + # Normalize to [-sqrt(3), sqrt(3)]: + noise = noise.astype(jnp.float32) / np.float32(2**31 - 1) + noise = noise * 3**0.5 + # TODO(tying): use multiscale noise for memory-efficient sampling + noise = noise.cumsum(axis=0) + self._noise = noise + assert self._noise.shape[0] >= hi + assert self._noise.shape[1] == depth + return self._noise[lo:hi, :] + + def _get_embeddings(self, lo: int, hi: int, depth, rng=None): + """Get embeddings float[length, depth]. + + Args: + lo: where to start sampling + hi: where to stop sampling + depth: embedding depth + rng: rng for random phase + + Returns: + embeddings: float[length, depth] + """ + noise = self._get_noise(lo, hi, (depth + 1) // 2) + # Make the stddev around 1 after 1/drift. + noise = noise * self._drift**0.5 + + t, c = np.mgrid[lo:hi, :depth] + # Make even channels cos, odd channels sin: + c_div_2, c_mod_2 = divmod(c, 2) + # Off-by-one correction for odd depth: + drift = self._drift + if depth > 2: + drift = drift ** (((depth + 1) // 2) / (depth // 2)) + # Spend roughly half the frequencies on noise: + freq = jnp.geomspace(0.5, 0.5 * drift**2, num=(depth + 1) // 2)[c_div_2] + cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4 + assert cycles.shape == (hi - lo, depth), cycles.shape + + # Get random phases: + if self._affine: + assert rng is not None + cycles = cycles + trax.fastmath.random.uniform( + rng, + ( + 1, + depth, + ), + minval=0, + maxval=1, + ) + + # Convert from cycles to radians: + embeddings = jnp.cos(jnp.pi * 2 * cycles) + + # Set the last channels to the time bin features: + if self._time_bin_length is not None: + inter_bin_idx, intra_bin_idx = divmod(t[:, -1:], self._time_bin_length) + bin_parity = inter_bin_idx % 2 + bin_fraction = intra_bin_idx / self._time_bin_length + embeddings = jnp.concatenate( + [ + embeddings[:, :-3], + 1 / (1 + inter_bin_idx), + bin_fraction, + bin_parity.astype(jnp.float32), + ], + -1, + ) + + assert embeddings.shape == (hi - lo, depth), embeddings.shape + return embeddings + + def forward(self, inputs): + rng, state = self.rng, self.state + d_feature = inputs.shape[-1] + input_len = inputs.shape[-2] + + if self._mode == "predict": + # Assume all the positions are pretty close to each other. + index, predict_rng = state + lo = index.min() + hi = index.max() + 1 + emb = self._get_embeddings(lo=lo, hi=hi, depth=d_feature, rng=predict_rng) + emb = emb[index - lo, jnp.newaxis, :] + index = index + 1 + state = index, predict_rng + else: + emb = self._get_embeddings(lo=0, hi=input_len, depth=d_feature, rng=rng) + emb = emb[jnp.newaxis, :input_len, :] + # TODO(tying): check that XLA swaps matmul(slice(x)) -> slice(matmul(x)), + # or inline this code into get_embeddings/get_noise + if self._transform == "diag": + emb = emb * jax.nn.softplus(self.weights) + elif self._transform == "any": + emb = emb @ self.weights + self.state = state + return inputs + emb + + def init_weights_and_state(self, input_signature): + d_feature = input_signature.shape[-1] + if self._transform == "diag": + # Initialize it to a small value because JAX has a bug in softplus. + scale_isoftplus = jnp.zeros((d_feature,), dtype=jnp.float32) + 1e-4 + weights = scale_isoftplus + elif self._transform == "any": + ortho = trax.layers.initializers.OrthogonalInitializer() + weights = ortho((d_feature, d_feature), self.rng) + else: + weights = layer_base.EMPTY_WEIGHTS + if self._mode == "predict": + batch_size = input_signature.shape[0] + self.state = jnp.zeros((batch_size,), dtype=jnp.int32), self.rng + self.weights = weights class TimeBinPositionalEncoding(layer_base.Layer): - """Just the engineered features from InfinitePositionalEncoding.""" - num_features = 3 - - def __init__(self, time_bin_length, mode='train'): - """Initializes the encoding. - - Args: - time_bin_length: TimeBinCausalAttention.bin_length of the first layer. - mode: if 'predict', allow evaluating one token at a time - """ - super().__init__() - self._time_bin_length = time_bin_length - self._mode = mode - - def _get_embeddings(self, t): - """Get embeddings float[..., num_features]. - - Args: - t: int[...] position (i.e. jnp.arange(..., jnp.int32)) - - Returns: - embeddings: float[..., num_features] - """ - inter_bin_idx, intra_bin_idx = divmod(t, self._time_bin_length) - bin_parity = inter_bin_idx % 2 - bin_fraction = intra_bin_idx / self._time_bin_length - embeddings = jnp.stack([ - 1 / (1 + inter_bin_idx), - bin_fraction, - bin_parity.astype(jnp.float32), - ], -1) - - assert embeddings.shape == t.shape + (self.num_features,), embeddings.shape - return embeddings - - def forward(self, inputs): - state = self.state - depth = inputs.shape[-1] - - if self._mode == 'predict': - emb = self._get_embeddings(t=state) - emb = emb[:, jnp.newaxis, :] - state = state + 1 - else: - input_len = inputs.shape[-2] - emb = self._get_embeddings(t=jnp.arange(input_len, dtype=jnp.int32)) - # Leave batch axis as 1 for broadcasting: - emb = emb[jnp.newaxis, :, :] - emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3,)) - - # Replace the last num_features channels of input. - inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1) - if inputs.shape[-1] > depth: - logging.warning( - 'dropping feature(s): %d down to %d', inputs.shape[-1], depth) - inputs = inputs[..., -depth:] - - assert inputs.shape[-1] == depth, inputs.shape - self.state = state - return inputs - - def init_weights_and_state(self, input_signature): - if self._mode == 'predict': - batch_size = input_signature.shape[0] - self.state = jnp.zeros((batch_size,), dtype=jnp.int32) + """Just the engineered features from InfinitePositionalEncoding.""" + + num_features = 3 + + def __init__(self, time_bin_length, mode="train"): + """Initializes the encoding. + + Args: + time_bin_length: TimeBinCausalAttention.bin_length of the first layer. + mode: if 'predict', allow evaluating one token at a time + """ + super().__init__() + self._time_bin_length = time_bin_length + self._mode = mode + + def _get_embeddings(self, t): + """Get embeddings float[..., num_features]. + + Args: + t: int[...] position (i.e. jnp.arange(..., jnp.int32)) + + Returns: + embeddings: float[..., num_features] + """ + inter_bin_idx, intra_bin_idx = divmod(t, self._time_bin_length) + bin_parity = inter_bin_idx % 2 + bin_fraction = intra_bin_idx / self._time_bin_length + embeddings = jnp.stack( + [ + 1 / (1 + inter_bin_idx), + bin_fraction, + bin_parity.astype(jnp.float32), + ], + -1, + ) + + assert embeddings.shape == t.shape + (self.num_features,), embeddings.shape + return embeddings + + def forward(self, inputs): + state = self.state + depth = inputs.shape[-1] + + if self._mode == "predict": + emb = self._get_embeddings(t=state) + emb = emb[:, jnp.newaxis, :] + state = state + 1 + else: + input_len = inputs.shape[-2] + emb = self._get_embeddings(t=jnp.arange(input_len, dtype=jnp.int32)) + # Leave batch axis as 1 for broadcasting: + emb = emb[jnp.newaxis, :, :] + emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3,)) + + # Replace the last num_features channels of input. + inputs = jnp.concatenate([inputs[..., : -self.num_features], emb], -1) + if inputs.shape[-1] > depth: + logging.warning( + "dropping feature(s): %d down to %d", inputs.shape[-1], depth + ) + inputs = inputs[..., -depth:] + + assert inputs.shape[-1] == depth, inputs.shape + self.state = state + return inputs + + def init_weights_and_state(self, input_signature): + if self._mode == "predict": + batch_size = input_signature.shape[0] + self.state = jnp.zeros((batch_size,), dtype=jnp.int32) diff --git a/trax/layers/research/position_encodings_test.py b/trax/layers/research/position_encodings_test.py deleted file mode 100644 index f59cbc592..000000000 --- a/trax/layers/research/position_encodings_test.py +++ /dev/null @@ -1,100 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.research.position_encodings.""" - -import functools -import absl.testing.absltest as unittest -import numpy as np -import parameterized - -from trax import fastmath -import trax.layers.research.position_encodings as pe - - -@parameterized.parameterized_class([ - # {'Encoding': pe.FixedBasePositionalEncoding}, - {'Encoding': pe.InfinitePositionalEncoding}, - {'Encoding': functools.partial( - pe.InfinitePositionalEncoding, affine=False)}, - {'Encoding': functools.partial( - pe.TimeBinPositionalEncoding, time_bin_length=5)}, -]) -class PositionEncodingsTest(unittest.TestCase): - """Position encodings conform to the position encodings protocol.""" - - @parameterized.parameterized.expand([ - (1, 100, 8), # typical - (1, 1, 8), # short - (1, 100, 1), # narrow - (2, 100, 8), # batched - ]) - def test_training(self, n, t, c): - encoding = self.Encoding() - input_ntc = np.random.randn(n, t, c) - encoding.init(input_ntc) - output_ntc = encoding(input_ntc) - self.assertEqual(output_ntc.shape, input_ntc.shape) - self.assertTrue(np.not_equal(output_ntc, input_ntc).any()) - - @parameterized.parameterized.expand([ - (1, 100, 8), # typical - (1, 100, 1), # narrow - (2, 100, 8), # batched - ]) - def test_inference(self, n, t, c): - # Get the eval mode outputs: - encoding = self.Encoding(mode='eval') - input_ntc = np.random.randn(n, t, c) - rng = fastmath.random.get_prng(1234) - encoding.init(input_ntc, rng=rng) - output_ntc = encoding(input_ntc) - - is_random = self.Encoding == pe.InfinitePositionalEncoding - - # Get the predict mode outputs: - encoding_pred = self.Encoding(mode='predict') - encoding_pred.init(input_ntc[:, 0:1, :], rng=rng) - output_ntc0 = encoding_pred(input_ntc[:, 0:1, :]) - if not is_random: - np.testing.assert_allclose(output_ntc0, output_ntc[:, 0:1, :], atol=1e-4) - - output_ntc1 = encoding_pred(input_ntc[:, 1:2, :]) - if not is_random: - np.testing.assert_allclose(output_ntc1, output_ntc[:, 1:2, :], atol=1e-4) - - output_ntc2 = encoding_pred(input_ntc[:, 2:3, :]) - if not is_random: - np.testing.assert_allclose(output_ntc2, output_ntc[:, 2:3, :], atol=1e-4) - - -class SinCosEncodingsTest(unittest.TestCase): - """Position encodings conform to the position encodings protocol.""" - - @parameterized.parameterized.expand([ - (1, 100, 8), # typical - (1, 1, 8), # short - (2, 100, 8), # batched - ]) - def test_training(self, n, t, c): - encoding = pe.SinCosPositionalEncoding() - input_ntc = np.random.randn(n, t, c) - encoding.init(input_ntc) - output_ntc = encoding(input_ntc) - self.assertEqual(output_ntc.shape, input_ntc.shape) - - -if __name__ == '__main__': - unittest.main() diff --git a/trax/layers/research/rel_attention.py b/trax/layers/research/rel_attention.py index 19b25240d..ddf66f8b8 100644 --- a/trax/layers/research/rel_attention.py +++ b/trax/layers/research/rel_attention.py @@ -38,470 +38,495 @@ # pylint: disable=invalid-name -def RelativeAttentionWrapper(d_feature, - n_heads=1, - dropout=0.0, - max_inference_length=2048, - mode='train', - context_bias_layer=None, - location_bias_layer=None, - total_pooling=None): - """Relative attention wrapper. - - Args: - d_feature: Last/innermost dimension of activations in the input to and - output from this layer. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size ``d_feature / - n_heads``. - dropout: dropout rate. - max_inference_length: max inference length. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - context_bias_layer: context bias layer. - location_bias_layer: location bias layer. - total_pooling: total pooling. - - Returns: - relative attention layer. - - Relative attention wrapper for compatibility with configurable attention, - so that it can be called by `ApplyAttentionLayer`. - """ - del max_inference_length - - attention = RelativeAttentionLMLayer( - d_feature, - context_bias_layer, - location_bias_layer, - total_pooling, - n_heads=n_heads, - dropout=dropout, - mode=mode) - - return cb.Serial(cb.Select([0, 0, 0]), attention) +def RelativeAttentionWrapper( + d_feature, + n_heads=1, + dropout=0.0, + max_inference_length=2048, + mode="train", + context_bias_layer=None, + location_bias_layer=None, + total_pooling=None, +): + """Relative attention wrapper. + Args: + d_feature: Last/innermost dimension of activations in the input to and + output from this layer. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size ``d_feature / + n_heads``. + dropout: dropout rate. + max_inference_length: max inference length. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + context_bias_layer: context bias layer. + location_bias_layer: location bias layer. + total_pooling: total pooling. -def get_rel_att_inputs(d_model, n_heads): - """Global relative attentions bias initialization shared across layers.""" - assert d_model % n_heads == 0 and d_model % 2 == 0 - d_head = d_model // n_heads - - bias_initializer = init.RandomNormalInitializer(1e-6) - context_bias_layer = core.Weights( - bias_initializer, shape=(1, n_heads, 1, d_head)) - location_bias_layer = core.Weights( - bias_initializer, shape=(1, n_heads, 1, d_head)) - return context_bias_layer, location_bias_layer - - -@assert_shape('bSq,blk,blv,b1xl->bSd,b1xl') -def RelativeAttentionLayer(d_feature, - context_bias_layer, - location_bias_layer, - total_kv_pooling, - separate_cls, - n_heads=1, - dropout=0.0, - mode='train'): - """Returns a layer that maps (q, k, v, masks) to (activations, masks). - - When number of keys is smaller than number of queries layer works in O(q^2*d). - Otherwise it is O(q*k*d). That is because we need to shift relative distances - by current_pooling. When we upsample this is current pooling is a fraction < 1 - Visual explanation: - [01][23][45][67] -> [0][1][2][3][4][5][6][7] - For token [0] we calculate relative distances as follows: - * 0 2 4 6 - However for token [1] we need relative distances changed by 1, specifically: - * -1 1 3 5 - So we not only need to calculate the distances that corresponds to spacing - between the keys but also for the ones in between because there are more than - one query tokens (on different positions which means different relative - distances) for single key token. - - Args: - d_feature: Depth/dimensionality of feature embedding. - context_bias_layer: Global context bias from Transformer XL's attention. - There should be one such layer shared for all relative attention layers - location_bias_layer: Global location bias from Transformer XL's attention. - There should be one such layer shared for all relative attention layers. - total_kv_pooling: Accumulated pool size of keys/values used at this layer - separate_cls: True/False if we separate_cls in calculations. - - n_heads: Number of attention heads. - dropout: Probabilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - return cb.Serial( - cb.Branch( - PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), - cb.Select([0]), cb.Select([1])), - cb.Parallel( - core.Dense(d_feature), - core.Dense(d_feature), - core.Dense(d_feature), - core.Dense(d_feature), - ), - context_bias_layer, - location_bias_layer, - RelativeAttention( # pylint: disable=no-value-for-parameter - separate_cls=separate_cls, - n_heads=n_heads, - dropout=dropout, - mode=mode), - core.Dense(d_feature), - ) - - -@assert_shape('bSq,blk,blv->bSd') -def RelativeAttentionLMLayer(d_feature, - context_bias_layer, - location_bias_layer, - total_kv_pooling, - separate_cls=False, - n_heads=1, - dropout=0.0, - mode='train'): - """Returns a layer that maps (q, k, v) to (activations). - - Same as standard Relative attention layer but additionally based on sizes - of queries and keys prepares a mask that masks out the future. - Masking the future is the concept primarily used for Language Modelling. - Args: - d_feature: Depth/dimensionality of feature embedding. - context_bias_layer: Global context bias from Transformer XL's attention. - There should be one such layer shared for all relative attention layers - location_bias_layer: Global location bias from Transformer XL's attention. - There should be one such layer shared for all relative attention layers. - total_kv_pooling: Accumulated pool size of keys/values used at this layer. - separate_cls: True/False if we separate_cls in calculations. - n_heads: Number of attention heads. - dropout: Probabilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - attention = RelativeAttentionLayer( - d_feature, - context_bias_layer, - location_bias_layer, - total_kv_pooling, - separate_cls, - n_heads=n_heads, - dropout=dropout, - mode=mode) - - return cb.Serial( - CreateAttentionMaskLayer(), # q, k, v, mask - attention, # vecs, mask - cb.Select([0], n_in=2), # vecs - ) + Returns: + relative attention layer. + Relative attention wrapper for compatibility with configurable attention, + so that it can be called by `ApplyAttentionLayer`. + """ + del max_inference_length + + attention = RelativeAttentionLMLayer( + d_feature, + context_bias_layer, + location_bias_layer, + total_pooling, + n_heads=n_heads, + dropout=dropout, + mode=mode, + ) + + return cb.Serial(cb.Select([0, 0, 0]), attention) -class RelativeAttention(base.Layer): - """Relative attention layer. - - Layer that maps (location_bias, context_bias, pos_emb, q, k, v, mask) - to (activations, mask). - This layer type performs the inner workings of one pass of multi-head - self-attention. It: - - splits queries, keys, and values into multiple 'heads', - - splits positional embeddings into multiple 'heads', - - computes per-head attention weights from per-head (queries, keys), - - applies mask to screen out positions that come from padding tokens, - - [in `'train'` mode] applies dropout to attention weights, - - uses attention weights to combine per-head values vectors, and - - merges per-head results into outgoing activations matching original input - activation vector shapes. - """ - - def __init__(self, separate_cls, n_heads=1, dropout=0.0, mode='train'): - """Returns a new PureAttention instance. + +def get_rel_att_inputs(d_model, n_heads): + """Global relative attentions bias initialization shared across layers.""" + assert d_model % n_heads == 0 and d_model % 2 == 0 + d_head = d_model // n_heads + + bias_initializer = init.RandomNormalInitializer(1e-6) + context_bias_layer = core.Weights(bias_initializer, shape=(1, n_heads, 1, d_head)) + location_bias_layer = core.Weights(bias_initializer, shape=(1, n_heads, 1, d_head)) + return context_bias_layer, location_bias_layer + + +@assert_shape("bSq,blk,blv,b1xl->bSd,b1xl") +def RelativeAttentionLayer( + d_feature, + context_bias_layer, + location_bias_layer, + total_kv_pooling, + separate_cls, + n_heads=1, + dropout=0.0, + mode="train", +): + """Returns a layer that maps (q, k, v, masks) to (activations, masks). + + When number of keys is smaller than number of queries layer works in O(q^2*d). + Otherwise it is O(q*k*d). That is because we need to shift relative distances + by current_pooling. When we upsample this is current pooling is a fraction < 1 + Visual explanation: + [01][23][45][67] -> [0][1][2][3][4][5][6][7] + For token [0] we calculate relative distances as follows: + * 0 2 4 6 + However for token [1] we need relative distances changed by 1, specifically: + * -1 1 3 5 + So we not only need to calculate the distances that corresponds to spacing + between the keys but also for the ones in between because there are more than + one query tokens (on different positions which means different relative + distances) for single key token. Args: + d_feature: Depth/dimensionality of feature embedding. + context_bias_layer: Global context bias from Transformer XL's attention. + There should be one such layer shared for all relative attention layers + location_bias_layer: Global location bias from Transformer XL's attention. + There should be one such layer shared for all relative attention layers. + total_kv_pooling: Accumulated pool size of keys/values used at this layer separate_cls: True/False if we separate_cls in calculations. + n_heads: Number of attention heads. - dropout: Probabilistic rate for dropout applied to attention strengths - (based on query-key pairs) before applying them to values. + dropout: Probabilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ - super().__init__(n_in=7, n_out=2) - self._separate_cls = separate_cls - self._n_heads = n_heads - self._dropout = dropout - self._mode = mode - - def forward(self, inputs): - """Returns attention-computed activations and unmodified mask. + return cb.Serial( + cb.Branch( + PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), + cb.Select([0]), + cb.Select([1]), + ), + cb.Parallel( + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + ), + context_bias_layer, + location_bias_layer, + RelativeAttention( # pylint: disable=no-value-for-parameter + separate_cls=separate_cls, n_heads=n_heads, dropout=dropout, mode=mode + ), + core.Dense(d_feature), + ) + + +@assert_shape("bSq,blk,blv->bSd") +def RelativeAttentionLMLayer( + d_feature, + context_bias_layer, + location_bias_layer, + total_kv_pooling, + separate_cls=False, + n_heads=1, + dropout=0.0, + mode="train", +): + """Returns a layer that maps (q, k, v) to (activations). + + Same as standard Relative attention layer but additionally based on sizes + of queries and keys prepares a mask that masks out the future. + Masking the future is the concept primarily used for Language Modelling. Args: - inputs: A (location_bias, context_bias, pos_emb, q, k, v, mask) tuple. + d_feature: Depth/dimensionality of feature embedding. + context_bias_layer: Global context bias from Transformer XL's attention. + There should be one such layer shared for all relative attention layers + location_bias_layer: Global location bias from Transformer XL's attention. + There should be one such layer shared for all relative attention layers. + total_kv_pooling: Accumulated pool size of keys/values used at this layer. + separate_cls: True/False if we separate_cls in calculations. + n_heads: Number of attention heads. + dropout: Probabilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + mode: One of `'train'`, `'eval'`, or `'predict'`. """ - location_bias, context_bias, pos_emb, q, k, v, mask = inputs - - d_feature = q.shape[-1] - n_heads = self._n_heads - if d_feature % n_heads != 0: - raise ValueError( - f'Dimensionality of feature embedding ({d_feature}) is not a ' - f'multiple of the requested number of attention heads ({n_heads}).') - - per_head_results, dots = DotProductAttention( - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(q), - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(k), - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(v), - pos_emb.reshape((-1, n_heads, d_feature // n_heads)), - context_bias, - location_bias, - mask, - separate_cls=self._separate_cls, - dropout=self._dropout, - mode=self._mode, - rng=self.rng) - if self._mode == 'viz': - self.state = dots - merged_results = MergeHeads( - n_heads, merged_batch_and_head=False).forward(per_head_results) - return merged_results, mask - - -def DotProductAttention(queries, keys, values, pos_emb, context_bias, - location_bias, mask, separate_cls, dropout, mode, rng): - """Computes new activations via masked attention-weighted sum of values. - - Args: - queries: Per-head activations representing attention queries. - keys: Per-head activations representing attention keys. - values: Per-head activations to be combined by computed attention weights. - pos_emb: Per-head activations representing positional embeddings. - context_bias: Global context bias from Transformer XL's attention. - location_bias: Global location bias from Transformer XL's attention. - mask: Mask that distinguishes positions with real content vs. padding. - separate_cls: True/False if we separate_cls in calculations. - dropout: Probabilistic rate for dropout applied to attention strengths - (based on query-key pairs) before applying them to values. - mode: One of `'train'`, `'eval'`, or `'predict'`. - rng: Single-use random number generator (JAX PRNG key). - - Returns: - Per-head activations resulting from masked per-head attention-weighted - sum of per-head values. - - This function is the core of the attention mechanism. It: - - computes per-head attention weights from per-head `queries` and `keys`, - - applies `mask` to screen out positions that come from padding tokens, - - optionally applies dropout to attention weights, and - - uses attention weights to combine per-head `values` vectors. - """ - d_feature = queries.shape[-1] - keys_len, queries_len = keys.shape[-2], queries.shape[-2] - funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) - - ac = jnp.einsum('bnid,bnjd->bnij', queries + context_bias, keys) - bd = jnp.einsum('bnid,jnd->bnij', queries + location_bias, pos_emb) - bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling) - - if separate_cls: - # Masking out location part of attention for cls token - bd = bd.at[:, :, :, 0].set(0) - bd = bd.at[:, :, 0, :].set(0) - - dots = (ac + bd) / jnp.sqrt(d_feature) - if mask is not None: - dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) - # Softmax. - dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) - if dropout >= 1.0: - raise ValueError('Dropout rates must be lower than 1.') - if dropout is not None and dropout > 0.0 and mode == 'train': - keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) - dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) - out = jnp.matmul(dots, values) - out = out.astype(jnp.float32) - dots = dots.astype(jnp.float32) - return out, dots + attention = RelativeAttentionLayer( + d_feature, + context_bias_layer, + location_bias_layer, + total_kv_pooling, + separate_cls, + n_heads=n_heads, + dropout=dropout, + mode=mode, + ) -def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling): - """Positional embeddings. + return cb.Serial( + CreateAttentionMaskLayer(), # q, k, v, mask + attention, # vecs, mask + cb.Select([0], n_in=2), # vecs + ) - Args: - d_feature: Depth/dimensionality of feature embedding. - separate_cls: True/False if we separate_cls in calculations. - total_kv_pooling: Accumulated pool size of keys/values until this layer. - Returns: - a layer that based on queries, keys and accumulated pool size of - keys/values until this layer calculates sinusoidal positional embeddings - for relative attention calculations. - """ +class RelativeAttention(base.Layer): + """Relative attention layer. + + Layer that maps (location_bias, context_bias, pos_emb, q, k, v, mask) + to (activations, mask). + This layer type performs the inner workings of one pass of multi-head + self-attention. It: + - splits queries, keys, and values into multiple 'heads', + - splits positional embeddings into multiple 'heads', + - computes per-head attention weights from per-head (queries, keys), + - applies mask to screen out positions that come from padding tokens, + - [in `'train'` mode] applies dropout to attention weights, + - uses attention weights to combine per-head values vectors, and + - merges per-head results into outgoing activations matching original input + activation vector shapes. + """ - def PositionsVectors(queries, keys): - assert not separate_cls + def __init__(self, separate_cls, n_heads=1, dropout=0.0, mode="train"): + """Returns a new PureAttention instance. + + Args: + separate_cls: True/False if we separate_cls in calculations. + n_heads: Number of attention heads. + dropout: Probabilistic rate for dropout applied to attention strengths + (based on query-key pairs) before applying them to values. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + super().__init__(n_in=7, n_out=2) + self._separate_cls = separate_cls + self._n_heads = n_heads + self._dropout = dropout + self._mode = mode + + def forward(self, inputs): + """Returns attention-computed activations and unmodified mask. + + Args: + inputs: A (location_bias, context_bias, pos_emb, q, k, v, mask) tuple. + """ + location_bias, context_bias, pos_emb, q, k, v, mask = inputs + + d_feature = q.shape[-1] + n_heads = self._n_heads + if d_feature % n_heads != 0: + raise ValueError( + f"Dimensionality of feature embedding ({d_feature}) is not a " + f"multiple of the requested number of attention heads ({n_heads})." + ) + + per_head_results, dots = DotProductAttention( + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(q), + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(k), + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(v), + pos_emb.reshape((-1, n_heads, d_feature // n_heads)), + context_bias, + location_bias, + mask, + separate_cls=self._separate_cls, + dropout=self._dropout, + mode=self._mode, + rng=self.rng, + ) + if self._mode == "viz": + self.state = dots + merged_results = MergeHeads(n_heads, merged_batch_and_head=False).forward( + per_head_results + ) + return merged_results, mask + + +def DotProductAttention( + queries, + keys, + values, + pos_emb, + context_bias, + location_bias, + mask, + separate_cls, + dropout, + mode, + rng, +): + """Computes new activations via masked attention-weighted sum of values. + Args: + queries: Per-head activations representing attention queries. + keys: Per-head activations representing attention keys. + values: Per-head activations to be combined by computed attention weights. + pos_emb: Per-head activations representing positional embeddings. + context_bias: Global context bias from Transformer XL's attention. + location_bias: Global location bias from Transformer XL's attention. + mask: Mask that distinguishes positions with real content vs. padding. + separate_cls: True/False if we separate_cls in calculations. + dropout: Probabilistic rate for dropout applied to attention strengths + (based on query-key pairs) before applying them to values. + mode: One of `'train'`, `'eval'`, or `'predict'`. + rng: Single-use random number generator (JAX PRNG key). + + Returns: + Per-head activations resulting from masked per-head attention-weighted + sum of per-head values. + + This function is the core of the attention mechanism. It: + - computes per-head attention weights from per-head `queries` and `keys`, + - applies `mask` to screen out positions that come from padding tokens, + - optionally applies dropout to attention weights, and + - uses attention weights to combine per-head `values` vectors. + """ + d_feature = queries.shape[-1] keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) - if funnel_factor == 1: - offset = keys_len - 1 - positions = (jnp.arange(keys_len) - offset) * total_kv_pooling - else: - if is_upsampling: - positions = jnp.arange(-queries_len + 1, queries_len, 1.0) - else: - positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling + ac = jnp.einsum("bnid,bnjd->bnij", queries + context_bias, keys) + bd = jnp.einsum("bnid,jnd->bnij", queries + location_bias, pos_emb) + bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling) + + if separate_cls: + # Masking out location part of attention for cls token + bd = bd.at[:, :, :, 0].set(0) + bd = bd.at[:, :, 0, :].set(0) + + dots = (ac + bd) / jnp.sqrt(d_feature) + if mask is not None: + dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) + # Softmax. + dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + if dropout is not None and dropout > 0.0 and mode == "train": + keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) + dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) + out = jnp.matmul(dots, values) + out = out.astype(jnp.float32) + dots = dots.astype(jnp.float32) + return out, dots - return positions - def Sinusoidal_Embeddings(positions): - inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) - sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) - pos_emb = jnp.concatenate( - [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) - return pos_emb +def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling): + """Positional embeddings. - return cb.Serial( - cb.Fn('Generate positions vectors', PositionsVectors, n_out=1), - cb.Fn( - 'Transform to sinusoidal encodings', Sinusoidal_Embeddings, n_out=1)) + Args: + d_feature: Depth/dimensionality of feature embedding. + separate_cls: True/False if we separate_cls in calculations. + total_kv_pooling: Accumulated pool size of keys/values until this layer. + Returns: + a layer that based on queries, keys and accumulated pool size of + keys/values until this layer calculates sinusoidal positional embeddings + for relative attention calculations. + """ -def calc_funnel_ratio(keys_len, queries_len): - """Calculate funnel ratio.""" + def PositionsVectors(queries, keys): + assert not separate_cls - if queries_len > keys_len: # Upsampling - assert queries_len % keys_len == 0 - funnel_factor = queries_len // keys_len - is_upsampling = True - else: # Downsampling - assert keys_len % queries_len == 0 - funnel_factor = keys_len // queries_len - is_upsampling = False + keys_len, queries_len = keys.shape[-2], queries.shape[-2] + funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) - return funnel_factor, is_upsampling + if funnel_factor == 1: + offset = keys_len - 1 + positions = (jnp.arange(keys_len) - offset) * total_kv_pooling + else: + if is_upsampling: + positions = jnp.arange(-queries_len + 1, queries_len, 1.0) + else: + positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling + return positions -def _fast_matrix_shift(x, funnel_factor=1, is_upsampling=False): - """Fast matrix shift.""" + def Sinusoidal_Embeddings(positions): + inv_freq = 1 / (10000 ** (jnp.arange(0.0, d_feature, 2.0) / d_feature)) + sinusoid_freq = jnp.einsum("i,j->ij", positions, inv_freq) + pos_emb = jnp.concatenate( + [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1 + ) + return pos_emb - if funnel_factor == 1 and not is_upsampling: - shift = 1 - batch_size, n_head = x.shape[0], x.shape[1] - queries_len, keys_len = x.shape[2], x.shape[3] - zero_pad = jnp.zeros((batch_size, n_head, queries_len, shift)) - x = jnp.concatenate([zero_pad, x], axis=3) - x = x.reshape(batch_size, n_head, keys_len + shift, queries_len) - x = x[:, :, shift:, :] - return x + return cb.Serial( + cb.Fn("Generate positions vectors", PositionsVectors, n_out=1), + cb.Fn("Transform to sinusoidal encodings", Sinusoidal_Embeddings, n_out=1), + ) - if is_upsampling: - k = funnel_factor - shift = 1 - else: - k = 1 - shift = funnel_factor - bsz, n_head = x.shape[0], x.shape[1] - qlen, klen = x.shape[2], (x.shape[3] + 1) // 2 +def calc_funnel_ratio(keys_len, queries_len): + """Calculate funnel ratio.""" - zero_pad = jnp.zeros((bsz, n_head, qlen, shift)) - x = jnp.concatenate([zero_pad, x], axis=3) - x = x.reshape(bsz, n_head, 2 * klen - 1 + shift, qlen) - x = x[:, :, shift:, :] - x = x.reshape(bsz, n_head, qlen, klen * 2 - 1) - x = x[:, :, :, shift - 1:shift - 1 + klen:k] - return x + if queries_len > keys_len: # Upsampling + assert queries_len % keys_len == 0 + funnel_factor = queries_len // keys_len + is_upsampling = True + else: # Downsampling + assert keys_len % queries_len == 0 + funnel_factor = keys_len // queries_len + is_upsampling = False + return funnel_factor, is_upsampling -@assert_shape('bqd,bkd,bvd->bqd,bkd,bvd,b1qk') -def CreateAttentionMaskLayer(): - """Creates attention mask layer. - Returns a layer that based on queries, keys and accumulated pool size of - keys/values until this layer calculates positional embeddings for - causal relative attention calculations. +def _fast_matrix_shift(x, funnel_factor=1, is_upsampling=False): + """Fast matrix shift.""" + + if funnel_factor == 1 and not is_upsampling: + shift = 1 + batch_size, n_head = x.shape[0], x.shape[1] + queries_len, keys_len = x.shape[2], x.shape[3] + zero_pad = jnp.zeros((batch_size, n_head, queries_len, shift)) + x = jnp.concatenate([zero_pad, x], axis=3) + x = x.reshape(batch_size, n_head, keys_len + shift, queries_len) + x = x[:, :, shift:, :] + return x + + if is_upsampling: + k = funnel_factor + shift = 1 + else: + k = 1 + shift = funnel_factor - Takes as input q, k, v and appends proper mask in the end. - Causal attention uses masking to prevent a given sequence position from - attending to positions greater than / following it. This is used, for - example, when training autoregressive sequence models, or when decoding a - sequence symbol by symbol. + bsz, n_head = x.shape[0], x.shape[1] + qlen, klen = x.shape[2], (x.shape[3] + 1) // 2 - Returns: - an attention mask layer. - """ + zero_pad = jnp.zeros((bsz, n_head, qlen, shift)) + x = jnp.concatenate([zero_pad, x], axis=3) + x = x.reshape(bsz, n_head, 2 * klen - 1 + shift, qlen) + x = x[:, :, shift:, :] + x = x.reshape(bsz, n_head, qlen, klen * 2 - 1) + x = x[:, :, :, shift - 1 : shift - 1 + klen : k] + return x - def calculate_mask(queries, keys): - batch_size = queries.shape[0] - keys_len, queries_len = keys.shape[-2], queries.shape[-2] - funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) - return _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, - is_upsampling) +@assert_shape("bqd,bkd,bvd->bqd,bkd,bvd,b1qk") +def CreateAttentionMaskLayer(): + """Creates attention mask layer. - def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, - is_upsampling): - """Funnel mask. + Returns a layer that based on queries, keys and accumulated pool size of + keys/values until this layer calculates positional embeddings for + causal relative attention calculations. - Args: - batch_size: batch size. - keys_len: keys length. - queries_len: queries length. - funnel_factor: funnel factor. - is_upsampling: True or False. + Takes as input q, k, v and appends proper mask in the end. + Causal attention uses masking to prevent a given sequence position from + attending to positions greater than / following it. This is used, for + example, when training autoregressive sequence models, or when decoding a + sequence symbol by symbol. Returns: - funnel mask. - - This function based on keys/queries lengths creates a triangle mask - that prevents tokens from attending to positions following it. - - If funnel_factor is not equal to 1 due to funnel upsampling or - downsampling it adjusts created mask for funnel attention - by repeating each element funnel_factor times. - - This is because after funnel layer one token attends to funnel_factor - different tokens in downsampling. During upsampling on the other hand - funnel_factor tokens are attending to single token before upsampling. + an attention mask layer. """ - if funnel_factor != 1: - if not is_upsampling: - mask = jnp.tril(jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) - mask = jnp.repeat(mask, funnel_factor, axis=-1) - else: - mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) - mask = jnp.repeat(mask, funnel_factor, axis=-2) - else: - mask = jnp.tril(jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) - - return jnp.repeat(mask[None, None, :, :], batch_size, axis=0) - - return cb.Branch( - cb.Select([0]), cb.Select([1]), cb.Select([2]), - cb.Fn('create attention mask layer', calculate_mask, n_out=1)) - - -@assert_shape('...d->...d') + def calculate_mask(queries, keys): + batch_size = queries.shape[0] + keys_len, queries_len = keys.shape[-2], queries.shape[-2] + funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) + + return _funnel_mask( + batch_size, keys_len, queries_len, funnel_factor, is_upsampling + ) + + def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling): + """Funnel mask. + + Args: + batch_size: batch size. + keys_len: keys length. + queries_len: queries length. + funnel_factor: funnel factor. + is_upsampling: True or False. + + Returns: + funnel mask. + + This function based on keys/queries lengths creates a triangle mask + that prevents tokens from attending to positions following it. + + If funnel_factor is not equal to 1 due to funnel upsampling or + downsampling it adjusts created mask for funnel attention + by repeating each element funnel_factor times. + + This is because after funnel layer one token attends to funnel_factor + different tokens in downsampling. During upsampling on the other hand + funnel_factor tokens are attending to single token before upsampling. + """ + + if funnel_factor != 1: + if not is_upsampling: + mask = jnp.tril(jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) + mask = jnp.repeat(mask, funnel_factor, axis=-1) + else: + mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) + mask = jnp.repeat(mask, funnel_factor, axis=-2) + else: + mask = jnp.tril(jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) + + return jnp.repeat(mask[None, None, :, :], batch_size, axis=0) + + return cb.Branch( + cb.Select([0]), + cb.Select([1]), + cb.Select([2]), + cb.Fn("create attention mask layer", calculate_mask, n_out=1), + ) + + +@assert_shape("...d->...d") def ShiftRightCls(cls_id): - """Shifts right and insert cls. + """Shifts right and insert cls. - Args: - cls_id: id of the cls token in embedding dictionary. Returns a layer that - shifts input tokens to the right by one and inserts an cls token to the - beginning like in BERT paper. + Args: + cls_id: id of the cls token in embedding dictionary. Returns a layer that + shifts input tokens to the right by one and inserts an cls token to the + beginning like in BERT paper. - Returns: - layer shifting to right and inserting cls. - """ + Returns: + layer shifting to right and inserting cls. + """ - def shift_right(x): - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[1] = (1, 0) - padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(cls_id)) - return padded[:, :-1] + def shift_right(x): + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[1] = (1, 0) + padded = jnp.pad( + x, pad_widths, mode="constant", constant_values=x.dtype.type(cls_id) + ) + return padded[:, :-1] - return cb.Fn('ShiftRightCls()', shift_right) + return cb.Fn("ShiftRightCls()", shift_right) diff --git a/trax/layers/research/resampling.py b/trax/layers/research/resampling.py index e1866fd8c..2fc8b4ea3 100644 --- a/trax/layers/research/resampling.py +++ b/trax/layers/research/resampling.py @@ -24,109 +24,135 @@ def AveragePooling(shorten_factor, *args, **kwargs): - del args, kwargs - - return AvgPool(pool_size=(shorten_factor,), strides=(shorten_factor,)) - - -def LinearPooling(shorten_factor, d_model, *args, dropout=0.0, mode='train', - **kwargs): - del args, kwargs - - return cb.Serial( - core.Fn( - 'Shorten', - lambda x: jnp.reshape( # pylint: disable=g-long-lambda - # Shorten -- move to depth. # pylint: disable=g-long-lambda - x, (x.shape[0], x.shape[1] // shorten_factor, -1)), - n_out=1), - core.Dense(d_model), - core.Dropout(rate=dropout, mode=mode) - ) - - -def LinearUpsampling(shorten_factor, d_model, *args, dropout=0.0, mode='train', - **kwargs): - del args, kwargs - - return cb.Serial( - core.Dense(shorten_factor * d_model), - core.Dropout(rate=dropout, mode=mode), - core.Fn( - 'ProlongBack', - lambda x: jnp.reshape( # pylint: disable=g-long-lambda - # Prolong back. # pylint: disable=g-long-lambda - x, (x.shape[0], x.shape[1] * shorten_factor, -1)), - n_out=1) - ) - - -def NaiveUpsampling(shorten_factor, d_model, *args, **kwargs): # pylint: disable = unused-argument - return core.Fn('Repeat', lambda x: jnp.repeat(x, shorten_factor, axis=1)) + del args, kwargs + + return AvgPool(pool_size=(shorten_factor,), strides=(shorten_factor,)) + + +def LinearPooling(shorten_factor, d_model, *args, dropout=0.0, mode="train", **kwargs): + del args, kwargs + + return cb.Serial( + core.Fn( + "Shorten", + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + # Shorten -- move to depth. # pylint: disable=g-long-lambda + x, + (x.shape[0], x.shape[1] // shorten_factor, -1), + ), + n_out=1, + ), + core.Dense(d_model), + core.Dropout(rate=dropout, mode=mode), + ) + + +def LinearUpsampling( + shorten_factor, d_model, *args, dropout=0.0, mode="train", **kwargs +): + del args, kwargs + + return cb.Serial( + core.Dense(shorten_factor * d_model), + core.Dropout(rate=dropout, mode=mode), + core.Fn( + "ProlongBack", + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + # Prolong back. # pylint: disable=g-long-lambda + x, + (x.shape[0], x.shape[1] * shorten_factor, -1), + ), + n_out=1, + ), + ) + + +def NaiveUpsampling( + shorten_factor, d_model, *args, **kwargs +): # pylint: disable = unused-argument + return core.Fn("Repeat", lambda x: jnp.repeat(x, shorten_factor, axis=1)) def NoUpsampling(shorten_factor, d_model, *args, **kwargs): - del d_model, args, kwargs - - return core.Fn('ReturnZero', lambda x: jnp.zeros( # pylint: disable=g-long-lambda - (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]), dtype=x.dtype)) - - -def FeedForwardBlock(d_model, - d_ff, - dropout, - dropout_shared_axes, - mode, - activation): - # We copy the ff block function because we cannot import it from models - return [ - core.Dense(d_ff), - activation(), - core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, - mode=mode), - core.Dense(d_model), - ] - - -def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads, - dropout, dropout_shared_axes, mode, ff_activation, - context_bias_layer, location_bias_layer, total_pooling, - resampling_fn): - """Attention resampling.""" - - attention = RelativeAttentionLMLayer( - d_model, context_bias_layer, location_bias_layer, - total_pooling, n_heads=n_heads, dropout=dropout, - mode=mode) - - feed_forward = FeedForwardBlock( - d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) - - resampling = resampling_fn(shorten_factor, d_model, - mode=mode) - - def _Dropout(): - return core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, - mode=mode) - - return [ - LayerNorm(), # h - cb.Branch(cb.Serial( - resampling, - LayerNorm(), - ), None), # h', h - cb.Serial( # pylint: disable=g-long-ternary - cb.Select([0, 2, 1, 2]), - cb.Add(), - ) if is_upsampling else [], - cb.Residual( - cb.Select([0, 1, 1]), # h', h, h - attention, - _Dropout(), - ), - cb.Residual( - LayerNorm(), - feed_forward, - _Dropout(), - ), - ] + del d_model, args, kwargs + + return core.Fn( + "ReturnZero", + lambda x: jnp.zeros( # pylint: disable=g-long-lambda + (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]), dtype=x.dtype + ), + ) + + +def FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, activation): + # We copy the ff block function because we cannot import it from models + return [ + core.Dense(d_ff), + activation(), + core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + core.Dense(d_model), + ] + + +def AttentionResampling( + shorten_factor, + d_model, + is_upsampling, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + context_bias_layer, + location_bias_layer, + total_pooling, + resampling_fn, +): + """Attention resampling.""" + + attention = RelativeAttentionLMLayer( + d_model, + context_bias_layer, + location_bias_layer, + total_pooling, + n_heads=n_heads, + dropout=dropout, + mode=mode, + ) + + feed_forward = FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation + ) + + resampling = resampling_fn(shorten_factor, d_model, mode=mode) + + def _Dropout(): + return core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + return [ + LayerNorm(), # h + cb.Branch( + cb.Serial( + resampling, + LayerNorm(), + ), + None, + ), # h', h + cb.Serial( # pylint: disable=g-long-ternary + cb.Select([0, 2, 1, 2]), + cb.Add(), + ) + if is_upsampling + else [], + cb.Residual( + cb.Select([0, 1, 1]), # h', h, h + attention, + _Dropout(), + ), + cb.Residual( + LayerNorm(), + feed_forward, + _Dropout(), + ), + ] diff --git a/trax/layers/research/rotary_positional_embedding.py b/trax/layers/research/rotary_positional_embedding.py index dbb08fcea..8e9b1b3bd 100644 --- a/trax/layers/research/rotary_positional_embedding.py +++ b/trax/layers/research/rotary_positional_embedding.py @@ -19,30 +19,29 @@ https://arxiv.org/pdf/2104.09864.pdf """ -# from trax import layers as tl from trax.fastmath import numpy as jnp from trax.layers import core def rotate(x): - """Rotate function.""" - _, l, d = x.shape - inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) - positions = jnp.arange(l) - freqs = jnp.einsum('i,j->ij', positions, inv_freq) - emb = jnp.concatenate((freqs, freqs), axis=-1) - cos = jnp.cos(emb) - sin = jnp.sin(emb) + """Rotate function.""" + _, l, d = x.shape + inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) + positions = jnp.arange(l) + freqs = jnp.einsum("i,j->ij", positions, inv_freq) + emb = jnp.concatenate((freqs, freqs), axis=-1) + cos = jnp.cos(emb) + sin = jnp.sin(emb) - def mul(vecs, pos_emb): - return jnp.einsum('bld,ld->bld', vecs, pos_emb) + def mul(vecs, pos_emb): + return jnp.einsum("bld,ld->bld", vecs, pos_emb) - def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) - return mul(x, cos) + mul(rotate_half(x), sin) + return mul(x, cos) + mul(rotate_half(x), sin) def Rotate(): # pylint: disable=invalid-name - return core.Fn('Rotate', rotate) + return core.Fn("Rotate", rotate) diff --git a/trax/layers/research/sparsity.py b/trax/layers/research/sparsity.py index 1ac1c8ca4..e4a3c55cc 100644 --- a/trax/layers/research/sparsity.py +++ b/trax/layers/research/sparsity.py @@ -35,1699 +35,1981 @@ # pylint: disable=invalid-name -@assert_shape('...->...') +@assert_shape("...->...") class ReversibleReshapePermute(reversible.ReversibleLayer): - """Simple and fast, reversible, random-looking permutation layer. - - This layer permutates the last dimension (usually the embedding dimension) - with simple reshapes. It uses the same permutation for every embedding, and - permutation never changes. - The layer works only when the last dimension is a power of 2. The - permutation is not truly random, as it just uses reshapes to get a fast - random-looking permutation. It has, however, a permutation cycle length - of just log2(dimension_size). - """ - - def forward(self, x): - shape = x.shape - x = x.reshape(shape[:-1]+(-1, self._get_multiplier(x))) - t_x = jnp.einsum('...ab->...ba', x) # transpose - return t_x.reshape(shape) - - def reverse(self, x, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng - shape = x.shape - x = x.reshape(shape[:-1]+(self._get_multiplier(x), -1)) - t_x = jnp.einsum('...ab->...ba', x) # transpose - return t_x.reshape(shape) - - def _get_multiplier(self, x): - """Return a size of the new dimension for reshaping. - - We want to split the last dimension into two using approximately equal - dimensions, we could split a dimension of size 512 into 16 * 32. - However, not all numbers will work equally well, because we have a different - cycle length for permutations for different numbers. For example, for - dimension size 1024 and multiplier 32 we would get the same permutation - already after applying permutation twice (cycle length is 2), but with - multiplier 8 we would get the same permutation after appling permutation 10 - times (cycle length is 10). - - For powers of two the cycle length is limited by log2(dimension_size). - This function returns the biggest multiplier smaller than - sqrt(dimension_size) that keeps the longest possible cycle lenght of the - permutation. + """Simple and fast, reversible, random-looking permutation layer. + + This layer permutates the last dimension (usually the embedding dimension) + with simple reshapes. It uses the same permutation for every embedding, and + permutation never changes. + The layer works only when the last dimension is a power of 2. The + permutation is not truly random, as it just uses reshapes to get a fast + random-looking permutation. It has, however, a permutation cycle length + of just log2(dimension_size). + """ - Args: - x: The input tensor. + def forward(self, x): + shape = x.shape + x = x.reshape(shape[:-1] + (-1, self._get_multiplier(x))) + t_x = jnp.einsum("...ab->...ba", x) # transpose + return t_x.reshape(shape) + + def reverse(self, x, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng + shape = x.shape + x = x.reshape(shape[:-1] + (self._get_multiplier(x), -1)) + t_x = jnp.einsum("...ab->...ba", x) # transpose + return t_x.reshape(shape) + + def _get_multiplier(self, x): + """Return a size of the new dimension for reshaping. + + We want to split the last dimension into two using approximately equal + dimensions, we could split a dimension of size 512 into 16 * 32. + However, not all numbers will work equally well, because we have a different + cycle length for permutations for different numbers. For example, for + dimension size 1024 and multiplier 32 we would get the same permutation + already after applying permutation twice (cycle length is 2), but with + multiplier 8 we would get the same permutation after appling permutation 10 + times (cycle length is 10). + + For powers of two the cycle length is limited by log2(dimension_size). + This function returns the biggest multiplier smaller than + sqrt(dimension_size) that keeps the longest possible cycle lenght of the + permutation. + + Args: + x: The input tensor. + + Returns: + An appropriate multiplier for the permutation reshape. + """ + last_dim = x.shape[-1] + + def big_relatively_prime(n): + # The longest possible cycle is achieved iff log2(multiplier) and + # log2(dimension_size) are relatively prime. We choose the biggest such + # number smaller than sqrt(dimension_size). + for i in range(n // 2, 0, -1): + if n % i != 0: + return i + return 1 + + max_cycle_len = int(math.log(last_dim, 2)) + assert 2**max_cycle_len == last_dim + + return 2 ** big_relatively_prime(max_cycle_len) + + +@assert_shape("...->...") +class ReversibleRandomPermute(reversible.ReversibleLayer): + """Reversible, random permutation layer. - Returns: - An appropriate multiplier for the permutation reshape. + This layer permutates the last dimension (usually the embedding dimension) + by indexing and slicing. It uses the same random permutation for every + embedding, and this permutation never changes. """ - last_dim = x.shape[-1] - def big_relatively_prime(n): - # The longest possible cycle is achieved iff log2(multiplier) and - # log2(dimension_size) are relatively prime. We choose the biggest such - # number smaller than sqrt(dimension_size). - for i in range(n//2, 0, -1): - if n%i != 0: - return i - return 1 + def forward(self, x): + permutation, _ = self._get_permutation_and_reverse_permutation(x) + return x[..., permutation] - max_cycle_len = int(math.log(last_dim, 2)) - assert 2 ** max_cycle_len == last_dim + def reverse(self, x, weights=(), state=(), new_state=(), rng=None): + _, rev_permutation = self._get_permutation_and_reverse_permutation(x) + return x[..., rev_permutation] - return 2 ** big_relatively_prime(max_cycle_len) + def _get_permutation_and_reverse_permutation(self, x): + # TODO(jaszczur): random seed should be stored in state. + # Currently there is no way of doing it reliably. + last_dim = x.shape[-1] + permutation = list(range(last_dim)) + rand = pyrandom.Random(42) + rand.shuffle(permutation) + rev_permutation = [permutation.index(i) for i in range(last_dim)] + return permutation, rev_permutation -@assert_shape('...->...') -class ReversibleRandomPermute(reversible.ReversibleLayer): - """Reversible, random permutation layer. +@assert_shape("...a->...bc") +def SplitLastAxis(num_splits): + return tl.Fn( + f"SplitLastAxis_{num_splits}", + lambda x: jnp.reshape(x, tuple(x.shape)[:-1] + (num_splits, -1)), + ) - This layer permutates the last dimension (usually the embedding dimension) - by indexing and slicing. It uses the same random permutation for every - embedding, and this permutation never changes. - """ - def forward(self, x): - permutation, _ = self._get_permutation_and_reverse_permutation(x) - return x[..., permutation] +@assert_shape("...ab->...c") +def MergeLastTwoAxes(): + return tl.Fn( + "MergeLastTwoAxes", lambda x: jnp.reshape(x, tuple(x.shape)[:-2] + (-1,)) + ) - def reverse(self, x, weights=(), state=(), new_state=(), rng=None): - _, rev_permutation = self._get_permutation_and_reverse_permutation(x) - return x[..., rev_permutation] - def _get_permutation_and_reverse_permutation(self, x): - # TODO(jaszczur): random seed should be stored in state. - # Currently there is no way of doing it reliably. - last_dim = x.shape[-1] - permutation = list(range(last_dim)) - rand = pyrandom.Random(42) - rand.shuffle(permutation) - rev_permutation = [permutation.index(i) for i in range(last_dim)] - return permutation, rev_permutation +@assert_shape("...a->...b") +def LocallyConnectedDense( + n_modules, + n_units, + kernel_size=1, + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, +): + """Layer using LocallyConnected1d for approximation of Dense layer. + + The layer splits the last axis of a tensor into `n_modules`, then runs + LocallyConnected1d (grouped convolution) on all those modules, and + concatenates their results. It is essentially a locally-sensitive + approximation of Dense layer, with number of parameters smaller by the factor + of `n_modules / kernel_size`. + Args: + n_modules: Indicates how many modules (pixels) should be input and output + split into for processing. + n_units: how many outputs (filters) should each module generate. + kernel_size: The size of the kernel to be used. + kernel_initializer: Function that creates a matrix of (random) initial + connection weights `W` for the layer. + bias_initializer: Function that creates a vector of (random) initial + bias weights `b` for the layer. + use_bias: If `True`, compute an affine map `y = Wx + b`; else compute + a linear map `y = Wx`. -@assert_shape('...a->...bc') -def SplitLastAxis(num_splits): - return tl.Fn(f'SplitLastAxis_{num_splits}', - lambda x: jnp.reshape(x, tuple(x.shape)[:-1] + (num_splits, -1))) + Returns: + LocallyConnectedDense base.Layer. + """ + if n_modules == 1: + return tl.Dense( + n_units, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + use_bias=use_bias, + ) + return tl.Serial( + tl.SplitLastAxis(n_modules), + tl.LocallyConnected1d( + n_units, + kernel_size, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + use_bias=use_bias, + padding="WRAP", + ), + tl.MergeLastTwoAxes(), + ) -@assert_shape('...ab->...c') -def MergeLastTwoAxes(): - return tl.Fn('MergeLastTwoAxes', - lambda x: jnp.reshape(x, tuple(x.shape)[:-2] + (-1,))) - - -@assert_shape('...a->...b') -def LocallyConnectedDense(n_modules, n_units, kernel_size=1, - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True): - """Layer using LocallyConnected1d for approximation of Dense layer. - - The layer splits the last axis of a tensor into `n_modules`, then runs - LocallyConnected1d (grouped convolution) on all those modules, and - concatenates their results. It is essentially a locally-sensitive - approximation of Dense layer, with number of parameters smaller by the factor - of `n_modules / kernel_size`. - - Args: - n_modules: Indicates how many modules (pixels) should be input and output - split into for processing. - n_units: how many outputs (filters) should each module generate. - kernel_size: The size of the kernel to be used. - kernel_initializer: Function that creates a matrix of (random) initial - connection weights `W` for the layer. - bias_initializer: Function that creates a vector of (random) initial - bias weights `b` for the layer. - use_bias: If `True`, compute an affine map `y = Wx + b`; else compute - a linear map `y = Wx`. - - Returns: - LocallyConnectedDense base.Layer. - """ - if n_modules == 1: - return tl.Dense(n_units, kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, use_bias=use_bias) - return tl.Serial( - tl.SplitLastAxis(n_modules), - tl.LocallyConnected1d( - n_units, kernel_size, kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, use_bias=use_bias, padding='WRAP'), - tl.MergeLastTwoAxes()) - - -@assert_shape('bld->bld') -def ModularCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, - max_inference_length=2048, - kernel_size=1, mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - it uses LocallyConnectedDense instead of Dense layer for computing Q/K/V. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: Number of modules used in LocallyConnectedDense. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - kernel_size: Kernel size used in LocallyConnectedDense. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - n_modules = n_heads if sparsity is None else sparsity - @assert_shape('...a->...b') - def ProcessingLayer(): - assert d_feature % n_modules == 0 - return LocallyConnectedDense(n_modules, d_feature // n_modules, - kernel_size=kernel_size) - - return tl.ConfigurableAttention( - ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), - ProcessingLayer(), n_heads=n_heads, - qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) +@assert_shape("bld->bld") +def ModularCausalAttention( + d_feature, + n_heads=1, + sparsity=None, + dropout=0.0, + max_inference_length=2048, + kernel_size=1, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + it uses LocallyConnectedDense instead of Dense layer for computing Q/K/V. + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: Number of modules used in LocallyConnectedDense. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + kernel_size: Kernel size used in LocallyConnectedDense. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + n_modules = n_heads if sparsity is None else sparsity + + @assert_shape("...a->...b") + def ProcessingLayer(): + assert d_feature % n_modules == 0 + return LocallyConnectedDense( + n_modules, d_feature // n_modules, kernel_size=kernel_size + ) + + return tl.ConfigurableAttention( + ProcessingLayer(), + ProcessingLayer(), + ProcessingLayer(), + ProcessingLayer(), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) class _RememberPad(base.Layer): - """Layer which remembers last N elements in predict mode.""" + """Layer which remembers last N elements in predict mode.""" + + def __init__(self, n_items_to_remember, mode): + """Returns a layer which remembers last N elements in predict mode. + + For predict mode, the layer remembers last N elements and pads with them. + For other modes, it pads with zeros. The layer pads/remembers elements from + the second axis. + + Args: + n_items_to_remember: Number of items to remember/pad with. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + super().__init__(name="_RememberPad") + self._n_items_to_remember = n_items_to_remember + self._mode = mode + self._portal_mask = ( + self.monkey_patched_mask() + ) # pylint: disable=assignment-from-none + + def monkey_patched_mask(self): + # This is necessary for Terraformer model. See comments there. + # The mask will only be used in Terraformer in predict mode. + return None + + def forward(self, x): + if self._n_items_to_remember == 0: + return x + if self._mode == "predict": + x = jnp.concatenate([self.state[0], x], axis=1) + if self._portal_mask is not None and "init" in self.state[1]: + # TODO(jaszczur): In predict mode with monkey-patched mask, we + # currently assume that batch size is 1. + assert x.shape[0] == 1 + mask = self._portal_mask.get_value() + count_padding = jnp.sum(mask == 0, dtype=jnp.int32) + self.state = ( + fastmath.dynamic_slice_in_dim( + x, + x.shape[1] - (self._n_items_to_remember + count_padding), + self._n_items_to_remember, + axis=1, + ), + {"forward": ()}, + ) + else: + self.state = (x[:, -self._n_items_to_remember :, ...], {"forward": ()}) + else: + pad_widths = [[0, 0] for _ in range(len(x.shape))] + pad_widths[1][0] = self._n_items_to_remember + x = jnp.pad(x, pad_width=pad_widths, mode="constant") + return x + + def init_weights_and_state(self, input_signature): + """Initializes this layer's weights.""" + if isinstance(input_signature, (list, tuple)): + input_signature = input_signature[0] + self.weights = () + if self._mode == "predict": + shape = list(input_signature.shape) + shape[1] = self._n_items_to_remember + self.state = (jnp.zeros(shape, dtype=jnp.float32), {"init": ()}) + else: + self.state = () + + +@assert_shape("...a->...b") +def LocallyConvDense(n_modules, n_units, mode, kernel_size=1, length_kernel_size=1): + """Layer using local convolutions for approximation of Dense layer. + + The layer splits the last axis of a tensor into `n_modules`, then runs + a convolution on all those modules, and concatenates their results. + It is similar to LocallyConnectedDense above, but shares weights. - def __init__(self, n_items_to_remember, mode): - """Returns a layer which remembers last N elements in predict mode. + Args: + n_modules: Indicates how many modules (pixels) should be input and output + split into for processing. + n_units: how many outputs (filters) should each module generate. + mode: One of `'train'`, `'eval'`, or `'predict'`. + kernel_size: The size of the kernel to be used. + length_kernel_size: If > 1, also do causal convolution on the previous axis, + which is often the sentence length in sequence models. - For predict mode, the layer remembers last N elements and pads with them. - For other modes, it pads with zeros. The layer pads/remembers elements from - the second axis. + Returns: + LocallyConvDense base.Layer. + """ + if n_modules == 1: + return tl.Dense(n_units) + if kernel_size % 2 != 1: + raise ValueError("Currently we only handle odd kernel sizes.") + half = (kernel_size - 1) // 2 + pad_widths = [[0, 0], [0, 0], [half, half], [0, 0]] + return tl.Serial( + tl.SplitLastAxis(n_modules), + tl.Fn("Pad", lambda x: jnp.pad(x, pad_width=pad_widths, mode="constant")), + _RememberPad(length_kernel_size - 1, mode=mode), + tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)), + tl.MergeLastTwoAxes(), + ) + + +@assert_shape("bld->bld") +def ConvCausalAttention( + d_feature, + n_heads=1, + sparsity=None, + dropout=0.0, + max_inference_length=2048, + kernel_size=1, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + it uses LocallyConvDense instead of Dense layer for computing Q/K/V. Args: - n_items_to_remember: Number of items to remember/pad with. + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: Number of modules used in LocallyConvDense. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + kernel_size: Kernel size used in LocallyConnectedDense. mode: One of `'train'`, `'eval'`, or `'predict'`. """ - super().__init__(name='_RememberPad') - self._n_items_to_remember = n_items_to_remember - self._mode = mode - self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none - - def monkey_patched_mask(self): - # This is necessary for Terraformer model. See comments there. - # The mask will only be used in Terraformer in predict mode. - return None - - def forward(self, x): - if self._n_items_to_remember == 0: - return x - if self._mode == 'predict': - x = jnp.concatenate([self.state[0], x], axis=1) - if self._portal_mask is not None and 'init' in self.state[1]: - # TODO(jaszczur): In predict mode with monkey-patched mask, we - # currently assume that batch size is 1. - assert x.shape[0] == 1 - mask = self._portal_mask.get_value() - count_padding = jnp.sum(mask == 0, dtype=jnp.int32) - self.state = (fastmath.dynamic_slice_in_dim( - x, x.shape[1] - (self._n_items_to_remember + count_padding), - self._n_items_to_remember, axis=1), {'forward': ()}) - else: - self.state = (x[:, -self._n_items_to_remember:, ...], {'forward': ()}) - else: - pad_widths = [[0, 0] for _ in range(len(x.shape))] - pad_widths[1][0] = self._n_items_to_remember - x = jnp.pad(x, pad_width=pad_widths, mode='constant') - return x - - def init_weights_and_state(self, input_signature): - """Initializes this layer's weights.""" - if isinstance(input_signature, (list, tuple)): - input_signature = input_signature[0] - self.weights = () - if self._mode == 'predict': - shape = list(input_signature.shape) - shape[1] = self._n_items_to_remember - self.state = (jnp.zeros(shape, dtype=jnp.float32), {'init': ()}) - else: - self.state = () - - -@assert_shape('...a->...b') -def LocallyConvDense(n_modules, n_units, mode, kernel_size=1, - length_kernel_size=1): - """Layer using local convolutions for approximation of Dense layer. - - The layer splits the last axis of a tensor into `n_modules`, then runs - a convolution on all those modules, and concatenates their results. - It is similar to LocallyConnectedDense above, but shares weights. - - Args: - n_modules: Indicates how many modules (pixels) should be input and output - split into for processing. - n_units: how many outputs (filters) should each module generate. - mode: One of `'train'`, `'eval'`, or `'predict'`. - kernel_size: The size of the kernel to be used. - length_kernel_size: If > 1, also do causal convolution on the previous axis, - which is often the sentence length in sequence models. - - Returns: - LocallyConvDense base.Layer. - """ - if n_modules == 1: - return tl.Dense(n_units) - if kernel_size % 2 != 1: - raise ValueError('Currently we only handle odd kernel sizes.') - half = (kernel_size - 1) // 2 - pad_widths = [[0, 0], [0, 0], [half, half], [0, 0]] - return tl.Serial( - tl.SplitLastAxis(n_modules), - tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths, mode='constant')), - _RememberPad(length_kernel_size-1, mode=mode), - tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)), - tl.MergeLastTwoAxes() - ) - - -@assert_shape('bld->bld') -def ConvCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, - max_inference_length=2048, - kernel_size=1, mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - it uses LocallyConvDense instead of Dense layer for computing Q/K/V. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: Number of modules used in LocallyConvDense. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - kernel_size: Kernel size used in LocallyConnectedDense. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - n_modules = n_heads if sparsity is None else sparsity - @assert_shape('...a->...b') - def ProcessingLayer(): - assert d_feature % n_modules == 0 - return LocallyConvDense(n_modules, d_feature // n_modules, mode=mode, - kernel_size=kernel_size) - - return tl.ConfigurableAttention( - ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), - ProcessingLayer(), n_heads=n_heads, - qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) - - -@assert_shape('...a->...b') + n_modules = n_heads if sparsity is None else sparsity + + @assert_shape("...a->...b") + def ProcessingLayer(): + assert d_feature % n_modules == 0 + return LocallyConvDense( + n_modules, d_feature // n_modules, mode=mode, kernel_size=kernel_size + ) + + return tl.ConfigurableAttention( + ProcessingLayer(), + ProcessingLayer(), + ProcessingLayer(), + ProcessingLayer(), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("...a->...b") def LowRankDense(n_units, d_lowrank): - return tl.Serial( - tl.Dense(d_lowrank), - tl.Dense(n_units) - ) + return tl.Serial(tl.Dense(d_lowrank), tl.Dense(n_units)) -@assert_shape('...a->...b') +@assert_shape("...a->...b") def EinsumDense(d_input, d_output, use_bias): - """Returns a reimplementation of Dense layer, using einsum. - - While this is an equivalent of a Dense layer, it seems to be faster when used - in decoding if used with bias (see decoding_timing_test.py ). - This layer can be removed when we understand better the reason for the - difference in decoding speed. - - Args: - d_input: Dimensionality of the input tensor. - d_output: Dimensionality of the output tensor. - use_bias: Whether to use bias. - """ - layers = [ - tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]), - tl.Fn('EinsumDense', - (lambda kernel, embeds: # pylint: disable=g-long-lambda - jnp.einsum('xd,...d->...x', kernel, embeds))) - ] - if use_bias: - layers.extend([ - tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), - tl.Add() - ]) - return tl.Serial(layers) + """Returns a reimplementation of Dense layer, using einsum. + + While this is an equivalent of a Dense layer, it seems to be faster when used + in decoding if used with bias (see decoding_timing_test.py ). + This layer can be removed when we understand better the reason for the + difference in decoding speed. + + Args: + d_input: Dimensionality of the input tensor. + d_output: Dimensionality of the output tensor. + use_bias: Whether to use bias. + """ + layers = [ + tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]), + tl.Fn( + "EinsumDense", + ( + lambda kernel, embeds: jnp.einsum( # pylint: disable=g-long-lambda + "xd,...d->...x", kernel, embeds + ) + ), + ), + ] + if use_bias: + layers.extend( + [tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), tl.Add()] + ) + return tl.Serial(layers) def RandomLayer(layer_a, layer_b, prob_a): - """Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.""" - condition = tl.Serial( - tl.RandomUniform(), - tl.Fn('SmallerThan', lambda x: x < prob_a) - ) - return tl.Cond(condition, layer_a, layer_b) - - -@assert_shape('...a->...b') -def SparseDenseWithOptions(n_units, d_input=None, sparsity_type=None, - sparsity=0, d_lowrank=None, prob_sparse=None, - mode=None, use_bias=True, use_bfloat16=False): - """Configurable sparse version of Dense layer.""" - if prob_sparse is not None: - if mode is not None and mode != 'train': - # For non-training modes, we want to use a sparse variant. - # This is different than simply prob_sparse being None, as the weights of - # the model are different. - prob_sparse = 1.0 - return RandomLayer( - SparseDenseWithOptions(n_units, d_input, sparsity_type, sparsity, - d_lowrank, use_bias=use_bias, - use_bfloat16=use_bfloat16), - tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16), - prob_sparse) - - if sparsity_type is None or sparsity_type == 'None' or sparsity == 0: - return tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16) - if sparsity_type == 'mult': - return FactoredDense(sparsity, d_input, n_units, use_bias=use_bias, - use_bfloat16=use_bfloat16) - - assert not use_bfloat16 # use_bfloat16 is unsupported for other variants - if sparsity_type == 'lowrank': - assert use_bias # use_bias=False is unsupported - return LowRankDense(n_units, d_lowrank) - if sparsity_type == 'einsum': - return EinsumDense(d_input, n_units, use_bias=use_bias) - if sparsity_type == 'local': - assert use_bias # use_bias = False is unsupported - assert n_units % sparsity == 0 - return LocallyConnectedDense(sparsity, n_units/sparsity) - if sparsity_type == 'local3': - assert use_bias # use_bias = False is unsupported - assert n_units % sparsity == 0 - return LocallyConnectedDense(sparsity, n_units/sparsity, kernel_size=3) - - raise ValueError('Unknown sparsity type: {}'.format(sparsity_type)) - - -@assert_shape('bld->bld') -def LowRankCausalAttention(d_feature, n_heads=1, dropout=0.0, - max_inference_length=2048, lowrank=64, - mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - it uses low-rank approximation of kernel in Dense layer for computing Q/K/V. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - lowrank: The rank of low-rank approximation. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - return tl.ConfigurableAttention( - LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), - LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), - n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) - - -@assert_shape('...a->...b') + """Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.""" + condition = tl.Serial( + tl.RandomUniform(), tl.Fn("SmallerThan", lambda x: x < prob_a) + ) + return tl.Cond(condition, layer_a, layer_b) + + +@assert_shape("...a->...b") +def SparseDenseWithOptions( + n_units, + d_input=None, + sparsity_type=None, + sparsity=0, + d_lowrank=None, + prob_sparse=None, + mode=None, + use_bias=True, + use_bfloat16=False, +): + """Configurable sparse version of Dense layer.""" + if prob_sparse is not None: + if mode is not None and mode != "train": + # For non-training modes, we want to use a sparse variant. + # This is different than simply prob_sparse being None, as the weights of + # the model are different. + prob_sparse = 1.0 + return RandomLayer( + SparseDenseWithOptions( + n_units, + d_input, + sparsity_type, + sparsity, + d_lowrank, + use_bias=use_bias, + use_bfloat16=use_bfloat16, + ), + tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16), + prob_sparse, + ) + + if sparsity_type is None or sparsity_type == "None" or sparsity == 0: + return tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16) + if sparsity_type == "mult": + return FactoredDense( + sparsity, d_input, n_units, use_bias=use_bias, use_bfloat16=use_bfloat16 + ) + + assert not use_bfloat16 # use_bfloat16 is unsupported for other variants + if sparsity_type == "lowrank": + assert use_bias # use_bias=False is unsupported + return LowRankDense(n_units, d_lowrank) + if sparsity_type == "einsum": + return EinsumDense(d_input, n_units, use_bias=use_bias) + if sparsity_type == "local": + assert use_bias # use_bias = False is unsupported + assert n_units % sparsity == 0 + return LocallyConnectedDense(sparsity, n_units / sparsity) + if sparsity_type == "local3": + assert use_bias # use_bias = False is unsupported + assert n_units % sparsity == 0 + return LocallyConnectedDense(sparsity, n_units / sparsity, kernel_size=3) + + raise ValueError("Unknown sparsity type: {}".format(sparsity_type)) + + +@assert_shape("bld->bld") +def LowRankCausalAttention( + d_feature, + n_heads=1, + dropout=0.0, + max_inference_length=2048, + lowrank=64, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + it uses low-rank approximation of kernel in Dense layer for computing Q/K/V. + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + lowrank: The rank of low-rank approximation. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + + return tl.ConfigurableAttention( + LowRankDense(d_feature, lowrank), + LowRankDense(d_feature, lowrank), + LowRankDense(d_feature, lowrank), + LowRankDense(d_feature, lowrank), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("...a->...b") def FactoredDense(n_modules, d_in, d_out, use_bias=True, use_bfloat16=False): - r"""Returns a Dense-like layer, internally factored to use fewer parameters. - - This layer treats an activation vector as if divided into :math:`M` - subvectors (``n_modules`` 'modules'). It uses this factored view to compute - a :py:class:`Dense`-like mapping with high mixing/connectivity, but using - approximately :math:`1/M` the number of weights of a similarly dimensioned - :py:class:`Dense` layer. - - More specifically, each activation vector of dimensionality ``n_in`` is - multiplied element-wise (a generalized form of gating) with ``n_modules`` - vectors also of dimensionality ``n_in``. The resulting vectors are projected - to the subvector/module dimensionality ``d_out / n_modules`` via a matrix - multiply, and finally reshaped back to a single vector of dimensionality - ``d_out``. Optionally, a bias vector of dimensionality ``d_out`` is added at - the end. All the above-mentioned non-input objects -- gating vectors, - projection matrix, and optional bias -- are trainable weights. - - Args: - n_modules: Number by which an activation vector is divided into subvectors - (modules) for the factored computation. - d_in: Last/innermost dimension of input array. - d_out: Last/innermost dimension of output array. - use_bias: If True, add bias vectors at the end of the layer; else end the - layer with the matrix multiply. - use_bfloat16: If True, use bfloat16 weights; else use float32 weights. - """ - if d_out % n_modules != 0: - raise ValueError(f'Value d_out ({d_out}) must be a multiple of arg ' - f'n_modules ({n_modules}).') - d_module = d_out // n_modules - - def GatingVectors(): - return tl.Weights(init.RandomNormalInitializer(stddev=0.5), - shape=[n_modules, d_in], - use_bfloat16=use_bfloat16) - - def ProjectionMatrix(): - return tl.Weights(init.GlorotUniformInitializer(), - shape=[d_in, d_module], - use_bfloat16=use_bfloat16), - - def Bias(): - return tl.Weights(init.RandomNormalInitializer(1e-6), - shape=[d_out], - use_bfloat16=use_bfloat16), - - layers = [ - GatingVectors(), - ProjectionMatrix(), - _GateAndProject(), - MergeLastTwoAxes(), - ] - if use_bias: - layers += [Bias(), tl.Add()] - - return tl.Serial(layers) + r"""Returns a Dense-like layer, internally factored to use fewer parameters. + + This layer treats an activation vector as if divided into :math:`M` + subvectors (``n_modules`` 'modules'). It uses this factored view to compute + a :py:class:`Dense`-like mapping with high mixing/connectivity, but using + approximately :math:`1/M` the number of weights of a similarly dimensioned + :py:class:`Dense` layer. + + More specifically, each activation vector of dimensionality ``n_in`` is + multiplied element-wise (a generalized form of gating) with ``n_modules`` + vectors also of dimensionality ``n_in``. The resulting vectors are projected + to the subvector/module dimensionality ``d_out / n_modules`` via a matrix + multiply, and finally reshaped back to a single vector of dimensionality + ``d_out``. Optionally, a bias vector of dimensionality ``d_out`` is added at + the end. All the above-mentioned non-input objects -- gating vectors, + projection matrix, and optional bias -- are trainable weights. + + Args: + n_modules: Number by which an activation vector is divided into subvectors + (modules) for the factored computation. + d_in: Last/innermost dimension of input array. + d_out: Last/innermost dimension of output array. + use_bias: If True, add bias vectors at the end of the layer; else end the + layer with the matrix multiply. + use_bfloat16: If True, use bfloat16 weights; else use float32 weights. + """ + if d_out % n_modules != 0: + raise ValueError( + f"Value d_out ({d_out}) must be a multiple of arg " + f"n_modules ({n_modules})." + ) + d_module = d_out // n_modules + + def GatingVectors(): + return tl.Weights( + init.RandomNormalInitializer(stddev=0.5), + shape=[n_modules, d_in], + use_bfloat16=use_bfloat16, + ) + + def ProjectionMatrix(): + return ( + tl.Weights( + init.GlorotUniformInitializer(), + shape=[d_in, d_module], + use_bfloat16=use_bfloat16, + ), + ) + + def Bias(): + return ( + tl.Weights( + init.RandomNormalInitializer(1e-6), + shape=[d_out], + use_bfloat16=use_bfloat16, + ), + ) + + layers = [ + GatingVectors(), + ProjectionMatrix(), + _GateAndProject(), + MergeLastTwoAxes(), + ] + if use_bias: + layers += [Bias(), tl.Add()] + + return tl.Serial(layers) def _GateAndProject(): - """Returns a combined gating+projection layer that saves on memory.""" + """Returns a combined gating+projection layer that saves on memory.""" - def f(projection, gating, x): - # Args arrive in reverse order because of how they were put on the stack. - # Einsum indices: d (d_in), n (n_modules), m (d_module = d_out/n_modules) - return jnp.einsum('...d,nd,dm->...nm', x, gating, projection) + def f(projection, gating, x): + # Args arrive in reverse order because of how they were put on the stack. + # Einsum indices: d (d_in), n (n_modules), m (d_module = d_out/n_modules) + return jnp.einsum("...d,nd,dm->...nm", x, gating, projection) - return tl.Fn('_GateAndProject', f) + return tl.Fn("_GateAndProject", f) -@assert_shape('...a->...a') +@assert_shape("...a->...a") def MultiplicativeModularSparseDense(sparsity, d_feature): - """Returns a replacement of Dense layer which uses less parameters. - - The layer uses number of modules equal to `sparsity`. It is a combination of - multiplicative dense and locally connected dense layers. - - Args: - sparsity: The sparsity of the layer; the output vector is divided into this - number of modules. - d_feature: Dimensionality of input and output tensor. - """ - - assert d_feature % sparsity == 0 - d_module = d_feature // sparsity - - return tl.Serial( - # Weight below is used for per-head preprocessing of an embedding. - tl.Weights(init.RandomNormalInitializer(stddev=0.5), - shape=[sparsity, d_feature]), - # Weight below is a kernel of multiplicative dense, shared across heads. - tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]), - # Weight below is a kernel of modular dense. - tl.Weights(functools.partial(init.GlorotUniformInitializer(), - nonreceptive_dims=[0]), - [sparsity, d_module, d_module]), - # To save memory the per-head preprocessing and multiplying by - # kernels is done in a single einsum. - tl.Fn('SparseDenseEinsum', - (lambda kmod, kmult, multiplier, embeds: # pylint: disable=g-long-lambda - jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier, embeds - ))), - MergeLastTwoAxes(), - # Weight below is bias after dense, per-head. - tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), - tl.Add(), - ) - - -@assert_shape('bld->bld') -def MultiplicativeCausalAttention(d_feature, n_heads=1, sparsity=None, - dropout=0.0, max_inference_length=2048, - mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - for computing Q/K/V instead of a Dense layer it multiplies each embedding - dimension by a scalar specific to each dimension and each head; then it - produces Q/K/V by applying the same dense layer to each head. In comparison - to standard dense layer for computing Q/K/V, this layer uses less parameters - while still being able to express many functions, like a permutation. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: The sparsity of the layer; usually it should be equal to n_heads. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - sparsity = n_heads if sparsity is None else sparsity - return tl.ConfigurableAttention( - FactoredDense(sparsity, d_feature, d_feature), - FactoredDense(sparsity, d_feature, d_feature), - FactoredDense(sparsity, d_feature, d_feature), - FactoredDense(sparsity, d_feature, d_feature), - n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) - - -@assert_shape('bld->bld') + """Returns a replacement of Dense layer which uses less parameters. + + The layer uses number of modules equal to `sparsity`. It is a combination of + multiplicative dense and locally connected dense layers. + + Args: + sparsity: The sparsity of the layer; the output vector is divided into this + number of modules. + d_feature: Dimensionality of input and output tensor. + """ + + assert d_feature % sparsity == 0 + d_module = d_feature // sparsity + + return tl.Serial( + # Weight below is used for per-head preprocessing of an embedding. + tl.Weights( + init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_feature] + ), + # Weight below is a kernel of multiplicative dense, shared across heads. + tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]), + # Weight below is a kernel of modular dense. + tl.Weights( + functools.partial(init.GlorotUniformInitializer(), nonreceptive_dims=[0]), + [sparsity, d_module, d_module], + ), + # To save memory the per-head preprocessing and multiplying by + # kernels is done in a single einsum. + tl.Fn( + "SparseDenseEinsum", + ( + lambda kmod, kmult, multiplier, embeds: jnp.einsum( # pylint: disable=g-long-lambda + "hxo,dx,hd,...d->...ho", kmod, kmult, multiplier, embeds + ) + ), + ), + MergeLastTwoAxes(), + # Weight below is bias after dense, per-head. + tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), + tl.Add(), + ) + + +@assert_shape("bld->bld") +def MultiplicativeCausalAttention( + d_feature, + n_heads=1, + sparsity=None, + dropout=0.0, + max_inference_length=2048, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + for computing Q/K/V instead of a Dense layer it multiplies each embedding + dimension by a scalar specific to each dimension and each head; then it + produces Q/K/V by applying the same dense layer to each head. In comparison + to standard dense layer for computing Q/K/V, this layer uses less parameters + while still being able to express many functions, like a permutation. + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: The sparsity of the layer; usually it should be equal to n_heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + sparsity = n_heads if sparsity is None else sparsity + return tl.ConfigurableAttention( + FactoredDense(sparsity, d_feature, d_feature), + FactoredDense(sparsity, d_feature, d_feature), + FactoredDense(sparsity, d_feature, d_feature), + FactoredDense(sparsity, d_feature, d_feature), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("bld->bld") def MultiplicativeModularCausalAttention( - d_feature, n_heads=1, sparsity=None, dropout=0.0, max_inference_length=2048, - mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - for computing Q/K/V instead of a Dense layer it combines - FactoredDense layer with LocallyConnectedLayer. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: The sparsity of the layer; usually it should be equal to n_heads. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - sparsity = n_heads if sparsity is None else sparsity - return tl.ConfigurableAttention( - MultiplicativeModularSparseDense(sparsity, d_feature), - MultiplicativeModularSparseDense(sparsity, d_feature), - MultiplicativeModularSparseDense(sparsity, d_feature), - MultiplicativeModularSparseDense(sparsity, d_feature), n_heads=n_heads, - qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) - - -@assert_shape('bld->bld') + d_feature, + n_heads=1, + sparsity=None, + dropout=0.0, + max_inference_length=2048, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + for computing Q/K/V instead of a Dense layer it combines + FactoredDense layer with LocallyConnectedLayer. + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: The sparsity of the layer; usually it should be equal to n_heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + sparsity = n_heads if sparsity is None else sparsity + return tl.ConfigurableAttention( + MultiplicativeModularSparseDense(sparsity, d_feature), + MultiplicativeModularSparseDense(sparsity, d_feature), + MultiplicativeModularSparseDense(sparsity, d_feature), + MultiplicativeModularSparseDense(sparsity, d_feature), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("bld->bld") def MultiplicativeConvCausalAttention( - d_feature, n_heads=1, sparsity=None, length_kernel_size=3, dropout=0.0, - force_no_dropout=False, max_inference_length=2048, share_qk=False, - output_layer_type='none', v_concat_type='none', mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - for computing Q/K/V instead of a Dense layer it combines - FactoredDense layer with LocallyConvLayer. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: The sparsity of the layer; usually it should be equal to n_heads. - length_kernel_size: Size of convolution kernel on the length dimension. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - force_no_dropout: If True, force dropout to be 0.0 independent of the above - value; used to override some configurations. - max_inference_length: maximum length for inference. - share_qk: if True, average Q and K embeddings and share for both Q and K. - output_layer_type: Which sparse layers to use for processing output from the - attention mechanism. One of `'none'`, `'mult'`, `'conv'`, - or `'multconv'`. - v_concat_type: What kind of concatenation to use when computing V tensor. - One of `'original'`, `'fixed'`, or `'none'`. `'none'` means using just - output from mutliplicative layer shared by Q, K, V. `'fixed'` means - using output from multiplicative layer concatenated, for each module, - with the layer input. `'original'` means using concatenation without - properly taking modules into account; this method was used in - experiments previously, so it is included for backwards-compatibility. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - assert output_layer_type in ['none', 'mult', 'conv', 'multconv'] - assert v_concat_type in ['original', 'fixed', 'none'] - - dropout = 0.0 if force_no_dropout else dropout - sparsity = n_heads if sparsity is None else sparsity - d_module = d_feature // sparsity - - output_layers = [] - if 'mult' in output_layer_type: - output_layers.append(FactoredDense( - sparsity, d_feature, d_feature)) - if 'conv' in output_layer_type: - output_layers.append(LocallyConvDense( - sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size)) - - if v_concat_type == 'original': - # 'original'` uses concatenation without properly taking modules into - # account; this method was used in experiments previously, so it is included - # for backwards-compatibility. - concat_layers = [tl.Concatenate()] # use permuted and original for v - elif v_concat_type == 'fixed': - # `'fixed'` uses the output from multiplicative layer concatenated, for each - # module, with the layer input. This means that every module in Conv layer - # has access both to parts of embeddings which were used to compute Q/K of - # this particular module, and it ha access to parts of the embedding which - # will be modified by this module. - concat_layers = [ - tl.Parallel( - tl.Fn('Reshape1', lambda x: jnp.reshape( # pylint: disable=g-long-lambda - x, (x.shape[0], x.shape[1], sparsity, d_module))), - tl.Fn('Reshape2', lambda x: jnp.reshape( # pylint: disable=g-long-lambda - x, (x.shape[0], x.shape[1], sparsity, d_module)))), - tl.Concatenate(), - tl.Fn('Reshape3', - lambda x: jnp.reshape(x, (x.shape[0], x.shape[1], 2*d_feature))), - ] - elif v_concat_type == 'none': - # `'none'` doesn't use concatenation: we throw away the original layer - # input and pass to Conv only output of shared Multiplicative layer. - concat_layers = [tl.Select([0], n_in=2)] + d_feature, + n_heads=1, + sparsity=None, + length_kernel_size=3, + dropout=0.0, + force_no_dropout=False, + max_inference_length=2048, + share_qk=False, + output_layer_type="none", + v_concat_type="none", + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + for computing Q/K/V instead of a Dense layer it combines + FactoredDense layer with LocallyConvLayer. - if share_qk: + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: The sparsity of the layer; usually it should be equal to n_heads. + length_kernel_size: Size of convolution kernel on the length dimension. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + force_no_dropout: If True, force dropout to be 0.0 independent of the above + value; used to override some configurations. + max_inference_length: maximum length for inference. + share_qk: if True, average Q and K embeddings and share for both Q and K. + output_layer_type: Which sparse layers to use for processing output from the + attention mechanism. One of `'none'`, `'mult'`, `'conv'`, + or `'multconv'`. + v_concat_type: What kind of concatenation to use when computing V tensor. + One of `'original'`, `'fixed'`, or `'none'`. `'none'` means using just + output from mutliplicative layer shared by Q, K, V. `'fixed'` means + using output from multiplicative layer concatenated, for each module, + with the layer input. `'original'` means using concatenation without + properly taking modules into account; this method was used in + experiments previously, so it is included for backwards-compatibility. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + assert output_layer_type in ["none", "mult", "conv", "multconv"] + assert v_concat_type in ["original", "fixed", "none"] + + dropout = 0.0 if force_no_dropout else dropout + sparsity = n_heads if sparsity is None else sparsity + d_module = d_feature // sparsity + + output_layers = [] + if "mult" in output_layer_type: + output_layers.append(FactoredDense(sparsity, d_feature, d_feature)) + if "conv" in output_layer_type: + output_layers.append( + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ) + ) + + if v_concat_type == "original": + # 'original'` uses concatenation without properly taking modules into + # account; this method was used in experiments previously, so it is included + # for backwards-compatibility. + concat_layers = [tl.Concatenate()] # use permuted and original for v + elif v_concat_type == "fixed": + # `'fixed'` uses the output from multiplicative layer concatenated, for each + # module, with the layer input. This means that every module in Conv layer + # has access both to parts of embeddings which were used to compute Q/K of + # this particular module, and it ha access to parts of the embedding which + # will be modified by this module. + concat_layers = [ + tl.Parallel( + tl.Fn( + "Reshape1", + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + x, (x.shape[0], x.shape[1], sparsity, d_module) + ), + ), + tl.Fn( + "Reshape2", + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + x, (x.shape[0], x.shape[1], sparsity, d_module) + ), + ), + ), + tl.Concatenate(), + tl.Fn( + "Reshape3", + lambda x: jnp.reshape(x, (x.shape[0], x.shape[1], 2 * d_feature)), + ), + ] + elif v_concat_type == "none": + # `'none'` doesn't use concatenation: we throw away the original layer + # input and pass to Conv only output of shared Multiplicative layer. + concat_layers = [tl.Select([0], n_in=2)] + + if share_qk: + return tl.Serial( + tl.Select([0, 0]), # pre-qkv, pre-v-for-concat + FactoredDense(sparsity, d_feature, d_feature), # shared q k + tl.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + tl.Select([0, 0]), # use for q and k + tl.Parallel( + [], + [], + [ + concat_layers, + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=1, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + ], + ), + tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + tl.MergeHeads(n_heads), + output_layers, + ) return tl.Serial( - tl.Select([0, 0]), # pre-qkv, pre-v-for-concat - FactoredDense(sparsity, d_feature, d_feature), # shared q k - tl.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat - LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads), - tl.Select([0, 0]), # use for q and k + tl.Select([0, 0]), # duplicate activations + FactoredDense(sparsity, d_feature, d_feature), # shared q, k + tl.Select([0, 0, 0]), # use for q, k, v tl.Parallel( - [], - [], - [concat_layers, - LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads)], + [ + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + ], + [ + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + ], + [ + concat_layers, + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=1, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + ], ), tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode), + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), tl.MergeHeads(n_heads), output_layers, ) - return tl.Serial( - tl.Select([0, 0]), # duplicate activations - FactoredDense(sparsity, d_feature, d_feature), # shared q, k - tl.Select([0, 0, 0]), # use for q, k, v - tl.Parallel( - [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads)], - [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads)], - [concat_layers, - LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads)], - ), - tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode), - tl.MergeHeads(n_heads), - output_layers, - ) class FavorAttention(base.Layer): - """Implements FAVOR+ attention. - - Original paper: https://arxiv.org/abs/2006.03555 - The layer expects 4 inputs: (Q, K, V, MASK), and returns two outputs: - (RENORMALIZED_ATTENTION, MASK). - - Attributes: - - d_feature: Dimensionality of feature embedding. - n_heads: Number of attention heads. - n_random_features: Free dimension size for the orthogonal random matrix. - numerical_stabilizer: float, small number used for numerical stability. - use_approximate_softmax: Bool, if True uses approximate softmax, otherwise - Relu. - scale_by_norm: Boolean; whether to scale orthogonal random matrix. - normalize_data: predicate indicating whether data should be normalized. - epsilon: numerical stabilizer. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - def __init__(self, d_feature=4, n_heads=1, n_random_features=256, - numerical_stabilizer=0.001, - use_approximate_softmax=False, scale_by_norm=True, - normalize_data=False, - epsilon=0.0001, mode='train'): - super().__init__(n_in=4, n_out=2) - self._d_feature = d_feature - self._n_heads = n_heads - self._n_random_features = n_random_features - self._numerical_stabilizer = numerical_stabilizer - self._mode = mode - self._use_approximate_softmax = use_approximate_softmax - self._normalize_data = normalize_data - self._epsilon = epsilon - if self._use_approximate_softmax: - rng = random.get_prng(0) - self._projection_matrix = self.get_2d_array( - rng=rng, n_rows=self._n_random_features, - n_columns=(self._d_feature // self._n_heads), - scale_by_norm=scale_by_norm, - normalize_data=normalize_data, epsilon=epsilon) - else: - self._projection_matrix = None + """Implements FAVOR+ attention. - def nonnegative_softmax_kernel_feature_creator(self, x, is_query): - """Constructs nonnegative kernel features for fast softmax attention. + Original paper: https://arxiv.org/abs/2006.03555 + The layer expects 4 inputs: (Q, K, V, MASK), and returns two outputs: + (RENORMALIZED_ATTENTION, MASK). - Args: - x: input for which features are computed. - is_query: predicate indicating whether input data corresponds to - queries or keys. + Attributes: - Returns: - Random features for fast softmax attention. + d_feature: Dimensionality of feature embedding. + n_heads: Number of attention heads. + n_random_features: Free dimension size for the orthogonal random matrix. + numerical_stabilizer: float, small number used for numerical stability. + use_approximate_softmax: Bool, if True uses approximate softmax, otherwise + Relu. + scale_by_norm: Boolean; whether to scale orthogonal random matrix. + normalize_data: predicate indicating whether data should be normalized. + epsilon: numerical stabilizer. + mode: One of `'train'`, `'eval'`, or `'predict'`. """ - if self._normalize_data: - # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where - # w_norm = w * data_normalizer for w in {q,k}. - data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(x.shape[-1]))) - else: - data_normalizer = 1.0 - ratio = 1.0 / jnp.sqrt(self._projection_matrix.shape[0]) - # TODO(wgaj): Double-check... Should there be only one batch dimension...? - data_mod_shape = x.shape[0:1] + self._projection_matrix.shape - data_thick_random_matrix = (jnp.zeros(data_mod_shape) + - self._projection_matrix) - - data_dash = jnp.einsum('Bij, Bkj -> Bik', - data_normalizer * x, - data_thick_random_matrix) - diag_data = jnp.square(x) - diag_data = jnp.sum(diag_data, axis=x.ndim - 1) - diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer - diag_data = jnp.expand_dims(diag_data, axis=x.ndim - 1) - - last_dims_t = (len(data_dash.shape) - 1,) - attention_dims_t = (1,) - if is_query: - data_dash = ratio * ( - jnp.exp(data_dash - diag_data - - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + - self._epsilon) - else: - data_dash = ratio * ( - jnp.exp(data_dash - diag_data - jnp.max( - data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) + - self._epsilon) - - return data_dash - @staticmethod - def get_2d_array(rng, n_rows=256, n_columns=0, scale_by_norm=True, - normalize_data=False, epsilon=0.0001): - """Generator for approximate softmax orthogonal kernel feature matrix. + def __init__( + self, + d_feature=4, + n_heads=1, + n_random_features=256, + numerical_stabilizer=0.001, + use_approximate_softmax=False, + scale_by_norm=True, + normalize_data=False, + epsilon=0.0001, + mode="train", + ): + super().__init__(n_in=4, n_out=2) + self._d_feature = d_feature + self._n_heads = n_heads + self._n_random_features = n_random_features + self._numerical_stabilizer = numerical_stabilizer + self._mode = mode + self._use_approximate_softmax = use_approximate_softmax + self._normalize_data = normalize_data + self._epsilon = epsilon + if self._use_approximate_softmax: + rng = random.get_prng(0) + self._projection_matrix = self.get_2d_array( + rng=rng, + n_rows=self._n_random_features, + n_columns=(self._d_feature // self._n_heads), + scale_by_norm=scale_by_norm, + normalize_data=normalize_data, + epsilon=epsilon, + ) + else: + self._projection_matrix = None + + def nonnegative_softmax_kernel_feature_creator(self, x, is_query): + """Constructs nonnegative kernel features for fast softmax attention. + + Args: + x: input for which features are computed. + is_query: predicate indicating whether input data corresponds to + queries or keys. + + Returns: + Random features for fast softmax attention. + """ + if self._normalize_data: + # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where + # w_norm = w * data_normalizer for w in {q,k}. + data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(x.shape[-1]))) + else: + data_normalizer = 1.0 + ratio = 1.0 / jnp.sqrt(self._projection_matrix.shape[0]) + # TODO(wgaj): Double-check... Should there be only one batch dimension...? + data_mod_shape = x.shape[0:1] + self._projection_matrix.shape + data_thick_random_matrix = jnp.zeros(data_mod_shape) + self._projection_matrix + + data_dash = jnp.einsum( + "Bij, Bkj -> Bik", data_normalizer * x, data_thick_random_matrix + ) + diag_data = jnp.square(x) + diag_data = jnp.sum(diag_data, axis=x.ndim - 1) + diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer + diag_data = jnp.expand_dims(diag_data, axis=x.ndim - 1) + + last_dims_t = (len(data_dash.shape) - 1,) + attention_dims_t = (1,) + if is_query: + data_dash = ratio * ( + jnp.exp( + data_dash + - diag_data + - jnp.max(data_dash, axis=last_dims_t, keepdims=True) + ) + + self._epsilon + ) + else: + data_dash = ratio * ( + jnp.exp( + data_dash + - diag_data + - jnp.max( + data_dash, axis=last_dims_t + attention_dims_t, keepdims=True + ) + ) + + self._epsilon + ) + + return data_dash + + @staticmethod + def get_2d_array( + rng, + n_rows=256, + n_columns=0, + scale_by_norm=True, + normalize_data=False, + epsilon=0.0001, + ): + """Generator for approximate softmax orthogonal kernel feature matrix. + + Args: + rng: Random number generator. + n_rows: Number of rows. + n_columns: Number of columns. + scale_by_norm: Boolean; whether to scale orthogonal random matrix. + normalize_data: predicate indicating whether data should be normalized. + epsilon: numerical stabilizer. + + Returns: + Orthogonal kernel feature matrix. + """ + n_full_blocks = int(n_rows / n_columns) + block_list = [] + rng_key = rng + for _ in range(n_full_blocks): + rng, rng_input = random.split(rng) + unstructured_block = random.normal(rng_input, (n_columns, n_columns)) + q, _ = jnp.linalg.qr(unstructured_block) + q = jnp.transpose(q) + block_list.append(q) + remaining_rows = n_rows - n_full_blocks * n_columns + if remaining_rows > 0: + rng, rng_input = random.split(rng) + unstructured_block = random.normal(rng_input, (n_columns, n_columns)) + q, _ = jnp.linalg.qr(unstructured_block) + q = jnp.transpose(q) + block_list.append(q[0:remaining_rows]) + final_matrix = jnp.vstack(block_list) + + if scale_by_norm: + multiplier = jnp.linalg.norm( + random.normal(rng_key, (n_rows, n_columns)), axis=1 + ) + else: + multiplier = jnp.sqrt(float(n_columns)) * jnp.ones((n_rows)) + + return jnp.matmul(jnp.diag(multiplier), final_matrix) + + @staticmethod + def bidirectional_numerator(query_prime, key_prime, value): + kvs = jnp.einsum("lbm,lbd->bmd", key_prime, value) + return jnp.einsum("lbm,bmd->lbd", query_prime, kvs) + + @staticmethod + def bidirectional_denominator(query_prime, key_prime): + all_ones = jnp.ones([query_prime.shape[0]]) + ks_sum = jnp.einsum("lbm,l->bm", key_prime, all_ones) + return jnp.einsum("lbm,bm->lb", query_prime, ks_sum) + + @staticmethod + def relu(x): + return jnp.where(x <= 0, jnp.zeros_like(x), x) + + def forward(self, inputs): + query, key, value, mask = inputs + if self._use_approximate_softmax: + query_prime = self.nonnegative_softmax_kernel_feature_creator(query, True) + key_prime = self.nonnegative_softmax_kernel_feature_creator(key, False) + else: + query_prime = self.relu(query) + self._numerical_stabilizer + key_prime = self.relu(key) + self._numerical_stabilizer + mask_batch_1_length = jnp.reshape( + mask, [key.shape[0] // self._n_heads, 1, key.shape[1]] + ).astype(jnp.float32) + mask_heads = mask_batch_1_length + jnp.zeros((1, self._n_heads, 1)) + key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) + + w = self.bidirectional_numerator( + jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0), + jnp.moveaxis(value, 1, 0), + ) + r = self.bidirectional_denominator( + jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0) + ) + w = jnp.moveaxis(w, 0, 1) + r = jnp.moveaxis(r, 0, 1) + r = jnp.reciprocal(r) + r = jnp.expand_dims(r, len(r.shape)) + renormalized_attention = w * r + return renormalized_attention, mask + + +def Favor( + d_feature, + n_heads=1, + n_random_features=256, + dropout=0.0, + numerical_stabilizer=0.001, + use_approximate_softmax=False, + scale_by_norm=0, + normalize_data=False, + epsilon=0.0001, + mode="train", +): + """Returns a layer that maps (activations, mask) to (new_activations, mask). + + See the FAVOR paper for details: https://arxiv.org/abs/2006.03555 Args: - rng: Random number generator. - n_rows: Number of rows. - n_columns: Number of columns. + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + n_random_features: Free dimension size for the orthogonal random matrix. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + numerical_stabilizer: float, small number used for numerical stability. + use_approximate_softmax: Bool, if True uses approximate softmax, otherwise + Relu. scale_by_norm: Boolean; whether to scale orthogonal random matrix. normalize_data: predicate indicating whether data should be normalized. epsilon: numerical stabilizer. - - Returns: - Orthogonal kernel feature matrix. + mode: One of `'train'`, `'eval'`, or `'predict'`. """ - n_full_blocks = int(n_rows / n_columns) - block_list = [] - rng_key = rng - for _ in range(n_full_blocks): - rng, rng_input = random.split(rng) - unstructured_block = random.normal(rng_input, (n_columns, n_columns)) - q, _ = jnp.linalg.qr(unstructured_block) - q = jnp.transpose(q) - block_list.append(q) - remaining_rows = n_rows - n_full_blocks * n_columns - if remaining_rows > 0: - rng, rng_input = random.split(rng) - unstructured_block = random.normal(rng_input, (n_columns, n_columns)) - q, _ = jnp.linalg.qr(unstructured_block) - q = jnp.transpose(q) - block_list.append(q[0:remaining_rows]) - final_matrix = jnp.vstack(block_list) - - if scale_by_norm: - multiplier = jnp.linalg.norm( - random.normal(rng_key, (n_rows, n_columns)), axis=1) - else: - multiplier = jnp.sqrt(float(n_columns)) * jnp.ones((n_rows)) - - return jnp.matmul(jnp.diag(multiplier), final_matrix) - - @staticmethod - def bidirectional_numerator(query_prime, key_prime, value): - kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value) - return jnp.einsum('lbm,bmd->lbd', query_prime, kvs) - - @staticmethod - def bidirectional_denominator(query_prime, key_prime): - all_ones = jnp.ones([query_prime.shape[0]]) - ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones) - return jnp.einsum('lbm,bm->lb', query_prime, ks_sum) - - @staticmethod - def relu(x): - return jnp.where(x <= 0, jnp.zeros_like(x), x) - - def forward(self, inputs): - query, key, value, mask = inputs - if self._use_approximate_softmax: - query_prime = self.nonnegative_softmax_kernel_feature_creator(query, True) - key_prime = self.nonnegative_softmax_kernel_feature_creator(key, False) - else: - query_prime = self.relu(query) + self._numerical_stabilizer - key_prime = self.relu(key) + self._numerical_stabilizer - mask_batch_1_length = jnp.reshape( - mask, [key.shape[0] // self._n_heads, 1, key.shape[1]]).astype( - jnp.float32) - mask_heads = mask_batch_1_length + jnp.zeros((1, self._n_heads, 1)) - key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) - - w = self.bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0), - jnp.moveaxis(key_prime, 1, 0), - jnp.moveaxis(value, 1, 0)) - r = self.bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0), - jnp.moveaxis(key_prime, 1, 0)) - w = jnp.moveaxis(w, 0, 1) - r = jnp.moveaxis(r, 0, 1) - r = jnp.reciprocal(r) - r = jnp.expand_dims(r, len(r.shape)) - renormalized_attention = w * r - return renormalized_attention, mask - - -def Favor(d_feature, n_heads=1, n_random_features=256, dropout=0.0, - numerical_stabilizer=0.001, use_approximate_softmax=False, - scale_by_norm=0, normalize_data=False, epsilon=0.0001, mode='train'): - """Returns a layer that maps (activations, mask) to (new_activations, mask). - - See the FAVOR paper for details: https://arxiv.org/abs/2006.03555 - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - n_random_features: Free dimension size for the orthogonal random matrix. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - numerical_stabilizer: float, small number used for numerical stability. - use_approximate_softmax: Bool, if True uses approximate softmax, otherwise - Relu. - scale_by_norm: Boolean; whether to scale orthogonal random matrix. - normalize_data: predicate indicating whether data should be normalized. - epsilon: numerical stabilizer. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - del dropout # not implemented yet but needed in the API - - return tl.ConfigurableAttention( - tl.Dense(d_feature), tl.Dense(d_feature), tl.Dense(d_feature), - tl.Dense(d_feature), - tl.FavorAttention(d_feature, n_heads, n_random_features, - numerical_stabilizer, use_approximate_softmax, - scale_by_norm, normalize_data, epsilon, mode), - n_heads=n_heads) + del dropout # not implemented yet but needed in the API + + return tl.ConfigurableAttention( + tl.Dense(d_feature), + tl.Dense(d_feature), + tl.Dense(d_feature), + tl.Dense(d_feature), + tl.FavorAttention( + d_feature, + n_heads, + n_random_features, + numerical_stabilizer, + use_approximate_softmax, + scale_by_norm, + normalize_data, + epsilon, + mode, + ), + n_heads=n_heads, + ) class CausalFavorAttention(base.Layer): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - causal attention, but using FAVOR fast attention as in the following paper: - https://arxiv.org/abs/2006.03555 - - Layer expects three inputs (Q, K, V), and returns one output - RENORMALIZED_ATTENTION. - - Attributes: - numerical_stabilizer: float, small number used for numerical stability. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - def __init__(self, numerical_stabilizer=0.001, mode='train'): - super().__init__(n_in=3, n_out=1) - self._numerical_stabilizer = numerical_stabilizer - self._mode = mode - - def forward(self, inputs): - def favor_numerator_fwd(init_prefix_sum_value, - query_prime, key_prime, value): - def body(p, qkv): - (q, k, v) = qkv - p += jnp.einsum('...m,...d->...md', k, v) - x_slice = jnp.einsum('...m,...md->...d', q, p) - return p, x_slice - p, w = fastmath.scan(body, init_prefix_sum_value, - (query_prime, key_prime, value)) - return w, (p, query_prime, key_prime, value) - - def favor_numerator_bwd(pqkv, w_ct): - p, qs, ks, vs = pqkv - - def body(carry, qkv_xct): - p, p_ct = carry - q, k, v, x_ct = qkv_xct - q_ct = jnp.einsum('...d,...md->...m', x_ct, p) - p_ct += jnp.einsum('...d,...m->...md', x_ct, q) - k_ct = jnp.einsum('...md,...d->...m', p_ct, v) - v_ct = jnp.einsum('...md,...m->...d', p_ct, k) - p -= jnp.einsum('...m,...d->...md', k, v) - return (p, p_ct), (q_ct, k_ct, v_ct) - - _, (qs_ct, ks_ct, vs_ct) = fastmath.scan( - body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) - return (None, qs_ct, ks_ct, vs_ct) - - def favor_numerator(init_prefix_sum_value, query_prime, - key_prime, value): - w, _ = favor_numerator_fwd(init_prefix_sum_value, - query_prime, key_prime, value) - return w - - favor_numerator = fastmath.custom_vjp( - favor_numerator, favor_numerator_fwd, favor_numerator_bwd) - - def favor_denominator_fwd(init_prefix_sum_value, - query_prime, key_prime): - def body(p, qk): - q, k = qk - p += k - x = jnp.einsum('...m,...m->...', q, p) - return p, x - - p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, - key_prime)) - return r, (query_prime, key_prime, p) - - def favor_denominator_bwd(qkp, r_ct): - qs, ks, p = qkp - - def body(carry, qkx): - p, p_ct = carry - q, k, x_ct = qkx - q_ct = jnp.einsum('...,...m->...m', x_ct, p) - p_ct += jnp.einsum('...,...m->...m', x_ct, q) - k_ct = p_ct - p -= k - return (p, p_ct), (q_ct, k_ct) - - _, (qs_ct, ks_ct) = fastmath.scan( - body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True) - return (None, qs_ct, ks_ct) - - def favor_denominator(init_prefix_sum_value, query_prime, - key_prime): - r, _ = favor_denominator_fwd(init_prefix_sum_value, - query_prime, key_prime) - return r - - favor_denominator = fastmath.custom_vjp( - favor_denominator, favor_denominator_fwd, favor_denominator_bwd) - - favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) + """Returns a layer that maps activations to activations, with causal masking. - def relu(x): - return jnp.where(x <= 0, jnp.zeros_like(x), x) - - query, key, value = inputs - query_prime = relu(query) + self._numerical_stabilizer - key_prime = relu(key) + self._numerical_stabilizer - prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) - t_slice_shape = (key.shape[0], key.shape[-1]) - init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) - init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) - - w = favor_numerator(init_prefix_sum_value_numerator, - jnp.moveaxis(query_prime, 1, 0), - jnp.moveaxis(key_prime, 1, 0), - jnp.moveaxis(value, 1, 0)) - r = favor_denominator(init_prefix_sum_value_denominator, - jnp.moveaxis(query_prime, 1, 0), - jnp.moveaxis(key_prime, 1, 0)) - w = jnp.moveaxis(w, 0, 1) - r = jnp.moveaxis(r, 0, 1) - r = jnp.reciprocal(r) - r = jnp.expand_dims(r, len(r.shape)) - renormalized_attention = w * r - return renormalized_attention - - -def CausalFavor(d_feature, n_heads=1, dropout=0.0, - numerical_stabilizer=0.001, mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - causal attention, but using FAVOR fast attention as in the following paper: - https://arxiv.org/abs/2006.03555 - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - numerical_stabilizer: float, small number used for numerical stability. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - del dropout - return tl.ConfigurableAttention( - core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), - core.Dense(d_feature), n_heads=n_heads, - qkv_attention_layer=tl.CausalFavorAttention(numerical_stabilizer, - mode)) + Like `CausalAttention`, this layer type represents one pass of multi-head + causal attention, but using FAVOR fast attention as in the following paper: + https://arxiv.org/abs/2006.03555 + Layer expects three inputs (Q, K, V), and returns one output + RENORMALIZED_ATTENTION. -class _RememberInReverse(base.Layer): - """Layer remembering the input in forward pass. For reversible models.""" - - def __init__(self, output=True): - """Layer remembering the input in forward pass. For reversible models. + Attributes: + numerical_stabilizer: float, small number used for numerical stability. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ - During the first pass through the model this layer saves the input as - state, and returns the input unmodified. During the second pass through the - model the layer outputs the input from the first pass. This is used to - combat numerical stability problems in Terraformer. It doesn't do anything - in non-reversible models. + def __init__(self, numerical_stabilizer=0.001, mode="train"): + super().__init__(n_in=3, n_out=1) + self._numerical_stabilizer = numerical_stabilizer + self._mode = mode + + def forward(self, inputs): + def favor_numerator_fwd(init_prefix_sum_value, query_prime, key_prime, value): + def body(p, qkv): + (q, k, v) = qkv + p += jnp.einsum("...m,...d->...md", k, v) + x_slice = jnp.einsum("...m,...md->...d", q, p) + return p, x_slice + + p, w = fastmath.scan( + body, init_prefix_sum_value, (query_prime, key_prime, value) + ) + return w, (p, query_prime, key_prime, value) + + def favor_numerator_bwd(pqkv, w_ct): + p, qs, ks, vs = pqkv + + def body(carry, qkv_xct): + p, p_ct = carry + q, k, v, x_ct = qkv_xct + q_ct = jnp.einsum("...d,...md->...m", x_ct, p) + p_ct += jnp.einsum("...d,...m->...md", x_ct, q) + k_ct = jnp.einsum("...md,...d->...m", p_ct, v) + v_ct = jnp.einsum("...md,...m->...d", p_ct, k) + p -= jnp.einsum("...m,...d->...md", k, v) + return (p, p_ct), (q_ct, k_ct, v_ct) + + _, (qs_ct, ks_ct, vs_ct) = fastmath.scan( + body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True + ) + return (None, qs_ct, ks_ct, vs_ct) + + def favor_numerator(init_prefix_sum_value, query_prime, key_prime, value): + w, _ = favor_numerator_fwd( + init_prefix_sum_value, query_prime, key_prime, value + ) + return w + + favor_numerator = fastmath.custom_vjp( + favor_numerator, favor_numerator_fwd, favor_numerator_bwd + ) + + def favor_denominator_fwd(init_prefix_sum_value, query_prime, key_prime): + def body(p, qk): + q, k = qk + p += k + x = jnp.einsum("...m,...m->...", q, p) + return p, x + + p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime)) + return r, (query_prime, key_prime, p) + + def favor_denominator_bwd(qkp, r_ct): + qs, ks, p = qkp + + def body(carry, qkx): + p, p_ct = carry + q, k, x_ct = qkx + q_ct = jnp.einsum("...,...m->...m", x_ct, p) + p_ct += jnp.einsum("...,...m->...m", x_ct, q) + k_ct = p_ct + p -= k + return (p, p_ct), (q_ct, k_ct) + + _, (qs_ct, ks_ct) = fastmath.scan( + body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True + ) + return (None, qs_ct, ks_ct) + + def favor_denominator(init_prefix_sum_value, query_prime, key_prime): + r, _ = favor_denominator_fwd(init_prefix_sum_value, query_prime, key_prime) + return r + + favor_denominator = fastmath.custom_vjp( + favor_denominator, favor_denominator_fwd, favor_denominator_bwd + ) + + favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) + + def relu(x): + return jnp.where(x <= 0, jnp.zeros_like(x), x) + + query, key, value = inputs + query_prime = relu(query) + self._numerical_stabilizer + key_prime = relu(key) + self._numerical_stabilizer + prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) + t_slice_shape = (key.shape[0], key.shape[-1]) + init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) + init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) + + w = favor_numerator( + init_prefix_sum_value_numerator, + jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0), + jnp.moveaxis(value, 1, 0), + ) + r = favor_denominator( + init_prefix_sum_value_denominator, + jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0), + ) + w = jnp.moveaxis(w, 0, 1) + r = jnp.moveaxis(r, 0, 1) + r = jnp.reciprocal(r) + r = jnp.expand_dims(r, len(r.shape)) + renormalized_attention = w * r + return renormalized_attention + + +def CausalFavor( + d_feature, n_heads=1, dropout=0.0, numerical_stabilizer=0.001, mode="train" +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + causal attention, but using FAVOR fast attention as in the following paper: + https://arxiv.org/abs/2006.03555 Args: - output: Whether to pass the input or not. + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + numerical_stabilizer: float, small number used for numerical stability. + mode: One of `'train'`, `'eval'`, or `'predict'`. """ - n_out = 1 if output else 0 - self._output = output - super().__init__(name='_RememberInReverse', n_out=n_out) - - def forward(self, x): - if 'running_second_time_yes' in self.state[1]: - result = self.state[0] - else: - result = x - self.state = (x, {'running_second_time': ()}) + del dropout + return tl.ConfigurableAttention( + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + n_heads=n_heads, + qkv_attention_layer=tl.CausalFavorAttention(numerical_stabilizer, mode), + ) - if self._output: - return result - else: - return tuple() - def init_weights_and_state(self, input_signature): - """Initializes this layer's weights.""" - if isinstance(input_signature, (list, tuple)): - input_signature = input_signature[0] - self.weights = () - self.state = (jnp.zeros(input_signature.shape, dtype=jnp.int32), - {'running_second_time': ()}) +class _RememberInReverse(base.Layer): + """Layer remembering the input in forward pass. For reversible models.""" + + def __init__(self, output=True): + """Layer remembering the input in forward pass. For reversible models. + + During the first pass through the model this layer saves the input as + state, and returns the input unmodified. During the second pass through the + model the layer outputs the input from the first pass. This is used to + combat numerical stability problems in Terraformer. It doesn't do anything + in non-reversible models. + + Args: + output: Whether to pass the input or not. + """ + n_out = 1 if output else 0 + self._output = output + super().__init__(name="_RememberInReverse", n_out=n_out) + + def forward(self, x): + if "running_second_time_yes" in self.state[1]: + result = self.state[0] + else: + result = x + self.state = (x, {"running_second_time": ()}) + + if self._output: + return result + else: + return tuple() + + def init_weights_and_state(self, input_signature): + """Initializes this layer's weights.""" + if isinstance(input_signature, (list, tuple)): + input_signature = input_signature[0] + self.weights = () + self.state = ( + jnp.zeros(input_signature.shape, dtype=jnp.int32), + {"running_second_time": ()}, + ) class _RecallQuantMaskInReverse(base.Layer): - """Layer recalling quant mask from specific _RememberInReverse. - - This layer is needed for memory-efficient training of reversible model with - ff chunking. During forward pass it simply returns minus ones, which are - ignored in the controller. During reverse_and_grad it returns a quant_mask - which was memorized (saved to state) by a RememberInReverse layer. - - This enable us to save quant_mask right after chunking, and load it again - (when reversing) right before chunking. - """ - - def __init__(self, remember_layer, elements): - self._remember_layer = remember_layer - self._elements = elements - super().__init__(name='_RecallQuantMaskInReverse', n_in=1, n_out=2) - - def forward(self, x): - if (self._remember_layer.state and - 'running_second_time_yes' in self._remember_layer.state[1]): - # It's reverse_and_grad, so we pull the quant_mask from remembering layer. - result = self._remember_layer.state[0] - else: - result = -jnp.ones((x.shape[0], self._elements), dtype=jnp.int32) - return (x, result) - + """Layer recalling quant mask from specific _RememberInReverse. -class _SparseFFController(base.Layer): - """The controller part of Sparse Feed-Forward layer.""" - - def __init__(self, d_ff, n_elements_in_block, d_lowrank, temperature, - use_bfloat16, mode, kernel_initializer, bias_initializer, - also_return_nondiscrete_output): - """Returns a sparse feed-forward block.""" - n_out = 2 if also_return_nondiscrete_output else 1 - super().__init__(name=f'_SparseFFController_{d_ff}', n_in=2, n_out=n_out) - self._use_bfloat16 = use_bfloat16 - self._d_ff = d_ff - self._d_lowrank = d_lowrank - # Q: what temperature is actually most useful in training? - self._temperature = temperature if mode == 'train' else 0.0 - self._mode = mode - self._n_elements_in_block = n_elements_in_block - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - # Helper numbers as d_ff will be divided by n_elements_in_block. - assert self._d_ff % self._n_elements_in_block == 0 - self._d1 = self._d_ff // self._n_elements_in_block - self._d2 = self._n_elements_in_block - self._also_return_nondiscrete_output = also_return_nondiscrete_output - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + This layer is needed for memory-efficient training of reversible model with + ff chunking. During forward pass it simply returns minus ones, which are + ignored in the controller. During reverse_and_grad it returns a quant_mask + which was memorized (saved to state) by a RememberInReverse layer. - Returns: - Tensor of same shape and dtype as the input. + This enable us to save quant_mask right after chunking, and load it again + (when reversing) right before chunking. """ - x, recalled_quant_mask = x - m1, m2, mb = self.weights - - x_shape = x.shape - x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - # Q: should we add bias and/or put relu after the low-rank m1 dot? - # Replacing multiplication and reshape by this einsum brings training speed - # improvement (see also reshape in initialization). - mask_logits = jnp.einsum('bd,dl,lxy->bxy', x, m1, m2) + mb - - if self._also_return_nondiscrete_output: - # Softmax. - mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) - log_mask = mask_logits - mask_logsumexp - mask = jnp.exp(log_mask) - # Gumbel-softmax with straight-through discretization. - if self._temperature == 0.0: - quant_mask = jnp.argmax(log_mask, axis=-1) - else: - u = fastmath.random.uniform(self.rng, mask.shape, jnp.float32, 1e-6, - 1.0 - 1e-6) - g = -jnp.log(-jnp.log(u)) - quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) - else: - quant_mask = jnp.argmax(mask_logits, axis=-1) - if self._mode == 'train': - # We use recalled_quant_mask if it's different than -1; otherwise - # we use a quant_mask which we have just computed. - quant_mask = jnp.where(recalled_quant_mask == -1, - quant_mask, recalled_quant_mask) + def __init__(self, remember_layer, elements): + self._remember_layer = remember_layer + self._elements = elements + super().__init__(name="_RecallQuantMaskInReverse", n_in=1, n_out=2) - if self._also_return_nondiscrete_output: - return quant_mask, mask - else: - return quant_mask + def forward(self, x): + if ( + self._remember_layer.state + and "running_second_time_yes" in self._remember_layer.state[1] + ): + # It's reverse_and_grad, so we pull the quant_mask from remembering layer. + result = self._remember_layer.state[0] + else: + result = -jnp.ones((x.shape[0], self._elements), dtype=jnp.int32) + return (x, result) - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - x_input_signature = input_signature[0] - d_model = x_input_signature.shape[-1] - shape_m1 = (d_model, self._d_lowrank) - shape_m2 = (self._d_lowrank, self._d_ff) - shape_mb = (self._d_ff,) - rng_m1, rng_m2, rng_mb = fastmath.random.split(self.rng, 3) - m1 = self._kernel_initializer(shape_m1, rng_m1) - m2 = self._kernel_initializer(shape_m2, rng_m2) - mb = self._bias_initializer(shape_mb, rng_mb) - if self._use_bfloat16: - m1 = m1.astype(jnp.bfloat16) - m2 = m2.astype(jnp.bfloat16) - mb = mb.astype(jnp.bfloat16) +class _SparseFFController(base.Layer): + """The controller part of Sparse Feed-Forward layer.""" + + def __init__( + self, + d_ff, + n_elements_in_block, + d_lowrank, + temperature, + use_bfloat16, + mode, + kernel_initializer, + bias_initializer, + also_return_nondiscrete_output, + ): + """Returns a sparse feed-forward block.""" + n_out = 2 if also_return_nondiscrete_output else 1 + super().__init__(name=f"_SparseFFController_{d_ff}", n_in=2, n_out=n_out) + self._use_bfloat16 = use_bfloat16 + self._d_ff = d_ff + self._d_lowrank = d_lowrank + # Q: what temperature is actually most useful in training? + self._temperature = temperature if mode == "train" else 0.0 + self._mode = mode + self._n_elements_in_block = n_elements_in_block + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + # Helper numbers as d_ff will be divided by n_elements_in_block. + assert self._d_ff % self._n_elements_in_block == 0 + self._d1 = self._d_ff // self._n_elements_in_block + self._d2 = self._n_elements_in_block + self._also_return_nondiscrete_output = also_return_nondiscrete_output + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + x, recalled_quant_mask = x + m1, m2, mb = self.weights + + x_shape = x.shape + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + # Q: should we add bias and/or put relu after the low-rank m1 dot? + # Replacing multiplication and reshape by this einsum brings training speed + # improvement (see also reshape in initialization). + mask_logits = jnp.einsum("bd,dl,lxy->bxy", x, m1, m2) + mb + + if self._also_return_nondiscrete_output: + # Softmax. + mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) + log_mask = mask_logits - mask_logsumexp + mask = jnp.exp(log_mask) + # Gumbel-softmax with straight-through discretization. + if self._temperature == 0.0: + quant_mask = jnp.argmax(log_mask, axis=-1) + else: + u = fastmath.random.uniform( + self.rng, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6 + ) + g = -jnp.log(-jnp.log(u)) + quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) + else: + quant_mask = jnp.argmax(mask_logits, axis=-1) + + if self._mode == "train": + # We use recalled_quant_mask if it's different than -1; otherwise + # we use a quant_mask which we have just computed. + quant_mask = jnp.where( + recalled_quant_mask == -1, quant_mask, recalled_quant_mask + ) + + if self._also_return_nondiscrete_output: + return quant_mask, mask + else: + return quant_mask + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + x_input_signature = input_signature[0] + d_model = x_input_signature.shape[-1] + shape_m1 = (d_model, self._d_lowrank) + shape_m2 = (self._d_lowrank, self._d_ff) + shape_mb = (self._d_ff,) + + rng_m1, rng_m2, rng_mb = fastmath.random.split(self.rng, 3) + m1 = self._kernel_initializer(shape_m1, rng_m1) + m2 = self._kernel_initializer(shape_m2, rng_m2) + mb = self._bias_initializer(shape_mb, rng_mb) + if self._use_bfloat16: + m1 = m1.astype(jnp.bfloat16) + m2 = m2.astype(jnp.bfloat16) + mb = mb.astype(jnp.bfloat16) + + # Reshapes below, with einsum in feedforward, improve the training speed. + m2 = jnp.reshape(m2, [self._d_lowrank, self._d1, self._d2]) + mb = jnp.reshape(mb, [self._d1, self._d2]) + + self.weights = (m1, m2, mb) - # Reshapes below, with einsum in feedforward, improve the training speed. - m2 = jnp.reshape(m2, [self._d_lowrank, self._d1, self._d2]) - mb = jnp.reshape(mb, [self._d1, self._d2]) - self.weights = (m1, m2, mb) +class _SparseFFMain(base.Layer): + """The main (non-controller) part of Sparse Feed-Forward layer.""" + + def __init__( + self, + d_ff, + n_elements_in_block, + d_lowrank, + quant_prob, + use_bfloat16, + big_weights_in_bfloat16, + mode, + kernel_initializer, + bias_initializer, + multiply_by_controller_output, + kernel_scaling, + ): + """Returns a sparse feed-forward block.""" + n_in = 3 if mode == "train" or multiply_by_controller_output else 2 + super().__init__(name=f"_SparseFFMain_{d_ff}", n_in=n_in, n_out=2) + self._mode = mode + self._use_bfloat16 = use_bfloat16 + self._big_weights_in_bfloat16 = big_weights_in_bfloat16 + self._d_ff = d_ff + self._d_lowrank = d_lowrank + self._quant_prob = quant_prob + self._n_elements_in_block = n_elements_in_block + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + # Helper numbers as d_ff will be divided by n_elements_in_block. + assert self._d_ff % self._n_elements_in_block == 0 + self._d1 = self._d_ff // self._n_elements_in_block + self._d2 = self._n_elements_in_block + self._multiply_by_controller_output = multiply_by_controller_output + self._kernel_scaling = kernel_scaling + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + if self._mode == "train" or self._multiply_by_controller_output: + quant_mask, mask, x = x + else: + quant_mask, x = x + original_quant_mask = quant_mask + + w1, w2, b2 = self.weights + + if self._mode == "predict": + w1 = jnp.transpose(w1, (1, 2, 0)) # dm, d1, d2 -> d1, d2, dm + w2 = jnp.transpose(w2, (1, 0, 2)) # d2, d1, dm -> d1, d2, dm + x_shape = x.shape + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + if self._mode == "train": + # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 + quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) + quant_mask = fastmath.stop_gradient(quant_mask) + quant_mask += mask - fastmath.stop_gradient(mask) # straight-through + # We will sometimes (quant_prob of the batches) use the soft-mask instead + # of the quantized mask to improve training stability (see paper above). + select = fastmath.random.uniform(self.rng, (), jnp.float32, 0.0, 1.0) + quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) + + # In training, run full matmul to get benefits from the above tricks. + mid = jnp.einsum("bd,dxy->bxy", x, w1) * quant_mask + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + if self._multiply_by_controller_output: + # We multiply only for quantized decisions, since for non-quantized + # decisions we've already multiplied the output. + mask_mult = jnp.where( + select < self._quant_prob, mask, jnp.ones_like(mask) + ) + # Stop-gradient is here, because we already have a pass-through gradient + # (for quantized decisions). + mask_mult = fastmath.stop_gradient(mask_mult) + relu = relu * mask_mult + res = jnp.einsum("bxy,yxd->bd", relu, w2) + b2 + elif self._mode == "predict": + # This implementation mimicks inference. It's not efficient for large + # size of joint_batch, but at inference that will be 1 most of the time. + # Shapes: + # quant_mask is [joint_batch, self._d1] + # w1 is [d_model, self._d1, self._d2] + # we'll index w1 with advanced numpy indexing, first range over + # self._d1 times the batch size, second range being quant_mask + batch_size = quant_mask.shape[0] + idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) + # flatten indices and select from w1 + idx1 = jnp.reshape(idx1, [-1]) + idx2 = jnp.reshape(quant_mask, [-1]) + w = w1[idx1, idx2, :] # now we have per-element weights with batch dim + w = jnp.reshape(w, [batch_size, self._d1, -1]) + mid = jnp.einsum("ai,aji->aj", x, w) + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + if self._multiply_by_controller_output: + mask_mult = jnp.take_along_axis(mask, quant_mask[..., None], -1)[..., 0] + relu = relu * mask_mult + # w2 is [self._d1, self._d2, d_model] + v = w2[idx1, idx2, :] + v = jnp.reshape(v, [batch_size, self._d1, -1]) + res = jnp.einsum("ai,aij->aj", relu, v) + b2 + else: + quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) + mid = jnp.einsum("bd,dxy->bxy", x, w1) * quant_mask + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + if self._multiply_by_controller_output: + relu = relu * mask + res = jnp.einsum("bxy,yxd->bd", relu, w2) + b2 + + return original_quant_mask, jnp.reshape(res, x_shape) + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + d_model = input_signature[-1].shape[-1] + shape_w1 = (d_model, self._d_ff) + shape_w2 = (self._d_ff, d_model) + shape_b2 = (d_model,) + + rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3) + if base.N_WEIGHTS_SHARDS > 1: + # In sharded-weights mode, put the weights on CPU on init + # as they will be sharded later. + w1 = tl.on_cpu(self._kernel_initializer(shape_w1, rng_w1)) + w2 = tl.on_cpu(self._kernel_initializer(shape_w2, rng_w2)) + else: + w1 = self._kernel_initializer(shape_w1, rng_w1) + w2 = self._kernel_initializer(shape_w2, rng_w2) + + b2 = self._bias_initializer(shape_b2, rng_b2) + if self._use_bfloat16: + b2 = b2.astype(jnp.bfloat16) + if self._use_bfloat16 or self._big_weights_in_bfloat16: + w1 = w1.astype(jnp.bfloat16) + w2 = w2.astype(jnp.bfloat16) + + w1 = jnp.reshape(w1, (-1, self._d1, self._d2)) + w2 = jnp.reshape(w2, (self._d2, self._d1, -1)) + + if self._kernel_scaling: + # This keeps expected variance of the output regardless of N. + w2 = w2 * (self._n_elements_in_block**0.5) + + self.weights = (w1, w2, b2) -class _SparseFFMain(base.Layer): - """The main (non-controller) part of Sparse Feed-Forward layer.""" - - def __init__(self, d_ff, n_elements_in_block, d_lowrank, quant_prob, - use_bfloat16, big_weights_in_bfloat16, mode, kernel_initializer, - bias_initializer, multiply_by_controller_output, kernel_scaling): - """Returns a sparse feed-forward block.""" - n_in = 3 if mode == 'train' or multiply_by_controller_output else 2 - super().__init__(name=f'_SparseFFMain_{d_ff}', n_in=n_in, n_out=2) - self._mode = mode - self._use_bfloat16 = use_bfloat16 - self._big_weights_in_bfloat16 = big_weights_in_bfloat16 - self._d_ff = d_ff - self._d_lowrank = d_lowrank - self._quant_prob = quant_prob - self._n_elements_in_block = n_elements_in_block - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - # Helper numbers as d_ff will be divided by n_elements_in_block. - assert self._d_ff % self._n_elements_in_block == 0 - self._d1 = self._d_ff // self._n_elements_in_block - self._d2 = self._n_elements_in_block - self._multiply_by_controller_output = multiply_by_controller_output - self._kernel_scaling = kernel_scaling - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. +def SparseFF( + d_ff, + n_elements_in_block=32, + d_lowrank=64, + temperature=0.1, + quant_prob=0.3, + use_bfloat16=False, + big_weights_in_bfloat16=False, + mode="train", + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + dropout_rate=0.0, + dropout_shared_axes=None, + ff_chunk_size=0, + multiply_by_controller_output=False, + kernel_scaling=False, +): + """Returns Feed-forward block with sparsity. + + The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense + that takes an input, makes it of size d_ff (usually larger than it was) and + then brings it back to the original size after Relu. It is commonly used in + Transformer models where it often accounts for most of the trainable weights. + + The original block can be slow in decoding due to the need to fetch a lot of + weights from memory. This sparse block only allows one non-zero element + in a block of a specified size. This is trained with straight-through Gumbel + softmax trick. Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input. + d_ff: Depth/dimensionality of FeedForward layer. + n_elements_in_block: The sparsity level. The layer is divided into blocks of + this size, and each block has only a single element active. + d_lowrank: The dimensionality of low-rank controller. + temperature: The temperature of the controller during training. + quant_prob: During training this proportion of blocks will have quantized + mask (i.e. a single element active). The rest will use a soft mask. + use_bfloat16: Whether to use bfloat16 for weights. + big_weights_in_bfloat16: : Whether to use bfloat16 for main weights of the + FeedForward layer. + mode: One of `'train'`, `'eval'`, or `'predict'`. + kernel_initializer: Function that creates a matrix of (random) initial + connection weights `W` for the layer. + bias_initializer: Function that creates a vector of (random) initial + bias weights `b` for the layer. + dropout_rate: Probability for dropping an activation value. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks. + multiply_by_controller_output: whether to multiply the middle activation + layer of FF by controller output (i.e. softmax). + kernel_scaling: Whether to scale the kernel matrix (during init) to keep the + variance of the layer output regardless of n_elements_in_block. """ - if self._mode == 'train' or self._multiply_by_controller_output: - quant_mask, mask, x = x - else: - quant_mask, x = x - original_quant_mask = quant_mask - - w1, w2, b2 = self.weights - - if self._mode == 'predict': - w1 = jnp.transpose(w1, (1, 2, 0)) # dm, d1, d2 -> d1, d2, dm - w2 = jnp.transpose(w2, (1, 0, 2)) # d2, d1, dm -> d1, d2, dm - x_shape = x.shape - x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - if self._mode == 'train': - # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 - quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) - quant_mask = fastmath.stop_gradient(quant_mask) - quant_mask += mask - fastmath.stop_gradient(mask) # straight-through - # We will sometimes (quant_prob of the batches) use the soft-mask instead - # of the quantized mask to improve training stability (see paper above). - select = fastmath.random.uniform(self.rng, (), jnp.float32, 0.0, 1.0) - quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) - - # In training, run full matmul to get benefits from the above tricks. - mid = jnp.einsum('bd,dxy->bxy', x, w1) * quant_mask - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - if self._multiply_by_controller_output: - # We multiply only for quantized decisions, since for non-quantized - # decisions we've already multiplied the output. - mask_mult = jnp.where(select < self._quant_prob, - mask, jnp.ones_like(mask)) - # Stop-gradient is here, because we already have a pass-through gradient - # (for quantized decisions). - mask_mult = fastmath.stop_gradient(mask_mult) - relu = relu * mask_mult - res = jnp.einsum('bxy,yxd->bd', relu, w2) + b2 - elif self._mode == 'predict': - # This implementation mimicks inference. It's not efficient for large - # size of joint_batch, but at inference that will be 1 most of the time. - # Shapes: - # quant_mask is [joint_batch, self._d1] - # w1 is [d_model, self._d1, self._d2] - # we'll index w1 with advanced numpy indexing, first range over - # self._d1 times the batch size, second range being quant_mask - batch_size = quant_mask.shape[0] - idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) - # flatten indices and select from w1 - idx1 = jnp.reshape(idx1, [-1]) - idx2 = jnp.reshape(quant_mask, [-1]) - w = w1[idx1, idx2, :] # now we have per-element weights with batch dim - w = jnp.reshape(w, [batch_size, self._d1, -1]) - mid = jnp.einsum('ai,aji->aj', x, w) - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - if self._multiply_by_controller_output: - mask_mult = jnp.take_along_axis(mask, quant_mask[..., None], -1)[..., 0] - relu = relu * mask_mult - # w2 is [self._d1, self._d2, d_model] - v = w2[idx1, idx2, :] - v = jnp.reshape(v, [batch_size, self._d1, -1]) - res = jnp.einsum('ai,aij->aj', relu, v) + b2 - else: - quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) - mid = jnp.einsum('bd,dxy->bxy', x, w1) * quant_mask - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - if self._multiply_by_controller_output: - relu = relu * mask - res = jnp.einsum('bxy,yxd->bd', relu, w2) + b2 - - return original_quant_mask, jnp.reshape(res, x_shape) - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - d_model = input_signature[-1].shape[-1] - shape_w1 = (d_model, self._d_ff) - shape_w2 = (self._d_ff, d_model) - shape_b2 = (d_model,) - - rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3) - if base.N_WEIGHTS_SHARDS > 1: - # In sharded-weights mode, put the weights on CPU on init - # as they will be sharded later. - w1 = tl.on_cpu(self._kernel_initializer(shape_w1, rng_w1)) - w2 = tl.on_cpu(self._kernel_initializer(shape_w2, rng_w2)) - else: - w1 = self._kernel_initializer(shape_w1, rng_w1) - w2 = self._kernel_initializer(shape_w2, rng_w2) - - b2 = self._bias_initializer(shape_b2, rng_b2) - if self._use_bfloat16: - b2 = b2.astype(jnp.bfloat16) - if self._use_bfloat16 or self._big_weights_in_bfloat16: - w1 = w1.astype(jnp.bfloat16) - w2 = w2.astype(jnp.bfloat16) - w1 = jnp.reshape(w1, (-1, self._d1, self._d2)) - w2 = jnp.reshape(w2, (self._d2, self._d1, -1)) - - if self._kernel_scaling: - # This keeps expected variance of the output regardless of N. - w2 = w2 * (self._n_elements_in_block ** 0.5) + if mode == "train" or multiply_by_controller_output: + also_return_nondiscrete_output = True + else: + also_return_nondiscrete_output = False + controller = _SparseFFController( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + d_lowrank=d_lowrank, + temperature=temperature, + use_bfloat16=use_bfloat16, + mode=mode, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + also_return_nondiscrete_output=also_return_nondiscrete_output, + ) - self.weights = (w1, w2, b2) + main = [ + _SparseFFMain( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + d_lowrank=d_lowrank, + quant_prob=quant_prob, + use_bfloat16=use_bfloat16, + big_weights_in_bfloat16=big_weights_in_bfloat16, + mode=mode, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + multiply_by_controller_output=multiply_by_controller_output, + kernel_scaling=kernel_scaling, + ), + # quant_mask, emb + tl.Select([1, 0]), + # emb, quant_mask + tl.Dropout(rate=dropout_rate, shared_axes=dropout_shared_axes, mode=mode), + tl.Select([1, 0]), + # quant_mask, emb + ] + # We will "remember" quant_mask _after_ chunking, and "recall" this same + # quant_mask during reverse_and_grad _before_ chunking. + remembering = _RememberInReverse(output=False) + recalling = _RecallQuantMaskInReverse( + remember_layer=remembering, elements=d_ff // n_elements_in_block + ) -def SparseFF( - d_ff, n_elements_in_block=32, d_lowrank=64, temperature=0.1, quant_prob=0.3, - use_bfloat16=False, big_weights_in_bfloat16=False, mode='train', - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6), - dropout_rate=0.0, dropout_shared_axes=None, ff_chunk_size=0, - multiply_by_controller_output=False, kernel_scaling=False): - """Returns Feed-forward block with sparsity. - - The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense - that takes an input, makes it of size d_ff (usually larger than it was) and - then brings it back to the original size after Relu. It is commonly used in - Transformer models where it often accounts for most of the trainable weights. - - The original block can be slow in decoding due to the need to fetch a lot of - weights from memory. This sparse block only allows one non-zero element - in a block of a specified size. This is trained with straight-through Gumbel - softmax trick. - - Args: - d_ff: Depth/dimensionality of FeedForward layer. - n_elements_in_block: The sparsity level. The layer is divided into blocks of - this size, and each block has only a single element active. - d_lowrank: The dimensionality of low-rank controller. - temperature: The temperature of the controller during training. - quant_prob: During training this proportion of blocks will have quantized - mask (i.e. a single element active). The rest will use a soft mask. - use_bfloat16: Whether to use bfloat16 for weights. - big_weights_in_bfloat16: : Whether to use bfloat16 for main weights of the - FeedForward layer. - mode: One of `'train'`, `'eval'`, or `'predict'`. - kernel_initializer: Function that creates a matrix of (random) initial - connection weights `W` for the layer. - bias_initializer: Function that creates a vector of (random) initial - bias weights `b` for the layer. - dropout_rate: Probability for dropping an activation value. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks. - multiply_by_controller_output: whether to multiply the middle activation - layer of FF by controller output (i.e. softmax). - kernel_scaling: Whether to scale the kernel matrix (during init) to keep the - variance of the layer output regardless of n_elements_in_block. - """ - - if mode == 'train' or multiply_by_controller_output: - also_return_nondiscrete_output = True - else: - also_return_nondiscrete_output = False - controller = _SparseFFController( - d_ff=d_ff, n_elements_in_block=n_elements_in_block, - d_lowrank=d_lowrank, temperature=temperature, - use_bfloat16=use_bfloat16, mode=mode, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - also_return_nondiscrete_output=also_return_nondiscrete_output) - - main = [ - _SparseFFMain( - d_ff=d_ff, n_elements_in_block=n_elements_in_block, - d_lowrank=d_lowrank, quant_prob=quant_prob, use_bfloat16=use_bfloat16, - big_weights_in_bfloat16=big_weights_in_bfloat16, mode=mode, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - multiply_by_controller_output=multiply_by_controller_output, - kernel_scaling=kernel_scaling), - # quant_mask, emb - tl.Select([1, 0]), - # emb, quant_mask - tl.Dropout(rate=dropout_rate, shared_axes=dropout_shared_axes, mode=mode), - tl.Select([1, 0]), - # quant_mask, emb - ] - - # We will "remember" quant_mask _after_ chunking, and "recall" this same - # quant_mask during reverse_and_grad _before_ chunking. - remembering = _RememberInReverse(output=False) - recalling = _RecallQuantMaskInReverse( - remember_layer=remembering, elements=d_ff//n_elements_in_block) - - return tl.BatchLeadingAxes(tl.Serial( - recalling, # emb, quant_mask - tl.Chunk(chunk_size=ff_chunk_size, layer=tl.Serial( - # emb, quant_mask - tl.Select((0, 1, 0)), # emb, quant_mask, emb - controller, # quant_mask, mask, emb - main, # quant_mask, emb/output - )), - remembering, # emb/output - )) + return tl.BatchLeadingAxes( + tl.Serial( + recalling, # emb, quant_mask + tl.Chunk( + chunk_size=ff_chunk_size, + layer=tl.Serial( + # emb, quant_mask + tl.Select((0, 1, 0)), # emb, quant_mask, emb + controller, # quant_mask, mask, emb + main, # quant_mask, emb/output + ), + ), + remembering, # emb/output + ) + ) class BlockSparseFF(base.Layer): - """Feed-forward block with block sparsity. - - The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense - that takes an input, makes it of size d_ff (usually larger than it was) and - then brings it back to the original size after Relu. It is commonly used in - Transformer models where it often accounts for most of the trainable weights. - - This block sparse layer mimics mixture of experts architecture. - It divides the dimension of d_ff in each weight matrix to # of blocks equal to - n_experts and activates only one non-zero block from the weights matrix. - This is trained with straight-through Gumbel softmax trick. - """ - - def __init__(self, - d_ff, - n_experts=64, - temperature=0.7, - mode='train', - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6)): - """Returns a block sparse feed-forward block.""" - super().__init__(name=f'BlockSparseFF_{d_ff}') - self._mode = mode - self._d_ff = d_ff - self._n_experts = n_experts - self._temperature = temperature if mode == 'train' else 0.0 - self._n_elements_in_block = d_ff // n_experts - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - assert self._d_ff % self._n_experts == 0 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + """Feed-forward block with block sparsity. - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense + that takes an input, makes it of size d_ff (usually larger than it was) and + then brings it back to the original size after Relu. It is commonly used in + Transformer models where it often accounts for most of the trainable weights. - Returns: - Tensor of same shape and dtype as the input. + This block sparse layer mimics mixture of experts architecture. + It divides the dimension of d_ff in each weight matrix to # of blocks equal to + n_experts and activates only one non-zero block from the weights matrix. + This is trained with straight-through Gumbel softmax trick. """ - m1, w1, w2, b2 = self.weights - x_shape = x.shape - x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - # Q: check if we need bias and/or put relu after the m1 dot? - mask_logits = jnp.dot(x, m1) - # Softmax. - mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) - log_mask = mask_logits - mask_logsumexp - mask = jnp.exp(log_mask) - # Gumbel-softmax with straight-through discretization. - # TODO(lukaszkaiser, chowdhery): Extract this block and share - rng1, rng2 = fastmath.random.split(self.rng, 2) - u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) - g = -jnp.log(-jnp.log(u)) - selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) - if self._mode == 'train': - # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 - quant_mask = tl.one_hot(selected_experts, self._n_experts) - quant_mask = fastmath.stop_gradient(quant_mask) - quant_mask += mask - fastmath.stop_gradient(mask) # straight-through - # We will sometimes (50% of the batches) use the soft-mask instead of - # the quantized mask to improve training stability (see the paper above). - # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? - select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) - quant_mask = jnp.where(select > 0.0, quant_mask, mask) - else: - quant_mask = tl.one_hot(selected_experts, self._n_experts) - quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1]) - batch_size = quant_mask.shape[0] - - if self._mode == 'predict' and batch_size == 1: - # This implementation mimicks inference for batch_size 1. - start_idx = selected_experts[0] * self._n_elements_in_block - # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] - w = fastmath.dynamic_slice(w1, [0, start_idx], - [w1.shape[0], self._n_elements_in_block]) - mid = jnp.dot(x, w) - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] - v = fastmath.dynamic_slice(w2, [start_idx, 0], - [self._n_elements_in_block, w2.shape[-1]]) - v = jnp.reshape(v, [self._n_elements_in_block, -1]) - res = jnp.dot(relu, v) + b2 - else: - expanded_mask = jnp.broadcast_to( - quant_mask, - (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block)) - expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) - mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - res = jnp.dot(relu, w2) + b2 - - return jnp.reshape(res, x_shape) # un-flatten if needed - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - d_model = input_signature.shape[-1] - shape_m1 = (d_model, self._n_experts) - shape_w1 = (d_model, self._d_ff) - shape_w2 = (self._d_ff, d_model) - shape_b2 = (d_model,) - rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4) - m1 = self._kernel_initializer(shape_m1, rng_m1) - w1 = self._kernel_initializer(shape_w1, rng_w1) - w2 = self._kernel_initializer(shape_w2, rng_w2) - b2 = self._bias_initializer(shape_b2, rng_b2) - - self.weights = (m1, w1, w2, b2) + def __init__( + self, + d_ff, + n_experts=64, + temperature=0.7, + mode="train", + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + ): + """Returns a block sparse feed-forward block.""" + super().__init__(name=f"BlockSparseFF_{d_ff}") + self._mode = mode + self._d_ff = d_ff + self._n_experts = n_experts + self._temperature = temperature if mode == "train" else 0.0 + self._n_elements_in_block = d_ff // n_experts + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + assert self._d_ff % self._n_experts == 0 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + m1, w1, w2, b2 = self.weights + x_shape = x.shape + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + # Q: check if we need bias and/or put relu after the m1 dot? + mask_logits = jnp.dot(x, m1) + # Softmax. + mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) + log_mask = mask_logits - mask_logsumexp + mask = jnp.exp(log_mask) + # Gumbel-softmax with straight-through discretization. + # TODO(lukaszkaiser, chowdhery): Extract this block and share + rng1, rng2 = fastmath.random.split(self.rng, 2) + u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) + g = -jnp.log(-jnp.log(u)) + selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) + if self._mode == "train": + # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 + quant_mask = tl.one_hot(selected_experts, self._n_experts) + quant_mask = fastmath.stop_gradient(quant_mask) + quant_mask += mask - fastmath.stop_gradient(mask) # straight-through + # We will sometimes (50% of the batches) use the soft-mask instead of + # the quantized mask to improve training stability (see the paper above). + # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? + select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) + quant_mask = jnp.where(select > 0.0, quant_mask, mask) + else: + quant_mask = tl.one_hot(selected_experts, self._n_experts) + quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1]) + batch_size = quant_mask.shape[0] + + if self._mode == "predict" and batch_size == 1: + # This implementation mimicks inference for batch_size 1. + start_idx = selected_experts[0] * self._n_elements_in_block + # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] + w = fastmath.dynamic_slice( + w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block] + ) + mid = jnp.dot(x, w) + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] + v = fastmath.dynamic_slice( + w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]] + ) + v = jnp.reshape(v, [self._n_elements_in_block, -1]) + res = jnp.dot(relu, v) + b2 + else: + expanded_mask = jnp.broadcast_to( + quant_mask, + (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block), + ) + expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) + mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + res = jnp.dot(relu, w2) + b2 + + return jnp.reshape(res, x_shape) # un-flatten if needed + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + d_model = input_signature.shape[-1] + shape_m1 = (d_model, self._n_experts) + shape_w1 = (d_model, self._d_ff) + shape_w2 = (self._d_ff, d_model) + shape_b2 = (d_model,) + + rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4) + m1 = self._kernel_initializer(shape_m1, rng_m1) + w1 = self._kernel_initializer(shape_w1, rng_w1) + w2 = self._kernel_initializer(shape_w2, rng_w2) + b2 = self._bias_initializer(shape_b2, rng_b2) + + self.weights = (m1, w1, w2, b2) class SwitchSparseFF(base.Layer): - """Feed-forward block with switch-style block sparsity. - - The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense - that takes an input, makes it of size d_ff (usually larger than it was) and - then brings it back to the original size after Relu. It is commonly used in - Transformer models where it often accounts for most of the trainable weights. - - This block sparse layer mimics mixture of experts architecture. - It divides the dimension of d_ff in each weight matrix to # of blocks equal to - n_experts and activates only one non-zero block from the weights matrix. - This is trained with methods following the Switch Transformer. - """ - - def __init__(self, - d_ff, - n_experts=64, - temperature=0.1, - mode='train', - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6)): - """Returns a switch-style training block sparse feed-forward block.""" - super().__init__(name=f'SwitchSparseFF_{d_ff}') - self._mode = mode - self._d_ff = d_ff - self._n_experts = n_experts - self._temperature = temperature if mode == 'train' else 0.0 - self._n_elements_in_block = d_ff // n_experts - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - assert self._d_ff % self._n_experts == 0 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + """Feed-forward block with switch-style block sparsity. - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense + that takes an input, makes it of size d_ff (usually larger than it was) and + then brings it back to the original size after Relu. It is commonly used in + Transformer models where it often accounts for most of the trainable weights. - Returns: - Tensor of same shape and dtype as the input. + This block sparse layer mimics mixture of experts architecture. + It divides the dimension of d_ff in each weight matrix to # of blocks equal to + n_experts and activates only one non-zero block from the weights matrix. + This is trained with methods following the Switch Transformer. """ - m1, w1, w2, b2 = self.weights - x_shape = x.shape - x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - # Q: check if we need bias and/or put relu after the m1 dot? - mask_logits = jnp.dot(x, m1) - # Softmax. - mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) - log_mask = mask_logits - mask_logsumexp - mask = jnp.exp(log_mask) - # Gumbel noise to allow sampling from the softmax. - rng1, _ = fastmath.random.split(self.rng, 2) - u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) - g = -jnp.log(-jnp.log(u)) - selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) - quant_mask = tl.one_hot(selected_experts, self._n_experts) - quant_mask = fastmath.stop_gradient(quant_mask) - quant_mask *= mask # go to just the selected expert - quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1]) - batch_size = quant_mask.shape[0] - - if self._mode == 'predict' and batch_size == 1: - mask_flat = jnp.reshape(mask, [-1, self._n_experts]) - selected_flat = jnp.reshape(selected_experts, [-1]) - selected_mask_flat = mask_flat[np.arange(selected_flat.size), - selected_flat] - # This implementation mimicks inference for batch_size 1. - start_idx = selected_experts[0] * self._n_elements_in_block - # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] - w = fastmath.dynamic_slice(w1, [0, start_idx], - [w1.shape[0], self._n_elements_in_block]) - mid = jnp.dot(x, w) - mid *= jnp.reshape(selected_mask_flat, mid.shape[:-1])[..., None] - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] - v = fastmath.dynamic_slice(w2, [start_idx, 0], - [self._n_elements_in_block, w2.shape[-1]]) - v = jnp.reshape(v, [self._n_elements_in_block, -1]) - res = jnp.dot(relu, v) + b2 - else: - expanded_mask = jnp.broadcast_to( - quant_mask, - (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block)) - expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) - mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - res = jnp.dot(relu, w2) + b2 - - return jnp.reshape(res, x_shape) # un-flatten if needed - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - d_model = input_signature.shape[-1] - shape_m1 = (d_model, self._n_experts) - shape_w1 = (d_model, self._d_ff) - shape_w2 = (self._d_ff, d_model) - shape_b2 = (d_model,) - - rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4) - m1 = self._kernel_initializer(shape_m1, rng_m1) - w1 = self._kernel_initializer(shape_w1, rng_w1) - w2 = self._kernel_initializer(shape_w2, rng_w2) - b2 = self._bias_initializer(shape_b2, rng_b2) - - self.weights = (m1, w1, w2, b2) + + def __init__( + self, + d_ff, + n_experts=64, + temperature=0.1, + mode="train", + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + ): + """Returns a switch-style training block sparse feed-forward block.""" + super().__init__(name=f"SwitchSparseFF_{d_ff}") + self._mode = mode + self._d_ff = d_ff + self._n_experts = n_experts + self._temperature = temperature if mode == "train" else 0.0 + self._n_elements_in_block = d_ff // n_experts + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + assert self._d_ff % self._n_experts == 0 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + m1, w1, w2, b2 = self.weights + x_shape = x.shape + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + # Q: check if we need bias and/or put relu after the m1 dot? + mask_logits = jnp.dot(x, m1) + # Softmax. + mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) + log_mask = mask_logits - mask_logsumexp + mask = jnp.exp(log_mask) + # Gumbel noise to allow sampling from the softmax. + rng1, _ = fastmath.random.split(self.rng, 2) + u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) + g = -jnp.log(-jnp.log(u)) + selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) + quant_mask = tl.one_hot(selected_experts, self._n_experts) + quant_mask = fastmath.stop_gradient(quant_mask) + quant_mask *= mask # go to just the selected expert + quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1]) + batch_size = quant_mask.shape[0] + + if self._mode == "predict" and batch_size == 1: + mask_flat = jnp.reshape(mask, [-1, self._n_experts]) + selected_flat = jnp.reshape(selected_experts, [-1]) + selected_mask_flat = mask_flat[np.arange(selected_flat.size), selected_flat] + # This implementation mimicks inference for batch_size 1. + start_idx = selected_experts[0] * self._n_elements_in_block + # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] + w = fastmath.dynamic_slice( + w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block] + ) + mid = jnp.dot(x, w) + mid *= jnp.reshape(selected_mask_flat, mid.shape[:-1])[..., None] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] + v = fastmath.dynamic_slice( + w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]] + ) + v = jnp.reshape(v, [self._n_elements_in_block, -1]) + res = jnp.dot(relu, v) + b2 + else: + expanded_mask = jnp.broadcast_to( + quant_mask, + (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block), + ) + expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) + mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + res = jnp.dot(relu, w2) + b2 + + return jnp.reshape(res, x_shape) # un-flatten if needed + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + d_model = input_signature.shape[-1] + shape_m1 = (d_model, self._n_experts) + shape_w1 = (d_model, self._d_ff) + shape_w2 = (self._d_ff, d_model) + shape_b2 = (d_model,) + + rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4) + m1 = self._kernel_initializer(shape_m1, rng_m1) + w1 = self._kernel_initializer(shape_w1, rng_w1) + w2 = self._kernel_initializer(shape_w2, rng_w2) + b2 = self._bias_initializer(shape_b2, rng_b2) + + self.weights = (m1, w1, w2, b2) diff --git a/trax/layers/research/sparsity_test.py b/trax/layers/research/sparsity_test.py deleted file mode 100644 index dd39091aa..000000000 --- a/trax/layers/research/sparsity_test.py +++ /dev/null @@ -1,466 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.research.efficient_attention.""" - -import functools -from absl.testing import parameterized -import jax -import numpy as np -from tensorflow import test - -from trax import fastmath -from trax import shapes -import trax.layers as tl -from trax.layers import test_utils -from trax.layers.research import sparsity - - -class EfficientFeedForwardTest(test.TestCase, parameterized.TestCase): - - def test_blocksparse_ff_train(self): - d_model = 1024 - n_experts = 64 - d_ff = d_model * 8 - x_shape = (3, 7, d_model) - with fastmath.use_backend(fastmath.Backend.JAX): - layer = sparsity.BlockSparseFF( - d_ff=d_ff, n_experts=n_experts, temperature=0.7, mode='train') - x = np.ones(x_shape).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_blocksparse_ff_predict_equals_eval(self): - d_model = 1024 - n_experts = 64 - d_ff = d_model * 8 - x_shape = (1, 1, d_model) - temperature = 0.7 - with fastmath.use_backend(fastmath.Backend.JAX): - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - common_kwargs = dict( - d_ff=d_ff, - n_experts=n_experts, - temperature=temperature, - ) - eval_model = sparsity.BlockSparseFF( - mode='eval', **common_kwargs) - weights, state = eval_model.init(input_signature) - eval_out, _ = eval_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - pred_model = sparsity.BlockSparseFF( - mode='predict', **common_kwargs) - _, _ = pred_model.init(input_signature) - pred_out, _ = pred_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(eval_out.shape, x.shape) - # eval_out and pred_out should be identical. - np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :]) - - def test_sparse_ff_predict_equals_eval(self): - with fastmath.use_backend(fastmath.Backend.JAX): - d_model = 64 - seq_len = 6 - x_shape = (1, seq_len, d_model) - inp = np.ones(x_shape).astype(np.float32) - - model_fn = functools.partial( - sparsity.SparseFF, - d_ff=256, - temperature=0.7, - n_elements_in_block=8, - ) - - configs = [ - {'multiply_by_controller_output': True}, - {'multiply_by_controller_output': False}, - {'ff_chunk_size': 2}, - ] - - test_utils.test_eval_equals_predict_configs(inp, model_fn, configs) - - @parameterized.named_parameters(('_mode_train', 'train'), - ('_mode_eval', 'eval'), - ('_mode_predict', 'predict')) - def test_sparse_ff_with_chunking(self, mode): - d_model = 8 - n_elements_in_block = 2 - d_ff = 16 - x_shape = (2, 8, d_model) - temperature = 0.7 - with fastmath.use_backend(fastmath.Backend.JAX): - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - model = sparsity.SparseFF( - d_ff=d_ff, - n_elements_in_block=n_elements_in_block, - temperature=temperature, - ff_chunk_size=4, - mode=mode) - weights, state = model.init(input_signature) - out, _ = model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(out.shape, x.shape) - - @parameterized.named_parameters(('_mode_train', 'train'), - ('_mode_eval', 'eval'), - ('_mode_predict', 'predict')) - def test_sparse_ff_multiply(self, mode): - d_model = 8 - n_elements_in_block = 2 - d_ff = 16 - x_shape = (2, 8, d_model) - temperature = 0.7 - with fastmath.use_backend(fastmath.Backend.JAX): - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - model = sparsity.SparseFF( - d_ff=d_ff, - n_elements_in_block=n_elements_in_block, - temperature=temperature, - ff_chunk_size=4, - mode=mode, - multiply_by_controller_output=True) - weights, state = model.init(input_signature) - out, _ = model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(out.shape, x.shape) - - def test_sparse_ff_kernel_scaling(self): - d_model = 8 - n_elements_in_block = 2 - d_ff = 16 - x_shape = (2, 8, d_model) - temperature = 0.7 - with fastmath.use_backend(fastmath.Backend.JAX): - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - model = sparsity.SparseFF( - d_ff=d_ff, - n_elements_in_block=n_elements_in_block, - temperature=temperature, - ff_chunk_size=4, - mode='train', - kernel_scaling=True) - weights, state = model.init(input_signature) - out, _ = model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(out.shape, x.shape) - - def test_switchsparse_ff_train(self): - d_model = 1024 - n_experts = 64 - d_ff = d_model * 8 - x_shape = (3, 7, d_model) - layer = sparsity.SwitchSparseFF( - d_ff=d_ff, n_experts=n_experts, mode='train') - x = np.ones(x_shape).astype(np.float32) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_switchsparse_ff_predict_equals_eval(self): - d_model = 1024 - n_experts = 64 - d_ff = d_model * 8 - x_shape = (1, 1, d_model) - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - eval_model = sparsity.SwitchSparseFF( - mode='eval', d_ff=d_ff, n_experts=n_experts) - weights, state = eval_model.init(input_signature) - eval_out, _ = eval_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - pred_model = sparsity.SwitchSparseFF( - mode='predict', d_ff=d_ff, n_experts=n_experts) - pred_model.init(input_signature) - pred_out, _ = pred_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(eval_out.shape, x.shape) - # eval_out and pred_out should be identical. - np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :]) - - -class ReversibleReshapePermuteTest(test.TestCase): - - def test_reversible_permute(self): - layer = sparsity.ReversibleReshapePermute() - x = np.array([[1, 2, 3, 4, 5, 6, 7, 8], - [0, 1, 2, 3, 4, 5, 6, 7]]) - layer.init(shapes.signature(x)) - ys = layer(x) - self.assertEqual(tl.to_list(ys), [ - [1, 3, 5, 7, 2, 4, 6, 8], - [0, 2, 4, 6, 1, 3, 5, 7]]) - rev_x = layer.reverse(ys, weights=layer.weights) - self.assertEqual(tl.to_list(x), tl.to_list(rev_x)) - - -class ReversibleRandomPermuteTest(test.TestCase): - - def test_reversible_permute(self): - layer = sparsity.ReversibleRandomPermute() - x = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - [0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 11, 12, 13], - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - ]) - layer.init(shapes.signature(x)) - ys = layer(x) - # this assert will fail once per ~87B runs, but it's okay - self.assertNotEqual(tl.to_list(ys), tl.to_list(x)) - - self.assertEqual(tl.to_list(ys[0]), tl.to_list(ys[2])) - self.assertNotEqual(tl.to_list(ys[0]), tl.to_list(ys[1])) - rev_x = layer.reverse(ys, weights=layer.weights) - self.assertEqual(tl.to_list(x), tl.to_list(rev_x)) - - -class LocallyConnectedDenseTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.LocallyConnectedDense(2, 8) - x = np.array([[2, 5, 3, 4], - [0, 1, 2, 3]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (2, 16)) - - -class SparseDenseWithOptionsTest(test.TestCase): - - def test_simple_call(self): - d_input, d_output = 16, 32 - settings = [ - (None, 0, 0, False), - (None, 0, 0, True), - ('einsum', 0, 0, False), - ('lowrank', 0, 8, False), - ('mult', 2, 0, False), - ('mult', 2, 0, True), - ('local', 2, 0, False), - ('local3', 2, 0, False), - ] - for stype, sparsity_level, d_lowrank, use_bfloat16 in settings: - layer = sparsity.SparseDenseWithOptions( - d_output, d_input=d_input, sparsity_type=stype, - sparsity=sparsity_level, d_lowrank=d_lowrank, - use_bfloat16=use_bfloat16) - x = np.ones((1, 1, d_input)) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (1, 1, d_output), - msg='[{}->{}] {} - {} - {} - {}'.format( - d_input, d_output, stype, sparsity_level, d_lowrank, - use_bfloat16)) - - -class ModularCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.ModularCausalAttention( - d_feature=4, n_heads=2, sparsity=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - -class LowRankCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.LowRankCausalAttention( - d_feature=4, n_heads=2, lowrank=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - -class MultiplicativeCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.MultiplicativeCausalAttention( - d_feature=4, n_heads=2, sparsity=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - -class MultiplicativeModularCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.MultiplicativeModularCausalAttention( - d_feature=4, n_heads=2, sparsity=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - -class MultiplicativeConvCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.MultiplicativeConvCausalAttention( - d_feature=4, n_heads=2, sparsity=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - def test_various_calls(self): - list_kwargs = [] - for share_qk in [True, False]: - for output in ['none', 'mult', 'conv', 'multconv']: - for concat in ['original', 'fixed', 'none']: - kwargs = {'share_qk': share_qk, 'output_layer_type': output, - 'v_concat_type': concat} - list_kwargs.append(kwargs) - for kwargs in list_kwargs: - layer = sparsity.MultiplicativeConvCausalAttention( - d_feature=4, n_heads=2, sparsity=2, **kwargs) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - def test_predict_equals_eval(self): - with fastmath.use_backend(fastmath.Backend.JAX): - d_model = 32 - seq_len = 5 - x_shape = (1, seq_len, d_model) - inp = np.ones(x_shape).astype(np.float32) - - model_fn = functools.partial( - sparsity.MultiplicativeConvCausalAttention, - d_feature=d_model, - n_heads=4, - sparsity=4, - ) - - list_kwargs = [] - for share_qk in [True, False]: - for output in ['none', 'mult', 'conv', 'multconv']: - for concat in ['original', 'fixed', 'none']: - kwargs = {'share_qk': share_qk, 'output_layer_type': output, - 'v_concat_type': concat} - list_kwargs.append(kwargs) - - test_utils.test_eval_equals_predict_configs(inp, model_fn, list_kwargs) - - -class FavorTest(test.TestCase): - - def test_call_and_grad(self): - layer_partial = tl.Serial( - tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), - sparsity.Favor(d_feature=4, n_heads=2), - tl.Select([0], n_in=2), - ) - layer = tl.Serial( - tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), - sparsity.Favor(d_feature=4, n_heads=2), - tl.Select([0], n_in=2), - tl.WeightedCategoryCrossEntropy(), - ) - x = np.ones((1, 2), dtype=np.int32) - w = np.ones_like(x).astype(np.float32) - x_sig = shapes.signature(x) - w_sig = shapes.signature(w) - layer_partial.init(x_sig) - y = layer_partial(x) - self.assertEqual(y.shape, (1, 2, 4)) - layer.init((x_sig, x_sig, w_sig)) - y = layer((x, x, w)) - self.assertEqual(y.shape, ()) - state = layer.state - rng = fastmath.random.get_prng(0) - fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] - g = fastmath.grad(fwd)(layer.weights, (x, x, w)) - self.assertEqual(g[0][1][0].shape, (3, 4)) - - def test_call_and_grad_approximate_softmax(self): - layer_partial = tl.Serial( - tl.Branch(tl.Embedding(11, 12), tl.PaddingMask()), - sparsity.Favor(d_feature=12, n_heads=3, n_random_features=128, - use_approximate_softmax=True), - tl.Select([0], n_in=2), - ) - layer = tl.Serial( - tl.Branch(tl.Embedding(11, 12), tl.PaddingMask()), - sparsity.Favor(d_feature=12, n_heads=3, n_random_features=128, - use_approximate_softmax=True), - tl.Select([0], n_in=2), - tl.WeightedCategoryCrossEntropy(), - ) - x = np.ones((3, 5), dtype=np.int32) - w = np.ones_like(x).astype(np.float32) - x_sig = shapes.signature(x) - w_sig = shapes.signature(w) - layer_partial.init(x_sig) - y = layer_partial(x) - self.assertEqual(y.shape, (3, 5, 12)) - layer.init((x_sig, x_sig, w_sig)) - y = layer((x, x, w)) - self.assertEqual(y.shape, ()) - state = layer.state - rng = fastmath.random.get_prng(0) - fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] - g = fastmath.grad(fwd)(layer.weights, (x, x, w)) - self.assertEqual(g[0][1][0].shape, (11, 12)) - - def test_causal_call_and_grad(self): - layer = tl.Serial( - tl.Dense(4), - sparsity.CausalFavor(d_feature=4, n_heads=2), - tl.L2Loss() - ) - x = np.random.uniform(size=(1, 2, 4)).astype(np.float32) - w = np.ones_like(x) - x_sig = shapes.signature(x) - w_sig = shapes.signature(w) - layer.init((x_sig, x_sig, w_sig)) - y = layer((x, x, w)) - self.assertEqual(y.shape, ()) - state = layer.state - rng = fastmath.random.get_prng(0) - fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] - g = fastmath.grad(fwd)(layer.weights, (x, x, w)) - self.assertEqual(g[0][0].shape, (4, 4)) - - -if __name__ == '__main__': - test.main() diff --git a/trax/layers/reversible.py b/trax/layers/reversible.py index 438e2bf52..2485668f4 100644 --- a/trax/layers/reversible.py +++ b/trax/layers/reversible.py @@ -36,440 +36,489 @@ class ReversibleLayer(base.Layer): - """Reversible Layer.""" - - def reverse(self, output, weights=(), state=(), new_state=(), rng=None): - """Reverse this layer: compute input given output.""" - raise NotImplementedError - - def _pure_forward(self, x, weights, state, rng): - """Call self.forward in a pure way.""" - old_weights, old_state, old_rng = self.weights, self.state, self._rng - self.weights, self.state, self._rng = weights, state, rng - res = self.forward(x) - self.weights, self.state, self._rng = old_weights, old_state, old_rng - return res - - def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(), - rng=None): - """Backward pass: computes the inverse of a layer and propagates gradients. - - While you may choose to only implement reverse, some layers implement this - function directly as computation may be shared between reversing and - computing gradients. - - Args: - output: Output activations; can be a (possibly nested) tuple. - grad: gradient signal (cotangent) computed based on subsequent layers. - The structure and shape must match the output. - weights: layer weights - state: start state - new_state: updated state computed by the forward pass - rng: Single-use random number generator (JAX PRNG key). - - Returns: - A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, - x_grad is the gradient signal for the input, and weights_grad is the - gradient signal for the weights. - """ - reconstructed_x = self.reverse(output, weights, state, new_state, rng) - _, vjpfun = fastmath.vjp( - self._pure_forward, reconstructed_x, weights, state, rng) - x_grad, weights_grad, _, _ = vjpfun(grad) - return reconstructed_x, (x_grad, weights_grad) - - @property - def has_backward(self): - return True - - def backward(self, inputs, output, grad, weights, state, new_state, rng): - del inputs - _, inputs_weights_grad = ( - self.reverse_and_grad(output, grad, weights, state, new_state, rng)) - return inputs_weights_grad + """Reversible Layer.""" + + def reverse(self, output, weights=(), state=(), new_state=(), rng=None): + """Reverse this layer: compute input given output.""" + raise NotImplementedError + + def _pure_forward(self, x, weights, state, rng): + """Call self.forward in a pure way.""" + old_weights, old_state, old_rng = self.weights, self.state, self._rng + self.weights, self.state, self._rng = weights, state, rng + res = self.forward(x) + self.weights, self.state, self._rng = old_weights, old_state, old_rng + return res + + def reverse_and_grad( + self, output, grad, weights=(), state=(), new_state=(), rng=None + ): + """Backward pass: computes the inverse of a layer and propagates gradients. + + While you may choose to only implement reverse, some layers implement this + function directly as computation may be shared between reversing and + computing gradients. + + Args: + output: Output activations; can be a (possibly nested) tuple. + grad: gradient signal (cotangent) computed based on subsequent layers. + The structure and shape must match the output. + weights: layer weights + state: start state + new_state: updated state computed by the forward pass + rng: Single-use random number generator (JAX PRNG key). + + Returns: + A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, + x_grad is the gradient signal for the input, and weights_grad is the + gradient signal for the weights. + """ + reconstructed_x = self.reverse(output, weights, state, new_state, rng) + _, vjpfun = fastmath.vjp( + self._pure_forward, reconstructed_x, weights, state, rng + ) + x_grad, weights_grad, _, _ = vjpfun(grad) + return reconstructed_x, (x_grad, weights_grad) + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, grad, weights, state, new_state, rng): + del inputs + _, inputs_weights_grad = self.reverse_and_grad( + output, grad, weights, state, new_state, rng + ) + return inputs_weights_grad class ReversibleConcatenatePair(ReversibleLayer): - """Maps (x, y) -> ([x, y], [x, y]); [x, y] is concatenation on last axis.""" + """Maps (x, y) -> ([x, y], [x, y]); [x, y] is concatenation on last axis.""" - def __init__(self): - super().__init__(n_in=2, n_out=2) + def __init__(self): + super().__init__(n_in=2, n_out=2) - def forward(self, inputs): - x, y = inputs - r = fastmath.numpy.concatenate((x, y), axis=-1) - return r, r + def forward(self, inputs): + x, y = inputs + r = fastmath.numpy.concatenate((x, y), axis=-1) + return r, r - def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng, weights - pair, _ = outputs - x, y = fastmath.numpy.split(pair, 2, axis=-1) - return x, y + def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng, weights + pair, _ = outputs + x, y = fastmath.numpy.split(pair, 2, axis=-1) + return x, y class ReversibleSelect(ReversibleLayer): - """Reversible version of the Select combinator.""" - - def __init__(self, indices, n_in=None, name=None): - if n_in is None: - n_in = max(indices) + 1 - if name is None: - name = f'ReversibleSelect{indices}'.replace(' ', '') - super().__init__(n_in=n_in, n_out=len(indices), name=name) - self._indices = indices - - # Calculate reverse indices. - self._reverse_indices = [] - for i in range(n_in): - if i not in indices: - raise ValueError('To be reversible, all inputs to Select must be in ' - 'indices. Did not find %d in indices.' % i) - else: - self._reverse_indices.append(indices.index(i)) - - def forward(self, inputs): - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - selected = tuple(inputs[i] for i in self._indices) - return selected[0] if len(selected) == 1 else selected - - def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng, weights - if not isinstance(outputs, (tuple, list)): - outputs = (outputs,) - selected = tuple(outputs[i] for i in self._reverse_indices) - return selected[0] if len(selected) == 1 else selected + """Reversible version of the Select combinator.""" + + def __init__(self, indices, n_in=None, name=None): + if n_in is None: + n_in = max(indices) + 1 + if name is None: + name = f"ReversibleSelect{indices}".replace(" ", "") + super().__init__(n_in=n_in, n_out=len(indices), name=name) + self._indices = indices + + # Calculate reverse indices. + self._reverse_indices = [] + for i in range(n_in): + if i not in indices: + raise ValueError( + "To be reversible, all inputs to Select must be in " + "indices. Did not find %d in indices." % i + ) + else: + self._reverse_indices.append(indices.index(i)) + + def forward(self, inputs): + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + selected = tuple(inputs[i] for i in self._indices) + return selected[0] if len(selected) == 1 else selected + + def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng, weights + if not isinstance(outputs, (tuple, list)): + outputs = (outputs,) + selected = tuple(outputs[i] for i in self._reverse_indices) + return selected[0] if len(selected) == 1 else selected def ReversibleSwap(): # pylint: disable=invalid-name - return ReversibleSelect([1, 0], name='ReversibleSwap') + return ReversibleSelect([1, 0], name="ReversibleSwap") class ReversibleReshape(ReversibleLayer): - """Reversible reshaping layer.""" - - def __init__(self, shape1, shape2, n_in=1): - self._shape1 = list(shape1) - self._shape2 = list(shape2) - name = 'ReversibleReshape_%s_%s' % (str(shape1), str(shape2)) - super().__init__(n_in=n_in, n_out=n_in, name=name) - - def forward(self, inputs): - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - res = [] - for x in inputs: - new_shape = self._shape2 + list(x.shape)[len(self._shape1):] - res.append(fastmath.numpy.reshape(x, new_shape)) - if len(res) == 1: - return res[0] - return tuple(res) - - def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng, weights - if not isinstance(outputs, (tuple, list)): - outputs = (outputs,) - res = [] - for x in outputs: - new_shape = self._shape1 + list(x.shape)[len(self._shape2):] - res.append(fastmath.numpy.reshape(x, new_shape)) - if len(res) == 1: - return res[0] - return tuple(res) + """Reversible reshaping layer.""" + + def __init__(self, shape1, shape2, n_in=1): + self._shape1 = list(shape1) + self._shape2 = list(shape2) + name = "ReversibleReshape_%s_%s" % (str(shape1), str(shape2)) + super().__init__(n_in=n_in, n_out=n_in, name=name) + + def forward(self, inputs): + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + res = [] + for x in inputs: + new_shape = self._shape2 + list(x.shape)[len(self._shape1) :] + res.append(fastmath.numpy.reshape(x, new_shape)) + if len(res) == 1: + return res[0] + return tuple(res) + + def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng, weights + if not isinstance(outputs, (tuple, list)): + outputs = (outputs,) + res = [] + for x in outputs: + new_shape = self._shape1 + list(x.shape)[len(self._shape2) :] + res.append(fastmath.numpy.reshape(x, new_shape)) + if len(res) == 1: + return res[0] + return tuple(res) class ReversiblePrintShape(ReversibleLayer): - """Reversible PrintShape for debugging reversible serial layers.""" + """Reversible PrintShape for debugging reversible serial layers.""" - def __init__(self, n_in=1, msg=''): - super().__init__(n_in=n_in, n_out=n_in) - self._msg = msg + def __init__(self, n_in=1, msg=""): + super().__init__(n_in=n_in, n_out=n_in) + self._msg = msg - def forward(self, xs): - shapes_and_dtypes = ', '.join([str(x.shape) + f'[{x.dtype}]' for x in xs]) - info = f'PrintShape: {self._msg}: [{shapes_and_dtypes}]' - print(info) - logging.info(info) - return xs + def forward(self, xs): + shapes_and_dtypes = ", ".join([str(x.shape) + f"[{x.dtype}]" for x in xs]) + info = f"PrintShape: {self._msg}: [{shapes_and_dtypes}]" + print(info) + logging.info(info) + return xs - def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng, weights - return outputs + def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng, weights + return outputs class ReversibleSerial(ReversibleLayer, cb.Serial): - """A reversible version of tl.Serial (requires reversible sub-layers).""" - - def __init__(self, *layers): - super().__init__(*layers) - # def __init__(self, *layers): # pylint: disable=super-init-not-called - # cb.Serial.__init__(self, layers) - - # Note that sublayers has already been flattened to remove nested lists. - for i, layer in enumerate(self.sublayers): - if not isinstance(layer, ReversibleLayer): - raise ValueError( - 'Sub-layer {} of ReversibleSerial is not reversible: {}'.format( - i, layer)) - - def reverse(self, output, weights=(), state=(), new_state=(), rng=None): - rngs = (None,) * self._n_layers - if rng is not None: - rngs = fastmath.random.split(rng, self._n_layers) - - stack = output - for layer, p, s, ns, rng in reversed(list(zip( - self.sublayers, weights, state, new_state, rngs))): - layer_val = cb.inputs_from_stack(stack, layer.n_out) - layer_val = layer.reverse(layer_val, p, s, ns, rng=rng) - stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) - - return stack - - def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(), - rng=None): - rngs = (None,) * self._n_layers - if rng is not None: - rngs = fastmath.random.split(rng, self._n_layers) - - stack = output - stack_grad = grad - weights_grad = [] - for layer, p, s, ns, rng in reversed(list(zip( - self.sublayers, weights, state, new_state, rngs))): - layer_val = cb.inputs_from_stack(stack, layer.n_out) - layer_ct = cb.inputs_from_stack(stack_grad, layer.n_out) - layer_val, layer_ct = layer.reverse_and_grad( - layer_val, layer_ct, p, s, ns, rng=rng) - layer_ct, p_ct = layer_ct - weights_grad.insert(0, p_ct) - stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) - stack_grad = cb.outputs_onto_stack(layer_ct, stack_grad, layer.n_out) - - return stack, (stack_grad, tuple(weights_grad)) + """A reversible version of tl.Serial (requires reversible sub-layers).""" + + def __init__(self, *layers): + super().__init__(*layers) + # def __init__(self, *layers): # pylint: disable=super-init-not-called + # cb.Serial.__init__(self, layers) + + # Note that sublayers has already been flattened to remove nested lists. + for i, layer in enumerate(self.sublayers): + if not isinstance(layer, ReversibleLayer): + raise ValueError( + "Sub-layer {} of ReversibleSerial is not reversible: {}".format( + i, layer + ) + ) + + def reverse(self, output, weights=(), state=(), new_state=(), rng=None): + rngs = (None,) * self._n_layers + if rng is not None: + rngs = fastmath.random.split(rng, self._n_layers) + + stack = output + for layer, p, s, ns, rng in reversed( + list(zip(self.sublayers, weights, state, new_state, rngs)) + ): + layer_val = cb.inputs_from_stack(stack, layer.n_out) + layer_val = layer.reverse(layer_val, p, s, ns, rng=rng) + stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) + + return stack + + def reverse_and_grad( + self, output, grad, weights=(), state=(), new_state=(), rng=None + ): + rngs = (None,) * self._n_layers + if rng is not None: + rngs = fastmath.random.split(rng, self._n_layers) + + stack = output + stack_grad = grad + weights_grad = [] + for layer, p, s, ns, rng in reversed( + list(zip(self.sublayers, weights, state, new_state, rngs)) + ): + layer_val = cb.inputs_from_stack(stack, layer.n_out) + layer_ct = cb.inputs_from_stack(stack_grad, layer.n_out) + layer_val, layer_ct = layer.reverse_and_grad( + layer_val, layer_ct, p, s, ns, rng=rng + ) + layer_ct, p_ct = layer_ct + weights_grad.insert(0, p_ct) + stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) + stack_grad = cb.outputs_onto_stack(layer_ct, stack_grad, layer.n_out) + + return stack, (stack_grad, tuple(weights_grad)) class ReversibleHalfResidual(ReversibleLayer): - """Half of a RevNet-style residual that optionally performs attention. - - When attention_layer is None, this layer has the signature :: - - [accumulator, *context] -> [accumulator + f(context), *context] - - The attention_layer must be an instance of EfficientAttentionBase or one of - its subclasses (see efficient_attention.py), or None. - - Attention is special-cased for the following two reasons: - - - LSH attention needs to save bucket assignments from the forward pass to the - backward pass, for training stability. This requires special-casing it. - - We can call attention_layer.forward_and_or_backward to compute its output - (needed for inverting a reversible residual layer) while simultaneously - performing the backward pass. Sharing computation between these two - operations improves training speed. - """ - - def __init__(self, *residual_layers, attention_layer=None, name=None): - super().__init__(name=name) - - self._compute_residual = cb.Serial(*residual_layers) - self._attention_layer = attention_layer - - if self._attention_layer is None: - self._sublayers = (self._compute_residual,) - else: - if hasattr(attention_layer, 'forward_and_or_backward'): - self._forward_and_or_backward = attention_layer.forward_and_or_backward - else: - self._forward_and_or_backward = _forward_and_or_backward( - attention_layer) - self._sublayers = (self._compute_residual, self._attention_layer) - - running_max = 0 - running_total = 0 - for layer in self._sublayers: - running_total += layer.n_in - running_max = max(running_max, running_total) - running_total -= layer.n_out - self._n_in = self._n_out = running_max + 1 - - def forward(self, xs): - rngs = _split_rngs(self.rng, len(self.sublayers)) - accumulator, *context = xs - stack = context = tuple(context) - new_state = [] - for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs): - inputs = cb.inputs_from_stack(stack, layer.n_in) - if base.N_WEIGHTS_SHARDS > 1: - # With sharded weights, make sure we don't keep them concatenated - # in memory on each device by using remat. - outputs, s = jax.remat(layer.pure_fn)(inputs, w, s, rng) - else: - outputs, s = layer.pure_fn(inputs, w, s, rng) - stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) - new_state.append(s) - residual = stack[0] if isinstance(stack, (tuple, list)) else stack - - output = accumulator + residual - stack = (output,) + context - self.state = tuple(new_state) - return stack - - def reverse(self, output, weights=(), state=(), new_state=(), rng=None): - raise NotImplementedError('Only reverse_and_grad is actually used.') - - def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), - rng=None): - rngs = _split_rngs(rng, len(self.sublayers)) - - accumulator_output, *context = output - context = tuple(context) - accumulator_output_ct, *context_ct = ct - context_ct = tuple(context_ct) - - # Forward pass through self._compute_residual. Outputs that will not receive - # a gradient signal from subsequent layers are moved to aux. - def call_compute_residual(x, weights): - state_to_pass = state[0] # old_state - - # _replace_second_time is currently used exclusively in _RememberInReverse - # layer to combat numerical instability in Terraformer when quantizing - # the mask in SparseFF. - def _replace_second_time(stt, nstt): - if (isinstance(stt, tuple) and len(stt) == 2 and - isinstance(stt[1], dict) and 'running_second_time' in stt[1]): - return (nstt[0], {'running_second_time_yes': ()}) - elif isinstance(stt, (tuple, list)): - assert isinstance(nstt, (tuple, list)) and len(nstt) == len(stt) - return type(stt)([ - _replace_second_time(s, ns) for s, ns in zip(stt, nstt)]) + """Half of a RevNet-style residual that optionally performs attention. + + When attention_layer is None, this layer has the signature :: + + [accumulator, *context] -> [accumulator + f(context), *context] + + The attention_layer must be an instance of EfficientAttentionBase or one of + its subclasses (see efficient_attention.py), or None. + + Attention is special-cased for the following two reasons: + + - LSH attention needs to save bucket assignments from the forward pass to the + backward pass, for training stability. This requires special-casing it. + - We can call attention_layer.forward_and_or_backward to compute its output + (needed for inverting a reversible residual layer) while simultaneously + performing the backward pass. Sharing computation between these two + operations improves training speed. + """ + + def __init__(self, *residual_layers, attention_layer=None, name=None): + super().__init__(name=name) + + self._compute_residual = cb.Serial(*residual_layers) + self._attention_layer = attention_layer + + if self._attention_layer is None: + self._sublayers = (self._compute_residual,) + else: + if hasattr(attention_layer, "forward_and_or_backward"): + self._forward_and_or_backward = attention_layer.forward_and_or_backward + else: + self._forward_and_or_backward = _forward_and_or_backward( + attention_layer + ) + self._sublayers = (self._compute_residual, self._attention_layer) + + running_max = 0 + running_total = 0 + for layer in self._sublayers: + running_total += layer.n_in + running_max = max(running_max, running_total) + running_total -= layer.n_out + self._n_in = self._n_out = running_max + 1 + + def forward(self, xs): + rngs = _split_rngs(self.rng, len(self.sublayers)) + accumulator, *context = xs + stack = context = tuple(context) + new_state = [] + for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs): + inputs = cb.inputs_from_stack(stack, layer.n_in) + if base.N_WEIGHTS_SHARDS > 1: + # With sharded weights, make sure we don't keep them concatenated + # in memory on each device by using remat. + outputs, s = jax.remat(layer.pure_fn)(inputs, w, s, rng) + else: + outputs, s = layer.pure_fn(inputs, w, s, rng) + stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) + new_state.append(s) + residual = stack[0] if isinstance(stack, (tuple, list)) else stack + + output = accumulator + residual + stack = (output,) + context + self.state = tuple(new_state) + return stack + + def reverse(self, output, weights=(), state=(), new_state=(), rng=None): + raise NotImplementedError("Only reverse_and_grad is actually used.") + + def reverse_and_grad( + self, output, ct, weights=(), state=(), new_state=(), rng=None + ): + rngs = _split_rngs(rng, len(self.sublayers)) + + accumulator_output, *context = output + context = tuple(context) + accumulator_output_ct, *context_ct = ct + context_ct = tuple(context_ct) + + # Forward pass through self._compute_residual. Outputs that will not receive + # a gradient signal from subsequent layers are moved to aux. + def call_compute_residual(x, weights): + state_to_pass = state[0] # old_state + + # _replace_second_time is currently used exclusively in _RememberInReverse + # layer to combat numerical instability in Terraformer when quantizing + # the mask in SparseFF. + def _replace_second_time(stt, nstt): + if ( + isinstance(stt, tuple) + and len(stt) == 2 + and isinstance(stt[1], dict) + and "running_second_time" in stt[1] + ): + return (nstt[0], {"running_second_time_yes": ()}) + elif isinstance(stt, (tuple, list)): + assert isinstance(nstt, (tuple, list)) and len(nstt) == len(stt) + return type(stt)( + [_replace_second_time(s, ns) for s, ns in zip(stt, nstt)] + ) + else: + return stt + + state_to_pass = _replace_second_time(state_to_pass, new_state[0]) + res, _ = self._compute_residual.pure_fn( + x, weights=weights, state=state_to_pass, rng=rngs[0] + ) + if not isinstance(res, (tuple, list)): + return res, None + else: + n_differentiable = 1 + if self._attention_layer is not None: + n_differentiable = min(len(res), self._attention_layer.n_in) + return res[:n_differentiable], res[n_differentiable:] + + stack = context + inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) + outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( + call_compute_residual, inputs, weights[0], has_aux=True + ) + if outputs_aux is not None: + n_differentiable_outputs = len(outputs) + outputs = outputs + outputs_aux + stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) + + stack_ct = accumulator_output_ct + if self._attention_layer is None: + residual = stack[0] if isinstance(stack, (tuple, list)) else stack + else: + inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) + ( + residual, + _, + attn_inputs_ct, + attn_weights_ct, + ) = self._forward_and_or_backward( + inputs, + weights[1], + new_state[1], + rngs[1], + output_grad=accumulator_output_ct, + compute_output=True, + update_state=False, + ) + stack_ct = cb.outputs_onto_stack( + attn_inputs_ct, stack_ct, self._attention_layer.n_out + ) + + compute_residual_ct = cb.inputs_from_stack( + stack_ct, self._compute_residual.n_out + ) + if outputs_aux is not None: + if not isinstance(compute_residual_ct, (tuple, list)): + compute_residual_ct = (compute_residual_ct,) + compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] + assert len(compute_residual_ct) == n_differentiable_outputs + ( + compute_residual_inputs_ct, + compute_residual_weights_ct, + ) = compute_residual_vjpfun(compute_residual_ct) + stack_ct = cb.outputs_onto_stack( + compute_residual_inputs_ct, stack_ct, self._compute_residual.n_out + ) + if not isinstance(stack_ct, (tuple, list)): + stack_ct = (stack_ct,) + + def _add(x, y): + # `None` is for TFNP backend, which uses `None` as the gradient of + # int/bool instead of an array of dtype `float0`. + if x is None or x.dtype == jax.float0: + return y + if y is None or y.dtype == jax.float0: + return x + return x + y + + stack_ct = ( + (accumulator_output_ct,) + + fastmath.nested_map_multiarg(_add, context_ct[: len(stack_ct)], stack_ct) + + context_ct[len(stack_ct) :] + ) + + reconstructed_x = accumulator_output - residual + stack = (reconstructed_x,) + context + if self._attention_layer is None: + weights_ct = (compute_residual_weights_ct,) + else: + weights_ct = (compute_residual_weights_ct, attn_weights_ct) + return stack, (stack_ct, weights_ct) + + # pylint: disable=protected-access + def init_weights_and_state(self, input_signature): + stack = input_signature[1:] + if len(stack) == 1: + stack = stack[0] + + inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) + weights, state = self._compute_residual.init(inputs) + outputs, _ = self._compute_residual._forward_abstract(inputs) + stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) + + if self._attention_layer is None: + self.state = (state,) + self.weights = (weights,) else: - return stt - - state_to_pass = _replace_second_time(state_to_pass, new_state[0]) - res, _ = self._compute_residual.pure_fn( - x, weights=weights, state=state_to_pass, rng=rngs[0]) - if not isinstance(res, (tuple, list)): - return res, None - else: - n_differentiable = 1 - if self._attention_layer is not None: - n_differentiable = min(len(res), self._attention_layer.n_in) - return res[:n_differentiable], res[n_differentiable:] - - stack = context - inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) - outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( - call_compute_residual, inputs, weights[0], has_aux=True) - if outputs_aux is not None: - n_differentiable_outputs = len(outputs) - outputs = outputs + outputs_aux - stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) - - stack_ct = accumulator_output_ct - if self._attention_layer is None: - residual = stack[0] if isinstance(stack, (tuple, list)) else stack - else: - inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) - (residual, _, attn_inputs_ct, attn_weights_ct - ) = self._forward_and_or_backward( - inputs, weights[1], new_state[1], rngs[1], - output_grad=accumulator_output_ct, - compute_output=True, update_state=False) - stack_ct = cb.outputs_onto_stack( - attn_inputs_ct, stack_ct, self._attention_layer.n_out) - - compute_residual_ct = cb.inputs_from_stack( - stack_ct, self._compute_residual.n_out) - if outputs_aux is not None: - if not isinstance(compute_residual_ct, (tuple, list)): - compute_residual_ct = (compute_residual_ct,) - compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] - assert len(compute_residual_ct) == n_differentiable_outputs - (compute_residual_inputs_ct, compute_residual_weights_ct - ) = compute_residual_vjpfun(compute_residual_ct) - stack_ct = cb.outputs_onto_stack( - compute_residual_inputs_ct, stack_ct, self._compute_residual.n_out) - if not isinstance(stack_ct, (tuple, list)): - stack_ct = (stack_ct,) - def _add(x, y): - # `None` is for TFNP backend, which uses `None` as the gradient of - # int/bool instead of an array of dtype `float0`. - if x is None or x.dtype == jax.float0: - return y - if y is None or y.dtype == jax.float0: - return x - return x + y - stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg( - _add, context_ct[:len(stack_ct)], stack_ct) + context_ct[len(stack_ct):] - - reconstructed_x = accumulator_output - residual - stack = (reconstructed_x,) + context - if self._attention_layer is None: - weights_ct = (compute_residual_weights_ct,) - else: - weights_ct = (compute_residual_weights_ct, attn_weights_ct) - return stack, (stack_ct, weights_ct) - - # pylint: disable=protected-access - def init_weights_and_state(self, input_signature): - stack = input_signature[1:] - if len(stack) == 1: - stack = stack[0] - - inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) - weights, state = self._compute_residual.init(inputs) - outputs, _ = self._compute_residual._forward_abstract(inputs) - stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) - - if self._attention_layer is None: - self.state = (state,) - self.weights = (weights,) - else: - inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) - attn_weights, attn_state = self._attention_layer.init(inputs) - self.state = (state, attn_state) - self.weights = (weights, attn_weights) - # pylint: enable=protected-access + inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) + attn_weights, attn_state = self._attention_layer.init(inputs) + self.state = (state, attn_state) + self.weights = (weights, attn_weights) + + # pylint: enable=protected-access def _forward_and_or_backward(layer): - """Create forward_and_or_backward for layers that don't define it.""" - - def forward_and_or_backward(inputs, weights, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. - - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. - - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # Calculate the vector-Jacobian product of the layer pure_fn. - output, vjp_fn, new_state = fastmath.vjp( - layer.pure_fn, inputs, weights, state, rng, has_aux=True) - output = output if compute_output else None - new_state = new_state if update_state else None - - # The vjp function returns gradients with respect to inputs and weights. - if output_grad is not None: - grads_inputs, grads_weights, _, _ = vjp_fn(output_grad) - else: - grads_inputs, grads_weights = None, None - - return (output, new_state, grads_inputs, grads_weights) - return forward_and_or_backward + """Create forward_and_or_backward for layers that don't define it.""" + + def forward_and_or_backward( + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # Calculate the vector-Jacobian product of the layer pure_fn. + output, vjp_fn, new_state = fastmath.vjp( + layer.pure_fn, inputs, weights, state, rng, has_aux=True + ) + output = output if compute_output else None + new_state = new_state if update_state else None + + # The vjp function returns gradients with respect to inputs and weights. + if output_grad is not None: + grads_inputs, grads_weights, _, _ = vjp_fn(output_grad) + else: + grads_inputs, grads_weights = None, None + + return (output, new_state, grads_inputs, grads_weights) + + return forward_and_or_backward diff --git a/trax/layers/rnn.py b/trax/layers/rnn.py index 3d80cfdda..3fd56eb10 100644 --- a/trax/layers/rnn.py +++ b/trax/layers/rnn.py @@ -26,203 +26,230 @@ class LSTMCell(base.Layer): - """LSTM Cell. - - For a nice overview of the motivation and (i, o, f) gates, see this tutorial: - https://colah.github.io/posts/2015-08-Understanding-LSTMs/ - - See this paper for a description and detailed study of all gate types: - https://arxiv.org/pdf/1503.04069.pdf - """ - - def __init__(self, - n_units, - forget_bias=1.0, - kernel_initializer=initializers.GlorotUniformInitializer(), - bias_initializer=initializers.RandomNormalInitializer(1e-6)): - super().__init__(n_in=2, n_out=2) - self._n_units = n_units - self._forget_bias = forget_bias - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - - def forward(self, inputs): - x, lstm_state = inputs - - # LSTM state consists of c and h. - c, h = jnp.split(lstm_state, 2, axis=-1) - - # Dense layer on the concatenation of x and h. - w, b = self.weights - y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = jnp.split(y, 4, axis=-1) - - new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j) - new_h = jnp.tanh(new_c) * fastmath.sigmoid(o) - return new_h, jnp.concatenate([new_c, new_h], axis=-1) - - def init_weights_and_state(self, input_signature): - # LSTM state last dimension must be twice n_units. - if input_signature[1].shape[-1] != 2 * self._n_units: - raise ValueError( - f'Last dimension of state (shape: {str(input_signature[1].shape)}) ' - f'must be equal to 2*n_units ({2 * self._n_units})') - # The dense layer input is the input and half of the lstm state. - input_shape = input_signature[0].shape[-1] + self._n_units - rng1, rng2 = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer((input_shape, 4 * self._n_units), rng1) - b = self._bias_initializer((4 * self._n_units,), rng2) + self._forget_bias - self.weights = (w, b) + """LSTM Cell. + + For a nice overview of the motivation and (i, o, f) gates, see this tutorial: + https://colah.github.io/posts/2015-08-Understanding-LSTMs/ + + See this paper for a description and detailed study of all gate types: + https://arxiv.org/pdf/1503.04069.pdf + """ + + def __init__( + self, + n_units, + forget_bias=1.0, + kernel_initializer=initializers.GlorotUniformInitializer(), + bias_initializer=initializers.RandomNormalInitializer(1e-6), + ): + super().__init__(n_in=2, n_out=2) + self._n_units = n_units + self._forget_bias = forget_bias + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + def forward(self, inputs): + x, lstm_state = inputs + + # LSTM state consists of c and h. + c, h = jnp.split(lstm_state, 2, axis=-1) + + # Dense layer on the concatenation of x and h. + w, b = self.weights + y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = jnp.split(y, 4, axis=-1) + + new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j) + new_h = jnp.tanh(new_c) * fastmath.sigmoid(o) + return new_h, jnp.concatenate([new_c, new_h], axis=-1) + + def init_weights_and_state(self, input_signature): + # LSTM state last dimension must be twice n_units. + if input_signature[1].shape[-1] != 2 * self._n_units: + raise ValueError( + f"Last dimension of state (shape: {str(input_signature[1].shape)}) " + f"must be equal to 2*n_units ({2 * self._n_units})" + ) + # The dense layer input is the input and half of the lstm state. + input_shape = input_signature[0].shape[-1] + self._n_units + rng1, rng2 = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer((input_shape, 4 * self._n_units), rng1) + b = self._bias_initializer((4 * self._n_units,), rng2) + self._forget_bias + self.weights = (w, b) def MakeZeroState(depth_multiplier=1): - """Makes zeros of shape like x but removing the length (axis 1).""" - def f(x): # pylint: disable=invalid-name - if len(x.shape) != 3: - raise ValueError(f'Layer input should be a rank 3 tensor representing' - f' (batch_size, sequence_length, feature_depth); ' - f'instead got shape {x.shape}.') - return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]), - dtype=jnp.float32) - return base.Fn('MakeZeroState', f) - - -def LSTM(n_units, mode='train', return_state=False, initial_state=False): - """LSTM running on axis 1. - - Args: - n_units: `n_units` for the `LSTMCell`. - mode: if 'predict' then we save the previous state for one-by-one inference. - return_state: Boolean. Whether to return the latest status in addition to - the output. Default: False. - initial_state: Boolean. If the state RNN (c, h) is to be obtained from the - stack. Default: False. - - Returns: - A LSTM layer. - """ - - if not initial_state: - zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter - if return_state: - return cb.Serial( - cb.Branch([], zero_state), - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - name=f'LSTM_{n_units}', - sublayers_to_print=[]) + """Makes zeros of shape like x but removing the length (axis 1).""" + + def f(x): # pylint: disable=invalid-name + if len(x.shape) != 3: + raise ValueError( + f"Layer input should be a rank 3 tensor representing" + f" (batch_size, sequence_length, feature_depth); " + f"instead got shape {x.shape}." + ) + return jnp.zeros( + (x.shape[0], depth_multiplier * x.shape[-1]), dtype=jnp.float32 + ) + + return base.Fn("MakeZeroState", f) + + +def LSTM(n_units, mode="train", return_state=False, initial_state=False): + """LSTM running on axis 1. + + Args: + n_units: `n_units` for the `LSTMCell`. + mode: if 'predict' then we save the previous state for one-by-one inference. + return_state: Boolean. Whether to return the latest status in addition to + the output. Default: False. + initial_state: Boolean. If the state RNN (c, h) is to be obtained from the + stack. Default: False. + + Returns: + A LSTM layer. + """ + + if not initial_state: + zero_state = MakeZeroState( + depth_multiplier=2 + ) # pylint: disable=no-value-for-parameter + if return_state: + return cb.Serial( + cb.Branch([], zero_state), + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + name=f"LSTM_{n_units}", + sublayers_to_print=[], + ) + else: + return cb.Serial( + cb.Branch([], zero_state), # fill state RNN with zero. + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), # Drop RNN state. + # Set the name to LSTM and don't print sublayers. + name=f"LSTM_{n_units}", + sublayers_to_print=[], + ) else: - return cb.Serial( - cb.Branch([], zero_state), # fill state RNN with zero. - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - cb.Select([0], n_in=2), # Drop RNN state. - # Set the name to LSTM and don't print sublayers. - name=f'LSTM_{n_units}', sublayers_to_print=[]) - else: - if return_state: - return cb.Serial( - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - name=f'LSTM_{n_units}', sublayers_to_print=[]) - else: - return cb.Serial( - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - cb.Select([0], n_in=2), # Drop RNN state. - name=f'LSTM_{n_units}', sublayers_to_print=[]) + if return_state: + return cb.Serial( + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + name=f"LSTM_{n_units}", + sublayers_to_print=[], + ) + else: + return cb.Serial( + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), # Drop RNN state. + name=f"LSTM_{n_units}", + sublayers_to_print=[], + ) class GRUCell(base.Layer): - """Builds a traditional GRU cell with dense internal transformations. - - Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555 - """ - - def __init__(self, - n_units, - forget_bias=0.0, - kernel_initializer=initializers.RandomUniformInitializer(0.01), - bias_initializer=initializers.RandomNormalInitializer(1e-6)): - super().__init__(n_in=2, n_out=2) - self._n_units = n_units - self._forget_bias = forget_bias - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - - def forward(self, inputs): - x, gru_state = inputs - - # Dense layer on the concatenation of x and h. - w1, b1, w2, b2 = self.weights - y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 - - # Update and reset gates. - u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1) - - # Candidate. - c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 - - new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) - return new_gru_state, new_gru_state - - def init_weights_and_state(self, input_signature): - if input_signature[1].shape[-1] != self._n_units: - raise ValueError( - f'Second argument in input signature should have a final dimension of' - f' {self._n_units}; instead got {input_signature[1].shape[-1]}.') - - # The dense layer input is the input and half of the GRU state. - input_shape = input_signature[0].shape[-1] + self._n_units - rng1, rng2, rng3, rng4 = fastmath.random.split(self.rng, 4) - w1 = self._kernel_initializer((input_shape, 2 * self._n_units), rng1) - b1 = self._bias_initializer((2 * self._n_units,), rng2) + self._forget_bias - w2 = self._kernel_initializer((input_shape, self._n_units), rng3) - b2 = self._bias_initializer((self._n_units,), rng4) - self.weights = (w1, b1, w2, b2) - - -def GRU(n_units, mode='train'): - """GRU running on axis 1.""" - zero_state = MakeZeroState(depth_multiplier=1) # pylint: disable=no-value-for-parameter - return cb.Serial( - cb.Branch([], zero_state), - cb.Scan(GRUCell(n_units=n_units), axis=1, mode=mode), - cb.Select([0], n_in=2), # Drop RNN state. - # Set the name to GRU and don't print sublayers. - name=f'GRU_{n_units}', sublayers_to_print=[] - ) + """Builds a traditional GRU cell with dense internal transformations. + + Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555 + """ + + def __init__( + self, + n_units, + forget_bias=0.0, + kernel_initializer=initializers.RandomUniformInitializer(0.01), + bias_initializer=initializers.RandomNormalInitializer(1e-6), + ): + super().__init__(n_in=2, n_out=2) + self._n_units = n_units + self._forget_bias = forget_bias + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + def forward(self, inputs): + x, gru_state = inputs + + # Dense layer on the concatenation of x and h. + w1, b1, w2, b2 = self.weights + y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 + + # Update and reset gates. + u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1) + + # Candidate. + c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 + + new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) + return new_gru_state, new_gru_state + + def init_weights_and_state(self, input_signature): + if input_signature[1].shape[-1] != self._n_units: + raise ValueError( + f"Second argument in input signature should have a final dimension of" + f" {self._n_units}; instead got {input_signature[1].shape[-1]}." + ) + + # The dense layer input is the input and half of the GRU state. + input_shape = input_signature[0].shape[-1] + self._n_units + rng1, rng2, rng3, rng4 = fastmath.random.split(self.rng, 4) + w1 = self._kernel_initializer((input_shape, 2 * self._n_units), rng1) + b1 = self._bias_initializer((2 * self._n_units,), rng2) + self._forget_bias + w2 = self._kernel_initializer((input_shape, self._n_units), rng3) + b2 = self._bias_initializer((self._n_units,), rng4) + self.weights = (w1, b1, w2, b2) + + +def GRU(n_units, mode="train"): + """GRU running on axis 1.""" + zero_state = MakeZeroState( + depth_multiplier=1 + ) # pylint: disable=no-value-for-parameter + return cb.Serial( + cb.Branch([], zero_state), + cb.Scan(GRUCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), # Drop RNN state. + # Set the name to GRU and don't print sublayers. + name=f"GRU_{n_units}", + sublayers_to_print=[], + ) def ConvGRUCell(n_units, kernel_size=(3, 3)): - """Builds a convolutional GRU. + """Builds a convolutional GRU. - Paper: https://arxiv.org/abs/1511.06432. + Paper: https://arxiv.org/abs/1511.06432. - Args: - n_units: Number of hidden units - kernel_size: Kernel size for convolution + Args: + n_units: Number of hidden units + kernel_size: Kernel size for convolution - Returns: - A Stax model representing a GRU cell with convolution transforms. - """ + Returns: + A Stax model representing a GRU cell with convolution transforms. + """ - def BuildConv(): - return convolution.Conv( - filters=n_units, kernel_size=kernel_size, padding='SAME') + def BuildConv(): + return convolution.Conv( + filters=n_units, kernel_size=kernel_size, padding="SAME" + ) - return GeneralGRUCell( - candidate_transform=BuildConv, - memory_transform_fn=None, - gate_nonlinearity=activation_fns.Sigmoid, - candidate_nonlinearity=activation_fns.Tanh) + return GeneralGRUCell( + candidate_transform=BuildConv, + memory_transform_fn=None, + gate_nonlinearity=activation_fns.Sigmoid, + candidate_nonlinearity=activation_fns.Tanh, + ) -def GeneralGRUCell(candidate_transform, - memory_transform_fn=None, - gate_nonlinearity=activation_fns.Sigmoid, - candidate_nonlinearity=activation_fns.Tanh, - dropout_rate_c=0.1, - sigmoid_bias=0.5): - r"""Parametrized Gated Recurrent Unit (GRU) cell construction. +def GeneralGRUCell( + candidate_transform, + memory_transform_fn=None, + gate_nonlinearity=activation_fns.Sigmoid, + candidate_nonlinearity=activation_fns.Tanh, + dropout_rate_c=0.1, + sigmoid_bias=0.5, +): + r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations for update gate, reset gate, candidate memory, and new state: @@ -252,75 +279,85 @@ def GeneralGRUCell(candidate_transform, Returns: A model representing a GRU cell with specified transforms. """ - gate_block = [ # u_t - candidate_transform(), - _AddSigmoidBias(sigmoid_bias), - gate_nonlinearity(), - ] - reset_block = [ # r_t - candidate_transform(), - _AddSigmoidBias(sigmoid_bias), # Want bias to start positive. - gate_nonlinearity(), - ] - candidate_block = [ - cb.Dup(), - reset_block, - cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) - candidate_transform(), # Final projection + tanh to get Ct - candidate_nonlinearity(), # Candidate gate - - # Only apply dropout on the C gate. Paper reports 0.1 as a good default. - core.Dropout(rate=dropout_rate_c) - ] - memory_transform = memory_transform_fn() if memory_transform_fn else [] - return cb.Serial( - cb.Branch(memory_transform, gate_block, candidate_block), - cb.Gate(), - ) + gate_block = [ # u_t + candidate_transform(), + _AddSigmoidBias(sigmoid_bias), + gate_nonlinearity(), + ] + reset_block = [ # r_t + candidate_transform(), + _AddSigmoidBias(sigmoid_bias), # Want bias to start positive. + gate_nonlinearity(), + ] + candidate_block = [ + cb.Dup(), + reset_block, + cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) + candidate_transform(), # Final projection + tanh to get Ct + candidate_nonlinearity(), # Candidate gate + # Only apply dropout on the C gate. Paper reports 0.1 as a good default. + core.Dropout(rate=dropout_rate_c), + ] + memory_transform = memory_transform_fn() if memory_transform_fn else [] + return cb.Serial( + cb.Branch(memory_transform, gate_block, candidate_block), + cb.Gate(), + ) def InnerSRUCell(): - """The inner (non-parallel) computation of an SRU.""" - def f(cur_x_times_one_minus_f, cur_f, cur_state): # pylint: disable=invalid-name - res = cur_f * cur_state + cur_x_times_one_minus_f - return res, res - return base.Fn('InnerSRUCell', f, n_out=2) - - -def ScanSRUCell(mode, monkey_patched_mask=None): - """The inner (non-parallel) computation of an SRU.""" - if monkey_patched_mask is None: - return cb.Scan(InnerSRUCell(), axis=1, mode=mode) - - # This is necessary for Terraformer model. See comments there. - # The mask will only be used in Terraformer in predict mode. - assert mode == 'predict' - - def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name - initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32) - if initial.shape[1] > 1: - updated_mask = fastmath.dynamic_update_slice_in_dim( - initial != 0, mask != 0, 1, axis=1) - else: - updated_mask = initial - return updated_mask, x_times_one_minus_f + """The inner (non-parallel) computation of an SRU.""" - def masked_inner_sru_cell(cur_mask, cur_x_times_one_minus_f, cur_f, # pylint: disable=invalid-name - cur_state): - res = ((cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask - + (1 - cur_mask) * cur_state) - return res, res + def f(cur_x_times_one_minus_f, cur_f, cur_state): # pylint: disable=invalid-name + res = cur_f * cur_state + cur_x_times_one_minus_f + return res, res - return cb.Serial( - monkey_patched_mask.get_layer(), - base.Fn('update_mask', update_mask, n_out=2), - cb.Scan(base.Fn('MaskedInnerSRUCell', masked_inner_sru_cell, n_out=2), - axis=1, mode=mode), - ) + return base.Fn("InnerSRUCell", f, n_out=2) -def SRU(n_units, activation=None, mode='train'): - r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. +def ScanSRUCell(mode, monkey_patched_mask=None): + """The inner (non-parallel) computation of an SRU.""" + if monkey_patched_mask is None: + return cb.Scan(InnerSRUCell(), axis=1, mode=mode) + + # This is necessary for Terraformer model. See comments there. + # The mask will only be used in Terraformer in predict mode. + assert mode == "predict" + + def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name + initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32) + if initial.shape[1] > 1: + updated_mask = fastmath.dynamic_update_slice_in_dim( + initial != 0, mask != 0, 1, axis=1 + ) + else: + updated_mask = initial + return updated_mask, x_times_one_minus_f + + def masked_inner_sru_cell( + cur_mask, + cur_x_times_one_minus_f, + cur_f, # pylint: disable=invalid-name + cur_state, + ): + res = (cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask + ( + 1 - cur_mask + ) * cur_state + return res, res + + return cb.Serial( + monkey_patched_mask.get_layer(), + base.Fn("update_mask", update_mask, n_out=2), + cb.Scan( + base.Fn("MaskedInnerSRUCell", masked_inner_sru_cell, n_out=2), + axis=1, + mode=mode, + ), + ) + + +def SRU(n_units, activation=None, mode="train"): + r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: @@ -343,24 +380,24 @@ def SRU(n_units, activation=None, mode='train'): Returns: The SRU layer. """ - sigmoid_activation = activation_fns.Sigmoid() - return cb.Serial( # x - cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x - cb.Split(n_items=3), # r, f, y, x - cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x - base.Fn('', - lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x - n_out=3), - cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), - ScanSRUCell(mode=mode), - cb.Select([0], n_in=2), # act(c), r, x - activation if activation is not None else [], - base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), - # Set the name to SRU and don't print sublayers. - name=f'SRU_{n_units}', sublayers_to_print=[] - ) + sigmoid_activation = activation_fns.Sigmoid() + return cb.Serial( # x + cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x + cb.Split(n_items=3), # r, f, y, x + cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x + base.Fn( + "", lambda r, f, y: (y * (1.0 - f), f, r), n_out=3 # y * (1 - f), f, r, x + ), + cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), + ScanSRUCell(mode=mode), + cb.Select([0], n_in=2), # act(c), r, x + activation if activation is not None else [], + base.Fn("FinalSRUGate", lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), + # Set the name to SRU and don't print sublayers. + name=f"SRU_{n_units}", + sublayers_to_print=[], + ) def _AddSigmoidBias(sigmoid_bias): - return base.Fn('AddSigmoidBias({sigmoid_bias})', - lambda x: x + sigmoid_bias) + return base.Fn("AddSigmoidBias({sigmoid_bias})", lambda x: x + sigmoid_bias) diff --git a/trax/layers/rnn_test.py b/trax/layers/rnn_test.py deleted file mode 100644 index 40e128785..000000000 --- a/trax/layers/rnn_test.py +++ /dev/null @@ -1,77 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for rnn layers.""" - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import shapes -import trax.layers as tl - - -BACKENDS = [fastmath.Backend.JAX] - - -@parameterized.named_parameters( - ('_' + b.value, b) for b in BACKENDS) -class RnnTest(parameterized.TestCase): - - def test_conv_gru_cell(self, backend): - with fastmath.use_backend(backend): - layer = tl.ConvGRUCell(9, kernel_size=(3, 3)) - x = np.ones((8, 1, 7, 9)) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_gru_cell(self, backend): - with fastmath.use_backend(backend): - layer = tl.GRUCell(9) - xs = [np.ones((8, 7, 9)), np.ones((8, 7, 9))] - _, _ = layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual([y.shape for y in ys], [(8, 7, 9), (8, 7, 9)]) - - def test_lstm_cell(self, backend): - with fastmath.use_backend(backend): - layer = tl.LSTMCell(9) - xs = [np.ones((8, 9)), np.ones((8, 18))] - _, _ = layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual([y.shape for y in ys], [(8, 9), (8, 18)]) - - def test_sru(self, backend): - with fastmath.use_backend(backend): - layer = tl.SRU(7) - x = np.ones((8, 9, 7), np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_names(self, backend): - with fastmath.use_backend(backend): - layer = tl.LSTM(3) - self.assertEqual('LSTM_3', str(layer)) - layer = tl.GRU(5) - self.assertEqual('GRU_5', str(layer)) - layer = tl.SRU(7) - self.assertEqual('SRU_7', str(layer)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/test_utils.py b/trax/layers/test_utils.py deleted file mode 100644 index 156220314..000000000 --- a/trax/layers/test_utils.py +++ /dev/null @@ -1,283 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for testing.""" - -import copy -import functools - -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes - - -def test_eval_is_deterministic(inp, model_fn, message=''): - """Utility method for testing if eval mode is deterministic. - - Args: - inp: input fed to the model. It can be a tensor, or a tuple of tensors. - model_fn: function creating a model after calling with `mode` argument. - message: Optional message to show when outputs of eval/predict mode don't - match. - """ - with fastmath.use_backend(fastmath.Backend.JAX): - model_eval1 = model_fn(mode='eval') - model_eval2 = model_fn(mode='eval') - - input_signature = shapes.signature(inp) - model_eval1.init(input_signature) - model_eval2.init(input_signature) - model_eval1.save_to_file('/tmp/unique_weights') - model_eval2.init_from_file('/tmp/unique_weights', weights_only=True, - input_signature=input_signature) - - rng = fastmath.random.get_prng(0) - output_eval1 = model_eval1(inp, rng=rng) - if not isinstance(output_eval1, (tuple, list)): - # We will automatically check each and every tensor returned. - output_eval1 = [output_eval1] - - output_eval2 = model_eval2(inp, rng=rng) - if not isinstance(output_eval2, (tuple, list)): - # We will automatically check each and every tensor returned. - output_eval2 = [output_eval2] - - np.testing.assert_equal(len(output_eval1), len(output_eval2)) - for out1, out2 in zip(output_eval1, output_eval2): - np.testing.assert_array_almost_equal( - out1, - out2, - decimal=5, - err_msg='Non-deterministic.{}'.format(message)) - - -def test_eval_equals_predict(inp, model_fn, seq_axis=1, seq_tensor=None, - init_tokens=3, message=''): - """Utility method for testing equivalence of predict and eval modes. - - Args: - inp: input fed to the model. It can be a tensor, or a tuple of tensors. - model_fn: function creating a model after calling with `mode` argument. - seq_axis: axis of sequence_length. In predict mode we iterate over this - axis. By default `1`, which is 2nd dimension. - seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor - in this tuple on which we iterate the sequence. - init_tokens: how many tokens should be passed to the first `predict` call. - message: Optional message to show when outputs of eval/predict mode don't - match. - """ - with fastmath.use_backend(fastmath.Backend.JAX): - model_eval = model_fn(mode='eval') - model_predict = model_fn(mode='predict') - - input_signature = shapes.signature(inp) - model_eval.init(input_signature) - model_predict.init(input_signature) - model_eval.save_to_file('/tmp/unique_weights') - model_predict.init_from_file('/tmp/unique_weights', weights_only=True, - input_signature=input_signature) - - rng = fastmath.random.get_prng(0) - output_eval = model_eval(inp, rng=rng) - if not isinstance(output_eval, (tuple, list)): - # We will automatically check each and every tensor returned. - output_eval = [output_eval] - - if seq_tensor is None: - length = inp.shape[seq_axis] - else: - length = inp[seq_tensor].shape[seq_axis] - - assert length >= init_tokens + 2 # Required to properly test predict mode. - indices_list = [(0, init_tokens)] + [(i, i+1) - for i in range(init_tokens, length)] - - for indices in indices_list: - start, end = indices - if seq_tensor is None: - new_inp = inp.take(indices=np.arange(start, end), axis=seq_axis) - else: - new_inp = list(inp) - new_inp[seq_tensor] = new_inp[seq_tensor].take( - indices=np.arange(start, end), axis=seq_axis) - - output_predict = model_predict(new_inp, rng=rng) - if not isinstance(output_predict, (tuple, list)): - # We will automatically check each and every tensor returned. - output_predict = [output_predict] - - np.testing.assert_equal(len(output_predict), len(output_eval)) - for outp, oute in zip(output_predict, output_eval): - np.testing.assert_array_almost_equal( - oute.take(indices=np.arange(start, end), axis=seq_axis), - outp.take(indices=np.arange(0, end-start), axis=seq_axis), - decimal=5, - err_msg='Error on element {} out of {}.{}'.format(indices, length, - message)) - - -def test_eval_equals_predict_configs(inp, model_fn, configs, seq_axis=1, - seq_tensor=None, message=''): - """Utility method for testing equivalence of predict and eval modes. - - This function iterates over a list of dictionaries `confis`, and runs the test - on models with each configuration. - - Args: - inp: input fed to the model. It can be a tensor, or a tuple of tensors. - model_fn: function creating a model after calling with `mode` argument. - configs: List of dictionaries, which contain configs to be fed into - `model_fn`. - seq_axis: axis of sequence_length. In predict mode we iterate over this - axis. By default `1`, which is 2nd dimension. - seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor - in this tuple on which we iterate the sequence. - message: Optional message to show when outputs of eval/predict mode don't - match. - """ - for config in configs: - model_fn_configured = functools.partial(model_fn, **config) - test_eval_equals_predict(inp, model_fn_configured, seq_axis=seq_axis, - seq_tensor=seq_tensor, - message=' Config: {}.{}'.format(config, message)) - - -def test_eval_equals_predict_discrete( - model_fn, vocab_size=10, length=5, batch_size=3 -): - """Tests the equivalence of eval and predict modes for discrete models.""" - with fastmath.use_backend(fastmath.Backend.JAX): - model_slow = model_fn(mode='eval', vocab_size=vocab_size) - model_fast = model_fn(mode='predict', vocab_size=vocab_size) - rng = fastmath.random.get_prng(0) - input_signature = shapes.ShapeDtype((batch_size, 1), np.int32) - # Given the same rng, both models initialize with the same parameters. - model_slow.init(input_signature, rng) - model_fast.init(input_signature, rng) - - buf = np.zeros((batch_size, length), dtype=np.int32) - next_sym = np.zeros((batch_size, 1), dtype=np.int32) - - for index in range(length): - logits_slow = model_slow(buf, rng=rng) - logits_fast = model_fast(next_sym, rng=rng) - np.testing.assert_array_almost_equal( - logits_slow[:, index, :], logits_fast[:, 0, :], - decimal=5, - ) - next_sym = np.random.randint(vocab_size, size=(batch_size, 1)) - buf[:, index] = next_sym[:, 0] - - -class MockTransformerLM(tl.Layer): - r"""Mock TransformerLM for testing autoregressive sampling routines. - - Mimics the behavior of a perfectly-trained, deterministic TransformerLM. - Allows to specify the \sigma^* -> \sigma function implemented by the model - and to make assertions about the input sequence passed to the model. - - Supports two modes: stateful "predict" for fast inference, and stateless - non-"predict" ("train", "eval" etc). - - Useful for testing any logic that relies on autoregressive sampling, as it - removes the additional layer of complexity related to training a model or - maintaining a pretrained one. Makes the tests run MUCH faster. - - Does not support acceleration. Do not wrap in tl.Accelerate(). - """ - - def __init__(self, sequence_fn, mode, vocab_size): - super().__init__() - - self._sequence_fn = sequence_fn - self._mode = mode - self._vocab_size = vocab_size - - self._prediction_buffers = None - - @property - def state(self): - return copy.deepcopy(self._prediction_buffers) - - @state.setter - def state(self, state): - self._prediction_buffers = copy.deepcopy(state) - - def _output_symbol_predict(self, input_symbols, prediction_buffer): - prediction_buffer.extend(input_symbols) - output_symbol = self._sequence_fn(np.array(prediction_buffer)) - return np.array([output_symbol]) - - def _output_symbols_eval(self, input_symbols, prediction_buffer): - del prediction_buffer - - # Add a leading 0 token to imitate ShiftRight. - input_symbols = np.concatenate(([0], input_symbols)) - - # Call sequence_fn repeatedly along the input sequence. - return np.array([ - self._sequence_fn(input_symbols[:end]) - for end in range(1, len(input_symbols)) - ]) - - def _symbols_to_logits(self, symbols): - # Assert that symbols are discrete. - assert np.issubdtype(symbols.dtype, np.integer) - # Assert that 0 <= symbols < vocab_size. - np.testing.assert_array_less(-1, symbols) - np.testing.assert_array_less(symbols, self._vocab_size) - - # Return almost-determinisitc logits: - # e^1000 / (e^1000 + vocab_size) ~= 1 - return tl.one_hot(symbols, n_categories=self._vocab_size) * 1000.0 - - def __call__(self, inputs, rng=None): - del rng - - assert inputs.ndim == 2, ( - 'The input sequences should have exactly two axes.' - ) - - if self._prediction_buffers is None: - # Initialize the buffer. - batch_size = inputs.shape[0] - # [[]] * batch_size would create multiple references to the same - # list, and we want separate lists. - self._prediction_buffers = [[] for _ in range(batch_size)] - - if self._mode == 'predict': - output_fn = self._output_symbol_predict - else: - output_fn = self._output_symbols_eval - - # Calculate the output separately for each sequence in the batch. - output_symbols = np.array([ - output_fn(input_seq, pred_buffer) - for (input_seq, pred_buffer) in zip( - inputs, self._prediction_buffers - ) - ]) - return self._symbols_to_logits(output_symbols) - - def assert_prediction_buffers_equal(self, expected_buffers): - if self._prediction_buffers is None: - batch_size = expected_buffers.shape[0] - actual_buffers = np.empty((batch_size, 0)) - else: - actual_buffers = np.array(self._prediction_buffers) - - np.testing.assert_array_equal(actual_buffers, expected_buffers) diff --git a/trax/layers/test_utils_test.py b/trax/layers/test_utils_test.py deleted file mode 100644 index e21a50dbd..000000000 --- a/trax/layers/test_utils_test.py +++ /dev/null @@ -1,91 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.test_utils.""" - -import functools - -from absl.testing import absltest -import numpy as np - -from trax.layers import test_utils -from trax.supervised import decoding - - -def arithmetic_sequence(input_seq, limit=10): - # Increment the last symbol. Wrap to [0, 10). - return (input_seq[-1] + 1) % limit - - -class TestUtilsTest(absltest.TestCase): - - def test_mock_transformer_lm_eval_equals_predict(self): - model_fn = functools.partial( - test_utils.MockTransformerLM, - sequence_fn=arithmetic_sequence, - vocab_size=10, - ) - test_utils.test_eval_equals_predict_discrete(model_fn, vocab_size=10) - - def test_mock_transformer_lm_decodes_arithmetic_sequence(self): - model = test_utils.MockTransformerLM( - sequence_fn=arithmetic_sequence, - vocab_size=10, - mode='predict', - ) - output = decoding.autoregressive_sample( - model, max_length=5, start_id=0, eos_id=-1, accelerate=False - ) - - # Sequence including the leading 0 and the last predicted symbol. - full_seq = list(range(6)) - # decoding.autoregressive_sample doesn't return the leading 0. - np.testing.assert_array_equal(output, [full_seq[1:]]) - # The prediction buffers don't include the last predicted symbol. - model.assert_prediction_buffers_equal([full_seq[:-1]]) - - def test_mock_transformer_lm_rewinds(self): - model = test_utils.MockTransformerLM( - sequence_fn=arithmetic_sequence, - vocab_size=10, - mode='predict', - ) - sample_3 = functools.partial( - decoding.autoregressive_sample, - max_length=3, - eos_id=-1, - accelerate=False, - ) - - # Generate the 3 initial symbols. - init_output = sample_3(model, start_id=0) - np.testing.assert_array_equal(init_output, [[1, 2, 3]]) - state = model.state - - # Generate the next 3 symbols. - next_output = sample_3(model, start_id=init_output[0, -1]) - np.testing.assert_array_equal(next_output, [[4, 5, 6]]) - - # Rewind and generate the last 3 symbols again. - model.state = state - next_output = sample_3(model, start_id=init_output[0, -1]) - np.testing.assert_array_equal(next_output, [[4, 5, 6]]) - - # Check the buffers. - model.assert_prediction_buffers_equal([[0, 1, 2, 3, 4, 5]]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/Attention_Visualization_in_Trax.ipynb b/trax/models/Attention_Visualization_in_Trax.ipynb deleted file mode 100644 index 040b6d676..000000000 --- a/trax/models/Attention_Visualization_in_Trax.ipynb +++ /dev/null @@ -1,1601 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "7yuytuIllsv1" - }, - "source": [ - "# Attention Visualization in Trax\n", - "\n", - "For more information see the [tenso2tensor](https://trax-ml.readthedocs.io/en/latest/) visualization colab. All js tools are taken from the tensor2tensor version along with attention processing methods. The \"viz\" mode for a Trax model used in this colab [was added to Trax](https://github.com/google/trax/commit/e9a171379ef206a3e351b67cef91fe40bf37589c) with the attention visualization in mind. The colab re-uses some parts of the [Intro to Trax](https://github.com/google/trax/blob/master/trax/intro.ipynb) colab.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BIl27504La0G" - }, - "source": [ - "**General Setup**\n", - "\n", - "Execute the following few cells (once) before running of visualization codes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "colab": {}, - "colab_type": "code", - "id": "oILRLCWN_16u" - }, - "outputs": [], - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "\n", - "import json\n", - "import numpy as np\n", - "import os\n", - "import IPython.display as display\n", - "import gin" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 466 - }, - "colab_type": "code", - "id": "vlGjGoGMTt-D", - "outputId": "28f4556b-caef-47a1-bddd-7f51ecc064d8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 368kB 2.8MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.5MB 13.0MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2.6MB 20.1MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 163kB 33.1MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 194kB 19.4MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 983kB 30.6MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 655kB 56.6MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 81kB 11.7MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5.3MB 45.0MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 368kB 57.1MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 307kB 55.8MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 358kB 58.6MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.1MB 59.0MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3.5MB 58.4MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 778kB 59.4MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 51kB 8.7MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 51kB 8.6MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 235kB 54.2MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3.0MB 62.4MB/s \n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 890kB 58.2MB/s \n", - "\u001b[?25h Building wheel for bz2file (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for pypng (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[31mERROR: kfac 0.2.2 has requirement tensorflow-probability==0.8, but you'll have tensorflow-probability 0.7.0 which is incompatible.\u001b[0m\n", - "INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 \n" - ] - } - ], - "source": [ - "#@title\n", - "# Import Trax\n", - "\n", - "!pip install -q -U trax\n", - "import trax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", - "id": "VCBjVMrZRS6q" - }, - "outputs": [], - "source": [ - "#@title Some cool tooling for attention (make sure that you run the cell)\n", - "def resize(att_mat, max_length=None):\n", - " \"\"\"Normalize attention matrices and reshape as necessary.\"\"\"\n", - " for i, att in enumerate(att_mat):\n", - " # Add extra batch dim for viz code to work.\n", - " if att.ndim == 3:\n", - " att = np.expand_dims(att, axis=0)\n", - " if max_length is not None:\n", - " # Sum across different attention values for each token.\n", - " att = att[:, :, :max_length, :max_length]\n", - " row_sums = np.sum(att, axis=2)\n", - " # Normalize\n", - " att /= row_sums[:, :, np.newaxis]\n", - " att_mat[i] = att\n", - " return att_mat\n", - "\n", - "\n", - "def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts):\n", - " \"\"\"Compute representation of the attention ready for the d3 visualization.\n", - "\n", - " Args:\n", - " inp_text: list of strings, words to be displayed on the left of the vis\n", - " out_text: list of strings, words to be displayed on the right of the vis\n", - " enc_atts: numpy array, encoder self-attentions\n", - " [num_layers, batch_size, num_heads, enc_length, enc_length]\n", - " dec_atts: numpy array, decoder self-attentions\n", - " [num_layers, batch_size, num_heads, dec_length, dec_length]\n", - " encdec_atts: numpy array, encoder-decoder attentions\n", - " [num_layers, batch_size, num_heads, dec_length, enc_length]\n", - "\n", - " Returns:\n", - " Dictionary of attention representations with the structure:\n", - " {\n", - " 'all': Representations for showing all attentions at the same time.\n", - " 'inp_inp': Representations for showing encoder self-attentions\n", - " 'inp_out': Representations for showing encoder-decoder attentions\n", - " 'out_out': Representations for showing decoder self-attentions\n", - " }\n", - " and each sub-dictionary has structure:\n", - " {\n", - " 'att': list of inter attentions matrices, one for each attention head\n", - " 'top_text': list of strings, words to be displayed on the left of the vis\n", - " 'bot_text': list of strings, words to be displayed on the right of the vis\n", - " }\n", - " \"\"\"\n", - " def get_full_attention(layer):\n", - " \"\"\"Get the full input+output - input+output attentions.\"\"\"\n", - " enc_att = enc_atts[layer][0]\n", - " dec_att = dec_atts[layer][0]\n", - " encdec_att = encdec_atts[layer][0]\n", - " enc_att = np.transpose(enc_att, [0, 2, 1])\n", - " dec_att = np.transpose(dec_att, [0, 2, 1])\n", - " encdec_att = np.transpose(encdec_att, [0, 2, 1])\n", - " # [heads, query_length, memory_length]\n", - " enc_length = enc_att.shape[1]\n", - " dec_length = dec_att.shape[1]\n", - " num_heads = enc_att.shape[0]\n", - " first = np.concatenate([enc_att, encdec_att], axis=2)\n", - " second = np.concatenate(\n", - " [np.zeros((num_heads, dec_length, enc_length)), dec_att], axis=2)\n", - " full_att = np.concatenate([first, second], axis=1)\n", - " return [ha.T.tolist() for ha in full_att]\n", - "\n", - " def get_inp_inp_attention(layer):\n", - " att = np.transpose(enc_atts[layer][0], (0, 2, 1))\n", - " return [ha.T.tolist() for ha in att]\n", - "\n", - " def get_out_inp_attention(layer):\n", - " att = np.transpose(encdec_atts[layer][0], (0, 2, 1))\n", - " return [ha.T.tolist() for ha in att]\n", - "\n", - " def get_out_out_attention(layer):\n", - " att = np.transpose(dec_atts[layer][0], (0, 2, 1))\n", - " return [ha.T.tolist() for ha in att]\n", - "\n", - " def get_attentions(get_attention_fn):\n", - " num_layers = len(enc_atts)\n", - " return [get_attention_fn(i) for i in range(num_layers)]\n", - "\n", - " attentions = {\n", - " 'all': {\n", - " 'att': get_attentions(get_full_attention),\n", - " 'top_text': inp_text + out_text,\n", - " 'bot_text': inp_text + out_text,\n", - " },\n", - " 'inp_inp': {\n", - " 'att': get_attentions(get_inp_inp_attention),\n", - " 'top_text': inp_text,\n", - " 'bot_text': inp_text,\n", - " },\n", - " 'inp_out': {\n", - " 'att': get_attentions(get_out_inp_attention),\n", - " 'top_text': inp_text,\n", - " 'bot_text': out_text,\n", - " },\n", - " 'out_out': {\n", - " 'att': get_attentions(get_out_out_attention),\n", - " 'top_text': out_text,\n", - " 'bot_text': out_text,\n", - " },\n", - " }\n", - "\n", - " return attentions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", - "id": "47lzWIH5THcw" - }, - "outputs": [], - "source": [ - "#@title Some cool HTML and js stuff (make sure that you run the cell)\n", - "vis_html = \"\"\"\n", - " \u003cspan style=\"user-select:none\"\u003e\n", - " Layer: \u003cselect id=\"layer\"\u003e\u003c/select\u003e\n", - " Attention: \u003cselect id=\"att_type\"\u003e\n", - " \u003coption value=\"all\"\u003eAll\u003c/option\u003e\n", - " \u003coption value=\"inp_inp\"\u003eInput - Input\u003c/option\u003e\n", - " \u003coption value=\"inp_out\"\u003eInput - Output\u003c/option\u003e\n", - " \u003coption value=\"out_out\"\u003eOutput - Output\u003c/option\u003e\n", - " \u003c/select\u003e\n", - " \u003c/span\u003e\n", - " \u003cdiv id='vis'\u003e\u003c/div\u003e\n", - "\"\"\"\n", - "def call_html():\n", - " import IPython\n", - " display.display(display.HTML('''\n", - " \u003cscript src=\"/static/components/requirejs/require.js\"\u003e\u003c/script\u003e\n", - " \u003cscript\u003e\n", - " requirejs.config({\n", - " paths: {\n", - " base: '/static/base',\n", - " \"d3\": \"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min\",\n", - " jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n", - " },\n", - " });\n", - " \u003c/script\u003e\n", - " '''))\n", - "vis_js = \"\"\"\n", - "/**\n", - " * @fileoverview Transformer Visualization D3 javascript code.\n", - " */\n", - "\n", - "requirejs(['jquery', 'd3'],\n", - "function($, d3) {\n", - "\n", - "var attention = window.attention;\n", - "\n", - "const TEXT_SIZE = 15;\n", - "const BOXWIDTH = TEXT_SIZE * 8;\n", - "const BOXHEIGHT = TEXT_SIZE * 1.5;\n", - "const WIDTH = 2000;\n", - "const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100;\n", - "const MATRIX_WIDTH = 150;\n", - "const head_colours = d3.scale.category10();\n", - "const CHECKBOX_SIZE = 20;\n", - "\n", - "function lighten(colour) {\n", - " var c = d3.hsl(colour);\n", - " var increment = (1 - c.l) * 0.6;\n", - " c.l += increment;\n", - " c.s -= increment;\n", - " return c;\n", - "}\n", - "\n", - "function transpose(mat) {\n", - " return mat[0].map(function(col, i) {\n", - " return mat.map(function(row) {\n", - " return row[i];\n", - " });\n", - " });\n", - "}\n", - "\n", - "function zip(a, b) {\n", - " return a.map(function (e, i) {\n", - " return [e, b[i]];\n", - " });\n", - "}\n", - "\n", - "\n", - "function renderVis(id, top_text, bot_text, attention_heads, config) {\n", - " $(id).empty();\n", - " var svg = d3.select(id)\n", - " .append('svg')\n", - " .attr(\"width\", WIDTH)\n", - " .attr(\"height\", HEIGHT);\n", - "\n", - " var att_data = [];\n", - " for (var i=0; i \u003c attention_heads.length; i++) {\n", - " var att_trans = transpose(attention_heads[i]);\n", - " att_data.push(zip(attention_heads[i], att_trans));\n", - " }\n", - "\n", - " renderText(svg, top_text, true, att_data, 0);\n", - " renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH);\n", - "\n", - " renderAttentionHighlights(svg, att_data);\n", - "\n", - " svg.append(\"g\").classed(\"attention_heads\", true);\n", - "\n", - " renderAttention(svg, attention_heads);\n", - "\n", - " draw_checkboxes(config, 0, svg, attention_heads);\n", - "}\n", - "\n", - "\n", - "function renderText(svg, text, is_top, att_data, left_pos) {\n", - " var id = is_top ? \"top\" : \"bottom\";\n", - " var textContainer = svg.append(\"svg:g\")\n", - " .attr(\"id\", id);\n", - "\n", - " textContainer.append(\"g\").classed(\"attention_boxes\", true)\n", - " .selectAll(\"g\")\n", - " .data(att_data)\n", - " .enter()\n", - " .append(\"g\")\n", - " .selectAll(\"rect\")\n", - " .data(function(d) {return d;})\n", - " .enter()\n", - " .append(\"rect\")\n", - " .attr(\"x\", function(d, i, j) {\n", - " return left_pos + box_offset(j);\n", - " })\n", - " .attr(\"y\", function(d, i) {\n", - " return (+1) * BOXHEIGHT;\n", - " })\n", - " .attr(\"width\", BOXWIDTH/active_heads())\n", - " .attr(\"height\", function() { return BOXHEIGHT; })\n", - " .attr(\"fill\", function(d, i, j) {\n", - " return head_colours(j);\n", - " })\n", - " .style(\"opacity\", 0.0);\n", - "\n", - "\n", - " var tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n", - " .data(text)\n", - " .enter()\n", - " .append(\"g\");\n", - "\n", - " tokenContainer.append(\"rect\")\n", - " .classed(\"background\", true)\n", - " .style(\"opacity\", 0.0)\n", - " .attr(\"fill\", \"lightgray\")\n", - " .attr(\"x\", left_pos)\n", - " .attr(\"y\", function(d, i) {\n", - " return (i+1) * BOXHEIGHT;\n", - " })\n", - " .attr(\"width\", BOXWIDTH)\n", - " .attr(\"height\", BOXHEIGHT);\n", - "\n", - " var theText = tokenContainer.append(\"text\")\n", - " .text(function(d) { return d; })\n", - " .attr(\"font-size\", TEXT_SIZE + \"px\")\n", - " .style(\"cursor\", \"default\")\n", - " .style(\"-webkit-user-select\", \"none\")\n", - " .attr(\"x\", left_pos)\n", - " .attr(\"y\", function(d, i) {\n", - " return (i+1) * BOXHEIGHT;\n", - " });\n", - "\n", - " if (is_top) {\n", - " theText.style(\"text-anchor\", \"end\")\n", - " .attr(\"dx\", BOXWIDTH - TEXT_SIZE)\n", - " .attr(\"dy\", TEXT_SIZE);\n", - " } else {\n", - " theText.style(\"text-anchor\", \"start\")\n", - " .attr(\"dx\", + TEXT_SIZE)\n", - " .attr(\"dy\", TEXT_SIZE);\n", - " }\n", - "\n", - " tokenContainer.on(\"mouseover\", function(d, index) {\n", - " textContainer.selectAll(\".background\")\n", - " .style(\"opacity\", function(d, i) {\n", - " return i == index ? 1.0 : 0.0;\n", - " });\n", - "\n", - " svg.selectAll(\".attention_heads\").style(\"display\", \"none\");\n", - "\n", - " svg.selectAll(\".line_heads\") // To get the nesting to work.\n", - " .selectAll(\".att_lines\")\n", - " .attr(\"stroke-opacity\", function(d) {\n", - " return 1.0;\n", - " })\n", - " .attr(\"y1\", function(d, i) {\n", - " if (is_top) {\n", - " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " } else {\n", - " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " }\n", - " })\n", - " .attr(\"x1\", BOXWIDTH)\n", - " .attr(\"y2\", function(d, i) {\n", - " if (is_top) {\n", - " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " } else {\n", - " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " }\n", - " })\n", - " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", - " .attr(\"stroke-width\", 2)\n", - " .attr(\"stroke\", function(d, i, j) {\n", - " return head_colours(j);\n", - " })\n", - " .attr(\"stroke-opacity\", function(d, i, j) {\n", - " if (is_top) {d = d[0];} else {d = d[1];}\n", - " if (config.head_vis[j]) {\n", - " if (d) {\n", - " return d[index];\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " });\n", - "\n", - "\n", - " function updateAttentionBoxes() {\n", - " var id = is_top ? \"bottom\" : \"top\";\n", - " var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0;\n", - " svg.select(\"#\" + id)\n", - " .selectAll(\".attention_boxes\")\n", - " .selectAll(\"g\")\n", - " .selectAll(\"rect\")\n", - " .attr(\"x\", function(d, i, j) { return the_left_pos + box_offset(j); })\n", - " .attr(\"y\", function(d, i) { return (i+1) * BOXHEIGHT; })\n", - " .attr(\"width\", BOXWIDTH/active_heads())\n", - " .attr(\"height\", function() { return BOXHEIGHT; })\n", - " .style(\"opacity\", function(d, i, j) {\n", - " if (is_top) {d = d[0];} else {d = d[1];}\n", - " if (config.head_vis[j])\n", - " if (d) {\n", - " return d[index];\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " else\n", - " return 0.0;\n", - "\n", - " });\n", - " }\n", - "\n", - " updateAttentionBoxes();\n", - " });\n", - "\n", - " textContainer.on(\"mouseleave\", function() {\n", - " d3.select(this).selectAll(\".background\")\n", - " .style(\"opacity\", 0.0);\n", - "\n", - " svg.selectAll(\".att_lines\").attr(\"stroke-opacity\", 0.0);\n", - " svg.selectAll(\".attention_heads\").style(\"display\", \"inline\");\n", - " svg.selectAll(\".attention_boxes\")\n", - " .selectAll(\"g\")\n", - " .selectAll(\"rect\")\n", - " .style(\"opacity\", 0.0);\n", - " });\n", - "}\n", - "\n", - "function renderAttentionHighlights(svg, attention) {\n", - " var line_container = svg.append(\"g\");\n", - " line_container.selectAll(\"g\")\n", - " .data(attention)\n", - " .enter()\n", - " .append(\"g\")\n", - " .classed(\"line_heads\", true)\n", - " .selectAll(\"line\")\n", - " .data(function(d){return d;})\n", - " .enter()\n", - " .append(\"line\").classed(\"att_lines\", true);\n", - "}\n", - "\n", - "function renderAttention(svg, attention_heads) {\n", - " var line_container = svg.selectAll(\".attention_heads\");\n", - " line_container.html(null);\n", - " for(var h=0; h\u003cattention_heads.length; h++) {\n", - " for(var a=0; a\u003cattention_heads[h].length; a++) {\n", - " for(var s=0; s\u003cattention_heads[h][a].length; s++) {\n", - " line_container.append(\"line\")\n", - " .attr(\"y1\", (s+1) * BOXHEIGHT + (BOXHEIGHT/2))\n", - " .attr(\"x1\", BOXWIDTH)\n", - " .attr(\"y2\", (a+1) * BOXHEIGHT + (BOXHEIGHT/2))\n", - " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", - " .attr(\"stroke-width\", 2)\n", - " .attr(\"stroke\", head_colours(h))\n", - " .attr(\"stroke-opacity\", function() {\n", - " if (config.head_vis[h]) {\n", - " return attention_heads[h][a][s]/active_heads();\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " }());\n", - " }\n", - " }\n", - " }\n", - "}\n", - "\n", - "// Checkboxes\n", - "function box_offset(i) {\n", - " var num_head_above = config.head_vis.reduce(\n", - " function(acc, val, cur) {return val \u0026\u0026 cur \u003c i ? acc + 1: acc;}, 0);\n", - " return num_head_above*(BOXWIDTH / active_heads());\n", - "}\n", - "\n", - "function active_heads() {\n", - " return config.head_vis.reduce(function(acc, val) {\n", - " return val ? acc + 1: acc;\n", - " }, 0);\n", - "}\n", - "\n", - "function draw_checkboxes(config, top, svg, attention_heads) {\n", - " var checkboxContainer = svg.append(\"g\");\n", - " var checkbox = checkboxContainer.selectAll(\"rect\")\n", - " .data(config.head_vis)\n", - " .enter()\n", - " .append(\"rect\")\n", - " .attr(\"fill\", function(d, i) {\n", - " return head_colours(i);\n", - " })\n", - " .attr(\"x\", function(d, i) {\n", - " return (i+1) * CHECKBOX_SIZE;\n", - " })\n", - " .attr(\"y\", top)\n", - " .attr(\"width\", CHECKBOX_SIZE)\n", - " .attr(\"height\", CHECKBOX_SIZE);\n", - "\n", - " function update_checkboxes() {\n", - " checkboxContainer.selectAll(\"rect\")\n", - " .data(config.head_vis)\n", - " .attr(\"fill\", function(d, i) {\n", - " var head_colour = head_colours(i);\n", - " var colour = d ? head_colour : lighten(head_colour);\n", - " return colour;\n", - " });\n", - " }\n", - "\n", - " update_checkboxes();\n", - "\n", - " checkbox.on(\"click\", function(d, i) {\n", - " if (config.head_vis[i] \u0026\u0026 active_heads() == 1) return;\n", - " config.head_vis[i] = !config.head_vis[i];\n", - " update_checkboxes();\n", - " renderAttention(svg, attention_heads);\n", - " });\n", - "\n", - " checkbox.on(\"dblclick\", function(d, i) {\n", - " // If we double click on the only active head then reset\n", - " if (config.head_vis[i] \u0026\u0026 active_heads() == 1) {\n", - " config.head_vis = new Array(config.num_heads).fill(true);\n", - " } else {\n", - " config.head_vis = new Array(config.num_heads).fill(false);\n", - " config.head_vis[i] = true;\n", - " }\n", - " update_checkboxes();\n", - " renderAttention(svg, attention_heads);\n", - " });\n", - "}\n", - "\n", - "var config = {\n", - " layer: 0,\n", - " att_type: 'all',\n", - "};\n", - "\n", - "function visualize() {\n", - " var num_heads = attention['all']['att'][0].length;\n", - " config.head_vis = new Array(num_heads).fill(true);\n", - " config.num_heads = num_heads;\n", - " config.attention = attention;\n", - "\n", - " render();\n", - "}\n", - "\n", - "function render() {\n", - " var conf = config.attention[config.att_type];\n", - "\n", - " var top_text = conf.top_text;\n", - " var bot_text = conf.bot_text;\n", - " var attention = conf.att[config.layer];\n", - "\n", - " $(\"#vis svg\").empty();\n", - " renderVis(\"#vis\", top_text, bot_text, attention, config);\n", - "}\n", - "\n", - "$(\"#layer\").empty();\n", - "for(var i=0; i\u003c6; i++) {\n", - " $(\"#layer\").append($(\"\u003coption /\u003e\").val(i).text(i));\n", - "}\n", - "\n", - "$(\"#layer\").on('change', function(e) {\n", - " config.layer = +e.currentTarget.value;\n", - " render();\n", - "});\n", - "\n", - "$(\"#att_type\").on('change', function(e) {\n", - " config.att_type = e.currentTarget.value;\n", - " render();\n", - "});\n", - "\n", - "$(\"button\").on('click', visualize);\n", - "\n", - "visualize();\n", - "\n", - "});\n", - "\"\"\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-LQ89rFFsEdk" - }, - "source": [ - "## 1. Run a pre-trained Transformer\n", - "\n", - "* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)\n", - "* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)\n", - "* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)\n", - "* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)\n", - "* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "id": "djTiSLcaNFGa", - "outputId": "b5ad2955-5e1d-47aa-97bb-5d72a25ed76d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Es ist schΓΆn, heute neue Dinge zu lernen!\n" - ] - } - ], - "source": [ - "# Create a Transformer model.\n", - "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", - "model = trax.models.Transformer(\n", - " input_vocab_size=33300,\n", - " d_model=512, d_ff=2048,\n", - " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", - " max_len=2048, mode='predict')\n", - "\n", - "# Initialize using pre-trained weights.\n", - "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", - " weights_only=True)\n", - "\n", - "# Tokenize a sentence.\n", - "sentence = 'It is nice to learn new things today!'\n", - "tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.\n", - " vocab_dir='gs://trax-ml/vocabs/',\n", - " vocab_file='ende_32k.subword'))[0]\n", - "\n", - "# Decode from the Transformer.\n", - "tokenized = tokenized[None, :] # Add batch dimension.\n", - "tokenized_translation = trax.supervised.decoding.autoregressive_sample(\n", - " model, tokenized, temperature=0.0) # Higher temperature: more diverse results.\n", - "\n", - "# De-tokenize,\n", - "tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.\n", - "translation = trax.data.detokenize(tokenized_translation,\n", - " vocab_dir='gs://trax-ml/vocabs/',\n", - " vocab_file='ende_32k.subword')\n", - "print(translation)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "colab_type": "code", - "id": "pWDPwZfSJeD3", - "outputId": "050d40bf-f28d-49ea-b69a-af2886cf92a4" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([[ 118, 16, 1902, 9, 3197, 141, 1059, 420, 207]]),\n", - " array([ 168, 24, 9358, 2, 352, 367, 2427, 18, 3580, 207]))" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenized, tokenized_translation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Lu6URNjbXIHv" - }, - "source": [ - "## 2. Prepare the tokens for visualization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "kqNWMpNdMg9z" - }, - "outputs": [], - "source": [ - "def decode(single_token):\n", - " return trax.data.detokenize(single_token,\n", - " vocab_dir='gs://trax-ml/vocabs/',\n", - " vocab_file='ende_32k.subword')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "H2fbJB_BMeRw" - }, - "outputs": [], - "source": [ - "def get_tokens_str(integers):\n", - " token_strs = []\n", - " for i in range(integers.shape[1]):\n", - " token_strs.append(decode(integers[:,i]))\n", - " return token_strs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "YkNT8rbgKM5-" - }, - "outputs": [], - "source": [ - "tokenized_translation_with_start = np.array([0]+list(tokenized_translation), dtype=np.int64)\n", - "tokenized_translation_with_start = tokenized_translation_with_start[np.newaxis, ...]\n", - "tokenized_translation = np.array(tokenized_translation, dtype=np.int64)\n", - "tokenized_translation = tokenized_translation[np.newaxis, ...]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "r-FVdSZPKQhs" - }, - "outputs": [], - "source": [ - "tokenized_str = get_tokens_str(tokenized)\n", - "tokenized_translation_str = get_tokens_str(tokenized_translation_with_start)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 223 - }, - "colab_type": "code", - "id": "Cy7edKBuKash", - "outputId": "c1e00dbe-f467-48df-eaaf-579f68ef788f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(['It', 'is', 'nice', 'to', 'learn', 'new', 'things', 'today', '!'],\n", - " ['\u003cpad\u003e',\n", - " 'Es',\n", - " 'ist',\n", - " 'schΓΆn',\n", - " ', ',\n", - " 'heute',\n", - " 'neue',\n", - " 'Dinge',\n", - " 'zu',\n", - " 'lernen',\n", - " '!'])" - ] - }, - "execution_count": 11, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenized_str, tokenized_translation_str" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "1XxJSqAsOTBe" - }, - "outputs": [], - "source": [ - "max_len = max(tokenized.shape[1], tokenized_translation.shape[1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Qju-9pPHOV6G" - }, - "outputs": [], - "source": [ - "tokenized_translation_pad = np.zeros((1,max_len), dtype=np.int64)\n", - "tokenized_translation_pad[:,:tokenized_translation.shape[1]] = tokenized_translation\n", - "\n", - "tokenized_pad = np.zeros((1,max_len), dtype=np.int64)\n", - "tokenized_pad[:,:tokenized.shape[1]] = tokenized" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "id": "zGxBSk0gOfYi", - "outputId": "d83328fa-eec8-4631-d2b6-4fffc3f0b933" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((1, 10), (1, 10))" - ] - }, - "execution_count": 14, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenized_translation_pad.shape, tokenized_pad.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "WqvjmRaCXign" - }, - "source": [ - "## 3. Create the same pre-trained model in the \"viz\" mode." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Qb2F4Pj_OLMZ" - }, - "outputs": [], - "source": [ - "# Create a Transformer model in the \"viz\" mode\n", - "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", - "model_viz = trax.models.Transformer(\n", - " input_vocab_size=33300,\n", - " d_model=512, d_ff=2048,\n", - " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", - " max_len=2048, mode='viz')\n", - "\n", - "# Initialize using pre-trained weights.\n", - "model_viz.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", - " weights_only=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "AxcrAfprO0rD" - }, - "outputs": [], - "source": [ - "# We run the viz model because later we want to inspect its state\n", - "_ = model_viz((tokenized_pad, tokenized_translation_pad))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "lVCYSQSuXw6f" - }, - "source": [ - "## 4. Find the attention weights (aka dots)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "dsGuqdgnO2Lf" - }, - "outputs": [], - "source": [ - "attention_weights = []\n", - "def attention_sublayers(layer):\n", - " if 'Attention' in layer.name:\n", - " print(\"Found layer {}\".format(layer.name))\n", - " attention_weights.append(layer.state)\n", - " if layer.sublayers:\n", - " for sublayer in layer.sublayers:\n", - " attention_sublayers(sublayer)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 326 - }, - "colab_type": "code", - "id": "FA3ba2-DO5l4", - "outputId": "f66756b1-fa86-4582-bd04-9b464ae132eb" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n" - ] - } - ], - "source": [ - "attention_sublayers(model_viz)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "id": "q36-o98QO7HC", - "outputId": "445fe1ce-f1fa-484a-9db4-b37f56915d7c" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "18" - ] - }, - "execution_count": 19, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "len(attention_weights)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "LahOE6q6PB1B" - }, - "outputs": [], - "source": [ - "# Manually identification of layers would be difficult, hence we rely on attention_sublayers function\n", - "enc_atts = attention_weights[:6]\n", - "dec_atts = attention_weights[6::2] # these are the DotProductCausalAttention layers\n", - "encdec_atts = attention_weights[7::2] # these are the PureAttention layers starting from the 6th layer on\n", - "\n", - "# Here we use a number of python utils inherited from tensor2tensor\n", - "enc_atts_res = resize(enc_atts)\n", - "dec_atts_res = resize(dec_atts)\n", - "encdec_atts_res = resize(encdec_atts)\n", - "attention_dict = _get_attention(tokenized_str, tokenized_translation_str, enc_atts_res, dec_atts_res, encdec_atts_res)\n", - "attention_json = json.dumps(attention_dict)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1DgBBfg-X6-d" - }, - "source": [ - "## 5. Display attention" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "resources": { - "http://localhost:8080/static/components/requirejs/require.js": { - "data": "LyoqIHZpbTogZXQ6dHM9NDpzdz00OnN0cz00CiAqIEBsaWNlbnNlIFJlcXVpcmVKUyAyLjEuMjIgQ29weXJpZ2h0IChjKSAyMDEwLTIwMTUsIFRoZSBEb2pvIEZvdW5kYXRpb24gQWxsIFJpZ2h0cyBSZXNlcnZlZC4KICogQXZhaWxhYmxlIHZpYSB0aGUgTUlUIG9yIG5ldyBCU0QgbGljZW5zZS4KICogc2VlOiBodHRwOi8vZ2l0aHViLmNvbS9qcmJ1cmtlL3JlcXVpcmVqcyBmb3IgZGV0YWlscwogKi8KLy9Ob3QgdXNpbmcgc3RyaWN0OiB1bmV2ZW4gc3RyaWN0IHN1cHBvcnQgaW4gYnJvd3NlcnMsICMzOTIsIGFuZCBjYXVzZXMKLy9wcm9ibGVtcyB3aXRoIHJlcXVpcmVqcy5leGVjKCkvdHJhbnNwaWxlciBwbHVnaW5zIHRoYXQgbWF5IG5vdCBiZSBzdHJpY3QuCi8qanNsaW50IHJlZ2V4cDogdHJ1ZSwgbm9tZW46IHRydWUsIHNsb3BweTogdHJ1ZSAqLwovKmdsb2JhbCB3aW5kb3csIG5hdmlnYXRvciwgZG9jdW1lbnQsIGltcG9ydFNjcmlwdHMsIHNldFRpbWVvdXQsIG9wZXJhICovCgp2YXIgcmVxdWlyZWpzLCByZXF1aXJlLCBkZWZpbmU7CihmdW5jdGlvbiAoZ2xvYmFsKSB7CiAgICB2YXIgcmVxLCBzLCBoZWFkLCBiYXNlRWxlbWVudCwgZGF0YU1haW4sIHNyYywKICAgICAgICBpbnRlcmFjdGl2ZVNjcmlwdCwgY3VycmVudGx5QWRkaW5nU2NyaXB0LCBtYWluU2NyaXB0LCBzdWJQYXRoLAogICAgICAgIHZlcnNpb24gPSAnMi4xLjIyJywKICAgICAgICBjb21tZW50UmVnRXhwID0gLyhcL1wqKFtcc1xTXSo/KVwqXC98KFteOl18XilcL1wvKC4qKSQpL21nLAogICAgICAgIGNqc1JlcXVpcmVSZWdFeHAgPSAvW14uXVxzKnJlcXVpcmVccypcKFxzKlsiJ10oW14nIlxzXSspWyInXVxzKlwpL2csCiAgICAgICAganNTdWZmaXhSZWdFeHAgPSAvXC5qcyQvLAogICAgICAgIGN1cnJEaXJSZWdFeHAgPSAvXlwuXC8vLAogICAgICAgIG9wID0gT2JqZWN0LnByb3RvdHlwZSwKICAgICAgICBvc3RyaW5nID0gb3AudG9TdHJpbmcsCiAgICAgICAgaGFzT3duID0gb3AuaGFzT3duUHJvcGVydHksCiAgICAgICAgYXAgPSBBcnJheS5wcm90b3R5cGUsCiAgICAgICAgaXNCcm93c2VyID0gISEodHlwZW9mIHdpbmRvdyAhPT0gJ3VuZGVmaW5lZCcgJiYgdHlwZW9mIG5hdmlnYXRvciAhPT0gJ3VuZGVmaW5lZCcgJiYgd2luZG93LmRvY3VtZW50KSwKICAgICAgICBpc1dlYldvcmtlciA9ICFpc0Jyb3dzZXIgJiYgdHlwZW9mIGltcG9ydFNjcmlwdHMgIT09ICd1bmRlZmluZWQnLAogICAgICAgIC8vUFMzIGluZGljYXRlcyBsb2FkZWQgYW5kIGNvbXBsZXRlLCBidXQgbmVlZCB0byB3YWl0IGZvciBjb21wbGV0ZQogICAgICAgIC8vc3BlY2lmaWNhbGx5LiBTZXF1ZW5jZSBpcyAnbG9hZGluZycsICdsb2FkZWQnLCBleGVjdXRpb24sCiAgICAgICAgLy8gdGhlbiAnY29tcGxldGUnLiBUaGUgVUEgY2hlY2sgaXMgdW5mb3J0dW5hdGUsIGJ1dCBub3Qgc3VyZSBob3cKICAgICAgICAvL3RvIGZlYXR1cmUgdGVzdCB3L28gY2F1c2luZyBwZXJmIGlzc3Vlcy4KICAgICAgICByZWFkeVJlZ0V4cCA9IGlzQnJvd3NlciAmJiBuYXZpZ2F0b3IucGxhdGZvcm0gPT09ICdQTEFZU1RBVElPTiAzJyA/CiAgICAgICAgICAgICAgICAgICAgICAvXmNvbXBsZXRlJC8gOiAvXihjb21wbGV0ZXxsb2FkZWQpJC8sCiAgICAgICAgZGVmQ29udGV4dE5hbWUgPSAnXycsCiAgICAgICAgLy9PaCB0aGUgdHJhZ2VkeSwgZGV0ZWN0aW5nIG9wZXJhLiBTZWUgdGhlIHVzYWdlIG9mIGlzT3BlcmEgZm9yIHJlYXNvbi4KICAgICAgICBpc09wZXJhID0gdHlwZW9mIG9wZXJhICE9PSAndW5kZWZpbmVkJyAmJiBvcGVyYS50b1N0cmluZygpID09PSAnW29iamVjdCBPcGVyYV0nLAogICAgICAgIGNvbnRleHRzID0ge30sCiAgICAgICAgY2ZnID0ge30sCiAgICAgICAgZ2xvYmFsRGVmUXVldWUgPSBbXSwKICAgICAgICB1c2VJbnRlcmFjdGl2ZSA9IGZhbHNlOwoKICAgIGZ1bmN0aW9uIGlzRnVuY3Rpb24oaXQpIHsKICAgICAgICByZXR1cm4gb3N0cmluZy5jYWxsKGl0KSA9PT0gJ1tvYmplY3QgRnVuY3Rpb25dJzsKICAgIH0KCiAgICBmdW5jdGlvbiBpc0FycmF5KGl0KSB7CiAgICAgICAgcmV0dXJuIG9zdHJpbmcuY2FsbChpdCkgPT09ICdbb2JqZWN0IEFycmF5XSc7CiAgICB9CgogICAgLyoqCiAgICAgKiBIZWxwZXIgZnVuY3Rpb24gZm9yIGl0ZXJhdGluZyBvdmVyIGFuIGFycmF5LiBJZiB0aGUgZnVuYyByZXR1cm5zCiAgICAgKiBhIHRydWUgdmFsdWUsIGl0IHdpbGwgYnJlYWsgb3V0IG9mIHRoZSBsb29wLgogICAgICovCiAgICBmdW5jdGlvbiBlYWNoKGFyeSwgZnVuYykgewogICAgICAgIGlmIChhcnkpIHsKICAgICAgICAgICAgdmFyIGk7CiAgICAgICAgICAgIGZvciAoaSA9IDA7IGkgPCBhcnkubGVuZ3RoOyBpICs9IDEpIHsKICAgICAgICAgICAgICAgIGlmIChhcnlbaV0gJiYgZnVuYyhhcnlbaV0sIGksIGFyeSkpIHsKICAgICAgICAgICAgICAgICAgICBicmVhazsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQogICAgICAgIH0KICAgIH0KCiAgICAvKioKICAgICAqIEhlbHBlciBmdW5jdGlvbiBmb3IgaXRlcmF0aW5nIG92ZXIgYW4gYXJyYXkgYmFja3dhcmRzLiBJZiB0aGUgZnVuYwogICAgICogcmV0dXJucyBhIHRydWUgdmFsdWUsIGl0IHdpbGwgYnJlYWsgb3V0IG9mIHRoZSBsb29wLgogICAgICovCiAgICBmdW5jdGlvbiBlYWNoUmV2ZXJzZShhcnksIGZ1bmMpIHsKICAgICAgICBpZiAoYXJ5KSB7CiAgICAgICAgICAgIHZhciBpOwogICAgICAgICAgICBmb3IgKGkgPSBhcnkubGVuZ3RoIC0gMTsgaSA+IC0xOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgIGlmIChhcnlbaV0gJiYgZnVuYyhhcnlbaV0sIGksIGFyeSkpIHsKICAgICAgICAgICAgICAgICAgICBicmVhazsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQogICAgICAgIH0KICAgIH0KCiAgICBmdW5jdGlvbiBoYXNQcm9wKG9iaiwgcHJvcCkgewogICAgICAgIHJldHVybiBoYXNPd24uY2FsbChvYmosIHByb3ApOwogICAgfQoKICAgIGZ1bmN0aW9uIGdldE93bihvYmosIHByb3ApIHsKICAgICAgICByZXR1cm4gaGFzUHJvcChvYmosIHByb3ApICYmIG9ialtwcm9wXTsKICAgIH0KCiAgICAvKioKICAgICAqIEN5Y2xlcyBvdmVyIHByb3BlcnRpZXMgaW4gYW4gb2JqZWN0IGFuZCBjYWxscyBhIGZ1bmN0aW9uIGZvciBlYWNoCiAgICAgKiBwcm9wZXJ0eSB2YWx1ZS4gSWYgdGhlIGZ1bmN0aW9uIHJldHVybnMgYSB0cnV0aHkgdmFsdWUsIHRoZW4gdGhlCiAgICAgKiBpdGVyYXRpb24gaXMgc3RvcHBlZC4KICAgICAqLwogICAgZnVuY3Rpb24gZWFjaFByb3Aob2JqLCBmdW5jKSB7CiAgICAgICAgdmFyIHByb3A7CiAgICAgICAgZm9yIChwcm9wIGluIG9iaikgewogICAgICAgICAgICBpZiAoaGFzUHJvcChvYmosIHByb3ApKSB7CiAgICAgICAgICAgICAgICBpZiAoZnVuYyhvYmpbcHJvcF0sIHByb3ApKSB7CiAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9CiAgICB9CgogICAgLyoqCiAgICAgKiBTaW1wbGUgZnVuY3Rpb24gdG8gbWl4IGluIHByb3BlcnRpZXMgZnJvbSBzb3VyY2UgaW50byB0YXJnZXQsCiAgICAgKiBidXQgb25seSBpZiB0YXJnZXQgZG9lcyBub3QgYWxyZWFkeSBoYXZlIGEgcHJvcGVydHkgb2YgdGhlIHNhbWUgbmFtZS4KICAgICAqLwogICAgZnVuY3Rpb24gbWl4aW4odGFyZ2V0LCBzb3VyY2UsIGZvcmNlLCBkZWVwU3RyaW5nTWl4aW4pIHsKICAgICAgICBpZiAoc291cmNlKSB7CiAgICAgICAgICAgIGVhY2hQcm9wKHNvdXJjZSwgZnVuY3Rpb24gKHZhbHVlLCBwcm9wKSB7CiAgICAgICAgICAgICAgICBpZiAoZm9yY2UgfHwgIWhhc1Byb3AodGFyZ2V0LCBwcm9wKSkgewogICAgICAgICAgICAgICAgICAgIGlmIChkZWVwU3RyaW5nTWl4aW4gJiYgdHlwZW9mIHZhbHVlID09PSAnb2JqZWN0JyAmJiB2YWx1ZSAmJgogICAgICAgICAgICAgICAgICAgICAgICAhaXNBcnJheSh2YWx1ZSkgJiYgIWlzRnVuY3Rpb24odmFsdWUpICYmCiAgICAgICAgICAgICAgICAgICAgICAgICEodmFsdWUgaW5zdGFuY2VvZiBSZWdFeHApKSB7CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoIXRhcmdldFtwcm9wXSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgdGFyZ2V0W3Byb3BdID0ge307CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgbWl4aW4odGFyZ2V0W3Byb3BdLCB2YWx1ZSwgZm9yY2UsIGRlZXBTdHJpbmdNaXhpbik7CiAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgdGFyZ2V0W3Byb3BdID0gdmFsdWU7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9KTsKICAgICAgICB9CiAgICAgICAgcmV0dXJuIHRhcmdldDsKICAgIH0KCiAgICAvL1NpbWlsYXIgdG8gRnVuY3Rpb24ucHJvdG90eXBlLmJpbmQsIGJ1dCB0aGUgJ3RoaXMnIG9iamVjdCBpcyBzcGVjaWZpZWQKICAgIC8vZmlyc3QsIHNpbmNlIGl0IGlzIGVhc2llciB0byByZWFkL2ZpZ3VyZSBvdXQgd2hhdCAndGhpcycgd2lsbCBiZS4KICAgIGZ1bmN0aW9uIGJpbmQob2JqLCBmbikgewogICAgICAgIHJldHVybiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgIHJldHVybiBmbi5hcHBseShvYmosIGFyZ3VtZW50cyk7CiAgICAgICAgfTsKICAgIH0KCiAgICBmdW5jdGlvbiBzY3JpcHRzKCkgewogICAgICAgIHJldHVybiBkb2N1bWVudC5nZXRFbGVtZW50c0J5VGFnTmFtZSgnc2NyaXB0Jyk7CiAgICB9CgogICAgZnVuY3Rpb24gZGVmYXVsdE9uRXJyb3IoZXJyKSB7CiAgICAgICAgdGhyb3cgZXJyOwogICAgfQoKICAgIC8vQWxsb3cgZ2V0dGluZyBhIGdsb2JhbCB0aGF0IGlzIGV4cHJlc3NlZCBpbgogICAgLy9kb3Qgbm90YXRpb24sIGxpa2UgJ2EuYi5jJy4KICAgIGZ1bmN0aW9uIGdldEdsb2JhbCh2YWx1ZSkgewogICAgICAgIGlmICghdmFsdWUpIHsKICAgICAgICAgICAgcmV0dXJuIHZhbHVlOwogICAgICAgIH0KICAgICAgICB2YXIgZyA9IGdsb2JhbDsKICAgICAgICBlYWNoKHZhbHVlLnNwbGl0KCcuJyksIGZ1bmN0aW9uIChwYXJ0KSB7CiAgICAgICAgICAgIGcgPSBnW3BhcnRdOwogICAgICAgIH0pOwogICAgICAgIHJldHVybiBnOwogICAgfQoKICAgIC8qKgogICAgICogQ29uc3RydWN0cyBhbiBlcnJvciB3aXRoIGEgcG9pbnRlciB0byBhbiBVUkwgd2l0aCBtb3JlIGluZm9ybWF0aW9uLgogICAgICogQHBhcmFtIHtTdHJpbmd9IGlkIHRoZSBlcnJvciBJRCB0aGF0IG1hcHMgdG8gYW4gSUQgb24gYSB3ZWIgcGFnZS4KICAgICAqIEBwYXJhbSB7U3RyaW5nfSBtZXNzYWdlIGh1bWFuIHJlYWRhYmxlIGVycm9yLgogICAgICogQHBhcmFtIHtFcnJvcn0gW2Vycl0gdGhlIG9yaWdpbmFsIGVycm9yLCBpZiB0aGVyZSBpcyBvbmUuCiAgICAgKgogICAgICogQHJldHVybnMge0Vycm9yfQogICAgICovCiAgICBmdW5jdGlvbiBtYWtlRXJyb3IoaWQsIG1zZywgZXJyLCByZXF1aXJlTW9kdWxlcykgewogICAgICAgIHZhciBlID0gbmV3IEVycm9yKG1zZyArICdcbmh0dHA6Ly9yZXF1aXJlanMub3JnL2RvY3MvZXJyb3JzLmh0bWwjJyArIGlkKTsKICAgICAgICBlLnJlcXVpcmVUeXBlID0gaWQ7CiAgICAgICAgZS5yZXF1aXJlTW9kdWxlcyA9IHJlcXVpcmVNb2R1bGVzOwogICAgICAgIGlmIChlcnIpIHsKICAgICAgICAgICAgZS5vcmlnaW5hbEVycm9yID0gZXJyOwogICAgICAgIH0KICAgICAgICByZXR1cm4gZTsKICAgIH0KCiAgICBpZiAodHlwZW9mIGRlZmluZSAhPT0gJ3VuZGVmaW5lZCcpIHsKICAgICAgICAvL0lmIGEgZGVmaW5lIGlzIGFscmVhZHkgaW4gcGxheSB2aWEgYW5vdGhlciBBTUQgbG9hZGVyLAogICAgICAgIC8vZG8gbm90IG92ZXJ3cml0ZS4KICAgICAgICByZXR1cm47CiAgICB9CgogICAgaWYgKHR5cGVvZiByZXF1aXJlanMgIT09ICd1bmRlZmluZWQnKSB7CiAgICAgICAgaWYgKGlzRnVuY3Rpb24ocmVxdWlyZWpzKSkgewogICAgICAgICAgICAvL0RvIG5vdCBvdmVyd3JpdGUgYW4gZXhpc3RpbmcgcmVxdWlyZWpzIGluc3RhbmNlLgogICAgICAgICAgICByZXR1cm47CiAgICAgICAgfQogICAgICAgIGNmZyA9IHJlcXVpcmVqczsKICAgICAgICByZXF1aXJlanMgPSB1bmRlZmluZWQ7CiAgICB9CgogICAgLy9BbGxvdyBmb3IgYSByZXF1aXJlIGNvbmZpZyBvYmplY3QKICAgIGlmICh0eXBlb2YgcmVxdWlyZSAhPT0gJ3VuZGVmaW5lZCcgJiYgIWlzRnVuY3Rpb24ocmVxdWlyZSkpIHsKICAgICAgICAvL2Fzc3VtZSBpdCBpcyBhIGNvbmZpZyBvYmplY3QuCiAgICAgICAgY2ZnID0gcmVxdWlyZTsKICAgICAgICByZXF1aXJlID0gdW5kZWZpbmVkOwogICAgfQoKICAgIGZ1bmN0aW9uIG5ld0NvbnRleHQoY29udGV4dE5hbWUpIHsKICAgICAgICB2YXIgaW5DaGVja0xvYWRlZCwgTW9kdWxlLCBjb250ZXh0LCBoYW5kbGVycywKICAgICAgICAgICAgY2hlY2tMb2FkZWRUaW1lb3V0SWQsCiAgICAgICAgICAgIGNvbmZpZyA9IHsKICAgICAgICAgICAgICAgIC8vRGVmYXVsdHMuIERvIG5vdCBzZXQgYSBkZWZhdWx0IGZvciBtYXAKICAgICAgICAgICAgICAgIC8vY29uZmlnIHRvIHNwZWVkIHVwIG5vcm1hbGl6ZSgpLCB3aGljaAogICAgICAgICAgICAgICAgLy93aWxsIHJ1biBmYXN0ZXIgaWYgdGhlcmUgaXMgbm8gZGVmYXVsdC4KICAgICAgICAgICAgICAgIHdhaXRTZWNvbmRzOiA3LAogICAgICAgICAgICAgICAgYmFzZVVybDogJy4vJywKICAgICAgICAgICAgICAgIHBhdGhzOiB7fSwKICAgICAgICAgICAgICAgIGJ1bmRsZXM6IHt9LAogICAgICAgICAgICAgICAgcGtnczoge30sCiAgICAgICAgICAgICAgICBzaGltOiB7fSwKICAgICAgICAgICAgICAgIGNvbmZpZzoge30KICAgICAgICAgICAgfSwKICAgICAgICAgICAgcmVnaXN0cnkgPSB7fSwKICAgICAgICAgICAgLy9yZWdpc3RyeSBvZiBqdXN0IGVuYWJsZWQgbW9kdWxlcywgdG8gc3BlZWQKICAgICAgICAgICAgLy9jeWNsZSBicmVha2luZyBjb2RlIHdoZW4gbG90cyBvZiBtb2R1bGVzCiAgICAgICAgICAgIC8vYXJlIHJlZ2lzdGVyZWQsIGJ1dCBub3QgYWN0aXZhdGVkLgogICAgICAgICAgICBlbmFibGVkUmVnaXN0cnkgPSB7fSwKICAgICAgICAgICAgdW5kZWZFdmVudHMgPSB7fSwKICAgICAgICAgICAgZGVmUXVldWUgPSBbXSwKICAgICAgICAgICAgZGVmaW5lZCA9IHt9LAogICAgICAgICAgICB1cmxGZXRjaGVkID0ge30sCiAgICAgICAgICAgIGJ1bmRsZXNNYXAgPSB7fSwKICAgICAgICAgICAgcmVxdWlyZUNvdW50ZXIgPSAxLAogICAgICAgICAgICB1bm5vcm1hbGl6ZWRDb3VudGVyID0gMTsKCiAgICAgICAgLyoqCiAgICAgICAgICogVHJpbXMgdGhlIC4gYW5kIC4uIGZyb20gYW4gYXJyYXkgb2YgcGF0aCBzZWdtZW50cy4KICAgICAgICAgKiBJdCB3aWxsIGtlZXAgYSBsZWFkaW5nIHBhdGggc2VnbWVudCBpZiBhIC4uIHdpbGwgYmVjb21lCiAgICAgICAgICogdGhlIGZpcnN0IHBhdGggc2VnbWVudCwgdG8gaGVscCB3aXRoIG1vZHVsZSBuYW1lIGxvb2t1cHMsCiAgICAgICAgICogd2hpY2ggYWN0IGxpa2UgcGF0aHMsIGJ1dCBjYW4gYmUgcmVtYXBwZWQuIEJ1dCB0aGUgZW5kIHJlc3VsdCwKICAgICAgICAgKiBhbGwgcGF0aHMgdGhhdCB1c2UgdGhpcyBmdW5jdGlvbiBzaG91bGQgbG9vayBub3JtYWxpemVkLgogICAgICAgICAqIE5PVEU6IHRoaXMgbWV0aG9kIE1PRElGSUVTIHRoZSBpbnB1dCBhcnJheS4KICAgICAgICAgKiBAcGFyYW0ge0FycmF5fSBhcnkgdGhlIGFycmF5IG9mIHBhdGggc2VnbWVudHMuCiAgICAgICAgICovCiAgICAgICAgZnVuY3Rpb24gdHJpbURvdHMoYXJ5KSB7CiAgICAgICAgICAgIHZhciBpLCBwYXJ0OwogICAgICAgICAgICBmb3IgKGkgPSAwOyBpIDwgYXJ5Lmxlbmd0aDsgaSsrKSB7CiAgICAgICAgICAgICAgICBwYXJ0ID0gYXJ5W2ldOwogICAgICAgICAgICAgICAgaWYgKHBhcnQgPT09ICcuJykgewogICAgICAgICAgICAgICAgICAgIGFyeS5zcGxpY2UoaSwgMSk7CiAgICAgICAgICAgICAgICAgICAgaSAtPSAxOwogICAgICAgICAgICAgICAgfSBlbHNlIGlmIChwYXJ0ID09PSAnLi4nKSB7CiAgICAgICAgICAgICAgICAgICAgLy8gSWYgYXQgdGhlIHN0YXJ0LCBvciBwcmV2aW91cyB2YWx1ZSBpcyBzdGlsbCAuLiwKICAgICAgICAgICAgICAgICAgICAvLyBrZWVwIHRoZW0gc28gdGhhdCB3aGVuIGNvbnZlcnRlZCB0byBhIHBhdGggaXQgbWF5CiAgICAgICAgICAgICAgICAgICAgLy8gc3RpbGwgd29yayB3aGVuIGNvbnZlcnRlZCB0byBhIHBhdGgsIGV2ZW4gdGhvdWdoCiAgICAgICAgICAgICAgICAgICAgLy8gYXMgYW4gSUQgaXQgaXMgbGVzcyB0aGFuIGlkZWFsLiBJbiBsYXJnZXIgcG9pbnQKICAgICAgICAgICAgICAgICAgICAvLyByZWxlYXNlcywgbWF5IGJlIGJldHRlciB0byBqdXN0IGtpY2sgb3V0IGFuIGVycm9yLgogICAgICAgICAgICAgICAgICAgIGlmIChpID09PSAwIHx8IChpID09PSAxICYmIGFyeVsyXSA9PT0gJy4uJykgfHwgYXJ5W2kgLSAxXSA9PT0gJy4uJykgewogICAgICAgICAgICAgICAgICAgICAgICBjb250aW51ZTsKICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKGkgPiAwKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGFyeS5zcGxpY2UoaSAtIDEsIDIpOwogICAgICAgICAgICAgICAgICAgICAgICBpIC09IDI7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvKioKICAgICAgICAgKiBHaXZlbiBhIHJlbGF0aXZlIG1vZHVsZSBuYW1lLCBsaWtlIC4vc29tZXRoaW5nLCBub3JtYWxpemUgaXQgdG8KICAgICAgICAgKiBhIHJlYWwgbmFtZSB0aGF0IGNhbiBiZSBtYXBwZWQgdG8gYSBwYXRoLgogICAgICAgICAqIEBwYXJhbSB7U3RyaW5nfSBuYW1lIHRoZSByZWxhdGl2ZSBuYW1lCiAgICAgICAgICogQHBhcmFtIHtTdHJpbmd9IGJhc2VOYW1lIGEgcmVhbCBuYW1lIHRoYXQgdGhlIG5hbWUgYXJnIGlzIHJlbGF0aXZlCiAgICAgICAgICogdG8uCiAgICAgICAgICogQHBhcmFtIHtCb29sZWFufSBhcHBseU1hcCBhcHBseSB0aGUgbWFwIGNvbmZpZyB0byB0aGUgdmFsdWUuIFNob3VsZAogICAgICAgICAqIG9ubHkgYmUgZG9uZSBpZiB0aGlzIG5vcm1hbGl6YXRpb24gaXMgZm9yIGEgZGVwZW5kZW5jeSBJRC4KICAgICAgICAgKiBAcmV0dXJucyB7U3RyaW5nfSBub3JtYWxpemVkIG5hbWUKICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiBub3JtYWxpemUobmFtZSwgYmFzZU5hbWUsIGFwcGx5TWFwKSB7CiAgICAgICAgICAgIHZhciBwa2dNYWluLCBtYXBWYWx1ZSwgbmFtZVBhcnRzLCBpLCBqLCBuYW1lU2VnbWVudCwgbGFzdEluZGV4LAogICAgICAgICAgICAgICAgZm91bmRNYXAsIGZvdW5kSSwgZm91bmRTdGFyTWFwLCBzdGFySSwgbm9ybWFsaXplZEJhc2VQYXJ0cywKICAgICAgICAgICAgICAgIGJhc2VQYXJ0cyA9IChiYXNlTmFtZSAmJiBiYXNlTmFtZS5zcGxpdCgnLycpKSwKICAgICAgICAgICAgICAgIG1hcCA9IGNvbmZpZy5tYXAsCiAgICAgICAgICAgICAgICBzdGFyTWFwID0gbWFwICYmIG1hcFsnKiddOwoKICAgICAgICAgICAgLy9BZGp1c3QgYW55IHJlbGF0aXZlIHBhdGhzLgogICAgICAgICAgICBpZiAobmFtZSkgewogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuc3BsaXQoJy8nKTsKICAgICAgICAgICAgICAgIGxhc3RJbmRleCA9IG5hbWUubGVuZ3RoIC0gMTsKCiAgICAgICAgICAgICAgICAvLyBJZiB3YW50aW5nIG5vZGUgSUQgY29tcGF0aWJpbGl0eSwgc3RyaXAgLmpzIGZyb20gZW5kCiAgICAgICAgICAgICAgICAvLyBvZiBJRHMuIEhhdmUgdG8gZG8gdGhpcyBoZXJlLCBhbmQgbm90IGluIG5hbWVUb1VybAogICAgICAgICAgICAgICAgLy8gYmVjYXVzZSBub2RlIGFsbG93cyBlaXRoZXIgLmpzIG9yIG5vbiAuanMgdG8gbWFwCiAgICAgICAgICAgICAgICAvLyB0byBzYW1lIGZpbGUuCiAgICAgICAgICAgICAgICBpZiAoY29uZmlnLm5vZGVJZENvbXBhdCAmJiBqc1N1ZmZpeFJlZ0V4cC50ZXN0KG5hbWVbbGFzdEluZGV4XSkpIHsKICAgICAgICAgICAgICAgICAgICBuYW1lW2xhc3RJbmRleF0gPSBuYW1lW2xhc3RJbmRleF0ucmVwbGFjZShqc1N1ZmZpeFJlZ0V4cCwgJycpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vIFN0YXJ0cyB3aXRoIGEgJy4nIHNvIG5lZWQgdGhlIGJhc2VOYW1lCiAgICAgICAgICAgICAgICBpZiAobmFtZVswXS5jaGFyQXQoMCkgPT09ICcuJyAmJiBiYXNlUGFydHMpIHsKICAgICAgICAgICAgICAgICAgICAvL0NvbnZlcnQgYmFzZU5hbWUgdG8gYXJyYXksIGFuZCBsb3Agb2ZmIHRoZSBsYXN0IHBhcnQsCiAgICAgICAgICAgICAgICAgICAgLy9zbyB0aGF0IC4gbWF0Y2hlcyB0aGF0ICdkaXJlY3RvcnknIGFuZCBub3QgbmFtZSBvZiB0aGUgYmFzZU5hbWUncwogICAgICAgICAgICAgICAgICAgIC8vbW9kdWxlLiBGb3IgaW5zdGFuY2UsIGJhc2VOYW1lIG9mICdvbmUvdHdvL3RocmVlJywgbWFwcyB0bwogICAgICAgICAgICAgICAgICAgIC8vJ29uZS90d28vdGhyZWUuanMnLCBidXQgd2Ugd2FudCB0aGUgZGlyZWN0b3J5LCAnb25lL3R3bycgZm9yCiAgICAgICAgICAgICAgICAgICAgLy90aGlzIG5vcm1hbGl6YXRpb24uCiAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZEJhc2VQYXJ0cyA9IGJhc2VQYXJ0cy5zbGljZSgwLCBiYXNlUGFydHMubGVuZ3RoIC0gMSk7CiAgICAgICAgICAgICAgICAgICAgbmFtZSA9IG5vcm1hbGl6ZWRCYXNlUGFydHMuY29uY2F0KG5hbWUpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHRyaW1Eb3RzKG5hbWUpOwogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuam9pbignLycpOwogICAgICAgICAgICB9CgogICAgICAgICAgICAvL0FwcGx5IG1hcCBjb25maWcgaWYgYXZhaWxhYmxlLgogICAgICAgICAgICBpZiAoYXBwbHlNYXAgJiYgbWFwICYmIChiYXNlUGFydHMgfHwgc3Rhck1hcCkpIHsKICAgICAgICAgICAgICAgIG5hbWVQYXJ0cyA9IG5hbWUuc3BsaXQoJy8nKTsKCiAgICAgICAgICAgICAgICBvdXRlckxvb3A6IGZvciAoaSA9IG5hbWVQYXJ0cy5sZW5ndGg7IGkgPiAwOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICBuYW1lU2VnbWVudCA9IG5hbWVQYXJ0cy5zbGljZSgwLCBpKS5qb2luKCcvJyk7CgogICAgICAgICAgICAgICAgICAgIGlmIChiYXNlUGFydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9GaW5kIHRoZSBsb25nZXN0IGJhc2VOYW1lIHNlZ21lbnQgbWF0Y2ggaW4gdGhlIGNvbmZpZy4KICAgICAgICAgICAgICAgICAgICAgICAgLy9TbywgZG8gam9pbnMgb24gdGhlIGJpZ2dlc3QgdG8gc21hbGxlc3QgbGVuZ3RocyBvZiBiYXNlUGFydHMuCiAgICAgICAgICAgICAgICAgICAgICAgIGZvciAoaiA9IGJhc2VQYXJ0cy5sZW5ndGg7IGogPiAwOyBqIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG1hcFZhbHVlID0gZ2V0T3duKG1hcCwgYmFzZVBhcnRzLnNsaWNlKDAsIGopLmpvaW4oJy8nKSk7CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9iYXNlTmFtZSBzZWdtZW50IGhhcyBjb25maWcsIGZpbmQgaWYgaXQgaGFzIG9uZSBmb3IKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vdGhpcyBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1hcFZhbHVlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbWFwVmFsdWUgPSBnZXRPd24obWFwVmFsdWUsIG5hbWVTZWdtZW50KTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAobWFwVmFsdWUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9NYXRjaCwgdXBkYXRlIG5hbWUgdG8gdGhlIG5ldyB2YWx1ZS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZm91bmRNYXAgPSBtYXBWYWx1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZm91bmRJID0gaTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWsgb3V0ZXJMb29wOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgLy9DaGVjayBmb3IgYSBzdGFyIG1hcCBtYXRjaCwgYnV0IGp1c3QgaG9sZCBvbiB0byBpdCwKICAgICAgICAgICAgICAgICAgICAvL2lmIHRoZXJlIGlzIGEgc2hvcnRlciBzZWdtZW50IG1hdGNoIGxhdGVyIGluIGEgbWF0Y2hpbmcKICAgICAgICAgICAgICAgICAgICAvL2NvbmZpZywgdGhlbiBmYXZvciBvdmVyIHRoaXMgc3RhciBtYXAuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFmb3VuZFN0YXJNYXAgJiYgc3Rhck1hcCAmJiBnZXRPd24oc3Rhck1hcCwgbmFtZVNlZ21lbnQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGZvdW5kU3Rhck1hcCA9IGdldE93bihzdGFyTWFwLCBuYW1lU2VnbWVudCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHN0YXJJID0gaTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFmb3VuZE1hcCAmJiBmb3VuZFN0YXJNYXApIHsKICAgICAgICAgICAgICAgICAgICBmb3VuZE1hcCA9IGZvdW5kU3Rhck1hcDsKICAgICAgICAgICAgICAgICAgICBmb3VuZEkgPSBzdGFySTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBpZiAoZm91bmRNYXApIHsKICAgICAgICAgICAgICAgICAgICBuYW1lUGFydHMuc3BsaWNlKDAsIGZvdW5kSSwgZm91bmRNYXApOwogICAgICAgICAgICAgICAgICAgIG5hbWUgPSBuYW1lUGFydHMuam9pbignLycpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CgogICAgICAgICAgICAvLyBJZiB0aGUgbmFtZSBwb2ludHMgdG8gYSBwYWNrYWdlJ3MgbmFtZSwgdXNlCiAgICAgICAgICAgIC8vIHRoZSBwYWNrYWdlIG1haW4gaW5zdGVhZC4KICAgICAgICAgICAgcGtnTWFpbiA9IGdldE93bihjb25maWcucGtncywgbmFtZSk7CgogICAgICAgICAgICByZXR1cm4gcGtnTWFpbiA/IHBrZ01haW4gOiBuYW1lOwogICAgICAgIH0KCiAgICAgICAgZnVuY3Rpb24gcmVtb3ZlU2NyaXB0KG5hbWUpIHsKICAgICAgICAgICAgaWYgKGlzQnJvd3NlcikgewogICAgICAgICAgICAgICAgZWFjaChzY3JpcHRzKCksIGZ1bmN0aW9uIChzY3JpcHROb2RlKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKHNjcmlwdE5vZGUuZ2V0QXR0cmlidXRlKCdkYXRhLXJlcXVpcmVtb2R1bGUnKSA9PT0gbmFtZSAmJgogICAgICAgICAgICAgICAgICAgICAgICAgICAgc2NyaXB0Tm9kZS5nZXRBdHRyaWJ1dGUoJ2RhdGEtcmVxdWlyZWNvbnRleHQnKSA9PT0gY29udGV4dC5jb250ZXh0TmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICBzY3JpcHROb2RlLnBhcmVudE5vZGUucmVtb3ZlQ2hpbGQoc2NyaXB0Tm9kZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBoYXNQYXRoRmFsbGJhY2soaWQpIHsKICAgICAgICAgICAgdmFyIHBhdGhDb25maWcgPSBnZXRPd24oY29uZmlnLnBhdGhzLCBpZCk7CiAgICAgICAgICAgIGlmIChwYXRoQ29uZmlnICYmIGlzQXJyYXkocGF0aENvbmZpZykgJiYgcGF0aENvbmZpZy5sZW5ndGggPiAxKSB7CiAgICAgICAgICAgICAgICAvL1BvcCBvZmYgdGhlIGZpcnN0IGFycmF5IHZhbHVlLCBzaW5jZSBpdCBmYWlsZWQsIGFuZAogICAgICAgICAgICAgICAgLy9yZXRyeQogICAgICAgICAgICAgICAgcGF0aENvbmZpZy5zaGlmdCgpOwogICAgICAgICAgICAgICAgY29udGV4dC5yZXF1aXJlLnVuZGVmKGlkKTsKCiAgICAgICAgICAgICAgICAvL0N1c3RvbSByZXF1aXJlIHRoYXQgZG9lcyBub3QgZG8gbWFwIHRyYW5zbGF0aW9uLCBzaW5jZQogICAgICAgICAgICAgICAgLy9JRCBpcyAiYWJzb2x1dGUiLCBhbHJlYWR5IG1hcHBlZC9yZXNvbHZlZC4KICAgICAgICAgICAgICAgIGNvbnRleHQubWFrZVJlcXVpcmUobnVsbCwgewogICAgICAgICAgICAgICAgICAgIHNraXBNYXA6IHRydWUKICAgICAgICAgICAgICAgIH0pKFtpZF0pOwoKICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvL1R1cm5zIGEgcGx1Z2luIXJlc291cmNlIHRvIFtwbHVnaW4sIHJlc291cmNlXQogICAgICAgIC8vd2l0aCB0aGUgcGx1Z2luIGJlaW5nIHVuZGVmaW5lZCBpZiB0aGUgbmFtZQogICAgICAgIC8vZGlkIG5vdCBoYXZlIGEgcGx1Z2luIHByZWZpeC4KICAgICAgICBmdW5jdGlvbiBzcGxpdFByZWZpeChuYW1lKSB7CiAgICAgICAgICAgIHZhciBwcmVmaXgsCiAgICAgICAgICAgICAgICBpbmRleCA9IG5hbWUgPyBuYW1lLmluZGV4T2YoJyEnKSA6IC0xOwogICAgICAgICAgICBpZiAoaW5kZXggPiAtMSkgewogICAgICAgICAgICAgICAgcHJlZml4ID0gbmFtZS5zdWJzdHJpbmcoMCwgaW5kZXgpOwogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuc3Vic3RyaW5nKGluZGV4ICsgMSwgbmFtZS5sZW5ndGgpOwogICAgICAgICAgICB9CiAgICAgICAgICAgIHJldHVybiBbcHJlZml4LCBuYW1lXTsKICAgICAgICB9CgogICAgICAgIC8qKgogICAgICAgICAqIENyZWF0ZXMgYSBtb2R1bGUgbWFwcGluZyB0aGF0IGluY2x1ZGVzIHBsdWdpbiBwcmVmaXgsIG1vZHVsZQogICAgICAgICAqIG5hbWUsIGFuZCBwYXRoLiBJZiBwYXJlbnRNb2R1bGVNYXAgaXMgcHJvdmlkZWQgaXQgd2lsbAogICAgICAgICAqIGFsc28gbm9ybWFsaXplIHRoZSBuYW1lIHZpYSByZXF1aXJlLm5vcm1hbGl6ZSgpCiAgICAgICAgICoKICAgICAgICAgKiBAcGFyYW0ge1N0cmluZ30gbmFtZSB0aGUgbW9kdWxlIG5hbWUKICAgICAgICAgKiBAcGFyYW0ge1N0cmluZ30gW3BhcmVudE1vZHVsZU1hcF0gcGFyZW50IG1vZHVsZSBtYXAKICAgICAgICAgKiBmb3IgdGhlIG1vZHVsZSBuYW1lLCB1c2VkIHRvIHJlc29sdmUgcmVsYXRpdmUgbmFtZXMuCiAgICAgICAgICogQHBhcmFtIHtCb29sZWFufSBpc05vcm1hbGl6ZWQ6IGlzIHRoZSBJRCBhbHJlYWR5IG5vcm1hbGl6ZWQuCiAgICAgICAgICogVGhpcyBpcyB0cnVlIGlmIHRoaXMgY2FsbCBpcyBkb25lIGZvciBhIGRlZmluZSgpIG1vZHVsZSBJRC4KICAgICAgICAgKiBAcGFyYW0ge0Jvb2xlYW59IGFwcGx5TWFwOiBhcHBseSB0aGUgbWFwIGNvbmZpZyB0byB0aGUgSUQuCiAgICAgICAgICogU2hvdWxkIG9ubHkgYmUgdHJ1ZSBpZiB0aGlzIG1hcCBpcyBmb3IgYSBkZXBlbmRlbmN5LgogICAgICAgICAqCiAgICAgICAgICogQHJldHVybnMge09iamVjdH0KICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiBtYWtlTW9kdWxlTWFwKG5hbWUsIHBhcmVudE1vZHVsZU1hcCwgaXNOb3JtYWxpemVkLCBhcHBseU1hcCkgewogICAgICAgICAgICB2YXIgdXJsLCBwbHVnaW5Nb2R1bGUsIHN1ZmZpeCwgbmFtZVBhcnRzLAogICAgICAgICAgICAgICAgcHJlZml4ID0gbnVsbCwKICAgICAgICAgICAgICAgIHBhcmVudE5hbWUgPSBwYXJlbnRNb2R1bGVNYXAgPyBwYXJlbnRNb2R1bGVNYXAubmFtZSA6IG51bGwsCiAgICAgICAgICAgICAgICBvcmlnaW5hbE5hbWUgPSBuYW1lLAogICAgICAgICAgICAgICAgaXNEZWZpbmUgPSB0cnVlLAogICAgICAgICAgICAgICAgbm9ybWFsaXplZE5hbWUgPSAnJzsKCiAgICAgICAgICAgIC8vSWYgbm8gbmFtZSwgdGhlbiBpdCBtZWFucyBpdCBpcyBhIHJlcXVpcmUgY2FsbCwgZ2VuZXJhdGUgYW4KICAgICAgICAgICAgLy9pbnRlcm5hbCBuYW1lLgogICAgICAgICAgICBpZiAoIW5hbWUpIHsKICAgICAgICAgICAgICAgIGlzRGVmaW5lID0gZmFsc2U7CiAgICAgICAgICAgICAgICBuYW1lID0gJ19AcicgKyAocmVxdWlyZUNvdW50ZXIgKz0gMSk7CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIG5hbWVQYXJ0cyA9IHNwbGl0UHJlZml4KG5hbWUpOwogICAgICAgICAgICBwcmVmaXggPSBuYW1lUGFydHNbMF07CiAgICAgICAgICAgIG5hbWUgPSBuYW1lUGFydHNbMV07CgogICAgICAgICAgICBpZiAocHJlZml4KSB7CiAgICAgICAgICAgICAgICBwcmVmaXggPSBub3JtYWxpemUocHJlZml4LCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CiAgICAgICAgICAgICAgICBwbHVnaW5Nb2R1bGUgPSBnZXRPd24oZGVmaW5lZCwgcHJlZml4KTsKICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9BY2NvdW50IGZvciByZWxhdGl2ZSBwYXRocyBpZiB0aGVyZSBpcyBhIGJhc2UgbmFtZS4KICAgICAgICAgICAgaWYgKG5hbWUpIHsKICAgICAgICAgICAgICAgIGlmIChwcmVmaXgpIHsKICAgICAgICAgICAgICAgICAgICBpZiAocGx1Z2luTW9kdWxlICYmIHBsdWdpbk1vZHVsZS5ub3JtYWxpemUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9QbHVnaW4gaXMgbG9hZGVkLCB1c2UgaXRzIG5vcm1hbGl6ZSBtZXRob2QuCiAgICAgICAgICAgICAgICAgICAgICAgIG5vcm1hbGl6ZWROYW1lID0gcGx1Z2luTW9kdWxlLm5vcm1hbGl6ZShuYW1lLCBmdW5jdGlvbiAobmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vIElmIG5lc3RlZCBwbHVnaW4gcmVmZXJlbmNlcywgdGhlbiBkbyBub3QgdHJ5IHRvCiAgICAgICAgICAgICAgICAgICAgICAgIC8vIG5vcm1hbGl6ZSwgYXMgaXQgd2lsbCBub3Qgbm9ybWFsaXplIGNvcnJlY3RseS4gVGhpcwogICAgICAgICAgICAgICAgICAgICAgICAvLyBwbGFjZXMgYSByZXN0cmljdGlvbiBvbiByZXNvdXJjZUlkcywgYW5kIHRoZSBsb25nZXIKICAgICAgICAgICAgICAgICAgICAgICAgLy8gdGVybSBzb2x1dGlvbiBpcyBub3QgdG8gbm9ybWFsaXplIHVudGlsIHBsdWdpbnMgYXJlCiAgICAgICAgICAgICAgICAgICAgICAgIC8vIGxvYWRlZCBhbmQgYWxsIG5vcm1hbGl6YXRpb25zIHRvIGFsbG93IGZvciBhc3luYwogICAgICAgICAgICAgICAgICAgICAgICAvLyBsb2FkaW5nIG9mIGEgbG9hZGVyIHBsdWdpbi4gQnV0IGZvciBub3csIGZpeGVzIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAvLyBjb21tb24gdXNlcy4gRGV0YWlscyBpbiAjMTEzMQogICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTmFtZSA9IG5hbWUuaW5kZXhPZignIScpID09PSAtMSA/CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplKG5hbWUsIHBhcmVudE5hbWUsIGFwcGx5TWFwKSA6CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbmFtZTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgIC8vQSByZWd1bGFyIG1vZHVsZS4KICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTmFtZSA9IG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CgogICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplZCBuYW1lIG1heSBiZSBhIHBsdWdpbiBJRCBkdWUgdG8gbWFwIGNvbmZpZwogICAgICAgICAgICAgICAgICAgIC8vYXBwbGljYXRpb24gaW4gbm9ybWFsaXplLiBUaGUgbWFwIGNvbmZpZyB2YWx1ZXMgbXVzdAogICAgICAgICAgICAgICAgICAgIC8vYWxyZWFkeSBiZSBub3JtYWxpemVkLCBzbyBkbyBub3QgbmVlZCB0byByZWRvIHRoYXQgcGFydC4KICAgICAgICAgICAgICAgICAgICBuYW1lUGFydHMgPSBzcGxpdFByZWZpeChub3JtYWxpemVkTmFtZSk7CiAgICAgICAgICAgICAgICAgICAgcHJlZml4ID0gbmFtZVBhcnRzWzBdOwogICAgICAgICAgICAgICAgICAgIG5vcm1hbGl6ZWROYW1lID0gbmFtZVBhcnRzWzFdOwogICAgICAgICAgICAgICAgICAgIGlzTm9ybWFsaXplZCA9IHRydWU7CgogICAgICAgICAgICAgICAgICAgIHVybCA9IGNvbnRleHQubmFtZVRvVXJsKG5vcm1hbGl6ZWROYW1lKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9JZiB0aGUgaWQgaXMgYSBwbHVnaW4gaWQgdGhhdCBjYW5ub3QgYmUgZGV0ZXJtaW5lZCBpZiBpdCBuZWVkcwogICAgICAgICAgICAvL25vcm1hbGl6YXRpb24sIHN0YW1wIGl0IHdpdGggYSB1bmlxdWUgSUQgc28gdHdvIG1hdGNoaW5nIHJlbGF0aXZlCiAgICAgICAgICAgIC8vaWRzIHRoYXQgbWF5IGNvbmZsaWN0IGNhbiBiZSBzZXBhcmF0ZS4KICAgICAgICAgICAgc3VmZml4ID0gcHJlZml4ICYmICFwbHVnaW5Nb2R1bGUgJiYgIWlzTm9ybWFsaXplZCA/CiAgICAgICAgICAgICAgICAgICAgICdfdW5ub3JtYWxpemVkJyArICh1bm5vcm1hbGl6ZWRDb3VudGVyICs9IDEpIDoKICAgICAgICAgICAgICAgICAgICAgJyc7CgogICAgICAgICAgICByZXR1cm4gewogICAgICAgICAgICAgICAgcHJlZml4OiBwcmVmaXgsCiAgICAgICAgICAgICAgICBuYW1lOiBub3JtYWxpemVkTmFtZSwKICAgICAgICAgICAgICAgIHBhcmVudE1hcDogcGFyZW50TW9kdWxlTWFwLAogICAgICAgICAgICAgICAgdW5ub3JtYWxpemVkOiAhIXN1ZmZpeCwKICAgICAgICAgICAgICAgIHVybDogdXJsLAogICAgICAgICAgICAgICAgb3JpZ2luYWxOYW1lOiBvcmlnaW5hbE5hbWUsCiAgICAgICAgICAgICAgICBpc0RlZmluZTogaXNEZWZpbmUsCiAgICAgICAgICAgICAgICBpZDogKHByZWZpeCA/CiAgICAgICAgICAgICAgICAgICAgICAgIHByZWZpeCArICchJyArIG5vcm1hbGl6ZWROYW1lIDoKICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZE5hbWUpICsgc3VmZml4CiAgICAgICAgICAgIH07CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBnZXRNb2R1bGUoZGVwTWFwKSB7CiAgICAgICAgICAgIHZhciBpZCA9IGRlcE1hcC5pZCwKICAgICAgICAgICAgICAgIG1vZCA9IGdldE93bihyZWdpc3RyeSwgaWQpOwoKICAgICAgICAgICAgaWYgKCFtb2QpIHsKICAgICAgICAgICAgICAgIG1vZCA9IHJlZ2lzdHJ5W2lkXSA9IG5ldyBjb250ZXh0Lk1vZHVsZShkZXBNYXApOwogICAgICAgICAgICB9CgogICAgICAgICAgICByZXR1cm4gbW9kOwogICAgICAgIH0KCiAgICAgICAgZnVuY3Rpb24gb24oZGVwTWFwLCBuYW1lLCBmbikgewogICAgICAgICAgICB2YXIgaWQgPSBkZXBNYXAuaWQsCiAgICAgICAgICAgICAgICBtb2QgPSBnZXRPd24ocmVnaXN0cnksIGlkKTsKCiAgICAgICAgICAgIGlmIChoYXNQcm9wKGRlZmluZWQsIGlkKSAmJgogICAgICAgICAgICAgICAgICAgICghbW9kIHx8IG1vZC5kZWZpbmVFbWl0Q29tcGxldGUpKSB7CiAgICAgICAgICAgICAgICBpZiAobmFtZSA9PT0gJ2RlZmluZWQnKSB7CiAgICAgICAgICAgICAgICAgICAgZm4oZGVmaW5lZFtpZF0pOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgbW9kID0gZ2V0TW9kdWxlKGRlcE1hcCk7CiAgICAgICAgICAgICAgICBpZiAobW9kLmVycm9yICYmIG5hbWUgPT09ICdlcnJvcicpIHsKICAgICAgICAgICAgICAgICAgICBmbihtb2QuZXJyb3IpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICBtb2Qub24obmFtZSwgZm4pOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBvbkVycm9yKGVyciwgZXJyYmFjaykgewogICAgICAgICAgICB2YXIgaWRzID0gZXJyLnJlcXVpcmVNb2R1bGVzLAogICAgICAgICAgICAgICAgbm90aWZpZWQgPSBmYWxzZTsKCiAgICAgICAgICAgIGlmIChlcnJiYWNrKSB7CiAgICAgICAgICAgICAgICBlcnJiYWNrKGVycik7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBlYWNoKGlkcywgZnVuY3Rpb24gKGlkKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIG1vZCA9IGdldE93bihyZWdpc3RyeSwgaWQpOwogICAgICAgICAgICAgICAgICAgIGlmIChtb2QpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9TZXQgZXJyb3Igb24gbW9kdWxlLCBzbyBpdCBza2lwcyB0aW1lb3V0IGNoZWNrcy4KICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmVycm9yID0gZXJyOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAobW9kLmV2ZW50cy5lcnJvcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgbm90aWZpZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgIGlmICghbm90aWZpZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXEub25FcnJvcihlcnIpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvKioKICAgICAgICAgKiBJbnRlcm5hbCBtZXRob2QgdG8gdHJhbnNmZXIgZ2xvYmFsUXVldWUgaXRlbXMgdG8gdGhpcyBjb250ZXh0J3MKICAgICAgICAgKiBkZWZRdWV1ZS4KICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiB0YWtlR2xvYmFsUXVldWUoKSB7CiAgICAgICAgICAgIC8vUHVzaCBhbGwgdGhlIGdsb2JhbERlZlF1ZXVlIGl0ZW1zIGludG8gdGhlIGNvbnRleHQncyBkZWZRdWV1ZQogICAgICAgICAgICBpZiAoZ2xvYmFsRGVmUXVldWUubGVuZ3RoKSB7CiAgICAgICAgICAgICAgICBlYWNoKGdsb2JhbERlZlF1ZXVlLCBmdW5jdGlvbihxdWV1ZUl0ZW0pIHsKICAgICAgICAgICAgICAgICAgICB2YXIgaWQgPSBxdWV1ZUl0ZW1bMF07CiAgICAgICAgICAgICAgICAgICAgaWYgKHR5cGVvZiBpZCA9PT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZU1hcFtpZF0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICBkZWZRdWV1ZS5wdXNoKHF1ZXVlSXRlbSk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIGdsb2JhbERlZlF1ZXVlID0gW107CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIGhhbmRsZXJzID0gewogICAgICAgICAgICAncmVxdWlyZSc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIGlmIChtb2QucmVxdWlyZSkgewogICAgICAgICAgICAgICAgICAgIHJldHVybiBtb2QucmVxdWlyZTsKICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChtb2QucmVxdWlyZSA9IGNvbnRleHQubWFrZVJlcXVpcmUobW9kLm1hcCkpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAogICAgICAgICAgICAnZXhwb3J0cyc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIG1vZC51c2luZ0V4cG9ydHMgPSB0cnVlOwogICAgICAgICAgICAgICAgaWYgKG1vZC5tYXAuaXNEZWZpbmUpIHsKICAgICAgICAgICAgICAgICAgICBpZiAobW9kLmV4cG9ydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChkZWZpbmVkW21vZC5tYXAuaWRdID0gbW9kLmV4cG9ydHMpOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiAobW9kLmV4cG9ydHMgPSBkZWZpbmVkW21vZC5tYXAuaWRdID0ge30pOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKICAgICAgICAgICAgJ21vZHVsZSc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIGlmIChtb2QubW9kdWxlKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG1vZC5tb2R1bGU7CiAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgIHJldHVybiAobW9kLm1vZHVsZSA9IHsKICAgICAgICAgICAgICAgICAgICAgICAgaWQ6IG1vZC5tYXAuaWQsCiAgICAgICAgICAgICAgICAgICAgICAgIHVyaTogbW9kLm1hcC51cmwsCiAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZzogZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGdldE93bihjb25maWcuY29uZmlnLCBtb2QubWFwLmlkKSB8fCB7fTsKICAgICAgICAgICAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAgICAgICAgICAgZXhwb3J0czogbW9kLmV4cG9ydHMgfHwgKG1vZC5leHBvcnRzID0ge30pCiAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBmdW5jdGlvbiBjbGVhblJlZ2lzdHJ5KGlkKSB7CiAgICAgICAgICAgIC8vQ2xlYW4gdXAgbWFjaGluZXJ5IHVzZWQgZm9yIHdhaXRpbmcgbW9kdWxlcy4KICAgICAgICAgICAgZGVsZXRlIHJlZ2lzdHJ5W2lkXTsKICAgICAgICAgICAgZGVsZXRlIGVuYWJsZWRSZWdpc3RyeVtpZF07CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBicmVha0N5Y2xlKG1vZCwgdHJhY2VkLCBwcm9jZXNzZWQpIHsKICAgICAgICAgICAgdmFyIGlkID0gbW9kLm1hcC5pZDsKCiAgICAgICAgICAgIGlmIChtb2QuZXJyb3IpIHsKICAgICAgICAgICAgICAgIG1vZC5lbWl0KCdlcnJvcicsIG1vZC5lcnJvcik7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICB0cmFjZWRbaWRdID0gdHJ1ZTsKICAgICAgICAgICAgICAgIGVhY2gobW9kLmRlcE1hcHMsIGZ1bmN0aW9uIChkZXBNYXAsIGkpIHsKICAgICAgICAgICAgICAgICAgICB2YXIgZGVwSWQgPSBkZXBNYXAuaWQsCiAgICAgICAgICAgICAgICAgICAgICAgIGRlcCA9IGdldE93bihyZWdpc3RyeSwgZGVwSWQpOwoKICAgICAgICAgICAgICAgICAgICAvL09ubHkgZm9yY2UgdGhpbmdzIHRoYXQgaGF2ZSBub3QgY29tcGxldGVkCiAgICAgICAgICAgICAgICAgICAgLy9iZWluZyBkZWZpbmVkLCBzbyBzdGlsbCBpbiB0aGUgcmVnaXN0cnksCiAgICAgICAgICAgICAgICAgICAgLy9hbmQgb25seSBpZiBpdCBoYXMgbm90IGJlZW4gbWF0Y2hlZCB1cAogICAgICAgICAgICAgICAgICAgIC8vaW4gdGhlIG1vZHVsZSBhbHJlYWR5LgogICAgICAgICAgICAgICAgICAgIGlmIChkZXAgJiYgIW1vZC5kZXBNYXRjaGVkW2ldICYmICFwcm9jZXNzZWRbZGVwSWRdKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChnZXRPd24odHJhY2VkLCBkZXBJZCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG1vZC5kZWZpbmVEZXAoaSwgZGVmaW5lZFtkZXBJZF0pOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmNoZWNrKCk7IC8vcGFzcyBmYWxzZT8KICAgICAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGJyZWFrQ3ljbGUoZGVwLCB0cmFjZWQsIHByb2Nlc3NlZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIHByb2Nlc3NlZFtpZF0gPSB0cnVlOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBjaGVja0xvYWRlZCgpIHsKICAgICAgICAgICAgdmFyIGVyciwgdXNpbmdQYXRoRmFsbGJhY2ssCiAgICAgICAgICAgICAgICB3YWl0SW50ZXJ2YWwgPSBjb25maWcud2FpdFNlY29uZHMgKiAxMDAwLAogICAgICAgICAgICAgICAgLy9JdCBpcyBwb3NzaWJsZSB0byBkaXNhYmxlIHRoZSB3YWl0IGludGVydmFsIGJ5IHVzaW5nIHdhaXRTZWNvbmRzIG9mIDAuCiAgICAgICAgICAgICAgICBleHBpcmVkID0gd2FpdEludGVydmFsICYmIChjb250ZXh0LnN0YXJ0VGltZSArIHdhaXRJbnRlcnZhbCkgPCBuZXcgRGF0ZSgpLmdldFRpbWUoKSwKICAgICAgICAgICAgICAgIG5vTG9hZHMgPSBbXSwKICAgICAgICAgICAgICAgIHJlcUNhbGxzID0gW10sCiAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSBmYWxzZSwKICAgICAgICAgICAgICAgIG5lZWRDeWNsZUNoZWNrID0gdHJ1ZTsKCiAgICAgICAgICAgIC8vRG8gbm90IGJvdGhlciBpZiB0aGlzIGNhbGwgd2FzIGEgcmVzdWx0IG9mIGEgY3ljbGUgYnJlYWsuCiAgICAgICAgICAgIGlmIChpbkNoZWNrTG9hZGVkKSB7CiAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIGluQ2hlY2tMb2FkZWQgPSB0cnVlOwoKICAgICAgICAgICAgLy9GaWd1cmUgb3V0IHRoZSBzdGF0ZSBvZiBhbGwgdGhlIG1vZHVsZXMuCiAgICAgICAgICAgIGVhY2hQcm9wKGVuYWJsZWRSZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCkgewogICAgICAgICAgICAgICAgdmFyIG1hcCA9IG1vZC5tYXAsCiAgICAgICAgICAgICAgICAgICAgbW9kSWQgPSBtYXAuaWQ7CgogICAgICAgICAgICAgICAgLy9Ta2lwIHRoaW5ncyB0aGF0IGFyZSBub3QgZW5hYmxlZCBvciBpbiBlcnJvciBzdGF0ZS4KICAgICAgICAgICAgICAgIGlmICghbW9kLmVuYWJsZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFtYXAuaXNEZWZpbmUpIHsKICAgICAgICAgICAgICAgICAgICByZXFDYWxscy5wdXNoKG1vZCk7CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFtb2QuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICAvL0lmIHRoZSBtb2R1bGUgc2hvdWxkIGJlIGV4ZWN1dGVkLCBhbmQgaXQgaGFzIG5vdAogICAgICAgICAgICAgICAgICAgIC8vYmVlbiBpbml0ZWQgYW5kIHRpbWUgaXMgdXAsIHJlbWVtYmVyIGl0LgogICAgICAgICAgICAgICAgICAgIGlmICghbW9kLmluaXRlZCAmJiBleHBpcmVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChoYXNQYXRoRmFsbGJhY2sobW9kSWQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB1c2luZ1BhdGhGYWxsYmFjayA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgbm9Mb2Fkcy5wdXNoKG1vZElkKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlbW92ZVNjcmlwdChtb2RJZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKCFtb2QuaW5pdGVkICYmIG1vZC5mZXRjaGVkICYmIG1hcC5pc0RlZmluZSkgewogICAgICAgICAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAoIW1hcC5wcmVmaXgpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vTm8gcmVhc29uIHRvIGtlZXAgbG9va2luZyBmb3IgdW5maW5pc2hlZAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9sb2FkaW5nLiBJZiB0aGUgb25seSBzdGlsbExvYWRpbmcgaXMgYQogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9wbHVnaW4gcmVzb3VyY2UgdGhvdWdoLCBrZWVwIGdvaW5nLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9iZWNhdXNlIGl0IG1heSBiZSB0aGF0IGEgcGx1Z2luIHJlc291cmNlCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL2lzIHdhaXRpbmcgb24gYSBub24tcGx1Z2luIGN5Y2xlLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChuZWVkQ3ljbGVDaGVjayA9IGZhbHNlKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSk7CgogICAgICAgICAgICBpZiAoZXhwaXJlZCAmJiBub0xvYWRzLmxlbmd0aCkgewogICAgICAgICAgICAgICAgLy9JZiB3YWl0IHRpbWUgZXhwaXJlZCwgdGhyb3cgZXJyb3Igb2YgdW5sb2FkZWQgbW9kdWxlcy4KICAgICAgICAgICAgICAgIGVyciA9IG1ha2VFcnJvcigndGltZW91dCcsICdMb2FkIHRpbWVvdXQgZm9yIG1vZHVsZXM6ICcgKyBub0xvYWRzLCBudWxsLCBub0xvYWRzKTsKICAgICAgICAgICAgICAgIGVyci5jb250ZXh0TmFtZSA9IGNvbnRleHQuY29udGV4dE5hbWU7CiAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihlcnIpOwogICAgICAgICAgICB9CgogICAgICAgICAgICAvL05vdCBleHBpcmVkLCBjaGVjayBmb3IgYSBjeWNsZS4KICAgICAgICAgICAgaWYgKG5lZWRDeWNsZUNoZWNrKSB7CiAgICAgICAgICAgICAgICBlYWNoKHJlcUNhbGxzLCBmdW5jdGlvbiAobW9kKSB7CiAgICAgICAgICAgICAgICAgICAgYnJlYWtDeWNsZShtb2QsIHt9LCB7fSk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9JZiBzdGlsbCB3YWl0aW5nIG9uIGxvYWRzLCBhbmQgdGhlIHdhaXRpbmcgbG9hZCBpcyBzb21ldGhpbmcKICAgICAgICAgICAgLy9vdGhlciB0aGFuIGEgcGx1Z2luIHJlc291cmNlLCBvciB0aGVyZSBhcmUgc3RpbGwgb3V0c3RhbmRpbmcKICAgICAgICAgICAgLy9zY3JpcHRzLCB0aGVuIGp1c3QgdHJ5IGJhY2sgbGF0ZXIuCiAgICAgICAgICAgIGlmICgoIWV4cGlyZWQgfHwgdXNpbmdQYXRoRmFsbGJhY2spICYmIHN0aWxsTG9hZGluZykgewogICAgICAgICAgICAgICAgLy9Tb21ldGhpbmcgaXMgc3RpbGwgd2FpdGluZyB0byBsb2FkLiBXYWl0IGZvciBpdCwgYnV0IG9ubHkKICAgICAgICAgICAgICAgIC8vaWYgYSB0aW1lb3V0IGlzIG5vdCBhbHJlYWR5IGluIGVmZmVjdC4KICAgICAgICAgICAgICAgIGlmICgoaXNCcm93c2VyIHx8IGlzV2ViV29ya2VyKSAmJiAhY2hlY2tMb2FkZWRUaW1lb3V0SWQpIHsKICAgICAgICAgICAgICAgICAgICBjaGVja0xvYWRlZFRpbWVvdXRJZCA9IHNldFRpbWVvdXQoZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICBjaGVja0xvYWRlZFRpbWVvdXRJZCA9IDA7CiAgICAgICAgICAgICAgICAgICAgICAgIGNoZWNrTG9hZGVkKCk7CiAgICAgICAgICAgICAgICAgICAgfSwgNTApOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CgogICAgICAgICAgICBpbkNoZWNrTG9hZGVkID0gZmFsc2U7CiAgICAgICAgfQoKICAgICAgICBNb2R1bGUgPSBmdW5jdGlvbiAobWFwKSB7CiAgICAgICAgICAgIHRoaXMuZXZlbnRzID0gZ2V0T3duKHVuZGVmRXZlbnRzLCBtYXAuaWQpIHx8IHt9OwogICAgICAgICAgICB0aGlzLm1hcCA9IG1hcDsKICAgICAgICAgICAgdGhpcy5zaGltID0gZ2V0T3duKGNvbmZpZy5zaGltLCBtYXAuaWQpOwogICAgICAgICAgICB0aGlzLmRlcEV4cG9ydHMgPSBbXTsKICAgICAgICAgICAgdGhpcy5kZXBNYXBzID0gW107CiAgICAgICAgICAgIHRoaXMuZGVwTWF0Y2hlZCA9IFtdOwogICAgICAgICAgICB0aGlzLnBsdWdpbk1hcHMgPSB7fTsKICAgICAgICAgICAgdGhpcy5kZXBDb3VudCA9IDA7CgogICAgICAgICAgICAvKiB0aGlzLmV4cG9ydHMgdGhpcy5mYWN0b3J5CiAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwcyA9IFtdLAogICAgICAgICAgICAgICB0aGlzLmVuYWJsZWQsIHRoaXMuZmV0Y2hlZAogICAgICAgICAgICAqLwogICAgICAgIH07CgogICAgICAgIE1vZHVsZS5wcm90b3R5cGUgPSB7CiAgICAgICAgICAgIGluaXQ6IGZ1bmN0aW9uIChkZXBNYXBzLCBmYWN0b3J5LCBlcnJiYWNrLCBvcHRpb25zKSB7CiAgICAgICAgICAgICAgICBvcHRpb25zID0gb3B0aW9ucyB8fCB7fTsKCiAgICAgICAgICAgICAgICAvL0RvIG5vdCBkbyBtb3JlIGluaXRzIGlmIGFscmVhZHkgZG9uZS4gQ2FuIGhhcHBlbiBpZiB0aGVyZQogICAgICAgICAgICAgICAgLy9hcmUgbXVsdGlwbGUgZGVmaW5lIGNhbGxzIGZvciB0aGUgc2FtZSBtb2R1bGUuIFRoYXQgaXMgbm90CiAgICAgICAgICAgICAgICAvL2Egbm9ybWFsLCBjb21tb24gY2FzZSwgYnV0IGl0IGlzIGFsc28gbm90IHVuZXhwZWN0ZWQuCiAgICAgICAgICAgICAgICBpZiAodGhpcy5pbml0ZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgdGhpcy5mYWN0b3J5ID0gZmFjdG9yeTsKCiAgICAgICAgICAgICAgICBpZiAoZXJyYmFjaykgewogICAgICAgICAgICAgICAgICAgIC8vUmVnaXN0ZXIgZm9yIGVycm9ycyBvbiB0aGlzIG1vZHVsZS4KICAgICAgICAgICAgICAgICAgICB0aGlzLm9uKCdlcnJvcicsIGVycmJhY2spOwogICAgICAgICAgICAgICAgfSBlbHNlIGlmICh0aGlzLmV2ZW50cy5lcnJvcikgewogICAgICAgICAgICAgICAgICAgIC8vSWYgbm8gZXJyYmFjayBhbHJlYWR5LCBidXQgdGhlcmUgYXJlIGVycm9yIGxpc3RlbmVycwogICAgICAgICAgICAgICAgICAgIC8vb24gdGhpcyBtb2R1bGUsIHNldCB1cCBhbiBlcnJiYWNrIHRvIHBhc3MgdG8gdGhlIGRlcHMuCiAgICAgICAgICAgICAgICAgICAgZXJyYmFjayA9IGJpbmQodGhpcywgZnVuY3Rpb24gKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL0RvIGEgY29weSBvZiB0aGUgZGVwZW5kZW5jeSBhcnJheSwgc28gdGhhdAogICAgICAgICAgICAgICAgLy9zb3VyY2UgaW5wdXRzIGFyZSBub3QgbW9kaWZpZWQuIEZvciBleGFtcGxlCiAgICAgICAgICAgICAgICAvLyJzaGltIiBkZXBzIGFyZSBwYXNzZWQgaW4gaGVyZSBkaXJlY3RseSwgYW5kCiAgICAgICAgICAgICAgICAvL2RvaW5nIGEgZGlyZWN0IG1vZGlmaWNhdGlvbiBvZiB0aGUgZGVwTWFwcyBhcnJheQogICAgICAgICAgICAgICAgLy93b3VsZCBhZmZlY3QgdGhhdCBjb25maWcuCiAgICAgICAgICAgICAgICB0aGlzLmRlcE1hcHMgPSBkZXBNYXBzICYmIGRlcE1hcHMuc2xpY2UoMCk7CgogICAgICAgICAgICAgICAgdGhpcy5lcnJiYWNrID0gZXJyYmFjazsKCiAgICAgICAgICAgICAgICAvL0luZGljYXRlIHRoaXMgbW9kdWxlIGhhcyBiZSBpbml0aWFsaXplZAogICAgICAgICAgICAgICAgdGhpcy5pbml0ZWQgPSB0cnVlOwoKICAgICAgICAgICAgICAgIHRoaXMuaWdub3JlID0gb3B0aW9ucy5pZ25vcmU7CgogICAgICAgICAgICAgICAgLy9Db3VsZCBoYXZlIG9wdGlvbiB0byBpbml0IHRoaXMgbW9kdWxlIGluIGVuYWJsZWQgbW9kZSwKICAgICAgICAgICAgICAgIC8vb3IgY291bGQgaGF2ZSBiZWVuIHByZXZpb3VzbHkgbWFya2VkIGFzIGVuYWJsZWQuIEhvd2V2ZXIsCiAgICAgICAgICAgICAgICAvL3RoZSBkZXBlbmRlbmNpZXMgYXJlIG5vdCBrbm93biB1bnRpbCBpbml0IGlzIGNhbGxlZC4gU28KICAgICAgICAgICAgICAgIC8vaWYgZW5hYmxlZCBwcmV2aW91c2x5LCBub3cgdHJpZ2dlciBkZXBlbmRlbmNpZXMgYXMgZW5hYmxlZC4KICAgICAgICAgICAgICAgIGlmIChvcHRpb25zLmVuYWJsZWQgfHwgdGhpcy5lbmFibGVkKSB7CiAgICAgICAgICAgICAgICAgICAgLy9FbmFibGUgdGhpcyBtb2R1bGUgYW5kIGRlcGVuZGVuY2llcy4KICAgICAgICAgICAgICAgICAgICAvL1dpbGwgY2FsbCB0aGlzLmNoZWNrKCkKICAgICAgICAgICAgICAgICAgICB0aGlzLmVuYWJsZSgpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmNoZWNrKCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICBkZWZpbmVEZXA6IGZ1bmN0aW9uIChpLCBkZXBFeHBvcnRzKSB7CiAgICAgICAgICAgICAgICAvL0JlY2F1c2Ugb2YgY3ljbGVzLCBkZWZpbmVkIGNhbGxiYWNrIGZvciBhIGdpdmVuCiAgICAgICAgICAgICAgICAvL2V4cG9ydCBjYW4gYmUgY2FsbGVkIG1vcmUgdGhhbiBvbmNlLgogICAgICAgICAgICAgICAgaWYgKCF0aGlzLmRlcE1hdGNoZWRbaV0pIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcE1hdGNoZWRbaV0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIHRoaXMuZGVwQ291bnQgLT0gMTsKICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcEV4cG9ydHNbaV0gPSBkZXBFeHBvcnRzOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAoKICAgICAgICAgICAgZmV0Y2g6IGZ1bmN0aW9uICgpIHsKICAgICAgICAgICAgICAgIGlmICh0aGlzLmZldGNoZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB0aGlzLmZldGNoZWQgPSB0cnVlOwoKICAgICAgICAgICAgICAgIGNvbnRleHQuc3RhcnRUaW1lID0gKG5ldyBEYXRlKCkpLmdldFRpbWUoKTsKCiAgICAgICAgICAgICAgICB2YXIgbWFwID0gdGhpcy5tYXA7CgogICAgICAgICAgICAgICAgLy9JZiB0aGUgbWFuYWdlciBpcyBmb3IgYSBwbHVnaW4gbWFuYWdlZCByZXNvdXJjZSwKICAgICAgICAgICAgICAgIC8vYXNrIHRoZSBwbHVnaW4gdG8gbG9hZCBpdCBub3cuCiAgICAgICAgICAgICAgICBpZiAodGhpcy5zaGltKSB7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5tYWtlUmVxdWlyZSh0aGlzLm1hcCwgewogICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVCdWlsZENhbGxiYWNrOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgfSkodGhpcy5zaGltLmRlcHMgfHwgW10sIGJpbmQodGhpcywgZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gbWFwLnByZWZpeCA/IHRoaXMuY2FsbFBsdWdpbigpIDogdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICAgICAgfSkpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAvL1JlZ3VsYXIgZGVwZW5kZW5jeS4KICAgICAgICAgICAgICAgICAgICByZXR1cm4gbWFwLnByZWZpeCA/IHRoaXMuY2FsbFBsdWdpbigpIDogdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICBsb2FkOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICB2YXIgdXJsID0gdGhpcy5tYXAudXJsOwoKICAgICAgICAgICAgICAgIC8vUmVndWxhciBkZXBlbmRlbmN5LgogICAgICAgICAgICAgICAgaWYgKCF1cmxGZXRjaGVkW3VybF0pIHsKICAgICAgICAgICAgICAgICAgICB1cmxGZXRjaGVkW3VybF0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIGNvbnRleHQubG9hZCh0aGlzLm1hcC5pZCwgdXJsKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIC8qKgogICAgICAgICAgICAgKiBDaGVja3MgaWYgdGhlIG1vZHVsZSBpcyByZWFkeSB0byBkZWZpbmUgaXRzZWxmLCBhbmQgaWYgc28sCiAgICAgICAgICAgICAqIGRlZmluZSBpdC4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGNoZWNrOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICBpZiAoIXRoaXMuZW5hYmxlZCB8fCB0aGlzLmVuYWJsaW5nKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHZhciBlcnIsIGNqc01vZHVsZSwKICAgICAgICAgICAgICAgICAgICBpZCA9IHRoaXMubWFwLmlkLAogICAgICAgICAgICAgICAgICAgIGRlcEV4cG9ydHMgPSB0aGlzLmRlcEV4cG9ydHMsCiAgICAgICAgICAgICAgICAgICAgZXhwb3J0cyA9IHRoaXMuZXhwb3J0cywKICAgICAgICAgICAgICAgICAgICBmYWN0b3J5ID0gdGhpcy5mYWN0b3J5OwoKICAgICAgICAgICAgICAgIGlmICghdGhpcy5pbml0ZWQpIHsKICAgICAgICAgICAgICAgICAgICAvLyBPbmx5IGZldGNoIGlmIG5vdCBhbHJlYWR5IGluIHRoZSBkZWZRdWV1ZS4KICAgICAgICAgICAgICAgICAgICBpZiAoIWhhc1Byb3AoY29udGV4dC5kZWZRdWV1ZU1hcCwgaWQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuZmV0Y2goKTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgdGhpcy5lcnJvcik7CiAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKCF0aGlzLmRlZmluaW5nKSB7CiAgICAgICAgICAgICAgICAgICAgLy9UaGUgZmFjdG9yeSBjb3VsZCB0cmlnZ2VyIGFub3RoZXIgcmVxdWlyZSBjYWxsCiAgICAgICAgICAgICAgICAgICAgLy90aGF0IHdvdWxkIHJlc3VsdCBpbiBjaGVja2luZyB0aGlzIG1vZHVsZSB0bwogICAgICAgICAgICAgICAgICAgIC8vZGVmaW5lIGl0c2VsZiBhZ2Fpbi4gSWYgYWxyZWFkeSBpbiB0aGUgcHJvY2VzcwogICAgICAgICAgICAgICAgICAgIC8vb2YgZG9pbmcgdGhhdCwgc2tpcCB0aGlzIHdvcmsuCiAgICAgICAgICAgICAgICAgICAgdGhpcy5kZWZpbmluZyA9IHRydWU7CgogICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLmRlcENvdW50IDwgMSAmJiAhdGhpcy5kZWZpbmVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChpc0Z1bmN0aW9uKGZhY3RvcnkpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0cnkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBjb250ZXh0LmV4ZWNDYihpZCwgZmFjdG9yeSwgZGVwRXhwb3J0cywgZXhwb3J0cyk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZXJyID0gZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBGYXZvciByZXR1cm4gdmFsdWUgb3ZlciBleHBvcnRzLiBJZiBub2RlL2NqcyBpbiBwbGF5LAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdGhlbiB3aWxsIG5vdCBoYXZlIGEgcmV0dXJuIHZhbHVlIGFueXdheS4gRmF2b3IKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIG1vZHVsZS5leHBvcnRzIGFzc2lnbm1lbnQgb3ZlciBleHBvcnRzIG9iamVjdC4KICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLm1hcC5pc0RlZmluZSAmJiBleHBvcnRzID09PSB1bmRlZmluZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjanNNb2R1bGUgPSB0aGlzLm1vZHVsZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAoY2pzTW9kdWxlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBjanNNb2R1bGUuZXhwb3J0czsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMudXNpbmdFeHBvcnRzKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vZXhwb3J0cyBhbHJlYWR5IHNldCB0aGUgZGVmaW5lZCB2YWx1ZS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZXhwb3J0cyA9IHRoaXMuZXhwb3J0czsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIElmIHRoZXJlIGlzIGFuIGVycm9yIGxpc3RlbmVyLCBmYXZvciBwYXNzaW5nCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdG8gdGhhdCBpbnN0ZWFkIG9mIHRocm93aW5nIGFuIGVycm9yLiBIb3dldmVyLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIG9ubHkgZG8gaXQgZm9yIGRlZmluZSgpJ2QgIG1vZHVsZXMuIHJlcXVpcmUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBlcnJiYWNrcyBzaG91bGQgbm90IGJlIGNhbGxlZCBmb3IgZmFpbHVyZXMgaW4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyB0aGVpciBjYWxsYmFja3MgKCM2OTkpLiBIb3dldmVyIGlmIGEgZ2xvYmFsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gb25FcnJvciBpcyBzZXQsIHVzZSB0aGF0LgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICgodGhpcy5ldmVudHMuZXJyb3IgJiYgdGhpcy5tYXAuaXNEZWZpbmUpIHx8CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlcS5vbkVycm9yICE9PSBkZWZhdWx0T25FcnJvcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZU1hcCA9IHRoaXMubWFwOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZU1vZHVsZXMgPSB0aGlzLm1hcC5pc0RlZmluZSA/IFt0aGlzLm1hcC5pZF0gOiBudWxsOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZVR5cGUgPSB0aGlzLm1hcC5pc0RlZmluZSA/ICdkZWZpbmUnIDogJ3JlcXVpcmUnOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcigodGhpcy5lcnJvciA9IGVycikpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0gZWxzZSBpZiAodHlwZW9mIGNvbnNvbGUgIT09ICd1bmRlZmluZWQnICYmCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjb25zb2xlLmVycm9yKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIExvZyB0aGUgZXJyb3IgZm9yIGRlYnVnZ2luZy4gSWYgcHJvbWlzZXMgY291bGQgYmUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdXNlZCwgdGhpcyB3b3VsZCBiZSBkaWZmZXJlbnQsIGJ1dCBtYWtpbmcgZG8uCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbnNvbGUuZXJyb3IoZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBEbyBub3Qgd2FudCB0byBjb21wbGV0ZWx5IGxvc2UgdGhlIGVycm9yLiBXaGlsZSB0aGlzCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIHdpbGwgbWVzcyB1cCBwcm9jZXNzaW5nIGFuZCBsZWFkIHRvIHNpbWlsYXIgcmVzdWx0cwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBhcyBidWcgMTQ0MCwgaXQgYXQgbGVhc3Qgc3VyZmFjZXMgdGhlIGVycm9yLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXEub25FcnJvcihlcnIpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vSnVzdCBhIGxpdGVyYWwgdmFsdWUKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBmYWN0b3J5OwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmV4cG9ydHMgPSBleHBvcnRzOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRoaXMubWFwLmlzRGVmaW5lICYmICF0aGlzLmlnbm9yZSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgZGVmaW5lZFtpZF0gPSBleHBvcnRzOwoKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmIChyZXEub25SZXNvdXJjZUxvYWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB2YXIgcmVzTG9hZE1hcHMgPSBbXTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlYWNoKHRoaXMuZGVwTWFwcywgZnVuY3Rpb24gKGRlcE1hcCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXNMb2FkTWFwcy5wdXNoKGRlcE1hcC5ub3JtYWxpemVkTWFwIHx8IGRlcE1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcmVxLm9uUmVzb3VyY2VMb2FkKGNvbnRleHQsIHRoaXMubWFwLCByZXNMb2FkTWFwcyk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vQ2xlYW4gdXAKICAgICAgICAgICAgICAgICAgICAgICAgY2xlYW5SZWdpc3RyeShpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgLy9GaW5pc2hlZCB0aGUgZGVmaW5lIHN0YWdlLiBBbGxvdyBjYWxsaW5nIGNoZWNrIGFnYWluCiAgICAgICAgICAgICAgICAgICAgLy90byBhbGxvdyBkZWZpbmUgbm90aWZpY2F0aW9ucyBiZWxvdyBpbiB0aGUgY2FzZSBvZiBhCiAgICAgICAgICAgICAgICAgICAgLy9jeWNsZS4KICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluaW5nID0gZmFsc2U7CgogICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLmRlZmluZWQgJiYgIXRoaXMuZGVmaW5lRW1pdHRlZCkgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZUVtaXR0ZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2RlZmluZWQnLCB0aGlzLmV4cG9ydHMpOwogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZUVtaXRDb21wbGV0ZSA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIGNhbGxQbHVnaW46IGZ1bmN0aW9uICgpIHsKICAgICAgICAgICAgICAgIHZhciBtYXAgPSB0aGlzLm1hcCwKICAgICAgICAgICAgICAgICAgICBpZCA9IG1hcC5pZCwKICAgICAgICAgICAgICAgICAgICAvL01hcCBhbHJlYWR5IG5vcm1hbGl6ZWQgdGhlIHByZWZpeC4KICAgICAgICAgICAgICAgICAgICBwbHVnaW5NYXAgPSBtYWtlTW9kdWxlTWFwKG1hcC5wcmVmaXgpOwoKICAgICAgICAgICAgICAgIC8vTWFyayB0aGlzIGFzIGEgZGVwZW5kZW5jeSBmb3IgdGhpcyBwbHVnaW4sIHNvIGl0CiAgICAgICAgICAgICAgICAvL2NhbiBiZSB0cmFjZWQgZm9yIGN5Y2xlcy4KICAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwcy5wdXNoKHBsdWdpbk1hcCk7CgogICAgICAgICAgICAgICAgb24ocGx1Z2luTWFwLCAnZGVmaW5lZCcsIGJpbmQodGhpcywgZnVuY3Rpb24gKHBsdWdpbikgewogICAgICAgICAgICAgICAgICAgIHZhciBsb2FkLCBub3JtYWxpemVkTWFwLCBub3JtYWxpemVkTW9kLAogICAgICAgICAgICAgICAgICAgICAgICBidW5kbGVJZCA9IGdldE93bihidW5kbGVzTWFwLCB0aGlzLm1hcC5pZCksCiAgICAgICAgICAgICAgICAgICAgICAgIG5hbWUgPSB0aGlzLm1hcC5uYW1lLAogICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnROYW1lID0gdGhpcy5tYXAucGFyZW50TWFwID8gdGhpcy5tYXAucGFyZW50TWFwLm5hbWUgOiBudWxsLAogICAgICAgICAgICAgICAgICAgICAgICBsb2NhbFJlcXVpcmUgPSBjb250ZXh0Lm1ha2VSZXF1aXJlKG1hcC5wYXJlbnRNYXAsIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuYWJsZUJ1aWxkQ2FsbGJhY2s6IHRydWUKICAgICAgICAgICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgICAgIC8vSWYgY3VycmVudCBtYXAgaXMgbm90IG5vcm1hbGl6ZWQsIHdhaXQgZm9yIHRoYXQKICAgICAgICAgICAgICAgICAgICAvL25vcm1hbGl6ZWQgbmFtZSB0byBsb2FkIGluc3RlYWQgb2YgY29udGludWluZy4KICAgICAgICAgICAgICAgICAgICBpZiAodGhpcy5tYXAudW5ub3JtYWxpemVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplIHRoZSBJRCBpZiB0aGUgcGx1Z2luIGFsbG93cyBpdC4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHBsdWdpbi5ub3JtYWxpemUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG5hbWUgPSBwbHVnaW4ubm9ybWFsaXplKG5hbWUsIGZ1bmN0aW9uIChuYW1lKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCB0cnVlKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0pIHx8ICcnOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL3ByZWZpeCBhbmQgbmFtZSBzaG91bGQgYWxyZWFkeSBiZSBub3JtYWxpemVkLCBubyBuZWVkCiAgICAgICAgICAgICAgICAgICAgICAgIC8vZm9yIGFwcGx5aW5nIG1hcCBjb25maWcgYWdhaW4gZWl0aGVyLgogICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTWFwID0gbWFrZU1vZHVsZU1hcChtYXAucHJlZml4ICsgJyEnICsgbmFtZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5tYXAucGFyZW50TWFwKTsKICAgICAgICAgICAgICAgICAgICAgICAgb24obm9ybWFsaXplZE1hcCwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICdkZWZpbmVkJywgYmluZCh0aGlzLCBmdW5jdGlvbiAodmFsdWUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLm1hcC5ub3JtYWxpemVkTWFwID0gbm9ybWFsaXplZE1hcDsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmluaXQoW10sIGZ1bmN0aW9uICgpIHsgcmV0dXJuIHZhbHVlOyB9LCBudWxsLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuYWJsZWQ6IHRydWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlnbm9yZTogdHJ1ZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZE1vZCA9IGdldE93bihyZWdpc3RyeSwgbm9ybWFsaXplZE1hcC5pZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChub3JtYWxpemVkTW9kKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL01hcmsgdGhpcyBhcyBhIGRlcGVuZGVuY3kgZm9yIHRoaXMgcGx1Z2luLCBzbyBpdAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9jYW4gYmUgdHJhY2VkIGZvciBjeWNsZXMuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcE1hcHMucHVzaChub3JtYWxpemVkTWFwKTsKCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAodGhpcy5ldmVudHMuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTW9kLm9uKCdlcnJvcicsIGJpbmQodGhpcywgZnVuY3Rpb24gKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9KSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTW9kLmVuYWJsZSgpOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAvL0lmIGEgcGF0aHMgY29uZmlnLCB0aGVuIGp1c3QgbG9hZCB0aGF0IGZpbGUgaW5zdGVhZCB0bwogICAgICAgICAgICAgICAgICAgIC8vcmVzb2x2ZSB0aGUgcGx1Z2luLCBhcyBpdCBpcyBidWlsdCBpbnRvIHRoYXQgcGF0aHMgbGF5ZXIuCiAgICAgICAgICAgICAgICAgICAgaWYgKGJ1bmRsZUlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMubWFwLnVybCA9IGNvbnRleHQubmFtZVRvVXJsKGJ1bmRsZUlkKTsKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybjsKICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGxvYWQgPSBiaW5kKHRoaXMsIGZ1bmN0aW9uICh2YWx1ZSkgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmluaXQoW10sIGZ1bmN0aW9uICgpIHsgcmV0dXJuIHZhbHVlOyB9LCBudWxsLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVkOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgICAgICBsb2FkLmVycm9yID0gYmluZCh0aGlzLCBmdW5jdGlvbiAoZXJyKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuaW5pdGVkID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5lcnJvciA9IGVycjsKICAgICAgICAgICAgICAgICAgICAgICAgZXJyLnJlcXVpcmVNb2R1bGVzID0gW2lkXTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vUmVtb3ZlIHRlbXAgdW5ub3JtYWxpemVkIG1vZHVsZXMgZm9yIHRoaXMgbW9kdWxlLAogICAgICAgICAgICAgICAgICAgICAgICAvL3NpbmNlIHRoZXkgd2lsbCBuZXZlciBiZSByZXNvbHZlZCBvdGhlcndpc2Ugbm93LgogICAgICAgICAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZC5tYXAuaWQuaW5kZXhPZihpZCArICdfdW5ub3JtYWxpemVkJykgPT09IDApIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjbGVhblJlZ2lzdHJ5KG1vZC5tYXAuaWQpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgICAgIG9uRXJyb3IoZXJyKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgLy9BbGxvdyBwbHVnaW5zIHRvIGxvYWQgb3RoZXIgY29kZSB3aXRob3V0IGhhdmluZyB0byBrbm93IHRoZQogICAgICAgICAgICAgICAgICAgIC8vY29udGV4dCBvciBob3cgdG8gJ2NvbXBsZXRlJyB0aGUgbG9hZC4KICAgICAgICAgICAgICAgICAgICBsb2FkLmZyb21UZXh0ID0gYmluZCh0aGlzLCBmdW5jdGlvbiAodGV4dCwgdGV4dEFsdCkgewogICAgICAgICAgICAgICAgICAgICAgICAvKmpzbGludCBldmlsOiB0cnVlICovCiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBtb2R1bGVOYW1lID0gbWFwLm5hbWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVNYXAgPSBtYWtlTW9kdWxlTWFwKG1vZHVsZU5hbWUpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaGFzSW50ZXJhY3RpdmUgPSB1c2VJbnRlcmFjdGl2ZTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vQXMgb2YgMi4xLjAsIHN1cHBvcnQganVzdCBwYXNzaW5nIHRoZSB0ZXh0LCB0byByZWluZm9yY2UKICAgICAgICAgICAgICAgICAgICAgICAgLy9mcm9tVGV4dCBvbmx5IGJlaW5nIGNhbGxlZCBvbmNlIHBlciByZXNvdXJjZS4gU3RpbGwKICAgICAgICAgICAgICAgICAgICAgICAgLy9zdXBwb3J0IG9sZCBzdHlsZSBvZiBwYXNzaW5nIG1vZHVsZU5hbWUgYnV0IGRpc2NhcmQKICAgICAgICAgICAgICAgICAgICAgICAgLy90aGF0IG1vZHVsZU5hbWUgaW4gZmF2b3Igb2YgdGhlIGludGVybmFsIHJlZi4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRleHRBbHQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRleHQgPSB0ZXh0QWx0OwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL1R1cm4gb2ZmIGludGVyYWN0aXZlIHNjcmlwdCBtYXRjaGluZyBmb3IgSUUgZm9yIGFueSBkZWZpbmUKICAgICAgICAgICAgICAgICAgICAgICAgLy9jYWxscyBpbiB0aGUgdGV4dCwgdGhlbiB0dXJuIGl0IGJhY2sgb24gYXQgdGhlIGVuZC4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGhhc0ludGVyYWN0aXZlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB1c2VJbnRlcmFjdGl2ZSA9IGZhbHNlOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL1ByaW1lIHRoZSBzeXN0ZW0gYnkgY3JlYXRpbmcgYSBtb2R1bGUgaW5zdGFuY2UgZm9yCiAgICAgICAgICAgICAgICAgICAgICAgIC8vaXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGdldE1vZHVsZShtb2R1bGVNYXApOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9UcmFuc2ZlciBhbnkgY29uZmlnIHRvIHRoaXMgb3RoZXIgbW9kdWxlLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFzUHJvcChjb25maWcuY29uZmlnLCBpZCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZy5jb25maWdbbW9kdWxlTmFtZV0gPSBjb25maWcuY29uZmlnW2lkXTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgdHJ5IHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlcS5leGVjKHRleHQpOwogICAgICAgICAgICAgICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ2Zyb210ZXh0ZXZhbCcsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdmcm9tVGV4dCBldmFsIGZvciAnICsgaWQgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICcgZmFpbGVkOiAnICsgZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgW2lkXSkpOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFzSW50ZXJhY3RpdmUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHVzZUludGVyYWN0aXZlID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9NYXJrIHRoaXMgYXMgYSBkZXBlbmRlbmN5IGZvciB0aGUgcGx1Z2luCiAgICAgICAgICAgICAgICAgICAgICAgIC8vcmVzb3VyY2UKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBNYXBzLnB1c2gobW9kdWxlTWFwKTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vU3VwcG9ydCBhbm9ueW1vdXMgbW9kdWxlcy4KICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5jb21wbGV0ZUxvYWQobW9kdWxlTmFtZSk7CgogICAgICAgICAgICAgICAgICAgICAgICAvL0JpbmQgdGhlIHZhbHVlIG9mIHRoYXQgbW9kdWxlIHRvIHRoZSB2YWx1ZSBmb3IgdGhpcwogICAgICAgICAgICAgICAgICAgICAgICAvL3Jlc291cmNlIElELgogICAgICAgICAgICAgICAgICAgICAgICBsb2NhbFJlcXVpcmUoW21vZHVsZU5hbWVdLCBsb2FkKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgLy9Vc2UgcGFyZW50TmFtZSBoZXJlIHNpbmNlIHRoZSBwbHVnaW4ncyBuYW1lIGlzIG5vdCByZWxpYWJsZSwKICAgICAgICAgICAgICAgICAgICAvL2NvdWxkIGJlIHNvbWUgd2VpcmQgc3RyaW5nIHdpdGggbm8gcGF0aCB0aGF0IGFjdHVhbGx5IHdhbnRzIHRvCiAgICAgICAgICAgICAgICAgICAgLy9yZWZlcmVuY2UgdGhlIHBhcmVudE5hbWUncyBwYXRoLgogICAgICAgICAgICAgICAgICAgIHBsdWdpbi5sb2FkKG1hcC5uYW1lLCBsb2NhbFJlcXVpcmUsIGxvYWQsIGNvbmZpZyk7CiAgICAgICAgICAgICAgICB9KSk7CgogICAgICAgICAgICAgICAgY29udGV4dC5lbmFibGUocGx1Z2luTWFwLCB0aGlzKTsKICAgICAgICAgICAgICAgIHRoaXMucGx1Z2luTWFwc1twbHVnaW5NYXAuaWRdID0gcGx1Z2luTWFwOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgZW5hYmxlOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICBlbmFibGVkUmVnaXN0cnlbdGhpcy5tYXAuaWRdID0gdGhpczsKICAgICAgICAgICAgICAgIHRoaXMuZW5hYmxlZCA9IHRydWU7CgogICAgICAgICAgICAgICAgLy9TZXQgZmxhZyBtZW50aW9uaW5nIHRoYXQgdGhlIG1vZHVsZSBpcyBlbmFibGluZywKICAgICAgICAgICAgICAgIC8vc28gdGhhdCBpbW1lZGlhdGUgY2FsbHMgdG8gdGhlIGRlZmluZWQgY2FsbGJhY2tzCiAgICAgICAgICAgICAgICAvL2ZvciBkZXBlbmRlbmNpZXMgZG8gbm90IHRyaWdnZXIgaW5hZHZlcnRlbnQgbG9hZAogICAgICAgICAgICAgICAgLy93aXRoIHRoZSBkZXBDb3VudCBzdGlsbCBiZWluZyB6ZXJvLgogICAgICAgICAgICAgICAgdGhpcy5lbmFibGluZyA9IHRydWU7CgogICAgICAgICAgICAgICAgLy9FbmFibGUgZWFjaCBkZXBlbmRlbmN5CiAgICAgICAgICAgICAgICBlYWNoKHRoaXMuZGVwTWFwcywgYmluZCh0aGlzLCBmdW5jdGlvbiAoZGVwTWFwLCBpKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIGlkLCBtb2QsIGhhbmRsZXI7CgogICAgICAgICAgICAgICAgICAgIGlmICh0eXBlb2YgZGVwTWFwID09PSAnc3RyaW5nJykgewogICAgICAgICAgICAgICAgICAgICAgICAvL0RlcGVuZGVuY3kgbmVlZHMgdG8gYmUgY29udmVydGVkIHRvIGEgZGVwTWFwCiAgICAgICAgICAgICAgICAgICAgICAgIC8vYW5kIHdpcmVkIHVwIHRvIHRoaXMgbW9kdWxlLgogICAgICAgICAgICAgICAgICAgICAgICBkZXBNYXAgPSBtYWtlTW9kdWxlTWFwKGRlcE1hcCwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAodGhpcy5tYXAuaXNEZWZpbmUgPyB0aGlzLm1hcCA6IHRoaXMubWFwLnBhcmVudE1hcCksCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZmFsc2UsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIXRoaXMuc2tpcE1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwc1tpXSA9IGRlcE1hcDsKCiAgICAgICAgICAgICAgICAgICAgICAgIGhhbmRsZXIgPSBnZXRPd24oaGFuZGxlcnMsIGRlcE1hcC5pZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFuZGxlcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBFeHBvcnRzW2ldID0gaGFuZGxlcih0aGlzKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybjsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBDb3VudCArPSAxOwoKICAgICAgICAgICAgICAgICAgICAgICAgb24oZGVwTWFwLCAnZGVmaW5lZCcsIGJpbmQodGhpcywgZnVuY3Rpb24gKGRlcEV4cG9ydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLnVuZGVmZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZURlcChpLCBkZXBFeHBvcnRzKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuY2hlY2soKTsKICAgICAgICAgICAgICAgICAgICAgICAgfSkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRoaXMuZXJyYmFjaykgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgb24oZGVwTWFwLCAnZXJyb3InLCBiaW5kKHRoaXMsIHRoaXMuZXJyYmFjaykpOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMuZXZlbnRzLmVycm9yKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBObyBkaXJlY3QgZXJyYmFjayBvbiB0aGlzIG1vZHVsZSwgYnV0IHNvbWV0aGluZwogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gZWxzZSBpcyBsaXN0ZW5pbmcgZm9yIGVycm9ycywgc28gYmUgc3VyZSB0bwogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gcHJvcGFnYXRlIHRoZSBlcnJvciBjb3JyZWN0bHkuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBvbihkZXBNYXAsICdlcnJvcicsIGJpbmQodGhpcywgZnVuY3Rpb24oZXJyKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5lbWl0KCdlcnJvcicsIGVycik7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9KSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGlkID0gZGVwTWFwLmlkOwogICAgICAgICAgICAgICAgICAgIG1vZCA9IHJlZ2lzdHJ5W2lkXTsKCiAgICAgICAgICAgICAgICAgICAgLy9Ta2lwIHNwZWNpYWwgbW9kdWxlcyBsaWtlICdyZXF1aXJlJywgJ2V4cG9ydHMnLCAnbW9kdWxlJwogICAgICAgICAgICAgICAgICAgIC8vQWxzbywgZG9uJ3QgY2FsbCBlbmFibGUgaWYgaXQgaXMgYWxyZWFkeSBlbmFibGVkLAogICAgICAgICAgICAgICAgICAgIC8vaW1wb3J0YW50IGluIGNpcmN1bGFyIGRlcGVuZGVuY3kgY2FzZXMuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFoYXNQcm9wKGhhbmRsZXJzLCBpZCkgJiYgbW9kICYmICFtb2QuZW5hYmxlZCkgewogICAgICAgICAgICAgICAgICAgICAgICBjb250ZXh0LmVuYWJsZShkZXBNYXAsIHRoaXMpOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pKTsKCiAgICAgICAgICAgICAgICAvL0VuYWJsZSBlYWNoIHBsdWdpbiB0aGF0IGlzIHVzZWQgaW4KICAgICAgICAgICAgICAgIC8vYSBkZXBlbmRlbmN5CiAgICAgICAgICAgICAgICBlYWNoUHJvcCh0aGlzLnBsdWdpbk1hcHMsIGJpbmQodGhpcywgZnVuY3Rpb24gKHBsdWdpbk1hcCkgewogICAgICAgICAgICAgICAgICAgIHZhciBtb2QgPSBnZXRPd24ocmVnaXN0cnksIHBsdWdpbk1hcC5pZCk7CiAgICAgICAgICAgICAgICAgICAgaWYgKG1vZCAmJiAhbW9kLmVuYWJsZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5lbmFibGUocGx1Z2luTWFwLCB0aGlzKTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9KSk7CgogICAgICAgICAgICAgICAgdGhpcy5lbmFibGluZyA9IGZhbHNlOwoKICAgICAgICAgICAgICAgIHRoaXMuY2hlY2soKTsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIG9uOiBmdW5jdGlvbiAobmFtZSwgY2IpIHsKICAgICAgICAgICAgICAgIHZhciBjYnMgPSB0aGlzLmV2ZW50c1tuYW1lXTsKICAgICAgICAgICAgICAgIGlmICghY2JzKSB7CiAgICAgICAgICAgICAgICAgICAgY2JzID0gdGhpcy5ldmVudHNbbmFtZV0gPSBbXTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNicy5wdXNoKGNiKTsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIGVtaXQ6IGZ1bmN0aW9uIChuYW1lLCBldnQpIHsKICAgICAgICAgICAgICAgIGVhY2godGhpcy5ldmVudHNbbmFtZV0sIGZ1bmN0aW9uIChjYikgewogICAgICAgICAgICAgICAgICAgIGNiKGV2dCk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIGlmIChuYW1lID09PSAnZXJyb3InKSB7CiAgICAgICAgICAgICAgICAgICAgLy9Ob3cgdGhhdCB0aGUgZXJyb3IgaGFuZGxlciB3YXMgdHJpZ2dlcmVkLCByZW1vdmUKICAgICAgICAgICAgICAgICAgICAvL3RoZSBsaXN0ZW5lcnMsIHNpbmNlIHRoaXMgYnJva2VuIE1vZHVsZSBpbnN0YW5jZQogICAgICAgICAgICAgICAgICAgIC8vY2FuIHN0YXkgYXJvdW5kIGZvciBhIHdoaWxlIGluIHRoZSByZWdpc3RyeS4KICAgICAgICAgICAgICAgICAgICBkZWxldGUgdGhpcy5ldmVudHNbbmFtZV07CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBmdW5jdGlvbiBjYWxsR2V0TW9kdWxlKGFyZ3MpIHsKICAgICAgICAgICAgLy9Ta2lwIG1vZHVsZXMgYWxyZWFkeSBkZWZpbmVkLgogICAgICAgICAgICBpZiAoIWhhc1Byb3AoZGVmaW5lZCwgYXJnc1swXSkpIHsKICAgICAgICAgICAgICAgIGdldE1vZHVsZShtYWtlTW9kdWxlTWFwKGFyZ3NbMF0sIG51bGwsIHRydWUpKS5pbml0KGFyZ3NbMV0sIGFyZ3NbMl0pOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiByZW1vdmVMaXN0ZW5lcihub2RlLCBmdW5jLCBuYW1lLCBpZU5hbWUpIHsKICAgICAgICAgICAgLy9GYXZvciBkZXRhY2hFdmVudCBiZWNhdXNlIG9mIElFOQogICAgICAgICAgICAvL2lzc3VlLCBzZWUgYXR0YWNoRXZlbnQvYWRkRXZlbnRMaXN0ZW5lciBjb21tZW50IGVsc2V3aGVyZQogICAgICAgICAgICAvL2luIHRoaXMgZmlsZS4KICAgICAgICAgICAgaWYgKG5vZGUuZGV0YWNoRXZlbnQgJiYgIWlzT3BlcmEpIHsKICAgICAgICAgICAgICAgIC8vUHJvYmFibHkgSUUuIElmIG5vdCBpdCB3aWxsIHRocm93IGFuIGVycm9yLCB3aGljaCB3aWxsIGJlCiAgICAgICAgICAgICAgICAvL3VzZWZ1bCB0byBrbm93LgogICAgICAgICAgICAgICAgaWYgKGllTmFtZSkgewogICAgICAgICAgICAgICAgICAgIG5vZGUuZGV0YWNoRXZlbnQoaWVOYW1lLCBmdW5jKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgIG5vZGUucmVtb3ZlRXZlbnRMaXN0ZW5lcihuYW1lLCBmdW5jLCBmYWxzZSk7CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIC8qKgogICAgICAgICAqIEdpdmVuIGFuIGV2ZW50IGZyb20gYSBzY3JpcHQgbm9kZSwgZ2V0IHRoZSByZXF1aXJlanMgaW5mbyBmcm9tIGl0LAogICAgICAgICAqIGFuZCB0aGVuIHJlbW92ZXMgdGhlIGV2ZW50IGxpc3RlbmVycyBvbiB0aGUgbm9kZS4KICAgICAgICAgKiBAcGFyYW0ge0V2ZW50fSBldnQKICAgICAgICAgKiBAcmV0dXJucyB7T2JqZWN0fQogICAgICAgICAqLwogICAgICAgIGZ1bmN0aW9uIGdldFNjcmlwdERhdGEoZXZ0KSB7CiAgICAgICAgICAgIC8vVXNpbmcgY3VycmVudFRhcmdldCBpbnN0ZWFkIG9mIHRhcmdldCBmb3IgRmlyZWZveCAyLjAncyBzYWtlLiBOb3QKICAgICAgICAgICAgLy9hbGwgb2xkIGJyb3dzZXJzIHdpbGwgYmUgc3VwcG9ydGVkLCBidXQgdGhpcyBvbmUgd2FzIGVhc3kgZW5vdWdoCiAgICAgICAgICAgIC8vdG8gc3VwcG9ydCBhbmQgc3RpbGwgbWFrZXMgc2Vuc2UuCiAgICAgICAgICAgIHZhciBub2RlID0gZXZ0LmN1cnJlbnRUYXJnZXQgfHwgZXZ0LnNyY0VsZW1lbnQ7CgogICAgICAgICAgICAvL1JlbW92ZSB0aGUgbGlzdGVuZXJzIG9uY2UgaGVyZS4KICAgICAgICAgICAgcmVtb3ZlTGlzdGVuZXIobm9kZSwgY29udGV4dC5vblNjcmlwdExvYWQsICdsb2FkJywgJ29ucmVhZHlzdGF0ZWNoYW5nZScpOwogICAgICAgICAgICByZW1vdmVMaXN0ZW5lcihub2RlLCBjb250ZXh0Lm9uU2NyaXB0RXJyb3IsICdlcnJvcicpOwoKICAgICAgICAgICAgcmV0dXJuIHsKICAgICAgICAgICAgICAgIG5vZGU6IG5vZGUsCiAgICAgICAgICAgICAgICBpZDogbm9kZSAmJiBub2RlLmdldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlbW9kdWxlJykKICAgICAgICAgICAgfTsKICAgICAgICB9CgogICAgICAgIGZ1bmN0aW9uIGludGFrZURlZmluZXMoKSB7CiAgICAgICAgICAgIHZhciBhcmdzOwoKICAgICAgICAgICAgLy9BbnkgZGVmaW5lZCBtb2R1bGVzIGluIHRoZSBnbG9iYWwgcXVldWUsIGludGFrZSB0aGVtIG5vdy4KICAgICAgICAgICAgdGFrZUdsb2JhbFF1ZXVlKCk7CgogICAgICAgICAgICAvL01ha2Ugc3VyZSBhbnkgcmVtYWluaW5nIGRlZlF1ZXVlIGl0ZW1zIGdldCBwcm9wZXJseSBwcm9jZXNzZWQuCiAgICAgICAgICAgIHdoaWxlIChkZWZRdWV1ZS5sZW5ndGgpIHsKICAgICAgICAgICAgICAgIGFyZ3MgPSBkZWZRdWV1ZS5zaGlmdCgpOwogICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IG51bGwpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ21pc21hdGNoJywgJ01pc21hdGNoZWQgYW5vbnltb3VzIGRlZmluZSgpIG1vZHVsZTogJyArCiAgICAgICAgICAgICAgICAgICAgICAgIGFyZ3NbYXJncy5sZW5ndGggLSAxXSkpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAvL2FyZ3MgYXJlIGlkLCBkZXBzLCBmYWN0b3J5LiBTaG91bGQgYmUgbm9ybWFsaXplZCBieSB0aGUKICAgICAgICAgICAgICAgICAgICAvL2RlZmluZSgpIGZ1bmN0aW9uLgogICAgICAgICAgICAgICAgICAgIGNhbGxHZXRNb2R1bGUoYXJncyk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZU1hcCA9IHt9OwogICAgICAgIH0KCiAgICAgICAgY29udGV4dCA9IHsKICAgICAgICAgICAgY29uZmlnOiBjb25maWcsCiAgICAgICAgICAgIGNvbnRleHROYW1lOiBjb250ZXh0TmFtZSwKICAgICAgICAgICAgcmVnaXN0cnk6IHJlZ2lzdHJ5LAogICAgICAgICAgICBkZWZpbmVkOiBkZWZpbmVkLAogICAgICAgICAgICB1cmxGZXRjaGVkOiB1cmxGZXRjaGVkLAogICAgICAgICAgICBkZWZRdWV1ZTogZGVmUXVldWUsCiAgICAgICAgICAgIGRlZlF1ZXVlTWFwOiB7fSwKICAgICAgICAgICAgTW9kdWxlOiBNb2R1bGUsCiAgICAgICAgICAgIG1ha2VNb2R1bGVNYXA6IG1ha2VNb2R1bGVNYXAsCiAgICAgICAgICAgIG5leHRUaWNrOiByZXEubmV4dFRpY2ssCiAgICAgICAgICAgIG9uRXJyb3I6IG9uRXJyb3IsCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogU2V0IGEgY29uZmlndXJhdGlvbiBmb3IgdGhlIGNvbnRleHQuCiAgICAgICAgICAgICAqIEBwYXJhbSB7T2JqZWN0fSBjZmcgY29uZmlnIG9iamVjdCB0byBpbnRlZ3JhdGUuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBjb25maWd1cmU6IGZ1bmN0aW9uIChjZmcpIHsKICAgICAgICAgICAgICAgIC8vTWFrZSBzdXJlIHRoZSBiYXNlVXJsIGVuZHMgaW4gYSBzbGFzaC4KICAgICAgICAgICAgICAgIGlmIChjZmcuYmFzZVVybCkgewogICAgICAgICAgICAgICAgICAgIGlmIChjZmcuYmFzZVVybC5jaGFyQXQoY2ZnLmJhc2VVcmwubGVuZ3RoIC0gMSkgIT09ICcvJykgewogICAgICAgICAgICAgICAgICAgICAgICBjZmcuYmFzZVVybCArPSAnLyc7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vU2F2ZSBvZmYgdGhlIHBhdGhzIHNpbmNlIHRoZXkgcmVxdWlyZSBzcGVjaWFsIHByb2Nlc3NpbmcsCiAgICAgICAgICAgICAgICAvL3RoZXkgYXJlIGFkZGl0aXZlLgogICAgICAgICAgICAgICAgdmFyIHNoaW0gPSBjb25maWcuc2hpbSwKICAgICAgICAgICAgICAgICAgICBvYmpzID0gewogICAgICAgICAgICAgICAgICAgICAgICBwYXRoczogdHJ1ZSwKICAgICAgICAgICAgICAgICAgICAgICAgYnVuZGxlczogdHJ1ZSwKICAgICAgICAgICAgICAgICAgICAgICAgY29uZmlnOiB0cnVlLAogICAgICAgICAgICAgICAgICAgICAgICBtYXA6IHRydWUKICAgICAgICAgICAgICAgICAgICB9OwoKICAgICAgICAgICAgICAgIGVhY2hQcm9wKGNmZywgZnVuY3Rpb24gKHZhbHVlLCBwcm9wKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKG9ianNbcHJvcF0pIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKCFjb25maWdbcHJvcF0pIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZ1twcm9wXSA9IHt9OwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIG1peGluKGNvbmZpZ1twcm9wXSwgdmFsdWUsIHRydWUsIHRydWUpOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZ1twcm9wXSA9IHZhbHVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgIC8vUmV2ZXJzZSBtYXAgdGhlIGJ1bmRsZXMKICAgICAgICAgICAgICAgIGlmIChjZmcuYnVuZGxlcykgewogICAgICAgICAgICAgICAgICAgIGVhY2hQcm9wKGNmZy5idW5kbGVzLCBmdW5jdGlvbiAodmFsdWUsIHByb3ApIHsKICAgICAgICAgICAgICAgICAgICAgICAgZWFjaCh2YWx1ZSwgZnVuY3Rpb24gKHYpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh2ICE9PSBwcm9wKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYnVuZGxlc01hcFt2XSA9IHByb3A7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vTWVyZ2Ugc2hpbQogICAgICAgICAgICAgICAgaWYgKGNmZy5zaGltKSB7CiAgICAgICAgICAgICAgICAgICAgZWFjaFByb3AoY2ZnLnNoaW0sIGZ1bmN0aW9uICh2YWx1ZSwgaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9Ob3JtYWxpemUgdGhlIHN0cnVjdHVyZQogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaXNBcnJheSh2YWx1ZSkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHZhbHVlID0gewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGRlcHM6IHZhbHVlCiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9OwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIGlmICgodmFsdWUuZXhwb3J0cyB8fCB2YWx1ZS5pbml0KSAmJiAhdmFsdWUuZXhwb3J0c0ZuKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB2YWx1ZS5leHBvcnRzRm4gPSBjb250ZXh0Lm1ha2VTaGltRXhwb3J0cyh2YWx1ZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgc2hpbVtpZF0gPSB2YWx1ZTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICBjb25maWcuc2hpbSA9IHNoaW07CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgLy9BZGp1c3QgcGFja2FnZXMgaWYgbmVjZXNzYXJ5LgogICAgICAgICAgICAgICAgaWYgKGNmZy5wYWNrYWdlcykgewogICAgICAgICAgICAgICAgICAgIGVhY2goY2ZnLnBhY2thZ2VzLCBmdW5jdGlvbiAocGtnT2JqKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBsb2NhdGlvbiwgbmFtZTsKCiAgICAgICAgICAgICAgICAgICAgICAgIHBrZ09iaiA9IHR5cGVvZiBwa2dPYmogPT09ICdzdHJpbmcnID8ge25hbWU6IHBrZ09ian0gOiBwa2dPYmo7CgogICAgICAgICAgICAgICAgICAgICAgICBuYW1lID0gcGtnT2JqLm5hbWU7CiAgICAgICAgICAgICAgICAgICAgICAgIGxvY2F0aW9uID0gcGtnT2JqLmxvY2F0aW9uOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAobG9jYXRpb24pIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZy5wYXRoc1tuYW1lXSA9IHBrZ09iai5sb2NhdGlvbjsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9TYXZlIHBvaW50ZXIgdG8gbWFpbiBtb2R1bGUgSUQgZm9yIHBrZyBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICAvL1JlbW92ZSBsZWFkaW5nIGRvdCBpbiBtYWluLCBzbyBtYWluIHBhdGhzIGFyZSBub3JtYWxpemVkLAogICAgICAgICAgICAgICAgICAgICAgICAvL2FuZCByZW1vdmUgYW55IHRyYWlsaW5nIC5qcywgc2luY2UgZGlmZmVyZW50IHBhY2thZ2UKICAgICAgICAgICAgICAgICAgICAgICAgLy9lbnZzIGhhdmUgZGlmZmVyZW50IGNvbnZlbnRpb25zOiBzb21lIHVzZSBhIG1vZHVsZSBuYW1lLAogICAgICAgICAgICAgICAgICAgICAgICAvL3NvbWUgdXNlIGEgZmlsZSBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICBjb25maWcucGtnc1tuYW1lXSA9IHBrZ09iai5uYW1lICsgJy8nICsgKHBrZ09iai5tYWluIHx8ICdtYWluJykKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC5yZXBsYWNlKGN1cnJEaXJSZWdFeHAsICcnKQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLnJlcGxhY2UoanNTdWZmaXhSZWdFeHAsICcnKTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL0lmIHRoZXJlIGFyZSBhbnkgIndhaXRpbmcgdG8gZXhlY3V0ZSIgbW9kdWxlcyBpbiB0aGUgcmVnaXN0cnksCiAgICAgICAgICAgICAgICAvL3VwZGF0ZSB0aGUgbWFwcyBmb3IgdGhlbSwgc2luY2UgdGhlaXIgaW5mbywgbGlrZSBVUkxzIHRvIGxvYWQsCiAgICAgICAgICAgICAgICAvL21heSBoYXZlIGNoYW5nZWQuCiAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCwgaWQpIHsKICAgICAgICAgICAgICAgICAgICAvL0lmIG1vZHVsZSBhbHJlYWR5IGhhcyBpbml0IGNhbGxlZCwgc2luY2UgaXQgaXMgdG9vCiAgICAgICAgICAgICAgICAgICAgLy9sYXRlIHRvIG1vZGlmeSB0aGVtLCBhbmQgaWdub3JlIHVubm9ybWFsaXplZCBvbmVzCiAgICAgICAgICAgICAgICAgICAgLy9zaW5jZSB0aGV5IGFyZSB0cmFuc2llbnQuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFtb2QuaW5pdGVkICYmICFtb2QubWFwLnVubm9ybWFsaXplZCkgewogICAgICAgICAgICAgICAgICAgICAgICBtb2QubWFwID0gbWFrZU1vZHVsZU1hcChpZCwgbnVsbCwgdHJ1ZSk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgLy9JZiBhIGRlcHMgYXJyYXkgb3IgYSBjb25maWcgY2FsbGJhY2sgaXMgc3BlY2lmaWVkLCB0aGVuIGNhbGwKICAgICAgICAgICAgICAgIC8vcmVxdWlyZSB3aXRoIHRob3NlIGFyZ3MuIFRoaXMgaXMgdXNlZnVsIHdoZW4gcmVxdWlyZSBpcyBkZWZpbmVkIGFzIGEKICAgICAgICAgICAgICAgIC8vY29uZmlnIG9iamVjdCBiZWZvcmUgcmVxdWlyZS5qcyBpcyBsb2FkZWQuCiAgICAgICAgICAgICAgICBpZiAoY2ZnLmRlcHMgfHwgY2ZnLmNhbGxiYWNrKSB7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5yZXF1aXJlKGNmZy5kZXBzIHx8IFtdLCBjZmcuY2FsbGJhY2spOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAoKICAgICAgICAgICAgbWFrZVNoaW1FeHBvcnRzOiBmdW5jdGlvbiAodmFsdWUpIHsKICAgICAgICAgICAgICAgIGZ1bmN0aW9uIGZuKCkgewogICAgICAgICAgICAgICAgICAgIHZhciByZXQ7CiAgICAgICAgICAgICAgICAgICAgaWYgKHZhbHVlLmluaXQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0ID0gdmFsdWUuaW5pdC5hcHBseShnbG9iYWwsIGFyZ3VtZW50cyk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIHJldHVybiByZXQgfHwgKHZhbHVlLmV4cG9ydHMgJiYgZ2V0R2xvYmFsKHZhbHVlLmV4cG9ydHMpKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIHJldHVybiBmbjsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIG1ha2VSZXF1aXJlOiBmdW5jdGlvbiAocmVsTWFwLCBvcHRpb25zKSB7CiAgICAgICAgICAgICAgICBvcHRpb25zID0gb3B0aW9ucyB8fCB7fTsKCiAgICAgICAgICAgICAgICBmdW5jdGlvbiBsb2NhbFJlcXVpcmUoZGVwcywgY2FsbGJhY2ssIGVycmJhY2spIHsKICAgICAgICAgICAgICAgICAgICB2YXIgaWQsIG1hcCwgcmVxdWlyZU1vZDsKCiAgICAgICAgICAgICAgICAgICAgaWYgKG9wdGlvbnMuZW5hYmxlQnVpbGRDYWxsYmFjayAmJiBjYWxsYmFjayAmJiBpc0Z1bmN0aW9uKGNhbGxiYWNrKSkgewogICAgICAgICAgICAgICAgICAgICAgICBjYWxsYmFjay5fX3JlcXVpcmVKc0J1aWxkID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGlmICh0eXBlb2YgZGVwcyA9PT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGlzRnVuY3Rpb24oY2FsbGJhY2spKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL0ludmFsaWQgY2FsbAogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdyZXF1aXJlYXJncycsICdJbnZhbGlkIHJlcXVpcmUgY2FsbCcpLCBlcnJiYWNrKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9JZiByZXF1aXJlfGV4cG9ydHN8bW9kdWxlIGFyZSByZXF1ZXN0ZWQsIGdldCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy92YWx1ZSBmb3IgdGhlbSBmcm9tIHRoZSBzcGVjaWFsIGhhbmRsZXJzLiBDYXZlYXQ6CiAgICAgICAgICAgICAgICAgICAgICAgIC8vdGhpcyBvbmx5IHdvcmtzIHdoaWxlIG1vZHVsZSBpcyBiZWluZyBkZWZpbmVkLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAocmVsTWFwICYmIGhhc1Byb3AoaGFuZGxlcnMsIGRlcHMpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gaGFuZGxlcnNbZGVwc10ocmVnaXN0cnlbcmVsTWFwLmlkXSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vU3luY2hyb25vdXMgYWNjZXNzIHRvIG9uZSBtb2R1bGUuIElmIHJlcXVpcmUuZ2V0IGlzCiAgICAgICAgICAgICAgICAgICAgICAgIC8vYXZhaWxhYmxlIChhcyBpbiB0aGUgTm9kZSBhZGFwdGVyKSwgcHJlZmVyIHRoYXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChyZXEuZ2V0KSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gcmVxLmdldChjb250ZXh0LCBkZXBzLCByZWxNYXAsIGxvY2FsUmVxdWlyZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplIG1vZHVsZSBuYW1lLCBpZiBpdCBjb250YWlucyAuIG9yIC4uCiAgICAgICAgICAgICAgICAgICAgICAgIG1hcCA9IG1ha2VNb2R1bGVNYXAoZGVwcywgcmVsTWFwLCBmYWxzZSwgdHJ1ZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIGlkID0gbWFwLmlkOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKCFoYXNQcm9wKGRlZmluZWQsIGlkKSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdub3Rsb2FkZWQnLCAnTW9kdWxlIG5hbWUgIicgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgaWQgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJyIgaGFzIG5vdCBiZWVuIGxvYWRlZCB5ZXQgZm9yIGNvbnRleHQ6ICcgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dE5hbWUgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgKHJlbE1hcCA/ICcnIDogJy4gVXNlIHJlcXVpcmUoW10pJykpKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gZGVmaW5lZFtpZF07CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAvL0dyYWIgZGVmaW5lcyB3YWl0aW5nIGluIHRoZSBnbG9iYWwgcXVldWUuCiAgICAgICAgICAgICAgICAgICAgaW50YWtlRGVmaW5lcygpOwoKICAgICAgICAgICAgICAgICAgICAvL01hcmsgYWxsIHRoZSBkZXBlbmRlbmNpZXMgYXMgbmVlZGluZyB0byBiZSBsb2FkZWQuCiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5uZXh0VGljayhmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vU29tZSBkZWZpbmVzIGNvdWxkIGhhdmUgYmVlbiBhZGRlZCBzaW5jZSB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy9yZXF1aXJlIGNhbGwsIGNvbGxlY3QgdGhlbS4KICAgICAgICAgICAgICAgICAgICAgICAgaW50YWtlRGVmaW5lcygpOwoKICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZCA9IGdldE1vZHVsZShtYWtlTW9kdWxlTWFwKG51bGwsIHJlbE1hcCkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9TdG9yZSBpZiBtYXAgY29uZmlnIHNob3VsZCBiZSBhcHBsaWVkIHRvIHRoaXMgcmVxdWlyZQogICAgICAgICAgICAgICAgICAgICAgICAvL2NhbGwgZm9yIGRlcGVuZGVuY2llcy4KICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZC5za2lwTWFwID0gb3B0aW9ucy5za2lwTWFwOwoKICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZC5pbml0KGRlcHMsIGNhbGxiYWNrLCBlcnJiYWNrLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVkOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgICAgICAgICAgY2hlY2tMb2FkZWQoKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGxvY2FsUmVxdWlyZTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBtaXhpbihsb2NhbFJlcXVpcmUsIHsKICAgICAgICAgICAgICAgICAgICBpc0Jyb3dzZXI6IGlzQnJvd3NlciwKCiAgICAgICAgICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAgICAgICAgICogQ29udmVydHMgYSBtb2R1bGUgbmFtZSArIC5leHRlbnNpb24gaW50byBhbiBVUkwgcGF0aC4KICAgICAgICAgICAgICAgICAgICAgKiAqUmVxdWlyZXMqIHRoZSB1c2Ugb2YgYSBtb2R1bGUgbmFtZS4gSXQgZG9lcyBub3Qgc3VwcG9ydCB1c2luZwogICAgICAgICAgICAgICAgICAgICAqIHBsYWluIFVSTHMgbGlrZSBuYW1lVG9VcmwuCiAgICAgICAgICAgICAgICAgICAgICovCiAgICAgICAgICAgICAgICAgICAgdG9Vcmw6IGZ1bmN0aW9uIChtb2R1bGVOYW1lUGx1c0V4dCkgewogICAgICAgICAgICAgICAgICAgICAgICB2YXIgZXh0LAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaW5kZXggPSBtb2R1bGVOYW1lUGx1c0V4dC5sYXN0SW5kZXhPZignLicpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgc2VnbWVudCA9IG1vZHVsZU5hbWVQbHVzRXh0LnNwbGl0KCcvJylbMF0sCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBpc1JlbGF0aXZlID0gc2VnbWVudCA9PT0gJy4nIHx8IHNlZ21lbnQgPT09ICcuLic7CgogICAgICAgICAgICAgICAgICAgICAgICAvL0hhdmUgYSBmaWxlIGV4dGVuc2lvbiBhbGlhcywgYW5kIGl0IGlzIG5vdCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy9kb3RzIGZyb20gYSByZWxhdGl2ZSBwYXRoLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaW5kZXggIT09IC0xICYmICghaXNSZWxhdGl2ZSB8fCBpbmRleCA+IDEpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBleHQgPSBtb2R1bGVOYW1lUGx1c0V4dC5zdWJzdHJpbmcoaW5kZXgsIG1vZHVsZU5hbWVQbHVzRXh0Lmxlbmd0aCk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVOYW1lUGx1c0V4dCA9IG1vZHVsZU5hbWVQbHVzRXh0LnN1YnN0cmluZygwLCBpbmRleCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiBjb250ZXh0Lm5hbWVUb1VybChub3JtYWxpemUobW9kdWxlTmFtZVBsdXNFeHQsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlbE1hcCAmJiByZWxNYXAuaWQsIHRydWUpLCBleHQsICB0cnVlKTsKICAgICAgICAgICAgICAgICAgICB9LAoKICAgICAgICAgICAgICAgICAgICBkZWZpbmVkOiBmdW5jdGlvbiAoaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGhhc1Byb3AoZGVmaW5lZCwgbWFrZU1vZHVsZU1hcChpZCwgcmVsTWFwLCBmYWxzZSwgdHJ1ZSkuaWQpOwogICAgICAgICAgICAgICAgICAgIH0sCgogICAgICAgICAgICAgICAgICAgIHNwZWNpZmllZDogZnVuY3Rpb24gKGlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlkID0gbWFrZU1vZHVsZU1hcChpZCwgcmVsTWFwLCBmYWxzZSwgdHJ1ZSkuaWQ7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiBoYXNQcm9wKGRlZmluZWQsIGlkKSB8fCBoYXNQcm9wKHJlZ2lzdHJ5LCBpZCk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgLy9Pbmx5IGFsbG93IHVuZGVmIG9uIHRvcCBsZXZlbCByZXF1aXJlIGNhbGxzCiAgICAgICAgICAgICAgICBpZiAoIXJlbE1hcCkgewogICAgICAgICAgICAgICAgICAgIGxvY2FsUmVxdWlyZS51bmRlZiA9IGZ1bmN0aW9uIChpZCkgewogICAgICAgICAgICAgICAgICAgICAgICAvL0JpbmQgYW55IHdhaXRpbmcgZGVmaW5lKCkgY2FsbHMgdG8gdGhpcyBjb250ZXh0LAogICAgICAgICAgICAgICAgICAgICAgICAvL2ZpeCBmb3IgIzQwOAogICAgICAgICAgICAgICAgICAgICAgICB0YWtlR2xvYmFsUXVldWUoKTsKCiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBtYXAgPSBtYWtlTW9kdWxlTWFwKGlkLCByZWxNYXAsIHRydWUpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kID0gZ2V0T3duKHJlZ2lzdHJ5LCBpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBtb2QudW5kZWZlZCA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgICAgIHJlbW92ZVNjcmlwdChpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBkZWxldGUgZGVmaW5lZFtpZF07CiAgICAgICAgICAgICAgICAgICAgICAgIGRlbGV0ZSB1cmxGZXRjaGVkW21hcC51cmxdOwogICAgICAgICAgICAgICAgICAgICAgICBkZWxldGUgdW5kZWZFdmVudHNbaWRdOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9DbGVhbiBxdWV1ZWQgZGVmaW5lcyB0b28uIEdvIGJhY2t3YXJkcwogICAgICAgICAgICAgICAgICAgICAgICAvL2luIGFycmF5IHNvIHRoYXQgdGhlIHNwbGljZXMgZG8gbm90CiAgICAgICAgICAgICAgICAgICAgICAgIC8vbWVzcyB1cCB0aGUgaXRlcmF0aW9uLgogICAgICAgICAgICAgICAgICAgICAgICBlYWNoUmV2ZXJzZShkZWZRdWV1ZSwgZnVuY3Rpb24oYXJncywgaSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IGlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZGVmUXVldWUuc3BsaWNlKGksIDEpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICAgICAgZGVsZXRlIGNvbnRleHQuZGVmUXVldWVNYXBbaWRdOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9Ib2xkIG9uIHRvIGxpc3RlbmVycyBpbiBjYXNlIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9tb2R1bGUgd2lsbCBiZSBhdHRlbXB0ZWQgdG8gYmUgcmVsb2FkZWQKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vdXNpbmcgYSBkaWZmZXJlbnQgY29uZmlnLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZC5ldmVudHMuZGVmaW5lZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHVuZGVmRXZlbnRzW2lkXSA9IG1vZC5ldmVudHM7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgY2xlYW5SZWdpc3RyeShpZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9OwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHJldHVybiBsb2NhbFJlcXVpcmU7CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogQ2FsbGVkIHRvIGVuYWJsZSBhIG1vZHVsZSBpZiBpdCBpcyBzdGlsbCBpbiB0aGUgcmVnaXN0cnkKICAgICAgICAgICAgICogYXdhaXRpbmcgZW5hYmxlbWVudC4gQSBzZWNvbmQgYXJnLCBwYXJlbnQsIHRoZSBwYXJlbnQgbW9kdWxlLAogICAgICAgICAgICAgKiBpcyBwYXNzZWQgaW4gZm9yIGNvbnRleHQsIHdoZW4gdGhpcyBtZXRob2QgaXMgb3ZlcnJpZGRlbiBieQogICAgICAgICAgICAgKiB0aGUgb3B0aW1pemVyLiBOb3Qgc2hvd24gaGVyZSB0byBrZWVwIGNvZGUgY29tcGFjdC4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGVuYWJsZTogZnVuY3Rpb24gKGRlcE1hcCkgewogICAgICAgICAgICAgICAgdmFyIG1vZCA9IGdldE93bihyZWdpc3RyeSwgZGVwTWFwLmlkKTsKICAgICAgICAgICAgICAgIGlmIChtb2QpIHsKICAgICAgICAgICAgICAgICAgICBnZXRNb2R1bGUoZGVwTWFwKS5lbmFibGUoKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIC8qKgogICAgICAgICAgICAgKiBJbnRlcm5hbCBtZXRob2QgdXNlZCBieSBlbnZpcm9ubWVudCBhZGFwdGVycyB0byBjb21wbGV0ZSBhIGxvYWQgZXZlbnQuCiAgICAgICAgICAgICAqIEEgbG9hZCBldmVudCBjb3VsZCBiZSBhIHNjcmlwdCBsb2FkIG9yIGp1c3QgYSBsb2FkIHBhc3MgZnJvbSBhIHN5bmNocm9ub3VzCiAgICAgICAgICAgICAqIGxvYWQgY2FsbC4KICAgICAgICAgICAgICogQHBhcmFtIHtTdHJpbmd9IG1vZHVsZU5hbWUgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZSB0byBwb3RlbnRpYWxseSBjb21wbGV0ZS4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGNvbXBsZXRlTG9hZDogZnVuY3Rpb24gKG1vZHVsZU5hbWUpIHsKICAgICAgICAgICAgICAgIHZhciBmb3VuZCwgYXJncywgbW9kLAogICAgICAgICAgICAgICAgICAgIHNoaW0gPSBnZXRPd24oY29uZmlnLnNoaW0sIG1vZHVsZU5hbWUpIHx8IHt9LAogICAgICAgICAgICAgICAgICAgIHNoRXhwb3J0cyA9IHNoaW0uZXhwb3J0czsKCiAgICAgICAgICAgICAgICB0YWtlR2xvYmFsUXVldWUoKTsKCiAgICAgICAgICAgICAgICB3aGlsZSAoZGVmUXVldWUubGVuZ3RoKSB7CiAgICAgICAgICAgICAgICAgICAgYXJncyA9IGRlZlF1ZXVlLnNoaWZ0KCk7CiAgICAgICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IG51bGwpIHsKICAgICAgICAgICAgICAgICAgICAgICAgYXJnc1swXSA9IG1vZHVsZU5hbWU7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vSWYgYWxyZWFkeSBmb3VuZCBhbiBhbm9ueW1vdXMgbW9kdWxlIGFuZCBib3VuZCBpdAogICAgICAgICAgICAgICAgICAgICAgICAvL3RvIHRoaXMgbmFtZSwgdGhlbiB0aGlzIGlzIHNvbWUgb3RoZXIgYW5vbiBtb2R1bGUKICAgICAgICAgICAgICAgICAgICAgICAgLy93YWl0aW5nIGZvciBpdHMgY29tcGxldGVMb2FkIHRvIGZpcmUuCiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChmb3VuZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgZm91bmQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSBpZiAoYXJnc1swXSA9PT0gbW9kdWxlTmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICAvL0ZvdW5kIG1hdGNoaW5nIGRlZmluZSBjYWxsIGZvciB0aGlzIHNjcmlwdCEKICAgICAgICAgICAgICAgICAgICAgICAgZm91bmQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgY2FsbEdldE1vZHVsZShhcmdzKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNvbnRleHQuZGVmUXVldWVNYXAgPSB7fTsKCiAgICAgICAgICAgICAgICAvL0RvIHRoaXMgYWZ0ZXIgdGhlIGN5Y2xlIG9mIGNhbGxHZXRNb2R1bGUgaW4gY2FzZSB0aGUgcmVzdWx0CiAgICAgICAgICAgICAgICAvL29mIHRob3NlIGNhbGxzL2luaXQgY2FsbHMgY2hhbmdlcyB0aGUgcmVnaXN0cnkuCiAgICAgICAgICAgICAgICBtb2QgPSBnZXRPd24ocmVnaXN0cnksIG1vZHVsZU5hbWUpOwoKICAgICAgICAgICAgICAgIGlmICghZm91bmQgJiYgIWhhc1Byb3AoZGVmaW5lZCwgbW9kdWxlTmFtZSkgJiYgbW9kICYmICFtb2QuaW5pdGVkKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKGNvbmZpZy5lbmZvcmNlRGVmaW5lICYmICghc2hFeHBvcnRzIHx8ICFnZXRHbG9iYWwoc2hFeHBvcnRzKSkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGhhc1BhdGhGYWxsYmFjayhtb2R1bGVOYW1lKSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdub2RlZmluZScsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdObyBkZWZpbmUgY2FsbCBmb3IgJyArIG1vZHVsZU5hbWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIG51bGwsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIFttb2R1bGVOYW1lXSkpOwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9BIHNjcmlwdCB0aGF0IGRvZXMgbm90IGNhbGwgZGVmaW5lKCksIHNvIGp1c3Qgc2ltdWxhdGUKICAgICAgICAgICAgICAgICAgICAgICAgLy90aGUgY2FsbCBmb3IgaXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGNhbGxHZXRNb2R1bGUoW21vZHVsZU5hbWUsIChzaGltLmRlcHMgfHwgW10pLCBzaGltLmV4cG9ydHNGbl0pOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBjaGVja0xvYWRlZCgpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIENvbnZlcnRzIGEgbW9kdWxlIG5hbWUgdG8gYSBmaWxlIHBhdGguIFN1cHBvcnRzIGNhc2VzIHdoZXJlCiAgICAgICAgICAgICAqIG1vZHVsZU5hbWUgbWF5IGFjdHVhbGx5IGJlIGp1c3QgYW4gVVJMLgogICAgICAgICAgICAgKiBOb3RlIHRoYXQgaXQgKipkb2VzIG5vdCoqIGNhbGwgbm9ybWFsaXplIG9uIHRoZSBtb2R1bGVOYW1lLAogICAgICAgICAgICAgKiBpdCBpcyBhc3N1bWVkIHRvIGhhdmUgYWxyZWFkeSBiZWVuIG5vcm1hbGl6ZWQuIFRoaXMgaXMgYW4KICAgICAgICAgICAgICogaW50ZXJuYWwgQVBJLCBub3QgYSBwdWJsaWMgb25lLiBVc2UgdG9VcmwgZm9yIHRoZSBwdWJsaWMgQVBJLgogICAgICAgICAgICAgKi8KICAgICAgICAgICAgbmFtZVRvVXJsOiBmdW5jdGlvbiAobW9kdWxlTmFtZSwgZXh0LCBza2lwRXh0KSB7CiAgICAgICAgICAgICAgICB2YXIgcGF0aHMsIHN5bXMsIGksIHBhcmVudE1vZHVsZSwgdXJsLAogICAgICAgICAgICAgICAgICAgIHBhcmVudFBhdGgsIGJ1bmRsZUlkLAogICAgICAgICAgICAgICAgICAgIHBrZ01haW4gPSBnZXRPd24oY29uZmlnLnBrZ3MsIG1vZHVsZU5hbWUpOwoKICAgICAgICAgICAgICAgIGlmIChwa2dNYWluKSB7CiAgICAgICAgICAgICAgICAgICAgbW9kdWxlTmFtZSA9IHBrZ01haW47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgYnVuZGxlSWQgPSBnZXRPd24oYnVuZGxlc01hcCwgbW9kdWxlTmFtZSk7CgogICAgICAgICAgICAgICAgaWYgKGJ1bmRsZUlkKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGNvbnRleHQubmFtZVRvVXJsKGJ1bmRsZUlkLCBleHQsIHNraXBFeHQpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vSWYgYSBjb2xvbiBpcyBpbiB0aGUgVVJMLCBpdCBpbmRpY2F0ZXMgYSBwcm90b2NvbCBpcyB1c2VkIGFuZCBpdCBpcyBqdXN0CiAgICAgICAgICAgICAgICAvL2FuIFVSTCB0byBhIGZpbGUsIG9yIGlmIGl0IHN0YXJ0cyB3aXRoIGEgc2xhc2gsIGNvbnRhaW5zIGEgcXVlcnkgYXJnIChpLmUuID8pCiAgICAgICAgICAgICAgICAvL29yIGVuZHMgd2l0aCAuanMsIHRoZW4gYXNzdW1lIHRoZSB1c2VyIG1lYW50IHRvIHVzZSBhbiB1cmwgYW5kIG5vdCBhIG1vZHVsZSBpZC4KICAgICAgICAgICAgICAgIC8vVGhlIHNsYXNoIGlzIGltcG9ydGFudCBmb3IgcHJvdG9jb2wtbGVzcyBVUkxzIGFzIHdlbGwgYXMgZnVsbCBwYXRocy4KICAgICAgICAgICAgICAgIGlmIChyZXEuanNFeHRSZWdFeHAudGVzdChtb2R1bGVOYW1lKSkgewogICAgICAgICAgICAgICAgICAgIC8vSnVzdCBhIHBsYWluIHBhdGgsIG5vdCBtb2R1bGUgbmFtZSBsb29rdXAsIHNvIGp1c3QgcmV0dXJuIGl0LgogICAgICAgICAgICAgICAgICAgIC8vQWRkIGV4dGVuc2lvbiBpZiBpdCBpcyBpbmNsdWRlZC4gVGhpcyBpcyBhIGJpdCB3b25reSwgb25seSBub24tLmpzIHRoaW5ncyBwYXNzCiAgICAgICAgICAgICAgICAgICAgLy9hbiBleHRlbnNpb24sIHRoaXMgbWV0aG9kIHByb2JhYmx5IG5lZWRzIHRvIGJlIHJld29ya2VkLgogICAgICAgICAgICAgICAgICAgIHVybCA9IG1vZHVsZU5hbWUgKyAoZXh0IHx8ICcnKTsKICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgLy9BIG1vZHVsZSB0aGF0IG5lZWRzIHRvIGJlIGNvbnZlcnRlZCB0byBhIHBhdGguCiAgICAgICAgICAgICAgICAgICAgcGF0aHMgPSBjb25maWcucGF0aHM7CgogICAgICAgICAgICAgICAgICAgIHN5bXMgPSBtb2R1bGVOYW1lLnNwbGl0KCcvJyk7CiAgICAgICAgICAgICAgICAgICAgLy9Gb3IgZWFjaCBtb2R1bGUgbmFtZSBzZWdtZW50LCBzZWUgaWYgdGhlcmUgaXMgYSBwYXRoCiAgICAgICAgICAgICAgICAgICAgLy9yZWdpc3RlcmVkIGZvciBpdC4gU3RhcnQgd2l0aCBtb3N0IHNwZWNpZmljIG5hbWUKICAgICAgICAgICAgICAgICAgICAvL2FuZCB3b3JrIHVwIGZyb20gaXQuCiAgICAgICAgICAgICAgICAgICAgZm9yIChpID0gc3ltcy5sZW5ndGg7IGkgPiAwOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcGFyZW50TW9kdWxlID0gc3ltcy5zbGljZSgwLCBpKS5qb2luKCcvJyk7CgogICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnRQYXRoID0gZ2V0T3duKHBhdGhzLCBwYXJlbnRNb2R1bGUpOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAocGFyZW50UGF0aCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9JZiBhbiBhcnJheSwgaXQgbWVhbnMgdGhlcmUgYXJlIGEgZmV3IGNob2ljZXMsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL0Nob29zZSB0aGUgb25lIHRoYXQgaXMgZGVzaXJlZAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGlzQXJyYXkocGFyZW50UGF0aCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnRQYXRoID0gcGFyZW50UGF0aFswXTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIHN5bXMuc3BsaWNlKDAsIGksIHBhcmVudFBhdGgpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIC8vSm9pbiB0aGUgcGF0aCBwYXJ0cyB0b2dldGhlciwgdGhlbiBmaWd1cmUgb3V0IGlmIGJhc2VVcmwgaXMgbmVlZGVkLgogICAgICAgICAgICAgICAgICAgIHVybCA9IHN5bXMuam9pbignLycpOwogICAgICAgICAgICAgICAgICAgIHVybCArPSAoZXh0IHx8ICgvXmRhdGFcOnxcPy8udGVzdCh1cmwpIHx8IHNraXBFeHQgPyAnJyA6ICcuanMnKSk7CiAgICAgICAgICAgICAgICAgICAgdXJsID0gKHVybC5jaGFyQXQoMCkgPT09ICcvJyB8fCB1cmwubWF0Y2goL15bXHdcK1wuXC1dKzovKSA/ICcnIDogY29uZmlnLmJhc2VVcmwpICsgdXJsOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHJldHVybiBjb25maWcudXJsQXJncyA/IHVybCArCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAoKHVybC5pbmRleE9mKCc/JykgPT09IC0xID8gJz8nIDogJyYnKSArCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgY29uZmlnLnVybEFyZ3MpIDogdXJsOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLy9EZWxlZ2F0ZXMgdG8gcmVxLmxvYWQuIEJyb2tlbiBvdXQgYXMgYSBzZXBhcmF0ZSBmdW5jdGlvbiB0bwogICAgICAgICAgICAvL2FsbG93IG92ZXJyaWRpbmcgaW4gdGhlIG9wdGltaXplci4KICAgICAgICAgICAgbG9hZDogZnVuY3Rpb24gKGlkLCB1cmwpIHsKICAgICAgICAgICAgICAgIHJlcS5sb2FkKGNvbnRleHQsIGlkLCB1cmwpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIEV4ZWN1dGVzIGEgbW9kdWxlIGNhbGxiYWNrIGZ1bmN0aW9uLiBCcm9rZW4gb3V0IGFzIGEgc2VwYXJhdGUgZnVuY3Rpb24KICAgICAgICAgICAgICogc29sZWx5IHRvIGFsbG93IHRoZSBidWlsZCBzeXN0ZW0gdG8gc2VxdWVuY2UgdGhlIGZpbGVzIGluIHRoZSBidWlsdAogICAgICAgICAgICAgKiBsYXllciBpbiB0aGUgcmlnaHQgc2VxdWVuY2UuCiAgICAgICAgICAgICAqCiAgICAgICAgICAgICAqIEBwcml2YXRlCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBleGVjQ2I6IGZ1bmN0aW9uIChuYW1lLCBjYWxsYmFjaywgYXJncywgZXhwb3J0cykgewogICAgICAgICAgICAgICAgcmV0dXJuIGNhbGxiYWNrLmFwcGx5KGV4cG9ydHMsIGFyZ3MpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIGNhbGxiYWNrIGZvciBzY3JpcHQgbG9hZHMsIHVzZWQgdG8gY2hlY2sgc3RhdHVzIG9mIGxvYWRpbmcuCiAgICAgICAgICAgICAqCiAgICAgICAgICAgICAqIEBwYXJhbSB7RXZlbnR9IGV2dCB0aGUgZXZlbnQgZnJvbSB0aGUgYnJvd3NlciBmb3IgdGhlIHNjcmlwdAogICAgICAgICAgICAgKiB0aGF0IHdhcyBsb2FkZWQuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBvblNjcmlwdExvYWQ6IGZ1bmN0aW9uIChldnQpIHsKICAgICAgICAgICAgICAgIC8vVXNpbmcgY3VycmVudFRhcmdldCBpbnN0ZWFkIG9mIHRhcmdldCBmb3IgRmlyZWZveCAyLjAncyBzYWtlLiBOb3QKICAgICAgICAgICAgICAgIC8vYWxsIG9sZCBicm93c2VycyB3aWxsIGJlIHN1cHBvcnRlZCwgYnV0IHRoaXMgb25lIHdhcyBlYXN5IGVub3VnaAogICAgICAgICAgICAgICAgLy90byBzdXBwb3J0IGFuZCBzdGlsbCBtYWtlcyBzZW5zZS4KICAgICAgICAgICAgICAgIGlmIChldnQudHlwZSA9PT0gJ2xvYWQnIHx8CiAgICAgICAgICAgICAgICAgICAgICAgIChyZWFkeVJlZ0V4cC50ZXN0KChldnQuY3VycmVudFRhcmdldCB8fCBldnQuc3JjRWxlbWVudCkucmVhZHlTdGF0ZSkpKSB7CiAgICAgICAgICAgICAgICAgICAgLy9SZXNldCBpbnRlcmFjdGl2ZSBzY3JpcHQgc28gYSBzY3JpcHQgbm9kZSBpcyBub3QgaGVsZCBvbnRvIGZvcgogICAgICAgICAgICAgICAgICAgIC8vdG8gbG9uZy4KICAgICAgICAgICAgICAgICAgICBpbnRlcmFjdGl2ZVNjcmlwdCA9IG51bGw7CgogICAgICAgICAgICAgICAgICAgIC8vUHVsbCBvdXQgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZSBhbmQgdGhlIGNvbnRleHQuCiAgICAgICAgICAgICAgICAgICAgdmFyIGRhdGEgPSBnZXRTY3JpcHREYXRhKGV2dCk7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5jb21wbGV0ZUxvYWQoZGF0YS5pZCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogQ2FsbGJhY2sgZm9yIHNjcmlwdCBlcnJvcnMuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBvblNjcmlwdEVycm9yOiBmdW5jdGlvbiAoZXZ0KSB7CiAgICAgICAgICAgICAgICB2YXIgZGF0YSA9IGdldFNjcmlwdERhdGEoZXZ0KTsKICAgICAgICAgICAgICAgIGlmICghaGFzUGF0aEZhbGxiYWNrKGRhdGEuaWQpKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIHBhcmVudHMgPSBbXTsKICAgICAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24odmFsdWUsIGtleSkgewogICAgICAgICAgICAgICAgICAgICAgICBpZiAoa2V5LmluZGV4T2YoJ19AcicpICE9PSAwKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlYWNoKHZhbHVlLmRlcE1hcHMsIGZ1bmN0aW9uKGRlcE1hcCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmIChkZXBNYXAuaWQgPT09IGRhdGEuaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcGFyZW50cy5wdXNoKGtleSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ3NjcmlwdGVycm9yJywgJ1NjcmlwdCBlcnJvciBmb3IgIicgKyBkYXRhLmlkICsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgKHBhcmVudHMubGVuZ3RoID8KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJyIsIG5lZWRlZCBieTogJyArIHBhcmVudHMuam9pbignLCAnKSA6CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICciJyksIGV2dCwgW2RhdGEuaWRdKSk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBjb250ZXh0LnJlcXVpcmUgPSBjb250ZXh0Lm1ha2VSZXF1aXJlKCk7CiAgICAgICAgcmV0dXJuIGNvbnRleHQ7CiAgICB9CgogICAgLyoqCiAgICAgKiBNYWluIGVudHJ5IHBvaW50LgogICAgICoKICAgICAqIElmIHRoZSBvbmx5IGFyZ3VtZW50IHRvIHJlcXVpcmUgaXMgYSBzdHJpbmcsIHRoZW4gdGhlIG1vZHVsZSB0aGF0CiAgICAgKiBpcyByZXByZXNlbnRlZCBieSB0aGF0IHN0cmluZyBpcyBmZXRjaGVkIGZvciB0aGUgYXBwcm9wcmlhdGUgY29udGV4dC4KICAgICAqCiAgICAgKiBJZiB0aGUgZmlyc3QgYXJndW1lbnQgaXMgYW4gYXJyYXksIHRoZW4gaXQgd2lsbCBiZSB0cmVhdGVkIGFzIGFuIGFycmF5CiAgICAgKiBvZiBkZXBlbmRlbmN5IHN0cmluZyBuYW1lcyB0byBmZXRjaC4gQW4gb3B0aW9uYWwgZnVuY3Rpb24gY2FsbGJhY2sgY2FuCiAgICAgKiBiZSBzcGVjaWZpZWQgdG8gZXhlY3V0ZSB3aGVuIGFsbCBvZiB0aG9zZSBkZXBlbmRlbmNpZXMgYXJlIGF2YWlsYWJsZS4KICAgICAqCiAgICAgKiBNYWtlIGEgbG9jYWwgcmVxIHZhcmlhYmxlIHRvIGhlbHAgQ2FqYSBjb21wbGlhbmNlIChpdCBhc3N1bWVzIHRoaW5ncwogICAgICogb24gYSByZXF1aXJlIHRoYXQgYXJlIG5vdCBzdGFuZGFyZGl6ZWQpLCBhbmQgdG8gZ2l2ZSBhIHNob3J0CiAgICAgKiBuYW1lIGZvciBtaW5pZmljYXRpb24vbG9jYWwgc2NvcGUgdXNlLgogICAgICovCiAgICByZXEgPSByZXF1aXJlanMgPSBmdW5jdGlvbiAoZGVwcywgY2FsbGJhY2ssIGVycmJhY2ssIG9wdGlvbmFsKSB7CgogICAgICAgIC8vRmluZCB0aGUgcmlnaHQgY29udGV4dCwgdXNlIGRlZmF1bHQKICAgICAgICB2YXIgY29udGV4dCwgY29uZmlnLAogICAgICAgICAgICBjb250ZXh0TmFtZSA9IGRlZkNvbnRleHROYW1lOwoKICAgICAgICAvLyBEZXRlcm1pbmUgaWYgaGF2ZSBjb25maWcgb2JqZWN0IGluIHRoZSBjYWxsLgogICAgICAgIGlmICghaXNBcnJheShkZXBzKSAmJiB0eXBlb2YgZGVwcyAhPT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgLy8gZGVwcyBpcyBhIGNvbmZpZyBvYmplY3QKICAgICAgICAgICAgY29uZmlnID0gZGVwczsKICAgICAgICAgICAgaWYgKGlzQXJyYXkoY2FsbGJhY2spKSB7CiAgICAgICAgICAgICAgICAvLyBBZGp1c3QgYXJncyBpZiB0aGVyZSBhcmUgZGVwZW5kZW5jaWVzCiAgICAgICAgICAgICAgICBkZXBzID0gY2FsbGJhY2s7CiAgICAgICAgICAgICAgICBjYWxsYmFjayA9IGVycmJhY2s7CiAgICAgICAgICAgICAgICBlcnJiYWNrID0gb3B0aW9uYWw7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBkZXBzID0gW107CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIGlmIChjb25maWcgJiYgY29uZmlnLmNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dE5hbWUgPSBjb25maWcuY29udGV4dDsKICAgICAgICB9CgogICAgICAgIGNvbnRleHQgPSBnZXRPd24oY29udGV4dHMsIGNvbnRleHROYW1lKTsKICAgICAgICBpZiAoIWNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dCA9IGNvbnRleHRzW2NvbnRleHROYW1lXSA9IHJlcS5zLm5ld0NvbnRleHQoY29udGV4dE5hbWUpOwogICAgICAgIH0KCiAgICAgICAgaWYgKGNvbmZpZykgewogICAgICAgICAgICBjb250ZXh0LmNvbmZpZ3VyZShjb25maWcpOwogICAgICAgIH0KCiAgICAgICAgcmV0dXJuIGNvbnRleHQucmVxdWlyZShkZXBzLCBjYWxsYmFjaywgZXJyYmFjayk7CiAgICB9OwoKICAgIC8qKgogICAgICogU3VwcG9ydCByZXF1aXJlLmNvbmZpZygpIHRvIG1ha2UgaXQgZWFzaWVyIHRvIGNvb3BlcmF0ZSB3aXRoIG90aGVyCiAgICAgKiBBTUQgbG9hZGVycyBvbiBnbG9iYWxseSBhZ3JlZWQgbmFtZXMuCiAgICAgKi8KICAgIHJlcS5jb25maWcgPSBmdW5jdGlvbiAoY29uZmlnKSB7CiAgICAgICAgcmV0dXJuIHJlcShjb25maWcpOwogICAgfTsKCiAgICAvKioKICAgICAqIEV4ZWN1dGUgc29tZXRoaW5nIGFmdGVyIHRoZSBjdXJyZW50IHRpY2sKICAgICAqIG9mIHRoZSBldmVudCBsb29wLiBPdmVycmlkZSBmb3Igb3RoZXIgZW52cwogICAgICogdGhhdCBoYXZlIGEgYmV0dGVyIHNvbHV0aW9uIHRoYW4gc2V0VGltZW91dC4KICAgICAqIEBwYXJhbSAge0Z1bmN0aW9ufSBmbiBmdW5jdGlvbiB0byBleGVjdXRlIGxhdGVyLgogICAgICovCiAgICByZXEubmV4dFRpY2sgPSB0eXBlb2Ygc2V0VGltZW91dCAhPT0gJ3VuZGVmaW5lZCcgPyBmdW5jdGlvbiAoZm4pIHsKICAgICAgICBzZXRUaW1lb3V0KGZuLCA0KTsKICAgIH0gOiBmdW5jdGlvbiAoZm4pIHsgZm4oKTsgfTsKCiAgICAvKioKICAgICAqIEV4cG9ydCByZXF1aXJlIGFzIGEgZ2xvYmFsLCBidXQgb25seSBpZiBpdCBkb2VzIG5vdCBhbHJlYWR5IGV4aXN0LgogICAgICovCiAgICBpZiAoIXJlcXVpcmUpIHsKICAgICAgICByZXF1aXJlID0gcmVxOwogICAgfQoKICAgIHJlcS52ZXJzaW9uID0gdmVyc2lvbjsKCiAgICAvL1VzZWQgdG8gZmlsdGVyIG91dCBkZXBlbmRlbmNpZXMgdGhhdCBhcmUgYWxyZWFkeSBwYXRocy4KICAgIHJlcS5qc0V4dFJlZ0V4cCA9IC9eXC98OnxcP3xcLmpzJC87CiAgICByZXEuaXNCcm93c2VyID0gaXNCcm93c2VyOwogICAgcyA9IHJlcS5zID0gewogICAgICAgIGNvbnRleHRzOiBjb250ZXh0cywKICAgICAgICBuZXdDb250ZXh0OiBuZXdDb250ZXh0CiAgICB9OwoKICAgIC8vQ3JlYXRlIGRlZmF1bHQgY29udGV4dC4KICAgIHJlcSh7fSk7CgogICAgLy9FeHBvcnRzIHNvbWUgY29udGV4dC1zZW5zaXRpdmUgbWV0aG9kcyBvbiBnbG9iYWwgcmVxdWlyZS4KICAgIGVhY2goWwogICAgICAgICd0b1VybCcsCiAgICAgICAgJ3VuZGVmJywKICAgICAgICAnZGVmaW5lZCcsCiAgICAgICAgJ3NwZWNpZmllZCcKICAgIF0sIGZ1bmN0aW9uIChwcm9wKSB7CiAgICAgICAgLy9SZWZlcmVuY2UgZnJvbSBjb250ZXh0cyBpbnN0ZWFkIG9mIGVhcmx5IGJpbmRpbmcgdG8gZGVmYXVsdCBjb250ZXh0LAogICAgICAgIC8vc28gdGhhdCBkdXJpbmcgYnVpbGRzLCB0aGUgbGF0ZXN0IGluc3RhbmNlIG9mIHRoZSBkZWZhdWx0IGNvbnRleHQKICAgICAgICAvL3dpdGggaXRzIGNvbmZpZyBnZXRzIHVzZWQuCiAgICAgICAgcmVxW3Byb3BdID0gZnVuY3Rpb24gKCkgewogICAgICAgICAgICB2YXIgY3R4ID0gY29udGV4dHNbZGVmQ29udGV4dE5hbWVdOwogICAgICAgICAgICByZXR1cm4gY3R4LnJlcXVpcmVbcHJvcF0uYXBwbHkoY3R4LCBhcmd1bWVudHMpOwogICAgICAgIH07CiAgICB9KTsKCiAgICBpZiAoaXNCcm93c2VyKSB7CiAgICAgICAgaGVhZCA9IHMuaGVhZCA9IGRvY3VtZW50LmdldEVsZW1lbnRzQnlUYWdOYW1lKCdoZWFkJylbMF07CiAgICAgICAgLy9JZiBCQVNFIHRhZyBpcyBpbiBwbGF5LCB1c2luZyBhcHBlbmRDaGlsZCBpcyBhIHByb2JsZW0gZm9yIElFNi4KICAgICAgICAvL1doZW4gdGhhdCBicm93c2VyIGRpZXMsIHRoaXMgY2FuIGJlIHJlbW92ZWQuIERldGFpbHMgaW4gdGhpcyBqUXVlcnkgYnVnOgogICAgICAgIC8vaHR0cDovL2Rldi5qcXVlcnkuY29tL3RpY2tldC8yNzA5CiAgICAgICAgYmFzZUVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50c0J5VGFnTmFtZSgnYmFzZScpWzBdOwogICAgICAgIGlmIChiYXNlRWxlbWVudCkgewogICAgICAgICAgICBoZWFkID0gcy5oZWFkID0gYmFzZUVsZW1lbnQucGFyZW50Tm9kZTsKICAgICAgICB9CiAgICB9CgogICAgLyoqCiAgICAgKiBBbnkgZXJyb3JzIHRoYXQgcmVxdWlyZSBleHBsaWNpdGx5IGdlbmVyYXRlcyB3aWxsIGJlIHBhc3NlZCB0byB0aGlzCiAgICAgKiBmdW5jdGlvbi4gSW50ZXJjZXB0L292ZXJyaWRlIGl0IGlmIHlvdSB3YW50IGN1c3RvbSBlcnJvciBoYW5kbGluZy4KICAgICAqIEBwYXJhbSB7RXJyb3J9IGVyciB0aGUgZXJyb3Igb2JqZWN0LgogICAgICovCiAgICByZXEub25FcnJvciA9IGRlZmF1bHRPbkVycm9yOwoKICAgIC8qKgogICAgICogQ3JlYXRlcyB0aGUgbm9kZSBmb3IgdGhlIGxvYWQgY29tbWFuZC4gT25seSB1c2VkIGluIGJyb3dzZXIgZW52cy4KICAgICAqLwogICAgcmVxLmNyZWF0ZU5vZGUgPSBmdW5jdGlvbiAoY29uZmlnLCBtb2R1bGVOYW1lLCB1cmwpIHsKICAgICAgICB2YXIgbm9kZSA9IGNvbmZpZy54aHRtbCA/CiAgICAgICAgICAgICAgICBkb2N1bWVudC5jcmVhdGVFbGVtZW50TlMoJ2h0dHA6Ly93d3cudzMub3JnLzE5OTkveGh0bWwnLCAnaHRtbDpzY3JpcHQnKSA6CiAgICAgICAgICAgICAgICBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdzY3JpcHQnKTsKICAgICAgICBub2RlLnR5cGUgPSBjb25maWcuc2NyaXB0VHlwZSB8fCAndGV4dC9qYXZhc2NyaXB0JzsKICAgICAgICBub2RlLmNoYXJzZXQgPSAndXRmLTgnOwogICAgICAgIG5vZGUuYXN5bmMgPSB0cnVlOwogICAgICAgIHJldHVybiBub2RlOwogICAgfTsKCiAgICAvKioKICAgICAqIERvZXMgdGhlIHJlcXVlc3QgdG8gbG9hZCBhIG1vZHVsZSBmb3IgdGhlIGJyb3dzZXIgY2FzZS4KICAgICAqIE1ha2UgdGhpcyBhIHNlcGFyYXRlIGZ1bmN0aW9uIHRvIGFsbG93IG90aGVyIGVudmlyb25tZW50cwogICAgICogdG8gb3ZlcnJpZGUgaXQuCiAgICAgKgogICAgICogQHBhcmFtIHtPYmplY3R9IGNvbnRleHQgdGhlIHJlcXVpcmUgY29udGV4dCB0byBmaW5kIHN0YXRlLgogICAgICogQHBhcmFtIHtTdHJpbmd9IG1vZHVsZU5hbWUgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZS4KICAgICAqIEBwYXJhbSB7T2JqZWN0fSB1cmwgdGhlIFVSTCB0byB0aGUgbW9kdWxlLgogICAgICovCiAgICByZXEubG9hZCA9IGZ1bmN0aW9uIChjb250ZXh0LCBtb2R1bGVOYW1lLCB1cmwpIHsKICAgICAgICB2YXIgY29uZmlnID0gKGNvbnRleHQgJiYgY29udGV4dC5jb25maWcpIHx8IHt9LAogICAgICAgICAgICBub2RlOwogICAgICAgIGlmIChpc0Jyb3dzZXIpIHsKICAgICAgICAgICAgLy9JbiB0aGUgYnJvd3NlciBzbyB1c2UgYSBzY3JpcHQgdGFnCiAgICAgICAgICAgIG5vZGUgPSByZXEuY3JlYXRlTm9kZShjb25maWcsIG1vZHVsZU5hbWUsIHVybCk7CiAgICAgICAgICAgIGlmIChjb25maWcub25Ob2RlQ3JlYXRlZCkgewogICAgICAgICAgICAgICAgY29uZmlnLm9uTm9kZUNyZWF0ZWQobm9kZSwgY29uZmlnLCBtb2R1bGVOYW1lLCB1cmwpOwogICAgICAgICAgICB9CgogICAgICAgICAgICBub2RlLnNldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlY29udGV4dCcsIGNvbnRleHQuY29udGV4dE5hbWUpOwogICAgICAgICAgICBub2RlLnNldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlbW9kdWxlJywgbW9kdWxlTmFtZSk7CgogICAgICAgICAgICAvL1NldCB1cCBsb2FkIGxpc3RlbmVyLiBUZXN0IGF0dGFjaEV2ZW50IGZpcnN0IGJlY2F1c2UgSUU5IGhhcwogICAgICAgICAgICAvL2Egc3VidGxlIGlzc3VlIGluIGl0cyBhZGRFdmVudExpc3RlbmVyIGFuZCBzY3JpcHQgb25sb2FkIGZpcmluZ3MKICAgICAgICAgICAgLy90aGF0IGRvIG5vdCBtYXRjaCB0aGUgYmVoYXZpb3Igb2YgYWxsIG90aGVyIGJyb3dzZXJzIHdpdGgKICAgICAgICAgICAgLy9hZGRFdmVudExpc3RlbmVyIHN1cHBvcnQsIHdoaWNoIGZpcmUgdGhlIG9ubG9hZCBldmVudCBmb3IgYQogICAgICAgICAgICAvL3NjcmlwdCByaWdodCBhZnRlciB0aGUgc2NyaXB0IGV4ZWN1dGlvbi4gU2VlOgogICAgICAgICAgICAvL2h0dHBzOi8vY29ubmVjdC5taWNyb3NvZnQuY29tL0lFL2ZlZWRiYWNrL2RldGFpbHMvNjQ4MDU3L3NjcmlwdC1vbmxvYWQtZXZlbnQtaXMtbm90LWZpcmVkLWltbWVkaWF0ZWx5LWFmdGVyLXNjcmlwdC1leGVjdXRpb24KICAgICAgICAgICAgLy9VTkZPUlRVTkFURUxZIE9wZXJhIGltcGxlbWVudHMgYXR0YWNoRXZlbnQgYnV0IGRvZXMgbm90IGZvbGxvdyB0aGUgc2NyaXB0CiAgICAgICAgICAgIC8vc2NyaXB0IGV4ZWN1dGlvbiBtb2RlLgogICAgICAgICAgICBpZiAobm9kZS5hdHRhY2hFdmVudCAmJgogICAgICAgICAgICAgICAgICAgIC8vQ2hlY2sgaWYgbm9kZS5hdHRhY2hFdmVudCBpcyBhcnRpZmljaWFsbHkgYWRkZWQgYnkgY3VzdG9tIHNjcmlwdCBvcgogICAgICAgICAgICAgICAgICAgIC8vbmF0aXZlbHkgc3VwcG9ydGVkIGJ5IGJyb3dzZXIKICAgICAgICAgICAgICAgICAgICAvL3JlYWQgaHR0cHM6Ly9naXRodWIuY29tL2pyYnVya2UvcmVxdWlyZWpzL2lzc3Vlcy8xODcKICAgICAgICAgICAgICAgICAgICAvL2lmIHdlIGNhbiBOT1QgZmluZCBbbmF0aXZlIGNvZGVdIHRoZW4gaXQgbXVzdCBOT1QgbmF0aXZlbHkgc3VwcG9ydGVkLgogICAgICAgICAgICAgICAgICAgIC8vaW4gSUU4LCBub2RlLmF0dGFjaEV2ZW50IGRvZXMgbm90IGhhdmUgdG9TdHJpbmcoKQogICAgICAgICAgICAgICAgICAgIC8vTm90ZSB0aGUgdGVzdCBmb3IgIltuYXRpdmUgY29kZSIgd2l0aCBubyBjbG9zaW5nIGJyYWNlLCBzZWU6CiAgICAgICAgICAgICAgICAgICAgLy9odHRwczovL2dpdGh1Yi5jb20vanJidXJrZS9yZXF1aXJlanMvaXNzdWVzLzI3MwogICAgICAgICAgICAgICAgICAgICEobm9kZS5hdHRhY2hFdmVudC50b1N0cmluZyAmJiBub2RlLmF0dGFjaEV2ZW50LnRvU3RyaW5nKCkuaW5kZXhPZignW25hdGl2ZSBjb2RlJykgPCAwKSAmJgogICAgICAgICAgICAgICAgICAgICFpc09wZXJhKSB7CiAgICAgICAgICAgICAgICAvL1Byb2JhYmx5IElFLiBJRSAoYXQgbGVhc3QgNi04KSBkbyBub3QgZmlyZQogICAgICAgICAgICAgICAgLy9zY3JpcHQgb25sb2FkIHJpZ2h0IGFmdGVyIGV4ZWN1dGluZyB0aGUgc2NyaXB0LCBzbwogICAgICAgICAgICAgICAgLy93ZSBjYW5ub3QgdGllIHRoZSBhbm9ueW1vdXMgZGVmaW5lIGNhbGwgdG8gYSBuYW1lLgogICAgICAgICAgICAgICAgLy9Ib3dldmVyLCBJRSByZXBvcnRzIHRoZSBzY3JpcHQgYXMgYmVpbmcgaW4gJ2ludGVyYWN0aXZlJwogICAgICAgICAgICAgICAgLy9yZWFkeVN0YXRlIGF0IHRoZSB0aW1lIG9mIHRoZSBkZWZpbmUgY2FsbC4KICAgICAgICAgICAgICAgIHVzZUludGVyYWN0aXZlID0gdHJ1ZTsKCiAgICAgICAgICAgICAgICBub2RlLmF0dGFjaEV2ZW50KCdvbnJlYWR5c3RhdGVjaGFuZ2UnLCBjb250ZXh0Lm9uU2NyaXB0TG9hZCk7CiAgICAgICAgICAgICAgICAvL0l0IHdvdWxkIGJlIGdyZWF0IHRvIGFkZCBhbiBlcnJvciBoYW5kbGVyIGhlcmUgdG8gY2F0Y2gKICAgICAgICAgICAgICAgIC8vNDA0cyBpbiBJRTkrLiBIb3dldmVyLCBvbnJlYWR5c3RhdGVjaGFuZ2Ugd2lsbCBmaXJlIGJlZm9yZQogICAgICAgICAgICAgICAgLy90aGUgZXJyb3IgaGFuZGxlciwgc28gdGhhdCBkb2VzIG5vdCBoZWxwLiBJZiBhZGRFdmVudExpc3RlbmVyCiAgICAgICAgICAgICAgICAvL2lzIHVzZWQsIHRoZW4gSUUgd2lsbCBmaXJlIGVycm9yIGJlZm9yZSBsb2FkLCBidXQgd2UgY2Fubm90CiAgICAgICAgICAgICAgICAvL3VzZSB0aGF0IHBhdGh3YXkgZ2l2ZW4gdGhlIGNvbm5lY3QubWljcm9zb2Z0LmNvbSBpc3N1ZQogICAgICAgICAgICAgICAgLy9tZW50aW9uZWQgYWJvdmUgYWJvdXQgbm90IGRvaW5nIHRoZSAnc2NyaXB0IGV4ZWN1dGUsCiAgICAgICAgICAgICAgICAvL3RoZW4gZmlyZSB0aGUgc2NyaXB0IGxvYWQgZXZlbnQgbGlzdGVuZXIgYmVmb3JlIGV4ZWN1dGUKICAgICAgICAgICAgICAgIC8vbmV4dCBzY3JpcHQnIHRoYXQgb3RoZXIgYnJvd3NlcnMgZG8uCiAgICAgICAgICAgICAgICAvL0Jlc3QgaG9wZTogSUUxMCBmaXhlcyB0aGUgaXNzdWVzLAogICAgICAgICAgICAgICAgLy9hbmQgdGhlbiBkZXN0cm95cyBhbGwgaW5zdGFsbHMgb2YgSUUgNi05LgogICAgICAgICAgICAgICAgLy9ub2RlLmF0dGFjaEV2ZW50KCdvbmVycm9yJywgY29udGV4dC5vblNjcmlwdEVycm9yKTsKICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgIG5vZGUuYWRkRXZlbnRMaXN0ZW5lcignbG9hZCcsIGNvbnRleHQub25TY3JpcHRMb2FkLCBmYWxzZSk7CiAgICAgICAgICAgICAgICBub2RlLmFkZEV2ZW50TGlzdGVuZXIoJ2Vycm9yJywgY29udGV4dC5vblNjcmlwdEVycm9yLCBmYWxzZSk7CiAgICAgICAgICAgIH0KICAgICAgICAgICAgbm9kZS5zcmMgPSB1cmw7CgogICAgICAgICAgICAvL0ZvciBzb21lIGNhY2hlIGNhc2VzIGluIElFIDYtOCwgdGhlIHNjcmlwdCBleGVjdXRlcyBiZWZvcmUgdGhlIGVuZAogICAgICAgICAgICAvL29mIHRoZSBhcHBlbmRDaGlsZCBleGVjdXRpb24sIHNvIHRvIHRpZSBhbiBhbm9ueW1vdXMgZGVmaW5lCiAgICAgICAgICAgIC8vY2FsbCB0byB0aGUgbW9kdWxlIG5hbWUgKHdoaWNoIGlzIHN0b3JlZCBvbiB0aGUgbm9kZSksIGhvbGQgb24KICAgICAgICAgICAgLy90byBhIHJlZmVyZW5jZSB0byB0aGlzIG5vZGUsIGJ1dCBjbGVhciBhZnRlciB0aGUgRE9NIGluc2VydGlvbi4KICAgICAgICAgICAgY3VycmVudGx5QWRkaW5nU2NyaXB0ID0gbm9kZTsKICAgICAgICAgICAgaWYgKGJhc2VFbGVtZW50KSB7CiAgICAgICAgICAgICAgICBoZWFkLmluc2VydEJlZm9yZShub2RlLCBiYXNlRWxlbWVudCk7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBoZWFkLmFwcGVuZENoaWxkKG5vZGUpOwogICAgICAgICAgICB9CiAgICAgICAgICAgIGN1cnJlbnRseUFkZGluZ1NjcmlwdCA9IG51bGw7CgogICAgICAgICAgICByZXR1cm4gbm9kZTsKICAgICAgICB9IGVsc2UgaWYgKGlzV2ViV29ya2VyKSB7CiAgICAgICAgICAgIHRyeSB7CiAgICAgICAgICAgICAgICAvL0luIGEgd2ViIHdvcmtlciwgdXNlIGltcG9ydFNjcmlwdHMuIFRoaXMgaXMgbm90IGEgdmVyeQogICAgICAgICAgICAgICAgLy9lZmZpY2llbnQgdXNlIG9mIGltcG9ydFNjcmlwdHMsIGltcG9ydFNjcmlwdHMgd2lsbCBibG9jayB1bnRpbAogICAgICAgICAgICAgICAgLy9pdHMgc2NyaXB0IGlzIGRvd25sb2FkZWQgYW5kIGV2YWx1YXRlZC4gSG93ZXZlciwgaWYgd2ViIHdvcmtlcnMKICAgICAgICAgICAgICAgIC8vYXJlIGluIHBsYXksIHRoZSBleHBlY3RhdGlvbiBpcyB0aGF0IGEgYnVpbGQgaGFzIGJlZW4gZG9uZSBzbwogICAgICAgICAgICAgICAgLy90aGF0IG9ubHkgb25lIHNjcmlwdCBuZWVkcyB0byBiZSBsb2FkZWQgYW55d2F5LiBUaGlzIG1heSBuZWVkCiAgICAgICAgICAgICAgICAvL3RvIGJlIHJlZXZhbHVhdGVkIGlmIG90aGVyIHVzZSBjYXNlcyBiZWNvbWUgY29tbW9uLgogICAgICAgICAgICAgICAgaW1wb3J0U2NyaXB0cyh1cmwpOwoKICAgICAgICAgICAgICAgIC8vQWNjb3VudCBmb3IgYW5vbnltb3VzIG1vZHVsZXMKICAgICAgICAgICAgICAgIGNvbnRleHQuY29tcGxldGVMb2FkKG1vZHVsZU5hbWUpOwogICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICBjb250ZXh0Lm9uRXJyb3IobWFrZUVycm9yKCdpbXBvcnRzY3JpcHRzJywKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnaW1wb3J0U2NyaXB0cyBmYWlsZWQgZm9yICcgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVOYW1lICsgJyBhdCAnICsgdXJsLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgW21vZHVsZU5hbWVdKSk7CiAgICAgICAgICAgIH0KICAgICAgICB9CiAgICB9OwoKICAgIGZ1bmN0aW9uIGdldEludGVyYWN0aXZlU2NyaXB0KCkgewogICAgICAgIGlmIChpbnRlcmFjdGl2ZVNjcmlwdCAmJiBpbnRlcmFjdGl2ZVNjcmlwdC5yZWFkeVN0YXRlID09PSAnaW50ZXJhY3RpdmUnKSB7CiAgICAgICAgICAgIHJldHVybiBpbnRlcmFjdGl2ZVNjcmlwdDsKICAgICAgICB9CgogICAgICAgIGVhY2hSZXZlcnNlKHNjcmlwdHMoKSwgZnVuY3Rpb24gKHNjcmlwdCkgewogICAgICAgICAgICBpZiAoc2NyaXB0LnJlYWR5U3RhdGUgPT09ICdpbnRlcmFjdGl2ZScpIHsKICAgICAgICAgICAgICAgIHJldHVybiAoaW50ZXJhY3RpdmVTY3JpcHQgPSBzY3JpcHQpOwogICAgICAgICAgICB9CiAgICAgICAgfSk7CiAgICAgICAgcmV0dXJuIGludGVyYWN0aXZlU2NyaXB0OwogICAgfQoKICAgIC8vTG9vayBmb3IgYSBkYXRhLW1haW4gc2NyaXB0IGF0dHJpYnV0ZSwgd2hpY2ggY291bGQgYWxzbyBhZGp1c3QgdGhlIGJhc2VVcmwuCiAgICBpZiAoaXNCcm93c2VyICYmICFjZmcuc2tpcERhdGFNYWluKSB7CiAgICAgICAgLy9GaWd1cmUgb3V0IGJhc2VVcmwuIEdldCBpdCBmcm9tIHRoZSBzY3JpcHQgdGFnIHdpdGggcmVxdWlyZS5qcyBpbiBpdC4KICAgICAgICBlYWNoUmV2ZXJzZShzY3JpcHRzKCksIGZ1bmN0aW9uIChzY3JpcHQpIHsKICAgICAgICAgICAgLy9TZXQgdGhlICdoZWFkJyB3aGVyZSB3ZSBjYW4gYXBwZW5kIGNoaWxkcmVuIGJ5CiAgICAgICAgICAgIC8vdXNpbmcgdGhlIHNjcmlwdCdzIHBhcmVudC4KICAgICAgICAgICAgaWYgKCFoZWFkKSB7CiAgICAgICAgICAgICAgICBoZWFkID0gc2NyaXB0LnBhcmVudE5vZGU7CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIC8vTG9vayBmb3IgYSBkYXRhLW1haW4gYXR0cmlidXRlIHRvIHNldCBtYWluIHNjcmlwdCBmb3IgdGhlIHBhZ2UKICAgICAgICAgICAgLy90byBsb2FkLiBJZiBpdCBpcyB0aGVyZSwgdGhlIHBhdGggdG8gZGF0YSBtYWluIGJlY29tZXMgdGhlCiAgICAgICAgICAgIC8vYmFzZVVybCwgaWYgaXQgaXMgbm90IGFscmVhZHkgc2V0LgogICAgICAgICAgICBkYXRhTWFpbiA9IHNjcmlwdC5nZXRBdHRyaWJ1dGUoJ2RhdGEtbWFpbicpOwogICAgICAgICAgICBpZiAoZGF0YU1haW4pIHsKICAgICAgICAgICAgICAgIC8vUHJlc2VydmUgZGF0YU1haW4gaW4gY2FzZSBpdCBpcyBhIHBhdGggKGkuZS4gY29udGFpbnMgJz8nKQogICAgICAgICAgICAgICAgbWFpblNjcmlwdCA9IGRhdGFNYWluOwoKICAgICAgICAgICAgICAgIC8vU2V0IGZpbmFsIGJhc2VVcmwgaWYgdGhlcmUgaXMgbm90IGFscmVhZHkgYW4gZXhwbGljaXQgb25lLgogICAgICAgICAgICAgICAgaWYgKCFjZmcuYmFzZVVybCkgewogICAgICAgICAgICAgICAgICAgIC8vUHVsbCBvZmYgdGhlIGRpcmVjdG9yeSBvZiBkYXRhLW1haW4gZm9yIHVzZSBhcyB0aGUKICAgICAgICAgICAgICAgICAgICAvL2Jhc2VVcmwuCiAgICAgICAgICAgICAgICAgICAgc3JjID0gbWFpblNjcmlwdC5zcGxpdCgnLycpOwogICAgICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBzcmMucG9wKCk7CiAgICAgICAgICAgICAgICAgICAgc3ViUGF0aCA9IHNyYy5sZW5ndGggPyBzcmMuam9pbignLycpICArICcvJyA6ICcuLyc7CgogICAgICAgICAgICAgICAgICAgIGNmZy5iYXNlVXJsID0gc3ViUGF0aDsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL1N0cmlwIG9mZiBhbnkgdHJhaWxpbmcgLmpzIHNpbmNlIG1haW5TY3JpcHQgaXMgbm93CiAgICAgICAgICAgICAgICAvL2xpa2UgYSBtb2R1bGUgbmFtZS4KICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBtYWluU2NyaXB0LnJlcGxhY2UoanNTdWZmaXhSZWdFeHAsICcnKTsKCiAgICAgICAgICAgICAgICAvL0lmIG1haW5TY3JpcHQgaXMgc3RpbGwgYSBwYXRoLCBmYWxsIGJhY2sgdG8gZGF0YU1haW4KICAgICAgICAgICAgICAgIGlmIChyZXEuanNFeHRSZWdFeHAudGVzdChtYWluU2NyaXB0KSkgewogICAgICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBkYXRhTWFpbjsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL1B1dCB0aGUgZGF0YS1tYWluIHNjcmlwdCBpbiB0aGUgZmlsZXMgdG8gbG9hZC4KICAgICAgICAgICAgICAgIGNmZy5kZXBzID0gY2ZnLmRlcHMgPyBjZmcuZGVwcy5jb25jYXQobWFpblNjcmlwdCkgOiBbbWFpblNjcmlwdF07CgogICAgICAgICAgICAgICAgcmV0dXJuIHRydWU7CiAgICAgICAgICAgIH0KICAgICAgICB9KTsKICAgIH0KCiAgICAvKioKICAgICAqIFRoZSBmdW5jdGlvbiB0aGF0IGhhbmRsZXMgZGVmaW5pdGlvbnMgb2YgbW9kdWxlcy4gRGlmZmVycyBmcm9tCiAgICAgKiByZXF1aXJlKCkgaW4gdGhhdCBhIHN0cmluZyBmb3IgdGhlIG1vZHVsZSBzaG91bGQgYmUgdGhlIGZpcnN0IGFyZ3VtZW50LAogICAgICogYW5kIHRoZSBmdW5jdGlvbiB0byBleGVjdXRlIGFmdGVyIGRlcGVuZGVuY2llcyBhcmUgbG9hZGVkIHNob3VsZAogICAgICogcmV0dXJuIGEgdmFsdWUgdG8gZGVmaW5lIHRoZSBtb2R1bGUgY29ycmVzcG9uZGluZyB0byB0aGUgZmlyc3QgYXJndW1lbnQncwogICAgICogbmFtZS4KICAgICAqLwogICAgZGVmaW5lID0gZnVuY3Rpb24gKG5hbWUsIGRlcHMsIGNhbGxiYWNrKSB7CiAgICAgICAgdmFyIG5vZGUsIGNvbnRleHQ7CgogICAgICAgIC8vQWxsb3cgZm9yIGFub255bW91cyBtb2R1bGVzCiAgICAgICAgaWYgKHR5cGVvZiBuYW1lICE9PSAnc3RyaW5nJykgewogICAgICAgICAgICAvL0FkanVzdCBhcmdzIGFwcHJvcHJpYXRlbHkKICAgICAgICAgICAgY2FsbGJhY2sgPSBkZXBzOwogICAgICAgICAgICBkZXBzID0gbmFtZTsKICAgICAgICAgICAgbmFtZSA9IG51bGw7CiAgICAgICAgfQoKICAgICAgICAvL1RoaXMgbW9kdWxlIG1heSBub3QgaGF2ZSBkZXBlbmRlbmNpZXMKICAgICAgICBpZiAoIWlzQXJyYXkoZGVwcykpIHsKICAgICAgICAgICAgY2FsbGJhY2sgPSBkZXBzOwogICAgICAgICAgICBkZXBzID0gbnVsbDsKICAgICAgICB9CgogICAgICAgIC8vSWYgbm8gbmFtZSwgYW5kIGNhbGxiYWNrIGlzIGEgZnVuY3Rpb24sIHRoZW4gZmlndXJlIG91dCBpZiBpdCBhCiAgICAgICAgLy9Db21tb25KUyB0aGluZyB3aXRoIGRlcGVuZGVuY2llcy4KICAgICAgICBpZiAoIWRlcHMgJiYgaXNGdW5jdGlvbihjYWxsYmFjaykpIHsKICAgICAgICAgICAgZGVwcyA9IFtdOwogICAgICAgICAgICAvL1JlbW92ZSBjb21tZW50cyBmcm9tIHRoZSBjYWxsYmFjayBzdHJpbmcsCiAgICAgICAgICAgIC8vbG9vayBmb3IgcmVxdWlyZSBjYWxscywgYW5kIHB1bGwgdGhlbSBpbnRvIHRoZSBkZXBlbmRlbmNpZXMsCiAgICAgICAgICAgIC8vYnV0IG9ubHkgaWYgdGhlcmUgYXJlIGZ1bmN0aW9uIGFyZ3MuCiAgICAgICAgICAgIGlmIChjYWxsYmFjay5sZW5ndGgpIHsKICAgICAgICAgICAgICAgIGNhbGxiYWNrCiAgICAgICAgICAgICAgICAgICAgLnRvU3RyaW5nKCkKICAgICAgICAgICAgICAgICAgICAucmVwbGFjZShjb21tZW50UmVnRXhwLCAnJykKICAgICAgICAgICAgICAgICAgICAucmVwbGFjZShjanNSZXF1aXJlUmVnRXhwLCBmdW5jdGlvbiAobWF0Y2gsIGRlcCkgewogICAgICAgICAgICAgICAgICAgICAgICBkZXBzLnB1c2goZGVwKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAvL01heSBiZSBhIENvbW1vbkpTIHRoaW5nIGV2ZW4gd2l0aG91dCByZXF1aXJlIGNhbGxzLCBidXQgc3RpbGwKICAgICAgICAgICAgICAgIC8vY291bGQgdXNlIGV4cG9ydHMsIGFuZCBtb2R1bGUuIEF2b2lkIGRvaW5nIGV4cG9ydHMgYW5kIG1vZHVsZQogICAgICAgICAgICAgICAgLy93b3JrIHRob3VnaCBpZiBpdCBqdXN0IG5lZWRzIHJlcXVpcmUuCiAgICAgICAgICAgICAgICAvL1JFUVVJUkVTIHRoZSBmdW5jdGlvbiB0byBleHBlY3QgdGhlIENvbW1vbkpTIHZhcmlhYmxlcyBpbiB0aGUKICAgICAgICAgICAgICAgIC8vb3JkZXIgbGlzdGVkIGJlbG93LgogICAgICAgICAgICAgICAgZGVwcyA9IChjYWxsYmFjay5sZW5ndGggPT09IDEgPyBbJ3JlcXVpcmUnXSA6IFsncmVxdWlyZScsICdleHBvcnRzJywgJ21vZHVsZSddKS5jb25jYXQoZGVwcyk7CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIC8vSWYgaW4gSUUgNi04IGFuZCBoaXQgYW4gYW5vbnltb3VzIGRlZmluZSgpIGNhbGwsIGRvIHRoZSBpbnRlcmFjdGl2ZQogICAgICAgIC8vd29yay4KICAgICAgICBpZiAodXNlSW50ZXJhY3RpdmUpIHsKICAgICAgICAgICAgbm9kZSA9IGN1cnJlbnRseUFkZGluZ1NjcmlwdCB8fCBnZXRJbnRlcmFjdGl2ZVNjcmlwdCgpOwogICAgICAgICAgICBpZiAobm9kZSkgewogICAgICAgICAgICAgICAgaWYgKCFuYW1lKSB7CiAgICAgICAgICAgICAgICAgICAgbmFtZSA9IG5vZGUuZ2V0QXR0cmlidXRlKCdkYXRhLXJlcXVpcmVtb2R1bGUnKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNvbnRleHQgPSBjb250ZXh0c1tub2RlLmdldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlY29udGV4dCcpXTsKICAgICAgICAgICAgfQogICAgICAgIH0KCiAgICAgICAgLy9BbHdheXMgc2F2ZSBvZmYgZXZhbHVhdGluZyB0aGUgZGVmIGNhbGwgdW50aWwgdGhlIHNjcmlwdCBvbmxvYWQgaGFuZGxlci4KICAgICAgICAvL1RoaXMgYWxsb3dzIG11bHRpcGxlIG1vZHVsZXMgdG8gYmUgaW4gYSBmaWxlIHdpdGhvdXQgcHJlbWF0dXJlbHkKICAgICAgICAvL3RyYWNpbmcgZGVwZW5kZW5jaWVzLCBhbmQgYWxsb3dzIGZvciBhbm9ueW1vdXMgbW9kdWxlIHN1cHBvcnQsCiAgICAgICAgLy93aGVyZSB0aGUgbW9kdWxlIG5hbWUgaXMgbm90IGtub3duIHVudGlsIHRoZSBzY3JpcHQgb25sb2FkIGV2ZW50CiAgICAgICAgLy9vY2N1cnMuIElmIG5vIGNvbnRleHQsIHVzZSB0aGUgZ2xvYmFsIHF1ZXVlLCBhbmQgZ2V0IGl0IHByb2Nlc3NlZAogICAgICAgIC8vaW4gdGhlIG9uc2NyaXB0IGxvYWQgY2FsbGJhY2suCiAgICAgICAgaWYgKGNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZS5wdXNoKFtuYW1lLCBkZXBzLCBjYWxsYmFja10pOwogICAgICAgICAgICBjb250ZXh0LmRlZlF1ZXVlTWFwW25hbWVdID0gdHJ1ZTsKICAgICAgICB9IGVsc2UgewogICAgICAgICAgICBnbG9iYWxEZWZRdWV1ZS5wdXNoKFtuYW1lLCBkZXBzLCBjYWxsYmFja10pOwogICAgICAgIH0KICAgIH07CgogICAgZGVmaW5lLmFtZCA9IHsKICAgICAgICBqUXVlcnk6IHRydWUKICAgIH07CgogICAgLyoqCiAgICAgKiBFeGVjdXRlcyB0aGUgdGV4dC4gTm9ybWFsbHkganVzdCB1c2VzIGV2YWwsIGJ1dCBjYW4gYmUgbW9kaWZpZWQKICAgICAqIHRvIHVzZSBhIGJldHRlciwgZW52aXJvbm1lbnQtc3BlY2lmaWMgY2FsbC4gT25seSB1c2VkIGZvciB0cmFuc3BpbGluZwogICAgICogbG9hZGVyIHBsdWdpbnMsIG5vdCBmb3IgcGxhaW4gSlMgbW9kdWxlcy4KICAgICAqIEBwYXJhbSB7U3RyaW5nfSB0ZXh0IHRoZSB0ZXh0IHRvIGV4ZWN1dGUvZXZhbHVhdGUuCiAgICAgKi8KICAgIHJlcS5leGVjID0gZnVuY3Rpb24gKHRleHQpIHsKICAgICAgICAvKmpzbGludCBldmlsOiB0cnVlICovCiAgICAgICAgcmV0dXJuIGV2YWwodGV4dCk7CiAgICB9OwoKICAgIC8vU2V0IHVwIHdpdGggY29uZmlnIGluZm8uCiAgICByZXEoY2ZnKTsKfSh0aGlzKSk7Cg==", - "headers": [ - [ - "content-type", - "application/javascript" - ] - ], - "ok": true, - "status": 200, - "status_text": "" - } - } - }, - "colab_type": "code", - "id": "k0j5zzpAPSFn", - "outputId": "cb5b1d88-054b-413e-d303-428e63bce694" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \u003cscript src=\"/static/components/requirejs/require.js\"\u003e\u003c/script\u003e\n", - " \u003cscript\u003e\n", - " requirejs.config({\n", - " paths: {\n", - " base: '/static/base',\n", - " \"d3\": \"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min\",\n", - " jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n", - " },\n", - " });\n", - " \u003c/script\u003e\n", - " " - ], - "text/plain": [ - "\u003cIPython.core.display.HTML object\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \u003cspan style=\"user-select:none\"\u003e\n", - " Layer: \u003cselect id=\"layer\"\u003e\u003c/select\u003e\n", - " Attention: \u003cselect id=\"att_type\"\u003e\n", - " \u003coption value=\"all\"\u003eAll\u003c/option\u003e\n", - " \u003coption value=\"inp_inp\"\u003eInput - Input\u003c/option\u003e\n", - " \u003coption value=\"inp_out\"\u003eInput - Output\u003c/option\u003e\n", - " \u003coption value=\"out_out\"\u003eOutput - Output\u003c/option\u003e\n", - " \u003c/select\u003e\n", - " \u003c/span\u003e\n", - " \u003cdiv id='vis'\u003e\u003c/div\u003e\n" - ], - "text/plain": [ - "\u003cIPython.core.display.HTML object\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window.attention = {\"all\": {\"att\": [[[[0.05334341153502464, 0.025828205049037933, 0.062369391322135925, 0.043252814561128616, 0.4045393764972687, 0.06697215139865875, 0.09001608937978745, 0.14983074367046356, 0.10384786874055862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11816457659006119, 0.03106253407895565, 0.01979171112179756, 0.16624291241168976, 0.3321376442909241, 0.020051123574376106, 0.08730963617563248, 0.18211135268211365, 0.04312858730554581, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05936884880065918, 0.02174757793545723, 0.016160180792212486, 0.010601435787975788, 0.43925121426582336, 0.03876951336860657, 0.19815810024738312, 0.07065817713737488, 0.14528508484363556, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15478025376796722, 0.16446512937545776, 0.0578744001686573, 0.21637752652168274, 0.03835854306817055, 0.09130414575338364, 0.11191156506538391, 0.08360221982002258, 0.08132638782262802, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2183060646057129, 0.1704275906085968, 0.0827711746096611, 0.1202380359172821, 0.05203341320157051, 0.05958092212677002, 0.12280035018920898, 0.09366822242736816, 0.08017415553331375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05084313824772835, 0.026207493618130684, 0.13631564378738403, 0.012270472943782806, 0.16236551105976105, 0.02548854425549507, 0.03909383341670036, 0.03172134608030319, 0.5156941413879395, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03615221381187439, 0.04799472168087959, 0.04255519434809685, 0.04762651398777962, 0.5117892622947693, 0.016304347664117813, 0.005770198069512844, 0.10897397249937057, 0.18283340334892273, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03243544325232506, 0.025252558290958405, 0.11733424663543701, 0.0250592939555645, 0.20289097726345062, 0.08240236341953278, 0.18285907804965973, 0.011341268196702003, 0.3204246759414673, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22355543076992035, 0.1260528564453125, 0.03741241991519928, 0.16813479363918304, 0.09858733415603638, 0.035831648856401443, 0.16361697018146515, 0.07236126810312271, 0.07444748282432556, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08996112644672394, 0.0921943336725235, 0.22672457993030548, 0.12702998518943787, 0.05907799303531647, 0.10712798684835434, 0.16789256036281586, 0.055181413888931274, 0.07481010258197784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9198169708251953, 0.0801829993724823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9412446618080139, 0.05875528231263161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8846490979194641, 0.10308036208152771, 0.012270578183233738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7461972832679749, 0.18569768965244293, 0.06810508668422699, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9307316541671753, 0.03309628367424011, 0.027538668364286423, 0.008633385412395, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4299372434616089, 0.16845084726810455, 0.2029547393321991, 0.19865721464157104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9335180521011353, 0.020782457664608955, 0.008113296702504158, 0.029529055580496788, 0.008057110011577606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5215166807174683, 0.16121163964271545, 0.19463112950325012, 0.09347883611917496, 0.029161658138036728, 0.0, 0.0, 0.0, 0.0, 0.0], [0.923790454864502, 0.01269624661654234, 0.004588128533214331, 0.020286502316594124, 0.018672045320272446, 0.019966628402471542, 0.0, 0.0, 0.0, 0.0, 0.26405569911003113, 0.04358615726232529, 0.10687251389026642, 0.1710020899772644, 0.4105237126350403, 0.0039598336443305016, 0.0, 0.0, 0.0, 0.0], [0.5214514136314392, 0.051599469035863876, 0.007387364283204079, 0.04305899888277054, 0.0632161945104599, 0.07775087654590607, 0.2355356514453888, 0.0, 0.0, 0.0, 0.29189321398735046, 0.19170531630516052, 0.11295431852340698, 0.08274418860673904, 0.12850242853164673, 0.09739833325147629, 0.09480219334363937, 0.0, 0.0, 0.0], [0.9122877717018127, 0.007671441417187452, 0.0012418286642059684, 0.005250561982393265, 0.001960531808435917, 0.032091617584228516, 0.03012256510555744, 0.009373520500957966, 0.0, 0.0, 0.3496137857437134, 0.03085259348154068, 0.0195528082549572, 0.45414459705352783, 0.09152030944824219, 0.008845902979373932, 0.02992299199104309, 0.01554702315479517, 0.0, 0.0], [0.012450892478227615, 0.0001350480888504535, 0.0001820741599658504, 0.0018266986589878798, 0.00022605709091294557, 0.0032795630395412445, 0.005876350682228804, 0.012136856094002724, 0.9638864398002625, 0.0, 0.4675538241863251, 0.03941410034894943, 0.05400091037154198, 0.17985978722572327, 0.20104949176311493, 0.030323797836899757, 0.010615098290145397, 0.015154700726270676, 0.002028239192441106, 0.0], [0.907938539981842, 0.003707215888425708, 0.003004483412951231, 0.0008324749651364982, 0.0015859504928812385, 0.008079104125499725, 0.010460118763148785, 0.005838368553668261, 0.038938846439123154, 0.019614921882748604, 0.053565241396427155, 0.029699191451072693, 0.0156599972397089, 0.016939852386713028, 0.04015244543552399, 0.21933501958847046, 0.1449035257101059, 0.4037321209907532, 0.019583676010370255, 0.056428998708724976]], [[0.040477100759744644, 0.20988762378692627, 0.4869004786014557, 0.03505674749612808, 0.0558856800198555, 0.025423096492886543, 0.12231241166591644, 0.007062799762934446, 0.016993943601846695, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8996549844741821, 0.02599872276186943, 0.049097247421741486, 0.0040262676775455475, 0.0039152717217803, 0.0049644638784229755, 0.010553319938480854, 0.001352570834569633, 0.0004369009402580559, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.33065715432167053, 0.2687782049179077, 0.03312753140926361, 0.22958999872207642, 0.01851547136902809, 0.046473052352666855, 0.053183481097221375, 0.007113412953913212, 0.012561764568090439, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1589452475309372, 0.47470128536224365, 0.12878550589084625, 0.14158962666988373, 0.04442765936255455, 0.022274963557720184, 0.013780632056295872, 0.0024951419327408075, 0.012999956496059895, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2559169828891754, 0.033451542258262634, 0.15095548331737518, 0.024318046867847443, 0.10824166238307953, 0.03234097361564636, 0.36475417017936707, 0.012823408469557762, 0.017197895795106888, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.021462664008140564, 0.010474847629666328, 0.007213775999844074, 0.02227940410375595, 0.21737068891525269, 0.4960675537586212, 0.014628118835389614, 0.20502059161663055, 0.005482145119458437, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06734316051006317, 0.09532227367162704, 0.1127309575676918, 0.009542002342641354, 0.0678786113858223, 0.12933993339538574, 0.03809814900159836, 0.44453269243240356, 0.035212237387895584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10458365827798843, 0.02846018597483635, 0.029760979115962982, 0.014774680137634277, 0.022077379748225212, 0.1553817093372345, 0.3539015054702759, 0.19523507356643677, 0.09582491964101791, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.021077070385217667, 0.010932122357189655, 0.05088815093040466, 0.028641115874052048, 0.0881260335445404, 0.12014731019735336, 0.3900885581970215, 0.09544514119625092, 0.1946544349193573, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02552945166826248, 0.05594164505600929, 0.045791901648044586, 0.093170166015625, 0.03584437444806099, 0.0969511866569519, 0.18585819005966187, 0.17433671653270721, 0.28657644987106323, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4050312936306, 0.5949686765670776, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5249735116958618, 0.4750264883041382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2333158701658249, 0.39531010389328003, 0.37137407064437866, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3563348054885864, 0.5701623558998108, 0.07350286096334457, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.52278733253479, 0.11893566697835922, 0.28584957122802734, 0.07242746651172638, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3398579955101013, 0.23167477548122406, 0.1957632154226303, 0.23270410299301147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23179638385772705, 0.09258762001991272, 0.103512242436409, 0.19472002983093262, 0.37738385796546936, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4351256191730499, 0.09737284481525421, 0.08845506608486176, 0.06574707478284836, 0.31329941749572754, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3839746117591858, 0.05338669568300247, 0.09416119009256363, 0.09689370542764664, 0.24871283769607544, 0.12287086993455887, 0.0, 0.0, 0.0, 0.0, 0.360861599445343, 0.02136792428791523, 0.005633710417896509, 0.009215844795107841, 0.15762653946876526, 0.4452943205833435, 0.0, 0.0, 0.0, 0.0], [0.5838866233825684, 0.02439245954155922, 0.042716383934020996, 0.03342103213071823, 0.08018141984939575, 0.15234005451202393, 0.08306187391281128, 0.0, 0.0, 0.0, 0.009015758521854877, 0.0013937305193394423, 0.00017763266805559397, 0.00016997012426145375, 0.010879353620111942, 0.0024589570239186287, 0.9759047627449036, 0.0, 0.0, 0.0], [0.639571487903595, 0.016348807141184807, 0.038869310170412064, 0.02800355665385723, 0.0377902127802372, 0.0529697984457016, 0.07620508968830109, 0.11024164408445358, 0.0, 0.0, 0.014776602387428284, 0.0001805058855097741, 1.6896785382414237e-05, 0.0003442507586441934, 0.006220621056854725, 0.0012393802171573043, 0.9433164596557617, 0.033905431628227234, 0.0, 0.0], [0.5836893320083618, 0.011862898245453835, 0.02550557814538479, 0.009363977238535881, 0.0196645837277174, 0.018125057220458984, 0.07040998339653015, 0.2077602595090866, 0.053618304431438446, 0.0, 0.005810329224914312, 0.002043980173766613, 0.0003433740057516843, 0.001522325212135911, 0.0030212807469069958, 0.00817712489515543, 0.5456522107124329, 0.10564129799604416, 0.32778817415237427, 0.0], [0.49946048855781555, 0.04904361814260483, 0.04135226085782051, 0.015084759332239628, 0.018269173800945282, 0.020069265738129616, 0.05080949887633324, 0.09452320635318756, 0.06869905441999435, 0.14268863201141357, 0.3754594326019287, 0.030579065904021263, 0.028458155691623688, 0.035943739116191864, 0.28040432929992676, 0.0202159583568573, 0.0396210215985775, 0.05075624957680702, 0.13473623991012573, 0.0038258912973105907]], [[0.18220090866088867, 0.25508272647857666, 0.2721964120864868, 0.04886331781744957, 0.010257811285555363, 0.07344724237918854, 0.08866558223962784, 0.037977367639541626, 0.0313086174428463, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5722172260284424, 0.09567929804325104, 0.1448327898979187, 0.033306267112493515, 0.0031244128476828337, 0.020944159477949142, 0.012691132724285126, 0.061001092195510864, 0.05620381608605385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.049244701862335205, 0.5266616344451904, 0.27518483996391296, 0.09334208071231842, 0.005858665332198143, 0.005467486567795277, 0.02565312758088112, 0.005746132228523493, 0.012841282412409782, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13445906341075897, 0.13356590270996094, 0.6041688919067383, 0.01878039538860321, 0.06342840194702148, 0.03677675500512123, 0.008389262482523918, 0.0002739423362072557, 0.00015757972141727805, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03273050859570503, 0.0697193592786789, 0.19719526171684265, 0.41500693559646606, 0.13721567392349243, 0.05743291601538658, 0.06517775356769562, 0.010865128599107265, 0.014656689018011093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.031571000814437866, 0.014337136410176754, 0.06860436499118805, 0.09357307106256485, 0.10011686384677887, 0.07827721536159515, 0.5866308212280273, 0.011440092697739601, 0.015449290163815022, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006158333271741867, 0.001533387927338481, 0.05427416041493416, 0.005477452650666237, 0.02694696933031082, 0.8134917616844177, 0.02643686905503273, 0.050265438854694366, 0.015415593050420284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008847472257912159, 0.0066053420305252075, 0.036443497985601425, 0.021455924957990646, 0.019254589453339577, 0.11543811857700348, 0.1138116791844368, 0.20307059586048126, 0.4750728905200958, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017603449523448944, 0.008448019623756409, 0.004260394722223282, 0.006066101603209972, 0.013470137491822243, 0.01876576989889145, 0.16350960731506348, 0.1980665624141693, 0.5698099732398987, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10490093380212784, 0.014168650843203068, 0.0247807614505291, 0.018330294638872147, 0.009348674677312374, 0.02287398651242256, 0.032268356531858444, 0.10571902245283127, 0.6676092147827148, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9956012964248657, 0.00439875153824687, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9630448818206787, 0.036955028772354126, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8920916318893433, 0.017498359084129333, 0.09041006118059158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8940342664718628, 0.015322646126151085, 0.09064316004514694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8103601336479187, 0.011479738168418407, 0.14884205162525177, 0.029318034648895264, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4866876006126404, 0.028273453935980797, 0.4569007158279419, 0.028138065710663795, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9073429107666016, 0.017702236771583557, 0.0008831396116875112, 0.017153160646557808, 0.05691858008503914, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7252220511436462, 0.10817205905914307, 0.07890959084033966, 0.017715180292725563, 0.06998112797737122, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7007134556770325, 0.00013011474220547825, 0.0017889889422804117, 0.00429273396730423, 0.20973503589630127, 0.08333952724933624, 0.0, 0.0, 0.0, 0.0, 0.8598019480705261, 0.012843498960137367, 0.014502299018204212, 0.004056263715028763, 0.10580158233642578, 0.0029942472465336323, 0.0, 0.0, 0.0, 0.0], [0.8020992279052734, 0.0005838978104293346, 0.0002877263759728521, 0.000665249943267554, 0.00924165453761816, 0.10947777330875397, 0.07764454185962677, 0.0, 0.0, 0.0, 0.8686293363571167, 0.024889284744858742, 0.013860221020877361, 0.00703870365396142, 0.07120370119810104, 0.003939351066946983, 0.010439489968121052, 0.0, 0.0, 0.0], [0.936653733253479, 0.00026242269086651504, 0.0004762547614518553, 0.000683068297803402, 0.0005867508007213473, 0.008624686859548092, 0.044821251183748245, 0.00789186917245388, 0.0, 0.0, 0.8572709560394287, 0.018014011904597282, 0.008267350494861603, 0.0022140766959637403, 0.1038530021905899, 0.004275611136108637, 0.0009780752006918192, 0.005126776173710823, 0.0, 0.0], [0.638530433177948, 0.00012756754586007446, 2.6267471184837632e-05, 0.035790614783763885, 0.00038457714254036546, 0.0026843701489269733, 0.0740678533911705, 0.21536435186862946, 0.03302408382296562, 0.0, 0.35013046860694885, 0.0037752145435661077, 0.0071558705531060696, 0.01608894392848015, 0.6097922325134277, 0.002463925164192915, 0.0005387101555243134, 0.005540961865335703, 0.004513624589890242, 0.0], [0.9069857597351074, 0.0010905838571488857, 0.0003166680980939418, 0.0021527763456106186, 0.00019805191550403833, 0.0004849489778280258, 0.025774035602808, 0.02642407827079296, 0.01662513054907322, 0.01994791068136692, 0.1888049989938736, 0.12293454259634018, 0.5947631597518921, 0.009457849897444248, 0.07291270792484283, 0.008950368501245975, 0.0004109511792194098, 0.000914009811822325, 0.0006959570455364883, 0.00015547229850199074]], [[0.2071455419063568, 0.637531578540802, 0.06835082173347473, 0.011966697871685028, 0.0017193991225212812, 0.04911382868885994, 0.009478496387600899, 0.008040529675781727, 0.00665308628231287, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07411027699708939, 0.15093472599983215, 0.2656005620956421, 0.05758262053132057, 0.05194409564137459, 0.23625947535037994, 0.019166678190231323, 0.04010465368628502, 0.10429693013429642, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1540999412536621, 0.10598444193601608, 0.22474077343940735, 0.32441702485084534, 0.1116243302822113, 0.054135363548994064, 0.008848286233842373, 0.004088098648935556, 0.012061581946909428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019440434873104095, 0.00560638727620244, 0.0035774046555161476, 0.0888679027557373, 0.7120485901832581, 0.14891275763511658, 0.011600993573665619, 0.008666431531310081, 0.0012791723711416125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08580154180526733, 0.02444172091782093, 0.08060747385025024, 0.05198557302355766, 0.2700504660606384, 0.34216371178627014, 0.11280739307403564, 0.006445358972996473, 0.02569655328989029, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0424385629594326, 0.029667967930436134, 0.006252861116081476, 0.020168066024780273, 0.03000665083527565, 0.2812231779098511, 0.49279165267944336, 0.09351769089698792, 0.003933228086680174, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006467411294579506, 0.0076894015073776245, 0.008325580507516861, 0.0010907554533332586, 0.01040297094732523, 0.19462232291698456, 0.013263629749417305, 0.24681615829467773, 0.5113216042518616, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.028696376830339432, 0.014982450753450394, 0.011884906329214573, 0.0011242942418903112, 0.01692844182252884, 0.12885364890098572, 0.028225399553775787, 0.6451764106750488, 0.12412811070680618, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16117365658283234, 0.06794824451208115, 0.06173194944858551, 0.00451233983039856, 0.05306624248623848, 0.0510348416864872, 0.04402391240000725, 0.12432018667459488, 0.4321887195110321, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1690559983253479, 0.043453093618154526, 0.036818861961364746, 0.017293656244874, 0.11775903403759003, 0.07970321178436279, 0.043801818042993546, 0.06849095970392227, 0.4236232340335846, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9964158535003662, 0.0035840808413922787, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.91131192445755, 0.08868805319070816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.603236198425293, 0.29069802165031433, 0.10606581717729568, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.786292314529419, 0.09286607056856155, 0.1208416074514389, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7401933073997498, 0.005742713809013367, 0.18690980970859528, 0.06715414673089981, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1722075194120407, 0.10747934877872467, 0.1462225317955017, 0.5740904808044434, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9087624549865723, 0.0078224902972579, 0.003505129599943757, 0.0673881471157074, 0.012521738186478615, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1893281787633896, 0.1733204573392868, 0.06838839501142502, 0.47577211260795593, 0.09319086372852325, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7394620180130005, 0.0234938096255064, 0.009907918982207775, 0.01616108976304531, 0.1237591803073883, 0.08721596747636795, 0.0, 0.0, 0.0, 0.0, 0.08935888856649399, 0.012517428956925869, 0.017112966626882553, 0.08479276299476624, 0.7640082240104675, 0.03220977261662483, 0.0, 0.0, 0.0, 0.0], [0.9526587724685669, 0.007287254091352224, 0.0013716809917241335, 0.0023222684394568205, 0.007607423700392246, 0.009167732670903206, 0.01958492584526539, 0.0, 0.0, 0.0, 0.824190616607666, 0.008810147643089294, 0.002143737394362688, 0.002297793049365282, 0.11996792256832123, 0.005709697026759386, 0.036880046129226685, 0.0, 0.0, 0.0], [0.9270981550216675, 0.004809631034731865, 0.0030887839384377003, 0.005205564666539431, 0.018441975116729736, 0.006030889227986336, 0.03003735840320587, 0.0052877976559102535, 0.0, 0.0, 0.1513449102640152, 0.015725232660770416, 0.02784004621207714, 0.01800909824669361, 0.6534391641616821, 0.016422629356384277, 0.09054289758205414, 0.026676079258322716, 0.0, 0.0], [0.603268563747406, 0.009098237380385399, 0.00021995518181938678, 0.07179546356201172, 0.0017328117974102497, 0.01055157370865345, 0.020978767424821854, 0.2736198902130127, 0.008734744042158127, 0.0, 0.1625923067331314, 0.016224535182118416, 0.06514906883239746, 0.003223034320399165, 0.6737184524536133, 0.014129054732620716, 0.036937959492206573, 0.023035621270537376, 0.004990031942725182, 0.0], [0.6497007608413696, 0.0906025841832161, 0.0100435521453619, 0.007925360463559628, 0.013416239991784096, 0.0018666544929146767, 0.02140365168452263, 0.08128199726343155, 0.04188578948378563, 0.08187359571456909, 0.06836045533418655, 0.01236770860850811, 0.008784784935414791, 0.014186863787472248, 0.09790214896202087, 0.046204064041376114, 0.1703491061925888, 0.1878211945295334, 0.0703599750995636, 0.32366377115249634]], [[0.03085354156792164, 0.12322185933589935, 0.13651973009109497, 0.050716523081064224, 0.2999139726161957, 0.09802427887916565, 0.06620478630065918, 0.0782310962677002, 0.11631430685520172, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06789751350879669, 0.058182138949632645, 0.3129631578922272, 0.04353875666856766, 0.09142065048217773, 0.10271093249320984, 0.026392055675387383, 0.09630800783634186, 0.2005866914987564, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07152411341667175, 0.3454192876815796, 0.11299439519643784, 0.18012462556362152, 0.07151429355144501, 0.052652161568403244, 0.0567985400557518, 0.09459780901670456, 0.014374655671417713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10420235246419907, 0.21845531463623047, 0.19832336902618408, 0.022119704633951187, 0.13572701811790466, 0.07722532749176025, 0.0508468933403492, 0.045597679913043976, 0.14750221371650696, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07030870020389557, 0.10706955939531326, 0.02791348285973072, 0.02260597050189972, 0.12725059688091278, 0.07336997240781784, 0.26662203669548035, 0.16957008838653564, 0.13528966903686523, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05156806856393814, 0.04327721148729324, 0.07664787024259567, 0.06931594759225845, 0.1889398992061615, 0.09515503793954849, 0.07227510958909988, 0.2641449272632599, 0.13867592811584473, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02184019424021244, 0.11184182018041611, 0.36672860383987427, 0.013787303119897842, 0.07600502669811249, 0.0389828234910965, 0.040494974702596664, 0.12485849112272263, 0.20546066761016846, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013738485053181648, 0.05187288299202919, 0.03463537245988846, 0.03627979755401611, 0.048659998923540115, 0.02440205216407776, 0.07256433367729187, 0.024731382727622986, 0.6931155323982239, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02671198360621929, 0.4013687074184418, 0.01132842618972063, 0.14022575318813324, 0.026275552809238434, 0.08107840269804001, 0.04189194366335869, 0.25432130694389343, 0.0167979933321476, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14228780567646027, 0.07866450399160385, 0.08390624076128006, 0.09396661072969437, 0.087954580783844, 0.14498625695705414, 0.13517630100250244, 0.1169552430510521, 0.11610251665115356, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9857779741287231, 0.014221975579857826, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.961704432964325, 0.038295578211545944, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9197340607643127, 0.07413885742425919, 0.0061270855367183685, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.37462106347084045, 0.2157517969608307, 0.40962719917297363, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8673564195632935, 0.016403868794441223, 0.1017053872346878, 0.014534366317093372, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.48521965742111206, 0.031020229682326317, 0.3760664165019989, 0.10769358277320862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.044595908373594284, 0.010755550116300583, 0.002565854461863637, 0.9345642328262329, 0.007518457714468241, 0.0, 0.0, 0.0, 0.0, 0.0, 0.914044201374054, 0.004715718794614077, 0.006151301320642233, 0.005079128313809633, 0.07000966370105743, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4605148434638977, 0.007289387751370668, 0.009601963683962822, 0.08598940074443817, 0.4091304838657379, 0.027473902329802513, 0.0, 0.0, 0.0, 0.0, 0.060511741787195206, 0.006127620115876198, 0.00728148128837347, 0.013585635460913181, 0.9084653854370117, 0.004028240218758583, 0.0, 0.0, 0.0, 0.0], [0.8714936971664429, 0.002528996206820011, 0.0021269593853503466, 0.0052809687331318855, 0.02593054249882698, 0.07010670751333237, 0.022532090544700623, 0.0, 0.0, 0.0, 0.23348243534564972, 0.03748093172907829, 0.055222347378730774, 0.014132470823824406, 0.27614685893058777, 0.017582375556230545, 0.3659524619579315, 0.0, 0.0, 0.0], [0.507957398891449, 0.003823956474661827, 0.004157013725489378, 0.018131878226995468, 0.06916838884353638, 0.047881923615932465, 0.2798653542995453, 0.06901402771472931, 0.0, 0.0, 0.06461911648511887, 0.003781915409490466, 0.002705940278246999, 0.016099220141768456, 0.8774597644805908, 0.012668337672948837, 0.0088069261983037, 0.013858767226338387, 0.0, 0.0], [0.4575899839401245, 0.005646431352943182, 0.0004441867640707642, 0.03129462152719498, 0.014414624311029911, 0.0058625745587050915, 0.09207130968570709, 0.34311652183532715, 0.04955975338816643, 0.0, 0.05451222136616707, 0.014412143267691135, 0.00208102585747838, 0.011283651925623417, 0.02552390843629837, 0.02239326573908329, 0.031104939058423042, 0.20777365565299988, 0.630915105342865, 0.0], [0.8105311393737793, 0.0010255038505420089, 0.0001402802881784737, 0.0005781117943115532, 0.00122542935423553, 0.000594198820181191, 0.02804729714989662, 0.01081023644655943, 0.13665232062339783, 0.010395429097115993, 0.5451503992080688, 0.014764615334570408, 0.2503703534603119, 0.037022024393081665, 0.0935375839471817, 0.022694993764162064, 0.0037449353840202093, 0.0053339023143053055, 0.007315538357943296, 0.020065704360604286]], [[0.02165721170604229, 0.018354326486587524, 0.6383510828018188, 0.042513273656368256, 0.10956817120313644, 0.10717540234327316, 0.030344119295477867, 0.015826348215341568, 0.01621006615459919, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4647374749183655, 0.07284841686487198, 0.28081396222114563, 0.014013433828949928, 0.03169411048293114, 0.02214456908404827, 0.058711059391498566, 0.036629818379879, 0.01840737834572792, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07372704148292542, 0.12858736515045166, 0.4501189887523651, 0.054217785596847534, 0.07096204906702042, 0.05748127028346062, 0.06541819125413895, 0.04703349620103836, 0.05245373025536537, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04684445261955261, 0.019098779186606407, 0.008431704714894295, 0.0010175607167184353, 0.9129327535629272, 0.004866998642683029, 0.006678053177893162, 8.096762758214027e-05, 4.903498847852461e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08239725232124329, 0.02813413366675377, 0.16611848771572113, 0.1532817929983139, 0.07408729940652847, 0.10856874287128448, 0.047752734273672104, 0.02563621662557125, 0.31402355432510376, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.17959792912006378, 0.02262653037905693, 0.10724494606256485, 0.022216446697711945, 0.1862414926290512, 0.14705143868923187, 0.15912717580795288, 0.15293282270431519, 0.02296125516295433, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.038375359028577805, 0.0038853511214256287, 0.06201936677098274, 0.005828780122101307, 0.22059503197669983, 0.36631014943122864, 0.020396992564201355, 0.20976856350898743, 0.07282061129808426, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014258276671171188, 0.005652762018144131, 0.025611618533730507, 0.15294744074344635, 0.06760217249393463, 0.2498260736465454, 0.1669282466173172, 0.2265811711549759, 0.09059228003025055, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15833799540996552, 0.1228356659412384, 0.10147804021835327, 0.0284584891051054, 0.27955442667007446, 0.06763719022274017, 0.08874277770519257, 0.1152903363108635, 0.037665050476789474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09844867885112762, 0.0919492095708847, 0.028445947915315628, 0.03726689890027046, 0.035665158182382584, 0.06817072629928589, 0.29930955171585083, 0.09819743037223816, 0.2425464242696762, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8512031435966492, 0.14879685640335083, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9904667735099792, 0.009533224627375603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10041537135839462, 0.8953256011009216, 0.0042589944787323475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9818503260612488, 0.007338901981711388, 0.010810752399265766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6295948624610901, 0.2121732085943222, 0.10306572169065475, 0.055166181176900864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9738979935646057, 0.007647394668310881, 0.015154722146689892, 0.0032999368850141764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9503376483917236, 0.007425909396260977, 0.0019253676291555166, 0.025024304166436195, 0.015286784619092941, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6611008644104004, 0.04138284549117088, 0.1119912639260292, 0.0262944046407938, 0.15923058986663818, 0.0, 0.0, 0.0, 0.0, 0.0], [0.24298420548439026, 0.06981680542230606, 0.030552756041288376, 0.020666545256972313, 0.46177101135253906, 0.1742086559534073, 0.0, 0.0, 0.0, 0.0, 0.9380988478660583, 0.005562208592891693, 0.01078465860337019, 0.004562946502119303, 0.033130958676338196, 0.007860423997044563, 0.0, 0.0, 0.0, 0.0], [0.8132306933403015, 0.003601218806579709, 0.01019350253045559, 0.009439423680305481, 0.040081463754177094, 0.07570415735244751, 0.04774952307343483, 0.0, 0.0, 0.0, 0.9377894997596741, 0.003691342193633318, 0.002771170577034354, 0.0017416415503248572, 0.04246653988957405, 0.002464305842295289, 0.009075501933693886, 0.0, 0.0, 0.0], [0.6454712152481079, 0.006356438156217337, 0.006696825381368399, 0.0020169378258287907, 0.11416922509670258, 0.11139311641454697, 0.07912010699510574, 0.03477614000439644, 0.0, 0.0, 0.9083399176597595, 0.005597027484327555, 0.02609928511083126, 0.005710097029805183, 0.017865832895040512, 0.0029857312329113483, 0.002900469582527876, 0.030501706525683403, 0.0, 0.0], [0.22032444179058075, 0.0006508066435344517, 0.006827942095696926, 0.028858821839094162, 0.0022757677361369133, 0.006474251858890057, 0.09447979182004929, 0.6212162375450134, 0.018891895189881325, 0.0, 0.8338009119033813, 0.00436164066195488, 0.006190306507050991, 0.0008050849428400397, 0.015337309800088406, 0.00863864365965128, 0.010715007781982422, 0.1143304780125618, 0.005820483900606632, 0.0], [0.03250038996338844, 0.0005526043241843581, 2.807211239996832e-05, 0.00014761221245862544, 0.00482193985953927, 7.781770545989275e-05, 0.00014718669990543276, 0.0008632297394797206, 0.959712028503418, 0.0011490467004477978, 0.9085996747016907, 0.00676243519410491, 0.02013525180518627, 0.009278967045247555, 0.02104269526898861, 0.009343095123767853, 0.0009470531367696822, 0.0018253516172990203, 0.003784958738833666, 0.018280424177646637]], [[0.02519470639526844, 0.006357265170663595, 0.14269335567951202, 0.023629529401659966, 0.3124701976776123, 0.13565225899219513, 0.2595662772655487, 0.07959114015102386, 0.014845297671854496, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04550129547715187, 0.011541971005499363, 0.1165909469127655, 0.02512240968644619, 0.01843150518834591, 0.05711649730801582, 0.44489097595214844, 0.033205363899469376, 0.24759893119335175, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13528011739253998, 0.06777236610651016, 0.14429129660129547, 0.04697401076555252, 0.1738385707139969, 0.014099549502134323, 0.38417065143585205, 0.01158357597887516, 0.02199004776775837, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.21356959640979767, 0.1638900637626648, 0.10595463216304779, 0.06925727427005768, 0.167257159948349, 0.04259340837597847, 0.10967854410409927, 0.03570139408111572, 0.09209771454334259, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20140984654426575, 0.04755665361881256, 0.15174560248851776, 0.11619894206523895, 0.21928974986076355, 0.07600340992212296, 0.05828682705760002, 0.10010629147291183, 0.029402663931250572, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.024259669706225395, 0.02116699516773224, 0.21201731264591217, 0.019622934982180595, 0.4893963038921356, 0.021304504945874214, 0.16948339343070984, 0.022949064150452614, 0.01979990489780903, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022248759865760803, 0.01183647196739912, 0.0633181631565094, 0.029095010831952095, 0.07090882211923599, 0.4614315629005432, 0.020150773227214813, 0.18720205128192902, 0.1338084638118744, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003461656626313925, 0.01603432185947895, 0.009874427691102028, 0.014947548508644104, 0.2953553795814514, 0.3502987027168274, 0.08878874033689499, 0.036094941198825836, 0.18514421582221985, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.005101516842842102, 0.022985950112342834, 0.007523353211581707, 0.026773063465952873, 0.01009095273911953, 0.014858697541058064, 0.15149906277656555, 0.028601571917533875, 0.7325656414031982, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12995873391628265, 0.07769863307476044, 0.02032659947872162, 0.13720010221004486, 0.011713794432580471, 0.054615918546915054, 0.23920413851737976, 0.13190706074237823, 0.19737498462200165, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9700191020965576, 0.029980869963765144, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.972051739692688, 0.027948210015892982, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7072298526763916, 0.2173422873020172, 0.07542789727449417, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7552067041397095, 0.17251533269882202, 0.0722779706120491, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5017270445823669, 0.10517530888319016, 0.32087045907974243, 0.07222715020179749, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6455309987068176, 0.23265127837657928, 0.10187581926584244, 0.01994187943637371, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39005738496780396, 0.2261916995048523, 0.1838584840297699, 0.10916081070899963, 0.09073163568973541, 0.0, 0.0, 0.0, 0.0, 0.0, 0.470674991607666, 0.26442891359329224, 0.14268451929092407, 0.03363766148686409, 0.08857394009828568, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11122927069664001, 0.04386316239833832, 0.023478534072637558, 0.07375308126211166, 0.5692906379699707, 0.17838534712791443, 0.0, 0.0, 0.0, 0.0, 0.6457618474960327, 0.011289404705166817, 0.008832731284201145, 0.01570025272667408, 0.2588561475276947, 0.059559762477874756, 0.0, 0.0, 0.0, 0.0], [0.16762810945510864, 0.030268238857388496, 0.015392551198601723, 0.05242612585425377, 0.21519990265369415, 0.34948840737342834, 0.16959665715694427, 0.0, 0.0, 0.0, 0.4916176497936249, 0.07200384140014648, 0.0701020285487175, 0.019148536026477814, 0.0833231583237648, 0.12199999392032623, 0.14180481433868408, 0.0, 0.0, 0.0], [0.15348000824451447, 0.03554287180304527, 0.008979924954473972, 0.07115276902914047, 0.08698276430368423, 0.24143245816230774, 0.28553345799446106, 0.11689584702253342, 0.0, 0.0, 0.11119699478149414, 0.002801541704684496, 0.0021932011004537344, 0.0016493132570758462, 0.06827285885810852, 0.22499483823776245, 0.5049597024917603, 0.08393163233995438, 0.0, 0.0], [0.09456975758075714, 0.010759694501757622, 0.0067994119599461555, 0.01042863354086876, 0.05627141892910004, 0.11228546500205994, 0.14361944794654846, 0.3204572796821594, 0.2448090761899948, 0.0, 0.13208742439746857, 0.0035411729477345943, 0.0015305017586797476, 0.002489483682438731, 0.06612236052751541, 0.213859423995018, 0.5324232578277588, 0.03503565117716789, 0.012910734862089157, 0.0], [0.057867951691150665, 0.02229062095284462, 0.016399098560214043, 0.02521427348256111, 0.047808028757572174, 0.03428687900304794, 0.05170976370573044, 0.19979508221149445, 0.41991233825683594, 0.12471600621938705, 0.20209012925624847, 0.05223073810338974, 0.03088257648050785, 0.036374326795339584, 0.014660456217825413, 0.03045688569545746, 0.03597142919898033, 0.16862399876117706, 0.022359324619174004, 0.40635016560554504]], [[0.21207179129123688, 0.11920439451932907, 0.4251355528831482, 0.014464439824223518, 0.20776884257793427, 0.01428140513598919, 0.0027938869316130877, 0.001743048895150423, 0.002536489861086011, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.046175818890333176, 0.026793524622917175, 0.8552185297012329, 0.04517081379890442, 0.010388500988483429, 0.004191457759588957, 0.0036751439329236746, 0.0013485046802088618, 0.007037981878966093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013186579570174217, 0.020899420604109764, 0.6900137662887573, 0.0480119027197361, 0.15360434353351593, 0.02344118244946003, 0.03952033817768097, 0.0038994532078504562, 0.007422822527587414, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006273405160754919, 0.00015674144378863275, 0.000751359446439892, 0.00447711581364274, 0.9859057664871216, 0.002212332095950842, 0.00014360185014083982, 4.957199053023942e-05, 2.9913859179941937e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.001047183177433908, 0.0003636489564087242, 0.009283728897571564, 0.016805388033390045, 0.42387446761131287, 0.4776095747947693, 0.06253702938556671, 0.005590841174125671, 0.002888289513066411, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0018647151300683618, 0.0002549054042901844, 2.6050107408082113e-05, 2.586200753285084e-05, 0.0024472770746797323, 0.006814199965447187, 0.9776560664176941, 0.010138182900846004, 0.000773087958805263, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.047241877764463425, 0.006076885852962732, 0.04534892365336418, 0.00081661093281582, 0.087706059217453, 0.41394293308258057, 0.21876952052116394, 0.17005810141563416, 0.0100388890132308, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0019138919888064265, 0.006189406383782625, 0.010115097276866436, 8.508542669005692e-05, 0.008424345403909683, 0.003492203773930669, 0.13495568931102753, 0.4890870749950409, 0.34573695063591003, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.016032341867685318, 0.005025702994316816, 0.009520799852907658, 0.0008855267078615725, 0.026489384472370148, 0.0020503124687820673, 0.032939448952674866, 0.09461060166358948, 0.8124459385871887, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.25683313608169556, 0.02960006147623062, 0.11211041361093521, 0.09736908972263336, 0.17546677589416504, 0.032068025320768356, 0.017857572063803673, 0.025635067373514175, 0.25305992364883423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9535994529724121, 0.04640045389533043, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9218347668647766, 0.0781652107834816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8665578961372375, 0.09402694553136826, 0.03941517323255539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4189925193786621, 0.4865715503692627, 0.09443587809801102, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8201385140419006, 0.07587680220603943, 0.05075912922620773, 0.053225547075271606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.48251789808273315, 0.34758540987968445, 0.13321316242218018, 0.036683470010757446, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6245242953300476, 0.093341164290905, 0.11281723529100418, 0.1092497780919075, 0.06006752699613571, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8504839539527893, 0.033341050148010254, 0.053517427295446396, 0.012789242900907993, 0.049868300557136536, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5755861401557922, 0.0864969864487648, 0.10001320391893387, 0.12654373049736023, 0.06871193647384644, 0.04264802858233452, 0.0, 0.0, 0.0, 0.0, 0.4515743553638458, 0.03267433121800423, 0.019386781379580498, 0.024256065487861633, 0.17900733649730682, 0.29310107231140137, 0.0, 0.0, 0.0, 0.0], [0.6500274538993835, 0.06470640748739243, 0.047299426048994064, 0.08855419605970383, 0.06197808310389519, 0.04487667977809906, 0.04255769029259682, 0.0, 0.0, 0.0, 0.5910289883613586, 0.0027754076290875673, 0.004533650353550911, 0.0023315453436225653, 0.08002334088087082, 0.06913208961486816, 0.2501751184463501, 0.0, 0.0, 0.0], [0.5771223902702332, 0.0491044707596302, 0.09411156177520752, 0.06903567165136337, 0.04109871760010719, 0.06523709744215012, 0.06637011468410492, 0.03792000934481621, 0.0, 0.0, 0.1626552939414978, 0.0011573631782084703, 0.00017211545491591096, 0.0007665579323656857, 0.03241841867566109, 0.34369325637817383, 0.2890424132347107, 0.17009468376636505, 0.0, 0.0], [0.4695849120616913, 0.017787985503673553, 0.06290572881698608, 0.06516575813293457, 0.09894091635942459, 0.03647425398230553, 0.051347069442272186, 0.08907806128263474, 0.10871540009975433, 0.0, 0.10835989564657211, 0.0007107920246198773, 0.00030798258376307786, 0.005807099863886833, 0.04662986099720001, 0.1659584492444992, 0.3522194027900696, 0.30094781517982483, 0.019058646634221077, 0.0], [0.18501408398151398, 0.040740884840488434, 0.10466982424259186, 0.07660976052284241, 0.17033715546131134, 0.05819392204284668, 0.0898737907409668, 0.09184892475605011, 0.10470453649759293, 0.0780070349574089, 0.5449283123016357, 0.01310307253152132, 0.008020865730941296, 0.006764447782188654, 0.16009773313999176, 0.06950337439775467, 0.0024397175293415785, 0.014089844189584255, 0.013654321432113647, 0.1673980951309204]]], [[[0.10487863421440125, 0.7106320858001709, 0.1635318249464035, 0.011256101541221142, 0.0012767312582582235, 0.00310636218637228, 0.0013001860352233052, 0.0012553841806948185, 0.002762428717687726, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.021650908514857292, 0.0030605364590883255, 0.6595932245254517, 0.2987315356731415, 0.012945608235895634, 0.0028472936246544123, 7.557096250820905e-05, 0.00029089683084748685, 0.0008047237643040717, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014272261410951614, 0.040512338280677795, 0.8595607280731201, 0.038314104080200195, 0.037397123873233795, 0.006795509252697229, 0.001303989440202713, 0.001011757180094719, 0.0008321924251504242, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.031783342361450195, 0.007319662719964981, 0.7663278579711914, 0.0010118860518559813, 0.1672297865152359, 0.02513650804758072, 0.000853335193824023, 0.0002817189379129559, 5.600590884569101e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002136597875505686, 0.00037253598566167057, 0.07588302344083786, 0.2252500057220459, 0.33551687002182007, 0.35751965641975403, 0.0027331046294420958, 0.00018122239271178842, 0.0004068210837431252, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0004353485128376633, 0.0003557991876732558, 0.0003262429090682417, 0.003819868667051196, 0.33603885769844055, 0.2681770920753479, 0.3838857412338257, 0.0068349516950547695, 0.00012614508159458637, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [6.71677480568178e-05, 3.9912600186653435e-05, 0.00047830803669057786, 5.937727837590501e-05, 0.0014537296956405044, 0.6413838863372803, 0.29047340154647827, 0.06565171480178833, 0.0003929881495423615, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00047039391938596964, 0.0007891620043665171, 0.0007817292353138328, 0.0010076714679598808, 0.00965806283056736, 0.003733346238732338, 0.35330116748809814, 0.5722718238830566, 0.05798657611012459, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006178696174174547, 0.009340841323137283, 0.0005589249776676297, 0.005146770738065243, 0.0033258567564189434, 0.0016933922888711095, 0.06414961069822311, 0.3291752338409424, 0.5804308652877808, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006624103523790836, 0.001978900283575058, 0.0081730792298913, 0.0030846702866256237, 0.0018904987955465913, 0.0014340116176754236, 0.005187559872865677, 0.029854312539100647, 0.9417726993560791, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10875418037176132, 0.15107707679271698, 0.07560893893241882, 0.11182637512683868, 0.051575273275375366, 0.1800614595413208, 0.13901139795780182, 0.11257244646549225, 0.06951297074556351, 0.0, 0.03246883675456047, 0.020431363955140114, 0.06294436007738113, 0.08282972872257233, 0.047490958124399185, 0.03976213559508324, 0.01868664100766182, 0.5054241418838501, 0.18996170163154602, 0.0], [0.04530828073620796, 0.11530135571956635, 0.03132164478302002, 0.12301183491945267, 0.01339547149837017, 0.009322633035480976, 0.0069213854148983955, 0.181557297706604, 0.47386014461517334, 0.0, 0.0334412157535553, 0.45350977778434753, 0.23828978836536407, 0.07703227549791336, 0.02545342594385147, 0.019935714080929756, 0.007961008697748184, 0.08864670246839523, 0.05572996661067009, 0.0], [0.08671615272760391, 0.21926835179328918, 0.11249969899654388, 0.05250205472111702, 0.044286634773015976, 0.006910341326147318, 0.004434189759194851, 0.00961831770837307, 0.4637643098831177, 0.0, 0.008816813118755817, 0.009350132197141647, 0.09488566964864731, 0.022458655759692192, 0.001578008639626205, 0.01768183708190918, 0.0012928039068356156, 0.7889453768730164, 0.05499071627855301, 0.0], [0.016148164868354797, 0.08668603748083115, 0.1414848268032074, 0.024200299754738808, 0.018711188808083534, 0.02537006139755249, 0.017450006678700447, 0.039331331849098206, 0.6306182146072388, 0.0, 0.0037117439787834883, 0.00603569345548749, 0.019362367689609528, 0.06632085889577866, 0.02251342497766018, 0.048607613891363144, 0.00711278198286891, 0.7890322804450989, 0.03730323165655136, 0.0], [0.024489276111125946, 0.03301851078867912, 0.03003605268895626, 0.03562680631875992, 0.06981870532035828, 0.022592445835471153, 0.025447512045502663, 0.03545365110039711, 0.7235170006752014, 0.0, 0.0017165049212053418, 0.0031809706706553698, 0.00569736585021019, 0.027958940714597702, 0.001130971242673695, 0.006313299294561148, 0.004051794297993183, 0.9312260150909424, 0.018723946064710617, 0.0], [0.05760658532381058, 0.08793947100639343, 0.053903114050626755, 0.0679689273238182, 0.007038408424705267, 0.007889931090176105, 0.010035911574959755, 0.019540006294846535, 0.6880777478218079, 0.0, 0.0028915719594806433, 0.007050157990306616, 0.004614752251654863, 0.0017270235111936927, 0.0016248916508629918, 0.06901240348815918, 0.005150379613041878, 0.13293159008026123, 0.7749972939491272, 0.0], [0.045610494911670685, 0.042210742831230164, 0.14248158037662506, 0.03233090415596962, 0.03048519603908062, 0.011738738045096397, 0.014284060336649418, 0.006383211817592382, 0.6744750738143921, 0.0, 0.005032604560256004, 0.005055313929915428, 0.0030569147784262896, 0.0010687477188184857, 0.012304573319852352, 0.013984610326588154, 0.3489484190940857, 0.012370014563202858, 0.5981789827346802, 0.0], [0.096277616918087, 0.030696624889969826, 0.10220203548669815, 0.04915016517043114, 0.047845132648944855, 0.05814794450998306, 0.06954183429479599, 0.028650736436247826, 0.5174878835678101, 0.0, 0.0019784842152148485, 0.009333183988928795, 0.005381024908274412, 0.0002465381403453648, 0.0013898308388888836, 0.005461550783365965, 0.0012134313583374023, 0.001065099611878395, 0.9739308953285217, 0.0], [0.009306053631007671, 0.02153283730149269, 0.009718294255435467, 0.005953253246843815, 0.011703923344612122, 0.017902903258800507, 0.011090915650129318, 0.01645584963262081, 0.8963360786437988, 0.0, 0.005657540168613195, 0.006781480740755796, 0.00696007814258337, 0.0009338636882603168, 0.02429838851094246, 0.03842600807547569, 0.00286328443326056, 0.03579647094011307, 0.8782829642295837, 0.0], [0.009895006194710732, 0.026821313425898552, 0.16079027950763702, 0.01761648990213871, 0.01726638339459896, 0.08361288905143738, 0.039622098207473755, 0.14411716163158417, 0.5002583861351013, 0.0, 0.007395321968942881, 0.012293249368667603, 0.006963892374187708, 0.00022730379714630544, 0.0005401583621278405, 0.005707587581127882, 0.0028992195148020983, 0.0027063635643571615, 0.9612669944763184, 0.0]], [[0.17277710139751434, 0.13871003687381744, 0.020699918270111084, 0.04190761595964432, 0.17760643362998962, 0.1702892780303955, 0.16168300807476044, 0.10000763088464737, 0.01631900854408741, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9987638592720032, 0.0011447033612057567, 1.5495901607209817e-05, 2.3805538096333123e-10, 1.1166920899086108e-07, 4.81009180930414e-07, 2.3257289285538718e-05, 3.4320622944505885e-05, 1.812833215808496e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029870687052607536, 0.9668734669685364, 0.0031853404361754656, 3.7420595617732033e-06, 1.0481591772304455e-07, 4.711453893690987e-09, 4.051101996083162e-07, 1.359390239485947e-06, 6.518688314827159e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.9839180569979362e-05, 0.0008244949858635664, 0.9990562796592712, 6.778111855965108e-05, 2.14482715819031e-05, 5.3428358959273226e-11, 7.202954205309808e-11, 7.697720239008277e-11, 1.422941551254553e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [9.680035873316228e-05, 4.205659934086725e-05, 0.0021876851096749306, 0.9926192164421082, 0.0050464412197470665, 7.330636890401365e-06, 4.7689670878980905e-08, 8.238330573284713e-10, 9.979119397485192e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.136659183335723e-06, 6.750806136324172e-08, 8.17252839624416e-06, 0.008817464113235474, 0.9640147089958191, 0.027066770941019058, 8.771067950874567e-05, 3.571775764044105e-09, 3.5257423647294672e-09, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.115869043947896e-07, 1.0059281407848175e-08, 1.3136859422502312e-07, 9.641905052149013e-08, 0.001335342414677143, 0.9957214593887329, 0.0029362423811107874, 7.136273325158982e-06, 1.1521567699901425e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [3.561131961760111e-06, 2.727877870256634e-07, 8.369554507225985e-07, 1.214864764342849e-09, 4.873449597653234e-06, 0.024909861385822296, 0.9680997133255005, 0.006879042834043503, 0.00010210835171164945, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00021467455371748656, 9.040503209689632e-05, 3.369562909938395e-05, 1.9265097961351785e-08, 9.727973520057276e-07, 2.4095537810353562e-05, 0.0040859803557395935, 0.8618475794792175, 0.1337023377418518, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.289768872287823e-06, 6.284429400693625e-05, 0.0001214230724144727, 2.809870807141124e-07, 1.092972157223926e-09, 1.0671180605825725e-09, 1.2438744079190656e-06, 0.024907555431127548, 0.9749038219451904, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0543275885283947, 0.01742306910455227, 0.05347726121544838, 0.18824619054794312, 0.09003543108701706, 0.08433128148317337, 0.1953076422214508, 0.206686869263649, 0.11016455292701721, 0.0, 0.02470340207219124, 0.02512546442449093, 0.11353036016225815, 0.35132649540901184, 0.20412008464336395, 0.027150044217705727, 0.015305055305361748, 0.05760098248720169, 0.1811380535364151, 0.0], [0.00859006680548191, 0.02184058353304863, 0.02418440766632557, 0.03131486475467682, 0.03273439407348633, 0.06774082779884338, 0.1731010377407074, 0.09275981038808823, 0.5477339029312134, 0.0, 0.009894105605781078, 0.02192404493689537, 0.3007009029388428, 0.13983333110809326, 0.03682582825422287, 0.08908118307590485, 0.27657952904701233, 0.026430398225784302, 0.09873086214065552, 0.0], [0.02145911566913128, 0.046526145190000534, 0.014734850265085697, 0.026213468983769417, 0.04904777929186821, 0.08567024767398834, 0.13810616731643677, 0.03392839804291725, 0.5843138694763184, 0.0, 0.011459765024483204, 0.044317521154880524, 0.5289616584777832, 0.19549138844013214, 0.03426412120461464, 0.017797794193029404, 0.030613277107477188, 0.0163635965436697, 0.12073105573654175, 0.0], [0.019245177507400513, 0.01515401341021061, 0.027409562841057777, 0.0068243746645748615, 0.07997982203960419, 0.0921224057674408, 0.04510754346847534, 0.04373685643076897, 0.670420229434967, 0.0, 0.011578483507037163, 0.0029169816989451647, 0.00455811433494091, 0.01625976897776127, 0.018393559381365776, 0.11749742925167084, 0.32938554883003235, 0.41049671173095703, 0.08891336619853973, 0.0], [0.04381020739674568, 0.06711422652006149, 0.07609888166189194, 0.021496189758181572, 0.05042967572808266, 0.15614424645900726, 0.11071597784757614, 0.14296749234199524, 0.3312230408191681, 0.0, 0.0033444140572100878, 0.0011373214656487107, 0.0019445078214630485, 0.02781311236321926, 0.0049105980433523655, 0.05221953243017197, 0.09222303330898285, 0.3644186854362488, 0.45198866724967957, 0.0], [0.04100082442164421, 0.030313873663544655, 0.032653506845235825, 0.0695231482386589, 0.12672685086727142, 0.12515434622764587, 0.08855390548706055, 0.05835743993520737, 0.4277162253856659, 0.0, 0.002199131529778242, 0.0006913270917721093, 0.002652444876730442, 0.017487458884716034, 0.18746966123580933, 0.39171290397644043, 0.26989367604255676, 0.017002178356051445, 0.11089123785495758, 0.0], [0.14112897217273712, 0.06592341512441635, 0.06986766308546066, 0.06311382353305817, 0.12678426504135132, 0.04950721934437752, 0.08025017380714417, 0.03467738255858421, 0.36874714493751526, 0.0, 0.01051913108676672, 0.003755246289074421, 0.0008555634994991124, 0.002675057854503393, 0.0025919810868799686, 0.02418649010360241, 0.018060903996229172, 0.003447937313467264, 0.9339075684547424, 0.0], [0.02841436117887497, 0.022568009793758392, 0.014519155025482178, 0.019271234050393105, 0.018120555207133293, 0.036434635519981384, 0.014109926298260689, 0.24622198939323425, 0.6003400683403015, 0.0, 0.029951948672533035, 0.006547479424625635, 0.030934682115912437, 0.0036260345950722694, 0.1420958936214447, 0.19529034197330475, 0.1491098254919052, 0.009723717346787453, 0.43272000551223755, 0.0], [0.05730762332677841, 0.07724729180335999, 0.030861826613545418, 0.04063780978322029, 0.08539344370365143, 0.029541905969381332, 0.02964094467461109, 0.028206804767251015, 0.6211622953414917, 0.0, 0.017757408320903778, 0.006832967512309551, 0.028906390070915222, 0.00921954121440649, 0.054915353655815125, 0.028632348403334618, 0.03646676614880562, 0.01978384144604206, 0.7974854707717896, 0.0], [0.20915710926055908, 0.193747878074646, 0.11181499063968658, 0.07680925726890564, 0.04479793831706047, 0.03787367418408394, 0.04819086939096451, 0.11330965161323547, 0.1642986238002777, 0.0, 0.06588920205831528, 0.05552517622709274, 0.18546447157859802, 0.007839588448405266, 0.020484987646341324, 0.01699826307594776, 0.01947665773332119, 0.017759086564183235, 0.6105626821517944, 0.0]], [[0.058097392320632935, 0.00935883168131113, 0.04822169989347458, 0.0048278868198394775, 0.191309854388237, 0.28154584765434265, 0.09391050785779953, 0.24126385152339935, 0.07146408408880234, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10414423793554306, 0.027566324919462204, 0.021727869287133217, 0.033647697418928146, 0.026882247999310493, 0.17782779037952423, 0.05685214698314667, 0.45095938444137573, 0.10039239376783371, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.44215551018714905, 0.049670565873384476, 0.014098896645009518, 0.029011834412813187, 0.01834075152873993, 0.1358453929424286, 0.04072042554616928, 0.2330295443534851, 0.03712712228298187, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10425814986228943, 0.06979154050350189, 0.036334071308374405, 0.028995294123888016, 0.015532439574599266, 0.1330128014087677, 0.063407763838768, 0.23157192766666412, 0.3170958459377289, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3384562134742737, 0.055937401950359344, 0.038792647421360016, 0.00819220207631588, 0.03063569962978363, 0.09386011958122253, 0.07227522879838943, 0.30926018953323364, 0.05259038880467415, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3519401550292969, 0.1823827177286148, 0.06509842723608017, 0.030452275648713112, 0.08377533406019211, 0.09469012171030045, 0.04247477278113365, 0.11751312017440796, 0.03167306259274483, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3634622097015381, 0.14048337936401367, 0.08374395966529846, 0.038946691900491714, 0.03473563492298126, 0.06442954391241074, 0.019375532865524292, 0.22685663402080536, 0.027966352179646492, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.18070067465305328, 0.04645215719938278, 0.0992647334933281, 0.005799622740596533, 0.47514480352401733, 0.12094692885875702, 0.030788421630859375, 0.025236092507839203, 0.015666494145989418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5453059673309326, 0.10054859519004822, 0.01722547970712185, 0.06704734265804291, 0.007780902087688446, 0.07263857871294022, 0.022086072713136673, 0.1394840031862259, 0.027883058413863182, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15028028190135956, 0.17163224518299103, 0.06043723225593567, 0.10140684247016907, 0.10512865334749222, 0.06778015196323395, 0.06512691080570221, 0.23085294663906097, 0.04735487326979637, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.038908280432224274, 0.07760688662528992, 0.062413811683654785, 0.0023113787174224854, 0.0021746077109128237, 0.015095214359462261, 0.003646473865956068, 0.038165315985679626, 0.759678065776825, 0.0, 0.14391662180423737, 0.11156481504440308, 0.4162432849407196, 0.07845085859298706, 0.04067624360322952, 0.016916701570153236, 0.012291320599615574, 0.10670017451047897, 0.07323983311653137, 0.0], [0.015742339193820953, 0.029524141922593117, 0.0550379604101181, 0.16926467418670654, 0.035933610051870346, 0.03279981389641762, 0.03188418969511986, 0.5383173227310181, 0.09149592369794846, 0.0, 0.0171683169901371, 0.03512553498148918, 0.4936983287334442, 0.18945446610450745, 0.020571058616042137, 0.011469473131000996, 0.04002959281206131, 0.08968089520931244, 0.10280223935842514, 0.0], [0.022741766646504402, 0.013864121399819851, 0.06161126494407654, 0.06985131651163101, 0.03954875469207764, 0.02864447981119156, 0.036658816039562225, 0.05774570629000664, 0.6693336963653564, 0.0, 0.2093620002269745, 0.11281707882881165, 0.25891542434692383, 0.14515942335128784, 0.0042000748217105865, 0.006485591176897287, 0.005525505635887384, 0.14364667236804962, 0.11388827115297318, 0.0], [0.06077639013528824, 0.053226571530103683, 0.05544588342308998, 0.08368532359600067, 0.04779139161109924, 0.028960514813661575, 0.03463221713900566, 0.42419588565826416, 0.21128588914871216, 0.0, 0.0109701631590724, 0.0007525839027948678, 0.011503712274134159, 0.03920656442642212, 0.2449047565460205, 0.048431187868118286, 0.12996943295001984, 0.4081973731517792, 0.10606419295072556, 0.0], [0.03320460394024849, 0.07872876524925232, 0.0791814923286438, 0.008506255224347115, 0.010383618995547295, 0.021636927500367165, 0.009444555267691612, 0.026183925569057465, 0.7327298521995544, 0.0, 0.004995591007173061, 0.0001893905719043687, 0.0009439413552172482, 0.03207648918032646, 0.08267047256231308, 0.015983520075678825, 0.02033340558409691, 0.8191123604774475, 0.023694908246397972, 0.0], [0.14095324277877808, 0.17195045948028564, 0.04960065335035324, 0.02801741287112236, 0.02789357118308544, 0.0246508177369833, 0.027228642255067825, 0.008449538610875607, 0.521255612373352, 0.0, 0.0022357299458235502, 0.000793653482105583, 0.0010144039988517761, 0.2958794832229614, 0.3394852876663208, 0.07495945692062378, 0.06856833398342133, 0.06118563562631607, 0.15587811172008514, 0.0], [0.01678302139043808, 0.02193976752460003, 0.13912786543369293, 0.05168221518397331, 0.06239692494273186, 0.008615943603217602, 0.037501659244298935, 0.02482585795223713, 0.6371266841888428, 0.0, 0.0020441634114831686, 0.00032311712857335806, 0.0006899640429764986, 0.03996479511260986, 0.38782593607902527, 0.05503879860043526, 0.24750953912734985, 0.004524962045252323, 0.26207876205444336, 0.0], [0.03396642208099365, 0.07778684049844742, 0.18657010793685913, 0.11281172931194305, 0.019890569150447845, 0.012303605675697327, 0.0494060292840004, 0.11448060721158981, 0.39278414845466614, 0.0, 0.0012333561899140477, 0.0002747838443610817, 0.0023864947725087404, 0.10253860056400299, 0.4721597135066986, 0.04103615880012512, 0.03782818093895912, 0.026908699423074722, 0.31563398241996765, 0.0], [0.02684134803712368, 0.03310805931687355, 0.163743257522583, 0.014529252424836159, 0.10077258199453354, 0.044357266277074814, 0.04152251034975052, 0.10173188894987106, 0.4733937382698059, 0.0, 0.004791810177266598, 0.0015037101693451405, 0.004669447895139456, 0.38809871673583984, 0.13379721343517303, 0.024320820346474648, 0.03647102415561676, 0.013309511356055737, 0.3930378258228302, 0.0], [0.01862592063844204, 0.022009190171957016, 0.028925148770213127, 0.006837732624262571, 0.006956242956221104, 0.010202805511653423, 0.015325144864618778, 0.11640346795320511, 0.7747144103050232, 0.0, 0.00849083997309208, 0.003579143201932311, 0.0033037925604730844, 0.006032468285411596, 0.017621049657464027, 0.0234503336250782, 0.018282314762473106, 0.02657976746559143, 0.8926602602005005, 0.0]], [[0.11086989939212799, 0.14517885446548462, 0.17419463396072388, 0.060936953872442245, 0.08783368766307831, 0.11005676537752151, 0.03251044824719429, 0.07983692735433578, 0.19858187437057495, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16660544276237488, 0.29352903366088867, 0.1008867621421814, 0.023942291736602783, 0.15022507309913635, 0.06581585109233856, 0.02344084158539772, 0.05208655819296837, 0.12346797436475754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1683349758386612, 0.22478938102722168, 0.06976605206727982, 0.1032773107290268, 0.16255290806293488, 0.08890064060688019, 0.03925151377916336, 0.023706944659352303, 0.11942004412412643, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.19914905726909637, 0.1368866264820099, 0.178489089012146, 0.11241752654314041, 0.06187256798148155, 0.0768556222319603, 0.01627686619758606, 0.07274915277957916, 0.14530348777770996, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08000901341438293, 0.20181676745414734, 0.21235129237174988, 0.05340588092803955, 0.12758778035640717, 0.11278047412633896, 0.06906574964523315, 0.08596791326999664, 0.05701539292931557, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14153669774532318, 0.10432923585176468, 0.09881750494241714, 0.08603313565254211, 0.10391980409622192, 0.06189347058534622, 0.06772381067276001, 0.08503933250904083, 0.25070688128471375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06525713205337524, 0.07869093865156174, 0.11366366595029831, 0.044226594269275665, 0.05455174669623375, 0.23646420240402222, 0.09933798015117645, 0.1198185384273529, 0.1879890412092209, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09450254589319229, 0.027017319574952126, 0.06480545550584793, 0.10929621011018753, 0.11382008343935013, 0.17441418766975403, 0.11898359656333923, 0.06495486199855804, 0.23220552504062653, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07681684195995331, 0.0671391412615776, 0.0905177965760231, 0.06064317002892494, 0.06652072072029114, 0.09855856746435165, 0.07360702753067017, 0.13956283032894135, 0.3266339898109436, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12179998308420181, 0.07977079600095749, 0.08405954390764236, 0.1456507444381714, 0.14551174640655518, 0.07862778753042221, 0.09882251918315887, 0.14300917088985443, 0.1027478501200676, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0830092504620552, 0.0839436799287796, 0.10106679797172546, 0.11154499650001526, 0.045070260763168335, 0.1284436285495758, 0.1161414161324501, 0.19574469327926636, 0.1350351870059967, 0.0, 0.8417463898658752, 0.05951714888215065, 0.012198105454444885, 0.03180553764104843, 0.02919766865670681, 0.0096508814021945, 0.003031272441148758, 0.0009100366733036935, 0.011942943558096886, 0.0], [0.0006529411766678095, 0.0018492193194106221, 0.018439743667840958, 0.004895282443612814, 0.0036929987836629152, 0.05041775107383728, 0.03271673619747162, 0.4425412714481354, 0.4447941780090332, 0.0, 0.00569154741242528, 0.979739785194397, 0.012030904181301594, 0.0001143000990850851, 9.368032624479383e-05, 0.0008171445806510746, 0.00012590458209160715, 0.0005024938145652413, 0.0008843241375871003, 0.0], [0.015919672325253487, 0.02172437310218811, 0.013682822696864605, 0.028371846303343773, 0.017258556559681892, 0.014516759663820267, 0.033475372940301895, 0.45419326424598694, 0.40085726976394653, 0.0, 0.005223963409662247, 0.005622355733066797, 0.9848889708518982, 0.002582893241196871, 0.0003334738139528781, 0.0005618981667794287, 3.256636409787461e-05, 0.00024550766102038324, 0.0005086653982289135, 0.0], [0.006064589135348797, 0.006147248670458794, 0.06902536749839783, 0.011021673679351807, 0.0062199062667787075, 0.17622654139995575, 0.00982236210256815, 0.46262383460998535, 0.25284844636917114, 0.0, 0.0032260464504361153, 0.007557107135653496, 0.0651315227150917, 0.6094849109649658, 0.008782745338976383, 0.2748804986476898, 0.015592943876981735, 0.008143502287566662, 0.007200630847364664, 0.0], [0.018328940495848656, 0.034908927977085114, 0.027539005503058434, 0.04494883120059967, 0.03695090860128403, 0.18224696815013885, 0.04204700142145157, 0.09570277482271194, 0.5173265337944031, 0.0, 0.01683628372848034, 0.0020552987698465586, 0.00783018209040165, 0.008005303330719471, 0.0011927365558221936, 0.9284406900405884, 0.03478293865919113, 0.00030738895293325186, 0.0005490221083164215, 0.0], [0.06838149577379227, 0.025893883779644966, 0.06412170827388763, 0.11039282381534576, 0.12848982214927673, 0.09953469038009644, 0.09056522697210312, 0.12723064422607422, 0.28538966178894043, 0.0, 0.0004254023951943964, 7.111614831956103e-05, 0.0008891545585356653, 1.880968193290755e-05, 6.570573896169662e-05, 0.9941434860229492, 0.0025632327888160944, 9.733852493809536e-06, 0.0018130606040358543, 0.0], [0.07893572002649307, 0.0734885111451149, 0.06503137946128845, 0.04291535168886185, 0.08502060174942017, 0.04846649244427681, 0.07035838067531586, 0.14812934398651123, 0.38765427470207214, 0.0, 7.936867405078374e-06, 1.8136512153432705e-05, 4.5569290705316234e-06, 1.071940641850233e-05, 3.808495648627286e-06, 0.0008168917265720665, 0.9974388480186462, 1.4373016711033415e-05, 0.0016848900122568011, 0.0], [0.007445929106324911, 0.004103729501366615, 0.05411284416913986, 0.006074799690395594, 0.07146289199590683, 0.5494692921638489, 0.05009504780173302, 0.058794084936380386, 0.1984413117170334, 0.0, 0.0014213839313015342, 0.003971228376030922, 0.008488249033689499, 2.0282970581320114e-05, 8.774230809649453e-05, 0.030342059209942818, 0.010436602868139744, 0.013138609007000923, 0.9320940375328064, 0.0], [0.0037151367869228125, 0.005083263851702213, 0.02171880006790161, 0.01245985459536314, 0.012914983555674553, 0.14437292516231537, 0.026943473145365715, 0.17420484125614166, 0.5985866785049438, 0.0, 9.058997966349125e-05, 0.0009022729936987162, 0.0017266678623855114, 1.3629892237077001e-05, 0.000727150880265981, 0.002379553159698844, 0.0010508937994018197, 0.012508089654147625, 0.9806011319160461, 0.0], [0.02579679898917675, 0.0645768865942955, 0.03225725144147873, 0.044467855244874954, 0.04297630116343498, 0.06060377135872841, 0.030930038541555405, 0.03278812766075134, 0.6656030416488647, 0.0, 0.0003429521748330444, 0.001905322540551424, 0.0005013775080442429, 1.1471392099338118e-05, 0.00017356597527395934, 0.0029742273036390543, 0.003938945475965738, 0.028075864538550377, 0.9620763063430786, 0.0]], [[0.0261031873524189, 0.9575563073158264, 0.006272038444876671, 0.0037288309540599585, 0.0038619006518274546, 0.0007324732141569257, 0.0005133527447469532, 0.0003637235495261848, 0.0008679544553160667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02134888991713524, 0.08473973721265793, 0.6753177642822266, 0.028721673414111137, 0.14432094991207123, 0.027568204328417778, 0.0057298606261610985, 0.004451636224985123, 0.007801060564815998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03883299231529236, 0.030284319072961807, 0.5620493292808533, 0.09062989801168442, 0.17362907528877258, 0.08253934979438782, 0.010801085270941257, 0.00978847872465849, 0.0014453904004767537, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002180949319154024, 0.003013473702594638, 0.16569769382476807, 0.008050205186009407, 0.7580646276473999, 0.061441101133823395, 0.001020166208036244, 0.0001067533012246713, 0.0004249440098647028, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004150479566305876, 0.00034606645931489766, 0.3802972435951233, 0.06855826079845428, 0.29045602679252625, 0.1767650991678238, 0.06603583693504333, 0.0014808314153924584, 0.011909942142665386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006170187145471573, 0.0012396957026794553, 0.0354800671339035, 0.0032299698796123266, 0.03240001201629639, 0.5543311238288879, 0.30418315529823303, 0.051339369267225266, 0.01162647269666195, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0035115755163133144, 0.0011483307462185621, 0.017956364899873734, 0.003783614607527852, 0.030611976981163025, 0.3673596978187561, 0.20627115666866302, 0.3506667912006378, 0.01869054324924946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0021685126703232527, 0.0006909942603670061, 0.010240452364087105, 0.01958688348531723, 0.004634156823158264, 0.11485372483730316, 0.04815557599067688, 0.7050773501396179, 0.0945921242237091, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.049201104789972305, 0.02397306263446808, 0.02337191067636013, 0.31066185235977173, 0.06433572620153427, 0.12544430792331696, 0.0786852017045021, 0.25179895758628845, 0.07252778857946396, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.010841209441423416, 0.0041772774420678616, 0.01548130251467228, 0.036074474453926086, 0.033387064933776855, 0.08192819356918335, 0.04784044623374939, 0.10195028781890869, 0.668319821357727, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13460709154605865, 0.15298102796077728, 0.06546170264482498, 0.14220191538333893, 0.11837887763977051, 0.09888823330402374, 0.10630416870117188, 0.08867054432630539, 0.09250646829605103, 0.0, 0.23634016513824463, 0.09021607041358948, 0.12040459364652634, 0.01354933436959982, 0.0019137230701744556, 0.009001325815916061, 0.028688833117485046, 0.2612648904323578, 0.23862121999263763, 0.0], [0.9316296577453613, 0.016095036640763283, 0.0020372711587697268, 0.0019596514757722616, 2.8437656510504894e-05, 6.708989531034604e-05, 0.0004955903859809041, 3.0113247703411616e-05, 0.047657083719968796, 0.0, 0.2307557761669159, 0.2812652289867401, 0.30346915125846863, 0.05031246319413185, 0.006193350534886122, 0.01668362505733967, 0.012607063166797161, 0.07951408624649048, 0.019199388101696968, 0.0], [0.043201129883527756, 0.9419298768043518, 0.0003410913050174713, 0.003313146298751235, 7.506452675443143e-06, 1.9570916265365668e-05, 2.5470235414104536e-05, 2.1080213628010824e-05, 0.011141069233417511, 0.0, 0.29960742592811584, 0.20819564163684845, 0.27825382351875305, 0.007396433036774397, 0.0007608149899169803, 0.0260151494294405, 0.012685009278357029, 0.12934625148773193, 0.03773954138159752, 0.0], [3.7581870856229216e-05, 0.00022979748609941453, 0.9982534646987915, 8.70372386998497e-05, 5.87535805607331e-06, 2.5239218302886002e-05, 6.597588708245894e-06, 2.193619138779468e-06, 0.001352491439320147, 0.0, 0.035675279796123505, 0.035874202847480774, 0.007117687724530697, 0.018771182745695114, 0.010206644423305988, 0.06527784466743469, 0.03775254264473915, 0.7770709991455078, 0.012253628112375736, 0.0], [0.0019612079486250877, 0.011641290038824081, 0.010358362458646297, 0.8346317410469055, 0.00641160923987627, 0.0007435380248352885, 0.0018172020791098475, 7.255822129081935e-05, 0.1323624849319458, 0.0, 0.012017791159451008, 0.0028583300299942493, 0.0024127706419676542, 0.002610970288515091, 0.001820205245167017, 0.04092223569750786, 0.016621166840195656, 0.9115477800369263, 0.009188669733703136, 0.0], [4.077299308846705e-05, 0.00016088274423964322, 3.1180113637674367e-06, 5.9685276937671006e-05, 6.661444786004722e-06, 0.0006764131248928607, 5.4107837058836594e-05, 0.9797272086143494, 0.01927126571536064, 0.0, 0.03447290509939194, 0.013388306833803654, 0.08488336205482483, 0.015237652696669102, 0.19176845252513885, 0.3472833037376404, 0.10885429382324219, 0.192628413438797, 0.011483324691653252, 0.0], [2.7792530090664513e-06, 1.1777839063142892e-05, 1.0386434951215051e-05, 0.0006807934259995818, 0.00028749846387654543, 0.9563493728637695, 2.4335316993528977e-05, 0.001297356327995658, 0.041335828602313995, 0.0, 0.0005363536183722317, 0.0001964608090929687, 0.0017719777533784509, 0.003164003835991025, 0.27662715315818787, 0.05286016687750816, 0.648875892162323, 0.007890382781624794, 0.00807751715183258, 0.0], [0.00033864984288811684, 0.00016234541544690728, 0.00011107163300039247, 7.639558316441253e-05, 9.851753566181287e-05, 0.00046863980242051184, 0.9855522513389587, 0.00012009339843643829, 0.013071970082819462, 0.0, 0.001257028547115624, 0.00020761204359587282, 0.0024441492278128862, 0.003374723019078374, 0.9062062501907349, 0.0712839737534523, 0.0032159662805497646, 0.009974849410355091, 0.0020355340093374252, 0.0], [0.001446103909984231, 0.0026176422834396362, 0.0005430445889942348, 0.5833504796028137, 0.08298782259225845, 0.01277364045381546, 0.008405186235904694, 0.028461067005991936, 0.2794148921966553, 0.0, 0.0008205634076148272, 0.00019305139721836895, 0.002098840195685625, 0.004588909447193146, 0.9688709378242493, 0.01628950424492359, 0.0038415545132011175, 0.0016231476329267025, 0.0016735766548663378, 0.0], [8.301706202473724e-07, 1.612889263924444e-06, 3.859615389956161e-06, 0.0015496612759307027, 0.9884966611862183, 0.0003321043332107365, 1.1829011782538146e-05, 3.7258676002238644e-06, 0.00959983840584755, 0.0, 0.03610469028353691, 0.046298399567604065, 0.04650943726301193, 0.02111651562154293, 0.06683006882667542, 0.37146270275115967, 0.174205482006073, 0.15773150324821472, 0.07974111288785934, 0.0]], [[0.005738695617765188, 0.0068999892100691795, 0.4274883270263672, 0.08288666605949402, 0.1445126235485077, 0.04382907599210739, 0.10957401990890503, 0.05347184091806412, 0.1255987584590912, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0025263649877160788, 0.00471830926835537, 0.13454590737819672, 0.4177793860435486, 0.28839975595474243, 0.029358303174376488, 0.017654288560152054, 0.0047735795378685, 0.10024390369653702, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009192855097353458, 0.007133236154913902, 0.03149157017469406, 0.1856081485748291, 0.5691666603088379, 0.07386670261621475, 0.029819192364811897, 0.03683711960911751, 0.05688462406396866, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00297820963896811, 0.0015070328954607248, 0.0025649494491517544, 0.0011051844339817762, 0.04088710993528366, 0.1953955888748169, 0.34000417590141296, 0.3367410898208618, 0.07881659269332886, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003951869439333677, 0.009354526177048683, 0.007010620087385178, 0.0025927696842700243, 0.09962604194879532, 0.10909298062324524, 0.4455967843532562, 0.15358439087867737, 0.16918975114822388, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0038829154800623655, 0.0036434896755963564, 0.006399825215339661, 0.000760377966798842, 0.010139851830899715, 0.038725122809410095, 0.10014155507087708, 0.48370444774627686, 0.35260239243507385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.001297087874263525, 0.0014563009608536959, 0.013839880004525185, 0.0004286184557713568, 0.012207024730741978, 0.028704902157187462, 0.046600911766290665, 0.26406532526016235, 0.6313998103141785, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0033481158316135406, 0.0038099782541394234, 0.0031049775425344706, 0.00033546099439263344, 0.0031272985506802797, 0.008788534440100193, 0.021183660253882408, 0.12157405912876129, 0.8347280025482178, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3364367187023163, 0.17456969618797302, 0.051038213074207306, 0.006790165323764086, 0.024106895551085472, 0.0694134384393692, 0.02184627763926983, 0.061508405953645706, 0.25429028272628784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10536088049411774, 0.07750789821147919, 0.0850178673863411, 0.08725376427173615, 0.2586125433444977, 0.16756391525268555, 0.054291605949401855, 0.030132828280329704, 0.13425879180431366, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03624086081981659, 0.008591840974986553, 0.01890810765326023, 0.010947922244668007, 0.5211313366889954, 0.04890615865588188, 0.13394898176193237, 0.08554741740226746, 0.13577744364738464, 0.0, 0.03425053879618645, 0.026130978018045425, 0.3080751299858093, 0.027706336230039597, 0.12989944219589233, 0.29902005195617676, 0.0305496696382761, 0.03879137709736824, 0.1055762991309166, 0.0], [0.09101090580224991, 0.15663929283618927, 0.2008313536643982, 0.13744188845157623, 0.16349081695079803, 0.01479706447571516, 0.04576689749956131, 0.05515507981181145, 0.1348666250705719, 0.0, 0.004509713500738144, 0.02305547706782818, 0.939035952091217, 0.006188178434967995, 0.020785806700587273, 0.00040150884888134897, 0.00018676061881706119, 0.00013036451127845794, 0.005706076975911856, 0.0], [0.10898119956254959, 0.19741322100162506, 0.12774543464183807, 0.07097428292036057, 0.033309608697891235, 0.016726871952414513, 0.019306309521198273, 0.09155051410198212, 0.3339925706386566, 0.0, 0.0005241778562776744, 0.009561678394675255, 0.988527774810791, 2.2495760276797228e-05, 4.7274414100684226e-05, 0.00013538387429434806, 4.543165232462343e-06, 6.27172994427383e-05, 0.001113483915105462, 0.0], [0.051247891038656235, 0.06952031701803207, 0.3243081271648407, 0.04820195212960243, 0.05462171137332916, 0.04280935227870941, 0.03801479935646057, 0.07710513472557068, 0.2941707372665405, 0.0, 0.06551901996135712, 0.0800878182053566, 0.06342226266860962, 0.00974376779049635, 0.5160938501358032, 0.02204274758696556, 0.004013149533420801, 0.0735243633389473, 0.1655530482530594, 0.0], [0.22540897130966187, 0.04426601901650429, 0.13483746349811554, 0.09052211791276932, 0.036632657051086426, 0.06078784167766571, 0.09962243586778641, 0.04597063735127449, 0.2619517743587494, 0.0, 0.0013552415184676647, 0.0004213388019707054, 0.002606122987344861, 0.0010090378345921636, 0.24638326466083527, 0.6568374633789062, 0.01604411192238331, 0.04806208983063698, 0.027281243354082108, 0.0], [0.08315062522888184, 0.10649015009403229, 0.15254046022891998, 0.0728936716914177, 0.10388997197151184, 0.04998103529214859, 0.0675109326839447, 0.17524446547031403, 0.18829864263534546, 0.0, 0.0002145337639376521, 0.00018796027870848775, 0.0008407118148170412, 0.0029629908967763186, 0.28427600860595703, 0.6725634336471558, 0.023870857432484627, 0.00339014851488173, 0.011693413369357586, 0.0], [0.09407053142786026, 0.04335644096136093, 0.04757237061858177, 0.023308007046580315, 0.14141318202018738, 0.017728488892316818, 0.02331509254872799, 0.07266414165496826, 0.5365718007087708, 0.0, 0.0009873382514342666, 0.0005485343281179667, 6.628077971981838e-05, 0.0029302756302058697, 0.23183174431324005, 0.05256076529622078, 0.5701138377189636, 0.005792138632386923, 0.13516920804977417, 0.0], [0.08477651327848434, 0.026448125019669533, 0.013684368692338467, 0.1331702470779419, 0.16824185848236084, 0.007634431589394808, 0.025501158088445663, 0.035930439829826355, 0.5046128630638123, 0.0, 2.471696279826574e-05, 2.0868348656222224e-05, 4.437468305695802e-05, 0.002024284563958645, 0.9655042886734009, 0.024176988750696182, 0.001284845289774239, 0.00018083618488162756, 0.006738840136677027, 0.0], [0.03296202793717384, 0.01823815330862999, 0.025750160217285156, 0.08325016498565674, 0.1596710979938507, 0.010502922348678112, 0.01792057603597641, 0.05097610503435135, 0.6007286906242371, 0.0, 0.0007289832574315369, 7.746354822302237e-05, 0.00018428664770908654, 0.014176051132380962, 0.9112405180931091, 0.013280178420245647, 0.003417921019718051, 0.02014165185391903, 0.03675319626927376, 0.0], [0.04370357468724251, 0.02250431850552559, 0.016271278262138367, 0.019842427223920822, 0.12028838694095612, 0.03933797404170036, 0.043740611523389816, 0.08045370131731033, 0.6138576865196228, 0.0, 0.00874137319624424, 0.03438721224665642, 0.17507928609848022, 0.007159235887229443, 0.0029199302662163973, 0.023628318682312965, 0.007933209650218487, 0.004559694789350033, 0.7355918884277344, 0.0]], [[0.034539882093667984, 0.0018589550163596869, 0.9604092836380005, 1.3120608855388127e-05, 2.1815638319822028e-05, 0.00012517283903434873, 8.019943197723478e-05, 0.0021589084062725306, 0.0007928607519716024, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.048832912914804e-07, 1.7815009414334781e-06, 0.9998455047607422, 0.0001518452918389812, 4.1070780554264275e-08, 2.7954746156799715e-11, 9.231376947582692e-12, 9.901777175969073e-09, 2.5545642756696907e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [6.695767496012195e-08, 2.089915795977504e-07, 0.005368041805922985, 0.9945066571235657, 0.0001248170156031847, 2.304766155702964e-09, 2.762512718579302e-10, 3.973758211373024e-09, 9.372820954922645e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.018761014413675e-13, 1.4841802622529476e-16, 4.663825770023777e-09, 3.820862737313746e-09, 0.9999942183494568, 4.988648925063899e-06, 4.967477167452938e-13, 1.416252587396787e-16, 2.1775358895380023e-16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [4.666895758731471e-09, 7.292542437975502e-12, 2.898993545219497e-11, 4.2817244194637283e-10, 0.00027504604076966643, 0.9995728731155396, 0.00015239788626786321, 1.9082661839586734e-10, 2.232514032581706e-13, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.7137297136926577e-10, 5.3312285142048665e-12, 2.2368220760327594e-14, 4.904942142678549e-17, 8.726878775178193e-09, 0.004644036293029785, 0.9953435659408569, 1.324965796811739e-05, 6.982896899598856e-12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [4.877224735189145e-10, 1.5497924055196677e-09, 6.021576987036426e-11, 8.955144165463396e-19, 1.7180077889825118e-13, 6.163505759104737e-07, 0.001256544259376824, 0.9987285733222961, 1.4209075743565336e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [3.25698863434809e-08, 7.313030323530256e-07, 1.412931510458293e-06, 1.1662047555981733e-16, 8.495708612521816e-14, 1.1933978653379251e-13, 1.3303619539328793e-07, 0.01294001005589962, 0.9870572686195374, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.6884889646462398e-06, 2.6281904865754768e-05, 0.001122217159718275, 6.101166945882142e-06, 4.424501298672112e-08, 5.172042264953158e-13, 5.508820136168602e-11, 5.942968346062116e-05, 0.9987838268280029, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [4.288114359951578e-05, 6.015944563841913e-06, 0.004432132933288813, 0.025997335091233253, 0.000731422973331064, 6.87844434188456e-11, 8.199346692057408e-13, 7.098316245901515e-08, 0.9687905311584473, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1783323585987091, 0.3813028037548065, 0.2072289139032364, 0.06766574084758759, 0.053963109850883484, 0.030795719474554062, 0.023536406457424164, 0.03921645134687424, 0.01795845478773117, 0.0, 0.01947755739092827, 0.007096209097653627, 0.03225293010473251, 0.0123430285602808, 0.10373923927545547, 0.44083938002586365, 0.04899014160037041, 0.25500863790512085, 0.08025286346673965, 0.0], [0.8837893009185791, 0.07202983647584915, 0.03646722435951233, 0.0004511935112532228, 0.0007272462244145572, 0.0008432198665104806, 0.0031319037079811096, 0.0004143840924371034, 0.0021455709356814623, 0.0, 0.018974049016833305, 0.05092930048704147, 0.38670486211776733, 0.05532746762037277, 0.02096201851963997, 0.23439037799835205, 0.029592081904411316, 0.06233520433306694, 0.1407845914363861, 0.0], [0.3973897695541382, 0.14911939203739166, 0.3486334979534149, 0.012645252980291843, 0.00675938231870532, 0.00483374297618866, 0.010028100572526455, 0.012036854401230812, 0.058554183691740036, 0.0, 0.009641589596867561, 0.009545106440782547, 0.19981582462787628, 0.009672220796346664, 0.003704657079651952, 0.04582780599594116, 0.006998295895755291, 0.5789687037467957, 0.13582585752010345, 0.0], [0.005409032106399536, 0.005906772334128618, 0.13379110395908356, 0.15247586369514465, 0.06559418141841888, 0.15356750786304474, 0.04085409641265869, 0.029147597029805183, 0.41325387358665466, 0.0, 0.00450306897982955, 0.0034239809028804302, 0.012258612550795078, 0.005700208712369204, 0.04511384665966034, 0.4419432282447815, 0.12840862572193146, 0.13075105845928192, 0.22789721190929413, 0.0], [0.0013326199259608984, 0.0014979635598137975, 0.011986319907009602, 0.7730216383934021, 0.06901827454566956, 0.05895080044865608, 0.016383536159992218, 0.015771687030792236, 0.052037257701158524, 0.0, 0.00048664878704585135, 0.00010348611976951361, 0.0010980216320604086, 0.0006185582024045289, 0.028226494789123535, 0.37447214126586914, 0.09456676244735718, 0.48241522908210754, 0.018012629821896553, 0.0], [0.0012038598069921136, 0.0033955213148146868, 0.025528373196721077, 0.03136582672595978, 0.10901585966348648, 0.3851255178451538, 0.0182026457041502, 0.13982580602169037, 0.2863365411758423, 0.0, 8.0467427324038e-05, 3.9275117160286754e-05, 0.00016763176245149225, 0.00013412459520623088, 0.009092556312680244, 0.7851189374923706, 0.16675172746181488, 0.0029041438829153776, 0.03571125119924545, 0.0], [0.008065885864198208, 0.004362722393125296, 0.06363680213689804, 0.023311397060751915, 0.06106392294168472, 0.1357712298631668, 0.03965916484594345, 0.06073852628469467, 0.6033903956413269, 0.0, 0.0007275060634128749, 0.00015159584290813655, 0.00037383963353931904, 0.0005468691233545542, 0.01837681420147419, 0.03491391986608505, 0.7517433166503906, 0.00028147027478553355, 0.19288486242294312, 0.0], [0.0003142715140711516, 0.0005578870768658817, 0.0015481057344004512, 0.0887022390961647, 0.06383900344371796, 0.2639910578727722, 0.049384135752916336, 0.12241825461387634, 0.40924492478370667, 0.0, 0.0005560970166698098, 0.0002987806510645896, 0.0021934551186859608, 0.00023410467838402838, 0.023030919954180717, 0.05263887345790863, 0.01838914304971695, 0.0007265828317031264, 0.9019319415092468, 0.0], [0.0003916181158274412, 0.0003099135938100517, 0.0024421222042292356, 0.016801349818706512, 0.18835966289043427, 0.025843605399131775, 0.08458039909601212, 0.20884136855602264, 0.4724300503730774, 0.0, 0.007445591501891613, 0.0020796440076082945, 0.012208829633891582, 0.001590645289979875, 0.09274771064519882, 0.017371611669659615, 0.04761578515172005, 0.004260089714080095, 0.8146799802780151, 0.0], [5.865378989255987e-05, 7.253760122694075e-05, 0.0007906460668891668, 0.025103986263275146, 0.0753612071275711, 0.04038592055439949, 0.011871143244206905, 0.05808362737298012, 0.7882723212242126, 0.0, 0.014990360476076603, 0.004210897721350193, 0.002848376054316759, 0.0006518716691061854, 0.0007818753365427256, 0.0019951288122683764, 0.0036728696431964636, 0.0004030312702525407, 0.9704453349113464, 0.0]], [[0.02526121959090233, 0.9527671933174133, 0.014345486648380756, 0.0014051493490114808, 0.003839265089482069, 0.00014350644778460264, 0.0006356940139085054, 0.00025237957015633583, 0.0013501241337507963, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004122408106923103, 0.023777475580573082, 0.9002965688705444, 0.0682864859700203, 0.0017659803852438927, 0.0001271881628781557, 0.00011044178245356306, 0.0001890352723421529, 0.0013242338318377733, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.841444650897756e-05, 0.0002895947836805135, 0.06307922303676605, 0.9069769978523254, 0.028407124802470207, 0.000558151863515377, 0.00022284295118879527, 0.00018588549573905766, 0.00019132612214889377, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.889026179924258e-06, 3.9712713260087185e-06, 0.001210480579175055, 0.003201226470991969, 0.8290116786956787, 0.16640713810920715, 0.00015829727635718882, 4.0429063119518105e-06, 9.256136763724498e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.000399262469727546, 5.1438626542221755e-05, 0.0001944842515513301, 0.0007700449787080288, 0.4879837930202484, 0.4847603738307953, 0.025640420615673065, 0.00018376839580014348, 1.6383723050239496e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [4.30414620495867e-05, 1.017293288896326e-05, 8.407413588429336e-06, 5.451946094581217e-07, 0.000544070964679122, 0.021075371652841568, 0.9573339819908142, 0.0208626389503479, 0.00012169074034318328, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00043880229350179434, 0.0004488519043661654, 0.000600603292696178, 1.4583132212919736e-07, 3.6701523640658706e-05, 0.010162030346691608, 0.37363454699516296, 0.559087336063385, 0.0555914081633091, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0010709260823205113, 0.0006920771556906402, 0.0016655249055474997, 0.00010216240480076522, 1.0821948308148421e-05, 2.6151516067329794e-05, 0.01446994487196207, 0.2987785339355469, 0.6831837296485901, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0002485924051143229, 0.00016839140153024346, 0.019545644521713257, 0.016785046085715294, 0.005671702325344086, 0.00014030851889401674, 0.001185068627819419, 0.04272715002298355, 0.9135279655456543, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0039028520695865154, 0.0008621322922408581, 0.02400260791182518, 0.35541704297065735, 0.048350416123867035, 0.00013779231812804937, 0.00015075977717060596, 0.0015127401566132903, 0.5656636953353882, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01597539149224758, 0.027860743924975395, 0.08824922889471054, 0.011547067202627659, 0.02896539680659771, 0.03845160827040672, 0.011409634724259377, 0.043791815638542175, 0.7337491512298584, 0.0, 0.21779413521289825, 0.08220235258340836, 0.04201545566320419, 0.07069981843233109, 0.041075702756643295, 0.13784317672252655, 0.1975526064634323, 0.04344295710325241, 0.16737376153469086, 0.0], [0.0371943861246109, 0.014876782894134521, 0.02253115549683571, 0.10164438933134079, 0.029471710324287415, 0.040005166083574295, 0.020577073097229004, 0.07326765358448029, 0.6604316830635071, 0.0, 0.23605762422084808, 0.07441659271717072, 0.04143041744828224, 0.05435749515891075, 0.0077708023600280285, 0.0960790365934372, 0.4399828016757965, 0.006641789805144072, 0.04326343908905983, 0.0], [0.06676606088876724, 0.1320837438106537, 0.02368331328034401, 0.09289334714412689, 0.06407851725816727, 0.007657648529857397, 0.014540987089276314, 0.018603011965751648, 0.5796933174133301, 0.0, 0.06337786465883255, 0.03357791155576706, 0.03929098695516586, 0.5017232298851013, 0.0066258725710213184, 0.009236367419362068, 0.1690734624862671, 0.0422079935669899, 0.13488635420799255, 0.0], [0.029496638104319572, 0.013616771437227726, 0.030488401651382446, 0.021259615197777748, 0.13049498200416565, 0.06418323516845703, 0.050123173743486404, 0.1609034240245819, 0.4994336664676666, 0.0, 0.006272959988564253, 0.0007428607787005603, 0.0011506476439535618, 0.007357995491474867, 0.0006080326274968684, 0.05679970234632492, 0.8685706257820129, 0.03271445259451866, 0.025782890617847443, 0.0], [0.010230573825538158, 0.015954630449414253, 0.007779641076922417, 0.018425902351737022, 0.021085364744067192, 0.0588817335665226, 0.013979516923427582, 0.0252523310482502, 0.828410267829895, 0.0, 0.041861388832330704, 0.004794578067958355, 0.0024879220873117447, 0.015253551304340363, 0.0005973980878479779, 0.08281483501195908, 0.814189076423645, 0.006639576051384211, 0.03136153519153595, 0.0], [0.02648993395268917, 0.0214377511292696, 0.03494586795568466, 0.05471349507570267, 0.09140968322753906, 0.04952282831072807, 0.05564551055431366, 0.11169540882110596, 0.5541394948959351, 0.0, 0.010862020775675774, 0.0008270516409538686, 0.00023008826246950775, 0.006298262160271406, 0.0022151959128677845, 0.09469958394765854, 0.8416994214057922, 0.0006256845663301647, 0.04254243150353432, 0.0], [0.03231878578662872, 0.018621357157826424, 0.05183127149939537, 0.03979233279824257, 0.13804322481155396, 0.03567919135093689, 0.047386858612298965, 0.13114488124847412, 0.505182147026062, 0.0, 0.00024508681963197887, 3.835038296529092e-05, 2.0304802092141472e-05, 0.00012946058996021748, 0.0003255259362049401, 0.0026247953064739704, 0.9805192947387695, 0.00014136231038719416, 0.01595580205321312, 0.0], [0.04592716693878174, 0.010993612930178642, 0.01772226020693779, 0.05332585424184799, 0.15264220535755157, 0.22139224410057068, 0.048004403710365295, 0.12396018952131271, 0.3260320723056793, 0.0, 0.001919803791679442, 0.0005674636922776699, 0.0002780239738058299, 0.0008655164856463671, 0.0013816945720463991, 0.010561172850430012, 0.05357982590794563, 0.0009362901910208166, 0.9299100637435913, 0.0], [0.03168570622801781, 0.026294516399502754, 0.025469979271292686, 0.03026771917939186, 0.058515094220638275, 0.13361068069934845, 0.026259208098053932, 0.0612059161067009, 0.6066910624504089, 0.0, 0.00319756381213665, 0.0005108749028295279, 0.00043022894533351064, 0.005312783177942038, 0.005197612568736076, 0.008492776192724705, 0.05858352780342102, 0.01401757076382637, 0.9042569398880005, 0.0], [0.07492455840110779, 0.06428299844264984, 0.07022737711668015, 0.0507473424077034, 0.0447908453643322, 0.060839906334877014, 0.14463475346565247, 0.054812539368867874, 0.4347396492958069, 0.0, 0.00021474930690601468, 0.0004951281007379293, 0.00032367443782277405, 0.0001866286911536008, 6.129321263870224e-05, 0.00016246296581812203, 0.0016925180098041892, 0.000427676277467981, 0.996435821056366, 0.0]]], [[[0.09929531812667847, 0.3125585615634918, 0.26699960231781006, 0.036189958453178406, 0.01689508929848671, 0.05626463145017624, 0.014853590168058872, 0.021625356748700142, 0.17531771957874298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6598999500274658, 0.04883529245853424, 0.24573534727096558, 0.008949915878474712, 0.008034803904592991, 0.0058951652608811855, 0.001835338887758553, 0.0024289200082421303, 0.018385181203484535, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.28377673029899597, 0.4307016134262085, 0.19275489449501038, 0.05968217924237251, 0.007509235758334398, 0.00627214927226305, 0.0010254314402118325, 0.0010938378982245922, 0.017183959484100342, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00751571636646986, 0.01881357654929161, 0.9318985342979431, 0.014481762424111366, 0.02105659246444702, 0.0032304797787219286, 0.00013498679618351161, 2.4857494281604886e-05, 0.0028432777617126703, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08691340684890747, 0.01259385235607624, 0.21131311357021332, 0.15839329361915588, 0.3931293189525604, 0.10845079272985458, 0.004768806044012308, 0.0032348930835723877, 0.021202562376856804, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029192518442869186, 0.06438057869672775, 0.033022571355104446, 0.04279496520757675, 0.6011855006217957, 0.17385539412498474, 0.03754284232854843, 0.006468524225056171, 0.011557108722627163, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006125382613390684, 0.006982659921050072, 0.004575703293085098, 0.0037440320011228323, 0.36007580161094666, 0.5409486889839172, 0.0626324936747551, 0.00843171589076519, 0.006483553443104029, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0017123871948570013, 0.017555760219693184, 0.012620777823030949, 0.00947127677500248, 0.08178496360778809, 0.2538650631904602, 0.19189175963401794, 0.255443274974823, 0.17565478384494781, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02615528553724289, 0.002552631078287959, 0.01957615464925766, 0.021708596497774124, 0.008856788277626038, 0.021813882514834404, 0.052812058478593826, 0.19690369069576263, 0.6496209502220154, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004899451043456793, 0.005663626827299595, 0.012920243665575981, 0.007757777348160744, 0.014441648498177528, 0.021742597222328186, 0.05050418898463249, 0.35952994227409363, 0.5225404500961304, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9642227292060852, 0.035777393728494644, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9262088537216187, 0.07379112392663956, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9523521065711975, 0.027811188250780106, 0.019836684688925743, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2983383536338806, 0.576672375202179, 0.12498921155929565, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.849480152130127, 0.03536543622612953, 0.019422976300120354, 0.09573143720626831, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3100782334804535, 0.1274886280298233, 0.5286650061607361, 0.033768050372600555, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.741925060749054, 0.05566684901714325, 0.024736514315009117, 0.08595114946365356, 0.09172046929597855, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3118414282798767, 0.11087317764759064, 0.12077098339796066, 0.10916762799024582, 0.34734681248664856, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6503966450691223, 0.0582728385925293, 0.0236701387912035, 0.0691222995519638, 0.0758395791053772, 0.12269847840070724, 0.0, 0.0, 0.0, 0.0, 0.1361667662858963, 0.0034004957415163517, 0.00320720998570323, 0.0056303562596440315, 0.013746269047260284, 0.8378488421440125, 0.0, 0.0, 0.0, 0.0], [0.4914315342903137, 0.11739180237054825, 0.02309434488415718, 0.07889512181282043, 0.05101678892970085, 0.12367808818817139, 0.11449223756790161, 0.0, 0.0, 0.0, 0.9168469905853271, 0.009582683444023132, 0.002923850901424885, 0.009140468202531338, 0.0233402531594038, 0.01968987099826336, 0.01847577467560768, 0.0, 0.0, 0.0], [0.4262734055519104, 0.07066749036312103, 0.024391667917370796, 0.04879573732614517, 0.051445234566926956, 0.1276569813489914, 0.11843930184841156, 0.13233007490634918, 0.0, 0.0, 0.4528708755970001, 0.012551077641546726, 0.013286955654621124, 0.003301329677924514, 0.024005549028515816, 0.0439622700214386, 0.03865182027220726, 0.41137006878852844, 0.0, 0.0], [0.589878499507904, 0.026613032445311546, 0.020459800958633423, 0.028271155431866646, 0.03679497539997101, 0.07860217243432999, 0.08500825613737106, 0.09285575151443481, 0.04151623696088791, 0.0, 0.06380993872880936, 0.0008893097401596606, 0.0011801879154518247, 0.0013187900185585022, 0.0034512828569859266, 0.0014297974994406104, 0.0023058890365064144, 0.041651248931884766, 0.8839635848999023, 0.0], [0.2743179202079773, 0.06089583784341812, 0.03565794974565506, 0.044920988380908966, 0.03933599591255188, 0.18495218455791473, 0.09192009270191193, 0.13160176575183868, 0.04121606424450874, 0.09518115967512131, 0.5330018997192383, 0.012773798778653145, 0.01854255609214306, 0.022641947492957115, 0.1288023591041565, 0.01178218238055706, 0.020595960319042206, 0.08756020665168762, 0.09921147674322128, 0.06508753448724747]], [[0.8470081686973572, 0.043761640787124634, 0.000660977209918201, 0.00018918802379630506, 0.01478277612477541, 0.00942840613424778, 0.06798462569713593, 0.011217072606086731, 0.004967056680470705, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9998846054077148, 9.298400982515886e-05, 7.557733283647394e-08, 4.2952964861113496e-13, 4.9295836510032665e-12, 3.2098330660090824e-09, 5.042555585532682e-06, 1.7450745872338302e-05, 2.33268380611662e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.118646625604015e-05, 0.9999122619628906, 6.629392737522721e-05, 1.312590147684034e-09, 2.7011800782239526e-11, 6.488713510726871e-14, 1.250517189799183e-10, 3.650779589747799e-08, 2.9122876554765753e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.1949000816580124e-11, 3.2456850362905243e-07, 1.0, 3.0732459777027543e-07, 4.943382370115046e-10, 1.2582140899967535e-17, 7.485076299292317e-18, 2.998638596002183e-14, 1.3861908843004755e-10, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.382360668271247e-10, 8.056646905174603e-09, 0.00035429277340881526, 0.9995232820510864, 0.00012279135989956558, 1.6631793720023325e-09, 1.8857353897253244e-14, 9.284229879032505e-15, 1.8321206097376974e-12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.614902194392648e-12, 3.5818106835540375e-13, 4.029543365646759e-09, 3.1193526410788763e-06, 0.9959417581558228, 0.004055640660226345, 2.0883923923520342e-08, 1.5150488692381933e-14, 1.8145465705242968e-17, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.3006167283734502e-12, 4.150501252094593e-15, 2.9068709245239077e-12, 2.726213081238188e-13, 1.0724114645199734e-06, 0.9999104142189026, 8.954491204349324e-05, 3.77386955019432e-10, 8.537545242676776e-16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.656632632941808e-10, 2.8593680201360883e-10, 4.910126749635424e-10, 3.37084723469553e-15, 1.3075121541028523e-10, 0.0003027402563020587, 0.999218225479126, 0.00047932929010130465, 1.4258912273135138e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0133464911632473e-07, 1.7307414168499236e-07, 2.3342326471720298e-07, 4.688030020606748e-13, 1.5028331227032177e-12, 5.3876938466146385e-09, 0.00158107269089669, 0.994592010974884, 0.0038271904923021793, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.33300490037891e-10, 1.2628836998374027e-07, 1.2948551102454076e-06, 3.169647599943204e-10, 1.5141217069741288e-14, 8.21656009561151e-15, 2.347289251858342e-09, 0.0025180077645927668, 0.9974797964096069, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9842625260353088, 0.015737490728497505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9422653913497925, 0.057734500616788864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8382691144943237, 0.11647694557905197, 0.04525385797023773, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.37070432305336, 0.2449311465024948, 0.3843645751476288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4638526439666748, 0.1585947573184967, 0.3189436197280884, 0.0586090050637722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5423898100852966, 0.11884469538927078, 0.1850128471851349, 0.15375272929668427, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2375488132238388, 0.07284080982208252, 0.20766110718250275, 0.3110494017601013, 0.1708998829126358, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7452426552772522, 0.024770371615886688, 0.025099167600274086, 0.014617366716265678, 0.19027042388916016, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20615516602993011, 0.03705071657896042, 0.05929475650191307, 0.08692343533039093, 0.5564662218093872, 0.05410974845290184, 0.0, 0.0, 0.0, 0.0, 0.4940005838871002, 0.026306116953492165, 0.014163044281303883, 0.022562485188245773, 0.43185216188430786, 0.011115492321550846, 0.0, 0.0, 0.0, 0.0], [0.31913095712661743, 0.011343744583427906, 0.01675090566277504, 0.013238506391644478, 0.06746862828731537, 0.3789318799972534, 0.19313538074493408, 0.0, 0.0, 0.0, 0.8323472142219543, 0.005361876450479031, 0.001218354911543429, 0.0017811520956456661, 0.06672050058841705, 0.0179598405957222, 0.07461105287075043, 0.0, 0.0, 0.0], [0.4113273322582245, 0.003934106323868036, 0.003564919577911496, 0.005882325116544962, 0.018547017127275467, 0.18534934520721436, 0.3216978907585144, 0.04969710111618042, 0.0, 0.0, 0.5900163650512695, 0.0016051119891926646, 0.00041884748497977853, 0.002425695303827524, 0.09076588600873947, 0.005809221416711807, 0.03928956016898155, 0.2696692943572998, 0.0, 0.0], [0.07648876309394836, 0.0013769177021458745, 0.001890459912829101, 0.006597061175853014, 0.007926206104457378, 0.013261871412396431, 0.15683594346046448, 0.7190074324607849, 0.016615279018878937, 0.0, 0.14191001653671265, 0.0026981914415955544, 0.000433926354162395, 0.0025318085681647062, 0.0752185806632042, 0.041030533611774445, 0.10226735472679138, 0.6134982705116272, 0.020411266013979912, 0.0], [0.08104224503040314, 0.00045554721145890653, 0.00038501128437928855, 0.0009405335295014083, 0.005597654264420271, 0.0034990713465958834, 0.009850292466580868, 0.0463707260787487, 0.7366765141487122, 0.11518235504627228, 0.9951959252357483, 0.000172812317032367, 0.0011272057890892029, 0.0002565488684922457, 0.001650187186896801, 0.0010172545444220304, 3.585639569791965e-05, 0.00030177918961271644, 2.7251116989646107e-05, 0.00021514984837267548]], [[0.011770328506827354, 0.014021093025803566, 0.10656744986772537, 0.04667313024401665, 0.13704808056354523, 0.04681243374943733, 0.08347266167402267, 0.3310377299785614, 0.22259721159934998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009583584032952785, 0.010384900495409966, 0.09424954652786255, 0.09874095767736435, 0.2214881330728531, 0.08727390319108963, 0.09998933970928192, 0.16299772262573242, 0.21529172360897064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.040493443608284, 0.05296378955245018, 0.12471148371696472, 0.04822944849729538, 0.2201310694217682, 0.13458549976348877, 0.16853223741054535, 0.12866733968257904, 0.08168572932481766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014574799686670303, 0.015747353434562683, 0.011357909068465233, 0.008449763990938663, 0.024292636662721634, 0.06141809746623039, 0.10683716088533401, 0.6414783596992493, 0.1158437430858612, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0041047134436666965, 0.010159346275031567, 0.006441198755055666, 0.009530052542686462, 0.061682768166065216, 0.07391326874494553, 0.3019707202911377, 0.45178085565567017, 0.08041701465845108, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013634801842272282, 0.03774101287126541, 0.015713637694716454, 0.01436087116599083, 0.06650711596012115, 0.06899012625217438, 0.1819150745868683, 0.376579225063324, 0.2245580554008484, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03166442736983299, 0.07015468180179596, 0.1104653850197792, 0.016236137598752975, 0.18190902471542358, 0.08141329884529114, 0.15690769255161285, 0.22899281978607178, 0.12225660681724548, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10994787514209747, 0.08447018265724182, 0.05270976573228836, 0.013435273431241512, 0.06919412314891815, 0.04981343820691109, 0.24833135306835175, 0.2721446752548218, 0.09995320439338684, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39435869455337524, 0.21061576902866364, 0.1085209921002388, 0.004411425907164812, 0.06908565759658813, 0.04562678933143616, 0.02559957653284073, 0.06842028349637985, 0.0733608528971672, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2682938873767853, 0.18270419538021088, 0.12741044163703918, 0.03156330808997154, 0.10574271529912949, 0.0955348014831543, 0.052997197955846786, 0.0821281224489212, 0.05362524837255478, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9800853133201599, 0.019914645701646805, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9959792494773865, 0.004020644351840019, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9159882068634033, 0.02969631738960743, 0.05431551858782768, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8763805031776428, 0.06819441169500351, 0.05542506277561188, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6467475295066833, 0.08892705291509628, 0.19796258211135864, 0.06636285036802292, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6675543785095215, 0.035431310534477234, 0.2554236948490143, 0.04159051924943924, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9833061099052429, 0.004010406322777271, 0.004914217162877321, 0.0015858567785471678, 0.006183335091918707, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8250302076339722, 0.013232334516942501, 0.10887149721384048, 0.016031241044402122, 0.03683457896113396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9524497389793396, 0.0022862900514155626, 0.000848656112793833, 0.00408557103946805, 0.028177350759506226, 0.012152665294706821, 0.0, 0.0, 0.0, 0.0, 0.14042839407920837, 0.005938003305345774, 0.04128086566925049, 0.01834655925631523, 0.7866368293762207, 0.007369248662143946, 0.0, 0.0, 0.0, 0.0], [0.1907505989074707, 0.026542214676737785, 0.01945381611585617, 0.029287727549672127, 0.057166602462530136, 0.11766232550144196, 0.5591367483139038, 0.0, 0.0, 0.0, 0.3567042350769043, 0.0165000781416893, 0.015264611691236496, 0.010309864766895771, 0.38396307826042175, 0.025359012186527252, 0.1918991357088089, 0.0, 0.0, 0.0], [0.4022328555583954, 0.017193131148815155, 0.01565318927168846, 0.01915702596306801, 0.01739031821489334, 0.16459040343761444, 0.18205313384532928, 0.18172988295555115, 0.0, 0.0, 0.03735272213816643, 0.0005555232055485249, 0.0009066119673661888, 0.003488750196993351, 0.4253699481487274, 0.039391178637742996, 0.3313658535480499, 0.1615692675113678, 0.0, 0.0], [0.9652498960494995, 0.0010482663055881858, 0.0012260396033525467, 0.0009098293376155198, 0.0013901795027777553, 0.0028189055155962706, 0.007343438919633627, 0.018731823191046715, 0.0012814495712518692, 0.0, 0.0020103107672184706, 0.0002689870889298618, 0.0004340466111898422, 0.0009705349220894277, 0.03535917028784752, 0.014057940803468227, 0.07802704721689224, 0.8683921694755554, 0.0004796571738552302, 0.0], [0.18471455574035645, 0.018054824322462082, 0.08812589198350906, 0.00762907462194562, 0.018057269975543022, 0.05247756093740463, 0.03497685119509697, 0.5025416612625122, 0.052323222160339355, 0.04109897091984749, 0.21001528203487396, 0.008917403407394886, 0.08127831667661667, 0.6020672917366028, 0.0504239983856678, 0.01106872595846653, 0.002271559089422226, 0.009885885752737522, 0.013363776728510857, 0.010707534849643707]], [[8.027511648833752e-05, 0.0010475717717781663, 0.9977908730506897, 0.0002747455728240311, 0.000536168459802866, 9.231048170477152e-05, 0.00010586588905425742, 1.1979215742030647e-05, 5.969347330392338e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00012679747305810452, 5.715776205761358e-05, 0.922791600227356, 0.07177212089300156, 0.002934361109510064, 0.0005548547487705946, 0.001313770073466003, 2.2278460164670832e-05, 0.0004267726035322994, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0063565499149262905, 0.0009426671313121915, 0.23976103961467743, 0.6402719020843506, 0.019077658653259277, 0.04590805247426033, 0.0423574335873127, 0.00055616011377424, 0.0047685266472399235, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00012164804502390325, 1.1780298336816486e-05, 0.0001827587402658537, 0.00020120454428251833, 0.9978508353233337, 0.0014421044616028666, 6.411068170564249e-05, 4.628768147085793e-05, 7.896547322161496e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03763079643249512, 0.00208932813256979, 0.0006042887107469141, 0.5138440728187561, 0.19755180180072784, 0.029773280024528503, 0.15554653108119965, 0.015671545639634132, 0.0472884401679039, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [3.8805592339485884e-05, 1.2464041901694145e-05, 9.030352521222085e-05, 1.7544094589538872e-05, 0.0006991567788645625, 0.039246365427970886, 0.9305517077445984, 0.02403487078845501, 0.005308609921485186, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003011370776221156, 0.005974559113383293, 0.003425326431170106, 0.001937237335368991, 0.01794668287038803, 0.06517820060253143, 0.25853174924850464, 0.28359606862068176, 0.3603990077972412, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0019687232561409473, 0.0019828693475574255, 0.0009621239732950926, 0.0017320939805358648, 0.008526722900569439, 0.012685983441770077, 0.060781437903642654, 0.38653799891471863, 0.524821937084198, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06319467723369598, 0.3812802731990814, 0.07775641977787018, 0.0546053946018219, 0.0410320870578289, 0.010218034498393536, 0.022281788289546967, 0.04868403077125549, 0.30094724893569946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06465335935354233, 0.0841824859380722, 0.028003698214888573, 0.01470992248505354, 0.013160775415599346, 0.006258893292397261, 0.003528257366269827, 0.022525515407323837, 0.7629771828651428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9911633133888245, 0.008836665190756321, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8274853825569153, 0.1725146621465683, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9641951322555542, 0.023474374786019325, 0.012330451980233192, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.39722761511802673, 0.5465205311775208, 0.05625181272625923, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6152319312095642, 0.28041696548461914, 0.04906271770596504, 0.05528838559985161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7089572548866272, 0.12511004507541656, 0.08669630438089371, 0.0792364850640297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6057276725769043, 0.1235719844698906, 0.06170117110013962, 0.11151555925607681, 0.0974835753440857, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9339975714683533, 0.013466393575072289, 0.00928713008761406, 0.00507207540795207, 0.03817704692482948, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6386814713478088, 0.07927443087100983, 0.06004401296377182, 0.06398510187864304, 0.06341437995433807, 0.09460049122571945, 0.0, 0.0, 0.0, 0.0, 0.7470325231552124, 0.0030789184384047985, 0.0006101431790739298, 0.009402818977832794, 0.23476918041706085, 0.005106179974973202, 0.0, 0.0, 0.0, 0.0], [0.13321073353290558, 0.0565485954284668, 0.20425985753536224, 0.10307760536670685, 0.17957380414009094, 0.26328328251838684, 0.06004612147808075, 0.0, 0.0, 0.0, 0.21711143851280212, 0.003716376842930913, 0.00037448908551596105, 0.0019620254170149565, 0.018900232389569283, 0.009617134928703308, 0.7483181953430176, 0.0, 0.0, 0.0], [0.19694660604000092, 0.027736904099583626, 0.05790374055504799, 0.10621010512113571, 0.15510229766368866, 0.2214440256357193, 0.18680275976657867, 0.04785352945327759, 0.0, 0.0, 0.010075456462800503, 5.468959716381505e-05, 5.17756825502147e-06, 5.762913860962726e-05, 0.0005752856959588826, 0.0004235330270603299, 0.004707484506070614, 0.9841007590293884, 0.0, 0.0], [0.08537944406270981, 0.033881768584251404, 0.03968465328216553, 0.08240006119012833, 0.15350975096225739, 0.23219235241413116, 0.22240297496318817, 0.11620921641588211, 0.034339725971221924, 0.0, 0.0014721885090693831, 9.766960283741355e-05, 9.390318155055866e-06, 9.01468301890418e-05, 0.00026504675042815506, 0.0001477079640608281, 0.0007441531051881611, 0.9970147013664246, 0.00015886487381067127, 0.0], [0.06051333248615265, 0.012086840346455574, 0.028373999521136284, 0.07542525231838226, 0.10199770331382751, 0.15039192140102386, 0.20426926016807556, 0.16016273200511932, 0.06537677347660065, 0.14140206575393677, 0.9506397247314453, 0.010028047487139702, 0.0004243685398250818, 0.012790095992386341, 0.006212451495230198, 0.0008045415161177516, 0.0008908100426197052, 0.0004145564162172377, 0.0002187698701163754, 0.01757662557065487]], [[0.00496841873973608, 0.010829150676727295, 0.03283568099141121, 0.009884797036647797, 0.047239795327186584, 0.06476759165525436, 0.11417313665151596, 0.6207002401351929, 0.09460126608610153, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014457895420491695, 0.06253711134195328, 0.10527490824460983, 0.051058270037174225, 0.04873393103480339, 0.058862265199422836, 0.13390113413333893, 0.44425415992736816, 0.0809202790260315, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09337731450796127, 0.22848238050937653, 0.11594945937395096, 0.04185759648680687, 0.012283656746149063, 0.1264774352312088, 0.19395124912261963, 0.16978387534618378, 0.017837027087807655, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7125841975212097, 0.21987739205360413, 0.020619483664631844, 0.02881826087832451, 0.009833384305238724, 0.004124533850699663, 0.0008098671096377075, 0.0004809961246792227, 0.0028517041355371475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029080189764499664, 0.33611080050468445, 0.12628716230392456, 0.0817737877368927, 0.1908877044916153, 0.0943109318614006, 0.05712011829018593, 0.06781000643968582, 0.016619542613625526, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07309448719024658, 0.07739713788032532, 0.0567743182182312, 0.03291132301092148, 0.16455504298210144, 0.1779973953962326, 0.2714528441429138, 0.13868720829486847, 0.007130389101803303, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2111189365386963, 0.06559138745069504, 0.041267942637205124, 0.009358389303088188, 0.20342323184013367, 0.1869427114725113, 0.19775718450546265, 0.07797932624816895, 0.006560905836522579, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08770362287759781, 0.12808790802955627, 0.023038268089294434, 0.17453545331954956, 0.09798892587423325, 0.11677049100399017, 0.09396524727344513, 0.26174578070640564, 0.01616443321108818, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.35409674048423767, 0.0420590415596962, 0.00930203776806593, 0.3349112272262573, 0.03967892378568649, 0.15319538116455078, 0.022175630554556847, 0.0432865284383297, 0.0012946304632350802, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10030248761177063, 0.08145220577716827, 0.053510215133428574, 0.08076464384794235, 0.07446140050888062, 0.13495147228240967, 0.2503055930137634, 0.17467214167118073, 0.04957977309823036, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5400503277778625, 0.4599496126174927, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9158000946044922, 0.0841999277472496, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04321815073490143, 0.9357689023017883, 0.02101275697350502, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9424960017204285, 0.02535107545554638, 0.032153017818927765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48035699129104614, 0.12913382053375244, 0.27151036262512207, 0.11899882555007935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22060541808605194, 0.18997374176979065, 0.08500542491674423, 0.5044154524803162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6920371055603027, 0.019891848787665367, 0.1885785609483719, 0.06273186951875687, 0.036760613322257996, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7531844973564148, 0.02070058509707451, 0.008920542895793915, 0.016695866361260414, 0.20049844682216644, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8527964949607849, 0.08059625327587128, 0.0037265238352119923, 0.008582950569689274, 0.042790722101926804, 0.01150701567530632, 0.0, 0.0, 0.0, 0.0, 0.759453296661377, 0.0056156679056584835, 0.008695651777088642, 0.014426307752728462, 0.16163751482963562, 0.05017174035310745, 0.0, 0.0, 0.0, 0.0], [0.900881826877594, 0.012710069306194782, 0.000794807099737227, 0.00424413476139307, 0.02110898308455944, 0.01962616853415966, 0.04063420742750168, 0.0, 0.0, 0.0, 0.2527230679988861, 0.0006535803549923003, 0.00037003192119300365, 0.00041730765951797366, 0.057080648839473724, 0.06757333129644394, 0.6211821436882019, 0.0, 0.0, 0.0], [0.713775098323822, 0.003081131726503372, 0.000918463512789458, 0.009338468313217163, 0.013423318043351173, 0.019161174073815346, 0.10174864530563354, 0.13855360448360443, 0.0, 0.0, 0.6996693015098572, 0.00526623846963048, 0.003115275641903281, 0.001864676014520228, 0.019210346043109894, 0.022201303392648697, 0.16487717628479004, 0.08379579335451126, 0.0, 0.0], [0.4800099730491638, 0.0009553784620948136, 0.00013007478264626116, 0.020002998411655426, 0.0032414987217634916, 0.002101779682561755, 0.028948260471224785, 0.46123453974723816, 0.0033754503820091486, 0.0, 0.01643717661499977, 0.001304203411564231, 0.00015219511988107115, 8.364384120795876e-05, 0.0027460975106805563, 0.005807426758110523, 0.02910688892006874, 0.054244525730609894, 0.8901176452636719, 0.0], [0.7501513361930847, 0.019767694175243378, 0.0020619838032871485, 0.0038300605956465006, 0.0023455689661204815, 0.023803891614079475, 0.011456847190856934, 0.045016106218099594, 0.08813992142677307, 0.05342674255371094, 0.03737838938832283, 0.0008823095704428852, 0.00013810240488965064, 0.0003819032572209835, 0.0009168537217192352, 0.017434338107705116, 0.0524771511554718, 0.5634113550186157, 0.05003770440816879, 0.27694204449653625]], [[0.140123188495636, 0.010056160390377045, 0.0845566838979721, 0.03108036518096924, 0.16015855967998505, 0.30321791768074036, 0.04101235046982765, 0.0719088688492775, 0.1578858345746994, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6134085655212402, 0.1547522246837616, 0.03818102553486824, 0.001013039844110608, 0.013297338038682938, 0.008754062466323376, 0.005134810693562031, 0.0324203222990036, 0.13303862512111664, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6891250014305115, 0.17779399454593658, 0.09809523820877075, 0.006996517535299063, 0.007719202898442745, 0.0016296659596264362, 0.010662317276000977, 0.004304768517613411, 0.0036729834973812103, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04376668110489845, 0.09640005975961685, 0.8100467324256897, 0.018579678609967232, 0.017539000138640404, 0.0008903089328669012, 0.0009985471842810512, 0.003613307373598218, 0.008165487088263035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03085213713347912, 0.025543441995978355, 0.6937543153762817, 0.17392684519290924, 0.03124413825571537, 0.02177071012556553, 0.007475809659808874, 0.003389933379366994, 0.012042560614645481, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.020024498924613, 0.002941351616755128, 0.05481509119272232, 0.183584526181221, 0.4182366132736206, 0.25923243165016174, 0.05362166836857796, 0.0045484029687941074, 0.002995501272380352, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006091661751270294, 0.0012010806240141392, 0.008193010464310646, 0.009258490055799484, 0.15450483560562134, 0.7388086915016174, 0.06675267219543457, 0.01373466569930315, 0.0014547830214723945, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0014694302808493376, 0.0017220929730683565, 0.005703628528863192, 0.0032696493435651064, 0.01713697426021099, 0.49356934428215027, 0.3729664385318756, 0.05505490303039551, 0.04910748079419136, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0052343131974339485, 0.004969605710357428, 0.005609327927231789, 0.0007064095698297024, 0.005421568639576435, 0.045942794531583786, 0.22256441414356232, 0.43683722615242004, 0.27271413803100586, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011939328163862228, 0.019054703414440155, 0.010745645500719547, 0.006908759940415621, 0.009522099047899246, 0.006889646407216787, 0.12289831787347794, 0.2292226105928421, 0.5828191637992859, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03494315221905708, 0.965056836605072, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9822245836257935, 0.017775410786271095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.020348060876131058, 0.8944171071052551, 0.08523476868867874, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9783667922019958, 0.004186260513961315, 0.01744689606130123, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0015979396412149072, 0.6347042918205261, 0.09008561074733734, 0.27361196279525757, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8277915120124817, 0.0035995396319776773, 0.1268300712108612, 0.04177885130047798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01025437843054533, 0.17247439920902252, 0.3664330542087555, 0.4087805449962616, 0.04205762594938278, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9593387246131897, 0.001320014358498156, 0.002763292985036969, 0.002305841539055109, 0.03427214175462723, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012186901643872261, 0.3028968572616577, 0.12117700278759003, 0.3522109389305115, 0.06255244463682175, 0.14897578954696655, 0.0, 0.0, 0.0, 0.0, 0.5380056500434875, 0.00011044789425795898, 0.001150083844549954, 0.002725756261497736, 0.45681822299957275, 0.0011898496886715293, 0.0, 0.0, 0.0, 0.0], [0.010822800919413567, 0.2333739995956421, 0.11113002151250839, 0.15861180424690247, 0.11286703497171402, 0.2766783833503723, 0.0965159684419632, 0.0, 0.0, 0.0, 0.16147758066654205, 0.001678255619481206, 0.004225697834044695, 0.012547606602311134, 0.4120558202266693, 0.030565770342946053, 0.37744930386543274, 0.0, 0.0, 0.0], [0.00965114776045084, 0.19982098042964935, 0.054301097989082336, 0.13056904077529907, 0.03828747197985649, 0.4827912747859955, 0.05511533096432686, 0.029463520273566246, 0.0, 0.0, 0.07655133306980133, 0.00011485892173368484, 0.0004792730906046927, 0.0037317569367587566, 0.9091346859931946, 0.005207230802625418, 0.003226343309506774, 0.0015543886693194509, 0.0, 0.0], [0.014548483304679394, 0.07520423084497452, 0.1090526208281517, 0.14237697422504425, 0.030428709462285042, 0.5021095275878906, 0.026151562109589577, 0.04390878602862358, 0.05621904134750366, 0.0, 0.0006837816908955574, 6.692374881822616e-05, 3.2170661143027246e-05, 0.017242103815078735, 0.9703013896942139, 0.0009919245494529605, 0.00010187587758991867, 0.00012404048175085336, 0.01045528706163168, 0.0], [0.000422637298470363, 0.17123113572597504, 0.04347287863492966, 0.10408183932304382, 0.013075248338282108, 0.5476951003074646, 0.020964276045560837, 0.019243689253926277, 0.0612923838198185, 0.018520813435316086, 0.8681296706199646, 0.004244405776262283, 0.0034055972937494516, 0.0032342004124075174, 0.11890427023172379, 0.00032322408515028656, 1.7166490579256788e-05, 8.356601756531745e-05, 0.00016651467012707144, 0.0014914675848558545]], [[0.0014003654941916466, 0.00935011450201273, 0.8996742963790894, 0.029868578538298607, 0.05752851441502571, 0.0008847691351547837, 0.0005429417942650616, 0.0004143548430874944, 0.00033632174017839134, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0005502321291714907, 0.003854800947010517, 0.8475468754768372, 0.06876953691244125, 0.07909266650676727, 5.498397149494849e-05, 2.1647396351909265e-05, 6.648269391007489e-06, 0.00010276718239765614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0025599629152566195, 0.010113149881362915, 0.21385346353054047, 0.26065483689308167, 0.44287386536598206, 0.0458405464887619, 0.013329384848475456, 0.0076821851544082165, 0.0030928871128708124, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0002600199659354985, 3.3608048397582024e-05, 0.0020931970793753862, 0.007768034934997559, 0.9780486822128296, 0.011327453888952732, 0.00041993538616225123, 4.125805935473181e-05, 8.07127889856929e-06, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0010751935187727213, 0.00017567894246894866, 0.004301255568861961, 0.0010412797564640641, 0.012584774754941463, 0.5903621912002563, 0.36841556429862976, 0.021853862330317497, 0.00019013854034710675, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00036065353197045624, 0.00041391997365280986, 0.00018344201089348644, 1.21664334074012e-05, 0.0008204621262848377, 0.02300320193171501, 0.7380199432373047, 0.23411831259727478, 0.0030676021706312895, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0007766868220642209, 0.00179819215554744, 0.0031821478623896837, 1.569229607412126e-05, 0.001023828866891563, 0.004582487046718597, 0.04412461444735527, 0.8326310515403748, 0.11186514794826508, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002560202032327652, 0.0021961459424346685, 0.0012966376962140203, 3.874531466863118e-05, 0.00012789985339622945, 0.00017348439723718911, 0.06046983227133751, 0.07663179188966751, 0.856505274772644, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05078713223338127, 0.09524610638618469, 0.03648101165890694, 0.050540339201688766, 0.009611092507839203, 0.0027538249269127846, 0.009690326638519764, 0.015156174078583717, 0.7297340035438538, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017420543357729912, 0.009016300551593304, 0.008660875260829926, 0.04713813588023186, 0.042011067271232605, 0.003162879729643464, 0.00040178498602472246, 0.005153133533895016, 0.8670352697372437, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9947329163551331, 0.005267037078738213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9673911333084106, 0.032608743757009506, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7284466028213501, 0.21829284727573395, 0.05326057970523834, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8945506811141968, 0.048047225922346115, 0.05740200728178024, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7024527192115784, 0.0454108789563179, 0.10381712764501572, 0.14831924438476562, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8226539492607117, 0.025171183049678802, 0.033602889627218246, 0.1185719221830368, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2374107390642166, 0.04589728266000748, 0.2683154046535492, 0.3902822434902191, 0.0580943301320076, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7488189339637756, 0.022310951724648476, 0.03220387548208237, 0.05049983412027359, 0.14616648852825165, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7228419780731201, 0.007619804237037897, 0.013993922621011734, 0.04429992660880089, 0.020430808886885643, 0.19081364572048187, 0.0, 0.0, 0.0, 0.0, 0.5947939157485962, 0.009725339710712433, 0.01194794476032257, 0.06678443402051926, 0.22137242555618286, 0.09537594765424728, 0.0, 0.0, 0.0, 0.0], [0.4783930778503418, 0.005506142508238554, 0.008406496606767178, 0.012424511834979057, 0.04335693642497063, 0.17542317509651184, 0.27648961544036865, 0.0, 0.0, 0.0, 0.5493549704551697, 0.010730843059718609, 0.013811847195029259, 0.01375968661159277, 0.13386781513690948, 0.031593821942806244, 0.2468811273574829, 0.0, 0.0, 0.0], [0.056768160313367844, 0.001066300319507718, 0.0015203694347292185, 0.004650356248021126, 0.004999558907002211, 0.17368057370185852, 0.7387632131576538, 0.018551528453826904, 0.0, 0.0, 0.44999176263809204, 0.0022518665064126253, 0.007128801662474871, 0.06941325962543488, 0.11436374485492706, 0.06527625769376755, 0.25339174270629883, 0.038182370364665985, 0.0, 0.0], [0.14709600806236267, 0.007261540275067091, 0.001291902968659997, 0.012605146504938602, 0.005232691299170256, 0.08098926395177841, 0.5304067134857178, 0.207069993019104, 0.00804678164422512, 0.0, 0.6273319125175476, 0.0019851899705827236, 0.014608433470129967, 0.053566914051771164, 0.10037831962108612, 0.05395424738526344, 0.09709113836288452, 0.020020073279738426, 0.031063806265592575, 0.0], [0.15080930292606354, 0.014301316812634468, 0.002821019385010004, 0.02008463814854622, 0.004475536290556192, 0.05297520384192467, 0.27036672830581665, 0.407105028629303, 0.007729486562311649, 0.06933178007602692, 0.13732852041721344, 0.005784862674772739, 0.011142567731440067, 0.3659982979297638, 0.03412118926644325, 0.191008523106575, 0.02493627928197384, 0.01782877929508686, 0.005097466055303812, 0.2067534178495407]], [[0.22553573548793793, 0.2680850327014923, 0.019470686092972755, 0.14175784587860107, 0.053468361496925354, 0.02777918614447117, 0.05628729239106178, 0.04874898120760918, 0.15886712074279785, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.28905513882637024, 0.12247822433710098, 0.046002231538295746, 0.1958596557378769, 0.10771062225103378, 0.06661061197519302, 0.07628067582845688, 0.02713944762945175, 0.06886337697505951, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04905243590474129, 0.05268532782793045, 0.11285670101642609, 0.09091109782457352, 0.24185867607593536, 0.20752739906311035, 0.04222555831074715, 0.05885446071624756, 0.14402832090854645, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06971512734889984, 0.14066818356513977, 0.05942149832844734, 0.21028849482536316, 0.10966084897518158, 0.08002462983131409, 0.10722756385803223, 0.1377343237400055, 0.08525940030813217, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1429702192544937, 0.26978883147239685, 0.12360350787639618, 0.05825580656528473, 0.022957824170589447, 0.2193503975868225, 0.0713224932551384, 0.06461618840694427, 0.02713468112051487, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07554306834936142, 0.051579318940639496, 0.2103901356458664, 0.03246254473924637, 0.12347473949193954, 0.20594589412212372, 0.10415074229240417, 0.14436782896518707, 0.05208563804626465, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10752540081739426, 0.08459899574518204, 0.07340764254331589, 0.019914846867322922, 0.048802055418491364, 0.2628321945667267, 0.23049965500831604, 0.11754198372364044, 0.05487721040844917, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.054300110787153244, 0.03522595763206482, 0.19028180837631226, 0.11526520550251007, 0.043804410845041275, 0.1941872388124466, 0.12765192985534668, 0.19942660629749298, 0.03985673561692238, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13462598621845245, 0.09648311138153076, 0.08205218613147736, 0.241444393992424, 0.024601474404335022, 0.03336581960320473, 0.09252338856458664, 0.0673752948641777, 0.22752824425697327, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1438782811164856, 0.15257491171360016, 0.11015111207962036, 0.2259429395198822, 0.11582648009061813, 0.06522659957408905, 0.06865230947732925, 0.07465960830450058, 0.04308782145380974, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9945669174194336, 0.005433134268969297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9590145349502563, 0.0409853532910347, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9554939270019531, 0.02177131362259388, 0.0227347444742918, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13186156749725342, 0.7104970812797546, 0.15764127671718597, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.19059398770332336, 0.7459079623222351, 0.05105874687433243, 0.012439398095011711, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1307007521390915, 0.4791290760040283, 0.2198515087366104, 0.1703186184167862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.062025006860494614, 0.7277394533157349, 0.13110491633415222, 0.028790757060050964, 0.050339892506599426, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25735223293304443, 0.03605807572603226, 0.08834479749202728, 0.21978884935379028, 0.398455947637558, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7678350806236267, 0.007377212401479483, 0.020054306834936142, 0.11815592646598816, 0.07254840433597565, 0.014029012061655521, 0.0, 0.0, 0.0, 0.0, 0.014754761941730976, 0.016280202195048332, 0.010505245067179203, 0.26496851444244385, 0.6780229210853577, 0.015468388795852661, 0.0, 0.0, 0.0, 0.0], [0.8187481760978699, 0.009394909255206585, 0.015446240082383156, 0.012167787179350853, 0.10175905376672745, 0.02721206098794937, 0.01527167297899723, 0.0, 0.0, 0.0, 0.0561433881521225, 0.00821017101407051, 0.013592599891126156, 0.04250938817858696, 0.20505541563034058, 0.637790322303772, 0.03669866546988487, 0.0, 0.0, 0.0], [0.7012083530426025, 0.12151088565587997, 0.03808446228504181, 0.01883355714380741, 0.0837249755859375, 0.006598148960620165, 0.006499246694147587, 0.023540453985333443, 0.0, 0.0, 0.02288638986647129, 0.0031705975998193026, 0.0010986417764797807, 0.1258203089237213, 0.13997967541217804, 0.6275703310966492, 0.004779829643666744, 0.07469423860311508, 0.0, 0.0], [0.5152325630187988, 0.054241329431533813, 0.17093418538570404, 0.020541386678814888, 0.17657014727592468, 0.012641755864024162, 0.01802964322268963, 0.023539982736110687, 0.008269038051366806, 0.0, 0.04480466619133949, 0.007826470769941807, 0.0012622721260413527, 0.18829701840877533, 0.1579897105693817, 0.4087865948677063, 0.0030938636045902967, 0.17715193331241608, 0.010787548497319221, 0.0], [0.9131196141242981, 0.0010915634920820594, 0.006193474866449833, 0.006082434672862291, 0.03542511910200119, 0.006826554890722036, 0.0028478680178523064, 0.004068343434482813, 0.014553201384842396, 0.009791722521185875, 0.2647387683391571, 0.0023117128293961287, 0.5836825370788574, 0.022214042022824287, 0.05302866920828819, 0.05609899014234543, 0.0002153095556423068, 0.0012429821072146297, 0.012765316292643547, 0.0037017168942838907]]], [[[0.008583037182688713, 0.007665919605642557, 0.023932937532663345, 0.013663848862051964, 0.00724611384794116, 0.01780843734741211, 0.04220886155962944, 0.035630952566862106, 0.8432599306106567, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.005249040201306343, 0.006725347600877285, 0.022601336240768433, 0.004061485640704632, 0.003380684182047844, 0.05792760103940964, 0.08571713417768478, 0.017759306356310844, 0.796578049659729, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014741344377398491, 0.08626628667116165, 0.11416944116353989, 0.06755448132753372, 0.010767532512545586, 0.037519536912441254, 0.13943251967430115, 0.03284287825226784, 0.4967060387134552, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8946033120155334, 0.07520093768835068, 0.007621173746883869, 0.004705401603132486, 0.005715447012335062, 0.0016736779361963272, 0.0011882666731253266, 0.0005322583019733429, 0.008759708143770695, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.17331360280513763, 0.32618802785873413, 0.1865183413028717, 0.12219864875078201, 0.08427056670188904, 0.017049826681613922, 0.027256622910499573, 0.011689829640090466, 0.05151442065834999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.024287043139338493, 0.22289688885211945, 0.2742122411727905, 0.1883603185415268, 0.1339159905910492, 0.04209006950259209, 0.04496186599135399, 0.03600992262363434, 0.033265650272369385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01142946071922779, 0.05564042925834656, 0.055694323033094406, 0.5140662789344788, 0.1435396671295166, 0.038738954812288284, 0.06230159476399422, 0.07060025632381439, 0.047988954931497574, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03956271708011627, 0.0978141501545906, 0.053332336246967316, 0.4993227422237396, 0.15091775357723236, 0.05724353715777397, 0.05616844817996025, 0.014285729266703129, 0.03135249391198158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04081583395600319, 0.017569201067090034, 0.031049959361553192, 0.07860688865184784, 0.1978374421596527, 0.3013133406639099, 0.2561938464641571, 0.010236106812953949, 0.06637723743915558, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.005346705671399832, 0.017637349665164948, 0.01670711860060692, 0.027819450944662094, 0.014111858792603016, 0.15744496881961823, 0.29349666833877563, 0.10989060997962952, 0.357545405626297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16448259353637695, 0.17219680547714233, 0.09987642616033554, 0.09012344479560852, 0.06534503400325775, 0.08456553518772125, 0.06690192222595215, 0.08019057661294937, 0.17631761729717255, 0.0, 0.18620921671390533, 0.0449230894446373, 0.15743261575698853, 0.0027164025232195854, 0.000954743183683604, 0.10880818217992783, 0.004260051064193249, 0.4840531051158905, 0.010642877779901028, 0.0], [0.49537378549575806, 0.03979916125535965, 0.09498286247253418, 0.0017974335933104157, 0.028368383646011353, 0.0015277893980965018, 0.014851069077849388, 0.0003722719266079366, 0.3229270279407501, 0.0, 0.10068266838788986, 0.8361198902130127, 0.05278307944536209, 0.003077939385548234, 0.0006954235723242164, 0.001363753923214972, 0.00026539582177065313, 0.004202431067824364, 0.0008096573874354362, 0.0], [0.0031106590759009123, 0.8318147659301758, 0.0329316072165966, 0.00014872441533952951, 0.000739947019610554, 0.0009879706194624305, 0.0012947155628353357, 0.00040531408740207553, 0.128566175699234, 0.0, 0.012129311449825764, 0.01155073568224907, 0.9600933194160461, 8.282387716462836e-05, 1.0725593710958492e-05, 0.0005505315493792295, 8.825069380691275e-05, 0.015057343989610672, 0.00043726651347242296, 0.0], [3.727031798916869e-05, 0.00033458907273598015, 0.9051278829574585, 0.014809494838118553, 0.0013665216974914074, 0.0009820980485528708, 0.0004274636448826641, 0.0006300737150013447, 0.07628484070301056, 0.0, 8.100323611870408e-05, 0.0004598332743626088, 0.004657193087041378, 0.000634010590147227, 0.00027469659107737243, 0.005632649641484022, 0.000647437758743763, 0.9867796301841736, 0.0008332319557666779, 0.0], [2.789895370369777e-05, 7.413508137688041e-05, 0.00011113573418697342, 0.9593441486358643, 0.023210706189274788, 0.00043970797560177743, 0.00011651179374894127, 0.0001221746060764417, 0.016553271561861038, 0.0, 0.00010327257041353732, 8.895192149793729e-05, 0.0004001102061010897, 3.5898548958357424e-05, 8.903054549591616e-06, 0.002168947132304311, 0.0003314291825518012, 0.9968016743659973, 6.082480831537396e-05, 0.0], [5.518151283467887e-06, 4.040239218738861e-06, 4.706911568064243e-06, 0.0001475349417887628, 0.0011833186727017164, 0.007331210654228926, 0.0003812467912212014, 0.7072276473045349, 0.28371480107307434, 0.0, 0.0006819640402682126, 0.0025551444850862026, 0.029635878279805183, 0.0007182788685895503, 0.0009121407056227326, 0.9391846656799316, 0.0023257755674421787, 0.020892569795250893, 0.0030933902598917484, 0.0], [2.1062598989374237e-06, 1.0153020184588968e-06, 9.153064297606761e-07, 2.3557351596537046e-05, 0.0019158869981765747, 0.9726926684379578, 0.0003360892878845334, 0.008161749690771103, 0.01686590164899826, 0.0, 0.0006610184791497886, 0.004029686562716961, 0.03350083529949188, 0.0028945906087756157, 0.06891647726297379, 0.0361749529838562, 0.6805889010429382, 0.0015104033518582582, 0.17172299325466156, 0.0], [1.876308124337811e-05, 3.1762643629917875e-05, 7.612020908709383e-06, 4.369785983726615e-06, 0.00035698129795491695, 0.006292039528489113, 0.9372867941856384, 0.0028216273058205843, 0.0531802624464035, 0.0, 0.00011510718468343839, 0.00041600633994676173, 0.007651225198060274, 0.0003919293521903455, 0.048794399946928024, 0.12390702962875366, 0.005600529722869396, 0.0008058404200710356, 0.8123176097869873, 0.0], [0.00017082327394746244, 0.0008267413941211998, 0.0010992212919518352, 0.016357675194740295, 0.03317699581384659, 0.013446258381009102, 0.022417983040213585, 0.0993492603302002, 0.813154935836792, 0.0, 0.0003188557457178831, 0.0017433647299185395, 0.0013032852439209819, 0.008202485740184784, 0.26753997802734375, 0.1699969321489334, 0.02015369012951851, 0.026912324130535126, 0.5038290619850159, 0.0], [2.095436911986326e-06, 1.0510404990782263e-06, 8.745904779061675e-06, 9.465758921578526e-05, 0.9096792936325073, 0.004888555034995079, 0.00019891942793037742, 0.00012723646068479866, 0.08499950170516968, 0.0, 0.020566454157233238, 0.12752646207809448, 0.13235142827033997, 8.515831723343581e-05, 0.0007726486655883491, 0.005525102838873863, 0.002064254367724061, 0.0015006973408162594, 0.7096077799797058, 0.0]], [[0.14326919615268707, 0.06937730312347412, 0.4621289074420929, 0.06899607926607132, 0.20691490173339844, 0.03204977884888649, 0.010433961637318134, 0.001572124194353819, 0.005257652141153812, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7372201681137085, 0.03819188475608826, 0.19263039529323578, 0.00509582320228219, 0.014029700309038162, 0.004338367842137814, 0.0016640998655930161, 0.0023727945517748594, 0.004456941969692707, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6392468810081482, 0.09436309337615967, 0.23124097287654877, 0.009032140485942364, 0.016629014164209366, 0.004053707234561443, 0.0011662752367556095, 0.0013368013314902782, 0.0029307324439287186, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15959776937961578, 0.060010410845279694, 0.6323540210723877, 0.04208587482571602, 0.09941276162862778, 0.001314919558353722, 0.0003186642425134778, 0.00045829309965483844, 0.004447522107511759, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06331828236579895, 0.03697410970926285, 0.6882537603378296, 0.04094800353050232, 0.1500014215707779, 0.014815385453402996, 0.0006663103122264147, 0.0014023728435859084, 0.0036205528303980827, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02740752510726452, 0.007235638331621885, 0.2575177550315857, 0.2825733423233032, 0.26921361684799194, 0.13694509863853455, 0.012512636370956898, 0.00419765617698431, 0.0023968773894011974, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026527998968958855, 0.0014296816661953926, 0.0034867397043854, 0.11850380897521973, 0.15826237201690674, 0.4342584013938904, 0.21162042021751404, 0.04376554489135742, 0.0021449460182338953, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0008783259545452893, 0.0010965524706989527, 0.006981557235121727, 0.007060014642775059, 0.27200379967689514, 0.45634904503822327, 0.1935150921344757, 0.03130912408232689, 0.030806703492999077, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012816469185054302, 0.004784241784363985, 0.007290879264473915, 0.0027244724333286285, 0.0388973169028759, 0.12052476406097412, 0.3920805752277374, 0.10759556293487549, 0.3132855296134949, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0021361028775572777, 0.003133963793516159, 0.003311034757643938, 0.0013810866512358189, 0.004479007329791784, 0.007041627541184425, 0.09507600963115692, 0.5596640706062317, 0.32377713918685913, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09510962665081024, 0.13984361290931702, 0.01835908181965351, 0.05623754486441612, 0.05484192445874214, 0.02751241996884346, 0.023350151255726814, 0.02046714909374714, 0.5642784833908081, 0.0, 0.08830718696117401, 0.003260435536503792, 0.007942354306578636, 0.007197668310254812, 0.023230358958244324, 0.6884769797325134, 0.13524922728538513, 0.013760159723460674, 0.03257569298148155, 0.0], [0.32246580719947815, 0.12212380021810532, 0.0033711090218275785, 0.41883695125579834, 0.0010050723794847727, 0.00026374190929345787, 0.00840060692280531, 0.0003199145139660686, 0.12321317940950394, 0.0, 0.01410764642059803, 0.011476421728730202, 0.655226469039917, 0.029443562030792236, 0.17404575645923615, 0.04738258570432663, 0.035108331590890884, 0.004049936309456825, 0.02915901131927967, 0.0], [0.1343918889760971, 0.42756012082099915, 0.03016146458685398, 0.27197346091270447, 0.0008738918695598841, 0.00041738885920494795, 0.0011337834876030684, 0.0017680631717666984, 0.13172008097171783, 0.0, 0.006112441886216402, 0.010383019223809242, 0.9739192724227905, 0.0017695348942652345, 0.0007649966282770038, 0.001380802714265883, 0.0003705607377924025, 0.00034036929719150066, 0.004958811681717634, 0.0], [4.970032023265958e-05, 0.0002945268643088639, 0.9929893612861633, 0.006102537736296654, 1.304412307945313e-06, 7.552243459940655e-06, 2.0433815279830014e-06, 1.4308750905911438e-05, 0.0005390164442360401, 0.0, 0.025388794019818306, 0.006199578754603863, 0.10192698240280151, 0.0023500584065914154, 0.009979050606489182, 0.5388055443763733, 0.29305511713027954, 0.002850176068022847, 0.0194447822868824, 0.0], [0.0006735534407198429, 0.0037932321429252625, 0.014864870347082615, 0.9520841240882874, 0.0031083461362868547, 0.0014454165939241648, 0.000881638377904892, 0.00042032121564261615, 0.02272843010723591, 0.0, 0.0011180925648659468, 3.349311737110838e-05, 0.00020844468963332474, 0.00016400347521994263, 0.001158660277724266, 0.5398337244987488, 0.4514371454715729, 0.00012239665375091136, 0.005924074444919825, 0.0], [1.054488166118972e-06, 5.819076250190847e-06, 3.686256491164386e-07, 5.7184315664926544e-05, 1.600286668690387e-05, 0.0002979082928504795, 5.8259040088159963e-05, 0.997514009475708, 0.0020495890639722347, 0.0, 4.934398384648375e-05, 6.905893883413228e-07, 5.809057256556116e-06, 1.44853029269143e-05, 0.0013859024038538337, 0.62599116563797, 0.3719564974308014, 0.0002632574178278446, 0.00033293903106823564, 0.0], [1.2081607110303594e-06, 1.8248301785206422e-06, 3.5412674037615943e-07, 0.00017610432405490428, 0.0004308871575631201, 0.9919483065605164, 0.001251595327630639, 0.004008213523775339, 0.002181792864575982, 0.0, 1.8935834305011667e-05, 5.593590231001144e-06, 9.02482042874908e-06, 4.666295353672467e-05, 0.00140501803252846, 0.0024830379988998175, 0.9939435124397278, 0.00030495785176754, 0.0017833412857726216, 0.0], [1.3394396773946937e-06, 1.858925656961219e-06, 8.99223309147601e-08, 5.498410246218555e-06, 4.1167979361489415e-05, 0.003499603597447276, 0.9961592555046082, 8.322765097545926e-06, 0.0002831367892213166, 0.0, 0.00015082204481586814, 9.979225069400854e-06, 0.00013493606820702553, 0.0006857623811811209, 0.9507938623428345, 0.013522839173674583, 0.004887807182967663, 0.001293701701797545, 0.028520429506897926, 0.0], [0.0011697824811562896, 0.00207342766225338, 0.0001985222043003887, 0.24218614399433136, 0.2580603361129761, 0.03422079235315323, 0.3017951250076294, 0.0700761154294014, 0.09021952003240585, 0.0, 0.00021830093464814126, 1.1190621080459096e-05, 0.0010014179861173034, 0.0016852812841534615, 0.9693949818611145, 0.003066261066123843, 0.002616706071421504, 0.006246546749025583, 0.015759343281388283, 0.0], [4.897859540164973e-08, 1.9182496657776937e-07, 1.6890984966266842e-07, 0.00012898082786705345, 0.9986647963523865, 0.0003688811557367444, 8.465539576718584e-05, 1.2611121746886056e-05, 0.0007397857843898237, 0.0, 0.033513687551021576, 0.047761499881744385, 0.1371326446533203, 0.027179328724741936, 0.07905351370573044, 0.04665757715702057, 0.017991477623581886, 0.0258343443274498, 0.5848759412765503, 0.0]], [[0.001748488168232143, 0.011698327027261257, 0.047558922320604324, 0.7770814299583435, 0.15215088427066803, 0.0056790816597640514, 0.0010312696686014533, 0.0011229184456169605, 0.0019287114264443517, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.000820137036498636, 0.0007328591891564429, 0.012266330420970917, 0.94822758436203, 0.02221596986055374, 0.006038068328052759, 0.0018012026557698846, 0.002194090047851205, 0.0057037402875721455, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0017187671037390828, 0.0012595502194017172, 0.00971528235822916, 0.8996129631996155, 0.03184645250439644, 0.026646586135029793, 0.01671759784221649, 0.005960865877568722, 0.006522092968225479, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.010048117488622665, 0.003920346032828093, 0.01464000903069973, 0.028398782014846802, 0.047600653022527695, 0.6803404688835144, 0.07394693046808243, 0.046145662665367126, 0.09495888650417328, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0020061242394149303, 0.0010488562984392047, 0.0021137045696377754, 0.03403143212199211, 0.040159616619348526, 0.4656003415584564, 0.16990402340888977, 0.16164875030517578, 0.12348736822605133, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0023888982832431793, 0.0010238748509436846, 0.0031129145063459873, 0.00400560162961483, 0.005227341782301664, 0.050918273627758026, 0.28773385286331177, 0.5181463956832886, 0.12744267284870148, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0057381619699299335, 0.0037375285755842924, 0.006655727047473192, 0.0010085925459861755, 0.005980721674859524, 0.02943945676088333, 0.05893365666270256, 0.6100658774375916, 0.2784405052661896, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003593636676669121, 0.0024473541416227818, 0.002264569513499737, 0.00914584007114172, 0.0013253247598186135, 0.010908454656600952, 0.07958614826202393, 0.12585432827472687, 0.7648744583129883, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.031058229506015778, 0.02174283377826214, 0.012145284563302994, 0.010826506651937962, 0.01352943666279316, 0.021966811269521713, 0.055832888931035995, 0.11603516340255737, 0.7168627977371216, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20383700728416443, 0.06762446463108063, 0.042199794203042984, 0.021983252838253975, 0.11625738441944122, 0.013579235412180424, 0.025292381644248962, 0.08914806693792343, 0.4200783669948578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008738831616938114, 0.010689073242247105, 0.010104849003255367, 0.025418052449822426, 0.008787600323557854, 0.018541773781180382, 0.01414045225828886, 0.009587875567376614, 0.8939914107322693, 0.0, 0.3675236701965332, 0.22013956308364868, 0.3048599064350128, 0.045011524111032486, 0.013697491027414799, 0.012050136923789978, 0.009531261399388313, 0.0020223394967615604, 0.025163909420371056, 0.0], [0.050771377980709076, 0.08173098415136337, 0.03076810948550701, 0.6816214919090271, 0.04326915368437767, 0.0030209666583687067, 0.006032166071236134, 0.007633579429239035, 0.09515213221311569, 0.0, 0.013416368514299393, 0.7244334816932678, 0.22923606634140015, 0.004823721945285797, 0.0007022434147074819, 0.0012150612892583013, 0.001360778696835041, 0.00021415007358882576, 0.024598030373454094, 0.0], [0.04749365150928497, 0.07148067653179169, 0.018722670152783394, 0.5845115184783936, 0.03816590458154678, 0.003933309111744165, 0.006466464139521122, 0.021205652505159378, 0.20802012085914612, 0.0, 0.03640636429190636, 0.024720389395952225, 0.8944843411445618, 0.0018058173591271043, 0.00014742508938070387, 0.002046161564067006, 0.0012721297098323703, 0.0010774562833830714, 0.0380399152636528, 0.0], [0.021572547033429146, 0.11727327853441238, 0.03622674569487572, 0.4274545907974243, 0.05620160698890686, 0.01161592174321413, 0.010393376462161541, 0.014363090507686138, 0.30489882826805115, 0.0, 0.032080236822366714, 0.02157183177769184, 0.017530914396047592, 0.21374234557151794, 0.5176447033882141, 0.021586988121271133, 0.06124785542488098, 0.004810539539903402, 0.10978466272354126, 0.0], [0.015270093455910683, 0.10013995319604874, 0.006727923639118671, 0.19538360834121704, 0.1119888573884964, 0.027630485594272614, 0.0700199231505394, 0.01868581771850586, 0.4541531801223755, 0.0, 0.16469916701316833, 0.0144515885040164, 0.007452514488250017, 0.029052020981907845, 0.2643658220767975, 0.1970161497592926, 0.2818319797515869, 0.016781603917479515, 0.024349281564354897, 0.0], [0.00540963327512145, 0.07916348427534103, 0.01957465149462223, 0.49324244260787964, 0.10871188342571259, 0.02422497235238552, 0.008650544099509716, 0.16292543709278107, 0.0980970561504364, 0.0, 0.025996195152401924, 0.005627068690955639, 0.007119623012840748, 0.004898787476122379, 0.5349600911140442, 0.05678911507129669, 0.3094601333141327, 0.008422048762440681, 0.04672713205218315, 0.0], [0.027941647917032242, 0.005471521522849798, 0.006384703796356916, 0.03924928605556488, 0.22657036781311035, 0.21837352216243744, 0.3372570872306824, 0.05897291377186775, 0.07977905124425888, 0.0, 0.004280757624655962, 0.0006373892538249493, 9.946383943315595e-05, 0.00030879577388986945, 0.02805289998650551, 0.008433223702013493, 0.9252934455871582, 0.001439885818399489, 0.03145414590835571, 0.0], [0.009049936197698116, 0.005020579323172569, 0.014692768454551697, 0.15799382328987122, 0.4401932656764984, 0.1766415536403656, 0.03136269003152847, 0.12063619494438171, 0.044409021735191345, 0.0, 0.04426492750644684, 0.0032368048559874296, 0.0014763016952201724, 0.0021763627883046865, 0.5636131763458252, 0.010265699587762356, 0.08146306872367859, 0.003517861943691969, 0.289985716342926, 0.0], [0.0007816475699655712, 0.0003147682291455567, 0.0032215022947639227, 0.4467180669307709, 0.3918246924877167, 0.00227341428399086, 0.004370422102510929, 0.14414219558238983, 0.006353371310979128, 0.0, 0.012160537764430046, 0.00020874926121905446, 0.0005602578166872263, 0.0007960868533700705, 0.9389106035232544, 0.005963308271020651, 0.005384649150073528, 0.0009963578777387738, 0.035019390285015106, 0.0], [0.0005489268223755062, 0.016601460054516792, 0.01341363787651062, 0.2753817141056061, 0.13981539011001587, 0.04711242765188217, 0.08167178928852081, 0.11951272189617157, 0.30594193935394287, 0.0, 0.006462599150836468, 0.006167746149003506, 0.00141435069963336, 0.00035615835804492235, 0.0002947094908449799, 0.002378113567829132, 0.011835698038339615, 0.0024426754098385572, 0.968647837638855, 0.0]], [[0.022736268118023872, 0.02286626398563385, 0.14116300642490387, 0.13108347356319427, 0.23994718492031097, 0.1924150437116623, 0.01816762052476406, 0.04976898059248924, 0.18185211718082428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05882957577705383, 0.028569074347615242, 0.23305171728134155, 0.053790394216775894, 0.18451730906963348, 0.2002667486667633, 0.015585620887577534, 0.052768219262361526, 0.17262138426303864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09136874228715897, 0.08459936082363129, 0.05023255571722984, 0.21660202741622925, 0.1335863471031189, 0.10654665529727936, 0.02717875875532627, 0.06888726353645325, 0.22099831700325012, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04131297022104263, 0.05848437175154686, 0.3077566921710968, 0.040097035467624664, 0.16343727707862854, 0.11984208226203918, 0.06441103667020798, 0.0850440189242363, 0.11961443722248077, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06447532773017883, 0.05503746494650841, 0.11529060453176498, 0.13719302415847778, 0.0843825414776802, 0.22279226779937744, 0.11870565265417099, 0.05292103812098503, 0.14920207858085632, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.061820220202207565, 0.03663187846541405, 0.08412205427885056, 0.386857271194458, 0.1083698719739914, 0.1462787538766861, 0.03903358429670334, 0.026668915525078773, 0.11021733283996582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08746915310621262, 0.025642354041337967, 0.16437062621116638, 0.19346435368061066, 0.10867251455783844, 0.12237238138914108, 0.06722743809223175, 0.0922309011220932, 0.13855047523975372, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10294228792190552, 0.07313423603773117, 0.18607352674007416, 0.09769721329212189, 0.1089077964425087, 0.26933327317237854, 0.06555335968732834, 0.061070602387189865, 0.03528755530714989, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12094805389642715, 0.14730192720890045, 0.09877816587686539, 0.21085986495018005, 0.06241541728377342, 0.22994481027126312, 0.04595630243420601, 0.04531335458159447, 0.0384821854531765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11032164841890335, 0.07897982746362686, 0.08231978863477707, 0.2677886188030243, 0.1231643408536911, 0.0929633229970932, 0.08270144462585449, 0.06097007542848587, 0.10079105943441391, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11438923329114914, 0.12380287796258926, 0.23573537170886993, 0.19010169804096222, 0.15611350536346436, 0.031749427318573, 0.02482231892645359, 0.05017237365245819, 0.07311322540044785, 0.0, 0.013161101378500462, 0.01350532379001379, 0.39494189620018005, 0.007352527230978012, 0.12711142003536224, 0.14605116844177246, 0.03487401455640793, 0.15623201429843903, 0.10677067190408707, 0.0], [0.002549531403928995, 0.03178577870130539, 0.17347589135169983, 0.2232668697834015, 0.49775105714797974, 0.018238944932818413, 0.005651220679283142, 0.03368452191352844, 0.013595964759588242, 0.0, 0.021876059472560883, 0.4906902313232422, 0.4596463143825531, 0.004091671667993069, 0.004464378114789724, 0.001156727666966617, 0.000353646173607558, 0.000146497564855963, 0.017574656754732132, 0.0], [0.0032994491048157215, 0.026504727080464363, 0.41210347414016724, 0.24245016276836395, 0.18897436559200287, 0.012874660082161427, 0.006452939473092556, 0.10089367628097534, 0.00644671730697155, 0.0, 0.005734701175242662, 0.026843877509236336, 0.9321272969245911, 0.00021884289162699133, 0.00045866103027947247, 0.0010309598874300718, 0.00017261962057091296, 0.003054215107113123, 0.030358724296092987, 0.0], [0.002998506650328636, 0.048583757132291794, 0.28224417567253113, 0.0846971943974495, 0.013445784337818623, 0.02188579924404621, 0.017656570300459862, 0.5155076384544373, 0.012980557046830654, 0.0, 0.0482722632586956, 0.14050070941448212, 0.4546079635620117, 0.0072937230579555035, 0.023873258382081985, 0.09857403486967087, 0.0516686774790287, 0.11766187101602554, 0.05754747614264488, 0.0], [0.004188622813671827, 0.028234833851456642, 0.022820167243480682, 0.058492597192525864, 0.19205521047115326, 0.08343320339918137, 0.07119973003864288, 0.4843534827232361, 0.0552222914993763, 0.0, 0.0020078516099601984, 0.002228439087048173, 0.111594557762146, 0.0033910104539245367, 0.08423032611608505, 0.17691271007061005, 0.14758752286434174, 0.4346924424171448, 0.037355244159698486, 0.0], [0.0038351663388311863, 0.015353971160948277, 0.01755588687956333, 0.06245748698711395, 0.1218588799238205, 0.07207991182804108, 0.02867230959236622, 0.5455195903778076, 0.13266700506210327, 0.0, 0.0008274781284853816, 0.0016531302826479077, 0.047970183193683624, 0.0006053023971617222, 0.22220103442668915, 0.6234129071235657, 0.05364101752638817, 0.012585645541548729, 0.03710317984223366, 0.0], [0.004144841339439154, 0.0048835063353180885, 0.0035110898315906525, 0.06276324391365051, 0.04069552943110466, 0.3603023290634155, 0.1472603678703308, 0.2116946280002594, 0.16474448144435883, 0.0, 2.7583497285377234e-05, 1.1631378583842888e-05, 4.4259006244828925e-05, 0.0006730516324751079, 0.599366307258606, 0.006597205530852079, 0.3886081576347351, 0.0003169252013321966, 0.004354946780949831, 0.0], [0.024624889716506004, 0.016127971932291985, 0.0073340879753232, 0.023849278688430786, 0.042295511811971664, 0.5078635215759277, 0.2884303331375122, 0.011452756822109222, 0.07802165299654007, 0.0, 2.752073669398669e-06, 2.0648456029448425e-06, 8.536147106497083e-06, 6.34281532256864e-05, 0.9992840886116028, 0.00028667543665505946, 7.951273437356576e-05, 3.5721727726922836e-06, 0.00026920961681753397, 0.0], [0.00880166981369257, 0.002673782641068101, 0.001370548619888723, 0.0061265453696250916, 0.02490534819662571, 0.2073771357536316, 0.3818575143814087, 0.1663341522216797, 0.20055335760116577, 0.0, 3.3996084312093444e-06, 2.1497796751646092e-06, 7.304265182028757e-06, 0.00018760550301522017, 0.99969482421875, 2.4790026145637967e-05, 3.4293629141757265e-05, 6.942725121916737e-06, 3.892222957802005e-05, 0.0], [0.012253189459443092, 0.02221212349832058, 0.002282155444845557, 0.10455729067325592, 0.4111727774143219, 0.08308815956115723, 0.045707643032073975, 0.03711223974823952, 0.2816142141819, 0.0, 0.0005689842510037124, 0.002939490834251046, 0.019829533994197845, 0.0003717679646797478, 0.01646142266690731, 0.011912180110812187, 0.001234701368957758, 0.0013870754046365619, 0.945294976234436, 0.0]], [[0.008687321096658707, 0.012162125669419765, 0.02774685248732567, 0.0013578477082774043, 0.052177976816892624, 0.027187975123524666, 0.05590689554810524, 0.020962538197636604, 0.7938104867935181, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.005042325239628553, 0.015503124333918095, 0.010042164474725723, 0.0008876739302650094, 0.011308688670396805, 0.010491759516298771, 0.03130592033267021, 0.04934320226311684, 0.8660751581192017, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013016406446695328, 0.03886239603161812, 0.027493299916386604, 0.029101338237524033, 0.009947741404175758, 0.00769558921456337, 0.035501737147569656, 0.023772817105054855, 0.8146085143089294, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.018851714208722115, 0.05105733126401901, 0.8005384206771851, 0.01116525661200285, 0.09583853930234909, 0.0015093896072357893, 0.005055624525994062, 0.0006665397086180747, 0.015317671000957489, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01609102450311184, 0.023716216906905174, 0.5135837197303772, 0.10603100061416626, 0.26668840646743774, 0.019648341462016106, 0.01755940169095993, 0.01368130836635828, 0.023000601679086685, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01718730293214321, 0.02692273259162903, 0.05480796471238136, 0.010818017646670341, 0.7150712013244629, 0.0585104264318943, 0.04717297852039337, 0.030360547825694084, 0.039148781448602676, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006439396180212498, 0.012697076424956322, 0.014188298024237156, 0.000897688849363476, 0.7481768727302551, 0.15047557651996613, 0.03333613649010658, 0.01207506563514471, 0.021714046597480774, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009459104388952255, 0.022298788651823997, 0.013802104629576206, 0.011955137364566326, 0.03879927098751068, 0.1585427075624466, 0.07075291126966476, 0.329448938369751, 0.3449409306049347, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04810584336519241, 0.017975708469748497, 0.025123968720436096, 0.023182567209005356, 0.020010611042380333, 0.04571577161550522, 0.1801854819059372, 0.06764508783817291, 0.5720548629760742, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026153914630413055, 0.0356404148042202, 0.10573611408472061, 0.06201518699526787, 0.06006328761577606, 0.09286139905452728, 0.2927103638648987, 0.20419549942016602, 0.12062377482652664, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5821239352226257, 0.14550858736038208, 0.031251534819602966, 0.030760297551751137, 0.02147754468023777, 0.013665237464010715, 0.009087015874683857, 0.01557532325387001, 0.15055041015148163, 0.0, 0.00632825493812561, 0.011520092375576496, 0.08263711631298065, 0.006356080062687397, 0.022936103865504265, 0.03108564019203186, 0.013897407799959183, 0.697504997253418, 0.12773430347442627, 0.0], [0.12817564606666565, 0.33913177251815796, 0.07241326570510864, 0.41213902831077576, 0.0326012559235096, 0.0031606394331902266, 0.0006341012776829302, 0.007317711599171162, 0.0044263736344873905, 0.0, 0.008715116418898106, 0.015272715128958225, 0.10463730990886688, 0.08011683076620102, 0.13045108318328857, 0.05373600497841835, 0.015578814782202244, 0.4212273955345154, 0.1702648103237152, 0.0], [0.08047150820493698, 0.06199575960636139, 0.5555182099342346, 0.2858560383319855, 0.008700164034962654, 0.003758196486160159, 0.001155794132500887, 0.0007424709619954228, 0.0018020549323409796, 0.0, 0.004959889687597752, 0.007777809165418148, 0.14492008090019226, 0.02459821291267872, 0.014704479835927486, 0.016136664897203445, 0.008129375986754894, 0.7319321036338806, 0.0468413271009922, 0.0], [0.010044030845165253, 0.018482256680727005, 0.6269924640655518, 0.32439544796943665, 0.01023165788501501, 0.007641270756721497, 0.0008933563949540257, 0.0010311403311789036, 0.00028844154439866543, 0.0, 0.005315575283020735, 0.0021190166007727385, 0.007080279756337404, 0.006970370654016733, 0.010002117604017258, 0.007610250264406204, 0.004703941754996777, 0.8570073246955872, 0.09919113665819168, 0.0], [0.0007911038701422513, 0.0008549468475393951, 0.015090622939169407, 0.8270009160041809, 0.11969847232103348, 0.032614268362522125, 0.0024233118165284395, 0.0011481117689982057, 0.0003779604157898575, 0.0, 0.0016317280242219567, 0.0005414763581939042, 0.004523266106843948, 0.0019645043648779392, 0.010821727104485035, 0.008883371017873287, 0.00927714817225933, 0.920802652835846, 0.041554201394319534, 0.0], [0.017773190513253212, 0.008623103611171246, 0.0020072387997061014, 0.08177924901247025, 0.13816505670547485, 0.6801413297653198, 0.02186667174100876, 0.024107687175273895, 0.025536518543958664, 0.0, 0.002020488725975156, 0.0007793906843289733, 0.022791940718889236, 0.005821499973535538, 0.1932065784931183, 0.30031588673591614, 0.08197023719549179, 0.12508654594421387, 0.2680076062679291, 0.0], [0.000318053673254326, 5.6540200603194535e-05, 1.071194674295839e-05, 0.0009494975674897432, 0.0034297029487788677, 0.032661326229572296, 0.9588278532028198, 0.003185966284945607, 0.0005602877936325967, 0.0, 0.007396090775728226, 0.0032474161125719547, 0.00692824088037014, 0.007240207865834236, 0.42384257912635803, 0.04473983123898506, 0.013007782399654388, 0.007779541425406933, 0.4858182966709137, 0.0], [0.0017862697131931782, 0.0002347631088923663, 2.1297884813975543e-05, 0.0004797980946023017, 0.0018031852087005973, 0.024247879162430763, 0.45456385612487793, 0.5099425911903381, 0.006920217536389828, 0.0, 0.0026900237426161766, 0.0007204422145150602, 0.005861051380634308, 0.003422616282477975, 0.46744993329048157, 0.10402297228574753, 0.05837857723236084, 0.0177029799669981, 0.3397515118122101, 0.0], [0.0006541880429722369, 0.0009561541373841465, 7.73017163737677e-05, 0.00942671112716198, 0.04198922589421272, 0.04971348121762276, 0.32961171865463257, 0.4513629972934723, 0.11620841920375824, 0.0, 0.005906206555664539, 0.002057044068351388, 0.0031123505905270576, 0.008901549503207207, 0.43650564551353455, 0.08504725992679596, 0.0923796221613884, 0.009556618519127369, 0.3565336763858795, 0.0], [0.017209511250257492, 0.004475452937185764, 3.128392927465029e-05, 0.00047953161993063986, 0.00448839133605361, 0.03360708802938461, 0.11509764194488525, 0.5398797988891602, 0.2847314178943634, 0.0, 0.013360978104174137, 0.04520300775766373, 0.09048072248697281, 0.012179902754724026, 0.030064363032579422, 0.023480970412492752, 0.008669134229421616, 0.03746046498417854, 0.7391002178192139, 0.0]], [[0.02415475994348526, 0.0027711745351552963, 0.003856832394376397, 0.0957413911819458, 0.02159286104142666, 0.03336814045906067, 0.009564127773046494, 0.03954486921429634, 0.7694058418273926, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9052021503448486, 0.02053658291697502, 0.0014916026266291738, 0.00022646080469712615, 4.7710393118904904e-05, 0.000383042759494856, 0.014123834669589996, 0.0205638837069273, 0.03742456063628197, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.37607336044311523, 0.6030705571174622, 0.0068079219199717045, 0.0036466827150434256, 9.876023250399157e-05, 2.0246809071977623e-05, 0.0007042856304906309, 0.002560489112511277, 0.007017510011792183, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.0091031880583614e-05, 0.00024915943504311144, 0.9895205497741699, 0.006273698527365923, 0.0016484790248796344, 4.1711446101544425e-05, 7.522702958340233e-07, 1.2660359971050639e-05, 0.002202932955697179, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.009441080503166e-05, 9.311464236816391e-05, 0.006593613885343075, 0.9913647770881653, 0.0018261962104588747, 1.6436462829005904e-05, 8.038865075832291e-07, 1.0318336762793479e-06, 2.3524326024926268e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [3.1561212381348014e-05, 1.8178753862230224e-06, 0.00011904581333510578, 0.027105441316962242, 0.8800897598266602, 0.09253741800785065, 0.00010895416926359758, 5.953493655397324e-06, 1.9602707368449046e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.7160528553716858e-09, 1.4191656530493368e-11, 3.274841375855431e-08, 2.1219284462858923e-07, 1.9925082597183064e-05, 0.9999751448631287, 3.130498271275428e-06, 1.9788064946624218e-06, 3.1215499074477293e-09, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.2861962204624433e-05, 5.737682045037218e-07, 2.0471109110076213e-06, 1.0477544492459856e-05, 6.581651632586727e-06, 0.02534269355237484, 0.16125597059726715, 0.5878354907035828, 0.22553342580795288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0009172551217488945, 7.270056084962562e-05, 2.2026280930731446e-05, 4.6261970965133514e-06, 4.921669642499182e-06, 4.060195351485163e-05, 0.027831047773361206, 0.33271971344947815, 0.6383873224258423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.3075091374048498e-05, 6.147480598883703e-05, 4.768987855641171e-05, 2.045959490715177e-06, 1.1152823553572944e-08, 3.07468525306831e-07, 0.0007055726600810885, 0.02803119830787182, 0.9711382985115051, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20143046975135803, 0.41116827726364136, 0.09215858578681946, 0.10672477632761002, 0.06125285103917122, 0.017610367387533188, 0.01457523088902235, 0.02514597773551941, 0.06993352621793747, 0.0, 0.023652182891964912, 0.008639940991997719, 0.08203616738319397, 0.035750582814216614, 0.050224509090185165, 0.3533262312412262, 0.03081362321972847, 0.28302860260009766, 0.1325281411409378, 0.0], [0.026864346116781235, 0.037146128714084625, 0.08411292731761932, 0.02904331497848034, 0.0955604761838913, 0.05886658653616905, 0.08584483712911606, 0.4076027572154999, 0.17495866119861603, 0.0, 0.016670020297169685, 0.1283574253320694, 0.836423397064209, 0.0042742472141981125, 0.0022883012425154448, 0.00297459471039474, 0.00022807312780059874, 0.0012588471872732043, 0.007524838205426931, 0.0], [0.073190838098526, 0.07998740673065186, 0.05594569817185402, 0.03243006020784378, 0.10037493705749512, 0.13878461718559265, 0.15250830352306366, 0.25721096992492676, 0.10956726223230362, 0.0, 0.031559381633996964, 0.02045642025768757, 0.8176267743110657, 0.006169404834508896, 0.0014412011951208115, 0.0069603933952748775, 0.0010916722239926457, 0.011522608809173107, 0.10317197442054749, 0.0], [0.0438627265393734, 0.04628896340727806, 0.4038660526275635, 0.005475929472595453, 0.03436022624373436, 0.11165640503168106, 0.02260321006178856, 0.28233063220977783, 0.04955587536096573, 0.0, 0.004598122555762529, 0.004610949195921421, 0.01865001954138279, 0.020574036985635757, 0.0137012405321002, 0.7973257303237915, 0.01646837778389454, 0.023596635088324547, 0.1004747673869133, 0.0], [0.2377929538488388, 0.08882997930049896, 0.12371516227722168, 0.08651548624038696, 0.015416872687637806, 0.04211122542619705, 0.16403844952583313, 0.11833071708679199, 0.12324906885623932, 0.0, 0.0005213705007918179, 0.00018707667186390609, 0.0016978917410597205, 0.019619440659880638, 0.009308884851634502, 0.8590161800384521, 0.024511896073818207, 0.06970686465501785, 0.015430280938744545, 0.0], [0.023254310712218285, 0.0034057339653372765, 0.036038532853126526, 0.009054891765117645, 0.0329253226518631, 0.05284882336854935, 0.15671837329864502, 0.6067742109298706, 0.07897992432117462, 0.0, 0.0001481063081882894, 2.072651477647014e-05, 0.00035672096419148147, 0.00033358228392899036, 0.00040588833508081734, 0.9861487746238708, 0.00651955883949995, 0.00443643843755126, 0.0016300288261845708, 0.0], [0.015282228589057922, 0.008608018048107624, 0.08339564502239227, 0.032651614397764206, 0.21303850412368774, 0.22661514580249786, 0.21832069754600525, 0.1323210895061493, 0.06976725161075592, 0.0, 0.0010996124474331737, 0.0011850595474243164, 0.0075045316480100155, 0.004539311397820711, 0.05570072680711746, 0.18870605528354645, 0.23963898420333862, 0.013960372656583786, 0.487665593624115, 0.0], [0.019424932077527046, 0.008587736636400223, 0.014951083809137344, 0.01159222237765789, 0.2890152633190155, 0.2543036639690399, 0.2561561167240143, 0.0882645845413208, 0.05770434811711311, 0.0, 0.0003884119214490056, 0.0004658032557927072, 0.028157439082860947, 0.0002352961164433509, 0.1278570294380188, 0.08260466903448105, 0.02582997828722, 0.022790132090449333, 0.7116712927818298, 0.0], [0.020595766603946686, 0.015824340283870697, 0.008689227513968945, 0.03796549141407013, 0.3004503846168518, 0.16956602036952972, 0.10506420582532883, 0.05004280060529709, 0.2918018400669098, 0.0, 0.0015414542285725474, 0.0007310948567464948, 0.010464987717568874, 0.0012846259633079171, 0.45206302404403687, 0.029316790401935577, 0.04706822335720062, 0.018986493349075317, 0.4385431706905365, 0.0], [0.18154361844062805, 0.0977708026766777, 0.20556335151195526, 0.05251142755150795, 0.13640889525413513, 0.06629360467195511, 0.06030320003628731, 0.08172836154699326, 0.11787670105695724, 0.0, 0.0005072542116977274, 0.0011837932979688048, 0.01220926083624363, 8.532252832083032e-05, 0.0018606879748404026, 0.010199862532317638, 0.0016309961210936308, 0.010775143280625343, 0.9615475535392761, 0.0]], [[0.060361556708812714, 0.015829458832740784, 0.05784451961517334, 0.3351474404335022, 0.06477320939302444, 0.04427827522158623, 0.09356044977903366, 0.03362266346812248, 0.2945823669433594, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.051239900290966034, 0.0459107868373394, 0.10656695812940598, 0.4080160856246948, 0.16381530463695526, 0.044977184385061264, 0.05972094088792801, 0.009804679080843925, 0.10994797199964523, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019088272005319595, 0.05349855497479439, 0.4389742910861969, 0.022328443825244904, 0.03395729511976242, 0.20592069625854492, 0.007582489866763353, 0.08437496423721313, 0.13427504897117615, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03275543451309204, 0.01311502419412136, 0.038520246744155884, 0.47789818048477173, 0.04586595296859741, 0.01380465179681778, 0.03337283805012703, 0.07212045043706894, 0.27254730463027954, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04071904346346855, 0.043366871774196625, 0.1190471276640892, 0.18268215656280518, 0.2763146162033081, 0.029253922402858734, 0.017268449068069458, 0.0670313611626625, 0.22431644797325134, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04853136092424393, 0.0034203159157186747, 0.17822766304016113, 0.005087696481496096, 0.02670232392847538, 0.5734196305274963, 0.06478680670261383, 0.04684215411543846, 0.05298209935426712, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.016102498397231102, 0.0006646174006164074, 0.00315408268943429, 0.003398373955860734, 0.01210782676935196, 0.07864897698163986, 0.743419349193573, 0.023116787895560265, 0.11938738822937012, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0031801864970475435, 0.0032259617000818253, 0.027063841000199318, 0.0018325509736314416, 0.006064774002879858, 0.017839375883340836, 0.05006564408540726, 0.8002738952636719, 0.0904538482427597, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02500138245522976, 0.016465606167912483, 0.02692888118326664, 0.01824249140918255, 0.047875918447971344, 0.06556686758995056, 0.15585453808307648, 0.21941381692886353, 0.42465049028396606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07641319185495377, 0.017753547057509422, 0.039497166872024536, 0.014236720278859138, 0.03872253745794296, 0.1210501492023468, 0.17305448651313782, 0.2333979308605194, 0.28587427735328674, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07673492282629013, 0.03585591912269592, 0.0804624855518341, 0.05707075819373131, 0.16190174221992493, 0.1288135051727295, 0.1235240250825882, 0.06807681918144226, 0.2675597667694092, 0.0, 0.29744189977645874, 0.04770943149924278, 0.09888078272342682, 0.19768767058849335, 0.048243775963783264, 0.12058595567941666, 0.05976371467113495, 0.03847452625632286, 0.09121233224868774, 0.0], [0.005086997989565134, 0.014635499566793442, 0.013461720198392868, 0.6349815726280212, 0.14714521169662476, 0.015218403190374374, 0.01605474203824997, 0.018318237736821175, 0.1350976973772049, 0.0, 0.04126456007361412, 0.6604095697402954, 0.028894882649183273, 0.20104490220546722, 0.0014044500421732664, 0.0009343607816845179, 0.00244489056058228, 0.007453228812664747, 0.05614929273724556, 0.0], [0.03515003249049187, 0.049813926219940186, 0.04029693454504013, 0.4151618778705597, 0.24873343110084534, 0.009437951259315014, 0.008381601423025131, 0.020832136273384094, 0.17219208180904388, 0.0, 0.008357543498277664, 0.0022072584833949804, 0.9876156449317932, 8.841200906317681e-05, 1.4883004041621462e-05, 0.00011741811613319442, 2.7020510970032774e-05, 0.00016062626673374325, 0.001411277218721807, 0.0], [0.06722414493560791, 0.13528113067150116, 0.06224377825856209, 0.18915168941020966, 0.17580503225326538, 0.07229694724082947, 0.012536793015897274, 0.09137610346078873, 0.19408434629440308, 0.0, 0.06216944754123688, 0.48559242486953735, 0.042546145617961884, 0.034007471054792404, 0.047574639320373535, 0.12490913271903992, 0.07922931015491486, 0.013364763930439949, 0.11060672253370285, 0.0], [0.09099949151277542, 0.09548961371183395, 0.04829362779855728, 0.1739831268787384, 0.06667517125606537, 0.05157051607966423, 0.05465595796704292, 0.06177656352519989, 0.3565560579299927, 0.0, 0.05222959443926811, 0.025416702032089233, 0.02865077182650566, 0.17457211017608643, 0.03144511207938194, 0.3907364010810852, 0.19607771933078766, 0.05274118855595589, 0.04813018813729286, 0.0], [0.09822985529899597, 0.05441536381840706, 0.039150238037109375, 0.06369251012802124, 0.05292840674519539, 0.050128646194934845, 0.044398434460163116, 0.04042055085301399, 0.5566359758377075, 0.0, 0.0037726862356066704, 0.0031579534988850355, 0.0029440780635923147, 0.0017320584738627076, 0.060473062098026276, 0.761774480342865, 0.1523173600435257, 0.0058823637664318085, 0.007945872843265533, 0.0], [0.012019939720630646, 0.0076602306216955185, 0.02716030552983284, 0.03984800726175308, 0.09776019304990768, 0.05175628885626793, 0.08536165207624435, 0.0944109782576561, 0.5840223431587219, 0.0, 0.0020738786552101374, 0.0012752892216667533, 0.0004058163322042674, 0.020963717252016068, 0.39340031147003174, 0.012434415519237518, 0.4783190190792084, 0.011497312225401402, 0.0796302929520607, 0.0], [0.036716632544994354, 0.021969007328152657, 0.010507079772651196, 0.012404722161591053, 0.040125522762537, 0.010736462660133839, 0.018730206415057182, 0.030387653037905693, 0.8184227347373962, 0.0, 5.31752230017446e-05, 1.4492364243778866e-05, 7.312332309084013e-05, 0.0023682843893766403, 0.9866323471069336, 0.0009243910317309201, 0.0011850211303681135, 0.0017622504383325577, 0.0069872229360044, 0.0], [0.04769879952073097, 0.19333122670650482, 0.02803504839539528, 0.016029207035899162, 0.11119306832551956, 0.03845509514212608, 0.011404097080230713, 0.0836206004023552, 0.4702327847480774, 0.0, 4.074166645295918e-05, 1.823456841520965e-05, 0.0001418270985595882, 0.007263784296810627, 0.9604514241218567, 0.0001852070417953655, 0.00034164052340202034, 0.0018497714772820473, 0.029707150533795357, 0.0], [0.05245642364025116, 0.013315027579665184, 0.012056763283908367, 0.004825723823159933, 0.015483945608139038, 0.032884638756513596, 0.027794960886240005, 0.07057305425405502, 0.7706093788146973, 0.0, 0.0133396340534091, 0.03136875480413437, 0.6319980621337891, 0.0033722908701747656, 0.04728742688894272, 0.03541773557662964, 0.009523973800241947, 0.03100484237074852, 0.1966874897480011, 0.0]], [[0.15564993023872375, 0.3264511823654175, 0.08247561007738113, 0.04047680273652077, 0.04636594280600548, 0.03705644607543945, 0.05653020739555359, 0.08808662742376328, 0.16690711677074432, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6047166585922241, 0.08402378112077713, 0.11650887131690979, 0.004807815421372652, 0.02726476825773716, 0.0609126091003418, 0.02905944734811783, 0.012920884415507317, 0.059785205870866776, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5938906669616699, 0.07300958037376404, 0.08890929818153381, 0.008111076429486275, 0.04038470610976219, 0.07353192567825317, 0.03085281327366829, 0.08706387132406235, 0.004246041644364595, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2591831088066101, 0.17658700048923492, 0.44177621603012085, 0.01689036749303341, 0.0653892457485199, 0.01502177957445383, 0.02055797167122364, 0.0024378441739827394, 0.0021566858049482107, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.33400091528892517, 0.03927909955382347, 0.27614372968673706, 0.009977479465305805, 0.12025652825832367, 0.1713484674692154, 0.04292818158864975, 0.004225345328450203, 0.00184013566467911, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06147114187479019, 0.019044799730181694, 0.059415291994810104, 0.05198045074939728, 0.12181691080331802, 0.419679194688797, 0.1140735000371933, 0.14551687240600586, 0.00700181070715189, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006845483556389809, 0.002091927919536829, 0.01196279563009739, 0.014390786178410053, 0.02692629024386406, 0.8455513715744019, 0.07174734026193619, 0.017689114436507225, 0.0027949714567512274, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00039940490387380123, 0.00013551976007875055, 0.020663700997829437, 0.008696838282048702, 0.021915050223469734, 0.1381293535232544, 0.0347108468413353, 0.7650054097175598, 0.010343861766159534, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02615724503993988, 0.0051858089864254, 0.038734134286642075, 0.021585455164313316, 0.19684533774852753, 0.17548950016498566, 0.1665634661912918, 0.2796759307384491, 0.08976294845342636, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.043001022189855576, 0.016749290749430656, 0.04958483204245567, 0.06659381091594696, 0.0702962800860405, 0.27735820412635803, 0.14212922751903534, 0.20686522126197815, 0.12742231786251068, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05745904520153999, 0.06613133102655411, 0.11319872736930847, 0.031750500202178955, 0.0641264021396637, 0.07090476900339127, 0.053613319993019104, 0.1108509749174118, 0.4319649040699005, 0.0, 0.03367111459374428, 0.018932543694972992, 0.09506545215845108, 0.04718795791268349, 0.028798582032322884, 0.33658939599990845, 0.02586139366030693, 0.29842811822891235, 0.11546547710895538, 0.0], [0.12783250212669373, 0.16847258806228638, 0.08126984536647797, 0.10575822740793228, 0.03301985561847687, 0.2111520618200302, 0.10687874257564545, 0.06316707283258438, 0.10244929045438766, 0.0, 0.006203038617968559, 0.0906001627445221, 0.6977949738502502, 0.018352899700403214, 0.06787873804569244, 0.04403599724173546, 0.001631368650123477, 0.024296771734952927, 0.049206044524908066, 0.0], [0.1413263976573944, 0.38601601123809814, 0.16798537969589233, 0.14611834287643433, 0.015951359644532204, 0.042198505252599716, 0.016183707863092422, 0.06246974319219589, 0.021750787273049355, 0.0, 0.006243667099624872, 0.010453532449901104, 0.7879610657691956, 0.004093538969755173, 0.0008473669877275825, 0.027760563418269157, 0.0003080451278947294, 0.14831961691379547, 0.014012438245117664, 0.0], [0.020376645028591156, 0.008152640424668789, 0.04579228535294533, 0.022974595427513123, 0.007921000011265278, 0.11700868606567383, 0.010826223529875278, 0.7216546535491943, 0.04529344290494919, 0.0, 0.004387176129966974, 0.023410169407725334, 0.17247918248176575, 0.03958609700202942, 0.023799436166882515, 0.43659475445747375, 0.014754846692085266, 0.2318120151758194, 0.05317622795701027, 0.0], [0.04728184640407562, 0.041129130870103836, 0.12847241759300232, 0.038289085030555725, 0.07389654964208603, 0.11478690057992935, 0.04442784935235977, 0.41169247031211853, 0.1000237911939621, 0.0, 0.0020952164195477962, 0.0024118656292557716, 0.028229335322976112, 0.007075420115143061, 0.019164882600307465, 0.5397294163703918, 0.034580815583467484, 0.3465326428413391, 0.020180128514766693, 0.0], [0.016180921345949173, 0.005130380857735872, 0.21081623435020447, 0.00797765702009201, 0.04691680520772934, 0.052309177815914154, 0.2947923243045807, 0.34133997559547424, 0.02453651838004589, 0.0, 0.00020744462381117046, 0.00036016973899677396, 0.004934145137667656, 0.0004664760490413755, 0.008187839761376381, 0.9661812782287598, 0.009987047873437405, 0.003882928751409054, 0.005792597308754921, 0.0], [0.006579844746738672, 0.001606129459105432, 0.206822007894516, 0.017204096540808678, 0.13898226618766785, 0.09910376369953156, 0.4235020577907562, 0.05497713387012482, 0.051222700625658035, 0.0, 3.4081476769642904e-05, 1.7181657312903553e-05, 5.4824478866066784e-05, 0.00045897584641352296, 0.0043338024988770485, 0.001544477418065071, 0.9909620881080627, 2.356152981519699e-05, 0.0025708049070090055, 0.0], [0.00896216370165348, 0.0023249718360602856, 0.0226416178047657, 0.05458173528313637, 0.07694459706544876, 0.29436299204826355, 0.36870595812797546, 0.12525610625743866, 0.046219732612371445, 0.0, 0.0001047314508468844, 0.0001599654060555622, 0.001310097286477685, 0.001540280063636601, 0.833267331123352, 0.044754061847925186, 0.0028599577490240335, 0.0006454077665694058, 0.11535807698965073, 0.0], [0.027829669415950775, 0.014619122259318829, 0.014550572261214256, 0.048137370496988297, 0.15001901984214783, 0.11716196686029434, 0.34159788489341736, 0.1513865739107132, 0.13469791412353516, 0.0, 8.819431968731806e-05, 6.364465662045404e-05, 0.00022057128080632538, 0.001112746773287654, 0.9560981392860413, 0.003599100047722459, 0.0002217600413132459, 0.0006697923527099192, 0.03792598471045494, 0.0], [0.0014273751294240355, 0.003807784290984273, 0.3760293126106262, 0.002253596903756261, 0.11343870311975479, 0.12883712351322174, 0.04242479428648949, 0.28902071714401245, 0.042760640382766724, 0.0, 0.0018130787648260593, 0.022020958364009857, 0.12822051346302032, 0.0005810249131172895, 0.03168048337101936, 0.014293116517364979, 0.002500524278730154, 0.0212943647056818, 0.7775959372520447, 0.0]]], [[[0.13086311519145966, 0.049477167427539825, 0.10100015252828598, 0.03843620419502258, 0.27287009358406067, 0.20078831911087036, 0.16546384990215302, 0.03368193656206131, 0.007419050205498934, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1137659102678299, 0.11250672489404678, 0.21935509145259857, 0.09974226355552673, 0.22245454788208008, 0.11022598296403885, 0.0977952778339386, 0.010162456892430782, 0.013991687446832657, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09118296205997467, 0.0991944894194603, 0.31555840373039246, 0.16625922918319702, 0.1399575173854828, 0.0926588773727417, 0.021735703572630882, 0.056496523320674896, 0.016956249251961708, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.35773080587387085, 0.19870112836360931, 0.026073846966028214, 0.07347559928894043, 0.09251826256513596, 0.0859094187617302, 0.06421677768230438, 0.06334269791841507, 0.0380314365029335, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02230222336947918, 0.0210218857973814, 0.024334343150258064, 0.36442241072654724, 0.2750929892063141, 0.13295342028141022, 0.06824173033237457, 0.0036951478105038404, 0.0879359245300293, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.018942566588521004, 0.011805560439825058, 0.04696377366781235, 0.09440026432275772, 0.39890599250793457, 0.17608429491519928, 0.10613365471363068, 0.10454639047384262, 0.04221746698021889, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0475851334631443, 0.008668179623782635, 0.011950161308050156, 0.0786907747387886, 0.09432563930749893, 0.07653870433568954, 0.4287588894367218, 0.13403372466564178, 0.1194487139582634, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008243327029049397, 0.006908380892127752, 0.04044030234217644, 0.08380357921123505, 0.1593569815158844, 0.1858288198709488, 0.0890916958451271, 0.40247857570648193, 0.02384827472269535, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09753390401601791, 0.04787491634488106, 0.10570236295461655, 0.09989321976900101, 0.07242950052022934, 0.16000299155712128, 0.13195638358592987, 0.12870465219020844, 0.15590202808380127, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3338638246059418, 0.05386793985962868, 0.15485166013240814, 0.05483235418796539, 0.052468191832304, 0.12754301726818085, 0.13515245914459229, 0.06475869566202164, 0.022661946713924408, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9917634725570679, 0.008236419409513474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6252409815788269, 0.3747589886188507, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.711856484413147, 0.20838035643100739, 0.07976315170526505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8520486354827881, 0.010580658912658691, 0.13737063109874725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6327172517776489, 0.1227935329079628, 0.21565596759319305, 0.028833283111453056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05910082906484604, 0.011589597910642624, 0.877491295337677, 0.051818281412124634, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3586137592792511, 0.038762304931879044, 0.08015953004360199, 0.4233120083808899, 0.09915236383676529, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3626183867454529, 0.026959313079714775, 0.07612177729606628, 0.13077552616596222, 0.4035249352455139, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7095601558685303, 0.03453405201435089, 0.02220289036631584, 0.009008818306028843, 0.201883926987648, 0.022810086607933044, 0.0, 0.0, 0.0, 0.0, 0.21979263424873352, 0.001410112832672894, 0.007092535495758057, 0.13166557252407074, 0.626970648765564, 0.013068560510873795, 0.0, 0.0, 0.0, 0.0], [0.5828825831413269, 0.02795644849538803, 0.054448600858449936, 0.01975347101688385, 0.11504233628511429, 0.08908692002296448, 0.11082970350980759, 0.0, 0.0, 0.0, 0.08148042857646942, 0.001490423921495676, 0.004908325150609016, 0.01383854728192091, 0.7959722876548767, 0.05201547220349312, 0.05029459297657013, 0.0, 0.0, 0.0], [0.4315364956855774, 0.020537925884127617, 0.01659376546740532, 0.014654956758022308, 0.13063199818134308, 0.27319464087486267, 0.08869150280952454, 0.024158723652362823, 0.0, 0.0, 0.03934427723288536, 5.908778257435188e-05, 0.00014962907880544662, 0.005592166446149349, 0.7025003433227539, 0.1675100177526474, 0.03920353576540947, 0.04564077779650688, 0.0, 0.0], [0.26020547747612, 0.014821716584265232, 0.01224969606846571, 0.0724530965089798, 0.10939211398363113, 0.19152909517288208, 0.10495918244123459, 0.1680101454257965, 0.06637949496507645, 0.0, 0.4660189151763916, 0.00034756408422254026, 9.701005183160305e-05, 0.008154522627592087, 0.08121690154075623, 0.15592943131923676, 0.11426379531621933, 0.17044323682785034, 0.0035288764629513025, 0.0], [0.6687084436416626, 0.04345089942216873, 0.009689688682556152, 0.0018685735994949937, 0.0738394483923912, 0.12735962867736816, 0.025320274755358696, 0.026545442640781403, 0.020931225270032883, 0.0022863498888909817, 0.3707294762134552, 0.0020887483842670918, 0.23984688520431519, 0.07748916745185852, 0.18109895288944244, 0.03584783151745796, 0.005205830093473196, 0.005058187525719404, 0.0050886403769254684, 0.0775463655591011]], [[0.011833908967673779, 0.03545977920293808, 0.03510122373700142, 0.06200635805726051, 0.09438431262969971, 0.06055876612663269, 0.053256530314683914, 0.30701303482055664, 0.3403860926628113, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03663749620318413, 0.06511621922254562, 0.05716057866811752, 0.07533077895641327, 0.10846659541130066, 0.037432827055454254, 0.04480022192001343, 0.18166707456111908, 0.39338818192481995, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06557667255401611, 0.03966936469078064, 0.008358842693269253, 0.06794404983520508, 0.05668830871582031, 0.02720261737704277, 0.07913517951965332, 0.20437636971473694, 0.45104852318763733, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.044038429856300354, 0.07477934658527374, 0.10143070667982101, 0.16204005479812622, 0.06265459954738617, 0.10170722752809525, 0.08676454424858093, 0.0699862688779831, 0.2965989410877228, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06005045771598816, 0.046840403228998184, 0.06629239022731781, 0.04125581681728363, 0.007815167307853699, 0.20412082970142365, 0.1083299070596695, 0.04942404478788376, 0.41587093472480774, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03666035085916519, 0.028792625293135643, 0.06887229532003403, 0.18481910228729248, 0.15058831870555878, 0.048441674560308456, 0.0780390277504921, 0.13469383120536804, 0.26909276843070984, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03408746421337128, 0.026394939050078392, 0.05409233644604683, 0.06951043754816055, 0.1446777582168579, 0.09970070421695709, 0.05472328141331673, 0.16119606792926788, 0.35561704635620117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12936006486415863, 0.04621516913175583, 0.10149524360895157, 0.14774896204471588, 0.45855623483657837, 0.033130910247564316, 0.031401973217725754, 0.02012830227613449, 0.031963150948286057, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1214270144701004, 0.04088712856173515, 0.05250505730509758, 0.07924661785364151, 0.05337269604206085, 0.10527284443378448, 0.08820997178554535, 0.17732012271881104, 0.28175854682922363, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13074854016304016, 0.06475767493247986, 0.07325490564107895, 0.0625966489315033, 0.14061231911182404, 0.07830052822828293, 0.12438739091157913, 0.21453101933002472, 0.11081094294786453, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9482711553573608, 0.051728855818510056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15256483852863312, 0.8474349975585938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8711318373680115, 0.04994085431098938, 0.07892734557390213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08618302643299103, 0.30268052220344543, 0.6111364364624023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7221198678016663, 0.040686361491680145, 0.06532222777605057, 0.17187155783176422, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6251113414764404, 0.14608541131019592, 0.21724094450473785, 0.011562197469174862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5948007702827454, 0.036634139716625214, 0.02264709398150444, 0.035541336983442307, 0.3103766441345215, 0.0, 0.0, 0.0, 0.0, 0.0, 0.31851068139076233, 0.11805614084005356, 0.02926168404519558, 0.0854775682091713, 0.44869405031204224, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6650473475456238, 0.01644211634993553, 0.019737746566534042, 0.0375308021903038, 0.10231779515743256, 0.15892422199249268, 0.0, 0.0, 0.0, 0.0, 0.23099647462368011, 0.015003926120698452, 0.0028121687937527895, 0.025386620312929153, 0.5829272270202637, 0.14287345111370087, 0.0, 0.0, 0.0, 0.0], [0.36675524711608887, 0.04118315875530243, 0.02765432558953762, 0.03228116035461426, 0.11875578761100769, 0.12892943620681763, 0.2844408452510834, 0.0, 0.0, 0.0, 0.2648485600948334, 0.01456066407263279, 0.008421574719250202, 0.01653379574418068, 0.25845009088516235, 0.35933130979537964, 0.07785411924123764, 0.0, 0.0, 0.0], [0.19659309089183807, 0.015950728207826614, 0.02453998662531376, 0.039237309247255325, 0.037656329572200775, 0.34599894285202026, 0.23759640753269196, 0.10242718458175659, 0.0, 0.0, 0.21031156182289124, 0.00652333116158843, 0.005756322760134935, 0.019128819927573204, 0.2526819407939911, 0.49096593260765076, 0.008809886872768402, 0.00582215515896678, 0.0, 0.0], [0.3881740868091583, 0.012267092242836952, 0.01897304505109787, 0.013982790522277355, 0.030991200357675552, 0.10819684714078903, 0.20157809555530548, 0.14642520248889923, 0.07941170781850815, 0.0, 0.11555754393339157, 0.00475481478497386, 0.0013921409845352173, 0.045808907598257065, 0.29882168769836426, 0.3024459183216095, 0.0483231395483017, 0.18265680968761444, 0.0002390409354120493, 0.0], [0.11410266160964966, 0.03479800745844841, 0.043540675193071365, 0.021180409938097, 0.03197954222559929, 0.2248576581478119, 0.12852585315704346, 0.2089216560125351, 0.039846520870923996, 0.1522471308708191, 0.8451279401779175, 0.021679740399122238, 0.035543736070394516, 0.005811640061438084, 0.04445958510041237, 0.018052000552415848, 0.0015424924204126, 0.013668404892086983, 0.012673787772655487, 0.0014405279653146863]], [[0.0022766904439777136, 0.00227623013779521, 0.027263110503554344, 0.7988243699073792, 0.12335250526666641, 0.012830986641347408, 0.008179515600204468, 0.004631126299500465, 0.020365260541439056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022365765646100044, 0.0197063609957695, 0.08540411293506622, 0.7100865840911865, 0.10288897156715393, 0.023861246183514595, 0.009303209371864796, 0.012690575793385506, 0.013693095184862614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.023093748837709427, 0.013999207876622677, 0.09048538655042648, 0.10519850999116898, 0.12126202881336212, 0.34847554564476013, 0.057331401854753494, 0.0919070839881897, 0.14824725687503815, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03627682104706764, 0.0323517769575119, 0.06003699079155922, 0.04609783738851547, 0.3189731240272522, 0.3202785551548004, 0.06900984793901443, 0.021341597661376, 0.0956336110830307, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026664189994335175, 0.018690558150410652, 0.01473171729594469, 0.003785684471949935, 0.012891196645796299, 0.6301508545875549, 0.1024516150355339, 0.10377107560634613, 0.08686315268278122, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.010066811926662922, 0.005272349342703819, 0.019913937896490097, 0.005584465805441141, 0.0479762889444828, 0.06466472148895264, 0.2978198528289795, 0.22872935235500336, 0.31997203826904297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.054553788155317307, 0.011876759119331837, 0.005296430550515652, 0.008171333000063896, 0.17499762773513794, 0.29638832807540894, 0.22286026179790497, 0.017016055062413216, 0.20883934199810028, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03061697818338871, 0.020777547731995583, 0.27117541432380676, 0.010558649897575378, 0.16651615500450134, 0.3011224865913391, 0.026109976693987846, 0.048922766000032425, 0.12420005351305008, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16545239090919495, 0.03877135366201401, 0.007565324194729328, 0.015141250565648079, 0.03747279569506645, 0.3241279125213623, 0.26990416646003723, 0.043362975120544434, 0.09820175170898438, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22949647903442383, 0.0972394198179245, 0.02905140444636345, 0.03182214871048927, 0.025490015745162964, 0.08278947323560715, 0.15009135007858276, 0.031098822131752968, 0.3229208290576935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.993086576461792, 0.0069133141078054905, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9927853345870972, 0.007214863318949938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9852874875068665, 0.011381878517568111, 0.0033306065015494823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.011021426878869534, 0.007158290129154921, 0.9818204641342163, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4834398031234741, 0.011301998049020767, 0.48758530616760254, 0.017672834917902946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.007071706000715494, 0.026167649775743484, 0.19316613674163818, 0.773594319820404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9851425886154175, 0.0010397545993328094, 0.00470126885920763, 0.0012236799811944366, 0.007892588153481483, 0.0, 0.0, 0.0, 0.0, 0.0, 0.320003479719162, 0.03976304829120636, 0.22334550321102142, 0.24320250749588013, 0.17368540167808533, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6588926315307617, 0.005506628658622503, 0.021607331931591034, 0.010738613083958626, 0.07747143507003784, 0.2257833182811737, 0.0, 0.0, 0.0, 0.0, 0.10932182520627975, 0.001151762087829411, 0.007792286574840546, 0.18981949985027313, 0.6517421007156372, 0.04017229378223419, 0.0, 0.0, 0.0, 0.0], [0.13557791709899902, 0.018924091011285782, 0.02187344618141651, 0.015362304635345936, 0.11512601375579834, 0.14739760756492615, 0.5457385182380676, 0.0, 0.0, 0.0, 0.02538878843188286, 0.005211540497839451, 0.03069700486958027, 0.13252338767051697, 0.4279623329639435, 0.0899164006114006, 0.28830063343048096, 0.0, 0.0, 0.0], [0.38992705941200256, 0.021535715088248253, 0.005403842777013779, 0.0032997699454426765, 0.4358868896961212, 0.06306594610214233, 0.03204012289643288, 0.04884066432714462, 0.0, 0.0, 0.010537173599004745, 0.0007831656257621944, 0.0007035965682007372, 0.015162549912929535, 0.9050821661949158, 0.05248205363750458, 0.01132790744304657, 0.00392116466537118, 0.0, 0.0], [0.81478351354599, 0.022238636389374733, 0.0008386021945625544, 0.01924033649265766, 0.06109088659286499, 0.020853841677308083, 0.014834966510534286, 0.028932424262166023, 0.017186695709824562, 0.0, 0.005222301464527845, 0.003575690556317568, 0.0029950442258268595, 0.00018454395467415452, 0.0012630765559151769, 0.01364975143224001, 0.09376595914363861, 0.853415846824646, 0.02592780999839306, 0.0], [0.011323019862174988, 0.004743177909404039, 0.004908193834125996, 0.04389021545648575, 0.9175272583961487, 0.008399821817874908, 0.00010120288789039478, 0.0007724545430392027, 0.001946530188433826, 0.006388010922819376, 0.14979584515094757, 0.0004723063320852816, 0.4970340430736542, 0.03214645013213158, 0.022075939923524857, 0.006538126152008772, 0.0013381451135501266, 0.0030305178370326757, 0.0008045822032727301, 0.28676414489746094]], [[0.023217031732201576, 0.015444980934262276, 0.33269768953323364, 0.4809305965900421, 0.08491171896457672, 0.027504485100507736, 0.007655052933841944, 0.015150148421525955, 0.012488299049437046, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003814368275925517, 0.0054845609702169895, 0.005400203168392181, 0.34217125177383423, 0.010647634975612164, 0.00044525362318381667, 0.00011972449283348396, 0.00042839962407015264, 0.6314883828163147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013448912650346756, 0.01028169970959425, 0.4982297718524933, 0.3182436525821686, 0.01780710555613041, 0.024587348103523254, 0.0009282209794037044, 0.11607228964567184, 0.0004009671974927187, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0027270291466265917, 0.01338754128664732, 0.019254636019468307, 0.11856623739004135, 0.0025901400949805975, 0.0012062221067026258, 0.0006161375786177814, 0.0012282256502658129, 0.8404240608215332, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.802536098693963e-05, 0.0005015733768232167, 2.3977232558536343e-05, 0.00012258262722752988, 0.00013862864580005407, 1.9367420463822782e-05, 1.2695372788584791e-05, 2.8395381377777085e-05, 0.9991349577903748, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.045823611319065094, 0.0060311248525977135, 0.11489683389663696, 0.011397628113627434, 0.14236140251159668, 0.31853923201560974, 0.18707275390625, 0.16781283915042877, 0.006064609158784151, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.031908370554447174, 0.0013231962220743299, 0.03774190694093704, 0.014869065955281258, 0.08836144208908081, 0.662682056427002, 0.1095389723777771, 0.05017231032252312, 0.0034025281202048063, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0061959377489984035, 0.012075785547494888, 0.28881579637527466, 0.0719127431511879, 0.08756363391876221, 0.0848873034119606, 0.027471251785755157, 0.404219388961792, 0.016858302056789398, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0946543961763382, 0.0623893216252327, 0.18748056888580322, 0.1788652539253235, 0.03208017721772194, 0.1587594598531723, 0.05469479411840439, 0.17047303915023804, 0.06060296297073364, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019481608644127846, 0.068674735724926, 0.13537795841693878, 0.2137300968170166, 0.031131863594055176, 0.02376358024775982, 0.030956387519836426, 0.04989796131849289, 0.4269856810569763, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9621535539627075, 0.037846412509679794, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9691458940505981, 0.03085414692759514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5398231148719788, 0.4385344386100769, 0.021642372012138367, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9338735938072205, 0.02144204080104828, 0.04468445107340813, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6502059698104858, 0.16868625581264496, 0.04876677691936493, 0.13234086334705353, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4091326594352722, 0.1788463294506073, 0.3530478775501251, 0.058973249047994614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5965072512626648, 0.06637387722730637, 0.1054789125919342, 0.1866345852613449, 0.04500538855791092, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8083640336990356, 0.0245783980935812, 0.02959858626127243, 0.02002020739018917, 0.11743883788585663, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3253602683544159, 0.03396952152252197, 0.02178867906332016, 0.07780158519744873, 0.04822142422199249, 0.49285849928855896, 0.0, 0.0, 0.0, 0.0, 0.6256738901138306, 0.03313886746764183, 0.03255102410912514, 0.015011090785264969, 0.27659764885902405, 0.017027597874403, 0.0, 0.0, 0.0, 0.0], [0.2524598240852356, 0.04065639525651932, 0.06012948602437973, 0.022925280034542084, 0.0371418297290802, 0.17370767891407013, 0.41297948360443115, 0.0, 0.0, 0.0, 0.2970131039619446, 0.01776941865682602, 0.015323061496019363, 0.014444534666836262, 0.2387886643409729, 0.36828577518463135, 0.048375438898801804, 0.0, 0.0, 0.0], [0.03411499038338661, 0.003937003668397665, 0.005961195565760136, 0.01710909977555275, 0.011033114977180958, 0.7081340551376343, 0.13750500977039337, 0.08220544457435608, 0.0, 0.0, 0.16347570717334747, 0.01386126596480608, 0.012116431258618832, 0.006670618429780006, 0.5951986312866211, 0.1577492356300354, 0.024585027247667313, 0.02634291537106037, 0.0, 0.0], [0.42400264739990234, 0.02131979539990425, 0.017963027581572533, 0.01083337515592575, 0.019156770780682564, 0.14712399244308472, 0.1343262642621994, 0.19853995740413666, 0.02673417516052723, 0.0, 0.1568753868341446, 0.002166055142879486, 0.0014692704426124692, 0.009539359249174595, 0.7249224781990051, 0.0696585550904274, 0.02269914373755455, 0.010646837763488293, 0.0020231890957802534, 0.0], [0.010900852270424366, 0.01643177680671215, 0.007438827771693468, 0.037741534411907196, 0.0038807683158665895, 0.513563871383667, 0.17121337354183197, 0.14364023506641388, 0.04466766491532326, 0.050521109253168106, 0.6687246561050415, 0.003988182172179222, 0.00992897991091013, 0.00877397134900093, 0.07160260528326035, 0.14080072939395905, 0.01739262230694294, 0.04941429942846298, 0.01782085746526718, 0.011553076095879078]], [[0.00896595511585474, 0.001820763573050499, 0.0036846648436039686, 0.8942996859550476, 0.002699120668694377, 0.0018430916825309396, 0.00023619653075002134, 0.0008667120710015297, 0.08558366447687149, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011139868758618832, 0.00517098605632782, 0.03486357256770134, 0.92783522605896, 0.010794212110340595, 0.0029791113920509815, 0.0008399260113947093, 0.0003134821599815041, 0.006063643377274275, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07888396829366684, 0.0272236131131649, 0.0322146937251091, 0.791079044342041, 0.03133838623762131, 0.009372375905513763, 0.002263500588014722, 0.0005359782953746617, 0.02708848938345909, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008838528767228127, 0.0009813528740778565, 0.014693140052258968, 0.00012726498243864626, 0.013269715011119843, 0.06431703269481659, 0.0039668334648013115, 0.8607616424560547, 0.0330444760620594, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.028727378696203232, 0.001701394678093493, 0.0009593431605026126, 0.0036824517883360386, 0.009683175943791866, 0.2589351236820221, 0.040837112814188004, 0.01649528741836548, 0.6389787197113037, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009239337407052517, 0.0011580593418329954, 0.0009623299702070653, 0.000996780814602971, 0.00493139773607254, 0.04319336265325546, 0.859686553478241, 0.012395362369716167, 0.06743697822093964, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.024199873208999634, 0.007249501068145037, 0.02041051909327507, 0.008800184354186058, 0.02760438062250614, 0.1116553395986557, 0.030366744846105576, 0.03851965814828873, 0.7311937808990479, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06881897896528244, 0.21671976149082184, 0.02303808182477951, 0.0017656114650890231, 0.09897635877132416, 0.04207116737961769, 0.012660021893680096, 0.25307658314704895, 0.2828734517097473, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09324429929256439, 0.059572815895080566, 0.021969754248857498, 0.008625463582575321, 0.022502752020955086, 0.07016356289386749, 0.033860694617033005, 0.03514377400279045, 0.6549169421195984, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04541633278131485, 0.01696496643126011, 0.003866765182465315, 0.00941139180213213, 0.006640681531280279, 0.024550199508666992, 0.009012367576360703, 0.009869653731584549, 0.8742677569389343, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4730486273765564, 0.5269513726234436, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.497504860162735, 0.502495288848877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39858773350715637, 0.07930062711238861, 0.5221116542816162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.028444888070225716, 0.01678420603275299, 0.9547709822654724, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5825604200363159, 0.08404675871133804, 0.15067298710346222, 0.182719886302948, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02853180095553398, 0.022399114444851875, 0.7835201025009155, 0.1655489057302475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.29498350620269775, 0.03899451717734337, 0.00506106112152338, 0.006130008026957512, 0.6548308730125427, 0.0, 0.0, 0.0, 0.0, 0.0, 0.023048963397741318, 0.055082567036151886, 0.3371332883834839, 0.25099456310272217, 0.33374062180519104, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13055028021335602, 0.007264712825417519, 0.014658198691904545, 0.03852052241563797, 0.6908979415893555, 0.11810839176177979, 0.0, 0.0, 0.0, 0.0, 0.013693265616893768, 0.057373203337192535, 0.02566814236342907, 0.11711565405130386, 0.13761301338672638, 0.6485366225242615, 0.0, 0.0, 0.0, 0.0], [0.6701509952545166, 0.016114505007863045, 0.009837295860052109, 0.013812566176056862, 0.10121432691812515, 0.04637172445654869, 0.14249859750270844, 0.0, 0.0, 0.0, 0.5831283926963806, 0.0857725590467453, 0.06227085366845131, 0.03169894590973854, 0.06183577701449394, 0.01752074435353279, 0.15777261555194855, 0.0, 0.0, 0.0], [0.15980258584022522, 0.02680308185517788, 0.03885137289762497, 0.01341771800071001, 0.16442187130451202, 0.12716332077980042, 0.3698134124279022, 0.09972671419382095, 0.0, 0.0, 0.0033312023151665926, 0.003545752028003335, 0.0018331086030229926, 0.05265560373663902, 0.047756411135196686, 0.045255228877067566, 0.20667387545108795, 0.6389486193656921, 0.0, 0.0], [0.5671898722648621, 0.0029452391900122166, 0.0006932761170901358, 0.0009682640084065497, 0.008882325142621994, 0.018135691061615944, 0.19489231705665588, 0.1878870278596878, 0.01840599626302719, 0.0, 0.02047032117843628, 0.03542931377887726, 0.01270933635532856, 0.46998995542526245, 0.035482652485370636, 0.015606570988893509, 0.1128709465265274, 0.03180817514657974, 0.26563259959220886, 0.0], [0.10793960839509964, 0.02733222208917141, 0.05983218923211098, 0.007959540002048016, 0.012123869732022285, 0.0992540642619133, 0.031409986317157745, 0.1074245497584343, 0.5389924645423889, 0.007731476798653603, 0.027955254539847374, 0.024354776367545128, 0.4609973132610321, 0.07958999276161194, 0.34062448143959045, 0.0068156360648572445, 0.000798556546214968, 0.0009541919571347535, 0.00023223790049087256, 0.05767740309238434]], [[0.007143852766603231, 0.26111796498298645, 0.053768061101436615, 0.022731401026248932, 0.014146089553833008, 0.012985849753022194, 0.007359612733125687, 0.0043042986653745174, 0.6164429783821106, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6188079118728638, 0.11004422605037689, 0.07541824132204056, 0.010463211685419083, 0.003863272722810507, 0.016659650951623917, 0.028880171477794647, 0.010046081617474556, 0.1258174628019333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1673126220703125, 0.6515482068061829, 0.016748156398534775, 0.042502570897340775, 0.016912223771214485, 0.011716129258275032, 0.04548521339893341, 0.0008787817787379026, 0.04689598083496094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6625117659568787, 0.049922335892915726, 0.2738172709941864, 0.004228150937706232, 0.0033112652599811554, 0.001177642960101366, 0.0005330604617483914, 0.00011132613872177899, 0.0043872324749827385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0040039438754320145, 0.00112480903044343, 0.04353015124797821, 0.9313303232192993, 0.010056668892502785, 0.0007567661814391613, 0.0006773694767616689, 0.00016374654660467058, 0.008356312289834023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0015825676964595914, 0.001574154943227768, 0.001225732616148889, 0.27774307131767273, 0.47191065549850464, 0.041899941861629486, 0.10331469774246216, 0.0047262245789170265, 0.0960230901837349, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0012044805334880948, 0.001744594657793641, 0.010911357589066029, 0.035235531628131866, 0.12406003475189209, 0.49639585614204407, 0.02129644714295864, 0.07618547230958939, 0.23296628892421722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006665610242635012, 0.008957373909652233, 0.0028928713873028755, 0.7268922924995422, 0.10707614570856094, 0.01201178040355444, 0.013845101930201054, 0.022992080077528954, 0.09866661578416824, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04134861007332802, 0.0526767373085022, 0.04131396487355232, 0.023087071254849434, 0.04077164828777313, 0.027765633538365364, 0.05679082125425339, 0.025407245382666588, 0.6908383369445801, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0021411015186458826, 0.012145284563302994, 0.008635377511382103, 0.004571457393467426, 0.009789393283426762, 0.022923681885004044, 0.019266795367002487, 0.15913596749305725, 0.7613908052444458, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9029706120491028, 0.09702935069799423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.909331738948822, 0.09066825360059738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0243820920586586, 0.026593990623950958, 0.9490237236022949, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3865880072116852, 0.017979737371206284, 0.5954321622848511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002445725491270423, 0.01137782447040081, 0.2685152590274811, 0.7176609635353088, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5815560817718506, 0.15706834197044373, 0.052335821092128754, 0.2090395838022232, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013372139073908329, 0.017163371667265892, 0.023703746497631073, 0.029362313449382782, 0.916398286819458, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7268465161323547, 0.0363004133105278, 0.07873083651065826, 0.06576839834451675, 0.09235385805368423, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009483089670538902, 0.0015323974657803774, 0.016186771914362907, 0.02369842305779457, 0.15252061188220978, 0.7965786457061768, 0.0, 0.0, 0.0, 0.0, 0.6550694704055786, 0.019533857703208923, 0.042362816631793976, 0.07321250438690186, 0.06519921869039536, 0.14462217688560486, 0.0, 0.0, 0.0, 0.0], [0.9718639850616455, 0.0004310244112275541, 0.00011954548244830221, 0.007853196933865547, 0.005200029816478491, 0.0034086843952536583, 0.011123435571789742, 0.0, 0.0, 0.0, 0.48597243428230286, 0.05253118649125099, 0.06572883576154709, 0.06831242144107819, 0.06681334227323532, 0.09225586801767349, 0.168385848402977, 0.0, 0.0, 0.0], [0.24338993430137634, 0.0009381886338815093, 0.001691973302513361, 0.004991883412003517, 0.06480661034584045, 0.02667633630335331, 0.5911706686019897, 0.06633439660072327, 0.0, 0.0, 0.2283225953578949, 0.01085133571177721, 0.0076954541727900505, 0.03403906524181366, 0.05505141243338585, 0.11318682134151459, 0.23008716106414795, 0.3207661509513855, 0.0, 0.0], [0.9644694328308105, 0.00020931981271132827, 0.00022034233552403748, 0.001116775325499475, 0.0005140798166394234, 0.011200232431292534, 0.006607241928577423, 0.015303434804081917, 0.00035933865001425147, 0.0, 0.31019407510757446, 0.01576145552098751, 0.006604246329516172, 0.1025082990527153, 0.11805430799722672, 0.0999068170785904, 0.17944715917110443, 0.09494999051094055, 0.07257375121116638, 0.0], [6.446504994528368e-05, 5.223282641964033e-05, 4.761212403536774e-05, 0.0026887860149145126, 0.9879595041275024, 8.169181819539517e-05, 3.4432316169841215e-05, 0.00022215544595383108, 0.008540215902030468, 0.0003085293574258685, 0.028495613485574722, 0.00728303287178278, 0.028978589922189713, 0.21746259927749634, 0.0312367994338274, 0.01134485937654972, 0.002138715935871005, 0.0005697175511159003, 0.00012198994954815134, 0.6723678112030029]], [[0.1481553614139557, 0.14691436290740967, 0.5575758218765259, 0.02441403828561306, 0.058879025280475616, 0.011832842603325844, 0.01016098354011774, 0.015505112707614899, 0.026562504470348358, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20737145841121674, 0.2658809721469879, 0.4251604974269867, 0.03998560830950737, 0.012661930173635483, 0.003662273520603776, 0.0006891252705827355, 0.004390099551528692, 0.040197838097810745, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09591562300920486, 0.13717111945152283, 0.23219715058803558, 0.020156029611825943, 0.031411558389663696, 0.04842779412865639, 0.003137993859127164, 0.03623202070593834, 0.3953508734703064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.039530280977487564, 0.0770052894949913, 0.4637334942817688, 0.3284752666950226, 0.018390586599707603, 0.021701356396079063, 0.0038800504989922047, 0.01712900958955288, 0.03015456721186638, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11945435404777527, 0.064958855509758, 0.07506249845027924, 0.03312050923705101, 0.045947931706905365, 0.21168209612369537, 0.1585550606250763, 0.15941208600997925, 0.13180643320083618, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.027811916545033455, 0.005367752630263567, 0.022701909765601158, 0.02928026206791401, 0.042085714638233185, 0.23124846816062927, 0.3448639512062073, 0.17164644598960876, 0.12499356269836426, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04207267239689827, 0.0036673955619335175, 0.013221505098044872, 0.04823020473122597, 0.018784234300255775, 0.15617071092128754, 0.5762590169906616, 0.08817830681800842, 0.053416069597005844, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.016374358907341957, 0.01844160258769989, 0.0564517118036747, 0.008724928833544254, 0.031119121238589287, 0.08068697899580002, 0.028958600014448166, 0.08762799203395844, 0.6716147661209106, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01953587494790554, 0.025442129001021385, 0.033712055534124374, 0.054231878370046616, 0.046861547976732254, 0.038379911333322525, 0.03105914779007435, 0.027592265978455544, 0.723185122013092, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13706038892269135, 0.05296454578638077, 0.06056801974773407, 0.10271193832159042, 0.10989244282245636, 0.11971112340688705, 0.10623226314783096, 0.11503037810325623, 0.19582894444465637, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9361864924430847, 0.06381344050168991, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9973775148391724, 0.0026223897002637386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8988155722618103, 0.08033642917871475, 0.020848000422120094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8086934685707092, 0.08078567683696747, 0.11052089184522629, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9107179045677185, 0.0414138063788414, 0.01669401116669178, 0.031174303963780403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14967022836208344, 0.05171789228916168, 0.3914002478122711, 0.40721163153648376, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7451518774032593, 0.09319417923688889, 0.038068220019340515, 0.07867664098739624, 0.04490913450717926, 0.0, 0.0, 0.0, 0.0, 0.0, 0.40780311822891235, 0.04434635117650032, 0.05232110992074013, 0.3448564112186432, 0.15067294239997864, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7448095083236694, 0.053781699389219284, 0.02206255868077278, 0.051568105816841125, 0.050503022968769073, 0.07727508991956711, 0.0, 0.0, 0.0, 0.0, 0.3417491614818573, 0.023165758699178696, 0.008621969260275364, 0.03819064050912857, 0.566249430179596, 0.022023199126124382, 0.0, 0.0, 0.0, 0.0], [0.49530187249183655, 0.08432565629482269, 0.024537190794944763, 0.03536847233772278, 0.04101351276040077, 0.2942921817302704, 0.025161173194646835, 0.0, 0.0, 0.0, 0.13283461332321167, 0.0027981544844806194, 0.001892031985335052, 0.057958006858825684, 0.4807162284851074, 0.22431829571723938, 0.09948258846998215, 0.0, 0.0, 0.0], [0.33749768137931824, 0.0470278225839138, 0.025994539260864258, 0.11184448003768921, 0.035708073526620865, 0.3288814127445221, 0.052594561129808426, 0.06045151129364967, 0.0, 0.0, 0.5702553391456604, 0.005225116387009621, 0.0014312443090602756, 0.028526127338409424, 0.15899939835071564, 0.05284468084573746, 0.022491520270705223, 0.16022635996341705, 0.0, 0.0], [0.5827996730804443, 0.037185750901699066, 0.025691334158182144, 0.040444474667310715, 0.032313525676727295, 0.15237869322299957, 0.02532070316374302, 0.06300554424524307, 0.040860243141651154, 0.0, 0.005209033377468586, 6.901475717313588e-05, 5.760595013271086e-05, 0.006149875931441784, 0.006613760255277157, 0.010193211026489735, 0.013639912940561771, 0.9578513503074646, 0.00021631908020935953, 0.0], [0.17712005972862244, 0.06858222186565399, 0.023361776024103165, 0.06553570926189423, 0.015878353267908096, 0.40178799629211426, 0.03335757926106453, 0.09457883983850479, 0.06679747253656387, 0.05300001800060272, 0.5839820504188538, 0.007275882177054882, 0.03890826180577278, 0.2169828861951828, 0.02285575494170189, 0.0033320344518870115, 0.0027764069382101297, 0.032896872609853745, 0.038299210369586945, 0.05269058048725128]], [[0.1995955854654312, 0.1365700364112854, 0.06844333559274673, 0.10430964082479477, 0.06450515240430832, 0.046256136149168015, 0.1181989535689354, 0.11867640167474747, 0.14344464242458344, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11168574541807175, 0.19221879541873932, 0.09424428641796112, 0.10450402647256851, 0.06917304545640945, 0.0600862056016922, 0.14199501276016235, 0.09375911951065063, 0.1323338896036148, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1609925925731659, 0.1396498680114746, 0.177944153547287, 0.03334498405456543, 0.016808874905109406, 0.10536731034517288, 0.1187783032655716, 0.03365077078342438, 0.2134632021188736, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1381153017282486, 0.04422784969210625, 0.04791303351521492, 0.16848880052566528, 0.14531251788139343, 0.08485772460699081, 0.03650972992181778, 0.08906612545251846, 0.24550898373126984, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05420248210430145, 0.02658572979271412, 0.05446610227227211, 0.10749125480651855, 0.22097598016262054, 0.16638338565826416, 0.0331658273935318, 0.035041358321905136, 0.30168798565864563, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13837963342666626, 0.02727973647415638, 0.1299334168434143, 0.0896206796169281, 0.11551950126886368, 0.13963927328586578, 0.07841819524765015, 0.02172034978866577, 0.25948914885520935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08648376911878586, 0.022083481773734093, 0.023758457973599434, 0.19388236105442047, 0.1724909394979477, 0.02776450663805008, 0.04756799340248108, 0.03839871287345886, 0.38756975531578064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06476524472236633, 0.0030945511534810066, 0.016289785504341125, 0.013512126170098782, 0.007217712234705687, 0.047962453216314316, 0.03755675256252289, 0.02134101279079914, 0.788260281085968, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06007339805364609, 0.038077425211668015, 0.01732070930302143, 0.04335314407944679, 0.09849875420331955, 0.07123422622680664, 0.12536978721618652, 0.17402620613574982, 0.37204641103744507, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.31619662046432495, 0.10629935562610626, 0.051193755120038986, 0.08206456899642944, 0.08056272566318512, 0.05727463215589523, 0.13476009666919708, 0.03839832916855812, 0.13325001299381256, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5274816155433655, 0.47251835465431213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.050016310065984726, 0.9499835968017578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0672292709350586, 0.8769893646240234, 0.05578138306736946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1501082479953766, 0.3363426625728607, 0.5135491490364075, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13304242491722107, 0.3016340434551239, 0.1132093071937561, 0.45211419463157654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3278971016407013, 0.3615517318248749, 0.08450257778167725, 0.22604861855506897, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3050541281700134, 0.12257003039121628, 0.15977424383163452, 0.12758392095565796, 0.2850176990032196, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25667694211006165, 0.40780505537986755, 0.15422186255455017, 0.10097997635602951, 0.08031607419252396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22356735169887543, 0.03928647190332413, 0.007754397578537464, 0.009327426552772522, 0.12143179029226303, 0.5986325740814209, 0.0, 0.0, 0.0, 0.0, 0.06381947547197342, 0.026328660547733307, 0.009785240516066551, 0.001955200219526887, 0.6504424810409546, 0.24766895174980164, 0.0, 0.0, 0.0, 0.0], [0.21462960541248322, 0.06222677230834961, 0.03770677372813225, 0.020617984235286713, 0.1298619657754898, 0.16450734436511993, 0.3704494833946228, 0.0, 0.0, 0.0, 0.20312389731407166, 0.04875075817108154, 0.01619674637913704, 0.028123438358306885, 0.08579143136739731, 0.20417368412017822, 0.41383999586105347, 0.0, 0.0, 0.0], [0.03640613704919815, 0.010346643626689911, 0.00673291739076376, 0.007102567236870527, 0.047351155430078506, 0.07502260059118271, 0.1735789030790329, 0.6434589624404907, 0.0, 0.0, 0.0028631098102778196, 0.00043423930765129626, 0.00021322182146832347, 0.0004598038794938475, 0.009248310700058937, 0.010348621755838394, 0.14874719083309174, 0.8276853561401367, 0.0, 0.0], [0.10655857622623444, 0.005401281639933586, 0.008467022329568863, 0.004935698118060827, 0.02920999936759472, 0.06761414557695389, 0.11367721855640411, 0.6410130262374878, 0.02312297187745571, 0.0, 0.0018436884274706244, 0.00014077842934057117, 0.00019285999587737024, 0.0006260477821342647, 0.010675687342882156, 0.007219970691949129, 0.05410425364971161, 0.9211149215698242, 0.0040815602988004684, 0.0], [0.022663118317723274, 0.01638328656554222, 0.016234278678894043, 0.013438239693641663, 0.13762539625167847, 0.1316443532705307, 0.10126813501119614, 0.1815005987882614, 0.10885223746299744, 0.27039045095443726, 0.0996592566370964, 0.00012745348794851452, 0.0005518114776350558, 0.00026576913660392165, 0.00016320105351042002, 9.30886235437356e-05, 0.00024295282491948456, 0.000212193961488083, 0.0026789114344865084, 0.8960053324699402]]], [[[0.0077307759784162045, 0.013184988871216774, 0.016869038343429565, 0.013336911797523499, 0.01304439827799797, 0.013718237169086933, 0.0296618789434433, 0.02448520064353943, 0.8679684996604919, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1374177634716034, 0.018056754022836685, 0.029763542115688324, 0.004862301517277956, 0.00231130956672132, 0.006278112530708313, 0.012106452137231827, 0.033879589289426804, 0.755324125289917, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1539316475391388, 0.23461903631687164, 0.033691998571157455, 0.026462335139513016, 0.0030949951615184546, 0.0038835303857922554, 0.009438932873308659, 0.0025479686446487904, 0.5323294997215271, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.29215019941329956, 0.05790534242987633, 0.18934868276119232, 0.018473153933882713, 0.002999690594151616, 0.004652327857911587, 0.010374259203672409, 0.0072145056910812855, 0.4168816804885864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16539201140403748, 0.05819307267665863, 0.12084146589040756, 0.1738077849149704, 0.004504370968788862, 0.006831282749772072, 0.02180996537208557, 0.012287246994674206, 0.4363327622413635, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017602156847715378, 0.026805447414517403, 0.07671570032835007, 0.5152483582496643, 0.21202509105205536, 0.041201505810022354, 0.02207496576011181, 0.00952092744410038, 0.07880578190088272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017750855535268784, 0.01654047518968582, 0.07482129335403442, 0.23223723471164703, 0.3542158007621765, 0.16141267120838165, 0.02249749004840851, 0.044686269015073776, 0.07583795487880707, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01755565032362938, 0.008823209442198277, 0.013083218596875668, 0.5061533451080322, 0.02344801276922226, 0.0075739468447864056, 0.07187878340482712, 0.029167035594582558, 0.3223167061805725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07037408649921417, 0.05453738570213318, 0.05508268624544144, 0.02769530564546585, 0.038050826638936996, 0.20446287095546722, 0.19980187714099884, 0.19835616648197174, 0.15163862705230713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003375247586518526, 0.008793321438133717, 0.001001630094833672, 0.002094975672662258, 0.0032946632709354162, 0.01792662777006626, 0.10988471657037735, 0.21093924343585968, 0.6426896452903748, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5484673976898193, 0.11615661531686783, 0.018765881657600403, 0.05580288916826248, 0.007166780531406403, 0.0032917018979787827, 0.014802427962422371, 0.009165346622467041, 0.2263808399438858, 0.0, 0.11433855444192886, 0.04686390608549118, 0.05050662159919739, 0.01692698895931244, 0.014815520495176315, 0.0768260732293129, 0.017432983964681625, 0.6101956367492676, 0.052093639969825745, 0.0], [0.05529153719544411, 0.9114729762077332, 0.02160518430173397, 0.004567866213619709, 0.0020856212358921766, 0.002110318513587117, 7.9427394666709e-05, 0.0003330546023789793, 0.002454147208482027, 0.0, 0.38972416520118713, 0.22944767773151398, 0.07904881238937378, 0.052096493542194366, 0.007011168636381626, 0.006784902419894934, 0.0014207185013219714, 0.13426420092582703, 0.10020176321268082, 0.0], [0.05249117314815521, 0.7437799572944641, 0.20206855237483978, 0.000259216787526384, 5.300459815771319e-05, 0.0010736108524724841, 1.3093422239762731e-05, 5.472580232890323e-05, 0.00020689083612523973, 0.0, 0.11186019331216812, 0.017331605777144432, 0.04338392615318298, 0.012996343895792961, 0.0037596735637634993, 0.04186789691448212, 0.010188347660005093, 0.40130615234375, 0.35730576515197754, 0.0], [0.00011124516458949074, 0.00014391898002941161, 0.002859711181372404, 0.9545366168022156, 0.03663090988993645, 0.0007743826135993004, 9.755761857377365e-05, 0.004461521748453379, 0.00038416660390794277, 0.0, 0.10591340810060501, 0.20223784446716309, 0.2886512279510498, 0.06966178864240646, 0.009836114011704922, 0.0229511596262455, 0.008547846227884293, 0.14636646211147308, 0.14583423733711243, 0.0], [0.00743946572765708, 0.0006693374016322196, 0.030706975609064102, 0.7693524360656738, 0.13636630773544312, 0.010656076483428478, 0.00020547783060465008, 0.04430907592177391, 0.0002948205510620028, 0.0, 0.07172418385744095, 0.2936038672924042, 0.03526498004794121, 0.13891401886940002, 0.06139945611357689, 0.03925776481628418, 0.05349786579608917, 0.1478489190340042, 0.15848881006240845, 0.0], [0.34355977177619934, 0.05352164804935455, 0.28276264667510986, 0.0013746530748903751, 0.09131773561239243, 0.14369648694992065, 0.012657100334763527, 0.0018211203860118985, 0.06928855180740356, 0.0, 0.06896000355482101, 0.07670743763446808, 0.04746192321181297, 0.0077508557587862015, 0.0064402432180941105, 0.0690828412771225, 0.07239065319299698, 0.3534849286079407, 0.2977212965488434, 0.0], [0.00012065855844411999, 1.9836104911519215e-05, 1.09456195787061e-05, 1.777518991730176e-05, 5.799566861242056e-05, 0.00024955562548711896, 0.9993764758110046, 2.622428155518719e-06, 0.00014374956663232297, 0.0, 0.07854402810335159, 0.05315924435853958, 0.006970829796046019, 0.01197806466370821, 0.15678070485591888, 0.059328123927116394, 0.0844779834151268, 0.058106619864702225, 0.4906543493270874, 0.0], [0.07937772572040558, 0.06896578520536423, 0.046331409364938736, 0.0006753505440428853, 0.10991709679365158, 0.07862550020217896, 0.015291865915060043, 0.024935398250818253, 0.575879693031311, 0.0, 0.03681005910038948, 0.0140389958396554, 0.007565703243017197, 0.004180380143225193, 0.03550711274147034, 0.08768045157194138, 0.04672156274318695, 0.05331620201468468, 0.7141793966293335, 0.0], [0.0013021372724324465, 0.0014324801741167903, 0.001721968175843358, 0.0011953436769545078, 0.025524066761136055, 0.0017154657980427146, 0.004751213360577822, 0.06856247782707214, 0.8937948942184448, 0.0, 0.011276278644800186, 0.0043571279384195805, 0.0015699869254603982, 0.009309794753789902, 0.5466312766075134, 0.06633520126342773, 0.012565904296934605, 0.036605555564165115, 0.3113488256931305, 0.0], [0.19237911701202393, 0.08248087018728256, 0.005975060164928436, 0.0005637910799123347, 0.0032617340330034494, 0.08159960061311722, 0.10672623664140701, 0.06415636837482452, 0.4628572165966034, 0.0, 0.06600549072027206, 0.025541657581925392, 0.15132947266101837, 0.052793603390455246, 0.07684693485498428, 0.05682613328099251, 0.01590665802359581, 0.05306769907474518, 0.501682460308075, 0.0]], [[0.07041527330875397, 0.15367093682289124, 0.3963199257850647, 0.03077671490609646, 0.0928598940372467, 0.04086732864379883, 0.018142100423574448, 0.012120239436626434, 0.18482762575149536, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0367966964840889, 0.022482391446828842, 0.35830163955688477, 0.02875097654759884, 0.03547174483537674, 0.026731541380286217, 0.005365677177906036, 0.038472291082143784, 0.4476269483566284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.035722482949495316, 0.013633755035698414, 0.09877835214138031, 0.0896211713552475, 0.16777706146240234, 0.10725134611129761, 0.05053357034921646, 0.10712091624736786, 0.32956135272979736, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026592494919896126, 0.002020884770900011, 0.010739283636212349, 0.015951883047819138, 0.18538028001785278, 0.16766443848609924, 0.03731367364525795, 0.3853055238723755, 0.16903170943260193, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0038862666115164757, 0.0004870722477789968, 0.0013956124894320965, 0.001421120367012918, 0.013834443874657154, 0.18579153716564178, 0.18213258683681488, 0.5258954763412476, 0.08515587449073792, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002724156714975834, 0.004415807779878378, 0.003638928523287177, 0.00862019695341587, 0.010569852776825428, 0.18068262934684753, 0.2256886065006256, 0.3616458773612976, 0.20201392471790314, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004798010922968388, 0.006960091646760702, 0.005558326840400696, 0.015252271667122841, 0.010058294981718063, 0.16163121163845062, 0.19844789803028107, 0.089196115732193, 0.508097767829895, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006148195825517178, 0.009931370615959167, 0.0022139709908515215, 0.003481896361336112, 0.00199966412037611, 0.011451391503214836, 0.018514955416321754, 0.10389390587806702, 0.8423647284507751, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1474374532699585, 0.1116698756814003, 0.10040155798196793, 0.07832593470811844, 0.051569730043411255, 0.103182852268219, 0.0987909808754921, 0.06264416873455048, 0.24597744643688202, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0029639359563589096, 0.0008504446013830602, 0.002215511864051223, 0.0016108372947201133, 0.001786046545021236, 0.003435377962887287, 0.000923731888178736, 0.009203500114381313, 0.9770104885101318, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02611132338643074, 0.03929625079035759, 0.13154497742652893, 0.0362381711602211, 0.41634607315063477, 0.056842975318431854, 0.07852181792259216, 0.04581860452890396, 0.16927982866764069, 0.0, 0.02032626047730446, 0.020203230902552605, 0.6020291447639465, 0.030009541660547256, 0.018654465675354004, 0.18802858889102936, 0.05143868178129196, 0.02303317002952099, 0.04627683013677597, 0.0], [0.02255222387611866, 0.10361888259649277, 0.8158259987831116, 0.004235901869833469, 0.006766538135707378, 0.000986771541647613, 0.0032120062969624996, 0.001885716337710619, 0.04091576859354973, 0.0, 0.0019176346249878407, 0.016540443524718285, 0.9535443782806396, 0.023291945457458496, 0.00010982116509694606, 0.0009689715225249529, 0.00013850984396412969, 0.00020587266772054136, 0.0032822038047015667, 0.0], [0.017373425886034966, 0.13537107408046722, 0.8077713847160339, 0.0038886782713234425, 0.000276644917903468, 0.0006381930434145033, 0.00045153097016736865, 0.00025059780455194414, 0.0339784249663353, 0.0, 0.0021921033039689064, 0.01942657120525837, 0.9639623165130615, 0.007353355176746845, 5.936318120802753e-05, 0.0048356493934988976, 6.20722203166224e-05, 0.00038220148417167366, 0.0017266407376155257, 0.0], [0.005007221829146147, 0.01780957728624344, 0.01267488207668066, 0.04065018519759178, 0.30516332387924194, 0.0026367236860096455, 0.0019572318997234106, 0.004150545224547386, 0.6099502444267273, 0.0, 0.009460263885557652, 0.051493529230356216, 0.3948231041431427, 0.009338779374957085, 0.006950095761567354, 0.48816555738449097, 0.01867900975048542, 0.0028053568676114082, 0.018284155055880547, 0.0], [0.00225767120718956, 0.0031865497585386038, 0.001125291339121759, 0.016497144475579262, 0.8971690535545349, 0.00343481102026999, 0.003961168695241213, 0.006191920954734087, 0.0661763921380043, 0.0, 0.003424277761951089, 0.009652957320213318, 0.10005363076925278, 0.026136387139558792, 0.09182560443878174, 0.6324647068977356, 0.07658552378416061, 0.005459657870233059, 0.05439731851220131, 0.0], [0.009628614410758018, 0.010054895654320717, 0.001336919842287898, 0.0704738199710846, 0.6877674460411072, 0.03301373869180679, 0.05187760666012764, 0.005273953080177307, 0.1305730938911438, 0.0, 0.0026005429681390524, 0.0029943317640572786, 0.26352012157440186, 0.02426978573203087, 0.05801504850387573, 0.49633610248565674, 0.11849544942378998, 0.009708826430141926, 0.02405967190861702, 0.0], [0.004103087354451418, 0.0010862533235922456, 0.0006940921302884817, 0.005870609078556299, 0.43826234340667725, 0.030803751200437546, 0.2956492602825165, 0.002342070685699582, 0.2211885154247284, 0.0, 0.0027453943621367216, 0.0021119171287864447, 0.0030521987937390804, 0.09308812767267227, 0.28145554661750793, 0.015254919417202473, 0.530491828918457, 0.007408978417515755, 0.06439103186130524, 0.0], [0.0007035931921564043, 0.0015657383482903242, 0.0003329406026750803, 0.025085464119911194, 0.8715798258781433, 0.006046876311302185, 0.002586639951914549, 0.00011169366916874424, 0.09198720753192902, 0.0, 0.0012973180273547769, 0.002199073787778616, 0.0031004296615719795, 0.024488963186740875, 0.8535729050636292, 0.016068320721387863, 0.029179612174630165, 0.012250186875462532, 0.05784311145544052, 0.0], [0.0021814475767314434, 0.0018482444575056434, 0.02461252734065056, 0.02290530502796173, 0.17733190953731537, 0.007551506161689758, 0.026218494400382042, 0.1859409213066101, 0.5514096021652222, 0.0, 0.0010010729311034083, 0.0008253253763541579, 0.0028483583591878414, 0.028342707082629204, 0.860925018787384, 0.0038871155120432377, 0.006998666562139988, 0.01413769368082285, 0.08103384077548981, 0.0], [0.002560347318649292, 0.0069580040872097015, 0.0021583843044936657, 0.002428637584671378, 0.010794135741889477, 0.002866419730708003, 0.010929176583886147, 0.004671781323850155, 0.9566330909729004, 0.0, 0.007383578456938267, 0.056256651878356934, 0.5807297825813293, 0.01667044125497341, 0.03810223564505577, 0.07880110293626785, 0.009197888895869255, 0.12926581501960754, 0.08359251171350479, 0.0]], [[0.022387586534023285, 0.045972827821969986, 0.05835629999637604, 0.22869053483009338, 0.010770916007459164, 0.006216464098542929, 0.018148910254240036, 0.006308646872639656, 0.6031478047370911, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0035997454542666674, 0.002674269489943981, 0.016009783372282982, 0.05554450675845146, 0.0013587778666988015, 0.0032801039051264524, 0.00560772093012929, 0.00799081102013588, 0.9039342403411865, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0018151046242564917, 0.001049908110871911, 0.0005912692868150771, 0.005136367864906788, 0.0005621784366667271, 0.00844560656696558, 0.017937110736966133, 0.008342047221958637, 0.9561205506324768, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004429707303643227, 0.005516116041690111, 0.003033371875062585, 0.012963998131453991, 0.0034379358403384686, 0.003276604227721691, 0.0140963364392519, 0.005416945554316044, 0.9478288888931274, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013230006210505962, 0.011804360896348953, 0.009972047992050648, 0.004975683055818081, 0.008386109955608845, 0.18977868556976318, 0.1806434541940689, 0.03204761818051338, 0.549161970615387, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006766628473997116, 0.008349079638719559, 0.003925195895135403, 0.0006033667596057057, 0.006175691727548838, 0.2236345112323761, 0.03405819088220596, 0.07976362109184265, 0.6367236971855164, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03860252723097801, 0.01646261475980282, 0.02104821614921093, 0.0021387943997979164, 0.005319601856172085, 0.2400989532470703, 0.03188503161072731, 0.005558657925575972, 0.6388856768608093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0738457515835762, 0.018826894462108612, 0.0069308048114180565, 0.0074225678108632565, 0.004789229016751051, 0.046955253928899765, 0.11907684803009033, 0.18744726479053497, 0.5347052812576294, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09479796141386032, 0.11939200013875961, 0.0752992108464241, 0.061374519020318985, 0.08638977259397507, 0.12459041178226471, 0.16023214161396027, 0.0879756435751915, 0.1899482160806656, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.027537798509001732, 0.06296242028474808, 0.014751194976270199, 0.0011882808757945895, 0.016387099400162697, 0.15830224752426147, 0.03707461059093475, 0.028470970690250397, 0.6533253788948059, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10015174746513367, 0.09576243162155151, 0.03824060782790184, 0.020538825541734695, 0.027732321992516518, 0.017240623012185097, 0.011470633558928967, 0.4632270634174347, 0.22563566267490387, 0.0, 0.023356424644589424, 0.012650059536099434, 0.05017145350575447, 0.05590398982167244, 0.05159280076622963, 0.01602507382631302, 0.014807065948843956, 0.654244601726532, 0.12124844640493393, 0.0], [0.018642960116267204, 0.04822106286883354, 0.0140958521515131, 0.13092586398124695, 0.031955283135175705, 0.05195324495434761, 0.024238048121333122, 0.35591286420822144, 0.32405486702919006, 0.0, 0.030835414305329323, 0.04180247709155083, 0.029645785689353943, 0.20071062445640564, 0.010328685864806175, 0.03208288922905922, 0.026780622079968452, 0.457701712846756, 0.17011170089244843, 0.0], [0.04842180013656616, 0.06019889563322067, 0.016854524612426758, 0.166769877076149, 0.016003064811229706, 0.024013301357626915, 0.03686072677373886, 0.44016721844673157, 0.1907106339931488, 0.0, 0.02856343612074852, 0.03459611535072327, 0.15441730618476868, 0.04662194848060608, 0.0013040672056376934, 0.017847269773483276, 0.02464178577065468, 0.5969575643539429, 0.09505032747983932, 0.0], [0.0036558371502906084, 0.0033082098234444857, 0.0009605223312973976, 0.017093589529395103, 0.019939379766583443, 0.08280718326568604, 0.031923551112413406, 0.703184187412262, 0.1371275782585144, 0.0, 0.03350958973169327, 0.02514287829399109, 0.027676144614815712, 0.11052078753709793, 0.15496152639389038, 0.08862635493278503, 0.027723105624318123, 0.24766287207603455, 0.28417670726776123, 0.0], [0.0018930931109935045, 0.002881440566852689, 0.00019882648484781384, 0.00406575808301568, 0.0021070034708827734, 0.011610278859734535, 0.0074381148442626, 0.9341073036193848, 0.03569793701171875, 0.0, 0.014355039224028587, 0.005383878946304321, 0.002517768880352378, 0.09422861039638519, 0.06622537225484848, 0.046315327286720276, 0.08473969250917435, 0.4999735355377197, 0.18626095354557037, 0.0], [0.00691588269546628, 0.02838265150785446, 0.015397720038890839, 0.031874921172857285, 0.04765379801392555, 0.22744230926036835, 0.06624653190374374, 0.10724947601556778, 0.4688366651535034, 0.0, 0.002460801973938942, 0.0016284199664369226, 0.005857668351382017, 0.006880565080791712, 0.7626023292541504, 0.025456121191382408, 0.021016357466578484, 0.06090177595615387, 0.11319592595100403, 0.0], [0.009144916199147701, 0.012914983555674553, 0.0114166010171175, 0.010616455227136612, 0.03852293640375137, 0.11398687958717346, 0.23996756970882416, 0.03855413943529129, 0.5248754620552063, 0.0, 0.007633878383785486, 0.002682786202058196, 0.0008938225219026208, 0.006808742880821228, 0.17231638729572296, 0.049100711941719055, 0.32851701974868774, 0.0061601921916007996, 0.4258863925933838, 0.0], [0.0165528766810894, 0.08396174013614655, 0.03695421293377876, 0.012792840600013733, 0.05054211989045143, 0.004681664984673262, 0.006349458359181881, 0.0059485542587935925, 0.7822163701057434, 0.0, 0.003303236560896039, 0.0015338786179199815, 0.0017581325955688953, 0.0052335225045681, 0.24177710711956024, 0.09136255830526352, 0.06603478640317917, 0.0047843558713793755, 0.5842124223709106, 0.0], [0.01941962167620659, 0.0720844566822052, 0.06703408807516098, 0.0024893011432141066, 0.09017500281333923, 0.01547347940504551, 0.011082785204052925, 0.036743972450494766, 0.6854971051216125, 0.0, 0.011694137938320637, 0.0015430846251547337, 0.00043408613419160247, 0.005433904007077217, 0.03723231703042984, 0.1666216105222702, 0.04878358170390129, 0.024785596877336502, 0.7034717798233032, 0.0], [0.010559813119471073, 0.7681021094322205, 0.01782229356467724, 0.0007385257631540298, 0.000383153063012287, 0.00014055910287424922, 0.00037340103881433606, 0.000453647633548826, 0.2014264017343521, 0.0, 0.028515880927443504, 0.0183264147490263, 0.011487613432109356, 0.03205259144306183, 0.06179385632276535, 0.041277043521404266, 0.014015565626323223, 0.06198226660490036, 0.7305486798286438, 0.0]], [[0.044569190591573715, 0.00917287077754736, 0.004391324240714312, 0.8386606574058533, 0.06130588799715042, 0.003870139131322503, 0.007488539442420006, 0.028126200661063194, 0.002415221417322755, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08009635657072067, 0.016815535724163055, 0.012093844823539257, 0.1592065542936325, 0.5643750429153442, 0.02920410968363285, 0.0919446051120758, 0.036902546882629395, 0.009361499920487404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08603464066982269, 0.01933746039867401, 0.05900268629193306, 0.2806539237499237, 0.22094620764255524, 0.08643656224012375, 0.026435989886522293, 0.1974046230316162, 0.023747902363538742, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011109462939202785, 0.008883769623935223, 0.006091873627156019, 0.9400036931037903, 0.020445559173822403, 0.0056496066972613335, 0.0019461432239040732, 0.005268549080938101, 0.000601345207542181, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12766019999980927, 0.021774157881736755, 0.08726640790700912, 0.0718328207731247, 0.053083695471286774, 0.3031027019023895, 0.06321869790554047, 0.2611844837665558, 0.010876962915062904, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06407223641872406, 0.11627303808927536, 0.1807759404182434, 0.0054795523174107075, 0.026687098667025566, 0.09637009352445602, 0.052303463220596313, 0.4456423819065094, 0.012396130710840225, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09725438803434372, 0.15948796272277832, 0.05173082649707794, 0.01153761800378561, 0.0721999853849411, 0.059252724051475525, 0.11923323571681976, 0.05380275845527649, 0.3755004107952118, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0974307581782341, 0.23218762874603271, 0.06967660784721375, 0.031012043356895447, 0.04906507954001427, 0.31767621636390686, 0.08231117576360703, 0.07159094512462616, 0.04904941841959953, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16388627886772156, 0.17526376247406006, 0.07081529498100281, 0.17886894941329956, 0.07944575697183609, 0.07640470564365387, 0.0757102444767952, 0.04333823174238205, 0.13626690208911896, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09265941381454468, 0.11633585393428802, 0.04908691346645355, 0.0062498715706169605, 0.07016508281230927, 0.012818480841815472, 0.0484321266412735, 0.015437646768987179, 0.5888146162033081, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15394526720046997, 0.03155883774161339, 0.013263775035738945, 0.6156436800956726, 0.07578981667757034, 0.016770629212260246, 0.041555847972631454, 0.008898256346583366, 0.042573899030685425, 0.0, 0.10071786493062973, 0.017111245542764664, 0.07246935367584229, 0.01480931881815195, 0.14864948391914368, 0.20273517072200775, 0.054981958121061325, 0.25890761613845825, 0.12961813807487488, 0.0], [0.0076131573878228664, 0.0034139170311391354, 0.013692762702703476, 0.9489110708236694, 0.0023491643369197845, 0.004504398908466101, 0.018228650093078613, 6.88576401444152e-05, 0.0012180121848359704, 0.0, 0.049787674099206924, 0.02108882926404476, 0.20989678800106049, 0.006962155923247337, 0.21569682657718658, 0.1622857302427292, 0.016771212220191956, 0.2403237521648407, 0.07718709856271744, 0.0], [0.002593724289909005, 0.0010450187837705016, 0.004842442460358143, 0.9895163178443909, 0.0010734394891187549, 6.863682210678235e-05, 0.00040535334846936166, 0.00029785462538711727, 0.00015763564442750067, 0.0, 0.024618864059448242, 0.010488898493349552, 0.4834355115890503, 0.015693388879299164, 0.07393413037061691, 0.055557433515787125, 0.007495412603020668, 0.27800077199935913, 0.05077548325061798, 0.0], [0.007949098013341427, 0.0007930149440653622, 0.0010613474296405911, 0.913150429725647, 0.017281265929341316, 0.01414033118635416, 0.03622613847255707, 0.0036093932576477528, 0.0057889497838914394, 0.0, 0.006324393209069967, 0.0006586945382878184, 0.02188086323440075, 0.003439908614382148, 0.055277179926633835, 0.5423230528831482, 0.1656835526227951, 0.12264314293861389, 0.08176910877227783, 0.0], [0.002646723063662648, 0.0005160199943929911, 0.0002488561731297523, 0.0025351925287395716, 0.0016247049206867814, 0.09429789334535599, 0.8856176137924194, 0.005861275363713503, 0.006651633884757757, 0.0, 0.008246216922998428, 0.000647101376671344, 0.018551276996731758, 0.0031310885678976774, 0.04379039630293846, 0.34376823902130127, 0.2999532222747803, 0.13205647468566895, 0.1498558670282364, 0.0], [0.010305220261216164, 0.0041244118474423885, 0.0009454450337216258, 0.011387528851628304, 0.006450551562011242, 0.09920497238636017, 0.7582080960273743, 0.0005519646219909191, 0.1088215708732605, 0.0, 0.001800144906155765, 0.00032634654780849814, 0.02560480497777462, 0.0014933178899809718, 0.04328969866037369, 0.48067817091941833, 0.22867664694786072, 0.008819987997412682, 0.20931090414524078, 0.0], [0.02700764685869217, 0.004230276681482792, 0.0004602614790201187, 0.0022337904665619135, 0.001628970610909164, 0.01760227419435978, 0.5739604234695435, 0.0034094173461198807, 0.3694668710231781, 0.0, 0.0023069612216204405, 0.0018136217258870602, 0.006447605788707733, 0.005140945315361023, 0.046570103615522385, 0.045606330037117004, 0.3236173987388611, 0.014286459423601627, 0.5542104840278625, 0.0], [0.008519366383552551, 0.005846879445016384, 0.00031929058604873717, 0.00022687541786581278, 0.0001488836423959583, 0.0012441301951184869, 0.007195098325610161, 0.000364138453733176, 0.9761351943016052, 0.0, 0.0025225167628377676, 0.000774701707996428, 0.0168449804186821, 0.0014132045907899737, 0.1692919135093689, 0.21547472476959229, 0.19468647241592407, 0.00621472392231226, 0.392776757478714, 0.0], [0.10429845005273819, 0.062129467725753784, 0.0009245545952580869, 0.00015166438242886215, 0.00031537580071017146, 0.00040291156619787216, 0.006900690030306578, 0.009933815337717533, 0.8149431943893433, 0.0, 0.006893941201269627, 0.0026040272787213326, 0.036687299609184265, 0.0016275923699140549, 0.13132861256599426, 0.15552441775798798, 0.23651301860809326, 0.023025648668408394, 0.4057953953742981, 0.0], [0.021318454295396805, 0.024646490812301636, 0.0006273255567066371, 3.4892458643298596e-05, 0.0002248335222247988, 0.001184952794574201, 0.005942351184785366, 0.01648845337331295, 0.9295321702957153, 0.0, 0.004257934633642435, 0.008543262258172035, 0.05716743320226669, 0.0024442216381430626, 0.027526315301656723, 0.08828678727149963, 0.025276461616158485, 0.2843557894229889, 0.5021417737007141, 0.0]], [[0.005850312765687704, 0.017421673983335495, 0.004798548296093941, 0.008814580738544464, 0.00403921864926815, 0.015260725282132626, 0.03377071022987366, 0.009620469063520432, 0.9004237651824951, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013436811044812202, 0.03272867575287819, 0.005969575606286526, 0.02213078737258911, 0.008325905539095402, 0.015314633026719093, 0.027177294716238976, 0.017041552811861038, 0.8578747510910034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012305106967687607, 0.03383316844701767, 0.010593314655125141, 0.027156231924891472, 0.00306991720572114, 0.004844812210649252, 0.018964877352118492, 0.05307865887880325, 0.8361539244651794, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01255274098366499, 0.03600494936108589, 0.010369472205638885, 0.019021298736333847, 0.0032906190026551485, 0.0037067385856062174, 0.017627976834774017, 0.01037716306746006, 0.8870489597320557, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0020879805088043213, 0.013872731477022171, 0.0018324662232771516, 0.006437606178224087, 0.013170951046049595, 0.011930068954825401, 0.0030771365854889154, 0.018353432416915894, 0.9292376041412354, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022865016013383865, 0.0674654096364975, 0.00996339786797762, 0.01914660632610321, 0.014956261031329632, 0.026097828522324562, 0.018910687416791916, 0.06562207639217377, 0.7549726963043213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019515078514814377, 0.07521340996026993, 0.03206341341137886, 0.0070005785673856735, 0.0066195218823850155, 0.03877842426300049, 0.01228683814406395, 0.032381508499383926, 0.7761411666870117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0022571254521608353, 0.00383052253164351, 0.0012509305961430073, 0.005982697941362858, 0.001252268673852086, 0.0028570422437042, 0.00556317949667573, 0.7337145805358887, 0.24329175055027008, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.132036030292511, 0.1683972030878067, 0.13758207857608795, 0.14189518988132477, 0.03147142380475998, 0.047566916793584824, 0.07834812998771667, 0.12177446484565735, 0.1409287303686142, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.038695693016052246, 0.06063547730445862, 0.020152689889073372, 0.1101006418466568, 0.021127784624695778, 0.02848564088344574, 0.03705665469169617, 0.108894944190979, 0.5748504996299744, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.18210574984550476, 0.33179324865341187, 0.17356830835342407, 0.09634877741336823, 0.13200198113918304, 0.013823754154145718, 0.003925282042473555, 0.0030049949418753386, 0.06342781335115433, 0.0, 0.2613511085510254, 0.024888625368475914, 0.11462423205375671, 0.021279124543070793, 0.1065509021282196, 0.23139707744121552, 0.07117345929145813, 0.09822205454111099, 0.07051338255405426, 0.0], [0.0020119324326515198, 0.018535200506448746, 0.9658629894256592, 0.007450602483004332, 0.004517300054430962, 0.0010996937053278089, 0.00011890578753082082, 8.778373739914969e-05, 0.0003156243183184415, 0.0, 0.08973123878240585, 0.07256940752267838, 0.3644520342350006, 0.09313907474279404, 0.10501276701688766, 0.026235496625304222, 0.035534195601940155, 0.05646198242902756, 0.15686386823654175, 0.0], [0.00025491390260867774, 0.010184208862483501, 0.989091694355011, 0.0003856455150526017, 4.125349732930772e-05, 1.630889528314583e-05, 4.754766450787429e-06, 9.06635705177905e-06, 1.2339231943769846e-05, 0.0, 0.12105944007635117, 0.03531542792916298, 0.18099160492420197, 0.04576702043414116, 0.03264385089278221, 0.04934798926115036, 0.0072426870465278625, 0.2739674150943756, 0.25366437435150146, 0.0], [0.0005263744969852269, 0.0021082928869873285, 0.03339724615216255, 0.9491472840309143, 0.010056117549538612, 0.0008323733345605433, 0.00041247360059060156, 0.003171485150232911, 0.0003482940956018865, 0.0, 0.049101557582616806, 0.02436317875981331, 0.1119280532002449, 0.019082490354776382, 0.23333144187927246, 0.12024182081222534, 0.09606382250785828, 0.03866123780608177, 0.3072265088558197, 0.0], [0.004013302735984325, 0.003607808379456401, 0.10117900371551514, 0.3848154544830322, 0.4549750089645386, 0.02184862084686756, 0.015023248270154, 0.013938427902758121, 0.0005991325015202165, 0.0, 0.011761177331209183, 0.004259902983903885, 0.019396282732486725, 0.010304590687155724, 0.5410462021827698, 0.1548439860343933, 0.1577453315258026, 0.022628072649240494, 0.07801424711942673, 0.0], [0.006630904506891966, 0.001555793103761971, 0.01566290855407715, 0.005377574823796749, 0.0545264296233654, 0.7578195929527283, 0.15542279183864594, 0.00011175184772582725, 0.002892365213483572, 0.0, 0.011012338101863861, 0.006456742994487286, 0.03514476120471954, 0.01111147552728653, 0.3646441400051117, 0.06045660004019737, 0.22725869715213776, 0.030072104185819626, 0.2538430392742157, 0.0], [0.0013706677127629519, 0.0003565592342056334, 0.0006504033226519823, 0.0008717189775779843, 0.023110924288630486, 0.16852477192878723, 0.8020843863487244, 0.0004564319388009608, 0.0025740356650203466, 0.0, 0.0009554855059832335, 0.0010365764610469341, 0.000539954868145287, 0.013481645844876766, 0.6702913641929626, 0.013201623223721981, 0.06565960496664047, 0.008186675608158112, 0.2266470193862915, 0.0], [0.0005740747437812388, 0.0018384596332907677, 0.015691960230469704, 0.0004515495093073696, 0.04004881531000137, 0.8668573498725891, 0.03566786274313927, 0.01278533972799778, 0.0260846596211195, 0.0, 0.007978711277246475, 0.0019918852485716343, 0.0007363414042629302, 0.010062554851174355, 0.10717969387769699, 0.01258536335080862, 0.08278501033782959, 0.02946571074426174, 0.7472147941589355, 0.0], [0.0008355869795195758, 0.003608973463997245, 0.04490630701184273, 0.009341607801616192, 0.007649072911590338, 0.10034366697072983, 0.06446904689073563, 0.7009655237197876, 0.06788014620542526, 0.0, 0.014369996264576912, 0.00412968173623085, 0.002898097038269043, 0.0381503589451313, 0.28382056951522827, 0.03412872180342674, 0.2624143660068512, 0.04523473232984543, 0.3148534893989563, 0.0], [0.00805425550788641, 0.039243537932634354, 0.05003930628299713, 0.0007152591715566814, 0.00863983016461134, 0.4756345748901367, 0.24407540261745453, 0.1204291433095932, 0.053168926388025284, 0.0, 0.0025127469561994076, 0.0030011499766260386, 0.0036209137178957462, 0.0006047216593287885, 0.01094596553593874, 0.0023283734917640686, 0.003409643191844225, 0.009625249542295933, 0.9639512896537781, 0.0]], [[0.0842718631029129, 0.33930304646492004, 0.1421334594488144, 0.18528752028942108, 0.05815916135907173, 0.022830937057733536, 0.01860896497964859, 0.009871570393443108, 0.13953347504138947, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1295168399810791, 0.08884051442146301, 0.06592670828104019, 0.2686370015144348, 0.02522267960011959, 0.03633918985724449, 0.021549394354224205, 0.051057688891887665, 0.31290990114212036, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20177747309207916, 0.039902716875076294, 0.053595658391714096, 0.09988140314817429, 0.01657777465879917, 0.07154539972543716, 0.024320384487509727, 0.12353017926216125, 0.36886897683143616, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20979411900043488, 0.30265647172927856, 0.20804975926876068, 0.007371237967163324, 0.0033807901199907064, 0.02442527562379837, 0.017248263582587242, 0.022337088361382484, 0.2047368586063385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20739784836769104, 0.05559583380818367, 0.12535981833934784, 0.009768493473529816, 0.015522800385951996, 0.024528708308935165, 0.03864477947354317, 0.08712086826562881, 0.4360608160495758, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.054182324558496475, 0.0038395673036575317, 0.019914912059903145, 0.014234894886612892, 0.012772555463016033, 0.019022708758711815, 0.04023807495832443, 0.36886361241340637, 0.46693122386932373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10943998396396637, 0.007069440558552742, 0.019821595400571823, 0.012627309188246727, 0.016869045794010162, 0.05302179232239723, 0.05124732851982117, 0.14304620027542114, 0.5868573188781738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23190192878246307, 0.060554418712854385, 0.05880776792764664, 0.00438718032091856, 0.00454165181145072, 0.15464532375335693, 0.09585105627775192, 0.02281157113611698, 0.3664989471435547, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13208958506584167, 0.15704363584518433, 0.07176639884710312, 0.08554346114397049, 0.0733223557472229, 0.0956358015537262, 0.07472448796033859, 0.09573546797037125, 0.21413877606391907, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.28021925687789917, 0.045415956526994705, 0.04812552034854889, 0.00880114920437336, 0.012029618956148624, 0.04001859948039055, 0.0577121265232563, 0.02487611398100853, 0.4828015863895416, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7289455533027649, 0.07586073875427246, 0.09869885444641113, 0.029881592839956284, 0.013988524675369263, 0.006547953933477402, 0.011115974746644497, 0.01961168274283409, 0.01534893549978733, 0.0, 0.023785226047039032, 0.024275904521346092, 0.6168470978736877, 0.01581703871488571, 0.026939542964100838, 0.1783975064754486, 0.04853774979710579, 0.02762567065656185, 0.03777410835027695, 0.0], [0.0068419137969613075, 0.9909167289733887, 0.0013876587618142366, 0.00011136479588458315, 6.257302447920665e-05, 6.40047510387376e-05, 1.7212767488672398e-05, 0.00025805848417803645, 0.0003405954339541495, 0.0, 0.006065902300179005, 0.04932599142193794, 0.8176359534263611, 0.018976736813783646, 0.008159944787621498, 0.011068272404372692, 0.010428683832287788, 0.014124251902103424, 0.06421414017677307, 0.0], [0.0458948090672493, 0.09657034277915955, 0.8496794104576111, 0.005955036263912916, 0.00011324687511660159, 0.000537818530574441, 0.00024974963162094355, 6.34682146483101e-05, 0.0009358474053442478, 0.0, 0.0037236923817545176, 0.007844064384698868, 0.9502744674682617, 0.0048003061674535275, 0.00022506865207105875, 0.004834793973714113, 0.0015490480000153184, 0.0026021157391369343, 0.024146683514118195, 0.0], [0.0026473035104572773, 0.007075308356434107, 0.0509142205119133, 0.9043685793876648, 0.01687121018767357, 0.0027590824756771326, 0.0040096985176205635, 0.004973408300429583, 0.006381142418831587, 0.0, 0.0007564057596027851, 0.0017460802337154746, 0.15768493711948395, 0.004074132069945335, 0.015430302359163761, 0.7368869781494141, 0.028010869398713112, 0.013945921324193478, 0.04146439954638481, 0.0], [0.00491793267428875, 0.0006714555202051997, 0.0047717769630253315, 0.09624139964580536, 0.8609471917152405, 0.020327016711235046, 0.008984015323221684, 0.0013932499568909407, 0.0017457373905926943, 0.0, 0.0008445779676549137, 0.0015138997696340084, 0.17073306441307068, 0.0074179465882480145, 0.08121992647647858, 0.5853323936462402, 0.09402737021446228, 0.024092217907309532, 0.034818582236766815, 0.0], [0.003878936870023608, 0.005480882711708546, 0.00011314810399198905, 0.0003485401102807373, 0.006120527163147926, 0.0029893070459365845, 0.0006264422554522753, 0.004414959345012903, 0.9760271906852722, 0.0, 0.0006014688406139612, 0.0016882645431905985, 0.16094569861888885, 0.003698966233059764, 0.034668561071157455, 0.5876308679580688, 0.09562253206968307, 0.05209798738360405, 0.06304588913917542, 0.0], [2.4277944248751737e-05, 4.029595402244013e-06, 3.533453991622082e-07, 0.0002488488098606467, 2.0782925275852904e-05, 0.0004858894390054047, 0.9990906119346619, 4.08113919547759e-05, 8.422240352956578e-05, 0.0, 0.00015283364336937666, 0.0004516944463830441, 0.003205003682523966, 0.0049727726727724075, 0.10853080451488495, 0.03262018784880638, 0.6125266551971436, 0.005719948559999466, 0.2318202555179596, 0.0], [0.0008877408690750599, 0.00558891985565424, 8.855570922605693e-05, 8.779557902016677e-06, 0.00010457105963723734, 0.00017662049503996968, 0.002778601599857211, 0.09916532039642334, 0.8912010192871094, 0.0, 9.933842375176027e-05, 0.00020211786613799632, 0.0037883264012634754, 0.0051808832213282585, 0.6936825513839722, 0.10089477151632309, 0.023457802832126617, 0.011726793833076954, 0.16096755862236023, 0.0], [0.0001188771057059057, 0.0008492054184898734, 9.383026917930692e-05, 5.974515715934103e-06, 0.002050562761723995, 8.90250812517479e-05, 8.933644130593166e-05, 0.002921103034168482, 0.9937818646430969, 0.0, 0.0002526468597352505, 0.0010056017199531198, 0.003837066935375333, 0.034950658679008484, 0.5882559418678284, 0.029549231752753258, 0.030938459560275078, 0.01461110170930624, 0.2965993583202362, 0.0], [0.013618292286992073, 0.005976812914013863, 3.6079491110285744e-05, 4.805085336556658e-05, 0.00010178168304264545, 0.03545643016695976, 0.04239484667778015, 0.35667654871940613, 0.5456912517547607, 0.0, 0.0029390468262135983, 0.005815382581204176, 0.06488344818353653, 0.008705642074346542, 0.010130577720701694, 0.012970774434506893, 0.019612692296504974, 0.007819950580596924, 0.8671225309371948, 0.0]], [[0.057871319353580475, 0.09845025092363358, 0.03600643575191498, 0.06401734054088593, 0.07263048738241196, 0.014885936863720417, 0.07473781704902649, 0.1193607747554779, 0.4620397090911865, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04736582189798355, 0.06951819360256195, 0.039210546761751175, 0.040616557002067566, 0.05645532160997391, 0.01900673843920231, 0.063181072473526, 0.23291724920272827, 0.43172842264175415, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05375257506966591, 0.04091374948620796, 0.01263821218162775, 0.04125160351395607, 0.014244006015360355, 0.012229752726852894, 0.029117466881871223, 0.07314542680978775, 0.7227071523666382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04062311723828316, 0.2312910407781601, 0.20085060596466064, 0.03848586603999138, 0.04763459786772728, 0.013425372540950775, 0.027237186208367348, 0.03882591798901558, 0.3616262376308441, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04558584839105606, 0.04037311673164368, 0.043737076222896576, 0.02740027941763401, 0.005366531666368246, 0.014126299880445004, 0.07268305867910385, 0.014923120848834515, 0.7358046770095825, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04308110475540161, 0.037618398666381836, 0.054927192628383636, 0.045146394520998, 0.02157701551914215, 0.014024189673364162, 0.03546718508005142, 0.04130468890070915, 0.7068538665771484, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.049693867564201355, 0.040712278336286545, 0.011129319667816162, 0.08677691221237183, 0.24132831394672394, 0.028864668682217598, 0.04710082337260246, 0.028962818905711174, 0.46543097496032715, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0256647989153862, 0.026453843340277672, 0.1064542606472969, 0.07867259532213211, 0.03285365179181099, 0.056291256099939346, 0.026517342776060104, 0.014768523164093494, 0.6323237419128418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07950135320425034, 0.04906205087900162, 0.09099037200212479, 0.10450085997581482, 0.06846266984939575, 0.21755923330783844, 0.14818403124809265, 0.14456483721733093, 0.0971745029091835, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0029191188514232635, 0.0026039828080683947, 0.000987313687801361, 0.027727283537387848, 0.007311245426535606, 0.0033244043588638306, 0.023969389498233795, 0.00596341909840703, 0.9251939058303833, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5123088955879211, 0.04897892847657204, 0.0054785809479653835, 0.037157464772462845, 0.033040400594472885, 0.0287709329277277, 0.020658228546380997, 0.005767439026385546, 0.3078390061855316, 0.0, 0.027314670383930206, 0.02143898233771324, 0.1116434708237648, 0.006578116212040186, 0.20446842908859253, 0.3867157995700836, 0.054494183510541916, 0.09778231382369995, 0.08956411480903625, 0.0], [0.02116605080664158, 0.45384567975997925, 0.5185156464576721, 0.002682638820260763, 0.000782136688940227, 0.0011598592391237617, 9.265153494197875e-05, 0.0013774167746305466, 0.00037764458102174103, 0.0, 0.09418193250894547, 0.7071846127510071, 0.05323847755789757, 0.0077135805040597916, 0.01789833791553974, 0.010848474688827991, 0.0020562252029776573, 0.01705808937549591, 0.08982021361589432, 0.0], [0.06292606890201569, 0.18043853342533112, 0.7547035217285156, 0.0011577572440728545, 6.746899998688605e-06, 0.00012623713701032102, 8.539375994587317e-05, 0.0005039210664108396, 5.15973697474692e-05, 0.0, 0.005751691292971373, 0.01031999196857214, 0.8884198069572449, 0.00210022390820086, 0.0066058398224413395, 0.019834432750940323, 0.002143828198313713, 0.02793465554714203, 0.036889322102069855, 0.0], [0.0030907110776752234, 0.0035500964149832726, 0.4530283808708191, 0.5221376419067383, 0.009557867422699928, 0.008033420890569687, 0.0001996932114707306, 0.0003745300055015832, 2.7415442673373036e-05, 0.0, 0.027659546583890915, 0.03931494802236557, 0.10616040229797363, 0.011142275296151638, 0.1017894372344017, 0.30847156047821045, 0.12201698124408722, 0.05519269034266472, 0.22825226187705994, 0.0], [0.00426989421248436, 0.0004077394842170179, 0.010691642761230469, 0.016011668369174004, 0.5530933737754822, 0.12423280626535416, 0.0053755901753902435, 0.28551235795021057, 0.00040478314622305334, 0.0, 0.037639543414115906, 0.062483835965394974, 0.050776157528162, 0.012697378173470497, 0.27911704778671265, 0.19993652403354645, 0.14870049059391022, 0.10304640233516693, 0.10560261458158493, 0.0], [0.00216855201870203, 0.009595326147973537, 0.007803121581673622, 0.04625817388296127, 0.24702508747577667, 0.2669595181941986, 0.024053409695625305, 0.24639348685741425, 0.14974308013916016, 0.0, 0.007995839230716228, 0.008397839032113552, 0.03270075097680092, 0.004312656354159117, 0.03775893524289131, 0.3733556568622589, 0.3424486219882965, 0.012857009656727314, 0.18017242848873138, 0.0], [1.953603486981592e-06, 5.726202516598278e-07, 5.0084551617146644e-08, 6.896097602293594e-06, 0.0001788837107596919, 0.0027895711828023195, 0.9969833493232727, 1.1296948287053965e-05, 2.7187752493773587e-05, 0.0, 0.012142053805291653, 0.007298614829778671, 0.016982076689600945, 0.02473442070186138, 0.08738671243190765, 0.033574704080820084, 0.27830857038497925, 0.033199213445186615, 0.5063735842704773, 0.0], [0.00041725789196789265, 0.0019265476148575544, 6.523763295263052e-05, 0.00018337361689191312, 0.011946662329137325, 0.04555974155664444, 0.15744170546531677, 0.025624049827456474, 0.7568355202674866, 0.0, 0.02729739435017109, 0.05135440081357956, 0.03332214429974556, 0.02499799057841301, 0.11955489963293076, 0.020848069339990616, 0.017926985397934914, 0.01858661323785782, 0.6861116290092468, 0.0], [0.0004706757317762822, 0.0027324894908815622, 0.0007427418022416532, 0.00934627279639244, 0.17134670913219452, 0.030644211918115616, 0.08413954824209213, 0.2513456642627716, 0.4492316246032715, 0.0, 0.018110578879714012, 0.011406980454921722, 0.0018257799092680216, 0.025524618104100227, 0.3885835111141205, 0.010744227096438408, 0.008441396057605743, 0.003679890651255846, 0.5316829681396484, 0.0], [0.0002758087939582765, 0.0016877831658348441, 2.4452297111565713e-06, 0.0004533462051767856, 0.001545731327496469, 0.008134560659527779, 0.010873721912503242, 0.026235109195113182, 0.950791597366333, 0.0, 0.02325628325343132, 0.013795747421681881, 0.0823512151837349, 0.0021813653875142336, 0.03511650115251541, 0.0814405307173729, 0.02589382231235504, 0.14330172538757324, 0.5926627516746521, 0.0]], [[0.3116825819015503, 0.20666195452213287, 0.1363646388053894, 0.07141851633787155, 0.029045483097434044, 0.04730900749564171, 0.037391580641269684, 0.10128220915794373, 0.05884409323334694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16883184015750885, 0.1785622239112854, 0.23318006098270416, 0.05258537083864212, 0.10740725696086884, 0.09185276180505753, 0.022670285776257515, 0.08943870663642883, 0.05547139048576355, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13844002783298492, 0.030681077390909195, 0.12107612937688828, 0.007168712094426155, 0.05214103311300278, 0.10045275092124939, 0.006991118658334017, 0.2955043315887451, 0.2475447952747345, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.17306901514530182, 0.19963578879833221, 0.148520827293396, 0.2046978771686554, 0.06720028817653656, 0.019652126356959343, 0.05792365223169327, 0.07665500044822693, 0.05264541134238243, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15115797519683838, 0.06555280834436417, 0.02683498151600361, 0.20794028043746948, 0.17434173822402954, 0.11980342864990234, 0.04239796847105026, 0.03961418569087982, 0.1723567098379135, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022970011457800865, 0.003322059055790305, 0.03333919122815132, 0.03161727264523506, 0.3007935583591461, 0.20675496757030487, 0.037206847220659256, 0.30415406823158264, 0.059842076152563095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029802093282341957, 0.006723356898874044, 0.00844306219369173, 0.09911961853504181, 0.2257867157459259, 0.22737178206443787, 0.28318148851394653, 0.05687837302684784, 0.0626935064792633, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008192392997443676, 0.002076511736959219, 0.010627061128616333, 0.01573592983186245, 0.01893553137779236, 0.042316026985645294, 0.02403445728123188, 0.868257999420166, 0.009823988191783428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08033955097198486, 0.18780502676963806, 0.02817567251622677, 0.041370414197444916, 0.02824225462973118, 0.038844767957925797, 0.11732563376426697, 0.06162348762154579, 0.4162730574607849, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0563269779086113, 0.007248359732329845, 0.008029782213270664, 0.003040844574570656, 0.007221699226647615, 0.01730128563940525, 0.050128430128097534, 0.3587413430213928, 0.491961270570755, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23423711955547333, 0.3770143389701843, 0.036408282816410065, 0.05249679461121559, 0.12246986478567123, 0.07398416101932526, 0.040104977786540985, 0.05500312149524689, 0.008280999027192593, 0.0, 0.011912941001355648, 0.006341524887830019, 0.1334817260503769, 0.017931688576936722, 0.005569889210164547, 0.7441595792770386, 0.0258712787181139, 0.034265827387571335, 0.020465616136789322, 0.0], [0.032638370990753174, 0.0975130945444107, 0.07180608063936234, 0.1075734794139862, 0.025216424837708473, 0.0218534916639328, 0.0376754030585289, 0.19710108637809753, 0.40862247347831726, 0.0, 0.03743305802345276, 0.05229289084672928, 0.3549361228942871, 0.028500670567154884, 0.01974724419414997, 0.29288655519485474, 0.08050932735204697, 0.06582070142030716, 0.06787342578172684, 0.0], [0.06665100902318954, 0.01976630836725235, 0.13041609525680542, 0.1772802174091339, 0.07561768591403961, 0.0061133550480008125, 0.022857116535305977, 0.3516288995742798, 0.14966930449008942, 0.0, 0.025082573294639587, 0.10057684034109116, 0.7856844663619995, 0.01178921852260828, 0.0010154875926673412, 0.02595749869942665, 0.008632739074528217, 0.006036050152033567, 0.0352250337600708, 0.0], [0.007724895142018795, 0.007534464355558157, 0.020593322813510895, 0.32147932052612305, 0.059838466346263885, 0.017819387838244438, 0.13181470334529877, 0.15524767339229584, 0.27794769406318665, 0.0, 0.010372502729296684, 0.023954369127750397, 0.18692812323570251, 0.03930393233895302, 0.004741673823446035, 0.46527597308158875, 0.1267295777797699, 0.048278260976076126, 0.09441567957401276, 0.0], [0.0028419073205441236, 0.0006648481939919293, 0.0018234169110655785, 0.01609039306640625, 0.0009005893371067941, 0.09726841002702713, 0.11035522073507309, 0.6978457570075989, 0.07220931351184845, 0.0, 0.0031391805969178677, 0.006868112366646528, 0.1369999200105667, 0.013019833713769913, 0.008593270555138588, 0.6626507639884949, 0.07777946442365646, 0.052107103168964386, 0.03884238004684448, 0.0], [0.0027223427314311266, 0.011157790198922157, 0.0052745467983186245, 0.00438346853479743, 0.010011360049247742, 0.38358205556869507, 0.2294219732284546, 0.057655833661556244, 0.29579076170921326, 0.0, 0.005276903510093689, 0.026354510337114334, 0.056238383054733276, 0.03191604092717171, 0.025259410962462425, 0.3898610472679138, 0.2790180444717407, 0.028249284252524376, 0.1578262895345688, 0.0], [0.0007371494430117309, 0.00023907626746222377, 8.450529276160523e-05, 0.002850313438102603, 0.002168564358726144, 0.02191324159502983, 0.9514637589454651, 0.0002505093871150166, 0.02029278129339218, 0.0, 0.001124523114413023, 0.0035109275486320257, 0.0021898215636610985, 0.043262600898742676, 0.0267842635512352, 0.03029855713248253, 0.5000982284545898, 0.0056237452663481236, 0.38710734248161316, 0.0], [0.004062887746840715, 0.007024500984698534, 0.007399421185255051, 0.011675640009343624, 0.0680280476808548, 0.02557964250445366, 0.07043837755918503, 0.01946563646197319, 0.7863259315490723, 0.0, 0.00020160828717052937, 0.0004381221951916814, 0.010444902814924717, 0.010398894548416138, 0.10143585503101349, 0.23107938468456268, 0.09920267760753632, 0.019364865496754646, 0.5274338126182556, 0.0], [0.0007200208492577076, 0.0024253681767731905, 0.029840486124157906, 0.00014908696175552905, 0.11268167197704315, 0.03336171433329582, 0.007834793999791145, 0.08127990365028381, 0.7317068576812744, 0.0, 0.00019610628078226, 0.00041535915806889534, 0.009462974965572357, 0.005905983969569206, 0.29870083928108215, 0.24092966318130493, 0.11132201552391052, 0.05168075114488602, 0.2813863754272461, 0.0], [0.008707311004400253, 0.08642490208148956, 0.11372587829828262, 0.004973042756319046, 0.13256150484085083, 0.16557356715202332, 0.0817113071680069, 0.006879410706460476, 0.3994430899620056, 0.0, 0.004135515075176954, 0.014277483336627483, 0.15738952159881592, 0.003396045882254839, 0.01641785353422165, 0.07296160608530045, 0.02388397790491581, 0.024013018235564232, 0.683525025844574, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\", \"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"], \"bot_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\", \"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}, \"inp_inp\": {\"att\": [[[[0.05334341153502464, 0.025828205049037933, 0.062369391322135925, 0.043252814561128616, 0.4045393764972687, 0.06697215139865875, 0.09001608937978745, 0.14983074367046356, 0.10384786874055862, 0.0], [0.11816457659006119, 0.03106253407895565, 0.01979171112179756, 0.16624291241168976, 0.3321376442909241, 0.020051123574376106, 0.08730963617563248, 0.18211135268211365, 0.04312858730554581, 0.0], [0.05936884880065918, 0.02174757793545723, 0.016160180792212486, 0.010601435787975788, 0.43925121426582336, 0.03876951336860657, 0.19815810024738312, 0.07065817713737488, 0.14528508484363556, 0.0], [0.15478025376796722, 0.16446512937545776, 0.0578744001686573, 0.21637752652168274, 0.03835854306817055, 0.09130414575338364, 0.11191156506538391, 0.08360221982002258, 0.08132638782262802, 0.0], [0.2183060646057129, 0.1704275906085968, 0.0827711746096611, 0.1202380359172821, 0.05203341320157051, 0.05958092212677002, 0.12280035018920898, 0.09366822242736816, 0.08017415553331375, 0.0], [0.05084313824772835, 0.026207493618130684, 0.13631564378738403, 0.012270472943782806, 0.16236551105976105, 0.02548854425549507, 0.03909383341670036, 0.03172134608030319, 0.5156941413879395, 0.0], [0.03615221381187439, 0.04799472168087959, 0.04255519434809685, 0.04762651398777962, 0.5117892622947693, 0.016304347664117813, 0.005770198069512844, 0.10897397249937057, 0.18283340334892273, 0.0], [0.03243544325232506, 0.025252558290958405, 0.11733424663543701, 0.0250592939555645, 0.20289097726345062, 0.08240236341953278, 0.18285907804965973, 0.011341268196702003, 0.3204246759414673, 0.0], [0.22355543076992035, 0.1260528564453125, 0.03741241991519928, 0.16813479363918304, 0.09858733415603638, 0.035831648856401443, 0.16361697018146515, 0.07236126810312271, 0.07444748282432556, 0.0], [0.08996112644672394, 0.0921943336725235, 0.22672457993030548, 0.12702998518943787, 0.05907799303531647, 0.10712798684835434, 0.16789256036281586, 0.055181413888931274, 0.07481010258197784, 0.0]], [[0.040477100759744644, 0.20988762378692627, 0.4869004786014557, 0.03505674749612808, 0.0558856800198555, 0.025423096492886543, 0.12231241166591644, 0.007062799762934446, 0.016993943601846695, 0.0], [0.8996549844741821, 0.02599872276186943, 0.049097247421741486, 0.0040262676775455475, 0.0039152717217803, 0.0049644638784229755, 0.010553319938480854, 0.001352570834569633, 0.0004369009402580559, 0.0], [0.33065715432167053, 0.2687782049179077, 0.03312753140926361, 0.22958999872207642, 0.01851547136902809, 0.046473052352666855, 0.053183481097221375, 0.007113412953913212, 0.012561764568090439, 0.0], [0.1589452475309372, 0.47470128536224365, 0.12878550589084625, 0.14158962666988373, 0.04442765936255455, 0.022274963557720184, 0.013780632056295872, 0.0024951419327408075, 0.012999956496059895, 0.0], [0.2559169828891754, 0.033451542258262634, 0.15095548331737518, 0.024318046867847443, 0.10824166238307953, 0.03234097361564636, 0.36475417017936707, 0.012823408469557762, 0.017197895795106888, 0.0], [0.021462664008140564, 0.010474847629666328, 0.007213775999844074, 0.02227940410375595, 0.21737068891525269, 0.4960675537586212, 0.014628118835389614, 0.20502059161663055, 0.005482145119458437, 0.0], [0.06734316051006317, 0.09532227367162704, 0.1127309575676918, 0.009542002342641354, 0.0678786113858223, 0.12933993339538574, 0.03809814900159836, 0.44453269243240356, 0.035212237387895584, 0.0], [0.10458365827798843, 0.02846018597483635, 0.029760979115962982, 0.014774680137634277, 0.022077379748225212, 0.1553817093372345, 0.3539015054702759, 0.19523507356643677, 0.09582491964101791, 0.0], [0.021077070385217667, 0.010932122357189655, 0.05088815093040466, 0.028641115874052048, 0.0881260335445404, 0.12014731019735336, 0.3900885581970215, 0.09544514119625092, 0.1946544349193573, 0.0], [0.02552945166826248, 0.05594164505600929, 0.045791901648044586, 0.093170166015625, 0.03584437444806099, 0.0969511866569519, 0.18585819005966187, 0.17433671653270721, 0.28657644987106323, 0.0]], [[0.18220090866088867, 0.25508272647857666, 0.2721964120864868, 0.04886331781744957, 0.010257811285555363, 0.07344724237918854, 0.08866558223962784, 0.037977367639541626, 0.0313086174428463, 0.0], [0.5722172260284424, 0.09567929804325104, 0.1448327898979187, 0.033306267112493515, 0.0031244128476828337, 0.020944159477949142, 0.012691132724285126, 0.061001092195510864, 0.05620381608605385, 0.0], [0.049244701862335205, 0.5266616344451904, 0.27518483996391296, 0.09334208071231842, 0.005858665332198143, 0.005467486567795277, 0.02565312758088112, 0.005746132228523493, 0.012841282412409782, 0.0], [0.13445906341075897, 0.13356590270996094, 0.6041688919067383, 0.01878039538860321, 0.06342840194702148, 0.03677675500512123, 0.008389262482523918, 0.0002739423362072557, 0.00015757972141727805, 0.0], [0.03273050859570503, 0.0697193592786789, 0.19719526171684265, 0.41500693559646606, 0.13721567392349243, 0.05743291601538658, 0.06517775356769562, 0.010865128599107265, 0.014656689018011093, 0.0], [0.031571000814437866, 0.014337136410176754, 0.06860436499118805, 0.09357307106256485, 0.10011686384677887, 0.07827721536159515, 0.5866308212280273, 0.011440092697739601, 0.015449290163815022, 0.0], [0.006158333271741867, 0.001533387927338481, 0.05427416041493416, 0.005477452650666237, 0.02694696933031082, 0.8134917616844177, 0.02643686905503273, 0.050265438854694366, 0.015415593050420284, 0.0], [0.008847472257912159, 0.0066053420305252075, 0.036443497985601425, 0.021455924957990646, 0.019254589453339577, 0.11543811857700348, 0.1138116791844368, 0.20307059586048126, 0.4750728905200958, 0.0], [0.017603449523448944, 0.008448019623756409, 0.004260394722223282, 0.006066101603209972, 0.013470137491822243, 0.01876576989889145, 0.16350960731506348, 0.1980665624141693, 0.5698099732398987, 0.0], [0.10490093380212784, 0.014168650843203068, 0.0247807614505291, 0.018330294638872147, 0.009348674677312374, 0.02287398651242256, 0.032268356531858444, 0.10571902245283127, 0.6676092147827148, 0.0]], [[0.2071455419063568, 0.637531578540802, 0.06835082173347473, 0.011966697871685028, 0.0017193991225212812, 0.04911382868885994, 0.009478496387600899, 0.008040529675781727, 0.00665308628231287, 0.0], [0.07411027699708939, 0.15093472599983215, 0.2656005620956421, 0.05758262053132057, 0.05194409564137459, 0.23625947535037994, 0.019166678190231323, 0.04010465368628502, 0.10429693013429642, 0.0], [0.1540999412536621, 0.10598444193601608, 0.22474077343940735, 0.32441702485084534, 0.1116243302822113, 0.054135363548994064, 0.008848286233842373, 0.004088098648935556, 0.012061581946909428, 0.0], [0.019440434873104095, 0.00560638727620244, 0.0035774046555161476, 0.0888679027557373, 0.7120485901832581, 0.14891275763511658, 0.011600993573665619, 0.008666431531310081, 0.0012791723711416125, 0.0], [0.08580154180526733, 0.02444172091782093, 0.08060747385025024, 0.05198557302355766, 0.2700504660606384, 0.34216371178627014, 0.11280739307403564, 0.006445358972996473, 0.02569655328989029, 0.0], [0.0424385629594326, 0.029667967930436134, 0.006252861116081476, 0.020168066024780273, 0.03000665083527565, 0.2812231779098511, 0.49279165267944336, 0.09351769089698792, 0.003933228086680174, 0.0], [0.006467411294579506, 0.0076894015073776245, 0.008325580507516861, 0.0010907554533332586, 0.01040297094732523, 0.19462232291698456, 0.013263629749417305, 0.24681615829467773, 0.5113216042518616, 0.0], [0.028696376830339432, 0.014982450753450394, 0.011884906329214573, 0.0011242942418903112, 0.01692844182252884, 0.12885364890098572, 0.028225399553775787, 0.6451764106750488, 0.12412811070680618, 0.0], [0.16117365658283234, 0.06794824451208115, 0.06173194944858551, 0.00451233983039856, 0.05306624248623848, 0.0510348416864872, 0.04402391240000725, 0.12432018667459488, 0.4321887195110321, 0.0], [0.1690559983253479, 0.043453093618154526, 0.036818861961364746, 0.017293656244874, 0.11775903403759003, 0.07970321178436279, 0.043801818042993546, 0.06849095970392227, 0.4236232340335846, 0.0]], [[0.03085354156792164, 0.12322185933589935, 0.13651973009109497, 0.050716523081064224, 0.2999139726161957, 0.09802427887916565, 0.06620478630065918, 0.0782310962677002, 0.11631430685520172, 0.0], [0.06789751350879669, 0.058182138949632645, 0.3129631578922272, 0.04353875666856766, 0.09142065048217773, 0.10271093249320984, 0.026392055675387383, 0.09630800783634186, 0.2005866914987564, 0.0], [0.07152411341667175, 0.3454192876815796, 0.11299439519643784, 0.18012462556362152, 0.07151429355144501, 0.052652161568403244, 0.0567985400557518, 0.09459780901670456, 0.014374655671417713, 0.0], [0.10420235246419907, 0.21845531463623047, 0.19832336902618408, 0.022119704633951187, 0.13572701811790466, 0.07722532749176025, 0.0508468933403492, 0.045597679913043976, 0.14750221371650696, 0.0], [0.07030870020389557, 0.10706955939531326, 0.02791348285973072, 0.02260597050189972, 0.12725059688091278, 0.07336997240781784, 0.26662203669548035, 0.16957008838653564, 0.13528966903686523, 0.0], [0.05156806856393814, 0.04327721148729324, 0.07664787024259567, 0.06931594759225845, 0.1889398992061615, 0.09515503793954849, 0.07227510958909988, 0.2641449272632599, 0.13867592811584473, 0.0], [0.02184019424021244, 0.11184182018041611, 0.36672860383987427, 0.013787303119897842, 0.07600502669811249, 0.0389828234910965, 0.040494974702596664, 0.12485849112272263, 0.20546066761016846, 0.0], [0.013738485053181648, 0.05187288299202919, 0.03463537245988846, 0.03627979755401611, 0.048659998923540115, 0.02440205216407776, 0.07256433367729187, 0.024731382727622986, 0.6931155323982239, 0.0], [0.02671198360621929, 0.4013687074184418, 0.01132842618972063, 0.14022575318813324, 0.026275552809238434, 0.08107840269804001, 0.04189194366335869, 0.25432130694389343, 0.0167979933321476, 0.0], [0.14228780567646027, 0.07866450399160385, 0.08390624076128006, 0.09396661072969437, 0.087954580783844, 0.14498625695705414, 0.13517630100250244, 0.1169552430510521, 0.11610251665115356, 0.0]], [[0.02165721170604229, 0.018354326486587524, 0.6383510828018188, 0.042513273656368256, 0.10956817120313644, 0.10717540234327316, 0.030344119295477867, 0.015826348215341568, 0.01621006615459919, 0.0], [0.4647374749183655, 0.07284841686487198, 0.28081396222114563, 0.014013433828949928, 0.03169411048293114, 0.02214456908404827, 0.058711059391498566, 0.036629818379879, 0.01840737834572792, 0.0], [0.07372704148292542, 0.12858736515045166, 0.4501189887523651, 0.054217785596847534, 0.07096204906702042, 0.05748127028346062, 0.06541819125413895, 0.04703349620103836, 0.05245373025536537, 0.0], [0.04684445261955261, 0.019098779186606407, 0.008431704714894295, 0.0010175607167184353, 0.9129327535629272, 0.004866998642683029, 0.006678053177893162, 8.096762758214027e-05, 4.903498847852461e-05, 0.0], [0.08239725232124329, 0.02813413366675377, 0.16611848771572113, 0.1532817929983139, 0.07408729940652847, 0.10856874287128448, 0.047752734273672104, 0.02563621662557125, 0.31402355432510376, 0.0], [0.17959792912006378, 0.02262653037905693, 0.10724494606256485, 0.022216446697711945, 0.1862414926290512, 0.14705143868923187, 0.15912717580795288, 0.15293282270431519, 0.02296125516295433, 0.0], [0.038375359028577805, 0.0038853511214256287, 0.06201936677098274, 0.005828780122101307, 0.22059503197669983, 0.36631014943122864, 0.020396992564201355, 0.20976856350898743, 0.07282061129808426, 0.0], [0.014258276671171188, 0.005652762018144131, 0.025611618533730507, 0.15294744074344635, 0.06760217249393463, 0.2498260736465454, 0.1669282466173172, 0.2265811711549759, 0.09059228003025055, 0.0], [0.15833799540996552, 0.1228356659412384, 0.10147804021835327, 0.0284584891051054, 0.27955442667007446, 0.06763719022274017, 0.08874277770519257, 0.1152903363108635, 0.037665050476789474, 0.0], [0.09844867885112762, 0.0919492095708847, 0.028445947915315628, 0.03726689890027046, 0.035665158182382584, 0.06817072629928589, 0.29930955171585083, 0.09819743037223816, 0.2425464242696762, 0.0]], [[0.02519470639526844, 0.006357265170663595, 0.14269335567951202, 0.023629529401659966, 0.3124701976776123, 0.13565225899219513, 0.2595662772655487, 0.07959114015102386, 0.014845297671854496, 0.0], [0.04550129547715187, 0.011541971005499363, 0.1165909469127655, 0.02512240968644619, 0.01843150518834591, 0.05711649730801582, 0.44489097595214844, 0.033205363899469376, 0.24759893119335175, 0.0], [0.13528011739253998, 0.06777236610651016, 0.14429129660129547, 0.04697401076555252, 0.1738385707139969, 0.014099549502134323, 0.38417065143585205, 0.01158357597887516, 0.02199004776775837, 0.0], [0.21356959640979767, 0.1638900637626648, 0.10595463216304779, 0.06925727427005768, 0.167257159948349, 0.04259340837597847, 0.10967854410409927, 0.03570139408111572, 0.09209771454334259, 0.0], [0.20140984654426575, 0.04755665361881256, 0.15174560248851776, 0.11619894206523895, 0.21928974986076355, 0.07600340992212296, 0.05828682705760002, 0.10010629147291183, 0.029402663931250572, 0.0], [0.024259669706225395, 0.02116699516773224, 0.21201731264591217, 0.019622934982180595, 0.4893963038921356, 0.021304504945874214, 0.16948339343070984, 0.022949064150452614, 0.01979990489780903, 0.0], [0.022248759865760803, 0.01183647196739912, 0.0633181631565094, 0.029095010831952095, 0.07090882211923599, 0.4614315629005432, 0.020150773227214813, 0.18720205128192902, 0.1338084638118744, 0.0], [0.003461656626313925, 0.01603432185947895, 0.009874427691102028, 0.014947548508644104, 0.2953553795814514, 0.3502987027168274, 0.08878874033689499, 0.036094941198825836, 0.18514421582221985, 0.0], [0.005101516842842102, 0.022985950112342834, 0.007523353211581707, 0.026773063465952873, 0.01009095273911953, 0.014858697541058064, 0.15149906277656555, 0.028601571917533875, 0.7325656414031982, 0.0], [0.12995873391628265, 0.07769863307476044, 0.02032659947872162, 0.13720010221004486, 0.011713794432580471, 0.054615918546915054, 0.23920413851737976, 0.13190706074237823, 0.19737498462200165, 0.0]], [[0.21207179129123688, 0.11920439451932907, 0.4251355528831482, 0.014464439824223518, 0.20776884257793427, 0.01428140513598919, 0.0027938869316130877, 0.001743048895150423, 0.002536489861086011, 0.0], [0.046175818890333176, 0.026793524622917175, 0.8552185297012329, 0.04517081379890442, 0.010388500988483429, 0.004191457759588957, 0.0036751439329236746, 0.0013485046802088618, 0.007037981878966093, 0.0], [0.013186579570174217, 0.020899420604109764, 0.6900137662887573, 0.0480119027197361, 0.15360434353351593, 0.02344118244946003, 0.03952033817768097, 0.0038994532078504562, 0.007422822527587414, 0.0], [0.006273405160754919, 0.00015674144378863275, 0.000751359446439892, 0.00447711581364274, 0.9859057664871216, 0.002212332095950842, 0.00014360185014083982, 4.957199053023942e-05, 2.9913859179941937e-05, 0.0], [0.001047183177433908, 0.0003636489564087242, 0.009283728897571564, 0.016805388033390045, 0.42387446761131287, 0.4776095747947693, 0.06253702938556671, 0.005590841174125671, 0.002888289513066411, 0.0], [0.0018647151300683618, 0.0002549054042901844, 2.6050107408082113e-05, 2.586200753285084e-05, 0.0024472770746797323, 0.006814199965447187, 0.9776560664176941, 0.010138182900846004, 0.000773087958805263, 0.0], [0.047241877764463425, 0.006076885852962732, 0.04534892365336418, 0.00081661093281582, 0.087706059217453, 0.41394293308258057, 0.21876952052116394, 0.17005810141563416, 0.0100388890132308, 0.0], [0.0019138919888064265, 0.006189406383782625, 0.010115097276866436, 8.508542669005692e-05, 0.008424345403909683, 0.003492203773930669, 0.13495568931102753, 0.4890870749950409, 0.34573695063591003, 0.0], [0.016032341867685318, 0.005025702994316816, 0.009520799852907658, 0.0008855267078615725, 0.026489384472370148, 0.0020503124687820673, 0.032939448952674866, 0.09461060166358948, 0.8124459385871887, 0.0], [0.25683313608169556, 0.02960006147623062, 0.11211041361093521, 0.09736908972263336, 0.17546677589416504, 0.032068025320768356, 0.017857572063803673, 0.025635067373514175, 0.25305992364883423, 0.0]]], [[[0.10487863421440125, 0.7106320858001709, 0.1635318249464035, 0.011256101541221142, 0.0012767312582582235, 0.00310636218637228, 0.0013001860352233052, 0.0012553841806948185, 0.002762428717687726, 0.0], [0.021650908514857292, 0.0030605364590883255, 0.6595932245254517, 0.2987315356731415, 0.012945608235895634, 0.0028472936246544123, 7.557096250820905e-05, 0.00029089683084748685, 0.0008047237643040717, 0.0], [0.014272261410951614, 0.040512338280677795, 0.8595607280731201, 0.038314104080200195, 0.037397123873233795, 0.006795509252697229, 0.001303989440202713, 0.001011757180094719, 0.0008321924251504242, 0.0], [0.031783342361450195, 0.007319662719964981, 0.7663278579711914, 0.0010118860518559813, 0.1672297865152359, 0.02513650804758072, 0.000853335193824023, 0.0002817189379129559, 5.600590884569101e-05, 0.0], [0.002136597875505686, 0.00037253598566167057, 0.07588302344083786, 0.2252500057220459, 0.33551687002182007, 0.35751965641975403, 0.0027331046294420958, 0.00018122239271178842, 0.0004068210837431252, 0.0], [0.0004353485128376633, 0.0003557991876732558, 0.0003262429090682417, 0.003819868667051196, 0.33603885769844055, 0.2681770920753479, 0.3838857412338257, 0.0068349516950547695, 0.00012614508159458637, 0.0], [6.71677480568178e-05, 3.9912600186653435e-05, 0.00047830803669057786, 5.937727837590501e-05, 0.0014537296956405044, 0.6413838863372803, 0.29047340154647827, 0.06565171480178833, 0.0003929881495423615, 0.0], [0.00047039391938596964, 0.0007891620043665171, 0.0007817292353138328, 0.0010076714679598808, 0.00965806283056736, 0.003733346238732338, 0.35330116748809814, 0.5722718238830566, 0.05798657611012459, 0.0], [0.006178696174174547, 0.009340841323137283, 0.0005589249776676297, 0.005146770738065243, 0.0033258567564189434, 0.0016933922888711095, 0.06414961069822311, 0.3291752338409424, 0.5804308652877808, 0.0], [0.006624103523790836, 0.001978900283575058, 0.0081730792298913, 0.0030846702866256237, 0.0018904987955465913, 0.0014340116176754236, 0.005187559872865677, 0.029854312539100647, 0.9417726993560791, 0.0]], [[0.17277710139751434, 0.13871003687381744, 0.020699918270111084, 0.04190761595964432, 0.17760643362998962, 0.1702892780303955, 0.16168300807476044, 0.10000763088464737, 0.01631900854408741, 0.0], [0.9987638592720032, 0.0011447033612057567, 1.5495901607209817e-05, 2.3805538096333123e-10, 1.1166920899086108e-07, 4.81009180930414e-07, 2.3257289285538718e-05, 3.4320622944505885e-05, 1.812833215808496e-05, 0.0], [0.029870687052607536, 0.9668734669685364, 0.0031853404361754656, 3.7420595617732033e-06, 1.0481591772304455e-07, 4.711453893690987e-09, 4.051101996083162e-07, 1.359390239485947e-06, 6.518688314827159e-05, 0.0], [2.9839180569979362e-05, 0.0008244949858635664, 0.9990562796592712, 6.778111855965108e-05, 2.14482715819031e-05, 5.3428358959273226e-11, 7.202954205309808e-11, 7.697720239008277e-11, 1.422941551254553e-07, 0.0], [9.680035873316228e-05, 4.205659934086725e-05, 0.0021876851096749306, 0.9926192164421082, 0.0050464412197470665, 7.330636890401365e-06, 4.7689670878980905e-08, 8.238330573284713e-10, 9.979119397485192e-08, 0.0], [5.136659183335723e-06, 6.750806136324172e-08, 8.17252839624416e-06, 0.008817464113235474, 0.9640147089958191, 0.027066770941019058, 8.771067950874567e-05, 3.571775764044105e-09, 3.5257423647294672e-09, 0.0], [5.115869043947896e-07, 1.0059281407848175e-08, 1.3136859422502312e-07, 9.641905052149013e-08, 0.001335342414677143, 0.9957214593887329, 0.0029362423811107874, 7.136273325158982e-06, 1.1521567699901425e-08, 0.0], [3.561131961760111e-06, 2.727877870256634e-07, 8.369554507225985e-07, 1.214864764342849e-09, 4.873449597653234e-06, 0.024909861385822296, 0.9680997133255005, 0.006879042834043503, 0.00010210835171164945, 0.0], [0.00021467455371748656, 9.040503209689632e-05, 3.369562909938395e-05, 1.9265097961351785e-08, 9.727973520057276e-07, 2.4095537810353562e-05, 0.0040859803557395935, 0.8618475794792175, 0.1337023377418518, 0.0], [2.289768872287823e-06, 6.284429400693625e-05, 0.0001214230724144727, 2.809870807141124e-07, 1.092972157223926e-09, 1.0671180605825725e-09, 1.2438744079190656e-06, 0.024907555431127548, 0.9749038219451904, 0.0]], [[0.058097392320632935, 0.00935883168131113, 0.04822169989347458, 0.0048278868198394775, 0.191309854388237, 0.28154584765434265, 0.09391050785779953, 0.24126385152339935, 0.07146408408880234, 0.0], [0.10414423793554306, 0.027566324919462204, 0.021727869287133217, 0.033647697418928146, 0.026882247999310493, 0.17782779037952423, 0.05685214698314667, 0.45095938444137573, 0.10039239376783371, 0.0], [0.44215551018714905, 0.049670565873384476, 0.014098896645009518, 0.029011834412813187, 0.01834075152873993, 0.1358453929424286, 0.04072042554616928, 0.2330295443534851, 0.03712712228298187, 0.0], [0.10425814986228943, 0.06979154050350189, 0.036334071308374405, 0.028995294123888016, 0.015532439574599266, 0.1330128014087677, 0.063407763838768, 0.23157192766666412, 0.3170958459377289, 0.0], [0.3384562134742737, 0.055937401950359344, 0.038792647421360016, 0.00819220207631588, 0.03063569962978363, 0.09386011958122253, 0.07227522879838943, 0.30926018953323364, 0.05259038880467415, 0.0], [0.3519401550292969, 0.1823827177286148, 0.06509842723608017, 0.030452275648713112, 0.08377533406019211, 0.09469012171030045, 0.04247477278113365, 0.11751312017440796, 0.03167306259274483, 0.0], [0.3634622097015381, 0.14048337936401367, 0.08374395966529846, 0.038946691900491714, 0.03473563492298126, 0.06442954391241074, 0.019375532865524292, 0.22685663402080536, 0.027966352179646492, 0.0], [0.18070067465305328, 0.04645215719938278, 0.0992647334933281, 0.005799622740596533, 0.47514480352401733, 0.12094692885875702, 0.030788421630859375, 0.025236092507839203, 0.015666494145989418, 0.0], [0.5453059673309326, 0.10054859519004822, 0.01722547970712185, 0.06704734265804291, 0.007780902087688446, 0.07263857871294022, 0.022086072713136673, 0.1394840031862259, 0.027883058413863182, 0.0], [0.15028028190135956, 0.17163224518299103, 0.06043723225593567, 0.10140684247016907, 0.10512865334749222, 0.06778015196323395, 0.06512691080570221, 0.23085294663906097, 0.04735487326979637, 0.0]], [[0.11086989939212799, 0.14517885446548462, 0.17419463396072388, 0.060936953872442245, 0.08783368766307831, 0.11005676537752151, 0.03251044824719429, 0.07983692735433578, 0.19858187437057495, 0.0], [0.16660544276237488, 0.29352903366088867, 0.1008867621421814, 0.023942291736602783, 0.15022507309913635, 0.06581585109233856, 0.02344084158539772, 0.05208655819296837, 0.12346797436475754, 0.0], [0.1683349758386612, 0.22478938102722168, 0.06976605206727982, 0.1032773107290268, 0.16255290806293488, 0.08890064060688019, 0.03925151377916336, 0.023706944659352303, 0.11942004412412643, 0.0], [0.19914905726909637, 0.1368866264820099, 0.178489089012146, 0.11241752654314041, 0.06187256798148155, 0.0768556222319603, 0.01627686619758606, 0.07274915277957916, 0.14530348777770996, 0.0], [0.08000901341438293, 0.20181676745414734, 0.21235129237174988, 0.05340588092803955, 0.12758778035640717, 0.11278047412633896, 0.06906574964523315, 0.08596791326999664, 0.05701539292931557, 0.0], [0.14153669774532318, 0.10432923585176468, 0.09881750494241714, 0.08603313565254211, 0.10391980409622192, 0.06189347058534622, 0.06772381067276001, 0.08503933250904083, 0.25070688128471375, 0.0], [0.06525713205337524, 0.07869093865156174, 0.11366366595029831, 0.044226594269275665, 0.05455174669623375, 0.23646420240402222, 0.09933798015117645, 0.1198185384273529, 0.1879890412092209, 0.0], [0.09450254589319229, 0.027017319574952126, 0.06480545550584793, 0.10929621011018753, 0.11382008343935013, 0.17441418766975403, 0.11898359656333923, 0.06495486199855804, 0.23220552504062653, 0.0], [0.07681684195995331, 0.0671391412615776, 0.0905177965760231, 0.06064317002892494, 0.06652072072029114, 0.09855856746435165, 0.07360702753067017, 0.13956283032894135, 0.3266339898109436, 0.0], [0.12179998308420181, 0.07977079600095749, 0.08405954390764236, 0.1456507444381714, 0.14551174640655518, 0.07862778753042221, 0.09882251918315887, 0.14300917088985443, 0.1027478501200676, 0.0]], [[0.0261031873524189, 0.9575563073158264, 0.006272038444876671, 0.0037288309540599585, 0.0038619006518274546, 0.0007324732141569257, 0.0005133527447469532, 0.0003637235495261848, 0.0008679544553160667, 0.0], [0.02134888991713524, 0.08473973721265793, 0.6753177642822266, 0.028721673414111137, 0.14432094991207123, 0.027568204328417778, 0.0057298606261610985, 0.004451636224985123, 0.007801060564815998, 0.0], [0.03883299231529236, 0.030284319072961807, 0.5620493292808533, 0.09062989801168442, 0.17362907528877258, 0.08253934979438782, 0.010801085270941257, 0.00978847872465849, 0.0014453904004767537, 0.0], [0.002180949319154024, 0.003013473702594638, 0.16569769382476807, 0.008050205186009407, 0.7580646276473999, 0.061441101133823395, 0.001020166208036244, 0.0001067533012246713, 0.0004249440098647028, 0.0], [0.004150479566305876, 0.00034606645931489766, 0.3802972435951233, 0.06855826079845428, 0.29045602679252625, 0.1767650991678238, 0.06603583693504333, 0.0014808314153924584, 0.011909942142665386, 0.0], [0.006170187145471573, 0.0012396957026794553, 0.0354800671339035, 0.0032299698796123266, 0.03240001201629639, 0.5543311238288879, 0.30418315529823303, 0.051339369267225266, 0.01162647269666195, 0.0], [0.0035115755163133144, 0.0011483307462185621, 0.017956364899873734, 0.003783614607527852, 0.030611976981163025, 0.3673596978187561, 0.20627115666866302, 0.3506667912006378, 0.01869054324924946, 0.0], [0.0021685126703232527, 0.0006909942603670061, 0.010240452364087105, 0.01958688348531723, 0.004634156823158264, 0.11485372483730316, 0.04815557599067688, 0.7050773501396179, 0.0945921242237091, 0.0], [0.049201104789972305, 0.02397306263446808, 0.02337191067636013, 0.31066185235977173, 0.06433572620153427, 0.12544430792331696, 0.0786852017045021, 0.25179895758628845, 0.07252778857946396, 0.0], [0.010841209441423416, 0.0041772774420678616, 0.01548130251467228, 0.036074474453926086, 0.033387064933776855, 0.08192819356918335, 0.04784044623374939, 0.10195028781890869, 0.668319821357727, 0.0]], [[0.005738695617765188, 0.0068999892100691795, 0.4274883270263672, 0.08288666605949402, 0.1445126235485077, 0.04382907599210739, 0.10957401990890503, 0.05347184091806412, 0.1255987584590912, 0.0], [0.0025263649877160788, 0.00471830926835537, 0.13454590737819672, 0.4177793860435486, 0.28839975595474243, 0.029358303174376488, 0.017654288560152054, 0.0047735795378685, 0.10024390369653702, 0.0], [0.009192855097353458, 0.007133236154913902, 0.03149157017469406, 0.1856081485748291, 0.5691666603088379, 0.07386670261621475, 0.029819192364811897, 0.03683711960911751, 0.05688462406396866, 0.0], [0.00297820963896811, 0.0015070328954607248, 0.0025649494491517544, 0.0011051844339817762, 0.04088710993528366, 0.1953955888748169, 0.34000417590141296, 0.3367410898208618, 0.07881659269332886, 0.0], [0.003951869439333677, 0.009354526177048683, 0.007010620087385178, 0.0025927696842700243, 0.09962604194879532, 0.10909298062324524, 0.4455967843532562, 0.15358439087867737, 0.16918975114822388, 0.0], [0.0038829154800623655, 0.0036434896755963564, 0.006399825215339661, 0.000760377966798842, 0.010139851830899715, 0.038725122809410095, 0.10014155507087708, 0.48370444774627686, 0.35260239243507385, 0.0], [0.001297087874263525, 0.0014563009608536959, 0.013839880004525185, 0.0004286184557713568, 0.012207024730741978, 0.028704902157187462, 0.046600911766290665, 0.26406532526016235, 0.6313998103141785, 0.0], [0.0033481158316135406, 0.0038099782541394234, 0.0031049775425344706, 0.00033546099439263344, 0.0031272985506802797, 0.008788534440100193, 0.021183660253882408, 0.12157405912876129, 0.8347280025482178, 0.0], [0.3364367187023163, 0.17456969618797302, 0.051038213074207306, 0.006790165323764086, 0.024106895551085472, 0.0694134384393692, 0.02184627763926983, 0.061508405953645706, 0.25429028272628784, 0.0], [0.10536088049411774, 0.07750789821147919, 0.0850178673863411, 0.08725376427173615, 0.2586125433444977, 0.16756391525268555, 0.054291605949401855, 0.030132828280329704, 0.13425879180431366, 0.0]], [[0.034539882093667984, 0.0018589550163596869, 0.9604092836380005, 1.3120608855388127e-05, 2.1815638319822028e-05, 0.00012517283903434873, 8.019943197723478e-05, 0.0021589084062725306, 0.0007928607519716024, 0.0], [7.048832912914804e-07, 1.7815009414334781e-06, 0.9998455047607422, 0.0001518452918389812, 4.1070780554264275e-08, 2.7954746156799715e-11, 9.231376947582692e-12, 9.901777175969073e-09, 2.5545642756696907e-07, 0.0], [6.695767496012195e-08, 2.089915795977504e-07, 0.005368041805922985, 0.9945066571235657, 0.0001248170156031847, 2.304766155702964e-09, 2.762512718579302e-10, 3.973758211373024e-09, 9.372820954922645e-07, 0.0], [5.018761014413675e-13, 1.4841802622529476e-16, 4.663825770023777e-09, 3.820862737313746e-09, 0.9999942183494568, 4.988648925063899e-06, 4.967477167452938e-13, 1.416252587396787e-16, 2.1775358895380023e-16, 0.0], [4.666895758731471e-09, 7.292542437975502e-12, 2.898993545219497e-11, 4.2817244194637283e-10, 0.00027504604076966643, 0.9995728731155396, 0.00015239788626786321, 1.9082661839586734e-10, 2.232514032581706e-13, 0.0], [1.7137297136926577e-10, 5.3312285142048665e-12, 2.2368220760327594e-14, 4.904942142678549e-17, 8.726878775178193e-09, 0.004644036293029785, 0.9953435659408569, 1.324965796811739e-05, 6.982896899598856e-12, 0.0], [4.877224735189145e-10, 1.5497924055196677e-09, 6.021576987036426e-11, 8.955144165463396e-19, 1.7180077889825118e-13, 6.163505759104737e-07, 0.001256544259376824, 0.9987285733222961, 1.4209075743565336e-05, 0.0], [3.25698863434809e-08, 7.313030323530256e-07, 1.412931510458293e-06, 1.1662047555981733e-16, 8.495708612521816e-14, 1.1933978653379251e-13, 1.3303619539328793e-07, 0.01294001005589962, 0.9870572686195374, 0.0], [1.6884889646462398e-06, 2.6281904865754768e-05, 0.001122217159718275, 6.101166945882142e-06, 4.424501298672112e-08, 5.172042264953158e-13, 5.508820136168602e-11, 5.942968346062116e-05, 0.9987838268280029, 0.0], [4.288114359951578e-05, 6.015944563841913e-06, 0.004432132933288813, 0.025997335091233253, 0.000731422973331064, 6.87844434188456e-11, 8.199346692057408e-13, 7.098316245901515e-08, 0.9687905311584473, 0.0]], [[0.02526121959090233, 0.9527671933174133, 0.014345486648380756, 0.0014051493490114808, 0.003839265089482069, 0.00014350644778460264, 0.0006356940139085054, 0.00025237957015633583, 0.0013501241337507963, 0.0], [0.004122408106923103, 0.023777475580573082, 0.9002965688705444, 0.0682864859700203, 0.0017659803852438927, 0.0001271881628781557, 0.00011044178245356306, 0.0001890352723421529, 0.0013242338318377733, 0.0], [8.841444650897756e-05, 0.0002895947836805135, 0.06307922303676605, 0.9069769978523254, 0.028407124802470207, 0.000558151863515377, 0.00022284295118879527, 0.00018588549573905766, 0.00019132612214889377, 0.0], [1.889026179924258e-06, 3.9712713260087185e-06, 0.001210480579175055, 0.003201226470991969, 0.8290116786956787, 0.16640713810920715, 0.00015829727635718882, 4.0429063119518105e-06, 9.256136763724498e-07, 0.0], [0.000399262469727546, 5.1438626542221755e-05, 0.0001944842515513301, 0.0007700449787080288, 0.4879837930202484, 0.4847603738307953, 0.025640420615673065, 0.00018376839580014348, 1.6383723050239496e-05, 0.0], [4.30414620495867e-05, 1.017293288896326e-05, 8.407413588429336e-06, 5.451946094581217e-07, 0.000544070964679122, 0.021075371652841568, 0.9573339819908142, 0.0208626389503479, 0.00012169074034318328, 0.0], [0.00043880229350179434, 0.0004488519043661654, 0.000600603292696178, 1.4583132212919736e-07, 3.6701523640658706e-05, 0.010162030346691608, 0.37363454699516296, 0.559087336063385, 0.0555914081633091, 0.0], [0.0010709260823205113, 0.0006920771556906402, 0.0016655249055474997, 0.00010216240480076522, 1.0821948308148421e-05, 2.6151516067329794e-05, 0.01446994487196207, 0.2987785339355469, 0.6831837296485901, 0.0], [0.0002485924051143229, 0.00016839140153024346, 0.019545644521713257, 0.016785046085715294, 0.005671702325344086, 0.00014030851889401674, 0.001185068627819419, 0.04272715002298355, 0.9135279655456543, 0.0], [0.0039028520695865154, 0.0008621322922408581, 0.02400260791182518, 0.35541704297065735, 0.048350416123867035, 0.00013779231812804937, 0.00015075977717060596, 0.0015127401566132903, 0.5656636953353882, 0.0]]], [[[0.09929531812667847, 0.3125585615634918, 0.26699960231781006, 0.036189958453178406, 0.01689508929848671, 0.05626463145017624, 0.014853590168058872, 0.021625356748700142, 0.17531771957874298, 0.0], [0.6598999500274658, 0.04883529245853424, 0.24573534727096558, 0.008949915878474712, 0.008034803904592991, 0.0058951652608811855, 0.001835338887758553, 0.0024289200082421303, 0.018385181203484535, 0.0], [0.28377673029899597, 0.4307016134262085, 0.19275489449501038, 0.05968217924237251, 0.007509235758334398, 0.00627214927226305, 0.0010254314402118325, 0.0010938378982245922, 0.017183959484100342, 0.0], [0.00751571636646986, 0.01881357654929161, 0.9318985342979431, 0.014481762424111366, 0.02105659246444702, 0.0032304797787219286, 0.00013498679618351161, 2.4857494281604886e-05, 0.0028432777617126703, 0.0], [0.08691340684890747, 0.01259385235607624, 0.21131311357021332, 0.15839329361915588, 0.3931293189525604, 0.10845079272985458, 0.004768806044012308, 0.0032348930835723877, 0.021202562376856804, 0.0], [0.029192518442869186, 0.06438057869672775, 0.033022571355104446, 0.04279496520757675, 0.6011855006217957, 0.17385539412498474, 0.03754284232854843, 0.006468524225056171, 0.011557108722627163, 0.0], [0.006125382613390684, 0.006982659921050072, 0.004575703293085098, 0.0037440320011228323, 0.36007580161094666, 0.5409486889839172, 0.0626324936747551, 0.00843171589076519, 0.006483553443104029, 0.0], [0.0017123871948570013, 0.017555760219693184, 0.012620777823030949, 0.00947127677500248, 0.08178496360778809, 0.2538650631904602, 0.19189175963401794, 0.255443274974823, 0.17565478384494781, 0.0], [0.02615528553724289, 0.002552631078287959, 0.01957615464925766, 0.021708596497774124, 0.008856788277626038, 0.021813882514834404, 0.052812058478593826, 0.19690369069576263, 0.6496209502220154, 0.0], [0.004899451043456793, 0.005663626827299595, 0.012920243665575981, 0.007757777348160744, 0.014441648498177528, 0.021742597222328186, 0.05050418898463249, 0.35952994227409363, 0.5225404500961304, 0.0]], [[0.8470081686973572, 0.043761640787124634, 0.000660977209918201, 0.00018918802379630506, 0.01478277612477541, 0.00942840613424778, 0.06798462569713593, 0.011217072606086731, 0.004967056680470705, 0.0], [0.9998846054077148, 9.298400982515886e-05, 7.557733283647394e-08, 4.2952964861113496e-13, 4.9295836510032665e-12, 3.2098330660090824e-09, 5.042555585532682e-06, 1.7450745872338302e-05, 2.33268380611662e-07, 0.0], [2.118646625604015e-05, 0.9999122619628906, 6.629392737522721e-05, 1.312590147684034e-09, 2.7011800782239526e-11, 6.488713510726871e-14, 1.250517189799183e-10, 3.650779589747799e-08, 2.9122876554765753e-08, 0.0], [1.1949000816580124e-11, 3.2456850362905243e-07, 1.0, 3.0732459777027543e-07, 4.943382370115046e-10, 1.2582140899967535e-17, 7.485076299292317e-18, 2.998638596002183e-14, 1.3861908843004755e-10, 0.0], [5.382360668271247e-10, 8.056646905174603e-09, 0.00035429277340881526, 0.9995232820510864, 0.00012279135989956558, 1.6631793720023325e-09, 1.8857353897253244e-14, 9.284229879032505e-15, 1.8321206097376974e-12, 0.0], [8.614902194392648e-12, 3.5818106835540375e-13, 4.029543365646759e-09, 3.1193526410788763e-06, 0.9959417581558228, 0.004055640660226345, 2.0883923923520342e-08, 1.5150488692381933e-14, 1.8145465705242968e-17, 0.0], [2.3006167283734502e-12, 4.150501252094593e-15, 2.9068709245239077e-12, 2.726213081238188e-13, 1.0724114645199734e-06, 0.9999104142189026, 8.954491204349324e-05, 3.77386955019432e-10, 8.537545242676776e-16, 0.0], [8.656632632941808e-10, 2.8593680201360883e-10, 4.910126749635424e-10, 3.37084723469553e-15, 1.3075121541028523e-10, 0.0003027402563020587, 0.999218225479126, 0.00047932929010130465, 1.4258912273135138e-08, 0.0], [1.0133464911632473e-07, 1.7307414168499236e-07, 2.3342326471720298e-07, 4.688030020606748e-13, 1.5028331227032177e-12, 5.3876938466146385e-09, 0.00158107269089669, 0.994592010974884, 0.0038271904923021793, 0.0], [2.33300490037891e-10, 1.2628836998374027e-07, 1.2948551102454076e-06, 3.169647599943204e-10, 1.5141217069741288e-14, 8.21656009561151e-15, 2.347289251858342e-09, 0.0025180077645927668, 0.9974797964096069, 0.0]], [[0.011770328506827354, 0.014021093025803566, 0.10656744986772537, 0.04667313024401665, 0.13704808056354523, 0.04681243374943733, 0.08347266167402267, 0.3310377299785614, 0.22259721159934998, 0.0], [0.009583584032952785, 0.010384900495409966, 0.09424954652786255, 0.09874095767736435, 0.2214881330728531, 0.08727390319108963, 0.09998933970928192, 0.16299772262573242, 0.21529172360897064, 0.0], [0.040493443608284, 0.05296378955245018, 0.12471148371696472, 0.04822944849729538, 0.2201310694217682, 0.13458549976348877, 0.16853223741054535, 0.12866733968257904, 0.08168572932481766, 0.0], [0.014574799686670303, 0.015747353434562683, 0.011357909068465233, 0.008449763990938663, 0.024292636662721634, 0.06141809746623039, 0.10683716088533401, 0.6414783596992493, 0.1158437430858612, 0.0], [0.0041047134436666965, 0.010159346275031567, 0.006441198755055666, 0.009530052542686462, 0.061682768166065216, 0.07391326874494553, 0.3019707202911377, 0.45178085565567017, 0.08041701465845108, 0.0], [0.013634801842272282, 0.03774101287126541, 0.015713637694716454, 0.01436087116599083, 0.06650711596012115, 0.06899012625217438, 0.1819150745868683, 0.376579225063324, 0.2245580554008484, 0.0], [0.03166442736983299, 0.07015468180179596, 0.1104653850197792, 0.016236137598752975, 0.18190902471542358, 0.08141329884529114, 0.15690769255161285, 0.22899281978607178, 0.12225660681724548, 0.0], [0.10994787514209747, 0.08447018265724182, 0.05270976573228836, 0.013435273431241512, 0.06919412314891815, 0.04981343820691109, 0.24833135306835175, 0.2721446752548218, 0.09995320439338684, 0.0], [0.39435869455337524, 0.21061576902866364, 0.1085209921002388, 0.004411425907164812, 0.06908565759658813, 0.04562678933143616, 0.02559957653284073, 0.06842028349637985, 0.0733608528971672, 0.0], [0.2682938873767853, 0.18270419538021088, 0.12741044163703918, 0.03156330808997154, 0.10574271529912949, 0.0955348014831543, 0.052997197955846786, 0.0821281224489212, 0.05362524837255478, 0.0]], [[8.027511648833752e-05, 0.0010475717717781663, 0.9977908730506897, 0.0002747455728240311, 0.000536168459802866, 9.231048170477152e-05, 0.00010586588905425742, 1.1979215742030647e-05, 5.969347330392338e-05, 0.0], [0.00012679747305810452, 5.715776205761358e-05, 0.922791600227356, 0.07177212089300156, 0.002934361109510064, 0.0005548547487705946, 0.001313770073466003, 2.2278460164670832e-05, 0.0004267726035322994, 0.0], [0.0063565499149262905, 0.0009426671313121915, 0.23976103961467743, 0.6402719020843506, 0.019077658653259277, 0.04590805247426033, 0.0423574335873127, 0.00055616011377424, 0.0047685266472399235, 0.0], [0.00012164804502390325, 1.1780298336816486e-05, 0.0001827587402658537, 0.00020120454428251833, 0.9978508353233337, 0.0014421044616028666, 6.411068170564249e-05, 4.628768147085793e-05, 7.896547322161496e-05, 0.0], [0.03763079643249512, 0.00208932813256979, 0.0006042887107469141, 0.5138440728187561, 0.19755180180072784, 0.029773280024528503, 0.15554653108119965, 0.015671545639634132, 0.0472884401679039, 0.0], [3.8805592339485884e-05, 1.2464041901694145e-05, 9.030352521222085e-05, 1.7544094589538872e-05, 0.0006991567788645625, 0.039246365427970886, 0.9305517077445984, 0.02403487078845501, 0.005308609921485186, 0.0], [0.003011370776221156, 0.005974559113383293, 0.003425326431170106, 0.001937237335368991, 0.01794668287038803, 0.06517820060253143, 0.25853174924850464, 0.28359606862068176, 0.3603990077972412, 0.0], [0.0019687232561409473, 0.0019828693475574255, 0.0009621239732950926, 0.0017320939805358648, 0.008526722900569439, 0.012685983441770077, 0.060781437903642654, 0.38653799891471863, 0.524821937084198, 0.0], [0.06319467723369598, 0.3812802731990814, 0.07775641977787018, 0.0546053946018219, 0.0410320870578289, 0.010218034498393536, 0.022281788289546967, 0.04868403077125549, 0.30094724893569946, 0.0], [0.06465335935354233, 0.0841824859380722, 0.028003698214888573, 0.01470992248505354, 0.013160775415599346, 0.006258893292397261, 0.003528257366269827, 0.022525515407323837, 0.7629771828651428, 0.0]], [[0.00496841873973608, 0.010829150676727295, 0.03283568099141121, 0.009884797036647797, 0.047239795327186584, 0.06476759165525436, 0.11417313665151596, 0.6207002401351929, 0.09460126608610153, 0.0], [0.014457895420491695, 0.06253711134195328, 0.10527490824460983, 0.051058270037174225, 0.04873393103480339, 0.058862265199422836, 0.13390113413333893, 0.44425415992736816, 0.0809202790260315, 0.0], [0.09337731450796127, 0.22848238050937653, 0.11594945937395096, 0.04185759648680687, 0.012283656746149063, 0.1264774352312088, 0.19395124912261963, 0.16978387534618378, 0.017837027087807655, 0.0], [0.7125841975212097, 0.21987739205360413, 0.020619483664631844, 0.02881826087832451, 0.009833384305238724, 0.004124533850699663, 0.0008098671096377075, 0.0004809961246792227, 0.0028517041355371475, 0.0], [0.029080189764499664, 0.33611080050468445, 0.12628716230392456, 0.0817737877368927, 0.1908877044916153, 0.0943109318614006, 0.05712011829018593, 0.06781000643968582, 0.016619542613625526, 0.0], [0.07309448719024658, 0.07739713788032532, 0.0567743182182312, 0.03291132301092148, 0.16455504298210144, 0.1779973953962326, 0.2714528441429138, 0.13868720829486847, 0.007130389101803303, 0.0], [0.2111189365386963, 0.06559138745069504, 0.041267942637205124, 0.009358389303088188, 0.20342323184013367, 0.1869427114725113, 0.19775718450546265, 0.07797932624816895, 0.006560905836522579, 0.0], [0.08770362287759781, 0.12808790802955627, 0.023038268089294434, 0.17453545331954956, 0.09798892587423325, 0.11677049100399017, 0.09396524727344513, 0.26174578070640564, 0.01616443321108818, 0.0], [0.35409674048423767, 0.0420590415596962, 0.00930203776806593, 0.3349112272262573, 0.03967892378568649, 0.15319538116455078, 0.022175630554556847, 0.0432865284383297, 0.0012946304632350802, 0.0], [0.10030248761177063, 0.08145220577716827, 0.053510215133428574, 0.08076464384794235, 0.07446140050888062, 0.13495147228240967, 0.2503055930137634, 0.17467214167118073, 0.04957977309823036, 0.0]], [[0.140123188495636, 0.010056160390377045, 0.0845566838979721, 0.03108036518096924, 0.16015855967998505, 0.30321791768074036, 0.04101235046982765, 0.0719088688492775, 0.1578858345746994, 0.0], [0.6134085655212402, 0.1547522246837616, 0.03818102553486824, 0.001013039844110608, 0.013297338038682938, 0.008754062466323376, 0.005134810693562031, 0.0324203222990036, 0.13303862512111664, 0.0], [0.6891250014305115, 0.17779399454593658, 0.09809523820877075, 0.006996517535299063, 0.007719202898442745, 0.0016296659596264362, 0.010662317276000977, 0.004304768517613411, 0.0036729834973812103, 0.0], [0.04376668110489845, 0.09640005975961685, 0.8100467324256897, 0.018579678609967232, 0.017539000138640404, 0.0008903089328669012, 0.0009985471842810512, 0.003613307373598218, 0.008165487088263035, 0.0], [0.03085213713347912, 0.025543441995978355, 0.6937543153762817, 0.17392684519290924, 0.03124413825571537, 0.02177071012556553, 0.007475809659808874, 0.003389933379366994, 0.012042560614645481, 0.0], [0.020024498924613, 0.002941351616755128, 0.05481509119272232, 0.183584526181221, 0.4182366132736206, 0.25923243165016174, 0.05362166836857796, 0.0045484029687941074, 0.002995501272380352, 0.0], [0.006091661751270294, 0.0012010806240141392, 0.008193010464310646, 0.009258490055799484, 0.15450483560562134, 0.7388086915016174, 0.06675267219543457, 0.01373466569930315, 0.0014547830214723945, 0.0], [0.0014694302808493376, 0.0017220929730683565, 0.005703628528863192, 0.0032696493435651064, 0.01713697426021099, 0.49356934428215027, 0.3729664385318756, 0.05505490303039551, 0.04910748079419136, 0.0], [0.0052343131974339485, 0.004969605710357428, 0.005609327927231789, 0.0007064095698297024, 0.005421568639576435, 0.045942794531583786, 0.22256441414356232, 0.43683722615242004, 0.27271413803100586, 0.0], [0.011939328163862228, 0.019054703414440155, 0.010745645500719547, 0.006908759940415621, 0.009522099047899246, 0.006889646407216787, 0.12289831787347794, 0.2292226105928421, 0.5828191637992859, 0.0]], [[0.0014003654941916466, 0.00935011450201273, 0.8996742963790894, 0.029868578538298607, 0.05752851441502571, 0.0008847691351547837, 0.0005429417942650616, 0.0004143548430874944, 0.00033632174017839134, 0.0], [0.0005502321291714907, 0.003854800947010517, 0.8475468754768372, 0.06876953691244125, 0.07909266650676727, 5.498397149494849e-05, 2.1647396351909265e-05, 6.648269391007489e-06, 0.00010276718239765614, 0.0], [0.0025599629152566195, 0.010113149881362915, 0.21385346353054047, 0.26065483689308167, 0.44287386536598206, 0.0458405464887619, 0.013329384848475456, 0.0076821851544082165, 0.0030928871128708124, 0.0], [0.0002600199659354985, 3.3608048397582024e-05, 0.0020931970793753862, 0.007768034934997559, 0.9780486822128296, 0.011327453888952732, 0.00041993538616225123, 4.125805935473181e-05, 8.07127889856929e-06, 0.0], [0.0010751935187727213, 0.00017567894246894866, 0.004301255568861961, 0.0010412797564640641, 0.012584774754941463, 0.5903621912002563, 0.36841556429862976, 0.021853862330317497, 0.00019013854034710675, 0.0], [0.00036065353197045624, 0.00041391997365280986, 0.00018344201089348644, 1.21664334074012e-05, 0.0008204621262848377, 0.02300320193171501, 0.7380199432373047, 0.23411831259727478, 0.0030676021706312895, 0.0], [0.0007766868220642209, 0.00179819215554744, 0.0031821478623896837, 1.569229607412126e-05, 0.001023828866891563, 0.004582487046718597, 0.04412461444735527, 0.8326310515403748, 0.11186514794826508, 0.0], [0.002560202032327652, 0.0021961459424346685, 0.0012966376962140203, 3.874531466863118e-05, 0.00012789985339622945, 0.00017348439723718911, 0.06046983227133751, 0.07663179188966751, 0.856505274772644, 0.0], [0.05078713223338127, 0.09524610638618469, 0.03648101165890694, 0.050540339201688766, 0.009611092507839203, 0.0027538249269127846, 0.009690326638519764, 0.015156174078583717, 0.7297340035438538, 0.0], [0.017420543357729912, 0.009016300551593304, 0.008660875260829926, 0.04713813588023186, 0.042011067271232605, 0.003162879729643464, 0.00040178498602472246, 0.005153133533895016, 0.8670352697372437, 0.0]], [[0.22553573548793793, 0.2680850327014923, 0.019470686092972755, 0.14175784587860107, 0.053468361496925354, 0.02777918614447117, 0.05628729239106178, 0.04874898120760918, 0.15886712074279785, 0.0], [0.28905513882637024, 0.12247822433710098, 0.046002231538295746, 0.1958596557378769, 0.10771062225103378, 0.06661061197519302, 0.07628067582845688, 0.02713944762945175, 0.06886337697505951, 0.0], [0.04905243590474129, 0.05268532782793045, 0.11285670101642609, 0.09091109782457352, 0.24185867607593536, 0.20752739906311035, 0.04222555831074715, 0.05885446071624756, 0.14402832090854645, 0.0], [0.06971512734889984, 0.14066818356513977, 0.05942149832844734, 0.21028849482536316, 0.10966084897518158, 0.08002462983131409, 0.10722756385803223, 0.1377343237400055, 0.08525940030813217, 0.0], [0.1429702192544937, 0.26978883147239685, 0.12360350787639618, 0.05825580656528473, 0.022957824170589447, 0.2193503975868225, 0.0713224932551384, 0.06461618840694427, 0.02713468112051487, 0.0], [0.07554306834936142, 0.051579318940639496, 0.2103901356458664, 0.03246254473924637, 0.12347473949193954, 0.20594589412212372, 0.10415074229240417, 0.14436782896518707, 0.05208563804626465, 0.0], [0.10752540081739426, 0.08459899574518204, 0.07340764254331589, 0.019914846867322922, 0.048802055418491364, 0.2628321945667267, 0.23049965500831604, 0.11754198372364044, 0.05487721040844917, 0.0], [0.054300110787153244, 0.03522595763206482, 0.19028180837631226, 0.11526520550251007, 0.043804410845041275, 0.1941872388124466, 0.12765192985534668, 0.19942660629749298, 0.03985673561692238, 0.0], [0.13462598621845245, 0.09648311138153076, 0.08205218613147736, 0.241444393992424, 0.024601474404335022, 0.03336581960320473, 0.09252338856458664, 0.0673752948641777, 0.22752824425697327, 0.0], [0.1438782811164856, 0.15257491171360016, 0.11015111207962036, 0.2259429395198822, 0.11582648009061813, 0.06522659957408905, 0.06865230947732925, 0.07465960830450058, 0.04308782145380974, 0.0]]], [[[0.008583037182688713, 0.007665919605642557, 0.023932937532663345, 0.013663848862051964, 0.00724611384794116, 0.01780843734741211, 0.04220886155962944, 0.035630952566862106, 0.8432599306106567, 0.0], [0.005249040201306343, 0.006725347600877285, 0.022601336240768433, 0.004061485640704632, 0.003380684182047844, 0.05792760103940964, 0.08571713417768478, 0.017759306356310844, 0.796578049659729, 0.0], [0.014741344377398491, 0.08626628667116165, 0.11416944116353989, 0.06755448132753372, 0.010767532512545586, 0.037519536912441254, 0.13943251967430115, 0.03284287825226784, 0.4967060387134552, 0.0], [0.8946033120155334, 0.07520093768835068, 0.007621173746883869, 0.004705401603132486, 0.005715447012335062, 0.0016736779361963272, 0.0011882666731253266, 0.0005322583019733429, 0.008759708143770695, 0.0], [0.17331360280513763, 0.32618802785873413, 0.1865183413028717, 0.12219864875078201, 0.08427056670188904, 0.017049826681613922, 0.027256622910499573, 0.011689829640090466, 0.05151442065834999, 0.0], [0.024287043139338493, 0.22289688885211945, 0.2742122411727905, 0.1883603185415268, 0.1339159905910492, 0.04209006950259209, 0.04496186599135399, 0.03600992262363434, 0.033265650272369385, 0.0], [0.01142946071922779, 0.05564042925834656, 0.055694323033094406, 0.5140662789344788, 0.1435396671295166, 0.038738954812288284, 0.06230159476399422, 0.07060025632381439, 0.047988954931497574, 0.0], [0.03956271708011627, 0.0978141501545906, 0.053332336246967316, 0.4993227422237396, 0.15091775357723236, 0.05724353715777397, 0.05616844817996025, 0.014285729266703129, 0.03135249391198158, 0.0], [0.04081583395600319, 0.017569201067090034, 0.031049959361553192, 0.07860688865184784, 0.1978374421596527, 0.3013133406639099, 0.2561938464641571, 0.010236106812953949, 0.06637723743915558, 0.0], [0.005346705671399832, 0.017637349665164948, 0.01670711860060692, 0.027819450944662094, 0.014111858792603016, 0.15744496881961823, 0.29349666833877563, 0.10989060997962952, 0.357545405626297, 0.0]], [[0.14326919615268707, 0.06937730312347412, 0.4621289074420929, 0.06899607926607132, 0.20691490173339844, 0.03204977884888649, 0.010433961637318134, 0.001572124194353819, 0.005257652141153812, 0.0], [0.7372201681137085, 0.03819188475608826, 0.19263039529323578, 0.00509582320228219, 0.014029700309038162, 0.004338367842137814, 0.0016640998655930161, 0.0023727945517748594, 0.004456941969692707, 0.0], [0.6392468810081482, 0.09436309337615967, 0.23124097287654877, 0.009032140485942364, 0.016629014164209366, 0.004053707234561443, 0.0011662752367556095, 0.0013368013314902782, 0.0029307324439287186, 0.0], [0.15959776937961578, 0.060010410845279694, 0.6323540210723877, 0.04208587482571602, 0.09941276162862778, 0.001314919558353722, 0.0003186642425134778, 0.00045829309965483844, 0.004447522107511759, 0.0], [0.06331828236579895, 0.03697410970926285, 0.6882537603378296, 0.04094800353050232, 0.1500014215707779, 0.014815385453402996, 0.0006663103122264147, 0.0014023728435859084, 0.0036205528303980827, 0.0], [0.02740752510726452, 0.007235638331621885, 0.2575177550315857, 0.2825733423233032, 0.26921361684799194, 0.13694509863853455, 0.012512636370956898, 0.00419765617698431, 0.0023968773894011974, 0.0], [0.026527998968958855, 0.0014296816661953926, 0.0034867397043854, 0.11850380897521973, 0.15826237201690674, 0.4342584013938904, 0.21162042021751404, 0.04376554489135742, 0.0021449460182338953, 0.0], [0.0008783259545452893, 0.0010965524706989527, 0.006981557235121727, 0.007060014642775059, 0.27200379967689514, 0.45634904503822327, 0.1935150921344757, 0.03130912408232689, 0.030806703492999077, 0.0], [0.012816469185054302, 0.004784241784363985, 0.007290879264473915, 0.0027244724333286285, 0.0388973169028759, 0.12052476406097412, 0.3920805752277374, 0.10759556293487549, 0.3132855296134949, 0.0], [0.0021361028775572777, 0.003133963793516159, 0.003311034757643938, 0.0013810866512358189, 0.004479007329791784, 0.007041627541184425, 0.09507600963115692, 0.5596640706062317, 0.32377713918685913, 0.0]], [[0.001748488168232143, 0.011698327027261257, 0.047558922320604324, 0.7770814299583435, 0.15215088427066803, 0.0056790816597640514, 0.0010312696686014533, 0.0011229184456169605, 0.0019287114264443517, 0.0], [0.000820137036498636, 0.0007328591891564429, 0.012266330420970917, 0.94822758436203, 0.02221596986055374, 0.006038068328052759, 0.0018012026557698846, 0.002194090047851205, 0.0057037402875721455, 0.0], [0.0017187671037390828, 0.0012595502194017172, 0.00971528235822916, 0.8996129631996155, 0.03184645250439644, 0.026646586135029793, 0.01671759784221649, 0.005960865877568722, 0.006522092968225479, 0.0], [0.010048117488622665, 0.003920346032828093, 0.01464000903069973, 0.028398782014846802, 0.047600653022527695, 0.6803404688835144, 0.07394693046808243, 0.046145662665367126, 0.09495888650417328, 0.0], [0.0020061242394149303, 0.0010488562984392047, 0.0021137045696377754, 0.03403143212199211, 0.040159616619348526, 0.4656003415584564, 0.16990402340888977, 0.16164875030517578, 0.12348736822605133, 0.0], [0.0023888982832431793, 0.0010238748509436846, 0.0031129145063459873, 0.00400560162961483, 0.005227341782301664, 0.050918273627758026, 0.28773385286331177, 0.5181463956832886, 0.12744267284870148, 0.0], [0.0057381619699299335, 0.0037375285755842924, 0.006655727047473192, 0.0010085925459861755, 0.005980721674859524, 0.02943945676088333, 0.05893365666270256, 0.6100658774375916, 0.2784405052661896, 0.0], [0.003593636676669121, 0.0024473541416227818, 0.002264569513499737, 0.00914584007114172, 0.0013253247598186135, 0.010908454656600952, 0.07958614826202393, 0.12585432827472687, 0.7648744583129883, 0.0], [0.031058229506015778, 0.02174283377826214, 0.012145284563302994, 0.010826506651937962, 0.01352943666279316, 0.021966811269521713, 0.055832888931035995, 0.11603516340255737, 0.7168627977371216, 0.0], [0.20383700728416443, 0.06762446463108063, 0.042199794203042984, 0.021983252838253975, 0.11625738441944122, 0.013579235412180424, 0.025292381644248962, 0.08914806693792343, 0.4200783669948578, 0.0]], [[0.022736268118023872, 0.02286626398563385, 0.14116300642490387, 0.13108347356319427, 0.23994718492031097, 0.1924150437116623, 0.01816762052476406, 0.04976898059248924, 0.18185211718082428, 0.0], [0.05882957577705383, 0.028569074347615242, 0.23305171728134155, 0.053790394216775894, 0.18451730906963348, 0.2002667486667633, 0.015585620887577534, 0.052768219262361526, 0.17262138426303864, 0.0], [0.09136874228715897, 0.08459936082363129, 0.05023255571722984, 0.21660202741622925, 0.1335863471031189, 0.10654665529727936, 0.02717875875532627, 0.06888726353645325, 0.22099831700325012, 0.0], [0.04131297022104263, 0.05848437175154686, 0.3077566921710968, 0.040097035467624664, 0.16343727707862854, 0.11984208226203918, 0.06441103667020798, 0.0850440189242363, 0.11961443722248077, 0.0], [0.06447532773017883, 0.05503746494650841, 0.11529060453176498, 0.13719302415847778, 0.0843825414776802, 0.22279226779937744, 0.11870565265417099, 0.05292103812098503, 0.14920207858085632, 0.0], [0.061820220202207565, 0.03663187846541405, 0.08412205427885056, 0.386857271194458, 0.1083698719739914, 0.1462787538766861, 0.03903358429670334, 0.026668915525078773, 0.11021733283996582, 0.0], [0.08746915310621262, 0.025642354041337967, 0.16437062621116638, 0.19346435368061066, 0.10867251455783844, 0.12237238138914108, 0.06722743809223175, 0.0922309011220932, 0.13855047523975372, 0.0], [0.10294228792190552, 0.07313423603773117, 0.18607352674007416, 0.09769721329212189, 0.1089077964425087, 0.26933327317237854, 0.06555335968732834, 0.061070602387189865, 0.03528755530714989, 0.0], [0.12094805389642715, 0.14730192720890045, 0.09877816587686539, 0.21085986495018005, 0.06241541728377342, 0.22994481027126312, 0.04595630243420601, 0.04531335458159447, 0.0384821854531765, 0.0], [0.11032164841890335, 0.07897982746362686, 0.08231978863477707, 0.2677886188030243, 0.1231643408536911, 0.0929633229970932, 0.08270144462585449, 0.06097007542848587, 0.10079105943441391, 0.0]], [[0.008687321096658707, 0.012162125669419765, 0.02774685248732567, 0.0013578477082774043, 0.052177976816892624, 0.027187975123524666, 0.05590689554810524, 0.020962538197636604, 0.7938104867935181, 0.0], [0.005042325239628553, 0.015503124333918095, 0.010042164474725723, 0.0008876739302650094, 0.011308688670396805, 0.010491759516298771, 0.03130592033267021, 0.04934320226311684, 0.8660751581192017, 0.0], [0.013016406446695328, 0.03886239603161812, 0.027493299916386604, 0.029101338237524033, 0.009947741404175758, 0.00769558921456337, 0.035501737147569656, 0.023772817105054855, 0.8146085143089294, 0.0], [0.018851714208722115, 0.05105733126401901, 0.8005384206771851, 0.01116525661200285, 0.09583853930234909, 0.0015093896072357893, 0.005055624525994062, 0.0006665397086180747, 0.015317671000957489, 0.0], [0.01609102450311184, 0.023716216906905174, 0.5135837197303772, 0.10603100061416626, 0.26668840646743774, 0.019648341462016106, 0.01755940169095993, 0.01368130836635828, 0.023000601679086685, 0.0], [0.01718730293214321, 0.02692273259162903, 0.05480796471238136, 0.010818017646670341, 0.7150712013244629, 0.0585104264318943, 0.04717297852039337, 0.030360547825694084, 0.039148781448602676, 0.0], [0.006439396180212498, 0.012697076424956322, 0.014188298024237156, 0.000897688849363476, 0.7481768727302551, 0.15047557651996613, 0.03333613649010658, 0.01207506563514471, 0.021714046597480774, 0.0], [0.009459104388952255, 0.022298788651823997, 0.013802104629576206, 0.011955137364566326, 0.03879927098751068, 0.1585427075624466, 0.07075291126966476, 0.329448938369751, 0.3449409306049347, 0.0], [0.04810584336519241, 0.017975708469748497, 0.025123968720436096, 0.023182567209005356, 0.020010611042380333, 0.04571577161550522, 0.1801854819059372, 0.06764508783817291, 0.5720548629760742, 0.0], [0.026153914630413055, 0.0356404148042202, 0.10573611408472061, 0.06201518699526787, 0.06006328761577606, 0.09286139905452728, 0.2927103638648987, 0.20419549942016602, 0.12062377482652664, 0.0]], [[0.02415475994348526, 0.0027711745351552963, 0.003856832394376397, 0.0957413911819458, 0.02159286104142666, 0.03336814045906067, 0.009564127773046494, 0.03954486921429634, 0.7694058418273926, 0.0], [0.9052021503448486, 0.02053658291697502, 0.0014916026266291738, 0.00022646080469712615, 4.7710393118904904e-05, 0.000383042759494856, 0.014123834669589996, 0.0205638837069273, 0.03742456063628197, 0.0], [0.37607336044311523, 0.6030705571174622, 0.0068079219199717045, 0.0036466827150434256, 9.876023250399157e-05, 2.0246809071977623e-05, 0.0007042856304906309, 0.002560489112511277, 0.007017510011792183, 0.0], [5.0091031880583614e-05, 0.00024915943504311144, 0.9895205497741699, 0.006273698527365923, 0.0016484790248796344, 4.1711446101544425e-05, 7.522702958340233e-07, 1.2660359971050639e-05, 0.002202932955697179, 0.0], [8.009441080503166e-05, 9.311464236816391e-05, 0.006593613885343075, 0.9913647770881653, 0.0018261962104588747, 1.6436462829005904e-05, 8.038865075832291e-07, 1.0318336762793479e-06, 2.3524326024926268e-05, 0.0], [3.1561212381348014e-05, 1.8178753862230224e-06, 0.00011904581333510578, 0.027105441316962242, 0.8800897598266602, 0.09253741800785065, 0.00010895416926359758, 5.953493655397324e-06, 1.9602707368449046e-07, 0.0], [1.7160528553716858e-09, 1.4191656530493368e-11, 3.274841375855431e-08, 2.1219284462858923e-07, 1.9925082597183064e-05, 0.9999751448631287, 3.130498271275428e-06, 1.9788064946624218e-06, 3.1215499074477293e-09, 0.0], [1.2861962204624433e-05, 5.737682045037218e-07, 2.0471109110076213e-06, 1.0477544492459856e-05, 6.581651632586727e-06, 0.02534269355237484, 0.16125597059726715, 0.5878354907035828, 0.22553342580795288, 0.0], [0.0009172551217488945, 7.270056084962562e-05, 2.2026280930731446e-05, 4.6261970965133514e-06, 4.921669642499182e-06, 4.060195351485163e-05, 0.027831047773361206, 0.33271971344947815, 0.6383873224258423, 0.0], [1.3075091374048498e-05, 6.147480598883703e-05, 4.768987855641171e-05, 2.045959490715177e-06, 1.1152823553572944e-08, 3.07468525306831e-07, 0.0007055726600810885, 0.02803119830787182, 0.9711382985115051, 0.0]], [[0.060361556708812714, 0.015829458832740784, 0.05784451961517334, 0.3351474404335022, 0.06477320939302444, 0.04427827522158623, 0.09356044977903366, 0.03362266346812248, 0.2945823669433594, 0.0], [0.051239900290966034, 0.0459107868373394, 0.10656695812940598, 0.4080160856246948, 0.16381530463695526, 0.044977184385061264, 0.05972094088792801, 0.009804679080843925, 0.10994797199964523, 0.0], [0.019088272005319595, 0.05349855497479439, 0.4389742910861969, 0.022328443825244904, 0.03395729511976242, 0.20592069625854492, 0.007582489866763353, 0.08437496423721313, 0.13427504897117615, 0.0], [0.03275543451309204, 0.01311502419412136, 0.038520246744155884, 0.47789818048477173, 0.04586595296859741, 0.01380465179681778, 0.03337283805012703, 0.07212045043706894, 0.27254730463027954, 0.0], [0.04071904346346855, 0.043366871774196625, 0.1190471276640892, 0.18268215656280518, 0.2763146162033081, 0.029253922402858734, 0.017268449068069458, 0.0670313611626625, 0.22431644797325134, 0.0], [0.04853136092424393, 0.0034203159157186747, 0.17822766304016113, 0.005087696481496096, 0.02670232392847538, 0.5734196305274963, 0.06478680670261383, 0.04684215411543846, 0.05298209935426712, 0.0], [0.016102498397231102, 0.0006646174006164074, 0.00315408268943429, 0.003398373955860734, 0.01210782676935196, 0.07864897698163986, 0.743419349193573, 0.023116787895560265, 0.11938738822937012, 0.0], [0.0031801864970475435, 0.0032259617000818253, 0.027063841000199318, 0.0018325509736314416, 0.006064774002879858, 0.017839375883340836, 0.05006564408540726, 0.8002738952636719, 0.0904538482427597, 0.0], [0.02500138245522976, 0.016465606167912483, 0.02692888118326664, 0.01824249140918255, 0.047875918447971344, 0.06556686758995056, 0.15585453808307648, 0.21941381692886353, 0.42465049028396606, 0.0], [0.07641319185495377, 0.017753547057509422, 0.039497166872024536, 0.014236720278859138, 0.03872253745794296, 0.1210501492023468, 0.17305448651313782, 0.2333979308605194, 0.28587427735328674, 0.0]], [[0.15564993023872375, 0.3264511823654175, 0.08247561007738113, 0.04047680273652077, 0.04636594280600548, 0.03705644607543945, 0.05653020739555359, 0.08808662742376328, 0.16690711677074432, 0.0], [0.6047166585922241, 0.08402378112077713, 0.11650887131690979, 0.004807815421372652, 0.02726476825773716, 0.0609126091003418, 0.02905944734811783, 0.012920884415507317, 0.059785205870866776, 0.0], [0.5938906669616699, 0.07300958037376404, 0.08890929818153381, 0.008111076429486275, 0.04038470610976219, 0.07353192567825317, 0.03085281327366829, 0.08706387132406235, 0.004246041644364595, 0.0], [0.2591831088066101, 0.17658700048923492, 0.44177621603012085, 0.01689036749303341, 0.0653892457485199, 0.01502177957445383, 0.02055797167122364, 0.0024378441739827394, 0.0021566858049482107, 0.0], [0.33400091528892517, 0.03927909955382347, 0.27614372968673706, 0.009977479465305805, 0.12025652825832367, 0.1713484674692154, 0.04292818158864975, 0.004225345328450203, 0.00184013566467911, 0.0], [0.06147114187479019, 0.019044799730181694, 0.059415291994810104, 0.05198045074939728, 0.12181691080331802, 0.419679194688797, 0.1140735000371933, 0.14551687240600586, 0.00700181070715189, 0.0], [0.006845483556389809, 0.002091927919536829, 0.01196279563009739, 0.014390786178410053, 0.02692629024386406, 0.8455513715744019, 0.07174734026193619, 0.017689114436507225, 0.0027949714567512274, 0.0], [0.00039940490387380123, 0.00013551976007875055, 0.020663700997829437, 0.008696838282048702, 0.021915050223469734, 0.1381293535232544, 0.0347108468413353, 0.7650054097175598, 0.010343861766159534, 0.0], [0.02615724503993988, 0.0051858089864254, 0.038734134286642075, 0.021585455164313316, 0.19684533774852753, 0.17548950016498566, 0.1665634661912918, 0.2796759307384491, 0.08976294845342636, 0.0], [0.043001022189855576, 0.016749290749430656, 0.04958483204245567, 0.06659381091594696, 0.0702962800860405, 0.27735820412635803, 0.14212922751903534, 0.20686522126197815, 0.12742231786251068, 0.0]]], [[[0.13086311519145966, 0.049477167427539825, 0.10100015252828598, 0.03843620419502258, 0.27287009358406067, 0.20078831911087036, 0.16546384990215302, 0.03368193656206131, 0.007419050205498934, 0.0], [0.1137659102678299, 0.11250672489404678, 0.21935509145259857, 0.09974226355552673, 0.22245454788208008, 0.11022598296403885, 0.0977952778339386, 0.010162456892430782, 0.013991687446832657, 0.0], [0.09118296205997467, 0.0991944894194603, 0.31555840373039246, 0.16625922918319702, 0.1399575173854828, 0.0926588773727417, 0.021735703572630882, 0.056496523320674896, 0.016956249251961708, 0.0], [0.35773080587387085, 0.19870112836360931, 0.026073846966028214, 0.07347559928894043, 0.09251826256513596, 0.0859094187617302, 0.06421677768230438, 0.06334269791841507, 0.0380314365029335, 0.0], [0.02230222336947918, 0.0210218857973814, 0.024334343150258064, 0.36442241072654724, 0.2750929892063141, 0.13295342028141022, 0.06824173033237457, 0.0036951478105038404, 0.0879359245300293, 0.0], [0.018942566588521004, 0.011805560439825058, 0.04696377366781235, 0.09440026432275772, 0.39890599250793457, 0.17608429491519928, 0.10613365471363068, 0.10454639047384262, 0.04221746698021889, 0.0], [0.0475851334631443, 0.008668179623782635, 0.011950161308050156, 0.0786907747387886, 0.09432563930749893, 0.07653870433568954, 0.4287588894367218, 0.13403372466564178, 0.1194487139582634, 0.0], [0.008243327029049397, 0.006908380892127752, 0.04044030234217644, 0.08380357921123505, 0.1593569815158844, 0.1858288198709488, 0.0890916958451271, 0.40247857570648193, 0.02384827472269535, 0.0], [0.09753390401601791, 0.04787491634488106, 0.10570236295461655, 0.09989321976900101, 0.07242950052022934, 0.16000299155712128, 0.13195638358592987, 0.12870465219020844, 0.15590202808380127, 0.0], [0.3338638246059418, 0.05386793985962868, 0.15485166013240814, 0.05483235418796539, 0.052468191832304, 0.12754301726818085, 0.13515245914459229, 0.06475869566202164, 0.022661946713924408, 0.0]], [[0.011833908967673779, 0.03545977920293808, 0.03510122373700142, 0.06200635805726051, 0.09438431262969971, 0.06055876612663269, 0.053256530314683914, 0.30701303482055664, 0.3403860926628113, 0.0], [0.03663749620318413, 0.06511621922254562, 0.05716057866811752, 0.07533077895641327, 0.10846659541130066, 0.037432827055454254, 0.04480022192001343, 0.18166707456111908, 0.39338818192481995, 0.0], [0.06557667255401611, 0.03966936469078064, 0.008358842693269253, 0.06794404983520508, 0.05668830871582031, 0.02720261737704277, 0.07913517951965332, 0.20437636971473694, 0.45104852318763733, 0.0], [0.044038429856300354, 0.07477934658527374, 0.10143070667982101, 0.16204005479812622, 0.06265459954738617, 0.10170722752809525, 0.08676454424858093, 0.0699862688779831, 0.2965989410877228, 0.0], [0.06005045771598816, 0.046840403228998184, 0.06629239022731781, 0.04125581681728363, 0.007815167307853699, 0.20412082970142365, 0.1083299070596695, 0.04942404478788376, 0.41587093472480774, 0.0], [0.03666035085916519, 0.028792625293135643, 0.06887229532003403, 0.18481910228729248, 0.15058831870555878, 0.048441674560308456, 0.0780390277504921, 0.13469383120536804, 0.26909276843070984, 0.0], [0.03408746421337128, 0.026394939050078392, 0.05409233644604683, 0.06951043754816055, 0.1446777582168579, 0.09970070421695709, 0.05472328141331673, 0.16119606792926788, 0.35561704635620117, 0.0], [0.12936006486415863, 0.04621516913175583, 0.10149524360895157, 0.14774896204471588, 0.45855623483657837, 0.033130910247564316, 0.031401973217725754, 0.02012830227613449, 0.031963150948286057, 0.0], [0.1214270144701004, 0.04088712856173515, 0.05250505730509758, 0.07924661785364151, 0.05337269604206085, 0.10527284443378448, 0.08820997178554535, 0.17732012271881104, 0.28175854682922363, 0.0], [0.13074854016304016, 0.06475767493247986, 0.07325490564107895, 0.0625966489315033, 0.14061231911182404, 0.07830052822828293, 0.12438739091157913, 0.21453101933002472, 0.11081094294786453, 0.0]], [[0.0022766904439777136, 0.00227623013779521, 0.027263110503554344, 0.7988243699073792, 0.12335250526666641, 0.012830986641347408, 0.008179515600204468, 0.004631126299500465, 0.020365260541439056, 0.0], [0.022365765646100044, 0.0197063609957695, 0.08540411293506622, 0.7100865840911865, 0.10288897156715393, 0.023861246183514595, 0.009303209371864796, 0.012690575793385506, 0.013693095184862614, 0.0], [0.023093748837709427, 0.013999207876622677, 0.09048538655042648, 0.10519850999116898, 0.12126202881336212, 0.34847554564476013, 0.057331401854753494, 0.0919070839881897, 0.14824725687503815, 0.0], [0.03627682104706764, 0.0323517769575119, 0.06003699079155922, 0.04609783738851547, 0.3189731240272522, 0.3202785551548004, 0.06900984793901443, 0.021341597661376, 0.0956336110830307, 0.0], [0.026664189994335175, 0.018690558150410652, 0.01473171729594469, 0.003785684471949935, 0.012891196645796299, 0.6301508545875549, 0.1024516150355339, 0.10377107560634613, 0.08686315268278122, 0.0], [0.010066811926662922, 0.005272349342703819, 0.019913937896490097, 0.005584465805441141, 0.0479762889444828, 0.06466472148895264, 0.2978198528289795, 0.22872935235500336, 0.31997203826904297, 0.0], [0.054553788155317307, 0.011876759119331837, 0.005296430550515652, 0.008171333000063896, 0.17499762773513794, 0.29638832807540894, 0.22286026179790497, 0.017016055062413216, 0.20883934199810028, 0.0], [0.03061697818338871, 0.020777547731995583, 0.27117541432380676, 0.010558649897575378, 0.16651615500450134, 0.3011224865913391, 0.026109976693987846, 0.048922766000032425, 0.12420005351305008, 0.0], [0.16545239090919495, 0.03877135366201401, 0.007565324194729328, 0.015141250565648079, 0.03747279569506645, 0.3241279125213623, 0.26990416646003723, 0.043362975120544434, 0.09820175170898438, 0.0], [0.22949647903442383, 0.0972394198179245, 0.02905140444636345, 0.03182214871048927, 0.025490015745162964, 0.08278947323560715, 0.15009135007858276, 0.031098822131752968, 0.3229208290576935, 0.0]], [[0.023217031732201576, 0.015444980934262276, 0.33269768953323364, 0.4809305965900421, 0.08491171896457672, 0.027504485100507736, 0.007655052933841944, 0.015150148421525955, 0.012488299049437046, 0.0], [0.003814368275925517, 0.0054845609702169895, 0.005400203168392181, 0.34217125177383423, 0.010647634975612164, 0.00044525362318381667, 0.00011972449283348396, 0.00042839962407015264, 0.6314883828163147, 0.0], [0.013448912650346756, 0.01028169970959425, 0.4982297718524933, 0.3182436525821686, 0.01780710555613041, 0.024587348103523254, 0.0009282209794037044, 0.11607228964567184, 0.0004009671974927187, 0.0], [0.0027270291466265917, 0.01338754128664732, 0.019254636019468307, 0.11856623739004135, 0.0025901400949805975, 0.0012062221067026258, 0.0006161375786177814, 0.0012282256502658129, 0.8404240608215332, 0.0], [1.802536098693963e-05, 0.0005015733768232167, 2.3977232558536343e-05, 0.00012258262722752988, 0.00013862864580005407, 1.9367420463822782e-05, 1.2695372788584791e-05, 2.8395381377777085e-05, 0.9991349577903748, 0.0], [0.045823611319065094, 0.0060311248525977135, 0.11489683389663696, 0.011397628113627434, 0.14236140251159668, 0.31853923201560974, 0.18707275390625, 0.16781283915042877, 0.006064609158784151, 0.0], [0.031908370554447174, 0.0013231962220743299, 0.03774190694093704, 0.014869065955281258, 0.08836144208908081, 0.662682056427002, 0.1095389723777771, 0.05017231032252312, 0.0034025281202048063, 0.0], [0.0061959377489984035, 0.012075785547494888, 0.28881579637527466, 0.0719127431511879, 0.08756363391876221, 0.0848873034119606, 0.027471251785755157, 0.404219388961792, 0.016858302056789398, 0.0], [0.0946543961763382, 0.0623893216252327, 0.18748056888580322, 0.1788652539253235, 0.03208017721772194, 0.1587594598531723, 0.05469479411840439, 0.17047303915023804, 0.06060296297073364, 0.0], [0.019481608644127846, 0.068674735724926, 0.13537795841693878, 0.2137300968170166, 0.031131863594055176, 0.02376358024775982, 0.030956387519836426, 0.04989796131849289, 0.4269856810569763, 0.0]], [[0.00896595511585474, 0.001820763573050499, 0.0036846648436039686, 0.8942996859550476, 0.002699120668694377, 0.0018430916825309396, 0.00023619653075002134, 0.0008667120710015297, 0.08558366447687149, 0.0], [0.011139868758618832, 0.00517098605632782, 0.03486357256770134, 0.92783522605896, 0.010794212110340595, 0.0029791113920509815, 0.0008399260113947093, 0.0003134821599815041, 0.006063643377274275, 0.0], [0.07888396829366684, 0.0272236131131649, 0.0322146937251091, 0.791079044342041, 0.03133838623762131, 0.009372375905513763, 0.002263500588014722, 0.0005359782953746617, 0.02708848938345909, 0.0], [0.008838528767228127, 0.0009813528740778565, 0.014693140052258968, 0.00012726498243864626, 0.013269715011119843, 0.06431703269481659, 0.0039668334648013115, 0.8607616424560547, 0.0330444760620594, 0.0], [0.028727378696203232, 0.001701394678093493, 0.0009593431605026126, 0.0036824517883360386, 0.009683175943791866, 0.2589351236820221, 0.040837112814188004, 0.01649528741836548, 0.6389787197113037, 0.0], [0.009239337407052517, 0.0011580593418329954, 0.0009623299702070653, 0.000996780814602971, 0.00493139773607254, 0.04319336265325546, 0.859686553478241, 0.012395362369716167, 0.06743697822093964, 0.0], [0.024199873208999634, 0.007249501068145037, 0.02041051909327507, 0.008800184354186058, 0.02760438062250614, 0.1116553395986557, 0.030366744846105576, 0.03851965814828873, 0.7311937808990479, 0.0], [0.06881897896528244, 0.21671976149082184, 0.02303808182477951, 0.0017656114650890231, 0.09897635877132416, 0.04207116737961769, 0.012660021893680096, 0.25307658314704895, 0.2828734517097473, 0.0], [0.09324429929256439, 0.059572815895080566, 0.021969754248857498, 0.008625463582575321, 0.022502752020955086, 0.07016356289386749, 0.033860694617033005, 0.03514377400279045, 0.6549169421195984, 0.0], [0.04541633278131485, 0.01696496643126011, 0.003866765182465315, 0.00941139180213213, 0.006640681531280279, 0.024550199508666992, 0.009012367576360703, 0.009869653731584549, 0.8742677569389343, 0.0]], [[0.007143852766603231, 0.26111796498298645, 0.053768061101436615, 0.022731401026248932, 0.014146089553833008, 0.012985849753022194, 0.007359612733125687, 0.0043042986653745174, 0.6164429783821106, 0.0], [0.6188079118728638, 0.11004422605037689, 0.07541824132204056, 0.010463211685419083, 0.003863272722810507, 0.016659650951623917, 0.028880171477794647, 0.010046081617474556, 0.1258174628019333, 0.0], [0.1673126220703125, 0.6515482068061829, 0.016748156398534775, 0.042502570897340775, 0.016912223771214485, 0.011716129258275032, 0.04548521339893341, 0.0008787817787379026, 0.04689598083496094, 0.0], [0.6625117659568787, 0.049922335892915726, 0.2738172709941864, 0.004228150937706232, 0.0033112652599811554, 0.001177642960101366, 0.0005330604617483914, 0.00011132613872177899, 0.0043872324749827385, 0.0], [0.0040039438754320145, 0.00112480903044343, 0.04353015124797821, 0.9313303232192993, 0.010056668892502785, 0.0007567661814391613, 0.0006773694767616689, 0.00016374654660467058, 0.008356312289834023, 0.0], [0.0015825676964595914, 0.001574154943227768, 0.001225732616148889, 0.27774307131767273, 0.47191065549850464, 0.041899941861629486, 0.10331469774246216, 0.0047262245789170265, 0.0960230901837349, 0.0], [0.0012044805334880948, 0.001744594657793641, 0.010911357589066029, 0.035235531628131866, 0.12406003475189209, 0.49639585614204407, 0.02129644714295864, 0.07618547230958939, 0.23296628892421722, 0.0], [0.006665610242635012, 0.008957373909652233, 0.0028928713873028755, 0.7268922924995422, 0.10707614570856094, 0.01201178040355444, 0.013845101930201054, 0.022992080077528954, 0.09866661578416824, 0.0], [0.04134861007332802, 0.0526767373085022, 0.04131396487355232, 0.023087071254849434, 0.04077164828777313, 0.027765633538365364, 0.05679082125425339, 0.025407245382666588, 0.6908383369445801, 0.0], [0.0021411015186458826, 0.012145284563302994, 0.008635377511382103, 0.004571457393467426, 0.009789393283426762, 0.022923681885004044, 0.019266795367002487, 0.15913596749305725, 0.7613908052444458, 0.0]], [[0.1481553614139557, 0.14691436290740967, 0.5575758218765259, 0.02441403828561306, 0.058879025280475616, 0.011832842603325844, 0.01016098354011774, 0.015505112707614899, 0.026562504470348358, 0.0], [0.20737145841121674, 0.2658809721469879, 0.4251604974269867, 0.03998560830950737, 0.012661930173635483, 0.003662273520603776, 0.0006891252705827355, 0.004390099551528692, 0.040197838097810745, 0.0], [0.09591562300920486, 0.13717111945152283, 0.23219715058803558, 0.020156029611825943, 0.031411558389663696, 0.04842779412865639, 0.003137993859127164, 0.03623202070593834, 0.3953508734703064, 0.0], [0.039530280977487564, 0.0770052894949913, 0.4637334942817688, 0.3284752666950226, 0.018390586599707603, 0.021701356396079063, 0.0038800504989922047, 0.01712900958955288, 0.03015456721186638, 0.0], [0.11945435404777527, 0.064958855509758, 0.07506249845027924, 0.03312050923705101, 0.045947931706905365, 0.21168209612369537, 0.1585550606250763, 0.15941208600997925, 0.13180643320083618, 0.0], [0.027811916545033455, 0.005367752630263567, 0.022701909765601158, 0.02928026206791401, 0.042085714638233185, 0.23124846816062927, 0.3448639512062073, 0.17164644598960876, 0.12499356269836426, 0.0], [0.04207267239689827, 0.0036673955619335175, 0.013221505098044872, 0.04823020473122597, 0.018784234300255775, 0.15617071092128754, 0.5762590169906616, 0.08817830681800842, 0.053416069597005844, 0.0], [0.016374358907341957, 0.01844160258769989, 0.0564517118036747, 0.008724928833544254, 0.031119121238589287, 0.08068697899580002, 0.028958600014448166, 0.08762799203395844, 0.6716147661209106, 0.0], [0.01953587494790554, 0.025442129001021385, 0.033712055534124374, 0.054231878370046616, 0.046861547976732254, 0.038379911333322525, 0.03105914779007435, 0.027592265978455544, 0.723185122013092, 0.0], [0.13706038892269135, 0.05296454578638077, 0.06056801974773407, 0.10271193832159042, 0.10989244282245636, 0.11971112340688705, 0.10623226314783096, 0.11503037810325623, 0.19582894444465637, 0.0]], [[0.1995955854654312, 0.1365700364112854, 0.06844333559274673, 0.10430964082479477, 0.06450515240430832, 0.046256136149168015, 0.1181989535689354, 0.11867640167474747, 0.14344464242458344, 0.0], [0.11168574541807175, 0.19221879541873932, 0.09424428641796112, 0.10450402647256851, 0.06917304545640945, 0.0600862056016922, 0.14199501276016235, 0.09375911951065063, 0.1323338896036148, 0.0], [0.1609925925731659, 0.1396498680114746, 0.177944153547287, 0.03334498405456543, 0.016808874905109406, 0.10536731034517288, 0.1187783032655716, 0.03365077078342438, 0.2134632021188736, 0.0], [0.1381153017282486, 0.04422784969210625, 0.04791303351521492, 0.16848880052566528, 0.14531251788139343, 0.08485772460699081, 0.03650972992181778, 0.08906612545251846, 0.24550898373126984, 0.0], [0.05420248210430145, 0.02658572979271412, 0.05446610227227211, 0.10749125480651855, 0.22097598016262054, 0.16638338565826416, 0.0331658273935318, 0.035041358321905136, 0.30168798565864563, 0.0], [0.13837963342666626, 0.02727973647415638, 0.1299334168434143, 0.0896206796169281, 0.11551950126886368, 0.13963927328586578, 0.07841819524765015, 0.02172034978866577, 0.25948914885520935, 0.0], [0.08648376911878586, 0.022083481773734093, 0.023758457973599434, 0.19388236105442047, 0.1724909394979477, 0.02776450663805008, 0.04756799340248108, 0.03839871287345886, 0.38756975531578064, 0.0], [0.06476524472236633, 0.0030945511534810066, 0.016289785504341125, 0.013512126170098782, 0.007217712234705687, 0.047962453216314316, 0.03755675256252289, 0.02134101279079914, 0.788260281085968, 0.0], [0.06007339805364609, 0.038077425211668015, 0.01732070930302143, 0.04335314407944679, 0.09849875420331955, 0.07123422622680664, 0.12536978721618652, 0.17402620613574982, 0.37204641103744507, 0.0], [0.31619662046432495, 0.10629935562610626, 0.051193755120038986, 0.08206456899642944, 0.08056272566318512, 0.05727463215589523, 0.13476009666919708, 0.03839832916855812, 0.13325001299381256, 0.0]]], [[[0.0077307759784162045, 0.013184988871216774, 0.016869038343429565, 0.013336911797523499, 0.01304439827799797, 0.013718237169086933, 0.0296618789434433, 0.02448520064353943, 0.8679684996604919, 0.0], [0.1374177634716034, 0.018056754022836685, 0.029763542115688324, 0.004862301517277956, 0.00231130956672132, 0.006278112530708313, 0.012106452137231827, 0.033879589289426804, 0.755324125289917, 0.0], [0.1539316475391388, 0.23461903631687164, 0.033691998571157455, 0.026462335139513016, 0.0030949951615184546, 0.0038835303857922554, 0.009438932873308659, 0.0025479686446487904, 0.5323294997215271, 0.0], [0.29215019941329956, 0.05790534242987633, 0.18934868276119232, 0.018473153933882713, 0.002999690594151616, 0.004652327857911587, 0.010374259203672409, 0.0072145056910812855, 0.4168816804885864, 0.0], [0.16539201140403748, 0.05819307267665863, 0.12084146589040756, 0.1738077849149704, 0.004504370968788862, 0.006831282749772072, 0.02180996537208557, 0.012287246994674206, 0.4363327622413635, 0.0], [0.017602156847715378, 0.026805447414517403, 0.07671570032835007, 0.5152483582496643, 0.21202509105205536, 0.041201505810022354, 0.02207496576011181, 0.00952092744410038, 0.07880578190088272, 0.0], [0.017750855535268784, 0.01654047518968582, 0.07482129335403442, 0.23223723471164703, 0.3542158007621765, 0.16141267120838165, 0.02249749004840851, 0.044686269015073776, 0.07583795487880707, 0.0], [0.01755565032362938, 0.008823209442198277, 0.013083218596875668, 0.5061533451080322, 0.02344801276922226, 0.0075739468447864056, 0.07187878340482712, 0.029167035594582558, 0.3223167061805725, 0.0], [0.07037408649921417, 0.05453738570213318, 0.05508268624544144, 0.02769530564546585, 0.038050826638936996, 0.20446287095546722, 0.19980187714099884, 0.19835616648197174, 0.15163862705230713, 0.0], [0.003375247586518526, 0.008793321438133717, 0.001001630094833672, 0.002094975672662258, 0.0032946632709354162, 0.01792662777006626, 0.10988471657037735, 0.21093924343585968, 0.6426896452903748, 0.0]], [[0.07041527330875397, 0.15367093682289124, 0.3963199257850647, 0.03077671490609646, 0.0928598940372467, 0.04086732864379883, 0.018142100423574448, 0.012120239436626434, 0.18482762575149536, 0.0], [0.0367966964840889, 0.022482391446828842, 0.35830163955688477, 0.02875097654759884, 0.03547174483537674, 0.026731541380286217, 0.005365677177906036, 0.038472291082143784, 0.4476269483566284, 0.0], [0.035722482949495316, 0.013633755035698414, 0.09877835214138031, 0.0896211713552475, 0.16777706146240234, 0.10725134611129761, 0.05053357034921646, 0.10712091624736786, 0.32956135272979736, 0.0], [0.026592494919896126, 0.002020884770900011, 0.010739283636212349, 0.015951883047819138, 0.18538028001785278, 0.16766443848609924, 0.03731367364525795, 0.3853055238723755, 0.16903170943260193, 0.0], [0.0038862666115164757, 0.0004870722477789968, 0.0013956124894320965, 0.001421120367012918, 0.013834443874657154, 0.18579153716564178, 0.18213258683681488, 0.5258954763412476, 0.08515587449073792, 0.0], [0.002724156714975834, 0.004415807779878378, 0.003638928523287177, 0.00862019695341587, 0.010569852776825428, 0.18068262934684753, 0.2256886065006256, 0.3616458773612976, 0.20201392471790314, 0.0], [0.004798010922968388, 0.006960091646760702, 0.005558326840400696, 0.015252271667122841, 0.010058294981718063, 0.16163121163845062, 0.19844789803028107, 0.089196115732193, 0.508097767829895, 0.0], [0.006148195825517178, 0.009931370615959167, 0.0022139709908515215, 0.003481896361336112, 0.00199966412037611, 0.011451391503214836, 0.018514955416321754, 0.10389390587806702, 0.8423647284507751, 0.0], [0.1474374532699585, 0.1116698756814003, 0.10040155798196793, 0.07832593470811844, 0.051569730043411255, 0.103182852268219, 0.0987909808754921, 0.06264416873455048, 0.24597744643688202, 0.0], [0.0029639359563589096, 0.0008504446013830602, 0.002215511864051223, 0.0016108372947201133, 0.001786046545021236, 0.003435377962887287, 0.000923731888178736, 0.009203500114381313, 0.9770104885101318, 0.0]], [[0.022387586534023285, 0.045972827821969986, 0.05835629999637604, 0.22869053483009338, 0.010770916007459164, 0.006216464098542929, 0.018148910254240036, 0.006308646872639656, 0.6031478047370911, 0.0], [0.0035997454542666674, 0.002674269489943981, 0.016009783372282982, 0.05554450675845146, 0.0013587778666988015, 0.0032801039051264524, 0.00560772093012929, 0.00799081102013588, 0.9039342403411865, 0.0], [0.0018151046242564917, 0.001049908110871911, 0.0005912692868150771, 0.005136367864906788, 0.0005621784366667271, 0.00844560656696558, 0.017937110736966133, 0.008342047221958637, 0.9561205506324768, 0.0], [0.004429707303643227, 0.005516116041690111, 0.003033371875062585, 0.012963998131453991, 0.0034379358403384686, 0.003276604227721691, 0.0140963364392519, 0.005416945554316044, 0.9478288888931274, 0.0], [0.013230006210505962, 0.011804360896348953, 0.009972047992050648, 0.004975683055818081, 0.008386109955608845, 0.18977868556976318, 0.1806434541940689, 0.03204761818051338, 0.549161970615387, 0.0], [0.006766628473997116, 0.008349079638719559, 0.003925195895135403, 0.0006033667596057057, 0.006175691727548838, 0.2236345112323761, 0.03405819088220596, 0.07976362109184265, 0.6367236971855164, 0.0], [0.03860252723097801, 0.01646261475980282, 0.02104821614921093, 0.0021387943997979164, 0.005319601856172085, 0.2400989532470703, 0.03188503161072731, 0.005558657925575972, 0.6388856768608093, 0.0], [0.0738457515835762, 0.018826894462108612, 0.0069308048114180565, 0.0074225678108632565, 0.004789229016751051, 0.046955253928899765, 0.11907684803009033, 0.18744726479053497, 0.5347052812576294, 0.0], [0.09479796141386032, 0.11939200013875961, 0.0752992108464241, 0.061374519020318985, 0.08638977259397507, 0.12459041178226471, 0.16023214161396027, 0.0879756435751915, 0.1899482160806656, 0.0], [0.027537798509001732, 0.06296242028474808, 0.014751194976270199, 0.0011882808757945895, 0.016387099400162697, 0.15830224752426147, 0.03707461059093475, 0.028470970690250397, 0.6533253788948059, 0.0]], [[0.044569190591573715, 0.00917287077754736, 0.004391324240714312, 0.8386606574058533, 0.06130588799715042, 0.003870139131322503, 0.007488539442420006, 0.028126200661063194, 0.002415221417322755, 0.0], [0.08009635657072067, 0.016815535724163055, 0.012093844823539257, 0.1592065542936325, 0.5643750429153442, 0.02920410968363285, 0.0919446051120758, 0.036902546882629395, 0.009361499920487404, 0.0], [0.08603464066982269, 0.01933746039867401, 0.05900268629193306, 0.2806539237499237, 0.22094620764255524, 0.08643656224012375, 0.026435989886522293, 0.1974046230316162, 0.023747902363538742, 0.0], [0.011109462939202785, 0.008883769623935223, 0.006091873627156019, 0.9400036931037903, 0.020445559173822403, 0.0056496066972613335, 0.0019461432239040732, 0.005268549080938101, 0.000601345207542181, 0.0], [0.12766019999980927, 0.021774157881736755, 0.08726640790700912, 0.0718328207731247, 0.053083695471286774, 0.3031027019023895, 0.06321869790554047, 0.2611844837665558, 0.010876962915062904, 0.0], [0.06407223641872406, 0.11627303808927536, 0.1807759404182434, 0.0054795523174107075, 0.026687098667025566, 0.09637009352445602, 0.052303463220596313, 0.4456423819065094, 0.012396130710840225, 0.0], [0.09725438803434372, 0.15948796272277832, 0.05173082649707794, 0.01153761800378561, 0.0721999853849411, 0.059252724051475525, 0.11923323571681976, 0.05380275845527649, 0.3755004107952118, 0.0], [0.0974307581782341, 0.23218762874603271, 0.06967660784721375, 0.031012043356895447, 0.04906507954001427, 0.31767621636390686, 0.08231117576360703, 0.07159094512462616, 0.04904941841959953, 0.0], [0.16388627886772156, 0.17526376247406006, 0.07081529498100281, 0.17886894941329956, 0.07944575697183609, 0.07640470564365387, 0.0757102444767952, 0.04333823174238205, 0.13626690208911896, 0.0], [0.09265941381454468, 0.11633585393428802, 0.04908691346645355, 0.0062498715706169605, 0.07016508281230927, 0.012818480841815472, 0.0484321266412735, 0.015437646768987179, 0.5888146162033081, 0.0]], [[0.005850312765687704, 0.017421673983335495, 0.004798548296093941, 0.008814580738544464, 0.00403921864926815, 0.015260725282132626, 0.03377071022987366, 0.009620469063520432, 0.9004237651824951, 0.0], [0.013436811044812202, 0.03272867575287819, 0.005969575606286526, 0.02213078737258911, 0.008325905539095402, 0.015314633026719093, 0.027177294716238976, 0.017041552811861038, 0.8578747510910034, 0.0], [0.012305106967687607, 0.03383316844701767, 0.010593314655125141, 0.027156231924891472, 0.00306991720572114, 0.004844812210649252, 0.018964877352118492, 0.05307865887880325, 0.8361539244651794, 0.0], [0.01255274098366499, 0.03600494936108589, 0.010369472205638885, 0.019021298736333847, 0.0032906190026551485, 0.0037067385856062174, 0.017627976834774017, 0.01037716306746006, 0.8870489597320557, 0.0], [0.0020879805088043213, 0.013872731477022171, 0.0018324662232771516, 0.006437606178224087, 0.013170951046049595, 0.011930068954825401, 0.0030771365854889154, 0.018353432416915894, 0.9292376041412354, 0.0], [0.022865016013383865, 0.0674654096364975, 0.00996339786797762, 0.01914660632610321, 0.014956261031329632, 0.026097828522324562, 0.018910687416791916, 0.06562207639217377, 0.7549726963043213, 0.0], [0.019515078514814377, 0.07521340996026993, 0.03206341341137886, 0.0070005785673856735, 0.0066195218823850155, 0.03877842426300049, 0.01228683814406395, 0.032381508499383926, 0.7761411666870117, 0.0], [0.0022571254521608353, 0.00383052253164351, 0.0012509305961430073, 0.005982697941362858, 0.001252268673852086, 0.0028570422437042, 0.00556317949667573, 0.7337145805358887, 0.24329175055027008, 0.0], [0.132036030292511, 0.1683972030878067, 0.13758207857608795, 0.14189518988132477, 0.03147142380475998, 0.047566916793584824, 0.07834812998771667, 0.12177446484565735, 0.1409287303686142, 0.0], [0.038695693016052246, 0.06063547730445862, 0.020152689889073372, 0.1101006418466568, 0.021127784624695778, 0.02848564088344574, 0.03705665469169617, 0.108894944190979, 0.5748504996299744, 0.0]], [[0.0842718631029129, 0.33930304646492004, 0.1421334594488144, 0.18528752028942108, 0.05815916135907173, 0.022830937057733536, 0.01860896497964859, 0.009871570393443108, 0.13953347504138947, 0.0], [0.1295168399810791, 0.08884051442146301, 0.06592670828104019, 0.2686370015144348, 0.02522267960011959, 0.03633918985724449, 0.021549394354224205, 0.051057688891887665, 0.31290990114212036, 0.0], [0.20177747309207916, 0.039902716875076294, 0.053595658391714096, 0.09988140314817429, 0.01657777465879917, 0.07154539972543716, 0.024320384487509727, 0.12353017926216125, 0.36886897683143616, 0.0], [0.20979411900043488, 0.30265647172927856, 0.20804975926876068, 0.007371237967163324, 0.0033807901199907064, 0.02442527562379837, 0.017248263582587242, 0.022337088361382484, 0.2047368586063385, 0.0], [0.20739784836769104, 0.05559583380818367, 0.12535981833934784, 0.009768493473529816, 0.015522800385951996, 0.024528708308935165, 0.03864477947354317, 0.08712086826562881, 0.4360608160495758, 0.0], [0.054182324558496475, 0.0038395673036575317, 0.019914912059903145, 0.014234894886612892, 0.012772555463016033, 0.019022708758711815, 0.04023807495832443, 0.36886361241340637, 0.46693122386932373, 0.0], [0.10943998396396637, 0.007069440558552742, 0.019821595400571823, 0.012627309188246727, 0.016869045794010162, 0.05302179232239723, 0.05124732851982117, 0.14304620027542114, 0.5868573188781738, 0.0], [0.23190192878246307, 0.060554418712854385, 0.05880776792764664, 0.00438718032091856, 0.00454165181145072, 0.15464532375335693, 0.09585105627775192, 0.02281157113611698, 0.3664989471435547, 0.0], [0.13208958506584167, 0.15704363584518433, 0.07176639884710312, 0.08554346114397049, 0.0733223557472229, 0.0956358015537262, 0.07472448796033859, 0.09573546797037125, 0.21413877606391907, 0.0], [0.28021925687789917, 0.045415956526994705, 0.04812552034854889, 0.00880114920437336, 0.012029618956148624, 0.04001859948039055, 0.0577121265232563, 0.02487611398100853, 0.4828015863895416, 0.0]], [[0.057871319353580475, 0.09845025092363358, 0.03600643575191498, 0.06401734054088593, 0.07263048738241196, 0.014885936863720417, 0.07473781704902649, 0.1193607747554779, 0.4620397090911865, 0.0], [0.04736582189798355, 0.06951819360256195, 0.039210546761751175, 0.040616557002067566, 0.05645532160997391, 0.01900673843920231, 0.063181072473526, 0.23291724920272827, 0.43172842264175415, 0.0], [0.05375257506966591, 0.04091374948620796, 0.01263821218162775, 0.04125160351395607, 0.014244006015360355, 0.012229752726852894, 0.029117466881871223, 0.07314542680978775, 0.7227071523666382, 0.0], [0.04062311723828316, 0.2312910407781601, 0.20085060596466064, 0.03848586603999138, 0.04763459786772728, 0.013425372540950775, 0.027237186208367348, 0.03882591798901558, 0.3616262376308441, 0.0], [0.04558584839105606, 0.04037311673164368, 0.043737076222896576, 0.02740027941763401, 0.005366531666368246, 0.014126299880445004, 0.07268305867910385, 0.014923120848834515, 0.7358046770095825, 0.0], [0.04308110475540161, 0.037618398666381836, 0.054927192628383636, 0.045146394520998, 0.02157701551914215, 0.014024189673364162, 0.03546718508005142, 0.04130468890070915, 0.7068538665771484, 0.0], [0.049693867564201355, 0.040712278336286545, 0.011129319667816162, 0.08677691221237183, 0.24132831394672394, 0.028864668682217598, 0.04710082337260246, 0.028962818905711174, 0.46543097496032715, 0.0], [0.0256647989153862, 0.026453843340277672, 0.1064542606472969, 0.07867259532213211, 0.03285365179181099, 0.056291256099939346, 0.026517342776060104, 0.014768523164093494, 0.6323237419128418, 0.0], [0.07950135320425034, 0.04906205087900162, 0.09099037200212479, 0.10450085997581482, 0.06846266984939575, 0.21755923330783844, 0.14818403124809265, 0.14456483721733093, 0.0971745029091835, 0.0], [0.0029191188514232635, 0.0026039828080683947, 0.000987313687801361, 0.027727283537387848, 0.007311245426535606, 0.0033244043588638306, 0.023969389498233795, 0.00596341909840703, 0.9251939058303833, 0.0]], [[0.3116825819015503, 0.20666195452213287, 0.1363646388053894, 0.07141851633787155, 0.029045483097434044, 0.04730900749564171, 0.037391580641269684, 0.10128220915794373, 0.05884409323334694, 0.0], [0.16883184015750885, 0.1785622239112854, 0.23318006098270416, 0.05258537083864212, 0.10740725696086884, 0.09185276180505753, 0.022670285776257515, 0.08943870663642883, 0.05547139048576355, 0.0], [0.13844002783298492, 0.030681077390909195, 0.12107612937688828, 0.007168712094426155, 0.05214103311300278, 0.10045275092124939, 0.006991118658334017, 0.2955043315887451, 0.2475447952747345, 0.0], [0.17306901514530182, 0.19963578879833221, 0.148520827293396, 0.2046978771686554, 0.06720028817653656, 0.019652126356959343, 0.05792365223169327, 0.07665500044822693, 0.05264541134238243, 0.0], [0.15115797519683838, 0.06555280834436417, 0.02683498151600361, 0.20794028043746948, 0.17434173822402954, 0.11980342864990234, 0.04239796847105026, 0.03961418569087982, 0.1723567098379135, 0.0], [0.022970011457800865, 0.003322059055790305, 0.03333919122815132, 0.03161727264523506, 0.3007935583591461, 0.20675496757030487, 0.037206847220659256, 0.30415406823158264, 0.059842076152563095, 0.0], [0.029802093282341957, 0.006723356898874044, 0.00844306219369173, 0.09911961853504181, 0.2257867157459259, 0.22737178206443787, 0.28318148851394653, 0.05687837302684784, 0.0626935064792633, 0.0], [0.008192392997443676, 0.002076511736959219, 0.010627061128616333, 0.01573592983186245, 0.01893553137779236, 0.042316026985645294, 0.02403445728123188, 0.868257999420166, 0.009823988191783428, 0.0], [0.08033955097198486, 0.18780502676963806, 0.02817567251622677, 0.041370414197444916, 0.02824225462973118, 0.038844767957925797, 0.11732563376426697, 0.06162348762154579, 0.4162730574607849, 0.0], [0.0563269779086113, 0.007248359732329845, 0.008029782213270664, 0.003040844574570656, 0.007221699226647615, 0.01730128563940525, 0.050128430128097534, 0.3587413430213928, 0.491961270570755, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"], \"bot_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"]}, \"inp_out\": {\"att\": [[[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9198169708251953, 0.0801829993724823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8846490979194641, 0.10308036208152771, 0.012270578183233738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9307316541671753, 0.03309628367424011, 0.027538668364286423, 0.008633385412395, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9335180521011353, 0.020782457664608955, 0.008113296702504158, 0.029529055580496788, 0.008057110011577606, 0.0, 0.0, 0.0, 0.0, 0.0], [0.923790454864502, 0.01269624661654234, 0.004588128533214331, 0.020286502316594124, 0.018672045320272446, 0.019966628402471542, 0.0, 0.0, 0.0, 0.0], [0.5214514136314392, 0.051599469035863876, 0.007387364283204079, 0.04305899888277054, 0.0632161945104599, 0.07775087654590607, 0.2355356514453888, 0.0, 0.0, 0.0], [0.9122877717018127, 0.007671441417187452, 0.0012418286642059684, 0.005250561982393265, 0.001960531808435917, 0.032091617584228516, 0.03012256510555744, 0.009373520500957966, 0.0, 0.0], [0.012450892478227615, 0.0001350480888504535, 0.0001820741599658504, 0.0018266986589878798, 0.00022605709091294557, 0.0032795630395412445, 0.005876350682228804, 0.012136856094002724, 0.9638864398002625, 0.0], [0.907938539981842, 0.003707215888425708, 0.003004483412951231, 0.0008324749651364982, 0.0015859504928812385, 0.008079104125499725, 0.010460118763148785, 0.005838368553668261, 0.038938846439123154, 0.019614921882748604]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4050312936306, 0.5949686765670776, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2333158701658249, 0.39531010389328003, 0.37137407064437866, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.52278733253479, 0.11893566697835922, 0.28584957122802734, 0.07242746651172638, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23179638385772705, 0.09258762001991272, 0.103512242436409, 0.19472002983093262, 0.37738385796546936, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3839746117591858, 0.05338669568300247, 0.09416119009256363, 0.09689370542764664, 0.24871283769607544, 0.12287086993455887, 0.0, 0.0, 0.0, 0.0], [0.5838866233825684, 0.02439245954155922, 0.042716383934020996, 0.03342103213071823, 0.08018141984939575, 0.15234005451202393, 0.08306187391281128, 0.0, 0.0, 0.0], [0.639571487903595, 0.016348807141184807, 0.038869310170412064, 0.02800355665385723, 0.0377902127802372, 0.0529697984457016, 0.07620508968830109, 0.11024164408445358, 0.0, 0.0], [0.5836893320083618, 0.011862898245453835, 0.02550557814538479, 0.009363977238535881, 0.0196645837277174, 0.018125057220458984, 0.07040998339653015, 0.2077602595090866, 0.053618304431438446, 0.0], [0.49946048855781555, 0.04904361814260483, 0.04135226085782051, 0.015084759332239628, 0.018269173800945282, 0.020069265738129616, 0.05080949887633324, 0.09452320635318756, 0.06869905441999435, 0.14268863201141357]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9956012964248657, 0.00439875153824687, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8920916318893433, 0.017498359084129333, 0.09041006118059158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8103601336479187, 0.011479738168418407, 0.14884205162525177, 0.029318034648895264, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9073429107666016, 0.017702236771583557, 0.0008831396116875112, 0.017153160646557808, 0.05691858008503914, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7007134556770325, 0.00013011474220547825, 0.0017889889422804117, 0.00429273396730423, 0.20973503589630127, 0.08333952724933624, 0.0, 0.0, 0.0, 0.0], [0.8020992279052734, 0.0005838978104293346, 0.0002877263759728521, 0.000665249943267554, 0.00924165453761816, 0.10947777330875397, 0.07764454185962677, 0.0, 0.0, 0.0], [0.936653733253479, 0.00026242269086651504, 0.0004762547614518553, 0.000683068297803402, 0.0005867508007213473, 0.008624686859548092, 0.044821251183748245, 0.00789186917245388, 0.0, 0.0], [0.638530433177948, 0.00012756754586007446, 2.6267471184837632e-05, 0.035790614783763885, 0.00038457714254036546, 0.0026843701489269733, 0.0740678533911705, 0.21536435186862946, 0.03302408382296562, 0.0], [0.9069857597351074, 0.0010905838571488857, 0.0003166680980939418, 0.0021527763456106186, 0.00019805191550403833, 0.0004849489778280258, 0.025774035602808, 0.02642407827079296, 0.01662513054907322, 0.01994791068136692]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9964158535003662, 0.0035840808413922787, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.603236198425293, 0.29069802165031433, 0.10606581717729568, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7401933073997498, 0.005742713809013367, 0.18690980970859528, 0.06715414673089981, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9087624549865723, 0.0078224902972579, 0.003505129599943757, 0.0673881471157074, 0.012521738186478615, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7394620180130005, 0.0234938096255064, 0.009907918982207775, 0.01616108976304531, 0.1237591803073883, 0.08721596747636795, 0.0, 0.0, 0.0, 0.0], [0.9526587724685669, 0.007287254091352224, 0.0013716809917241335, 0.0023222684394568205, 0.007607423700392246, 0.009167732670903206, 0.01958492584526539, 0.0, 0.0, 0.0], [0.9270981550216675, 0.004809631034731865, 0.0030887839384377003, 0.005205564666539431, 0.018441975116729736, 0.006030889227986336, 0.03003735840320587, 0.0052877976559102535, 0.0, 0.0], [0.603268563747406, 0.009098237380385399, 0.00021995518181938678, 0.07179546356201172, 0.0017328117974102497, 0.01055157370865345, 0.020978767424821854, 0.2736198902130127, 0.008734744042158127, 0.0], [0.6497007608413696, 0.0906025841832161, 0.0100435521453619, 0.007925360463559628, 0.013416239991784096, 0.0018666544929146767, 0.02140365168452263, 0.08128199726343155, 0.04188578948378563, 0.08187359571456909]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9857779741287231, 0.014221975579857826, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9197340607643127, 0.07413885742425919, 0.0061270855367183685, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8673564195632935, 0.016403868794441223, 0.1017053872346878, 0.014534366317093372, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.044595908373594284, 0.010755550116300583, 0.002565854461863637, 0.9345642328262329, 0.007518457714468241, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4605148434638977, 0.007289387751370668, 0.009601963683962822, 0.08598940074443817, 0.4091304838657379, 0.027473902329802513, 0.0, 0.0, 0.0, 0.0], [0.8714936971664429, 0.002528996206820011, 0.0021269593853503466, 0.0052809687331318855, 0.02593054249882698, 0.07010670751333237, 0.022532090544700623, 0.0, 0.0, 0.0], [0.507957398891449, 0.003823956474661827, 0.004157013725489378, 0.018131878226995468, 0.06916838884353638, 0.047881923615932465, 0.2798653542995453, 0.06901402771472931, 0.0, 0.0], [0.4575899839401245, 0.005646431352943182, 0.0004441867640707642, 0.03129462152719498, 0.014414624311029911, 0.0058625745587050915, 0.09207130968570709, 0.34311652183532715, 0.04955975338816643, 0.0], [0.8105311393737793, 0.0010255038505420089, 0.0001402802881784737, 0.0005781117943115532, 0.00122542935423553, 0.000594198820181191, 0.02804729714989662, 0.01081023644655943, 0.13665232062339783, 0.010395429097115993]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8512031435966492, 0.14879685640335083, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10041537135839462, 0.8953256011009216, 0.0042589944787323475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6295948624610901, 0.2121732085943222, 0.10306572169065475, 0.055166181176900864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9503376483917236, 0.007425909396260977, 0.0019253676291555166, 0.025024304166436195, 0.015286784619092941, 0.0, 0.0, 0.0, 0.0, 0.0], [0.24298420548439026, 0.06981680542230606, 0.030552756041288376, 0.020666545256972313, 0.46177101135253906, 0.1742086559534073, 0.0, 0.0, 0.0, 0.0], [0.8132306933403015, 0.003601218806579709, 0.01019350253045559, 0.009439423680305481, 0.040081463754177094, 0.07570415735244751, 0.04774952307343483, 0.0, 0.0, 0.0], [0.6454712152481079, 0.006356438156217337, 0.006696825381368399, 0.0020169378258287907, 0.11416922509670258, 0.11139311641454697, 0.07912010699510574, 0.03477614000439644, 0.0, 0.0], [0.22032444179058075, 0.0006508066435344517, 0.006827942095696926, 0.028858821839094162, 0.0022757677361369133, 0.006474251858890057, 0.09447979182004929, 0.6212162375450134, 0.018891895189881325, 0.0], [0.03250038996338844, 0.0005526043241843581, 2.807211239996832e-05, 0.00014761221245862544, 0.00482193985953927, 7.781770545989275e-05, 0.00014718669990543276, 0.0008632297394797206, 0.959712028503418, 0.0011490467004477978]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9700191020965576, 0.029980869963765144, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7072298526763916, 0.2173422873020172, 0.07542789727449417, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5017270445823669, 0.10517530888319016, 0.32087045907974243, 0.07222715020179749, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39005738496780396, 0.2261916995048523, 0.1838584840297699, 0.10916081070899963, 0.09073163568973541, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11122927069664001, 0.04386316239833832, 0.023478534072637558, 0.07375308126211166, 0.5692906379699707, 0.17838534712791443, 0.0, 0.0, 0.0, 0.0], [0.16762810945510864, 0.030268238857388496, 0.015392551198601723, 0.05242612585425377, 0.21519990265369415, 0.34948840737342834, 0.16959665715694427, 0.0, 0.0, 0.0], [0.15348000824451447, 0.03554287180304527, 0.008979924954473972, 0.07115276902914047, 0.08698276430368423, 0.24143245816230774, 0.28553345799446106, 0.11689584702253342, 0.0, 0.0], [0.09456975758075714, 0.010759694501757622, 0.0067994119599461555, 0.01042863354086876, 0.05627141892910004, 0.11228546500205994, 0.14361944794654846, 0.3204572796821594, 0.2448090761899948, 0.0], [0.057867951691150665, 0.02229062095284462, 0.016399098560214043, 0.02521427348256111, 0.047808028757572174, 0.03428687900304794, 0.05170976370573044, 0.19979508221149445, 0.41991233825683594, 0.12471600621938705]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9535994529724121, 0.04640045389533043, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8665578961372375, 0.09402694553136826, 0.03941517323255539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8201385140419006, 0.07587680220603943, 0.05075912922620773, 0.053225547075271606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6245242953300476, 0.093341164290905, 0.11281723529100418, 0.1092497780919075, 0.06006752699613571, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5755861401557922, 0.0864969864487648, 0.10001320391893387, 0.12654373049736023, 0.06871193647384644, 0.04264802858233452, 0.0, 0.0, 0.0, 0.0], [0.6500274538993835, 0.06470640748739243, 0.047299426048994064, 0.08855419605970383, 0.06197808310389519, 0.04487667977809906, 0.04255769029259682, 0.0, 0.0, 0.0], [0.5771223902702332, 0.0491044707596302, 0.09411156177520752, 0.06903567165136337, 0.04109871760010719, 0.06523709744215012, 0.06637011468410492, 0.03792000934481621, 0.0, 0.0], [0.4695849120616913, 0.017787985503673553, 0.06290572881698608, 0.06516575813293457, 0.09894091635942459, 0.03647425398230553, 0.051347069442272186, 0.08907806128263474, 0.10871540009975433, 0.0], [0.18501408398151398, 0.040740884840488434, 0.10466982424259186, 0.07660976052284241, 0.17033715546131134, 0.05819392204284668, 0.0898737907409668, 0.09184892475605011, 0.10470453649759293, 0.0780070349574089]]], [[[0.10875418037176132, 0.15107707679271698, 0.07560893893241882, 0.11182637512683868, 0.051575273275375366, 0.1800614595413208, 0.13901139795780182, 0.11257244646549225, 0.06951297074556351, 0.0], [0.04530828073620796, 0.11530135571956635, 0.03132164478302002, 0.12301183491945267, 0.01339547149837017, 0.009322633035480976, 0.0069213854148983955, 0.181557297706604, 0.47386014461517334, 0.0], [0.08671615272760391, 0.21926835179328918, 0.11249969899654388, 0.05250205472111702, 0.044286634773015976, 0.006910341326147318, 0.004434189759194851, 0.00961831770837307, 0.4637643098831177, 0.0], [0.016148164868354797, 0.08668603748083115, 0.1414848268032074, 0.024200299754738808, 0.018711188808083534, 0.02537006139755249, 0.017450006678700447, 0.039331331849098206, 0.6306182146072388, 0.0], [0.024489276111125946, 0.03301851078867912, 0.03003605268895626, 0.03562680631875992, 0.06981870532035828, 0.022592445835471153, 0.025447512045502663, 0.03545365110039711, 0.7235170006752014, 0.0], [0.05760658532381058, 0.08793947100639343, 0.053903114050626755, 0.0679689273238182, 0.007038408424705267, 0.007889931090176105, 0.010035911574959755, 0.019540006294846535, 0.6880777478218079, 0.0], [0.045610494911670685, 0.042210742831230164, 0.14248158037662506, 0.03233090415596962, 0.03048519603908062, 0.011738738045096397, 0.014284060336649418, 0.006383211817592382, 0.6744750738143921, 0.0], [0.096277616918087, 0.030696624889969826, 0.10220203548669815, 0.04915016517043114, 0.047845132648944855, 0.05814794450998306, 0.06954183429479599, 0.028650736436247826, 0.5174878835678101, 0.0], [0.009306053631007671, 0.02153283730149269, 0.009718294255435467, 0.005953253246843815, 0.011703923344612122, 0.017902903258800507, 0.011090915650129318, 0.01645584963262081, 0.8963360786437988, 0.0], [0.009895006194710732, 0.026821313425898552, 0.16079027950763702, 0.01761648990213871, 0.01726638339459896, 0.08361288905143738, 0.039622098207473755, 0.14411716163158417, 0.5002583861351013, 0.0]], [[0.0543275885283947, 0.01742306910455227, 0.05347726121544838, 0.18824619054794312, 0.09003543108701706, 0.08433128148317337, 0.1953076422214508, 0.206686869263649, 0.11016455292701721, 0.0], [0.00859006680548191, 0.02184058353304863, 0.02418440766632557, 0.03131486475467682, 0.03273439407348633, 0.06774082779884338, 0.1731010377407074, 0.09275981038808823, 0.5477339029312134, 0.0], [0.02145911566913128, 0.046526145190000534, 0.014734850265085697, 0.026213468983769417, 0.04904777929186821, 0.08567024767398834, 0.13810616731643677, 0.03392839804291725, 0.5843138694763184, 0.0], [0.019245177507400513, 0.01515401341021061, 0.027409562841057777, 0.0068243746645748615, 0.07997982203960419, 0.0921224057674408, 0.04510754346847534, 0.04373685643076897, 0.670420229434967, 0.0], [0.04381020739674568, 0.06711422652006149, 0.07609888166189194, 0.021496189758181572, 0.05042967572808266, 0.15614424645900726, 0.11071597784757614, 0.14296749234199524, 0.3312230408191681, 0.0], [0.04100082442164421, 0.030313873663544655, 0.032653506845235825, 0.0695231482386589, 0.12672685086727142, 0.12515434622764587, 0.08855390548706055, 0.05835743993520737, 0.4277162253856659, 0.0], [0.14112897217273712, 0.06592341512441635, 0.06986766308546066, 0.06311382353305817, 0.12678426504135132, 0.04950721934437752, 0.08025017380714417, 0.03467738255858421, 0.36874714493751526, 0.0], [0.02841436117887497, 0.022568009793758392, 0.014519155025482178, 0.019271234050393105, 0.018120555207133293, 0.036434635519981384, 0.014109926298260689, 0.24622198939323425, 0.6003400683403015, 0.0], [0.05730762332677841, 0.07724729180335999, 0.030861826613545418, 0.04063780978322029, 0.08539344370365143, 0.029541905969381332, 0.02964094467461109, 0.028206804767251015, 0.6211622953414917, 0.0], [0.20915710926055908, 0.193747878074646, 0.11181499063968658, 0.07680925726890564, 0.04479793831706047, 0.03787367418408394, 0.04819086939096451, 0.11330965161323547, 0.1642986238002777, 0.0]], [[0.038908280432224274, 0.07760688662528992, 0.062413811683654785, 0.0023113787174224854, 0.0021746077109128237, 0.015095214359462261, 0.003646473865956068, 0.038165315985679626, 0.759678065776825, 0.0], [0.015742339193820953, 0.029524141922593117, 0.0550379604101181, 0.16926467418670654, 0.035933610051870346, 0.03279981389641762, 0.03188418969511986, 0.5383173227310181, 0.09149592369794846, 0.0], [0.022741766646504402, 0.013864121399819851, 0.06161126494407654, 0.06985131651163101, 0.03954875469207764, 0.02864447981119156, 0.036658816039562225, 0.05774570629000664, 0.6693336963653564, 0.0], [0.06077639013528824, 0.053226571530103683, 0.05544588342308998, 0.08368532359600067, 0.04779139161109924, 0.028960514813661575, 0.03463221713900566, 0.42419588565826416, 0.21128588914871216, 0.0], [0.03320460394024849, 0.07872876524925232, 0.0791814923286438, 0.008506255224347115, 0.010383618995547295, 0.021636927500367165, 0.009444555267691612, 0.026183925569057465, 0.7327298521995544, 0.0], [0.14095324277877808, 0.17195045948028564, 0.04960065335035324, 0.02801741287112236, 0.02789357118308544, 0.0246508177369833, 0.027228642255067825, 0.008449538610875607, 0.521255612373352, 0.0], [0.01678302139043808, 0.02193976752460003, 0.13912786543369293, 0.05168221518397331, 0.06239692494273186, 0.008615943603217602, 0.037501659244298935, 0.02482585795223713, 0.6371266841888428, 0.0], [0.03396642208099365, 0.07778684049844742, 0.18657010793685913, 0.11281172931194305, 0.019890569150447845, 0.012303605675697327, 0.0494060292840004, 0.11448060721158981, 0.39278414845466614, 0.0], [0.02684134803712368, 0.03310805931687355, 0.163743257522583, 0.014529252424836159, 0.10077258199453354, 0.044357266277074814, 0.04152251034975052, 0.10173188894987106, 0.4733937382698059, 0.0], [0.01862592063844204, 0.022009190171957016, 0.028925148770213127, 0.006837732624262571, 0.006956242956221104, 0.010202805511653423, 0.015325144864618778, 0.11640346795320511, 0.7747144103050232, 0.0]], [[0.0830092504620552, 0.0839436799287796, 0.10106679797172546, 0.11154499650001526, 0.045070260763168335, 0.1284436285495758, 0.1161414161324501, 0.19574469327926636, 0.1350351870059967, 0.0], [0.0006529411766678095, 0.0018492193194106221, 0.018439743667840958, 0.004895282443612814, 0.0036929987836629152, 0.05041775107383728, 0.03271673619747162, 0.4425412714481354, 0.4447941780090332, 0.0], [0.015919672325253487, 0.02172437310218811, 0.013682822696864605, 0.028371846303343773, 0.017258556559681892, 0.014516759663820267, 0.033475372940301895, 0.45419326424598694, 0.40085726976394653, 0.0], [0.006064589135348797, 0.006147248670458794, 0.06902536749839783, 0.011021673679351807, 0.0062199062667787075, 0.17622654139995575, 0.00982236210256815, 0.46262383460998535, 0.25284844636917114, 0.0], [0.018328940495848656, 0.034908927977085114, 0.027539005503058434, 0.04494883120059967, 0.03695090860128403, 0.18224696815013885, 0.04204700142145157, 0.09570277482271194, 0.5173265337944031, 0.0], [0.06838149577379227, 0.025893883779644966, 0.06412170827388763, 0.11039282381534576, 0.12848982214927673, 0.09953469038009644, 0.09056522697210312, 0.12723064422607422, 0.28538966178894043, 0.0], [0.07893572002649307, 0.0734885111451149, 0.06503137946128845, 0.04291535168886185, 0.08502060174942017, 0.04846649244427681, 0.07035838067531586, 0.14812934398651123, 0.38765427470207214, 0.0], [0.007445929106324911, 0.004103729501366615, 0.05411284416913986, 0.006074799690395594, 0.07146289199590683, 0.5494692921638489, 0.05009504780173302, 0.058794084936380386, 0.1984413117170334, 0.0], [0.0037151367869228125, 0.005083263851702213, 0.02171880006790161, 0.01245985459536314, 0.012914983555674553, 0.14437292516231537, 0.026943473145365715, 0.17420484125614166, 0.5985866785049438, 0.0], [0.02579679898917675, 0.0645768865942955, 0.03225725144147873, 0.044467855244874954, 0.04297630116343498, 0.06060377135872841, 0.030930038541555405, 0.03278812766075134, 0.6656030416488647, 0.0]], [[0.13460709154605865, 0.15298102796077728, 0.06546170264482498, 0.14220191538333893, 0.11837887763977051, 0.09888823330402374, 0.10630416870117188, 0.08867054432630539, 0.09250646829605103, 0.0], [0.9316296577453613, 0.016095036640763283, 0.0020372711587697268, 0.0019596514757722616, 2.8437656510504894e-05, 6.708989531034604e-05, 0.0004955903859809041, 3.0113247703411616e-05, 0.047657083719968796, 0.0], [0.043201129883527756, 0.9419298768043518, 0.0003410913050174713, 0.003313146298751235, 7.506452675443143e-06, 1.9570916265365668e-05, 2.5470235414104536e-05, 2.1080213628010824e-05, 0.011141069233417511, 0.0], [3.7581870856229216e-05, 0.00022979748609941453, 0.9982534646987915, 8.70372386998497e-05, 5.87535805607331e-06, 2.5239218302886002e-05, 6.597588708245894e-06, 2.193619138779468e-06, 0.001352491439320147, 0.0], [0.0019612079486250877, 0.011641290038824081, 0.010358362458646297, 0.8346317410469055, 0.00641160923987627, 0.0007435380248352885, 0.0018172020791098475, 7.255822129081935e-05, 0.1323624849319458, 0.0], [4.077299308846705e-05, 0.00016088274423964322, 3.1180113637674367e-06, 5.9685276937671006e-05, 6.661444786004722e-06, 0.0006764131248928607, 5.4107837058836594e-05, 0.9797272086143494, 0.01927126571536064, 0.0], [2.7792530090664513e-06, 1.1777839063142892e-05, 1.0386434951215051e-05, 0.0006807934259995818, 0.00028749846387654543, 0.9563493728637695, 2.4335316993528977e-05, 0.001297356327995658, 0.041335828602313995, 0.0], [0.00033864984288811684, 0.00016234541544690728, 0.00011107163300039247, 7.639558316441253e-05, 9.851753566181287e-05, 0.00046863980242051184, 0.9855522513389587, 0.00012009339843643829, 0.013071970082819462, 0.0], [0.001446103909984231, 0.0026176422834396362, 0.0005430445889942348, 0.5833504796028137, 0.08298782259225845, 0.01277364045381546, 0.008405186235904694, 0.028461067005991936, 0.2794148921966553, 0.0], [8.301706202473724e-07, 1.612889263924444e-06, 3.859615389956161e-06, 0.0015496612759307027, 0.9884966611862183, 0.0003321043332107365, 1.1829011782538146e-05, 3.7258676002238644e-06, 0.00959983840584755, 0.0]], [[0.03624086081981659, 0.008591840974986553, 0.01890810765326023, 0.010947922244668007, 0.5211313366889954, 0.04890615865588188, 0.13394898176193237, 0.08554741740226746, 0.13577744364738464, 0.0], [0.09101090580224991, 0.15663929283618927, 0.2008313536643982, 0.13744188845157623, 0.16349081695079803, 0.01479706447571516, 0.04576689749956131, 0.05515507981181145, 0.1348666250705719, 0.0], [0.10898119956254959, 0.19741322100162506, 0.12774543464183807, 0.07097428292036057, 0.033309608697891235, 0.016726871952414513, 0.019306309521198273, 0.09155051410198212, 0.3339925706386566, 0.0], [0.051247891038656235, 0.06952031701803207, 0.3243081271648407, 0.04820195212960243, 0.05462171137332916, 0.04280935227870941, 0.03801479935646057, 0.07710513472557068, 0.2941707372665405, 0.0], [0.22540897130966187, 0.04426601901650429, 0.13483746349811554, 0.09052211791276932, 0.036632657051086426, 0.06078784167766571, 0.09962243586778641, 0.04597063735127449, 0.2619517743587494, 0.0], [0.08315062522888184, 0.10649015009403229, 0.15254046022891998, 0.0728936716914177, 0.10388997197151184, 0.04998103529214859, 0.0675109326839447, 0.17524446547031403, 0.18829864263534546, 0.0], [0.09407053142786026, 0.04335644096136093, 0.04757237061858177, 0.023308007046580315, 0.14141318202018738, 0.017728488892316818, 0.02331509254872799, 0.07266414165496826, 0.5365718007087708, 0.0], [0.08477651327848434, 0.026448125019669533, 0.013684368692338467, 0.1331702470779419, 0.16824185848236084, 0.007634431589394808, 0.025501158088445663, 0.035930439829826355, 0.5046128630638123, 0.0], [0.03296202793717384, 0.01823815330862999, 0.025750160217285156, 0.08325016498565674, 0.1596710979938507, 0.010502922348678112, 0.01792057603597641, 0.05097610503435135, 0.6007286906242371, 0.0], [0.04370357468724251, 0.02250431850552559, 0.016271278262138367, 0.019842427223920822, 0.12028838694095612, 0.03933797404170036, 0.043740611523389816, 0.08045370131731033, 0.6138576865196228, 0.0]], [[0.1783323585987091, 0.3813028037548065, 0.2072289139032364, 0.06766574084758759, 0.053963109850883484, 0.030795719474554062, 0.023536406457424164, 0.03921645134687424, 0.01795845478773117, 0.0], [0.8837893009185791, 0.07202983647584915, 0.03646722435951233, 0.0004511935112532228, 0.0007272462244145572, 0.0008432198665104806, 0.0031319037079811096, 0.0004143840924371034, 0.0021455709356814623, 0.0], [0.3973897695541382, 0.14911939203739166, 0.3486334979534149, 0.012645252980291843, 0.00675938231870532, 0.00483374297618866, 0.010028100572526455, 0.012036854401230812, 0.058554183691740036, 0.0], [0.005409032106399536, 0.005906772334128618, 0.13379110395908356, 0.15247586369514465, 0.06559418141841888, 0.15356750786304474, 0.04085409641265869, 0.029147597029805183, 0.41325387358665466, 0.0], [0.0013326199259608984, 0.0014979635598137975, 0.011986319907009602, 0.7730216383934021, 0.06901827454566956, 0.05895080044865608, 0.016383536159992218, 0.015771687030792236, 0.052037257701158524, 0.0], [0.0012038598069921136, 0.0033955213148146868, 0.025528373196721077, 0.03136582672595978, 0.10901585966348648, 0.3851255178451538, 0.0182026457041502, 0.13982580602169037, 0.2863365411758423, 0.0], [0.008065885864198208, 0.004362722393125296, 0.06363680213689804, 0.023311397060751915, 0.06106392294168472, 0.1357712298631668, 0.03965916484594345, 0.06073852628469467, 0.6033903956413269, 0.0], [0.0003142715140711516, 0.0005578870768658817, 0.0015481057344004512, 0.0887022390961647, 0.06383900344371796, 0.2639910578727722, 0.049384135752916336, 0.12241825461387634, 0.40924492478370667, 0.0], [0.0003916181158274412, 0.0003099135938100517, 0.0024421222042292356, 0.016801349818706512, 0.18835966289043427, 0.025843605399131775, 0.08458039909601212, 0.20884136855602264, 0.4724300503730774, 0.0], [5.865378989255987e-05, 7.253760122694075e-05, 0.0007906460668891668, 0.025103986263275146, 0.0753612071275711, 0.04038592055439949, 0.011871143244206905, 0.05808362737298012, 0.7882723212242126, 0.0]], [[0.01597539149224758, 0.027860743924975395, 0.08824922889471054, 0.011547067202627659, 0.02896539680659771, 0.03845160827040672, 0.011409634724259377, 0.043791815638542175, 0.7337491512298584, 0.0], [0.0371943861246109, 0.014876782894134521, 0.02253115549683571, 0.10164438933134079, 0.029471710324287415, 0.040005166083574295, 0.020577073097229004, 0.07326765358448029, 0.6604316830635071, 0.0], [0.06676606088876724, 0.1320837438106537, 0.02368331328034401, 0.09289334714412689, 0.06407851725816727, 0.007657648529857397, 0.014540987089276314, 0.018603011965751648, 0.5796933174133301, 0.0], [0.029496638104319572, 0.013616771437227726, 0.030488401651382446, 0.021259615197777748, 0.13049498200416565, 0.06418323516845703, 0.050123173743486404, 0.1609034240245819, 0.4994336664676666, 0.0], [0.010230573825538158, 0.015954630449414253, 0.007779641076922417, 0.018425902351737022, 0.021085364744067192, 0.0588817335665226, 0.013979516923427582, 0.0252523310482502, 0.828410267829895, 0.0], [0.02648993395268917, 0.0214377511292696, 0.03494586795568466, 0.05471349507570267, 0.09140968322753906, 0.04952282831072807, 0.05564551055431366, 0.11169540882110596, 0.5541394948959351, 0.0], [0.03231878578662872, 0.018621357157826424, 0.05183127149939537, 0.03979233279824257, 0.13804322481155396, 0.03567919135093689, 0.047386858612298965, 0.13114488124847412, 0.505182147026062, 0.0], [0.04592716693878174, 0.010993612930178642, 0.01772226020693779, 0.05332585424184799, 0.15264220535755157, 0.22139224410057068, 0.048004403710365295, 0.12396018952131271, 0.3260320723056793, 0.0], [0.03168570622801781, 0.026294516399502754, 0.025469979271292686, 0.03026771917939186, 0.058515094220638275, 0.13361068069934845, 0.026259208098053932, 0.0612059161067009, 0.6066910624504089, 0.0], [0.07492455840110779, 0.06428299844264984, 0.07022737711668015, 0.0507473424077034, 0.0447908453643322, 0.060839906334877014, 0.14463475346565247, 0.054812539368867874, 0.4347396492958069, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9642227292060852, 0.035777393728494644, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9523521065711975, 0.027811188250780106, 0.019836684688925743, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.849480152130127, 0.03536543622612953, 0.019422976300120354, 0.09573143720626831, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.741925060749054, 0.05566684901714325, 0.024736514315009117, 0.08595114946365356, 0.09172046929597855, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6503966450691223, 0.0582728385925293, 0.0236701387912035, 0.0691222995519638, 0.0758395791053772, 0.12269847840070724, 0.0, 0.0, 0.0, 0.0], [0.4914315342903137, 0.11739180237054825, 0.02309434488415718, 0.07889512181282043, 0.05101678892970085, 0.12367808818817139, 0.11449223756790161, 0.0, 0.0, 0.0], [0.4262734055519104, 0.07066749036312103, 0.024391667917370796, 0.04879573732614517, 0.051445234566926956, 0.1276569813489914, 0.11843930184841156, 0.13233007490634918, 0.0, 0.0], [0.589878499507904, 0.026613032445311546, 0.020459800958633423, 0.028271155431866646, 0.03679497539997101, 0.07860217243432999, 0.08500825613737106, 0.09285575151443481, 0.04151623696088791, 0.0], [0.2743179202079773, 0.06089583784341812, 0.03565794974565506, 0.044920988380908966, 0.03933599591255188, 0.18495218455791473, 0.09192009270191193, 0.13160176575183868, 0.04121606424450874, 0.09518115967512131]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9842625260353088, 0.015737490728497505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8382691144943237, 0.11647694557905197, 0.04525385797023773, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4638526439666748, 0.1585947573184967, 0.3189436197280884, 0.0586090050637722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2375488132238388, 0.07284080982208252, 0.20766110718250275, 0.3110494017601013, 0.1708998829126358, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20615516602993011, 0.03705071657896042, 0.05929475650191307, 0.08692343533039093, 0.5564662218093872, 0.05410974845290184, 0.0, 0.0, 0.0, 0.0], [0.31913095712661743, 0.011343744583427906, 0.01675090566277504, 0.013238506391644478, 0.06746862828731537, 0.3789318799972534, 0.19313538074493408, 0.0, 0.0, 0.0], [0.4113273322582245, 0.003934106323868036, 0.003564919577911496, 0.005882325116544962, 0.018547017127275467, 0.18534934520721436, 0.3216978907585144, 0.04969710111618042, 0.0, 0.0], [0.07648876309394836, 0.0013769177021458745, 0.001890459912829101, 0.006597061175853014, 0.007926206104457378, 0.013261871412396431, 0.15683594346046448, 0.7190074324607849, 0.016615279018878937, 0.0], [0.08104224503040314, 0.00045554721145890653, 0.00038501128437928855, 0.0009405335295014083, 0.005597654264420271, 0.0034990713465958834, 0.009850292466580868, 0.0463707260787487, 0.7366765141487122, 0.11518235504627228]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9800853133201599, 0.019914645701646805, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9159882068634033, 0.02969631738960743, 0.05431551858782768, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6467475295066833, 0.08892705291509628, 0.19796258211135864, 0.06636285036802292, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9833061099052429, 0.004010406322777271, 0.004914217162877321, 0.0015858567785471678, 0.006183335091918707, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9524497389793396, 0.0022862900514155626, 0.000848656112793833, 0.00408557103946805, 0.028177350759506226, 0.012152665294706821, 0.0, 0.0, 0.0, 0.0], [0.1907505989074707, 0.026542214676737785, 0.01945381611585617, 0.029287727549672127, 0.057166602462530136, 0.11766232550144196, 0.5591367483139038, 0.0, 0.0, 0.0], [0.4022328555583954, 0.017193131148815155, 0.01565318927168846, 0.01915702596306801, 0.01739031821489334, 0.16459040343761444, 0.18205313384532928, 0.18172988295555115, 0.0, 0.0], [0.9652498960494995, 0.0010482663055881858, 0.0012260396033525467, 0.0009098293376155198, 0.0013901795027777553, 0.0028189055155962706, 0.007343438919633627, 0.018731823191046715, 0.0012814495712518692, 0.0], [0.18471455574035645, 0.018054824322462082, 0.08812589198350906, 0.00762907462194562, 0.018057269975543022, 0.05247756093740463, 0.03497685119509697, 0.5025416612625122, 0.052323222160339355, 0.04109897091984749]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9911633133888245, 0.008836665190756321, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9641951322555542, 0.023474374786019325, 0.012330451980233192, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6152319312095642, 0.28041696548461914, 0.04906271770596504, 0.05528838559985161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6057276725769043, 0.1235719844698906, 0.06170117110013962, 0.11151555925607681, 0.0974835753440857, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6386814713478088, 0.07927443087100983, 0.06004401296377182, 0.06398510187864304, 0.06341437995433807, 0.09460049122571945, 0.0, 0.0, 0.0, 0.0], [0.13321073353290558, 0.0565485954284668, 0.20425985753536224, 0.10307760536670685, 0.17957380414009094, 0.26328328251838684, 0.06004612147808075, 0.0, 0.0, 0.0], [0.19694660604000092, 0.027736904099583626, 0.05790374055504799, 0.10621010512113571, 0.15510229766368866, 0.2214440256357193, 0.18680275976657867, 0.04785352945327759, 0.0, 0.0], [0.08537944406270981, 0.033881768584251404, 0.03968465328216553, 0.08240006119012833, 0.15350975096225739, 0.23219235241413116, 0.22240297496318817, 0.11620921641588211, 0.034339725971221924, 0.0], [0.06051333248615265, 0.012086840346455574, 0.028373999521136284, 0.07542525231838226, 0.10199770331382751, 0.15039192140102386, 0.20426926016807556, 0.16016273200511932, 0.06537677347660065, 0.14140206575393677]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5400503277778625, 0.4599496126174927, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04321815073490143, 0.9357689023017883, 0.02101275697350502, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48035699129104614, 0.12913382053375244, 0.27151036262512207, 0.11899882555007935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6920371055603027, 0.019891848787665367, 0.1885785609483719, 0.06273186951875687, 0.036760613322257996, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8527964949607849, 0.08059625327587128, 0.0037265238352119923, 0.008582950569689274, 0.042790722101926804, 0.01150701567530632, 0.0, 0.0, 0.0, 0.0], [0.900881826877594, 0.012710069306194782, 0.000794807099737227, 0.00424413476139307, 0.02110898308455944, 0.01962616853415966, 0.04063420742750168, 0.0, 0.0, 0.0], [0.713775098323822, 0.003081131726503372, 0.000918463512789458, 0.009338468313217163, 0.013423318043351173, 0.019161174073815346, 0.10174864530563354, 0.13855360448360443, 0.0, 0.0], [0.4800099730491638, 0.0009553784620948136, 0.00013007478264626116, 0.020002998411655426, 0.0032414987217634916, 0.002101779682561755, 0.028948260471224785, 0.46123453974723816, 0.0033754503820091486, 0.0], [0.7501513361930847, 0.019767694175243378, 0.0020619838032871485, 0.0038300605956465006, 0.0023455689661204815, 0.023803891614079475, 0.011456847190856934, 0.045016106218099594, 0.08813992142677307, 0.05342674255371094]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03494315221905708, 0.965056836605072, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.020348060876131058, 0.8944171071052551, 0.08523476868867874, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0015979396412149072, 0.6347042918205261, 0.09008561074733734, 0.27361196279525757, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01025437843054533, 0.17247439920902252, 0.3664330542087555, 0.4087805449962616, 0.04205762594938278, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012186901643872261, 0.3028968572616577, 0.12117700278759003, 0.3522109389305115, 0.06255244463682175, 0.14897578954696655, 0.0, 0.0, 0.0, 0.0], [0.010822800919413567, 0.2333739995956421, 0.11113002151250839, 0.15861180424690247, 0.11286703497171402, 0.2766783833503723, 0.0965159684419632, 0.0, 0.0, 0.0], [0.00965114776045084, 0.19982098042964935, 0.054301097989082336, 0.13056904077529907, 0.03828747197985649, 0.4827912747859955, 0.05511533096432686, 0.029463520273566246, 0.0, 0.0], [0.014548483304679394, 0.07520423084497452, 0.1090526208281517, 0.14237697422504425, 0.030428709462285042, 0.5021095275878906, 0.026151562109589577, 0.04390878602862358, 0.05621904134750366, 0.0], [0.000422637298470363, 0.17123113572597504, 0.04347287863492966, 0.10408183932304382, 0.013075248338282108, 0.5476951003074646, 0.020964276045560837, 0.019243689253926277, 0.0612923838198185, 0.018520813435316086]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9947329163551331, 0.005267037078738213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7284466028213501, 0.21829284727573395, 0.05326057970523834, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7024527192115784, 0.0454108789563179, 0.10381712764501572, 0.14831924438476562, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2374107390642166, 0.04589728266000748, 0.2683154046535492, 0.3902822434902191, 0.0580943301320076, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7228419780731201, 0.007619804237037897, 0.013993922621011734, 0.04429992660880089, 0.020430808886885643, 0.19081364572048187, 0.0, 0.0, 0.0, 0.0], [0.4783930778503418, 0.005506142508238554, 0.008406496606767178, 0.012424511834979057, 0.04335693642497063, 0.17542317509651184, 0.27648961544036865, 0.0, 0.0, 0.0], [0.056768160313367844, 0.001066300319507718, 0.0015203694347292185, 0.004650356248021126, 0.004999558907002211, 0.17368057370185852, 0.7387632131576538, 0.018551528453826904, 0.0, 0.0], [0.14709600806236267, 0.007261540275067091, 0.001291902968659997, 0.012605146504938602, 0.005232691299170256, 0.08098926395177841, 0.5304067134857178, 0.207069993019104, 0.00804678164422512, 0.0], [0.15080930292606354, 0.014301316812634468, 0.002821019385010004, 0.02008463814854622, 0.004475536290556192, 0.05297520384192467, 0.27036672830581665, 0.407105028629303, 0.007729486562311649, 0.06933178007602692]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9945669174194336, 0.005433134268969297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9554939270019531, 0.02177131362259388, 0.0227347444742918, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.19059398770332336, 0.7459079623222351, 0.05105874687433243, 0.012439398095011711, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.062025006860494614, 0.7277394533157349, 0.13110491633415222, 0.028790757060050964, 0.050339892506599426, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7678350806236267, 0.007377212401479483, 0.020054306834936142, 0.11815592646598816, 0.07254840433597565, 0.014029012061655521, 0.0, 0.0, 0.0, 0.0], [0.8187481760978699, 0.009394909255206585, 0.015446240082383156, 0.012167787179350853, 0.10175905376672745, 0.02721206098794937, 0.01527167297899723, 0.0, 0.0, 0.0], [0.7012083530426025, 0.12151088565587997, 0.03808446228504181, 0.01883355714380741, 0.0837249755859375, 0.006598148960620165, 0.006499246694147587, 0.023540453985333443, 0.0, 0.0], [0.5152325630187988, 0.054241329431533813, 0.17093418538570404, 0.020541386678814888, 0.17657014727592468, 0.012641755864024162, 0.01802964322268963, 0.023539982736110687, 0.008269038051366806, 0.0], [0.9131196141242981, 0.0010915634920820594, 0.006193474866449833, 0.006082434672862291, 0.03542511910200119, 0.006826554890722036, 0.0028478680178523064, 0.004068343434482813, 0.014553201384842396, 0.009791722521185875]]], [[[0.16448259353637695, 0.17219680547714233, 0.09987642616033554, 0.09012344479560852, 0.06534503400325775, 0.08456553518772125, 0.06690192222595215, 0.08019057661294937, 0.17631761729717255, 0.0], [0.49537378549575806, 0.03979916125535965, 0.09498286247253418, 0.0017974335933104157, 0.028368383646011353, 0.0015277893980965018, 0.014851069077849388, 0.0003722719266079366, 0.3229270279407501, 0.0], [0.0031106590759009123, 0.8318147659301758, 0.0329316072165966, 0.00014872441533952951, 0.000739947019610554, 0.0009879706194624305, 0.0012947155628353357, 0.00040531408740207553, 0.128566175699234, 0.0], [3.727031798916869e-05, 0.00033458907273598015, 0.9051278829574585, 0.014809494838118553, 0.0013665216974914074, 0.0009820980485528708, 0.0004274636448826641, 0.0006300737150013447, 0.07628484070301056, 0.0], [2.789895370369777e-05, 7.413508137688041e-05, 0.00011113573418697342, 0.9593441486358643, 0.023210706189274788, 0.00043970797560177743, 0.00011651179374894127, 0.0001221746060764417, 0.016553271561861038, 0.0], [5.518151283467887e-06, 4.040239218738861e-06, 4.706911568064243e-06, 0.0001475349417887628, 0.0011833186727017164, 0.007331210654228926, 0.0003812467912212014, 0.7072276473045349, 0.28371480107307434, 0.0], [2.1062598989374237e-06, 1.0153020184588968e-06, 9.153064297606761e-07, 2.3557351596537046e-05, 0.0019158869981765747, 0.9726926684379578, 0.0003360892878845334, 0.008161749690771103, 0.01686590164899826, 0.0], [1.876308124337811e-05, 3.1762643629917875e-05, 7.612020908709383e-06, 4.369785983726615e-06, 0.00035698129795491695, 0.006292039528489113, 0.9372867941856384, 0.0028216273058205843, 0.0531802624464035, 0.0], [0.00017082327394746244, 0.0008267413941211998, 0.0010992212919518352, 0.016357675194740295, 0.03317699581384659, 0.013446258381009102, 0.022417983040213585, 0.0993492603302002, 0.813154935836792, 0.0], [2.095436911986326e-06, 1.0510404990782263e-06, 8.745904779061675e-06, 9.465758921578526e-05, 0.9096792936325073, 0.004888555034995079, 0.00019891942793037742, 0.00012723646068479866, 0.08499950170516968, 0.0]], [[0.09510962665081024, 0.13984361290931702, 0.01835908181965351, 0.05623754486441612, 0.05484192445874214, 0.02751241996884346, 0.023350151255726814, 0.02046714909374714, 0.5642784833908081, 0.0], [0.32246580719947815, 0.12212380021810532, 0.0033711090218275785, 0.41883695125579834, 0.0010050723794847727, 0.00026374190929345787, 0.00840060692280531, 0.0003199145139660686, 0.12321317940950394, 0.0], [0.1343918889760971, 0.42756012082099915, 0.03016146458685398, 0.27197346091270447, 0.0008738918695598841, 0.00041738885920494795, 0.0011337834876030684, 0.0017680631717666984, 0.13172008097171783, 0.0], [4.970032023265958e-05, 0.0002945268643088639, 0.9929893612861633, 0.006102537736296654, 1.304412307945313e-06, 7.552243459940655e-06, 2.0433815279830014e-06, 1.4308750905911438e-05, 0.0005390164442360401, 0.0], [0.0006735534407198429, 0.0037932321429252625, 0.014864870347082615, 0.9520841240882874, 0.0031083461362868547, 0.0014454165939241648, 0.000881638377904892, 0.00042032121564261615, 0.02272843010723591, 0.0], [1.054488166118972e-06, 5.819076250190847e-06, 3.686256491164386e-07, 5.7184315664926544e-05, 1.600286668690387e-05, 0.0002979082928504795, 5.8259040088159963e-05, 0.997514009475708, 0.0020495890639722347, 0.0], [1.2081607110303594e-06, 1.8248301785206422e-06, 3.5412674037615943e-07, 0.00017610432405490428, 0.0004308871575631201, 0.9919483065605164, 0.001251595327630639, 0.004008213523775339, 0.002181792864575982, 0.0], [1.3394396773946937e-06, 1.858925656961219e-06, 8.99223309147601e-08, 5.498410246218555e-06, 4.1167979361489415e-05, 0.003499603597447276, 0.9961592555046082, 8.322765097545926e-06, 0.0002831367892213166, 0.0], [0.0011697824811562896, 0.00207342766225338, 0.0001985222043003887, 0.24218614399433136, 0.2580603361129761, 0.03422079235315323, 0.3017951250076294, 0.0700761154294014, 0.09021952003240585, 0.0], [4.897859540164973e-08, 1.9182496657776937e-07, 1.6890984966266842e-07, 0.00012898082786705345, 0.9986647963523865, 0.0003688811557367444, 8.465539576718584e-05, 1.2611121746886056e-05, 0.0007397857843898237, 0.0]], [[0.008738831616938114, 0.010689073242247105, 0.010104849003255367, 0.025418052449822426, 0.008787600323557854, 0.018541773781180382, 0.01414045225828886, 0.009587875567376614, 0.8939914107322693, 0.0], [0.050771377980709076, 0.08173098415136337, 0.03076810948550701, 0.6816214919090271, 0.04326915368437767, 0.0030209666583687067, 0.006032166071236134, 0.007633579429239035, 0.09515213221311569, 0.0], [0.04749365150928497, 0.07148067653179169, 0.018722670152783394, 0.5845115184783936, 0.03816590458154678, 0.003933309111744165, 0.006466464139521122, 0.021205652505159378, 0.20802012085914612, 0.0], [0.021572547033429146, 0.11727327853441238, 0.03622674569487572, 0.4274545907974243, 0.05620160698890686, 0.01161592174321413, 0.010393376462161541, 0.014363090507686138, 0.30489882826805115, 0.0], [0.015270093455910683, 0.10013995319604874, 0.006727923639118671, 0.19538360834121704, 0.1119888573884964, 0.027630485594272614, 0.0700199231505394, 0.01868581771850586, 0.4541531801223755, 0.0], [0.00540963327512145, 0.07916348427534103, 0.01957465149462223, 0.49324244260787964, 0.10871188342571259, 0.02422497235238552, 0.008650544099509716, 0.16292543709278107, 0.0980970561504364, 0.0], [0.027941647917032242, 0.005471521522849798, 0.006384703796356916, 0.03924928605556488, 0.22657036781311035, 0.21837352216243744, 0.3372570872306824, 0.05897291377186775, 0.07977905124425888, 0.0], [0.009049936197698116, 0.005020579323172569, 0.014692768454551697, 0.15799382328987122, 0.4401932656764984, 0.1766415536403656, 0.03136269003152847, 0.12063619494438171, 0.044409021735191345, 0.0], [0.0007816475699655712, 0.0003147682291455567, 0.0032215022947639227, 0.4467180669307709, 0.3918246924877167, 0.00227341428399086, 0.004370422102510929, 0.14414219558238983, 0.006353371310979128, 0.0], [0.0005489268223755062, 0.016601460054516792, 0.01341363787651062, 0.2753817141056061, 0.13981539011001587, 0.04711242765188217, 0.08167178928852081, 0.11951272189617157, 0.30594193935394287, 0.0]], [[0.11438923329114914, 0.12380287796258926, 0.23573537170886993, 0.19010169804096222, 0.15611350536346436, 0.031749427318573, 0.02482231892645359, 0.05017237365245819, 0.07311322540044785, 0.0], [0.002549531403928995, 0.03178577870130539, 0.17347589135169983, 0.2232668697834015, 0.49775105714797974, 0.018238944932818413, 0.005651220679283142, 0.03368452191352844, 0.013595964759588242, 0.0], [0.0032994491048157215, 0.026504727080464363, 0.41210347414016724, 0.24245016276836395, 0.18897436559200287, 0.012874660082161427, 0.006452939473092556, 0.10089367628097534, 0.00644671730697155, 0.0], [0.002998506650328636, 0.048583757132291794, 0.28224417567253113, 0.0846971943974495, 0.013445784337818623, 0.02188579924404621, 0.017656570300459862, 0.5155076384544373, 0.012980557046830654, 0.0], [0.004188622813671827, 0.028234833851456642, 0.022820167243480682, 0.058492597192525864, 0.19205521047115326, 0.08343320339918137, 0.07119973003864288, 0.4843534827232361, 0.0552222914993763, 0.0], [0.0038351663388311863, 0.015353971160948277, 0.01755588687956333, 0.06245748698711395, 0.1218588799238205, 0.07207991182804108, 0.02867230959236622, 0.5455195903778076, 0.13266700506210327, 0.0], [0.004144841339439154, 0.0048835063353180885, 0.0035110898315906525, 0.06276324391365051, 0.04069552943110466, 0.3603023290634155, 0.1472603678703308, 0.2116946280002594, 0.16474448144435883, 0.0], [0.024624889716506004, 0.016127971932291985, 0.0073340879753232, 0.023849278688430786, 0.042295511811971664, 0.5078635215759277, 0.2884303331375122, 0.011452756822109222, 0.07802165299654007, 0.0], [0.00880166981369257, 0.002673782641068101, 0.001370548619888723, 0.0061265453696250916, 0.02490534819662571, 0.2073771357536316, 0.3818575143814087, 0.1663341522216797, 0.20055335760116577, 0.0], [0.012253189459443092, 0.02221212349832058, 0.002282155444845557, 0.10455729067325592, 0.4111727774143219, 0.08308815956115723, 0.045707643032073975, 0.03711223974823952, 0.2816142141819, 0.0]], [[0.5821239352226257, 0.14550858736038208, 0.031251534819602966, 0.030760297551751137, 0.02147754468023777, 0.013665237464010715, 0.009087015874683857, 0.01557532325387001, 0.15055041015148163, 0.0], [0.12817564606666565, 0.33913177251815796, 0.07241326570510864, 0.41213902831077576, 0.0326012559235096, 0.0031606394331902266, 0.0006341012776829302, 0.007317711599171162, 0.0044263736344873905, 0.0], [0.08047150820493698, 0.06199575960636139, 0.5555182099342346, 0.2858560383319855, 0.008700164034962654, 0.003758196486160159, 0.001155794132500887, 0.0007424709619954228, 0.0018020549323409796, 0.0], [0.010044030845165253, 0.018482256680727005, 0.6269924640655518, 0.32439544796943665, 0.01023165788501501, 0.007641270756721497, 0.0008933563949540257, 0.0010311403311789036, 0.00028844154439866543, 0.0], [0.0007911038701422513, 0.0008549468475393951, 0.015090622939169407, 0.8270009160041809, 0.11969847232103348, 0.032614268362522125, 0.0024233118165284395, 0.0011481117689982057, 0.0003779604157898575, 0.0], [0.017773190513253212, 0.008623103611171246, 0.0020072387997061014, 0.08177924901247025, 0.13816505670547485, 0.6801413297653198, 0.02186667174100876, 0.024107687175273895, 0.025536518543958664, 0.0], [0.000318053673254326, 5.6540200603194535e-05, 1.071194674295839e-05, 0.0009494975674897432, 0.0034297029487788677, 0.032661326229572296, 0.9588278532028198, 0.003185966284945607, 0.0005602877936325967, 0.0], [0.0017862697131931782, 0.0002347631088923663, 2.1297884813975543e-05, 0.0004797980946023017, 0.0018031852087005973, 0.024247879162430763, 0.45456385612487793, 0.5099425911903381, 0.006920217536389828, 0.0], [0.0006541880429722369, 0.0009561541373841465, 7.73017163737677e-05, 0.00942671112716198, 0.04198922589421272, 0.04971348121762276, 0.32961171865463257, 0.4513629972934723, 0.11620841920375824, 0.0], [0.017209511250257492, 0.004475452937185764, 3.128392927465029e-05, 0.00047953161993063986, 0.00448839133605361, 0.03360708802938461, 0.11509764194488525, 0.5398797988891602, 0.2847314178943634, 0.0]], [[0.20143046975135803, 0.41116827726364136, 0.09215858578681946, 0.10672477632761002, 0.06125285103917122, 0.017610367387533188, 0.01457523088902235, 0.02514597773551941, 0.06993352621793747, 0.0], [0.026864346116781235, 0.037146128714084625, 0.08411292731761932, 0.02904331497848034, 0.0955604761838913, 0.05886658653616905, 0.08584483712911606, 0.4076027572154999, 0.17495866119861603, 0.0], [0.073190838098526, 0.07998740673065186, 0.05594569817185402, 0.03243006020784378, 0.10037493705749512, 0.13878461718559265, 0.15250830352306366, 0.25721096992492676, 0.10956726223230362, 0.0], [0.0438627265393734, 0.04628896340727806, 0.4038660526275635, 0.005475929472595453, 0.03436022624373436, 0.11165640503168106, 0.02260321006178856, 0.28233063220977783, 0.04955587536096573, 0.0], [0.2377929538488388, 0.08882997930049896, 0.12371516227722168, 0.08651548624038696, 0.015416872687637806, 0.04211122542619705, 0.16403844952583313, 0.11833071708679199, 0.12324906885623932, 0.0], [0.023254310712218285, 0.0034057339653372765, 0.036038532853126526, 0.009054891765117645, 0.0329253226518631, 0.05284882336854935, 0.15671837329864502, 0.6067742109298706, 0.07897992432117462, 0.0], [0.015282228589057922, 0.008608018048107624, 0.08339564502239227, 0.032651614397764206, 0.21303850412368774, 0.22661514580249786, 0.21832069754600525, 0.1323210895061493, 0.06976725161075592, 0.0], [0.019424932077527046, 0.008587736636400223, 0.014951083809137344, 0.01159222237765789, 0.2890152633190155, 0.2543036639690399, 0.2561561167240143, 0.0882645845413208, 0.05770434811711311, 0.0], [0.020595766603946686, 0.015824340283870697, 0.008689227513968945, 0.03796549141407013, 0.3004503846168518, 0.16956602036952972, 0.10506420582532883, 0.05004280060529709, 0.2918018400669098, 0.0], [0.18154361844062805, 0.0977708026766777, 0.20556335151195526, 0.05251142755150795, 0.13640889525413513, 0.06629360467195511, 0.06030320003628731, 0.08172836154699326, 0.11787670105695724, 0.0]], [[0.07673492282629013, 0.03585591912269592, 0.0804624855518341, 0.05707075819373131, 0.16190174221992493, 0.1288135051727295, 0.1235240250825882, 0.06807681918144226, 0.2675597667694092, 0.0], [0.005086997989565134, 0.014635499566793442, 0.013461720198392868, 0.6349815726280212, 0.14714521169662476, 0.015218403190374374, 0.01605474203824997, 0.018318237736821175, 0.1350976973772049, 0.0], [0.03515003249049187, 0.049813926219940186, 0.04029693454504013, 0.4151618778705597, 0.24873343110084534, 0.009437951259315014, 0.008381601423025131, 0.020832136273384094, 0.17219208180904388, 0.0], [0.06722414493560791, 0.13528113067150116, 0.06224377825856209, 0.18915168941020966, 0.17580503225326538, 0.07229694724082947, 0.012536793015897274, 0.09137610346078873, 0.19408434629440308, 0.0], [0.09099949151277542, 0.09548961371183395, 0.04829362779855728, 0.1739831268787384, 0.06667517125606537, 0.05157051607966423, 0.05465595796704292, 0.06177656352519989, 0.3565560579299927, 0.0], [0.09822985529899597, 0.05441536381840706, 0.039150238037109375, 0.06369251012802124, 0.05292840674519539, 0.050128646194934845, 0.044398434460163116, 0.04042055085301399, 0.5566359758377075, 0.0], [0.012019939720630646, 0.0076602306216955185, 0.02716030552983284, 0.03984800726175308, 0.09776019304990768, 0.05175628885626793, 0.08536165207624435, 0.0944109782576561, 0.5840223431587219, 0.0], [0.036716632544994354, 0.021969007328152657, 0.010507079772651196, 0.012404722161591053, 0.040125522762537, 0.010736462660133839, 0.018730206415057182, 0.030387653037905693, 0.8184227347373962, 0.0], [0.04769879952073097, 0.19333122670650482, 0.02803504839539528, 0.016029207035899162, 0.11119306832551956, 0.03845509514212608, 0.011404097080230713, 0.0836206004023552, 0.4702327847480774, 0.0], [0.05245642364025116, 0.013315027579665184, 0.012056763283908367, 0.004825723823159933, 0.015483945608139038, 0.032884638756513596, 0.027794960886240005, 0.07057305425405502, 0.7706093788146973, 0.0]], [[0.05745904520153999, 0.06613133102655411, 0.11319872736930847, 0.031750500202178955, 0.0641264021396637, 0.07090476900339127, 0.053613319993019104, 0.1108509749174118, 0.4319649040699005, 0.0], [0.12783250212669373, 0.16847258806228638, 0.08126984536647797, 0.10575822740793228, 0.03301985561847687, 0.2111520618200302, 0.10687874257564545, 0.06316707283258438, 0.10244929045438766, 0.0], [0.1413263976573944, 0.38601601123809814, 0.16798537969589233, 0.14611834287643433, 0.015951359644532204, 0.042198505252599716, 0.016183707863092422, 0.06246974319219589, 0.021750787273049355, 0.0], [0.020376645028591156, 0.008152640424668789, 0.04579228535294533, 0.022974595427513123, 0.007921000011265278, 0.11700868606567383, 0.010826223529875278, 0.7216546535491943, 0.04529344290494919, 0.0], [0.04728184640407562, 0.041129130870103836, 0.12847241759300232, 0.038289085030555725, 0.07389654964208603, 0.11478690057992935, 0.04442784935235977, 0.41169247031211853, 0.1000237911939621, 0.0], [0.016180921345949173, 0.005130380857735872, 0.21081623435020447, 0.00797765702009201, 0.04691680520772934, 0.052309177815914154, 0.2947923243045807, 0.34133997559547424, 0.02453651838004589, 0.0], [0.006579844746738672, 0.001606129459105432, 0.206822007894516, 0.017204096540808678, 0.13898226618766785, 0.09910376369953156, 0.4235020577907562, 0.05497713387012482, 0.051222700625658035, 0.0], [0.00896216370165348, 0.0023249718360602856, 0.0226416178047657, 0.05458173528313637, 0.07694459706544876, 0.29436299204826355, 0.36870595812797546, 0.12525610625743866, 0.046219732612371445, 0.0], [0.027829669415950775, 0.014619122259318829, 0.014550572261214256, 0.048137370496988297, 0.15001901984214783, 0.11716196686029434, 0.34159788489341736, 0.1513865739107132, 0.13469791412353516, 0.0], [0.0014273751294240355, 0.003807784290984273, 0.3760293126106262, 0.002253596903756261, 0.11343870311975479, 0.12883712351322174, 0.04242479428648949, 0.28902071714401245, 0.042760640382766724, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9917634725570679, 0.008236419409513474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.711856484413147, 0.20838035643100739, 0.07976315170526505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6327172517776489, 0.1227935329079628, 0.21565596759319305, 0.028833283111453056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3586137592792511, 0.038762304931879044, 0.08015953004360199, 0.4233120083808899, 0.09915236383676529, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7095601558685303, 0.03453405201435089, 0.02220289036631584, 0.009008818306028843, 0.201883926987648, 0.022810086607933044, 0.0, 0.0, 0.0, 0.0], [0.5828825831413269, 0.02795644849538803, 0.054448600858449936, 0.01975347101688385, 0.11504233628511429, 0.08908692002296448, 0.11082970350980759, 0.0, 0.0, 0.0], [0.4315364956855774, 0.020537925884127617, 0.01659376546740532, 0.014654956758022308, 0.13063199818134308, 0.27319464087486267, 0.08869150280952454, 0.024158723652362823, 0.0, 0.0], [0.26020547747612, 0.014821716584265232, 0.01224969606846571, 0.0724530965089798, 0.10939211398363113, 0.19152909517288208, 0.10495918244123459, 0.1680101454257965, 0.06637949496507645, 0.0], [0.6687084436416626, 0.04345089942216873, 0.009689688682556152, 0.0018685735994949937, 0.0738394483923912, 0.12735962867736816, 0.025320274755358696, 0.026545442640781403, 0.020931225270032883, 0.0022863498888909817]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9482711553573608, 0.051728855818510056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8711318373680115, 0.04994085431098938, 0.07892734557390213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7221198678016663, 0.040686361491680145, 0.06532222777605057, 0.17187155783176422, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5948007702827454, 0.036634139716625214, 0.02264709398150444, 0.035541336983442307, 0.3103766441345215, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6650473475456238, 0.01644211634993553, 0.019737746566534042, 0.0375308021903038, 0.10231779515743256, 0.15892422199249268, 0.0, 0.0, 0.0, 0.0], [0.36675524711608887, 0.04118315875530243, 0.02765432558953762, 0.03228116035461426, 0.11875578761100769, 0.12892943620681763, 0.2844408452510834, 0.0, 0.0, 0.0], [0.19659309089183807, 0.015950728207826614, 0.02453998662531376, 0.039237309247255325, 0.037656329572200775, 0.34599894285202026, 0.23759640753269196, 0.10242718458175659, 0.0, 0.0], [0.3881740868091583, 0.012267092242836952, 0.01897304505109787, 0.013982790522277355, 0.030991200357675552, 0.10819684714078903, 0.20157809555530548, 0.14642520248889923, 0.07941170781850815, 0.0], [0.11410266160964966, 0.03479800745844841, 0.043540675193071365, 0.021180409938097, 0.03197954222559929, 0.2248576581478119, 0.12852585315704346, 0.2089216560125351, 0.039846520870923996, 0.1522471308708191]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.993086576461792, 0.0069133141078054905, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9852874875068665, 0.011381878517568111, 0.0033306065015494823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4834398031234741, 0.011301998049020767, 0.48758530616760254, 0.017672834917902946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9851425886154175, 0.0010397545993328094, 0.00470126885920763, 0.0012236799811944366, 0.007892588153481483, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6588926315307617, 0.005506628658622503, 0.021607331931591034, 0.010738613083958626, 0.07747143507003784, 0.2257833182811737, 0.0, 0.0, 0.0, 0.0], [0.13557791709899902, 0.018924091011285782, 0.02187344618141651, 0.015362304635345936, 0.11512601375579834, 0.14739760756492615, 0.5457385182380676, 0.0, 0.0, 0.0], [0.38992705941200256, 0.021535715088248253, 0.005403842777013779, 0.0032997699454426765, 0.4358868896961212, 0.06306594610214233, 0.03204012289643288, 0.04884066432714462, 0.0, 0.0], [0.81478351354599, 0.022238636389374733, 0.0008386021945625544, 0.01924033649265766, 0.06109088659286499, 0.020853841677308083, 0.014834966510534286, 0.028932424262166023, 0.017186695709824562, 0.0], [0.011323019862174988, 0.004743177909404039, 0.004908193834125996, 0.04389021545648575, 0.9175272583961487, 0.008399821817874908, 0.00010120288789039478, 0.0007724545430392027, 0.001946530188433826, 0.006388010922819376]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9621535539627075, 0.037846412509679794, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5398231148719788, 0.4385344386100769, 0.021642372012138367, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6502059698104858, 0.16868625581264496, 0.04876677691936493, 0.13234086334705353, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5965072512626648, 0.06637387722730637, 0.1054789125919342, 0.1866345852613449, 0.04500538855791092, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3253602683544159, 0.03396952152252197, 0.02178867906332016, 0.07780158519744873, 0.04822142422199249, 0.49285849928855896, 0.0, 0.0, 0.0, 0.0], [0.2524598240852356, 0.04065639525651932, 0.06012948602437973, 0.022925280034542084, 0.0371418297290802, 0.17370767891407013, 0.41297948360443115, 0.0, 0.0, 0.0], [0.03411499038338661, 0.003937003668397665, 0.005961195565760136, 0.01710909977555275, 0.011033114977180958, 0.7081340551376343, 0.13750500977039337, 0.08220544457435608, 0.0, 0.0], [0.42400264739990234, 0.02131979539990425, 0.017963027581572533, 0.01083337515592575, 0.019156770780682564, 0.14712399244308472, 0.1343262642621994, 0.19853995740413666, 0.02673417516052723, 0.0], [0.010900852270424366, 0.01643177680671215, 0.007438827771693468, 0.037741534411907196, 0.0038807683158665895, 0.513563871383667, 0.17121337354183197, 0.14364023506641388, 0.04466766491532326, 0.050521109253168106]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4730486273765564, 0.5269513726234436, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39858773350715637, 0.07930062711238861, 0.5221116542816162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5825604200363159, 0.08404675871133804, 0.15067298710346222, 0.182719886302948, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.29498350620269775, 0.03899451717734337, 0.00506106112152338, 0.006130008026957512, 0.6548308730125427, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13055028021335602, 0.007264712825417519, 0.014658198691904545, 0.03852052241563797, 0.6908979415893555, 0.11810839176177979, 0.0, 0.0, 0.0, 0.0], [0.6701509952545166, 0.016114505007863045, 0.009837295860052109, 0.013812566176056862, 0.10121432691812515, 0.04637172445654869, 0.14249859750270844, 0.0, 0.0, 0.0], [0.15980258584022522, 0.02680308185517788, 0.03885137289762497, 0.01341771800071001, 0.16442187130451202, 0.12716332077980042, 0.3698134124279022, 0.09972671419382095, 0.0, 0.0], [0.5671898722648621, 0.0029452391900122166, 0.0006932761170901358, 0.0009682640084065497, 0.008882325142621994, 0.018135691061615944, 0.19489231705665588, 0.1878870278596878, 0.01840599626302719, 0.0], [0.10793960839509964, 0.02733222208917141, 0.05983218923211098, 0.007959540002048016, 0.012123869732022285, 0.0992540642619133, 0.031409986317157745, 0.1074245497584343, 0.5389924645423889, 0.007731476798653603]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9029706120491028, 0.09702935069799423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0243820920586586, 0.026593990623950958, 0.9490237236022949, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002445725491270423, 0.01137782447040081, 0.2685152590274811, 0.7176609635353088, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013372139073908329, 0.017163371667265892, 0.023703746497631073, 0.029362313449382782, 0.916398286819458, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009483089670538902, 0.0015323974657803774, 0.016186771914362907, 0.02369842305779457, 0.15252061188220978, 0.7965786457061768, 0.0, 0.0, 0.0, 0.0], [0.9718639850616455, 0.0004310244112275541, 0.00011954548244830221, 0.007853196933865547, 0.005200029816478491, 0.0034086843952536583, 0.011123435571789742, 0.0, 0.0, 0.0], [0.24338993430137634, 0.0009381886338815093, 0.001691973302513361, 0.004991883412003517, 0.06480661034584045, 0.02667633630335331, 0.5911706686019897, 0.06633439660072327, 0.0, 0.0], [0.9644694328308105, 0.00020931981271132827, 0.00022034233552403748, 0.001116775325499475, 0.0005140798166394234, 0.011200232431292534, 0.006607241928577423, 0.015303434804081917, 0.00035933865001425147, 0.0], [6.446504994528368e-05, 5.223282641964033e-05, 4.761212403536774e-05, 0.0026887860149145126, 0.9879595041275024, 8.169181819539517e-05, 3.4432316169841215e-05, 0.00022215544595383108, 0.008540215902030468, 0.0003085293574258685]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9361864924430847, 0.06381344050168991, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8988155722618103, 0.08033642917871475, 0.020848000422120094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9107179045677185, 0.0414138063788414, 0.01669401116669178, 0.031174303963780403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7451518774032593, 0.09319417923688889, 0.038068220019340515, 0.07867664098739624, 0.04490913450717926, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7448095083236694, 0.053781699389219284, 0.02206255868077278, 0.051568105816841125, 0.050503022968769073, 0.07727508991956711, 0.0, 0.0, 0.0, 0.0], [0.49530187249183655, 0.08432565629482269, 0.024537190794944763, 0.03536847233772278, 0.04101351276040077, 0.2942921817302704, 0.025161173194646835, 0.0, 0.0, 0.0], [0.33749768137931824, 0.0470278225839138, 0.025994539260864258, 0.11184448003768921, 0.035708073526620865, 0.3288814127445221, 0.052594561129808426, 0.06045151129364967, 0.0, 0.0], [0.5827996730804443, 0.037185750901699066, 0.025691334158182144, 0.040444474667310715, 0.032313525676727295, 0.15237869322299957, 0.02532070316374302, 0.06300554424524307, 0.040860243141651154, 0.0], [0.17712005972862244, 0.06858222186565399, 0.023361776024103165, 0.06553570926189423, 0.015878353267908096, 0.40178799629211426, 0.03335757926106453, 0.09457883983850479, 0.06679747253656387, 0.05300001800060272]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5274816155433655, 0.47251835465431213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0672292709350586, 0.8769893646240234, 0.05578138306736946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13304242491722107, 0.3016340434551239, 0.1132093071937561, 0.45211419463157654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3050541281700134, 0.12257003039121628, 0.15977424383163452, 0.12758392095565796, 0.2850176990032196, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22356735169887543, 0.03928647190332413, 0.007754397578537464, 0.009327426552772522, 0.12143179029226303, 0.5986325740814209, 0.0, 0.0, 0.0, 0.0], [0.21462960541248322, 0.06222677230834961, 0.03770677372813225, 0.020617984235286713, 0.1298619657754898, 0.16450734436511993, 0.3704494833946228, 0.0, 0.0, 0.0], [0.03640613704919815, 0.010346643626689911, 0.00673291739076376, 0.007102567236870527, 0.047351155430078506, 0.07502260059118271, 0.1735789030790329, 0.6434589624404907, 0.0, 0.0], [0.10655857622623444, 0.005401281639933586, 0.008467022329568863, 0.004935698118060827, 0.02920999936759472, 0.06761414557695389, 0.11367721855640411, 0.6410130262374878, 0.02312297187745571, 0.0], [0.022663118317723274, 0.01638328656554222, 0.016234278678894043, 0.013438239693641663, 0.13762539625167847, 0.1316443532705307, 0.10126813501119614, 0.1815005987882614, 0.10885223746299744, 0.27039045095443726]]], [[[0.5484673976898193, 0.11615661531686783, 0.018765881657600403, 0.05580288916826248, 0.007166780531406403, 0.0032917018979787827, 0.014802427962422371, 0.009165346622467041, 0.2263808399438858, 0.0], [0.05529153719544411, 0.9114729762077332, 0.02160518430173397, 0.004567866213619709, 0.0020856212358921766, 0.002110318513587117, 7.9427394666709e-05, 0.0003330546023789793, 0.002454147208482027, 0.0], [0.05249117314815521, 0.7437799572944641, 0.20206855237483978, 0.000259216787526384, 5.300459815771319e-05, 0.0010736108524724841, 1.3093422239762731e-05, 5.472580232890323e-05, 0.00020689083612523973, 0.0], [0.00011124516458949074, 0.00014391898002941161, 0.002859711181372404, 0.9545366168022156, 0.03663090988993645, 0.0007743826135993004, 9.755761857377365e-05, 0.004461521748453379, 0.00038416660390794277, 0.0], [0.00743946572765708, 0.0006693374016322196, 0.030706975609064102, 0.7693524360656738, 0.13636630773544312, 0.010656076483428478, 0.00020547783060465008, 0.04430907592177391, 0.0002948205510620028, 0.0], [0.34355977177619934, 0.05352164804935455, 0.28276264667510986, 0.0013746530748903751, 0.09131773561239243, 0.14369648694992065, 0.012657100334763527, 0.0018211203860118985, 0.06928855180740356, 0.0], [0.00012065855844411999, 1.9836104911519215e-05, 1.09456195787061e-05, 1.777518991730176e-05, 5.799566861242056e-05, 0.00024955562548711896, 0.9993764758110046, 2.622428155518719e-06, 0.00014374956663232297, 0.0], [0.07937772572040558, 0.06896578520536423, 0.046331409364938736, 0.0006753505440428853, 0.10991709679365158, 0.07862550020217896, 0.015291865915060043, 0.024935398250818253, 0.575879693031311, 0.0], [0.0013021372724324465, 0.0014324801741167903, 0.001721968175843358, 0.0011953436769545078, 0.025524066761136055, 0.0017154657980427146, 0.004751213360577822, 0.06856247782707214, 0.8937948942184448, 0.0], [0.19237911701202393, 0.08248087018728256, 0.005975060164928436, 0.0005637910799123347, 0.0032617340330034494, 0.08159960061311722, 0.10672623664140701, 0.06415636837482452, 0.4628572165966034, 0.0]], [[0.02611132338643074, 0.03929625079035759, 0.13154497742652893, 0.0362381711602211, 0.41634607315063477, 0.056842975318431854, 0.07852181792259216, 0.04581860452890396, 0.16927982866764069, 0.0], [0.02255222387611866, 0.10361888259649277, 0.8158259987831116, 0.004235901869833469, 0.006766538135707378, 0.000986771541647613, 0.0032120062969624996, 0.001885716337710619, 0.04091576859354973, 0.0], [0.017373425886034966, 0.13537107408046722, 0.8077713847160339, 0.0038886782713234425, 0.000276644917903468, 0.0006381930434145033, 0.00045153097016736865, 0.00025059780455194414, 0.0339784249663353, 0.0], [0.005007221829146147, 0.01780957728624344, 0.01267488207668066, 0.04065018519759178, 0.30516332387924194, 0.0026367236860096455, 0.0019572318997234106, 0.004150545224547386, 0.6099502444267273, 0.0], [0.00225767120718956, 0.0031865497585386038, 0.001125291339121759, 0.016497144475579262, 0.8971690535545349, 0.00343481102026999, 0.003961168695241213, 0.006191920954734087, 0.0661763921380043, 0.0], [0.009628614410758018, 0.010054895654320717, 0.001336919842287898, 0.0704738199710846, 0.6877674460411072, 0.03301373869180679, 0.05187760666012764, 0.005273953080177307, 0.1305730938911438, 0.0], [0.004103087354451418, 0.0010862533235922456, 0.0006940921302884817, 0.005870609078556299, 0.43826234340667725, 0.030803751200437546, 0.2956492602825165, 0.002342070685699582, 0.2211885154247284, 0.0], [0.0007035931921564043, 0.0015657383482903242, 0.0003329406026750803, 0.025085464119911194, 0.8715798258781433, 0.006046876311302185, 0.002586639951914549, 0.00011169366916874424, 0.09198720753192902, 0.0], [0.0021814475767314434, 0.0018482444575056434, 0.02461252734065056, 0.02290530502796173, 0.17733190953731537, 0.007551506161689758, 0.026218494400382042, 0.1859409213066101, 0.5514096021652222, 0.0], [0.002560347318649292, 0.0069580040872097015, 0.0021583843044936657, 0.002428637584671378, 0.010794135741889477, 0.002866419730708003, 0.010929176583886147, 0.004671781323850155, 0.9566330909729004, 0.0]], [[0.10015174746513367, 0.09576243162155151, 0.03824060782790184, 0.020538825541734695, 0.027732321992516518, 0.017240623012185097, 0.011470633558928967, 0.4632270634174347, 0.22563566267490387, 0.0], [0.018642960116267204, 0.04822106286883354, 0.0140958521515131, 0.13092586398124695, 0.031955283135175705, 0.05195324495434761, 0.024238048121333122, 0.35591286420822144, 0.32405486702919006, 0.0], [0.04842180013656616, 0.06019889563322067, 0.016854524612426758, 0.166769877076149, 0.016003064811229706, 0.024013301357626915, 0.03686072677373886, 0.44016721844673157, 0.1907106339931488, 0.0], [0.0036558371502906084, 0.0033082098234444857, 0.0009605223312973976, 0.017093589529395103, 0.019939379766583443, 0.08280718326568604, 0.031923551112413406, 0.703184187412262, 0.1371275782585144, 0.0], [0.0018930931109935045, 0.002881440566852689, 0.00019882648484781384, 0.00406575808301568, 0.0021070034708827734, 0.011610278859734535, 0.0074381148442626, 0.9341073036193848, 0.03569793701171875, 0.0], [0.00691588269546628, 0.02838265150785446, 0.015397720038890839, 0.031874921172857285, 0.04765379801392555, 0.22744230926036835, 0.06624653190374374, 0.10724947601556778, 0.4688366651535034, 0.0], [0.009144916199147701, 0.012914983555674553, 0.0114166010171175, 0.010616455227136612, 0.03852293640375137, 0.11398687958717346, 0.23996756970882416, 0.03855413943529129, 0.5248754620552063, 0.0], [0.0165528766810894, 0.08396174013614655, 0.03695421293377876, 0.012792840600013733, 0.05054211989045143, 0.004681664984673262, 0.006349458359181881, 0.0059485542587935925, 0.7822163701057434, 0.0], [0.01941962167620659, 0.0720844566822052, 0.06703408807516098, 0.0024893011432141066, 0.09017500281333923, 0.01547347940504551, 0.011082785204052925, 0.036743972450494766, 0.6854971051216125, 0.0], [0.010559813119471073, 0.7681021094322205, 0.01782229356467724, 0.0007385257631540298, 0.000383153063012287, 0.00014055910287424922, 0.00037340103881433606, 0.000453647633548826, 0.2014264017343521, 0.0]], [[0.15394526720046997, 0.03155883774161339, 0.013263775035738945, 0.6156436800956726, 0.07578981667757034, 0.016770629212260246, 0.041555847972631454, 0.008898256346583366, 0.042573899030685425, 0.0], [0.0076131573878228664, 0.0034139170311391354, 0.013692762702703476, 0.9489110708236694, 0.0023491643369197845, 0.004504398908466101, 0.018228650093078613, 6.88576401444152e-05, 0.0012180121848359704, 0.0], [0.002593724289909005, 0.0010450187837705016, 0.004842442460358143, 0.9895163178443909, 0.0010734394891187549, 6.863682210678235e-05, 0.00040535334846936166, 0.00029785462538711727, 0.00015763564442750067, 0.0], [0.007949098013341427, 0.0007930149440653622, 0.0010613474296405911, 0.913150429725647, 0.017281265929341316, 0.01414033118635416, 0.03622613847255707, 0.0036093932576477528, 0.0057889497838914394, 0.0], [0.002646723063662648, 0.0005160199943929911, 0.0002488561731297523, 0.0025351925287395716, 0.0016247049206867814, 0.09429789334535599, 0.8856176137924194, 0.005861275363713503, 0.006651633884757757, 0.0], [0.010305220261216164, 0.0041244118474423885, 0.0009454450337216258, 0.011387528851628304, 0.006450551562011242, 0.09920497238636017, 0.7582080960273743, 0.0005519646219909191, 0.1088215708732605, 0.0], [0.02700764685869217, 0.004230276681482792, 0.0004602614790201187, 0.0022337904665619135, 0.001628970610909164, 0.01760227419435978, 0.5739604234695435, 0.0034094173461198807, 0.3694668710231781, 0.0], [0.008519366383552551, 0.005846879445016384, 0.00031929058604873717, 0.00022687541786581278, 0.0001488836423959583, 0.0012441301951184869, 0.007195098325610161, 0.000364138453733176, 0.9761351943016052, 0.0], [0.10429845005273819, 0.062129467725753784, 0.0009245545952580869, 0.00015166438242886215, 0.00031537580071017146, 0.00040291156619787216, 0.006900690030306578, 0.009933815337717533, 0.8149431943893433, 0.0], [0.021318454295396805, 0.024646490812301636, 0.0006273255567066371, 3.4892458643298596e-05, 0.0002248335222247988, 0.001184952794574201, 0.005942351184785366, 0.01648845337331295, 0.9295321702957153, 0.0]], [[0.18210574984550476, 0.33179324865341187, 0.17356830835342407, 0.09634877741336823, 0.13200198113918304, 0.013823754154145718, 0.003925282042473555, 0.0030049949418753386, 0.06342781335115433, 0.0], [0.0020119324326515198, 0.018535200506448746, 0.9658629894256592, 0.007450602483004332, 0.004517300054430962, 0.0010996937053278089, 0.00011890578753082082, 8.778373739914969e-05, 0.0003156243183184415, 0.0], [0.00025491390260867774, 0.010184208862483501, 0.989091694355011, 0.0003856455150526017, 4.125349732930772e-05, 1.630889528314583e-05, 4.754766450787429e-06, 9.06635705177905e-06, 1.2339231943769846e-05, 0.0], [0.0005263744969852269, 0.0021082928869873285, 0.03339724615216255, 0.9491472840309143, 0.010056117549538612, 0.0008323733345605433, 0.00041247360059060156, 0.003171485150232911, 0.0003482940956018865, 0.0], [0.004013302735984325, 0.003607808379456401, 0.10117900371551514, 0.3848154544830322, 0.4549750089645386, 0.02184862084686756, 0.015023248270154, 0.013938427902758121, 0.0005991325015202165, 0.0], [0.006630904506891966, 0.001555793103761971, 0.01566290855407715, 0.005377574823796749, 0.0545264296233654, 0.7578195929527283, 0.15542279183864594, 0.00011175184772582725, 0.002892365213483572, 0.0], [0.0013706677127629519, 0.0003565592342056334, 0.0006504033226519823, 0.0008717189775779843, 0.023110924288630486, 0.16852477192878723, 0.8020843863487244, 0.0004564319388009608, 0.0025740356650203466, 0.0], [0.0005740747437812388, 0.0018384596332907677, 0.015691960230469704, 0.0004515495093073696, 0.04004881531000137, 0.8668573498725891, 0.03566786274313927, 0.01278533972799778, 0.0260846596211195, 0.0], [0.0008355869795195758, 0.003608973463997245, 0.04490630701184273, 0.009341607801616192, 0.007649072911590338, 0.10034366697072983, 0.06446904689073563, 0.7009655237197876, 0.06788014620542526, 0.0], [0.00805425550788641, 0.039243537932634354, 0.05003930628299713, 0.0007152591715566814, 0.00863983016461134, 0.4756345748901367, 0.24407540261745453, 0.1204291433095932, 0.053168926388025284, 0.0]], [[0.7289455533027649, 0.07586073875427246, 0.09869885444641113, 0.029881592839956284, 0.013988524675369263, 0.006547953933477402, 0.011115974746644497, 0.01961168274283409, 0.01534893549978733, 0.0], [0.0068419137969613075, 0.9909167289733887, 0.0013876587618142366, 0.00011136479588458315, 6.257302447920665e-05, 6.40047510387376e-05, 1.7212767488672398e-05, 0.00025805848417803645, 0.0003405954339541495, 0.0], [0.0458948090672493, 0.09657034277915955, 0.8496794104576111, 0.005955036263912916, 0.00011324687511660159, 0.000537818530574441, 0.00024974963162094355, 6.34682146483101e-05, 0.0009358474053442478, 0.0], [0.0026473035104572773, 0.007075308356434107, 0.0509142205119133, 0.9043685793876648, 0.01687121018767357, 0.0027590824756771326, 0.0040096985176205635, 0.004973408300429583, 0.006381142418831587, 0.0], [0.00491793267428875, 0.0006714555202051997, 0.0047717769630253315, 0.09624139964580536, 0.8609471917152405, 0.020327016711235046, 0.008984015323221684, 0.0013932499568909407, 0.0017457373905926943, 0.0], [0.003878936870023608, 0.005480882711708546, 0.00011314810399198905, 0.0003485401102807373, 0.006120527163147926, 0.0029893070459365845, 0.0006264422554522753, 0.004414959345012903, 0.9760271906852722, 0.0], [2.4277944248751737e-05, 4.029595402244013e-06, 3.533453991622082e-07, 0.0002488488098606467, 2.0782925275852904e-05, 0.0004858894390054047, 0.9990906119346619, 4.08113919547759e-05, 8.422240352956578e-05, 0.0], [0.0008877408690750599, 0.00558891985565424, 8.855570922605693e-05, 8.779557902016677e-06, 0.00010457105963723734, 0.00017662049503996968, 0.002778601599857211, 0.09916532039642334, 0.8912010192871094, 0.0], [0.0001188771057059057, 0.0008492054184898734, 9.383026917930692e-05, 5.974515715934103e-06, 0.002050562761723995, 8.90250812517479e-05, 8.933644130593166e-05, 0.002921103034168482, 0.9937818646430969, 0.0], [0.013618292286992073, 0.005976812914013863, 3.6079491110285744e-05, 4.805085336556658e-05, 0.00010178168304264545, 0.03545643016695976, 0.04239484667778015, 0.35667654871940613, 0.5456912517547607, 0.0]], [[0.5123088955879211, 0.04897892847657204, 0.0054785809479653835, 0.037157464772462845, 0.033040400594472885, 0.0287709329277277, 0.020658228546380997, 0.005767439026385546, 0.3078390061855316, 0.0], [0.02116605080664158, 0.45384567975997925, 0.5185156464576721, 0.002682638820260763, 0.000782136688940227, 0.0011598592391237617, 9.265153494197875e-05, 0.0013774167746305466, 0.00037764458102174103, 0.0], [0.06292606890201569, 0.18043853342533112, 0.7547035217285156, 0.0011577572440728545, 6.746899998688605e-06, 0.00012623713701032102, 8.539375994587317e-05, 0.0005039210664108396, 5.15973697474692e-05, 0.0], [0.0030907110776752234, 0.0035500964149832726, 0.4530283808708191, 0.5221376419067383, 0.009557867422699928, 0.008033420890569687, 0.0001996932114707306, 0.0003745300055015832, 2.7415442673373036e-05, 0.0], [0.00426989421248436, 0.0004077394842170179, 0.010691642761230469, 0.016011668369174004, 0.5530933737754822, 0.12423280626535416, 0.0053755901753902435, 0.28551235795021057, 0.00040478314622305334, 0.0], [0.00216855201870203, 0.009595326147973537, 0.007803121581673622, 0.04625817388296127, 0.24702508747577667, 0.2669595181941986, 0.024053409695625305, 0.24639348685741425, 0.14974308013916016, 0.0], [1.953603486981592e-06, 5.726202516598278e-07, 5.0084551617146644e-08, 6.896097602293594e-06, 0.0001788837107596919, 0.0027895711828023195, 0.9969833493232727, 1.1296948287053965e-05, 2.7187752493773587e-05, 0.0], [0.00041725789196789265, 0.0019265476148575544, 6.523763295263052e-05, 0.00018337361689191312, 0.011946662329137325, 0.04555974155664444, 0.15744170546531677, 0.025624049827456474, 0.7568355202674866, 0.0], [0.0004706757317762822, 0.0027324894908815622, 0.0007427418022416532, 0.00934627279639244, 0.17134670913219452, 0.030644211918115616, 0.08413954824209213, 0.2513456642627716, 0.4492316246032715, 0.0], [0.0002758087939582765, 0.0016877831658348441, 2.4452297111565713e-06, 0.0004533462051767856, 0.001545731327496469, 0.008134560659527779, 0.010873721912503242, 0.026235109195113182, 0.950791597366333, 0.0]], [[0.23423711955547333, 0.3770143389701843, 0.036408282816410065, 0.05249679461121559, 0.12246986478567123, 0.07398416101932526, 0.040104977786540985, 0.05500312149524689, 0.008280999027192593, 0.0], [0.032638370990753174, 0.0975130945444107, 0.07180608063936234, 0.1075734794139862, 0.025216424837708473, 0.0218534916639328, 0.0376754030585289, 0.19710108637809753, 0.40862247347831726, 0.0], [0.06665100902318954, 0.01976630836725235, 0.13041609525680542, 0.1772802174091339, 0.07561768591403961, 0.0061133550480008125, 0.022857116535305977, 0.3516288995742798, 0.14966930449008942, 0.0], [0.007724895142018795, 0.007534464355558157, 0.020593322813510895, 0.32147932052612305, 0.059838466346263885, 0.017819387838244438, 0.13181470334529877, 0.15524767339229584, 0.27794769406318665, 0.0], [0.0028419073205441236, 0.0006648481939919293, 0.0018234169110655785, 0.01609039306640625, 0.0009005893371067941, 0.09726841002702713, 0.11035522073507309, 0.6978457570075989, 0.07220931351184845, 0.0], [0.0027223427314311266, 0.011157790198922157, 0.0052745467983186245, 0.00438346853479743, 0.010011360049247742, 0.38358205556869507, 0.2294219732284546, 0.057655833661556244, 0.29579076170921326, 0.0], [0.0007371494430117309, 0.00023907626746222377, 8.450529276160523e-05, 0.002850313438102603, 0.002168564358726144, 0.02191324159502983, 0.9514637589454651, 0.0002505093871150166, 0.02029278129339218, 0.0], [0.004062887746840715, 0.007024500984698534, 0.007399421185255051, 0.011675640009343624, 0.0680280476808548, 0.02557964250445366, 0.07043837755918503, 0.01946563646197319, 0.7863259315490723, 0.0], [0.0007200208492577076, 0.0024253681767731905, 0.029840486124157906, 0.00014908696175552905, 0.11268167197704315, 0.03336171433329582, 0.007834793999791145, 0.08127990365028381, 0.7317068576812744, 0.0], [0.008707311004400253, 0.08642490208148956, 0.11372587829828262, 0.004973042756319046, 0.13256150484085083, 0.16557356715202332, 0.0817113071680069, 0.006879410706460476, 0.3994430899620056, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"], \"bot_text\": [\"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}, \"out_out\": {\"att\": [[[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9412446618080139, 0.05875528231263161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7461972832679749, 0.18569768965244293, 0.06810508668422699, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4299372434616089, 0.16845084726810455, 0.2029547393321991, 0.19865721464157104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5215166807174683, 0.16121163964271545, 0.19463112950325012, 0.09347883611917496, 0.029161658138036728, 0.0, 0.0, 0.0, 0.0, 0.0], [0.26405569911003113, 0.04358615726232529, 0.10687251389026642, 0.1710020899772644, 0.4105237126350403, 0.0039598336443305016, 0.0, 0.0, 0.0, 0.0], [0.29189321398735046, 0.19170531630516052, 0.11295431852340698, 0.08274418860673904, 0.12850242853164673, 0.09739833325147629, 0.09480219334363937, 0.0, 0.0, 0.0], [0.3496137857437134, 0.03085259348154068, 0.0195528082549572, 0.45414459705352783, 0.09152030944824219, 0.008845902979373932, 0.02992299199104309, 0.01554702315479517, 0.0, 0.0], [0.4675538241863251, 0.03941410034894943, 0.05400091037154198, 0.17985978722572327, 0.20104949176311493, 0.030323797836899757, 0.010615098290145397, 0.015154700726270676, 0.002028239192441106, 0.0], [0.053565241396427155, 0.029699191451072693, 0.0156599972397089, 0.016939852386713028, 0.04015244543552399, 0.21933501958847046, 0.1449035257101059, 0.4037321209907532, 0.019583676010370255, 0.056428998708724976]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5249735116958618, 0.4750264883041382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3563348054885864, 0.5701623558998108, 0.07350286096334457, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3398579955101013, 0.23167477548122406, 0.1957632154226303, 0.23270410299301147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4351256191730499, 0.09737284481525421, 0.08845506608486176, 0.06574707478284836, 0.31329941749572754, 0.0, 0.0, 0.0, 0.0, 0.0], [0.360861599445343, 0.02136792428791523, 0.005633710417896509, 0.009215844795107841, 0.15762653946876526, 0.4452943205833435, 0.0, 0.0, 0.0, 0.0], [0.009015758521854877, 0.0013937305193394423, 0.00017763266805559397, 0.00016997012426145375, 0.010879353620111942, 0.0024589570239186287, 0.9759047627449036, 0.0, 0.0, 0.0], [0.014776602387428284, 0.0001805058855097741, 1.6896785382414237e-05, 0.0003442507586441934, 0.006220621056854725, 0.0012393802171573043, 0.9433164596557617, 0.033905431628227234, 0.0, 0.0], [0.005810329224914312, 0.002043980173766613, 0.0003433740057516843, 0.001522325212135911, 0.0030212807469069958, 0.00817712489515543, 0.5456522107124329, 0.10564129799604416, 0.32778817415237427, 0.0], [0.3754594326019287, 0.030579065904021263, 0.028458155691623688, 0.035943739116191864, 0.28040432929992676, 0.0202159583568573, 0.0396210215985775, 0.05075624957680702, 0.13473623991012573, 0.0038258912973105907]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9630448818206787, 0.036955028772354126, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8940342664718628, 0.015322646126151085, 0.09064316004514694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4866876006126404, 0.028273453935980797, 0.4569007158279419, 0.028138065710663795, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7252220511436462, 0.10817205905914307, 0.07890959084033966, 0.017715180292725563, 0.06998112797737122, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8598019480705261, 0.012843498960137367, 0.014502299018204212, 0.004056263715028763, 0.10580158233642578, 0.0029942472465336323, 0.0, 0.0, 0.0, 0.0], [0.8686293363571167, 0.024889284744858742, 0.013860221020877361, 0.00703870365396142, 0.07120370119810104, 0.003939351066946983, 0.010439489968121052, 0.0, 0.0, 0.0], [0.8572709560394287, 0.018014011904597282, 0.008267350494861603, 0.0022140766959637403, 0.1038530021905899, 0.004275611136108637, 0.0009780752006918192, 0.005126776173710823, 0.0, 0.0], [0.35013046860694885, 0.0037752145435661077, 0.0071558705531060696, 0.01608894392848015, 0.6097922325134277, 0.002463925164192915, 0.0005387101555243134, 0.005540961865335703, 0.004513624589890242, 0.0], [0.1888049989938736, 0.12293454259634018, 0.5947631597518921, 0.009457849897444248, 0.07291270792484283, 0.008950368501245975, 0.0004109511792194098, 0.000914009811822325, 0.0006959570455364883, 0.00015547229850199074]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.91131192445755, 0.08868805319070816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.786292314529419, 0.09286607056856155, 0.1208416074514389, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1722075194120407, 0.10747934877872467, 0.1462225317955017, 0.5740904808044434, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1893281787633896, 0.1733204573392868, 0.06838839501142502, 0.47577211260795593, 0.09319086372852325, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08935888856649399, 0.012517428956925869, 0.017112966626882553, 0.08479276299476624, 0.7640082240104675, 0.03220977261662483, 0.0, 0.0, 0.0, 0.0], [0.824190616607666, 0.008810147643089294, 0.002143737394362688, 0.002297793049365282, 0.11996792256832123, 0.005709697026759386, 0.036880046129226685, 0.0, 0.0, 0.0], [0.1513449102640152, 0.015725232660770416, 0.02784004621207714, 0.01800909824669361, 0.6534391641616821, 0.016422629356384277, 0.09054289758205414, 0.026676079258322716, 0.0, 0.0], [0.1625923067331314, 0.016224535182118416, 0.06514906883239746, 0.003223034320399165, 0.6737184524536133, 0.014129054732620716, 0.036937959492206573, 0.023035621270537376, 0.004990031942725182, 0.0], [0.06836045533418655, 0.01236770860850811, 0.008784784935414791, 0.014186863787472248, 0.09790214896202087, 0.046204064041376114, 0.1703491061925888, 0.1878211945295334, 0.0703599750995636, 0.32366377115249634]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.961704432964325, 0.038295578211545944, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.37462106347084045, 0.2157517969608307, 0.40962719917297363, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48521965742111206, 0.031020229682326317, 0.3760664165019989, 0.10769358277320862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.914044201374054, 0.004715718794614077, 0.006151301320642233, 0.005079128313809633, 0.07000966370105743, 0.0, 0.0, 0.0, 0.0, 0.0], [0.060511741787195206, 0.006127620115876198, 0.00728148128837347, 0.013585635460913181, 0.9084653854370117, 0.004028240218758583, 0.0, 0.0, 0.0, 0.0], [0.23348243534564972, 0.03748093172907829, 0.055222347378730774, 0.014132470823824406, 0.27614685893058777, 0.017582375556230545, 0.3659524619579315, 0.0, 0.0, 0.0], [0.06461911648511887, 0.003781915409490466, 0.002705940278246999, 0.016099220141768456, 0.8774597644805908, 0.012668337672948837, 0.0088069261983037, 0.013858767226338387, 0.0, 0.0], [0.05451222136616707, 0.014412143267691135, 0.00208102585747838, 0.011283651925623417, 0.02552390843629837, 0.02239326573908329, 0.031104939058423042, 0.20777365565299988, 0.630915105342865, 0.0], [0.5451503992080688, 0.014764615334570408, 0.2503703534603119, 0.037022024393081665, 0.0935375839471817, 0.022694993764162064, 0.0037449353840202093, 0.0053339023143053055, 0.007315538357943296, 0.020065704360604286]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9904667735099792, 0.009533224627375603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9818503260612488, 0.007338901981711388, 0.010810752399265766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9738979935646057, 0.007647394668310881, 0.015154722146689892, 0.0032999368850141764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6611008644104004, 0.04138284549117088, 0.1119912639260292, 0.0262944046407938, 0.15923058986663818, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9380988478660583, 0.005562208592891693, 0.01078465860337019, 0.004562946502119303, 0.033130958676338196, 0.007860423997044563, 0.0, 0.0, 0.0, 0.0], [0.9377894997596741, 0.003691342193633318, 0.002771170577034354, 0.0017416415503248572, 0.04246653988957405, 0.002464305842295289, 0.009075501933693886, 0.0, 0.0, 0.0], [0.9083399176597595, 0.005597027484327555, 0.02609928511083126, 0.005710097029805183, 0.017865832895040512, 0.0029857312329113483, 0.002900469582527876, 0.030501706525683403, 0.0, 0.0], [0.8338009119033813, 0.00436164066195488, 0.006190306507050991, 0.0008050849428400397, 0.015337309800088406, 0.00863864365965128, 0.010715007781982422, 0.1143304780125618, 0.005820483900606632, 0.0], [0.9085996747016907, 0.00676243519410491, 0.02013525180518627, 0.009278967045247555, 0.02104269526898861, 0.009343095123767853, 0.0009470531367696822, 0.0018253516172990203, 0.003784958738833666, 0.018280424177646637]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.972051739692688, 0.027948210015892982, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7552067041397095, 0.17251533269882202, 0.0722779706120491, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6455309987068176, 0.23265127837657928, 0.10187581926584244, 0.01994187943637371, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.470674991607666, 0.26442891359329224, 0.14268451929092407, 0.03363766148686409, 0.08857394009828568, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6457618474960327, 0.011289404705166817, 0.008832731284201145, 0.01570025272667408, 0.2588561475276947, 0.059559762477874756, 0.0, 0.0, 0.0, 0.0], [0.4916176497936249, 0.07200384140014648, 0.0701020285487175, 0.019148536026477814, 0.0833231583237648, 0.12199999392032623, 0.14180481433868408, 0.0, 0.0, 0.0], [0.11119699478149414, 0.002801541704684496, 0.0021932011004537344, 0.0016493132570758462, 0.06827285885810852, 0.22499483823776245, 0.5049597024917603, 0.08393163233995438, 0.0, 0.0], [0.13208742439746857, 0.0035411729477345943, 0.0015305017586797476, 0.002489483682438731, 0.06612236052751541, 0.213859423995018, 0.5324232578277588, 0.03503565117716789, 0.012910734862089157, 0.0], [0.20209012925624847, 0.05223073810338974, 0.03088257648050785, 0.036374326795339584, 0.014660456217825413, 0.03045688569545746, 0.03597142919898033, 0.16862399876117706, 0.022359324619174004, 0.40635016560554504]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9218347668647766, 0.0781652107834816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4189925193786621, 0.4865715503692627, 0.09443587809801102, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48251789808273315, 0.34758540987968445, 0.13321316242218018, 0.036683470010757446, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8504839539527893, 0.033341050148010254, 0.053517427295446396, 0.012789242900907993, 0.049868300557136536, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4515743553638458, 0.03267433121800423, 0.019386781379580498, 0.024256065487861633, 0.17900733649730682, 0.29310107231140137, 0.0, 0.0, 0.0, 0.0], [0.5910289883613586, 0.0027754076290875673, 0.004533650353550911, 0.0023315453436225653, 0.08002334088087082, 0.06913208961486816, 0.2501751184463501, 0.0, 0.0, 0.0], [0.1626552939414978, 0.0011573631782084703, 0.00017211545491591096, 0.0007665579323656857, 0.03241841867566109, 0.34369325637817383, 0.2890424132347107, 0.17009468376636505, 0.0, 0.0], [0.10835989564657211, 0.0007107920246198773, 0.00030798258376307786, 0.005807099863886833, 0.04662986099720001, 0.1659584492444992, 0.3522194027900696, 0.30094781517982483, 0.019058646634221077, 0.0], [0.5449283123016357, 0.01310307253152132, 0.008020865730941296, 0.006764447782188654, 0.16009773313999176, 0.06950337439775467, 0.0024397175293415785, 0.014089844189584255, 0.013654321432113647, 0.1673980951309204]]], [[[0.03246883675456047, 0.020431363955140114, 0.06294436007738113, 0.08282972872257233, 0.047490958124399185, 0.03976213559508324, 0.01868664100766182, 0.5054241418838501, 0.18996170163154602, 0.0], [0.0334412157535553, 0.45350977778434753, 0.23828978836536407, 0.07703227549791336, 0.02545342594385147, 0.019935714080929756, 0.007961008697748184, 0.08864670246839523, 0.05572996661067009, 0.0], [0.008816813118755817, 0.009350132197141647, 0.09488566964864731, 0.022458655759692192, 0.001578008639626205, 0.01768183708190918, 0.0012928039068356156, 0.7889453768730164, 0.05499071627855301, 0.0], [0.0037117439787834883, 0.00603569345548749, 0.019362367689609528, 0.06632085889577866, 0.02251342497766018, 0.048607613891363144, 0.00711278198286891, 0.7890322804450989, 0.03730323165655136, 0.0], [0.0017165049212053418, 0.0031809706706553698, 0.00569736585021019, 0.027958940714597702, 0.001130971242673695, 0.006313299294561148, 0.004051794297993183, 0.9312260150909424, 0.018723946064710617, 0.0], [0.0028915719594806433, 0.007050157990306616, 0.004614752251654863, 0.0017270235111936927, 0.0016248916508629918, 0.06901240348815918, 0.005150379613041878, 0.13293159008026123, 0.7749972939491272, 0.0], [0.005032604560256004, 0.005055313929915428, 0.0030569147784262896, 0.0010687477188184857, 0.012304573319852352, 0.013984610326588154, 0.3489484190940857, 0.012370014563202858, 0.5981789827346802, 0.0], [0.0019784842152148485, 0.009333183988928795, 0.005381024908274412, 0.0002465381403453648, 0.0013898308388888836, 0.005461550783365965, 0.0012134313583374023, 0.001065099611878395, 0.9739308953285217, 0.0], [0.005657540168613195, 0.006781480740755796, 0.00696007814258337, 0.0009338636882603168, 0.02429838851094246, 0.03842600807547569, 0.00286328443326056, 0.03579647094011307, 0.8782829642295837, 0.0], [0.007395321968942881, 0.012293249368667603, 0.006963892374187708, 0.00022730379714630544, 0.0005401583621278405, 0.005707587581127882, 0.0028992195148020983, 0.0027063635643571615, 0.9612669944763184, 0.0]], [[0.02470340207219124, 0.02512546442449093, 0.11353036016225815, 0.35132649540901184, 0.20412008464336395, 0.027150044217705727, 0.015305055305361748, 0.05760098248720169, 0.1811380535364151, 0.0], [0.009894105605781078, 0.02192404493689537, 0.3007009029388428, 0.13983333110809326, 0.03682582825422287, 0.08908118307590485, 0.27657952904701233, 0.026430398225784302, 0.09873086214065552, 0.0], [0.011459765024483204, 0.044317521154880524, 0.5289616584777832, 0.19549138844013214, 0.03426412120461464, 0.017797794193029404, 0.030613277107477188, 0.0163635965436697, 0.12073105573654175, 0.0], [0.011578483507037163, 0.0029169816989451647, 0.00455811433494091, 0.01625976897776127, 0.018393559381365776, 0.11749742925167084, 0.32938554883003235, 0.41049671173095703, 0.08891336619853973, 0.0], [0.0033444140572100878, 0.0011373214656487107, 0.0019445078214630485, 0.02781311236321926, 0.0049105980433523655, 0.05221953243017197, 0.09222303330898285, 0.3644186854362488, 0.45198866724967957, 0.0], [0.002199131529778242, 0.0006913270917721093, 0.002652444876730442, 0.017487458884716034, 0.18746966123580933, 0.39171290397644043, 0.26989367604255676, 0.017002178356051445, 0.11089123785495758, 0.0], [0.01051913108676672, 0.003755246289074421, 0.0008555634994991124, 0.002675057854503393, 0.0025919810868799686, 0.02418649010360241, 0.018060903996229172, 0.003447937313467264, 0.9339075684547424, 0.0], [0.029951948672533035, 0.006547479424625635, 0.030934682115912437, 0.0036260345950722694, 0.1420958936214447, 0.19529034197330475, 0.1491098254919052, 0.009723717346787453, 0.43272000551223755, 0.0], [0.017757408320903778, 0.006832967512309551, 0.028906390070915222, 0.00921954121440649, 0.054915353655815125, 0.028632348403334618, 0.03646676614880562, 0.01978384144604206, 0.7974854707717896, 0.0], [0.06588920205831528, 0.05552517622709274, 0.18546447157859802, 0.007839588448405266, 0.020484987646341324, 0.01699826307594776, 0.01947665773332119, 0.017759086564183235, 0.6105626821517944, 0.0]], [[0.14391662180423737, 0.11156481504440308, 0.4162432849407196, 0.07845085859298706, 0.04067624360322952, 0.016916701570153236, 0.012291320599615574, 0.10670017451047897, 0.07323983311653137, 0.0], [0.0171683169901371, 0.03512553498148918, 0.4936983287334442, 0.18945446610450745, 0.020571058616042137, 0.011469473131000996, 0.04002959281206131, 0.08968089520931244, 0.10280223935842514, 0.0], [0.2093620002269745, 0.11281707882881165, 0.25891542434692383, 0.14515942335128784, 0.0042000748217105865, 0.006485591176897287, 0.005525505635887384, 0.14364667236804962, 0.11388827115297318, 0.0], [0.0109701631590724, 0.0007525839027948678, 0.011503712274134159, 0.03920656442642212, 0.2449047565460205, 0.048431187868118286, 0.12996943295001984, 0.4081973731517792, 0.10606419295072556, 0.0], [0.004995591007173061, 0.0001893905719043687, 0.0009439413552172482, 0.03207648918032646, 0.08267047256231308, 0.015983520075678825, 0.02033340558409691, 0.8191123604774475, 0.023694908246397972, 0.0], [0.0022357299458235502, 0.000793653482105583, 0.0010144039988517761, 0.2958794832229614, 0.3394852876663208, 0.07495945692062378, 0.06856833398342133, 0.06118563562631607, 0.15587811172008514, 0.0], [0.0020441634114831686, 0.00032311712857335806, 0.0006899640429764986, 0.03996479511260986, 0.38782593607902527, 0.05503879860043526, 0.24750953912734985, 0.004524962045252323, 0.26207876205444336, 0.0], [0.0012333561899140477, 0.0002747838443610817, 0.0023864947725087404, 0.10253860056400299, 0.4721597135066986, 0.04103615880012512, 0.03782818093895912, 0.026908699423074722, 0.31563398241996765, 0.0], [0.004791810177266598, 0.0015037101693451405, 0.004669447895139456, 0.38809871673583984, 0.13379721343517303, 0.024320820346474648, 0.03647102415561676, 0.013309511356055737, 0.3930378258228302, 0.0], [0.00849083997309208, 0.003579143201932311, 0.0033037925604730844, 0.006032468285411596, 0.017621049657464027, 0.0234503336250782, 0.018282314762473106, 0.02657976746559143, 0.8926602602005005, 0.0]], [[0.8417463898658752, 0.05951714888215065, 0.012198105454444885, 0.03180553764104843, 0.02919766865670681, 0.0096508814021945, 0.003031272441148758, 0.0009100366733036935, 0.011942943558096886, 0.0], [0.00569154741242528, 0.979739785194397, 0.012030904181301594, 0.0001143000990850851, 9.368032624479383e-05, 0.0008171445806510746, 0.00012590458209160715, 0.0005024938145652413, 0.0008843241375871003, 0.0], [0.005223963409662247, 0.005622355733066797, 0.9848889708518982, 0.002582893241196871, 0.0003334738139528781, 0.0005618981667794287, 3.256636409787461e-05, 0.00024550766102038324, 0.0005086653982289135, 0.0], [0.0032260464504361153, 0.007557107135653496, 0.0651315227150917, 0.6094849109649658, 0.008782745338976383, 0.2748804986476898, 0.015592943876981735, 0.008143502287566662, 0.007200630847364664, 0.0], [0.01683628372848034, 0.0020552987698465586, 0.00783018209040165, 0.008005303330719471, 0.0011927365558221936, 0.9284406900405884, 0.03478293865919113, 0.00030738895293325186, 0.0005490221083164215, 0.0], [0.0004254023951943964, 7.111614831956103e-05, 0.0008891545585356653, 1.880968193290755e-05, 6.570573896169662e-05, 0.9941434860229492, 0.0025632327888160944, 9.733852493809536e-06, 0.0018130606040358543, 0.0], [7.936867405078374e-06, 1.8136512153432705e-05, 4.5569290705316234e-06, 1.071940641850233e-05, 3.808495648627286e-06, 0.0008168917265720665, 0.9974388480186462, 1.4373016711033415e-05, 0.0016848900122568011, 0.0], [0.0014213839313015342, 0.003971228376030922, 0.008488249033689499, 2.0282970581320114e-05, 8.774230809649453e-05, 0.030342059209942818, 0.010436602868139744, 0.013138609007000923, 0.9320940375328064, 0.0], [9.058997966349125e-05, 0.0009022729936987162, 0.0017266678623855114, 1.3629892237077001e-05, 0.000727150880265981, 0.002379553159698844, 0.0010508937994018197, 0.012508089654147625, 0.9806011319160461, 0.0], [0.0003429521748330444, 0.001905322540551424, 0.0005013775080442429, 1.1471392099338118e-05, 0.00017356597527395934, 0.0029742273036390543, 0.003938945475965738, 0.028075864538550377, 0.9620763063430786, 0.0]], [[0.23634016513824463, 0.09021607041358948, 0.12040459364652634, 0.01354933436959982, 0.0019137230701744556, 0.009001325815916061, 0.028688833117485046, 0.2612648904323578, 0.23862121999263763, 0.0], [0.2307557761669159, 0.2812652289867401, 0.30346915125846863, 0.05031246319413185, 0.006193350534886122, 0.01668362505733967, 0.012607063166797161, 0.07951408624649048, 0.019199388101696968, 0.0], [0.29960742592811584, 0.20819564163684845, 0.27825382351875305, 0.007396433036774397, 0.0007608149899169803, 0.0260151494294405, 0.012685009278357029, 0.12934625148773193, 0.03773954138159752, 0.0], [0.035675279796123505, 0.035874202847480774, 0.007117687724530697, 0.018771182745695114, 0.010206644423305988, 0.06527784466743469, 0.03775254264473915, 0.7770709991455078, 0.012253628112375736, 0.0], [0.012017791159451008, 0.0028583300299942493, 0.0024127706419676542, 0.002610970288515091, 0.001820205245167017, 0.04092223569750786, 0.016621166840195656, 0.9115477800369263, 0.009188669733703136, 0.0], [0.03447290509939194, 0.013388306833803654, 0.08488336205482483, 0.015237652696669102, 0.19176845252513885, 0.3472833037376404, 0.10885429382324219, 0.192628413438797, 0.011483324691653252, 0.0], [0.0005363536183722317, 0.0001964608090929687, 0.0017719777533784509, 0.003164003835991025, 0.27662715315818787, 0.05286016687750816, 0.648875892162323, 0.007890382781624794, 0.00807751715183258, 0.0], [0.001257028547115624, 0.00020761204359587282, 0.0024441492278128862, 0.003374723019078374, 0.9062062501907349, 0.0712839737534523, 0.0032159662805497646, 0.009974849410355091, 0.0020355340093374252, 0.0], [0.0008205634076148272, 0.00019305139721836895, 0.002098840195685625, 0.004588909447193146, 0.9688709378242493, 0.01628950424492359, 0.0038415545132011175, 0.0016231476329267025, 0.0016735766548663378, 0.0], [0.03610469028353691, 0.046298399567604065, 0.04650943726301193, 0.02111651562154293, 0.06683006882667542, 0.37146270275115967, 0.174205482006073, 0.15773150324821472, 0.07974111288785934, 0.0]], [[0.03425053879618645, 0.026130978018045425, 0.3080751299858093, 0.027706336230039597, 0.12989944219589233, 0.29902005195617676, 0.0305496696382761, 0.03879137709736824, 0.1055762991309166, 0.0], [0.004509713500738144, 0.02305547706782818, 0.939035952091217, 0.006188178434967995, 0.020785806700587273, 0.00040150884888134897, 0.00018676061881706119, 0.00013036451127845794, 0.005706076975911856, 0.0], [0.0005241778562776744, 0.009561678394675255, 0.988527774810791, 2.2495760276797228e-05, 4.7274414100684226e-05, 0.00013538387429434806, 4.543165232462343e-06, 6.27172994427383e-05, 0.001113483915105462, 0.0], [0.06551901996135712, 0.0800878182053566, 0.06342226266860962, 0.00974376779049635, 0.5160938501358032, 0.02204274758696556, 0.004013149533420801, 0.0735243633389473, 0.1655530482530594, 0.0], [0.0013552415184676647, 0.0004213388019707054, 0.002606122987344861, 0.0010090378345921636, 0.24638326466083527, 0.6568374633789062, 0.01604411192238331, 0.04806208983063698, 0.027281243354082108, 0.0], [0.0002145337639376521, 0.00018796027870848775, 0.0008407118148170412, 0.0029629908967763186, 0.28427600860595703, 0.6725634336471558, 0.023870857432484627, 0.00339014851488173, 0.011693413369357586, 0.0], [0.0009873382514342666, 0.0005485343281179667, 6.628077971981838e-05, 0.0029302756302058697, 0.23183174431324005, 0.05256076529622078, 0.5701138377189636, 0.005792138632386923, 0.13516920804977417, 0.0], [2.471696279826574e-05, 2.0868348656222224e-05, 4.437468305695802e-05, 0.002024284563958645, 0.9655042886734009, 0.024176988750696182, 0.001284845289774239, 0.00018083618488162756, 0.006738840136677027, 0.0], [0.0007289832574315369, 7.746354822302237e-05, 0.00018428664770908654, 0.014176051132380962, 0.9112405180931091, 0.013280178420245647, 0.003417921019718051, 0.02014165185391903, 0.03675319626927376, 0.0], [0.00874137319624424, 0.03438721224665642, 0.17507928609848022, 0.007159235887229443, 0.0029199302662163973, 0.023628318682312965, 0.007933209650218487, 0.004559694789350033, 0.7355918884277344, 0.0]], [[0.01947755739092827, 0.007096209097653627, 0.03225293010473251, 0.0123430285602808, 0.10373923927545547, 0.44083938002586365, 0.04899014160037041, 0.25500863790512085, 0.08025286346673965, 0.0], [0.018974049016833305, 0.05092930048704147, 0.38670486211776733, 0.05532746762037277, 0.02096201851963997, 0.23439037799835205, 0.029592081904411316, 0.06233520433306694, 0.1407845914363861, 0.0], [0.009641589596867561, 0.009545106440782547, 0.19981582462787628, 0.009672220796346664, 0.003704657079651952, 0.04582780599594116, 0.006998295895755291, 0.5789687037467957, 0.13582585752010345, 0.0], [0.00450306897982955, 0.0034239809028804302, 0.012258612550795078, 0.005700208712369204, 0.04511384665966034, 0.4419432282447815, 0.12840862572193146, 0.13075105845928192, 0.22789721190929413, 0.0], [0.00048664878704585135, 0.00010348611976951361, 0.0010980216320604086, 0.0006185582024045289, 0.028226494789123535, 0.37447214126586914, 0.09456676244735718, 0.48241522908210754, 0.018012629821896553, 0.0], [8.0467427324038e-05, 3.9275117160286754e-05, 0.00016763176245149225, 0.00013412459520623088, 0.009092556312680244, 0.7851189374923706, 0.16675172746181488, 0.0029041438829153776, 0.03571125119924545, 0.0], [0.0007275060634128749, 0.00015159584290813655, 0.00037383963353931904, 0.0005468691233545542, 0.01837681420147419, 0.03491391986608505, 0.7517433166503906, 0.00028147027478553355, 0.19288486242294312, 0.0], [0.0005560970166698098, 0.0002987806510645896, 0.0021934551186859608, 0.00023410467838402838, 0.023030919954180717, 0.05263887345790863, 0.01838914304971695, 0.0007265828317031264, 0.9019319415092468, 0.0], [0.007445591501891613, 0.0020796440076082945, 0.012208829633891582, 0.001590645289979875, 0.09274771064519882, 0.017371611669659615, 0.04761578515172005, 0.004260089714080095, 0.8146799802780151, 0.0], [0.014990360476076603, 0.004210897721350193, 0.002848376054316759, 0.0006518716691061854, 0.0007818753365427256, 0.0019951288122683764, 0.0036728696431964636, 0.0004030312702525407, 0.9704453349113464, 0.0]], [[0.21779413521289825, 0.08220235258340836, 0.04201545566320419, 0.07069981843233109, 0.041075702756643295, 0.13784317672252655, 0.1975526064634323, 0.04344295710325241, 0.16737376153469086, 0.0], [0.23605762422084808, 0.07441659271717072, 0.04143041744828224, 0.05435749515891075, 0.0077708023600280285, 0.0960790365934372, 0.4399828016757965, 0.006641789805144072, 0.04326343908905983, 0.0], [0.06337786465883255, 0.03357791155576706, 0.03929098695516586, 0.5017232298851013, 0.0066258725710213184, 0.009236367419362068, 0.1690734624862671, 0.0422079935669899, 0.13488635420799255, 0.0], [0.006272959988564253, 0.0007428607787005603, 0.0011506476439535618, 0.007357995491474867, 0.0006080326274968684, 0.05679970234632492, 0.8685706257820129, 0.03271445259451866, 0.025782890617847443, 0.0], [0.041861388832330704, 0.004794578067958355, 0.0024879220873117447, 0.015253551304340363, 0.0005973980878479779, 0.08281483501195908, 0.814189076423645, 0.006639576051384211, 0.03136153519153595, 0.0], [0.010862020775675774, 0.0008270516409538686, 0.00023008826246950775, 0.006298262160271406, 0.0022151959128677845, 0.09469958394765854, 0.8416994214057922, 0.0006256845663301647, 0.04254243150353432, 0.0], [0.00024508681963197887, 3.835038296529092e-05, 2.0304802092141472e-05, 0.00012946058996021748, 0.0003255259362049401, 0.0026247953064739704, 0.9805192947387695, 0.00014136231038719416, 0.01595580205321312, 0.0], [0.001919803791679442, 0.0005674636922776699, 0.0002780239738058299, 0.0008655164856463671, 0.0013816945720463991, 0.010561172850430012, 0.05357982590794563, 0.0009362901910208166, 0.9299100637435913, 0.0], [0.00319756381213665, 0.0005108749028295279, 0.00043022894533351064, 0.005312783177942038, 0.005197612568736076, 0.008492776192724705, 0.05858352780342102, 0.01401757076382637, 0.9042569398880005, 0.0], [0.00021474930690601468, 0.0004951281007379293, 0.00032367443782277405, 0.0001866286911536008, 6.129321263870224e-05, 0.00016246296581812203, 0.0016925180098041892, 0.000427676277467981, 0.996435821056366, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9262088537216187, 0.07379112392663956, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2983383536338806, 0.576672375202179, 0.12498921155929565, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3100782334804535, 0.1274886280298233, 0.5286650061607361, 0.033768050372600555, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3118414282798767, 0.11087317764759064, 0.12077098339796066, 0.10916762799024582, 0.34734681248664856, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1361667662858963, 0.0034004957415163517, 0.00320720998570323, 0.0056303562596440315, 0.013746269047260284, 0.8378488421440125, 0.0, 0.0, 0.0, 0.0], [0.9168469905853271, 0.009582683444023132, 0.002923850901424885, 0.009140468202531338, 0.0233402531594038, 0.01968987099826336, 0.01847577467560768, 0.0, 0.0, 0.0], [0.4528708755970001, 0.012551077641546726, 0.013286955654621124, 0.003301329677924514, 0.024005549028515816, 0.0439622700214386, 0.03865182027220726, 0.41137006878852844, 0.0, 0.0], [0.06380993872880936, 0.0008893097401596606, 0.0011801879154518247, 0.0013187900185585022, 0.0034512828569859266, 0.0014297974994406104, 0.0023058890365064144, 0.041651248931884766, 0.8839635848999023, 0.0], [0.5330018997192383, 0.012773798778653145, 0.01854255609214306, 0.022641947492957115, 0.1288023591041565, 0.01178218238055706, 0.020595960319042206, 0.08756020665168762, 0.09921147674322128, 0.06508753448724747]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9422653913497925, 0.057734500616788864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.37070432305336, 0.2449311465024948, 0.3843645751476288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5423898100852966, 0.11884469538927078, 0.1850128471851349, 0.15375272929668427, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7452426552772522, 0.024770371615886688, 0.025099167600274086, 0.014617366716265678, 0.19027042388916016, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4940005838871002, 0.026306116953492165, 0.014163044281303883, 0.022562485188245773, 0.43185216188430786, 0.011115492321550846, 0.0, 0.0, 0.0, 0.0], [0.8323472142219543, 0.005361876450479031, 0.001218354911543429, 0.0017811520956456661, 0.06672050058841705, 0.0179598405957222, 0.07461105287075043, 0.0, 0.0, 0.0], [0.5900163650512695, 0.0016051119891926646, 0.00041884748497977853, 0.002425695303827524, 0.09076588600873947, 0.005809221416711807, 0.03928956016898155, 0.2696692943572998, 0.0, 0.0], [0.14191001653671265, 0.0026981914415955544, 0.000433926354162395, 0.0025318085681647062, 0.0752185806632042, 0.041030533611774445, 0.10226735472679138, 0.6134982705116272, 0.020411266013979912, 0.0], [0.9951959252357483, 0.000172812317032367, 0.0011272057890892029, 0.0002565488684922457, 0.001650187186896801, 0.0010172545444220304, 3.585639569791965e-05, 0.00030177918961271644, 2.7251116989646107e-05, 0.00021514984837267548]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9959792494773865, 0.004020644351840019, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8763805031776428, 0.06819441169500351, 0.05542506277561188, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6675543785095215, 0.035431310534477234, 0.2554236948490143, 0.04159051924943924, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8250302076339722, 0.013232334516942501, 0.10887149721384048, 0.016031241044402122, 0.03683457896113396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14042839407920837, 0.005938003305345774, 0.04128086566925049, 0.01834655925631523, 0.7866368293762207, 0.007369248662143946, 0.0, 0.0, 0.0, 0.0], [0.3567042350769043, 0.0165000781416893, 0.015264611691236496, 0.010309864766895771, 0.38396307826042175, 0.025359012186527252, 0.1918991357088089, 0.0, 0.0, 0.0], [0.03735272213816643, 0.0005555232055485249, 0.0009066119673661888, 0.003488750196993351, 0.4253699481487274, 0.039391178637742996, 0.3313658535480499, 0.1615692675113678, 0.0, 0.0], [0.0020103107672184706, 0.0002689870889298618, 0.0004340466111898422, 0.0009705349220894277, 0.03535917028784752, 0.014057940803468227, 0.07802704721689224, 0.8683921694755554, 0.0004796571738552302, 0.0], [0.21001528203487396, 0.008917403407394886, 0.08127831667661667, 0.6020672917366028, 0.0504239983856678, 0.01106872595846653, 0.002271559089422226, 0.009885885752737522, 0.013363776728510857, 0.010707534849643707]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8274853825569153, 0.1725146621465683, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39722761511802673, 0.5465205311775208, 0.05625181272625923, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7089572548866272, 0.12511004507541656, 0.08669630438089371, 0.0792364850640297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9339975714683533, 0.013466393575072289, 0.00928713008761406, 0.00507207540795207, 0.03817704692482948, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7470325231552124, 0.0030789184384047985, 0.0006101431790739298, 0.009402818977832794, 0.23476918041706085, 0.005106179974973202, 0.0, 0.0, 0.0, 0.0], [0.21711143851280212, 0.003716376842930913, 0.00037448908551596105, 0.0019620254170149565, 0.018900232389569283, 0.009617134928703308, 0.7483181953430176, 0.0, 0.0, 0.0], [0.010075456462800503, 5.468959716381505e-05, 5.17756825502147e-06, 5.762913860962726e-05, 0.0005752856959588826, 0.0004235330270603299, 0.004707484506070614, 0.9841007590293884, 0.0, 0.0], [0.0014721885090693831, 9.766960283741355e-05, 9.390318155055866e-06, 9.01468301890418e-05, 0.00026504675042815506, 0.0001477079640608281, 0.0007441531051881611, 0.9970147013664246, 0.00015886487381067127, 0.0], [0.9506397247314453, 0.010028047487139702, 0.0004243685398250818, 0.012790095992386341, 0.006212451495230198, 0.0008045415161177516, 0.0008908100426197052, 0.0004145564162172377, 0.0002187698701163754, 0.01757662557065487]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9158000946044922, 0.0841999277472496, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9424960017204285, 0.02535107545554638, 0.032153017818927765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22060541808605194, 0.18997374176979065, 0.08500542491674423, 0.5044154524803162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7531844973564148, 0.02070058509707451, 0.008920542895793915, 0.016695866361260414, 0.20049844682216644, 0.0, 0.0, 0.0, 0.0, 0.0], [0.759453296661377, 0.0056156679056584835, 0.008695651777088642, 0.014426307752728462, 0.16163751482963562, 0.05017174035310745, 0.0, 0.0, 0.0, 0.0], [0.2527230679988861, 0.0006535803549923003, 0.00037003192119300365, 0.00041730765951797366, 0.057080648839473724, 0.06757333129644394, 0.6211821436882019, 0.0, 0.0, 0.0], [0.6996693015098572, 0.00526623846963048, 0.003115275641903281, 0.001864676014520228, 0.019210346043109894, 0.022201303392648697, 0.16487717628479004, 0.08379579335451126, 0.0, 0.0], [0.01643717661499977, 0.001304203411564231, 0.00015219511988107115, 8.364384120795876e-05, 0.0027460975106805563, 0.005807426758110523, 0.02910688892006874, 0.054244525730609894, 0.8901176452636719, 0.0], [0.03737838938832283, 0.0008823095704428852, 0.00013810240488965064, 0.0003819032572209835, 0.0009168537217192352, 0.017434338107705116, 0.0524771511554718, 0.5634113550186157, 0.05003770440816879, 0.27694204449653625]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9822245836257935, 0.017775410786271095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9783667922019958, 0.004186260513961315, 0.01744689606130123, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8277915120124817, 0.0035995396319776773, 0.1268300712108612, 0.04177885130047798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9593387246131897, 0.001320014358498156, 0.002763292985036969, 0.002305841539055109, 0.03427214175462723, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5380056500434875, 0.00011044789425795898, 0.001150083844549954, 0.002725756261497736, 0.45681822299957275, 0.0011898496886715293, 0.0, 0.0, 0.0, 0.0], [0.16147758066654205, 0.001678255619481206, 0.004225697834044695, 0.012547606602311134, 0.4120558202266693, 0.030565770342946053, 0.37744930386543274, 0.0, 0.0, 0.0], [0.07655133306980133, 0.00011485892173368484, 0.0004792730906046927, 0.0037317569367587566, 0.9091346859931946, 0.005207230802625418, 0.003226343309506774, 0.0015543886693194509, 0.0, 0.0], [0.0006837816908955574, 6.692374881822616e-05, 3.2170661143027246e-05, 0.017242103815078735, 0.9703013896942139, 0.0009919245494529605, 0.00010187587758991867, 0.00012404048175085336, 0.01045528706163168, 0.0], [0.8681296706199646, 0.004244405776262283, 0.0034055972937494516, 0.0032342004124075174, 0.11890427023172379, 0.00032322408515028656, 1.7166490579256788e-05, 8.356601756531745e-05, 0.00016651467012707144, 0.0014914675848558545]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9673911333084106, 0.032608743757009506, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8945506811141968, 0.048047225922346115, 0.05740200728178024, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8226539492607117, 0.025171183049678802, 0.033602889627218246, 0.1185719221830368, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7488189339637756, 0.022310951724648476, 0.03220387548208237, 0.05049983412027359, 0.14616648852825165, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5947939157485962, 0.009725339710712433, 0.01194794476032257, 0.06678443402051926, 0.22137242555618286, 0.09537594765424728, 0.0, 0.0, 0.0, 0.0], [0.5493549704551697, 0.010730843059718609, 0.013811847195029259, 0.01375968661159277, 0.13386781513690948, 0.031593821942806244, 0.2468811273574829, 0.0, 0.0, 0.0], [0.44999176263809204, 0.0022518665064126253, 0.007128801662474871, 0.06941325962543488, 0.11436374485492706, 0.06527625769376755, 0.25339174270629883, 0.038182370364665985, 0.0, 0.0], [0.6273319125175476, 0.0019851899705827236, 0.014608433470129967, 0.053566914051771164, 0.10037831962108612, 0.05395424738526344, 0.09709113836288452, 0.020020073279738426, 0.031063806265592575, 0.0], [0.13732852041721344, 0.005784862674772739, 0.011142567731440067, 0.3659982979297638, 0.03412118926644325, 0.191008523106575, 0.02493627928197384, 0.01782877929508686, 0.005097466055303812, 0.2067534178495407]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9590145349502563, 0.0409853532910347, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13186156749725342, 0.7104970812797546, 0.15764127671718597, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1307007521390915, 0.4791290760040283, 0.2198515087366104, 0.1703186184167862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.25735223293304443, 0.03605807572603226, 0.08834479749202728, 0.21978884935379028, 0.398455947637558, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014754761941730976, 0.016280202195048332, 0.010505245067179203, 0.26496851444244385, 0.6780229210853577, 0.015468388795852661, 0.0, 0.0, 0.0, 0.0], [0.0561433881521225, 0.00821017101407051, 0.013592599891126156, 0.04250938817858696, 0.20505541563034058, 0.637790322303772, 0.03669866546988487, 0.0, 0.0, 0.0], [0.02288638986647129, 0.0031705975998193026, 0.0010986417764797807, 0.1258203089237213, 0.13997967541217804, 0.6275703310966492, 0.004779829643666744, 0.07469423860311508, 0.0, 0.0], [0.04480466619133949, 0.007826470769941807, 0.0012622721260413527, 0.18829701840877533, 0.1579897105693817, 0.4087865948677063, 0.0030938636045902967, 0.17715193331241608, 0.010787548497319221, 0.0], [0.2647387683391571, 0.0023117128293961287, 0.5836825370788574, 0.022214042022824287, 0.05302866920828819, 0.05609899014234543, 0.0002153095556423068, 0.0012429821072146297, 0.012765316292643547, 0.0037017168942838907]]], [[[0.18620921671390533, 0.0449230894446373, 0.15743261575698853, 0.0027164025232195854, 0.000954743183683604, 0.10880818217992783, 0.004260051064193249, 0.4840531051158905, 0.010642877779901028, 0.0], [0.10068266838788986, 0.8361198902130127, 0.05278307944536209, 0.003077939385548234, 0.0006954235723242164, 0.001363753923214972, 0.00026539582177065313, 0.004202431067824364, 0.0008096573874354362, 0.0], [0.012129311449825764, 0.01155073568224907, 0.9600933194160461, 8.282387716462836e-05, 1.0725593710958492e-05, 0.0005505315493792295, 8.825069380691275e-05, 0.015057343989610672, 0.00043726651347242296, 0.0], [8.100323611870408e-05, 0.0004598332743626088, 0.004657193087041378, 0.000634010590147227, 0.00027469659107737243, 0.005632649641484022, 0.000647437758743763, 0.9867796301841736, 0.0008332319557666779, 0.0], [0.00010327257041353732, 8.895192149793729e-05, 0.0004001102061010897, 3.5898548958357424e-05, 8.903054549591616e-06, 0.002168947132304311, 0.0003314291825518012, 0.9968016743659973, 6.082480831537396e-05, 0.0], [0.0006819640402682126, 0.0025551444850862026, 0.029635878279805183, 0.0007182788685895503, 0.0009121407056227326, 0.9391846656799316, 0.0023257755674421787, 0.020892569795250893, 0.0030933902598917484, 0.0], [0.0006610184791497886, 0.004029686562716961, 0.03350083529949188, 0.0028945906087756157, 0.06891647726297379, 0.0361749529838562, 0.6805889010429382, 0.0015104033518582582, 0.17172299325466156, 0.0], [0.00011510718468343839, 0.00041600633994676173, 0.007651225198060274, 0.0003919293521903455, 0.048794399946928024, 0.12390702962875366, 0.005600529722869396, 0.0008058404200710356, 0.8123176097869873, 0.0], [0.0003188557457178831, 0.0017433647299185395, 0.0013032852439209819, 0.008202485740184784, 0.26753997802734375, 0.1699969321489334, 0.02015369012951851, 0.026912324130535126, 0.5038290619850159, 0.0], [0.020566454157233238, 0.12752646207809448, 0.13235142827033997, 8.515831723343581e-05, 0.0007726486655883491, 0.005525102838873863, 0.002064254367724061, 0.0015006973408162594, 0.7096077799797058, 0.0]], [[0.08830718696117401, 0.003260435536503792, 0.007942354306578636, 0.007197668310254812, 0.023230358958244324, 0.6884769797325134, 0.13524922728538513, 0.013760159723460674, 0.03257569298148155, 0.0], [0.01410764642059803, 0.011476421728730202, 0.655226469039917, 0.029443562030792236, 0.17404575645923615, 0.04738258570432663, 0.035108331590890884, 0.004049936309456825, 0.02915901131927967, 0.0], [0.006112441886216402, 0.010383019223809242, 0.9739192724227905, 0.0017695348942652345, 0.0007649966282770038, 0.001380802714265883, 0.0003705607377924025, 0.00034036929719150066, 0.004958811681717634, 0.0], [0.025388794019818306, 0.006199578754603863, 0.10192698240280151, 0.0023500584065914154, 0.009979050606489182, 0.5388055443763733, 0.29305511713027954, 0.002850176068022847, 0.0194447822868824, 0.0], [0.0011180925648659468, 3.349311737110838e-05, 0.00020844468963332474, 0.00016400347521994263, 0.001158660277724266, 0.5398337244987488, 0.4514371454715729, 0.00012239665375091136, 0.005924074444919825, 0.0], [4.934398384648375e-05, 6.905893883413228e-07, 5.809057256556116e-06, 1.44853029269143e-05, 0.0013859024038538337, 0.62599116563797, 0.3719564974308014, 0.0002632574178278446, 0.00033293903106823564, 0.0], [1.8935834305011667e-05, 5.593590231001144e-06, 9.02482042874908e-06, 4.666295353672467e-05, 0.00140501803252846, 0.0024830379988998175, 0.9939435124397278, 0.00030495785176754, 0.0017833412857726216, 0.0], [0.00015082204481586814, 9.979225069400854e-06, 0.00013493606820702553, 0.0006857623811811209, 0.9507938623428345, 0.013522839173674583, 0.004887807182967663, 0.001293701701797545, 0.028520429506897926, 0.0], [0.00021830093464814126, 1.1190621080459096e-05, 0.0010014179861173034, 0.0016852812841534615, 0.9693949818611145, 0.003066261066123843, 0.002616706071421504, 0.006246546749025583, 0.015759343281388283, 0.0], [0.033513687551021576, 0.047761499881744385, 0.1371326446533203, 0.027179328724741936, 0.07905351370573044, 0.04665757715702057, 0.017991477623581886, 0.0258343443274498, 0.5848759412765503, 0.0]], [[0.3675236701965332, 0.22013956308364868, 0.3048599064350128, 0.045011524111032486, 0.013697491027414799, 0.012050136923789978, 0.009531261399388313, 0.0020223394967615604, 0.025163909420371056, 0.0], [0.013416368514299393, 0.7244334816932678, 0.22923606634140015, 0.004823721945285797, 0.0007022434147074819, 0.0012150612892583013, 0.001360778696835041, 0.00021415007358882576, 0.024598030373454094, 0.0], [0.03640636429190636, 0.024720389395952225, 0.8944843411445618, 0.0018058173591271043, 0.00014742508938070387, 0.002046161564067006, 0.0012721297098323703, 0.0010774562833830714, 0.0380399152636528, 0.0], [0.032080236822366714, 0.02157183177769184, 0.017530914396047592, 0.21374234557151794, 0.5176447033882141, 0.021586988121271133, 0.06124785542488098, 0.004810539539903402, 0.10978466272354126, 0.0], [0.16469916701316833, 0.0144515885040164, 0.007452514488250017, 0.029052020981907845, 0.2643658220767975, 0.1970161497592926, 0.2818319797515869, 0.016781603917479515, 0.024349281564354897, 0.0], [0.025996195152401924, 0.005627068690955639, 0.007119623012840748, 0.004898787476122379, 0.5349600911140442, 0.05678911507129669, 0.3094601333141327, 0.008422048762440681, 0.04672713205218315, 0.0], [0.004280757624655962, 0.0006373892538249493, 9.946383943315595e-05, 0.00030879577388986945, 0.02805289998650551, 0.008433223702013493, 0.9252934455871582, 0.001439885818399489, 0.03145414590835571, 0.0], [0.04426492750644684, 0.0032368048559874296, 0.0014763016952201724, 0.0021763627883046865, 0.5636131763458252, 0.010265699587762356, 0.08146306872367859, 0.003517861943691969, 0.289985716342926, 0.0], [0.012160537764430046, 0.00020874926121905446, 0.0005602578166872263, 0.0007960868533700705, 0.9389106035232544, 0.005963308271020651, 0.005384649150073528, 0.0009963578777387738, 0.035019390285015106, 0.0], [0.006462599150836468, 0.006167746149003506, 0.00141435069963336, 0.00035615835804492235, 0.0002947094908449799, 0.002378113567829132, 0.011835698038339615, 0.0024426754098385572, 0.968647837638855, 0.0]], [[0.013161101378500462, 0.01350532379001379, 0.39494189620018005, 0.007352527230978012, 0.12711142003536224, 0.14605116844177246, 0.03487401455640793, 0.15623201429843903, 0.10677067190408707, 0.0], [0.021876059472560883, 0.4906902313232422, 0.4596463143825531, 0.004091671667993069, 0.004464378114789724, 0.001156727666966617, 0.000353646173607558, 0.000146497564855963, 0.017574656754732132, 0.0], [0.005734701175242662, 0.026843877509236336, 0.9321272969245911, 0.00021884289162699133, 0.00045866103027947247, 0.0010309598874300718, 0.00017261962057091296, 0.003054215107113123, 0.030358724296092987, 0.0], [0.0482722632586956, 0.14050070941448212, 0.4546079635620117, 0.0072937230579555035, 0.023873258382081985, 0.09857403486967087, 0.0516686774790287, 0.11766187101602554, 0.05754747614264488, 0.0], [0.0020078516099601984, 0.002228439087048173, 0.111594557762146, 0.0033910104539245367, 0.08423032611608505, 0.17691271007061005, 0.14758752286434174, 0.4346924424171448, 0.037355244159698486, 0.0], [0.0008274781284853816, 0.0016531302826479077, 0.047970183193683624, 0.0006053023971617222, 0.22220103442668915, 0.6234129071235657, 0.05364101752638817, 0.012585645541548729, 0.03710317984223366, 0.0], [2.7583497285377234e-05, 1.1631378583842888e-05, 4.4259006244828925e-05, 0.0006730516324751079, 0.599366307258606, 0.006597205530852079, 0.3886081576347351, 0.0003169252013321966, 0.004354946780949831, 0.0], [2.752073669398669e-06, 2.0648456029448425e-06, 8.536147106497083e-06, 6.34281532256864e-05, 0.9992840886116028, 0.00028667543665505946, 7.951273437356576e-05, 3.5721727726922836e-06, 0.00026920961681753397, 0.0], [3.3996084312093444e-06, 2.1497796751646092e-06, 7.304265182028757e-06, 0.00018760550301522017, 0.99969482421875, 2.4790026145637967e-05, 3.4293629141757265e-05, 6.942725121916737e-06, 3.892222957802005e-05, 0.0], [0.0005689842510037124, 0.002939490834251046, 0.019829533994197845, 0.0003717679646797478, 0.01646142266690731, 0.011912180110812187, 0.001234701368957758, 0.0013870754046365619, 0.945294976234436, 0.0]], [[0.00632825493812561, 0.011520092375576496, 0.08263711631298065, 0.006356080062687397, 0.022936103865504265, 0.03108564019203186, 0.013897407799959183, 0.697504997253418, 0.12773430347442627, 0.0], [0.008715116418898106, 0.015272715128958225, 0.10463730990886688, 0.08011683076620102, 0.13045108318328857, 0.05373600497841835, 0.015578814782202244, 0.4212273955345154, 0.1702648103237152, 0.0], [0.004959889687597752, 0.007777809165418148, 0.14492008090019226, 0.02459821291267872, 0.014704479835927486, 0.016136664897203445, 0.008129375986754894, 0.7319321036338806, 0.0468413271009922, 0.0], [0.005315575283020735, 0.0021190166007727385, 0.007080279756337404, 0.006970370654016733, 0.010002117604017258, 0.007610250264406204, 0.004703941754996777, 0.8570073246955872, 0.09919113665819168, 0.0], [0.0016317280242219567, 0.0005414763581939042, 0.004523266106843948, 0.0019645043648779392, 0.010821727104485035, 0.008883371017873287, 0.00927714817225933, 0.920802652835846, 0.041554201394319534, 0.0], [0.002020488725975156, 0.0007793906843289733, 0.022791940718889236, 0.005821499973535538, 0.1932065784931183, 0.30031588673591614, 0.08197023719549179, 0.12508654594421387, 0.2680076062679291, 0.0], [0.007396090775728226, 0.0032474161125719547, 0.00692824088037014, 0.007240207865834236, 0.42384257912635803, 0.04473983123898506, 0.013007782399654388, 0.007779541425406933, 0.4858182966709137, 0.0], [0.0026900237426161766, 0.0007204422145150602, 0.005861051380634308, 0.003422616282477975, 0.46744993329048157, 0.10402297228574753, 0.05837857723236084, 0.0177029799669981, 0.3397515118122101, 0.0], [0.005906206555664539, 0.002057044068351388, 0.0031123505905270576, 0.008901549503207207, 0.43650564551353455, 0.08504725992679596, 0.0923796221613884, 0.009556618519127369, 0.3565336763858795, 0.0], [0.013360978104174137, 0.04520300775766373, 0.09048072248697281, 0.012179902754724026, 0.030064363032579422, 0.023480970412492752, 0.008669134229421616, 0.03746046498417854, 0.7391002178192139, 0.0]], [[0.023652182891964912, 0.008639940991997719, 0.08203616738319397, 0.035750582814216614, 0.050224509090185165, 0.3533262312412262, 0.03081362321972847, 0.28302860260009766, 0.1325281411409378, 0.0], [0.016670020297169685, 0.1283574253320694, 0.836423397064209, 0.0042742472141981125, 0.0022883012425154448, 0.00297459471039474, 0.00022807312780059874, 0.0012588471872732043, 0.007524838205426931, 0.0], [0.031559381633996964, 0.02045642025768757, 0.8176267743110657, 0.006169404834508896, 0.0014412011951208115, 0.0069603933952748775, 0.0010916722239926457, 0.011522608809173107, 0.10317197442054749, 0.0], [0.004598122555762529, 0.004610949195921421, 0.01865001954138279, 0.020574036985635757, 0.0137012405321002, 0.7973257303237915, 0.01646837778389454, 0.023596635088324547, 0.1004747673869133, 0.0], [0.0005213705007918179, 0.00018707667186390609, 0.0016978917410597205, 0.019619440659880638, 0.009308884851634502, 0.8590161800384521, 0.024511896073818207, 0.06970686465501785, 0.015430280938744545, 0.0], [0.0001481063081882894, 2.072651477647014e-05, 0.00035672096419148147, 0.00033358228392899036, 0.00040588833508081734, 0.9861487746238708, 0.00651955883949995, 0.00443643843755126, 0.0016300288261845708, 0.0], [0.0010996124474331737, 0.0011850595474243164, 0.0075045316480100155, 0.004539311397820711, 0.05570072680711746, 0.18870605528354645, 0.23963898420333862, 0.013960372656583786, 0.487665593624115, 0.0], [0.0003884119214490056, 0.0004658032557927072, 0.028157439082860947, 0.0002352961164433509, 0.1278570294380188, 0.08260466903448105, 0.02582997828722, 0.022790132090449333, 0.7116712927818298, 0.0], [0.0015414542285725474, 0.0007310948567464948, 0.010464987717568874, 0.0012846259633079171, 0.45206302404403687, 0.029316790401935577, 0.04706822335720062, 0.018986493349075317, 0.4385431706905365, 0.0], [0.0005072542116977274, 0.0011837932979688048, 0.01220926083624363, 8.532252832083032e-05, 0.0018606879748404026, 0.010199862532317638, 0.0016309961210936308, 0.010775143280625343, 0.9615475535392761, 0.0]], [[0.29744189977645874, 0.04770943149924278, 0.09888078272342682, 0.19768767058849335, 0.048243775963783264, 0.12058595567941666, 0.05976371467113495, 0.03847452625632286, 0.09121233224868774, 0.0], [0.04126456007361412, 0.6604095697402954, 0.028894882649183273, 0.20104490220546722, 0.0014044500421732664, 0.0009343607816845179, 0.00244489056058228, 0.007453228812664747, 0.05614929273724556, 0.0], [0.008357543498277664, 0.0022072584833949804, 0.9876156449317932, 8.841200906317681e-05, 1.4883004041621462e-05, 0.00011741811613319442, 2.7020510970032774e-05, 0.00016062626673374325, 0.001411277218721807, 0.0], [0.06216944754123688, 0.48559242486953735, 0.042546145617961884, 0.034007471054792404, 0.047574639320373535, 0.12490913271903992, 0.07922931015491486, 0.013364763930439949, 0.11060672253370285, 0.0], [0.05222959443926811, 0.025416702032089233, 0.02865077182650566, 0.17457211017608643, 0.03144511207938194, 0.3907364010810852, 0.19607771933078766, 0.05274118855595589, 0.04813018813729286, 0.0], [0.0037726862356066704, 0.0031579534988850355, 0.0029440780635923147, 0.0017320584738627076, 0.060473062098026276, 0.761774480342865, 0.1523173600435257, 0.0058823637664318085, 0.007945872843265533, 0.0], [0.0020738786552101374, 0.0012752892216667533, 0.0004058163322042674, 0.020963717252016068, 0.39340031147003174, 0.012434415519237518, 0.4783190190792084, 0.011497312225401402, 0.0796302929520607, 0.0], [5.31752230017446e-05, 1.4492364243778866e-05, 7.312332309084013e-05, 0.0023682843893766403, 0.9866323471069336, 0.0009243910317309201, 0.0011850211303681135, 0.0017622504383325577, 0.0069872229360044, 0.0], [4.074166645295918e-05, 1.823456841520965e-05, 0.0001418270985595882, 0.007263784296810627, 0.9604514241218567, 0.0001852070417953655, 0.00034164052340202034, 0.0018497714772820473, 0.029707150533795357, 0.0], [0.0133396340534091, 0.03136875480413437, 0.6319980621337891, 0.0033722908701747656, 0.04728742688894272, 0.03541773557662964, 0.009523973800241947, 0.03100484237074852, 0.1966874897480011, 0.0]], [[0.03367111459374428, 0.018932543694972992, 0.09506545215845108, 0.04718795791268349, 0.028798582032322884, 0.33658939599990845, 0.02586139366030693, 0.29842811822891235, 0.11546547710895538, 0.0], [0.006203038617968559, 0.0906001627445221, 0.6977949738502502, 0.018352899700403214, 0.06787873804569244, 0.04403599724173546, 0.001631368650123477, 0.024296771734952927, 0.049206044524908066, 0.0], [0.006243667099624872, 0.010453532449901104, 0.7879610657691956, 0.004093538969755173, 0.0008473669877275825, 0.027760563418269157, 0.0003080451278947294, 0.14831961691379547, 0.014012438245117664, 0.0], [0.004387176129966974, 0.023410169407725334, 0.17247918248176575, 0.03958609700202942, 0.023799436166882515, 0.43659475445747375, 0.014754846692085266, 0.2318120151758194, 0.05317622795701027, 0.0], [0.0020952164195477962, 0.0024118656292557716, 0.028229335322976112, 0.007075420115143061, 0.019164882600307465, 0.5397294163703918, 0.034580815583467484, 0.3465326428413391, 0.020180128514766693, 0.0], [0.00020744462381117046, 0.00036016973899677396, 0.004934145137667656, 0.0004664760490413755, 0.008187839761376381, 0.9661812782287598, 0.009987047873437405, 0.003882928751409054, 0.005792597308754921, 0.0], [3.4081476769642904e-05, 1.7181657312903553e-05, 5.4824478866066784e-05, 0.00045897584641352296, 0.0043338024988770485, 0.001544477418065071, 0.9909620881080627, 2.356152981519699e-05, 0.0025708049070090055, 0.0], [0.0001047314508468844, 0.0001599654060555622, 0.001310097286477685, 0.001540280063636601, 0.833267331123352, 0.044754061847925186, 0.0028599577490240335, 0.0006454077665694058, 0.11535807698965073, 0.0], [8.819431968731806e-05, 6.364465662045404e-05, 0.00022057128080632538, 0.001112746773287654, 0.9560981392860413, 0.003599100047722459, 0.0002217600413132459, 0.0006697923527099192, 0.03792598471045494, 0.0], [0.0018130787648260593, 0.022020958364009857, 0.12822051346302032, 0.0005810249131172895, 0.03168048337101936, 0.014293116517364979, 0.002500524278730154, 0.0212943647056818, 0.7775959372520447, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6252409815788269, 0.3747589886188507, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8520486354827881, 0.010580658912658691, 0.13737063109874725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05910082906484604, 0.011589597910642624, 0.877491295337677, 0.051818281412124634, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3626183867454529, 0.026959313079714775, 0.07612177729606628, 0.13077552616596222, 0.4035249352455139, 0.0, 0.0, 0.0, 0.0, 0.0], [0.21979263424873352, 0.001410112832672894, 0.007092535495758057, 0.13166557252407074, 0.626970648765564, 0.013068560510873795, 0.0, 0.0, 0.0, 0.0], [0.08148042857646942, 0.001490423921495676, 0.004908325150609016, 0.01383854728192091, 0.7959722876548767, 0.05201547220349312, 0.05029459297657013, 0.0, 0.0, 0.0], [0.03934427723288536, 5.908778257435188e-05, 0.00014962907880544662, 0.005592166446149349, 0.7025003433227539, 0.1675100177526474, 0.03920353576540947, 0.04564077779650688, 0.0, 0.0], [0.4660189151763916, 0.00034756408422254026, 9.701005183160305e-05, 0.008154522627592087, 0.08121690154075623, 0.15592943131923676, 0.11426379531621933, 0.17044323682785034, 0.0035288764629513025, 0.0], [0.3707294762134552, 0.0020887483842670918, 0.23984688520431519, 0.07748916745185852, 0.18109895288944244, 0.03584783151745796, 0.005205830093473196, 0.005058187525719404, 0.0050886403769254684, 0.0775463655591011]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15256483852863312, 0.8474349975585938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08618302643299103, 0.30268052220344543, 0.6111364364624023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6251113414764404, 0.14608541131019592, 0.21724094450473785, 0.011562197469174862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.31851068139076233, 0.11805614084005356, 0.02926168404519558, 0.0854775682091713, 0.44869405031204224, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23099647462368011, 0.015003926120698452, 0.0028121687937527895, 0.025386620312929153, 0.5829272270202637, 0.14287345111370087, 0.0, 0.0, 0.0, 0.0], [0.2648485600948334, 0.01456066407263279, 0.008421574719250202, 0.01653379574418068, 0.25845009088516235, 0.35933130979537964, 0.07785411924123764, 0.0, 0.0, 0.0], [0.21031156182289124, 0.00652333116158843, 0.005756322760134935, 0.019128819927573204, 0.2526819407939911, 0.49096593260765076, 0.008809886872768402, 0.00582215515896678, 0.0, 0.0], [0.11555754393339157, 0.00475481478497386, 0.0013921409845352173, 0.045808907598257065, 0.29882168769836426, 0.3024459183216095, 0.0483231395483017, 0.18265680968761444, 0.0002390409354120493, 0.0], [0.8451279401779175, 0.021679740399122238, 0.035543736070394516, 0.005811640061438084, 0.04445958510041237, 0.018052000552415848, 0.0015424924204126, 0.013668404892086983, 0.012673787772655487, 0.0014405279653146863]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9927853345870972, 0.007214863318949938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011021426878869534, 0.007158290129154921, 0.9818204641342163, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.007071706000715494, 0.026167649775743484, 0.19316613674163818, 0.773594319820404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.320003479719162, 0.03976304829120636, 0.22334550321102142, 0.24320250749588013, 0.17368540167808533, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10932182520627975, 0.001151762087829411, 0.007792286574840546, 0.18981949985027313, 0.6517421007156372, 0.04017229378223419, 0.0, 0.0, 0.0, 0.0], [0.02538878843188286, 0.005211540497839451, 0.03069700486958027, 0.13252338767051697, 0.4279623329639435, 0.0899164006114006, 0.28830063343048096, 0.0, 0.0, 0.0], [0.010537173599004745, 0.0007831656257621944, 0.0007035965682007372, 0.015162549912929535, 0.9050821661949158, 0.05248205363750458, 0.01132790744304657, 0.00392116466537118, 0.0, 0.0], [0.005222301464527845, 0.003575690556317568, 0.0029950442258268595, 0.00018454395467415452, 0.0012630765559151769, 0.01364975143224001, 0.09376595914363861, 0.853415846824646, 0.02592780999839306, 0.0], [0.14979584515094757, 0.0004723063320852816, 0.4970340430736542, 0.03214645013213158, 0.022075939923524857, 0.006538126152008772, 0.0013381451135501266, 0.0030305178370326757, 0.0008045822032727301, 0.28676414489746094]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9691458940505981, 0.03085414692759514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9338735938072205, 0.02144204080104828, 0.04468445107340813, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4091326594352722, 0.1788463294506073, 0.3530478775501251, 0.058973249047994614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8083640336990356, 0.0245783980935812, 0.02959858626127243, 0.02002020739018917, 0.11743883788585663, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6256738901138306, 0.03313886746764183, 0.03255102410912514, 0.015011090785264969, 0.27659764885902405, 0.017027597874403, 0.0, 0.0, 0.0, 0.0], [0.2970131039619446, 0.01776941865682602, 0.015323061496019363, 0.014444534666836262, 0.2387886643409729, 0.36828577518463135, 0.048375438898801804, 0.0, 0.0, 0.0], [0.16347570717334747, 0.01386126596480608, 0.012116431258618832, 0.006670618429780006, 0.5951986312866211, 0.1577492356300354, 0.024585027247667313, 0.02634291537106037, 0.0, 0.0], [0.1568753868341446, 0.002166055142879486, 0.0014692704426124692, 0.009539359249174595, 0.7249224781990051, 0.0696585550904274, 0.02269914373755455, 0.010646837763488293, 0.0020231890957802534, 0.0], [0.6687246561050415, 0.003988182172179222, 0.00992897991091013, 0.00877397134900093, 0.07160260528326035, 0.14080072939395905, 0.01739262230694294, 0.04941429942846298, 0.01782085746526718, 0.011553076095879078]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.497504860162735, 0.502495288848877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.028444888070225716, 0.01678420603275299, 0.9547709822654724, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02853180095553398, 0.022399114444851875, 0.7835201025009155, 0.1655489057302475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.023048963397741318, 0.055082567036151886, 0.3371332883834839, 0.25099456310272217, 0.33374062180519104, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013693265616893768, 0.057373203337192535, 0.02566814236342907, 0.11711565405130386, 0.13761301338672638, 0.6485366225242615, 0.0, 0.0, 0.0, 0.0], [0.5831283926963806, 0.0857725590467453, 0.06227085366845131, 0.03169894590973854, 0.06183577701449394, 0.01752074435353279, 0.15777261555194855, 0.0, 0.0, 0.0], [0.0033312023151665926, 0.003545752028003335, 0.0018331086030229926, 0.05265560373663902, 0.047756411135196686, 0.045255228877067566, 0.20667387545108795, 0.6389486193656921, 0.0, 0.0], [0.02047032117843628, 0.03542931377887726, 0.01270933635532856, 0.46998995542526245, 0.035482652485370636, 0.015606570988893509, 0.1128709465265274, 0.03180817514657974, 0.26563259959220886, 0.0], [0.027955254539847374, 0.024354776367545128, 0.4609973132610321, 0.07958999276161194, 0.34062448143959045, 0.0068156360648572445, 0.000798556546214968, 0.0009541919571347535, 0.00023223790049087256, 0.05767740309238434]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.909331738948822, 0.09066825360059738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3865880072116852, 0.017979737371206284, 0.5954321622848511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5815560817718506, 0.15706834197044373, 0.052335821092128754, 0.2090395838022232, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7268465161323547, 0.0363004133105278, 0.07873083651065826, 0.06576839834451675, 0.09235385805368423, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6550694704055786, 0.019533857703208923, 0.042362816631793976, 0.07321250438690186, 0.06519921869039536, 0.14462217688560486, 0.0, 0.0, 0.0, 0.0], [0.48597243428230286, 0.05253118649125099, 0.06572883576154709, 0.06831242144107819, 0.06681334227323532, 0.09225586801767349, 0.168385848402977, 0.0, 0.0, 0.0], [0.2283225953578949, 0.01085133571177721, 0.0076954541727900505, 0.03403906524181366, 0.05505141243338585, 0.11318682134151459, 0.23008716106414795, 0.3207661509513855, 0.0, 0.0], [0.31019407510757446, 0.01576145552098751, 0.006604246329516172, 0.1025082990527153, 0.11805430799722672, 0.0999068170785904, 0.17944715917110443, 0.09494999051094055, 0.07257375121116638, 0.0], [0.028495613485574722, 0.00728303287178278, 0.028978589922189713, 0.21746259927749634, 0.0312367994338274, 0.01134485937654972, 0.002138715935871005, 0.0005697175511159003, 0.00012198994954815134, 0.6723678112030029]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9973775148391724, 0.0026223897002637386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8086934685707092, 0.08078567683696747, 0.11052089184522629, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14967022836208344, 0.05171789228916168, 0.3914002478122711, 0.40721163153648376, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.40780311822891235, 0.04434635117650032, 0.05232110992074013, 0.3448564112186432, 0.15067294239997864, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3417491614818573, 0.023165758699178696, 0.008621969260275364, 0.03819064050912857, 0.566249430179596, 0.022023199126124382, 0.0, 0.0, 0.0, 0.0], [0.13283461332321167, 0.0027981544844806194, 0.001892031985335052, 0.057958006858825684, 0.4807162284851074, 0.22431829571723938, 0.09948258846998215, 0.0, 0.0, 0.0], [0.5702553391456604, 0.005225116387009621, 0.0014312443090602756, 0.028526127338409424, 0.15899939835071564, 0.05284468084573746, 0.022491520270705223, 0.16022635996341705, 0.0, 0.0], [0.005209033377468586, 6.901475717313588e-05, 5.760595013271086e-05, 0.006149875931441784, 0.006613760255277157, 0.010193211026489735, 0.013639912940561771, 0.9578513503074646, 0.00021631908020935953, 0.0], [0.5839820504188538, 0.007275882177054882, 0.03890826180577278, 0.2169828861951828, 0.02285575494170189, 0.0033320344518870115, 0.0027764069382101297, 0.032896872609853745, 0.038299210369586945, 0.05269058048725128]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.050016310065984726, 0.9499835968017578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1501082479953766, 0.3363426625728607, 0.5135491490364075, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3278971016407013, 0.3615517318248749, 0.08450257778167725, 0.22604861855506897, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.25667694211006165, 0.40780505537986755, 0.15422186255455017, 0.10097997635602951, 0.08031607419252396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06381947547197342, 0.026328660547733307, 0.009785240516066551, 0.001955200219526887, 0.6504424810409546, 0.24766895174980164, 0.0, 0.0, 0.0, 0.0], [0.20312389731407166, 0.04875075817108154, 0.01619674637913704, 0.028123438358306885, 0.08579143136739731, 0.20417368412017822, 0.41383999586105347, 0.0, 0.0, 0.0], [0.0028631098102778196, 0.00043423930765129626, 0.00021322182146832347, 0.0004598038794938475, 0.009248310700058937, 0.010348621755838394, 0.14874719083309174, 0.8276853561401367, 0.0, 0.0], [0.0018436884274706244, 0.00014077842934057117, 0.00019285999587737024, 0.0006260477821342647, 0.010675687342882156, 0.007219970691949129, 0.05410425364971161, 0.9211149215698242, 0.0040815602988004684, 0.0], [0.0996592566370964, 0.00012745348794851452, 0.0005518114776350558, 0.00026576913660392165, 0.00016320105351042002, 9.30886235437356e-05, 0.00024295282491948456, 0.000212193961488083, 0.0026789114344865084, 0.8960053324699402]]], [[[0.11433855444192886, 0.04686390608549118, 0.05050662159919739, 0.01692698895931244, 0.014815520495176315, 0.0768260732293129, 0.017432983964681625, 0.6101956367492676, 0.052093639969825745, 0.0], [0.38972416520118713, 0.22944767773151398, 0.07904881238937378, 0.052096493542194366, 0.007011168636381626, 0.006784902419894934, 0.0014207185013219714, 0.13426420092582703, 0.10020176321268082, 0.0], [0.11186019331216812, 0.017331605777144432, 0.04338392615318298, 0.012996343895792961, 0.0037596735637634993, 0.04186789691448212, 0.010188347660005093, 0.40130615234375, 0.35730576515197754, 0.0], [0.10591340810060501, 0.20223784446716309, 0.2886512279510498, 0.06966178864240646, 0.009836114011704922, 0.0229511596262455, 0.008547846227884293, 0.14636646211147308, 0.14583423733711243, 0.0], [0.07172418385744095, 0.2936038672924042, 0.03526498004794121, 0.13891401886940002, 0.06139945611357689, 0.03925776481628418, 0.05349786579608917, 0.1478489190340042, 0.15848881006240845, 0.0], [0.06896000355482101, 0.07670743763446808, 0.04746192321181297, 0.0077508557587862015, 0.0064402432180941105, 0.0690828412771225, 0.07239065319299698, 0.3534849286079407, 0.2977212965488434, 0.0], [0.07854402810335159, 0.05315924435853958, 0.006970829796046019, 0.01197806466370821, 0.15678070485591888, 0.059328123927116394, 0.0844779834151268, 0.058106619864702225, 0.4906543493270874, 0.0], [0.03681005910038948, 0.0140389958396554, 0.007565703243017197, 0.004180380143225193, 0.03550711274147034, 0.08768045157194138, 0.04672156274318695, 0.05331620201468468, 0.7141793966293335, 0.0], [0.011276278644800186, 0.0043571279384195805, 0.0015699869254603982, 0.009309794753789902, 0.5466312766075134, 0.06633520126342773, 0.012565904296934605, 0.036605555564165115, 0.3113488256931305, 0.0], [0.06600549072027206, 0.025541657581925392, 0.15132947266101837, 0.052793603390455246, 0.07684693485498428, 0.05682613328099251, 0.01590665802359581, 0.05306769907474518, 0.501682460308075, 0.0]], [[0.02032626047730446, 0.020203230902552605, 0.6020291447639465, 0.030009541660547256, 0.018654465675354004, 0.18802858889102936, 0.05143868178129196, 0.02303317002952099, 0.04627683013677597, 0.0], [0.0019176346249878407, 0.016540443524718285, 0.9535443782806396, 0.023291945457458496, 0.00010982116509694606, 0.0009689715225249529, 0.00013850984396412969, 0.00020587266772054136, 0.0032822038047015667, 0.0], [0.0021921033039689064, 0.01942657120525837, 0.9639623165130615, 0.007353355176746845, 5.936318120802753e-05, 0.0048356493934988976, 6.20722203166224e-05, 0.00038220148417167366, 0.0017266407376155257, 0.0], [0.009460263885557652, 0.051493529230356216, 0.3948231041431427, 0.009338779374957085, 0.006950095761567354, 0.48816555738449097, 0.01867900975048542, 0.0028053568676114082, 0.018284155055880547, 0.0], [0.003424277761951089, 0.009652957320213318, 0.10005363076925278, 0.026136387139558792, 0.09182560443878174, 0.6324647068977356, 0.07658552378416061, 0.005459657870233059, 0.05439731851220131, 0.0], [0.0026005429681390524, 0.0029943317640572786, 0.26352012157440186, 0.02426978573203087, 0.05801504850387573, 0.49633610248565674, 0.11849544942378998, 0.009708826430141926, 0.02405967190861702, 0.0], [0.0027453943621367216, 0.0021119171287864447, 0.0030521987937390804, 0.09308812767267227, 0.28145554661750793, 0.015254919417202473, 0.530491828918457, 0.007408978417515755, 0.06439103186130524, 0.0], [0.0012973180273547769, 0.002199073787778616, 0.0031004296615719795, 0.024488963186740875, 0.8535729050636292, 0.016068320721387863, 0.029179612174630165, 0.012250186875462532, 0.05784311145544052, 0.0], [0.0010010729311034083, 0.0008253253763541579, 0.0028483583591878414, 0.028342707082629204, 0.860925018787384, 0.0038871155120432377, 0.006998666562139988, 0.01413769368082285, 0.08103384077548981, 0.0], [0.007383578456938267, 0.056256651878356934, 0.5807297825813293, 0.01667044125497341, 0.03810223564505577, 0.07880110293626785, 0.009197888895869255, 0.12926581501960754, 0.08359251171350479, 0.0]], [[0.023356424644589424, 0.012650059536099434, 0.05017145350575447, 0.05590398982167244, 0.05159280076622963, 0.01602507382631302, 0.014807065948843956, 0.654244601726532, 0.12124844640493393, 0.0], [0.030835414305329323, 0.04180247709155083, 0.029645785689353943, 0.20071062445640564, 0.010328685864806175, 0.03208288922905922, 0.026780622079968452, 0.457701712846756, 0.17011170089244843, 0.0], [0.02856343612074852, 0.03459611535072327, 0.15441730618476868, 0.04662194848060608, 0.0013040672056376934, 0.017847269773483276, 0.02464178577065468, 0.5969575643539429, 0.09505032747983932, 0.0], [0.03350958973169327, 0.02514287829399109, 0.027676144614815712, 0.11052078753709793, 0.15496152639389038, 0.08862635493278503, 0.027723105624318123, 0.24766287207603455, 0.28417670726776123, 0.0], [0.014355039224028587, 0.005383878946304321, 0.002517768880352378, 0.09422861039638519, 0.06622537225484848, 0.046315327286720276, 0.08473969250917435, 0.4999735355377197, 0.18626095354557037, 0.0], [0.002460801973938942, 0.0016284199664369226, 0.005857668351382017, 0.006880565080791712, 0.7626023292541504, 0.025456121191382408, 0.021016357466578484, 0.06090177595615387, 0.11319592595100403, 0.0], [0.007633878383785486, 0.002682786202058196, 0.0008938225219026208, 0.006808742880821228, 0.17231638729572296, 0.049100711941719055, 0.32851701974868774, 0.0061601921916007996, 0.4258863925933838, 0.0], [0.003303236560896039, 0.0015338786179199815, 0.0017581325955688953, 0.0052335225045681, 0.24177710711956024, 0.09136255830526352, 0.06603478640317917, 0.0047843558713793755, 0.5842124223709106, 0.0], [0.011694137938320637, 0.0015430846251547337, 0.00043408613419160247, 0.005433904007077217, 0.03723231703042984, 0.1666216105222702, 0.04878358170390129, 0.024785596877336502, 0.7034717798233032, 0.0], [0.028515880927443504, 0.0183264147490263, 0.011487613432109356, 0.03205259144306183, 0.06179385632276535, 0.041277043521404266, 0.014015565626323223, 0.06198226660490036, 0.7305486798286438, 0.0]], [[0.10071786493062973, 0.017111245542764664, 0.07246935367584229, 0.01480931881815195, 0.14864948391914368, 0.20273517072200775, 0.054981958121061325, 0.25890761613845825, 0.12961813807487488, 0.0], [0.049787674099206924, 0.02108882926404476, 0.20989678800106049, 0.006962155923247337, 0.21569682657718658, 0.1622857302427292, 0.016771212220191956, 0.2403237521648407, 0.07718709856271744, 0.0], [0.024618864059448242, 0.010488898493349552, 0.4834355115890503, 0.015693388879299164, 0.07393413037061691, 0.055557433515787125, 0.007495412603020668, 0.27800077199935913, 0.05077548325061798, 0.0], [0.006324393209069967, 0.0006586945382878184, 0.02188086323440075, 0.003439908614382148, 0.055277179926633835, 0.5423230528831482, 0.1656835526227951, 0.12264314293861389, 0.08176910877227783, 0.0], [0.008246216922998428, 0.000647101376671344, 0.018551276996731758, 0.0031310885678976774, 0.04379039630293846, 0.34376823902130127, 0.2999532222747803, 0.13205647468566895, 0.1498558670282364, 0.0], [0.001800144906155765, 0.00032634654780849814, 0.02560480497777462, 0.0014933178899809718, 0.04328969866037369, 0.48067817091941833, 0.22867664694786072, 0.008819987997412682, 0.20931090414524078, 0.0], [0.0023069612216204405, 0.0018136217258870602, 0.006447605788707733, 0.005140945315361023, 0.046570103615522385, 0.045606330037117004, 0.3236173987388611, 0.014286459423601627, 0.5542104840278625, 0.0], [0.0025225167628377676, 0.000774701707996428, 0.0168449804186821, 0.0014132045907899737, 0.1692919135093689, 0.21547472476959229, 0.19468647241592407, 0.00621472392231226, 0.392776757478714, 0.0], [0.006893941201269627, 0.0026040272787213326, 0.036687299609184265, 0.0016275923699140549, 0.13132861256599426, 0.15552441775798798, 0.23651301860809326, 0.023025648668408394, 0.4057953953742981, 0.0], [0.004257934633642435, 0.008543262258172035, 0.05716743320226669, 0.0024442216381430626, 0.027526315301656723, 0.08828678727149963, 0.025276461616158485, 0.2843557894229889, 0.5021417737007141, 0.0]], [[0.2613511085510254, 0.024888625368475914, 0.11462423205375671, 0.021279124543070793, 0.1065509021282196, 0.23139707744121552, 0.07117345929145813, 0.09822205454111099, 0.07051338255405426, 0.0], [0.08973123878240585, 0.07256940752267838, 0.3644520342350006, 0.09313907474279404, 0.10501276701688766, 0.026235496625304222, 0.035534195601940155, 0.05646198242902756, 0.15686386823654175, 0.0], [0.12105944007635117, 0.03531542792916298, 0.18099160492420197, 0.04576702043414116, 0.03264385089278221, 0.04934798926115036, 0.0072426870465278625, 0.2739674150943756, 0.25366437435150146, 0.0], [0.049101557582616806, 0.02436317875981331, 0.1119280532002449, 0.019082490354776382, 0.23333144187927246, 0.12024182081222534, 0.09606382250785828, 0.03866123780608177, 0.3072265088558197, 0.0], [0.011761177331209183, 0.004259902983903885, 0.019396282732486725, 0.010304590687155724, 0.5410462021827698, 0.1548439860343933, 0.1577453315258026, 0.022628072649240494, 0.07801424711942673, 0.0], [0.011012338101863861, 0.006456742994487286, 0.03514476120471954, 0.01111147552728653, 0.3646441400051117, 0.06045660004019737, 0.22725869715213776, 0.030072104185819626, 0.2538430392742157, 0.0], [0.0009554855059832335, 0.0010365764610469341, 0.000539954868145287, 0.013481645844876766, 0.6702913641929626, 0.013201623223721981, 0.06565960496664047, 0.008186675608158112, 0.2266470193862915, 0.0], [0.007978711277246475, 0.0019918852485716343, 0.0007363414042629302, 0.010062554851174355, 0.10717969387769699, 0.01258536335080862, 0.08278501033782959, 0.02946571074426174, 0.7472147941589355, 0.0], [0.014369996264576912, 0.00412968173623085, 0.002898097038269043, 0.0381503589451313, 0.28382056951522827, 0.03412872180342674, 0.2624143660068512, 0.04523473232984543, 0.3148534893989563, 0.0], [0.0025127469561994076, 0.0030011499766260386, 0.0036209137178957462, 0.0006047216593287885, 0.01094596553593874, 0.0023283734917640686, 0.003409643191844225, 0.009625249542295933, 0.9639512896537781, 0.0]], [[0.023785226047039032, 0.024275904521346092, 0.6168470978736877, 0.01581703871488571, 0.026939542964100838, 0.1783975064754486, 0.04853774979710579, 0.02762567065656185, 0.03777410835027695, 0.0], [0.006065902300179005, 0.04932599142193794, 0.8176359534263611, 0.018976736813783646, 0.008159944787621498, 0.011068272404372692, 0.010428683832287788, 0.014124251902103424, 0.06421414017677307, 0.0], [0.0037236923817545176, 0.007844064384698868, 0.9502744674682617, 0.0048003061674535275, 0.00022506865207105875, 0.004834793973714113, 0.0015490480000153184, 0.0026021157391369343, 0.024146683514118195, 0.0], [0.0007564057596027851, 0.0017460802337154746, 0.15768493711948395, 0.004074132069945335, 0.015430302359163761, 0.7368869781494141, 0.028010869398713112, 0.013945921324193478, 0.04146439954638481, 0.0], [0.0008445779676549137, 0.0015138997696340084, 0.17073306441307068, 0.0074179465882480145, 0.08121992647647858, 0.5853323936462402, 0.09402737021446228, 0.024092217907309532, 0.034818582236766815, 0.0], [0.0006014688406139612, 0.0016882645431905985, 0.16094569861888885, 0.003698966233059764, 0.034668561071157455, 0.5876308679580688, 0.09562253206968307, 0.05209798738360405, 0.06304588913917542, 0.0], [0.00015283364336937666, 0.0004516944463830441, 0.003205003682523966, 0.0049727726727724075, 0.10853080451488495, 0.03262018784880638, 0.6125266551971436, 0.005719948559999466, 0.2318202555179596, 0.0], [9.933842375176027e-05, 0.00020211786613799632, 0.0037883264012634754, 0.0051808832213282585, 0.6936825513839722, 0.10089477151632309, 0.023457802832126617, 0.011726793833076954, 0.16096755862236023, 0.0], [0.0002526468597352505, 0.0010056017199531198, 0.003837066935375333, 0.034950658679008484, 0.5882559418678284, 0.029549231752753258, 0.030938459560275078, 0.01461110170930624, 0.2965993583202362, 0.0], [0.0029390468262135983, 0.005815382581204176, 0.06488344818353653, 0.008705642074346542, 0.010130577720701694, 0.012970774434506893, 0.019612692296504974, 0.007819950580596924, 0.8671225309371948, 0.0]], [[0.027314670383930206, 0.02143898233771324, 0.1116434708237648, 0.006578116212040186, 0.20446842908859253, 0.3867157995700836, 0.054494183510541916, 0.09778231382369995, 0.08956411480903625, 0.0], [0.09418193250894547, 0.7071846127510071, 0.05323847755789757, 0.0077135805040597916, 0.01789833791553974, 0.010848474688827991, 0.0020562252029776573, 0.01705808937549591, 0.08982021361589432, 0.0], [0.005751691292971373, 0.01031999196857214, 0.8884198069572449, 0.00210022390820086, 0.0066058398224413395, 0.019834432750940323, 0.002143828198313713, 0.02793465554714203, 0.036889322102069855, 0.0], [0.027659546583890915, 0.03931494802236557, 0.10616040229797363, 0.011142275296151638, 0.1017894372344017, 0.30847156047821045, 0.12201698124408722, 0.05519269034266472, 0.22825226187705994, 0.0], [0.037639543414115906, 0.062483835965394974, 0.050776157528162, 0.012697378173470497, 0.27911704778671265, 0.19993652403354645, 0.14870049059391022, 0.10304640233516693, 0.10560261458158493, 0.0], [0.007995839230716228, 0.008397839032113552, 0.03270075097680092, 0.004312656354159117, 0.03775893524289131, 0.3733556568622589, 0.3424486219882965, 0.012857009656727314, 0.18017242848873138, 0.0], [0.012142053805291653, 0.007298614829778671, 0.016982076689600945, 0.02473442070186138, 0.08738671243190765, 0.033574704080820084, 0.27830857038497925, 0.033199213445186615, 0.5063735842704773, 0.0], [0.02729739435017109, 0.05135440081357956, 0.03332214429974556, 0.02499799057841301, 0.11955489963293076, 0.020848069339990616, 0.017926985397934914, 0.01858661323785782, 0.6861116290092468, 0.0], [0.018110578879714012, 0.011406980454921722, 0.0018257799092680216, 0.025524618104100227, 0.3885835111141205, 0.010744227096438408, 0.008441396057605743, 0.003679890651255846, 0.5316829681396484, 0.0], [0.02325628325343132, 0.013795747421681881, 0.0823512151837349, 0.0021813653875142336, 0.03511650115251541, 0.0814405307173729, 0.02589382231235504, 0.14330172538757324, 0.5926627516746521, 0.0]], [[0.011912941001355648, 0.006341524887830019, 0.1334817260503769, 0.017931688576936722, 0.005569889210164547, 0.7441595792770386, 0.0258712787181139, 0.034265827387571335, 0.020465616136789322, 0.0], [0.03743305802345276, 0.05229289084672928, 0.3549361228942871, 0.028500670567154884, 0.01974724419414997, 0.29288655519485474, 0.08050932735204697, 0.06582070142030716, 0.06787342578172684, 0.0], [0.025082573294639587, 0.10057684034109116, 0.7856844663619995, 0.01178921852260828, 0.0010154875926673412, 0.02595749869942665, 0.008632739074528217, 0.006036050152033567, 0.0352250337600708, 0.0], [0.010372502729296684, 0.023954369127750397, 0.18692812323570251, 0.03930393233895302, 0.004741673823446035, 0.46527597308158875, 0.1267295777797699, 0.048278260976076126, 0.09441567957401276, 0.0], [0.0031391805969178677, 0.006868112366646528, 0.1369999200105667, 0.013019833713769913, 0.008593270555138588, 0.6626507639884949, 0.07777946442365646, 0.052107103168964386, 0.03884238004684448, 0.0], [0.005276903510093689, 0.026354510337114334, 0.056238383054733276, 0.03191604092717171, 0.025259410962462425, 0.3898610472679138, 0.2790180444717407, 0.028249284252524376, 0.1578262895345688, 0.0], [0.001124523114413023, 0.0035109275486320257, 0.0021898215636610985, 0.043262600898742676, 0.0267842635512352, 0.03029855713248253, 0.5000982284545898, 0.0056237452663481236, 0.38710734248161316, 0.0], [0.00020160828717052937, 0.0004381221951916814, 0.010444902814924717, 0.010398894548416138, 0.10143585503101349, 0.23107938468456268, 0.09920267760753632, 0.019364865496754646, 0.5274338126182556, 0.0], [0.00019610628078226, 0.00041535915806889534, 0.009462974965572357, 0.005905983969569206, 0.29870083928108215, 0.24092966318130493, 0.11132201552391052, 0.05168075114488602, 0.2813863754272461, 0.0], [0.004135515075176954, 0.014277483336627483, 0.15738952159881592, 0.003396045882254839, 0.01641785353422165, 0.07296160608530045, 0.02388397790491581, 0.024013018235564232, 0.683525025844574, 0.0]]]], \"top_text\": [\"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"], \"bot_text\": [\"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}}" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript object\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "\n", - "/**\n", - " * @fileoverview Transformer Visualization D3 javascript code.\n", - " */\n", - "\n", - "requirejs(['jquery', 'd3'],\n", - "function($, d3) {\n", - "\n", - "var attention = window.attention;\n", - "\n", - "const TEXT_SIZE = 15;\n", - "const BOXWIDTH = TEXT_SIZE * 8;\n", - "const BOXHEIGHT = TEXT_SIZE * 1.5;\n", - "const WIDTH = 2000;\n", - "const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100;\n", - "const MATRIX_WIDTH = 150;\n", - "const head_colours = d3.scale.category10();\n", - "const CHECKBOX_SIZE = 20;\n", - "\n", - "function lighten(colour) {\n", - " var c = d3.hsl(colour);\n", - " var increment = (1 - c.l) * 0.6;\n", - " c.l += increment;\n", - " c.s -= increment;\n", - " return c;\n", - "}\n", - "\n", - "function transpose(mat) {\n", - " return mat[0].map(function(col, i) {\n", - " return mat.map(function(row) {\n", - " return row[i];\n", - " });\n", - " });\n", - "}\n", - "\n", - "function zip(a, b) {\n", - " return a.map(function (e, i) {\n", - " return [e, b[i]];\n", - " });\n", - "}\n", - "\n", - "\n", - "function renderVis(id, top_text, bot_text, attention_heads, config) {\n", - " $(id).empty();\n", - " var svg = d3.select(id)\n", - " .append('svg')\n", - " .attr(\"width\", WIDTH)\n", - " .attr(\"height\", HEIGHT);\n", - "\n", - " var att_data = [];\n", - " for (var i=0; i \u003c attention_heads.length; i++) {\n", - " var att_trans = transpose(attention_heads[i]);\n", - " att_data.push(zip(attention_heads[i], att_trans));\n", - " }\n", - "\n", - " renderText(svg, top_text, true, att_data, 0);\n", - " renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH);\n", - "\n", - " renderAttentionHighlights(svg, att_data);\n", - "\n", - " svg.append(\"g\").classed(\"attention_heads\", true);\n", - "\n", - " renderAttention(svg, attention_heads);\n", - "\n", - " draw_checkboxes(config, 0, svg, attention_heads);\n", - "}\n", - "\n", - "\n", - "function renderText(svg, text, is_top, att_data, left_pos) {\n", - " var id = is_top ? \"top\" : \"bottom\";\n", - " var textContainer = svg.append(\"svg:g\")\n", - " .attr(\"id\", id);\n", - "\n", - " textContainer.append(\"g\").classed(\"attention_boxes\", true)\n", - " .selectAll(\"g\")\n", - " .data(att_data)\n", - " .enter()\n", - " .append(\"g\")\n", - " .selectAll(\"rect\")\n", - " .data(function(d) {return d;})\n", - " .enter()\n", - " .append(\"rect\")\n", - " .attr(\"x\", function(d, i, j) {\n", - " return left_pos + box_offset(j);\n", - " })\n", - " .attr(\"y\", function(d, i) {\n", - " return (+1) * BOXHEIGHT;\n", - " })\n", - " .attr(\"width\", BOXWIDTH/active_heads())\n", - " .attr(\"height\", function() { return BOXHEIGHT; })\n", - " .attr(\"fill\", function(d, i, j) {\n", - " return head_colours(j);\n", - " })\n", - " .style(\"opacity\", 0.0);\n", - "\n", - "\n", - " var tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n", - " .data(text)\n", - " .enter()\n", - " .append(\"g\");\n", - "\n", - " tokenContainer.append(\"rect\")\n", - " .classed(\"background\", true)\n", - " .style(\"opacity\", 0.0)\n", - " .attr(\"fill\", \"lightgray\")\n", - " .attr(\"x\", left_pos)\n", - " .attr(\"y\", function(d, i) {\n", - " return (i+1) * BOXHEIGHT;\n", - " })\n", - " .attr(\"width\", BOXWIDTH)\n", - " .attr(\"height\", BOXHEIGHT);\n", - "\n", - " var theText = tokenContainer.append(\"text\")\n", - " .text(function(d) { return d; })\n", - " .attr(\"font-size\", TEXT_SIZE + \"px\")\n", - " .style(\"cursor\", \"default\")\n", - " .style(\"-webkit-user-select\", \"none\")\n", - " .attr(\"x\", left_pos)\n", - " .attr(\"y\", function(d, i) {\n", - " return (i+1) * BOXHEIGHT;\n", - " });\n", - "\n", - " if (is_top) {\n", - " theText.style(\"text-anchor\", \"end\")\n", - " .attr(\"dx\", BOXWIDTH - TEXT_SIZE)\n", - " .attr(\"dy\", TEXT_SIZE);\n", - " } else {\n", - " theText.style(\"text-anchor\", \"start\")\n", - " .attr(\"dx\", + TEXT_SIZE)\n", - " .attr(\"dy\", TEXT_SIZE);\n", - " }\n", - "\n", - " tokenContainer.on(\"mouseover\", function(d, index) {\n", - " textContainer.selectAll(\".background\")\n", - " .style(\"opacity\", function(d, i) {\n", - " return i == index ? 1.0 : 0.0;\n", - " });\n", - "\n", - " svg.selectAll(\".attention_heads\").style(\"display\", \"none\");\n", - "\n", - " svg.selectAll(\".line_heads\") // To get the nesting to work.\n", - " .selectAll(\".att_lines\")\n", - " .attr(\"stroke-opacity\", function(d) {\n", - " return 1.0;\n", - " })\n", - " .attr(\"y1\", function(d, i) {\n", - " if (is_top) {\n", - " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " } else {\n", - " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " }\n", - " })\n", - " .attr(\"x1\", BOXWIDTH)\n", - " .attr(\"y2\", function(d, i) {\n", - " if (is_top) {\n", - " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " } else {\n", - " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " }\n", - " })\n", - " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", - " .attr(\"stroke-width\", 2)\n", - " .attr(\"stroke\", function(d, i, j) {\n", - " return head_colours(j);\n", - " })\n", - " .attr(\"stroke-opacity\", function(d, i, j) {\n", - " if (is_top) {d = d[0];} else {d = d[1];}\n", - " if (config.head_vis[j]) {\n", - " if (d) {\n", - " return d[index];\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " });\n", - "\n", - "\n", - " function updateAttentionBoxes() {\n", - " var id = is_top ? \"bottom\" : \"top\";\n", - " var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0;\n", - " svg.select(\"#\" + id)\n", - " .selectAll(\".attention_boxes\")\n", - " .selectAll(\"g\")\n", - " .selectAll(\"rect\")\n", - " .attr(\"x\", function(d, i, j) { return the_left_pos + box_offset(j); })\n", - " .attr(\"y\", function(d, i) { return (i+1) * BOXHEIGHT; })\n", - " .attr(\"width\", BOXWIDTH/active_heads())\n", - " .attr(\"height\", function() { return BOXHEIGHT; })\n", - " .style(\"opacity\", function(d, i, j) {\n", - " if (is_top) {d = d[0];} else {d = d[1];}\n", - " if (config.head_vis[j])\n", - " if (d) {\n", - " return d[index];\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " else\n", - " return 0.0;\n", - "\n", - " });\n", - " }\n", - "\n", - " updateAttentionBoxes();\n", - " });\n", - "\n", - " textContainer.on(\"mouseleave\", function() {\n", - " d3.select(this).selectAll(\".background\")\n", - " .style(\"opacity\", 0.0);\n", - "\n", - " svg.selectAll(\".att_lines\").attr(\"stroke-opacity\", 0.0);\n", - " svg.selectAll(\".attention_heads\").style(\"display\", \"inline\");\n", - " svg.selectAll(\".attention_boxes\")\n", - " .selectAll(\"g\")\n", - " .selectAll(\"rect\")\n", - " .style(\"opacity\", 0.0);\n", - " });\n", - "}\n", - "\n", - "function renderAttentionHighlights(svg, attention) {\n", - " var line_container = svg.append(\"g\");\n", - " line_container.selectAll(\"g\")\n", - " .data(attention)\n", - " .enter()\n", - " .append(\"g\")\n", - " .classed(\"line_heads\", true)\n", - " .selectAll(\"line\")\n", - " .data(function(d){return d;})\n", - " .enter()\n", - " .append(\"line\").classed(\"att_lines\", true);\n", - "}\n", - "\n", - "function renderAttention(svg, attention_heads) {\n", - " var line_container = svg.selectAll(\".attention_heads\");\n", - " line_container.html(null);\n", - " for(var h=0; h\u003cattention_heads.length; h++) {\n", - " for(var a=0; a\u003cattention_heads[h].length; a++) {\n", - " for(var s=0; s\u003cattention_heads[h][a].length; s++) {\n", - " line_container.append(\"line\")\n", - " .attr(\"y1\", (s+1) * BOXHEIGHT + (BOXHEIGHT/2))\n", - " .attr(\"x1\", BOXWIDTH)\n", - " .attr(\"y2\", (a+1) * BOXHEIGHT + (BOXHEIGHT/2))\n", - " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", - " .attr(\"stroke-width\", 2)\n", - " .attr(\"stroke\", head_colours(h))\n", - " .attr(\"stroke-opacity\", function() {\n", - " if (config.head_vis[h]) {\n", - " return attention_heads[h][a][s]/active_heads();\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " }());\n", - " }\n", - " }\n", - " }\n", - "}\n", - "\n", - "// Checkboxes\n", - "function box_offset(i) {\n", - " var num_head_above = config.head_vis.reduce(\n", - " function(acc, val, cur) {return val \u0026\u0026 cur \u003c i ? acc + 1: acc;}, 0);\n", - " return num_head_above*(BOXWIDTH / active_heads());\n", - "}\n", - "\n", - "function active_heads() {\n", - " return config.head_vis.reduce(function(acc, val) {\n", - " return val ? acc + 1: acc;\n", - " }, 0);\n", - "}\n", - "\n", - "function draw_checkboxes(config, top, svg, attention_heads) {\n", - " var checkboxContainer = svg.append(\"g\");\n", - " var checkbox = checkboxContainer.selectAll(\"rect\")\n", - " .data(config.head_vis)\n", - " .enter()\n", - " .append(\"rect\")\n", - " .attr(\"fill\", function(d, i) {\n", - " return head_colours(i);\n", - " })\n", - " .attr(\"x\", function(d, i) {\n", - " return (i+1) * CHECKBOX_SIZE;\n", - " })\n", - " .attr(\"y\", top)\n", - " .attr(\"width\", CHECKBOX_SIZE)\n", - " .attr(\"height\", CHECKBOX_SIZE);\n", - "\n", - " function update_checkboxes() {\n", - " checkboxContainer.selectAll(\"rect\")\n", - " .data(config.head_vis)\n", - " .attr(\"fill\", function(d, i) {\n", - " var head_colour = head_colours(i);\n", - " var colour = d ? head_colour : lighten(head_colour);\n", - " return colour;\n", - " });\n", - " }\n", - "\n", - " update_checkboxes();\n", - "\n", - " checkbox.on(\"click\", function(d, i) {\n", - " if (config.head_vis[i] \u0026\u0026 active_heads() == 1) return;\n", - " config.head_vis[i] = !config.head_vis[i];\n", - " update_checkboxes();\n", - " renderAttention(svg, attention_heads);\n", - " });\n", - "\n", - " checkbox.on(\"dblclick\", function(d, i) {\n", - " // If we double click on the only active head then reset\n", - " if (config.head_vis[i] \u0026\u0026 active_heads() == 1) {\n", - " config.head_vis = new Array(config.num_heads).fill(true);\n", - " } else {\n", - " config.head_vis = new Array(config.num_heads).fill(false);\n", - " config.head_vis[i] = true;\n", - " }\n", - " update_checkboxes();\n", - " renderAttention(svg, attention_heads);\n", - " });\n", - "}\n", - "\n", - "var config = {\n", - " layer: 0,\n", - " att_type: 'all',\n", - "};\n", - "\n", - "function visualize() {\n", - " var num_heads = attention['all']['att'][0].length;\n", - " config.head_vis = new Array(num_heads).fill(true);\n", - " config.num_heads = num_heads;\n", - " config.attention = attention;\n", - "\n", - " render();\n", - "}\n", - "\n", - "function render() {\n", - " var conf = config.attention[config.att_type];\n", - "\n", - " var top_text = conf.top_text;\n", - " var bot_text = conf.bot_text;\n", - " var attention = conf.att[config.layer];\n", - "\n", - " $(\"#vis svg\").empty();\n", - " renderVis(\"#vis\", top_text, bot_text, attention, config);\n", - "}\n", - "\n", - "$(\"#layer\").empty();\n", - "for(var i=0; i\u003c6; i++) {\n", - " $(\"#layer\").append($(\"\u003coption /\u003e\").val(i).text(i));\n", - "}\n", - "\n", - "$(\"#layer\").on('change', function(e) {\n", - " config.layer = +e.currentTarget.value;\n", - " render();\n", - "});\n", - "\n", - "$(\"#att_type\").on('change', function(e) {\n", - " config.att_type = e.currentTarget.value;\n", - " render();\n", - "});\n", - "\n", - "$(\"button\").on('click', visualize);\n", - "\n", - "visualize();\n", - "\n", - "});\n" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript object\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "call_html()\n", - "display.display(display.HTML(vis_html))\n", - "display.display(display.Javascript('window.attention = %s' % attention_json))\n", - "display.display(display.Javascript(vis_js))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "lydjSs3hgDVF" - }, - "outputs": [], - "source": [ - "" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "Attention_Visualization_in_Trax.ipynb", - "provenance": [ - { - "file_id": "1bJu3Qx37FY9UpHqVMyXCTNb64v4Iw_v7", - "timestamp": 1598692842045 - } - ], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/trax/models/atari_cnn.py b/trax/models/atari_cnn.py index 99464d527..2cb15e960 100644 --- a/trax/models/atari_cnn.py +++ b/trax/models/atari_cnn.py @@ -19,83 +19,102 @@ def _FrameStack(n_frames): - """Stacks successive game frames along their last dimension.""" - # Input shape: (B, T, ..., C). - # Output shape: (B, T, ..., C * n_frames). - assert n_frames >= 1 - if n_frames == 1: - return [] # No-op; just let the data flow through. - return [ - # Create copies of input sequence, shift right by [0, ..., n_frames - 1] - # frames, and concatenate along the channel dimension. - tl.Branch(*map(_shift_right, range(n_frames))), - tl.Concatenate(n_items=n_frames, axis=-1) - ] + """Stacks successive game frames along their last dimension.""" + # Input shape: (B, T, ..., C). + # Output shape: (B, T, ..., C * n_frames). + assert n_frames >= 1 + if n_frames == 1: + return [] # No-op; just let the data flow through. + return [ + # Create copies of input sequence, shift right by [0, ..., n_frames - 1] + # frames, and concatenate along the channel dimension. + tl.Branch(*map(_shift_right, range(n_frames))), + tl.Concatenate(n_items=n_frames, axis=-1), + ] def _BytesToFloats(): - """Layer that converts unsigned bytes to floats.""" - return tl.Fn('BytesToFloats', lambda x: x / 255.0) - - -def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): - """An Atari CNN.""" - del mode - - # TODO(jonni): Include link to paper? - # Input shape: (B, T, H, W, C) - # Output shape: (B, T, output_size) - return tl.Serial( - _BytesToFloats(), - _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) - tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), - tl.Relu(), - tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), - tl.Relu(), - tl.Flatten(n_axes_to_keep=2), # B, T and rest. - tl.Dense(output_size), - tl.Relu(), - ) - - -def AtariCnnBody(n_frames=4, hidden_sizes=(32, 64, 64), - output_size=512, mode='train', - kernel_initializer=None, padding='VALID'): - """An Atari CNN.""" - del mode - - # TODO(jonni): Include link to paper? - # Input shape: (B, T, H, W, C) - # Output shape: (B, T, output_size) - return tl.Serial( - _BytesToFloats(), - _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) - tl.Conv(hidden_sizes[0], (8, 8), (4, 4), padding=padding, - kernel_initializer=kernel_initializer), - tl.Relu(), - tl.Conv(hidden_sizes[1], (4, 4), (2, 2), padding=padding, - kernel_initializer=kernel_initializer), - tl.Relu(), - tl.Conv(hidden_sizes[2], (3, 3), (1, 1), padding=padding, - kernel_initializer=kernel_initializer), - tl.Relu(), - tl.Flatten(n_axes_to_keep=2), # B, T and rest. - tl.Dense(output_size), - tl.Relu(), - ) - - -def FrameStackMLP(n_frames=4, hidden_sizes=(64,), output_size=64, - mode='train'): - """MLP operating on a fixed number of last frames.""" - del mode - - return tl.Serial( - _FrameStack(n_frames=n_frames), - [[tl.Dense(d_hidden), tl.Relu()] for d_hidden in hidden_sizes], - tl.Dense(output_size), - ) + """Layer that converts unsigned bytes to floats.""" + return tl.Fn("BytesToFloats", lambda x: x / 255.0) + + +def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode="train"): + """An Atari CNN.""" + del mode + + # TODO(jonni): Include link to paper? + # Input shape: (B, T, H, W, C) + # Output shape: (B, T, output_size) + return tl.Serial( + _BytesToFloats(), + _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) + tl.Conv(hidden_sizes[0], (5, 5), (2, 2), "SAME"), + tl.Relu(), + tl.Conv(hidden_sizes[1], (5, 5), (2, 2), "SAME"), + tl.Relu(), + tl.Flatten(n_axes_to_keep=2), # B, T and rest. + tl.Dense(output_size), + tl.Relu(), + ) + + +def AtariCnnBody( + n_frames=4, + hidden_sizes=(32, 64, 64), + output_size=512, + mode="train", + kernel_initializer=None, + padding="VALID", +): + """An Atari CNN.""" + del mode + + # TODO(jonni): Include link to paper? + # Input shape: (B, T, H, W, C) + # Output shape: (B, T, output_size) + return tl.Serial( + _BytesToFloats(), + _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) + tl.Conv( + hidden_sizes[0], + (8, 8), + (4, 4), + padding=padding, + kernel_initializer=kernel_initializer, + ), + tl.Relu(), + tl.Conv( + hidden_sizes[1], + (4, 4), + (2, 2), + padding=padding, + kernel_initializer=kernel_initializer, + ), + tl.Relu(), + tl.Conv( + hidden_sizes[2], + (3, 3), + (1, 1), + padding=padding, + kernel_initializer=kernel_initializer, + ), + tl.Relu(), + tl.Flatten(n_axes_to_keep=2), # B, T and rest. + tl.Dense(output_size), + tl.Relu(), + ) + + +def FrameStackMLP(n_frames=4, hidden_sizes=(64,), output_size=64, mode="train"): + """MLP operating on a fixed number of last frames.""" + del mode + + return tl.Serial( + _FrameStack(n_frames=n_frames), + [[tl.Dense(d_hidden), tl.Relu()] for d_hidden in hidden_sizes], + tl.Dense(output_size), + ) def _shift_right(n): # pylint: disable=invalid-name - return [tl.ShiftRight()] * n + return [tl.ShiftRight()] * n diff --git a/trax/models/atari_cnn_test.py b/trax/models/atari_cnn_test.py deleted file mode 100644 index fe3ded66d..000000000 --- a/trax/models/atari_cnn_test.py +++ /dev/null @@ -1,58 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.models.atari_cnn.""" - -import functools -import operator as op -import numpy as np -from tensorflow import test -from trax.models import atari_cnn -from trax.shapes import ShapeDtype - - -class AtariCnnTest(test.TestCase): - - def test_computes(self): - hidden_size = (4, 4) - output_size = 6 - model = atari_cnn.AtariCnn( - hidden_sizes=hidden_size, output_size=output_size) - B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name - input_signature = ShapeDtype((1, 1) + OBS) - _, _ = model.init(input_signature) - x = np.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( - B, T + 1, *OBS) - y = model(x) - self.assertEqual((B, T + 1, output_size), y.shape) - - -class FrameStackMLPTest(test.TestCase): - - def test_computes(self): - hidden_size = (4, 4) - output_size = 6 - model = atari_cnn.FrameStackMLP( - hidden_sizes=hidden_size, output_size=output_size) - B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name - input_signature = ShapeDtype((1, 1, OBS)) - _, _ = model.init(input_signature) - x = np.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) - y = model(x) - self.assertEqual((B, T + 1, output_size), y.shape) - - -if __name__ == '__main__': - test.main() diff --git a/trax/models/reformer/image_generation.ipynb b/trax/models/reformer/image_generation.ipynb deleted file mode 100644 index 626a99cae..000000000 --- a/trax/models/reformer/image_generation.ipynb +++ /dev/null @@ -1,414 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Reformer: Image Generation", - "provenance": [], - "collapsed_sections": [ - "udDs_biH0n5U" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "TPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "udDs_biH0n5U", - "colab_type": "text" - }, - "source": [ - "#### Copyright 2020 Google LLC." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WPY-OyyM0pSs", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - " https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "psnUF-8c02o_", - "colab_type": "text" - }, - "source": [ - "# Reformer: Image Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/image_generation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1lnRd_IoERdk", - "colab_type": "text" - }, - "source": [ - "This notebook was designed to run on TPU.\n", - "\n", - "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8PluCmWbZIpJ", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Install JAX. This custom build raises the TPU timeout threshold, because the\n", - "# default limit of 2 minutes is too short for sampling very long sequences.\n", - "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n", - "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n", - "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n", - "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n", - "\n", - "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", - "import requests\n", - "import os\n", - "if 'TPU_DRIVER_MODE' not in globals():\n", - " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", - " resp = requests.post(url)\n", - " TPU_DRIVER_MODE = 1\n", - "\n", - "# The following is required to use TPU Driver as JAX's backend.\n", - "from jax.config import config\n", - "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", - "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", - "print(config.FLAGS.jax_backend_target)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yiPdBenoZwH6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n", - "\n", - "from tensorflow.compat.v1.io.gfile import GFile\n", - "import gin\n", - "import os\n", - "import jax\n", - "import trax\n", - "from trax.models.beam_search import Search\n", - "from trax.supervised import inputs\n", - "\n", - "import numpy as np\n", - "import jax.numpy as jnp\n", - "\n", - "from scipy.special import softmax" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yyxRk75iaAap", - "colab_type": "code", - "colab": {} - }, - "source": [ - "%matplotlib inline\n", - "from matplotlib import pyplot as plt" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FQ89jHCYfhpg" - }, - "source": [ - "## Load example data and model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "qBvuw2h85WXE", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Normally we train on the full imagenet64 training set, which is quite large so\n", - "# we won't be loading it from this notebook. Instead, let's just load a few PNG\n", - "# images to use in our data pipeline.\n", - "DATA = []\n", - "for i in range(8):\n", - " img = plt.imread(GFile('gs://trax-ml/reformer/img{}.png'.format(i), 'rb'))\n", - " # Convert from RGBA floating-point to RGB integer representation.\n", - " img = np.asarray(img[:, :, :3] * 255, dtype=np.int32)\n", - " DATA.append(img)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "oBZh0Q2UEiaB", - "colab_type": "code", - "outputId": "d5adcac0-6f76-4c56-e6ef-74becaca87be", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 130 - } - }, - "source": [ - "# We can examine one of the images to make sure we've loaded it correctly.\n", - "plt.figure(figsize=(1.5, 1.5))\n", - "plt.axis('off')\n", - "plt.imshow(DATA[0])" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 5 - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAF8AAABfCAYAAACOTBv1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO29eaxk2X3f9znn3KXq1v72pdfp6Vk4\nC3tmyBluoklKpixSlk0pcbzJQhIkkaXAiWUpSOTEie0IiOEkBmwkCBAHhg1FSGzSlkjKDCVRpLgP\nZzjDmeF0z0z36/3129+rV3vde885+eN3q17T0jQ9DQQtIH2ABrrq3brL7/zOb/n+vud3lfee++Pe\nDH2vb+D/z+O+8O/huC/8ezjuC/8ejvvCv4cjuNMf/4+/9Re9wwHgnEZpj3MK4yRCyrVHe4V3coz3\n4HAYZ7HFvGqlOH/+Im988xWSigFAeYdTClNcxwIKj0eR2+JLL/+0kfMoLV8EPiBTclA7s6RDeODX\nPwtAZTlm41aXkw8sYm9clntkQDL4PsumDUCvvU8zewNIcPkYgDTz9IYZaVexsyXPEgaas6djMiu/\nG3RLHB7mHF6LKPcrcq5RjjKKB463ADA5ZB6+8p03+NDTpwCoJyF/+/98Xr1j4Vtl8E5+pxTgFRpP\n8RXay39cEa56QKGwKkBNJkQ5bJ4TBIDyhbAVyikwvjgPODTGO1Ivwg60xzqPLc4TBeCsJjOWsJih\nuoXmv/gXtMMhAMN2Sm25wnD9IsuJnLvX26MZ3CLubcr95BuMDkZcvbHF7uZAvhs/wGzyDGdOP0Zr\nVp5tc32DnUspYeMlAPYP1mlv1bi53uZULRKZaMVBd8wfvHQLgDj0rMyXeO7xFZK4UDQTvq1875ud\nezjuqPmB9+SFlis8hRIWy0A033uHLj5775jkbNN1pjRZnlFYjeKiigyPniZ4Bo9j7DRm8kvlMcqQ\nOdFyoxUOR+QduSg6x/7Zb3HDtinbKgBltUV9Y4t6Y4Z6+n0A0rZmPILdXgpAKQh48flNnnv4V/n5\nH38GgPm5OplJuHZzm+u3RIvPnDnF5vYup1Y/AcC1Nz7Dh37iU1y9dInf/eK/AmD7cBtGKXYogun0\nNWk6IptJ6I7lu3oluzvh57cJ0XtV2OWj75xy8mFih9ByjPfTCVIefGpxSmw7gHeTvxezqRzeycFa\nTa4HWnkCM/lCodCk1x0P/eN/BMCVwFH3muXahkzQ/qs0yOm+uYaviK1uGkdS2sLq7wIQph/nv/6r\nf4/V+ZNcu/o1AP7mf/s5fvEX/iK7Gylzhd2pNpt465hZmAHgodM/z/rFN3juxz7CtWuvifDOH5IR\nk/VFwN1xSr/f59Z6B63E7Ayz/t0JXwFq4nDRKOVxqImpBpQIuvhsvMK6XFbE5BCtyNOcQIN1E8Eq\ncuuJYrl87hxGabz15MVq0AZy5wkKr+xx0LHM/Dd/i/yxh+SYzZs08gvosTx8KT9AVc8ys7JP3O/I\nHY5eJ7F/gp/46G/Ic+SQlGP2Ll/BeLn+uSdO4dIRTz55ju+/dQmA2TDi+LFVgqgIHDwsnz5OHkac\nfebDAFx68wXqcRVdEQWZVR5nZxinOWkmqpal47sTfiFSeQjlUV6jsfjCKYoThomtcYDXCuudOFTA\nKMhthlIQFEqcA9oolJ9ESWKTPBAUBzmnQHnc5Fo+wycJJz/0LJ2OCLYZ3KLVvUqpslJM2AkWVYc0\nfYXlkWj+2XzAuHeF73z12wAsLZ7kyuWL1GoVlhdOA/CpP/soPqnzjedfZHHhGACDfofXXnyen/0P\n/wMAtq9fZn17l+6Vq+hcTFh/aOiPB0zwsdAYSqEmCkNKsTha6ytvK9n7Dvcejh+q+ZMwUikxMUop\nXBEyKjTgsEc+Eu8UGo2f2m5PlmcozXSqIw8aP/UBaIVRkkMExbnHTkyP1sXquAFP/M+/xltlg+ls\nAdDK9ghqyygt2rU8/h3ePdjg0VKXqHB0eZrz5sENLp8Xmx+VyqycOkZmHSqR+HxtZ5evffUL/Kmf\n+tPkA/HmX/nil/jYJ3+My1evAPC3/+avknvNw+96FxUt97jQ8OggYmJ3rfUMh552b0xWRCeTYOSu\nhK8mkQy+iNPVbcvF45VDTeJ9BQaFxaPcZIIUfpjKxEyOm7jtyXJF4RUY5cTcMIluQNtcHuJnPsXN\nH3uOK9fXWRrK71qt4/TJWfaHADxCxLl6H22H5L3i/sc5c5Ut3rv0IAAf/uk/iwlCxsMeX/29LwPQ\nnKvziU/8BLtbGyS1BID/9Jd+gd/67Of5L//aXwNg9vgCS4sn2b5ykbnSjtxTnmOUIVAiRm0slYqh\nisarifmchIjvUPhOia0H8A6MEoFMbJxxFtC44kLGWazyGCt2XSYPxi4j9RBPbLxSZNZL4oWsLu81\nTmsmWuScRQUB2TU55uFf+HmujFIeW1omUiKg9OB1FnWVObsOwJOsobMe6cCQ9cUut/cq/N5bi1wu\nvQLA73znl1mYn2fY7tJNRwA8+4EP8MwHP8gHVj9CurcNwE/+1CepJiXOPfu03EA+5NRixHh3DWdl\nzaqghPMON9FGb0B5PBr9b4Tlf9S4b/Pv4bhzkoWaLhsBFo7sP4AzQaEFogkOCU8z5THFzDtvSccD\nkuAozs8dxFozwY0AXK5QgcYVq0Mph9rJqPwXvwJA7dy7SK68SKOySHtX7PIxdYYkfY1nh2sAlPw2\ndqQhs1NT0EsjvnztOH/nf5HznDpzCus8SoEJi4jEOrTN6e9u8iu/8ssAtJIaOSFnapJ0NQPFW5df\npJwYFOJjPB7t7TRqQ2kU4J0jKBIW696+UnhH4XtAKTP97HyRYN3mTJUWkyQCU1jnUMpNbZ73Dj8C\nHwVoJk4IlJGUDMQHBKFnaC2mCC197jBPvpvV/+wXAfj25i3y9pDWeMQwbopgwwEPre9yxv42IFBR\nThN8D1uEg7+/NsfH//zPcuLBUwCM05RARzjvyYbDqRC8tfzK3/hVPvcb/zcASw+e4sMf/hgXXrwO\nQDI/wOsyJmCaG3qX4ZxCqeK5AgUeBqOcJBLDO1t5e+Ny5yRLqamgrRMYQcx9YfPw4BR+ik/maAzW\nH4FvzgssYa2/zchJLjDJZjWKwEseYIoHSbdg4X/7O7StaM78wT7naHJzdZ6VfXF4ycEOy/oaRYBC\nIzLYPCNwKU6LVn/7fIc/96eb06hDEeByh/MWClQzKpf4n/7+3+dLX/gsKw8/DEC5UuPmzavMnX4c\ngK3rXyQbDrAGOkXeVE7m0cGRWQ8kFaRSDklzkcm1g7uNdiZJFBAojXUOoxSTONLi8UodZcFOS/Tj\nQE2QTu/RFlykpriN8x6fa8KgMEReMfZOVs4N+e7c//D36K4ssteTsKWysUtt6QS0D9nfFKd4OtxD\nlxVqAp8oSznqieMO5cuFxohKkpBnoolZmmO0IR2PSALRhm989Wv8r3/312iePIErnHCne0iiF3ng\nKUngWmf/CuW4TKNeZ2dTENIL3/wNRuMIncu5U6VRoSYOA6KSCD+I3l689x3uPRw/BNtR0zVlXSr4\nCx6vJniHwgo6Lz/w4ApzMvEzucvILSTakbmjwkiExvsj6MKiCDOL/VMfA2D9Q8+Rr7RYvXJVTt1s\nsjcYkAURp5uCwy/tvsAp/30CNS/3Y/bETzlHruoAvLg5h/30Z3j47LsAaM7P0e+PMHbMMJOb/K9+\n+ZdoHFshIMBq0WITKcaDIV//g28A8N6PfoSPPvEMLz7/LQ5vXQCgXK8RmQRd2E+nFc46RsNsaq6d\nv0uzI1anALp0UAhUkEs5QEzFtJKFBa9wePQkAsodGrFUxSon9QqtISvu0GuBivub8PD/+Gvy3fiA\nE99fp312DoDLVy8SLZ0mHd4kOhQj3/RDgrCO82IabP51jDboCqhDuadGvcJL332Jv/HXxXF/8Ec+\njDIljh9b5dqVNwC4dfMyQZwQGD8NFBSek48eZ3FOJrZVCvjcP/1HmPwaSdgQ4UVV0LpIGgEnGW2S\nmCnSO5HfOxa+vw32VV6jdA5YnJvYc4EcpuGod3ibg/PkxepIxzmDQzAlRxSa6aRmgpPKhWyAu5HS\n+Gf/hH0lWr39ygXU4x9gqyP2NQkStjsd7CjmrVQmZHZ8nGbnJrOVr8v9lMv4NMX2mab3T54J+dZa\nxN6+AG2//dnPkeUjtnd2prHv7EyTU0shEUp8GlApa5ZbfdKewNWvXtjCK8PqwhLuKNyTRL2IfjRS\n3XKFXORh7zLa8d7jVZGruhDwAjNPYnjtMCYiKpfk4oFBB2WiJKFaEwEd7G/y8f+4ypULF7jxhsC1\ncQjKmClG5Dsp0S/95zy6tMjogmSrs9XjKDKWOhK1DCslgjwlz9rEsTjTNg121AlS+6aclxyvPWSg\njTzaRx8f8fyVEO8LiFcHuCigXmmRu8JT+5hTczmRzbGTe/I53b1rAjAh8XoQG7x3KHWE6lrv0BPX\nqTzDUcpSyzBM5bthOsn1//C473Dv4fghoabBhDEAppRgggAdRuigLH/WHoXBFlplrccEHu8DTCin\n9qbM6cfOcfrRJ+jvS7b45X/1WQYH+5hQ5r63D3/moz/C9nyJN8Xq8J4PvZ/zl7/NfGkRgD0/oDzy\n+Laj48TmX2s7Hqts0S5LLO663yMuGyIDg26h6eOI5QXH9r6UGo23jK3HGTUNP8slQ57nGHWbffaA\nNtMCvs0dldji3BiFxI/OaLTzR6bAK3QQcH1zTFIU8FfmkrsTfn35QUlFEbvmvJXkpBjWOtBqavMM\nKdlAk9shm/2u3KDNKUUVUueYhOO9zj6BiXAF+LX4V/99vvPUA3QurlFZEef52sYGcS/moC02N7CW\nuBJjqxnNkRRTZs8sUt02dGJxiuw1aLBLWFbUyjJBy6MxT52q86/b8jkKK+h+Sqg12YTeUqqg1Bjl\nBZEF0GgCHGmBlztrGWdD9vdyjBYNiZKYUmDQgZhGrUNCownrCXlRtruyeZcO11qLK8IojcdZi7Vj\nUitCw2lwYHRRDvQ5SivKcYh1ReLhIXdjXArpKCvOC0GQkkqiiv+Zv8S181doLc4xzGVVNXb36Xzn\n2/T3DwA4e+YsG7ZOTERpRlbDYTbgm7P/Hu8ZSUFbVR9nvrKD1Y5SLOHgI6e2+f52ix/9UA2A9f2U\nV75rCSOYK4sGzwYdYlJAHUHovqg754XwvCcMAsplBb7AhHLoeVAF7B0HniiIQXnCovZsjtCZPzTu\n2/x7OO6o+YP+cPp/rQ2B1linCEPRTo0jzVLGmRQzlI8wYUQ6Gk39QKgTgrhEmvemKyQuVbAbfar/\n/X8nJz//XcLHniJVM5T6Eu3s7dwiWqmjHjgJwG5jCZPtk4YV9iNRp9JoDCj20+MAPDvfQ6ffQY0H\nZKFEYGF5j0899TJr7WcBeP1YTKmmOdgfsxLKCjZX98gxKO3QhQ11WmF0wMRY6rBENYnQpQCjZMUY\nZcBP0Ray3NEdDX/gu+AO+MKdIeXAoApq3sHWJuVGgyCI8Lnc4Mh60CFjLxfQTpH2utTiEma63sbY\nPEahyDPBTfprfUo/9xfQD0kB+8yZIZeCOq3OW9hMHjZspozcLP0dEdDmy6/TeO8DVDo3sPtC59Bl\nRWheYa6yC0C+8wUy06SaDHCDIskzitJsyEOzki+cGW3xZPkMaTrim1fkvm84h4kMJkiICmGpwBDq\ngL6Tey77EXG5KJYUyKvFSwI5wb+MolqO8TiUE9FOWAzvWPhprhn3xebur62x+vRzpGnKqLCLoXZo\nH1IORMs8llK5iVEBFBpjswxPiveW7p4IqfTX/yNqn/hxGscEF99MQwYuxdiArMDY48iQbbfxuayq\n5LGTxMaS12bo2uK7/pCazqkENwE46DdYXszZS+eZj8ShdHstSjom7Ug221iBBxe/x1B/mFdfvQqA\nQWMAYxS6iNJ0WBLuqBfhG6Uko0dPgUSLx2s1FT5eSbkUmCSnd2AL3ln4h70e/bY86DhqEEYxzmpU\n4d3DwACOvED1tMtxNifNh2QFFBxHISooQ9Zm8MRzctyf+AjxMGSzJFFKMNilsn4d62Hu5Am5+MZr\n+GRMqkXLoyTkYGdMFEK9qNENjeayOsFuXyb1TLfGI0GfWrZBQ2rjxIsOZxyRFlOZ5ylYg+q+QJhL\nIuiM4EsedVTM8RLJ5dkEXgGtFdbpKRSuvEPfFszkCJzgUUd81jvI977DvYfjjpqflOt853vfA6Bq\nSjyMRRkwrqBWjyF3ClOogjcx1mY4DFHhFLPcYvtDvq5CnvtLPwfAIWOO54pLh4LVl3yf8twA26lS\nLRhie84wCldoHJewMk23mKtCevU1RolUslr+IpkrMYjkmFcrs1w+zHhP7GmNBNVcMtch9ISzslrT\n60OM9lQjw1ZbTIpyIcobKNgXIPUI4/QUgggMBWXRC90RpujtBLdU0zKfnxCypxDKOxb+Gzc3WN8v\n6Nglx2CsiEsRuRInqHOI9JHDGac5oAmjEJfJBFnn+M7738/euaf59YN9AD4wv8C4v89MqUhOQo0e\ntFg/UKg14eQEgwrdEDQyQXNhD/SIuBFQNSKQzD9AMF7HjiQLP6Y9zXLE74zeR70rcf5C5xIqSEHm\nAm0g23WMzYATDcl6d7oxQRgThjHKFA5XCaVxkuGWjEZ5gybFTUqrhamawO7eiRly6ojRqt1dAms3\ndnrs9OTBNtZ3+dSPOlym8boAloISqIy8EHSII3Cefn/MtRlJq7daq7jHz+HeWuPRZKIOnkvZkOG2\nZK8zpsTC7CKPrGxz+Xe/CUDr7DniGpTCglfvcvygj9kbE8fih3quykJ+QHcok/Hs6SF5ZZlu+xhf\nqT0BwPdKfxkXlliKRWHOlL/GA+PfRKURplhlXml0EGB8gFbhVLCpzadFf6U93jtQ+kibvSroR0fE\nAqc8yukpNWDiQ/6ocd/m38PxQzT/kPmmLMPz+7DXSWk1ckxRMNfjDpn1pCXR8n6rwjYD2t1Dbg0F\n/1h99t1cOf8qS9mABx//KADf2LrFUgBLi6sArAcxmUs5mRp+9Kf/PABfu/ga49xSGYldHt+6SPvC\nTZ4LDlheEU3vJE0eXC2xPnpAzpNU2LgJ26niMJXVMVpoYULLsPIIAK+WHyVY+nf5udf/AZ2XvyPP\noRO8AmscalIQ9prM5tNsySiH0x68meJoaIVyfppQOS+lE40XwA2mhLJ3LvxbXRYKvvpBJ+PSlZu8\n772PkhbbctaXZthujjksWADd7as4X6WjqyycOiO/23yLeVMi+cCH+Obr50WQzRbxKGR/LODb8Xib\nNKpQXl3lextX5bnoEW91p0I8Vq8QzVT5Sn6c+UVRiLnqPDeXHqJXOEU7PGR8RhMrSxKLQnQGPWpp\nzvUNKbrXW7OUlOI3z/4VVr56Ua7V0Wh0kbGqiVyxmZ9WqZQuODKe6fYm7cTm3wZqisA52p/A28v+\nh6CaIdMMN6qGnFhokKcpX3zsLADGDYn31okLwr6ZO015XCNqzTGel/i8PNylfPxdjLfXqV+Vhw1V\nyPGzdfq6gIbdAn1dQu9eZy4pqv7xPLulPmpWoIM03aR0pkL1Sk7FiELoqIkfHOIKfr4qNYj8EGMM\nfiAKUUnm2d3dZaEqSKhPIUtTuuUmN899XM7z+S+hUDivMbcZ4pG1mCnKqdAukOJVcYxTRVI1JQ4f\nlQ2nROE7yPfO8EKs0VqWfaV3QFKqMAxiakUE0N7eJjCLjPblwQbff5m9Rx+m2tHoumSvo8VTDLIR\npfXrtA5Ei+vPvYduNKS7K5pfCQ7IMweNJd54Qbbz+JMtllaWCdMCvlU1Bm7A4hPzuIHs9qiYLn7Y\nx0VSUx2T4rTBWodJJOsedIfM4qmMpYy4HyU0SyGmv44992MA2M98Flepo7yfbgD0GrLMipMFlDbk\nHiEFTzZwFBK30xRXTfclTN3sHXpb3He493DcUfOz4YBwVRKYxf/kk3z+yWPsZT2a1wQnaV68yTAs\nkRYAWfz4Y8zbnHQ5Io1F87i5waOrc6SDEVfPyHaeua0N0uEGw/ISAPrkAstpTvviGtFpccKDSoVR\n74BRYTQjbYjLIVWV4aoSn6u4ycAobGUBAJulKJtRDkP2C/JTvrtFuLLC7ljw/Pl8n3Sk6FeWWSo2\njeyXa8KYw0GxVcjhSa0jLBJI5RQGNwXVQLYqOXdkYhRCOvC3cVBvP/4dCb9WqxKcEPsaxHXmtnss\nvnWd/p6Yi3h2lujMaezrwmccPHaMwbFjqDyiWxIsJZqJWV/fZen0aT5s3wJg1OmyveXQPUmEfLXO\ndmrpzSzTjMXHJDnkPqBZFaGFY41xJTZtwFIsDndvOECX68ThxAhDbiO8tjQKa5vPNaA7JE0kX9mt\nL2Nu7VHVkFkBDX1UQpOjVHxkv60iTy1xgc5aJRs/UFN/i/KglMYUpsUqfjDlhaMS4zsV/vb6Fvam\noIPReJs/87Of4iqazceEz+jrdRpXXuHwOSE6NecVw+0+1g9Z6Iitbr+5R/PkAg+GPXqHYndt1sCV\n21NK3cXc0lxappZ5jBcnPDPq09El9ouKWDkqUXcZy9mQdsEIyOeWcSNHfyQPfywO0GmfnazEwUgm\nZDaMsblnqTjvMW154UsvE/30jxMuS/hZXj6DX7+Ei/wUKslzR+Yd5cIDey9O1/rbmDheFZPlp8eg\nJp8mvJ37SdYfy3HnGq4xjArayYPvepxxFBGfmefcoIiArr7Etfd9gsaizO7++VvYSo3l1XkSL3H1\nMOgzaB1jbf8mVSNw8aX9DnpmBhufA6C1sMxcbxNVbdHrTgrWcn1yCSMrYZnD1FIKI8bFuk96bfoq\nIayIadpQHpdoWtmQaFAU8GsxvdzQmRPf1c0Ve/GIoHOJWlf8i4/LgEQ6qgibc+dRqKM9YVMO3m1w\ngpbI53a6pBeC/pQuqe4Q6N9R+ItPP4hqSxi5WK8TDm4xPKjxWksepPTsaWZblv3LYpqau9vw2GlU\n3dO3kuRUT1fJ1tbo1xN6+wKS7Y0dw3SGxpL4k4Vxn3G5SXXQ5cGK2PN+rvEuJI4L5ttgwIEN0eWY\ngn1NyQSUAArbHVFFdTP6lIiKPKNuLIclTb8rZsfbLnNnKxzOPETckGMqvS0iZ3HkWCcnH+cWr/yU\nWm4oGnS4I2FOMJ7ppkEv/1dKTTF/9/aFrDsL/+mVKsvPSlXic40neKU2Q0yflYJ/k7814HquODUW\nLT/2yZ/gYKaFjmKGm1cBWM1GZLVZRkNH51BWzCOVCv0TZ2kUe2WbaYdtX8VVEyb7fkvVFicV1Atm\nwM64z/FQkZYr6KEoxLLR7I/HDJ2ELW1lOWlGbGSarNihuAssNg3HuwLQrQUNFpIHGISO+UC0unv5\nJr6mcC6fopp5Jpu5j1obUBjzfMpY05ZiP9YPQglCp7xDavtvI/z0mSf43KKEh/F+m+buLZLEc1jM\n+G6Ws6hTlv4d2SisGoaNcYljt26QZ/IQ6+Um+QjKgz5JS3D4w/kVtjv7JDW5vFERD+Qjrm4dcLMI\nIxtJhUGqOCzg47g0Q9V3yYY9fFUqUONsQKDH1PsCVadJnQNVIlcWWyCvjWxEd2i4ZZbloWoxm1rR\nCebYX5MVWxqvQfVh2TlZmAtJnCy6CD29mtACdUGaBK8Ft7fTXhTgnCALfspvut915I/luKPmv9TO\nCC7+gRxYKTN6eIWg22WwXYRfKwtUdjpcjiWmDza20Yc77Ngyh3VJsip7HWpJBe89dlaSKhM6ymPL\nWipaMWMCwhBqWtEcil8IUk1l1MfGshI65RLlYJmIjL2BHLNYqdAfj8niyQY1RVaq0KLHQSoqvNSs\nkhvFYVOuHd66TN5cJsoGuOfl2XQ4J30enEcVxOAsywt8nuLcmsBDpuQTFHvSUNO4XzaJC/Lp3aTg\ncpfsBWZalMsioFKtQd4/JLpwjfTUowCcuXmdlZOrRF5ArKBaY7NSpdvfJXhdGMm1dz9Ca5xzo7HI\nXiSJzrn2derecGMkmMxNDccTQ2YzKBflv1LMqFRlPpGJ3tzdpDcYMAyr1Iud471uDxUHLJcKjlA/\n48rBmNmlEp3vS1H90tIT7O3sYiUgotJqsJXH1OszJBvCEbImIPQe78AW6eooy2TTw1SyhcNV4Ow0\nlBGaSGFAPEo2ens//d1kA8g7Fv5MluOrRdTS30Ot3eDazpCTg9cBmH/30yx+4GFu9iSSaA/GMOxQ\nXtuk8ZyEkWzvsD5S7C1XeeSa/K5SC8mTGk9bcYIXXYMbOxYbxcRGNG9kDLrvuVrEuk3nqJfFHYeN\nwsGu71M9fpor+1IRO39rB3vmNGFnh9GsrMberVvUg5yDDcnCcxtRmlOUD/sMnv+/AEhmTh7R1QvB\nWusJOQoVZUfOJNmSMdmT42/ffaKcAHQTX3E/yfrjOe4MrM3N0EolWTGf/l0WZyCbOUd1VTD2+Jl3\n8aby5J2CAXywi9rvkz7+EFcvijb2nzzD2a1rvG/vdbpV4emocsB+rgiLDRTLeweYHPZ0SLNolVLL\nU7TqkRoBzWqNBczhHldR1NdkM0QyO8+t3S3WhwIxN2cquJ0txlkP3RIb3wgNajzgsSVJ8C6+fINy\ndIB5/XWuS3mByqkBS7M1yvWjBnu580SRm2LDmiMG8+3FcadzVNHswLkAlMNhyJW77ci7EL4dD6i8\nJQjmzKkGW2ffT32oOFiQZb83PGC2nVK5LvZ1YDT9uUUaL3yP8H3CjWy4MeWyJjZldvpy82utBcrj\nAde6EvcvRRW8GRPYHkMnIWoQheg44eDKFQA6K/P0rCHt7DGzIILdc2P2PdSLsDIoRSSHB7BynO2i\nGliqBNSbi1y7KPc4zCEeG07/P/+cF8vFc3ZHDHo51caYWk2e7fqNXeZn65Tn40Ia0uDD3VY8QaUo\nb44g+4LNYL0/2vJ0tzvQZ9sDDgNxgPaBj9IbDll76glmFuSuq7t7tF9ap3FCNHpnt81iZOGRM+wX\n3fXy3OBKLcL2iNlCsP7KNVZnI0aZMAp6yQJV3ed4SdMbykrbzeukozYLzSLD1AOSxhytSs4rI3Hw\nw7hE4By1omVkK8+4UK0y4z2HRTFlM8uor/fpdguWm+vxjJ6nVeqyelLue2OzTeYSBt2MdlsmqTMY\ns9HeZe2m3PPqXMLyXImkHFuFgA8AAA7TSURBVEwZCd4Lb2ZScHFMGkIddT67a82PUvAFNXDv5DE2\nkjorowGlL79aCKhMOj+L78vFK6dXKe9epufniS6Kg3v6RImNNKZ3fZ/qKbncvB2QVGepdUTQpaRC\nrBS622Zc3O3YdcjzEenyKRGGtZyNyjA4pN8X4TeCEuXITCHmvOdolsoc1qq0ilU22uwx7GbTaCe3\nLVZvdVg89gB5KsvjPU8+wc5hj/Nv3GR/Q4TvrCMOK7iiicT6wYD17UMq1YCVlijkTLOECW4jTVmF\n1x7rHUdNiN7erd53uPdw3FHzt06dpjMr4Fd964DTVrH1jecZTeiC73k/s+Mcl8jqCK5dovLQInZk\nCbV4rsbcPOUbWxyszOJ7ggkFMytcdGVqWjRvr9NmPlKouMlhYYoqWUqjkjAeyOeRj9hKUuilHBYF\njvfPNOmMumwVzex0r0s8u8RONIP69vMArJaXua4zcoHgyMclzvTf4PqNm5giOTRhxKljxzl94ji7\nbVmN59+6ySuvXiCw8rtyHJOakNHAcaknmFR+0zLbqLJcEMSatRBw0qiy0Pj8DqSpO2M7RtPcl2Wo\nN9Y5vHyZaK6CP/6YTIhSDPM+1V1BFZsPHGMjConckBOrckPXtlLo7DPyJepFUT10OcHhLvWCRt6P\nG2weDrF+SK1eUP+WZti7eovasizfza0N6uESb7S7nC0il26e09VQLwouB60mb3YtJ8cHXK+K0sz0\n+5iyolN0Vjq5vofb/z2GvQOa85OMO8R78M4zW5f7/sizD/O+9zzIjetCX3z1/BX2d7ZJoioTvGZk\nPQeHfXb3Cs5poJmdr7HQLFOOi/Lj3XI14zTHviaJ0ZUAji20qOsydlYclb25RjPukS5KyJj2PKXZ\nEnkfhiW5aG3jRcauykFnxPqqUPjifEQjH+IK+nWSHeArdUoaRonACbFO6JiY8Vg0/+FGwog+fZdC\nkQjtpCmnyxV2vEy+6Q3YHVXIzl8iOSHZc7djyUoRUUlAvTPBW2xcuYTWRzvntfJFkyg13QCYeWk3\n+eBp2aD30OkT7B4ccP7SDd54/c1CeIYoiBnriZbDre091jc09UT80Px87e6Ev/3iC1On8ODWPkmp\nTudjT+NelSU9v9xkuLTMY1uyVBs31vjm4Bid62ucft+TAPRHDlsKaR6bRRU7C7U2uCjnVtGua5xb\nas2I0miIHohgX965TmlukUFB+WiEit3RCB0GbGSiaatpnW3XoVGYoes2ZvbaHj5JaBdtWXo2JSrP\n0hnLMadKu+x2tonD2tTjOTSRNnjnpECC7LJRRYEcwKmcVqvOh599kg++R/o4nF9b5/z5y2xvi3WI\n4wRtyozHnvZIlGb3xt7byve+w72H446aby5coBXJ8jl8+Al6T72b0td/n+GymJ2tlSVObHW5PF/s\n8Pjffx375z5CZWWFziQPacyztdnl9NIyUUe27wz1LPNBjC9KlINqndF4TGxzMi3EqjTQDJWjVmDl\nN7yj7y0HI095UZbyMNZ093foFTa4eyNjMa2xX4OsKyfv1iMCp3lsIBo4XruIVgFRqUpcoKFhEEm3\nFK0whYO0SssWoCnXUqopzjtU0RLgiUdO8tSjJ6Y2/9W3rvD6y28SmpCFgr0xuluHm7z4KtmflEbO\nwbvfzfDLn4dHTlOfFbvcczmDvUOyVJZd/4PnqM3NUentcLMjNjfvj6lGMUp7xrkIKauUGQWKdlFw\n2Oz2KUeGUQZJSRZjpV6hZ9y0++Z+WGJ3b8Cwt8d4QxRif2mVVe/ovS6VtNJggf5ii8P2LUZVEVC/\nnBAN4diBZOrp61/GRHXKlTqlAjQMAiM9/b3CTfr1e4/CQrHZz3uL1j9oKJzzZECriPs/9v6n+ZFn\nH+etq+u89LKgup1bu3cn/N7cKvFHPgDA4NZ5Sk8+RrWqOVgSbEf969+h/+gjZBuCTkbvfZK4f5Vm\nM6K9d00e3kPQWmYvzdGJTIga9Hk5mMUMRWiVTHPy5CJ+0GW/4OC4KCYa9YgKcKVaSvjulVcIZ+v4\nolYQH94imatx4YpEJKWlId3DPmHF4iIRbITmPQsBM9+9CoDFEEclojhChyJY6x2hNqA0pphsoQ1G\nxUsTZEe69zlOKbw96mxulZqSpqzNMDrgsQeP8eRDIqNbe4dvK9/7Nv8ejjtqfvVD76OrBTH0SYm8\nHjHeOkTPS3Sjv/ES2bNP0esU6KRJiZKInp4jjCXxMaZOHA/ZPwhZrYnGxPN1rrx5DWxBZHr8XWzd\nuk6yuMqw6FlQH2UMcku1qPu+9dobuASyZoVSkdCX6yX6FzbIC7TxgCHdSkhSbhIV52lozanf/Me4\n81+Qh4pqhKUKQTkhKDZle2+lV8JtVSqtZWvrtFOUUgTa4JzHmMnvJlx8OcRrieudYxolLTdbdyf8\nwXILWyzDWqgY5DmkKZXrYlIWZ2e42j8kiIpKUjqgmcxwbXtIXLAAqM1QiXdx5YgqEv9pEzEkl8YZ\nQCOuc33s0HFMrYCH8+4eG8kMvatS5O5lfcZJDZMsYBE2XLNcZudqG1/8ZhDUqA0d1u/hTkg38Pde\n+zbx938LXxIFCeOYWr1JqVw+4lhajTEBjnQqEu89yvjpq0nwORaDVky3Remiu+Lt6NmkH8/ET9s7\nsJTvLPzFhKzYVZhf3ufhfI/umUfZuSJRy+Hx4/ir2wwWC45MnNDOHd3UkxY9iW/s7dKYy0kqMcMJ\njz+HQTdlYUG04nBtnXB2Dpdm6MKetqqKfH+DjZHYzEE5glqdmc+/wtonJc6ONofYPYUKJRGq7Wvy\nlSHJ8gLvtzJpK1/4JxDPYoouV0mjRVxOUJhpn0+MJFzGG9l9AtOXNUxJTzoi955A+yPat5cVMtk0\nh5IcRpLAH8T+37HwjQ7onRct18tNtj/9DdJnnqVVvOhlpdNh71vfYvTTPwPA1tIit8YDZlSX/YI/\nGdkRo7RGkOSoWELErd6ITm9IswgZ3/znn2b+F38OM3SMJhyYpMXG1ov0Z0Rje5WEmS9c52p4ktk1\nyWjDzNLWEd2KhIy2PqDZmqHpco594Z/KM9gBPkxozgo7rbawgjEB1uXT9pNaadBidoy/DY38gWK5\nI/TgpGR+2yFHxCqlFBZHYJh2a8G9vfDvO9x7OO4MrO3s00qLziMvb7Nz7ifxlw9ZWpPexbvnX2f0\nF/4y5aI43X7oFOVuD6/BFVxJpXNcskJSrmCKrGpjc4M532dcnLt2fIa9L1+k9MgcutD0vYtXGdZK\njCO5xfqr22xvz+CfydDLckz2rbc4CA15Sa7VOtGiVlJ8/NXfJ7z2glw/apA050iK85pAg/PShbDQ\n8kwpAqewKPSkkK7k9SR2ykyQ1sUa92/YeI26zeOqoh29spPtoXcp/OSwDd+SZOH6yU/wyImA65tt\n+q8XLwL42CdYtwNay2Lz+84y53N6xDAsYm+lsSomMgZfgGRB1iOvlqkVgu0MctbrA07EY0Y3hLM/\n3yrT9iWSwnF3X9AMziaUV0vEa5If7Kyvc3h6lYViMpqjER9Z+yrBN/8lLhZ/UirXmJlfJi4yTuXA\nGVPsIhTJBFqyW+kTOpGqRRk9BfGMDvAuxSk9LREqr7Aqm86FUcUuROWZNmW4W+rI+FvrdE7+lJwj\nsbwxSjm2cwFTNJoKEk9crzGeExg4thloj8tTTLExelyts3vYx820MDviqHPtaAzG7F0UfL9barGd\nGObDhJlZ+d2eN6hmg8OvSOq+MVOhfgaiUcDWV0SrD04vs1KPOBaJ5p9746s0X/pdXNwgjCQRay2u\nEJWr05BRa1A4jIG8YLEqL90HtTIE0x7/Hu3dtEu097k4U9RUm713hATYIkiQFzxQvGWj+OEdSFP3\nbf49HHcONY/9CK2qRDb9JCeIUpL1AYunBMvYaMyRd9r4R4XBFo4HWG1I+23yYtkN21B2W2AzGgWc\n0GnNEf3275EuCNV8o7XISj2SuLsA2w4rCfkNw+4NOU/lqRwqi7Sfv0S0UtATFxJWF+DMC/K+q/k3\nv4gLqygdMrMoBZdya17YZkVbGqc9RkeyTd9MXl1RRI+G6VshvDcoE0x9AF7hfC7U79sp4dofYT7K\n4b3CFf35OTr0nQu/p0ccluUsse8yHrVgY4vqqaKvQi9lbrlEu3iIik1RNidPx6RFA6T9gePxsEsc\nKIbzMmmbL1zjzPY2W3PCHC6vVKklmlI2YLMuoN1hO6bzpTb6QTEpdu4YnZduUF+pMLNV9OB5+Sbl\ntuLExS+LDKMKQRAxt3qSWuFghTF8FB56K2ZHaY0rNl54kCTLMaX+TfdDT6gf3smRzsuGaJjSBaeH\nFK2Qpq+wgj8Exv1bC9+WR2g/4a3UYbdP2e3ST+TBgryLXl3BOXECNks5GFsIPIOunDqql7BVTe4D\ntrfEns9/7it0HjnG4YoIv5aU8NWQYalMuyMONvuXe3ByDKdE0OZCm0HT4bpQuSDR1rlFzfLVF8Qx\nAjouM3/8DM3GHOnEUXqN0RpfcPFDpfCqeKFMsVNQGgG6H7TVTppc+CJeN55pows38RXFb6cvp9Hg\nbdGmeNIA/G57KatyCiNZ4q40ItZdopoiVZLU6MVZulGAKXrcZ5llPExRukS/yPoWKwE2H5P7KvzD\nz8gNrZ7k+uoS0bysoFYZwlqF1w7LLH1GSnTDuRi3HFO6KTD0brbLM43jLK2/wjCQ5Mxc+jRmBlwR\nEbUWVqg3i92Tk67ZKJz3hMXqdBZQHusySa4AcsiNx0xeLTiRrM+ZbNawThpyO+cmc4Z30tQ1nxQm\n8sLZHvX6mvbm+aPGfYd7D8ed4YVI44puf9m4RxQPqZZDLgvLj3huARuG6K7Y5V5uMQwZdDVUC6qG\nSSF37D7/Bk8/9ScBiFZ3eGNukbAkS7NWjtjqRAS/v84QMTNu8RDXn8VclwTu/Y0RC1/6Kje7mudi\nKYysPvWTjHYusCRlAmqzyyiMNNi+zQ57xbRjlDGa3DpCZaYvk7HFNk/nme6lst5hNLjb3o6cadAF\nlXwyrM+mWi70cYdz7o62fjKUv5M7vj/+Px33zc49HPeFfw/HfeHfw3Ff+Pdw3Bf+PRz3hX8Px/8L\nmha/p4Qii9cAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "VXjtCPxl3I82", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# We'll be using a pre-trained 12-layer Reformer model.\n", - "# First, load the config (which sets all needed hyperparameters).\n", - "!gsutil cp gs://trax-ml/reformer/imgnet64/config.gin ./config.gin\n", - "gin.parse_config_file('./config.gin')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "NhiTshPPbvLY", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Now we construct a ReformerLM instance and load the pre-trained weights.\n", - "# The 'predict' mode configures the model to accept single tokens at a time,\n", - "# instead of feeding in a complete image all at once.\n", - "model_infer = trax.models.ReformerLM(mode='predict')\n", - "model_infer.init_from_file(\n", - " 'gs://trax-ml/reformer/imgnet64/model.pkl', weights_only=True)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zY3hpgnI5Rgn", - "colab_type": "text" - }, - "source": [ - "## Sample from the model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PnzRPCzFqIVi", - "colab_type": "text" - }, - "source": [ - "Now we're ready to sample from the pre-trained Reformer model. Unlike during training, sampling processes the images one pixel and channel value at a time. The TPU colab runtime has 8 cores so we can sample 8 images in parallel." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "W9ZetV91PujO", - "colab_type": "code", - "colab": {} - }, - "source": [ - "sampling_decoder = Search(\n", - " trax.models.ReformerLM,\n", - " model_infer.weights,\n", - " temperature=1.0,\n", - " max_decode_len=32*64*3,\n", - " )" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HOLawc5dB7QV", - "colab_type": "text" - }, - "source": [ - "Sampling is an inherently serial process and will take up to 9 minutes to run. A good chunk of that time will be spent on JIT-compiling the code, though, so the code cell below will finish faster when re-run for a second time." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "We9Jj9Rap3cB", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 214 - }, - "outputId": "10b6142b-11f1-414d-9b63-353f721a6a82" - }, - "source": [ - "flat_prompt = []\n", - "for i, img in enumerate(DATA[:trax.fastmath.device_count()]):\n", - " img = img.reshape((-1, 64, 3))[:32, :, :]\n", - " flat_prompt.append(img.reshape((-1,)))\n", - "prompt = np.stack(flat_prompt, 0)\n", - "\n", - "print(\"Prompt:\")\n", - "plt.figure(figsize=(10, 10*8))\n", - "for i in range(prompt.shape[0]):\n", - " plt.subplot(1, 8, i+1)\n", - " plt.axis('off')\n", - " plt.imshow(prompt[i].reshape((-1, 64, 3)), aspect='equal')\n", - "plt.show()\n", - "\n", - "seqs, scores = sampling_decoder.decode(targets_prefix=prompt, batch_size=8)\n", - "\n", - "print(\"Sampled completions:\")\n", - "plt.figure(figsize=(10, 10*8))\n", - "for i in range(prompt.shape[0]):\n", - " plt.subplot(1, 8, i+1)\n", - " plt.axis('off')\n", - " plt.imshow(seqs[i, -1].reshape((-1, 64, 3)), aspect='equal')\n", - "\n", - "plt.figure(figsize=(10, 10*8))\n", - "for i in range(prompt.shape[0]):\n", - " plt.subplot(1, 8, i+1)\n", - " plt.axis('off')\n", - " img = jnp.concatenate([prompt[i], seqs[i, -1]], -1)\n", - " plt.imshow(img.reshape((-1, 64, 3)), aspect='equal')" - ], - "execution_count": 12, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Prompt:\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAsCAYAAABhRmIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8eZRlV3Xm+dvn3OGNMUdkRM6DlKkh\nUxISCBCDEGCEwWAwduPC5bmqGozbhbtNuQu77Ha5DMvluapcVe62jSdsbFw2BmMMhUACoQFSQkNK\nqVTOGZkZkTG+ePMdztn9x30RKcal8Fq9ei1WfGtpKfIN97633z7nfOfb3z6iqmxhC1vYwha2sIUt\nfDvD/P/9AbawhS1sYQtb2MIW/r/GFuHZwha2sIUtbGEL3/bYIjxb2MIWtrCFLWzh2x5bhGcLW9jC\nFrawhS1822OL8GxhC1vYwha2sIVve2wRni1sYQtb2MIWtvBtj+BbPfkHv/AO9Xi8N4hRvBesV3Kj\niFNUweOx3uEwGBGefvokzzzwOJWqRdTjRbCAAwRFEfIcUDC2uC4ogQZk4mhkjrQH+//sY1RnYuYu\nt9izfxtu9gw5XQ4k99NurDCSPQNU8HlCmintXkbaEhaveMLAcO2+mMw16LZKrK3lrJ2PcMsxYoX9\nu0axOWQK937pGV5+616GKiG/9KGH5fkG7s4T5zRAMCJ4MTgRqkbIEJwJcMYxjKVhhNBYBEHE0MUR\nilAFVIQcsGIQASNCLEIPQECKMGEwZCgWg4jHqgcxRCheDF6VFDAIXkFVEVFyBVElFRDvEQUPJN6D\ngFclQvBAnueUMOTeY0VQVYwqHgjUc8/ebc87NgAPzGVqg5A/+/Vf4/Dhmzl89+tI+w7nhMwLKDgP\n6hWv4L3gvKIevC+uYQffHyAQAVM8r8og98BpkVVGipgZo4gIYoq8M0YRAyIQWUGMYkSxBsTIIO4e\nwRfX0Kv3lEEceM7jKlLksQ5uCIDyih3h847PK1/zcm2srHDo4AEaK8sM1Wp0W022TY5w+4tupNts\nUqtU2b1zL6+5624ee+RRTKycnr3IleUVQjUcO/0sf/OxewjDEcJSwr5DY4yMVhFRvDPkuafZalIp\n1VhZbdHr9jES4r0DySmVQ+JSSL1WoRxFbNteI6yME9gqWepBHMZk9JOEytDt7Nl1F8uNZzl79l4C\nGmT9RSqVnUTxfrq+yo1HXk2302Zp7gFWlo8xs+s13HnX23njCyfQIljPCz/7m2/T1cUuzliWly9y\n5tnTRBVL6pW5C02WOvCHH/hPXLN/ih95z/dDUieoZAR5lUo9pi3zlMeUzqqgudLtClEAzimBEWrb\nIEuUsckIu3SISydXEZ/SaLaZmBzmrte+iIPXHGL3rkNIXOIrxz/N7/7WnxD1lNf90Cv57u/6fpaW\nT/Inf/aXHL3/MkFJCMqKZDBhyrTbKaUZJfWe0SnD6iKM7ILOGrg2jE1BJOD6wlpLOX3UPe/Y/MgP\n7tMrjRSLYsXQbHbo93p4p0QlIYosy3mXP7r9Vdxyx6043UGWXKY7VOFif436noMENka9J0vXSPoN\n2mtLfPz+x1hb6/HI4xcgVH7+p9+JsauMDteIKxWsGhBDril/9/cP8ME/vQ8JTDGn4QjjiDe8+S7G\nhsuExlMKDHgHRlBvMXiyLEMUFEHFk+eet976aiIT8lu/90GG+6dp2WEe65R49omzmAguLj7/vAH4\nnff/hCZJk9Ba4iiimCGELPVcWrhML03odLu0kg69tAveERhBtUkt7FKNAypRndhGxVyAFvODxqS5\n0s8cqYIJY0rlCuVomDiuUYpiRALyzJHnOVnuaLd7NNbWmL18kU4/oZ8N5jh1pE5p9RMEC2LBWKJA\nCEyItUJgLNYKobUIivcg3oAWc5IVQ2CF+x85+bzjs7BwVkUVsASBLeYBoFQdx4iQ9BuIhARhBICx\nFhEBVYwJsdZSzHdXbyny1bf/2n8D9JMOF44/SHN+juHhIXYduYNyffKrrgOg6nFPfhgrLdj1chSD\nX75I84mHWF1p0bI1/Pg0k3sPEg2Pgjii0ghhqY6YEFDCqII1IemTH6F8y9u/YWy+JeFxYlFfLAqo\nYFC8gFHBqUehSHkJEO9R8bg8JwgAURyCeAGrGAWPwarHm2Jxc94TBeCdIbOOMHcMORj5yEdohD16\njZT6TJXepZPMVJR2e5m4cwbN5+iv9jk3e4Wl+S6a7Ge8chsH9t3I6DjMX5pj8VRKOPwoK6uXaFyp\nc/FSgz2Vbay2Eu579DJxqGyfLPHiw9upxBax4bcKxddDlVwFa8AKhN6DsYQUREO9YdUW5CN3DmMs\nRjJiLLkqXYEICpIxWFQjDLl6YmMQ9WQYUgrSGKqAOEIvZCLFxKHQQwm0mERVlQwIEDJVrAoBQugc\noYGeGJx3uAEJNUDbe8oiqFj66hiSggjmKGZAUS1+c7GhIBHqMn7wZ9/LJ/7v3+PMb/9H7n7Xewkj\nQTKHV4M1oCp4D06VwIP6gsRIkVwF11AAxYjgBt+TAQ+x6jdeWJCXgjyKASMFEbIDYmOtYkxxnXWi\ngxQEqBgdxfeFb0x2RAoCWbyieM/6KNgMZrZNUauUCAPDgf17aawuE8UBQWjodruEYYXdu/bz2te+\nmiceO0ovaWF8xNNPPcULX3oHcxcuEYchB6/Zx6lT85SqFovFuZRKtYzLDc73qFYN3ncJwh71YUu3\nreRZhjHQ76UEYYDgqIZC1usRhE2iMCSKI5J+Rp7lBMYSRiGVkRGCkQOk+UXaSydo5ys0W1dwzYyw\nPMajj36YWsWztjJPEG9j14FDjEyObjpvvvPF/5JXvez1fPSe/4d/+6v/irEZeOub3sM9X/gfXFpo\n8iv/6heY2FbmX//s9/PyW76X1HdoN5dZWVkgCVeoVz3Li9BdFeJIKJcgS2FoSrCBkuVCdVjodhNq\nQZPxaUskNZqdFlNTYwRRCZtcIbhyAr/jjST9nNwJsVW6rQWePvEZnjo2S7NZ/OpRCHsOxAgJrbMp\n7bZjeqJMP+2Rpp5e2zAFpA6cQq8BrTa0VhWXbi42KoYoFEIxtHsOYy3eeZxTDBHOQQgcW1rmlgtr\ndCZCTkmbZt5mrdMmfOZJGs0mIzXDQ48dx6qn0/d0Us/liy1cpvSSJqn2uXCmx+//8Qd55UsOs7jU\nZdtMjbtediPX7BmnXotIk5Q4hJmd0xx5wXVsmyozXI6AjMwZRAXvDVme4xHCUoRzHnWeUikiTRxP\nLz6LX/LYzjxjwxHPXF6iPLKf0bGA5Ua+6dwBS5p4JBKseESKzY/3xaYwd0ruldwZvA9JcynmCC1h\n1BGYiDiI8RKBCM47AgyIRTQgMAYjJeJSjXK5TlwqE8cxcRRhjMU5T5Z5kn6KupgsC4jiHr3eGiGO\nVD24AK8O5xy5E1QUsZDmBms81lishcga4qiYkQTBeIfFbsxtTjdXnOl3VpGsRxCVkVKVMKqAGFRd\n8f3EYIxBZHBdLe6rBeVCB/crOI0MXqLfkORcfVwwYrGBJev3cKUIzZPBvFoIHcUqNBADZu5g/qkv\nMD0REtXHycZqLO3NmcuPkTnDiAgu62KlTpbngw1ugLUW9Q5BwRh6k7dR/iZx+JaEJ1AlFzZY5vo3\nNlqwdFXP+rmFRRwMWZ4xEG2KhZdCKQCL4km8wRjFiiXzDmsEjydST96DnX/yd8y6BmVXoyxXGJq7\nwtDwGEPpMdKGYWkppRQEHH14nhcfeh/vvPs2JieGyGyF8xcXuHD5MgcO7GV+YYm9O97A+Wf+By//\nzrdy7tQpPvTnvwf9FNfzNDuGNO2TjVVoJZ6havY8U6eAxaCmIIGhV7AGdQ4VgxhHIAa8p2/sgNR4\nYlUSFDWQekENZKIEKKEKGY4RI6Qe0sHSG4vBKwieQBWHEnpIjMOpYHBYAlItJtQizRSvHotskKmy\nQoAnFiFTUDyOgiAEIoDDqZJq8VkThFAVCySb22gVUEA8Pk148zv/Vz7xB3/M7777+/nn/9d/YWRy\nrCDIWmSOGrAq5AP1xvgB6diQWgbXG8REtFClioFjNl4iAipXiY4xYEWK/xvF2mIgGqMY3EAVKvap\nxQGc+tybbXyGDUqjxfWLnVbxaURls3yHMDSMjw7RbjZpra5yzbW7GanvoVIp022vMT5UY2pqis9/\n/l4uzZ7l5iNHCOIhvudtb2P/wYOcePIZtu/ezuT0Tn7nmf+KS2tkSaFiuFzp9xN6HSXPoiIPZBhj\nhHpdCUNLEASICLnL6fUSdu/dQz9PWFlrUomHMFjwDpelSFyl12+z1rxCaUiZnBwhzIYIzTCNZhMC\nx/btB1lenePcyUeIokle/Yo3cvsdd1Apb75iPjE5CcDdr/hnfOJzf0Z1bIq/+uhvs7IEK5fg4aNf\nYeYa2HH99Rx54S3cfOBGvuPOtzC/dILdd1zHCw6NMhY42r0WYUVwfkBKRVi6INQnIYhgbUlYaZ4j\nTcD3wZZKjI4PEWVL7F66h+0pPH3yfh454YhKCh1otFY5fuox7rvvDHNXILBCXFb27t1LJcp4fP4i\nYeToNxK+70ffyT987I9oxQm9ppDlIB66C4oGggQwdc3mYhObgH6vQ8crViDtZTjnAIP3inolzWF6\nz3Wc3zPOQytPMb84z/U7DvGr/+1DnL8CIxW4sgAzY/Dmu1/C9FSZrzx5gTRLSbKMeq1Gc+UKu3dt\n54ljOU8ce4xaDaIIPvH3jzI9BjfsHWNqxw7GxuqEwyVqlRHCIKKXZsUGAkcUGILQUy6FqBp6/QQj\nFmMDjDGUSxHzrRWaF+cYkmVCKREFyrZaj5tfvZ0vHG1sOndUIXOOIA/wAVgtlHdFcD4A8VhTIrBC\nLx8oJ1Ko4i73OBfiXIxKiA7UDYfgvEUkwpoIE8QEpkwcVqlX6tg4xiJ4VQILgQFrIrwL6KdCHFQJ\nbILmCcZCF4+mildL7hWPwXiwBnJjgOI6mYVMPZEtFHYrxZxtjUG0EBA2A5d06J07hs8SSqMzDO++\nntLQJLgUNWXwKd5niDEIHh2sISBYG6JhQbaKeVI2iI4qz/n76mZx8IuACFFUIjDKMD3i3iwaOLzL\nwSXgE8T3IBqhs7ZKOD7D0MxBAIK0RCwpExXLmgtRn6IuxQQRtFdJxKLqKVlL6JNCtY8qlLX9TePw\nLQlPzvoOWzb2vgJ48QOqN9AAim9dLA6pw0tRwlI/WBnwIJ4ifsUKYkQJbLFCCYb0gufg7/9nzgae\nITXM1OewK08wTE7rxGm02mDEeuKhRwjT1/Hz7/pVdkzu4fy5L/Bzv/hx3v0T72BpLmVidJzayAjq\nPGNTYxzc904unXyGF7/2VXzuk39ERkzWyWglKZ1Oh8uXmhix9LLO5jLIKwOVEW8c1hfBTEULyRlH\ngiVWV6glIqgaAlG8g9RCmYJ8RFLEKhKlj0V9MQIsgOYYAlI8BiFW6BplGKFnoeYNoq7QG0RRNWSq\nhGoQih1KIkpjQCYsHqNCn6IkaRBy78kHyoUTqCMY9agY1HtiNjm6YJ0BA0reT/muf/HDNAn4m3//\nL/nJP/goSd8Xk/TgvkYHMQLUy4DfFO9fV3kKYUdxqli9en0FzAbZWf9PBmWrAfkRxRgw4gpFR77u\nYw7IT/HA+uB97jfX5z6gz6FGmzytPEv6lKKQ0AovfNGtOE0olQz1epWpsTFGh8pkeZulpQUmJ6dZ\nWmpSHYLFtRYnTp9m/sIc23ZO8rKXvojHH38lTzz5OK3VDplaWs2U+bllmmtd6vUhxBQEq16vEsWW\nUjnCmJDARHgjBD6lakrUaqOsrC7SS0NEPL2sS+b6DMczTG0bpSRtTHuNwK8xMVGhVhonigKaPYjK\nNcbtLi6eOcvkyLUcvu42to2Vn8NYnz8azRUAem6Nmw69ifFto1yav0KQLjMfPMPnH/o4/doz7Jm6\nnsnhYYKgBMD52ePsGq9wzYE7eOALn2BifIhm2sQjRCUIQs/k3kIGXDyvNOYUa4SwBJ1MKfmQKFni\nVv8su0eh5wzX1Jb5jgguTBuSC0qvmXC606G1BqEVAhHmF5RLZ9vcfffr+PI/fJgoiuiser7y5b8n\nyfqUx0ACiJ1lbcnRawn1Mcv+w4oZdpuLTTshywtFMs0dYSz0+x4xhszlhEFISWDkroMslC3TtX2I\nVx768hmWlqEMuD7cdM0kb3nDzVxeWeXhJy9z+tQynTRHA0Op1+PYo49w5xsmeNubb+b++x5n+zRM\nj8fM7Jxi7/7d1IbriLVAQGN1lU7epFItI2IJpFBROg6cyykFnnolZKgOvZ6jn3isqZJnSikMeel1\nN/DA7CJRxXPy1FnS8xfYdV3Inl3jm84dURA/UIc9hXqCIKoIIZEYiCy5F5IsLxQNA4EarAiBDYEY\nlQART2gKwmxtDARgS1gbEcUVKpUKpbiENSEeQXSgMCBYG4CHNMkohQFl6/HpMjYeIsibeNcprAeA\nMSFhVMcaAwYUQ2AAo0VZ3hYleQuE1hCJxYgF8y2X7q+DFUMcVsgTjzaXSRbOE1WHEZ8gYQmftsnX\nVrBxCc0SXNLEtxcLhXdiL/HMTdgwHgS2UGXWoSobis6GMk5BfATI8WT9FuOjq5h+A6ovx+LQs3+M\nzJ2G2WMsm1tZu/77GL/2JrzrIjbE+x4SVxg68ALKTlmbPUOeZqRpl6nsDG5lhVwisuoOOvEEoSi1\nbI5S/yTsvOUbxuFbRq34Cr5YakXxyNWFZlAycT4vFB8AI+RpTmAG/gwVcqdEcTDwhhjUKWqKQRHY\ngrXSdIz9u18gv/EgZv4iw/lxTJJRyleR2rWMbV8h7jSR/lO87VUn8DlUyjHLZ85iNeCWI3vxaZ+b\nbrqFY8+eYjyM2LVzB0FkMAoz+3aRhxHeeobiGqYqjIvi3RhJmpNmjixNNpVAqkroDVlRsSMb1HtF\nhY56Im8R4xGESDyBt/TFgzEYhMgrLYFQPE6FUChYv3pUit0JeIxCb1DKKgd5wfTV0336ONnKCjI9\nAwcP4t2gRIjBU3h4rBaqWnnwm63gCWTglRGw6wv7QOXwWpTnmqqEYojUFGTtn8R3riY+oqS9lB/4\nFz/Ax+pD/OEHfocf+OmfIgjB5W6QT+uqCYVXhwGP0Ks7CVXBIAW50XWFZTBxFMEvatxGB3K2bpCf\novadb3h9njswN5Jdr5anNnYtzzH1bPAinkOE5Op7ni/q9SGeevJxDt9wPb1+j+XVOcbG6gzV69xz\nz2d4/d13USmX8F5ZXV1jrdFkaHSSRrvFth3bictlvHM89OB93H77YTK6PH3iGSZH9uHzPlOTo2zf\nPkUUW1RT+v0uxuSEEoMEDNfHmJyYZmJigjKelfPPsrh8npGDhzBxjU7SLsqexhT+kDCm05onbZ0l\n6V+kVq/hNcJITESXIExotRuMT+/jhiOvZGpmqvBN6SZrNsCH/+I/c/+zf8H9932QyanX8Nhf38Py\nEtx551u5fOE0nSShvdxlz80Hef13vIOHj36MhfYsf/PpD/NDP/Ie/u6vf59zC7B/f5vhcUOnBd0V\nz75rx1hbXuPMcce2qYgX3LaDY4+dBQM37NlGvpTyum3Pcv12sIEwFEHXGO44DH95ytATw003v5Iv\nPnQfXiGOBJ9CyQjHj11idfGjJPmg7J97cmkws2uEfteycqFHc7ZLlgr1aagOZ7TWIMg3p4AFxlCK\nQ/IsQY3QaafkDmqRkKmQZV2u2b+L3KZkPaVSG2Fk2zif+9PPsLIExsLth7fx+te+gNnFBdY6Kc7b\nwnqAILmSmBLnn13g/pGHeMNrDjJcabJn7w7GxkcwQbHtjYISWR4iGEpBk3arg08dRjxhYMnynNx7\nwijCRAFLSYIVpVYTdkzX6XegkaSMlGfoN1a47cUv4+BNh5HD9/HiFx1k1+4xqpV407mTu3wwHh3q\ni9g6zfHYYtwbQ4QhDi1RWPhjLCCDclFgI2wwKGeJIQgsOCUwESqGIAzBRsSliFKpRGgDrISFCqRF\necggGGvwVUNZQ7a9YD8vuPY2hoeGERE67SbHnzlGrAFJ2uXy2ae5cOnLXG6epxTUsSYq5nijlIzF\nDDYNoREqAYTGFD5G2RxZro/vQSpD+PYa2ZXHyVZncTuuRVwHWxqHXgN99mO4zKFZH8lTAgNBXEE1\nw49dg9hgUOrSQgGDDaJztczFxnoCppg/JaCfZ6S+Rrw0i0anwIaQT9BniBWzl/bYfkZ23EBgLUm/\nRRiE+CwhKpcxQZnI56yhuDyhtTDLtnP/BVP1hHu+h3J1CPIu2liE2YfpzS9SeeE3jsPzoIm2WDzU\nYHComoGnpzC0qin8POIFK5C7DBEIpFCIjBVE/UbJoChJGBDFq0E0QysV9rz8dprNJiPBZUZb5yhV\nt2PsbrZJkzR9nJl+g2vzLl/6/ENMb9vD2TMnqderzEzt461vuR6tDPHFh4+ybWon3U6TJ48+zA/+\n+I+xcOEMlxaWaJ09R6dn6STdgqxYSyk0RGFIKQ5xWt1UAhk8ESAY8gGhSMmJgSE19HDEanE4RAyZ\n97jAIt5jxKAihWFYBY+wosqwKVQcBHJ1GDVk4hlWoSfQVYNbXCD/4IdI8hx1jr6D0vV7Cd/+vUW9\nWXMsQqqQSk6IxXhwKA4IVXFGMN7jKRSeSAvZX6VQl0pi6Kpi1JGr4/nbca9C1yUQ1geBkvQS3vz2\nN/H60Yinjj7Ie/7jb7L7wHbSJB3kBaiaQtPRor6rqhtkR9UTitkQVHSwpzIwIDtFOa8wLlPUeKXw\nPhmKEtfX15x1oDIONEy5SnCuDl5Fv/odX32FTSo8Z86eZ3hsnAPXXsNjXznKjYcPMj05Tr/X48A1\n17DW6GNo0u526bba3HnXnRx/5gJHjtzEDTcfpt1pk/dzLv7t3xLamJHhUXbu3FPsBAPHxHiFzBdK\nbBSUqFbqGK2g/QBUGK/NsLbQpeRWKFcTaqNVxqerXGo26WlGHhpCE2B6XTrLlzgvI1RrQzQXTqL9\nBp3WCJQquFwxzrFy+SStbo/h8euwtSkoGdIMOs3Ne78+9rGPI4/CUABh5QJhJaRSdXzxwb9lrQlZ\nBlme8aVjH+UrP3sf1x7Zw2/8wTs5PZewd+dOLjYWqATC3HllZBhsKDiFcT0MXCbPTxHGnubaMkEF\nDkzv5kC9xu27n+ZFhwRjCzobVyHJCo/Oj96g/PFRw/33f5YL59uQmw0/kA0hLBnmllaZimJSVfpN\nT70yzdhMyqVjU5y9dJTMCcN7hMqkZ6Q8wsRUicRvTlWOrdDredaz1TtPLQ7JfeFXwcA/+7434H0J\nl7fITEoUl/mNX3w3n7//KY4fn+XIjbtpt/usLvdYaXXp9yFJc3RQYs69cK5jWXvyAlF1nCMvvAEr\nFqeF0dhIQDftEscVJBiiUo2olaC3do6+nSbNAnKXEZiQfjeh3+kX446A1aWMuaDJTYe2U45KdPtL\nJGvwkpftZmTvAq8bqXDq7FFOPLFCp5fyI+9776bi00syHBnOF17FXME7h8ejxhfzgxOcdZiwULvx\nHjHFAi6RwQRm3cWHwxbPBRawSBiAGGwYIkYwYjBBUCgc3qJkBPEQmasT9RJ21CpMTV5LYJWZoTrT\n06MElRqvetGL6Pf6zJ48h9x4G0tnbubUlcs8PPsQp1efpmojxIRExhdz3KCaIKGgpIXvcZM2g6zX\nQOceQa+cQNOUcOIwvttA0gZ+9BpMYCkFOUYTlB4SerABGii+t4DvrSCl6iAy6zo4Az4gz1F5WK9z\nbdSExBjyqMwpP0Z5OWaqto2oNoy5/mWsjj7F8aVPMT4xw/ahOr2F45QroxiJi6YmIGsuQX+NvN/G\npS2Glx7D7/8Z9Oy9SDdERncj9WFUFmlluzjHCt9Y33kehMcPDEjrRiQvimBwMlgbvGAw6KC+l+VZ\nYaswEGnB8RyAEawUnV7eK8aCMR6dhSO/+Ss8W7bY5hVGs2WC+gxiqswkn+bm7hzXl1pE1Yw8zfnH\nxx8hKpXZvncnmfNIZZTTi0t84fOf5PVvfhN5t8e9n7qHV7/xtZw5d5Zf+rn3kavh0A03MDWsRf0P\nxTml19NCJva+WBg3AdFBbIwnY6AiYIouKBQD9FyONwF171FriNRvlFf6XgmM0lchMxAZIfPFjtEP\nSoh9yTc6qzKU2DuW/vBPcb0uGlhCl6CBof/UCfS3f5fof/sJMjGoFSK1pCgl5+hLYV6uqOBUyRgo\nTYOF2g9q1Tk58WByC1CSgWel+09QePS5g0LWhRAl7SX827/9Ij/9lttZecdJfuD//DW+83tfTdLP\nBsTGb0w4Xt2A6FwlPmw8v3719T/9hhmZAfEpLDqFV0fka8nOern1OR6dr7ruIO+/QTlvXeFZn3M2\nmTq43FEervGlL32Zaw/soVyq0G536LQ6jI2M0en2WFpa4Jr9e4hsgPeOiW2jfOGLn+P85dPcetut\ndBo5tXiE2dnL7Jvex/6d1/Kl4w+Re0O3m+ICxUYBuQ8pmTqVaJhyyVApl8lzR95f48C+a0m7c0zW\nh1hZWWH/vp08/MQTRLUyEnoqJQNJjysXHsXaGMlWca5L1OtRG5skCMvktgauTdbvE8Vr+KTF0twa\n5VKVC+cWNxcYoBzXSPpt1rxw7NhJKpUqoQ1ZazQIawbj4PzpBc6eWQAHDz72ZbIMwlB45thFjAjT\n+4S4orRXBd+BZsMQVSpErQgxcHk2p1xrcsu1B9gzPMYL7CNcPw3HTxuOHCp+zwuzymhNEAO7JuDH\nXuT489N9WghBDmkOJiw6DHvdQj3BC1Fg6TRzuk3hlptv5eP3fRQfCNG4MHEttNcgI6WTZTTbmyM8\nq82EgS8Ar0I1DnF5gnphaCjgv37gl1nud1lpLhKFAaJQr40zFo9wcNciQ7Ual+cXSZOMIIpoNBqs\nrLQL9d7nlOpjDNUCwmwJ7cVonhKKpRRHhe8LoVoZpp8WHY82HGIlXyXPLWFlO4EpkeYeY1zR4Wn9\nc0zVBs0M3W7OvfefZu/OYcLIcfedt3Bl+fP84x+u8v7/9iRvvHOIiYk6eWlk07nTyZuopoRI4a3z\nWjRpkOFJQJTceDJSnEnAesTmg5gqmfFkJkfFAA5jDQaLNQKDTSpGCwIFRSkLQARvhdVGTOv4LPHi\nWWKXQXWCZXeUtLVM8oLbyTyJux8AACAASURBVA4dZMdttxFFEZU4IZU1Tn3p7xlrX2afH+KWQy9j\ntnOEv5v9IivZJQiiwjSihauGEEQC0KLEtRm45gqSlfHhTqRSxffbZE/dh9GMYMftXFW2czAgPgWf\nI07Q7gq+eQWpTyPGDjpVZcNKgCrWJ4V4oZ6iG8yjPsPkCVWrDFVqrDU7rLbn2f2y7ybLu5QqNUra\nZZttsTR7nIWhKsNj04iYjVnXBhFOuviki01biBHq6UNcuhRzubudPY0y5RqYWo6WZljzSxw/+g/c\n8tZ3f8M4fEvC44vfGe+LNl7PwDDlc4wYUAcC1rHhU+l3OqQeYik6ujIHQVCUfTwGsYqIhyDAnc84\n9LlPcbweM2ICovo+0tUusamyw53gbb1ZquVl0q6l30xpLFdpNDs8+MX7mZqcpNdo0Ur73H7HHbzz\np9/Dzh0zpMsLvP+X38cTx75Iq51A3mPvzBhrxz7FUC3EIAS2mPgqNaVCUCggsrndaCkvGsXFGCoG\ncuMJJKA7KFEFGIwRSi5HrVDTnMAZEhFScRgxWCdYHJEvSoZOhASLNVB2GW52Frk4S2thkaTX5ekv\nH4PhiL/+Nz9ItSqQ9sgTg0Zlfu/TX+SpRx5n+EW3kmaORAqSpAPVJMHjFZKBh8eipDhQoasQeI8b\nkIyuKqFCSR2JGhLd/E7duYHQuU76pVDDFM9LXnyER1ccT504zw++9NV8+NerVKZu4z3v/w9cd3gn\nSVIMnI2Sm/qNPjG/7hreoCjF0FgvDqx7edbv/XVlKF0nOrJxBRlc5mpn1oaEtDGhrctK8tx6ln7N\nv58nev0enY5F85RWq02r3ebMyZPMTG/jRbfeSm91lb1791Kt1QnEsHBlkWp9Gwf2HSQqlzn6pSfo\npp4Lq4tM7dhOlqa4LGVyKKSVCDkG5zwhIbsnd3H93v20Vhc5d+YSy4tnuf7GQxzYvoPG4mni6ihz\nC02yNKCM4/vvfjtO4DMP/E8uLM1SiiJEHZ1mk4nJfQT9jHf987dz/0Of5czCHFe6XerRHly7z2rz\nGGe6TVoXHyZzhosLFzYXGCAvtQm7kA9J4XVzPVprPXodg3aETtszXDGEAvE49JeFLPPEk4LrQ3NV\nSdaE0XGL9Y5oTNk7Bq3lPgE1ggC+743fxU37d3Hl2Ge5e/rLdHuGl/9CMMiiQdcfAYVGXSyG33Mz\nfOC7PZ8/JnzgIciN4DLFBoKNio1cMdcZ4qrhoU+f5PhXTuIzqO9RahNC1ofKiEFdn7mLRYlqM1hu\n9LGhZWK4TmOtwVriEC1y9v3/7idp9ZZYvbTCqTNPYwKQ3HHp3CJDY7sZHa+j3jO9bYR7v/Q0H/rI\nKSKB0Ql41csPkC2dZnpnmd2H9hCVbqTTyymHIc12n9wZSpU6qpa1lS55lmCtYE2bnBgTlokmxsBn\nRJphdFBuUcWrKwh2ltFrKbMXGqSp47aJ3Ry5ZogrrQan5keZ3nct3/n6PmtLF+ilOZeba5vOnYa/\nQGgElTKRBIg1eONwPicNu3j1pC4liTpkkmJLDMrbFu+hazJSYwiMEJqAvvSohEU3pWpClzYQ4Oij\nPqFEldZ8iUtPzFLRNttCxx5SRkpKPQ7AXGa1K+R1xc1+kXi6jj9xP9WDN3Hpyw+wduJJJpqPEpse\nh6ZiOs1HucYIR0bKLMtr+E/p3yHY4ngRlL61RAPCs1noZ36VvJdB5rCFExJbqSGjE7jlC5goBNeF\nrIPkffBpsbnUDtge+dkHkPoMZmSmaDgp5HPCc78L4/uRkVsK9//8n0BrDlmeRc7cC9ktBNk09WQE\nt/OF7LrlDlRTrFHU9ajMHGHnnUNMd9eQzgru9JO05h2loEt1YoZxey1ahV51hnzmBuLWIn1/mF3l\nM+z2BpdEpKc+giYp3nnq849y+w2v/KZx+NZdWgje+8FCddWI5G0ALtuYGjJRrAevjjTpUgkK03Lu\nITaG9eXK51Kc3yCKLGZU/817qd9yA5WzRxmubqOx1GOnHKCSPsntvdOUdAHXN5A5AglopxE//lPv\nZe+BvThflC1sGOKcx7icztI8733vzzBaqZMTcqB+mZFAePbMUcoVizKGUYcMzrERQL0nGLTJbwbO\nQ2gU8QMPigckJxZwgUVxqC8Yuc09BAYjEKnDiNCRHKuGmkAsDgc4DPrscdLHHmeumREd2I/dt5fy\n6Cjn/uaj/Pr7XsG+6e3Ebhkny9htbcJ2Fdcf4V1vO8zPHa8zjGFNlEwUgyUUJcGQ+ZxU1826Bbkx\naogGpcl8IPd6CrITeA9iQR32n9Cl5ZFCZh+YBov+Ax20PXrSNOGGg3v4yFce4Kfe8mZ04RF+8Z0/\nzo/9zC9x91teWpS51snJwCc1KHIBXM3HIilZV2QK21/x4qKIWviqdL0N8mvPjuCrrDtf13HwNY6d\njeuu31a02BhsBp12l5HhOnEcc/SRR6jXa1x36DqSpM/c8gLnnpnlhutu4sSzpxidGOLkxUtcf0ON\ndgKXzpwmimJOnLuIBCGEder1GlF5hD0Hr+fpZ76M5AlhEBNplRPHZ1mba/DiW68hrgpRFrHSaCAN\nz/DIEJHJqYxUuHRllSwLePTxx9m5axevf9lryZIO9x59kDPnjlGvTbJ9ch8uafL3n/mfHLnuANt2\nHOBjn7qXtWyV8alJgjDg0uwFOs02Ya1Onvc2FxjgJ3/85/iDP/wVMjy5M6Qr0FqBfg/qkcEbBa+U\nx4WoDv0GBJGQOsi6yraRmNVewtKi5+ANu1AajFSmSJIO1juMQnt1jsuzGeeXhUcTeOkh5XM/78mc\nx3nIcgiilHYP+m3FqVArw2pLeHq+MORGAxM8gDqhVldK4pB2hAi4VGhdFib31zj8kiGayUVWlwTv\nheFJ8KmwcmVzG4lqNaJeiVluNMnynDgKyPIc8TBe3cZf/e3HGRsdpz6yg9FqlcWFecamhTRt4PMO\nYXWE8/MJ23fv4Vd++SChVYLCSEcv2UcQRajLWGv2CIOAJE+JwxjnApJejtcUAUIRkl4OEZTE00kd\nfaXYyDrHwmKLdidhba1DnilXFlaZX1xgYmKCH/ruF3LzdeM8e/4yc1dgredotHMudU6wY+848bVT\nHH/yBOKbm84dDRt4E5Ebh7FRodp7wfkUifuI5ohLCW0forRYP8SCFJYLrw4VwRVyME4go0MUCNZG\nRRec9sjo4JyjsTDM5ccvUcmbDNcCJvMlKr6HqMXYiNRbqtUaUaVKEOeMjhny3hrp6hwxPXYPreJ2\nhpDmmDBnanuJvJ9i3RVK3R43l1/MUf9wcVwHkOERm21a3QEIywEmL9qQrCiYgniT9snnnoHaMEE7\nwWZtEAd5BurwPiUHXHIB05jHDG1DTbF2ikJnbYxa60Ek/St8OFWUA+xhkt4+svh21vpXWKVKoz6K\nSbuEjfP4bTvIvaNUGiEIQmz7HP7kZ2HlNOXlL4HmRCMH4MAoGrZg4mVUKjvZ4+rkEpG3r8Az/0iy\nNsRCA8rXHUIqU1CdIt99J6W4/k3j8C0JT2FhKMLrN3a7VxcDEcF5j4hHpfBYaB80CjCDg9zEDg7E\nQwlCpeccJvPYm25mx79+Nw/NXyZv9BhN+vTiEdphl4OXljjgPoFVyBkBbePylM+enuBd1+wlSdOi\ny0S1OD8EUOd47//xPj7+53/J9DV7eeUrX83xoxeoTHZRU8YGoD7De0HEYwIBhW4/pxLljFc3t9sy\neNQ5gkF8FENOjqKEqnhjSKyhKxkGQ8UbclOcLeERylgCgU6vS2AtLjBUTj7NylyXymvfQDhSpkxA\neuo4Nzz4ed7y5sM02wmPPTBHffkKLuzxwncM8+QDC7QuLWMnxnjrzX0+eTyjYg29VhvmF2m22oxO\njJG/4hVYG2IVnFe6CKH4DVOy8UpPlKr3WKAjQh3FqmI3W7MBnDOD2n1hkPYKVrXoPpCi5JVmKft2\nTfL7n/o0P/HW7yFbeoY//62f4YHPvon3/of/nUrNkmXuOQyDDcKyrsqs/yHr9Gbgu9H1s3X0ardG\n8UPpoBT1nHKWDMjR13pxnqP0XH3sOQ8Pyo2bRafbJwhiev02d9xxByh88pOf4uDBA/STPjt3T7G8\nusKFS3Ps2LuP7sUlHnn6aZaXlqlVh1hud6nUaxgb0Fhr4LyjVCpjzAhZf5TcJWR5n8uNRegFLFxZ\nIvMJteowmDKLyytUywH9fod2pUOaZ3QzhXyeLAfNuuzf9hImZyZ5yY//L1y48nJq47v40Ic/w/zi\nGhfPneIz9zzC9h27GZnczcXZM7AstLodFpcXeNn1+9Ewprf6zdtDvxne9a5foJcu8/7f+e+ghToc\n1wqC22znjIwI07vK9LMeQVUwcfEj5GueajDEW77vbv7qwx9hYQV8NsubX/8uJkd28plP/QNR7Bmb\nggcef4TkwCG0b/jo5RmMznHXrYINIEvgLz4Bb3od7K8qrZaQO1hZho88oNx7wTBiHWkg+BxqdQUj\njM9AvjJo46UocflMIerw6EM9RnYJ4oROE4bGYGib0lzb3LiqhpbZSw0kKM6PsWKwJiBzGYmBN735\nO9i74wCffeheHjtxhivzy1ycnefUqXniGF54807GxiepVYdYutJhdHyUWhhgSyVqZYN6Jc/61KwF\nB1nu6fcs6ruI8UQo3VRRJ2QuIcky1la7NJp9et2UTqdPu5NgAktgDEFgyDPPxOQQ73jzazmwu8Yz\nZ+f463vOk2WGaqXBgV0ziE/JndBPPUnmuOXGAzz2xMObzp0g7oCkeJPgpFR0mSI4n4HpF2e1+Jwo\nTLDqEDxiAkRzvOY4Xyh6InYwPYQ4EVYiB1QYs+OU0iG6aU63W2HlsVMESYt6JWIqaxD7DoFJ6SYh\n+IxqPWZy5za8c5CuUa948rhCb+4MfvEk/bP3EEUlslaDuGYwPqY+PEllaDvl+ct8x/IOqL6YU/Zh\n3OAcJ2sKr+xm4bME8rQ4hiR0hCEFiU3apCfvJyuNEl2+ggZ9rCnOUFMjQEyuPbxeQi4+hR3fhalP\nUoxIRW/4bq6s3kze7jGx7whhuUK/12D26S/Q4ArB5LVcevooQWyoJSlrj34Ggjp2ejfUU9Sn2KyJ\n685jQ0cc5RgFpyV86wCm7vB6E6STuLU5TBhRSto05Q7avslaLWZo7hPYsT3kzRJ+LUFWFuHWt3/j\nHPlWQZLB7toNPC7FerDefmeBHIPFaVG+8lqcbeCcrjtJceu+FYRACzOzuwJT//3f03DK5OoKtzDC\nxR2TbF9ZpLK6yIw5T7cHw5HF5RmBT/Em5KGnm7xbBCHA5x6vDvKEqFziN37t17jnkx9j+6FDlKt1\nLl48x8S+w1y58CmyXhdnIbOThdQrEFAoG9VySJpbzq9u0gTmcqyxiM9I1OA1IRu04ds8pxSEoJa+\nFRIrdNXjcyUTKFkh8IKSI3lCqdvlykc/xdqBg4zffB3Gd0gfP0k2MUrj0UepTu4lKXfp9Rw325DH\nGj32v7LC4pMTyK7LXHi6y0zJ0D03z3LrBWT3foGs3aZ0x21sv+kmzn/pUQKKLoKKConxZFq0qZat\noZU5nEDklUwhFqGsHq+FPynwm+sIAMhzIQgKRUcGhnVvC7M0dr0CXJzwPDVe44P/+A/8zA//KAun\nH8E88Tf88Bsf45d+5wMcuW0fSb8/IBqF2rJOUgZJunE2BBsHBepXqS9mvTQmMugu+JrfevD2r/L4\n6HPcO+vP6XPI/nOe26Q4iJiQZ06cJrBKq9lieKjGXXe9muXlRcIwwpuUo08eZXR0mi8+9BgmCCHv\ncHb2IgcPXk/fOcqlKnv27OHBBx8sTPhhxMLCFbZNbefipbOosSROWF5e5saD13PszEUunfoKo2MR\n+/eMUysZfO4Io2UmJicYHhtltFymUq4TxhUW584xEmxDG/b/5ey9oyy77jrfz94n31w5dFV1Tmp1\nS2pFS7JlgbOxMDAegzFjDx7SwGOB4b034DfADDCD15Ae89YwYwawCTYDtmVjjCzZVrCCZcVudc6p\ncrh184l77/fHudWSGdtQ7LW6+96qWn1v7bvPPr/9/X0DbhLzpc9/jiQWuL5FZXQIY8CpBjx35Bkm\nRye5fPUKC4sL3Hnv/QwMb2e52cTxRzY3McCP/utbefnocUZquVzcDQRWYOg2BFoZpAupFWMywfoy\nue2AgW4PFC1u2vc6ih+weMNd72L7zA6Myvh/P/bLHLihypkzq8gs37mOnb/E1uFJIuXwl8cG8Kx1\n7rlF4npw2374vU8J7tojuOcmQ7sFXzyqeXgBfGkjLVhTijQVOfnVzlutymgsacj6NycUdOLceLR9\nCooVKJUhiqDbAGlvbuFcXmwwWK0RRRFJnJCJnBtjWS6s1tk6NsbymZd46IkXqflT7Jrcx90HRxkY\nGKBSDZCWxPYEWHDm7BXOn7/MK0dP0+y2mJoYZWZqEM+1SWJoNFq5SaUlSJKYNDFESW72lqaKNO3L\npS0Hv2Dh+h6VWok0VSQqI0kyhioB99++hbHA4rnzczx/UhEEDu1Oyu03jDG3mlFvhFgWOBYEjqDZ\ny9Aln9/8N+/a9NqR7jpCuGjhkGIhRb7La61BhhiVIIzGAdz+IUlriRYJ0uTimo3WtSVdMh2DSUjM\nDg7KOxEL6zzx9DMUtuzFt5oM0kbIDBGF9HRC0TEUXcFASYEf0Isyav46zYVZmusd5MJV/OFJotnz\nrDz/BcT6CkFNYjuS9UuGsbEm4wVFefIgSTTIRHeN+9xp4tKNzMXH8fSriQX/cAv7x4aOkrwbYbKc\nLuAJHEuRRG1M/QpZeBpaa1B0cYpF5PBeilt3k6zNk107DVFIcvEZ5NAk7vbbEX4FhMEi5Yk/+RXG\nBiapDxZxtx6gvOUmqsMzeN1F0qOfYWk2xd5VxAqGsB2XzsJV0gwGRmaQTglZHseyNcL2uV5a+Cnr\n0++gsPNNFGpjqM4i1s5BVp78c0i3sTQ5jZmA8oUHEW2XxtHzJK0eNb+N4dsLkL4zabmP5ttCovqR\nAxiBwvSJbhIjNLlhbe6pIhVoV1w3YzKZxLFzrkhsNEIIbv6tj9KeHGOt06G4sEp5fAYaTeqLy2x3\n1pCBQKSAUARuBy0lOCmj1YgszUiTvNhI4oiCLXn6a0/y3379N6ltnUEnEa12k4IcY8ctkwzs/lcE\nXkC1UuFrn/1dothFZhmJkAhH4jk2rm9hu5tbQKlRpJkhFgKZxsRKIywLSxpiCYExiMzBKziI1NCy\nNyhuuWwVwPciOHmECy9eZurdb8snu91Btzus/dnfUDl0gIEbdjO6ZYr6/Es4pTnOrmsO76xSGyug\n0wMMJiGHDoesL8L05F4G/uZlBraOMOdOo184zjVjU/ngDxLHKY6CTFqgcyfnBJErbcjbBJJc2ZNo\nQaoVgZA4/XbUZoeAPppm+kqp/IaFJfr+QnlFYgClNOWCxe9/8s/5Pz/0E5x+9kvMTGs+/IEf4f3/\n9hd4349/H2maIoXCthy0NjiuTZoY0gwKBYFKM5TW3+yO3C96jDCv8nCu83pe3THMN8NF/V/gVY6P\n6ZsAXUeDNtpar616NjHeeP+beeyxr+TITFAiSTVXr1zj0qVLTM9M4DoT7D9wkIXlOo3uOqVShayT\nsHPHLhrrTcqlCkEQ0Gq1GBkZoVQqYVmSTq+BUlneWnUcysUiUanH0eOvcOvNdzAzvYNW/RJGdNDK\nRRgbzwuo1+skImGwuI04TZGOZnB4iFK1xqW5y8yv97gyt04kfBrtBoduvplemrIwd4WDB2e4eH6d\n1XqdHbv3sG//PnqdBEv4OHLzbrlrcpaZGwZZ6NVxQ0Hcy5We5SGwOwKvCFdPaAYLgso02MOCXh1i\nJYhTzece/yi1kSX+9itNPvAvfo5jp75Eu/0CfrAHSxaw3VyJZLkJS+FlHDVMhs9/e6aITY+7bxUc\n2A9DLxl+5VOCv98Pj57WLBUmObTH5ciRywSuS9FoMmGIeoKgYEibMDomWI5DtACTgbYgSfPuAEIQ\n9cArQGdF0KhrvGBz62Z0aDBXeCKRlo3KFGncodOGbSNb8cduwBm+i9//d+9FZBFJ2EJFLVTUQaUh\nQtpI28eybW687SDOvXdhFcqcml3k0a89wy//yu/S6sHP//hhCoUqvdjQabcBCKMYneZ7nrQlhcDO\nuYFxkneKpSZKFc1OyOSgz523TiEswRNHLrDDH8HziyiadLop46MFxoerzC6v0+6ElAKbRjvFCIUl\nDZ2uxtsyuOm1I7xuzjmBvnK2X7/YeVyNkBu+vv0hc3NYaTbuszagUcYgSbGMQcpR7k7eQnzlJKcv\nL1D2QWZt9hYF3WaDRmJRCBxMpujpmHXVo1yUDPsGWSpw+dhLrC9kGNfHLM9jeT6NMy8xf26OSuDg\n2xpv0MIfgTiUdNbqeNUVytWA7Ow8Y6ng7cEBni3VOJ0+hWPMN+1d/9ShuiF2/1CoyVCuyT2LYoXv\nDTNw6EbsqYMkta0Iu8CKM4jQLSrUKd1pMO3TdF58hujkVxB+GWfqRoQd0Fo+y65dhxkvV7jy8MeJ\nmj0OH3gLfqnCwtFPUeudZfeBtzDf7YAcxqqMIoanSY1GxXWwHfBrCK+I6K7ka8mA2zxL49Hf5Suf\n+iPuOLwX2VnC9Dp0FpfAL7HcTji77nDnjVtx2h7h3EK+T/s2VtH/tvPwj/jw9HuZOsk9dDAYIZFG\n5D69BnQfwdEGMp2SKShITdrP33KRGJNL2xUCJ1XM3Xsn2eQAWy5dxtRqrPV6pLbL9lqP8dXn2WaO\nY4sRpLWW32S0JhMVXlgcZunaCrWRYbrdCEvFhKnhl37xw1SnJrGxUTLDcgVxL+SpJ57m9vvfyP0H\nb+WFb3ydoFLGtQpIKdBSoJUmCtP+KX1zi8hLc+WV0imFTGMbTYShoEAHHj2jiU2KneayPJSFkjmy\n40pDz1Gox5+ktZIy9cB3kfV6SCkoKoMsFNjyI+/l1Cf+EnPkRe564C3Mzy1zy8ECO+8ex2o2WXsu\nZdU7R2st4sa9Du6tVX7v9x+hZxw6toNz33ayn/hRSsMDmF6EEZJM5kRQQ4o0eVKVFJpBrZg1OYLn\nmtwd2zWSjtEUtMayNs+S0zpHcfKsMfqyWXG9SEaIPh8n/6OUwLcFf/AXf8J/+Ll/xxOf/iOmt2/l\nbz/+UV545ll+7bc/QmmowlNPvIAXlHj2S5/hlSMnsFHsO3Qfb/q+tzIxOcLwRIUszlBCI64bX268\nTn9N58QeXsWZ8u/8w/EPfXo2ECZe819qxHXb9X/q+Jfvex+f+cLnmR4dIk4UUdhjdbXO6Ng4WZag\ntMsrx89SqlTZMjXJ6lqdanmAWm2AbdMl2q0OR08co1KpMDo6ilKKTqdDuThKpVJG6WkuXDoJKqFU\n8il4RU6dPEmxVGD39mFs45P2eti4dHs9lEmpWlUSragUPRRwdX6RXq9Lmikyu0Y7Br/iMz4yRdTW\nZMZCRZIkTqjX17jr7ruYmtlCt7NKvdWkm2bIf8a6oRKRNiWmC5YLMoHuOiSRwWSCbtsQ90AX8riG\nzipUhmGoCKorMd4SozMwIq4QnniaYaNxgnHi2KVSrTDWBS18VldiFpY1BVap+MP0lMtvPZzy/9gJ\ntx+S/PDb4P6b8hbXi13wSzaNuINXBhXmPLfhKvgBIHIfrTCE5mpuaJghcMowtgWWZzW9UFAuCjwB\nqTJYDhQqm5uaNDMImRJGMUnSZX4u4/bbbubnf+YXYHg3q/OzxJ010qiDVimoNG9NaA06xXJ8LNvv\nt3JyUYVjOUwXivzkO9/AT7//PTx99By//Z/+C0dfeIztN04yMFhEWBau5+AXHQpoXNelFyekaUaW\nmnz9hTEjNZd3fNdOVtbbPPLSFbR2saVkVnTYPlDFCiWVQYtb9o7y9LHLWKKCo23SLMWVklDl6iff\nd/nkw0/zln+7ufkR9gYKyzfT7jYQXItvPsRgkHbfMNcYDBnGiPwwnWUYA4OlG3HPXaDVW8cSMZPT\nM7zhhnGm1AUeXMk9cTrtDoHdxpZtWjYEDiw3JKmVsjaXcXXVY9SNGNURZAn1xUW6IYxXcqWwE2r8\nwDA6rHFLNmmvTTcx9DprlAcd9tgNJoMD/F1gcz56ikxluX3JJkbWSZEWCFvm5O2Wzt3Uw4TRH31P\nLr1XMS6LoDPsrmCVKSLLpVJ/ieLap3FqNxKev0B66lGEX8Ye3YVXnKAwNkGaxpRueRve+Haay2fp\nNRy8HffD6I0UgiHG68uo1SWGx8dxp4cQRpOlMXG3gddaxs402do1TApoiI3Dte4kd73p7UwWVzCX\nzpB11hibGSbrrOCHBr19O421BrXWGp5tMK6XAzDRP9NpeaOBIKXdh+1lH07T/QWSIzea3E5cZzon\np4rcYjsxua1/isDIPD6iuwhTQY+Z43M0dg9z8fI53PHtJOEsbjOkZkJsp4I2k6jsKSxpIYsgmppq\npcgv/PxPc8/r34CwfKantnDl0mnmZy9iewVsy1wPd9y6f5qx4REGfJsvfOK/YmVXKBbHQObGfOhc\nSl4oWKC/tfz4Ow3fKJJE4WeGyOTWXW5+9oIoRdl59ISlDLaUCK2xTd5SqWlNcOEi9YWImXsPY4Ux\nQa2A0gK13sFxJE7ZJ884sXjslXPs0gkPf3KZn/reEpcaHaarRcJTZ3B9h+MvQmp1WVprUfED1sOE\nQuE2pseGaUUxoQAPg0kNtp1/Hpkx+edmoNkPhRUijwxJ9cbnn5tEtjff0ULpfLXQ5/HkeTC5e+iG\nLcF1TEXkCKE2BplF/Ic/+C3+6/AAl57/Mu1eSP3q03zgDW/g7R/8Sb7+0F+QdNeYmCoRrXYpFhVf\n+cIzPP/VX6U6NMhNh9/Dv/ypn2ZgbIJMZRj1D1g2140W/8En/i0iIv53c8IcNtrgJRnAaLHpltbd\n97+ehx/5Eu9+1zswqkSmBUFpkChW7N4zTZIo6o02rU4MWGzftZsoiplfWqHbvUyhUODee+/hscce\no1qt0Gq2KBQKZCpDhKWqqQAAIABJREFUSM3pE+eYmpokzSKE5yKUhSccGu0Ox45d49ZbDqLsNXpR\nGxUnWJ6m3ogoBC0Gh4ZxHY8kgXonYc+2LTzx3DFKlRpeUCSKQ4xIyMKI+mqd2dl5ur0WtmNYb8yT\nZjHtRos0TQiCwuYmBlicjUkjg12UtGZzwnJtFIoDgm4j71HWJgzRiqHXNThFSDqSmf0GXTY01uDF\n5+DAnpM4ye+wd/oD6I6HPV7DEgEyCGjWNa02BIFAZLDWXaVWqNIIC/z3F0oEXp1tM4LxUc3XzwpU\nMUeGAtdmZHCQpfkWjhaoLvRSg2ML7Cr0mnnkjJACrRReSTBaEkzs9ZlbTiiPGGRP0OxAGguaq5tb\nOLnDcka90WSkVuNzf/3H3H7PWwivvMz6tdMYnZFH+rlIKUHaaCEwKkNYNpYb4NgOeZme+4MJSxK1\nV4h6dazZoxyuDPDg3/x3nnrpNL/x678AysV2c/GD61goDa12jyzTdKKEVitm/7Yqt+4ao9EO+cyT\n53EsB60dbAuGBooEBYtCUGbHRIDthywu9+h2JdWSTZJpqkWPRtwmNRmu7aC15Kkvn9/02nGc/BJV\n5jXdbXiNS/urIoXrKInJkZ8NYY7EQRiP6eBWqnKMkhikJ0KMEujiIO+69xA31y4TdT3a9Q6JsPGt\nKM+RlEUaWUa3rtk2nLEW5mrMdmzTWoq5Ierhiohr5y8zNVKikYLnKtxYMbso6PQUW2Y8LBnTXu/S\n7sZUBiIqhYxqOePd8m4e9gc50/1bhNoceqrjFGVz3Yk6DmOkUDiDO9HNP4fgrpwwp7qgG7iiyFTz\nzyFU6DWbbDkjPPP36KGdMP8K2cguZGUcrzhALXBwB4coz+xEjuxCSpssS1F+BacyQLF7hS0HdiPk\nVozngj2OjiM61hS2NUdQmkSPfwT32J+RPPmHGBtMllJMVyiVJMm55+HyMUQwgtE9rOoIDjAwuYtg\n7jjBwjKWV0VbNiopoJNvH0vynUnLQgAaYSRCZoDqxwH01VtGY1QG2pAJSRJn+UXv676TZc4sl4g8\nNO1aQvXP/pQzR08hbrybpdYiBbvAcquFijzOJsMMxdPUWrMMFZ/CBAEmSVBdSLXm0E6Hb1xo8MW/\n/QJpFrG8sgIKhgZrbBt3cBFYQlAMJBMDXZLOAq+cWsIIiy2j4/mNSYDRfWiz73nz2kymf+rIlMbJ\nDCpTWJaFFuBoQywMrlbYyhAZQ+zYCJXiSotYaDwEa7ZAHL1AcOdN+FGGX/BodnuozJBlCYVMkkYJ\nez74Q5QCn+6jXyGaXSbsdfnGpQWmtu/lyU7K8M33cuH8JdTyJQq+RCYJsZDgOax+8itYL13Cfuub\nCccmKSpFz8qjIkKtrwe5KmNIEGTCkKFxtek7Z+cZaJHJ40M3O7TKJ1sYcst0ARtQoOlbJYsNwnHf\nqVsYk/s8JSE/+6v/N7/+cy1Ky2fIsJkYS3jswY8xVJL4JR9PFBgbtBidsLhth6DjFGnHGa3uc/zO\nR57lwIH7uOvN72Rix0Ecz8qzW+gXMfnLXmcfi421fv0o2F//bKBAr70xvfpM9zfOzar2syTjpkM3\n8/jjT/Hhn/0ZTp04RtF1KFcqzM43KQY227Zvo9FqMzo6jmPbtJI2ruMyOjrK6uoqH//4J3jHO97B\nxQsXOHTTTX1PlTUajSb79u+jXl+kXCmRdRqUC0Ookk3QbRP2DM88d56De3fi2QHrnYvUSgVGh4cJ\n3CE6LY0xLTzPZXF+lrWVJVq9DNcLcG2LKALPDWist2g2WriOh1OrEicd4gxcx8FxXWzH+RYmj//4\n6KwKem2BtCEMDTqB+rxAWgbHE5Rr4HYgcwUH7hxnrR0yv9DA8uBNB29m0m8icMGTuEM1vvTkQxx/\ntsG+uyaYHA9YvBISJfm1P71NkjUtFhczItFkbHyGnlvir85V+f7oEg3b4X+eSjmwYxw7LlEREOsW\ntpOhsh6uEGRdQ2Z5RInG9lMsB1SiQYCnYHXJUAo0diborYATgBvAcMmg083NT6OVsLi0zI+9/4P8\n+1/6CMn6GuvHH0VrhbTsvrAky1WDKleROU6G5WhEvEqhMEPcOIURAVpnSLuCMS2EXUHKAhpBr71G\nfORL3LllN3/914/w0z/zk7Ray6jApxvmflhpmpJkGq0V3//GbVyer/P8xTppZnAshyzVeL5NqWDl\nxY7rsNxpYGOzd7TMkXOr+K6NJSRxlmB5Dq0oxHMcfMvh5aOzXNi8wA/PMf34nA3UJv+6Ma/y7K6b\nv7NRAOWPFKbP+Uuw9CghC5R1Fa1XifytNGrw5u0DOM2z+COGuLVCxesxt6SxfEFB2KjSIAXZYWzU\nJU0bVEs2bgEaHc2pDkg3IM00rguBm1JvGXZO24Q9TZJlXF0tIoIaI6UZEl9geJrF+SWmbjyEcCyG\nVZvXewdY1SeIkjObmhuTphgtURqEyXDGdxHsex3Zxa8gD3wOTQ8ZXcE0T8D6LIy/l8xZZP3FL+Je\nO0L32BGujDzAuLVKMWyhls6Rbb0NHJve818kvPA8Ztth7Ps+iDMwCtIi7nWoNU9Q3Gahk0UwLUxc\nQDiTdOIZEq9Iof4NdPJV6Cwg0imEO4aJl/AE7Hnjd9M0NnLfAzC+D91aIVpdxrgDsGOc8V13EBQz\nrOazyNokxvJxfEFw6/d823n4zgWPMRiRgc7j1wUgNGip8UsDSNvKbZ8LBUrlYdbri7zlx0tcOnWK\na6fP4zkgLAstDKaV4H7459g/PkajUUSQMt5yCIs+dpaQpQ08L6VBlRUxQ6LO4JFhpIEUpGVz/40R\nz150SKWNdm0qxQEynYLx2Dac4aoMJQzGZLTXroC0UNpge3nImBQyJ1GTOz2HUcL4gEWYSMJkcxVz\nlihkppFSIHQ/yoCcD5M5FiiFbcAyCqnz26QlLBIyCklGmGWUmy3U6AArUYjKDEblKFCcZaQSTLvN\n6c9+Frm4TMPz8ITkq6fXGNXXqJKRzF5m5417+dTLLTqL1+hkGVUswjglkxaXj5/FfuEEI7/yf9Eb\nGsakikxu3MI1icnzgIzO878E0DMKY2y00fhmg56+eS2STnM5urbyOAghN2DjDdtxcrSk3yq9zqsx\nAqUFWS/lJz7yH/nUv/8xPCvlwnLCzbcc5sKZo+hMsbLeplD0uHSuzfFuyt6tKcMDKWnLp1S0iVYe\nJ158kRcv7WJk5gPsOnAApVOuR0W8tkWFyBf2BoJ5/W/Rr39eG5OR/17KACb379Cb1KVblkOqNFMz\nW/nTv/wk/+tTf8Vvf/SjCGFjpKAYVFhYWCTJMmZnF5icnqFUKtNqt+gu9ti5cwff8z0PcPXKNW6/\n43U0m00s6TA0NIRlSxYX5/EDj1phJM/jscs41Qlumj7M6PgkX33kQS6ffIQdU0M4fgUhBUXHZd+e\nvZQrFRYWFsiyBMd1SXGIsxQ7cOh02mzbup1Gq82pk2e5dOkqxWKZ2+/aS5J2sB0faQUUCjbdbrcf\nbLm50esBmaEXQq8rUImhMgRxS5A2DCIFrSCzNUtLLVY6IT0FY4Vb+eAPf5z04sMk9QsgLK5dnWPf\nAKztMATCxnE8wgikC2FTgFKkSDxPstbUTE141Ea2Uxv0eV7sY2zLNsKv/CFZO6Adxaw3I1bW29TX\nQ2yhqBTyZOt2JyROwfGh6OWOs0rlN9ChLZDGGZOD0OmB4+avjxSb5sadOL/M//jPv8L7P/ABuvOX\nSZMULAejDVIKLCvD98ro5gI6i7Ecl3D+AlnrLGm0xrV5C8tqULzhR1CXH8fyixSqGbLgI6SFXdmJ\n9EsYb4jeyhW8Xos//ZNP8GM/9sM02x2SLEEbQZgm7Joqs2XI49Nfv8KA7yGkRmW5atFxbDzPwvUk\nUdgjCm186bB/psCLZ+aJE5sw7OKOV0iVoRfFBAWHLFFkqeHRR84wvcl2H9APoqbPydkg2tEvgnLa\nQp4a8M2HGAv6Bqe5kMWRHTy1m0XaoGfo1HvcMjXChYsXuanSobWm8DyL+dU1WnGR8S2TmFKZkekJ\n9uwZYTw6S9yAbn2ZXkGgB1POXYbM8VhaXqVlHLraMDpYIE57LLUFEzvGmNw2hDO0n/LkboYcm0sn\nztJqrmBSje61aDQSyr7ihvIBjpnNFTxCgJYuwmRYts/gu38JESRcPfkc2W+9iYHdu1h8+Rjulmmi\nSx28LV2SHePUVx1WO1to7PsQY3KV6toVUqcM7SVUr4Uol0mlj+cOIObOorttTG0YCztvnyqNaT8P\ng/swwetBljFLn6AcNKgsfwp6VyG4m/RCl97poySxTa3QL1CvvER5h4Oz9jLi0qOojoVILTICNJI0\nbaOuHcdvNSDqIKwM8a/+I1cmPsj2bzMP/whp2cJyPCy/gGXbSMdF2gFC5g6jShks22CMjeXYGCtg\n+4Gb2b7/IN36PI89+Lf01utYjqRTh++9//Usj/hMHz7IyYvPMuKPsWZ6BJHBNDQtHXKloTlQXKIR\n3IhuH8ELLFwLem1FGrvYTgHLKGKV30yzNCPwLbIsy/0FIF/o0kJpjco0RU+hdQyWnydxCw1GIG2b\nq4sxhYJhcnhz8LubxigjSRQ4dr/UEQK0QlsCpOgTuaFvuYhRGsdkpEISHLqR9ZePsPjYCo4lUb5D\n4LikWYatNYnRhLPzFKKQUEA3jFjPUt542ySPPPQoE9UiExMTXFtcJol6hG1DsWgRGYPMUoTlEmUZ\nluWjjr2CfOP9GJEnZFv9WIu4vyn4wiIxitTkQXoCjRaGngEHk6sXNjmyLOfsCJkjghs2OAj6LdAN\np1LTd+vcaG/lhWmqDGEieeA976N1+WG8Y9e4sNzED4pcW2+yZXwU23ZZaK4yv7DKfD1iZHiQg1sz\nZsZtOoOav394mYIzxz2lGja7URv8HfpFTL9F9c005tfG5L7mUui//dxMDYyRKNO3fd8k/0v0XzPV\nKUGpzPt+4ieZOXQrP/+hDzFZc4njiGqtComkXM7RNaUUnuexuLjI4tIi27buZO/e/cRRyuzVOdqd\nFnfceZh2p8ny8iJ79h7I1UNphgxG2XXgfu69+y5qI2N0ki5Xz3yd5aUmxpIUiy5Ga+YXL7C3to/D\ntx7khedeZs/eAzx35BhK2JBl+L5HsVhkvdnh9OkzedSAjICMJO1iux5aGYzQeJ636cgNyDk7rVZO\nLBaJwRcQLhmENCglWZ0z2EXDjpESh2rDnIln6QjNA2/4XtbWlzl1ZZnVM2fpdFJGPMPA2AC7vSEK\nlSFaPUmxCt0OlMqaoVKVlbRHXMxImyBsi8rAAOPjNXZunQYtKXhw5Pw52msJWQyuDSUBy12BY2kK\ngUXg5ER/nRmMk2/WxYJAJjmIWAjytHFPwEDZptPNaLfACjY3N7/+iz/D+z/4o/TmrmBMntmkTN6e\ncizonHiS6Mpj6Oo9MHIDlb2HaJw4weUXzpJhs/tdv0znyjHqYpyqW6O+tsz6YoilVwiqRaR7lOL4\nHopTt2CVpsmiFva1l/it3/kYh2+/i127J5mUDveObmNCFKi0KuzfNcYXFs5wrWvwbPAch8CXuI6k\nGyY4QuKg6YqI9U4TnQjedhOsegf4+nNXqPgucZzvCOWSx3PPXKYOlP4Z6CD0jyobSD79h330diOX\n6nrkQf8n9PUWtwHjk5pJMrmfJN1D2PQYKZ/HdWzKQZF2awkha8w3uiQioDI4QmXHreybqrFrqspA\nQeG2G1w6fgURd/CFy85Jn2vzLTKVMTpawxocZD5RXF7rsnUQDt0gGdwyQFTcSm1kkOpQSq+5zq0/\n9OM8/+n/idKKpLVGMRjhyLNPcOit7+K8M7CpebEDH+GWcT2J0AY5PEV07BN4l57j4vv+EEe2Wb31\njVSGJtj/vv1cfPIhku4K8wtz1LbNMPfI5xi67200mwOU4nWyXhMTtpFpAqUK7vgYOoowhTKyb/ro\n+D5JO6R7soXv/BrWfR+Dmbchpj4C3Xmyyy/S+epDiOghlt0HiEcOw4RmYOGTICE79TTyypNQqiBF\nDaNjrF4PVIiQCeELX8AxIU4RjM4Q4zfg1J9k6/xvwlujbz0P32mSKhO7QFiYvqmg7jtoKpUvUIuE\ntCfJVMhit41WGb5bJNG5nLnTqmNbLrqbMPZT/5rnbtlB69wFXHsBr+Ox3ljAVgqv6KFKKbWoxdDO\nMUrLFi1vBNaqVFnFCQTlIGQiilEyxHWKyG6CIyUpEukXESJG9BVkEomNJlF5Lz1OQ+prGUFJ4dsW\n0nbylFpL4lQKZEpzaXFzm3OQZmhhIWwboTSpyFtDSIGnAWPIXAsXQ6IV0lg5kTZLMGGGZ6UEtx4k\n0wYrVqRJiIpSbAkiU0THjhMurdFKYvYd3M/C/DWSMOXLX/4GtjAkMsNPe6hQYZa7VPxRGtE1Bioe\nvTTFMhk6yxGNxQf/nvKRk5Te/4PooWGUys2rJILYCKTIUNoQidx3JyAPXbVN3vPdvNYG0BqjrW86\nZV13Q95AVsyrmxEbHSUEQmpsAQMFQ8NxmZ4aZWU9Znb5DK6dMTk+glIJa/U2a6urbN15C2nSIola\nXFwpsdBssW3XIIM7PbxSjVb4DU6+9PvsvPnD3+y5c12m3i92+jla4lsVPf2fM7rvvWvyYE2tc5+W\nzY18o7UtC2NAZRl33naYF4+8xL13HETiMb+wQG1wiFKpiCPzvJowjNi7dy+tZotut0u1UmNubo79\nN+zHsiRz81eo1Wq8+c3v4OSps1RKZaqFLQyNTUK4yje+/nne8tYf4ND+XXyxWKXb6tBrNhkbnmZ1\naZY4aTAzUaVbrTC9dydLVy8zu7jG9NbdtDttxsaqtMMm6/U6cSfDKfgYWxD2UoJSCQsFogcCfK+E\n42xS+kjO/SqVBWli0AVBFoPRkrGRMrOLTZCGqgP33n4nd04Ncqixj7X1Ot0XnuDvjjzOqvIQXU06\ne5kdd+/lsVN13GKFpHmR6Yl9FKTPehJRMD4q9rB1hmVpZKxIe7nAIE0iVJZSKlWxHJskTBiSIMuC\nNCjgRT3WIggTQ6eXYTtQCERe5GORaY0rBUEA0ZIBFyInJz7HoUPcy/A82Kwv44d/8ReJlhcQ0uqj\nyhIp81if5ac+TyyHuTY7gli4SO9yTONv/j+C6gjISSZHB2md/iwXT5xB2F9CpYZCIMAqMThwN62V\nqwSsosI50ladytab8IZ2E61dZmzLjfz2r/4CZ/76c8xUx9m2dZyBQpXBbXsZtlu8d8+9/PHnv8B/\neegse8aHMDolS8EWgjQzZCbl5t1DHLu8jhGS9dlldt93gBPHfIxKAQehYGmpwb4Du7Go8cgLm0Mw\nALLrBqcCrV/l7+SXrugrLTeUSvSx5j4HtX/tp+zGNe9F9aZpr67QXl/lzl2DFOMVQq1RtqRnyjx/\nsQFD27nntjvZsaXC5KCN62ouXjjHi6+cZ4ctGa2OIrMGjhezZwxGRstMbN/F7TftobmygLe6yr5b\nt3Ph0hxfvKi4Uj9CVrnGdx8+yA+8YRIrs6lsPQw0yLSPl6wRWCH2+jK3jHz7ts23GtJ1cQYncEsO\navYqILCHdyNdwZnH/oLWhXPcNewigjJnP9fD1hnutoOM7Zim6GruPriTVZVyZiXlNi/C9FrosIVU\nKaLTQPR6GJUinZwjZtIUnaasVbaRTRQZDA9T+tM/hC0PsRoLaDeovP6nkB/8UZbTkJcef5BpN+b1\n6Sfh8Nsxpx6C0EWoDLs0ggjKqGtXMUmYG0YGLrayMLIMdPN9Pe5iwh1Y06//tvPwHQsepRRa9gMB\nlEKpmEQloCUWNpnJEFIQeE5u320g0zE6gSRKUQpsOyFZAfMDP8yVk5cYGBumsFqn9dyzdOvr7N65\nmwVVwcPFHxyjmfZ4Zui93BY9iCjdyEhxBSU1vneKfduW+e57p5irJxx9UeG4MBy4DNktPJJ8CQuB\nMQotQGV5I9exbYJAoDLoGBAqw7MNru2BMDiWwNokTSVWCmVrrFQTk7srKyHR2oDOsHwXKSDQGgtD\npFMSAYctiw+M+fz2pTbNOMZRCmUErnTJAolsNpn/6uOEjSYYxT33vY4Xv/EcrTjBM5qh0m66WRs3\n0SytJSSZTVnU6BkbO07RcUSv18VyXSKjsQ2EwiI8f4mlX/vPbHvPA8jX30s3M0hL4BhDqKGHRmuB\nNKBFnmYv0bm8fvMH9TyHxUjQfT5wv/axBH05url+CruuCt/4VwiUlkS9DF8kCG+cm28uMlXTfPnZ\na7xy8hJYFq1mg2K5hK0XCQIb7dgYZeFbis5Kl93jRYLpkCiUKHU0L3CMfs0maF5D2elTq/tffzVH\nK/+u6bMeNWB07j1ljCDTgkz974jQdx79k2a/uLIFYDSW7fCeD/0sD//lH2ESQ2O9zfhYAc/1cH2f\nK1euAAbP84jCiMmJIkNDQywuLrKyuoJScd/JWjI2MkxjbY1tO3ax3qiTNVdZadVZuLyM0BZbp0d5\n5cg5MDbdTsrM1lFGywHx3DyyNsFXH/kar7vlBh649zAnTl1lS2mQwwdv5cS5YzTWlhFRD+E7KGXw\nfQ/HBs8T4Ci0yhEZKTe/cIQAy4NeAmEjt7kQSpOmhjTJ5znRhsdf+CqZ+S5Ov3CaW6eH6KQpwwdf\nR8ktU47mCMYzXphtc34xIosbjI4NsGPGB+2RNiLwq5w7k9BOuygj8YqABMsSmH7r27IlYxM1ulfr\njBegg6AtJZYUuMJcp3xpRa46ycDyIcugFymyroB1gSjAUqSZGfeIPEOWGkRd4Jc3OT/dEKNFzoHT\n+cWjwxXoKepmkt7sRVy5TjcxxN119t1yCE9kNKIS1y6fZKQaMrrjAGcuXSVsdDm05y6ihVO02rOE\nvQQTlxmv7ECeO4FIPo1M78UZvYd09hXuvfvNtD72O2w7NIFU52m3qyRnFykevAe9IvjQu+5n19A4\nv/HIESpuQBInKC1IEs34kM9yo4tB4NqS09cKbFtaJU1iSgWPTOXWJqVije9+99088NaE4+/bfMGj\ndR4UuzE2AEa9AfVsXMv9FnQeGtJvpwtQYgyZfT/N+Rpr155FpyG2MQh2o6MO0sRMbt9NdVCy3lih\nvd5gbXGOPeM2vUgyMrqV85bP+eV1vrbW49CAYJ/UbIszdr/hEKPjW7DKgxy8700sHX2Uhc41MmHz\n7Nw0jy/ELNj7uGt6io/88cPo5G7uPbSFsLGMt3uClZMnkYNVRmuGUmAzxcSm5kZIB9VbJV5eRWRA\ndwV77CCVN76XnVEJSyaUrhwjTSwaa22StEtpp4teWCTZthVnbAoZa2bKhqSZkHWbyO46RmXYaUjW\naaMsDx2FKGFwvICwsUDQvMzCy49xvDPAzjf/LDN79hItr+LKCCHaLL3wRZJzL7N6qou1Y5roru+i\n+IYPYyrjtP/mT6kNekjHwmQppt1GxAk4BpEaHBGRaa9PT4C2v435pUHs41e54cC3nofvWPD0uvkR\nRMrcOVNpgeMESDTdXgNhXCzHJYkilFE4soDt+SRZB0vauffCQpfSb/wanHwR58AtJGKQtZXncScr\niB1bWa2OY6V1EqdI3bXwoxgQ1JNp7hjpIJPnEHGP1PFxgjXeuSPhxJSHX5as12MmnQTr8hoZFkJq\npNFoKbCkjSZFOj6lgov0bYoyyIEFkbuItqPw+nN7k0Y8wrIwlotIY2whMdImUSmlJEPVyigBxUwR\nofEMOMKilcS8PV3jufmIt4yOoaKUTy22KBtNKjW2Fpz9zOewpMBWCa7j8uxjT+FJKGUZQkpa4RVs\newilHAIryHOOhEBni2gjaLa67N6xlfVml2StTiYTHMfGMYKu6zP76b+jfOIE1R95PzgBYT8IzzJ5\nMSK1JhV5H9xGkKLAbJ60bNIYbBtpSay++YV4DZryavxcPv/SbJCHTZ9IDpWKTVoawFo/wfriOTAR\n9+4S/ND3vZcrCw0++9BLPHvkFKrZpug5jNTKVALNYr3LpdkWS6seb1wdZOedkqFwDiEcjInzvC24\nzsfZcH4W11Ef+o9f+wvl6g+trbylpQWZyr2FMrXJltY3xVfk0n1p5Zv197zn/ezeMsbv/adfx7UF\na2t1CuUySml2bN/BkaNHuOGGGxgaHqDba/H8C99gdHSE6aktHDv+Cnv37qVUKqIyh6TTYG1pjpGx\nMVpRj+bSKrVgmFKpTNlzqPglYtdjdbVDc3CAsZEC0oOlpTmanRZHT55ifHSUgSEXXItzFy+zNL/I\n2+6/G99oHvrSlynVRgjDLkZK/GIBy5Z0wi62FeC6zqbXjS0NcSjoroOt82Kwi2H2WgNcgfRBuDA8\nOUqhGrB/tMSWisPUjgHC5AoDIqDipxyb17xwdpVuq4OlYtyaRWpydcpIFdrNdQYqI2jTxfETYkUu\n01U6j5kxeZxA4FQpVlpU0MSJodNNKbh5JLLoL5FUgaMFgWvIVJb7/AiBbYFwDLEGT0LWTug2XEzR\nUCkIkk12bdK4h7Ry3piwDapxkdbsIisnvk6WhKwuzTNx8Hu58NSDlIqScK3OWlbALktGR8oIfwer\niwssLzQ4sLvG6nKdxkKDGw/fTBB2abZjenaN6vBhrl48zoy8xlBlgaw9xPjewwRbYLn5LNLycAKf\nSrHAhTNP0WltZdu+Q9x38w5OX1ziwfPzOW9Pg7Thhu1VXr7QxnMspGWzrgJWzy3g+xGDA1UanYQ4\n7VEtVNg6HJBUA5557KObXjtK8Wo7q1/kbBxFrj/vPzH9NrQRG+IJUNk76V2waDeOUQwkzW4IwjAg\nu7R6TXZOegwP+5hyQJjEaL/M0OAgp6+u8+4ffBu//Is/wyePD/Mb/8c7+cZfPcza+lVGthbYM1ll\n+o7XMXjTzTTbGrdYpjp7mnjLFo7Wx3HGNeeOnSWY7jA6qLih1KI1vp2lhXlkb56CXaXTu8bgcAmv\nFODRY+fI5kw9tbSgF5KuNbGziObf/QH+rgOYOMY0OtSGy7RPRZjuAuWBIVrrXZzlOUbveSelvbeC\nSKiefhwxZ5EmM8m1AAAgAElEQVQlCh32MN0mMkvyRIVM0TMJ+uRnsFnDymLsTgtPOSTpAI25qyR/\n9WtY0yFbChXQEqd5mR3+FhKzTmfHLuzRCs8/fZ77y/8DvXiFRsMlbMdMiUV0nJKuxxApTLWAiTNE\n1kTFDmYgP1BX559grdPDf//vfdt5+M5ZWraFEIr1pUWCahXbdjGZJlKG2LhILUg6bcqej2VZQIzK\nPASCLI3oXujif+CHkHu2s3NnyHm7wkDrLLKWEOkhuisJiy+foHr7Doqta6j6IDIQONZRhourZCsP\nkVo1SoUeupdzY/bsXGRntMShYCdJEvHMJZdrWmO5FpZdwLVdhG3hSJuujghMhBdYIEy/3ZUvftsS\nlAIPg0ZomyTdZF9CJQRCENkCLSRKGrwwRagMtxeC7RD3ughpkdXKpDrGXW+yUNQMl32KYZP5C3OU\nijWCLKTjVrHCNiXLJk6iPIW528EohUCSKY0vLDITkcXXkF6RtUSiMk1HdZFCosm4+dBBlldX2Ld7\nG2EUU19fJzOQWLkTNp6PuLRA62Mfx/+x9+M7BZI+3BtJhUee75Uog9VPvvU2GayaD43Q+YxvtIsc\nwbdIpc9xlA1VuDZ5XpZtaYigXT8Piy/jF4coVxxWOqtkCyeZDAb4yR+8g/vu2Mpn/+5Z5tc7LLR7\ndFPJntECwoZYZ3RMTK8+TOop4m4Xr+Bi+pyk/H1tqPQ2vmBek70lrr9HTe4hpA19N/E8fy1Rfb7S\nJhCe16qXrrtEAxjN2ECZbQ+8m6999XGOfv0ZZoanSLKMgcEC9Xqdm266iSzLuHbtKnEccc89r+OF\n519gy5YJ7rjjDlqtFuPj45w7O4slbNq9GL8TMrNzB5XqGLYbsLK8yOpqHc8tksYZ0nVZaba49tI8\nh7ZWUWKWhhGcOX6OCxcfZmpqK29569vpLs2Rdtt4MuF73nYvt9y0jz/7X5+j3QqpVEevI4QITRh1\n8b3Nt7R6sSBuQVAAoww2AtnNzUyrFWi2DHt3TfADb/p+1s++zMzUMCNTVUzJJewlRHg88/IlHvn8\nizRiOHCgmitKOxkiU1TKNqkF44UiW3YNcuWiYnTQZjFdJ+z0GVxGoozCKI20bISt8YTEVRrXcxBS\nI2TOm3Hy+whpCtWCRAtJN8soOrkhodaAnZsnKmmoOA6ptskihbXJ6ckVsnnUy/rxr3Dt0b+Abe+g\npwZpr5yhMlDh9Nc+y8j4JMKGuUaCKAZYq8t0u5qyfInMGuemQzfSSiWe6jF24z3MXTiGxyVq276L\nTKUc///be9MgS6+zzvN3zrve/ebNfa2sylpUKklV2mXZ8m7wAtjGTTdg0zTGHQ66GWCiGSCIYWa6\niQbCDZhmbGgGQ8PYdNtsxrslS7YWa9+lUlWpstas3DPvfu973+2cMx/eLMnD2B6nvzWR/4iMyg8Z\nlZknz/ue5zzPf/nmVznxw++n13iMYWGh0hQpDI0BOM0p/OIwtfIY/bU2l6MrbHZeZLVXR4sh3jKf\n4zMnTZY/ZcOxQ0OcutJiuFwkShS24+JVbfJmFNXfoCHOUq2MEIiIr3zjEre+boWxqTJ1Zdi3y72j\nDaBfjX+F7BNtzCshq2bnZSP4llEXhkH8DqKzw6j+GXzfp765iYgl+bEJtpttJobLFH1FeXSS9mCb\nwlAZlta48tLj/Mjbb+U//eeP8WDrAExYTIWr3PeZ3+Y1b34/W/2A0uR15GcOIfffSqm+TWt9m9zk\nQQr1LbwrHV48s87RuRqdrVP85Z8/yY/987fRbnRZPPdN0o01ykM3ot0hlJPHuCG1MgSD3cW26P4A\nWShBbQzVqtP75hfpPfgZbL/I7E/+NvaBg5RvehtpfR3bshgXKbrTRC09zub2JSxP4538PLEq4SQp\nKuhDp45MEmRnCT24SHlqlMLm3Vj2EMLNY/kunWZAKyhz5503UK2fpxg8jJVuZpy52IIZH2kkw9MT\nrG6uM5E2SB/8W5SAkm3h+CXCrR4Ig/Y8hOcjbBDVGiIIobuNUa8eHu3I4qu//AF+8e8vftt1+K4F\nT5xKon6TxvnzTN90O3EcEwqBIzU528eg8HNVLGEDCSpJMMQYo+jWt/H/539N6Z0/SGWmwHrsEOgY\nS9nI3DDJZguTtskf24dnKdJSja5qk+8PKMmUgr1Ms19hcjylHo8y6m7R7Q0he2eoTMHB8ecYyNfz\nwguXsJBYZO1o6dhIx8cgSE2YuUOLLAAVNEaKVze9uTpc0Fi7vIyaRCFMhGtZBLZFrDSOLbETQ9rp\n4wCRJbBERNqG94+XuZRYnA0SfiANKVmCK2Gfg0OjnNEWOdew9vmvE0chaZpgCXClJBWZ7N82EqEU\ntjAUXIcwDnAtm16aZpwqW6CV4PSpM4TdgHOnLyJzHtIYvEGE9jWRSolNyk+8841sFIr8zV9+mpkP\nfRBXKTQCXwsSkaJU1nvRZK3z4Ptg8UiRjRev8nYssfOCkZkl/VVJOlc7KwACLCF2ih5BqANUf5VU\nVqi6FsnGCqWyhxkM0L0eoRZMOy6//ct/SL0R84nPPs3JS+ucNobhsuTokEvaDmmlfZRbwlXPA7fv\nDK8y8rS4Ogbl6s/Dq1fCq50YTWaYaMQrxU6qyD5SQZx8H6Obne9/9btqBBaaip3SjzX//F/8BE/c\n9yDtdpdG0CEKI5rNJlubmxw/cQLHsYmjiM2NdfbNz2HbNoPBgJGREer1On6+RLPRot/qceHKGo0o\nJu9XcHMJjp9D2x5GOGjdR+kUxysxNXENOatPPwiQqcXKUgfXOsyJY7cStQbEnYCFuUmaW2sM14ao\nlnPcdOJ6nnv5AkkiGM4NY+yIVtLcIa1+e+Lgd0OqAQt6q+BUsiJhdFQwGIBKIZ8acirk0//3HzE3\ncoB6I6S4VOIH3/MOBqbP2OwCc4HFG+cWeawZcf0Nh6jXA/JKIyxJTdVYXelRqLkYy6FQKlG0JFVt\n0U9ShNFE4YCgNyDIBfiyiKtzIAYU8x6TRpAMDCUbUm0yW4y8oBEZEi0wWpEkgGUI+2BssCqCgqsp\nGAfX82gHBaIwYKSyu0vWVSPYuHWFxfvuoRfWkIuPg51nfXUVP2czun+BxOTpbS3hVSaZGonZ3hLU\n1+pURkcQGOqXnqE6fpgwCigFl4gq06StBLF2N5E8ysF3fZDWxgXy0zdjVIyVq7JxeYVHvgFH7nCp\n9WMWSi32LSxQGn4jF88/wcW4x4MPPsGNx6a5oZrj2UaX2aE8/X5CFAMixJYurpWwvyaYOPBuPqQ1\n975wH/c+/Q8kHZ/nXg750Q98gk/9l5/m2qO793CSAvRO4P0rPD2TPb+v+BNf5RSSFT+egIG8i/jC\ntcTdVaQNjlGQCkq1Ca7ZP8eJQz7DyWUaWxv41TxXlpvMTFcIHjjPwnUT/OWn/poHkmPcYJ9ieRka\n2wEVO6U2P83cTVOM3/Im8kfuBO0jvBGq88OYiUn6jTa1+jOMTMyyvrHNVtfiXT/5HkaHi8xd/jy5\n1Kbfh9baCsLNIb08Q47HoLPOYODtam10vw+OjfBymHwB02ogUx+VJOjnH0CGF0lPPoSDJFg8CX6F\nROdJCiXi2xfID1ZwREIUbmPiBvmZm7D3HaBcjnDe/+/QeOjkCk6aolaeRp99FtKUIdHjYGUM0biI\nTJvEgcEvgSz4GDePGnRQXpVCocL0O3+E+Yd/B70NRgrKeYUs2zjDI1jVoSwaQ0p0p4NqNHBVjF/d\n+QVTw3YXngkdfuD2a7/jOnzXgidRkgefPU3RyjHv2hlR2WhSc1UI5aKShETFWJZFmoLU8M005fbf\n+SOmiJhNBefaffykzv74AmpQJFIdKFUZ3neION7AaifEl17EzVcxZpGz2mfJfQ1G30rxQsIt3vPc\nUqwwMbeEkysTLnWwpKYkH6OxXcEI5xWVlDAGtEIKizCKcR2y7ocRWEZmxly8MgLnqgR5t0QVE4WE\nqUVi2YSOpmpZhEKgCz5OlIBOkdIgLYtxV/O1bpfy+gYfPjaHrWIa568wnfepBS3eduNx/v1vfZRg\nbQsv1rgqQduZedPsvjk2t7Yp5iU6CAnDAaFS5D0bkhhShS8E6SAkjUJsUSJNY3zPoRf0yFsWxrah\nn2K7NoQxn/rC1/iZn/hRDpdrXP7Tv2DoX70f20CsDakx9CVIBUiBpaH/fcjSbSmwZfav3CnAs8nW\nzu2Kq/wZnRUg/6hBIqTGtotstWosP/13mG5AbbTK0kqbkg9uoYSRgs0AWr2Y3nabD7x+hvSN8/S7\nA+4/vcnlnqKbas7dN+D0mRy/tPAEdvk1GKMwr9wD/7FVu3hFMWYgC1Q0WYGjNahUEClDmkIUZx9h\ntDuBsTLgpRA7kma/y1iuhJCQ2hb1pMNQrsyhE9czefwQ9ZWzjPkloqCDSQfMz02SRn2aWwHDwyPo\nRFEq5sk5Hto3tNttkiSl0WiQhCEH5heoVUcYHhrBKRbo9ftcvnyZ17zubSTJgAfu+VvefNtRThyZ\nYW39Eq2NlM4gxiv53PnaOzE7YxllOXSTlHOLp7ju4BTjY4e5tN6iF4Q0NhoM+gnNeo/b77iV4sx+\nUsvLRkN8eVdrM2TD5DWC1rjhwvMQ6qywEEKQbGgmFxzOdZr8+Ps+wMzEGCfv/Tz19YRKYQJDj6Cn\nOHLkBAdufobN57eZHB/Gtnyiep+861CoTbBgVxgeLtIOFZb2uP+5S5y/1OP4TfsZxDEFz8HxLFwv\nTzE/zEZXs2kVOTDiMFMrsr1VZ63T4Q13nmBkpMY9X/06FV/QCwy+a5FzFZV9Bm90Z+8XoRNKNpMB\nedEnHABSsL62q6XBLfmEGyssfvaPaTS2GTSXqczsp728ROzPYeiQ3z6Dqt3E/iPjbC2eZr3Vp7jv\nVibKb8b2LTprG5QPHWK9GVJLXyZe+DFKzYuMXVOjZd6HiAasLS8z4vTpnv1rvNd9lLQ0w1uvfRf+\nUJnlc8tMFAucWs8z7m3x+kPnGB0b57pClc6Ra5kfjnnn/HmCmmBjq0+oDbbjUHQcXnvUxRf7mUkO\nc+6xh5gvVTmeFHnvW34VYzs88JY6v/Z//TG/97uf4qUVRRT85q7W52qUDDvFjNHZBYWMypmNAtl5\nrkUmUQ+Tt6DPHoR4FSmKlMujlHM+agCuFZOTmpe//MeMH7qB0WoZW3Q4POuw/tBDXO67/OWXzvK2\n41V+qpzy8mqZ181YfOHZgE+/7xfIlUq84YP/CyulfTiDMvlmD4sUJ5dDFsfY/zO/wdiLD1O4527y\negnfHUetfpVcQRKM3kj7pYfZrFncf+9JOlpxZP8Ydm6E08sR/a2n+Q40lW+PNEa3mgjbQbgOlGvo\nIMYqlQlOPoa7dZKzz10i7EpGRgsIEyBJyR0twcVTVN98C+V3fwi/t4FrF3GiM6TuAolqI8rDJGkX\nW2sGWiALo3jH309yehnnof/K9CAm2P8WVD8mnrwT6idx4g62E0JZ4LiT+ME3WfnTv6IlYGxYolPI\nl8CyW4h2E1oXszuoAUuBWwIxdRtieBZZzBG4R6jk9vFOr0axfuk7LsN3LXjOLK+x0lCUfU0QCTzf\nJRUxMs0OryhOAYnjOugkQmnNE695DfUTN/GpZoM7R8eI+g1qvoN0JDIYYqUpcHoFug5Ieow4PZAh\nXsWmaCUk5gB2tIIKPWakoZpzuSe8g3L3NGOdc2DHSAuSbU1kBcxVimx1PWzHw3E8hOUiREYkVVrj\nWxJhLCTxK+m5CIHZyVDRO6MLqXd3aBmjiJXARmCJlCYKmYLvSvpGoS2JTZZM3u72yXs2p/B4bqvD\njYT0Uk1sSa69boEvfeFzbF5eo2JZJDuOWUalpCqlP+gzUS2ztL6BrVOCNCEnLAZpSm5Heq9llt5e\nKRRIohjfkkijMa5LmCbIJMG3HERqSI2m3mrzax//c8YqNYqzU0QXXiaYPUReCpKsIUaAIaezl4bk\n++lgZG8arbN2smRnXHWVSMhV/sw/+r9NlgRtuw5P3XMf//X3f5Z88QAnDo7y2btPc+O+Ks+swdn6\nJpHJukGDCF5/qMrbKzkGVo5mvcMd1+/nwZMr0GgyVHEplUK2t15g+mBmsojJUof1Tnvb7BCVd36E\nbPhpTCY93/HbUUqQaINKIU6znKRECc7e98fwhp/7ntdG6xSwUUA/AVXUGJ2glEfULdMrwgMPPM/L\np5aYGR2lP0golwzT09P0ej2KxQI5v4jv57h8+TLz8/O0u23CdEC5VGF2dh+piqjMTOBaNmE0oB/0\nGS1XQBlmp6bxPJ+zi6vYMmVra43CjQe47YbDLJ5e4tRaEyuXY6veZGp6Et/3CQcDVtfWcOMGMu1T\nHqpge1UsRBZmqTRh0geTUixWmdp/DcVScdf7xnMEUShB+tiVPvGmQA0gVwR3RGCXUsIYwqBBvRXi\nFyyGKwpXajbWzhEKh/NJSqup6ZWLPPj8Mq12wHF9EefUBsrMceK2m1g4ssDTzz/N9UfHWW0HLG/2\nXvHUwnGIowilUqTMJLYryy3inoW10WXEUyQh3PPoc8wetChfB811kEFm7CYlkBf0bIh6BrsPUShI\nY+jbmbDBpALx3U1B/j9ofOP3KN/xYaLmWXJCERbHGEQ5xo7fRP+pb2CP3sx6VMXrdFhZewk5cRtD\nfht78gauPPs0YzfcTnV7kWT1JFV3H0Gax5z6IpE9zpWz2xR4lGoux9iJ96HXDPnpHyV36C7+3Q/d\nwjIw1e0gc5KGdOhowUBYHExybC9tMTczSrezwqmNFB/N1kaIEiCU4boDFe66YxrVy1PeOkTa2ma4\nOkKaDChVqjR6A6o1l7smJnj0Y7/Pv/j13yY2W7veO5hvFRpctQR5lbwsZHaBwQjQWeC1ak+ikoC0\nG2DlHKyoy+VmkygWXH/oAG86NsRKc5LhoSrlikvS2CZOEkrT+2gEl6nhsbY5YF8+5NDsPEOrL9CW\no6ixOfYfPsrLwRAzJUmqEuI46xp32hvYpJTHJ3HnjjJ/0wb9c5Lxa/ejencwWHqebstDKouxEjx8\nxTA2nKPRSnHsPnGjgeft7rwSRmGCIFsjvwCOjTtzAH/fUZL6Iqqzhp0v4Pf6WYGYJMRhhDh7lqn3\nHGJy5F709uPk3EOYdAgsgaufzS6u0UWceAWwMNYkJArd3KLTrZGzhjFhG178OuVcE9fTyNo8Ymgc\nEoXpt9GNLkk3wrVkNlVI9CthzVlnJfNxE8ZAAPKud2G97YPI/gW29Bw96wB4Y1x1edPzU99xHb7r\nI3dlq8dWz2NtZZv3vkWjE4mREmH7pGkPB42tDf1+xOVano2hafR1J9Bnz3M0b0AbziUDBptr1Cyf\nseFxrpnapHkJvBL4zjpGp5igj1WP8Lw2PV1kLG3SHSTctn9AWpik25rh/tL1POd/gDFPspB7iAPR\nPyBiF8uVWb6XbWMZGykyf/FYpVmyrDQZZ0NkvikZoS073LQwCC3RZJlgu0GapLg2CClxkqzAshKN\niCWp46IlGGnwpCHRknYvJI/h7s0Ox6crTI4ZLNfHtgzVQo6ZoSJho0diVKaIixVawr6ZKosnVxDR\nAK0lvpTINCXVCca2iRKNY0EiHZJeF69SQESCVBt0kmIrg29LemkKQmPZLmWl8KQgGrQIzoZ4h2ax\npw7QJZNZy6vDHiMYGDK58S6hjcoyfHYIgtoIpDGZb83VQ0Fk/kVCvKqXECLriq0s9/jNn/8ZJsfm\n2AhivvnSFTSSL59r48ss6TtvZ3+7Sglebsac/foFBv0B0xWfsYt1brgmzze3DP2OYCKR5KoRcS/A\nLtpozI7zdlaQ/eMWkzEZeTobZYFWAqVAp4JEQRxnvjBLp5/kwoN/AHzvBQ8q+32VEQSxIpaSxHg8\n8fhTXFo8T7lU494v/z0VN6JYGKbVDSnk88RxFtRYr2+jUkmhUMR1XRr1BoevOUy9vcH6+jpKK4Zr\nY/iOS7fT4cTxBVzb5vCRBZ58+in6vQB0QhSHhKki0YKtVpswVDilGtotkQD5HfdnCYSDgPX1TfaN\nFIiVphe0KFg5bCEplUoYadHptwh6HQbdDpFKmZqa2fW+SWKF5+XwPIcrsg95wAEVglMyDA17zEzP\ncHC0xMblp3ndNQeJgpAnv/4Z7nniDGOTM/R6EanRBL2AMErpR4q75iFut1naavL1R77Ihz7wdrY7\nCQ0RUq+3iQIQRqJUgtaaKLaIopQ01UxMjdJpdYn6Csc2WHmL6hSYcUGAImpIoqbBc7KRujGK7qZB\nbwmMBZbe4Y8kgjQF42YBurtV7X/uk3/Oe5wK02/+t5z+0n8jabUxdsqFJ+5Gjt2BpZpsrHQYyvXI\nHX4N3eWzlPcf4MIjXydUksZ9/wmvdoh0EFGdP0ihd540MlRnh7myepraXT9Jb/UCrS98hGvf+n6O\n/uS/4Td+4g08/BIMAVpItJYM0gQ3EqzFLb5xOuLo1BgbGwMO75tj8dSTiAIoIxGW5g3H53nTm/fR\nafZpni8xNuqQ5PIIpSFfJUoiiiOjxNEA33Vw2wF/+7/+Mu/8zY/teu+onUwJbdjh7PBKhyfr7rxa\nDkmZkZxNZ0ASe6SWi0bSbG7hS8MtN9zEwbEyfnARJ19C6i469gjrCaOvew3ysecoc5kYQ6oSNhpN\n5mYLXCoeodkc8Ka3vpX9199IbPm0m3W++dwicn2TtcUlnti6hO16/OzP/hiv++F3U7nt7VQPH6Px\n0tOUJm5AmyGs3hNsrGsOTgmc1FAtKgaRJkxiiBTV0bFdrY0IA0g0JtVopRFejrS+SHh+Dfe6SYxt\nUau6BM0eBVJ6xRyJl2JGKlQne+jqD2HiKPNS0C3QIagWBgfcBczw25GNryGXH0eceYYrZ29lo2/w\nZm5mNrkHVwzQAxfLCbHnr8W67V9hWltEX/0D9NIiQkmkY6FTAylQyBSzWfj0zi8RgfX+D8PCKNtd\nxXL8OsomRlbHyaUDnPBFhj0Btg/c+G3X4f+n4GkzWnU51YB6J2aokmJhIaMOAydHf6jAJgGtbpvV\nQcD0bce5eOoFJpKAg9e9iYc3VpmwYWJ8mhXbI9Ex+2KLvp8QpYpCGBKtLtI6vcztdpPJqYROvsrB\naZ+V8AAr+QJry7AZC9pxm3BsiH7hBl7IHcWe+DF++qWP0nn2CaTMZ7JCSyNEAkaSqHTnFq/R0oCx\nECK7wZkdIpshI8hKnZlP7QqpBjSRVKQ6q0hLRjDQKRYSV8lM2aMNOVKIIrRwaHkez6yscXyogrZS\nLpw6R6ffp7XdIictkkGMQGAdmGNeDPiV/3Ajd38q4uOfaJD3IG/biFSw78gBls6v8SNvu4mtZg/L\ntnjwqRewjSDRGpkmpFGCkIJ+rIgB47j4RpOEMcp2sI0msQ3x0hIpKV5q4YksNE9j00fjf59p6Vor\nlNZIrTFaZpESgBLgGPHKtSsbcclMLUHmGFtvRPzrH7yLoZrFaN7i5e2EsgcTNZfGekitUqSEYqAV\n22GCJSwcqXAcQdyVBEKythWSyxtkztATKeW8prfu4Ezb2N/K1cGwY7ywsx/EjrFgJo03JuPqJCp7\noUYqI6imSlJfX6fz0G+i7V0mOzuaTmJoDwx/9PGP8bP/8ucII83f/M2f0ms+iyfHCLYblCoxkdsl\nND3qTY9iySOIBhQLZYTILP6PHz/OCy+8mAXhBgPm52ZZXDyPjU1uYpKR4SGWr1zk9ttu4cxLz+EJ\nxVpzkzBR+J7L2PAoqeVxaWmdg7Pj4JZYrbfIlUdIkxhJSnOzQ6c/YNBoMXtsnnp9idX1DsUgT7k6\nxL5DR1jfWKYTJAThgInxMbZXN2lu1He9b3IeWCLH1PQUcWR45vEWri2J+oa8C2kSMTkxQ9+a5Eoz\nz3p9k35vwOKFdcIg5NKZc0yNF7NRbtolUTAuNc0eLLgDtDFsrNbpY1MbG2P/sZt56ZHHuXCujtaK\noNVCjo5hCUMcB5RLBYwReK5LHMbk8jZewaabJnQagkEvS0P3EvBklk3nuhadvkIpUFJgGYNtC0hA\nkTn7RglYu/Th+dqiS+FPPsrtP/ReTvzMr/D0332WzoWvkqvdQL3bx4pbzFU6aGVTiC9i77+esHOS\n4ZmbaK0tYo2/g81Tj1KpVolXHmQrPsD8TIW1M+eI7BFO//ffpVyDO3/+E4yM2/zqu17LUjvPagK2\nI3FtCdIiNTYqVGgPLm006EYhE4mkXV8hTCFqtrkch5iggbrlGK1ul6WXA7aXBPkISmqAlStDv0dx\nfJzm2hqzo1Wifg9KRTzH5c9+4YO73jv6qibd7MS+GF7x1Mo4rTty9J0vSxMfTBFFgvGqCAGd+joj\nM/PcfGSKo5OC9Re7/MPnvsIv/uJP092+gj06zeDCJfqlA3R5ngNlUEKy3e6QWA36kWTmujfS9qfp\nlfdxcTuh8bkv8Z61bRYcifBs3j12E18tbvB//O4n+dztN1MYnsDMXstIbYTNL/8FqRboqM/immQg\nBE+tSA7MxBTsmJdfrnP89huJvlV//70gSrLLpjLoKEZaDsH6FsGLlxixQ3K37MNdXWUwP0J/CHQl\nR0n2cU2HzskGo+VnkPtfC7kFdBgi/DzGriIL4xjpIRrnSHrXU//sx1k5C8Htd3KmOOAu7zxGg+XJ\n7OVvQJ35MvadH0YcfDOW/+dkd/HMC2mzLZgcA+FKSFQmMtIgjUFMjcBkm03zK3j1T+NM3Iqqn2Sh\n8WdQXgDPxqg2pE0E7/y2y/DdC57VLmNDwzQ7CecuLnPHrUeJU8XKRI3FYofu5iW0KdKRRcbmF2iu\nn2XU8snf+ToeeekUUXUIL3RoRF1mvU1it0BuehrZuIK30aUdt5kpF3BrRe5PZxkddxkpjrI8cZie\nTlCDNtGCxBOKvJenE/Rorm1SHhrGF4J/OPQvmXpwEdmRSCSWsHZGIqCS7AgTUl5lrWVjnh255FVz\nXM0OOXWX9Y6TJiQobGPwhGFgBJ4E26SkQuKZENspoJAII7FVyqBQwFjwpPKpbG7R6XXoRCF3P/IC\nqcm8KM0FG44AABPMSURBVGw7o6+Ovf1tNP74Uyze20DEilo+h2NnstqC7TAxVGH29VNYfhGjesSu\nxVCxSHu7jmsJEgO+bRGhkdqQA+I0JTUKIS3yQjBQYJHQO7PIcJJgSOlZgqK2cE2MBjwtvq+0dEeC\nMjobaSmzo8LKRluplNgiGykakR0KAp3J4i2H//KR38FN10jdMssDxcGZPBVpqGAoeBbNfp/UsgmV\nJhqklEdhEGryfp6D15WQwrC0btH3bCQOo0nMl7/Z5cRtCb6dZLcSXi249I7nM4AiuxkqnWVkKZ2R\nwdPUEKdZsZMkknY7YOvz/xP1Tpvdmjj1U4f1Zp/zF86zeulF/vITf8iRhVs4MjdDODZJqiVPXfwK\na/UWtUKOfM5hY6vO8nqP6687BsYlGvRxXZcXXniBUqlEs9nEEhbbG1uMj46wb98MQT9gZGIKWwou\nnj/PeG2G6YX9bKxeIUwVaRjjqATPcXEsl62NDVZ6K+SLeTzbpjg6Qtjv4GLRaDa5684T+HaK4zr0\nI8kzjz7NyPgEI2MTdMM2U/kjTEwusLJ8ERNHDILdKUkA3CIEgwGdbh2SjCyolQHbULEcupcMp8Ul\n7r3wAO/4oXdREgl//6V7GOQcFuZHCBqao/MjbDQDPMvGSQTlqM9qdZKV6RvJNS4A8NADT3Lw0D5y\n5RW6rU0KRUhVil8ooI1GyATfy1MbGeX8mfOkIdgCVKpRSpO0oZVknEUwWDLbK9I2eJZNEibEcdaK\nTwCZGEyaxUroQCBSg9ylUOKei5rrKx7ii59l8um7ueMDH2Gt+WbOfPmvKEhBp/kihbnbKIiEeP5t\nbH/jExgRUx1+iJx3gCAIqE5M009dLCvH2NQQ3d4WcaipTU5x3Qf+lP0Hp7j7kx/n8//9yxSmDpKr\nwfalcwxbGq0zr64oifFcjzDR5IShGQ+wty/RtTXb9TZBbYIfee8bqVU9Xjx5gf1zmpce38YxEe2i\nR9Dv45Q0s0NFoq1lKpbF9uIKhZwkWDG4E4ep1HYnu4bMh+cVns7VV9ZVry2xc3EzIosfSi0kknay\niNbjhN2QYt6nWBvjwPgI46WU7uWzeKLPD7zzbZh0wEtnznPXyCi52ijnzj/GwWGHVGk2+ploRtsd\n/JFDtLsBuptQbMWcf/Zlrl3Zor12lpX9k9Sm5hi7eYKfLp3g3LlP8NRjT/KGd7072xiWTengDaw/\n9HmkFIzkFf+w6DJWzsb2g8I083Mt/FKJQbu9q7UxqcZo80rRYxxN6nms2wG1zS2swQRxwaFQaDEk\nGoiBjesU8BzF08tlXvi9e5m5+Tyj0+OUDh0j1gpfpWxdXCdfzHPvpz8Ga3DgPb/K+msOs3zucQ6v\nP8a19gu0XBfbM5hB1q1R2xD96c8gp29H1TfABpEIXG3oDAxKC0Sssk5P5iSZndfl/Vjri0T69/Gq\nNzK3/FuU821Mcg66xwANzgykne+4Dt/dadkBIRRu0WFurEIax9x97BCWHlBuLWON7CcXlXCHRohG\na+QG2+RmryXaXKF8aRFHOMweKtOXAy7rMfrSR24vUfFG2fb7iOFZ4ngdf6FA8WJKwRpGulVM0EZH\nCcKv4JoBlmVhgohCfhQ7bGNiSOKYbq7K8okfQH7xvuzgMjueL0CoFNaOKFpqO+tayB1llria3bSj\nknmFwPy9wzKKRNtkqcMGT0NsGWg30ZM5/HaArnhoKYgNuGmKMSlOanFmfQvvyjlGfI8LW1tsrG0j\nMSS2jQkVtbe8nmhphbt+ZJZ9N8Tc93CHfqoYsi1EqlBC8NBjJ7nm2PV0o3WGKxUunV7Hsh1QWdq5\nsEAmKfZO98SyJcJxEEmMIQsKFcrgCYuw3sseBGHIxRrfKAbC4CJJjUKa3fd4HGmQZL4mSMPVkD5p\nRHbFRe4Y02ksbZDSAguee3GTZ7/0F1SHKiAlUz4cr/k8fLnDhjYcHnNYbCe0+hHzwyWK01USA11v\nwOyMxC5pOudi3jTnkpMwd42HZ3zoWTz8wAbH39qkzTivyjV2eDvwiv9PuhOyp3eKHaUzrk6SGtJE\n0O6EXP7cL1PMSUQgcWWyq7XpRRYPP/g4Lz13PyUn4vzpe6n5EMcNrrvpJ9GWy1MPvIAwqwzlXHK+\nzeqldWZnphgdnUBgsRrFSCEJwxDHcajVali2j1KKXC6H7/nk/Cwd+5FvPsi73vVOfNfjG9+4jzte\ncxt9JXn0kUepC3BJCQd9WnGL3PA0hw4e4MyZJWpDw3ieS6fdwRAwPFYkbreZ3TfN2tYGlaEi9eY2\nlm+hkew7eIwDB+/iltdaPPiNTxLHXRYXd8fMrW9JXKtPu90DA56URF0DMczfci333/c8b3jLtahG\nm5uPv4ZqET73R/cQkXDtkSnOv7TN5MQQzSDBtSRxamOrAdr1sfJjFHOXmByGpx8/xYXHT/HFv/sK\nt94wjGNDmiToJCTodWk2NBurj/Hyiy+ju2A7WUyK1IZgoBEpuBJild2XjAItBIosYkIrIM1UWkKD\nic2OVHHnc3YURbuAtgxfu6AoHPKJlGb9Iz/P5MHD3PG+f8PmWouVh/usrb5IcXaB9tOfozJ3iOZ2\nnV6UUHNbjI7naW7lcPOzJIMO2yvrzO2v8Iaf+0XioMujn/skH/uFL0IR8mNzREGTM40+A8Aam6W5\ndoV2lDAKFHIGY0uMVFQdh0a3halM8csf/zC337FAqiNSLH5c3sX25hZ/9Tt/wMKBPN3BANlsUg17\ndOMew0UwsSZOFIWxGYJ2ndWTz1CaXeDg7pYnU1Mq+H9fZjLl56tfI3ZGXAqjDaWqha3zNJOQ+ZlZ\niq7gxusWEJ0rrC8vUSumXLlwliNzN/Hgc5e55ZZjmDTk4kaMXyiyttQi52Zd9/FRD2lrakM+uWCJ\nUN3M2nNPc52VMPmmt2A7CXHOpbUW4qUD3nn8BL4AJQSWUqgIHN+nPH+QpSceYSBtHluXvPeaEOWN\nUUYgx8axPAdp7U7FliaZzcLO3T8jIPoOlguxaaMHdcZHXXy/hp0fxVgOwraQpIxPVXjQuo2ZcItv\n/sZfc8sb4cXNEmvWKPc+doEffuM8J37hYTatDi9deJHS8klYXqJdm6HfeQFLJuDYJH2JTkAzhBlE\n9P/+K1gjFmJ0HJGEWI02YzmZuXSH2Tmtkmz8KAQEpkJue4T5IxpT05jacYx3A0TboLsgdm4Q8jur\nir+7D48nkTKk0GuS9wsMbI+S1rQ2Nwl7eYKTz1I/eoRiRyLLBcLxeYIkxF9ZYqjZpnz7LXTdAd3t\nLgW7SZpoqEyw1l9hYmoSJw5wRYlAB4xfP4oO+hSsLmbQR7sVImK0tFBKY+V9gu6ACd2i4eap+g5W\nfwV14q2ov/s8ulDOsqu0wEhIEoUxOvPB2Znpyp1qX+14n4hvcdx89UrwvSFRCkskaOmQhhFKSmSs\nEFrhx4okjUmUztJxpSA1JrNg1QZdr/PYs4tUXYvtdh+hFBpDmiqEBeO33cilP/kES/tcvvH5iPNn\n++TtLE/LE2CkpGAbTp85y6/9/E9xcb3JS089RRSn2I4kNlkxE+kUS9oIJCqFtokoXR0txQNSHJBg\nWxZCKXLaEEqBnSryFoRGI4Dw+1BpSa0xtkLplDQVYNyMuGwbLCV2Zu074aHSYKGxhctn/vA/UPIT\njJFMiZRHL/ZZ7kXcPp/HuJJuKCmGCQU35S3H5+imkshyefT0MhubitsnLDZnDEHPUKmlfPWBdRJl\nSLG487YRnrr/bo6962eJo+hb/vYCvRMIqo3cuSmaTH6eGpSCNIY4EcRKcPKL/5FicJmOk6MfhFTK\nu0s6vPsrn+XUc0/Sqy9TKRWpb6+wfOVFcr4HJsHLV5nfvw/bu0Kh4GCFhutvWODsmcvUt5tIS5Cm\nKUZDLpcnCAasr6+TL7jEccShQ4dYWV3hwP4F2p0OBxYWCPoBKlQsLMzzpS9/kYVrb8B2PYRbZGV1\nhUG5hPQdiLo88cxpDh69Bte3iOOIfr9PGA3w8z7byz1i4VFvdoiVhV+osNFcR3pyJzSyzOT0HG94\n04+yuXmFR7761K7WptsylIYFnhT4ecPELFw+mz2b+/dNcj/PM+IX6I+OsPXy8+A5FIZBFzx69QGN\nVsD2SpvqUImw02eqWqTd3ebRe0+RL9ToNTtUgbojcIvQbBqSZo+yAwjJ8HAVbUks22VmZpLNlU1C\nso6x5wusSGMShVQCP68ZpAIiiD3I+wZHClKdHSyYnUJfQnaygY4z/qB4RVG0CxjFswMYX5FUXcNU\nJUfEEhc++ksced07OP6+D3KdU0AnmoESmDRgbblJHEQMun3SnMORY0P45RFwLHKeJOj3+ZOffyeb\nm5C4oMsOQvqsNerYVkrt8Ov43z/wWkqlCkUH1q6c55Fnn+Vrjz7K4QJYvsXSep/5W67hj/7Pd5DP\nQat9EcsuISwX2fDxPIfYHpAMeqytrjDmOERRSqsxIOgqbFcSSUHr4oDtdkSn06cQGO7c5fIkO/YQ\nxoC4mmVI9mxfdXhHGCzhYrCQTg7FOs0tm/kDd6Kkw4lDY1w/YTi/GCOsPJWxCo2Vz6D0zQQdQGk2\nVmICcpSKRTbtJrEwaJ3i54q4hTJTE6O858ffzYV4kmfIEVTKbKeSm47O0+r0SYwmjmKGpsZZunKZ\n9splisNjCMvDGj+Mu3Ka3uKj/NkZi1JO0ejDzZUC1ZFhet0BrWabaq2yq7VRsXolQV5pIIwxjoPt\nW+BIfHWapG9h/ApEVrZ3tYI05GDyEnFyiK/Yt+H+0q/QyklWtjtshpqf+uBRXj77Er3Ok0ysv8jE\npRdpRy6XBoZJbfO5C/O8b/YS2ovRqUW4BXHaxPEsmPUQjsDoLtZolThJ8aMIIxVktFtQO7xl4EI4\nw/SlmNOf/DRv+N+mMGkLZB3hDmXPZHEcY3x0+zR8B4rTd5elDwKc6XHGP/wuvnjDDPWkR/XyGaqL\nyzSPHcW77hijKiWedIk9H5bXODo9QhyEXFo4zMjGGvFgjUFuArlvjMk4pbV4HjUzS9hrEiJwpYWX\ncyiKBF0sIrwqgSVQhTFUEiNUQs5xaMQh6fYG28MTjKYN4lDQL0wyUYBGroRtwEaDyQipsdI4Mitq\nLDTGZDwNfTXqgEyaaHYOc7PLLoYgxU5gAPhGoFSmFEljhTRpNvcPu3jSxSCxbRu706NvewjLod0L\naIssHLGXRlRzHkSa4h23svqpv0J3Oyw+q3npyXVynoN2JR6ZRFpqjbRt8rrH7//hn+P6HiXfYZAk\n6DiBVGG5DjYCV0BqVOaZpC08x6MXhUhj4ckE21j0jQJliNIEXwoSY3C0IW8sukZjyd13eLIcLysr\naHTWfVNYWAhSmXV6rG/psliWyyP3P4WJEgKj+aHpHEXLpudIbjxUJohSigJOrbUp5QvcPD/BTE3y\n9IUul1shQXvAQEoeeKhPvmCYqjn0tcfMfJ4Xrwja9Q7ttYC//fR/Zt8Nb8Id34c2eif8Myt0jLna\n2QGtJKkyJCqrU6MEkliwdvob+J172R7koN+mVnSZGNpdCuQ9X/kojnZQA4EZmeWa6+9g9eIZOp0N\nvviFjzE0doROa5Hh8Rz9NEIZC2lp/JzP1ladYsnHsi08x6XdblOr1VBKUSjkCfp9tra26PUGXLmy\nhBSGQwsHeOmlkxy/7jj5fJ75+TkajTrVkTFSJ08UwOmLq1y6skW3FXL98RvQqQKR0mw26fUDktjQ\nrA8IQ0GjV6dcGSVMJKOT07RWLhAnbeKwydbmBabnphmtHkNHhV3vG2TGlUotQRIKvDxUyoLYGC4/\n/SgAW4/fT+ponn3oy+yPLQ7ddoKT6+cJ15q0t1tsr0uGJ4ZptTv4UQCJoTYCGo1XyJEvSmSkSSwL\ncor2DoFdpSmDfhevNIYUCs93qVRKzM14NIII6Wf7wVhQdh0m8zatQYAxAt/OChnflkTKoEJ2pKxk\nNuY7nYVX+PHiajfie8dbpx1e7sP9rZgfrNmsdRLakcWR0RwvP/IVzj78FYQE14NyZZbc8CTV8gjS\nzWPyEA0GrD+5Tmv7Chvrm3Q6ECkIgcC16Qw0lbxmqxux1IxhaIa3X3OM9vIFClMTDO+fY+jIYU4c\nPcJ73/wmfuk//hZxX3HTjQf4yL+/haC5iFFDKJFjoLs4TgGZuJx/4utYuSEubbYISwkhglreIYl6\npOkA23cZqhap1xv0RI52J8FrNHe9dVIld9h4ZAcmAil3GDwW2NImRaOVAqMQRjOIUo4ffg/7a+Pc\ncf1+Vl94gs/f3WRfvsv+2SEKQ1X6Goo5m9m5YVR5P8+fbpIYg+VoHBfaQUpxfJJer8NQJeDydsRT\nV+DM9hX+2fwUYb9NkR52muAV8rhpSsSAQ/v20Q0u0106hxSG0tgsJlfFnbmeoaNv4tTf3cMPTGRF\n5fbmRVYWR5k6PE398jrVodKu1iaUDl4Uv2q1EaWABTmJTDOer8wp0A1ktjxZN9IC40A/mWCudY5i\nOWbf2hL54RLGLyAvNrnLbGHu/RxWQWNkREU7iJnb8GoFXvjCc/yzWRAqRY+OY+k+xUIpc2ken0W6\nFlpJjEpxaxInXEHmQEciC3PcgRbgNs5y+bU/zbHbP8RXPvLrvP1Dt7P2zKOU5g6QqxbZuudvUSZP\nX0kOHf/1b7sO4vtJNN7DHvawhz3sYQ97+B8J348AZw972MMe9rCHPezhfyjsFTx72MMe9rCHPezh\nnzz2Cp497GEPe9jDHvbwTx57Bc8e9rCHPexhD3v4J4+9gmcPe9jDHvawhz38k8dewbOHPexhD3vY\nwx7+yeP/AaEKzhrzt/cfAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "stream", - "text": [ - "Sampled completions:\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAsCAYAAABhRmIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8eYyl2Xne9zvnfPvdqu6tvaur1+np\n6ekZkkMOyeHIpESKpmRttmQkRrRAlgIoieMEiuVYjuNEDgwhAWTYiZ0/EgOKbUTRFkmwSVGkqXCT\nOBLF2dee7um9a7/79m1nyR9fzUh/iAOWgCAAUQ9QqOq6BfR33/ue8z7v8z7nCOccJzjBCU5wghOc\n4ATfzpD/fz/ACU5wghOc4AQnOMH/1zghPCc4wQlOcIITnODbHieE5wQnOMEJTnCCE3zb44TwnOAE\nJzjBCU5wgm97nBCeE5zgBCc4wQlO8G2PE8JzghOc4AQnOMEJvu3hvduLP/sHv+3K1gIdL2T3+ptM\nhwPipSYzL2ZfKNovvURSq+MuXMCLWyTNOo/FQ76YNlB5xEX/FodFSRCtQCEo9oeI9SYyihh3ezSb\nCotPljriwEfNumzIEdolZPE6o8GIWjRExg2GNsHdegP9yOO423tsPfows/4tYqMZf/5zNF5/jbBW\nRymwzuP6vT6hZ2gmCmHBOoETFoHCCMCCEQ6cBQcOyb/43WviWw3cz/6zX3SlUHRWVtntDphaiJ1g\nNpmRr6wgR2NMLaKe1DClZlxqikyz0Aq489JLZM9eI0sntKKQUApKB7k2bP3wX2X26d9lMB6hpMA5\ngS9hqKEuJdo5AiWx1mKVj5MegQJjNGmhmZc5i3HMLM2Q1mCdQAOxFMwsFIDyfTypKJ1ECij9mOX/\n8Rew1uFhcdainCTEUeCYCsm17//4txwbgM/+wW0nRYCWERqPUnjgh0gl8UKBUOBL8DxBEMKd195g\nvn+DL/7el7jzJ7/Gpy4uc2ua01mug81pL9bYP5yxPTTMrGUw1mTWkijFSqy43S8hrhGFivWlJmWR\nc+vWIWdPBbS2Ev74jyb8lfcs4gLYevQH+b6/9Q8oyhKHwNrqy1je+bnUDmuh0FAUkKaSLLc89798\nksxO8KOIjdVl0nTMfh7zq7/y77713PnH/6UrZ5JOc53dwztM84w4cNx843mEbNJeO0M6OKC9lGGk\nYjzIKbKShcY6d27v0Om0aC906B30cM7SWlig02kjlePatTe4fPkRrBVkaUoc+Tz1oSeZjsZ09w/Z\nOLVOs9Xg2ZdfZ/P8Zb70lS9QTnd48fnbaCfYWNlkpbOIF0LYsOjSMOumzA3U5YzlZo1xlqNlQpIs\nUhhLsrzEQfcOzfoSteYFzl+4ikkNaTrgf/6f/h7OuW85Np0t5QLPIRwoITAGZhOBlzjOJpbnb8CP\nXQa3GrK2GDLv13jsB/8aX3z+ee78yasc9KY8/eRpOg2fN1+/S6uV0OsX3N3NOXP1PIPxkGy/z7Up\ndGqCpUByruPzxmGGizt8/498gloSE4c+y0trvPKNr6OvvYDzBC/embHy6EW8JCDUGXce3Of5myMi\nI2n4oJQiDn3KsuTBnkYocD5IHAiBtSAcVZvpHE4LnLbfet68r+ZKLK8PDTfGhr+0IIl9R+grrp7y\n8D2J50vqkY8QJX5QQyiPIpvjXE5WeJSuQbT0ELZImVOj6Xb56lsx1+69wWAyI/NqXOvOAHh4fZlz\nnVWW2m1On7vIYj3hcDhhc2ONtaaPCCI+98bX+C9+ep10vku93mBuHLV4jVRrEApPBcymQ/73f/SH\n3L05IlaQJDFFmpMwpxkHGCHJ8hQ/iuilGaawRGHA79zoHWvP+al/rJx0DgcoJZAShBAIHJ60OEf1\nJUQVf1vjseBTPNb6ILvb93jpwYxb/ZLvvxRzcdWjLsY89JEP8cv/9P/kOz94mms3e7znqaf4/Fdf\n5plX98mmA8amxtKlpzh9+T2M/t1/Tm31w2x+73/M2uWrPD24S30v485en+79OxxMd/mV117jf/ux\nH6eIIAlq2IuOvrMsXnqMlbMPI1SImm+z8+u/xM/+D79GMxEYpfmOyw02l9c5+8hZhv0ptU6bR//u\nMfacjzXcjxpNiMM4gfUCtOeRDnM2WzNOfw+YGshZSLX4a5CVoA1eNuB6+WFU5zzPXtvDFimnQkOc\nzbj48CP0tg+pj79Os5khY5hOY/ZaHyDwSn7zV/6Yn/8R2P7u/5WVx5/E3/kitrTIm19CXvlO5Pnv\nROx9Edd/Bfugy/6vfpHmmkDkILy3P0OH5xxv1j9E8fOfYTXy+b2//d1sbD/Lpb/yg2ycW6L/5j6z\nnQe0ojr5/Qds/cqdPzc270p4dkrwXnmFrL1GY6mDvvgw01qN+WjG6d/7Hbh0EW91ncOlVR6ZD7mz\ntM5z9TXkXo/B4S0O2mdYi/YY+DFReZ+eLCmmddaDPqdWWvRNm0k+I7Y99GSEkiV9tYj06nh5j3AR\n7FyQqhqB1YSPv4/uNCNTPqNySM2WTMMYc2oT8/LzmNIH4WOMwViNJyVgEXhoz6GswDmHdnD0EhYQ\nVLznONiZCTwcw9ku1g8RWcYUg9CW3FhCY/A0jDKHCgLSWYkMPPanOe210+zoF6kJhTSWkXEIa9EO\nls19rvzlLfrbGZPJnOHhmMm8xA2mWBlilSS1FicqcoLNSa0kweJ0SegkeZZjrSO31QqXEnKnsA6s\n1SjPw1oDOGQQYJyhnM8RnkcuqvclhGXiwDgInTlecACsoxQOi8EicEKidYlHABoCIXEe9A4Oef6z\nv86zX/63nH/4Efbv3aW90uTAwsV2jYG1RPjcuDWg8DzONh1JEtBPJBNjGY4t49LiPAPaMpxlHO6P\nePRKk+aKT2d1Ba9ueP+ThkmrQEn4zd/6VX7gP/05rPNwDsyfITzGCIxxGC0ojaPUUJSCorTsvfUc\nSTPHFRFeVEM4RxgkJPP58XJnO+dn/uZ/S6Pe4BvP/hZf+Pf/NyZs0WisY7Ipdjaj152hfB+VlKT5\nDGk99vd3abdb1Gs1+r0BSnlk2Rzf9+l2e6xvLPPkkx+k3+9XuTMYcPr0Br1eH19Izp+/yOvXXuH8\n+XNYq8FoonqLUvcJ4zrbN+/Q6SzTXKpjTMbG2ir3HtxHKcmp1U36+29Q8VaftbVTzLKSzZVl3rq7\nB6VkPh0h5H32d+ssLy4RBsfPG3lUjCIpWF2BQQpe7FAW3jgQ1BcE11PHyl7BRl1x7VaX4kufZX9/\nRnulxZnzC1x87AqD7Xvo0uAHBonFiwSHOw8YFyUCSS1xqAicdVitwYHWmhuvXOPUmdNMJ2M87xaT\nbpdEW5TyyQ/mNJ+q01lbwrMFOgq43XseNQVFVVytE1Q7CiAcnhIIWeWWOCq4wgGuIv3HyptCsCAl\n5xNHagzPjx2P1sEvLS8/MJxZtPgyI15dIUzajGZjgppCFxLtb1BrzVlYvMJ0+zXqa1exd79KufAE\nH7qyyNWzS7xw4x6TUnNmOeWFN+7jEMy1ZWc4ovfaG2idc2p5ndt3b+E5w+b6En//v/oOutNXCeOY\n3EIUJuTGsNxqMy8KkAs06xtc/cg2r3z189iVFlmWYpDYWkRaCNLpBJQilJq7I0jHOaut4ti5Y43D\niYooKwnKk1TdrMRYBzikEAQKPJlwlR+g7T3OZz7/DC/sTGnUY65sdRAixsar1JMednqXqF7DOcuV\nh1d54Y++QX1hncDcwPkxFy8/QXzuPRA0udH4UR7veLzvY3+Jdjlj5e4IS8HZ5Qg7XmRtIeGTH/kI\nWTpDGoXVBQvJKqmdUc5nSClBGPBjChew1obxzFGLIM9zXFxn0C+RQUw5TY+XOwPHOJJ0pMGK6u49\n58A6h7UCjEMu/xh84uO4a59GZLvIT/0iYu93cb/9T+i+9YDTjRZi3KeeDrnw8aeYDubM9w5xwx5z\n12JRZCglUNLDTgeUzZjpHNDghjv4X/pH2OEYpn2Ev4/Yew79hX+ACGsQBAgX4VLACKyoyImrPjYs\nQABy2qcgQS6fZ7OVsvPqTaSxXL+zxxNrHTwtERff803j8K6Ex2vFJLKNnox5KwNSgevts9Q/wDxy\nAd1ZJ22tMB9MGI/GnFt33OuO2Wg0mbZWkLN73G+t4q+sU7w1oLFZ4yD3GFNnWLRIbMpyeYByM9Jo\nCW/WJUzqWGexnodLLdNwmViERA2PaZriZYbIU+xlPssrl+gM7rFf9yjzjDBpghPkxlA6iy8lOIVF\n4FuHRWGdrd70USAR4JxAHKuXAM8YhHX4zpAVJc4YjFJ46QxtLZ6xeLZkbnJqJqBZTpkSEjqL12qR\nTjIC30N6EqUdxloWI583P/cKX7rXJYp9lHMIUTWEYeBjrEWWBmT1XpAWg8ETioEzSGdASqwBJRxS\nQOYcpRDgDM4apFVIY3BKYZ3BGUvgCfSgD+1lfGGQQiIExIA2gD0mG6RaSDiDddUzGuFw1lEiEH7E\n/t4Dvvabv8y9l79ELUhpRiGzwW1GszENzyKNz/1hip8oCimZpo5unnL24ZhYggwUk14OXkiZQ+gL\naouWNJOM+xF37qR8/HLCS3cnrOqE5vmYnbcM585LHv+o4htf/jqPf+yj5IXFmCOVx4C2R2THOooS\nSi0oNRhr0Nd/C3yBzQSNRoulZsib93eIk+h4uSN9wqROnkNRFmg7oMwdSyun2HnrFVoNhy8l87mg\nFiuajQbTQUpZFDTqkjwv3tkJGs0m0+kUpRTj0ZgoipBSMplMAYFUksXFBW5cu87mxiYXL5wjKwqm\nszlFkZPOUzJTcPr0Eh/96AfJdMpsNkQKD1NaLm1dYEfucfdgl2YtZmllkYP+mCyf4Xke48kBYZyg\nXUoQl+wdvsFg1OOR7/tJnGkdO2+EhMiXtFsgfEctBD2CYiJIRw4ix58MAQTPvjnnzBbMend4/+Yq\n1+8NWdp4mEYtYq+wGAdOwFwqljo1droj+imcqUl2tKGpYNyDsdEUU9BNQ6kdYb3FQqfF5Yce5Stf\n+AI7t97g3NlFZBxw59YdXnrzVdpJg9kwpxhBcCQUIzW+UlhrAYc8IjqU4CRIV205iKOtxx7v0tfn\nM8FFpVnzYcmTFL7jK334QNPhK81bA8lqLSDd6bMQHKD8mASJ1ZJGzTIYrxCnt1Fqgf6D10kal+g+\nuEYWnsF5Po88cobbb1zjY+9/gu/72NM06g1u3N+j25/woYfWWG6vsrSwyHw65a37d/HrdbTXR1iD\nI8balFbjEsL3GI53iPyYVqPDPCv4nh/4GP/yFz+Pl89QcRPp+wwmc5w1CN/D5iW9WUZROtLCkhbH\nJ8tlUeUP0iEAaw1CgpQSQaWKCwlZ6bBmzmuzGfdeeoa1jTN43ZssNWs0Qh+nZ1x+7BEWliXXP/cr\nvO9Ch4P9Lu1mjSCpsdc7QHoeH//xn+PGNKG7d8iHPvgQT33s/SwFmnmvx/msixMOnY4JQ5+rVxZA\nxOyMuwjP4QmJsILxzSmL711koDUiUFBa8CTh+kPMclhbEGy0fRqtFvPhHs1WE5cXKHU8N4o3E+ig\nWhNGCIRzYG3VDBsJU4PwCkQzwkbL8Ce/g+8+Xu2JRcRwbFmeFzx9WtPsbLLwvd/B9c99EfPaLrVm\nwEjXsWYfz0EgS+qRpcgmjKbQK1dpvfZvwLuPzMCU4DzIRqB8ULUZBBY8UwkQDpzinb4BK6rP1BQY\nISnwiX2L0zEv/vGz3C0afOLDF9GDEflMkyrLyjeLw7sFKRARZbCEOCtZtB7m7g1WkhB97jSeiNhx\nJd3ZhHP7t1lb2mBsBCqo03vzdTaXT1FEgqC1gLm/Q2oiTNnEL4boaR3mz6LOnoVeTiYVyegW0itJ\npyX1UOL8VdyCYp7DkmcwWUo3BTtLiRIf8pR4lHHgtVg+fZ6D0idxVVEtykq+FFJWvEa8vXiqDuCI\n4GJwOCv4i1w2HWRTKqXe4juHsaIiL3mO0I5Ya+xoRGvBx+gU3e1SeD61zhLkGdZTRIGPsBZjNdIJ\nsjSjJiEMHZ40KCvRHniiKsCZtSSeJDWCSFmsUkgtKI0mkBLpBNoYAk9WBVuIo/FX9b4NEiUs1hqk\nscRRRKZL8DzcwSFhu43SjlIZpKnyrXQw4/iERwqJEwLloHDV5mOFxQGvvvIi//LvfpKHz29ybmOF\nL7y8S+B7nFpus7GYsCljClPw8Pk2X7nWIzWafm/OqCi4dl/QqimmxsMpR5RojBG0lxzWOJYuJtSa\nUE4tr90q+P73tHl1W7P3jZILTwUEZQRFj9t3+jz6UYE2AmuqwqRNNcbKtUNr0CVoXS3Qspzh8tfx\nvCZrLY9Op8b9B9u0mnUGw+N1WwEln//sr9KIz3M47BJGglI6SiyDYcHiYo4fGAb9MeOsYLWzSJGX\nLLQWwcHy8jIHe4cYbRkNh9QbDaIoQmvDzs4u0+kEZ+07v2s2GmxsrJOmKZ7nc/f6DZyFPCuIw4jx\nzHDh/FmC0IfSEEdLxFHCUtIgFD7zVotukRF6isj3iaMQv17DCNBOE9QCMuswYsj6aZhODrl/7wVO\nb106dt5QVhLIrBTsHzi8QDAfwcGhBR/W12D3DoDDAz7yXR9gdWMZr5wx1w8oyqJSD5w4UlQqIusL\nTaig9MCzlo6BdV/QuGx54sm/xNLdAc+8cJP2+lmc57HUWSBJEmpxxGBQ8sgVD08XfNcHLjANYpaW\nFnn55Ru8st1FSAc5RHHVOMW+B2isoepWBAjnsAaEEjhTjbiEOl6XNcoLrgmFNiV1HIlwnArhpbFj\nMxecb9p3lOqthYBsmrOYd/GUR6lnRP4hon4VVU/wJhHz0T1qrQUmD95kafMM+e6bPHG6xfz2v2Xt\nwodRRZsPX9qkFj3MdNAlu30LubFJLH0uLS4SNRLeeO4uZy5ptPBpJB0sKbZQtFunUQHsHtwiipq0\n1pf5D/72p/i1f/55lsKC6SRHOIvyQE8ztHBIYxEypBbkDPN3LU1/LrQDoatiajQoK1DSIoRFqSoP\njAFtBDp3zPSnUa2/QW+yx+bmJsXsECMWUXENaXKyIRTGJwpKJqWmGbapJRp7MOG7f+q/R61eYOOw\ny9XLZykNZGmBjB2XdBd/0KW0GqksQjq8qEF/NMCWBbYAGUU4Z1EzQ2Yknq8QKJwnwSREZy4BAqct\nN7vw9HqIsCXp3LBxqo2eT44VmyAvsEZhVNWMKjgi5lAWAjcD/Y3fwLv3NaTeRrbBDAADCI/Uxhz2\nC7bQSKfJvvw5tva/zL37jmTTg1qHLIegDn5UUC/GzLIhY8/nCzdzfqy+j8kr1u954LRD1kCUAqdB\nihThHGVZrQn39ujlqEbjAFMghEAZzRoPWFur89Cj51i5tIKYDylEwnxyH93ofNM4vGtW7QmPc1vr\nDLsj6uMhrG9SKMGgO2ceZaRZxoKFpSc+zIGWzOox+VtvcbETMwg0QbRBfXmN3s5nmQVNlucTimyA\n8PaJF9p416+RyAdYU+C1HsWPIozXwMUxByhqFLTsISJeoSwNVvgUCwm+nSKlwpqUIIuYJ010vYFz\nDmU9LHOwDk9YnJMIIZBHC0ICpbB4TuCUxGlbkaBjKjx7uzt4nk8UxSwGAbk1BDIi15a4KLA6R2Oh\nTMFYRJkTpxmHMoQ8RZUFJvLxHFghwDpyY6mVDmkMnpRMTUHgFBoolcBHYLQlFA7rPDAFUnmkxqKM\nxRmDJ2FSGgJVCTPWWRCSwhoCJIFSZM5S4jBFiRKS0mqYDBGFoRSW0AqsAqstoYOYYwaHP+2unHU4\nazDOoQEI+bV/+AkefughDiczXv76NYIwYDH2WN1Y4fa1Nzn0FB9Yi+mOCmazgnbk0bWOehQyKufM\nJx5XLzRJxAJW+Xh1x/2Ro1abM91P2VhfxJQa1xZ84eY2W2tNzgWKojdnr6tY22hz7dnf5BN/469h\n9dvenaog5RpKDVpXSpnWYHRJNuojXclBL+PK1jL37r5FHK9z/c59WguLx8ude28SBw3SJGKUpczn\nJatnN8lnmnHe42AEh70eDp/INemmc06vLnP7wQ5In9ObZxBCMptPaLcXKYqCOI44ODhgaanDma1z\nJEkNaw3bO/d56ZWXMYVmc2Mdaz3CMKTlJQhXKUlxkFAUJXsHu6xtnKLb7VILc5IFwU5/iA0jeuMJ\nj37gPfTvv4kR0Ns/QIYRXhySiQHj7JCtMzWybEhrUfLya7/H7dt3jp03pXFkJfi5IPKrfc4TAnCg\nIEiq7qQdC37gPZDeeJFnxw9x8ZHHEI0ZgRH4UmHw0Q6cFExmJe2kUlhyAzMB7RC0hDSDpNVmqkd4\nnmQyPsQJjZn3yeYFO9v75D7cGzkmBq69+TK1lVWSSNEbTMCANUdWQCdwzpKLP9NBvf2jEAjlcM5V\nREeAE8frtN4fw9dnhrul4pQoCQTEQtDxYKzhlYHj4QbUAsHB2OCQFKUj9DWjWQlC0xp8nU67AX6L\nnJh5d5d6u045vIGTAePZhNbGkzTOPoGxmjI3TAb7WCs4ffkKvhM0FWgZ4JzlzV+9zuiHGlx+r0E1\nOpS6IEnWKIoZ8+mQKOnQH/Xpp47/8Ee/g3/9zz/PbJ5V4xTlo3KDcQYcyCCgTEtQqlIxjwmrqxop\nJSCPVGZbkVBj3jbwVD49J0Aqj3k25+LmBQb9HoEf0+31+KFPPMF0XvDmH/wBl8+uY4qU/P6MuNak\n3tDcydZobJ7HWLh66SzDucZqw/pqAz+dMP2jZ8jOnSNI6sz2BoTWUboxs9kEa330kcJS83yENvgm\npH35PK7q0kEImlsXWA5KhC84GEnevH2fpz/4PpTUGA3J8jfTMP58hEd5qgEhHc4eeZqOJgLOgVBg\nt3cQAVjfQ1mDKx3kCpcERK2Y+e0MPylp2B7T+AmS915DDAb4RYOhaNBUE6QFX6ZE5ZS9VNJIAAPO\nOETNIc89gvG2CL/rp3F6ivmtn4HDEiczdHb0PByZ3Y78tkhQLifCYYRg0+1hpx3OX1gjKrfJZzXm\nGKxzJFtr3zQO76qLNVROdmebZnZAUuzjBY6uisgCB5M5siyxrToHu0N2NIwmU9RoTGOhzeLCArLZ\nYnawy3KrSWv1FK7pky6dI1rdJLYRtTNX2GtdZrD4KKmAnlpk6vnMywyRTZjPC2r1BvN5wTRZoxkK\nammJXjmD8GoYr0naCAgbq/ibW5hSgzJkReVBQSgQAke1ESEExjlUZSOsJGVn35mtHweNsM5y3MC3\nDqM1aMM8S8lMCrpkmuVYY5C5RszmKCFY1I5iMmfuwE9ifAHWWAIL0lqUktSXa3zPDz9B6iRPf+Qq\n61srLJ/q4FmH7yzWWAoNnrOQO4wxhAL8o/FceaTMOOOqPJEC4zTCWFJnGFtDqTVGG5wHKRbyAntw\ngLEGZao2SBXVuKswoP8CIy1jNMZYjDE4ZzHWUWiDlRKZLHPrcMRsntOqJazFitI5lO9R6JL5vEQC\ns1lJPRD0pilLiUJFAY+eTXjvxRYLYcDeIOX5nZLhqOTymWVQFr/tOLw3w2sGrJx2LK41mdShecZj\n+1XNC8/uIkpwvT9iOsgrcmMc5ZFB2WgwWlAUjqKw6LKkLKcwO+D63X2SWkhaZuQ6ZDRPQTm2dx4c\nL3cWamTzEuHqJLUmToTMs4KsLHnk6iZREjLLhggyrM4ZDecoFVPkBolCIphOp9TrdYQQzOdz0jTl\nzJktlpeXWVhosbf3gOef+zr7O3tsbZ0hiWK6B4cg4ODggFocMZ+MmQz7GAPb2/dZX9tEIHHO0UwS\nnn35ZQ6zKff2t2ksLJBpyWgwYTrOmExSjJFY6yNUxtraKkm4yHQYMBvFuNJwuHPj2HkjBCAFTjoM\njtnUcdA/WpwlOA1nT9f5gY+uEcZ1dp/VqL0ujcYiixunkL6HEg5jDcKCwENJR70WISIf3zpyHOsN\nWK7B3MIzf3KNfD4iTQvOX7jEysoilx66zKVLl7l0+QIbMawDRQjOSpzyefbFa/zBH9+ATGBL8H1B\nEAjCwH9HZXHWvTPCQgKeqDpmqIjQMZfVeiB4r2+5lxnulpKZgUA6IgGTUjDT8OIAdmeOnYljd2LZ\nmTr6M0tpHd2Z4u7A8frdCbdu32XUu8usSGnECXLlaRY2rnL+8U9g+rfovfYV9m7cw4wHDLdvQTqn\n39tlPLzH3b3bjG68xGT7De5sv8x/83e+xuvPjwFBHNRwpkDIAGMkRelhtKMoUmZK81M/90nSsaa0\nljwvybVBKYV2itk8Bxypc6TH33Kq8b86MiQfpYxzVaG3BkojeNsj7knJ6DBB+QGNJGRx7SzSD0g6\na7z2ym1ckRK3mgjpoV3AQj0m1YrSenzfT/w07eUlgiRBeAFJ6LO50WahFhHUY+4ub3DwzPPIWYYX\nBWht6QpNZh02jAgbDYRUCCEIlI83mBGuXKpmcghcMSMfz2h6cL0nSXxDEARk8znNRh1dZkcezG8d\nYdVTU9rqezUQrxoJqwXOgK/qqPoa1q5QjhOK2xKzDZQCfIUnLCx3cLUY01pBDHbRLkEqCERGqRJc\nCVhB6GlM7pGXsBkbpF/VIrHchKd+CPnIFYwNcMUYKcrqeY68O1a/3TxUD+pcpbxrGTE00Jul7F/4\nSaRnufDDT7P1C79O8P7vJE81ur7IxkrwTePwroQn6R1Q69/FHfYxnUtMMkPZ61FkOYP/5/fZDD1q\n05ICTT7sE+/vs7p1ltvE3Hch6yrD4nij68OdXcjGFM4h0hKnJsyzQ7TWOAL6QYuJGVPD4kufhVrE\nYmwJshzhCaaTDCMUpefQE00966K1JtEK8hxqDaSqEjwvDL6oTkeIyiGIQyCtrUyRAK6SmaUTCCeQ\nx2Q8RpeM8wzlLOO0xBaaUmuKvEQXOTadM3npVYb3b5OPRpRBzJ3JFD2fI2YTPOHIS402lsKWKCnB\nOq5d32atnaDHObosEQJOnVoiKzWZsVgsvnAo5ZOZkqK0mFyTWY0sNViNZy2lFUgHWIcpNQpBTcqq\nGACRcOg0JTz6+8mDXZQxSG3RWpMX+p15jirLY8UGwDiDtpbCGkptMaXGWMe9Gy+zslDD6JISSJSl\nFviks4w7D+6hneSJrQbrrRDrCZTvc/5Ui0Ip2oHH51/IePnalG9cG6C9AjEfcKqhIR3z1NmzPLG1\nxaX1JRaTgmTTZ2MdWqOcO5TSVi8AACAASURBVLdnpFmJcwX1SHF2I2Ha28MejTxMWZHFvBQUBWCg\nKA1FMaEopph0By8AV2TsHIxZ9AX7BwcESrDYbBwvNqqk1m4TNRTNhQ5Qo9SWIp8zGN9j8/Qpzp29\nCsKnlkSk84w7O9toC5cuPcz6+jqe56GUotlsUhQFBwcHCFFtEqPRGK2PcmdzkxvXbzKbp9y4fZev\nP/sci51lgtBHSIPEEYd10lnG7vY29XqDs+fOsru/R4HhYDrG+ArtCnZ3brG4ukijtUhRlAwGXW7d\nvE4nruEZj927QxK/QzNaphEsEVv/2HlTWIEXgC1gOBHMcoHWR4bTmsBoOHd1k2L5Ib6x73ixgDv7\nGbdvbnP/YMb6xQvgDIPRsCL/ErCW8STDc45SwMyBkIJTNUFDSW7cukfkCTzP47mvfplXv/ECd29v\no3WJ0Ya0hF3pqHUUZ69c4aELD1FoAylHZudKSTUGCmvR+qgYvV10jxRkAchY4N4+rWWOOdLS0A4k\nD0WOoRE8sJKxgUDBSugwtipo18fQzaGfQ28O2xPH/aFgmlcK5jy3DDPFKPdo1D10mRG5A6LmOvv3\n7pCsXGEyBpGOmRweoPMZe/s3Odh5nb3D++hScmM052sv3OWlssku8PJrJaQeg8kOxmp0kVMPlkmz\nnCCqs1BrUotbfPJ7P8A+oLVFWINxglQLSm2wuErpzkqWT586du4YK7C46hQWR6qtqwq7fdssfvTB\nOGeJl2D1dAMcxJ5lvb1Aux5z7sIqN/ZTLl55hOFoxMHeAc3IIYo+fZ3QtXVu3uuR1GPAsbhYJw49\nAl/hKcl7P/HdfGXUIx0PkKZgWCuZqxykV40z/YAwjLBS4XkSlR4deghjUA6nDSoI6CyHeM7gnODR\nC2cJdEFeljQaTfJcHys2wlYWjrdVLmdd9SzOoU01/pU1gdp6CP/9HyD81Pfi/fWfgI98Py6fcv1w\niPIEdnWd0cZH0Ebi0jH+d/40niuIVYH1E8qS6qBMNaAjErCUaIw8mjakDvvl38L8/i/j/s1fRXz6\n50FValV1NubIoez+1MLD0XoRRlZWCecI1x6jP+whpg/43f/uH/JPP/MqeXuJRt1jcLf3TePwriMt\nP2qiT50m2x1QZI7Y5chmiJooGv/JTzKzIeNhnyJKEN1DlLXslSWeJ7mwtcFh95C8n7OxtoXwh7j+\ngDOj1yjb5yjGKdHiGmvlmDxIkEDdSWZxC2VKUpez5AJ2jMHu9Wk0WxTZlKXOGomcMvJblLnCTjSR\nf4A+s4l761UK7ZhlBfVAgZQ4V42sXKVxYmUVUInF8mdmWe54JjAfh8MhtSGRkhCBtI7SCYbdA+x8\nRlJPKPYPsFKRtRewozFu1Idai0YtJp+lSClAhcxnGdLB3/yxD3HlQ1f4pX91Ceeg3x/yG//6mUrg\nM64yGnuSeZ4RSonCopUAXTIXVEZnq6vOBktxdCJYOktWGjxPoaQgdQJPWKQFbTPcXpegyPDLEi0U\nTio8HFMH87+Ah2eWFhgshbNoq9AICgcPXvkK97oTlushYVhJ10MPzm018Vsxo70h55uC3jQnzw2X\nliJmpWOtFnJ6JSaIchCgAstkrlhfS9gbpmSHU7KDCBlFLK4J8kyTBBEMAtqJR52SSx8K+Z7mEjcP\nuzRW2uTzlCyvTmFpDWUJpnSUZUFRjtDFlLLIMCZntv8WG4sd+qOMcWYolWBro83t7R7t2vHk9/Za\nwivXPkMSfA1rfNIsx4kuGxstGuEmWiUki8tsBDFCp3zH049zMOpz4dIZbt+6xu3b91hZ7lCWJXEc\nkyQJnU6HNM24efMmQRCglGF9Y429/QM6nSVWlxbxA8lbb92g1Wyxs3Mfa+HM6S3ubO+ytrJBaUv6\nvR6NWszpzVNIP0AEIfN0RqcWoJxAq5i9/UPanU20mHHx4U263T7bD0b4KmFjfZ3nXnyJjspY7hxP\ndgco5w4tBYV1DLqg/swOZXHMZpDUE7r9MUMv4m48Ix9MuJDlXHjoAouLbUzvFm+8uUNSQm4Eh3NJ\nmpcsJ5LvuxihZEhLRuxd38evh3zkiYcqJbS8w+Pvv0qy0OHq5UtVZ1nmxCF88PFHKYIHGCXoj8bc\nuX5YzQfU20fOBaWxkObMCgdBVUQqtZV3ZHl3pMTao+nFcfDixBBKS2LhlDCkDvapxmmnJSwHVWuX\nGrg/h9IKFgNHKGGQW3COQAo8CaHnCJTj1sASqwNqUZfo+psIHApDqxnS3QOb5WydWibWc4bGIy1m\nPJgqnnmQ82IP/uu/9UH+xY+sEQYNrJcSeAnOzYjjFtoG1Bw0mw9hLWTZgHpnkd//6t/hwx/9J5wN\nASkptCUQIIVinmqe+sufoh3Ex84dnEMgEOKI6DiB4E/JJkffrRNHR9en7Jn/g66IudD5Z6yc8ol7\n13n+5bs8/f6LBJ5gNhjhshl70Xn0lZ/AhTWuvb7HYO+Qb/zRc9iiZOviOdbX13jysTUatUrh+qFf\n+iU++xu/zZPLBjspyYoQz6uUB5fnWGOphQEqCoiaMRYD2QybTpFRggwNl68+SX/yh1zbdxxs3+P0\nuXNEgUdpSvJZdqzQJFRXjRjkO2THVvwSbR3OKTATGH8VBkcemgLEPGZwqDicl1gn+Pozb/HcquUX\nL2Ts5DHLb3waHS3gFzlSCNIZyLrDi1MWOh6rS46thZTbh0u8tfAU39X9NGI4QQYSV1Z+ZBcUoKu1\nbo54nDl6DQHSCZxwCJeB0QgkxDG7H/z73PiNf8XtXo/3Lizy2gtdnn5ojT/+zB/yH/3cnx+HdyU8\nunQUqcFRIsZDBsvLDNIxC50mtvToDw9ZKKuFVydnob2IDhSdmgMzIW+uIRghBndIXIkfQB+PNTVm\nXG+QzjLqUQtPgZERraRBNx2hw0XkaAfRWafTWmSmdyjjBkmjSek36alV6tMHTL2cTstH2w0WtoaV\nxG40tjR4UXUkUSIQFhwGIxXKvC2Xy2qUdSTumWNuPtZVxmTnwJOWHIGSglIIojBgPjZ4rho5+b5H\nMZ/jrKNmDSqbM5nMCZUkL0tkAb6sxjnlbMbowW3iRkwUBZxeT+h0PO5dc2jfoiwYU90fg6s64sCT\nRzK6wViLLxVaOJwFicIJjRWu6nhsdUIg8T1AkOsS4Rwi8rFaExiDwVJ6GoHAtxL/uDszYHVJ7hSF\nq8YLGjAostEe7zvfoT+Y4ZCUiURPCwokDz+yTv/2Hp9/a8x71xqkFlYaC+xPJjy03mCnP+FD5xcY\n93MGJkNLzXQ+p9NU9EvLg1lOMC/ZGfosNUL6gSHyPYw09A4c+Z0xq+dTvEZIVDgMIeUR0anELEep\nC4piRFlMKMuMMp8dHeGe4HzBre6cK5fOotIx03SCJxWD7Hidup63aTVSth/cY6HZYtgf4wenyeaQ\npwW7e3dYaLZITUo99sjyEUtLMXkxRJuMdKrZPLXBfD5nNBqRpinT6RTfVyRJjWazyXQ6ZDrN6LSX\naDVbNBoNOp0WS51Fbt68RSdqM5tl3L23zWg8JR/3WF1fqY6/mpLpaIDwInw/JNKamh8i8RlNp2Rl\nynZ3SFyTbJ1ZIlnw2fTWGPQ0K6fOUbt7G5OXPOgfHjtvPFV5XZSiuiurgHekkqP7kbwoIh1P0Lq6\nV6JAMOj2eWUyZmVpka2OwvMlOrfgLG0PhKdYWYpoJqCNZDrN8CKIWzEPDscEQhD4AcPBFC0SprOU\nTjshqdcpDPQebHPn+i4mHVMIj52dyqgubHWM0lPQCiQEkqnWR3ftVD46AeDzzghLUB0f+9ZvJ6rQ\n1ZaakqTW4VlHjeq/yaxjiiAW1ZCi7oE4Uo8GhSOQUPMECkmoHOFRkSsU+AZyzzEtDYlfqWmZkUST\nEicUTgumdw+JPMnu1HJtCK8P/lRd+J5PnmE6nRG0HUakKCtJU8N4NqReO4UxAdlsn1RXqtp83mN9\nfYG/9599F7/+y18lLwzaOjyv8gaZCxdZatXoHXzzLv2b4W2y886/RfW7txV+QXXq1UF1WMWCkgIh\nUl65/zOsrP4CW63H8Lweo7njlLXM0zHTGYw++OPEXkTkSU5vdDjc7tJqtVleb9NZXmDz1DIWSxT4\nZDNH5gyf/Os/yPQz/xezucVJgSkNKgwwZUkAqFoIvoVhisCC0bhSQyRQUUT70qPsf+YPuTUIuTwX\n1Ho9WotL4DcIwuOdDHVH5Oao9T9i3Eem/rdDJj0wBhtJCCoPpydT9rJFclvxgf5ck9QbKDdmPi9o\n1VqUxRApQpQtMdWQAWksyvOohSmy0eKc12VRP0PXnWbF3kccLQynDU7Zqkbb6mQjR55Wd1Sdofqs\n/LyPcoJ67ONqEYcPXqOQCfPte3g+PLbeIssF7eVv7ql8V8Izmw+pr9ZxrQ5Zy8ft7pH7glrnNAMf\nhPIR9Ro2y6lduspwNuWUnBL4IeloSiAK9NxQtNcgnSHzObQ2GThFLZ5jsh6ZW8SFEsZDekZTqyVk\nkz1UrUaZ5lDkTIVPYCR5mZJLHzubUfohXlxjpFNiQuYrp4niuBpv4fCVOLpjpwqZO9qBjKjGWMoe\nmYWrke+Rq+cYMIYSkBiyHKJAkmpJ4AzTLEWUmshKMinQRUnkhcxCiT+2dEddgkKTxAGhlBSi8org\nHP/+c7f57GfewqaOXEFZWuJEonxx5N+ySCfwqS4RdM6RGVuNxIxjLgTG2Io5G4uU1dhMW4snvcpT\nIx2mrI6tW2PxfR8yTTabITyvujAssxgFwsl3VKJjhUdbtHOUxlC66j4fqQR+GDNFksqAjZqPnDpq\ngY8fSFpBjBAB3VlBUAuop45cwlyX9IYls9zh9aasLERcrS3yYFjghIcpFC/3xrQWNMK35POc0HN0\nx5a7U0tRahaaCZcuR4yBXtfSeRisk0dkx1GWlrLI0OWUPJ9QFnN0kR0pPJpsPmcyz1hq1JjPJyxE\nHosu5sKy5ebu6FixmXUHzCZDOott5tMxtVpE/v9y9ibBliX3ed8vM8945zcPNXX1iO4GQACEg5RM\nB2UNDMumltLKG9vhrVb2xhtFaGHvHN544/DWEbLDCzFkmwxZIimTJgmCAwg0Gg10dVXX9F698c5n\nyNGLPPdVk3R38OFEVNd79Ya+N0+ezC+///d9/8qRyhErLalXM7YGCu/nQEmRl9TNirqxFHnO4rrB\nWsN0OuXw8BAp5U0pa7GYo5Rie3uHsixxFvqDIZPxmCRRfPbZCz799BFvPHyT1WpNmqW0TUN/0EcI\nwdXlJYN+yfHREWcXU9q6RSnFyfMXTCaHTFdLimFCYTzW1kjlSXOYTht2946ompYkTXEWVtXtTqEA\nKhGUvUic8Fefya5EIVSCDx7rJMMhpErw+7/1pyyBr7094td/7QNGvQLdaedQ0C9SBB5jPcYa1p2w\n/+lFxfj+GJGAlIG8lzMc5PR7JUY3tE1D28DJySlNs+CnjxacrwAbc3c2V79UGCkopML7jvHrSmfQ\nnZgdf5lquOWSE4Jg7QItkIdoaOsJGCC49mCEYNw5W0YqkAhBZSNbNmsDiYTCCwoZSCXk0QlM5iCR\ngmUTUAicCCxbcMHTGkHlBSdLODeScb/Ht9/eR3vLZK9H055TlCVJ3idR0HpLlvQo831kMib3axCO\nVGY07ZThYMLp9Jr/9D//D/hf/qffwQcZbeMhcJ1k/OK77/P24QQRbhlSRMxw6vTKQHh9fwI3YOcm\ni2TzdRFdunkaWCz/ex7bD3hw8E+Z3Btx/fT3WNeBk2/91wyEJE0TpFJkueLDX3yfLJPMF52sYWU4\n7OdoY+kP+7RVhQ2OK5mTBkviYxZc0AZpAxQ5vq4hpNirCnF9SpomUcjeVAirmaQtx/dKntaGQZHQ\nGIH3BiladHM71j2ETt4hvqCjDwJJPBgHK8BaxN4AuZVDliKkgsJxVnmyoodQijYIcAJFy7oGefwu\n7gf/GqGiucCLCCZjlIlnoEAIh0MyVleo4GltTpG0iOMDyAp49jQeDqwALwgdAgthA1hjWbhOBmgJ\nhVSoJOHyL/6CK7NkUC/YbST9411cI9l/ePSl4/DVtvQ7D2iCJDl9SZjsIoaK+3pBpme0155JmZKG\nnONBn6xasVVfwe4RYvuA0n/C2UKS7WbUs4qenpEGSbZ6jt17E7925K1lYB/TZN+kLjxpb4KqZ4gs\nwyQDWmPIyjGpmpNf/ozVes1s/+tsT59TH72DCbAlIUklKj+kGh0Snn5CIiBT6gbMROAT/6s69Og7\n54cghobdspxO5hxeCaSxOJXhrEd6SyslxXqNUhInLMfjLV5eXyGbml5esG6usIsF4yLduO5QQmBF\nXGBEqqh0S5kLshCLmd5bPNFGaEO3KMn4uiWB1rqoE0gSUmdwUtI55uNER+AINN6RJwrv40naeg8q\nPgFCeMTlJcnhYXQ4bcbO25/DowWN1ZiQYJxDO4kDlJdkecHjV9coIVimfY4nBS9rx9wGfu/ffY+D\ncZ/z2ZKV9WR5hvExZ+jyesrB7oDFSuCswXrD1jDl0YuK66XHt4aTp2C9I8sVVJ7do4REOgY9z86g\nprI9SikY7ikWly1BSLQOWBMwpka3S7RZYZoVWte4tsUaHYXUi1OkC1iRkPiU3ckus8U15yvYG2/f\nbu4oR9A1QYwo8x3miyUP3zgmUxlFvk1TfcZykXJ8dMhiOqNpNCYEpFRYZyl7JScnJxweHrJarUiS\nhDRNSRLFgwdvcHV1xXqdUtcNvd4Ao1t+/NGPGI16vDo9ZXtrQtO2TGczWu3Z298lV5GJHI6HGK05\nP7tge2ePk4sZB3fvsDMecX5+zmjc5+TVI7a2jiD0efHsFTZtOD58i8urhlfnJ7TakogCqW6fpeI7\nnr0xm1n3BVTgQdfgjEO3Dm0twzGoWpLuK5bnlnI4YrHQLOsofLfW4oRiXObYUBMQpJnCB4P14POU\nv/MPfoXRoMfTk/+V7cNjsjQjUZI8H1DkOW99+D7buwmjq5+SFimLumKVbDQQMdrgYu7Ym8iYI5Qo\nmtp3vEL0TgTXuWBs3GhFAsltWWUfEAScFKyFoCVgApQExggqDxfAXhq1LIWKDHftQHtwITDXgbWE\nXAnySEhReKJ9G5AiCn5dEMy05+Va0FhBlkGeQVWvODlfsVrAP3j7ASoIiiKlbedYN4zaMEoSSpp2\nTggW51NMEGyPJ9SNYWd0F2ta5hbKqDSgNo69d97izTu7nK8a9sZfLjz9ssv5QLCRLUJJ5E2sLEix\nId1kpxsPsUwiI+BRqcDZBXX1Rzz6/I/4H5/+XX6lPyTzBclgTK/Xw0M8PIaU1niKImU0KPDGkkgw\nrSEfDDDGdnloLjplfRJFKNaQZj1EIiF4nHe4VUvzyROGIolltt4QVEqorjn5/DF3797FfvIps7XE\nhRaVp9imohzeTjdINxKbw70KEPBIKbEenIewVARdwfPV68fOwdNmh51JQQhgnEU7Dd5TVUCz7IJ7\nYwnXdCGuXntUkVDKFuzqhgkdlVNMk8fnwUsIZdTC+S++yPh52NBxImqQEhGihMN7qus5yf13ufj9\nf8kv7xSUWU5bRbdf2f/yufOVgMdVNbme4e5sk0jJWGzh/Ijq4pLDvUMuL1dM8hY1SMlbw3D/kDKp\nuLo4Y+V2GY80Zr1g1C4xyRhlpoitO+QukPYLZsO36FUz5npOnibkrmKVlbTWUsqU0fJz2vaSbHxI\nur2FGO3gFxV2vMNWWXJ18RwjMnQ/wdZr5Ptvsvjoe+RFL05kIopVCAIWGeKU76qAUQV+M7C329Zd\n8OQIXKJIgkMKQZAC5T0i+FhaaxoutGZ7e4vp1TVpr0DagLCWoldEIOIdMkBBQiUMynuc1TiVYAQU\nQlJ5Q9I9sEIQ8x1QBBFhXColIThWTkda12hQyU25LoSADKCkRCBIpMAaS8CiUJAqggdXLbH+ACfj\nQpkEiZbh5woedNZinMZ6ifUSFwIuSFTRIxEpk0Ly+HyKfHCf7V6FWLb4Xo/aet45GPPTkzVlafnk\n3COTwLLMCTUc7ARqArVTXJxafvLCUbWeug0MS4VDYLTHJZJ1JdgbSIyXCFWyuLLIPGeUJaQubkfO\nOoxpMabC2Rara6xp0U2NMxprDdZYnFmSyQmuvSTbznnx6oy7ByOm84bFLfVfLmhkgMvLKf3BDqC4\nvj5htSqYjLfYGm8xvVogUQx6JVfn13gpSdKcshgQjKYse8znc6SUlGXJfD7n4GD/xq21Wq346KMf\n8c7b77FcrijznF7ZQ0nFaDTger5iuVxgnGQwnrBYLJBSMBqPSaVkd2eHsj+gXLZIIblaTEFYtnfG\nrOshw16GFAn90ZDr9QWLZYP38OaDuzx//DN8EBTF7RdlqWKZ1ru/SoIIktDZagEfHASHUqAkpIMU\nzh0/e/qKrVFOpaFwIKXicun5hQcKLwoEmjRVXFk4SGDdWn7jN34XZ+HxowXf+90/ZLA1JkMzHA75\n9LPPaVYzRvuHZJOSb779BmfXH7O6Mp0GJ77Cfi5IVTxQqeQLDEJnx4UO9GxYh9svOd3WLTrtRUBI\nWAHWw7YM9AKYIHjVwF4OiqjXGSCoBWjfBfIFWNlAIyCVgr6Phy7RhZUKoLKB0wpM8AQFqza+BiWh\n0oI1AaGg198hSzXGgleaQuQkacl8fY4VgiIZsTY1xlYEPyBVLdYOmC+vWRIrfdZaevsH3L9/n8mw\nz3TuWITbC96de+22FT7EsEfJTXgrIgKf1EfjEUHg4CbnJQhwNr6vuftt/pUHVx3wHwtP3TakSU6a\np9SNZTjIAAmuJUsVqYrrf9vaKO4l4OdTamNQSYo0AZGkeDyZkmChzApUdY3b3UFt3cHPniNVSgge\nkffYv3PA1XSOc/Bi5vj2lsc0GpUXrKufw8a2YU39jS44Otp8wLWe0EKYA04g0w4Ilh4nJHnicc5j\ntYluN6dQ0rJ68jGq61og8QQUxjjSNI6v31BuKpZ/gwOl2jjmL06B08jqdILqjTsrdG7RzXwMHrwR\ngGTVGjJTMdk/5tXjK+QHB6hU0SwbirJgffXlyfdfuVIHIZjkS9JqzrBZI2YXtEmfUBaYtCDvO+ph\nStbf5s5uhq+X6AX0lSJtLsiaJaZeI1XCTtpgsj4qz8jQNHIASE4MbPX7IAV52sPowK7UeFtR10vS\ntsKHnEruIHtHhMufMtAWMV8w7o8ZiTW+LFCFx3/zA66nS4pMAQERQidY9vguhDACoI1zYoMibl3Q\nIrgoeo5pzpKl952lF4Ru8QbKBHqJoqljqSjxHkcgSyW60RjraYwF49DORooxkQyShJZ4gqxDGyeC\n8zEGnIASMkK44PB4GmcJPran8C6eWggBGWJPIikFUkBwFms0tdUEEeiVA1SWdjVci18usd7gnScx\nBmNrgulCRm551a2l1pq61bRtg9Ytuq0pxwf0tyTXreNga8TbOzLGCGQZj5+dsyawbi2zeUNpDdPZ\nmucnC3RjuVi2nJwbPnvZ8i9+d8m/+3jNqmlBeD54d0i/p/AWVCpx1tHzgY+errlaaD56vKJpNcNB\nQnVZ41YOpTKc11hbY/Ua3S6xek1brzG6xdjouNNGU+uG0bhgNMqY9DOUCiAMeb9g75b6yhfnZxiv\nmE3XEAJGtySJZ70+xxjF9niPLO2RJyOadUCQ8fjJC9Zrw7qqKMuSsiyZTqdsbW3hu5DBk5MTvPc8\nfvyEpmn5+te/gbGWPM9J0wzvPe+99y77+3tMJhOc91xdX/HRRz+iyHOGwyHVehUddNbig+fg4AAf\nPMtqDsKxWk+ZbA9ZracxNM5YVmuL1prRcMhHH/0pu7sj0jSh17u98DQVnXAxhL/yTHafN/GUHUIE\ntst5TAPfzNDtcUHZyxgPs5jo6qPbrpdv5nkEJmnockiEoj8s6Q8LigLuP7zH2++9y4MHb/Lwza/x\n3rtv0eiKpRXoVeDPf/KC6+tOwxLipiAiodCdCzpK+QtHySBFZFKlYKNkDi7mPt3q6s4e7sZdE8FL\nJaDqmOKeCAyTwFUjaDpGSEooFRRSUG4AYiecrm1gZgJTE1homOvo7jprBEEqiizDWCjyPt/+4EMO\ntoZsjXaAIe+9MYmljGSEZ02RDZm37ub+9ESKcwbdLhHK0Zg52uUs1+fsjCb83b99n2vvUbngrXfe\n48HuBIQg70/YK/PbTp14P72IQas+RAdm1GrjQszfsd3YbZQ9oRtHF6LLyzlF00jWRtBWUOydoY3H\n+0Cvl6MU5FmCEhKCi+G2wSOVREhFkiqkSkjzgubJ5wihEEHirYkOIx8IrYZOmiDyHulej2A1Mklj\nmK3VWONx9KMVvS+Y5ApnLNfXs2jGcbd0aX0BXAc6PU8HnEOIImEpAnIQkPtlbGPUxDLgLChyJQne\n4xBcXFyzXDfIvI+cv0SIJM7Jjo6xrhv31kTgI6Ij2nvwbSxbdXREVFZ1z6XvdFWv72F8na5DZ8Fb\nnPfs5Ak7x3dpFnMqAUJ62spyfTGHTBK+omfLVzI82+s5vpeSDnfplSn1fEaoK8a5ZtbWZIMx272U\nfn1FI+NJuswzqvUpAxNwIcVnI7RoWbYVaZqROU2QBWNzSuZ6WJaYfItJmTGTfZQ9IzRLfL6Hv/N1\nivEIu6gwwbEKkuLf+/ssm5pRO6VgwEKnzNsWuc7ZPXoTefQGmTBd4GBcX4KINVwrYn1adguTDZ1W\nQHRhVbe4QvBRRIlDADke6Sw6KKRQZMFQJSnWeFTwjNKU6+Ua3TTYRmOSCFAyAa33N/Sd1iZOBOdR\nwZGrmLLsfWwR4XxACo8LggKBEZC4gMwVofXI4NDOoxTYTrmUipQklTTG4EMgEwopJU5rXAggFQHJ\nlmmx2iCCwBA1ATmaW/Q3vLn2BoLpsmKuA41TMc5cOpJygrCBDx4eU60q+mXC3Hj2Mvjue4e4IOkP\nclpt2e4pJqOCi0VD3VoGw4SgYG09w96SXMJwOMT4QOYSBpnDFy1FkpKVkidXNe/fy5ivLFILRJ7z\n4qxmMko43hL8xR//bazHFwAAIABJREFU78zPTvnm3/svMLrGmLoDZw1Wa4zROGOw1vLOW/dYvjzl\n3TfuoGuDlJrpyjCUt5882mWY9ZrDw2MW8znDfkZoLMpGC2qW9vBO4hzMF0v29rb57sF3cVbSNoZ+\nNmQ2m3N4eMh6vSbPc6y1DAbDmKw8GsVTUoDj42Nmsxn3DncYDPsYrVmvVhRFDgSyLEPWDUIIBv0B\nx8fHKBFYzKacn58j8wFSeIaDIUWhqM2aPE/oDfq4EEhVwv7WFnu7uzSN42h/RLVumIyO0XV763nT\ntDGYT4buMHKDtUO3qNI11Q0Y45hVIFOL6+bo9cWKz7KXNJUmEWDxiASqpsFah8oEMlHY1mNzaKqK\nhw/fZDLu8+MffMb9N+6ztTOh3++jpIxlZusYDHqgYLmq0KaTwH4h7CV0S3jsUMdrqUg8e930BBI+\n/lvwX/j5v+HlO6Fp3BzitNvc5ysRtTrbKjJhozxQ28iWDVTEWmUSreuJjPZ1E+J6omMlgLYDQcZF\nR5wQlmkF3/2FX6BICzJamizhfp5yfJxTMmc43MOLFYkqma2fU+TbOK8p8h5KBJKkoHEBY1O2xgfM\nlpdsj+8jZeCXf/UDfvsPnjHav8vRnXsc7Y6YriqGvV5sXnnLy28scC4KwpN4wr2hCf2G894Me3dP\nVJQ/dmMZf4cPUdw+vfD837//n/HBO/+cb3/wDcp+jzxTGOsRQtA6TS+T9DJJbRyNCxQEVCJoXz5h\nKDK8N92G3TFsRoMqsM6Qb02QZSDYGpRCJJJQB4RKGB8fk0nBu4egHRFwJRkqFZh6fbvBCdzolXyA\npKO0RPdv9TKQLWNIo0gUYdmgVJxbRqQoXPzYB9ZVS23i5ipWJwRShJIdIM9xrLBWkPgY/EsHgNGB\nkIiYf9dJKW4eIfeFrgdfuD9RxxMz5YTTyM5lvDae33n0gncPBzf3zdmYZWe+4oD+lYCnaZ8zOv4l\nDoc9ripNPT5EXl3iR8eEZUXpA4PRDn7Qwy9eYfoTliHQ2zmifvkZl65kq5TMQ0kvLfHNFK0tl16y\nn24zHPe5vhT0tEMl4O0FvTKnrj2lX5KaBCUd5wQGaQ9bVTTVNT3tCNs72JCyGjjyVlMkNZXNSN56\nH3nyAwQKEUS0oSMQXpJ0iD50fUykCPggEUFsJOx/88v5ri+XxIZ4AnBBUihoQ6B1gTQY0u7hWUtF\nv0gpRXRUJHmCCpGOTl1sHJorifIBLyTGa+7d2eXR8zOyLn1TESiUoDWeVAlM8GgTA9aSEAt3KSCV\nxIdOx9QlHdfaRIpRCGwwYBOEgCyJJ9/gHOuqIqsbGilwIdLNrVSkP4eI59GjH9FUCwbDCY3pY9U2\nQuYkQrNYW8rjwP0HIzyK/UHJujUom/JqtqRMBFu5ZLLVZ+4bPnx7wr/9k1OqtuXD93MSAou54o3j\nAeXdPWRZcf50QaMsJo3Wy9XCRLF0I/ng4ZCs5/jDH62YrxKSN1Ns22f14vs8vf4+78x/HS8Vpq0x\nbY01Da1uI9hpW5z1TFfXLKqGiWmx2nN5teR4Z8RZfcnDw71bjY3M+qyvZhg9RQFF0kNXNQ8f3mVu\nLWU5QgjJarVi/3CH9XqJcgWvTq8piwFJP2O1WvHw4UO01jx+/Jg0Tfnwww+5vr7GWstqtUJrTZpm\nPLh3j8ViCcKyXq9pmobecNI5CGPpoGnamIkVAtYZVssVdd2Ahb29bRbLC1bLhryXMJr0MUays7PP\n6emUUsD16QVae3qjXaSH5Tr+vtte1sT+Ov7G6f+F5zIRYAMegQ8B0/VGyvOcg8M9Ts9eIIJgOBrg\ntUYHg/SCXEqKtKAya6wP8fnoMlush7OXr/j4oylNA7/7W7/DeHePr3/4kP29PT7//AVXFyse/x+/\nz+dnmxcSefbNqTlRsUyMlAjAWA+24+I3LE7noNq0mmDDMN/i2mwGoSuXed+tbV3vqAUxqXovmmwY\npnGDmprI7PRU1OxI4t+hK4EtiUDIdcxHogQ2wKKF73z4Ndp1xaNHf8EbWxn/1d87YtLvo/oFv/W8\nQsgabxuEKnCuxnmBdoZhknM6uyZVQ7a23mC2eIn0liwb0tanCJHxn/zaN/lv/7vf4vD4Dr1M8eNP\nn3G8U3I2u8Do27VrgRguKLqEZYmITTK7ZrSbWoYg2p29j99PBwDw3UzbyEZEzDFzTmLsT/iTj/8x\n947/iLxq2d2dkKcpxhrSJEWHOI96uULKBONjWxFnHYIcb+J7US7KFwgBKWQEaBLE/WNU0SOsmwg4\nVIISUO7cRaUZwQVSqej1B9jWUDWBejG71dhsJByv8cTruedCFAzbijhPwxIlwBhB0gbqAMNuD1pV\nmt5bE4y+wnROZ+8FwcW2JsaKyOo4CDesZ1diJO6zQcQIlyDETWRD8LHFxM1hoYsOEB68iAd35y2J\nh8Z5Vuua65ef8639Mc5b2tYhhSeEwPr6y0taX+3S2nqP5mxN1h/TsqRYOabLBaK/x9o6jkoBraNa\nLuh//TvY1YrxckpzvaZMTQxpu/RkZYmxNZoE0Xju+mdI28fqFpFper6l9ttMlqfMekOGSWBJghcV\nue5TZgl6ren1C4Ztg88GBG0w6wsmMmo41NYBfbtkuT1CvYi31IvwWpMfy+mobtDtTVG383LdUoeB\nCHhrYjR3gMQLrHSIoHABslTRekXdtvQSQVFXhDyDJCXBk0iFs5ZMglYCbWI/LSGjRkc4z08evaQs\n0o5+9ogQ6+5BdLEem/ekwLk4czwy1u4TsNZFrZyU0c1EQAZH8CICzBBorEEh8N7RrCpSY8B3zUdl\ntHTfmv4CHv3Jv0ZJuLsz4vR6TbL9HWT/Dpcn3+fBwTbzlWWvL/nR00vOrtbc3xsxyAyvtMYLyaVO\nOVk0ZCl89njG3qhAKcn0zPB87tguBOsAW7uBxWXGwds7+HbNbNbi1rCYOUYEViLhyWnAmoaMAuUa\nvvPOMWYxZSIu8fNAszrHiQLnWrRu0E2L0S1WdwyP0eA016s15SynyDKClGRFn1oVnFe3y+G5f+9r\nPDx6wN2DHT776ccsZufs7Ix5/OinFHsPCblk0B+ilIXgKfKcV+dTIOCd5fLykrfffpv5fE6SJOzt\n7aGUYjqdorXGe8/Ozi7eO05PT7l355j79++zXs+RsuHOnbtczhYopUhC1DEdHh6ytTUhBI+xlslk\nDCrBkjIajVkuL8mLjLxMMdZzeTkjzXLKPGOSDfn82ROslVxPz9jeHrKYVlxd3d5arDIR++w48de1\nY90a3TQt61WDsQaBoNUt1/Oa/hCKMuHtt+7yuK15edZgutgG7UEbj5SBtvVkWULdxpP3eGtMUAJ4\nzvvfeIdiOOYXvvVNtrf3yZXn8z//HipkQB11Ox0bsNkhkzxuqmkSn31tOyZGBEgEwhHTo0PoDlfx\n1HrbMjqvyYrYCLPT3dxE8csYqmi9YFcEBiIm7OZpzOa50hH0FB3jEwgxgTeJombTsWTaQ9MEvvH2\nm4DkR48+AeC//NX77I9KZGIYb6W88azh7OwlB7vbGKHwvmBQ3keKlFpb9idvkhYlrVkjRU7T1iya\ncwbFiOASdnfG/MP/6F2sHTMqMmrpyIoRtYKJvPXoEEKXpC86b27oSiRdIGcXhdSxLdyAXlwHBG7y\nIrtdoyNvrYU8hX/z+/8Nf/sX/xk7O+OooRISmQDe44PEWocPDumj/KBtPalwBCUQxsWmzkmCsCZq\nJFOJKEvEG/eB6FRDBIJMECHyUTt7R4wGnzLs9Th5ds4v/tIvYZsVvdFtG/PGktIGbEftavc+O1bP\nmmgLj76eWPRrKtAyapSa2tDagHexBrjp1sEm70gpWgO95DWIVCIyPJtQHWE9QYnY6y6J94JNTlVX\nfgw3qKe7V93T4r1FSIH2kmZ2xfUP/4zsb73FslpQDgXj/W2sdnj/5czyVwKeUE4Y7k1weUkpSox+\nRZoXzC+ec3D3AfXsKVMUxc4Rl9MF6eoSVfZI21c0o/fZbRc8tTMKuYIkwS/P6DlN/tb7XBhBcfmc\nkZXoYkQVHGuZs93bR+WW7ZCyqBfMxC6ZfoVU0T4c8n28NMyXDYM8w6kxSlyw1pp20VA0K7xUSNEZ\n2kIAIZFhgxIlUgQIEkHU3VgCStzW5kfnAJMk0uGcJ0VhhCBYg5eKHI+XgdZ6ijyjtZYkSUgQaG3R\n3tLL0ljbJWBswKQB73yMw09ltMv7eFIJCHLi4dGEqCEaoFg6fZPcqkVMmdYmkISAlLJLMHWIEHN6\nkKprYCijWl9JVBBIbZC2RlpBKpPYc0xIWnl7iif4wGzVUirJ5bMTJlenvJqukDLwYuX4+oMJ52c1\nD7f7bE9SDvKCy+s1eV9xeVHx4bsjXJ5ileJlvebb7+/y2aNr3jkc8eCNPkVjeOks26M+qmoYbEmq\nVUKWNFyKaz64O4TKILUjAK8uClRWczAZMKs8zUyw0j+jv0h59K/+Odvf+SfY5gqb7iFkjnOCtm2w\nRmONZrFecDDssT/oYYTku197wGfPX0EiyEa3o99DCPzxH/8hH2WKrWGfyXDAYn7NYFDwyaMnfP39\nbyCkJM97nQNGY82UPMvJs4LJzjbOOebzOU3T8O1vf5vPPvsM5xzvv/8+T548QQg6IJRwenpKu15w\ndLTP0dERWmtm0yl5XoD03Lt3l+vra87OTjk4OKCp16wWU1aVYbx7yHw+o8hzrIs7gtWGoiyYzRbs\n7dzj1atL5vOKujbs7N8FlaKSFP1zdLxWEkaFoN2IFW++Im4qXLHzddyWvI+J/EWScHiwx4cf3sPp\nhpdnS0wAKwPSe/JUIaXCuQbtFKvG0U+gaWEw6JEVUW8UgMnWsHO3abwNKCk5mAh0UvL8UYxXFkKg\nVMytCZ2GyBpNIlKsj44sIQVBv65uiU2Nvdt1btu/z7MBWpsfDJ2GSNy4MgUCTeBMCFoX2JMCJaFP\noFSdRV0LchXzeZSAnhJkMqB9/LoXgopAnuf8yY8/BiQHpWJvpBC5JMmGUU8YDGlimdcLtvbuULuE\ndX1Gmgwp8y20WSETSSIzdrYGWO1JBCRFybyqqN2Sf/j33+M3/01CmQe+++GbfPb5K0iGZKPbu7R8\niCF1LgAuhsAmaXjdksnHeRS/3LH9fqPz2ST1xLU2rraiK4FJWh1ozG/zB3+aI+U/Y3tryGAQy1tS\ninjo9HFejXKBNw2+MbgkwXfGmVhLNBGoEkXnvH2IPNiL9kOVEIKLbIhzyGxEWhTsDPucXK1JESxW\nC44OdxHhduxpxy9G3BE2IuH4FR86wXfXSFncsJddynkISBHdwD4IqtWS3r5CiuiENETwKIHaSSbh\n9f6olL8pHYZOTxYssdRLBJk3zGV4/dAHH/PyQkfJxS85TPAkBLSzzC/AaEeOpF5rtg8ThAg3tvb/\nv+srAY/Pe+hByaJfYmbPwUicbhhmJVwvMIMxKu+T6ilH9ZLry2vYuofYOkKkgmubcbAtMStFtTT0\nkoLhYExVSwbzp/T2jzHrcyRRDDX2M3rXF5h7f4vpxXMmRZ/1/AUhz0iCJ5EB7ypma420FSuRkrg5\nSX+XZHlKkymy5Qr8ZrK6eOe8v5nISoAVArWp1QpP4sOts4R98KigaJ2JOQ4hZvwoF9tXaGvZEQqZ\npMyMpW1bhlnOMhUsuhqqUoqcCHATIQgitnXwIsQgJikIQSCFxwaPJD7MCiBElsqGGDQYOkidAhZJ\nKgKVdyTWgYwCbktE8HGee5SI2hofHMoZKmsp1wakohEOJSRWKZS6fSaG94J+pliuFuwfjPnmO/sM\nJts8fnnBb/7bP+Xz85pvvrXLvKqZLRqKHYVNISNB5Rl1Y3h5taLIUlIJP/t8yr39PmtjmD6b0u/n\nyJ2UJ58+x6uMoh1TrWts1fLW4YTKN1TaYrWi1wuMj3L2R2OG2oDRPHpVM+wrBgcS73JO/uw3CE6j\n8iHrumXn4bdZmX50r9mWrWEf5zxaBEzreDI/RwvFcNAnSW8XAuZtTaYUEmLQX78kS3us5heYNlBV\nNffvv0FVrZhdXyElvPvO1zg5OSXLcuq65uXLl2xtxbLUz372M+7du0fbtjx//pzxeIwQ8OTJE7wP\n3Dk+5s7hLk275gc/eMRka4KQEu8dWhvG4y1yFdjd2QIBRVGQyG0QK7xzNHVNr19gq4oXL16wczBi\nPJ5weTEnz0sW66eUgz4ycxhv2dk74POnp8wXt8snAqibwGIl8HZj6v5iQb9jIEyDbrt2JxbI4OBw\nm7PLOZ988oxhEZ8bgKbxrK1HEMizFOdapIRUKVrnURL+/Ps/IcnjUlgva5787Bn7kx36vYLHjx7H\nhoSDnHcORyxX10irub5s8d2eU6SQpcSQUxnLKqEDQqJjbUOIIIiNNunnuIJ/rXmIJa3XFZvIPIWb\nwEPrAlMhaB3shkCpujVHQC4DrRfUNq6JAwVp19yUIGg6Zk3ebByef/TdNzEu6Q5WLYI+3/yFXYq8\nJB8e0jQtk3KbJC0h6WFMi1SSJB1Rtw3GrCnSPo1cYkxLIjPqVvGtb93lf/uXP0XLBzx5foIWPYaD\nkqQ3uPX4eN8543y4ie3wTtyUAENn+tiIzb2LDJAnxJZDN1Ot2xvobOsibsAyBOaL3+T//N1P+fe/\n+z/wjfe/jreBNJMMctkdVEVcV6slSkqMjS16pAgEZ5GtQWYpMnTRKHvbiDwn1FWcK0HQ2fnI+kOC\nzBiOt9nt9GveaB598pSvffjGzzmJeB1ZcoObBdZwoyn7wjRC3di5wGsPIrBqzE33edH9UikkLgiM\nVx17BsF31QHrI3vWAUxBHE827kW/AVzx4xsG1cdviSaczeuQJCEgRMI1MFtqjkpIEglK8uj7P+Jr\nf+eXvvTtf7VLqwL/4nOaJyfYa01YvaKXZbjRiHYwJMxr8tWKEsmqv8twqGgnAyotqJ+ecmg1sveA\ndPE544Fja+uIfDigNZoLsUNYVWi1Qyszes6QDnZY9e+xOn2EkDlLmdKb7ECzhiyWXCrrKPWMO9u7\n7G7vE3p76HSAUSN2dye0rmNDhO9uRqw9+iCQHoyIziVPIHamChASNgFHf+N5EwLGGRIvyIVAJQrV\nCaekCKRKssRFC6mLjoHWWyb9Iv4sEYnWNhaSlfcoAso58hDPIcIHBB7T0YZJiHofQ2ReVCculAGE\n2li/QXpHY7tOmEoSiH1qEhFTk1V3mtHORBstEuc9rtYIpxHdz3prUdYS9O17aS0WU4JwzJYrrs9m\n/NkPn/LxRz8j9Uu+/uYR33jriDwXPLs0eCd4Y3vIpFdyd6fkeJyRZoGLpaWvACmQSH74dIrTnt1S\nMEqgJ/skeY/jnQLjaoKqKMeBpxc1L55ahMooR5DniiLXhFBzJWvOg+LgjSGWlKoCL3NIFdoLrLME\nAs8++j2WL79Ps/gcW1+RqRSVZtSrlvVyTZqUeATKW67Wt9MbPHv0MaNeyf7ePv3BEJWkIFICJWUx\n4OLikjwvefbsOd5HKvnk5Iy7dx9gjSPNUqSSFB0rIaXk5cuXlGXJ7u4uWmt6vR5JknB8fESaJKgk\nIYTAnTvHLJcrnjx5Qln2uHv3Hp8/fUYIgaura+7cucvu7i4+wO7eHt4HrqdXtG1L07S8+867VKsa\n76BposDe+IplPYPEUzUVP/74x1hneeutN289b1wd+z017i/zO/BaZudNwAV3s3hJCdVqyfnZCbqd\ncbG8pm0sxgNE0YpDxEyezrljbKBMYuPPd792j6998DYAw8mEO/fvcPfeXQ6P7/P2u29RrywhSZCJ\nYDZfMppEu/3Glp4k8ZUmSpInCdZ0WTsbNAJx5daho4K6P7fVDQa6DfH15rT5eNOt/aZzu4+5NMsQ\neBng0gqMj6UsJaFUgX4ChQzMbWDtopBZSEiVRAIfPXp0879+ezcjeEvwPrILIlD0DpAqQ0lJ1b7C\nhiXeLxFuQZEP8N7jzBpnZgz6e6RlSZJsQQhoM6UoJty/MwKvWa9XpPmYfpmivOdqdnuwTKftsp0m\nxLm49gYXx8KHeBCLVvSurNXpljYMBHyhrCK4ccYGIXAheqyFe8Tv/cGv8+zZc/Iio9GOurYMhzmp\nUljvaK/OSZSKgNQ6WDegLd5ZZJIipSA9mMRIEOOiw1d0zEta4IVECMH24RHbo5IyVzRaMt7e5s13\nHrJarW49Nq+3uE1z1ddlVWs3LrWobXLd+BgbS4LWe+ym86izgEQbhyPpuq1HQGm9imxa9ztEJyCO\n4YbEykKkKm/m7WbObkqzIUTw4zf3oGNznbUIAjp4mtYyArJg4r1xgfVc8+Z3vslq9uWC7q8EPJl+\nhU4EbvWCNmlpJmOMUPSLLYIINFqTppZK5WQiYT18E6VbQlsjegV+ssvs9Dm93XdZhRQpA/P5mtyc\ncTdfkuyOsUWBG2+TNa/AJvT9iqAKiiTFt4aqdfQGfWamT1ZKEmFI+rtMRR/tNam5ppg/QWYpl1aR\n9wrCTd5mJMGF39SsJTLE2uWNhCdIEP72gMc7UiFAOFrvaLRGG4sIniRI0hDQITIwibQQAsJ5GmtZ\ne0cKKDyFFFjvMCEQRDxBOSlIhMBIMCEunpbY+DDWowXGRYFpQWdk8Y5URCu6J24CqYxiMhHi6VPF\nEjFKyk47KQjeda/ZI5ynXS7BW5SLGiVlLdLfHvBcmhVpIrmqK/qjgixPOF2uWM0FgwyadsUPny44\n3Ck4OpiwNoaPn81IgV98cw/XeL715oBKBqarlp1Rj1/7zj36vZS7R2Men9esZyt2J0OmVbRFl3mf\ntglIJXjjwZisdMgkofUK1c8ZZ57rc49wKUk54PBgi8EwYV1dc3D3Q6wQrNYxcHBQCKSeU198zsWz\nj5g1KwqV0ViJUZL+IKUs+lBresVXEqV/7bImakGurq7o9wbUjcYYh/cZw/4QIQQ//OEPOTw8Yr2u\naFoLQXF5ec14soNznqPDI2azGcfHx+zsbOO9Z7FYMJ1OWSwWrNdrdnd3mU6nbJoJHR0fRlHfekWW\n5VxfX/P8+XP6/T7jrS3eePgw6oCMwTjL9fWUdbXm4PAgltPyDOc8TWPI0oKD/UPWq4q8t01vOOK9\nDz5ApYKj40OyXNA0t3SSAIRusQ3AX30mN3u8CNAtzpt1fLFY0StGvPfOAf/hL78TSzMBtPPUOgr5\nlZSxi7oQGBcFviRQ9odsbcfwyLKfc3Cwi1LRDQICbeJi3C8yEHD26hqIDieAxsYDBEpivCNTrzeX\nzYYiNkFaX3xbt0w7jSfnL8TtbzbmjVYldsp87XzpdEbawZWHV16wdvH7BVEDnkliGi7QeFjaQGsg\nTwSViRTWYb+kUAJJA97hQhLLHNJTDg5jxEU+Jk16nbZFQdAcHX4XVI/tyZsYfcZqcUmrK7K0T54f\n0Mv7aHHIew8NptH0x0NsSEFJesXtbenWhhshsvUB5zsxtt+wOVHjIzZjttGzhE3pamOC22hGvjj4\nAReihd06iZTwf/3OP+XsbIrWhsaD1hbvHUmSsf6LP8HJTnDp4s4dulBNCch+huz3EPu70C6I0d0S\nZBobNmcZIQR0W2FFj1nleDWdMZ0uEViEul1OkejcfBsrfvw7HiqCDzgLQQe8DTcg0bvIvAgRDwu1\nsQREFzsQ4yFwG3AZEEhMNwd9iG9JSl5bF70gOHEDNm/m7V9idjZMZrixqEN8HQSHCIFcSer1kh6R\nh2tah8pSVKIQSvDq0cmXjsNXAh7bHzCoZsjJDrLM2PNrXJLTXM0xTUXrFZUfkVQXSL9g4M7R/TFS\nKQrzAy5PnnFYtqSzM+5kmsG4j0wzDvfukfZ2kNkA5mvSqoZigiwz5skIFzzBGco+BD3l0hTsDgKN\nGOFlhuj1WYcU7TJEOYJ8zEh5RldXmLNLhIoqqC/E7OCUwKkNsOlq6BtqE9GZy//mlwoBEzwy+JtA\nwFQKrAgY72i9J8OTSk8uEgrho6Auwi6cC2gfotCto6ad0zgCZQjUwZF5GV1ezsXFy3mMd6Q3AWIO\nI6KYy/n4INsQOwGnIS6AInhcAOM9XgRciIK6ICRJl02kiAhetga9XFI4B9ogvac1FmFu77bZ2uoz\n2BozHI0oxwV7O2O++/4DnlxcYVygrT39NIlx7SLw6fka4wLf+3zKTy+WrFXGJ4/nnJ03fOuNfU6X\nK2QSeO/uLmmesTfM2BaOV1cLVBop1YuziuGwZHs7xYYW4VOUEJQjhQ2KpetxfL/HpKxIK4NAUqox\nGZIffv//Qa8arNN4o2l0XDCraonwNUlW8PnLUzIMpZJcXS0YpYHhKIvHoFuNTQQow0HsdC6FQIgE\nJWM68t07d+j3e6RpRn84ZP/wEGMDz56d4JxnvV5jjCHLMtI05fT0FePxmCzLODk5YW9vj6urK169\neoVSCU+ePOHq6prZdM719ZTt7R2EFCRJitaaEALLxYJXr14xn885efESgaAsS7I8o6oqjI2hlufn\n57z77gfM5yuW8zWrVcP5VcXu/h2sN1hfMd7qs1hekme3L4WCoOhBkf/153HTJsnZqEVTgmjxJgIa\nKSXrynJ+folz8SQvBdQaqnUbGZAQN8E0lViAABfnF3z608hm/PDPfswf/b/f4+Mf/5hPf/oT/vzP\nfkRtIVjPso6Na+vVDb8ORNNA/FXx35zgJroqtpbo9Bqd1iZsdtbbDs+m7Xp4vUH4zSax0T188XNE\n1zE8RLeWDbwKgrmPTjchIuhJRMwMGyrBKBFk6nXgG8C9UUqOIRUpQUR2QvsUXbfd4VLgXEWiBhAU\nQeQ0RrNaPUEJyenFY0CQl9uMBnt4kbIzPiBLeyzWFb/yq2/x7OVnXJ0+Y5RahtKCuX1bEtEx+f5m\nHOLnzkWg6B03QDmE2K9Nbu5Jp9mJPx+63/F6PP0GJBHXVeclzvw5//O/+Mc8fX5KpaNLdzwo8PNL\n9OkpiBTXGrwx0bHlPd7FTVumCrHVjwdOY0FG/YlvKpx1YA0hJAwnE0bDHkc7Q4a9HuNBgXUB095u\nfMINmiDqlIMHfqZZAAAgAElEQVS4eZ4iewLOCrwVOBM/9y72sgvEJa7RFilidIq1vmPDo7Vqk33U\n+AgUA5uDS/hL8xU27rjXZarwBcC5GXPvYmnSuw6QOfDWYkPgsjEs5xW7dLqsEFiuWozRWOOYXn85\n+/WVR9PEe4zPEa1kK9XM7IT+AFaLFf9fe2cOY9mRpecvIu767tsy82Vm7WSR7G72Oj0aoTE9gkae\nLNkCZAiQMDJly5YlT7YMWXLkyhEgDATMoCFDmhn1MtMryWazKllVub3Mt7+7xCYj7n1ZZJOcTnpq\n5AESWZXLy3fjxo3445z//0+W9ujdn6BVRloMSPMcnVfE2RC1+Bir/pBjfcH+cMIsH6J9wqpeMTjc\nY2YLfLQim1+xQTLwFT5OcCIi0xafKMT2OZV+kyQXCHNJWR+hBhnm5Bl1b5+0OSPKPVodIOKMjavQ\n5QIhQmpQtMSsgOI9kRXomyJXuNGdQ6S/PYfHOE9sg5cPOCIZTnxKBBKX8FBah3A2HLBbByzhHMoZ\nrI8DilXB/8JYSyYlwjtCwlAS4dDeE/lQioukak8sDiOCwMwS5JOO0HYC4ZEu9NuRsltXg6LDuKBq\nkSiMCP23HJIah7Qaaw1isaEebHEqIbKB7xTsMm8XRwdH/OY3J0SpYDpbMpttGPVS8lzx/HLNHz4Y\nUxjN/b2U80XJ6WzLvKwZZBF1aYkiwfFeAd7ys+dTjvZ7/PmPT/mz779BXRuuNg2DImF7uWIw7DMs\nQNiEsoF+njM+yrm+XlMZSyE8qIr1NsaZCF/U9I4kH/38kiLqE6uYPNU03pNJIFWkiefjizWpinES\ndOlQUYxwgahXa4tznvmiRg7HtxscL+n3C4SXaNOEViHekecFqdtydnbG/v6Ys7Mz6rpmsVjibPAc\nquuGKI53Ke3Oh0dKSVVVHB0dcXV1xdOnT/nVr37FYDBkOBxSbreMhj0eP37Mh7/5Ddvtltl8jooz\nmrrm+PgQYwxXV1N6RUGepVgL66rBOUuaFvTygrSX8+GHH4GPePLmO5y9uiAfHqKyAYgSFW05+fhv\nOTocslndTr3WxWoeTm5d7Mik7TSUsjVw84AFqyBKIlRVEcWSvBeRZYpmHUpYAEophLVEKpBInbE4\nGQDR4dERaZYA/5c//v4fkfR7fPVrX6NXDBDW8MO/+F9467mcLVuZ+SfZRbJda6z1of+iaxMHHRH2\ntY0F1/HzOmLo7x6esEG717gWu9d9/Yd8i42g/SMdQJOstaeSMJKhLU+xcyK+aayZScEoDgeotYEn\nkz6xCHwnSQkuxdsty3XNPVNjRUUapxi7wgqJb5bEcRZUjmKGiipEPMGYBTio6iW4hihOiKKC737r\nmH/7g//KX7aXEEcJeVHwH//Tf77dAH2qZOOdD/xI2WZX2i7qOwJ5u/kHvkoYMNEByTZjJgShy31b\nXwm4R+AaT6RA+J/zg7/+DxD/e/aHbzAyjuriFUhFow1GV0jnAvjUGunasY4jeOcN/HwGaRbeum5C\n2wknsLoiGubko32yNMOYmvGwoDGaB+MRLz5e3W5odvNCtNLw8EWPbDk07f7FzXUigiGhpfW9cqGF\nkRXB+LDRpi0XOpwL1ZPG3GRtnISuR5F3hPYSrS9S4LfRytJ96xXQgsuOQL1TRIb7IVpK0ONezkc4\nFMF2QdcWOYxoSk1vLIm/IOP+hYCnt7mgzB9QvXpGEjUkB2+x1bB/fJ9q/orh4QGrRpF6iynn5PGA\nl1fn7NUrXNPA0VM2xuOufsEiOmYgPGex5qj5CN/fw0rJIE+5rqb0kERx4FMoVyPSx8R7PcxyRS+O\nMXbJ9tlzhpMjit4B9eIZZZSQrF+gYk+9jrhal6RnJ8j9NNw26RFOYIVESIcQEmGDjZGTYaEKGTPR\n9l353cM5h5Wmtdh2NEaQqojSWJwNJawMSSNEILMhcFKQpBmT0ZB6u8XJkC73KnBrgqOkIzCKAkx2\ngJMSrKHxoU+WEMFUy7YlLIMLaquW1ehdIC93c8kKggpLhKnsRSv0cwInDEooMBbblJSbNVxP6fWH\nmKAtCE3kbhl+vmG51TzOhrz5tXu899ErTl7NSLxgr8gY9hJELUilIEcwHuQMUsV8U3FdNmRxzCBW\nnC5LDoY9/vidY64WS/7bj0/ZG0cMJilnqy3f+eZX6EnN2XJK1EvJnGC2XrNYzYAIZz2lSoijGCEa\nZCRwNubsrKR/HNOsNqyuFXhDESVczrYYC3vDgsN+TpZHlKVmvt1wOBiRRBpHjPKSJI1ZG8k7e7cj\nWB72Mz786CMev/lVlssFcZqyXC5pnGdvb8TZ+SlFkZEkCbPZjPv372GNYD6fs1yuONwfB8fssiSO\nY4bDIZvNBmMMe3t7DAZ9Xrx40RJbLU1dMxgc4b1nvV7xne98m0obLi6vmM8XHB4eEUUxh5MJ+3tj\nTk9PKXoFF9NZ2LSlZLPdgFdEaY80HZGmBVUJ+5NHfPA3f81+P6ZJNoyGKVlqKTdrvoS4D4Ej60uc\nFtStff6OadC+njbBzbfW7BbPzXzJRsOLV3AxC9kda9mZFfbSmMVmG4zhXGiP0E8lSSI5e/kKorAU\nPn/+krRXcDAYEcUxJ8+et414PV//2kOmrzynH84/ATJctyCLAKCi1siu61jeNRAVkiBPt6/tzbcI\n59yOjN3F6+8jgKyb/3Qwp/WzDXxFwt5yiedSQBLBN9KWACBAekiVp4gEeSxg4bjeWkqjcInBGpCq\ngUbz8sWSR9trNME7ylHTmC3Hk29R1hLjIvrFkNVqynp+ikwKirggiwpqVyHFMUl8xf6jJ7z7Rwf8\n8IdXIATGala39JkJF9Z6uxDWOteVSVrXZdEq5KQMjTSVCuoq15Fsvce27Q92GR0PkQw/61SwS9BN\n+Du1FqHEdfnn/Pf/8ef8xQ++xz/7k3/H93/5l+joXjCdrSym0SRxDJstqj8gynKKf/o9GBaB0JwV\neCK88ki9DW86ycFDXVeM7t8n/iDD+hohI6qyYr2+HeDpUo67DuTt9dEeHLR1mErio/BN5UTgG/uw\nj2hr2XiNEpKrdYWpPVaLXSbIWYnbCGrrsHXoV2YjiQqE2TCm5kYJJwjPlNOBd+bMDbBxbWmySwo5\n2u8nAZDW1lLXdRD8GIsR4DSIWKErzfjg81vafOHRPRscYOorxuMerijYOEOeZERGMMj32DQ19fKU\nsi7ZxGNeXF6iBBwPYlSWwmZJNMjY7n+N/t6EOK4Ybl7hlCTeLIL5XjMlywpsNMaVK7xe0rdL8s05\n5WlNGqfkyRjnItK9Q8q6YbGaoZ/8Ab2DN0gO3qR4+A9ZOU/y7ARl2g1dgHCKTmLoRTBJEiJMVrhp\nO4F3re/A7x7WWxQShcO70KOqthbrHco6GgSRCEaAuQBhPU4bBIZG12xaYGKcCwThcAYIKNlbEgJ/\nRwqPsSGVGON3lviyPa147wLfB79zZA5lqu6E191kt1v0jDVBlhmwO03rQSq1xtcNotLUdYU2DVrX\neHP7k/r5csv+0ZDptuRnP/mQPo48jXk8Loil5Ocv59TOURrNlbPkWcLDyYC3jwc02tLgWVUNX3uw\nT5Ipfvnskrp2PBjF5ElC5Wr0KCZLPWeLFddTzXpVYmwJtiFPU7x1DPsxHk1VO4SPcdZRl4ZYSFIk\ncZTQm6TEqeR6WTHa67E3HhApKBuNd5ZGa9I4Zb4quaoiNqahJx2rTYUSnnJ9O4Ll+cUV+0f3mF5d\nIpVivliwLUuu5teslzMePnrI4yeP2Ww2fO2rXyHv9Vitl/R6GdvthnK7YbVacXx8zGq14vr6mtFo\nxIMHD8jznKqq8Z5QEmvdkwN/Nnjs/OhHPwYheeedt0mShHJbBpn7conWmgcPHlBWVQA7IoB0FUU0\nRmOsZ7GseX5yznyx5b33PsBVS/qppZ9H7I/3cUYSqxz3JTKDRIKmhKaBT8OCHam0baHyejbDWMNw\nkPDGkwO+8nRCEon2+Qi/U9U6lDcIi7sk8Ia194wP9nn06AEADx8f8/Sdxzx68oivfvWrfP1bXwcB\nWlvqxqI+wx9GyGBwsUvZdGen7kH07R9rU/ti93DeLjpiaPfSvq23vP5599GWGHYn5LaEILrsRfte\nGi04bRu1ClrQIyCWUCjYL+BHJzMqF8bKE9YfJeCNp5PQMy0pEHiSKEZFQzbbBVV1FqwJ6jW9Yh/t\nFXl2gIzH9Hr76EaH8ojesKkF//pf/Gm4/RGoqCOE3y5cR1Jur60rR4W2ES1nxbcTprsF3e7fKog6\n8ncYs278ROt7JogVxInosALCw6aWVCVsFn/Nrz/8AddXK2rjgjrJmKCAM5ZExERRHNbmcR9fN8g4\nbRuLa/AGohRMg3cWJyRxOqReLRjvH1CkSXAZl57jR/duOTqvI2Cxu2xoidk+0Cxcy8nxLkjUnQ1z\npzaOsi3dW2txjaU9n+NaLo61Htv+23Vy847L5ggcHk/bu+zmo5vXHdHZOdG+Tlue7F7LhV1MO89g\nb0xFsF/p5r8SHhVLyurzOadfOK1kVfNANcQ+QcqcB0WEd7ClptQ15caS1zVxr2C9WsCwx0A6PuaI\n2bZi7VIq1Sf3Jb5aY4tH1HtfIUkTmsGY5XSOTA/wzRa/eUkvUuSiB/sPaUZDjvg1pGMsMdZbyAYM\nc0UalejnH7JuUiqZMRU56tFj4hcvUIlAydBYE0G7ELVgBx/cN9s7Ebx5ZJue/u2F7AunjxNUxgRk\n7CzGORwW7wwuiii8ZWks2jqstzjpkVHwY+gVPXzLOA/ydkHSnkQkDuk81jki4Ul9yMYkxlGKUPpQ\nwmK9JcUSdyuD8GgczodSS+U8ot0UhGPn6iy9bxuR+iDbdJ7Eh5Odcx5XN0G91ZjwYVp6/S2jX6R8\n/fiIP/3OY7733cc82O8xGWVstWd/lFKMM96/WPBXH1wym254/+SCHz9fYKzk7YMel+cznjyY8ORg\nwCSJEJGisZ4oVVgL9yZjnuwPWJWOJRtc3uDWjuuZpawiXr5oyNOENEkQTmG0pXGeyjRUtUVIwXxR\nUaQFOZLZwpKnMbX2JCrYAGjrWS8bhoOMnnMcHQ4ZxZrYKl4t1tRVw6QYcXZ1u9Nov+iFurh3zJYL\naqMRkUJGEfPlHO8d77//ASpSzOZz3n/vPcpyy+HhIb084969ezx58oTVasVkMmEymSCEYDabYW0w\nEVwul/R6PZbLJefn53jvGY2G5FnG/fv32Gy3vHz5ktFwiJQCbTTb7ZbNdst6vW4dmUuUksRRRN1o\nmkZzcvIx8/kcrTUff/yC4WDI8fGIk+e/Zr1acX52RZr2SNM+St1Org+A86iI0PH6U9FxDkS7YXVn\nFGchTSSHBwPuHY3oJVkoG3kwVjDMIUpT4kiFBobhxIDWnohgFpckgQS6WFRs1iXWGrQO2THnoKoM\nVRUyhO27uHlj3oUyhQoL+o6L3BkH4UP5vP36bjO/Nejx7Sm4XcPaDWy3pH3WR0uU7T524Kgt2Xjn\nuTTwIgjudu14IGwgvVix0A2zVYM1Fd5YgoQCvvH1+2yrC5J8xLq6wtgmlPbchkH/MXGckiQD9gb3\nSOIBziqECmtULx1T19dkSZ+yrPiT778d7mVLzOaW6zHctIdwzt2AunaDdU60/ned1w67Pk3Oth2+\nO7VQe/2+LUcGQNKOswygLEk8KiYoSAVsS0ndwIuf/hfWySQcspt6VyITTY1QMgDK+0PEYAhZD5Ks\nBTseGYd2Lw4LaR+BQ6QSW9comTKfT0myHuvVmji6nVACbsajUwd2VDLvOxVbB3Jk4PPYwOWRUtAY\nS6WDglU72wJLj+kyPG2fMk/gjHkbvI66FI1v+2t1nBzXtpu4Aedt/yz8LhtkO+J5m3ELwMfSU5LG\ne06AygafNWODItpq+4WtH78YR6cp2+yArashyWhMhHcVtvbM1ltiVSLvP2a9bcgzQTp9xUrHTBJN\nlMaMfv4/mb44QZZzuHqGbirs1Xvo/IAMh7QzembDSMYUaY+Lyduo4/uopkblD5APv0MkNY0y5MrT\nMyt0sYdXA6KsoTn/CV4vUKc/Q5YLxE//DhlFrTnZjbNmR0gLh7MAcQSiBUVtku+W+XfTUs3jzu/B\nW4wFhSR2Gm09qXdYZzBe4DWtE6dHZAm4Vj3mw+ZaiVA/td4FPx8haQjKiQiPVhB5R0zosSSFxIrg\nYGRbcpdsT3fmNRWady74HHlPQgA+ljARjfekIii2jLNBrm41Zl2y2WwRukLbBqFvy3CC1abhVy9f\n4pIe/YMR/fuHXGy3vPdqSrWu2S5K3j3a47tvHHK432OYJ0hXc7IqGQwK/tE3H3N+teTkes3RaIi3\njtJY1pVlqjWnesOmJ9HJkr1Bwl6/z+i4TzGISWTKYKRIC0WjG+IkwhqJ14EnNRpmVBbyPGZdr5hu\ntvT7Cuk9faXQ2uC8Ylyk5IOUdQX7x4eMe5KPztf0Rwn9oo+1FR++OiNYht5mbEqkVIzHe8RJwmg8\nZnI4Ybw35pvf/g6Xl5f0ej3KssQ5x3A4ZH9/n5OPT3j8+DEgOD8/YzAYYFtCuzGGPM+ZTqecnp7y\n9OlTkiRhb2+PouhhTFBdxXHgjvWynEcPH7Fer4njmCSOqaqK87NzmqYhSRIePX7M9OqK9WpFvyiw\n1qK1ppdljEcjrNFMpxccjMcMih5NbTg6vI8nRhBTlbcHyqiwMRnz23Oua8XgCMT7DgA5H1pSXE7X\nPHt2wfnZjG2lcQQJdpYEg7KyttSdR5f1LTcDZtNrPvrwGQCzq2tmV0ENc352ysmzk1CecjDqZ+Rp\nd6/9628MpYIIQMpuEb/JMkDnxeNveAuvnX5/5+gyF7+laLnFS+3AUJct8xjtmRp41YjdVQkgEdBT\njqMEfvqqRGuNcR4nIqyF5mTBd0cDrK4ZZhMGvQmJKhAMUQp0U3J6/WvOph8yGh4hlWCzuqaslwjh\nSKIcIRUqGTE+GPHwUbHL3Hn/Oe//CyIcXEPWoCPR0krPOyAVOCcCZ9qms+6mJGm9D9fXZsREV3rx\nfocdRTg/E8eCOAUZudb5RLBYQ09+hTgbgANb1eE2W4NwDiEF0lnibzwCGSOcw2/XgTIQp2A0ZnWN\nUK1CzWqkSFBpjzxLONgfc3H2ijhWlOvPb5/wufe9A3rc4GUIGSzrOkDSgc72a11pyXq07UYZjLbB\n4sEEno/TnbIrjLUNhYuwt7Z2Cfhwdg6E5OCo/gmSvQ8jbl13v3gt6wPtxo7FE7eCG63dbr1YXq1b\nQcDnPw1fCHhWS1g0EhdnZEkoFSW6wixfcv/eAWNpiZWh3hg2l2eIZsmmWXJux4wPjjD/5J+jzBRl\nNc34IbGpSNMj4nqG1hVFcZ9KeLyPsVIzuvgVs+WKRiq82bLxfayPSZsNMs7QIsVePqeqNenem0wG\nI+JygRk9ZiAjVtOr19QFqn0I/K5BqBXBk8OJm6fJO4GVis84UH7x/GlXHd0aFjggtgbhLHUTFFHa\nOhKhsMYipUM6ENogjaYRAWR4OkJzm1Vomfzah0yPFC50UW9JX85bEu8xzoXNrpXtKedDH67u6fU+\nSOIJpxcA4TVVy/dRbauKxvtgh46jdjpI9psG2YRMT2YczZfI8Lx5b0JiPKJp0JXjarbhw5dz3nx0\nSB1JvvnWfYZpRG08S+M4ngzJs5RUJMzKip++mmFc6D2jpCXPUpSSTLcV+4c5w9EA7zXWGeIoZIDW\nqxJvPP39iMlen/OLNVmRYrUnU5ZBP8ZLh7GOSGk2VYNTkCtBHKWIKEZFggcH/eCuXGrWlcYIzS9+\n84ym8oz7Oe+9OAdvGOURvVQyu17eamzW2y11XWOtZW9vj6Zpdv+vq4r1es3JyQlJHKOiiOPj41Cz\nThKqquSnP/07jDFIKanrmul0ilKK6XTK/v7+DgQVRUEcx4xGI46ODsmyjOvrK1arFSfPTzg/vyDL\nMpRSXFxMkVLy7rvv4r2nqkp+8fOfs91sQQimF1M2my0PHjwKi8tyhXOOKIqx2pLlfZrGsy0dL04u\nmF7O6eW9W88bZDhRquh1WnCI7uTmndtt9kC7SDqcs6w2Fca6mzkfBZl1FqvAw9CByeLaRdUamBxO\nODgM/dAmh3s8fuMBx/cOefj4CV//5rth87Jh3vhPLZmCNjPbWvIjQumjbT20Mwa8kea2QEN+qQTP\n7nc+TVjuNo1dZgM+tZm8dsJ/7fUgvCer4bTxvGhEew3hW5GAgww+vCrRLqY2ZrcTzZZLDocZDyWk\n+UOm84+xdsOmWaCtQKohaZQjRUJZrUBAv39EGhfk6aQ1lJMoGUN8zL/5V8EwznF7QvfNmPiWLP7J\n0t3OVbnNpGtL27E+ZBA6ECm5ATqdfH83lu1nKUO/tzj2xIlARR4Rhfl0sliG/k9Vg9QGYTSqapC2\nJUQ7T/b4Md668CJR6yhtGrxzRHGOSHPCg6CI8xwlFcvliqZ2DEcTVJKQF7d7ttrOVLtMoG8vZqcs\ndK35X5tpcdYHMNNmapwPZrpBrm9xrWLYm/ZZsiJYAXCTpbMu9DQLcvQwx5xtQU8LfnwLfujKYC70\n49oBL9uSmYN/IwIYxBHLzTY4O/tAmm6MxWiDaTSm+fzWEl8IeJrmCmE1cdqDeJ/YWXySocbHlJdT\nqmiAc5ZYXVGMxkRPvknBhnp2wXI2Q16+JE4G9Mf3mMQW4dYkqUQm++hSU8cFfnCEzCUqHqPimHj5\niroyRFlOZM8ReokcHTJfNyh9TdI/RtqS+fUaV0yoZUE2e8a6rsCGUkVYeIIyaZfCDMWt1pMnZHYC\nx0UQAf6WfAPtglLLOYe2FuUctXdoG8wMpQ5ZFWM0CZ5GOGIZZOJFJIlVoElr64KiopWaRwKylmwa\ni+B0WRH4RlqEJqNWQGeVWDuL96GRqceTEtpFGEJX+NpbhA+GYY0TCBdKKdpZmlD03p1gIkLbAGMM\n2mhc1WCdubXsGuDjkxfEzvLyfMHZ1YJYRDw9ntAYzziPqRrNyXTD//7glMTAtx8ccjRMOB4K+nHC\nu/f3eDTO2QrPj0+v8QJW64ZvPJwQ5Z5NWVFtt1Sl5Wph2G7KQLKznmGSYMWWg72M9cIgnKQ/zEky\nSSQiitQirSCyBfE2QwlBGsHeIKPoRaysIU9ihq3k3FaGWEbMVhVFntCLUiovqFqFwDq+HYdntLfH\ner1hdn3Nq5cv+Qd/+N0d4Lm4uODdd9/l6dOnrNZr+kVB0zRorTm+d4+6rjk8OmJ/f5/tdstyueTR\no0esViu+8Y2vE0URSZJQVRVXV1dcXV3xy1/8gtVqxXq9DgDFWg4ODnDOkmUZy+WCwWDAeDRmPl/s\nMkPj8R5HR4cMBwO01hRFn9Vy1fbsquj3eyyXoadWudWoqEev2OPp02/gnGI2vx0QBPB16Omjm89I\ngbT/tdZ++ogKhFLY4aRgfy8hSlqDOR88U9I0IlKKPFOBryYgVtDLJFIJ+v2ifW2D0cE93VqLbjTW\nC2wD68U2kDBff7+wU4tDeJ7dblMJGQYhAAVCifCAd9ya244Nr/1O+49PAIMOBP09FaHPyp6EsgJM\ntees4/S0wCeL4MXlgum6Ai9p6uBSGA9zOLvmXm9FT09Jkj693pC90RtIqbi4/Dv62QH90RFlvUZ4\nzeXVS2SUMl+dI71DmyqASa350z/59u46vkyGx3c3ozWXdG0WG8ROsm9tS0xuS1y7zAa729LWA0J2\nUMiwae+AAiHDIxVEbZYnigVKCqIEzuwH1NUGtlXgbDa6Fe6HDLzcK5C9HkgZJOlKgnN4GYxBfRzt\nuHN4UGmBNQ1Zr49KU1abFeWm5oOPTm83NogbAChaANjyZ/CB7tDokPXyrYzfdmDQB0NW01ocWOva\nA0BLLjYiGBcaF36/BSq2bXTQtazYlcss0AKaTxoPBrC1K5MZ2r/XyuRNuMnzKuxRMVC2hpll4/FK\nEqUJsfr8vfwLd3mVpozHB8RCBMXO3oD1/ALlIyJvkLZCN4rh+B7IHKoVvcnbDPoJPkkxQjC0axbW\nEQ3fwMoxa7nHxXxOfviQ9bpBWE3lYpSoaRqJGgzwMiIzG66qGJMOOb+cMUxLbDHBqZh4f0LPXXCy\nsLi8h+ntIV98RNQuRh6QrlNggRCSzne560nlnQw4vi1peWFvNYGwEDmLcIZYeIzxRM5DWxpqhCDC\ntUDDIozBGxcclq3b+TIIJW68GVzYsGvnSRCB/4MPUjof5OkJoJ3FWIPz3fdcK8Vn57ysXCAldxbm\ngrb1BUEaH3lBZj3O6dAjqc0MeatxRiMaA42lKSv8FyDmz4u3H004nAxZrbdcTa8RZYl2no3WLLaW\neakRieKNeyPiLGK2XeC9ZL7VPDjqcbFYsawN0jRsF5rpuuT+pCBKPWnhOdjzCGXYbkuqckmSe5IE\n9lPJq+mSQb6PlCll3bDVTXCS1R7dwKvLksnxiPEoQg08w35EZRyrTYnxsFpWZP2UTdmQFxl5HFP0\nEi4WW3oKJqMeWRRxMbPsTybs9W8nS5+t1hjnODg4oCgKfvKTvyWJE5pGc3x8zOXFBZv1mjeePOHk\n5ITT01MePLjPfD4LHAljMMYipWS73TKdTrl//z4/+tGPWa1WHBwc7Jx++/0+WZaFBSKOybKM7bZE\nKkW/P2A+n7O/f4CUiuVqRa+XM5td45yjqkqGwwHb7YZiMMAYy8XlFBVF5HmPPE+RUlD0+xTDISpJ\nmBw+5tXpFU/efIdHj5/eet6IlusiPqPEfINvbvgVYRMSiChms9ah2asLJeGwiYXTYiRDixbXgnsc\nRFJSNo7ZdMbpyzMAnj9/xdnpJR8/P+HlyUf88me/wrbPTlB7/vaSKQinTEQg88o2FSOEuFEPuJBR\n8DbsOB2341Zj033uwJL/5MeuHPXpofOf/NzxNj4dvt3AXtWeuQmbviT4i0368PEcrGvaEoOh1vCT\nH7yC5yc8UhXWbJhe/pI0KhBScbD/JnEyYLM1KCmYraYIpanWJ0SRwDqDVD2sL2mc4lvf+irDvfT2\nSLAN9xhZ1UwAAAUjSURBVPrmGW5HS9K+KTHuwJAL/K4AgF4rdXU/R8c9F+29bJtluq60JZAqNGGO\nU1CxJ4qhVPBqeYXTJqyjpl3bfVh3kzcPYdjHV+XNjRDBjiQIbDpAbEHFCDz9yX3wJUWeYustkfL0\ne7fr3xeux7ccp84rpy03dSW9djyMuyEPGxtSBdb5nZmjdxZtQwVDtyDJttkgZz3WSLQWWCODt5Dx\nbTanuz8tX8qIXcPWbtz9a2A1lB5boLQrkbUtyRoT7EFM+JvbxqLiCNM4+oPPz359MeBxjrWOkEmK\n2syppxt8ccB2dk0zOqCO++jVlNx71nVJ1TK9dXGMJ2IbFWwl9MyWbXnN1q/J9BV721Ocrcn7MbJa\ncTAaEg/v0RsPyOOMYeEox28y0guEsBR+Rd1/BKIhtQ3GpER7b3BfnJG3KLJ6dhomoWBn8hW7YKrn\nCQuR2HlSgFBBoRTtOD236+wsZEC52oTSUmQtjQ1GhJU1OFOFmmzbOVdYT+kM3jqSKKKDV8I4VDv5\nKhd6ETkXskNRJNsJ5oJplXfUzoaMFCFbhAugKN4pCsLq571DtmRt57ufCSee8PuBK4QP91k4G1xo\nnSNyGudCc1O0xX4Gn+Lvi2a+Yr0oKRRU1nFytWQ2nzFAoVTDs8s1e2nE9x7ts9pqau05GKRYLzm7\n3pD0UoaZJJaKw37GW/sFk37GS2+JWq6SijNEbNk7GGJshiVlriXjfsJ6s6ZpSrJMYLXA1LBee+JI\nMBnn2KZCiwisZLBXMM5TojwjSyXaCUzVkAtJ1Ri086QqZdSLWJQWKRQphidHMdeLJWwvbzU2UX/C\nVguWywVNU3NwcIC1ltFoyPn5eehiHimMMRRFwf7+/o6QnCQJSZIwHA5ZLBYcHh7y1ltv8ezZMx4+\nfMi9e/faklSFEAJrLYNBH601m82WsqwYjUY0TcX19TVlWXF1dY2UAucc17MZh4dHWGsZj8cM+n3K\nbUnRy0OPrSjm4YMH5FnGcrlkMOiT5ikOx3A45q/+6v9QFEPef/8jfv7TX9x63uyIlJ8x5bp93BnH\nJzgwHqw29LKIzUZzPatuXHJdMJgzhMOENiHz21WYIgmTowlvvv0EgDefPuLBowc8efqEN956hz/4\no2/v/pYQnjj+DDJ1q4AJqqj237bNmgQP1PBzUgQg1yGJLzM2HWj5FNj5RPrHfcbXXvv8Otj5dJmr\nM3l7WQnq9n1HEgYxvH8NtXFY3YRNTmtq2efX769wJ88ZX00ZDd9gU53jnKbcznk5/QDvSmqzppf1\nSbIDBqN3SJIU6zOiqCBSY4yp8DLmz/7lP27H9EuMT6D5YttSlNspfTpD1puN1bSIMWz2olVc3Qxa\n1wncA1K2QmrflW/C60gliOOQ2YnjUKFKYvjh+Qd4bfGN2WVqpFJIKcnefhiaN+sKbw2uafAqDkpY\n3UCU4p0BROiY3tToasto/4gIwfG9h8Rpn6Ojv7nt4LzOc3+t5BkeONe6JduWsGw7kLHLfgUhjfOh\nW7ox4VkyBpyRGAPGORrrW26UwLTs/a6U5doS4utOzrtMTpfV6cCXuWlv4TtOT1toSKTE6AYJNCZk\neLT11JXGaAM8/9xhEP6zoP5d3MVd3MVd3MVd3MXvUXzJvr13cRd3cRd3cRd3cRf//8Qd4LmLu7iL\nu7iLu7iL3/u4Azx3cRd3cRd3cRd38Xsfd4DnLu7iLu7iLu7iLn7v4w7w3MVd3MVd3MVd3MXvfdwB\nnru4i7u4i7u4i7v4vY//B8JgNN16YgdXAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAABJCAYAAAA5f/zBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8ebBlV3Xm+Vt7n+GObx5zHqRMDakB\nCQSIQQgwwmAwGLtx4fJcVQ3G7cLdptyFXXa7XIZwea4qV5W7beMJGxuXzWCMoRBIIDRASmhIZSqV\nc77MfJlvvO/OZ9h79R/nvpdiDD1HdHSE430RCr28wzn3rrv23t/+1re2qCpb2MIWtrCFLWxhC/+U\nYf7//gBb2MIWtrCFLWxhC/9fY4vwbGELW9jCFrawhX/y2CI8W9jCFrawhS1s4Z88tgjPFrawhS1s\nYQtb+CePLcKzhS1sYQtb2MIW/skj+HZP/sEvvEM9Hu8NYhTvBeuV3CjiFFXweKx3OAxGhKNHT/DM\ng09QqVpEPV4ECzhAUBQhzwEFY4vrghJoQCaORuZIe7Dvzz5OdTZm/lKL3fumcXOnyemyP3mAdmOF\nkewZoILPE9JMafcy0paweMUTBoZr98ZkrkG3VWJtLWftXIRbjhEr7Ns5is0hU7jvy8/w8tv2MFQJ\n+aUPPSLPN3B3HT+rAYIRwYvBiVA1QobgTIAzjmEsDSOExiIIIoYujlCEKqAi5IAVgwgYEWIRegAC\nUoQJgyFDsRhEPFY9iCFC8WLwqqSAQfAKqoqIkiuIKqmAeI8oeCDxHgS8KhGCB/I8p4Qh9x4rgqpi\nVPFAoJ5790w/79gAPDifqQ1C/uzXf41Dh27h0D2vI+07nBMyL6DgPKhXvIL3gvOKevC+uIYdfH+A\nQARM8bwqg9wDp0VWGSliZowiIogp8s4YRQyIQGQFMYoRxRoQI4O4ewRfXEOv3lMGceA5j6tIkcc6\nuCEAyiu2h887Pq98zcu1sbLCwQP7aawsM1Sr0W01mZ4c4Y4X3Ui32aRWqbJrxx5ec/c9PP7oY5hY\nOTV3gSvLK4RqOHLqWf7m4/cShiOEpYS9B8cYGa0ionhnyHNPs9WkUqqxstqi1+1jJMR7B5JTKofE\npZB6rUI5ipjeViOsjBPYKlnqQRzGZPSThMrQHezeeTfLjWc5c+Y+Ahpk/UUqlR1E8T66vsqNN72a\nbqfN0vyDrCwfYXbna7jr7rfzxhdOoEWwnhd+9jffpquLXZyxLC9f4PSzp4gqltQr8+ebLHXgDz/w\nn7hm3xQ/8p7vh6ROUMkI8iqVekxbLlMeUzqrguZKtytEATinBEaoTUOWKGOTEXbpIBdPrCI+pdFs\nMzE5zN2vfREHrjnIrp0HkbjEV499ht/9rT8h6imv+6FX8t3f9f0sLZ/gT/7sLzn8wCWCkhCUFclg\nwpRpt1NKs0rqPaNThtVFGNkJnTVwbRibgkjA9YW1lnLqsHvesfmRH9yrVxopFsWKodns0O/18E6J\nSkIUWZbzLn90x6u49c7bcLqdLLlEd6jChf4a9d0HCGyMek+WrpH0G7TXlvjEA4+zttbj0SfOQ6j8\n/E+/E2NXGR2uEVcqWDUghlxTPvZ3D/LBP70fCUwxp+EI44g3vPluxobLhMZTCgx4B0ZQbzF4sixD\nFBRBxZPnnrfe9moiE/Jbv/dBhvunaNlhHu+UePbJM5gILiw+/7wB+J33/4QmSZPQWuIoopghhCz1\nXFy4RC9N6HS7tJIOvbQL3hEYQbVJLexSjQMqUZ3YRsVcgBbzg8akudLPHKmCCWNK5QrlaJg4rlGK\nYkQC8syR5zlZ7mi3ezTW1pi7dIFOP6GfDeY4daROafUTBAtiwViiQAhMiLVCYCzWCqG1CIr3IN6A\nFnOSFUNghQcePfG847OwcEZFFbAEgS3mAaBUHceIkPQbiIQEYQSAsRYRAVWMCbHWUsx3V28p8rW3\n//p/A/STDuePPUTz8jzDw0PsvOlOyvXJr7kOgKrHPfVhrLRg58tRDH75As0nH2Z1pUXL1vDjM0zu\nOUA0PAriiEojhKU6YkJACaMK1oSkT32E8q1v/6ax+baEx4lFfbEooIJB8QJGBacehSLlJUC8R8Xj\n8pwgAERxCOIFrGIUPAarHm+Kxc15TxSAd4bMOsLcMeRg5CMfoRH26DVS6rNVehdPMFtR2u1l4s5p\nNJ+nv9rn7NwVli530WQf45Xb2b/3RkbH4fLFeRZPpoTDj7GyepHGlToXLjbYXZlmtZVw/2OXiENl\n22SJFx/aRiW2iA2/XSi+EarkKlgDViD0HowlpCAa6g2rtiAfuXMYYzGSEWPJVekKRFCQjMGiGmHI\n1RMbg6gnw5BSkMZQBcQReiETKSYOhR5KoMUkqqpkQICQqWJVCBBC5wgN9MTgvMMNSKgB2t5TFkHF\n0lfHkBREMEcxA4pq8ZuLDQWJUJfxgz/7Xj75f/8ep3/7P3LPu95LGAmSObwarAFVwXtwqgQe1Bck\nRorkKriGAihGBDf4ngx4iFW/8cKCvBTkUQwYKYiQHRAbaxVjiuusEx2kIEDF6Ci+L3xzsiNSEMji\nFcV71kfBZjA7PUWtUiIMDPv37aGxukwUBwShodvtEoYVdu3cx2tf+2qefPwwvaSF8RFHn36aF770\nTubPXyQOQw5cs5eTJy9TqlosFudSKtUyLjc436NaNXjfJQh71Ict3baSZxnGQL+XEoQBgqMaClmv\nRxA2icKQKI5I+hl5lhMYSxiFVEZGCEb2k+YXaC8dp52v0GxdwTUzwvIYjz32YWoVz9rKZYJ4mp37\nDzIyObrpvPnOF/9LXvWy1/PRe/8f/u2v/ivGZuGtb3oP937xf3Bxocmv/KtfYGK6zL/+2e/n5bd+\nL6nv0G4us7KyQBKuUK96lhehuyrEkVAuQZbC0JRgAyXLheqw0O0m1IIm4zOWSGo0Oy2mpsYIohI2\nuUJw5Th++xtJ+jm5E2KrdFsLHD3+WZ4+MkezWfzqUQi798cICa0zKe22Y2aiTD/tkaaeXtswBaQO\nnEKvAa02tFYVl24uNiqGKBRCMbR7DmMt3nmcUwwRzkEIHFla5tbza3QmQk5Km2beZq3TJnzmKRrN\nJiM1w8OPH8Oqp9P3dFLPpQstXKb0kiap9jl/usfv//EHeeVLDrG41GV6tsbdL7uRa3aPU69FpElK\nHMLsjhluesF1TE+VGS5HQEbmDKKC94Ysz/EIYSnCOY86T6kUkSaOo4vP4pc8tnOZseGIZy4tUR7Z\nx+hYwHIj33TugCVNPBIJVjwixebH+2JTmDsl90ruDN6HpLkUc4SWMOoITEQcxHiJQATnHQEGxCIa\nEBiDkRJxqUa5XCculYnjmDiKMMbinCfLPEk/RV1MlgVEcY9eb40QR6oeXIBXh3OO3AkqilhIc4M1\nHmss1kJkDXFUzEiCYLzDYjfmNqebK870O6tI1iOIykipShhVQAyqrvh+YjDGIDK4rhb31YJyoYP7\nFZxGBi/Rb0pyrj4uGLHYwJL1e7hShObJYF4thI5iFRqIAbN3cvnpLzIzERLVx8nGaiztyZnPj5A5\nw4gILutipU6W54MNboC1FvUOQcEYepO3U/4Wcfi2hCdQJRc2WOb6NzZasHRVz/oxPkUcDFmeMRBt\nioWXQikAi+JJvMEYxYol8w5rBI8nUk/egx1/8jHmXIOyq1GWKwzNX2FoeIyh9Ahpw7C0lFIKAg4/\ncpkXH3wf77zndiYnhshshXMXFjh/6RL79+/h8sISe7a/gXPP/A9e/p1v5ezJk3zoz38P+imu52l2\nDGnaJxur0Eo8Q9XseaZOAYtBTUECQ69gDeocKgYxjkAMeE/f2AGp8cSqJChqIPWCGshECVBCFTIc\nI0ZIPaSDpTcWg1cQPIEqDiX0kBiHU8HgsASkWkyoRZopXj0W2SBTZYUATyxCpqB4HAVBCEQAh1Ml\n1eKzJgihKhZINrfRKqCAeHya8OZ3/q988g/+mN999/fzz/+v/8LI5FhBkLXIHDVgVcgH6o3xA9Kx\nIbUMrjeIiWihShUDx2y8RARUrhIdY8CKFP83irXFQDRGMbiBKlTsU4vzqPS5N9v4DBuURovrFzut\n4tOIymb5DmFoGB8dot1s0lpd5ZprdzFS302lUqbbXmN8qMbU1BRf+MJ9XJw7wy033UQQD/E9b3sb\n+w4c4PhTz7Bt1zYmZ3bwO8/8V1xaI0sKFcPlSr+f0OsoeRYVeSDDGCPU60oYWoIgQETIXU6vl7Br\nz276ecLKWpNKPITBgne4LEXiKr1+m7XmFUpDyuTkCGE2RGiGaTSbEDi2bTvA8uo8Z088ShRN8upX\nvJE77ryTSnnzFfOJyUkA7nnFP+OTn/8zqmNT/NVHf5uVJVi5CI8c/iqz18D266/nphfeyi37b+Q7\n7noLl5eOs+vO63jBwVHGAke71yKsCM4PSKkIS+eF+iQEEawtCSvNs6QJ+D7YUonR8SGibIldS/ey\nLYWjJx7g0eOOqKTQgUZrlWMnH+f++08zfwUCK8RlZc+ePVSijCcuXyCMHP1Gwvf96Dv5+4//Ea04\nodcUshzEQ3dB0UCQAKau2VxsYhPQ73XoeMUKpL0M5xxg8F5Rr6Q5zOy+jnO7x3l45WkuL17m+u0H\n+dX/9iHOXYGRClxZgNkxePM9L2FmqsxXnzpPmqUkWUa9VqO5coVdO7fx5JGcJ488Tq0GUQSf/LvH\nmBmDG/aMMbV9O2NjdcLhErXKCGEQ0UuzYgOBIwoMQegpl0JUDb1+ghGLsQHGGMqliMutFZoX5hmS\nZUIpEQXKdK3HLa/exhcPNzadO6qQOUeQB/gArBbKuyI4H4B4rCkRWKGXD5QTKVRxl3ucC3EuRiVE\nB+qGQ3DeIhJhTYQJYgJTJg6r1Ct1bBxjEbwqgYXAgDUR3gX0UyEOqgQ2QfMEY6GLR1PFqyX3isdg\nPFgDuTFAcZ3MQqaeyBYKu5VizrbGIFoICJuBSzr0zh7BZwml0VmGd11PaWgSXIqaMvgU7zPEGASP\nDtYQEKwN0bAgW8U8KRtER5Xn/H11szj4RUCEKCoRGGWYHnFvDg0c3uXgEvAJ4nsQjdBZWyUcn2Vo\n9gAAQVoilpSJimXNhahPUZdiggjaqyRiUfWUrCX0SaHaRxXK2v6Wcfi2hCdnfYctG3tfAbz4AdUb\naADFty4Wh9ThpShhqR+sDHgQTxG/YgUxogS2WKEEQ3rec+D3/zNnAs+QGmbr89iVJxkmp3X8FFpt\nMGI98dCjhOnr+Pl3/SrbJ3dz7uwX+blf/ATv/ol3sDSfMjE6Tm1kBHWesakxDux9JxdPPMOLX/sq\nPv+pPyIjJutktJKUTqfDpYtNjFh6WWdzGeSVgcqINw7ri2CmooXkjCPBEqsr1BIRVA2BKN5BaqFM\nQT4iKWIVidLHor4YARZAcwwBKR6DECt0jTKM0LNQ8wZRV+gNoqgaMlVCNQjFDiURpTEgExaPUaFP\nUZI0CLn35APlwgnUEYx6VAzqPTGbHF2wzoABJe+nfNe/+GGaBPzNv/+X/OQffJSk74tJenBfo4MY\nAeplwG+K96+rPIWwozhVrF69vgJmg+ys/yeDstWA/IhiDBhxhaIj3/AxB+SneGB98D73m+tzH9Dn\nUKNNHt6ZJX1KUUhohRe+6DacJpRKhnq9ytTYGKNDZbK8zdLSApOTMywtNakOweJai+OnTnH5/DzT\nOyZ52UtfxBNPvJInn3qC1mqHTC2tZsrl+WWaa13q9SHEFASrXq8SxZZSOcKYkMBEeCMEPqVqStRq\no6ysLtJLQ0Q8vaxL5voMx7NMTY9SkjamvUbg15iYqFArjRNFAc0eROUa43YnF06fYXLkWg5ddzvT\nY+XnMNbnj0ZzBYCeW+Pmg29ifHqUi5evEKTLXA6e4QsPf4J+7Rl2T13P5PAwQVAC4NzcMXaOV7hm\n/508+MVPMjE+RDNt4hGiEgShZ3JPIQMunlMa84o1QliCTqaUfEiULHGbf5Zdo9Bzhmtqy3xHBOdn\nDMl5pddMONXp0FqD0AqBCJcXlItn2txzz+v4yt9/mCiK6Kx6vvqVvyPJ+pTHQAKInWVtydFrCfUx\ny75Dihl2m4tNOyHLC0UyzR1hLPT7HjGGzOWEQUhJYOTuAyyULTO1vYhXHv7KaZaWoQy4Ptx8zSRv\necMtXFpZ5ZGnLnHq5DKdNEcDQ6nX48hjj3LXGyZ425tv4YH7n2DbDMyMx8zumGLPvl3UhuuItUBA\nY3WVTt6kUi0jYgmkUFE6DpzLKQWeeiVkqA69nqOfeKypkmdKKQx56XU38ODcIlHFc+LkGdJz59l5\nXcjuneObzh1RED9Qhz2FeoIgqgghkRiILLkXkiwvFA0DgRqsCIENgRiVABFPaArCbG0MBGBLWBsR\nxRUqlQqluIQ1IR5BdKAwIFgbgIc0ySiFAWXr8ekyNh4iyJt41ymsB4AxIWFUxxoDBhRDYACjRVne\nFiV5C4TWEInFiAXzbZfub4AVQxxWyBOPNpdJFs4RVYcRnyBhCZ+2yddWsHEJzRJc0sS3FwuFd2IP\n8ezN2DAeBLZQZdahKhuKzoYyTkF8BMjxZP0W46OrmH4Dqi/H4tAzf4zMn4K5Iyyb21i7/vsYv/Zm\nvOsiNsT7HhJXGNr/AspOWZs7TZ5mpGmXqew0bmWFXCKy6nY68QShKLVsnlL/BOy49ZvG4dtGrfgK\nvlhqRfHI1YVmUDJxPi8UHwAj5GlOYAb+DBVyp0RxMPCGGNQpaopBEdiCtdJ0jP27XyC/8QDm8gWG\n82OYJKOUryK1axnbtkLcaSL9p3nbq47jc6iUY5ZPn8FqwK037cGnfW6++VaOPHuS8TBi547tBJHB\nKMzu3UkeRnjrGYprmKowLop3YyRpTpo5sjTZVAKpKqE3ZEXFjmxQ7xUVOuqJvEWMRxAi8QTe0hcP\nxmAQIq+0BELxOBVCoWD96lEpdifgMQq9QSmrHOQF01dP9+gxspUVZGYWDhzAu0GJEIOn8PBYLVS1\n8uA3W8ETyMArI2DXF/aByuG1KM81VQnFEKkpyNo/iu9cTXxESXspP/AvfoCP14f4ww/8Dj/w0z9F\nEILL3SCf1lUTCq8OAx6hV3cSqoJBCnKj6wrLYOIogl/UuI0O5GzdID9F7Tvf8Po8d2BuJLteLU9t\n7FqeY+rZ4EU8hwjJ1fc8X9TrQzz91BMcuuF6ev0ey6vzjI3VGarXuffez/L6e+6mUi7hvbK6usZa\no8nQ6CSNdovp7duIy2W8czz80P3cccchMrocPf4MkyN78XmfqclRtm2bIootqin9fhdjckKJQQKG\n62NMTswwMTFBGc/KuWdZXD7HyIGDmLhGJ2kXZU9jCn9IGNNpXSZtnSHpX6BWr+E1wkhMRJcgTGi1\nG4zP7OWGm17J1OxU4ZvSTdZsgA//xX/mgWf/ggfu/yCTU6/h8b++l+UluOuut3Lp/Ck6SUJ7ucvu\nWw7w+u94B48c/jgL7Tn+5jMf5od+5D187K9/n7MLsG9fm+FxQ6cF3RXP3mvHWFte4/Qxx/RUxAtu\n386Rx8+AgRt2T5Mvpbxu+lmu3wY2EIYi6BrDnYfgL08aemK4+ZZX8qWH78crxJHgUygZ4diRi6wu\nfpQkH5T9c08uDWZ3jtDvWlbO92jOdclSoT4D1eGM1hoE+eYUsMAYSnFIniWoETrtlNxBLRIyFbKs\nyzX7dpLblKynVGojjEyP8/k//SwrS2As3HFomte/9gXMLS6w1klx3hbWAwTJlcSUOPfsAg+MPMwb\nXnOA4UqT3Xu2MzY+ggmKbW8UlMjyEMFQCpq0Wx186jDiCQNLlufk3hNGESYKWEoSrCi1mrB9pk6/\nA40kZaQ8S7+xwu0vfhkHbj6EHLqfF7/oADt3jVGtxJvOndzlg/HoUF/E1mmOxxbj3hgiDHFoicLC\nH2MBGZSLAhthg0E5SwxBYMEpgYlQMQRhCDYiLkWUSiVCG2AlLFQgLcpDBsFYg68ayhoy/YJ9vODa\n2xkeGkZE6LSbHHvmCLEGJGmXS2eOcv7iV7jUPEcpqGNNVMzxRikZixlsGkIjVAIIjSl8jLI5slwf\n341UhvDtNbIrT5CtzuG2X4u4DrY0Dr0G+uzHcZlDsz6SpwQGgriCaoYfuwaxwaDUpYUCBhtE52qZ\ni431BEwxf0pAP89IfY14aQ6NToINIZ+gzxArZg/tsX2MbL+BwFqSfoswCPFZQlQuY4Iykc9ZQ3F5\nQmthjumz/wVT9YS7v4dydQjyLtpYhLlH6F1epPLCbx6H50ETbbF4qMHgUDUDT09haFVT+HnEC1Yg\ndxkiEEihEBkriPqNkkFRkjAgileDaIZWKux++R00m01GgkuMts5Sqm7D2F1MS5M0fYLZfoNr8y5f\n/sLDzEzv5szpE9TrVWan9vLWt1yPVob40iOHmZ7aQbfT5KnDj/CDP/5jLJw/zcWFJVpnztLpWTpJ\ntyAr1lIKDVEYUopDnFY3lUAGTwQIhnxAKFJyYmBIDT0csVocDhFD5j0usIj3GDGoSGEYVsEjrKgy\nbAoVB4FcHUYNmXiGVegJdNXgFhfIP/ghkjxHnaPvoHT9HsK3f29Rb9Yci5AqpJITYjEeHIoDQlWc\nEYz3eAqFJ9JC9lcp1KWSGLqqGHXk6nj+dtyr0HUJhPVBoCS9hDe//U28fjTi6cMP8Z7/+Jvs2r+N\nNEkHeQGqptB0tKjvquoG2VH1hGI2BBUd7KkMDMhOUc4rjMsUNV4pvE+GosT1jTVnHaiMAw1TrhKc\nq4NX0a99x9deYZMKz+kz5xgeG2f/tdfw+FcPc+OhA8xMjtPv9dh/zTWsNfoYmrS7XbqtNnfdfRfH\nnjnPTTfdzA23HKLdaZP3cy787d8S2piR4VF27Nhd7AQDx8R4hcwXSmwUlKhW6hitoP0AVBivzbK2\n0KXkVihXE2qjVcZnqlxsNulpRh4aQhNgel06yxc5JyNUa0M0F06g/Qad1giUKrhcMc6xcukErW6P\n4fHrsLUpKBnSDDrNzXu/Pv7xTyCPwVAAYeU8YSWkUnV86aG/Za0JWQZZnvHlIx/lqz97P9fetJvf\n+IN3cmo+Yc+OHVxoLFAJhPlzysgw2FBwCuN6CLhEnp8kjD3NtWWCCuyf2cX+eo07dh3lRQcFYws6\nG1chyQqPzo/eoPzxYcMDD3yO8+fakJsNP5ANISwZ5pdWmYpiUlX6TU+9MsPYbMrFI1OcuXiYzAnD\nu4XKpGekPMLEVInEb05Vjq3Q63nWs9U7Ty0OyX3hV8HAP/u+N+B9CZe3yExKFJf5jV98N1944GmO\nHZvjpht30W73WV3usdLq0u9DkubooMSce+Fsx7L21Hmi6jg3vfAGrFicFkZjIwHdtEscV5BgiEo1\nolaC3tpZ+naGNAvIXUZgQvrdhH6nX4w7AlaXMuaDJjcf3EY5KtHtL5GswUtetouRPQu8bqTCyTOH\nOf7kCp1eyo+8772bik8vyXBkOF94FXMF7xwejxpfzA9OcNZhwkLtxnvEFAu4RAYTmHUXHw5bPBdY\nwCJhAGKwYYgYwYjBBEGhcHiLkhHEQ2SuTtRL2F6rMDV5LYFVZofqzMyMElRqvOpFL6Lf6zN34ixy\n4+0snb6Fk1cu8cjcw5xaPUrVRogJiYwv5rhBNUFCQUkL3+MmbQZZr4HOP4peOY6mKeHEIXy3gaQN\n/Og1mMBSCnKMJig9JPRgAzRQfG8B31tBStVBZNZ1cAZ8QJ6j8rBe59qoCYkx5FGZk36M8nLMVG2a\nqDaMuf5lrI4+zbGlTzM+Mcu2oTq9hWOUK6MYiYumJiBrLkF/jbzfxqUthpcex+/7GfTMfUg3REZ3\nIfVhVBZpZTs5ywrfXN95HoTHDwxI60YkL4pgcDJYG7xgMOigvpflWWGrMBBpwfEcgBGsFJ1e3ivG\ngjEenYObfvNXeLZssc0rjGbLBPVZxFSZTT7DLd15ri+1iKoZeZrzD088SlQqs23PDjLnkcoopxaX\n+OIXPsXr3/wm8m6P+z59L69+42s5ffYMv/Rz7yNXw8EbbmBqWIv6H4pzSq+nhUzsfbEwbgKig9gY\nT8ZARcAUXVAoBui5HG8C6t6j1hCp3yiv9L0SGKWvQmYgMkLmix2jH5QQ+5JvdFZlKLF3LP3hn+J6\nXTSwhC5BA0P/6ePob/8u0f/2E2RiUCtEaklRSs7Rl8K8XFHBqZIxUJoGC7Uf1KpzcuLB5BagJAPP\nSvcfofDocweFrAshStpL+Ld/+yV++i13sPKOE/zA//lrfOf3vpqknw2Ijd+YcLy6AdG5SnzYeH79\n6ut/+g0zMgPiU1h0Cq+OyNeTnfVy63M8Ol9z3UHef5Ny3rrCsz7nbDJ1cLmjPFzjy1/+Ctfu3025\nVKHd7tBpdRgbGaPT7bG0tMA1+3YT2QDvHRPTo3zxS5/n3KVT3Hb7bXQaObV4hLm5S+yd2cu+Hdfy\n5WMPk3tDt5viAsVGAbkPKZk6lWiYcslQKZfJc0feX2P/3mtJu/NM1odYWVlh394dPPLkk0S1MhJ6\nKiUDSY8r5x/D2hjJVnGuS9TrURubJAjL5LYGrk3W7xPFa/ikxdL8GuVSlfNnFzcXGKAc10j6bda8\ncOTICSqVKqENWWs0CGsG4+DcqQXOnF4ABw89/hWyDMJQeObIBYwIM3uFuKK0VwXfgWbDEFUqRK0I\nMXBpLqdca3LrtfvZPTzGC+yjXD8Dx04ZbjpY/J7n55TRmiAGdk7Aj73I8een+rQQghzSHExYdBj2\nuoV6gheiwNJp5nSbwq233MYn7v8oPhCicWHiWmivQUZKJ8totjdHeFabCQNfAF6Fahzi8gT1wtBQ\nwH/9wC+z3O+y0lwkCgNEoV4bZywe4cDORYZqNS5dXiRNMoIootFosLLSLtR7n1OqjzFUCwizJbQX\no3lKKJZSHBW+L4RqZZh+WnQ82nCIlXyVPLeElW0EpkSae4xxRYen9c8xVRs0M3S7Ofc9cIo9O4YJ\nI8c9d93KleUv8A9/uMr7/9tTvPGuISYm6uSlkU3nTidvopoSIoW3zmvRpEGGJwFRcuPJSHEmAesR\nmw9iqmTGk5kcFQM4jDUYLNYIDDapGC0IFBSlLAARvBVWGzGtY3PEi2eIXQbVCZbdYdLWMskL7iA7\neIDtt99OFEVU4oRU1jj55VQFs2QAACAASURBVL9jrH2JvX6IWw++jLnOTXxs7kusZBchiArTiBau\nGkIQCUCLEtdm4JorSFbGhzuQShXfb5M9fT9GM4Ltd3BV2c7BgPgUfI44Qbsr+OYVpD6DGDvoVJUN\nKwGqWJ8U4oV6im4wj/oMkydUrTJUqbHW7LDavsyul303Wd6lVKlR0i7TtsXS3DEWhqoMj80gYjZm\nXRtEOOniky42bSFGqKcPc/FizKXuNnY3ypRrYGo5WpplzS9x7PDfc+tb3/1N4/BtCY8vfme8L9p4\nPQPDlM8xYkAdCFjHhk+l3+mQeoil6OjKHARBUfbxGMQqIh6CAHcu4+DnP82xesyICYjqe0lXu8Sm\nynZ3nLf15qiWl0m7ln4zpbFcpdHs8NCXHmBqcpJeo0Ur7XPHnXfyzp9+Dzu2z5IuL/D+X34fTx75\nEq12AnmPPbNjrB35NEO1EIMQ2GLiq9SUCkGhgMjmdqOlvGgUF2OoGMiNJ5CA7qBEFWAwRii5HLVC\nTXMCZ0hESMVhxGCdYHFEvigZOhESLNZA2WW4uTnkwhythUWSXpejXzkCwxF//W9+kGpVIO2RJwaN\nyvzeZ77E048+wfCLbiPNHIkUJEkHqkmCxyskAw+PRUlxoEJXIfAeNyAZXVVChZI6EjUkuvmdunMD\noXOd9Euhhimel7z4Jh5bcTx9/Bw/+NJX8+Ffr1KZup33vP8/cN2hHSRJMXA2Sm7qN/rE/LpreIOi\nFENjvTiw7uVZv/c3lKF0nejIxhVkcJmrnVkbEtLGhLYuK8lz61n6df9+nuj1e3Q6Fs1TWq02rXab\n0ydOMDszzYtuu43e6ip79uyhWqsTiGHhyiLV+jT79x4gKpc5/OUn6aae86uLTG3fRpamuCxlciik\nlQg5Buc8ISG7Jndy/Z59tFYXOXv6IsuLZ7j+xoPs37adxuIp4uoo8wtNsjSgjOP773k7TuCzD/5P\nzi/NUYoiRB2dZpOJyb0E/Yx3/fO388DDn+P0wjxXul3q0W5cu89q8winu01aFx4hc4YLC+c3Fxgg\nL7UJu5APSeF1cz1aaz16HYN2hE7bM1wxhALxOPSXhSzzxJOC60NzVUnWhNFxi/WOaEzZMwat5T4B\nNYIAvu+N38XN+3Zy5cjnuGfmK3R7hpf/QjDIokHXHwGFRl0sht9zC3zguz1fOCJ84GHIjeAyxQaC\njYqNXDHXGeKq4eHPnODYV0/gM6jvVmoTQtaHyohBXZ/5C0WJajNYbvSxoWViuE5jrcFa4hAtcvb9\n/+4nafWWWL24wsnTRzEBSO64eHaRobFdjI7XUe+ZmR7hvi8f5UMfOUkkMDoBr3r5frKlU8zsKLPr\n4G6i0o10ejnlMKTZ7pM7Q6lSR9WyttIlzxKsFaxpkxNjwjLRxBj4jEgzjA7KLap4dQXBzjJ6LWXu\nfIM0ddw+sYubrhniSqvBycujzOy9lu98fZ+1pfP00pxLzbVN507Dnyc0gkqZSALEGrxxOJ+Thl28\nelKXkkQdMkmxJQblbYv30DUZqTEERghNQF96VMKim1I1oUsbCHD0UZ9QokrrcomLT85R0TbToWM3\nKSMlpR4HYC6x2hXyuuLmvkQ8U8cff4DqgZu5+JUHWTv+FBPNx4hNj4NTMZ3mY1xjhJtGyizLa/hP\n6ccQbHG8CErfWqIB4dks9LO/St7LIHPYwgmJrdSQ0Qnc8nlMFILrQtZB8j74tNhcagdsj/zMg0h9\nFjMyWzScFPI54dnfhfF9yMithfv/8p9Aax5ZnkNO3wfZrQTZDPVkBLfjhey89U5UU6xR1PWozN7E\njruGmOmuIZ0V3KmnaF12lIIu1YlZxu21aBV61Vny2RuIW4v0/SF2lk+zyxtcEpGe/AiapHjnqV9+\njDtueOW3jMO379JC8N4PFqqrRiRvA3DZxtSQiWI9eHWkSZdKUJiWcw+xMawvVz6X4vwGUWQxo/pv\n3kv91huonDnMcHWaxlKPHbKfSvoUd/ROUdIFXN9A5ggkoJ1G/PhPvZc9+/fgfFG2sGGIcx7jcjpL\nl3nve3+G0UqdnJD99UuMBMKzpw9TrliUMYw6ZHCOjQDqPcGgTX4zcB5Co4gfeFA8IDmxgAssikN9\nwcht7iEwGIFIHUaEjuRYNdQEYnE4wGHQZ4+RPv4E882MaP8+7N49lEdHOfs3H+XX3/cK9s5sI3bL\nOFnGTrcJ21Vcf4R3ve0QP3eszjCGNVEyUQyWUJQEQ+ZzUl036xbkxqghGpQm84Hc6ynITuA9iAV1\n2H9El5ZHCpl9YBos+g900PboSdOEGw7s5iNffZCfesub0YVH+cV3/jg/9jO/xD1veWlR5lonJwOf\n1KDIBXA1H4ukZF2RKWx/xYuLImrhq9L1NsivPzuCr7HufEPHwdc5djauu35b0WJjsBl02l1GhuvE\ncczhRx+lXq9x3cHrSJI+88sLnH1mjhuuu5njz55kdGKIExcucv0NNdoJXDx9iiiKOX72AhKEENap\n12tE5RF2H7ieo898BckTwiAm0irHj82xNt/gxbddQ1wVoixipdFAGp7hkSEik1MZqXDxyipZFvDY\nE0+wY+dOXv+y15IlHe47/BCnzx6hXptk2+ReXNLk7z77P7npuv1Mb9/Pxz99H2vZKuNTkwRhwMW5\n83SabcJanTzvbS4wwE/++M/xB3/4K2R4cmdIV6C1Av0e1CODNwpeKY8LUR36DQgiIXWQdZXpkZjV\nXsLSoufADTtRGoxUpkiSDtY7jEJ7dZ5LcxnnloXHEnjpQeXzP+/JnMd5yHIIopR2D/ptxalQK8Nq\nSzh6uTDkRgMTPIA6oVZXSuKQdoQIuFRoXRIm99U49JIhmskFVpcE74XhSfCpsHJlcxuJajWiXolZ\nbjTJ8pw4CsjyHPEwXp3mr/72E4yNjlMf2c5otcriwmXGZoQ0beDzDmF1hHOXE7bt2s2v/PIBQqsE\nhZGOXrKXIIpQl7HW7BEGAUmeEocxzgUkvRyvKQKEIiS9HCIoiaeTOvpKsZF1joXFFu1OwtpahzxT\nriyscnlxgYmJCX7ou1/ILdeN8+y5S8xfgbWeo9HOudg5zvY948TXTnHsqeOIb246dzRs4E1EbhzG\nRoVq7wXnUyTuI5ojLiW0fYjSYv0QC1JYLrw6VARXyME4gYwOUSBYGxVdcNojo4NzjsbCMJeeuEgl\nbzJcC5jMl6j4HqIWYyNSb6lWa0SVKkGcMzpmyHtrpKvzxPTYNbSK2xFCmmPCnKltJfJ+inVXKHV7\n3FJ+MYf9I8VxHUCGR2y2aXUHICwHmLxoQ7KiYAriTdonn38GasME7QSbtUEc5Bmow/uUHHDJeUzj\nMmZoGjXF2ikKnbUxaq2HkPSv8OFUUQ6wh0h6e8niO1jrX2GVKo36KCbtEjbO4ae3k3tHqTRCEITY\n9ln8ic/ByinKy18GzYlG9sP+UTRswcTLqFR2sNvVySUib1+BZ/6BZG2IhQaUrzuIVKagOkW+6y5K\ncf1bxuHbEp7CwlCE12/sdq8uBiKC8x4Rj0rhsdA+aBRgBge5iR0ciIcShErPOUzmsTffwvZ//W4e\nvnyJvNFjNOnTi0doh10OXFxiv/skViFnBLSNy1M+d2qCd12zhyRNiy4T1eL8EECd473/x/v4xJ//\nJTPX7OGVr3w1xw6fpzLZRU0ZG4D6DO8FEY8JBBS6/ZxKlDNe3dxuy+BR5wgG8VEMOTmKEqrijSGx\nhq5kGAwVb8hNcbaERyhjCQQ6vS6BtbjAUDlxlJX5LpXXvoFwpEyZgPTkMW546Au85c2HaLYTHn9w\nnvryFVzY44XvGOapBxdoXVzGTozx1lv6fOpYRsUaeq02XF6k2WozOjFG/opXYG2IVXBe6SKE4jdM\nycYrPVGq3mOBjgh1FKuK3WzNBnDODGr3hUHaK1jVovtAipJXmqXs3TnJ73/6M/zEW7+HbOkZ/vy3\nfoYHP/cm3vsf/ncqNUuWuecwDDYIy7oqs/6HrNObge9G18/W0avdGsUPpYNS1HPKWTIgR1/vxXmO\n0nP1sec8PCg3bhadbp8giOn129x5552g8KlPfZoDB/bTT/rs2DXF8uoK5y/Os33PXroXlnj06FGW\nl5apVYdYbnep1GsYG9BYa+C8o1QqY8wIWX+U3CVkeZ9LjUXoBSxcWSLzCbXqMJgyi8srVMsB/X6H\ndqVDmmd0M4X8MlkOmnXZN/0SJmcnecmP/y+cv/JyauM7+dCHP8vlxTUunD3JZ+99lG3bdzEyuYsL\nc6dhWWh1OywuL/Cy6/ehYUxv9Vu3h34rvOtdv0AvXeb9v/PfQQt1OK4VBLfZzhkZEWZ2lulnPYKq\nYOLiR8jXPNVgiLd83z381Yc/wsIK+GyON7/+XUyO7OCzn/57otgzNgUPPvEoyf6DaN/w0UuzGJ3n\n7tsEG0CWwF98Et70OthXVVotIXewsgwfeVC577xhxDrSQPA51OoKRhifhXxl0MZLUeLymULU4bGH\ne4zsFMQJnSYMjcHQtNJc29y4qoaWuYsNJCjOj7FisCYgcxmJgTe9+TvYs30/n3v4Ph4/fporl5e5\nMHeZkycvE8fwwlt2MDY+Sa06xNKVDqPjo9TCAFsqUSsb1Ct51qdmLTjIck+/Z1HfRYwnQummijoh\ncwlJlrG22qXR7NPrpnQ6fdqdBBNYAmMIAkOeeSYmh3jHm1/L/l01njkzz1/fe44sM1QrDfbvnEV8\nSu6EfupJMsetN+7n8Scf2XTuBHEHJMWbBCelossUwfkMTL84q8XnRGGCVYfgERMgmuM1x/lC0ROx\ng+khxImwEjmgwpgdp5QO0U1zut0KK4+fJEha1CsRU1mD2HcITEo3CcFnVOsxkzum8c5Buka94snj\nCr350/jFE/TP3EsUlchaDeKawfiY+vAklaFtlC9f4juWt0P1xZy0j+AG5zhZU3hlNwufJZCnxTEk\noSMMKUhs0iY98QBZaZTo0hU06GNNcYaaGgFicu3h9SJy4Wns+E5MfZJiRCp6w3dzZfUW8naPib03\nEZYr9HsN5o5+kQZXCCav5eLRwwSxoZakrD32WQjq2JldUE9Rn2KzJq57GRs64ijHKDgt4Vv7MXWH\n15shncStzWPCiFLSpil30vZN1moxQ/OfxI7tJm+W8GsJsrIIt739m+fItwuSDHbXbuBxKdaD9fY7\nC+QYLE6L8pXX4mwD53TdSYpb960gBFqYmd0VmPrv/56GUyZXV7iVES5sn2TbyiKV1UVmzTm6PRiO\nLC7PCHyKNyEPH23ybhGEAJ97vDrIE6Jyid/4tV/j3k99nG0HD1Ku1rlw4SwTew9x5fynyXpdnIXM\nThZSr0BAoWxUyyFpbjm3ukkTmMuxxiI+I1GD14Rs0IZv85xSEIJa+lZIrNBVj8+VTKBkhcALSo7k\nCaVulysf/TRr+w8wfst1GN8hfeIE2cQojcceozq5h6Tcpddz3GJDHm/02PfKCotPTSA7L3H+aJfZ\nkqF79jLLrReQ3fdFsnab0p23s+3mmzn35ccIKLoIKiokxpNp0aZatoZW5nACkVcyhViEsnq8Fv6k\nwG+uIwAgz4UgKBQdGRjWvS3M0tj1CnBxwvPUeI0P/sPf8zM//KMsnHoU8+Tf8MNvfJxf+p0PcNPt\ne0n6/QHRKNSWdZIySNKNsyHYOChQv0Z9MeulMZFBd8HX/daDt3+Nx0ef495Zf06fQ/af89wmxUHE\nhDxz/BSBVVrNFsNDNe6++9UsLy8ShhHepBx+6jCjozN86eHHMUEIeYczcxc4cOB6+s5RLlXZvXs3\nDz30UGHCDyMWFq4wPbWNCxfPoMaSOGF5eZkbD1zPkdMXuHjyq4yORezbPU6tZPC5I4yWmZicYHhs\nlNFymUq5ThhXWJw/y0gwjTYsUZrwDx/7KP8vZ+8dZdl11/l+9j755sqhq6pzUqtbUssKlmTLAmdj\nYeB5DMaMPfiRBh4LDO+9Ab9hmAFm8BrSY94aZswANsFmwLYMxsiSbQUrWJZaUrc651Q53Lr5xL33\n++PcasmMbSj2Wt19b1Wtvrf23Wef3/7+viGJBa5vURkdwhhwqgHPH32WydFJrly7ysLiAnfd9wAD\nw9tZbjZx/JHNTQzwo//qdl4+doKRWi4XdwOBFRi6DYFWBulCasWYTLC+TG47YKDbA0WLW/a9nuIH\nLd5497vZPrMDozL+34//MgduqnL27Coyy3eu4xcus3V4kkg5/MXxATxrnXtvk7gevG4//O6nBXfv\nEdx7i6Hdgi8e0zyyAL60kRasKUWaipz8auetVmU0ljRk/ZsTCjpxbjzaPg3FCpTKEEXQbYC0N7dw\nriw2GKzWiKKIJE7IRM6NsSwXVutsHRtj+exLPPzki9T8KXZN7uOeg6MMDAxQqQZIS2J7Aiw4e+4q\nFy5c4ZVjZ2h2W0xNjDIzNYjn2iQxNBqt3KTSEiRJTJoYoiQ3e0tTRZr25dKWg1+wcH2PSq1EmioS\nlZEkGUOVgAfu2MJYYPH8hTleOKUIAod2J+WOm8aYW82oN0IsCxwLAkfQ7GXoks9v/O/v3vTake46\nQrho4ZBiIUW+y2utQYYYlSCMxgHc/iFJa4kWCdLk4pqN1rUlXTIdg0lIzA4OyrsQC+s8+cyzFLbs\nxbeaDNJGyAwRhfR0QtExFF3BQEmBH9CLMmr+Os2FWZrrHeTCNfzhSaLZC6y88AXE+gpBTWI7kvXL\nhrGxJuMFRXnyIEk0yER3jfvdaeLSzczFJ/D0q4kF/3AL+8eGjpK8G2GynC7gCRxLkURtTP0qWXgG\nWmtQdHGKReTwXopbd5OszZNdPwNRSHLpWeTQJO72OxB+BYTBIuXJP/4VxgYmqQ8WcbceoLzlFqrD\nM3jdRdJjn2VpNsXeVcQKhrAdl87CNdIMBkZmkE4JWR7HsjXC9rlRWvgp69PvpLDzzRRqY6jOItbO\nQVae+jNIt7E0OY2ZgPLFhxBtl8axCyStHjW/jeHbC5C+M2m5j+bbQqL6kQMYgcL0iW4SIzS5YW3u\nqSIVaFfcMGMymcSxc65IbDRCCG79zY/RnhxjrdOhuLBKeXwGGk3qi8tsd9aQgUCkgFAEbgctJTgp\no9WILM1Ik7zYSOKIgi155mtP8V9/7TeobZ1BJxGtdpOCHGPHbZMM7P6XBF5AtVLha5/7HaLYRWYZ\niZAIR+I5Nq5vYbubW0CpUaSZIRYCmcbESiMsC0saYgmBMYjMwSs4iNTQsjcobrlsFcD3Ijh1lIsv\nXmHqPW/PJ7vdQbc7rP3pX1M5dICBm3YzumWK+vxLOKU5zq1rDu+sUhsroNMDDCYhhw6HrC/C9ORe\nBv76ZQa2jjDnTqOPnOC6sal86AeJ4xRHQSYt0LmTc4LIlTbkbQJJruxJtCDVikBInH47arNDQB9N\nM32lVH7DwhJ9f6G8IjGAUppyweL3PvVn/J8f/gnOPPclZqY1H/ngj/CBf/0LvP/Hv480TZFCYVsO\nWhsc1yZNDGkGhYJApRlK6292R+4XPUaYV3k4N3g9r+4Y5pvhov4v8CrHx/RNgG6gQRttrddWPZsY\nb3rgLTz++FdyZCYokaSaa1evc/nyZaZnJnCdCfYfOMjCcp1Gd51SqULWSdi5YxeN9SblUoUgCGi1\nWoyMjFAqlbAsSafXQKksb606DuVikajU49iJV7j91juZmd5Bq34ZIzpo5SKMjecF1Ot1EpEwWNxG\nnKZIRzM4PESpWuPy3BXm13tcnVsnEj6NdoNDt95KL01ZmLvKwYMzXLqwzmq9zo7de9i3fx+9ToIl\nfBy5ebfcNTnLzE2DLPTquKEg7uVKz/IQ2B2BV4RrJzWDBUFlGuxhQa8OsRLEqebzT3yM2sgSf/uV\nJh/8336O46e/RLt9BD/YgyUL2G6uRLLchKXwCo4aJsPnvz5bxKbHPbcLDuyHoZcMv/Jpwd/vh8fO\naJYKkxza43L06BUC16VoNJkwRD1BUDCkTRgdEyzHIVqAyUBbkKR5dwAhiHrgFaCzImjUNV6wuXUz\nOjSYKzyRSMtGZYo07tBpw7aRrfhjN+EM383v/Zv3IbKIJGyhohYq6qDSECFtpO1j2TY3v+4gzn13\nYxXKnJ5d5LGvPcsv/8rv0OrBz//4YQqFKr3Y0Gm3AQijGJ3me560JYXAzrmBcZJ3iqUmShXNTsjk\noM9dt08hLMGTRy+ywx/B84somnS6KeOjBcaHq8wur9PuhJQCm0Y7xQiFJQ2drsbbMrjptSO8bs45\ngb5ytl+/2HlcjZAbvr79IXNzWGk27rM2oFHGIEmxjEHKUe5J3kp89RRnrixQ9kFmbfYWBd1mg0Zi\nUQgcTKbo6Zh11aNclAz7BlkqcOX4S6wvZBjXxyzPY3k+jbMvMX9+jkrg4Nsab9DCH4E4lHTW6njV\nFcrVgOzcPGOp4B3BAZ4r1TiTPo1jzDftXf/Uobohdv9QqMlQrsk9i2KF7w0zcOhm7KmDJLWtCLvA\nijOI0C0q1CndZTDtM3RefJbo1FcQfhln6maEHdBaPseuXYcZL1e4+sgniJo9Dh94K36pwsKxT1Pr\nnWP3gbcy3+2AHMaqjCKGp0mNRsV1sB3wawiviOiu5GvJgNs8R+Ox3+Ern/5D7jy8F9lZwvQ6dBaX\nwC+x3E44t+5w181bcdoe4dxCvk/7NlbR/7bz8I/48PR7mTrJPXQwGCGRRuQ+vQZ0H8HRBjKdkiko\nSE3az99ykRiTS9sVAidVzN13F9nkAFsuX8HUaqz1eqS2y/Zaj/HVF9hmTmCLEaS1lt9ktCYTFY4s\nDrN0fYXayDDdboSlYsLU8Eu/+BGqU5PY2CiZYbmCuBfy9JPPcMcDb+KBg7dz5BtfJ6iUca0CUgq0\nFGilicK0f0rf3CLy0lx5pXRKIdPYRhNhKCjQgUfPaGKTYqe5LA9loWSO7LjS0HMU6omnaK2kTD34\nXWS9HlIKisogCwW2/Mj7OP3Jv8AcfZG7H3wr83PL3HawwM57xrGaTdaeT1n1ztNai7h5r4N7e5Xf\n/b1H6RmHju3g3L+d7Cd+lNLwAKYXYYQkkzkR1JAiTZ5UJYVmUCtmTY7guSZ3x3aNpGM0Ba2xrM2z\n5LTOUZw8a4y+bFbcKJIRos/Hyf8oJfBtwe//+R/z73/u3/DkZ/6Q6e1b+dtPfIwjzz7Hr/7WRykN\nVXj6ySN4QYnnvvRZXjl6EhvFvkP38+bvexsTkyMMT1TI4gwlNOKG8eXG6/TXdE7s4VWcKf/OPxz/\n0KdnA2HiNf+lRtywXf+njn/x/vfz2S/8DdOjQ8SJIgp7rK7WGR0bJ8sSlHZ55cQ5SpUqW6YmWV2r\nUy0PUKsNsG26RLvV4djJ41QqFUZHR1FK0el0KBdHqVTKKD3NxcunQCWUSj4Fr8jpU6colgrs3j6M\nbXzSXg8bl26vhzIpVatKohWVoocCrs0v0ut1STNFZtdox+BXfMZHpojamsxYqEiSxAn1+hp333M3\nUzNb6HZWqbeadNMM+c9YN1Qi0qbEdMFyQSbQXYckMphM0G0b4h7oQh7X0FmFyjAMFUF1JcZbYnQG\nRsRVwpPPMGw0TjBOHLtUqhXGuqCFz+pKzMKypsAqFX+YnnL5zUdS/h874Y5Dkh9+OzxwS97ierEL\nfsmmEXfwyqDCnOc2XAU/AETuoxWG0FzNDQ0zBE4ZxrbA8qymFwrKRYEnIFUGy4FCZXNTk2YGIVPC\nKCZJuszPZdzxulv5+Z/5BRjezer8LHFnjTTqoFUKKs1bE1qDTrEcH8v2+62cXFThWA7ThSI/+a43\n8tMfeC/PHDvPb/3H/8yxI4+z/eZJBgaLCMvC9Rz8okMBjeu69OKENM3IUpOvvzBmpObyzu/aycp6\nm0dfuorWLraUzIoO2weqWKGkMmhx295Rnjl+BUtUcLRNmqW4UhKqXP3k+y6feuQZ3vqvNzc/wt5A\nYflm2t0GgmvxzYcYDNLuG+YagyHDGJEfprMMY2CwdDPu+Yu0eutYImZyeoY33jTOlLrIQyu5J06n\n3SGw29iyTcuGwIHlhiS1UtbmMq6teoy6EaM6giyhvrhIN4TxSq4UdkKNHxhGhzVuySbttekmhl5n\njfKgwx67wWRwgL8LbC5ET5OpLLcv2cTIOinSAmHLnLzd0rmbepgw+qPvzaX3KsZlEXSG3RWsMkVk\nuVTqL1Fc+wxO7WbCCxdJTz+G8MvYo7vwihMUxiZI05jSbW/HG99Oc/kcvYaDt+MBGL2ZQjDEeH0Z\ntbrE8Pg47vQQwmiyNCbuNvBay9iZJlu7jkkBDbFxuN6d5O43v4PJ4grm8lmyzhpjM8NknRX80KC3\nb6ex1qDWWsOzDcb1cgAm+mc6LW80EKS0+7C97MNpur9AcuRGk9uJ60zn5FSRW2wnJrf1TxEYmcdH\ndBdhKugxc2KOxu5hLl05jzu+nSScxW2G1EyI7VTQZhKVPY0lLWQRRFNTrRT5hZ//ae59wxsRls/0\n1BauXj7D/OwlbK+AbZkb4Y5b908zNjzCgG/zhU/+F6zsKsXiGMjcmA+dS8kLBQv0t5Yff6fhG0WS\nKPzMEJncusvNz14QpSg7j56wlMGWEqE1tslbKjWtCS5eor4QMXPfYawwJqgVUFqg1js4jsQp++QZ\nJxaPv3KeXTrhkU8t81PfW+Jyo8N0tUh4+iyu73DiRUitLktrLSp+wHqYUCi8jumxYVpRTCjAw2BS\ng23nn0dmTP65GWj2Q2GFyCNDUr3x+ecmke3Nd7RQOl8t9Hk8eR5M7h66YUtwA1MROUKojUFmEf/+\n93+T/zI8wOUXvky7F1K/9gwffOMbeceHfpKvP/znJN01JqZKRKtdikXFV77wLC989d9RHRrklsPv\n5V/81E8zMDZBpjKM+gcsmxtGi//gE/8WERH/qzlhDhtt8JIMYLTYdEvrngfewCOPfon3vPudGFUi\n04KgNEgUK3bvmSZJFPVGm1YnBiy279pNFMXML63Q7V6hUChw33338vjjj1OtVmg1WxQKBTKVIaTm\nzMnzTE1NkmYRwnMRSKngdQAAIABJREFUysITDo12h+PHr3P7bQdR9hq9qI2KEyxPU29EFIIWg0PD\nuI5HkkC9k7Bn2xaefP44pUoNLygSxSFGJGRhRH21zuzsPN1eC9sxrDfmSbOYdqNFmiYEQWFzEwMs\nzsakkcEuSlqzOWG5NgrFAUG3kfcoaxOGaMXQ6xqcIiQdycx+gy4bGmvw4vNwYM8pnOS32Tv9QXTH\nwx6vYYkAGQQ065pWG4JAIDJY665SK1RphAX+25ESgVdn24xgfFTz9XMCVcyRocC1GRkcZGm+haMF\nqgu91ODYArsKvWYeOSOkQCuFVxKMlgQTe33mlhPKIwbZEzQ7kMaC5urmFk7usJxRbzQZqdX4/F/9\nEXfc+1bCqy+zfv0MRmfkkX4uUkqQNloIjMoQlo3lBji2Q16m5/5gwpJE7RWiXh1r9hiHKwM89Nf/\njadfOsOv/9ovgHKx3Vz84DoWSkOr3SPLNJ0oodWK2b+tyu27xmi0Qz771AUcy0FrB9uCoYEiQcGi\nEJTZMRFg+yGLyz26XUm1ZJNkmmrRoxG3SU2GaztoLXn6yxc2vXYcJ79ElXlNdxte49L+qkjhBkpi\ncuRnQ5gjcRDGYzq4naocoyQG6YkQowS6OMi77zvErbUrRF2Pdr1DImx8K8pzJGWRRpbRrWu2DWes\nhbkasx3btJZibop6uCLi+oUrTI2UaKTguQo3VswuCjo9xZYZD0vGtNe7tLsxlYGISiGjWs54j7yH\nR/xBznb/FqE2h57qOEXZ3HCijsMYKRTO4E50888guDsnzKku6AauKDLV/DMIFXrNJlvOCM/+PXpo\nJ8y/QjayC1kZxysOUAsc3MEhyjM7kSO7kNImy1KUX8GpDFDsXmXLgd0IuRXjuWCPo+OIjjWFbc0R\nlCbR4x/FPf6nJE/9AcYGk6UU0xVKJUly/gW4chwRjGB0D6s6ggMMTO4imDtBsLCM5VXRlo1KCujk\n28eSfGfSshCARhiJkBmg+nEAffWW0RiVgTZkQpLEWX7R+7rvZJkzyyUiD027nlD90z/h7LHTiJvv\nYam1SMEusNxqoSKPc8kwQ/E0tdYsQ8WnMUGASRJUF1KtObTT4RsXG3zxb79AmkUsr6yAgqHBGtvG\nHVwElhAUA8nEQJeks8Arp5cwwmLL6Hh+YxJgdB/a7HvevDaT6Z86MqVxMoPKFJZloQU42hALg6sV\ntjJExhA7NkKluNIiFhoPwZotEMcuEtx1C36U4Rc8mt0eKjNkWUIhk6RRwp4P/RClwKf72FeIZpcJ\ne12+cXmBqe17eaqTMnzrfVy8cBm1fJmCL5FJQiwkeA6rn/oK1kuXsd/2FsKxSYpK0bPyqIhQ6xtB\nrsoYEgSZMGRoXG36ztl5Blpk8vjQzQ6t8skWhtwyXcAGFGj6Vslig3Dcd+oWxuQ+T0nIz/67/5tf\n+7kWpeWzZNhMjCU8/tDHGSpJ/JKPJwqMDVqMTli8boeg4xRpxxmt7vP89kef48CB+7n7Le9iYsdB\nHM/Ks1voFzH5y95gH4uNtX7jKNhf/2ygQK+9Mb36TPc3zs2q9rMk45ZDt/LEE0/zkZ/9GU6fPE7R\ndShXKszONykGNtu2b6PRajM6Oo5j27SSNq7jMjo6yurqKp/4xCd55zvfyaWLFzl0yy19T5U1Go0m\n+/bvo15fpFwpkXUalAtDqJJN0G0T9gzPPn+Bg3t34tkB651L1EoFRoeHCdwhOi2NMS08z2Vxfpa1\nlSVavQzXC3BtiygCzw1orLdoNlq4jodTqxInHeIMXMfBcV1sx/kWJo//+OisCnptgbQhDA06gfq8\nQFoGxxOUa+B2IHMFB+4aZ60dMr/QwPLgzQdvZdJvInDBk7hDNb701MOceK7BvrsnmBwPWLwaEiX5\ntT+9TZI1LRYXMyLRZGx8hp5b4i/PV/n+6DIN2+F/nE45sGMcOy5RERDrFraTobIerhBkXUNmeUSJ\nxvZTLAdUokGAp2B1yVAKNHYm6K2AE4AbwHDJoNPNzU+jlbC4tMyPfeBD/Ntf+ijJ+hrrJx5Da4W0\n7L6wJMtVgypXkTlOhuVoRLxKoTBD3DiNEQFaZ0i7gjEthF1BygIaQa+9Rnz0S9y1ZTd/9VeP8tM/\n85O0WsuowKcb5n5YaZqSZBqtFd//pm1cma/zwqU6aWZwLIcs1Xi+Talg5cWO67DcaWBjs3e0zNHz\nq/iujSUkcZZgeQ6tKMRzHHzL4eVjs1zcvMAPzzH9+JwN1Cb/ujGv8uxumL+zUQDljxSmz/lLsPQo\nIQuUdRWtV4n8rTRq8JbtAzjNc/gjhri1QsXrMbeksXxBQdio0iAF2WFs1CVNG1RLNm4BGh3N6Q5I\nNyDNNK4LgZtSbxl2TtuEPU2SZVxbLSKCGiOlGRJfYHiGxfklpm4+hHAshlWbN3gHWNUniZKzm5ob\nk6YYLVEahMlwxncR7Hs92aWvIA98Hk0PGV3FNE/C+iyMv4/MWWT9xS/iXj9K9/hRro48yLi1SjFs\noZbOk219HTg2vRe+SHjxBcy2w9j3fwhnYBSkRdzrUGuepLjNQieLYFqYuIBwJunEMyRekUL9G+jk\nq9BZQKRTCHcMEy/hCdjzpu+maWzkvgdhfB+6tUK0uoxxB2DHOOO77iQoZljN55C1SYzl4/iC4Pbv\n+bbz8J0LHmMwIgOdx68LQGjQUuOXBpC2lds+FwqUysOs1xd564+XuHz6NNfPXMBzQFgWWhhMK8H9\nyM+xf3yMRqOIIGW85RAWfewsIUsbeF5KgyorYoZEncUjw0gDKUjL5oGbI5675JBKG+3aVIoDZDoF\n47FtOMNVGUoYjMlor10FaaG0wfbykDEpZE6iJnd6DqOE8QGLMJGEyeYq5ixRyEwjpUDofpQBOR8m\ncyxQCtuAZRRS57dJS1gkZBSSjDDLKDdbqNEBVqIQlRmMylGgOMtIJZh2mzOf+xxycZmG5+EJyVfP\nrDGqr1MlI5m9ws6b9/Lpl1t0Fq/TyTKqWIRxSiYtrpw4h33kJCO/8n/RGxrGpIpMbtzCNYnJ84CM\nzvO/BNAzCmNstNH4ZoOevnktkk5zObq28jgIITdg4w3bcXK0pN8qvcGrMQKlBVkv5Sc++h/49L/9\nMTwr5eJywq23Hebi2WPoTLGy3qZQ9Lh8vs2JbsrerSnDAylpy6dUtIlWniBefJEXL+9iZOaD7Dpw\nAKVTbkRFvLZFhcgX9gaCeeNv0a9/XhuTkf9eygAm9+/Qm9SlW5ZDqjRTM1v5k7/4FP/z03/Jb33s\nYwhhY6SgGFRYWFgkyTJmZxeYnJ6hVCrTarfoLvbYuXMH3/M9D3Lt6nXuuPP1NJtNLOkwNDSEZUsW\nF+fxA49aYSTP47HLONUJbpk+zOj4JF999CGunHqUHVNDOH4FIQVFx2Xfnr2UKxUWFhbIsgTHdUlx\niLMUO3DodNps27qdRqvN6VPnuHz5GsVimTvu3kuSdrAdH2kFFAo23W63H2y5udHrAZmhF0KvK1CJ\noTIEcUuQNgwiBa0gszVLSy1WOiE9BWOF2/nQD3+C9NIjJPWLICyuX5tj3wCs7TAEwsZxPMIIpAth\nU4BSpEg8T7LW1ExNeNRGtlMb9HlB7GNsyzbCr/wBWTugHcWsNyNW1tvU10NsoagU8mTrdickTsHx\noejljrNK5TfQoS2QxhmTg9DpgePmr48Um+bGnbywzH//T7/CBz74QbrzV0iTFCwHow1SCiwrw/fK\n6OYCOouxHJdw/iJZ6xxptMb1eQvLalC86UdQV57A8osUqhmy4COkhV3ZifRLGG+I3spVvF6LP/nj\nT/JjP/bDNNsdkixBG0GYJuyaKrNlyOMzX7/KgO8hpEZluWrRcWw8z8L1JFHYIwptfOmwf6bAi2fn\niRObMOzijldIlaEXxQQFhyxRZKnhsUfPMr3Jdh/QD6Kmz8nZINrRL4Jy2kKeGvDNhxgL+ganuZDF\nkR08tZtF2qBn6NR73DY1wsVLl7il0qG1pvA8i/nVNVpxkfEtk5hSmZHpCfbsGWE8OkfcgG59mV5B\noAdTzl+BzPFYWl6lZRy62jA6WCBOeyy1BRM7xpjcNoQztJ/y5G6GHJvLJ8/Raq5gUo3utWg0Esq+\n4qbyAY6bzRU8QoCWLsJkWLbP4Ht+CREkXDv1PNlvvpmB3btYfPk47pZpossdvC1dkh3j1FcdVjtb\naOz7MGNyleraVVKnDO0lVK+FKJdJpY/nDiDmzqG7bUxtGAs7b58qjWm/AIP7MMEbQJYxS5+kHDSo\nLH8aetcguIf0YpfemWMksU2t0C9Qr75EeYeDs/Yy4vJjqI6FSC0yAjSSNG2jrp/AbzUg6iCsDPEv\n/wNXJz7E9m8zD/8IadnCcjwsv4Bl20jHRdoBQuYOo0oZLNtgjI3l2BgrYPuBW9m+/yDd+jyPP/S3\n9NbrWI6kU4fvfeANLI/4TB8+yKlLzzHij7FmegSRwTQ0LR1ytaE5UFyiEdyMbh/FCyxcC3ptRRq7\n2E4Byyhild9MszQj8C2yLMv9BSBf6NJCaY3KNEVPoXUMlp8ncQsNRiBtm2uLMYWCYXJ4c/C7m8Yo\nI0kUOHa/1BECtEJbAqToE7mhb7mIURrHZKRCEhy6mfWXj7L4+AqOJVG+Q+C4pFmGrTWJ0YSz8xSi\nkFBAN4xYz1Le9LpJHn34MSaqRSYmJri+uEwS9QjbhmLRIjIGmaUIyyXKMizLRx1/BfmmBzAiT8i2\n+rEWcX9T8IVFYhSpyYP0BBotDD0DDiZXL2xyZFnO2REyRwQ3bHAQ9FugG06lpu/WudHeygvTVBnC\nRPLge99P68ojeMevc3G5iR8Uub7eZMv4KLbtstBcZX5hlfl6xMjwIAe3ZsyM23QGNX//yDIFZ457\nSzVsdqM2+Dv0i5h+i+qbacyvjcl9zaXQf/u5mRoYI1Gmb/u+Sf6X6L9mqlOCUpn3/8RPMnPodn7+\nwx9msuYSxxHVWhUSSbmco2tKKTzPY3FxkcWlRbZt3cnevfuJo5TZa3O0Oy3uvOsw7U6T5eVF9uw9\nkKuH0gwZjLLrwAPcd8/d1EbG6CRdrp39OstLTYwlKRZdjNbML15kb20fh28/yJHnX2bP3gM8f/Q4\nStiQZfi+R7FYZL3Z4cyZs3nUgIyAjCTtYrseWhmM0Hiet+nIDcg5O61WTiwWicEXEC4ZhDQoJVmd\nM9hFw46REodqw5yNZ+kIzYNv/F7W1pc5fXWZ1bPn6HRSRjzDwNgAu70hCpUhWj1JsQrdDpTKmqFS\nlZW0R1zMSJsgbIvKwADj4zV2bp0GLSl4cPTCedprCVkMrg0lActdgWNpCoFF4OREf50ZjJNv1sWC\nQCY5iFgI8rRxT8BA2abTzWi3wAo2Nze/9os/wwc+9KP05q5iTJ7ZpEzennIs6Jx8iujq4+jqvTBy\nE5W9h2icPMmVI+fIsNn97l+mc/U4dTFO1a1RX1tmfTHE0isE1SLSPUZxfA/FqduwStNkUQv7+kv8\n5m9/nMN33M2u3ZNMSof7RrcxIQpUWhX27xrjCwtnud41eDZ4jkPgS1xH0g0THCFx0HRFxHqniU4E\nb78FVr0DfP35q1R8lzjOd4RyyeP5Z69QB0r/DHQQ+keVDSSf/sM+eruRS3Uj8qD/E/pGi9uA8UnN\nJJncT5LuIWx6jJQv4Do25aBIu7WEkDXmG10SEVAZHKGy43b2TdXYNVVloKBw2w0un7iKiDv4wmXn\npM/1+RaZyhgdrWENDjKfKK6sddk6CIdukgxuGSAqbqU2Mkh1KKXXXOf2H/pxXvjM/0BpRdJaoxiM\ncPS5Jzn0tndzwRnY1LzYgY9wy7ieRGiDHJ4iOv5JvMvPc+n9f4Aj26ze/iYqQxPsf/9+Lj31MEl3\nhfmFOWrbZph79PMM3f92ms0BSvE6Wa+JCdvINIFSBXd8DB1FmEIZ2Td9dHyfpB3SPdXCd34V6/6P\nw8zbEVMfhe482ZUX6Xz1YUT0MMvug8Qjh2FCM7DwKZCQnX4GefUpKFWQoobRMVavBypEyITwyBdw\nTIhTBKMzxPhNOPWn2Dr/G/C26FvPw3eapMrELhAWpm8qqPsOmkrlC9QiIe1JMhWy2G2jVYbvFkl0\nLmfutOrYlovuJoz91L/i+dt20Dp/EddewOt4rDcWsJXCK3qoUkotajG0c4zSskXLG4G1KlVWcQJB\nOQiZiGKUDHGdIrKb4EhJikT6RYSIEX0FmURio0lU3kuP05D6WkZQUvi2hbSdPKXWkjiVApnSXF7c\n3OYcpBlaWAjbRihNKvLWEFLgacAYMtfCxZBohTRWTqTNEkyY4Vkpwe0HybTBihVpEqKiFFuCyBTR\n8ROES2u0kph9B/ezMH+dJEz58pe/gS0Miczw0x4qVJjlLhV/lEZ0nYGKRy9NsUyGznJEY/Ghv6d8\n9BSlD/wgemgYpXLzKokgNgIpMpQ2RCL33QnIQ1dtk/d8N6+1AbTGaOubTlk33JA3kBXz6mbERkcJ\ngZAaW8BAwdBwXKanRllZj5ldPotrZ0yOj6BUwlq9zdrqKlt33kaatEiiFpdWSiw0W2zbNcjgTg+v\nVKMVfoNTL/0eO2/9yDd77tyQqfeLnX6OlvhWRU//54zue++aPFhT69ynZXMj32hty8IYUFnGXa87\nzItHX+K+Ow8i8ZhfWKA2OESpVMSReV5NGEbs3buXVrNFt9ulWqkxNzfH/pv2Y1mSufmr1Go13vKW\nd3Lq9DkqpTLVwhaGxiYhXOUbX/8b3vq2H+DQ/l18sVil2+rQazYZG55mdWmWOGkwM1GlW60wvXcn\nS9euMLu4xvTW3bQ7bcbGqrTDJuv1OnEnwyn4GFsQ9lKCUgkLBaIHAnyvhONsUvpIzv0qlQVpYtAF\nQRaD0ZKxkTKzi02QhqoD991xF3dNDXKosY+19TrdI0/yd0efYFV5iK4mnb3Cjnv28vjpOm6xQtK8\nxPTEPgrSZz2JKBgfFXvYOsOyNDJWpL1cYJAmESpLKZWqWI5NEiYMSZBlQRoU8KIeaxGEiaHTy7Ad\nKAQiL/KxyLTGlYIggGjJgAuRkxOf49Ah7mV4HmzWl/Ejv/iLRMsLCGn1UWWJlHmsz/LTf0Msh7k+\nO4JYuETvSkzjr/8/guoIyEkmRwdpnfkcl06eRdhfQqWGQiDAKjE4cA+tlWsErKLCOdJWncrWW/CG\ndhOtXWFsy8381r/7Bc7+1eeZqY6zbes4A4Uqg9v2Mmy3eN+e+/ijv/kC//nhc+wZH8LolCwFWwjS\nzJCZlFt3D3H8yjpGSNZnl9l9/wFOHvcxKgUchIKlpQb7DuzGosajRzaHYABkNwxOBVq/yt/JL13R\nV1puKJXoY819Dmr/2k/ZjWveh+pN015dob2+yl27BinGK4Rao2xJz5R54VIDhrZz7+vuYseWCpOD\nNq6ruXTxPC++coEdtmS0OorMGjhezJ4xGBktM7F9F3fcsofmygLe6ir7bt/OxctzfPGS4mr9KFnl\nOt99+CA/8MZJrMymsvUw0CDTPl6yRmCF2OvL3Dby7ds232pI18UZnMAtOajZa4DAHt6NdAVnH/9z\nWhfPc/ewiwjKnPt8D1tnuNsOMrZjmqKruefgTlZVytmVlNd5EabXQoctpEoRnQai18OoFOnkHDGT\npug0Za2yjWyiyGB4mNKf/AFseZjVWEC7QeUNP4X80I+ynIa89MRDTLsxb0g/BYffgTn9MIQuQmXY\npRFEUEZdv4ZJwtwwMnCxlYWRZaCb7+txFxPuwJp+w7edh+9Y8Cil0LIfCKAUSsUkKgEtsbDJTIaQ\ngsBzcvtuA5mO0QkkUYpSYNsJyQqYH/hhrp66zMDYMIXVOq3nn6NbX2f3zt0sqAoeLv7gGM20x7ND\n7+N10UOI0s2MFFdQUuN7p9m3bZnvvm+KuXrCsRcVjgvDgcuQ3cIjyZewEBij0AJUljdyHdsmCAQq\ng44BoTI82+DaHgiDYwmsTdJUYqVQtsZKNTG5u7ISEq0N6AzLd5ECAq2xMEQ6JRFw2LL44JjPb11u\n04xjHKVQRuBKlyyQyGaT+a8+QdhoglHce//refEbz9OKEzyjGSrtppu1cRPN0lpCktmURY2esbHj\nFB1H9HpdLNclMhrbQCgswguXWfrV/8S29z6IfMN9dDODtASOMYQaemi0FkgDWuRp9hKdy+s3f1DP\nc1iMBN3nA/drH0vQl6ObG6ewG6rwjX+FQGlJ1MvwRYLwxrn11iJTNc2Xn7vOK6cug2XRajYolkvY\nepEgsNGOjVEWvqXorHTZPV4kmA6JQolSx/ICx+jXbILmNZSdPrW6//VXc7Ty75o+61EDRufeU8YI\nMi3I1P+KCH3n0T9p9osrWwBGY9kO7/3wz/LIX/whJjE01tuMjxXwXA/X97l69Spg8DyPKIyYnCgy\nNDTE4uIiK6srKBX3nawlYyPDNNbW2LZjF+uNOllzlZVWnYUrywhtsXV6lFeOngdj0+2kzGwdZbQc\nEM/NI2sTfPXRr/H6227iwfsOc/L0NbaUBjl88HZOnj9OY20ZEfUQvoNSBt/3cGzwPAGOQqsckZFy\n8wtHCLA86CUQNnKbC6E0aWpIk3yeE2144shXycx3cebIGW6fHqKTpgwffD0lt0w5miMYzzgy2+bC\nYkQWNxgdG2DHjA/aI21E4Fc5fzahnXZRRuIVAQmWJTD91rdlS8YmanSv1RkvQAdBW0osKXCFuUH5\n0opcdZKB5UOWQS9SZF0B6wJRgKVIMzPuEXmGLDWIusAvb3J+uiFGi5wDp/OLR4cr0FPUzSS92Uu4\ncp1uYoi76+y77RCeyGhEJa5fOcVINWR0xwHOXr5G2OhyaM/dRAunabVnCXsJJi4zXtmBPH8SkXwG\nmd6HM3ov6ewr3HfPW2h9/LfZdmgCqS7QbldJzi1SPHgvekXw4Xc/wK6hcX790aNU3IAkTlBakCSa\n8SGf5UYXg8C1JWeuF9i2tEqaxJQKHpnKrU1KxRrf/Z57ePBtCSfev/mCR+s8KHZjbACMegPq2biW\n+y3oPDSk304XoMQYMvt+mvM11q4/h05DbGMQ7EZHHaSJmdy+m+qgZL2xQnu9wdriHHvGbXqRZGR0\nKxcsnwvL63xtrcehAcE+qdkWZ+x+4yFGx7dglQc5eP+bWTr2GAud62TC5rm5aZ5YiFmw93H39BQf\n/aNH0Mk93HdoC2FjGW/3BCunTiEHq4zWDKXAZoqJTc2NkA6qt0q8vIrIgO4K9thBKm96HzujEpZM\nKF09TppYNNbaJGmX0k4XvbBIsm0rztgUMtbMlA1JMyHrNpHddYzKsNOQrNNGWR46ClHC4HgBYWOB\noHmFhZcf50RngJ1v+Vlm9uwlWl7FlRFCtFk68kWS8y+zerqLtWOa6O7vovjGj2Aq47T/+k+oDXpI\nx8JkKabdRsQJOAaRGhwRkWmvT0+Atr+N+aVB7BPXuOnAt56H71jw9Lr5EUTK3DlTaYHjBEg03V4D\nYVwsxyWJIpRROLKA7fkkWQdL2rn3wkKX0q//Kpx6EefAbSRikLWVF3AnK4gdW1mtjmOldRKnSN21\n8KMYENSTae4c6SCT5xFxj9TxcYI13rUj4eSUh1+WrNdjJp0E68oaGRZCaqTRaCmwpI0mRTo+pYKL\n9G2KMsiBBZG7iLaj8MZze5NGPMKyMJaLSGNsITHSJlEppSRD1cooAcVMEaHxDDjCopXEvCNd4/n5\niLeOjqGilE8vtigbTSo1thac++znsaTAVgmu4/Lc40/jSShlGUJKWuFVbHsIpRwCK8hzjoRAZ4to\nI2i2uuzesZX1ZpdkrU4mExzHxjGCrusz+5m/o3zyJNUf+QA4AWE/CM8yeTEitSYVeR/cRpCiwGye\ntGzSGGwbaUmsvvmFeA2a8mr8XD7/0myQh02fSA6Vik1aGsBaP8n64nkwEfftEvzQ972PqwsNPvfw\nSzx39DSq2aboOYzUylQCzWK9y+XZFkurHm9aHWTnXZKhcA4hHIyJ87wtuMHH2XB+FjdQH/qPX/sL\n5eoPra28paUFmcq9hTK1yZbWN8VX5NJ9aeWb9fe89wPs3jLG7/7HX8O1BWtrdQrlMkppdmzfwdFj\nR7npppsYGh6g22vxwpFvMDo6wvTUFo6feIW9e/dSKhVRmUPSabC2NMfI2BitqEdzaZVaMEypVKbs\nOVT8ErHrsbraoTk4wNhIAenB0tIczU6LY6dOMz46ysCQC67F+UtXWJpf5O0P3INvNA9/6cuUaiOE\nYRcjJX6xgGVLOmEX2wpwXWfT68aWhjgUdNfB1nkx2MUwe70BrkD6IFwYnhylUA3YP1piS8VhascA\nYXKVARFQ8VOOz2uOnFul2+pgqRi3ZpGaXJ0yUoV2c52BygjadHH8hFiRy3SVzmNmTB4nEDhVipUW\nFTRxYuh0UwpuHoks+kskVeBoQeAaMpXlPj9CYFsgHEOswZOQtRO6DRdTNFQKgmSTXZs07iGtnDcm\nbINqXKI1u8jKya+TJSGrS/NMHPxeLj79EKWiJFyrs5YVsMuS0ZEywt/B6uICywsNDuyusbpcp7HQ\n4ObDtxKEXZrtmJ5dozp8mGuXTjAjrzNUWSBrDzG+9zDBFlhuPoe0PJzAp1IscPHs03RaW9m27xD3\n37qDM5eWeOjCfM7b0yBtuGl7lZcvtvEcC2nZrKuA1fML+H7E4ECVRichTntUCxW2Dgck1YBnH//Y\npteOUrzazuoXORtHkRvP+09Mvw1txIZ4AlT2LnoXLdqN4xQDSbMbgjAMyC6tXpOdkx7Dwz6mHBAm\nMdovMzQ4yJlr67znB9/OL//iz/CpE8P8+v/xLr7xl4+wtn6Nka0F9kxWmb7z9QzecivNtsYtlqnO\nniHesoVj9XGccc354+cIpjuMDipuKrVojW9naWEe2ZunYFfp9K4zOFzCKwV49Ng5sjlTTy0t6IWk\na03sLKL5d7+Pv+sAJo4xjQ614TLt0xGmu0B5YIjWehdneY7Re99Fae/tIBKqZ55AzFlkiUKHPUy3\nicySPFEhU/SEJAiJAAAgAElEQVRMgj71WWzWsLIYu9PCUw5JOkBj7hrJX/4q1nTIlkIFtMRpXmGH\nv4XErNPZsQt7tMILz1zggfJ/Ry9epdFwCdsxU2IRHaek6zFEClMtYOIMkTVRsYMZyA/U1fknWev0\n8D/wu992Hr5zlpZtIYRifWmRoFrFtl1MpomUITYuUguSTpuy52NZFhCjMg+BIEsjuhe7+B/8IeSe\n7ezcGXLBrjDQOoesJUR6iO5KwuLLJ6nesYNi6zqqPogMBI51jOHiKtnKw6RWjVKhh+7l3Jg9OxfZ\nGS1xKNhJkkQ8e9nlutZYroVlF3BtF2FbONKmqyMCE+EFFgjTb3fli9+2BKXAw6AR2iZJN9mXUAmB\nEES2QAuJkgYvTBEqw+2FYDvEvS5CWmS1MqmOcdebLBQ1w2WfYthk/uIcpWKNIAvpuFWssE3JsomT\nKE9h7nYwSiGQZErjC4vMRGTxdaRXZC2RqEzTUV2kkGgybj10kOXVFfbt3kYYxdTX18kMJFbuhI3n\nIy4v0Pr4J/B/7AP4ToGkD/dGUuGR53slymD1k2+9TQar5kMjdD7jG+0iR/AtUulzHGVDFa5Nnpdl\nWxoiaNcvwOLL+MUhyhWHlc4q2cIpJoMBfvIH7+T+O7fyub97jvn1DgvtHt1Usme0gLAh1hkdE9Or\nD5N6irjbxSu4mD4nKX9fGyq9jS+Y12RviRvvUZN7CGlD3008z19LVJ+vtAmE57XqpRsu0QBGMzZQ\nZtuD7+FrX32CY19/lpnhKZIsY2CwQL1e55ZbbiHLMq5fv0YcR9x77+s58sIRtmyZ4M4776TVajE+\nPs75c7NYwqbdi/E7ITM7d1CpjmG7ASvLi6yu1vHcImmcIV2XlWaL6y/Nc2hrFSVmaRjB2RPnuXjp\nEaamtvLWt72D7tIcabeNJxO+5+33cdst+/jT//l52q2QSnX0BkKI0IRRF9/bfEurFwviFgQFMMpg\nI5Dd3My0WoFmy7B31wQ/8ObvZ/3cy8xMDTMyVcWUXMJeQoTHsy9f5tG/eZFGDAcOVHNFaSdDZIpK\n2Sa1YLxQZMuuQa5eUowO2iym64SdPoPLSJRRGKWRlo2wNZ6QuErjeg5CaoTMeTNOfh8hTaFakGgh\n6WYZRSc3JNQasHPzRCUNFcch1TZZpLA2OT25QjaPelk/8RWuP/bnsO2d9NQg7ZWzVAYqnPna5xgZ\nn0TYMNdIEMUAa3WZbldTli+RWePccuhmWqnEUz3Gbr6XuYvH8bhMbdt3kamUE09/iVvf/cN06v8/\nZ28eJdl113l+7n1r7JGR+1pZlbWoVJKqtC+2vBtsyY1t3HQDNk1j3MeHbgaY0wxwmOmZ6WYaOG7A\nNGObZjA0jE23zWbkRbZkydZi7btUmyprzco9M/blrffe+eNlSR7G1jh9z4mT8UecjIhf3Pfub/ku\nTzIsLFSaIoWhEYDTnMIvDlMrj9Ffa3Mpusxm5xVWe3W0GOKd8zm+eNxk/lM2HDkwxMnLLYbLRaJE\nYTsuXtUmb0ZR/Q0a4gzVyggDEfH1b1/k5jevMDZVpq4Me3a5d7QB9Ov2r5A90ca8ZrJqdm42gu8a\ndWEI4vcSnRlG9U/j+z71zU1ELMmPTbDdbDMxXKboK8qjk7SDbQpDZVha4/KJp/ix99zMf/rPn+KR\n1j6YsJgKV3nwi7/D7e/4MFv9AaXJa8jPHEDuvZlSfZvW+ja5yf0U6lt4lzu8cnqdw3M1Olsn+cs/\nf4af+Gfvpt3osnj2O6Qba5SHrke7Qygnj3FDamUYBLuzbdH9AFkoQW0M1arT+85X6T3yRWy/yOxP\n/w72vv2Ub3g3aX0d27IYFym600QtPcXm9kUsT+Md/zKxKuEkKWrQh04dmSTIzhI6uEB5apTC5n1Y\n9hDCzWP5Lp3mgNagzB13XEe1fo7i4DGsdDPDzMUWzPhIIxmenmB1c52JtEH6yN+iBJRsC8cvEW71\nQBi05yE8H2GDqNYQgxC62xj1+uHRjiy+8asf4Zf//sL3jMMbJjxxKon6TRrnzjF9w63EcUwoBI7U\n5Gwfg8LPVbGEDSSoJMEQY4yiW9/G/x//FaW7fpTKTIH12GGgYyxlI3PDJJstTNomf2QPnqVISzW6\nqk2+H1CSKQV7mWa/wuR4Sj0eZdTdotsbQvZOU5mC/eMvEsi38PLLF7GQWGTtaOnYSMfHIEhNmKlD\ni8wAFTRGitc3vbkyXNBYuyxGTaIQJsK1LAa2Raw0ji2xE0Pa6eMAkSWwRETahg+Pl7mYWJwZJPxI\nGlKyBJfDPvuHRjmtLXKuYe3L3yKOQtI0wRLgSkkqMtq/bSRCKWxhKLgOYTzAtWx6aZphqmyBVoJT\nJ08TdgecPXUBmfOQxuAFEdrXRColNik/ddfb2CgU+Zu//AIzH/sorlJoBL4WJCJFqaz3osla54Mf\nAsUjRTZevILbscTODUZmkvRXKOlc6awACLCE2El6BKEeoPqrpLJC1bVINlYolT1MEKB7PUItmHZc\nfudXf5R6I+azX3qO4xfXOWUMw2XJ4SGXtB3SSvsot4SrXgJu3RleZeBpcWUMypXPw+sl4ZVOjCYT\nTDTitWQnVWSPVBAnP8ToZuf9r7yrRmChqdgp/Vjzz/75T/H0g4/QbndpDDpEYUSz2WRrc5Ojx47h\nODZxFLG5sc6e+Tls2yYIAkZGRqjX6/j5Es1Gi36rx/nLazSimLxfwc0lOH4ObXsY4aB1H6VTHK/E\n1MRV5Kw+/cEAmVqsLHVwrYMcO3IzUSsg7gxYmJukubXGcG2IajnHDceu5cVXz5MkguHcMMaOaCXN\nHdDq9wYOvtFKNWBBbxWcSpYkjI4KggBUCvnUkFMhX/i/P8PcyD7qjZDiUokf/cB7CUyfsdkF5gYW\nb5tb5MlmxLXXHaBeH5BXGmFJaqrG6kqPQs3FWA6FUomiJalqi36SIowmCgMGvYBBboAvi7g6ByKg\nmPeYNIIkMJRsSLXJZDHygkZkSLTAaEWSAJYh7IOxwaoICq6mYBxcz6M9KBCFA0YquyuyrgjBxq3L\nLD54P72whlx8Cuw866ur+Dmb0b0LJCZPb2sJrzLJ1EjM9pagvlanMjqCwFC/+DzV8YOE0YDS4CJR\nZZq0lSDW7iOSh9l/90dpbZwnP30jRsVYuSobl1Z4/Ntw6DaXWj9modRiz8ICpeG3ceHc01yIezzy\nyNNcf2Sa66o5Xmh0mR3K0+8nRDEgQmzp4loJe2uCiX3v52Na88DLD/LAc/9A0vF58dWQH//IZ/n8\nf/lZrj68ew0nKUDvGN6/htMz2fX7mj7xFUwhWfLjCQjkncTnryburiJtcIyCVFCqTXDV3jmOHfAZ\nTi7R2NrAr+a5vNxkZrrC4OFzLFwzwV9+/q95ODnCdfZJlpehsT2gYqfU5qeZu2GK8ZveTv7QHaB9\nhDdCdX4YMzFJv9GmVn+ekYlZ1je22epa3P3TH2B0uMjcpS+TS236fWitrSDcHNLLM+R4BJ11gsDb\nVWx0vw+OjfBymHwB02ogUx+VJOiXHkaGF0iPP4qDZLB4HPwKic6TFErEty6QD1ZwREIUbmPiBvmZ\nG7D37KNcjnA+/G/ReOjkMk6aolaeQ595AdKUIdFjf2UM0biATJvEA4NfAlnwMW4eFXRQXpVCocL0\nXT/G/GO/i94GIwXlvEKWbZzhEazqUGaNISW600E1Grgqxq/ufMHUsN2F50OHH7n16u8bhzdMeBIl\neeSFUxStHPOunQGVjSY1V4hQLipJSFSMZVmkKUgN30lTbv3dzzBFxGwqONvu4yd19sbnUUGRSHWg\nVGV4zwHieAOrnRBffAU3X8WYRc5onyX3doy+meL5hJu8l7ipWGFibgknVyZc6mBJTUk+SWO7ghHO\naywpYQxohRQWYRTjOmTdDyOwjMyEuXhtBM4VCvJugSomCglTi8SyCR1N1bIIhUAXfJwoAZ0ipUFa\nFuOu5pvdLuX1DT5+ZA5bxTTOXWY671MbtHj39Uf597/9SQZrW3ixxlUJ2s7Em2b3zLG5tU0xL9GD\nkDAMCJUi79mQxJAqfCFIg5A0CrFFiTSN8T2H3qBH3rIwtg39FNu1IYz5/Fe+yc/91I9zsFzj0p/+\nBUP/8sPYBmJtSI2hL0EqQAosDf0fgpZuS4Ets79yJwHPJls71RVX8DM6S0D+UYNESI1tF9lq1Vh+\n7u8w3QG10SpLK21KPriFEkYKNgfQ6sX0ttt85C0zpG+bp98NeOjUJpd6im6qOftgwKnTOX5l4Wns\n8u0YozCv1YH/WKpdvMYYM5AZKposwdEaVCqIlCFNIYqzRxjtjmCsDHgpxI6k2e8ylishJKS2RT3p\nMJQrc+DYtUwePUB95Qxjfolo0MGkAfNzk6RRn+bWgOHhEXSiKBXz5BwP7Rva7TZJktJoNEjCkH3z\nC9SqIwwPjeAUC/T6fS5dusTtb343SRLw8P1/yztuOcyxQzOsrV+ktZHSCWK8ks8db7oDszOWUZZD\nN0k5u3iSa/ZPMT52kIvrLXqDkMZGg6Cf0Kz3uPW2mynO7CW1vGw0xL27is2QDZNXCVrjhvMvQaiz\nxEIIQbKhmVxwONtp8pMf+ggzE2Mcf+DL1NcTKoUJDD0GPcWhQ8fYd+PzbL60zeT4MLblE9X75F2H\nQm2CBbvC8HCRdqiwtMdDL17k3MUeR2/YSxDHFDwHx7NwvTzF/DAbXc2mVWTfiMNMrcj2Vp21Toe3\n3nGMkZEa93/jW1R8QW9g8F2LnKuo7DF4ozt7vwidULKZBORFnzAApGB9bVehwS35hBsrLH7pj2k0\ntgmay1Rm9tJeXiL25zB0yG+fRtVuYO+hcbYWT7He6lPcczMT5Xdg+xadtQ3KBw6w3gyppa8SL/wE\npeYFxq6q0TIfQkQBa8vLjDh9umf+Gu/NnyQtzfCuq+/GHyqzfHaZiWKBk+t5xr0t3nLgLKNj41xT\nqNI5dDXzwzF3zZ9jUBNsbPUJtcF2HIqOw5sOu/hiLzPJQc4++SjzpSpHkyIffOevY2yHh99Z5zf+\nrz/m93/v85xYUUSD39pVfK5YybCTzBidFShkUM5sFMjOdS0yinqYvBN9Zj/Eq0hRpFwepZzzUQG4\nVkxOal69948ZP3Ado9UytuhwcNZh/dFHudR3+cuvneHdR6v8TDnl1dUyb56x+MoLA77woV8iVyrx\n1o/+T6yU9uAEZfLNHhYpTi6HLI6x9+f+HWOvPEbh/vvI6yV8dxy1+g1yBclg9HraJx5js2bx0APH\n6WjFob1j2LkRTi1H9Lee4/vAVL73SmN0q4mwHYTrQLmGHsRYpTKD40/ibh3nzIsXCbuSkdECwgyQ\npOQOl+DCSarvuIny+z+G39vAtYs40WlSd4FEtRHlYZK0i601gRbIwije0Q+TnFrGefS/Mh3EDPa+\nE9WPiSfvgPpxnLiD7YRQFjjuJP7gO6z86V/REjA2LNEp5Etg2S1EuwmtC1kNasBS4JZATN2CGJ5F\nFnMM3ENUcnu4y6tRrF/8vmF4w4Tn9PIaKw1F2dcMIoHnu6QiRqbZ4RXFKSBxXAedRCitefr226kf\nu4HPNxvcMTpG1G9Q8x2kI5GDIVaaAqdXoOuApMeI0wMZ4lVsilZCYvZhRyuo0GNGGqo5l/vD2yh3\nTzHWOQt2jLQg2dZE1oC5SpGtrofteDiOh7BchMiApEprfEsijIUkfs09FyEwOx4qemd0IfXuDi1j\nFLES2AgskdJEIVPwXUnfKLQlscmcydvdPnnP5iQeL251uJ6QXqqJLcnV1yzwta/cw+alNSqWRbKj\nmGVUSqpS+kGfiWqZpfUNbJ0ySBNywiJIU3I71HstM/f2SqFAEsX4lkQajXFdwjRBJgm+5SBSQ2o0\n9Vab3/j0nzNWqVGcnSI6/yqD2QPkpSDJGmIMMOR0dtOQ/DAdjOxOo3XWTpbsjKuuAAm5gp/5R//b\nZE7Qtuvw7P0P8l//4OfJF/dxbP8oX7rvFNfvqfL8GpypbxKZrBsURPCWA1XeU8kRWDma9Q63XbuX\nR46vQKPJUMWlVArZ3nqZ6f2ZyCImcx3WO+1tswNU3vkI2fDTmIx6vqO3o5Qg0QaVQpxmPkmJEpx5\n8I/hrb/wA8dG6xSwUUA/AVXUGJ2glEfULdMrwsMPv8SrJ5eYGR2lHySUS4bp6Wl6vR7FYoGcX8T3\nc1y6dIn5+Xna3TZhGlAuVZid3UOqIiozE7iWTRgF9Ad9RssVUIbZqWk8z+fM4iq2TNnaWqNw/T5u\nue4gi6eWOLnWxMrl2Ko3mZqexPd9wiBgdW0NN24g0z7loQq2V8VCZGaWShMmfTApxWKVqb1XUSwV\nd71vPEcQhRKkj13pE28KVAC5IrgjAruUEsYQDhrUWyF+wWK4onClZmPtLKFwOJektJqaXrnIIy8t\n02oPOKov4JzcQJk5jt1yAwuHFnjupee49vA4q+0By5u91zS1cBziKEKpFCkziu3Kcou4Z2FtdBnx\nFEkI9z/xIrP7LcrXQHMd5CATdpMSyAt6NkQ9g92HKBSkMfTtjNhgUoF4Y1GQ/89qfPv3Kd/2caLm\nGXJCERbHCKIcY0dvoP/st7FHb2Q9quJ1OqysnUBO3MKQ38aevI7LLzzH2HW3Ut1eJFk9TtXdwyDN\nY05+lcge5/KZbQo8QTWXY+zYh9Brhvz0j5M7cCf/9n03sQxMdTvInKQhHTpaEAiL/UmO7aUt5mZG\n6XZWOLmR4qPZ2ghRAoQyXLOvwp23TaN6ecpbB0hb2wxXR0iTgFKlSqMXUK253DkxwROf+gP++W/+\nDrHZ2vXewXw30eCKJMjr4GUhswIGI0BnhteqPYlKBqTdAVbOwYq6XGo2iWLBtQf28fYjQ6w0Jxke\nqlKuuCSNbeIkoTS9h8bgEjU81jYD9uRDDszOM7T6Mm05ihqbY+/Bw7w6GGKmJElVQhxnXeNOewOb\nlPL4JO7cYeZv2KB/VjJ+9V5U7zaCpZfotjykshgrwWOXDWPDORqtFMfuEzcaeN7uzithFGYwyGLk\nF8CxcWf24e85TFJfRHXWsPMF/F4/SxCThDiMEGfOMPWBA0yOPIDefoqcewCTDoElcPULWeEaXcCJ\nVwALY01CotDNLTrdGjlrGBO24ZVvUc41cT2NrM0jhsYhUZh+G93oknQjXEtmU4VEv2bWnHVWMh03\nYQwMQN55N9a7P4rsn2dLz9Gz9oE3xhWVNz0/9X3j8IaX3OWtHls9j7WVbT74To1OJEZKhO2Tpj0c\nNLY29PsRl2p5Noam0dccQ585x+G8AW04mwQEm2vULJ+x4XGumtqkeRG8EvjOOkanmEEfqx7heW16\nushY2qQbJNyyNyAtTNJtzfBQ6Vpe9D/CmCdZyD3KvugfELGL5crM38u2sYyNFJm+eKzSzFlWmgyz\nITLdlAzQlh1uWhiElmgyT7DdrDRJcW0QUuIkWYJlJRoRS1LHRUsw0uBJQ6Il7V5IHsN9mx2OTleY\nHDNYro9tGaqFHDNDRcJGj8SojBEXK7SEPTNVFo+vIKIArSW+lMg0JdUJxraJEo1jQSIdkl4Xr1JA\nRIJUG3SSYiuDb0t6aQpCY9kuZaXwpCAKWgzOhHgHZrGn9tElo1nLK8MeIwgMGd14l0sblXn47AAE\ntRFIYzLdmiuHgsj0i4R4nS8hRNYVW1nu8Vu/+HNMjs2xMYj5zonLaCT3nm3jy8zpO29nv12lBK82\nY8586zxBP2C64jN2oc51V+X5zpah3xFMJJJcNSLuDbCLNhqzo7ydJWT/uMVkTAaezkZZoJVAKdCp\nIFEQx5kuzNKpZzj/yB8CP3jCg8q+rzKCQayIpSQxHk8/9SwXF89RLtV44N6/p+JGFAvDtLohhXye\nOM6MGuv1bVQqKRSKuK5Lo97g4FUHqbc3WF9fR2nFcG0M33HpdjocO7qAa9scPLTAM889S783AJ0Q\nxSFhqki0YKvVJgwVTqmGdkskQH5H/VkCYTBgfX2TPSMFYqXpDVoUrBy2kJRKJYy06PRbDHodgm6H\nSKVMTc3set8kscLzcniew2XZhzzggArBKRmGhj1mpmfYP1pi49JzvPmq/USDkGe+9UXuf/o0Y5Mz\n9HoRqdEMegPCKKUfKe6ch7jdZmmrybce/yof+8h72O4kNERIvd4mGoAwEqUStNZEsUUUpaSpZmJq\nlE6rS9RXOLbByltUp8CMCwYoooYkaho8JxupG6Pobhr0lsBYYOkd/EgiSFMwbmagu1vW/j2f+3M+\n4FSYfse/4dTX/htJq42xU84/fR9y7DYs1WRjpcNQrkfu4O10l89Q3ruP849/i1BJGg/+J7zaAdIg\nojq/n0LvHGlkqM4Oc3n1FLU7f5re6nlaX/kEV7/rwxz+6X/Nv/upt/LYCRgCtJBoLQnSBDcSrMUt\nvn0q4vDUGBsbAQf3zLF48hlEAZSRCEvz1qPzvP0de+g0+zTPlRgbdUhyeYTSkK8SJRHFkVHiKMB3\nHdz2gL/9X36Vu37rU7veO2rHU0IbdjA7vNbhybo7r6dDUmYgZ9MJSGKP1HLRSJrNLXxpuOm6G9g/\nVsYfXMDJl5C6i449wnrC6JtvRz75ImUuEWNIVcJGo8ncbIGLxUM0mwFvf9e72Hvt9cSWT7tZ5zsv\nLiLXN1lbXOLprYvYrsfP//xP8OZ/8n4qt7yH6sEjNE48R2niOrQZwuo9zca6Zv+UwEkN1aIiiDRh\nEkOkqI6O7So2IhxAojGpRiuN8HKk9UXCc2u410xibIta1WXQ7FEgpVfMkXgpZqRCdbKHrr4PE0eZ\nloJugQ5BtTA44C5ght+DbHwTufwU4vTzXD5zMxt9gzdzI7PJ/bgiQAculhNiz1+Ndcu/xLS2iL7x\nh+ilRYSSSMdCpwZSoJAxZjPz6Z0vEYH14Y/DwijbXcVy/GbKJkZWx8mlAU74CsOeANsHrv+ecfj/\nSXjajFZdTjag3okZqqRYWMioQ+Dk6A8V2GRAq9tmNRgwfctRLpx8mYlkwP5r3s5jG6tM2DAxPs2K\n7ZHomD2xRd9PiFJFIQyJVhdpnVrmVrvJ5FRCJ19l/7TPSriPlXyBtWXYjAXtuE04NkS/cB0v5w5j\nT/wEP3vik3ReeBop8xmt0NIIkYCRJCrdqeI1WhowFkJkFZzZAbIZMoCs1Jn41K5WqgFNJBWpzjLS\nkhEEOsVC4iqZMXu0IUcKUYQWDi3P4/mVNY4OVdBWyvmTZ+n0+7S2W+SkRRLECATWvjnmRcCv/Yfr\nue/zEZ/+bIO8B3nbRqSCPYf2sXRujR979w1sNXtYtsUjz76MbQSJ1sg0IY0ShBT0Y0UMGMfFN5ok\njFG2g200iW2Il5ZISfFSC09kpnkamz4a/4d0S9daobRGao3RMrOUAJQAx4jXyq5sxCUztgSZYmy9\nEfGvfvROhmoWo3mLV7cTyh5M1Fwa6yG1SpESikArtsMES1g4UuE4grgrGQjJ2lZILm+QOUNPpJTz\nmt66gzNtY383VgfDjvDCzn4QO8KCGTXemAyrk6jshhqpDKCaKkl9fZ3Oo7+Ftnfp7OxoOomhHRg+\n8+lP8fP/4hcII83f/M2f0mu+gCfHGGw3KFViIrdLaHrUmx7FkscgCigWygiRSfwfPXqUl19+JTPC\nHQTMz82yuHgOG5vcxCQjw0MsX77ArbfcxOkTL+IJxVpzkzBR+J7L2PAoqeVxcWmd/bPj4JZYrbfI\nlUdIkxhJSnOzQ6cfEDRazB6Zp15fYnW9Q3GQp1wdYs+BQ6xvLNMZJAzCgInxMbZXN2lu1He9b3Ie\nWCLH1PQUcWR4/qkWri2J+oa8C2kSMTkxQ9+a5HIzz3p9k34vYPH8OuEg5OLps0yNF7NRbtolUTAu\nNc0eLLgB2hg2Vuv0samNjbH3yI2cePwpzp+to7Vi0GohR8ewhCGOB5RLBYwReK5LHMbk8jZewaab\nJnQagqCXuaF7CXgy86ZzXYtOX6EUKCmwjMG2BSSgyJR9owSsXerwfHPRpfAnn+TW932QYz/3azz3\nd1+ic/4b5GrXUe/2seIWc5UOWtkU4gvYe68l7BxneOYGWmuLWOPvZfPkE1SqVeKVR9iK9zE/U2Ht\n9Fkie4RT//33KNfgjl/8LCPjNr9+95tYaudZTcB2JK4tQVqkxkaFCu3BxY0G3ShkIpG06yuEKUTN\nNpfiEDNooG46QqvbZenVAdtLgnwEJRVg5crQ71EcH6e5tsbsaJWo34NSEc9x+bNf+uiu946+wkk3\nO7Yvhtc0tTJM6w4dfedlaeKDKaJIMF4VIaBTX2dkZp4bD01xeFKw/kqXf7jn6/zyL/8s3e3L2KPT\nBOcv0i/to8tL7CuDEpLtdofEatCPJDPXvI22P02vvIcL2wmNe77GB9a2WXAkwrN5/9gNfKO4wf/+\ne5/jnltvpDA8gZm9mpHaCJv3/gWpFuioz+KaJBCCZ1ck+2ZiCnbMq6/WOXrr9UTfzb//QVaUZMWm\nMugoRloOg/UtBq9cZMQOyd20B3d1lWB+hP4Q6EqOkuzjmg6d4w1Gy88j974JcgvoMET4eYxdRRbG\nMdJDNM6S9K6l/qVPs3IGBrfeweliwJ3eOYwGy5PZzd+AOn0v9h0fR+x/B5b/52S1eKaFtNkWTI6B\ncCUkKiMZaZDGIKZGYLLNpvk1vPoXcCZuRtWPs9D4MygvgGdjVBvSJoK7vmcY3jjhWe0yNjRMs5Nw\n9sIyt918mDhVrEzUWCx26G5eRJsiHVlkbH6B5voZRi2f/B1v5vETJ4mqQ3ihQyPqMuttErsFctPT\nyMZlvI0u7bjNTLmAWyvyUDrL6LjLSHGU5YmD9HSCCtpECxJPKPJens6gR3Ntk/LQML4Q/MOBf8HU\nI4vIjkQisYS1MxIBlWRHmJDyCmotG/Ps0CWviONqdsCpu8x3nDQhQWEbgycMgRF4EmyTkgqJZ0Js\np4BCIozEVilBoYCx4BnlU9ncotPr0IlC7nv8ZVKTaVHYdgZfHXvPu2n88edZfKCBiBW1fA7Hzmi1\nBdthYjIt6GYAACAASURBVKjC7FumsPwiRvWIXYuhYpH2dh3XEiQGfNsiQiO1IQfEaUpqFEJa5IUg\nUGCR0Du9yHCSYEjpWYKitnBNjAY8LX4ot3RHgjI6G2kps8PCykZbqZTYIhspGpEdCgKd0eIth//y\nid/FTddI3TLLgWL/TJ6KNFQwFDyLZr9PatmEShMFKeVRCEJN3s+z/5oSUhiW1i36no3EYTSJufc7\nXY7dkuDbSVaV8HrCpXc0nwEUWWWodOaRpXQGBk9TQ5xmyU6SSNrtAVtf/h+od9rsVsSpnzqsN/uc\nO3+O1Yuv8Jef/SMOLdzEobkZwrFJUi159sLXWau3qBVy5HMOG1t1ltd7XHvNETAuUdDHdV1efvll\nSqUSzWYTS1hsb2wxPjrCnj0zDPoDRiamsKXgwrlzjNdmmF7Yy8bqZcJUkYYxjkrwHBfHctna2GCl\nt0K+mMezbYqjI4T9Di4WjWaTO+84hm+nOK5DP5I8/8RzjIxPMDI2QTdsM5U/xMTkAivLFzBxRDDY\nHZMEwC3CIAjodOuQZGBBrQzYhorl0L1oOCUu8sD5h3nv++6mJBL+/mv3E+QcFuZHGDQ0h+dH2GgO\n8CwbJxGUoz6r1UlWpq8n1zgPwKMPP8P+A3vIlVfotjYpFCFVKX6hgDYaIRN8L09tZJRzp8+RhmAL\nUKlGKU3ShlaSYRbBYMlsr0jb4Fk2SZgQx1krPgFkYjBpZiuhBwKRGuQuiRL3X9BcW/EQX/0Sk8/d\nx20f+QRrzXdw+t6/oiAFneYrFOZuoSAS4vl3s/3tz2JETHX4UXLePgaDAdWJafqpi2XlGJsaotvb\nIg41tckprvnIn7J3/xT3fe7TfPm/30thaj+5GmxfPMuwpdE60+qKkhjP9QgTTU4YmnGAvX2Rrq3Z\nrrcZ1Cb4sQ++jVrV45Xj59k7pznx1DaOiWgXPQb9Pk5JMztUJNpapmJZbC+uUMhJBisGd+Igldru\naNeQ6fC8htO5csu6orUldgo3IzL7odRCImkni2g9TtgNKeZ9irUx9o2PMF5K6V46gyf6/Mhd78ak\nASdOn+POkVFytVHOnnuS/cMOqdJs9DPSjLY7+CMHaHcH6G5CsRVz7oVXuXpli/baGVb2TlKbmmPs\nxgl+tnSMs2c/y7NPPsNb735/tjEsm9L+61h/9MtIKRjJK/5h0WWsnI3tg8I083Mt/FKJoN3eVWxM\nqjHavJb0GEeTeh7r9oDa5hZWMEFccCgUWgyJBiKwcZ0CnqN4brnMy7//ADM3nmN0epzSgSPEWuGr\nlK0L6+SLeR74wqdgDfZ94NdZv/0gy2ef4uD6k1xtv0zLdbE9gwmybo3ahuhPfw45fSuqvgE2iETg\nakMnMCgtELHKOj2ZkmR2Xpf3Yq0vEuk/wKtez9zyb1POtzHJWegeATQ4M5B2vm8c3lhp2QEhFG7R\nYW6sQhrH3HfkAJYOKLeWsUb2kotKuEMjRKM1csE2udmriTZXKF9cxBEOswfK9GXAJT1GX/rI7SUq\n3ijbfh8xPEscr+MvFCheSClYw0i3ihm00VGC8Cu4JsCyLMwgopAfxQ7bmBiSOKabq7J87EeQX30w\nO7jMjuYLECqFtUOKltrOuhZyh5klrng37bBkXgMw/+DLMopE22SuwwZPQ2wZaDfRkzn89gBd8dBS\nEBtw0xRjUpzU4vT6Ft7ls4z4Hue3tthY20ZiSGwbEypq73wL0dIKd/7YLHuui3nwsQ79VDFkW4hU\noYTg0SePc9WRa+lG6wxXKlw8tY5lO6Ayt3NhgUxS7J3uiWVLhOMgkhhDZhQqlMETFmG9l10IwpCL\nNb5RBMLgIkmNQprd93gcaZBkuiZIwxWTPmlEVuIid4TpNJY2SGmBBS++sskLX/sLqkMVkJIpH47W\nfB671GFDGw6OOSy2E1r9iPnhEsXpKomBrhcwOyOxS5rO2Zi3z7nkJMxd5eEZH3oWjz28wdF3NWkz\nzut0jR3cDrym/5PumOzpnWRH6Qyrk6SGNBG0OyGX7vlVijmJGEhcmewqNr3I4rFHnuLEiw9RciLO\nnXqAmg9x3OCaG34abbk8+/DLCLPKUM4l59usXlxndmaK0dEJBBarUYwUkjAMcRyHWq2GZfsopcjl\ncvieT87P3LEf/84j3H33Xfiux7e//SC33X4LfSV54vEnqAtwSQmDPq24RW54mgP793H69BK1oWE8\nz6XT7mAYMDxWJG63md0zzdrWBpWhIvXmNpZvoZHs2X+Effvv5KY3WTzy7c8Rx10WF3eHzK1vSVyr\nT7vdAwOelERdAzHM33Q1Dz34Em9959WoRpsbj95OtQj3fOZ+IhKuPjTFuRPbTE4M0RwkuJYkTm1s\nFaBdHys/RjF3kclheO6pk5x/6iRf/buvc/N1wzg2pEmCTkIGvS7NhmZj9UlefeVVdBdsJ7NJkdow\nCDQiBVdCrLJ6ySjQQqDILCa0AtKMpSU0mNjsUBV3nrPDKNrF0pbhm+cVhQM+kdKsf+IXmdx/kNs+\n9K/ZXGux8liftdVXKM4u0H7uHipzB2hu1+lFCTW3xeh4nuZWDjc/SxJ02F5ZZ25vhbf+wi8TD7o8\ncc/n+NQvfRWKkB+bIxo0Od3oEwDW2CzNtcu0o4RRoJAzGFtipKLqODS6LUxlil/99Me59bYFUh2R\nYvGT8k62N7f4q9/9Qxb25ekGAbLZpBr26MY9hotgYk2cKApjMwzadVaPP09pdoH9uwtPxqZU8P8u\nZjLm5+uvETsjLoXRhlLVwtZ5mknI/MwsRVdw/TULiM5l1peXqBVTLp8/w6G5G3jkxUvcdNMRTBpy\nYSPGLxRZW2qRc7Ou+/ioh7Q1tSGf3GCJUN3I2ovPcY2VMPn2d2I7CXHOpbUW4qUBdx09hi9ACYGl\nFCoCx/cpz+9n6enHCaTNk+uSD14Vorwxygjk2DiW5yCt3bHY0iSTWdip/TMAou9guRCbNjqoMz7q\n4vs17PwoxnIQtoUkZXyqwiPWLcyEW3zn3/01N70NXtkssWaN8sCT5/knb5vn2C89xqbV4cT5Vygt\nH4flJdq1Gfqdl7FkAo5N0pfoBDRDmCCi//dfxxqxEKPjiCTEarQZy8lMpTvMzmmVZONHIWBgKuS2\nR5g/pDE1jakdxXjXQbQNugtip4KQ359V/MY6PJ5EypBCr0neLxDYHiWtaW1uEvbyDI6/QP3wIYod\niSwXCMfnGSQh/soSQ8025VtvousGdLe7FOwmaaKhMsFaf4WJqUmceIArSgz0gPFrR9GDPgWriwn6\naLdCRIyWFkpprLzPoBswoVs03DxV38Hqr6COvQv1d19GF8qZd5UWGAlJojBGZzo4OzNduZPtqx3t\nE/FdipuvlwQ/2EqUwhIJWjqkYYSSEhkrhFb4sSJJYxKlM3dcKUiNySRYtUHX6zz5wiJV12K73Uco\nhcaQpgphwfgt13PxTz7L0h6Xb3854tyZPnk789PyBBgpKdiGU6fP8Bu/+DNcWG9y4tlnieIU25HE\nJktmIp1iSRuBRKXQNhGlK6OlOCDFAQm2ZSGUIqcNoRTYqSJvQWg0Agh/CJaW1BpjK5ROSVMBxs2A\ny7bBUmJn1r5jHioNFhpbuHzxj/4DJT/BGMmUSHniQp/lXsSt83mMK+mGkmKYUHBT3nl0jm4qiSyX\nJ04ts7GpuHXCYnPGMOgZKrWUbzy8TqIMKRZ33DLCsw/dx5G7f544ir7rtxfoHUNQbeROpWgy+nlq\nUArSGOJEECvB8a/+R4qDS3ScHP1BSKW8O6fD+77+JU6++Ay9+jKVUpH69grLl18h53tgErx8lfm9\ne7C9yxQKDlZouPa6Bc6cvkR9u4m0BGmaYjTkcnkGg4D19XXyBZc4jjhw4AArqyvs27tAu9Nh38IC\ng/4AFSoWFub52r1fZeHq67BdD+EWWVldISiXkL4DUZennz/F/sNX4foWcRzR7/cJowA/77O93CMW\nHvVmh1hZ+IUKG811pCd3TCPLTE7P8da3/zibm5d5/BvP7io23ZahNCzwpMDPGyZm4dKZ7Nrcu2eS\nh3iJEb9Af3SErVdfAs+hMAy64NGrBzRaA7ZX2lSHSoSdPlPVIu3uNk88cJJ8oUav2aEK1B2BW4Rm\n05A0e5QdQEiGh6toS2LZLjMzk2yubBKSdYw9X2BFGpMopBL4eU2QCogg9iDvGxwpSHV2sGB2En0J\n2ckGOs7wg+I1RtEullG8EMD4iqTqGqYqOSKWOP/JX+HQm9/L0Q99lGucAjrRBEpg0gFry03iQUTQ\n7ZPmHA4dGcIvj4BjkfMkg36fP/nFu9jchMQFXXYQ0metUce2UmoH38z/9pE3USpVKDqwdvkcj7/w\nAt984gkOFsDyLZbW+8zfdBWf+T/fSz4HrfYFLLuEsFxkw8fzHGI7IAl6rK2uMOY4RFFKqxEw6Cps\nVxJJQetCwHY7otPpUxgY7thleJIdeQhjQFzxMiS7tq8ovCMMlnAxWEgnh2Kd5pbN/L47UNLh2IEx\nrp0wnFuMEVaeyliFxsoXUfpGBh1AaTZWYgbkKBWLbNpNYmHQOsXPFXELZaYmRvnAT76f8/Ekz5Nj\nUCmznUpuODxPq9MnMZo4ihmaGmfp8iXaK5coDo8hLA9r/CDuyil6i0/wZ6ctSjlFow83VgpUR4bp\ndQNazTbVWmVXsVGxes1BXmkgjDGOg+1b4Eh8dYqkb2H8CkRWtne1gjRkf3KCODnA1+1bcH/l12jl\nJCvbHTZDzc989DCvnjlBr/MME+uvMHHxFdqRy8XAMKlt7jk/z4dmL6K9GJ1ahFsQp00cz4JZD+EI\njO5ijVaJkxQ/ijBSQQa7BbWDWwbOhzNMX4w59bkv8Nb/dQqTtkDWEe5Qdk0WxzHGR7dPwfeBOL0x\nLT0Y4EyPM/7xu/nqdTPUkx7VS6epLi7TPHIY75ojjKqUeNIl9nxYXuPw9AjxIOTiwkFGNtaIgzWC\n3ARyzxiTcUpr8RxqZpaw1yRE4EoLL+dQFAm6WER4VQaWQBXGUEmMUAk5x6ERh6TbG2wPTzCaNohD\nQb8wyUQBGrkStgEbDSYDpMZK48gsqbHQGJPhNPQVqwMyaqLZOczNLrsYghQ7gQDwjUCpjCmSxgpp\n0mzuH3bxpItBYts2dqdH3/YQlkO7N6AtMnPEXhpRzXkQaYq33czq5/8K3e2w+ILmxDPr5DwH7Uo8\nMoq01Bpp2+R1jz/4oz/H9T1KvkOQJOg4gVRhuQ42AldAalSmmaQtPMejF4VIY+HJBNtY9I0CZYjS\nBF8KEmNwtCFvLLpGY8ndd3gyHy8rS2h01n1TWFgIUpl1eqzv6rJYlsvjDz2LiRIGRvO+6RxFy6bn\nSK4/UGYQpRQFnFxrU8oXuHF+gpma5LnzXS61QgbtgEBKHn60T75gmKo59LXHzHyeVy4L2vUO7bUB\nf/uF/8ye696OO74HbfSO+WeW6BhzpbMDWklSZUhUlqdGCSSxYO3Ut/E7D7Ad5KDfplZ0mRjanQvk\n/V//JI52UIHAjMxy1bW3sXrhNJ3OBl/9yqcYGjtEp7XI8HiOfhqhjIW0NH7OZ2urTrHkY9kWnuPS\nbrep1WoopSgU8gz6fba2tuj1Ai5fXkIKw4GFfZw4cZyj1xwln88zPz9Ho1GnOjJG6uSJBnDqwioX\nL2/RbYVce/Q6dKpApDSbTXr9AUlsaNYDwlDQ6NUpV0YJE8no5DStlfPESZs4bLK1eZ7puWlGq0fQ\nUWHX+waZYaVSS5CEAi8PlbIgNoZLzz0BwNZTD5E6mhcevZe9scWBW45xfP0c4VqT9naL7XXJ8MQw\nrXYHPxpAYqiNgEbjFXLkixIZaRLLgpyivQNgV2lK0O/ilcaQQuH5LpVKibkZj8YgQvrZfjAWlF2H\nybxNKxhgjMC3s0TGtyWRMqiQHSormYz5TmfhNXy8uNKN+MHXu6YdXu3DQ62YH63ZrHUS2pHFodEc\nrz7+dc489nWEBNeDcmWW3PAk1fII0s1j8hAFAevPrNPavszG+iadDkQKQmDg2nQCTSWv2epGLDVj\nGJrhPVcdob18nsLUBMN75xg6dJBjhw/xwXe8nV/5j79N3FfccP0+PvHvb2LQXMSoIZTIEegujlNA\nJi7nnv4WVm6Ii5stwlJCiKCWd0iiHmkaYPsuQ9Ui9XqDnsjR7iR4jeaut06q5A4aj+zARCDlDoLH\nAlvapGi0UmAUwmiCKOXowQ+wtzbObdfuZfXlp/nyfU325LvsnR2iMFSlr6GYs5mdG0aV9/LSqSaJ\nMViOxnGhPUgpjk/S63UYqgy4tB3x7GU4vX2Zfzo/RdhvU6SHnSZ4hTxumhIRcGDPHrqDS3SXziKF\noTQ2i8lVcWeuZejw2zn5d/fzIxNZUrm9eYGVxVGmDk5Tv7ROdai0q9iE0sGL4telNqIUsCAnkWmG\n85U5BbqBzMKTdSMtMA70kwnmWmcplmP2rC2RHy5h/ALyQpM7zRbmgXuwChojIyraQczcglcr8PJX\nXuSfzoJQKXp0HEv3KRZKmUrz+CzStdBKYlSKW5M44QoyBzoSmZnjztIC3MYZLr3pZzly68f4+id+\nk/d87FbWnn+C0tw+ctUiW/f/Lcrk6SvJgaO/+T3j8IYJz/s+82sklSrDtsfamVfphYrc/Dj9/dfQ\nExbFl16CQhHHd8kNBuRrVfawzbempnAjnzGnyVZcxveL0OzS22jhTA4zgqYTppQrBSKqtANDznWw\n+ttM6TZ5kSc0gkE3pOC3kFYJR+SxzYBktMjKhR5zRw6hG+cJeinJNdfjnzyBcVyQKVrZxJHCszPF\nXKGzUYURIIWVTTH0DiX9CvVtlxieBfok0mJ4eJy17SY9DTktMs2aNEbGMcqRFPMuKonpJClxnDKS\nE1zsbOP5OcKgi+t7THqlTPbeVVTm5+m/coIoSTKLCUcAKcEALCnRJmuX636CtpyM552GhErjSJuB\nihgqFegHIY7M1IAVkJOCVGl6KsByHGwpSIxERTGWk8NWaQbmU4pEa7SReCT4GHq77b0DCRqpUrS0\ngRSVAHhZYikFwsrAy7YtkQ6cO34KN1ilVC5TtQUF1+ZcL+LY3homjpiuFdnY6jNcLdLXmm+eXOcr\nxzV5y2IsZ6FUAm6ByCtRq5bZCCOeeGyL+WmX6T15Lq9axNqhlnN48O//G3f/m/85a/MidthYWayy\n7k42vtIakhTiGNJUEieaS/f+H4Ra4fgpU+PjBEGHc53ddcAO7b+epC8ZLk+ytnWRXrvP+MQM5049\njwjWcFyXKA4pCE3R8+kMIuIw4sD+A1y8sIptQ606TH2zjjGaKIoYHq6RJCnrG+v/D2dvFqRJdp7n\nPeec3P+t/tqrunrv6enpaQyBGSwEQAIkQXAxSdGmFbbCXIIiL2jLti4syrZsy5YdDkY4QpLtkH1h\nK0xTCtMUSVMKmRAEEBSxkAAJYgYzmLWn9+ru2v99yfUsvsi/GwiLgFTMi96qIjr/rJOZ3/m+931e\nlrpdnBM8fPiIOPK5dOEc62vrvPH662yf2aLT6XB79212GkuM85yqsrxx6xjtBNs7V3CE9I7HTDOL\nrgx5kRP6be7dfcxau0luCyZzTZJ0OTocsLN5hePeA6w2zCbH7O++i8kMWXb6l5aUAps7isyhRe2M\nc0oSJ45RWusWRNqjvRGy2Q2ZDhp83yc+gv16xIM/fZNKQdBtEASG2XCA6iTkZQkFpMeHDCcj8rkl\n19BwlmeXFDuBZKqgKA2d7jKNxK9dnNZQFXM+fr6F89q89mDO+nuv4CUBoc558PgRDojCGk9hraMy\nFlNp0PUad/4iNNdbOIWeOEUX1vTTHGcCn/XA8raAf3Fi+N4lSVFpXi0UN84kNfhUSQLPZ5qOybVm\n1DugzFOcK8hLj8q1iFZvsHY2o0GDtjvgS3diHp28w3A6J582uNmb1+u0WfDml/+A1eVl4iCAPOdk\nNGVne5PNts//9T/893zmnS/zV39xiyw9oNlsMUlzGvESaE1lUjylWbt2jaVOj93emKIYcZLElP0x\nCSntOMDkBfnRCD+K6Gd9TGn/XJTuJ5lZjgWEVoIQsuZ8OYvRumZsSYEnHM5GfGTlB3lPZ4WDvZv8\nH6+8wr1BxY9fjVnrRjA/Il66wLnL72c2r3jphecJojYP9l8nTwsCArYu3eA9Vz/M2Wvfxfj//Y9I\nvJCd932QMGnyV1q7NN97lgeHDXqPHvAHX/sav/7WW/xvP/Oz5ALkoxGXr7QZDI/x1zZoYhEyJNo4\nx8rmMj+8Du1I1HqZZpvpdEyRdgmVZD4dnera/Fqu+WljCRfuUCtAY/CiACHnyBa4Bsh5CF6ICBuQ\n15tnmQ9ZomLl2mVevnlIr5ScmU6I8wOuPPsce/uG5mRM2+TIGEQWEwQz/P6QZDYGBQcf+V9Zf+ED\n+Pt/gK0s8u7nkde/D3np+xCHf4AbvEH4uMfR248JonoCIJ5oTcwizzEzeO//KcrIZ19oPvPf/c9c\n/Tf+Al2Rc/Snu8z3Z3QiCB89/rbX4TsWPPsVeG+8Qb68SWt1BX3lWWaNBul4ztl//k/g6hW8jS1O\nVjd4Lh3xYHWLV5qbyMM+w5N7HC+fZzM6ZOjHRNUj+rKinDXZCgacWe8wMMtMizmx7aOnY5SsGKgu\n0mviFX3Cbi3wy1SDwGrCF95Hb5aTK59xNaJhK2ZhjDmzg3n965jKB+FjTD1K8WSN3RR4aM89tYdq\nt7BFfwuG/JSudPbnAg/HaH6A9UNEnjPDILSlMJbQGDwN49yhgoBsXiEDj6NZwfLmWfb1azSEQhrL\n2DiErYGOa+YR13/oHIO9nOk0ZXQyYZpWuOEMK0OskmTW4oSrO1q2ILOSBIvTFaGTFHmBtY5iYVWQ\nEgqnFrEIGuV5WFsHz8ggwDhDlaYIz6MQ9ecSwjJ1C0ieO70tHeuohMNSU5ydkGhd4RGAhkBInAf9\n4xO+/unf5OUv/FMuPfscRw93WV5vc2zhynKDobVE+Ny+N6T0PC60HUkSMEgkU2MZTSyTqh6foS2j\nec7J0Zjnr7dpr/usbKzjNQ0vfcAw7ZQoCb/9O7/BT/wHv4x13lOBcl3wLCIjjMNoQWVcXfBUgrKy\nHN55haRd4MoIL2rUsRtBQpKmp1s7ewW/9Jf/K1rNFl97+Xf43O/9P5iwQ6u1hcln2Pmcfm+O8n1U\nUpEVc6T1ODo6YHm5Q7PRYNAfopRHnqf4vk+v12dre40PfOCDDAaDeu0Mh5w9u02/P8AXkkuXrvD2\nzTe4dOlizQIymqjZodIDwrjJ3t0HrKys0V5tYkzO9uYGDx8/QinJmY0dBkfvIDzwQ5/NzTPM84qd\n9TXu7B5CJUlnY4R8xNFBk7XuKmFw+nUjRf2wi6RgYx2GGXixQ1l451jQXBLcyhzrhyXbTcXNez3K\nz3+ao6M5y+sdzl9a4sp7rjPce4iuDH5gkFi8SHCy/5hJWSGQNBKHiurCwy5ehFprbr9xkzPnzzKb\nTvC8e0x7PRJtUcqnOE5pf7jJyuYqni3RUcD9/tdRs1qWJoRYOIEWhYxweEogZL22nowUhANcXRCd\nat2UgiUpuZQ4MmP4+sTxfBP8yvL6Y8P5rsWXOfHGOmGyzHg+IWgodCnR/jaNTspS9zqzvbdobt7A\n7n6JaulFPnS9y40Lq7x6+yHTSnN+LePVdx7hEKTasj8a03/rHbQuOLO2xf3de3jOsLO1yt/4T76H\n3uxNwjimsBCFCYUxrHWWScsS5BLt5jY3PrLHG1/6LHa9Q55nGCS2EZGVgmw2BaUIpWZ3DNmkYKNT\nnnrt2IU5QgmBkqC8J4pXuYBg1rl1gQJPJtzgJ1j2XuBTn/0Kr+7PaDVjrp9bQYgYG2/QTPrY2S5R\ns4FzluvPbvDqH3+N5tIWgbmN82OuXHuR+OJ3QdDmduuneWHF430f/16Wqznru2MsJRfWIuyky+ZS\nwic/8hHybI40CqtLlpINMjunSudIKUEY8GNKF7C5DJO5oxFBURS4uMlwUCGDmGp2Oovf/tAxiSQr\n0izYc/VatK7e5GEccu1n4BM/gLv5u4j8APnDv4I4/Ge4f/x36N15zNlWBzEZ0MxGXP6BDzMbpqSH\nJ7hRn9R16IocpRY5lrMhVTtmllJr2Ub7+J//b7GjCcwGCP8IcfgK+nP/ZV1cBQHCRbgMMHVB5i3O\nEbeQHgQgZwNKEuTaJXY6Gftv3kUay60Hh7y4uYKnJeLKd33b6/CdNTydmEQuo6cT7uRAJnD9I1YH\nx5jnLqNXtsg666TDKZPxhItbjoe9CdutNrPOOnL+kEedDfz1Lco7Q1o7DY4LjwlNRmWHxGasVcco\nNyeLVvHmPcKkiXUW63m4zDIL14hFSNTymGUZXm6IPMVh7rO2fpWV4UOOmh5VkRMmbXCCwhgqZ/Gl\nBKewCHzrsCiss/WH/pbOjnPiXyL9/qsOzxiEdfjOkJcVzhiMUnjZHG0tnrF4tiI1BQ0T0K5mzAgJ\nncXrdMimOYHvIT2J0jUksRv5vPuZN/j8wx5R7KMWwZYSCAO/tnlXBmT9WZAWg8ETiqEzSLdAbxtQ\nC35B7hzVQlXprEFahTQGp1TNyjGWwBPo4QCW1/CFQQqJEBBTW7JrmMXpDuscOIN19TkaUTsEKgTC\njzg6fMyXf/tXefj652kEGe0oZD68z3g+oeVZpPF5NMrwE0UpJbPM0SsyLjwbE0uQgWLaL8ALqYoa\nWNfoWrJcMhlEPHiQ8QPXEr6xO2VDJ7QvxezfMVy8JHnhY4qvfeGrvPDxj1GUFmPEQqBcU32NrgGD\nZQWVFlQajDXoW78DvsDmglarw2o75N1H+8RJdLq1I33CpElRQFmVaDukKhyr62fYv/MGnZbDl5I0\nFTRiRbvVYjbMqMqSVlNSFOXTJ0Gr3WY2m6GUYjKeEEURUkqm0xkgkErS7S5x++YtdrZ3uHL5InlZ\nMpunlGVBlmbkpuTs2VU+9rEPkuuM+XyEFB6mslw9d5l9ecju8QHtRszqepfjwYS8mON5HpPpMWGc\n2wNtaQAAIABJREFUoF1GEFccnrzDcNznuR/7eZw5nc4A6rl95EuWOzXhuRGCHkM5FWRjB5HjT0cA\ngpffTTl/Dub9B7y0s8GthyNWt5+l1Yg4LC1m0dVNpWJ1pcF+b8wgg/MNyb42tBVM+jAxmnIGum2o\ntCNsdlha6XDtmef54uc+x/69d7h4oYuMAx7ce8A33n2T5aTFfFRQjiFwiw2T1PhKYRfwF7kodKhJ\nGbUrsT71+tHzxB/9r3l8PRdcUZpNH1Y9Sek7vjiA97cdvtLcGUo2GgHZ/oCl4BjlxyRIrJa0Gpbh\nZJ04u49SSwwev03Sukrv8U3y8DzO83nuufPcf+cmH3/pRX7s4x+l1Wxx+9EhvcGUDz2zydryBqtL\nXdLZjDuPdvGbTbQ3WGTmxVib0WldRfgeo8k+kR/Taa2Q5iU/8hMf5+//ymfxijkqbiN9n+E0xVmD\n8D1sUdGf55SVIystWXn6YrkqF7oPubChW4OQIKVEIOtMNQl55bAm5a35nIff+Aqb2+fxendZbTdo\nhT5Oz7n2nudYWpPc+syv877LKxwf9VhuNwiSBof9Y6Tn8QM/+8vcniX0Dk/40Aef4cMff4nVQJP2\n+1zKezjh0NmEMPS5cX0JRMz+pIfwHJ6QCCuY3J3RfW+XodaIQEFlwZOEW88wL2BzSbC97NPqdEhH\nh7Q7bVxRotTpuu7eXKCD+p4wTxMJbL0ZNhJmBuGViHaEjdbgT/8JvvuB+plYRowmlrW05KNnNe2V\nHZZ+9Hu49Zk/wLx1QKMdMNZNrDnCcxDIimZkKfMp4xn0qw06b/1D8B4hczBVLebPx6B8UI05BBY8\nUzcgaorMN6cuVtQ/U1NihKTEJ/YtTse89icvs1u2+MR3X0EPxxRzTabst5PwfOeCJxARVbCKuCDp\nWg+ze5v1JERfPIsnIvZdRW8+5eLRfTZXt5kYgQqa9N99m521M5SRIOgsYR7tk5kIU7XxyxF61oT0\nZdSFC9AvyKUiGd9DehXZrKIZSpy/gVtSpAWsegaTZ/SyOgQtSnwoMuJxzrHXYe3sJY4rn2QBtisr\nuxCuyQVG/MnN474FPw6Ghcj59K5rgny2SNy1+M4tLKkSVRQI7Yi1xo7HdJZ8jM7QvR6l59NYWYUi\nx3qKKPAR1mKsRjpBnuU0JIShw5MGZSXaA0/UL+DcWhJPkhlBpCxWKaQWVEYTSIl0Am0MgSfrF7YQ\nNb1ywR4yyJpLZA3SWOIoItcVeB7u+IRweRmlHZUySFOvt8r9+aIlpKijPpSDcuGKsKJm3bz5xmv8\n/b/+SZ69tMPF7XU+9/oBge9xZm2Z7W7CjowpTcmzl5b54s0+mdEM+injsuTmI0GnoZgZD6ccUVJn\nfy2v1nbL1SsJjTZUM8tb90p+/LuWeXNPc/i1issfDgiqCMo+9x8MeP5jTxLP6xeTNvUYq9B1dISu\nav2OqaCq5rjibTyvzWbHY2WlwaPHe3TaTYaj0+22Aio+++nfoBVf4mTUI4wElXRUWIajkm63wA8M\nw8GESV6ysdKlLCqWOl1wsLa2xvHhCUZbxqMRzVaLKIrQ2rC/f8BsNsVZ+/Tf2q0W29tbZFmG5/ns\n3rqNs1DkJXEYMZkbLl+6QBD6UBniaJU4SlhNWoTCJ+106JU5oaeIfJ84CvGbDYwA7TRBIyC3DiNG\nbJ2F2fSERw9f5ey5q6deN1R1C2ReCY6OHV4gSMdwfGLBh61NOHgA4PCAj3z/+9nYXsOr5qT6MWVV\n1t2DJ1l5ri5kfaEJFVQeeNayYmDLF7SuWV78wPeyujvkK6/eZXnrAs7zWF1ZIkkSGnHEcFjx3HUP\nT5d8//svMwtiVle7vP76bd7Y6yGkgwKiuN44xX49xrWGJ3kqiMUoWiiBMzUKQ6jT7bLGRclNodCm\nookjEY4zIXxj4tgpBJfa9mmn+txSQD4r6BY9POVR6TmRf4Jo3kA1E7xpRDp+SKOzxPTxu6zunKc4\neJcXz3ZI7/9TNi9/N6pc5ruv7tCInmU27JHfv4fc3iGWPle7XaJWwjuv7HL+qkYLn1aygiXDlorl\nzllUAAfH94iiNp2tNf6d//iH+Ud/77OshiWzaYFwFuWBnuVo4ZCmHuk0goJRcUoMNdTmlIUzzuga\n+KikRQiLUvU6MAa0EejCMde/i+r8JfrTQ3Z2dijnJxjRRcUNpCnIR1AanyiomFaadrhMI9HY4yk/\n+Av/DWrjMtsnPW5cu0BlIM9KZOy4qnv4wx6V1UhlEdLhRS0G4yG2KrElyCjCOYuaG3Ij8XyFQOE8\nCSYhOn8VEDhtuduDj26FCFuRpYbtM8vodHqqaxMUJdYozCJEWsGiMK+1iW4O+mu/hffwy0i9h1wG\nM6QGRwmPzMacDErOoZFOk3/hM5w7+gIPHzmSHQ8aK+RFjZXwo5JmOWGej5h4Pp+7W/AzzSNMUVf9\nngdOO2SjtqM7DVJkCOeoqoVr9mkCrHvqrMOUCCFQRrPJYzY3mzzz/EXWr64j0hGlSEinj9CtlW97\nHb7jqjoUHhfPbTHqjWlORrC1Q6kEw15KGuVkec6ShdUXv5tjLZk3Y4o7d7iyEjMMNEG0TXNtk/7+\np5kHbdbSKWU+RHhHxEvLeLduksjHWFPidZ7HjyKM18LFMccoGpR07AkiXqeqDFb4lEsJvp0hpcKa\njCCPSJM2utmqk4SthyWtAzdFLVZ+EmKpXf38qYTFcwKnZM0n+JdBu//K4/BgH8/ziaKYbhBQWEMg\nIwpticsSqws0FqoMjEVUBXGWcyJDKDJUVWIiH8/VdlasozCWRuWQxuBJycyUBE6hgUoJfARGW0Lh\nsM4DUyKVR2YsylicMXgSppUhUAvx7YIyXVpDgCRQitxZKhymrFBCUlkN0xGiNFTCElqBVWC1JXQQ\nn1bgxDd3VzX7wWCcW0SQhvyjv/kJnn3mGU6mc17/6k2CMKAbe2xsr3P/5ruceIr3b8b0xiXzecly\n5NGzjmYUMq5S0qnHjcttErGEVT5e0/Fo7Gg0UmZHGdtbXUylccuCz93d49xmm4uBouynHPYUm9vL\n3Hz5t/nEX/q3sFoshMr1C6nQtW5H67pTpjUYXZGPB0hXcdzPuX5ujYe7d4jjLW49eERnqXu6tfPw\nXeKgRZZEjPOMNK3YuLBDMddMij7HYzjp93H4RK5NL0s5u7HG/cf7IH3O7pxHCMk8nbK83KUsS+I4\n4vj4mNXVFc6fu0iSNLDWsLf/iG+88Tqm1Oxsb2GtRxiGdLwE4epOUhwklGXF4fEBm9tn6PV6NMKC\nZEmwPxhhw4j+ZMrz7/8uBo/exQjoHx0jwwgvDsnFkEl+wrnzDfJ8RKcref2tf879+w9OvW4q48gr\n8AtB5NfPOa+OtQYFQVLvTpZjwU98F2S3X+PlyTNcee49iNacwAh8qTD4deafFEznFctJ3WEpDMwF\nLIegJWQ5JJ1lZnqM50mmkxOc0Jh0QJ6W7O8dUfjwcOyYGrj57us01jdIIkV/OK2dV2aB+nKi1lR9\nq/3qyR+FQKg67Vw8saef0qb1UgxfnRt2K8UZUREIiIVgxYOJhjeGjmdb0AgExxODQ1JWjtDXjOcV\nCE1n+FVWllvgdyiISXsHNJebVKPbOBkwmU/pbH+A1oUXMVZTFYbp8AhrBWevXcd3grYCLQOcs7z7\nG7cY/2SLa+81qNYKlS5Jkk3Kck46GxElKwzGAwaZ49/96e/hH/y9zzJP83qconxUYTCutuLIIKDK\nKlCq7mKe8rD6mxRl5KLLbOsi9AmFubaki1rPqTzSPOXKzmWGgz6BH9Pr9/nJT7zILC159w//kGsX\ntjBlRvFoTtxo02xpHuSbtHYuYSzcuHqBUaqx2rC10cLPpsz++CvkFy8SJE3mh0NC66jchPl8irU+\netFhaXg+Qht8E7J87VIddrpIVm6fu8xaUCF8wfFY8u79R3z0g+9DSV2H6J6StBwu1qlmkT5g66aA\nW0wEnKtRJnZvHxGA9T2UNbjKQaFwSUDUiUnv5/hJRcv2mcUvkrz3JmI4xC9bjESLtpoiLfgyI6pm\nHGaSVkLttjIO0XDIi89hvHOE3/+LOD3D/M4vwUmFkzk6X5zPE2/WQm+LBOUKIhxGCHbcIXa2wqXL\nm0TVHsW8QYrBOkdybvPbXofv2BdrqYL8wR7t/JikPMILHD0VkQcOpimyqrCdJscHI/Y1jKcz1HhC\na2mZ7tISst1hfnzAWqdNZ+MMru2TrV4k2tghthGN89c57Fxj2H2eTEBfdZl5PmmVI/IpaVrSaLZI\n05JZskk7FDSyCr1+HuE1MF6brBUQtjbwd87VYkFlyMtag4JQi/gCsYiXEBjnUIt0KGHrXvST2fpp\njlbYZC1u4VtXi+G0Ic0zcpOBrpjlBdYYZKER8xQlBF3tKKcpqQM/ifEFWGMJbG3jVkrSXGvwIz/1\nIpmTfPQjN9g6t87amRU86/BdjQUvNXjOQuEwxhAK8BfjuWrRmXGmditIKTBOI4wlc4aJNVRaY7TB\neZBhoSixx8cYa1Cm3gapsh53lQb0n2OkZUyd4m5MjQcw1lFqg5USmaxx72TMPC3oNBI2Y7VwPXiU\nuiJNKyQwn1c0A0F/lrGaKFQU8PyFhPde6bAUBhwOM76+XzEaV1w7vwbK4i87Th7O8doB62cd3c02\n0ya0z3vsval59eUDRAWu/8fMhkVd3BhHZet8LKPrgNCydJSlRVcVVTWD+TG3do9IGiFZlVPokHGa\ngXLs7X97kdyfuXaWGuRphXBNkkYbJ0LSvCSvKp67sUOUhMzzEYIcqwvGoxSlYsrCIFFIxCJTq4kQ\ngjRNybKM8+fPsba2xtJSh8PDx3z9la9ytH/IuXPnSaKY3vEJCDg+PqYRR6TTCdPRAGNgb+8RW5s7\nNXXaOdpJwsuvv85JPuPh0R6tpSVyLRkPp8wmOdNphjESa32Eytnc3CAJu8xGAfNxjKsMJ/u3T71u\nhACkwEmHwTGfOY4Hi5uzAqfhwtkmP/GxTcK4ycHLGnXYo9Xq0t0+g/Q9lHAYaxAWBB5KOpqNCBH5\n+NZR4NhqwVoDUgtf+dObFOmYLCu5dPkq6+tdrj5zjatXr3H12mW2Y9iitp47K3HK5+XXbvKHf3Ib\ncoGtwPcFQSAIA/9pl8UtEBD1W5gFuXDxQWsE1amOrUDwXt/yMDfsVpK5gUDWTs9pJZhreG0IB3PH\n/tRxMLXszxyDuaWyjt5csTt0vL075d79Xcb9XeZlRitOkOsfZWn7Bpde+ARmcI/+W1/k8PZDzGTI\naO8eZCmD/gGT0UN2D+8zvv0Npnvv8GDvdf6Lv/Zl3v76BBDEQQNnSoQMMEZSVh5GO8oyY640v/DL\nnySbaCprKYqauK+UQjvFPC0AR+Yc2ekfOfX4X9UasG/N0bKLl31lBHqBIvGkZHySoPyAVhLS3byA\n9AOSlU3eeuM+rsyIO22E9NAuYKkZk2lFZT1+7Od+keW1VYIkQXgBSeizs73MUiMiaMbsrm1z/JWv\nI+c5XhSgtaUnNLl12DAibLUQsib/B8rHG84J16/WMzkErpxTTOa0PbjVlyS+IQgC8jSl3Wqiq3yh\nwfzXP8J6T01lFxRq4AmNzmqBM+CrJqq5ibXrVJOE8r7E7AGVAF/hCQtrK7hGjOmsI4YHaJcgFQQi\np1IJrgKsIPQ0pvAoKtiJDdKv30VirQ0f/knkc9cxNsCVE6So6vNZaHesfrJ5qE/UubrzrmXEyEB/\nnnF0+eeRnuXyT32Uc3/rNwle+j6KTKObXbbXv73g/TsWPEn/mMZgF3cywKxcZZobqn6fMi8Y/ovf\nZyf0aMwqSjTFaEB8dMTGuQvcJ+aRC9lSORbHOz0fHhxAPqF0DpFVODUlzU9qnggBg6DD1ExoYPGl\nz1IjohtbgrxAeILZNMcIReU59FTTzHtorUm0gqKARgtZ4wMoSoMv6tiIOpyytiZKa2tRJNRqfVfb\no58yek5xGF0xKXKUs0yyCltqKq0piwpdFtgsZfqNNxk9uk8xHlMFMQ+mM3SaIuZTPOEoKo02ltJW\ntfXbOm7e2mNzOUFPCnRVIQScObNKXmlyY7FYfOFQyic3FWVlMYUmtxpZabAaz1oqK+qgTuswlUYh\naEhZvwyASDh0lhEuvn/6+ABlDFJbtNZ1MOxinqOq04H1AIyrIzdKa6h07Vwx1vHw9uusLzUwuqrz\nmpSlEfhk85wHjx+ineTFcy22OiHWEyjf59KZDqVSLAcen3015/WbM752c4j2SkQ65ExLQzbhwxcu\n8OK5c1zdWqWblCQ7Pttb0BkXPLg/J8srnCtpRooL2wmz/mHt0DKLsZWBohKUJWCgrAxlOaUsZ5hs\nHy8AV+bsH0/o+oKj42MCJei2T2cRNaqisbxM1FK0l1aABpW2lEXKcPKQnbNnuHjhBgifRhKRpTkP\n9vfQFq5efZatrS08z0MpRbvdpixLjo+PEaJ+SIzHE7RerJ2dHW7fuss8zbh9f5evvvwK3ZU1gtBH\nSIPEEYdNsnnOwd4ezWaLCxcvcHB0SInheDbB+ArtSg7279Hd6NLqdCnLiuGwx727t1iJG3jG42B3\nROKv0I7WaAWrxPaUKGGgtAIvAFvCaCqYFzULCRyyITAaLt7YoVx7hq8dOV4r4cFRzv27ezw6nrN1\n5TI4w3A8qot/CVjLZJrjOUclYL5w6pxpCFpKcvveQyJP4Hker3zpC7z5tVfZvb+H1hVGG7IKDqSj\nsaK4cP06z1x+hlKbmknhnuC1LMZAaS1aL15G36QuPIW5y7gmeAtHjdo/xTHWsBxInokcIyN4bCUT\nA4GC9dDVgEwLtybQK2BQQD+Fvanj0UgwK+oOZlpYRrliXHi0mh66yoncMVF7i6OHD0jWrzOdgMgm\nTE+O0cWcw6O7HO+/zeHJI3QluT1O+fKru3yjanMAvP5WBZnHcLpfs7fKgmawRpYXBFGTpUabRtzh\nkz/6fo4ArS3CGowTZFpQ6ZpFhnUUecXa2TOnXjvGCiwOt3jG2wVywi3iYr7ZUKvzFeNV2DjbAgex\nZ9laXmK5GXPx8ga3jzKuXH+O0XjM8eEx7cghygEDndCzTe4+7JM0Y8DR7TaJQ4/AV3hK8t5P/CBf\nHPfJJkOkKRk1KlJVgPTqcaYfEIYRVio8T6KyhekhjEE5nDaoIGBlLcRzBucEz1++QKBLiqqi1WpT\nFN8ervdnHcLWEo4nXS5nXX0uzqFNPf6VDYE69wz+S+8n/OEfxfuLPwcf+XFcMePWyQjlCezGFuPt\nj6CNxGUT/O/7RTxXEqsS6ydUFbVRph7QEQlYTTRGLqYNmcN+4Xcwv/+ruH/4byJ+9z8HVXer3BNo\n5CL/7OndsbhfhJG1VMI5ws33MBj1EbPH/LP/+m/yP37qTYrlVVpNj+Hut4+0+Y4jLT9qo8+cJT8Y\nUuaO2BXIdoiaKlr//s8ztyGT0YAyShC9E5S1HFYVnie5fG6bk94JxaBge/Mcwh/hBkPOj9+iWr5I\nOcmIuptsVhOKIEECTSeZxx2UqchcwaoL2DcGezig1e5Q5jNWVzZJ5Iyx36EqFHaqifxj9Pkd3J03\nKbVjnpc0AwVS4pz9Zvq1dXWulq3TxS3fMss6JYfHx+FwSG1IpCREIK2jcoJR7xibzkmaCeXRMVYq\n8uUl7HiCGw+g0aHViCnmGVIKUCHpPEc6+Ms/8yGuf+g6f/vXruIcDAYjfusffKVu8BlXC409SVrk\nhFKisGglQFekglrobJ/kYVnKBUlYOkteGTxPoaQgc6KOd7CgbY477BGUOX5VoYXCSYWHY+Yg/XNo\neOZZicFSOou2Ck1NnH78xhd52Juy1gwJw7p1PfLg4rk2fidmfDjiUlvQnxUUheHqasS8cmw2Qs6u\nxwRRAQJUYJmmiq3NhMNRRn4yIz+OkFFEd1NQ5JokiGAYsJx4NKm4+qGQH2mvcvekR2t9mSLNyIva\nhaUXsRGmclRVSVmN0eWMqswxpmB+dIft7gqDcc4kN1RKcG57mft7fZYbp2u/L28mvHHzUyTBl7HG\nJ8sLnOixvd2hFe6gVULSXWM7iBE643s++gLH4wGXr57n/r2b3L//kPW1FaqqIo5jkiRhZWWFLMu5\ne/cuQRCglGFre5PDo2NWVlbZWO3iB5I7d27TaXfY33+EtXD+7Dke7B2wub5NZSsG/T6tRszZnTNI\nP0AEIWk2Z6URoJxAq5jDoxOWV3bQYs6VZ3fo9QbsPR7jq4TtrS1eee0brKictZXTtd0BqtShpaC0\njmEP1Lc8oSyO+RySZkJvMGHkRezGc4rhlMt5weVnLtPtLmP693jn3X2SCgojOEklWVGxlkh+7EqE\nkiEdGXF46wi/GfKRF5+pO6HVA1546QbJ0go3rl2td5ZVQRzCB194njJ4jFGCwXjCg1sn9XxA1Xoc\nawWVsZAVzEsHQf0SqbutPG3LL5in9Qv4lLSH16aGUFoSC2eEIXNwRD1OOythLai3dpmBRylUVtAN\nHKGEYVEH+QZS4EkIPUegHPeGllgd04h6RLfeReBQGDrtkN4h2Lzg3Jk1Yp0yMh5ZOefxTPGVxwWv\n9eE//Q8/yP/yb28SBi2slxF4Cc7NieMO2gY0HLTbz2At5PmQ5kqX3//SX+O7P/Z3uBACUlJqS7BA\nhqSZ5sM/9MMsB6djWwE8AT0K4Z7m4Qm+JTlo8bt1YmFdn3Fo/k96Iubyyv/E+hmfuH+Lr7++y0df\nukLgCebDMS6fcxhdQl//OVzY4ObbhwwPT/jaH7+CLSvOXbnI1tYmH3jPJq1G3eH6yb/9t/n0b/1j\nPrBmsNOKvAzxvLrz4IoCayyNMEBFAVE7xmIgn2OzGTJKkKHh2o0PMJj+ETePHMd7Dzl78SJR4FGZ\nimKen+rSJFiUkxjk02Kn5o/VRg3nFJgpTL4Ew4WGpgSRxgxPFCdphXWCr37lDq9sWH7lcs5+EbP2\nzu+ioyX8skAKQTYH2XR4ccbSisfGquPcUsb9k1XuLH2Y7+/9LmI0RQYSVy3wVEEJur7XzaKOM4uv\nIUC6BV7G5WB0HVAUxxx88G9w+7d+jfv9Pu9d6vLWqz0++swmf/KpP+Lf++U/+zp8x4JHV44yMzgq\nxGTEcG2NYTZhaaWNrTwGoxOWqvrGa1KwtNxFB4qVhgMzpWhvIhgjhg9IXIUfwACPTTVh0myRzXOa\nUQdPgZERnaRFLxujwy5yvI9Y2WKl02Wu96niFkmrTeW36asNmrPHzLyClY6PttssnRvVLXajsZXB\ni2pLonySnYXBSIUyT9rlsh5lLZp75rSYd1cLk50DT1oKBEoKKiGIwoB0UudsGWPwfY8yTXHW0bAG\nladMpymhkhRVhSzBl/U4p5rPGT++T9yKiaKAs1sJKyseD286tG9RFoyp6nakq3fEgScXbfQ6sNOX\nCi1cHbqGwgmNFa7e8djaIZD4HiAodIVwDhH5WK0JjMFgqTyNQOBbiX/aJzM1eLBwitLV4wUNGBT5\n+JD3XVphMJzjkFSJRM9KSiTPPrfF4P4hn70z4b2bLTIL660ljqZTntlqsT+Y8qFLS0wGBUOTo6Vm\nlqastBWDyvJ4XhCkFfsjn9VWyCAwRL6HkYb+saN4MGHjUobXColKhyGkWhQ6dTPLUemSshxTlVOq\nKqcq5gsL9xTnC+71Uq5fvYDKJsyyKZ5UDPPT7dR1ukynlbH3+CFL7Q6jwQQ/OEueQpGVHBw+YKnd\nITMZzdgjL8asrsYU5QhtcrKZZufMNmmaMh6PybKM2WyG7yuSpEG73WY2GzGb5awsr9Jpd2i1Wqys\ndFhd6XL37j1WomXm85zdh3uMJzOKSZ+NrfXa/moqZuMhwovw/ZBIaxp+iMRnPJuRVxl7vRFxQ3Lu\n/CrJks+Ot8mwr1k/c5HG7n1MUfF4cHLqdeOpWuuiFGDBlPC0VbLgI3lRRDaZonXNlSgRDHsD3phO\nWF/tcm5F4fkSXVhwlmUPhKdYX41oJzWgbjbL8SKIOzGPTyYEQhD4AaPhDC0SZvOMleWEpNmkNNB/\nvMeDWweYbEIpPPb3a6G6sGJBK4dOICGQzLR+mgH0JNsJn6cjLLEgELrTLRt62tJQksw6POtoUP83\nuXXMEMSLyJymB2LRPRqWjkBCwxMoJKFyhIuXXKnAN1B4jlllSPy6m5YbSTStcELhtGC2e0LkSQ5m\nlpsjeHv4ze7Cj3zyPLPZnGDZYUSGspIsM0zmI5qNMxgTkM+PyHTdVUvTPltbS/xnf+X7+c1f/RJF\nadC2DldtKzCXr7DaadA/Pn3w7JNi5+nfBU8jbepSqHa9OqjNKhaUFAiR8cajX2J9429xrvMePK/P\nOHWcsZY0mzCbw/iDP0vsRUSe5Oz2Cid7PTqdZda2lllZW2LnzBoWSxT45HNH7gyf/It/gdmn/m/m\nqcVJgakMKgwwVUUAqEYIvoVRVodgGI2rNEQCFUUsX32eo0/9EfeGIddSQaPfp9NdBb9FEJ7OGerc\nk9DURaaYdU8dyk/NgtIDY7CRhKDWcHoy4zDvUti6HhikmqTZQrkJaVrSaXSoyhFShChbYeohA9JY\nlOfRCDNkq8NFr0dXf4WeO8u6fYRY3BhOG5yqcySdrZ2NLDStjifhT/XPyi8GKCdoxj6uEXHy+C1K\nmZDuPcTz4T1bHfJCsLz27TWV37HgmacjmhtNXGeFvOPjDg4pfEFj5SxDH4TyEc0GNi9oXL3BaD7j\njJwR+CHZeEYgSnRqKJc3IZsjixQ6OwydohGnmLxP7rq4UMJkRN9oGo2EfHqIajSosgLKgpnwCYyk\nqDIK6WPncyo/xIsbjHVGTEi6fpYojuvxFg5fiQVjp75kbvEEMmJB+bULsXA98l2oek5xGFOHAmLI\nC4gCSaYlgTPM8gxRaSIryaVAlxWRFzIPJf7E0hv3CEpNEgeEUlKKWiuCc/zeZ+7z6U/dwWaOQkFV\nWeJEonyx0G9ZpBP4UCegO0duFjRk40iFwBhbV87GImU9NtPW4kmv1tRIh6lq27o1Ft/3IdcOkMF7\nAAAgAElEQVTk8znC80AovNxiFAgnn3aJTnV5tEU7R2UMlat5PlIJ/DBmhiSTAdsNHzlzNAIfP5B0\nghghAnrzkqAR0MwchYRUV/RHFfPC4fVnrC9F3Gh0eTwqccLDlIrX+xM6SxrhW4q0IPQcvYlld2Yp\nK81SO+HqtYgJ0O9ZVp6t6cp1seOoKktV5uhqRlFMqcoUXeaLDo8mT1Omac5qq0GaTlmKPLou5vKa\n5e7B6YL85r0h8+mIle4y6WxCoxFRpAZftpmVkmw2ottUWDsGYqIwJstnZLkmCkMmgxytK4bDIZub\nm0gpn46yJpMxSimWl1eI4xijodFssdTp4HmKu3cfc/v2HS5cvMRsNscPfIo8p9FsIISg3+vRbMRs\nb21xdDKkyAqUUuw/eszS0ibD2ZSo5RFVFq0zpLL4IQyHOatrW6R5gef7GA2z9HS7UADlCeKkbpzw\n/78nFyMKoTyss2gjabXAV4I/+swrTIFrV9r8+A9dp51ElAvtHAoakY+gjnqpdMV8IezfPUnpnOsg\nPJDSESYhrWZII4mpypwizyly2N8/IM8nvHtnwvEM0DV358nRiBWVFERSYe2i47cYncFix7yg1z5t\nNZzykeOcYG4cBRC62tCWCGgiGFiohKCzcLa0lcMTglTX3bJR4fAkRFYQSYcvIaydwAQGPCmY5g5F\nHa45Lerw36ISpFawP4XjStJpJLzvyjql1SytJeTFMVEc44UNPAWF1QReQhyuI70OoZ2DMPgyIC+G\ntJpLHAwH/MwvfC+//r9/HutkbRt3joEX8NLV57iyuYRwp4QUUTOcFnplwH3z5+N4Wuw8ZZE8+bqo\nXbqh75hM/y739HXOb/xVls62Gez+IfPMsf/ev05T1NEpUimCUPH8S88RBJLxZCFrmFVsNkLKStNo\nNSjSFO0MfRniO41naxacKyukdhCF2CwD56P7KWJwgO97tZA9TxG6ZMkv2D4bs5tVNCOPvBJYWyFF\nQZmfruvu3ELeIb5FR+8Eknpj7LQArRFrTWQ3hMBHSAWR4Si1BFGCUIrCCTACRcE8A7l9FfPa7yFU\nbS6woi4ma5SJpalACINB0lF9lLMUOiTyCsT2BgQRPNytNwdagBW4RQX2JO0e6rFw5jUpJURSoTyP\n3je+Qb+a0swmrOaSxvYqJpesX9z6ttfhO9vSz5wndxLvYA+3tIpoKc6VE4JyRDGwLMU+vgvZbjYI\n0hndrA+rW4jlDWJ7k6OJJFgNyEYpSTnCd5Jg9gi9dgk7N4SFpqnvkQcvkEUWP1lCZSNEEFB5TYqq\nIog7+GpM2LvFbD5ntH6D5eEjsq1nqBx0JXi+RIWbpO1N3O5NPAGBUk+LmbrwqX9Vi+rRLpwfghoa\ndspxOoExWCWQlcaoAKMt0moKKYnmc5SSGKHZ7nTZG/SReUYSRszzPnoyoRP5T1x3KCHQon7ACF+R\nlgVxKAjqQB6s1VhqG6F2i4eSrM9b4ii0qXUCnodvKoyUT2nSbtHmNThyawg9hbX1TlpbWwcaOhDC\nIno9vM3N2uH05NpZ/efwaEGuSyrnURlDaSQGUFYShBH3DgcoIZj6DbaXIvYyw1g7/vCLX2Wj0+B4\nNGWmLUEYUNmaM9QbDNlYbTKZCYyu0Lai2/K58zhlMLXYomJ/F7Q1BKGC1LK65eFJQzOxrDQzUp0Q\nS0FrTTHpFTghKUuHrhxVlVEWU8pqRpXPKMsMUxToqqyF1JMDpHFo4eFZn9WlVUaTAcczWOssn27t\nKIMrM5xoE4crjCdTLl7YJlABUbhMnt5lOvHZ3tpkMhyR5yWVqwNWtdHEScz+/j6bm5vMZjM8z8P3\nfTxPcf78Bfr9PvO5T5blJEmTqix46803aLcTDg8OWO4ukRcFw9GIorSsra8SqroT2eq0qMqS46MT\nllfW2D8ZsbFzhpVOm+PjY9qdBvuHd+h2t8A1ePzwEO3nbG9eptfPOTzerzPdRIRUp2ep2EWfPa+e\nrLpvqQoslBmYylAWhlJrWp3/j7M3ibUsOe/8fhFxxju/+eXLOWtiVZEUSdGQ1Bagdk/ottvL9t5o\neGd45ZU3bfSivTLgjeGFvTbahgG30LYl290aLFkTKVEii8UimZVz5pvfnc8Qoxdx7suSZBb0eIBE\n5hvz3jhxIr74f/8BVC1J9xXLM0s5HLFYaJZ1JL5ba3FCMS5zbKgJCNJM4YPBevB5yt/++7/KaNDj\n+Zv/ie3DI7I0I1GSPB9Q5DnvfPwh27sJo8sfkxYpi7pilWw4ENHa4Hzu2JvI6COUKJrad7hC1E4E\n16lgbJellUByU1TZBwQBJwVrIWgJmAAlgTGCysM5sJdGLkuhIsJdO9A+ZgjOdWAtIVeCPAJSFJ4o\n3wakiIRfFwQz7Xm9FjRWkGWQZ1DVK96crVgt4O+/ex8VBEWR0rZzrBtGbhglCSVNOycEi/MpJgi2\nxxPqxrAzuoM1LXMLZWQaUBvH3nvv8Oj2Lmerhr3xzZ2WnQ8EG9EilERe28qCFBvQTXa88RDbJDIW\nPCoVOLugrv6Ix8/+iP/2+d/hV/tDMl+QDMb0ej08xMNjSGmNpyhSRoMCbyyJBNMa8sEAY2znh+ai\nUtYnkYRiDWnWQyQSgsd5h1u1NJ89ZSiS2GbrDUGlhOqKN8+ecOfOHexnP2W2lrjQovIU21SUw5vx\nBulGYnO4VwECHikl1sd8rbBUBF3By9Xbx87B82aHnUlBCGCcRTsN3lNVQLPsjHtjC9d0Jq5ee1SR\nUMoW7OoaCR2VU0yTx+fBSwhl5ML5L77I+HHYwHEicpASETMXrfdUV3OSe+9z/vv/il/eKSiznLaK\nar+y/7PnzpcWPK6qyfUMd3ubRErGYgvnR1TnFxzuHXJxsWKSt6hBSt4ahvuHlEnF5fkpK7fLeKQx\n6wWjdolJxigzRWzdJneBtF8wG75Dr5ox13PyNCF3FauspLWWUqaMls9o2wuy8SHp9hZitINfVNjx\nDltlyeX5S4zI0P0EW6+RHz5i8ckfkxe9OJGJVaxCELBd6neg6wJGFvj1wN5sW3fBkyNwiSIJDikE\nQQqU94jgY2utaTjXmu3tLaaXV6S9AmkDwlqKXhELEe+QAQoSKmFQ3uOsxqkEI6AQksobku6BFYLo\n74AiiFjGpVISgmPldIR1jQaVXLfrQgjIEKMpBIJECqyxxEhNBakieHDVEusPcDIulEmQaBl+LuNB\nZy3GaayXMX08BFyQqKJHIlImheTJ2RR5/x7bvQqxbPG9HrX1vHcw5sdv1pSl5bMzj0wCyzIn1HCw\nE6gJ1E5xfmz50StH1XrqNjAsFQ6B0R6XSNaVYG8gMV4iVMni0iLznFGWkLq4HTnrMKbFmApnW6yu\nsaZFNzXOaKw1WGNxZkkmJ7j2gmw759XJKXcORkznDYsb8r9c0MgAFxdT+oMdQHF19YbVqmAy3mJr\nvMX0coFEMeiVXJ5d4aUkSXPKYkAwmrLsMZ/PkVJSliXz+ZyDg/1rtdZqteKTT37Ae+9+wHK5osxz\nemUPJRWj0YCr+YrlcoFxksF4wmKxQErBaDwmlZLdnR3K/oBy2SKF5HIxBWHZ3hmzrocMexlSJPRH\nQ67W5yyWDd7Do/t3ePnkJ/ggKIqbL8pSxTatd38VBBEkoZPVAj7EwB+lQElIBymcOX7y/IStUU6l\noXAgpeJi6fmF+wovCgSaNFVcWjhIYN1afv3Xfwdn4cnjBX/8O3/IYGtMhmY4HPLTz5/RrGaM9g/J\nJiVff/cBp1efsro0HQcnvsJ+LkhVPFCp5AsIQifHha7o2aAON19yuq1bdNyLgJCwAqyHbRnoBTBB\ncNLAXg6KyNcZIKgFaN8Z8gVY2UAjIJWCvo+HLtGZlQqgsoHjCkzwBAWrmLWLklBpwZqAUNDr75Cl\nGmPBK00hcpK0ZL4+wwpBkYxYmxpjK4IfkKoWawfMl1csiZ0+ay29/QPu3bvHZNhnOncsws0J7869\nVdsKH6LZY5e0LbsBlCJ2kYwAgsDBtc9LEOBsfF9z91v8aw+uOuDfF566bUiTnDRPqRvLcJARg99a\nslSRqrj+t62N5F4Cfj6lNgaVpEgTEEmKx5MpCRbKrEBVV7jdHdTWbfzsJVKlMfQ677F/+4DL6Rzn\n4NXM8c0tj2k0Ki9YVz+HjG2DmvprXnBUtPmAaz2hhTAHnECmXSFYepyQ5InHOY/VJqrdnEJJy+rp\np6gutUDiCSiMcaRpHF+/gdxUbP8GB0q1ccxfHQPHEdXpCNUbdVbo1KKb+Rg8eCMAyao1ZKZisn/E\nyZNL5EcHqFTRLBuKsmB9+bOd7790pQ5CMMmXpNWcYbNGzM5pkz6hLDBpQd531MOUrL/N7d0MXy/R\nC+grRdqckzVLTL1GqoSdtMFkfVSekaFp5ACQvDGw1e+DFORpD6MDu1LjbUVdL0nbCh9yKrmD7N0i\nXPyYgbaI+YJxf8xIrPFlgSo8/usfcTVdUmQKCDEtV8RK1ncmhLEA2ignNlXEjRtaBBdJz9HNWbL0\nvpP0gtAt3kCZQC9RNHVsFSXe4whkqUQ3GmM9jbFgHNrZCDEmkkGS0BJPkHVo40RwPtqAE1BdPowN\nDo+ncZbgYzyFd/HUQgjIEDNIpBRIAcFZrNHUVhNEoFcOUFna9XAtfrnEeoN3nsQYjK0JpjMZueFV\nt5Zaa+pW07YNWrfotqYcH9Dfkly1joOtEe/uyGgjkGU8eXHGmsC6tczmDaU1TGdrXr5ZoBvL+bLl\nzZnh89ct//J3lvzup2tWTQvC89H7Q/o9hbegUomzjp4PfPJ8zeVC88mTFU2rGQ4Sqosat3IoleG8\nxtoaq9fodonVa9p6jdEtxkbFnTaaWjeMxgWjUcakn6FUAGHI+wV7N+RXvjo7xXjFbLqGEDC6JUk8\n6/UZxii2x3tkaY88GdGsA4KMJ09fsV4b1lVFWZaUZcl0OmVrawvfmQy+efMG7z1PnjylaVq++tWv\nYawlz3PSNMN7zwcfvM/+/h6TyQTnPZdXl3zyyQ8o8pzhcEi1XkUFnbX44Dk4OMAHz7Kag3Cs1lMm\n20NW62k0jTOW1dqitWY0HPLJJ3/K7u6INE3o9W5OPE1FR1wM4a88k93HTTxlhxAL2+U8uoFvZuj2\nuKDsZYyHWXR09VFt18s38zwWJmnofEiEoj8s6Q8LigLuPbzLux+8z/37j3j46Ct88P47NLpiaQV6\nFfjej15xddVxWELcFEQEFLpzQQcpf+EoGaSISKoUbJjMwYUbh4duOEHuWl0Ti5dKQNUhxT0RGCaB\ny0bQdIiQlFAqKKSg3BSIHXG6toGZCUxNYKFhrqO667QRBKkosgxjocj7fPOjjznYGrI12gGGfPBg\nElsZyQjPmiIbMm/d9f3piRTnDLpdIpSjMXO0y1muz9gZTfg7f+seV96jcsE7733A/d0JCEHen7BX\n5jedOvF++phpFkOBRQyL7lrq3gtsN3YbZk/oxtGFqPJyTtE0krURtBUUe6do4/E+0OvlKAV5lqCE\nhOCiuW3wSCURUpGkCqkS0rygefoMIRQiSLw1UWHkA6HV0FETRN4j3esRrEYmaTSztRprPI5+lKL3\nBZNc4Yzl6moWxTjuhiqtLxTXgbe5Y76T8DsT0T05CMj9MsYYNbENOAuKXEmC9zgE5+dXLNcNMu8j\n568RIolzsoNjrOvGvTWx8BFREe09MRTUi45qIiKzqnsufcerensP4+t0XXUWvMV5z06esHN0h2Yx\npxIgpKetLFfnc8gk4UsyW74U4dlez/G9lHS4S69MqeczQl0xzjWztiYbjNnupfTrSxoZT9JlnlGt\njxmYgAspPhuhRcuyrUjTjMxpgiwYm2My18OyxORbTMqMmeyj7CmhWeLzPfztr1KMR9hFhQmOVZAU\n/87fY9nUjNopBQMWOmXetsh1zu6tR8hbD8iE6QwH4/oS8/q6MEIRkN3CZEPHFRCdWdUNrhB8JFHi\nEEDehWXqoJBCkQVDlaRY41HBM0pTrpZrdNNgG41JYoGSCWi9v4bvtDZxIjiPCo5cRZdl72NEhPMB\nKTwuCAoERkDiAjJXhNYjg0M7j1JgO+ZSKlKSVNIYgw+BTCiklDitcSGAVAQkW6bFaoMIAkPkBOTo\na++Km1x7A8F0WTHXgcapaGcuHUk5QdjARw+PqFYV/TJhbjx7GXz7g0NckPQHOa22bPcUk1HB+aKh\nbi2DYUJQsLaeYW9JLmE4HGJ8IHMJg8zhi5YiSclKydPLmg/vZsxXFqkFIs95dVozGSUcbQn+4k/+\nF+anx3z97/5TjK4xpu6KswarNcZonDFYa3nvnbssXx/z/oPb6NogpWa6MgzlzSePdhlmvebw8IjF\nfM6wnxEai7JRgpqlPbyTOAfzxZK9vW2+ffBtnJW0jaGfDZnN5hweHrJer8nzHGstg8EwOiuPRvGU\nFODo6IjZbMbdwx0Gwz5Ga9arFUWRA4Esy5B1gxCCQX/A0dERSgQWsylnZ2fIfIAUnuFgSFEoarMm\nzxN6gz4uBFKVsL+1xd7uLk3juLU/olo3TEZH6Lq98bxp2mjMJ0N3GLmutUO3qEaiaQgBYxyzCmRq\ncd0cvTpf8Xn2mqbSJAIsHpFA1TRY61CZQCYK23psDk1V8fDhIybjPj/888+59+AeWzsT+v0+SsrY\nZraOwaAHCparCm06CuwXzF5Ct4THhDreUkXi2es6E0j4+Lngv/Dzf8PLd0TTuDnEabe5z5cicnW2\nVUTCRnmgthEtG6hYa5VJlK4nMsrXTYjriY6dANquCDIuKuKEsEwr+PYv/AJFWpDR0mQJ9/KUo6Oc\nkjnD4R5erEhUyWz9kiLfxnlNkfdQIpAkBY0LGJuyNT5gtrxge3wPKQO//Gsf8Vt/8ILR/h1u3b7L\nrd0R01XFsNeDvH/jueM3EjgXCeFJPOFew4R+g3lvhr27JyrSH7uxjL/Dh0hun557/u/f/4/56L1/\nzjc/+hplv0eeKYz1CCFonaaXSXqZpDaOxgUKAioRtK+fMhQZ3ptuw+4QNqNBFVhnyLcmyDIQbA1K\nIRJJqANCJYyPjsik4P1D0I5YcCUZKhWYen2zwQlc85V8gKSDtET3uXoZyJbRpFEkirBsUCrOLSNS\nFC7+2wfWVUtt4uYqVm8IpAglu4I8x7HCWkHio/EvXQGMDoRERP+7jkpx/Qi5L6QefOH+RB5P9JQT\nTiM7lfHaeH778SvePxxc3zdno5ed+ZID+pcWPE37ktHRL3E47HFZaerxIfLyAj86IiwrSh8YjHbw\ngx5+cYLpT1iGQG/nFvXrz7lwJVulZB5KemmJb6Zobbnwkv10m+G4z9WFoKcdKgFvz+mVOXXtKf2S\n1CQo6TgjMEh72Kqiqa7oaUfY3sGGlNXAkbeaIqmpbEbyzofIN3+OQCGCiDJ0BMJLkq6iD12OiRQB\nH2RMZb1hrg3Od7lcEhviCcAFSaGgDYHWBdJgSLuHZy0V/SKlFFFRkeQJKkQ4OnUxODRXEuUDXkiM\n19y9vcvjl6dknfumIlAoQWs8qRKY4NEmGqwlITbuUkAqiQ8dj6lzOq61iRCjENhgwCYIAVkST77B\nOdZVRVY3NFLgQoSbW6lIfw4Sz+PHP6CpFgyGExrTx6pthMxJhGaxtpRHgXv3R3gU+4OSdWtQNuVk\ntqRMBFu5ZLLVZ+4bPn53wr/97jFV2/LxhzkJgcVc8eBoQHlnD1lWnD1f0CiLSaP0crUwkSzdSD56\nOCTrOf7wByvmq4TkUYpt+6xefYfnV9/hvfk/xkuFaWtMW2NNQ6vbWOy0Lc56pqsrFlXDxLRY7bm4\nXHK0M+K0vuDh4d6NxkZmfdaXM4yeooAi6aGrmocP7zC3lrIcIYRktVqxf7jDer1EuYKT4yvKYkDS\nz1itVjx8+BCtNU+ePCFNUz7++GOurq6w1rJardBak6YZ9+/eZbFYgrCs12uapqE3nHQKwtg6aJo2\nemKFgHWG1XJFXTdgYW9vm8XynNWyIe8ljCZ9jJHs7OxzfDylFHB1fI7Wnt5oF+lhuY6/76aXNTFf\nx18r/b/wXCYCbMAj8CFgumykPM85ONzj+PQVIgiGowFea3QwSC/IpaRICyqzxvoQn4/Os8V6OH19\nwqefTGka+J3f/G3Gu3t89eOH7O/t8ezZKy7PVzz5336fZ6ebFxJx9s2pOVGxTYyUCMBYD7bD4jco\nTqeg2kRNsEGYb3BtNoPQtcu879a2LjtqQXSq3osiG4Zp3KCmJiI7PRU5O5L4d+haYEtiIeQ65CNR\nAhtg0cK3Pv4K7bri8eO/4MFWxn/+d28x6fdR/YLffFkhZI23DUIVOFfjvEA7wzDJOZ5dkaohW1sP\nmC1eI70ly4a09TFCZPwH/+Dr/Iv/6jc5PLpNL1P88KcvONopOZ2dY/TN4logmguKzmFZImJIZhdG\nu+llCKLc2fv4/XQFAL6baRvaiIg+Zs5JjP0R3/30n3D36I/Iq5bd3Ql5mmKsIU1SdIjzqJcrpEww\nPsaKOOsQ5HgT34tykb5ACEghY4EmQdw7QhU9wrqJBYdKUALKnTuoNCO4QCoVvf4A2xqqJlAvbpaW\nvqFwvK0n3s49FyJh2FbEeRqWKAHGCJI2UAcYdnvQqtL03plg9CWmUzp7LwguxpoYKyKq4yBco55d\ni5G4zwYRLVyCENeWDcHHiInrw0JnHSA8eBEP7s5bEg+N86zWNVevn/GN/THOW9rWIYUnhMD66me3\ntL5cpbX1Ac3pmqw/pmVJsXJMlwtEf4+1ddwqBbSOarmg/9VvYVcrxsspzdWaMjXRpO3Ck5UlxtZo\nEkTjueNfIG0fq1tEpun5ltpvM1keM+sNGSaBJQleVOS6T5kl6LWm1y8Ytg0+GxC0wazPmcjI4VBb\nB/TtkuX2CPUq3lIvwltOfmyno7pBt9dN3U7LdUMeBiLgrYnW3AESL7DSIYLCBchSResVddvSSwRF\nXRHyDJKUBE8iFc5aMglaCbSJeVpCRo6OcJ4fPX5NWaQd/OwRIfbdg+hsPTbvSYFzceZ4ZOzdJ2Ct\ni1w5KaOaiYAMjuBFLDBDoLEGhcB7R7OqSI0B34WPyijpvjH8BTz+7v+FknBnZ8Tx1Zpk+1vI/m0u\n3nyH+wfbzFeWvb7kB88vOL1cc29vxCAznGiNF5ILnfJm0ZCl8PmTGXujAqUk01PDy7ljuxCsA2zt\nBhYXGQfv7uDbNbNZi1vDYuYYEViJhKfHAWsaMgqUa/jWe0eYxZSJuMDPA83qDCcKnGvRukE3LUa3\nWN0hPEaD01yt1pSznCLLCFKSFX1qVXBW3cyH597dr/Dw1n3uHOzw+Y8/ZTE7Y2dnzJPHP6bYe0jI\nJYP+EKUsBE+R55ycTYGAd5aLiwveffdd5vM5SZKwt7eHUorpdIrWGu89Ozu7eO84Pj7m7u0j7t27\nx3o9R8qG27fvcDFboJQiCZHHdHh4yNbWhBA8xlomkzGoBEvKaDRmubwgLzLyMsVYz8XFjDTLKfOM\nSTbk2YunWCu5mp6yvT1kMa24vLy5tFhlIubsOPHXuWPdGt00LetVg7EGgaDVLVfzmv4QijLh3Xfu\n8KSteX3aYDrbBu1BG4+Ugbb1ZFlC3caT93hrTFACeMmHX3uPYjjmF77xdba398mV59n3/hgVMqCO\nvJ0ODdjskEkeN9U0ic++th0SIwIkAuGI7tEhdIereGq9aRudt2BFDMLseDfXVvwymipaL9gVgYGI\nDrt5Gr15LnUseooO8QmE6MCbRFKz6VAy7aFpAl979xEg+cHjzwD4T37tHvujEpkYxlspD140nJ6+\n5mB3GyMU3hcMyntIkVJry/7kEWlR0po1UuQ0bc2iOWNQjAguYXdnzD/6h+9j7ZhRkVFLR1aMqBVM\n5I1HhxA6J33RaXND1yLpDDk7K6QObeG66MV1hcC1X2S3a3TgrbWQp/Bvfv+/4G/94j9jZ2ccOVRC\nIhPAe3yQWOvwwSF9pB+0rScVjqAEwrgY6pwkCGsiRzKViLJEPLgHRKUaIhBkgggRj9rZu8Vo8FOG\nvR5vXpzxi7/0S9hmRW9002De2FLaFNuRu9q9zw7VsybKwqOuJzb9mgq0jBylpja0NuBd7AFu0jrY\n+B0pRWugl7wtIpWICM/GVEdYT1AiZt0l8V6w8anq2o/huurp7lX3tHhvEVKgvaSZXXL1/T8j+5V3\nWFYLyqFgvL+N1Q7vfzay/KUFTygnDPcmuLykFCVGn5DmBfPzlxzcuU89e84URbFzi4vpgnR1gSp7\npO0JzehDdtsFz+2MQq4gSfDLU3pOk7/zIedGUFy8ZGQluhhRBcda5mz39lG5ZTukLOoFM7FLpk+Q\nKsqHQ76Pl4b5smGQZzg1Rolz1lrTLhqKZoWXCik6QVsIICQybKpEiRQBgkQQeTeWgBI3lfnRKcAk\niXQ450lRGCEI1uClIsfjZaC1niLPaK0lSRISBFpbtLf0sjT2dgkYGzBpwDsf7fBTGeXyPp5UAoKc\neHg0IXKIBiiWTl87t2oRXaa1CSQhIKXsHEwdIkSfHqTqAgxlZOsriQoCqQ3S1kgrSGUSM8eEpJU3\nh3iCD8xWLaWSXLx4w+TymJPpCikDr1aOr96fcHZa83C7z/Yk5SAvuLhak/cVF+cVH78/wuUpVile\n12u++eEunz++4r3DEfcf9Ckaw2tn2R71UVXDYEtSrRKypOFCXPHRnSFUBqkdATg5L1BZzcFkwKzy\nNDPBSv+E/iLl8b/+52x/6z/CNpfYdA8hc5wTtG2DNRprNIv1goNhj/1BDyMk3/7KfT5/eQKJIBvd\nDH4PIfAnf/KHfJIptoZ9JsMBi/kVg0HBZ4+f8tUPv4aQkjzvdQoYjTVT8iwnzwomO9s455jP5zRN\nwze/+U0+//xznHN8+OGHPH36FCHoCqGE4+Nj2vWCW7f2uXXrFlprZtMpeV6A9Ny9e4erqytOT485\nODigqdesFlNWlWG8e8h8PqPIc6yLO4LVhqIsmM0W7O3c5eTkgvm8oq4NO/t3QKWoJF3xNygAACAA\nSURBVEX/HInXSsKoELQbsuL1V8R1hysmX8dtyfvoyF8kCYcHe3z88V2cbnh9usQEsDIgvSdPFVIq\nnGvQTrFqHP0EmhYGgx5ZEflGAZhsDTt1m8bbgJKSg4lAJyUvH0d7ZSEESkXfmtBxiKzRJCLF+qjI\nElIQ9Nvultj02Ltd56b5fZ5NobX5wdBxiMS1KlMg0AROhaB1gT0pUBL6BErVSdS1IFfRn0cJ6ClB\nJgPax697IagI5HnOd3/4KSA5KBV7I4XIJUk2jHzCYEgTy7xesLV3m9olrOtT0mRImW+hzQqZSBKZ\nsbM1wGpPIiApSuZVRe2W/KO/9wG/8W8Syjzw7Y8f8fmzE0iGZKObq7R8iCZ1LgAumsAmaXgbyeTj\nPIpf7tB+v+H5bJx64lobV1vRtcAkrQ405rf4gz/NkfKfsb01ZDCI7S0pRTx0+jivRrnAmwbfGFyS\n4DvhTOwlmlioEknnvHuIPNiL8kOVEIKLaIhzyGxEWhTsDPu8uVyTIlisFtw63EWEm6GnHb4Y646w\nIQnHr/jQEb67IGVxjV52LuchIEVUA/sgqFZLevsKKaIS0hCLRwnUTjIJb/dHpfx16zB0fLJgia1e\nYpF5jVyGtw998NEvL3SQXPySwwRPQkA7y/wcjHbkSOq1ZvswQYhwLWv//7u+tODxeQ89KFn0S8zs\nJRiJ0w3DrISrBWYwRuV9Uj3lVr3k6uIKtu4itm4hUsGVzTjYlpiVoloaeknBcDCmqiWD+XN6+0eY\n9RmSSIYa+xm9q3PM3V9hev6SSdFnPX9FyDOS4ElkwLuK2VojbcVKpCRuTtLfJVke02SKbLkCv5ms\nLt45768nshJghUBterXCk/hwYy9hHzwqKFpnoo9DiB4/ysX4Cm0tO0Ihk5SZsbRtyzDLWaaCRddD\nVUqREwvcRAiCiLEOXoRoxCQFIQik8NjgkcSHWQGEiFLZEI0GQ1dSp4BFkopA5R2JdSAjgdsSK/g4\nzz1KRG6NDw7lDJW1lGsDUtEIhxISqxRK3dwTw3tBP1MsVwv2D8Z8/b19BpNtnrw+5zf+7Z/y7Kzm\n6+/sMq9qZouGYkdhU8hIUHlG3RheX64ospRUwk+eTbm732dtDNMXU/r9HLmT8vSnL/Eqo2jHVOsa\nW7W8czih8g2Vtlit6PUC41s5+6MxQ23AaB6f1Az7isGBxLucN3/26wSnUfmQdd2y8/CbrEw/qtds\ny9awj3MeLQKmdTydn6GFYjjok6Q3MwHztiZTCgnR6K9fkqU9VvNzTBuoqpp79x5QVStmV5dICe+/\n9xXevDkmy3Lquub169dsbcW21E9+8hPu3r1L27a8fPmS8XiMEPD06VO8D9w+OuL24S5Nu+bP//wx\nk60JQkq8d2htGI+3yFVgd2cLBBRFQSK3QazwztHUNb1+ga0qXr16xc7BiPF4wsX5nDwvWayfUw76\nyMxhvGVn74Bnz4+ZL27mTwRQN4HFSuDtRtT9xYZ+h0CYBt12cScWyODgcJvTizmfffaCYRGfG4Cm\n8aytRxDIsxTnWqSEVCla51ESvvedH5HkcSmslzVPf/KC/ckO/V7Bk8dPYiDhIOe9wxHL1RXSaq4u\nWny35xQpZCnR5FTGtkroCiHRobYhxCKIDTfp57iCf8t5iC2ttx2biDyFa8ND6wJTIWgd7IZAqbo1\nR0AuA60X1DauiQMFaRduShA0HbImrzcOz3/47UcYl3QHqxZBn6//wi5FXpIPD2malkm5TZKWkPQw\npkUqSZKOqNsGY9YUaZ9GLjGmJZEZdav4xjfu8D//qx+j5X2evnyDFj2Gg5KkN7jx+HjfKeN8uLbt\n8E5ctwBDJ/rYkM29iwiQJ8TIoeup1u0NdLJ1ETdgGQLzxW/wv//OT/l3v/3f8LUPv4q3gTSTDHLZ\nHVRFXFerJUpKjI0RPVIEgrPI1iCzFBk6a5S9bUSeE+oqzpUg6OR8ZP0hQWYMx9vsdvw1bzSPP3vO\nVz5+8HNOIt5allzXzQJruOaUfWEaoa7lXOC1BxFYNeY6fV50v1QKiQsC41WHnkHwXXfA+oiedQWm\nII4nG/Wi3xRc8d/XCKqP3xJFOJvXIUlCQIiEK2C21NwqIUkkKMnj7/yAr/ztX/qZb//LVVoV+FfP\naJ6+wV5pwuqEXpbhRiPawZAwr8lXK0okq/4uw6GinQyotKB+fsyh1cjefdLFM8YDx9bWLfLhgNZo\nzsUOYVWh1Q6tzOg5QzrYYdW/y+r4MULmLGVKb7IDzRqy2HKprKPUM25v77K7vU/o7aHTAUaN2N2d\n0LoODRG+uxmx9+iDQHowIiqXPIGYTBUgJGwMjv7G8yYEjDMkXpALgUoUqiNOSRFIlWSJixJSFxUD\nrbdM+kX8WWIlWtvYSFbeowgo58hDPIcIHxB4TAcbJiHyfQwReVEduVAGEGoj/QbpHY3tkjCVJBBz\nahIRXZNVd5rRzkQZLRLnPa7WCKcR3c96a1HWEvTNs7QWiylBOGbLFVenM/7s+8/59JOfkPolX310\ni6+9c4s8F7y4MHgneLA9ZNIrubNTcjTOSLPA+dLSV4AUSCTffz7Fac9uKRgl0JN9krzH0U6BcTVB\nVZTjwPPzmlfPLUJllCPIc0WRa0KouZQ1Z0Fx8GCIJaWqwMscUoX2AussgcCLT36P5evv0CyeYetL\nMpWi0ox61bJerkmTEo9Aecvl+mZ8gxePP2XUK9nf26c/GKKSFERKoKQsBpyfX5DnJS9evMT7CCW/\neXPKnTv3scaRZilSSYoOlZBS8vr1a8qyZHd3F601vV6PJEk4OrpFmiSoJCGEwO3bRyyXK54+fUpZ\n9rhz5y7Pnr8ghMDl5RW3b99hd3cXH2B3bw/vA1fTS9q2pWla3n/vfapVjXfQNJFgb3zFsp5B4qma\nih9++kOss7zzzqMbzxtXx7ynxv1lfAfe0uy8CbjgrhcvKaFaLTk7fYNuZ5wvr2gbi/EAkbTiENGT\np1PuGBsokxj8+f5X7vKVj94FYDiZcPvebe7cvcPh0T3eff8d6pUlJAkyEczmS0aTKLffyNKTJL7S\nREnyJMGazmtnU41AXLl16KCg7s9NeYOBbkN8uzlt/r1Ja79ObvfRl2YZAq8DXFiB8bGVpSSUKtBP\noJCBuQ2sXSQyCwmpkkjgk8ePr//rd3czgrcE7yO6IAJF7wCpMpSUVO0JNizxfolwC4p8gPceZ9Y4\nM2PQ3yMtS5JkC0JAmylFMeHe7RF4zXq9Is3H9MsU5T2Xs5sXy3TcLttxQpyLa29wcSx8iAexKEXv\n2lodb2mDQMAX2iqCa2VsEAIXosZauMf83h/8Y168eEleZDTaUdeW4TAnVQrrHe3lGYlSsSC1DtYN\naIt3FpmkSClIDybREsS4qPAVHfKSFnghEUKwfXiL7VFJmSsaLRlvb/PovYesVqsbj83bLW4Trvq2\nrWrtRqUWuU2uGx9jY0vQeo/dJI86C0i0cTiSLm09FpTWq4imdb9DdATiaG5I7CxEqPJ63m7m7KY1\nG0IsfvzmHnRorrMWQUAHT9NaRkAWTLw3LrCeax596+usZj+b0P2lBU+mT9CJwK1e0SYtzWSMEYp+\nsUUQgUZr0tRSqZxMJKyHj1C6JbQ1olfgJ7vMjl/S232fVUiRMjCfr8nNKXfyJcnuGFsUuPE2WXMC\nNqHvVwRVUCQpvjVUraM36DMzfbJSkghD0t9lKvpor0nNFcX8KTJLubCKvFcQrv02Iwgu/KZnLZEh\n9i6vKTxBgvA3L3i8IxUChKP1jkZrtLGI4EmCJA0BHSICk0gLISCcp7GWtXekgMJTSIH1DhMCQcQT\nlJOCRAiMBBPi4mmJwYexHy0wLhJMCzohi3ekIkrRPXETSGUkk4kQT58qtohRUnbcSUHwrnvNHuE8\n7XIJ3qJc5Cgpa5H+5gXPhVmRJpLLuqI/KsjyhOPlitVcMMigaVd8//mCw52CWwcT1sbw6YsZKfCL\nj/ZwjecbjwZUMjBdteyMevyDb92l30u5c2vMk7Oa9WzF7mTItIqy6DLv0zYBqQQP7o/JSodMElqv\nUP2ccea5OvMIl5KUAw4PthgME9bVFQd3PsYKwWodDQcHhUDqOfX5M85ffMKsWVGojMZKjJL0Byll\n0Yda0yu+FCj9a5c1kQtyeXlJvzegbjTGOLzPGPaHCCH4/ve/z+HhLdbriqa1EBQXF1eMJzs457l1\neIvZbMbR0RE7O9t471ksFkynUxaLBev1mt3dXabTKZswoVtHh5HUt16RZTlXV1e8fPmSfr/PeGuL\nBw8fRh6QMRhnubqasq7WHBwexHZanuGcp2kMWVpwsH/IelWR97bpDUd88NFHqFRw6+iQLBc0zQ2V\nJAChW2wD8Fefyc0eLwJ0i/NmHV8sVvSKER+8d8C/98vvxdZMAO08tY5EfiVlTFEXAuMiwZcEyv6Q\nre1oHln2cw4OdlEqqkFAoE1cjPtFBgJOT66AqHACaGw8QKAkxjsy9XZz2WwoYmOk9cW3dUO303hy\n/oLd/mZj3nBVYlLmW+VLxzPSDi49nHjB2sXvF0QOeCaJbrhA42FpA62BPBFUJkJYh/2SQgkkDXiH\nC0lsc0hPOTiMFhf5mDTpddwWBUFz6/DboHpsTx5h9CmrxQWtrsjSPnl+QC/vo8UhHzw0mEbTHw+x\nIQUl6RU3l6VbG66JyNYHnO/I2H6D5kSOj9iM2YbPEjatq40IbsMZ+eLgB1yIEnbrJFLC//Hb/xmn\np1O0NjQetLZ470iSjPVffBcnO8Klizt36Ew1JSD7GbLfQ+zvQrsgWndLkGkMbM4yQgjotsKKHrPK\ncTKdMZ0uEViEuplPkejUfBspfvw7HiqCDzgLQQe8DddFoncReREiHhZqYwmIznYg2kPgNsVlQCAx\n3Rz0Ib4lKXkrXfSC4MR1sXk9b/8SsrNBMsO1RB3i6yA4RAjkSlKvl/SIOFzTOlSWohKFUIKTx29+\n5jh8acFj+wMG1Qw52UGWGXt+jUtymss5pqlovaLyI5LqHOkXDNwZuj9GKkVh/pyLNy84LFvS2Sm3\nM81g3EemGYd7d0l7O8hsAPM1aVVDMUGWGfNkhAue4AxlH4KecmEKdgeBRozwMkP0+qxDinYZohxB\nPmakPKPLS8zpBUJFFtQXbHZwSuDUprDpeugbaBPRicv/5pcKARM8MvhrQ8BUCqwIGO9ovSfDk0pP\nLhIK4SOhLpZdOBfQPkSiWwdNO6dxBMoQqIMj8zKqvJyLi5fzGO9Irw3EHEZEMpfz8UG2ISYBpyEu\ngCJ4XADjPV4EXIiEuiAkSedNpIgVvGwNermkcA60QXpPayzC3Fxts7XVZ7A1ZjgaUY4L9nbGfPvD\n+zw9v8S4QFt7+mkS7dpF4Kdna4wL/PGzKT8+X7JWGZ89mXN61vCNB/scL1fIJPDBnV3SPGNvmLEt\nHCeXC1QaIdXz04rhsGR7O8WGFuFTlBCUI4UNiqXrcXSvx6SsSCuDQFKqMRmS73/n/0GvGqzTeKNp\ndFwwq2qJ8DVJVvDs9TEZhlJJLi8XjNLAcJTFY9CNxiYWKMNBTDqXQiBEgpLRHfnO7dv0+z3SNKM/\nHLJ/eIixgRcv3uCcZ71eY4whyzLSNOX4+ITxeEyWZbx584a9vT0uLy85OTlBqYSnT59yeXnFbDrn\n6mrK9vYOQgqSJEVrTQiB5WLByckJ8/mcN69eIxCUZUmWZ1RVhbHR1PLs7Iz33/+I+XzFcr5mtWo4\nu6zY3b+N9QbrK8ZbfRbLC/Ls5q1QEBQ9KPK//jxuYpKcjVw0JYgSb2JBI6VkXVnOzi5wLp7kpYBa\nQ7VuIwIS4iaYphILEOD87Jyf/jiiGd//sx/yR//vH/PpD3/IT3/8I773Zz+gthCsZ1nH4Np6dY2v\nA1E0EH9V/JwTXFtXxWiJjq/RcW3CZme96fBsYtfD2w3CbzaJDe/hix8jusTwENVaNnASBHMflW5C\nxKInEdEzbKgEo0SQqbeGbwB3Ryk5hlSkBBHRCe1TdN12h0uBcxWJGkBQBJHTGM1q9RQlJMfnTwBB\nXm4zGuzhRcrO+IAs7bFYV/zqr73Di9efc3n8glFqGUoL5uaxJKJD8v31OMSPnYuFondcF8ohxLw2\nubknHWcn/nzofsfb8fSbIom4rjovceZ7/A//8p/w/OUxlY4q3fGgwM8v0MfHIFJca/DGRMWW93gX\nN22ZKsRWPx44jQUZ+Se+qXDWgTWEkDCcTBgNe9zaGTLs9RgPCqwLmPZm4xOuqwkiTzmI6+cpoifg\nrMBbgTPxY+9ill0gLnGNtkgRrVOs9R0aHqVVG++jxsdCMbA5uIS/NF9ho45726YKXyg4N2PuXWxN\netcVZA68tdgQuGgMy3nFLh0vKwSWqxZjNNY4plc/G/360qNp4j3G54hWspVqZnZCfwCrxYoi79G7\ntYtRBXl/SF6WmLIhLUao+Uuc+iYH5ozt0S7TcoQJGct2yXBvi6nrE5IlxeySNZJhaAhphhcJhXGE\nTCGq5zTmAVkpEPacut1HDQvsi2e0vW1yfUJSBozaQaQFa99g6jlCRGhQdMSsWMUHEicwb5tc8UZv\nHCLDzTk81gdSF718wJPIeOJTIpK4RIDaeYR38YDdOWAJ71He4kIaq1gV/S+scxRSIoInAoaSBI8J\ngSTEVlwiVXdi8VgRBWaOKJ/0xNgJRED6mLcj5WZdjYoO66OqRaKwIuZveSQtHukMzlnEfE07rPAq\nI3GR7xTtMm927e/s8+TJC5JccDFdMJ2uGfdyylLx/HzFN48m9K3h1lbO6bzmeFoxq1uGRUJbO5JE\ncLDVh+D45PkF+9s9/s/vHfNPf+U+bWu5XGuG/YzqfMlwNGDUB+Eyag2DsmSyX3J1taKxjr4IoBpW\nVYq3CaHf0tuXPP3hOf1kQKpSytygQ6CQQK7Is8DLsxW5SvESTO1RSYrwkajXGof3gdm8RY4mNxuc\nIBkM+oggMVbHqJDgKcs+ua84OTlhe3vCyckJbdsyny/wLnoOta0mSdNrSHvjwyOlpGka9vf3uby8\n5OHDh3z22WcMhyNGoxF1VTEe9bh79y6fP3lCVVVMZzNUWqDbloODPay1XF5e0Ov3KYsc52DVaLx3\n5HmfXtkn75V8/vlTCAn3HrzLyZszytEeqhiCqFFJxYuXf8H+3oj18mbqtc21nMWT2+a6JpN201DK\nzsAtAA6cgiRLUE1DkkrKXkJRKPQqtrAAlFII50hUJJF66/AyFkR7+/vkRQZ8l1/+lV8kG/R4/4MP\n6PWHCGf509/6PYILnE8Xncz8L7OLZLfWOBdi/qLvgIMNEfYLGwt+w8/bEEP/5lcgbtD+C1yL69/7\nxW8KXW0E3X+yKdAkKxNoJIxljOXpXzsRvw3WLKRgnMYD1MrCvd0BqYh8J0kNPie4isWq5dC2ONGQ\npznWLXFCEvSCNC2iylFMUUmDSHexdg4emnYBXpOkGUnS5xtfPeA//d3/kd/u3kKaZJT9Pv/1f/ff\n32yA/krLJvgQ+ZGyQ1e6FPVrAnm3+Ue+ShwwsSkkO8RMCGLKfddfiXWPwOtAokCEH/K7f/IvIP0v\n2R7dZ2w9zdkbkAptLNY0SO9j8WkM0ndjnSbw7n3CbAp5EV+60TF2wgucaUhGJeV4myIvsLZlMuqj\nreFoMubVy+XNhuZ6XohOGh4/GZAdh6bbv3j7PhHRkNDR+V75GGHkRDQ+1MZ27UKP97F7ou1b1MZL\n2GQUBU+Ml+h8kSK/jU6WHjqvgK643BCorxWR8X6IjhJ0t1fyFI8i2i6Y1iFHCbo29CaS9EsQ9y8t\neHrrM+ryiObNM7JEk+08ojKwfXCLZvaG0d4OS63Ig8PWM8p0yOvLU7baJV5r2H/I2gb85afMkwOG\nInCSGvb1U8JgCyclwzLnqrmghyRJI59C+RaR3yXd6mEXS3ppinULqmfPGe3u0+/t0M6fUScZ2eoV\nKg20q4TLVU1+8gK5ncfbJgPCC5yQCOkRQiJctDHyMi5UETETXe7K3/zy3uOk7Sy2PdoKcpVQW4d3\nsYVVINFCRDIbAi8FWV6wOx7RVhVeRrg8qMitiY6SnsgoimWyB7yU4Cw6xJwsIaKplutaWBYf1VYd\nqzH4SF7ezCUniCosEadyEJ3Qzwu8sCihwDqcrqnXK7i6oDcYYaO2IIbI3fAKszWLynC3GPHgg0N+\n/PQNL95MyYJgq18w6mWIVpBLQYlgMiwZ5orZuuGq1hRpyjBVHC9qdkY9fvndAy7nC/7X7x2zNUkY\n7uacLCu+/vF79KThZHFB0sspvGC6WjFfToEE7wK1ykiTFCE0MhF4l3JyUjM4SNHLNcsrBcHSTzLO\npxXWwdaoz96gpCgT6towq9bsDcdkicGTooIky1NWVvLu1s0IlnuDgs+fPuXug/dZLOakec7i/2vv\nzHojO9L0/EScNfPkSiaXYrFKqlYvknqZtsdoYMbw+M5XvjbgCwM27N/g/+P/YMAYGLAxV/bMYLrb\no+nukdRyV1FVxSKZZO55lth8EXGSJbWkGerOjfyABPdkZpw4EW983/u+33JJYx3j8ZA3V5cURU6a\npsxmMx49OsVowXw+Z7lccXQw8o7ZZUmSJAwGAzabDVprxuMx/X6Ply9fBmKroalr+v1jnHOs1yt+\n8pMfUynN9c0t8/mCo6Nj4jjhaDLhYDzi8vKSoltwPZ35TVtKNtsNuIg465JlQ7KsoCrhYHLOp3/9\nVxz0Epp0w3CQkWeGcrPmW4j7EFjynsQqQR3s83dMg/B8Sns331qxWzw38yUbBS9fw/XMZ3eMYWdW\n2M0SFputN4azvj1CL5OkqeTNq9cQ+6XwxYtXZN2Cw/6QOEm4eP4iNOJ1fPCDx0xfOy4/m38BZNh2\nQRYeQMXByK7tWN42EBUSL083b+3NDwhr7Y6M3cbbr8ODrPsvWpgT/Gw9XxG/t9zguBGQxvBhFggA\nAqSDLHIUsaCTCFhY7raGUkfYVGM0yKiBRvHq5ZLz7R0K7x1lqWn0lpPJjyhribYxvWLAajVlPb9E\npgVFUpDHBbWtkOKENLnl4Pwp7//xIX/zN7cgBNooVg/0mfFvLHi74Nc625ZJguuyCAo5KX0jzSjy\n6irbkmydw4T2B7uMjoNY+t+1kbdLUI3/P7USvsR18+f81//25/yPv/gZ//pP/zN/8pv/iYpPvels\nZdCNIk0S2GyJen3ivEPxr34Gg8ITmvMCR4yLHFJt/YtOO+CgriuGjx6RfJpjXI2QMVVZsV4/DPC0\nKcddB/Lw/ggHB2UsupK42P8wssLzjZ3fR5QxbJwiEpLbdYWuHUaJXSbIGondCGpjMbXvV2ZiSeQJ\ns35M9b0STuDvKas878zqe2BjQ2myTQpZws9TD0hrY6jr2gt+tEELsApEEqEqxejw61vafOPRPe8f\noutbRqMutijYWE0nzYm1oN8Zs2lq6uUlZV2ySUa8vLkhEnDST4jyDDZL4n7O9uAH9MYTkqRisHmN\njSTJZuHN95opeV5g4hG2XOHUkp5Z0tlcUV7WZElGJx1hbUw2PqKsGxarGerpH9E9fIf08F2Kx/+M\nlXWkzy+IdNjQBQgb0UoMnfAmSUL4yQr3bSdwNvgO/OPDOEOEJMLirO9RVRuDcZbIWBoEsfBGgB0B\nwjis0gg0jarZBGCirfUEYX8G8CjZGVI8f0cKhzY+lZjgdpb4MpxWnLOe74PbOTL7MlV7wmsvst0t\netpoL8v02J0meJBKpXB1g6gUdV2hdINSNU4//KR+tdxycDxgui35u19+Rg9LJ0t4MipIpORXr+bU\n1lJqxa01dPKUx5M+7530aZShwbGqGn5wdkCaR/zm+Q11bTkbJnTSlMrWqGFCnjneLFbcTRXrVYk2\nJZiGTpbhjGXQS3AoqtoiXII1lrrUJEKSIUnilO4kI8kkd8uK4bjLeNQnjqBsFM4aGqXIkoz5quS2\nitnohq60rDYVkXCU64cRLK+ubzk4PmV6e4OMIuaLBduy5HZ+x3o54/H5Y548fcJms+EH3/8enW6X\n1XpJt5uz3W4otxtWqxUnJyesVivu7u4YDoecnZ3R6XSoqhrn8CWx4J7s+bPeY+fnP/8FCMl3v/se\naZpSbksvc18uUUpxdnZGWVUe7AgP0qM4ptEKbRyLZc2Liyvmiy0ff/wptlrSywy9TszB6ACrJUnU\nwX6LzCCxoCmhaeDLsGBHKg0tVN7OZmijGfRT3nl6yPeeTUhjEe4P/zdVrXx5A7+4SzxvWDnH6PCA\n8/MzAB4/OeHZd59w/vSc73//+3zwow9AgFKGujFEX+EPI6Q3uNilbNqzU3sjuvDPQmpf7G7Oh0VL\nDG2f2oV6y9sfd49QYtidkEMJQbTZi/BaGiW4DI1aBQH0CEgkFBEcFPDzixmV9WPl8OtPJOCdZxPf\nMy0tEDjSOCGKB2y2C6rqjbcmqNd0iwOUi+jkh8hkRLd7gGqUL4+oDZta8B/+7Z/5yx9DFLeE8IeF\nbUnK4b215SjfNiJwVlyYMO0laHf/oCBqyd9+zNrxE8H3TJBEkKSixQoIB5taUpWwWfwVv/3sL7i7\nXVFr69VJWnsFnDakIiGOE782j3q4ukEmWWgsrsBpiDPQDc4arJAk2YB6tWB0cEiRpd5lXDpOzk8f\nODpvI2Cxe9sQiNnO0yxs4OQ46yXq1vi5U2tLGUr3xhhsYwjnc2zg4hjjMOFz28rNWy6bxXN4HKF3\n2f2jndct0dlaEZ4nlCfb57J+F1PW0R+PqPD2K+38j4QjSiRl9fWc02+cVrKqOYsaEpciZYezIsZZ\n2FJTqppyY+jUNUm3YL1awKBLX1o+55jZtmJtM6qoR8eVuGqNKc6px98jzVKa/ojldI7MDnHNFrd5\nRTeO6IguHDymGQ445reQjTAkGGcg7zPoRGRxiXrxGesmo5I5U9EhOn9C8vIlUSqIpG+siSAsRAHs\n4Lz7ZrgS3ptHhvT07y9k3zh9rKDS2iNja9DWYjE4q7FxTOEMS21QxmKcwUqHA/CXNwAAF01JREFU\njL0fQ7fo4gLj3MvbBWk4iUgs0jqMtcTCkTmfjUm1pRS+9BEJg3GGDEPSrgzCobBY50stlXWIsCkI\ny87VWToXGpE6L9u0jtT5k521Dls3Xr3VaP/QgV7/wOgVGR+cHPNnP3nCz376hLODLpNhzlY5DoYZ\nxSjnk+sFf/npDbPphk8urvnFiwXaSN477HJzNePp2YSnh30maYyIIxrjiLMIY+B0MuLpQZ9VaVmy\nwXYa7NpyNzOUVcyrlw2dLCVLU4SN0MrQWEelG6raIKRgvqgosoIOktnC0MkSauVII28DoIxjvWwY\n9HO61nJ8NGCYKBIT8Xqxpq4aJsWQN7cPO432iq6vizvLbLmg1goRR8g4Zr6c45zlk08+JYojZvM5\nn3z8MWW55ejoiG4n5/T0lKdPn7JarZhMJkwmE4QQzGYzjPEmgsvlkm63y3K55OrqCuccw+GATp7z\n6NEpm+2WV69eMRwMkFKgtGK73bLZblmv18GRuSSKJEkcUzeKplFcXHzOfD5HKcXnn79k0B9wcjLk\n4sVvWa9WXL25Jcu6ZFmPKHqYXB8A64hifMfrL0XLORBhw2rPKNZAlkqODvucHg/pprkvGznQRjDo\nQJxlJHHkGxj6EwNKOWK8WVyaehLoYlGxWZcYo1HKZ8esharSVJXPEIZXcf/CnPVlisgv6Dsucmsc\nhPPl8/D93Wb+YNDjwik4rGFhA9staV/1CETZ9rEDR6Fk46zjRsNLL7jbteMBv4F0k4iFapitGoyu\ncNrgJRTw4QeP2FbXpJ0h6+oWbRpf2rMb+r0nJElGmvYZ909Jkz7WRIjIr1HdbERd35GnPcqy4k//\n5D1/LQMxmweux3DfHsJaew/qwgZrrQj+d63XDrs+TdaEDt+tWii8fxfKkR6QhHGWHpSlqSNK8ApS\nAdtSUjfw8qP/wjqd+EN2U+9KZKKpEZH0gPLRANEfQN6FNA9gxyET3+7FYiDrIbCITGLqmkhmzOdT\n0rzLerUmiR8mlID78WjVgS2VzLlWxdaCHOn5PMZzeaQUNNpQKa9gVdYEYOnQbYYn9ClzeM6YM97r\nqE3RuNBfq+Xk2NBu4h6ch/5ZuF02yLTE85Bx88DH0I0kjXNcAJXxPmvaeEW0UeYbWz9+M47OMrb5\nIVtbQ5rT6BhnK0ztmK23JFGJfPSE9bahkwuy6WtWKmGSKuIsYfir/8705QWynMPtc1RTYW4/RnUO\nybFIM6OrNwxlQpF1uZ68R3TyiKipiTpnyMc/IZaKJtJ0IkdXr1DFGBf1ifOG5uqXOLUguvw7ZLlA\nfPS3yDgO5mT3zpotIc0fzjzEEYgAikKS74H5dx2o5knr9+AM2kCEJLEKZRyZsxir0U7gFMGJ0yHy\nFGxQjzm/uVbC10+Ns97PR0gavHIixqEiiJ0lwfdYkkJihHcwMoHcJcPpTr+lQnPWep8j50jxwMfg\nJ6J2jkx4xZa2xsvVjUKvSzabLUJVKNMg1EMZTrDaNPz9q1fYtEvvcEjv0RHX2y0fv55SrWu2i5L3\nj8f89J0jjg66DDop0tZcrEr6/YJ//sMnXN0uubhbczwc4Iyl1IZ1ZZgqxaXasOlKVLpk3E8Z93oM\nT3oU/YRUZvSHEVkR0aiGJI0xWuKU50kNBzmVgU4nYV2vmG629HoR0jl6UYRSGusiRkVGp5+xruDg\n5IhRV/K7qzW9YUqv6GFMxWev3+AtQx8yNiVSRoxGY5I0ZTgaMTmaMBqP+OGPf8LNzQ3dbpeyLLHW\nMhgMODg44OLzC548eQIIrq7e0O/3MYHQrrWm0+kwnU65vLzk2bNnpGnKeDymKLpo7VVXSeK5Y928\nw/njc9brNUmSkCYJVVVx9eaKpmlI05TzJ0+Y3t6yXq3oFQXGGJRSdPOc0XCI0Yrp9JrD0Yh+0aWp\nNcdHj3AkCBKq8uFAmchvTFr//pxrWzFYPPG+BUDW+ZYUN9M1z59fc/VmxrZSWLwEO0+9QVlZG+rW\no8u4wM2A2fSO3332HIDZ7R2zW6+GuXpzycXzC1+esjDs5XSy9lq7t18YUeRFAFK2i/h9lgFaLx53\nz1t46/T7j442c/F7ipYHPNUODLXZModWjqmG143YvSsBpAK6keU4hY9elyil0NZhRYwx0Fws+Omw\nj1E1g3xCvzshjQoEA6IIVFNyefdb3kw/Yzg4RkaCzeqOsl4ihCWNOwgZEaVDRodDHp8Xu8ydc1/z\n+r8h/MHVZw1aEi1Bet4CKc85EVgdms7a+5Kkcc6/v5ARE23pxbkddhT+/EySCJIMZGyD84lgsYau\n/B5J3gcLpqr9ZTYaYS1CCqQ1JB+eg0wQ1uK2a08ZSDLQCr26Q0RBoWYUUqREWZdOnnJ4MOL6zWuS\nJKJcf337hK+97i3Q4x4vg89gGdsCkhZ0hu+1pSXjUKYdZdDKeIsH7Xk+VrXKLj/Wxhcu/N4a7BJw\n/uzsCcneUf0LJHvnR9zY9nrxVtYHwsaOwZEEwY1SdrdeLG/XQRDw9XfDNwKe1RIWjcQmOXnqS0Wp\nqtDLVzw6PWQkDUmkqTeazc0bRLNk0yy5MiNGh8fof/lviPSUyCia0WMSXZFlxyT1DKUqiuIRlXA4\nl2CkYnj998yWKxoZ4fSWjethXELWbJBJjhIZ5uYFVa3Ixu8y6Q9JygV6+IS+jFlNb99SF0ThJnC7\nBqFGeE8OK+7vJmcFRkZ8xYHym+dPWHVUMCywQGI0whrqxiuilLGkIsJog5QWaUEojdSKRniQ4WgJ\nzSGrEJj8yvlMjxTWd1EPpC/rDKlzaGv9Zhdke5F1vg9Xe/c65yXx+NMLgHCKKvB9otCqonHO26Fj\nqa3ykv2mQTY+05NrS/MtMjzvnk5ItUM0Daqy3M42fPZqzrvnR9Sx5IffecQgi6m1Y6ktJ5MBnTwj\nEymzsuKj1zO09b1nImno5BlRJJluKw6OOgyGfZxTGKtJYp8BWq9KnHb0DmIm4x5X12vyIsMoRx4Z\n+r0EJy3aWOJIsakabASdSJDEGSJOiGLB2WHPuyuXinWl0ELx6//7nKZyjHodPn55BU4z7MR0M8ns\nbvmgsVlvt9R1jTGG8XhM0zS7r+uqYr1ec3FxQZokRHHMycmJr1mnKVVV8tFHf4vWGikldV0znU6J\noojpdMrBwcEOBBVFQZIkDIdDjo+PyPOcu7tbVqsVFy8uuLq6Js9zoiji+nqKlJL3338f5xxVVfLr\nX/2K7WYLQjC9nrLZbDk7O/eLy3KFtZY4TjDKkHd6NI1jW1peXlwzvZnT7XQfPG+Q/kQZxW/Tgn20\nJzdn7W6zB8IiabHWsNpUaGPv53zsZdZ5EnkehvJMFhsWVaNhcjTh8Mj3Q5scjXnyzhknp0c8fvKU\nD374vt+8jJ837ktLpiBkZoMlP8KXPkLroZ0x4L00NwAN+a0SPLu/+TJhud00dpkN+NJm8tYJ/63n\nA/+ajILLxvGyEeE9+B/FAg5z+Oy2RNmEWuvdTjRbLjka5DyWkHUeM51/jjEbNs0CZQQyGpDFHaRI\nKasVCOj1jsmSgk42CYZykkgmkJzwn/69N4yzPJzQfT8mLpDFv1i627kqh0y6MoSO9T6D0IJIyT3Q\naeX7u7EMH6X0/d6SxJGkgih2iNjPp4vF0vd/qhqk0gitiKoGaQIh2jryJ09wxvoniYOjtG5w1hIn\nHUTWwd8IEUmnQyQjlssVTW0ZDCdEaUqneNi9FTpT7TKBLryZnbLQBvO/kGmxxnkwEzI11nkzXS/X\nN9igGHY63EtGeCsA7rN0xvqeZl6O7ueYNQH0BPDjAvihLYNZ349rB7xMIDN7/0YE0E9ilputd3Z2\nnjTdaINWGt0odPP1rSW+EfA0zS3CKJKsC8kBiTW4NCcanVDeTKniPtYakuiWYjgifvpDCjbUs2uW\nsxny5hVJ2qc3OmWSGIRdk2YSmR6gSkWdFLj+MbIjiZIRUZKQLF9TV5o47xCbK4RaIodHzNcNkboj\n7Z0gTcn8bo0tJtSyIJ89Z11XYHypwi88Xpm0S2H64lbw5PGZHc9xEcSAeyDfQFmv1LLWoowhspba\nWZTxZoZS+ayK1ooURyMsifQy8SKWJJGnSStjvaIiSM1jAXkgmybCO11WeL6REr7JqBHQWiXW1uCc\nb2TqcGT4dhEa3xW+dgbhvGFYYwXC+lKKsobGF713J5gY3zZAa43SCls1GKsfLLsG+PziJYk1vLpa\n8OZ2QSJinp1MaLRj1EmoGsXFdMP/+vSSVMOPz444HqScDAS9JOX9R2PORx22wvGLyzucgNW64cPH\nE+KOY1NWVNstVWm4XWi2m9KT7IxjkKYYseVwnLNeaISV9AYd0lwSi5giM0gjiE1Bss2JhCCLYdzP\nKboxK6PppAmDIDk3lSaRMbNVRdFJ6cYZlRNUQSGwTh7G4RmOx6zXG2Z3d7x+9Yp/+k9+ugM819fX\nvP/++zx79ozVek2vKGiaBqUUJ6en1HXN0fExBwcHbLdblssl5+fnrFYrPvzwA+I4Jk1Tqqri9vaW\n29tbfvPrX7NarViv1x6gGMPh4SHWGvI8Z7lc0O/3GQ1HzOeLXWZoNBpzfHzEoN9HKUVR9FgtV6Fn\nV0Wv12W59D21yq0iirt0izHPnn2ItRGz+cOAIICrfU8f1XxFCiR8aYz58hEV8KWwo0nBwTglToPB\nnPOeKVkWE0cRnTzyfDUBSQTdXCIjQa9XhOfWaOXd040xqEZhnMA0sF5sPQnz7dcLO7U4+PvZ7jYV\nn2EQAohARMLf4C235qFjw1t/Ez75AjBoQdA/UBH6quyJLyvAVDnetJyeAHzyGF7eLJiuK3CSpvYu\nhcmgA2/uOO2u6Kopadqj2x0wHr6DlBHXN39LLz+kNzymrNcIp7i5fYWMM+arK6SzKF15MKkUf/an\nP969j2+T4XHtxQjmkjZksUHsJPvGBGJyKHHtMhvsLkuoB/jsoJB+094BBXyGR0YQhyxPnAgiKYhT\neGM+pa42sK08Z7NRQbjvM/ByXCC7XZDSS9IjCdbipDcGdUm8487hIMoKjG7Iuz2iLGO1WVFuaj79\n3eXDxgZxDwBFAICBP4PzdIdG+ayXCzJ+04JB5w1ZdbA4MMaGA0AgF2vhjQu19X8fgIoJjQ7alhW7\ncpkBAqD5ovGgB1u7Mpkm/L8gk9f+Is8rv0clQBkMM8vG4SJJnKUk0dfv5d+4y0dZxmh0SCKEV+yM\n+6zn10QuJnYaaSpUEzEYnYLsQLWiO3mPfi/FpRlaCAZmzcJY4sE7GDliLcdcz+d0jh6zXjcIo6hs\nQiRqmkYS9fs4GZPrDbdVgs4GXN3MGGQlpphgo4TkYELXXnOxMNhOF90dI1/+jjgsRg6QtlVggRCS\n1ne57UnlrPQ4PpS0nDAPmkAYiK1BWE0iHFo7YusglIYaIYixAWgYhNY4bb3DsrE7XwYRiXtvBus3\n7No6UoTn/+C8lM55eXoKKGvQRmNd+zMbpPjsnJcj60nJrYW5ILS+wEvjYyfIjcNa5XskhcyQMwqr\nFaLR0BiassJ9A2L+unjvfMLRZMBqveV2eocoS5R1bJRisTXMS4VII945HZLkMbPtAuck863i7LjL\n9WLFstZI3bBdKKbrkkeTgjhzZIXjcOwQkWa7LanKJWnHkaZwkEleT5f0OwdImVHWDVvVeCdZ5VAN\nvL4pmZwMGQ1jor5j0IuptGW1KdEOVsuKvJexKRs6RU4nSSi6KdeLLd0IJsMueRxzPTMcTCaMew+T\npc9Wa7S1HB4eUhQFv/zl/yFNUppGcXJyws31NZv1mneePuXi4oLLy0vOzh4xn888R0JrtDZIKdlu\nt0ynUx49esTPf/4LVqsVh4eHO6ffXq9Hnud+gUgS8jxnuy2RUUSv12c+n3NwcIiUEcvVim63w2x2\nh7WWqioZDPpstxuKfh+tDdc3U6I4ptPp0ulkSCkoej2KwYAoTZkcPeH15S1P3/0u50+ePXjeiMB1\nEV9RYr7HN/f8Cr8JCUScsFkr3+zV+pKw38T8aTGWvkWLDeAeC7GUlI1lNp1x+eoNAC9evObN5Q2f\nv7jg1cXv+M3f/T0m3Dte7fn7S6bAnzIRnswrQypGCHGvHrA+o+CM33FabseDxqb92IIl98XHrhz1\n5aFzX/zY8ja+HC5sYK9rx1z7TV/i/cUmPfh8DsY2ocSgqRX88i9ew4sLzqMKozdMb35DFhcIGXF4\n8C5J2mez1URSMFtNEZGiWl8QxwJjNTLqYlxJYyN+9KPvMxhnD0eCIezbm6e/HIGkfV9i3IEh6/ld\nHgC9Vepqf4+Wey7CtQzNMm1b2hLIyDdhTjKIEkecQBnB6+UtVmm/juqwtju/7qbvHsGgh6vK+wsh\nvB2JF9i0gNhAlCBw9CaPwJUUnQxTb4kjR6/7sP59/v24wHFqvXJCuakt6YXx0PaePKyNTxUY63Zm\njs4alPEVDBVAkgnZIGscRkuUEhgtvbeQdiGb016fwJfSYtewtR139xZY9aXHAJR2JbLQkqzR3h5E\n+/+5bQxREqMbS6//9dmvbwY81rJWMTLNiDZz6ukGVxyynd3RDA+pkx5qNaXjHOu6pApMb1Wc4IjZ\nxgVbCV29ZVvesXVrcnXLeHuJNTWdXoKsVhwOBySDU7qjPp0kZ1BYytG7DNUCIQyFW1H3zkE0ZKZB\n64x4/A6PxBs6AUVWzy/9JBTsTL4S6031HH4hEjtPChCRVyjFO07Pwzo7C+lRrtK+tBQbQ2O8EWFl\nNFZXviYbOucK4yitxhlLGse08EpoSxQmX2V9LyJrfXYojmWYYNabVjlLbY3PSOGzRVgPipKdosCv\nfs5ZZCBrW9f+jj/x+L/3XCGcv87CGu9Cay2xVVjrm5uiDOYr+BT/UDTzFetFSRFBZSwXt0tm8xl9\nIqKo4fnNmnEW87PzA1ZbRa0ch/0M4yRv7jak3YxBLklkxFEv5zsHBZNezitniANXKUpyRGIYHw7Q\nJseQMVeSUS9lvVnTNCV5LjBKoGtYrx1JLJiMOpimQokYjKQ/Lhh1MuJOTp5JlBXoqqEjJFWjUdaR\nRRnDbsyiNEgRkaF5epxwt1jC9uZBYxP3JmyVYLlc0DQ1h4eHGGMYDgdcXV35LuZxhNaaoig4ODjY\nEZLTNCVNUwaDAYvFgqOjI77zne/w/PlzHj9+zOnpaShJVQghMMbQ7/dQSrHZbCnLiuFwSNNU3N3d\nUZYVt7d3SCmw1nI3m3F0dIwxhtFoRL/Xo9yWFN2O77EVJzw+O6OT5yyXS/r9Hlknw2IZDEb85V/+\nb4piwCef/I5fffTrB8+bHZHyK6Zcu49bbfkCB8aBUZpuHrPZKO5m1b1LrvUGcxp/mFDaZ37bClMs\nYXI84d33ngLw7rNzzs7PePrsKe9857v80R//ePe/hHAkyVeQqYMCxquiwucmZE28B6r/PSk8kGuR\nxLcZmxa0fAnsfCH9Y7/ie299fBvsfLnM1Zq8vaoEdXjdsYR+Ap/cQa0tRjV+k1OKWvb47Scr7MUL\nRrdThoN32FRXWKsot3NeTT/F2ZJar+nmPdL8kP7wu6RphnE5cVwQRyO0rnAy4T/+u38RxvRbjI+n\n+WJCKcrulD6tIev9xqoDYvSbvQiKq/tBazuBO0DKIKR2bfnGP4+MBEniMztJ4itUaQJ/c/UpThlc\no3eZGhlFSCnJ33vsmzerCmc0tmlwUeKVsKqBOMNZDQjfMb2pUdWW4cExMYKT08ckWY/j479+6OC8\nzXN/q+Tpbzgb3JJNICybFmTssl9eSGOd75autb+XtAarJVqDtpbGuMCNEujA3m9LWTaUEN92ct5l\nctqsTgu+9H17C9dyekKhIZUSrRok0Gif4VHGUVcKrTTw4muHQbivgvr72Mc+9rGPfexjH39A8S37\n9u5jH/vYxz72sY99/P8Te8Czj33sYx/72Mc+/uBjD3j2sY997GMf+9jHH3zsAc8+9rGPfexjH/v4\ng4894NnHPvaxj33sYx9/8LEHPPvYxz72sY997OMPPv4fdIUDdUTUECYAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "olF4PpORpCTK", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} diff --git a/trax/models/reformer/machine_translation.ipynb b/trax/models/reformer/machine_translation.ipynb deleted file mode 100644 index 55192bf21..000000000 --- a/trax/models/reformer/machine_translation.ipynb +++ /dev/null @@ -1,382 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Reformer: Machine Translation", - "provenance": [], - "collapsed_sections": [ - "udDs_biH0n5U" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "TPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "udDs_biH0n5U", - "colab_type": "text" - }, - "source": [ - "#### Copyright 2020 Google LLC." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WPY-OyyM0pSs", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - " https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "psnUF-8c02o_", - "colab_type": "text" - }, - "source": [ - "# Reformer: Machine Translation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1lnRd_IoERdk", - "colab_type": "text" - }, - "source": [ - "This notebook was designed to run on TPU.\n", - "\n", - "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8PluCmWbZIpJ", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Install JAX.\n", - "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n", - "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n", - "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n", - "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n", - "\n", - "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", - "import requests\n", - "import os\n", - "if 'TPU_DRIVER_MODE' not in globals():\n", - " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", - " resp = requests.post(url)\n", - " TPU_DRIVER_MODE = 1\n", - "\n", - "# The following is required to use TPU Driver as JAX's backend.\n", - "from jax.config import config\n", - "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", - "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", - "print(config.FLAGS.jax_backend_target)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yiPdBenoZwH6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n", - "\n", - "from tensorflow.compat.v1.io.gfile import GFile\n", - "import gin\n", - "import os\n", - "import pickle\n", - "import jax\n", - "import trax\n", - "from trax.models.beam_search import Search\n", - "from trax.supervised import inputs\n", - "\n", - "from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder\n", - "\n", - "import numpy as np\n", - "import jax.numpy as jnp\n", - "\n", - "from scipy.special import softmax" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "uCX88z9iXB7s", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Install sacreBLEU\n", - "!pip install sacrebleu\n", - "import sacrebleu" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FQ89jHCYfhpg" - }, - "source": [ - "## Load WMT14 data" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8S3h28Q9b_9B", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Download the newstest2014 English-to-German translation pairs\n", - "!sacrebleu -t wmt14/full -l en-de --echo src > wmt14-en-de.src\n", - "!sacrebleu -t wmt14/full -l en-de --echo ref > wmt14-en-de.ref" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "CBv2SDnWZEI7", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Load the source text and reference translations into Python\n", - "refs = []\n", - "for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1):\n", - " if line.endswith('\\n'):\n", - " line = line[:-1]\n", - " refs.append(line)\n", - "srcs = []\n", - "for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1):\n", - " if line.endswith('\\n'):\n", - " line = line[:-1]\n", - " srcs.append(line)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "CbYw4eMXZGKa", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Set up our sub-word tokenizer\n", - "tokenizer = SubwordTextEncoder(\n", - " 'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "2NbOslppZGZ0", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Encode source sentences using the tokenizer\n", - "input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64)\n", - "for i, x in enumerate(srcs):\n", - " x = tokenizer.encode(x)\n", - " assert len(x) <= 127\n", - " input_ids[i, :len(x)] = x\n", - " input_ids[i, len(x)] = 1" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YwzU64GmZTb2", - "colab_type": "text" - }, - "source": [ - "## Load the pre-trained model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "VXjtCPxl3I82", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# We'll be using a pre-trained reversible transformer-base model.\n", - "# First, load the config (which sets all needed hyperparameters).\n", - "!gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin\n", - "gin.parse_config_file('./config.gin')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "IediBe8MXyLf", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Now we load the pre-trained model weights.\n", - "with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f:\n", - " model_weights = pickle.load(f)['weights']" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zY3hpgnI5Rgn", - "colab_type": "text" - }, - "source": [ - "## Beam search decoding" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "fc_VlhrBYW0u", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Set up beam search.\n", - "beam_decoder = Search(\n", - " trax.models.Reformer, model_weights,\n", - " beam_size=4,\n", - " alpha=0.6, # For length normalization, set to 0.6 following Vaswani et al.\n", - " eos_id=1, # The stop token has id 1 in the vocabulary we use.\n", - " max_decode_len=146,\n", - " )" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "bynTpreMYXPs", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 71 - }, - "outputId": "cfd24e01-617b-4beb-a5f2-98a7ce2e1449" - }, - "source": [ - "pred_ids = []\n", - "preds = []\n", - "BATCH_SIZE = 1024\n", - "for start in range(0, input_ids.shape[0], BATCH_SIZE):\n", - " print(start, '/', input_ids.shape[0], flush=True)\n", - " batch = input_ids[start:start+BATCH_SIZE]\n", - " seqs, scores = beam_decoder.decode(batch, batch_size=BATCH_SIZE)\n", - " # Select highest scoring output.\n", - " batch_pred_ids = seqs[:, -1]\n", - " pred_ids.append(batch_pred_ids)\n", - " preds.extend([\n", - " tokenizer.decode(pred.tolist(), strip_extraneous=True)\n", - " for pred in batch_pred_ids\n", - " ])" - ], - "execution_count": 13, - "outputs": [ - { - "output_type": "stream", - "text": [ - "0 / 3003\n", - "1024 / 3003\n", - "2048 / 3003\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "c5Gq4qF_YY2i", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 35 - }, - "outputId": "37a5e24f-9264-4d7a-dd74-065758c9a7e4" - }, - "source": [ - "bleu = sacrebleu.corpus_bleu(preds, [refs], lowercase=True, tokenize='intl')\n", - "print(bleu)" - ], - "execution_count": 14, - "outputs": [ - { - "output_type": "stream", - "text": [ - "BLEU = 27.86 59.5/33.5/21.3/14.2 (BP = 1.000 ratio = 1.020 hyp_len = 65943 ref_len = 64676)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "olF4PpORpCTK", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} diff --git a/trax/models/reformer/reformer.py b/trax/models/reformer/reformer.py index 73edeae7a..286ffda9a 100644 --- a/trax/models/reformer/reformer.py +++ b/trax/models/reformer/reformer.py @@ -24,619 +24,771 @@ # pylint: disable=invalid-name -def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, - n_heads, attention_type, dropout, ff_activation, - ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, - attention_chunk_size, n_attention_layers=1, - n_feedforward_layers=1, center_layernorm=True, - use_bfloat16=False, mode='train'): - """Reversible transformer decoder layer. - - Args: - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - d_attention_key: int: depth of key vector for each attention head - d_attention_value: int: depth of value vector for each attention head - n_heads: int: number of attention heads - attention_type: subclass of tl.BaseCausalAttention: attention class to use - dropout: float: dropout rate (how much to drop out) - ff_activation: the non-linearity in feed-forward layer - ff_dropout: the dropout rate in feed-forward layer - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - attention_chunk_size: int, if > 0 run attention chunked at this size - n_attention_layers: how many residual causal attention layers should we - have before the feed-forward block (default: 1, the standard block) - n_feedforward_layers: how many FFNN layers should we have (default 1). - center_layernorm: whether to use centering in LayerNorm (default) or if - to skip it, which is known as RMS normalization. - use_bfloat16: whether to use bfloat16 for weights (default: False). - mode: str: 'train' or 'eval' - - - Returns: - the layer. - """ - # pylint: disable=g-complex-comprehension - def _Attn(): - return ct.ApplyAttentionLayer( - attention_type, d_model, n_heads, d_attention_key, - d_attention_value, True, False, dropout, dropout, - attention_chunk_size, mode) - - def _FF(): - return ct.FeedForwardWithOptions( - d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, - ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, - mode, use_bfloat16) - - def _attention_half_residual(): - return [ - tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm), - attention_layer=_Attn(), - name='ReversibleHalfResidualDecoderAttn'), - tl.ReversibleSwap() +def DecoderBlock( + d_model, + d_ff, + d_attention_key, + d_attention_value, + n_heads, + attention_type, + dropout, + ff_activation, + ff_dropout, + ff_use_sru, + ff_chunk_size, + ff_sparsity, + attention_chunk_size, + n_attention_layers=1, + n_feedforward_layers=1, + center_layernorm=True, + use_bfloat16=False, + mode="train", +): + """Reversible transformer decoder layer. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + n_heads: int: number of attention heads + attention_type: subclass of tl.BaseCausalAttention: attention class to use + dropout: float: dropout rate (how much to drop out) + ff_activation: the non-linearity in feed-forward layer + ff_dropout: the dropout rate in feed-forward layer + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + attention_chunk_size: int, if > 0 run attention chunked at this size + n_attention_layers: how many residual causal attention layers should we + have before the feed-forward block (default: 1, the standard block) + n_feedforward_layers: how many FFNN layers should we have (default 1). + center_layernorm: whether to use centering in LayerNorm (default) or if + to skip it, which is known as RMS normalization. + use_bfloat16: whether to use bfloat16 for weights (default: False). + mode: str: 'train' or 'eval' + + + Returns: + the layer. + """ + # pylint: disable=g-complex-comprehension + def _Attn(): + return ct.ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_attention_key, + d_attention_value, + True, + False, + dropout, + dropout, + attention_chunk_size, + mode, + ) + + def _FF(): + return ct.FeedForwardWithOptions( + d_model, + d_ff, + dropout, + [-2], + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + center_layernorm, + mode, + use_bfloat16, + ) + + def _attention_half_residual(): + return [ + tl.ReversibleHalfResidual( + tl.LayerNorm(center=center_layernorm), + attention_layer=_Attn(), + name="ReversibleHalfResidualDecoderAttn", + ), + tl.ReversibleSwap(), + ] + + def _feed_forward(): + return [ + tl.ReversibleHalfResidual(_FF(), name="ReversibleHalfResidualDecoderFF"), + tl.ReversibleSwap(), + ] + + return [_attention_half_residual() for _ in range(n_attention_layers)] + [ + _feed_forward() for _ in range(n_feedforward_layers) ] - def _feed_forward(): - return [ - tl.ReversibleHalfResidual(_FF(), - name='ReversibleHalfResidualDecoderFF'), - tl.ReversibleSwap() + +def ReformerLM( + vocab_size, + d_model=512, + d_ff=2048, + d_attention_key=64, + d_attention_value=64, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + attention_type=tl.SelfAttention, + pos_type=None, + pos_axial_shape=(), + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, + ff_activation=tl.FastGelu, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, + loss_sparsity_type="mult", + loss_sparsity=0, + loss_d_lowrank=0, + loss_sparsity_prob=None, + attention_chunk_size=0, + mode="train", +): + """Reversible transformer language model (only uses a decoder, no encoder). + + Args: + vocab_size: int: vocab size + d_model: int: depth of *each half* of the two-part features + d_ff: int: depth of feed-forward layer + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + n_layers: int: number of decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + attention_type: class: attention class to use, such as SelfAttention. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + pos_start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + pos_max_offset_to_add: maximum offset to add to positions during training + when randomizing; this offset plus input length must still be less than + max_len for all training examples. + ff_activation: the non-linearity in feed-forward layer + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + loss_sparsity_type: str, type of sparsity to used in loss layer. See + SparseDenseWithOptions for options. None if no sparsity should be used. + loss_sparsity: int, the sparsity for loss layer (if used) + loss_d_lowrank: int, the dimensions for intermediate layer (if used) + loss_sparsity_prob: float, the probability for sparse version of loss to be + used. If None, only sparse version is used. + attention_chunk_size: int, if > 0 run attention chunked at this size + mode: str: 'train', 'eval', or 'predict' + + Returns: + the layer. + """ + positional_encoding = ct.PositionalEncoder( + mode, + dropout, + max_len, + pos_type, + pos_axial_shape, + pos_d_axial_embs, + pos_start_from_zero_prob, + pos_max_offset_to_add, + ) + + positional_embedder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout( + rate=dropout, shared_axes=[-2], mode=mode + ), # pylint: disable=no-value-for-parameter + positional_encoding, + ] + + decoder_blocks = [] + + if isinstance(attention_type, (tuple, list)): + assert n_layers % len(attention_type) == 0 + else: + attention_type = [attention_type] + for layer_idx in range(n_layers): + layer_attention_type = attention_type[layer_idx % len(attention_type)] + decoder_block = DecoderBlock( + d_model, + d_ff, + d_attention_key, + d_attention_value, + n_heads, + attention_type=layer_attention_type, + dropout=dropout, + ff_activation=ff_activation, + ff_dropout=dropout, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + decoder_blocks.append(decoder_block) + + dense_loss_layer = tl.SparseDenseWithOptions( + vocab_size, + d_input=d_model, + sparsity_type=loss_sparsity_type, + sparsity=loss_sparsity, + d_lowrank=loss_d_lowrank, + prob_sparse=loss_sparsity_prob, + mode=mode, + ) + + return tl.Serial( + tl.ShiftRight(mode=mode), + positional_embedder, + tl.Dup(), + tl.ReversibleSerial(decoder_blocks), + tl.Concatenate(), + # # TODO(kitaev): Test whether dropout should go before or after the + # LayerNorm, and whether dropout broadcasting is needed here. + tl.LayerNorm(), + tl.Dropout( + rate=dropout, shared_axes=[-2], mode=mode + ), # pylint: disable=no-value-for-parameter + dense_loss_layer, + ) + + +def ReformerShortenLM( + vocab_size, + shorten_factor=1, + d_embedding=256, + d_model=512, + d_ff=2048, + d_attention_key=64, + d_attention_value=64, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + attention_type=tl.SelfAttention, + pos_type=None, + pos_axial_shape=(), + pos_d_axial_embs=None, + ff_activation=tl.FastGelu, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, + attention_chunk_size=0, + mode="train", +): + """Reversible transformer language model with shortening. + + When shorten_factor is F and processing an input of shape [batch, length], + we embed the (shifted-right) input and then group each F elements (on length) + into a single vector -- so that in the end we process a tensor of shape :: + + [batch, length // F, d_model] + + almost until the end -- at the end it's un-shortend and a SRU is applied. + This reduces the length processed inside the main model body, effectively + making the model faster but possibly slightly less accurate. + + Args: + vocab_size: int: vocab size + shorten_factor: by how much to shorten, see above + d_embedding: the depth of the embedding layer and final logits + d_model: int: depth of *each half* of the two-part features + d_ff: int: depth of feed-forward layer + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + n_layers: int: number of decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + attention_type: class: attention class to use, such as SelfAttention. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, values must sum to d_embedding. + ff_activation: the non-linearity in feed-forward layer + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + attention_chunk_size: int, if > 0 run attention chunked at this size + mode: str: 'train' or 'eval' + + Returns: + the layer. + """ + assert mode != "predict" # TODO(lukaszkaiser,kitaev): fast inference + + positional_encoding = ct.PositionalEncoder( + mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs + ) + + positional_embedder = [ + tl.Embedding(vocab_size, d_embedding), + tl.Dropout( + rate=dropout, shared_axes=[-2], mode=mode + ), # pylint: disable=no-value-for-parameter + positional_encoding, ] - return ([_attention_half_residual() for _ in range(n_attention_layers)] - + [_feed_forward() for _ in range(n_feedforward_layers)]) - - -def ReformerLM(vocab_size, - d_model=512, - d_ff=2048, - d_attention_key=64, - d_attention_value=64, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - attention_type=tl.SelfAttention, - pos_type=None, - pos_axial_shape=(), - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0, - ff_activation=tl.FastGelu, - ff_use_sru=0, - ff_chunk_size=0, - ff_sparsity=0, - loss_sparsity_type='mult', - loss_sparsity=0, - loss_d_lowrank=0, - loss_sparsity_prob=None, - attention_chunk_size=0, - mode='train'): - """Reversible transformer language model (only uses a decoder, no encoder). - - Args: - vocab_size: int: vocab size - d_model: int: depth of *each half* of the two-part features - d_ff: int: depth of feed-forward layer - d_attention_key: int: depth of key vector for each attention head - d_attention_value: int: depth of value vector for each attention head - n_layers: int: number of decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - attention_type: class: attention class to use, such as SelfAttention. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - pos_start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - pos_max_offset_to_add: maximum offset to add to positions during training - when randomizing; this offset plus input length must still be less than - max_len for all training examples. - ff_activation: the non-linearity in feed-forward layer - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - loss_sparsity_type: str, type of sparsity to used in loss layer. See - SparseDenseWithOptions for options. None if no sparsity should be used. - loss_sparsity: int, the sparsity for loss layer (if used) - loss_d_lowrank: int, the dimensions for intermediate layer (if used) - loss_sparsity_prob: float, the probability for sparse version of loss to be - used. If None, only sparse version is used. - attention_chunk_size: int, if > 0 run attention chunked at this size - mode: str: 'train', 'eval', or 'predict' - - Returns: - the layer. - """ - positional_encoding = ct.PositionalEncoder( - mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs, - pos_start_from_zero_prob, pos_max_offset_to_add) - - positional_embedder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter - positional_encoding, - ] - - decoder_blocks = [] - - if isinstance(attention_type, (tuple, list)): - assert n_layers % len(attention_type) == 0 - else: - attention_type = [attention_type] - for layer_idx in range(n_layers): - layer_attention_type = attention_type[layer_idx % len(attention_type)] - decoder_block = DecoderBlock( - d_model, d_ff, d_attention_key, d_attention_value, n_heads, - attention_type=layer_attention_type, - dropout=dropout, - ff_activation=ff_activation, - ff_dropout=dropout, - ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity, - attention_chunk_size=attention_chunk_size, - mode=mode) - decoder_blocks.append(decoder_block) - - dense_loss_layer = tl.SparseDenseWithOptions( - vocab_size, - d_input=d_model, - sparsity_type=loss_sparsity_type, - sparsity=loss_sparsity, - d_lowrank=loss_d_lowrank, - prob_sparse=loss_sparsity_prob, - mode=mode) - - return tl.Serial( - tl.ShiftRight(mode=mode), - positional_embedder, - tl.Dup(), - tl.ReversibleSerial(decoder_blocks), - tl.Concatenate(), - # TODO(kitaev): Test whether dropout should go before or after the - # LayerNorm, and whether dropout broadcasting is needed here. - tl.LayerNorm(), - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter - dense_loss_layer, - ) - - -def ReformerShortenLM(vocab_size, - shorten_factor=1, - d_embedding=256, - d_model=512, - d_ff=2048, - d_attention_key=64, - d_attention_value=64, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - attention_type=tl.SelfAttention, - pos_type=None, - pos_axial_shape=(), - pos_d_axial_embs=None, - ff_activation=tl.FastGelu, - ff_use_sru=0, - ff_chunk_size=0, - ff_sparsity=0, - attention_chunk_size=0, - mode='train'): - """Reversible transformer language model with shortening. - - When shorten_factor is F and processing an input of shape [batch, length], - we embed the (shifted-right) input and then group each F elements (on length) - into a single vector -- so that in the end we process a tensor of shape :: - - [batch, length // F, d_model] - - almost until the end -- at the end it's un-shortend and a SRU is applied. - This reduces the length processed inside the main model body, effectively - making the model faster but possibly slightly less accurate. - - Args: - vocab_size: int: vocab size - shorten_factor: by how much to shorten, see above - d_embedding: the depth of the embedding layer and final logits - d_model: int: depth of *each half* of the two-part features - d_ff: int: depth of feed-forward layer - d_attention_key: int: depth of key vector for each attention head - d_attention_value: int: depth of value vector for each attention head - n_layers: int: number of decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - attention_type: class: attention class to use, such as SelfAttention. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, values must sum to d_embedding. - ff_activation: the non-linearity in feed-forward layer - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - attention_chunk_size: int, if > 0 run attention chunked at this size - mode: str: 'train' or 'eval' - - Returns: - the layer. - """ - assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference - - positional_encoding = ct.PositionalEncoder( - mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs) - - positional_embedder = [ - tl.Embedding(vocab_size, d_embedding), - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter - positional_encoding, - ] - - decoder_blocks = [] - - if isinstance(attention_type, (tuple, list)): - assert n_layers % len(attention_type) == 0 - else: - attention_type = [attention_type] - for layer_idx in range(n_layers): - layer_attention_type = attention_type[layer_idx % len(attention_type)] - decoder_block = DecoderBlock( - d_model, d_ff, d_attention_key, d_attention_value, n_heads, - attention_type=layer_attention_type, - dropout=dropout, - ff_activation=ff_activation, - ff_dropout=dropout, - ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity, - attention_chunk_size=attention_chunk_size, - mode=mode) - decoder_blocks.append(decoder_block) - - # pylint: disable=g-long-lambda - return tl.Serial( - tl.ShiftRight(), - positional_embedder, - tl.Dup(), # Stack has (x, x), the first will be shortened - # Before shortening, we need to pad by shorten factor so as not to leak - # information into the future. To understand why, imagine shorten factor - # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we - # would have 0ABC, which gets grouped to [0A][BC] on input, which is - # predicting ABCD as targets. The problem is that [0A] has access to A - # and [BC] has access to C -- it will learn to copy it, peek into - # the future. Shifting twice to [00][AB] solves the problem as the first - # "big" symbol becomes all-0 and the rest is shifted enough. - tl.ShiftRight(n_positions=shorten_factor - 1), - tl.Fn('Shorten', lambda x: jnp.reshape( # Shorten -- move to depth. - x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), - tl.Dense(d_model), - tl.Dup(), # Stack has (short_x, short_x, x) - tl.ReversibleSerial(decoder_blocks), - tl.Select([0], n_in=2), - tl.LayerNorm(), - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter - tl.Dense(shorten_factor * d_embedding), - tl.Fn('ProlongBack', lambda x: jnp.reshape( # Prolong back. - x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), - tl.Concatenate(), # Concatenate with just the embeddings. - tl.CausalConv(d_embedding), - tl.Relu(), - tl.SRU(d_embedding), # One RNN layer for conditional dependence. - tl.Dense(vocab_size), - ) - # pylint: enable=g-long-lambda - - -def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, - ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, - attention_chunk_size=0, center_layernorm=True, - use_bfloat16=False, use_two_swaps_per_block=True, - mode='train'): - """Returns a list of layers that implements a Reformer encoder block. - - The input to the layer is a pair, (activations, mask), where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. - - Args: - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_heads: int: number of attention heads - attention_type: subclass of tl.BaseCausalAttention: attention class to use - dropout: float: dropout rate (how much to drop out) - ff_activation: the non-linearity in feed-forward layer - ff_dropout: the dropout rate in feed-forward layer - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - attention_chunk_size: int, if > 0 run attention chunked at this size - center_layernorm: whether to use centering in LayerNorm (default) or if - to skip it, which is known as RMS normalization. - use_bfloat16: whether to use bfloat16 for weights (default: False) - use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder - block, otherwise use only one swap. - mode: str: 'train' or 'eval' - - Returns: - A list of layers that maps (activations, mask) to (activations, mask). - """ - if mode == 'predict': - # Mode 'predict' means that the decoder should be run one token at a time. - # The encoder only ever runs over full sequences, which is why it's switched - # to 'eval' mode instead. - mode = 'eval' - - def _Attn(): - return ct.ApplyAttentionLayer( - attention_type=attention_type, d_model=d_model, n_heads=n_heads, - d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False, - attention_dropout=dropout, output_dropout=dropout, - attention_chunk_size=attention_chunk_size, mode=mode) - - def _FF(): - return ct.FeedForwardWithOptions( - d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, - ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, - mode, use_bfloat16) - - # TODO(lukaszkaiser): refactor efficient attention layers to unify the API - # If we're using standard attention, we need to pass reshaped mask and not - # return the mask to be compatible with the EfficientAttention API. - attention = _Attn() - if attention.n_out == 2: - attention = tl.Serial( - tl.Parallel([], _InsertAxes12()), - attention, - tl.Select([0], n_in=2) + decoder_blocks = [] + + if isinstance(attention_type, (tuple, list)): + assert n_layers % len(attention_type) == 0 + else: + attention_type = [attention_type] + for layer_idx in range(n_layers): + layer_attention_type = attention_type[layer_idx % len(attention_type)] + decoder_block = DecoderBlock( + d_model, + d_ff, + d_attention_key, + d_attention_value, + n_heads, + attention_type=layer_attention_type, + dropout=dropout, + ff_activation=ff_activation, + ff_dropout=dropout, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + decoder_blocks.append(decoder_block) + + # pylint: disable=g-long-lambda + return tl.Serial( + tl.ShiftRight(), + positional_embedder, + tl.Dup(), # Stack has (x, x), the first will be shortened + # Before shortening, we need to pad by shorten factor so as not to leak + # information into the future. To understand why, imagine shorten factor + # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we + # would have 0ABC, which gets grouped to [0A][BC] on input, which is + # predicting ABCD as targets. The problem is that [0A] has access to A + # and [BC] has access to C -- it will learn to copy it, peek into + # the future. Shifting twice to [00][AB] solves the problem as the first + # "big" symbol becomes all-0 and the rest is shifted enough. + tl.ShiftRight(n_positions=shorten_factor - 1), + tl.Fn( + "Shorten", + lambda x: jnp.reshape( # Shorten -- move to depth. + x, (x.shape[0], x.shape[1] // shorten_factor, -1) + ), + n_out=1, + ), + tl.Dense(d_model), + tl.Dup(), # Stack has (short_x, short_x, x) + tl.ReversibleSerial(decoder_blocks), + tl.Select([0], n_in=2), + tl.LayerNorm(), + tl.Dropout( + rate=dropout, shared_axes=[-2], mode=mode + ), # pylint: disable=no-value-for-parameter + tl.Dense(shorten_factor * d_embedding), + tl.Fn( + "ProlongBack", + lambda x: jnp.reshape( # Prolong back. + x, (x.shape[0], x.shape[1] * shorten_factor, -1) + ), + n_out=1, + ), + tl.Concatenate(), # Concatenate with just the embeddings. + tl.CausalConv(d_embedding), + tl.Relu(), + tl.SRU(d_embedding), # One RNN layer for conditional dependence. + tl.Dense(vocab_size), + ) + # pylint: enable=g-long-lambda + + +def EncoderBlock( + d_model, + d_ff, + n_heads, + attention_type, + dropout, + ff_activation, + ff_dropout, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, + attention_chunk_size=0, + center_layernorm=True, + use_bfloat16=False, + use_two_swaps_per_block=True, + mode="train", +): + """Returns a list of layers that implements a Reformer encoder block. + + The input to the layer is a pair, (activations, mask), where the mask was + created from the original source tokens to prevent attending to the padding + part of the input. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_heads: int: number of attention heads + attention_type: subclass of tl.BaseCausalAttention: attention class to use + dropout: float: dropout rate (how much to drop out) + ff_activation: the non-linearity in feed-forward layer + ff_dropout: the dropout rate in feed-forward layer + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + attention_chunk_size: int, if > 0 run attention chunked at this size + center_layernorm: whether to use centering in LayerNorm (default) or if + to skip it, which is known as RMS normalization. + use_bfloat16: whether to use bfloat16 for weights (default: False) + use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder + block, otherwise use only one swap. + mode: str: 'train' or 'eval' + + Returns: + A list of layers that maps (activations, mask) to (activations, mask). + """ + if mode == "predict": + # Mode 'predict' means that the decoder should be run one token at a time. + # The encoder only ever runs over full sequences, which is why it's switched + # to 'eval' mode instead. + mode = "eval" + + def _Attn(): + return ct.ApplyAttentionLayer( + attention_type=attention_type, + d_model=d_model, + n_heads=n_heads, + d_qk=d_model // n_heads, + d_v=d_model // n_heads, + masked=True, + causal=False, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + + def _FF(): + return ct.FeedForwardWithOptions( + d_model, + d_ff, + dropout, + [-2], + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + center_layernorm, + mode, + use_bfloat16, + ) + + # TODO(lukaszkaiser): refactor efficient attention layers to unify the API + # If we're using standard attention, we need to pass reshaped mask and not + # return the mask to be compatible with the EfficientAttention API. + attention = _Attn() + if attention.n_out == 2: + attention = tl.Serial( + tl.Parallel([], _InsertAxes12()), attention, tl.Select([0], n_in=2) + ) + + def _attention_half_residual(): + return [ + tl.ReversibleHalfResidual( + tl.LayerNorm(center=center_layernorm), + attention_layer=attention, + name="ReversibleHalfResidualEncoderAttn", + ), + tl.ReversibleSwap(), + ] + + def _feed_forward(): + layers = [ + tl.ReversibleHalfResidual(_FF(), name="ReversibleHalfResidualEncoderFF") + ] + if use_two_swaps_per_block: + layers.append(tl.ReversibleSwap()) + return layers + + return _attention_half_residual() + _feed_forward() + + +def EncoderDecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + ff_activation, + ff_dropout, + mode, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, +): + """Reversible transformer decoder layer. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + ff_activation: the non-linearity in feed-forward layer + ff_dropout: float: (optional) separate dropout rate for feed-forward layer + mode: str: 'train' or 'eval' + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + + Returns: + the layer. + """ + enc_dec_attention = tl.EncDecAttention( + n_heads=n_heads, + d_qk=d_model // n_heads, + d_v=d_model // n_heads, + attention_dropout=dropout, + output_dropout=dropout, + mode=mode, + ) + enc_dec_attention_half_residual = tl.ReversibleHalfResidual( + tl.LayerNorm(), + attention_layer=enc_dec_attention, + ) + + causal_attention = tl.SelfAttention( + n_heads=n_heads, + d_qk=d_model // n_heads, + d_v=d_model // n_heads, + causal=True, + attention_dropout=dropout, + output_dropout=dropout, + mode=mode, + ) + causal_attention_half_residual = tl.ReversibleHalfResidual( + tl.LayerNorm(), + attention_layer=causal_attention, + ) + + feed_forward = ct.FeedForwardWithOptions( + d_model, + d_ff, + dropout, + [-2], + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + True, + mode, ) - def _attention_half_residual(): - return [ - tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm), - attention_layer=attention, - name='ReversibleHalfResidualEncoderAttn'), - tl.ReversibleSwap() + return [ # vec_d1 vec_d2 vec_e masks + causal_attention_half_residual, + tl.ReversibleSwap(), + enc_dec_attention_half_residual, + tl.ReversibleSwap(), + tl.ReversibleHalfResidual(feed_forward), + tl.ReversibleSwap(), ] - def _feed_forward(): - layers = [ - tl.ReversibleHalfResidual(_FF(), - name='ReversibleHalfResidualEncoderFF') + +def Reformer( + input_vocab_size, + output_vocab_size=None, + d_model=512, + d_ff=2048, + n_encoder_layers=6, + n_decoder_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + ff_activation=tl.Relu, + ff_dropout=None, + mode="train", + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, +): + """Reversible transformer encoder-decoder model. + + This model expects an input pair: target, source. + + At the moment, this model supports dot-product attention only. For the + attention types in the Reformer paper, see ReformerLM. + + Args: + input_vocab_size: int: vocab size of the source. + output_vocab_size: int (optional): vocab size of the target. If None, the + source and target are assumed to have the same vocab. + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_encoder_layers: int: number of encoder layers + n_decoder_layers: int: number of decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + ff_activation: the non-linearity in feed-forward layer + ff_dropout: float: (optional) separate dropout rate at feed-forward + nonlinearity. This is called relu_dropout in T2T. + mode: str: 'train' or 'eval' + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + + Returns: + A Reformer model as a layer that maps from a target, source pair to + activations over a vocab set. + """ + in_encoder, out_encoder, output_vocab_size = ct.EmbeddingAndPositionalEncodings( + input_vocab_size, + d_model, + mode, + dropout, + [-2], # dropout_shared_axes + max_len, + output_vocab_size=output_vocab_size, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + ) + + # pylint: disable=g-complex-comprehension + encoder_blocks = [ + EncoderBlock( + d_model, + d_ff, + n_heads, + tl.SelfAttention, + dropout, + ff_activation, + ff_dropout, + mode=mode, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + ) + for _ in range(n_encoder_layers) ] - if use_two_swaps_per_block: - layers.append(tl.ReversibleSwap()) - return layers - - return _attention_half_residual() + _feed_forward() - - -def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, - ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0, - ff_sparsity=0): - """Reversible transformer decoder layer. - - Args: - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - ff_activation: the non-linearity in feed-forward layer - ff_dropout: float: (optional) separate dropout rate for feed-forward layer - mode: str: 'train' or 'eval' - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - - Returns: - the layer. - """ - enc_dec_attention = tl.EncDecAttention( - n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, - attention_dropout=dropout, output_dropout=dropout, - mode=mode) - enc_dec_attention_half_residual = tl.ReversibleHalfResidual( - tl.LayerNorm(), - attention_layer=enc_dec_attention, - ) - - causal_attention = tl.SelfAttention( - n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, - causal=True, - attention_dropout=dropout, output_dropout=dropout, - mode=mode) - causal_attention_half_residual = tl.ReversibleHalfResidual( - tl.LayerNorm(), - attention_layer=causal_attention, - ) - - feed_forward = ct.FeedForwardWithOptions( - d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, - ff_chunk_size, ff_use_sru, ff_sparsity, True, mode) - - return [ # vec_d1 vec_d2 vec_e masks - causal_attention_half_residual, - tl.ReversibleSwap(), - enc_dec_attention_half_residual, - tl.ReversibleSwap(), - tl.ReversibleHalfResidual(feed_forward), - tl.ReversibleSwap(), - ] - - -def Reformer(input_vocab_size, - output_vocab_size=None, - d_model=512, - d_ff=2048, - n_encoder_layers=6, - n_decoder_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - ff_activation=tl.Relu, - ff_dropout=None, - mode='train', - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - ff_use_sru=0, - ff_chunk_size=0, - ff_sparsity=0): - """Reversible transformer encoder-decoder model. - - This model expects an input pair: target, source. - - At the moment, this model supports dot-product attention only. For the - attention types in the Reformer paper, see ReformerLM. - - Args: - input_vocab_size: int: vocab size of the source. - output_vocab_size: int (optional): vocab size of the target. If None, the - source and target are assumed to have the same vocab. - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_encoder_layers: int: number of encoder layers - n_decoder_layers: int: number of decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - ff_activation: the non-linearity in feed-forward layer - ff_dropout: float: (optional) separate dropout rate at feed-forward - nonlinearity. This is called relu_dropout in T2T. - mode: str: 'train' or 'eval' - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - - Returns: - A Reformer model as a layer that maps from a target, source pair to - activations over a vocab set. - """ - in_encoder, out_encoder, output_vocab_size = ( - ct.EmbeddingAndPositionalEncodings( - input_vocab_size, - d_model, - mode, - dropout, - [-2], # dropout_shared_axes - max_len, - output_vocab_size=output_vocab_size, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs) - ) - - # pylint: disable=g-complex-comprehension - encoder_blocks = [ - EncoderBlock( - d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, - ff_dropout, mode=mode, ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) - for _ in range(n_encoder_layers)] - # pylint: enable=g-complex-comprehension - - encoder = tl.Serial([ - in_encoder, - tl.Dup(), - tl.ReversibleSerial(encoder_blocks), - _XYAvg(), - tl.LayerNorm(), - ]) - if mode == 'predict': - encoder = tl.Cache(encoder) - - # pylint: disable=g-complex-comprehension - encoder_decoder_blocks = [ - EncoderDecoderBlock( - d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, - ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity) - for _ in range(n_decoder_layers)] - # pylint: enable=g-complex-comprehension - - # Assemble and return the model. - return tl.Serial( - # Input: encoder_side_tokens, decoder_side_tokens - # Copy decoder tokens for use in loss. - tl.Select([0, 1, 1]), # tok_e tok_d tok_d - tl.Branch([], [tl.PaddingMask(), - _RemoveAxes12()]), # tok_e mask tok_d ..... - - # Encode. - encoder, # vec_e mask tok_d ..... - - # Decode. - tl.Select([2, 0, 1]), # tok_d vec_e mask ..... - tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... - out_encoder, # vec_d vec_e mask ..... - tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... - tl.ReversibleSerial(encoder_decoder_blocks), - _XYAvg(), # vec_d vec_e mask ..... - tl.LayerNorm(), # vec_d vec_e mask ..... - - # Map to output vocab. - tl.Select([0], n_in=3), # vec_d ..... - tl.Dense(output_vocab_size), # vec_d ..... - ) + # pylint: enable=g-complex-comprehension + + encoder = tl.Serial( + [ + in_encoder, + tl.Dup(), + tl.ReversibleSerial(encoder_blocks), + _XYAvg(), + tl.LayerNorm(), + ] + ) + if mode == "predict": + encoder = tl.Cache(encoder) + + # pylint: disable=g-complex-comprehension + encoder_decoder_blocks = [ + EncoderDecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + ff_activation, + ff_dropout, + mode, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + ) + for _ in range(n_decoder_layers) + ] + # pylint: enable=g-complex-comprehension + + # Assemble and return the model. + return tl.Serial( + # Input: encoder_side_tokens, decoder_side_tokens + # Copy decoder tokens for use in loss. + tl.Select([0, 1, 1]), # tok_e tok_d tok_d + tl.Branch([], [tl.PaddingMask(), _RemoveAxes12()]), # tok_e mask tok_d ..... + # Encode. + encoder, # vec_e mask tok_d ..... + # Decode. + tl.Select([2, 0, 1]), # tok_d vec_e mask ..... + tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... + out_encoder, # vec_d vec_e mask ..... + tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... + tl.ReversibleSerial(encoder_decoder_blocks), + _XYAvg(), # vec_d vec_e mask ..... + tl.LayerNorm(), # vec_d vec_e mask ..... + # Map to output vocab. + tl.Select([0], n_in=3), # vec_d ..... + tl.Dense(output_vocab_size), # vec_d ..... + ) def _InsertAxes12(): - """Returns a layer that inserts two internal size-1 axes into an array.""" - return tl.Fn('InsertAxes12', - lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1]))) + """Returns a layer that inserts two internal size-1 axes into an array.""" + return tl.Fn( + "InsertAxes12", lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])) + ) def _RemoveAxes12(): - """Returns a layer that removes two internal size-1 axes from an array.""" - return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2))) + """Returns a layer that removes two internal size-1 axes from an array.""" + return tl.Fn("RemoveAxes12", lambda x: jnp.squeeze(x, (1, 2))) def _AsTokenIDs(): - """Returns a layer that makes mask values look like token ID ints.""" - return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32)) + """Returns a layer that makes mask values look like token ID ints.""" + return tl.Fn("AsTokenIDs", lambda x: x.astype(jnp.int32)) def _XYAvg(): - """Returns a layer that computes the element-wise average of two arrays.""" - return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0) + """Returns a layer that computes the element-wise average of two arrays.""" + return tl.Fn("XYAvg", lambda x, y: (x + y) / 2.0) def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True): - """ReversibleSerial but with a forgetting block every n_layers.""" - if not n_layers or len(layers) <= n_layers + 1: - return tl.ReversibleSerial(layers) - layers1, layers2 = layers[:n_layers], layers[n_layers:] - - if forget_dense: - forgetting_layer = tl.Serial( - _XYAvg(), - tl.Dense(d_model), - tl.Dup(), + """ReversibleSerial but with a forgetting block every n_layers.""" + if not n_layers or len(layers) <= n_layers + 1: + return tl.ReversibleSerial(layers) + layers1, layers2 = layers[:n_layers], layers[n_layers:] + + if forget_dense: + forgetting_layer = tl.Serial( + _XYAvg(), + tl.Dense(d_model), + tl.Dup(), + ) + else: + forgetting_layer = tl.Select([0, 1]) + + return tl.Serial( + tl.ReversibleSerial(layers1), + forgetting_layer, + _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense), ) - else: - forgetting_layer = tl.Select([0, 1]) - - return tl.Serial( - tl.ReversibleSerial(layers1), - forgetting_layer, - _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense) - ) def _ConvertToNaNsOnAnyZero(): - def _convert_to_nans(x, y): - # if all values in y are non-zeros, return x; otherwise return 0s - return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y - return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2) + def _convert_to_nans(x, y): + # if all values in y are non-zeros, return x; otherwise return 0s + return jnp.where(jnp.all(y, keepdims=False), x, x / 0.0), y + + return tl.Fn("ConvertToNaNsOnAnyZero", _convert_to_nans, n_out=2) diff --git a/trax/models/reformer/reformer_e2e_test.py b/trax/models/reformer/reformer_e2e_test.py deleted file mode 100644 index 57b180353..000000000 --- a/trax/models/reformer/reformer_e2e_test.py +++ /dev/null @@ -1,80 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""End to end test for Reformer.""" - -import os - -from absl.testing import absltest -import gin - -from trax import test_utils -from trax.models.reformer import reformer # pylint: disable=unused-import -from trax.supervised import trainer_lib -from trax.tf_numpy import numpy as tf_np # pylint: disable=unused-import - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') -_CONFIG_DIR = os.path.join(pkg_dir, '../../supervised/configs/') - - -class ReformerE2ETest(absltest.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - gin.add_config_file_search_path(_CONFIG_DIR) - test_utils.ensure_flag('test_tmpdir') - - def test_reformer_wmt_ende(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - - gin.parse_config_file('reformer_wmt_ende.gin') - - gin.bind_parameter('data_streams.data_dir', _TESTDATA) - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('Reformer.n_encoder_layers', n_layers) - gin.bind_parameter('Reformer.n_decoder_layers', n_layers) - gin.bind_parameter('Reformer.d_ff', d_ff) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - def test_reformer_copy(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - d_model = 32 - - gin.parse_config_file('reformer_copy.gin') - - gin.bind_parameter('data_streams.data_dir', _TESTDATA) - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('ReformerLM.n_layers', n_layers) - gin.bind_parameter('ReformerLM.d_ff', d_ff) - gin.bind_parameter('ReformerLM.d_model', d_model) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/reformer/reformer_test.py b/trax/models/reformer/reformer_test.py deleted file mode 100644 index 5a1fce949..000000000 --- a/trax/models/reformer/reformer_test.py +++ /dev/null @@ -1,126 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Reformer models.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax.models.reformer import reformer - - -BACKENDS = [fastmath.Backend.JAX] - - -def short_name(b): - if b == fastmath.Backend.JAX: - return 'jax' - else: - return 'tf' - - -class ReformerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - def _lsh_self_attention_fn(self): - return functools.partial( - tl.LSHSelfAttention, - attention_dropout=0.0, - chunk_len=64, - n_buckets=[32, 32], - n_chunks_after=0, - n_chunks_before=1, - n_hashes=1, - n_parallel_heads=1, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def _timebin_self_attention_fn(self, use_reference_code=False): - return functools.partial( - tl.SelfAttention, - attention_dropout=0.05, - chunk_len=64, - n_chunks_before=1, - n_parallel_heads=1, - use_reference_code=use_reference_code - ) - - def test_reformer_lm_forward_shape(self): - vocab_size = 16 - model = reformer.ReformerLM( - vocab_size, d_model=32, d_ff=64, d_attention_key=16, - d_attention_value=16, n_layers=1, n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - - def test_reformer_lm_lsh(self): - lsh_self_attention = self._lsh_self_attention_fn() - timebin_self_attention = self._timebin_self_attention_fn() - - model = reformer.ReformerLM( - vocab_size=256, - d_model=256, - d_ff=512, - d_attention_key=64, - d_attention_value=64, - n_layers=2, - n_heads=2, - dropout=0.05, - max_len=65536, - attention_type=[timebin_self_attention, lsh_self_attention], - pos_axial_shape=(256, 256), - pos_d_axial_embs=(64, 192), - ff_activation=tl.Relu, - ff_use_sru=0, - ff_chunk_size=8192, - mode='train', - ) - x = np.ones((1, 65536)).astype(np.int32) - weights, state = model.init(shapes.signature(x)) - - @fastmath.jit - def mock_training_step(x, weights, state, rng): - def compute_mock_loss(weights): - logits, new_state = model.pure_fn(x, weights, state, rng) - loss = fastmath.numpy.mean(logits[..., 0]) - return loss, (new_state, logits) - gradients, (new_state, logits) = fastmath.grad( - compute_mock_loss, has_aux=True)(weights) - new_weights = fastmath.nested_map_multiarg( - lambda w, g: w - 1e-4 * g, weights, gradients) - return new_weights, new_state, logits - - weights, state, logits = mock_training_step( - x, weights, state, fastmath.random.get_prng(0)) - self.assertEqual(logits.shape, (1, 65536, 256)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/reformer/text_generation.ipynb b/trax/models/reformer/text_generation.ipynb deleted file mode 100644 index 5b67721b0..000000000 --- a/trax/models/reformer/text_generation.ipynb +++ /dev/null @@ -1,548 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Reformer: Text Generation", - "provenance": [], - "collapsed_sections": [ - "udDs_biH0n5U" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "TPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "udDs_biH0n5U", - "colab_type": "text" - }, - "source": [ - "#### Copyright 2020 Google LLC." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WPY-OyyM0pSs", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - " https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "psnUF-8c02o_", - "colab_type": "text" - }, - "source": [ - "# Reformer: Text Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1lnRd_IoERdk", - "colab_type": "text" - }, - "source": [ - "This notebook was designed to run on TPU.\n", - "\n", - "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8PluCmWbZIpJ", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Install JAX.\n", - "!pip install --upgrade jax\n", - "!pip install --upgrade jaxlib\n", - "!pip install --upgrade trax\n", - "\n", - "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", - "import requests\n", - "import os\n", - "if 'TPU_DRIVER_MODE' not in globals():\n", - " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", - " resp = requests.post(url)\n", - " TPU_DRIVER_MODE = 1\n", - "\n", - "# The following is required to use TPU Driver as JAX's backend.\n", - "from jax.config import config\n", - "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", - "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", - "print(config.FLAGS.jax_backend_target)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yiPdBenoZwH6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "!pip install --upgrade -q sentencepiece\n", - "!pip install --upgrade -q gin \n", - "\n", - "from tensorflow.compat.v1.io.gfile import GFile\n", - "import gin\n", - "import os\n", - "import jax\n", - "import trax\n", - "from trax.data import inputs\n", - "\n", - "import numpy as np\n", - "import jax.numpy as jnp\n", - "\n", - "from scipy.special import softmax\n", - "\n", - "from sentencepiece import SentencePieceProcessor" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FQ89jHCYfhpg" - }, - "source": [ - "## Setting up data and model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9_OCIqghSyfs", - "colab_type": "text" - }, - "source": [ - "In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of \"Crime and Punishment\" on *each* of the 8 TPU cores (over 500,000 tokens per 8GB of memory)." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "tYSOVGR47LVL", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Import a copy of \"Crime and Punishment\", by Fyodor Dostoevsky\n", - "with GFile('gs://trax-ml/reformer/crime-and-punishment-2554.txt') as f:\n", - " text = f.read()\n", - "\n", - "# The file read above includes metadata and licensing information.\n", - "# For training our language model, we will only use the actual novel text.\n", - "start = text.find('CRIME AND PUNISHMENT') # skip header\n", - "start = text.find('CRIME AND PUNISHMENT', start + 1) # skip header\n", - "start = text.find('CRIME AND PUNISHMENT', start + 1) # skip translator preface\n", - "end = text.rfind('End of Project') # skip extra text at the end\n", - "text = text[start:end].strip()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "mMntV3H-6OR0", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 102 - }, - "outputId": "c8d4386c-cf5d-4dc4-92d9-24391fa2f30e" - }, - "source": [ - "# Load a BPE vocabulaary with 320 types. This mostly consists of single letters\n", - "# and pairs of letters, but it has some common words and word pieces, too.\n", - "!gsutil cp gs://trax-ml/reformer/cp.320.* .\n", - "\n", - "TOKENIZER = SentencePieceProcessor()\n", - "TOKENIZER.load('cp.320.model')" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Copying gs://trax-ml/reformer/cp.320.model...\n", - "Copying gs://trax-ml/reformer/cp.320.vocab...\n", - "/ [2 files][239.0 KiB/239.0 KiB] \n", - "Operation completed over 2 objects/239.0 KiB. \n" - ], - "name": "stdout" - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "True" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 4 - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "HnJzxSi_77zP", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17 - }, - "outputId": "f8b2050b-0233-40e4-88f1-e546a1541b31" - }, - "source": [ - "# Tokenize\n", - "IDS = TOKENIZER.EncodeAsIds(text)\n", - "IDS = np.asarray(IDS, dtype=np.int32)\n", - "PAD_AMOUNT = 512 * 1024 - len(IDS)\n", - "print(\"Number of tokens:\", IDS.shape[0])" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Number of tokens: 513812\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bzQ7G9uGSga5", - "colab_type": "text" - }, - "source": [ - "As we see above, \"Crime and Punishment\" has just over half a million tokens with the BPE vocabulary we have selected.\n", - "\n", - "Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.\n", - "\n", - "We have 8 TPU cores, so we will separately randomize the amount of padding for each core." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "PdAwmpS220ub", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "outputId": "c0919b3d-4c63-4d2f-db44-3aeccaf4d966" - }, - "source": [ - "# Set up the data pipeline.\n", - "def my_inputs(n_devices):\n", - " while True:\n", - " inputs = []\n", - " mask = []\n", - " pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)\n", - " for i in range(n_devices):\n", - " inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n", - " mode='constant'))\n", - " mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),\n", - " (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n", - " mode='constant'))\n", - " inputs = np.stack(inputs)\n", - " mask = np.stack(mask)\n", - " yield (inputs, inputs, mask)\n", - "\n", - "print(\"(device count, tokens per device) = \",\n", - " next(my_inputs(trax.fastmath.device_count()))[0].shape)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "(device count, tokens per device) = (8, 524288)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Ei90LdK024r_", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Configure hyperparameters.\n", - "gin.parse_config(\"\"\"\n", - "import trax.layers\n", - "import trax.models\n", - "import trax.optimizers\n", - "import trax.data.inputs\n", - "import trax.supervised.trainer_lib\n", - "\n", - "# Parameters that will vary between experiments:\n", - "# ==============================================================================\n", - "train.model = @trax.models.ReformerLM\n", - "# Our model will have 6 layers, alternating between the LSH attention proposed\n", - "# in the Reformer paper and local attention within a certain context window.\n", - "n_layers = 6\n", - "attn_type = [\n", - " @trax.layers.SelfAttention,\n", - " @LSHSelfAttention, \n", - " @trax.layers.SelfAttention,\n", - " @LSHSelfAttention,\n", - " @trax.layers.SelfAttention,\n", - " @LSHSelfAttention,\n", - " ]\n", - "share_qk = False # LSH attention ignores this flag and always shares q & k\n", - "n_heads = 2\n", - "attn_kv = 64\n", - "dropout = 0.05\n", - "n_tokens = 524288\n", - "\n", - "# Parameters for multifactor:\n", - "# ==============================================================================\n", - "multifactor.constant = 0.01\n", - "multifactor.factors = 'constant * linear_warmup * cosine_decay'\n", - "multifactor.warmup_steps = 100\n", - "multifactor.steps_per_cycle = 900\n", - "\n", - "# Parameters for Adam:\n", - "# ==============================================================================\n", - "Adam.weight_decay_rate=0.0\n", - "Adam.b1 = 0.86\n", - "Adam.b2 = 0.92\n", - "Adam.eps = 1e-9\n", - "\n", - "# Parameters for SelfAttention:\n", - "# ==============================================================================\n", - "trax.layers.SelfAttention.attention_dropout = 0.05\n", - "trax.layers.SelfAttention.chunk_len = 64\n", - "trax.layers.SelfAttention.n_chunks_before = 1\n", - "trax.layers.SelfAttention.n_parallel_heads = 1\n", - "\n", - "# Parameters for LSHSelfAttention:\n", - "# ==============================================================================\n", - "LSHSelfAttention.attention_dropout = 0.0\n", - "LSHSelfAttention.chunk_len = 64\n", - "LSHSelfAttention.n_buckets = [64, 128]\n", - "LSHSelfAttention.n_chunks_after = 0\n", - "LSHSelfAttention.n_chunks_before = 1\n", - "LSHSelfAttention.n_hashes = 1\n", - "LSHSelfAttention.n_parallel_heads = 1\n", - "LSHSelfAttention.predict_drop_len = 128\n", - "LSHSelfAttention.predict_mem_len = 1024\n", - "\n", - "# Parameters for ReformerLM:\n", - "# ==============================================================================\n", - "ReformerLM.attention_type = %attn_type\n", - "ReformerLM.d_attention_key = %attn_kv\n", - "ReformerLM.d_attention_value = %attn_kv\n", - "ReformerLM.d_model = 256\n", - "ReformerLM.d_ff = 512\n", - "ReformerLM.dropout = %dropout\n", - "ReformerLM.ff_activation = @trax.layers.Relu\n", - "ReformerLM.max_len = %n_tokens\n", - "ReformerLM.mode = 'train'\n", - "ReformerLM.n_heads = %n_heads\n", - "ReformerLM.n_layers = %n_layers\n", - "ReformerLM.vocab_size = 320\n", - "ReformerLM.axial_pos_shape = (512, 1024)\n", - "ReformerLM.d_axial_pos_embs= (64, 192)\n", - "\"\"\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "RGGt0WaT3a-h", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Set up a Trainer.\n", - "output_dir = os.path.expanduser('~/train_dir/')\n", - "!rm -f ~/train_dir/model.pkl.gz # Remove old model\n", - "\n", - "trainer = trax.supervised.Trainer(\n", - " model=trax.models.ReformerLM,\n", - " loss_fn=trax.layers.CrossEntropyLoss(),\n", - " optimizer=trax.optimizers.Adam,\n", - " lr_schedule=trax.lr.multifactor(),\n", - " inputs=trax.data.inputs.Inputs(my_inputs),\n", - " output_dir=output_dir)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "y6VQkmKO3a1L", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 255 - }, - "outputId": "3c933bab-b49d-4e18-caf6-3dfc3e220938" - }, - "source": [ - "# Run one training step, to make sure the model fits in memory.\n", - "# The first time trainer.train_epoch is called, it will JIT the entire network\n", - "# architecture, which takes around 2 minutes. The JIT-compiled model is saved\n", - "# so subsequent runs will be much faster than the first.\n", - "trainer.train_epoch(n_steps=1, n_eval_steps=1)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\n", - "Step 1: Ran 1 train steps in 155.17 secs\n", - "Step 1: Evaluation\n", - "Step 1: train accuracy | 0.00343633\n", - "Step 1: train loss | 6.36618853\n", - "Step 1: train neg_log_perplexity | -6.36618853\n", - "Step 1: train sequence_accuracy | 0.00000000\n", - "Step 1: train weights_per_batch_per_core | 513812.00000000\n", - "Step 1: eval accuracy | 0.00340154\n", - "Step 1: eval loss | 6.36649418\n", - "Step 1: eval neg_log_perplexity | -6.36649418\n", - "Step 1: eval sequence_accuracy | 0.00000000\n", - "Step 1: eval weights_per_batch_per_core | 513812.00000000\n", - "Step 1: Finished evaluation\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "EFnX4G6z3asD", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Train for 600 steps total\n", - "# The first ~20 steps are slow to run, but after that it reaches steady-state\n", - "# speed. This will take at least 30 minutes to run to completion, but can safely\n", - "# be interrupted by selecting \"Runtime > Interrupt Execution\" from the menu.\n", - "# The language model won't be exceptionally good when trained for just a few\n", - "# steps and with minimal regularization. However, we can still sample from it to\n", - "# see what it learns.\n", - "trainer.train_epoch(n_steps=9, n_eval_steps=1)\n", - "for _ in range(59):\n", - " trainer.train_epoch(n_steps=10, n_eval_steps=1)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zY3hpgnI5Rgn", - "colab_type": "text" - }, - "source": [ - "## Sample from the model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ffeLSbJk35pv", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# As we report in the Reformer paper, increasing the number of hashing rounds\n", - "# helps with quality. We can even increase the number of hashing rounds at\n", - "# evaluation time only.\n", - "\n", - "gin.parse_config(\"\"\"LSHSelfAttention.n_hashes = 4\"\"\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "-BwIjdl6_2tX", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Load the trained Reformer in 'predict' mode\n", - "model = trax.models.ReformerLM(mode='predict')\n", - "model.init_from_file(os.path.join(output_dir,'model.pkl.gz'),\n", - " weights_only=True)\n", - "\n", - "# Sample from ReformerLM\n", - "output_token_ids = trax.supervised.decoding.autoregressive_sample(\n", - " model, temperature=0.0)\n", - "\n", - "# Decode token IDs\n", - "# Reformer outputed a batch with one item, we access it using [0]\n", - "# tolist() converts from int64 to int, the type SentencePiece expects\n", - "TOKENIZER.DecodeIds(output_token_ids[0].tolist()) \n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "s5f5QAmZBgPj", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/trax/models/research/configurable_transformer.py b/trax/models/research/configurable_transformer.py index b0d1e3232..25d5b92ef 100644 --- a/trax/models/research/configurable_transformer.py +++ b/trax/models/research/configurable_transformer.py @@ -22,1005 +22,1163 @@ from trax import layers as tl -def _FeedForward(d_model, d_ff, dropout, activation, act_dropout, - use_bfloat16, mode): - """Feed-forward block with layer normalization at start.""" - if act_dropout is None: - act_dropout = dropout - return [ - tl.Dense(d_ff, use_bfloat16=use_bfloat16), - tl.Dropout(rate=act_dropout, shared_axes=[-2], mode=mode), - activation(), - tl.Dense(d_model, use_bfloat16=use_bfloat16), - ] - - -def FeedForwardWithOptions(d_model, - d_ff, - dropout, - dropout_shared_axes, - ff_activation, - ff_dropout, - ff_chunk_size, - ff_use_sru, - ff_sparsity, - center_layernorm, - mode, - use_bfloat16=False, - ff_sparsity_type='1inN'): - """Feed-Forward block with all the options. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block - with this sparsity - center_layernorm: whether to use centering in LayerNorm (default) or if - to skip it, which is known as RMS normalization. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - use_bfloat16: whether to use bfloat16 for weights (default: False). - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - use SwitchSparseFF if ff_sparsity_type=`'Switch'` - - Returns: - A list of layers which maps vectors to vectors. - """ - if ff_sparsity and ff_sparsity_type == '1inN': - temperature, quant_prob = 0.1, 0.3 - if isinstance(ff_sparsity, str): - # This is hacky but used to pass ff_sparsity in yaml sweep files. - ff_sparsity = [(float(x) if '.' in x else int(x)) - for x in ff_sparsity.split()] - if isinstance(ff_sparsity, (list, tuple)): - if len(ff_sparsity) == 2: - n_elements_in_block, d_lowrank = ff_sparsity - else: - n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity - else: - assert isinstance(ff_sparsity, int) - n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity - ff = tl.SparseFF( - d_ff, - n_elements_in_block=n_elements_in_block, - d_lowrank=d_lowrank, - temperature=temperature, - quant_prob=quant_prob, - use_bfloat16=use_bfloat16, - mode=mode, - dropout_rate=dropout, - dropout_shared_axes=dropout_shared_axes, - ff_chunk_size=ff_chunk_size) - elif ff_sparsity and ff_sparsity_type == 'Block': - ff = tl.BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) - elif ff_sparsity and ff_sparsity_type == 'Switch': - ff = tl.SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) - else: - ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout, - use_bfloat16, mode) - res = [tl.LayerNorm(center=center_layernorm), ff] - if ff_sparsity_type != '1inN' or ff_sparsity == 0: - # SparseFF has Dropout and BatchLeadingAxes built-in. - res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, - mode=mode)) - if ff_chunk_size > 0: - res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size)) - if ff_use_sru: - if isinstance(ff_use_sru, (list, tuple)): - sru_n_layers, sru_n_units = ff_use_sru +def _FeedForward(d_model, d_ff, dropout, activation, act_dropout, use_bfloat16, mode): + """Feed-forward block with layer normalization at start.""" + if act_dropout is None: + act_dropout = dropout + return [ + tl.Dense(d_ff, use_bfloat16=use_bfloat16), + tl.Dropout(rate=act_dropout, shared_axes=[-2], mode=mode), + activation(), + tl.Dense(d_model, use_bfloat16=use_bfloat16), + ] + + +def FeedForwardWithOptions( + d_model, + d_ff, + dropout, + dropout_shared_axes, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + center_layernorm, + mode, + use_bfloat16=False, + ff_sparsity_type="1inN", +): + """Feed-Forward block with all the options. + + Args: + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block + with this sparsity + center_layernorm: whether to use centering in LayerNorm (default) or if + to skip it, which is known as RMS normalization. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + use_bfloat16: whether to use bfloat16 for weights (default: False). + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + use SwitchSparseFF if ff_sparsity_type=`'Switch'` + + Returns: + A list of layers which maps vectors to vectors. + """ + if ff_sparsity and ff_sparsity_type == "1inN": + temperature, quant_prob = 0.1, 0.3 + if isinstance(ff_sparsity, str): + # This is hacky but used to pass ff_sparsity in yaml sweep files. + ff_sparsity = [ + (float(x) if "." in x else int(x)) for x in ff_sparsity.split() + ] + if isinstance(ff_sparsity, (list, tuple)): + if len(ff_sparsity) == 2: + n_elements_in_block, d_lowrank = ff_sparsity + else: + n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity + else: + assert isinstance(ff_sparsity, int) + n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity + ff = tl.SparseFF( + d_ff, + n_elements_in_block=n_elements_in_block, + d_lowrank=d_lowrank, + temperature=temperature, + quant_prob=quant_prob, + use_bfloat16=use_bfloat16, + mode=mode, + dropout_rate=dropout, + dropout_shared_axes=dropout_shared_axes, + ff_chunk_size=ff_chunk_size, + ) + elif ff_sparsity and ff_sparsity_type == "Block": + ff = tl.BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) + elif ff_sparsity and ff_sparsity_type == "Switch": + ff = tl.SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) else: - sru_n_layers, sru_n_units = ff_use_sru, 32 - sru = [tl.SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)] - block = [tl.LayerNorm(center=center_layernorm), tl.Dense(sru_n_units) - ] + sru + [tl.Dense(d_model)] - res = tl.Residual(block, shortcut=res) - return [res] + ff = _FeedForward( + d_model, d_ff, dropout, ff_activation, ff_dropout, use_bfloat16, mode + ) + res = [tl.LayerNorm(center=center_layernorm), ff] + if ff_sparsity_type != "1inN" or ff_sparsity == 0: + # SparseFF has Dropout and BatchLeadingAxes built-in. + res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)) + if ff_chunk_size > 0: + res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size)) + if ff_use_sru: + if isinstance(ff_use_sru, (list, tuple)): + sru_n_layers, sru_n_units = ff_use_sru + else: + sru_n_layers, sru_n_units = ff_use_sru, 32 + sru = [tl.SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)] + block = ( + [tl.LayerNorm(center=center_layernorm), tl.Dense(sru_n_units)] + + sru + + [tl.Dense(d_model)] + ) + res = tl.Residual(block, shortcut=res) + return [res] # TODO(lukaszkaiser): unify attention layers API and remove this branch -def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal, - masked, attention_dropout, output_dropout, - attention_chunk_size, mode): - """Runs the supplied attention layer.""" - try: - attention = attention_type( - n_heads=n_heads, - d_qk=d_qk, - d_v=d_v, - causal=causal, - masked=masked, - output_dropout=output_dropout, - attention_dropout=attention_dropout, - mode=mode) - except TypeError: # No d_qk arguments in less advanced layers. - attention = attention_type( - d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode) - return tl.Chunk(attention, attention_chunk_size) - - -@tl.assert_shape('...d->...d') -def PositionalEncoder(mode, - dropout=None, - max_len=None, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0, - use_bfloat16=False): - """Returns the positional encoding layer depending on the arguments. - - Args: - mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder - block will include dropout; else, it will pass all values through - unaltered. - dropout: Stochastic rate (probability) for dropping an activation - value when applying dropout after the embedding block. - max_len: Maximum symbol length for positional encoding. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - pos_start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - pos_max_offset_to_add: maximum offset to add to positions during training - when randomizing; this offset plus input length must still be less than - max_len for all training examples. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - - Returns: - A layer that will do the positional encoding. - """ - if not pos_type: - positional_encoding = tl.PositionalEncoding( - max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16, - start_from_zero_prob=pos_start_from_zero_prob, - max_offset_to_add=pos_max_offset_to_add, mode=mode) - elif pos_type == 'sin-cos': - positional_encoding = tl.SinCosPositionalEncoding(mode=mode) - elif pos_type == 'fixed-base': - positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) - elif pos_type == 'infinite': - positional_encoding = tl.InfinitePositionalEncoding(affine=False) - elif pos_type == 'infinite-affine': - positional_encoding = tl.InfinitePositionalEncoding() - elif pos_type == 'time-bin': - positional_encoding = tl.TimeBinPositionalEncoding() - elif pos_type == 'no': - positional_encoding = tl.Serial() # no positional encoding at all - else: # TODO(lukaszkaiser): name this type and check for the correct name - assert pos_d_axial_embs is not None - positional_encoding = tl.AxialPositionalEncoding( - shape=pos_axial_shape, d_embs=pos_d_axial_embs, - dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)), - dropout=dropout, mode=mode) - - return positional_encoding - - -def EmbeddingAndPositionalEncodings(input_vocab_size, - d_model, - mode, - embedding_dropout, - dropout_shared_axes, - max_len, - output_vocab_size=None, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0, - use_bfloat16=False): - """Returns the embedder and positional encoder. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder - block will include dropout; else, it will pass all values through - unaltered. - embedding_dropout: Stochastic rate (probability) for dropping an activation - value when applying dropout after the embedding block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - max_len: Maximum symbol length for positional encoding. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if None, then input and target integers (token IDs) are assumed to come - from the same vocabulary. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - pos_start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - pos_max_offset_to_add: maximum offset to add to positions during training +def ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_qk, + d_v, + causal, + masked, + attention_dropout, + output_dropout, + attention_chunk_size, + mode, +): + """Runs the supplied attention layer.""" + try: + attention = attention_type( + n_heads=n_heads, + d_qk=d_qk, + d_v=d_v, + causal=causal, + masked=masked, + output_dropout=output_dropout, + attention_dropout=attention_dropout, + mode=mode, + ) + except TypeError: # No d_qk arguments in less advanced layers. + attention = attention_type( + d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode + ) + return tl.Chunk(attention, attention_chunk_size) + + +@tl.assert_shape("...d->...d") +def PositionalEncoder( + mode, + dropout=None, + max_len=None, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, + use_bfloat16=False, +): + """Returns the positional encoding layer depending on the arguments. + + Args: + mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder + block will include dropout; else, it will pass all values through + unaltered. + dropout: Stochastic rate (probability) for dropping an activation + value when applying dropout after the embedding block. + max_len: Maximum symbol length for positional encoding. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + pos_start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + pos_max_offset_to_add: maximum offset to add to positions during training + when randomizing; this offset plus input length must still be less than + max_len for all training examples. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + + Returns: + A layer that will do the positional encoding. + """ + if not pos_type: + positional_encoding = tl.PositionalEncoding( + max_len=max_len, + dropout=dropout, + use_bfloat16=use_bfloat16, + start_from_zero_prob=pos_start_from_zero_prob, + max_offset_to_add=pos_max_offset_to_add, + mode=mode, + ) + elif pos_type == "sin-cos": + positional_encoding = tl.SinCosPositionalEncoding(mode=mode) + elif pos_type == "fixed-base": + positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) + elif pos_type == "infinite": + positional_encoding = tl.InfinitePositionalEncoding(affine=False) + elif pos_type == "infinite-affine": + positional_encoding = tl.InfinitePositionalEncoding() + elif pos_type == "time-bin": + positional_encoding = tl.TimeBinPositionalEncoding() + elif pos_type == "no": + positional_encoding = tl.Serial() # no positional encoding at all + else: # TODO(lukaszkaiser): name this type and check for the correct name + assert pos_d_axial_embs is not None + positional_encoding = tl.AxialPositionalEncoding( + shape=pos_axial_shape, + d_embs=pos_d_axial_embs, + dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)), + dropout=dropout, + mode=mode, + ) + + return positional_encoding + + +def EmbeddingAndPositionalEncodings( + input_vocab_size, + d_model, + mode, + embedding_dropout, + dropout_shared_axes, + max_len, + output_vocab_size=None, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, + use_bfloat16=False, +): + """Returns the embedder and positional encoder. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder + block will include dropout; else, it will pass all values through + unaltered. + embedding_dropout: Stochastic rate (probability) for dropping an activation + value when applying dropout after the embedding block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + max_len: Maximum symbol length for positional encoding. + output_vocab_size: If specified, gives the vocabulary size for the targets; + if None, then input and target integers (token IDs) are assumed to come + from the same vocabulary. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + pos_start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + pos_max_offset_to_add: maximum offset to add to positions during training + when randomizing; this offset plus input length must still be less than + max_len for all training examples. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + + Returns: + A tuple of (input encoder, output encoder, output vocab size used). + """ + + # tokens --> vectors + def Embedder(vocab_size, embedding_mode): + if vocab_size is not None: + embedding = tl.Embedding(vocab_size, d_model, use_bfloat16=use_bfloat16) + else: + embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16) + return [ + embedding, + tl.Dropout( + rate=embedding_dropout, + shared_axes=dropout_shared_axes, + mode=embedding_mode, + ), + ] + + # NOTE: Positional encodings are not shared between encoder and decoder. + + # Since encoder doesn't run stepwise, we do not use predict mode there. + encoder_mode = "eval" if mode == "predict" else mode + in_embedder = Embedder(input_vocab_size, encoder_mode) + in_encoder = in_embedder + [ + PositionalEncoder( + encoder_mode, + dropout=embedding_dropout, + max_len=max_len, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + pos_start_from_zero_prob=pos_start_from_zero_prob, + pos_max_offset_to_add=pos_max_offset_to_add, + use_bfloat16=use_bfloat16, + ) + ] + + # If output_vocab_size is None, we reuse the same embedding matrix, otherwise + # we initialize one. + assert input_vocab_size or output_vocab_size + if output_vocab_size is None: + out_embedder = in_embedder + else: + out_embedder = Embedder(output_vocab_size, mode) + + out_encoder = out_embedder + [ + PositionalEncoder( + mode, + dropout=embedding_dropout, + max_len=max_len, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + pos_start_from_zero_prob=pos_start_from_zero_prob, + pos_max_offset_to_add=pos_max_offset_to_add, + use_bfloat16=use_bfloat16, + ) + ] + + # Set this to the value actually used. + if output_vocab_size is None: + output_vocab_size = input_vocab_size + + if input_vocab_size is None: + in_encoder = tl.AssertFunction("...a->...b", in_encoder) + else: + in_encoder = tl.AssertFunction("...->...d", in_encoder) + out_encoder = tl.AssertFunction("...->...d", out_encoder) + + return in_encoder, out_encoder, output_vocab_size + + +def ConfigurableTransformerEncoder( + vocab_size, + n_classes=10, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + max_len=2048, + dropout=0.1, + dropout_shared_axes=None, + mode="train", + ff_activation=tl.Relu, + ff_dropout=0.1, + ff_chunk_size=0, + ff_use_sru=0, + ff_sparsity=0, + ff_sparsity_type="1inN", + attention_chunk_size=0, + attention_type=tl.Attention, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, +): + """Returns a Transformer encoder merged with an N-way categorization head. + + This model performs text categorization: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 2 tensor representing a batch of log-probability + distributions over N categories; shape is (batch_size, `n_classes`). + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor should + be an integer in `range(vocab_size)`. These integers typically represent + token IDs from a vocabulary-based tokenizer. + n_classes: Final dimension of the output tensors, representing N-way + classification. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each encoder + block. + n_layers: Number of encoder blocks. Each block includes attention, dropout, + residual, feed-forward (`Dense`), and activation layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within an encoder block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each encoder block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each encoder block; + must be an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use for the encoder part. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + + Returns: + A Transformer model that maps strings (conveyed via token IDs) to + probability-like activations over a range of output classes. + """ + positional_encoder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + PositionalEncoder( + mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs + ), + ] + + positional_encoder = tl.AssertFunction("...->...d", positional_encoder) + + # pylint: disable=g-complex-comprehension + encoder_blocks = [ + EncoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + ) + for i in range(n_layers) + ] + # pylint: enable=g-complex-comprehension + + # Assemble and return the model. + return tl.Serial( # toks + # Encode. + tl.Branch(positional_encoder, tl.PaddingMask()), # vecs masks + encoder_blocks, # vecs masks + tl.Select([0], n_in=2), # vecs + tl.LayerNorm(), # vecs + # Map to output categories. + tl.Mean(axis=1), # vecs + tl.Dense(n_classes), # vecs + ) + + +def ConfigurableTransformerLM( + vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + max_len=2048, + dropout=0.1, + dropout_shared_axes=None, + mode="train", + ff_activation=tl.Relu, + ff_dropout=0.1, + ff_chunk_size=0, + ff_use_sru=0, + ff_sparsity=0, + ff_sparsity_type="1inN", + loss_sparsity_type="mult", + loss_sparsity=0, + loss_d_lowrank=0, + loss_sparsity_prob=None, + attention_chunk_size=0, + attention_type=tl.CausalAttention, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, +): + """Returns a Transformer language model. + + This model performs autoregressive language modeling: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + This model uses only the decoder part of the overall Transformer. + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor should + be an integer in `range(vocab_size)`. These integers typically represent + token IDs from a vocabulary-based tokenizer. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each encoder + block. + n_layers: Number of encoder blocks. Each block includes attention, dropout, + residual, feed-forward (`Dense`), and activation layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within an encoder block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'predict'`, use fast inference. If `'train'`, each encoder block + will include dropout; else, it will pass all values through unaltered. + ff_activation: Type of activation function at the end of each encoder block; + must be an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + loss_sparsity_type: string, type of sparsity to used in loss layer. See + SparseDenseWithOptions for options. None if no sparsity should be used. + loss_sparsity: int, the sparsity for loss layer (if used) + loss_d_lowrank: int, the dimensions for intermediate layer (if used) + loss_sparsity_prob: float, the probability for sparse version of loss to be + used. If None, only sparse version is used. + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use for the decoder part. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + pos_start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + pos_max_offset_to_add: maximum offset to add to positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - - Returns: - A tuple of (input encoder, output encoder, output vocab size used). - """ - # tokens --> vectors - def Embedder(vocab_size, embedding_mode): - if vocab_size is not None: - embedding = tl.Embedding(vocab_size, d_model, use_bfloat16=use_bfloat16) - else: - embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16) - return [ - embedding, - tl.Dropout(rate=embedding_dropout, - shared_axes=dropout_shared_axes, - mode=embedding_mode), + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + positional_encoder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + PositionalEncoder( + mode, + dropout, + max_len, + pos_type, + pos_axial_shape, + pos_d_axial_embs, + pos_start_from_zero_prob, + pos_max_offset_to_add, + ), + ] + + # pylint: disable=g-complex-comprehension + decoder_blocks = [ + DecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + ) + for i in range(n_layers) + ] + # pylint: enable=g-complex-comprehension + + # Assemble and return the model. + return tl.Serial( # tokens (or chunked tuple of tokens) + tl.ShiftRight(mode=mode), # toks + positional_encoder, # vecs + decoder_blocks, # vecs + tl.LayerNorm(), # vecs + tl.SparseDenseWithOptions( # vecs + vocab_size, + d_input=d_model, + sparsity_type=loss_sparsity_type, + sparsity=loss_sparsity, + d_lowrank=loss_d_lowrank, + prob_sparse=loss_sparsity_prob, + mode=mode, + ), + ) + + +def ConfigurableTransformer( + input_vocab_size, + output_vocab_size=None, + d_model=512, + d_ff=2048, + n_encoder_layers=6, + n_decoder_layers=6, + n_heads=8, + max_len=2048, + dropout=0.1, + dropout_shared_axes=None, + mode="train", + ff_activation=tl.Relu, + ff_dropout=0.1, + ff_chunk_size=0, + ff_use_sru=0, + ff_sparsity=0, + ff_sparsity_type="1inN", + loss_sparsity_type="mult", + loss_sparsity=0, + loss_d_lowrank=0, + loss_sparsity_prob=None, + attention_chunk_size=0, + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + enc_dec_attention_sparsity=0, +): + """Returns a full Transformer model. + + This model is an encoder-decoder that performs tokenized string-to-string + ("source"-to-"target") transduction: + + - inputs (2): + + - source: rank 2 tensor representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length). The + tensor elements are integers in `range(input_vocab_size)`, and `0` + values mark padding positions. + + - target: rank 2 tensor representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length). The + tensor elements are integers in `range(output_vocab_size)`, and `0` + values mark padding positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + An example use would be to translate (tokenized) sentences from English to + German. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + output_vocab_size: If specified, gives the vocabulary size for the targets; + if None, then input and target integers (token IDs) are assumed to come + from the same vocabulary. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each encoder + and decoder block. + n_encoder_layers: Number of encoder blocks. + n_decoder_layers: Number of decoder blocks. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within an encoder/decoder block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder + block will include dropout; else, it will pass all values through + unaltered. + ff_activation: Type of activation function at the end of each + encoder/decoder block; must be an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + loss_sparsity_type: str, type of sparsity to used in loss layer. See + SparseDenseWithOptions for options. None if no sparsity should be used. + loss_sparsity: int, the sparsity for loss layer (if used) + loss_d_lowrank: int, the dimensions for intermediate layer (if used) + loss_sparsity_prob: float, the probability for sparse version of loss to be + used. If None, only sparse version is used. + attention_chunk_size: int, if > 0 run attention chunked at this size + encoder_attention_type: The attention layer to use for the encoder part. + encoder_decoder_attention_type: The attention layer to use for the + encoder-decoder attention. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + enc_dec_attention_sparsity: int, if > 0 use this sparsity in attention. + + Returns: + A Transformer model as a layer that maps from a source-target tokenized + text pair to activations over a vocab set. + """ + in_encoder, out_encoder, output_vocab_size = EmbeddingAndPositionalEncodings( + input_vocab_size, + d_model, + mode, + dropout, + dropout_shared_axes, + max_len, + output_vocab_size=output_vocab_size, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + ) + + # pylint: disable=g-complex-comprehension + encoder_blocks = [ + EncoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + encoder_attention_type, + ) + for i in range(n_encoder_layers) + ] + # pylint: enable=g-complex-comprehension + + encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) + if mode == "predict": + encoder = tl.Cache(encoder) + + # pylint: disable=g-complex-comprehension + encoder_decoder_blocks = [ + EncoderDecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + encoder_decoder_attention_type, + enc_dec_attention_sparsity, + ) + for i in range(n_decoder_layers) + ] + # pylint: enable=g-complex-comprehension + + # Assemble and return the model. + return tl.Serial( + # Input: encoder_side_tokens, decoder_side_tokens + # Copy decoder tokens for use in loss. + tl.Select([0, 1, 1]), # tok_e tok_d tok_d + # Encode. + tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... + encoder, # vec_e ..... ..... ..... + # Decode. + tl.Select([2, 1, 0]), # tok_d masks vec_e ..... + tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... + out_encoder, # vec_d ..... ..... ..... + tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... + encoder_decoder_blocks, # vec_d masks ..... ..... + tl.LayerNorm(), # vec_d ..... ..... ..... + # Map to output vocab. + tl.Select([0], n_in=3), # vec_d tok_d + tl.SparseDenseWithOptions( # vec_d ..... + output_vocab_size, + d_input=d_model, + sparsity_type=loss_sparsity_type, + sparsity=loss_sparsity, + d_lowrank=loss_d_lowrank, + prob_sparse=loss_sparsity_prob, + mode=mode, + ), + ) + + +def EncoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + n_attention_layers=1, + n_feedforward_layers=1, +): + """Returns a list of layers that implements a Transformer encoder block. + + The input to the block is a pair, (activations, mask), where the mask was + created from the original source tokens to prevent attending to the padding + part of the input. + + Args: + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use. + n_attention_layers: how many residual causal attention layers should we + have before the feed-forward block (default: 1, the standard block) + n_feedforward_layers: how many FFNN layers should we have (default 1). + + Returns: + A list of layers that maps (activations, mask) to (activations, mask). + """ + # `n_attention_layers` number of residuals of attention layer + dropout. + # pylint: disable=g-complex-comprehension + residual_attentions = [ + tl.Residual( + tl.LayerNorm(), + ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_model // n_heads, + d_model // n_heads, + causal=False, + masked=True, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=attention_chunk_size, + mode=mode, + ), + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + ) + for _ in range(n_attention_layers) + ] + + feed_forwards = [ + tl.Residual( + FeedForwardWithOptions( + d_model, + d_ff, + dropout, + dropout_shared_axes, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + True, + mode, + False, + ff_sparsity_type, + ) + ) + for _ in range(n_feedforward_layers) + ] + # pylint: enable=g-complex-comprehension + + return residual_attentions + feed_forwards + + +def DecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + n_attention_layers=1, + n_feedforward_layers=1, +): + """Returns a list of layers that implements a Transformer decoder block. + + The input is an activation tensor. + + Args: + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use. + n_attention_layers: how many residual causal attention layers should we + have before the feed-forward block (default: 1, the standard block) + n_feedforward_layers: how many FFNN layers should we have (default 1). + + Returns: + A list of layers that maps an activation tensor to an activation tensor. + """ + # pylint: disable=g-complex-comprehension + causal_attentions = [ + ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_model // n_heads, + d_model // n_heads, + causal=True, + masked=False, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + for _ in range(n_attention_layers) ] - # NOTE: Positional encodings are not shared between encoder and decoder. - - # Since encoder doesn't run stepwise, we do not use predict mode there. - encoder_mode = 'eval' if mode == 'predict' else mode - in_embedder = Embedder(input_vocab_size, encoder_mode) - in_encoder = in_embedder + [ - PositionalEncoder(encoder_mode, - dropout=embedding_dropout, - max_len=max_len, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs, - pos_start_from_zero_prob=pos_start_from_zero_prob, - pos_max_offset_to_add=pos_max_offset_to_add, - use_bfloat16=use_bfloat16) - ] - - # If output_vocab_size is None, we reuse the same embedding matrix, otherwise - # we initialize one. - assert input_vocab_size or output_vocab_size - if output_vocab_size is None: - out_embedder = in_embedder - else: - out_embedder = Embedder(output_vocab_size, mode) - - out_encoder = out_embedder + [ - PositionalEncoder(mode, - dropout=embedding_dropout, - max_len=max_len, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs, - pos_start_from_zero_prob=pos_start_from_zero_prob, - pos_max_offset_to_add=pos_max_offset_to_add, - use_bfloat16=use_bfloat16) - ] - - # Set this to the value actually used. - if output_vocab_size is None: - output_vocab_size = input_vocab_size - - if input_vocab_size is None: - in_encoder = tl.AssertFunction('...a->...b', in_encoder) - else: - in_encoder = tl.AssertFunction('...->...d', in_encoder) - out_encoder = tl.AssertFunction('...->...d', out_encoder) - - return in_encoder, out_encoder, output_vocab_size - - -def ConfigurableTransformerEncoder(vocab_size, - n_classes=10, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - max_len=2048, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.Relu, - ff_dropout=0.1, - ff_chunk_size=0, - ff_use_sru=0, - ff_sparsity=0, - ff_sparsity_type='1inN', - attention_chunk_size=0, - attention_type=tl.Attention, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None): - """Returns a Transformer encoder merged with an N-way categorization head. - - This model performs text categorization: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 2 tensor representing a batch of log-probability - distributions over N categories; shape is (batch_size, `n_classes`). - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor should - be an integer in `range(vocab_size)`. These integers typically represent - token IDs from a vocabulary-based tokenizer. - n_classes: Final dimension of the output tensors, representing N-way - classification. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - n_layers: Number of encoder blocks. Each block includes attention, dropout, - residual, feed-forward (`Dense`), and activation layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each encoder block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder block; - must be an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use for the encoder part. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - - Returns: - A Transformer model that maps strings (conveyed via token IDs) to - probability-like activations over a range of output classes. - """ - positional_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), - PositionalEncoder( - mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs) - ] - - positional_encoder = tl.AssertFunction('...->...d', positional_encoder) - - # pylint: disable=g-complex-comprehension - encoder_blocks = [ - EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, - ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, ff_sparsity_type, - attention_chunk_size, attention_type) - for i in range(n_layers) - ] - # pylint: enable=g-complex-comprehension - - # Assemble and return the model. - return tl.Serial( # toks - # Encode. - tl.Branch( - positional_encoder, tl.PaddingMask()), # vecs masks - encoder_blocks, # vecs masks - tl.Select([0], n_in=2), # vecs - tl.LayerNorm(), # vecs - - # Map to output categories. - tl.Mean(axis=1), # vecs - tl.Dense(n_classes), # vecs - ) - - -def ConfigurableTransformerLM(vocab_size, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - max_len=2048, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.Relu, - ff_dropout=0.1, - ff_chunk_size=0, - ff_use_sru=0, - ff_sparsity=0, - ff_sparsity_type='1inN', - loss_sparsity_type='mult', - loss_sparsity=0, - loss_d_lowrank=0, - loss_sparsity_prob=None, - attention_chunk_size=0, - attention_type=tl.CausalAttention, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0): - """Returns a Transformer language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - This model uses only the decoder part of the overall Transformer. - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor should - be an integer in `range(vocab_size)`. These integers typically represent - token IDs from a vocabulary-based tokenizer. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - n_layers: Number of encoder blocks. Each block includes attention, dropout, - residual, feed-forward (`Dense`), and activation layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'predict'`, use fast inference. If `'train'`, each encoder block - will include dropout; else, it will pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder block; - must be an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - loss_sparsity_type: string, type of sparsity to used in loss layer. See - SparseDenseWithOptions for options. None if no sparsity should be used. - loss_sparsity: int, the sparsity for loss layer (if used) - loss_d_lowrank: int, the dimensions for intermediate layer (if used) - loss_sparsity_prob: float, the probability for sparse version of loss to be - used. If None, only sparse version is used. - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use for the decoder part. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - pos_start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - pos_max_offset_to_add: maximum offset to add to positions during training - when randomizing; this offset plus input length must still be less than - max_len for all training examples. - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - positional_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), - PositionalEncoder( - mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs, - pos_start_from_zero_prob, pos_max_offset_to_add) - ] - - # pylint: disable=g-complex-comprehension - decoder_blocks = [ - DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, - ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, ff_sparsity_type, - attention_chunk_size, attention_type) - for i in range(n_layers) - ] - # pylint: enable=g-complex-comprehension - - # Assemble and return the model. - return tl.Serial( # tokens (or chunked tuple of tokens) - tl.ShiftRight(mode=mode), # toks - positional_encoder, # vecs - decoder_blocks, # vecs - tl.LayerNorm(), # vecs - tl.SparseDenseWithOptions( # vecs - vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, - sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, - prob_sparse=loss_sparsity_prob, mode=mode), - ) - - -def ConfigurableTransformer(input_vocab_size, - output_vocab_size=None, - d_model=512, - d_ff=2048, - n_encoder_layers=6, - n_decoder_layers=6, - n_heads=8, - max_len=2048, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.Relu, - ff_dropout=0.1, - ff_chunk_size=0, - ff_use_sru=0, - ff_sparsity=0, - ff_sparsity_type='1inN', - loss_sparsity_type='mult', - loss_sparsity=0, - loss_d_lowrank=0, - loss_sparsity_prob=None, - attention_chunk_size=0, - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - enc_dec_attention_sparsity=0): - """Returns a full Transformer model. - - This model is an encoder-decoder that performs tokenized string-to-string - ("source"-to-"target") transduction: - - - inputs (2): - - - source: rank 2 tensor representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length). The - tensor elements are integers in `range(input_vocab_size)`, and `0` - values mark padding positions. - - - target: rank 2 tensor representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length). The - tensor elements are integers in `range(output_vocab_size)`, and `0` - values mark padding positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - An example use would be to translate (tokenized) sentences from English to - German. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if None, then input and target integers (token IDs) are assumed to come - from the same vocabulary. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - and decoder block. - n_encoder_layers: Number of encoder blocks. - n_decoder_layers: Number of decoder blocks. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within an encoder/decoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder - block will include dropout; else, it will pass all values through - unaltered. - ff_activation: Type of activation function at the end of each - encoder/decoder block; must be an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - loss_sparsity_type: str, type of sparsity to used in loss layer. See - SparseDenseWithOptions for options. None if no sparsity should be used. - loss_sparsity: int, the sparsity for loss layer (if used) - loss_d_lowrank: int, the dimensions for intermediate layer (if used) - loss_sparsity_prob: float, the probability for sparse version of loss to be - used. If None, only sparse version is used. - attention_chunk_size: int, if > 0 run attention chunked at this size - encoder_attention_type: The attention layer to use for the encoder part. - encoder_decoder_attention_type: The attention layer to use for the - encoder-decoder attention. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - enc_dec_attention_sparsity: int, if > 0 use this sparsity in attention. - - Returns: - A Transformer model as a layer that maps from a source-target tokenized - text pair to activations over a vocab set. - """ - in_encoder, out_encoder, output_vocab_size = ( - EmbeddingAndPositionalEncodings( - input_vocab_size, - d_model, - mode, - dropout, - dropout_shared_axes, - max_len, - output_vocab_size=output_vocab_size, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs) - ) - - # pylint: disable=g-complex-comprehension - encoder_blocks = [ - EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, - ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, ff_sparsity_type, - attention_chunk_size, encoder_attention_type) - for i in range(n_encoder_layers) - ] - # pylint: enable=g-complex-comprehension - - encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) - if mode == 'predict': - encoder = tl.Cache(encoder) - - # pylint: disable=g-complex-comprehension - encoder_decoder_blocks = [ - EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation, ff_dropout, ff_chunk_size, - ff_use_sru, ff_sparsity, ff_sparsity_type, - attention_chunk_size, encoder_decoder_attention_type, - enc_dec_attention_sparsity) - for i in range(n_decoder_layers) - ] - # pylint: enable=g-complex-comprehension - - # Assemble and return the model. - return tl.Serial( - # Input: encoder_side_tokens, decoder_side_tokens - # Copy decoder tokens for use in loss. - tl.Select([0, 1, 1]), # tok_e tok_d tok_d - - # Encode. - tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... - encoder, # vec_e ..... ..... ..... - - # Decode. - tl.Select([2, 1, 0]), # tok_d masks vec_e ..... - tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... - out_encoder, # vec_d ..... ..... ..... - tl.Branch( - [], tl.EncoderDecoderMask()), # vec_d masks ..... ..... - encoder_decoder_blocks, # vec_d masks ..... ..... - tl.LayerNorm(), # vec_d ..... ..... ..... - - # Map to output vocab. - tl.Select([0], n_in=3), # vec_d tok_d - tl.SparseDenseWithOptions( # vec_d ..... - output_vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, - sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, - prob_sparse=loss_sparsity_prob, mode=mode), - ) - - -def EncoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation, - ff_dropout, - ff_chunk_size, - ff_use_sru, - ff_sparsity, - ff_sparsity_type, - attention_chunk_size, - attention_type, - n_attention_layers=1, - n_feedforward_layers=1): - """Returns a list of layers that implements a Transformer encoder block. - - The input to the block is a pair, (activations, mask), where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use. - n_attention_layers: how many residual causal attention layers should we - have before the feed-forward block (default: 1, the standard block) - n_feedforward_layers: how many FFNN layers should we have (default 1). - - Returns: - A list of layers that maps (activations, mask) to (activations, mask). - """ - # `n_attention_layers` number of residuals of attention layer + dropout. - # pylint: disable=g-complex-comprehension - residual_attentions = [ - tl.Residual(tl.LayerNorm(), - ApplyAttentionLayer(attention_type, - d_model, - n_heads, - d_model // n_heads, - d_model // n_heads, - causal=False, - masked=True, - attention_dropout=dropout, - output_dropout=dropout, - attention_chunk_size=attention_chunk_size, - mode=mode), - tl.Dropout(rate=dropout, - shared_axes=dropout_shared_axes, - mode=mode) - ) - for _ in range(n_attention_layers) - ] - - feed_forwards = [ - tl.Residual( - FeedForwardWithOptions(d_model, d_ff, dropout, - dropout_shared_axes, ff_activation, - ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, True, mode, False, - ff_sparsity_type) - ) - for _ in range(n_feedforward_layers) - ] - # pylint: enable=g-complex-comprehension - - return residual_attentions + feed_forwards - - -def DecoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation, - ff_dropout, - ff_chunk_size, - ff_use_sru, - ff_sparsity, - ff_sparsity_type, - attention_chunk_size, - attention_type, - n_attention_layers=1, - n_feedforward_layers=1): - """Returns a list of layers that implements a Transformer decoder block. - - The input is an activation tensor. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use. - n_attention_layers: how many residual causal attention layers should we - have before the feed-forward block (default: 1, the standard block) - n_feedforward_layers: how many FFNN layers should we have (default 1). - - Returns: - A list of layers that maps an activation tensor to an activation tensor. - """ - # pylint: disable=g-complex-comprehension - causal_attentions = [ApplyAttentionLayer( - attention_type, - d_model, - n_heads, - d_model // n_heads, - d_model // n_heads, - causal=True, - masked=False, - attention_dropout=dropout, - output_dropout=dropout, - attention_chunk_size=attention_chunk_size, - mode=mode) for _ in range(n_attention_layers)] - - residual_attentions = [ - tl.Residual( - tl.LayerNorm(), - causal_attentions[i], - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - ) for i in range(n_attention_layers)] - - feed_forwards = [ - tl.Residual( - FeedForwardWithOptions(d_model, d_ff, dropout, - dropout_shared_axes, ff_activation, - ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, True, mode, False, - ff_sparsity_type) - ) - for _ in range(n_feedforward_layers) - ] - # pylint: enable=g-complex-comprehension - - return residual_attentions + feed_forwards - - -def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation, ff_dropout, ff_chunk_size, - ff_use_sru, ff_sparsity, ff_sparsity_type, - attention_chunk_size, attention_type, - enc_dec_attention_sparsity=0): - """Returns a list of layers implementing a Transformer encoder-decoder block. - - The input is a triple (decoder_activations, mask, encoder_activiations) where - the mask is created from the original input token IDs to prevent attending to - the padding part of the encoder. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use. - enc_dec_attention_sparsity: Sparsity to use in encoder-decoder attention. - - Returns: - A list of layers which maps triples (decoder_activations, mask, - encoder_activations) to triples of the same sort. - """ - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - # TODO(afrozm): This layer isn't configurable because: We currently don't have - # any alternative for it (LSH cannot do it fundamentally, that's why we have - # NoEncDec models, and local attention doesn't make sense in the general - # setting where we don't know what in input is local to what in output; - # some variants of FAVOR can do it, so maybe in the future, - # but we don't have them yet). - if isinstance(enc_dec_attention_sparsity, tuple): - q_sparsity, result_sparsity = enc_dec_attention_sparsity - elif enc_dec_attention_sparsity > 0: - q_sparsity = enc_dec_attention_sparsity - result_sparsity = 'noop' # We simply skip Dense layer after attention. - else: - q_sparsity = None - result_sparsity = None - attention_qkv = tl.AttentionQKV( - d_model, n_heads=n_heads, dropout=dropout, mode=mode, - cache_KV_in_predict=True, - q_sparsity=q_sparsity, result_sparsity=result_sparsity) - - causal_attention = ApplyAttentionLayer( - attention_type, - d_model, - n_heads, - d_model // n_heads, - d_model // n_heads, - causal=True, - masked=True, - attention_dropout=dropout, - output_dropout=dropout, - attention_chunk_size=attention_chunk_size, - mode=mode) - - feed_forward = FeedForwardWithOptions(d_model, d_ff, dropout, - dropout_shared_axes, ff_activation, - ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, True, mode, False, - ff_sparsity_type) - - return [ # vec_d masks vec_e - tl.Residual( - tl.LayerNorm(), # vec_d ..... ..... - causal_attention, # vec_d ..... ..... - _Dropout(), # vec_d ..... ..... - ), - tl.Residual( - tl.LayerNorm(), # vec_d ..... ..... - tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e - attention_qkv, # vec_d masks vec_e - _Dropout(), # vec_d masks vec_e - ), - tl.Residual( - feed_forward # vec_d masks vec_e - ), - ] + residual_attentions = [ + tl.Residual( + tl.LayerNorm(), + causal_attentions[i], + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + ) + for i in range(n_attention_layers) + ] + + feed_forwards = [ + tl.Residual( + FeedForwardWithOptions( + d_model, + d_ff, + dropout, + dropout_shared_axes, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + True, + mode, + False, + ff_sparsity_type, + ) + ) + for _ in range(n_feedforward_layers) + ] + # pylint: enable=g-complex-comprehension + + return residual_attentions + feed_forwards + + +def EncoderDecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + enc_dec_attention_sparsity=0, +): + """Returns a list of layers implementing a Transformer encoder-decoder block. + + The input is a triple (decoder_activations, mask, encoder_activiations) where + the mask is created from the original input token IDs to prevent attending to + the padding part of the encoder. + + Args: + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use. + enc_dec_attention_sparsity: Sparsity to use in encoder-decoder attention. + + Returns: + A list of layers which maps triples (decoder_activations, mask, + encoder_activations) to triples of the same sort. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + # TODO(afrozm): This layer isn't configurable because: We currently don't have + # any alternative for it (LSH cannot do it fundamentally, that's why we have + # NoEncDec models, and local attention doesn't make sense in the general + # setting where we don't know what in input is local to what in output; + # some variants of FAVOR can do it, so maybe in the future, + # but we don't have them yet). + if isinstance(enc_dec_attention_sparsity, tuple): + q_sparsity, result_sparsity = enc_dec_attention_sparsity + elif enc_dec_attention_sparsity > 0: + q_sparsity = enc_dec_attention_sparsity + result_sparsity = "noop" # We simply skip Dense layer after attention. + else: + q_sparsity = None + result_sparsity = None + attention_qkv = tl.AttentionQKV( + d_model, + n_heads=n_heads, + dropout=dropout, + mode=mode, + cache_KV_in_predict=True, + q_sparsity=q_sparsity, + result_sparsity=result_sparsity, + ) + + causal_attention = ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_model // n_heads, + d_model // n_heads, + causal=True, + masked=True, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + + feed_forward = FeedForwardWithOptions( + d_model, + d_ff, + dropout, + dropout_shared_axes, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + True, + mode, + False, + ff_sparsity_type, + ) + + return [ # vec_d masks vec_e + tl.Residual( + tl.LayerNorm(), # vec_d ..... ..... + causal_attention, # vec_d ..... ..... + _Dropout(), # vec_d ..... ..... + ), + tl.Residual( + tl.LayerNorm(), # vec_d ..... ..... + tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e + attention_qkv, # vec_d masks vec_e + _Dropout(), # vec_d masks vec_e + ), + tl.Residual(feed_forward), # vec_d masks vec_e + ] diff --git a/trax/models/research/configurable_transformer_test.py b/trax/models/research/configurable_transformer_test.py deleted file mode 100644 index 0c10f078f..000000000 --- a/trax/models/research/configurable_transformer_test.py +++ /dev/null @@ -1,188 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Transformer models.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax.layers import test_utils -from trax.models.research import configurable_transformer as ct - - -class ConfigurableTransformerTest(parameterized.TestCase): - - def test_transformer_lm_forward_shape(self): - vocab_size = 16 - model = ct.ConfigurableTransformerLM( - vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) - x = np.ones((3, 5)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 5, vocab_size)) - - def _test_transformer_forward_shape(self, input_vocab_size, - output_vocab_size): - model = ct.ConfigurableTransformer( - input_vocab_size, - output_vocab_size, - d_model=32, - d_ff=64, - n_encoder_layers=2, - n_decoder_layers=2, - n_heads=2) - xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - y, _ = model(xs) - - vocab_size = output_vocab_size or input_vocab_size - self.assertEqual(y.shape, (3, 5, vocab_size)) - - @parameterized.named_parameters(('same_vocab', 16, None), - ('same_size', 16, 16), - ('different_size', 16, 50)) - def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): - """Run the Transformer forward and check output shape.""" - self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) - - - def test_dot_product_causal_attention_fast_inference(self): - self._test_fast_inference(length=5) - - def _test_fast_inference(self, length): - with fastmath.use_backend(fastmath.Backend.JAX): - model_fn = functools.partial( - ct.ConfigurableTransformerLM, - vocab_size=16, - d_model=4, - d_ff=8, - n_layers=2, - n_heads=2, - ) - batch_size = 2 - inp = np.zeros((batch_size, length), dtype=np.int32) - - test_utils.test_eval_equals_predict(inp, model_fn) - - def test_sparse_configurable_transformer_fast_inference(self): - self._test_sparse_fast_inference(length=5) - - def _test_sparse_fast_inference(self, length): - with fastmath.use_backend(fastmath.Backend.JAX): - vocab_size = 16 - d_model = 4 - batch_size = 2 - - encoder_decoder_attention_type = functools.partial( - tl.MultiplicativeConvCausalAttention, - sparsity=2, - length_kernel_size=1, - ) - - model_fn = functools.partial( - ct.ConfigurableTransformer, - input_vocab_size=vocab_size, - d_model=d_model, - d_ff=8, - n_encoder_layers=2, - n_decoder_layers=2, - n_heads=2, - loss_sparsity=2, - ff_sparsity=2, - encoder_decoder_attention_type=encoder_decoder_attention_type, - ff_use_sru=(1, 4), - ) - - inp = np.random.randint(vocab_size, size=(batch_size, length)) - out = np.zeros((batch_size, length), dtype=np.int32) - - test_utils.test_eval_equals_predict((inp, out), model_fn, seq_tensor=1) - - @parameterized.named_parameters( - ('positional_encoding', None), - ('fixed_base_positional_encoding', 'fixed-base'), - ('infinite_positional_encoding', 'infinite'), - ('infinite_affine_positional_encoding', 'infinite-affine'), - ('axial_positional_encoding', (2, 16))) - def test_positional_encoder(self, pos_axial_shape): - # dim should divide FixedBasePositionalEncoding.n_digits - batch, length, dim = 2, 32, 8 - input_shape = (batch, length, dim) - vocab_size = 32 - x = np.random.randint(0, vocab_size - 1, input_shape) - # should sum to dim - pos_d_axial_embs = (4, 4) - - positional_encoding = ct.PositionalEncoder( - 'train', dropout=0.1, max_len=length, pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs) - _, _ = positional_encoding.init(shapes.signature(x)) - y = positional_encoding(x) - self.assertEqual(y.shape, input_shape) - - @parameterized.named_parameters( - ('input_vocab_size_only', 32, None), - ('output_vocab_size_only', None, 32), - ('same_input_output_vocab_size', 32, 32), - ('different_input_output_vocab_size', 32, 16), - ) - def test_embedding_and_positional_encodings(self, input_vocab_size, - output_vocab_size): - d_model = 16 - max_len = 32 - batch = 2 - input_shape = (batch, max_len) - output_vocab_size_expected = output_vocab_size or input_vocab_size - x_out = np.random.randint(0, output_vocab_size_expected - 1, input_shape) - if input_vocab_size is None: - x_in = np.random.uniform(size=list(input_shape) + [2]) - else: - x_in = np.random.randint(0, input_vocab_size - 1, input_shape) - - in_encoder, out_encoder, output_vocab_size_result = ( - ct.EmbeddingAndPositionalEncodings( - input_vocab_size, - d_model, - 'train', - 0.1, - [-2], - max_len, - output_vocab_size=output_vocab_size, - pos_axial_shape=None, - pos_d_axial_embs=None)) - - self.assertEqual(output_vocab_size_result, output_vocab_size_expected) - - model_in = tl.Serial(in_encoder) - model_out = tl.Serial(out_encoder) - - model_in.init(shapes.signature(x_in)) - model_out.init(shapes.signature(x_out)) - - y = model_in(x_in) - self.assertEqual(y.shape, input_shape + (d_model,)) - - y = model_out(x_out) - self.assertEqual(y.shape, input_shape + (d_model,)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/hourglass_test.py b/trax/models/research/hourglass_test.py deleted file mode 100644 index 9329c109e..000000000 --- a/trax/models/research/hourglass_test.py +++ /dev/null @@ -1,145 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Hourglass model.""" - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import jax -import numpy as np -from trax import fastmath -from trax import layers as tl -from trax import shapes -import trax.layers.research.resampling as resampling -import trax.models.research.hourglass as hourglass - - -class HourglassTest(parameterized.TestCase): - - def _check_forward_shape(self, model, input_shape, output_vocab_size): - x = np.ones(input_shape).astype(np.int32) - model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (*input_shape, output_vocab_size)) - - def test_hourglass_lm_forward_shape(self): - d_model = 16 - vocab_size = 7 - model = hourglass.HourglassLM( - vocab_size, - hierarchy='2@3 2@6 2@3', - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - ) - - batch_size, seq_len = 3, 24 - self._check_forward_shape(model, - input_shape=(batch_size, seq_len), - output_vocab_size=vocab_size) - - def test_lsh_attention_in_vanilla(self): - d_model = 16 - vocab_size = 7 - - gin.bind_parameter('PureLSHSelfAttentionWrapper.pure_lsh_implementation', - tl.PureLSHSelfAttention) - gin.bind_parameter('PureLSHSelfAttention.chunk_len', 2) - - model = hourglass.HourglassLM( - vocab_size, - hierarchy='2@3', - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - vanilla_attn_type=tl.PureLSHSelfAttentionWrapper, - downsampling_fn=resampling.LinearPooling, - upsampling_fn=resampling.LinearUpsampling, - ) - - batch_size, seq_len = 3, 12 - self._check_forward_shape( - model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size) - - def _test_autoregressive_property(self, model, input_shape, - output_vocab_size): - rng_1 = jax.random.PRNGKey(0) - rng_2 = jax.random.PRNGKey(1) - - def _get_output_logits(unitialized_eval_model: tl.Layer, x): - input_signature = shapes.signature(x) - unitialized_eval_model.init(input_signature, rng=rng_1, use_cache=False) - - output_logits, *_ = unitialized_eval_model(x, rng=rng_1) - return output_logits - - def check_autoregressive_property(model): - with fastmath.use_backend(fastmath.Backend.JAX): - x_1 = jax.random.randint(rng_1, input_shape, 0, output_vocab_size) - y_1 = _get_output_logits(model, x_1) - - x_2 = jax.random.randint(rng_2, input_shape, 0, output_vocab_size) - - for i in range(input_shape[1]): - masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]), axis=1) - - y_2 = _get_output_logits(model, masked_x_2) - self.assertEqual(y_2.shape[0], input_shape[1]) - np.testing.assert_array_almost_equal(y_1[:i + 1], y_2[:i + 1]) - - check_autoregressive_property(model) - - def test_hourglass_lm_autoregressive_property(self): - d_model = 8 - vocab_size = 26 - - model_single_stage = hourglass.HourglassLM( - vocab_size, - hierarchy='2@4', - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - ) - - model_multi_stage = hourglass.HourglassLM( - vocab_size, - hierarchy='2@3 2@6 2@3', - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - ) - - input_shape = (1, 12) - self._test_autoregressive_property(model_single_stage, input_shape, - output_vocab_size=vocab_size) - self._test_autoregressive_property(model_multi_stage, input_shape, - output_vocab_size=vocab_size) - - def test_parse_hourglass_hierarchy(self): - self.assertEqual(hourglass._parse_hierarchy('6@3'), ([6], [3])) - self.assertEqual(hourglass._parse_hierarchy('3@2 2@6 5@24 2@6 3@2'), ( - [3, 2, 5], [2, 3, 4] - )) - self.assertRaises(ValueError, hourglass._parse_hierarchy, '1@2 2@3 1@2') - self.assertRaises(ValueError, hourglass._parse_hierarchy, '1@2 2@3') - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/layerdrop_transformer.py b/trax/models/research/layerdrop_transformer.py index 0709fad38..57364f7a6 100644 --- a/trax/models/research/layerdrop_transformer.py +++ b/trax/models/research/layerdrop_transformer.py @@ -24,278 +24,303 @@ def LargerThan(val): - """Checks if the input is larger than a certain value.""" - return tl.Fn('LargerThan', lambda x: x > val) - - -@assert_shape('...s->...sv') -def SkippingTransformerLM(vocab_size, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - mode='train', - ff_activation=tl.Relu, - skip_fraction=0.4): - """Returns a Skipping Transformer language model. - - The input to the model is a tensor of tokens. (This model uses only the - decoder part of the overall Transformer.) - - Args: - vocab_size: int: vocab size - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_layers: int: number of encoder/decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference - ff_activation: the non-linearity in feed-forward layer - skip_fraction: fraction of times to skip some layers - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - embedder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, mode=mode), - tl.PositionalEncoding(max_len=max_len, mode=mode), - ] - - @assert_shape('...sd,->...sd,') - def ConditionedBlock(current_layer_num): - return tl.Serial( - # stack: embedding, n_layers_to_keep - tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep - tl.Cond( - # if n_layers_to_keep > current_layer_num - LargerThan(float(current_layer_num)), - # then: run block - tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access - d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), - # else: run noop - tl.Serial() + """Checks if the input is larger than a certain value.""" + return tl.Fn("LargerThan", lambda x: x > val) + + +@assert_shape("...s->...sv") +def SkippingTransformerLM( + vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + mode="train", + ff_activation=tl.Relu, + skip_fraction=0.4, +): + """Returns a Skipping Transformer language model. + + The input to the model is a tensor of tokens. (This model uses only the + decoder part of the overall Transformer.) + + Args: + vocab_size: int: vocab size + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference + ff_activation: the non-linearity in feed-forward layer + skip_fraction: fraction of times to skip some layers + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + embedder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, mode=mode), + tl.PositionalEncoding(max_len=max_len, mode=mode), + ] + + @assert_shape("...sd,->...sd,") + def ConditionedBlock(current_layer_num): + return tl.Serial( + # stack: embedding, n_layers_to_keep + tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep + tl.Cond( + # if n_layers_to_keep > current_layer_num + LargerThan(float(current_layer_num)), + # then: run block + tl.Serial( + transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access + d_model, d_ff, n_heads, dropout, [], mode, ff_activation + ), + ), + # else: run noop + tl.Serial(), ) - # stack: embedding, n_layers_to_keep + # stack: embedding, n_layers_to_keep ) - if mode == 'train': - if skip_fraction == 0.0: - minimum_layers = float(n_layers) - maximum_layers = float(n_layers) + if mode == "train": + if skip_fraction == 0.0: + minimum_layers = float(n_layers) + maximum_layers = float(n_layers) + else: + minimum_layers = 0.0 + maximum_layers = float(n_layers) / skip_fraction else: - minimum_layers = 0.0 - maximum_layers = float(n_layers) / skip_fraction - else: - minimum_layers = maximum_layers = float(n_layers) - - return tl.Serial( - tl.ShiftRight(mode=mode), - embedder, - # stack: embedding - tl.RandomUniform(minimum_layers, maximum_layers, sync=True), - # stack: n_layers_to_keep, embedding - tl.Swap(), - # stack: embedding, n_layers_to_keep - [ConditionedBlock(i) for i in range(n_layers)], - # stack: embedding, n_layers_to_keep - tl.AssertShape('...sd,'), - tl.Select([0], n_in=2), # stack: embedding - tl.AssertShape('...sd'), - tl.LayerNorm(), - tl.Dense(vocab_size), - ) - - -@assert_shape('...s->...sv') -def EveryOtherLayerDropTransformerLM(vocab_size, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - mode='train', - ff_activation=tl.Relu, - skip_mode='even', - skip_fraction=0.5, - eval_skip_fraction=0.0): - """Returns an "EveryOther" LayerDrop Transformer language model. - - During each training step it either runs all layers, or skips a subset of - layers. This subset is the same every time, and it is specified by - "skip_mode". - The input to the model is a tensor of tokens. (This model uses only the - decoder part of the overall Transformer.) - - Args: - vocab_size: int: vocab size - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_layers: int: number of encoder/decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference - ff_activation: the non-linearity in feed-forward layer - skip_mode: which layers to skip when skipping: even/odd/1half/2half. - skip_fraction: fraction of times to skip layers - eval_skip_fraction: fraction of times to skip layers during eval - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - embedder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, mode=mode), - tl.PositionalEncoding(max_len=max_len, mode=mode), - ] - - if mode == 'train': - pass - else: - skip_fraction = eval_skip_fraction - - skip_mode_funs = { # which layers should be skipped? - 'even': (lambda num: num%2 == 0), # 0th layer is even - 'odd': (lambda num: num%2 == 1), - '1half': (lambda num: num < (n_layers/2)), - '2half': (lambda num: num >= (n_layers/2)), - } - - skip_mode_fun = skip_mode_funs[skip_mode] - - @assert_shape('...sd,->...sd,') - def ConditionedBlock(current_layer_num): + minimum_layers = maximum_layers = float(n_layers) + return tl.Serial( + tl.ShiftRight(mode=mode), + embedder, + # stack: embedding + tl.RandomUniform(minimum_layers, maximum_layers, sync=True), + # stack: n_layers_to_keep, embedding + tl.Swap(), # stack: embedding, n_layers_to_keep - tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep - tl.Cond( - # if random() > skip_fraction OR layer not in skip_mode ... - LargerThan(skip_fraction if skip_mode_fun(current_layer_num) - else 0.0), - # then: run block - tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access - d_model, d_ff, n_heads, dropout, [], mode, ff_activation)) - # else: noop (implicit) - ) + [ConditionedBlock(i) for i in range(n_layers)], # stack: embedding, n_layers_to_keep + tl.AssertShape("...sd,"), + tl.Select([0], n_in=2), # stack: embedding + tl.AssertShape("...sd"), + tl.LayerNorm(), + tl.Dense(vocab_size), + ) + + +@assert_shape("...s->...sv") +def EveryOtherLayerDropTransformerLM( + vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + mode="train", + ff_activation=tl.Relu, + skip_mode="even", + skip_fraction=0.5, + eval_skip_fraction=0.0, +): + """Returns an "EveryOther" LayerDrop Transformer language model. + + During each training step it either runs all layers, or skips a subset of + layers. This subset is the same every time, and it is specified by + "skip_mode". + The input to the model is a tensor of tokens. (This model uses only the + decoder part of the overall Transformer.) + + Args: + vocab_size: int: vocab size + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference + ff_activation: the non-linearity in feed-forward layer + skip_mode: which layers to skip when skipping: even/odd/1half/2half. + skip_fraction: fraction of times to skip layers + eval_skip_fraction: fraction of times to skip layers during eval + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + embedder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, mode=mode), + tl.PositionalEncoding(max_len=max_len, mode=mode), + ] + + if mode == "train": + pass + else: + skip_fraction = eval_skip_fraction + + skip_mode_funs = { # which layers should be skipped? + "even": (lambda num: num % 2 == 0), # 0th layer is even + "odd": (lambda num: num % 2 == 1), + "1half": (lambda num: num < (n_layers / 2)), + "2half": (lambda num: num >= (n_layers / 2)), + } + + skip_mode_fun = skip_mode_funs[skip_mode] + + @assert_shape("...sd,->...sd,") + def ConditionedBlock(current_layer_num): + return tl.Serial( + # stack: embedding, n_layers_to_keep + tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep + tl.Cond( + # if random() > skip_fraction OR layer not in skip_mode ... + LargerThan(skip_fraction if skip_mode_fun(current_layer_num) else 0.0), + # then: run block + tl.Serial( + transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access + d_model, d_ff, n_heads, dropout, [], mode, ff_activation + ) + ) + # else: noop (implicit) + ) + # stack: embedding, n_layers_to_keep ) - return tl.Serial( - tl.ShiftRight(mode=mode), - embedder, - # stack: embedding - tl.RandomUniform(0., 1., sync=True), - # stack: n_layers_to_keep, embedding - tl.Swap(), - # stack: embedding, n_layers_to_keep - [ConditionedBlock(i) for i in range(n_layers)], - # stack: embedding, n_layers_to_keep - tl.Select([0], n_in=2), # stack: embedding - tl.LayerNorm(), - tl.Dense(vocab_size), - ) - - -@assert_shape('...s->...sv') -def LayerDropTransformerLM(vocab_size, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - mode='train', - ff_activation=tl.Relu, - skip_fraction=0.4, - eval_skip_fraction='every_other'): - """Returns a LayerDrop Transformer language model. - - Based on Fan, Grave, Joulin 2019, https://arxiv.org/abs/1909.11556 . - - The input to the model is a tensor of tokens. (This model uses only the - decoder part of the overall Transformer.) - - Args: - vocab_size: int: vocab size - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_layers: int: number of encoder/decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference - ff_activation: the non-linearity in feed-forward layer - skip_fraction: probability of skipping a layer; it can be a single - probability or a list of probabilities different for each layer - eval_skip_fraction: probability of skipping a layer during eval; it can be a - single probability, or a list of probabilities different for each layer, - or a string "every other" implementing a strategy from original paper - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - embedder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, mode=mode), - tl.PositionalEncoding(max_len=max_len, mode=mode), - ] - - if not isinstance(skip_fraction, (list, tuple)): - # If we don't get a list of skip_fractions we use the same skip_fraction - # for each layer. - skip_fraction = [skip_fraction for i in range(n_layers)] - if len(skip_fraction) != n_layers: - raise ValueError('n_layers ({}) must be equal to len(skip_fraction) ({})' - .format(n_layers, len(skip_fraction))) - - if eval_skip_fraction == 'every_other': - # 100% skipping for even-numbered layers; 0% for odd-numbered layers. - eval_skip_fraction = [(1.0 if i % int(1./skip_fraction[i]) == 0 else 0.0) - if skip_fraction[i] != 0 else 0.0 - for i in range(n_layers)] - if eval_skip_fraction == 'same': - # Same skip_fraction as in training. - eval_skip_fraction = skip_fraction - if not isinstance(eval_skip_fraction, (list, tuple)): - # If we don't get a list of eval_skip_fractions we use the same - # eval_skip_fraction for each layer. - eval_skip_fraction = [eval_skip_fraction for i in range(n_layers)] - if len(eval_skip_fraction) != n_layers: - raise ValueError( - 'n_layers ({}) must be equal to len(eval_skip_fraction) ({})' - .format(n_layers, len(eval_skip_fraction))) - - @assert_shape('...sd->...sd') - def ConditionedBlock(current_layer_num): return tl.Serial( + tl.ShiftRight(mode=mode), + embedder, # stack: embedding - tl.RandomUniform(0., 1, sync=True), - # stack: random_uniform, embedding - tl.Cond( - # if random_uniform > skip_fraction - LargerThan(skip_fraction[current_layer_num] if mode == 'train' - else eval_skip_fraction[current_layer_num]), - # then: run block - tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access - d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), - # else: run noop - tl.Serial() + tl.RandomUniform(0.0, 1.0, sync=True), + # stack: n_layers_to_keep, embedding + tl.Swap(), + # stack: embedding, n_layers_to_keep + [ConditionedBlock(i) for i in range(n_layers)], + # stack: embedding, n_layers_to_keep + tl.Select([0], n_in=2), # stack: embedding + tl.LayerNorm(), + tl.Dense(vocab_size), + ) + + +@assert_shape("...s->...sv") +def LayerDropTransformerLM( + vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + mode="train", + ff_activation=tl.Relu, + skip_fraction=0.4, + eval_skip_fraction="every_other", +): + """Returns a LayerDrop Transformer language model. + + Based on Fan, Grave, Joulin 2019, https://arxiv.org/abs/1909.11556 . + + The input to the model is a tensor of tokens. (This model uses only the + decoder part of the overall Transformer.) + + Args: + vocab_size: int: vocab size + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference + ff_activation: the non-linearity in feed-forward layer + skip_fraction: probability of skipping a layer; it can be a single + probability or a list of probabilities different for each layer + eval_skip_fraction: probability of skipping a layer during eval; it can be a + single probability, or a list of probabilities different for each layer, + or a string "every other" implementing a strategy from original paper + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + embedder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, mode=mode), + tl.PositionalEncoding(max_len=max_len, mode=mode), + ] + + if not isinstance(skip_fraction, (list, tuple)): + # If we don't get a list of skip_fractions we use the same skip_fraction + # for each layer. + skip_fraction = [skip_fraction for i in range(n_layers)] + if len(skip_fraction) != n_layers: + raise ValueError( + "n_layers ({}) must be equal to len(skip_fraction) ({})".format( + n_layers, len(skip_fraction) + ) + ) + + if eval_skip_fraction == "every_other": + # 100% skipping for even-numbered layers; 0% for odd-numbered layers. + eval_skip_fraction = [ + (1.0 if i % int(1.0 / skip_fraction[i]) == 0 else 0.0) + if skip_fraction[i] != 0 + else 0.0 + for i in range(n_layers) + ] + if eval_skip_fraction == "same": + # Same skip_fraction as in training. + eval_skip_fraction = skip_fraction + if not isinstance(eval_skip_fraction, (list, tuple)): + # If we don't get a list of eval_skip_fractions we use the same + # eval_skip_fraction for each layer. + eval_skip_fraction = [eval_skip_fraction for i in range(n_layers)] + if len(eval_skip_fraction) != n_layers: + raise ValueError( + "n_layers ({}) must be equal to len(eval_skip_fraction) ({})".format( + n_layers, len(eval_skip_fraction) ) - # stack: embedding ) - return tl.Serial( - tl.ShiftRight(mode=mode), - embedder, - [ConditionedBlock(i) for i in range(n_layers)], - tl.LayerNorm(), - tl.Dense(vocab_size), - ) + @assert_shape("...sd->...sd") + def ConditionedBlock(current_layer_num): + return tl.Serial( + # stack: embedding + tl.RandomUniform(0.0, 1, sync=True), + # stack: random_uniform, embedding + tl.Cond( + # if random_uniform > skip_fraction + LargerThan( + skip_fraction[current_layer_num] + if mode == "train" + else eval_skip_fraction[current_layer_num] + ), + # then: run block + tl.Serial( + transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access + d_model, d_ff, n_heads, dropout, [], mode, ff_activation + ) + ), + # else: run noop + tl.Serial(), + ) + # stack: embedding + ) + + return tl.Serial( + tl.ShiftRight(mode=mode), + embedder, + [ConditionedBlock(i) for i in range(n_layers)], + tl.LayerNorm(), + tl.Dense(vocab_size), + ) diff --git a/trax/models/research/layerdrop_transformer_test.py b/trax/models/research/layerdrop_transformer_test.py deleted file mode 100644 index 2fe41fe07..000000000 --- a/trax/models/research/layerdrop_transformer_test.py +++ /dev/null @@ -1,81 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Reformer models.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.models.research import layerdrop_transformer - - -class SkippingTransformerTest(absltest.TestCase): - - def test_skipping_transformer_forward_shape(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = layerdrop_transformer.SkippingTransformerLM( - vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -class LayerDropTransformerTest(absltest.TestCase): - - def test_layerdrop_transformer_forward_shape(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = layerdrop_transformer.LayerDropTransformerLM( - vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - def test_layerdrop_layerwise_skip_fraction(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = layerdrop_transformer.LayerDropTransformerLM( - vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16, - skip_fraction=[0.2, 0.8]) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -class EveryOtherLayerDropTransformerTest(absltest.TestCase): - - def test_everyother_layerdrop_transformer_forward(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = layerdrop_transformer.EveryOtherLayerDropTransformerLM( - vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16, - skip_mode='1half') - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/predict_terraformer.py b/trax/models/research/predict_terraformer.py index 79df866c5..2951bf9ee 100644 --- a/trax/models/research/predict_terraformer.py +++ b/trax/models/research/predict_terraformer.py @@ -40,99 +40,31 @@ """ +import functools -import sys -import time - -import os -import random -import time import numpy as np +import tensorflow_datasets as tfds import trax -from trax import layers as tl -from trax import fastmath -from trax.fastmath import numpy as jnp -from trax.supervised import training -from trax.layers.assert_shape import assert_shape - - -import copy -import functools -import gc -import os -import time -from jax.config import config -import numpy as np -import psutil -from tensorflow.compat.v2 import test - from trax import fastmath from trax import layers as tl from trax import models from trax import shapes -from trax.supervised import decoding -import gin - - -# from colabtools import adhoc_import -import json -import gc -import jax -import numpy as np -import os -import time -import gin - -import tensorflow_datasets as tfds - - -# from colabtools import adhoc_import -import functools - from trax.data import tf_inputs -import tensorflow_datasets as tfds -from t5.data import preprocessors as t5_processors -import t5.data - -from trax import data -from trax import layers as tl -from trax import models -from trax import optimizers -from trax.data import inputs -from trax.supervised import lr_schedules -from trax.supervised import trainer_lib -from trax.rl import serialization_utils -from trax.rl import space_serializer -import math from trax.fastmath import numpy as numpy_math -import trax - - -import numpy as np - -from trax import fastmath -from trax.fastmath import numpy as jnp -from trax.layers import base -from trax.layers import combinators as cb -from trax.layers import core -from trax.layers import initializers as init from trax.layers.assert_shape import assert_shape -from trax.layers.base import Fn -from trax.layers.research import sparsity +from trax.supervised import decoding -import functools -from trax import layers as tl -from trax.fastmath import numpy as jnp -from trax.models.reformer import reformer -from trax.models.research import configurable_transformer as ct -from trax.models.research import transformer2 as t2 +# from colabtools import adhoc_import +# from colabtools import adhoc_import ##### og_PositionalEncoding = tl.PositionalEncoding -trax.layers.attention.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64) +trax.layers.attention.PositionalEncoding = functools.partial( + og_PositionalEncoding, d_feature=64 +) trax.layers.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64) tl.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64) @@ -141,22 +73,24 @@ import gin + gin.enter_interactive_mode() def model_configure(*args, **kwargs): - kwargs['module'] = 'trax.models' - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.models" + return gin.external_configurable(*args, **kwargs) + #### -xm2a_main = '/tmp/Terraformer/model_200000.pkl.gz' -xm2a_weights = '/tmp/Terraformer/model_200000.weights.npy.gz' -xm2a_opt_slots = '/tmp/Terraformer/model_200000.opt_slots0.npy.gz' -xm2a_config = '/tmp/Terraformer/config.gin' +xm2a_main = "/tmp/Terraformer/model_200000.pkl.gz" +xm2a_weights = "/tmp/Terraformer/model_200000.weights.npy.gz" +xm2a_opt_slots = "/tmp/Terraformer/model_200000.opt_slots0.npy.gz" +xm2a_config = "/tmp/Terraformer/config.gin" -VOCAB_FILE = 'en_16k.subword' -VOCAB_DIR = '/tmp/Terraformer' +VOCAB_FILE = "en_16k.subword" +VOCAB_DIR = "/tmp/Terraformer" #### @@ -169,31 +103,35 @@ def model_configure(*args, **kwargs): # ) og_DotProductCausalAttention = trax.layers.attention.DotProductCausalAttention trax.layers.attention.DotProductCausalAttention = functools.partial( - og_DotProductCausalAttention, max_inference_length=16384, + og_DotProductCausalAttention, + max_inference_length=16384, ) # gin_config.append( # '\nMixedLSHSelfAttention.std_length=16384' # ) -gin_config = [l for l in gin_config if 'mira' not in l] -gin_config = [l for l in gin_config if 'okenize' not in l] # tokenize +gin_config = [l for l in gin_config if "mira" not in l] +gin_config = [l for l in gin_config if "okenize" not in l] # tokenize -gin_config = ''.join(gin_config) +gin_config = "".join(gin_config) gin.parse_config(gin_config) -gin.operative_config_str().split('\n') +gin.operative_config_str().split("\n") print(gin_config) #### + def model(mode): - return models.ConfigurableTerraformer(mode=mode) + return models.ConfigurableTerraformer(mode=mode) + # #### -padding_fun = trax.data.PadToLength(len_map={0: 15*1024, 1: 15*1024, 2: 15*1024}, - pad_value = {0: 0, 1: 0, 2:0}) +padding_fun = trax.data.PadToLength( + len_map={0: 15 * 1024, 1: 15 * 1024, 2: 15 * 1024}, pad_value={0: 0, 1: 0, 2: 0} +) # padding_fun = lambda x: x # padding_fun = trax.data.PadToLength(len_map={0: 128, 1: 128, 2:128}, pad_value={0: 0, 1: 0, 2: 0}, multiple=True) @@ -202,48 +140,67 @@ def model(mode): dataset = tfds.summarization.scientific_papers.ScientificPapers() -valid = tfds.load(name='scientific_papers/arxiv:1.1.1')['test'] +valid = tfds.load(name="scientific_papers/arxiv:1.1.1")["test"] index = 0 xarts = [] for x in valid: - xarts.append(x) - index += 1 - if index == 3: - break + xarts.append(x) + index += 1 + if index == 3: + break model_file = xm2a_main shape11 = trax.shapes.ShapeDtype((1, 1), dtype=numpy_math.int32) -shape1l = trax.shapes.ShapeDtype((1, 15*1024), dtype=numpy_math.int32) +shape1l = trax.shapes.ShapeDtype((1, 15 * 1024), dtype=numpy_math.int32) with trax.fastmath.use_backend(trax.fastmath.Backend.JAX): - model = model(mode='eval') - model.init_from_file(model_file, weights_only=True) - # in mode='predict' use input_signature=(shape1l, shape11) - old_state = model.state + model = model(mode="eval") + model.init_from_file(model_file, weights_only=True) + # in mode='predict' use input_signature=(shape1l, shape11) + old_state = model.state # Decode the first article -xart = xarts[2]['article'] +xart = xarts[2]["article"] question = xart.numpy().decode() # print(question[:512]) -tokenized = next(padding_fun(trax.data.tokenize([question,], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, n_reserved_ids=100))) +tokenized = next( + padding_fun( + trax.data.tokenize( + [ + question, + ], + vocab_file=VOCAB_FILE, + vocab_dir=VOCAB_DIR, + n_reserved_ids=100, + ) + ) +) + def detokenize(x): - return trax.data.detokenize(x, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, - n_reserved_ids=100) + return trax.data.detokenize( + x, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, n_reserved_ids=100 + ) + with trax.fastmath.use_backend(trax.fastmath.Backend.JAX): - model.state = old_state - counter, tokens, max_length = 0, [], 30 - for token in decoding.autoregressive_sample_stream( - model, tokenized[None, :15*1024], batch_size=1, temperature=0.0, - eval_mode=True, eval_min_length=1024): - print(f'Token {counter}: "{detokenize(token)}" {token}') - tokens.append(token[:, None]) - counter += 1 - if counter > max_length: - break - tokens = np.concatenate(tokens, axis=1) - print(tokens) - print(detokenize(tokens[0])) + model.state = old_state + counter, tokens, max_length = 0, [], 30 + for token in decoding.autoregressive_sample_stream( + model, + tokenized[None, : 15 * 1024], + batch_size=1, + temperature=0.0, + eval_mode=True, + eval_min_length=1024, + ): + print(f'Token {counter}: "{detokenize(token)}" {token}') + tokens.append(token[:, None]) + counter += 1 + if counter > max_length: + break + tokens = np.concatenate(tokens, axis=1) + print(tokens) + print(detokenize(tokens[0])) diff --git a/trax/models/research/rezero_test.py b/trax/models/research/rezero_test.py deleted file mode 100644 index d6be6d32e..000000000 --- a/trax/models/research/rezero_test.py +++ /dev/null @@ -1,67 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for ReZero models.""" - -from absl.testing import absltest -import numpy as np - -from trax import layers as tl -from trax import shapes -from trax.models.research import rezero - - -class ResidualZeroTest(absltest.TestCase): - - def test_residual_layer_forward(self): - """Tests that the forward pass runs and returns the expected shape.""" - model = rezero.ResidualZero(tl.Dense(5)) - x = [np.arange(5).astype(np.float32)] - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.tolist(), [0., 1., 2., 3., 4.]) - - -class ReZeroTransformerLMTest(absltest.TestCase): - - def test_rezero_lm_forward_shape(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = rezero.ReZeroTransformerLM( - vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -class ReZeroTransformerTest(absltest.TestCase): - - def test_rezero_forward_shape(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = rezero.ReZeroTransformer( - vocab_size, d_model=32, d_ff=64, n_encoder_layers=2, n_decoder_layers=2, - n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/rse.py b/trax/models/research/rse.py index b5c51c247..5c6ba3c83 100644 --- a/trax/models/research/rse.py +++ b/trax/models/research/rse.py @@ -27,31 +27,30 @@ # pylint: disable=invalid-name def _inverse_sigmoid(x): - return np.log(x / (1 - x)) + return np.log(x / (1 - x)) -@assert_shape('...->...') +@assert_shape("...->...") class _ClippedScaling(tl.Layer): - """Pointwise multiplies by sigmoid(S) with a learnable vector S.""" + """Pointwise multiplies by sigmoid(S) with a learnable vector S.""" - def __init__(self, - residual_weight): - super().__init__(n_in=1, n_out=1) - self._residual_weight = residual_weight + def __init__(self, residual_weight): + super().__init__(n_in=1, n_out=1) + self._residual_weight = residual_weight - def forward(self, x): - s = self.weights - return jnp.multiply(x, fastmath.expit(s)) + def forward(self, x): + s = self.weights + return jnp.multiply(x, fastmath.expit(s)) - def init_weights_and_state(self, input_signature): - self.weights = _inverse_sigmoid(self._residual_weight) * np.ones( - (input_signature.shape[-1])).astype('float32') + def init_weights_and_state(self, input_signature): + self.weights = _inverse_sigmoid(self._residual_weight) * np.ones( + (input_signature.shape[-1]) + ).astype("float32") -@assert_shape('bld->bld') -def ResidualSwitchUnit( - d_model, dropout=0.1, mode='train', residual_weight=0.9): - r"""RSU (Residual Switch Unit) layer as in https://arxiv.org/pdf/2004.04662.pdf. +@assert_shape("bld->bld") +def ResidualSwitchUnit(d_model, dropout=0.1, mode="train", residual_weight=0.9): + r"""RSU (Residual Switch Unit) layer as in https://arxiv.org/pdf/2004.04662.pdf. As defined in the paper: @@ -75,145 +74,152 @@ def ResidualSwitchUnit( Returns: The RSU layer. """ - return tl.Serial( - tl.Fn( - 'Reshape2Pairs', - lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] // 2, -1)), - n_out=1), - tl.Residual( - tl.Dense(4 * d_model, use_bias=False), - tl.LayerNorm(), - tl.Gelu(), - tl.Dense(2 * d_model), - tl.Fn('Scaling', + return tl.Serial( + tl.Fn( + "Reshape2Pairs", + lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] // 2, -1)), + n_out=1, + ), + tl.Residual( + tl.Dense(4 * d_model, use_bias=False), + tl.LayerNorm(), + tl.Gelu(), + tl.Dense(2 * d_model), + tl.Fn( + "Scaling", lambda x: x * np.sqrt(1 - residual_weight**2) * 0.25, - n_out=1), - shortcut=_ClippedScaling(residual_weight)), - tl.Fn( - 'UnPair', - lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] * 2, -1)), - n_out=1), - tl.Dropout(rate=dropout, mode=mode) - ) + n_out=1, + ), + shortcut=_ClippedScaling(residual_weight), + ), + tl.Fn( + "UnPair", + lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] * 2, -1)), + n_out=1, + ), + tl.Dropout(rate=dropout, mode=mode), + ) def _ror(x, n, p=1): - """Bitwise right rotation. + """Bitwise right rotation. - Args: - x: np.array - n: Bit count to represent each value of x - p: Bit positions to shift + Args: + x: np.array + n: Bit count to represent each value of x + p: Bit positions to shift - Returns: - np.array: x with all values shifted by p positions in n bits - """ - a = np.right_shift(x, p) - b = np.left_shift(1, p) - 1 - c = np.bitwise_and(x, b) - d = np.left_shift(c, n - p) + Returns: + np.array: x with all values shifted by p positions in n bits + """ + a = np.right_shift(x, p) + b = np.left_shift(1, p) - 1 + c = np.bitwise_and(x, b) + d = np.left_shift(c, n - p) - return a + d + return a + d def _rol(x, n, p=1): - """Bitwise left rotation. + """Bitwise left rotation. - Args: - x: np.array - n: Bit count to represent each value of x - p: Bit positions to shift + Args: + x: np.array + n: Bit count to represent each value of x + p: Bit positions to shift - Returns: - np.array: x with all values shifted by p positions in n bits - """ - a = np.left_shift(x, p) - b = np.left_shift(1, n) - 1 - c = np.bitwise_and(a, b) - d = np.right_shift(x, n - p) + Returns: + np.array: x with all values shifted by p positions in n bits + """ + a = np.left_shift(x, p) + b = np.left_shift(1, n) - 1 + c = np.bitwise_and(a, b) + d = np.right_shift(x, n - p) - return np.bitwise_or(c, d) + return np.bitwise_or(c, d) def _shuffle_layer(inputs, shuffle_fn): - """Shuffles the elements according to bitwise left or right rotation. + """Shuffles the elements according to bitwise left or right rotation. - Args: - inputs: Tensor input from previous layer - shuffle_fn: Shift function rol or ror + Args: + inputs: Tensor input from previous layer + shuffle_fn: Shift function rol or ror - Returns: - tf.Tensor: Inputs shifted according to shuffle_fn - """ - seq_length = inputs.shape[1] - n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 + Returns: + tf.Tensor: Inputs shifted according to shuffle_fn + """ + seq_length = inputs.shape[1] + n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 - indices = np.arange(0, seq_length).astype('int32') - rev_indices = shuffle_fn(indices, n_bits) - return jnp.take(inputs, rev_indices, axis=1, mode='clip') + indices = np.arange(0, seq_length).astype("int32") + rev_indices = shuffle_fn(indices, n_bits) + return jnp.take(inputs, rev_indices, axis=1, mode="clip") -@assert_shape('bld->bld') +@assert_shape("bld->bld") def ShuffleLayer(): - return tl.Fn( - 'ShuffleLayer', lambda x: _shuffle_layer(x, _rol), n_out=1) + return tl.Fn("ShuffleLayer", lambda x: _shuffle_layer(x, _rol), n_out=1) -@assert_shape('bld->bld') +@assert_shape("bld->bld") def ReverseShuffleLayer(): - return tl.Fn( - 'ReverseShuffleLayer', lambda x: _shuffle_layer(x, _ror), n_out=1) + return tl.Fn("ReverseShuffleLayer", lambda x: _shuffle_layer(x, _ror), n_out=1) -@assert_shape('...,bld->...,bld') +@assert_shape("...,bld->...,bld") def _ForwardStep(d_model, dropout, mode): - """Takes (n_layer, state) and returns (n_layer, shuffle_layer(rsu(state))).""" - return tl.Parallel([], tl.Serial( - ResidualSwitchUnit(d_model, dropout, mode), - ShuffleLayer(), - )) + """Takes (n_layer, state) and returns (n_layer, shuffle_layer(rsu(state))).""" + return tl.Parallel( + [], + tl.Serial( + ResidualSwitchUnit(d_model, dropout, mode), + ShuffleLayer(), + ), + ) -@assert_shape('...,bld->...,bld') +@assert_shape("...,bld->...,bld") def _BackwardStep(d_model, dropout, mode): - """Takes (n_layer, state) and returns (n_layer, reverse_shuffle_layer(rsu(state))).""" - return tl.Parallel([], tl.Serial( - ResidualSwitchUnit(d_model, dropout, mode), - ReverseShuffleLayer(), - )) + """Takes (n_layer, state) and returns (n_layer, reverse_shuffle_layer(rsu(state))).""" + return tl.Parallel( + [], + tl.Serial( + ResidualSwitchUnit(d_model, dropout, mode), + ReverseShuffleLayer(), + ), + ) -@assert_shape('bld->bld') +@assert_shape("bld->bld") def BenesBlock(d_model, dropout, mode): - def bit_sequence(inputs): - seq_length = inputs.shape[1] - n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 - return jnp.arange(0, n_bits) - return tl.Serial( - tl.Dup(), - tl.Fn('BitSeq', bit_sequence, n_out=1), - tl.Scan(_ForwardStep(d_model, dropout, mode)), - tl.Scan(_BackwardStep(d_model, dropout, mode)), - tl.Select([1]), - ) - - -@assert_shape('bl->blv') -def ResidualShuffleExchange(vocab_size, - d_model, - input_dropout, - dropout, - mode='train', - n_blocks=2): - """Returns a Residual Shuffle Exchange Network model.""" - benes_blocks = [BenesBlock(d_model, dropout, mode) for _ in range(n_blocks)] - return tl.Serial( - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=input_dropout, mode=mode), - # Apply Benes Block n_blocks times. - *benes_blocks, - ResidualSwitchUnit(d_model, dropout, mode), - # Produce probabilities. - tl.Dense(vocab_size), - tl.LogSoftmax(), - ) + def bit_sequence(inputs): + seq_length = inputs.shape[1] + n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 + return jnp.arange(0, n_bits) + + return tl.Serial( + tl.Dup(), + tl.Fn("BitSeq", bit_sequence, n_out=1), + tl.Scan(_ForwardStep(d_model, dropout, mode)), + tl.Scan(_BackwardStep(d_model, dropout, mode)), + tl.Select([1]), + ) + + +@assert_shape("bl->blv") +def ResidualShuffleExchange( + vocab_size, d_model, input_dropout, dropout, mode="train", n_blocks=2 +): + """Returns a Residual Shuffle Exchange Network model.""" + benes_blocks = [BenesBlock(d_model, dropout, mode) for _ in range(n_blocks)] + return tl.Serial( + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=input_dropout, mode=mode), + # Apply Benes Block n_blocks times. + *benes_blocks, + ResidualSwitchUnit(d_model, dropout, mode), + # Produce probabilities. + tl.Dense(vocab_size), + tl.LogSoftmax(), + ) diff --git a/trax/models/research/rse_test.py b/trax/models/research/rse_test.py deleted file mode 100644 index 36891dbe5..000000000 --- a/trax/models/research/rse_test.py +++ /dev/null @@ -1,110 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Residual Shuffle-Exchange Networks.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.models.research import rse - - -class RSETest(absltest.TestCase): - - def test_rsu_forward_shape(self): - batch_size = 3 - seq_len = 32 - d_model = 17 - model = rse.ResidualSwitchUnit( - d_model=d_model, dropout=0.1, mode='train') - x = np.ones((batch_size, seq_len, d_model)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (batch_size, seq_len, d_model)) - - def test_shuffle_layer(self): - shuffle_layer = rse.ShuffleLayer() - x = np.array([[[0], [1], [2], [3], [4], [5], [6], [7]]]) - print(x.shape) - _, _ = shuffle_layer.init(shapes.signature(x)) - y = shuffle_layer(x) - expected_output = np.array([[[0], [2], [4], [6], [1], [3], [5], [7]]]) - self._assert_equal_tensors(y, expected_output) - - def test_shuffle_layer_log_times_is_identity(self): - seq_len = 8 - d_model = 17 - shuffle_layer = rse.ShuffleLayer() - x = _input_with_indice_as_values(seq_len, d_model) - _, _ = shuffle_layer.init(shapes.signature(x)) - y = x - for _ in range(int(np.log2(seq_len))): - y = shuffle_layer(y) - self._assert_equal_tensors(x, y) - - def test_reverse_shuffle_layer(self): - reverse_shuffle_layer = rse.ReverseShuffleLayer() - x = np.array([[[0], [1], [2], [3], [4], [5], [6], [7]]]) - print(x.shape) - _, _ = reverse_shuffle_layer.init(shapes.signature(x)) - y = reverse_shuffle_layer(x) - expected_output = np.array([[[0], [4], [1], [5], [2], [6], [3], [7]]]) - self._assert_equal_tensors(y, expected_output) - - def test_reverse_shuffle_layer_log_times_is_identity(self): - seq_len = 8 - d_model = 17 - reverse_shuffle_layer = rse.ReverseShuffleLayer() - x = _input_with_indice_as_values(seq_len, d_model) - _, _ = reverse_shuffle_layer.init(shapes.signature(x)) - y = x - for _ in range(int(np.log2(seq_len))): - y = reverse_shuffle_layer(y) - self._assert_equal_tensors(x, y) - - def test_rse_forward_shape(self): - vocab_size = 12 - seq_len = 32 - model = rse.ResidualShuffleExchange( - vocab_size=vocab_size, d_model=17, dropout=0.1, input_dropout=0.05, - mode='train') - x = np.ones((3, seq_len)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, seq_len, vocab_size)) - - def _assert_equal_tensors(self, x, y): - self.assertEqual(y.shape, x.shape) - for i in range(x.shape[0]): - for j in range(x.shape[1]): - for k in range(x.shape[2]): - self.assertEqual( - x[i][j][k], y[i][j][k], - f'Tensors differ on index [{i}][{j}][{k}].') - - -def _input_with_indice_as_values(length, dim): - """Retuns np.array of size (1, length, dim) where x[0, a, b] = a.""" - positions = [] - for i in range(length): - positions.append([i] * dim) - positions_input = np.array(positions) - positions_input = np.expand_dims(positions_input, axis=0) - return positions_input - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/terraformer.py b/trax/models/research/terraformer.py index 892c5c5d9..92af63866 100644 --- a/trax/models/research/terraformer.py +++ b/trax/models/research/terraformer.py @@ -29,462 +29,478 @@ # pylint: disable=invalid-name -def ConfigurableTerraformer(input_vocab_size, - output_vocab_size=None, - d_model=512, - d_ff=2048, - d_attention_key=None, - d_attention_value=None, - n_encoder_layers=6, - n_decoder_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - encoder_attention_type=tl.SelfAttention, - encoder_decoder_attention_type=tl.SelfAttention, - pos_type='fixed-base', - pos_axial_shape=(), - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0, - ff_activation=tl.Relu, - ff_use_sru=0, - ff_chunk_size=0, - ff_dropout=None, - ff_sparsity=0, - loss_sparsity_type='mult', - loss_sparsity=0, - loss_d_lowrank=0, - loss_sparsity_prob=None, - attention_chunk_size=0, - n_layers_forget=0, - forget_dense=True, - n_decoder_attention_layers=2, - use_bfloat16=False, - reversible_encoder=False, - use_two_swaps_per_encoder_block=True, - center_layernorm=True, - half_before_layer=None, - double_after_layer=None, - mode='train'): - """Returns a highly configurable Terraformer encoder-decoder model. - - This model maps paired text sequences (source and target) to float-valued - losses. If ``input_vocab_size`` is not ``None``, the layer takes - two input sequences: - - - inputs (2): - - - source: 2-D int array representing a batch of text strings via token - IDs plus padding markers; shape is `(batch_size, sequence_length)`, - where sequence_length <= ``max_len``. Array elements are in - ``range(input_vocab_size)``, and 0 values mark padding positions. - - - target: 2-D int array representing a batch of text strings via token - IDs plus padding markers; shape is `(batch_size, sequence_length)`, - where sequence_length <= ``max_len``. Array elements are in - ``range(output_vocab_size)``, and 0 values mark padding positions. - - - output: 1-D float array of losses; shape is `(batch_size)`. - - If ``input_vocab_size`` is ``None``, the layer takes three input sequences: - - - inputs (3): - - - source: 3-D float array representing a batch of already-embedded text - strings; shape is `(batch_size, sequence_length, d_model)`, where - sequence_length <= ``max_len``. - - - mask: 2-D int array representing active versus masked positions; 0 - values mark masked (padding) positions. - - - target: 2-D int array representing a batch of text strings via token - IDs plus padding markers; shape is `(batch_size, sequence_length)`, - where sequence_length <= ``max_len``. Array elements are in - ``range(output_vocab_size)``, and 0 values mark padding positions. - - - output: 1-D float array of losses; shape is `(batch_size)`. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if ``None``, then input and target integers (token IDs) are assumed to - come from the same vocabulary. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - d_attention_key: Depth of key vectors in each attention head. - d_attention_value: Depth of value vectors in each attention head. - n_encoder_layers: Number of encoder blocks. - n_decoder_layers: Number of decoder blocks. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder/decoder blocks. The same rate is - also used for attention dropout in encoder/decoder blocks. - max_len: Maximum symbol length for positional encoding. - encoder_attention_type: Type of attention to use in the encoder; must be - an attention-type subclass of :py:class:`trax.layers.Layer`. - encoder_decoder_attention_type: Type of attention to use in the decoder; - must be an attention-type subclass of :py:class:`trax.layers.Layer`. - pos_type: String indicating the type of positional embeddings to use. - pos_axial_shape: Shape (tuple of ints) to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: Tuple of ints specifying the depth of position embedding - for each axis. Tuple length must match ``pos_axial_shape``, and values - must sum to ``d_model``. - pos_start_from_zero_prob: Stochastic rate (probability) for starting - positional encoding at position 0 during training. If 1.0, always start - from position 0; if < 1.0, the non-zero starts will be uniformly - distributed up to ``pos_max_offset_to_add``. - pos_max_offset_to_add: Maximum offset to add to positions during training - when randomizing. This offset plus input length must be less than - ``max_len`` for all training examples. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`trax.layers.Layer`. - ff_use_sru: If > 0, use this number of SRU layers in place of feedforward - layers. - ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this - size. - ff_dropout: Stochastic rate (probability) for dropping an activation value - at feedforward nonlinearities. - ff_sparsity: If > 0, use sparse feedforward blocks with this level of - sparsity. - loss_sparsity_type: String indicating the type of sparsity to used in loss - layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``, - use no sparsity. - loss_sparsity: If > 0, use this level of sparsity in the loss layer. - loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this - dimension, in the loss. - loss_sparsity_prob: Stochastic rate (probability) for using the sparse - version of the loss. If ``None``, use the sparse version exclusively. - attention_chunk_size: If > 0, compute attention using chunks of this size. - n_layers_forget: How often to have a forgetting block between layers. - forget_dense: If True, use :py:class:`Dense` instances as forget layers; - else use no-ops. - n_decoder_attention_layers: Number of attention layers in a decoder block. - use_bfloat16: If True, use bfloat16 for weights; else use float32. - reversible_encoder: If True, make the encoder be reversible. - use_two_swaps_per_encoder_block: If True, ensure that there is a an even - number of swaps across the encoder. - center_layernorm: If True, use centering in :py:class:`LayerNorm` (the - default); else omit centering (which is known as RMS normalization). - half_before_layer: If not None, specifies an n'th layer such that all - layers before the n'th use half the normal values for ``d_model`` and - ``d_ff``. - double_after_layer: If not None, specifies an n'th layer such that all - layers after the n'th use double the normal values for ``d_model`` and - ``d_ff``. - mode: If ``'train'``, include dropout in each encoder/decoder block; else - dropout layers have no effect. - - Returns: - A Terraformer encoder-decoder as a layer that maps from target and source - text sequences to a scalar loss. - """ - if mode == 'predict': - portal_mask = _PortalInput() - else: - portal_mask = None - - # Set default dimensions for attention head key and value sizes. - if (d_model / 2) % n_heads != 0: - raise ValueError(f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})') - if d_attention_key is None: - d_attention_key = d_model // n_heads - if d_attention_value is None: - d_attention_value = d_model // n_heads - - # Set values of d_model, d_ff and d_qkv for the first stage. - d_model1, d_ff1 = d_model, d_ff - d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value - if half_before_layer: - d_model1, d_ff1 = d_model / 2, d_ff / 2 - d_attention_key1 = d_attention_key / 2 - d_attention_value1 = d_attention_value / 2 - - # Set values of d_model, d_ff and d_qkv for the final stage. - d_model2, d_ff2 = d_model, d_ff - d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value - if double_after_layer: - d_model2, d_ff2 = d_model * 2, d_ff * 2 - d_attention_key2 = d_attention_key * 2 - d_attention_value2 = d_attention_value * 2 - - # Vector embeddings. - in_encoder, out_encoder, output_vocab_size = ( - ct.EmbeddingAndPositionalEncodings( - input_vocab_size, - d_model1, - mode, - dropout, - [-2], # dropout_shared_axes - max_len, - output_vocab_size=output_vocab_size, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs, - pos_start_from_zero_prob=pos_start_from_zero_prob, - pos_max_offset_to_add=pos_max_offset_to_add, - use_bfloat16=use_bfloat16) - ) - - def _EncoderBlock(): - return reformer.EncoderBlock( +def ConfigurableTerraformer( + input_vocab_size, + output_vocab_size=None, + d_model=512, + d_ff=2048, + d_attention_key=None, + d_attention_value=None, + n_encoder_layers=6, + n_decoder_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + encoder_attention_type=tl.SelfAttention, + encoder_decoder_attention_type=tl.SelfAttention, + pos_type="fixed-base", + pos_axial_shape=(), + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, + ff_activation=tl.Relu, + ff_use_sru=0, + ff_chunk_size=0, + ff_dropout=None, + ff_sparsity=0, + loss_sparsity_type="mult", + loss_sparsity=0, + loss_d_lowrank=0, + loss_sparsity_prob=None, + attention_chunk_size=0, + n_layers_forget=0, + forget_dense=True, + n_decoder_attention_layers=2, + use_bfloat16=False, + reversible_encoder=False, + use_two_swaps_per_encoder_block=True, + center_layernorm=True, + half_before_layer=None, + double_after_layer=None, + mode="train", +): + """Returns a highly configurable Terraformer encoder-decoder model. + + This model maps paired text sequences (source and target) to float-valued + losses. If ``input_vocab_size`` is not ``None``, the layer takes + two input sequences: + + - inputs (2): + + - source: 2-D int array representing a batch of text strings via token + IDs plus padding markers; shape is `(batch_size, sequence_length)`, + where sequence_length <= ``max_len``. Array elements are in + ``range(input_vocab_size)``, and 0 values mark padding positions. + + - target: 2-D int array representing a batch of text strings via token + IDs plus padding markers; shape is `(batch_size, sequence_length)`, + where sequence_length <= ``max_len``. Array elements are in + ``range(output_vocab_size)``, and 0 values mark padding positions. + + - output: 1-D float array of losses; shape is `(batch_size)`. + + If ``input_vocab_size`` is ``None``, the layer takes three input sequences: + + - inputs (3): + + - source: 3-D float array representing a batch of already-embedded text + strings; shape is `(batch_size, sequence_length, d_model)`, where + sequence_length <= ``max_len``. + + - mask: 2-D int array representing active versus masked positions; 0 + values mark masked (padding) positions. + + - target: 2-D int array representing a batch of text strings via token + IDs plus padding markers; shape is `(batch_size, sequence_length)`, + where sequence_length <= ``max_len``. Array elements are in + ``range(output_vocab_size)``, and 0 values mark padding positions. + + - output: 1-D float array of losses; shape is `(batch_size)`. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in ``range(vocab_size)``. These integers typically + represent token IDs from a vocabulary-based tokenizer. + output_vocab_size: If specified, gives the vocabulary size for the targets; + if ``None``, then input and target integers (token IDs) are assumed to + come from the same vocabulary. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + d_attention_key: Depth of key vectors in each attention head. + d_attention_value: Depth of value vectors in each attention head. + n_encoder_layers: Number of encoder blocks. + n_decoder_layers: Number of decoder blocks. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder/decoder blocks. The same rate is + also used for attention dropout in encoder/decoder blocks. + max_len: Maximum symbol length for positional encoding. + encoder_attention_type: Type of attention to use in the encoder; must be + an attention-type subclass of :py:class:`trax.layers.Layer`. + encoder_decoder_attention_type: Type of attention to use in the decoder; + must be an attention-type subclass of :py:class:`trax.layers.Layer`. + pos_type: String indicating the type of positional embeddings to use. + pos_axial_shape: Shape (tuple of ints) to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: Tuple of ints specifying the depth of position embedding + for each axis. Tuple length must match ``pos_axial_shape``, and values + must sum to ``d_model``. + pos_start_from_zero_prob: Stochastic rate (probability) for starting + positional encoding at position 0 during training. If 1.0, always start + from position 0; if < 1.0, the non-zero starts will be uniformly + distributed up to ``pos_max_offset_to_add``. + pos_max_offset_to_add: Maximum offset to add to positions during training + when randomizing. This offset plus input length must be less than + ``max_len`` for all training examples. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`trax.layers.Layer`. + ff_use_sru: If > 0, use this number of SRU layers in place of feedforward + layers. + ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this + size. + ff_dropout: Stochastic rate (probability) for dropping an activation value + at feedforward nonlinearities. + ff_sparsity: If > 0, use sparse feedforward blocks with this level of + sparsity. + loss_sparsity_type: String indicating the type of sparsity to used in loss + layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``, + use no sparsity. + loss_sparsity: If > 0, use this level of sparsity in the loss layer. + loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this + dimension, in the loss. + loss_sparsity_prob: Stochastic rate (probability) for using the sparse + version of the loss. If ``None``, use the sparse version exclusively. + attention_chunk_size: If > 0, compute attention using chunks of this size. + n_layers_forget: How often to have a forgetting block between layers. + forget_dense: If True, use :py:class:`Dense` instances as forget layers; + else use no-ops. + n_decoder_attention_layers: Number of attention layers in a decoder block. + use_bfloat16: If True, use bfloat16 for weights; else use float32. + reversible_encoder: If True, make the encoder be reversible. + use_two_swaps_per_encoder_block: If True, ensure that there is a an even + number of swaps across the encoder. + center_layernorm: If True, use centering in :py:class:`LayerNorm` (the + default); else omit centering (which is known as RMS normalization). + half_before_layer: If not None, specifies an n'th layer such that all + layers before the n'th use half the normal values for ``d_model`` and + ``d_ff``. + double_after_layer: If not None, specifies an n'th layer such that all + layers after the n'th use double the normal values for ``d_model`` and + ``d_ff``. + mode: If ``'train'``, include dropout in each encoder/decoder block; else + dropout layers have no effect. + + Returns: + A Terraformer encoder-decoder as a layer that maps from target and source + text sequences to a scalar loss. + """ + if mode == "predict": + portal_mask = _PortalInput() + else: + portal_mask = None + + # Set default dimensions for attention head key and value sizes. + if (d_model / 2) % n_heads != 0: + raise ValueError(f"n_heads ({n_heads}) must divide d_model/2 ({d_model/2})") + if d_attention_key is None: + d_attention_key = d_model // n_heads + if d_attention_value is None: + d_attention_value = d_model // n_heads + + # Set values of d_model, d_ff and d_qkv for the first stage. + d_model1, d_ff1 = d_model, d_ff + d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value + if half_before_layer: + d_model1, d_ff1 = d_model / 2, d_ff / 2 + d_attention_key1 = d_attention_key / 2 + d_attention_value1 = d_attention_value / 2 + + # Set values of d_model, d_ff and d_qkv for the final stage. + d_model2, d_ff2 = d_model, d_ff + d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value + if double_after_layer: + d_model2, d_ff2 = d_model * 2, d_ff * 2 + d_attention_key2 = d_attention_key * 2 + d_attention_value2 = d_attention_value * 2 + + # Vector embeddings. + in_encoder, out_encoder, output_vocab_size = ct.EmbeddingAndPositionalEncodings( + input_vocab_size, d_model1, - d_ff1, - n_heads, - encoder_attention_type, - dropout=dropout, - ff_activation=ff_activation, - ff_dropout=ff_dropout, - ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity, - attention_chunk_size=attention_chunk_size, - center_layernorm=center_layernorm, + mode, + dropout, + [-2], # dropout_shared_axes + max_len, + output_vocab_size=output_vocab_size, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + pos_start_from_zero_prob=pos_start_from_zero_prob, + pos_max_offset_to_add=pos_max_offset_to_add, use_bfloat16=use_bfloat16, - use_two_swaps_per_block=use_two_swaps_per_encoder_block, - mode=mode) + ) - def _Encoder(): # vec_e mask_e tok_e tok_d tok_d - layers = [ - tl.ReversibleSelect([0, 0]), - _ReversibleSerialForget( - [_EncoderBlock() for _ in range(n_encoder_layers)], + def _EncoderBlock(): + return reformer.EncoderBlock( d_model1, - n_layers_forget, - forget_dense) - ] - if not reversible_encoder: - layers += [ - _XYAvg(), - tl.Dense(d_model1, use_bfloat16=use_bfloat16), - tl.LayerNorm(), - ] - if mode == 'predict': - return tl.Cache(tl.Serial(layers)) - else: - return tl.Serial(layers) - - if mode == 'predict': - # TODO(jaszczur): Remove temporary fix of Terraformer padding in predict. - # In predict mode Terraformer needs masking for merged encoder-decoder - # sequence. This monkey patches the layer with a mask to neccessary places. - # This shouldn't be a permanent solution - mask should be passed through - # the stack and all the layers. - tl.attention.DotProductCausalAttention.monkey_patched_mask = ( - lambda x: portal_mask) - tl.research.sparsity._RememberPad.monkey_patched_mask = ( # pylint: disable=protected-access - lambda x: portal_mask) - originalScanSRUCell = tl.rnn.ScanSRUCell - tl.rnn.ScanSRUCell = functools.partial(tl.rnn.ScanSRUCell, - monkey_patched_mask=portal_mask) - - decoder_blocks = [] - - if isinstance(encoder_decoder_attention_type, (tuple, list)): - assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 - else: - encoder_decoder_attention_type = [encoder_decoder_attention_type] - for layer_idx in range(n_decoder_layers): - layer_attention_type = encoder_decoder_attention_type[ - layer_idx % len(encoder_decoder_attention_type)] - # Grow d_model, d_ff, and d_qkv if requested. - d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 - if half_before_layer and layer_idx >= half_before_layer: - d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value - if double_after_layer and layer_idx > double_after_layer: - d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 - decoder_block = reformer.DecoderBlock( - d_m, d_f, d_k, d_v, n_heads, - attention_type=layer_attention_type, - dropout=dropout, - ff_activation=ff_activation, - ff_dropout=ff_dropout, - ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity, - attention_chunk_size=attention_chunk_size, - n_attention_layers=n_decoder_attention_layers, - center_layernorm=center_layernorm, - use_bfloat16=use_bfloat16, - mode=mode) - decoder_blocks.append(decoder_block) - if half_before_layer and layer_idx == half_before_layer - 1: - decoder_blocks.append(tl.ReversibleConcatenatePair()) - if double_after_layer and layer_idx == double_after_layer: - decoder_blocks.append(tl.ReversibleConcatenatePair()) - - if mode == 'predict': - # After initializing the decoder we can revert to original state of - # previously monkey-patched classes/functions. - tl.attention.DotProductCausalAttention.monkey_patched_mask = ( - lambda x: None) - tl.research.sparsity._RememberPad.monkey_patched_mask = (lambda x: None) # pylint: disable=protected-access - tl.rnn.ScanSRUCell = originalScanSRUCell - - def _Loss(): - return tl.SparseDenseWithOptions( - output_vocab_size, - d_input=d_model2, - sparsity_type=loss_sparsity_type, - sparsity=loss_sparsity, - d_lowrank=loss_d_lowrank, - prob_sparse=loss_sparsity_prob, - use_bfloat16=use_bfloat16, - mode=mode) - - def _enc_dec_concat(): - """Layers to merge encoder and decoder.""" - if reversible_encoder: - return [ - tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d - t2.ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d - ] - else: - return [ - tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d - t2.ConcatWithPadding(mode=mode), # v_ed tok_e tok_d - tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d - ] - - def _inp_layers(): - if input_vocab_size is not None: - return tl.AssertFunction( - 'bl,br->bld,bl,bl,br', # b: batch, l/r: enc/dec length, d: vec depth - tl.Serial( # tok_e tok_d - tl.Select([0, 0, 0, 1]), - tl.Parallel(in_encoder, [tl.PaddingMask(), - _RemoveAxes12()]) - )) # vec_e mask_e tok_e tok_d + d_ff1, + n_heads, + encoder_attention_type, + dropout=dropout, + ff_activation=ff_activation, + ff_dropout=ff_dropout, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + attention_chunk_size=attention_chunk_size, + center_layernorm=center_layernorm, + use_bfloat16=use_bfloat16, + use_two_swaps_per_block=use_two_swaps_per_encoder_block, + mode=mode, + ) + + def _Encoder(): # vec_e mask_e tok_e tok_d tok_d + layers = [ + tl.ReversibleSelect([0, 0]), + _ReversibleSerialForget( + [_EncoderBlock() for _ in range(n_encoder_layers)], + d_model1, + n_layers_forget, + forget_dense, + ), + ] + if not reversible_encoder: + layers += [ + _XYAvg(), + tl.Dense(d_model1, use_bfloat16=use_bfloat16), + tl.LayerNorm(), + ] + if mode == "predict": + return tl.Cache(tl.Serial(layers)) + else: + return tl.Serial(layers) + + if mode == "predict": + # TODO(jaszczur): Remove temporary fix of Terraformer padding in predict. + # In predict mode Terraformer needs masking for merged encoder-decoder + # sequence. This monkey patches the layer with a mask to neccessary places. + # This shouldn't be a permanent solution - mask should be passed through + # the stack and all the layers. + tl.attention.DotProductCausalAttention.monkey_patched_mask = ( + lambda x: portal_mask + ) + tl.research.sparsity._RememberPad.monkey_patched_mask = ( # pylint: disable=protected-access + lambda x: portal_mask + ) + originalScanSRUCell = tl.rnn.ScanSRUCell + tl.rnn.ScanSRUCell = functools.partial( + tl.rnn.ScanSRUCell, monkey_patched_mask=portal_mask + ) + + decoder_blocks = [] + + if isinstance(encoder_decoder_attention_type, (tuple, list)): + assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: - # Input in this case is vec_e, mask_e, tok_d. Where all downstream - # operations expect tok_e, we give it instead mask_e, expecting that - # downstream ops only are looking for padding/not padding. - return tl.AssertFunction( - 'blf,bl,br->bld,bl,bl,br', # f: in-feature depth, d: out-vector depth - tl.Serial( # vec_e mask_e tok_d - tl.Select([0, 1, 1, 2]), - tl.Parallel(in_encoder, [], _AsTokenIDs()) - )) # vec_e mask_e tok_e tok_d - - # Assemble and return the model. - return tl.Serial( - _inp_layers(), # vec_e mask_e tok_e tok_d - tl.Parallel([], portal_mask), - - tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss. - - # Embed in and out tokens; done together as weights may be shared. - tl.Parallel([], [], [], [tl.ShiftRight(mode=mode), - out_encoder]), # vec_e mask_e tok_e vec_d tok_d - - # Encode; then concat encoder and decoder, given encoder mask. - _Encoder(), # vec_e mask_e tok_e vec_d tok_d - _enc_dec_concat(), - - # Run decoder blocks. - _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget, - forget_dense), # vec_ed1 vec_ed2 tok_e tok_d - _XYAvg(), # vec_ed tok_e tok_d - tl.LayerNorm(), # vec_ed tok_e tok_d - - # Separate out the encoder part from the concatenated vector, - # then compute loss. - tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d - t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d - _Loss(), # vec_d tok_d - ) + encoder_decoder_attention_type = [encoder_decoder_attention_type] + for layer_idx in range(n_decoder_layers): + layer_attention_type = encoder_decoder_attention_type[ + layer_idx % len(encoder_decoder_attention_type) + ] + # Grow d_model, d_ff, and d_qkv if requested. + d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 + if half_before_layer and layer_idx >= half_before_layer: + d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value + if double_after_layer and layer_idx > double_after_layer: + d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 + decoder_block = reformer.DecoderBlock( + d_m, + d_f, + d_k, + d_v, + n_heads, + attention_type=layer_attention_type, + dropout=dropout, + ff_activation=ff_activation, + ff_dropout=ff_dropout, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + attention_chunk_size=attention_chunk_size, + n_attention_layers=n_decoder_attention_layers, + center_layernorm=center_layernorm, + use_bfloat16=use_bfloat16, + mode=mode, + ) + decoder_blocks.append(decoder_block) + if half_before_layer and layer_idx == half_before_layer - 1: + decoder_blocks.append(tl.ReversibleConcatenatePair()) + if double_after_layer and layer_idx == double_after_layer: + decoder_blocks.append(tl.ReversibleConcatenatePair()) + + if mode == "predict": + # After initializing the decoder we can revert to original state of + # previously monkey-patched classes/functions. + tl.attention.DotProductCausalAttention.monkey_patched_mask = lambda x: None + tl.research.sparsity._RememberPad.monkey_patched_mask = ( + lambda x: None + ) # pylint: disable=protected-access + tl.rnn.ScanSRUCell = originalScanSRUCell + + def _Loss(): + return tl.SparseDenseWithOptions( + output_vocab_size, + d_input=d_model2, + sparsity_type=loss_sparsity_type, + sparsity=loss_sparsity, + d_lowrank=loss_d_lowrank, + prob_sparse=loss_sparsity_prob, + use_bfloat16=use_bfloat16, + mode=mode, + ) + + def _enc_dec_concat(): + """Layers to merge encoder and decoder.""" + if reversible_encoder: + return [ + tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d + t2.ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d + ] + else: + return [ + tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d + t2.ConcatWithPadding(mode=mode), # v_ed tok_e tok_d + tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d + ] + + def _inp_layers(): + if input_vocab_size is not None: + return tl.AssertFunction( + "bl,br->bld,bl,bl,br", # b: batch, l/r: enc/dec length, d: vec depth + tl.Serial( # tok_e tok_d + tl.Select([0, 0, 0, 1]), + tl.Parallel(in_encoder, [tl.PaddingMask(), _RemoveAxes12()]), + ), + ) # vec_e mask_e tok_e tok_d + else: + # Input in this case is vec_e, mask_e, tok_d. Where all downstream + # operations expect tok_e, we give it instead mask_e, expecting that + # downstream ops only are looking for padding/not padding. + return tl.AssertFunction( + "blf,bl,br->bld,bl,bl,br", # f: in-feature depth, d: out-vector depth + tl.Serial( # vec_e mask_e tok_d + tl.Select([0, 1, 1, 2]), tl.Parallel(in_encoder, [], _AsTokenIDs()) + ), + ) # vec_e mask_e tok_e tok_d + + # Assemble and return the model. + return tl.Serial( + _inp_layers(), # vec_e mask_e tok_e tok_d + tl.Parallel(tl.Select([0]), portal_mask), + tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss. + # Embed in and out tokens; done together as weights may be shared. + tl.Parallel( + tl.Select([0]), + tl.Select([0]), + tl.Select([0]), + [tl.ShiftRight(mode=mode), out_encoder], + ), # vec_e mask_e tok_e vec_d tok_d + # Encode; then concat encoder and decoder, given encoder mask. + _Encoder(), # vec_e mask_e tok_e vec_d tok_d + _enc_dec_concat(), + # Run decoder blocks. + _ReversibleSerialForget( + decoder_blocks, d_model2, n_layers_forget, forget_dense + ), # vec_ed1 vec_ed2 tok_e tok_d + _XYAvg(), # vec_ed tok_e tok_d + tl.LayerNorm(), # vec_ed tok_e tok_d + # Separate out the encoder part from the concatenated vector, + # then compute loss. + tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d + t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d + _Loss(), # vec_d tok_d + ) def _InsertAxes12(): - """Returns a layer that inserts two internal size-1 axes into an array.""" - return tl.Fn('InsertAxes12', - lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1]))) + """Returns a layer that inserts two internal size-1 axes into an array.""" + return tl.Fn( + "InsertAxes12", lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])) + ) def _RemoveAxes12(): - """Returns a layer that removes two internal size-1 axes from an array.""" - return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2))) + """Returns a layer that removes two internal size-1 axes from an array.""" + return tl.Fn("RemoveAxes12", lambda x: jnp.squeeze(x, (1, 2))) def _AsTokenIDs(): - """Returns a layer that makes mask values look like token ID ints.""" - return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32)) + """Returns a layer that makes mask values look like token ID ints.""" + return tl.Fn("AsTokenIDs", lambda x: x.astype(jnp.int32)) def _XYAvg(): - """Returns a layer that computes the element-wise average of two arrays.""" - return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0) + """Returns a layer that computes the element-wise average of two arrays.""" + return tl.Fn("XYAvg", lambda x, y: (x + y) / 2.0) def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True): - """ReversibleSerial but with a forgetting block every n_layers.""" - if not n_layers or len(layers) <= n_layers + 1: - return tl.ReversibleSerial(layers) - layers1, layers2 = layers[:n_layers], layers[n_layers:] - - if forget_dense: - forgetting_layer = tl.Serial( - _XYAvg(), - tl.Dense(d_model), - tl.Dup(), - ) - else: - forgetting_layer = tl.Select([0, 1]) + """ReversibleSerial but with a forgetting block every n_layers.""" + if not n_layers or len(layers) <= n_layers + 1: + return tl.ReversibleSerial(layers) + layers1, layers2 = layers[:n_layers], layers[n_layers:] + + if forget_dense: + forgetting_layer = tl.Serial( + _XYAvg(), + tl.Dense(d_model), + tl.Dup(), + ) + else: + forgetting_layer = tl.Select([0, 1]) - return tl.Serial( - tl.ReversibleSerial(layers1), - forgetting_layer, - _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense) - ) + return tl.Serial( + tl.ReversibleSerial(layers1), + forgetting_layer, + _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense), + ) def _ConvertToNaNsOnAnyZero(): - def _convert_to_nans(x, y): - # if all values in y are non-zeros, return x; otherwise return 0s - return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y - return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2) + def _convert_to_nans(x, y): + # if all values in y are non-zeros, return x; otherwise return 0s + return jnp.where(jnp.all(y, keepdims=False), x, x / 0.0), y + + return tl.Fn("ConvertToNaNsOnAnyZero", _convert_to_nans, n_out=2) class _PortalInput(tl.Layer): - """Portal input for monkey-patching of mask in predict mode.""" + """Portal input for monkey-patching of mask in predict mode.""" - def __init__(self): - super().__init__(name='_PortalInput', n_out=1, n_in=1) - self._portal_output = _PortalOutput(self) + def __init__(self): + super().__init__(name="_PortalInput", n_out=1, n_in=1) + self._portal_output = _PortalOutput(self) - def forward(self, x): - if isinstance(x, (list, tuple)): - x = x[0] - self.state = (x,) - return x + def forward(self, x): + if isinstance(x, (list, tuple)): + x = x[0] + self.state = (x,) + return x - def init_weights_and_state(self, input_signature): - """Initializes this layer's weights.""" - if isinstance(input_signature, (list, tuple)): - input_signature = input_signature[0] - self.state = (jnp.zeros(input_signature.shape),) + def init_weights_and_state(self, input_signature): + """Initializes this layer's weights.""" + if isinstance(input_signature, (list, tuple)): + input_signature = input_signature[0] + self.state = (jnp.zeros(input_signature.shape),) - def get_value(self): - return self.state[0] + def get_value(self): + return self.state[0] - def get_layer(self): - return self._portal_output + def get_layer(self): + return self._portal_output class _PortalOutput(tl.Layer): - """Portal input for monkey-patching of mask in predict mode.""" + """Portal input for monkey-patching of mask in predict mode.""" - def __init__(self, portal_input): - super().__init__(name='_PortalOutput', n_out=1, n_in=0) - self._portal_input = portal_input + def __init__(self, portal_input): + super().__init__(name="_PortalOutput", n_out=1, n_in=0) + self._portal_input = portal_input - def forward(self, x): - return self._portal_input.get_value() + def forward(self, x): + return self._portal_input.get_value() - def get_value(self): - return self._portal_input.get_value() + def get_value(self): + return self._portal_input.get_value() diff --git a/trax/models/research/terraformer_e2e_test.py b/trax/models/research/terraformer_e2e_test.py deleted file mode 100644 index 2dfd36742..000000000 --- a/trax/models/research/terraformer_e2e_test.py +++ /dev/null @@ -1,99 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""End to end test for Reformer.""" - -import os - -from absl.testing import absltest -import gin - -from trax import test_utils -from trax.models.research import terraformer # pylint: disable=unused-import -from trax.supervised import trainer_lib -from trax.tf_numpy import numpy as tf_np # pylint: disable=unused-import - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') -_CONFIG_DIR = os.path.join(pkg_dir, '../../supervised/configs/') - - -class TerraformerE2ETest(absltest.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - gin.add_config_file_search_path(_CONFIG_DIR) - test_utils.ensure_flag('test_tmpdir') - - def test_terraformer_wmt_ende(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - - gin.parse_config_file('terraformer_wmt_ende.gin') - - gin.bind_parameter('data_streams.data_dir', _TESTDATA) - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('batcher.buckets', - ([512], [batch_size_per_device, batch_size_per_device])) - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('ConfigurableTerraformer.n_encoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.n_decoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.d_ff', d_ff) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - def test_terraformer_copy(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - - gin.parse_config_file('terraformer_copy.gin') - - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('batcher.buckets', ([64], [1, 1])) # batch size 1. - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('ConfigurableTerraformer.n_encoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.n_decoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.d_ff', d_ff) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - def test_terraformer_purelsh_copy(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - - gin.parse_config_file('terraformer_purelsh_copy.gin') - - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('batcher.buckets', ([64], [1, 1])) # batch size 1. - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('ConfigurableTerraformer.n_encoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.n_decoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.d_ff', d_ff) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/terraformer_oom_test.py b/trax/models/research/terraformer_oom_test.py deleted file mode 100644 index 2d68819fe..000000000 --- a/trax/models/research/terraformer_oom_test.py +++ /dev/null @@ -1,129 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for OOM for Terraformer .""" - -import functools -import operator - -from absl.testing import absltest -import gin -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax.models.research import terraformer - - -class TerraformerOOMTest(absltest.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - def _lsh_self_attention_fn(self): - return functools.partial( - tl.LSHSelfAttention, - attention_dropout=0.0, - chunk_len=64, - n_buckets=[32, 32], - n_chunks_after=0, - n_chunks_before=1, - n_hashes=1, - n_parallel_heads=1, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def test_terraformer_one_step(self): - d_model = 1024 - vocab_size = 14041 - max_len = 16384 - pos_axial = (128, 128) # should multiply to max_len - pos_d_axial_embs = (512, 512) # sum to d model - - assert operator.mul(*pos_axial) == max_len - assert sum(pos_d_axial_embs) == d_model - - d_ff = 4096 - n_heads = 8 - d_attn = d_model // n_heads - - n_buckets = 128 - encoder_chunk_len = (2 * max_len) // n_buckets # 256 - decoder_chunk_len = 2 * encoder_chunk_len # 512 - encoder_n_chunks_after = 1 # since its not causal. - - lsh_self_attention = functools.partial(self._lsh_self_attention_fn(), - n_buckets=n_buckets) - - encoder_lsh_self_attention = functools.partial( - lsh_self_attention, n_chunks_after=encoder_n_chunks_after, - chunk_len=encoder_chunk_len) - - decoder_lsh_self_attention = functools.partial( - lsh_self_attention, n_chunks_after=0, - chunk_len=decoder_chunk_len) - - model = terraformer.ConfigurableTerraformer( - vocab_size, - d_model=d_model, - d_ff=d_ff, - d_attention_key=d_attn, - d_attention_value=d_attn, - n_encoder_layers=1, - n_decoder_layers=1, - n_heads=n_heads, - dropout=0.05, - max_len=max_len, - encoder_attention_type=encoder_lsh_self_attention, - encoder_decoder_attention_type=decoder_lsh_self_attention, - pos_axial_shape=pos_axial, - pos_d_axial_embs=pos_d_axial_embs, - ff_activation=tl.Relu, - ff_use_sru=0, - mode='train', - ) - - def random_sentence(): - return np.random.randint(low=1, high=vocab_size - 1, size=(1, max_len), - dtype=np.int32) - - x = [random_sentence(), random_sentence()] - weights, state = model.init(shapes.signature(x)) - - @fastmath.jit - def mock_training_step(x, weights, state, rng): - def compute_mock_loss(weights): - logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) - # This returns [logits, decoder tokens] - logits = logits_and_dec_toks[0] - loss = fastmath.numpy.mean(logits[..., 0]) - return loss, (new_state, logits) - gradients, (new_state, logits) = fastmath.grad( - compute_mock_loss, has_aux=True)(weights) - new_weights = fastmath.nested_map_multiarg( - lambda w, g: w - 1e-4 * g, weights, gradients) - return new_weights, new_state, logits - - weights, state, logits = mock_training_step( - x, weights, state, fastmath.random.get_prng(0)) - - self.assertEqual(logits.shape, (1, max_len, vocab_size)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/terraformer_test.py b/trax/models/research/terraformer_test.py deleted file mode 100644 index b5344a2f5..000000000 --- a/trax/models/research/terraformer_test.py +++ /dev/null @@ -1,273 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Terraformer models.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax.layers import test_utils -from trax.models.research import terraformer - - -BACKENDS = [fastmath.Backend.JAX] - - -def short_name(b): - if b == fastmath.Backend.JAX: - return 'jax' - else: - return 'tf' - - -class TerraformerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - def _lsh_self_attention_fn(self): - return functools.partial( - tl.LSHSelfAttention, - attention_dropout=0.0, - chunk_len=64, - n_buckets=[32, 32], - n_chunks_after=0, - n_chunks_before=1, - n_hashes=1, - n_parallel_heads=1, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def _timebin_self_attention_fn(self, use_reference_code=False): - return functools.partial( - tl.SelfAttention, - attention_dropout=0.05, - chunk_len=64, - n_chunks_before=1, - n_parallel_heads=1, - use_reference_code=use_reference_code - ) - - @parameterized.named_parameters( - [('_%s_efficient' % short_name(backend), backend, tl.SelfAttention, False) - for backend in BACKENDS] + - [('_%s_causal' % short_name(backend), backend, tl.CausalAttention, False) - for backend in BACKENDS] + - # NOTE: tl.SelfAttention is not currently working for this case. - [('_%s_preembed' % short_name(backend), backend, tl.CausalAttention, True) - for backend in BACKENDS]) - def test_terraformer_quick(self, backend, encoder_attention_type, preembed): - with fastmath.use_backend(backend): - vocab_size = 2 - input_vocab_size = None if preembed else vocab_size - output_vocab_size = vocab_size if preembed else None - max_len = 2 - - model = terraformer.ConfigurableTerraformer( - input_vocab_size, - d_model=4, - d_ff=4, - n_encoder_layers=1, - n_decoder_layers=1, - n_heads=2, - dropout=0.05, - max_len=max_len, - pos_type=None, - ff_activation=tl.Relu, - ff_use_sru=0, - ff_chunk_size=2, - mode='train', - output_vocab_size=output_vocab_size, - encoder_attention_type=encoder_attention_type, - ) - - if preembed: - model_inputs = [np.ones((1, max_len, 3)).astype(np.float32), - np.ones((1, max_len)).astype(bool)] - else: - model_inputs = [np.ones((1, max_len)).astype(np.int32)] - x = model_inputs + [np.ones((1, max_len)).astype(np.int32)] - model.init(shapes.signature(x)) - - logits, dec_toks = model(x) - del dec_toks - - self.assertEqual(logits.shape, (1, max_len, vocab_size)) - - def test_terraformer_deterministic_eval(self): - with fastmath.use_backend(fastmath.Backend.JAX): - vocab_size = 16 - d_model = 4 - batch_size = 2 - length = 5 - - model_fn = functools.partial( - terraformer.ConfigurableTerraformer, - vocab_size, - d_model=d_model, - d_ff=16, - n_encoder_layers=0, - n_decoder_layers=1, - n_heads=2, - dropout=0.0, - max_len=length*2, - pos_type=None, - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - ) - - inp = np.random.randint(vocab_size, size=(batch_size, length)) - out = np.zeros((batch_size, length), dtype=np.int32) - - test_utils.test_eval_is_deterministic((inp, out), model_fn) - - def test_terraformer_predict_equals_eval(self): - with fastmath.use_backend(fastmath.Backend.JAX): - vocab_size = 16 - d_model = 8 - batch_size = 1 - length = 5 - - model_fn = functools.partial( - terraformer.ConfigurableTerraformer, - vocab_size, - d_model=d_model, - d_ff=16, - n_encoder_layers=1, - n_decoder_layers=1, - n_heads=2, - ff_use_sru=(1, 8), # ? is SRU working? - dropout=0.0, - max_len=(length+7)*2, - pos_type=None, - reversible_encoder=True, - n_decoder_attention_layers=1, - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - ) - - # Token id of 0 indicates padding; and predict mode doesn't support it. - inp = np.random.randint(1, vocab_size, size=(batch_size, length)) - inp[:, -2:] = 0 - out = np.zeros((batch_size, length), dtype=np.int32) - - test_utils.test_eval_equals_predict( - (inp, out), model_fn, seq_axis=1, seq_tensor=-1, init_tokens=1) - - def test_terraformer_doubling(self): - vocab_size = 2 - max_len = 2 - - model = terraformer.ConfigurableTerraformer( - vocab_size, - d_model=8, - d_ff=16, - n_encoder_layers=1, - n_decoder_layers=6, - n_heads=2, - dropout=0.05, - max_len=max_len, - pos_type=None, - half_before_layer=2, - double_after_layer=2, - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - mode='train', - ) - - x = [np.ones((1, max_len)).astype(np.int32), - np.ones((1, max_len)).astype(np.int32)] - model.init(shapes.signature(x)) - - logits, dec_toks = model(x) - del dec_toks - - self.assertEqual(logits.shape, (1, max_len, vocab_size)) - - def test_terraformer_one_step(self): - vocab_size = 32 - max_len = 256 - pos_axial = 16 - assert pos_axial * pos_axial == max_len - - chunk_len = 32 - - # Since 2 * chunk_len * n_buckets should be max_len. - n_buckets = max_len // (2 * chunk_len) - - lsh_self_attention = functools.partial(self._lsh_self_attention_fn(), - chunk_len=chunk_len, - n_buckets=n_buckets) - - timebin_self_attention = self._timebin_self_attention_fn() - - model = terraformer.ConfigurableTerraformer( - vocab_size, - d_model=32, - d_ff=64, - d_attention_key=64, - d_attention_value=64, - n_encoder_layers=2, - n_decoder_layers=2, - n_heads=2, - dropout=0.05, - max_len=max_len, - encoder_attention_type=lsh_self_attention, - encoder_decoder_attention_type=[timebin_self_attention, - lsh_self_attention], - pos_axial_shape=(pos_axial, pos_axial), - pos_d_axial_embs=(64, 192), - ff_activation=tl.Relu, - ff_use_sru=0, - ff_chunk_size=64, - ff_sparsity=8, - mode='train', - ) - - x = [np.ones((1, max_len)).astype(np.int32), - np.ones((1, max_len)).astype(np.int32)] - weights, state = model.init(shapes.signature(x)) - - @fastmath.jit - def mock_training_step(x, weights, state, rng): - def compute_mock_loss(weights): - logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) - # This returns [logits, decoder tokens] - logits = logits_and_dec_toks[0] - loss = fastmath.numpy.mean(logits[..., 0]) - return loss, (new_state, logits) - gradients, (new_state, logits) = fastmath.grad( - compute_mock_loss, has_aux=True)(weights) - new_weights = fastmath.nested_map_multiarg( - lambda w, g: w - 1e-4 * g, weights, gradients) - return new_weights, new_state, logits - - weights, state, logits = mock_training_step( - x, weights, state, fastmath.random.get_prng(0)) - - self.assertEqual(logits.shape, (1, max_len, vocab_size)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/transformer2_test.py b/trax/models/research/transformer2_test.py deleted file mode 100644 index 18a10c5d3..000000000 --- a/trax/models/research/transformer2_test.py +++ /dev/null @@ -1,377 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Transformer models.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.models.research import transformer2 - - -class Transformer2Test(absltest.TestCase): - - def test_concat_with_padding(self): - vec_e = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - # vec_e[:,:,0] != 0 - mask_e = np.array([[True, True, False, False, False, False], - [True, True, True, False, False, False]]) - - vec_d = np.array( - [[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - layer = transformer2.ConcatWithPadding(mode='train') - inp = (vec_e, vec_d, mask_e, vec_e, vec_d) # tok_e = vec_e, tok_d = vec_d - layer.init(shapes.signature(inp)) - y, _, _ = layer(inp) - - np.testing.assert_equal( - y, - np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - ) - - def test_concat_with_padding_predict(self): - vec_e = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - # vec_e[:,:,0] != 0 - mask_e = np.array([[True, True, False, False, False, False], - [True, True, True, False, False, False]]) - - vec_d = np.array( - [[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - layer = transformer2.ConcatWithPadding(mode='predict') - inp = (vec_e, vec_d, mask_e, vec_e, vec_d) # tok_e = vec_e, tok_d = vec_d - _, _ = layer.init(shapes.signature(inp)) - y, _, _ = layer(inp) - - np.testing.assert_equal( - y, - np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - ) - - # On subsequent runs however, we should get vec_d only. - for _ in range(2): - y, _, _ = layer(inp) - np.testing.assert_equal(y, vec_d) - - def test_concat_with_padding2(self): - vec_e = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - # vec_e[:,:,0] != 0 - mask_e = np.array([[True, True, False, False, False, False], - [True, True, True, False, False, False]]) - - vec_d = np.array( - [[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - layer = transformer2.ConcatWithPadding2(mode='train') - inp = (vec_e, vec_e, vec_d, mask_e, vec_e, vec_d) - layer.init(shapes.signature(inp)) - y1, y2, _, _ = layer(inp) - - np.testing.assert_equal( - y1, - np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - ) - np.testing.assert_equal( - y2, - np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - ) - - def test_strip_from_concatenate_with_padding(self): - enc_dec = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - tok_e = np.array([[7, 8, 0, 0, 0, 0], [4, 6, 3, 0, 0, 0]]) - tok_d = np.array([[4, 6, 0, 0], [3, 4, 1, 0]]) - - layer = transformer2.StripFromConcatenateWithPadding( - mode='train') - inp = (enc_dec, tok_e, tok_d) - _, _ = layer.init(shapes.signature(inp)) - y = layer(inp) - - np.testing.assert_equal( - y, - np.array([[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0]]])) - - def test_strip_from_concatenate_with_padding_predict(self): - enc_dec = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - tok_e = np.array([[7, 8, 0, 0, 0, 0], [4, 6, 3, 0, 0, 0]]) - tok_d = np.array([[4, 6, 0, 0], [3, 4, 1, 0]]) - - layer = transformer2.StripFromConcatenateWithPadding( - mode='predict') - inp = (enc_dec, tok_e, tok_d) - _, _ = layer.init(shapes.signature(inp)) - y = layer(inp) - - np.testing.assert_equal( - y, - np.array([[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0]]])) - - # On subsequent runs however, we should get enc_dec only. - for _ in range(2): - y = layer(inp) - np.testing.assert_equal(y, enc_dec) - - def test_transformer_noencdec_forward_shape(self): - input_vocab_size = 16 - output_vocab_size = 16 - - model = transformer2.Transformer2( - input_vocab_size, output_vocab_size, d_model=32, d_ff=64, - n_encoder_layers=2, n_decoder_layers=2, n_heads=2) - - enc_toks = np.array( - [[6, 2, 0, 0, 0, 0], - [6, 3, 7, 0, 0, 0]]) - dec_toks = np.array( - [[4, 2, 0, 0], - [8, 5, 0, 0]]) - - xs = [enc_toks, dec_toks] - _, _ = model.init(shapes.signature(xs)) - - # decoder output, decoder mask - ys = model(xs) - - # (B, L2, H) - self.assertEqual(ys[0].shape, - (dec_toks.shape[0], dec_toks.shape[1], output_vocab_size)) - - self.assertEqual(ys[1].shape, dec_toks.shape) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/rl_test.py b/trax/models/rl_test.py deleted file mode 100644 index ac0e8b4ce..000000000 --- a/trax/models/rl_test.py +++ /dev/null @@ -1,55 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RL.""" - -from unittest import mock -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.models import rl - - -class RLTest(absltest.TestCase): - - def test_policy_forward_shape(self): - mock_dist = mock.MagicMock() - mock_dist.n_inputs = 4 - model = rl.Policy(policy_distribution=mock_dist) - x = np.ones((2, 3)) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (2, 4)) - - def test_value_forward_shape(self): - model = rl.Value() - x = np.ones((2, 3)) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (2, 1)) - - def test_policy_and_value_forward_shape(self): - mock_dist = mock.MagicMock() - mock_dist.n_inputs = 4 - model = rl.PolicyAndValue(policy_distribution=mock_dist) - x = np.ones((2, 3)) - _, _ = model.init(shapes.signature(x)) - ys = model(x) - self.assertEqual([y.shape for y in ys], [(2, 4), (2, 1)]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/rnn.py b/trax/models/rnn.py index f29414a8e..e3f841bc8 100644 --- a/trax/models/rnn.py +++ b/trax/models/rnn.py @@ -19,208 +19,217 @@ from trax.fastmath import numpy as jnp -def RNNLM(vocab_size, - d_model=512, - n_layers=2, - rnn_cell=tl.LSTMCell, - rnn_cell_d_state_multiplier=2, - dropout=0.1, - mode='train'): - """Returns an RNN language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Embedding depth throughout the model. - n_layers: Number of RNN layers. - rnn_cell: Type of RNN cell; must be a subclass of `Layer`. - rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell - state. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout. - mode: If `'predict'`, use fast inference; if `'train'` apply dropout. - - Returns: - An RNN language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - - if n_layers != 2: # TODO(jonni): Remove n_layers arg, if it can't vary? - raise ValueError(f'Number of layers must be set to 2; instead got' - f' {n_layers}.') - - def MultiRNNCell(): - """Multi-layer RNN cell.""" +def RNNLM( + vocab_size, + d_model=512, + n_layers=2, + rnn_cell=tl.LSTMCell, + rnn_cell_d_state_multiplier=2, + dropout=0.1, + mode="train", +): + """Returns an RNN language model. + + This model performs autoregressive language modeling: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Embedding depth throughout the model. + n_layers: Number of RNN layers. + rnn_cell: Type of RNN cell; must be a subclass of `Layer`. + rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell + state. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout. + mode: If `'predict'`, use fast inference; if `'train'` apply dropout. + + Returns: + An RNN language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + + if n_layers != 2: # TODO(jonni): Remove n_layers arg, if it can't vary? + raise ValueError( + f"Number of layers must be set to 2; instead got" f" {n_layers}." + ) + + def MultiRNNCell(): + """Multi-layer RNN cell.""" + return tl.Serial( + tl.Parallel(tl.Select([0]), tl.Split(n_items=n_layers)), + tl.SerialWithSideOutputs( + [rnn_cell(n_units=d_model) for _ in range(n_layers)] + ), + tl.Parallel(tl.Select([0]), tl.Concatenate(n_items=n_layers)), + ) + + zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter + depth_multiplier=n_layers * rnn_cell_d_state_multiplier + ) + return tl.Serial( - tl.Parallel([], tl.Split(n_items=n_layers)), - tl.SerialWithSideOutputs( - [rnn_cell(n_units=d_model) for _ in range(n_layers)]), - tl.Parallel([], tl.Concatenate(n_items=n_layers)) + tl.ShiftRight(mode=mode), + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, mode=mode), + tl.Branch(tl.Select([0]), zero_state), + tl.Scan(MultiRNNCell(), axis=1, mode=mode), + tl.Select([0], n_in=2), # Drop RNN state. + tl.Dense(vocab_size), ) - zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter - depth_multiplier=n_layers * rnn_cell_d_state_multiplier - ) - - return tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, mode=mode), - tl.Branch([], zero_state), - tl.Scan(MultiRNNCell(), axis=1, mode=mode), - tl.Select([0], n_in=2), # Drop RNN state. - tl.Dense(vocab_size), - ) - - -def GRULM(vocab_size=256, - d_model=512, - n_layers=1, - mode='train'): - """Returns a GRU (gated recurrent unit) language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Embedding depth throughout the model. - n_layers: Number of GRU layers. - mode: If `'predict'`, use fast inference (and omit the right shift). - - Returns: - A GRU language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - return tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(vocab_size, d_model), - [tl.GRU(d_model, mode=mode) for _ in range(n_layers)], - tl.Dense(vocab_size), - ) + +def GRULM(vocab_size=256, d_model=512, n_layers=1, mode="train"): + """Returns a GRU (gated recurrent unit) language model. + + This model performs autoregressive language modeling: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Embedding depth throughout the model. + n_layers: Number of GRU layers. + mode: If `'predict'`, use fast inference (and omit the right shift). + + Returns: + A GRU language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + return tl.Serial( + tl.ShiftRight(mode=mode), + tl.Embedding(vocab_size, d_model), + [tl.GRU(d_model, mode=mode) for _ in range(n_layers)], + tl.Dense(vocab_size), + ) # TODO(jonni): Decide names (here and Transformer): input/source, output/target # TODO(jonni): Align with Transfomer: (attention-)dropout, n-(attention-)heads -def LSTMSeq2SeqAttn(input_vocab_size=256, - target_vocab_size=256, - d_model=512, - n_encoder_layers=2, - n_decoder_layers=2, - n_attention_heads=1, - attention_dropout=0.0, - mode='train'): - """Returns an LSTM sequence-to-sequence model with attention. - - This model is an encoder-decoder that performs tokenized string-to-string - ("source"-to-"target") transduction: - - - inputs (2): - - - source: rank 2 tensor representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length). The - tensor elements are integers in `range(input_vocab_size)`, and `0` - values mark padding positions. - - - target: rank 2 tensor representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length). The - tensor elements are integers in `range(output_vocab_size)`, and `0` - values mark padding positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - An example use would be to translate (tokenized) sentences from English to - German. - - The model works as follows: - - * Input encoder runs on the input tokens and creates activations that - are used as both keys and values in attention. - * Pre-attention decoder runs on the targets and creates - activations that are used as queries in attention. - * Attention runs on the queries, keys and values masking out input padding. - * Decoder runs on the result, followed by a cross-entropy loss. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - target_vocab_size: Target vocabulary size. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - n_encoder_layers: Number of LSTM layers in the encoder. - n_decoder_layers: Number of LSTM layers in the decoder after attention. - n_attention_heads: Number of attention heads. - attention_dropout: Stochastic rate (probability) for dropping an activation - value when applying dropout within an attention block. - mode: If `'predict'`, use fast inference. If `'train'`, each attention block - will include dropout; else, it will pass all values through unaltered. - - Returns: - An LSTM sequence-to-sequence model as a layer that maps from a - source-target tokenized text pair to activations over a vocab set. - """ - input_encoder = tl.Serial( - tl.Embedding(input_vocab_size, d_model), - [tl.LSTM(d_model) for _ in range(n_encoder_layers)], - ) - - pre_attention_decoder = tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(target_vocab_size, d_model), - tl.LSTM(d_model, mode=mode), - ) - - def PrepareAttentionInputs(): - """Layer that prepares queries, keys, values and mask for attention.""" - def F(encoder_activations, decoder_activations, input_tokens): - keys = values = encoder_activations - queries = decoder_activations - # Mask is 1 where inputs are not padding (0) and 0 where they are padding. - mask = (input_tokens != 0) - # We need to add axes to the mask for attention heads and decoder length. - mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) - # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len]. - mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1)) - mask = mask.astype(jnp.float32) - return queries, keys, values, mask - return tl.Fn('PrepareAttentionInputs', F, n_out=4) - - return tl.Serial( # in-toks, target-toks - tl.Select([0, 1, 0, 1]), # in-toks, target-toks, in-toks, target-toks - tl.Parallel(input_encoder, pre_attention_decoder), - PrepareAttentionInputs(), # q, k, v, mask, target-toks - tl.Residual( - tl.AttentionQKV(d_model, n_heads=n_attention_heads, - dropout=attention_dropout, mode=mode, - cache_KV_in_predict=True) - ), # decoder-vecs, mask, target-toks - tl.Select([0, 2]), # decoder-vecs, target-toks - [tl.LSTM(d_model, mode=mode) for _ in range(n_decoder_layers)], - tl.Dense(target_vocab_size), - tl.LogSoftmax() - ) +def LSTMSeq2SeqAttn( + input_vocab_size=256, + target_vocab_size=256, + d_model=512, + n_encoder_layers=2, + n_decoder_layers=2, + n_attention_heads=1, + attention_dropout=0.0, + mode="train", +): + """Returns an LSTM sequence-to-sequence model with attention. + + This model is an encoder-decoder that performs tokenized string-to-string + ("source"-to-"target") transduction: + + - inputs (2): + + - source: rank 2 tensor representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length). The + tensor elements are integers in `range(input_vocab_size)`, and `0` + values mark padding positions. + + - target: rank 2 tensor representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length). The + tensor elements are integers in `range(output_vocab_size)`, and `0` + values mark padding positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + An example use would be to translate (tokenized) sentences from English to + German. + + The model works as follows: + + * Input encoder runs on the input tokens and creates activations that + are used as both keys and values in attention. + * Pre-attention decoder runs on the targets and creates + activations that are used as queries in attention. + * Attention runs on the queries, keys and values masking out input padding. + * Decoder runs on the result, followed by a cross-entropy loss. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + target_vocab_size: Target vocabulary size. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + n_encoder_layers: Number of LSTM layers in the encoder. + n_decoder_layers: Number of LSTM layers in the decoder after attention. + n_attention_heads: Number of attention heads. + attention_dropout: Stochastic rate (probability) for dropping an activation + value when applying dropout within an attention block. + mode: If `'predict'`, use fast inference. If `'train'`, each attention block + will include dropout; else, it will pass all values through unaltered. + + Returns: + An LSTM sequence-to-sequence model as a layer that maps from a + source-target tokenized text pair to activations over a vocab set. + """ + input_encoder = tl.Serial( + tl.Embedding(input_vocab_size, d_model), + [tl.LSTM(d_model) for _ in range(n_encoder_layers)], + ) + + pre_attention_decoder = tl.Serial( + tl.ShiftRight(mode=mode), + tl.Embedding(target_vocab_size, d_model), + tl.LSTM(d_model, mode=mode), + ) + + def PrepareAttentionInputs(): + """Layer that prepares queries, keys, values and mask for attention.""" + + def F(encoder_activations, decoder_activations, input_tokens): + keys = values = encoder_activations + queries = decoder_activations + # Mask is 1 where inputs are not padding (0) and 0 where they are padding. + mask = input_tokens != 0 + # We need to add axes to the mask for attention heads and decoder length. + mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) + # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len]. + mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1)) + mask = mask.astype(jnp.float32) + return queries, keys, values, mask + + return tl.Fn("PrepareAttentionInputs", F, n_out=4) + + return tl.Serial( # in-toks, target-toks + tl.Select([0, 1, 0, 1]), # in-toks, target-toks, in-toks, target-toks + tl.Parallel(input_encoder, pre_attention_decoder), + PrepareAttentionInputs(), # q, k, v, mask, target-toks + tl.Residual( + tl.AttentionQKV( + d_model, + n_heads=n_attention_heads, + dropout=attention_dropout, + mode=mode, + cache_KV_in_predict=True, + ) + ), # decoder-vecs, mask, target-toks + tl.Select([0, 2]), # decoder-vecs, target-toks + [tl.LSTM(d_model, mode=mode) for _ in range(n_decoder_layers)], + tl.Dense(target_vocab_size), + tl.LogSoftmax(), + ) diff --git a/trax/models/rnn_test.py b/trax/models/rnn_test.py deleted file mode 100644 index 6de04bea2..000000000 --- a/trax/models/rnn_test.py +++ /dev/null @@ -1,60 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RNNs.""" - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import shapes -from trax.models import rnn - -BACKENDS = [fastmath.Backend.JAX] - - -@parameterized.named_parameters( - ('_' + b.value, b) for b in BACKENDS) -class RNNTest(parameterized.TestCase): - - def test_rnnlm_forward_shape(self, backend): - with fastmath.use_backend(backend): - model = rnn.RNNLM(vocab_size=20, d_model=16) - x = np.ones((3, 28)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 28, 20)) - - def test_grulm_forward_shape(self, backend): - with fastmath.use_backend(backend): - model = rnn.GRULM(vocab_size=20, d_model=16) - x = np.ones((3, 28)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 28, 20)) - - def test_lstmseq2seqattn_forward_shape(self, backend): - with fastmath.use_backend(backend): - model = rnn.LSTMSeq2SeqAttn( - input_vocab_size=20, target_vocab_size=20, d_model=16) - x = np.ones((3, 28)).astype(np.int32) - _, _ = model.init([shapes.signature(x), shapes.signature(x)]) - ys = model([x, x]) - self.assertEqual([y.shape for y in ys], [(3, 28, 20), (3, 28)]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/transformer.py b/trax/models/transformer.py index ed9917baa..f64fd35ef 100644 --- a/trax/models/transformer.py +++ b/trax/models/transformer.py @@ -23,594 +23,606 @@ # Defaults used across Transformer variants. -MODE = 'train' +MODE = "train" D_MODEL = 512 D_FF = 2048 N_LAYERS = 6 N_HEADS = 8 MAX_SEQUENCE_LENGTH = 2048 -DROPOUT_RATE = .1 +DROPOUT_RATE = 0.1 DROPOUT_SHARED_AXES = None FF_ACTIVATION_TYPE = tl.Relu -def TransformerEncoder(vocab_size, - n_classes=10, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer encoder suitable for N-way classification. - - This model maps tokenized text to N-way (``n_classes``) activations: - - - input: Array representing a batch of text strings via token IDs plus - padding markers; shape is (batch_size, sequence_length), where - sequence_length <= ``max_len``. Array elements are integers in - ``range(vocab_size)``, and 0 values mark padding positions. - - - output: Array representing a batch of raw (non-normalized) activations - over ``n_classes`` categories; shape is (batch_size, ``n_classes``). - - Args: - vocab_size: Input vocabulary size -- each element of the input array - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - n_classes: Last/innermost dimension of output arrays, suitable for N-way - classification. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of encoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder blocks. The same rate is also - used for attention dropout in encoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each encoder block will include dropout; else, it - will pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. - - Returns: - A Transformer model that maps strings (conveyed by token IDs) to - raw (non-normalized) activations over a range of output classes. - """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _EncBlock(): - return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - tl.Branch([], tl.PaddingMask()), # Creates masks from copy of the tokens. - tl.Embedding(vocab_size, d_model), - _Dropout(), - tl.PositionalEncoding(max_len=max_len), - [_EncBlock() for _ in range(n_layers)], - tl.Select([0], n_in=2), # Drops the masks. - tl.LayerNorm(), - tl.Mean(axis=1), - tl.Dense(n_classes), - ) - - -def TransformerDecoder(vocab_size=None, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer decoder. - - This model maps sequential inputs to sequential outputs: - - - input if ``vocab_size`` is specified: array representing a batch - of text strings via token IDs plus padding markers; shape is - (batch_size, sequence_length). The tensor elements are integers in - ``range(vocab_size)``, and 0 values mark padding positions. - - - input if ``vocab_size`` is ``None``: 3-D array representing a batch of - sequences of activation vectors; shape is (batch_size, sequence_length, - ``d_model``). - - - output: 3-D array with shape (batch_size, sequence_length, ``d_model``). - - The model uses causal attention and does *not* shift the input to the right. - Thus, the output for position `t` is based on inputs up to and including - position `t`. - - Args: - vocab_size: If specified, gives the input vocabulary size -- each element - of the input tensor should be an integer in ``range(vocab_size)``. - If ``None``, indicates that the model expects as input sequences of - floating point vectors, each with ``d_model`` components. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of decoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also - used for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each encoder block will include dropout; else, it - will pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. - - Returns: - If ``vocab_size`` is defined: a Transformer model that maps strings - (conveyed by token IDs) to sequences of activation vectors. - - If ``vocab_size`` is ``None``: a Transformer model that maps sequences of - activation vectors to sequences of activation vectors. - """ - def _EmbeddingOrDense(): - return (tl.Embedding(vocab_size, d_model) if vocab_size is not None - else tl.Dense(d_model)) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _DecBlock(): - return _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - _EmbeddingOrDense(), - _Dropout(), - tl.PositionalEncoding(max_len=max_len), - [_DecBlock() for _ in range(n_layers)], - tl.LayerNorm(), - ) - - -def TransformerLM(vocab_size, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer language model. - - This model performs autoregressive language modeling: - - - input: Array representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). Array - elements are integers in ``range(vocab_size)``, and 0 values mark padding - positions. - - - output: 3-D array of raw activations with last/innermost dimension of - ``vocab_size``, suitable for decoding into a batch of token strings; - shape is (batch_size, sequence_length, ``vocab_size``). - - This model uses only the decoder part of the overall Transformer. - - Args: - vocab_size: Input vocabulary size -- each element of the input array - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of decoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also - used for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'predict'``, use fast inference. If ``'train'``, each decoder - block will include dropout; else, it will pass all values through - unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. - - Returns: - A Transformer language model that maps strings (represented as token ID - sequences) to sequences of raw (non-normalized) activation vectors; each - vector in the sequence can be mapped (e.g., by `argmax`) to a token ID. - """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _DecBlock(): - return _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(vocab_size, d_model), - _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=mode), - [_DecBlock() for _ in range(n_layers)], - tl.LayerNorm(), - tl.Dense(vocab_size), - ) - - -def Transformer(input_vocab_size, - output_vocab_size=None, - d_model=D_MODEL, - d_ff=D_FF, - n_encoder_layers=N_LAYERS, - n_decoder_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a full Transformer model. - - This model is an encoder-decoder that performs tokenized string-to-string - ("source"-to-"target") transduction: - - - inputs (2): - - - source: Array representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length), - where sequence_length <= ``max_len``. Array elements are integers in - ``range(input_vocab_size)``, and 0 values mark padding positions. - - - target: Array representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length), - where sequence_length <= ``max_len``. Array elements are integers in - ``range(output_vocab_size)``, and 0 values mark padding positions. - - - output: 3-D array of raw activations with last/innermost dimension of - ``output_vocab_size``, suitable for decoding into a batch of token - strings; shape is (batch_size, sequence_length, ``vocab_size``). - - An example use would be to translate (tokenized) sentences from English to - German. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if ``None``, then input and target integers (token IDs) are assumed to - come from the same vocabulary. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_encoder_layers: Number of encoder blocks. - n_decoder_layers: Number of decoder blocks. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder/decoder blocks. The same rate is - also used for attention dropout in encoder/decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'predict'``, use fast inference. If ``'train'``, each - encoder/decoder block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each - encoder/decoder block; must be an activation-type subclass of - :py:class:`Layer`. - - Returns: - A Transformer model as a layer that maps from a source-target tokenized - text pair to activations over a vocab set. - """ - # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise. - encoder_mode = 'eval' if mode == 'predict' else mode - - # Share embedding weights if no separate output vocab size. - in_embedder = tl.Embedding(input_vocab_size, d_model) - if output_vocab_size is None: - out_embedder = in_embedder - output_vocab_size = input_vocab_size - else: - out_embedder = tl.Embedding(output_vocab_size, d_model) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _EncBlock(): - return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - def _Encoder(): - encoder = tl.Serial( - in_embedder, +def TransformerEncoder( + vocab_size, + n_classes=10, + d_model=D_MODEL, + d_ff=D_FF, + n_layers=N_LAYERS, + n_heads=N_HEADS, + max_len=MAX_SEQUENCE_LENGTH, + dropout=DROPOUT_RATE, + dropout_shared_axes=DROPOUT_SHARED_AXES, + mode=MODE, + ff_activation=FF_ACTIVATION_TYPE, +): + """Returns a Transformer encoder suitable for N-way classification. + + This model maps tokenized text to N-way (``n_classes``) activations: + + - input: Array representing a batch of text strings via token IDs plus + padding markers; shape is (batch_size, sequence_length), where + sequence_length <= ``max_len``. Array elements are integers in + ``range(vocab_size)``, and 0 values mark padding positions. + + - output: Array representing a batch of raw (non-normalized) activations + over ``n_classes`` categories; shape is (batch_size, ``n_classes``). + + Args: + vocab_size: Input vocabulary size -- each element of the input array + should be an integer in ``range(vocab_size)``. These integers typically + represent token IDs from a vocabulary-based tokenizer. + n_classes: Last/innermost dimension of output arrays, suitable for N-way + classification. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + n_layers: Number of encoder blocks. Each block includes attention, dropout, + residual, layer-norm, feedforward (:py:class:`Dense`), and activation + layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder blocks. The same rate is also + used for attention dropout in encoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each encoder block will include dropout; else, it + will pass all values through unaltered. + ff_activation: Type of activation function at the end of each encoder + block; must be an activation-type subclass of :py:class:`Layer`. + + Returns: + A Transformer model that maps strings (conveyed by token IDs) to + raw (non-normalized) activations over a range of output classes. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _EncBlock(): + return _EncoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + return tl.Serial( + tl.Branch([], tl.PaddingMask()), # Creates masks from copy of the tokens. + tl.Embedding(vocab_size, d_model), _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=encoder_mode), - [_EncBlock() for _ in range(n_encoder_layers)], + tl.PositionalEncoding(max_len=max_len), + [_EncBlock() for _ in range(n_layers)], + tl.Select([0], n_in=2), # Drops the masks. tl.LayerNorm(), + tl.Mean(axis=1), + tl.Dense(n_classes), ) - return tl.Cache(encoder) if mode == 'predict' else encoder - - def _EncDecBlock(): - return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation) - - # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e - # Model output is decoder-side vectors and decoder-side tokens: vec_d tok_d - return tl.Serial( - tl.Select([0, 1, 1]), # Copies decoder tokens for use in loss. - - # Encode. - tl.Branch([], tl.PaddingMask()), # tok_e masks tok_d tok_d - _Encoder(), - - # Decode. - tl.Select([2, 1, 0]), # Re-orders inputs: tok_d masks vec_e ..... - tl.ShiftRight(mode=mode), - out_embedder, - _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=mode), - tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... - [_EncDecBlock() for _ in range(n_decoder_layers)], - tl.LayerNorm(), - tl.Select([0], n_in=3), # Drops masks and encoding vectors. - - # Map vectors to match output vocab size. - tl.Dense(output_vocab_size), - ) - - -def _EncoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers that implements a Transformer encoder block. - - The input to the block is a pair (activations, mask) where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. The block's outputs are the same type/shape as its inputs, - so that multiple blocks can be chained together. - - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder blocks. The same rate is also used - for attention dropout in encoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. - - Returns: - A list of layers that act in series as a (repeatable) encoder block. - """ - def _Attention(): - return tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - return [ - tl.Residual( - tl.LayerNorm(), - _Attention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _DecoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers that implements a Transformer decoder block. - - The input to the block is a pair (activations, mask) where the mask encodes - causal connections, preventing attention to future positions in the sequence. - The block's outputs are the same type/shape as its inputs, so that multiple - blocks can be chained together. - - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also used - for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. - - Returns: - A list of layers that act in series as a (repeatable) decoder block. - """ - def _CausalAttention(): - return tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, - mode=mode), - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - return [ - tl.Residual( - tl.LayerNorm(), - _CausalAttention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _EncoderDecoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers implementing a Transformer encoder-decoder block. - - The block input is a triple (decoder_activations, mask, encoder_activations) - where the mask was created from the original input token IDs to prevent - attending to padding positions for that input. - - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder/decoder blocks. The same rate is - also used for attention dropout in encoder/decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. - - Returns: - A list of layers that act in series as a (repeatable) encoder-decoder - block. - """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _AttentionQKV(): - return tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, - mode=mode, cache_KV_in_predict=True) - - def _CausalAttention(): - return tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - return [ # vec_d masks vec_e - tl.Residual( - tl.LayerNorm(), - _CausalAttention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e - _AttentionQKV(), # vec_d masks vec_e - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _FeedForwardBlock(d_model, - d_ff, - dropout, - dropout_shared_axes, - mode, - activation): - """Returns a list of layers that implements a feedforward block. - - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. - - Returns: - A list of layers that maps vectors to vectors. - """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - return [ - tl.Dense(d_ff), - activation(), - _Dropout(), - tl.Dense(d_model), - ] + + +def TransformerDecoder( + vocab_size=None, + d_model=D_MODEL, + d_ff=D_FF, + n_layers=N_LAYERS, + n_heads=N_HEADS, + max_len=MAX_SEQUENCE_LENGTH, + dropout=DROPOUT_RATE, + dropout_shared_axes=DROPOUT_SHARED_AXES, + mode=MODE, + ff_activation=FF_ACTIVATION_TYPE, +): + """Returns a Transformer decoder. + + This model maps sequential inputs to sequential outputs: + + - input if ``vocab_size`` is specified: array representing a batch + of text strings via token IDs plus padding markers; shape is + (batch_size, sequence_length). The tensor elements are integers in + ``range(vocab_size)``, and 0 values mark padding positions. + + - input if ``vocab_size`` is ``None``: 3-D array representing a batch of + sequences of activation vectors; shape is (batch_size, sequence_length, + ``d_model``). + + - output: 3-D array with shape (batch_size, sequence_length, ``d_model``). + + The model uses causal attention and does *not* shift the input to the right. + Thus, the output for position `t` is based on inputs up to and including + position `t`. + + Args: + vocab_size: If specified, gives the input vocabulary size -- each element + of the input tensor should be an integer in ``range(vocab_size)``. + If ``None``, indicates that the model expects as input sequences of + floating point vectors, each with ``d_model`` components. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + n_layers: Number of decoder blocks. Each block includes attention, dropout, + residual, layer-norm, feedforward (:py:class:`Dense`), and activation + layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within decoder blocks. The same rate is also + used for attention dropout in decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each encoder block will include dropout; else, it + will pass all values through unaltered. + ff_activation: Type of activation function at the end of each encoder + block; must be an activation-type subclass of :py:class:`Layer`. + + Returns: + If ``vocab_size`` is defined: a Transformer model that maps strings + (conveyed by token IDs) to sequences of activation vectors. + + If ``vocab_size`` is ``None``: a Transformer model that maps sequences of + activation vectors to sequences of activation vectors. + """ + + def _EmbeddingOrDense(): + return ( + tl.Embedding(vocab_size, d_model) + if vocab_size is not None + else tl.Dense(d_model) + ) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _DecBlock(): + return _DecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + return tl.Serial( + _EmbeddingOrDense(), + _Dropout(), + tl.PositionalEncoding(max_len=max_len), + [_DecBlock() for _ in range(n_layers)], + tl.LayerNorm(), + ) + + +def TransformerLM( + vocab_size, + d_model=D_MODEL, + d_ff=D_FF, + n_layers=N_LAYERS, + n_heads=N_HEADS, + max_len=MAX_SEQUENCE_LENGTH, + dropout=DROPOUT_RATE, + dropout_shared_axes=DROPOUT_SHARED_AXES, + mode=MODE, + ff_activation=FF_ACTIVATION_TYPE, +): + """Returns a Transformer language model. + + This model performs autoregressive language modeling: + + - input: Array representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). Array + elements are integers in ``range(vocab_size)``, and 0 values mark padding + positions. + + - output: 3-D array of raw activations with last/innermost dimension of + ``vocab_size``, suitable for decoding into a batch of token strings; + shape is (batch_size, sequence_length, ``vocab_size``). + + This model uses only the decoder part of the overall Transformer. + + Args: + vocab_size: Input vocabulary size -- each element of the input array + should be an integer in ``range(vocab_size)``. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + n_layers: Number of decoder blocks. Each block includes attention, dropout, + residual, layer-norm, feedforward (:py:class:`Dense`), and activation + layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within decoder blocks. The same rate is also + used for attention dropout in decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'predict'``, use fast inference. If ``'train'``, each decoder + block will include dropout; else, it will pass all values through + unaltered. + ff_activation: Type of activation function at the end of each encoder + block; must be an activation-type subclass of :py:class:`Layer`. + + Returns: + A Transformer language model that maps strings (represented as token ID + sequences) to sequences of raw (non-normalized) activation vectors; each + vector in the sequence can be mapped (e.g., by `argmax`) to a token ID. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _DecBlock(): + return _DecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + return tl.Serial( + tl.ShiftRight(mode=mode), + tl.Embedding(vocab_size, d_model), + _Dropout(), + tl.PositionalEncoding(max_len=max_len, mode=mode), + [_DecBlock() for _ in range(n_layers)], + tl.LayerNorm(), + tl.Dense(vocab_size), + ) + + +def Transformer( + input_vocab_size, + output_vocab_size=None, + d_model=D_MODEL, + d_ff=D_FF, + n_encoder_layers=N_LAYERS, + n_decoder_layers=N_LAYERS, + n_heads=N_HEADS, + max_len=MAX_SEQUENCE_LENGTH, + dropout=DROPOUT_RATE, + dropout_shared_axes=DROPOUT_SHARED_AXES, + mode=MODE, + ff_activation=FF_ACTIVATION_TYPE, +): + """Returns a full Transformer model. + + This model is an encoder-decoder that performs tokenized string-to-string + ("source"-to-"target") transduction: + + - inputs (2): + + - source: Array representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length), + where sequence_length <= ``max_len``. Array elements are integers in + ``range(input_vocab_size)``, and 0 values mark padding positions. + + - target: Array representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length), + where sequence_length <= ``max_len``. Array elements are integers in + ``range(output_vocab_size)``, and 0 values mark padding positions. + + - output: 3-D array of raw activations with last/innermost dimension of + ``output_vocab_size``, suitable for decoding into a batch of token + strings; shape is (batch_size, sequence_length, ``vocab_size``). + + An example use would be to translate (tokenized) sentences from English to + German. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in ``range(vocab_size)``. These integers typically + represent token IDs from a vocabulary-based tokenizer. + output_vocab_size: If specified, gives the vocabulary size for the targets; + if ``None``, then input and target integers (token IDs) are assumed to + come from the same vocabulary. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + n_encoder_layers: Number of encoder blocks. + n_decoder_layers: Number of decoder blocks. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder/decoder blocks. The same rate is + also used for attention dropout in encoder/decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'predict'``, use fast inference. If ``'train'``, each + encoder/decoder block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each + encoder/decoder block; must be an activation-type subclass of + :py:class:`Layer`. + + Returns: + A Transformer model as a layer that maps from a source-target tokenized + text pair to activations over a vocab set. + """ + # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise. + encoder_mode = "eval" if mode == "predict" else mode + + # Share embedding weights if no separate output vocab size. + in_embedder = tl.Embedding(input_vocab_size, d_model) + if output_vocab_size is None: + out_embedder = in_embedder + output_vocab_size = input_vocab_size + else: + out_embedder = tl.Embedding(output_vocab_size, d_model) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _EncBlock(): + return _EncoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + def _Encoder(): + encoder = tl.Serial( + in_embedder, + _Dropout(), + tl.PositionalEncoding(max_len=max_len, mode=encoder_mode), + [_EncBlock() for _ in range(n_encoder_layers)], + tl.LayerNorm(), + ) + return tl.Cache(encoder) if mode == "predict" else encoder + + def _EncDecBlock(): + return _EncoderDecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e + # Model output is decoder-side vectors and decoder-side tokens: vec_d tok_d + return tl.Serial( + tl.Select([0, 1, 1]), # Copies decoder tokens for use in loss. + # Encode. + tl.Branch([], tl.PaddingMask()), # tok_e masks tok_d tok_d + _Encoder(), + # Decode. + tl.Select([2, 1, 0]), # Re-orders inputs: tok_d masks vec_e ..... + tl.ShiftRight(mode=mode), + out_embedder, + _Dropout(), + tl.PositionalEncoding(max_len=max_len, mode=mode), + tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... + [_EncDecBlock() for _ in range(n_decoder_layers)], + tl.LayerNorm(), + tl.Select([0], n_in=3), # Drops masks and encoding vectors. + # Map vectors to match output vocab size. + tl.Dense(output_vocab_size), + ) + + +def _EncoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation +): + """Returns a list of layers that implements a Transformer encoder block. + + The input to the block is a pair (activations, mask) where the mask was + created from the original source tokens to prevent attending to the padding + part of the input. The block's outputs are the same type/shape as its inputs, + so that multiple blocks can be chained together. + + Args: + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder blocks. The same rate is also used + for attention dropout in encoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`Layer`. + + Returns: + A list of layers that act in series as a (repeatable) encoder block. + """ + + def _Attention(): + return tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _FFBlock(): + return _FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation + ) + + return [ + tl.Residual( + tl.LayerNorm(), + _Attention(), + _Dropout(), + ), + tl.Residual( + tl.LayerNorm(), + _FFBlock(), + _Dropout(), + ), + ] + + +def _DecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation +): + """Returns a list of layers that implements a Transformer decoder block. + + The input to the block is a pair (activations, mask) where the mask encodes + causal connections, preventing attention to future positions in the sequence. + The block's outputs are the same type/shape as its inputs, so that multiple + blocks can be chained together. + + Args: + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within decoder blocks. The same rate is also used + for attention dropout in decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`Layer`. + + Returns: + A list of layers that act in series as a (repeatable) decoder block. + """ + + def _CausalAttention(): + return ( + tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), + ) + + def _FFBlock(): + return _FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation + ) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + return [ + tl.Residual( + tl.LayerNorm(), + _CausalAttention(), + _Dropout(), + ), + tl.Residual( + tl.LayerNorm(), + _FFBlock(), + _Dropout(), + ), + ] + + +def _EncoderDecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation +): + """Returns a list of layers implementing a Transformer encoder-decoder block. + + The block input is a triple (decoder_activations, mask, encoder_activations) + where the mask was created from the original input token IDs to prevent + attending to padding positions for that input. + + Args: + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder/decoder blocks. The same rate is + also used for attention dropout in encoder/decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`Layer`. + + Returns: + A list of layers that act in series as a (repeatable) encoder-decoder + block. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _AttentionQKV(): + return tl.AttentionQKV( + d_model, + n_heads=n_heads, + dropout=dropout, + mode=mode, + cache_KV_in_predict=True, + ) + + def _CausalAttention(): + return tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) + + def _FFBlock(): + return _FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation + ) + + return [ # vec_d masks vec_e + tl.Residual( + tl.LayerNorm(), + _CausalAttention(), + _Dropout(), + ), + tl.Residual( + tl.LayerNorm(), + tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e + _AttentionQKV(), # vec_d masks vec_e + _Dropout(), + ), + tl.Residual( + tl.LayerNorm(), + _FFBlock(), + _Dropout(), + ), + ] + + +def _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, activation): + """Returns a list of layers that implements a feedforward block. + + Args: + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each block. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each block will include dropout; else, it will + pass all values through unaltered. + activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`Layer`. + + Returns: + A list of layers that maps vectors to vectors. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + return [ + tl.Dense(d_ff), + activation(), + _Dropout(), + tl.Dense(d_model), + ] diff --git a/trax/models/transformer_test.py b/trax/models/transformer_test.py deleted file mode 100644 index 017b1d4e0..000000000 --- a/trax/models/transformer_test.py +++ /dev/null @@ -1,70 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Transformer models.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import shapes -from trax.layers import test_utils -from trax.models import transformer - - -class TransformerTest(parameterized.TestCase): - - def test_transformer_lm_forward_shape(self): - vocab_size = 16 - model = transformer.TransformerLM( - vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) - x = np.ones((3, 5)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 5, vocab_size)) - - def _test_transformer_forward_shape(self, input_vocab_size, - output_vocab_size): - model = transformer.Transformer( - input_vocab_size, output_vocab_size, d_model=32, d_ff=64, - n_encoder_layers=2, n_decoder_layers=2, n_heads=2) - xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - y, _ = model(xs) - - vocab_size = output_vocab_size or input_vocab_size - self.assertEqual(y.shape, (3, 5, vocab_size)) - - @parameterized.named_parameters( - ('same_vocab', 16, None), - ('same_size', 16, 16), - ('different_size', 16, 50)) - def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): - """Run the Transformer forward and check output shape.""" - self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) - - - def test_dot_product_causal_attention_fast_inference(self): - model_fn = functools.partial( - transformer.TransformerLM, d_model=4, d_ff=8, n_layers=2, n_heads=2 - ) - test_utils.test_eval_equals_predict_discrete(model_fn) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/optimizers/__init__.py b/trax/optimizers/__init__.py index 1ec623abe..825195b77 100644 --- a/trax/optimizers/__init__.py +++ b/trax/optimizers/__init__.py @@ -17,6 +17,7 @@ import gin +from trax.optimizers import sgd from trax.optimizers import adafactor from trax.optimizers import adam from trax.optimizers import base @@ -29,12 +30,13 @@ def opt_configure(*args, **kwargs): - kwargs['module'] = 'trax.optimizers' - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.optimizers" + return gin.external_configurable(*args, **kwargs) + # Optimizers (using upper-case names). # pylint: disable=invalid-name -SGD = opt_configure(base.SGD) +SGD = opt_configure(sgd.SGD) Momentum = opt_configure(momentum.Momentum) RMSProp = opt_configure(rms_prop.RMSProp) Adam = opt_configure(adam.Adam) diff --git a/trax/optimizers/adafactor.py b/trax/optimizers/adafactor.py index 501290246..a1fbe5d6a 100644 --- a/trax/optimizers/adafactor.py +++ b/trax/optimizers/adafactor.py @@ -20,142 +20,148 @@ class Adafactor(opt_base.Optimizer): - """Adafactor optimizer, as described in https://arxiv.org/abs/1804.04235.""" + """Adafactor optimizer, as described in https://arxiv.org/abs/1804.04235.""" - def __init__(self, - learning_rate=0.05, - factored=True, - multiply_by_parameter_scale=True, - do_clipping=True, - do_momentum=False, - momentum_in_bfloat16=False, - beta1=0.0, - decay_rate=0.8, - clipping_threshold=1.0, - weight_decay_rate=1e-5, - weight_decay_n_steps=0, - epsilon1=1e-16, - epsilon2=1e-3): - """Create the Adafactor optimizer. + def __init__( + self, + learning_rate=0.05, + factored=True, + multiply_by_parameter_scale=True, + do_clipping=True, + do_momentum=False, + momentum_in_bfloat16=False, + beta1=0.0, + decay_rate=0.8, + clipping_threshold=1.0, + weight_decay_rate=1e-5, + weight_decay_n_steps=0, + epsilon1=1e-16, + epsilon2=1e-3, + ): + """Create the Adafactor optimizer. - Adafactor is described in https://arxiv.org/abs/1804.04235. + Adafactor is described in https://arxiv.org/abs/1804.04235. - Args: - learning_rate: float: trax-provided learning rate. - factored: boolean: whether to use factored second-moment estimator for 2d - variables. - multiply_by_parameter_scale: boolean: if True, then scale provided - learning_rate by parameter norm. if False, provided learning_rate is - absolute step size. - do_clipping: whether to clip gradients; if True, set clipping_theshold. - do_momentum: whether to use momentum; if True, set beta1. - momentum_in_bfloat16: if True, store momentum in bfloat16 to save memory. - beta1: a float value between 0 and 1, enables momentum and uses extra - memory if nonzero! Off by default. - decay_rate: float: controls second-moment exponential decay schedule. - clipping_threshold: an optional float >= 1, if None no update clipping. - weight_decay_rate: rate at which to decay weights. - weight_decay_n_steps: for how many steps to decay weights (always if None) - epsilon1: Regularization constant for squared gradient. - epsilon2: Regularization constant for parameter scale. - """ - # These 4 parameters are not configurable once the class is created. - self._factored = factored - self._multiply_by_parameter_scale = multiply_by_parameter_scale - self._do_clipping = do_clipping - self._do_momentum = do_momentum - self._momentum_in_bfloat16 = momentum_in_bfloat16 - # Dynamically configurable parameters will be passed to the update function. - super().__init__( - learning_rate=learning_rate, - beta1=beta1, - decay_rate=decay_rate, - clipping_threshold=clipping_threshold, - weight_decay_rate=weight_decay_rate, - weight_decay_n_steps=weight_decay_n_steps, - epsilon1=epsilon1, - epsilon2=epsilon2, - ) + Args: + learning_rate: float: trax-provided learning rate. + factored: boolean: whether to use factored second-moment estimator for 2d + variables. + multiply_by_parameter_scale: boolean: if True, then scale provided + learning_rate by parameter norm. if False, provided learning_rate is + absolute step size. + do_clipping: whether to clip gradients; if True, set clipping_theshold. + do_momentum: whether to use momentum; if True, set beta1. + momentum_in_bfloat16: if True, store momentum in bfloat16 to save memory. + beta1: a float value between 0 and 1, enables momentum and uses extra + memory if nonzero! Off by default. + decay_rate: float: controls second-moment exponential decay schedule. + clipping_threshold: an optional float >= 1, if None no update clipping. + weight_decay_rate: rate at which to decay weights. + weight_decay_n_steps: for how many steps to decay weights (always if None) + epsilon1: Regularization constant for squared gradient. + epsilon2: Regularization constant for parameter scale. + """ + # These 4 parameters are not configurable once the class is created. + self._factored = factored + self._multiply_by_parameter_scale = multiply_by_parameter_scale + self._do_clipping = do_clipping + self._do_momentum = do_momentum + self._momentum_in_bfloat16 = momentum_in_bfloat16 + # Dynamically configurable parameters will be passed to the update function. + super().__init__( + learning_rate=learning_rate, + beta1=beta1, + decay_rate=decay_rate, + clipping_threshold=clipping_threshold, + weight_decay_rate=weight_decay_rate, + weight_decay_n_steps=weight_decay_n_steps, + epsilon1=epsilon1, + epsilon2=epsilon2, + ) - @staticmethod - def _decay_rate_pow(i, exponent=0.8): - """Default Adafactor second-moment decay schedule.""" - t = jnp.array(i, jnp.float32) + 1.0 - return 1.0 - t**(-exponent) + @staticmethod + def _decay_rate_pow(i, exponent=0.8): + """Default Adafactor second-moment decay schedule.""" + t = jnp.array(i, jnp.float32) + 1.0 + return 1.0 - t ** (-exponent) - def init(self, weights): - shape = weights.shape - slots = [] - if self._factored and len(shape) >= 2: - v_row = jnp.zeros(shape[:-1], dtype=jnp.float32) - v_col = jnp.zeros(shape[:-2] + shape[-1:], dtype=jnp.float32) - slots.extend([v_row, v_col]) - else: - v = jnp.zeros_like(weights) - slots.append(v) - if self._do_momentum: - m = jnp.zeros_like(weights) - if self._momentum_in_bfloat16: - m = m.astype(jnp.bfloat16) - slots.append(m) - return slots + def init(self, weights): + shape = weights.shape + slots = [] + if self._factored and len(shape) >= 2: + v_row = jnp.zeros(shape[:-1], dtype=jnp.float32) + v_col = jnp.zeros(shape[:-2] + shape[-1:], dtype=jnp.float32) + slots.extend([v_row, v_col]) + else: + v = jnp.zeros_like(weights) + slots.append(v) + if self._do_momentum: + m = jnp.zeros_like(weights) + if self._momentum_in_bfloat16: + m = m.astype(jnp.bfloat16) + slots.append(m) + return slots - def update(self, step, grads, weights, slots, opt_params): - updates = [] - learning_rate = opt_params['learning_rate'] - beta1 = opt_params['beta1'] - decay_rate = opt_params['decay_rate'] - clipping_threshold = opt_params['clipping_threshold'] - weight_decay_rate = opt_params['weight_decay_rate'] - weight_decay_n_steps = opt_params['weight_decay_n_steps'] - weight_decay_rate = jnp.where( - weight_decay_n_steps < 1, # if weight_decay_n_steps == 0, ignore it - weight_decay_rate, - (weight_decay_rate * jnp.maximum(weight_decay_n_steps - step, 0.0) / - jnp.maximum(weight_decay_n_steps, 0.0))) - epsilon1 = opt_params['epsilon1'] - epsilon2 = opt_params['epsilon2'] - decay_rate = self._decay_rate_pow(step, exponent=decay_rate) - update_scale = learning_rate - if self._multiply_by_parameter_scale: - update_scale *= jnp.maximum( - jnp.sqrt(jnp.mean(weights * weights)), epsilon2) - mixing_rate = 1.0 - decay_rate + def update(self, step, grads, weights, slots, opt_params): + updates = [] + learning_rate = opt_params["learning_rate"] + beta1 = opt_params["beta1"] + decay_rate = opt_params["decay_rate"] + clipping_threshold = opt_params["clipping_threshold"] + weight_decay_rate = opt_params["weight_decay_rate"] + weight_decay_n_steps = opt_params["weight_decay_n_steps"] + weight_decay_rate = jnp.where( + weight_decay_n_steps < 1, # if weight_decay_n_steps == 0, ignore it + weight_decay_rate, + ( + weight_decay_rate + * jnp.maximum(weight_decay_n_steps - step, 0.0) + / jnp.maximum(weight_decay_n_steps, 0.0) + ), + ) + epsilon1 = opt_params["epsilon1"] + epsilon2 = opt_params["epsilon2"] + decay_rate = self._decay_rate_pow(step, exponent=decay_rate) + update_scale = learning_rate + if self._multiply_by_parameter_scale: + update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)), epsilon2) + mixing_rate = 1.0 - decay_rate - grads_sqr = grads * grads - if self._factored and len(weights.shape) >= 2: - v_row = slots[0] # In this case, the slots are (v_row, v_col, ...). - v_col = slots[1] - new_v_row = ( - decay_rate * v_row + mixing_rate * jnp.mean(grads_sqr, axis=-1)) - new_v_col = ( - decay_rate * v_col + mixing_rate * jnp.mean(grads_sqr, axis=-2)) - updates.extend([new_v_row, new_v_col]) - row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True) - row_factor = (row_mean / (new_v_row + epsilon1))**0.5 - col_factor = (new_v_col + epsilon1)**-0.5 - y = ( - grads * jnp.expand_dims(row_factor, axis=-1) * - jnp.expand_dims(col_factor, axis=-2)) - else: - v = slots[0] # In this case, the slots are (v, ...) - new_v = decay_rate * v + mixing_rate * grads_sqr - updates.append(new_v) - y = grads * (new_v + epsilon1)**-0.5 + grads_sqr = grads * grads + if self._factored and len(weights.shape) >= 2: + v_row = slots[0] # In this case, the slots are (v_row, v_col, ...). + v_col = slots[1] + new_v_row = decay_rate * v_row + mixing_rate * jnp.mean(grads_sqr, axis=-1) + new_v_col = decay_rate * v_col + mixing_rate * jnp.mean(grads_sqr, axis=-2) + updates.extend([new_v_row, new_v_col]) + row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True) + row_factor = (row_mean / (new_v_row + epsilon1)) ** 0.5 + col_factor = (new_v_col + epsilon1) ** -0.5 + y = ( + grads + * jnp.expand_dims(row_factor, axis=-1) + * jnp.expand_dims(col_factor, axis=-2) + ) + else: + v = slots[0] # In this case, the slots are (v, ...) + new_v = decay_rate * v + mixing_rate * grads_sqr + updates.append(new_v) + y = grads * (new_v + epsilon1) ** -0.5 - if self._do_clipping: - clipping_denom = ( - jnp.maximum(1.0, jnp.sqrt(jnp.mean(y * y)) / clipping_threshold)) - y /= clipping_denom + if self._do_clipping: + clipping_denom = jnp.maximum( + 1.0, jnp.sqrt(jnp.mean(y * y)) / clipping_threshold + ) + y /= clipping_denom - subtrahend = update_scale * y - if self._do_momentum: - m = slots[-1] # Momentum is always the last slot (if used). - m = m.astype(subtrahend.dtype) # Accumulate in subtrahend dtype. - new_m = beta1 * m + (1.0 - beta1) * subtrahend - subtrahend = new_m - updates.append(new_m.astype(slots[-1].dtype)) # Back to bfloat if needed. + subtrahend = update_scale * y + if self._do_momentum: + m = slots[-1] # Momentum is always the last slot (if used). + m = m.astype(subtrahend.dtype) # Accumulate in subtrahend dtype. + new_m = beta1 * m + (1.0 - beta1) * subtrahend + subtrahend = new_m + updates.append(new_m.astype(slots[-1].dtype)) # Back to bfloat if needed. - new_weights = (1 - weight_decay_rate) * weights - subtrahend - # TODO(lukaszkaiser): why is the astype needed here? Check and correct. - return new_weights.astype(weights.dtype), updates + new_weights = (1 - weight_decay_rate) * weights - subtrahend + # TODO(lukaszkaiser): why is the astype needed here? Check and correct. + return new_weights.astype(weights.dtype), updates diff --git a/trax/optimizers/adam.py b/trax/optimizers/adam.py index e950eab9f..d25a6155f 100644 --- a/trax/optimizers/adam.py +++ b/trax/optimizers/adam.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + """Adam optimizer class.""" from trax.fastmath import numpy as jnp @@ -21,62 +22,72 @@ # pylint: disable=line-too-long class Adam(opt_base.Optimizer): - r"""Adam optimizer; described in https://arxiv.org/abs/1412.6980. - - The update rule for time step :math:`t`, given gradients :math:`g_t` and - "Stepsize" :math:`\alpha`, is: - - .. math:: + r"""Adam optimizer; described in https://arxiv.org/abs/1412.6980. + The update rule for time step :math:`t`, given gradients :math:`g_t` and "Stepsize" :math:`\alpha`, is: + .. math:: \hat{m}_t &\leftarrow \big(\beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t\big)\ /\ (1 - \beta_1^t) \\ \hat{v}_t &\leftarrow \big(\beta_2 \cdot m_{t-1} + (1 - \beta_2) \cdot g_t^2\big)\ /\ (1 - \beta_2^t) \\ \theta_t &\leftarrow \theta_{t-1} -\ \alpha \cdot \hat{m}_t / \big(\sqrt{\hat{v}_t} + \epsilon\big) """ - # pylint: enable=line-too-long - def __init__(self, learning_rate=0.0001, weight_decay_rate=1e-5, # pylint: disable=useless-super-delegation - b1=0.9, b2=0.999, eps=1e-5, clip_grad_norm=None): - r"""Creates an Adam optimizer. + # pylint: enable=line-too-long + def __init__( + self, + learning_rate=0.0001, + weight_decay_rate=1e-5, # pylint: disable=useless-super-delegation + b1=0.9, + b2=0.999, + eps=1e-5, + clip_grad_norm=None, + ): + r"""Creates an Adam optimizer. + + Args: + learning_rate: Initial (unadapted) learning rate :math:`\alpha`; original + paper calls this `Stepsize` and suggests .001 as a generally good + value. + weight_decay_rate: Fraction of prior weight values to subtract on each + step; equivalent to multiplying each weight element by + `1 - weight_decay_rate`. (This is not part of the core Adam + algorithm.) + b1: Exponential decay rate :math:`\beta_1` for first moment estimates. + b2: Exponential decay rate :math:`\beta_2` for second moment estimates. + eps: Small positive constant :math:`\epsilon` for numerical stability. + clip_grad_norm: Threshold value above which gradient clipping occurs. + (This is not part of the core Adam algorithm.) + """ + super().__init__( + learning_rate=learning_rate, + weight_decay_rate=weight_decay_rate, + b1=b1, + b2=b2, + eps=eps, + clip_grad_norm=clip_grad_norm, + ) + + def init(self, weights): + m = jnp.zeros_like(weights) + v = jnp.zeros_like(weights) + return m, v + + def update(self, step, grads, weights, slots, opt_params): + m, v = slots + + learning_rate = opt_params["learning_rate"] + weight_decay_rate = opt_params["weight_decay_rate"] + b1 = opt_params["b1"] + b2 = opt_params["b2"] + eps = opt_params["eps"] - Args: - learning_rate: Initial (unadapted) learning rate :math:`\alpha`; original - paper calls this `Stepsize` and suggests .001 as a generally good - value. - weight_decay_rate: Fraction of prior weight values to subtract on each - step; equivalent to multiplying each weight element by - `1 - weight_decay_rate`. (This is not part of the core Adam - algorithm.) - b1: Exponential decay rate :math:`\beta_1` for first moment estimates. - b2: Exponential decay rate :math:`\beta_2` for second moment estimates. - eps: Small positive constant :math:`\epsilon` for numerical stability. - clip_grad_norm: Threshold value above which gradient clipping occurs. - (This is not part of the core Adam algorithm.) - """ - super().__init__( - learning_rate=learning_rate, - weight_decay_rate=weight_decay_rate, - b1=b1, - b2=b2, - eps=eps, - clip_grad_norm=clip_grad_norm - ) + m = (1 - b1) * grads + b1 * m # First moment estimate. + v = (1 - b2) * (grads**2) + b2 * v # Second moment estimate. + mhat = m / (1 - b1 ** (step + 1)) # Bias correction. + vhat = v / (1 - b2 ** (step + 1)) - def init(self, weights): - m = jnp.zeros_like(weights) - v = jnp.zeros_like(weights) - return m, v + new_weights = ( + (1 - weight_decay_rate) * weights + - (learning_rate * mhat / (jnp.sqrt(vhat) + eps)) + ).astype(weights.dtype) - def update(self, step, grads, weights, slots, opt_params): - m, v = slots - learning_rate = opt_params['learning_rate'] - weight_decay_rate = opt_params['weight_decay_rate'] - b1 = opt_params['b1'] - b2 = opt_params['b2'] - eps = opt_params['eps'] - m = (1 - b1) * grads + b1 * m # First moment estimate. - v = (1 - b2) * (grads ** 2) + b2 * v # Second moment estimate. - mhat = m / (1 - b1 ** (step + 1)) # Bias correction. - vhat = v / (1 - b2 ** (step + 1)) - new_weights = ((1 - weight_decay_rate) * weights - ( - learning_rate * mhat / (jnp.sqrt(vhat) + eps))).astype(weights.dtype) - return new_weights, (m, v) + return new_weights, (m, v) diff --git a/trax/optimizers/base.py b/trax/optimizers/base.py index 269bc0a73..52100d263 100644 --- a/trax/optimizers/base.py +++ b/trax/optimizers/base.py @@ -20,234 +20,234 @@ class Optimizer: - """Base class for optimizers that work hand in hand with Trax layers. + """Base class for optimizers that work hand in hand with Trax layers. - To define an optimizer subclass, specify its behavior with respect to a - single node in the network (e.g., a single dense layer): + To define an optimizer subclass, specify its behavior with respect to a + single node in the network (e.g., a single dense layer): - - `init`: how to create/initialize optimizer-internal parameters ("slots"), - as a function of the node's weights. - - `update`: how to use gradient information to update node weights and - optimizer slots. + - `init`: how to create/initialize optimizer-internal parameters ("slots"), + as a function of the node's weights. + - `update`: how to use gradient information to update node weights and + optimizer slots. - The Trax runtime combines these node-local computations into layer weight - updates and optimizer slot updates for the whole tree of layers in the model. - """ - - def __init__(self, learning_rate=0.01, clip_grad_norm=None, - **init_opt_params): - """Sets initial hyperparameter values for this optimizer. - - Takes optimizer hyperparameters as keyword arguments. These values can - change over time (training steps), e.g., for learning rate schedules. - - To expose subclass hyperparameters for gin configuration, override this - constructor and use explicitly named keyword arguments. See - `momentum.Momentum.__init__` for one such example. - - Args: - learning_rate: Learning rate for the optimizer. This can change during - training by means of a training rate schedule. - clip_grad_norm: If specified, this scalar value is used to limit gradient - size -- all gradient elements in a training step are treated as if - they belonged to a single vector and then scaled back if needed so - that such a vector's L2 norm does not exceed `clip_grad_norm`. If - None, no clipping happens. - **init_opt_params: Initial values of any additional optimizer parameters. - """ - init_opt_params['learning_rate'] = learning_rate - self._init_opt_params = { - name: jnp.array(value) for (name, value) in init_opt_params.items() - } - self._slots = None - # Gradient clipping happens with respect to the norm of the whole gradient - # tree, so it is not passed to single-slot updates, but done in this class - # for the whole gradient tree. - self._clip_grad_norm = clip_grad_norm - - def init(self, weights): - """Creates optimizer slots that fit the given weights. - - Args: - weights: Trainable weights for one layer. Optimizer slots typically match - the data shape and type of the given layer weights. - """ - raise NotImplementedError - - def update(self, step, grads, weights, slots, opt_params): - """Computes updated layer weights and optimizer slots for one training step. - - Args: - step: Training step number. - grads: Gradient values for this node (from back-propagation during a - training step). - weights: Current weight values for this node (i.e., layer weights). - slots: Current slot values for this node. - opt_params: Optimizer hyperparameters (e.g. learning rate, momentum), - same across all nodes in the model. - - Returns: - Tuple of (new_weights, new_slots), which the Trax runtime will use to - update the model and optimizer within each training step. + The Trax runtime combines these node-local computations into layer weight + updates and optimizer slot updates for the whole tree of layers in the model. """ - raise NotImplementedError - - @property - def slots(self): - return self._slots - @slots.setter - def slots(self, slots): - self._slots = slots + def __init__(self, learning_rate=0.01, clip_grad_norm=None, **init_opt_params): + """Sets initial hyperparameter values for this optimizer. + + Takes optimizer hyperparameters as keyword arguments. These values can + change over time (training steps), e.g., for learning rate schedules. + + To expose subclass hyperparameters for gin configuration, override this + constructor and use explicitly named keyword arguments. See + `momentum.Momentum.__init__` for one such example. + + Args: + learning_rate: Learning rate for the optimizer. This can change during + training by means of a training rate schedule. + clip_grad_norm: If specified, this scalar value is used to limit gradient + size -- all gradient elements in a training step are treated as if + they belonged to a single vector and then scaled back if needed so + that such a vector's L2 norm does not exceed `clip_grad_norm`. If + None, no clipping happens. + **init_opt_params: Initial values of any additional optimizer parameters. + """ + init_opt_params["learning_rate"] = learning_rate + self._init_opt_params = { + name: jnp.array(value) for (name, value) in init_opt_params.items() + } + self._slots = None + # Gradient clipping happens with respect to the norm of the whole gradient + # tree, so it is not passed to single-slot updates, but done in this class + # for the whole gradient tree. + self._clip_grad_norm = clip_grad_norm + + def init(self, weights): + """Creates optimizer slots that fit the given weights. + + Args: + weights: Trainable weights for one layer. Optimizer slots typically match + the data shape and type of the given layer weights. + """ + raise NotImplementedError + + def update(self, step, grads, weights, slots, opt_params): + """Computes updated layer weights and optimizer slots for one training step. + + Args: + step: Training step number. + grads: Gradient values for this node (from back-propagation during a + training step). + weights: Current weight values for this node (i.e., layer weights). + slots: Current slot values for this node. + opt_params: Optimizer hyperparameters (e.g. learning rate, momentum), + same across all nodes in the model. + + Returns: + Tuple of (new_weights, new_slots), which the Trax runtime will use to + update the model and optimizer within each training step. + """ + raise NotImplementedError + + @property + def slots(self): + return self._slots + + @slots.setter + def slots(self, slots): + self._slots = slots + + @property + def opt_params(self): + return self._init_opt_params + + @opt_params.setter + def opt_params(self, opt_params): + self._init_opt_params = opt_params + + def tree_init(self, weight_tree): + """Assembles node-local initializations into full-tree initialization. + + Args: + weight_tree: Weights for an entire model, in a tree that matches the + model's layer structure. + + Returns: + Tuple `(slots, opt_params)`, where `slots` are the initialized optimizer + slot values and `opt_params` are optimizer hyperparameters (e.g., + learning rate, momentum). + """ + self._slots = tuple( + self.init(weight) for weight in fastmath.tree_flatten(weight_tree) + ) + return (self._slots, self._init_opt_params) + + def tree_update( + self, step, grad_tree, weight_tree, slots, opt_params, store_slots=True + ): + """Assembles node-local weight and slot updates for the full layer tree. + + Args: + step: Current step number in the training process. + grad_tree: Gradients for the entire model, in a tree that matches the + model's layer structure. + weight_tree: Current weights for the entire model, in a tree that matches + the model's layer structure. + slots: Optimizer slots. + opt_params: Optimizer hyperparameters (e.g. learning rate, momentum). + store_slots: Boolean; if True, stores resulting slots in this object; + when set to False, this becomes a pure function. + + Returns: + Tuple `(weights, slots)`, where `weights` are the optimizer-updated + weights for the whole model (in a tree matching the model's layer + structure) and `slots` are the updated optimizer slot values. + """ + grads_flat = fastmath.tree_flatten(grad_tree) + grads_norm = self._l2_norm(grads_flat) + if self._clip_grad_norm is not None: + max_norm = self._clip_grad_norm + grads_flat = [ + jnp.where( + grads_norm < max_norm, # pylint: disable=g-complex-comprehension + g, + g * (max_norm / grads_norm), + ) + for g in grads_flat + ] + weights_flat = fastmath.tree_flatten(weight_tree) + weights_norm = self._l2_norm(weights_flat) + updated_pairs = [ + self._update_and_check(step, grad, weight, slot, opt_params) + for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) + ] + new_weights_flat, slots = zip(*updated_pairs) + new_weights, _ = fastmath.tree_unflatten(new_weights_flat, weight_tree) + metrics = {"gradients_l2": grads_norm, "weights_l2": weights_norm} + slots = tuple(slots) + if store_slots: + self.slots = slots + return new_weights, slots, metrics + + def _l2_norm(self, flat_list): + """Returns an L2-like norm of all elements of all tensors in `flat_list`. + + Args: + flat_list: Collection of tensors as a flat list (rather than, e.g., a + tree). + + Returns: + A scalar value computed as if all the tensors in `flat_list` were joined + and flattened into a single vector, and then the L2 norm of that vector + was calculated. + """ + if fastmath.is_backend(fastmath.Backend.JAX): + norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) + else: + norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list)) + + return norm + + def _update_and_check(self, step, grads, weights, slots, opt_params): + """Updates a single weight array and checks types.""" + new_weights, new_slots = self.update(step, grads, weights, slots, opt_params) + if isinstance(weights, jnp.ndarray): + if not isinstance(new_weights, jnp.ndarray): + raise ValueError( + f"New weight values should be of type jnp.ndarray or a subclass; " + f"instead got {type(new_weights)}." + ) + if new_weights.dtype != weights.dtype: + raise ValueError( + f"New weight values dtype ({new_weights.dtype}) does not match " + f"the old one ({weights.dtype})." + ) + return new_weights, new_slots - @property - def opt_params(self): - return self._init_opt_params - @opt_params.setter - def opt_params(self, opt_params): - self._init_opt_params = opt_params - - def tree_init(self, weight_tree): - """Assembles node-local initializations into full-tree initialization. - - Args: - weight_tree: Weights for an entire model, in a tree that matches the - model's layer structure. +# Utilities. - Returns: - Tuple `(slots, opt_params)`, where `slots` are the initialized optimizer - slot values and `opt_params` are optimizer hyperparameters (e.g., - learning rate, momentum). - """ - self._slots = tuple(self.init(weight) - for weight in fastmath.tree_flatten(weight_tree)) - return (self._slots, self._init_opt_params) - def tree_update(self, step, grad_tree, weight_tree, slots, opt_params, - store_slots=True): - """Assembles node-local weight and slot updates for the full layer tree. +def l2_norm(tree): + """Returns an L2 norm computed over all elements of all tensors in `tree`. Args: - step: Current step number in the training process. - grad_tree: Gradients for the entire model, in a tree that matches the - model's layer structure. - weight_tree: Current weights for the entire model, in a tree that matches + tree: Tree-structured collection of tensors, e.g., model weights matching the model's layer structure. - slots: Optimizer slots. - opt_params: Optimizer hyperparameters (e.g. learning rate, momentum). - store_slots: Boolean; if True, stores resulting slots in this object; - when set to False, this becomes a pure function. Returns: - Tuple `(weights, slots)`, where `weights` are the optimizer-updated - weights for the whole model (in a tree matching the model's layer - structure) and `slots` are the updated optimizer slot values. - """ - grads_flat = fastmath.tree_flatten(grad_tree) - grads_norm = self._l2_norm(grads_flat) - if self._clip_grad_norm is not None: - max_norm = self._clip_grad_norm - grads_flat = [jnp.where(grads_norm < max_norm, # pylint: disable=g-complex-comprehension - g, - g * (max_norm / grads_norm)) - for g in grads_flat] - weights_flat = fastmath.tree_flatten(weight_tree) - weights_norm = self._l2_norm(weights_flat) - updated_pairs = [ - self._update_and_check(step, grad, weight, slot, opt_params) - for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) - ] - new_weights_flat, slots = zip(*updated_pairs) - new_weights, _ = fastmath.tree_unflatten(new_weights_flat, weight_tree) - metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm} - slots = tuple(slots) - if store_slots: - self.slots = slots - return new_weights, slots, metrics - - def _l2_norm(self, flat_list): - """Returns an L2-like norm of all elements of all tensors in `flat_list`. - - Args: - flat_list: Collection of tensors as a flat list (rather than, e.g., a - tree). - - Returns: - A scalar value computed as if all the tensors in `flat_list` were joined + A scalar value computed as if all the tensors in `tree` were combined and flattened into a single vector, and then the L2 norm of that vector was calculated. """ - if fastmath.is_backend(fastmath.Backend.JAX): - norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) - else: # TODO(lukaszkaiser): add vdot to TF-numpy - norm = jnp.sqrt(sum(jnp.sum(x*x) for x in flat_list)) - return norm - - def _update_and_check(self, step, grads, weights, slots, opt_params): - """Updates a single weight array and checks types.""" - new_weights, new_slots = self.update( - step, grads, weights, slots, opt_params) - if isinstance(weights, jnp.ndarray): - if not isinstance(new_weights, jnp.ndarray): - raise ValueError( - f'New weight values should be of type jnp.ndarray or a subclass; ' - f'instead got {type(new_weights)}.') - if new_weights.dtype != weights.dtype: - raise ValueError( - f'New weight values dtype ({new_weights.dtype}) does not match ' - f'the old one ({weights.dtype}).') - return new_weights, new_slots - - -class SGD(Optimizer): - """Stochastic gradient descent (SGD) optimizer.""" - - def init(self, weights): - return None - - def update(self, step, grads, weights, slots, opt_params): - del step, slots - lr = opt_params['learning_rate'] - new_weights = weights - (lr * grads).astype(weights.dtype) - return new_weights, None - - -# Utilities. + leaves = fastmath.tree_flatten(tree) + if fastmath.is_backend(fastmath.Backend.JAX): + norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves)) + else: + norm = jnp.sqrt(sum(jnp.tensordot(x, x)) for x in leaves) -def l2_norm(tree): - """Returns an L2 norm computed over all elements of all tensors in `tree`. + return norm - Args: - tree: Tree-structured collection of tensors, e.g., model weights matching - the model's layer structure. - Returns: - A scalar value computed as if all the tensors in `tree` were combined - and flattened into a single vector, and then the L2 norm of that vector - was calculated. - """ - leaves = fastmath.tree_flatten(tree) - return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves)) +def clip_grads(grad_tree, max_norm): + """Proportionally reduces each gradient value to respect an aggregate limit. + Args: + grad_tree: Gradient values structured as a tree of tensors matching the + model's layer structure. + max_norm: The aggregate limit on gradient values. All gradient elements in + `grad_tree` are treated as if they belonged to a single vector and + that vector is shortened if needed so that its L2 norm does not exceed + `clip_grad_norm`. -def clip_grads(grad_tree, max_norm): - """Proportionally reduces each gradient value to respect an aggregate limit. - - Args: - grad_tree: Gradient values structured as a tree of tensors matching the - model's layer structure. - max_norm: The aggregate limit on gradient values. All gradient elements in - `grad_tree` are treated as if they belonged to a single vector and - that vector is shortened if needed so that its L2 norm does not exceed - `clip_grad_norm`. - - Returns: - A new tree of tensors matching the structure of `grad_tree`, but with - element values proportionally rescaled as needed to respect the `max_norm` - limit. - """ - norm = l2_norm(grad_tree) - normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm)) - return fastmath.nested_map(grad_tree, normalize) + Returns: + A new tree of tensors matching the structure of `grad_tree`, but with + element values proportionally rescaled as needed to respect the `max_norm` + limit. + """ + norm = l2_norm(grad_tree) + normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm)) + return fastmath.nested_map(grad_tree, normalize) diff --git a/trax/optimizers/momentum.py b/trax/optimizers/momentum.py index 625e7a2b4..5318507a6 100644 --- a/trax/optimizers/momentum.py +++ b/trax/optimizers/momentum.py @@ -21,7 +21,7 @@ # TODO(jonni): Consider renaming this class to NesterovMomentum. class Momentum(base.Optimizer): - r"""A momentum optimizer. + r"""A momentum optimizer. This class implements two variants of momentum stochastic gradient descent (SGD): with and without the Nesterov correction. The implementation of the @@ -41,32 +41,32 @@ class Momentum(base.Optimizer): (:math:`\alpha`) on the parameters, independent of the Nesterov momentum. """ - def __init__( - self, learning_rate=0.01, mass=0.9, weight_decay_rate=1e-5, nesterov=True - ): # pylint: disable=useless-super-delegation - super().__init__( - learning_rate=learning_rate, - mass=mass, - weight_decay_rate=weight_decay_rate, - ) - self._nesterov = nesterov + def __init__( + self, learning_rate=0.01, mass=0.9, weight_decay_rate=1e-5, nesterov=True + ): # pylint: disable=useless-super-delegation + super().__init__( + learning_rate=learning_rate, + mass=mass, + weight_decay_rate=weight_decay_rate, + ) + self._nesterov = nesterov - def init(self, weights): - return jnp.zeros_like(weights) + def init(self, weights): + return jnp.zeros_like(weights) - def update(self, step, grads, weights, velocity, opt_params): - del step - v = velocity - mu = opt_params['mass'] - alpha = opt_params['weight_decay_rate'] - epsilon = opt_params['learning_rate'] + def update(self, step, grads, weights, velocity, opt_params): + del step + v = velocity + mu = opt_params["mass"] + alpha = opt_params["weight_decay_rate"] + epsilon = opt_params["learning_rate"] - new_v = mu * v + grads - if self._nesterov: - weight_update = mu * new_v + grads - else: - weight_update = new_v - new_weights = (1 - alpha) * weights - epsilon * weight_update + new_v = mu * v + grads + if self._nesterov: + weight_update = mu * new_v + grads + else: + weight_update = new_v + new_weights = (1 - alpha) * weights - epsilon * weight_update - new_weights = new_weights.astype(weights.dtype) - return (new_weights, new_v) + new_weights = new_weights.astype(weights.dtype) + return (new_weights, new_v) diff --git a/trax/optimizers/rms_prop.py b/trax/optimizers/rms_prop.py index 351d05425..40786cb5b 100644 --- a/trax/optimizers/rms_prop.py +++ b/trax/optimizers/rms_prop.py @@ -20,30 +20,32 @@ class RMSProp(opt_base.Optimizer): - """RMSProp optimizer. - - Uses optimizer weights ("slots") to maintain a root-mean-square exponentially - decaying average of gradients from prior training batches. - """ - - def __init__(self, learning_rate=0.001, gamma=0.9, - eps=1e-8, clip_grad_norm=None): # pylint: disable=useless-super-delegation - super().__init__( - learning_rate=learning_rate, - gamma=gamma, - eps=eps, - clip_grad_norm=clip_grad_norm - ) - - def init(self, weights): - return jnp.ones_like(weights) - - def update(self, step, grads, weights, avg_sq_grad, opt_params): - del step - lr = opt_params['learning_rate'] - gamma = opt_params['gamma'] - eps = opt_params['eps'] - avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma) - weights = weights - (lr * grads / - (jnp.sqrt(avg_sq_grad) + eps)).astype(weights.dtype) - return weights, avg_sq_grad + """RMSProp optimizer. + + Uses optimizer weights ("slots") to maintain a root-mean-square exponentially + decaying average of gradients from prior training batches. + """ + + def __init__( + self, learning_rate=0.001, gamma=0.9, eps=1e-8, clip_grad_norm=None + ): # pylint: disable=useless-super-delegation + super().__init__( + learning_rate=learning_rate, + gamma=gamma, + eps=eps, + clip_grad_norm=clip_grad_norm, + ) + + def init(self, weights): + return jnp.ones_like(weights) + + def update(self, step, grads, weights, avg_sq_grad, opt_params): + del step + lr = opt_params["learning_rate"] + gamma = opt_params["gamma"] + eps = opt_params["eps"] + avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1.0 - gamma) + weights = weights - (lr * grads / (jnp.sqrt(avg_sq_grad) + eps)).astype( + weights.dtype + ) + return weights, avg_sq_grad diff --git a/trax/optimizers/sgd.py b/trax/optimizers/sgd.py new file mode 100644 index 000000000..63f1ae228 --- /dev/null +++ b/trax/optimizers/sgd.py @@ -0,0 +1,14 @@ +from trax.optimizers import base as opt_base + + +class SGD(opt_base.Optimizer): + """Stochastic gradient descent (SGD) optimizer.""" + + def init(self, weights): + return None + + def update(self, step, grads, weights, slots, opt_params): + del step, slots + lr = opt_params["learning_rate"] + new_weights = weights - (lr * grads).astype(weights.dtype) + return new_weights, None diff --git a/trax/optimizers/sm3.py b/trax/optimizers/sm3.py index e716bd32c..2a101ffd9 100644 --- a/trax/optimizers/sm3.py +++ b/trax/optimizers/sm3.py @@ -22,173 +22,181 @@ class MomentumType(enum.IntEnum): - EMA = 1 - HEAVY_BALL = 2 - NESTEROV = 3 + EMA = 1 + HEAVY_BALL = 2 + NESTEROV = 3 class SM3(opt_base.Optimizer): - """SM3 optimizer, as described in https://arxiv.org/abs/1901.11150.""" - - def __init__(self, - learning_rate=0.01, - momentum=0.9, - second_moment_averaging=1.0, - weight_decay=0.0, - momentum_type=MomentumType.EMA): # pylint: disable=useless-super-delegation - """Create the SM3 optimizer. - - Memory-Efficient Adaptive Optimization. - https://arxiv.org/abs/1901.11150 - - Args: - learning_rate: a postitive scalar value for the initial learning rate. - momentum: optional, a positive scalar value for momentum - second_moment_averaging: averaging of second moments (if 1.0, adds from - begining of time like AdaGrad). - weight_decay: Weight decay for regularizing the model. - momentum_type: Nestrov, Heavy-Ball or EMA (Default). - - """ - self._has_momentum = momentum > 0.0 - self._momentum_type = momentum_type - self._graft = second_moment_averaging != 1.0 - super().__init__( - learning_rate=learning_rate, - momentum=momentum, - second_moment_averaging=second_moment_averaging, - weight_decay=weight_decay, - ) - - def init(self, w): - momentum = [] - if self._has_momentum: - momentum = jnp.zeros_like(w) - v1s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] - v2s = [] - if self._graft: - v2s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] - return (momentum, v1s, v2s) - - def _momentum_update(self, g, m, beta1): - """Handle various types of momentum.""" - if self._momentum_type == MomentumType.EMA: - m = (1 - beta1) * g + beta1 * m - update = m - elif self._momentum_type == MomentumType.HEAVY_BALL: - m = g + beta1 * m - update = m - elif self._momentum_type == MomentumType.NESTEROV: - m = g + beta1 * m - nesterov_m = g + beta1 * m - update = nesterov_m - else: - assert False, 'Unknown momentum_type.' - return m, update - - def _update_diagonal(self, g, w, m, v1, v2, opt_params): - learning_rate = opt_params['learning_rate'] - beta2 = opt_params['second_moment_averaging'] - weight_decay = opt_params['weight_decay'] - - is_beta2_1 = (beta2 == 1).astype(g.dtype) - one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) - v1[0] = beta2 * v1[0] + one_minus_beta2_except1 * g * g - - preconditioner = jnp.where(v1[0] > 0, 1.0 / (jnp.sqrt(v1[0]) + 1e-16), - jnp.zeros_like(v1[0])) - - pg = preconditioner * g - if self._graft: - v2[0] += g * g - preconditioner_graft = jnp.where( - v2[0] > 0, 1.0 / (jnp.sqrt(v2[0]) + 1e-16), jnp.zeros_like(v2[0])) - pg_graft = preconditioner_graft * g - pg_norm = jnp.linalg.norm(pg) - pg_graft_norm = jnp.linalg.norm(pg_graft) - pg = pg * (pg_graft_norm/(pg_norm + 1e-16)) - - pg = pg + w * weight_decay - - if self._has_momentum: - m, update = self._momentum_update(pg, m, opt_params['momentum']) - else: - update = pg - - w = w - (update * learning_rate).astype(w.dtype) - return w, (m, v1, v2) - - def _expanded_shape(self, shape, axis): - # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. - # For eg: i = 1 returns [1, N, 1]. - rank = len(shape) - return [1] * axis + [shape[axis]] + [1] * (rank - axis - 1) - - def _minimum(self, tensor_list): - minimum = tensor_list[0] - for i in range(1, len(tensor_list)): - minimum = jnp.minimum(minimum, tensor_list[i]) - return minimum - - def _update_sketched(self, g, w, m, v1, v2, opt_params): - """Update for higher-rank parameters.""" - learning_rate = opt_params['learning_rate'] - momentum = opt_params['momentum'] - beta2 = opt_params['second_moment_averaging'] - weight_decay = opt_params['weight_decay'] - - shape = w.shape - rank = len(shape) - reshaped_accumulators = [jnp.reshape(v1[i], self._expanded_shape(shape, i)) - for i in range(rank)] - acc = self._minimum(reshaped_accumulators) - - is_beta2_1 = (beta2 == 1).astype(g.dtype) - one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) - acc = beta2 * acc + one_minus_beta2_except1 * g * g - - preconditioner = jnp.where(acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16), - jnp.zeros_like(acc)) - pg = g * preconditioner - if self._graft: - v2_acc = self._minimum([ - jnp.reshape(v2[i], self._expanded_shape(shape, i)) - for i in range(rank) - ]) - v2_acc = v2_acc + g * g - preconditioner_graft = jnp.where(v2_acc > 0.0, - 1.0 / (jnp.sqrt(v2_acc) + 1e-16), - jnp.zeros_like(v2_acc)) - pg_graft = preconditioner_graft * g - pg_norm = jnp.linalg.norm(pg) - pg_graft_norm = jnp.linalg.norm(pg_graft) - pg = pg * (pg_graft_norm/(pg_norm + 1e-16)) - - pg = pg + w * weight_decay - - if self._has_momentum: - m, update = self._momentum_update(pg, m, momentum) - else: - update = pg - - w = w - (learning_rate * update).astype(w.dtype) - for i in range(len(v1)): - axes = list(range(int(i))) + list(range(int(i) + 1, rank)) - dim_accumulator = jnp.amax(acc, axis=axes) - v1[i] = dim_accumulator - - if self._graft: - for i in range(len(v2)): - axes = list(range(int(i))) + list(range(int(i) + 1, rank)) - dim_accumulator = jnp.amax(v2_acc, axis=axes) - v2[i] = dim_accumulator - return w, (m, v1, v2) - - def update(self, step, g, w, slots, opt_params): - del step - m, v1, v2 = slots - rank = len(w.shape) - if rank > 1: - return self._update_sketched(g, w, m, v1, v2, opt_params) - else: - return self._update_diagonal(g, w, m, v1, v2, opt_params) + """SM3 optimizer, as described in https://arxiv.org/abs/1901.11150.""" + + def __init__( + self, + learning_rate=0.01, + momentum=0.9, + second_moment_averaging=1.0, + weight_decay=0.0, + momentum_type=MomentumType.EMA, + ): # pylint: disable=useless-super-delegation + """Create the SM3 optimizer. + + Memory-Efficient Adaptive Optimization. + https://arxiv.org/abs/1901.11150 + + Args: + learning_rate: a postitive scalar value for the initial learning rate. + momentum: optional, a positive scalar value for momentum + second_moment_averaging: averaging of second moments (if 1.0, adds from + begining of time like AdaGrad). + weight_decay: Weight decay for regularizing the model. + momentum_type: Nestrov, Heavy-Ball or EMA (Default). + + """ + self._has_momentum = momentum > 0.0 + self._momentum_type = momentum_type + self._graft = second_moment_averaging != 1.0 + super().__init__( + learning_rate=learning_rate, + momentum=momentum, + second_moment_averaging=second_moment_averaging, + weight_decay=weight_decay, + ) + + def init(self, w): + momentum = [] + if self._has_momentum: + momentum = jnp.zeros_like(w) + v1s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] + v2s = [] + if self._graft: + v2s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] + return (momentum, v1s, v2s) + + def _momentum_update(self, g, m, beta1): + """Handle various types of momentum.""" + if self._momentum_type == MomentumType.EMA: + m = (1 - beta1) * g + beta1 * m + update = m + elif self._momentum_type == MomentumType.HEAVY_BALL: + m = g + beta1 * m + update = m + elif self._momentum_type == MomentumType.NESTEROV: + m = g + beta1 * m + nesterov_m = g + beta1 * m + update = nesterov_m + else: + assert False, "Unknown momentum_type." + return m, update + + def _update_diagonal(self, g, w, m, v1, v2, opt_params): + learning_rate = opt_params["learning_rate"] + beta2 = opt_params["second_moment_averaging"] + weight_decay = opt_params["weight_decay"] + + is_beta2_1 = (beta2 == 1).astype(g.dtype) + one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) + v1[0] = beta2 * v1[0] + one_minus_beta2_except1 * g * g + + preconditioner = jnp.where( + v1[0] > 0, 1.0 / (jnp.sqrt(v1[0]) + 1e-16), jnp.zeros_like(v1[0]) + ) + + pg = preconditioner * g + if self._graft: + v2[0] += g * g + preconditioner_graft = jnp.where( + v2[0] > 0, 1.0 / (jnp.sqrt(v2[0]) + 1e-16), jnp.zeros_like(v2[0]) + ) + pg_graft = preconditioner_graft * g + pg_norm = jnp.linalg.norm(pg) + pg_graft_norm = jnp.linalg.norm(pg_graft) + pg = pg * (pg_graft_norm / (pg_norm + 1e-16)) + + pg = pg + w * weight_decay + + if self._has_momentum: + m, update = self._momentum_update(pg, m, opt_params["momentum"]) + else: + update = pg + + w = w - (update * learning_rate).astype(w.dtype) + return w, (m, v1, v2) + + def _expanded_shape(self, shape, axis): + # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. + # For eg: i = 1 returns [1, N, 1]. + rank = len(shape) + return [1] * axis + [shape[axis]] + [1] * (rank - axis - 1) + + def _minimum(self, tensor_list): + minimum = tensor_list[0] + for i in range(1, len(tensor_list)): + minimum = jnp.minimum(minimum, tensor_list[i]) + return minimum + + def _update_sketched(self, g, w, m, v1, v2, opt_params): + """Update for higher-rank parameters.""" + learning_rate = opt_params["learning_rate"] + momentum = opt_params["momentum"] + beta2 = opt_params["second_moment_averaging"] + weight_decay = opt_params["weight_decay"] + + shape = w.shape + rank = len(shape) + reshaped_accumulators = [ + jnp.reshape(v1[i], self._expanded_shape(shape, i)) for i in range(rank) + ] + acc = self._minimum(reshaped_accumulators) + + is_beta2_1 = (beta2 == 1).astype(g.dtype) + one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) + acc = beta2 * acc + one_minus_beta2_except1 * g * g + + preconditioner = jnp.where( + acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16), jnp.zeros_like(acc) + ) + pg = g * preconditioner + if self._graft: + v2_acc = self._minimum( + [ + jnp.reshape(v2[i], self._expanded_shape(shape, i)) + for i in range(rank) + ] + ) + v2_acc = v2_acc + g * g + preconditioner_graft = jnp.where( + v2_acc > 0.0, 1.0 / (jnp.sqrt(v2_acc) + 1e-16), jnp.zeros_like(v2_acc) + ) + pg_graft = preconditioner_graft * g + pg_norm = jnp.linalg.norm(pg) + pg_graft_norm = jnp.linalg.norm(pg_graft) + pg = pg * (pg_graft_norm / (pg_norm + 1e-16)) + + pg = pg + w * weight_decay + + if self._has_momentum: + m, update = self._momentum_update(pg, m, momentum) + else: + update = pg + + w = w - (learning_rate * update).astype(w.dtype) + for i in range(len(v1)): + axes = list(range(int(i))) + list(range(int(i) + 1, rank)) + dim_accumulator = jnp.amax(acc, axis=axes) + v1[i] = dim_accumulator + + if self._graft: + for i in range(len(v2)): + axes = list(range(int(i))) + list(range(int(i) + 1, rank)) + dim_accumulator = jnp.amax(v2_acc, axis=axes) + v2[i] = dim_accumulator + return w, (m, v1, v2) + + def update(self, step, g, w, slots, opt_params): + del step + m, v1, v2 = slots + rank = len(w.shape) + if rank > 1: + return self._update_sketched(g, w, m, v1, v2, opt_params) + else: + return self._update_diagonal(g, w, m, v1, v2, opt_params) diff --git a/trax/optimizers/trainer.py b/trax/optimizers/trainer.py index c633bface..1ce73517c 100644 --- a/trax/optimizers/trainer.py +++ b/trax/optimizers/trainer.py @@ -15,15 +15,13 @@ """Multi-device accelerated optimization.""" -from concurrent import futures import functools import os -import time -from absl import logging import jax import numpy as np import psutil +from absl import logging from trax import fastmath from trax import layers as tl @@ -33,722 +31,830 @@ class Trainer: - """Multi-device accelerated trainer. - - Given an optimizer and a composite layer containing model+loss, this class - creates a multi-device accelerated function with which it can compute one step - of updates to the model's weights/state and the optimizer slots. By default - it uses all available accelerators, via JIT compilation and parallel mapping. - - The optimizer and model must be initialized prior to use by this class. - - The key `one_step` function runs one forward-backward pass through the model, - and returns the resulting loss value and updated optimizer statistics. As a - side effect, the function also modifies the model weights and optimizer slots. - """ - - def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False): - self._model_with_loss = model_with_loss - self._optimizer = optimizer - self._n_devices = n_devices or fastmath.local_device_count() - self._adasum = adasum - - # optimizer slots and opt_params may need to be replicated - self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices( - (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)) - - # accelerated version of model+loss to replicate weights and state - self._accelerated_model_with_loss = tl.Accelerate( - model_with_loss, n_devices=n_devices) - - # Signature: - # (batch, weights, state, rng) -> ((loss, state), gradients) - self._forward_and_backward_fn = ( - fastmath.value_and_grad( - model_with_loss.pure_fn, - argnums=1, # arg1 of pure_fn: weights - has_aux=True)) # return (loss, state), gradients - - # Signature: - # (weights, slots), step, opt_params, batch, state, rng -> - # (weights, slots), state, stats - self._accelerated_update_fn = ( - _accelerate_update_fn( + """Multi-device accelerated trainer. + + Given an optimizer and a composite layer containing model+loss, this class + creates a multi-device accelerated function with which it can compute one step + of updates to the model's weights/state and the optimizer slots. By default + it uses all available accelerators, via JIT compilation and parallel mapping. + + The optimizer and model must be initialized prior to use by this class. + + The key `one_step` function runs one forward-backward pass through the model, + and returns the resulting loss value and updated optimizer statistics. As a + side effect, the function also modifies the model weights and optimizer slots. + """ + + def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False): + self._model_with_loss = model_with_loss + self._optimizer = optimizer + self._n_devices = n_devices or fastmath.local_device_count() + self._adasum = adasum + + # optimizer slots and opt_params may need to be replicated + self._slots, self._opt_params = tl.on_cpu( + tl.for_n_devices( + (self._optimizer.slots, self._optimizer.opt_params), self._n_devices + ) + ) + + # accelerated version of model+loss to replicate weights and state + self._accelerated_model_with_loss = tl.Accelerate( + model_with_loss, n_devices=n_devices + ) + + # Signature: + # (batch, weights, state, rng) -> ((loss, state), gradients) + self._forward_and_backward_fn = fastmath.value_and_grad( + model_with_loss.pure_fn, argnums=1, has_aux=True # arg1 of pure_fn: weights + ) # return (loss, state), gradients + + # Signature: + # (weights, slots), step, opt_params, batch, state, rng -> + # (weights, slots), state, stats + self._accelerated_update_fn = _accelerate_update_fn( self._forward_and_backward_fn, self._optimizer, n_devices=self._n_devices, accelerate=True, - adasum=self._adasum + adasum=self._adasum, ) - ) - - @property - def model_with_loss(self): - """Returns the composite model+loss for this instance.""" - return self._model_with_loss - - @property - def accelerated_model_with_loss(self): - """Returns the accelerated composite model+loss for this instance.""" - return self._accelerated_model_with_loss - @property - def optimizer(self): - """Returns the optimizer for this instance.""" - return self._optimizer - - @property - def slots(self): - """Returns the slots of the optimizers.""" - return self._optimizer.slots - - @slots.setter - def slots(self, slots): - """Sets the slots of the optimizers and this class (replicated).""" - self._optimizer.slots = slots - self._slots = tl.on_cpu(tl.for_n_devices(slots, self._n_devices)) + @property + def model_with_loss(self): + """Returns the composite model+loss for this instance.""" + return self._model_with_loss + + @property + def accelerated_model_with_loss(self): + """Returns the accelerated composite model+loss for this instance.""" + return self._accelerated_model_with_loss + + @property + def optimizer(self): + """Returns the optimizer for this instance.""" + return self._optimizer + + @property + def slots(self): + """Returns the slots of the optimizers.""" + return self._optimizer.slots + + @slots.setter + def slots(self, slots): + """Sets the slots of the optimizers and this class (replicated).""" + self._optimizer.slots = slots + self._slots = tl.on_cpu(tl.for_n_devices(slots, self._n_devices)) + + def one_step(self, batch, rng, step=0, learning_rate=None): + """Runs one training step, to update model and optimizer parameters. + + Args: + batch: Batch of labeled training data. + rng: Single-use random number generator (JAX PRNG key). + step: Training step number. + learning_rate: Learning rate for the optimizer; if None, use optimizer's + default learning rate. + + Returns: + Tuple of (loss, optimizer_stats), with the newly computed loss and + updated stats as reported by the optimizer. + """ + if learning_rate is not None: + self._opt_params["learning_rate"] = tl.for_n_devices( + learning_rate, self._n_devices + ) + + # Split the batch across devices (batch_dim --> batch_dim // n_devices) + # and create new rng's 1-1 with devices. + if self._n_devices > 1: + batch = tl.reshape_by_device(batch, self._n_devices) + rng = jnp.stack(fastmath.random.split(rng, self._n_devices)) + + weights = self._accelerated_model_with_loss.weights + state = self._accelerated_model_with_loss.state + if logging.vlog_is_on(1) and ((step & step - 1) == 0): + # Prints every power of two, if debugging is enabled. + logging.info("step[%d]", step) + logging.info("opt_params[%s]", self._opt_params) + logging.info("slots[%s]", self._slots) + logging.info("weights[%s]", weights) + logging.info("state[%s]", state) + + # NOTE: stats is a replicated dictionary of key to jnp arrays. + (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( + (weights, self._slots), step, self._opt_params, batch, state, rng + ) - def one_step(self, batch, rng, step=0, learning_rate=None): - """Runs one training step, to update model and optimizer parameters. + if logging.vlog_is_on(1) and ((step & step - 1) == 0): + logging.info("updated weights[%s]", new_weights) + logging.info("stats[%s]", stats) - Args: - batch: Batch of labeled training data. - rng: Single-use random number generator (JAX PRNG key). - step: Training step number. - learning_rate: Learning rate for the optimizer; if None, use optimizer's - default learning rate. + self._accelerated_model_with_loss.weights = new_weights + self._accelerated_model_with_loss.state = new_state + self._slots = new_slots + self._optimizer.slots = self._unreplicate(self._slots) + return stats["loss"], stats - Returns: - Tuple of (loss, optimizer_stats), with the newly computed loss and - updated stats as reported by the optimizer. - """ - if learning_rate is not None: - self._opt_params['learning_rate'] = tl.for_n_devices( - learning_rate, self._n_devices) - - # Split the batch across devices (batch_dim --> batch_dim // n_devices) - # and create new rng's 1-1 with devices. - if self._n_devices > 1: - batch = tl.reshape_by_device(batch, self._n_devices) - rng = jnp.stack(fastmath.random.split(rng, self._n_devices)) - - weights = self._accelerated_model_with_loss.weights - state = self._accelerated_model_with_loss.state - if logging.vlog_is_on(1) and ((step & step - 1) == 0): - # Prints every power of two, if debugging is enabled. - logging.info('step[%d]', step) - logging.info('opt_params[%s]', self._opt_params) - logging.info('slots[%s]', self._slots) - logging.info('weights[%s]', weights) - logging.info('state[%s]', state) - - # NOTE: stats is a replicated dictionary of key to jnp arrays. - (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( - (weights, self._slots), step, self._opt_params, batch, state, rng) - - if logging.vlog_is_on(1) and ((step & step - 1) == 0): - logging.info('updated weights[%s]', new_weights) - logging.info('stats[%s]', stats) - - self._accelerated_model_with_loss.weights = new_weights - self._accelerated_model_with_loss.state = new_state - self._slots = new_slots - self._optimizer.slots = self._unreplicate(self._slots) - return stats['loss'], stats - - def _unreplicate(self, x): - if self._n_devices == 1: - return x - return fastmath.nested_map(lambda x: x[0], x) + def _unreplicate(self, x): + if self._n_devices == 1: + return x + return fastmath.nested_map(lambda x: x[0], x) def _adasum_merge(g1, g2): - """Adasum gradient composition, see https://arxiv.org/pdf/2006.02924.pdf.""" - frac1 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g1, g1) + 1e-30) - frac2 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g2, g2) + 1e-30) - return (1 - frac1) * g1 + (1 - frac2) * g2 + """Adasum gradient composition, see https://arxiv.org/pdf/2006.02924.pdf.""" + frac1 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g1, g1) + 1e-30) + frac2 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g2, g2) + 1e-30) + return (1 - frac1) * g1 + (1 - frac2) * g2 def _average_multidevice_gradients(gradients, adasum=False): - """Averages gradients over all the devices across different hosts.""" - n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS - if adasum: - # This implements a version of the Adasum algorithm from the following - # paper: https://arxiv.org/pdf/2006.02924.pdf - lg = max([i for i in range(20) if 2**i <= n]) - for lg_i in range(lg): - shift = 2**lg_i - perm = [] - for i in range(n): - block_i = i % (2*shift) # we do blocks of 2*shift size - if block_i < shift: - perm.append((i, i+shift)) - else: - perm.append((i, i-shift)) - perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name='batch') - gradients = fastmath.nested_map_multiarg( - _adasum_merge, gradients, perm_grad) - if base.N_WEIGHTS_SHARDS > 1: # only sum gradients from matching shards - groups = [[base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))] - for d in range(base.N_WEIGHTS_SHARDS)] - gradients_psum = fastmath.psum(gradients, 'batch', - axis_index_groups=groups) - else: - gradients_psum = fastmath.psum(gradients, 'batch') # sum all gradients - n = jnp.array(n, dtype=jnp.float32) - return fastmath.nested_map(lambda g: g / n, gradients_psum) + """Averages gradients over all the devices across different hosts.""" + n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS + if adasum: + # This implements a version of the Adasum algorithm from the following + # paper: https://arxiv.org/pdf/2006.02924.pdf + lg = max([i for i in range(20) if 2**i <= n]) + for lg_i in range(lg): + shift = 2**lg_i + perm = [] + for i in range(n): + block_i = i % (2 * shift) # we do blocks of 2*shift size + if block_i < shift: + perm.append((i, i + shift)) + else: + perm.append((i, i - shift)) + perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name="batch") + gradients = fastmath.nested_map_multiarg( + _adasum_merge, gradients, perm_grad + ) + if base.N_WEIGHTS_SHARDS > 1: # only sum gradients from matching shards + groups = [ + [base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))] + for d in range(base.N_WEIGHTS_SHARDS) + ] + gradients_psum = fastmath.psum(gradients, "batch", axis_index_groups=groups) + else: + gradients_psum = fastmath.psum(gradients, "batch") # sum all gradients + n = jnp.array(n, dtype=jnp.float32) + return fastmath.nested_map(lambda g: g / n, gradients_psum) # Returns a function with the following signature: # (weights, slots), step, opt_params, batch, state, rng -> # (weights, slots), state, stats -def _accelerate_update_fn(forward_and_backward_fn, - optimizer, - n_devices, - accelerate=True, - adasum=False): - """Accelerates the given forward_and_backward_fn function.""" - if n_devices == 1: - def single_device_update_fn( - weights_and_slots, step, opt_params, batch, state, rng): - step = jnp.array(step, dtype=jnp.int32) # Needed in TFNP backend. - weights, slots = weights_and_slots - (loss, state), gradients = forward_and_backward_fn( - batch, weights, state, rng) - weights, slots, stats = optimizer.tree_update( - step, gradients, weights, slots, opt_params, store_slots=False) - stats['loss'] = loss - return (weights, slots), state, stats - if accelerate: - # TODO(afrozm): Find out the status of buffer donation on GPUs, then do - # donate_argnums=(0,). - single_device_update_fn = fastmath.jit(single_device_update_fn) - return single_device_update_fn - - # More than one device (core), i.e. all of TPU configurations etc. - assert n_devices > 1, f'{n_devices} should be greater than 1.' - - @functools.partial(fastmath.pmap, axis_name='batch', donate_argnums=(0,)) - def _multi_device_update_fn( - weights_and_slots, step, opt_params, batch, state, rng): - # All tensors should have the first dimension = n_devices. - weights, slots = weights_and_slots - (loss, state), gradients = ( - forward_and_backward_fn(batch, weights, state, rng)) - gradients = _average_multidevice_gradients(gradients, adasum=adasum) - weights, slots, stats = optimizer.tree_update( - step, gradients, weights, slots, opt_params, store_slots=False) - stats['loss'] = loss - return (weights, slots), state, stats - - def multi_device_update_fn( - weights_and_slots, step, opt_params, batch, state, rng): - # Need to replicate step to n_devices leading dimension. - return _multi_device_update_fn(weights_and_slots, - jnp.repeat(step, n_devices), opt_params, - batch, state, rng) - - return multi_device_update_fn - +def _accelerate_update_fn( + forward_and_backward_fn, optimizer, n_devices, accelerate=True, adasum=False +): + """Accelerates the given forward_and_backward_fn function.""" + if n_devices == 1: + + def single_device_update_fn( + weights_and_slots, step, opt_params, batch, state, rng + ): + step = jnp.array(step, dtype=jnp.int32) # Needed in TFNP backend. + weights, slots = weights_and_slots + + (loss, state), gradients = forward_and_backward_fn( + batch, weights, state, rng + ) + + weights, slots, stats = optimizer.tree_update( + step, gradients, weights, slots, opt_params, store_slots=False + ) + stats["loss"] = loss + return (weights, slots), state, stats + + if accelerate: + # TODO(afrozm): Find out the status of buffer donation on GPUs, then do + # donate_argnums=(0,). + single_device_update_fn = fastmath.jit(single_device_update_fn) + + return single_device_update_fn + + # More than one device (core), i.e. all of TPU configurations etc. + assert n_devices > 1, f"{n_devices} should be greater than 1." + + @functools.partial(fastmath.pmap, axis_name="batch", donate_argnums=(0,)) + def _multi_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): + # All tensors should have the first dimension = n_devices. + weights, slots = weights_and_slots + (loss, state), gradients = forward_and_backward_fn(batch, weights, state, rng) + gradients = _average_multidevice_gradients(gradients, adasum=adasum) + weights, slots, stats = optimizer.tree_update( + step, gradients, weights, slots, opt_params, store_slots=False + ) + stats["loss"] = loss + return (weights, slots), state, stats + + def multi_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): + # Need to replicate step to n_devices leading dimension. + return _multi_device_update_fn( + weights_and_slots, + jnp.repeat(step, n_devices), + opt_params, + batch, + state, + rng, + ) -class ReversibleSerialTrainer: - """Runs an optimizer on a series of layers, reversible and not. + return multi_device_update_fn - We provide layers to this trainer in blocks, each block consisting of - a list of standard layers and a list of reversible layers. They all run - in turn (like one huge Serial block) but in a more memory-efficient way. - The main motivation for this class is to save memory: it allows to train - models that have more weights than the memory available on accelerators. - This happens by caching the weights in CPU memory and transferring only - the weights of one layer at a time. The reversible layers are used to make - the backward pass without using additional memory for storing activations. +class ReversibleSerialTrainer: + """Runs an optimizer on a series of layers, reversible and not. - Note: we do not allow sharing weights between blocks for now. - """ + We provide layers to this trainer in blocks, each block consisting of + a list of standard layers and a list of reversible layers. They all run + in turn (like one huge Serial block) but in a more memory-efficient way. - def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None, - memoize_jit=True, free_accelerators_on_step=False, adasum=False): - """Creates a ReversibleSerialTrainer and the needed optimizers. + The main motivation for this class is to save memory: it allows to train + models that have more weights than the memory available on accelerators. + This happens by caching the weights in CPU memory and transferring only + the weights of one layer at a time. The reversible layers are used to make + the backward pass without using additional memory for storing activations. - This trainer performs updates equivalent to using the default Trainer on:: + Note: we do not allow sharing weights between blocks for now. + """ - tl.Serial(blocks + [loss_layer]). + def __init__( + self, + blocks, + loss_layer, + optimizer_fn, + n_devices=None, + memoize_jit=True, + free_accelerators_on_step=False, + adasum=False, + ): + """Creates a ReversibleSerialTrainer and the needed optimizers. + + This trainer performs updates equivalent to using the default Trainer on:: + + tl.Serial(blocks + [loss_layer]). + + It is more memory-efficient though since weights are stored on CPU and only + sent to accelerator layer-by-layer. Blocks are pairs consisting of a list + of standard (arbitrary) layers and a list of reversible layers which help + save memory thanks to being reversible. + + Args: + blocks: A list of pairs of lists of standard and reversible layers. + loss_layer: The final layer of the model; it can have trainable weights + but should end with a loss: it is required to produce a scalar output. + optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. + n_devices: An optional integer, number of accelerator devices to use; + by default, all available accelerators will be used. + memoize_jit: Whether to memoize JITed functions; this significantly speeds + up XLA compilation of larger models, but it uses `repr(layer)` as keys + to memoize so it could fail if two layers with different functionality + had the same string representaion. We have not encountered such case + yet so this is turned on by default, but consider turning it off or + reviewing your model if you use custom layers and encounter a problem. + free_accelerators_on_step: If true, frees memory on accelerators when + starting a step. All layers and arguments must be on host for that, + otherwise it can lead to failures. Can prevent memory fragmentation. + adasum: if True, use adaptive summation to gather multi-device gradients. + """ + self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] + self._loss_layer = loss_layer + self._optimizer_fn = optimizer_fn + self._n_devices = n_devices or fastmath.local_device_count() + self._adasum = adasum + self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) + self._n_steps_per_log = 100 # Log layers and stats every 100 steps. + self._n_async_layers = 1 # How many layers to run asynchronously. + self._jit_memory = {} if memoize_jit else None + self._do_free = free_accelerators_on_step + self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs, backend="cpu") + + # Create accelerated versions of layers as pmaped/jited pure_fn. + self._accelerated_layer_fns = fastmath.nested_map( + lambda layer: self._pjit(layer.pure_fn, f"fwd {repr(layer)}"), self._blocks + ) - It is more memory-efficient though since weights are stored on CPU and only - sent to accelerator layer-by-layer. Blocks are pairs consisting of a list - of standard (arbitrary) layers and a list of reversible layers which help - save memory thanks to being reversible. + # Create per-layer optimizers and replicate opt_params. + def _make_optimizer(layer): + opt = optimizer_fn() + opt.tree_init(layer.weights) + opt.slots = tl.on_cpu(opt.slots) + return opt - Args: - blocks: A list of pairs of lists of standard and reversible layers. - loss_layer: The final layer of the model; it can have trainable weights - but should end with a loss: it is required to produce a scalar output. - optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. - n_devices: An optional integer, number of accelerator devices to use; - by default, all available accelerators will be used. - memoize_jit: Whether to memoize JITed functions; this significantly speeds - up XLA compilation of larger models, but it uses `repr(layer)` as keys - to memoize so it could fail if two layers with different functionality - had the same string representaion. We have not encountered such case - yet so this is turned on by default, but consider turning it off or - reviewing your model if you use custom layers and encounter a problem. - free_accelerators_on_step: If true, frees memory on accelerators when - starting a step. All layers and arguments must be on host for that, - otherwise it can lead to failures. Can prevent memory fragmentation. - adasum: if True, use adaptive summation to gather multi-device gradients. - """ - self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] - self._loss_layer = loss_layer - self._optimizer_fn = optimizer_fn - self._n_devices = n_devices or fastmath.local_device_count() - self._adasum = adasum - self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) - self._n_steps_per_log = 100 # Log layers and stats every 100 steps. - self._n_async_layers = 1 # How many layers to run asynchronously. - self._jit_memory = {} if memoize_jit else None - self._do_free = free_accelerators_on_step - self._jit_per_device_rngs = fastmath.jit( - self._per_device_rngs, backend='cpu') - - # Create accelerated versions of layers as pmaped/jited pure_fn. - self._accelerated_layer_fns = fastmath.nested_map( - lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'), - self._blocks) - - # Create per-layer optimizers and replicate opt_params. - def _make_optimizer(layer): - opt = optimizer_fn() - opt.tree_init(layer.weights) - opt.slots = tl.on_cpu(opt.slots) - return opt - - self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) - self._replicated_opt_params = fastmath.nested_map( - lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers) - - self._loss_opt = _make_optimizer(loss_layer) - self._replicated_loss_opt_params = self._replicate_cpu( - self._loss_opt.opt_params) - - # Forward + backward + optimizer-update functions for all layers. - # We call them in short FBO for "Forward + Backward + Optimizer update". - # Reversible layers define a reverse_and_fbo function that also reverses. - - self._fbos = [] - for i, (std_layer, rev_layers) in enumerate(self._blocks): - (std_opt, rev_opts) = self._optimizers[i] - std_fbo = _fbo_with_layer_and_opt( - std_layer, std_opt, self._n_devices, adasum=self._adasum) - rev_and_fbos = [] - for layer, opt in zip(rev_layers, rev_opts): - rev_and_fbo = _reverse_and_fbo_with_layer_and_opt( - layer, opt, self._n_devices, self._adasum) - # The donated args are (outputs, weights, grads) and we can donate - # them because weights and grads are immediately replaced and in - # case of reversible layers, the outputs are never used again. - rev_and_fbos.append(self._pjit( - rev_and_fbo, f'rev+bwd {repr(layer)}', donate_argnums=(0, 1, 2))) - # In standard layers, the inputs cannot be donated as they may be used - # as outputs for the reversible block below, but weights and grads can. - jit_std_fbo = self._pjit( - std_fbo, f'bwd {repr(std_layer)}', donate_argnums=(1, 2)) - self._fbos.append((jit_std_fbo, rev_and_fbos)) - - loss_fbo = _fbo_with_layer_and_opt( - self._loss_layer, self._loss_opt, self._n_devices, 'loss', self._adasum) - self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2)) - - @property - def loss_layer(self): - """Returns the loss layer used to initialize this class.""" - return self._loss_layer - - @property - def all_layers(self): - """Returns all layers that compose the model and loss in this class.""" - layers = [] - for (std_layer, rev_layers) in self._blocks: - layers.append(std_layer) - layers.extend(rev_layers) - layers.append(self._loss_layer) - return layers - - @property - def optimizer_fn(self): - """Returns the optimizer function used to initialize this class.""" - return self._optimizer_fn - - @property - def slots(self): - """Returns the slots of all optimizers.""" - optimizers = list(self._optimizers) + [self._loss_opt] - return fastmath.nested_map(lambda opt: opt.slots, optimizers) - - @slots.setter - def slots(self, slots): - """Sets the slots of all optimizers.""" - for ((s_opt, r_opts), (s_slots, r_slots)) in zip( - self._optimizers, slots[:-1]): - for (opt, slot) in zip([s_opt] + r_opts, [s_slots] + r_slots): - opt.slots = slot - self._loss_opt.slots = slots[-1] - - def _pjit(self, f, memory_key=None, donate_argnums=()): - """JIT f if 1 device is available and pmap if more are available.""" - should_memoize = self._jit_memory is not None and memory_key is not None - if (should_memoize and memory_key in self._jit_memory): - logging.info('Found JITed function in memory for: %s', memory_key) - return self._jit_memory[memory_key] - if self._n_devices == 1: - res = fastmath.jit(f, donate_argnums=donate_argnums) - else: - res = fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums) - if should_memoize: - self._jit_memory[memory_key] = res - return res - - def _replicate(self, x): - if self._n_devices > 1: - return tl.for_n_devices(x, self._n_devices) - return tl.on_accelerator(x) - - def _replicate_cpu(self, x): - # TODO(lukaszkaiser): move it to layers/acceleration to be together with - # tl.for_n_devices and other functions like that, possibly refactor them. - def f(x): - if self._n_devices > 1: - return np.broadcast_to(x, (self._n_devices,) + np.asarray(x).shape) - else: - return x - return tl.on_cpu(fastmath.nested_map(f, x)) - - def _unreplicate(self, x): - if self._n_devices == 1: - return tl.on_cpu(x) - return tl.on_cpu(fastmath.nested_map(lambda x: x[0], x)) - - def _lazy_unreplicate(self, x): - def unreplicate_and_start_async_copy(y): - unreplicated = y if self._n_devices == 1 else y[0] - unreplicated.copy_to_host_async() - return unreplicated - return fastmath.nested_map(unreplicate_and_start_async_copy, x) - - def _collect_weights(self, layer): - layer.weights = fastmath.nested_map(np.asarray, layer.weights) - - def _free_accelerators(self, exceptions=(), keep_constants=True): - """Deletes all live buffers from accelerator with no safety guarantees.""" - backend = jax.lib.xla_bridge.get_backend() - live_buffers = backend.live_buffers() - logging.info('Deleting %d live buffers.', len(live_buffers)) - exceptions_buffers = [] - for x in fastmath.tree_flatten(exceptions): - if hasattr(x, 'device_buffer'): # DeviceArray - exceptions_buffers.append(x.device_buffer) - if hasattr(x, 'device_buffers'): # ShardedDeviceArray - exceptions_buffers.extend(x.device_buffers) - for b in live_buffers: - should_delete = True - for e in exceptions_buffers: - if b is e: - should_delete = False - if keep_constants and not b.shape: - should_delete = False - if should_delete: - b.delete() - - def _per_device_rngs(self, rng): - """Create per-device RNGs from a given rng.""" - # Splitting by device first to be identical with default trainer. - per_device_rng = fastmath.random.split(rng, self._n_devices) - per_device_rngs = [ - fastmath.random.split(r, self._n_layers) for r in per_device_rng] - rngs = [jnp.stack([r[i] for r in per_device_rngs]) - for i in range(self._n_layers)] - return rngs - - def one_step(self, batch, rng, step=0, learning_rate=None): - """Updates layers weights/state and optimizers slots by running one step. + self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) + self._replicated_opt_params = fastmath.nested_map( + lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers + ) - Args: - batch: Batch of data to use for optimization. - rng: Random number generator to use for running this step. - step: Which step of the training are we running. - learning_rate: Learning rate to use instead of the default one. + self._loss_opt = _make_optimizer(loss_layer) + self._replicated_loss_opt_params = self._replicate_cpu( + self._loss_opt.opt_params + ) - Returns: - Tuple (loss, stats) with new values from one step - of training, where stats are all optimizer statistics. - """ - # Update the learning rate if needed. - if learning_rate is not None: - self._replicated_loss_opt_params['learning_rate'] = self._replicate_cpu( - learning_rate) - for (std_op, rev_ops) in self._replicated_opt_params: - std_op['learning_rate'] = self._replicate_cpu(learning_rate) - for op in rev_ops: - op['learning_rate'] = self._replicate_cpu(learning_rate) - - # Batch needs to be split across the local devices -- the difference - # between _for_n_devices and _reshape_by_device is that the latter splits - # the batch dim to batch // n_devices, vs _for_n_devices - # broadcasts/replicates to n_devices dimension. - step_int = step - if self._n_devices > 1: - batch = tl.reshape_by_device(batch, self._n_devices, pure_np=True) - step = np.repeat(step, self._n_devices) - - # Create separate rng for each device and layer. - if self._n_devices == 1: - rngs = fastmath.random.split(rng, self._n_layers) - else: - # JIT the function and run it on CPU to avoid memory fragmentation. - rngs = self._jit_per_device_rngs(tl.on_cpu(rng)) - # Group rngs by layer blocks. - rng_blocks, rng_i = [], 0 - for _, rev_layers in self._blocks: - l = len(rev_layers) - rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1])) - rng_i += l + 1 - - # Run the layers forward upto the loss layer. - if self._do_free: - self._free_accelerators() - process = psutil.Process(os.getpid()) - if isinstance(batch, (list, tuple)): - batch_shapes = [x.shape for x in batch] - else: - batch_shapes = batch.shape - logging.info('running step %d on shapes %s', step_int, str(batch_shapes)) - if step_int % self._n_steps_per_log == 1: - logging.info('run fwd: cpu memory use (MB): %.2f', - process.memory_info().rss / float(1024 * 1024)) - - stack = batch - block_inputs_states = [] - for i, (std_layer, rev_layers) in enumerate(self._blocks): - acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i] - std_rng, rev_rngs = rng_blocks[i] - # Run the standard layer. - stack, std_inputs, std_state = self._run_forward_standard( - stack, std_layer, acc_std_layer_fn, std_rng, step_int) - - # Run the reversible layers and collect old and new states. - stack, rev_old_states, rev_new_states = self._run_forward_reversible( - stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int) - block_inputs_states.append(tl.on_cpu( - ((std_inputs, std_state), (rev_old_states, rev_new_states)))) - - # Run the loss layer forward and backward with optimizer update. - if step_int % self._n_steps_per_log == 1: - logging.info('run loss: cpu memory use (MB): %.2f', - process.memory_info().rss / float(1024 * 1024)) - loss_state = self._replicate(self._loss_layer.state) - loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) - loss_stats, grad_stack = self._run_backward_standard( - None, step, self._loss_layer, loss_inputs, - loss_state, self._loss_fbo, rngs[-1], self._loss_opt, - self._replicated_loss_opt_params) - self._collect_weights(self._loss_layer) - stats = [tl.on_cpu(loss_stats)] - - # De-fragment memory. - if self._do_free: - stack, grad_stack = tl.on_cpu(stack), tl.on_cpu(grad_stack) - self._free_accelerators() - - # Run the layers backward and run optimizer updates. - if step_int % self._n_steps_per_log == 1: - logging.info('run bwd: cpu memory use (MB): %.2f', - process.memory_info().rss / float(1024 * 1024)) - for i in range(len(self._blocks) - 1, -1, -1): - std_layer, rev_layers = self._blocks[i] - (std_inputs, std_state), (rev_old_states, - rev_new_states) = block_inputs_states[i] - std_fbo, rev_fbos = self._fbos[i] - std_opt, rev_opts = self._optimizers[i] - std_rng, rev_rngs = rng_blocks[i] - repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i] - - # Run reversible layers backward with optimizer update. - stack, grad_stack, new_stats = self._run_backward_reversible( - stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states, - rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params) - stats.extend(tl.on_cpu(new_stats)) - - # Run the standard layer forward-and-backward pass and optimizer update. - std_layer_stats, grad_stack = self._run_backward_standard( - grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng, - std_opt, repl_std_opt_params) - stack = cb.outputs_onto_stack( # Put layer inputs on the stack. - std_inputs, stack, std_layer.n_out) - stats.append(tl.on_cpu(std_layer_stats)) - - # Collect lazily unreplicated layer weights. - for rev_layer_id in range(self._n_async_layers): - self._collect_weights(rev_layers[rev_layer_id]) - self._collect_weights(std_layer) - - # Join stats from different optimizers into one. - joint_stats = {} - for i, stat in enumerate(reversed(stats)): - for k, v in stat.items(): - joint_stats[f'layer{i}/' + k] = v - return stats[0]['loss'], joint_stats - - def _run_forward_standard(self, stack, layer, accelerated_fn, rng, step): - """Run standard layer forward.""" - if step % self._n_steps_per_log == 1: - logging.info('running forward standard layer %s', str(layer)) - layer_inputs = cb.inputs_from_stack(stack, layer.n_in) - layer_weights = self._replicate(layer.weights) - layer_state = self._replicate(layer.state) - outputs, layer_new_state = accelerated_fn( - layer_inputs, layer_weights, layer_state, rng) - stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) - return stack, layer_inputs, layer_new_state - - def _run_forward_reversible(self, stack, rev_layers, accelerated_fns, - rngs, step): - """Run reversible layers forward, collect states for backwards pass.""" - old_states, new_states = [], [] - for i, layer in enumerate(rev_layers): - if step % self._n_steps_per_log == 1: - logging.info('running forward reversible layer %s', str(layer)) - weights = self._replicate(layer.weights) # also copies cpu -> accelerator - state = self._replicate(layer.state) - old_states.append(state) - inputs = cb.inputs_from_stack(stack, layer.n_in) - outputs, new_state = accelerated_fns[i]( - inputs, weights, state, rngs[i]) - stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) - new_states.append(new_state) - return stack, old_states, new_states - - def _run_backward_standard(self, grad_stack, step, layer, inp, state, - fbo_fn, rng, optimizer, replicated_opt_params): - """Run reversible layers backwards.""" - step_int = int(step) if self._n_devices < 2 else int(step[0]) - if step_int % self._n_steps_per_log == 1: - logging.info('running backward standard layer %s', str(layer)) - if grad_stack is not None: - grads = cb.inputs_from_stack(grad_stack, layer.n_out) - else: - grads = None - slots = self._replicate(optimizer.slots) - weights = self._replicate(layer.weights) - # Ensure all arguments are on accelerator. - state = tl.on_accelerator(state) - replicated_opt_params = tl.on_accelerator(replicated_opt_params) - rng = tl.on_accelerator(rng) - grads = tl.on_accelerator(grads) - inp = tl.on_accelerator(inp) - new_weights, new_state, new_slots, new_grads, stats = fbo_fn( - inp, weights, grads, state, slots, replicated_opt_params, rng, step) - layer.weights = self._lazy_unreplicate(new_weights) - layer.state = self._unreplicate(new_state) - optimizer.slots = self._unreplicate(new_slots) - if grad_stack is not None: - grad_stack = cb.outputs_onto_stack(new_grads, grad_stack, layer.n_out) - else: - grad_stack = new_grads - return stats, grad_stack - - def _run_backward_reversible(self, stack, grad_stack, step, - rev_layers, rev_and_fbos, - old_states, new_states, rngs, - optimizers, replicated_opt_params): - """Run reversible layers backwards.""" - counter = 0 - stats = [] - step_int = int(step) if self._n_devices < 2 else int(step[0]) - for layer, reverse_and_fbo, old_state, new_state, rng in reversed(list(zip( - rev_layers, rev_and_fbos, - old_states, new_states, rngs))): - if step_int % self._n_steps_per_log == 1: - logging.info('running backward reversible layer %s', str(layer)) - counter -= 1 - stack, grad_stack, layer_stats = self._run_backward_one_reversible( - layer, stack, grad_stack, step, rng, optimizers[counter], - replicated_opt_params[counter], reverse_and_fbo, old_state, new_state) - stats.append(layer_stats) - if counter + self._n_async_layers < 0: - self._collect_weights(rev_layers[counter + self._n_async_layers]) - return stack, grad_stack, stats - - def _run_backward_one_reversible(self, layer, stack, grad_stack, step, rng, - optimizer, opt_params, reverse_and_fbo, - old_state, new_state): - """Run one reversible layer backwards.""" - # We are running backwards and reversing, so we get *outputs* from stack. - outputs = cb.inputs_from_stack(stack, layer.n_out) - grads = cb.inputs_from_stack(grad_stack, layer.n_out) - slots = self._replicate(optimizer.slots) - weights = self._replicate(layer.weights) # cpu -> accelerator - # Ensure all arguments are on accelerator. - outputs = tl.on_accelerator(outputs) - grads = tl.on_accelerator(grads) - old_state = tl.on_accelerator(old_state) - new_state = tl.on_accelerator(new_state) - opt_params = tl.on_accelerator(opt_params) - rng = tl.on_accelerator(rng) - new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( - outputs, weights, grads, old_state, new_state, - slots, opt_params, rng, step) - layer.weights = self._lazy_unreplicate(new_weights) # accelerator -> cpu - layer.state = self._unreplicate(new_state) - optimizer.slots = self._unreplicate(new_slots) - stack = cb.outputs_onto_stack(inputs, stack, layer.n_out) - grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out) - return stack, grad_stack, layer_stats + # Forward + backward + optimizer-update functions for all layers. + # We call them in short FBO for "Forward + Backward + Optimizer update". + # Reversible layers define a reverse_and_fbo function that also reverses. + + self._fbos = [] + for i, (std_layer, rev_layers) in enumerate(self._blocks): + (std_opt, rev_opts) = self._optimizers[i] + std_fbo = _fbo_with_layer_and_opt( + std_layer, std_opt, self._n_devices, adasum=self._adasum + ) + rev_and_fbos = [] + for layer, opt in zip(rev_layers, rev_opts): + rev_and_fbo = _reverse_and_fbo_with_layer_and_opt( + layer, opt, self._n_devices, self._adasum + ) + # The donated args are (outputs, weights, grads) and we can donate + # them because weights and grads are immediately replaced and in + # case of reversible layers, the outputs are never used again. + rev_and_fbos.append( + self._pjit( + rev_and_fbo, f"rev+bwd {repr(layer)}", donate_argnums=(0, 1, 2) + ) + ) + # In standard layers, the inputs cannot be donated as they may be used + # as outputs for the reversible block below, but weights and grads can. + jit_std_fbo = self._pjit( + std_fbo, f"bwd {repr(std_layer)}", donate_argnums=(1, 2) + ) + self._fbos.append((jit_std_fbo, rev_and_fbos)) + + loss_fbo = _fbo_with_layer_and_opt( + self._loss_layer, self._loss_opt, self._n_devices, "loss", self._adasum + ) + self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2)) + + @property + def loss_layer(self): + """Returns the loss layer used to initialize this class.""" + return self._loss_layer + + @property + def all_layers(self): + """Returns all layers that compose the model and loss in this class.""" + layers = [] + for (std_layer, rev_layers) in self._blocks: + layers.append(std_layer) + layers.extend(rev_layers) + layers.append(self._loss_layer) + return layers + + @property + def optimizer_fn(self): + """Returns the optimizer function used to initialize this class.""" + return self._optimizer_fn + + @property + def slots(self): + """Returns the slots of all optimizers.""" + optimizers = list(self._optimizers) + [self._loss_opt] + return fastmath.nested_map(lambda opt: opt.slots, optimizers) + + @slots.setter + def slots(self, slots): + """Sets the slots of all optimizers.""" + for ((s_opt, r_opts), (s_slots, r_slots)) in zip(self._optimizers, slots[:-1]): + for (opt, slot) in zip([s_opt] + r_opts, [s_slots] + r_slots): + opt.slots = slot + self._loss_opt.slots = slots[-1] + + def _pjit(self, f, memory_key=None, donate_argnums=()): + """JIT f if 1 device is available and pmap if more are available.""" + should_memoize = self._jit_memory is not None and memory_key is not None + if should_memoize and memory_key in self._jit_memory: + logging.info("Found JITed function in memory for: %s", memory_key) + return self._jit_memory[memory_key] + if self._n_devices == 1: + res = fastmath.jit(f, donate_argnums=donate_argnums) + else: + res = fastmath.pmap(f, axis_name="batch", donate_argnums=donate_argnums) + if should_memoize: + self._jit_memory[memory_key] = res + return res + + def _replicate(self, x): + if self._n_devices > 1: + return tl.for_n_devices(x, self._n_devices) + return tl.on_accelerator(x) + + def _replicate_cpu(self, x): + # TODO(lukaszkaiser): move it to layers/acceleration to be together with + # tl.for_n_devices and other functions like that, possibly refactor them. + def f(x): + if self._n_devices > 1: + return np.broadcast_to(x, (self._n_devices,) + np.asarray(x).shape) + else: + return x + + return tl.on_cpu(fastmath.nested_map(f, x)) + + def _unreplicate(self, x): + if self._n_devices == 1: + return tl.on_cpu(x) + return tl.on_cpu(fastmath.nested_map(lambda x: x[0], x)) + + def _lazy_unreplicate(self, x): + def unreplicate_and_start_async_copy(y): + unreplicated = y if self._n_devices == 1 else y[0] + unreplicated.copy_to_host_async() + return unreplicated + + return fastmath.nested_map(unreplicate_and_start_async_copy, x) + + def _collect_weights(self, layer): + layer.weights = fastmath.nested_map(np.asarray, layer.weights) + + def _free_accelerators(self, exceptions=(), keep_constants=True): + """Deletes all live buffers from accelerator with no safety guarantees.""" + backend = jax.lib.xla_bridge.get_backend() + live_buffers = backend.live_buffers() + logging.info("Deleting %d live buffers.", len(live_buffers)) + exceptions_buffers = [] + for x in fastmath.tree_flatten(exceptions): + if hasattr(x, "device_buffer"): # DeviceArray + exceptions_buffers.append(x.device_buffer) + if hasattr(x, "device_buffers"): # ShardedDeviceArray + exceptions_buffers.extend(x.device_buffers) + for b in live_buffers: + should_delete = True + for e in exceptions_buffers: + if b is e: + should_delete = False + if keep_constants and not b.shape: + should_delete = False + if should_delete: + b.delete() + + def _per_device_rngs(self, rng): + """Create per-device RNGs from a given rng.""" + # Splitting by device first to be identical with default trainer. + per_device_rng = fastmath.random.split(rng, self._n_devices) + per_device_rngs = [ + fastmath.random.split(r, self._n_layers) for r in per_device_rng + ] + rngs = [ + jnp.stack([r[i] for r in per_device_rngs]) for i in range(self._n_layers) + ] + return rngs + + def one_step(self, batch, rng, step=0, learning_rate=None): + """Updates layers weights/state and optimizers slots by running one step. + + Args: + batch: Batch of data to use for optimization. + rng: Random number generator to use for running this step. + step: Which step of the training are we running. + learning_rate: Learning rate to use instead of the default one. + + Returns: + Tuple (loss, stats) with new values from one step + of training, where stats are all optimizer statistics. + """ + # Update the learning rate if needed. + if learning_rate is not None: + self._replicated_loss_opt_params["learning_rate"] = self._replicate_cpu( + learning_rate + ) + for (std_op, rev_ops) in self._replicated_opt_params: + std_op["learning_rate"] = self._replicate_cpu(learning_rate) + for op in rev_ops: + op["learning_rate"] = self._replicate_cpu(learning_rate) + + # Batch needs to be split across the local devices -- the difference + # between _for_n_devices and _reshape_by_device is that the latter splits + # the batch dim to batch // n_devices, vs _for_n_devices + # broadcasts/replicates to n_devices dimension. + step_int = step + if self._n_devices > 1: + batch = tl.reshape_by_device(batch, self._n_devices, pure_np=True) + step = np.repeat(step, self._n_devices) + + # Create separate rng for each device and layer. + if self._n_devices == 1: + rngs = fastmath.random.split(rng, self._n_layers) + else: + # JIT the function and run it on CPU to avoid memory fragmentation. + rngs = self._jit_per_device_rngs(tl.on_cpu(rng)) + # Group rngs by layer blocks. + rng_blocks, rng_i = [], 0 + for _, rev_layers in self._blocks: + l = len(rev_layers) + rng_blocks.append((rngs[rng_i], rngs[rng_i + 1 : rng_i + l + 1])) + rng_i += l + 1 + + # Run the layers forward upto the loss layer. + if self._do_free: + self._free_accelerators() + process = psutil.Process(os.getpid()) + if isinstance(batch, (list, tuple)): + batch_shapes = [x.shape for x in batch] + else: + batch_shapes = batch.shape + logging.info("running step %d on shapes %s", step_int, str(batch_shapes)) + if step_int % self._n_steps_per_log == 1: + logging.info( + "run fwd: cpu memory use (MB): %.2f", + process.memory_info().rss / float(1024 * 1024), + ) + + stack = batch + block_inputs_states = [] + for i, (std_layer, rev_layers) in enumerate(self._blocks): + acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i] + std_rng, rev_rngs = rng_blocks[i] + # Run the standard layer. + stack, std_inputs, std_state = self._run_forward_standard( + stack, std_layer, acc_std_layer_fn, std_rng, step_int + ) + + # Run the reversible layers and collect old and new states. + stack, rev_old_states, rev_new_states = self._run_forward_reversible( + stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int + ) + block_inputs_states.append( + tl.on_cpu(((std_inputs, std_state), (rev_old_states, rev_new_states))) + ) + + # Run the loss layer forward and backward with optimizer update. + if step_int % self._n_steps_per_log == 1: + logging.info( + "run loss: cpu memory use (MB): %.2f", + process.memory_info().rss / float(1024 * 1024), + ) + loss_state = self._replicate(self._loss_layer.state) + loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) + loss_stats, grad_stack = self._run_backward_standard( + None, + step, + self._loss_layer, + loss_inputs, + loss_state, + self._loss_fbo, + rngs[-1], + self._loss_opt, + self._replicated_loss_opt_params, + ) + self._collect_weights(self._loss_layer) + stats = [tl.on_cpu(loss_stats)] + + # De-fragment memory. + if self._do_free: + stack, grad_stack = tl.on_cpu(stack), tl.on_cpu(grad_stack) + self._free_accelerators() + + # Run the layers backward and run optimizer updates. + if step_int % self._n_steps_per_log == 1: + logging.info( + "run bwd: cpu memory use (MB): %.2f", + process.memory_info().rss / float(1024 * 1024), + ) + for i in range(len(self._blocks) - 1, -1, -1): + std_layer, rev_layers = self._blocks[i] + (std_inputs, std_state), ( + rev_old_states, + rev_new_states, + ) = block_inputs_states[i] + std_fbo, rev_fbos = self._fbos[i] + std_opt, rev_opts = self._optimizers[i] + std_rng, rev_rngs = rng_blocks[i] + repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i] + + # Run reversible layers backward with optimizer update. + stack, grad_stack, new_stats = self._run_backward_reversible( + stack, + grad_stack, + step, + rev_layers, + rev_fbos, + rev_old_states, + rev_new_states, + rev_rngs, + rev_opts, + repl_rev_opts_params, + ) + stats.extend(tl.on_cpu(new_stats)) + + # Run the standard layer forward-and-backward pass and optimizer update. + std_layer_stats, grad_stack = self._run_backward_standard( + grad_stack, + step, + std_layer, + std_inputs, + std_state, + std_fbo, + std_rng, + std_opt, + repl_std_opt_params, + ) + stack = cb.outputs_onto_stack( # Put layer inputs on the stack. + std_inputs, stack, std_layer.n_out + ) + stats.append(tl.on_cpu(std_layer_stats)) + + # Collect lazily unreplicated layer weights. + for rev_layer_id in range(self._n_async_layers): + self._collect_weights(rev_layers[rev_layer_id]) + self._collect_weights(std_layer) + + # Join stats from different optimizers into one. + joint_stats = {} + for i, stat in enumerate(reversed(stats)): + for k, v in stat.items(): + joint_stats[f"layer{i}/" + k] = v + return stats[0]["loss"], joint_stats + + def _run_forward_standard(self, stack, layer, accelerated_fn, rng, step): + """Run standard layer forward.""" + if step % self._n_steps_per_log == 1: + logging.info("running forward standard layer %s", str(layer)) + layer_inputs = cb.inputs_from_stack(stack, layer.n_in) + layer_weights = self._replicate(layer.weights) + layer_state = self._replicate(layer.state) + outputs, layer_new_state = accelerated_fn( + layer_inputs, layer_weights, layer_state, rng + ) + stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) + return stack, layer_inputs, layer_new_state + + def _run_forward_reversible(self, stack, rev_layers, accelerated_fns, rngs, step): + """Run reversible layers forward, collect states for backwards pass.""" + old_states, new_states = [], [] + for i, layer in enumerate(rev_layers): + if step % self._n_steps_per_log == 1: + logging.info("running forward reversible layer %s", str(layer)) + weights = self._replicate(layer.weights) # also copies cpu -> accelerator + state = self._replicate(layer.state) + old_states.append(state) + inputs = cb.inputs_from_stack(stack, layer.n_in) + outputs, new_state = accelerated_fns[i](inputs, weights, state, rngs[i]) + stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) + new_states.append(new_state) + return stack, old_states, new_states + + def _run_backward_standard( + self, + grad_stack, + step, + layer, + inp, + state, + fbo_fn, + rng, + optimizer, + replicated_opt_params, + ): + """Run reversible layers backwards.""" + step_int = int(step) if self._n_devices < 2 else int(step[0]) + if step_int % self._n_steps_per_log == 1: + logging.info("running backward standard layer %s", str(layer)) + if grad_stack is not None: + grads = cb.inputs_from_stack(grad_stack, layer.n_out) + else: + grads = None + slots = self._replicate(optimizer.slots) + weights = self._replicate(layer.weights) + # Ensure all arguments are on accelerator. + state = tl.on_accelerator(state) + replicated_opt_params = tl.on_accelerator(replicated_opt_params) + rng = tl.on_accelerator(rng) + grads = tl.on_accelerator(grads) + inp = tl.on_accelerator(inp) + new_weights, new_state, new_slots, new_grads, stats = fbo_fn( + inp, weights, grads, state, slots, replicated_opt_params, rng, step + ) + layer.weights = self._lazy_unreplicate(new_weights) + layer.state = self._unreplicate(new_state) + optimizer.slots = self._unreplicate(new_slots) + if grad_stack is not None: + grad_stack = cb.outputs_onto_stack(new_grads, grad_stack, layer.n_out) + else: + grad_stack = new_grads + return stats, grad_stack + + def _run_backward_reversible( + self, + stack, + grad_stack, + step, + rev_layers, + rev_and_fbos, + old_states, + new_states, + rngs, + optimizers, + replicated_opt_params, + ): + """Run reversible layers backwards.""" + counter = 0 + stats = [] + step_int = int(step) if self._n_devices < 2 else int(step[0]) + for layer, reverse_and_fbo, old_state, new_state, rng in reversed( + list(zip(rev_layers, rev_and_fbos, old_states, new_states, rngs)) + ): + if step_int % self._n_steps_per_log == 1: + logging.info("running backward reversible layer %s", str(layer)) + counter -= 1 + stack, grad_stack, layer_stats = self._run_backward_one_reversible( + layer, + stack, + grad_stack, + step, + rng, + optimizers[counter], + replicated_opt_params[counter], + reverse_and_fbo, + old_state, + new_state, + ) + stats.append(layer_stats) + if counter + self._n_async_layers < 0: + self._collect_weights(rev_layers[counter + self._n_async_layers]) + return stack, grad_stack, stats + + def _run_backward_one_reversible( + self, + layer, + stack, + grad_stack, + step, + rng, + optimizer, + opt_params, + reverse_and_fbo, + old_state, + new_state, + ): + """Run one reversible layer backwards.""" + # We are running backwards and reversing, so we get *outputs* from stack. + outputs = cb.inputs_from_stack(stack, layer.n_out) + grads = cb.inputs_from_stack(grad_stack, layer.n_out) + slots = self._replicate(optimizer.slots) + weights = self._replicate(layer.weights) # cpu -> accelerator + # Ensure all arguments are on accelerator. + outputs = tl.on_accelerator(outputs) + grads = tl.on_accelerator(grads) + old_state = tl.on_accelerator(old_state) + new_state = tl.on_accelerator(new_state) + opt_params = tl.on_accelerator(opt_params) + rng = tl.on_accelerator(rng) + new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( + outputs, weights, grads, old_state, new_state, slots, opt_params, rng, step + ) + layer.weights = self._lazy_unreplicate(new_weights) # accelerator -> cpu + layer.state = self._unreplicate(new_state) + optimizer.slots = self._unreplicate(new_slots) + stack = cb.outputs_onto_stack(inputs, stack, layer.n_out) + grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out) + return stack, grad_stack, layer_stats # Forward + backward + optimizer-update functions for all layers. # We call them in short FBO for "Forward + Backward + Optimizer update". -def _fbo_with_layer_and_opt(layer, optimizer, n_devices, - stats_name=None, adasum=False): - """Create the fbo function for a given layer and optimizer.""" - def fbo(inputs, weights, grads, state, slots, opt_params, rng, step): - """FBO of the layer.""" - # We need a layer pure_fn but only for inputs and weights. - def pure_fn_without_state_and_rng(x, w): - return layer.pure_fn(x, w, state, rng) +def _fbo_with_layer_and_opt(layer, optimizer, n_devices, stats_name=None, adasum=False): + """Create the fbo function for a given layer and optimizer.""" - # Calculate the vector-Jacobian product of the reduced pure fn. - activations, vjp_fn, new_state = fastmath.vjp( - pure_fn_without_state_and_rng, inputs, weights, has_aux=True) + def fbo(inputs, weights, grads, state, slots, opt_params, rng, step): + """FBO of the layer.""" + # We need a layer pure_fn but only for inputs and weights. + def pure_fn_without_state_and_rng(x, w): + return layer.pure_fn(x, w, state, rng) - # In the loss layer, set gradients to 1 with the dtype of activations=loss. - if grads is None and stats_name is not None: - grads = jnp.ones((), dtype=activations.dtype) + # Calculate the vector-Jacobian product of the reduced pure fn. + activations, vjp_fn, new_state = fastmath.vjp( + pure_fn_without_state_and_rng, inputs, weights, has_aux=True + ) - # The vjp function returns gradients with respect to inputs and weights. - grads_inputs, grads_weights = vjp_fn(grads) + # In the loss layer, set gradients to 1 with the dtype of activations=loss. + if grads is None and stats_name is not None: + grads = jnp.ones((), dtype=activations.dtype) - # For non-trainable layers, return the calculated arguments. - if _is_empty_tuple(weights): - stats = {} - if stats_name is not None: - stats[stats_name] = activations - return weights, new_state, slots, grads_inputs, stats + # The vjp function returns gradients with respect to inputs and weights. + grads_inputs, grads_weights = vjp_fn(grads) - # In multi-device setting, average gradients from multiple devices. - if n_devices > 1: - grads_weights = _average_multidevice_gradients( - grads_weights, adasum=adasum) + # For non-trainable layers, return the calculated arguments. + if _is_empty_tuple(weights): + stats = {} + if stats_name is not None: + stats[stats_name] = activations + return weights, new_state, slots, grads_inputs, stats - # Run the optimizer. - new_weights, new_slots, stats = optimizer.tree_update( - step, grads_weights, weights, slots, opt_params, store_slots=False) - if stats_name is not None: - stats[stats_name] = activations - return new_weights, new_state, new_slots, grads_inputs, stats + # In multi-device setting, average gradients from multiple devices. + if n_devices > 1: + grads_weights = _average_multidevice_gradients(grads_weights, adasum=adasum) - return fbo + # Run the optimizer. + new_weights, new_slots, stats = optimizer.tree_update( + step, grads_weights, weights, slots, opt_params, store_slots=False + ) + if stats_name is not None: + stats[stats_name] = activations + return new_weights, new_state, new_slots, grads_inputs, stats + + return fbo # Reversible layers define a reverse_and_fbo function that both reverses @@ -757,149 +863,161 @@ def pure_fn_without_state_and_rng(x, w): def _reverse_and_fbo_with_layer_and_opt(layer, optimizer, n_devices, adasum): - """Create the reverse_and_fbo function for a given layer and optimizer.""" - def reverse_and_fbo(output, weights, grads, state, new_state, - slots, opt_params, rng, step): - """Reverse and FBO of the layer.""" - # Call the reverse_and_grad method of the layer. - inputs, (grads_inputs, grads_weights) = layer.reverse_and_grad( - output, grads, weights, state, new_state, rng=rng) + """Create the reverse_and_fbo function for a given layer and optimizer.""" + + def reverse_and_fbo( + output, weights, grads, state, new_state, slots, opt_params, rng, step + ): + """Reverse and FBO of the layer.""" + # Call the reverse_and_grad method of the layer. + inputs, (grads_inputs, grads_weights) = layer.reverse_and_grad( + output, grads, weights, state, new_state, rng=rng + ) - # For non-trainable layers, return the calculated arguments. - if _is_empty_tuple(weights): - return weights, slots, inputs, grads_inputs, {} + # For non-trainable layers, return the calculated arguments. + if _is_empty_tuple(weights): + return weights, slots, inputs, grads_inputs, {} - # In multi-device setting, average gradients from multiple devices. - if n_devices > 1: - grads_weights = _average_multidevice_gradients( - grads_weights, adasum=adasum) + # In multi-device setting, average gradients from multiple devices. + if n_devices > 1: + grads_weights = _average_multidevice_gradients(grads_weights, adasum=adasum) - # Run the optimizer. - new_weights, new_slots, stats = optimizer.tree_update( - step, grads_weights, weights, slots, opt_params, store_slots=False) + # Run the optimizer. + new_weights, new_slots, stats = optimizer.tree_update( + step, grads_weights, weights, slots, opt_params, store_slots=False + ) - return new_weights, new_slots, inputs, grads_inputs, stats + return new_weights, new_slots, inputs, grads_inputs, stats - return reverse_and_fbo + return reverse_and_fbo def _is_empty_tuple(x): - """Check if x is either empty or a tuple of (tuples of) empty things.""" - if not isinstance(x, (list, tuple)): - return False - for y in x: - if not _is_empty_tuple(y): - return False - return True + """Check if x is either empty or a tuple of (tuples of) empty things.""" + if not isinstance(x, (list, tuple)): + return False + for y in x: + if not _is_empty_tuple(y): + return False + return True def extract_reversible_blocks(layers, loss_chunk_size=0): - """Extracts blocks and loss layer for use with ReversibleSerialTrainer. - - Args: - layers: a list of layers of a single layer to extract blocks from; - should end with a loss, e.g., [model, loss] or tl.Serial(model, loss). - loss_chunk_size: int, if > 0 creates a chunked loss layer to save memory - in models with larger vocabulary; requires the last sublayers of loss - are [Dense, LogSoftmax, _CrossEntropy, _WeightedMean] in that order. - - Returns: - a pair (blocks, loss_layer) to use with ReversibleSerialTrainer. - """ - def _flatten(l): - """Flatten all Serial layers and sub(sub-...) layers into a list.""" - if isinstance(l, (list, tuple)): - return [x for layer in l for x in _flatten(layer)] # pylint: disable=g-complex-comprehension - elif isinstance(l, tl.Serial): - return _flatten(l.sublayers) - else: - return [l] - - # Extract standard and reversible layer blocks. - blocks, std_layers, rev_layers = [], [], [] - for layer in _flatten(layers): - if isinstance(layer, tl.ReversibleLayer): - rev_layers.append(layer) - elif not rev_layers: - std_layers.append(layer) + """Extracts blocks and loss layer for use with ReversibleSerialTrainer. + + Args: + layers: a list of layers of a single layer to extract blocks from; + should end with a loss, e.g., [model, loss] or tl.Serial(model, loss). + loss_chunk_size: int, if > 0 creates a chunked loss layer to save memory + in models with larger vocabulary; requires the last sublayers of loss + are [Dense, LogSoftmax, _CrossEntropy, _WeightedMean] in that order. + + Returns: + a pair (blocks, loss_layer) to use with ReversibleSerialTrainer. + """ + + def _flatten(l): + """Flatten all Serial layers and sub(sub-...) layers into a list.""" + if isinstance(l, (list, tuple)): + return [ + x for layer in l for x in _flatten(layer) + ] # pylint: disable=g-complex-comprehension + elif isinstance(l, tl.Serial): + return _flatten(l.sublayers) + else: + return [l] + + # Extract standard and reversible layer blocks. + blocks, std_layers, rev_layers = [], [], [] + for layer in _flatten(layers): + if isinstance(layer, tl.ReversibleLayer): + rev_layers.append(layer) + elif not rev_layers: + std_layers.append(layer) + else: + blocks.append((std_layers, rev_layers)) + std_layers, rev_layers = [], [] + std_layers.append(layer) + if rev_layers: + raise ValueError("The final layer must be a standard loss, not reversible.") + if loss_chunk_size > 0: + # For now we only do chunking of [Dense, LogSoftmax, CrossEntopy, Mean] + # Let's check that these are the last 4 layers. + border_layers = ["StripFromConcatenateWithPadding", "Select"] + + loss_start = None + for index, layer in enumerate(std_layers): + if layer.name in border_layers: + loss_start = index + 1 + if loss_start is None: + raise ValueError( + "Loss layer should be preceeded by one of {}; got {}".format( + border_layers, [l.name for l in std_layers] + ) + ) + if len(std_layers) - loss_start < 4: + raise ValueError("Too short loss layer for chunking") + last_3_names = " ".join([l.name for l in std_layers[-3:]]) + if last_3_names != "LogSoftmax _CrossEntropy _WeightedMean": + raise ValueError( + 'Loss chunking only works with last layers being "' + 'LogSoftmax, _CrossEntropy, _WeightedMean" but got: ' + last_3_names + ) + + # Create chunked dense+logsoftmax+cross-entropy-loss. + chunked_xent = tl.Chunk(tl.Serial(std_layers[loss_start:-1]), loss_chunk_size) + # The chunked loss should operate on a merged batch dimension, e.g., + # including both length and batch size. Need to merge and un-merge later. + def _reshape_to_batch_and_copy_targets(preds, targets): + batched_preds = jnp.reshape(preds, [-1, preds.shape[-1]]) + batched_targets = jnp.reshape(targets, [-1]) + return batched_preds, batched_targets, targets + + def _reshape_xent_back(xent, targets): + return jnp.reshape(xent, targets.shape) + + batched_xent = tl.Serial( + tl.Fn("pre_xent_rebatch", _reshape_to_batch_and_copy_targets, n_out=3), + chunked_xent, + tl.Fn("after_xent_rebatch", _reshape_xent_back), + ) + loss_layer = tl.Serial(std_layers[:loss_start] + [batched_xent], std_layers[-1]) else: - blocks.append((std_layers, rev_layers)) - std_layers, rev_layers = [], [] - std_layers.append(layer) - if rev_layers: - raise ValueError('The final layer must be a standard loss, not reversible.') - if loss_chunk_size > 0: - # For now we only do chunking of [Dense, LogSoftmax, CrossEntopy, Mean] - # Let's check that these are the last 4 layers. - border_layers = ['StripFromConcatenateWithPadding', 'Select'] - - loss_start = None - for index, layer in enumerate(std_layers): - if layer.name in border_layers: - loss_start = index + 1 - if loss_start is None: - raise ValueError('Loss layer should be preceeded by one of {}; got {}' - .format(border_layers, [l.name for l in std_layers])) - if len(std_layers) - loss_start < 4: - raise ValueError('Too short loss layer for chunking') - last_3_names = ' '.join([l.name for l in std_layers[-3:]]) - if last_3_names != 'LogSoftmax _CrossEntropy _WeightedMean': - raise ValueError('Loss chunking only works with last layers being "' - 'LogSoftmax, _CrossEntropy, _WeightedMean" but got: ' + - last_3_names) - - # Create chunked dense+logsoftmax+cross-entropy-loss. - chunked_xent = tl.Chunk(tl.Serial(std_layers[loss_start:-1]), - loss_chunk_size) - # The chunked loss should operate on a merged batch dimension, e.g., - # including both length and batch size. Need to merge and un-merge later. - def _reshape_to_batch_and_copy_targets(preds, targets): - batched_preds = jnp.reshape(preds, [-1, preds.shape[-1]]) - batched_targets = jnp.reshape(targets, [-1]) - return batched_preds, batched_targets, targets - def _reshape_xent_back(xent, targets): - return jnp.reshape(xent, targets.shape) - batched_xent = tl.Serial( - tl.Fn('pre_xent_rebatch', _reshape_to_batch_and_copy_targets, n_out=3), - chunked_xent, - tl.Fn('after_xent_rebatch', _reshape_xent_back) - ) - loss_layer = tl.Serial(std_layers[:loss_start] + [batched_xent], - std_layers[-1]) - else: - loss_layer = tl.Serial(std_layers) - return blocks, loss_layer + loss_layer = tl.Serial(std_layers) + return blocks, loss_layer def init_reversible_blocks(blocks, loss_layer, input_signature, rng): - """Initialize reversible blocks and the loss layer and place weights on CPU. - - Args: - blocks: List of reversible blocks (pairs of layer lists). - loss_layer: The final loss layer to initialize. - input_signature: The signature of the input to the blocks. - rng: Random key used to initialize the layers. - """ - sig_stack = input_signature - process = psutil.Process(os.getpid()) - mem_use = process.memory_info().rss - for (std_layers, rev_layers) in blocks: - rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1) - rng = rngs[0] - for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]): - sig = cb.inputs_from_stack(sig_stack, layer.n_in) - layer.init(sig, rng=layer_rng) - layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory - layer.state = tl.on_cpu(layer.state) # store weights in cpu memory - logging.info('init: layer %s\nadded cpu memory (MB): %.2f', str(layer), - (process.memory_info().rss - mem_use) / float(1024 * 1024)) - mem_use = process.memory_info().rss - logging.info('init: cpu memory use (MB): %.2f', - mem_use / float(1024 * 1024)) - out_sig = layer.output_signature(sig) - sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in) - loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng) - loss_layer.weights = tl.on_cpu(loss_layer.weights) - loss_layer.state = tl.on_cpu(loss_layer.state) - + """Initialize reversible blocks and the loss layer and place weights on CPU. + Args: + blocks: List of reversible blocks (pairs of layer lists). + loss_layer: The final loss layer to initialize. + input_signature: The signature of the input to the blocks. + rng: Random key used to initialize the layers. + """ + sig_stack = input_signature + process = psutil.Process(os.getpid()) + mem_use = process.memory_info().rss + for (std_layers, rev_layers) in blocks: + rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1) + rng = rngs[0] + for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]): + sig = cb.inputs_from_stack(sig_stack, layer.n_in) + layer.init(sig, rng=layer_rng) + layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory + layer.state = tl.on_cpu(layer.state) # store weights in cpu memory + logging.info( + "init: layer %s\nadded cpu memory (MB): %.2f", + str(layer), + (process.memory_info().rss - mem_use) / float(1024 * 1024), + ) + mem_use = process.memory_info().rss + logging.info( + "init: cpu memory use (MB): %.2f", mem_use / float(1024 * 1024) + ) + out_sig = layer.output_signature(sig) + sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in) + loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng) + loss_layer.weights = tl.on_cpu(loss_layer.weights) + loss_layer.state = tl.on_cpu(loss_layer.state) diff --git a/trax/optimizers/trainer_test.py b/trax/optimizers/trainer_test.py deleted file mode 100644 index bda6d191f..000000000 --- a/trax/optimizers/trainer_test.py +++ /dev/null @@ -1,344 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for accelerated optimization of loss layers.""" - -import time -from absl.testing import absltest - -from jax.config import config -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import optimizers -from trax import shapes -from trax.layers import base -from trax.models.research import terraformer - - -class TrainerTest(absltest.TestCase): - - def _assert_all_equal(self, t1, t2, tol=1e-5): - def eq(x1, x2): - diff = np.maximum(np.abs(x1 - x2) - tol, 0.0) - self.assertLessEqual(np.sum(diff), 0.0, - msg=f'\n{x1}\n !=\n{x2}\n diff:\n{x1-x2}') - fastmath.nested_map_multiarg(eq, t1, t2) - - def test_run_simple_task(self): - """Runs an accelerated optimizer on a simple task.""" - inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch - targets_batch = np.pi * np.ones_like(inputs_batch) - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss()) - loss_layer.init(labeled_batch) - optimizer = optimizers.SGD(.01) - optimizer.tree_init(loss_layer.weights) - trainer = optimizers.Trainer(loss_layer, optimizer) - rng = fastmath.random.get_prng(0) - trainer.one_step(labeled_batch, rng) - - - def test_run_sharded_terraformer(self): - """Runs Terraformer with sharded weights (only on 2+-device systems).""" - if fastmath.local_device_count() == 1: - return - base.N_WEIGHTS_SHARDS = fastmath.local_device_count() - inputs_batch = np.arange(8).reshape((2, 4)) + 1 - targets_batch = 2 * inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) - input_sig = (int_sig, int_sig, int_sig) - # We want to test rng propagation too, so adding some dropout layers. - model = terraformer.ConfigurableTerraformer( - 20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, - n_encoder_layers=2, n_decoder_layers=2, - ff_sparsity=(4, 8, 0.0, 1.0), - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - pos_type=None, reversible_encoder=True) - loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) - model_with_loss = tl.Serial(model, loss) - rng_init = fastmath.random.get_prng(12) - model_with_loss.init(input_sig, rng=rng_init) - - # Make a step with the trainer. - optimizer = optimizers.Adafactor(0.01) - split_w = fastmath.nested_map( - lambda x: x[0], - tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS)) - optimizer.tree_init(split_w) - trainer = optimizers.Trainer(model_with_loss, optimizer) - rng_step1 = fastmath.random.get_prng(7) - trainer.one_step(labeled_batch, rng_step1) - # Reset shards back to default. - base.N_WEIGHTS_SHARDS = 1 - - def test_run_reversible_slots(self): - """Tests that slots can be read and assigned in reversible trainer.""" - layers = [tl.Dense(4), tl.Dup()] - rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)), - tl.ReversibleSwap()] - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(4), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - trainer = optimizers.ReversibleSerialTrainer( - [(layers, rev_layers)], loss_layer, optimizers.Adam) - slots = trainer.slots - trainer.slots = slots - self.assertEqual(slots, trainer.slots) - - def test_run_reversible_same_as_default_basic(self): - """Runs the reversible trainer, check results are the same as default.""" - inputs_batch = np.arange(8).reshape((2, 4)) - targets_batch = 2 * inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - # We want to test rng propagation too, so adding some dropout layers. - first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) - rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), - tl.ReversibleSwap(), - tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), - tl.ReversibleSwap()] - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - model = tl.Serial([first_layer] + rev_layers + [loss_layer]) - rng_init = fastmath.random.get_prng(12) - model.init(labeled_batch, rng=rng_init) - optimizer_fn = optimizers.Adam # to test slots - - # Make 2 steps with the original trainer. - optimizer = optimizer_fn() - optimizer.tree_init(model.weights) - trainer = optimizers.Trainer(model, optimizer) - rng_step1 = fastmath.random.get_prng(7) - rng_step2 = fastmath.random.get_prng(8) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - first_layer_weights1 = first_layer.weights - rev_layer0_weights1 = rev_layers[0].weights - rev_layer2_weights1 = rev_layers[2].weights - loss_layer_weights1 = loss_layer.weights - - # Now make 2 steps with reversible trainer. - model.init(labeled_batch, rng=rng_init) - trainer = optimizers.ReversibleSerialTrainer( - [(first_layer.sublayers, rev_layers)], loss_layer, optimizer_fn) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - - # Check that weights end up the same. - self._assert_all_equal(loss_layer_weights1, loss_layer.weights) - self._assert_all_equal(rev_layer2_weights1, rev_layers[2].weights) - self._assert_all_equal(rev_layer0_weights1, rev_layers[0].weights) - self._assert_all_equal(first_layer_weights1, first_layer.weights) - - def test_run_reversible_same_as_default_extended(self): - """Runs the reversible trainer, check results are the same as default.""" - inputs_batch = np.arange(8).reshape((2, 4)) - targets_batch = 2 * inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - # We want to test rng propagation too, so adding some dropout layers. - first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) - rev_layers1 = [tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), - tl.ReversibleSwap(), - tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), - tl.ReversibleSwap()] - mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) - rev_layers2 = [tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), - tl.ReversibleSwap()] - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] + - rev_layers2 + [loss_layer]) - rng_init = fastmath.random.get_prng(12) - model.init(labeled_batch, rng=rng_init) - optimizer_fn = optimizers.Adam # to test slots - - # Make 3 steps with the original trainer. - optimizer = optimizer_fn() - optimizer.tree_init(model.weights) - trainer = optimizers.Trainer(model, optimizer) - rng_step1 = fastmath.random.get_prng(7) - rng_step2 = fastmath.random.get_prng(8) - rng_step3 = fastmath.random.get_prng(9) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) - first_layer_weights1 = first_layer.weights - rev_layer12_weights1 = rev_layers1[2].weights - mid_layer_weights1 = mid_layer.weights - rev_layer20_weights1 = rev_layers2[0].weights - loss_layer_weights1 = loss_layer.weights - - # Now make 3 steps with reversible trainer. - model.init(labeled_batch, rng=rng_init) - # TODO(lukaszkaiser): this test seems to fail with memoize_jit, why? - trainer = optimizers.ReversibleSerialTrainer( - [(first_layer.sublayers, rev_layers1), - (mid_layer.sublayers, rev_layers2)], - loss_layer, optimizer_fn, memoize_jit=False) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) - - # Check that weights end up the same. - self._assert_all_equal(loss_layer_weights1, loss_layer.weights) - self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) - self._assert_all_equal(mid_layer_weights1, mid_layer.weights) - self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) - self._assert_all_equal(first_layer_weights1, first_layer.weights) - - def test_run_reversible_same_as_default_terraformer(self): - """Runs the reversible trainer, check results are the same as default.""" - inputs_batch = np.arange(8).reshape((2, 4)) + 1 - targets_batch = 2 * inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) - input_sig = (int_sig, int_sig, int_sig) - # We want to test rng propagation too, so adding some dropout layers. - model = terraformer.ConfigurableTerraformer( - 20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, n_encoder_layers=2, - n_decoder_layers=2, ff_sparsity=(4, 8, 0.0, 1.0), pos_type=None, - reversible_encoder=True) - loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) - optimizer_fn = optimizers.Adafactor - blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( - [model, loss], loss_chunk_size=4) - blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks] - model_with_loss = tl.Serial(model, loss) - rng_init = fastmath.random.get_prng(12) - model_with_loss.init(input_sig, rng=rng_init) - - # Make 3 steps with the original trainer. - optimizer = optimizer_fn() - optimizer.tree_init(model_with_loss.weights) - trainer = optimizers.Trainer(model_with_loss, optimizer) - rng_step1 = fastmath.random.get_prng(7) - rng_step2 = fastmath.random.get_prng(8) - rng_step3 = fastmath.random.get_prng(9) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) - first_weights = blocks_serial[0][0].weights - first_rev_weights = blocks[0][1][0].weights - loss_weights = loss_layer.weights - - # Now make 3 steps with reversible trainer. - model_with_loss.init(input_sig, rng=rng_init) - trainer = optimizers.ReversibleSerialTrainer( - blocks, loss_layer, optimizer_fn) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) - - # Check that weights end up the same. - self._assert_all_equal(loss_weights, loss_layer.weights) - self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights) - self._assert_all_equal(first_weights, blocks_serial[0][0].weights) - - def test_run_reversible_large_weights(self): - """Runs the reversible trainer with a lot of weights to test memory use.""" - # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU - # and CPU when you run it locally, but it's too big for unit-testing. - ram_limited = True # Set to False to run this test locally. - if fastmath.global_device_count() == 1 and ram_limited: - return - - # Create inputs and rngs. - inputs_batch = np.arange(8).reshape((2, 4)) - targets_batch = inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - first_layer = tl.Serial(tl.Embedding(9, 16*1024), tl.Dup()) - rng_init = fastmath.random.get_prng(12) - rng_step = fastmath.random.get_prng(13) - - # Initialize layers. - first_layer.init(labeled_batch, rng=rng_init) - n_layers = 18 # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram - rev_layers = [] - int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) - shape = shapes.ShapeDtype((2, 4, 16*1024)) - sig = (shape, shape) - for _ in range(n_layers): - layer = tl.ReversibleHalfResidual(tl.Dense(16*1024)) - layer.init(sig, rng=rng_init) - layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory - rev_layers.append(layer) - rev_layers.append(tl.ReversibleSwap()) - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - loss_layer.init((shape, shape, int_shape, int_shape)) - optimizer_fn = optimizers.Adafactor - - # Make a step with reversible trainer. - trainer = optimizers.ReversibleSerialTrainer( - [(first_layer, rev_layers)], loss_layer, optimizer_fn) - loss, _ = trainer.one_step(labeled_batch, rng_step) - self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. - # Set to true to run again, e.g., for profiling. - run_twice = False - if run_twice: - t = time.time() - loss, _ = trainer.one_step(labeled_batch, rng_step) - self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. - print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss)) - - def test_run_reversible_weights_trainsfer_xprof(self): - """Runs the reversible trainer and profiles weight transfer stats.""" - run_this_test = False # We only run this test manually. - if not run_this_test or fastmath.global_device_count() == 1: # TPU only - return - - # Create inputs and rngs. - inputs_batch = np.ones((1024, 128), dtype=np.int32) - targets_batch = inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup()) - rng_init = fastmath.random.get_prng(12) - rng_step = fastmath.random.get_prng(13) - - # Initialize layers. - first_layer.init(labeled_batch, rng=rng_init) - n_layers = 6 - rev_layers = [] - int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32) - shape = shapes.ShapeDtype((1024, 128, 1024)) - sig = (shape, shape) - for _ in range(n_layers): - layer = tl.ReversibleHalfResidual(tl.Dense(1024)) - layer.init(sig, rng=rng_init) - layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory - rev_layers.append(layer) - rev_layers.append(tl.ReversibleSwap()) - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - loss_layer.init((shape, shape, int_shape, int_shape)) - optimizer_fn = optimizers.SGD - - # Make a step with reversible trainer. - trainer = optimizers.ReversibleSerialTrainer( - [(first_layer, rev_layers)], loss_layer, optimizer_fn) - loss, _ = trainer.one_step(labeled_batch, rng_step) - self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. - # We profile here. - t = time.time() - loss, _ = trainer.one_step(labeled_batch, rng_step) - self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. - print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss)) - - -if __name__ == '__main__': - config.config_with_absl() - absltest.main() diff --git a/trax/predict_drop.py b/trax/predict_drop.py index 0b00a12e9..9cae41e41 100644 --- a/trax/predict_drop.py +++ b/trax/predict_drop.py @@ -41,287 +41,311 @@ FLAGS = flags.FLAGS -flags.DEFINE_string('checkpoint_dir', '', - 'Path to model checkpoint.') -flags.DEFINE_integer('max_answer_len', 1024, - 'Maximum length of answers to produce.') -flags.DEFINE_integer('batch_size', 1, 'Batch size for eval.') -flags.DEFINE_integer('num_examples', 1, 'Number of examples to infer.') -flags.DEFINE_integer('n_hashes', None, - 'n_hashes parameter to override in attentions.') -flags.DEFINE_integer('example_repetitions', 1, - 'How many times to infer an example.') -flags.DEFINE_bool('use_eval_mode', False, - 'If True, use the slower but easier to debug eval mode.') -flags.DEFINE_bool('use_eval_set', False, - 'If True, use eval set for evaluation.') +flags.DEFINE_string("checkpoint_dir", "", "Path to model checkpoint.") +flags.DEFINE_integer("max_answer_len", 1024, "Maximum length of answers to produce.") +flags.DEFINE_integer("batch_size", 1, "Batch size for eval.") +flags.DEFINE_integer("num_examples", 1, "Number of examples to infer.") +flags.DEFINE_integer("n_hashes", None, "n_hashes parameter to override in attentions.") +flags.DEFINE_integer("example_repetitions", 1, "How many times to infer an example.") flags.DEFINE_bool( - 'use_beam_search', False, - 'If True, use beam search, otherwise use autoregresive sampling.') -flags.DEFINE_float('autoregressive_sample_temp', 1, - 'The temperature for autoregressive sampling.') -flags.DEFINE_integer('n_beams', 4, 'How many beams to use in beam search.') + "use_eval_mode", False, "If True, use the slower but easier to debug eval mode." +) +flags.DEFINE_bool("use_eval_set", False, "If True, use eval set for evaluation.") +flags.DEFINE_bool( + "use_beam_search", + False, + "If True, use beam search, otherwise use autoregresive sampling.", +) +flags.DEFINE_float( + "autoregressive_sample_temp", 1, "The temperature for autoregressive sampling." +) +flags.DEFINE_integer("n_beams", 4, "How many beams to use in beam search.") flags.DEFINE_string( - 'output_dir', '', 'Path to the output directory where articles, abstracts, ' - 'and predictions would be stored.') -flags.DEFINE_integer('starting_example', 0, - 'Example index for starting decoding.') -flags.DEFINE_integer('reload_after', 1000, - 'Reload checkpoint after reload_after examples.') -flags.DEFINE_multi_string('config_file', None, - 'Configuration file with parameters (.gin).') + "output_dir", + "", + "Path to the output directory where articles, abstracts, " + "and predictions would be stored.", +) +flags.DEFINE_integer("starting_example", 0, "Example index for starting decoding.") +flags.DEFINE_integer( + "reload_after", 1000, "Reload checkpoint after reload_after examples." +) +flags.DEFINE_multi_string( + "config_file", None, "Configuration file with parameters (.gin)." +) def _check_exists(file_path): - if not tf.io.gfile.exists(file_path): - print('No such file: %s' % file_path, flush=True) - exit(1) + if not tf.io.gfile.exists(file_path): + print("No such file: %s" % file_path, flush=True) + exit(1) def multiply_examples(example): - for i in range(FLAGS.example_repetitions): - yield i, example + for i in range(FLAGS.example_repetitions): + yield i, example def prepare_model(model_file, batch_size=1): - """Prepare the model.""" - mode = 'eval' if FLAGS.use_eval_mode else 'predict' - print('Initializing the model in %s mode.' % mode, flush=True) + """Prepare the model.""" + mode = "eval" if FLAGS.use_eval_mode else "predict" + print("Initializing the model in %s mode." % mode, flush=True) - # Read the model name from the gin file - model_reference = gin.query_parameter( - 'trax.supervised.trainer_lib.train.model') - model = model_reference.scoped_configurable_fn(mode=mode) + # Read the model name from the gin file + model_reference = gin.query_parameter("trax.supervised.trainer_lib.train.model") + model = model_reference.scoped_configurable_fn(mode=mode) - dec_len = 32 if FLAGS.use_eval_mode else 1 - batch_size_pd = max(1, batch_size // jax.local_device_count()) - shape11 = shapes.ShapeDtype((batch_size_pd, dec_len), dtype=np.int32) - # shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - model.init_from_file( - model_file, weights_only=True, input_signature=(shape11, shape11)) - model = tl.Accelerate(model) + dec_len = 32 if FLAGS.use_eval_mode else 1 + batch_size_pd = max(1, batch_size // jax.local_device_count()) + shape11 = shapes.ShapeDtype((batch_size_pd, dec_len), dtype=np.int32) + # shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + model.init_from_file( + model_file, weights_only=True, input_signature=(shape11, shape11) + ) + model = tl.Accelerate(model) - initial_state = model.state - vocab = t5_spc_vocab.SentencePieceVocabulary(data.DEFAULT_SPM_PATH) + initial_state = model.state + vocab = t5_spc_vocab.SentencePieceVocabulary(data.DEFAULT_SPM_PATH) - return vocab, model, initial_state + return vocab, model, initial_state def is_number(s): - try: - float(s) - return True - except ValueError: - return False + try: + float(s) + return True + except ValueError: + return False def main(argv): - if len(argv) > 1: - raise absl_app.UsageError('Too many command-line arguments.') - if not FLAGS.output_dir: - raise absl_app.UsageError('--output_dir needs to be provided.') - - tf.compat.v1.enable_eager_execution() - - # Check that checkpoint_dir is correct: should contain model.pkl.gz file. - model_file = os.path.join(FLAGS.checkpoint_dir, 'model.pkl.gz') - _check_exists(model_file) - - gin.parse_config_file(os.path.join(FLAGS.checkpoint_dir, 'config.gin')) - # Batching on our own because of possible repetitions of examples. - gin.bind_parameter('data.Batch.batch_size', 1) - if FLAGS.n_hashes is not None: - gin.bind_parameter('LSHSelfAttention.n_hashes', FLAGS.n_hashes) - gin.bind_parameter('ref2_encoder/LSHSelfAttention.n_hashes', FLAGS.n_hashes) - - vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) - - host_id, host_count = jax.host_id(), jax.host_count() - print('Running on host %d out of %d.' % (host_id, host_count)) - - example_count = 0 - start_time = time.time() - - # Creates all intermediate directories if they do not exist - tf.io.gfile.makedirs(FLAGS.output_dir) - - json_to_write = os.path.join(FLAGS.output_dir, 'output%d.json' % host_id) - all_jsons = [] - - # In a case of a reset we have to check how much work was already done. - # We can check whether the processing of an example was finished, but - # currently we are only checking whether it was started. - done = FLAGS.starting_example - reload_count = 0 - all_existing_files = tf.io.gfile.listdir(FLAGS.output_dir) - for filename in all_existing_files: - if 'processing' in filename: - # The definition of digits looks for a number after the infix "processing" - # in the file name. Example: tom_processing_532 will lead to - # digits = "processing_532" and number equal to "532". - digits = filename[filename.find('processing'):] - number = ''.join(d for d in digits if d.isdigit()) - if is_number( - number) and int(number) < FLAGS.num_examples + FLAGS.starting_example: - done = max(done, int(number)) - print('The done number is {}'.format(done)) - - if FLAGS.use_eval_set: - drop_gen = trax_data.CreateDropInputs(train=False)() - else: - drop_gen = trax_data.CreateDropInputs(train=True)() - padding_fun = trax_data.PadToLength() - - # TODO(henrykm): improve managment of the counters. - # example_count_total - all numeric examples - # example_count - all numeric examples above starting_example - # reload_count - if we processed FLAGS.reload_after examples, - # then the checkpoint should be reloaded. - # idx - total number of exaples - example_count_total = 0 - reload_count += 1 - for idx, e in enumerate(drop_gen): - if reload_count >= FLAGS.reload_after: - vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) - reload_count = 0 - if example_count >= FLAGS.num_examples: - print('Reached the example_count {} - breaking'.format(example_count)) - break - if not is_number(e[1]): - continue - target_answer = float(e[1]) - - # We count numeric starting examples - example_count_total += 1 - if example_count_total <= FLAGS.starting_example: - print('Skipping example_count_total {} because it is below {}'.format( - example_count_total, FLAGS.starting_example)) - continue - - if example_count % 10 == 0: - elapsed_time = time.time() - start_time - start_time = time.time() - print('Starting inference on example %d, %.2fs since last log' % - (example_count, elapsed_time), flush=True) - - example_count += 1 - if example_count <= done - FLAGS.starting_example + 1: - print('Skipping example_count {} because it is below {}'.format( - example_count, done - FLAGS.starting_example)) - # We are increasing the example_count because the example - # was processed before - continue - - if example_count % host_count != host_id: - continue - - # At this point we are committed to the processing of an example with - # index example_count - processing_file = os.path.join(FLAGS.output_dir, 'processing_') - data_id = str(example_count + FLAGS.starting_example) - with tf.io.gfile.GFile(processing_file + data_id, 'w') as w: - w.write('Procesing started.') - for repetition_id, example in multiply_examples(e): - question = example[0] - question_text = question[question.find(':') + 2:] - question_text = question_text.replace('-', ' - ') - question = 'infer full calculation: ' + question_text - - list_num = [ - float(num.replace(',', '').rstrip('.')) for num in re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', question) - ] - for i in range(len(list_num)): - question += ' n{} = {}'.format(i, list_num[i]) - - # print('Question {}'.format(question)) - tokenized_question = next( - padding_fun( - trax_data.tokenize([ - question, - ], - vocab_file=gin.query_parameter( - 'trax.data.Tokenize.vocab_file')))) - state = model.state - if FLAGS.use_beam_search: - answer_beams = decoding.beam_search( - model, - tokenized_question[None, :], - n_beams=FLAGS.n_beams, - max_length=FLAGS.max_answer_len, - accelerate=False) - model.state = state - else: - answer_beams = [] - # We recycle the n_beams flag to control the number - # of autoregressive samples. - for i in range(FLAGS.n_beams): - answer = decoding.autoregressive_sample( - model, - tokenized_question[None, :], - temperature=FLAGS.autoregressive_sample_temp, - max_length=FLAGS.max_answer_len, - accelerate=False) - model.state = state - answer_beams.append(answer) - - correct_example_index = -1 - - for i in range(len(answer_beams)): - if FLAGS.use_beam_search: - answer = trax_data.detokenize( - answer_beams[i][0][0], - vocab_file=gin.query_parameter('trax.data.Tokenize.vocab_file')) - else: - answer = trax_data.detokenize( - answer_beams[i][0], - vocab_file=gin.query_parameter('trax.data.Tokenize.vocab_file')) - print('Proposed computation {}'.format(answer)) - list_op = answer.split('|') - if not list_op[-1]: - list_op = list_op[:-1] - - try: - result = trax_data.tf_inputs.compute_result(list_op, list_num) - if target_answer in result: - correct_example_index = result.index(target_answer) + if len(argv) > 1: + raise absl_app.UsageError("Too many command-line arguments.") + if not FLAGS.output_dir: + raise absl_app.UsageError("--output_dir needs to be provided.") + + tf.compat.v1.enable_eager_execution() + + # Check that checkpoint_dir is correct: should contain model.pkl.gz file. + model_file = os.path.join(FLAGS.checkpoint_dir, "model.pkl.gz") + _check_exists(model_file) + + gin.parse_config_file(os.path.join(FLAGS.checkpoint_dir, "config.gin")) + # Batching on our own because of possible repetitions of examples. + gin.bind_parameter("data.Batch.batch_size", 1) + if FLAGS.n_hashes is not None: + gin.bind_parameter("LSHSelfAttention.n_hashes", FLAGS.n_hashes) + gin.bind_parameter("ref2_encoder/LSHSelfAttention.n_hashes", FLAGS.n_hashes) + + vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) + + host_id, host_count = jax.host_id(), jax.host_count() + print("Running on host %d out of %d." % (host_id, host_count)) + + example_count = 0 + start_time = time.time() + + # Creates all intermediate directories if they do not exist + tf.io.gfile.makedirs(FLAGS.output_dir) + + json_to_write = os.path.join(FLAGS.output_dir, "output%d.json" % host_id) + all_jsons = [] + + # In a case of a reset we have to check how much work was already done. + # We can check whether the processing of an example was finished, but + # currently we are only checking whether it was started. + done = FLAGS.starting_example + reload_count = 0 + all_existing_files = tf.io.gfile.listdir(FLAGS.output_dir) + for filename in all_existing_files: + if "processing" in filename: + # The definition of digits looks for a number after the infix "processing" + # in the file name. Example: tom_processing_532 will lead to + # digits = "processing_532" and number equal to "532". + digits = filename[filename.find("processing") :] + number = "".join(d for d in digits if d.isdigit()) + if ( + is_number(number) + and int(number) < FLAGS.num_examples + FLAGS.starting_example + ): + done = max(done, int(number)) + print("The done number is {}".format(done)) + + if FLAGS.use_eval_set: + drop_gen = trax_data.CreateDropInputs(train=False)() + else: + drop_gen = trax_data.CreateDropInputs(train=True)() + padding_fun = trax_data.PadToLength() + + # TODO(henrykm): improve managment of the counters. + # example_count_total - all numeric examples + # example_count - all numeric examples above starting_example + # reload_count - if we processed FLAGS.reload_after examples, + # then the checkpoint should be reloaded. + # idx - total number of exaples + example_count_total = 0 + reload_count += 1 + for idx, e in enumerate(drop_gen): + if reload_count >= FLAGS.reload_after: + vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) + reload_count = 0 + if example_count >= FLAGS.num_examples: + print("Reached the example_count {} - breaking".format(example_count)) break - # This is a temporary hack with "broad" exceptions - the computations - # must fail sometime, because we evaluate arbitrary sequences; I am in - # the process of checking what are possible failure modes. - except Exception as e: # pylint: disable=broad-except - print(e) - try: - result = trax_data.tf_inputs.compute_result(list_op[:-1], list_num) - if target_answer in result: - correct_example_index = result.index(target_answer) - break - except Exception as e: # pylint: disable=broad-except - print(e) - print('Infered incorrect computation.') - - if correct_example_index == -1: - continue - - json_record = { - 'question': question_text, - 'input': question, - 'calculation': '|'.join(list_op[:correct_example_index + 1]), - 'target_answer': target_answer - } - all_jsons.append(json.dumps(json_record) + '\n') - # Outputting the inferred data in JSONL format. - data_id = str(example_count + FLAGS.starting_example) - with tf.io.gfile.GFile(json_to_write + data_id, 'w') as w: - w.write(json.dumps(json_record) + '\n') - with tf.io.gfile.GFile(processing_file + data_id, 'w') as w: - w.write('Procesing finished.') - - with tf.io.gfile.GFile(json_to_write + '_' + str(FLAGS.starting_example), - 'w') as w: - for record in all_jsons: - w.write(record) - - -if __name__ == '__main__': - absl_app.run(main) + if not is_number(e[1]): + continue + target_answer = float(e[1]) + + # We count numeric starting examples + example_count_total += 1 + if example_count_total <= FLAGS.starting_example: + print( + "Skipping example_count_total {} because it is below {}".format( + example_count_total, FLAGS.starting_example + ) + ) + continue + + if example_count % 10 == 0: + elapsed_time = time.time() - start_time + start_time = time.time() + print( + "Starting inference on example %d, %.2fs since last log" + % (example_count, elapsed_time), + flush=True, + ) + + example_count += 1 + if example_count <= done - FLAGS.starting_example + 1: + print( + "Skipping example_count {} because it is below {}".format( + example_count, done - FLAGS.starting_example + ) + ) + # We are increasing the example_count because the example + # was processed before + continue + + if example_count % host_count != host_id: + continue + + # At this point we are committed to the processing of an example with + # index example_count + processing_file = os.path.join(FLAGS.output_dir, "processing_") + data_id = str(example_count + FLAGS.starting_example) + with tf.io.gfile.GFile(processing_file + data_id, "w") as w: + w.write("Procesing started.") + for repetition_id, example in multiply_examples(e): + question = example[0] + question_text = question[question.find(":") + 2 :] + question_text = question_text.replace("-", " - ") + question = "infer full calculation: " + question_text + + list_num = [ + float(num.replace(",", "").rstrip(".")) + for num in re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", question + ) + ] + for i in range(len(list_num)): + question += " n{} = {}".format(i, list_num[i]) + + # print('Question {}'.format(question)) + tokenized_question = next( + padding_fun( + trax_data.tokenize( + [ + question, + ], + vocab_file=gin.query_parameter("trax.data.Tokenize.vocab_file"), + ) + ) + ) + state = model.state + if FLAGS.use_beam_search: + answer_beams = decoding.beam_search( + model, + tokenized_question[None, :], + n_beams=FLAGS.n_beams, + max_length=FLAGS.max_answer_len, + accelerate=False, + ) + model.state = state + else: + answer_beams = [] + # We recycle the n_beams flag to control the number + # of autoregressive samples. + for i in range(FLAGS.n_beams): + answer = decoding.autoregressive_sample( + model, + tokenized_question[None, :], + temperature=FLAGS.autoregressive_sample_temp, + max_length=FLAGS.max_answer_len, + accelerate=False, + ) + model.state = state + answer_beams.append(answer) + + correct_example_index = -1 + + for i in range(len(answer_beams)): + if FLAGS.use_beam_search: + answer = trax_data.detokenize( + answer_beams[i][0][0], + vocab_file=gin.query_parameter("trax.data.Tokenize.vocab_file"), + ) + else: + answer = trax_data.detokenize( + answer_beams[i][0], + vocab_file=gin.query_parameter("trax.data.Tokenize.vocab_file"), + ) + print("Proposed computation {}".format(answer)) + list_op = answer.split("|") + if not list_op[-1]: + list_op = list_op[:-1] + + try: + result = trax_data.tf_inputs.compute_result(list_op, list_num) + if target_answer in result: + correct_example_index = result.index(target_answer) + break + # This is a temporary hack with "broad" exceptions - the computations + # must fail sometime, because we evaluate arbitrary sequences; I am in + # the process of checking what are possible failure modes. + except Exception as e: # pylint: disable=broad-except + print(e) + try: + result = trax_data.tf_inputs.compute_result( + list_op[:-1], list_num + ) + if target_answer in result: + correct_example_index = result.index(target_answer) + break + except Exception as e: # pylint: disable=broad-except + print(e) + print("Infered incorrect computation.") + + if correct_example_index == -1: + continue + + json_record = { + "question": question_text, + "input": question, + "calculation": "|".join(list_op[: correct_example_index + 1]), + "target_answer": target_answer, + } + all_jsons.append(json.dumps(json_record) + "\n") + # Outputting the inferred data in JSONL format. + data_id = str(example_count + FLAGS.starting_example) + with tf.io.gfile.GFile(json_to_write + data_id, "w") as w: + w.write(json.dumps(json_record) + "\n") + with tf.io.gfile.GFile(processing_file + data_id, "w") as w: + w.write("Procesing finished.") + + with tf.io.gfile.GFile(json_to_write + "_" + str(FLAGS.starting_example), "w") as w: + for record in all_jsons: + w.write(record) + + +if __name__ == "__main__": + absl_app.run(main) diff --git a/trax/rl_trainer.py b/trax/rl_trainer.py index 7487a0c46..7201ecaf9 100644 --- a/trax/rl_trainer.py +++ b/trax/rl_trainer.py @@ -31,97 +31,95 @@ import faulthandler +import gin +import jax from absl import app from absl import flags from absl import logging -import gin -import jax from jax.config import config + from trax import fastmath -from trax import rl # pylint: disable=unused-import -from trax import trainer_flags # pylint: disable=unused-import from trax.rl import task as rl_task from trax.rl import training as light_trainers from trax.tf_numpy import numpy as tf_np - FLAGS = flags.FLAGS # Not just 'train' to avoid a conflict with trax.train in GIN files. -@gin.configurable(denylist=['output_dir'], module='trax') +@gin.configurable(denylist=["output_dir"], module="trax") def train_rl( output_dir, n_epochs=10000, light_rl=True, - light_rl_trainer=light_trainers.PolicyGradient): - """Train the RL agent. - - Args: - output_dir: Output directory. - n_epochs: Number epochs to run the training for. - light_rl: deprecated, always True, left out for old gin configs. - light_rl_trainer: which light RL trainer to use (experimental). - """ - del light_rl - tf_np.set_allow_float64(FLAGS.tf_allow_float64) - task = rl_task.RLTask() - env_name = task.env_name - - - if FLAGS.jax_debug_nans: - config.update('jax_debug_nans', True) - - if FLAGS.use_tpu: - config.update('jax_platform_name', 'tpu') - else: - config.update('jax_platform_name', '') - - - trainer = light_rl_trainer(task=task, output_dir=output_dir) - def light_training_loop(): - """Run the trainer for n_epochs and call close on it.""" - try: - logging.info('Starting RL training for %d epochs.', n_epochs) - trainer.run(n_epochs, n_epochs_is_total_epochs=True) - logging.info('Completed RL training for %d epochs.', n_epochs) - trainer.close() - logging.info('Trainer is now closed.') - except Exception as e: - raise e - finally: - logging.info('Encountered an exception, still calling trainer.close()') - trainer.close() - logging.info('Trainer is now closed.') - - if FLAGS.jax_debug_nans or FLAGS.disable_jit: - fastmath.disable_jit() - with jax.disable_jit(): - light_training_loop() - else: - light_training_loop() + light_rl_trainer=light_trainers.PolicyGradient, +): + """Train the RL agent. + + Args: + output_dir: Output directory. + n_epochs: Number epochs to run the training for. + light_rl: deprecated, always True, left out for old gin configs. + light_rl_trainer: which light RL trainer to use (experimental). + """ + del light_rl + tf_np.set_allow_float64(FLAGS.tf_allow_float64) + task = rl_task.RLTask() + env_name = task.env_name + + if FLAGS.jax_debug_nans: + config.update("jax_debug_nans", True) + + if FLAGS.use_tpu: + config.update("jax_platform_name", "tpu") + else: + config.update("jax_platform_name", "") + + trainer = light_rl_trainer(task=task, output_dir=output_dir) + + def light_training_loop(): + """Run the trainer for n_epochs and call close on it.""" + try: + logging.info("Starting RL training for %d epochs.", n_epochs) + trainer.run(n_epochs, n_epochs_is_total_epochs=True) + logging.info("Completed RL training for %d epochs.", n_epochs) + trainer.close() + logging.info("Trainer is now closed.") + except Exception as e: + raise e + finally: + logging.info("Encountered an exception, still calling trainer.close()") + trainer.close() + logging.info("Trainer is now closed.") + + if FLAGS.jax_debug_nans or FLAGS.disable_jit: + fastmath.disable_jit() + with jax.disable_jit(): + light_training_loop() + else: + light_training_loop() def main(argv): - del argv - logging.info('Starting RL training.') + del argv + logging.info("Starting RL training.") - gin_configs = FLAGS.config if FLAGS.config is not None else [] - gin.enter_interactive_mode() - gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs) - gin.exit_interactive_mode() + gin_configs = FLAGS.config if FLAGS.config is not None else [] + gin.enter_interactive_mode() + gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs) + gin.exit_interactive_mode() - logging.info('Gin config:') - logging.info(gin_configs) + logging.info("Gin config:") + logging.info(gin_configs) - train_rl(output_dir=FLAGS.output_dir) + train_rl(output_dir=FLAGS.output_dir) - # TODO(afrozm): This is for debugging. - logging.info('Dumping stack traces of all stacks.') - faulthandler.dump_traceback(all_threads=True) + # TODO(afrozm): This is for debugging. + logging.info("Dumping stack traces of all stacks.") + faulthandler.dump_traceback(all_threads=True) - logging.info('Training is done, should exit.') + logging.info("Training is done, should exit.") -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/trax/shapes.py b/trax/shapes.py index ee58a7e7c..8db9c73d6 100644 --- a/trax/shapes.py +++ b/trax/shapes.py @@ -20,121 +20,124 @@ class ShapeDtype: - """A NumPy ndarray-like object abstracted as shape and dtype. + """A NumPy ndarray-like object abstracted as shape and dtype. - Main use is for representing input and output signatures. - """ - __slots__ = ['shape', 'dtype'] - - def __init__(self, shape, dtype=np.float32): - """Creates a `ShapeDtype` instance, with canonicalized `shape` and `dtype`. - - Args: - shape: A tuple or list, each element of which is an int or, less often, - `None`. - dtype: A `dtype` object, either from NumPy or TensorFlow. - - Returns: - A `ShapeDtype` instance whose `shape` is a tuple and `dtype` is a NumPy - `dtype` object. + Main use is for representing input and output signatures. """ - # Canonicalize shape and dtype. - if isinstance(shape, tf.TensorShape): - shape = shape.as_list() - if isinstance(shape, list): - shape = tuple(shape) - if not isinstance(shape, tuple): - raise TypeError('shape must be tuple or list; got: {}'.format(shape)) - if isinstance(dtype, tf.DType): - dtype = dtype.as_numpy_dtype - - self.shape = shape - self.dtype = dtype - - def __eq__(self, other): - return (isinstance(other, self.__class__) + + __slots__ = ["shape", "dtype"] + + def __init__(self, shape, dtype=np.float32): + """Creates a `ShapeDtype` instance, with canonicalized `shape` and `dtype`. + + Args: + shape: A tuple or list, each element of which is an int or, less often, + `None`. + dtype: A `dtype` object, either from NumPy or TensorFlow. + + Returns: + A `ShapeDtype` instance whose `shape` is a tuple and `dtype` is a NumPy + `dtype` object. + """ + # Canonicalize shape and dtype. + if isinstance(shape, tf.TensorShape): + shape = shape.as_list() + if isinstance(shape, list): + shape = tuple(shape) + if not isinstance(shape, tuple): + raise TypeError("shape must be tuple or list; got: {}".format(shape)) + if isinstance(dtype, tf.DType): + dtype = dtype.as_numpy_dtype + + self.shape = shape + self.dtype = dtype + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) and self.shape == other.shape - and self.dtype == other.dtype) + and self.dtype == other.dtype + ) - def __ne__(self, other): - return not self == other + def __ne__(self, other): + return not self == other - def __repr__(self): - return 'ShapeDtype{{shape:{}, dtype:{}}}'.format(self.shape, self.dtype) + def __repr__(self): + return "ShapeDtype{{shape:{}, dtype:{}}}".format(self.shape, self.dtype) - def __len__(self): - """Returns length of 1; relevant to input and output signatures.""" - return 1 + def __len__(self): + """Returns length of 1; relevant to input and output signatures.""" + return 1 - def as_tuple(self): - return self.shape, self.dtype + def as_tuple(self): + return self.shape, self.dtype - def replace(self, **kwargs): - """Creates a copy of the object with some parameters replaced.""" - return type(self)( - shape=kwargs.pop('shape', self.shape), - dtype=kwargs.pop('dtype', self.dtype), - ) + def replace(self, **kwargs): + """Creates a copy of the object with some parameters replaced.""" + return type(self)( + shape=kwargs.pop("shape", self.shape), + dtype=kwargs.pop("dtype", self.dtype), + ) def signature(obj): - """Returns a `ShapeDtype` signature for the given `obj`. - - A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype` - instances. Note that this function is permissive with respect to its inputs - (accepts lists or tuples or dicts, and underlying objects can be any type - as long as they have shape and dtype attributes) and returns the corresponding - nested structure of `ShapeDtype`. - - Args: - obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict - of such objects. - - Returns: - A corresponding nested structure of `ShapeDtype` instances. - """ - if isinstance(obj, (list, tuple)): - output = tuple(signature(x) for x in obj) - return output if isinstance(obj, tuple) else list(output) - elif isinstance(obj, dict): - return {k: signature(v) for (k, v) in obj.items()} - else: - return ShapeDtype(obj.shape, obj.dtype) + """Returns a `ShapeDtype` signature for the given `obj`. + A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype` + instances. Note that this function is permissive with respect to its inputs + (accepts lists or tuples or dicts, and underlying objects can be any type + as long as they have shape and dtype attributes) and returns the corresponding + nested structure of `ShapeDtype`. -def splice_signatures(*sigs): - """Creates a new signature by splicing together any number of signatures. - - The splicing effectively flattens the top level input signatures. For - instance, it would perform the following mapping: - - - `*sigs: sd1, (sd2, sd3, sd4), (), sd5` - - return: `(sd1, sd2, sd3, sd4, sd5)` - - Args: - *sigs: Any number of signatures. A signature is either a `ShapeDtype` - instance or a tuple of `ShapeDtype` instances. - - Returns: - A single `ShapeDtype` instance if the spliced signature has one element, - else a tuple of `ShapeDtype` instances. - """ - result_sigs = [] - for sig in sigs: - if isinstance(sig, (list, tuple)): - result_sigs.extend(sig) + Args: + obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict + of such objects. + + Returns: + A corresponding nested structure of `ShapeDtype` instances. + """ + if isinstance(obj, (list, tuple)): + output = tuple(signature(x) for x in obj) + return output if isinstance(obj, tuple) else list(output) + elif isinstance(obj, dict): + return {k: signature(v) for (k, v) in obj.items()} else: - result_sigs.append(sig) - return result_sigs[0] if len(result_sigs) == 1 else tuple(result_sigs) + return ShapeDtype(obj.shape, obj.dtype) + + +def splice_signatures(*sigs): + """Creates a new signature by splicing together any number of signatures. + + The splicing effectively flattens the top level input signatures. For + instance, it would perform the following mapping: + + - `*sigs: sd1, (sd2, sd3, sd4), (), sd5` + - return: `(sd1, sd2, sd3, sd4, sd5)` + + Args: + *sigs: Any number of signatures. A signature is either a `ShapeDtype` + instance or a tuple of `ShapeDtype` instances. + + Returns: + A single `ShapeDtype` instance if the spliced signature has one element, + else a tuple of `ShapeDtype` instances. + """ + result_sigs = [] + for sig in sigs: + if isinstance(sig, (list, tuple)): + result_sigs.extend(sig) + else: + result_sigs.append(sig) + return result_sigs[0] if len(result_sigs) == 1 else tuple(result_sigs) def assert_shape_equals(array, shape): - """Asserts that an array has the given shape.""" - assert array.shape == shape, ( - 'Invalid shape {}; expected {}.'.format(array.shape, shape) - ) + """Asserts that an array has the given shape.""" + assert array.shape == shape, "Invalid shape {}; expected {}.".format( + array.shape, shape + ) def assert_same_shape(array1, array2): - """Asserts that two arrays have the same shapes.""" - assert_shape_equals(array1, array2.shape) + """Asserts that two arrays have the same shapes.""" + assert_shape_equals(array1, array2.shape) diff --git a/trax/shapes_test.py b/trax/shapes_test.py deleted file mode 100644 index 4266195e5..000000000 --- a/trax/shapes_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.shapes.""" -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.shapes import ShapeDtype - - -class ShapesTest(absltest.TestCase): - - def test_constructor_and_read_properties(self): - sd = ShapeDtype((2, 3), np.int32) - self.assertEqual(sd.shape, (2, 3)) - self.assertEqual(sd.dtype, np.int32) - - def test_default_dtype_is_float32(self): - sd = ShapeDtype((2, 3)) - self.assertEqual(sd.shape, (2, 3)) - self.assertEqual(sd.dtype, np.float32) - - def test_signature_on_ndarray(self): - array = np.array([[2, 3, 5, 7], - [11, 13, 17, 19]], - dtype=np.int16) - sd = shapes.signature(array) - self.assertEqual(sd.shape, (2, 4)) - self.assertEqual(sd.dtype, np.int16) - - def test_shape_dtype_repr(self): - sd = ShapeDtype((2, 3)) - repr_string = '{}'.format(sd) - self.assertEqual(repr_string, - "ShapeDtype{shape:(2, 3), dtype:}") - - def test_splice_signatures(self): - sd1 = ShapeDtype((1,)) - sd2 = ShapeDtype((2,)) - sd3 = ShapeDtype((3,)) - sd4 = ShapeDtype((4,)) - sd5 = ShapeDtype((5,)) - - # Signatures can be ShapeDtype instances, tuples of 2+ ShapeDtype instances, - # or empty tuples. - sig1 = sd1 - sig2 = (sd2, sd3, sd4) - sig3 = () - sig4 = sd5 - spliced = shapes.splice_signatures(sig1, sig2, sig3, sig4) - self.assertEqual(spliced, (sd1, sd2, sd3, sd4, sd5)) - - def test_len_signature(self): - """Signatures of all sizes should give correct length when asked.""" - x1 = np.array([1, 2, 3]) - x2 = np.array([10, 20, 30]) - inputs0 = () - inputs1 = x1 # NOT in a tuple - inputs2 = (x1, x2) - - sig0 = shapes.signature(inputs0) - sig1 = shapes.signature(inputs1) - sig2 = shapes.signature(inputs2) - - # pylint: disable=g-generic-assert - self.assertEqual(len(sig0), 0) - self.assertEqual(len(sig1), 1) - self.assertEqual(len(sig2), 2) - # pylint: enable=g-generic-assert - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/callbacks.py b/trax/supervised/callbacks.py index 9c9b826b9..3da289cd2 100644 --- a/trax/supervised/callbacks.py +++ b/trax/supervised/callbacks.py @@ -40,209 +40,211 @@ class TrainingStepCallback: - """Callback triggered before and after a training step.""" + """Callback triggered before and after a training step.""" - def __init__(self, loop): - """Initializes the callback with a `supervised.training.Loop` instance.""" - self._loop = loop + def __init__(self, loop): + """Initializes the callback with a `supervised.training.Loop` instance.""" + self._loop = loop - def call_at(self, step): - """Returns whether the callback should be called at a given step.""" - raise NotImplementedError + def call_at(self, step): + """Returns whether the callback should be called at a given step.""" + raise NotImplementedError - def on_step_begin(self, step): - """Called by Loop before training steps, when call_at returned True.""" - raise NotImplementedError + def on_step_begin(self, step): + """Called by Loop before training steps, when call_at returned True.""" + raise NotImplementedError - def on_step_end(self, step): - """Called by Loop after training steps, when call_at returned True.""" - raise NotImplementedError + def on_step_end(self, step): + """Called by Loop after training steps, when call_at returned True.""" + raise NotImplementedError @gin.configurable class SerializedModelEvaluation(TrainingStepCallback): - """Evaluates serialized sequence prediction models. - - Example: time series prediction. We can serialize a time series into - a sequence of discrete tokens and model this sequence using an autoregressive - sequence model, such as Transformer - see - `trax.rl.serialization_utils.SerializedModel`. Then we can use this callback - to evaluate long-horizon predictions of such a model. - """ - - def __init__( - self, - loop, - model=None, - eval_at=1000, - eval_task=None, - context_lengths=(1,), - horizon_lengths=(1,), - n_steps=1, - accelerate_model=True, - ): - """Initializes SerializedModelEvaluation. - - Args: - loop: Instance of `trax.supervised.training.Loop` or `None`. Can be set to - `None` for testing - in such a case, `model` and `eval_task` must be - provided. - model: Instance of `trax.rl.serialization_utils.SerializedModel`. Not - required if `loop` is provided. - eval_at: When to evaluate. Either int (every how many steps to evaluate), - or a list of ints (step numbers), or a function int -> bool (step - predicate). - eval_task: Instance of `trax.supervised.training.EvalTask` with the - evaluation data, or None. If not provided, the task will be taken from - `loop`. - context_lengths: List of lengths of the context sequence fed into the - model before starting prediction. - horizon_lengths: List of lengths of the predicted sequence. - n_steps: Number of batches to run evaluation for. - accelerate_model (bool): Whether to wrap the model in `tl.Accelerate`. - """ - super().__init__(loop) - - if model is None: - model = loop.model - - observation_serializer = model.observation_serializer - action_serializer = model.action_serializer - - predict_model = model.make_predict_model() - if accelerate_model: - predict_model = tl.Accelerate(predict_model) - self._predict_model = predict_model - self._obs_serializer = observation_serializer - self._act_serializer = action_serializer - - if isinstance(eval_at, int): - self._eval_at = lambda step: step % eval_at == 1 - elif hasattr(eval_at, '__in__'): - self._eval_at = lambda step: step in eval_at - elif callable(eval_at): - self._eval_at = eval_at - else: - raise TypeError(f'Unsupported type for eval_at: {type(eval_at)}.') - - if eval_task is None: - if len(loop.eval_tasks) != 1: - raise ValueError( - 'If eval_task is not provided, the number of eval_tasks registered ' - 'in Loop must be exactly 1.' - ) - eval_task = loop.eval_tasks[0] - self._eval_task = eval_task + """Evaluates serialized sequence prediction models. - self._context_lengths = list(sorted(context_lengths)) - self._horizon_lengths = list(sorted(horizon_lengths)) - self._n_steps = n_steps + Example: time series prediction. We can serialize a time series into + a sequence of discrete tokens and model this sequence using an autoregressive + sequence model, such as Transformer - see + `trax.rl.serialization_utils.SerializedModel`. Then we can use this callback + to evaluate long-horizon predictions of such a model. + """ - self._batch_size = eval_task.sample_batch[0].shape[0] - (_, self._init_state) = predict_model.init( - shapes.ShapeDtype((self._batch_size, 1), dtype=np.int32) - ) + def __init__( + self, + loop, + model=None, + eval_at=1000, + eval_task=None, + context_lengths=(1,), + horizon_lengths=(1,), + n_steps=1, + accelerate_model=True, + ): + """Initializes SerializedModelEvaluation. + + Args: + loop: Instance of `trax.supervised.training.Loop` or `None`. Can be set to + `None` for testing - in such a case, `model` and `eval_task` must be + provided. + model: Instance of `trax.rl.serialization_utils.SerializedModel`. Not + required if `loop` is provided. + eval_at: When to evaluate. Either int (every how many steps to evaluate), + or a list of ints (step numbers), or a function int -> bool (step + predicate). + eval_task: Instance of `trax.supervised.training.EvalTask` with the + evaluation data, or None. If not provided, the task will be taken from + `loop`. + context_lengths: List of lengths of the context sequence fed into the + model before starting prediction. + horizon_lengths: List of lengths of the predicted sequence. + n_steps: Number of batches to run evaluation for. + accelerate_model (bool): Whether to wrap the model in `tl.Accelerate`. + """ + super().__init__(loop) + + if model is None: + model = loop.model + + observation_serializer = model.observation_serializer + action_serializer = model.action_serializer + + predict_model = model.make_predict_model() + if accelerate_model: + predict_model = tl.Accelerate(predict_model) + self._predict_model = predict_model + self._obs_serializer = observation_serializer + self._act_serializer = action_serializer + + if isinstance(eval_at, int): + self._eval_at = lambda step: step % eval_at == 1 + elif hasattr(eval_at, "__in__"): + self._eval_at = lambda step: step in eval_at + elif callable(eval_at): + self._eval_at = eval_at + else: + raise TypeError(f"Unsupported type for eval_at: {type(eval_at)}.") + + if eval_task is None: + if len(loop.eval_tasks) != 1: + raise ValueError( + "If eval_task is not provided, the number of eval_tasks registered " + "in Loop must be exactly 1." + ) + eval_task = loop.eval_tasks[0] + self._eval_task = eval_task + + self._context_lengths = list(sorted(context_lengths)) + self._horizon_lengths = list(sorted(horizon_lengths)) + self._n_steps = n_steps + + self._batch_size = eval_task.sample_batch[0].shape[0] + (_, self._init_state) = predict_model.init( + shapes.ShapeDtype((self._batch_size, 1), dtype=np.int32) + ) - @property - def predict_model(self): - return self._predict_model + @property + def predict_model(self): + return self._predict_model - def call_at(self, step): - return self._eval_at(step) + def call_at(self, step): + return self._eval_at(step) - def on_step_begin(self, step): - pass + def on_step_begin(self, step): + pass - def on_step_end(self, step): - summary_writer = jaxboard.SummaryWriter( - os.path.join(self._loop.output_dir, 'srl_eval') - ) - try: - weights = self._loop.eval_model.seq_model_weights - metrics = self.evaluate(weights) - self._loop.log_summary(metrics, summary_writer, '', 'srl_eval') - finally: - summary_writer.close() - - def evaluate(self, weights): - """Evaluates the model and returns the metrics.""" - self._predict_model.weights = weights - - metrics = collections.defaultdict(list) - for _ in range(self._n_steps): - batch = self._eval_task.next_batch() - step_metrics = self._evaluate_batch(batch) - for (key, value) in step_metrics.items(): - metrics[key].append(value) - - metrics = {k: np.array(v) for (k, v) in metrics.items()} - - def metric_name(context, horizon): - return f'pred_error/context_{context}/horizon_{horizon}' - - return { - metric_name(context, horizon): - np.sum(errors) / (np.sum(errors != 0) + 1e-6) - for ((context, horizon), errors) in metrics.items() - } - - def _evaluate_batch(self, batch): - """Performs evaluation on a single batch.""" - (obs, act, _, mask) = batch - obs_repr = serialization_utils.Serialize(self._obs_serializer)(obs) - act_repr = serialization_utils.Serialize(self._act_serializer)(act) - - errors = {} - last_context = 0 - last_state = self._init_state - last_start_id = 0 - for context in self._context_lengths: - self._predict_model.state = last_state - start_id = last_start_id - - if context > last_context: - context_seq = serialization_utils.Interleave()(( - obs_repr[:, last_context:context], act_repr[:, last_context:context] - )) - consume_sequence(self._predict_model, start_id, context_seq[:, :-1]) - last_start_id = start_id = context_seq[:, -1:] - last_state = self._predict_model.state - last_context = context - - for timestep in range(max(self._horizon_lengths)): - pred_repr = decoding.autoregressive_sample( - self._predict_model, - start_id=start_id, - eos_id=-1, - batch_size=self._batch_size, - max_length=self._obs_serializer.representation_length, - accelerate=False, - ) - horizon = timestep + 1 - if horizon in self._horizon_lengths: - pred = self._obs_serializer.deserialize(pred_repr) - error = self._calculate_error(pred, obs[:, context + timestep]) - errors[context, horizon] = error * mask[:, context + timestep] - - start_id = pred_repr[:, -1:] - consume_sequence( - self._predict_model, start_id, act_repr[:, context + timestep, :-1] + def on_step_end(self, step): + summary_writer = jaxboard.SummaryWriter( + os.path.join(self._loop.output_dir, "srl_eval") ) - start_id = act_repr[:, context + timestep, -1:] - - return errors - - def _calculate_error(self, prediction, ground_truth): - return (prediction - ground_truth) ** 2 + try: + weights = self._loop.eval_model.seq_model_weights + metrics = self.evaluate(weights) + self._loop.log_summary(metrics, summary_writer, "", "srl_eval") + finally: + summary_writer.close() + + def evaluate(self, weights): + """Evaluates the model and returns the metrics.""" + self._predict_model.weights = weights + + metrics = collections.defaultdict(list) + for _ in range(self._n_steps): + batch = self._eval_task.next_batch() + step_metrics = self._evaluate_batch(batch) + for (key, value) in step_metrics.items(): + metrics[key].append(value) + + metrics = {k: np.array(v) for (k, v) in metrics.items()} + + def metric_name(context, horizon): + return f"pred_error/context_{context}/horizon_{horizon}" + + return { + metric_name(context, horizon): np.sum(errors) / (np.sum(errors != 0) + 1e-6) + for ((context, horizon), errors) in metrics.items() + } + + def _evaluate_batch(self, batch): + """Performs evaluation on a single batch.""" + (obs, act, _, mask) = batch + obs_repr = serialization_utils.Serialize(self._obs_serializer)(obs) + act_repr = serialization_utils.Serialize(self._act_serializer)(act) + + errors = {} + last_context = 0 + last_state = self._init_state + last_start_id = 0 + for context in self._context_lengths: + self._predict_model.state = last_state + start_id = last_start_id + + if context > last_context: + context_seq = serialization_utils.Interleave()( + ( + obs_repr[:, last_context:context], + act_repr[:, last_context:context], + ) + ) + consume_sequence(self._predict_model, start_id, context_seq[:, :-1]) + last_start_id = start_id = context_seq[:, -1:] + last_state = self._predict_model.state + last_context = context + + for timestep in range(max(self._horizon_lengths)): + pred_repr = decoding.autoregressive_sample( + self._predict_model, + start_id=start_id, + eos_id=-1, + batch_size=self._batch_size, + max_length=self._obs_serializer.representation_length, + accelerate=False, + ) + horizon = timestep + 1 + if horizon in self._horizon_lengths: + pred = self._obs_serializer.deserialize(pred_repr) + error = self._calculate_error(pred, obs[:, context + timestep]) + errors[context, horizon] = error * mask[:, context + timestep] + + start_id = pred_repr[:, -1:] + consume_sequence( + self._predict_model, start_id, act_repr[:, context + timestep, :-1] + ) + start_id = act_repr[:, context + timestep, -1:] + + return errors + + def _calculate_error(self, prediction, ground_truth): + return (prediction - ground_truth) ** 2 def consume_sequence(model, start_id, sequence): - decoding.autoregressive_sample( - model, - start_id=start_id, - eos_id=-1, - inputs=sequence, - batch_size=sequence.shape[0], - max_length=1, - accelerate=False, - ) + decoding.autoregressive_sample( + model, + start_id=start_id, + eos_id=-1, + inputs=sequence, + batch_size=sequence.shape[0], + max_length=1, + accelerate=False, + ) diff --git a/trax/supervised/callbacks_test.py b/trax/supervised/callbacks_test.py deleted file mode 100644 index 3eaf328f8..000000000 --- a/trax/supervised/callbacks_test.py +++ /dev/null @@ -1,226 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.supervised.callbacks.""" - -import functools -import io -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import gym -import numpy as np - -from trax import models -from trax import test_utils -from trax.data import inputs -from trax.layers import test_utils as tl_test_utils -from trax.rl import serialization_utils -from trax.rl import space_serializer -from trax.supervised import callbacks -from trax.supervised import lr_schedules -from trax.supervised import trainer_lib -from trax.supervised import training - - -def random_inputs(seq_len, batch_size): - def stream_fn(num_devices): - del num_devices - while True: - x = np.random.uniform(size=(batch_size, seq_len)) - y = np.random.uniform(size=(batch_size, seq_len)) - mask = np.ones_like(x).astype(np.float32) - yield (x, y, x, mask) - - return inputs.Inputs( - train_stream=stream_fn, - eval_stream=stream_fn, - ) - - -def make_multibonacci_modulo(history_length, limit): - """Creates a function that generates the Multibonacci sequence modulo n.""" - def sequence_fn(seq): - return np.sum(seq[-history_length:]) % limit - return sequence_fn - - -def generate_trajectory(sequence_fn, space, n_steps): - """Generates random actions and observations that follow sequence_fn.""" - act = [space.sample() for _ in range(n_steps)] - obs = [space.sample()] - - for (o, a) in zip( - obs, - act[:-1], # Don't generate the last observation. - ): - context = list(np.array([o, a]).flatten()) - symbols = [] - for _ in range(np.array(o).size): - symbol = sequence_fn(context + symbols) - symbols.append(symbol) - obs.append(np.reshape(symbols, space.shape)) - - obs = np.array([obs]) - act = np.array([act]) - return (obs, act) - - -def make_singleton_eval_task(observations, actions): - """Creates an EvalTask with just one example.""" - mask = np.ones(observations.shape[:2]) - def data(): - while True: - yield (observations, actions, observations, mask) - - return training.EvalTask( - labeled_data=data(), - metrics=[], - ) - - -def make_serialized_model(seq_model, space, vocab_size): - srl = space_serializer.create(space, vocab_size) - return serialization_utils.SerializedModel( - functools.partial(seq_model, vocab_size=vocab_size), - observation_serializer=srl, - action_serializer=srl, - significance_decay=0.7, - ) - - -class CallbacksTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - - @mock.patch('sys.stdout', new_callable=io.StringIO) - def test_serialized_model_evaluation(self, mock_stdout): - precision = 1 - vocab_size = 2 - srl = space_serializer.BoxSpaceSerializer( - space=gym.spaces.Box(shape=(), low=0.0, high=1.0), - vocab_size=vocab_size, - precision=precision, - ) - - def inner_model(mode): - return models.TransformerLM( - mode=mode, - vocab_size=vocab_size, - d_model=2, - d_ff=4, - n_layers=1, - n_heads=1, - ) - - serialized_model_fn = functools.partial( - serialization_utils.SerializedModel, - inner_model, - observation_serializer=srl, - action_serializer=srl, - significance_decay=0.7, - ) - eval_callback = functools.partial( - callbacks.SerializedModelEvaluation, eval_at=5 - ) - - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir=output_dir, - model=serialized_model_fn, - inputs=functools.partial(random_inputs, seq_len=4, batch_size=64), - lr_schedule_fn=functools.partial(lr_schedules.constant, 0.01), - callbacks=[eval_callback], - steps=10, - ) - self.assertTrue(_has_metric('pred_error', mock_stdout)) - - @parameterized.product( - context_lengths=((2,), (1, 3)), - horizon_lengths=((1,), (1, 2)), - ) - def test_srl_eval_feeds_correct_sequence( - self, context_lengths, horizon_lengths - ): - vocab_size = 10 - n_steps = 5 - - multibonacci_modulo = make_multibonacci_modulo(2, vocab_size) - space = gym.spaces.Discrete(n=vocab_size) - (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps) - eval_task = make_singleton_eval_task(obs, act) - seq_model = functools.partial( - tl_test_utils.MockTransformerLM, - sequence_fn=multibonacci_modulo, - ) - serialized_model = make_serialized_model(seq_model, space, vocab_size) - callback = callbacks.SerializedModelEvaluation( - loop=None, - eval_task=eval_task, - model=serialized_model, - context_lengths=context_lengths, - horizon_lengths=horizon_lengths, - accelerate_model=False, - ) - callback.evaluate(weights=None) - - expected_seq = np.zeros(2 * n_steps + 1) - expected_seq[1::2] = obs - expected_seq[2::2] = act - seen_len = (context_lengths[-1] + horizon_lengths[-1]) * 2 - callback.predict_model.assert_prediction_buffers_equal( - [expected_seq[:seen_len]] - ) - - @parameterized.named_parameters(('one_symbol', 1), ('two_symbols', 2)) - def test_srl_eval_reports_zero_error_for_perfect_model(self, precision): - vocab_size = 100 - n_steps = 5 - - multibonacci_modulo = make_multibonacci_modulo(2 * precision, vocab_size) - space = gym.spaces.MultiDiscrete(nvec=([vocab_size] * precision)) - (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps) - eval_task = make_singleton_eval_task(obs, act) - seq_model = functools.partial( - tl_test_utils.MockTransformerLM, - sequence_fn=multibonacci_modulo, - ) - serialized_model = make_serialized_model(seq_model, space, vocab_size) - callback = callbacks.SerializedModelEvaluation( - loop=None, - eval_task=eval_task, - model=serialized_model, - context_lengths=(1,), - horizon_lengths=(4,), - accelerate_model=False, - ) - metrics = callback.evaluate(weights=None) - error = next( - value for (name, value) in metrics.items() if 'pred_error' in name - ) - assert error == 0 - - -def _has_metric(metric_name, stdout): - log = stdout.getvalue() - metric_logs = [line for line in log.split('\n') if metric_name in line] - return bool(metric_logs) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/decoding.py b/trax/supervised/decoding.py index d8902c1bc..7ecda87c2 100644 --- a/trax/supervised/decoding.py +++ b/trax/supervised/decoding.py @@ -20,244 +20,279 @@ from trax import layers as tl -def autoregressive_sample_stream(model, inputs=None, - batch_size=1, temperature=1.0, - start_id=0, accelerate=True, - eval_mode=False, eval_min_length=1): - """Yields samples from `model`, in autoregressive language model fashion. +def autoregressive_sample_stream( + model, + inputs=None, + batch_size=1, + temperature=1.0, + start_id=0, + accelerate=True, + eval_mode=False, + eval_min_length=1, +): + """Yields samples from `model`, in autoregressive language model fashion. - This function uses `model` to generate outputs one position at a time, with - access to inputs for the current position and all preceding positions. The - new output becomes the next position's input, and further calls to - `autoregressive_sample_stream` repeat the process for successive positions - indefinitely. + This function uses `model` to generate outputs one position at a time, with + access to inputs for the current position and all preceding positions. The + new output becomes the next position's input, and further calls to + `autoregressive_sample_stream` repeat the process for successive positions + indefinitely. - Inputs and outputs always come in batches, even if size 1. If `inputs` is - present, it must have shape (`batch_size`, inputs_sequence_length), and each - output in the stream has shape (`batch_size`, 1). + Inputs and outputs always come in batches, even if size 1. If `inputs` is + present, it must have shape (`batch_size`, inputs_sequence_length), and each + output in the stream has shape (`batch_size`, 1). - Args: - model: A layer object (subclass of `trax.layers.Layer`) created in - `'predict'` mode and initialized from trained weights. The model - must have a structure that allows it to run as an autoregressive - one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), - except if `eval_mode` is set -- any model can be sampled then, - but the sampling process may be much slower. - inputs: Sequence of symbols the model sees as input the first time it - generates an output. If None, the model generates the first output - based on just the start symbol. - batch_size: Number of sequences to generate in parallel as a batch. - temperature: Parameter that controls the sharpness of the softmax that - feeds the sampling process. Values range from 0.0 (all probability mass - goes to one candidate; like an argmax) to positive infinity (all - candidates have equal probability). - start_id: Integer representing the start symbol for the autoregressive - process, or array of shape (`batch_size`, 1) of such integers. - accelerate: If True, create an accelerated version of `model` and use it - for generating outputs. - eval_mode: If True, assume the model is created in `eval` mode and sample - by collecting all previous outputs and passing the whole tensor. - eval_min_length: If set, the minimum length to pad to in eval mode. + Args: + model: A layer object (subclass of `trax.layers.Layer`) created in + `'predict'` mode and initialized from trained weights. The model + must have a structure that allows it to run as an autoregressive + one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), + except if `eval_mode` is set -- any model can be sampled then, + but the sampling process may be much slower. + inputs: Sequence of symbols the model sees as input the first time it + generates an output. If None, the model generates the first output + based on just the start symbol. + batch_size: Number of sequences to generate in parallel as a batch. + temperature: Parameter that controls the sharpness of the softmax that + feeds the sampling process. Values range from 0.0 (all probability mass + goes to one candidate; like an argmax) to positive infinity (all + candidates have equal probability). + start_id: Integer representing the start symbol for the autoregressive + process, or array of shape (`batch_size`, 1) of such integers. + accelerate: If True, create an accelerated version of `model` and use it + for generating outputs. + eval_mode: If True, assume the model is created in `eval` mode and sample + by collecting all previous outputs and passing the whole tensor. + eval_min_length: If set, the minimum length to pad to in eval mode. - Yields: - Tensor of integers with shape (`batch_size`, 1), representing the batch of - outputs for the next position in the stream. - """ - if inputs is not None and inputs.shape[0] != batch_size: - raise ValueError(f'Inputs batch size ({inputs.shape[0]}) does not match ' - f'batch_size arg ({batch_size}.') + Yields: + Tensor of integers with shape (`batch_size`, 1), representing the batch of + outputs for the next position in the stream. + """ + if inputs is not None and inputs.shape[0] != batch_size: + raise ValueError( + f"Inputs batch size ({inputs.shape[0]}) does not match " + f"batch_size arg ({batch_size}." + ) - fast_model = tl.Accelerate(model) if accelerate else model - if np.isscalar(start_id): - start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) - else: - start_symbol = start_id - if model.n_in == 1 and inputs is not None: - current_symbols = np.concatenate([start_symbol, inputs], axis=1) - else: - current_symbols = start_symbol - - if eval_mode: - # no start symbol needed in eval mode - current_symbols = current_symbols[:, 1:] - - while True: - # Pad inputs to power-of-2 length if needed. - if eval_mode: - # one extra symbol as an initial one will be added - l = max(eval_min_length, current_symbols.shape[1] + 1) - pad_len = int(2**np.ceil(np.log2(l))) - current_symbols.shape[1] - unpadded_symbols = current_symbols - current_symbols = np.pad( - current_symbols, [[0, 0], [0, pad_len]], mode='constant') - last_index = -pad_len # no -1 as the starting one will be added + fast_model = tl.Accelerate(model) if accelerate else model + if np.isscalar(start_id): + start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) else: - last_index = -1 - # Run the model. - if model.n_in > 1 and inputs is not None: - logits = fast_model((inputs, current_symbols))[0] + start_symbol = start_id + if model.n_in == 1 and inputs is not None: + current_symbols = np.concatenate([start_symbol, inputs], axis=1) else: - logits = fast_model(current_symbols) - logits = tl.log_softmax(logits[:, last_index, :]) - sample = tl.logsoftmax_sample(logits, temperature=temperature) - yield sample + current_symbols = start_symbol + if eval_mode: - current_symbols = np.concatenate( - [unpadded_symbols, sample[:, None]], axis=1) - else: - # NOTE: Because the model is autoregressive and in 'predict' mode, its - # history is cached in the model state and the next input is the single - # symbol just sampled. - current_symbols = sample[:, None] + # no start symbol needed in eval mode + current_symbols = current_symbols[:, 1:] + + while True: + # Pad inputs to power-of-2 length if needed. + if eval_mode: + # one extra symbol as an initial one will be added + l = max(eval_min_length, current_symbols.shape[1] + 1) + pad_len = int(2 ** np.ceil(np.log2(l))) - current_symbols.shape[1] + unpadded_symbols = current_symbols + current_symbols = np.pad( + current_symbols, [[0, 0], [0, pad_len]], mode="constant" + ) + last_index = -pad_len # no -1 as the starting one will be added + else: + last_index = -1 + # Run the model. + if model.n_in > 1 and inputs is not None: + logits = fast_model((inputs, current_symbols))[0] + else: + logits = fast_model(current_symbols) + logits = tl.log_softmax(logits[:, last_index, :]) + sample = tl.logsoftmax_sample(logits, temperature=temperature) + yield sample + if eval_mode: + current_symbols = np.concatenate( + [unpadded_symbols, sample[:, None]], axis=1 + ) + else: + # NOTE: Because the model is autoregressive and in 'predict' mode, its + # history is cached in the model state and the next input is the single + # symbol just sampled. + current_symbols = sample[:, None] -def autoregressive_sample(model, inputs=None, - batch_size=1, temperature=1.0, - start_id=0, eos_id=1, max_length=100, - accelerate=True, eval_mode=False, eval_min_length=1): - """Returns a batch of sequences created by autoregressive sampling. +def autoregressive_sample( + model, + inputs=None, + batch_size=1, + temperature=1.0, + start_id=0, + eos_id=1, + max_length=100, + accelerate=True, + eval_mode=False, + eval_min_length=1, +): + """Returns a batch of sequences created by autoregressive sampling. - This function uses `model` to generate outputs one position at a time, with - access to inputs for the current position and all preceding positions. The - new output becomes the next position's input, and this loop repeats until - either the model outputs the `eos_id` value or the output sequence reaches - `max_length` items. + This function uses `model` to generate outputs one position at a time, with + access to inputs for the current position and all preceding positions. The + new output becomes the next position's input, and this loop repeats until + either the model outputs the `eos_id` value or the output sequence reaches + `max_length` items. - Args: - model: A layer object (subclass of `trax.layers.Layer`) created in - `'predict'` mode and initialized from trained weights. The model - must have a structure that allows it to run as autoregressive - one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), - except if `eval_mode` is set -- any model can be sampled then, - but the sampling process may be much slower. - inputs: Sequence of symbols the model sees as input the first time it - generates an output. If None, the model must generate the first output - with no input to guide it. - batch_size: Number of sequences to generate in parallel as a batch. - temperature: Parameter that controls the sharpness of the softmax that - feeds the sampling process. Values range from 0.0 (all probability mass - goes to one candidate; like an argmax) to positive infinity (all - candidates have equal probability). - start_id: The start symbol (ID/integer) for the autoregressive process, - or array of shape (`batch_size`, 1) of such integers. - eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive - process. - max_length: Maximum length for generated sequences. - accelerate: If True, create an accelerated version of `model` and use it - for generating outputs. - eval_mode: If True, assume the model is created in `eval` mode and sample - by collecting all previous outputs and passing the whole tensor. - eval_min_length: If set, the minimum length to pad to in eval mode. + Args: + model: A layer object (subclass of `trax.layers.Layer`) created in + `'predict'` mode and initialized from trained weights. The model + must have a structure that allows it to run as autoregressive + one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), + except if `eval_mode` is set -- any model can be sampled then, + but the sampling process may be much slower. + inputs: Sequence of symbols the model sees as input the first time it + generates an output. If None, the model must generate the first output + with no input to guide it. + batch_size: Number of sequences to generate in parallel as a batch. + temperature: Parameter that controls the sharpness of the softmax that + feeds the sampling process. Values range from 0.0 (all probability mass + goes to one candidate; like an argmax) to positive infinity (all + candidates have equal probability). + start_id: The start symbol (ID/integer) for the autoregressive process, + or array of shape (`batch_size`, 1) of such integers. + eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive + process. + max_length: Maximum length for generated sequences. + accelerate: If True, create an accelerated version of `model` and use it + for generating outputs. + eval_mode: If True, assume the model is created in `eval` mode and sample + by collecting all previous outputs and passing the whole tensor. + eval_min_length: If set, the minimum length to pad to in eval mode. - Returns: - Tensor of integers with shape (`batch_size`, output_length) representing - a batch of output sequences. output_length is the maximum length of the - output sequences, where each sequence can be no longer than `max_length`. - """ - result = [] - eos_seen = [] - counter = 0 - for sample in autoregressive_sample_stream( - model, inputs, batch_size=batch_size, temperature=temperature, - start_id=start_id, accelerate=accelerate, eval_mode=eval_mode, - eval_min_length=eval_min_length): - sample = sample[:, None] - result.append(sample) - counter += 1 - if counter >= max_length: - return np.concatenate(result, axis=1) - # Check at which batch positions have we already encountered EOS. - for j in range(batch_size): - if int(sample[j, 0]) == eos_id: - eos_seen.append(j) - # If EOS has been seen on all positions, stop. - if all([j in eos_seen for j in range(batch_size)]): - return np.concatenate(result, axis=1) - return np.concatenate(result, axis=1) + Returns: + Tensor of integers with shape (`batch_size`, output_length) representing + a batch of output sequences. output_length is the maximum length of the + output sequences, where each sequence can be no longer than `max_length`. + """ + result = [] + eos_seen = [] + counter = 0 + for sample in autoregressive_sample_stream( + model, + inputs, + batch_size=batch_size, + temperature=temperature, + start_id=start_id, + accelerate=accelerate, + eval_mode=eval_mode, + eval_min_length=eval_min_length, + ): + sample = sample[:, None] + result.append(sample) + counter += 1 + if counter >= max_length: + return np.concatenate(result, axis=1) + # Check at which batch positions have we already encountered EOS. + for j in range(batch_size): + if int(sample[j, 0]) == eos_id: + eos_seen.append(j) + # If EOS has been seen on all positions, stop. + if all([j in eos_seen for j in range(batch_size)]): + return np.concatenate(result, axis=1) + return np.concatenate(result, axis=1) -def beam_search(model, inputs=None, batch_size=1, n_beams=2, start_id=0, - eos_id=1, max_length=100, length_penalty=1.0, accelerate=True): - """Returns a batch of n_beams-sequences created by beam search. +def beam_search( + model, + inputs=None, + batch_size=1, + n_beams=2, + start_id=0, + eos_id=1, + max_length=100, + length_penalty=1.0, + accelerate=True, +): + """Returns a batch of n_beams-sequences created by beam search. - This function uses `model` to generate outputs one position at a time, with - access to inputs for the current position and all preceding positions. The - new output becomes the next position's input, and this loop repeats until - either the model outputs the `eos_id` value or the output sequence reaches - `max_length` items -- but keeping n_beams top beams. + This function uses `model` to generate outputs one position at a time, with + access to inputs for the current position and all preceding positions. The + new output becomes the next position's input, and this loop repeats until + either the model outputs the `eos_id` value or the output sequence reaches + `max_length` items -- but keeping n_beams top beams. - Args: - model: A layer object (subclass of `trax.layers.Layer`) created in - `'predict'` mode and initialized from trained weights. The model - must have a structure that allows it to run as autoregressive - one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`). - inputs: Sequence of symbols the model sees as input the first time it - generates an output. If None, the model must generate the first output - with no input to guide it. - batch_size: Number of sequences to generate in parallel as a batch. - n_beams: How many beams to consider at the same time. - start_id: The start symbol (ID/integer) for the autoregressive process, - or array of shape (`batch_size`, 1) of such integers. - eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive - process. - max_length: Maximum length for generated sequences. - length_penalty: Factor alpha in calculating the length penalty for beams. - accelerate: If True, create an accelerated version of `model` and use it - for generating outputs. + Args: + model: A layer object (subclass of `trax.layers.Layer`) created in + `'predict'` mode and initialized from trained weights. The model + must have a structure that allows it to run as autoregressive + one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`). + inputs: Sequence of symbols the model sees as input the first time it + generates an output. If None, the model must generate the first output + with no input to guide it. + batch_size: Number of sequences to generate in parallel as a batch. + n_beams: How many beams to consider at the same time. + start_id: The start symbol (ID/integer) for the autoregressive process, + or array of shape (`batch_size`, 1) of such integers. + eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive + process. + max_length: Maximum length for generated sequences. + length_penalty: Factor alpha in calculating the length penalty for beams. + accelerate: If True, create an accelerated version of `model` and use it + for generating outputs. - Returns: - Tensor of integers with shape (`batch_size`, n_beams, output_length) with - a batch of output sequences. output_length is the maximum length of the - output sequences, where each sequence can be no longer than `max_length`. - """ - del eos_id, length_penalty # TODO(lukaszkaiser): add length penalty, eos - assert batch_size == 1, 'Batch size > 1 not supported yet' - if inputs is not None and inputs.shape[0] != batch_size: - raise ValueError(f'Inputs batch size ({inputs.shape[0]}) does not match ' - f'batch_size arg ({batch_size}.') + Returns: + Tensor of integers with shape (`batch_size`, n_beams, output_length) with + a batch of output sequences. output_length is the maximum length of the + output sequences, where each sequence can be no longer than `max_length`. + """ + del eos_id, length_penalty # TODO(lukaszkaiser): add length penalty, eos + assert batch_size == 1, "Batch size > 1 not supported yet" + if inputs is not None and inputs.shape[0] != batch_size: + raise ValueError( + f"Inputs batch size ({inputs.shape[0]}) does not match " + f"batch_size arg ({batch_size}." + ) - fast_model = tl.Accelerate(model) if accelerate else model - if np.isscalar(start_id): - start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) - else: - start_symbol = start_id - if model.n_in == 1 and inputs is not None: - current_symbols = np.concatenate([start_symbol, inputs], axis=1) - else: - current_symbols = start_symbol + fast_model = tl.Accelerate(model) if accelerate else model + if np.isscalar(start_id): + start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) + else: + start_symbol = start_id + if model.n_in == 1 and inputs is not None: + current_symbols = np.concatenate([start_symbol, inputs], axis=1) + else: + current_symbols = start_symbol - beams = [current_symbols for _ in range(n_beams)] - results = [([], 0.0) for _ in range(n_beams)] - states = [fast_model.state for _ in range(n_beams)] - top_k = [None] * n_beams - counter = 0 - while counter < max_length: - counter += 1 - # Run the model on all beams, collect states and top_k for each beam. - for beam_id in range(n_beams if counter > 1 else 1): - fast_model.state = states[beam_id] - if model.n_in > 1 and inputs is not None: - logits = fast_model((inputs, beams[beam_id]))[0] - else: - logits = fast_model(beams[beam_id]) - logits = tl.log_softmax(logits[:, -1, :]) - states[beam_id] = fast_model.state - top_k[beam_id] = fastmath.top_k(logits, k=n_beams) + beams = [current_symbols for _ in range(n_beams)] + results = [([], 0.0) for _ in range(n_beams)] + states = [fast_model.state for _ in range(n_beams)] + top_k = [None] * n_beams + counter = 0 + while counter < max_length: + counter += 1 + # Run the model on all beams, collect states and top_k for each beam. + for beam_id in range(n_beams if counter > 1 else 1): + fast_model.state = states[beam_id] + if model.n_in > 1 and inputs is not None: + logits = fast_model((inputs, beams[beam_id]))[0] + else: + logits = fast_model(beams[beam_id]) + logits = tl.log_softmax(logits[:, -1, :]) + states[beam_id] = fast_model.state + top_k[beam_id] = fastmath.top_k(logits, k=n_beams) - # Select new beams. - cur_values = [] # will hold triples (sum-of-logprobs, beam-id, symbol) - for beam_id in range(n_beams if counter > 1 else 1): - for k in range(n_beams): - values, symbols = top_k[beam_id] - value, symbol = values[:, k], symbols[:, k] - cur_values.append((results[beam_id][1] + value, beam_id, symbol)) - cur_values.sort(key=lambda x: -x[0][0]) # x[0][0] as batch_size=1 - # Collect top beams to the new states and results. - new_results, new_states, new_beams = [], [], [] - for (value, beam_id, symbol) in cur_values[:n_beams]: - new_results.append((results[beam_id][0] + [symbol], value)) - new_states.append(states[beam_id]) # copy? - new_beams.append(symbol[:, None]) - results, states, beams = new_results, new_states, new_beams + # Select new beams. + cur_values = [] # will hold triples (sum-of-logprobs, beam-id, symbol) + for beam_id in range(n_beams if counter > 1 else 1): + for k in range(n_beams): + values, symbols = top_k[beam_id] + value, symbol = values[:, k], symbols[:, k] + cur_values.append((results[beam_id][1] + value, beam_id, symbol)) + cur_values.sort(key=lambda x: -x[0][0]) # x[0][0] as batch_size=1 + # Collect top beams to the new states and results. + new_results, new_states, new_beams = [], [], [] + for (value, beam_id, symbol) in cur_values[:n_beams]: + new_results.append((results[beam_id][0] + [symbol], value)) + new_states.append(states[beam_id]) # copy? + new_beams.append(symbol[:, None]) + results, states, beams = new_results, new_states, new_beams - return [(np.stack(r, axis=-1), v) for (r, v) in results] + return [(np.stack(r, axis=-1), v) for (r, v) in results] diff --git a/trax/supervised/decoding_test.py b/trax/supervised/decoding_test.py deleted file mode 100644 index afaad725c..000000000 --- a/trax/supervised/decoding_test.py +++ /dev/null @@ -1,453 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for decoding.""" - -import functools -import os - -import gin -from jax.config import config -import numpy as np -from tensorflow.compat.v2 import test - -from trax import fastmath -from trax import layers as tl -from trax import models -from trax import shapes -from trax.supervised import decoding - - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') -_CONFIG_DIR = os.path.join(pkg_dir, 'configs/') - - -class DecodingTest(test.TestCase): - - def test_autoregressive_sample_transformerlm(self): - model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, - n_heads=2, mode='predict') - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - s1 = decoding.autoregressive_sample( - model, batch_size=1, eos_id=-1, max_length=10) - self.assertEqual(s1.shape[0], 1) - self.assertEqual(s1.shape[1], 10) - batch_per_device = 2 // fastmath.local_device_count() - model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) - s2 = decoding.autoregressive_sample( - model, batch_size=2, max_length=10) - self.assertEqual(s2.shape[0], 2) - self.assertLess(s2.shape[1], 11) - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - prefix = np.array([[1, 2, 3]]) - s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10, - batch_size=1) - self.assertEqual(s3.shape[0], 1) - self.assertEqual(s3.shape[1], 10) - - def test_autoregressive_sample_transformerlm_tfnp(self): - with fastmath.use_backend(fastmath.Backend.TFNP): - model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, - n_heads=2, mode='predict') - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - s1 = decoding.autoregressive_sample( - model, batch_size=1, eos_id=-1, max_length=10) - self.assertEqual(s1.shape[0], 1) - self.assertEqual(s1.shape[1], 10) - batch_per_device = 2 // fastmath.local_device_count() - model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) - s2 = decoding.autoregressive_sample( - model, batch_size=2, max_length=10) - self.assertEqual(s2.shape[0], 2) - self.assertLess(s2.shape[1], 11) - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - prefix = np.array([[1, 2, 3]]) - s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, - max_length=10, batch_size=1) - self.assertEqual(s3.shape[0], 1) - self.assertEqual(s3.shape[1], 10) - - def _lsh_self_attention_fn(self): - return functools.partial( - tl.LSHSelfAttention, - attention_dropout=0.0, - chunk_len=64, - n_buckets=[32, 32], - n_chunks_after=0, - n_chunks_before=1, - n_hashes=1, - n_parallel_heads=1, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def _pure_lsh_self_attention_fn(self, n_chunks_after=0): - return functools.partial( - tl.PureLSHSelfAttentionWrapper, - attention_dropout=0.0, - chunk_len=16, - n_buckets=[32, 32], - n_chunks_after=n_chunks_after, - n_chunks_before=1, - n_hashes=2, - n_parallel_heads=1, - max_length_for_buckets=1024, - predict_drop_len=128, - predict_mem_len=1024, - num_weights=2, - bias=False, - pure_lsh_implementation=tl.PureLSHSelfAttention, - ) - - def _timebin_self_attention_fn(self, use_reference_code=False, chunk_len=64): - return functools.partial( - tl.SelfAttention, - attention_dropout=0.05, - chunk_len=chunk_len, - n_chunks_before=1, - n_parallel_heads=1, - use_reference_code=use_reference_code, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def test_autoregressive_sample_reformerlm(self): - lsh_self_attention = self._lsh_self_attention_fn() - timebin_self_attention = self._timebin_self_attention_fn() - - model = models.ReformerLM(vocab_size=256, - d_model=256, - d_ff=512, - d_attention_key=128, - d_attention_value=128, - n_layers=2, - n_heads=2, - dropout=0.05, - max_len=65536, - attention_type=[timebin_self_attention, - lsh_self_attention], - pos_axial_shape=(256, 256), - pos_d_axial_embs=(128, 128), - ff_activation=tl.Relu, - ff_use_sru=0, - mode='predict', - ) - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - s1 = decoding.autoregressive_sample( - model, batch_size=1, eos_id=-1, max_length=10) - self.assertEqual(s1.shape[0], 1) - self.assertEqual(s1.shape[1], 10) - - def test_autoregressive_sample_transformer(self): - model = models.Transformer(10, d_model=32, d_ff=64, n_encoder_layers=1, - n_decoder_layers=1, n_heads=2, mode='predict') - inputs = np.ones((1, 3), dtype=np.int32) - model.init((shapes.signature(inputs), - shapes.ShapeDtype((1, 1), dtype=np.int32))) - s = decoding.autoregressive_sample(model, inputs=inputs, - eos_id=-1, max_length=10) - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 10) - - def test_autoregressive_sample_transformerlm_quality(self): - pred_model = models.TransformerLM( - d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, - n_layers=2, vocab_size=13, mode='predict') - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape11, shape11)) - inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) - s = decoding.autoregressive_sample(pred_model, inputs, - max_length=6, temperature=0.0) - self.assertEqual(str(s[0]), '[3 7 5 3 2 4]') - - def test_autoregressive_sample_transformerlm_quality_eval(self): - eval_model = models.TransformerLM( - d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, - n_layers=2, vocab_size=13, mode='eval') - model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') - eval_model.init_from_file(model_path) - inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) - s = decoding.autoregressive_sample(eval_model, inputs, eval_mode=True, - max_length=6, temperature=0.0) - self.assertEqual(str(s[0]), '[3 7 5 3 2 4]') - - def test_autoregressive_sample_transformerlm_quality_beam(self): - pred_model = models.TransformerLM( - d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, - n_layers=2, vocab_size=13, mode='predict') - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape11, shape11)) - inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) - s = decoding.beam_search(pred_model, inputs, n_beams=3, max_length=6) - self.assertEqual(len(s), 3) # 3 beams - self.assertEqual(str(s[0][0][0]), '[3 7 5 3 2 4]') - self.assertEqual(str(s[1][0][0]), '[3 7 5 3 2 2]') # different from above - self.assertEqual(str(s[2][0][0]), '[3 7 5 3 3 2]') # different from above - - def test_autoregressive_sample_transformer_quality(self): - pred_model = models.Transformer( - d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, - n_encoder_layers=2, n_decoder_layers=2, input_vocab_size=13, - mode='predict') - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - model_path = os.path.join(_TESTDATA, 'transformer_copy.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape11, shape11)) - inputs = np.array([[3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32) - s = decoding.autoregressive_sample(pred_model, inputs=inputs, - eos_id=1, max_length=10, temperature=0.0) - self.assertEqual(str(s[0]), '[3 7 5 3 2 4 1]') - - def test_autoregressive_sample_terraformer_lsh(self): - max_len = 128 - - pred_model = models.ConfigurableTerraformer( - mode='predict', - d_model=256, - d_ff=512, - dropout=0.05, - max_len=max_len, - n_heads=4, - n_encoder_layers=1, - n_decoder_layers=1, - ff_use_sru=1, - d_attention_key=64, - d_attention_value=64, - encoder_attention_type=self._lsh_self_attention_fn(), - encoder_decoder_attention_type=self._lsh_self_attention_fn(), - input_vocab_size=256, - pos_axial_shape=None, - ) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - pred_model.init(input_signature=(shape1l, shape11)) - - # 0w0w - inputs = np.array( - [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], - dtype=np.int32) - inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])], - mode='constant', constant_values=0) - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0) - - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 10) - - def test_autoregressive_sample_terraformer_lsh_attn_quality(self): - gin.add_config_file_search_path(_CONFIG_DIR) - max_len = 32 # 32 is the max length we trained the checkpoint for. - test_lengths = [8, 16, 32] - vocab_size = 13 - # The checkpoint is correct on ~90% sequences, set random seed to deflake. - np.random.seed(0) - for test_len in test_lengths: - gin.clear_config() - gin.parse_config_file('terraformer_copy.gin') - gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len) - gin.bind_parameter('LSHSelfAttention.predict_drop_len', 2 * max_len) - - pred_model = models.ConfigurableTerraformer(mode='predict') - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - - model_path = os.path.join(_TESTDATA, 'terraformer_copy_lsh_attn.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape1l, shape11)) - initial_state = pred_model.state - - for _ in range(2): # Set low to make the test run reasonably fast. - # Pick a length in [1, test_len] at random. - inp_len = np.random.randint(low=1, high=test_len + 1) - inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, max_len)) - # TODO(jaszczur): properly fix padding in terraformer predict mode, - # and add a test here. - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, - temperature=0.0) - np.testing.assert_equal(s[0], inputs[0, :inp_len]) - pred_model.state = initial_state - gin.clear_config() # Make sure to not affect other tests. - - def test_autoregressive_sample_reformerlm_lsh(self): - max_len = 32 - - pred_model = models.ReformerLM( - mode='predict', - d_model=256, - d_ff=512, - dropout=0.05, - max_len=2 * max_len, - n_heads=4, - n_layers=3, - ff_use_sru=0, - d_attention_key=64, - d_attention_value=64, - attention_type=functools.partial(tl.LSHSelfAttention, - chunk_len=16, - n_hashes=2, - n_buckets=[32, 32], - predict_drop_len=max_len, - predict_mem_len=max_len, - max_length_for_buckets=1024), - vocab_size=13, - pos_type='fixed-base', - pos_d_axial_embs=None, - ) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - pred_model.init(shape11) - - # 0w0 - inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) - inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])], - mode='constant', constant_values=0) - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0) - - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 10) - - def test_autoregressive_sample_reformerlm_lsh_quality(self): - max_len = 32 - - pred_model = models.ReformerLM( - mode='predict', - d_model=256, - d_ff=512, - dropout=0.05, - max_len=2 * max_len, - n_heads=4, - n_layers=3, - ff_use_sru=0, - d_attention_key=64, - d_attention_value=64, - attention_type=functools.partial(tl.LSHSelfAttention, - chunk_len=16, - n_hashes=2, - n_buckets=[32, 32], - predict_drop_len=max_len, - predict_mem_len=max_len, - max_length_for_buckets=1024), - vocab_size=13, - pos_type='fixed-base', - pos_d_axial_embs=None, - ) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - - model_path = os.path.join( - _TESTDATA, 'reformerlm_copy_lsh_attn.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=shape11) - - # 0w0 - inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) - inp_len = inputs.shape[1] - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=inp_len-2, - temperature=0.0) - - np.testing.assert_equal(s[0], inputs[0, 1:inp_len-1]) - # pylint: enable=unreachable - - def test_autoregressive_sample_terraformer_pure_lsh(self): - max_len = 128 - - pred_model = models.ConfigurableTerraformer( - mode='predict', - d_model=256, - d_ff=512, - dropout=0.05, - max_len=max_len, - n_heads=4, - n_encoder_layers=1, - n_decoder_layers=1, - ff_use_sru=1, - d_attention_key=64, - d_attention_value=64, - encoder_attention_type=self._pure_lsh_self_attention_fn( - n_chunks_after=1), - encoder_decoder_attention_type=self._pure_lsh_self_attention_fn(), - input_vocab_size=256, - pos_axial_shape=None, - ) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - pred_model.init(input_signature=(shape1l, shape11)) - - # 0w0w - inputs = np.array( - [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], - dtype=np.int32) - inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])], - mode='constant', constant_values=0) - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0) - - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 10) - - def test_autoregressive_sample_terraformer_pure_lsh_attn_quality(self): - gin.add_config_file_search_path(_CONFIG_DIR) - max_len = 32 # 32 is the max length we trained the checkpoint for. - test_lengths = [8, 16, 32] - vocab_size = 13 - # The checkpoint is correct on ~90% sequences, set random seed to deflake. - np.random.seed(0) - for test_len in test_lengths: - gin.clear_config() - gin.parse_config_file('terraformer_purelsh_copy.gin') - gin.bind_parameter('PureLSHSelfAttention.predict_mem_len', 2 * max_len) - gin.bind_parameter('PureLSHSelfAttention.predict_drop_len', 2 * max_len) - gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False) - gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2) - - pred_model = models.ConfigurableTerraformer(mode='predict') - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - - model_path = os.path.join(_TESTDATA, 'terraformer_purelsh_copy.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape1l, shape11)) - initial_state = pred_model.state - - for _ in range(2): # Set low to make the test run reasonably fast. - # Pick a length in [1, test_len] at random. - inp_len = np.random.randint(low=1, high=test_len + 1) - inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, max_len)) - # TODO(jaszczur): properly fix padding in terraformer predict mode, - # and add a test here. - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, - temperature=0.0) - - np.testing.assert_equal(s[0], inputs[0, :inp_len]) - pred_model.state = initial_state - gin.clear_config() # Make sure to not affect other tests. - - -if __name__ == '__main__': - config.config_with_absl() - test.main() diff --git a/trax/supervised/decoding_timing_test.py b/trax/supervised/decoding_timing_test.py deleted file mode 100644 index 48faf156e..000000000 --- a/trax/supervised/decoding_timing_test.py +++ /dev/null @@ -1,439 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Timing tests for decoding.""" - -import copy -import functools -import gc -import os -import time -from jax.config import config -import numpy as np -import psutil -from tensorflow.compat.v2 import test - -from trax import fastmath -from trax import layers as tl -from trax import models -from trax import shapes -from trax.supervised import decoding - - -def _size_of_model(model): - def _size(x): - try: - return x.size - except Exception: # pylint: disable=broad-except - return 0 - sizes = fastmath.nested_map(_size, model.weights) - total_size = sum(fastmath.tree_flatten(sizes)) - return total_size - - -def _recurrent_delete(w): - if 'delete' in dir(w): - # Object has a 'delete' method, so it is a DeviceArray or something similar, - # so we want to delete it. - w.delete() - elif isinstance(w, (list, tuple)): - for x in w: - _recurrent_delete(x) - elif isinstance(w, dict): - for x in w.values(): - _recurrent_delete(x) - else: - raise ValueError('Unknown type encountered in weights: {}'.format(type(w))) - - -def _memory_usage(): - gc.collect() - return psutil.Process(os.getpid()).memory_info().rss - - -class DecodingTimingTest(test.TestCase): - - def _terraformer_decoding_time(self, settings): - # Garbage collection influences the timing, so we turn it off. - gc.disable() - max_len = 16 - - def _self_attention_fn(): - return functools.partial( - tl.SelfAttention, - predict_drop_len=2 * max_len, - predict_mem_len=2 * max_len) - - def _causal_attention_fn(): - attn_layer, attn_kwargs = settings['attn'] - return functools.partial( - attn_layer, - max_inference_length=2 * max_len, **attn_kwargs) - - if settings['model'] == 'terraformer': - pred_model = models.ConfigurableTerraformer( - mode='predict', - d_model=settings['d_model'], - d_ff=settings['d_ff'], - dropout=0.1, - max_len=max_len, - n_heads=settings['n_heads'], - n_encoder_layers=settings['encoder_layers'], - n_decoder_layers=settings['decoder_layers'], - encoder_attention_type=_self_attention_fn(), - encoder_decoder_attention_type=_causal_attention_fn(), - input_vocab_size=settings['vocab'], - ff_sparsity=settings['ff_sparsity'], - ff_use_sru=settings['ff_use_sru'], - ff_dropout=0.1, - # ff_chunk_size=1024, - # attention_chunk_size=1, - n_decoder_attention_layers=settings['attention_layers'], - loss_sparsity=settings['loss_sparsity'], - pos_axial_shape=None, - use_bfloat16=True, - ) - elif settings['model'] == 'transformer': - pred_model = models.ConfigurableTransformer( - mode='predict', - d_model=settings['d_model'], - d_ff=settings['d_ff'], - dropout=0.1, - max_len=max_len, - n_heads=settings['n_heads'], - n_encoder_layers=settings['encoder_layers'], - n_decoder_layers=settings['decoder_layers'], - # encoder_attention_type=_self_attention_fn(), - encoder_decoder_attention_type=_causal_attention_fn(), - input_vocab_size=settings['vocab'], - ff_sparsity=settings['ff_sparsity'], - ff_use_sru=settings['ff_use_sru'], - # ff_dropout=0.1, - # ff_chunk_size=1024, - # attention_chunk_size=1, - # n_decoder_attention_layers=settings['attention_layers'], - loss_sparsity=settings['loss_sparsity'], - pos_axial_shape=None, - # enc_dec_attention_sparsity=settings['enc_dec_sparsity'], - # use_bfloat16=True, - ) - else: - assert False - # We put acceleration outside of autoregressive_sample_stream, because - # we want to have a separate run (separate input) for model compilation. - pred_model = tl.Accelerate(pred_model) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - pred_model.init(input_signature=(shape1l, shape11)) - original_state = copy.deepcopy(pred_model.state) - - inputs_warmup = np.zeros((1, max_len), dtype=np.int32) - inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len) - - # This is a warm-up run, for compilation. - result, current_time = [], time.time() - elapsed_warmup_times = [] - for index, sample in zip(range(0, 4), decoding.autoregressive_sample_stream( - pred_model, inputs_warmup, temperature=0.0, accelerate=False)): - del index # unused - result.append(sample[:, None]) # to be sure that the result is computed - - current_time, start_time = time.time(), current_time - elapsed_warmup_times.append(current_time - start_time) - - # This is a real decoding timing run that we measure. - pred_model.state = original_state - result, current_time = [], time.time() - elapsed_times = [] - for index, sample in zip(range(12), decoding.autoregressive_sample_stream( - pred_model, inputs, temperature=0.0, accelerate=False)): - del index # unused - result.append(sample[:, None]) # to be sure that the result is computed - - current_time, start_time = time.time(), current_time - elapsed_times.append(current_time - start_time) - peak_memory = _memory_usage() - - if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]): - print('WARNING! High variance found in elapsed times! Settings: {} ; ' - 'elapsed times: {} ; Probably more warm-up steps should be used, ' - 'or model size should be increased.'.format(settings, - elapsed_times)) - # Check resulting shapes. - s = np.concatenate(result, axis=1) - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 12) - model_size = int(_size_of_model(pred_model)) - - # We delete the model weights, because in some situations they won't be - # deleted automatically. - _recurrent_delete(pred_model.weights) - gc.enable() - return model_size, elapsed_times, peak_memory - - def test_autoregressive_sample_terraformer_timing(self): - template_to_use = 'medium_transformer' - - settings_templates = { - # full model - # # 54B params - # 'full_model': { - # 'encoder_layers': 6, 'decoder_layers': 36, 'vocab': 32000, - # 'attention_layers': 2, - # 'd_ff': 64*1024, 'd_model': 96*96, 'n_heads': 96, - # 'ff_use_sru': (1, 64), 'ff_sparsity': (256, 32), - # 'loss_sparsity': 8, - # 'attn': (tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 64})}, - - # 1/18 of model (1/6 of encoder, 1/18 of decoder, full vocab) - # 4B params - # 'big_terraformer': { - # 'model': 'terraformer', - # 'encoder_layers': 1, 'decoder_layers': 2, 'vocab': 32000, - # 'attention_layers': 2, - # 'd_ff': int(5/8 * 64*1024), 'd_model': 96*96, 'n_heads': 96, - # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, - # 'attn': (tl.CausalAttention, {})}, - - # 'big_transformer': { - # 'model': 'transformer', - # 'encoder_layers': 1, 'decoder_layers': 2, 'vocab': 32000, - # 'attention_layers': 2, - # 'd_ff': int(5/8 * 64*1024), 'd_model': 96*96, 'n_heads': 96, - # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, - # 'attn': (tl.CausalAttention, {})}, - - # medium model - # 275M params (only decoder) - 'medium_transformer': { - 'model': 'transformer', - 'encoder_layers': 2, 'decoder_layers': 24, 'vocab': 32000, - 'attention_layers': 2, - 'd_ff': 4*1024, 'd_model': 1024, 'n_heads': 16, - 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, - 'attn': (tl.CausalAttention, {})}, - # 'medium_terraformer': { - # 'model': 'terraformer', - # 'encoder_layers': 2, 'decoder_layers': 24, 'vocab': 32000, - # 'attention_layers': 2, - # 'd_ff': 4*1024, 'd_model': 1024, 'n_heads': 16, - # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, - # 'attn': (tl.CausalAttention, {})}, - - } - - sweep_settings = { - # 'big_transformer': [ # for big - # dict(), # baseline - # {'ff_sparsity': (256, 32)}, # + Sparse FF - # {'attn': ( # + Sparse QKV - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 64}), - # 'd_ff': 64*1024, - # }, - # {'ff_sparsity': (256, 32), - # 'attn': ( # + Sparse FF+QKV - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 64}), - # 'd_ff': 64*1024, - # }, - # ], - - 'medium_transformer': [ # for medium - dict(), # baseline - - {'ff_sparsity': 64, - 'attn': ( # Sparse FF+QKV - tl.MultiplicativeConvCausalAttention, - {'length_kernel_size': 3, 'sparsity': 16}), - 'd_ff': 6*1024, - }, - - # {'ff_sparsity': 64, # Sparse FF+QKV + Loss - # 'attn': ( - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # 'loss_sparsity': 4, - # }, - - # {'attn': ( # Sparse QKV - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # }, - # {'loss_sparsity': 4}, # Sparse Loss - # {'ff_sparsity': 64}, # Sparse FF - - # {'ff_sparsity': 128}, # + Sparse FF 128 - - # APPENDIX below - - # different loss layers - # {'loss_sparsity': 8}, - # {'loss_sparsity': 2}, - # {'loss_sparsity': 0}, - ], - - # 'big_terraformer': [ # for big terraformer - # dict(), # baseline - # {'ff_sparsity': 64}, # + Sparse FF / Sparse FF 64 - # {'ff_sparsity': 64, - # 'attn': ( # + Sparse FF+QKV - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # }, - # {'ff_sparsity': 64, # + Sparse FF+QKV+Loss - # 'attn': ( - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # 'loss_sparsity': 4, - # }, - - # ], - - # 'medium_terraformer': [ # for medium terraformer - # {'ff_sparsity': 64, # + Sparse FF+QKV+Loss - # 'attn': ( - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # 'loss_sparsity': 4, - # }, - # ], - } - - encoding_times = [] - decoding_times = [] - sizes = [] - memories = [] - messages = [] - for override_settings in sweep_settings[template_to_use]: - settings = copy.deepcopy(settings_templates[template_to_use]) - settings.update(override_settings) - - init_memory = _memory_usage() - size, elapsed_times, peak_memory = ( - self._terraformer_decoding_time(settings)) - - # TODO(jaszczur): Why is elapsed_times[0] always small? - encoding_time = elapsed_times[1] - decoding_time_10 = sum(elapsed_times[2:]) - - after_memory = _memory_usage() - model_memory_gigabytes = (peak_memory-init_memory)/1024**3 - decoding_time_diff = (max(elapsed_times[2:]) - min(elapsed_times[2:])) / 2 - decoding_time_diff_percent = int( - decoding_time_diff / np.mean(elapsed_times) * 100) - message = ( - '\n\n' - 'Params: {}\n' - 'Settings: {}\n' - 'Override: {}\n' - 'Init memory: {:.1f} GiB\n' - 'Peak memory: {:.1f} GiB\n' - 'After memory: {:.1f} GiB\n' - 'Estimated model memory: {:.1f} GiB\n' - 'Times for each step: {}\n' - 'Time for encoding: {:.4f} s\n' - 'Time for decoding 10 tokens: {:.4f} s +/- {} %\n' - '\n\n' - .format(size, settings, override_settings, - init_memory/1024**3, peak_memory/1024**3, - after_memory/1024**3, model_memory_gigabytes, - elapsed_times, encoding_time, - decoding_time_10, decoding_time_diff_percent)) - print(message) - messages.append(message) - encoding_times.append(encoding_time) - decoding_times.append(decoding_time_10) - sizes.append(size) - memories.append(model_memory_gigabytes) - - print('Final results (recap):') - for message in messages: - print(message) - - # This is useful for copying results into a spreadsheet etc. - # for i in range(len(sweep_settings)): - # print('{}\t{}\t{}\t{:.1f}'.format( - # sizes[i], encoding_times[i], decoding_times[i], memories[i])) - - def test_loss_layer_timing(self): - all_settings = [ - # The first run is sometimes slower, less reliable. - {'output': 32000, 'input': 2048, 'prob': None, - 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, - - {'output': 32000, 'input': 2048, 'prob': None, - 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, - {'output': 32000, 'input': 2048, 'prob': None, - 'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, - {'output': 32000, 'input': 2048, 'prob': None, - 'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': False}, - - {'output': 32000, 'input': 2048, 'prob': None, - 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': True}, - {'output': 32000, 'input': 2048, 'prob': None, - 'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': True}, - {'output': 32000, 'input': 2048, 'prob': None, - 'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': True}, - ] - - messages = [] - for settings in all_settings: - pred_model = tl.SparseDenseWithOptions( - n_units=settings['output'], - d_input=settings['input'], - sparsity_type=settings['type'], - sparsity=settings['sparsity'], - d_lowrank=settings['lowrank'], - prob_sparse=settings['prob'], - use_bias=settings['use_bias'], - mode='predict', - ) - pred_model = tl.Accelerate(pred_model) - - shape1l = shapes.ShapeDtype((1, settings['input'])) - pred_model.init(input_signature=shape1l) - inputs = np.ones((1, settings['input'])) - - total_time = 0.0 - for counter in range(-50, 100): - start_time = time.time() - y = pred_model(inputs) - self.assertEqual(y.shape, (1, settings['output'])) - elapsed_time = time.time() - start_time - if counter >= 0: - total_time += elapsed_time - - message = ( - '\n\nParams: %d Settings: %s\nTime for 100 tokens: %.4f s\n\n\n' - % (_size_of_model(pred_model), settings, total_time)) - messages.append(message) - print(message) - - print('Final results (recap):') - for message in messages: - print(message) - - -if __name__ == '__main__': - config.config_with_absl() - test.main() diff --git a/trax/supervised/history_test.py b/trax/supervised/history_test.py deleted file mode 100644 index 3aee06a64..000000000 --- a/trax/supervised/history_test.py +++ /dev/null @@ -1,56 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.supervised.history.""" - -from absl.testing import absltest - -from trax.supervised import history as trax_history - - -class HistoryTest(absltest.TestCase): - - def test_unknown_mode(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - self.assertEqual(history.get('unknown_mode', 'metric1'), []) - - def test_unknown_metric(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - self.assertEqual(history.get('train', 'unknown_metric'), []) - - def test_serializer_and_deserializer(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - json_object = history.to_dict() - history2 = trax_history.History.from_dict(json_object) - self.assertEqual(history2.get('train', 'metric1'), [(1, 0.1)]) - - def test_modes(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - history.append('test', 'metric2', 2, 0.2) - self.assertEqual(history.modes, ['test', 'train']) - - def test_metrics_for_mode(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - history.append('train', 'metric2', 2, 0.2) - self.assertEqual(history.metrics_for_mode('train'), ['metric1', 'metric2']) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/lr_schedules_test.py b/trax/supervised/lr_schedules_test.py deleted file mode 100644 index 1973686bf..000000000 --- a/trax/supervised/lr_schedules_test.py +++ /dev/null @@ -1,95 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests of learning rate schedules.""" - -import math - -from absl.testing import absltest - -from trax.supervised import lr_schedules - - -class LRFunctionsTest(absltest.TestCase): - - def test_warmup(self): - lr_fn = lr_schedules.warmup(9, .01) - - # Linear warm-up. - self.assertAlmostEqual(.001, lr_fn(1)) - self.assertAlmostEqual(.002, lr_fn(2)) - self.assertAlmostEqual(.005, lr_fn(5)) - self.assertAlmostEqual(.009, lr_fn(9)) - - # Constant thereafter. - self.assertAlmostEqual(.01, lr_fn(10)) - self.assertAlmostEqual(.01, lr_fn(11)) - self.assertAlmostEqual(.01, lr_fn(20)) - self.assertAlmostEqual(.01, lr_fn(300)) - self.assertAlmostEqual(.01, lr_fn(4000)) - - def test_constant(self): - lr_fn = lr_schedules.constant(.02) - self.assertEqual(.02, lr_fn(1)) - self.assertEqual(.02, lr_fn(20)) - self.assertEqual(.02, lr_fn(300)) - self.assertEqual(.02, lr_fn(4000)) - self.assertEqual(.02, lr_fn(50000)) - self.assertEqual(.02, lr_fn(600000)) - self.assertEqual(.02, lr_fn(7000000)) - self.assertEqual(.02, lr_fn(80000000)) - self.assertEqual(.02, lr_fn(900000000)) - - def test_warmup_and_rsqrt_decay(self): - lr_fn = lr_schedules.warmup_and_rsqrt_decay(24, .25) - - # Warm-up. - self.assertAlmostEqual(.01, lr_fn(1)) - self.assertAlmostEqual(.02, lr_fn(2)) - self.assertAlmostEqual(.23, lr_fn(23)) - self.assertAlmostEqual(.24, lr_fn(24)) - - # Reciprocal square-root decay. - self.assertAlmostEqual(.25 * (5 / math.sqrt(25)), lr_fn(25)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(26)), lr_fn(26)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(27)), lr_fn(27)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(300)), lr_fn(300)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(4000)), lr_fn(4000)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(50000)), lr_fn(50000)) - - def test_cosine_sawtooth(self): - tail_fn = lr_schedules._CosineSawtoothTail(180, min_value=.1) - lr_fn = lr_schedules._BodyAndTail(.3, tail_start=0, tail_fn=tail_fn) - - # First cycle - self.assertAlmostEqual(.29998477, lr_fn(1)) - self.assertAlmostEqual(.28660254, lr_fn(30)) - self.assertAlmostEqual(.25, lr_fn(60)) - self.assertAlmostEqual(.20, lr_fn(90)) - self.assertAlmostEqual(.15, lr_fn(120)) - self.assertAlmostEqual(.10001523, lr_fn(179)) - - # Second cycle - self.assertEqual(.3, lr_fn(180)) - self.assertAlmostEqual(.29998477, lr_fn(181)) - self.assertAlmostEqual(.28660254, lr_fn(210)) - self.assertAlmostEqual(.25, lr_fn(240)) - self.assertAlmostEqual(.20, lr_fn(270)) - self.assertAlmostEqual(.15, lr_fn(300)) - self.assertAlmostEqual(.10001523, lr_fn(359)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/trainer_lib.py b/trax/supervised/trainer_lib.py index 4ffef14f3..7f1775fe9 100644 --- a/trax/supervised/trainer_lib.py +++ b/trax/supervised/trainer_lib.py @@ -47,910 +47,985 @@ # TODO(afrozm): Maybe flatten everything from OptState into TrainerState. -TrainerState = collections.namedtuple('_TrainerState', [ - 'step', # Current training step number. - 'opt_state', # OptState. - 'history', # trax.history.History. - 'model_state', # Auxilliary state of the model. -]) +TrainerState = collections.namedtuple( + "_TrainerState", + [ + "step", # Current training step number. + "opt_state", # OptState. + "history", # trax.history.History. + "model_state", # Auxilliary state of the model. + ], +) -OptState = collections.namedtuple('_OptState', [ - 'weights', # Model weights. - 'slots', # Per-parameter optimizer state, e.g. gradient moments. - 'opt_params', # Optimizer (hyper)parameters, e.g. learning rate, momentum. -]) +OptState = collections.namedtuple( + "_OptState", + [ + "weights", # Model weights. + "slots", # Per-parameter optimizer state, e.g. gradient moments. + "opt_params", # Optimizer (hyper)parameters, e.g. learning rate, momentum. + ], +) _DEFAULT_METRICS = { - 'loss': tl.WeightedCategoryCrossEntropy(), - 'accuracy': tl.WeightedCategoryAccuracy(), - 'sequence_accuracy': tl.MaskedSequenceAccuracy(), - 'neg_log_perplexity': tl.Serial(tl.WeightedCategoryCrossEntropy(), - tl.Negate()), - 'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()), + "loss": tl.WeightedCategoryCrossEntropy(), + "accuracy": tl.WeightedCategoryAccuracy(), + "sequence_accuracy": tl.MaskedSequenceAccuracy(), + "neg_log_perplexity": tl.Serial(tl.WeightedCategoryCrossEntropy(), tl.Negate()), + "weights_per_batch_per_core": tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()), } -NamedStream = collections.namedtuple( - 'NamedStream', ['name', 'stream'] -) +NamedStream = collections.namedtuple("NamedStream", ["name", "stream"]) @gin.configurable def named_stream(name=gin.REQUIRED, stream=gin.REQUIRED): - return NamedStream(name=name, stream=stream) + return NamedStream(name=name, stream=stream) class Trainer: - """Trax trainer. - - A trainer allows to make training steps, train for full epochs, - save the training state and access evaluation data. - """ - - def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, - output_dir=None, random_seed=None, n_devices=None, - checkpoints_at=None, should_save_checkpoints=True, - should_write_summaries=True, - metrics=None, checkpoint_highest=None, - checkpoint_lowest=None, - init_checkpoint=None): - - self._is_chief, _, self._n_devices, rng = ( - training.init_host_and_devices(n_devices, random_seed)) - self._should_save_checkpoints = should_save_checkpoints and self._is_chief - self._checkpoints_at = checkpoints_at if checkpoints_at is not None else [] - self._should_write_summaries = should_write_summaries - if not output_dir: - self._should_save_checkpoints = False - self._should_write_summaries = False - self._checkpoint_highest = checkpoint_highest - self._checkpoint_lowest = checkpoint_lowest - self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS - # Inputs is either an Inputs instance or a function that returns it. - self._inputs = inputs - if callable(inputs): # If we pass a function, e.g., through gin, call it. - self._inputs = inputs() - # Initialize the learning rate to a dummy value. It will be set in reset(). - opt = optimizer(learning_rate=0.0) - - # Setup the model. - model_train = model(mode='train') - model_predict_eval = model(mode='eval') - # Should work for fine-tuning of T5. - if init_checkpoint: - model_train.init_from_file(init_checkpoint, weights_only=True) - model_predict_eval.init_from_file(init_checkpoint, weights_only=True) - self._model_with_loss = tl.Serial(model_train, loss_fn) - - # Setup state. - rng, init_rng = jax_random.split(rng) - self._rngs = np.stack(jax_random.split(rng, self._n_devices)) - shapes, dtypes = self._inputs.example_shape_dtype - input_signature = tuple(ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) - - def new_opt_state_and_model_state(rng): - """Returns optimizer and model states suitable for training a model.""" - weights, state = self._model_with_loss.init(input_signature, rng=rng) - (slots, opt_params) = opt.tree_init(weights) - return (OptState(weights, slots, opt_params), state) - - if fastmath.is_backend(fastmath.Backend.JAX): - # JIT parameter initialization to avoid memory fragmentation - new_opt_state_and_model_state = ( - fastmath.jit(new_opt_state_and_model_state)) - self._new_opt_state_and_model_state = ( - lambda: new_opt_state_and_model_state(init_rng)) - - # Arrange and initialize metrics layers. - self._metrics = list(sorted(self._metrics_dict.keys())) - metrics_layers = [self._metrics_dict[m] for m in self._metrics] - metrics_in_parallel = tl.Branch(*metrics_layers) - metrics_in_parallel.rng = init_rng - example_signature = tuple( - ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype) - ) - model_predict_eval.init(example_signature) - self._input_signature = example_signature - output_signature = model_predict_eval.output_signature(example_signature) - m_weights, m_state = metrics_in_parallel.init(output_signature) - self._metrics_weights = self._for_n_devices(m_weights) - self._metrics_state = self._for_n_devices(m_state) - - # Jit model_predict and update so they're fast. - self._jit_eval = _jit_predict_fn( - model_predict_eval, metrics_in_parallel, self._n_devices) - self._jit_update_fn = _jit_update_fn( - model_train, loss_fn, opt, self._n_devices) - - self._model_train = model_train - self._model_predict_eval = model_predict_eval - self._loss_fn = loss_fn - self._lr_schedule = lr_schedule - - # Those fields will be set in reset(). - self._output_dir = None - self._train_sw = None - self._eval_sw = None - self._history = None - self._opt_state = None - self._step = None - self._model_state = None - self.reset(output_dir) - - @property - def n_devices(self): - return self._n_devices - - @property - def step(self): - return self._step - - @property - def model_weights(self): - # Currently we need to pick [0] as we ignore loss weights (empty). - weights = self._opt_state.weights[0] - if self.n_devices > 1: - unreplicate = lambda x: x[0] - weights = fastmath.nested_map(unreplicate, weights) - return weights - - @model_weights.setter - def model_weights(self, weights): - new_model_weights = self._for_n_devices(weights) - if isinstance(self._opt_state.weights, list): - self._opt_state.weights[0] = new_model_weights - else: # weights are a tuple, need to re-create - new_weights = [new_model_weights] + list(self._opt_state.weights[1:]) - self._opt_state = self._opt_state._replace(weights=new_weights) - - @property - def model_state(self): - # Currently we need to pick [0] as we ignore loss state (empty). - state = self._model_state[0] - if self.n_devices > 1: - unreplicate = lambda x: x[0] - state = fastmath.nested_map(unreplicate, state) - return state - - @model_state.setter - def model_state(self, state): - new_model_state = self._for_n_devices(state) - if isinstance(self._model_state, list): - self._model_state[0] = new_model_state - else: # weights are a tuple, need to re-create - self._model_state = [new_model_state] + list(self._model_state[1:]) - - @property - def state(self): - return TrainerState( - opt_state=self._opt_state, step=self._step, history=self._history, - model_state=self._model_state) + """Trax trainer. - @property - def learning_rate(self): - with fastmath.use_backend(fastmath.Backend.NUMPY): - return self._lr_schedule(self._step) - - def reset(self, output_dir, init_checkpoint=None): - """Reset the model parameters. - - Restores the parameters from the given output_dir if a checkpoint exists, - otherwise randomly initializes them. - - Does not re-jit the model. - - Args: - output_dir: Output directory. - init_checkpoint: Initial checkpoint (default $output_dir/model.pkl.gz) + A trainer allows to make training steps, train for full epochs, + save the training state and access evaluation data. """ - self.close() - self._output_dir = output_dir - if output_dir is not None: - tf.io.gfile.makedirs(output_dir) - else: - assert not self._should_save_checkpoints - assert not self._should_write_summaries - - # Create summary writers and history. - if self._should_write_summaries: - self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'), - enable=self._is_chief) - self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'), - enable=self._is_chief) - - # Reset the train and eval streams. - self._train_stream = _repeat_stream(self._inputs.train_stream, - self._n_devices) - # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval - # set by adding a padding and stopping the stream when too large. - self._eval_stream = _repeat_stream( - self._inputs.eval_stream, self._n_devices) - self._train_eval_stream = _repeat_stream( - self._inputs.train_eval_stream, self._n_devices) - - # Restore the training state. - if output_dir is not None: - state = load_trainer_state(output_dir, self._model_with_loss, - init_checkpoint) - else: - state = TrainerState(step=None, opt_state=None, - history=trax_history.History(), model_state=None) - self._step = state.step or 0 - history = state.history - self._history = history - if state.opt_state: - opt_state = state.opt_state - model_state = state.model_state - else: - opt_state, model_state = self._new_opt_state_and_model_state() - model_state = self._for_n_devices(model_state) - self._opt_state = OptState(*self._for_n_devices(opt_state)) - self._model_state = model_state - if not state.opt_state and self._should_save_checkpoints: - self.save_state(keep=False) - - def train_epoch(self, n_steps, n_eval_steps): - """Runs `n_steps` of training, with periodic logging, saving, and evals.""" - # TODO(jonni): Clarify how this method relates to the stricter notion of - # epoch (training for as many steps as needed for a full pass through the - # training data). - print() # Add visual separator in logs for start of training epoch. - start_time = time.time() - - for _ in range(n_steps): - batch = next(self._train_stream) - if self.n_devices > 1: # TODO(lukaszkaiser): use everywhere if possible. - batch = _reshape_by_device(batch, self.n_devices) - self.train_step(batch) - if self._should_save_now(): - self.save_state(keep=True) - if self._should_log_now(): - self._train_sw.scalar('training/learning_rate', self.learning_rate) - - # At end of n_steps, do bookkeeping, run evals, and save state. - elapsed_time = time.time() - start_time - self.log_step('Ran %d train steps in %0.2f secs' % (n_steps, elapsed_time)) - if self._train_sw and n_steps > 1: - self._train_sw.scalar('training/steps per second', - n_steps / elapsed_time, step=self._step) - self._train_sw.flush() - self.evaluate(n_eval_steps) - if self._eval_sw: - self._eval_sw.flush() - if self._should_save_checkpoints: - self.save_state(keep=False) - if self._should_save_checkpoints and self._current_step_is_best(high=True): - self.save_state(keep=False, prefix='highest_' + self._checkpoint_highest) - if self._should_save_checkpoints and self._current_step_is_best(high=False): - self.save_state(keep=False, prefix='lowest_' + self._checkpoint_lowest) - - def train_step(self, batch): - """Run one training step and update self._opt_state.""" - # Calculate the current optimizer parameters. - opt_param_updates = self._for_n_devices( - {'learning_rate': np.array(self.learning_rate)}) - opt_state = self._opt_state - opt_state.opt_params.update(opt_param_updates) - - # Run the update. - weights, slots, opt_params = opt_state - (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( - (weights, slots), self._step, opt_params, batch, - self._model_state, self._rngs) - self._opt_state = opt_state._replace(weights=weights, slots=slots) - if self._should_log_now(): - for name, value in stat.items(): - # TODO(afrozm): value is a scalar, but sometimes JAX is crashing here - # with a device put array error complaining that it should be an array. - # On multiple devices, take the mean. - scalar_value = np.mean(np.array(value)) - self._train_sw.scalar('training/' + name, scalar_value, step=self._step) - self._step += 1 - - def evaluate(self, n_eval_steps): - """Evaluate the model and log metrics.""" - _, rng = jax_random.split(self._rngs[0]) - # TODO(lukaszkaiser): both model state and parameters by default include - # the loss layer. Currently, we access the pure-model parameters by just - # indexing, [0] here. But we should make it more explicit in a better API. - weights = (self._opt_state.weights[0], self._metrics_weights) - state = (self._model_state[0], self._metrics_state) - self.log_step('Evaluation') - train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps) - train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state, - rng) - self.log_metrics(train_metrics, self._train_sw, 'train') - eval_slice = itertools.islice(self._eval_stream, n_eval_steps) - eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng) - self.log_metrics(eval_metrics, self._eval_sw, 'eval') - self.log_step('Finished evaluation') - - # Save the learning rate in history. - self._history.append('train', 'training/learning_rate', - self._step, self.learning_rate) - - def evaluation_round(self, inputs_stream, weights, state, rng): - """Evaluate. + + def __init__( + self, + model, + loss_fn, + optimizer, + lr_schedule, + inputs, + output_dir=None, + random_seed=None, + n_devices=None, + checkpoints_at=None, + should_save_checkpoints=True, + should_write_summaries=True, + metrics=None, + checkpoint_highest=None, + checkpoint_lowest=None, + init_checkpoint=None, + ): + self._is_chief, _, self._n_devices, rng = training.init_host_and_devices( + n_devices, random_seed + ) + self._should_save_checkpoints = should_save_checkpoints and self._is_chief + self._checkpoints_at = checkpoints_at if checkpoints_at is not None else [] + self._should_write_summaries = should_write_summaries + if not output_dir: + self._should_save_checkpoints = False + self._should_write_summaries = False + self._checkpoint_highest = checkpoint_highest + self._checkpoint_lowest = checkpoint_lowest + self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS + # Inputs is either an Inputs instance or a function that returns it. + self._inputs = inputs + if callable(inputs): # If we pass a function, e.g., through gin, call it. + self._inputs = inputs() + # Initialize the learning rate to a dummy value. It will be set in reset(). + opt = optimizer(learning_rate=0.0) + + # Setup the model. + model_train = model(mode="train") + model_predict_eval = model(mode="eval") + # Should work for fine-tuning of T5. + if init_checkpoint: + model_train.init_from_file(init_checkpoint, weights_only=True) + model_predict_eval.init_from_file(init_checkpoint, weights_only=True) + self._model_with_loss = tl.Serial(model_train, loss_fn) + + # Setup state. + rng, init_rng = jax_random.split(rng) + self._rngs = np.stack(jax_random.split(rng, self._n_devices)) + shapes, dtypes = self._inputs.example_shape_dtype + input_signature = tuple(ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) + + def new_opt_state_and_model_state(rng): + """Returns optimizer and model states suitable for training a model.""" + weights, state = self._model_with_loss.init(input_signature, rng=rng) + (slots, opt_params) = opt.tree_init(weights) + return (OptState(weights, slots, opt_params), state) + + if fastmath.is_backend(fastmath.Backend.JAX): + # JIT parameter initialization to avoid memory fragmentation + new_opt_state_and_model_state = fastmath.jit(new_opt_state_and_model_state) + self._new_opt_state_and_model_state = lambda: new_opt_state_and_model_state( + init_rng + ) + + # Arrange and initialize metrics layers. + self._metrics = list(sorted(self._metrics_dict.keys())) + metrics_layers = [self._metrics_dict[m] for m in self._metrics] + metrics_in_parallel = tl.Branch(*metrics_layers) + metrics_in_parallel.rng = init_rng + example_signature = tuple( + ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype) + ) + model_predict_eval.init(example_signature) + self._input_signature = example_signature + output_signature = model_predict_eval.output_signature(example_signature) + m_weights, m_state = metrics_in_parallel.init(output_signature) + self._metrics_weights = self._for_n_devices(m_weights) + self._metrics_state = self._for_n_devices(m_state) + + # Jit model_predict and update so they're fast. + self._jit_eval = _jit_predict_fn( + model_predict_eval, metrics_in_parallel, self._n_devices + ) + self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) + + self._model_train = model_train + self._model_predict_eval = model_predict_eval + self._loss_fn = loss_fn + self._lr_schedule = lr_schedule + + # Those fields will be set in reset(). + self._output_dir = None + self._train_sw = None + self._eval_sw = None + self._history = None + self._opt_state = None + self._step = None + self._model_state = None + self.reset(output_dir) + + @property + def n_devices(self): + return self._n_devices + + @property + def step(self): + return self._step + + @property + def model_weights(self): + # Currently we need to pick [0] as we ignore loss weights (empty). + weights = self._opt_state.weights[0] + if self.n_devices > 1: + unreplicate = lambda x: x[0] + weights = fastmath.nested_map(unreplicate, weights) + return weights + + @model_weights.setter + def model_weights(self, weights): + new_model_weights = self._for_n_devices(weights) + if isinstance(self._opt_state.weights, list): + self._opt_state.weights[0] = new_model_weights + else: # weights are a tuple, need to re-create + new_weights = [new_model_weights] + list(self._opt_state.weights[1:]) + self._opt_state = self._opt_state._replace(weights=new_weights) + + @property + def model_state(self): + # Currently we need to pick [0] as we ignore loss state (empty). + state = self._model_state[0] + if self.n_devices > 1: + unreplicate = lambda x: x[0] + state = fastmath.nested_map(unreplicate, state) + return state + + @model_state.setter + def model_state(self, state): + new_model_state = self._for_n_devices(state) + if isinstance(self._model_state, list): + self._model_state[0] = new_model_state + else: # weights are a tuple, need to re-create + self._model_state = [new_model_state] + list(self._model_state[1:]) + + @property + def state(self): + return TrainerState( + opt_state=self._opt_state, + step=self._step, + history=self._history, + model_state=self._model_state, + ) + + @property + def learning_rate(self): + with fastmath.use_backend(fastmath.Backend.NUMPY): + return self._lr_schedule(self._step) + + def reset(self, output_dir, init_checkpoint=None): + """Reset the model parameters. + + Restores the parameters from the given output_dir if a checkpoint exists, + otherwise randomly initializes them. + + Does not re-jit the model. + + Args: + output_dir: Output directory. + init_checkpoint: Initial checkpoint (default $output_dir/model.pkl.gz) + """ + self.close() + self._output_dir = output_dir + if output_dir is not None: + tf.io.gfile.makedirs(output_dir) + else: + assert not self._should_save_checkpoints + assert not self._should_write_summaries + + # Create summary writers and history. + if self._should_write_summaries: + self._train_sw = jaxboard.SummaryWriter( + os.path.join(output_dir, "train"), enable=self._is_chief + ) + self._eval_sw = jaxboard.SummaryWriter( + os.path.join(output_dir, "eval"), enable=self._is_chief + ) + + # Reset the train and eval streams. + self._train_stream = _repeat_stream(self._inputs.train_stream, self._n_devices) + # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval + # set by adding a padding and stopping the stream when too large. + self._eval_stream = _repeat_stream(self._inputs.eval_stream, self._n_devices) + self._train_eval_stream = _repeat_stream( + self._inputs.train_eval_stream, self._n_devices + ) + + # Restore the training state. + if output_dir is not None: + state = load_trainer_state( + output_dir, self._model_with_loss, init_checkpoint + ) + else: + state = TrainerState( + step=None, + opt_state=None, + history=trax_history.History(), + model_state=None, + ) + self._step = state.step or 0 + history = state.history + self._history = history + if state.opt_state: + opt_state = state.opt_state + model_state = state.model_state + else: + opt_state, model_state = self._new_opt_state_and_model_state() + model_state = self._for_n_devices(model_state) + self._opt_state = OptState(*self._for_n_devices(opt_state)) + self._model_state = model_state + if not state.opt_state and self._should_save_checkpoints: + self.save_state(keep=False) + + def train_epoch(self, n_steps, n_eval_steps): + """Runs `n_steps` of training, with periodic logging, saving, and evals.""" + # TODO(jonni): Clarify how this method relates to the stricter notion of + # epoch (training for as many steps as needed for a full pass through the + # training data). + print() # Add visual separator in logs for start of training epoch. + start_time = time.time() + + for _ in range(n_steps): + batch = next(self._train_stream) + if self.n_devices > 1: # TODO(lukaszkaiser): use everywhere if possible. + batch = _reshape_by_device(batch, self.n_devices) + self.train_step(batch) + if self._should_save_now(): + self.save_state(keep=True) + if self._should_log_now(): + self._train_sw.scalar("training/learning_rate", self.learning_rate) + + # At end of n_steps, do bookkeeping, run evals, and save state. + elapsed_time = time.time() - start_time + self.log_step("Ran %d train steps in %0.2f secs" % (n_steps, elapsed_time)) + if self._train_sw and n_steps > 1: + self._train_sw.scalar( + "training/steps per second", n_steps / elapsed_time, step=self._step + ) + self._train_sw.flush() + self.evaluate(n_eval_steps) + if self._eval_sw: + self._eval_sw.flush() + if self._should_save_checkpoints: + self.save_state(keep=False) + if self._should_save_checkpoints and self._current_step_is_best(high=True): + self.save_state(keep=False, prefix="highest_" + self._checkpoint_highest) + if self._should_save_checkpoints and self._current_step_is_best(high=False): + self.save_state(keep=False, prefix="lowest_" + self._checkpoint_lowest) + + def train_step(self, batch): + """Run one training step and update self._opt_state.""" + # Calculate the current optimizer parameters. + opt_param_updates = self._for_n_devices( + {"learning_rate": np.array(self.learning_rate)} + ) + opt_state = self._opt_state + opt_state.opt_params.update(opt_param_updates) + + # Run the update. + weights, slots, opt_params = opt_state + (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( + (weights, slots), + self._step, + opt_params, + batch, + self._model_state, + self._rngs, + ) + self._opt_state = opt_state._replace(weights=weights, slots=slots) + if self._should_log_now(): + for name, value in stat.items(): + # TODO(afrozm): value is a scalar, but sometimes JAX is crashing here + # with a device put array error complaining that it should be an array. + # On multiple devices, take the mean. + scalar_value = np.mean(np.array(value)) + self._train_sw.scalar("training/" + name, scalar_value, step=self._step) + self._step += 1 + + def evaluate(self, n_eval_steps): + """Evaluate the model and log metrics.""" + _, rng = jax_random.split(self._rngs[0]) + # TODO(lukaszkaiser): both model state and parameters by default include + # the loss layer. Currently, we access the pure-model parameters by just + # indexing, [0] here. But we should make it more explicit in a better API. + weights = (self._opt_state.weights[0], self._metrics_weights) + state = (self._model_state[0], self._metrics_state) + self.log_step("Evaluation") + train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps) + train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state, rng) + self.log_metrics(train_metrics, self._train_sw, "train") + eval_slice = itertools.islice(self._eval_stream, n_eval_steps) + eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng) + self.log_metrics(eval_metrics, self._eval_sw, "eval") + self.log_step("Finished evaluation") + + # Save the learning rate in history. + self._history.append( + "train", "training/learning_rate", self._step, self.learning_rate + ) + + def evaluation_round(self, inputs_stream, weights, state, rng): + """Evaluate. + + Args: + inputs_stream: Iterable of inputs to evaluate on. + weights: Weights for each f in eval_fns. + state: State for each f in eval_fns. + rng: Single-use random number generator (JAX PRNG key). + + Returns: + Tuple of `(metrics, state)`. `metrics` is a dict from metric name to + metric value averaged over the number of inputs, and `state` is the end + state returned by this trainer's `predict_fn`. + """ + metrics = collections.defaultdict(float) + count = 0 + for inp in inputs_stream: + count += 1 + rng, subrng = jax_random.split(rng) + metric_values, _ = self._jit_eval(inp, weights, state, subrng) + try: + metric_values = list(metric_values) + except (TypeError, IndexError): + metric_values = [float(metric_values)] + for m, v in zip(self._metrics, metric_values): + metrics[m] += v + return {m: v / count for (m, v) in metrics.items()}, state + + def save_gin(self): + """ "Saves the operative gin config, only if it is the chief.""" + if not self._is_chief: + return + assert self._output_dir is not None + config_path = os.path.join(self._output_dir, "config.gin") + config_str = gin.operative_config_str() + with tf.io.gfile.GFile(config_path, "w") as f: + f.write(config_str) + sw = self._train_sw + if sw: + sw.text("gin_config", jaxboard.markdownify_operative_config_str(config_str)) + + def _save_state_dict(self, trainer_state_dict, weights_file): + training.pickle_to_file(trainer_state_dict, weights_file, gzip=True) + log("Model saved to %s" % weights_file, stdout=False) + + def save_state(self, keep, prefix="model"): + """Save trainer state given a possibly replicated opt_state.""" + opt_state = self._opt_state + if self.n_devices > 1: + first_replica = lambda x: x[0] + opt_state = OptState(*fastmath.nested_map(first_replica, opt_state)) + # This line, while optional, allows JAX to transfer arrays from the device + # to the host in parallel, which is particularly important for cloud TPU. + if fastmath.is_backend(fastmath.Backend.JAX): + opt_state = jax.device_get(opt_state) + step, history, model_state = self._step, self._history, self._model_state + output_dir = self._output_dir + + weights_file = os.path.join(output_dir, prefix + ".pkl.gz") + + # This dict will be stored as the model. + trainer_state_dict = make_trainer_state_dict( + step, opt_state, history, model_state, self._input_signature + ) + self._save_state_dict(trainer_state_dict, weights_file) + + if keep: + weights_file = os.path.join(output_dir, "{}_{}.pkl.gz".format(prefix, step)) + self._save_state_dict(trainer_state_dict, weights_file) + + def save_computation_graphs(self): + """Dump computation graphs to files.""" + if self.n_devices != 1: + return # TODO(lukaszkaiser): make this work with more devices. + batch = next(self._train_stream) + output_dir = self._output_dir + if self.n_devices > 1: + batch = _reshape_by_device(batch, self.n_devices) + weights = self._opt_state.weights[0] + forward_computation = ( + jax.jit(self._model_predict_eval) + .lower( + batch, weights=weights, state=self._model_state[0], rng=self._rngs[0] + ) + .compiler_ir(dialect="hlo") + ) + with tf.io.gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f: + f.write(forward_computation.as_hlo_text()) + with tf.io.gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f: + f.write(forward_computation.as_hlo_dot_graph()) + + def log_step(self, step_message): + log("Step % 6d: %s" % (self.step, step_message)) + + def log_metrics(self, metrics, summ_writer, log_prefix): + """Log metrics to summary writer and history.""" + history = self._history + rjust_len = max([0] + [len(name) for name in metrics]) + for name, value in metrics.items(): + self.log_step( + "%s %s | % .8f" % (log_prefix.ljust(5), name.rjust(rjust_len), value) + ) + full_name = "metrics/" + name + if history: + history.append(log_prefix, full_name, self.step, value) + if summ_writer: + summ_writer.scalar(full_name, value, self.step) + + def print_n_weights(self): + """Prints the total count of trainable weights.""" + opt_state = self._opt_state + sizes = _sizes(opt_state.weights) + if self.n_devices > 1: + unreplicate = lambda x: x[0] + single_weights = fastmath.nested_map(unreplicate, opt_state.weights) + sizes = _sizes(single_weights) + total_size = _nested_reduce(sum, sizes) + self.log_step("Total number of trainable weights: %d" % total_size) + + def _should_save_now(self): + return self._should_save_checkpoints and self._step in self._checkpoints_at + + def _current_step_is_best(self, high): + """Is the current step the best (highest if high, else lowest).""" + metric = self._checkpoint_highest if high else self._checkpoint_lowest + if metric is None: + return False + # History is a list of pairs (step, value). + history = self._history.get("eval", "metrics/" + metric) + sequence = [float(i[1]) for i in history] # Just the values. + best = max(sequence) if high else min(sequence) # Best value. + last_is_best = float(history[-1][1]) == best # Is last the best? + cur_step = history[-1][0] == self._step # Is last the current step? + return cur_step and last_is_best + + def _should_log_now(self): + return self._train_sw is not None and (self._step == 1 or self._step % 10 == 0) + + def _for_n_devices(self, x): + """Replicates/broadcasts `x` for n devices if `self.n_devices > 1`.""" + return tl.for_n_devices(x, self.n_devices) # pylint: disable=protected-access + + def close(self): + if self._train_sw is not None: + self._train_sw.close() + self._train_sw = None + if self._eval_sw is not None: + self._eval_sw.close() + self._eval_sw = None + + +@gin.configurable(denylist=["output_dir"]) +def train( + output_dir, + model=gin.REQUIRED, + loss_fn=tl.WeightedCategoryCrossEntropy(), + inputs=trax_inputs.batcher, + optimizer=trax_opt.Adafactor, + lr_schedule_fn=lr.multifactor, + trainer_class=Trainer, + steps=1000, + checkpoints_at=None, + permanent_checkpoints_at=None, + eval_steps=10, + eval_frequency=100, + permanent_checkpoint_frequency=None, + random_seed=None, + save_graphs=True, + metrics=None, + checkpoint_highest=None, + checkpoint_lowest=None, + use_loop=True, + loss_chunk_size=0, + use_memory_efficient_trainer=False, + adasum=False, + init_checkpoint=None, + callbacks=None, + n_weights_shards=1, + additional_train_tasks=None, + additional_eval_tasks=None, + additional_eval_streams=None, +): + """Train the model on the inputs. Args: - inputs_stream: Iterable of inputs to evaluate on. - weights: Weights for each f in eval_fns. - state: State for each f in eval_fns. - rng: Single-use random number generator (JAX PRNG key). + output_dir: Directory where to put the logs and checkpoints. + model: The model to train as a callable returning 2 callables, an init_fn + and apply_fn. + loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, + rng -> loss. + inputs: callable returning trax.inputs.Inputs. + optimizer: The optimizer (see optimizers/base.py for signature). + lr_schedule_fn: A learning rate schedule function, that when called returns + a function from step to learning rate (a float). + trainer_class: The trainer class to use. + steps: int, total number of training steps. + checkpoints_at: list of integers. Save a checkpoint for each training step + in the list. + permanent_checkpoints_at: list of integers. Save a permanent checkpoint for + each training step in the list. + eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. + eval_frequency: int, how often to run evaluation (every eval_frequency + steps). If None or 0, eval disabled. + permanent_checkpoint_frequency: int, how often to save permanent checkpoints + (every permanent_checkpoint_frequency steps). + random_seed: the random seed to use; time/os dependent if None (default). + save_graphs: bool, if True, save computation graph to file. + metrics: optionally override the default metrics dictionary. + checkpoint_highest: save the checkpoint highest at this metric. + checkpoint_lowest: save the checkpoint lowest at this metric. + use_loop: whether to use training.Loop instead of Trainer. + loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. + use_memory_efficient_trainer: whether to use memory-efficient trainer. + adasum: if True, use adaptive summation for multi-device gradients. + init_checkpoint: a checkpoint for fine tuning. + callbacks: a list of callbacks to call during training. + n_weights_shards: shard weights into this many devices. + additional_train_tasks: additional tasks which should be performed during + training. + additional_eval_tasks: additional tasks which should be performed during + evaluation. + additional_eval_streams: List[NamedStream], additional data streams that + should be used during evaluation. Can be provided independently of + additional_eval_tasks. Returns: - Tuple of `(metrics, state)`. `metrics` is a dict from metric name to - metric value averaged over the number of inputs, and `state` is the end - state returned by this trainer's `predict_fn`. + trax.TrainerState or training.Loop if use_loop is True """ - metrics = collections.defaultdict(float) - count = 0 - for inp in inputs_stream: - count += 1 - rng, subrng = jax_random.split(rng) - metric_values, _ = self._jit_eval(inp, weights, state, subrng) - try: - metric_values = list(metric_values) - except (TypeError, IndexError): - metric_values = [float(metric_values)] - for m, v in zip(self._metrics, metric_values): - metrics[m] += v - return {m: v / count for (m, v) in metrics.items()}, state - - def save_gin(self): - """"Saves the operative gin config, only if it is the chief.""" - if not self._is_chief: - return - assert self._output_dir is not None - config_path = os.path.join(self._output_dir, 'config.gin') - config_str = gin.operative_config_str() - with tf.io.gfile.GFile(config_path, 'w') as f: - f.write(config_str) - sw = self._train_sw - if sw: - sw.text('gin_config', - jaxboard.markdownify_operative_config_str(config_str)) - - def _save_state_dict(self, trainer_state_dict, weights_file): - training.pickle_to_file(trainer_state_dict, weights_file, gzip=True) - log('Model saved to %s' % weights_file, stdout=False) - - def save_state(self, keep, prefix='model'): - """Save trainer state given a possibly replicated opt_state.""" - opt_state = self._opt_state - if self.n_devices > 1: - first_replica = lambda x: x[0] - opt_state = OptState(*fastmath.nested_map(first_replica, opt_state)) - # This line, while optional, allows JAX to transfer arrays from the device - # to the host in parallel, which is particularly important for cloud TPU. - if fastmath.is_backend(fastmath.Backend.JAX): - opt_state = jax.device_get(opt_state) - step, history, model_state = self._step, self._history, self._model_state - output_dir = self._output_dir - - weights_file = os.path.join(output_dir, prefix + '.pkl.gz') - - # This dict will be stored as the model. - trainer_state_dict = make_trainer_state_dict( - step, opt_state, history, model_state, self._input_signature) - self._save_state_dict(trainer_state_dict, weights_file) - - if keep: - weights_file = os.path.join(output_dir, - '{}_{}.pkl.gz'.format(prefix, step)) - self._save_state_dict(trainer_state_dict, weights_file) - - def save_computation_graphs(self): - """Dump computation graphs to files.""" - if self.n_devices != 1: - return # TODO(lukaszkaiser): make this work with more devices. - batch = next(self._train_stream) - output_dir = self._output_dir - if self.n_devices > 1: - batch = _reshape_by_device(batch, self.n_devices) - weights = self._opt_state.weights[0] - forward_computation = jax.jit(self._model_predict_eval).lower( - batch, weights=weights, state=self._model_state[0], - rng=self._rngs[0]).compiler_ir(dialect='hlo') - with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f: - f.write(forward_computation.as_hlo_text()) - with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f: - f.write(forward_computation.as_hlo_dot_graph()) - - def log_step(self, step_message): - log('Step % 6d: %s' % (self.step, step_message)) - - def log_metrics(self, metrics, summ_writer, log_prefix): - """Log metrics to summary writer and history.""" - history = self._history - rjust_len = max([0] + [len(name) for name in metrics]) - for name, value in metrics.items(): - self.log_step('%s %s | % .8f' % ( - log_prefix.ljust(5), name.rjust(rjust_len), value)) - full_name = 'metrics/' + name - if history: - history.append(log_prefix, full_name, self.step, value) - if summ_writer: - summ_writer.scalar(full_name, value, self.step) - - def print_n_weights(self): - """Prints the total count of trainable weights.""" - opt_state = self._opt_state - sizes = _sizes(opt_state.weights) - if self.n_devices > 1: - unreplicate = lambda x: x[0] - single_weights = fastmath.nested_map(unreplicate, opt_state.weights) - sizes = _sizes(single_weights) - total_size = _nested_reduce(sum, sizes) - self.log_step('Total number of trainable weights: %d' % total_size) - - def _should_save_now(self): - return self._should_save_checkpoints and self._step in self._checkpoints_at - - def _current_step_is_best(self, high): - """Is the current step the best (highest if high, else lowest).""" - metric = self._checkpoint_highest if high else self._checkpoint_lowest - if metric is None: - return False - # History is a list of pairs (step, value). - history = self._history.get('eval', 'metrics/' + metric) - sequence = [float(i[1]) for i in history] # Just the values. - best = max(sequence) if high else min(sequence) # Best value. - last_is_best = float(history[-1][1]) == best # Is last the best? - cur_step = history[-1][0] == self._step # Is last the current step? - return cur_step and last_is_best - - def _should_log_now(self): - return (self._train_sw is not None - and (self._step == 1 or self._step % 10 == 0)) - - def _for_n_devices(self, x): - """Replicates/broadcasts `x` for n devices if `self.n_devices > 1`.""" - return tl.for_n_devices(x, self.n_devices) # pylint: disable=protected-access - - def close(self): - if self._train_sw is not None: - self._train_sw.close() - self._train_sw = None - if self._eval_sw is not None: - self._eval_sw.close() - self._eval_sw = None - - -@gin.configurable(denylist=['output_dir']) -def train(output_dir, - model=gin.REQUIRED, - loss_fn=tl.WeightedCategoryCrossEntropy(), - inputs=trax_inputs.batcher, - optimizer=trax_opt.Adafactor, - lr_schedule_fn=lr.multifactor, - trainer_class=Trainer, - steps=1000, - checkpoints_at=None, - permanent_checkpoints_at=None, - eval_steps=10, - eval_frequency=100, - permanent_checkpoint_frequency=None, - random_seed=None, - save_graphs=True, - metrics=None, - checkpoint_highest=None, - checkpoint_lowest=None, - use_loop=True, - loss_chunk_size=0, - use_memory_efficient_trainer=False, - adasum=False, - init_checkpoint=None, - callbacks=None, - n_weights_shards=1, - additional_train_tasks=None, - additional_eval_tasks=None, - additional_eval_streams=None): - """Train the model on the inputs. - - Args: - output_dir: Directory where to put the logs and checkpoints. - model: The model to train as a callable returning 2 callables, an init_fn - and apply_fn. - loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, - rng -> loss. - inputs: callable returning trax.inputs.Inputs. - optimizer: The optimizer (see optimizers/base.py for signature). - lr_schedule_fn: A learning rate schedule function, that when called returns - a function from step to learning rate (a float). - trainer_class: The trainer class to use. - steps: int, total number of training steps. - checkpoints_at: list of integers. Save a checkpoint for each training step - in the list. - permanent_checkpoints_at: list of integers. Save a permanent checkpoint for - each training step in the list. - eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. - eval_frequency: int, how often to run evaluation (every eval_frequency - steps). If None or 0, eval disabled. - permanent_checkpoint_frequency: int, how often to save permanent checkpoints - (every permanent_checkpoint_frequency steps). - random_seed: the random seed to use; time/os dependent if None (default). - save_graphs: bool, if True, save computation graph to file. - metrics: optionally override the default metrics dictionary. - checkpoint_highest: save the checkpoint highest at this metric. - checkpoint_lowest: save the checkpoint lowest at this metric. - use_loop: whether to use training.Loop instead of Trainer. - loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. - use_memory_efficient_trainer: whether to use memory-efficient trainer. - adasum: if True, use adaptive summation for multi-device gradients. - init_checkpoint: a checkpoint for fine tuning. - callbacks: a list of callbacks to call during training. - n_weights_shards: shard weights into this many devices. - additional_train_tasks: additional tasks which should be performed during - training. - additional_eval_tasks: additional tasks which should be performed during - evaluation. - additional_eval_streams: List[NamedStream], additional data streams that - should be used during evaluation. Can be provided independently of - additional_eval_tasks. - - Returns: - trax.TrainerState or training.Loop if use_loop is True - """ - base.N_WEIGHTS_SHARDS = n_weights_shards - if (permanent_checkpoint_frequency is not None - and permanent_checkpoints_at is not None): - raise ValueError('Only one of ["permanent_checkpoint_frequency", ' - '"permanent_checkpoints_at"] should be set.') - if use_loop: - n_devices = num_devices() or fastmath.local_device_count() - - # Prepare the training task. - # Inputs is either an Inputs instance or a function that returns it. - if callable(inputs): # If we pass a function, e.g., through gin, call it. - inputs = inputs() - opt = optimizer if use_memory_efficient_trainer else optimizer() - train_task = training.TrainTask( - inputs.train_stream(n_devices), - loss_layer=loss_fn, - optimizer=opt, - lr_schedule=lr_schedule_fn(), - n_steps_per_checkpoint=eval_frequency, - n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency) - - if additional_train_tasks is None: - additional_train_tasks = [] - - # Prepare the evaluation. - metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS - names, metrics = zip(*metrics_dict.items()) - eval_task = training.EvalTask(inputs.eval_stream(n_devices), - metrics, - metric_names=names, - n_eval_batches=eval_steps) - - if additional_eval_tasks is None: - additional_eval_tasks = [] - - additional_eval_tasks_from_streams = [] - if additional_eval_streams is not None: - for stream in additional_eval_streams: - additional_eval_tasks_from_streams.append( - training.EvalTask(stream.stream, - metrics, - metric_names=names, - n_eval_batches=eval_steps, - export_prefix=stream.name)) - - # Prepare the training loop. - checkpoint_at = None - if checkpoints_at is not None: - checkpoint_at = lambda step: step in checkpoints_at - permanent_checkpoint_at = None - if permanent_checkpoints_at is not None: - permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at) - - # Setup the model. - model_train = model(mode='train') - model_predict_eval = model(mode='eval') - if init_checkpoint: - model_train.init_from_file(init_checkpoint, weights_only=True) - model_predict_eval.init_from_file(init_checkpoint, weights_only=True) - loop = training.Loop( - model_train, [train_task] + additional_train_tasks, - eval_model=model_predict_eval, - eval_tasks=[eval_task] + - additional_eval_tasks + additional_eval_tasks_from_streams, - output_dir=output_dir, - checkpoint_at=checkpoint_at, - checkpoint_low_metric=checkpoint_lowest, - checkpoint_high_metric=checkpoint_highest, - permanent_checkpoint_at=permanent_checkpoint_at, - n_devices=n_devices, - loss_chunk_size=loss_chunk_size, - use_memory_efficient_trainer=use_memory_efficient_trainer, - adasum=adasum, + base.N_WEIGHTS_SHARDS = n_weights_shards + if ( + permanent_checkpoint_frequency is not None + and permanent_checkpoints_at is not None + ): + raise ValueError( + 'Only one of ["permanent_checkpoint_frequency", ' + '"permanent_checkpoints_at"] should be set.' + ) + if use_loop: + n_devices = num_devices() or fastmath.local_device_count() + + # Prepare the training task. + # Inputs is either an Inputs instance or a function that returns it. + if callable(inputs): # If we pass a function, e.g., through gin, call it. + inputs = inputs() + opt = optimizer if use_memory_efficient_trainer else optimizer() + train_task = training.TrainTask( + inputs.train_stream(n_devices), + loss_layer=loss_fn, + optimizer=opt, + lr_schedule=lr_schedule_fn(), + n_steps_per_checkpoint=eval_frequency, + n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency, + ) + + if additional_train_tasks is None: + additional_train_tasks = [] + + # Prepare the evaluation. + metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS + names, metrics = zip(*metrics_dict.items()) + eval_task = training.EvalTask( + inputs.eval_stream(n_devices), + metrics, + metric_names=names, + n_eval_batches=eval_steps, + ) + + if additional_eval_tasks is None: + additional_eval_tasks = [] + + additional_eval_tasks_from_streams = [] + if additional_eval_streams is not None: + for stream in additional_eval_streams: + additional_eval_tasks_from_streams.append( + training.EvalTask( + stream.stream, + metrics, + metric_names=names, + n_eval_batches=eval_steps, + export_prefix=stream.name, + ) + ) + + # Prepare the training loop. + checkpoint_at = None + if checkpoints_at is not None: + checkpoint_at = lambda step: step in checkpoints_at + permanent_checkpoint_at = None + if permanent_checkpoints_at is not None: + permanent_checkpoint_at = lambda step: step in permanent_checkpoints_at + + # Setup the model. + model_train = model(mode="train") + model_predict_eval = model(mode="eval") + if init_checkpoint: + model_train.init_from_file(init_checkpoint, weights_only=True) + model_predict_eval.init_from_file(init_checkpoint, weights_only=True) + loop = training.Loop( + model_train, + [train_task] + additional_train_tasks, + eval_model=model_predict_eval, + eval_tasks=[eval_task] + + additional_eval_tasks + + additional_eval_tasks_from_streams, + output_dir=output_dir, + checkpoint_at=checkpoint_at, + checkpoint_low_metric=checkpoint_lowest, + checkpoint_high_metric=checkpoint_highest, + permanent_checkpoint_at=permanent_checkpoint_at, + n_devices=n_devices, + loss_chunk_size=loss_chunk_size, + use_memory_efficient_trainer=use_memory_efficient_trainer, + adasum=adasum, + random_seed=random_seed, + callbacks=callbacks, + ) + + steps_to_go = steps - loop.step + if steps_to_go <= 0: + log("Stop training, already reached the total training steps %d" % steps) + return loop + + # Train and return the loop. + loop.run(steps_to_go) + return loop + + n_devices = num_devices() + trainer = trainer_class( + model, + loss_fn, + optimizer, + lr_schedule_fn(), + inputs, + output_dir, random_seed=random_seed, - callbacks=callbacks, + n_devices=n_devices, + checkpoints_at=checkpoints_at, + metrics=metrics, + checkpoint_lowest=checkpoint_lowest, + checkpoint_highest=checkpoint_highest, + init_checkpoint=init_checkpoint, ) - steps_to_go = steps - loop.step - if steps_to_go <= 0: - log('Stop training, already reached the total training steps %d' % steps) - return loop - - # Train and return the loop. - loop.run(steps_to_go) - return loop - - n_devices = num_devices() - trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, - output_dir, - random_seed=random_seed, - n_devices=n_devices, - checkpoints_at=checkpoints_at, - metrics=metrics, - checkpoint_lowest=checkpoint_lowest, - checkpoint_highest=checkpoint_highest, - init_checkpoint=init_checkpoint) - - epoch_steps = [steps] # Only training if eval_frequency is 0 or None - if eval_frequency and eval_steps > 0: - epoch_steps = itertools.chain([1, # first epoch only 1 step - eval_frequency - 1], - itertools.repeat(eval_frequency)) - trainer.log_step('Starting training using %d devices' % trainer.n_devices) - trainer.print_n_weights() - - try: - for epoch_steps in epochs(steps, trainer.step, epoch_steps): - trainer.train_epoch(epoch_steps, eval_steps) - - # Bookkeeping we do at the first step - if trainer.step == 1: - # Save computation graph (single-device only for now) - if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): - trainer.save_computation_graphs() - - # Save Gin config - trainer.save_gin() - - trainer.log_step('Training done') - except Exception as e: - raise e - finally: - trainer.close() - return trainer.state + epoch_steps = [steps] # Only training if eval_frequency is 0 or None + if eval_frequency and eval_steps > 0: + epoch_steps = itertools.chain( + [1, eval_frequency - 1], # first epoch only 1 step + itertools.repeat(eval_frequency), + ) + trainer.log_step("Starting training using %d devices" % trainer.n_devices) + trainer.print_n_weights() + + try: + for epoch_steps in epochs(steps, trainer.step, epoch_steps): + trainer.train_epoch(epoch_steps, eval_steps) + + # Bookkeeping we do at the first step + if trainer.step == 1: + # Save computation graph (single-device only for now) + if save_graphs and fastmath.is_backend(fastmath.Backend.JAX): + trainer.save_computation_graphs() + + # Save Gin config + trainer.save_gin() + + trainer.log_step("Training done") + except Exception as e: + raise e + finally: + trainer.close() + return trainer.state @gin.configurable def num_devices(value=None): - """Returns how many devices to use (if None, default, use all available).""" - return value + """Returns how many devices to use (if None, default, use all available).""" + return value @gin.configurable def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True): - """Returns a (JIT-compiled) function that computes updates for one step.""" - model_and_loss = tl.Serial(predict_fn, loss_fn) - # Gradients are always wrt. the first argument, so putting weights first. - def model_and_loss_call(weights, batch, state, rng): - res = model_and_loss(batch, weights=weights, state=state, rng=rng) - return res, model_and_loss.state - if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. - def single_update(weights_and_slots, i, opt_params, batch, state, rng): - weights, slots = weights_and_slots - rng, subrng = jax_random.split(rng[0]) - grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) - grads, state = grad_fn(weights, batch, state, rng) - new_weights, new_slots, stats = optimizer.tree_update( - i, grads, weights, slots, opt_params) - return (new_weights, new_slots), stats, state, [subrng] - if jit: - # TODO(lukaszkaiser): donate_argnums=(0,) when XLA supports it on GPU - return fastmath.jit(single_update) - else: - return single_update - - # Else, for n_devices > 1: - @functools.partial(fastmath.pmap, axis_name='batch') # donate_argnums=(0,)) - def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): - """This is a multi-device version of the update function above.""" - # We assume all tensors have the first dimension = n_devices. - weights, slots = weights_and_slots - rng, subrng = jax_random.split(rng) - grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) - grads, state = grad_fn(weights, batch, state, rng) - # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just - # the number of devices on this host machine, however psum goes over all - # devices of all hosts (ex: a TPU pod) and we need to be averaging over all - # of them. - # - # Collect all gradients. - grads = fastmath.psum(grads, 'batch') - n_devices_total = fastmath.psum(np.array(1.0), 'batch') - # Average across hosts. - grads = jax.tree_util.tree_map(lambda g: g / n_devices_total, grads) - - new_weights, new_slots, stats = optimizer.tree_update( - i, grads, weights, slots, opt_params) - return (new_weights, new_slots), stats, state, subrng - - def update(weights_and_slots, i, opt_params, batch, state, rng): - return mapped_update(weights_and_slots, np.repeat(i, n_devices), - opt_params, batch, state, rng) - - return update + """Returns a (JIT-compiled) function that computes updates for one step.""" + model_and_loss = tl.Serial(predict_fn, loss_fn) + + # Gradients are always wrt. the first argument, so putting weights first. + def model_and_loss_call(weights, batch, state, rng): + res = model_and_loss(batch, weights=weights, state=state, rng=rng) + return res, model_and_loss.state + + if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. + + def single_update(weights_and_slots, i, opt_params, batch, state, rng): + weights, slots = weights_and_slots + rng, subrng = jax_random.split(rng[0]) + grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) + grads, state = grad_fn(weights, batch, state, rng) + new_weights, new_slots, stats = optimizer.tree_update( + i, grads, weights, slots, opt_params + ) + return (new_weights, new_slots), stats, state, [subrng] + + if jit: + # TODO(lukaszkaiser): donate_argnums=(0,) when XLA supports it on GPU + return fastmath.jit(single_update) + else: + return single_update + + # Else, for n_devices > 1: + @functools.partial(fastmath.pmap, axis_name="batch") # donate_argnums=(0,)) + def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): + """This is a multi-device version of the update function above.""" + # We assume all tensors have the first dimension = n_devices. + weights, slots = weights_and_slots + rng, subrng = jax_random.split(rng) + grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) + grads, state = grad_fn(weights, batch, state, rng) + # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just + # the number of devices on this host machine, however psum goes over all + # devices of all hosts (ex: a TPU pod) and we need to be averaging over all + # of them. + # + # Collect all gradients. + grads = fastmath.psum(grads, "batch") + n_devices_total = fastmath.psum(np.array(1.0), "batch") + # Average across hosts. + grads = jax.tree_util.tree_map(lambda g: g / n_devices_total, grads) + + new_weights, new_slots, stats = optimizer.tree_update( + i, grads, weights, slots, opt_params + ) + return (new_weights, new_slots), stats, state, subrng + + def update(weights_and_slots, i, opt_params, batch, state, rng): + return mapped_update( + weights_and_slots, np.repeat(i, n_devices), opt_params, batch, state, rng + ) + + return update @gin.configurable def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True): - """Returns a JIT-compiled predict function (unless jit=False).""" - model = tl.Serial(model_predict, metric_fn) - if not jit: - return model.pure_fn + """Returns a JIT-compiled predict function (unless jit=False).""" + model = tl.Serial(model_predict, metric_fn) + if not jit: + return model.pure_fn - return tl.jit_forward(model.pure_fn, n_devices) + return tl.jit_forward(model.pure_fn, n_devices) @gin.configurable def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True): - """Returns a (JIT-compiled) function that computes the loss for one step.""" - if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. - def single_compute_loss(opt_state, batch, state, rng): - rng, subrng = jax_random.split(rng[0]) - loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) - return loss_val, state, [subrng] - return fastmath.jit(single_compute_loss) if jit else single_compute_loss - - # Else, for n_devices > 1: - @functools.partial(fastmath.pmap, axis_name='batch') - def mapped_compute_loss(opt_state, batch, state, rng): - """This is a multi-device version of the update function above.""" - # We assume all tensors have the first dimension = n_devices. - rng, subrng = jax_random.split(rng) - loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) - return loss_val, state, subrng - - def compute_loss(opt_state, batch, state, rng): - return mapped_compute_loss( - opt_state, _reshape_by_device(batch, n_devices), state, rng) - - return compute_loss + """Returns a (JIT-compiled) function that computes the loss for one step.""" + if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. + + def single_compute_loss(opt_state, batch, state, rng): + rng, subrng = jax_random.split(rng[0]) + loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) + return loss_val, state, [subrng] + + return fastmath.jit(single_compute_loss) if jit else single_compute_loss + + # Else, for n_devices > 1: + @functools.partial(fastmath.pmap, axis_name="batch") + def mapped_compute_loss(opt_state, batch, state, rng): + """This is a multi-device version of the update function above.""" + # We assume all tensors have the first dimension = n_devices. + rng, subrng = jax_random.split(rng) + loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) + return loss_val, state, subrng + + def compute_loss(opt_state, batch, state, rng): + return mapped_compute_loss( + opt_state, _reshape_by_device(batch, n_devices), state, rng + ) + + return compute_loss def log(s, stdout=True): - logging.info(s) - if stdout: - print(s) - sys.stdout.flush() + logging.info(s) + if stdout: + print(s) + sys.stdout.flush() def epochs(total_steps, steps_to_skip, epoch_steps): - """Generates the number of steps in each epoch before reaching total_steps. - - Args: - total_steps: int, total number of steps. - steps_to_skip: int, number of steps to skip because of a restart. - epoch_steps: iterable of int, numbers of steps in each epoch. - - Yields: - epoch_steps: int, number of steps in this epoch - """ - steps_to_go = total_steps - steps_to_skip - epoch_steps = iter(epoch_steps) - - # Remove the desired number of steps from the stream. - for steps_this_epoch in epoch_steps: - if steps_this_epoch > steps_to_skip: - # Put back the number of steps left in the unfinished epoch. - epoch_steps = itertools.chain( - [steps_this_epoch - steps_to_skip], epoch_steps) - if steps_this_epoch >= steps_to_skip: - break - steps_to_skip -= steps_this_epoch - - # Yield the remaining steps per epoch up to total_steps. - for steps_this_epoch in epoch_steps: - steps_this_epoch = min(steps_this_epoch, steps_to_go) - yield steps_this_epoch - steps_to_go -= steps_this_epoch - if steps_to_go == 0: - break - - -def make_trainer_state_dict(step, - opt_state, - history, - model_state, - input_signature): - """Creates a trainer state dictionary to save to disk. - - Args: - step: int, a step number - opt_state: OptState namedtuple - history: `trax.history.History`, the history object. - model_state: A nested structure of the model state. - input_signature: signature of model inputs. - - Returns: - A dictionary with the fields of TrainerState and OptState flattened. - """ - flat_weights, flat_state = tl.flatten_weights_and_state( - opt_state.weights, model_state) - return { - 'step': step, - 'flat_weights': flat_weights, - 'slots': opt_state.slots, - 'opt_params': opt_state.opt_params, - 'history': history, - 'flat_state': flat_state, - 'input_signature': input_signature, - 'version_timestamp': 'Jun-18-2020' # To update in the future if needed. - } + """Generates the number of steps in each epoch before reaching total_steps. + + Args: + total_steps: int, total number of steps. + steps_to_skip: int, number of steps to skip because of a restart. + epoch_steps: iterable of int, numbers of steps in each epoch. + + Yields: + epoch_steps: int, number of steps in this epoch + """ + steps_to_go = total_steps - steps_to_skip + epoch_steps = iter(epoch_steps) + + # Remove the desired number of steps from the stream. + for steps_this_epoch in epoch_steps: + if steps_this_epoch > steps_to_skip: + # Put back the number of steps left in the unfinished epoch. + epoch_steps = itertools.chain( + [steps_this_epoch - steps_to_skip], epoch_steps + ) + if steps_this_epoch >= steps_to_skip: + break + steps_to_skip -= steps_this_epoch + + # Yield the remaining steps per epoch up to total_steps. + for steps_this_epoch in epoch_steps: + steps_this_epoch = min(steps_this_epoch, steps_to_go) + yield steps_this_epoch + steps_to_go -= steps_this_epoch + if steps_to_go == 0: + break + + +def make_trainer_state_dict(step, opt_state, history, model_state, input_signature): + """Creates a trainer state dictionary to save to disk. + + Args: + step: int, a step number + opt_state: OptState namedtuple + history: `trax.history.History`, the history object. + model_state: A nested structure of the model state. + input_signature: signature of model inputs. + + Returns: + A dictionary with the fields of TrainerState and OptState flattened. + """ + flat_weights, flat_state = tl.flatten_weights_and_state( + opt_state.weights, model_state + ) + return { + "step": step, + "flat_weights": flat_weights, + "slots": opt_state.slots, + "opt_params": opt_state.opt_params, + "history": history, + "flat_state": flat_state, + "input_signature": input_signature, + "version_timestamp": "Jun-18-2020", # To update in the future if needed. + } def trainer_state_from_dict(trainer_state_dict, model): - """Given the trainer state dictionary, returns `TrainerState`.""" - # TODO(afrozm): This becomes simpler if OptState is flattened into - # TrainerState. - step = trainer_state_dict['step'] - history = trainer_state_dict['history'] - input_signature = trainer_state_dict['input_signature'] - weights_and_state_sig = model.weights_and_state_signature(input_signature) - weights, model_state = tl.unflatten_weights_and_state( - trainer_state_dict['flat_weights'], trainer_state_dict['flat_state'], - weights_and_state_sig) - opt_state = OptState( - weights=weights, - slots=trainer_state_dict['slots'], - opt_params=trainer_state_dict['opt_params']) - return TrainerState(step=step, opt_state=OptState(*opt_state), - history=history, model_state=model_state) + """Given the trainer state dictionary, returns `TrainerState`.""" + # TODO(afrozm): This becomes simpler if OptState is flattened into + # TrainerState. + step = trainer_state_dict["step"] + history = trainer_state_dict["history"] + input_signature = trainer_state_dict["input_signature"] + weights_and_state_sig = model.weights_and_state_signature(input_signature) + weights, model_state = tl.unflatten_weights_and_state( + trainer_state_dict["flat_weights"], + trainer_state_dict["flat_state"], + weights_and_state_sig, + ) + opt_state = OptState( + weights=weights, + slots=trainer_state_dict["slots"], + opt_params=trainer_state_dict["opt_params"], + ) + return TrainerState( + step=step, + opt_state=OptState(*opt_state), + history=history, + model_state=model_state, + ) def load_trainer_state(output_dir, model, weights_file=None): - """Returns a TrainerState instance loaded from the given `output_dir`.""" - if weights_file is None: - weights_file = os.path.join(output_dir, 'model.pkl.gz') - if not tf.io.gfile.exists(weights_file): - return TrainerState(step=None, opt_state=None, - history=trax_history.History(), model_state=None) - elif not tf.io.gfile.exists(weights_file): - raise ValueError('File not found: %s' % weights_file) - - trainer_state_dict = training.unpickle_from_file(weights_file, gzip=True) - trainer_state = trainer_state_from_dict(trainer_state_dict, model) - log('Model loaded from %s at step %d' % (weights_file, trainer_state.step)) - logging.debug('From loaded model : history = %s', trainer_state.history) - return trainer_state + """Returns a TrainerState instance loaded from the given `output_dir`.""" + if weights_file is None: + weights_file = os.path.join(output_dir, "model.pkl.gz") + if not tf.io.gfile.exists(weights_file): + return TrainerState( + step=None, + opt_state=None, + history=trax_history.History(), + model_state=None, + ) + elif not tf.io.gfile.exists(weights_file): + raise ValueError("File not found: %s" % weights_file) + + trainer_state_dict = training.unpickle_from_file(weights_file, gzip=True) + trainer_state = trainer_state_from_dict(trainer_state_dict, model) + log("Model loaded from %s at step %d" % (weights_file, trainer_state.step)) + logging.debug("From loaded model : history = %s", trainer_state.history) + return trainer_state def _reshape_by_device(x, n_devices): - """Reshapes possibly nested x into a shape (n_devices, ...).""" - return tl.reshape_by_device(x, n_devices) # pylint: disable=protected-access + """Reshapes possibly nested x into a shape (n_devices, ...).""" + return tl.reshape_by_device(x, n_devices) # pylint: disable=protected-access def _nested_reduce(f, x): - """Fold the function f to the nested structure x (dicts, tuples, lists).""" - if isinstance(x, list): - return f([_nested_reduce(f, y) for y in x]) - if isinstance(x, tuple): - return f([_nested_reduce(f, y) for y in x]) - if isinstance(x, dict): - return f([_nested_reduce(f, v) for (_, v) in x.items()]) - return x + """Fold the function f to the nested structure x (dicts, tuples, lists).""" + if isinstance(x, list): + return f([_nested_reduce(f, y) for y in x]) + if isinstance(x, tuple): + return f([_nested_reduce(f, y) for y in x]) + if isinstance(x, dict): + return f([_nested_reduce(f, v) for (_, v) in x.items()]) + return x def _sizes(x): - """Get a structure of sizes for a structure of nested arrays.""" - def size(x): - try: - return x.size - except Exception: # pylint: disable=broad-except - return 0 - return fastmath.nested_map(size, x) + """Get a structure of sizes for a structure of nested arrays.""" + + def size(x): + try: + return x.size + except Exception: # pylint: disable=broad-except + return 0 + + return fastmath.nested_map(size, x) def _repeat_stream(stream, n_devices): - """Repeat a stream indefinitely.""" - while True: - for example in stream(n_devices): - yield example + """Repeat a stream indefinitely.""" + while True: + for example in stream(n_devices): + yield example diff --git a/trax/supervised/trainer_lib_test.py b/trax/supervised/trainer_lib_test.py deleted file mode 100644 index 6464cdf2c..000000000 --- a/trax/supervised/trainer_lib_test.py +++ /dev/null @@ -1,555 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.supervised.trainer_lib.""" - -import functools -import os - -from absl.testing import absltest -from absl.testing import parameterized -import jax -from jax.config import config -import tensorflow.compat.v2 as tf -from trax import fastmath -from trax import layers as tl -from trax import models -from trax import optimizers as trax_opt -from trax import shapes as trax_shapes -from trax import test_utils -from trax.data import inputs as inputs_lib -from trax.fastmath import numpy as jnp -from trax.supervised import lr_schedules as lr -from trax.supervised import trainer_lib -from trax.tf_numpy import extensions as npe -from trax.tf_numpy import numpy as tf_np - - - -def _test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)): - """Make trainer_lib.inputs.Inputs.""" - batch_size = 2 * jax.device_count() - - def input_stream(n_devices): - del n_devices - key = fastmath.random.get_prng(0) - while True: - keys = fastmath.random.split(key, 4) - key = keys[0] - inputs = fastmath.random.uniform( - keys[1], [batch_size] + list(input_shape)) - targets = fastmath.random.randint( - keys[2], [batch_size], dtype=jnp.int32, minval=0, maxval=n_classes) - weights = fastmath.random.uniform(keys[3], [batch_size]) - if with_weights: - yield inputs, targets, weights - else: - yield inputs, targets - - def input_stream_masked(n_devices): - return inputs_lib.add_loss_weights(input_stream(n_devices)) - - return inputs_lib.Inputs(input_stream_masked) - - -def _test_inputs_lm(vocab_size, seq_len, per_device_batch_size=2): - """Make trainer_lib.inputs.Inputs for language model.""" - batch_size = per_device_batch_size * jax.device_count() - - def input_stream(_): - def make_batch(key): - return fastmath.random.randint( - key, [batch_size, seq_len], dtype=jnp.int32, minval=0, - maxval=vocab_size) - key = fastmath.random.get_prng(0) - while True: - keys = fastmath.random.split(key, 3) - key = keys[0] - inputs = make_batch(keys[1]) - targets = make_batch(keys[2]) - yield inputs, targets - - def input_stream_masked(n_devices): - return inputs_lib.add_loss_weights(input_stream(n_devices)) - - return inputs_lib.Inputs(input_stream_masked) - - - -BACKENDS = [fastmath.Backend.JAX] -BACKENDS_AND_CONFIGS = [(fastmath.Backend.JAX, [('Simple', None)])] - - -def short_name(b): - if b == fastmath.Backend.JAX: - return 'jax' - else: - return 'tf' - - -def opt_name(opt): - if opt is None: - return 'None' - return opt.__name__ - - -def _pure_lsh_self_attention_fn(n_chunks_after=0): - return functools.partial( - tl.PureLSHSelfAttentionWrapper, - attention_dropout=0.1, - chunk_len=16, - n_buckets=[32, 32], - n_chunks_after=n_chunks_after, - n_chunks_before=1, - n_hashes=2, - n_parallel_heads=1, - max_length_for_buckets=1024, - predict_drop_len=128, - predict_mem_len=1024, - num_weights=2, - bias=False, - pure_lsh_implementation=tl.PureLSHSelfAttention, - ) - - -def _mixed_lsh_self_attention_fn(n_chunks_after=0): - return functools.partial( - tl.PureLSHSelfAttentionWrapper, - attention_dropout=0.1, - chunk_len=16, - n_buckets=[32, 32], - n_chunks_after=n_chunks_after, - n_chunks_before=1, - n_hashes=2, - n_parallel_heads=1, - max_length_for_buckets=1024, - predict_drop_len=128, - predict_mem_len=1024, - num_weights=2, - bias=False, - pure_lsh_implementation=tl.MixedLSHSelfAttention, - ) - - -class TraxTest(parameterized.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super().__init__(methodName) - if npe.tpu_devices(): - # Initialize TPU for TF - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local') - tf.tpu.experimental.initialize_tpu_system(resolver) - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - self._old_is_allow_float64 = tf_np.is_allow_float64() - tf_np.set_allow_float64(False) - - def tearDown(self): - tf_np.set_allow_float64(self._old_is_allow_float64) - super().tearDown() - - def _test_train_eval_predict(self, backend, model_name='Simple', - optimizer=None): - with fastmath.use_backend(backend): - # Prepare model and inputs - steps = 2 - eval_steps = 2 - - if model_name == 'Simple': - n_classes = 4 - # Adds Dropout and BatchNorm to test state handling. - def model_fn(mode='train'): - return tl.Serial( - tl.Dropout(mode=mode, rate=0.1), - tl.BatchNorm(mode=mode), - models.MLP(layer_widths=(16, 16, n_classes), mode=mode)) - inputs = _test_inputs(n_classes) - n_in = 1 - elif model_name == 'Resnet50': - n_classes = 4 - model_fn = models.Resnet50 - inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) - n_in = 1 - elif model_name == 'Transformer': - vocab_size = 32 - seq_len = 16 - inputs = _test_inputs_lm(vocab_size, seq_len) - model_fn = functools.partial( - models.Transformer, - input_vocab_size=vocab_size) - n_in = 2 - else: - raise ValueError('Unrecognized model name: ' + model_name) - - kwargs = {} - if optimizer is not None: - kwargs['optimizer'] = optimizer - - # Train and evaluate - output_dir = self.create_tempdir().full_path - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1, # eval at every step. - **kwargs) - - # Assert total train steps - self.assertEqual(steps, loop.step) - - inputs = inputs.train_stream(1) - - # Predict with final weights - model = model_fn() - weights = loop.model.weights - state = loop.model.state - model(next(inputs)[:n_in], weights=weights, state=state) - - # Predict with weights loaded from file. - model = model_fn() - model.init_from_file(os.path.join(output_dir, 'model.pkl.gz')) - model(next(inputs)[:n_in]) - - @parameterized.named_parameters( - ('_%s_%s_%s' % (short_name(backend), model_name, opt_name(opt)), # pylint: disable=g-complex-comprehension - backend, model_name, opt) - for backend, configs in BACKENDS_AND_CONFIGS - for model_name, opt in configs) - def test_train_eval_predict(self, backend, model_name, opt): - self._test_train_eval_predict(backend, model_name, opt) - - @parameterized.parameters(BACKENDS) - def test_train_eval_predict_sm3(self, backend): - self._test_train_eval_predict(backend, 'Simple', trax_opt.SM3) - - @parameterized.parameters(BACKENDS) - def test_train_restart(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 2 - eval_steps = 2 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - - # Train and evaluate - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Restart training - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=(2 * steps), - eval_steps=eval_steps, - eval_frequency=1) - - # Assert total train steps - self.assertEqual(loop.step, 2 * steps) - - @parameterized.parameters(BACKENDS) - def test_train_permanent_checkpoints(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 5 - eval_steps = 2 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - - # Train and evaluate - output_dir = self.create_tempdir().full_path - - # Steps 1 -> 5 - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1, - permanent_checkpoint_frequency=2) - - # Steps 6 -> 10 - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=(2 * steps), - eval_steps=eval_steps, - eval_frequency=1, - permanent_checkpoints_at=[7, 8, 10]) - - path = os.path.join(output_dir, 'model.pkl.gz') - self.assertTrue(tf.io.gfile.exists(path)) - - for step in range(11): - filename = 'model_{}.pkl.gz'.format(step) - path = os.path.join(output_dir, filename) - if step in [1, 2, 4, 7, 8, 10]: - self.assertTrue(tf.io.gfile.exists(path), - msg='No model for step: {} in dir {}.'.format( - step, tf.io.gfile.listdir(output_dir))) - else: - self.assertFalse(tf.io.gfile.exists(path), - msg='Model for step: {} in dir {}.'.format( - step, tf.io.gfile.listdir(output_dir))) - - # Assert total train steps - self.assertEqual(loop.step, 10) - - @parameterized.parameters(BACKENDS) - def test_train_restart_with_same_steps(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 2 - eval_steps = 2 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - - # Train and evaluate - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Restart training - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Assert total train steps - self.assertEqual(loop.step, steps) - - def test_train_with_pure_lsh_attention(self, backend=fastmath.Backend.JAX): - with fastmath.use_backend(backend): - # Prepare model and inputs - def model(mode='train'): - return models.ConfigurableTerraformer( - mode=mode, - d_model=16, - d_ff=16, - n_heads=2, - dropout=0.05, - n_decoder_layers=1, - n_encoder_layers=1, - input_vocab_size=256, - encoder_attention_type=_pure_lsh_self_attention_fn(), - encoder_decoder_attention_type=_pure_lsh_self_attention_fn(), - ) - - max_len = 128 - inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) - - steps = 1 - eval_steps = 1 - - # Train and evaluate - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir, - model=model, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Read checkpoint - model_file = os.path.join(output_dir, 'model.pkl.gz') - - shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) - shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) - - model_predict = model(mode='predict') - model_predict.init_from_file( - model_file, weights_only=True, input_signature=(shape1l, shape11)) - - def test_train_with_mixed_lsh_attention(self, backend=fastmath.Backend.JAX): - with fastmath.use_backend(backend): - # Prepare model and inputs - - def model(mode='train'): - return models.ConfigurableTerraformer( - mode=mode, - d_model=16, - d_ff=16, - n_heads=2, - dropout=0.05, - n_decoder_layers=1, - n_encoder_layers=1, - input_vocab_size=256, - encoder_attention_type=_mixed_lsh_self_attention_fn(), - encoder_decoder_attention_type=_mixed_lsh_self_attention_fn(), - ) - - max_len = 128 - inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) - - steps = 1 - eval_steps = 1 - - # Train and evaluate - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir, - model=model, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Read checkpoint - model_file = os.path.join(output_dir, 'model.pkl.gz') - - shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) - shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) - - model_predict = model(mode='predict') - model_predict.init_from_file(model_file, weights_only=True, - input_signature=(shape1l, shape11)) - - @parameterized.parameters(BACKENDS) - def test_train_fills_in_missing_eval_metrics(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 2 - eval_steps = 2 - model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - additional_eval_stream = trainer_lib.NamedStream( - # deliberately duplicating eval data - stream=inputs.eval_stream(1), - name='additional_eval_task') - - # Train and evaluate - output_dir = self.create_tempdir().full_path - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1, - additional_eval_streams=[additional_eval_stream]) - - self.assertLen(loop.eval_tasks, 2) - eval_task_1, eval_task_2 = loop.eval_tasks - self.assertCountEqual(eval_task_1.metrics, eval_task_2.metrics) - self.assertCountEqual(eval_task_1.metric_names, eval_task_2.metric_names) - - @parameterized.named_parameters( - ('_%s' % short_name(backend), backend) - for backend in BACKENDS) - def test_train_with_weights(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 2 - eval_steps = 2 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes, with_weights=True) - - # Train and evaluate - output_dir = self.create_tempdir().full_path - state = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps) - - # Assert total train steps - self.assertEqual(state.step, steps) - - @parameterized.parameters(BACKENDS) - def test_reset_twice(self, backend): - with fastmath.use_backend(backend): - n_classes = 4 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - - trainer = trainer_lib.Trainer( - model=model_fn, - loss_fn=tl.WeightedCategoryCrossEntropy(), - optimizer=trax_opt.SM3, - lr_schedule=lr.multifactor(), - inputs=inputs, - ) - - output_dir1 = self.create_tempdir(name='output_dir1').full_path - trainer.reset(output_dir1) - trainer.evaluate(1) - output_dir2 = self.create_tempdir(name='output_dir2').full_path - trainer.reset(output_dir2) - trainer.evaluate(1) - - def test_tf_xla_forced_compile(self): - # TODO(wangpeng): re-enable this test - self.skipTest('Needs --config=cuda to pass this test') - old_flag = fastmath.tf.tf_xla_forced_compile_enabled() - fastmath.tf.set_tf_xla_forced_compile(True) - self._test_train_eval_predict('tf') - fastmath.tf.set_tf_xla_forced_compile(old_flag) - - - -class EpochsTest(absltest.TestCase): - - def test_cuts_epoch_when_total_steps_reached(self): - epoch_steps = trainer_lib.epochs( - total_steps=5, steps_to_skip=0, epoch_steps=[1, 2, 3]) - self.assertEqual(list(epoch_steps), [1, 2, 2]) - - def test_skips_full_epoch(self): - epoch_steps = trainer_lib.epochs( - total_steps=4, steps_to_skip=2, epoch_steps=[2, 2]) - self.assertEqual(list(epoch_steps), [2]) - - def test_skips_part_of_epoch(self): - epoch_steps = trainer_lib.epochs( - total_steps=4, steps_to_skip=1, epoch_steps=[2, 2]) - self.assertEqual(list(epoch_steps), [1, 2]) - - -if __name__ == '__main__': - config.config_with_absl() - tf.compat.v1.enable_eager_execution() - absltest.main() diff --git a/trax/supervised/training.py b/trax/supervised/training.py index e65709ae1..366f7de90 100644 --- a/trax/supervised/training.py +++ b/trax/supervised/training.py @@ -62,1326 +62,1421 @@ from trax.supervised import history as trax_history -_Evaluator = collections.namedtuple( - '_Evaluator', ['weights', 'state', 'metrics_fn'] -) +_Evaluator = collections.namedtuple("_Evaluator", ["weights", "state", "metrics_fn"]) class Loop: - """Loop that can run for a given number of steps to train a supervised model. - - Can train the model on multiple tasks by interleaving updates according to the - ``which_task`` argument. - - The typical supervised training process randomly initializes a model and - updates its weights via feedback (loss-derived gradients) from a training - task, by looping through batches of labeled data. A training loop can also - be configured to run periodic evals and save intermediate checkpoints. - - For speed, the implementation takes advantage of JAX's composable function - transformations (specifically, ``jit`` and ``grad``). It creates JIT-compiled - pure functions derived from variants of the core model; schematically: - - - training variant: `jit(grad(pure_function(model+loss)))` - - evals variant: `jit(pure_function(model+evals))` - - In training or during evals, these variants are called with explicit - arguments for all relevant input data, model weights/state, optimizer slots, - and random number seeds: - - - batch: labeled data - - model weights/state: trainable weights and input-related state (e.g., as - used by batch norm) - - optimizer slots: weights in the optimizer that evolve during the training - process - - random number seeds: JAX PRNG keys that enable high-quality, distributed, - repeatable generation of pseudo-random numbers - """ - - def __init__( - self, - model, - tasks, - eval_model=None, - eval_tasks=None, - output_dir=None, - checkpoint_at=None, - checkpoint_low_metric=None, - checkpoint_high_metric=None, - permanent_checkpoint_at=None, - eval_at=None, - which_task=None, - n_devices=None, - random_seed=None, - loss_chunk_size=0, - use_memory_efficient_trainer=False, - adasum=False, - callbacks=None, - ): - """Configures a training ``Loop``, including a random initialization. - - Args: - model: Trax layer, representing the core model to be trained. Loss - functions and eval functions (a.k.a. metrics) are considered to be - outside the core model, taking core model output and data labels as - their two inputs. - tasks: List of :py:class:`TrainTask` instances, which define the training - data, loss function, and optimizer to be used in respective tasks in - this training loop. It can also be a single :py:class:`TrainTask` - instance which is treated in the same way as a singleton list. - eval_model: Optional Trax layer, representing model used for evaluation, - e.g., with dropout turned off. If ``None``, the training model (model) - will be used. - eval_tasks: List of :py:class:`EvalTask` instances which define how to - evaluate the model: which validation data to use and which metrics to - report. Evaluation on each of the tasks and will run and be reported - separately which allows to score a model on different subtasks. This - argument can also be ``None``, in which case no evals will be run, or - a single :py:class:`EvalTask`, which wil be treated in the same way - as a singleton list. - output_dir: Path telling where to save outputs (evals and checkpoints). - Can be ``None`` if both ``eval_task`` and ``checkpoint_at`` are - ``None``. - checkpoint_at: Function (integer --> boolean) telling, for step n, whether - that step should have its checkpoint saved. If ``None``, the default - is periodic checkpointing at ``task.n_steps_per_checkpoint``. - checkpoint_low_metric: Name of metric, or None. The metric name must - be one of the metric names from the evals in ``eval_tasks``. At - checkpoint times determined by ``checkpoint_at``, a separate - specially named checkpoint will be saved (overwriting any previous - version) if the designated metric reaches a value less than or equal - to any previous recorded low value. No such checkpoint is saved if - arg value is `None`. - checkpoint_high_metric: Name of metric, or None. The metric name must - be one of the metric names from the evals in ``eval_tasks``. At - checkpoint times determined by ``checkpoint_at``, a separate - specially named checkpoint will be saved (overwriting any previous - version) if the designated metric reaches a value greater than or - equal to any previous recorded high value. No such checkpoint is - saved if arg value is `None`. - permanent_checkpoint_at: Function (integer --> boolean) telling, - for step n, whether that step should have its checkpoint saved - permanently. If ``None``, the default is periodic checkpointing at - ``task.n_steps_per_permanent_checkpoint``. - eval_at: Function (integer --> boolean) that says, for training step n, - whether that step should run evals. If ``None``, run evals on the - first step and on every N'th step, as determined by the first - training task. - which_task: Function (integer --> integer) indicating which task should be - used at which training step. Can be set to ``None`` in single-task - training. - n_devices: integer or ``None``, the number of devices for this - computation. - random_seed: the random seed to use; time/os dependent if ``None`` - (default). - loss_chunk_size: int, if > 0 use chunks of this size to make loss - computation more more memory-efficient. - use_memory_efficient_trainer: whether to use a special memory-efficient - trainer; if set to 2, the memory efficiency if very aggressive - adasum: if True, use adaptive summation for multi-device gradients - callbacks: List of subclasses of StepCallback to call on training - steps. + """Loop that can run for a given number of steps to train a supervised model. + + Can train the model on multiple tasks by interleaving updates according to the + ``which_task`` argument. + + The typical supervised training process randomly initializes a model and + updates its weights via feedback (loss-derived gradients) from a training + task, by looping through batches of labeled data. A training loop can also + be configured to run periodic evals and save intermediate checkpoints. + + For speed, the implementation takes advantage of JAX's composable function + transformations (specifically, ``jit`` and ``grad``). It creates JIT-compiled + pure functions derived from variants of the core model; schematically: + + - training variant: `jit(grad(pure_function(model+loss)))` + - evals variant: `jit(pure_function(model+evals))` + + In training or during evals, these variants are called with explicit + arguments for all relevant input data, model weights/state, optimizer slots, + and random number seeds: + + - batch: labeled data + - model weights/state: trainable weights and input-related state (e.g., as + used by batch norm) + - optimizer slots: weights in the optimizer that evolve during the training + process + - random number seeds: JAX PRNG keys that enable high-quality, distributed, + repeatable generation of pseudo-random numbers """ - self._is_chief, self._n_hosts, self._n_devices, self._rng = ( - init_host_and_devices(n_devices, random_seed)) - if use_memory_efficient_trainer: - self._rng = tl.on_cpu(self._rng) - - # Handle single task case without lists too. - if not isinstance(tasks, (list, tuple)): - tasks = [tasks] - - if not tasks: - raise ValueError('Must provide at least one training task.') - if eval_tasks is None: - eval_tasks = [] - eval_at = _never - else: - if not isinstance(eval_tasks, (list, tuple)): - eval_tasks = [eval_tasks] - - self._tasks = tasks - self._model = model - self._eval_model = eval_model or model - - self._use_memory_efficient_trainer = use_memory_efficient_trainer - self._loss_chunk_size = loss_chunk_size - self._adasum = adasum - # TODO(lukaszkaiser): can we have different eval models and save memory? - if use_memory_efficient_trainer: - assert len(tasks) == 1, 'only single task supported for now' - self._eval_model = model - - default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint) - permanent_default_at = _at_step_1_and_every_nth_step( - tasks[0].n_steps_per_permanent_checkpoint) - if output_dir is not None: - self._output_dir = os.path.expanduser(output_dir) - tf.io.gfile.makedirs(self._output_dir) - inputs.load_data_counters(self._output_dir) - else: - self._output_dir = None - - # Prepare training components. - self._step = 0 - self._history = trax_history.History() - self._checkpoint_at = checkpoint_at or default_at - self._checkpoint_low_metric = checkpoint_low_metric - self._checkpoint_high_metric = checkpoint_high_metric - self._permanent_checkpoint_at = ( - permanent_checkpoint_at or permanent_default_at) - if which_task is None: - # If which task is not passed, then we permute tasks one by one. - # If len(tasks) = 1, then which_task is a constant function equal to 0. - which_task = lambda n: n % len(tasks) - self._which_task = which_task - - # Initialize using the given random seed. - # NOTE: If random_seed is None then self._rng will be different on - # different hosts, leading to different weights on the different hosts. - self._batch_signature = shapes.signature(tasks[0].sample_batch) - self._model.rng = self.new_rng() - # In the memory-efficient case, we initialize in init_trainer. - if not use_memory_efficient_trainer: - if _is_uninitialized(self._model): - self._model.init(self._batch_signature) - self._eval_model.rng = self.new_rng() - if _is_uninitialized(self._eval_model): - self._eval_model.init(self._batch_signature) - - # To handle the above case (i.e. random_seed = None), we psum the weights - # and state and average them. - # NOTE: This adds time (how much?) so we prefer not to do it if it is - # unnecessary, i.e. random_seed was set. - # NOTE: Averaging the weights across devices can screw up the initial weight - # statistics. - # TODO(pkozakowski): Broadcast from one of the devices instead? - if (random_seed is None and self._n_hosts > 1 and - not use_memory_efficient_trainer): - logging.info('Syncing weights/state across %d hosts.', self._n_hosts) - # Do self._sync_weights_and_state_across_hosts() but layer-by-layer - # to save memory. - blocks, last_layer = optimizers.trainer.extract_reversible_blocks( - [self._model]) - all_layers = [] - for (std_layer, rev_layers) in blocks: - all_layers.append(tl.Serial(std_layer)) - all_layers.extend(rev_layers) - all_layers.append(last_layer) - for layer in all_layers: - weights_and_state = (layer.weights, layer.state) - if not _is_empty(weights_and_state): - layer.weights, layer.state = tl.on_cpu(self._unreplicate( - _make_weights_and_state_same_across_hosts( - self._for_n_devices(weights_and_state)))) - - # Create the optimizer for the training loss function. - self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) - - # Sync layers weights/state in memory effcient trainer layers. - if (random_seed is None and self._n_hosts > 1 and - use_memory_efficient_trainer): - logging.info('Syncing layers across %d hosts.', self._n_hosts) - for layer in self._trainer_per_task[0].all_layers: - weights_and_state = (layer.weights, layer.state) - if not _is_empty(weights_and_state): - layer.weights, layer.state = tl.on_cpu(self._unreplicate( - _make_weights_and_state_same_across_hosts( - self._for_n_devices(weights_and_state)))) - - # Load checkpoint if it exists. - self.load_checkpoint() - - # Prepare eval components. - self._eval_at = eval_at or default_at - self._eval_tasks = eval_tasks - loss_names = [task.loss_name for task in self._tasks] - metric_names = [ - name # pylint: disable=g-complex-comprehension - for eval_task in self._eval_tasks - for name in eval_task.metric_names - ] - self._rjust_len = max(map(len, loss_names + metric_names)) - self._evaluator_per_task = tuple( - self._init_evaluator(eval_task) for eval_task in self._eval_tasks) - - if self._output_dir is None: - _log('Will not write evaluation metrics, because output_dir is None.') - - def task_output_dir(task_index, task_list): - if self._output_dir is not None: - if len(task_list) < 2: - output_dir = self._output_dir - else: - output_dir = os.path.join( - self._output_dir, - task_list[task_index].export_prefix or str(task_index)) - tf.io.gfile.makedirs(output_dir) - return output_dir - else: - return None - self._output_dir_per_eval_task = [ - task_output_dir(i, eval_tasks) for i in range(len(eval_tasks))] - self._output_dir_per_train_task = [ - task_output_dir(i, tasks) for i in range(len(tasks))] - - callbacks = callbacks or [] - self._callbacks = [ - callback_class(self) for callback_class in callbacks - ] - - def _init_trainer(self, task): - """Initializes the per-task trainer.""" - # Build the per-task model, sharing weights with other tasks. - if not self._use_memory_efficient_trainer: - model_in_training = _model_with_ends( - self._model, - [task.loss_layer], - shapes.signature(task.sample_batch) - ) - if base.N_WEIGHTS_SHARDS > 1: - sharded_weights = fastmath.nested_map( - lambda x: x[0], tl.shard(model_in_training.weights)) - task.optimizer.tree_init(sharded_weights) - else: - task.optimizer.tree_init(model_in_training.weights) - return optimizers.Trainer( - model_in_training, task.optimizer, adasum=self._adasum) - # In the memory-efficient path, we initialize the model here. - blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( - [self._model, task.loss_layer], loss_chunk_size=self._loss_chunk_size) - rng = self._model.rng - sig = shapes.signature(task.sample_batch) - optimizers.trainer.init_reversible_blocks(blocks, loss_layer, sig, rng) - # TODO(lukaszkaiser): here optimizer is a function, revisit this. - return optimizers.ReversibleSerialTrainer( - blocks, loss_layer, task.optimizer, - free_accelerators_on_step=(self._use_memory_efficient_trainer == 2), - adasum=self._adasum) - - def _init_evaluator(self, eval_task): - """Initializes the per-task evaluator.""" - model_with_metrics = _model_with_metrics( - self._eval_model, eval_task) - if self._use_memory_efficient_trainer: - return _Evaluator( - weights=tl.on_cpu(model_with_metrics.weights[1]), - state=tl.on_cpu(model_with_metrics.state[1]), - metrics_fn=_accelerate_model_with_metrics(model_with_metrics, 0) - ) - else: - return _Evaluator( - # Replicate the eval part of weights and state. - weights=self._for_n_devices(model_with_metrics.weights[1]), - state=self._for_n_devices(model_with_metrics.state[1]), - metrics_fn=_accelerate_model_with_metrics( - model_with_metrics, self.n_devices) - ) - - def _sync_weights_and_state_across_hosts(self): - """Sync weights and state across all the hosts in the computation.""" - - if logging.vlog_is_on(1): - logging.debug( - 'Input training weights shape: %s', - fastmath.nested_map(lambda x: x.shape, - self._model.weights)) - logging.debug('Input training weights: %s', self._model.weights) - logging.debug('Input training state: %s', self._model.state) - logging.debug('Input eval weights: %s', self._eval_model.weights) - logging.debug('Input eval state: %s', self._eval_model.state) - - (self._model.weights, self._model.state, - self._eval_model.weights, self._eval_model.state) = self._unreplicate( - _make_weights_and_state_same_across_hosts( - self._for_n_devices( - (self._model.weights, self._model.state, - self._eval_model.weights, - self._eval_model.state)))) - - if logging.vlog_is_on(1): - logging.debug( - 'Output training weights shape: %s', - fastmath.nested_map(lambda x: x.shape, self._model.weights)) - logging.debug('Output training weights: %s', self._model.weights) - logging.debug('Output training state: %s', self._model.state) - logging.debug('Output eval weights: %s', self._eval_model.weights) - logging.debug('Output eval state: %s', self._eval_model.state) - - def run(self, n_steps=1): - """Runs this training loop for n steps. - - Optionally runs evals and saves checkpoints at specified points. - Args: - n_steps: Stop training after completing n steps. - """ - with self._open_summary_writers() as ( - train_summary_writers, eval_summary_writers): - process = psutil.Process(os.getpid()) - loss_acc, step_acc = 0.0, 0 - start_time = time.time() - optimizer_metrics_acc = collections.defaultdict(float) - for i in range(n_steps): - prev_task_index = self._which_task(self._step) - self._step += 1 - task_index = self._which_task(self._step) - task_changed = task_index != prev_task_index + def __init__( + self, + model, + tasks, + eval_model=None, + eval_tasks=None, + output_dir=None, + checkpoint_at=None, + checkpoint_low_metric=None, + checkpoint_high_metric=None, + permanent_checkpoint_at=None, + eval_at=None, + which_task=None, + n_devices=None, + random_seed=None, + loss_chunk_size=0, + use_memory_efficient_trainer=False, + adasum=False, + callbacks=None, + ): + """Configures a training ``Loop``, including a random initialization. + + Args: + model: Trax layer, representing the core model to be trained. Loss + functions and eval functions (a.k.a. metrics) are considered to be + outside the core model, taking core model output and data labels as + their two inputs. + tasks: List of :py:class:`TrainTask` instances, which define the training + data, loss function, and optimizer to be used in respective tasks in + this training loop. It can also be a single :py:class:`TrainTask` + instance which is treated in the same way as a singleton list. + eval_model: Optional Trax layer, representing model used for evaluation, + e.g., with dropout turned off. If ``None``, the training model (model) + will be used. + eval_tasks: List of :py:class:`EvalTask` instances which define how to + evaluate the model: which validation data to use and which metrics to + report. Evaluation on each of the tasks and will run and be reported + separately which allows to score a model on different subtasks. This + argument can also be ``None``, in which case no evals will be run, or + a single :py:class:`EvalTask`, which wil be treated in the same way + as a singleton list. + output_dir: Path telling where to save outputs (evals and checkpoints). + Can be ``None`` if both ``eval_task`` and ``checkpoint_at`` are + ``None``. + checkpoint_at: Function (integer --> boolean) telling, for step n, whether + that step should have its checkpoint saved. If ``None``, the default + is periodic checkpointing at ``task.n_steps_per_checkpoint``. + checkpoint_low_metric: Name of metric, or None. The metric name must + be one of the metric names from the evals in ``eval_tasks``. At + checkpoint times determined by ``checkpoint_at``, a separate + specially named checkpoint will be saved (overwriting any previous + version) if the designated metric reaches a value less than or equal + to any previous recorded low value. No such checkpoint is saved if + arg value is `None`. + checkpoint_high_metric: Name of metric, or None. The metric name must + be one of the metric names from the evals in ``eval_tasks``. At + checkpoint times determined by ``checkpoint_at``, a separate + specially named checkpoint will be saved (overwriting any previous + version) if the designated metric reaches a value greater than or + equal to any previous recorded high value. No such checkpoint is + saved if arg value is `None`. + permanent_checkpoint_at: Function (integer --> boolean) telling, + for step n, whether that step should have its checkpoint saved + permanently. If ``None``, the default is periodic checkpointing at + ``task.n_steps_per_permanent_checkpoint``. + eval_at: Function (integer --> boolean) that says, for training step n, + whether that step should run evals. If ``None``, run evals on the + first step and on every N'th step, as determined by the first + training task. + which_task: Function (integer --> integer) indicating which task should be + used at which training step. Can be set to ``None`` in single-task + training. + n_devices: integer or ``None``, the number of devices for this + computation. + random_seed: the random seed to use; time/os dependent if ``None`` + (default). + loss_chunk_size: int, if > 0 use chunks of this size to make loss + computation more more memory-efficient. + use_memory_efficient_trainer: whether to use a special memory-efficient + trainer; if set to 2, the memory efficiency if very aggressive + adasum: if True, use adaptive summation for multi-device gradients + callbacks: List of subclasses of StepCallback to call on training + steps. + """ + ( + self._is_chief, + self._n_hosts, + self._n_devices, + self._rng, + ) = init_host_and_devices(n_devices, random_seed) + if use_memory_efficient_trainer: + self._rng = tl.on_cpu(self._rng) + + # Handle single task case without lists too. + if not isinstance(tasks, (list, tuple)): + tasks = [tasks] + + if not tasks: + raise ValueError("Must provide at least one training task.") + if eval_tasks is None: + eval_tasks = [] + eval_at = _never + else: + if not isinstance(eval_tasks, (list, tuple)): + eval_tasks = [eval_tasks] + + self._tasks = tasks + self._model = model + self._eval_model = eval_model or model + + self._use_memory_efficient_trainer = use_memory_efficient_trainer + self._loss_chunk_size = loss_chunk_size + self._adasum = adasum + # TODO(lukaszkaiser): can we have different eval models and save memory? + if use_memory_efficient_trainer: + assert len(tasks) == 1, "only single task supported for now" + self._eval_model = model + + default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint) + permanent_default_at = _at_step_1_and_every_nth_step( + tasks[0].n_steps_per_permanent_checkpoint + ) + if output_dir is not None: + self._output_dir = os.path.expanduser(output_dir) + tf.io.gfile.makedirs(self._output_dir) + inputs.load_data_counters(self._output_dir) + else: + self._output_dir = None + + # Prepare training components. + self._step = 0 + self._history = trax_history.History() + self._checkpoint_at = checkpoint_at or default_at + self._checkpoint_low_metric = checkpoint_low_metric + self._checkpoint_high_metric = checkpoint_high_metric + self._permanent_checkpoint_at = permanent_checkpoint_at or permanent_default_at + if which_task is None: + # If which task is not passed, then we permute tasks one by one. + # If len(tasks) = 1, then which_task is a constant function equal to 0. + which_task = lambda n: n % len(tasks) + self._which_task = which_task + + # Initialize using the given random seed. + # NOTE: If random_seed is None then self._rng will be different on + # different hosts, leading to different weights on the different hosts. + self._batch_signature = shapes.signature(tasks[0].sample_batch) + self._model.rng = self.new_rng() + # In the memory-efficient case, we initialize in init_trainer. + if not use_memory_efficient_trainer: + if _is_uninitialized(self._model): + self._model.init(self._batch_signature) + self._eval_model.rng = self.new_rng() + if _is_uninitialized(self._eval_model): + self._eval_model.init(self._batch_signature) + + # To handle the above case (i.e. random_seed = None), we psum the weights + # and state and average them. + # NOTE: This adds time (how much?) so we prefer not to do it if it is + # unnecessary, i.e. random_seed was set. + # NOTE: Averaging the weights across devices can screw up the initial weight + # statistics. + # TODO(pkozakowski): Broadcast from one of the devices instead? + if ( + random_seed is None + and self._n_hosts > 1 + and not use_memory_efficient_trainer + ): + logging.info("Syncing weights/state across %d hosts.", self._n_hosts) + # Do self._sync_weights_and_state_across_hosts() but layer-by-layer + # to save memory. + blocks, last_layer = optimizers.trainer.extract_reversible_blocks( + [self._model] + ) + all_layers = [] + for (std_layer, rev_layers) in blocks: + all_layers.append(tl.Serial(std_layer)) + all_layers.extend(rev_layers) + all_layers.append(last_layer) + for layer in all_layers: + weights_and_state = (layer.weights, layer.state) + if not _is_empty(weights_and_state): + layer.weights, layer.state = tl.on_cpu( + self._unreplicate( + _make_weights_and_state_same_across_hosts( + self._for_n_devices(weights_and_state) + ) + ) + ) + + # Create the optimizer for the training loss function. + self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) + + # Sync layers weights/state in memory effcient trainer layers. + if random_seed is None and self._n_hosts > 1 and use_memory_efficient_trainer: + logging.info("Syncing layers across %d hosts.", self._n_hosts) + for layer in self._trainer_per_task[0].all_layers: + weights_and_state = (layer.weights, layer.state) + if not _is_empty(weights_and_state): + layer.weights, layer.state = tl.on_cpu( + self._unreplicate( + _make_weights_and_state_same_across_hosts( + self._for_n_devices(weights_and_state) + ) + ) + ) + + # Load checkpoint if it exists. + self.load_checkpoint() + + # Prepare eval components. + self._eval_at = eval_at or default_at + self._eval_tasks = eval_tasks + loss_names = [task.loss_name for task in self._tasks] + metric_names = [ + name # pylint: disable=g-complex-comprehension + for eval_task in self._eval_tasks + for name in eval_task.metric_names + ] + self._rjust_len = max(map(len, loss_names + metric_names)) + self._evaluator_per_task = tuple( + self._init_evaluator(eval_task) for eval_task in self._eval_tasks + ) - if task_changed: - loss_acc, step_acc = 0.0, 0 + if self._output_dir is None: + _log("Will not write evaluation metrics, because output_dir is None.") + + def task_output_dir(task_index, task_list): + if self._output_dir is not None: + if len(task_list) < 2: + output_dir = self._output_dir + else: + output_dir = os.path.join( + self._output_dir, + task_list[task_index].export_prefix or str(task_index), + ) + tf.io.gfile.makedirs(output_dir) + return output_dir + else: + return None + + self._output_dir_per_eval_task = [ + task_output_dir(i, eval_tasks) for i in range(len(eval_tasks)) + ] + self._output_dir_per_train_task = [ + task_output_dir(i, tasks) for i in range(len(tasks)) + ] + + callbacks = callbacks or [] + self._callbacks = [callback_class(self) for callback_class in callbacks] + + def _init_trainer(self, task): + """Initializes the per-task trainer.""" + # Build the per-task model, sharing weights with other tasks. + if not self._use_memory_efficient_trainer: + model_in_training = _model_with_ends( + self._model, [task.loss_layer], shapes.signature(task.sample_batch) + ) + if base.N_WEIGHTS_SHARDS > 1: + sharded_weights = fastmath.nested_map( + lambda x: x[0], tl.shard(model_in_training.weights) + ) + task.optimizer.tree_init(sharded_weights) + else: + task.optimizer.tree_init(model_in_training.weights) + return optimizers.Trainer( + model_in_training, task.optimizer, adasum=self._adasum + ) + # In the memory-efficient path, we initialize the model here. + blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( + [self._model, task.loss_layer], loss_chunk_size=self._loss_chunk_size + ) + rng = self._model.rng + sig = shapes.signature(task.sample_batch) + optimizers.trainer.init_reversible_blocks(blocks, loss_layer, sig, rng) + # TODO(lukaszkaiser): here optimizer is a function, revisit this. + return optimizers.ReversibleSerialTrainer( + blocks, + loss_layer, + task.optimizer, + free_accelerators_on_step=(self._use_memory_efficient_trainer == 2), + adasum=self._adasum, + ) - loss, optimizer_metrics = self._run_one_step(task_index, task_changed) + def _init_evaluator(self, eval_task): + """Initializes the per-task evaluator.""" + model_with_metrics = _model_with_metrics(self._eval_model, eval_task) + if self._use_memory_efficient_trainer: + return _Evaluator( + weights=tl.on_cpu(model_with_metrics.weights[1]), + state=tl.on_cpu(model_with_metrics.state[1]), + metrics_fn=_accelerate_model_with_metrics(model_with_metrics, 0), + ) + else: + return _Evaluator( + # Replicate the eval part of weights and state. + weights=self._for_n_devices(model_with_metrics.weights[1]), + state=self._for_n_devices(model_with_metrics.state[1]), + metrics_fn=_accelerate_model_with_metrics( + model_with_metrics, self.n_devices + ), + ) + + def _sync_weights_and_state_across_hosts(self): + """Sync weights and state across all the hosts in the computation.""" + + if logging.vlog_is_on(1): + logging.debug( + "Input training weights shape: %s", + fastmath.nested_map(lambda x: x.shape, self._model.weights), + ) + logging.debug("Input training weights: %s", self._model.weights) + logging.debug("Input training state: %s", self._model.state) + logging.debug("Input eval weights: %s", self._eval_model.weights) + logging.debug("Input eval state: %s", self._eval_model.state) + + ( + self._model.weights, + self._model.state, + self._eval_model.weights, + self._eval_model.state, + ) = self._unreplicate( + _make_weights_and_state_same_across_hosts( + self._for_n_devices( + ( + self._model.weights, + self._model.state, + self._eval_model.weights, + self._eval_model.state, + ) + ) + ) + ) - # optimizer_metrics and loss are replicated on self.n_devices, a few - # metrics are replicated (ex: gradients_l2, weights_l2) - i.e. they are - # the same across devices, whereas some (ex: loss) aren't because they - # are different on different devices (due to different data). - # Taking the average does the correct thing in both the cases. - # - # NOTE: Only the weights and gradients are synced across the hosts. This - # implies the loss here is averaged from this hosts' devices and not - # across all hosts. - optimizer_metrics, loss = fastmath.nested_map( - functools.partial(tl.mean_or_pmean, self._n_devices), - (optimizer_metrics, loss)) - - loss_acc += loss - # Log loss every 50 steps, every step in memory-efficient trainer. - if self._step % 50 == 0 or self._use_memory_efficient_trainer: - self._log_step('Loss: %.4f' % loss, stdout=False) - step_acc += 1 - for metric_name, value in optimizer_metrics.items(): - optimizer_metrics_acc[metric_name] += value - - # TODO(yuwenyan): Finds a way to log the last round eval step in - # history. + if logging.vlog_is_on(1): + logging.debug( + "Output training weights shape: %s", + fastmath.nested_map(lambda x: x.shape, self._model.weights), + ) + logging.debug("Output training weights: %s", self._model.weights) + logging.debug("Output training state: %s", self._model.state) + logging.debug("Output eval weights: %s", self._eval_model.weights) + logging.debug("Output eval state: %s", self._eval_model.state) + + def run(self, n_steps=1): + """Runs this training loop for n steps. + + Optionally runs evals and saves checkpoints at specified points. + + Args: + n_steps: Stop training after completing n steps. + """ + with self._open_summary_writers() as ( + train_summary_writers, + eval_summary_writers, + ): + process = psutil.Process(os.getpid()) + loss_acc, step_acc = 0.0, 0 + start_time = time.time() + optimizer_metrics_acc = collections.defaultdict(float) + for i in range(n_steps): + prev_task_index = self._which_task(self._step) + self._step += 1 + task_index = self._which_task(self._step) + task_changed = task_index != prev_task_index + + if task_changed: + loss_acc, step_acc = 0.0, 0 + + loss, optimizer_metrics = self._run_one_step(task_index, task_changed) + + # optimizer_metrics and loss are replicated on self.n_devices, a few + # metrics are replicated (ex: gradients_l2, weights_l2) - i.e. they are + # the same across devices, whereas some (ex: loss) aren't because they + # are different on different devices (due to different data). + # Taking the average does the correct thing in both the cases. + # + # NOTE: Only the weights and gradients are synced across the hosts. This + # implies the loss here is averaged from this hosts' devices and not + # across all hosts. + optimizer_metrics, loss = fastmath.nested_map( + functools.partial(tl.mean_or_pmean, self._n_devices), + (optimizer_metrics, loss), + ) + + loss_acc += loss + # Log loss every 50 steps, every step in memory-efficient trainer. + if self._step % 50 == 0 or self._use_memory_efficient_trainer: + self._log_step("Loss: %.4f" % loss, stdout=False) + step_acc += 1 + for metric_name, value in optimizer_metrics.items(): + optimizer_metrics_acc[metric_name] += value + + # TODO(yuwenyan): Finds a way to log the last round eval step in + # history. + # + # Right now, the last round eval log is missing in history since the + # checkpoint is saved before it. However sometimes the eval step will + # fail for some reasons, and it's not acceptable to loose the whole + # checkpoint in this case. Stays with the old way for now, and fixes it + # when the checkpoint format is changed to storing weights separately + # from a small file with history and other data. + if self._checkpoint_at(self.step): + self.save_checkpoint("model") + if self._permanent_checkpoint_at(self.step): + self.save_checkpoint(f"model_{self.step}") + if self._eval_at(self.step): + logging.info( + "cpu memory use (MB): %.2f", + process.memory_info().rss / float(1024 * 1024), + ) + elapsed_time = time.time() - start_time + self._log_training_progress( + task=self._tasks[task_index], + total_loss=loss_acc, + n_steps=step_acc, + elapsed_time=elapsed_time, + optimizer_metrics=optimizer_metrics_acc, + summary_writer=train_summary_writers[task_index], + ) + self.run_evals(eval_summary_writers) + loss_acc, step_acc = 0.0, 0 + start_time = time.time() + optimizer_metrics_acc = collections.defaultdict(float) + + # For the current step, after all evals are run and recorded in the + # event history, check if we need to save special checkpoints because + # of a new low metric value or a new high metric value. + if self._checkpoint_at(self.step): + if self._checkpoint_low_metric is not None and self._at_lowest(): + self.save_checkpoint(f"lowest_{self._checkpoint_low_metric}") + if self._checkpoint_high_metric is not None and self._at_highest(): + self.save_checkpoint(f"highest_{self._checkpoint_high_metric}") + + # Store the final values back into their respective objects, for testing + # or other inspection/use. # - # Right now, the last round eval log is missing in history since the - # checkpoint is saved before it. However sometimes the eval step will - # fail for some reasons, and it's not acceptable to loose the whole - # checkpoint in this case. Stays with the old way for now, and fixes it - # when the checkpoint format is changed to storing weights separately - # from a small file with history and other data. - if self._checkpoint_at(self.step): - self.save_checkpoint('model') - if self._permanent_checkpoint_at(self.step): - self.save_checkpoint(f'model_{self.step}') - if self._eval_at(self.step): - logging.info('cpu memory use (MB): %.2f', - process.memory_info().rss / float(1024 * 1024)) - elapsed_time = time.time() - start_time - self._log_training_progress( - task=self._tasks[task_index], - total_loss=loss_acc, - n_steps=step_acc, - elapsed_time=elapsed_time, - optimizer_metrics=optimizer_metrics_acc, - summary_writer=train_summary_writers[task_index], - ) - self.run_evals(eval_summary_writers) - loss_acc, step_acc = 0.0, 0 - start_time = time.time() - optimizer_metrics_acc = collections.defaultdict(float) - - # For the current step, after all evals are run and recorded in the - # event history, check if we need to save special checkpoints because - # of a new low metric value or a new high metric value. - if self._checkpoint_at(self.step): - if self._checkpoint_low_metric is not None and self._at_lowest(): - self.save_checkpoint(f'lowest_{self._checkpoint_low_metric}') - if self._checkpoint_high_metric is not None and self._at_highest(): - self.save_checkpoint(f'highest_{self._checkpoint_high_metric}') - - # Store the final values back into their respective objects, for testing - # or other inspection/use. - # - # We keep the standard model weights/state unreplicated and - # tl.Accelerate(model) will carry the replicated weights/state. - # TODO(afrozm): Try to use tl.Accelerate(model) everywhere in the Loop. - self._eval_model.weights = self._model.weights - - def _at_lowest(self): - low_items = self.history.get('eval', - f'metrics/{self._checkpoint_low_metric}') - vals = [float(obj[1]) for obj in low_items] - return vals[-1] == min(vals) - - def _at_highest(self): - high_items = self.history.get('eval', - f'metrics/{self._checkpoint_high_metric}') - vals = [float(obj[1]) for obj in high_items] - return vals[-1] == max(vals) - - @property - def step(self): - """Returns current step number in this training session.""" - return self._step - - @property - def history(self): - """Returns history in this training session.""" - return self._history - - @property - def n_devices(self): - """Returns the number of devices to be used in this computation.""" - return self._n_devices - - @property - def is_chief(self): - """Returns true if this Loop is the chief.""" - return self._is_chief - - @property - def model(self): - """Returns the model that is training.""" - return self._model - - @property - def tasks(self): - """Returns the training tasks.""" - return self._tasks - - @property - def eval_model(self): - """Returns the model used for evaluation.""" - return self._eval_model - - @property - def eval_tasks(self): - """Returns the evaluation tasks.""" - return self._eval_tasks - - @property - def output_dir(self): - """Returns the output directory.""" - return self._output_dir - - def new_rng(self): - """Returns a new single-use random number generator (JAX PRNG key).""" - self._rng, rng = fastmath.random.split(self._rng) - if self._use_memory_efficient_trainer: - self._rng = tl.on_cpu(self._rng) - rng = tl.on_cpu(rng) - return rng - - def _for_n_devices(self, x): - """Replicates/broadcasts ``x`` for n devices if ``self.n_devicess > 1``.""" - return tl.for_n_devices(x, self.n_devices) - - def _unreplicate(self, x): - if self.n_devices == 1: - return x - - unreplicate_fn = lambda x: x[0] - return fastmath.nested_map(unreplicate_fn, x) - - def _reshape_by_device(self, x): - if self.n_devices == 1: - return x - return tl.reshape_by_device(x, self.n_devices) - - def update_weights_and_state(self, weights=None, state=None): - """Updates the weights and state of the trained model. - - Sends this data both to the singleton model accessible via Loop.model - and to the replicated model on the accelerator. - - Useful when the weights or state are modified outside of training, e.g. - during data collection in RL agents. - - Args: - weights: Model weights or ``None``. If ``None``, don't set. - state: Model state or ``None``. If ``None``, don't set. - """ - for trainer in self._trainer_per_task: - acc_model_with_loss = trainer.accelerated_model_with_loss - if weights is not None: - self._model.weights = weights - acc_model_with_loss.replicate_weights(trainer.model_with_loss.weights) - if state is not None: - self._model.state = state - acc_model_with_loss.replicate_state(trainer.model_with_loss.state) - - def _run_one_step(self, task_index, task_changed): - """Updates model weights/state and optimizer slots by running one step. - - Args: - task_index (int): Index of the task to train on. - task_changed (bool): Whether the state has changed since the last step. - - Returns: - Tuple (loss, stats) with loss value from one step - of training and stats, the current optimizer statistics. - """ - step = self.step - for callback in self._callbacks: - if callback.call_at(step): - callback.on_step_begin(step) - - learning_rate = self._tasks[task_index].learning_rate(step) - batch = self._tasks[task_index].next_batch() - rng = self.new_rng() - trainer = self._trainer_per_task[task_index] - if task_changed: - # Re-replicate weights and state to synchronize them between tasks. - self.update_weights_and_state(self._model.weights, self._model.state) - - (loss, stats) = trainer.one_step( - batch, rng, step=step, learning_rate=learning_rate - ) - - for callback in self._callbacks: - if callback.call_at(step): - callback.on_step_end(step) - - return (loss, stats) - - def _log_training_progress(self, task, total_loss, n_steps, elapsed_time, - optimizer_metrics, summary_writer): - """Logs training related metrics. - - Logs: - * current learning rate, - * steps per second, - * average training loss, - * average metrics returned from the optimizer - to the provided summary writer. Training loss is also logged to stdout. - - Args: - task: Current training task. - total_loss: Total training loss accumulated over n_steps training steps. - n_steps: Number of steps over which the metrics were accumulated. - elapsed_time: Time of execution of n_steps training steps. - optimizer_metrics: Dict from optimizer metric name to metric values. - summary_writer: Jaxboard summary writer for saving provided metrics. - """ - # only here do avoid potential divide-by-0 - n_steps = max(1, n_steps) - _log('') # Separator for visibility on terminals. - if self.step == 1: - self._log_n_weights() - self._log_step('Ran %d train steps in %0.2f secs' % (n_steps, elapsed_time)) - self.log_summary( - {task.loss_name: total_loss / float(n_steps)}, - summary_writer, 'metrics/', 'train') - if self.step == 1: - self._save_gin(summary_writer) - train_parameters = { - 'learning_rate': task.learning_rate(self.step), - 'steps per second': n_steps / elapsed_time, - } - # Average optimizer_metrics over n_steps. - optimizer_metrics = {k: v / n_steps for k, v in optimizer_metrics.items()} - train_parameters.update(optimizer_metrics) - self.log_summary( - train_parameters, summary_writer, 'training/', 'train', stdout=False) - - def _save_gin(self, summary_writer): - """"Saves the operative gin config.""" - if not self.is_chief or self._output_dir is None: - return - config_path = os.path.join(self._output_dir, 'config.gin') - config_str = gin.operative_config_str() - with tf.io.gfile.GFile(config_path, 'w') as f: - f.write(config_str) - if summary_writer is not None: - summary_writer.text( - 'gin_config', jaxboard.markdownify_operative_config_str(config_str) - ) - - def _log_n_weights(self): - """"Logs the number of weights in the training model.""" - def _size(x): - try: - return x.size - except Exception: # pylint: disable=broad-except - return 0 - sizes = fastmath.nested_map(_size, self._model.weights) - total_size = sum(fastmath.tree_flatten(sizes)) - total_size *= base.N_WEIGHTS_SHARDS - self._log_step('Total number of trainable weights: %d' % total_size) - - # TODO(afrozm): Fix multi-host evals, right now the reported numbers in the - # summary writer are only from the chief and not averaged across hosts. - def run_evals(self, summary_writers=None): - """Runs and records evals for this training session. - - Args: - summary_writers: List of per-task Jaxboard summary writers to log metrics. - """ - if summary_writers is None: - summary_writers = (None,) * len(self._eval_tasks) - - self._eval_model.weights = self._model.weights - self._eval_model.state = self._model.state - - def recursively_look_for_printable_states(state): - if isinstance(state, (tuple, list)): - for substate in state: - for item in recursively_look_for_printable_states(substate): - yield item - if isinstance(state, dict): - for key, value in state.items(): - if isinstance(key, str) and key.startswith('summary_'): - for device_id, device_value in enumerate(value): - yield ('device{}/{}'.format(device_id, key[len('summary_'):]), - device_value) - - # The most recently trained weights are in this trainer, use those for eval. - cur_train_task_index = self._which_task(self._step) - trainer = self._trainer_per_task[cur_train_task_index] - - for eval_task_index in range(len(self._eval_tasks)): - eval_task = self._eval_tasks[eval_task_index] - evaluator = self._evaluator_per_task[eval_task_index] - if eval_task is None: - continue - - # Extract the actual model weights and state, excluding the loss layer. - if self._use_memory_efficient_trainer: - model_weights, model_state = self._model.weights, self._model.state - else: - model_weights = trainer.accelerated_model_with_loss.weights[0] - model_state = trainer.accelerated_model_with_loss.state[0] - - # evaluator.{weights,state} are already replicated. - metrics_weights = (model_weights, evaluator.weights) - metrics_state = (model_state, evaluator.state) - - n_batches = eval_task.n_eval_batches - n_metrics = len(eval_task.metrics) - sums = np.zeros((n_metrics,)) - for _ in range(n_batches): + # We keep the standard model weights/state unreplicated and + # tl.Accelerate(model) will carry the replicated weights/state. + # TODO(afrozm): Try to use tl.Accelerate(model) everywhere in the Loop. + self._eval_model.weights = self._model.weights + + def _at_lowest(self): + low_items = self.history.get("eval", f"metrics/{self._checkpoint_low_metric}") + vals = [float(obj[1]) for obj in low_items] + return vals[-1] == min(vals) + + def _at_highest(self): + high_items = self.history.get("eval", f"metrics/{self._checkpoint_high_metric}") + vals = [float(obj[1]) for obj in high_items] + return vals[-1] == max(vals) + + @property + def step(self): + """Returns current step number in this training session.""" + return self._step + + @property + def history(self): + """Returns history in this training session.""" + return self._history + + @property + def n_devices(self): + """Returns the number of devices to be used in this computation.""" + return self._n_devices + + @property + def is_chief(self): + """Returns true if this Loop is the chief.""" + return self._is_chief + + @property + def model(self): + """Returns the model that is training.""" + return self._model + + @property + def tasks(self): + """Returns the training tasks.""" + return self._tasks + + @property + def eval_model(self): + """Returns the model used for evaluation.""" + return self._eval_model + + @property + def eval_tasks(self): + """Returns the evaluation tasks.""" + return self._eval_tasks + + @property + def output_dir(self): + """Returns the output directory.""" + return self._output_dir + + def new_rng(self): + """Returns a new single-use random number generator (JAX PRNG key).""" + self._rng, rng = fastmath.random.split(self._rng) + if self._use_memory_efficient_trainer: + self._rng = tl.on_cpu(self._rng) + rng = tl.on_cpu(rng) + return rng + + def _for_n_devices(self, x): + """Replicates/broadcasts ``x`` for n devices if ``self.n_devicess > 1``.""" + return tl.for_n_devices(x, self.n_devices) + + def _unreplicate(self, x): + if self.n_devices == 1: + return x + + unreplicate_fn = lambda x: x[0] + return fastmath.nested_map(unreplicate_fn, x) + + def _reshape_by_device(self, x): + if self.n_devices == 1: + return x + return tl.reshape_by_device(x, self.n_devices) + + def update_weights_and_state(self, weights=None, state=None): + """Updates the weights and state of the trained model. + + Sends this data both to the singleton model accessible via Loop.model + and to the replicated model on the accelerator. + + Useful when the weights or state are modified outside of training, e.g. + during data collection in RL agents. + + Args: + weights: Model weights or ``None``. If ``None``, don't set. + state: Model state or ``None``. If ``None``, don't set. + """ + for trainer in self._trainer_per_task: + acc_model_with_loss = trainer.accelerated_model_with_loss + if weights is not None: + self._model.weights = weights + acc_model_with_loss.replicate_weights(trainer.model_with_loss.weights) + if state is not None: + self._model.state = state + acc_model_with_loss.replicate_state(trainer.model_with_loss.state) + + def _run_one_step(self, task_index, task_changed): + """Updates model weights/state and optimizer slots by running one step. + + Args: + task_index (int): Index of the task to train on. + task_changed (bool): Whether the state has changed since the last step. + + Returns: + Tuple (loss, stats) with loss value from one step + of training and stats, the current optimizer statistics. + """ + step = self.step + for callback in self._callbacks: + if callback.call_at(step): + callback.on_step_begin(step) + + learning_rate = self._tasks[task_index].learning_rate(step) + batch = self._tasks[task_index].next_batch() rng = self.new_rng() - batch = eval_task.next_batch() - metric_values, _ = evaluator.metrics_fn( - batch, metrics_weights, metrics_state, rng) - sums += metric_values - averages = sums / n_batches - all_metrics = dict(zip(eval_task.metric_names, averages)) - summary_writer = summary_writers[eval_task_index] - self.log_summary(all_metrics, summary_writer, 'metrics/', 'eval') - summary_metrics = dict(recursively_look_for_printable_states( - model_state)) - self.log_summary(summary_metrics, summary_writer, 'summary_', 'eval') - - def log_summary(self, values, summary_writer, value_prefix, log_prefix, - stdout=True): - """Logs and saves provided metrics. - - Args: - values: Dict from metric name to metric value. - summary_writer: Jaxboard summary writer. - value_prefix: String appended in front of summary_writer entries. - log_prefix: String appended in front of logs. - stdout: Boolean saying if logs should be logged to stdout as well. - """ - history = self._history - should_write_summaries = self.is_chief and summary_writer is not None - for name, value in values.items(): - full_name = value_prefix + name - s = tuple(jnp.shape(value)) - if not s: - self._log_step( - '%s %s | % .8f' % - (log_prefix.ljust(5), name.rjust(self._rjust_len), value), - stdout=stdout) - if should_write_summaries: - summary_writer.scalar(full_name, value, self.step) - else: - if should_write_summaries: - summary_writer.image(full_name, value, self.step) - if history: - history.append(log_prefix, full_name, self.step, value) - if should_write_summaries: - summary_writer.flush() - - def _log_step(self, msg, stdout=True): - """Logs message, labeled with the current training step number.""" - _log('Step % 6d: %s' % (self.step, msg), stdout=stdout) - - def save_checkpoint(self, basename): - """Saves checkpoint (multiple files) to disk for the current training step. + trainer = self._trainer_per_task[task_index] + if task_changed: + # Re-replicate weights and state to synchronize them between tasks. + self.update_weights_and_state(self._model.weights, self._model.state) - Saving a checkpoint will overwrite any previous checkpoint saved with the - same ``basename``. Use differing ``basename`` values to save multiple - checkpoints or multiple copies of the same checkpoint. + (loss, stats) = trainer.one_step( + batch, rng, step=step, learning_rate=learning_rate + ) - Args: - basename: Basename for saving a checkpoint. Full file paths for the saved - checkpoint will combine the output dir, basename, and relevant file - extensions (e.g., `.weights.npy.gz`). - """ - if self._output_dir is None: - _log('Did not save checkpoint as output_dir is None') - return - - inputs.save_data_counters(self._output_dir) - if not self.is_chief: - _log('Did not save checkpoint as we are not chief.') - return - - dir_and_basename = os.path.join(self._output_dir, basename) - pkl_file = dir_and_basename + '.pkl.gz' - - _log('Saving checkpoint to %s' % pkl_file, stdout=False) - weights = self._model.weights - if base.N_WEIGHTS_SHARDS > 1: - weights = self._trainer_per_task[0].accelerated_model_with_loss.weights - weights = tl.unshard(weights) - state = self._model.state - compresslevel = 0 if self._use_memory_efficient_trainer else 2 - # Serialize optimizer slots. - for i, trainer in enumerate(self._trainer_per_task): - flat_slots = _flatten_and_remove_empty(trainer.slots) - tl.np_to_file(self._to_bits(flat_slots), - f'{dir_and_basename}.opt_slots{i}.npy.gz', - compresslevel=compresslevel) - # We only need the input signature for the body, not for the loss layers. - # That part is the same across tasks - take it from the first one. - input_signature = self._batch_signature[:self._model.n_in] - flat_weights, flat_state = tl.flatten_weights_and_state(weights, state) - _, flat_eval_state = tl.flatten_weights_and_state( - weights, self._eval_model.state) - tl.np_to_file(self._to_bits(flat_weights), - f'{dir_and_basename}.weights.npy.gz', - compresslevel=compresslevel) - d = { - 'step': self.step, - 'flat_weights': compresslevel, # for compatibility with older format - 'flat_state': flat_state, - 'flat_eval_state': flat_eval_state, - 'history': self._history.to_dict(), - 'slots_per_task': compresslevel, # for compatibility with older format - 'input_signature': input_signature, - 'version_timestamp': 'Mar-10-2021' # To update in the future if needed. - } - pickle_to_file(d, pkl_file, gzip=True) - _log('Checkpoint saved in %s' % pkl_file, stdout=False) - - def _to_bits(self, weights): - """Converts a list of weights to bit-cast weights and their types.""" - # This is currently needed to pickle bfloat16 arrays from JAX. - # TODO(lukaszkaiser): remove once it is not needed (the following unit test - # checks it: training_test/test_restores_step_bfloat16). - if not fastmath.is_backend(fastmath.Backend.JAX): - return weights - bits = [] - for w in weights: - if w.dtype == jnp.bfloat16: - converted = jax.lax.bitcast_convert_type(w, np.uint16) - bits.append(np.asarray(converted.astype(np.uint16))) - else: # for non-bfloat16 weights, be compatible with earlier checkpoints - bits.append(np.asarray(w)) - return bits - - def _from_bits(self, bits): - """Converts a list of bit-cast weights back to weights.""" - # This is the reverse of _to_bits, see above for explanation. - if not fastmath.is_backend(fastmath.Backend.JAX): - return bits - weights = [] - for b in bits: - if b.dtype == np.uint16: # currently all uint16 are bfloat16s - w = jax.lax.bitcast_convert_type(b, jnp.bfloat16) - weights.append(np.asarray(w)) - else: - weights.append(b) - return weights - - def load_checkpoint(self, directory=None, filename=None): - """Loads model weights and step from a checkpoint on disk. + for callback in self._callbacks: + if callback.call_at(step): + callback.on_step_end(step) + + return (loss, stats) + + def _log_training_progress( + self, task, total_loss, n_steps, elapsed_time, optimizer_metrics, summary_writer + ): + """Logs training related metrics. + + Logs: + * current learning rate, + * steps per second, + * average training loss, + * average metrics returned from the optimizer + to the provided summary writer. Training loss is also logged to stdout. + + Args: + task: Current training task. + total_loss: Total training loss accumulated over n_steps training steps. + n_steps: Number of steps over which the metrics were accumulated. + elapsed_time: Time of execution of n_steps training steps. + optimizer_metrics: Dict from optimizer metric name to metric values. + summary_writer: Jaxboard summary writer for saving provided metrics. + """ + # only here do avoid potential divide-by-0 + n_steps = max(1, n_steps) + _log("") # Separator for visibility on terminals. + if self.step == 1: + self._log_n_weights() + self._log_step("Ran %d train steps in %0.2f secs" % (n_steps, elapsed_time)) + self.log_summary( + {task.loss_name: total_loss / float(n_steps)}, + summary_writer, + "metrics/", + "train", + ) + if self.step == 1: + self._save_gin(summary_writer) + train_parameters = { + "learning_rate": task.learning_rate(self.step), + "steps per second": n_steps / elapsed_time, + } + # Average optimizer_metrics over n_steps. + optimizer_metrics = {k: v / n_steps for k, v in optimizer_metrics.items()} + train_parameters.update(optimizer_metrics) + self.log_summary( + train_parameters, summary_writer, "training/", "train", stdout=False + ) - Args: - directory: Directory with the checkpoint (self._output_dir by default). - filename: Checkpoint file name (model.pkl.gz by default). - """ - directory = directory or self._output_dir - if directory is None: - _log('Not loading as both directory and output_dir are None.', - stdout=False) - return - filename = filename or 'model' - path = os.path.join(directory, filename) - pkl_path = path + '.pkl.gz' - if not tf.io.gfile.exists(pkl_path): - _log(f'Not loading as checkpoint file does not exist: {pkl_path}', - stdout=False) - return - _log('Loading checkpoint from %s' % pkl_path, stdout=False) - d = unpickle_from_file(pkl_path, gzip=True) - # Weights are stored in a separate non-pickled file in the new checkpoint - # format. We support loading old checkpoints with this hack. - # TODO(lukaszkaiser): remove the hack when not needed any more. - if isinstance(d['flat_weights'], int): - weights = tl.np_from_file(path + '.weights.npy.gz', - compresslevel=d['flat_weights']) - d['flat_weights'] = weights - else: - d['flat_weights'] = d['flat_weights'] - # The same holds for optimizer slots. - if 'slots' in d: # Old checkpoints had just 'slots' for one task. - if len(self._tasks) != 1: - raise ValueError( - 'Can\'t load a single-task checkpoint into a multitask Loop.' + def _save_gin(self, summary_writer): + """ "Saves the operative gin config.""" + if not self.is_chief or self._output_dir is None: + return + config_path = os.path.join(self._output_dir, "config.gin") + config_str = gin.operative_config_str() + with tf.io.gfile.GFile(config_path, "w") as f: + f.write(config_str) + if summary_writer is not None: + summary_writer.text( + "gin_config", jaxboard.markdownify_operative_config_str(config_str) + ) + + def _log_n_weights(self): + """ "Logs the number of weights in the training model.""" + + def _size(x): + try: + return x.size + except Exception: # pylint: disable=broad-except + return 0 + + sizes = fastmath.nested_map(_size, self._model.weights) + total_size = sum(fastmath.tree_flatten(sizes)) + total_size *= base.N_WEIGHTS_SHARDS + self._log_step("Total number of trainable weights: %d" % total_size) + + # TODO(afrozm): Fix multi-host evals, right now the reported numbers in the + # summary writer are only from the chief and not averaged across hosts. + def run_evals(self, summary_writers=None): + """Runs and records evals for this training session. + + Args: + summary_writers: List of per-task Jaxboard summary writers to log metrics. + """ + if summary_writers is None: + summary_writers = (None,) * len(self._eval_tasks) + + self._eval_model.weights = self._model.weights + self._eval_model.state = self._model.state + + def recursively_look_for_printable_states(state): + if isinstance(state, (tuple, list)): + for substate in state: + for item in recursively_look_for_printable_states(substate): + yield item + if isinstance(state, dict): + for key, value in state.items(): + if isinstance(key, str) and key.startswith("summary_"): + for device_id, device_value in enumerate(value): + yield ( + "device{}/{}".format(device_id, key[len("summary_") :]), + device_value, + ) + + # The most recently trained weights are in this trainer, use those for eval. + cur_train_task_index = self._which_task(self._step) + trainer = self._trainer_per_task[cur_train_task_index] + + for eval_task_index in range(len(self._eval_tasks)): + eval_task = self._eval_tasks[eval_task_index] + evaluator = self._evaluator_per_task[eval_task_index] + if eval_task is None: + continue + + # Extract the actual model weights and state, excluding the loss layer. + if self._use_memory_efficient_trainer: + model_weights, model_state = self._model.weights, self._model.state + else: + model_weights = trainer.accelerated_model_with_loss.weights[0] + model_state = trainer.accelerated_model_with_loss.state[0] + + # evaluator.{weights,state} are already replicated. + metrics_weights = (model_weights, evaluator.weights) + metrics_state = (model_state, evaluator.state) + + n_batches = eval_task.n_eval_batches + n_metrics = len(eval_task.metrics) + sums = np.zeros((n_metrics,)) + for _ in range(n_batches): + rng = self.new_rng() + batch = eval_task.next_batch() + metric_values, _ = evaluator.metrics_fn( + batch, metrics_weights, metrics_state, rng + ) + sums += metric_values + averages = sums / n_batches + all_metrics = dict(zip(eval_task.metric_names, averages)) + summary_writer = summary_writers[eval_task_index] + self.log_summary(all_metrics, summary_writer, "metrics/", "eval") + summary_metrics = dict(recursively_look_for_printable_states(model_state)) + self.log_summary(summary_metrics, summary_writer, "summary_", "eval") + + def log_summary( + self, values, summary_writer, value_prefix, log_prefix, stdout=True + ): + """Logs and saves provided metrics. + + Args: + values: Dict from metric name to metric value. + summary_writer: Jaxboard summary writer. + value_prefix: String appended in front of summary_writer entries. + log_prefix: String appended in front of logs. + stdout: Boolean saying if logs should be logged to stdout as well. + """ + history = self._history + should_write_summaries = self.is_chief and summary_writer is not None + for name, value in values.items(): + full_name = value_prefix + name + s = tuple(jnp.shape(value)) + if not s: + self._log_step( + "%s %s | % .8f" + % (log_prefix.ljust(5), name.rjust(self._rjust_len), value), + stdout=stdout, + ) + if should_write_summaries: + summary_writer.scalar(full_name, value, self.step) + else: + if should_write_summaries: + summary_writer.image(full_name, value, self.step) + if history: + history.append(log_prefix, full_name, self.step, value) + if should_write_summaries: + summary_writer.flush() + + def _log_step(self, msg, stdout=True): + """Logs message, labeled with the current training step number.""" + _log("Step % 6d: %s" % (self.step, msg), stdout=stdout) + + def save_checkpoint(self, basename): + """Saves checkpoint (multiple files) to disk for the current training step. + + Saving a checkpoint will overwrite any previous checkpoint saved with the + same ``basename``. Use differing ``basename`` values to save multiple + checkpoints or multiple copies of the same checkpoint. + + Args: + basename: Basename for saving a checkpoint. Full file paths for the saved + checkpoint will combine the output dir, basename, and relevant file + extensions (e.g., `.weights.npy.gz`). + """ + if self._output_dir is None: + _log("Did not save checkpoint as output_dir is None") + return + + inputs.save_data_counters(self._output_dir) + if not self.is_chief: + _log("Did not save checkpoint as we are not chief.") + return + + dir_and_basename = os.path.join(self._output_dir, basename) + pkl_file = dir_and_basename + ".pkl.gz" + + _log("Saving checkpoint to %s" % pkl_file, stdout=False) + weights = self._model.weights + if base.N_WEIGHTS_SHARDS > 1: + weights = self._trainer_per_task[0].accelerated_model_with_loss.weights + weights = tl.unshard(weights) + state = self._model.state + compresslevel = 0 if self._use_memory_efficient_trainer else 2 + # Serialize optimizer slots. + for i, trainer in enumerate(self._trainer_per_task): + flat_slots = _flatten_and_remove_empty(trainer.slots) + tl.np_to_file( + self._to_bits(flat_slots), + f"{dir_and_basename}.opt_slots{i}.npy.gz", + compresslevel=compresslevel, + ) + # We only need the input signature for the body, not for the loss layers. + # That part is the same across tasks - take it from the first one. + input_signature = self._batch_signature[: self._model.n_in] + flat_weights, flat_state = tl.flatten_weights_and_state(weights, state) + _, flat_eval_state = tl.flatten_weights_and_state( + weights, self._eval_model.state ) - d['slots_per_task'] = [d['slots']] - # Read from separate files if optimizer slots are in them. - if 'slots_per_task' in d and isinstance(d['slots_per_task'], int): - compresslevel = d['slots_per_task'] - d['slots_per_task'] = [] - for i in range(len(self._trainer_per_task)): - slots = tl.np_from_file(path + f'.opt_slots{i}.npy.gz', - compresslevel=compresslevel) - d['slots_per_task'].append(slots) - for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']): - matched_flat_slots = _match_by_shape( - self._to_bits(_flatten_and_remove_empty(trainer.slots)), - _flatten_and_remove_empty(slots)) - matched_slots, _ = fastmath.tree_unflatten( - self._from_bits(matched_flat_slots), - trainer.slots, copy_from_tree=[None, ()]) - trainer.slots = matched_slots - self._step = d['step'] - self._history = trax_history.History.from_dict(d['history']) - # This is self._model.init_from_file but optimized to not re-read. - input_signature = d['input_signature'] - weights_and_state_sig = self._model.weights_and_state_signature( - input_signature) - flat_init_weights, flat_init_state = tl.flatten_weights_and_state( - self._model.weights, self._model.state) - if len(d['flat_weights']) < len(flat_init_weights): - _log('Checkpoint has less weights than the model, loading first ones.') - matched_weights = _match_by_shape(self._to_bits(flat_init_weights), - d['flat_weights']) - matched_weights = self._from_bits(matched_weights) - try: - restored_state = True - matched_state = _match_by_shape(self._to_bits(flat_init_state), - d['flat_state']) - matched_state = self._from_bits(matched_state) - weights, state = tl.unflatten_weights_and_state( - matched_weights, matched_state, weights_and_state_sig) - self._model.state = state - except IndexError: - _log('Failed loading model state from checkpoint, loading weights only.') - restored_state = False - weights, _ = tl.unflatten_weights_and_state( - matched_weights, (), weights_and_state_sig, weights_only=True) - self._model.weights = weights - self._eval_model.weights = self._model.weights - # Restore eval model state; note: it's not always the same as train state. - if restored_state: - if 'flat_eval_state' in d: - flat_eval_state = d['flat_eval_state'] - else: # It wasn't saved in old checkpoints; remove this branch once done. - flat_eval_state = d['flat_state'] - _, eval_state = tl.unflatten_weights_and_state( - matched_weights, flat_eval_state, weights_and_state_sig) - self._eval_model.state = eval_state - _log('Checkpoint loaded from %s' % pkl_path, stdout=False) - - @contextlib.contextmanager - def _open_summary_writers(self): - """Opens the Jaxboard summary writers wrapped by context manager. - - Yields: - A pair (train_summary_writers, eval_summary_writers) of lists of - Jaxboard summary writers wrapped in a GeneratorContextManager object. - Elements of the lists correspond to the training and evaluation task - directories created during initialization. If there was no output_dir - provided, yields lists of Nones with the appropriate length. - """ - if self._output_dir is not None: - _log(f'Metrics will be written in {self._output_dir}.', stdout=False) - train_writers = [jaxboard.SummaryWriter(os.path.join(output_dir, 'train')) - for output_dir in self._output_dir_per_train_task] - eval_writers = [jaxboard.SummaryWriter(os.path.join(output_dir, 'eval')) - for output_dir in self._output_dir_per_eval_task] - try: - yield (train_writers, eval_writers) - finally: - for writer in train_writers + eval_writers: - writer.close() - _log(f'Metrics were written in {self._output_dir}', stdout=False) - else: - yield ([None] * len(self._tasks), [None] * len(self._eval_tasks)) + tl.np_to_file( + self._to_bits(flat_weights), + f"{dir_and_basename}.weights.npy.gz", + compresslevel=compresslevel, + ) + d = { + "step": self.step, + "flat_weights": compresslevel, # for compatibility with older format + "flat_state": flat_state, + "flat_eval_state": flat_eval_state, + "history": self._history.to_dict(), + "slots_per_task": compresslevel, # for compatibility with older format + "input_signature": input_signature, + "version_timestamp": "Mar-10-2021", # To update in the future if needed. + } + pickle_to_file(d, pkl_file, gzip=True) + _log("Checkpoint saved in %s" % pkl_file, stdout=False) + + def _to_bits(self, weights): + """Converts a list of weights to bit-cast weights and their types.""" + # This is currently needed to pickle bfloat16 arrays from JAX. + # TODO(lukaszkaiser): remove once it is not needed (the following unit test + # checks it: training_test/test_restores_step_bfloat16). + if not fastmath.is_backend(fastmath.Backend.JAX): + return weights + bits = [] + for w in weights: + if w.dtype == jnp.bfloat16: + converted = jax.lax.bitcast_convert_type(w, np.uint16) + bits.append(np.asarray(converted.astype(np.uint16))) + else: # for non-bfloat16 weights, be compatible with earlier checkpoints + bits.append(np.asarray(w)) + return bits + + def _from_bits(self, bits): + """Converts a list of bit-cast weights back to weights.""" + # This is the reverse of _to_bits, see above for explanation. + if not fastmath.is_backend(fastmath.Backend.JAX): + return bits + weights = [] + for b in bits: + if b.dtype == np.uint16: # currently all uint16 are bfloat16s + w = jax.lax.bitcast_convert_type(b, jnp.bfloat16) + weights.append(np.asarray(w)) + else: + weights.append(b) + return weights + + def load_checkpoint(self, directory=None, filename=None): + """Loads model weights and step from a checkpoint on disk. + + Args: + directory: Directory with the checkpoint (self._output_dir by default). + filename: Checkpoint file name (model.pkl.gz by default). + """ + directory = directory or self._output_dir + if directory is None: + _log("Not loading as both directory and output_dir are None.", stdout=False) + return + filename = filename or "model" + path = os.path.join(directory, filename) + pkl_path = path + ".pkl.gz" + if not tf.io.gfile.exists(pkl_path): + _log( + f"Not loading as checkpoint file does not exist: {pkl_path}", + stdout=False, + ) + return + _log("Loading checkpoint from %s" % pkl_path, stdout=False) + d = unpickle_from_file(pkl_path, gzip=True) + # Weights are stored in a separate non-pickled file in the new checkpoint + # format. We support loading old checkpoints with this hack. + # TODO(lukaszkaiser): remove the hack when not needed any more. + if isinstance(d["flat_weights"], int): + weights = tl.np_from_file( + path + ".weights.npy.gz", compresslevel=d["flat_weights"] + ) + d["flat_weights"] = weights + else: + d["flat_weights"] = d["flat_weights"] + # The same holds for optimizer slots. + if "slots" in d: # Old checkpoints had just 'slots' for one task. + if len(self._tasks) != 1: + raise ValueError( + "Can't load a single-task checkpoint into a multitask Loop." + ) + d["slots_per_task"] = [d["slots"]] + # Read from separate files if optimizer slots are in them. + if "slots_per_task" in d and isinstance(d["slots_per_task"], int): + compresslevel = d["slots_per_task"] + d["slots_per_task"] = [] + for i in range(len(self._trainer_per_task)): + slots = tl.np_from_file( + path + f".opt_slots{i}.npy.gz", compresslevel=compresslevel + ) + d["slots_per_task"].append(slots) + for (trainer, slots) in zip(self._trainer_per_task, d["slots_per_task"]): + matched_flat_slots = _match_by_shape( + self._to_bits(_flatten_and_remove_empty(trainer.slots)), + _flatten_and_remove_empty(slots), + ) + matched_slots, _ = fastmath.tree_unflatten( + self._from_bits(matched_flat_slots), + trainer.slots, + copy_from_tree=[None, ()], + ) + trainer.slots = matched_slots + self._step = d["step"] + self._history = trax_history.History.from_dict(d["history"]) + # This is self._model.init_from_file but optimized to not re-read. + input_signature = d["input_signature"] + weights_and_state_sig = self._model.weights_and_state_signature(input_signature) + flat_init_weights, flat_init_state = tl.flatten_weights_and_state( + self._model.weights, self._model.state + ) + if len(d["flat_weights"]) < len(flat_init_weights): + _log("Checkpoint has less weights than the model, loading first ones.") + matched_weights = _match_by_shape( + self._to_bits(flat_init_weights), d["flat_weights"] + ) + matched_weights = self._from_bits(matched_weights) + try: + restored_state = True + matched_state = _match_by_shape( + self._to_bits(flat_init_state), d["flat_state"] + ) + matched_state = self._from_bits(matched_state) + weights, state = tl.unflatten_weights_and_state( + matched_weights, matched_state, weights_and_state_sig + ) + self._model.state = state + except IndexError: + _log("Failed loading model state from checkpoint, loading weights only.") + restored_state = False + weights, _ = tl.unflatten_weights_and_state( + matched_weights, (), weights_and_state_sig, weights_only=True + ) + self._model.weights = weights + self._eval_model.weights = self._model.weights + # Restore eval model state; note: it's not always the same as train state. + if restored_state: + if "flat_eval_state" in d: + flat_eval_state = d["flat_eval_state"] + else: # It wasn't saved in old checkpoints; remove this branch once done. + flat_eval_state = d["flat_state"] + _, eval_state = tl.unflatten_weights_and_state( + matched_weights, flat_eval_state, weights_and_state_sig + ) + self._eval_model.state = eval_state + _log("Checkpoint loaded from %s" % pkl_path, stdout=False) + + @contextlib.contextmanager + def _open_summary_writers(self): + """Opens the Jaxboard summary writers wrapped by context manager. + + Yields: + A pair (train_summary_writers, eval_summary_writers) of lists of + Jaxboard summary writers wrapped in a GeneratorContextManager object. + Elements of the lists correspond to the training and evaluation task + directories created during initialization. If there was no output_dir + provided, yields lists of Nones with the appropriate length. + """ + if self._output_dir is not None: + _log(f"Metrics will be written in {self._output_dir}.", stdout=False) + train_writers = [ + jaxboard.SummaryWriter(os.path.join(output_dir, "train")) + for output_dir in self._output_dir_per_train_task + ] + eval_writers = [ + jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) + for output_dir in self._output_dir_per_eval_task + ] + try: + yield (train_writers, eval_writers) + finally: + for writer in train_writers + eval_writers: + writer.close() + _log(f"Metrics were written in {self._output_dir}", stdout=False) + else: + yield ([None] * len(self._tasks), [None] * len(self._eval_tasks)) def _model_with_ends(model, end_layers, batch_signature): - """Returns a model+ends layer built on an already initialized model. + """Returns a model+ends layer built on an already initialized model. - Ends can be loss or metric layers. + Ends can be loss or metric layers. - Args: - model: Layer with initialized weights and state. - end_layers: List of end layers. - batch_signature: Signature of the model input batch. + Args: + model: Layer with initialized weights and state. + end_layers: List of end layers. + batch_signature: Signature of the model input batch. - Returns: - An initialized, combined model+ends layer, preserving the initialization - of ``model``. - """ - # TODO(jonni): Redo this function as part of an initialization refactor? - metrics_layer = tl.Branch(*end_layers) - metrics_input_signature = model.output_signature(batch_signature) - _, _ = metrics_layer.init(metrics_input_signature) + Returns: + An initialized, combined model+ends layer, preserving the initialization + of ``model``. + """ + # TODO(jonni): Redo this function as part of an initialization refactor? + metrics_layer = tl.Branch(*end_layers) + metrics_input_signature = model.output_signature(batch_signature) + _, _ = metrics_layer.init(metrics_input_signature) - model_with_metrics = tl.Serial(model, metrics_layer) - return model_with_metrics + model_with_metrics = tl.Serial(model, metrics_layer) + return model_with_metrics def _model_with_metrics(model, eval_task): - """Returns a model+metrics layer built on an already initialized model. + """Returns a model+metrics layer built on an already initialized model. - Args: - model: Layer with initialized weights and state. - eval_task: :py:class:`EvalTask` instance. + Args: + model: Layer with initialized weights and state. + eval_task: :py:class:`EvalTask` instance. - Returns: - An initialized, combined model+metrics layer, preserving the initialization - of ``model``. - """ - return _model_with_ends( - model, eval_task.metrics, shapes.signature(eval_task.sample_batch) - ) + Returns: + An initialized, combined model+metrics layer, preserving the initialization + of ``model``. + """ + return _model_with_ends( + model, eval_task.metrics, shapes.signature(eval_task.sample_batch) + ) @gin.configurable class TrainTask: - """A supervised task (labeled data + feedback mechanism) for training.""" - - def __init__(self, labeled_data, loss_layer, optimizer, - lr_schedule=None, n_steps_per_checkpoint=100, - n_steps_per_permanent_checkpoint=None, loss_name=None, - sample_batch=None, export_prefix=None): - r"""Configures a training task. - - Args: - labeled_data: Iterator of batches of labeled data tuples. Each tuple has - 1+ data (input value) tensors followed by 1 label (target value) - tensor. All tensors are NumPy ndarrays or their JAX counterparts. - loss_layer: Layer that computes a scalar value (the "loss") by comparing - model output :math:`\hat{y}=f(x)` to the target :math:`y`. - optimizer: Optimizer object that computes model weight updates from - loss-function gradients. - lr_schedule: Learning rate schedule, a function step -> learning_rate. - n_steps_per_checkpoint: How many steps to run between checkpoints. - n_steps_per_permanent_checkpoint: How many steps to run between permanent - checkpoints. - loss_name: Name for the loss metric. - sample_batch: Optional sample batch for model initialization. If not - provided, it will be taken from ``labeled_data``. - export_prefix: Optional task name to be used as prefix for exporting - metrics during training in Loop. - """ - self._export_prefix = export_prefix - self._labeled_data = labeled_data - self._loss_layer = loss_layer - self._optimizer = optimizer - self._lr_schedule = lr_schedule - self._sample_batch = sample_batch or next(labeled_data) - self._n_steps_per_checkpoint = n_steps_per_checkpoint - self._n_steps_per_permanent_checkpoint = n_steps_per_permanent_checkpoint - self._loss_name = loss_name or self._loss_layer.name - - @property - def labeled_data(self): - return self._labeled_data - - @property - def sample_batch(self): - return self._sample_batch - - def next_batch(self): - """Returns one batch of labeled data: a tuple of input(s) plus label.""" - return next(self._labeled_data) - - @property - def export_prefix(self): - return self._export_prefix - - @property - def loss_layer(self): - return self._loss_layer - - @property - def loss_name(self): - return self._loss_name - - @property - def n_steps_per_checkpoint(self): - return self._n_steps_per_checkpoint - - @property - def n_steps_per_permanent_checkpoint(self): - return self._n_steps_per_permanent_checkpoint - - @property - def optimizer(self): - return self._optimizer - - def learning_rate(self, step): - """Return the learning rate for the given step.""" - if self._lr_schedule is not None: - with fastmath.use_backend(fastmath.Backend.NUMPY): - return self._lr_schedule(step) - opt = self._optimizer - if callable(opt): # when optimizer is a function, like Adam, not Adam() - opt = opt() - params = opt._init_opt_params # pylint: disable=protected-access - return params['learning_rate'] + """A supervised task (labeled data + feedback mechanism) for training.""" + + def __init__( + self, + labeled_data, + loss_layer, + optimizer, + lr_schedule=None, + n_steps_per_checkpoint=100, + n_steps_per_permanent_checkpoint=None, + loss_name=None, + sample_batch=None, + export_prefix=None, + ): + r"""Configures a training task. + + Args: + labeled_data: Iterator of batches of labeled data tuples. Each tuple has + 1+ data (input value) tensors followed by 1 label (target value) + tensor. All tensors are NumPy ndarrays or their JAX counterparts. + loss_layer: Layer that computes a scalar value (the "loss") by comparing + model output :math:`\hat{y}=f(x)` to the target :math:`y`. + optimizer: Optimizer object that computes model weight updates from + loss-function gradients. + lr_schedule: Learning rate schedule, a function step -> learning_rate. + n_steps_per_checkpoint: How many steps to run between checkpoints. + n_steps_per_permanent_checkpoint: How many steps to run between permanent + checkpoints. + loss_name: Name for the loss metric. + sample_batch: Optional sample batch for model initialization. If not + provided, it will be taken from ``labeled_data``. + export_prefix: Optional task name to be used as prefix for exporting + metrics during training in Loop. + """ + self._export_prefix = export_prefix + self._labeled_data = labeled_data + self._loss_layer = loss_layer + self._optimizer = optimizer + self._lr_schedule = lr_schedule + self._sample_batch = sample_batch or next(labeled_data) + self._n_steps_per_checkpoint = n_steps_per_checkpoint + self._n_steps_per_permanent_checkpoint = n_steps_per_permanent_checkpoint + self._loss_name = loss_name or self._loss_layer.name + + @property + def labeled_data(self): + return self._labeled_data + + @property + def sample_batch(self): + return self._sample_batch + + def next_batch(self): + """Returns one batch of labeled data: a tuple of input(s) plus label.""" + return next(self._labeled_data) + + @property + def export_prefix(self): + return self._export_prefix + + @property + def loss_layer(self): + return self._loss_layer + + @property + def loss_name(self): + return self._loss_name + + @property + def n_steps_per_checkpoint(self): + return self._n_steps_per_checkpoint + + @property + def n_steps_per_permanent_checkpoint(self): + return self._n_steps_per_permanent_checkpoint + + @property + def optimizer(self): + return self._optimizer + + def learning_rate(self, step): + """Return the learning rate for the given step.""" + if self._lr_schedule is not None: + with fastmath.use_backend(fastmath.Backend.NUMPY): + return self._lr_schedule(step) + opt = self._optimizer + if callable(opt): # when optimizer is a function, like Adam, not Adam() + opt = opt() + params = opt._init_opt_params # pylint: disable=protected-access + return params["learning_rate"] @gin.configurable class EvalTask: - """Labeled data plus scalar functions for (periodically) measuring a model. + """Labeled data plus scalar functions for (periodically) measuring a model. - An eval task specifies how (``labeled_data`` + ``metrics``) and with what - precision (``n_eval_batches``) to measure a model as it is training. - The variance of each scalar output is reduced by measuring over multiple - (``n_eval_batches``) batches and reporting the average from those - measurements. - """ - - def __init__(self, labeled_data, metrics, - metric_names=None, n_eval_batches=1, sample_batch=None, - export_prefix=None): - r"""Configures an eval task: named metrics run with a given data source. - - Args: - labeled_data: Iterator of batches of labeled data tuples. Each tuple has - 1+ data tensors (NumPy ndarrays) followed by 1 label (target value) - tensor. - metrics: List of layers; each computes a scalar value per batch by - comparing model output :math:`\hat{y}=f(x)` to the target :math:`y`. - metric_names: List of names, one for each item in ``metrics``, in matching - order, to be used when recording/reporting eval output. If ``None``, - generate default names using layer names from metrics. - n_eval_batches: Integer N that specifies how many eval batches to run; - the output is then the average of the outputs from the N batches. - sample_batch: Optional sample batch for model initialization. If not - provided, it will be taken from ``labeled_data``. - export_prefix: Optional task name to be used as prefix for exporting - metrics during evaluation in Loop. + An eval task specifies how (``labeled_data`` + ``metrics``) and with what + precision (``n_eval_batches``) to measure a model as it is training. + The variance of each scalar output is reduced by measuring over multiple + (``n_eval_batches``) batches and reporting the average from those + measurements. """ - self._export_prefix = export_prefix - self._labeled_data = labeled_data - self._metrics = metrics - self._metric_names = metric_names or self._default_names() - self._n_eval_batches = n_eval_batches # pylint: disable=invalid-name - self._sample_batch = sample_batch or next(labeled_data) - self._check_init_values() + def __init__( + self, + labeled_data, + metrics, + metric_names=None, + n_eval_batches=1, + sample_batch=None, + export_prefix=None, + ): + r"""Configures an eval task: named metrics run with a given data source. + + Args: + labeled_data: Iterator of batches of labeled data tuples. Each tuple has + 1+ data tensors (NumPy ndarrays) followed by 1 label (target value) + tensor. + metrics: List of layers; each computes a scalar value per batch by + comparing model output :math:`\hat{y}=f(x)` to the target :math:`y`. + metric_names: List of names, one for each item in ``metrics``, in matching + order, to be used when recording/reporting eval output. If ``None``, + generate default names using layer names from metrics. + n_eval_batches: Integer N that specifies how many eval batches to run; + the output is then the average of the outputs from the N batches. + sample_batch: Optional sample batch for model initialization. If not + provided, it will be taken from ``labeled_data``. + export_prefix: Optional task name to be used as prefix for exporting + metrics during evaluation in Loop. + """ + self._export_prefix = export_prefix + self._labeled_data = labeled_data + self._metrics = metrics + self._metric_names = metric_names or self._default_names() + self._n_eval_batches = n_eval_batches # pylint: disable=invalid-name + + self._sample_batch = sample_batch or next(labeled_data) + self._check_init_values() + + @property + def labeled_data(self): + return self._labeled_data + + @property + def sample_batch(self): + return self._sample_batch + + def next_batch(self): + """Returns one batch of labeled data: a tuple of input(s) plus label.""" + return next(self._labeled_data) + + @property + def export_prefix(self): + return self._export_prefix + + @property + def metrics(self): + return self._metrics + + @property + def metric_names(self): + return self._metric_names + + @property + def n_eval_batches(self): + return self._n_eval_batches + + def _default_names(self): + return [m.name for m in self._metrics] + + def _check_init_values(self): + if len(self._metrics) != len(self._metric_names): + raise ValueError( + f"Number of metrics ({len(self._metrics)}) does not equal " + f"number of metric names ({len(self._metric_names)})." + ) - @property - def labeled_data(self): - return self._labeled_data - @property - def sample_batch(self): - return self._sample_batch +def _never(*args): + """Returns False for all step numbers.""" + del args + return False - def next_batch(self): - """Returns one batch of labeled data: a tuple of input(s) plus label.""" - return next(self._labeled_data) - @property - def export_prefix(self): - return self._export_prefix +def _at_step_1_and_every_nth_step(period): + """A function that's true at 1 and n when n % period == 0.""" + if period is None: + return lambda step_n: False - @property - def metrics(self): - return self._metrics + def _at_1_and_periodically(step_n): + return (step_n == 1) or (step_n > 0 and (step_n % period == 0)) - @property - def metric_names(self): - return self._metric_names + return _at_1_and_periodically - @property - def n_eval_batches(self): - return self._n_eval_batches - def _default_names(self): - return [m.name for m in self._metrics] +def _log(s, stdout=True): + logging.info(s) + if stdout: + print(s) + sys.stdout.flush() - def _check_init_values(self): - if len(self._metrics) != len(self._metric_names): - raise ValueError( - f'Number of metrics ({len(self._metrics)}) does not equal ' - f'number of metric names ({len(self._metric_names)}).') +def pickle_to_file(obj, file_path, gzip=False): + """Pickle obj to file_path with gzipping and failure protection.""" + # Pickle to tmp file and overwrite to prevent writing partial files. + tmp_file_path = file_path + "._tmp_" + with tf.io.gfile.GFile(tmp_file_path, "wb") as f: + if not gzip: + pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) + else: + with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf: + pickle.dump(obj, gzipf, protocol=pickle.HIGHEST_PROTOCOL) + # Moving a file is much less error-prone than pickling large files. + tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True) -def _never(*args): - """Returns False for all step numbers.""" - del args - return False +def unpickle_from_file(file_path, gzip=False): + """Unpickle obj from file_path with gzipping.""" + with tf.io.gfile.GFile(file_path, "rb") as f: + if not gzip: + obj = pickle.load(f) + else: + with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf: + obj = pickle.load(gzipf) + return obj -def _at_step_1_and_every_nth_step(period): - """A function that's true at 1 and n when n % period == 0.""" - if period is None: - return lambda step_n: False - def _at_1_and_periodically(step_n): - return (step_n == 1) or (step_n > 0 and (step_n % period == 0)) - return _at_1_and_periodically +def _init_random_number_generators(seed=None): + """Initializes random generators for Python, NumPy, TensorFlow, and JAX.""" + # Seed Python random (None as seed is okay), then use it to seed the others. + random.seed(seed) + if seed is None: + seed = random.randint(0, 2**31 - 1) + logging.info("using seed %d", seed) + np.random.seed(seed) + tf.random.set_seed(seed) + return jax_random.get_prng(seed) -def _log(s, stdout=True): - logging.info(s) - if stdout: - print(s) - sys.stdout.flush() +def init_host_and_devices(n_devices=None, random_seed=None): + """Initializes host and device attributes for this trainer. + Args: + n_devices: Number of devices this trainer will use. If ``None``, get the + number from the backend. + random_seed: Random seed as the starting point for all random numbers used + by the trainer. If ``None``, calculate one from system time and host id. -def pickle_to_file(obj, file_path, gzip=False): - """Pickle obj to file_path with gzipping and failure protection.""" - # Pickle to tmp file and overwrite to prevent writing partial files. - tmp_file_path = file_path + '._tmp_' - with tf.io.gfile.GFile(tmp_file_path, 'wb') as f: - if not gzip: - pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) + Returns: + is_chief: True if this trainer has special chief responsibilities. + host_count: Number of hosts in this computation. + n_devices: The passed in value of n_devices or a computed default (for this + host). + random_seed: The passed in value of random_seed or a computed default. + """ + if fastmath.is_backend(fastmath.Backend.JAX): + host_id = jax.process_index() + host_count = jax.host_count() else: - with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf: - pickle.dump(obj, gzipf, protocol=pickle.HIGHEST_PROTOCOL) - # Moving a file is much less error-prone than pickling large files. - tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True) + host_id = 0 + host_count = 1 + is_chief = host_id == 0 + + logging.info( + "Initializing hosts and devices: host_id %d, host_count %d, " "is_chief %d", + host_id, + host_count, + is_chief, + ) + device_count = fastmath.local_device_count() + n_devices = n_devices or device_count + # TODO(lukaszkaiser): remove this restriction when possible. + if n_devices != device_count and fastmath.is_backend(fastmath.Backend.JAX): + raise ValueError( + "JAX cannot work yet with n_devices != all devices: " + "%d != %d" % (n_devices, device_count) + ) -def unpickle_from_file(file_path, gzip=False): - """Unpickle obj from file_path with gzipping.""" - with tf.io.gfile.GFile(file_path, 'rb') as f: - if not gzip: - obj = pickle.load(f) - else: - with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf: - obj = pickle.load(gzipf) - return obj + if random_seed is None and host_count > 1: + random_seed = int(1e6 * (host_id + time.time())) % 2**31 + return ( + is_chief, + host_count, + n_devices, + _init_random_number_generators(random_seed), + ) -def _init_random_number_generators(seed=None): - """Initializes random generators for Python, NumPy, TensorFlow, and JAX.""" - # Seed Python random (None as seed is okay), then use it to seed the others. - random.seed(seed) - if seed is None: - seed = random.randint(0, 2**31 - 1) - logging.info('using seed %d', seed) - np.random.seed(seed) - tf.random.set_seed(seed) - return jax_random.get_prng(seed) +def _accelerate_model_with_metrics( + model_with_metrics, n_devices, accelerate=True, do_mean=True +): + if not accelerate: + return model_with_metrics.pure_fn + return tl.jit_forward(model_with_metrics.pure_fn, n_devices, do_mean=do_mean) -def init_host_and_devices(n_devices=None, random_seed=None): - """Initializes host and device attributes for this trainer. - - Args: - n_devices: Number of devices this trainer will use. If ``None``, get the - number from the backend. - random_seed: Random seed as the starting point for all random numbers used - by the trainer. If ``None``, calculate one from system time and host id. - - Returns: - is_chief: True if this trainer has special chief responsibilities. - host_count: Number of hosts in this computation. - n_devices: The passed in value of n_devices or a computed default (for this - host). - random_seed: The passed in value of random_seed or a computed default. - """ - if fastmath.is_backend(fastmath.Backend.JAX): - host_id = jax.process_index() - host_count = jax.host_count() - else: - host_id = 0 - host_count = 1 - is_chief = (host_id == 0) - - logging.info('Initializing hosts and devices: host_id %d, host_count %d, ' - 'is_chief %d', host_id, host_count, is_chief) - - device_count = fastmath.local_device_count() - n_devices = n_devices or device_count - # TODO(lukaszkaiser): remove this restriction when possible. - if n_devices != device_count and fastmath.is_backend(fastmath.Backend.JAX): - raise ValueError('JAX cannot work yet with n_devices != all devices: ' - '%d != %d' % (n_devices, device_count)) - - if random_seed is None and host_count > 1: - random_seed = int(1e6 * (host_id + time.time())) % 2**31 - return (is_chief, host_count, n_devices, - _init_random_number_generators(random_seed)) - - -def _accelerate_model_with_metrics(model_with_metrics, n_devices, - accelerate=True, do_mean=True): - if not accelerate: - return model_with_metrics.pure_fn - - return tl.jit_forward(model_with_metrics.pure_fn, n_devices, do_mean=do_mean) - - -@functools.partial(fastmath.pmap, axis_name='devices', donate_argnums=(0,)) + +@functools.partial(fastmath.pmap, axis_name="devices", donate_argnums=(0,)) def _make_weights_and_state_same_across_hosts(weights_and_state): - """Makes train and eval model's weights and state the same across hosts.""" + """Makes train and eval model's weights and state the same across hosts.""" - # We assume that weights_and_state have been already replicated, i.e the - # leading axis is self._n_devices + # We assume that weights_and_state have been already replicated, i.e the + # leading axis is self._n_devices - # This is the total number of devices across all hosts. - n_devices_total = fastmath.psum(jnp.array(1.0), 'devices').astype(jnp.int32) + # This is the total number of devices across all hosts. + n_devices_total = fastmath.psum(jnp.array(1.0), "devices").astype(jnp.int32) - # We average the weights and state across all devices. - # We also make sure we don't change the type of the weights and state. - return fastmath.nested_map( - lambda x: (fastmath.psum(x, 'devices') / n_devices_total).astype(x.dtype), - weights_and_state) + # We average the weights and state across all devices. + # We also make sure we don't change the type of the weights and state. + return fastmath.nested_map( + lambda x: (fastmath.psum(x, "devices") / n_devices_total).astype(x.dtype), + weights_and_state, + ) def _is_empty(x): - if isinstance(x, (list, tuple)): - return all(_is_empty(y) for y in x) - else: - return x is None + if isinstance(x, (list, tuple)): + return all(_is_empty(y) for y in x) + else: + return x is None def _is_uninitialized(model): - """Checks whether no weights in the model have been initialized.""" - if not _is_empty(model.weights): - return False - return all(_is_uninitialized(l) for l in model.sublayers) + """Checks whether no weights in the model have been initialized.""" + if not _is_empty(model.weights): + return False + return all(_is_uninitialized(l) for l in model.sublayers) def _match_by_shape(full, partial): - """Puts partial into full matching by shape.""" - partial_idx = 0 - res = [] - for w in full: - if partial_idx >= len(partial): - res.append(w) # read everything from parial list, just fill - elif w is None and partial[partial_idx] is None: # both Nones, move on - res.append(None) - partial_idx += 1 - elif w is None or partial[partial_idx] is None: # one None but not both - res.append(w) - elif w.shape == partial[partial_idx].shape: - res.append(partial[partial_idx]) - partial_idx += 1 - else: - res.append(w) - if partial_idx < len(partial): - _log('Did not manage to match shapes in model for all checkpoint weights.') - for w in partial[:partial_idx]: - _log(' Inserted tensor of shape %s' % str(w.shape)) - for i, w in enumerate(partial[partial_idx:]): - _log(' Not inserted tensor of shape %s' % str(w.shape)) - model_weight_shape = str(full[i + partial_idx].shape) - _log(' Tensor in that place has shape: %s' % model_weight_shape) - raise IndexError - return res + """Puts partial into full matching by shape.""" + partial_idx = 0 + res = [] + for w in full: + if partial_idx >= len(partial): + res.append(w) # read everything from parial list, just fill + elif w is None and partial[partial_idx] is None: # both Nones, move on + res.append(None) + partial_idx += 1 + elif w is None or partial[partial_idx] is None: # one None but not both + res.append(w) + elif w.shape == partial[partial_idx].shape: + res.append(partial[partial_idx]) + partial_idx += 1 + else: + res.append(w) + if partial_idx < len(partial): + _log("Did not manage to match shapes in model for all checkpoint weights.") + for w in partial[:partial_idx]: + _log(" Inserted tensor of shape %s" % str(w.shape)) + for i, w in enumerate(partial[partial_idx:]): + _log(" Not inserted tensor of shape %s" % str(w.shape)) + model_weight_shape = str(full[i + partial_idx].shape) + _log(" Tensor in that place has shape: %s" % model_weight_shape) + raise IndexError + return res def _flatten_and_remove_empty(x): - flat = fastmath.tree_flatten(x) - return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison + flat = fastmath.tree_flatten(x) + return [ + f for f in flat if f is not None and f is not () + ] # pylint: disable=literal-comparison diff --git a/trax/supervised/training_test.py b/trax/supervised/training_test.py deleted file mode 100644 index 7d6d6bf5a..000000000 --- a/trax/supervised/training_test.py +++ /dev/null @@ -1,674 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for supervised training: core classes and flows.""" - -import collections -import os -import time - -from absl.testing import absltest -from jax.config import config -import numpy as np - -from trax import data -from trax import fastmath -from trax import layers as tl -from trax import optimizers -from trax import shapes -from trax import test_utils -from trax.layers import base -from trax.models import transformer -from trax.supervised import callbacks -from trax.supervised import training - - -class TrainingTest(absltest.TestCase): - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - - def test_loop_no_eval_task(self): - """Runs a training loop with no eval task(s).""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - training_session = training.Loop(model, [task]) - # Loop should initialize and run successfully, even with no eval task. - training_session.run(n_steps=5) - - - def test_loop_checkpoint_low_metric(self): - """Runs a training loop that saves checkpoints for low metric values.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask(_very_simple_data(), - tl.L2Loss(), - optimizers.SGD(.01)) - eval_metric = tl.L2Loss() - eval_task = training.EvalTask(_very_simple_data(), - [eval_metric], - metric_names=['l2_loss']) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, - [task], - eval_tasks=[eval_task], - output_dir=tmp_dir, - eval_at=lambda step_n: step_n % 2 == 0, - checkpoint_at=lambda step_n: step_n % 2 == 0, - checkpoint_low_metric='l2_loss') - call_counter = collections.Counter() - loop.save_checkpoint = lambda name: call_counter.update([name]) - loop.run(n_steps=10) - - # Eval metric steadily descends, so low checkpoint triggered all 5 times. - # High checkpoint not defined, so never triggered. - self.assertEqual(call_counter['model'], 5) - self.assertEqual(call_counter['lowest_l2_loss'], 5) - self.assertEqual(call_counter['highest_l2_loss'], 0) - - def test_loop_checkpoint_high_metric(self): - """Runs a training loop that saves checkpoints for high metric values.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask(_very_simple_data(), - tl.L2Loss(), - optimizers.SGD(.01)) - eval_metric = tl.L2Loss() - eval_task = training.EvalTask(_very_simple_data(), - [eval_metric], - metric_names=['l2_loss']) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, - [task], - eval_tasks=[eval_task], - output_dir=tmp_dir, - eval_at=lambda step_n: step_n % 2 == 0, - checkpoint_at=lambda step_n: step_n % 2 == 0, - checkpoint_high_metric='l2_loss') - call_counter = collections.Counter() - loop.save_checkpoint = lambda name: call_counter.update([name]) - loop.run(n_steps=10) - - # Eval metric steadily descends, so high checkpoint triggered only once. - # Low checkpoint not defined, so never triggered. - self.assertEqual(call_counter['model'], 5) - self.assertEqual(call_counter['lowest_l2_loss'], 0) - self.assertEqual(call_counter['highest_l2_loss'], 1) - - def test_train_dense_layer(self): - """Trains a very simple network on a very simple task.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - training_session = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=15) - self.assertEqual(15, training_session.step) - training_session.run(n_steps=5) - self.assertEqual(20, training_session.step) - - def test_loop_with_initialized_model(self): - """Check that loop does not re-initialize an already initialized model.""" - model = tl.Serial(tl.Dense(1)) - example_data = next(_very_simple_data()) - model.init(example_data) - w = model.weights[0][0] - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - loop = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0) - self.assertEqual(0, loop.step) - self.assertEqual(loop.model.weights[0][0], w) - - def test_train_save_restore_dense(self): - """Saves and restores a checkpoint to check for equivalence.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') - train_data = data.Serial(lambda _: _very_simple_data(), - data.CountAndSkip('simple_data')) - task = training.TrainTask( - train_data(), tl.L2Loss(), optimizers.Adam(.0001)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - tmp_dir = self.create_tempdir().full_path - - def _make_model_and_session(): - m = tl.Serial(tl.Dense(1)) - ts = training.Loop(m, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - return m, ts - - model, training_session = _make_model_and_session() - self.assertEqual(0, training_session.step) - training_session.run(n_steps=1) - training_session.save_checkpoint('model') - self.assertEqual(data.inputs.data_counters['simple_data'], 2) - data.inputs.data_counters['simple_data'] = 0 # reset manually - self.assertEqual(data.inputs.data_counters['simple_data'], 0) # check - model2, training_session2 = _make_model_and_session() - self.assertEqual(data.inputs.data_counters['simple_data'], 2) # restored - - x = np.ones((8, 1)) - y1 = model(x, rng=fastmath.random.get_prng(0)) - y2 = model2(x, rng=fastmath.random.get_prng(0)) - self.assertEqual(str(y1), str(y2)) - - training_session2.run(n_steps=1) - y1 = model(x, rng=fastmath.random.get_prng(0)) - y2 = model2(x, rng=fastmath.random.get_prng(0)) - self.assertNotEqual(str(y1), str(y2)) - - slots1 = training_session._trainer_per_task[0].slots - slots2 = training_session2._trainer_per_task[0].slots - np.testing.assert_array_equal(slots1, slots2) - - def test_train_save_restore_sharded(self): - """Saves and restores a sharded checkpoint to check for equivalence.""" - if fastmath.local_device_count() < 2: - return # multi-accelerator only - base.N_WEIGHTS_SHARDS = fastmath.local_device_count() - train_data = data.Serial(lambda _: _very_simple_data(2, 2), - data.CountAndSkip('simple_data')) - task = training.TrainTask( - train_data(), tl.L2Loss(), optimizers.Adam(.0001)) - eval_task = training.EvalTask( - _very_simple_data(2, 2), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - tmp_dir = self.create_tempdir().full_path - - def _make_model_and_session(): - m = tl.Serial(tl.Dense(2)) - ts = training.Loop(m, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - return m, ts - - _, training_session = _make_model_and_session() - self.assertEqual(0, training_session.step) - training_session.run(n_steps=1) - training_session.save_checkpoint('model') - _, training_session2 = _make_model_and_session() - training_session2.run(n_steps=1) - base.N_WEIGHTS_SHARDS = 1 - - def test_train_save_restore_transformer(self): - """Saves and restores a checkpoint to check for equivalence.""" - vocab_size = 8 - task = training.TrainTask( - _very_simple_transformer_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_transformer_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - tmp_dir = self.create_tempdir().full_path - - def _make_model_and_session(): - m = transformer.TransformerLM( - vocab_size, d_model=4, d_ff=4, n_layers=1, n_heads=2, dropout=0.) - ts = training.Loop(m, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - return m, ts - - model, training_session = _make_model_and_session() - self.assertEqual(0, training_session.step) - training_session.run(n_steps=1) - training_session.save_checkpoint('model') - model2, training_session2 = _make_model_and_session() - - x = np.ones((2, 2)).astype(np.int32) - y1 = model(x, rng=fastmath.random.get_prng(0)) - y2 = model2(x, rng=fastmath.random.get_prng(0)) - self.assertEqual(str(y1), str(y2)) - - training_session2.run(n_steps=1) - y1 = model(x, rng=fastmath.random.get_prng(0)) - y2 = model2(x, rng=fastmath.random.get_prng(0)) - self.assertNotEqual(str(y1), str(y2)) - - def test_train_dense_layer_with_momentum(self): - """Trains with an optimizer that has slots / requires initialization.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Momentum(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['Momentum.L2Loss']) - training_session = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=20) - self.assertEqual(20, training_session.step) - - def test_train_dense_layer_evals(self): - """Trains a very simple network on a very simple task, 2 epochs.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()]) - training_session = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: False) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=10) - self.assertEqual(10, training_session.step) - training_session.run_evals() - self.assertEqual(10, training_session.step) # Unchanged - - def test_summaries_are_written(self): - """Training writes down metrics when writing is turned on.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - tmp_dir = self.create_tempdir().full_path - training_session = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - expected_train_metric_dir = os.path.join(tmp_dir, 'train') - expected_eval_metric_dir = os.path.join(tmp_dir, 'eval') - for directory in [expected_train_metric_dir, expected_eval_metric_dir]: - self.assertFalse( - os.path.isdir(directory), 'Failed for directory %s.' % directory) - training_session.run(n_steps=15) - time.sleep(1) # wait for the files to be closed - for directory in [expected_train_metric_dir, expected_eval_metric_dir]: - self.assertTrue( - os.path.isdir(directory), 'Failed for directory %s.' % directory) - self.assertEqual( - 1, _count_files(directory), 'Failed for directory %s.' % directory) - training_session.run(n_steps=5) - time.sleep(1) # wait for the files to be closed - for directory in [expected_train_metric_dir, expected_eval_metric_dir]: - self.assertEqual( - 2, _count_files(directory), 'Failed for directory %s.' % directory) - - def test_restores_step(self): - """Training restores step from directory where it saved it.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(4) - loop2 = training.Loop(model, [task], output_dir=tmp_dir) - self.assertEqual(4, loop2.step) - - def test_restores_memory_efficient_from_standard(self): - """Training restores step from directory where it saved it.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') - model = tl.Serial(tl.Dense(4), tl.Dense(1)) - task_std = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001)) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task_std], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(4) - task_memeff = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Adam) - loop2 = training.Loop(model, [task_memeff], output_dir=tmp_dir, - use_memory_efficient_trainer=True) - loop2.run(2) - self.assertEqual(6, loop2.step) - - def test_restores_from_smaller_model(self): - """Training restores from a checkpoint created with smaller model.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') - model1 = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model1, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(2) - model2 = tl.Serial(tl.Dense(1), tl.Dense(1)) - loop2 = training.Loop(model2, [task], output_dir=tmp_dir) - self.assertEqual(2, loop2.step) - - def test_restore_fails_different_model(self): - """Training restores from a checkpoint created with smaller model.""" - model1 = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model1, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(2) - model2 = tl.Serial(tl.Dense(2)) - with self.assertRaises(IndexError): - training.Loop(model2, [task], output_dir=tmp_dir) - - def test_restores_step_bfloat16(self): - """Training restores step from directory where it saved it, w/ bfloat16.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') - model = tl.Serial(tl.Dense(1, use_bfloat16=True)) - # We'll also use Adafactor with bfloat16 to check restoring bfloat slots. - opt = optimizers.Adafactor(.01, do_momentum=True, momentum_in_bfloat16=True) - task = training.TrainTask(_very_simple_data(), tl.L2Loss(), opt) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(4) - loop2 = training.Loop(model, [task], output_dir=tmp_dir) - self.assertEqual(4, loop2.step) - loop2.run(2) # check that continued training works - self.assertEqual(6, loop2.step) - - def test_restores_step_sharded(self): - """Training restores step from directory where it saved it, sharded.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir, use_memory_efficient_trainer=True) - loop.run(4) - loop2 = training.Loop(model, [task], - output_dir=tmp_dir, use_memory_efficient_trainer=True) - self.assertEqual(4, loop2.step) - - def test_restores_step_sharded_bfloat16(self): - """Training restores step from where it saved it, sharded and bfloat16.""" - model = tl.Serial(tl.Dense(1, use_bfloat16=True)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir, use_memory_efficient_trainer=True) - loop.run(4) - loop2 = training.Loop(model, [task], - output_dir=tmp_dir, use_memory_efficient_trainer=True) - self.assertEqual(4, loop2.step) - loop2.run(2) # check that continued training works - self.assertEqual(6, loop2.step) - - def test_restores_history(self): - """Training restores history from directory where it saved it.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask(_very_simple_data(), tl.L2Loss(), - optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()]) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop( - model, [task], - eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(4) - loop2 = training.Loop(model, [task], output_dir=tmp_dir) - self.assertLen(loop2.history.modes, 2) - self.assertLen(loop2.history.metrics_for_mode('train'), 6) - self.assertLen(loop2.history.metrics_for_mode('eval'), 1) - for mode, metric in [ - ('train', 'metrics/L2Loss'), - ('train', 'training/learning_rate'), - ('train', 'training/steps per second'), - ('train', 'training/gradients_l2'), - ('train', 'training/loss'), - ('train', 'training/weights_l2'), - ('eval', 'metrics/L2Loss'), - ]: - self.assertLen(loop2.history.get(mode, metric), 1) - self.assertEqual(2, loop2.history.get(mode, metric)[0][0]) - - def test_trains_on_two_tasks(self): - """Trains a very simple network on two very simple tasks.""" - model = tl.Serial(tl.Dense(3), tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), - tl.L2Loss(), - optimizers.SGD(.01) - ) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - ) - training_session = training.Loop( - model, - tasks=(task, task), - eval_tasks=(eval_task, eval_task), - which_task=lambda step_n: step_n % 2, - ) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=15) - self.assertEqual(15, training_session.step) - training_session.run(n_steps=5) - self.assertEqual(20, training_session.step) - - def test_train_one_task_eval_two_tasks(self): - """Trains a very simple network on one task and evaluates on two tasks.""" - model = tl.Serial(tl.Dense(3), tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), - tl.L2Loss(), - optimizers.SGD(.01) - ) - export_prefix_1 = 'eval_1' - eval_task_1 = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - export_prefix=export_prefix_1, - ) - export_prefix_2 = 'eval_2' - eval_task_2 = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - export_prefix=export_prefix_2, - ) - training_session = training.Loop( - model, - tasks=(task,), - eval_tasks=(eval_task_1, eval_task_2), - ) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=5) - self.assertEqual(5, training_session.step) - export_prefixes = [task.export_prefix - for task in training_session.eval_tasks] - self.assertCountEqual([export_prefix_1, export_prefix_2], - export_prefixes) - - def test_can_predict_with_trained_model(self): - model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) - train_tasks, eval_tasks = [], [] - for output_dim in [1, 2]: - # The head we select from the model: 0 for output_dim 1 and 1 for 2. - head_index = output_dim - 1 - train_tasks.append(training.TrainTask( - _very_simple_data(output_dim), - tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss()), - optimizers.SGD(.01) - )) - eval_tasks.append(training.EvalTask( - _very_simple_data(output_dim), # deliberately re-use training data - [tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss())] - )) - tmp_dir = self.create_tempdir().full_path - training_session = training.Loop( - model, - tasks=train_tasks, - eval_tasks=eval_tasks, - checkpoint_at=lambda step_n: step_n == 1, - output_dir=tmp_dir, - which_task=lambda step_n: step_n % 2, - ) - training_session.run(n_steps=2) - - trained_model = training_session.eval_model - inp = next(_very_simple_data())[0] - out = trained_model(inp) - self.assertEqual( - shapes.signature(out), - (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), - ) - - def test_train_memory_efficient(self): - """Trains a large network in a memory-efficient way.""" - # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU - # and CPU when you run it locally, but it's too big for unit-testing. - ram_limited = True # Set to False to run this test locally. - if fastmath.global_device_count() == 1 and ram_limited: - return - - # Create the model. - n_layers = 16 # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram - model = tl.Serial( - tl.Embedding(9, 16*1024), - tl.Dup(), - [[tl.ReversibleHalfResidual(tl.Dense(16*1024)), tl.ReversibleSwap()] - for _ in range(n_layers)], - tl.Concatenate(), - tl.Dense(9), - ) - - # Create inputs. - inputs_batch = np.arange(8).reshape((2, 4)) - targets_batch = inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - def _data_gen(): - while True: - yield labeled_batch - - # Run training. - loss_layer = tl.WeightedCategoryCrossEntropy() - task = training.TrainTask(_data_gen(), loss_layer, optimizers.Adafactor) - eval_task = training.EvalTask(_data_gen(), - [tl.WeightedCategoryCrossEntropy()]) - loop = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n == 2, - use_memory_efficient_trainer=True) - self.assertEqual(0, loop.step) - loop.run(n_steps=2) - self.assertEqual(2, loop.step) - - def test_initializes_step_callbacks_with_loop_instance(self): - """Runs a training loop, asserting that callbacks are initialized.""" - - class ActualLoop: - # Wrapper object to make the Loop reference mutable. - loop = None - - class TestCallback(callbacks.TrainingStepCallback): - - def __init__(self, loop): - super().__init__(loop) - ActualLoop.loop = loop - - def call_at(self, step): - return False - - def on_step_begin(self, step): - del step - - def on_step_end(self, step): - del step - - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01) - ) - expected_loop = training.Loop( - model, [task], callbacks=[TestCallback] - ) - self.assertIs(ActualLoop.loop, expected_loop) - - def test_calls_step_callbacks(self): - """Runs a training loop, asserting that callbacks are called.""" - call_at_steps = [1, 3, 4] - begin_steps = [] - end_steps = [] - test_case = self - - class TestCallback(callbacks.TrainingStepCallback): - - def call_at(self, step): - return step in call_at_steps - - def on_step_begin(self, step): - begin_steps.append(step) - - def on_step_end(self, step): - # Assert that on_step_begin() was called before. - test_case.assertIn(step, begin_steps) - end_steps.append(step) - - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01) - ) - loop = training.Loop(model, [task], callbacks=[TestCallback]) - loop.run(n_steps=5) - - # Assert that the callback has been called at the appropriate steps. - self.assertEqual(begin_steps, call_at_steps) - self.assertEqual(end_steps, call_at_steps) - - -def _very_simple_data(output_dim=1, input_dim=1): - """"Returns stream of labeled data that maps small integers to constant pi.""" - inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch - inputs_batch = np.concatenate([inputs_batch] * input_dim, axis=1) - targets_batch = np.pi * np.ones((8, output_dim)) - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - while True: - yield labeled_batch - - -def _very_simple_transformer_data(): - """"Returns stream of labeled data that maps small integers to constant pi.""" - inputs_batch = np.ones((2, 2)).astype(np.int32) - targets_batch = np.ones((2, 2, 8)).astype(np.int32) - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - while True: - yield labeled_batch - - -def _count_files(path): - """Returns number of files in a given directory.""" - return len([filename for filename in os.listdir(path) - if os.path.isfile(os.path.join(path, filename))]) - - -if __name__ == '__main__': - config.config_with_absl() - absltest.main() diff --git a/trax/test_fork.py b/trax/test_fork.py new file mode 100644 index 000000000..2f4341571 --- /dev/null +++ b/trax/test_fork.py @@ -0,0 +1,9 @@ +print("Hello I am forked test lib 1") +print("Hello I am forked test lib 2") + +def fun_1(): + print("Hello I am forked test lib fun 1") + + +def fun_2(): + print("Hello I am forked test lib fun 2") \ No newline at end of file diff --git a/trax/test_utils.py b/trax/test_utils.py index cca9f722d..683baba48 100644 --- a/trax/test_utils.py +++ b/trax/test_utils.py @@ -19,17 +19,14 @@ from absl import flags -FLAGS = flags.FLAGS - - # pytest doesn't run the test as a main, so it doesn't parse the flags # so if flags are required in tests, this will ensure that flags are manually # parsed and the desired flag exists. def ensure_flag(flag_str): - try: - getattr(FLAGS, flag_str) - except flags.UnparsedFlagAccessError: - # Manually parse flags. - FLAGS(sys.argv) - finally: - assert getattr(FLAGS, flag_str) + try: + getattr(flags.FLAGS, flag_str) + except flags.UnparsedFlagAccessError: + # Manually parse flags. + flags.FLAGS(sys.argv) + finally: + assert getattr(flags.FLAGS, flag_str) diff --git a/trax/tf_numpy/examples/mnist/dataset.py b/trax/tf_numpy/examples/mnist/dataset.py deleted file mode 100644 index 755f724bf..000000000 --- a/trax/tf_numpy/examples/mnist/dataset.py +++ /dev/null @@ -1,85 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Load pickled MNIST data.""" -import gzip -import os -import pickle -import random -import urllib -import numpy as np - - -def load(): - """Loads the dataset. - - Looks for the dataset at /tmp/mnist.pkl.gz and downloads it if it is not there - already. - - Note: The training data is shuffled. - - Returns: - ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)). - Shapes: - train_x: num_training_examples x image_size - train_y: num_training_examples x num_classes - valid_x: num_validation_examples x image_size - valid_y: num_validation_examples x num_classes - test_x: num_test_examples x image_size - test_y: num_test_examples x num_classes - """ - filepath = _maybe_download() - with gzip.open(os.path.join(filepath), 'rb') as f: - training_data, validation_data, test_data = pickle.load(f) - training_data = (training_data[0], [to_one_hot(x) for x in training_data[1]]) - validation_data = (validation_data[0], - [to_one_hot(x) for x in validation_data[1]]) - test_data = (test_data[0], [to_one_hot(x) for x in test_data[1]]) - - def shuffle(data): - zipped = zip(*data) - random.shuffle(zipped) - return zip(*zipped) - - return (shuffle(training_data), validation_data, test_data) - - -def to_one_hot(label, num_classes=10): - vec = np.zeros(num_classes, dtype=np.float32) - vec[label] = 1. - return vec - - -def _maybe_download(): - """Downloads the MNIST dataset if it is not there already.""" - data_url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' - filename = data_url.split('/')[-1] - filepath = os.path.join(_get_data_dir(), filename) - if not os.path.exists(filepath): - - def _progress(count, block_size, total_size): - print('\r>> Downloading %s %.1f%%' % - (filename, float(count * block_size) / float(total_size) * 100.0)) - - filepath, _ = urllib.urlretrieve(data_url, filepath, _progress) - statinfo = os.stat(filepath) - print('Successfully downloaded %s %d bytes.' % (filename, statinfo.st_size)) - else: - print('Data already present on disk.') - return filepath - - -def _get_data_dir(): - return '/tmp' diff --git a/trax/tf_numpy/examples/mnist/model.py b/trax/tf_numpy/examples/mnist/model.py deleted file mode 100644 index 8f5057b53..000000000 --- a/trax/tf_numpy/examples/mnist/model.py +++ /dev/null @@ -1,132 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model for training on MNIST data.""" -from numpy import float32 -from numpy import int32 - -import tensorflow.compat.v2 as tf - -from trax.tf_numpy import numpy as np - - -class Model(object): - """A simple neural network with dense layers and sigmoid non-linearity. - - The network consists of `len(hidden_layers) + 1` dense layers. The sizes of - the hidden layers are specified by the user in `hidden_layers` and the - network takes care of adding layers to match the input and output size. - - Attributes: - weights: A list of 2-d float32 arrays containing the layer weights. - biases: A list of 2-d float32 arrays containing the layer biases. - - Methods: - forward: Can be used to perform a forward pass on a batch of - flattened images. Output is returned as a batch of one-hot vectors of the - classes. - train: method performs a forward and backward pass and updates the - weights and biases. - evaluate: method can be used to evaluate the network on a batch of - examples. - """ - - def __init__(self, hidden_layers, input_size=784, num_classes=10): - """Initializes the neural network. - - Args: - hidden_layers: List of ints specifying the sizes of hidden layers. Could - be empty. - input_size: Length of the input array. The network receives the input - image as a flattened 1-d array. Defaults to 784(28*28), the default - image size for MNIST. - num_classes: The number of output classes. Defaults to 10. - """ - hidden_layers = [input_size] + hidden_layers + [num_classes] - self.weights = [] - self.biases = [] - for i in range(len(hidden_layers) - 1): - # TODO(srbs): This is manually cast to float32 to avoid the cast in - # np.dot since backprop fails for tf.cast op. - self.weights.append( - np.array( - np.random.randn(hidden_layers[i + 1], hidden_layers[i]), - copy=False, - dtype=float32)) - self.biases.append( - np.array( - np.random.randn(hidden_layers[i + 1]), copy=False, dtype=float32)) - - def forward(self, x): - """Performs the forward pass. - - Args: - x: 2-d array of size batch_size x image_size. - - Returns: - A 2-d array of size batch_size x num_classes. - """ - - def sigmoid(x): - return 1.0 / (1.0 + np.exp(-x)) - - for w, b in zip(self.weights, self.biases): - x = sigmoid(np.dot(w, x.T).T + b) - return x - - def train(self, x, y, learning_rate=0.01): - """Runs a single training pass. - - Args: - x: 2-d array of size batch_size x image_size. - y: 2-d array of size batch_size x num_classes in one-hot notation. - learning_rate: The learning rate. - """ - x = np.array(x, copy=False) - y = np.array(y, copy=False) - - def mean_squared_error(x, y): - diff = x - y - return np.sum(diff * diff) / len(x) - - wb_tensors = self.weights + self.biases - with tf.GradientTape() as g: - g.watch(wb_tensors) - loss = mean_squared_error(self.forward(x), y) - gradients = g.gradient(loss, wb_tensors) - gradients = [np.asarray(grad) for grad in gradients] - - new_weights_and_biases = [] - for v, dv in zip(self.weights + self.biases, gradients): - new_weights_and_biases.append(v - learning_rate * dv) - - total_len = len(new_weights_and_biases) - self.weights = new_weights_and_biases[:total_len // 2] - self.biases = new_weights_and_biases[total_len // 2:] - - def evaluate(self, x, y): - """Returns the number of correct predictions. - - Args: - x: 2-d array of size batch_size x image_size. - y: 2-d array of size batch_size x num_classes. - - Returns: - A scalar, the number of correct predictions. - """ - y_actual = np.argmax(y, axis=1) - y_predicted = np.argmax(self.forward(x), axis=1) - return int( - np.sum(np.array(y_actual == y_predicted, copy=False, dtype=int32))) diff --git a/trax/tf_numpy/examples/mnist/train.py b/trax/tf_numpy/examples/mnist/train.py deleted file mode 100644 index 766c71998..000000000 --- a/trax/tf_numpy/examples/mnist/train.py +++ /dev/null @@ -1,84 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Perform training.""" -from absl import app -from absl import flags - -from six.moves import range -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.examples.mnist import dataset -from trax.tf_numpy.examples.mnist import model as model_lib - -FLAGS = flags.FLAGS - -flags.DEFINE_integer('batch_size', 50, 'Batch size.') -flags.DEFINE_integer('num_training_iters', 10000, - 'Number of iterations to train for.') -flags.DEFINE_integer( - 'validation_steps', 100, - 'Validation is performed every these many training steps.') -flags.DEFINE_float('learning_rate', 5.0, 'Learning rate.') - - -def train(batch_size, learning_rate, num_training_iters, validation_steps): - """Runs the training.""" - print('Loading data') - training_data, validation_data, test_data = dataset.load() - print('Loaded dataset with {} training, {} validation and {} test examples.'. - format( - len(training_data[0]), len(validation_data[0]), len(test_data[0]))) - - assert len(training_data[0]) % batch_size == 0 - assert len(validation_data[0]) % batch_size == 0 - assert len(test_data[0]) % batch_size == 0 - - def build_iterator(data, infinite=True): - """Build the iterator for inputs.""" - index = 0 - size = len(data[0]) - while True: - if index + batch_size > size: - if infinite: - index = 0 - else: - return - yield data[0][index:index + batch_size], data[1][index:index + batch_size] - index += batch_size - - train_iter = build_iterator(training_data) - model = model_lib.Model([30]) - - for i in range(num_training_iters): - train_x, train_y = next(train_iter) - model.train(train_x, train_y, learning_rate) - if (i + 1) % validation_steps == 0: - validation_iter = build_iterator(validation_data, infinite=False) - correct_predictions = 0 - for valid_x, valid_y in validation_iter: - correct_predictions += model.evaluate(valid_x, valid_y) - print('{}/{} correct validation predictions.'.format( - correct_predictions, len(validation_data[0]))) - - -def main(unused_argv): - train(FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_training_iters, - FLAGS.validation_steps) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - app.run(main) diff --git a/trax/tf_numpy/examples/mnist/train_test.py b/trax/tf_numpy/examples/mnist/train_test.py deleted file mode 100644 index 55a6a5eb4..000000000 --- a/trax/tf_numpy/examples/mnist/train_test.py +++ /dev/null @@ -1,60 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test that the example training script works on fake data.""" -import mock -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.examples.mnist import dataset -from trax.tf_numpy.examples.mnist import train - - -class TFNumpyMnistExampleTest(tf.test.TestCase): - - def testRuns(self): - with mock.patch.object(dataset, 'load', new=fake_mnist_data): - train.train( - batch_size=1, - learning_rate=0.1, - num_training_iters=10, - validation_steps=5) - train.train( - batch_size=2, - learning_rate=0.1, - num_training_iters=5, - validation_steps=2) - train.train( - batch_size=10, - learning_rate=0.1, - num_training_iters=1, - validation_steps=1) - - -def fake_mnist_data(): - - def gen_examples(num_examples): - x = np.array( - np.random.randn(num_examples, 784), copy=False, dtype=np.float32) - y = np.zeros((num_examples, 10), dtype=np.float32) - y[:][0] = 1. - return (x, y) - - return (gen_examples(100), gen_examples(10), gen_examples(10)) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index 46c4261e8..2171f9e3b 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -24,729 +24,794 @@ import numpy as np import six + import tensorflow.compat.v2 as tf import trax.tf_numpy.numpy as tf_np _int_dtype_lower_bounds = [ - -2**63, -2**31, -2**15, -2**7, 0, 2**7, 2**15, 2**31, 2**64 + -(2**63), + -(2**31), + -(2**15), + -(2**7), + 0, + 2**7, + 2**15, + 2**31, + 2**64, ] _int_dtypes = [ - tf.int64, tf.int32, tf.int16, tf.int8, tf.uint8, tf.uint16, tf.uint32, - tf.uint64 + tf.int64, + tf.int32, + tf.int16, + tf.int8, + tf.uint8, + tf.uint16, + tf.uint32, + tf.uint64, ] -_tf_nn_APIs = {1: [tf.nn.conv1d, tf.nn.conv1d_transpose], - 2: [tf.nn.conv2d, tf.nn.conv2d_transpose], - 3: [tf.nn.conv3d, tf.nn.conv3d_transpose]} +_tf_nn_APIs = { + 1: [tf.nn.conv1d, tf.nn.conv1d_transpose], + 2: [tf.nn.conv2d, tf.nn.conv2d_transpose], + 3: [tf.nn.conv3d, tf.nn.conv3d_transpose], +} remat = tf.recompute_grad def most_precise_int_dtype(x): - if not isinstance(x, six.integer_types) or isinstance(x, bool): - return None - i = bisect.bisect_right(_int_dtype_lower_bounds, x) - if i in (0, len(_int_dtype_lower_bounds)): - raise ValueError("Integer %s is out of bounds" % x) - assert len(_int_dtype_lower_bounds) == len(_int_dtypes) + 1 - return _int_dtypes[i - 1] + if not isinstance(x, six.integer_types) or isinstance(x, bool): + return None + i = bisect.bisect_right(_int_dtype_lower_bounds, x) + if i in (0, len(_int_dtype_lower_bounds)): + raise ValueError(f"Integer {x} is out of bounds") + assert len(_int_dtype_lower_bounds) == len(_int_dtypes) + 1 + return _int_dtypes[i - 1] def _canonicalize_jit_arg(x): - if isinstance(x, tf_np.ndarray): - return x - try: - # We need to convert `int` to the most precise dtype, otherwise the dtype - # of the result may be different from numpy's. For example, when a binary - # op takes in a Python integer 5 and an array of uint32, numpy will pick - # uint32 as 5's dtype, while tf.convert_to_tensor will choose int32 which - # will cause the two arguments to be promoted to int64. We pick uint8 - # here, which will be promoted to uint32 by the binary op. - # Note that we prefer unsigned int to signed int when both are equally - # precise. For example, for 5, we pick uint8 instead of int8. There is no - # reason to prefer one to the other, because for each there is a case - # where the behavior diverges from numpy. If we prefer signed int, - # consider the case where the first operand is 5 and the second is - # 2**64-1. Numpy picks uint64 as the result dtype, but because we choose a - # signed type for 5 such as int8, the result type will be float64. On the - # other hand, if we prefer unsigned int, consider the case where the first - # operand is 2**31-1 and the second is -1. Numpy will pick int32, but - # because we choose uint32 for 2*32-1, the result will be int64. The root - # of the problem is that `jit` converts `int` to tensors (hence committing - # to a dtype) too early, when we don't have enough information about the - # jitted function (e.g. which subset of the arguments should be promoted - # together using np.result_type). tf.function doesn't have this problem - # because it doesn't convert `int` to tensors. jax.jit doesn't have this - # problem because it converts `int` to "int tracer" which doesn't commit - # to a dtype. - # TODO(wangpeng): Revisit this design and see whether we can improve `jit` - # and tf.function. - dtype = most_precise_int_dtype(x) - if dtype is None and isinstance(x, float): - dtype = tf_np.default_float_type() - return tf.convert_to_tensor(value=x, dtype=dtype) - except (TypeError, ValueError): - return x + if isinstance(x, tf_np.ndarray): + return x + try: + # We need to convert `int` to the most precise dtype, otherwise the dtype + # of the result may be different from numpy's. For example, when a binary + # op takes in a Python integer 5 and an array of uint32, numpy will pick + # uint32 as 5's dtype, while tf.convert_to_tensor will choose int32 which + # will cause the two arguments to be promoted to int64. We pick uint8 + # here, which will be promoted to uint32 by the binary op. + # Note that we prefer unsigned int to signed int when both are equally + # precise. For example, for 5, we pick uint8 instead of int8. There is no + # reason to prefer one to the other, because for each there is a case + # where the behavior diverges from numpy. If we prefer signed int, + # consider the case where the first operand is 5 and the second is + # 2**64-1. Numpy picks uint64 as the result dtype, but because we choose a + # signed type for 5 such as int8, the result type will be float64. On the + # other hand, if we prefer unsigned int, consider the case where the first + # operand is 2**31-1 and the second is -1. Numpy will pick int32, but + # because we choose uint32 for 2*32-1, the result will be int64. The root + # of the problem is that `jit` converts `int` to tensors (hence committing + # to a dtype) too early, when we don't have enough information about the + # jitted function (e.g. which subset of the arguments should be promoted + # together using np.result_type). tf.function doesn't have this problem + # because it doesn't convert `int` to tensors. jax.jit doesn't have this + # problem because it converts `int` to "int tracer" which doesn't commit + # to a dtype. + # TODO(wangpeng): Revisit this design and see whether we can improve `jit` + # and tf.function. + dtype = most_precise_int_dtype(x) + if dtype is None and isinstance(x, float): + dtype = tf_np.default_float_type() + return tf.convert_to_tensor(value=x, dtype=dtype) + except (TypeError, ValueError): + return x def _canonicalize_jit_arguments(inp): - """Canonicalize arguments to be used for jit. + """Canonicalize arguments to be used for jit. - Args: - inp: a nested structure of arguments to be canonicalized (i.e. to be - converted to Tensors). Only tf_np.ndarray and things accepted by - `tf.convert_to_tensor` will be converted. + Args: + inp: a nested structure of arguments to be canonicalized (i.e. to be + converted to Tensors). Only tf_np.ndarray and things accepted by + `tf.convert_to_tensor` will be converted. - Returns: - The canonicalized version. - """ - return tf.nest.map_structure(_canonicalize_jit_arg, inp) + Returns: + The canonicalized version. + """ + return tf.nest.map_structure(_canonicalize_jit_arg, inp) def _tf_to_np(inp): + def f(x): + if type(x).__name__ == "ndarray": + data = x._data - def f(x): - if isinstance(x, tf.IndexedSlices): - return tf_np.asarray(x) - else: - return x + if isinstance(data, tf.IndexedSlices): + data = tf.convert_to_tensor(data) + return tf_np.asarray(data) - return tf.nest.map_structure(f, inp) + if isinstance(x, tf.IndexedSlices): + return tf_np.asarray(x) + else: + return x + return tf.nest.map_structure(f, inp) -def stop_gradient(x): - def static_stop_gradient(x): - # `tf.stop_gradient` is a no-op for non-Tensor. Returning the original type - # allows it to be used in the conditional without Autograph, if static. For - # example: - # `if fastmath.stop_gradient(5) > 4:` - return tf.stop_gradient(x) if tf.is_tensor(x) else x +def stop_gradient(x): + def static_stop_gradient(x): + # `tf.stop_gradient` is a no-op for non-Tensor. Returning the original type + # allows it to be used in the conditional without Autograph, if static. For + # example: + # `if fastmath.stop_gradient(5) > 4:` + return tf.stop_gradient(x) if tf.is_tensor(x) else x - return _tf_to_np(tf.nest.map_structure(static_stop_gradient, x)) + return _tf_to_np(tf.nest.map_structure(static_stop_gradient, x)) def custom_grad(f_vjp, f_original=None): - """Decorator to define a function with a custom gradient. + """Decorator to define a function with a custom gradient. - This function is very similar to `tf.custom_gradient`. See the documentation - of `tf.custom_gradient` for detailed usage. + This function is very similar to `tf.custom_gradient`. See the documentation + of `tf.custom_gradient` for detailed usage. - The differences with `tf.custom_gradient` are: + The differences with `tf.custom_gradient` are: - - All arguments and results are tf_np.ndarrays instead of tensors. + - All arguments and results are tf_np.ndarrays instead of tensors. - - The `grad_fn` returned by `f_vjp` accepts and returns nested structures, - unlike that in `tf.custom_gradient` which only accepts and returns lists. + - The `grad_fn` returned by `f_vjp` accepts and returns nested structures, + unlike that in `tf.custom_gradient` which only accepts and returns lists. - Args: - f_vjp: the same as the `f` argument of `tf.custom_gradient`. Note that all - inputs and outputs of `f_vjp` and of the `grad_fn` function it returns can - be nested structures. - f_original: (optional) not used. + Args: + f_vjp: the same as the `f` argument of `tf.custom_gradient`. Note that all + inputs and outputs of `f_vjp` and of the `grad_fn` function it returns can + be nested structures. + f_original: (optional) not used. - Returns: - The same as `tf.custom_gradient`. - """ - del f_original + Returns: + The same as `tf.custom_gradient`. + """ + del f_original - @tf.custom_gradient - def tf_f(*tf_args, **tf_kwargs): - np_args = _tf_to_np(tf_args) - np_kwargs = _tf_to_np(tf_kwargs) - np_y, np_vjp = f_vjp(*np_args, **np_kwargs) - tf_y = np_y + @tf.custom_gradient + def tf_f(*tf_args, **tf_kwargs): + np_args = _tf_to_np(tf_args) + np_kwargs = _tf_to_np(tf_kwargs) + np_y, np_vjp = f_vjp(*np_args, **np_kwargs) + tf_y = np_y - def tf_vjp(*flat_tf_dy): - tf_dy = tf.nest.pack_sequence_as(tf_y, flat_tf_dy) - np_dy = _tf_to_np(tf_dy) - np_dx = np_vjp(np_dy) - return tf.nest.flatten(np_dx) + def tf_vjp(*flat_tf_dy): + tf_dy = tf.nest.pack_sequence_as(tf_y, flat_tf_dy) + np_dy = _tf_to_np(tf_dy) + np_dx = np_vjp(np_dy) + return tf.nest.flatten(np_dx) - return tf_y, tf_vjp + return tf_y, tf_vjp - def np_f(*args, **kwargs): - return _tf_to_np(tf_f(*args), **kwargs) + def np_f(*args, **kwargs): + return _tf_to_np(tf_f(*args), **kwargs) - return np_f + return np_f def vjp(f, *primals, has_aux=False): - """Returns the result and the VJP function of `f`. - - This function returns the result and the vector-Jacobian-product (VJP) - function of `f`. - - Args: - f: a function from (nested structures of) tf_np.ndarrays to a (nested - structure of) tf_np.ndarray. If `has_aux` is True, it should return an - extra output. - *primals: the inputs to be fed to `f`. - has_aux: if True, the second output of `f` will be regarded as an auxiliary, - non-differentiable output that will be ignored by the VJP function. - - Returns: - A pair `(y, vjpfun)` if `has_aux` is False; a tuple `(y, vjpfun, aux)` - otherwise. `y` and `aux` are the outputs of `f`, i.e. `y, aux = - f(*primals)`. `vjpfun` is a function `dx = vjpfun(dy)`, where `dy` is the - cotengents of `y`, having the same structures, shapes and dtypes as - `y`. `dx` is the cotengents of `x`, having the same structures, shapes and - dtypes as `x`. - """ - with tf.GradientTape(persistent=True) as tape: - tape.watch(tf.nest.flatten(primals)) - outputs = f(*primals) - if has_aux: - np_out, aux = outputs - else: - np_out = outputs + """Returns the result and the VJP function of `f`. - def _vjp(dy): - tf_dx = tape.gradient(np_out, primals, output_gradients=dy) - return _tf_to_np(tf_dx) + This function returns the result and the vector-Jacobian-product (VJP) + function of `f`. - if has_aux: - ret = (np_out, _vjp, aux) - else: - ret = (np_out, _vjp) - return ret + Args: + f: a function from (nested structures of) tf_np.ndarrays to a (nested + structure of) tf_np.ndarray. If `has_aux` is True, it should return an + extra output. + *primals: the inputs to be fed to `f`. + has_aux: if True, the second output of `f` will be regarded as an auxiliary, + non-differentiable output that will be ignored by the VJP function. + + Returns: + A pair `(y, vjpfun)` if `has_aux` is False; a tuple `(y, vjpfun, aux)` + otherwise. `y` and `aux` are the outputs of `f`, i.e. `y, aux = + f(*primals)`. `vjpfun` is a function `dx = vjpfun(dy)`, where `dy` is the + cotengents of `y`, having the same structures, shapes and dtypes as + `y`. `dx` is the cotengents of `x`, having the same structures, shapes and + dtypes as `x`. + """ + with tf.GradientTape(persistent=True) as tape: + tape.watch(tf.nest.flatten(primals)) + outputs = f(*primals) + if has_aux: + np_out, aux = outputs + else: + np_out = outputs + + def _vjp(dy): + tf_dx = tape.gradient(np_out, primals, output_gradients=dy) + return _tf_to_np(tf_dx) + + if has_aux: + ret = (np_out, _vjp, aux) + else: + ret = (np_out, _vjp) + return ret # TODO(wangpeng): match JAX's handling of kwargs and non-ndarray args def grad(f, has_aux=False): - """Returns a function that computes gradient of f. - - Gradients can only be computed through numpy and tensorflow operations and not - through python float operations and values. - - Args: - f: a function of type (params, *args) -> scalar. 'params' can be a nested - structure (made of lists and tuples) of ndarrays and the gradient is - evaluated against it. `scalar` is a scalar ndarray. - has_aux: bool, indicates whether fun returns a pair where the first element - is considered the output of the mathematical function to be differentiated - and the second element is auxiliary data. - - Returns: - A gradient function of type (params, *args) -> gradients, where the result - 'gradients' has the same structure and shapes as 'params'. - """ - - def check_loss_shape(np_loss): - if not isinstance(np_loss, tf_np.ndarray): - raise ValueError( - "The result of the function to take gradient must be an ndarray.") - if not np_loss.shape.is_compatible_with([]): - raise ValueError( - "The result of the function to take gradient must be a scalar.") - - def _f(params, *args): - """The gradient function to be returned.""" - with tf.GradientTape() as g: - g.watch(tf.nest.flatten(params)) - outputs = f(params, *args) - if has_aux: - np_loss, aux = outputs - else: - np_loss = outputs - check_loss_shape(np_loss) - tf_grads = g.gradient(np_loss, params) - if has_aux: - res = (tf_grads, aux) - else: - res = tf_grads - return _tf_to_np(res) - - return _f + """Returns a function that computes gradient of f. + + Gradients can only be computed through numpy and tensorflow operations and not + through python float operations and values. + + Args: + f: a function of type (params, *args) -> scalar. 'params' can be a nested + structure (made of lists and tuples) of ndarrays and the gradient is + evaluated against it. `scalar` is a scalar ndarray. + has_aux: bool, indicates whether fun returns a pair where the first element + is considered the output of the mathematical function to be differentiated + and the second element is auxiliary data. + + Returns: + A gradient function of type (params, *args) -> gradients, where the result + 'gradients' has the same structure and shapes as 'params'. + """ + + def check_loss_shape(np_loss): + if not isinstance(np_loss, tf_np.ndarray): + raise ValueError( + "The result of the function to take gradient must be an ndarray." + ) + # TensorFlow 1.x has + # TensorFlow 2.x does not contain such method as is_compatible_with. + # We can change it to compare the shape with () + if not np_loss.shape == (): + raise ValueError( + "The result of the function to take gradient must be a scalar." + ) + + def _f(params, *args): + """The gradient function to be returned.""" + with tf.GradientTape() as g: + g.watch(tf.nest.flatten(params)) + outputs = f(params, *args) + + if has_aux: + np_loss, aux = outputs + else: + np_loss = outputs + + check_loss_shape(np_loss) + + tf_grads = g.gradient(np_loss, params) + tf_grads = _tf_to_np(tf_grads) + + if has_aux: + res = (tf_grads, aux) + else: + res = tf_grads + return _tf_to_np(res) + + return _f def _record_result_type(recorder, f): - """A decorator that records some information about the function. - - Args: - recorder: a function of signature `(args, kwargs, res) -> res`. - f: the original function. - - Returns: - A transformed function that calls the original function and then the - recorder afterwards. - """ - def wrapper(*args, **kwargs): - res = f(*args, **kwargs) - res = recorder(args, kwargs, res) - return res - - return wrapper - - -def jit(f, - static_argnums=(), - xla_forced_compile=False, - input_signature=None, - autograph=False, - experimental_compile=False): - """Returns a function that runs a trace-compiled version of `f`. - - A trace-compiled version of a function `f` has the same behavior as `f` (when - called with the same "static arguments", see below), but runs faster because - the whole computation is compiled into a computation graph once which is - reused for subsequent executions. - - The trace compilation happens lazily, when the returned function is called for - the first time. The compiled function may not be cached implicitly and - multiple calls to `jit` may not share the compiled function (see below for - "static" vs "dynamic" arguments). - - Args: - f: a function that takes any positional arguments `args` and any keyword - arguments `kwargs`. `ndarray`s and things accepted by - `tf.convert_to_tensor` in `args` and `kwargs` will be treated as 'dynamic - arguments' in the sense that calling the function with different values - for these arguments will not cause retracing. In contrast, arguments of - other types in `args` and `kwargs` are treated as 'static arguments' and - calling the function with different values of them will cause - re-compiling. Positional arguments whose positions are in `static_argnums` - are always treated as static arguments. - static_argnums: a tuple of positions of arguments that will be treated as - static arguments. Note that as aforementioned, any arguments that were not - convertible to tensor will also be static. - xla_forced_compile: if true, it will use XLA to force-compile the graph. - This requires that the function only contain ops that are XLA - compatible. It will compile the entire function into a single XLA op. - input_signature: a list of `tf.TensorSpec`, as the input signature to - control tracing behavior. See the - [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of - `tf.function` for details. - autograph: whether to use autograph to convert Python constructs such as - `if` and `while` to their TensorFlow counterparts. See the - [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of - `tf.function` for details. - experimental_compile: the `experimental_compile` flag for `tf.function`. See - the [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of - `tf.function` for details. This is the recommended way to turn on XLA for - tf.function, but unlike xla_forced_compile, it doesn't force-compile the - entire function into a single XLA op. - - Returns: - A trace-compiled version of f. - """ - - @tf.function(input_signature=input_signature, autograph=autograph, - experimental_compile=experimental_compile) - def _tf_f(*args, **kwargs): - """Accelerated function with tensor inputs/outputs.""" - np_args = _tf_to_np(args) - kwargs = {k: _tf_to_np(v) for k, v in kwargs.items()} - if xla_forced_compile: - # Use list for mutability - output_is_list = [False] - output_is_empty = [False] - output_structure = [None] - def recorder(args, kwargs, res): - del args, kwargs - # Workaround b/121383831 - output_is_list[0] = isinstance(res, list) - # If outputs are empty, xla.compile returns an `Operation`, which we - # don't want. - if tf.nest.flatten(res): - output_is_empty[0] = False - output_structure[0] = None - else: - output_is_empty[0] = True - # Without deepcopy, xla.compile will change output_structure[0] to a - # list of `Operation`. - output_structure[0] = copy.deepcopy(res) + """A decorator that records some information about the function. + + Args: + recorder: a function of signature `(args, kwargs, res) -> res`. + f: the original function. + + Returns: + A transformed function that calls the original function and then the + recorder afterwards. + """ + + def wrapper(*args, **kwargs): + res = f(*args, **kwargs) + res = recorder(args, kwargs, res) return res - f_ = _record_result_type(recorder, f) - np_out = tf.xla.experimental.compile(lambda: f_(*np_args, **kwargs)) - # Workaround b/121383831 - if output_is_empty[0]: - np_out = output_structure[0] - elif (isinstance(np_out, list) and len(np_out) == 1 and - not output_is_list[0]): - np_out = np_out[0] - else: - np_out = f(*np_args, **kwargs) - return np_out - def _f(*args, **kwargs): - args = [ - _canonicalize_jit_arguments(arg) if i not in static_argnums else arg - for i, arg in enumerate(args) - ] - kwargs = {k: _canonicalize_jit_arguments(v) for k, v in kwargs.items()} - tf_out = _tf_f(*args, **kwargs) - return _tf_to_np(tf_out) + return wrapper + - _f.tf_function = _tf_f +def jit( + f, + static_argnums=(), + xla_forced_compile=False, + input_signature=None, + autograph=False, + experimental_compile=False, +): + """Returns a function that runs a trace-compiled version of `f`. - return _f + A trace-compiled version of a function `f` has the same behavior as `f` (when + called with the same "static arguments", see below), but runs faster because + the whole computation is compiled into a computation graph once which is + reused for subsequent executions. + + The trace compilation happens lazily, when the returned function is called for + the first time. The compiled function may not be cached implicitly and + multiple calls to `jit` may not share the compiled function (see below for + "static" vs "dynamic" arguments). + + Args: + f: a function that takes any positional arguments `args` and any keyword + arguments `kwargs`. `ndarray`s and things accepted by + `tf.convert_to_tensor` in `args` and `kwargs` will be treated as 'dynamic + arguments' in the sense that calling the function with different values + for these arguments will not cause retracing. In contrast, arguments of + other types in `args` and `kwargs` are treated as 'static arguments' and + calling the function with different values of them will cause + re-compiling. Positional arguments whose positions are in `static_argnums` + are always treated as static arguments. + static_argnums: a tuple of positions of arguments that will be treated as + static arguments. Note that as aforementioned, any arguments that were not + convertible to tensor will also be static. + xla_forced_compile: if true, it will use XLA to force-compile the graph. + This requires that the function only contain ops that are XLA + compatible. It will compile the entire function into a single XLA op. + input_signature: a list of `tf.TensorSpec`, as the input signature to + control tracing behavior. See the + [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of + `tf.function` for details. + autograph: whether to use autograph to convert Python constructs such as + `if` and `while` to their TensorFlow counterparts. See the + [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of + `tf.function` for details. + experimental_compile: the `experimental_compile` flag for `tf.function`. See + the [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of + `tf.function` for details. This is the recommended way to turn on XLA for + tf.function, but unlike xla_forced_compile, it doesn't force-compile the + entire function into a single XLA op. + + Returns: + A trace-compiled version of f. + """ + + @tf.function( + input_signature=input_signature, + autograph=autograph, + experimental_compile=experimental_compile, + ) + def _tf_f(*args, **kwargs): + """Accelerated function with tensor inputs/outputs.""" + np_args = _tf_to_np(args) + kwargs = {k: _tf_to_np(v) for k, v in kwargs.items()} + if xla_forced_compile: + # Use list for mutability + output_is_list = [False] + output_is_empty = [False] + output_structure = [None] + + def recorder(args, kwargs, res): + del args, kwargs + # Workaround b/121383831 + output_is_list[0] = isinstance(res, list) + # If outputs are empty, xla.compile returns an `Operation`, which we + # don't want. + if tf.nest.flatten(res): + output_is_empty[0] = False + output_structure[0] = None + else: + output_is_empty[0] = True + # Without deepcopy, xla.compile will change output_structure[0] to a + # list of `Operation`. + output_structure[0] = copy.deepcopy(res) + return res + + f_ = _record_result_type(recorder, f) + np_out = tf.xla.experimental.compile(lambda: f_(*np_args, **kwargs)) + # Workaround b/121383831 + if output_is_empty[0]: + np_out = output_structure[0] + elif ( + isinstance(np_out, list) and len(np_out) == 1 and not output_is_list[0] + ): + np_out = np_out[0] + else: + np_out = f(*np_args, **kwargs) + return np_out + + def _f(*args, **kwargs): + args = [ + _canonicalize_jit_arguments(arg) if i not in static_argnums else arg + for i, arg in enumerate(args) + ] + kwargs = {k: _canonicalize_jit_arguments(v) for k, v in kwargs.items()} + tf_out = _tf_f(*args, **kwargs) + return _tf_to_np(tf_out) + + _f.tf_function = _tf_f + + return _f def eval_on_shapes(f, static_argnums=(), allow_static_outputs=False): - """Returns a function that evaluates `f` given input shapes and dtypes. - - It transforms function `f` to a function that performs the same computation as - `f` but only on shapes and dtypes (a.k.a. shape inference). - - Args: - f: the function to be transformed. - static_argnums: see documentation of `jit`. - allow_static_outputs: whether to allow non-array outputs. If True, non-array - outputs (e.g. Python integers) will be returned as-is; otherwise, they - will be converted to ndarrays, and then specs of those ndarrays will be - returned. - - Returns: - A function whose input arguments can be either the same as `f`'s or only - their shapes/dtypes represented by `tf.TensorSpec`, and whose return values - are `tf.TensorSpec`s with the same nested structure as `f`'s return - values. If `allow_static_outputs` is True, when `f` returns some non-array - outputs (e.g. Python integers), the converted function will return them - as-is instead of returning `tf.TensorSpec`s for them. - """ - def abstractify(args): - def _abstractify(x): - x = _canonicalize_jit_arg(x) - if isinstance(x, (tf.Tensor, tf_np.ndarray)): - return tf.TensorSpec(x.shape, x.dtype) - else: - return x - new_args = [] - for i, arg in enumerate(args): - if i in static_argnums: - new_args.append(arg) - else: - new_args.append(tf.nest.map_structure(_abstractify, arg)) - return new_args - - if allow_static_outputs: - # When `tf_f` below is called (via get_concrete_function) with the same - # arugments (after abstraction), the Python function `f` won't be run, so we - # need this python_outputs_map to retrieve the Python outputs we've seen - # before that correspond the arguments. - python_outputs_map = {} - def recorder(args, kwargs, res): - # Since the get_concrete_function below only uses positional args, we also - # only positional args here. - del args, kwargs - def is_tensor_like(x): - if hasattr(x, "_type_spec"): - return True # x is a CompositeTensor - return isinstance(x, (tf_np.ndarray, tf.Tensor)) - py_values = tf.nest.map_structure( - lambda x: None if is_tensor_like(x) else x, - res) - key = id(tf.compat.v1.get_default_graph()) - python_outputs_map[key] = py_values - # Set non-tensor outputs to None to avoid tf.function calling - # tf.convert_to_tensor on them. - res = tf.nest.map_structure( - lambda x: None if not is_tensor_like(x) else x, - res) - return res - f = _record_result_type(recorder, f) - - # TODO(wangpeng): tf.function could add a knob to turn off materializing the - # graph, so that we don't waste computation and memory when we just want - # shape inference. - tf_f = jit(f, static_argnums=static_argnums).tf_function - - # pylint: disable=missing-docstring - def f_return(*args): - def to_tensor_spec(x): - if isinstance(x, tf.Tensor): - return tf.TensorSpec(x.shape, x.dtype) - else: - return x + """Returns a function that evaluates `f` given input shapes and dtypes. - new_args = abstractify(args) - cfun = tf_f.get_concrete_function(*new_args) - res = cfun.structured_outputs - res = tf.nest.map_structure(to_tensor_spec, res) + It transforms function `f` to a function that performs the same computation as + `f` but only on shapes and dtypes (a.k.a. shape inference). + + Args: + f: the function to be transformed. + static_argnums: see documentation of `jit`. + allow_static_outputs: whether to allow non-array outputs. If True, non-array + outputs (e.g. Python integers) will be returned as-is; otherwise, they + will be converted to ndarrays, and then specs of those ndarrays will be + returned. + + Returns: + A function whose input arguments can be either the same as `f`'s or only + their shapes/dtypes represented by `tf.TensorSpec`, and whose return values + are `tf.TensorSpec`s with the same nested structure as `f`'s return + values. If `allow_static_outputs` is True, when `f` returns some non-array + outputs (e.g. Python integers), the converted function will return them + as-is instead of returning `tf.TensorSpec`s for them. + """ + + def abstractify(args): + def _abstractify(x): + x = _canonicalize_jit_arg(x) + if isinstance(x, (tf.Tensor, tf_np.ndarray)): + return tf.TensorSpec(x.shape, x.dtype) + else: + return x + + new_args = [] + for i, arg in enumerate(args): + if i in static_argnums: + new_args.append(arg) + else: + new_args.append(tf.nest.map_structure(_abstractify, arg)) + return new_args if allow_static_outputs: - key = id(cfun.graph) - py_values = python_outputs_map[key] - # We can also call tf.get_static_value on structured_outputs to retrieve - # the Python values, but since we'll need to use python_outputs_map to - # record "which outputs are static?" anyway, we choose to directly store - # the Python values in python_outputs_map. - res = tf.nest.map_structure( - lambda x, python_value: x if python_value is None else python_value, - res, py_values) + # When `tf_f` below is called (via get_concrete_function) with the same + # arugments (after abstraction), the Python function `f` won't be run, so we + # need this python_outputs_map to retrieve the Python outputs we've seen + # before that correspond the arguments. + python_outputs_map = {} + + def recorder(args, kwargs, res): + # Since the get_concrete_function below only uses positional args, we also + # only positional args here. + del args, kwargs + + def is_tensor_like(x): + if hasattr(x, "_type_spec"): + return True # x is a CompositeTensor + return isinstance(x, (tf_np.ndarray, tf.Tensor)) + + py_values = tf.nest.map_structure( + lambda x: None if is_tensor_like(x) else x, res + ) + key = id(tf.compat.v1.get_default_graph()) + python_outputs_map[key] = py_values + # Set non-tensor outputs to None to avoid tf.function calling + # tf.convert_to_tensor on them. + res = tf.nest.map_structure( + lambda x: None if not is_tensor_like(x) else x, res + ) + return res + + f = _record_result_type(recorder, f) + + # TODO(wangpeng): tf.function could add a knob to turn off materializing the + # graph, so that we don't waste computation and memory when we just want + # shape inference. + tf_f = jit(f, static_argnums=static_argnums).tf_function + + # pylint: disable=missing-docstring + def f_return(*args): + def to_tensor_spec(x): + if isinstance(x, tf.Tensor): + return tf.TensorSpec(x.shape, x.dtype) + else: + return x + + new_args = abstractify(args) + cfun = tf_f.get_concrete_function(*new_args) + res = cfun.structured_outputs + res = tf.nest.map_structure(to_tensor_spec, res) + + if allow_static_outputs: + key = id(cfun.graph) + py_values = python_outputs_map[key] + # We can also call tf.get_static_value on structured_outputs to retrieve + # the Python values, but since we'll need to use python_outputs_map to + # record "which outputs are static?" anyway, we choose to directly store + # the Python values in python_outputs_map. + res = tf.nest.map_structure( + lambda x, python_value: x if python_value is None else python_value, + res, + py_values, + ) - return res + return res - # Provides access to `tf_f` for testing purpose. - f_return._tf_function = tf_f # pylint: disable=protected-access - return f_return + # Provides access to `tf_f` for testing purpose. + f_return._tf_function = tf_f # pylint: disable=protected-access + return f_return def _index_update_helper(updater, x, idx, y): - x = tf_np.asarray(x) - y = tf_np.asarray(y) - # TODO(b/164251540): Remove this expensive manual broadcasting once - # tf.raw_ops.tensor_strided_slice_update and tf.tensor_scatter_nd_update - # support broadcasting. - y = tf.broadcast_to(y, tf.shape(x[idx])) - return updater(x, idx, y) + x = tf_np.asarray(x) + y = tf_np.asarray(y) + # TODO(b/164251540): Remove this expensive manual broadcasting once + # tf.raw_ops.tensor_strided_slice_update and tf.tensor_scatter_nd_update + # support broadcasting. + y = tf.broadcast_to(y, tf.shape(x[idx])) + return updater(x, idx, y) # pylint: disable=protected-access def index_update(x, idx, y): - """Pure equivalent of `x[idx] = y`. - - Returns the value of x that would result from the NumPy-style indexed - assignment `x[idx] = y`. Because it's a pure function, `x` itself won't be - changed. + """Pure equivalent of `x[idx] = y`. - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, slice objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. - y: the array of updates. `y` must be broadcastable to the shape of the array - that would be returned by `x[idx]`. + Returns the value of x that would result from the NumPy-style indexed + assignment `x[idx] = y`. Because it's a pure function, `x` itself won't be + changed. - Returns: - The updated version of `x`. - """ - return _index_update_helper(tf_np.ndarray._with_index_update, x, idx, y) + Args: + x: an array with the values to be updated. + idx: a Numpy-style index, consisting of `None`, integers, slice objects, + ellipses, ndarrays with integer dtypes, or a tuple of the above. + y: the array of updates. `y` must be broadcastable to the shape of the array + that would be returned by `x[idx]`. + + Returns: + The updated version of `x`. + """ + return _index_update_helper(tf_np.ndarray._with_index_update, x, idx, y) def index_add(x, idx, y): - """Pure equivalent of `x[idx] += y`. - - Returns the value of x that would result from the NumPy-style indexed - assignment `x[idx] += y`. Because it's a pure function, `x` itself won't be - changed. + """Pure equivalent of `x[idx] += y`. - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, slice objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. - y: the array of updates. `y` must be broadcastable to the shape of the array - that would be returned by `x[idx]`. + Returns the value of x that would result from the NumPy-style indexed + assignment `x[idx] += y`. Because it's a pure function, `x` itself won't be + changed. - Returns: - The updated version of `x`. - """ - return _index_update_helper(tf_np.ndarray._with_index_add, x, idx, y) + Args: + x: an array with the values to be updated. + idx: a Numpy-style index, consisting of `None`, integers, slice objects, + ellipses, ndarrays with integer dtypes, or a tuple of the above. + y: the array of updates. `y` must be broadcastable to the shape of the array + that would be returned by `x[idx]`. + + Returns: + The updated version of `x`. + """ + return _index_update_helper(tf_np.ndarray._with_index_add, x, idx, y) def index_min(x, idx, y): - """Pure equivalent of `x[idx] = minimum(x[idx], y)`. + """Pure equivalent of `x[idx] = minimum(x[idx], y)`. - Returns the value of x that would result from the NumPy-style indexed - assignment `x[idx] = minimum(x[idx], y)`. Because it's a pure function, `x` - itself won't be changed. + Returns the value of x that would result from the NumPy-style indexed + assignment `x[idx] = minimum(x[idx], y)`. Because it's a pure function, `x` + itself won't be changed. - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, slice objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. - y: the array of updates. `y` must be broadcastable to the shape of the array - that would be returned by `x[idx]`. - - Returns: - The updated version of `x`. - """ - return _index_update_helper(tf_np.ndarray._with_index_min, x, idx, y) + Args: + x: an array with the values to be updated. + idx: a Numpy-style index, consisting of `None`, integers, slice objects, + ellipses, ndarrays with integer dtypes, or a tuple of the above. + y: the array of updates. `y` must be broadcastable to the shape of the array + that would be returned by `x[idx]`. + + Returns: + The updated version of `x`. + """ + return _index_update_helper(tf_np.ndarray._with_index_min, x, idx, y) def index_max(x, idx, y): - """Pure equivalent of `x[idx] = maximum(x[idx], y)`. - - Returns the value of x that would result from the NumPy-style indexed - assignment `x[idx] = maximum(x[idx], y)`. Because it's a pure function, `x` - itself won't be changed. - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, slice objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. - y: the array of updates. `y` must be broadcastable to the shape of the array - that would be returned by `x[idx]`. - - Returns: - The updated version of `x`. - """ - return _index_update_helper(tf_np.ndarray._with_index_max, x, idx, y) + """Pure equivalent of `x[idx] = maximum(x[idx], y)`. + + Returns the value of x that would result from the NumPy-style indexed + assignment `x[idx] = maximum(x[idx], y)`. Because it's a pure function, `x` + itself won't be changed. + + Args: + x: an array with the values to be updated. + idx: a Numpy-style index, consisting of `None`, integers, slice objects, + ellipses, ndarrays with integer dtypes, or a tuple of the above. + y: the array of updates. `y` must be broadcastable to the shape of the array + that would be returned by `x[idx]`. + + Returns: + The updated version of `x`. + """ + return _index_update_helper(tf_np.ndarray._with_index_max, x, idx, y) + + # pylint: enable=protected-access def logsumexp(x, axis=None, keepdims=None): - """Computes log(sum(exp(elements across dimensions of a tensor))). - - Reduces `x` along the dimensions given in `axis`. - Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a - tensor with a single element is returned. - This function is more numerically stable than log(sum(exp(input))). It avoids - overflows caused by taking the exp of large inputs and underflows caused by - taking the log of small inputs. - - Args: - x: The tensor to reduce. Should have numeric type. - axis: The dimensions to reduce. If `None` (the default), reduces all - dimensions. Must be in the range `[-rank(x), rank(x))`. - keepdims: If true, retains reduced dimensions with length 1. - - Returns: - The reduced tensor. - """ - return tf_np.asarray( - tf.math.reduce_logsumexp( - input_tensor=x, axis=axis, keepdims=keepdims)) + """Computes log(sum(exp(elements across dimensions of a tensor))). + + Reduces `x` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + If `axis` has no entries, all dimensions are reduced, and a + tensor with a single element is returned. + This function is more numerically stable than log(sum(exp(input))). It avoids + overflows caused by taking the exp of large inputs and underflows caused by + taking the log of small inputs. + + Args: + x: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(x), rank(x))`. + keepdims: If true, retains reduced dimensions with length 1. + + Returns: + The reduced tensor. + """ + return tf_np.asarray( + tf.math.reduce_logsumexp(input_tensor=x, axis=axis, keepdims=keepdims) + ) def expit(x): - """Compute 1 / (1 + exp(-x)).""" - return tf_np.asarray(tf.math.sigmoid(x)) + """Compute 1 / (1 + exp(-x)).""" + return tf_np.asarray(tf.math.sigmoid(x)) def erf(x): - """Computes the Gauss error function of x element-wise.""" - return tf_np.asarray(tf.math.erf(x)) + """Computes the Gauss error function of x element-wise.""" + return tf_np.asarray(tf.math.erf(x)) def _minus(a, b): - return [x for x in a if x not in b] - + return [x for x in a if x not in b] -def _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, - lhs_batch, rhs_batch): - """Compose the output string representation. - e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik - aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik +def _compose_output_rep( + lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, lhs_batch, rhs_batch +): + """Compose the output string representation. - Args: - lhs_rep: A string representation for the left-hand side input array - rhs_rep: A string representation for the right-hand side input array - lhs_contraction: Sequence[int] (the contraction dimensions of lhs) - rhs_contraction: Sequence[int] (the contraction dimensions of rhs) - lhs_batch: Sequence[int] (the batch dimensions of lhs) - rhs_batch: Sequence[int] (the batch dimensions of rhs) + e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik + aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik - Returns: - A string representation of the result array. - """ - output_rep = [] - for dim in lhs_batch: - output_rep.append(lhs_rep[dim]) + Args: + lhs_rep: A string representation for the left-hand side input array + rhs_rep: A string representation for the right-hand side input array + lhs_contraction: Sequence[int] (the contraction dimensions of lhs) + rhs_contraction: Sequence[int] (the contraction dimensions of rhs) + lhs_batch: Sequence[int] (the batch dimensions of lhs) + rhs_batch: Sequence[int] (the batch dimensions of rhs) + + Returns: + A string representation of the result array. + """ + output_rep = [] + for dim in lhs_batch: + output_rep.append(lhs_rep[dim]) - for i in _minus(range(len(lhs_rep)), lhs_batch + lhs_contraction): - output_rep.append(lhs_rep[i]) - for i in _minus(range(len(rhs_rep)), rhs_batch + rhs_contraction): - output_rep.append(rhs_rep[i]) - return "".join(output_rep) + for i in _minus(range(len(lhs_rep)), lhs_batch + lhs_contraction): + output_rep.append(lhs_rep[i]) + for i in _minus(range(len(rhs_rep)), rhs_batch + rhs_contraction): + output_rep.append(rhs_rep[i]) + return "".join(output_rep) def _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction): - """Compute the non-batched matrix multiplication. + """Compute the non-batched matrix multiplication. - If it is the general non-batched/single-batched matrix multiplication, - use the highly optimized kernel `tf.tensordot` to handle it. + If it is the general non-batched/single-batched matrix multiplication, + use the highly optimized kernel `tf.tensordot` to handle it. - Args: - lhs: an array (the left-hand side matrix/vector to be multiplied) - rhs: an array (the right-hand side matrix/vector to be multiplied) - lhs_contraction: Sequence[int] (the contraction dimensions of lhs) - rhs_contraction: Sequence[int] (the contraction dimensions of rhs) + Args: + lhs: an array (the left-hand side matrix/vector to be multiplied) + rhs: an array (the right-hand side matrix/vector to be multiplied) + lhs_contraction: Sequence[int] (the contraction dimensions of lhs) + rhs_contraction: Sequence[int] (the contraction dimensions of rhs) - Returns: - An array that contains the result. - """ - return tf.tensordot( - lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction))) + Returns: + An array that contains the result. + """ + return tf.tensordot(lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction))) def tf_dot_general(lhs, rhs, dimension_numbers): - """The general dot operation for TensorFlow. - - An equivalent general dot operation as that in JAX - - - Although there is an implementation in TF XLA, avoid directly using XLA when - possible. - - e.g., non-batched: ij,jk->ik - batched: ijk,ikl->ijl - - Args: - lhs: an array (the left-hand side matrix/vector to be multiplied) - rhs: an array (the right-hand side matrix/vector to be multiplied) - dimension_numbers: (Tuple[Tuple[Sequence[int], Sequence[int]], - Tuple[Sequence[int], Sequence[int]]]) – a tuple of tuples of the form - ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, - rhs_batch_dims)) - - Returns: - An array that contains the result. - """ - char_list = list(string.ascii_lowercase) - char_list = char_list[8:] + char_list[:8] - lhs_rank, rhs_rank = len(lhs.shape), len(rhs.shape) - lhs_rep = char_list[:lhs_rank] - rhs_rep = char_list[lhs_rank:lhs_rank + rhs_rank] - contraction, batch = dimension_numbers - lhs_contraction, rhs_contraction = contraction - if len(lhs_contraction) != len(rhs_contraction): - raise ValueError( - "The input matrices are required to have the same number " - "of contraction dimensions, but got: lhs {}, rhs: {}".format( - len(lhs_contraction), len(rhs_contraction))) - lhs_batch, rhs_batch = batch - if len(lhs_batch) != len(rhs_batch): - raise ValueError("The input matrices are required to have the same number " - "of batch dimensions, but got: lhs {}, rhs: {}".format( - len(lhs_batch), len(rhs_batch))) - - if not lhs_batch and not rhs_batch: - return _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction) - - if (lhs_rank == rhs_rank == 3 and lhs_batch == (0,) and rhs_batch == (0,) and - lhs_contraction == (2,) and rhs_contraction == (1,)): - return tf.linalg.matmul(lhs, rhs) - - for i in range(len(lhs_contraction)): - rhs_rep[rhs_contraction[i]] = lhs_rep[lhs_contraction[i]] - for i in range(len(lhs_batch)): - rhs_rep[rhs_batch[i]] = lhs_rep[lhs_batch[i]] - - output_rep = _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, - rhs_contraction, lhs_batch, rhs_batch) - equation = "".join(lhs_rep) + "," + "".join(rhs_rep) + "->" + output_rep - return tf.einsum(equation, lhs, rhs) - - -def _conv_general_param_type_converter(window_strides, lhs_dilation, - rhs_dilation, dim): - """Convert strides, lhs_dilation, rhs_dilation to match TF convention. - - For example, - in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] - if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] - - Args: - window_strides: window_strides to be converted - lhs_dilation: lhs_dilation to be converted - rhs_dilation: rhs_dilation to be converted - dim: dim to be converted - - Returns: - The updated window_strides, lhs_dilation and rhs_dilation - """ - def _as_list_of_size(item, size): - if item is None: - return None - return [item] * size if isinstance(item, int) else list(item) - return (_as_list_of_size(window_strides, dim), - _as_list_of_size(lhs_dilation, dim), - _as_list_of_size(rhs_dilation, dim)) + """The general dot operation for TensorFlow. + + An equivalent general dot operation as that in JAX - + + Although there is an implementation in TF XLA, avoid directly using XLA when + possible. + + e.g., non-batched: ij,jk->ik + batched: ijk,ikl->ijl + + Args: + lhs: an array (the left-hand side matrix/vector to be multiplied) + rhs: an array (the right-hand side matrix/vector to be multiplied) + dimension_numbers: (Tuple[Tuple[Sequence[int], Sequence[int]], + Tuple[Sequence[int], Sequence[int]]]) – a tuple of tuples of the form + ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + rhs_batch_dims)) + + Returns: + An array that contains the result. + """ + char_list = list(string.ascii_lowercase) + char_list = char_list[8:] + char_list[:8] + lhs_rank, rhs_rank = len(lhs.shape), len(rhs.shape) + lhs_rep = char_list[:lhs_rank] + rhs_rep = char_list[lhs_rank : lhs_rank + rhs_rank] + contraction, batch = dimension_numbers + lhs_contraction, rhs_contraction = contraction + if len(lhs_contraction) != len(rhs_contraction): + raise ValueError( + "The input matrices are required to have the same number " + "of contraction dimensions, but got: lhs {}, rhs: {}".format( + len(lhs_contraction), len(rhs_contraction) + ) + ) + lhs_batch, rhs_batch = batch + if len(lhs_batch) != len(rhs_batch): + raise ValueError( + "The input matrices are required to have the same number " + "of batch dimensions, but got: lhs {}, rhs: {}".format( + len(lhs_batch), len(rhs_batch) + ) + ) + + if not lhs_batch and not rhs_batch: + return _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction) + + if ( + lhs_rank == rhs_rank == 3 + and lhs_batch == (0,) + and rhs_batch == (0,) + and lhs_contraction == (2,) + and rhs_contraction == (1,) + ): + return tf.linalg.matmul(lhs, rhs) + + for i in range(len(lhs_contraction)): + rhs_rep[rhs_contraction[i]] = lhs_rep[lhs_contraction[i]] + for i in range(len(lhs_batch)): + rhs_rep[rhs_batch[i]] = lhs_rep[lhs_batch[i]] + + output_rep = _compose_output_rep( + lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, lhs_batch, rhs_batch + ) + equation = "".join(lhs_rep) + "," + "".join(rhs_rep) + "->" + output_rep + return tf.einsum(equation, lhs, rhs) + + +def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation, dim): + """Convert strides, lhs_dilation, rhs_dilation to match TF convention. + + For example, + in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] + if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] + + Args: + window_strides: window_strides to be converted + lhs_dilation: lhs_dilation to be converted + rhs_dilation: rhs_dilation to be converted + dim: dim to be converted + + Returns: + The updated window_strides, lhs_dilation and rhs_dilation + """ + + def _as_list_of_size(item, size): + if item is None: + return None + return [item] * size if isinstance(item, int) else list(item) + + return ( + _as_list_of_size(window_strides, dim), + _as_list_of_size(lhs_dilation, dim), + _as_list_of_size(rhs_dilation, dim), + ) # pylint: disable=g-bad-todo @@ -755,518 +820,588 @@ def _as_list_of_size(item, size): # TODO(DarrenZhang01): Support feature_group_count, batch_group_count and # precision, and allow lhs_dilation and rhs_dilation to happen at the same time. # pylint: enable=g-bad-todo -def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, - lhs_dilation=None, rhs_dilation=None, - dimension_numbers=None, feature_group_count=1, - batch_group_count=1, precision=None): - """A general conv API for TensorFlow. - - According JAX version: - https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html - - Args: - lhs: a rank n+2 dimensional input array. - rhs: a rank n+2 dimensional array of kernel weights. - window_strides: a sequence of n integers, representing the inter-window - strides. - padding: either the string β€˜SAME’, the string β€˜VALID’, or a sequence of n - (low, high) integer pairs that give the padding to apply before and - after each spatial dimension. - output_shape: the output shape of the convolution (only required for - transpose convolution). - lhs_dilation: None, or a sequence of n integers, giving the dilation factor - to apply in each spatial dimension of lhs. LHS dilation is - also known as transposed convolution. - rhs_dilation: None, or a sequence of n integers, giving the dilation factor - to apply in each spatial dimension of rhs. RHS dilation is - also known as atrous convolution. - dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple - (lhs_spec, rhs_spec, out_spec), where each element is a - string of length n+2. - feature_group_count: integer, default 1. Changing this is currently not - supported. - batch_group_count: integer, default 1. Changing this is currently not - supported. - precision: Optional. Either None, which means the default precision for the - backend, or a Precision enum value. - - Returns: - A TF NumPy array that contains the convolution result. - """ - dim = None - lhs_spec, rhs_spec, out_spec = dimension_numbers - if lhs_spec != out_spec: - raise ValueError("Current implementation requires the `data_format` of the " - "inputs and outputs to be the same.") - if len(lhs_spec) >= 6: - raise ValueError("Current implmentation does not support 4 or higher" - "dimensional convolution, but got: ", len(lhs_spec) - 2) - dim = len(lhs_spec) - 2 - if lhs_dilation and rhs_dilation: - if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: - lhs_dilation, rhs_dilation = None, None +def tf_conv_general_dilated( + lhs, + rhs, + window_strides, + padding, + output_shape, + lhs_dilation=None, + rhs_dilation=None, + dimension_numbers=None, + feature_group_count=1, + batch_group_count=1, + precision=None, +): + """A general conv API for TensorFlow. + + According JAX version: + https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html + + Args: + lhs: a rank n+2 dimensional input array. + rhs: a rank n+2 dimensional array of kernel weights. + window_strides: a sequence of n integers, representing the inter-window + strides. + padding: either the string β€˜SAME’, the string β€˜VALID’, or a sequence of n + (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. + output_shape: the output shape of the convolution (only required for + transpose convolution). + lhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of lhs. LHS dilation is + also known as transposed convolution. + rhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of rhs. RHS dilation is + also known as atrous convolution. + dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple + (lhs_spec, rhs_spec, out_spec), where each element is a + string of length n+2. + feature_group_count: integer, default 1. Changing this is currently not + supported. + batch_group_count: integer, default 1. Changing this is currently not + supported. + precision: Optional. Either None, which means the default precision for the + backend, or a Precision enum value. + + Returns: + A TF NumPy array that contains the convolution result. + """ + dim = None + lhs_spec, rhs_spec, out_spec = dimension_numbers + if lhs_spec != out_spec: + raise ValueError( + "Current implementation requires the `data_format` of the " + "inputs and outputs to be the same." + ) + if len(lhs_spec) >= 6: + raise ValueError( + "Current implmentation does not support 4 or higher" + "dimensional convolution, but got: ", + len(lhs_spec) - 2, + ) + dim = len(lhs_spec) - 2 + if lhs_dilation and rhs_dilation: + if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: + lhs_dilation, rhs_dilation = None, None + else: + raise ValueError( + "Current implementation does not support that " + "deconvolution and dilation to be performed at the same " + "time, but got lhs_dilation: {}, rhs_dilation: {}".format( + lhs_dilation, rhs_dilation + ) + ) + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Current implementation requires the padding parameter" + "to be either 'VALID' or 'SAME', but got: ", + padding, + ) + if batch_group_count != 1 or feature_group_count != 1: + raise NotImplementedError( + "batch_group_count and feature_group_count " + "other than 1 is currently not supported, but" + " got feature_group_count: {}, batch_group_count" + ": {}".format(feature_group_count, batch_group_count) + ) + if precision is not None: + raise NotImplementedError( + "precision other than `None` is currently not " + "supported, but got: {}".format(precision) + ) + # Convert params from int/Sequence[int] to list of ints. + strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( + window_strides, lhs_dilation, rhs_dilation, dim + ) + # Preprocess the shapes + dim_maps = {} + if isinstance(lhs_spec, str): + dim_maps["I"] = list(rhs_spec).index("I") + dim_maps["O"] = list(rhs_spec).index("O") + dim_maps["N"] = list(lhs_spec).index("N") + dim_maps["C"] = list(lhs_spec).index("C") else: - raise ValueError("Current implementation does not support that " - "deconvolution and dilation to be performed at the same " - "time, but got lhs_dilation: {}, rhs_dilation: {}" - .format(lhs_dilation, rhs_dilation)) - if padding not in ["SAME", "VALID"]: - raise ValueError("Current implementation requires the padding parameter" - "to be either 'VALID' or 'SAME', but got: ", padding) - if batch_group_count != 1 or feature_group_count != 1: - raise NotImplementedError("batch_group_count and feature_group_count " - "other than 1 is currently not supported, but" - " got feature_group_count: {}, batch_group_count" - ": {}".format(feature_group_count, - batch_group_count)) - if precision is not None: - raise NotImplementedError("precision other than `None` is currently not " - "supported, but got: {}".format(precision)) - # Convert params from int/Sequence[int] to list of ints. - strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( - window_strides, lhs_dilation, rhs_dilation, dim - ) - # Preprocess the shapes - dim_maps = {} - if isinstance(lhs_spec, str): - dim_maps["I"] = list(rhs_spec).index("I") - dim_maps["O"] = list(rhs_spec).index("O") - dim_maps["N"] = list(lhs_spec).index("N") - dim_maps["C"] = list(lhs_spec).index("C") - else: - dim_maps["I"] = rhs_spec[1] - dim_maps["O"] = rhs_spec[0] - dim_maps["N"] = lhs_spec[0] - dim_maps["C"] = lhs_spec[1] - - lhs = tf_np.moveaxis(lhs, (dim_maps["N"], dim_maps["C"]), (0, dim + 1)) - # Adjust the filters, put the dimension 'I' and 'O' at last. - rhs = tf_np.moveaxis(rhs, (dim_maps["O"], dim_maps["I"]), (dim + 1, dim)) - spatial_dim_maps = {1: "W", 2: "HW", 3: "DHW"} - data_format = "N" + spatial_dim_maps[dim] + "C" - - if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): - output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, - rhs_dilation) - else: - output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, - padding, data_format, lhs_dilation) - output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps["N"], dim_maps["C"])) - return output - - -def conv(inp, - fltr, - window_strides, - padding, - dimension_numbers, - filter_dilation=None): - """Convolution over an N-D array. - - See https://www.tensorflow.org/api_docs/python/tf/nn/convolution and - https://www.tensorflow.org/xla/operation_semantics#conv_convolution for - reference. - - Args: - inp: an (N+2)-D array. The input of the convolution. - fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. - window_strides: a sequence of N ints, the strides for moving the convolution - window. - padding: a string, either "VALID" or "SAME". The padding algorithm. - dimension_numbers: a tuple of three strings encoding the data format of - input, filter and output. "I" means input; "O" means output; "C" means - channel; other characters such as "W", "H" and "D" means spatial - dimensions. - filter_dilation: the dilation rates for the filter. Dilating the filter - means adding "holes" to the filter. - - Returns: - An (N+2)-D array. The convolution result. - """ - input_spec, filter_spec, output_spec = dimension_numbers - if input_spec != output_spec: - raise ValueError("Input and output data formats must be the same; got %s " - "and %s" % (input_spec, output_spec)) - supported_filter_spec = ["WIO", "HWIO", "DHWIO"] - if filter_spec not in supported_filter_spec: - raise ValueError("The supported data format for the filter are %s; got %s" % - (supported_filter_spec, filter_spec)) - if input_spec[1:-1] != filter_spec[:-2]: - raise ValueError("Input data format (%s) is not compatible with filter " - "data format (%s)" % (input_spec, filter_spec)) - # No type promotion in order to prevent accidentally doing more expensive - # computation. - dtype = tf_np.result_type(inp, fltr) - inp = tf_np.asarray(inp, dtype) - fltr = tf_np.asarray(fltr, dtype) - return tf_np.asarray( - tf.nn.convolution( - input=inp, - filters=fltr, - padding=padding, - strides=window_strides, - dilations=filter_dilation, - data_format=input_spec)) + dim_maps["I"] = rhs_spec[1] + dim_maps["O"] = rhs_spec[0] + dim_maps["N"] = lhs_spec[0] + dim_maps["C"] = lhs_spec[1] + + lhs = tf_np.moveaxis(lhs, (dim_maps["N"], dim_maps["C"]), (0, dim + 1)) + # Adjust the filters, put the dimension 'I' and 'O' at last. + rhs = tf_np.moveaxis(rhs, (dim_maps["O"], dim_maps["I"]), (dim + 1, dim)) + spatial_dim_maps = {1: "W", 2: "HW", 3: "DHW"} + data_format = "N" + spatial_dim_maps[dim] + "C" + + if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): + output = _tf_nn_APIs[dim][0]( + lhs, rhs, strides, padding, data_format, rhs_dilation + ) + else: + output = _tf_nn_APIs[dim][1]( + lhs, + rhs, + tf.constant(output_shape), + strides, + padding, + data_format, + lhs_dilation, + ) + output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps["N"], dim_maps["C"])) + return output + + +def conv(inp, fltr, window_strides, padding, dimension_numbers, filter_dilation=None): + """Convolution over an N-D array. + + See https://www.tensorflow.org/api_docs/python/tf/nn/convolution and + https://www.tensorflow.org/xla/operation_semantics#conv_convolution for + reference. + + Args: + inp: an (N+2)-D array. The input of the convolution. + fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. + window_strides: a sequence of N ints, the strides for moving the convolution + window. + padding: a string, either "VALID" or "SAME". The padding algorithm. + dimension_numbers: a tuple of three strings encoding the data format of + input, filter and output. "I" means input; "O" means output; "C" means + channel; other characters such as "W", "H" and "D" means spatial + dimensions. + filter_dilation: the dilation rates for the filter. Dilating the filter + means adding "holes" to the filter. + + Returns: + An (N+2)-D array. The convolution result. + """ + input_spec, filter_spec, output_spec = dimension_numbers + if input_spec != output_spec: + raise ValueError( + "Input and output data formats must be the same; got %s " + "and %s" % (input_spec, output_spec) + ) + supported_filter_spec = ["WIO", "HWIO", "DHWIO"] + if filter_spec not in supported_filter_spec: + raise ValueError( + "The supported data format for the filter are %s; got %s" + % (supported_filter_spec, filter_spec) + ) + if input_spec[1:-1] != filter_spec[:-2]: + raise ValueError( + "Input data format (%s) is not compatible with filter " + "data format (%s)" % (input_spec, filter_spec) + ) + # No type promotion in order to prevent accidentally doing more expensive + # computation. + dtype = tf_np.result_type(inp, fltr) + inp = tf_np.asarray(inp, dtype) + fltr = tf_np.asarray(fltr, dtype) + return tf_np.asarray( + tf.nn.convolution( + input=inp, + filters=fltr, + padding=padding, + strides=window_strides, + dilations=filter_dilation, + data_format=input_spec, + ) + ) def avg_pool(x, pool_size, strides, padding): - """Performs an N-D average pooling. - - Args: - x: ndarray of rank N+2, of shape `[batch_size] + input_spatial_shape + - [num_channels]`. Pooling happens over the spatial dimensions only. - pool_size: sequence of N ints. - strides: sequence of N ints. - padding: a string, the padding algorithm. Must be "SAME" or "VALID". - - Returns: - An (N+2)-D array, of shape - [batch_size] + output_spatial_shape + [num_channels], - where `output_spatial_shape` depends on the value of padding: - If padding = "SAME": - output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) - If padding = "VALID": - output_spatial_shape[i] = - ceil((input_spatial_shape[i] - (pool_size[i] - 1)) / strides[i]). - """ - x = tf_np.asarray(x) - return tf_np.asarray( - tf.nn.pool( - input=x, - window_shape=pool_size, - pooling_type="AVG", - strides=strides, - padding=padding)) + """Performs an N-D average pooling. + + Args: + x: ndarray of rank N+2, of shape `[batch_size] + input_spatial_shape + + [num_channels]`. Pooling happens over the spatial dimensions only. + pool_size: sequence of N ints. + strides: sequence of N ints. + padding: a string, the padding algorithm. Must be "SAME" or "VALID". + + Returns: + An (N+2)-D array, of shape + [batch_size] + output_spatial_shape + [num_channels], + where `output_spatial_shape` depends on the value of padding: + If padding = "SAME": + output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) + If padding = "VALID": + output_spatial_shape[i] = + ceil((input_spatial_shape[i] - (pool_size[i] - 1)) / strides[i]). + """ + x = tf_np.asarray(x) + return tf_np.asarray( + tf.nn.pool( + input=x, + window_shape=pool_size, + pooling_type="AVG", + strides=strides, + padding=padding, + ) + ) def max_pool(x, pool_size, strides, padding): - """Performs an N-D max pooling. - - Args: - x: ndarray of rank N+2, of shape `[batch_size] + input_spatial_shape + - [num_channels]`. Pooling happens over the spatial dimensions only. - pool_size: sequence of N ints. - strides: sequence of N ints. - padding: a string, the padding algorithm. Must be "SAME" or "VALID". - - Returns: - An (N+2)-D array, of shape - [batch_size] + output_spatial_shape + [num_channels], - where `output_spatial_shape` depends on the value of padding: - If padding = "SAME": - output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) - If padding = "VALID": - output_spatial_shape[i] = - ceil((input_spatial_shape[i] - (pool_size[i] - 1)) / strides[i]). - """ - x = tf_np.asarray(x) - return tf_np.asarray( - tf.nn.pool( - input=x, - window_shape=pool_size, - pooling_type="MAX", - strides=strides, - padding=padding)) + """Performs an N-D max pooling. + + Args: + x: ndarray of rank N+2, of shape `[batch_size] + input_spatial_shape + + [num_channels]`. Pooling happens over the spatial dimensions only. + pool_size: sequence of N ints. + strides: sequence of N ints. + padding: a string, the padding algorithm. Must be "SAME" or "VALID". + + Returns: + An (N+2)-D array, of shape + [batch_size] + output_spatial_shape + [num_channels], + where `output_spatial_shape` depends on the value of padding: + If padding = "SAME": + output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) + If padding = "VALID": + output_spatial_shape[i] = + ceil((input_spatial_shape[i] - (pool_size[i] - 1)) / strides[i]). + """ + x = tf_np.asarray(x) + return tf_np.asarray( + tf.nn.pool( + input=x, + window_shape=pool_size, + pooling_type="MAX", + strides=strides, + padding=padding, + ) + ) def sort_key_val(keys, values, dimension=-1): - """Sorts keys along a dimension and applies same permutation to values. - - Args: - keys: an array. The dtype must be comparable numbers (integers and reals). - values: an array, with the same shape of `keys`. - dimension: an `int`. The dimension along which to sort. - - Returns: - Permuted keys and values. - """ - keys = tf_np.asarray(keys) - values = tf_np.asarray(values) - rank = keys.shape.ndims - if rank is None: - rank = values.shape.ndims - if rank is None: - # We need to know the rank because tf.gather requires batch_dims to be `int` - raise ValueError("The rank of either keys or values must be known, but " - "both are unknown (i.e. their shapes are both None).") - if dimension in (-1, rank - 1): - - def maybe_swapaxes(a): - return a - else: - - def maybe_swapaxes(a): - return tf_np.swapaxes(a, dimension, -1) - - # We need to swap axes because tf.gather (and tf.gather_nd) supports - # batch_dims on the left but not on the right. - # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis. - keys = maybe_swapaxes(keys) - values = maybe_swapaxes(values) - idxs = tf_np.argsort(keys) - - # Using tf.gather rather than np.take because the former supports batch_dims - def gather(a): - return tf_np.asarray(tf.gather(a, idxs, batch_dims=rank - 1)) - - keys = gather(keys) - values = gather(values) - keys = maybe_swapaxes(keys) - values = maybe_swapaxes(values) - return keys, values + """Sorts keys along a dimension and applies same permutation to values. + Args: + keys: an array. The dtype must be comparable numbers (integers and reals). + values: an array, with the same shape of `keys`. + dimension: an `int`. The dimension along which to sort. + + Returns: + Permuted keys and values. + """ + keys = tf_np.asarray(keys) + values = tf_np.asarray(values) + rank = keys.shape.ndims + if rank is None: + rank = values.shape.ndims + if rank is None: + # We need to know the rank because tf.gather requires batch_dims to be `int` + raise ValueError( + "The rank of either keys or values must be known, but " + "both are unknown (i.e. their shapes are both None)." + ) + if dimension in (-1, rank - 1): + + def maybe_swapaxes(a): + return a -def scan(f, init, xs, length=None, reverse=False): - """Scan a function over leading array axes while carrying along state. - - See the docstring of `jax.lax.scan` - (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) for - details. - - Args: - f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning - that ``f`` accepts two arguments where the first is a value of the loop - carry and the second is a slice of ``xs`` along its leading axis, and that - ``f`` returns a pair where the first element represents a new value for - the loop carry and the second represents a slice of the output. Note that - the input and output carry must have the same dtype. - init: an initial loop carry value of type ``c``, which can be a scalar, - array, or any pytree (nested Python tuple/list/dict) thereof, representing - the initial loop carry value. This value must have the same structure as - the first element of the pair returned by ``f``. - xs: the value of type ``[a]`` over which to scan along the leading axis, - where ``[a]`` can be an array or any pytree (nested Python - tuple/list/dict) thereof with consistent leading axis sizes. - length: optional integer specifying the number of loop iterations, which - must agree with the sizes of leading axes of the arrays in ``xs`` (but can - be used to perform scans where no input ``xs`` are needed). - reverse: optional boolean specifying whether to run the scan iteration - forward (the default) or in reverse, equivalent to reversing the leading - axes of the arrays in both ``xs`` and in ``ys``. - - Returns: - A pair of type ``(c, [b])`` where the first element represents the final - loop carry value and the second element represents the stacked outputs of - the second output of ``f`` when scanned over the leading axis of the inputs. - """ - init, xs = tf.nest.map_structure( - lambda x: tf_np.asarray(x) if x is not None else None, (init, xs)) - if length is not None: - length = int(length) - def get_length(x): - if x is None: - return None - if x.shape.rank == 0: - raise ValueError("Some array in `xs` doesn't have a leading dimension") - return x.shape[0] - lengths = tf.nest.flatten(tf.nest.map_structure(get_length, xs)) - for l in lengths: - if l is not None: - if length is None: - length = l - elif length != l: - raise ValueError("There are two different leading-dimension lengths: " - f"{length} and {l}") - if length is None: - raise ValueError( - "Can't determine length. Please set the `length` argument.") - xs_ta = tf.nest.map_structure( - lambda t: (tf.TensorArray(t.dtype, size=length, dynamic_size=False) # pylint: disable=g-long-lambda - .unstack(t) if t is not None else None), - xs) - # tf.while_loop doesn't allow None in loop_vars, so we mask them. - is_init_none = tf.nest.map_structure(lambda x: x is None, init) - def to_safe(carry): - return tf.nest.map_structure( - lambda x, is_none: tf.zeros([]) if is_none else x, carry, is_init_none) - def from_safe(safe_carry): - return tf.nest.map_structure( - lambda x, is_none: None if is_none else x, safe_carry, is_init_none) - def body(i, safe_carry, ys_ta): - carry = from_safe(safe_carry) - if reverse: - i_ = length - 1 - i else: - i_ = i - xs = tf.nest.map_structure( - lambda x_ta: x_ta.read(i_) if x_ta is not None else None, xs_ta) - carry, ys = f(*_tf_to_np((carry, xs))) + + def maybe_swapaxes(a): + return tf_np.swapaxes(a, dimension, -1) + + # We need to swap axes because tf.gather (and tf.gather_nd) supports + # batch_dims on the left but not on the right. + # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis. + keys = maybe_swapaxes(keys) + values = maybe_swapaxes(values) + idxs = tf_np.argsort(keys) + + # Using tf.gather rather than np.take because the former supports batch_dims + def gather(a): + return tf_np.asarray(tf.gather(a, idxs, batch_dims=rank - 1)) + + keys = gather(keys) + values = gather(values) + keys = maybe_swapaxes(keys) + values = maybe_swapaxes(values) + return keys, values + + +def scan(f, init, xs, length=None, reverse=False): + """Scan a function over leading array axes while carrying along state. + + See the docstring of `jax.lax.scan` + (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) for + details. + + Args: + f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning + that ``f`` accepts two arguments where the first is a value of the loop + carry and the second is a slice of ``xs`` along its leading axis, and that + ``f`` returns a pair where the first element represents a new value for + the loop carry and the second represents a slice of the output. Note that + the input and output carry must have the same dtype. + init: an initial loop carry value of type ``c``, which can be a scalar, + array, or any pytree (nested Python tuple/list/dict) thereof, representing + the initial loop carry value. This value must have the same structure as + the first element of the pair returned by ``f``. + xs: the value of type ``[a]`` over which to scan along the leading axis, + where ``[a]`` can be an array or any pytree (nested Python + tuple/list/dict) thereof with consistent leading axis sizes. + length: optional integer specifying the number of loop iterations, which + must agree with the sizes of leading axes of the arrays in ``xs`` (but can + be used to perform scans where no input ``xs`` are needed). + reverse: optional boolean specifying whether to run the scan iteration + forward (the default) or in reverse, equivalent to reversing the leading + axes of the arrays in both ``xs`` and in ``ys``. + + Returns: + A pair of type ``(c, [b])`` where the first element represents the final + loop carry value and the second element represents the stacked outputs of + the second output of ``f`` when scanned over the leading axis of the inputs. + """ + init, xs = tf.nest.map_structure( + lambda x: tf_np.asarray(x) if x is not None else None, (init, xs) + ) + if length is not None: + length = int(length) + + def get_length(x): + if x is None: + return None + if x.shape.rank == 0: + raise ValueError("Some array in `xs` doesn't have a leading dimension") + return x.shape[0] + + lengths = tf.nest.flatten(tf.nest.map_structure(get_length, xs)) + for l in lengths: + if l is not None: + if length is None: + length = l + elif length != l: + raise ValueError( + "There are two different leading-dimension lengths: " + f"{length} and {l}" + ) + if length is None: + raise ValueError("Can't determine length. Please set the `length` argument.") + xs_ta = tf.nest.map_structure( + lambda t: ( + tf.TensorArray( + t.dtype, size=length, dynamic_size=False + ).unstack( # pylint: disable=g-long-lambda + t + ) + if t is not None + else None + ), + xs, + ) + # tf.while_loop doesn't allow None in loop_vars, so we mask them. + is_init_none = tf.nest.map_structure(lambda x: x is None, init) + + def to_safe(carry): + return tf.nest.map_structure( + lambda x, is_none: tf.zeros([]) if is_none else x, carry, is_init_none + ) + + def from_safe(safe_carry): + return tf.nest.map_structure( + lambda x, is_none: None if is_none else x, safe_carry, is_init_none + ) + + def body(i, safe_carry, ys_ta): + carry = from_safe(safe_carry) + if reverse: + i_ = length - 1 - i + else: + i_ = i + xs = tf.nest.map_structure( + lambda x_ta: x_ta.read(i_) if x_ta is not None else None, xs_ta + ) + carry, ys = f(*_tf_to_np((carry, xs))) + ys_ta = tf.nest.map_structure( + lambda y_ta, y: (y_ta.write(i_, y) if y is not None else y_ta), ys_ta, ys + ) + i = i + 1 + safe_carry = to_safe(carry) + return i, safe_carry, ys_ta + + xs_spec = tf.nest.map_structure( + lambda t: tf.TensorSpec(t.shape[1:], t.dtype) if t is not None else None, xs + ) + _, ys_spec = eval_on_shapes(f)(init, xs_spec) + # ys_ta can't contain None because tf.while_loop doesn't allow None in + # loop_vars. ys_ta = tf.nest.map_structure( - lambda y_ta, y: (y_ta.write(i_, y) if y is not None else y_ta), - ys_ta, ys) - i = i + 1 - safe_carry = to_safe(carry) - return i, safe_carry, ys_ta - xs_spec = tf.nest.map_structure( - lambda t: tf.TensorSpec(t.shape[1:], t.dtype) if t is not None else None, - xs) - _, ys_spec = eval_on_shapes(f)(init, xs_spec) - # ys_ta can't contain None because tf.while_loop doesn't allow None in - # loop_vars. - ys_ta = tf.nest.map_structure( - lambda y: tf.TensorArray(y.dtype if y is not None else tf.float32, # pylint: disable=g-long-lambda - size=length, dynamic_size=False), - ys_spec) - safe_init = to_safe(init) - _, safe_carry, ys_ta = tf.while_loop( - lambda i, *_: i < length, body, (0, safe_init, ys_ta), - maximum_iterations=length) - carry = from_safe(safe_carry) - def _stack(a, spec): - if spec is None: - return None - a = a.stack() - a.set_shape((length,) + a.shape[1:]) - return a - ys = tf.nest.map_structure(_stack, ys_ta, ys_spec) - return _tf_to_np((carry, ys)) + lambda y: tf.TensorArray( + y.dtype if y is not None else tf.float32, # pylint: disable=g-long-lambda + size=length, + dynamic_size=False, + ), + ys_spec, + ) + safe_init = to_safe(init) + _, safe_carry, ys_ta = tf.while_loop( + lambda i, *_: i < length, body, (0, safe_init, ys_ta), maximum_iterations=length + ) + carry = from_safe(safe_carry) + + def _stack(a, spec): + if spec is None: + return None + a = a.stack() + a.set_shape((length,) + a.shape[1:]) + return a + + ys = tf.nest.map_structure(_stack, ys_ta, ys_spec) + return _tf_to_np((carry, ys)) # named "tf_map" instead of "map" as in JAX to avoid conflict with Python `map` def tf_map(f, xs): - """Map a function over leading array axes. + """Map a function over leading array axes. - See the docstring of `jax.lax.map` - (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html) for - details. + See the docstring of `jax.lax.map` + (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html) for + details. - Args: - f: a Python function to apply element-wise over the first axis or axes of - `xs`. - xs: values over which to map along the leading axis. + Args: + f: a Python function to apply element-wise over the first axis or axes of + `xs`. + xs: values over which to map along the leading axis. - Returns: - Mapped values. - """ - def g(unused, x): - return unused, f(x) - carry = tf.nest.map_structure(lambda _: None, xs) - return scan(g, carry, xs)[1] + Returns: + Mapped values. + """ + + def g(unused, x): + return unused, f(x) + + carry = tf.nest.map_structure(lambda _: None, xs) + return scan(g, carry, xs)[1] def _get_dynamic_indices(operand, start_indices, slice_sizes): - """Calcuates the indices for `tf.gather_nd` from slices. - - Args: - operand: a Tensor to slice. - start_indices: a vector Tensor of integers, one per dimension. The starts of - the slice. The vector can be dynamic. - slice_sizes: a list of integers, one per dimension. The sizes of the slice. - - Returns: - An index array suitable for `tf.gather_nd` and `tf.scatter_nd`, or `None` if - `operand` is a scalar. - """ - rank = len(slice_sizes) - operand_rank = tf.rank(operand) - tf.debugging.Assert(operand_rank == rank, [operand_rank, rank]) - starts_rank = tf.rank(start_indices) - tf.debugging.Assert(starts_rank == 1, [starts_rank]) - num_starts = tf.shape(start_indices)[0] - tf.debugging.Assert(num_starts == rank, [num_starts, rank]) - operand_shape = tf.shape(operand) - tf.debugging.Assert(tf.reduce_all(slice_sizes <= operand_shape), - [slice_sizes, operand_shape]) - if rank == 0: - return None - start_indices = tf.where( - start_indices < 0, start_indices + operand_shape, start_indices) - idx_list = [] - for i in range(rank): - start = start_indices[i] - size = slice_sizes[i] - dim = operand_shape[i] - start = tf.clip_by_value(start, 0, dim - size) - # XLA requires tf.range's `start` to be compile-time constant, so we can't - # do tf.range(start, ...). - idx = start + tf.range(size) - shape = [1] * rank - shape[i] = size - idx = tf.reshape(idx, shape) - idx_list.append(idx) - slice_sizes_tensor = tf.convert_to_tensor(slice_sizes) - # tf.stack doesn't support broadcasting, so we need to broadcast manually. - # TODO(wangpeng): Reduce peak memory by broadcasting one-by-one instead of - # all-together. - idx_list = [tf.broadcast_to(x, slice_sizes_tensor) for x in idx_list] - return tf.stack(idx_list, axis=-1) + """Calcuates the indices for `tf.gather_nd` from slices. + + Args: + operand: a Tensor to slice. + start_indices: a vector Tensor of integers, one per dimension. The starts of + the slice. The vector can be dynamic. + slice_sizes: a list of integers, one per dimension. The sizes of the slice. + + Returns: + An index array suitable for `tf.gather_nd` and `tf.scatter_nd`, or `None` if + `operand` is a scalar. + """ + rank = len(slice_sizes) + operand_rank = tf.rank(operand) + tf.debugging.Assert(operand_rank == rank, [operand_rank, rank]) + starts_rank = tf.rank(start_indices) + tf.debugging.Assert(starts_rank == 1, [starts_rank]) + num_starts = tf.shape(start_indices)[0] + tf.debugging.Assert(num_starts == rank, [num_starts, rank]) + operand_shape = tf.shape(operand) + tf.debugging.Assert( + tf.reduce_all(slice_sizes <= operand_shape), [slice_sizes, operand_shape] + ) + if rank == 0: + return None + start_indices = tf.where( + start_indices < 0, start_indices + operand_shape, start_indices + ) + idx_list = [] + for i in range(rank): + start = start_indices[i] + size = slice_sizes[i] + dim = operand_shape[i] + start = tf.clip_by_value(start, 0, dim - size) + # XLA requires tf.range's `start` to be compile-time constant, so we can't + # do tf.range(start, ...). + idx = start + tf.range(size) + shape = [1] * rank + shape[i] = size + idx = tf.reshape(idx, shape) + idx_list.append(idx) + slice_sizes_tensor = tf.convert_to_tensor(slice_sizes) + # tf.stack doesn't support broadcasting, so we need to broadcast manually. + # TODO(wangpeng): Reduce peak memory by broadcasting one-by-one instead of + # all-together. + idx_list = [tf.broadcast_to(x, slice_sizes_tensor) for x in idx_list] + return tf.stack(idx_list, axis=-1) def dynamic_slice(operand, start_indices, slice_sizes): - """Slicing operation where the indices can be dynamic vlaues. - - See the docstring of `jax.lax.dynamic_slice` - (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html) - for details. - - Args: - operand: an array to slice. - start_indices: a vector of integers, one per dimension. The starts of the - slice. The vector can be dynamic. - slice_sizes: a list of integers, one per dimension. The sizes of the slice. - - Returns: - An array containing the slice, with shape equal to `slice_sizes`. - """ - # This implementation uses tf.gather_nd to implement dynamic_slice, which is - # memory inefficient because the size of `indices` given to gather_nd is - # large. - operand = tf_np.asarray(operand).data - start_indices = tf_np.asarray(start_indices, np.int32).data - idx = _get_dynamic_indices(operand, start_indices, slice_sizes) - if idx is not None: - operand = tf.gather_nd(operand, idx) - return tf_np.asarray(operand) + """Slicing operation where the indices can be dynamic vlaues. + + See the docstring of `jax.lax.dynamic_slice` + (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html) + for details. + + Args: + operand: an array to slice. + start_indices: a vector of integers, one per dimension. The starts of the + slice. The vector can be dynamic. + slice_sizes: a list of integers, one per dimension. The sizes of the slice. + + Returns: + An array containing the slice, with shape equal to `slice_sizes`. + """ + # This implementation uses tf.gather_nd to implement dynamic_slice, which is + # memory inefficient because the size of `indices` given to gather_nd is + # large. + operand = tf_np.asarray(operand).data + start_indices = tf_np.asarray(start_indices, np.int32).data + idx = _get_dynamic_indices(operand, start_indices, slice_sizes) + if idx is not None: + operand = tf.gather_nd(operand, idx) + return tf_np.asarray(operand) def dynamic_update_slice(operand, update, start_indices): - """Updates a dynamic slice. - - See the docstring of `jax.lax.dynamic_update_slice` - (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_update_slice.html) - for details. - - Args: - operand: an array to slice. - update: an array containing the new values to write onto `operand`. - start_indices: a vector of integers, one per dimension. The starts of the - slice. The vector can be dynamic. - - Returns: - The updated version of `operand`. - """ - operand = tf_np.asarray(operand).data - update = tf_np.asarray(update).data - start_indices = tf_np.asarray(start_indices, np.int32).data - if not update.shape.is_fully_defined(): - raise ValueError("update's shape must be fully defined") - slice_sizes = update.shape - idx = _get_dynamic_indices(operand, start_indices, slice_sizes) - if idx is None: - # `np.zeros([])[()] = 1.0` will result in a scalar array of 1.0 - return tf_np.asarray(update) - operand = tf.tensor_scatter_nd_update(operand, idx, update) - return tf_np.asarray(operand) + """Updates a dynamic slice. + + See the docstring of `jax.lax.dynamic_update_slice` + (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_update_slice.html) + for details. + + Args: + operand: an array to slice. + update: an array containing the new values to write onto `operand`. + start_indices: a vector of integers, one per dimension. The starts of the + slice. The vector can be dynamic. + + Returns: + The updated version of `operand`. + """ + operand = tf_np.asarray(operand).data + update = tf_np.asarray(update).data + start_indices = tf_np.asarray(start_indices, np.int32).data + if not update.shape.is_fully_defined(): + raise ValueError("update's shape must be fully defined") + slice_sizes = update.shape + idx = _get_dynamic_indices(operand, start_indices, slice_sizes) + if idx is None: + # `np.zeros([])[()] = 1.0` will result in a scalar array of 1.0 + return tf_np.asarray(update) + operand = tf.tensor_scatter_nd_update(operand, idx, update) + return tf_np.asarray(operand) def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0): - """Convenience wrapper around dynamic_slice applying to one dimension.""" - operand = tf_np.asarray(operand) - start_indices = [0] * operand.ndim - slice_sizes = list(operand.shape) - axis = int(axis) - start_indices[axis] = start_index - slice_sizes[axis] = int(slice_size) - return dynamic_slice(operand, start_indices, slice_sizes) + """Convenience wrapper around dynamic_slice applying to one dimension.""" + operand = tf_np.asarray(operand) + start_indices = [0] * operand.ndim + slice_sizes = list(operand.shape) + axis = int(axis) + start_indices[axis] = start_index + slice_sizes[axis] = int(slice_size) + return dynamic_slice(operand, start_indices, slice_sizes) def dynamic_update_slice_in_dim(operand, update, start_index, axis): - """Convenience wrapper around dynamic_update_slice for one dimension.""" - operand = tf_np.asarray(operand) - axis = int(axis) - start_indices = [0] * operand.ndim - start_indices[axis] = start_index - return dynamic_update_slice(operand, update, start_indices) + """Convenience wrapper around dynamic_update_slice for one dimension.""" + operand = tf_np.asarray(operand) + axis = int(axis) + start_indices = [0] * operand.ndim + start_indices[axis] = start_index + return dynamic_update_slice(operand, update, start_indices) # Use int64 instead of int32 to avoid TF's "int32 problem" @@ -1274,229 +1409,229 @@ def dynamic_update_slice_in_dim(operand, update, start_index, axis): def _key2seed(a): - """Converts an RNG key to an RNG seed. + """Converts an RNG key to an RNG seed. - Args: - a: an RNG key, an ndarray of shape [] and dtype `np.int64`. + Args: + a: an RNG key, an ndarray of shape [] and dtype `np.int64`. - Returns: - an RNG seed, a tensor of shape [2] and dtype `tf.int32`. - """ + Returns: + an RNG seed, a tensor of shape [2] and dtype `tf.int32`. + """ - def int64_to_int32s(a): - """Converts an int64 tensor of shape [] to an int32 tensor of shape [2].""" - a = tf.cast(a, tf.uint64) - fst = tf.cast(a, tf.uint32) - snd = tf.cast( - tf.bitwise.right_shift(a, tf.constant(32, tf.uint64)), tf.uint32) - a = [fst, snd] - a = tf.nest.map_structure(lambda x: tf.cast(x, tf.int32), a) - a = tf.stack(a) - return a + def int64_to_int32s(a): + """Converts an int64 tensor of shape [] to an int32 tensor of shape [2].""" + a = tf.cast(a, tf.uint64) + fst = tf.cast(a, tf.uint32) + snd = tf.cast(tf.bitwise.right_shift(a, tf.constant(32, tf.uint64)), tf.uint32) + a = [fst, snd] + a = tf.nest.map_structure(lambda x: tf.cast(x, tf.int32), a) + a = tf.stack(a) + return a - return int64_to_int32s(a) + return int64_to_int32s(a) def _seed2key(a): - """Converts an RNG seed to an RNG key. + """Converts an RNG seed to an RNG key. - Args: - a: an RNG seed, a tensor of shape [2] and dtype `tf.int32`. + Args: + a: an RNG seed, a tensor of shape [2] and dtype `tf.int32`. - Returns: - an RNG key, an ndarray of shape [] and dtype `np.int64`. - """ + Returns: + an RNG key, an ndarray of shape [] and dtype `np.int64`. + """ - def int32s_to_int64(a): - """Converts an int32 tensor of shape [2] to an int64 tensor of shape [].""" - a = tf.bitwise.bitwise_or( - tf.cast(a[0], tf.uint64), - tf.bitwise.left_shift( - tf.cast(a[1], tf.uint64), tf.constant(32, tf.uint64))) - a = tf.cast(a, tf.int64) - return a + def int32s_to_int64(a): + """Converts an int32 tensor of shape [2] to an int64 tensor of shape [].""" + a = tf.bitwise.bitwise_or( + tf.cast(a[0], tf.uint64), + tf.bitwise.left_shift(tf.cast(a[1], tf.uint64), tf.constant(32, tf.uint64)), + ) + a = tf.cast(a, tf.int64) + return a - return tf_np.asarray(int32s_to_int64(a)) + return tf_np.asarray(int32s_to_int64(a)) def prng(s): - """Creates RNG state from seed. + """Creates RNG state from seed. - Args: - s: the seed, an integer. + Args: + s: the seed, an integer. - Returns: - An RNG state, as a scalar array of dtype `np.int64`. - """ - # TODO(wangpeng): Become bitwise-identical to JAX when TF stateless RNGs get - # improved. - return tf_np.asarray(s, dtype=_RNG_KEY_DTYPE) + Returns: + An RNG state, as a scalar array of dtype `np.int64`. + """ + # TODO(wangpeng): Become bitwise-identical to JAX when TF stateless RNGs get + # improved. + return tf_np.asarray(s, dtype=_RNG_KEY_DTYPE) def stateless_split(seed, num=2): - """Splits an RNG seed into `num` new seeds by adding a leading axis. - - Example: - - >>> seed = [1, 2] - >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3) - >>> print(new_seeds) - tf.Tensor( - [[1105988140 1738052849] - [-335576002 370444179] - [ 10670227 -246211131]], shape=(3, 2), dtype=int32) - >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :]) - - - Args: - seed: an RNG seed (a tensor with shape [2] and dtype `int32` or `int64`). - (When using XLA, only `int32` is allowed.) - num: optional, a positive integer or scalar tensor indicating the number of - seeds to produce (default 2). - - Returns: - A tensor with shape [num, 2] representing `num` new seeds. It will have the - same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype - will be determined by `tf.convert_to_tensor`). - """ - seed = tf.convert_to_tensor(seed) - return tf.random.stateless_uniform( - shape=[num, 2], seed=seed, dtype=seed.dtype, minval=None, maxval=None) + """Splits an RNG seed into `num` new seeds by adding a leading axis. + + Example: + + >>> seed = [1, 2] + >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3) + >>> print(new_seeds) + tf.Tensor( + [[1105988140 1738052849] + [-335576002 370444179] + [ 10670227 -246211131]], shape=(3, 2), dtype=int32) + >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :]) + + + Args: + seed: an RNG seed (a tensor with shape [2] and dtype `int32` or `int64`). + (When using XLA, only `int32` is allowed.) + num: optional, a positive integer or scalar tensor indicating the number of + seeds to produce (default 2). + + Returns: + A tensor with shape [num, 2] representing `num` new seeds. It will have the + same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype + will be determined by `tf.convert_to_tensor`). + """ + seed = tf.convert_to_tensor(seed) + return tf.random.stateless_uniform( + shape=[num, 2], seed=seed, dtype=seed.dtype, minval=None, maxval=None + ) def split(state, num): - """Creates new independent RNG states from an existing state. - - Args: - state: the existing state. - num: the number of the new states. - - Returns: - A tuple of new states. - """ - state = tf_np.asarray(state, dtype=_RNG_KEY_DTYPE) - state = _key2seed(state) - try: - states = tf.random.experimental.stateless_split(state, num) - except AttributeError as e: # pylint: disable=unused-variable - # TODO(afrozm): For TF < 2.3 we need to do this. Delete once 2.3 launches. - states = stateless_split(state, num) - states = tf.unstack(states, num) - states = tf.nest.map_structure(_seed2key, states) - return states - - -def uniform(key, - shape, - dtype=tf_np.random.DEFAULT_RANDN_DTYPE, - minval=0., - maxval=1.): - """Sample uniform random values in range [`minval`, `maxval`). - - Args: - key: the RNG key. - shape: the shape of the result. - dtype: the dtype of the result. - minval: the minimal value (inclusive). - maxval: the maximal value (exclusive). - - Returns: - An ndarray with shape `shape` and dtype `dtype`. Each value in the ndarray - is sampled uniformly randomly in range [`minval`, `maxval`). - """ - minval = tf.cast(minval, dtype) - maxval = tf.cast(maxval, dtype) - key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) - return tf_np.asarray( - tf.random.stateless_uniform( - shape, seed=_key2seed(key), dtype=dtype, minval=minval, - maxval=maxval)) + """Creates new independent RNG states from an existing state. + + Args: + state: the existing state. + num: the number of the new states. + + Returns: + A tuple of new states. + """ + state = tf_np.asarray(state, dtype=_RNG_KEY_DTYPE) + state = _key2seed(state) + try: + states = tf.random.experimental.stateless_split(state, num) + except AttributeError as e: # pylint: disable=unused-variable + # TODO(afrozm): For TF < 2.3 we need to do this. Delete once 2.3 launches. + states = stateless_split(state, num) + states = tf.unstack(states, num) + states = tf.nest.map_structure(_seed2key, states) + return states + + +def uniform(key, shape, dtype=tf_np.random.DEFAULT_RANDN_DTYPE, minval=0.0, maxval=1.0): + """Sample uniform random values in range [`minval`, `maxval`). + + Args: + key: the RNG key. + shape: the shape of the result. + dtype: the dtype of the result. + minval: the minimal value (inclusive). + maxval: the maximal value (exclusive). + + Returns: + An ndarray with shape `shape` and dtype `dtype`. Each value in the ndarray + is sampled uniformly randomly in range [`minval`, `maxval`). + """ + minval = tf.cast(minval, dtype) + maxval = tf.cast(maxval, dtype) + key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) + return tf_np.asarray( + tf.random.stateless_uniform( + shape, seed=_key2seed(key), dtype=dtype, minval=minval, maxval=maxval + ) + ) def normal(key, shape, dtype=tf.float32): - """Sample standard-normal random values. + """Sample standard-normal random values. - Args: - key: the RNG key. - shape: the shape of the result. - dtype: the dtype of the result. + Args: + key: the RNG key. + shape: the shape of the result. + dtype: the dtype of the result. - Returns: - Random values in standard-normal distribution. - """ - key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) - return tf_np.asarray( - tf.random.stateless_normal(shape, seed=_key2seed(key), dtype=dtype)) + Returns: + Random values in standard-normal distribution. + """ + key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) + return tf_np.asarray( + tf.random.stateless_normal(shape, seed=_key2seed(key), dtype=dtype) + ) def bernoulli(key, mean=np.float32(0.5), shape=None): - """Sample Bernoulli random values with given shape and mean. - - Args: - key: the RNG key. - mean: optional, an array_like broadcastable to `shape` for the mean of the - random variables (default 0.5). - shape: optional, a tuple of nonnegative integers representing the shape - (default to `mean`'s shape). + """Sample Bernoulli random values with given shape and mean. - Returns: - A random array with the specified shape and boolean dtype. - """ - mean = tf_np.asarray(mean) - if shape is None: - shape = mean.shape - return uniform(key, shape) < mean + Args: + key: the RNG key. + mean: optional, an array_like broadcastable to `shape` for the mean of the + random variables (default 0.5). + shape: optional, a tuple of nonnegative integers representing the shape + (default to `mean`'s shape). + + Returns: + A random array with the specified shape and boolean dtype. + """ + mean = tf_np.asarray(mean) + if shape is None: + shape = mean.shape + return uniform(key, shape) < mean def _eager_dataset_iterator(dataset): - for item in dataset: - yield tf.nest.map_structure(tf_np.asarray, item) + for item in dataset: + yield tf.nest.map_structure(tf_np.asarray, item) def dataset_as_numpy(dataset): - """Converts a `tf.data.Dataset` to an iterable of ndarrays. - - `dataset_as_numpy` converts a possibly nested structure of `tf.data.Dataset`s - and `tf.Tensor`s to iterables of ndarrays and ndarrays, respectively. This - function must be run in eager mode outside tf.function. - - Args: - dataset: a possibly nested structure of `tf.data.Dataset`s and/or - `tf.Tensor`s. - - Returns: - A structure matching `dataset` where `tf.data.Dataset`s are converted to - generators of ndarrays and `tf.Tensor`s are converted to ndarrays. - """ - if not tf.executing_eagerly(): - raise ValueError( - "dataset_as_numpy must be run in eager mode outside tf.function") - nested_ds = dataset - del dataset - - # Flatten - flat_ds = tf.nest.flatten(nested_ds) - flat_np = [] - - # Type check for Tensors and Datasets - for ds_el in flat_ds: - if not isinstance(ds_el, (tf.Tensor, tf.data.Dataset)): - types = tf.nest.map_structure(type, nested_ds) - raise ValueError("Arguments to dataset_as_numpy must be (possibly nested " - "structure of) tf.Tensors or tf.data.Datasets. Got: %s" % - types) - - for ds_el in flat_ds: - if isinstance(ds_el, tf.Tensor): - np_el = tf_np.asarray(ds_el) - elif isinstance(ds_el, tf.data.Dataset): - np_el = _eager_dataset_iterator(ds_el) - else: - assert False - flat_np.append(np_el) + """Converts a `tf.data.Dataset` to an iterable of ndarrays. - return tf.nest.pack_sequence_as(nested_ds, flat_np) + `dataset_as_numpy` converts a possibly nested structure of `tf.data.Dataset`s + and `tf.Tensor`s to iterables of ndarrays and ndarrays, respectively. This + function must be run in eager mode outside tf.function. + + Args: + dataset: a possibly nested structure of `tf.data.Dataset`s and/or + `tf.Tensor`s. + + Returns: + A structure matching `dataset` where `tf.data.Dataset`s are converted to + generators of ndarrays and `tf.Tensor`s are converted to ndarrays. + """ + if not tf.executing_eagerly(): + raise ValueError( + "dataset_as_numpy must be run in eager mode outside tf.function" + ) + nested_ds = dataset + del dataset + + # Flatten + flat_ds = tf.nest.flatten(nested_ds) + flat_np = [] + + # Type check for Tensors and Datasets + for ds_el in flat_ds: + if not isinstance(ds_el, (tf.Tensor, tf.data.Dataset)): + types = tf.nest.map_structure(type, nested_ds) + raise ValueError( + "Arguments to dataset_as_numpy must be (possibly nested " + "structure of) tf.Tensors or tf.data.Datasets. Got: %s" % types + ) + + for ds_el in flat_ds: + if isinstance(ds_el, tf.Tensor): + np_el = tf_np.asarray(ds_el) + elif isinstance(ds_el, tf.data.Dataset): + np_el = _eager_dataset_iterator(ds_el) + else: + assert False + flat_np.append(np_el) + + return tf.nest.pack_sequence_as(nested_ds, flat_np) # TODO(nareshmodi): Group key should change based on the set of devices that we @@ -1511,80 +1646,84 @@ def dataset_as_numpy(dataset): # TODO(b/142565636): Ensure that multiple concurrent calls to a tf.function # containing a collective op run reasonably. def _get_instance_key(): - global _INSTANCE_KEY - global _INSTANCE_LOCK - with _INSTANCE_LOCK: - _INSTANCE_KEY = _INSTANCE_KEY + 1 - return _INSTANCE_KEY + global _INSTANCE_KEY + global _INSTANCE_LOCK + with _INSTANCE_LOCK: + _INSTANCE_KEY = _INSTANCE_KEY + 1 + return _INSTANCE_KEY # Don't use a namedtuple since nest considers that a tuple and unflattens and # flattens it. class ShardedNdArray(object): - """Wrapper over ndarray that can contain tensors on multiple devices. + """Wrapper over ndarray that can contain tensors on multiple devices. This is returned by extensions.pmap, and contains the individual tensors on different devices. - """ + """ - def __init__(self, tensors): - """Initializes the ShardedNdArray. + def __init__(self, tensors): + """Initializes the ShardedNdArray. - Note that the tensors should be ordered in the way the pmap producing these - tensors is run. + Note that the tensors should be ordered in the way the pmap producing these + tensors is run. - Args: - tensors: list or tuple of eager tensors, one for each device. - """ + Args: + tensors: list or tuple of eager tensors, one for each device. + """ - if not isinstance(tensors, (list, tuple)) or not tensors: - raise ValueError( - "Unable to create a ShardedNdArray without a list of tensors.") - self.tensors = tensors - self.n_devices = len(tensors) + if not isinstance(tensors, (list, tuple)) or not tensors: + raise ValueError( + "Unable to create a ShardedNdArray without a list of tensors." + ) + self.tensors = tensors + self.n_devices = len(tensors) - def __getitem__(self, i): - return tf_np.asarray(self.tensors[i]) + def __getitem__(self, i): + return tf_np.asarray(self.tensors[i]) - @property - def shape(self): - return (self.n_devices,) + self.tensors[0]._shape_tuple() # pylint: disable=protected-access + @property + def shape(self): + return (self.n_devices,) + self.tensors[ + 0 + ]._shape_tuple() # pylint: disable=protected-access - @property - def dtype(self): - return self.tensors[0].dtype + @property + def dtype(self): + return self.tensors[0].dtype def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): - del args, kwargs - # TODO(nareshmodi): Consider a collective op to gather the tensors from the - # various devices for performance reasons. - return tf.stack(value.tensors) + del args, kwargs + # TODO(nareshmodi): Consider a collective op to gather the tensors from the + # various devices for performance reasons. + return tf.stack(value.tensors) -tf.register_tensor_conversion_function(ShardedNdArray, - convert_sharded_tensor_to_eager_tensor) +tf.register_tensor_conversion_function( + ShardedNdArray, convert_sharded_tensor_to_eager_tensor +) class _PmapConfig(threading.local): - """Simple config used to maintain state related to a current pmap call.""" + """Simple config used to maintain state related to a current pmap call.""" - def __init__(self): - super(_PmapConfig, self).__init__() - self._axis_name = None - self._devices = None + def __init__(self): + super(_PmapConfig, self).__init__() + self._axis_name = None + self._devices = None - def axis_name(self): - return self._axis_name + def axis_name(self): + return self._axis_name - def set_axis_name(self, axis_name): - self._axis_name = axis_name + def set_axis_name(self, axis_name): + self._axis_name = axis_name - def devices(self): - return self._devices + def devices(self): + return self._devices - def set_devices(self, devices): - self._devices = devices + def set_devices(self, devices): + self._devices = devices _pmap_config = _PmapConfig() @@ -1592,404 +1731,426 @@ def set_devices(self, devices): @contextlib.contextmanager def pmap_config(axis_name, devices): - """Records axis_name and devices for this context.""" - old_axis_name = _pmap_config.axis_name() - old_devices = _pmap_config.devices() - _pmap_config.set_axis_name(axis_name) - _pmap_config.set_devices(devices) - try: - yield - finally: - _pmap_config.set_axis_name(old_axis_name) - _pmap_config.set_devices(old_devices) + """Records axis_name and devices for this context.""" + old_axis_name = _pmap_config.axis_name() + old_devices = _pmap_config.devices() + _pmap_config.set_axis_name(axis_name) + _pmap_config.set_devices(devices) + try: + yield + finally: + _pmap_config.set_axis_name(old_axis_name) + _pmap_config.set_devices(old_devices) def _psum(tensor, axis_name=None): - """Sum all-reduction. - - Args: - tensor: A tensor. - axis_name: The axis name to reduce. Must equal to that of the surrounding - pmap. - - Returns: - The sum of the `tensor` replicas on each participating devices. - """ - if axis_name != _pmap_config.axis_name(): - raise ValueError("axis_name (%s) is not equal to that of the surrounding " - "pmap (%s)" % (axis_name, _pmap_config.axis_name())) - devices = _pmap_config.devices() - if devices is None: - raise ValueError("Can't retrieve the device list from the surrounding pmap") - tensor = tf_np.asarray(tensor) - if tpu_devices(devices): - # TODO(b/170895907): Remove this workaround when tpu.cross_replica_sum - # supports int64/float64. - is_int64 = False - is_float64 = False - if tensor.dtype == np.int64: - is_int64 = True - tensor = tensor.astype(np.int32) - elif tensor.dtype == np.float64: - is_float64 = True - tensor = tensor.astype(np.float32) - # TODO(wangpeng): Supply the `group_assignment` argument to - # tpu.cross_replica_sum, calculated from `devices`. - tensor = tf.compat.v1.tpu.cross_replica_sum(tensor) - if is_int64: - tensor = tf.cast(tensor, tf.int64) - elif is_float64: - tensor = tf.cast(tensor, tf.float64) - else: - tensor = tf.raw_ops.CollectiveReduce( - input=tensor, - group_size=len(devices), - group_key=_GROUP_KEY, - instance_key=_get_instance_key(), - merge_op="Add", - final_op="Id", - subdiv_offsets=(0,)) - return tf_np.asarray(tensor) + """Sum all-reduction. + + Args: + tensor: A tensor. + axis_name: The axis name to reduce. Must equal to that of the surrounding + pmap. + + Returns: + The sum of the `tensor` replicas on each participating devices. + """ + if axis_name != _pmap_config.axis_name(): + raise ValueError( + "axis_name (%s) is not equal to that of the surrounding " + "pmap (%s)" % (axis_name, _pmap_config.axis_name()) + ) + devices = _pmap_config.devices() + if devices is None: + raise ValueError("Can't retrieve the device list from the surrounding pmap") + tensor = tf_np.asarray(tensor) + if tpu_devices(devices): + # TODO(b/170895907): Remove this workaround when tpu.cross_replica_sum + # supports int64/float64. + is_int64 = False + is_float64 = False + if tensor.dtype == np.int64: + is_int64 = True + tensor = tensor.astype(np.int32) + elif tensor.dtype == np.float64: + is_float64 = True + tensor = tensor.astype(np.float32) + # TODO(wangpeng): Supply the `group_assignment` argument to + # tpu.cross_replica_sum, calculated from `devices`. + tensor = tf.compat.v1.tpu.cross_replica_sum(tensor) + if is_int64: + tensor = tf.cast(tensor, tf.int64) + elif is_float64: + tensor = tf.cast(tensor, tf.float64) + else: + tensor = tf.raw_ops.CollectiveReduce( + input=tensor, + group_size=len(devices), + group_key=_GROUP_KEY, + instance_key=_get_instance_key(), + merge_op="Add", + final_op="Id", + subdiv_offsets=(0,), + ) + return tf_np.asarray(tensor) def psum(tensors, axis_name=None): - return tf.nest.map_structure( - functools.partial(_psum, axis_name=axis_name), tensors) + return tf.nest.map_structure(functools.partial(_psum, axis_name=axis_name), tensors) # Note this is not available in the jax api, but seemed like a reasonable API # to have. def pmean(tensor, axis_name=None): - """Mean all-reduction. - - Args: - tensor: A tensor. - axis_name: The axis name to reduce. Must equal to that of the surrounding - pmap. - - Returns: - The mean of the `tensor` replicas on each participating devices. - """ - if axis_name != _pmap_config.axis_name(): - raise ValueError("axis_name (%s) is not equal to that of the surrounding " - "pmap (%s)" % (axis_name, _pmap_config.axis_name())) - devices = _pmap_config.devices() - if devices is None: - raise ValueError("Can't retrieve the device list from the surrounding pmap") - if tpu_devices(devices): - # TODO(wangpeng): Implement this. - raise ValueError("pmean for TPU is not supported yet.") - else: - return tf.raw_ops.CollectiveReduce( - input=tensor, - group_size=len(devices), - group_key=_GROUP_KEY, - instance_key=_get_instance_key(), - merge_op="Add", - final_op="Div", - subdiv_offsets=(0,)) + """Mean all-reduction. + + Args: + tensor: A tensor. + axis_name: The axis name to reduce. Must equal to that of the surrounding + pmap. + + Returns: + The mean of the `tensor` replicas on each participating devices. + """ + if axis_name != _pmap_config.axis_name(): + raise ValueError( + "axis_name (%s) is not equal to that of the surrounding " + "pmap (%s)" % (axis_name, _pmap_config.axis_name()) + ) + devices = _pmap_config.devices() + if devices is None: + raise ValueError("Can't retrieve the device list from the surrounding pmap") + if tpu_devices(devices): + # TODO(wangpeng): Implement this. + raise ValueError("pmean for TPU is not supported yet.") + else: + return tf.raw_ops.CollectiveReduce( + input=tensor, + group_size=len(devices), + group_key=_GROUP_KEY, + instance_key=_get_instance_key(), + merge_op="Add", + final_op="Div", + subdiv_offsets=(0,), + ) def _get_pmap_impl(f, devices, has_tpu): - """This is a helper function to return the pmap impl. - - Args: - f: a function that takes ndarrays and returns ndarrays. - devices: a list of strings; the device list. - has_tpu: boolean; whether `devices` contains TPU devices. - - Returns: - A function that takes tensors and returns tensors. - """ - if has_tpu: - # Workaround b/121383831 - output_is_list = [False] # Use list for mutability - def recorder(args, kwargs, res): - del args, kwargs - output_is_list[0] = isinstance(res, list) - return res - f = _record_result_type(recorder, f) - - def tf_f(*tf_args): - """A wrapper for `f` that takes/returns tensors.""" - np_args = _tf_to_np(tf_args) - np_out = f(*np_args) - return np_out - - if has_tpu: - - @tf.function(autograph=False) - def fn(inputs): - # TODO(wangpeng): Supply the `device_assignment` argument to - # tpu.replicate, calculated from `devices`. - res = tf.compat.v1.tpu.replicate(tf_f, inputs) - # Workaround b/121383831 - if (res and isinstance(res[0], list) and len(res[0]) == 1 and - not output_is_list[0]): - res = [x[0] for x in res] - return res - - return fn - else: - # This is run in a tf.function so that the various underlying functions can - # be run in parallel. - # The trace happens on the client, so any devices should not depend on any - # side effects. - - jit_tf_f = tf.function(tf_f, autograph=False) - - @tf.function(autograph=False) - def fn(all_per_device_args): - """Multi-device function with calls placed on the correct device.""" - - results = [] - for per_device_args, device in zip(all_per_device_args, devices): - with tf.device(device): - results.append(jit_tf_f(*per_device_args)) - return results - - return fn + """This is a helper function to return the pmap impl. + + Args: + f: a function that takes ndarrays and returns ndarrays. + devices: a list of strings; the device list. + has_tpu: boolean; whether `devices` contains TPU devices. + + Returns: + A function that takes tensors and returns tensors. + """ + if has_tpu: + # Workaround b/121383831 + output_is_list = [False] # Use list for mutability + + def recorder(args, kwargs, res): + del args, kwargs + output_is_list[0] = isinstance(res, list) + return res + + f = _record_result_type(recorder, f) + + def tf_f(*tf_args): + """A wrapper for `f` that takes/returns tensors.""" + np_args = _tf_to_np(tf_args) + np_out = f(*np_args) + return np_out + + if has_tpu: + + @tf.function(autograph=False) + def fn(inputs): + # TODO(wangpeng): Supply the `device_assignment` argument to + # tpu.replicate, calculated from `devices`. + res = tf.compat.v1.tpu.replicate(tf_f, inputs) + # Workaround b/121383831 + if ( + res + and isinstance(res[0], list) + and len(res[0]) == 1 + and not output_is_list[0] + ): + res = [x[0] for x in res] + return res + + return fn + else: + # This is run in a tf.function so that the various underlying functions can + # be run in parallel. + # The trace happens on the client, so any devices should not depend on any + # side effects. + + jit_tf_f = tf.function(tf_f, autograph=False) + + @tf.function(autograph=False) + def fn(all_per_device_args): + """Multi-device function with calls placed on the correct device.""" + + results = [] + for per_device_args, device in zip(all_per_device_args, devices): + with tf.device(device): + results.append(jit_tf_f(*per_device_args)) + return results + + return fn def pmap(f, axis_name=None, devices=None): - """Transforms a function into a multi-device function. - - The semantics are similar to JAX's pmap. - - Args: - f: The function to be converted. - axis_name: Used for nested pmap, which is not supported yet. - devices: The devices over which the returned function will run. - - Returns: - A function that runs the underlying function `f` on `devices`. Its arguments - can be `ShardedNdArray`s, tensors or other Python objects, and its return - values are all `ShardedNdArray`s. If an input is a tensor, the length of its - first dimension must equal the number of devices, and the tensor will be - splitted along its first dimension among the devices. If an input is an - unknown Python object, it will be replicated among the devices. - """ - if devices is None: - devices = accelerators() - if not isinstance(devices, (list, tuple)): - raise ValueError("Must pass a list or tuple of devices") - num_devices = len(devices) - if not num_devices: - raise ValueError("There must be at least 1 device") - has_tpu = bool(tpu_devices(devices)) - - pmap_fn = _get_pmap_impl(f, devices, has_tpu) - - def wrapper(*args): - """Wrapper that wraps/unwraps args, retvals, and runs the function.""" - if _pmap_config.devices() is not None: - raise ValueError("Found a surrounding pmap. Nested pmap is not supported " - "yet.") - # TODO(wangpeng): Maybe we should use `asarray` to convert everything - # to ndarray first. - - flattened_input_args = tf.nest.flatten(args) - flattened_per_device_args = [[] for _ in devices] - for arg in flattened_input_args: - if isinstance(arg, tf.Tensor): - # TODO(nareshmodi): Try and use the dynamic shape instead. - if (not arg.shape.rank) or arg.shape[0] != len(devices): - # TODO(nareshmodi): Fix this restriction - raise ValueError( - "Input tensors need to have a first dimension equal to " - "the number of devices; got tensor of shape %s and %s devices" % - (arg.shape, len(devices))) - # NOTE: Alternatively use tf.split, and place the split tensors on the - # appropriate device. The best solution for this is to have an API that - # splits a tensor across devices. - for j, device in enumerate(devices): - updated_arg = tf.gather(arg, j) - # TODO(wangpeng): Investigate whether we need a tf.identity for TPU. - if not has_tpu: - with tf.device(device): - updated_arg = tf.identity(updated_arg) - flattened_per_device_args[j].append(updated_arg) - elif isinstance(arg, ShardedNdArray): - for device_args, tensor in zip(flattened_per_device_args, arg.tensors): - device_args.append(tensor) - else: - for device_args in flattened_per_device_args: - device_args.append(arg) - - all_per_device_args = [ - tf.nest.pack_sequence_as(args, device_args) - for device_args in flattened_per_device_args - ] - - with pmap_config(axis_name, devices): - results = pmap_fn(all_per_device_args) - - # Rewrap things. This can probably be written better. - flattened_results = [tf.nest.flatten(result) for result in results] - final_tree = [] - - # TODO(nareshmodi): assert all items in flattened_results have the same - # structures - - for i in range(len(flattened_results[0])): - tensors = [] - for j, device in enumerate(devices): - assert isinstance( - flattened_results[j][i], - tf.Tensor), ("currently only tensor return items are supported") - tensors.append(flattened_results[j][i]) - final_tree.append(ShardedNdArray(tensors)) - - return tf.nest.pack_sequence_as(results[0], final_tree) - - return wrapper + """Transforms a function into a multi-device function. + + The semantics are similar to JAX's pmap. + + Args: + f: The function to be converted. + axis_name: Used for nested pmap, which is not supported yet. + devices: The devices over which the returned function will run. + + Returns: + A function that runs the underlying function `f` on `devices`. Its arguments + can be `ShardedNdArray`s, tensors or other Python objects, and its return + values are all `ShardedNdArray`s. If an input is a tensor, the length of its + first dimension must equal the number of devices, and the tensor will be + splitted along its first dimension among the devices. If an input is an + unknown Python object, it will be replicated among the devices. + """ + if devices is None: + devices = accelerators() + if not isinstance(devices, (list, tuple)): + raise ValueError("Must pass a list or tuple of devices") + num_devices = len(devices) + if not num_devices: + raise ValueError("There must be at least 1 device") + has_tpu = bool(tpu_devices(devices)) + + pmap_fn = _get_pmap_impl(f, devices, has_tpu) + + def wrapper(*args): + """Wrapper that wraps/unwraps args, retvals, and runs the function.""" + if _pmap_config.devices() is not None: + raise ValueError( + "Found a surrounding pmap. Nested pmap is not supported " "yet." + ) + # TODO(wangpeng): Maybe we should use `asarray` to convert everything + # to ndarray first. + + flattened_input_args = tf.nest.flatten(args) + flattened_per_device_args = [[] for _ in devices] + for arg in flattened_input_args: + if isinstance(arg, tf.Tensor): + # TODO(nareshmodi): Try and use the dynamic shape instead. + if (not arg.shape.rank) or arg.shape[0] != len(devices): + # TODO(nareshmodi): Fix this restriction + raise ValueError( + "Input tensors need to have a first dimension equal to " + "the number of devices; got tensor of shape %s and %s devices" + % (arg.shape, len(devices)) + ) + # NOTE: Alternatively use tf.split, and place the split tensors on the + # appropriate device. The best solution for this is to have an API that + # splits a tensor across devices. + for j, device in enumerate(devices): + updated_arg = tf.gather(arg, j) + # TODO(wangpeng): Investigate whether we need a tf.identity for TPU. + if not has_tpu: + with tf.device(device): + updated_arg = tf.identity(updated_arg) + flattened_per_device_args[j].append(updated_arg) + elif isinstance(arg, ShardedNdArray): + for device_args, tensor in zip(flattened_per_device_args, arg.tensors): + device_args.append(tensor) + else: + for device_args in flattened_per_device_args: + device_args.append(arg) + + all_per_device_args = [ + tf.nest.pack_sequence_as(args, device_args) + for device_args in flattened_per_device_args + ] + + with pmap_config(axis_name, devices): + results = pmap_fn(all_per_device_args) + + # Rewrap things. This can probably be written better. + flattened_results = [tf.nest.flatten(result) for result in results] + final_tree = [] + + # TODO(nareshmodi): assert all items in flattened_results have the same + # structures + + for i in range(len(flattened_results[0])): + tensors = [] + for j, device in enumerate(devices): + assert isinstance( + flattened_results[j][i], tf.Tensor + ), "currently only tensor return items are supported" + tensors.append(flattened_results[j][i]) + final_tree.append(ShardedNdArray(tensors)) + + return tf.nest.pack_sequence_as(results[0], final_tree) + + return wrapper def find_devices(device_type, devices=None): - if not devices: - devices = [d.name for d in tf.config.experimental.list_logical_devices()] - devices = [(d, tf.DeviceSpec.from_string(d)) for d in devices] - results = [name for name, d in devices if d.device_type == device_type] - return results + if not devices: + devices = [d.name for d in tf.config.experimental.list_logical_devices()] + devices = [(d, tf.DeviceSpec.from_string(d)) for d in devices] + results = [name for name, d in devices if d.device_type == device_type] + return results def tpu_devices(devices=None): - """Gets TPU devices out of `devices`. + """Gets TPU devices out of `devices`. - Args: - devices: A device list (as a list of strings). If None, the list of all - available devices will be used for it. + Args: + devices: A device list (as a list of strings). If None, the list of all + available devices will be used for it. - Returns: - Those in `devices` that are TPUs. - """ - return find_devices("TPU", devices) + Returns: + Those in `devices` that are TPUs. + """ + return find_devices("TPU", devices) def gpu_devices(devices=None): - """Gets GPU devices out of `devices`. + """Gets GPU devices out of `devices`. - Args: - devices: A device list (as a list of strings). If None, the list of all - available devices will be used for it. + Args: + devices: A device list (as a list of strings). If None, the list of all + available devices will be used for it. - Returns: - Those in `devices` that are GPUs. - """ - return find_devices("GPU", devices) + Returns: + Those in `devices` that are GPUs. + """ + return find_devices("GPU", devices) def accelerators(devices=None): - return tpu_devices(devices) or gpu_devices(devices) + return tpu_devices(devices) or gpu_devices(devices) def _tree_broadcast(to, s): - """Broadcasts `s` to the nested structure `to`.""" - if not isinstance(to, (list, tuple, dict)): - if not isinstance(s, (int, type(None))): - raise ValueError - return s - if isinstance(s, (int, type(None))): - return tf.nest.map_structure(lambda x: s, to) - if isinstance(to, (list, tuple)): - if len(to) != len(s): - raise ValueError - new_s = [_tree_broadcast(x, y) for x, y in zip(to, s)] - if isinstance(to, tuple): - new_s = tuple(new_s) - return new_s - elif isinstance(to, dict): - return {k: _tree_broadcast(to[k], s[k]) for k in to} - else: - raise TypeError("Unsupported type %s" % type(to)) + """Broadcasts `s` to the nested structure `to`.""" + if not isinstance(to, (list, tuple, dict)): + if not isinstance(s, (int, type(None))): + raise ValueError + return s + if isinstance(s, (int, type(None))): + return tf.nest.map_structure(lambda x: s, to) + if isinstance(to, (list, tuple)): + if len(to) != len(s): + raise ValueError + new_s = [_tree_broadcast(x, y) for x, y in zip(to, s)] + if isinstance(to, tuple): + new_s = tuple(new_s) + return new_s + elif isinstance(to, dict): + return {k: _tree_broadcast(to[k], s[k]) for k in to} + else: + raise TypeError("Unsupported type %s" % type(to)) def vmap(f, in_axes=0, out_axes=0): - """Returns a function that maps `f` over first dimension of inputs.""" - in_axes_flat = tf.nest.flatten(in_axes) - if not all(isinstance(l, (type(None), int)) - for l in in_axes_flat): - raise TypeError( - "vmap in_axes must be an int, None, or (nested) container with " - "those types as leaves, but got {}.".format(in_axes)) - if all(isinstance(l, type(None)) for l in in_axes_flat): - raise ValueError("vmap must have at least one non-None value in in_axes") - - out_axes_flat = tf.nest.flatten(out_axes) - if not all(isinstance(l, (type(None), int)) - for l in out_axes_flat): - raise TypeError( - "vmap out_axes must be an int, None, or (nested) container with " - "those types as leaves, but got {}.".format(out_axes)) - - def _f(*args): - flat_args = tf.nest.flatten(args) - try: - f_in_axes = _tree_broadcast(args, in_axes) - except ValueError: - six.reraise( - ValueError, - ValueError( - "vmap in_axes specification must be a tree prefix of the " - r"corresponding value, got specification %s for value tree %s" % ( - in_axes, args)), - sys.exc_info()[2]) - f_in_axes_flat = tf.nest.flatten(f_in_axes) - - def tf_f(tf_args): - """Function passed to tf.vectorized_map call.""" - # Note that unbatched arguments are not passed to tf_f. Here we fill thos - # arguments back before calling `f`. - tf_flat_args = [] - j = 0 - for arg, axis in zip(flat_args, f_in_axes_flat): - if axis is None: - tf_flat_args.append(arg) - else: - tf_flat_args.append(tf_args[j]) - j += 1 - unbatched_args = tf.nest.pack_sequence_as(args, tf_flat_args) - return f(*unbatched_args) - - # Constructs arguments to pass to `tf_f`. - # Unbatch arguments are skipped. Arguments with non-zero axis are - # transposed. - tf_args = [] - for arg, axis in zip(flat_args, f_in_axes_flat): - if axis is None: - continue - arg = tf_np.asarray(arg) - if axis != 0: - arg = tf_np.moveaxis(arg, axis, 0) - tf_args.append(arg) - # TODO(agarwal): consider creating a tf.function outside of _f and reusing - # that to avoid overheads of re-vectorizing the code when running eagerly. - outputs = tf.vectorized_map(tf_f, tf_args) - try: - f_out_axes = _tree_broadcast(outputs, out_axes) - except ValueError: - six.reraise( - ValueError, - ValueError( - "vmap out_axes specification must be a tree prefix of the " - r"corresponding value, got specification %s for value tree %s" % ( - out_axes, outputs)), - sys.exc_info()[2]) - - def map_output(x, axis): - """Maps output of tf.vectorized_map to the final output.""" - x = tf_np.asarray(x) - if axis is None: - # Note that `tf.vectorized_map always batches the outputs. - # Here we unbatch it again. - return x[0, ...] - elif axis == 0: - return x - else: - # Need to transpose the output. - return tf_np.moveaxis(x, 0, axis) - new_outputs = [map_output(output, axis) for output, axis in zip( - tf.nest.flatten(outputs), tf.nest.flatten(f_out_axes))] - return tf.nest.pack_sequence_as(outputs, new_outputs) - - return _f + """Returns a function that maps `f` over first dimension of inputs.""" + in_axes_flat = tf.nest.flatten(in_axes) + if not all(isinstance(l, (type(None), int)) for l in in_axes_flat): + raise TypeError( + "vmap in_axes must be an int, None, or (nested) container with " + "those types as leaves, but got {}.".format(in_axes) + ) + if all(isinstance(l, type(None)) for l in in_axes_flat): + raise ValueError("vmap must have at least one non-None value in in_axes") + + out_axes_flat = tf.nest.flatten(out_axes) + if not all(isinstance(l, (type(None), int)) for l in out_axes_flat): + raise TypeError( + "vmap out_axes must be an int, None, or (nested) container with " + "those types as leaves, but got {}.".format(out_axes) + ) + + def _f(*args): + flat_args = tf.nest.flatten(args) + try: + f_in_axes = _tree_broadcast(args, in_axes) + except ValueError: + six.reraise( + ValueError, + ValueError( + "vmap in_axes specification must be a tree prefix of the " + r"corresponding value, got specification %s for value tree %s" + % (in_axes, args) + ), + sys.exc_info()[2], + ) + f_in_axes_flat = tf.nest.flatten(f_in_axes) + + def tf_f(tf_args): + """Function passed to tf.vectorized_map call.""" + # Note that unbatched arguments are not passed to tf_f. Here we fill thos + # arguments back before calling `f`. + tf_flat_args = [] + j = 0 + for arg, axis in zip(flat_args, f_in_axes_flat): + if axis is None: + tf_flat_args.append(arg) + else: + tf_flat_args.append(tf_args[j]) + j += 1 + unbatched_args = tf.nest.pack_sequence_as(args, tf_flat_args) + return f(*unbatched_args) + + # Constructs arguments to pass to `tf_f`. + # Unbatch arguments are skipped. Arguments with non-zero axis are + # transposed. + tf_args = [] + for arg, axis in zip(flat_args, f_in_axes_flat): + if axis is None: + continue + arg = tf_np.asarray(arg) + if axis != 0: + arg = tf_np.moveaxis(arg, axis, 0) + tf_args.append(arg) + # TODO(agarwal): consider creating a tf.function outside of _f and reusing + # that to avoid overheads of re-vectorizing the code when running eagerly. + outputs = tf.vectorized_map(tf_f, tf_args) + try: + f_out_axes = _tree_broadcast(outputs, out_axes) + except ValueError: + six.reraise( + ValueError, + ValueError( + "vmap out_axes specification must be a tree prefix of the " + r"corresponding value, got specification %s for value tree %s" + % (out_axes, outputs) + ), + sys.exc_info()[2], + ) + + def map_output(x, axis): + """Maps output of tf.vectorized_map to the final output.""" + x = tf_np.asarray(x) + if axis is None: + # Note that `tf.vectorized_map always batches the outputs. + # Here we unbatch it again. + return x[0, ...] + elif axis == 0: + return x + else: + # Need to transpose the output. + return tf_np.moveaxis(x, 0, axis) + + new_outputs = [ + map_output(output, axis) + for output, axis in zip( + tf.nest.flatten(outputs), tf.nest.flatten(f_out_axes) + ) + ] + return tf.nest.pack_sequence_as(outputs, new_outputs) + + return _f diff --git a/trax/tf_numpy/extensions/extensions_test.py b/trax/tf_numpy/extensions/extensions_test.py deleted file mode 100644 index 065dbdeef..000000000 --- a/trax/tf_numpy/extensions/extensions_test.py +++ /dev/null @@ -1,1060 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tf numpy mathematical methods.""" -import functools -import itertools - -from absl import flags -from absl.testing import parameterized - -from jax import lax -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy import extensions -import trax.tf_numpy.numpy as tf_np - - -FLAGS = flags.FLAGS - -flags.DEFINE_bool("requires_tpu", False, "Requires TPU.") - - -def generate_params_inputs_targets(num_examples=1000): - params = (tf_np.asarray(tf.constant(5.)), tf_np.asarray(tf.constant(0.))) - - params_true = (tf_np.asarray(tf.constant(3.)), tf_np.asarray(tf.constant(2.))) - - inputs = tf_np.asarray(tf.random.normal(shape=[num_examples])) - noise = tf_np.asarray(tf.random.normal(shape=[num_examples])) - targets = inputs * params_true[0] + params_true[1] + noise - - return params, params_true, inputs, targets - - -def loss_fn(params, inputs, targets): - predicted = params[0] * inputs + params[1] - loss = tf.reduce_mean(input_tensor=tf.square(predicted - targets)) - return tf_np.asarray(loss) - - -def train_step(params, inputs, targets, learning_rate=0.1): - grad_fn = extensions.grad(loss_fn) - grads = grad_fn(params, inputs, targets) - new_w = params[0] - (grads[0] * learning_rate) - new_b = params[1] - (grads[1] * learning_rate) - - return new_w, new_b - - -def uniform(rng, shape, dtype): - if np.issubdtype(dtype, np.integer): - minval = None - else: - minval = 0 - return tf_np.asarray(rng.uniform(shape=shape, dtype=dtype, minval=minval)) - - -def to_np(a): - return tf.nest.map_structure(tf_np.asarray, a) - - -def to_tf_fn(f): - return lambda *args: f(*to_np(args)) - - -def scan_reference(f, init, xs): - carry = init - ys = [] - for x in xs: - (carry, y) = f(carry, x) - ys.append(tf_np.reshape(y, (1,) + y.shape)) - ys = tf_np.concatenate(ys, 0) - return carry, ys - - -def spec(*args): - return tf.TensorSpec(args, tf.float32) - - -class ExtensionsTest(tf.test.TestCase, parameterized.TestCase): - - def __init__(self, methodName="runTest"): # pylint: disable=invalid-name - super().__init__(methodName) - physical_devices = tf.config.experimental.list_physical_devices("CPU") - tf.config.experimental.set_virtual_device_configuration( - physical_devices[0], [ - tf.config.experimental.VirtualDeviceConfiguration(), - tf.config.experimental.VirtualDeviceConfiguration() - ]) - if extensions.tpu_devices(): - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") - tf.tpu.experimental.initialize_tpu_system(resolver) - - def _hasGPU(self): - physical_devices = tf.config.experimental.list_physical_devices("GPU") - return physical_devices - - def testCustomGrad(self): - """Test for custom_grad.""" - x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) - y_shape = (tf.TensorShape([])) - dtype = np.float32 - scale1 = 5.0 - scale2 = 6.0 - - def fwd(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - @extensions.custom_grad - def f(a, b): - y = fwd(a, b) - - def vjp(dy): - return dy * scale1 * a, dy * scale2 * b - - return y, vjp - - rng = tf.random.Generator.from_seed(1234) - x, dy = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype), - [x_shape, y_shape]) - expected_y = fwd(*x) - expected_dx = (dy * scale1 * x[0], dy * scale2 * x[1]) - y, vjp = extensions.vjp(f, *x) - dx = vjp(dy) - self.assertAllClose(expected_y, y) - self.assertAllClose(expected_dx, dx) - - @parameterized.named_parameters([ - ( # pylint: disable=g-complex-comprehension - ("_%s_%s_%s" % (decorator_id, x_struct, y_struct)).replace( - " ", "").replace("None", ""), decorator, x_struct, y_struct) - for y_struct in [[None, ()], (None, (), [], (None, ((), None)))] - for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))] - for decorator_id, decorator in enumerate([lambda f: f, extensions.jit]) - ]) - def testCustomGradStructure(self, decorator, x_struct, y_struct): - """Tests that custom_grad can handle structured inputs/outputs.""" - - def zeros(x): - return tf.nest.map_structure(lambda _: tf_np.zeros([], np.float32), x) - - def get_struct(x): - return tf.nest.map_structure(lambda _: None, x) - - @extensions.custom_grad - def f(*x): - del x - - def vjp(dy): - self.assertEqual(y_struct, get_struct(dy)) - return zeros(x_struct) - - return zeros(y_struct), vjp - - x, dy = zeros([x_struct, y_struct]) - - @decorator - def run(x, dy): - y, vjp = extensions.vjp(f, *x) - dx = vjp(dy) - return dx, y - - dx, y = run(x, dy) - self.assertEqual(x_struct, get_struct(dx)) - self.assertEqual(y_struct, get_struct(y)) - - @parameterized.named_parameters([ - ("_%s" % has_aux, has_aux) for has_aux in [True, False] - ]) - def testVjp(self, has_aux): - x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) - y_shape = (tf.TensorShape([])) - dtype = np.float32 - - def f(a, b): - y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - if has_aux: - return y, tf_np.asarray(1) - else: - return y - - rng = tf.random.Generator.from_seed(1234) - x, dy_list = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype), - [x_shape, [y_shape] * 2]) - tf_x = x - outputs = extensions.vjp(f, *x, has_aux=has_aux) - if has_aux: - y, vjp, aux = outputs - else: - y, vjp = outputs - with tf.GradientTape(persistent=True) as tape: - tape.watch(tf_x) - outputs = f(*x) - if has_aux: - expected_y, expected_aux = outputs - self.assertAllClose(expected_aux, aux) - else: - expected_y = outputs - self.assertAllClose(expected_y, y) - for dy in dy_list: - expected_dx = tape.gradient( - expected_y, tf_x, output_gradients=dy) - self.assertAllClose(expected_dx, vjp(dy)) - - def testGrad(self): - - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - g = extensions.grad(f) - - def compare(a, b): - with tf.GradientTape() as tape: - tape.watch(a) - r = f(a, b) - expected = tape.gradient(r, a) - self.assertAllEqual(expected, g(a, b)) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - compare(a, b) - - def testGradNonArrayOutput(self): - - def f(_): - return 1.0 - - g = extensions.grad(f) - with self.assertRaisesWithPredicateMatch(ValueError, - r"result .* must be an ndarray"): - g(tf_np.asarray(1.0)) - - def testGradNonScalarOutput(self): - - def f(a): - return a - - g = extensions.grad(f) - with self.assertRaisesWithPredicateMatch(ValueError, - r"result .* must be a scalar"): - g(tf_np.asarray([1.0, 2.0])) - - @extensions.jit - def g_jitted(a): - return extensions.grad(f)(a) - - g_jitted(tf_np.asarray(1.0)) - with self.assertRaisesWithPredicateMatch(ValueError, - r"result .* must be a scalar"): - g_jitted(tf_np.asarray([1.0, 2.0])) - - def testJit(self): - - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - f_jitted = extensions.jit(f) - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - self.assertAllClose(f(a, b), f_jitted(a, b)) - # Call again since the code path is different on second call - self.assertAllClose(f(a, b), f_jitted(a, b)) - - def testJitNoUnnecessaryTracing(self): - - def num_traces(f): - return len(f.tf_function._list_all_concrete_functions_for_serialization()) - - def check_trace_only_once(arg1, arg2): - - @extensions.jit - def f(a): - return a + 1 - - self.assertAllEqual(0, num_traces(f)) - f(arg1) - self.assertAllEqual(1, num_traces(f)) - f(arg2) - self.assertAllEqual(1, num_traces(f)) - - check_trace_only_once(1, 2) - check_trace_only_once(1.1, 2.1) - check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2)) - check_trace_only_once( - tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2)) - - def _testEvalOnShapes(self, transformer, allow_static_outputs): - - # A class that's not convertable to tensor - class Thing: - - def __init__(self, value): - self.value = value - - def f(a, b, reverse=False): - res = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - res = (res, 10) - if allow_static_outputs: - res = res + (Thing(20),) - if reverse: - res = tuple(reversed(res)) - return res - - f_prime = transformer( - f, static_argnums=(2,), allow_static_outputs=allow_static_outputs) - shape = [10] - dtype = np.float16 - a = tf_np.zeros(shape=shape, dtype=dtype) - b = tf_np.zeros(shape=shape, dtype=dtype) - expected, *_ = f(a, b) - got = f_prime(a, b) - def check(got): - self.assertIsInstance(got[0], (tf.TensorSpec, tf_np.ndarray)) - self.assertAllEqual(expected.shape, got[0].shape) - self.assertAllEqual(expected.dtype, got[0].dtype) - if allow_static_outputs: - self.assertIsInstance(got[1], int) - self.assertEqual(10, got[1]) - self.assertIsInstance(got[2], Thing) - self.assertEqual(20, got[2].value) - else: - self.assertIsInstance(got[1], (tf.TensorSpec, tf_np.ndarray)) - self.assertAllEqual((), got[1].shape) - check(got) - # Call again since the code path is different on second call - got = f_prime(a, b) - check(got) - # Retrace and check again - got = f_prime(a, b, True) - check(tuple(reversed(got))) - got = f_prime(a, b, True) - check(tuple(reversed(got))) - - @parameterized.named_parameters(("_%s" % b, b) for b in [False, True]) - def testEvalOnShapes(self, allow_static_outputs): - self._testEvalOnShapes(extensions.eval_on_shapes, allow_static_outputs) - - def testEvalOnShapesNested(self): - transformer = functools.partial(extensions.eval_on_shapes, - allow_static_outputs=True) - @transformer - def outer(): - @transformer - def inner(): - return 1 - return inner() + 2 - r = outer() - self.assertIsInstance(r, int) - self.assertEqual(3, r) - - def testJitOfEvalOnShapes(self): - """Tests that eval_on_shapes can be called within jit.""" - - def transformer(f, **kwargs): - def f_prime(*args): - res = extensions.eval_on_shapes(f, **kwargs)(*args) - return tf.nest.map_structure( - lambda x: tf_np.zeros(x.shape, x.dtype), res) - return extensions.jit(f_prime, kwargs.get("static_argnums", ())) - - self._testEvalOnShapes(transformer, False) - - def testEvalOnShapesNoUnnecessaryTracing(self): - - def num_traces(f): - return len( - f._tf_function._list_all_concrete_functions_for_serialization()) - - def check_trace_only_once(arg1, arg2): - - @extensions.eval_on_shapes - def f(a): - return a + 1 - - self.assertAllEqual(0, num_traces(f)) - f(arg1) - self.assertAllEqual(1, num_traces(f)) - f(arg2) - self.assertAllEqual(1, num_traces(f)) - - check_trace_only_once(1, 2) - check_trace_only_once(1.1, 2.1) - check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2)) - check_trace_only_once( - tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2)) - - @parameterized.parameters( - { - "lhs_np": np.ones((5, 3)), - "rhs_np": np.ones((3, 2)), - "dims": (((1,), (0,)), ((), ())) - }, - { - "lhs_np": np.ones((5, 3)), - "rhs_np": np.ones((5, 3)), - "dims": (((0, 1), (0, 1)), ((), ())) - }, - { - "lhs_np": np.ones((5, 3, 2)), - "rhs_np": np.ones((2, 3, 2)), - "dims": (((1, 2), (1, 0)), ((), ())) - }, - { - "lhs_np": np.ones((6, 5, 3)), - "rhs_np": np.ones((6, 3, 2)), - "dims": (((2,), (1,)), ((0,), (0,))) - }, - { - "lhs_np": np.ones((6, 3, 5)), - "rhs_np": np.ones((6, 3, 2)), - "dims": (((1,), (1,)), ((0,), (0,))) - }, - { - "lhs_np": np.ones((5, 3, 2, 2)), - "rhs_np": np.ones((5, 2, 2, 6)), - "dims": (((2, 3), (1, 2)), ((0,), (0,))) - }, - { - "lhs_np": np.ones((2, 2, 5, 3)), - "rhs_np": np.ones((2, 2, 3, 2)), - "dims": (((3,), (2,)), ((0, 1), (0, 1))) - }, - { - "lhs_np": np.ones((2, 2, 5, 2)), - "rhs_np": np.ones((2, 2, 3, 2)), - "dims": (((3,), (1,)), ((0,), (0,))) - }, - { - "lhs_np": np.ones((2, 2, 5, 3, 3)), - "rhs_np": np.ones((2, 3, 2, 3, 2)), - "dims": (((4,), (1,)), ((0,), (0,))) - }, - ) - def test_tf_dot_general(self, lhs_np, rhs_np, dims): - ans = lax.dot_general(lhs_np, rhs_np, dims) - result = extensions.tf_dot_general(lhs_np, rhs_np, dims) - self.assertAllClose(result, np.array(ans)) - - @parameterized.named_parameters([ - ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}" # pylint: disable=g-complex-comprehension - "_lhs_dilation={}_rhs_dilation={}" - "_feature_group_count={}_batch_group_count={}_dims={}" - "_perms={}".format(lhs_shape, rhs_shape, - strides, padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, ",".join( - dimension_numbers), perms), - lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, dimension_numbers, perms) - for batch_group_count, feature_group_count in [(1, 1)] - for lhs_shape, rhs_shape in [ - ((b * batch_group_count, i * feature_group_count, 9, w), - (j * feature_group_count * batch_group_count, i, 4, 5)) - for w in [0, 10] - for b, i, j in itertools.product([2, 3], repeat=3)] - for strides in [(1, 1), (2, 1)] - for padding in ["SAME"] - for lhs_dilation, rhs_dilation in [ - (None, (1, 1)) - ] - for dimension_numbers, perms in [ - (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) - ]]) - def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, - padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, - dimension_numbers, perms): - lhs_perm, rhs_perm = perms # permute to compatible shapes - - lhs = np.transpose(np.ones(lhs_shape), lhs_perm) - rhs = np.transpose(np.ones(rhs_shape), rhs_perm) - - jax_conv = lax.conv_general_dilated(lhs, rhs, strides, padding, - lhs_dilation, rhs_dilation, - dimension_numbers, - feature_group_count, - batch_group_count) - - tf_conv = extensions.tf_conv_general_dilated(lhs, rhs, strides, - padding, None, - lhs_dilation, rhs_dilation, - dimension_numbers, - feature_group_count, - batch_group_count) - - self.assertAllClose(tf_conv, tf_np.asarray(jax_conv)) - - def testConv(self): - y = extensions.conv( - np.ones([5, 320, 480, 3], dtype=np.float32), - np.ones([3, 4, 3, 11], dtype=np.float32), [1, 1], "SAME", - ("NHWC", "HWIO", "NHWC")) - self.assertAllClose(y.shape, [5, 320, 480, 11]) - self.assertAllClose( - y, - tf.nn.conv2d( - input=tf.ones([5, 320, 480, 3], dtype=tf.float32), - filters=tf.ones([3, 4, 3, 11], dtype=tf.float32), - strides=1, - padding="SAME")) - - def testAvgPool(self): - y = extensions.avg_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID") - self.assertAllEqual( - y, - tf.nn.pool( - input=tf.ones([5, 320, 480, 3]), - window_shape=[3, 5], - pooling_type="AVG", - padding="VALID", - strides=[2, 3], - )) - - def testMaxPool(self): - y = extensions.max_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID") - self.assertAllEqual( - y, - tf.nn.pool( - input=tf.ones([5, 320, 480, 3]), - window_shape=[3, 5], - pooling_type="MAX", - padding="VALID", - strides=[2, 3], - )) - - def assertDTypesEqual(self, a, b): - get_dtype = lambda t: t.dtype - self.assertEqual(tf.nest.map_structure(get_dtype, a), - tf.nest.map_structure(get_dtype, b)) - - @parameterized.named_parameters( - (f"_{jit_scan}_{jit_f}", jit_scan, jit_f) # pylint: disable=g-complex-comprehension - for jit_f in [False, True] - for jit_scan in ["no", "no_xla", "xla_forced_compile"]) - def testScanImpl(self, jit_scan, jit_f): - rng = np.random.RandomState(0) - - d = rng.randn(2) - def f(c, a): - assert a.shape == (3,) - assert c.shape == (4,) - b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + - tf_np.sum(tf_np.tan(d))) - c = tf_np.sin(c * b) - assert b.shape == () # pylint: disable=g-explicit-bool-comparison - return c, b - - if jit_f: - f = extensions.jit(f) - - if jit_scan == "no_xla": - scan = extensions.jit(extensions.scan, static_argnums=(0,)) - elif jit_scan == "xla_forced_compile": - scan = extensions.jit(extensions.scan, static_argnums=(0,), - xla_forced_compile=True) - else: - scan = extensions.scan - - xs = rng.randn(5, 3) - c = rng.randn(4) - - ans = scan(f, c, xs) - expected = scan_reference(f, c, xs) - if jit_scan == "xla_forced_compile": - # xla.compile doesn't preserve list-vs-tuple properly for the outputs, so - # we canonicalize them to lists here. - expected = list(expected) - ans = list(ans) - self.assertDTypesEqual(expected, ans) - self.assertAllClose(expected, ans) - - def testScanStruct(self): - rng = np.random.RandomState(0) - - d = rng.randn(2) - def f(c_g_i, a_e_h): - c_g, i = c_g_i - c, g = c_g - a, e_h = a_e_h - e, h = e_h - assert a.shape == (3,) - assert e.shape == () # pylint: disable=g-explicit-bool-comparison - assert c.shape == (4,) - assert g.shape == (2,) - assert i is None - assert h is None - b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + - tf_np.sum(tf_np.tan(d))) - f = tf_np.cos(a) - c = tf_np.sin(c * b) - g = tf_np.sin(g * b) - assert b.shape == () # pylint: disable=g-explicit-bool-comparison - assert f.shape == (3,) - return [(c, g), i], (b, [f, h]) - - xs = (rng.randn(5, 3), [rng.randn(5), None]) - init = [(rng.randn(4), rng.randn(2)), None] - - c_g_i, b_f_h = extensions.scan(f, init, xs) - self.assertIsInstance(c_g_i, list) - self.assertIsInstance(b_f_h, tuple) - c_g, i = c_g_i - c, g = c_g - self.assertIsInstance(c_g, tuple) - self.assertEqual((4,), c.shape) - self.assertEqual((2,), g.shape) - self.assertIsNone(i) - b, f_h = b_f_h - f, h = f_h - self.assertIsInstance(f_h, list) - self.assertEqual((5,), b.shape) - self.assertEqual((5, 3), f.shape) - self.assertIsNone(h) - - @parameterized.named_parameters( - (f"_{jit_grad}_{jit_scan}_{jit_f}", jit_grad, jit_scan, jit_f) # pylint: disable=g-complex-comprehension - for jit_f in [False, True] - for jit_scan in ["no", "no_xla", "xla_forced_compile"] - for jit_grad in ["no", "no_xla", "xla_forced_compile"]) - def testScanGrad(self, jit_grad, jit_scan, jit_f): - rng = np.random.RandomState(0) - - d = rng.randn(2) - def f(c, a): - assert a.shape == (3,) - assert c.shape == (4,) - b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) + - tf_np.sum(tf_np.sin(d))) - c = tf_np.sin(c * b) - assert b.shape == () # pylint: disable=g-explicit-bool-comparison - return c, b - - if jit_f: - f = extensions.jit(f) - - if jit_scan == "no_xla": - scan = extensions.jit(extensions.scan, static_argnums=(0,)) - elif jit_scan == "xla_forced_compile": - # TODO(b/187107596): Remove `skipTest` - self.skipTest( - "Taking gradients of `jit(scan, experimental_compile=True)` triggers " - "'Support for TensorList crossing the XLA/TF boundary is not " - "implemented' error") - # `xla_forced_compile=True` doesn't support gradients, so we use - # `experimental_compile=True`. - scan = extensions.jit(extensions.scan, static_argnums=(0,), - experimental_compile=True) - else: - scan = extensions.scan - - xs = tf_np.asarray(rng.randn(5, 3)) - c = tf_np.asarray(rng.randn(4)) - - def losses(scan, c, xs): - c, ys = scan(f, c, xs) - return tf_np.concatenate(tf.nest.flatten(tf.nest.map_structure( - lambda a: tf_np.reshape(a, [-1]), (c, ys)))) - def loss(scan, c, xs): - return tf_np.sum(losses(scan, c, xs)) - - def grad_origin(c, xs): - return extensions.grad(functools.partial(loss, scan))(c, xs) - - if jit_grad == "no_xla": - grad_jit = extensions.jit(grad_origin) - elif jit_grad == "xla_forced_compile": - grad_jit = extensions.jit(grad_origin, xla_forced_compile=True) - else: - grad_jit = grad_origin - - ans = grad_jit(c, xs) - expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs) - self.assertDTypesEqual(expected, ans) - self.assertAllClose(expected, ans) - - theoretical, numerical = tf.test.compute_gradient( - to_tf_fn(functools.partial(losses, scan)), (c, xs)) - self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4) - - @parameterized.named_parameters( - (f"_{i}", *args) # pylint: disable=g-complex-comprehension - for i, args in enumerate([ - (lambda c, x: (c + 1, tf_np.sum(c + x, 0)), - [spec(2), spec(4, 3, 2)], [spec(2), spec(4, 2)]), - (lambda c, x: (c + 1, tf_np.sum(c + x, 0)), - [spec(2), spec(0, 3, 2), 0], [spec(2), spec(0, 2)]), - ])) - def testScanShape(self, f, inputs, expected_outputs): - outputs = extensions.eval_on_shapes( - functools.partial(extensions.scan, f), static_argnums=(2,))(*inputs) - self.assertAllEqual(expected_outputs, outputs) - - def testMap(self): - shape = [2, 3] - dtype = tf_np.int32 - xs1 = tf_np.zeros(shape, dtype) - xs2 = tf_np.ones(shape, dtype) - ys_expected = [xs2 + 10, xs1 + 20] - def f(x): - self.assertIsInstance(x, tuple) - for a in x: - self.assertEqual(a.shape, shape[1:]) - x1, x2 = x - return [x2 + 10, x1 + 20] - ys = extensions.tf_map(f, (xs1, xs2)) - self.assertIsInstance(ys, list) - self.assertAllClose(ys, ys_expected) - - def testPrng(self): - self.assertAllEqual(tf_np.asarray(123, np.int64), extensions.prng(123)) - - def testUniform(self): - minval = 0.43 - maxval = 3.10 - shape = [13, 34, 29] - atol = 0.1 - outputs = extensions.uniform(123, shape, minval=minval, maxval=maxval) - self.assertAllClose((minval + maxval) / 2.0, np.mean(outputs), atol=atol) - - def testNormal(self): - shape = [13, 34, 29] - atol = 0.1 - outputs = extensions.normal(123, shape) - self.assertAllClose(0, np.mean(outputs), atol=atol) - self.assertAllClose(1, np.std(outputs), atol=atol) - - def testBernoulli(self): - mean = 0.23 - shape = [13, 34, 29] - atol = 0.1 - outputs = extensions.bernoulli(123, mean, shape) - self.assertAllClose(mean, np.mean(outputs), atol=atol) - - def testBernoulliWrongShape(self): - mean = [0.1, 0.2] - shape = [3] - with self.assertRaisesIncompatibleShapesError(): - extensions.bernoulli(123, mean, shape) - - def testDatasetAsNumpy(self): - arrs = extensions.dataset_as_numpy( - [tf.constant([1, 2]), tf.constant([3, 4])]) - for a in arrs: - self.assertIsInstance(a, tf_np.ndarray) - with self.assertRaisesWithPredicateMatch( - ValueError, - r"dataset_as_numpy must be run in eager mode outside tf.function"): - - @tf.function - def f(): - return extensions.dataset_as_numpy([tf.constant([1, 2])]) - - f() - - def _get_two_devices(self, require_same_type=False): - tpus = extensions.tpu_devices() - if FLAGS.requires_tpu: - if len(tpus) == 2: - res = tpus - else: - raise ValueError("This test requires 2 TPU cores but %s are found" % - len(tpus)) - else: - if len(tpus) == 2: - res = tpus - elif self._hasGPU() and not require_same_type: - res = ("CPU:0", "GPU:0") - else: - res = ("CPU:0", "CPU:1") - return res - - def testPmap(self): - devices = self._get_two_devices() - - @functools.partial(extensions.pmap, devices=devices) - def return_three(f): - return f, f + 1.0, f + 2.0 - - result = return_three(tf.ones((2, 20))) - # The function returned 3 items, so we got 3 items back. - self.assertLen(result, 3) - - # Each of the items should be a ShardedNdarray that when converted to tensor - # should produce a tensor of shape (2, 20) - converted = tf.nest.map_structure(tf.convert_to_tensor, result) - - self.assertLen(result, 3) - - self.assertAllEqual(converted[0].shape, converted[1].shape) - self.assertAllEqual(converted[0].shape, converted[2].shape) - - self.assertAllEqual(converted[0], tf.ones((2, 20))) - self.assertAllEqual(converted[1], 1 + tf.ones((2, 20))) - self.assertAllEqual(converted[2], 2 + tf.ones((2, 20))) - - @functools.partial(extensions.pmap, devices=devices) - def return_one(f): - return f + 2.0 - - result = return_one(tf.ones((2, 20))) - - # Only a single item is returned, so we can convert it directly. - converted = tf.convert_to_tensor(value=result) - self.assertAllEqual(converted, 2 + tf.ones((2, 20))) - - @functools.partial(extensions.pmap, devices=devices) - def return_list(f): - return [f + 2.0] - - result = return_list(tf.ones((2, 20))) - - # A singleton list is returned. - self.assertLen(result, 1) - converted = tf.convert_to_tensor(value=result[0]) - self.assertAllEqual(converted, 2 + tf.ones((2, 20))) - - def testGradSimpleModel(self): - params, params_true, inputs, targets = generate_params_inputs_targets() - - for _ in range(50): - params = train_step(params, inputs, targets) - - # This is not trained super well, but it usually gets "close". - self.assertAllClose(params[0], params_true[0], atol=1e-1) - self.assertAllClose(params[1], params_true[1], atol=1e-1) - - # NOTE: Compare to testGradSimpleModel to see the differences when pmapping. - def testPmapSimpleModel(self): - devices = self._get_two_devices(require_same_type=True) - n_devices = len(devices) - - params, params_true, inputs, targets = generate_params_inputs_targets() - - def _train_and_reduce(params, inputs, targets, learning_rate=0.1): - new_w, new_b = train_step(params, inputs, targets, learning_rate) - - return (extensions.psum(new_w) / n_devices, - extensions.psum(new_b) / n_devices) - - train_step_pmapped = extensions.pmap(_train_and_reduce, devices=devices) - - def replicate(x, num_devices=2): - return tf_np.broadcast_to(x, (num_devices,) + x.shape) - - params = tf.nest.map_structure(replicate, params) - - def reshape(x, num_devices=2): - x_shape = list(x.shape) - batch_size = x_shape[0] - batch_size_per_device = batch_size // num_devices - - # New shape. - new_shape_prefix = [num_devices, batch_size_per_device] - return tf_np.reshape(x, new_shape_prefix + x_shape[1:]) - - inputs = tf.nest.map_structure(reshape, inputs) - targets = tf.nest.map_structure(reshape, targets) - - for _ in range(50): - params = train_step_pmapped(params, inputs, targets) - - # PMAP returns sharded tensors. - - # Since the inputs are identical, the returned tensors should be identical - self.assertAllClose(params[0][0], params[0][1]) - self.assertAllClose(params[1][0], params[1][1]) - - # This is not trained super well, but it usually gets "close". - self.assertAllClose(params[0][0], params_true[0], atol=1e-1) - self.assertAllClose(params[1][0], params_true[1], atol=1e-1) - - def testPsum(self): - devices = self._get_two_devices(require_same_type=True) - - def reduce_sum(f): - return extensions.psum(f) - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - pmapped = extensions.pmap(reduce_sum, devices=devices) - result = pmapped(data) - - self.assertAllClose(result[0], 4) - self.assertAllClose(result[1], 4) - - def testPsumStruct(self): - devices = self._get_two_devices(require_same_type=True) - - def reduce_sum(a): - a = extensions.psum(a) - tf.nest.map_structure( - lambda x: self.assertIsInstance(x, tf_np.ndarray), a) - return a - - data = [tf_np.asarray([1, 3]), tf_np.asarray([2, 4], np.int64)] - pmapped = extensions.pmap(reduce_sum, devices=devices) - result = pmapped(data) - - self.assertIsInstance(result[0][0], tf_np.ndarray) - self.assertIsInstance(result[0][1], tf_np.ndarray) - self.assertIsInstance(result[1][0], tf_np.ndarray) - self.assertIsInstance(result[1][1], tf_np.ndarray) - self.assertAllClose(result[0][0], 4) - self.assertAllClose(result[0][1], 4) - self.assertAllClose(result[1][0], 6) - self.assertAllClose(result[1][1], 6) - - def testPmean(self): - if extensions.tpu_devices(): - self.skipTest("pmean for TPU is not supported yet") - devices = self._get_two_devices(require_same_type=True) - - def reduce_mean(f): - return extensions.pmean(f) - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - pmapped = extensions.pmap(reduce_mean, devices=devices) - result = pmapped(data) - - self.assertAllClose(result[0], 2) - self.assertAllClose(result[1], 2) - - def testAxisName(self): - devices = self._get_two_devices(require_same_type=True) - - def reduce_sum(f): - return extensions.psum(f, axis_name="foo") - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) - pmapped(data) - - def testWrongAxisName(self): - devices = self._get_two_devices(require_same_type=True) - - def reduce_sum(f): - return extensions.psum(f, axis_name="bar") - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - with self.assertRaisesWithPredicateMatch( - ValueError, r"axis_name (.*) is not equal to that of the surrounding"): - pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) - pmapped(data) - - def testNoNestedPmap(self): - devices = self._get_two_devices(require_same_type=True) - - def f(x): - return x + 1.0 - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - with self.assertRaisesWithPredicateMatch(ValueError, - r"Nested pmap is not supported"): - f = extensions.pmap(f, devices=devices) - f = extensions.pmap(f, devices=devices) - f(data) - - def testVmap(self): - fn1 = extensions.vmap(lambda z: z * z) - - x = tf_np.arange(10) - self.assertAllClose(x * x, fn1(x)) - - y = tf.range(10) - np_y = tf_np.asarray(y) - output = fn1(y) - self.assertIsInstance(output, tf_np.ndarray) - self.assertAllClose(np_y * np_y, output) - - fn2 = extensions.vmap(lambda x, y: x + y) - x = tf_np.random.randn(10, 3) - y = tf_np.random.randn(10, 2, 3) - self.assertAllClose(tf_np.expand_dims(x, 1) + y, fn2(x, y)) - - def testRemat(self): - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - f_remat = extensions.remat(f) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - - actual = extensions.grad(f_remat)(a, b) - expected = extensions.grad(f)(a, b) - self.assertAllClose(actual, expected) - - def testRematLambdaFunction(self): - f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - f_remat = extensions.remat(f) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - - actual = extensions.grad(f_remat)(a, b) - expected = extensions.grad(f)(a, b) - self.assertAllClose(actual, expected) - - def testRematJit(self): - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - f_remat = extensions.remat(f) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - - actual = extensions.jit(extensions.grad(f_remat))(a, b) - expected = extensions.jit(extensions.grad(f))(a, b) - self.assertAllClose(actual, expected) - - def testRematJitXla(self): - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - f_remat = extensions.remat(f) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - - actual = extensions.jit( - extensions.grad(f_remat), xla_forced_compile=True)(a, b) - expected = extensions.jit(extensions.grad(f), xla_forced_compile=True)(a, b) - self.assertAllClose(actual, expected) - - actual = extensions.jit( - extensions.grad(f_remat), experimental_compile=True)(a, b) - expected = extensions.jit( - extensions.grad(f), experimental_compile=True)(a, b) - self.assertAllClose(actual, expected) - - def testStaticStopGradient(self): - self.assertEqual(extensions.stop_gradient(5.), 5.) - self.assertEqual(type(extensions.stop_gradient(5.)), type(5.)) - - self.assertEqual(extensions.stop_gradient(tf_np.asarray(5.)), 5.) - self.assertNotEqual( - type(extensions.stop_gradient(tf_np.asarray(5.))), type(5.)) - - -if __name__ == "__main__": - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/jax_tests/config.py b/trax/tf_numpy/jax_tests/config.py deleted file mode 100644 index 5da9f1b1e..000000000 --- a/trax/tf_numpy/jax_tests/config.py +++ /dev/null @@ -1,151 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -def bool_env(varname: str, default: bool) -> bool: - """Read an environment variable and interpret it as a boolean. - - True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; - false values are 'n', 'no', 'f', 'false', 'off', and '0'. - - Args: - varname: the name of the variable - default: the default boolean value - Raises: ValueError if the environment variable is anything else. - """ - val = os.getenv(varname, str(default)) - val = val.lower() - if val in ('y', 'yes', 't', 'true', 'on', '1'): - return True - elif val in ('n', 'no', 'f', 'false', 'off', '0'): - return False - else: - raise ValueError("invalid truth value %r for environment %r" % (val, varname)) - - -class Config(object): - def __init__(self): - self.values = {} - self.meta = {} - self.FLAGS = NameSpace(self.read) - self.use_absl = False - - def update(self, name, val): - if self.use_absl: - setattr(self.absl_flags.FLAGS, name, val) - else: - self.check_exists(name) - if name not in self.values: - raise Exception("Unrecognized config option: {}".format(name)) - self.values[name] = val - - def read(self, name): - if self.use_absl: - return getattr(self.absl_flags.FLAGS, name) - else: - self.check_exists(name) - return self.values[name] - - def add_option(self, name, default, opt_type, meta_args, meta_kwargs): - if name in self.values: - raise Exception("Config option {} already defined".format(name)) - self.values[name] = default - self.meta[name] = (opt_type, meta_args, meta_kwargs) - - def check_exists(self, name): - if name not in self.values: - raise Exception("Unrecognized config option: {}".format(name)) - - def DEFINE_bool(self, name, default, *args, **kwargs): - self.add_option(name, default, bool, args, kwargs) - - def DEFINE_integer(self, name, default, *args, **kwargs): - self.add_option(name, default, int, args, kwargs) - - def DEFINE_string(self, name, default, *args, **kwargs): - self.add_option(name, default, str, args, kwargs) - - def DEFINE_enum(self, name, default, *args, **kwargs): - self.add_option(name, default, 'enum', args, kwargs) - - def config_with_absl(self): - # Run this before calling `app.run(main)` etc - import absl.flags as absl_FLAGS - from absl import app, flags as absl_flags - - self.use_absl = True - self.absl_flags = absl_flags - absl_defs = { bool: absl_flags.DEFINE_bool, - int: absl_flags.DEFINE_integer, - str: absl_flags.DEFINE_string, - 'enum': absl_flags.DEFINE_enum } - - for name, val in self.values.items(): - flag_type, meta_args, meta_kwargs = self.meta[name] - absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) - - app.call_after_init(lambda: self.complete_absl_config(absl_flags)) - - def complete_absl_config(self, absl_flags): - for name, _ in self.values.items(): - self.update(name, getattr(absl_flags.FLAGS, name)) - - def parse_flags_with_absl(self): - global already_configured_with_absl - if not already_configured_with_absl: - import absl.flags - self.config_with_absl() - absl.flags.FLAGS(sys.argv, known_only=True) - self.complete_absl_config(absl.flags) - already_configured_with_absl = True - - -class NameSpace(object): - def __init__(self, getter): - self._getter = getter - - def __getattr__(self, name): - return self._getter(name) - - -config = Config() -flags = config -FLAGS = flags.FLAGS - -already_configured_with_absl = False - -flags.DEFINE_bool( - 'jax_enable_checks', - bool_env('JAX_ENABLE_CHECKS', False), - help='Turn on invariant checking (core.skip_checks = False)') - -flags.DEFINE_bool('tf_numpy_additional_tests', True, - 'Run tests added specifically for TF numpy') diff --git a/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py b/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py deleted file mode 100644 index cb583abae..000000000 --- a/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py +++ /dev/null @@ -1,359 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict # pylint: disable=g-importing-member -import itertools - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.jax_tests.config import config -import trax.tf_numpy.jax_tests.test_util as jtu -import trax.tf_numpy.numpy as jnp - - -config.parse_flags_with_absl() - - -class EinsumTest(jtu.TestCase): - - def _check(self, s, *ops): - a = np.einsum(s, *ops) - b = jnp.einsum(s, *ops) - self.assertAllClose(a, b, check_dtypes=True, atol=1e-4, rtol=1e-4) - - def test_three_operands_1(self): - r = self.rng() - x = r.randn(3) - y = r.randn(4) - z = r.randn(5) - s = 'i,j,k->ijk' - self._check(s, x, y, z) - - def test_three_operands_2(self): - r = self.rng() - x = r.randn(3) - y = r.randn(4) - z = r.randn(5) - s = 'i,j,k->ijk' - self._check(s, x, y, z) - - def test_two_operands_1(self): - r = self.rng() - x = r.randn(3, 4) - y = r.randn(4) - s = 'ij,j->i' - self._check(s, x, y) - - def test_two_operands_2(self): - r = self.rng() - x = r.randn(3, 4, 5) - y = r.randn(4) - s = 'ijk,j->i' - self._check(s, x, y) - - def test_two_operands_3(self): - r = self.rng() - x = r.randn(3, 4, 3) - y = r.randn(3) - s = 'iji,i->j' - self._check(s, x, y) - - def test_two_operands_4(self): - r = self.rng() - x = r.randn(3, 4) - y = r.randn(3, 4) - s = 'ij,ij->' - self._check(s, x, y) - - def test_two_operands_5(self): - r = self.rng() - x = r.randn(10, 2, 3) - y = r.randn(3, 4) - s = 'nij,jk->nik' - self._check(s, x, y) - - def test_two_operands_6(self): - # based on https://github.com/google/jax/issues/37#issuecomment-448572187 - r = self.rng() - x = r.randn(2, 1) - y = r.randn(2, 3, 4) - s = 'sa,shb->shab' - self._check(s, x, y) - - def test_one_operand_1(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = 'ijk->j' - self._check(s, x) - - def test_one_operand_2(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = 'ijk->kij' - self._check(s, x) - - def test_one_operand_3(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = 'ijk->ki' - self._check(s, x) - - def test_one_operand_4(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = 'ijk->ki' - self._check(s, x) - - def test_one_operand_5(self): - r = self.rng() - x = r.randn(2, 3, 4, 5) - s = '...ijk->...ki' - self._check(s, x) - - def test_one_operand_6(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = '...ijk->ki' - self._check(s, x) - - def test_one_operand_7(self): - r = self.rng() - x = r.randn(3, 3) - s = 'ii->' - self._check(s, x) - - def test_one_operand_8(self): - r = self.rng() - x = r.randn(3, 3) - s = 'ij->' - self._check(s, x) - - def test_one_operand_9(self): - r = self.rng() - x = r.randn(3, 3, 3) - s = 'iii->' - self._check(s, x) - - def test_one_operand_10(self): - r = self.rng() - x = r.randn(3, 3) - s = 'ii->i' - self._check(s, x) - - def test_one_operand_11(self): - r = self.rng() - x = r.randn(3, 3, 4) - s = 'iij->i' - self._check(s, x) - - def test_one_operand_12(self): - r = self.rng() - x = r.randn(3, 3, 3) - s = 'iii->i' - self._check(s, x) - - def test_one_operand_13(self): - r = self.rng() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkk->i' - self._check(s, x) - - def test_one_operand_14(self): - r = self.rng() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkk->ik' - self._check(s, x) - - def test_one_operand_15(self): - r = self.rng() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkl->il' - self._check(s, x) - - def test_one_operand_16(self): - r = self.rng() - x = r.randn(3, 3) - s = 'ij->ij' - self._check(s, x) - - def test_tf_unsupported_1(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = self.rng() - x = r.randn(2, 3, 5, 1) - y = r.randn(3, 4, 5, 1) - s = 'ij...,jk...->ik...' - self._check(s, x, y) - - def test_tf_unsupported_2(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = self.rng() - x = r.randn(2, 3, 3) - y = r.randn(4) - s = 'ijj,k->ik' - self._check(s, x, y) - - def test_tf_unsupported_3(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = self.rng() - x = r.randn(2, 3) - y = r.randn(2, 3) - z = r.randn(3, 4) - s = 'ij,ij,jk->ik' - self._check(s, x, y, z) - - # these tests are based on https://github.com/dask/dask/pull/3412/files - @parameterized.named_parameters( - {'testcase_name': '_{}_dtype={}'.format(einstr, dtype.__name__), # pylint: disable=g-complex-comprehension - 'einstr': einstr, 'dtype': dtype} - for einstr in [ - 'abc,bad->abcd', - 'abcdef,bcdfg->abcdeg', - 'ea,fb,abcd,gc,hd->efgh', - 'ab,b', - 'aa', - 'a,a->', - 'a,a->a', - 'a,a', - 'a,b', - 'a,b,c', - 'a', - 'ba,b', - 'ba,b->', - 'defab,fedbc->defac', - 'ab...,bc...->ac...', - 'a...a', - 'abc...->cba...', - '...ab->...a', - 'a...a->a...', - # Following 2 from # https://stackoverflow.com/a/19203475/1611416 - '...abc,...abcd->...d', - 'ab...,b->ab...', - # https://github.com/dask/dask/pull/3412#discussion_r182413444 - 'aa->a', - 'ab,ab,c->c', - 'aab,bc->ac', - 'aab,bcc->ac', - 'fdf,cdd,ccd,afe->ae', - 'fff,fae,bef,def->abd', - ] - # TODO(wangpeng): Add jnp.bool_ to dtype list - for dtype in [jnp.float32, jnp.int32, jnp.complex64]) - def test_from_dask(self, einstr, dtype): - r = jtu.rand_default() - if '->' in einstr: - input_str, _ = einstr.split('->') - else: - input_str = einstr - input_names = input_str.split(',') - - dims = itertools.cycle([2, 3, 4]) - shapes = defaultdict(lambda: next(dims)) - input_shapes = [tuple(shapes[c] for c in names.replace('...', '01')) - for names in input_names] - operands = [r(shape, dtype) for shape in input_shapes] - - self._check(einstr, *operands) - - def test_ordered_front_batch_dim_case(self): - x = np.ones((1, 8, 20, 4)) - y = np.ones((1, 8, 20, 4)) - s = 'ijkl,ijml->ijkm' - self._check(s, x, y) - - # pylint: disable=invalid-name - def test_einsum_path(self): - # just check examples from np.einsum_path docstring - a = self.rng().rand(2, 2) - b = self.rng().rand(2, 5) - c = self.rng().rand(5, 2) - - path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') - self.assertEqual(str(path_info[0]), "['einsum_path', (1, 2), (0, 1)]") - self.assertEqual(path_info[1].split('\n')[0], - ' Complete contraction: ij,jk,kl->il') - - # check this doesn't crash - I = self.rng().rand(10, 10, 10, 10) - C = self.rng().rand(10, 10) - np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, optimize='greedy') - - @jtu.disable - def test_einsum_kpmurphy_example(self): - # code from an email with @murphyk - N = 2 - C = 3 - D = 4 - K = 5 - T = 6 - r = self.rng() - S = r.randn(N, T, K) - W = r.randn(K, D) - V = r.randn(D, C) - L = np.zeros((N, C)) - for n in range(N): - for c in range(C): - s = 0 - for d in range(D): - for k in range(K): - for t in range(T): - s += S[n, t, k] * W[k, d] * V[d, c] - L[n, c] = s - - path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0] - rtol = 1e-2 if jtu.device_under_test() == 'tpu' else None - self.assertAllClose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path), - check_dtypes=False, rtol=rtol) - # pylint: enable=invalid-name - - @jtu.disable - def test_contraction_broadcasting(self): - r = self.rng() - x = r.randn(3, 4, 5) - y = r.randn(3, 1, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - @jtu.disable - def test_batch_broadcasting(self): - r = self.rng() - x = r.randn(1, 4, 5) - y = r.randn(3, 5, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - @jtu.disable - def test_batch_and_contraction_broadcasting(self): - r = self.rng() - x = r.randn(1, 4, 5) - y = r.randn(3, 1, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - @jtu.disable - def test_broadcasting_issue_2189(self): - r = self.rng() - x = r.randn(2, 1, 3, 3) - y = r.randn(2, 4, 3) - s = '...ij,...j' - self._check(s, x, y) - - -if __name__ == '__main__': - tf.enable_v2_behavior() - absltest.main() diff --git a/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py b/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py deleted file mode 100644 index 7f0a13f03..000000000 --- a/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py +++ /dev/null @@ -1,1000 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -import enum -from functools import partial -import itertools -import unittest - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as onp -import tensorflow.compat.v2 as tf - -import trax.tf_numpy.extensions as npe -import trax.tf_numpy.numpy as jnp - -from trax.tf_numpy.jax_tests.config import config -import trax.tf_numpy.jax_tests.test_util as jtu - -config.parse_flags_with_absl() - - -# We disable the whitespace continuation check in this file because otherwise it -# makes the test name formatting unwieldy. -# pylint: disable=bad-continuation -# We also disable undefined-variable till we start enabling tests. -# pylint: disable=undefined-variable - - -def subvals(lst, replace): - lst = list(lst) - for i, v in replace: - lst[i] = v - return tuple(lst) - - -float_dtypes = [onp.float32, onp.float64] -int_dtypes = [onp.int32, onp.int64] -bool_types = [onp.bool_] -default_dtypes = float_dtypes + int_dtypes -all_dtypes = float_dtypes + int_dtypes + bool_types - -IndexSpec = collections.namedtuple("IndexTest", ["shape", "indexer"]) - - -suppress_deprecated_indexing_warnings = partial( - jtu.ignore_warning, category=FutureWarning, - message='Using a non-tuple sequence.*') - - -STATIC_INDEXING_TESTS = [ - ("OneIntIndex", [ - IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2), - ]), - ("TwoIntIndices", [ - IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), - ]), - ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ("OneSliceIndex", [ - IndexSpec(shape=(10,), indexer=slice(1, 3)), - IndexSpec(shape=(10,), indexer=slice(1, -1)), - IndexSpec(shape=(10,), indexer=slice(None, -1)), - IndexSpec(shape=(10,), indexer=slice(None, None, None)), - IndexSpec(shape=(10, 8), indexer=slice(1, 3)), - IndexSpec(shape=(10, 8), indexer=slice(1, None)), - IndexSpec(shape=(10, 8), indexer=slice(None, 3)), - IndexSpec(shape=(10, 8), indexer=slice(-3, None)), - ]), - ("OneSliceIndexNegativeStride", [ - IndexSpec(shape=(10,), indexer=slice(3, 1, -1)), - IndexSpec(shape=(10,), indexer=slice(1, 8, -1)), # empty result - IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10,), indexer=slice(None, None, -1)), - IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1)), - IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1)), # empty result - IndexSpec(shape=(10, 8), indexer=slice(None, None, -1)), - ]), - ("OneSliceIndexNonUnitStride", [ - IndexSpec(shape=(10,), indexer=slice(0, 8, 2)), - IndexSpec(shape=(10,), indexer=slice(0, 8, 3)), - IndexSpec(shape=(10,), indexer=slice(1, 3, 2)), - IndexSpec(shape=(10,), indexer=slice(1, None, 2)), - IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, 2)), - IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, -2)), - ]), - ("TwoSliceIndices", [ - IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2))), - IndexSpec( - shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2))), - ]), - ("OneColonIndex", [ - IndexSpec(shape=(3,), indexer=slice(None)), - IndexSpec(shape=(3, 4), indexer=slice(None)), - ]), - ("MultipleColonIndices", [ - IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), - IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), - ]), - ("MixedSliceIndices", [ - IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2))), - IndexSpec(shape=(10, 4), indexer=(1, slice(None))), - ]), - ("EllipsisIndex", [ - IndexSpec(shape=(3,), indexer=Ellipsis), - IndexSpec(shape=(3, 4), indexer=Ellipsis), - IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), - ]), - ("NoneIndex", [ - IndexSpec(shape=(), indexer=None), - IndexSpec(shape=(), indexer=(None, None)), - IndexSpec(shape=(), indexer=(Ellipsis, None)), - IndexSpec(shape=(3,), indexer=None), - IndexSpec(shape=(3, 4), indexer=None), - IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), - IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), - ]), - ("EmptyIndex", [ - IndexSpec(shape=(), indexer=()), - IndexSpec(shape=(3,), indexer=()), - IndexSpec(shape=(3, 4), indexer=()), - ]), -] - -STATIC_INDEXING_GRAD_TESTS = [ - ("OneIntIndex", [ - IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2), - ]), - ("TwoIntIndices", [ - IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), - ]), - ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ("OneSliceIndex", [ - IndexSpec(shape=(5,), indexer=slice(1, 3)), - IndexSpec(shape=(5,), indexer=slice(1, -1)), - IndexSpec(shape=(5,), indexer=slice(None, -1)), - IndexSpec(shape=(5,), indexer=slice(None, None, None)), - IndexSpec(shape=(5, 4), indexer=slice(1, 3)), - IndexSpec(shape=(5, 4), indexer=slice(1, None)), - IndexSpec(shape=(5, 4), indexer=slice(None, 3)), - IndexSpec(shape=(5, 4), indexer=slice(-3, None)), - ]), - ("TwoSliceIndices", [ - IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4), indexer=(slice(1, None), slice(None, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, None), slice(0, 2))), - ]), - ("OneColonIndex", [ - IndexSpec(shape=(3,), indexer=slice(None)), - IndexSpec(shape=(3, 4), indexer=slice(None)), - ]), - ("MultipleColonIndices", [ - IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), - IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), - ]), - ("MixedSliceIndices", [ - IndexSpec(shape=(5, 4), indexer=(slice(None), slice(0, 2))), - IndexSpec(shape=(5, 4), indexer=(1, slice(None))), - ]), - ("EllipsisIndex", [ - IndexSpec(shape=(3,), indexer=Ellipsis), - IndexSpec(shape=(3, 4), indexer=Ellipsis), - IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), - ]), - ("NoneIndex", [ - IndexSpec(shape=(), indexer=None), - IndexSpec(shape=(), indexer=(None, None)), - IndexSpec(shape=(), indexer=(Ellipsis, None)), - IndexSpec(shape=(3,), indexer=None), - IndexSpec(shape=(3, 4), indexer=None), - IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), - IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), - ]), - # TODO(mattjj): these fail for uninteresting dtype reasons - # ("EmptyIndex", - # [IndexSpec(shape=(), indexer=()), - # IndexSpec(shape=(3,), indexer=()), - # IndexSpec(shape=(3, 4), indexer=()), - # ]), -] - -ADVANCED_INDEXING_TESTS = [ - ("One1DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=onp.array([0, 1])), - IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), - IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), - IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), - IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), - IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)), - ]), - ("One2DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=onp.array([[0, 0]])), - IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], - [0, 1, -1]])), - IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], - [-1, -2, 1, 0]])), - ]), - ("Two1DIntArrayIndicesNoBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(onp.array([0, 1]), - onp.array([1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, 0, 1]), - onp.array([-1, 0, -1, 2]))), - ]), - ("Two1DIntArrayIndicesWithBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(onp.array([[0, 1]]), - onp.array([1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2, 0, 1]]), - onp.array([-1, 0, -1, 2]))), - ]), - ("TupleOfListsOfPythonInts", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1])), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]])), - ]), - ("TupleOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=(0, onp.array([0, 1]))), - IndexSpec(shape=(3, 4, 5), indexer=(0, 1, - onp.array([[2, 3, 0, 3]]))), - ]), - ("TupleOfListsOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1], onp.array([0]))), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], - onp.array([[2, 3, 0, 3]]))), - ]), -] - -ADVANCED_INDEXING_TESTS_NO_REPEATS = [ - ("One1DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=onp.array([0, 1])), - IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 0])), - IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 1])), - IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), - IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), - # Fails with a TF/XLA error. - # IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)), - ]), - ("One2DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=onp.array([[0, 1]])), - IndexSpec(shape=(6, 6), indexer=onp.array([[1, 2, 0], - [3, 4, -1]])), - ]), - ("Two1DIntArrayIndicesNoBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(onp.array([0, 1]), - onp.array([1, 2]))), - IndexSpec(shape=(4, 5, 6), indexer=(onp.array([0, 2, 1, 3]), - onp.array([-1, 0, -2, 1]))), - ]), - ("Two1DIntArrayIndicesWithBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(onp.array([[0, 1]]), - onp.array([1, 2]))), - IndexSpec(shape=(4, 5, 6), indexer=(onp.array([[0, 2, -1, 1]]), - onp.array([-1, 0, -2, 2]))), - ]), - ("TupleOfListsOfPythonInts", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1])), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]])), - ]), - ("TupleOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=(0, onp.array([0, 1]))), - IndexSpec(shape=(3, 4, 5), indexer=(0, 1, - onp.array([[2, 3, 0]]))), - ]), - ("TupleOfListsOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1], onp.array([0]))), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], - onp.array([[2, 3, 0]]))), - ]), -] - -MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [ - ("SlicesAndOneIntArrayIndex", - [IndexSpec(shape=(2, 3), indexer=(onp.array([0, 1]), slice(1, 2))), - IndexSpec(shape=(2, 3), indexer=(slice(0, 2), - onp.array([0, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([0, 2]), - slice(None))), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([[0, 2], [1, 3]]), - slice(None))), - ]), - ("SlicesAndTwoIntArrayIndices", - [IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([0, 2]), - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - Ellipsis, - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - onp.array([-1, 2]), - Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - onp.array([-1, 2]), - slice(1, 3))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - slice(1, 3), - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]), - slice(None, None, 2), - onp.array([-1, 2, 1]))), - ]), - ("NonesAndIntArrayIndices", - [IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - None, - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - None, - None, - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([0, 2]), - None, - None, - onp.array([-1, 2]))), - ]), - ("IntArrayWithInt32Type", - [IndexSpec(shape=(3, 4), indexer=(Ellipsis, onp.array(1, dtype=onp.int32))) - ]), -] - -MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [ - ("SlicesAndOneIntArrayIndex", - [ - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([[0, 2], [1, 1]]), - slice(None))), - ]), - ("SlicesAndTwoIntArrayIndices", - [IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]), - slice(None, None, 2), - onp.array([-1, 2, -1]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2], [2, 0]]), - Ellipsis, - onp.array([[1, 0], [1, 0]]))), - ]),] - - -def dynamic_slice_reference(operand, start_indices, slice_sizes): - out = onp.zeros(slice_sizes, dtype=operand.dtype) - idx = tuple(slice(start, start+size) - for start, size in zip(start_indices, slice_sizes)) - section = operand[idx] - out[tuple(slice(None, stop) for stop in section.shape)] = section - return out - - -def dynamic_update_slice_reference(operand, update, start_indices): - slices = tuple(map( - slice, start_indices, onp.add(start_indices, update.shape))) - updated_operand = onp.copy(operand) - updated_operand[slices] = update - return updated_operand - - -class IndexingTest(jtu.TestCase): - """Tests for Numpy indexing translation rules.""" - - @parameterized.named_parameters(jtu.cases_from_list({ - "testcase_name": "{}_inshape={}_indexer={}".format( - name, jtu.format_shape_dtype_string( shape, dtype), indexer), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer - } for name, index_specs in STATIC_INDEXING_TESTS - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default])) - def testStaticIndexing(self, shape, dtype, rng_factory, indexer): - # TODO(rohanj): Revisit passing in self.rng() to this to customize further. - # This would need updating lax_numpy_test as well. - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype)] - onp_fun = lambda x: x[indexer] - jnp_fun = lambda x: jnp.asarray(x)[indexer] - self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True) - - def _ReplaceSlicesWithTuples(self, idx): - """Helper method to replace slices with tuples for dynamic indexing args.""" - if isinstance(idx, slice): - triple = idx.start, idx.stop, idx.step - isnone = [i for i, elt in enumerate(triple) if elt is None] - zeros = itertools.repeat(0) - nones = itertools.repeat(None) - out = subvals(triple, zip(isnone, zeros)) - return out, lambda out: slice(*subvals(out, zip(isnone, nones))) - elif isinstance(idx, (tuple, list)) and idx: - t = type(idx) - elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) - return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) - else: - return idx, lambda x: x - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} - for name, index_specs in [ - ("OneSliceIndex", - [IndexSpec(shape=(5,), indexer=slice(1, 3)), - IndexSpec(shape=(5, 4), indexer=slice(1, 3))]), - ("TwoSliceIndices", - [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]), - ("NonUnitStrides", [ - IndexSpec(shape=(3,), indexer=slice(None, None, -1)), - IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), - IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) - ]), - ("OnlyStartOrStopDynamic", [ - IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) - ]), - ] - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default]) - def testDynamicIndexingWithSlices(self, shape, dtype, rng_factory, indexer): - rng = rng_factory() - unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) - - def onp_fun(x, unpacked_indexer): - indexer = pack_indexer(unpacked_indexer) - return x[indexer] - - jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) - - args_maker = lambda: [rng(shape, dtype), unpacked_indexer] - self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) - # TODO(wangpeng): check_xla_forced_compile is turned off because some - # compile-time-constant requirements are violated. Investigate and turn it - # on. - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_eval_on_shapes=False, - check_incomplete_shape=True, - check_xla_forced_compile=False) - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} - for name, index_specs in [ - ("OneIntIndex", - [IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2)]), - ("TwoIntIndices", - [IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]), - ("ThreeIntIndices", - [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ] - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default]) - def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer): - # TODO(rohanj): Revisit passing in self.rng() to this to customize further. - # This would need updating lax_numpy_test as well. - rng = rng_factory() - unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) - - def onp_fun(x, unpacked_indexer): - indexer = pack_indexer(unpacked_indexer) - return x[indexer] - - jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) - - args_maker = lambda: [rng(shape, dtype), unpacked_indexer] - self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True) - - @parameterized.named_parameters( - {"testcase_name": "_{}_inshape={}_indexer={}" # pylint: disable=g-complex-comprehension - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "name": name, "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer} - for name, index_specs in ADVANCED_INDEXING_TESTS - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default]) - def testAdvancedIntegerIndexing(self, name, shape, dtype, rng_factory, - indexer): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), indexer] - onp_fun = lambda x, idx: x[idx] - jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) - - self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) - # TODO(wangpeng): check_xla_forced_compile is turned off for - # ListOfPythonIntsAndIntArrays because it throws "The number of output - # elements has to equal to number of input elements that are sliced when - # input indices are not constant". Investigate and turn it on. - check_xla = (name != "ListOfPythonIntsAndIntArrays") - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True, - check_xla_forced_compile=check_xla) - - @parameterized.named_parameters( - {"testcase_name": "_{}_inshape={}_indexer={}" # pylint: disable=g-complex-comprehension - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "name": name, "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer} - for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default]) - def testMixedAdvancedIntegerIndexing(self, name, shape, dtype, rng_factory, - indexer): - rng = rng_factory() - indexer_with_dummies = [e if isinstance(e, onp.ndarray) else () - for e in indexer] - substitutes = [(i, e) for i, e in enumerate(indexer) - if not isinstance(e, onp.ndarray)] - args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] - - def np_fun(x, indexer_with_dummies): - idx = type(indexer)(subvals(indexer_with_dummies, substitutes)) - return x[idx] - - jnp_fun = lambda x, idx: np_fun(jnp.asarray(x), idx) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - # TODO(wangpeng): check_xla_forced_compile is turned off for - # IntArrayWithInt32Type because it throws "The number of output elements has - # to equal to number of input elements that are sliced when input indices - # are not constant". Investigate and turn it on. - check_xla = (name != "IntArrayWithInt32Type") - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True, - check_xla_forced_compile=check_xla) - - def testAdvancedIndexingManually(self): - x = onp.random.RandomState(0).randn(3, 4, 5) - index_array = onp.array([0, 2, -1, 0]) - - op = lambda x, index_array: x[..., index_array, :] - cop = npe.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2, check_dtypes=True) - - op = lambda x, index_array: x[..., index_array, :, index_array, None] - cop = npe.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2, check_dtypes=True) - - op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] - cop = npe.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2, check_dtypes=True) - - # Note that we don't currently allow __iter__ in graph mode. So this test only - # iterates over eager tensor. - def testUnpacking(self): - - def foo(x): - a, b, c = x - return a + b + c - - a1 = foo(onp.arange(3)) - a2 = foo(jnp.arange(3)) - - self.assertAllClose(a1, a2, check_dtypes=True) - - def testBooleanIndexingArray1D(self): - idx = onp.array([True, True, False]) - x = jnp.asarray(onp.arange(3)) - ans = x[idx] - expected = onp.arange(3)[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingList1D(self): - idx = [True, True, False] - x = jnp.asarray(onp.arange(3)) - ans = x[idx] - expected = onp.arange(3)[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingArray2DBroadcast(self): - idx = onp.array([True, True, False, True]) - x = onp.arange(8).reshape(4, 2) - ans = jnp.asarray(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingList2DBroadcast(self): - idx = [True, True, False, True] - x = onp.arange(8).reshape(4, 2) - ans = jnp.asarray(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingArray2D(self): - idx = onp.array([[True, False], - [False, True], - [False, False], - [True, True]]) - x = onp.arange(8).reshape(4, 2) - ans = jnp.asarray(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingDynamicShape(self): - x = onp.zeros(3) - i = onp.array([True, True, False]) - ans = x[i] - expected = jnp.asarray(x)[i] - self.assertAllClose(ans, expected, check_dtypes=True) - - def testIssue187(self): - x = jnp.ones((5, 5)) - x[[0, 2, 4], [0, 2, 4]] # doesn't crash - - x = onp.arange(25).reshape((5, 5)) - ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) - expected = x[[0, 2, 4], [0, 2, 4]] - self.assertAllClose(ans, expected, check_dtypes=False) - - # TODO(agarwal): Fix this use case. - @jtu.disable - def testIndexingEmptyDimension(self): - # Issue 2671: XLA error when indexing into dimension of size 0 - x = jnp.ones((2, 0)) - # The following work, even on axis 1 of size 0 - _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] - - with self.assertRaisesRegex(IndexError, - "index .* is out of bounds for axis .* with size 0"): - _ = onp.ones((2, 0))[0, 0] # The numpy error - with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): - _ = x[0, 0] # JAX indexing - with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): - npe.jit(lambda i: x[0, i])(0) # JAX indexing under jit - - def testBooleanIndexingWithEmptyResult(self): - # based on a TensorFlow Probability test that started failing after #1623 - x = jnp.array([-1]) - mask = jnp.array([False]) - ans = x[mask] # doesn't crash - - expected = onp.array([-1])[onp.array([False])] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testFloatIndexingError(self): - error_regex = "only integers, slices.*are valid indices" - # Verify onp behavior - with self.assertRaisesRegex(IndexError, error_regex): - _ = onp.zeros((2, 2))[(0, 0.)] - # Test jnp - with self.assertRaisesRegex(IndexError, error_regex): - jnp.zeros(2)[0.] - with self.assertRaisesRegex(IndexError, error_regex): - jnp.zeros((2, 2))[(0, 0.)] - # Test with jit - with self.assertRaisesRegex(IndexError, error_regex): - npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.)) - - def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 - array = jnp.ones(5) - self.assertAllClose(array, array[:10], check_dtypes=True) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( # pylint: disable=g-complex-comprehension - jtu.format_shape_dtype_string(shape, dtype), - start_indices, size_indices), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "size_indices": size_indices, "rng_factory": rng_factory} - for shape, start_indices, size_indices in [ - [(3,), onp.array((1,)), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(5, 3), (1, -2), (3, 1)], - [(5, 3), onp.array((1, 1)), (3, 1)], - [(7, 5, 3), onp.array((4, 1, 0)), (2, 0, 1)], - [(), (), ()], - ] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicSlice(self, shape, dtype, start_indices, size_indices, - rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] - op = lambda x, starts: npe.dynamic_slice(x, starts, size_indices) - self._CompileAndCheck(op, args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( # pylint: disable=g-complex-comprehension - jtu.format_shape_dtype_string(shape, dtype), - start_indices, size_indices), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "size_indices": size_indices, "rng_factory": rng_factory} - for shape, start_indices, size_indices in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(5, 3), (1, -2), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 0, 1)], - [(), (), ()], - ] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicSliceAgainstNumpy(self, shape, dtype, start_indices, - size_indices, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] - op = lambda x, s: npe.dynamic_slice(x, s, size_indices) - numpy_op = lambda x, s: dynamic_slice_reference(x, s, size_indices) - self._CheckAgainstNumpy(numpy_op, op, args_maker) - - def testDynamicSliceInDim(self): - rng = jtu.rand_default() - x = rng((6, 7), onp.int32) - self.assertAllClose(npe.dynamic_slice_in_dim(x, 2, 3), x[2:5], - check_dtypes=True) - - -def _broadcastable_shapes(shape): - """Returns all shapes that broadcast to `shape`.""" - def f(rshape): - yield [] - if rshape: - for s in f(rshape[1:]): - yield rshape[0:1] + s - if rshape[0] != 1: - for s in f(rshape[1:]): - yield [1] + s - for x in f(list(reversed(shape))): - yield list(reversed(x)) - - -def _update_shape(shape, indexer): - return onp.zeros(shape)[indexer].shape - - -class UpdateOps(enum.Enum): - UPDATE = 0 - ADD = 1 - # MUL = 2 - MIN = 3 - MAX = 4 - - def np_fn(op, indexer, x, y): # pylint: disable=no-self-argument - x = x.copy() - x[indexer] = { - UpdateOps.UPDATE: lambda: y, - UpdateOps.ADD: lambda: x[indexer] + y, - # UpdateOps.MUL: lambda: x[indexer] * y, - UpdateOps.MIN: lambda: onp.minimum(x[indexer], y), - UpdateOps.MAX: lambda: onp.maximum(x[indexer], y), - }[op]() - return x - - def tfnp_fn(op, indexer, x, y): # pylint: disable=no-self-argument - return { - UpdateOps.UPDATE: npe.index_update, - UpdateOps.ADD: npe.index_add, - # UpdateOps.MUL: npe.index_mul, - UpdateOps.MIN: npe.index_min, - UpdateOps.MAX: npe.index_max, - }[op](x, indexer, y) - - -# a test to workaround b/123559667 -def has_non_trivial_stride(indexer): - def has(idx): - return isinstance(idx, slice) and idx.step not in (1, -1, None) - return any(has(idx) for idx in tf.nest.flatten(indexer)) - - -class IndexedUpdateTest(jtu.TestCase): - - @parameterized.named_parameters(jtu.cases_from_list({ # pylint: disable=g-complex-comprehension - "testcase_name": "_{}_{}_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer, "update_shape": update_shape, - "update_dtype": update_dtype, "op": op - } for name, index_specs in STATIC_INDEXING_TESTS - for shape, indexer in index_specs - for op in UpdateOps - for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) - for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) - for update_dtype in all_dtypes - for rng_factory in [jtu.rand_default])) - def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, - rng_factory, indexer, op): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) - self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) - # TODO(wangpeng): When indexer is slice(_, 8, -1), XLA throws error "Missing - # xla_context 0-th output from". Investigate. - check_xla = (not has_non_trivial_stride(indexer) and # b/123559667 - not (isinstance(indexer, slice) and indexer.stop == 8 and - indexer.step == -1)) - self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @parameterized.named_parameters(jtu.cases_from_list({ # pylint: disable=g-complex-comprehension - "testcase_name": "_{}_{}_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer, "update_shape": update_shape, - "update_dtype": update_dtype, "op": op - } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS - for shape, indexer in index_specs - for op in UpdateOps - for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) - for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) - for update_dtype in all_dtypes - for rng_factory in [jtu.rand_default])) - def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, - rng_factory, indexer, op): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) - self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) - self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) - - @parameterized.named_parameters(jtu.cases_from_list({ # pylint: disable=g-complex-comprehension - "testcase_name": "_{}_{}_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer, "update_shape": update_shape, - "update_dtype": update_dtype, "op": op - } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS - for shape, indexer in index_specs - for op in UpdateOps - for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) - for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) - for update_dtype in all_dtypes - for rng_factory in [jtu.rand_default])) - def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, - rng_factory, indexer, op): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) - self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) - check_xla = not has_non_trivial_stride(indexer) # b/123559667 - self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @parameterized.named_parameters(jtu.cases_from_list({ # pylint: disable=g-complex-comprehension - "testcase_name": "_{}_{}_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer, "update_shape": update_shape, - "update_dtype": update_dtype, "op": op - } for name, index_specs in STATIC_INDEXING_TESTS - for shape, indexer in index_specs - for op in [UpdateOps.ADD, UpdateOps.UPDATE] - for dtype in float_dtypes - for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) - for update_dtype in float_dtypes - for rng_factory in [jtu.rand_default])) - def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, - rng_factory, indexer, op): - rng = rng_factory() - tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) - x = rng(shape, dtype) - y = rng(update_shape, update_dtype) - self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension - jtu.format_shape_dtype_string(shape, dtype), - start_indices, update_shape), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "update_shape": update_shape, "rng_factory": rng_factory} - for shape, start_indices, update_shape in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(5, 3), (1, -2), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 0, 1)], - [(), (), ()], - ] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicUpdateSlice(self, shape, dtype, start_indices, update_shape, - rng_factory): - rng = rng_factory() - def args_maker(): - return [rng(shape, dtype), rng(update_shape, dtype), - onp.array(start_indices)] - # update's shape must be fully known. - # TODO(wangpeng): Support turning off check_incomplete_shape for individual - # arguments. - self._CompileAndCheck(npe.dynamic_update_slice, args_maker, - check_incomplete_shape=False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension - jtu.format_shape_dtype_string(shape, dtype), - start_indices, update_shape), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "update_shape": update_shape, "rng_factory": rng_factory} - for shape, start_indices, update_shape in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(5, 3), (1, -2), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 0, 1)], - [(), (), ()], - ] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, start_indices, - update_shape, rng_factory): - rng = rng_factory() - def args_maker(): - return [rng(shape, dtype), rng(update_shape, dtype), - onp.array(start_indices)] - self._CheckAgainstNumpy(dynamic_update_slice_reference, - npe.dynamic_update_slice, args_maker) - - def testDynamicUpdateSliceInDim(self): - rng = jtu.rand_default() - x = rng((6, 7), onp.int32) - y = rng((3, 7), onp.int32) - z = x.copy() - z[2:5] = y - self.assertAllClose(npe.dynamic_update_slice_in_dim(x, y, 2, 0), z, - check_dtypes=True) - - -if __name__ == "__main__": - tf.config.set_soft_device_placement(False) - jnp.enable_numpy_behavior() - absltest.main() diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py deleted file mode 100644 index e973ef79f..000000000 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ /dev/null @@ -1,3085 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import functools -from functools import partial -import itertools -import operator -import unittest -from unittest import SkipTest -import warnings - -from absl.testing import absltest -from absl.testing import parameterized -import six - -import numpy as onp - - -import tensorflow.compat.v2 as tf -import trax.tf_numpy.numpy as lnp -import trax.tf_numpy.extensions as npe -from trax.tf_numpy.jax_tests.config import config, FLAGS -import trax.tf_numpy.jax_tests.test_util as jtu - - -from tensorflow.python.framework import ops -from tensorflow.python.ops.numpy_ops import np_config - -config.parse_flags_with_absl() - - -nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] -nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes -empty_array_shapes = [(0,), (0, 4), (3, 0),] - -scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] -array_shapes = nonempty_array_shapes + empty_array_shapes -nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes -nonempty_shapes = scalar_shapes + nonempty_array_shapes -all_shapes = scalar_shapes + array_shapes - -# TODO(wangpeng): float_dtypes = [lnp.bfloat16, onp.float16, onp.float32, -# onp.float64] -float_dtypes = [onp.float16, onp.float32, onp.float64] -complex_dtypes = [onp.complex64, onp.complex128] -int_dtypes = [onp.int32, onp.int64] -unsigned_dtypes = [onp.uint32, onp.uint64] -bool_dtypes = [onp.bool_] -default_dtypes = float_dtypes + int_dtypes -inexact_dtypes = float_dtypes + complex_dtypes -number_dtypes = float_dtypes + complex_dtypes + int_dtypes -all_dtypes = number_dtypes + bool_dtypes - - -python_scalar_dtypes = [lnp.bool_, lnp.int_, lnp.float_, lnp.complex_] - -def _valid_dtypes_for_shape(shape, dtypes): - # Not all (shape, dtype) pairs are valid. In particular, Python scalars only - # have one type in each category (float, bool, etc.) - if shape is jtu.PYTHON_SCALAR_SHAPE: - return [t for t in dtypes if t in python_scalar_dtypes] - return dtypes - -def _shape_and_dtypes(shapes, dtypes): - for shape in shapes: - for dtype in _valid_dtypes_for_shape(shape, dtypes): - yield (shape, dtype) - -OpRecord = collections.namedtuple( - "OpRecord", - ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", - "test_name", "check_dtypes", "tolerance", "inexact", - "check_incomplete_shape"]) - -def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name=None, check_dtypes=True, tolerance=None, inexact=False, - check_incomplete_shape=True): - test_name = test_name or name - return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name, check_dtypes, tolerance, inexact, - check_incomplete_shape) - - -def minus(a, b): - return [x for x in a if x not in b] - - -JAX_ONE_TO_ONE_OP_RECORDS = [ - op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), - op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("float_power", 2, inexact_dtypes, all_shapes, - partial(jtu.rand_default, scale=1), ["rev"], - tolerance={ - # TODO(wangpeng): lnp.bfloat16: 1e-2, - onp.float32: 1e-3, - onp.float64: 1e-12, onp.complex64: 2e-4, - onp.complex128: 1e-12}, check_dtypes=False), - op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("greater", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_equal, []), - op_record("greater_equal", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_equal, []), - op_record("less", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_equal, []), - op_record("less_equal", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_equal, []), - op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("maximum", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf, []), - op_record("minimum", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf, []), - op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("nextafter", 2, [f for f in float_dtypes - if f not in (lnp.bfloat16, onp.float16)], - all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0), - op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), - op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), - op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"]), - op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("tan", 1, number_dtypes, all_shapes, - partial(jtu.rand_uniform, -1.5, 1.5), ["rev"], - tolerance={onp.complex64: 3e-5, onp.complex128: 4e-14}, - inexact=True), - # TODO(wangpeng): Add float16 support - op_record("sinh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("cosh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_default, ["rev"], - inexact=True), - # TODO(b/142975473): on CPU, tanh for complex128 is only accurate to - # ~float32 precision. - # TODO(b/143135720): on GPU, tanh has only ~float32 precision. - op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - tolerance={onp.float64: 1e-7, onp.complex128: 1e-7}, - inexact=True), - op_record("arcsin", 1, minus(float_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arccos", 1, minus(float_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arctan", 1, minus(float_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arctan2", 2, minus(float_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arcsinh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("arccosh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("arctanh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), -] - -JAX_COMPOUND_OP_RECORDS = [ - # angle has inconsistent 32/64-bit return types across numpy versions. - op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [], - check_dtypes=False, inexact=True), - op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"], - inexact=six.PY3), - op_record("divmod", 2, minus(int_dtypes + float_dtypes, [onp.float16]), - all_shapes, jtu.rand_nonzero, []), - op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - tolerance={ - # TODO(wangpeng): lnp.bfloat16: 2e-2, - onp.float16: 1e-2}, inexact=True), - # TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32 - # precision. - op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [], - test_name="expm1_large", tolerance={onp.float64: 1e-8}, inexact=True), - op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive, - [], tolerance={onp.float64: 1e-8}, inexact=True), - op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("floor_divide", 2, minus(number_dtypes, complex_dtypes), - all_shapes, jtu.rand_nonzero, ["rev"]), - op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], - inexact=True), - op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], - inexact=True), - op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, [], - check_incomplete_shape=False), - op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("isfinite", 1, minus(inexact_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf_and_nan, []), - op_record("isinf", 1, minus(inexact_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf_and_nan, []), - op_record("isnan", 1, minus(inexact_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf_and_nan, []), - op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [], - test_name="log1p_large", tolerance={onp.float64: 1e-12}, - inexact=True), - op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [], - tolerance={onp.float64: 1e-12}, inexact=True), - op_record("logaddexp", 2, float_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"], - tolerance={onp.float64: 1e-12}, inexact=True), - op_record("logaddexp2", 2, float_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"], - tolerance={onp.float16: 1e-2}, inexact=True), - op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes, - jtu.rand_default, [], check_dtypes=False, - tolerance={onp.float16: 1e-2, onp.float64: 1e-12}, - check_incomplete_shape=False), - op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - tolerance={onp.complex128: 1e-14}), - op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, [], - tolerance={onp.float64: 5e-6}), - op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("remainder", 2, minus(default_dtypes, [onp.float16]), all_shapes, - jtu.rand_nonzero, [], tolerance={onp.float16: 1e-2}), - op_record("mod", 2, minus(default_dtypes, [onp.float16]), all_shapes, - jtu.rand_nonzero, []), - op_record("sinc", 1, [t for t in number_dtypes if t != lnp.bfloat16], - all_shapes, jtu.rand_default, ["rev"], - tolerance={onp.complex64: 1e-5}, inexact=True, - check_dtypes=False), - op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"], - check_dtypes=False), - op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero, - ["rev"], inexact=True), - op_record("diff", 1, number_dtypes, nonzerodim_shapes, jtu.rand_default, - ["rev"], check_incomplete_shape=False), -] - -JAX_BITWISE_OP_RECORDS = [ - op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, []), - op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, []), - op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, []), - op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, []), -] - -JAX_REDUCER_RECORDS = [ - op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), - op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), - op_record("nanmean", 1, minus(inexact_dtypes, complex_dtypes), - nonempty_shapes, jtu.rand_some_nan, [], inexact=True), - op_record("nanprod", 1, minus(inexact_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_nan, []), - op_record("nansum", 1, minus(number_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_nan, []), -] - -JAX_REDUCER_NO_DTYPE_RECORDS = [ - op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), - op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), - op_record("max", 1, minus(all_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_default, []), - op_record("min", 1, minus(all_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_default, []), - op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), -] - -JAX_ARGMINMAX_RECORDS = [ - op_record("argmin", 1, minus(all_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_some_equal, []), - op_record("argmax", 1, minus(all_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_some_equal, []), -] - -JAX_OPERATOR_OVERLOADS = [ - op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [], - tolerance={onp.float32: 2e-4, onp.complex64: 2e-4, onp.complex128: 1e-14}), - op_record("__mod__", 2, minus(default_dtypes, [onp.float16]), all_shapes, jtu.rand_nonzero, [], - tolerance={onp.float16: 1e-1}), - op_record("__floordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], - inexact=True), - op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - # TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2 - op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []), - # TODO(mattjj): investigate these failures - # op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - # op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), - # TODO(mattjj): lshift, rshift -] - -JAX_RIGHT_OPERATOR_OVERLOADS = [ - op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [], - tolerance={onp.float32: 2e-4, onp.complex64: 1e-3}), - op_record("__rmod__", 2, minus(default_dtypes, [onp.float16]), all_shapes, jtu.rand_nonzero, [], - tolerance={onp.float16: 1e-1}), - op_record("__rfloordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], - inexact=True), - # op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - # op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), -] - -numpy_version = tuple(map(int, onp.version.version.split('.'))) -if numpy_version >= (1, 15): - JAX_COMPOUND_OP_RECORDS += [ - op_record("isclose", 2, [t for t in all_dtypes if t != lnp.bfloat16], - all_shapes, jtu.rand_small_positive, []), - op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default, []), - op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default, []), - ] - JAX_REDUCER_NO_DTYPE_RECORDS += [ - op_record("ptp", 1, minus(number_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_default, []), - ] - -if six.PY2: - JAX_OPERATOR_OVERLOADS += [ - op_record("__div__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), - ] - JAX_RIGHT_OPERATOR_OVERLOADS += [ - op_record("__rdiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), - ] - - -CombosWithReplacement = itertools.combinations_with_replacement - - -def _dtypes_are_compatible_for_bitwise_ops(args): - if len(args) <= 1: - return True - is_signed = lambda dtype: lnp.issubdtype(dtype, onp.signedinteger) - width = lambda dtype: lnp.iinfo(dtype).bits - x, y = args - # `lnp.iinfo(dtype).bits` can't be called on bools, so we convert bools to - # ints. - if x == lnp.bool_: - x = lnp.int32 - if y == lnp.bool_: - y = lnp.int32 - if width(x) > width(y): - x, y = y, x - if x == lnp.uint32 and y == lnp.uint64: - return False - # The following condition seems a little ad hoc, but seems to capture what - # numpy actually implements. - return ( - is_signed(x) == is_signed(y) - or (width(x) == 32 and width(y) == 32) - or (width(x) == 32 and width(y) == 64 and is_signed(y))) - - -def _shapes_are_broadcast_compatible(shapes): - accumulator = onp.zeros([]) - for shape in shapes: - try: - accumulator = accumulator + onp.zeros(shape) - except ValueError: - return False - return True - -def _shapes_are_equal_length(shapes): - return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) - - -def _promote_like_lnp(fun, inexact=False): - """Decorator that promotes the arguments of `fun` to `lnp.result_type(*args)`. - - lnp and onp have different type promotion semantics; this decorator allows - tests make an onp reference implementation act more like an lnp - implementation. - """ - def wrapper(*args, **kw): - flat_args = tf.nest.flatten(args) - if inexact and not any( - lnp.issubdtype(lnp.result_type(x).as_numpy_dtype, lnp.inexact) - for x in flat_args): - dtype = lnp.result_type(lnp.float_, *flat_args) - else: - dtype = lnp.result_type(*flat_args) - dtype = dtype.as_numpy_dtype - args = tf.nest.map_structure(lambda a: onp.asarray(a, dtype), args) - return fun(*args, **kw) - return wrapper - - -def new_test(f): - - def wrapper(self, *args, **kwargs): - if not FLAGS.tf_numpy_additional_tests: - self.skipTest("Newly added test is disabled, since flag is False.") - else: - f(self, *args, **kwargs) - - return wrapper - - -def named_parameters(ls): - """A version that allows an empty param list.""" - def noop(_): - def wrapper(self, *args, **kwargs): - self.skipTest("Empty parameter list") - return wrapper - if isinstance(ls, (list, tuple)) and not ls: - return noop - if isinstance(ls, itertools.chain): - try: - first = next(ls) - except StopIteration: - return noop - else: - ls = itertools.chain([first], ls) - return parameterized.named_parameters(ls) - - -# TODO(wangpeng): Enable all disabled tests in this class -class LaxBackedNumpyTests(jtu.TestCase): - """Tests for LAX-backed Numpy implementation.""" - - def _GetArgsMaker(self, rng, shapes, dtypes, onp_arrays=True): - def f(): - out = [rng(shape, dtype or lnp.float_) - for shape, dtype in zip(shapes, dtypes)] - return out if onp_arrays else [lnp.asarray(a) for a in out] - return f - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance, - "inexact": rec.inexact, - "check_incomplete_shape": rec.check_incomplete_shape} - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, - JAX_COMPOUND_OP_RECORDS))) - def testOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes, check_dtypes, - tolerance, inexact, check_incomplete_shape): - # TODO(b/147769803): Remove this skipping - if lnp_op.__name__ == "kron" and shapes == ((2, 3, 4), (2, 3, 4)): - self.skipTest("Case disabled because of b/147769803") - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) - tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) - tol = functools.reduce(jtu.join_tolerance, - [tolerance, tol, jtu.default_tolerance()]) - self._CheckAgainstNumpy(_promote_like_lnp(onp_op, inexact), lnp_op, - args_maker, check_dtypes=check_dtypes, tol=tol) - # tf.math.pow doesn't support int32/int64 on XLA (b/169191476). - check_xla = not (lnp_op.__name__ == "power" and set(dtypes).intersection( - (onp.int32, onp.int64))) - self._CompileAndCheck(lnp_op, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol, - check_incomplete_shape=check_incomplete_shape, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, - "tol": rec.tolerance} - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in JAX_OPERATOR_OVERLOADS)) - def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): - rng = rng_factory() - # onp and lnp arrays have different type promotion rules; force the use of - # lnp arrays. - args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) - fun = lambda *xs: getattr(operator, name.strip('_'))(*xs) - scalar_arg = (jtu.PYTHON_SCALAR_SHAPE in shapes or - jtu.NUMPY_SCALAR_SHAPE in shapes or - () in shapes) - empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) - self._CompileAndCheck( - fun, args_maker, check_dtypes=True, #not scalar_arg and not empty_shape, - atol=tol, rtol=tol) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, - "op_tolerance": rec.tolerance} - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in JAX_RIGHT_OPERATOR_OVERLOADS)) - def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, - op_tolerance): - if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: - raise SkipTest() # TODO(mattjj): clean up - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) - fun = lambda fst, snd: getattr(snd, name)(fst) - tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) - scalar_arg = (jtu.PYTHON_SCALAR_SHAPE in shapes or - jtu.NUMPY_SCALAR_SHAPE in shapes or - () in shapes) - empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) - self._CompileAndCheck( - fun, args_maker, check_dtypes=True, # not scalar_arg and not empty_shape, - atol=tol, rtol=tol) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.test_name, shapes, dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)} - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(rec.shapes, rec.nargs)) - for dtypes in filter( - _dtypes_are_compatible_for_bitwise_ops, - CombosWithReplacement(rec.dtypes, rec.nargs))) - for rec in JAX_BITWISE_OP_RECORDS)) - def testBitwiseOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, shapes, dtypes) - has_python_scalar = jtu.PYTHON_SCALAR_SHAPE in shapes - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - if onp_op == onp.bitwise_not and has_python_scalar: - # For bitwise_not with a Python `int`, npe.jit may choose a different - # dtype for the `int` from onp's choice, which may result in a different - # result value, so we skip _CompileAndCheck. - return - # Numpy does value-dependent dtype promotion on Python/numpy/array scalars - # which `jit` can't do (when np.result_type is called inside `jit`, tensor - # values are not available), so we skip dtype check in this case. - check_dtypes = not(set(shapes) & set([jtu.NUMPY_SCALAR_SHAPE, - jtu.PYTHON_SCALAR_SHAPE, ()])) - self._CompileAndCheck(lnp_op, args_maker, check_dtypes=check_dtypes) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, - "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for out_dtype in [None] + rec.dtypes - for axis in set(range(-len(shape), len(shape))) | set([None]) - for keepdims in [False, True]) - for rec in JAX_REDUCER_RECORDS)) - def testReducer(self, onp_op, lnp_op, rng_factory, shape, dtype, out_dtype, - axis, keepdims, inexact): - rng = rng_factory() - def onp_fun(x): - x_cast = x if dtype != lnp.bfloat16 else x.astype(onp.float32) - t = out_dtype if out_dtype != lnp.bfloat16 else onp.float32 - return onp_op(x_cast, axis, dtype=t, keepdims=keepdims) - onp_fun = _promote_like_lnp(onp_fun, inexact) - lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {onp.float16: 1e-2, onp.float32: 1e-3, onp.complex64: 1e-3, - onp.float64: 1e-5, onp.complex128: 1e-5} - tol = jtu.tolerance(dtype, tol_spec) - tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=lnp.bfloat16 not in (dtype, out_dtype), - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for rec in JAX_REDUCER_NO_DTYPE_RECORDS - for shape in rec.shapes for dtype in rec.dtypes - for axis in set(range(-len(shape), len(shape))) | set([None]) - for keepdims in [False, True])) - def testReducerNoDtype(self, onp_op, lnp_op, rng_factory, shape, dtype, axis, - keepdims, inexact): - rng = rng_factory() - onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) - onp_fun = _promote_like_lnp(onp_fun, inexact) - lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for shape in all_shapes for dtype in all_dtypes - for axis in set(range(-len(shape), len(shape))) | set([None]))) - def testCountNonzero(self, shape, dtype, axis): - rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.count_nonzero(x, axis) - lnp_fun = lambda x: lnp.count_nonzero(x, axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes for dtype in all_dtypes)) - def testNonzero(self, shape, dtype): - rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.nonzero(x) - lnp_fun = lambda x: lnp.nonzero(x) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) - # The shapes of `nonzero`'s results are value-dependent, so `eval_on_shapes` - # won't return concrete shapes. - # Also, `nonzero` requires a known rank. - # Turns off XLA check because there are no XLA kernels for `Where`, which - # XLA can't support because it's output shape is dynamic. - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_eval_on_shapes=False, - check_incomplete_shape=True, check_unknown_rank=False, - check_experimental_compile=False, check_xla_forced_compile=False) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "axis": axis} - for rec in JAX_ARGMINMAX_RECORDS - for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) - for axis in range(-len(shape), len(shape)))) - def testArgMinMax(self, onp_op, lnp_op, rng_factory, shape, dtype, axis): - rng = rng_factory() - if dtype == onp.complex128 and jtu.device_under_test() == "gpu": - raise unittest.SkipTest("complex128 reductions not supported on GPU") - - def onp_fun(array_to_reduce): - return onp_op(array_to_reduce, axis).astype(lnp.int_) - - def lnp_fun(array_to_reduce): - return lnp_op(array_to_reduce, axis) - - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), - axes), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "axes": axes, "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for lhs_shape, rhs_shape, axes in [ - [(2,), (2,), (-1, -1, -1, None)], # scalar output - [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors - [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors - [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting - [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes - [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting - [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors - [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting - [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing - [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before - ] - for lhs_dtype, rhs_dtype in CombosWithReplacement( - minus(number_dtypes, complex_dtypes), 2))) - def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - axisa, axisb, axisc, axis = axes - lnp_fun = lambda a, b: lnp.cross(a, b, axisa, axisb, axisc, axis) - def onp_fun(a, b): - a = a.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else a - b = b.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else b - out = onp.cross(a, b, axisa, axisb, axisc, axis) - return out.astype(lnp.promote_types(lhs_dtype, rhs_dtype)) - tol_spec = { - # TODO(wangpeng): dtypes.bfloat16: 3e-1, - onp.float16: 0.15} - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - name, - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for name, lhs_shape, rhs_shape in [ - ("matrix-scalar", (3, 3), ()), - ("scalar-matrix", (), (3, 3)), - ("matrix-vector", (4, 5), (5,)), - ("vector-matrix", (6,), (6, 4)), - ("matrix-matrix", (3, 4), (4, 5)), - ("tensor-vector", (4, 3, 2), (2,)), - ("vector-tensor", (2,), (3, 2, 4)), - ("tensor-matrix", (4, 3, 2), (2, 5)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-tensor", (2, 3, 4), (5, 4, 1))] - for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) - def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 1e-14, - onp.complex128: 1e-14} - if jtu.device_under_test() == "tpu": - tol[onp.float32] = tol[onp.complex64] = 2e-1 - def onp_dot(x, y): - x = x.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else x - y = y.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else y - # `onp.dot(x, y).dtype` sometimes differs from `onp.result_type(x, y)` - # (e.g. when x is float64[] and y is complex64[3,3], or when x is - # float16[3,3] and y is int64[]). We ignore this corner case and pretend - # that they agree. - return onp.dot(x, y).astype(onp.result_type(x, y)) - self._CheckAgainstNumpy(onp_dot, lnp.dot, args_maker, - check_dtypes=True, tol=tol) - # We disable dtype check in the following cases because `np.dot` does - # value-dependent type promotion in those cases. - check_dtypes = () not in (lhs_shape, rhs_shape) - # XLA lacks int32/int64 MatMul kernels (b/168657656). - check_xla = not set((lhs_dtype, rhs_dtype)).intersection( - (onp.int32, onp.int64)) - self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - name, - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for name, lhs_shape, rhs_shape in [ - ("vector-vector", (3,), (3,)), - ("matrix-vector", (3, 3), (3,)), - ("vector-matrix", (3,), (3, 3)), - ("matrix-matrix", (3, 3), (3, 3)), - ("vector-tensor", (3,), (5, 3, 2)), - ("tensor-vector", (5, 3, 2), (2,)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-matrix", (5, 2, 3), (3, 2)), - ("tensor-tensor", (5, 3, 4), (5, 4, 1)), - ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))] - for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) - def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): - rng = rng_factory() - def onp_fun(x, y): - dtype = lnp.promote_types(lhs_dtype, rhs_dtype) - return (onp.matmul(x, y).astype(dtype), - onp.array(x).__matmul__(y).astype(dtype), - onp.array(y).__rmatmul__(x).astype(dtype)) - def lnp_fun(x, y): - return (lnp.matmul(x, y), - lnp.array(x).__matmul__(y), - lnp.array(y).__rmatmul__(x)) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {onp.float16: 1e-2, onp.float32: 2e-2, onp.float64: 1e-12, - onp.complex128: 1e-12} - if jtu.device_under_test() == "tpu": - tol[onp.float32] = tol[onp.complex64] = 4e-2 - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=True, tol=tol) - # XLA lacks int32/int64 MatMul kernels (b/168657656). - check_xla = not set((lhs_dtype, rhs_dtype)).intersection( - (onp.int32, onp.int64)) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - name, - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for name, lhs_shape, rhs_shape in [ - ("vector-vector", (3,), (3,)), - ("vector-matrix", (9,), (3, 3)), - ("matrix-matrix", (3, 3), (3, 3)), - ("tensor-vector", (5, 3, 2), (30,))] - for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) - @new_test - def testVDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {onp.float16: 1e-2, onp.float32: 2e-2, onp.float64: 1e-12, - onp.complex128: 1e-12} - self._CheckAgainstNumpy(onp.vdot, lnp.vdot, args_maker, - check_dtypes=True, tol=tol) - # XLA lacks int32/int64 MatMul kernels (b/168657656). - check_xla = not set((lhs_dtype, rhs_dtype)).intersection( - (onp.int32, onp.int64)) - self._CompileAndCheck(lnp.vdot, args_maker, check_dtypes=True, atol=tol, - rtol=tol, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), - axes), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "axes": axes, "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for lhs_shape, rhs_shape, axes in [ - [(2, 3, 4), (5, 6, 7), 0], # from issue #740 - [(2, 3, 4), (3, 4, 5, 6), 2], - [(2, 3, 4), (5, 4, 3, 6), [1, 2]], - [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], - [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], - ] - for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) - def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - lnp_fun = lambda a, b: lnp.tensordot(a, b, axes) - def onp_fun(a, b): - a = a if lhs_dtype != lnp.bfloat16 else a.astype(onp.float32) - b = b if rhs_dtype != lnp.bfloat16 else b.astype(onp.float32) - dtype = lnp.promote_types(lhs_dtype, rhs_dtype) - return onp.tensordot(a, b, axes).astype(dtype) - tol = {onp.float16: 1e-1, onp.float32: 1e-3, onp.float64: 1e-12, - onp.complex64: 1e-3, onp.complex128: 1e-12} - if jtu.device_under_test() == "tpu": - tol[onp.float32] = tol[onp.complex64] = 2e-1 - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, - tol=tol) - # XLA lacks int32/int64 MatMul kernels (b/168657656). - check_xla = not set((lhs_dtype, rhs_dtype)).intersection( - (onp.int32, onp.int64)) - - tol = {onp.float64: 1e-14, onp.float16: 0.04, onp.complex128: 6e-15} - tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol)) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla, - atol = tol, - rtol = tol) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "rng_factory": jtu.rand_default} - # TODO(phawkins): support integer dtypes too. - for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - if len(jtu._dims_of_shape(lhs_shape)) == 0 - or len(jtu._dims_of_shape(rhs_shape)) == 0 - or lhs_shape[-1] == rhs_shape[-1])) - def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - def onp_fun(lhs, rhs): - lhs = lhs if lhs_dtype != lnp.bfloat16 else lhs.astype(onp.float32) - rhs = rhs if rhs_dtype != lnp.bfloat16 else rhs.astype(onp.float32) - dtype = lnp.promote_types(lhs_dtype, rhs_dtype) - return onp.inner(lhs, rhs).astype(dtype) - lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs) - tol_spec = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 2e-6} - if jtu.device_under_test() == "tpu": - tol_spec[onp.float32] = tol_spec[onp.complex64] = 2e-1 - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - # TODO(phawkins): there are float32/float64 disagreements for some inputs. - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False, atol=tol, - rtol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_amin={}_amax={}".format( - jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), - "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, - "rng_factory": jtu.rand_default} - for shape in all_shapes for dtype in minus(number_dtypes, complex_dtypes) - for a_min, a_max in [(-1, None), (None, 1), (-1, 1), - (-onp.ones(1), None), - (None, onp.ones(1)), - (-onp.ones(1), onp.ones(1))])) - def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) - lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {onp.float64: 2e-7} - tol = jtu.tolerance(dtype, tol_spec) - is_x32_scalar = (dtype in [onp.int32, onp.float32] and - shape in [jtu.NUMPY_SCALAR_SHAPE, ()]) - # Turns check_dtypes off if is_x32_scalar is True because there is - # a weird promotion inconsistency in numpy: - # ``` - # print(np.result_type(np.ones([], np.int32), 1)) - # print(np.result_type(np.ones([1], np.int32), 1)) - # print(np.result_type(np.int32(1), 1)) - # print(np.result_type(np.int32, 1)) - # print(np.result_type(np.ones([], np.float32), 1)) - # print(np.result_type(np.ones([1], np.float32), 1)) - # print(np.result_type(np.float32(1), 1)) - # print(np.result_type(np.float32, 1)) - # ``` - # >>> - # int64 - # int32 - # int64 - # int32 - # float64 - # float32 - # float64 - # float32 - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=not is_x32_scalar, tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=not is_x32_scalar, - atol=tol, rtol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_amin={}_amax={}".format( - jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), - "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, - "rng_factory": jtu.rand_default} - for shape in array_shapes + [jtu.NUMPY_SCALAR_SHAPE] - for dtype in minus(number_dtypes, complex_dtypes) - for a_min, a_max in [(-1, None), (None, 1), (-1, 1), - (-onp.ones(1), None), - (None, onp.ones(1)), - (-onp.ones(1), onp.ones(1))])) - @new_test - def testClipAsMethodStaticBounds( - self, shape, dtype, a_min, a_max, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) - lnp_fun = lambda x: lnp.asarray(x).clip(a_min=a_min, a_max=a_max) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {onp.float64: 2e-7} - tol = jtu.tolerance(dtype, tol_spec) - is_x32_scalar = (dtype in [onp.int32, onp.float32] and - shape in [jtu.NUMPY_SCALAR_SHAPE, ()]) - # Turns check_dtypes off if is_x32_scalar is True because there is - # a weird promotion inconsistency in numpy: - # ``` - # print(np.result_type(np.ones([], np.int32), 1)) - # print(np.result_type(np.ones([1], np.int32), 1)) - # print(np.result_type(np.int32(1), 1)) - # print(np.result_type(np.int32, 1)) - # print(np.result_type(np.ones([], np.float32), 1)) - # print(np.result_type(np.ones([1], np.float32), 1)) - # print(np.result_type(np.float32(1), 1)) - # print(np.result_type(np.float32, 1)) - # ``` - # >>> - # int64 - # int32 - # int64 - # int32 - # float64 - # float32 - # float64 - # float32 - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=not is_x32_scalar, tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=not is_x32_scalar, - atol=tol, rtol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_decimals={}".format( - jtu.format_shape_dtype_string(shape, dtype), decimals), - "shape": shape, "dtype": dtype, "decimals": decimals, - "rng_factory": jtu.rand_default} - for shape, dtype in _shape_and_dtypes( - all_shapes, minus(number_dtypes, complex_dtypes)) - for decimals in [0, 1, -2])) - def testRoundStaticDecimals(self, shape, dtype, decimals, rng_factory): - rng = rng_factory() - if lnp.issubdtype(dtype, onp.integer) and decimals < 0: - self.skipTest("Integer rounding with decimals < 0 not implemented") - onp_fun = lambda x: onp.round(x, decimals=decimals) - lnp_fun = lambda x: lnp.round(x, decimals=decimals) - args_maker = lambda: [rng(shape, dtype)] - tol = { - # TODO(b/154768983): lnp.bfloat16: 5e-2, - onp.float16: 1e-2} - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=check_dtypes, tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol, check_incomplete_shape=True) - - def testOperatorRound(self): - self.assertAllClose(round(onp.float32(7.532), 1), - round(lnp.float32(7.5), 1), check_dtypes=True) - self.assertAllClose(round(onp.float32(1.234), 2), - round(lnp.float32(1.234), 2), check_dtypes=True) - self.assertAllClose(round(onp.float32(1.234)), - round(lnp.float32(1.234)), check_dtypes=False) - self.assertAllClose(round(onp.float32(7.532), 1), - round(lnp.array(7.5, lnp.float32), 1), check_dtypes=True) - self.assertAllClose(round(onp.float32(1.234), 2), - round(lnp.array(1.234, lnp.float32), 2), check_dtypes=True) - self.assertAllClose(round(onp.float32(1.234)), - round(lnp.array(1.234, lnp.float32)), - check_dtypes=False) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_mode={}_rpadwidth={}_rconstantvalues={}".format( - jtu.format_shape_dtype_string(shape, dtype), mode, pad_width_rank, - constant_values_rank), - "shape": shape, "dtype": dtype, "mode": mode, - "pad_width_rank": pad_width_rank, - "constant_values_rank": constant_values_rank, - "rng_factory": jtu.rand_default, - "irng_factory": partial(jtu.rand_int, 3)} - for mode, constant_values_rank, shapes in [ - ('constant', 0, all_shapes), - ('constant', 1, all_shapes), - ('constant', 2, all_shapes), - ('symmetric', None, nonempty_shapes), - ('reflect', None, nonempty_shapes), - ('wrap', None, nonempty_shapes), - ] - for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) - for pad_width_rank in range(3))) - @jtu.disable - def testPad(self, shape, dtype, mode, pad_width_rank, constant_values_rank, - rng_factory, irng_factory): - rng = rng_factory() - irng = irng_factory() - pad_width = irng([len(shape), 2][2 - pad_width_rank:], onp.int32) - def onp_fun(x, kwargs): - if pad_width.size == 0: - return x - return onp.pad(x, pad_width, mode=mode, **kwargs) - def lnp_fun(x, kwargs): - return lnp.pad(x, pad_width, mode=mode, **kwargs) - - def args_maker(): - kwargs = {} - if constant_values_rank: - kwargs["constant_values"] = rng( - [len(shape), 2][2 - constant_values_rank:], dtype) - return rng(shape, dtype), kwargs - - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape=[{}]_reps={}".format( - jtu.format_shape_dtype_string(shape, dtype), reps), - "shape": shape, "dtype": dtype, "reps": reps, - "rng_factory": jtu.rand_default} - for reps in [(), (2,), (3, 4), (2, 3, 4)] - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) - )) - def testTile(self, shape, dtype, reps, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp.tile(arg, reps) - lnp_fun = lambda arg: lnp.tile(arg, reps) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {onp.float64: 2e-7} - tol = jtu.tolerance(dtype, tol_spec) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( - axis, ",".join(str(d) for d in base_shape), - ",".join(onp.dtype(dtype).name for dtype in arg_dtypes)), - "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes, - "rng_factory": jtu.rand_default} - for num_arrs in [3] - for arg_dtypes in CombosWithReplacement(default_dtypes, num_arrs) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape)))) - def testConcatenate(self, axis, base_shape, arg_dtypes, rng_factory): - rng = rng_factory() - wrapped_axis = axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - def onp_fun(*args): - # TODO(nareshmodi): enable once bfloat16 has better support - # args = [x if x.dtype != bfloat16 else x.astype(onp.float32) - # for x in args] - dtype = functools.reduce(lnp.promote_types, arg_dtypes) - return onp.concatenate(args, axis=axis).astype(dtype) - lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( - axis, ",".join(str(d) for d in base_shape), - ",".join(onp.dtype(dtype).name for dtype in arg_dtypes)), - "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes, - "rng_factory": jtu.rand_default} - for arg_dtypes in CombosWithReplacement(default_dtypes, 2) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape)))) - def testAppend(self, axis, base_shape, arg_dtypes, rng_factory): - rng = rng_factory() - wrapped_axis = axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - def onp_fun(arr, values): - arr = arr.astype(onp.float32) if lnp.bfloat16 == arr.dtype else arr - values = ( - values.astype(onp.float32) - if lnp.bfloat16 == values.dtype else values) - out = onp.append(arr, values, axis=axis) - return out.astype(lnp.promote_types(*arg_dtypes)) - lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape=[{}]_axis={}_repeats={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, repeats), - "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats, - "rng_factory": jtu.rand_default} - for repeats in [0, 1, 2] - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) - for axis in [None] + list(range(-len(shape), len(shape))))) - def testRepeat(self, axis, shape, dtype, repeats, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis) - onp_fun = _promote_like_lnp(onp_fun) - lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) - - args_maker = lambda: [rng(shape, dtype)] - - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False) - - def testIssue1233(self): - ''' - Following numpy test suite from `test_repeat` at https://github.com/numpy/numpy/blob/master/numpy/core/tests/test_multiarray.py - ''' - tol = 1e-5 - - def test_single(m, args_maker, repeats, axis): - lax_ans = lnp.repeat(m, repeats, axis) - numpy_ans = onp.repeat(m, repeats, axis) - - self.assertAllClose(lax_ans, numpy_ans, check_dtypes=True, rtol=tol, atol=tol) - - lnp_fun = lambda arg: lnp.repeat(arg, repeats = repeats, axis=axis) - # Turns off XLA check because there are no XLA kernels for `Where` used by - # tf.repeat (b/169192730). - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False, - check_experimental_compile=False, check_xla_forced_compile=False) - - m = lnp.array([1,2,3,4,5,6]) - args_maker = lambda: [m] - - for repeats in [2, [1,3,2,1,1,2], [1,3,0,1,1,2], [2], lnp.array([1,3,2,1,1,2]), lnp.array([2])]: - test_single(m, args_maker, repeats, None) - - m_rect = m.reshape((2,3)) - args_maker = lambda: [m_rect] - - for repeats in [2, [2,1], [2], lnp.array([2,1]), lnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=0) - - for repeats in [2, [1,3,2], [2], lnp.array([1,3,2]), lnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=1) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( - op, jtu.format_shape_dtype_string(shape, dtype), axis, out_dtype), - "axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "rng_factory": jtu.rand_default, "lnp_op": getattr(lnp, op), - "onp_op": getattr(onp, op)} - for op in ["cumsum", "cumprod"] - for dtype in default_dtypes - for out_dtype in default_dtypes - for shape in all_shapes - for axis in [None] + list(range(-len(shape), len(shape))))) - def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp_op(arg, axis=axis, dtype=out_dtype) - lnp_fun = lambda arg: lnp_op(arg, axis=axis, dtype=out_dtype) - - args_maker = lambda: [rng(shape, dtype)] - - tol = max(jtu.tolerance(dtype), jtu.tolerance(out_dtype)) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, - tol=tol) - # XLA lacks int64 Cumsum/Cumprod kernels (b/168841378). - check_xla = out_dtype != onp.int64 - rtol = None - if out_dtype == onp.float16: - rtol = 2e-3 - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, rtol=rtol, - check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_dtype={}_m={}_n={}_k={}".format( - onp.dtype(dtype).name, m, n, k), - "m": m, "n": n, "k": k, "dtype": dtype, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for n in [0, 4] - for m in [None, 0, 1, 3, 4] - for k in list(range(-4, 4)))) - def testTri(self, m, n, k, dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype) - lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_shape={}_k={}".format( - op, jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "op": op, "k": k, - "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for op in ["tril", "triu"] - for k in list(range(-3, 3)))) - def testTriLU(self, dtype, shape, op, k, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: getattr(onp, op)(arg, k=k) - lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - # Incomplete shape support is not implemented at the moment. - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_ndim={}_n={}".format(ndim, n), - "ndim": ndim, "n": n} - for ndim in [0, 1, 4] - for n in [0, 1, 7])) - def testDiagIndices(self, ndim, n): - onp.testing.assert_equal(onp.diag_indices(n, ndim), - lnp.diag_indices(n, ndim)) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}".format( - jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "k": k, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] - for k in list(range(-4, 4)))) - def testDiag(self, shape, dtype, k, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp.diag(arg, k) - lnp_fun = lambda arg: lnp.diag(arg, k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( - jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2), - "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1, - "axis2": axis2, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in [a for a in range(-len(shape), len(shape)) - if a % len(shape) != axis1 % len(shape)] - for offset in list(range(-4, 4)))) - def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2) - lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n), - "dtype": dtype, "n": n} - for dtype in default_dtypes - for n in list(range(4)))) - def testIdentity(self, n, dtype): - onp_fun = lambda: onp.identity(n, dtype) - lnp_fun = lambda: lnp.identity(n, dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format( - jtu.format_shape_dtype_string(shape, dtype), - out_dtype, offset, axis1, axis2), - "dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset, - "axis1": axis1, "axis2": axis2, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for out_dtype in [None] + number_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in range(-len(shape), len(shape)) - if (axis1 % len(shape)) != (axis2 % len(shape)) - for offset in list(range(-4, 4)))) - def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng_factory): - rng = rng_factory() - def onp_fun(arg): - if out_dtype == lnp.bfloat16: - return onp.trace(arg, offset, axis1, axis2, onp.float32).astype(lnp.bfloat16) - else: - return onp.trace(arg, offset, axis1, axis2, out_dtype) - lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis), - "shape": shape, "axis": axis, "dtypes": dtypes, "rng_factory": rng_factory} - for dtypes in [ - [onp.float32], - [onp.float32, onp.float32], - [onp.float32, onp.int32, onp.float32], - [onp.float32, onp.int64, onp.float32], - [onp.float32, onp.int32, onp.float64], - ] - for shape in [(), (2,), (3, 4), (1, 100)] - for axis in range(-len(shape), len(shape) + 1) - for rng_factory in [jtu.rand_default])) - def testStack(self, shape, axis, dtypes, rng_factory): - rng = rng_factory() - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - onp_fun = _promote_like_lnp(partial(onp.stack, axis=axis)) - lnp_fun = partial(lnp.stack, axis=axis) - self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_{}".format( - op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)), - "shape": shape, "op": op, "dtypes": dtypes, "rng_factory": rng_factory} - for op in ["hstack", "vstack", "dstack"] - for dtypes in [ - [onp.float32], - [onp.float32, onp.float32], - [onp.float32, onp.int32, onp.float32], - [onp.float32, onp.int64, onp.float32], - [onp.float32, onp.int32, onp.float64], - ] - for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)] - for rng_factory in [jtu.rand_default])) - def testHVDStack(self, shape, op, dtypes, rng_factory): - rng = rng_factory() - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - onp_fun = _promote_like_lnp(getattr(onp, op)) - lnp_fun = getattr(lnp, op) - self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outdtype={}".format( - jtu.format_shape_dtype_string(shape, fill_value_dtype), - onp.dtype(out_dtype).name if out_dtype else "None"), - "shape": shape, "fill_value_dtype": fill_value_dtype, - "out_dtype": out_dtype, "rng_factory": jtu.rand_default} - for shape in array_shapes + [3, onp.array(7, dtype=onp.int32)] - for fill_value_dtype in default_dtypes - for out_dtype in [None] + default_dtypes)) - def testFull(self, shape, fill_value_dtype, out_dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype) - lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype) - args_maker = lambda: [rng((), fill_value_dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype), - "onp_op": getattr(onp, op), "lnp_op": getattr(lnp, op), - "shape": shape, "dtype": dtype} - for op in ["zeros", "ones"] - for shape in [2, (), (2,), (3, 0), onp.array((4, 5, 6), dtype=onp.int32), - onp.array(4, dtype=onp.int32)] - for dtype in all_dtypes)) - def testZerosOnes(self, onp_op, lnp_op, shape, dtype): - rng = jtu.rand_default() - def args_maker(): return [] - onp_op = partial(onp_op, shape, dtype) - lnp_op = partial(lnp_op, shape, dtype) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format( - jtu.format_shape_dtype_string(shape, in_dtype), - onp.dtype(fill_value_dtype).name, - onp.dtype(out_dtype).name), - "shape": shape, "in_dtype": in_dtype, - "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, - "rng_factory": jtu.rand_default} - for shape in array_shapes - for in_dtype in default_dtypes - for fill_value_dtype in default_dtypes - for out_dtype in default_dtypes)) - def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype) - lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype) - args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_{}sections".format( - jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), - "shape": shape, "num_sections": num_sections, "axis": axis, - "dtype": dtype, "rng_factory": jtu.rand_default} - for shape, axis, num_sections in [ - ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), - ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] - for dtype in default_dtypes)) - def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.split(x, num_sections, axis=axis) - lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_{}sections".format( - jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), - "shape": shape, "num_sections": num_sections, "axis": axis, - "dtype": dtype, "rng_factory": jtu.rand_default} - for shape, axis, num_sections in [ - ((12, 4), 0, 4), ((12, 4), 1, 2), - ((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)] - for dtype in default_dtypes)) - def testHVDSplit(self, shape, num_sections, axis, dtype, rng_factory): - rng = rng_factory() - def fn(module, axis): - if axis == 0: - return module.vsplit - elif axis == 1: - return module.hsplit - else: - assert axis == 2 - return module.dsplit - - onp_fun = lambda x: fn(onp, axis)(x, num_sections) - lnp_fun = lambda x: fn(lnp, axis)(x, num_sections) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_order={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype), - order), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, - "order": order, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for order in ["C", "F"] - for arg_shape, out_shape in [ - (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), - ((), (1, 1, 1)), - ((7, 0), (0, 42, 101)), - ((3, 4), 12), - ((3, 4), (12,)), - ((3, 4), -1), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ])) - def testReshape(self, arg_shape, out_shape, dtype, order, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.reshape(x, out_shape, order=order) - lnp_fun = lambda x: lnp.reshape(x, out_shape, order=order) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype)), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, - "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for arg_shape, out_shape in [ - ((7, 0), (0, 42, 101)), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ])) - def testReshapeMethod(self, arg_shape, out_shape, dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.reshape(x, out_shape) - lnp_fun = lambda x: x.reshape(*out_shape) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_expanddim={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), dim), - "arg_shape": arg_shape, "dtype": dtype, "dim": dim, - "rng_factory": jtu.rand_default} - for arg_shape in [(), (3,), (3, 4)] - for dtype in default_dtypes - for dim in range(-len(arg_shape)+1, len(arg_shape)))) - def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.expand_dims(x, dim) - lnp_fun = lambda x: lnp.expand_dims(x, dim) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_axes=({},{})".format( - jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), - "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2, - "rng_factory": jtu.rand_default} - for arg_shape, ax1, ax2 in [ - ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), - ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] - for dtype in default_dtypes)) - def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) - lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axes=({},{})".format( - jtu.format_shape_dtype_string(arg_shape, dtype), source, destination), - "arg_shape": arg_shape, "dtype": dtype, "source": source, - "destination": destination, "rng_factory": jtu.rand_default} - for arg_shape, source, destination in [ - (tuple(range(6)), (0, 2), (3, 5)), - (tuple(range(6)), (0, 2), (-1, -3)), - (tuple(range(6)), (-6, -4),(3, 5)), - (tuple(range(6)), (-6, -4), (-1, -3)), - (tuple(range(6)), 0, 4), - (tuple(range(6)), -6, -2), - (tuple(range(6)), tuple(range(6)), tuple(range(6))), - (tuple(range(6)), tuple(range(6)), tuple(reversed(range(6)))), - (tuple(range(6)), (), ()), - ] for dtype in default_dtypes)) - @new_test - def testMoveaxisStaticAxes(self, arg_shape, dtype, source, destination, - rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.moveaxis(x, source, destination) - lnp_fun = lambda x: lnp.moveaxis(x, source, destination) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_axis={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), ax), - "arg_shape": arg_shape, "dtype": dtype, "ax": ax, - "rng_factory": jtu.rand_default} - for arg_shape, ax in [ - ((3, 1), None), - ((3, 1), 1), - ((1, 3, 1), (0, 2)), - ((1, 4, 1), (0,))] - for dtype in default_dtypes)) - def testSqueeze(self, arg_shape, dtype, ax, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.squeeze(x, ax) - lnp_fun = lambda x: lnp.squeeze(x, ax) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format( - jtu.format_shape_dtype_string(shape, dtype), - axis, - (None if weights_shape is None else jtu.format_shape_dtype_string(weights_shape, dtype)), - returned), - "rng_factory": jtu.rand_default, "shape": shape, "dtype": dtype, "axis": axis, - "weights_shape": weights_shape, "returned": returned} - for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes) - for axis in set(range(-len(shape), len(shape))) | set([None]) - # `weights_shape` is either `None`, same as the averaged axis, or same as - # that of the input - for weights_shape in ([None, shape] if axis is None or len(shape) == 1 - else [None, (shape[axis],), shape]) - for returned in [False, True])) - def testAverage(self, shape, dtype, axis, weights_shape, returned, rng_factory): - rng = rng_factory() - if weights_shape is None: - onp_fun = lambda x: onp.average(x, axis, returned=returned) - lnp_fun = lambda x: lnp.average(x, axis, returned=returned) - args_maker = lambda: [rng(shape, dtype)] - else: - onp_fun = lambda x, weights: onp.average(x, axis, weights, returned) - lnp_fun = lambda x, weights: lnp.average(x, axis, weights, returned) - args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)] - onp_fun = _promote_like_lnp(onp_fun, inexact=True) - tol = { - # TODO(b/154768983): lnp.bfloat16: 1e-1, - onp.float16: 1e-1, onp.float32: 1e-3, onp.float64: 2e-7, - onp.complex64: 1e-3, onp.complex128: 1e-10, - } - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - try: - self._CheckAgainstNumpy( - onp_fun, lnp_fun, args_maker, check_dtypes=check_dtypes, tol=tol) - except ZeroDivisionError: - self.skipTest("don't support checking for ZeroDivisionError") - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=check_dtypes, - rtol=tol, atol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_arg{}_ndmin={}".format(i, ndmin), - "arg": arg, "ndmin": ndmin, "dtype": dtype} - for i, (arg, dtype) in enumerate([ - ([True, False, True], lnp.bool_), - (3., lnp.float_), - ([1, 2, 3], lnp.int_), - ([1., 2., 3.], lnp.float_), - ([[1, 2], [3, 4], [5, 6]], lnp.int_), - ([[1, 2.], [3, 4], [5, 6]], lnp.float_), - ([[1., 2j], [3., 4.], [5., 6.]], lnp.complex_), - ([[3, onp.array(2, dtype=lnp.float_), 1], - onp.arange(3., dtype=lnp.float_)], lnp.float_), - ]) - for ndmin in [None, onp.ndim(arg), onp.ndim(arg) + 1, onp.ndim(arg) + 2])) - def testArray(self, arg, ndmin, dtype): - args_maker = lambda: [arg] - dtype = lnp.canonicalize_dtype(dtype) - if ndmin is not None: - onp_fun = partial(onp.array, ndmin=ndmin, dtype=dtype) - lnp_fun = partial(lnp.array, ndmin=ndmin) - else: - onp_fun = partial(onp.array, dtype=dtype) - lnp_fun = lnp.array - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True, static_argnums=[0]) - - def testIssue121(self): - assert not onp.isscalar(lnp.array(3)) - - @jtu.disable - def testArrayMethod(self): - class arraylike(object): - dtype = onp.float32 - def __array__(self, dtype=None): - return 3. - a = arraylike() - ans = lnp.array(a) - assert ans == 3. - - @jtu.skip_on_devices("tpu") # TODO(b/32368900): TPUs don't support uint8 yet. - @jtu.disable - def testMemoryView(self): - ans = lnp.array(bytearray(b'\x2a')) - self.assertAllClose( - ans, - onp.array([0x2a], dtype=onp.uint8), - check_dtypes=True) - - def testAllClose(self): - rng = onp.random.RandomState(0) - x = rng.randn(2, 2) - y = rng.randn(2) - - def same(list1, list2): - allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) - elements_close = list(map(allclose, list1, list2)) - return lnp.all(lnp.array(elements_close)) - - csame = npe.jit(same) - - a1 = same((x, y), (x, y)) - a2 = csame((x, y), (x, y)) - a3 = csame((x, y), (x, 2 * y)) - - self.assertTrue(a1) - self.assertTrue(a2) - self.assertFalse(a3) - - @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure - @jtu.disable - def testOnesBroadcastingConstantHandler(self): - # TODO(mattjj): update this test for jax3 - self.skipTest("test needs jax3 update") - - def fun(x): - ones = lnp.ones((3, 4)) - assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) - - # To check that the constant handler generates a Broadcast for stride-zero - # arrays, we monkey-patch the client instance. - # TODO(mattjj): once we have better HLO dumping and inspecting facilities, - # we can check the HLO more directly. - c = x._node.c - Broadcast = c.Broadcast # pylint: disable=invalid-name - was_called = [] - c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) - out = x + ones # the ndarray constant handler should call Broadcast here - assert was_called, "Broadcast was not called." - - return out - - fun = api.jit(fun) - out_val = fun(lnp.ones(4)) - self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False) - - def testZeroStridesConstantHandler(self): - raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) - const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) - - def fun(x): - return x * const - - fun = npe.jit(fun) - out_val = fun(3.) - self.assertAllClose(out_val, 3. * const, check_dtypes=False) - - def testIsInstanceNdarrayDuringTracing(self): - arr = onp.ones(3) - - @npe.jit - def f(x): - self.assertIsInstance(x, lnp.ndarray) - return lnp.sum(x) - - f(arr) - - @jtu.disable - def testNonArrayErrorMessage(self): - x = [1., 2.] - y = onp.array([3., 4.]) - - def g(x, y): - return lnp.add(x, y) - - def f(x, y): - return lnp.dot(x, y) - - self.assertRaises(TypeError, lambda: g(x, y)) - self.assertRaises(TypeError, lambda: f(x, y)) - self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) - self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) - - @jtu.disable - def testAbstractionErrorMessage(self): - - @api.jit - def f(x, n): - for _ in range(n): - x = x * x - return x - - self.assertRaises(TypeError, lambda: f(3., 3)) - - @api.jit - def g(x): - if x > 0.: - return x * 2 - else: - return x + 2 - - self.assertRaises(TypeError, lambda: g(3.)) - - @jtu.disable - def testTracingPrimitiveWithNoTranslationErrorMessage(self): - # TODO(mattjj): update this for jax3 - self.skipTest("test needs jax3 update") - foo = lnp._not_implemented(lambda x: x) - - # No error if there's no tracing. - foo(onp.arange(3)) - - cfoo = api.jit(foo) - self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis} - for shape in [(3,), (2, 3)] - for dtype in default_dtypes - for axis in list(range(-len(shape), len(shape))) + [None] # Test negative axes - for rng_factory in [jtu.rand_default])) - def testFlip(self, shape, dtype, axis, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.flip(x, axis) - onp_op = lambda x: onp.flip(x, axis) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype} - for shape in [(3,), (2, 3), (3, 2, 4)] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testFlipud(self, shape, dtype, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.flipud(x) - onp_op = lambda x: onp.flipud(x) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype} - for shape in [(3, 2), (2, 3), (3, 2, 4)] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testFliplr(self, shape, dtype, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.fliplr(x) - onp_op = lambda x: onp.fliplr(x) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_k={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), k, axes), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k, "axes": axes} - for shape, axes in [ - [(2, 3), (0, 1)], - [(2, 3), (1, 0)], - [(4, 3, 2), (0, 2)], - [(4, 3, 2), (2, 1)], - ] - for k in range(-3, 4) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testRot90(self, shape, dtype, k, axes, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.rot90(x, k, axes) - onp_op = lambda x: onp.rot90(x, k, axes) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_k={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), k, axes), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k, - "axes": axes} - for shape, axes in [ - [(2, 3), (-2, -1)], - [(2, 3), (-2, 1)], - [(4, 3, 2), (-1, -2)], - [(4, 3, 2), (2, -2)], - ] - for k in range(-3, 4) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - @new_test - # These tests are only added as a separate test from testRot90 since we would - # like to measure coverage directly against the existing baseline. Once we - # stop measuring that, we can combine this test with the above. - def testRot90Additional(self, shape, dtype, k, axes, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.rot90(x, k, axes) - onp_op = lambda x: onp.rot90(x, k, axes) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - # TODO(mattjj): test infix operator overrides - - def testRavel(self): - rng = onp.random.RandomState(0) - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True, - check_incomplete_shape=True) - - def testAstype(self): - rng = onp.random.RandomState(0) - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - op = lambda x: x.astype(lnp.int32) - self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True) - self._CompileAndCheck( - op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - # TODO(mattjj): test other ndarray-like method overrides - - def testOnpMean(self): - # from https://github.com/google/jax/issues/125 - x = lnp.add(lnp.eye(3, dtype=lnp.float_), 0.) - ans = onp.mean(x) - self.assertAllClose(ans, onp.array(1./3), check_dtypes=False) - - @jtu.disable - def testArangeOnFloats(self): - # from https://github.com/google/jax/issues/145 - expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_) - ans = lnp.arange(0.0, 1.0, 0.1) - self.assertAllClose(expected, ans, check_dtypes=True) - - def testSortManually(self): - - def _test(*args, **kwargs): - - raw_ans = lnp.sort(*args, **kwargs) - fn_ans = npe.jit(lnp.sort, static_argnums=(1,))(*args, **kwargs) - expected = onp.sort(*args, **kwargs) - - self.assertAllClose(expected, raw_ans, check_dtypes=True) - self.assertAllClose(expected, fn_ans, check_dtypes=True) - - # manual tests for sort are nice because we don't have to worry about ties. - # lax.sort is tested combinatorially. - _test(onp.array([16, 15, 23, 42, 8, 4])) - _test(onp.array([[1, 4], [3, 1]]), None) - _test(onp.array([[1, 4], [3, 1]])) - _test(onp.array([[1, 4], [3, 1]]), 0) - - def testArgsortManually(self): - - def _test(*args, **kwargs): - - raw_ans = lnp.argsort(*args, **kwargs) - fn_ans = npe.jit(lnp.argsort, static_argnums=(1,))(*args, **kwargs) - expected = onp.argsort(*args, **kwargs) - - self.assertAllClose(expected, raw_ans, check_dtypes=True) - self.assertAllClose(expected, fn_ans, check_dtypes=True) - - _test(onp.array([16, 15, 23, 42, 8, 4])) - _test(onp.array([[16, 15, 23], [42, 8, 4]]), 0) - _test(onp.array([[16, 15, 23], [42, 8, 4]]), 1) - _test(onp.array([[16, 15, 23], [42, 8, 4]]), None) - _test(onp.array([[16, 15, 23], [42, 8, 4]])) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_shifts={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), - shifts, axis), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "shifts": shifts, - "axis": axis} - for dtype in all_dtypes - for shape in [(3, 4), (3, 4, 5), (7, 4, 0)] - for shifts, axis in [ - (3, None), - (1, 1), - ((3,), (0,)), - ((-2,), (-2,)), - ((1, 2), (0, -1)) - ] - for rng_factory in [jtu.rand_default])) - def testRoll(self, shape, dtype, shifts, axis, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), onp.array(shifts)] - lnp_op = partial(lnp.roll, axis=axis) - onp_op = partial(onp.roll, axis=axis) - self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_index={}_axis={}_mode={}".format( - jtu.format_shape_dtype_string(shape, dtype), - jtu.format_shape_dtype_string(index_shape, index_dtype), - axis, mode), - "rng_factory": rng_factory, "rng_indices_factory": rng_indices_factory, - "shape": shape, "index_shape": index_shape, "dtype": dtype, - "index_dtype": index_dtype, "axis": axis, "mode": mode} - for shape in [(3,), (3, 4), (3, 4, 5)] - for index_shape in scalar_shapes + [(3,), (2, 1, 3)] - for axis in itertools.chain(range(-len(shape), len(shape)), [None]) - for dtype in all_dtypes - for index_dtype in int_dtypes - for mode in ['wrap', 'clip'] - for rng_factory in [jtu.rand_default] - for rng_indices_factory in [partial(jtu.rand_int, -5, 5)])) - def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode, - rng_factory, rng_indices_factory): - def args_maker(): - x = rng(shape, dtype) - i = rng_indices(index_shape, index_dtype) - return x, i - - rng = rng_factory() - rng_indices = rng_indices_factory() - lnp_op = lambda x, i: lnp.take(x, i, axis=axis, mode=mode) - onp_op = lambda x, i: onp.take(x, i, axis=axis, mode=mode) - self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_ishape={}_axis={}".format( - jtu.format_shape_dtype_string(x_shape, dtype), i_shape, axis), - "rng_factory": rng_factory, "x_shape": x_shape, "i_shape": i_shape, "dtype": dtype, - "axis": axis} - for x_shape, i_shape in filter( - _shapes_are_equal_length, - filter(_shapes_are_broadcast_compatible, - CombosWithReplacement(nonempty_nonscalar_array_shapes, 2))) - for axis in itertools.chain(range(len(x_shape)), [-1], [None]) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testTakeAlongAxis(self, x_shape, i_shape, dtype, axis, rng_factory): - rng = rng_factory() - i_shape = onp.array(i_shape) - if axis is None: - i_shape = [onp.prod(i_shape, dtype=onp.int64)] - else: - # Test the case where the size of the axis doesn't necessarily broadcast. - i_shape[axis] *= 3 - i_shape = list(i_shape) - def args_maker(): - x = rng(x_shape, dtype) - n = onp.prod(x_shape, dtype=onp.int32) if axis is None else x_shape[axis] - i = rng(i_shape, onp.int32) % (2 * n - 1) - (n - 1) - return x, i - - lnp_op = lambda x, i: lnp.take_along_axis(x, i, axis=axis) - - if hasattr(onp, "take_along_axis"): - onp_op = lambda x, i: onp.take_along_axis(x, i, axis=axis) - self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) - self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True, - check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_n={}_increasing={}".format( - jtu.format_shape_dtype_string([shape], dtype), - n, increasing), - "dtype": dtype, "shape": shape, "n": n, "increasing": increasing, - "rng_factory": jtu.rand_default} - for dtype in inexact_dtypes - for shape in [0, 5] - for n in [2, 4] - for increasing in [False, True])) - def testVander(self, shape, dtype, n, increasing, rng_factory): - rng = rng_factory() - def onp_fun(arg): - arg = arg.astype(onp.float32) if dtype == lnp.bfloat16 else arg - return onp.vander(arg, N=n, increasing=increasing) - lnp_fun = lambda arg: lnp.vander(arg, N=n, increasing=increasing) - args_maker = lambda: [rng([shape], dtype)] - # np.vander seems to return float64 for all floating types. We could obey - # those semantics, but they seem like a bug. - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False, - tol={onp.float32: 1e-3}) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=False, check_incomplete_shape=True, - rtol={onp.complex128: 2e-15}) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("nan_to_num", [shape], - [dtype]), - "rng_factory": jtu.rand_some_inf_and_nan, "shape": shape, - "dtype": dtype} - for shape in all_shapes - for dtype in inexact_dtypes)) - @jtu.disable - def testNanToNum(self, rng_factory, shape, dtype): - rng = rng_factory() - dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type - def onp_fun(x): - if dtype == lnp.bfloat16: - x = onp.where(onp.isnan(x), dtype(0), x) - x = onp.where(onp.isposinf(x), lnp.finfo(dtype).max, x) - x = onp.where(onp.isneginf(x), lnp.finfo(dtype).min, x) - return x - else: - return onp.nan_to_num(x).astype(dtype) - - args_maker = lambda: [rng(shape, dtype)] - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(onp_fun, lnp.nan_to_num, args_maker, - check_dtypes=check_dtypes) - self._CompileAndCheck(lnp.nan_to_num, args_maker, - check_dtypes=check_dtypes) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes), - "rng_factory": jtu.rand_default, "shapes": shapes, "dtypes": dtypes} - for shapes, dtypes in ( - ((), ()), - (((7,),), (onp.int32,)), - (((3,), (4,)), (onp.int32, onp.int32)), - (((3,), (1,), (4,)), (onp.int32, onp.int32, onp.int32)), - ))) - def testIx_(self, rng_factory, shapes, dtypes): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype) - for shape, dtype in zip(shapes, dtypes)] - self._CheckAgainstNumpy(onp.ix_, lnp.ix_, args_maker, - check_dtypes=True) - self._CompileAndCheck( - lnp.ix_, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": - "_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}".format( - op, - jtu.format_shape_dtype_string(a_shape, a_dtype), - jtu.format_shape_dtype_string(q_shape, q_dtype), - axis, keepdims), - "a_rng": jtu.rand_default(), "q_rng": q_rng, "op": op, - "a_shape": a_shape, "a_dtype": a_dtype, - "q_shape": q_shape, "q_dtype": q_dtype, "axis": axis, - "keepdims": keepdims} - for (op, q_rng) in ( - ("percentile", jtu.rand_uniform(low=0., high=100.)), - ("quantile", jtu.rand_uniform(low=0., high=1.)), - ("median", jtu.rand_uniform(low=0., high=1.)), - ) - for a_dtype in float_dtypes - for a_shape, axis in ( - ((7,), None), - ((47, 7), 0), - ((4, 101), 1), - ) - for q_dtype in [onp.float32] - for q_shape in scalar_shapes + [(4,)] - for keepdims in [False, True])) - @jtu.disable - def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype, - axis, keepdims): - if op == "quantile" and numpy_version < (1, 15): - raise SkipTest("Numpy < 1.15 does not have np.quantile") - if op == "median": - args_maker = lambda: [a_rng(a_shape, a_dtype)] - else: - args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] - - def onp_fun(*args): - args = [x if lnp.result_type(x) != lnp.bfloat16 else - onp.asarray(x, onp.float32) for x in args] - return getattr(onp, op)(*args, axis=axis, keepdims=keepdims) - lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims) - # TODO(phawkins): we currently set dtype=False because we aren't as - # aggressive about promoting to float64. It's not clear we want to mimic - # Numpy here. - tol_spec = {onp.float32: 2e-4, onp.float64: 5e-6} - tol = max(jtu.tolerance(a_dtype, tol_spec), - jtu.tolerance(q_dtype, tol_spec)) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes for dtype in all_dtypes)) - def testWhereOneArgument(self, shape, dtype): - rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.where(x) - lnp_fun = lambda x: lnp.where(x) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) - # Turns off XLA check because there are no XLA kernels for `Where`, which - # XLA can't support because it's output shape is dynamic. - self._CompileAndCheck( - lnp.where, - args_maker, - check_dtypes=True, - check_eval_on_shapes=False, - check_incomplete_shape=True, - check_unknown_rank=False, - check_experimental_compile=False, check_xla_forced_compile=False) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format("_".join( - jtu.format_shape_dtype_string(shape, dtype) - for shape, dtype in zip(shapes, dtypes))), - "rng_factory": jtu.rand_default, "shapes": shapes, "dtypes": dtypes} - for shapes in filter(_shapes_are_broadcast_compatible, - CombosWithReplacement(all_shapes, 3)) - for dtypes in CombosWithReplacement(all_dtypes, 3))) - def testWhereThreeArgument(self, rng_factory, shapes, dtypes): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng_factory(), shapes, dtypes) - def onp_fun(cond, x, y): - return _promote_like_lnp(partial(onp.where, cond))(x, y) - self._CheckAgainstNumpy(onp_fun, lnp.where, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp.where, args_maker, check_dtypes=True, check_incomplete_shape=True) - - def testWhereScalarPromotion(self): - x = lnp.where(lnp.array([True, False]), 3, - lnp.ones((2,), dtype=lnp.float32)) - self.assertEqual(x.dtype, onp.dtype(onp.float32)) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", shapes, - (onp.bool_,) * n + dtypes), - "rng_factory": jtu.rand_default, "shapes": shapes, "dtypes": dtypes} - for n in range(0, 3) - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(all_shapes, 2 * n + 1)) - for dtypes in CombosWithReplacement(all_dtypes, n + 1))) - def testSelect(self, rng_factory, shapes, dtypes): - rng = rng_factory() - n = len(dtypes) - 1 - def args_maker(): - condlist = [rng(shape, onp.bool_) for shape in shapes[:n]] - choicelist = [rng(shape, dtype) - for shape, dtype in zip(shapes[n:-1], dtypes[:n])] - default = rng(shapes[-1], dtypes[-1]) - return condlist, choicelist, default - # TODO(phawkins): float32/float64 type mismatches - def onp_fun(condlist, choicelist, default): - choicelist = [x if lnp.bfloat16 != lnp.result_type(x) - else x.astype(onp.float32) for x in choicelist] - dtype = lnp.result_type(default, *choicelist).as_numpy_dtype - return onp.select(condlist, - [onp.asarray(x, dtype=dtype) for x in choicelist], - onp.asarray(default, dtype=dtype)) - self._CheckAgainstNumpy(onp_fun, lnp.select, args_maker, - check_dtypes=False) - self._CompileAndCheck(lnp.select, args_maker, check_dtypes=True, - check_incomplete_shape=True, - rtol={onp.float64: 1e-7, onp.complex128: 1e-7}) - - - @jtu.disable - def testIssue330(self): - x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash - self.assertEqual(x[0, 0], 1) - - @jtu.disable - def testScalarDtypePromotion(self): - orig_numpy_result = (1 + onp.eye(1, dtype=onp.float32)).dtype - jax_numpy_result = (1 + lnp.eye(1, dtype=lnp.float32)).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - @jtu.disable - def testSymmetrizeDtypePromotion(self): - x = onp.eye(3, dtype=onp.float32) - orig_numpy_result = ((x + x.T) / 2).dtype - - x = lnp.eye(3, dtype=lnp.float32) - jax_numpy_result = ((x + x.T) / 2).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - @jtu.disable - def testIssue347(self): - # https://github.com/google/jax/issues/347 - def test_fail(x): - x = lnp.sqrt(lnp.sum(x ** 2, axis=1)) - ones = lnp.ones_like(x) - x = lnp.where(x > 0.5, x, ones) - return lnp.sum(x) - - x = lnp.array([[1, 2], [3, 4], [0, 0]], dtype=lnp.float64) - result = api.grad(test_fail)(x) - assert not onp.any(onp.isnan(result)) - - def testIssue453(self): - # https://github.com/google/jax/issues/453 - a = onp.arange(6) + 1 - ans = lnp.reshape(a, (3, 2), order='F') - expected = onp.reshape(a, (3, 2), order='F') - self.assertAllClose(ans, expected, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_dtype={}".format(op, pytype.__name__), - "pytype": pytype, "dtype": dtype, "op": op} - for pytype, dtype in [(int, lnp.int_), (float, lnp.float_), - (bool, lnp.bool_), (complex, lnp.complex_)] - for op in ["atleast_1d", "atleast_2d", "atleast_3d"])) - def testAtLeastNdLiterals(self, pytype, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 - onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype) - lnp_fun = lambda arg: getattr(lnp, op)(arg) - args_maker = lambda: [pytype(2)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - - def testLongLong(self): - self.assertAllClose( - onp.int64(7), npe.jit(lambda x: x)(onp.longlong(7)), check_dtypes=True) - - def testArange(self): - # test cases inspired by dask tests at - # https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92 - self.assertAllClose(lnp.arange(77), - onp.arange(77, dtype=lnp.int_), check_dtypes=True) - self.assertAllClose(lnp.arange(2, 13), - onp.arange(2, 13, dtype=lnp.int_), check_dtypes=True) - self.assertAllClose(lnp.arange(4, 21, 9), - onp.arange(4, 21, 9, dtype=lnp.int_), check_dtypes=True) - self.assertAllClose(lnp.arange(53, 5, -3), - onp.arange(53, 5, -3, dtype=lnp.int_), - check_dtypes=True) - # TODO(mattjj): make these tests work when enable_x64=True - self.assertAllClose( - lnp.arange(77, dtype=float), - onp.arange(77, dtype=float), - check_dtypes=True) - self.assertAllClose( - lnp.arange(2, 13, dtype=int), - onp.arange(2, 13, dtype=int), - check_dtypes=True) - self.assertAllClose(lnp.arange(0, 1, -0.5), - onp.arange(0, 1, -0.5, dtype=lnp.float_), - check_dtypes=True) - - self.assertRaises(TypeError, lambda: lnp.arange()) - - # # The following have been disabled since they test JAX specific behavior - # # test that lnp.arange(N) doesn't instantiate an ndarray - # self.assertFalse(type(lnp.arange(77)) == type(onp.arange(77))) - # self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77))) - - # # test that lnp.arange(N, dtype=int32) doesn't instantiate an ndarray - # self.assertFalse(type(lnp.arange(77, dtype=lnp.int32)) == - # type(onp.arange(77, dtype=onp.int32))) - # self.assertTrue(type(lnp.arange(77, dtype=lnp.int32)) == - # type(lax.iota(onp.int32, 77))) - - def testIssue830(self): - a = lnp.arange(4, dtype=lnp.complex64) - self.assertEqual(a.dtype, lnp.complex64) - - def testIssue728(self): - assert lnp.allclose(lnp.eye(5000), onp.eye(5000)) - self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050))) - - def testIssue746(self): - lnp.arange(12).reshape(3, 4) # doesn't crash - - def testIssue764(self): - x = lnp.linspace(190, 200, 4) - f = npe.grad(lambda x: lnp.sum(lnp.tanh(x))) - # Expected values computed with autograd in float64 precision. - expected = onp.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171, - 7.66067839e-174], onp.float64) - self.assertAllClose(f(x), expected, check_dtypes=False) - - @jtu.disable - def testIssue776(self): - """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" - def f(u): - y = onp.ones(10,).at[[2, 4, 5]].add(u) - # The transpose rule for lax.tie_in returns a symbolic zero for its first - # argument. - return lax.tie_in(y, 7.) - - self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)), - check_dtypes=True) - - @jtu.disable - def testIssue777(self): - x = lnp.linspace(-200, 0, 4, dtype=onp.float32) - f = npe.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x)))) - self.assertAllClose(f(x), onp.array([0., 0., 0., 0.25], dtype=onp.float32), - check_dtypes=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]), - "dtype": dtype, "op": op} - for dtype in float_dtypes - for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan", - "sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp", - "log", "expm1", "log1p"))) - def testMathSpecialFloatValues(self, op, dtype): - onp_op = getattr(onp, op) - lnp_op = getattr(lnp, op) - dtype = onp.dtype(lnp.canonicalize_dtype(dtype)).type - for x in (onp.nan, -onp.inf, -100., -2., -1., 0., 1., 2., 100., onp.inf, - lnp.finfo(dtype).max, onp.sqrt(lnp.finfo(dtype).max), - onp.sqrt(lnp.finfo(dtype).max) * 2.): - if (op in ("sin", "cos", "tan", "arctan") and - jtu.device_under_test() == "tpu"): - continue # TODO(b/132196789, b/134175194): fix and reenable. - # TODO(b/158006398): fix and reenable. - if (op in ("cosh", "arccosh", "arcsinh", "arcsin", "sinh", "arccos", - "arctan", "arctanh") and dtype == onp.float16): - continue - x = dtype(x) - expected = onp_op(x) - actual = lnp_op(x) - tol = jtu.tolerance(dtype, {onp.float32: 1e-3, onp.float64: 1e-7}) - self.assertAllClose(expected, actual, check_dtypes=True, atol=tol, - rtol=tol) - - def testIssue883(self): - # from https://github.com/google/jax/issues/883 - - @partial(npe.jit, static_argnums=(1,)) - def f(x, v): - return x - - x = lnp.ones((10, 10)) - v = lnp.array([1, 2, 3]) - first_call = f(x, v) - second_call = f(x, v) # doesn't crash - - def testReductionOfOutOfBoundsAxis(self): # Issue 888 - x = lnp.ones((3, 4)) - self.assertRaises( - tf.errors.InvalidArgumentError, lambda: lnp.sum(x, axis=2)) - - @jtu.disable - def testIssue956(self): - self.assertRaises(TypeError, lambda: lnp.ndarray((1, 1))) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype, out_dtype, axis, ddof, keepdims), - "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, - "ddof": ddof, "keepdims": keepdims, "rng_factory": rng_factory} - for shape in [(5,), (10, 5)] - for dtype in all_dtypes - for out_dtype in inexact_dtypes - for axis in [None, 0, -1] - for ddof in [0, 1, 2] - for keepdims in [False, True] - for rng_factory in [jtu.rand_default])) - def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - def onp_fun(x): - out = onp.var(x.astype(lnp.promote_types(onp.float32, dtype)), - axis=axis, ddof=ddof, keepdims=keepdims) - return out.astype(out_dtype) - lnp_fun = partial(lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) - tol = jtu.tolerance(out_dtype, {onp.float16: 1e-1, onp.float32: 1e-3, - onp.float64: 1e-3, onp.complex128: 1e-6}) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol, - atol=tol, check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format( - shape, dtype, rowvar, ddof, bias), - "shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof, - "bias": bias, "rng_factory": rng_factory} - for shape in [(5,), (10, 5), (5, 10)] - for dtype in all_dtypes - for rowvar in [True, False] - for bias in [True, False] - for ddof in [None, 2, 3] - for rng_factory in [jtu.rand_default])) - @jtu.skip_on_devices("gpu") # TODO(b/138003641): test fails on GPU. - @jtu.disable - def testCov(self, shape, dtype, rowvar, ddof, bias, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - onp_fun = partial(onp.cov, rowvar=rowvar, ddof=ddof, bias=bias) - lnp_fun = partial(lnp.cov, rowvar=rowvar, ddof=ddof, bias=bias) - tol = {onp.float32: 1e-5, onp.float64: 1e-13, onp.complex128: 1e-13} - tol = 7e-2 if jtu.device_under_test() == "tpu" else tol - tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) - self._CheckAgainstNumpy( - onp_fun, lnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol) - - def testIssue967(self): - self.assertRaises(TypeError, lambda: lnp.zeros(1.5)) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format( - shape, dtype, rowvar, ddof, bias), - "shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof, - "bias": bias, "rng_factory": rng_factory} - for shape in [(5,), (10, 5), (3, 10)] - for dtype in number_dtypes - for rowvar in [True, False] - for bias in [True, False] - for ddof in [None, 2, 3] - for rng_factory in [jtu.rand_default])) - @jtu.disable - def testCorrCoef(self, shape, dtype, rowvar, ddof, bias, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - mat = onp.asarray([rng(shape, dtype)]) - onp_fun = partial(onp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias) - lnp_fun = partial(lnp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias) - if not onp.any(onp.isclose(onp.std(mat), 0.0)): - self._CheckAgainstNumpy( - onp_fun, lnp_fun, args_maker, check_dtypes=False, - tol=1e-2 if jtu.device_under_test() == "tpu" else None) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters( - jtu.cases_from_list( - { - "testcase_name": - "_shapes={}_dtype={}_indexing={}_sparse={}".format( - shapes, jtu.dtype_str(dtype), indexing, sparse), - "shapes": - shapes, - "dtype": - dtype, - "indexing": - indexing, - "sparse": - sparse, - "rng_factory": - rng_factory - } for shapes in [(), (5,), (5, 3)] for dtype in number_dtypes - for indexing in ["xy", "ij"] - for sparse in [False] # TODO(nareshmodi): Make sparse work - for rng_factory in [jtu.rand_default])) - def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes], - [dtype] * len(shapes)) - onp_fun = partial(onp.meshgrid, indexing=indexing, sparse=sparse) - lnp_fun = partial(lnp.meshgrid, indexing=indexing, sparse=sparse) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_retstep={}_dtype={}").format( - start_shape, stop_shape, num, endpoint, retstep, dtype), - "start_shape": start_shape, "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, "retstep": retstep, - "dtype": dtype, "rng_factory": rng_factory} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - for retstep in [True, False] - for dtype in number_dtypes + [None,] - for rng_factory in [jtu.rand_default])) - def testLinspace(self, start_shape, stop_shape, num, endpoint, - retstep, dtype, rng_factory): - if not endpoint and onp.issubdtype(dtype, onp.integer): - # TODO(b/157597565): Support all dtypes when the tf op supports endpoint - # Currently, subtracting the step early leads to rounding errors for - # integers. - return - rng = rng_factory() - # relax default tolerances slightly - tol = jtu.tolerance(dtype if dtype else onp.float32) * 10 - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(onp.shape(start + stop)) - for axis in range(-ndim, ndim): - lnp_op = lambda start, stop: lnp.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - onp_op = lambda start, stop: onp.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, - check_dtypes=False, tol=tol) - # floating-point compute between jitted platforms and non-jit + rounding - # cause unavoidable variation in integer truncation for some inputs. - if dtype in (inexact_dtypes + [None,]): - self._CompileAndCheck(lnp_op, args_maker, - check_dtypes=False, atol=tol, rtol=tol, - check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_base={}_dtype={}").format( - start_shape, stop_shape, num, endpoint, base, - dtype.__name__ if dtype else "None"), - "start_shape": start_shape, - "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, "base": base, - "dtype": dtype, "rng_factory": rng_factory} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - for base in [10.0, 2, onp.e] - for dtype in inexact_dtypes + [None,] - for rng_factory in [jtu.rand_default])) - def testLogspace(self, start_shape, stop_shape, num, - endpoint, base, dtype, rng_factory): - if (dtype in int_dtypes and - jtu.device_under_test() in ("gpu", "tpu") and - not FLAGS.enable_x64): - raise unittest.SkipTest("GPUx32 truncated exponentiation" - " doesn't exactly match other platforms.") - rng = rng_factory() - # relax default tolerances slightly - tol = {onp.float16: 2e-2, onp.float32: 1e-2, onp.float64: 1e-6, - onp.complex64: 1e-3, onp.complex128: 1e-6} - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(onp.shape(start + stop)) - for axis in range(-ndim, ndim): - lnp_op = lambda start, stop: lnp.logspace( - start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) - onp_op = lambda start, stop: onp.logspace( - start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - # Why do compiled and op-by-op float16 np.power numbers differ - # slightly more than expected? - atol = {onp.float16: 1e-2} - self._CompileAndCheck(lnp_op, args_maker, - check_dtypes=False, atol=atol, rtol=tol, - check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_dtype={}").format( - start_shape, stop_shape, num, endpoint, dtype), - "start_shape": start_shape, - "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, - "dtype": dtype, "rng_factory": rng_factory} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - # NB: numpy's geomspace gives nonsense results on integer types - for dtype in inexact_dtypes + [None,] - for rng_factory in [jtu.rand_default])) - def testGeomspace(self, start_shape, stop_shape, num, - endpoint, dtype, rng_factory): - rng = rng_factory() - # relax default tolerances slightly - tol = {onp.float16: 4e-3, onp.float32: 2e-3, onp.complex128: 1e-14} - def args_maker(): - """Test the set of inputs onp.geomspace is well-defined on.""" - start, stop = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype])() - # onp.geomspace can't handle differently ranked tensors - # w. negative numbers! - start, stop = lnp.broadcast_arrays(start, stop) - if dtype in complex_dtypes: - return start, stop - # to avoid NaNs, non-complex start and stop cannot - # differ in sign, elementwise - start = start * lnp.sign(start) * lnp.sign(stop) - return start, stop - start, stop = args_maker() - ndim = len(onp.shape(start + stop)) - for axis in range(-ndim, ndim): - def lnp_op(start, stop): - return lnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, - axis=axis) - def onp_op(start, stop): - start = start.astype(onp.float32) if dtype == lnp.bfloat16 else start - stop = stop.astype(onp.float32) if dtype == lnp.bfloat16 else stop - return onp.geomspace( - start, stop, num, endpoint=endpoint, - dtype=dtype if dtype != lnp.bfloat16 else onp.float32, - axis=axis).astype(dtype) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - self._CompileAndCheck(lnp_op, args_maker, - check_dtypes=False, atol=tol, rtol=tol, - check_incomplete_shape=True) - - @jtu.disable - def testDisableNumpyRankPromotionBroadcasting(self): - try: - prev_flag = FLAGS.jax_numpy_rank_promotion - FLAGS.jax_numpy_rank_promotion = "allow" - lnp.ones(2) + lnp.ones((1, 2)) # works just fine - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag - - try: - prev_flag = FLAGS.jax_numpy_rank_promotion - FLAGS.jax_numpy_rank_promotion = "raise" - self.assertRaises(ValueError, lambda: lnp.ones(2) + lnp.ones((1, 2))) - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag - - try: - prev_flag = FLAGS.jax_numpy_rank_promotion - FLAGS.jax_numpy_rank_promotion = "warn" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - lnp.ones(2) + lnp.ones((1, 2)) - assert len(w) > 0 - msg = str(w[-1].message) - expected_msg = ("Following NumPy automatic rank promotion for add on " - "shapes (2,) (1, 2).") - self.assertEqual(msg[:len(expected_msg)], expected_msg) - - prev_len = len(w) - lnp.ones(2) + 3 - self.assertEqual(len(w), prev_len) # don't want to warn for scalars - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag - - def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 - @npe.jit - def foo(x): - return lnp.stack(x) - foo(onp.zeros(2)) # doesn't crash - - @npe.jit - def foo(x): - return lnp.concatenate(x) - foo(onp.zeros((2, 2))) # doesn't crash - - @jtu.disable - def testReluGradientConstants(self): - # This is a regression test that verifies that constants associated with the - # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the - # outermost jaxpr. This was producing some large materialized constants for - # every relu activation in a model. - def body(i, xy): - x, y = xy - y = y + jax.grad(lambda z: lnp.sum(lnp.maximum(z, 0.)))(x) - return x, y - - f = lambda y: lax.fori_loop(0, 5, body, (y, y)) - wrapped = linear_util.wrap_init(f) - pv = partial_eval.PartialVal( - (jax.core.ShapedArray((3, 4), onp.float32), jax.core.unit)) - _, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv]) - self.assertFalse( - any(onp.array_equal(x, onp.full((3, 4), 2., dtype=onp.float32)) - for x in consts)) - - @named_parameters( - {"testcase_name": "_from={}_to={}".format(from_shape, to_shape), - "rng_factory": rng_factory, "from_shape": from_shape, "to_shape": to_shape} - for from_shape, to_shape in [ - [(1, 3), (4, 3)], - [(3,), (2, 1, 3)], - [(3,), (3, 3)], - [(1,), (3,)], - ] - for rng_factory in [jtu.rand_default]) - def testBroadcastTo(self, from_shape, to_shape, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [from_shape], [onp.float32]) - onp_op = lambda x: onp.broadcast_to(x, to_shape) - lnp_op = lambda x: lnp.broadcast_to(x, to_shape) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - def testBroadcastToIssue1522(self): - self.assertRaisesRegex( - Exception, "Unable to broadcast", - lambda: lnp.broadcast_to(onp.ones((2, 3)), (1, 3))) - - def testBroadcastToIntIssue1548(self): - self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)), - check_dtypes=False) - - def testBroadcastToOnScalar(self): - self.assertIsInstance(lnp.broadcast_to(10.0, ()), lnp.ndarray) - self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray) - - @jtu.disable - def testPrecision(self): - - ones_1d = onp.ones((2,)) - ones_2d = onp.ones((2, 2)) - ones_3d = onp.ones((2, 2, 2)) - HIGHEST = lax.Precision.HIGHEST - - jtu.assert_dot_precision(None, lnp.dot, ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.dot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.dot, precision=HIGHEST), - ones_3d, ones_3d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.matmul, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.vdot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.tensordot, axes=2, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.tensordot, axes=(0, 0), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.tensordot, axes=((0,), (0,)), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.einsum, 'i,i', precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.einsum, 'ij,ij', precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.inner, precision=HIGHEST), - ones_1d, ones_1d) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": - "_{}_{}_{}_{}".format( - shape, jtu.dtype_str(key_dtype), jtu.dtype_str(value_dtype), - dimension).replace(" ", ""), - "shape": shape, "key_dtype": key_dtype, "value_dtype": value_dtype, - "dimension": dimension, "rng_factory": rng_factory} - for shape in all_shapes - for key_dtype in minus(number_dtypes, complex_dtypes) - for value_dtype in all_dtypes - for dimension in range(-len(shape), len(shape)) - for rng_factory in [jtu.rand_default])) - @new_test - def testSortKeyValue(self, shape, key_dtype, value_dtype, dimension, - rng_factory): - def onp_ref(keys, values): - idxs = list(onp.ix_(*[onp.arange(d) for d in keys.shape])) - idxs[dimension] = onp.argsort(keys, axis=dimension) - return keys[tuple(idxs)], values[tuple(idxs)] - rng = rng_factory() - args_maker = self._GetArgsMaker( - rng, [shape, shape], [key_dtype, value_dtype]) - op = partial(npe.sort_key_val, dimension=dimension) - self._CheckAgainstNumpy(onp_ref, op, args_maker, - check_dtypes=True) - # sort_key_val requires known rank. - # XLA only has TopKV2 (used by tf.argsort) kernels on those dtypes - # (b/169194137). - check_xla = key_dtype in (onp.uint32, onp.int32, onp.float32, lnp.bfloat16) - self._CompileAndCheck(op, args_maker, check_dtypes=True, - check_incomplete_shape=True, check_unknown_rank=False, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - -# Most grad tests are at the lax level (see lax_test.py), but we add some here -# as needed for e.g. particular compound ops of interest. - -GradTestSpec = collections.namedtuple( - "GradTestSpec", - ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"]) -def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): - return GradTestSpec( - op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) - -GRAD_TEST_RECORDS = [ - grad_test_spec(lnp.arcsinh, nargs=1, order=2, - rng_factory=jtu.rand_positive, - dtypes=[onp.float64, onp.complex64], tol=1e-4), - grad_test_spec(lnp.arccosh, nargs=1, order=2, - rng_factory=jtu.rand_positive, - dtypes=[onp.float64, onp.complex64], tol=1e-4), - grad_test_spec(lnp.arctanh, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, -0.9, 0.9), - dtypes=[onp.float64, onp.complex64], tol=1e-4), -] - -GradSpecialValuesTestSpec = collections.namedtuple( - "GradSpecialValuesTestSpec", ["op", "values", "order"]) - -GRAD_SPECIAL_VALUE_TEST_RECORDS = [ - GradSpecialValuesTestSpec(lnp.arcsinh, [0., 1000.], 2), - GradSpecialValuesTestSpec(lnp.arccosh, [1000.], 2), - GradSpecialValuesTestSpec(lnp.arctanh, [0.], 2), - # TODO(wangpeng): Add `GradSpecialValuesTestSpec(lnp.sinc, [0.], 1)` -] - -def num_float_bits(dtype): - return lnp.finfo(dtypes.canonicalize_dtype(dtype)).bits - -class NumpyGradTests(jtu.TestCase): - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.name, shapes, itertools.repeat(dtype)), - "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, - "order": rec.order, "tol": rec.tol} - for shapes in CombosWithReplacement(nonempty_shapes, rec.nargs) - for dtype in rec.dtypes) - for rec in GRAD_TEST_RECORDS)) - @jtu.disable - def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): - rng = rng_factory() - tol = {onp.float32: 1e-1, onp.complex64: 1e-1} - args = tuple(rng(shape, dtype) for shape in shapes) - check_grads(op, args, order, ["fwd", "rev"], tol, tol) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), - "op": rec.op, "special_value": special_value, "order": rec.order} - for special_value in rec.values) - for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS)) - @jtu.disable - def testOpGradSpecialValue(self, op, special_value, order): - check_grads(op, (special_value,), order, ["fwd", "rev"], - atol={onp.float32: 3e-3}) - - @jtu.disable - def testTakeAlongAxisIssue1521(self): - # https://github.com/google/jax/issues/1521 - idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1)) - - def f(x): - y = x * lnp.arange(3.).reshape((1, 3)) - return lnp.take_along_axis(y, idx, -1).sum() - - check_grads(f, (1.,), order=1) - - -if __name__ == "__main__": - tf.enable_v2_behavior() - lnp.enable_numpy_behavior() - absltest.main() diff --git a/trax/tf_numpy/jax_tests/test_util.py b/trax/tf_numpy/jax_tests/test_util.py deleted file mode 100644 index b12b04676..000000000 --- a/trax/tf_numpy/jax_tests/test_util.py +++ /dev/null @@ -1,902 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from contextlib import contextmanager -from distutils.util import strtobool -import functools -from functools import partial -import re -import itertools as it -import os -from typing import Dict, Sequence, Union -import sys -import unittest -import warnings -import zlib - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as onp -import numpy.random as npr -import scipy - -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.jax_tests.config import flags, bool_env -import trax.tf_numpy.extensions as npe -import trax.tf_numpy.numpy as tf_np - - -tree_map = tf.nest.map_structure -tree_multimap = tf.nest.map_structure - - -FLAGS = flags.FLAGS - - -# TODO(wangpeng): Remove this flag after broken tests are fixed -flags.DEFINE_bool('enable_x64', - strtobool('False'), - 'Enable 64-bit types to be used.') - - -flags.DEFINE_enum( - 'test_dut', '', - enum_values=['', 'cpu', 'gpu', 'tpu'], - help= - 'Describes the device under test in case special consideration is required.' -) - - -flags.DEFINE_integer( - 'num_generated_cases', - 10, - help='Number of generated cases to test') - - -EPS = 1e-4 - - -# Default dtypes corresponding to Python scalars. -python_scalar_dtypes = { - bool: onp.dtype(onp.bool_), - int: onp.dtype(onp.int_), - float: onp.dtype(onp.float_), - complex: onp.dtype(onp.complex_), -} - - -def _dtype(x): - if isinstance(x, tf.Tensor): - return x.dtype.as_numpy_dtype - return (getattr(x, 'dtype', None) or - onp.dtype(python_scalar_dtypes.get(type(x), None)) or - onp.asarray(x).dtype) - - -def is_sequence(x): - try: - iter(x) - except TypeError: - return False - else: - return True - -_default_tolerance = { - onp.dtype(onp.bool_): 0, - onp.dtype(onp.int8): 0, - onp.dtype(onp.int16): 0, - onp.dtype(onp.int32): 0, - onp.dtype(onp.int64): 0, - onp.dtype(onp.uint8): 0, - onp.dtype(onp.uint16): 0, - onp.dtype(onp.uint32): 0, - onp.dtype(onp.uint64): 0, - # TODO(b/154768983): onp.dtype(dtypes.bfloat16): 1e-2, - onp.dtype(onp.float16): 1e-3, - onp.dtype(onp.float32): 1e-6, - onp.dtype(onp.float64): 1e-15, - onp.dtype(onp.complex64): 1e-6, - onp.dtype(onp.complex128): 1e-15, -} - -def default_tolerance(): - return _default_tolerance - -default_gradient_tolerance = { - # TODO(b/154768983): onp.dtype(dtypes.bfloat16): 1e-1, - onp.dtype(onp.float16): 1e-2, - onp.dtype(onp.float32): 2e-3, - onp.dtype(onp.float64): 1e-5, - onp.dtype(onp.complex64): 1e-3, - onp.dtype(onp.complex128): 1e-5, -} - -def _assert_numpy_allclose(a, b, atol=None, rtol=None): - # TODO(b/154768983): - # a = a.astype(onp.float32) if a.dtype == dtypes.bfloat16 else a - # b = b.astype(onp.float32) if b.dtype == dtypes.bfloat16 else b - kw = {} - if atol: kw["atol"] = atol - if rtol: kw["rtol"] = rtol - onp.testing.assert_allclose(a, b, **kw) - -def tolerance(dtype, tol=None): - tol = {} if tol is None else tol - if not isinstance(tol, dict): - return tol - tol = {onp.dtype(key): value for key, value in tol.items()} - dtype = onp.dtype(dtype) - return tol.get(dtype, default_tolerance()[dtype]) - -def _normalize_tolerance(tol): - tol = tol or 0 - if isinstance(tol, dict): - return {onp.dtype(k): v for k, v in tol.items()} - else: - return {k: tol for k in _default_tolerance} - -def join_tolerance(tol1, tol2): - tol1 = _normalize_tolerance(tol1) - tol2 = _normalize_tolerance(tol2) - out = tol1 - for k, v in tol2.items(): - out[k] = max(v, tol1.get(k, 0)) - return out - -def _assert_numpy_close(a, b, atol=None, rtol=None): - assert a.shape == b.shape - atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol)) - rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol)) - _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size) - - -def check_eq(xs, ys): - tree_all(tree_multimap(_assert_numpy_allclose, xs, ys)) - - -def check_close(xs, ys, atol=None, rtol=None): - assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol) - tree_all(tree_multimap(assert_close, xs, ys)) - - -def inner_prod(xs, ys): - def contract(x, y): - return onp.real(onp.dot(onp.conj(x).reshape(-1), y.reshape(-1))) - return tree_reduce(onp.add, tree_multimap(contract, xs, ys)) - - -add = partial(tree_multimap, lambda x, y: onp.add(x, y, dtype=_dtype(x))) -sub = partial(tree_multimap, lambda x, y: onp.subtract(x, y, dtype=_dtype(x))) -conj = partial(tree_map, lambda x: onp.conj(x, dtype=_dtype(x))) - -def scalar_mul(xs, a): - return tree_map(lambda x: onp.multiply(x, a, dtype=_dtype(x)), xs) - - -def rand_like(rng, x): - shape = onp.shape(x) - dtype = _dtype(x) - randn = lambda: onp.asarray(rng.randn(*shape), dtype=dtype) - if onp.issubdtype(dtype, onp.complexfloating): - return randn() + dtype.type(1.0j) * randn() - else: - return randn() - - -def numerical_jvp(f, primals, tangents, eps=EPS): - delta = scalar_mul(tangents, eps) - f_pos = f(*add(primals, delta)) - f_neg = f(*sub(primals, delta)) - return scalar_mul(sub(f_pos, f_neg), 0.5 / eps) - - -def _merge_tolerance(tol, default): - if tol is None: - return default - if not isinstance(tol, dict): - return tol - out = default.copy() - for k, v in tol.items(): - out[onp.dtype(k)] = v - return out - -def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS): - atol = _merge_tolerance(atol, default_gradient_tolerance) - rtol = _merge_tolerance(rtol, default_gradient_tolerance) - rng = onp.random.RandomState(0) - tangent = tree_map(partial(rand_like, rng), args) - v_out, t_out = f_jvp(args, tangent) - v_out_expected = f(*args) - t_out_expected = numerical_jvp(f, args, tangent, eps=eps) - # In principle we should expect exact equality of v_out and v_out_expected, - # but due to nondeterminism especially on GPU (e.g., due to convolution - # autotuning) we only require "close". - check_close(v_out, v_out_expected, atol=atol, rtol=rtol) - check_close(t_out, t_out_expected, atol=atol, rtol=rtol) - - -def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS): - atol = _merge_tolerance(atol, default_gradient_tolerance) - rtol = _merge_tolerance(rtol, default_gradient_tolerance) - _rand_like = partial(rand_like, onp.random.RandomState(0)) - v_out, vjpfun = f_vjp(*args) - v_out_expected = f(*args) - check_close(v_out, v_out_expected, atol=atol, rtol=rtol) - tangent = tree_map(_rand_like, args) - tangent_out = numerical_jvp(f, args, tangent, eps=eps) - cotangent = tree_map(_rand_like, v_out) - cotangent_out = conj(vjpfun(conj(cotangent))) - ip = inner_prod(tangent, cotangent_out) - ip_expected = inner_prod(tangent_out, cotangent) - check_close(ip, ip_expected, atol=atol, rtol=rtol) - - -def device_under_test(): - return FLAGS.test_dut - -def if_device_under_test(device_type: Union[str, Sequence[str]], - if_true, if_false): - """Chooses `if_true` of `if_false` based on device_under_test.""" - if device_under_test() in ([device_type] if isinstance(device_type, str) - else device_type): - return if_true - else: - return if_false - -def supported_dtypes(): - if device_under_test() == "tpu": - return {onp.bool_, onp.int32, onp.uint32, dtypes.bfloat16, onp.float32, - onp.complex64} - else: - return {onp.bool_, onp.int8, onp.int16, onp.int32, onp.int64, - onp.uint8, onp.uint16, onp.uint32, onp.uint64, - dtypes.bfloat16, onp.float16, onp.float32, onp.float64, - onp.complex64, onp.complex128} - -def skip_if_unsupported_type(dtype): - if dtype not in supported_dtypes(): - raise unittest.SkipTest( - f"Type {dtype} not supported on {device_under_test()}") - -def skip_on_devices(*disabled_devices): - """A decorator for test methods to skip the test on certain devices.""" - def skip(test_method): - @functools.wraps(test_method) - def test_method_wrapper(self, *args, **kwargs): - device = device_under_test() - if device in disabled_devices: - test_name = getattr(test_method, '__name__', '[unknown test]') - raise unittest.SkipTest( - f"{test_name} not supported on {device.upper()}.") - return test_method(self, *args, **kwargs) - return test_method_wrapper - return skip - - -def skip_on_flag(flag_name, skip_value): - """A decorator for test methods to skip the test when flags are set.""" - def skip(test_method): # pylint: disable=missing-docstring - @functools.wraps(test_method) - def test_method_wrapper(self, *args, **kwargs): - flag_value = getattr(FLAGS, flag_name) - if flag_value == skip_value: - test_name = getattr(test_method, '__name__', '[unknown test]') - raise unittest.SkipTest( - f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}") - return test_method(self, *args, **kwargs) - return test_method_wrapper - return skip - -# TODO(phawkins): workaround for bug https://github.com/google/jax/issues/432 -# Delete this code after the minimum jaxlib version is 0.1.46 or greater. -skip_on_mac_linalg_bug = partial( - unittest.skipIf, - (sys.platform == "darwin" and scipy.version.version > "1.1.0" and - lib.version < (0, 1, 46)), - "Test fails on Mac with new scipy (issue #432)") - - -def format_test_name_suffix(opname, shapes, dtypes): - arg_descriptions = (format_shape_dtype_string(shape, dtype) - for shape, dtype in zip(shapes, dtypes)) - return '{}_{}'.format(opname.capitalize(), '_'.join(arg_descriptions)) - - -# We use special symbols, represented as singleton objects, to distinguish -# between NumPy scalars, Python scalars, and 0-D arrays. -class ScalarShape: - def __len__(self): return 0 - def __getitem__(self, i): - raise IndexError(f'index {i} out of range.') -class _NumpyScalar(ScalarShape): pass -class _PythonScalar(ScalarShape): pass -NUMPY_SCALAR_SHAPE = _NumpyScalar() -PYTHON_SCALAR_SHAPE = _PythonScalar() - - -def _dims_of_shape(shape): - """Converts `shape` to a tuple of dimensions.""" - if type(shape) in (list, tuple): - return shape - elif isinstance(shape, ScalarShape): - return () - else: - raise TypeError(type(shape)) - - -def _cast_to_shape(value, shape, dtype): - """Casts `value` to the correct Python type for `shape` and `dtype`.""" - if shape is NUMPY_SCALAR_SHAPE: - # explicitly cast to NumPy scalar in case `value` is a Python scalar. - return onp.dtype(dtype).type(value) - elif shape is PYTHON_SCALAR_SHAPE: - # explicitly cast to Python scalar via https://stackoverflow.com/a/11389998 - return onp.asarray(value).item() - elif type(shape) in (list, tuple): - assert onp.shape(value) == tuple(shape) - return value - else: - raise TypeError(type(shape)) - - -def dtype_str(dtype): - return onp.dtype(dtype).name - - -def format_shape_dtype_string(shape, dtype): - if shape is NUMPY_SCALAR_SHAPE: - return dtype_str(dtype) - elif shape is PYTHON_SCALAR_SHAPE: - return 'py' + dtype_str(dtype) - elif type(shape) in (list, tuple): - shapestr = ','.join(str(dim) for dim in shape) - return '{}[{}]'.format(dtype_str(dtype), shapestr) - elif type(shape) is int: - return '{}[{},]'.format(dtype_str(dtype), shape) - elif isinstance(shape, onp.ndarray): - return '{}[{}]'.format(dtype_str(dtype), shape) - else: - raise TypeError(type(shape)) - - -def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x): - """Produce random values given shape, dtype, scale, and post-processor. - - Args: - rand: a function for producing random values of a given shape, e.g. a - bound version of either onp.RandomState.randn or onp.RandomState.rand. - shape: a shape value as a tuple of positive integers. - dtype: a numpy dtype. - scale: optional, a multiplicative scale for the random values (default 1). - post: optional, a callable for post-processing the random values (default - identity). - - Returns: - An ndarray of the given shape and dtype using random values based on a call - to rand but scaled, converted to the appropriate dtype, and post-processed. - """ - r = lambda: onp.asarray(scale * rand(*_dims_of_shape(shape)), dtype) - if onp.issubdtype(dtype, onp.complexfloating): - vals = r() + 1.0j * r() - else: - vals = r() - return _cast_to_shape(onp.asarray(post(vals), dtype), shape, dtype) - - -def rand_default(scale=3): - randn = npr.RandomState(0).randn - return partial(_rand_dtype, randn, scale=scale) - - -def rand_nonzero(): - post = lambda x: onp.where(x == 0, onp.array(1, dtype=x.dtype), x) - randn = npr.RandomState(0).randn - return partial(_rand_dtype, randn, scale=3, post=post) - - -def rand_positive(): - post = lambda x: x + 1 - rand = npr.RandomState(0).rand - return partial(_rand_dtype, rand, scale=2, post=post) - - -def rand_small(): - randn = npr.RandomState(0).randn - return partial(_rand_dtype, randn, scale=1e-3) - - -def rand_not_small(offset=10.): - post = lambda x: x + onp.where(x > 0, offset, -offset) - randn = npr.RandomState(0).randn - return partial(_rand_dtype, randn, scale=3., post=post) - - -def rand_small_positive(): - rand = npr.RandomState(0).rand - return partial(_rand_dtype, rand, scale=2e-5) - -def rand_uniform(low=0.0, high=1.0): - assert low < high - rand = npr.RandomState(0).rand - post = lambda x: x * (high - low) + low - return partial(_rand_dtype, rand, post=post) - - -def rand_some_equal(): - randn = npr.RandomState(0).randn - rng = npr.RandomState(0) - - def post(x): - x_ravel = x.ravel() - if len(x_ravel) == 0: - return x - flips = rng.rand(*onp.shape(x)) < 0.5 - return onp.where(flips, x_ravel[0], x) - - return partial(_rand_dtype, randn, scale=100., post=post) - - -def rand_some_inf(): - """Return a random sampler that produces infinities in floating types.""" - rng = npr.RandomState(1) - base_rand = rand_default() - - """ - TODO: Complex numbers are not correctly tested - If blocks should be switched in order, and relevant tests should be fixed - """ - def rand(shape, dtype): - """The random sampler function.""" - if not onp.issubdtype(dtype, onp.floating): - # only float types have inf - return base_rand(shape, dtype) - - if onp.issubdtype(dtype, onp.complexfloating): - base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype - out = (rand(shape, base_dtype) + - onp.array(1j, dtype) * rand(shape, base_dtype)) - return _cast_to_shape(out, shape, dtype) - - dims = _dims_of_shape(shape) - posinf_flips = rng.rand(*dims) < 0.1 - neginf_flips = rng.rand(*dims) < 0.1 - - vals = base_rand(shape, dtype) - vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) - vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) - - return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) - - return rand - -def rand_some_nan(): - """Return a random sampler that produces nans in floating types.""" - rng = npr.RandomState(1) - base_rand = rand_default() - - def rand(shape, dtype): - """The random sampler function.""" - if onp.issubdtype(dtype, onp.complexfloating): - base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype - out = (rand(shape, base_dtype) + - onp.array(1j, dtype) * rand(shape, base_dtype)) - return _cast_to_shape(out, shape, dtype) - - if not onp.issubdtype(dtype, onp.floating): - # only float types have inf - return base_rand(shape, dtype) - - dims = _dims_of_shape(shape) - nan_flips = rng.rand(*dims) < 0.1 - - vals = base_rand(shape, dtype) - vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) - - return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) - - return rand - -def rand_some_inf_and_nan(): - """Return a random sampler that produces infinities in floating types.""" - rng = npr.RandomState(1) - base_rand = rand_default() - - """ - TODO: Complex numbers are not correctly tested - If blocks should be switched in order, and relevant tests should be fixed - """ - def rand(shape, dtype): - """The random sampler function.""" - if not onp.issubdtype(dtype, onp.floating): - # only float types have inf - return base_rand(shape, dtype) - - if onp.issubdtype(dtype, onp.complexfloating): - base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype - out = (rand(shape, base_dtype) + - onp.array(1j, dtype) * rand(shape, base_dtype)) - return _cast_to_shape(out, shape, dtype) - - dims = _dims_of_shape(shape) - posinf_flips = rng.rand(*dims) < 0.1 - neginf_flips = rng.rand(*dims) < 0.1 - nan_flips = rng.rand(*dims) < 0.1 - - vals = base_rand(shape, dtype) - vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) - vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) - vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) - - return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) - - return rand - -# TODO(mattjj): doesn't handle complex types -def rand_some_zero(): - """Return a random sampler that produces some zeros.""" - rng = npr.RandomState(1) - base_rand = rand_default() - - def rand(shape, dtype): - """The random sampler function.""" - dims = _dims_of_shape(shape) - zeros = rng.rand(*dims) < 0.5 - - vals = base_rand(shape, dtype) - vals = onp.where(zeros, onp.array(0, dtype=dtype), vals) - - return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) - - return rand - - -def rand_int(low, high=None): - randint = npr.RandomState(0).randint - def fn(shape, dtype): - return randint(low, high=high, size=shape, dtype=dtype) - return fn - -def rand_unique_int(): - randchoice = npr.RandomState(0).choice - def fn(shape, dtype): - return randchoice(onp.arange(onp.prod(shape), dtype=dtype), - size=shape, replace=False) - return fn - -def rand_bool(): - rng = npr.RandomState(0) - def generator(shape, dtype): - return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype) - return generator - -def check_raises(thunk, err_type, msg): - try: - thunk() - assert False - except err_type as e: - assert str(e).startswith(msg), "\n{}\n\n{}\n".format(e, msg) - -def check_raises_regexp(thunk, err_type, pattern): - try: - thunk() - assert False - except err_type as e: - assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern) - - -def _iter_eqns(jaxpr): - # TODO(necula): why doesn't this search in params? - for eqn in jaxpr.eqns: - yield eqn - for subjaxpr in core.subjaxprs(jaxpr): - yield from _iter_eqns(subjaxpr) - -def assert_dot_precision(expected_precision, fun, *args): - jaxpr = api.make_jaxpr(fun)(*args) - precisions = [eqn.params['precision'] for eqn in _iter_eqns(jaxpr.jaxpr) - if eqn.primitive == lax.dot_general_p] - for precision in precisions: - msg = "Unexpected precision: {} != {}".format(expected_precision, precision) - assert precision == expected_precision, msg - - -_CACHED_INDICES: Dict[int, Sequence[int]] = {} - -def cases_from_list(xs): - xs = list(xs) - n = len(xs) - k = min(n, FLAGS.num_generated_cases) - # Random sampling for every parameterized test is expensive. Do it once and - # cache the result. - indices = _CACHED_INDICES.get(n) - if indices is None: - rng = npr.RandomState(42) - _CACHED_INDICES[n] = indices = rng.permutation(n) - return [xs[i] for i in indices[:k]] - -def cases_from_gens(*gens): - sizes = [1, 3, 10] - cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1 - for size in sizes: - for i in range(cases_per_size): - yield ('_{}_{}'.format(size, i),) + tuple(gen(size) for gen in gens) - - -def to_np(a): - return tf.nest.map_structure(tf_np.asarray, a) - - -def to_tf_fn(f): - return lambda *args: f(*to_np(args)) - - -class TestCase(parameterized.TestCase): - """Base class for tests including numerical checks and boilerplate.""" - - # copied from jax.test_util - def setUp(self): - super().setUp() - self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) - - # copied from jax.test_util - def rng(self): - return self._rng - - # TODO(mattjj): this obscures the error messages from failures, figure out how - # to re-enable it - # def tearDown(self) -> None: - # assert core.reset_trace_state() - - def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None): - """Assert that x and y are close (up to numerical tolerances).""" - self.assertEqual(x.shape, y.shape) - atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) - rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) - - _assert_numpy_allclose(x, y, atol=atol, rtol=rtol) - - if check_dtypes: - self.assertDtypesMatch(x, y) - - def assertDtypesMatch(self, x, y): - if FLAGS.enable_x64: - self.assertEqual(_dtype(x), _dtype(y)) - - def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None): - """Assert that x and y, either arrays or nested tuples/lists, are close.""" - if isinstance(x, dict): - self.assertIsInstance(y, dict) - self.assertEqual(set(x.keys()), set(y.keys())) - for k in x: - self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol) - elif is_sequence(x) and not hasattr(x, '__array__'): - self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) - self.assertEqual(len(x), len(y)) - for x_elt, y_elt in zip(x, y): - self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol) - elif hasattr(x, '__array__') or onp.isscalar(x): - self.assertTrue(hasattr(y, '__array__') or onp.isscalar(y)) - if check_dtypes: - self.assertDtypesMatch(x, y) - x = onp.asarray(x) - y = onp.asarray(y) - self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol) - elif x == y: - return - else: - raise TypeError((type(x), type(y))) - - def assertMultiLineStrippedEqual(self, expected, what): - """Asserts two strings are equal, after stripping each line.""" - ignore_space_re = re.compile(r'\s*\n\s*') - expected_clean = re.sub(ignore_space_re, '\n', expected.strip()) - what_clean = re.sub(ignore_space_re, '\n', what.strip()) - self.assertMultiLineEqual(expected_clean, what_clean, - msg="Found\n{}\nExpecting\n{}".format(what, expected)) - - def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, - check_dtypes=True, tol=None): - args = args_maker() - lax_ans = lax_op(*args) - numpy_ans = numpy_reference_op(*args) - self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, - atol=tol, rtol=tol) - - def _CompileAndCheck(self, - fun, - args_maker, - check_dtypes=True, - rtol=None, - atol=None, - check_eval_on_shapes=True, - check_incomplete_shape=True, - check_unknown_rank=True, - static_argnums=(), - check_experimental_compile=True, - check_xla_forced_compile=True): - """Compiles the function and checks the results. - - Args: - fun: the function to be checked. - args_maker: a callable that returns a tuple which will be used as the - positional arguments. - check_dtypes: whether to check that the result dtypes from non-compiled - and compiled runs agree. - rtol: relative tolerance for allclose assertions. - atol: absolute tolerance for allclose assertions. - check_eval_on_shapes: whether to run `eval_on_shapes` on the function and - check that the result shapes and dtypes are correct. - check_incomplete_shape: whether to check that the function can handle - incomplete shapes (including those with and without a known rank). - check_unknown_rank: (only has effect when check_incomplete_shape is True) - whether to check that the function can handle unknown ranks. - static_argnums: indices of arguments to be treated as static arguments for - `jit` and `eval_on_shapes`. - check_experimental_compile: whether to check compilation with - experimental_compile=True (in addition to compilation without the flag). - check_xla_forced_compile: whether to check compilation with - forced_compile=True (in addition to compilation without the flag). This - flag is different from experimental_compile because it enforces - whole-function compilation while the latter doesn't. TPU requires - whole-function compilation. - """ - args = args_maker() - - for x in args: - if not hasattr(x, 'dtype'): - # If there is a input that doesn't have dtype info, jit and - # eval_on_shapes may pick a different dtype for it than numpy, so we - # skip the dtype check. - check_dtypes = False - - python_ans = fun(*args) - - python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans) - onp_shapes = tf.nest.map_structure(lambda x: onp.shape(onp.asarray(x)), - python_ans) - self.assertEqual(python_shapes, onp_shapes) - - def check_compile(**kwargs): - # `wrapped_fun` and `python_should_be_executing` are used to check that - # when the jitted function is called the second time, the original Python - # function won't be executed. - def wrapped_fun(*args): - self.assertTrue(python_should_be_executing) - return fun(*args) - - cfun = npe.jit(wrapped_fun, static_argnums=static_argnums, **kwargs) - python_should_be_executing = True - monitored_ans = cfun(*args) - - python_should_be_executing = False - compiled_ans = cfun(*args) - - self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol) - self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) - - # Run `cfun` with a different set of arguments to check that changing - # arguments won't cause recompilation. - - new_args = args_maker() - - skip_retracing_test = False - for old, new in zip(tf.nest.flatten(args), tf.nest.flatten(new_args)): - if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new): - # If the old and new arguments result in different dtypes (because - # they fall into different value ranges), tf-numpy will retrace, so we - # skip the no-retrace test. - skip_retracing_test = True - - if not skip_retracing_test: - python_should_be_executing = True - new_python_ans = fun(*new_args) - python_should_be_executing = False - compiled_ans = cfun(*new_args) - self.assertAllClose(new_python_ans, compiled_ans, check_dtypes, atol, - rtol) - - check_compile() - if check_experimental_compile: - check_compile(experimental_compile=True) - if check_xla_forced_compile: - check_compile(xla_forced_compile=True) - - if check_eval_on_shapes: - # Check that npe.eval_on_shapes can get complete output shapes given - # complete input shapes. - cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums) - compiled_ans = cfun(*args) - flat_python_ans = tf.nest.flatten(python_ans) - flat_compiled_ans = tf.nest.flatten(compiled_ans) - self.assertEqual(len(flat_python_ans), len(flat_compiled_ans)) - for a, b in zip(flat_python_ans, flat_compiled_ans): - if hasattr(a, 'shape'): - self.assertEqual(a.shape, b.shape) - if check_dtypes and hasattr(a, 'dtype'): - self.assertEqual(tf.as_dtype(a.dtype), b.dtype) - - # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we - # skip incomplete-shape checks, since shape specs need dtype. It's OK to - # skip since the same incomplete-shape checks will run for []-shaped arrays. - if check_incomplete_shape and all(hasattr(x, 'dtype') for x in args): - # Check partial shapes with known ranks. - # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not - # `shape`. - if all(hasattr(x, 'shape') for x in args): - specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args] - cfun = npe.jit( - fun, static_argnums=static_argnums, input_signature=specs) - compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) - - if check_unknown_rank: - # Check unknown ranks. - specs = [tf.TensorSpec(None, x.dtype) for x in args] - cfun = npe.jit( - fun, static_argnums=static_argnums, input_signature=specs) - compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) - - def check_grads(self, f, args, atol=None, rtol=None, delta=None): - """Check gradients against finite differences. - - Args: - f: function to check at ``f(*args)``. - args: a list or tuple of argument values. - atol: absolute tolerance for gradient equality. - rtol: relative tolerance for gradient equality. - delta: step size used for finite differences. - """ - if delta is None: - # Optimal stepsize for central difference is O(epsilon^{1/3}). - dtype = tf_np.result_type(*args) - epsilon = onp.finfo(dtype).eps - delta = epsilon ** (1.0 / 3.0) - theoretical, numerical = tf.test.compute_gradient( - to_tf_fn(f), args, delta=delta) - self.assertAllClose(theoretical, numerical, check_dtypes=False, atol=atol, - rtol=rtol) - - -@contextmanager -def ignore_warning(**kw): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", **kw) - yield - - -def disable(_): - - def wrapper(self, *args, **kwargs): - self.skipTest('Test is disabled') - - return wrapper diff --git a/trax/tf_numpy/jax_tests/vmap_test.py b/trax/tf_numpy/jax_tests/vmap_test.py deleted file mode 100644 index b35f78808..000000000 --- a/trax/tf_numpy/jax_tests/vmap_test.py +++ /dev/null @@ -1,167 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -from absl.testing import parameterized - -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy import extensions -import trax.tf_numpy.numpy as tf_np - -from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=g-direct-tensorflow-import - - -class VmapTest(tf.test.TestCase, parameterized.TestCase): - - def test_vmap_in_axes_list(self): - # https://github.com/google/jax/issues/2367 - dictionary = {'a': 5., 'b': tf_np.ones(2)} - x = tf_np.zeros(3) - y = tf_np.arange(3.) - - def f(dct, x, y): - return dct['a'] + dct['b'] + x + y - - out1 = extensions.vmap(f, (None, 0, 0))(dictionary, x, y) - out2 = extensions.vmap(f, [None, 0, 0])(dictionary, x, y) - self.assertAllClose(out1, out2) - - def test_vmap_in_axes_tree_prefix_error(self): - # https://github.com/google/jax/issues/795 - self.assertRaisesRegex( - ValueError, - 'vmap in_axes specification must be a tree prefix of the corresponding ' - r'value, got specification \(0, 0\) for value tree ', - lambda: extensions.vmap(lambda x: x, in_axes=(0, 0))(tf_np.ones(3))) - - def test_vmap_in_axes_leaf_types(self): - with self.assertRaisesRegex(TypeError, - r'vmap in_axes must be an int, None, or .*'): - extensions.vmap( - lambda x: x, in_axes=(tf_np.array([1., 2.]),))( - tf_np.array([1., 2.])) - - def test_vmap_out_axes_leaf_types(self): - with self.assertRaisesRegex(TypeError, - r'vmap out_axes must be an int, None, or .*'): - extensions.vmap( - lambda x: x, out_axes=(tf_np.array([1., 2.]),))( - tf_np.array([1., 2.])) - - def test_vmap_unbatched_object_passthrough_issue_183(self): - # https://github.com/google/jax/issues/183 - fun = lambda f, x: f(x) - vfun = extensions.vmap(fun, (None, 0)) - ans = vfun(lambda x: x + 1, tf_np.arange(3)) - self.assertAllClose(ans, np.arange(1, 4)) - - def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): - # https://github.com/google/jax/issues/705 - with self.assertRaisesRegex( - ValueError, 'vmap must have at least one non-None value in in_axes'): - # If the output is mapped, there must be a non-None in_axes - extensions.vmap(lambda x: x, in_axes=None)(tf_np.array([1., 2.])) - - # Error is: TypeError: only integer scalar arrays can be converted to a - # scalar index - with self.assertRaisesRegex( - ValueError, 'vmap out_axes specification must be a tree prefix of the ' - 'corresponding value.*'): - extensions.vmap( - lambda x: x, in_axes=0, out_axes=(2, 3))( - tf_np.array([1., 2.])) - - def test_vmap_structured_in_axes(self): - a, b, c, d = 2, 3, 4, 5 - k = 6 # batch size - x = np.ones((k, a, b)) # batch axis in different locations - y = np.ones((b, k, c)) - z = np.ones((c, d, k)) - - def foo(tree_arg): - x, (y, z) = tree_arg - return tf_np.dot(x, tf_np.dot(y, z)) - - tree = (x, (y, z)) - vfoo = extensions.vmap(foo, in_axes=((0, (1, 2)),)) - self.assertEqual(vfoo(tree).shape, (6, 2, 5)) - - Point = collections.namedtuple('Point', ['x', 'y']) - tree = (x, Point(y, z)) - vfoo = extensions.vmap(foo, in_axes=((0, Point(1, 2)),)) - self.assertEqual(vfoo(tree).shape, (6, 2, 5)) - - def foo2(tree_arg): - x, dct = tree_arg - y, z = dct['a'], dct['b'] - return tf_np.dot(x, tf_np.dot(y, z)) - - tree = (x, {'a': y, 'b': z}) - vfoo = extensions.vmap(foo2, in_axes=((0, {'a': 1, 'b': 2}),)) - self.assertEqual(vfoo(tree).shape, (6, 2, 5)) - - tree = (x, collections.OrderedDict([('a', y), ('b', z)])) - vfoo = extensions.vmap( - foo2, in_axes=((0, collections.OrderedDict([('a', 1), ('b', 2)])),)) - self.assertEqual(vfoo(tree).shape, (6, 2, 5)) - - def test_vmap_out_axes(self): - f = extensions.vmap(lambda x: x, out_axes=0) - inp = tf_np.arange(6).reshape([2, 3]) - self.assertAllClose(inp, f(inp)) - self.assertAllClose([inp, inp], f((inp, inp))) - - f = extensions.vmap(lambda x: x, out_axes=-1) - self.assertAllClose(inp.T, f(inp)) - - f = extensions.vmap(lambda x: x, out_axes=None) - self.assertAllClose(inp[0], f(inp)) - - f = extensions.vmap(lambda x: x, out_axes=([0], (-1, None), {'a': 1})) - a, b, c = f(([inp], (inp, inp), {'a': inp})) - self.assertAllClose([inp], a) - self.assertAllClose((inp.T, inp[0]), b) - self.assertAllClose(inp.T, c['a']) - - def test_negative_axes(self): - x = np.arange(3 * 4 * 5).reshape(3, 4, 5) - self.assertAllClose( - extensions.vmap(tf_np.sum, in_axes=-3)(x), tf_np.sum(x, axis=(1, 2))) - self.assertAllClose( - extensions.vmap(tf_np.sum, in_axes=-2)(x), tf_np.sum(x, axis=(0, 2))) - self.assertAllClose( - extensions.vmap(tf_np.sum, in_axes=-1)(x), tf_np.sum(x, axis=(0, 1))) - - identity = lambda y: y - self.assertAllClose(x, extensions.vmap(identity, in_axes=0, out_axes=-3)(x)) - self.assertAllClose( - x.transpose(1, 0, 2), - extensions.vmap(identity, in_axes=0, out_axes=-2)(x)) - self.assertAllClose( - x.transpose(1, 2, 0), - extensions.vmap(identity, in_axes=0, out_axes=-1)(x)) - - self.assertAllClose( - np.full((5,), 7), - extensions.vmap(lambda *xs: xs, in_axes=(0, None), - out_axes=(0, -1))(np.arange(5), 7)[1]) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - np_math_ops.enable_numpy_methods_on_tensor() - tf.test.main() diff --git a/trax/tf_numpy/numpy/__init__.py b/trax/tf_numpy/numpy/__init__.py index 2877116d7..49672a297 100644 --- a/trax/tf_numpy/numpy/__init__.py +++ b/trax/tf_numpy/numpy/__init__.py @@ -20,49 +20,51 @@ # pylint: disable=g-direct-tensorflow-import try: - # Note that this import will work in tf-nightly and TF versions 2.4 and - # higher. - from tensorflow.experimental.numpy import * - # TODO(agarwal): get rid of following imports. - from tensorflow.experimental.numpy import random - from tensorflow import bfloat16 - import numpy as onp - from tensorflow.python.ops.numpy_ops.np_dtypes import canonicalize_dtype - from tensorflow.python.ops.numpy_ops.np_dtypes import default_float_type - from tensorflow.python.ops.numpy_ops.np_dtypes import is_allow_float64 - from tensorflow.python.ops.numpy_ops.np_dtypes import set_allow_float64 + # Note that this import will work in tf-nightly and TF versions 2.4 and higher. + from tensorflow.experimental.numpy import * - random.DEFAULT_RANDN_DTYPE = onp.float32 -except ImportError: - try: - # Note that this import will work in TF 2.3 and higher. - from tensorflow.python.ops.numpy_ops import * + # TODO(agarwal): get rid of following imports. + from tensorflow.experimental.numpy import random from tensorflow import bfloat16 + import numpy as onp + from tensorflow.python.ops.numpy_ops.np_dtypes import canonicalize_dtype + from tensorflow.python.ops.numpy_ops.np_dtypes import default_float_type + from tensorflow.python.ops.numpy_ops.np_dtypes import is_allow_float64 + from tensorflow.python.ops.numpy_ops.np_dtypes import set_allow_float64 + + random.DEFAULT_RANDN_DTYPE = onp.float32 +except ImportError: + try: + # Note that this import will work in TF 2.3 and higher. + from tensorflow.python.ops.numpy_ops import * + from tensorflow import bfloat16 - except ImportError: - # Note that this fallback will be needed for TF 2.2. - from tensorflow import newaxis + except ImportError: + # Note that this fallback will be needed for TF 2.2. + from tensorflow import newaxis - from trax.tf_numpy.numpy_impl import random + from trax.tf_numpy.numpy_impl import random - # pylint: disable=wildcard-import - from trax.tf_numpy.numpy_impl.array_ops import * - from trax.tf_numpy.numpy_impl.arrays import * - from trax.tf_numpy.numpy_impl.dtypes import * - from trax.tf_numpy.numpy_impl.math_ops import * - from trax.tf_numpy.numpy_impl.utils import finfo - from trax.tf_numpy.numpy_impl.utils import promote_types - from trax.tf_numpy.numpy_impl.utils import result_type - # pylint: enable=wildcard-import + # pylint: disable=wildcard-import + from trax.tf_numpy.numpy_impl.array_ops import * + from trax.tf_numpy.numpy_impl.arrays import * + from trax.tf_numpy.numpy_impl.dtypes import * + from trax.tf_numpy.numpy_impl.math_ops import * + from trax.tf_numpy.numpy_impl.utils import finfo + from trax.tf_numpy.numpy_impl.utils import promote_types + from trax.tf_numpy.numpy_impl.utils import result_type - max = amax # pylint: disable=redefined-builtin,undefined-variable - min = amin # pylint: disable=redefined-builtin,undefined-variable - round = around # pylint: disable=redefined-builtin,undefined-variable + # pylint: enable=wildcard-import + + max = amax # pylint: disable=redefined-builtin,undefined-variable + min = amin # pylint: disable=redefined-builtin,undefined-variable + round = around # pylint: disable=redefined-builtin,undefined-variable try: - from tensorflow.python.ops.numpy_ops.np_config import enable_numpy_behavior - # TODO(b/171429739): This should be moved to every individual file/test. - enable_numpy_behavior() + from tensorflow.python.ops.numpy_ops.np_config import enable_numpy_behavior + + # TODO(b/171429739): This should be moved to every individual file/test. + enable_numpy_behavior() except ImportError: - pass + pass diff --git a/trax/tf_numpy/numpy_impl/array_ops.py b/trax/tf_numpy/numpy_impl/array_ops.py index c47b827b3..60cd22a36 100644 --- a/trax/tf_numpy/numpy_impl/array_ops.py +++ b/trax/tf_numpy/numpy_impl/array_ops.py @@ -26,1199 +26,1280 @@ def empty(shape, dtype=float): # pylint: disable=redefined-outer-name - """Returns an empty array with the specified shape and dtype. + """Returns an empty array with the specified shape and dtype. - Args: - shape: A fully defined shape. Could be - NumPy array or a python scalar, - list or tuple of integers, - TensorFlow tensor/ndarray of integer type and - rank <=1. - dtype: Optional, defaults to float. The type of the resulting ndarray. Could - be a python type, a NumPy type or a TensorFlow `DType`. + Args: + shape: A fully defined shape. Could be - NumPy array or a python scalar, + list or tuple of integers, - TensorFlow tensor/ndarray of integer type and + rank <=1. + dtype: Optional, defaults to float. The type of the resulting ndarray. Could + be a python type, a NumPy type or a TensorFlow `DType`. - Returns: - An ndarray. - """ - return zeros(shape, dtype) + Returns: + An ndarray. + """ + return zeros(shape, dtype) def empty_like(a, dtype=None): - """Returns an empty array with the shape and possibly type of the input array. + """Returns an empty array with the shape and possibly type of the input array. - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the input array. The type of the - resulting ndarray. Could be a python type, a NumPy type or a TensorFlow - `DType`. + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can be + converted to a Tensor using `tf.convert_to_tensor`. + dtype: Optional, defaults to dtype of the input array. The type of the + resulting ndarray. Could be a python type, a NumPy type or a TensorFlow + `DType`. - Returns: - An ndarray. - """ - return zeros_like(a, dtype) + Returns: + An ndarray. + """ + return zeros_like(a, dtype) def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name - """Returns an ndarray with the given shape and type filled with zeros. + """Returns an ndarray with the given shape and type filled with zeros. - Args: - shape: A fully defined shape. Could be - NumPy array or a python scalar, - list or tuple of integers, - TensorFlow tensor/ndarray of integer type and - rank <=1. - dtype: Optional, defaults to float. The type of the resulting ndarray. Could - be a python type, a NumPy type or a TensorFlow `DType`. + Args: + shape: A fully defined shape. Could be - NumPy array or a python scalar, + list or tuple of integers, - TensorFlow tensor/ndarray of integer type and + rank <=1. + dtype: Optional, defaults to float. The type of the resulting ndarray. Could + be a python type, a NumPy type or a TensorFlow `DType`. - Returns: - An ndarray. - """ - if dtype: - dtype = utils.result_type(dtype) - if isinstance(shape, arrays_lib.ndarray): - shape = shape.data - return arrays_lib.tensor_to_ndarray(tf.zeros(shape, dtype=dtype)) + Returns: + An ndarray. + """ + if dtype: + dtype = utils.result_type(dtype) + if isinstance(shape, arrays_lib.ndarray): + shape = shape.data + return arrays_lib.tensor_to_ndarray(tf.zeros(shape, dtype=dtype)) def zeros_like(a, dtype=None): - """Returns an array of zeros with the shape and type of the input array. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the input array. The type of the - resulting ndarray. Could be a python type, a NumPy type or a TensorFlow - `DType`. - - Returns: - An ndarray. - """ - if isinstance(a, arrays_lib.ndarray): - a = a.data - if dtype is None: - # We need to let utils.result_type decide the dtype, not tf.zeros_like - dtype = utils.result_type(a) - else: - # TF and numpy has different interpretations of Python types such as - # `float`, so we let `utils.result_type` decide. - dtype = utils.result_type(dtype) - dtype = tf.as_dtype(dtype) # Work around b/149877262 - return arrays_lib.tensor_to_ndarray(tf.zeros_like(a, dtype)) + """Returns an array of zeros with the shape and type of the input array. + + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can be + converted to a Tensor using `tf.convert_to_tensor`. + dtype: Optional, defaults to dtype of the input array. The type of the + resulting ndarray. Could be a python type, a NumPy type or a TensorFlow + `DType`. + + Returns: + An ndarray. + """ + if isinstance(a, arrays_lib.ndarray): + a = a.data + if dtype is None: + # We need to let utils.result_type decide the dtype, not tf.zeros_like + dtype = utils.result_type(a) + else: + # TF and numpy has different interpretations of Python types such as + # `float`, so we let `utils.result_type` decide. + dtype = utils.result_type(dtype) + dtype = tf.as_dtype(dtype) # Work around b/149877262 + return arrays_lib.tensor_to_ndarray(tf.zeros_like(a, dtype)) def ones(shape, dtype=float): # pylint: disable=redefined-outer-name - """Returns an ndarray with the given shape and type filled with ones. + """Returns an ndarray with the given shape and type filled with ones. - Args: - shape: A fully defined shape. Could be - NumPy array or a python scalar, - list or tuple of integers, - TensorFlow tensor/ndarray of integer type and - rank <=1. - dtype: Optional, defaults to float. The type of the resulting ndarray. Could - be a python type, a NumPy type or a TensorFlow `DType`. + Args: + shape: A fully defined shape. Could be - NumPy array or a python scalar, + list or tuple of integers, - TensorFlow tensor/ndarray of integer type and + rank <=1. + dtype: Optional, defaults to float. The type of the resulting ndarray. Could + be a python type, a NumPy type or a TensorFlow `DType`. - Returns: - An ndarray. - """ - if dtype: - dtype = utils.result_type(dtype) - if isinstance(shape, arrays_lib.ndarray): - shape = shape.data - return arrays_lib.tensor_to_ndarray(tf.ones(shape, dtype=dtype)) + Returns: + An ndarray. + """ + if dtype: + dtype = utils.result_type(dtype) + if isinstance(shape, arrays_lib.ndarray): + shape = shape.data + return arrays_lib.tensor_to_ndarray(tf.ones(shape, dtype=dtype)) def ones_like(a, dtype=None): - """Returns an array of ones with the shape and type of the input array. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the input array. The type of the - resulting ndarray. Could be a python type, a NumPy type or a TensorFlow - `DType`. - - Returns: - An ndarray. - """ - if isinstance(a, arrays_lib.ndarray): - a = a.data - if dtype is None: - dtype = utils.result_type(a) - else: - dtype = utils.result_type(dtype) - return arrays_lib.tensor_to_ndarray(tf.ones_like(a, dtype)) + """Returns an array of ones with the shape and type of the input array. + + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can be + converted to a Tensor using `tf.convert_to_tensor`. + dtype: Optional, defaults to dtype of the input array. The type of the + resulting ndarray. Could be a python type, a NumPy type or a TensorFlow + `DType`. + + Returns: + An ndarray. + """ + if isinstance(a, arrays_lib.ndarray): + a = a.data + if dtype is None: + dtype = utils.result_type(a) + else: + dtype = utils.result_type(dtype) + return arrays_lib.tensor_to_ndarray(tf.ones_like(a, dtype)) @utils.np_doc(np.eye) def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-docstring - if dtype: - dtype = utils.result_type(dtype) - if not M: - M = N - # Making sure N, M and k are `int` - N = int(N) - M = int(M) - k = int(k) - if k >= M or -k >= N: - # tf.linalg.diag will raise an error in this case - return zeros([N, M], dtype=dtype) - if k == 0: - return arrays_lib.tensor_to_ndarray(tf.eye(N, M, dtype=dtype)) - # We need the precise length, otherwise tf.linalg.diag will raise an error - diag_len = min(N, M) - if k > 0: - if N >= M: - diag_len -= k - elif N + k > M: - diag_len = M - k - elif k <= 0: - if M >= N: - diag_len += k - elif M - k > N: - diag_len = N + k - diagonal_ = tf.ones([diag_len], dtype=dtype) - return arrays_lib.tensor_to_ndarray( - tf.linalg.diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)) + if dtype: + dtype = utils.result_type(dtype) + if not M: + M = N + # Making sure N, M and k are `int` + N = int(N) + M = int(M) + k = int(k) + if k >= M or -k >= N: + # tf.linalg.diag will raise an error in this case + return zeros([N, M], dtype=dtype) + if k == 0: + return arrays_lib.tensor_to_ndarray(tf.eye(N, M, dtype=dtype)) + # We need the precise length, otherwise tf.linalg.diag will raise an error + diag_len = min(N, M) + if k > 0: + if N >= M: + diag_len -= k + elif N + k > M: + diag_len = M - k + elif k <= 0: + if M >= N: + diag_len += k + elif M - k > N: + diag_len = N + k + diagonal_ = tf.ones([diag_len], dtype=dtype) + return arrays_lib.tensor_to_ndarray( + tf.linalg.diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k) + ) def identity(n, dtype=float): - """Returns a square array with ones on the main diagonal and zeros elsewhere. + """Returns a square array with ones on the main diagonal and zeros elsewhere. - Args: - n: number of rows/cols. - dtype: Optional, defaults to float. The type of the resulting ndarray. Could - be a python type, a NumPy type or a TensorFlow `DType`. + Args: + n: number of rows/cols. + dtype: Optional, defaults to float. The type of the resulting ndarray. Could + be a python type, a NumPy type or a TensorFlow `DType`. - Returns: - An ndarray of shape (n, n) and requested type. - """ - return eye(N=n, M=n, dtype=dtype) + Returns: + An ndarray of shape (n, n) and requested type. + """ + return eye(N=n, M=n, dtype=dtype) def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name - """Returns an array with given shape and dtype filled with `fill_value`. + """Returns an array with given shape and dtype filled with `fill_value`. - Args: - shape: A valid shape object. Could be a native python object or an object - of type ndarray, numpy.ndarray or tf.TensorShape. - fill_value: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the `fill_value`. The type of the - resulting ndarray. Could be a python type, a NumPy type or a TensorFlow - `DType`. + Args: + shape: A valid shape object. Could be a native python object or an object + of type ndarray, numpy.ndarray or tf.TensorShape. + fill_value: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. + dtype: Optional, defaults to dtype of the `fill_value`. The type of the + resulting ndarray. Could be a python type, a NumPy type or a TensorFlow + `DType`. - Returns: - An ndarray. + Returns: + An ndarray. - Raises: - ValueError: if `fill_value` can not be broadcast to shape `shape`. - """ - fill_value = asarray(fill_value, dtype=dtype) - if utils.isscalar(shape): - shape = tf.reshape(shape, [1]) - return arrays_lib.tensor_to_ndarray(tf.broadcast_to(fill_value.data, shape)) + Raises: + ValueError: if `fill_value` can not be broadcast to shape `shape`. + """ + fill_value = asarray(fill_value, dtype=dtype) + if utils.isscalar(shape): + shape = tf.reshape(shape, [1]) + return arrays_lib.tensor_to_ndarray(tf.broadcast_to(fill_value.data, shape)) # Using doc only here since np full_like signature doesn't seem to have the # shape argument (even though it exists in the documentation online). @utils.np_doc_only(np.full_like) -def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # pylint: disable=missing-docstring,redefined-outer-name - """order, subok and shape arguments mustn't be changed.""" - if order != 'K': - raise ValueError('Non-standard orders are not supported.') - if not subok: - raise ValueError('subok being False is not supported.') - if shape: - raise ValueError('Overriding the shape is not supported.') - - a = asarray(a).data - dtype = dtype or utils.result_type(a) - fill_value = asarray(fill_value, dtype=dtype) - return arrays_lib.tensor_to_ndarray( - tf.broadcast_to(fill_value.data, tf.shape(a))) +def full_like( + a, fill_value, dtype=None, order="K", subok=True, shape=None +): # pylint: disable=missing-docstring,redefined-outer-name + """order, subok and shape arguments mustn't be changed.""" + if order != "K": + raise ValueError("Non-standard orders are not supported.") + if not subok: + raise ValueError("subok being False is not supported.") + if shape: + raise ValueError("Overriding the shape is not supported.") + + a = asarray(a).data + dtype = dtype or utils.result_type(a) + fill_value = asarray(fill_value, dtype=dtype) + return arrays_lib.tensor_to_ndarray(tf.broadcast_to(fill_value.data, tf.shape(a))) # TODO(wangpeng): investigate whether we can make `copy` default to False. # TODO(wangpeng): utils.np_doc can't handle np.array because np.array is a # builtin function. Make utils.np_doc support builtin functions. def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name - """Creates an ndarray with the contents of val. - - Args: - val: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the `val`. The type of the resulting - ndarray. Could be a python type, a NumPy type or a TensorFlow `DType`. - copy: Determines whether to create a copy of the backing buffer. Since - Tensors are immutable, a copy is made only if val is placed on a different - device than the current one. Even if `copy` is False, a new Tensor may - need to be built to satisfy `dtype` and `ndim`. This is used only if `val` - is an ndarray or a Tensor. - ndmin: The minimum rank of the returned array. - - Returns: - An ndarray. - """ - if dtype: - dtype = utils.result_type(dtype) - if isinstance(val, arrays_lib.ndarray): - result_t = val.data - else: - result_t = val - - if copy and isinstance(result_t, tf.Tensor): - # Note: In eager mode, a copy of `result_t` is made only if it is not on - # the context device. - result_t = tf.identity(result_t) - - if not isinstance(result_t, tf.Tensor): - if not dtype: - dtype = utils.result_type(result_t) - # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because - # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) - # while np.array allows them. We need to convert-then-cast. - def maybe_data(x): - if isinstance(x, arrays_lib.ndarray): - return x.data - return x - - # Handles lists of ndarrays - result_t = tf.nest.map_structure(maybe_data, result_t) - result_t = arrays_lib.convert_to_tensor(result_t) - result_t = tf.cast(result_t, dtype=dtype) - elif dtype: - result_t = tf.cast(result_t, dtype) - ndims = tf.rank(result_t) - - def true_fn(): - old_shape = tf.shape(result_t) - new_shape = tf.concat([tf.ones(ndmin - ndims, tf.int32), old_shape], axis=0) - return tf.reshape(result_t, new_shape) - - result_t = utils.cond(utils.greater(ndmin, ndims), true_fn, lambda: result_t) - return arrays_lib.tensor_to_ndarray(result_t) + """Creates an ndarray with the contents of val. + + Args: + val: array_like. Could be an ndarray, a Tensor or any object that can be + converted to a Tensor using `tf.convert_to_tensor`. + dtype: Optional, defaults to dtype of the `val`. The type of the resulting + ndarray. Could be a python type, a NumPy type or a TensorFlow `DType`. + copy: Determines whether to create a copy of the backing buffer. Since + Tensors are immutable, a copy is made only if val is placed on a different + device than the current one. Even if `copy` is False, a new Tensor may + need to be built to satisfy `dtype` and `ndim`. This is used only if `val` + is an ndarray or a Tensor. + ndmin: The minimum rank of the returned array. + + Returns: + An ndarray. + """ + if dtype: + dtype = utils.result_type(dtype) + if isinstance(val, arrays_lib.ndarray): + result_t = val.data + else: + result_t = val + + if copy and isinstance(result_t, tf.Tensor): + # Note: In eager mode, a copy of `result_t` is made only if it is not on + # the context device. + result_t = tf.identity(result_t) + + if not isinstance(result_t, tf.Tensor): + if not dtype: + dtype = utils.result_type(result_t) + # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because + # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) + # while np.array allows them. We need to convert-then-cast. + def maybe_data(x): + if isinstance(x, arrays_lib.ndarray): + return x.data + return x + + # Handles lists of ndarrays + result_t = tf.nest.map_structure(maybe_data, result_t) + result_t = arrays_lib.convert_to_tensor(result_t) + result_t = tf.cast(result_t, dtype=dtype) + elif dtype: + result_t = tf.cast(result_t, dtype) + ndims = tf.rank(result_t) + + def true_fn(): + old_shape = tf.shape(result_t) + new_shape = tf.concat([tf.ones(ndmin - ndims, tf.int32), old_shape], axis=0) + return tf.reshape(result_t, new_shape) + + result_t = utils.cond(utils.greater(ndmin, ndims), true_fn, lambda: result_t) + return arrays_lib.tensor_to_ndarray(result_t) @utils.np_doc(np.asarray) def asarray(a, dtype=None): - if dtype: - dtype = utils.result_type(dtype) - if isinstance(a, arrays_lib.ndarray) and (not dtype or dtype == a.dtype): - return a - return array(a, dtype, copy=False) + if dtype: + dtype = utils.result_type(dtype) + if isinstance(a, arrays_lib.ndarray) and (not dtype or dtype == a.dtype): + return a + return array(a, dtype, copy=False) @utils.np_doc(np.asanyarray) def asanyarray(a, dtype=None): - return asarray(a, dtype) + return asarray(a, dtype) @utils.np_doc(np.ascontiguousarray) def ascontiguousarray(a, dtype=None): - return array(a, dtype, ndmin=1) + return array(a, dtype, ndmin=1) # Numerical ranges. def arange(start, stop=None, step=1, dtype=None): - """Returns `step`-separated values in the range [start, stop). - - Args: - start: Start of the interval. Included in the range. - stop: End of the interval. If not specified, `start` is treated as 0 and - `start` value is used as `stop`. If specified, it is not included in the - range if `step` is integer. When `step` is floating point, it may or may - not be included. - step: The difference between 2 consecutive values in the output range. It is - recommended to use `linspace` instead of using non-integer values for - `step`. - dtype: Optional. Type of the resulting ndarray. Could be a python type, a - NumPy type or a TensorFlow `DType`. If not provided, the largest type of - `start`, `stop`, `step` is used. - - Raises: - ValueError: If step is zero. - """ - if not step: - raise ValueError('step must be non-zero.') - if dtype: - dtype = utils.result_type(dtype) - else: - if stop is None: - dtype = utils.result_type(start, step) + """Returns `step`-separated values in the range [start, stop). + + Args: + start: Start of the interval. Included in the range. + stop: End of the interval. If not specified, `start` is treated as 0 and + `start` value is used as `stop`. If specified, it is not included in the + range if `step` is integer. When `step` is floating point, it may or may + not be included. + step: The difference between 2 consecutive values in the output range. It is + recommended to use `linspace` instead of using non-integer values for + `step`. + dtype: Optional. Type of the resulting ndarray. Could be a python type, a + NumPy type or a TensorFlow `DType`. If not provided, the largest type of + `start`, `stop`, `step` is used. + + Raises: + ValueError: If step is zero. + """ + if not step: + raise ValueError("step must be non-zero.") + if dtype: + dtype = utils.result_type(dtype) else: - dtype = utils.result_type(start, step, stop) - if step > 0 and ((stop is not None and start > stop) or - (stop is None and start < 0)): - return array([], dtype=dtype) - if step < 0 and ((stop is not None and start < stop) or - (stop is None and start > 0)): - return array([], dtype=dtype) - # TODO(srbs): There are some bugs when start or stop is float type and dtype - # is integer type. - return arrays_lib.tensor_to_ndarray( - tf.cast(tf.range(start, limit=stop, delta=step), dtype=dtype)) + if stop is None: + dtype = utils.result_type(start, step) + else: + dtype = utils.result_type(start, step, stop) + if step > 0 and ( + (stop is not None and start > stop) or (stop is None and start < 0) + ): + return array([], dtype=dtype) + if step < 0 and ( + (stop is not None and start < stop) or (stop is None and start > 0) + ): + return array([], dtype=dtype) + # TODO(srbs): There are some bugs when start or stop is float type and dtype + # is integer type. + return arrays_lib.tensor_to_ndarray( + tf.cast(tf.range(start, limit=stop, delta=step), dtype=dtype) + ) @utils.np_doc(np.geomspace) -def geomspace(start, stop, num=50, endpoint=True, dtype=float): # pylint: disable=missing-docstring - if dtype: - dtype = utils.result_type(dtype) - if num < 0: - raise ValueError('Number of samples {} must be non-negative.'.format(num)) - if not num: - return empty([0]) - step = 1. - if endpoint: - if num > 1: - step = tf.pow((stop / start), 1 / (num - 1)) - else: - step = tf.pow((stop / start), 1 / num) - result = tf.cast(tf.range(num), step.dtype) - result = tf.pow(step, result) - result = tf.multiply(result, start) - if dtype: - result = tf.cast(result, dtype=dtype) - return arrays_lib.tensor_to_ndarray(result) +def geomspace( + start, stop, num=50, endpoint=True, dtype=float +): # pylint: disable=missing-docstring + if dtype: + dtype = utils.result_type(dtype) + if num < 0: + raise ValueError("Number of samples {} must be non-negative.".format(num)) + if not num: + return empty([0]) + step = 1.0 + if endpoint: + if num > 1: + step = tf.pow((stop / start), 1 / (num - 1)) + else: + step = tf.pow((stop / start), 1 / num) + result = tf.cast(tf.range(num), step.dtype) + result = tf.pow(step, result) + result = tf.multiply(result, start) + if dtype: + result = tf.cast(result, dtype=dtype) + return arrays_lib.tensor_to_ndarray(result) # Building matrices. @utils.np_doc(np.diag) def diag(v, k=0): # pylint: disable=missing-docstring - """Raises an error if input is not 1- or 2-d.""" - v = asarray(v).data - v_rank = tf.rank(v) - - v.shape.with_rank_at_most(2) - - # TODO(nareshmodi): Consider a utils.Assert version that will fail during - # tracing time if the shape is known. - tf.debugging.Assert( - utils.logical_or(tf.equal(v_rank, 1), tf.equal(v_rank, 2)), [v_rank]) - - def _diag(v, k): - return utils.cond( - tf.equal(tf.size(v), 0), - lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype), - lambda: tf.linalg.diag(v, k=k)) - - def _diag_part(v, k): - v_shape = tf.shape(v) - v, k = utils.cond( - utils.logical_or( - utils.less_equal(k, -1 * utils.getitem(v_shape, 0)), - utils.greater_equal(k, utils.getitem(v_shape, 1)), - ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k)) - result = tf.linalg.diag_part(v, k=k) - return result - - result = utils.cond( - tf.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)) - return utils.tensor_to_ndarray(result) + """Raises an error if input is not 1- or 2-d.""" + v = asarray(v).data + v_rank = tf.rank(v) + + v.shape.with_rank_at_most(2) + + # TODO(nareshmodi): Consider a utils.Assert version that will fail during + # tracing time if the shape is known. + tf.debugging.Assert( + utils.logical_or(tf.equal(v_rank, 1), tf.equal(v_rank, 2)), [v_rank] + ) + + def _diag(v, k): + return utils.cond( + tf.equal(tf.size(v), 0), + lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype), + lambda: tf.linalg.diag(v, k=k), + ) + + def _diag_part(v, k): + v_shape = tf.shape(v) + v, k = utils.cond( + utils.logical_or( + utils.less_equal(k, -1 * utils.getitem(v_shape, 0)), + utils.greater_equal(k, utils.getitem(v_shape, 1)), + ), + lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), + lambda: (v, k), + ) + result = tf.linalg.diag_part(v, k=k) + return result + + result = utils.cond( + tf.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k) + ) + return utils.tensor_to_ndarray(result) @utils.np_doc(np.diagonal) def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring - a = asarray(a).data - - maybe_rank = a.shape.rank - if maybe_rank is not None and offset == 0 and ( - axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or - axis2 == -1): - return utils.tensor_to_ndarray(tf.linalg.diag_part(a)) + a = asarray(a).data - a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data + maybe_rank = a.shape.rank + if ( + maybe_rank is not None + and offset == 0 + and (axis1 == maybe_rank - 2 or axis1 == -2) + and (axis2 == maybe_rank - 1 or axis2 == -1) + ): + return utils.tensor_to_ndarray(tf.linalg.diag_part(a)) - a_shape = tf.shape(a) + a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data - def _zeros(): # pylint: disable=missing-docstring - return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0) + a_shape = tf.shape(a) - # All zeros since diag_part doesn't handle all possible k (aka offset). - # Written this way since cond will run shape inference on both branches, - # and diag_part shape inference will fail when offset is out of bounds. - a, offset = utils.cond( - utils.logical_or( - utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)), - utils.greater_equal(offset, utils.getitem(a_shape, -1)), - ), _zeros, lambda: (a, offset)) + def _zeros(): # pylint: disable=missing-docstring + return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0) - a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset)) - return a + # All zeros since diag_part doesn't handle all possible k (aka offset). + # Written this way since cond will run shape inference on both branches, + # and diag_part shape inference will fail when offset is out of bounds. + a, offset = utils.cond( + utils.logical_or( + utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)), + utils.greater_equal(offset, utils.getitem(a_shape, -1)), + ), + _zeros, + lambda: (a, offset), + ) + + a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset)) + return a def diagflat(v, k=0): - """Returns a 2-d array with flattened `v` as diagonal. + """Returns a 2-d array with flattened `v` as diagonal. - Args: - v: array_like of any rank. Gets flattened when setting as diagonal. Could be - an ndarray, a Tensor or any object that can be converted to a Tensor using - `tf.convert_to_tensor`. - k: Position of the diagonal. Defaults to 0, the main diagonal. Positive - values refer to diagonals shifted right, negative values refer to - diagonals shifted left. + Args: + v: array_like of any rank. Gets flattened when setting as diagonal. Could be + an ndarray, a Tensor or any object that can be converted to a Tensor using + `tf.convert_to_tensor`. + k: Position of the diagonal. Defaults to 0, the main diagonal. Positive + values refer to diagonals shifted right, negative values refer to + diagonals shifted left. - Returns: - 2-d ndarray. - """ - v = asarray(v) - return diag(tf.reshape(v.data, [-1]), k) + Returns: + 2-d ndarray. + """ + v = asarray(v) + return diag(tf.reshape(v.data, [-1]), k) def _promote_dtype(*arrays): - dtype = utils.result_type(*arrays) - return [asarray(a, dtype=dtype) for a in arrays] + dtype = utils.result_type(*arrays) + return [asarray(a, dtype=dtype) for a in arrays] def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin - """Whether all array elements or those along an axis evaluate to true. - - Casts the array to bool type if it is not already and uses `tf.reduce_all` to - compute the result. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: Optional. Could be an int or a tuple of integers. If not specified, - the reduction is performed over all array indices. - keepdims: If true, retains reduced dimensions with length 1. - - Returns: - An ndarray. Note that unlike NumPy this does not return a scalar bool if - `axis` is None. - """ - a = asarray(a, dtype=bool) - return utils.tensor_to_ndarray( - tf.reduce_all(input_tensor=a.data, axis=axis, keepdims=keepdims)) + """Whether all array elements or those along an axis evaluate to true. + + Casts the array to bool type if it is not already and uses `tf.reduce_all` to + compute the result. + + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. + axis: Optional. Could be an int or a tuple of integers. If not specified, + the reduction is performed over all array indices. + keepdims: If true, retains reduced dimensions with length 1. + + Returns: + An ndarray. Note that unlike NumPy this does not return a scalar bool if + `axis` is None. + """ + a = asarray(a, dtype=bool) + return utils.tensor_to_ndarray( + tf.reduce_all(input_tensor=a.data, axis=axis, keepdims=keepdims) + ) def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin - """Whether any element in the entire array or in an axis evaluates to true. - - Casts the array to bool type if it is not already and uses `tf.reduce_any` to - compute the result. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: Optional. Could be an int or a tuple of integers. If not specified, - the reduction is performed over all array indices. - keepdims: If true, retains reduced dimensions with length 1. - - Returns: - An ndarray. Note that unlike NumPy this does not return a scalar bool if - `axis` is None. - """ - a = asarray(a, dtype=bool) - return utils.tensor_to_ndarray( - tf.reduce_any(input_tensor=a.data, axis=axis, keepdims=keepdims)) + """Whether any element in the entire array or in an axis evaluates to true. + + Casts the array to bool type if it is not already and uses `tf.reduce_any` to + compute the result. + + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. + axis: Optional. Could be an int or a tuple of integers. If not specified, + the reduction is performed over all array indices. + keepdims: If true, retains reduced dimensions with length 1. + + Returns: + An ndarray. Note that unlike NumPy this does not return a scalar bool if + `axis` is None. + """ + a = asarray(a, dtype=bool) + return utils.tensor_to_ndarray( + tf.reduce_any(input_tensor=a.data, axis=axis, keepdims=keepdims) + ) def compress(condition, a, axis=None): - """Compresses `a` by selecting values along `axis` with `condition` true. + """Compresses `a` by selecting values along `axis` with `condition` true. - Uses `tf.boolean_mask`. + Uses `tf.boolean_mask`. - Args: - condition: 1-d array of bools. If `condition` is shorter than the array - axis (or the flattened array if axis is None), it is padded with False. - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: Optional. Axis along which to select elements. If None, `condition` is - applied on flattened array. + Args: + condition: 1-d array of bools. If `condition` is shorter than the array + axis (or the flattened array if axis is None), it is padded with False. + a: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. + axis: Optional. Axis along which to select elements. If None, `condition` is + applied on flattened array. - Returns: - An ndarray. + Returns: + An ndarray. - Raises: - ValueError: if `condition` is not of rank 1. - """ - condition = asarray(condition, dtype=bool) - a = asarray(a) + Raises: + ValueError: if `condition` is not of rank 1. + """ + condition = asarray(condition, dtype=bool) + a = asarray(a) - if condition.ndim != 1: - raise ValueError('condition must be a 1-d array.') - # `np.compress` treats scalars as 1-d arrays. - if a.ndim == 0: - a = ravel(a) + if condition.ndim != 1: + raise ValueError("condition must be a 1-d array.") + # `np.compress` treats scalars as 1-d arrays. + if a.ndim == 0: + a = ravel(a) - if axis is None: - a = ravel(a) - axis = 0 + if axis is None: + a = ravel(a) + axis = 0 - if axis < 0: - axis += a.ndim + if axis < 0: + axis += a.ndim - assert axis >= 0 and axis < a.ndim + assert axis >= 0 and axis < a.ndim - # `tf.boolean_mask` requires the first dimensions of array and condition to - # match. `np.compress` pads condition with False when it is shorter. - condition_t = condition.data - a_t = a.data - if condition.shape[0] < a.shape[axis]: - padding = tf.fill([a.shape[axis] - condition.shape[0]], False) - condition_t = tf.concat([condition_t, padding], axis=0) - return utils.tensor_to_ndarray(tf.boolean_mask(tensor=a_t, mask=condition_t, - axis=axis)) + # `tf.boolean_mask` requires the first dimensions of array and condition to + # match. `np.compress` pads condition with False when it is shorter. + condition_t = condition.data + a_t = a.data + if condition.shape[0] < a.shape[axis]: + padding = tf.fill([a.shape[axis] - condition.shape[0]], False) + condition_t = tf.concat([condition_t, padding], axis=0) + return utils.tensor_to_ndarray( + tf.boolean_mask(tensor=a_t, mask=condition_t, axis=axis) + ) def copy(a): - """Returns a copy of the array.""" - return array(a, copy=True) + """Returns a copy of the array.""" + return array(a, copy=True) def _maybe_promote_to_int(a): - if tf.as_dtype(a.dtype).is_integer: - # If a is an integer type and its precision is less than that of `int`, - # the output type will be `int`. - output_type = np.promote_types(a.dtype, int) - if output_type != a.dtype: - a = asarray(a, dtype=output_type) + if tf.as_dtype(a.dtype).is_integer: + # If a is an integer type and its precision is less than that of `int`, + # the output type will be `int`. + output_type = np.promote_types(a.dtype, int) + if output_type != a.dtype: + a = asarray(a, dtype=output_type) - return a + return a @utils.np_doc(np.cumprod) def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring - a = asarray(a, dtype=dtype) + a = asarray(a, dtype=dtype) - if dtype is None: - a = _maybe_promote_to_int(a) + if dtype is None: + a = _maybe_promote_to_int(a) - # If axis is None, the input is flattened. - if axis is None: - a = ravel(a) - axis = 0 - elif axis < 0: - axis += tf.rank(a.data) - return utils.tensor_to_ndarray(tf.math.cumprod(a.data, axis)) + # If axis is None, the input is flattened. + if axis is None: + a = ravel(a) + axis = 0 + elif axis < 0: + axis += tf.rank(a.data) + return utils.tensor_to_ndarray(tf.math.cumprod(a.data, axis)) @utils.np_doc(np.cumsum) def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring - a = asarray(a, dtype=dtype) + a = asarray(a, dtype=dtype) - if dtype is None: - a = _maybe_promote_to_int(a) + if dtype is None: + a = _maybe_promote_to_int(a) - # If axis is None, the input is flattened. - if axis is None: - a = ravel(a) - axis = 0 - elif axis < 0: - axis += tf.rank(a.data) - return utils.tensor_to_ndarray(tf.cumsum(a.data, axis)) + # If axis is None, the input is flattened. + if axis is None: + a = ravel(a) + axis = 0 + elif axis < 0: + axis += tf.rank(a.data) + return utils.tensor_to_ndarray(tf.cumsum(a.data, axis)) def imag(a): - """Returns imaginary parts of all elements in `a`. + """Returns imaginary parts of all elements in `a`. - Uses `tf.imag`. + Uses `tf.imag`. - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. - Returns: - An ndarray with the same shape as `a`. - """ - a = asarray(a) - # TODO(srbs): np.imag returns a scalar if a is a scalar, whereas we always - # return an ndarray. - return utils.tensor_to_ndarray(tf.math.imag(a.data)) + Returns: + An ndarray with the same shape as `a`. + """ + a = asarray(a) + # TODO(srbs): np.imag returns a scalar if a is a scalar, whereas we always + # return an ndarray. + return utils.tensor_to_ndarray(tf.math.imag(a.data)) _TO_INT64 = 0 _TO_FLOAT = 1 -def _reduce(tf_fn, a, axis=None, dtype=None, keepdims=None, - promote_int=_TO_INT64, tf_bool_fn=None, preserve_bool=False): - """A general reduction function. - - Args: - tf_fn: the TF reduction function. - a: the array to be reduced. - axis: (optional) the axis along which to do the reduction. If None, all - dimensions are reduced. - dtype: (optional) the dtype of the result. - keepdims: (optional) whether to keep the reduced dimension(s). - promote_int: how to promote integer and bool inputs. There are three - choices: (1) _TO_INT64: always promote them to int64 or uint64; (2) - _TO_FLOAT: always promote them to a float type (determined by - dtypes.default_float_type); (3) None: don't promote. - tf_bool_fn: (optional) the TF reduction function for bool inputs. It - will only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s - dtype is `np.bool_` and `preserve_bool` is True. - preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype - is `np.bool_` (some reductions such as np.sum convert bools to - integers, while others such as np.max preserve bools. - - Returns: - An ndarray. - """ - if dtype: - dtype = utils.result_type(dtype) - if keepdims is None: - keepdims = False - a = asarray(a, dtype=dtype) - if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) - and tf_bool_fn is not None): - return utils.tensor_to_ndarray( - tf_bool_fn(input_tensor=a.data, axis=axis, keepdims=keepdims)) - if dtype is None: - dtype = a.dtype - if np.issubdtype(dtype, np.integer) or dtype == np.bool_: - if promote_int == _TO_INT64: - # If a is an integer/bool type and whose bit width is less than 64, - # numpy up-casts it to 64-bit. - if dtype == np.bool_: - is_signed = True - width = 8 # We can use any number here that is less than 64 - else: - is_signed = np.issubdtype(dtype, np.signedinteger) - width = np.iinfo(dtype).bits - if width < 64: - if is_signed: - dtype = np.int64 - else: - dtype = np.uint64 - a = a.astype(dtype) - elif promote_int == _TO_FLOAT: - a = a.astype(dtypes.default_float_type()) +def _reduce( + tf_fn, + a, + axis=None, + dtype=None, + keepdims=None, + promote_int=_TO_INT64, + tf_bool_fn=None, + preserve_bool=False, +): + """A general reduction function. + + Args: + tf_fn: the TF reduction function. + a: the array to be reduced. + axis: (optional) the axis along which to do the reduction. If None, all + dimensions are reduced. + dtype: (optional) the dtype of the result. + keepdims: (optional) whether to keep the reduced dimension(s). + promote_int: how to promote integer and bool inputs. There are three + choices: (1) _TO_INT64: always promote them to int64 or uint64; (2) + _TO_FLOAT: always promote them to a float type (determined by + dtypes.default_float_type); (3) None: don't promote. + tf_bool_fn: (optional) the TF reduction function for bool inputs. It + will only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s + dtype is `np.bool_` and `preserve_bool` is True. + preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype + is `np.bool_` (some reductions such as np.sum convert bools to + integers, while others such as np.max preserve bools. + + Returns: + An ndarray. + """ + if dtype: + dtype = utils.result_type(dtype) + if keepdims is None: + keepdims = False + a = asarray(a, dtype=dtype) + if ( + dtype == np.bool_ or preserve_bool and a.dtype == np.bool_ + ) and tf_bool_fn is not None: + return utils.tensor_to_ndarray( + tf_bool_fn(input_tensor=a.data, axis=axis, keepdims=keepdims) + ) + if dtype is None: + dtype = a.dtype + if np.issubdtype(dtype, np.integer) or dtype == np.bool_: + if promote_int == _TO_INT64: + # If a is an integer/bool type and whose bit width is less than 64, + # numpy up-casts it to 64-bit. + if dtype == np.bool_: + is_signed = True + width = 8 # We can use any number here that is less than 64 + else: + is_signed = np.issubdtype(dtype, np.signedinteger) + width = np.iinfo(dtype).bits + if width < 64: + if is_signed: + dtype = np.int64 + else: + dtype = np.uint64 + a = a.astype(dtype) + elif promote_int == _TO_FLOAT: + a = a.astype(dtypes.default_float_type()) - return utils.tensor_to_ndarray( - tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims)) + return utils.tensor_to_ndarray( + tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims) + ) @utils.np_doc(np.sum) def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin - return _reduce(tf.reduce_sum, a, axis=axis, dtype=dtype, keepdims=keepdims, - tf_bool_fn=tf.reduce_any) + return _reduce( + tf.reduce_sum, + a, + axis=axis, + dtype=dtype, + keepdims=keepdims, + tf_bool_fn=tf.reduce_any, + ) @utils.np_doc(np.prod) def prod(a, axis=None, dtype=None, keepdims=None): - return _reduce(tf.reduce_prod, a, axis=axis, dtype=dtype, keepdims=keepdims, - tf_bool_fn=tf.reduce_all) + return _reduce( + tf.reduce_prod, + a, + axis=axis, + dtype=dtype, + keepdims=keepdims, + tf_bool_fn=tf.reduce_all, + ) @utils.np_doc(np.mean) def mean(a, axis=None, dtype=None, keepdims=None): - return _reduce(tf.math.reduce_mean, a, axis=axis, dtype=dtype, - keepdims=keepdims, promote_int=_TO_FLOAT) + return _reduce( + tf.math.reduce_mean, + a, + axis=axis, + dtype=dtype, + keepdims=keepdims, + promote_int=_TO_FLOAT, + ) @utils.np_doc(np.amax) def amax(a, axis=None, keepdims=None): - return _reduce(tf.reduce_max, a, axis=axis, dtype=None, keepdims=keepdims, - promote_int=None, tf_bool_fn=tf.reduce_any, preserve_bool=True) + return _reduce( + tf.reduce_max, + a, + axis=axis, + dtype=None, + keepdims=keepdims, + promote_int=None, + tf_bool_fn=tf.reduce_any, + preserve_bool=True, + ) @utils.np_doc(np.amin) def amin(a, axis=None, keepdims=None): - return _reduce(tf.reduce_min, a, axis=axis, dtype=None, keepdims=keepdims, - promote_int=None, tf_bool_fn=tf.reduce_all, preserve_bool=True) + return _reduce( + tf.reduce_min, + a, + axis=axis, + dtype=None, + keepdims=keepdims, + promote_int=None, + tf_bool_fn=tf.reduce_all, + preserve_bool=True, + ) # TODO(wangpeng): Remove this workaround once b/157232284 is fixed def _reduce_variance_complex(input_tensor, axis, keepdims): - f = functools.partial(tf.math.reduce_variance, axis=axis, keepdims=keepdims) - return f(tf.math.real(input_tensor)) + f(tf.math.imag(input_tensor)) + f = functools.partial(tf.math.reduce_variance, axis=axis, keepdims=keepdims) + return f(tf.math.real(input_tensor)) + f(tf.math.imag(input_tensor)) # TODO(wangpeng): Remove this workaround once b/157232284 is fixed def _reduce_std_complex(input_tensor, axis, keepdims): - y = _reduce_variance_complex(input_tensor=input_tensor, axis=axis, - keepdims=keepdims) - return tf.math.sqrt(y) + y = _reduce_variance_complex( + input_tensor=input_tensor, axis=axis, keepdims=keepdims + ) + return tf.math.sqrt(y) @utils.np_doc(np.var) def var(a, axis=None, keepdims=None): - def f(input_tensor, axis, keepdims): - if input_tensor.dtype in (tf.complex64, tf.complex128): - # A workaround for b/157232284 - fn = _reduce_variance_complex - else: - fn = tf.math.reduce_variance - return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims) - return _reduce(f, a, axis=axis, dtype=None, keepdims=keepdims, - promote_int=_TO_FLOAT) + def f(input_tensor, axis, keepdims): + if input_tensor.dtype in (tf.complex64, tf.complex128): + # A workaround for b/157232284 + fn = _reduce_variance_complex + else: + fn = tf.math.reduce_variance + return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims) + + return _reduce( + f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT + ) @utils.np_doc(np.std) def std(a, axis=None, keepdims=None): - def f(input_tensor, axis, keepdims): - if input_tensor.dtype in (tf.complex64, tf.complex128): - # A workaround for b/157232284 - fn = _reduce_std_complex - else: - fn = tf.math.reduce_std - return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims) - return _reduce(f, a, axis=axis, dtype=None, keepdims=keepdims, - promote_int=_TO_FLOAT) + def f(input_tensor, axis, keepdims): + if input_tensor.dtype in (tf.complex64, tf.complex128): + # A workaround for b/157232284 + fn = _reduce_std_complex + else: + fn = tf.math.reduce_std + return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims) + + return _reduce( + f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT + ) @utils.np_doc(np.ravel) def ravel(a): # pylint: disable=missing-docstring - a = asarray(a) - if a.ndim == 1: - return a - return utils.tensor_to_ndarray(tf.reshape(a.data, [-1])) + a = asarray(a) + if a.ndim == 1: + return a + return utils.tensor_to_ndarray(tf.reshape(a.data, [-1])) -setattr(arrays_lib.ndarray, 'ravel', ravel) +setattr(arrays_lib.ndarray, "ravel", ravel) def real(val): - """Returns real parts of all elements in `a`. + """Returns real parts of all elements in `a`. - Uses `tf.real`. + Uses `tf.real`. - Args: - val: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. + Args: + val: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. - Returns: - An ndarray with the same shape as `a`. - """ - val = asarray(val) - # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always - # return an ndarray. - return utils.tensor_to_ndarray(tf.math.real(val.data)) + Returns: + An ndarray with the same shape as `a`. + """ + val = asarray(val) + # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always + # return an ndarray. + return utils.tensor_to_ndarray(tf.math.real(val.data)) @utils.np_doc(np.repeat) def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring - a = asarray(a).data - original_shape = a._shape_as_list() # pylint: disable=protected-access - # Best effort recovery of the shape. - if original_shape is not None and None not in original_shape: - if not original_shape: - original_shape = (repeats,) - else: - repeats_np = np.ravel(np.array(repeats)) - if repeats_np.size == 1: - repeats_np = repeats_np.item() - if axis is None: - original_shape = (repeats_np * np.prod(original_shape),) - else: - original_shape[axis] = repeats_np * original_shape[axis] - else: - if axis is None: - original_shape = (repeats_np.sum(),) + a = asarray(a).data + original_shape = a._shape_as_list() # pylint: disable=protected-access + # Best effort recovery of the shape. + if original_shape is not None and None not in original_shape: + if not original_shape: + original_shape = (repeats,) else: - original_shape[axis] = repeats_np.sum() - - repeats = asarray(repeats).data - result = tf.repeat(a, repeats, axis) - result.set_shape(original_shape) - - return utils.tensor_to_ndarray(result) + repeats_np = np.ravel(np.array(repeats)) + if repeats_np.size == 1: + repeats_np = repeats_np.item() + if axis is None: + original_shape = (repeats_np * np.prod(original_shape),) + else: + original_shape[axis] = repeats_np * original_shape[axis] + else: + if axis is None: + original_shape = (repeats_np.sum(),) + else: + original_shape[axis] = repeats_np.sum() + + repeats = asarray(repeats).data + result = tf.repeat(a, repeats, axis) + result.set_shape(original_shape) + + return utils.tensor_to_ndarray(result) @utils.np_doc(np.around) def around(a, decimals=0): # pylint: disable=missing-docstring - a = asarray(a) - dtype = a.dtype - factor = math.pow(10, decimals) - if np.issubdtype(dtype, np.inexact): - factor = tf.cast(factor, dtype) - else: - # Use float as the working dtype when a.dtype is exact (e.g. integer), - # because `decimals` can be negative. - float_dtype = dtypes.default_float_type() - a = a.astype(float_dtype).data - factor = tf.cast(factor, float_dtype) - a = tf.multiply(a, factor) - a = tf.round(a) - a = tf.math.divide(a, factor) - return utils.tensor_to_ndarray(a).astype(dtype) + a = asarray(a) + dtype = a.dtype + factor = math.pow(10, decimals) + if np.issubdtype(dtype, np.inexact): + factor = tf.cast(factor, dtype) + else: + # Use float as the working dtype when a.dtype is exact (e.g. integer), + # because `decimals` can be negative. + float_dtype = dtypes.default_float_type() + a = a.astype(float_dtype).data + factor = tf.cast(factor, float_dtype) + a = tf.multiply(a, factor) + a = tf.round(a) + a = tf.math.divide(a, factor) + return utils.tensor_to_ndarray(a).astype(dtype) round_ = around -setattr(arrays_lib.ndarray, '__round__', around) +setattr(arrays_lib.ndarray, "__round__", around) @utils.np_doc(np.reshape) -def reshape(a, newshape, order='C'): - """order argument can only b 'C' or 'F'.""" - if order not in {'C', 'F'}: - raise ValueError('Unsupported order argument {}'.format(order)) - - a = asarray(a) - if isinstance(newshape, arrays_lib.ndarray): - newshape = newshape.data - if isinstance(newshape, int): - newshape = [newshape] - - if order == 'F': - r = tf.transpose(tf.reshape(tf.transpose(a.data), newshape[::-1])) - else: - r = tf.reshape(a.data, newshape) +def reshape(a, newshape, order="C"): + """order argument can only b 'C' or 'F'.""" + if order not in {"C", "F"}: + raise ValueError("Unsupported order argument {}".format(order)) + + a = asarray(a) + if isinstance(newshape, arrays_lib.ndarray): + newshape = newshape.data + if isinstance(newshape, int): + newshape = [newshape] + + if order == "F": + r = tf.transpose(tf.reshape(tf.transpose(a.data), newshape[::-1])) + else: + r = tf.reshape(a.data, newshape) - return utils.tensor_to_ndarray(r) + return utils.tensor_to_ndarray(r) def _reshape_method_wrapper(a, *newshape, **kwargs): - order = kwargs.pop('order', 'C') - if kwargs: - raise ValueError('Unsupported arguments: {}'.format(kwargs.keys())) + order = kwargs.pop("order", "C") + if kwargs: + raise ValueError("Unsupported arguments: {}".format(kwargs.keys())) - if len(newshape) == 1 and not isinstance(newshape[0], int): - newshape = newshape[0] + if len(newshape) == 1 and not isinstance(newshape[0], int): + newshape = newshape[0] - return reshape(a, newshape, order=order) + return reshape(a, newshape, order=order) def expand_dims(a, axis): - """Expand the shape of an array. + """Expand the shape of an array. - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: int. axis on which to expand the shape. + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. + axis: int. axis on which to expand the shape. - Returns: - An ndarray with the contents and dtype of `a` and shape expanded on axis. - """ - a = asarray(a) - return utils.tensor_to_ndarray(tf.expand_dims(a.data, axis=axis)) + Returns: + An ndarray with the contents and dtype of `a` and shape expanded on axis. + """ + a = asarray(a) + return utils.tensor_to_ndarray(tf.expand_dims(a.data, axis=axis)) def squeeze(a, axis=None): - """Removes single-element axes from the array. + """Removes single-element axes from the array. - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: scalar or list/tuple of ints. + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. + axis: scalar or list/tuple of ints. - TODO(srbs): tf.squeeze throws error when axis is a Tensor eager execution - is enabled. So we cannot allow axis to be array_like here. Fix. + TODO(srbs): tf.squeeze throws error when axis is a Tensor eager execution + is enabled. So we cannot allow axis to be array_like here. Fix. - Returns: - An ndarray. - """ - a = asarray(a) - return utils.tensor_to_ndarray(tf.squeeze(a, axis)) + Returns: + An ndarray. + """ + a = asarray(a) + return utils.tensor_to_ndarray(tf.squeeze(a, axis)) def transpose(a, axes=None): - """Permutes dimensions of the array. + """Permutes dimensions of the array. - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axes: array_like. A list of ints with length rank(a) or None specifying the - order of permutation. The i'th dimension of the output array corresponds - to axes[i]'th dimension of the `a`. If None, the axes are reversed. + Args: + a: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. + axes: array_like. A list of ints with length rank(a) or None specifying the + order of permutation. The i'th dimension of the output array corresponds + to axes[i]'th dimension of the `a`. If None, the axes are reversed. - Returns: - An ndarray. - """ - a = asarray(a) - if axes is not None: - axes = asarray(axes) - return utils.tensor_to_ndarray(tf.transpose(a=a.data, perm=axes)) + Returns: + An ndarray. + """ + a = asarray(a) + if axes is not None: + axes = asarray(axes) + return utils.tensor_to_ndarray(tf.transpose(a=a.data, perm=axes)) @utils.np_doc(np.swapaxes) def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring - a = asarray(a) + a = asarray(a) - a_rank = tf.rank(a) - if axis1 < 0: - axis1 += a_rank - if axis2 < 0: - axis2 += a_rank + a_rank = tf.rank(a) + if axis1 < 0: + axis1 += a_rank + if axis2 < 0: + axis2 += a_rank - perm = tf.range(a_rank) - perm = tf.tensor_scatter_nd_update(perm, [[axis1], [axis2]], [axis2, axis1]) - a = tf.transpose(a, perm) + perm = tf.range(a_rank) + perm = tf.tensor_scatter_nd_update(perm, [[axis1], [axis2]], [axis2, axis1]) + a = tf.transpose(a, perm) - return utils.tensor_to_ndarray(a) + return utils.tensor_to_ndarray(a) @utils.np_doc(np.moveaxis) def moveaxis(a, source, destination): # pylint: disable=missing-docstring - """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" - if not source and not destination: - return a + """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" + if not source and not destination: + return a - a = asarray(a).data + a = asarray(a).data - if isinstance(source, int): - source = (source,) - if isinstance(destination, int): - destination = (destination,) + if isinstance(source, int): + source = (source,) + if isinstance(destination, int): + destination = (destination,) - a_rank = utils._maybe_static(tf.rank(a)) # pylint: disable=protected-access + a_rank = utils._maybe_static(tf.rank(a)) # pylint: disable=protected-access - def _correct_axis(axis, rank): - if axis < 0: - return axis + rank - return axis + def _correct_axis(axis, rank): + if axis < 0: + return axis + rank + return axis - source = tuple(_correct_axis(axis, a_rank) for axis in source) - destination = tuple(_correct_axis(axis, a_rank) for axis in destination) + source = tuple(_correct_axis(axis, a_rank) for axis in source) + destination = tuple(_correct_axis(axis, a_rank) for axis in destination) - if a.shape.rank is not None: - perm = [i for i in range(a_rank) if i not in source] - for dest, src in sorted(zip(destination, source)): - assert dest <= len(perm) - perm.insert(dest, src) - else: - r = tf.range(a_rank) + if a.shape.rank is not None: + perm = [i for i in range(a_rank) if i not in source] + for dest, src in sorted(zip(destination, source)): + assert dest <= len(perm) + perm.insert(dest, src) + else: + r = tf.range(a_rank) - def _remove_indices(a, b): - """Remove indices (`b`) from `a`.""" - items = tf.unstack(tf.sort(tf.stack(b)), num=len(b)) + def _remove_indices(a, b): + """Remove indices (`b`) from `a`.""" + items = tf.unstack(tf.sort(tf.stack(b)), num=len(b)) - i = 0 - result = [] + i = 0 + result = [] - for item in items: - result.append(a[i:item]) - i = item + 1 + for item in items: + result.append(a[i:item]) + i = item + 1 - result.append(a[i:]) + result.append(a[i:]) - return tf.concat(result, 0) + return tf.concat(result, 0) - minus_sources = _remove_indices(r, source) - minus_dest = _remove_indices(r, destination) + minus_sources = _remove_indices(r, source) + minus_dest = _remove_indices(r, destination) - perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources, [a_rank]) - perm = tf.tensor_scatter_nd_update(perm, tf.expand_dims(destination, 1), - source) - a = tf.transpose(a, perm) + perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources, [a_rank]) + perm = tf.tensor_scatter_nd_update(perm, tf.expand_dims(destination, 1), source) + a = tf.transpose(a, perm) - return utils.tensor_to_ndarray(a) + return utils.tensor_to_ndarray(a) def _setitem(arr, index, value): - """Sets the `value` at `index` in the array `arr`. - - This works by replacing the slice at `index` in the tensor with `value`. - Since tensors are immutable, this builds a new tensor using the `tf.concat` - op. Currently, only 0-d and 1-d indices are supported. - - Note that this may break gradients e.g. - - a = tf_np.array([1, 2, 3]) - old_a_t = a.data - - with tf.GradientTape(persistent=True) as g: - g.watch(a.data) - b = a * 2 - a[0] = 5 - g.gradient(b.data, [a.data]) # [None] - g.gradient(b.data, [old_a_t]) # [[2., 2., 2.]] - - Here `d_b / d_a` is `[None]` since a.data no longer points to the same - tensor. - - Args: - arr: array_like. - index: scalar or 1-d integer array. - value: value to set at index. - - Returns: - ndarray - - Raises: - ValueError: if `index` is not a scalar or 1-d array. - """ - # TODO(srbs): Figure out a solution to the gradient problem. - arr = asarray(arr) - index = asarray(index) - if index.ndim == 0: - index = ravel(index) - elif index.ndim > 1: - raise ValueError('index must be a scalar or a 1-d array.') - value = asarray(value, dtype=arr.dtype) - if arr.shape[len(index):] != value.shape: - value = full(arr.shape[len(index):], value) - prefix_t = arr.data[:index.data[0]] - postfix_t = arr.data[index.data[0] + 1:] - if len(index) == 1: - arr._data = tf.concat( # pylint: disable=protected-access - [prefix_t, tf.expand_dims(value.data, 0), postfix_t], 0) - else: - subarray = arr[index.data[0]] - _setitem(subarray, index[1:], value) - arr._data = tf.concat( # pylint: disable=protected-access - [prefix_t, tf.expand_dims(subarray.data, 0), postfix_t], 0) - - -setattr(arrays_lib.ndarray, 'transpose', transpose) -setattr(arrays_lib.ndarray, 'reshape', _reshape_method_wrapper) -setattr(arrays_lib.ndarray, '__setitem__', _setitem) + """Sets the `value` at `index` in the array `arr`. + + This works by replacing the slice at `index` in the tensor with `value`. + Since tensors are immutable, this builds a new tensor using the `tf.concat` + op. Currently, only 0-d and 1-d indices are supported. + + Note that this may break gradients e.g. + + a = tf_np.array([1, 2, 3]) + old_a_t = a.data + + with tf.GradientTape(persistent=True) as g: + g.watch(a.data) + b = a * 2 + a[0] = 5 + g.gradient(b.data, [a.data]) # [None] + g.gradient(b.data, [old_a_t]) # [[2., 2., 2.]] + + Here `d_b / d_a` is `[None]` since a.data no longer points to the same + tensor. + + Args: + arr: array_like. + index: scalar or 1-d integer array. + value: value to set at index. + + Returns: + ndarray + + Raises: + ValueError: if `index` is not a scalar or 1-d array. + """ + # TODO(srbs): Figure out a solution to the gradient problem. + arr = asarray(arr) + index = asarray(index) + if index.ndim == 0: + index = ravel(index) + elif index.ndim > 1: + raise ValueError("index must be a scalar or a 1-d array.") + value = asarray(value, dtype=arr.dtype) + if arr.shape[len(index) :] != value.shape: + value = full(arr.shape[len(index) :], value) + prefix_t = arr.data[: index.data[0]] + postfix_t = arr.data[index.data[0] + 1 :] + if len(index) == 1: + arr._data = tf.concat( # pylint: disable=protected-access + [prefix_t, tf.expand_dims(value.data, 0), postfix_t], 0 + ) + else: + subarray = arr[index.data[0]] + _setitem(subarray, index[1:], value) + arr._data = tf.concat( # pylint: disable=protected-access + [prefix_t, tf.expand_dims(subarray.data, 0), postfix_t], 0 + ) + + +setattr(arrays_lib.ndarray, "transpose", transpose) +setattr(arrays_lib.ndarray, "reshape", _reshape_method_wrapper) +setattr(arrays_lib.ndarray, "__setitem__", _setitem) def pad(ary, pad_width, mode, constant_values=0): - """Pads an array. - - Args: - ary: array_like of rank N. Input array. - pad_width: {sequence, array_like, int}. - Number of values padded to the edges of each axis. - ((before_1, after_1), ... (before_N, after_N)) unique pad widths - for each axis. - ((before, after),) yields same before and after pad for each axis. - (pad,) or int is a shortcut for before = after = pad width for all - axes. - mode: string. One of the following string values: - 'constant' - Pads with a constant value. - 'reflect' - Pads with the reflection of the vector mirrored on - the first and last values of the vector along each - axis. - 'symmetric' - Pads with the reflection of the vector mirrored - along the edge of the array. - **NOTE**: The supported list of `mode` does not match that of numpy's. - constant_values: scalar with same dtype as `array`. - Used in 'constant' mode as the pad value. Default is 0. - - - Returns: - An ndarray padded array of rank equal to `array` with shape increased - according to `pad_width`. - - Raises: - ValueError if `mode` is not supported. - """ - if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'): - raise ValueError('Unsupported padding mode: ' + mode) - mode = mode.upper() - ary = asarray(ary) - pad_width = asarray(pad_width, dtype=tf.int32) - return utils.tensor_to_ndarray(tf.pad( - tensor=ary.data, paddings=pad_width.data, mode=mode, - constant_values=constant_values)) + """Pads an array. + + Args: + ary: array_like of rank N. Input array. + pad_width: {sequence, array_like, int}. + Number of values padded to the edges of each axis. + ((before_1, after_1), ... (before_N, after_N)) unique pad widths + for each axis. + ((before, after),) yields same before and after pad for each axis. + (pad,) or int is a shortcut for before = after = pad width for all + axes. + mode: string. One of the following string values: + 'constant' + Pads with a constant value. + 'reflect' + Pads with the reflection of the vector mirrored on + the first and last values of the vector along each + axis. + 'symmetric' + Pads with the reflection of the vector mirrored + along the edge of the array. + **NOTE**: The supported list of `mode` does not match that of numpy's. + constant_values: scalar with same dtype as `array`. + Used in 'constant' mode as the pad value. Default is 0. + + + Returns: + An ndarray padded array of rank equal to `array` with shape increased + according to `pad_width`. + + Raises: + ValueError if `mode` is not supported. + """ + if not (mode == "constant" or mode == "reflect" or mode == "symmetric"): + raise ValueError("Unsupported padding mode: " + mode) + mode = mode.upper() + ary = asarray(ary) + pad_width = asarray(pad_width, dtype=tf.int32) + return utils.tensor_to_ndarray( + tf.pad( + tensor=ary.data, + paddings=pad_width.data, + mode=mode, + constant_values=constant_values, + ) + ) @utils.np_doc(np.take) -def take(a, indices, axis=None, out=None, mode='clip'): - """out argument is not supported, and default mode is clip.""" - if out is not None: - raise ValueError('out argument is not supported in take.') +def take(a, indices, axis=None, out=None, mode="clip"): + """out argument is not supported, and default mode is clip.""" + if out is not None: + raise ValueError("out argument is not supported in take.") - if mode not in {'raise', 'clip', 'wrap'}: - raise ValueError("Invalid mode '{}' for take".format(mode)) + if mode not in {"raise", "clip", "wrap"}: + raise ValueError("Invalid mode '{}' for take".format(mode)) - a = asarray(a).data - indices = asarray(indices).data + a = asarray(a).data + indices = asarray(indices).data - if axis is None: - a = tf.reshape(a, [-1]) - axis = 0 + if axis is None: + a = tf.reshape(a, [-1]) + axis = 0 - axis_size = tf.shape(a, indices.dtype)[axis] - if mode == 'clip': - indices = tf.clip_by_value(indices, 0, axis_size-1) - elif mode == 'wrap': - indices = tf.math.floormod(indices, axis_size) - else: - raise ValueError("The 'raise' mode to take is not supported.") + axis_size = tf.shape(a, indices.dtype)[axis] + if mode == "clip": + indices = tf.clip_by_value(indices, 0, axis_size - 1) + elif mode == "wrap": + indices = tf.math.floormod(indices, axis_size) + else: + raise ValueError("The 'raise' mode to take is not supported.") - return utils.tensor_to_ndarray(tf.gather(a, indices, axis=axis)) + return utils.tensor_to_ndarray(tf.gather(a, indices, axis=axis)) @utils.np_doc_only(np.where) def where(condition, x=None, y=None): - """Raises ValueError if exactly one of x or y is not None.""" - condition = asarray(condition, dtype=np.bool_) - if x is None and y is None: - return nonzero(condition) - elif x is not None and y is not None: - x, y = _promote_dtype(x, y) - return utils.tensor_to_ndarray(tf.where(condition.data, x.data, y.data)) - raise ValueError('Both x and y must be ndarrays, or both must be None.') + """Raises ValueError if exactly one of x or y is not None.""" + condition = asarray(condition, dtype=np.bool_) + if x is None and y is None: + return nonzero(condition) + elif x is not None and y is not None: + x, y = _promote_dtype(x, y) + return utils.tensor_to_ndarray(tf.where(condition.data, x.data, y.data)) + raise ValueError("Both x and y must be ndarrays, or both must be None.") @utils.np_doc(np.select) def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring - if len(condlist) != len(choicelist): - msg = 'condlist must have length equal to choicelist ({} vs {})' - raise ValueError(msg.format(len(condlist), len(choicelist))) - if not condlist: - raise ValueError('condlist must be non-empty') - choices = _promote_dtype(default, *choicelist) - choicelist = choices[1:] - output = choices[0] - # The traversal is in reverse order so we can return the first value in - # choicelist where condlist is True. - for cond, choice in zip(condlist[::-1], choicelist[::-1]): - output = where(cond, choice, output) - return output + if len(condlist) != len(choicelist): + msg = "condlist must have length equal to choicelist ({} vs {})" + raise ValueError(msg.format(len(condlist), len(choicelist))) + if not condlist: + raise ValueError("condlist must be non-empty") + choices = _promote_dtype(default, *choicelist) + choicelist = choices[1:] + output = choices[0] + # The traversal is in reverse order so we can return the first value in + # choicelist where condlist is True. + for cond, choice in zip(condlist[::-1], choicelist[::-1]): + output = where(cond, choice, output) + return output def shape(a): - """Return the shape of an array. + """Return the shape of an array. - Args: - a: array_like. Input array. + Args: + a: array_like. Input array. - Returns: - Tuple of ints. - """ - a = asarray(a) - return a.shape + Returns: + Tuple of ints. + """ + a = asarray(a) + return a.shape def ndim(a): - a = asarray(a) - return a.ndim + a = asarray(a) + return a.ndim def isscalar(a): - return ndim(a) == 0 + return ndim(a) == 0 def _boundaries_to_sizes(a, boundaries, axis): - """Converting boundaries of splits to sizes of splits. - - Args: - a: the array to be split. - boundaries: the boundaries, as in np.split. - axis: the axis along which to split. - - Returns: - A list of sizes of the splits, as in tf.split. - """ - if axis >= len(a.shape): - raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape)) - total_size = a.shape[axis] - sizes = [] - sizes_sum = 0 - prev = 0 - for i, b in enumerate(boundaries): - size = b - prev - if size < 0: - raise ValueError('The %s-th boundary %s is smaller than the previous ' - 'boundary %s' % (i, b, prev)) - size = min(size, max(0, total_size - sizes_sum)) - sizes.append(size) - sizes_sum += size - prev = b - sizes.append(max(0, total_size - sizes_sum)) - return sizes + """Converting boundaries of splits to sizes of splits. + + Args: + a: the array to be split. + boundaries: the boundaries, as in np.split. + axis: the axis along which to split. + + Returns: + A list of sizes of the splits, as in tf.split. + """ + if axis >= len(a.shape): + raise ValueError("axis %s is out of bound for shape %s" % (axis, a.shape)) + total_size = a.shape[axis] + sizes = [] + sizes_sum = 0 + prev = 0 + for i, b in enumerate(boundaries): + size = b - prev + if size < 0: + raise ValueError( + "The %s-th boundary %s is smaller than the previous " + "boundary %s" % (i, b, prev) + ) + size = min(size, max(0, total_size - sizes_sum)) + sizes.append(size) + sizes_sum += size + prev = b + sizes.append(max(0, total_size - sizes_sum)) + return sizes @utils.np_doc(np.split) def split(ary, indices_or_sections, axis=0): - ary = asarray(ary) - if not isinstance(indices_or_sections, six.integer_types): - indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis) - result = tf.split(ary.data, indices_or_sections, axis=axis) - return [utils.tensor_to_ndarray(a) for a in result] + ary = asarray(ary) + if not isinstance(indices_or_sections, six.integer_types): + indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis) + result = tf.split(ary.data, indices_or_sections, axis=axis) + return [utils.tensor_to_ndarray(a) for a in result] def _split_on_axis(np_fun, axis): - @utils.np_doc(np_fun) - def f(ary, indices_or_sections): - return split(ary, indices_or_sections, axis=axis) - return f + @utils.np_doc(np_fun) + def f(ary, indices_or_sections): + return split(ary, indices_or_sections, axis=axis) + + return f vsplit = _split_on_axis(np.vsplit, axis=0) @@ -1228,318 +1309,341 @@ def f(ary, indices_or_sections): @utils.np_doc(np.broadcast_to) def broadcast_to(array, shape): # pylint: disable=redefined-outer-name - return full(shape, array) + return full(shape, array) @utils.np_doc(np.stack) def stack(arrays, axis=0): - arrays = _promote_dtype(*arrays) # pylint: disable=protected-access - unwrapped_arrays = [ - a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays - ] - return asarray(tf.stack(unwrapped_arrays, axis)) + arrays = _promote_dtype(*arrays) # pylint: disable=protected-access + unwrapped_arrays = [ + a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays + ] + return asarray(tf.stack(unwrapped_arrays, axis)) @utils.np_doc(np.hstack) def hstack(tup): - arrays = [atleast_1d(a) for a in tup] - arrays = _promote_dtype(*arrays) # pylint: disable=protected-access - unwrapped_arrays = [ - a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays - ] - rank = tf.rank(unwrapped_arrays[0]) - return utils.cond(rank == 1, lambda: tf.concat(unwrapped_arrays, axis=0), - lambda: tf.concat(unwrapped_arrays, axis=1)) + arrays = [atleast_1d(a) for a in tup] + arrays = _promote_dtype(*arrays) # pylint: disable=protected-access + unwrapped_arrays = [ + a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays + ] + rank = tf.rank(unwrapped_arrays[0]) + return utils.cond( + rank == 1, + lambda: tf.concat(unwrapped_arrays, axis=0), + lambda: tf.concat(unwrapped_arrays, axis=1), + ) @utils.np_doc(np.vstack) def vstack(tup): - arrays = [atleast_2d(a) for a in tup] - arrays = _promote_dtype(*arrays) # pylint: disable=protected-access - unwrapped_arrays = [ - a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays - ] - return tf.concat(unwrapped_arrays, axis=0) + arrays = [atleast_2d(a) for a in tup] + arrays = _promote_dtype(*arrays) # pylint: disable=protected-access + unwrapped_arrays = [ + a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays + ] + return tf.concat(unwrapped_arrays, axis=0) @utils.np_doc(np.dstack) def dstack(tup): - arrays = [atleast_3d(a) for a in tup] - arrays = _promote_dtype(*arrays) # pylint: disable=protected-access - unwrapped_arrays = [ - a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays - ] - return tf.concat(unwrapped_arrays, axis=2) + arrays = [atleast_3d(a) for a in tup] + arrays = _promote_dtype(*arrays) # pylint: disable=protected-access + unwrapped_arrays = [ + a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays + ] + return tf.concat(unwrapped_arrays, axis=2) def _pad_left_to(n, old_shape): - old_shape = asarray(old_shape, dtype=np.int32).data - new_shape = tf.pad( - old_shape, [[tf.math.maximum(n - tf.size(old_shape), 0), 0]], - constant_values=1) - return asarray(new_shape) + old_shape = asarray(old_shape, dtype=np.int32).data + new_shape = tf.pad( + old_shape, [[tf.math.maximum(n - tf.size(old_shape), 0), 0]], constant_values=1 + ) + return asarray(new_shape) def _atleast_nd(n, new_shape, *arys): - """Reshape arrays to be at least `n`-dimensional. - - Args: - n: The minimal rank. - new_shape: a function that takes `n` and the old shape and returns the - desired new shape. - *arys: ndarray(s) to be reshaped. - - Returns: - The reshaped array(s). - """ - - def f(x): - # pylint: disable=g-long-lambda - x = asarray(x) - return asarray( - utils.cond( - utils.greater(n, tf.rank(x)), - lambda: reshape(x, new_shape(n, tf.shape(x.data))).data, - lambda: x.data)) - - arys = list(map(f, arys)) - if len(arys) == 1: - return arys[0] - else: - return arys + """Reshape arrays to be at least `n`-dimensional. + + Args: + n: The minimal rank. + new_shape: a function that takes `n` and the old shape and returns the + desired new shape. + *arys: ndarray(s) to be reshaped. + + Returns: + The reshaped array(s). + """ + + def f(x): + # pylint: disable=g-long-lambda + x = asarray(x) + return asarray( + utils.cond( + utils.greater(n, tf.rank(x)), + lambda: reshape(x, new_shape(n, tf.shape(x.data))).data, + lambda: x.data, + ) + ) + + arys = list(map(f, arys)) + if len(arys) == 1: + return arys[0] + else: + return arys @utils.np_doc(np.atleast_1d) def atleast_1d(*arys): - return _atleast_nd(1, _pad_left_to, *arys) + return _atleast_nd(1, _pad_left_to, *arys) @utils.np_doc(np.atleast_2d) def atleast_2d(*arys): - return _atleast_nd(2, _pad_left_to, *arys) + return _atleast_nd(2, _pad_left_to, *arys) @utils.np_doc(np.atleast_3d) def atleast_3d(*arys): # pylint: disable=missing-docstring - - def new_shape(_, old_shape): - # pylint: disable=g-long-lambda - ndim_ = tf.size(old_shape) - return utils.cond( - ndim_ == 0, lambda: tf.constant([1, 1, 1], dtype=tf.int32), - lambda: utils.cond( - ndim_ == 1, lambda: tf.pad(old_shape, [[1, 1]], constant_values=1), - lambda: tf.pad(old_shape, [[0, 1]], constant_values=1))) - - return _atleast_nd(3, new_shape, *arys) + def new_shape(_, old_shape): + # pylint: disable=g-long-lambda + ndim_ = tf.size(old_shape) + return utils.cond( + ndim_ == 0, + lambda: tf.constant([1, 1, 1], dtype=tf.int32), + lambda: utils.cond( + ndim_ == 1, + lambda: tf.pad(old_shape, [[1, 1]], constant_values=1), + lambda: tf.pad(old_shape, [[0, 1]], constant_values=1), + ), + ) + + return _atleast_nd(3, new_shape, *arys) @utils.np_doc(np.nonzero) def nonzero(a): - a = atleast_1d(a).data - if a.shape.rank is None: - raise ValueError("The rank of `a` is unknown, so we can't decide how many " - "arrays to return.") - return tf.nest.map_structure( - arrays_lib.tensor_to_ndarray, - tf.unstack(tf.where(tf.cast(a, tf.bool)), a.shape.rank, axis=1)) + a = atleast_1d(a).data + if a.shape.rank is None: + raise ValueError( + "The rank of `a` is unknown, so we can't decide how many " + "arrays to return." + ) + return tf.nest.map_structure( + arrays_lib.tensor_to_ndarray, + tf.unstack(tf.where(tf.cast(a, tf.bool)), a.shape.rank, axis=1), + ) @utils.np_doc(np.diag_indices) def diag_indices(n, ndim=2): # pylint: disable=missing-docstring,redefined-outer-name - if n < 0: - raise ValueError('n argument to diag_indices must be nonnegative, got {}' - .format(n)) - if ndim < 0: - raise ValueError('ndim argument to diag_indices must be nonnegative, got {}' - .format(ndim)) + if n < 0: + raise ValueError( + "n argument to diag_indices must be nonnegative, got {}".format(n) + ) + if ndim < 0: + raise ValueError( + "ndim argument to diag_indices must be nonnegative, got {}".format(ndim) + ) - return (tf.range(n),) * ndim + return (tf.range(n),) * ndim @utils.np_doc(np.tri) def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-docstring - M = M if M is not None else N - if dtype is not None: - dtype = utils.result_type(dtype) - else: - dtype = dtypes.default_float_type() - - if k < 0: - lower = -k - 1 - if lower > N: - r = tf.zeros([N, M], dtype) + M = M if M is not None else N + if dtype is not None: + dtype = utils.result_type(dtype) else: - # Keep as tf bool, since we create an upper triangular matrix and invert - # it. - o = tf.ones([N, M], dtype=tf.bool) - r = tf.cast(tf.math.logical_not(tf.linalg.band_part(o, lower, -1)), dtype) - else: - o = tf.ones([N, M], dtype) - if k > M: - r = o + dtype = dtypes.default_float_type() + + if k < 0: + lower = -k - 1 + if lower > N: + r = tf.zeros([N, M], dtype) + else: + # Keep as tf bool, since we create an upper triangular matrix and invert + # it. + o = tf.ones([N, M], dtype=tf.bool) + r = tf.cast(tf.math.logical_not(tf.linalg.band_part(o, lower, -1)), dtype) else: - r = tf.linalg.band_part(o, -1, k) - return utils.tensor_to_ndarray(r) + o = tf.ones([N, M], dtype) + if k > M: + r = o + else: + r = tf.linalg.band_part(o, -1, k) + return utils.tensor_to_ndarray(r) @utils.np_doc(np.tril) def tril(m, k=0): # pylint: disable=missing-docstring - m = asarray(m).data - m_shape = m.shape.as_list() + m = asarray(m).data + m_shape = m.shape.as_list() - if len(m_shape) < 2: - raise ValueError('Argument to tril must have rank at least 2') + if len(m_shape) < 2: + raise ValueError("Argument to tril must have rank at least 2") - if m_shape[-1] is None or m_shape[-2] is None: - raise ValueError('Currently, the last two dimensions of the input array ' - 'need to be known.') + if m_shape[-1] is None or m_shape[-2] is None: + raise ValueError( + "Currently, the last two dimensions of the input array " "need to be known." + ) - z = tf.constant(0, m.dtype) + z = tf.constant(0, m.dtype) - mask = tri(*m_shape[-2:], k=k, dtype=bool) - return utils.tensor_to_ndarray( - tf.where(tf.broadcast_to(mask, tf.shape(m)), m, z)) + mask = tri(*m_shape[-2:], k=k, dtype=bool) + return utils.tensor_to_ndarray(tf.where(tf.broadcast_to(mask, tf.shape(m)), m, z)) @utils.np_doc(np.triu) def triu(m, k=0): # pylint: disable=missing-docstring - m = asarray(m).data - m_shape = m.shape.as_list() + m = asarray(m).data + m_shape = m.shape.as_list() - if len(m_shape) < 2: - raise ValueError('Argument to triu must have rank at least 2') + if len(m_shape) < 2: + raise ValueError("Argument to triu must have rank at least 2") - if m_shape[-1] is None or m_shape[-2] is None: - raise ValueError('Currently, the last two dimensions of the input array ' - 'need to be known.') + if m_shape[-1] is None or m_shape[-2] is None: + raise ValueError( + "Currently, the last two dimensions of the input array " "need to be known." + ) - z = tf.constant(0, m.dtype) + z = tf.constant(0, m.dtype) - mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) - return utils.tensor_to_ndarray( - tf.where(tf.broadcast_to(mask, tf.shape(m)), z, m)) + mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) + return utils.tensor_to_ndarray(tf.where(tf.broadcast_to(mask, tf.shape(m)), z, m)) @utils.np_doc(np.flip) def flip(m, axis=None): # pylint: disable=missing-docstring - m = asarray(m).data + m = asarray(m).data - if axis is None: - return utils.tensor_to_ndarray(tf.reverse(m, tf.range(tf.rank(m)))) + if axis is None: + return utils.tensor_to_ndarray(tf.reverse(m, tf.range(tf.rank(m)))) - axis = utils._canonicalize_axis(axis, tf.rank(m)) # pylint: disable=protected-access + axis = utils._canonicalize_axis( + axis, tf.rank(m) + ) # pylint: disable=protected-access - return utils.tensor_to_ndarray(tf.reverse(m, [axis])) + return utils.tensor_to_ndarray(tf.reverse(m, [axis])) @utils.np_doc(np.flipud) def flipud(m): # pylint: disable=missing-docstring - return flip(m, 0) + return flip(m, 0) @utils.np_doc(np.fliplr) def fliplr(m): # pylint: disable=missing-docstring - return flip(m, 1) + return flip(m, 1) @utils.np_doc(np.roll) def roll(a, shift, axis=None): # pylint: disable=missing-docstring - a = asarray(a).data + a = asarray(a).data - if axis is not None: - return utils.tensor_to_ndarray(tf.roll(a, shift, axis)) + if axis is not None: + return utils.tensor_to_ndarray(tf.roll(a, shift, axis)) - # If axis is None, the roll happens as a 1-d tensor. - original_shape = tf.shape(a) - a = tf.roll(tf.reshape(a, [-1]), shift, 0) - return utils.tensor_to_ndarray(tf.reshape(a, original_shape)) + # If axis is None, the roll happens as a 1-d tensor. + original_shape = tf.shape(a) + a = tf.roll(tf.reshape(a, [-1]), shift, 0) + return utils.tensor_to_ndarray(tf.reshape(a, original_shape)) @utils.np_doc(np.rot90) def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring - m_rank = tf.rank(m) - ax1, ax2 = utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access - - k = k % 4 - if k == 0: - return m - elif k == 2: - return flip(flip(m, ax1), ax2) - else: - perm = tf.range(m_rank) - perm = tf.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1]) - - if k == 1: - return transpose(flip(m, ax2), perm) + m_rank = tf.rank(m) + ax1, ax2 = utils._canonicalize_axes( + axes, m_rank + ) # pylint: disable=protected-access + + k = k % 4 + if k == 0: + return m + elif k == 2: + return flip(flip(m, ax1), ax2) else: - return flip(transpose(m, perm), ax2) + perm = tf.range(m_rank) + perm = tf.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1]) + + if k == 1: + return transpose(flip(m, ax2), perm) + else: + return flip(transpose(m, perm), ax2) @utils.np_doc(np.vander) -def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name - x = asarray(x).data - - x_shape = tf.shape(x) - N = N or x_shape[0] - - N_temp = utils.get_static_value(N) # pylint: disable=invalid-name - if N_temp is not None: - N = N_temp - if N < 0: - raise ValueError('N must be nonnegative') - else: - tf.debugging.Assert(N >= 0, [N]) - - rank = tf.rank(x) - rank_temp = utils.get_static_value(rank) - if rank_temp is not None: - rank = rank_temp - if rank != 1: - raise ValueError('x must be a one-dimensional array') - else: - tf.debugging.Assert(rank == 1, [rank]) - - if increasing: - start = 0 - limit = N - delta = 1 - else: - start = N - 1 - limit = -1 - delta = -1 - - x = tf.expand_dims(x, -1) - return utils.tensor_to_ndarray( - tf.math.pow(x, tf.cast(tf.range(start, limit, delta), dtype=x.dtype))) +def vander( + x, N=None, increasing=False +): # pylint: disable=missing-docstring,invalid-name + x = asarray(x).data + + x_shape = tf.shape(x) + N = N or x_shape[0] + + N_temp = utils.get_static_value(N) # pylint: disable=invalid-name + if N_temp is not None: + N = N_temp + if N < 0: + raise ValueError("N must be nonnegative") + else: + tf.debugging.Assert(N >= 0, [N]) + + rank = tf.rank(x) + rank_temp = utils.get_static_value(rank) + if rank_temp is not None: + rank = rank_temp + if rank != 1: + raise ValueError("x must be a one-dimensional array") + else: + tf.debugging.Assert(rank == 1, [rank]) + + if increasing: + start = 0 + limit = N + delta = 1 + else: + start = N - 1 + limit = -1 + delta = -1 + + x = tf.expand_dims(x, -1) + return utils.tensor_to_ndarray( + tf.math.pow(x, tf.cast(tf.range(start, limit, delta), dtype=x.dtype)) + ) @utils.np_doc(np.ix_) def ix_(*args): # pylint: disable=missing-docstring - n = len(args) - output = [] - for i, a in enumerate(args): - a = asarray(a).data - a_rank = tf.rank(a) - a_rank_temp = utils.get_static_value(a_rank) - if a_rank_temp is not None: - a_rank = a_rank_temp - if a_rank != 1: - raise ValueError( - 'Arguments must be 1-d, got arg {} of rank {}'.format(i, a_rank)) - else: - tf.debugging.Assert(a_rank == 1, [a_rank]) - - new_shape = [1] * n - new_shape[i] = -1 - dtype = a.dtype - if dtype == tf.bool: - output.append( - utils.tensor_to_ndarray(tf.reshape(nonzero(a)[0].data, new_shape))) - elif dtype.is_integer: - output.append(utils.tensor_to_ndarray(tf.reshape(a, new_shape))) - else: - raise ValueError( - 'Only integer and bool dtypes are supported, got {}'.format(dtype)) + n = len(args) + output = [] + for i, a in enumerate(args): + a = asarray(a).data + a_rank = tf.rank(a) + a_rank_temp = utils.get_static_value(a_rank) + if a_rank_temp is not None: + a_rank = a_rank_temp + if a_rank != 1: + raise ValueError( + "Arguments must be 1-d, got arg {} of rank {}".format(i, a_rank) + ) + else: + tf.debugging.Assert(a_rank == 1, [a_rank]) + + new_shape = [1] * n + new_shape[i] = -1 + dtype = a.dtype + if dtype == tf.bool: + output.append( + utils.tensor_to_ndarray(tf.reshape(nonzero(a)[0].data, new_shape)) + ) + elif dtype.is_integer: + output.append(utils.tensor_to_ndarray(tf.reshape(a, new_shape))) + else: + raise ValueError( + "Only integer and bool dtypes are supported, got {}".format(dtype) + ) - return output + return output diff --git a/trax/tf_numpy/numpy_impl/arrays.py b/trax/tf_numpy/numpy_impl/arrays.py index 0329c25d0..d2d578736 100644 --- a/trax/tf_numpy/numpy_impl/arrays.py +++ b/trax/tf_numpy/numpy_impl/arrays.py @@ -23,217 +23,226 @@ def convert_to_tensor(value, dtype=None): - # A safer version of `tf.convert_to_tensor` to work around b/149876037. - # TODO(wangpeng): Remove this function once the bug is fixed. - if (dtype is None and isinstance(value, six.integer_types) - and value >= 2 ** 63): - dtype = tf.uint64 - elif (dtype is None and isinstance(value, float)): - dtype = dtypes.default_float_type() - return tf.convert_to_tensor(value, dtype=dtype) + # A safer version of `tf.convert_to_tensor` to work around b/149876037. + # TODO(wangpeng): Remove this function once the bug is fixed. + if dtype is None and isinstance(value, six.integer_types) and value >= 2**63: + dtype = tf.uint64 + elif dtype is None and isinstance(value, float): + dtype = dtypes.default_float_type() + return tf.convert_to_tensor(value, dtype=dtype) class ndarray(object): # pylint: disable=invalid-name - """Equivalent of numpy.ndarray backed by TensorFlow tensors. - - This does not support all features of NumPy ndarrays e.g. strides and - memory order since, unlike NumPy, the backing storage is not a raw memory - buffer. - - TODO(srbs): Clearly specify which attributes and methods are not supported - or if there are any differences in behavior. - """ - - def __init__(self, shape, dtype=float, buffer=None): # pylint: disable=redefined-builtin - """Initializes an ndarray. - - This is a low level interface for building ndarrays and should be avoided. - Users should instead use methods in array_creation.py. - - This class provides a numpy.ndarray like interface for a TF Tensor with a - fully-defined shape. Note that, unlike the backing buffer of np.ndarray, - Tensors are immutable. So, operations like `__setitem__` are performed by - replacing the Tensor. This restricts the ability to implement NumPy `view` - semantics. - - Compared to numpy.ndarray, this does not support `offset`, `strides` - and `order` arguments. - - Args: - shape: The shape of the array. Must be a scalar, an iterable of integers - or a `TensorShape` object. - dtype: Optional. The dtype of the array. Must be a python type, a numpy - type or a tensorflow `DType` object. - buffer: Optional. The backing buffer of the array. Must have shape - `shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`. - - Raises: - ValueError: If `buffer` is specified and its shape does not match - `shape`. - """ - if dtype and not isinstance(dtype, tf.DType): - dtype = tf.as_dtype(np.dtype(dtype)) - if buffer is None: - buffer = tf.zeros(shape, dtype=dtype) - else: - if isinstance(buffer, ndarray): - buffer = buffer.data - elif isinstance(buffer, np.ndarray): - # If `buffer` is a np.ndarray, the Tensor will share the underlying - # storage of the array. - buffer = convert_to_tensor(value=buffer, dtype=dtype) - elif not isinstance(buffer, tf.Tensor): - raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,' - ' Tensor or np.ndarray.'.format(type(buffer))) - - if shape is not None and tuple(shape) != buffer._shape_tuple(): # pylint: disable=protected-access - # TODO(srbs): NumPy allows this. Investigate if/how to support this. - raise ValueError('shape arg must match buffer.shape.') - - assert isinstance(buffer, tf.Tensor) - if dtype and dtype != buffer.dtype: - buffer = tf.bitcast(buffer, dtype) - self._data = buffer - self.base = None - - @property - def data(self): - """Tensor object containing the array data. - - This has a few key differences from the Python buffer object used in - NumPy arrays. - 1. Tensors are immutable. So operations requiring in-place edit, e.g. - __setitem__, are performed by replacing the underlying buffer with a new - one. - 2. Tensors do not provide access to their raw buffer. - - Returns: - A Tensor. - """ - return self._data - - @property - def shape(self): - """Returns a tuple of array dimensions.""" - return self.data._shape_tuple() # pylint: disable=protected-access - - @property - def dtype(self): - return np.dtype(self.data.dtype.as_numpy_dtype) - - @property - def ndim(self): - return self.data.shape.ndims - - @property - def size(self): - """Returns the number of elements in the array.""" - return np.prod(self.shape) - - @property - def T(self): # pylint: disable=invalid-name - return self.transpose() - - def __len__(self): - if self.shape: - return self.shape[0] - else: - raise TypeError('len() of unsized object.') - - def astype(self, dtype): - if self.dtype == dtype: - return self - else: - return tensor_to_ndarray(tf.cast(self.data, dtype)) - - # Unary operations - def __neg__(self): - return tensor_to_ndarray(-self.data) # pylint: disable=invalid-unary-operand-type - - def __pos__(self): - return self + """Equivalent of numpy.ndarray backed by TensorFlow tensors. - __hash__ = None + This does not support all features of NumPy ndarrays e.g. strides and + memory order since, unlike NumPy, the backing storage is not a raw memory + buffer. - def __int__(self): - return int(self.data) - - def __float__(self): - return float(self.data) - - def __nonzero__(self): - return bool(self.data) - - def __bool__(self): - return self.__nonzero__() + TODO(srbs): Clearly specify which attributes and methods are not supported + or if there are any differences in behavior. + """ - def __getitem__(self, slice_spec): - # TODO(srbs): Need to support better indexing. - result_t = self.data.__getitem__(slice_spec) - return tensor_to_ndarray(result_t) + def __init__( + self, shape, dtype=float, buffer=None + ): # pylint: disable=redefined-builtin + """Initializes an ndarray. + + This is a low level interface for building ndarrays and should be avoided. + Users should instead use methods in array_creation.py. + + This class provides a numpy.ndarray like interface for a TF Tensor with a + fully-defined shape. Note that, unlike the backing buffer of np.ndarray, + Tensors are immutable. So, operations like `__setitem__` are performed by + replacing the Tensor. This restricts the ability to implement NumPy `view` + semantics. + + Compared to numpy.ndarray, this does not support `offset`, `strides` + and `order` arguments. + + Args: + shape: The shape of the array. Must be a scalar, an iterable of integers + or a `TensorShape` object. + dtype: Optional. The dtype of the array. Must be a python type, a numpy + type or a tensorflow `DType` object. + buffer: Optional. The backing buffer of the array. Must have shape + `shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`. + + Raises: + ValueError: If `buffer` is specified and its shape does not match + `shape`. + """ + if dtype and not isinstance(dtype, tf.DType): + dtype = tf.as_dtype(np.dtype(dtype)) + if buffer is None: + buffer = tf.zeros(shape, dtype=dtype) + else: + if isinstance(buffer, ndarray): + buffer = buffer.data + elif isinstance(buffer, np.ndarray): + # If `buffer` is a np.ndarray, the Tensor will share the underlying + # storage of the array. + buffer = convert_to_tensor(value=buffer, dtype=dtype) + elif not isinstance(buffer, tf.Tensor): + raise ValueError( + "Unexpected type for `buffer` {}. Must be an ndarray," + " Tensor or np.ndarray.".format(type(buffer)) + ) + + if ( + shape is not None and tuple(shape) != buffer._shape_tuple() + ): # pylint: disable=protected-access + # TODO(srbs): NumPy allows this. Investigate if/how to support this. + raise ValueError("shape arg must match buffer.shape.") + + assert isinstance(buffer, tf.Tensor) + if dtype and dtype != buffer.dtype: + buffer = tf.bitcast(buffer, dtype) + self._data = buffer + self.base = None + + @property + def data(self): + """Tensor object containing the array data. + + This has a few key differences from the Python buffer object used in + NumPy arrays. + 1. Tensors are immutable. So operations requiring in-place edit, e.g. + __setitem__, are performed by replacing the underlying buffer with a new + one. + 2. Tensors do not provide access to their raw buffer. + + Returns: + A Tensor. + """ + return self._data + + @property + def shape(self): + """Returns a tuple of array dimensions.""" + return self.data._shape_tuple() # pylint: disable=protected-access + + @property + def dtype(self): + return np.dtype(self.data.dtype.as_numpy_dtype) + + @property + def ndim(self): + return self.data.shape.ndims + + @property + def size(self): + """Returns the number of elements in the array.""" + return np.prod(self.shape) + + @property + def T(self): # pylint: disable=invalid-name + return self.transpose() + + def __len__(self): + if self.shape: + return self.shape[0] + else: + raise TypeError("len() of unsized object.") + + def astype(self, dtype): + if self.dtype == dtype: + return self + else: + return tensor_to_ndarray(tf.cast(self.data, dtype)) + + # Unary operations + def __neg__(self): + return tensor_to_ndarray( + -self.data + ) # pylint: disable=invalid-unary-operand-type + + def __pos__(self): + return self + + __hash__ = None + + def __int__(self): + return int(self.data) + + def __float__(self): + return float(self.data) + + def __nonzero__(self): + return bool(self.data) + + def __bool__(self): + return self.__nonzero__() + + def __getitem__(self, slice_spec): + # TODO(srbs): Need to support better indexing. + result_t = self.data.__getitem__(slice_spec) + return tensor_to_ndarray(result_t) + + def __iter__(self): + for i in range(self.shape[0]): + result_t = self.data[i] + yield tensor_to_ndarray(result_t) + return + + def __array__(self, dtype=None): + """Returns a NumPy ndarray. - def __iter__(self): - for i in range(self.shape[0]): - result_t = self.data[i] - yield tensor_to_ndarray(result_t) - return + This allows instances of this class to be directly used in NumPy routines. + However, doing that may force a copy to CPU. - def __array__(self, dtype=None): - """Returns a NumPy ndarray. + Args: + dtype: A NumPy compatible type. - This allows instances of this class to be directly used in NumPy routines. - However, doing that may force a copy to CPU. + Returns: + A NumPy ndarray. + """ + return np.asarray(self.data, dtype) - Args: - dtype: A NumPy compatible type. + __array_priority__ = 110 - Returns: - A NumPy ndarray. - """ - return np.asarray(self.data, dtype) + def __index__(self): + """Returns a python scalar. - __array_priority__ = 110 + This allows using an instance of this class as an array index. + Note that only arrays of integer types with size 1 can be used as array + indices. - def __index__(self): - """Returns a python scalar. + Returns: + A Python scalar. - This allows using an instance of this class as an array index. - Note that only arrays of integer types with size 1 can be used as array - indices. + Raises: + TypeError: If the array is not of an integer type. + ValueError: If the array does not have size 1. + """ + # TODO(wangpeng): Handle graph mode + return self.data.numpy().item() - Returns: - A Python scalar. + def tolist(self): + return self.data.numpy().tolist() - Raises: - TypeError: If the array is not of an integer type. - ValueError: If the array does not have size 1. - """ - # TODO(wangpeng): Handle graph mode - return self.data.numpy().item() + def __str__(self): + return "ndarray<{}>".format(self.data.__str__()) - def tolist(self): - return self.data.numpy().tolist() - - def __str__(self): - return 'ndarray<{}>'.format(self.data.__str__()) - - def __repr__(self): - return 'ndarray<{}>'.format(self.data.__repr__()) + def __repr__(self): + return "ndarray<{}>".format(self.data.__repr__()) def tensor_to_ndarray(tensor): - return ndarray(tensor._shape_tuple(), dtype=tensor.dtype, buffer=tensor) # pylint: disable=protected-access + return ndarray( + tensor._shape_tuple(), dtype=tensor.dtype, buffer=tensor + ) # pylint: disable=protected-access def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False): - if as_ref: - raise ValueError('as_ref is not supported.') - if dtype and tf.as_dtype(arr.dtype) != dtype: - return tf.cast(arr.data, dtype) - result_t = arr.data - if name: - result_t = tf.identity(result_t, name=name) - return result_t + if as_ref: + raise ValueError("as_ref is not supported.") + if dtype and tf.as_dtype(arr.dtype) != dtype: + return tf.cast(arr.data, dtype) + result_t = arr.data + if name: + result_t = tf.identity(result_t, name=name) + return result_t tf.register_tensor_conversion_function(ndarray, ndarray_to_tensor) @@ -242,45 +251,50 @@ def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False): # Don't use a namedtuple since nest considers that a tuple and unflattens and # flattens it. class ShardedNdArray(object): - """Wrapper over ndarray that can contain tensors on multiple devices. + """Wrapper over ndarray that can contain tensors on multiple devices. This is returned by extensions.pmap, and contains the individual tensors on different devices. - """ + """ - def __init__(self, tensors): - """Initializes the ShardedNdArray. + def __init__(self, tensors): + """Initializes the ShardedNdArray. - Note that the tensors should be ordered in the way the pmap producing these - tensors is run. + Note that the tensors should be ordered in the way the pmap producing these + tensors is run. - Args: - tensors: list or tuple of eager tensors, one for each device. - """ + Args: + tensors: list or tuple of eager tensors, one for each device. + """ - if not isinstance(tensors, (list, tuple)) or not tensors: - raise ValueError( - 'Unable to create a ShardedNdArray without a list of tensors.') - self.tensors = tensors - self.n_devices = len(tensors) + if not isinstance(tensors, (list, tuple)) or not tensors: + raise ValueError( + "Unable to create a ShardedNdArray without a list of tensors." + ) + self.tensors = tensors + self.n_devices = len(tensors) - def __getitem__(self, i): - return self.tensors[i] + def __getitem__(self, i): + return self.tensors[i] - @property - def shape(self): - return (self.n_devices,) + self.tensors[0]._shape_tuple() # pylint: disable=protected-access + @property + def shape(self): + return (self.n_devices,) + self.tensors[ + 0 + ]._shape_tuple() # pylint: disable=protected-access - @property - def dtype(self): - return x.tensors[0].dtype + @property + def dtype(self): + return x.tensors[0].dtype def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): - del args, kwargs - # TODO(nareshmodi): Consider a collective op to gather the tensors from the - # various devices for performance reasons. - return tf.stack(value.tensors) + del args, kwargs + # TODO(nareshmodi): Consider a collective op to gather the tensors from the + # various devices for performance reasons. + return tf.stack(value.tensors) + tf.register_tensor_conversion_function( - ShardedNdArray, convert_sharded_tensor_to_eager_tensor) + ShardedNdArray, convert_sharded_tensor_to_eager_tensor +) diff --git a/trax/tf_numpy/numpy_impl/dtypes.py b/trax/tf_numpy/numpy_impl/dtypes.py index 6ba976313..811b653dd 100644 --- a/trax/tf_numpy/numpy_impl/dtypes.py +++ b/trax/tf_numpy/numpy_impl/dtypes.py @@ -21,40 +21,18 @@ # `if x.dtype.type is np.int64`. # pylint: disable=unused-import # pylint: disable=g-bad-import-order -from numpy import bool_ -from numpy import int_ -from numpy import int16 -from numpy import int32 -from numpy import int64 -from numpy import int8 -from numpy import uint16 -from numpy import uint32 -from numpy import uint64 -from numpy import uint8 -from numpy import float_ -from numpy import float16 from numpy import float32 from numpy import float64 -from numpy import complex_ -from numpy import complex64 -from numpy import complex128 - -from numpy import inexact - -from numpy import iinfo -from numpy import issubdtype - -from numpy import inf # TODO(wangpeng): Make bfloat16 a numpy dtype instead of using TF's -from tensorflow.compat.v2 import bfloat16 + # pylint: enable=g-bad-import-order # pylint: enable=unused-import _to_float32 = { - np.dtype('float64'): np.dtype('float32'), - np.dtype('complex128'): np.dtype('complex64'), + np.dtype("float64"): np.dtype("float32"), + np.dtype("complex128"): np.dtype("complex64"), } @@ -62,33 +40,33 @@ def is_allow_float64(): - return _allow_float64 + return _allow_float64 def set_allow_float64(b): - global _allow_float64 - _allow_float64 = b + global _allow_float64 + _allow_float64 = b def canonicalize_dtype(dtype): - if not is_allow_float64(): - return _to_float32.get(dtype, dtype) - else: - return dtype + if not is_allow_float64(): + return _to_float32.get(dtype, dtype) + else: + return dtype def _result_type(*arrays_and_dtypes): - dtype = np.result_type(*arrays_and_dtypes) - return canonicalize_dtype(dtype) + dtype = np.result_type(*arrays_and_dtypes) + return canonicalize_dtype(dtype) def default_float_type(): - """Gets the default float type. - - Returns: - If `is_allow_float64()` is true, returns float64; otherwise returns float32. - """ - if is_allow_float64(): - return float64 - else: - return float32 + """Gets the default float type. + + Returns: + If `is_allow_float64()` is true, returns float64; otherwise returns float32. + """ + if is_allow_float64(): + return float64 + else: + return float32 diff --git a/trax/tf_numpy/numpy_impl/math_ops.py b/trax/tf_numpy/numpy_impl/math_ops.py index aac45cc2c..05aadbeda 100644 --- a/trax/tf_numpy/numpy_impl/math_ops.py +++ b/trax/tf_numpy/numpy_impl/math_ops.py @@ -29,72 +29,78 @@ @utils.np_doc_only(np.dot) def dot(a, b): # pylint: disable=missing-docstring - def f(a, b): # pylint: disable=missing-docstring - return utils.cond( - utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0), - lambda: a * b, - lambda: utils.cond( # pylint: disable=g-long-lambda - tf.rank(b) == 1, - lambda: tf.tensordot(a, b, axes=[[-1], [-1]]), - lambda: tf.tensordot(a, b, axes=[[-1], [-2]]))) - return _bin_op(f, a, b) + def f(a, b): # pylint: disable=missing-docstring + return utils.cond( + utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0), + lambda: a * b, + lambda: utils.cond( # pylint: disable=g-long-lambda + tf.rank(b) == 1, + lambda: tf.tensordot(a, b, axes=[[-1], [-1]]), + lambda: tf.tensordot(a, b, axes=[[-1], [-2]]), + ), + ) + + return _bin_op(f, a, b) # TODO(wangpeng): Make element-wise ops `ufunc`s def _bin_op(tf_fun, a, b, promote=True): - if promote: - a, b = array_ops._promote_dtype(a, b) # pylint: disable=protected-access - else: - a = array_ops.array(a) - b = array_ops.array(b) - return utils.tensor_to_ndarray(tf_fun(a.data, b.data)) + if promote: + a, b = array_ops._promote_dtype(a, b) # pylint: disable=protected-access + else: + a = array_ops.array(a) + b = array_ops.array(b) + return utils.tensor_to_ndarray(tf_fun(a.data, b.data)) @utils.np_doc(np.add) def add(x1, x2): - def add_or_or(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - return tf.logical_or(x1, x2) - return tf.add(x1, x2) - return _bin_op(add_or_or, x1, x2) + def add_or_or(x1, x2): + if x1.dtype == tf.bool: + assert x2.dtype == tf.bool + return tf.logical_or(x1, x2) + return tf.add(x1, x2) + + return _bin_op(add_or_or, x1, x2) @utils.np_doc(np.subtract) def subtract(x1, x2): - return _bin_op(tf.subtract, x1, x2) + return _bin_op(tf.subtract, x1, x2) @utils.np_doc(np.multiply) def multiply(x1, x2): - def mul_or_and(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - return tf.logical_and(x1, x2) - return tf.multiply(x1, x2) - return _bin_op(mul_or_and, x1, x2) + def mul_or_and(x1, x2): + if x1.dtype == tf.bool: + assert x2.dtype == tf.bool + return tf.logical_and(x1, x2) + return tf.multiply(x1, x2) + + return _bin_op(mul_or_and, x1, x2) @utils.np_doc(np.true_divide) def true_divide(x1, x2): - def _avoid_float64(x1, x2): - if x1.dtype == x2.dtype and x1.dtype in (tf.int32, tf.int64): - x1 = tf.cast(x1, dtype=tf.float32) - x2 = tf.cast(x2, dtype=tf.float32) - return x1, x2 - - def f(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - float_ = dtypes.default_float_type() - x1 = tf.cast(x1, float_) - x2 = tf.cast(x2, float_) - if not dtypes.is_allow_float64(): - # tf.math.truediv in Python3 produces float64 when both inputs are int32 - # or int64. We want to avoid that when is_allow_float64() is False. - x1, x2 = _avoid_float64(x1, x2) - return tf.math.truediv(x1, x2) - return _bin_op(f, x1, x2) + def _avoid_float64(x1, x2): + if x1.dtype == x2.dtype and x1.dtype in (tf.int32, tf.int64): + x1 = tf.cast(x1, dtype=tf.float32) + x2 = tf.cast(x2, dtype=tf.float32) + return x1, x2 + + def f(x1, x2): + if x1.dtype == tf.bool: + assert x2.dtype == tf.bool + float_ = dtypes.default_float_type() + x1 = tf.cast(x1, float_) + x2 = tf.cast(x2, float_) + if not dtypes.is_allow_float64(): + # tf.math.truediv in Python3 produces float64 when both inputs are int32 + # or int64. We want to avoid that when is_allow_float64() is False. + x1, x2 = _avoid_float64(x1, x2) + return tf.math.truediv(x1, x2) + + return _bin_op(f, x1, x2) divide = true_divide @@ -102,24 +108,26 @@ def f(x1, x2): @utils.np_doc(np.floor_divide) def floor_divide(x1, x2): - def f(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - x1 = tf.cast(x1, tf.int8) - x2 = tf.cast(x2, tf.int8) - return tf.math.floordiv(x1, x2) - return _bin_op(f, x1, x2) + def f(x1, x2): + if x1.dtype == tf.bool: + assert x2.dtype == tf.bool + x1 = tf.cast(x1, tf.int8) + x2 = tf.cast(x2, tf.int8) + return tf.math.floordiv(x1, x2) + + return _bin_op(f, x1, x2) @utils.np_doc(np.mod) def mod(x1, x2): - def f(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - x1 = tf.cast(x1, tf.int8) - x2 = tf.cast(x2, tf.int8) - return tf.math.mod(x1, x2) - return _bin_op(f, x1, x2) + def f(x1, x2): + if x1.dtype == tf.bool: + assert x2.dtype == tf.bool + x1 = tf.cast(x1, tf.int8) + x2 = tf.cast(x2, tf.int8) + return tf.math.mod(x1, x2) + + return _bin_op(f, x1, x2) remainder = mod @@ -127,477 +135,533 @@ def f(x1, x2): @utils.np_doc(np.divmod) def divmod(x1, x2): - return floor_divide(x1, x2), mod(x1, x2) + return floor_divide(x1, x2), mod(x1, x2) @utils.np_doc(np.maximum) def maximum(x1, x2): - def max_or_or(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - return tf.logical_or(x1, x2) - return tf.math.maximum(x1, x2) - return _bin_op(max_or_or, x1, x2) + def max_or_or(x1, x2): + if x1.dtype == tf.bool: + assert x2.dtype == tf.bool + return tf.logical_or(x1, x2) + return tf.math.maximum(x1, x2) + + return _bin_op(max_or_or, x1, x2) @utils.np_doc(np.minimum) def minimum(x1, x2): - def min_or_and(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - return tf.logical_and(x1, x2) - return tf.math.minimum(x1, x2) - return _bin_op(min_or_and, x1, x2) + def min_or_and(x1, x2): + if x1.dtype == tf.bool: + assert x2.dtype == tf.bool + return tf.logical_and(x1, x2) + return tf.math.minimum(x1, x2) + + return _bin_op(min_or_and, x1, x2) @utils.np_doc(np.clip) def clip(a, a_min, a_max): # pylint: disable=missing-docstring - if a_min is None and a_max is None: - raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.') - if a_min is None: - return minimum(a, a_max) - elif a_max is None: - return maximum(a, a_min) - else: - a, a_min, a_max = array_ops._promote_dtype(a, a_min, a_max) # pylint: disable=protected-access - return utils.tensor_to_ndarray( - tf.clip_by_value(*utils.tf_broadcast(a.data, a_min.data, a_max.data))) + if a_min is None and a_max is None: + raise ValueError("Not more than one of `a_min` and `a_max` may be `None`.") + if a_min is None: + return minimum(a, a_max) + elif a_max is None: + return maximum(a, a_min) + else: + a, a_min, a_max = array_ops._promote_dtype( + a, a_min, a_max + ) # pylint: disable=protected-access + return utils.tensor_to_ndarray( + tf.clip_by_value(*utils.tf_broadcast(a.data, a_min.data, a_max.data)) + ) @utils.np_doc(np.matmul) def matmul(x1, x2): # pylint: disable=missing-docstring - def f(x1, x2): - try: - return utils.cond(tf.rank(x2) == 1, - lambda: tf.tensordot(x1, x2, axes=1), - lambda: utils.cond(tf.rank(x1) == 1, # pylint: disable=g-long-lambda - lambda: tf.tensordot( # pylint: disable=g-long-lambda - x1, x2, axes=[[0], [-2]]), - lambda: tf.matmul(x1, x2))) - except tf.errors.InvalidArgumentError as err: - six.reraise(ValueError, ValueError(str(err)), sys.exc_info()[2]) - return _bin_op(f, x1, x2) + def f(x1, x2): + try: + return utils.cond( + tf.rank(x2) == 1, + lambda: tf.tensordot(x1, x2, axes=1), + lambda: utils.cond( + tf.rank(x1) == 1, # pylint: disable=g-long-lambda + lambda: tf.tensordot( # pylint: disable=g-long-lambda + x1, x2, axes=[[0], [-2]] + ), + lambda: tf.matmul(x1, x2), + ), + ) + except tf.errors.InvalidArgumentError as err: + six.reraise(ValueError, ValueError(str(err)), sys.exc_info()[2]) + + return _bin_op(f, x1, x2) @utils.np_doc(np.tensordot) def tensordot(a, b, axes=2): - return _bin_op(lambda a, b: tf.tensordot(a, b, axes=axes), a, b) + return _bin_op(lambda a, b: tf.tensordot(a, b, axes=axes), a, b) @utils.np_doc_only(np.inner) def inner(a, b): - def f(a, b): - return utils.cond(utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0), - lambda: a * b, - lambda: tf.tensordot(a, b, axes=[[-1], [-1]])) - return _bin_op(f, a, b) + def f(a, b): + return utils.cond( + utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0), + lambda: a * b, + lambda: tf.tensordot(a, b, axes=[[-1], [-1]]), + ) + + return _bin_op(f, a, b) @utils.np_doc(np.cross) -def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=missing-docstring - def f(a, b): # pylint: disable=missing-docstring - # We can't assign to captured variable `axisa`, so make a new variable - axis_a = axisa - axis_b = axisb - axis_c = axisc - if axis is not None: - axis_a = axis - axis_b = axis - axis_c = axis - if axis_a < 0: - axis_a = utils.add(axis_a, tf.rank(a)) - if axis_b < 0: - axis_b = utils.add(axis_b, tf.rank(b)) - def maybe_move_axis_to_last(a, axis): - def move_axis_to_last(a, axis): - return tf.transpose( - a, tf.concat( - [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]], - axis=0)) - return utils.cond( - axis == utils.subtract(tf.rank(a), 1), - lambda: a, - lambda: move_axis_to_last(a, axis)) - a = maybe_move_axis_to_last(a, axis_a) - b = maybe_move_axis_to_last(b, axis_b) - a_dim = utils.getitem(tf.shape(a), -1) - b_dim = utils.getitem(tf.shape(b), -1) - def maybe_pad_0(a, size_of_last_dim): - def pad_0(a): - return tf.pad(a, tf.concat([tf.zeros([tf.rank(a) - 1, 2], tf.int32), - tf.constant([[0, 1]], tf.int32)], axis=0)) - return utils.cond(size_of_last_dim == 2, - lambda: pad_0(a), - lambda: a) - a = maybe_pad_0(a, a_dim) - b = maybe_pad_0(b, b_dim) - c = tf.linalg.cross(*utils.tf_broadcast(a, b)) - if axis_c < 0: - axis_c = utils.add(axis_c, tf.rank(c)) - def move_last_to_axis(a, axis): - r = tf.rank(a) - return tf.transpose( - a, tf.concat( - [tf.range(axis), [r - 1], tf.range(axis, r - 1)], axis=0)) - c = utils.cond( - (a_dim == 2) & (b_dim == 2), - lambda: c[..., 2], - lambda: utils.cond( # pylint: disable=g-long-lambda - axis_c == utils.subtract(tf.rank(c), 1), - lambda: c, - lambda: move_last_to_axis(c, axis_c))) - return c - return _bin_op(f, a, b) +def cross( + a, b, axisa=-1, axisb=-1, axisc=-1, axis=None +): # pylint: disable=missing-docstring + def f(a, b): # pylint: disable=missing-docstring + # We can't assign to captured variable `axisa`, so make a new variable + axis_a = axisa + axis_b = axisb + axis_c = axisc + if axis is not None: + axis_a = axis + axis_b = axis + axis_c = axis + if axis_a < 0: + axis_a = utils.add(axis_a, tf.rank(a)) + if axis_b < 0: + axis_b = utils.add(axis_b, tf.rank(b)) + + def maybe_move_axis_to_last(a, axis): + def move_axis_to_last(a, axis): + return tf.transpose( + a, + tf.concat( + [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]], axis=0 + ), + ) + + return utils.cond( + axis == utils.subtract(tf.rank(a), 1), + lambda: a, + lambda: move_axis_to_last(a, axis), + ) + + a = maybe_move_axis_to_last(a, axis_a) + b = maybe_move_axis_to_last(b, axis_b) + a_dim = utils.getitem(tf.shape(a), -1) + b_dim = utils.getitem(tf.shape(b), -1) + + def maybe_pad_0(a, size_of_last_dim): + def pad_0(a): + return tf.pad( + a, + tf.concat( + [ + tf.zeros([tf.rank(a) - 1, 2], tf.int32), + tf.constant([[0, 1]], tf.int32), + ], + axis=0, + ), + ) + + return utils.cond(size_of_last_dim == 2, lambda: pad_0(a), lambda: a) + + a = maybe_pad_0(a, a_dim) + b = maybe_pad_0(b, b_dim) + c = tf.linalg.cross(*utils.tf_broadcast(a, b)) + if axis_c < 0: + axis_c = utils.add(axis_c, tf.rank(c)) + + def move_last_to_axis(a, axis): + r = tf.rank(a) + return tf.transpose( + a, tf.concat([tf.range(axis), [r - 1], tf.range(axis, r - 1)], axis=0) + ) + + c = utils.cond( + (a_dim == 2) & (b_dim == 2), + lambda: c[..., 2], + lambda: utils.cond( # pylint: disable=g-long-lambda + axis_c == utils.subtract(tf.rank(c), 1), + lambda: c, + lambda: move_last_to_axis(c, axis_c), + ), + ) + return c + + return _bin_op(f, a, b) @utils.np_doc(np.power) def power(x1, x2): - return _bin_op(tf.math.pow, x1, x2) + return _bin_op(tf.math.pow, x1, x2) @utils.np_doc(np.float_power) def float_power(x1, x2): - return power(x1, x2) + return power(x1, x2) @utils.np_doc(np.arctan2) def arctan2(x1, x2): - return _bin_op(tf.math.atan2, x1, x2) + return _bin_op(tf.math.atan2, x1, x2) @utils.np_doc(np.nextafter) def nextafter(x1, x2): - return _bin_op(tf.math.nextafter, x1, x2) + return _bin_op(tf.math.nextafter, x1, x2) @utils.np_doc(np.heaviside) def heaviside(x1, x2): - def f(x1, x2): - return tf.where(x1 < 0, tf.constant(0, dtype=x2.dtype), - tf.where(x1 > 0, tf.constant(1, dtype=x2.dtype), x2)) - y = _bin_op(f, x1, x2) - if not np.issubdtype(y.dtype, np.inexact): - y = y.astype(dtypes.default_float_type()) - return y + def f(x1, x2): + return tf.where( + x1 < 0, + tf.constant(0, dtype=x2.dtype), + tf.where(x1 > 0, tf.constant(1, dtype=x2.dtype), x2), + ) + + y = _bin_op(f, x1, x2) + if not np.issubdtype(y.dtype, np.inexact): + y = y.astype(dtypes.default_float_type()) + return y @utils.np_doc(np.hypot) def hypot(x1, x2): - return sqrt(square(x1) + square(x2)) + return sqrt(square(x1) + square(x2)) @utils.np_doc(np.kron) def kron(a, b): - # pylint: disable=protected-access,g-complex-comprehension - a, b = array_ops._promote_dtype(a, b) - ndim = max(a.ndim, b.ndim) - if a.ndim < ndim: - a = array_ops.reshape(a, array_ops._pad_left_to(ndim, a.shape)) - if b.ndim < ndim: - b = array_ops.reshape(b, array_ops._pad_left_to(ndim, b.shape)) - a_reshaped = array_ops.reshape(a, [i for d in a.shape for i in (d, 1)]) - b_reshaped = array_ops.reshape(b, [i for d in b.shape for i in (1, d)]) - out_shape = tuple(np.multiply(a.shape, b.shape)) - return array_ops.reshape(a_reshaped * b_reshaped, out_shape) + # pylint: disable=protected-access,g-complex-comprehension + a, b = array_ops._promote_dtype(a, b) + ndim = max(a.ndim, b.ndim) + if a.ndim < ndim: + a = array_ops.reshape(a, array_ops._pad_left_to(ndim, a.shape)) + if b.ndim < ndim: + b = array_ops.reshape(b, array_ops._pad_left_to(ndim, b.shape)) + a_reshaped = array_ops.reshape(a, [i for d in a.shape for i in (d, 1)]) + b_reshaped = array_ops.reshape(b, [i for d in b.shape for i in (1, d)]) + out_shape = tuple(np.multiply(a.shape, b.shape)) + return array_ops.reshape(a_reshaped * b_reshaped, out_shape) @utils.np_doc(np.outer) def outer(a, b): - def f(a, b): - return tf.reshape(a, [-1, 1]) * tf.reshape(b, [-1]) - return _bin_op(f, a, b) + def f(a, b): + return tf.reshape(a, [-1, 1]) * tf.reshape(b, [-1]) + + return _bin_op(f, a, b) # This can also be implemented via tf.reduce_logsumexp @utils.np_doc(np.logaddexp) def logaddexp(x1, x2): - amax = maximum(x1, x2) - delta = x1 - x2 - return array_ops.where( - isnan(delta), - x1 + x2, # NaNs or infinities of the same sign. - amax + log1p(exp(-abs(delta)))) + amax = maximum(x1, x2) + delta = x1 - x2 + return array_ops.where( + isnan(delta), + x1 + x2, # NaNs or infinities of the same sign. + amax + log1p(exp(-abs(delta))), + ) @utils.np_doc(np.logaddexp2) def logaddexp2(x1, x2): - amax = maximum(x1, x2) - delta = x1 - x2 - return array_ops.where( - isnan(delta), - x1 + x2, # NaNs or infinities of the same sign. - amax + log1p(exp2(-abs(delta))) / np.log(2)) + amax = maximum(x1, x2) + delta = x1 - x2 + return array_ops.where( + isnan(delta), + x1 + x2, # NaNs or infinities of the same sign. + amax + log1p(exp2(-abs(delta))) / np.log(2), + ) @utils.np_doc(np.polyval) def polyval(p, x): - def f(p, x): - if p.shape.rank == 0: - p = tf.reshape(p, [1]) - p = tf.unstack(p) - # TODO(wangpeng): Make tf version take a tensor for p instead of a list. - y = tf.math.polyval(p, x) - # If the polynomial is 0-order, numpy requires the result to be broadcast to - # `x`'s shape. - if len(p) == 1: - y = tf.broadcast_to(y, x.shape) - return y - return _bin_op(f, p, x) + def f(p, x): + if p.shape.rank == 0: + p = tf.reshape(p, [1]) + p = tf.unstack(p) + # TODO(wangpeng): Make tf version take a tensor for p instead of a list. + y = tf.math.polyval(p, x) + # If the polynomial is 0-order, numpy requires the result to be broadcast to + # `x`'s shape. + if len(p) == 1: + y = tf.broadcast_to(y, x.shape) + return y + + return _bin_op(f, p, x) @utils.np_doc(np.isclose) -def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): # pylint: disable=missing-docstring - def f(a, b): # pylint: disable=missing-docstring - dtype = a.dtype - if np.issubdtype(dtype.as_numpy_dtype, np.inexact): - rtol_ = tf.convert_to_tensor(rtol, dtype.real_dtype) - atol_ = tf.convert_to_tensor(atol, dtype.real_dtype) - result = (tf.math.abs(a - b) <= atol_ + rtol_ * tf.math.abs(b)) - if equal_nan: - result = result | (tf.math.is_nan(a) & tf.math.is_nan(b)) - return result - else: - return a == b - return _bin_op(f, a, b) +def isclose( + a, b, rtol=1e-05, atol=1e-08, equal_nan=False +): # pylint: disable=missing-docstring + def f(a, b): # pylint: disable=missing-docstring + dtype = a.dtype + if np.issubdtype(dtype.as_numpy_dtype, np.inexact): + rtol_ = tf.convert_to_tensor(rtol, dtype.real_dtype) + atol_ = tf.convert_to_tensor(atol, dtype.real_dtype) + result = tf.math.abs(a - b) <= atol_ + rtol_ * tf.math.abs(b) + if equal_nan: + result = result | (tf.math.is_nan(a) & tf.math.is_nan(b)) + return result + else: + return a == b + + return _bin_op(f, a, b) @utils.np_doc(np.allclose) def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - return array_ops.all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) + return array_ops.all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) def _tf_gcd(x1, x2): - def _gcd_cond_fn(x1, x2): - return tf.reduce_any(x2 != 0) - def _gcd_body_fn(x1, x2): - # tf.math.mod will raise an error when any element of x2 is 0. def _gcd_body_fn(x1, x2):
    # tf.math.mod will raise an error when any element of x2 is 0. To avoid
    # that, we change those zeros to ones. Their values don't matter because + # they won't be used. + x2_safe = tf.where(x2 != 0, x2, tf.constant(1, x2.dtype)) + x1, x2 = ( + tf.where(x2 != 0, x2, x1), + tf.where(x2 != 0, tf.math.mod(x1, x2_safe), tf.constant(0, x2.dtype)), + ) + return (tf.where(x1 < x2, x2, x1), tf.where(x1 < x2, x1, x2)) + + if not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or not np.issubdtype( + x2.dtype.as_numpy_dtype, np.integer + ): + raise ValueError("Arguments to gcd must be integers.") + shape = tf.broadcast_static_shape(x1.shape, x2.shape) + x1 = tf.broadcast_to(x1, shape) + x2 = tf.broadcast_to(x2, shape) + gcd, _ = tf.while_loop( + _gcd_cond_fn, _gcd_body_fn, (tf.math.abs(x1), tf.math.abs(x2)) + ) + return gcd @utils.np_doc(np.gcd) def gcd(x1, x2): - return _bin_op(_tf_gcd, x1, x2) + return _bin_op(_tf_gcd, x1, x2) @utils.np_doc(np.lcm) def lcm(x1, x2): - def f(x1, x2): - d = _tf_gcd(x1, x2) - # Same as the `x2_safe` trick above - d_safe = tf.where(d == 0, tf.constant(1, d.dtype), d) - return tf.where(d == 0, tf.constant(0, d.dtype), - tf.math.abs(x1 * x2) // d_safe) - return _bin_op(f, x1, x2) + def f(x1, x2): + d = _tf_gcd(x1, x2) + # Same as the `x2_safe` trick above + d_safe = tf.where(d == 0, tf.constant(1, d.dtype), d) + return tf.where(d == 0, tf.constant(0, d.dtype), tf.math.abs(x1 * x2) // d_safe) + + return _bin_op(f, x1, x2) def _bitwise_binary_op(tf_fn, x1, x2): - def f(x1, x2): - is_bool = (x1.dtype == tf.bool) - if is_bool: - assert x2.dtype == tf.bool - x1 = tf.cast(x1, tf.int8) - x2 = tf.cast(x2, tf.int8) - r = tf_fn(x1, x2) - if is_bool: - r = tf.cast(r, tf.bool) - return r - return _bin_op(f, x1, x2) + def f(x1, x2): + is_bool = x1.dtype == tf.bool + if is_bool: + assert x2.dtype == tf.bool + x1 = tf.cast(x1, tf.int8) + x2 = tf.cast(x2, tf.int8) + r = tf_fn(x1, x2) + if is_bool: + r = tf.cast(r, tf.bool) + return r + + return _bin_op(f, x1, x2) @utils.np_doc(np.bitwise_and) def bitwise_and(x1, x2): - return _bitwise_binary_op(tf.bitwise.bitwise_and, x1, x2) + return _bitwise_binary_op(tf.bitwise.bitwise_and, x1, x2) @utils.np_doc(np.bitwise_or) def bitwise_or(x1, x2): - return _bitwise_binary_op(tf.bitwise.bitwise_or, x1, x2) + return _bitwise_binary_op(tf.bitwise.bitwise_or, x1, x2) @utils.np_doc(np.bitwise_xor) def bitwise_xor(x1, x2): - return _bitwise_binary_op(tf.bitwise.bitwise_xor, x1, x2) + return _bitwise_binary_op(tf.bitwise.bitwise_xor, x1, x2) @utils.np_doc(np.bitwise_not) def bitwise_not(x): - def f(x): - if x.dtype == tf.bool: - return tf.logical_not(x) - return tf.bitwise.invert(x) - return _scalar(f, x) + def f(x): + if x.dtype == tf.bool: + return tf.logical_not(x) + return tf.bitwise.invert(x) + + return _scalar(f, x) def _scalar(tf_fn, x, promote_to_float=False): - """Computes the tf_fn(x) for each element in `x`. - - Args: - tf_fn: function that takes a single Tensor argument. - x: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - promote_to_float: whether to cast the argument to a float dtype - (`dtypes.default_float_type`) if it is not already. - - Returns: - An ndarray with the same shape as `x`. The default output dtype is - determined by `dtypes.default_float_type`, unless x is an ndarray with a - floating point type, in which case the output type is same as x.dtype. - """ - x = array_ops.asarray(x) - if promote_to_float and not np.issubdtype(x.dtype, np.inexact): - x = x.astype(dtypes.default_float_type()) - return utils.tensor_to_ndarray(tf_fn(x.data)) + """Computes the tf_fn(x) for each element in `x`. + + Args: + tf_fn: function that takes a single Tensor argument. + x: array_like. Could be an ndarray, a Tensor or any object that can + be converted to a Tensor using `tf.convert_to_tensor`. + promote_to_float: whether to cast the argument to a float dtype + (`dtypes.default_float_type`) if it is not already. + + Returns: + An ndarray with the same shape as `x`. The default output dtype is + determined by `dtypes.default_float_type`, unless x is an ndarray with a + floating point type, in which case the output type is same as x.dtype. + """ + x = array_ops.asarray(x) + if promote_to_float and not np.issubdtype(x.dtype, np.inexact): + x = x.astype(dtypes.default_float_type()) + return utils.tensor_to_ndarray(tf_fn(x.data)) @utils.np_doc(np.log) def log(x): - return _scalar(tf.math.log, x, True) + return _scalar(tf.math.log, x, True) @utils.np_doc(np.exp) def exp(x): - return _scalar(tf.exp, x, True) + return _scalar(tf.exp, x, True) @utils.np_doc(np.sqrt) def sqrt(x): - return _scalar(tf.sqrt, x, True) + return _scalar(tf.sqrt, x, True) @utils.np_doc(np.abs) def abs(x): - return _scalar(tf.math.abs, x) + return _scalar(tf.math.abs, x) @utils.np_doc(np.absolute) def absolute(x): - return abs(x) + return abs(x) @utils.np_doc(np.fabs) def fabs(x): - return abs(x) + return abs(x) @utils.np_doc(np.ceil) def ceil(x): - return _scalar(tf.math.ceil, x, True) + return _scalar(tf.math.ceil, x, True) @utils.np_doc(np.floor) def floor(x): - return _scalar(tf.math.floor, x, True) + return _scalar(tf.math.floor, x, True) @utils.np_doc(np.conj) def conj(x): - return _scalar(tf.math.conj, x) + return _scalar(tf.math.conj, x) @utils.np_doc(np.negative) def negative(x): - return _scalar(tf.math.negative, x) + return _scalar(tf.math.negative, x) @utils.np_doc(np.reciprocal) def reciprocal(x): - return _scalar(tf.math.reciprocal, x) + return _scalar(tf.math.reciprocal, x) @utils.np_doc(np.signbit) def signbit(x): - def f(x): - if x.dtype == tf.bool: - return tf.fill(x.shape, False) - return x < 0 - return _scalar(f, x) + def f(x): + if x.dtype == tf.bool: + return tf.fill(x.shape, False) + return x < 0 + + return _scalar(f, x) @utils.np_doc(np.sin) def sin(x): - return _scalar(tf.math.sin, x, True) + return _scalar(tf.math.sin, x, True) @utils.np_doc(np.cos) def cos(x): - return _scalar(tf.math.cos, x, True) + return _scalar(tf.math.cos, x, True) @utils.np_doc(np.tan) def tan(x): - return _scalar(tf.math.tan, x, True) + return _scalar(tf.math.tan, x, True) @utils.np_doc(np.sinh) def sinh(x): - return _scalar(tf.math.sinh, x, True) + return _scalar(tf.math.sinh, x, True) @utils.np_doc(np.cosh) def cosh(x): - return _scalar(tf.math.cosh, x, True) + return _scalar(tf.math.cosh, x, True) @utils.np_doc(np.tanh) def tanh(x): - return _scalar(tf.math.tanh, x, True) + return _scalar(tf.math.tanh, x, True) @utils.np_doc(np.arcsin) def arcsin(x): - return _scalar(tf.math.asin, x, True) + return _scalar(tf.math.asin, x, True) @utils.np_doc(np.arccos) def arccos(x): - return _scalar(tf.math.acos, x, True) + return _scalar(tf.math.acos, x, True) @utils.np_doc(np.arctan) def arctan(x): - return _scalar(tf.math.atan, x, True) + return _scalar(tf.math.atan, x, True) @utils.np_doc(np.arcsinh) def arcsinh(x): - return _scalar(tf.math.asinh, x, True) + return _scalar(tf.math.asinh, x, True) @utils.np_doc(np.arccosh) def arccosh(x): - return _scalar(tf.math.acosh, x, True) + return _scalar(tf.math.acosh, x, True) @utils.np_doc(np.arctanh) def arctanh(x): - return _scalar(tf.math.atanh, x, True) + return _scalar(tf.math.atanh, x, True) @utils.np_doc(np.deg2rad) def deg2rad(x): - def f(x): - return x * (np.pi / 180.0) - return _scalar(f, x, True) + def f(x): + return x * (np.pi / 180.0) + + return _scalar(f, x, True) @utils.np_doc(np.rad2deg) def rad2deg(x): - return x * (180.0 / np.pi) + return x * (180.0 / np.pi) _tf_float_types = [tf.bfloat16, tf.float16, tf.float32, tf.float64] @@ -605,89 +669,93 @@ def rad2deg(x): @utils.np_doc(np.angle) def angle(z, deg=False): - def f(x): - if x.dtype in _tf_float_types: - # Workaround for b/147515503 - return tf.where(x < 0, np.pi, 0) - else: - return tf.math.angle(x) - y = _scalar(f, z, True) - if deg: - y = rad2deg(y) - return y + def f(x): + if x.dtype in _tf_float_types: + # Workaround for b/147515503 + return tf.where(x < 0, np.pi, 0) + else: + return tf.math.angle(x) + + y = _scalar(f, z, True) + if deg: + y = rad2deg(y) + return y @utils.np_doc(np.cbrt) def cbrt(x): - def f(x): - # __pow__ can't handle negative base, so we use `abs` here. - rt = tf.math.abs(x) ** (1.0 / 3) - return tf.where(x < 0, -rt, rt) - return _scalar(f, x, True) + def f(x): + # __pow__ can't handle negative base, so we use `abs` here. + rt = tf.math.abs(x) ** (1.0 / 3) + return tf.where(x < 0, -rt, rt) + + return _scalar(f, x, True) @utils.np_doc(np.conjugate) def conjugate(x): - return _scalar(tf.math.conj, x) + return _scalar(tf.math.conj, x) @utils.np_doc(np.exp2) def exp2(x): - def f(x): - return 2 ** x - return _scalar(f, x, True) + def f(x): + return 2**x + + return _scalar(f, x, True) @utils.np_doc(np.expm1) def expm1(x): - return _scalar(tf.math.expm1, x, True) + return _scalar(tf.math.expm1, x, True) @utils.np_doc(np.fix) def fix(x): - def f(x): - return tf.where(x < 0, tf.math.ceil(x), tf.math.floor(x)) - return _scalar(f, x, True) + def f(x): + return tf.where(x < 0, tf.math.ceil(x), tf.math.floor(x)) + + return _scalar(f, x, True) @utils.np_doc(np.iscomplex) def iscomplex(x): - return array_ops.imag(x) != 0 + return array_ops.imag(x) != 0 @utils.np_doc(np.isreal) def isreal(x): - return array_ops.imag(x) == 0 + return array_ops.imag(x) == 0 @utils.np_doc(np.iscomplexobj) def iscomplexobj(x): - x = array_ops.array(x) - return np.issubdtype(x.dtype, np.complexfloating) + x = array_ops.array(x) + return np.issubdtype(x.dtype, np.complexfloating) @utils.np_doc(np.isrealobj) def isrealobj(x): - return not iscomplexobj(x) + return not iscomplexobj(x) @utils.np_doc(np.isnan) def isnan(x): - return _scalar(tf.math.is_nan, x, True) + return _scalar(tf.math.is_nan, x, True) def _make_nan_reduction(onp_reduction, reduction, init_val): - """Helper to generate nan* functions.""" - @utils.np_doc(onp_reduction) - def nan_reduction(a, axis=None, dtype=None, keepdims=False): - a = array_ops.array(a) - v = array_ops.array(init_val, dtype=a.dtype) - return reduction( - array_ops.where(isnan(a), v, a), - axis=axis, - dtype=dtype, - keepdims=keepdims) - return nan_reduction + """Helper to generate nan* functions.""" + + @utils.np_doc(onp_reduction) + def nan_reduction(a, axis=None, dtype=None, keepdims=False): + a = array_ops.array(a) + v = array_ops.array(init_val, dtype=a.dtype) + return reduction( + array_ops.where(isnan(a), v, a), axis=axis, dtype=dtype, keepdims=keepdims + ) + + return nan_reduction nansum = _make_nan_reduction(np.nansum, array_ops.sum, 0) @@ -695,446 +763,471 @@ def nan_reduction(a, axis=None, dtype=None, keepdims=False): @utils.np_doc(np.nanmean) -def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring - a = array_ops.array(a) - if np.issubdtype(a.dtype, np.bool_) or np.issubdtype(a.dtype, np.integer): - return array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) - nan_mask = logical_not(isnan(a)) - if dtype is None: - dtype = a.dtype - normalizer = array_ops.sum( - nan_mask, axis=axis, dtype=dtype, keepdims=keepdims) - return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer +def nanmean( + a, axis=None, dtype=None, keepdims=None +): # pylint: disable=missing-docstring + a = array_ops.array(a) + if np.issubdtype(a.dtype, np.bool_) or np.issubdtype(a.dtype, np.integer): + return array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) + nan_mask = logical_not(isnan(a)) + if dtype is None: + dtype = a.dtype + normalizer = array_ops.sum(nan_mask, axis=axis, dtype=dtype, keepdims=keepdims) + return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer @utils.np_doc(np.isfinite) def isfinite(x): - return _scalar(tf.math.is_finite, x, True) + return _scalar(tf.math.is_finite, x, True) @utils.np_doc(np.isinf) def isinf(x): - return _scalar(tf.math.is_inf, x, True) + return _scalar(tf.math.is_inf, x, True) @utils.np_doc(np.isneginf) def isneginf(x): - return x == array_ops.full_like(x, -np.inf) + return x == array_ops.full_like(x, -np.inf) @utils.np_doc(np.isposinf) def isposinf(x): - return x == array_ops.full_like(x, np.inf) + return x == array_ops.full_like(x, np.inf) @utils.np_doc(np.log2) def log2(x): - return log(x) / np.log(2) + return log(x) / np.log(2) @utils.np_doc(np.log10) def log10(x): - return log(x) / np.log(10) + return log(x) / np.log(10) @utils.np_doc(np.log1p) def log1p(x): - return _scalar(tf.math.log1p, x, True) + return _scalar(tf.math.log1p, x, True) @utils.np_doc(np.positive) def positive(x): - return _scalar(lambda x: x, x) + return _scalar(lambda x: x, x) @utils.np_doc(np.sinc) def sinc(x): - def f(x): - pi_x = x * np.pi - return tf.where(x == 0, tf.ones_like(x), tf.math.sin(pi_x) / pi_x) - return _scalar(f, x, True) + def f(x): + pi_x = x * np.pi + return tf.where(x == 0, tf.ones_like(x), tf.math.sin(pi_x) / pi_x) + + return _scalar(f, x, True) @utils.np_doc(np.square) def square(x): - return _scalar(tf.math.square, x) + return _scalar(tf.math.square, x) @utils.np_doc(np.diff) def diff(a, n=1, axis=-1): - def f(a): - nd = a.shape.rank - if (axis + nd if axis < 0 else axis) >= nd: - raise ValueError("axis %s is out of bounds for array of dimension %s" % - (axis, nd)) - if n < 0: - raise ValueError("order must be non-negative but got %s" % n) - slice1 = [slice(None)] * nd - slice2 = [slice(None)] * nd - slice1[axis] = slice(1, None) - slice2[axis] = slice(None, -1) - slice1 = tuple(slice1) - slice2 = tuple(slice2) - op = tf.not_equal if a.dtype == tf.bool else tf.subtract - for _ in range(n): - a = op(a[slice1], a[slice2]) - return a - return _scalar(f, a) + def f(a): + nd = a.shape.rank + if (axis + nd if axis < 0 else axis) >= nd: + raise ValueError( + "axis %s is out of bounds for array of dimension %s" % (axis, nd) + ) + if n < 0: + raise ValueError("order must be non-negative but got %s" % n) + slice1 = [slice(None)] * nd + slice2 = [slice(None)] * nd + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + slice1 = tuple(slice1) + slice2 = tuple(slice2) + op = tf.not_equal if a.dtype == tf.bool else tf.subtract + for _ in range(n): + a = op(a[slice1], a[slice2]) + return a + + return _scalar(f, a) def _flip_args(f): - def _f(a, b): - return f(b, a) - return _f - - -setattr(arrays.ndarray, '__abs__', absolute) -setattr(arrays.ndarray, '__floordiv__', floor_divide) -setattr(arrays.ndarray, '__rfloordiv__', _flip_args(floor_divide)) -setattr(arrays.ndarray, '__mod__', mod) -setattr(arrays.ndarray, '__rmod__', _flip_args(mod)) -setattr(arrays.ndarray, '__add__', add) -setattr(arrays.ndarray, '__radd__', _flip_args(add)) -setattr(arrays.ndarray, '__sub__', subtract) -setattr(arrays.ndarray, '__rsub__', _flip_args(subtract)) -setattr(arrays.ndarray, '__mul__', multiply) -setattr(arrays.ndarray, '__rmul__', _flip_args(multiply)) -setattr(arrays.ndarray, '__pow__', power) -setattr(arrays.ndarray, '__rpow__', _flip_args(power)) -setattr(arrays.ndarray, '__truediv__', true_divide) -setattr(arrays.ndarray, '__rtruediv__', _flip_args(true_divide)) + def _f(a, b): + return f(b, a) + + return _f + + +setattr(arrays.ndarray, "__abs__", absolute) +setattr(arrays.ndarray, "__floordiv__", floor_divide) +setattr(arrays.ndarray, "__rfloordiv__", _flip_args(floor_divide)) +setattr(arrays.ndarray, "__mod__", mod) +setattr(arrays.ndarray, "__rmod__", _flip_args(mod)) +setattr(arrays.ndarray, "__add__", add) +setattr(arrays.ndarray, "__radd__", _flip_args(add)) +setattr(arrays.ndarray, "__sub__", subtract) +setattr(arrays.ndarray, "__rsub__", _flip_args(subtract)) +setattr(arrays.ndarray, "__mul__", multiply) +setattr(arrays.ndarray, "__rmul__", _flip_args(multiply)) +setattr(arrays.ndarray, "__pow__", power) +setattr(arrays.ndarray, "__rpow__", _flip_args(power)) +setattr(arrays.ndarray, "__truediv__", true_divide) +setattr(arrays.ndarray, "__rtruediv__", _flip_args(true_divide)) def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): - dtype = utils.result_type(x1, x2) - # Cast x1 and x2 to the result_type if needed. - x1 = array_ops.array(x1, dtype=dtype) - x2 = array_ops.array(x2, dtype=dtype) - x1 = x1.data - x2 = x2.data - if cast_bool_to_int and x1.dtype == tf.bool: - x1 = tf.cast(x1, tf.int32) - x2 = tf.cast(x2, tf.int32) - return utils.tensor_to_ndarray(tf_fun(x1, x2)) + dtype = utils.result_type(x1, x2) + # Cast x1 and x2 to the result_type if needed. + x1 = array_ops.array(x1, dtype=dtype) + x2 = array_ops.array(x2, dtype=dtype) + x1 = x1.data + x2 = x2.data + if cast_bool_to_int and x1.dtype == tf.bool: + x1 = tf.cast(x1, tf.int32) + x2 = tf.cast(x2, tf.int32) + return utils.tensor_to_ndarray(tf_fun(x1, x2)) @utils.np_doc(np.equal) def equal(x1, x2): - return _comparison(tf.equal, x1, x2) + return _comparison(tf.equal, x1, x2) @utils.np_doc(np.not_equal) def not_equal(x1, x2): - return _comparison(tf.not_equal, x1, x2) + return _comparison(tf.not_equal, x1, x2) @utils.np_doc(np.greater) def greater(x1, x2): - return _comparison(tf.greater, x1, x2, True) + return _comparison(tf.greater, x1, x2, True) @utils.np_doc(np.greater_equal) def greater_equal(x1, x2): - return _comparison(tf.greater_equal, x1, x2, True) + return _comparison(tf.greater_equal, x1, x2, True) @utils.np_doc(np.less) def less(x1, x2): - return _comparison(tf.less, x1, x2, True) + return _comparison(tf.less, x1, x2, True) @utils.np_doc(np.less_equal) def less_equal(x1, x2): - return _comparison(tf.less_equal, x1, x2, True) + return _comparison(tf.less_equal, x1, x2, True) @utils.np_doc(np.array_equal) def array_equal(a1, a2): - def f(a1, a2): - if a1.shape != a2.shape: - return tf.constant(False) - return tf.reduce_all(tf.equal(a1, a2)) - return _comparison(f, a1, a2) + def f(a1, a2): + if a1.shape != a2.shape: + return tf.constant(False) + return tf.reduce_all(tf.equal(a1, a2)) + + return _comparison(f, a1, a2) def _logical_binary_op(tf_fun, x1, x2): - x1 = array_ops.array(x1, dtype=np.bool_) - x2 = array_ops.array(x2, dtype=np.bool_) - return utils.tensor_to_ndarray(tf_fun(x1.data, x2.data)) + x1 = array_ops.array(x1, dtype=np.bool_) + x2 = array_ops.array(x2, dtype=np.bool_) + return utils.tensor_to_ndarray(tf_fun(x1.data, x2.data)) @utils.np_doc(np.logical_and) def logical_and(x1, x2): - return _logical_binary_op(tf.logical_and, x1, x2) + return _logical_binary_op(tf.logical_and, x1, x2) @utils.np_doc(np.logical_or) def logical_or(x1, x2): - return _logical_binary_op(tf.logical_or, x1, x2) + return _logical_binary_op(tf.logical_or, x1, x2) @utils.np_doc(np.logical_xor) def logical_xor(x1, x2): - return _logical_binary_op(tf.math.logical_xor, x1, x2) + return _logical_binary_op(tf.math.logical_xor, x1, x2) @utils.np_doc(np.logical_not) def logical_not(x): - x = array_ops.array(x, dtype=np.bool_) - return utils.tensor_to_ndarray(tf.logical_not(x.data)) + x = array_ops.array(x, dtype=np.bool_) + return utils.tensor_to_ndarray(tf.logical_not(x.data)) -setattr(arrays.ndarray, '__invert__', logical_not) -setattr(arrays.ndarray, '__lt__', less) -setattr(arrays.ndarray, '__le__', less_equal) -setattr(arrays.ndarray, '__gt__', greater) -setattr(arrays.ndarray, '__ge__', greater_equal) -setattr(arrays.ndarray, '__eq__', equal) -setattr(arrays.ndarray, '__ne__', not_equal) + +setattr(arrays.ndarray, "__invert__", logical_not) +setattr(arrays.ndarray, "__lt__", less) +setattr(arrays.ndarray, "__le__", less_equal) +setattr(arrays.ndarray, "__gt__", greater) +setattr(arrays.ndarray, "__ge__", greater_equal) +setattr(arrays.ndarray, "__eq__", equal) +setattr(arrays.ndarray, "__ne__", not_equal) @utils.np_doc(np.linspace) def linspace( # pylint: disable=missing-docstring - start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0): - if dtype: - dtype = utils.result_type(dtype) - start = array_ops.array(start, dtype=dtype).data - stop = array_ops.array(stop, dtype=dtype).data - if num < 0: - raise ValueError('Number of samples {} must be non-negative.'.format(num)) - step = tf.convert_to_tensor(np.nan) - if endpoint: - result = tf.linspace(start, stop, num, axis=axis) - if num > 1: - step = (stop - start) / (num - 1) - else: - # tf.linspace does not support endpoint=False so we manually handle it - # here. - if num > 1: - step = ((stop - start) / num) - new_stop = tf.cast(stop, step.dtype) - step - start = tf.cast(start, new_stop.dtype) - result = tf.linspace(start, new_stop, num, axis=axis) + start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0 +): + if dtype: + dtype = utils.result_type(dtype) + start = array_ops.array(start, dtype=dtype).data + stop = array_ops.array(stop, dtype=dtype).data + if num < 0: + raise ValueError("Number of samples {} must be non-negative.".format(num)) + step = tf.convert_to_tensor(np.nan) + if endpoint: + result = tf.linspace(start, stop, num, axis=axis) + if num > 1: + step = (stop - start) / (num - 1) else: - result = tf.linspace(start, stop, num, axis=axis) - if dtype: - result = tf.cast(result, dtype) - if retstep: - return arrays.tensor_to_ndarray(result), arrays.tensor_to_ndarray(step) - else: - return arrays.tensor_to_ndarray(result) + # tf.linspace does not support endpoint=False so we manually handle it + # here. + if num > 1: + step = (stop - start) / num + new_stop = tf.cast(stop, step.dtype) - step + start = tf.cast(start, new_stop.dtype) + result = tf.linspace(start, new_stop, num, axis=axis) + else: + result = tf.linspace(start, stop, num, axis=axis) + if dtype: + result = tf.cast(result, dtype) + if retstep: + return arrays.tensor_to_ndarray(result), arrays.tensor_to_ndarray(step) + else: + return arrays.tensor_to_ndarray(result) @utils.np_doc(np.logspace) def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): - dtype = utils.result_type(start, stop, dtype) - result = linspace( - start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis).data - result = tf.pow(tf.cast(base, result.dtype), result) - if dtype: - result = tf.cast(result, dtype) - return arrays.tensor_to_ndarray(result) + dtype = utils.result_type(start, stop, dtype) + result = linspace( + start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis + ).data + result = tf.pow(tf.cast(base, result.dtype), result) + if dtype: + result = tf.cast(result, dtype) + return arrays.tensor_to_ndarray(result) @utils.np_doc(np.ptp) def ptp(a, axis=None, keepdims=None): - return (array_ops.amax(a, axis=axis, keepdims=keepdims) - - array_ops.amin(a, axis=axis, keepdims=keepdims)) + return array_ops.amax(a, axis=axis, keepdims=keepdims) - array_ops.amin( + a, axis=axis, keepdims=keepdims + ) @utils.np_doc_only(np.concatenate) def concatenate(arys, axis=0): - if not isinstance(arys, (list, tuple)): - arys = [arys] - if not arys: - raise ValueError('Need at least one array to concatenate.') - dtype = utils.result_type(*arys) - arys = [array_ops.array(array, dtype=dtype).data for array in arys] - return arrays.tensor_to_ndarray(tf.concat(arys, axis)) + if not isinstance(arys, (list, tuple)): + arys = [arys] + if not arys: + raise ValueError("Need at least one array to concatenate.") + dtype = utils.result_type(*arys) + arys = [array_ops.array(array, dtype=dtype).data for array in arys] + return arrays.tensor_to_ndarray(tf.concat(arys, axis)) @utils.np_doc_only(np.tile) def tile(a, reps): - a = array_ops.array(a).data - reps = array_ops.array(reps, dtype=tf.int32).reshape([-1]).data + a = array_ops.array(a).data + reps = array_ops.array(reps, dtype=tf.int32).reshape([-1]).data - a_rank = tf.rank(a) - reps_size = tf.size(reps) - reps = tf.pad( - reps, [[tf.math.maximum(a_rank - reps_size, 0), 0]], - constant_values=1) - a_shape = tf.pad( - tf.shape(a), [[tf.math.maximum(reps_size - a_rank, 0), 0]], - constant_values=1) - a = tf.reshape(a, a_shape) + a_rank = tf.rank(a) + reps_size = tf.size(reps) + reps = tf.pad( + reps, [[tf.math.maximum(a_rank - reps_size, 0), 0]], constant_values=1 + ) + a_shape = tf.pad( + tf.shape(a), [[tf.math.maximum(reps_size - a_rank, 0), 0]], constant_values=1 + ) + a = tf.reshape(a, a_shape) - return arrays.tensor_to_ndarray(tf.tile(a, reps)) + return arrays.tensor_to_ndarray(tf.tile(a, reps)) @utils.np_doc(np.count_nonzero) def count_nonzero(a, axis=None): - return arrays.tensor_to_ndarray( - tf.math.count_nonzero(array_ops.array(a).data, axis)) + return arrays.tensor_to_ndarray( + tf.math.count_nonzero(array_ops.array(a).data, axis) + ) @utils.np_doc(np.argsort) -def argsort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring - # TODO(nareshmodi): make string tensors also work. - if kind not in ('quicksort', 'stable'): - raise ValueError("Only 'quicksort' and 'stable' arguments are supported.") - if order is not None: - raise ValueError("'order' argument to sort is not supported.") - stable = (kind == 'stable') +def argsort( + a, axis=-1, kind="quicksort", order=None +): # pylint: disable=missing-docstring + # TODO(nareshmodi): make string tensors also work. + if kind not in ("quicksort", "stable"): + raise ValueError("Only 'quicksort' and 'stable' arguments are supported.") + if order is not None: + raise ValueError("'order' argument to sort is not supported.") + stable = kind == "stable" - a = array_ops.array(a).data + a = array_ops.array(a).data - def _argsort(a, axis, stable): - if axis is None: - a = tf.reshape(a, [-1]) - axis = 0 + def _argsort(a, axis, stable): + if axis is None: + a = tf.reshape(a, [-1]) + axis = 0 - return tf.argsort(a, axis, stable=stable) + return tf.argsort(a, axis, stable=stable) - tf_ans = tf.cond( - tf.rank(a) == 0, lambda: tf.constant([0]), - lambda: _argsort(a, axis, stable)) + tf_ans = tf.cond( + tf.rank(a) == 0, lambda: tf.constant([0]), lambda: _argsort(a, axis, stable) + ) - return array_ops.array(tf_ans, dtype=np.intp) + return array_ops.array(tf_ans, dtype=np.intp) @utils.np_doc(np.sort) -def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring - if kind != 'quicksort': - raise ValueError("Only 'quicksort' is supported.") - if order is not None: - raise ValueError("'order' argument to sort is not supported.") +def sort(a, axis=-1, kind="quicksort", order=None): # pylint: disable=missing-docstring + if kind != "quicksort": + raise ValueError("Only 'quicksort' is supported.") + if order is not None: + raise ValueError("'order' argument to sort is not supported.") - a = array_ops.array(a) + a = array_ops.array(a) - if axis is None: - result_t = tf.sort(tf.reshape(a.data, [-1]), 0) - return utils.tensor_to_ndarray(result_t) - else: - return utils.tensor_to_ndarray(tf.sort(a.data, axis)) + if axis is None: + result_t = tf.sort(tf.reshape(a.data, [-1]), 0) + return utils.tensor_to_ndarray(result_t) + else: + return utils.tensor_to_ndarray(tf.sort(a.data, axis)) def _argminmax(fn, a, axis=None): - a = array_ops.array(a) - if axis is None: - # When axis is None numpy flattens the array. - a_t = tf.reshape(a.data, [-1]) - else: - a_t = array_ops.atleast_1d(a).data - return utils.tensor_to_ndarray(fn(input=a_t, axis=axis)) + a = array_ops.array(a) + if axis is None: + # When axis is None numpy flattens the array. + a_t = tf.reshape(a.data, [-1]) + else: + a_t = array_ops.atleast_1d(a).data + return utils.tensor_to_ndarray(fn(input=a_t, axis=axis)) @utils.np_doc(np.argmax) def argmax(a, axis=None): - return _argminmax(tf.argmax, a, axis) + return _argminmax(tf.argmax, a, axis) @utils.np_doc(np.argmin) def argmin(a, axis=None): - def run_test(arr, repeats, *args, **kwargs): Tuple of ints is not ' - 'supported yet. Got type: %s' % type(axis)) - a = array_ops.array(a) - if weights is None: # Treat all weights as 1 - if not np.issubdtype(a.dtype, np.inexact): - a = a.astype(utils.result_type(a.dtype, dtypes.default_float_type())) - avg = tf.reduce_mean(a.data, axis=axis) - if returned: - if axis is None: - weights_sum = tf.size(a.data) - else: - weights_sum = tf.shape(a.data)[axis] - weights_sum = tf.cast(weights_sum, a.data.dtype) - else: - if np.issubdtype(a.dtype, np.inexact): - out_dtype = utils.result_type(a.dtype, weights) - else: - out_dtype = utils.result_type(a.dtype, weights, - dtypes.default_float_type()) - a = array_ops.array(a, out_dtype).data - weights = array_ops.array(weights, out_dtype).data - - def rank_equal_case(): - tf.debugging.Assert(tf.reduce_all(tf.shape(a) == tf.shape(weights)), - [tf.shape(a), tf.shape(weights)]) - weights_sum = tf.reduce_sum(weights, axis=axis) - avg = tf.reduce_sum(a * weights, axis=axis) / weights_sum - return avg, weights_sum - if axis is None: - avg, weights_sum = rank_equal_case() +def average( + a, axis=None, weights=None, returned=False +): # pylint: disable=missing-docstring + if axis is not None and not isinstance(axis, six.integer_types): + # TODO(wangpeng): Support tuple of ints as `axis` + raise ValueError( + "`axis` must be an integer. Tuple of ints is not " + "supported yet. Got type: %s" % type(axis) + ) + a = array_ops.array(a) + if weights is None: # Treat all weights as 1 + if not np.issubdtype(a.dtype, np.inexact): + a = a.astype(utils.result_type(a.dtype, dtypes.default_float_type())) + avg = tf.reduce_mean(a.data, axis=axis) + if returned: + if axis is None: + weights_sum = tf.size(a.data) + else: + weights_sum = tf.shape(a.data)[axis] + weights_sum = tf.cast(weights_sum, a.data.dtype) else: - def rank_not_equal_case(): - tf.debugging.Assert(tf.rank(weights) == 1, [tf.rank(weights)]) - weights_sum = tf.reduce_sum(weights) - axes = tf.convert_to_tensor([[axis], [0]]) - avg = tf.tensordot(a, weights, axes) / weights_sum + if np.issubdtype(a.dtype, np.inexact): + out_dtype = utils.result_type(a.dtype, weights) + else: + out_dtype = utils.result_type(a.dtype, weights, dtypes.default_float_type()) + a = array_ops.array(a, out_dtype).data + weights = array_ops.array(weights, out_dtype).data + + def rank_equal_case(): + tf.debugging.Assert( + tf.reduce_all(tf.shape(a) == tf.shape(weights)), + [tf.shape(a), tf.shape(weights)], + ) + weights_sum = tf.reduce_sum(weights, axis=axis) + avg = tf.reduce_sum(a * weights, axis=axis) / weights_sum + return avg, weights_sum + + if axis is None: + avg, weights_sum = rank_equal_case() + else: + + def rank_not_equal_case(): + tf.debugging.Assert(tf.rank(weights) == 1, [tf.rank(weights)]) + weights_sum = tf.reduce_sum(weights) + axes = tf.convert_to_tensor([[axis], [0]]) + avg = tf.tensordot(a, weights, axes) / weights_sum + return avg, weights_sum + + # We condition on rank rather than shape equality, because if we do the + # latter, when the shapes are partially unknown but the ranks are known + # and different, utils.cond will run shape checking on the true branch, + # which will raise a shape-checking error. + avg, weights_sum = utils.cond( + tf.rank(a) == tf.rank(weights), rank_equal_case, rank_not_equal_case + ) + + avg = array_ops.array(avg) + if returned: + weights_sum = array_ops.broadcast_to(weights_sum, tf.shape(avg.data)) return avg, weights_sum - # We condition on rank rather than shape equality, because if we do the - # latter, when the shapes are partially unknown but the ranks are known - # and different, utils.cond will run shape checking on the true branch, - # which will raise a shape-checking error. - avg, weights_sum = utils.cond(tf.rank(a) == tf.rank(weights), - rank_equal_case, rank_not_equal_case) - - avg = array_ops.array(avg) - if returned: - weights_sum = array_ops.broadcast_to(weights_sum, tf.shape(avg.data)) - return avg, weights_sum - return avg + return avg @utils.np_doc(np.trace) -def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring - if dtype: - dtype = utils.result_type(dtype) - a = array_ops.asarray(a, dtype).data - - if offset == 0: - a_shape = a.shape - if a_shape.rank is not None: - rank = len(a_shape) - if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or - axis2 == rank - 1): - return utils.tensor_to_ndarray(tf.linalg.trace(a)) - - a = array_ops.diagonal(a, offset, axis1, axis2) - return array_ops.sum(a, -1, dtype) +def trace( + a, offset=0, axis1=0, axis2=1, dtype=None +): # pylint: disable=missing-docstring + if dtype: + dtype = utils.result_type(dtype) + a = array_ops.asarray(a, dtype).data + + if offset == 0: + a_shape = a.shape + if a_shape.rank is not None: + rank = len(a_shape) + if (axis1 == -2 or axis1 == rank - 2) and ( + axis2 == -1 or axis2 == rank - 1 + ): + return utils.tensor_to_ndarray(tf.linalg.trace(a)) + + a = array_ops.diagonal(a, offset, axis1, axis2) + return array_ops.sum(a, -1, dtype) @utils.np_doc(np.meshgrid) def meshgrid(*xi, **kwargs): - """This currently requires copy=True and sparse=False.""" - sparse = kwargs.get('sparse', False) - if sparse: - raise ValueError('tf.numpy doesnt support returning sparse arrays yet') + """This currently requires copy=True and sparse=False.""" + sparse = kwargs.get("sparse", False) + if sparse: + raise ValueError("tf.numpy doesnt support returning sparse arrays yet") - copy = kwargs.get('copy', True) - if not copy: - raise ValueError('tf.numpy only supports copy=True') + copy = kwargs.get("copy", True) + if not copy: + raise ValueError("tf.numpy only supports copy=True") - indexing = kwargs.get('indexing', 'xy') + indexing = kwargs.get("indexing", "xy") - xi = [array_ops.asarray(arg).data for arg in xi] - kwargs = {'indexing': indexing} + xi = [array_ops.asarray(arg).data for arg in xi] + kwargs = {"indexing": indexing} - outputs = tf.meshgrid(*xi, **kwargs) - outputs = [utils.tensor_to_ndarray(output) for output in outputs] + outputs = tf.meshgrid(*xi, **kwargs) + outputs = [utils.tensor_to_ndarray(output) for output in outputs] - return outputs + return outputs diff --git a/trax/tf_numpy/numpy_impl/random.py b/trax/tf_numpy/numpy_impl/random.py index 8ed3021eb..c6ca9c527 100644 --- a/trax/tf_numpy/numpy_impl/random.py +++ b/trax/tf_numpy/numpy_impl/random.py @@ -24,30 +24,29 @@ def randn(*args): - """Returns samples from a normal distribution. + """Returns samples from a normal distribution. - Uses `tf.random_normal`. + Uses `tf.random_normal`. - Args: - *args: The shape of the output array. + Args: + *args: The shape of the output array. - Returns: - An ndarray with shape `args` and dtype `float64`. - """ - # TODO(wangpeng): Use new stateful RNG - if utils.isscalar(args): - args = (args,) - return utils.tensor_to_ndarray( - tf.random.normal(args, dtype=DEFAULT_RANDN_DTYPE)) + Returns: + An ndarray with shape `args` and dtype `float64`. + """ + # TODO(wangpeng): Use new stateful RNG + if utils.isscalar(args): + args = (args,) + return utils.tensor_to_ndarray(tf.random.normal(args, dtype=DEFAULT_RANDN_DTYPE)) def seed(s): - """Sets the seed for the random number generator. + """Sets the seed for the random number generator. - Uses `tf.set_random_seed`. + Uses `tf.set_random_seed`. - Args: - s: an integer. - """ - # TODO(wangpeng): make the signature the same as numpy - tf.random.set_seed(s) + Args: + s: an integer. + """ + # TODO(wangpeng): make the signature the same as numpy + tf.random.set_seed(s) diff --git a/trax/tf_numpy/numpy_impl/tests/array_ops_test.py b/trax/tf_numpy/numpy_impl/tests/array_ops_test.py deleted file mode 100644 index b74992ba2..000000000 --- a/trax/tf_numpy/numpy_impl/tests/array_ops_test.py +++ /dev/null @@ -1,1130 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """Converts a native python or TF DType to numpy type. + """Converts a native python or TF DType to numpy type. - Args: - dtype: Could be a python type, a numpy type or a TF DType. + Args: + dtype: Could be a python type, a numpy type or a TF DType. - Returns: - A NumPy `dtype`. - """ - if isinstance(dtype, tf.DType): - return dtype.as_numpy_dtype - return np.dtype(dtype) + Returns: + A NumPy `dtype`. + """ + if isinstance(dtype, tf.DType): + return dtype.as_numpy_dtype + return np.dtype(dtype) def finfo(dtype): - """Returns properties of floating point types. + """Returns properties of floating point types. - Note that currently it just forwards to the numpy namesake, while tensorflow - and numpy dtypes may have different properties. + Note that currently it just forwards to the numpy namesake, while tensorflow + and numpy dtypes may have different properties. - Args: - dtype: Could be a python type, a numpy type or a TF DType. + Args: + dtype: Could be a python type, a numpy type or a TF DType. - Returns: - A class describing properties of `dtype`, as described by - https://docs.scipy.org/doc/numpy/reference/generated/numpy.finfo.html - """ - return np.finfo(_to_numpy_type(dtype)) + Returns: + A class describing properties of `dtype`, as described by + https://docs.scipy.org/doc/numpy/reference/generated/numpy.finfo.html + """ + return np.finfo(_to_numpy_type(dtype)) def isscalar(val): - """Returns whether `val` is a scalar value or scalar Tensor.""" - if isinstance(val, (np.ndarray, arrays.ndarray, tf.Tensor)): - return len(val.shape) == 0 # pylint: disable=g-explicit-length-test - return np.isscalar(val) + """Returns whether `val` is a scalar value or scalar Tensor.""" + if isinstance(val, (np.ndarray, arrays.ndarray, tf.Tensor)): + return len(val.shape) == 0 # pylint: disable=g-explicit-length-test + return np.isscalar(val) # Can't use np_doc because np.result_type is a builtin function. def result_type(*arrays_and_dtypes): - """Returns the type resulting from applying NumPy type promotion to arguments. - - Args: - *arrays_and_dtypes: A list of array_like objects or dtypes. - - Returns: - A numpy dtype. - """ - def maybe_get_dtype(x): - # Don't put np.ndarray in this list, because np.result_type looks at the - # value (not just dtype) of np.ndarray to decide the result type. - if isinstance(x, (arrays.ndarray, arrays.ShardedNdArray, - tf.Tensor, tf.IndexedSlices)): - return _to_numpy_type(x.dtype) - elif isinstance(x, tf.DType): - return _to_numpy_type(x) - return x - arrays_and_dtypes = [maybe_get_dtype(x) for x in - tf.nest.flatten(arrays_and_dtypes)] - if not arrays_and_dtypes: - # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. - arrays_and_dtypes = [np.asarray([])] - return dtypes._result_type(*arrays_and_dtypes) + """Returns the type resulting from applying NumPy type promotion to arguments. + + Args: + *arrays_and_dtypes: A list of array_like objects or dtypes. + + Returns: + A numpy dtype. + """ + + def maybe_get_dtype(x): + # Don't put np.ndarray in this list, because np.result_type looks at the + # value (not just dtype) of np.ndarray to decide the result type. + if isinstance( + x, (arrays.ndarray, arrays.ShardedNdArray, tf.Tensor, tf.IndexedSlices) + ): + return _to_numpy_type(x.dtype) + elif isinstance(x, tf.DType): + return _to_numpy_type(x) + return x + + arrays_and_dtypes = [maybe_get_dtype(x) for x in tf.nest.flatten(arrays_and_dtypes)] + if not arrays_and_dtypes: + # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. + arrays_and_dtypes = [np.asarray([])] + return dtypes._result_type(*arrays_and_dtypes) def promote_types(type1, type2): - """Returns the type resulting from applying NumPy type promotion. + """Returns the type resulting from applying NumPy type promotion. - Args: - type1: A numpy type. - type2: A numpy type. + Args: + type1: A numpy type. + type2: A numpy type. - Returns: - A numpy type. - """ - type1 = _to_numpy_type(type1) - type2 = _to_numpy_type(type2) - return dtypes.canonicalize_dtype(np.promote_types(type1, type2)) + Returns: + A numpy type. + """ + type1 = _to_numpy_type(type1) + type2 = _to_numpy_type(type2) + return dtypes.canonicalize_dtype(np.promote_types(type1, type2)) def _has_docstring(f): - return hasattr(f, '__doc__') and isinstance(f.__doc__, str) and f.__doc__ + return hasattr(f, "__doc__") and isinstance(f.__doc__, str) and f.__doc__ def _add_blank_line(s): - if s.endswith('\n'): - return s + '\n' - else: - return s + '\n\n' + if s.endswith("\n"): + return s + "\n" + else: + return s + "\n\n" def _np_signature(f): - """An enhanced funcsigs.signature that can handle numpy.ufunc.""" - if not isinstance(f, np.ufunc): - try: - return funcsigs.signature(f) - except ValueError: - return None - def names_from_num(prefix, n): - if n <= 0: - return [] - elif n == 1: - return [prefix] - else: - return [prefix + str(i + 1) for i in range(n)] - input_names = names_from_num('x', f.nin) - output_names = names_from_num('out', f.nout) - keyword_only_params = [ - ('where', True), - ('casting', 'same_kind'), - ('order', 'K'), - ('dtype', None), - ('subok', True), - ('signature', None), - ('extobj', None)] - params = [] - params += [funcsigs.Parameter(name, funcsigs.Parameter.POSITIONAL_ONLY) - for name in input_names] - if f.nout > 1: - params += [funcsigs.Parameter(name, funcsigs.Parameter.POSITIONAL_ONLY, - default=None) - for name in output_names] - params += [funcsigs.Parameter( - 'out', funcsigs.Parameter.POSITIONAL_OR_KEYWORD, - default=None if f.nout == 1 else (None,) * f.nout)] - params += [funcsigs.Parameter(name, funcsigs.Parameter.KEYWORD_ONLY, - default=default) - for name, default in keyword_only_params] - return funcsigs.Signature(params) + """An enhanced funcsigs.signature that can handle numpy.ufunc.""" + if not isinstance(f, np.ufunc): + try: + return funcsigs.signature(f) + except ValueError: + return None + + def names_from_num(prefix, n): + if n <= 0: + return [] + elif n == 1: + return [prefix] + else: + return [prefix + str(i + 1) for i in range(n)] + + input_names = names_from_num("x", f.nin) + output_names = names_from_num("out", f.nout) + keyword_only_params = [ + ("where", True), + ("casting", "same_kind"), + ("order", "K"), + ("dtype", None), + ("subok", True), + ("signature", None), + ("extobj", None), + ] + params = [] + params += [ + funcsigs.Parameter(name, funcsigs.Parameter.POSITIONAL_ONLY) + for name in input_names + ] + if f.nout > 1: + params += [ + funcsigs.Parameter(name, funcsigs.Parameter.POSITIONAL_ONLY, default=None) + for name in output_names + ] + params += [ + funcsigs.Parameter( + "out", + funcsigs.Parameter.POSITIONAL_OR_KEYWORD, + default=None if f.nout == 1 else (None,) * f.nout, + ) + ] + params += [ + funcsigs.Parameter(name, funcsigs.Parameter.KEYWORD_ONLY, default=default) + for name, default in keyword_only_params + ] + return funcsigs.Signature(params) # Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't # allow positional-only argument. So we conflate positional-only, keyword-only # and positional-or-keyword arguments here. def _is_compatible_param_kind(a, b): - def relax(k): - if k in (funcsigs.Parameter.POSITIONAL_ONLY, - funcsigs.Parameter.KEYWORD_ONLY): - return funcsigs.Parameter.POSITIONAL_OR_KEYWORD - return k - return relax(a) == relax(b) + def relax(k): + if k in (funcsigs.Parameter.POSITIONAL_ONLY, funcsigs.Parameter.KEYWORD_ONLY): + return funcsigs.Parameter.POSITIONAL_OR_KEYWORD + return k + + return relax(a) == relax(b) def np_doc(np_fun): - """Attachs numpy docstring to a function. - - Args: - np_fun: the numpy function whose docstring will be used. - - Returns: - A function decorator that attaches the docstring from `np_fun` to the - decorated function. - """ - np_sig = _np_signature(np_fun) - def decorator(f): - """The decorator.""" - unsupported_params = [] - if np_sig is not None: - sig = funcsigs.signature(f) - for name in np_sig.parameters: - if name not in sig.parameters: - unsupported_params.append(name) - f.__doc__ = _np_doc_helper(f, np_fun, unsupported_params) - return f - return decorator + """Attachs numpy docstring to a function. + + Args: + np_fun: the numpy function whose docstring will be used. + + Returns: + A function decorator that attaches the docstring from `np_fun` to the + decorated function. + """ + np_sig = _np_signature(np_fun) + + def decorator(f): + """The decorator.""" + unsupported_params = [] + if np_sig is not None: + sig = funcsigs.signature(f) + for name in np_sig.parameters: + if name not in sig.parameters: + unsupported_params.append(name) + f.__doc__ = _np_doc_helper(f, np_fun, unsupported_params) + return f + + return decorator def _np_doc_helper(f, np_f, unsupported_params=None): - """Helper to get docs.""" - if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f): - return np_f.__doc__ - doc = 'TensorFlow variant of `numpy.%s`.\n\n' % np_f.__name__ - if unsupported_params: - doc += 'Unsupported arguments: ' + ', '.join( - '`' + name + '`' for name in unsupported_params) + '.\n\n' - if _has_docstring(f): - doc += f.__doc__ - doc = _add_blank_line(doc) - if _has_docstring(np_f): - doc += 'Documentation for `numpy.%s`:\n\n' % np_f.__name__ - doc += np_f.__doc__ - return doc + """Helper to get docs.""" + if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f): + return np_f.__doc__ + doc = "TensorFlow variant of `numpy.%s`.\n\n" % np_f.__name__ + if unsupported_params: + doc += ( + "Unsupported arguments: " + + ", ".join("`" + name + "`" for name in unsupported_params) + + ".\n\n" + ) + if _has_docstring(f): + doc += f.__doc__ + doc = _add_blank_line(doc) + if _has_docstring(np_f): + doc += "Documentation for `numpy.%s`:\n\n" % np_f.__name__ + doc += np_f.__doc__ + return doc def np_doc_only(np_f): - """Attachs numpy docstring to a function. + """Attachs numpy docstring to a function. - This differs from np_doc in that it doesn't check for a match in signature. + This differs from np_doc in that it doesn't check for a match in signature. - Args: - np_f: the numpy function whose docstring will be used. + Args: + np_f: the numpy function whose docstring will be used. - Returns: - A function decorator that attaches the docstring from `np_f` to the - decorated function. - """ + Returns: + A function decorator that attaches the docstring from `np_f` to the + decorated function. + """ - def decorator(f): - f.__doc__ = _np_doc_helper(f, np_f) - return f + def decorator(f): + f.__doc__ = _np_doc_helper(f, np_f) + return f - return decorator + return decorator def tf_broadcast(*args): - """Broadcast tensors. + """Broadcast tensors. - Args: - *args: a list of tensors whose shapes are broadcastable against each other. + Args: + *args: a list of tensors whose shapes are broadcastable against each other. - Returns: - Tensors broadcasted to the common shape. - """ - if len(args) <= 1: - return args - sh = tf.shape(args[0]) - for arg in args[1:]: - sh = tf.broadcast_dynamic_shape(sh, tf.shape(arg)) - return [tf.broadcast_to(arg, sh) for arg in args] + Returns: + Tensors broadcasted to the common shape. + """ + if len(args) <= 1: + return args + sh = tf.shape(args[0]) + for arg in args[1:]: + sh = tf.broadcast_dynamic_shape(sh, tf.shape(arg)) + return [tf.broadcast_to(arg, sh) for arg in args] # TODO(wangpeng): Move the following functions to a separate file and check for @@ -280,28 +297,28 @@ def tf_broadcast(*args): def get_static_value(x): - """A version of tf.get_static_value that returns None on float dtypes. + """A version of tf.get_static_value that returns None on float dtypes. - It returns None on float dtypes in order to avoid breaking gradients. + It returns None on float dtypes in order to avoid breaking gradients. - Args: - x: a tensor. + Args: + x: a tensor. - Returns: - Same as `tf.get_static_value`, except that it returns None when `x` has a - float dtype. - """ - if isinstance(x, tf.Tensor) and (x.dtype.is_floating or x.dtype.is_complex): - return None - return tf.get_static_value(x) + Returns: + Same as `tf.get_static_value`, except that it returns None when `x` has a + float dtype. + """ + if isinstance(x, tf.Tensor) and (x.dtype.is_floating or x.dtype.is_complex): + return None + return tf.get_static_value(x) def _maybe_static(x): - value = get_static_value(x) - if value is None: - return x - else: - return value + value = get_static_value(x) + if value is None: + return x + else: + return value # All the following functions exist because get_static_value can't handle @@ -309,89 +326,89 @@ def _maybe_static(x): def cond(pred, true_fn, false_fn): - """A version of tf.cond that tries to evaluate the condition.""" - v = get_static_value(pred) - if v is None: - return tf.cond(pred, true_fn, false_fn) - if v: - return true_fn() - else: - return false_fn() + """A version of tf.cond that tries to evaluate the condition.""" + v = get_static_value(pred) + if v is None: + return tf.cond(pred, true_fn, false_fn) + if v: + return true_fn() + else: + return false_fn() def add(a, b): - """A version of tf.add that eagerly evaluates if possible.""" - return _maybe_static(a) + _maybe_static(b) + """A version of tf.add that eagerly evaluates if possible.""" + return _maybe_static(a) + _maybe_static(b) def subtract(a, b): - """A version of tf.subtract that eagerly evaluates if possible.""" - return _maybe_static(a) - _maybe_static(b) + """A version of tf.subtract that eagerly evaluates if possible.""" + return _maybe_static(a) - _maybe_static(b) def greater(a, b): - """A version of tf.greater that eagerly evaluates if possible.""" - return _maybe_static(a) > _maybe_static(b) + """A version of tf.greater that eagerly evaluates if possible.""" + return _maybe_static(a) > _maybe_static(b) def greater_equal(a, b): - """A version of tf.greater_equal that eagerly evaluates if possible.""" - return _maybe_static(a) >= _maybe_static(b) + """A version of tf.greater_equal that eagerly evaluates if possible.""" + return _maybe_static(a) >= _maybe_static(b) def less_equal(a, b): - """A version of tf.less_equal that eagerly evaluates if possible.""" - return _maybe_static(a) <= _maybe_static(b) + """A version of tf.less_equal that eagerly evaluates if possible.""" + return _maybe_static(a) <= _maybe_static(b) def logical_and(a, b): - """A version of tf.logical_and that eagerly evaluates if possible.""" - a_value = get_static_value(a) - if a_value is not None: - if np.isscalar(a_value): - if a_value: - return _maybe_static(b) - else: - return a_value + """A version of tf.logical_and that eagerly evaluates if possible.""" + a_value = get_static_value(a) + if a_value is not None: + if np.isscalar(a_value): + if a_value: + return _maybe_static(b) + else: + return a_value + else: + return a_value & _maybe_static(b) else: - return a_value & _maybe_static(b) - else: - return a & _maybe_static(b) + return a & _maybe_static(b) def logical_or(a, b): - """A version of tf.logical_or that eagerly evaluates if possible.""" - a_value = get_static_value(a) - if a_value is not None: - if np.isscalar(a_value): - if a_value: - return a_value - else: - return _maybe_static(b) + """A version of tf.logical_or that eagerly evaluates if possible.""" + a_value = get_static_value(a) + if a_value is not None: + if np.isscalar(a_value): + if a_value: + return a_value + else: + return _maybe_static(b) + else: + return a_value | _maybe_static(b) else: - return a_value | _maybe_static(b) - else: - return a | _maybe_static(b) + return a | _maybe_static(b) def getitem(a, slice_spec): - """A version of __getitem__ that eagerly evaluates if possible.""" - return _maybe_static(a)[slice_spec] + """A version of __getitem__ that eagerly evaluates if possible.""" + return _maybe_static(a)[slice_spec] def reduce_all(input_tensor, axis=None, keepdims=False): - """A version of tf.reduce_all that eagerly evaluates if possible.""" - v = get_static_value(input_tensor) - if v is None: - return tf.reduce_all(input_tensor, axis=axis, keepdims=keepdims) - else: - return v.all(axis=axis, keepdims=keepdims) + """A version of tf.reduce_all that eagerly evaluates if possible.""" + v = get_static_value(input_tensor) + if v is None: + return tf.reduce_all(input_tensor, axis=axis, keepdims=keepdims) + else: + return v.all(axis=axis, keepdims=keepdims) def reduce_any(input_tensor, axis=None, keepdims=False): - """A version of tf.reduce_any that eagerly evaluates if possible.""" - v = get_static_value(input_tensor) - if v is None: - return tf.reduce_any(input_tensor, axis=axis, keepdims=keepdims) - else: - return v.any(axis=axis, keepdims=keepdims) + """A version of tf.reduce_any that eagerly evaluates if possible.""" + v = get_static_value(input_tensor) + if v is None: + return tf.reduce_any(input_tensor, axis=axis, keepdims=keepdims) + else: + return v.any(axis=axis, keepdims=keepdims) diff --git a/trax/trainer.py b/trax/trainer.py index add33cb57..31e9682a7 100644 --- a/trax/trainer.py +++ b/trax/trainer.py @@ -19,16 +19,15 @@ import functools import os +import gin +import jax +import tensorflow.compat.v2 as tf from absl import app from absl import flags from absl import logging - -import gin -import jax from jax.lib import xla_extension as xc -import tensorflow.compat.v2 as tf + from trax import fastmath -from trax import trainer_flags # pylint: disable=unused-import from trax.supervised import trainer_lib from trax.tf_numpy import numpy as tf_np @@ -38,160 +37,165 @@ # TODO(afrozm): Share between trainer.py and rl_trainer.py def _tf_setup_from_flags(): - """Processes TensorFlow-relevant flags.""" - if FLAGS.enable_eager_execution: - tf.compat.v1.enable_eager_execution() - if FLAGS.tf_xla: - tf.config.optimizer.set_jit(True) - fastmath.tf.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) - tf.config.optimizer.set_experimental_options({ - 'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host, - 'layout_optimizer': FLAGS.tf_opt_layout, - }) - tf_np.set_allow_float64(FLAGS.tf_allow_float64) + """Processes TensorFlow-relevant flags.""" + if FLAGS.enable_eager_execution: + tf.compat.v1.enable_eager_execution() + if FLAGS.tf_xla: + tf.config.optimizer.set_jit(True) + fastmath.tf.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) + tf.config.optimizer.set_experimental_options( + { + "pin_to_host_optimization": FLAGS.tf_opt_pin_to_host, + "layout_optimizer": FLAGS.tf_opt_layout, + } + ) + tf_np.set_allow_float64(FLAGS.tf_allow_float64) # TODO(afrozm): Share between trainer.py and rl_trainer.py def _gin_parse_configs(): - """Initializes gin-controlled bindings.""" - # Imports for configurables - # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable - from trax import models as _trax_models - from trax import optimizers as _trax_opt - # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable - - configs = FLAGS.config if FLAGS.config is not None else [] - # Override with --dataset and --model - if FLAGS.dataset: - configs.append("data_streams.dataset_name='%s'" % FLAGS.dataset) - if FLAGS.data_dir: - configs.append("data_streams.data_dir='%s'" % FLAGS.data_dir) - if FLAGS.model: - configs.append('train.model=@trax.models.%s' % FLAGS.model) - gin.parse_config_files_and_bindings(FLAGS.config_file, configs) + """Initializes gin-controlled bindings.""" + # Imports for configurables + # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable + + # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable + + configs = FLAGS.config if FLAGS.config is not None else [] + # Override with --dataset and --model + if FLAGS.dataset: + configs.append("data_streams.dataset_name='%s'" % FLAGS.dataset) + if FLAGS.data_dir: + configs.append("data_streams.data_dir='%s'" % FLAGS.data_dir) + if FLAGS.model: + configs.append("train.model=@trax.models.%s" % FLAGS.model) + gin.parse_config_files_and_bindings(FLAGS.config_file, configs) def _output_dir_or_default(): - """Returns a path to the output directory.""" - if FLAGS.output_dir: - output_dir = FLAGS.output_dir - trainer_lib.log('Using --output_dir {}'.format(output_dir)) - return os.path.expanduser(output_dir) - - # Else, generate a default output dir (under the user's home directory). - try: - dataset_name = gin.query_parameter('data_streams.dataset_name') - except ValueError: - dataset_name = 'random' - output_name = '{model_name}_{dataset_name}_{timestamp}'.format( - model_name=gin.query_parameter('train.model').configurable.name, - dataset_name=dataset_name, - timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'), - ) - output_dir = os.path.join('~', 'trax', output_name) - output_dir = os.path.expanduser(output_dir) - print() - trainer_lib.log('No --output_dir specified') - trainer_lib.log('Using default output_dir: {}'.format(output_dir)) - return output_dir + """Returns a path to the output directory.""" + if FLAGS.output_dir: + output_dir = FLAGS.output_dir + trainer_lib.log("Using --output_dir {}".format(output_dir)) + return os.path.expanduser(output_dir) + + # Else, generate a default output dir (under the user's home directory). + try: + dataset_name = gin.query_parameter("data_streams.dataset_name") + except ValueError: + dataset_name = "random" + output_name = "{model_name}_{dataset_name}_{timestamp}".format( + model_name=gin.query_parameter("train.model").configurable.name, + dataset_name=dataset_name, + timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"), + ) + output_dir = os.path.join("~", "trax", output_name) + output_dir = os.path.expanduser(output_dir) + print() + trainer_lib.log("No --output_dir specified") + trainer_lib.log("Using default output_dir: {}".format(output_dir)) + return output_dir # TODO(afrozm): Share between trainer.py and rl_trainer.py def _jax_and_tf_configure_for_devices(): # pylint: disable=missing-function-docstring - if FLAGS.use_tpu: - jax.config.update('jax_platform_name', 'tpu') - jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) - jax.config.update('jax_backend_target', FLAGS.jax_backend_target) - if (FLAGS.enable_eager_execution and (fastmath.is_backend(Backend.NUMPY) or - fastmath.is_backend(Backend.JAX))): - # Numpy backend doesn't benefit from having the input pipeline run on GPU, - # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
    # set up first before determining the backend. Gin must be + # set up first before determining the backend. + tf.config.experimental.set_visible_devices([], "GPU") def _train_using_tf(output_dir): - worker_cpu = tf_init_tpu() - with tf.device(worker_cpu): - if trainer_lib.num_devices() == 1: - # TF's device priority is GPU > CPU > TPU, so we need to explicitly make - # the TPU core the default device here. - with tf.device('/device:TPU:0'): - trainer_lib.train(output_dir=output_dir) - else: - trainer_lib.train(output_dir=output_dir) + worker_cpu = tf_init_tpu() + with tf.device(worker_cpu): + if trainer_lib.num_devices() == 1: + # TF's device priority is GPU > CPU > TPU, so we need to explicitly make + # the TPU core the default device here. + with tf.device("/device:TPU:0"): + trainer_lib.train(output_dir=output_dir) + else: + trainer_lib.train(output_dir=output_dir) @gin.configurable -def tf_init_tpu(worker='', protocol=None): - """Initializes TPU for TensorFlow. - - Args: - worker: The BNS address of the remote TPU worker. If it's empty (the default
    value), TF will assume the TPU devices are connected to the local host.
  protocol: The network protocol used to connect to the TPU worker.
  Returns:
    The device name of the TPU worker's CPU. If it's empty (the default + value), TF will assume the TPU devices are connected to the local host. + protocol: The network protocol used to connect to the TPU worker. + Returns: + The device name of the TPU worker's CPU. + """ + protocol = protocol or "grpc" + is_local = worker in ("", "local") + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=worker) + if not is_local: + tf.config.experimental_connect_to_cluster(resolver, protocol=protocol) + tf.tpu.experimental.initialize_tpu_system(resolver) + if is_local: + return "" + else: + return "/job:worker" def _make_jax_gpu_cluster(host_id, server_ip, n_hosts, server_port=5005): - """Make JAX GPU Cluster.""" + """Make JAX GPU Cluster.""" - addr = f'{server_ip}:{server_port}' - if host_id == 0: - logging.info('starting service on %s', addr) - service = xc.get_distributed_runtime_service(addr, n_hosts) - # We add an explicit call to shutdown the service via atexit as Python - # interpreter may not call the service destructor on process termination. - atexit.register(service.shutdown)

  logging.info("connecting to service on %s", addr)
  dist_client = xc.get_distributed_runtime_client(addr, host_id)
  dist_client.connect()
  atexit.register(dist_client.shutdown)

  # register dist gpu backend
  factory = functools.partial(
      jax.lib.xla_client.make_gpu_client, dist_client, host_id
  )
  jax.lib.xla_bridge.register_backend_factory("gpu", factory, priority=300) logging.set_verbosity(FLAGS.log_level) - - _tf_setup_from_flags() - _gin_parse_configs() - _jax_and_tf_configure_for_devices() - - # Create a JAX GPU cluster if using JAX and given a chief IP. - if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip: - _make_jax_gpu_cluster(FLAGS.gpu_cluster_host_id, - FLAGS.gpu_cluster_chief_ip, - FLAGS.gpu_cluster_n_hosts, - FLAGS.gpu_cluster_port) - - if FLAGS.disable_jit: - fastmath.disable_jit() - - output_dir = _output_dir_or_default() - if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP): - _train_using_tf(output_dir) - else: - trainer_lib.train(output_dir=output_dir) + logging.set_verbosity(FLAGS.log_level) + + _tf_setup_from_flags() + _gin_parse_configs() + _jax_and_tf_configure_for_devices() + + # Create a JAX GPU cluster if using JAX and given a chief IP. + if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip: + _make_jax_gpu_cluster( + FLAGS.gpu_cluster_host_id, + FLAGS.gpu_cluster_chief_ip, + FLAGS.gpu_cluster_n_hosts, + FLAGS.gpu_cluster_port, + ) + + if FLAGS.disable_jit: + fastmath.disable_jit() + + output_dir = _output_dir_or_default() + if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP): + _train_using_tf(output_dir) + else: + trainer_lib.train(output_dir=output_dir) - trainer_lib.log('Finished training.') + trainer_lib.log("Finished training.") -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/trax/trainer_flags.py b/trax/trainer_flags.py index 097a8ddac..2866e1454 100644 --- a/trax/trainer_flags.py +++ b/trax/trainer_flags.py @@ -21,73 +21,72 @@ from absl import flags from absl import logging + # Common flags. -flags.DEFINE_string('output_dir', - None, - 'Path to the directory to save logs and checkpoints.') -flags.DEFINE_multi_string('config_file', - None, - 'Configuration file with parameters (.gin).') -flags.DEFINE_multi_string('config', - None, - 'Configuration parameters (gin string).') +flags.DEFINE_string( + "output_dir", None, "Path to the directory to save logs and checkpoints." +) +flags.DEFINE_multi_string( + "config_file", None, "Configuration file with parameters (.gin)." +) +flags.DEFINE_multi_string("config", None, "Configuration parameters (gin string).") # TPU Flags -flags.DEFINE_bool('use_tpu', False, "Whether we're running on TPU.") -flags.DEFINE_string('jax_xla_backend', - '', - 'Either "xla" for the XLA service directly, or "tpu_driver"' - 'for a TPU Driver backend.') -flags.DEFINE_string('jax_backend_target', - 'local', - 'Either "local" or "rpc:address" to connect to a ' - 'remote service target.') +flags.DEFINE_bool("use_tpu", False, "Whether we're running on TPU.") +flags.DEFINE_string( + "jax_xla_backend", + "", + 'Either "xla" for the XLA service directly, or "tpu_driver"' + "for a TPU Driver backend.", +) +flags.DEFINE_string( + "jax_backend_target", + "local", + 'Either "local" or "rpc:address" to connect to a ' "remote service target.", +) # trainer.py flags. -flags.DEFINE_string('dataset', None, 'Which dataset to use.') -flags.DEFINE_string('model', None, 'Which model to train.') -flags.DEFINE_string('data_dir', None, 'Path to the directory with data.') -flags.DEFINE_integer('log_level', logging.INFO, 'Log level.') +flags.DEFINE_string("dataset", None, "Which dataset to use.") +flags.DEFINE_string("model", None, "Which model to train.") +flags.DEFINE_string("data_dir", None, "Path to the directory with data.") +flags.DEFINE_integer("log_level", logging.INFO, "Log level.") # JAX/XLA GPU cluster flags. -flags.DEFINE_string('gpu_cluster_chief_ip', '', 'IP of GPU cluster chief.') -flags.DEFINE_integer('gpu_cluster_n_hosts', 1, - 'Number of hosts in GPU cluster.') -flags.DEFINE_integer('gpu_cluster_host_id', 0, 'Host id inside GPU cluster.') -flags.DEFINE_integer('gpu_cluster_port', 5005, 'Port to use in GPU cluster.') +flags.DEFINE_string("gpu_cluster_chief_ip", "", "IP of GPU cluster chief.") +flags.DEFINE_integer("gpu_cluster_n_hosts", 1, "Number of hosts in GPU cluster.") +flags.DEFINE_integer("gpu_cluster_host_id", 0, "Host id inside GPU cluster.") +flags.DEFINE_integer("gpu_cluster_port", 5005, "Port to use in GPU cluster.") # TensorFlow Flags -flags.DEFINE_bool('enable_eager_execution', - True, - "Whether we're running TF in eager mode.") -flags.DEFINE_bool('tf_xla', True, 'Whether to turn on XLA for TF.') -flags.DEFINE_bool('tf_opt_pin_to_host', - False, - 'Whether to turn on TF pin-to-host optimization.') -flags.DEFINE_bool('tf_opt_layout', - False, - 'Whether to turn on TF layout optimization.') -flags.DEFINE_bool('tf_xla_forced_compile', - False, - 'Use forced-compilation instead of auto-clustering for XLA.' - 'This flag only has effects when --tf_xla is on.') -flags.DEFINE_bool('tf_allow_float64', False, 'Whether to allow float64 for TF.') +flags.DEFINE_bool( + "enable_eager_execution", True, "Whether we're running TF in eager mode." +) +flags.DEFINE_bool("tf_xla", True, "Whether to turn on XLA for TF.") +flags.DEFINE_bool( + "tf_opt_pin_to_host", False, "Whether to turn on TF pin-to-host optimization." +) +flags.DEFINE_bool("tf_opt_layout", False, "Whether to turn on TF layout optimization.") +flags.DEFINE_bool( + "tf_xla_forced_compile", + False, + "Use forced-compilation instead of auto-clustering for XLA." + "This flag only has effects when --tf_xla is on.", +) +flags.DEFINE_bool("tf_allow_float64", False, "Whether to allow float64 for TF.") # rl_trainer.py flags. -flags.DEFINE_boolean('jax_debug_nans', - False, - 'Setting to true will help to debug nans and disable jit.') -flags.DEFINE_boolean('disable_jit', False, 'Setting to true will disable jit.') -flags.DEFINE_string('envs_output_dir', '', 'Output dir for the envs.') -flags.DEFINE_bool('xm', False, 'Copy atari roms?') -flags.DEFINE_integer('train_batch_size', - 32, - 'Number of parallel environments during training.') -flags.DEFINE_integer('eval_batch_size', 4, 'Batch size for evaluation.') -flags.DEFINE_boolean('parallelize_envs', - False, - 'If true, sets parallelism to number of cpu cores.') -flags.DEFINE_string('trajectory_dump_dir', - '', - 'Directory to dump trajectories to.') -flags.DEFINE_bool('async_mode', False, 'Async mode.') +flags.DEFINE_boolean( + "jax_debug_nans", False, "Setting to true will help to debug nans and disable jit." +) +flags.DEFINE_boolean("disable_jit", False, "Setting to true will disable jit.") +flags.DEFINE_string("envs_output_dir", "", "Output dir for the envs.") +flags.DEFINE_bool("xm", False, "Copy atari roms?") +flags.DEFINE_integer( + "train_batch_size", 32, "Number of parallel environments during training." +) +flags.DEFINE_integer("eval_batch_size", 4, "Batch size for evaluation.") +flags.DEFINE_boolean( + "parallelize_envs", False, "If true, sets parallelism to number of cpu cores." +) +flags.DEFINE_string("trajectory_dump_dir", "", "Directory to dump trajectories to.") +flags.DEFINE_bool("async_mode", False, "Async mode.") diff --git a/trax/trax2keras.py b/trax/trax2keras.py index df84ef6f1..c92c83a17 100644 --- a/trax/trax2keras.py +++ b/trax/trax2keras.py @@ -17,7 +17,7 @@ import functools -import tensorflow.compat.v2 as tf +import tensorflow.compat.v2 as tf # type: ignore from trax import fastmath as math_lib from trax import shapes as shapes_lib @@ -26,164 +26,178 @@ def _replace_none_batch(x, batch_size=None): - if batch_size is None: + if batch_size is None: + return x + if isinstance(x, tf.Tensor) and x.shape[0] is None: + x.set_shape([batch_size] + x.shape[1:]) + return x + elif isinstance(x, tf.TensorShape) and x[0] is None: + return [batch_size] + x[1:] return x - if isinstance(x, tf.Tensor) and x.shape[0] is None: - x.set_shape([batch_size] + x.shape[1:]) - return x - elif isinstance(x, tf.TensorShape) and x[0] is None: - return [batch_size] + x[1:] - return x def tensor_shapes_to_shape_dtypes(shapes, dtype): - return math_lib.nested_map( - lambda s: shapes_lib.ShapeDtype(s.as_list(), dtype), shapes) + return math_lib.nested_map( + lambda s: shapes_lib.ShapeDtype(s.as_list(), dtype), shapes + ) def read_values(variables): - return math_lib.nested_map(lambda v: v.read_value(), variables) + return math_lib.nested_map(lambda v: v.read_value(), variables) def to_tensors(args): - return math_lib.nested_map(tf.convert_to_tensor, args) + return math_lib.nested_map(tf.convert_to_tensor, args) def to_arrays(args): - return math_lib.nested_map(jnp.asarray, args) + return math_lib.nested_map(jnp.asarray, args) class AsKeras(tf.keras.layers.Layer): - """A Keras layer built from a Trax layer. - - This subclass of `tf.keras.layers.Layer` takes in a Trax layer as a - constructor argument and wraps it to be a Keras layer. It uses
  `tf.Variable` to store weights and state (initialized according to the Trax
  layer), and uses the Trax layer's forward function as its forward function.

  Consider this code snippet::

    keras_layer = AsKeras(trax_layer, initializer_rng=initializer_rng,
                          rng=rng, rng_updater=rng_updater)
    keras_layer.build(...) # optional
    outputs = keras_layer(inputs)

  (Note that in Keras calling `Layer.build` is optional. If omitted, it will be - called automatically by `Layer.__call__`.) - - If `trax_layer` already has weights at `build` time, the snippet is roughly - equivalent to:: - - weights = trax_layer.weights - state = trax_layer.state - keras_layer = tf.keras.layers.Layer() - keras_layer._weights = tf.Variable(weights) - keras_layer._state = tf.Variable(state) - keras_layer._rng = tf.Variable(rng) - outputs, new_state = trax_layer(inputs, keras_layer._weights, - keras_layer._state, keras_layer._rng) - keras_layer._state.assign(new_state) - keras_layer._rng.assign(rng_updater(rng)) - - If `trax_layer` doesn't have weights at `build` time, the snippet is roughly - equivalent to:: - - weights, state = trax_layer.init(..., rng=initializer_rng) - keras_layer = ... - ... - - `AsKeras` uses `tf.Variable` to store weights, not shared with the - original Trax layer (which uses tensors to store weights), so using - `AsKeras` may double the memory footprint. This problem can be solved
  by making sure that the Trax layer's weights/state are cleared whenever
  `tf.Variable.assign` (and `tf.Variable.assign_add` etc.) is called, because
  `tf.Variable` is copy-on-write by default.

  Mutations in those `tf.Variable`s won't affect the Trax layer's weights, but
  `AsKeras`'s forward function calls the Trax layer's forward function,
  which caches the weights in the Trax layer object, so a forward pass may
  change the weights cached in the original Trax layer.

  Note that this class is not thread-safe. If the same `AsKeras` object - is used in multiple threads, the `tf.Variable` updates may happen in a - non-deterministic order. - """ - - def __init__(self, trax_layer, batch_size=None, initializer_rng=None, - rng=None, rng_updater=None, dtype=None): - """Creates a Keras layer wrapping around a Trax layer. - - Args: - trax_layer: an object of class `trax.layers.Layer`, the trax layer to - wrap. - batch_size: (optional) an integer, the batch size that this Keras layer - will be used on. Keras sometimes needs to generate a TF graph for a - layer (e.g. for acceleration or checkpointing). The inputs used to trace - the graph will have `None` as the length of their batch dimensions, so - as to generate a graph that can handle any batch size. Some Trax layers - can't handle tensors whose shapes contain `None`. If `batch_size` is set - to an integer, the graph will be traced with `batch_size` as the batch - size instead of `None`. Note that in this case the graph (and the Keras - layer) can only be used on a specific batch size. If you want to use a - different batch size, you need to create another `AsKeras` object - with a different `batch_size`. - initializer_rng: (optional) an RNG key used to create the weights and - state if `trax_layer` doesn't have them. If `None`, - `trax.fastmath.random.get_prng(0)` will be used. - rng: (optional) an RNG key for the forward function (aka the "forward - key"). If `None`, `trax.fastmath.random.get_prng(0)` will be used. - rng_updater: (optional) a function of type rng_key -> rng_key, used to - update the forward key after each forward pass. If `None`, the function - `lambda x: trax.fastmath.random.split(x, 1)[0]` will be used, which - advances the RNG key. - dtype: (optional) the dtype of the inputs. See the `dtype` argument of - `tf.keras.layers.Layer.__init__` for details. + """A Keras layer built from a Trax layer. + + This subclass of `tf.keras.layers.Layer` takes in a Trax layer as a + constructor argument and wraps it to be a Keras layer. It uses + `tf.Variable` to store weights and state (initialized according to the Trax + layer), and uses the Trax layer's forward function as its forward function. + + Consider this code snippet:: + + keras_layer = AsKeras(trax_layer, initializer_rng=initializer_rng, + rng=rng, rng_updater=rng_updater) + keras_layer.build(...) # optional + outputs = keras_layer(inputs) + + (Note that in Keras calling `Layer.build` is optional. If omitted, it will be + called automatically by `Layer.__call__`.) + + If `trax_layer` already has weights at `build` time, the snippet is roughly + equivalent to:: + + weights = trax_layer.weights + state = trax_layer.state + keras_layer = tf.keras.layers.Layer() + keras_layer._weights = tf.Variable(weights) + keras_layer._state = tf.Variable(state) + keras_layer._rng = tf.Variable(rng) + outputs, new_state = trax_layer(inputs, keras_layer._weights, + keras_layer._state, keras_layer._rng) + keras_layer._state.assign(new_state) + keras_layer._rng.assign(rng_updater(rng)) + + If `trax_layer` doesn't have weights at `build` time, the snippet is roughly + equivalent to:: + + weights, state = trax_layer.init(..., rng=initializer_rng) + keras_layer = ... + ... + + `AsKeras` uses `tf.Variable` to store weights, not shared with the + original Trax layer (which uses tensors to store weights), so using + `AsKeras` may double the memory footprint. This problem can be solved + by making sure that the Trax layer's weights/state are cleared whenever + `tf.Variable.assign` (and `tf.Variable.assign_add` etc.) is called, because + `tf.Variable` is copy-on-write by default. + + Mutations in those `tf.Variable`s won't affect the Trax layer's weights, but + `AsKeras`'s forward function calls the Trax layer's forward function, + which caches the weights in the Trax layer object, so a forward pass may + change the weights cached in the original Trax layer. + + Note that this class is not thread-safe. If the same `AsKeras` object + is used in multiple threads, the `tf.Variable` updates may happen in a + non-deterministic order. """ - super().__init__(dtype=dtype) - with math_lib.use_backend(math_lib.Backend.TFNP): - if initializer_rng is None: - initializer_rng = math_lib.random.get_prng(0) - if rng is None: - rng = math_lib.random.get_prng(0) - if rng_updater is None: - rng_updater = lambda x: math_lib.random.split(x, 1)[0] - self._trax_layer = trax_layer - self._batch_size = batch_size - self._initializer_rng = initializer_rng - self._forward_rng_init = rng - self._rng_updater = rng_updater - - def build(self, input_shape): - with math_lib.use_backend(math_lib.Backend.TFNP): - # Using `is` instead of `==` following Trax's practice - if self._trax_layer.weights is base.EMPTY_WEIGHTS: - sanitized_input_shape = math_lib.nested_map( - functools.partial(_replace_none_batch, batch_size=self._batch_size), - input_shape) - weights, state = self._trax_layer.init( - tensor_shapes_to_shape_dtypes(sanitized_input_shape, self.dtype), - rng=self._initializer_rng) - else: - weights = self._trax_layer.weights - state = self._trax_layer.state - # Note: `weights` may contain `EMPTY_WEIGHTS` - self._weights = math_lib.nested_map( - functools.partial(tf.Variable, trainable=True), weights) - self._state = math_lib.nested_map( - functools.partial(tf.Variable, trainable=False), state) - self._rng = tf.Variable(self._forward_rng_init, trainable=False) - super().build(input_shape) - - def call(self, inputs): - with math_lib.use_backend(math_lib.Backend.TFNP): - inputs = math_lib.nested_map( - functools.partial(_replace_none_batch, batch_size=self._batch_size), - inputs) - weights, state, rng = read_values([self._weights, self._state, self._rng]) - inputs, weights, state, rng = to_arrays([inputs, weights, state, rng]) - outputs, new_state = self._trax_layer.pure_fn(inputs, weights=weights, - state=state, rng=rng) - tf.nest.map_structure(lambda v, t: v.assign(t), self._state, new_state) - self._rng.assign(self._rng_updater(rng)) - outputs = to_tensors(outputs) - return outputs + + def __init__( + self, + trax_layer, + batch_size=None, + initializer_rng=None, + rng=None, + rng_updater=None, + dtype=None, + ): + """Creates a Keras layer wrapping around a Trax layer. + + Args: + trax_layer: an object of class `trax.layers.Layer`, the trax layer to + wrap. + batch_size: (optional) an integer, the batch size that this Keras layer + will be used on. Keras sometimes needs to generate a TF graph for a
      layer (e.g. for acceleration or checkpointing). The inputs used to trace
      the graph will have `None` as the length of their batch dimensions, so
      as to generate a graph that can handle any batch size. Some Trax layers
      can't handle tensors whose shapes contain `None`. If `batch_size` is set
      to an integer, the graph will be traced with `batch_size` as the batch
      size instead of `None`. Note that in this case the graph (and the Keras
      layer) can only be used on a specific batch size. If you want to use a
      different batch size, you need to create another `AsKeras` object
      with a different `batch_size`.
    initializer_rng: (optional) an RNG key used to create the weights and
      state if `trax_layer` doesn't have them. If `None`,
      `trax.fastmath.random.get_prng(0)` will be used.
    rng: (optional) an RNG key for the forward function (aka the "forward
      key"). If `None`, `trax.fastmath.random.get_prng(0)` will be used. + rng_updater: (optional) a function of type rng_key -> rng_key, used to + update the forward key after each forward pass. If `None`, the function + `lambda x: trax.fastmath.random.split(x, 1)[0]` will be used, which + advances the RNG key. + dtype: (optional) the dtype of the inputs. See the `dtype` argument of + `tf.keras.layers.Layer.__init__` for details. + """ + super().__init__(dtype=dtype) + with math_lib.use_backend(math_lib.Backend.TFNP): + if initializer_rng is None: + initializer_rng = math_lib.random.get_prng(0) + if rng is None: + rng = math_lib.random.get_prng(0) + if rng_updater is None: + rng_updater = lambda x: math_lib.random.split(x, 1)[0] + self._trax_layer = trax_layer + self._batch_size = batch_size + self._initializer_rng = initializer_rng + self._forward_rng_init = rng + self._rng_updater = rng_updater + + def build(self, input_shape): + with math_lib.use_backend(math_lib.Backend.TFNP): + # Using `is` instead of `==` following Trax's practice + if self._trax_layer.weights is base.EMPTY_WEIGHTS: + sanitized_input_shape = math_lib.nested_map( + functools.partial(_replace_none_batch, batch_size=self._batch_size), + input_shape, + ) + weights, state = self._trax_layer.init( + tensor_shapes_to_shape_dtypes(sanitized_input_shape, self.dtype), + rng=self._initializer_rng, + ) + else: + weights = self._trax_layer.weights + state = self._trax_layer.state + # Note: `weights` may contain `EMPTY_WEIGHTS` + self._weights = math_lib.nested_map( + functools.partial(tf.Variable, trainable=True), weights + ) + self._state = math_lib.nested_map( + functools.partial(tf.Variable, trainable=False), state + ) + self._rng = tf.Variable(self._forward_rng_init, trainable=False) + super().build(input_shape) + + def call(self, inputs): + with math_lib.use_backend(math_lib.Backend.TFNP): + inputs = math_lib.nested_map( + functools.partial(_replace_none_batch, batch_size=self._batch_size), + inputs, + ) + weights, state, rng = read_values([self._weights, self._state, self._rng]) + inputs, weights, state, rng = to_arrays([inputs, weights, state, rng]) + outputs, new_state = self._trax_layer.pure_fn( + inputs, weights=weights, state=state, rng=rng + ) + tf.nest.map_structure(lambda v, t: v.assign(t), self._state, new_state) + self._rng.assign(self._rng_updater(rng)) + outputs = to_tensors(outputs) + return outputs diff --git a/trax/trax2keras_test.py b/trax/trax2keras_test.py deleted file mode 100644 index 9fbf86f52..000000000 --- a/trax/trax2keras_test.py +++ /dev/null @@ -1,192 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); If `True`, we will also test checkpointing and restoring - using the model. - """ - with trax.fastmath.use_backend("tensorflow-numpy"): - make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = ( - _LAYERS[layer_id]) - # We make a fresh trax layer for each test case, so that different test - # cases won't interfere with each other. - trax_layer = make_trax_layer() - if not allow_none_batch and batch_size is None: - self.skipTest("This Trax layer can't handle None batch size.") - rng_updater = _RNG_UPDATERS[rng_updater_id] - input_shapes = math_lib.nested_map( - lambda s: [batch_size] + s, input_shapes_no_batch) - input_sig = trax2keras.tensor_shapes_to_shape_dtypes(input_shapes, dtype) - initializer_rng = math_lib.random.get_prng(765) - weights, state = trax_layer.init(input_sig, rng=initializer_rng) - generator = tf.random.Generator.from_seed(567) - def get_inputs(): - return dummy_inputs(generator, input_sig) - if trax_has_weights: - trax_layer(to_arrays(get_inputs()), weights=weights, state=state) - rng = math_lib.random.get_prng(1234) - keras_layer = trax2keras.AsKeras( - trax_layer, batch_size=batch_size, initializer_rng=initializer_rng, - rng=rng, rng_updater=rng_updater) - if explicit_build: - keras_layer.build(input_shapes) - if use_model: - x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype) - y = keras_layer(x) - keras_model = tf.keras.Model(inputs=x, outputs=y) - lr = 0.1 # learning rate - for _ in range(3): - inputs = get_inputs() - with tf.GradientTape() as trax_tape: - trax_tape.watch(tf.nest.flatten(weights)) - trax_outputs, state = trax_layer.pure_fn( - to_arrays(inputs), weights=weights, state=state, rng=rng) - trax_grads = trax_tape.gradient(*to_tensors([trax_outputs, weights])) - # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor` - # before multiplication. - weights = tf.nest.map_structure( - lambda w, g: w + jnp.asarray(lr * tf.convert_to_tensor(g), w.dtype), - weights, trax_grads) - rng = rng_updater(rng) - with tf.GradientTape() as keras_tape: - if use_model: - keras_outputs = keras_model(inputs) - else: - keras_outputs = keras_layer(inputs) - if isinstance(keras_outputs, tuple) and len(keras_outputs) == 1: - keras_outputs = keras_outputs[0] - self.assertAllClose(to_tensors(trax_outputs), keras_outputs, atol=1e-5) - keras_grads = keras_tape.gradient(keras_outputs, - keras_layer.trainable_variables) - tf.nest.map_structure( - lambda v, g: v.assign_add( # pylint: disable=g-long-lambda - tf.cast(lr * tf.convert_to_tensor(g), v.dtype)), - keras_layer.trainable_variables, keras_grads) - self.assertAllClose( - to_tensors(weights), read_values(keras_layer._weights), - rtol=2e-6, atol=4.5e-4 if has_gpu() else 1e-6) - self.assertAllClose(to_tensors(state), read_values(keras_layer._state)) - self.assertAllClose(to_tensors(rng), read_values(keras_layer._rng)) - if use_model: - fname = os.path.join(self.get_temp_dir(), "checkpoint") - keras_model.save(fname) - loaded_model = tf.keras.models.load_model(fname) - for _ in range(2): - inputs = get_inputs() - self.assertAllClose(keras_model(inputs), loaded_model(inputs)) - - -if __name__ == "__main__": - absltest.main()